The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── README_zh.md
├── docker
    ├── Dockerfile
    ├── Dockerfile.rocm
    └── patch
    │   ├── megatron-sandwich_norm.patch
    │   └── sglang.patch
├── docs
    ├── en
    │   ├── agent_training.md
    │   ├── amd_tutorial.md
    │   ├── build.md
    │   ├── debug.md
    │   ├── models
    │   │   ├── glm4-9B.md
    │   │   ├── qwen3-30B-A3B.md
    │   │   └── qwen3-4B.md
    │   ├── qa.md
    │   ├── rollout_buffer_usage.md
    │   ├── sft.md
    │   └── usage.md
    └── zh
    │   ├── agent_training.md
    │   ├── build.md
    │   ├── debug.md
    │   ├── models
    │       ├── glm4-9B.md
    │       ├── qwen3-30B-A3B.md
    │       └── qwen3-4B.md
    │   ├── qa.md
    │   ├── rollout_buffer_usage.md
    │   ├── sft.md
    │   └── usage.md
├── examples
    └── search-r1
    │   ├── README.md
    │   ├── README_zh.md
    │   ├── generate_with_search.py
    │   ├── google_search_server.py
    │   ├── qa_em_format.py
    │   └── run_qwen2.5_3B.sh
├── imgs
    └── arch.png
├── pyproject.toml
├── requirements.txt
├── scripts
    ├── agent-example.sh
    ├── models
    │   ├── glm4-32B.sh
    │   ├── glm4-9B.sh
    │   ├── moonlight.sh
    │   ├── qwen2.5-1.5B.sh
    │   ├── qwen2.5-32B.sh
    │   ├── qwen2.5-3B.sh
    │   ├── qwen2.5-7B.sh
    │   ├── qwen3-0.6B.sh
    │   ├── qwen3-235B-A22B.sh
    │   ├── qwen3-30B-A3B.sh
    │   ├── qwen3-4B.sh
    │   └── qwen3-8B.sh
    ├── run-glm4-9B.sh
    ├── run-qwen3-235B-A22B-sft.sh
    ├── run-qwen3-235B-A22B.sh
    ├── run-qwen3-30B-A3B.sh
    ├── run-qwen3-4B-amd.sh
    ├── run-qwen3-4B-base-sft.sh
    ├── run-qwen3-4B-rf-baseline.sh
    ├── run-qwen3-4B-rf.sh
    ├── run-qwen3-4B.sh
    └── run_agent.sh
├── setup.py
├── slime
    ├── __init__.py
    ├── backends
    │   ├── megatron_utils
    │   │   ├── __init__.py
    │   │   ├── arguments.py
    │   │   ├── checkpoint.py
    │   │   ├── cp_utils.py
    │   │   ├── data.py
    │   │   ├── initialize.py
    │   │   ├── loss.py
    │   │   ├── model.py
    │   │   ├── models
    │   │   │   └── __init__.py
    │   │   └── update_weight_utils.py
    │   └── sglang_utils
    │   │   ├── __init__.py
    │   │   ├── arguments.py
    │   │   ├── http_server_engine.py
    │   │   └── sglang_engine.py
    ├── ray
    │   ├── __init__.py
    │   ├── buffer.py
    │   ├── placement_group.py
    │   ├── ppo_actor.py
    │   ├── ray_actor.py
    │   ├── rollout.py
    │   └── utils.py
    ├── rollout
    │   ├── agent_rollout.py
    │   ├── filter_hub
    │   │   ├── __init__.py
    │   │   ├── dynamic_sampling_filters.py
    │   │   └── over_sampling_filters.py
    │   ├── rm_hub
    │   │   ├── __init__.py
    │   │   ├── deepscaler.py
    │   │   ├── f1.py
    │   │   ├── math_dapo_utils.py
    │   │   └── math_utils.py
    │   ├── sft_example.py
    │   └── sglang_example.py
    └── utils
    │   ├── __init__.py
    │   ├── arguments.py
    │   ├── async_utils.py
    │   ├── data.py
    │   ├── distributed_utils.py
    │   ├── flops_utils.py
    │   ├── http_utils.py
    │   ├── mask_utils.py
    │   ├── memory_utils.py
    │   ├── misc.py
    │   ├── ppo_utils.py
    │   ├── seqlen_balancing.py
    │   ├── timer.py
    │   └── types.py
├── slime_plugins
    ├── __init__.py
    ├── mbridge
    │   ├── __init__.py
    │   └── glm4.py
    ├── models
    │   ├── __init__.py
    │   └── glm4.py
    └── rollout_buffer
    │   ├── buffer.py
    │   ├── generator
    │       ├── __init__.py
    │       ├── base_generator.py
    │       ├── reward_utils
    │       │   ├── __init__.py
    │       │   └── math_utils.py
    │       └── utils
    │       │   ├── arguments.py
    │       │   └── default_func.py
    │   └── tools
    │       ├── assign_instance_id.py
    │       └── visualizer.py
├── tests
    └── test_qwen3_0.6B.sh
├── tools
    ├── convert_hf_to_torch_dist.py
    ├── convert_to_hf.py
    └── convert_torch_dist_to_hf.py
├── train.py
└── train_async.py


/.gitignore:
--------------------------------------------------------------------------------
  1 | # Byte-compiled / optimized / DLL files
  2 | __pycache__/
  3 | *.py[cod]
  4 | *$py.class
  5 | 
  6 | # C extensions
  7 | *.so
  8 | 
  9 | # Distribution / packaging
 10 | .Python
 11 | build/
 12 | develop-eggs/
 13 | dist/
 14 | downloads/
 15 | eggs/
 16 | .eggs/
 17 | lib/
 18 | lib64/
 19 | parts/
 20 | sdist/
 21 | var/
 22 | wheels/
 23 | share/python-wheels/
 24 | *.egg-info/
 25 | .installed.cfg
 26 | *.egg
 27 | MANIFEST
 28 | 
 29 | # PyInstaller
 30 | #  Usually these files are written by a python script from a template
 31 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 32 | *.manifest
 33 | *.spec
 34 | 
 35 | # Installer logs
 36 | pip-log.txt
 37 | pip-delete-this-directory.txt
 38 | 
 39 | # Unit test / coverage reports
 40 | htmlcov/
 41 | .tox/
 42 | .nox/
 43 | .coverage
 44 | .coverage.*
 45 | .cache
 46 | nosetests.xml
 47 | coverage.xml
 48 | *.cover
 49 | *.py,cover
 50 | .hypothesis/
 51 | .pytest_cache/
 52 | cover/
 53 | 
 54 | # Translations
 55 | *.mo
 56 | *.pot
 57 | 
 58 | # Django stuff:
 59 | *.log
 60 | local_settings.py
 61 | db.sqlite3
 62 | db.sqlite3-journal
 63 | 
 64 | # Flask stuff:
 65 | instance/
 66 | .webassets-cache
 67 | 
 68 | # Scrapy stuff:
 69 | .scrapy
 70 | 
 71 | # Sphinx documentation
 72 | docs/_build/
 73 | 
 74 | # PyBuilder
 75 | .pybuilder/
 76 | target/
 77 | 
 78 | # Jupyter Notebook
 79 | .ipynb_checkpoints
 80 | 
 81 | # IPython
 82 | profile_default/
 83 | ipython_config.py
 84 | 
 85 | # pyenv
 86 | #   For a library or package, you might want to ignore these files since the code is
 87 | #   intended to run in multiple environments; otherwise, check them in:
 88 | # .python-version
 89 | 
 90 | # pipenv
 91 | #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
 92 | #   However, in case of collaboration, if having platform-specific dependencies or dependencies
 93 | #   having no cross-platform support, pipenv may install dependencies that don't work, or not
 94 | #   install all needed dependencies.
 95 | #Pipfile.lock
 96 | 
 97 | # UV
 98 | #   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
 99 | #   This is especially recommended for binary packages to ensure reproducibility, and is more
100 | #   commonly ignored for libraries.
101 | #uv.lock
102 | 
103 | # poetry
104 | #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | #   This is especially recommended for binary packages to ensure reproducibility, and is more
106 | #   commonly ignored for libraries.
107 | #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 | 
110 | # pdm
111 | #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | #   in version control.
115 | #   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 | 
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 | 
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 | 
127 | # SageMath parsed files
128 | *.sage.py
129 | 
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 | 
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 | 
143 | # Rope project settings
144 | .ropeproject
145 | 
146 | # mkdocs documentation
147 | /site
148 | 
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 | 
154 | # Pyre type checker
155 | .pyre/
156 | 
157 | # pytype static type analyzer
158 | .pytype/
159 | 
160 | # Cython debug symbols
161 | cython_debug/
162 | 
163 | # PyCharm
164 | #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | #  and can be added to the global gitignore or merged into this file.  For a more nuclear
167 | #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 | 
170 | # Ruff stuff:
171 | .ruff_cache/
172 | 
173 | # PyPI configuration file
174 | .pypirc
175 | 
176 | wandb/
177 | outputs/
178 | local/
179 | **/rollout_data/
180 | **/buffer_stats/
181 | *.out
182 | *.pkl
183 | 


--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
 1 | default_language_version:
 2 |   python: python3
 3 | 
 4 | ci:
 5 |   autofix_prs: true
 6 |   autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
 7 |   autoupdate_schedule: quarterly
 8 | 
 9 | repos:
10 |   - repo: https://github.com/pre-commit/pre-commit-hooks
11 |     rev: v4.5.0
12 |     hooks:
13 |       - id: check-yaml
14 |       - id: check-case-conflict
15 |       - id: detect-private-key
16 |       - id: check-added-large-files
17 |         args: ['--maxkb=1000']
18 |       - id: requirements-txt-fixer
19 | 
20 |   - repo: https://github.com/PyCQA/autoflake
21 |     rev: v2.0.2
22 |     hooks:
23 |       - id: autoflake
24 |         args: [--remove-all-unused-imports, --in-place]
25 | 
26 |   - repo: https://github.com/PyCQA/isort
27 |     rev: 5.13.2
28 |     hooks:
29 |       - id: isort
30 |         name: Format imports
31 |         exclude: docs/
32 | 
33 |   - repo: https://github.com/psf/black
34 |     rev: 24.3.0
35 |     hooks:
36 |       - id: black
37 |         name: Format code
38 |         additional_dependencies: ['click==8.0.2']
39 | 


--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
  1 | # slime
  2 | 
  3 | [English](./README.md)
  4 | 
  5 | **slime** 是为 RL scaling 设计的 LLM post‑training 框架,提供两大核心能力:
  6 | 
  7 | 1. **高性能训练**:通过连接 Megatron 与 SGLang,支持各种模式的高效训练;
  8 | 2. **灵活的数据生成**:通过自定义数据生成接口以及 server based engine,实现任意的数据训练数据生成流程。
  9 | 
 10 | ## 目录
 11 | 
 12 | - [架构总览](#架构总览)
 13 | - [快速开始](#快速开始)
 14 |   - [环境准备](#环境准备)
 15 |   - [示例](#示例)
 16 |     - [Dense 模型示例:GLM-4-9B 与 Qwen3-4B](#Dense-模型示例GLM-4-9B-与-Qwen3-4B)
 17 |     - [MoE 模型示例:Qwen3-30B-A3B](#MoE-模型示例Qwen3-30B-A3B)
 18 |     - [多轮对话 + 工具调用示例:Search-R1 lite](#多轮对话--工具调用示例Search-R1-lite)
 19 |     - [SFT 示例:Qwen3-4B-Base + OpenHermes-2.5](#SFT-示例Qwen3-4B-Base--OpenHermes-25)
 20 | - [Checkpoint 格式转换](#checkpoint-格式转换)
 21 | - [启动训练流程](#启动训练流程)
 22 | - [参数说明](#参数说明)
 23 | - [开发指南](#开发指南)
 24 | - [常见 Q&A 与致谢](#常见-qa-与致谢)
 25 | 
 26 | ## 架构总览
 27 | 
 28 | ![arch](./imgs/arch.png)
 29 | 
 30 | **模块说明**:
 31 | 
 32 | - **training (Megatron)**:负责主训练流程,从 Data Buffer 读取数据,训练完后将参数同步至 rollout 模块;
 33 | - **rollout (SGLang + router)**:生成新数据(含 reward/verifier),存储至 Data Buffer;
 34 | - **data buffer**:桥梁模块,管理 prompt 初始化、自定义数据与 rollout 生成方法。
 35 | 
 36 | ## 快速开始
 37 | 
 38 | ### 环境准备
 39 | 
 40 | 基于镜像 zhuzilin/slime:latest(已预装 SGLang 0.4.7 和 Megatron):
 41 | 
 42 | ```bash
 43 | docker run --rm --gpus all --ipc=host --shm-size=16g \
 44 |   --ulimit memlock=-1 --ulimit stack=67108864 \
 45 |   -it zhuzilin/slime:latest /bin/bash
 46 | 
 47 | git clone https://github.com/THUDM/slime.git
 48 | cd slime
 49 | pip install -e .
 50 | ```
 51 | 
 52 | - 对于不方便使用 docker 的场景,请参考 [从零搭建环境](./docs/zh/build.md);
 53 | - 对于 AMD 支持,请参考 [AMD 使用教程](./docs/en/amd_tutorial.md)。
 54 | 
 55 | ### 示例
 56 | 
 57 | #### Dense 模型示例:GLM-4-9B 与 Qwen3-4B
 58 | 
 59 | 我们提供了 [GLM-4-9B](https://huggingface.co/THUDM/GLM-Z1-9B-0414) 和 [Qwen3-4B](https://huggingface.co/Qwen/Qwen3-4B) 的使用示例,可以通过他们对 slime 的使用方法有个基本的了解:
 60 | 
 61 | - [示例:GLM-4-9B 模型](docs/zh/models/glm4-9B.md)
 62 | - [示例:Qwen3-4B 模型](docs/zh/models/qwen3-4B.md)
 63 | 
 64 | #### MoE 模型示例:Qwen3-30B-A3B
 65 | 
 66 | 我们也提供了 MoE 模型的示例,请查看:
 67 | 
 68 | - [示例:Qwen3-30B-A3B 模型](docs/zh/models/qwen3-30B-A3B.md)
 69 | 
 70 | #### 多轮对话 + 工具调用示例:Search-R1 lite
 71 | 
 72 | 针对多轮对话和工具调用场景,我们提供了一个简化版的 Search-R1 复现,请查看:
 73 | 
 74 | - [示例:Search-R1 lite](examples/search-r1/README_zh.md)
 75 | 
 76 | #### SFT 示例:Qwen3-4B-Base + OpenHermes-2.5
 77 | 
 78 | slime is not just a RL framework, we support a diverse set of post-training setups. For an SFT example, please refer to:
 79 | 
 80 | slime 不仅仅是一个 RL 框架,我们还支持了各种后训练流程。如果想使用 SFT,请参看:
 81 | 
 82 | - [示例: Qwen3-4B-Base + OpenHermes-2.5](docs/zh/sft.md).
 83 | 
 84 | ### Checkpoint 格式转换
 85 | 
 86 | 由于 slime 使用 megatron,而 megatron 不支持加载 huggingface checkpoint,我们需要将模型转换至 megatron 可以支持的 torch_dist 格式。
 87 | 
 88 | #### HF → Megatron torch_dist ckpt
 89 | 
 90 | 使用 [mbridge](https://github.com/ISEEKYAN/mbridge.git) 转换:
 91 | 
 92 | ```bash
 93 | cd slime/
 94 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
 95 |     --hf-checkpoint /root/GLM-Z1-9B-0414 \
 96 |     --save /root/GLM-Z1-9B-0414_torch_dist
 97 | ```
 98 | 
 99 | 在遇到 mbridge 暂时不支持的模型的时候,可以考虑使用 [Pai-Megatron-Patch](https://github.com/alibaba/Pai-Megatron-Patch) 进行转换。
100 | 
101 | ⚠️  如果出现找不到 slime 的问题,请在 slime 目录下 `pip install -e .`。
102 | 
103 | #### Megatron torch_dist → HF ckpt
104 | 
105 | 将训练过程中的存储的 torch_dist ckpt 转为 hf ckpt:
106 | 
107 | ```bash
108 | cd slime/
109 | PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py \
110 |   --input-dir /path/to/torch_dist_ckpt/iter_xxx/ \
111 |   --output-dir /root/GLM-Z1-9B-0414-iter_xxx \
112 |   --origin-hf-dir /root/GLM-Z1-9B-0414
113 | ```
114 | 
115 | ⚠️ 由于 mbridge 转换的 torch_dist ckpt 目前不保存 args,不能基于上一步的 torch_dist ckpt 反转回 HF。
116 | 
117 | #### 任意 Megatron ckpt → HF
118 | 
119 | 适用于自定义保存格式(如 `--ckpt-format torch`)。
120 | 
121 | 转化方式的原理是直接复用训练中,从 megatron 向 sglang 更新参数的函数,也就是直接复用一下训练脚本,将原先的:
122 | 
123 | ```bash
124 | ray job submit --address="http://127.0.0.1:8265" \
125 |    --runtime-env-json='{
126 |      "env_vars": { ...}
127 |    }' \
128 |    -- python3 train.py \
129 |    ... # 其他训练 args
130 | ```
131 | 
132 | 改成:
133 | 
134 | ```bash
135 | torchrun --nproc_per_node ${NUM_GPU} tools/convert_to_hf.py \
136 |    --load /your/saved/megatron_ckpt \
137 |    --output-dir /your/converted/hf_ckpt \
138 |    ... # 其他训练 args
139 | ```
140 | 
141 | 即,保持所有的参数不变,将:
142 | 
143 | 1. 任务启动从 ray 变成 torchrun,把 gpu 数量保存为 megatron 并行的不带 dp 的最小 gpu 数,例如如果是 tp4,就设成 4;
144 | 2. 确认把 `--load` 改成了需要 load 的路径;
145 | 3. 增加 `--output-dir` 对应要保存的 hf_ckpt。
146 | 
147 | ## 启动训练流程
148 | 
149 | 整个程序需要使用 ray 进行启动,首先需要启动一个 ray 集群,即在 node 0 运行:
150 | 
151 | ```bash
152 | # Node0(HEAD)
153 | ray start --head --node-ip-address ${MASTER_ADDR} \
154 |   --num-gpus 8 --disable-usage-stats
155 | 
156 | # 其他 Node
157 | ray start --address=${MASTER_ADDR}:6379 --num-gpus 8
158 | ```
159 | 
160 | 在 ray 集群启动后,可以在 node 0 提交任务,例如:
161 | 
162 | ```bash
163 | ray job submit --address="http://127.0.0.1:8265" \
164 |    --runtime-env-json='{
165 |      "env_vars": {
166 |         "PYTHONPATH": "/root/Megatron-LM/",
167 |         ... # e.g. no_proxy、接口变量等
168 |      }
169 |    }' \
170 |    -- python3 train.py \
171 |    --...(其他 Megatron/SGLang/slime 参数)
172 | ```
173 | 
174 | #### 参数说明
175 | 
176 | 参数分为三类:
177 | 
178 | 1. **megatron 参数**:slime 会读取 `PYTHONPATH` 中的 megatron 里设置的所有参数,可以通过传入如 `--tensor-model-parallel-size 2` 的方式配置 megatron;
179 | 2. **sglang 参数**:支持环境中安装的 sglang 的所有参数,这些参数需要以 `--sglang` 起始,例如 `--mem-fraction-static` 需要通过 `--sglang-mem-fraction-static` 传入。
180 | 3. **slime 自身的参数**:请见:[slime/utils/arguments.py](slime/utils/arguments.py)
181 | 
182 | 完整使用说明请查阅 [使用文档](docs/zh/usage.md)。
183 | 
184 | ## 开发指南
185 | 
186 | - **欢迎贡献!** 若有功能建议、性能调优或使用体验反馈,欢迎提交 Issue / PR 😊
187 | 
188 | - 使用 [pre-commit](https://pre-commit.com/) 保证提交代码风格:
189 | 
190 |   ```bash
191 |   apt install pre-commit -y
192 |   pre-commit install
193 |   ```
194 | 
195 | - 调试技巧请参考 [debug 指南](docs/zh/debug.md)
196 | 
197 | ## 常见 Q&A 与致谢
198 | 
199 | - 常见问题请见 [Q&A](docs/zh/qa.md)
200 | - 特别感谢以下项目 & 社区:SGLang、Megatron‑LM、mbridge、OpenRLHF、veRL 等。
201 | 


--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
 1 | FROM lmsysorg/sglang:dev AS base
 2 | 
 3 | # TODO: change to pip install sglang-router after it has a new release
 4 | RUN pip install sglang-router --force-reinstall
 5 | RUN pip install ray[default]
 6 | RUN pip install httpx[http2] wandb pylatexenc blobfile accelerate "mcp[cli]"
 7 | RUN pip install git+https://github.com/zhuzilin/cumem_allocator.git
 8 | 
 9 | # mbridge
10 | RUN pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps
11 | 
12 | RUN TORCH_CUDA_ARCH_LIST="8.0;8.9;9.0;9.0a" pip install git+https://github.com/fanshiqing/grouped_gemm@v1.1.4
13 | # apex
14 | RUN NVCC_APPEND_FLAGS="--threads 4" \
15 |   pip -v install --disable-pip-version-check --no-cache-dir \
16 |   --no-build-isolation \
17 |   --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" git+https://github.com/NVIDIA/apex.git
18 | # transformer engine
19 | RUN pip -v install transformer_engine[pytorch]
20 | # flash attn
21 | # the newest version megatron supports is v2.7.4.post1
22 | RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
23 | 
24 | WORKDIR /root/
25 | RUN git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \
26 |     cd Megatron-LM && \
27 |     pip install -e .
28 | 
29 | # sandwitch norm for GLM models
30 | COPY patch/megatron-sandwich_norm.patch /root/Megatron-LM/
31 | RUN cd Megatron-LM && \
32 |     git apply megatron-sandwich_norm.patch --3way && \
33 |     if grep -R -n '^<<<<<<< ' .; then \
34 |       echo "Patch failed to apply cleanly. Please resolve conflicts." && \
35 |       exit 1; \
36 |     fi && \
37 |     rm megatron-sandwich_norm.patch
38 | 
39 | # sglang patch
40 | COPY patch/sglang.patch /sgl-workspace/sglang/
41 | RUN cd /sgl-workspace/sglang && \
42 |   git apply sglang.patch && \
43 |   if grep -R -n '^<<<<<<< ' .; then \
44 |     echo "Patch failed to apply cleanly. Please resolve conflicts." && \
45 |     exit 1; \
46 |   fi && \
47 |   rm sglang.patch
48 | 


--------------------------------------------------------------------------------
/docker/patch/sglang.patch:
--------------------------------------------------------------------------------
 1 | diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py
 2 | index 25104bd..6853784 100644
 3 | --- a/python/sglang/srt/configs/model_config.py
 4 | +++ b/python/sglang/srt/configs/model_config.py
 5 | @@ -405,14 +405,14 @@ class ModelConfig:
 6 |              quant_method = quant_cfg.get("quant_method", "").lower()
 7 |  
 8 |              # Detect which checkpoint is it
 9 | -            for _, method in QUANTIZATION_METHODS.items():
10 | -                quantization_override = method.override_quantization_method(
11 | -                    quant_cfg, self.quantization
12 | -                )
13 | -                if quantization_override:
14 | -                    quant_method = quantization_override
15 | -                    self.quantization = quantization_override
16 | -                    break
17 | +            #for _, method in QUANTIZATION_METHODS.items():
18 | +            #    quantization_override = method.override_quantization_method(
19 | +            #        quant_cfg, self.quantization
20 | +            #    )
21 | +            #    if quantization_override:
22 | +            #        quant_method = quantization_override
23 | +            #        self.quantization = quantization_override
24 | +            #        break
25 |  
26 |              # Verify quantization configurations.
27 |              if self.quantization is None:
28 | diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py
29 | index f2c0d61..e548f39 100644
30 | --- a/python/sglang/srt/layers/quantization/fp8.py
31 | +++ b/python/sglang/srt/layers/quantization/fp8.py
32 | @@ -340,10 +340,10 @@ class Fp8LinearMethod(LinearMethodBase):
33 |                  return
34 |              else:
35 |                  weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
36 | -            layer.weight = torch.nn.Parameter(weight, requires_grad=False)
37 | -            layer.weight_scale_inv = torch.nn.Parameter(
38 | -                weight_scale, requires_grad=False
39 | -            )
40 | +            # layer.weight = torch.nn.Parameter(weight, requires_grad=False)
41 | +            # layer.weight_scale_inv = torch.nn.Parameter(
42 | +            #     weight_scale, requires_grad=False
43 | +            # )
44 |              return
45 |  
46 |          layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
47 | 


--------------------------------------------------------------------------------
/docs/en/build.md:
--------------------------------------------------------------------------------
 1 | # Setting up the Environment from Scratch
 2 | 
 3 | [中文版](../zh/build.md)
 4 | 
 5 | If it is inconvenient to directly use our pre-built image, we provide the following solution for setting up the environment:
 6 | 
 7 | ## Setting up the environment based on anaconda / mamba
 8 | 
 9 | Here, we take micromamba as an example to build a conda environment named `slime` within the official sglang image `lmsysorg/sglang:latest`:
10 | 
11 | ```bash
12 | ####################
13 | # create conda
14 | ####################
15 | yes '' | "${SHELL}" <(curl -L micro.mamba.pm/install.sh)
16 | source ~/.bashrc
17 | micromamba self-update
18 | 
19 | micromamba create -n slime python=3.10 pip -c conda-forge -y
20 | # install cuda-12.6.0 as this is the default cuda version for pytorch
21 | # and apex need this alignment.
22 | micromamba install -n slime cuda cuda-nvtx cuda-nvtx-dev -c nvidia/label/cuda-12.6.0 -y
23 | micromamba install -n slime -c conda-forge cudnn -y
24 | micromamba run -n slime pip install cmake ninja
25 | 
26 | ####################
27 | # sglang deps
28 | ####################
29 | cd /root/
30 | git clone https://github.com/sgl-project/sglang.git --branch v0.4.9 --depth 1
31 | cd /root/sglang/
32 | micromamba run -n slime pip -v install -e "python[all]"
33 | # TODO: change to pip install sglang-router after it has a new release
34 | micromamba run -n slime pip install https://github.com/zhuzilin/sgl-router/releases/download/dev/sglang_router-0.1.4-cp310-cp310-linux_x86_64.whl --force-reinstall
35 | 
36 | ####################
37 | # megatron deps
38 | ####################
39 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" micromamba run -n slime \
40 |   pip -v install --no-build-isolation \
41 |   git+https://github.com/fanshiqing/grouped_gemm@v1.1.4
42 | # apex
43 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" NVCC_APPEND_FLAGS="--threads 4" \
44 | micromamba run -n slime \
45 |   pip -v install --disable-pip-version-check --no-cache-dir \
46 |   --no-build-isolation \
47 |   --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" git+https://github.com/NVIDIA/apex.git
48 | # transformer engine
49 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" micromamba run -n slime \
50 |   pip -v install transformer_engine[pytorch]
51 | # flash attn
52 | # the newest version megatron supports is v2.7.4.post1
53 | micromamba run -n slime pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
54 | # megatron
55 | cd /root/
56 | git clone https://github.com/NVIDIA/Megatron-LM.git
57 | cd Megatron-LM/
58 | micromamba run -n slime pip install -e .
59 | 
60 | ####################
61 | # other deps
62 | ####################
63 | micromamba run -n slime pip install git+https://github.com/zhuzilin/cumem_allocator.git --no-build-isolation
64 | 
65 | ####################
66 | # slime
67 | ####################
68 | cd /root/
69 | git clone https://github.com/THUDM/slime.git
70 | cd slime/
71 | micromamba run -n slime pip install -e .
72 | # apply patch
73 | cd /root/sglang
74 | git apply /root/slime/docker/patch/sglang.patch
75 | ```
76 | 


--------------------------------------------------------------------------------
/docs/en/debug.md:
--------------------------------------------------------------------------------
 1 | # Debugging Guide
 2 | 
 3 | [中文版](../zh/debug.md)
 4 | 
 5 | ## Aligning Precision
 6 | 
 7 | During the development of slime, it is often necessary to check if the model's precision is correct. This can be verified in the following ways:
 8 | 
 9 | 1.  **First Training Step**
10 |     1.  Check if the generated `rollout` is coherent. If not, there are two possible reasons:
11 |         * Parameters were not loaded correctly. You need to check the logs for a confirmation that Megatron successfully loaded the checkpoint (ckpt).
12 |         * There was an error in updating the parameters. You can check if all parameters were converted and mapped correctly, or if the parameter names were converted according to the parallelization strategy (e.g., when `pp_size > 1`, check if the layer IDs for the parameters provided by the second stage are correct). A thorough method is to save all parameters in the `load_weights` implementation of the corresponding model in SGLang and verify that they are consistent with the loaded checkpoint.
13 |         * If all parameters are updated correctly and the problem persists, it's possible that some special buffers in SGLang were released during the release process.
14 |         * If you are testing with a pretrained model, you can switch to an instruct version of a model with the same architecture to see if this garbled output is specific to the pretrained model.
15 | 
16 |     2.  Check the printed rollout stats to see if `log_probs` and `ref_log_probs` are exactly equal (meaning KL divergence is 0 in the first step) and their values are small.
17 |         * If they are not exactly equal, it is usually caused by certain non-deterministic kernels in the Transformer Engine, for example:
18 |             * In some versions of Transformer Engine (TE), Megatron requires `--attention-backend flash` to enforce the use of Flash Attention, thereby avoiding numerical instability from the fused attention under Context Parallelism (CP).
19 |         * If the values are large (e.g., > 1), there are generally two possibilities:
20 |             * If the value is extremely large, there is likely a problem with the training configuration.
21 |             * If the value is only slightly larger than the SFT loss, for example, if the log probability of an instruct model reaches 0.8, it might be because the data does not conform to the trained chat template or does not match the cold-start distribution.
22 | 
23 |     3.  When running one inference step per training step (`num_steps_per_rollout == 1`), check if the KL divergence is 0 and if the `grad_norm` is small.
24 |         * This is basically due to some Megatron / TE related bugs, for example:
25 |             * Mixture of Experts (MoE) requires enabling `--moe-permute-fusion`.
26 | 
27 | 2.  **Second Training Step**
28 |     1.  For integrated training and inference, check if the second step can be loaded correctly and whether it results in an Out of Memory (OOM) error.
29 | 
30 | ## Separate Debugging for Training and Inference
31 | 
32 | slime supports debugging the training and inference parts separately, which allows for the following:
33 | 
34 | * When tuning/debugging the inference part, you can start the task with only a few GPUs.
35 | * When tuning/debugging the training part, you can ensure the model input is fixed, removing the randomness of rollouts.
36 | 
37 | Specifically, slime currently provides the following parameters for separate debugging:
38 | 
39 | 1.  `--debug-rollout-only`
40 | 
41 |     When enabled, slime will not load Megatron and will only initialize SGLang. You can use this method to debug the inference part.
42 | 
43 | 2.  `--debug-train-only`
44 | 
45 |     When enabled, slime will not load SGLang and will only initialize Megatron. You can use this method to debug the training part.
46 | 
47 | 3.  `--save-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt`
48 | 
49 |     When enabled, the results of each rollout will be saved. This can be used in conjunction with `--debug-rollout-only`. Note that the data is saved using the format: `args.save_debug_rollout_data.format(rollout_id=rollout_id)`.
50 | 
51 | 4.  `--load-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt`
52 | 
53 |     When enabled, data will be loaded from `args.load_debug_rollout_data.format(rollout_id=rollout_id)`, and SGLang will not be initialized (automatically setting `debug_train_only=True`). This method allows you to fix the input for the training part to tune it, for example, by switching between different parallelization strategies.


--------------------------------------------------------------------------------
/docs/en/models/qwen3-30B-A3B.md:
--------------------------------------------------------------------------------
 1 | # Example: Qwen3-30B-A3B Model
 2 | 
 3 | [中文版](../../zh/models/qwen3-30B-A3B.md)
 4 | 
 5 | ## Environment Preparation
 6 | 
 7 | The environment setup, model download, data, and checkpoint conversion are the same as for the Qwen3-4B model. You can refer to [Example: Qwen3-4B Model](./qwen3-4B.md), replacing mentions of Qwen3-4B with Qwen3-30B-A3B.
 8 | 
 9 | ## Run Training
10 | 
11 | Execute the training script:
12 | 
13 | ```bash
14 | cd /root/slime
15 | bash scripts/run-qwen3-30B-A3B.sh
16 | ```
17 | 
18 | ### Parameter Introduction
19 | 
20 | Here, we will briefly introduce the MoE-related parts in the [run-qwen3-30B-A3B.sh](../../../scripts/run-qwen3-30B-A3B.sh) script.
21 | 
22 | 1.  To support running Qwen3-30B-A3B in an 8xH800 environment, we need to enable Megatron's CPU Adam to save GPU memory. The corresponding configuration is:
23 | 
24 |     ```bash
25 |     OPTIMIZER_ARGS=(
26 |        ...
27 |        --optimizer-cpu-offload
28 |        --overlap-cpu-optimizer-d2h-h2d
29 |        --use-precision-aware-optimizer
30 |     )
31 |     ```
32 | 
33 | 2.  Enable MoE optimization supported by Megatron. The current configuration is tp4, ep8:
34 | 
35 |     ```bash
36 |     PERF_ARGS=(
37 |        --tensor-model-parallel-size 4
38 |        --sequence-parallel
39 |        --pipeline-model-parallel-size 1
40 |        --context-parallel-size 1
41 |        --expert-model-parallel-size 8
42 |        --expert-tensor-parallel-size 1
43 |        ...
44 |     )
45 |     ```
46 | 
47 | 3.  Enable MoE optimization supported by SGLang. The current configuration is ep8:
48 | 
49 |     ```bash
50 |     SGLANG_ARGS=(
51 |        --rollout-num-gpus-per-engine 8
52 |        --sglang-mem-fraction-static 0.5
53 |        --sglang-enable-ep-moe
54 |        --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256)
55 |     )
56 |     ```
57 | 
58 |     Similarly, you can also add DP attention, for example, by configuring:
59 | 
60 |     ```bash
61 |        --sglang-enable-dp-attention
62 |        --sglang-dp-size 8
63 |     ```
64 | 


--------------------------------------------------------------------------------
/docs/en/qa.md:
--------------------------------------------------------------------------------
 1 | # slime FAQ
 2 | 
 3 | [中文版](../zh/qa.md)
 4 | 
 5 | 1.  **Why do I see garbled text during training?**
 6 | 
 7 |     This situation generally occurs because Megatron is not loaded correctly. Please check if there is a corresponding checkpoint in the directory specified by `--load` or `--ref-load`. Note that Megatron can only load a directory that contains a `latest_checkpointed_iteration.txt` file.
 8 | 
 9 |     If you need to specify a particular iteration, you can refer to the current Megatron usage instructions. Generally, you can specify the step number using `--ckpt-step`.
10 | 
11 | 2.  **Why is my task stuck on the Ray submission page?**
12 | 
13 |     Please check whether your task is set up for co-located training and inference or decoupled training and inference.
14 | 
15 |     If it's **co-located** (training and inference share the same GPUs), please check:
16 | 
17 |       * Whether the `--colocate` parameter is set to enable co-located mode.
18 |       * Whether the total number of GPUs for the current task is greater than or equal to `actor_num_nodes * actor_num_gpus_per_node`.
19 | 
20 |     If it's **decoupled**, please check:
21 | 
22 |       * Whether the total number of GPUs for the current task is greater than or equal to `actor_num_nodes * actor_num_gpus_per_node + rollout_num_gpus`.
23 | 
24 | 3.  **Why did I encounter an Out-of-Memory (OOM) error during training? What is `max_tokens_per_gpu` for?**
25 | 
26 |     OOM errors often happen because `max_tokens_per_gpu` is set too high. This parameter defines the maximum number of tokens that can be processed on each GPU during training. If you are concerned about OOM, you can initially set this value to `rollout_max_response_len / cp_size` and then increase it later to improve training efficiency. Note that `--max-tokens-per-gpu` is only active when `--use-dynamic-batch-size` is enabled.
27 | 
28 |     If you still experience OOM with a small `max_tokens_per_gpu`, check if the data generated in a single pass is too long. You may need to enable context parallelism (CP) with `--context-parallel-size`. If you are using custom data generation, check if the total length of multi-turn generations is much longer than expected.
29 | 
30 | 4.  **During multi-node training, what should I do if the `transformers` library reports it cannot find a model?**
31 | 
32 |     This usually happens when multiple processes try to read local files simultaneously using methods like `AutoConfig.from_pretrained` or `AutoModelForCausalLM.from_pretrained`, causing file system write conflicts. You can mitigate this issue by setting the `--model-name` argument.
33 | 
34 | 5.  **How do I resume training?**
35 | 
36 |     Simply set the `--load` directory to your `--save` directory.
37 | 
38 | 6.  **How is the batch size calculated?**
39 | 
40 |     A single rollout uses `rollout_batch_size` prompts. For each prompt, `n_samples_per_prompt` samples are generated. Therefore, one rollout contains a total of `rollout_batch_size * n_samples_per_prompt` data entries.
41 | 
42 |     You can use `--num-steps-per-rollout` to determine how many steps to run per rollout. This is equivalent to setting the `global_batch_size` to `rollout_batch_size * n_samples_per_prompt // num_steps_per_rollout`.
43 | 
44 | 7.  **Does slime perform data packing / variable-length (varlen) processing?**
45 | 
46 |     Yes. Data packing refers to the process of concatenating samples of varying lengths during training to improve GPU utilization. slime performs this operation by default.
47 | 
48 | 8.  **What should I do if the sglang component shows a `Max retries exceeded with url: /get_model_info (Caused by NewConnectionError)` error?**
49 | 
50 |     This issue primarily stems from port conflicts caused by multiple sglang servers running on a single machine. We are currently working with the sglang team to resolve this. A temporary workaround is to minimize the number of sglang servers on a single machine, for example, by setting `tp=8`.
51 | 
52 | 9.  **My gradient norm is very high and the training crashes. What should I do?**
53 | 
54 |     First, ensure that your data and model are compatible. For example, if your data already uses a chat template, check if this template matches the one used by the original model. If the data is correct, please refer to our [Debug Guide](./debug.md) for a more in-depth analysis.
55 | 
56 | 10. **My sglang generation takes an extremely long time, GPU power is maxed out, and there's no output for a long while. Why?**
57 | 
58 |     Please verify that the model corresponding to `--hf-checkpoint` has its stop tokens configured correctly. If not, you can set them using the `--rollout-stop` or `--rollout-stop-token-ids` arguments.
59 | 
60 | 11. **Sglang shows an `an illegal memory access was encountered` error.**
61 | 
62 |     According to the sglang documentation ([https://docs.sglang.ai/references/troubleshooting.html](https://docs.sglang.ai/references/troubleshooting.html)), this could be an OOM error. Consider reducing the value of `--sglang-mem-fraction-static`.
63 | 
64 | 12. **A `JSONDecodeError` occurs related to torch compile/inductor.**
65 | 
66 |     This is generally an issue with the torch compiler's cache read/write operations. You can try adding `"TORCHINDUCTOR_FORCE_DISABLE_CACHES": "1"` to the `env_vars` in your Ray configuration.
67 | 
68 | 13. **Gradient becomes NaN or Inf during training.**
69 | 
70 |     You can try setting the `--no-check-for-nan-in-loss-and-grad` flag to skip the corresponding training steps.


--------------------------------------------------------------------------------
/docs/en/sft.md:
--------------------------------------------------------------------------------
 1 | # Example: Qwen3-4B-Base with OpenHermes-2.5
 2 | 
 3 | [中文版](../zh/sft.md)
 4 | 
 5 | ## Environment Preparation
 6 | 
 7 | First, we need to create a mirror environment and convert the `Qwen3-4B-Base` model by following the [Example: Qwen3-4B Model](./models/qwen3-4B.md).
 8 | 
 9 | After that, we will process the SFT data. Here, we use the classic [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5) as an example. First, we process the data into a format suitable for `slime` to load. You can use the following script to add a column that conforms to the OpenAI message format and save it to `/root/openhermes2_5.parquet`.
10 | 
11 | ```python
12 | from datasets import load_dataset
13 | 
14 | ds = load_dataset("teknium/OpenHermes-2.5")["train"]
15 | 
16 | def convert(sample):
17 |     conversations = sample["conversations"]
18 | 
19 |     def convert_role(role):
20 |         if role == "human":
21 |             return "user"
22 |         elif role == "gpt":
23 |             return "assistant"
24 |         elif role == "system":
25 |             return "system"
26 |         else:
27 |             raise ValueError(f"Unknown role: {role}")
28 | 
29 |     messages = [
30 |         {
31 |             "role": convert_role(turn["from"]),
32 |             "content": turn["value"],
33 |         }
34 |         for turn in conversations
35 |     ]
36 | 
37 |     return {"messages": messages}
38 | 
39 | ds = ds.map(convert)
40 | ds.to_parquet("/root/openhermes2_5.parquet")
41 | ```
42 | 
43 | ## Execute Training
44 | 
45 | Execute the training:
46 | 
47 | ```bash
48 | cd /root/slime
49 | bash script/run-qwen3-4B-base-sft.sh
50 | ```
51 | 
52 | ### Parameter Introduction
53 | 
54 | You can compare [run-qwen3-4B-base-sft.sh](../../scripts/run-qwen3-4B.sh) with [run-qwen3-4B.sh](../../scripts/run-qwen3-4B.sh). You will find that besides changing the model from the instruct version to the base model, the main adjustments are as follows:
55 | 
56 | 1.  Removed `SGLANG_ARGS` and `GRPO_ARGS`. This is because it is not necessary to start SGLang or configure GRPO-related settings during the SFT process.
57 | 
58 | 2.  Renamed `ROLLOUT_ARGS` to `SFT_ARGS` and configured it as follows:
59 | 
60 |     ```bash
61 |     SFT_ARGS=(
62 |        --rollout-function-path slime.rollout.sft_example.generate_rollout
63 |        --prompt-data /root/openhermes2_5.parquet
64 |        --input-key messages
65 |        --rollout-shuffle
66 |        --num-epoch 3
67 |        --rollout-batch-size 128
68 |        --global-batch-size 128
69 | 
70 |        --loss-type sft_loss
71 |        --calculate-per-token-loss
72 |        --disable-compute-advantages-and-returns
73 |        --debug-train-only
74 |     )
75 |     ```
76 | 
77 |     SFT actually reuses the custom rollout functionality of slime. By using `--rollout-function-path`, the data generation part is switched from the RL rollout that uses `sglang` to the SFT version that reads data from a file, which is `slime.rollout.sft_example.generate_rollout`.
78 | 
79 |     For SFT, it is recommended to set `rollout_batch_size` and `global_batch_size` to the same value and not to configure `n_samples_per_prompt`. This is equivalent to training one batch right after reading one batch.
80 | 
81 |     `slime` also supports different loss types, and we configure the SFT loss using `--loss-type sft_loss`.
82 | 
83 |     As for `--calculate-per-token-loss`, this is because `slime` defaults to calculating the per-sample mean for GRPO. In general SFT training, the average is taken over all unmasked tokens in a batch, so it is recommended to configure this.
84 | 
85 |     Finally, `--disable-compute-advantages-and-returns` indicates that there is no need to pre-calculate log probabilities during the SFT process, and `--debug-train-only` means that `sglang` does not need to be initialized.
86 | 
87 | 3.  Used `train_async.py` instead of `train.py`. This is to leverage the asynchronous training process to implement data prefetching.
88 | 


--------------------------------------------------------------------------------
/docs/zh/build.md:
--------------------------------------------------------------------------------
 1 | # 从零搭建环境
 2 | 
 3 | [English](../en/build.md)
 4 | 
 5 | 在不方便直接使用我们预先准备的镜像的情况下,我们提供了如下的搭建环境的方案:
 6 | 
 7 | ## 基于 anaconda / mamba 搭建环境
 8 | 
 9 | 这里我们以 micromamba 为例,在 sglang 的官方镜像 `lmsysorg/sglang:latest` 中搭建一个名为 slime 的 conda 环境:
10 | 
11 | 
12 | ```bash
13 | ####################
14 | # create conda
15 | ####################
16 | yes '' | "${SHELL}" <(curl -L micro.mamba.pm/install.sh)
17 | source ~/.bashrc
18 | micromamba self-update
19 | 
20 | micromamba create -n slime python=3.10 pip -c conda-forge -y
21 | # install cuda-12.6.0 as this is the default cuda version for pytorch
22 | # and apex need this alignment.
23 | micromamba install -n slime cuda cuda-nvtx cuda-nvtx-dev -c nvidia/label/cuda-12.6.0 -y
24 | micromamba install -n slime -c conda-forge cudnn -y
25 | micromamba run -n slime pip install cmake ninja
26 | 
27 | ####################
28 | # sglang deps
29 | ####################
30 | cd /root/
31 | git clone https://github.com/sgl-project/sglang.git --branch v0.4.9 --depth 1
32 | cd /root/sglang/
33 | micromamba run -n slime pip -v install -e "python[all]"
34 | # TODO: change to pip install sglang-router after it has a new release
35 | micromamba run -n slime pip install https://github.com/zhuzilin/sgl-router/releases/download/dev/sglang_router-0.1.4-cp310-cp310-linux_x86_64.whl --force-reinstall
36 | 
37 | ####################
38 | # megatron deps
39 | ####################
40 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" micromamba run -n slime \
41 |   pip -v install --no-build-isolation \
42 |   git+https://github.com/fanshiqing/grouped_gemm@v1.1.4
43 | # apex
44 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" NVCC_APPEND_FLAGS="--threads 4" \
45 | micromamba run -n slime \
46 |   pip -v install --disable-pip-version-check --no-cache-dir \
47 |   --no-build-isolation \
48 |   --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" git+https://github.com/NVIDIA/apex.git
49 | # transformer engine
50 | TORCH_CUDA_ARCH_LIST="9.0;9.0a" micromamba run -n slime \
51 |   pip -v install transformer_engine[pytorch]
52 | # flash attn
53 | # the newest version megatron supports is v2.7.4.post1
54 | micromamba run -n slime pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
55 | # megatron
56 | cd /root/
57 | git clone https://github.com/NVIDIA/Megatron-LM.git
58 | cd Megatron-LM/
59 | micromamba run -n slime pip install -e .
60 | 
61 | ####################
62 | # other deps
63 | ####################
64 | micromamba run -n slime pip install git+https://github.com/zhuzilin/cumem_allocator.git --no-build-isolation
65 | 
66 | ####################
67 | # slime
68 | ####################
69 | cd /root/
70 | git clone https://github.com/THUDM/slime.git
71 | cd slime/
72 | micromamba run -n slime pip install -e .
73 | # apply patch
74 | cd /root/sglang
75 | git apply /root/slime/docker/patch/sglang.patch
76 | ```
77 | 


--------------------------------------------------------------------------------
/docs/zh/debug.md:
--------------------------------------------------------------------------------
 1 | # Debug 指南
 2 | 
 3 | [English](../en/debug.md)
 4 | 
 5 | ## 对齐精度
 6 | 
 7 | 在开发 slime 的过程中,经常会需要检查模型的精度是否正确,可以通过以下方式检查:
 8 | 
 9 | 1. 训练第一步
10 |    1. rollout 的生成是否是人话,如果不是,有以下 2 种可能:
11 |       - 参数没有正常加载。需要查看是否有 megatron 成功加载 ckpt 的日志;
12 |       - 更新参数有误。可以查看是不是所有的参数都做了转换和参数对应,或者参数名是不是根据并行做了转换(例如 pp_size > 1 时,第二个 stage 提供的参数的 layer id 是不是正确的)。一个比较彻底的方法是在对应模型的 sglang 实现的 `load_weights` 中保存所有的参数,查看和加载的 ckpt 中是否一致;
13 |       - 如果所有参数更新都正确,还出现问题,有可能是 sglang 里有一些特殊的 buffer 在 release 的时候被释放了;
14 |       - 如果是用 pretrain 模型进行的测试,可以换成同结构模型的 instruct 版本,查看这种乱码是不是 pretrain 模型特有的。
15 |    2. 查看打印的 rollout stats 的 `log_probs` 和 `ref_log_probs` 是否完全相等(即第一步 kl=0),且值较小
16 |       - 如果不是完全相等的,一般是 transformer engine 中的某些 non-deterministic kernel 导致的,例如:
17 |         - 在某些版本的 te 里,megatron 需要 `--attention-backend flash`,来强制使用 flash attention,从而避免 CP 下 fused attention 的数值不稳定;
18 |       - 如果数值较大(例如 >1),一般有 2 种可能:
19 |         - 如果值非常大,应该是训练配置有问题;
20 |         - 如果值只是比 sft loss 的状态略大,例如 instruct 模型的 logprob 到了 0.8,有可能是数据不符合训练的 chat template,或者不符合冷启动的分布。
21 |    3. 查看在推一训一(`num_steps_per_rollout == 1`),kl 是否为 0,grad_norm 是否较小
22 |       - 基本上就是一些 megatron / te 相关的 bug,例如:
23 |         - moe 需要开启 `--moe-permute-fusion`。
24 | 
25 | 2. 训练第二步
26 |    1. 对于训推一体,查看是否能正确加载第二步,是否会 OOM;
27 | 
28 | ## 训练推理单独 debug
29 | 
30 | slime 支持将训练部分和推理部分分开进行调试,从而实现:
31 | 
32 | - 在调优/debug 推理部分时,只用少量卡就可以启动任务;
33 | - 在调优/debug 训练部分时,可以保证模型输入固定,去除 rollout 的随机性。
34 | 
35 | 具体来说,目前 slime 提供了如下的参数来进行分离调试:
36 | 
37 | 1. `--debug-rollout-only`
38 | 
39 |    开启后,slime 将不会加载 megatron,只初始化 sglang ,可以用这个方法来进行推理部分的调试。
40 | 
41 | 1. `--debug-train-only`
42 | 
43 |    开启后,slime 将不会加载 sglang,只初始化 megatron ,可以用这个方法来进行训练部分的调试。
44 | 
45 | 2. `--save-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt`
46 | 
47 |    开启后,会保存每次 rollout 的结果,可以和 `--debug-rollout-only` 配合使用。注意保存的方式为 `args.save_debug_rollout_data.format(rollout_id=rollout_id)`。
48 | 
49 | 3. `--load-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt`
50 | 
51 |    开启后,会从 `args.load_debug_rollout_data.format(rollout_id=rollout_id)` 来加载数据,并且不会初始化 sglang(自动设置 `debug_train_only=True`)。可以以这种方式来固定训练部分的输入,对训练部分进行调优,例如切换各种并行。
52 | 


--------------------------------------------------------------------------------
/docs/zh/models/qwen3-30B-A3B.md:
--------------------------------------------------------------------------------
 1 | # 示例:Qwen3-30B-A3B 模型
 2 | 
 3 | [English](../../en/models/qwen3-30B-A3B.md)
 4 | 
 5 | ## 环境准备
 6 | 
 7 | 搭建环境、下载模型、数据与 ckpt 转换均与 Qwen3-4B 模型相同,可以参考 [示例:Qwen3-4B 模型](./qwen3-4B.md),将文中 Qwen3-4B 的部分转换为 Qwen3-30B-A3B 即可。
 8 | 
 9 | ## 执行训练
10 | 
11 | 执行训练:
12 | 
13 | ```bash
14 | cd /root/slime
15 | bash scripts/run-qwen3-30B-A3B.sh
16 | ```
17 | 
18 | ### 参数简介
19 | 
20 | 这里我们简单介绍一下脚本 [run-qwen3-30B-A3B.sh](../../../scripts/run-qwen3-30B-A3B.sh) 中与 MoE 相关的部分。
21 | 
22 | 1. 为了支持在 8xH800 环境中运行 Qwen3-30B-A3B,我们需要开启 megatron 的 CPU Adam 以节省显存,对应配置为:
23 | 
24 |    ```bash
25 |    OPTIMIZER_ARGS=(
26 |       ...
27 |       --optimizer-cpu-offload
28 |       --overlap-cpu-optimizer-d2h-h2d
29 |       --use-precision-aware-optimizer
30 |    )
31 |    ```
32 | 
33 | 2. 开启 megatron 支持的 moe 优化,当前配置为 tp4, ep8:
34 | 
35 |    ```bash
36 |    PERF_ARGS=(
37 |       --tensor-model-parallel-size 4
38 |       --sequence-parallel
39 |       --pipeline-model-parallel-size 1
40 |       --context-parallel-size 1
41 |       --expert-model-parallel-size 8
42 |       --expert-tensor-parallel-size 1
43 |       ...
44 |    )
45 |    ```
46 | 
47 | 3. 开启 sglang 支持的 moe 优化,当前配置为 ep8:
48 | 
49 |    ```bash
50 |    SGLANG_ARGS=(
51 |       --rollout-num-gpus-per-engine 8
52 |       --sglang-mem-fraction-static 0.5
53 |       --sglang-enable-ep-moe
54 |       --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256)
55 |    )
56 |    ```
57 | 
58 |    类似地,也可以加入 dp attention,例如配置上:
59 | 
60 |    ```bash
61 |       --sglang-enable-dp-attention
62 |       --sglang-dp-size 8
63 |    ```
64 | 


--------------------------------------------------------------------------------
/docs/zh/qa.md:
--------------------------------------------------------------------------------
 1 | # slime 常见 Q&A
 2 | 
 3 | [English](../en/qa.md)
 4 | 
 5 | 1. **训练过程中为什么会出现乱码?**
 6 | 
 7 |    一般来说这种情况是 megatron 没有被正确加载。请检查 `--load` 或 `--ref-load` 是否有对应的 ckpt。注意 megatron 只能加载其中有 `latest_checkpointed_iteration.txt` 的目录。
 8 | 
 9 |    如果需要指定某个特定的 iter,可以查看当前 megatron 的使用方法,一般是可以通过 `--ckpt-step` 来指定步数。
10 | 
11 | 1. **为什么我的任务一直卡在 ray 提交的页面上?**
12 | 
13 |    请先检查你需要跑的任务是训推一体的,还是训推分离的。
14 | 
15 |    如果是训推一体,即训练和推理共用 GPU,请检查
16 | 
17 |    - 是否设置了 `--colocate` 参数开启训推一体;
18 |    - 当前任务的总卡数是否大于等于 `actor_num_nodes * actor_num_gpus_per_node`
19 | 
20 |    如果是训推分离,请检查:
21 | 
22 |    - 当前任务的总卡数是否大于等于 `actor_num_nodes * actor_num_gpus_per_node + rollout_num_gpus`
23 | 
24 | 1. **为什么训着训着 OOM 了?`max_tokens_per_gpu` 是干什么用的?**
25 | 
26 |    OOM 往往是因为 `max_tokens_per_gpu` 设置过高了。 `max_tokens_per_gpu` 是指在训练过程中,每张 GPU 上最多可以放多少 token。如果担心 OOM 的话,可以先把这个值设成 `rollout_max_response_len / cp_size`,之后再为了提升训练效率来增大这个值。`--max-tokens-per-gpu` 只有在开启 `--use-dynamic-batch-size` 的情况下才会启用。
27 | 
28 |    如果 `max_tokens_per_gpu` 很小,还会 oom,可以检查一下是否单次生成的数据太长了,需要开启 cp(`--context-parallel-size`)。如果进行了自定义的数据生成,可以看一下是否在多轮生成的情况下,生成的总长度比预期的长很多。
29 | 
30 | 1. **多机训练的时候,遇到了 transformers 库找不到某个模型的错误该怎么办?**
31 | 
32 |    这种情况一般是因为多个进程都在通过类似于 `AutoConfig.from_pretrained` 或者 `AutoModelForCausalLM.from_pretrained` 的方式读取本地文件,出现了文件系统的写冲突。可以通过设置 `--model-name` 缓解这一问题。
33 | 
34 | 1. **如何续训?**
35 | 
36 |    直接将 `--load` 设置为 `--save` 的目录即可。
37 | 
38 | 1. **batch size 是如何计算的?**
39 | 
40 |    一个 rollout 会用 `rollout_batch_size` 条 prompt,每一条会采 `n_samples_per_prompt` 条,所以一个 rollout 共 `rollout_batch_size * n_samples_per_prompt` 条数据。
41 | 
42 |    可以用 `--num-steps-per-rollout` 来决定每一个 rollout 跑多少步。这相当于是把 `global_batch_size` 设置成 `rollout_batch_size * n_samples_per_prompt // num_steps_per_rollout`。
43 | 
44 | 1. **slime 是否进行了 data packing / varlen 处理?**
45 | 
46 |    data packing 是指在训练过程中,将长短不一的 sample 拼接到一起,从而提升训练的利用率。slime 默认会进行这样的操作。
47 | 
48 | 1. **sglang 部分出现 `Max retries exceeded with url: /get_model_info (Caused by NewConnectionError` 的问题怎么办?**
49 | 
50 |    这个问题主要来源于单机内多个 sglang server 导致的端口冲突,目前我们仍在和 sglang 团队一起解决这个问题。一个临时的缓解方案是尽可能减少单机内的 sglang server 数量,例如设置 tp=8。
51 | 
52 | 1. **grad norm 好高,训练训崩了怎么办?**
53 | 
54 |    首先请确保数据和模型是匹配的,例如说,如果数据是实现已经做好 chat template 的了,这个 chat template 是否和原模型一致。如果数据正确的话,可以参考 [debug 指南](./debug.md) 进行更深入的分析。
55 | 
56 | 1. **我的 sglang 生成时间特别特别久,gpu 功率都打满了,跑了好久好没有输出是为什么?**
57 | 
58 |    请确认一下 `--hf-checkpoint` 对应的模型是否正确设置了 stop token,如果没有,可以通过 `--rollout-stop` 或者 `--rollout-stop-token-ids` 来进行设置。
59 | 
60 | 1. **sglang 出现 an illegal memory access was encountered**
61 | 
62 |    根据 sglang 的文档(https://docs.sglang.ai/references/troubleshooting.html),有可能是 OOM 了,可以考虑缩小 `--sglang-mem-fraction-static`。
63 | 
64 | 1. **出现 torch compile/inducer 的 `JSONDecodeError`**
65 | 
66 |    一般是 torch compile 读写 cache 出现的问题。可以考虑在 ray 的 env_var 里加上 `"TORCHINDUCTOR_FORCE_DISABLE_CACHES": "1"`。
67 | 
68 | 1. **训练出现 grad NaN 或者 Inf 的情况**
69 | 
70 |    可以通过设置 `--no-check-for-nan-in-loss-and-grad` 来尝试跳过对应的训练步。
71 | 


--------------------------------------------------------------------------------
/docs/zh/sft.md:
--------------------------------------------------------------------------------
 1 | # 示例:Qwen3-4B-Base + OpenHermes-2.5
 2 | 
 3 | [English](../en/sft.md)
 4 | 
 5 | ## 环境准备
 6 | 
 7 | 首先需要我们仿照 [示例:Qwen3-4B 模型](./models/qwen3-4B.md) 创建镜像环境与转换 `Qwen3-4B-Base` 模型。
 8 | 
 9 | 之后,我们处理 sft 数据。这里我们以经典的 [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5) 为例,首先把数据处理成适合 slime 加载的格式,可以用如下的脚本进行处理,增加一个符合 openai message 格式的列,并保存在 `/root/openhermes2_5.parquet`。
10 | 
11 | ```python
12 | from datasets import load_dataset
13 | 
14 | ds = load_dataset("teknium/OpenHermes-2.5")["train"]
15 | 
16 | def convert(sample):
17 |     conversations = sample["conversations"]
18 | 
19 |     def convert_role(role):
20 |         if role == "human":
21 |             return "user"
22 |         elif role == "gpt":
23 |             return "assistant"
24 |         elif role == "system":
25 |             return "system"
26 |         else:
27 |             raise ValueError(f"Unknown role: {role}")
28 | 
29 |     messages = [
30 |         {
31 |             "role": convert_role(turn["from"]),
32 |             "content": turn["value"],
33 |         }
34 |         for turn in conversations
35 |     ]
36 | 
37 |     return {"messages": messages}
38 | 
39 | ds = ds.map(convert)
40 | ds.to_parquet("/root/openhermes2_5.parquet")
41 | ```
42 | 
43 | ## 执行训练
44 | 
45 | 执行训练:
46 | 
47 | ```bash
48 | cd /root/slime
49 | bash script/run-qwen3-4B-base-sft.sh
50 | ```
51 | 
52 | ### 参数简介
53 | 
54 | 可以将 [run-qwen3-4B-base-sft.sh](../../scripts/run-qwen3-4B-base-sft.sh) 与 [run-qwen3-4B.sh](../../scripts/run-qwen3-4B.sh) 进行对比。会发现除了我们将模型由 instruct 模型换为了 base 模型之外,主要进行了如下的几个调整:
55 | 
56 | 1. 移除了 `SGLANG_ARGS` 和 `GRPO_ARGS`。这是因为 sft 的过程中不需要启动 sglang 或者做 grpo 相关的配置;
57 | 
58 | 2. 将 `ROLLOUT_ARGS` 改名为了 `SFT_ARGS`,并配置为:
59 | 
60 |    ```bash
61 |    SFT_ARGS=(
62 |       --rollout-function-path slime.rollout.sft_example.generate_rollout
63 |       --prompt-data /root/openhermes2_5.parquet
64 |       --input-key messages
65 |       --rollout-shuffle
66 |       --num-epoch 3
67 |       --rollout-batch-size 128
68 |       --global-batch-size 128
69 |    
70 |       --loss-type sft_loss
71 |       --calculate-per-token-loss
72 |       --disable-compute-advantages-and-returns
73 |       --debug-train-only
74 |    )
75 |    ```
76 | 
77 |    slime 中的 sft 实际上是复用了 slime 的 custom rollout 功能,通过 `--rollout-function-path` 将数据生成部分从使用 sglang 的 RL rollout,切换成了从文件中读取数据的 sft 版本,即 `slime.rollout.sft_example.generate_rollout`。
78 | 
79 |    对于 sft 来说,建议将 `rollout_batch_size` 与 `global_batch_size` 设置成相同的,并不要配置 `n_samples_per_prompt`,这样相当于是读一个 batch 就训一个 batch。
80 | 
81 |    slime 还支持不同的 loss 类型,我们就是通过 `--loss-type sft_loss` 配置上 sft loss 的。
82 | 
83 |    至于 `--calculate-per-token-loss`,这是因为 slime 默认是以 GRPO 的 per sample mean 进行计算的,而一般 sft 训练都是按一个 batch 的所有不被 mask 的 token 取平均,所以建议配置上。
84 | 
85 |    最后 `--disable-compute-advantages-and-returns` 表示 sft 的过程中不需要预先计算 log prob,`--debug-train-only` 表示不需要初始化 sglang。
86 | 
87 | 3. 使用了 `train_async.py` 而不是 `train.py`。这是为了利用异步训练的流程,来实现数据 prefetch。
88 | 


--------------------------------------------------------------------------------
/examples/search-r1/README.md:
--------------------------------------------------------------------------------
 1 | # Example: Search-R1 lite
 2 | 
 3 | [中文版](./README_zh.md)
 4 | 
 5 | This is a minimal reproduction of [Search-R1](https://github.com/PeterGriffinJin/Search-R1) and an example of using multi-turn conversation and tool-calling in slime.
 6 | 
 7 | ## Environment Setup
 8 | 
 9 | Use the `zhuzilin/slime:latest` image and initialize the environment required for Search-R1:
10 | 
11 | ```bash
12 | cd /root/
13 | git clone https://github.com/THUDM/slime.git
14 | pip install -e .
15 | # for Search R1
16 | pip install chardet
17 | ```
18 | 
19 | Please refer to the script provided in Search-R1 to download the data:
20 | 
21 | ```bash
22 | git clone https://github.com/PeterGriffinJin/Search-R1.git
23 | cd Search-R1/
24 | python scripts/data_process/nq_search.py --local_dir /root/nq_search/
25 | ```
26 | 
27 | Initialize the Qwen2.5-3B model:
28 | 
29 | ```bash
30 | # hf checkpoint
31 | huggingface-cli download Qwen/Qwen2.5-3B --local-dir /root/Qwen2.5-3B
32 | 
33 | # mcore checkpoint
34 | cd /root/slime
35 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
36 |     --hf-checkpoint /root/Qwen2.5-3B \
37 |     --save /root/Qwen2.5-3B_torch_dist
38 | ```
39 | 
40 | ## Running the Script
41 | 
42 | You need to configure your serper.dev API in `generate_with_search.py`:
43 | 
44 | ```python
45 | SEARCH_R1_CONFIGS = {
46 |     "max_turns": 3,
47 |     "topk": 3,
48 |     "google_api_key": "YOUR_API_KEY",  # Replace with your actual API key
49 |     "snippet_only": True,  # Set to True to only return snippets
50 |     "proxy": None,  # Set to your proxy if needed
51 |     "search_concurrency": 256,
52 |     # rm
53 |     "format_score": 0.2,
54 | }
55 | ```
56 | 
57 | And run:
58 | 
59 | ```bash
60 | cd slime/
61 | bash examples/search-r1/run_qwen2.5_3B.sh
62 | ```
63 | 
64 | ## Code Structure
65 | 
66 | To implement multi-turn conversation + tool-calling in slime, you only need to implement a custom data generation function and a reward model for the task. These correspond to the following 2 configuration items in the startup script:
67 | 
68 | ```bash
69 | CUSTOM_ARGS=(
70 |    --custom-generate-function-path generate_with_search.generate
71 |    --custom-rm-path generate_with_search.reward_func
72 | )
73 | ```
74 | 
75 | These are the `generate` and `reward_func` functions in `generate_with_search.py`.
76 | 


--------------------------------------------------------------------------------
/examples/search-r1/README_zh.md:
--------------------------------------------------------------------------------
 1 | # 示例:Search-R1 lite
 2 | 
 3 | [English](./README.md)
 4 | 
 5 | 这里是一个对 [Search-R1](https://github.com/PeterGriffinJin/Search-R1) 的简单复现,以及是一个在 slime 中使用多轮对话和工具调用的样例。
 6 | 
 7 | ## 配置环境
 8 | 
 9 | 使用 `zhuzilin/slime:latest` 镜像,并初始化 Search-R1 需要的环境:
10 | 
11 | ```bash
12 | cd /root/
13 | git clone https://github.com/THUDM/slime.git
14 | pip install -e .
15 | # for Search R1
16 | pip install chardet
17 | ```
18 | 
19 | 请参照 Search-R1 中提供的脚本下载数据:
20 | 
21 | ```bash
22 | git clone https://github.com/PeterGriffinJin/Search-R1.git
23 | cd Search-R1/
24 | python scripts/data_process/nq_search.py --local_dir /root/nq_search/
25 | ```
26 | 
27 | 初始化 Qwen2.5-3B 模型:
28 | 
29 | ```bash
30 | # hf checkpoint
31 | huggingface-cli download Qwen/Qwen2.5-3B --local-dir /root/Qwen2.5-3B
32 | 
33 | # mcore checkpoint
34 | cd /root/slime
35 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
36 |     --hf-checkpoint /root/Qwen2.5-3B \
37 |     --save /root/Qwen2.5-3B_torch_dist
38 | ```
39 | 
40 | ## 运行脚本
41 | 
42 | 需要将你的 serper.dev API 配置在 `generate_with_search.py` 中:
43 | 
44 | ```python
45 | SEARCH_R1_CONFIGS = {
46 |     "max_turns": 3,
47 |     "topk": 3,
48 |     "google_api_key": "YOUR_API_KEY",  # Replace with your actual API key
49 |     "snippet_only": True,  # Set to True to only return snippets
50 |     "proxy": None,  # Set to your proxy if needed
51 |     "search_concurrency": 256,
52 |     # rm
53 |     "format_score": 0.2,
54 | }
55 | ```
56 | 
57 | 并运行:
58 | 
59 | ```bash
60 | cd slime/
61 | bash examples/search-r1/run_qwen2.5_3B.sh
62 | ```
63 | 
64 | ## 代码结构
65 | 
66 | 为了实现多轮 + 工具调用,在 slime 中只需要实现一个自定义的数据生成函数,以及一个任务所需的 reward model,对应启动脚本中的这 2 个配置项:
67 | 
68 | ```bash
69 | CUSTOM_ARGS=(
70 |    --custom-generate-function-path generate_with_search.generate
71 |    --custom-rm-path generate_with_search.reward_func
72 | )
73 | ```
74 | 
75 | 也就是 `generate_with_search.py` 中的 `generate` 和 `reward_func` 两个函数。
76 | 


--------------------------------------------------------------------------------
/examples/search-r1/generate_with_search.py:
--------------------------------------------------------------------------------
  1 | # Adapted form https://github.com/PeterGriffinJin/Search-R1/blob/ceee7b89655ed52f205b9beb98e1190c3eedcfb0/search_r1/llm_agent/generation.py
  2 | import asyncio
  3 | import re
  4 | 
  5 | from google_search_server import google_search
  6 | from qa_em_format import compute_score_em
  7 | 
  8 | from slime.rollout.sglang_example import GenerateState
  9 | from slime.utils.http_utils import post
 10 | from slime.utils.types import Sample
 11 | 
 12 | SEARCH_R1_CONFIGS = {
 13 |     "max_turns": 3,
 14 |     "topk": 3,
 15 |     "google_api_key": "YOUR_API_KEY",  # Replace with your actual API key
 16 |     "snippet_only": True,  # Set to True to only return snippets
 17 |     "proxy": None,  # Set to your proxy if needed
 18 |     "search_concurrency": 256,
 19 |     # rm
 20 |     "format_score": 0.2,
 21 | }
 22 | 
 23 | 
 24 | SEMAPHORE = asyncio.Semaphore(SEARCH_R1_CONFIGS["search_concurrency"])
 25 | 
 26 | 
 27 | def _passages2string(retrieval_result):
 28 |     format_reference = ""
 29 |     for idx, doc_item in enumerate(retrieval_result):
 30 | 
 31 |         content = doc_item["document"]["contents"]
 32 |         title = content.split("\n")[0]
 33 |         text = "\n".join(content.split("\n")[1:])
 34 |         format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
 35 | 
 36 |     return format_reference
 37 | 
 38 | 
 39 | async def search(query: str) -> str:
 40 |     result = await google_search(
 41 |         SEARCH_R1_CONFIGS["google_api_key"],
 42 |         query,
 43 |         SEARCH_R1_CONFIGS["topk"],
 44 |         snippet_only=SEARCH_R1_CONFIGS["snippet_only"],
 45 |         proxy=SEARCH_R1_CONFIGS["proxy"],
 46 |     )
 47 |     return _passages2string(result)
 48 | 
 49 | 
 50 | def postprocess_responses(resp: str) -> str:
 51 |     return (
 52 |         resp.split("</search>")[0] + "</search>"
 53 |         if "</search>" in resp
 54 |         else resp.split("</answer>")[0] + "</answer>" if "</answer>" in resp else resp
 55 |     )
 56 | 
 57 | 
 58 | def postprocess_predictions(prediction: str):
 59 |     pattern = r"<(search|answer)>(.*?)</\1>"
 60 |     match = re.search(pattern, prediction, re.DOTALL)
 61 |     if match:
 62 |         content = match.group(2).strip()  # Return only the content inside the tags
 63 |         action = match.group(1)
 64 |     else:
 65 |         content = ""
 66 |         action = None
 67 | 
 68 |     return action, content
 69 | 
 70 | 
 71 | async def execute_predictions(prediction: str) -> str:
 72 |     action, content = postprocess_predictions(prediction)
 73 | 
 74 |     if action == "search":
 75 |         search_query = content
 76 |         async with SEMAPHORE:
 77 |             search_results = await search(search_query)
 78 |         next_obs = f"\n\n<information>{search_results.strip()}</information>\n\n"
 79 |         done = False
 80 |     elif action == "answer":
 81 |         next_obs = ""
 82 |         done = True
 83 |     else:
 84 |         next_obs = f"\nMy previous action is invalid. \
 85 | If I want to search, I should put the query between <search> and </search>. \
 86 | If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n"
 87 |         done = False
 88 | 
 89 |     return next_obs, done
 90 | 
 91 | 
 92 | async def generate(args, sample: Sample, sampling_params) -> Sample:
 93 |     assert not args.partial_rollout, f"Partial rollout is not supported for this function at the moment."
 94 | 
 95 |     state = GenerateState(args)
 96 | 
 97 |     url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
 98 | 
 99 |     # Handle partial rollout samples: continue generation from existing response
100 |     prompt = sample.prompt
101 |     prompt_tokens_ids = state.tokenizer(sample.prompt, add_special_tokens=False)["input_ids"]
102 |     response = ""
103 |     response_token_ids = []
104 |     loss_masks = []
105 |     for _ in range(SEARCH_R1_CONFIGS["max_turns"]):
106 |         payload = {
107 |             "text": prompt + response,
108 |             "sampling_params": sampling_params,
109 |         }
110 |         output = await post(url, payload, use_http2=args.use_http2)
111 | 
112 |         # abort
113 |         if output["meta_info"]["finish_reason"]["type"] == "abort":
114 |             sample.status = Sample.Status.ABORTED
115 |             return sample
116 | 
117 |         cur_response = output["text"]
118 |         cur_response = postprocess_responses(cur_response)
119 | 
120 |         cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"]
121 |         response += cur_response
122 |         response_token_ids += cur_response_token_ids
123 |         loss_masks += [1] * len(cur_response_token_ids)
124 | 
125 |         if output["meta_info"]["finish_reason"]["type"] == "length":
126 |             break
127 | 
128 |         next_obs, done = await execute_predictions(cur_response)
129 |         if done:
130 |             break
131 | 
132 |         assert next_obs != "", "Next observation should not be empty."
133 |         obs_tokens_ids = state.tokenizer(next_obs, add_special_tokens=False)["input_ids"]
134 |         response += next_obs
135 |         response_token_ids += obs_tokens_ids
136 |         loss_masks += [0] * len(obs_tokens_ids)
137 | 
138 |     sample.tokens = prompt_tokens_ids + response_token_ids
139 |     sample.response_length = len(response_token_ids)
140 |     sample.response = response
141 |     sample.loss_masks = loss_masks
142 |     match output["meta_info"]["finish_reason"]["type"]:
143 |         case "length":
144 |             sample.status = Sample.Status.TRUNCATED
145 |         case "abort":
146 |             sample.status = Sample.Status.ABORTED
147 |         case "stop":
148 |             sample.status = Sample.Status.COMPLETED
149 | 
150 |     return sample
151 | 
152 | 
153 | async def reward_func(args, sample, **kwargs):
154 |     """The reward function for retrieval-based question answering.
155 | 
156 |     Args:
157 |         args: the arguments
158 |         sample: the sample to evaluate
159 |     """
160 |     if not isinstance(sample, Sample):
161 |         raise TypeError("Sample must be an instance of Sample class.")
162 | 
163 |     score = compute_score_em(
164 |         solution_str=sample.prompt + sample.response,
165 |         ground_truth=sample.label["ground_truth"],
166 |         format_score=SEARCH_R1_CONFIGS["format_score"],
167 |     )
168 | 
169 |     return score
170 | 


--------------------------------------------------------------------------------
/examples/search-r1/google_search_server.py:
--------------------------------------------------------------------------------
  1 | import asyncio
  2 | import os
  3 | import random
  4 | import re
  5 | from typing import Dict, List
  6 | 
  7 | import aiohttp
  8 | import chardet
  9 | 
 10 | 
 11 | # --- Utilities ---
 12 | def parse_snippet(snippet: str) -> List[str]:
 13 |     segments = snippet.split("...")
 14 |     return [s.strip() for s in segments if len(s.strip().split()) > 5]
 15 | 
 16 | 
 17 | def sanitize_search_query(query: str) -> str:
 18 |     # Remove or replace special characters that might cause issues.
 19 |     # This is a basic example; you might need to add more characters or patterns.
 20 |     sanitized_query = re.sub(r"[^\w\s]", " ", query)  # Replace non-alphanumeric and non-whitespace with spaces.
 21 |     sanitized_query = re.sub(
 22 |         r"[\t\r\f\v\n]", " ", sanitized_query
 23 |     )  # replace tab, return, formfeed, vertical tab with spaces.
 24 |     sanitized_query = re.sub(
 25 |         r"\s+", " ", sanitized_query
 26 |     ).strip()  # remove duplicate spaces, and trailing/leading spaces.
 27 | 
 28 |     return sanitized_query
 29 | 
 30 | 
 31 | def filter_links(search_results: List[Dict]) -> List[str]:
 32 |     links = []
 33 |     for result in search_results:
 34 |         for item in result.get("items", []):
 35 |             if "mime" in item:
 36 |                 continue
 37 |             ext = os.path.splitext(item["link"])[1]
 38 |             if ext in ["", ".html", ".htm", ".shtml"]:
 39 |                 links.append(item["link"])
 40 |     return links
 41 | 
 42 | 
 43 | async def fetch(session: aiohttp.ClientSession, url: str, semaphore: asyncio.Semaphore) -> str:
 44 |     if url == "":
 45 |         return ""
 46 |     user_agents = [
 47 |         "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P)...",
 48 |         "Mozilla/5.0 AppleWebKit/537.36...",
 49 |         "Mozilla/5.0 (compatible; Googlebot/2.1; +https://www.google.com/bot.html)",
 50 |     ]
 51 |     headers = {"User-Agent": random.choice(user_agents)}
 52 | 
 53 |     async with semaphore:
 54 |         try:
 55 |             async with session.get(url, headers=headers) as response:
 56 |                 raw = await response.read()
 57 |                 detected = chardet.detect(raw)
 58 |                 encoding = detected["encoding"] or "utf-8"
 59 |                 return raw.decode(encoding, errors="ignore")
 60 |         except (aiohttp.ClientError, asyncio.TimeoutError):
 61 |             return ""
 62 | 
 63 | 
 64 | async def fetch_all(urls: List[str], limit: int = 8) -> List[str]:
 65 |     semaphore = asyncio.Semaphore(limit)
 66 |     timeout = aiohttp.ClientTimeout(total=5)
 67 |     connector = aiohttp.TCPConnector(limit_per_host=limit, force_close=True)
 68 | 
 69 |     async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
 70 |         tasks = [fetch(session, url, semaphore) for url in urls]
 71 |         return await asyncio.gather(*tasks)
 72 | 
 73 | 
 74 | def collect_context(snippet: str, doc: str) -> str:
 75 |     snippets = parse_snippet(snippet)
 76 |     ctx_paras = []
 77 | 
 78 |     for s in snippets:
 79 |         pos = doc.replace("\n", " ").find(s)
 80 |         if pos == -1:
 81 |             continue
 82 |         sta = pos
 83 |         while sta > 0 and doc[sta] != "\n":
 84 |             sta -= 1
 85 |         end = pos + len(s)
 86 |         while end < len(doc) and doc[end] != "\n":
 87 |             end += 1
 88 |         para = doc[sta:end].strip()
 89 |         if para not in ctx_paras:
 90 |             ctx_paras.append(para)
 91 | 
 92 |     return "\n".join(ctx_paras)
 93 | 
 94 | 
 95 | async def google_search(api_key, query, top_k=5, timeout: int = 60, proxy=None, snippet_only=False) -> List[Dict]:
 96 |     timeout_obj = aiohttp.ClientTimeout(total=timeout)
 97 |     session_kwargs = {}
 98 |     if proxy:
 99 |         session_kwargs["proxy"] = proxy
100 |     async with aiohttp.ClientSession(**session_kwargs) as session:
101 |         async with session.post(
102 |             "https://google.serper.dev/search",
103 |             json={
104 |                 "q": query,
105 |                 "num": top_k,
106 |                 "gl": "us",
107 |                 "hl": "en",
108 |             },
109 |             headers={
110 |                 "Content-Type": "application/json",
111 |                 "X-API-KEY": api_key,
112 |             },
113 |             timeout=timeout_obj,
114 |         ) as resp:
115 |             resp.raise_for_status()
116 |             response = await resp.json()
117 |             items = response.get("organic", [])
118 | 
119 |     contexts = []
120 |     if snippet_only:
121 |         for item in items:
122 |             title = item.get("title", "")
123 |             context = " ".join(parse_snippet(item.get("snippet", "")))
124 |             if title != "" or context != "":
125 |                 title = "No title." if not title else title
126 |                 context = "No snippet available." if not context else context
127 |                 contexts.append(
128 |                     {
129 |                         "document": {"contents": f'"{title}"\n{context}'},
130 |                     }
131 |                 )
132 |     else:
133 |         links = [item.get("link", "") for item in items if "link" in item]
134 |         web_contents = await fetch_all(links)
135 |         contexts = []
136 |         for i, item in enumerate(items):
137 |             title = item.get("title", "")
138 |             snippet = item.get("snippet", "")
139 | 
140 |             context = collect_context(snippet, web_contents[i])
141 |             if title != "" or context != "":
142 |                 title = "No title." if not title else title
143 |                 context = "No snippet available." if not context else context
144 |                 contexts.append(
145 |                     {
146 |                         "document": {"contents": f'"{title}"\n{context}'},
147 |                     }
148 |                 )
149 | 
150 |     return contexts
151 | 


--------------------------------------------------------------------------------
/examples/search-r1/run_qwen2.5_3B.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | # for rerun the task
  4 | pkill -9 sglang
  5 | sleep 3
  6 | ray stop --force
  7 | pkill -9 ray
  8 | pkill -9 python
  9 | sleep 3
 10 | pkill -9 ray
 11 | pkill -9 python
 12 | 
 13 | set -ex
 14 | 
 15 | # will prevent ray from buffering stdout/stderr
 16 | export PYTHONBUFFERED=16
 17 | 
 18 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
 19 | source "${SCRIPT_DIR}/../../scripts/models/qwen2.5-3B.sh"
 20 | 
 21 | CKPT_ARGS=(
 22 |    --hf-checkpoint /root/Qwen2.5-3B/
 23 |    --ref-load /root/Qwen2.5-3B_torch_dist/
 24 |    --load /root/Qwen2.5-3B_slime/
 25 |    --save /root/Qwen2.5-3B_slime/
 26 |    --save-interval 20
 27 | )
 28 | 
 29 | ROLLOUT_ARGS=(
 30 |    --prompt-data /root/nq_search/train.parquet
 31 |    --input-key prompt
 32 |    --label-key reward_model
 33 |    --apply-chat-template
 34 |    --rollout-shuffle
 35 |    --num-rollout 3000
 36 |    --rollout-batch-size 32
 37 |    --n-samples-per-prompt 8
 38 |    --rollout-max-response-len 512
 39 |    --rollout-temperature 0.8
 40 | 
 41 |    --global-batch-size 256
 42 |    --balance-data
 43 | )
 44 | 
 45 | PERF_ARGS=(
 46 |    --tensor-model-parallel-size 2
 47 |    --sequence-parallel
 48 |    --pipeline-model-parallel-size 1
 49 |    --context-parallel-size 1
 50 |    --expert-model-parallel-size 1
 51 |    --expert-tensor-parallel-size 1
 52 | 
 53 |    --recompute-granularity full
 54 |    --recompute-method uniform
 55 |    --recompute-num-layers 1
 56 | 
 57 |    # --micro-batch-size 1
 58 |    --use-dynamic-batch-size
 59 |    --max-tokens-per-gpu 9216
 60 | )
 61 | 
 62 | GRPO_ARGS=(
 63 |    --advantage-estimator grpo
 64 |    --use-kl-loss
 65 |    --kl-loss-coef 0.00
 66 |    --kl-loss-type low_var_kl
 67 |    --kl-coef 0.00
 68 |    --entropy-coef 0.00
 69 |    --eps-clip 0.2
 70 |    --eps-clip-high 0.28
 71 | )
 72 | 
 73 | OPTIMIZER_ARGS=(
 74 |    --optimizer adam
 75 |    --lr 1e-6
 76 |    --lr-decay-style constant
 77 |    --weight-decay 0.1
 78 |    --adam-beta1 0.9
 79 |    --adam-beta2 0.98
 80 | )
 81 | 
 82 | WANDB_ARGS=(
 83 |    # --use-wandb
 84 |    # --wandb-project slime-dev
 85 |    # --wandb-group search-r1_qwen2.5-3B-test
 86 |    # --wandb-key ${WANDB_KEY}
 87 | )
 88 | 
 89 | SGLANG_ARGS=(
 90 |    --rollout-num-gpus-per-engine 2
 91 |    --sglang-mem-fraction-static 0.7
 92 | )
 93 | 
 94 | MISC_ARGS=(
 95 |    # default dropout in megatron is 0.1
 96 |    --attention-dropout 0.0
 97 |    --hidden-dropout 0.0
 98 |    # should be good for model performance
 99 |    --accumulate-allreduce-grads-in-fp32
100 |    --attention-softmax-in-fp32
101 |    # need to comment this when using model with MLA
102 |    --attention-backend flash
103 | )
104 | 
105 | CUSTOM_ARGS=(
106 |    --custom-generate-function-path generate_with_search.generate
107 |    --custom-rm-path generate_with_search.reward_func
108 | )
109 | 
110 | # launch the master node of ray in container
111 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
112 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
113 | 
114 | RUNTIME_ENV_JSON="{
115 |   \"env_vars\": {
116 |     \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\",
117 |     \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\"
118 |   }
119 | }"
120 | 
121 | ray job submit --address="http://127.0.0.1:8265" \
122 |    --runtime-env-json="${RUNTIME_ENV_JSON}" \
123 |    -- python3 train.py \
124 |    --actor-num-nodes 1 \
125 |    --actor-num-gpus-per-node 4 \
126 |    --rollout-num-gpus 4 \
127 |    --colocate \
128 |    ${MODEL_ARGS[@]} \
129 |    ${CKPT_ARGS[@]} \
130 |    ${ROLLOUT_ARGS[@]} \
131 |    ${OPTIMIZER_ARGS[@]} \
132 |    ${GRPO_ARGS[@]} \
133 |    ${DISTRIBUTED_ARGS[@]} \
134 |    ${WANDB_ARGS[@]} \
135 |    ${PERF_ARGS[@]} \
136 |    ${SGLANG_ARGS[@]} \
137 |    ${MISC_ARGS[@]} \
138 |    ${CUSTOM_ARGS[@]}
139 | 


--------------------------------------------------------------------------------
/imgs/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/imgs/arch.png


--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
 1 | [build-system]
 2 | requires = [
 3 |     "packaging",
 4 |     "setuptools >= 49.4.0",
 5 |     "wheel",
 6 | ]
 7 | build-backend = "setuptools.build_meta"
 8 | 
 9 | [tool.isort]
10 | profile = "black"  # black-compatible
11 | line_length = 119  # should match black parameters
12 | ignore_whitespace = true  # ignore whitespace for compatibility with the initial style
13 | py_version = 310  # python 3.10 as a target version
14 | sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
15 | default_section = "THIRDPARTY"
16 | extend_skip = ["setup.py", "docs/source/conf.py"]
17 | 
18 | 
19 | [tool.black]
20 | line_length = 119
21 | 
22 | [tool.ruff]
23 | line-length = 119
24 | 
25 | [tool.pytest.ini_options]
26 | # durations=0 will display all tests execution time, sorted in ascending order starting from from the slowest one.
27 | # -vv will also display tests with durration = 0.00s
28 | addopts = "--verbose --pyargs --durations=0 --strict-markers"  # always add these arguments to pytest
29 | testpaths = ["./tests"]  # must be an explicit path to avoid importing another "tests" module
30 | # directories to ignore when discovering tests
31 | norecursedirs = [
32 |     "external",
33 |     "examples",
34 |     "docs",
35 |     "scripts",
36 |     "tools",
37 |     "tutorials",
38 |     "*.egg",
39 |     ".*",
40 |     "_darcs",
41 |     "build",
42 |     "CVS",
43 |     "dist",
44 |     "venv",
45 |     "{arch}",
46 | ]
47 | # markers to select tests, use `pytest --markers` to see all available markers, `pytest -m "<marker>"` to select tests
48 | markers = [
49 |     "unit: marks unit test, i.e. testing a single, well isolated functionality (deselect with '-m \"not unit\"')",
50 |     "integration: marks test checking the elements when integrated into subsystems (deselect with '-m \"not integration\"')",
51 |     "system: marks test working at the highest integration level (deselect with '-m \"not system\"')",
52 |     "acceptance: marks test checking whether the developed product/model passes the user defined acceptance criteria (deselect with '-m \"not acceptance\"')",
53 |     "docs: mark tests related to documentation (deselect with '-m \"not docs\"')",
54 |     "skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups",
55 |     "pleasefixme: marks tests that are broken and need fixing",
56 | ]
57 | 


--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | httpx[http2]
3 | mcp[cli]
4 | pylatexenc
5 | ray[default]
6 | sglang
7 | torch
8 | wandb
9 | 


--------------------------------------------------------------------------------
/scripts/agent-example.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | # for rerun the task
  4 | pkill -9 sglang
  5 | sleep 3
  6 | ray stop --force
  7 | pkill -9 ray
  8 | pkill -9 python
  9 | sleep 3
 10 | pkill -9 ray
 11 | pkill -9 python
 12 | 
 13 | set -ex
 14 | 
 15 | export PYTHONBUFFERED=16
 16 | 
 17 | export TP_SIZE=2
 18 | export PP_SIZE=1
 19 | export CP_SIZE=1
 20 | 
 21 | export HF_MODEL_PATH=/root/hf_models/deepseek-ai--DeepSeek-R1-Distill-Qwen-7B
 22 | export MCORE_MODEL_PATH=/root/megatron_model/DeepSeek-R1-Distill-Qwen-7B-25.02
 23 | export PROMPT_DATA=/root/dapo-math-17k/dapo-math-17k_processed.jsonl
 24 | export MCORE_MODEL_PATH_SAVE=/root/megatron_model/DeepSeek-R1-Distill-Qwen-7B-25.02_save
 25 | 
 26 | # DeepSeek-R1-Distill-Qwen-7B
 27 | MODEL_ARGS=(
 28 |    --swiglu
 29 |    --num-layers 28
 30 |    --hidden-size 3584
 31 |    --ffn-hidden-size 18944
 32 |    --num-attention-heads 28
 33 |    --group-query-attention
 34 |    --num-query-groups 4
 35 |    --max-position-embeddings 131072
 36 |    --seq-length 4096
 37 |    --use-rotary-position-embeddings
 38 |    --disable-bias-linear
 39 |    --add-qkv-bias
 40 |    --normalization "RMSNorm"
 41 |    --norm-epsilon 1e-06
 42 |    --rotary-base 10000
 43 |    --vocab-size 152064
 44 |    --accumulate-allreduce-grads-in-fp32
 45 |    --attention-softmax-in-fp32
 46 |    --attention-backend flash
 47 |    --moe-token-dispatcher-type alltoall
 48 |    --untie-embeddings-and-output-weights
 49 |    --attention-dropout 0.0
 50 |    --hidden-dropout 0.0
 51 | )
 52 | 
 53 | CKPT_ARGS=(
 54 |    --hf-checkpoint ${HF_MODEL_PATH}
 55 |    --ref-load ${MCORE_MODEL_PATH}
 56 |    --save-interval 100
 57 |    --save ${MCORE_MODEL_PATH_SAVE}
 58 | )
 59 | 
 60 | ROLLOUT_ARGS=(
 61 |    --rollout-function-path slime.rollout.agent_rollout.generate_rollout
 62 |    --rm-type deepscaler
 63 |    --prompt-data ${PROMPT_DATA}
 64 |    --label-key label
 65 |    --num-rollout 3000
 66 |    --rollout-batch-size 128
 67 |    --rollout-max-response-len 8192
 68 |    --rollout-temperature 0.8
 69 |    --rollout-shuffle
 70 |    --n-samples-per-prompt 8
 71 |    --global-batch-size 1024
 72 |    --micro-batch-size 8
 73 |    --ref-micro-batch-size 8
 74 |    --use-dynamic-batch-size
 75 |    --max-tokens-per-gpu 9216
 76 |    --balance-data
 77 | )
 78 | 
 79 | DISTRIBUTED_ARGS=(
 80 |    --tensor-model-parallel-size ${TP_SIZE}
 81 |    --pipeline-model-parallel-size ${PP_SIZE}
 82 |    --context-parallel-size ${CP_SIZE}
 83 |    --sequence-parallel
 84 | )
 85 | 
 86 | PERF_ARGS=(
 87 |    --recompute-granularity full
 88 |    --recompute-method uniform
 89 |    --recompute-num-layers 1
 90 | )
 91 | 
 92 | GRPO_ARGS=(
 93 |    --advantage-estimator grpo
 94 |    --use-kl-loss
 95 |    --kl-loss-coef 0.001
 96 |    --kl-loss-type low_var_kl
 97 |    --kl-coef 0.00
 98 |    --entropy-coef 0.00
 99 | )
100 | 
101 | OPTIMIZER_ARGS=(
102 |    --lr 1e-6
103 |    --lr-decay-style constant
104 |    --weight-decay 0.1
105 |    --adam-beta1 0.9
106 |    --adam-beta2 0.98
107 | )
108 | 
109 | WANDB_ARGS=(
110 |    # --use-wandb \
111 | )
112 | 
113 | # launch the master node of ray in container
114 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
115 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
116 | 
117 | ray job submit --address="http://127.0.0.1:8265" \
118 |    --runtime-env-json='{
119 |      "env_vars": {
120 |         "PYTHONPATH": "/root/Megatron-LM/",
121 |         "CUDA_DEVICE_MAX_CONNECTIONS": "1",
122 |         "NCCL_CUMEM_ENABLE": "0"
123 |      }
124 |    }' \
125 |    -- python3 train_async.py \
126 |    --actor-num-nodes 1 \
127 |    --actor-num-gpus-per-node 4 \
128 |    --rollout-num-gpus 4 \
129 |    --rollout-num-gpus-per-engine 1 \
130 |    --sglang-mem-fraction-static 0.8 \
131 |    ${MODEL_ARGS[@]} \
132 |    ${CKPT_ARGS[@]} \
133 |    ${ROLLOUT_ARGS[@]} \
134 |    ${OPTIMIZER_ARGS[@]} \
135 |    ${GRPO_ARGS[@]} \
136 |    ${DISTRIBUTED_ARGS[@]} \
137 |    ${WANDB_ARGS[@]} \
138 |    ${PERF_ARGS[@]} \
139 |    --agent-rollout-buffer-url http://${MASTER_ADDR}:8889 \
140 |    --keep-old-actor \
141 |    --disable-rewards-normalization \
142 |    --offload-old-actor \
143 |    --offload-ref \
144 |    --loss-mask-type distill_qwen \
145 |    --sglang-log-level error \
146 |    --input-key prompt \
147 |    --log-passrate
148 | 


--------------------------------------------------------------------------------
/scripts/models/glm4-32B.sh:
--------------------------------------------------------------------------------
 1 | MODEL_ARGS=(
 2 |    --spec "slime_plugins.models.glm4" "get_glm_spec"
 3 |    --swiglu
 4 |    --num-layers 64
 5 |    --hidden-size 6144
 6 |    --ffn-hidden-size 23040
 7 |    --num-attention-heads 48
 8 |    --max-position-embeddings 32768
 9 |    --seq-length 32768
10 |    --use-rotary-position-embeddings
11 |    --disable-bias-linear
12 |    --normalization "RMSNorm"
13 |    --norm-epsilon 1e-5
14 |    --rotary-base 10000
15 |    --group-query-attention
16 |    --num-query-groups 8
17 |    --vocab-size 151552
18 |    --post-self-attn-layernorm
19 |    --post-mlp-layernorm
20 |    --rotary-interleaved
21 |    --rotary-percent 0.5
22 |    --no-rope-fusion
23 |    --untie-embeddings-and-output-weights
24 | )


--------------------------------------------------------------------------------
/scripts/models/glm4-9B.sh:
--------------------------------------------------------------------------------
 1 | MODEL_ARGS=(
 2 |    --spec "slime_plugins.models.glm4" "get_glm_spec"
 3 |    --swiglu
 4 |    --num-layers 40
 5 |    --hidden-size 4096
 6 |    --ffn-hidden-size 13696
 7 |    --num-attention-heads 32
 8 |    --group-query-attention
 9 |    --num-query-groups 2
10 |    --use-rotary-position-embeddings
11 |    --disable-bias-linear
12 |    --add-qkv-bias
13 |    --normalization "RMSNorm"
14 |    --norm-epsilon 1e-5
15 |    --rotary-base 10000
16 |    --vocab-size 151552
17 |    --post-self-attn-layernorm
18 |    --post-mlp-layernorm
19 |    --rotary-interleaved
20 |    --rotary-percent 0.5
21 |    --no-rope-fusion
22 |    --untie-embeddings-and-output-weights
23 | )


--------------------------------------------------------------------------------
/scripts/models/moonlight.sh:
--------------------------------------------------------------------------------
 1 | MOE_SHARED_EXPERTS=2
 2 | MOE_FFN_HIDDEN=1408
 3 | MOE_SHARED_EXPERT_INTERMEDIATE_SIZE=$(($MOE_FFN_HIDDEN * $MOE_SHARED_EXPERTS))
 4 | MOE_ROUTER_TOPK_SCALING_FACTOR=2.446
 5 | NLAYERS=27
 6 | FIRST_K_DENSE_REPLACE=1
 7 | 
 8 | arr=()
 9 | for ((i=0; i<NLAYERS; i++)); do
10 |   if (( i < FIRST_K_DENSE_REPLACE )); then
11 |     arr+=(0)
12 |   else
13 |     arr+=(1)
14 |   fi
15 | done
16 | 
17 | printf -v MOE_LAYER_FREQ "[%s]" "$(IFS=', '; echo "${arr[*]}")"
18 | 
19 | # moonlight
20 | MODEL_ARGS=(
21 |     --disable-bias-linear
22 |     --num-layers 27
23 |     --hidden-size 2048
24 |     --ffn-hidden-size 11264
25 |     --num-attention-heads 16
26 |     --kv-channels 128
27 |     --normalization RMSNorm
28 |     --position-embedding-type rope
29 |     --norm-epsilon 1e-5
30 |     --rotary-percent 1.0
31 |     --swiglu
32 |     --untie-embeddings-and-output-weights
33 |     --no-masked-softmax-fusion
34 |     --vocab-size 163840
35 | 
36 |     --multi-latent-attention
37 |     --kv-lora-rank 512
38 |     --qk-head-dim 128
39 |     --qk-pos-emb-head-dim 64
40 |     --v-head-dim 128
41 |     --qk-layernorm
42 |     --rotary-scaling-factor 1
43 |     --rotary-base 50000
44 |     --mscale 1.0
45 |     --mscale-all-dim 1.0
46 |     --attention-softmax-in-fp32
47 |     --no-rope-fusion
48 | 
49 |     # moe
50 |     --num-experts 64
51 |     --moe-layer-freq $MOE_LAYER_FREQ
52 |     --moe-ffn-hidden-size $MOE_FFN_HIDDEN
53 |     --moe-router-topk 6
54 |     --moe-shared-expert-intermediate-size $MOE_SHARED_EXPERT_INTERMEDIATE_SIZE
55 |     --moe-router-pre-softmax
56 |     --moe-router-score-function sigmoid
57 |     --moe-router-enable-expert-bias
58 |     --moe-router-load-balancing-type seq_aux_loss
59 |     --moe-token-dispatcher-type alltoall
60 |     --moe-aux-loss-coeff 0
61 |     --moe-router-bias-update-rate 0
62 |     --moe-router-group-topk 1
63 |     --moe-router-num-groups 1
64 |     --moe-grouped-gemm
65 |     --moe-router-topk-scaling-factor $MOE_ROUTER_TOPK_SCALING_FACTOR
66 |     --moe-token-drop-policy probs
67 |     --moe-router-dtype fp32
68 |     --moe-permute-fusion
69 | )


--------------------------------------------------------------------------------
/scripts/models/qwen2.5-1.5B.sh:
--------------------------------------------------------------------------------
 1 | MODEL_ARGS=(
 2 |    --swiglu
 3 |    --num-layers 28
 4 |    --hidden-size 1536
 5 |    --ffn-hidden-size 8960
 6 |    --num-attention-heads 12
 7 |    --use-rotary-position-embeddings
 8 |    --disable-bias-linear
 9 |    --add-qkv-bias
10 |    --normalization "RMSNorm"
11 |    --norm-epsilon 1e-6
12 |    --rotary-base 10000
13 |    --group-query-attention
14 |    --num-query-groups 2
15 |    --vocab-size 151936
16 |    --untie-embeddings-and-output-weights
17 | )


--------------------------------------------------------------------------------
/scripts/models/qwen2.5-32B.sh:
--------------------------------------------------------------------------------
 1 | MODEL_ARGS=(
 2 |    --swiglu
 3 |    --num-layers 64
 4 |    --hidden-size 5120
 5 |    --ffn-hidden-size 27648
 6 |    --num-attention-heads 40
 7 |    --group-query-attention
 8 |    --num-query-groups 8
 9 |    --use-rotary-position-embeddings
10 |    --disable-bias-linear
11 |    --add-qkv-bias
12 |    --normalization "RMSNorm"
13 |    --norm-epsilon 1e-5
14 |    --rotary-base 1000000
15 |    --vocab-size 152064
16 |    --untie-embeddings-and-output-weights
17 | )


--------------------------------------------------------------------------------
/scripts/models/qwen2.5-3B.sh:
--------------------------------------------------------------------------------
 1 | MODEL_ARGS=(
 2 |    --swiglu
 3 |    --num-layers 36
 4 |    --hidden-size 2048
 5 |    --ffn-hidden-size 11008
 6 |    --num-attention-heads 16
 7 |    --use-rotary-position-embeddings
 8 |    --disable-bias-linear
 9 |    --add-qkv-bias
10 |    --normalization "RMSNorm"
11 |    --norm-epsilon 1e-6
12 |    --rotary-base 10000
13 |    --group-query-attention
14 |    --num-query-groups 2
15 |    --vocab-size 151936
16 | )


--------------------------------------------------------------------------------
/scripts/models/qwen2.5-7B.sh:
--------------------------------------------------------------------------------
 1 | MODEL_ARGS=(
 2 |    --swiglu
 3 |    --num-layers 28
 4 |    --hidden-size 3584
 5 |    --ffn-hidden-size 18944
 6 |    --num-attention-heads 28
 7 |    --group-query-attention
 8 |    --num-query-groups 4
 9 |    --use-rotary-position-embeddings
10 |    --disable-bias-linear
11 |    --add-qkv-bias
12 |    --normalization "RMSNorm"
13 |    --norm-epsilon 1e-06
14 |    --rotary-base 1000000
15 |    --vocab-size 152064
16 |    --untie-embeddings-and-output-weights
17 | )


--------------------------------------------------------------------------------
/scripts/models/qwen3-0.6B.sh:
--------------------------------------------------------------------------------
 1 | MODEL_ARGS=(
 2 |    --swiglu
 3 |    --num-layers 28
 4 |    --hidden-size 1024
 5 |    --ffn-hidden-size 3072
 6 |    --num-attention-heads 16
 7 |    --group-query-attention
 8 |    --num-query-groups 8
 9 |    --use-rotary-position-embeddings
10 |    --disable-bias-linear
11 |    --normalization "RMSNorm"
12 |    --norm-epsilon 1e-6
13 |    --rotary-base 1000000
14 |    --vocab-size 151936
15 |    --kv-channels 128
16 |    --qk-layernorm
17 | )


--------------------------------------------------------------------------------
/scripts/models/qwen3-235B-A22B.sh:
--------------------------------------------------------------------------------
 1 | # qwen3-235B-a22B
 2 | MODEL_ARGS=(
 3 |    --disable-bias-linear
 4 |    --qk-layernorm
 5 |    --group-query-attention
 6 |    --num-attention-heads 64
 7 |    --num-query-groups 4
 8 |    --kv-channels 128
 9 |    --num-layers 94
10 |    --hidden-size 4096
11 |    --ffn-hidden-size 12288
12 | 
13 |    --normalization RMSNorm
14 |    --position-embedding-type rope
15 |    --norm-epsilon 1e-6
16 |    --rotary-percent 1.0
17 |    --swiglu
18 |    --untie-embeddings-and-output-weights
19 |    --vocab-size 151936
20 | 
21 |    --rotary-base 1000000
22 | 
23 |    # moe
24 |    --moe-ffn-hidden-size 1536
25 |    --moe-router-score-function softmax
26 |    --moe-token-dispatcher-type alltoall
27 |    --moe-router-topk 8
28 |    --moe-layer-freq "'([1]*94)'"
29 |    --num-experts 128
30 |    --moe-grouped-gemm
31 |    --moe-token-drop-policy probs
32 |    --moe-router-dtype fp32
33 |    --moe-permute-fusion
34 |    --moe-aux-loss-coeff 0
35 | )


--------------------------------------------------------------------------------
/scripts/models/qwen3-30B-A3B.sh:
--------------------------------------------------------------------------------
 1 | MODEL_ARGS=(
 2 |    --disable-bias-linear
 3 |    --qk-layernorm
 4 |    --group-query-attention
 5 |    --num-attention-heads 32
 6 |    --num-query-groups 4
 7 |    --kv-channels 128
 8 |    --num-layers 48
 9 |    --hidden-size 2048
10 |    --ffn-hidden-size 6144
11 | 
12 |    --normalization RMSNorm
13 |    --position-embedding-type rope
14 |    --norm-epsilon 1e-6
15 |    --rotary-percent 1.0
16 |    --swiglu
17 |    --untie-embeddings-and-output-weights
18 |    --vocab-size 151936
19 | 
20 |    --rotary-base 1000000
21 | 
22 |    # moe
23 |    --moe-ffn-hidden-size 768
24 |    --moe-router-score-function softmax
25 |    --moe-token-dispatcher-type alltoall
26 |    --moe-router-topk 8
27 |    --moe-layer-freq "'([1]*48)'"
28 |    --num-experts 128
29 |    --moe-grouped-gemm
30 |    --moe-token-drop-policy probs
31 |    --moe-router-dtype fp32
32 |    --moe-permute-fusion
33 |    --moe-aux-loss-coeff 0
34 | )


--------------------------------------------------------------------------------
/scripts/models/qwen3-4B.sh:
--------------------------------------------------------------------------------
 1 | MODEL_ARGS=(
 2 |    --swiglu
 3 |    --num-layers 36
 4 |    --hidden-size 2560
 5 |    --ffn-hidden-size 9728
 6 |    --num-attention-heads 32
 7 |    --group-query-attention
 8 |    --num-query-groups 8
 9 |    --use-rotary-position-embeddings
10 |    --disable-bias-linear
11 |    --normalization "RMSNorm"
12 |    --norm-epsilon 1e-6
13 |    --rotary-base 1000000
14 |    --vocab-size 151936
15 |    --kv-channels 128
16 |    --qk-layernorm
17 | )


--------------------------------------------------------------------------------
/scripts/models/qwen3-8B.sh:
--------------------------------------------------------------------------------
 1 | MODEL_ARGS=(
 2 |    --swiglu
 3 |    --num-layers 36
 4 |    --hidden-size 4096
 5 |    --ffn-hidden-size 12288
 6 |    --num-attention-heads 32
 7 |    --group-query-attention
 8 |    --num-query-groups 8
 9 |    --use-rotary-position-embeddings
10 |    --disable-bias-linear
11 |    --normalization "RMSNorm"
12 |    --norm-epsilon 1e-6
13 |    --rotary-base 1000000
14 |    --vocab-size 151936
15 |    --kv-channels 128
16 |    --qk-layernorm
17 |    --untie-embeddings-and-output-weights
18 | )


--------------------------------------------------------------------------------
/scripts/run-glm4-9B.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | # for rerun the task
  4 | pkill -9 sglang
  5 | sleep 3
  6 | ray stop --force
  7 | pkill -9 ray
  8 | pkill -9 python
  9 | sleep 3
 10 | pkill -9 ray
 11 | pkill -9 python
 12 | 
 13 | set -ex
 14 | 
 15 | # will prevent ray from buffering stdout/stderr
 16 | export PYTHONBUFFERED=16
 17 | 
 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then
 20 |     HAS_NVLINK=1
 21 | else
 22 |     HAS_NVLINK=0
 23 | fi
 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
 25 | 
 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
 27 | source "${SCRIPT_DIR}/models/glm4-9B.sh"
 28 | 
 29 | CKPT_ARGS=(
 30 |    --hf-checkpoint /root/GLM-Z1-9B-0414/
 31 |    --ref-load /root/GLM-Z1-9B-0414_torch_dist
 32 |    --load /root/GLM-Z1-9B-0414_slime/
 33 |    --save /root/GLM-Z1-9B-0414_slime/
 34 |    --save-interval 20
 35 | )
 36 | 
 37 | ROLLOUT_ARGS=(
 38 |    --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
 39 |    --input-key prompt
 40 |    --label-key label
 41 |    --apply-chat-template
 42 |    --rollout-shuffle
 43 | 
 44 |    --rm-type deepscaler
 45 | 
 46 |    --num-rollout 3000
 47 |    --rollout-batch-size 32
 48 |    --n-samples-per-prompt 8
 49 |    --rollout-max-response-len 8192
 50 |    --rollout-temperature 0.8
 51 | 
 52 |    --global-batch-size 256
 53 |    --balance-data
 54 | )
 55 | 
 56 | EVAL_ARGS=(
 57 |    --eval-interval 20
 58 |    --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
 59 |    --n-samples-per-eval-prompt 16
 60 |    --eval-max-response-len 16384
 61 |    --eval-top-p 0.7
 62 | )
 63 | 
 64 | PERF_ARGS=(
 65 |    --tensor-model-parallel-size 2
 66 |    --sequence-parallel
 67 |    --pipeline-model-parallel-size 1
 68 |    --context-parallel-size 2
 69 |    --expert-model-parallel-size 1
 70 |    --expert-tensor-parallel-size 1
 71 | 
 72 |    --recompute-granularity full
 73 |    --recompute-method uniform
 74 |    --recompute-num-layers 1
 75 | 
 76 |    # --micro-batch-size 1
 77 |    --use-dynamic-batch-size
 78 |    --max-tokens-per-gpu 4608
 79 | )
 80 | 
 81 | GRPO_ARGS=(
 82 |    --advantage-estimator grpo
 83 |    --use-kl-loss
 84 |    --kl-loss-coef 0.00
 85 |    --kl-loss-type low_var_kl
 86 |    --kl-coef 0.00
 87 |    --entropy-coef 0.00
 88 |    --eps-clip 0.2
 89 |    --eps-clip-high 0.28
 90 | )
 91 | 
 92 | OPTIMIZER_ARGS=(
 93 |    --optimizer adam
 94 |    --lr 1e-6
 95 |    --lr-decay-style constant
 96 |    --weight-decay 0.1
 97 |    --adam-beta1 0.9
 98 |    --adam-beta2 0.98
 99 | )
100 | 
101 | WANDB_ARGS=(
102 |    #--use-wandb
103 |    # --wandb-project slime-dev
104 |    # --wandb-group qwen3-4B-test
105 |    # --wandb-key ${WANDB_KEY}
106 | )
107 | 
108 | SGLANG_ARGS=(
109 |    --rollout-num-gpus-per-engine 2
110 | )
111 | 
112 | MISC_ARGS=(
113 |    # default dropout in megatron is 0.1
114 |    --attention-dropout 0.0
115 |    --hidden-dropout 0.0
116 |    # should be good for model performance
117 |    --accumulate-allreduce-grads-in-fp32
118 |    --attention-softmax-in-fp32
119 |    # need to comment this when using model with MLA
120 |    --attention-backend flash
121 | )
122 | 
123 | # launch the master node of ray in container
124 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
125 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
126 | 
127 | # Build the runtime environment JSON with proper variable substitution
128 | RUNTIME_ENV_JSON="{
129 |   \"env_vars\": {
130 |     \"PYTHONPATH\": \"/root/Megatron-LM/\",
131 |     \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
132 |     \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
133 |   }
134 | }"
135 | 
136 | ray job submit --address="http://127.0.0.1:8265" \
137 |    --runtime-env-json="${RUNTIME_ENV_JSON}" \
138 |    -- python3 train.py \
139 |    --actor-num-nodes 1 \
140 |    --actor-num-gpus-per-node 4 \
141 |    --rollout-num-gpus 4 \
142 |    ${MODEL_ARGS[@]} \
143 |    ${CKPT_ARGS[@]} \
144 |    ${ROLLOUT_ARGS[@]} \
145 |    ${OPTIMIZER_ARGS[@]} \
146 |    ${GRPO_ARGS[@]} \
147 |    ${DISTRIBUTED_ARGS[@]} \
148 |    ${WANDB_ARGS[@]} \
149 |    ${PERF_ARGS[@]} \
150 |    ${EVAL_ARGS[@]} \
151 |    ${SGLANG_ARGS[@]} \
152 |    ${MISC_ARGS[@]}
153 | 


--------------------------------------------------------------------------------
/scripts/run-qwen3-235B-A22B-sft.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | # for rerun the task
  4 | pkill -9 sglang
  5 | sleep 3
  6 | ray stop --force
  7 | pkill -9 ray
  8 | pkill -9 python
  9 | sleep 3
 10 | pkill -9 ray
 11 | pkill -9 python
 12 | 
 13 | set -ex
 14 | 
 15 | # if base folder not set raise error
 16 | if [ -z "${BASE_FOLDER}" ]; then
 17 |   echo "BASE_FOLDER is not set. Please set it to the base directory of your checkpoints."
 18 |   exit 1
 19 | fi
 20 | 
 21 | if [ -z "${MASTER_ADDR}" ]; then
 22 |   echo "MASTER_ADDR is not set. Please set it to the master node address."
 23 |   exit 1
 24 | fi
 25 | 
 26 | # will prevent ray from buffering stdout/stderr
 27 | export PYTHONBUFFERED=16
 28 | 
 29 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
 30 | if [ "$NVLINK_COUNT" -gt 0 ]; then
 31 |     HAS_NVLINK=1
 32 | else
 33 |     HAS_NVLINK=0
 34 | fi
 35 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
 36 | 
 37 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
 38 | source "${SCRIPT_DIR}/models/qwen3-235B-A22B.sh"
 39 | 
 40 | CKPT_ARGS=(
 41 |    --hf-checkpoint ${BASE_FOLDER}/Qwen3-235B-A22B
 42 |    --ref-load ${BASE_FOLDER}/Qwen3-235B-A22B_torch_dist
 43 |    --load ${BASE_FOLDER}/Qwen3-235B-A22B_slime/
 44 |    --save ${BASE_FOLDER}/Qwen3-235B-A22B_slime/
 45 |    --save-interval 1000
 46 | )
 47 | 
 48 | SFT_ARGS=(
 49 |    --rollout-function-path slime.rollout.sft_example.generate_rollout
 50 |    --prompt-data ${BASE_FOLDER}/openhermes2_5.parquet
 51 |    --input-key messages
 52 |    --rollout-shuffle
 53 |    --num-epoch 3
 54 |    --rollout-batch-size 128
 55 |    --global-batch-size 128
 56 | 
 57 |    --loss-type sft_loss
 58 |    --calculate-per-token-loss
 59 |    --disable-compute-advantages-and-returns
 60 |    --debug-train-only
 61 | )
 62 | 
 63 | PERF_ARGS=(
 64 |    --tensor-model-parallel-size 4
 65 |    --sequence-parallel
 66 |    --pipeline-model-parallel-size 1
 67 |    --context-parallel-size 1
 68 |    --expert-model-parallel-size 32
 69 |    --expert-tensor-parallel-size 1
 70 | 
 71 |    --recompute-granularity full
 72 |    --recompute-method uniform
 73 |    --recompute-num-layers 1
 74 | 
 75 |    # --micro-batch-size 1
 76 |    --use-dynamic-batch-size
 77 |    --max-tokens-per-gpu 9216
 78 | )
 79 | 
 80 | OPTIMIZER_ARGS=(
 81 |    --optimizer adam
 82 |    --lr 1e-5
 83 |    --lr-warmup-iters 128
 84 |    --lr-decay-style cosine
 85 |    --min-lr 1e-6
 86 |    --lr-warmup-fraction 0.9
 87 |    --weight-decay 0.1
 88 |    --adam-beta1 0.9
 89 |    --adam-beta2 0.98
 90 | 
 91 |    --optimizer-cpu-offload
 92 |    --overlap-cpu-optimizer-d2h-h2d
 93 |    --use-precision-aware-optimizer
 94 | )
 95 | 
 96 | WANDB_ARGS=(
 97 |    # --use-wandb
 98 |    # --wandb-project slime-dev
 99 |    # --wandb-group qwen3-235B-sft
100 | )
101 | 
102 | MISC_ARGS=(
103 |    # default dropout in megatron is 0.1
104 |    --attention-dropout 0.0
105 |    --hidden-dropout 0.0
106 |    # should be good for model performance
107 |    --accumulate-allreduce-grads-in-fp32
108 |    --attention-softmax-in-fp32
109 |    # need to comment this when using model with MLA
110 |    --attention-backend flash
111 | )
112 | 
113 | # launch the master node of ray in container
114 | export no_proxy="127.0.0.1,${MASTER_ADDR}"
115 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
116 | for WORKER_IP in $(awk '{print $1}' /root/mpi_rack_hostfile); do
117 |   if [[ "$WORKER_IP" == "$MLP_WORKER_0_HOST" ]]; then
118 |     continue
119 |   fi
120 |   echo "Starting Ray worker on ${WORKER_IP}"
121 |   ssh root@"${WORKER_IP}" \
122 |     "pkill -9 sglang ; ray stop --force ; pkill -9 python ; ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 --node-ip-address ${WORKER_IP} --disable-usage-stats" &
123 | done
124 | wait
125 | 
126 | 
127 | # Build the runtime environment JSON with proper variable substitution
128 | RUNTIME_ENV_JSON="{
129 |   \"env_vars\": {
130 |     \"PYTHONPATH\": \"/root/Megatron-LM/\",
131 |     \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
132 |     \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
133 |     \"no_proxy\": \"${no_proxy}\",
134 |     \"MASTER_ADDR\": \"${MASTER_ADDR}\"
135 |   }
136 | }"
137 | 
138 | ray job submit --address="http://127.0.0.1:8265" \
139 |    --runtime-env-json="${RUNTIME_ENV_JSON}" \
140 |    -- python3 train_async.py \
141 |    --actor-num-nodes 4 \
142 |    --actor-num-gpus-per-node 8 \
143 |    ${MODEL_ARGS[@]} \
144 |    ${CKPT_ARGS[@]} \
145 |    ${SFT_ARGS[@]} \
146 |    ${OPTIMIZER_ARGS[@]} \
147 |    ${DISTRIBUTED_ARGS[@]} \
148 |    ${WANDB_ARGS[@]} \
149 |    ${PERF_ARGS[@]} \
150 |    ${EVAL_ARGS[@]} \
151 |    ${MISC_ARGS[@]}
152 | 


--------------------------------------------------------------------------------
/scripts/run-qwen3-235B-A22B.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | # for rerun the task
  4 | pkill -9 sglang
  5 | sleep 3
  6 | ray stop --force
  7 | pkill -9 ray
  8 | pkill -9 python
  9 | sleep 3
 10 | pkill -9 ray
 11 | pkill -9 python
 12 | 
 13 | set -ex
 14 | 
 15 | # if base folder not set raise error
 16 | if [ -z "${BASE_FOLDER}" ]; then
 17 |   echo "BASE_FOLDER is not set. Please set it to the base directory of your checkpoints."
 18 |   exit 1
 19 | fi
 20 | 
 21 | if [ -z "${MASTER_ADDR}" ]; then
 22 |   echo "MASTER_ADDR is not set. Please set it to the master node address."
 23 |   exit 1
 24 | fi
 25 | 
 26 | # will prevent ray from buffering stdout/stderr
 27 | export PYTHONBUFFERED=16
 28 | 
 29 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
 30 | if [ "$NVLINK_COUNT" -gt 0 ]; then
 31 |     HAS_NVLINK=1
 32 | else
 33 |     HAS_NVLINK=0
 34 | fi
 35 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
 36 | 
 37 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
 38 | source "${SCRIPT_DIR}/models/qwen3-235B-A22B.sh"
 39 | 
 40 | CKPT_ARGS=(
 41 |    --hf-checkpoint ${BASE_FOLDER}/Qwen3-235B-A22B
 42 |    #--hf-checkpoint ${BASE_FOLDER}/Qwen3-235B-A22B-FP8
 43 |    --ref-load ${BASE_FOLDER}/Qwen3-235B-A22B_torch_dist
 44 |    --load ${BASE_FOLDER}/Qwen3-235B-A22B_slime/
 45 |    --save ${BASE_FOLDER}/Qwen3-235B-A22B_slime/
 46 |    --save-interval 20
 47 | )
 48 | 
 49 | ROLLOUT_ARGS=(
 50 |    --prompt-data ${BASE_FOLDER}/dapo-math-17k/dapo-math-17k.jsonl
 51 |    --input-key prompt
 52 |    --label-key label
 53 |    --apply-chat-template
 54 |    --rollout-shuffle
 55 | 
 56 |    --rm-type deepscaler
 57 | 
 58 |    --num-rollout 3000
 59 |    --rollout-batch-size 8
 60 |    --n-samples-per-prompt 8
 61 |    --rollout-max-response-len 8192
 62 |    --rollout-temperature 0.8
 63 | 
 64 |    --global-batch-size 64
 65 |    --balance-data
 66 | )
 67 | 
 68 | EVAL_ARGS=(
 69 |    #--eval-interval 20
 70 |    --eval-prompt-data aime ${BASE_FOLDER}/aime-2024/aime-2024.jsonl
 71 |    --n-samples-per-eval-prompt 16
 72 |    --eval-max-response-len 16384
 73 |    --eval-top-p 0.7
 74 | )
 75 | 
 76 | PERF_ARGS=(
 77 |    --tensor-model-parallel-size 4
 78 |    --sequence-parallel
 79 |    --pipeline-model-parallel-size 4
 80 |    --context-parallel-size 2
 81 |    --expert-model-parallel-size 8
 82 |    --expert-tensor-parallel-size 1
 83 |    --decoder-last-pipeline-num-layers 22
 84 | 
 85 |    --recompute-granularity full
 86 |    --recompute-method uniform
 87 |    --recompute-num-layers 1
 88 | 
 89 |    # --micro-batch-size 1
 90 |    --use-dynamic-batch-size
 91 |    --max-tokens-per-gpu 4096
 92 | )
 93 | 
 94 | GRPO_ARGS=(
 95 |    --advantage-estimator grpo
 96 |    #--use-kl-loss
 97 |    --kl-loss-coef 0.00
 98 |    --kl-loss-type low_var_kl
 99 |    --kl-coef 0.00
100 |    --entropy-coef 0.00
101 |    --eps-clip 0.2
102 |    --eps-clip-high 0.28
103 | )
104 | 
105 | OPTIMIZER_ARGS=(
106 |    --optimizer adam
107 |    --lr 1e-6
108 |    --lr-decay-style constant
109 |    --weight-decay 0.1
110 |    --adam-beta1 0.9
111 |    --adam-beta2 0.98
112 | 
113 |    --optimizer-cpu-offload
114 |    --overlap-cpu-optimizer-d2h-h2d
115 |    --use-precision-aware-optimizer
116 | )
117 | 
118 | WANDB_ARGS=(
119 |    # --use-wandb
120 |    # --wandb-project slime-dev
121 |    # --wandb-group qwen3-235B-sft
122 | )
123 | 
124 | SGLANG_ARGS=(
125 |    --rollout-num-gpus-per-engine 32
126 |    --sglang-mem-fraction-static 0.5
127 |    --sglang-enable-ep-moe
128 |    --sglang-enable-dp-attention
129 |    --sglang-dp-size 4
130 |    --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256)
131 | )
132 | 
133 | MISC_ARGS=(
134 |    # default dropout in megatron is 0.1
135 |    --attention-dropout 0.0
136 |    --hidden-dropout 0.0
137 |    # should be good for model performance
138 |    --accumulate-allreduce-grads-in-fp32
139 |    --attention-softmax-in-fp32
140 |    # need to comment this when using model with MLA
141 |    --attention-backend flash
142 | )
143 | 
144 | # launch the master node of ray in container
145 | export no_proxy="127.0.0.1,${MASTER_ADDR}"
146 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
147 | for WORKER_IP in $(awk '{print $1}' /root/mpi_rack_hostfile); do
148 |   if [[ "$WORKER_IP" == "$MLP_WORKER_0_HOST" ]]; then
149 |     continue
150 |   fi
151 |   echo "Starting Ray worker on ${WORKER_IP}"
152 |   ssh root@"${WORKER_IP}" \
153 |     "pkill -9 sglang ; ray stop --force ; pkill -9 python ; ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 --node-ip-address ${WORKER_IP} --disable-usage-stats" &
154 | done
155 | wait
156 | 
157 | 
158 | # Build the runtime environment JSON with proper variable substitution
159 | RUNTIME_ENV_JSON="{
160 |   \"env_vars\": {
161 |     \"PYTHONPATH\": \"/root/Megatron-LM/\",
162 |     \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
163 |     \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
164 |     \"no_proxy\": \"${no_proxy}\",
165 |     \"MASTER_ADDR\": \"${MASTER_ADDR}\"
166 |   }
167 | }"
168 | 
169 | ray job submit --address="http://127.0.0.1:8265" \
170 |    --runtime-env-json="${RUNTIME_ENV_JSON}" \
171 |    -- python3 train.py \
172 |    --actor-num-nodes 4 \
173 |    --actor-num-gpus-per-node 8 \
174 |    --colocate \
175 |    ${MODEL_ARGS[@]} \
176 |    ${CKPT_ARGS[@]} \
177 |    ${ROLLOUT_ARGS[@]} \
178 |    ${OPTIMIZER_ARGS[@]} \
179 |    ${GRPO_ARGS[@]} \
180 |    ${DISTRIBUTED_ARGS[@]} \
181 |    ${WANDB_ARGS[@]} \
182 |    ${PERF_ARGS[@]} \
183 |    ${EVAL_ARGS[@]} \
184 |    ${SGLANG_ARGS[@]} \
185 |    ${MISC_ARGS[@]}
186 | 


--------------------------------------------------------------------------------
/scripts/run-qwen3-30B-A3B.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | # for rerun the task
  4 | pkill -9 sglang
  5 | sleep 3
  6 | ray stop --force
  7 | pkill -9 ray
  8 | pkill -9 python
  9 | sleep 3
 10 | pkill -9 ray
 11 | pkill -9 python
 12 | 
 13 | set -ex
 14 | 
 15 | # will prevent ray from buffering stdout/stderr
 16 | export PYTHONBUFFERED=16
 17 | 
 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then
 20 |     HAS_NVLINK=1
 21 | else
 22 |     HAS_NVLINK=0
 23 | fi
 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
 25 | 
 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
 27 | source "${SCRIPT_DIR}/models/qwen3-30B-A3B.sh"
 28 | 
 29 | CKPT_ARGS=(
 30 |    --hf-checkpoint /root/Qwen3-30B-A3B
 31 |    #--hf-checkpoint /root/Qwen3-30B-A3B-FP8
 32 |    --ref-load /root/Qwen3-30B-A3B_torch_dist
 33 |    --load /root/Qwen3-4B_slime/
 34 |    --save /root/Qwen3-4B_slime/
 35 |    --save-interval 20
 36 | )
 37 | 
 38 | ROLLOUT_ARGS=(
 39 |    --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
 40 |    --input-key prompt
 41 |    --label-key label
 42 |    --apply-chat-template
 43 |    --rollout-shuffle
 44 | 
 45 |    --rm-type deepscaler
 46 | 
 47 |    --num-rollout 3000
 48 |    --rollout-batch-size 32
 49 |    --n-samples-per-prompt 8
 50 |    --rollout-max-response-len 8192
 51 |    --rollout-temperature 0.8
 52 | 
 53 |    --global-batch-size 256
 54 |    --balance-data
 55 | )
 56 | 
 57 | EVAL_ARGS=(
 58 |    --eval-interval 20
 59 |    --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
 60 |    --n-samples-per-eval-prompt 16
 61 |    --eval-max-response-len 16384
 62 |    --eval-top-p 0.7
 63 | )
 64 | 
 65 | PERF_ARGS=(
 66 |    --tensor-model-parallel-size 4
 67 |    --sequence-parallel
 68 |    --pipeline-model-parallel-size 1
 69 |    --context-parallel-size 1
 70 |    --expert-model-parallel-size 8
 71 |    --expert-tensor-parallel-size 1
 72 | 
 73 |    --recompute-granularity full
 74 |    --recompute-method uniform
 75 |    --recompute-num-layers 1
 76 | 
 77 |    # --micro-batch-size 1
 78 |    --use-dynamic-batch-size
 79 |    --max-tokens-per-gpu 20480
 80 | )
 81 | 
 82 | GRPO_ARGS=(
 83 |    --advantage-estimator grpo
 84 |    --use-kl-loss
 85 |    --kl-loss-coef 0.00
 86 |    --kl-loss-type low_var_kl
 87 |    --kl-coef 0.00
 88 |    --entropy-coef 0.00
 89 |    --eps-clip 0.2
 90 |    --eps-clip-high 0.28
 91 | )
 92 | 
 93 | OPTIMIZER_ARGS=(
 94 |    --optimizer adam
 95 |    --lr 1e-6
 96 |    --lr-decay-style constant
 97 |    --weight-decay 0.1
 98 |    --adam-beta1 0.9
 99 |    --adam-beta2 0.98
100 | 
101 |    --optimizer-cpu-offload
102 |    --overlap-cpu-optimizer-d2h-h2d
103 |    --use-precision-aware-optimizer
104 | )
105 | 
106 | WANDB_ARGS=(
107 |    #--use-wandb
108 |    # --wandb-project slime-dev
109 |    # --wandb-group qwen3-4B-test
110 |    # --wandb-key ${WANDB_KEY}
111 | )
112 | 
113 | SGLANG_ARGS=(
114 |    --rollout-num-gpus-per-engine 8
115 |    --sglang-mem-fraction-static 0.5
116 |    --sglang-enable-ep-moe
117 |    --sglang-enable-dp-attention
118 |    --sglang-dp-size 8
119 |    --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256)
120 | )
121 | 
122 | MISC_ARGS=(
123 |    # default dropout in megatron is 0.1
124 |    --attention-dropout 0.0
125 |    --hidden-dropout 0.0
126 |    # should be good for model performance
127 |    --accumulate-allreduce-grads-in-fp32
128 |    --attention-softmax-in-fp32
129 |    # need to comment this when using model with MLA
130 |    --attention-backend flash
131 | )
132 | 
133 | # launch the master node of ray in container
134 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
135 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
136 | 
137 | # Build the runtime environment JSON with proper variable substitution
138 | RUNTIME_ENV_JSON="{
139 |   \"env_vars\": {
140 |     \"PYTHONPATH\": \"/root/Megatron-LM/\",
141 |     \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
142 |     \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
143 |   }
144 | }"
145 | 
146 | ray job submit --address="http://127.0.0.1:8265" \
147 |    --runtime-env-json="${RUNTIME_ENV_JSON}" \
148 |    -- python3 train.py \
149 |    --actor-num-nodes 1 \
150 |    --actor-num-gpus-per-node 8 \
151 |    --colocate \
152 |    ${MODEL_ARGS[@]} \
153 |    ${CKPT_ARGS[@]} \
154 |    ${ROLLOUT_ARGS[@]} \
155 |    ${OPTIMIZER_ARGS[@]} \
156 |    ${GRPO_ARGS[@]} \
157 |    ${DISTRIBUTED_ARGS[@]} \
158 |    ${WANDB_ARGS[@]} \
159 |    ${PERF_ARGS[@]} \
160 |    ${EVAL_ARGS[@]} \
161 |    ${SGLANG_ARGS[@]} \
162 |    ${MISC_ARGS[@]}
163 | 


--------------------------------------------------------------------------------
/scripts/run-qwen3-4B-amd.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | 
  4 | # bash scripts/run-qwen3-4B-amd.sh
  5 | 
  6 | 
  7 | ####clear before training
  8 | pkill -9 sglang
  9 | sleep 3
 10 | ray stop --force
 11 | pkill -9 ray
 12 | pkill -9 python
 13 | sleep 3
 14 | pkill -9 ray
 15 | pkill -9 python
 16 | 
 17 | 
 18 | set -euxo pipefail
 19 | 
 20 | 
 21 | ### AMD Support ###
 22 | SLIME_DIR="/home/yushensu/projects/slime" # Need to change to your own path
 23 | export SLIME_DIR=$SLIME_DIR
 24 | 
 25 | MODEL_DIR="/home/yushensu/projects/model" # Need to change to your own path
 26 | export MODEL_DIR=$MODEL_DIR
 27 | 
 28 | DATA_DIR="/home/yushensu/projects/data"  # Need to change to your own path
 29 | export DATA_DIR=$DATA_DIR
 30 | 
 31 | # For AMD GPU
 32 | export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1
 33 | export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use
 34 | ####################
 35 | 
 36 | 
 37 | # ### AMD Support ### (If you do not istall, please install them)
 38 | # # # Clone and install Megatron-LMi-amd_version
 39 | # export MAX_JOBS=512
 40 | # cd $SLIME_DIR
 41 | # pip uninstall megatron-core -y
 42 | # if [ ! -d "Megatron-LM-amd_version" ]; then
 43 | #     git clone git@github.com:yushengsu-thu/Megatron-LM-amd_version.git
 44 | # else
 45 | #     echo "Megatron-LM-amd_version directory already exists, skipping clone"
 46 | # fi
 47 | # cd Megatron-LM-amd_version
 48 | # pip install -vvv -e . 
 49 | # cd $SLIME_DIR
 50 | 
 51 | # # Install slime
 52 | # pip install -e .
 53 | # ####################
 54 | 
 55 | 
 56 | 
 57 | # will prevent ray from buffering stdout/stderr
 58 | export PYTHONBUFFERED=16
 59 | 
 60 | 
 61 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
 62 | source "${SCRIPT_DIR}/models/qwen3-4B.sh"
 63 | 
 64 | CKPT_ARGS=(
 65 |    --hf-checkpoint ${MODEL_DIR}/Qwen3-4B
 66 |    #--hf-checkpoint /root/Qwen3-4B-FP8
 67 |    --ref-load ${MODEL_DIR}/Qwen3-4B_torch
 68 |    --load ${MODEL_DIR}/Qwen3-4B_slime/
 69 |    --save ${MODEL_DIR}/Qwen3-4B_slime/
 70 |    --save-interval 20
 71 | )
 72 | 
 73 | ROLLOUT_ARGS=(
 74 |    --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl
 75 |    --input-key prompt
 76 |    --label-key label
 77 |    --apply-chat-template
 78 |    --rollout-shuffle
 79 | 
 80 |    --rm-type deepscaler
 81 | 
 82 |    --num-rollout 3000
 83 |    --rollout-batch-size 32
 84 |    --n-samples-per-prompt 8
 85 |    --rollout-max-response-len 8192
 86 |    --rollout-temperature 0.8
 87 | 
 88 |    --global-batch-size 256
 89 |    --balance-data
 90 | )
 91 | 
 92 | EVAL_ARGS=(
 93 |    --eval-interval 20
 94 |    --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl
 95 |    --n-samples-per-eval-prompt 16
 96 |    --eval-max-response-len 16384
 97 |    --eval-top-p 0.7
 98 | )
 99 | 
100 | PERF_ARGS=(
101 |    --tensor-model-parallel-size 2
102 |    --sequence-parallel
103 |    --pipeline-model-parallel-size 1
104 |    --context-parallel-size 1
105 |    --expert-model-parallel-size 1
106 |    --expert-tensor-parallel-size 1
107 | 
108 |    --recompute-granularity full
109 |    --recompute-method uniform
110 |    --recompute-num-layers 1
111 | 
112 |    # --micro-batch-size 1
113 |    --use-dynamic-batch-size
114 |    --max-tokens-per-gpu 9216
115 | )
116 | 
117 | GRPO_ARGS=(
118 |    --advantage-estimator grpo
119 |    --use-kl-loss
120 |    --kl-loss-coef 0.00
121 |    --kl-loss-type low_var_kl
122 |    --kl-coef 0.00
123 |    --entropy-coef 0.00
124 |    --eps-clip 0.2
125 |    --eps-clip-high 0.28
126 | )
127 | 
128 | OPTIMIZER_ARGS=(
129 |    --optimizer adam
130 |    --lr 1e-6
131 |    --lr-decay-style constant
132 |    --weight-decay 0.1
133 |    --adam-beta1 0.9
134 |    --adam-beta2 0.98
135 | )
136 | 
137 | WANDB_ARGS=(
138 |    #--use-wandb
139 |    # --wandb-project slime-dev
140 |    # --wandb-group qwen3-4B-test
141 |    # --wandb-key ${WANDB_KEY}
142 | )
143 | 
144 | ### AMD Support ###
145 | # Need to fix some issue with torch_memory_saver in rocm to support larger  --sglang-mem-fraction-static 
146 | # SGLANG_ARGS=(
147 | #    --rollout-num-gpus-per-engine 2
148 | #    --sglang-mem-fraction-static 0.7
149 | # )
150 | SGLANG_ARGS=(
151 |    --rollout-num-gpus-per-engine 2
152 |    --sglang-mem-fraction-static 0.4
153 | )
154 | ####################
155 | 
156 | 
157 | MISC_ARGS=(
158 |    # default dropout in megatron is 0.1
159 |    --attention-dropout 0.0
160 |    --hidden-dropout 0.0
161 |    # should be good for model performance
162 |    --accumulate-allreduce-grads-in-fp32
163 |    --attention-softmax-in-fp32
164 |    # need to comment this when using model with MLA
165 |    --attention-backend flash
166 |    ### AMD Support ###
167 |    # disable gradient accumulation fusion: Need to add apex to enable this
168 |    --no-gradient-accumulation-fusion
169 |    ###################
170 | )
171 | 
172 | # launch the master node of ray in container
173 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
174 | 
175 | NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l)
176 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats
177 | 
178 | 
179 | # "PYTHONPATH": "/workspace/Megatron-LM-amd_version/",
180 | MEGATRON_LM_PATH=$(pip list | grep megatron-core | awk '{print $NF}')
181 | 
182 | ray job submit --address="http://127.0.0.1:8265" \
183 |    --runtime-env-json='{
184 |      "env_vars": {
185 |         "PYTHONPATH": "/workspace/Megatron-LM-amd_version/",
186 |         "CUDA_DEVICE_MAX_CONNECTIONS": "1"
187 |      }
188 |    }' \
189 |    -- python3 train.py \
190 |    --actor-num-nodes 1 \
191 |    --actor-num-gpus-per-node 8 \
192 |    --colocate \
193 |    ${MODEL_ARGS[@]} \
194 |    ${CKPT_ARGS[@]} \
195 |    ${ROLLOUT_ARGS[@]} \
196 |    ${OPTIMIZER_ARGS[@]} \
197 |    ${GRPO_ARGS[@]} \
198 |    ${DISTRIBUTED_ARGS[@]} \
199 |    ${WANDB_ARGS[@]} \
200 |    ${PERF_ARGS[@]} \
201 |    ${EVAL_ARGS[@]} \
202 |    ${SGLANG_ARGS[@]} \
203 |    ${MISC_ARGS[@]}
204 | 
205 | 
206 | 
207 | ####clear after training
208 | 
209 | pkill -9 sglang
210 | sleep 3
211 | ray stop --force
212 | pkill -9 ray
213 | pkill -9 python
214 | sleep 3
215 | pkill -9 ray
216 | pkill -9 python
217 | 
218 | 
219 | 
220 | 
221 | 
222 | 
223 | 
224 | 
225 | 
226 | 
227 | 


--------------------------------------------------------------------------------
/scripts/run-qwen3-4B-base-sft.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | # for rerun the task
  4 | pkill -9 sglang
  5 | sleep 3
  6 | ray stop --force
  7 | pkill -9 ray
  8 | pkill -9 python
  9 | sleep 3
 10 | pkill -9 ray
 11 | pkill -9 python
 12 | 
 13 | set -ex
 14 | 
 15 | # will prevent ray from buffering stdout/stderr
 16 | export PYTHONBUFFERED=16
 17 | 
 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then
 20 |     HAS_NVLINK=1
 21 | else
 22 |     HAS_NVLINK=0
 23 | fi
 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
 25 | 
 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
 27 | source "${SCRIPT_DIR}/models/qwen3-4B.sh"
 28 | 
 29 | CKPT_ARGS=(
 30 |    --hf-checkpoint /root/Qwen3-4B-Base/
 31 |    --ref-load /root/Qwen3-4B-Base_torch_dist
 32 |    --load /root/Qwen3-4B-Base_slime/
 33 |    --save /root/Qwen3-4B-Base_slime/
 34 |    --save-interval 1000
 35 | )
 36 | 
 37 | SFT_ARGS=(
 38 |    --rollout-function-path slime.rollout.sft_example.generate_rollout
 39 |    --prompt-data /root/openhermes2_5.parquet
 40 |    --input-key messages
 41 |    --rollout-shuffle
 42 |    --num-epoch 3
 43 |    --rollout-batch-size 128
 44 |    --global-batch-size 128
 45 | 
 46 |    --loss-type sft_loss
 47 |    --calculate-per-token-loss
 48 |    --disable-compute-advantages-and-returns
 49 |    --debug-train-only
 50 | )
 51 | 
 52 | PERF_ARGS=(
 53 |    --tensor-model-parallel-size 1
 54 |    --sequence-parallel
 55 |    --pipeline-model-parallel-size 1
 56 |    --context-parallel-size 1
 57 |    --expert-model-parallel-size 1
 58 |    --expert-tensor-parallel-size 1
 59 | 
 60 |    --recompute-granularity full
 61 |    --recompute-method uniform
 62 |    --recompute-num-layers 1
 63 | 
 64 |    # --micro-batch-size 1
 65 |    --use-dynamic-batch-size
 66 |    --max-tokens-per-gpu 9216
 67 | )
 68 | 
 69 | OPTIMIZER_ARGS=(
 70 |    --optimizer adam
 71 |    --lr 1e-5
 72 |    --lr-warmup-iters 128
 73 |    --lr-decay-style cosine
 74 |    --min-lr 1e-6
 75 |    --lr-warmup-fraction 0.9
 76 |    --weight-decay 0.1
 77 |    --adam-beta1 0.9
 78 |    --adam-beta2 0.95
 79 | )
 80 | 
 81 | WANDB_ARGS=(
 82 |    # --use-wandb
 83 |    # --wandb-project slime-dev
 84 |    # --wandb-group qwen3-4B-base-sft
 85 |    # --wandb-key ${WANDB_KEY}
 86 | )
 87 | 
 88 | MISC_ARGS=(
 89 |    # default dropout in megatron is 0.1
 90 |    --attention-dropout 0.0
 91 |    --hidden-dropout 0.0
 92 |    # should be good for model performance
 93 |    --accumulate-allreduce-grads-in-fp32
 94 |    --attention-softmax-in-fp32
 95 |    # need to comment this when using model with MLA
 96 |    --attention-backend flash
 97 | )
 98 | 
 99 | # launch the master node of ray in container
100 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
101 | export no_proxy="127.0.0.1,${MASTER_ADDR}"
102 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
103 | 
104 | 
105 | # Build the runtime environment JSON with proper variable substitution
106 | RUNTIME_ENV_JSON="{
107 |   \"env_vars\": {
108 |     \"PYTHONPATH\": \"/root/Megatron-LM/\",
109 |     \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
110 |     \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
111 |   }
112 | }"
113 | 
114 | ray job submit --address="http://127.0.0.1:8265" \
115 |    --runtime-env-json="${RUNTIME_ENV_JSON}" \
116 |    -- python3 train_async.py \
117 |    --actor-num-nodes 1 \
118 |    --actor-num-gpus-per-node 8 \
119 |    ${MODEL_ARGS[@]} \
120 |    ${CKPT_ARGS[@]} \
121 |    ${SFT_ARGS[@]} \
122 |    ${OPTIMIZER_ARGS[@]} \
123 |    ${DISTRIBUTED_ARGS[@]} \
124 |    ${WANDB_ARGS[@]} \
125 |    ${PERF_ARGS[@]} \
126 |    ${EVAL_ARGS[@]} \
127 |    ${MISC_ARGS[@]}
128 | 


--------------------------------------------------------------------------------
/scripts/run-qwen3-4B-rf-baseline.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | # for rerun the task
  4 | pkill -9 sglang
  5 | sleep 3
  6 | ray stop --force
  7 | pkill -9 ray
  8 | pkill -9 python
  9 | sleep 3
 10 | pkill -9 ray
 11 | pkill -9 python
 12 | 
 13 | set -ex
 14 | 
 15 | # will prevent ray from buffering stdout/stderr
 16 | export PYTHONBUFFERED=16
 17 | 
 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then
 20 |     HAS_NVLINK=1
 21 | else
 22 |     HAS_NVLINK=0
 23 | fi
 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
 25 | 
 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
 27 | source "${SCRIPT_DIR}/models/qwen3-4B.sh"
 28 | 
 29 | CKPT_ARGS=(
 30 |    --hf-checkpoint /root/Qwen3-4B
 31 |    #--hf-checkpoint /root/Qwen3-4B-FP8
 32 |    --ref-load /root/Qwen3-4B_torch_dist
 33 |    --load /root/Qwen3-4B_slime/
 34 |    --save /root/Qwen3-4B_slime/
 35 |    --save-interval 20
 36 | )
 37 | 
 38 | ROLLOUT_ARGS=(
 39 |    --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
 40 |    --input-key prompt
 41 |    --label-key label
 42 |    --apply-chat-template
 43 |    --rollout-shuffle
 44 |    --rm-type deepscaler
 45 |    --num-rollout 3000
 46 |    --rollout-batch-size 32
 47 |    --n-samples-per-prompt 8
 48 |    --rollout-max-response-len 8192
 49 |    --rollout-temperature 0.8
 50 | 
 51 |    --global-batch-size 256
 52 |    --balance-data
 53 | )
 54 | 
 55 | EVAL_ARGS=(
 56 |    --eval-interval 20
 57 |    --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
 58 |    --n-samples-per-eval-prompt 16
 59 |    --eval-max-response-len 16384
 60 |    --eval-top-p 0.7
 61 | )
 62 | 
 63 | PERF_ARGS=(
 64 |    --tensor-model-parallel-size 1
 65 |    --sequence-parallel
 66 |    --pipeline-model-parallel-size 1
 67 |    --context-parallel-size 1
 68 |    --expert-model-parallel-size 1
 69 |    --expert-tensor-parallel-size 1
 70 | 
 71 |    --recompute-granularity full
 72 |    --recompute-method uniform
 73 |    --recompute-num-layers 1
 74 | 
 75 |    # --micro-batch-size 1
 76 |    --use-dynamic-batch-size
 77 |    --max-tokens-per-gpu 9216
 78 | )
 79 | 
 80 | GRPO_ARGS=(
 81 |    --advantage-estimator reinforce_plus_plus_baseline
 82 |    --use-kl-loss
 83 |    --kl-loss-coef 0.00
 84 |    --kl-loss-type low_var_kl
 85 |    --kl-coef 0.00
 86 |    --entropy-coef 0.00
 87 |    --eps-clip 0.2
 88 |    --eps-clip-high 0.28
 89 |    --normalize-advantages
 90 | )
 91 | 
 92 | OPTIMIZER_ARGS=(
 93 |    --optimizer adam
 94 |    --lr 1e-6
 95 |    --lr-decay-style constant
 96 |    --weight-decay 0.1
 97 |    --adam-beta1 0.9
 98 |    --adam-beta2 0.98
 99 | )
100 | 
101 | WANDB_ARGS=(
102 | #    --use-wandb
103 | #    --wandb-project slime-dev
104 | #    --wandb-group qwen3-4B-test-reinforce_plus_plus-baseline
105 | #    --wandb-key ${WANDB_KEY}
106 | )
107 | 
108 | SGLANG_ARGS=(
109 |    --rollout-num-gpus-per-engine 2
110 |    --sglang-mem-fraction-static 0.7
111 | )
112 | 
113 | MISC_ARGS=(
114 |    # default dropout in megatron is 0.1
115 |    --attention-dropout 0.0
116 |    --hidden-dropout 0.0
117 |    # should be good for model performance
118 |    --accumulate-allreduce-grads-in-fp32
119 |    --attention-softmax-in-fp32
120 |    # need to comment this when using model with MLA
121 |    --attention-backend flash
122 | )
123 | 
124 | # launch the master node of ray in container
125 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
126 | export MASTER_PORT=${MASTER_PORT:-"12345"}
127 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
128 | 
129 | # Build the runtime environment JSON with proper variable substitution
130 | RUNTIME_ENV_JSON="{
131 |   \"env_vars\": {
132 |     \"PYTHONPATH\": \"/root/Megatron-LM/\",
133 |     \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
134 |     \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
135 |   }
136 | }"
137 | 
138 | ray job submit --address="http://127.0.0.1:8265" \
139 |    --runtime-env-json="${RUNTIME_ENV_JSON}" \
140 |    -- python3 train.py \
141 |    --actor-num-nodes 1 \
142 |    --actor-num-gpus-per-node 8 \
143 |    --colocate \
144 |    ${MODEL_ARGS[@]} \
145 |    ${CKPT_ARGS[@]} \
146 |    ${ROLLOUT_ARGS[@]} \
147 |    ${OPTIMIZER_ARGS[@]} \
148 |    ${GRPO_ARGS[@]} \
149 |    ${DISTRIBUTED_ARGS[@]} \
150 |    ${WANDB_ARGS[@]} \
151 |    ${PERF_ARGS[@]} \
152 |    ${EVAL_ARGS[@]} \
153 |    ${SGLANG_ARGS[@]} \
154 |    ${MISC_ARGS[@]}


--------------------------------------------------------------------------------
/scripts/run-qwen3-4B-rf.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | # for rerun the task
  4 | pkill -9 sglang
  5 | sleep 3
  6 | ray stop --force
  7 | pkill -9 ray
  8 | pkill -9 python
  9 | sleep 3
 10 | pkill -9 ray
 11 | pkill -9 python
 12 | 
 13 | set -ex
 14 | 
 15 | # will prevent ray from buffering stdout/stderr
 16 | export PYTHONBUFFERED=16
 17 | 
 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then
 20 |     HAS_NVLINK=1
 21 | else
 22 |     HAS_NVLINK=0
 23 | fi
 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
 25 | 
 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
 27 | source "${SCRIPT_DIR}/models/qwen3-4B.sh"
 28 | 
 29 | CKPT_ARGS=(
 30 |    --hf-checkpoint /root/Qwen3-4B
 31 |    #--hf-checkpoint /root/Qwen3-4B-FP8
 32 |    --ref-load /root/Qwen3-4B_torch_dist
 33 |    --load /root/Qwen3-4B_slime/
 34 |    --save /root/Qwen3-4B_slime/
 35 |    --save-interval 20
 36 | )
 37 | 
 38 | ROLLOUT_ARGS=(
 39 |    --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
 40 |    --input-key prompt
 41 |    --label-key label
 42 |    --apply-chat-template
 43 |    --rollout-shuffle
 44 |    --rm-type deepscaler
 45 |    --num-rollout 3000
 46 |    --rollout-batch-size 32
 47 |    --n-samples-per-prompt 8
 48 |    --rollout-max-response-len 8192
 49 |    --rollout-temperature 0.8
 50 | 
 51 |    --global-batch-size 256
 52 |    --balance-data
 53 | )
 54 | 
 55 | EVAL_ARGS=(
 56 |    --eval-interval 20
 57 |    --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
 58 |    --n-samples-per-eval-prompt 16
 59 |    --eval-max-response-len 16384
 60 |    --eval-top-p 0.7
 61 | )
 62 | 
 63 | PERF_ARGS=(
 64 |    --tensor-model-parallel-size 1
 65 |    --sequence-parallel
 66 |    --pipeline-model-parallel-size 1
 67 |    --context-parallel-size 1
 68 |    --expert-model-parallel-size 1
 69 |    --expert-tensor-parallel-size 1
 70 | 
 71 |    --recompute-granularity full
 72 |    --recompute-method uniform
 73 |    --recompute-num-layers 1
 74 | 
 75 |    # --micro-batch-size 1
 76 |    --use-dynamic-batch-size
 77 |    --max-tokens-per-gpu 9216
 78 | )
 79 | 
 80 | GRPO_ARGS=(
 81 |    --advantage-estimator reinforce_plus_plus
 82 |    --use-kl-loss
 83 |    --kl-loss-coef 0.00
 84 |    --kl-loss-type low_var_kl
 85 |    --kl-coef 0.005
 86 |    --entropy-coef 0.00
 87 |    --eps-clip 0.2
 88 |    --eps-clip-high 0.28
 89 |    --gamma 1.0
 90 |    --normalize-advantages
 91 | )
 92 | 
 93 | OPTIMIZER_ARGS=(
 94 |    --optimizer adam
 95 |    --lr 5e-7
 96 |    --lr-decay-style constant
 97 |    --weight-decay 0.1
 98 |    --adam-beta1 0.9
 99 |    --adam-beta2 0.98
100 | )
101 | 
102 | WANDB_ARGS=(
103 | #    --use-wandb
104 | #    --wandb-project slime-dev
105 | #    --wandb-group qwen3-4B-test-reinforce_plus_plus
106 | #    --wandb-key ${WANDB_KEY}
107 | )
108 | 
109 | SGLANG_ARGS=(
110 |    --rollout-num-gpus-per-engine 2
111 |    --sglang-mem-fraction-static 0.7
112 | )
113 | 
114 | MISC_ARGS=(
115 |    # default dropout in megatron is 0.1
116 |    --attention-dropout 0.0
117 |    --hidden-dropout 0.0
118 |    # should be good for model performance
119 |    --accumulate-allreduce-grads-in-fp32
120 |    --attention-softmax-in-fp32
121 |    # need to comment this when using model with MLA
122 |    --attention-backend flash
123 | )
124 | 
125 | # launch the master node of ray in container
126 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
127 | export MASTER_PORT=${MASTER_PORT:-"12345"}
128 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
129 | 
130 | # Build the runtime environment JSON with proper variable substitution
131 | RUNTIME_ENV_JSON="{
132 |   \"env_vars\": {
133 |     \"PYTHONPATH\": \"/root/Megatron-LM/\",
134 |     \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
135 |     \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
136 |   }
137 | }"
138 | 
139 | ray job submit --address="http://127.0.0.1:8265" \
140 |    --runtime-env-json="${RUNTIME_ENV_JSON}" \
141 |    -- python3 train.py \
142 |    --actor-num-nodes 1 \
143 |    --actor-num-gpus-per-node 8 \
144 |    --colocate \
145 |    ${MODEL_ARGS[@]} \
146 |    ${CKPT_ARGS[@]} \
147 |    ${ROLLOUT_ARGS[@]} \
148 |    ${OPTIMIZER_ARGS[@]} \
149 |    ${GRPO_ARGS[@]} \
150 |    ${DISTRIBUTED_ARGS[@]} \
151 |    ${WANDB_ARGS[@]} \
152 |    ${PERF_ARGS[@]} \
153 |    ${EVAL_ARGS[@]} \
154 |    ${SGLANG_ARGS[@]} \
155 |    ${MISC_ARGS[@]}


--------------------------------------------------------------------------------
/scripts/run-qwen3-4B.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | # for rerun the task
  4 | pkill -9 sglang
  5 | sleep 3
  6 | ray stop --force
  7 | pkill -9 ray
  8 | pkill -9 python
  9 | sleep 3
 10 | pkill -9 ray
 11 | pkill -9 python
 12 | 
 13 | set -ex
 14 | 
 15 | # will prevent ray from buffering stdout/stderr
 16 | export PYTHONBUFFERED=16
 17 | 
 18 | NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
 19 | if [ "$NVLINK_COUNT" -gt 0 ]; then
 20 |     HAS_NVLINK=1
 21 | else
 22 |     HAS_NVLINK=0
 23 | fi
 24 | echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
 25 | 
 26 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
 27 | source "${SCRIPT_DIR}/models/qwen3-4B.sh"
 28 | 
 29 | CKPT_ARGS=(
 30 |    --hf-checkpoint /root/Qwen3-4B
 31 |    #--hf-checkpoint /root/Qwen3-4B-FP8
 32 |    --ref-load /root/Qwen3-4B_torch_dist
 33 |    --load /root/Qwen3-4B_slime/
 34 |    --save /root/Qwen3-4B_slime/
 35 |    --save-interval 20
 36 | )
 37 | 
 38 | ROLLOUT_ARGS=(
 39 |    --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
 40 |    --input-key prompt
 41 |    --label-key label
 42 |    --apply-chat-template
 43 |    --rollout-shuffle
 44 |    --rm-type deepscaler
 45 |    --num-rollout 3000
 46 |    --rollout-batch-size 32
 47 |    --n-samples-per-prompt 8
 48 |    --rollout-max-response-len 8192
 49 |    --rollout-temperature 0.8
 50 | 
 51 |    --global-batch-size 256
 52 |    --balance-data
 53 | )
 54 | 
 55 | EVAL_ARGS=(
 56 |    --eval-interval 20
 57 |    --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
 58 |    --n-samples-per-eval-prompt 16
 59 |    --eval-max-response-len 16384
 60 |    --eval-top-p 0.7
 61 | )
 62 | 
 63 | PERF_ARGS=(
 64 |    --tensor-model-parallel-size 2
 65 |    --sequence-parallel
 66 |    --pipeline-model-parallel-size 1
 67 |    --context-parallel-size 1
 68 |    --expert-model-parallel-size 1
 69 |    --expert-tensor-parallel-size 1
 70 | 
 71 |    --recompute-granularity full
 72 |    --recompute-method uniform
 73 |    --recompute-num-layers 1
 74 | 
 75 |    # --micro-batch-size 1
 76 |    --use-dynamic-batch-size
 77 |    --max-tokens-per-gpu 9216
 78 | )
 79 | 
 80 | GRPO_ARGS=(
 81 |    --advantage-estimator grpo
 82 |    --use-kl-loss
 83 |    --kl-loss-coef 0.00
 84 |    --kl-loss-type low_var_kl
 85 |    --kl-coef 0.00
 86 |    --entropy-coef 0.00
 87 |    --eps-clip 0.2
 88 |    --eps-clip-high 0.28
 89 | )
 90 | 
 91 | OPTIMIZER_ARGS=(
 92 |    --optimizer adam
 93 |    --lr 1e-6
 94 |    --lr-decay-style constant
 95 |    --weight-decay 0.1
 96 |    --adam-beta1 0.9
 97 |    --adam-beta2 0.98
 98 | )
 99 | 
100 | WANDB_ARGS=(
101 |    # --use-wandb
102 |    # --wandb-project slime-dev
103 |    # --wandb-group qwen3-4B-test
104 |    # --wandb-key ${WANDB_KEY}
105 | )
106 | 
107 | SGLANG_ARGS=(
108 |    --rollout-num-gpus-per-engine 2
109 |    --sglang-mem-fraction-static 0.7
110 | )
111 | 
112 | MISC_ARGS=(
113 |    # default dropout in megatron is 0.1
114 |    --attention-dropout 0.0
115 |    --hidden-dropout 0.0
116 |    # should be good for model performance
117 |    --accumulate-allreduce-grads-in-fp32
118 |    --attention-softmax-in-fp32
119 |    # need to comment this when using model with MLA
120 |    --attention-backend flash
121 | )
122 | 
123 | # launch the master node of ray in container
124 | export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
125 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
126 | 
127 | # Build the runtime environment JSON with proper variable substitution
128 | RUNTIME_ENV_JSON="{
129 |   \"env_vars\": {
130 |     \"PYTHONPATH\": \"/root/Megatron-LM/\",
131 |     \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
132 |     \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
133 |   }
134 | }"
135 | 
136 | ray job submit --address="http://127.0.0.1:8265" \
137 |    --runtime-env-json="${RUNTIME_ENV_JSON}" \
138 |    -- python3 train.py \
139 |    --actor-num-nodes 1 \
140 |    --actor-num-gpus-per-node 8 \
141 |    --colocate \
142 |    ${MODEL_ARGS[@]} \
143 |    ${CKPT_ARGS[@]} \
144 |    ${ROLLOUT_ARGS[@]} \
145 |    ${OPTIMIZER_ARGS[@]} \
146 |    ${GRPO_ARGS[@]} \
147 |    ${DISTRIBUTED_ARGS[@]} \
148 |    ${WANDB_ARGS[@]} \
149 |    ${PERF_ARGS[@]} \
150 |    ${EVAL_ARGS[@]} \
151 |    ${SGLANG_ARGS[@]} \
152 |    ${MISC_ARGS[@]}


--------------------------------------------------------------------------------
/scripts/run_agent.sh:
--------------------------------------------------------------------------------
 1 | 
 2 | #!/bin/bash
 3 | 
 4 | set -e
 5 | 
 6 | SESSION_NAME="slime_run"
 7 | WINDOW_1="slime"
 8 | WINDOW_2="buffer"
 9 | 
10 | if tmux has-session -t $SESSION_NAME 2>/dev/null; then
11 |     echo "Killing existing tmux session: $SESSION_NAME"
12 |     tmux kill-session -t $SESSION_NAME
13 | fi
14 | 
15 | tmux new-session -d -s $SESSION_NAME -n $WINDOW_1
16 | tmux send-keys -t ${SESSION_NAME}:${WINDOW_1} "cd $(pwd)" C-m
17 | tmux send-keys -t ${SESSION_NAME}:${WINDOW_1} "bash ./scripts/agent-example.sh" C-m
18 | 
19 | tmux new-window -t $SESSION_NAME -n $WINDOW_2
20 | tmux send-keys -t ${SESSION_NAME}:${WINDOW_2} "sleep 30 && cd slime_plugins/rollout_buffer && python buffer.py" C-m
21 | 
22 | tmux attach-session -t $SESSION_NAME


--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
 1 | import sys
 2 | import platform
 3 | 
 4 | from setuptools import find_packages, setup
 5 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
 6 | 
 7 | 
 8 | def _fetch_requirements(path):
 9 |     with open(path, "r") as fd:
10 |         return [r.strip() for r in fd.readlines()]
11 | 
12 | 
13 | # Custom wheel class to modify the wheel name
14 | class bdist_wheel(_bdist_wheel):
15 |     def finalize_options(self):
16 |         _bdist_wheel.finalize_options(self)
17 |         self.root_is_pure = False
18 | 
19 |     def get_tag(self):
20 |         python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
21 |         abi_tag = f"{python_version}"
22 | 
23 |         if platform.system() == "Linux":
24 |             platform_tag = "manylinux1_x86_64"
25 |         else:
26 |             platform_tag = platform.system().lower()
27 | 
28 |         return python_version, abi_tag, platform_tag
29 | 
30 | 
31 | # Setup configuration
32 | setup(
33 |     author="slime Team",
34 |     name="slime",
35 |     version="0.0.1",
36 |     packages=find_packages(include=["slime*", "slime_plugins*"]),
37 |     include_package_data=True,
38 |     install_requires=_fetch_requirements("requirements.txt"),
39 |     python_requires=">=3.10",
40 |     classifiers=[
41 |         "Programming Language :: Python :: 3.10",
42 |         "Programming Language :: Python :: 3.11",
43 |         "Programming Language :: Python :: 3.12",
44 |         "Environment :: GPU :: NVIDIA CUDA",
45 |         "Topic :: Scientific/Engineering :: Artificial Intelligence",
46 |         "Topic :: System :: Distributed Computing",
47 |     ],
48 |     cmdclass={"bdist_wheel": bdist_wheel},
49 | )
50 | 


--------------------------------------------------------------------------------
/slime/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime/__init__.py


--------------------------------------------------------------------------------
/slime/backends/megatron_utils/__init__.py:
--------------------------------------------------------------------------------
 1 | import logging
 2 | 
 3 | from .arguments import _vocab_size_with_padding, parse_args, validate_args
 4 | from .checkpoint import load_checkpoint, save_checkpoint
 5 | from .data import (
 6 |     get_batch,
 7 |     get_data_iterator,
 8 |     log_eval_data,
 9 |     log_multi_turn_data,
10 |     log_passrate,
11 |     log_perf_data,
12 |     log_rollout_data,
13 |     process_rollout_data,
14 |     set_metadata,
15 | )
16 | from .initialize import get_gloo_group, init
17 | from .loss import compute_advantages_and_returns, get_log_probs_and_entropy, loss_function
18 | from .model import forward_only, initialize_model_and_optimizer, save, train
19 | 
20 | logging.getLogger().setLevel(logging.WARNING)
21 | 
22 | 
23 | __all__ = [
24 |     "parse_args",
25 |     "validate_args",
26 |     "load_checkpoint",
27 |     "save_checkpoint",
28 |     "get_batch",
29 |     "get_data_iterator",
30 |     "get_gloo_group",
31 |     "process_rollout_data",
32 |     "init",
33 |     "set_metadata",
34 |     "get_log_probs_and_entropy",
35 |     "log_rollout_data",
36 |     "log_passrate",
37 |     "log_multi_turn_data",
38 |     "log_eval_data",
39 |     "log_perf_data",
40 |     "compute_advantages_and_returns",
41 |     "loss_function",
42 |     "forward_only",
43 |     "train",
44 |     "save",
45 |     "initialize_model_and_optimizer",
46 |     "_vocab_size_with_padding",
47 | ]
48 | 


--------------------------------------------------------------------------------
/slime/backends/megatron_utils/arguments.py:
--------------------------------------------------------------------------------
1 | from megatron.training.arguments import parse_args, validate_args
2 | from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
3 | 
4 | __all__ = ["validate_args", "parse_args", "_vocab_size_with_padding"]
5 | 


--------------------------------------------------------------------------------
/slime/backends/megatron_utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | # TODO: may need to copy those 2 functions and do refactoring.
2 | from megatron.training.checkpointing import load_checkpoint, save_checkpoint
3 | 
4 | __all__ = ["load_checkpoint", "save_checkpoint"]
5 | 


--------------------------------------------------------------------------------
/slime/backends/megatron_utils/cp_utils.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | from megatron.core import mpu
  3 | 
  4 | 
  5 | def get_logits_and_tokens_offset_with_cp(
  6 |     total_length: int,
  7 |     response_length: int,
  8 | ):
  9 |     """
 10 |     All offsets start from the begining of the prompt.
 11 |     """
 12 |     cp_rank = mpu.get_context_parallel_rank()
 13 |     cp_size = mpu.get_context_parallel_world_size()
 14 |     assert cp_size > 1
 15 | 
 16 |     prompt_length = total_length - response_length
 17 |     chunk_size = (total_length + 2 * cp_size - 1) // (2 * cp_size)
 18 | 
 19 |     # the offset of 2 chunks
 20 |     chunk_0 = (cp_rank * chunk_size, (cp_rank + 1) * chunk_size)
 21 |     chunk_1 = ((2 * cp_size - cp_rank - 1) * chunk_size, (2 * cp_size - cp_rank) * chunk_size)
 22 | 
 23 |     # the offset of 2 logits, note that the logits need a "-1".
 24 |     logits_0 = (max(chunk_0[0], prompt_length - 1), min(chunk_0[1], total_length - 1))
 25 |     logits_1 = (max(chunk_1[0], prompt_length - 1), min(chunk_1[1], total_length - 1))
 26 | 
 27 |     # when the sequence is empty, make an empty slice to continue the gradient flow.
 28 |     if logits_0[0] < logits_0[1]:
 29 |         token_0 = (logits_0[0] + 1, logits_0[1] + 1)
 30 |     else:
 31 |         logits_0 = (0, 0)
 32 |         token_0 = (0, 0)
 33 | 
 34 |     if logits_1[0] < logits_1[1]:
 35 |         token_1 = (logits_1[0] + 1, logits_1[1] + 1)
 36 |     else:
 37 |         logits_1 = (0, 0)
 38 |         token_1 = (0, 0)
 39 | 
 40 |     return chunk_size, (chunk_0, chunk_1), (logits_0, logits_1), (token_0, token_1)
 41 | 
 42 | 
 43 | def get_sum_of_sample_mean(
 44 |     total_lengths,
 45 |     response_lengths,
 46 |     loss_masks,
 47 |     calculate_per_token_loss: bool = False,
 48 | ):
 49 |     """
 50 |     Calculate correct sample mean for CP
 51 |     """
 52 |     cp_size = mpu.get_context_parallel_world_size()
 53 |     if cp_size == 1:
 54 | 
 55 |         def sum_of_sample_mean(x: torch.Tensor):
 56 |             return sum(
 57 |                 [
 58 |                     (x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1)
 59 |                     for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks)
 60 |                 ]
 61 |             )
 62 | 
 63 |         def sum_of_token(x: torch.Tensor):
 64 |             return sum(
 65 |                 [(x_i * loss_mask_i).sum() for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks)]
 66 |             )
 67 | 
 68 |     else:
 69 |         cp_chunk_lengths = []
 70 |         chunked_loss_masks = []
 71 |         for i, (total_length, response_length, loss_mask) in enumerate(
 72 |             zip(total_lengths, response_lengths, loss_masks)
 73 |         ):
 74 |             prompt_length = total_length - response_length
 75 |             _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(total_length, response_length)
 76 |             loss_mask_0 = loss_mask[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length]
 77 |             loss_mask_1 = loss_mask[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length]
 78 |             chunked_loss_masks.append(torch.cat([loss_mask_0, loss_mask_1], dim=0))
 79 |             cp_chunk_lengths.append(chunked_loss_masks[i].size(0))
 80 | 
 81 |         def sum_of_sample_mean(x):
 82 |             return sum(
 83 |                 [
 84 |                     (x_i * chunked_loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1)
 85 |                     for x_i, chunked_loss_mask, loss_mask in zip(
 86 |                         x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, loss_masks
 87 |                     )
 88 |                 ]
 89 |             )
 90 | 
 91 |         def sum_of_token(x: torch.Tensor):
 92 |             return sum(
 93 |                 [
 94 |                     (x_i * chunked_loss_mask).sum()
 95 |                     for x_i, chunked_loss_mask in zip(x.split(cp_chunk_lengths, dim=0), chunked_loss_masks)
 96 |                 ]
 97 |             )
 98 | 
 99 |     return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token
100 | 


--------------------------------------------------------------------------------
/slime/backends/megatron_utils/models/__init__.py:
--------------------------------------------------------------------------------
  1 | # Adapt from https://github.com/NVIDIA/Megatron-LM/blob/b1efb3c7126ef7615e8c333432d76e08038e17ff/pretrain_gpt.py
  2 | import inspect
  3 | from contextlib import nullcontext
  4 | 
  5 | from megatron.core.enums import ModelType
  6 | from megatron.core.models.gpt import GPTModel
  7 | from megatron.core.models.gpt.gpt_layer_specs import (
  8 |     get_gpt_decoder_block_spec,
  9 |     get_gpt_layer_local_spec,
 10 |     get_gpt_layer_with_transformer_engine_spec,
 11 | )
 12 | from megatron.core.transformer.spec_utils import import_module
 13 | from megatron.training import get_args
 14 | from megatron.training.arguments import core_transformer_config_from_args
 15 | 
 16 | 
 17 | def model_provider(pre_process=True, post_process=True) -> GPTModel:
 18 |     """Builds the model.
 19 | 
 20 |     If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
 21 | 
 22 |     Args:
 23 |         pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
 24 |         post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
 25 | 
 26 | 
 27 |     Returns:
 28 |         Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
 29 |     """
 30 |     args = get_args()
 31 |     use_te = args.transformer_impl == "transformer_engine"
 32 | 
 33 |     # Experimental loading arguments from yaml
 34 |     config = core_transformer_config_from_args(args)
 35 | 
 36 |     if args.spec is not None:
 37 |         transformer_layer_spec = import_module(args.spec)
 38 |         # Allow the spec to be a function so that user can use customized Megatron easier.
 39 |         if callable(transformer_layer_spec):
 40 |             transformer_layer_spec = transformer_layer_spec(args)
 41 |     else:
 42 |         if args.num_experts:
 43 |             # Define the decoder block spec
 44 |             transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te)
 45 |         else:
 46 |             # Define the decoder layer spec
 47 |             if use_te:
 48 |                 transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
 49 |                     args.num_experts,
 50 |                     args.moe_grouped_gemm,
 51 |                     args.qk_layernorm,
 52 |                     args.multi_latent_attention,
 53 |                     args.moe_use_legacy_grouped_gemm,
 54 |                 )
 55 |             else:
 56 |                 transformer_layer_spec = get_gpt_layer_local_spec(
 57 |                     args.num_experts,
 58 |                     args.moe_grouped_gemm,
 59 |                     args.qk_layernorm,
 60 |                     args.multi_latent_attention,
 61 |                     args.moe_use_legacy_grouped_gemm,
 62 |                 )
 63 | 
 64 |     build_model_context = nullcontext
 65 |     build_model_context_args = {}
 66 |     if args.fp8_param_gather:
 67 |         try:
 68 |             from transformer_engine.pytorch import fp8_model_init
 69 | 
 70 |             build_model_context = fp8_model_init
 71 |             build_model_context_args["enabled"] = True
 72 | 
 73 |             # Check if fp8_model_init supports preserve_high_precision_init_val
 74 |             if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters:
 75 |                 build_model_context_args["preserve_high_precision_init_val"] = True
 76 |         except:
 77 |             raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.")
 78 | 
 79 |     kwargs = {
 80 |         "config": config,
 81 |         "transformer_layer_spec": transformer_layer_spec,
 82 |         "vocab_size": args.padded_vocab_size,
 83 |         "max_sequence_length": args.max_position_embeddings,
 84 |         "pre_process": pre_process,
 85 |         "post_process": post_process,
 86 |         "fp16_lm_cross_entropy": args.fp16_lm_cross_entropy,
 87 |         "parallel_output": True,
 88 |         "share_embeddings_and_output_weights": not args.untie_embeddings_and_output_weights,
 89 |         "position_embedding_type": args.position_embedding_type,
 90 |         "rotary_percent": args.rotary_percent,
 91 |         "rotary_base": args.rotary_base,
 92 |         "rope_scaling": args.use_rope_scaling,
 93 |     }
 94 | 
 95 |     if getattr(args, "mtp_num_layers", None):
 96 |         from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec
 97 | 
 98 |         mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te)
 99 |         kwargs["mtp_block_spec"] = mtp_block_spec
100 | 
101 |     with build_model_context(**build_model_context_args):
102 |         model = GPTModel(**kwargs)
103 | 
104 |     return model
105 | 
106 | 
107 | def get_model_provider_and_type():
108 |     return model_provider, ModelType.encoder_or_decoder
109 | 


--------------------------------------------------------------------------------
/slime/backends/sglang_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime/backends/sglang_utils/__init__.py


--------------------------------------------------------------------------------
/slime/backends/sglang_utils/arguments.py:
--------------------------------------------------------------------------------
  1 | from sglang.srt.server_args import ServerArgs
  2 | 
  3 | 
  4 | def add_sglang_router_arguments(parser):
  5 |     """
  6 |     Add arguments to the parser for the SGLang router.
  7 |     """
  8 |     parser.add_argument(
  9 |         "--sglang-router-ip",
 10 |         type=str,
 11 |         default=None,
 12 |         help="IP address of the SGLang router",
 13 |     )
 14 |     parser.add_argument(
 15 |         "--sglang-router-port",
 16 |         type=int,
 17 |         default=None,
 18 |         help="Port of the SGLang router",
 19 |     )
 20 |     return parser
 21 | 
 22 | 
 23 | def add_sglang_arguments(parser):
 24 |     """
 25 |     Add arguments to the parser for the SGLang server.
 26 |     """
 27 |     parser = add_sglang_router_arguments(parser)
 28 |     parser.add_argument("--sglang-server-concurrency", type=int, default=512)
 29 | 
 30 |     old_add_argument = parser.add_argument
 31 | 
 32 |     skipped_args = [
 33 |         "model_path",
 34 |         "dtype",
 35 |         "trust_remote_code",
 36 |         "random_seed",
 37 |         # memory
 38 |         "enable_memory_saver",
 39 |         # distributed
 40 |         "tp_size",
 41 |         "port",
 42 |         "nnodes",
 43 |         "node_rank",
 44 |         "dist_init_addr",
 45 |         "gpu_id_step",
 46 |         "base_gpu_id",
 47 |         "nccl_port",
 48 |         "skip_server_warmup",
 49 |     ]
 50 | 
 51 |     def new_add_argument_wrapper(*name_or_flags, **kwargs):
 52 |         """
 53 |         Add arguments to the parser, ensuring that the server arguments are prefixed and skippable.
 54 |         """
 55 |         # Determine the canonical name for skip check (e.g., "model_path")
 56 |         canonical_name_for_skip_check = None
 57 |         if "dest" in kwargs:
 58 |             canonical_name_for_skip_check = kwargs["dest"]
 59 |         else:
 60 |             for flag_name_candidate in name_or_flags:
 61 |                 if isinstance(flag_name_candidate, str) and flag_name_candidate.startswith("--"):
 62 |                     # Derive from first long flag: --foo-bar -> foo_bar
 63 |                     stem = flag_name_candidate[2:]
 64 |                     canonical_name_for_skip_check = stem.replace("-", "_")
 65 |                     break
 66 |             # If no long flag and no dest, skip logic might not catch it unless short flags imply a dest.
 67 | 
 68 |         if canonical_name_for_skip_check and canonical_name_for_skip_check in skipped_args:
 69 |             return  # Skip this entire argument definition
 70 | 
 71 |         # If not skipped, proceed to prefix flags and dest
 72 |         new_name_or_flags_list = []
 73 |         for item_flag in name_or_flags:
 74 |             if isinstance(item_flag, str) and item_flag.startswith("-"):
 75 |                 original_flag_stem = item_flag.lstrip("-")  # "foo-bar" from "--foo-bar", or "f" from "-f"
 76 |                 prefixed_item = f"--sglang-{original_flag_stem}"
 77 |                 new_name_or_flags_list.append(prefixed_item)
 78 |             else:
 79 |                 # Positional arguments or non-string items
 80 |                 new_name_or_flags_list.append(item_flag)
 81 | 
 82 |         # Prepare kwargs for the actual add_argument call.
 83 |         # Make a copy to avoid modifying the original kwargs dict.
 84 |         final_kwargs = kwargs.copy()
 85 | 
 86 |         # If 'dest' is explicitly provided and is a string, prefix it.
 87 |         # This ensures the attribute on the args namespace becomes, e.g., args.sglang_dest_name.
 88 |         if "dest" in final_kwargs and isinstance(final_kwargs["dest"], str):
 89 |             original_dest = final_kwargs["dest"]
 90 |             # Avoid double prefixing if dest somehow already starts with sglang_
 91 |             if not original_dest.startswith("sglang_"):
 92 |                 final_kwargs["dest"] = f"sglang_{original_dest}"
 93 |         # If 'dest' is not explicitly provided (or is None/not a string),
 94 |         # argparse will derive 'dest' from the (now prefixed) flag names.
 95 |         # E.g., if the first flag is "--sglang-foo-bar", argparse sets dest to "sglang_foo_bar".
 96 | 
 97 |         old_add_argument(*new_name_or_flags_list, **final_kwargs)
 98 | 
 99 |     parser.add_argument = new_add_argument_wrapper
100 |     ServerArgs.add_cli_args(parser)
101 |     parser.add_argument = old_add_argument
102 | 
103 |     return parser
104 | 
105 | 
106 | def validate_args(args):
107 |     # sglang
108 |     args.sglang_tp_size = args.rollout_num_gpus_per_engine
109 |     args.sglang_dp_size = args.sglang_data_parallel_size
110 |     args.sglang_pp_size = args.sglang_pipeline_parallel_size
111 |     args.sglang_ep_size = args.sglang_expert_parallel_size
112 | 
113 |     if args.sglang_dp_size > 1:
114 |         assert args.sglang_enable_dp_attention
115 | 


--------------------------------------------------------------------------------
/slime/backends/sglang_utils/sglang_engine.py:
--------------------------------------------------------------------------------
  1 | import dataclasses
  2 | 
  3 | import os
  4 | from typing import TYPE_CHECKING
  5 | 
  6 | from sglang.srt.server_args import ServerArgs
  7 | from slime.utils.http_utils import get_host_info
  8 | from .http_server_engine import HttpServerEngineAdapter
  9 | 
 10 | if TYPE_CHECKING:
 11 |     pass
 12 | 
 13 | 
 14 | def get_base_gpu_id(args, rank):
 15 |     num_gpus = min(8, args.rollout_num_gpus_per_engine)
 16 |     if args.colocate:
 17 |         start_index = (rank * num_gpus) % 8
 18 |     else:
 19 |         num_actor_gpus = args.actor_num_gpus_per_node * args.actor_num_nodes
 20 |         start_index = (num_actor_gpus + rank * num_gpus) % 8
 21 |     return start_index
 22 | 
 23 | 
 24 | class SglangEngine:
 25 | 
 26 |     def __init__(self, args, rank, dist_init_addr, port, nccl_port):
 27 |         self.args = args
 28 | 
 29 |         # remove the CUDA_VISIBLE_DEVICES set by ray and use base_gpu_id
 30 |         os.environ.pop("CUDA_VISIBLE_DEVICES", None)
 31 | 
 32 |         nnodes = max(1, args.rollout_num_gpus_per_engine // 8)
 33 |         node_rank = rank % nnodes
 34 |         kwargs = {
 35 |             "model_path": args.hf_checkpoint,
 36 |             "trust_remote_code": True,
 37 |             "random_seed": args.seed + rank,
 38 |             # memory
 39 |             "enable_memory_saver": args.offload,
 40 |             # distributed
 41 |             "host": get_host_info()[1],
 42 |             "port": port,
 43 |             "nccl_port": nccl_port,
 44 |             "nnodes": nnodes,
 45 |             "node_rank": node_rank,
 46 |             "dist_init_addr": dist_init_addr,
 47 |             "gpu_id_step": 1,
 48 |             "base_gpu_id": get_base_gpu_id(args, rank),
 49 |             # parallel
 50 |             "tp_size": args.rollout_num_gpus_per_engine,
 51 |             "dp_size": args.sglang_dp_size,
 52 |             "pp_size": args.sglang_pp_size,
 53 |             "ep_size": args.sglang_ep_size,
 54 |             # always skip warmup to prevent warmup timeout.
 55 |             "skip_server_warmup": True,
 56 |         }
 57 | 
 58 |         unused_keys = set(kwargs.keys())
 59 |         for attr in dataclasses.fields(ServerArgs):
 60 |             if hasattr(args, f"sglang_{attr.name}") and attr.name not in kwargs:
 61 |                 kwargs[attr.name] = getattr(args, f"sglang_{attr.name}")
 62 |             unused_keys.discard(attr.name)
 63 | 
 64 |         # for compatibility with old args
 65 |         if len(unused_keys) > 0:
 66 |             print(f"Warning: The following arguments is not supported in the current sglang: {unused_keys}.")
 67 |             for key in unused_keys:
 68 |                 kwargs.pop(key)
 69 | 
 70 |         self.llm = HttpServerEngineAdapter(
 71 |             router_ip=args.sglang_router_ip, router_port=args.sglang_router_port, **kwargs
 72 |         )
 73 | 
 74 |     def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend):
 75 |         return self.llm.init_weights_update_group(
 76 |             master_address, master_port, rank_offset, world_size, group_name, backend
 77 |         )
 78 | 
 79 |     def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
 80 |         self.llm.update_weights_from_distributed(names, dtypes, shapes, group_name)
 81 |         return
 82 | 
 83 |     def update_weights_from_tensor(self, ipc_handles):
 84 |         self.llm.update_weights_from_tensor(ipc_handles)
 85 |         return
 86 | 
 87 |     def reset_prefix_cache(self):
 88 |         self.llm.flush_cache()
 89 | 
 90 |     def sleep(self, level=1):
 91 |         # Adhoc solution to ensure no running requests
 92 |         self.llm.flush_cache()
 93 |         self.llm.release_memory_occupation()
 94 | 
 95 |     def wake_up(self):
 96 |         self.llm.resume_memory_occupation()
 97 | 
 98 |     def pause_generation(self):
 99 |         self.llm.pause_generation()
100 | 
101 |     def continue_generation(self):
102 |         self.llm.continue_generation()
103 | 


--------------------------------------------------------------------------------
/slime/ray/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime/ray/__init__.py


--------------------------------------------------------------------------------
/slime/ray/placement_group.py:
--------------------------------------------------------------------------------
  1 | import socket
  2 | import ray
  3 | from ray.util.placement_group import placement_group
  4 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  5 | 
  6 | from .ppo_actor import RayTrainGroup
  7 | from .rollout import RolloutGroup
  8 | 
  9 | 
 10 | @ray.remote(num_gpus=1)
 11 | class InfoActor:
 12 |     def get_ip_and_gpu_id(self):
 13 |         return ray.util.get_node_ip_address(), ray.get_gpu_ids()[0]
 14 | 
 15 | 
 16 | def sort_key(x):
 17 |     index, node_identifier, gpu_id = x
 18 |     # Sort by node IP number and then by GPU ID
 19 |     try:
 20 |         # try to parse it as an IP address.
 21 |         ip_address = node_identifier
 22 |         node_ip_parts = list(map(int, ip_address.split(".")))
 23 |     except ValueError:
 24 |         # Try to resolve the hostname to an IP address.
 25 |         try:
 26 |             ip_address = socket.gethostbyname(node_identifier)
 27 |             node_ip_parts = list(map(int, ip_address.split(".")))
 28 |         except (socket.gaierror, TypeError):
 29 |             # Instead, we convert each character of the original identifier string
 30 |             # to its ASCII value. This provides a stable and consistent numerical
 31 |             # representation that allows for sorting.
 32 |             node_ip_parts = [ord(c) for c in node_identifier]
 33 | 
 34 |     return (node_ip_parts, gpu_id)
 35 | 
 36 | 
 37 | def _create_placement_group(num_gpus):
 38 |     """Create a placement group with the specified number of GPUs."""
 39 |     bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)]
 40 |     pg = placement_group(bundles, strategy="PACK")
 41 |     num_bundles = len(bundles)
 42 | 
 43 |     ray.get(pg.ready())
 44 |     # use info actor to get the GPU id
 45 |     info_actors = []
 46 |     for i in range(num_bundles):
 47 |         info_actors.append(
 48 |             InfoActor.options(
 49 |                 scheduling_strategy=PlacementGroupSchedulingStrategy(
 50 |                     placement_group=pg,
 51 |                     placement_group_bundle_index=i,
 52 |                 )
 53 |             ).remote()
 54 |         )
 55 |     gpu_ids = ray.get([actor.get_ip_and_gpu_id.remote() for actor in info_actors])
 56 |     for actor in info_actors:
 57 |         ray.kill(actor)
 58 | 
 59 |     bundle_infos = [(i, gpu_ids[i][0], gpu_ids[i][1]) for i in range(num_bundles)]
 60 |     pg_reordered_bundle_indices = [bundle_info[0] for bundle_info in sorted(bundle_infos, key=sort_key)]
 61 |     for i in range(num_bundles):
 62 |         actual_bundle_index = pg_reordered_bundle_indices[i]
 63 |         print(
 64 |             f"  bundle {i:4}, actual_bundle_index: {actual_bundle_index:4}, "
 65 |             f"node: {gpu_ids[actual_bundle_index][0]}, gpu: {gpu_ids[actual_bundle_index][1]}"
 66 |         )
 67 | 
 68 |     return pg, pg_reordered_bundle_indices
 69 | 
 70 | 
 71 | def create_placement_groups(args):
 72 |     """Create placement groups for actor and rollout engines."""
 73 | 
 74 |     num_gpus = 0
 75 |     if args.debug_train_only:
 76 |         num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node
 77 |         rollout_offset = 0
 78 |     elif args.debug_rollout_only:
 79 |         num_gpus = args.rollout_num_gpus
 80 |         rollout_offset = 0
 81 |     elif args.colocate:
 82 |         num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node
 83 |         rollout_offset = 0
 84 |     else:
 85 |         num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + args.rollout_num_gpus
 86 |         rollout_offset = args.actor_num_nodes * args.actor_num_gpus_per_node
 87 | 
 88 |     print(f"Creating placement group with {num_gpus} GPUs...")
 89 |     pg, actor_pg_reordered_bundle_indices = _create_placement_group(num_gpus)
 90 | 
 91 |     rollout_pg_reordered_bundle_indices = actor_pg_reordered_bundle_indices[rollout_offset:]
 92 | 
 93 |     return {
 94 |         "actor": (pg, actor_pg_reordered_bundle_indices),
 95 |         "rollout": (pg, rollout_pg_reordered_bundle_indices),
 96 |     }
 97 | 
 98 | 
 99 | def allocate_train_group(num_nodes, num_gpus_per_node, pg):
100 |     return RayTrainGroup(
101 |         num_nodes=num_nodes,
102 |         num_gpus_per_node=num_gpus_per_node,
103 |         pg=pg,
104 |         num_gpus_per_actor=0.8,
105 |     )
106 | 
107 | 
108 | def create_actor_group(args, pg):
109 |     actor_model = allocate_train_group(
110 |         num_nodes=args.actor_num_nodes,
111 |         num_gpus_per_node=args.actor_num_gpus_per_node,
112 |         pg=pg,
113 |     )
114 |     return actor_model
115 | 
116 | 
117 | def create_rollout_group(args, pg):
118 |     return RolloutGroup(args, pg)
119 | 


--------------------------------------------------------------------------------
/slime/ray/ray_actor.py:
--------------------------------------------------------------------------------
 1 | import ray
 2 | from slime.utils.http_utils import is_port_available
 3 | 
 4 | 
 5 | class RayActor:
 6 |     @staticmethod
 7 |     def _get_current_node_ip_and_free_port(start_port=10000, consecutive=1):
 8 |         address = ray._private.services.get_node_ip_address()
 9 |         # strip ipv6 address
10 |         address = address.strip("[]")
11 | 
12 |         # find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available
13 |         port = start_port
14 |         while not all(is_port_available(port + i) for i in range(consecutive)):
15 |             port += 1
16 | 
17 |         return address, port
18 | 
19 |     def get_master_addr_and_port(self):
20 |         return self.master_addr, self.master_port
21 | 


--------------------------------------------------------------------------------
/slime/ray/utils.py:
--------------------------------------------------------------------------------
 1 | # Adapted from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ray/utils.py#L1
 2 | import os
 3 | 
 4 | import ray
 5 | import torch
 6 | 
 7 | 
 8 | def ray_noset_visible_devices(env_vars=os.environ):
 9 |     # Refer to
10 |     # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96
11 |     # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103
12 |     # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95
13 |     # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117
14 |     # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109
15 |     # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172
16 |     # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98
17 |     NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [
18 |         "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
19 |         "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES",
20 |         "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES",
21 |         "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES",
22 |         "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES",
23 |         "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS",
24 |         "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR",
25 |     ]
26 |     return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST)
27 | 
28 | 
29 | def get_physical_gpu_id():
30 |     device = torch.cuda.current_device()
31 |     props = torch.cuda.get_device_properties(device)
32 |     return str(props.uuid)
33 | 
34 | 
35 | @ray.remote
36 | class Lock:
37 |     def __init__(self):
38 |         self._locked = False  # False: unlocked, True: locked
39 | 
40 |     def acquire(self):
41 |         """
42 |         Try to acquire the lock. Returns True if acquired, False otherwise.
43 |         Caller should retry until it returns True.
44 |         """
45 |         if not self._locked:
46 |             self._locked = True
47 |             return True
48 |         return False
49 | 
50 |     def release(self):
51 |         """Release the lock, allowing others to acquire."""
52 |         assert self._locked, "Lock is not acquired, cannot release."
53 |         self._locked = False
54 | 


--------------------------------------------------------------------------------
/slime/rollout/filter_hub/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime/rollout/filter_hub/__init__.py


--------------------------------------------------------------------------------
/slime/rollout/filter_hub/dynamic_sampling_filters.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | from slime.utils.types import Sample
 3 | 
 4 | 
 5 | __all__ = ["check_reward_nonzero_std"]
 6 | 
 7 | 
 8 | def check_reward_nonzero_std(args, samples: list[Sample], **kwargs):
 9 |     rewards = [sample.reward for sample in samples]
10 |     return torch.tensor(rewards, dtype=torch.float).std() > 0.0
11 | 


--------------------------------------------------------------------------------
/slime/rollout/filter_hub/over_sampling_filters.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | from slime.utils.types import Sample
 3 | 
 4 | 
 5 | __all__ = ["sort_by_reward_std"]
 6 | 
 7 | 
 8 | def sort_by_reward_std(args, samples: list[list[Sample]], **kwargs) -> list[list[Sample]]:
 9 |     samples_with_std = []
10 |     for group in samples:
11 |         rewards = [item.reward for item in group]
12 |         std = torch.tensor(rewards, dtype=torch.float).std()
13 |         samples_with_std.append((group, std))
14 |     # python sort is stable, so the order of samples with the same std is preserved
15 |     samples_with_std.sort(key=lambda x: x[1], reverse=True)
16 |     return [item[0] for item in samples_with_std]
17 | 


--------------------------------------------------------------------------------
/slime/rollout/rm_hub/__init__.py:
--------------------------------------------------------------------------------
 1 | import asyncio
 2 | from typing import Union
 3 | 
 4 | import aiohttp
 5 | 
 6 | from slime.utils.misc import load_function
 7 | from slime.utils.types import Sample
 8 | 
 9 | from .deepscaler import get_deepscaler_rule_based_reward
10 | from .f1 import f1_score
11 | from .math_dapo_utils import compute_score as compute_score_dapo
12 | from .math_utils import extract_answer as extract_boxed_answer
13 | from .math_utils import grade_answer_verl
14 | 
15 | 
16 | async def remote_rm(args, sample: Sample):
17 |     payload = {
18 |         "prompt": sample.prompt,
19 |         "response": sample.response,
20 |         "label": sample.label,
21 |     }
22 |     session_kwargs = {}
23 |     async with aiohttp.ClientSession(**session_kwargs) as session:
24 |         async with session.post(args.rm_url, json=payload) as resp:
25 |             resp.raise_for_status()
26 |             return await resp.json()
27 | 
28 | 
29 | async def async_rm(args, sample: Sample, **kwargs):
30 |     if args.custom_rm_path is not None:
31 |         rm_function = load_function(args.custom_rm_path)
32 |         return await rm_function(args, sample, **kwargs)
33 | 
34 |     rm_type = args.rm_type
35 |     response = sample.response
36 |     label = sample.label
37 |     if rm_type.startswith("boxed_"):
38 |         response = extract_boxed_answer(response)
39 |         rm_type = rm_type[len("boxed_") :]
40 | 
41 |     # This function is intended for remote or time-consuming reward model evaluation.
42 |     # Implement the actual logic as needed.
43 |     if rm_type == "remote_rm":
44 |         return await remote_rm(args, sample)
45 |     elif rm_type == "deepscaler":
46 |         return get_deepscaler_rule_based_reward(response, label)
47 |     elif rm_type == "dapo":
48 |         return compute_score_dapo(response, label)
49 |     elif rm_type == "math":
50 |         return 1 if grade_answer_verl(response, label) else 0
51 |     elif rm_type == "f1":
52 |         return f1_score(response, label)[0]
53 |     else:
54 |         raise NotImplementedError(f"Rule-based RM for {type} is not implemented.")
55 | 
56 | 
57 | async def batched_async_rm(
58 |     args,
59 |     samples: list[Sample],
60 |     **kwargs,
61 | ) -> list[Union[int, float]]:
62 |     if args.custom_rm_path is not None:
63 |         rm_function = load_function(args.custom_rm_path)
64 |         return await rm_function(args, samples, **kwargs)
65 | 
66 |     rm_type = args.rm_type
67 |     prompts = [sample.prompt for sample in samples]
68 |     responses = [sample.response for sample in samples]
69 |     labels = [sample.label for sample in samples]
70 |     if labels is None:
71 |         labels = [None] * len(prompts)
72 |     tasks = [
73 |         async_rm(rm_type, prompt, response, label, **kwargs)
74 |         for prompt, response, label in zip(prompts, responses, labels)
75 |     ]
76 |     rewards = await asyncio.gather(*tasks)
77 |     return rewards
78 | 


--------------------------------------------------------------------------------
/slime/rollout/rm_hub/deepscaler.py:
--------------------------------------------------------------------------------
 1 | from .math_utils import extract_answer, grade_answer_mathd, grade_answer_sympy
 2 | 
 3 | 
 4 | def get_deepscaler_rule_based_reward(response, label):
 5 |     if "</think>" in response:
 6 |         model_solution = response.split("</think>")[1]
 7 |     elif "###Response" in response:
 8 |         model_solution = response.split("###Response")[1]
 9 |     else:
10 |         return 0
11 | 
12 |     model_answer = extract_answer(model_solution)
13 |     if model_answer is None:
14 |         return 0
15 |     if label == "":
16 |         return 0
17 | 
18 |     # Convert single answer to list for uniform processing
19 |     assert isinstance(label, (str, float, int))
20 |     ground_truths = [label]
21 | 
22 |     # Process each ground truth
23 |     processed_ground_truths = []
24 |     for truth in ground_truths:
25 |         truth = str(truth)
26 |         if "\\boxed" in truth:
27 |             processed_truth = extract_answer(truth)
28 |             if processed_truth is not None:
29 |                 processed_ground_truths.append(processed_truth)
30 |         else:
31 |             processed_ground_truths.append(truth)
32 | 
33 |     if not processed_ground_truths:
34 |         return 0
35 | 
36 |     # Check against all possible correct answers
37 |     for ground_truth in processed_ground_truths:
38 |         is_correct = grade_answer_mathd(model_answer, ground_truth) or grade_answer_sympy(model_answer, ground_truth)
39 |         if is_correct:
40 |             return 1
41 | 
42 |     return 0
43 | 


--------------------------------------------------------------------------------
/slime/rollout/rm_hub/f1.py:
--------------------------------------------------------------------------------
 1 | import re
 2 | import string
 3 | from collections import Counter
 4 | 
 5 | 
 6 | def normalize_answer(s):
 7 | 
 8 |     def remove_articles(text):
 9 |         return re.sub(r"\b(a|an|the)\b", " ", text)
10 | 
11 |     def white_space_fix(text):
12 |         return " ".join(text.split())
13 | 
14 |     def remove_punc(text):
15 |         exclude = set(string.punctuation)
16 |         return "".join(ch for ch in text if ch not in exclude)
17 | 
18 |     def lower(text):
19 |         return text.lower()
20 | 
21 |     return white_space_fix(remove_articles(remove_punc(lower(s))))
22 | 
23 | 
24 | def f1_score(prediction, ground_truth):
25 |     ZERO_METRIC = (0, 0, 0)
26 | 
27 |     if prediction is None:
28 |         return ZERO_METRIC
29 | 
30 |     normalized_prediction = normalize_answer(prediction)
31 |     normalized_ground_truth = normalize_answer(ground_truth)
32 | 
33 |     if normalized_prediction in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth:
34 |         return ZERO_METRIC
35 |     if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth:
36 |         return ZERO_METRIC
37 | 
38 |     prediction_tokens = normalized_prediction.split()
39 |     ground_truth_tokens = normalized_ground_truth.split()
40 |     common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
41 |     num_same = sum(common.values())
42 |     if num_same == 0:
43 |         return ZERO_METRIC
44 |     precision = 1.0 * num_same / len(prediction_tokens)
45 |     recall = 1.0 * num_same / len(ground_truth_tokens)
46 |     f1 = (2 * precision * recall) / (precision + recall)
47 |     return f1, precision, recall
48 | 


--------------------------------------------------------------------------------
/slime/rollout/sft_example.py:
--------------------------------------------------------------------------------
 1 | from transformers import AutoTokenizer
 2 | 
 3 | from slime.utils.mask_utils import MultiTurnLossMaskGenerator
 4 | 
 5 | __all__ = ["generate_rollout"]
 6 | 
 7 | 
 8 | TOKENIZER = None
 9 | MASK_GENERATOR = None
10 | 
11 | 
12 | def generate_rollout(args, rollout_id, data_buffer, evaluation=False):
13 |     """An example to implement the generate_rollout function for an rule based rm rollout generation.
14 | 
15 |     Args:
16 |         args: the whole args
17 |         rollout_id: int, the id of the rollout, used for deterministic data generation
18 |         data_buffer: the data buffer to store the generated samples
19 |         evaluation: bool, whether the rollout is for evaluation or not
20 | 
21 |     Returns:
22 |         list[Sample]: a list of samples generated by the rollout
23 |     """
24 |     assert not evaluation
25 |     assert args.rollout_global_dataset
26 | 
27 |     global TOKENIZER, MASK_GENERATOR
28 |     if TOKENIZER is None:
29 |         TOKENIZER = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
30 | 
31 |     if MASK_GENERATOR is None:
32 |         MASK_GENERATOR = MultiTurnLossMaskGenerator(TOKENIZER, tokenizer_type=args.loss_mask_type)
33 | 
34 |     samples = data_buffer.get_samples(args.rollout_batch_size)
35 | 
36 |     for sample in samples:
37 |         (sample,) = sample
38 |         messages = sample.prompt
39 |         token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages)
40 |         response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0]
41 | 
42 |         sample.tokens = token_ids
43 |         sample.response_length = response_length
44 |         sample.reward = 0
45 |         sample.loss_mask = loss_mask[-response_length:]
46 | 
47 |     return samples
48 | 


--------------------------------------------------------------------------------
/slime/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .arguments import parse_args
2 | 
3 | __all__ = ["parse_args"]
4 | 


--------------------------------------------------------------------------------
/slime/utils/async_utils.py:
--------------------------------------------------------------------------------
 1 | import asyncio
 2 | import threading
 3 | 
 4 | __all__ = ["get_async_loop", "run"]
 5 | 
 6 | 
 7 | # Create a background event loop thread
 8 | class AsyncLoopThread:
 9 |     def __init__(self):
10 |         self.loop = asyncio.new_event_loop()
11 |         self._thread = threading.Thread(target=self._start_loop, daemon=True)
12 |         self._thread.start()
13 | 
14 |     def _start_loop(self):
15 |         asyncio.set_event_loop(self.loop)
16 |         self.loop.run_forever()
17 | 
18 |     def run(self, coro):
19 |         # Schedule a coroutine onto the loop and block until it's done
20 |         return asyncio.run_coroutine_threadsafe(coro, self.loop).result()
21 | 
22 | 
23 | # Create one global instance
24 | async_loop = None
25 | 
26 | 
27 | def get_async_loop():
28 |     global async_loop
29 |     if async_loop is None:
30 |         async_loop = AsyncLoopThread()
31 |     return async_loop
32 | 
33 | 
34 | def run(coro):
35 |     """Run a coroutine in the background event loop."""
36 |     return get_async_loop().run(coro)
37 | 


--------------------------------------------------------------------------------
/slime/utils/data.py:
--------------------------------------------------------------------------------
 1 | import random
 2 | 
 3 | import pandas as pd
 4 | 
 5 | from slime.utils.types import Sample
 6 | 
 7 | __all__ = ["Dataset"]
 8 | 
 9 | 
10 | # TODO: don't read the whole file into memory.
11 | def read_file(path):
12 |     if path.endswith(".jsonl"):
13 |         df = pd.read_json(path, lines=True)
14 |     elif path.endswith(".parquet"):
15 |         df = pd.read_parquet(path)
16 |     else:
17 |         raise ValueError(f"Unsupported file format: {path}. Supported formats are .jsonl and .parquet.")
18 | 
19 |     for _, row in df.iterrows():
20 |         yield row.to_dict()
21 | 
22 | 
23 | class Dataset:
24 |     def __init__(
25 |         self,
26 |         path,
27 |         tokenizer,
28 |         max_length,
29 |         *,
30 |         prompt_key="text",
31 |         label_key=None,
32 |         tool_key=None,
33 |         metadata_key="metadata",
34 |         seed=42,
35 |         apply_chat_template=False,
36 |     ):
37 |         self.origin_samples = []
38 |         for data in read_file(path):
39 |             prompt = data[prompt_key]
40 |             if apply_chat_template:
41 |                 if tool_key is not None:
42 |                     tools = data[tool_key]
43 |                 else:
44 |                     tools = None
45 |                 prompt = tokenizer.apply_chat_template(prompt, tools, tokenize=False, add_generation_prompt=True)
46 | 
47 |             # TODO: this is slow.
48 |             if max_length is not None:
49 |                 if len(tokenizer(prompt)["input_ids"]) > max_length:
50 |                     continue
51 | 
52 |             self.origin_samples.append(
53 |                 Sample(
54 |                     prompt=prompt,
55 |                     label=data[label_key] if label_key is not None else None,
56 |                     metadata=data.get(metadata_key) or {},
57 |                 )
58 |             )
59 | 
60 |         self.epoch_id = -1
61 |         self.seed = seed
62 |         self.samples = self.origin_samples
63 | 
64 |     def shuffle(self, new_epoch_id):
65 |         if self.epoch_id == new_epoch_id:
66 |             return
67 | 
68 |         random.seed(self.seed + new_epoch_id)
69 |         permutation = list(range(len(self.samples)))
70 |         random.shuffle(permutation)
71 |         self.samples = [self.origin_samples[i] for i in permutation]
72 |         self.epoch_id = new_epoch_id
73 | 
74 |     def __getitem__(self, idx):
75 |         return self.samples[idx]
76 | 
77 |     def __len__(self):
78 |         return len(self.samples)
79 | 


--------------------------------------------------------------------------------
/slime/utils/distributed_utils.py:
--------------------------------------------------------------------------------
  1 | from datetime import timedelta
  2 | from typing import Any, Optional, Union
  3 | 
  4 | import torch
  5 | import torch.distributed as dist
  6 | from torch.distributed.distributed_c10d import (
  7 |     Backend,
  8 |     PrefixStore,
  9 |     Store,
 10 |     _new_process_group_helper,
 11 |     _world,
 12 |     default_pg_timeout,
 13 |     rendezvous,
 14 | )
 15 | from megatron.core import mpu
 16 | 
 17 | 
 18 | # Copy from pytorch to allow creating multiple main groups.
 19 | # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
 20 | def init_process_group(
 21 |     backend: Union[str, Backend] = None,
 22 |     init_method: Optional[str] = None,
 23 |     timeout: Optional[timedelta] = None,
 24 |     world_size: int = -1,
 25 |     rank: int = -1,
 26 |     store: Optional[Store] = None,
 27 |     group_name: str = None,
 28 |     pg_options: Optional[Any] = None,
 29 | ):
 30 |     assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
 31 | 
 32 |     if store is not None:
 33 |         assert world_size > 0, "world_size must be positive if using store"
 34 |         assert rank >= 0, "rank must be non-negative if using store"
 35 |     elif init_method is None:
 36 |         init_method = "env://"
 37 | 
 38 |     if backend:
 39 |         backend = Backend(backend)
 40 |     else:
 41 |         backend = Backend("undefined")
 42 | 
 43 |     if timeout is None:
 44 |         timeout = default_pg_timeout
 45 | 
 46 |     # backward compatible API
 47 |     if store is None:
 48 |         rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
 49 |         store, rank, world_size = next(rendezvous_iterator)
 50 |         store.set_timeout(timeout)
 51 | 
 52 |         # Use a PrefixStore to avoid accidental overrides of keys used by
 53 |         # different systems (e.g. RPC) in case the store is multi-tenant.
 54 |         store = PrefixStore(group_name, store)
 55 | 
 56 |     # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
 57 |     # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
 58 |     # We need to determine the appropriate parameter name based on PyTorch version
 59 |     pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
 60 |     pg, _ = _new_process_group_helper(
 61 |         world_size,
 62 |         rank,
 63 |         [],
 64 |         backend,
 65 |         store,
 66 |         group_name=group_name,
 67 |         **{pg_options_param_name: pg_options},
 68 |         timeout=timeout,
 69 |     )
 70 | 
 71 |     _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
 72 | 
 73 |     return pg
 74 | 
 75 | 
 76 | def distributed_masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True, epsilon: float = 1e-8):
 77 |     """
 78 |     Performs whitening on a tensor using global statistics from all participating GPUs.
 79 | 
 80 |     It calculates the global mean and variance across all ranks in the default
 81 |     process group (the WORLD) and uses these global statistics to normalize the
 82 |     local data on each rank.
 83 | 
 84 |     Args:
 85 |         values (torch.Tensor): The local tensor of values to whiten.
 86 |         mask (torch.Tensor): The local mask corresponding to the values.
 87 |         shift_mean (bool): If True, the output is zero-mean. Defaults to True.
 88 |         epsilon (float): A small value for numerical stability.
 89 | 
 90 |     Returns:
 91 |         torch.Tensor: The locally whitened tensor using global statistics.
 92 |     """
 93 |     # Calculate local intermediate statistics
 94 |     local_sum = (values * mask).sum()
 95 |     local_sum_sq = ((values**2) * mask).sum()
 96 |     local_mask_sum = mask.sum()
 97 | 
 98 |     stats_tensor = torch.tensor(
 99 |         [local_sum, local_sum_sq, local_mask_sum],
100 |         device=values.device,
101 |         dtype=torch.float32,
102 |     )
103 | 
104 |     # Aggregate via all_reduce within the DP group
105 |     dist.all_reduce(stats_tensor)
106 | 
107 |     # Calculate global stats from aggregated results
108 |     global_sum, global_sum_sq, global_mask_sum = stats_tensor
109 | 
110 |     if global_mask_sum.item() == 0:
111 |         raise ValueError("The global mask sum across all participating GPUs is zero.")
112 | 
113 |     global_mean = global_sum / global_mask_sum
114 |     global_mean_sq = global_sum_sq / global_mask_sum
115 |     global_var = global_mean_sq - global_mean**2
116 |     
117 |     # Bessel's correction for unbiased estimate
118 |     if global_mask_sum.item() >= 2:
119 |         bessel_correction = global_mask_sum / (global_mask_sum - 1)
120 |         global_var = global_var * bessel_correction
121 |     
122 |     # Whiten local data using global stats
123 |     whitened_values = (values - global_mean) * torch.rsqrt(global_var + epsilon)
124 | 
125 |     if not shift_mean:
126 |         whitened_values += global_mean
127 |     
128 |     return whitened_values


--------------------------------------------------------------------------------
/slime/utils/flops_utils.py:
--------------------------------------------------------------------------------
  1 | def calculate_embedding_flops(seqlen, hidden_size):
  2 |     return 2 * seqlen * hidden_size
  3 | 
  4 | 
  5 | def calculate_lm_head_flops(seqlen, hidden_size, vocab_size):
  6 |     return 2 * seqlen * hidden_size * vocab_size
  7 | 
  8 | 
  9 | def calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num_query_groups):
 10 |     head_dim = hidden_size // num_attention_heads
 11 |     n_q_heads = num_attention_heads
 12 |     n_kv_heads = num_query_groups
 13 |     q_flops = 2 * seqlen * hidden_size * n_q_heads * head_dim
 14 |     kv_flops = 2 * seqlen * hidden_size * n_kv_heads * head_dim * 2
 15 |     return q_flops + kv_flops
 16 | 
 17 | 
 18 | def calculate_attention_flops(seqlen, num_attention_heads, head_dim):
 19 |     # QK^T
 20 |     flops = 2 * num_attention_heads * seqlen * seqlen * head_dim
 21 |     # A*V
 22 |     flops += 2 * num_attention_heads * seqlen * seqlen * head_dim
 23 |     return flops
 24 | 
 25 | 
 26 | def calculate_output_flops(seqlen, hidden_size):
 27 |     return 2 * seqlen * hidden_size * hidden_size
 28 | 
 29 | 
 30 | def calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size):
 31 |     return 2 * seqlen * hidden_size * ffn_hidden_size * 3
 32 | 
 33 | 
 34 | def calculate_layer_flops(seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size):
 35 |     head_dim = hidden_size // num_attention_heads
 36 |     return (
 37 |         calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num_query_groups)
 38 |         + calculate_attention_flops(seqlen, num_attention_heads, head_dim)
 39 |         + calculate_output_flops(seqlen, hidden_size)
 40 |         + calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size)
 41 |     )
 42 | 
 43 | 
 44 | def calculate_fwd_flops(
 45 |     seqlens,
 46 |     args,
 47 | ):
 48 |     hidden_size = args.hidden_size
 49 |     num_attention_heads = args.num_attention_heads
 50 |     num_query_groups = args.num_query_groups
 51 |     vocab_size = args.vocab_size
 52 | 
 53 |     total_flops = 0
 54 | 
 55 |     dense_ffn = args.ffn_hidden_size
 56 |     if args.num_experts is None:
 57 |         num_dense_layers = args.num_layers
 58 |         num_moe_layers = 0
 59 |     else:
 60 |         shared_expert_ffn = getattr(args, "moe_shared_expert_intermediate_size", None)
 61 |         if shared_expert_ffn is None:
 62 |             shared_expert_ffn = 0
 63 | 
 64 |         moe_ffn = args.moe_ffn_hidden_size * args.moe_router_topk + shared_expert_ffn
 65 |         if hasattr(args, "moe_layer_freq"):
 66 |             if isinstance(args.moe_layer_freq, list):
 67 |                 num_dense_layers = sum(1 for freq in args.moe_layer_freq if freq == 0)
 68 |                 num_moe_layers = sum(1 for freq in args.moe_layer_freq if freq > 0)
 69 |             else:
 70 |                 num_dense_layers = sum(1 for i in range(args.num_layers) if i % args.moe_layer_freq != 0)
 71 |                 num_moe_layers = sum(1 for i in range(args.num_layers) if i % args.moe_layer_freq == 0)
 72 |         else:
 73 |             num_dense_layers = 0
 74 |             num_moe_layers = args.num_layers
 75 | 
 76 |     for seqlen in seqlens:
 77 |         if num_dense_layers > 0:
 78 |             total_flops += (
 79 |                 calculate_layer_flops(
 80 |                     seqlen,
 81 |                     hidden_size,
 82 |                     num_attention_heads,
 83 |                     num_query_groups,
 84 |                     dense_ffn,
 85 |                 )
 86 |                 * num_dense_layers
 87 |             )
 88 | 
 89 |         if num_moe_layers > 0:
 90 |             total_flops += (
 91 |                 calculate_layer_flops(
 92 |                     seqlen,
 93 |                     hidden_size,
 94 |                     num_attention_heads,
 95 |                     num_query_groups,
 96 |                     moe_ffn,
 97 |                 )
 98 |                 * num_moe_layers
 99 |             )
100 | 
101 |         total_flops += calculate_lm_head_flops(seqlen, hidden_size, vocab_size)
102 | 
103 |     return total_flops
104 | 


--------------------------------------------------------------------------------
/slime/utils/http_utils.py:
--------------------------------------------------------------------------------
  1 | import asyncio
  2 | import multiprocessing
  3 | import random
  4 | import socket
  5 | 
  6 | import httpx
  7 | 
  8 | 
  9 | def find_available_port(base_port: int):
 10 |     port = base_port + random.randint(100, 1000)
 11 |     while True:
 12 |         if is_port_available(port):
 13 |             return port
 14 |         if port < 60000:
 15 |             port += 42
 16 |         else:
 17 |             port -= 43
 18 | 
 19 | 
 20 | def is_port_available(port):
 21 |     """Return whether a port is available."""
 22 |     with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
 23 |         try:
 24 |             s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 25 |             s.bind(("", port))
 26 |             s.listen(1)
 27 |             return True
 28 |         except socket.error:
 29 |             return False
 30 |         except OverflowError:
 31 |             return False
 32 | 
 33 | 
 34 | def get_host_info():
 35 |     hostname = socket.gethostname()
 36 | 
 37 |     local_ip = socket.gethostbyname(hostname)
 38 | 
 39 |     return hostname, local_ip
 40 | 
 41 | 
 42 | def run_router(args):
 43 |     try:
 44 |         from sglang_router.launch_router import launch_router
 45 | 
 46 |         router = launch_router(args)
 47 |         if router is None:
 48 |             return 1
 49 |         return 0
 50 |     except Exception as e:
 51 |         print(e)
 52 |         return 1
 53 | 
 54 | 
 55 | def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None:
 56 |     """Terminate a process gracefully, with forced kill as fallback.
 57 | 
 58 |     Args:
 59 |         process: The process to terminate
 60 |         timeout: Seconds to wait for graceful termination before forcing kill
 61 |     """
 62 |     if not process.is_alive():
 63 |         return
 64 | 
 65 |     process.terminate()
 66 |     process.join(timeout=timeout)
 67 |     if process.is_alive():
 68 |         process.kill()
 69 |         process.join()
 70 | 
 71 | 
 72 | async def post(url, payload, use_http2=False, max_retries=60):
 73 |     # never timeout
 74 |     timeout = httpx.Timeout(None)
 75 |     max_retries = 60
 76 |     retry_count = 0
 77 |     while retry_count < max_retries:
 78 |         try:
 79 |             async with httpx.AsyncClient(http1=not use_http2, http2=use_http2, timeout=timeout) as client:
 80 |                 response = await client.post(url, json=payload or {})
 81 |                 response.raise_for_status()
 82 |                 try:
 83 |                     output = response.json()
 84 |                 except:
 85 |                     output = response.text
 86 |         except Exception as e:
 87 |             retry_count += 1
 88 |             print(f"Error: {e}, retrying... (attempt {retry_count}/{max_retries})")
 89 |             if retry_count >= max_retries:
 90 |                 print(f"Max retries ({max_retries}) reached, failing...")
 91 |                 raise e
 92 |             await asyncio.sleep(1)
 93 |             continue
 94 |         break
 95 | 
 96 |     return output
 97 | 
 98 | 
 99 | async def get(url, use_http2=False):
100 |     # never timeout
101 |     timeout = httpx.Timeout(None)
102 |     async with httpx.AsyncClient(http1=not use_http2, http2=use_http2, timeout=timeout) as client:
103 |         response = await client.get(url)
104 |         response.raise_for_status()
105 |         output = response.json()
106 |     return output
107 | 


--------------------------------------------------------------------------------
/slime/utils/mask_utils.py:
--------------------------------------------------------------------------------
  1 | from typing import Dict, List, Tuple
  2 | 
  3 | from transformers import AutoTokenizer
  4 | 
  5 | 
  6 | class MultiTurnLossMaskGenerator:
  7 |     def __init__(self, tokenizer: AutoTokenizer, tokenizer_type: str = "qwen"):
  8 |         self.tokenizer = tokenizer
  9 |         self.system_message_length, self.gen_token_length = self.get_system_message_length()
 10 |         self.tokenizer_type = tokenizer_type
 11 | 
 12 |     def get_response_lengths(self, loss_masks: List[List[int]]) -> List[int]:
 13 |         return [len(mask[mask.index(1) :]) if 1 in mask else 0 for mask in loss_masks]
 14 | 
 15 |     def find_all_sublist_indices(self, main_list, sublist):
 16 |         sublist_len = len(sublist)
 17 |         indices = []
 18 |         for i in range(len(main_list) - sublist_len + 1):
 19 |             if main_list[i : i + sublist_len] == sublist:
 20 |                 indices.append(i)
 21 |         return indices
 22 | 
 23 |     def get_system_message_length(self) -> Tuple[int, int]:
 24 |         test_string = "FOR TESTING ONLY"
 25 |         test_messages = [
 26 |             {"role": "user", "content": test_string},
 27 |             {"role": "user", "content": test_string},
 28 |         ]
 29 |         raw_token_ids = self.tokenizer(test_string, add_special_tokens=False)["input_ids"]
 30 |         chat_template_token = self.tokenizer.apply_chat_template(
 31 |             test_messages, add_special_tokens=False, tokenize=False
 32 |         )
 33 |         chat_template_token_ids = self.tokenizer(chat_template_token, add_special_tokens=False)["input_ids"]
 34 |         idx_1, idx_2 = self.find_all_sublist_indices(chat_template_token_ids, raw_token_ids)
 35 |         end_interval = len(chat_template_token_ids) - len(raw_token_ids) - idx_2
 36 |         gen_token_length = len(
 37 |             self.tokenizer.apply_chat_template(
 38 |                 test_messages, add_special_tokens=False, tokenize=True, add_generation_prompt=True
 39 |             )
 40 |         ) - len(chat_template_token_ids)
 41 | 
 42 |         system_message_length = idx_1 - ((idx_2 - idx_1) - end_interval - len(raw_token_ids))
 43 |         return system_message_length, gen_token_length
 44 | 
 45 |     def gen_multi_turn_loss_mask_qwen(self, messages: List[Dict]) -> Tuple[List[int], List[int]]:
 46 |         all_loss_masks = []
 47 |         all_token_ids = []
 48 | 
 49 |         for i, message in enumerate(messages):
 50 |             message_ids = self.tokenizer.apply_chat_template([message], tokenize=True)
 51 | 
 52 |             if message["role"] != "system" and i > 0:
 53 |                 message_ids = message_ids[self.system_message_length :]
 54 | 
 55 |             if message["role"] == "assistant":
 56 |                 loss_mask = [0] * self.gen_token_length + [1] * (len(message_ids) - self.gen_token_length)
 57 |             else:
 58 |                 loss_mask = [0] * len(message_ids)
 59 | 
 60 |             all_loss_masks.extend(loss_mask)
 61 |             all_token_ids.extend(message_ids)
 62 | 
 63 |         return all_token_ids, all_loss_masks
 64 | 
 65 |     def gen_multi_turn_loss_mask_distill_qwen(self, messages: List[Dict]) -> Tuple[List[int], List[int]]:
 66 |         prompt = self.tokenizer.apply_chat_template(messages[:1], tokenize=False, add_generation_prompt=True)
 67 |         response = messages[-1]["content"]
 68 |         prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
 69 |         response_tokens = self.tokenizer(response, add_special_tokens=False)["input_ids"]
 70 | 
 71 |         response_length = len(response_tokens)
 72 |         token_ids = prompt_tokens + response_tokens
 73 |         loss_mask = [0] * len(prompt_tokens) + [1] * response_length
 74 |         return token_ids, loss_mask
 75 | 
 76 |     def get_loss_mask(self, messages: List[Dict]) -> List[int]:
 77 |         if self.tokenizer_type == "qwen":
 78 |             if "<|Assistant|>" in self.tokenizer.get_added_vocab():
 79 |                 return self.gen_multi_turn_loss_mask_distill_qwen(messages)
 80 | 
 81 |             return self.gen_multi_turn_loss_mask_qwen(messages)
 82 |         elif self.tokenizer_type == "distill_qwen":
 83 |             return self.gen_multi_turn_loss_mask_distill_qwen(messages)
 84 |         else:
 85 |             raise ValueError(f"Unsupported tokenizer type: {self.tokenizer_type}")
 86 | 
 87 |     def get_text_from_loss_mask(self, token_ids: List[int], loss_masks: List[int]) -> List[str]:
 88 |         selected_texts = []
 89 |         current_tokens = []
 90 | 
 91 |         for idx, mask in enumerate(loss_masks):
 92 |             if mask == 1:
 93 |                 current_tokens.append(token_ids[idx])
 94 |             elif current_tokens:
 95 |                 selected_texts.append(self.tokenizer.decode(current_tokens))
 96 |                 current_tokens = []
 97 | 
 98 |         if current_tokens:
 99 |             selected_texts.append(self.tokenizer.decode(current_tokens))
100 | 
101 |         return selected_texts
102 | 


--------------------------------------------------------------------------------
/slime/utils/memory_utils.py:
--------------------------------------------------------------------------------
 1 | import gc
 2 | import torch
 3 | import torch.distributed as dist
 4 | 
 5 | 
 6 | def clear_memory():
 7 |     torch.cuda.synchronize()
 8 |     gc.collect()
 9 |     torch.cuda.empty_cache()
10 | 
11 | 
12 | def available_memory():
13 |     free, total = torch.cuda.mem_get_info(torch.cuda.current_device())
14 |     return {
15 |         "gpu": str(torch.cuda.current_device()),
16 |         "total_GB": round(total / (1024**3), 2),
17 |         "free_GB": round(free / (1024**3), 2),
18 |         "used_GB": round((total - free) / (1024**3), 2),
19 |     }
20 | 
21 | 
22 | def print_memory(msg):
23 |     if dist.get_rank() == 0:
24 |         print(f"Memory-Usage {msg}:", available_memory())
25 | 


--------------------------------------------------------------------------------
/slime/utils/misc.py:
--------------------------------------------------------------------------------
 1 | import importlib
 2 | 
 3 | 
 4 | def load_function(path):
 5 |     """
 6 |     Load a function from a module.
 7 |     :param path: The path to the function, e.g. "module.submodule.function".
 8 |     :return: The function object.
 9 |     """
10 |     module_path, _, attr = path.rpartition(".")
11 |     module = importlib.import_module(module_path)
12 |     return getattr(module, attr)
13 | 
14 | 
15 | class SingletonMeta(type):
16 |     """
17 |     A metaclass for creating singleton classes.
18 |     """
19 | 
20 |     _instances = {}
21 | 
22 |     def __call__(cls, *args, **kwargs):
23 |         if cls not in cls._instances:
24 |             instance = super().__call__(*args, **kwargs)
25 |             cls._instances[cls] = instance
26 |         return cls._instances[cls]
27 | 


--------------------------------------------------------------------------------
/slime/utils/timer.py:
--------------------------------------------------------------------------------
 1 | from contextlib import contextmanager
 2 | from functools import wraps
 3 | from time import time
 4 | 
 5 | from .misc import SingletonMeta
 6 | 
 7 | __all__ = ["Timer", "timer"]
 8 | 
 9 | 
10 | class Timer(metaclass=SingletonMeta):
11 |     def __init__(self):
12 |         self.timers = {}
13 |         self.start_time = {}
14 | 
15 |     def start(self, name):
16 |         assert name not in self.timers, f"Timer {name} already started."
17 |         self.start_time[name] = time()
18 | 
19 |     def end(self, name):
20 |         assert name in self.start_time, f"Timer {name} not started."
21 |         elapsed_time = time() - self.start_time[name]
22 |         self.add(name, elapsed_time)
23 |         del self.start_time[name]
24 | 
25 |     def reset(self, name=None):
26 |         if name is None:
27 |             self.timers = {}
28 |         elif name in self.timers:
29 |             del self.timers[name]
30 | 
31 |     def add(self, name, elapsed_time):
32 |         if name not in self.timers:
33 |             self.timers[name] = elapsed_time
34 |         else:
35 |             self.timers[name] += elapsed_time
36 | 
37 |     def log_dict(self):
38 |         return self.timers
39 | 
40 |     @contextmanager
41 |     def context(self, name):
42 |         self.start(name)
43 |         try:
44 |             yield
45 |         finally:
46 |             self.end(name)
47 | 
48 | 
49 | def timer(name_or_func):
50 |     """
51 |     Can be used either as a decorator or a context manager:
52 | 
53 |     @timer
54 |     def func():
55 |         ...
56 | 
57 |     or
58 | 
59 |     with timer("block_name"):
60 |         ...
61 |     """
62 |     # When used as a context manager
63 |     if isinstance(name_or_func, str):
64 |         name = name_or_func
65 |         return Timer().context(name)
66 | 
67 |     func = name_or_func
68 | 
69 |     @wraps(func)
70 |     def wrapper(*args, **kwargs):
71 |         with Timer().context(func.__name__):
72 |             return func(*args, **kwargs)
73 | 
74 |     return wrapper
75 | 


--------------------------------------------------------------------------------
/slime/utils/types.py:
--------------------------------------------------------------------------------
 1 | from dataclasses import dataclass, field
 2 | from enum import Enum
 3 | from typing import Optional, Union
 4 | 
 5 | import torch
 6 | 
 7 | 
 8 | @dataclass
 9 | class Sample:
10 |     """The sample generated"""
11 | 
12 |     index: Optional[int] = None
13 |     # prompt
14 |     prompt: Union[str, list[dict[str, str]]] = ""
15 |     tokens: list[int] = field(default_factory=list)
16 |     # response
17 |     response: str = ""
18 |     response_length: int = 0
19 |     label: Optional[str] = None
20 |     reward: Optional[Union[float, dict[str, float]]] = None
21 |     loss_mask: Optional[list[int]] = None
22 | 
23 |     class Status(Enum):
24 |         PENDING = "pending"
25 |         COMPLETED = "completed"
26 |         TRUNCATED = "truncated"
27 |         ABORTED = "aborted"
28 | 
29 |     status: Status = Status.PENDING
30 |     metadata: dict = field(default_factory=dict)
31 | 
32 |     def to_dict(self):
33 |         value = self.__dict__.copy()
34 |         value["status"] = self.status.value
35 |         return value
36 | 
37 |     @staticmethod
38 |     def from_dict(data: dict):
39 |         data["status"] = Sample.Status(data["status"])
40 |         return Sample(**data)
41 | 
42 | 
43 | @dataclass
44 | class ParamInfo:
45 |     name: str
46 |     dtype: torch.dtype
47 |     shape: torch.Size
48 |     attrs: dict
49 |     size: int
50 |     src_rank: int
51 | 


--------------------------------------------------------------------------------
/slime_plugins/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime_plugins/__init__.py


--------------------------------------------------------------------------------
/slime_plugins/mbridge/__init__.py:
--------------------------------------------------------------------------------
1 | from .glm4 import GLM4Bridge


--------------------------------------------------------------------------------
/slime_plugins/mbridge/glm4.py:
--------------------------------------------------------------------------------
  1 | from mbridge.core import LLMBridge, register_model
  2 | from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
  3 | 
  4 | 
  5 | @register_model("glm4")
  6 | class GLM4Bridge(LLMBridge):
  7 |     """
  8 |     Bridge implementation for Qwen2 models.
  9 | 
 10 |     This class extends LLMBridge to provide specific configurations and
 11 |     optimizations for Qwen2 models, handling the conversion between
 12 |     Hugging Face Qwen2 format and Megatron-Core.
 13 |     """
 14 | 
 15 |     _DIRECT_MAPPING = {
 16 |         "embedding.word_embeddings.weight": "model.embed_tokens.weight",
 17 |         "decoder.final_layernorm.weight": "model.norm.weight",
 18 |         "output_layer.weight": "lm_head.weight",
 19 |     }
 20 |     _ATTENTION_MAPPING = {
 21 |         "self_attention.linear_proj.weight": [
 22 |             "model.layers.{layer_number}.self_attn.o_proj.weight"
 23 |         ],
 24 |         "self_attention.linear_qkv.layer_norm_weight": [
 25 |             "model.layers.{layer_number}.input_layernorm.weight"
 26 |         ],
 27 |         "self_attention.q_layernorm.weight": [
 28 |             "model.layers.{layer_number}.self_attn.q_norm.weight"
 29 |         ],
 30 |         "self_attention.k_layernorm.weight": [
 31 |             "model.layers.{layer_number}.self_attn.k_norm.weight"
 32 |         ],
 33 |         "self_attention.linear_qkv.weight": [
 34 |             "model.layers.{layer_number}.self_attn.q_proj.weight",
 35 |             "model.layers.{layer_number}.self_attn.k_proj.weight",
 36 |             "model.layers.{layer_number}.self_attn.v_proj.weight",
 37 |         ],
 38 |         "self_attention.linear_qkv.bias": [
 39 |             "model.layers.{layer_number}.self_attn.q_proj.bias",
 40 |             "model.layers.{layer_number}.self_attn.k_proj.bias",
 41 |             "model.layers.{layer_number}.self_attn.v_proj.bias",
 42 |         ],
 43 |     }
 44 |     _MLP_MAPPING = {
 45 |         "mlp.linear_fc1.weight": [
 46 |             "model.layers.{layer_number}.mlp.gate_up_proj.weight",
 47 |         ],
 48 |         "mlp.linear_fc1.layer_norm_weight": [
 49 |             "model.layers.{layer_number}.post_attention_layernorm.weight"
 50 |         ],
 51 |         "mlp.linear_fc2.weight": ["model.layers.{layer_number}.mlp.down_proj.weight"],
 52 |     }
 53 | 
 54 |     def _build_config(self):
 55 |         """
 56 |         Build the configuration for Qwen2 models.
 57 | 
 58 |         Configures Qwen2-specific parameters such as QKV bias settings and
 59 |         layer normalization options.
 60 | 
 61 |         Returns:
 62 |             TransformerConfig: Configuration object for Qwen2 models
 63 |         """
 64 |         return self._build_base_config(
 65 |             # qwen2
 66 |             add_qkv_bias=True,
 67 |             qk_layernorm=False,
 68 |             post_mlp_layernorm=True,
 69 |             post_self_attn_layernorm=True,
 70 |             rotary_interleaved=True,
 71 |         )
 72 | 
 73 |     def _get_transformer_layer_spec(self):
 74 |         """
 75 |         Gets the transformer layer specification.
 76 | 
 77 |         Creates and returns a specification for the transformer layers based on
 78 |         the current configuration.
 79 | 
 80 |         Returns:
 81 |             TransformerLayerSpec: Specification for transformer layers
 82 | 
 83 |         Raises:
 84 |             AssertionError: If normalization is not RMSNorm
 85 |         """
 86 |         transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
 87 |             post_self_attn_layernorm=True,
 88 |             post_mlp_layernorm=True,
 89 |         )
 90 |         return transformer_layer_spec
 91 | 
 92 |     def _weight_name_mapping_mcore_to_hf(self, mcore_weights_name: str) -> list[str]:
 93 |         """
 94 |         Map MCore weight names to Hugging Face weight names.
 95 | 
 96 |         Args:
 97 |             mcore_weights_name: MCore weight name
 98 | 
 99 |         Returns:
100 |             list: Corresponding Hugging Face weight names
101 |         """
102 |         assert (
103 |             "_extra_state" not in mcore_weights_name
104 |         ), "extra_state should not be loaded"
105 | 
106 |         if mcore_weights_name in self._DIRECT_MAPPING:
107 |             return [self._DIRECT_MAPPING[mcore_weights_name]]
108 | 
109 |         if "post_self_attn_layernorm" in mcore_weights_name:
110 |             layer_number = mcore_weights_name.split(".")[2]
111 |             return [
112 |                 f"model.layers.{layer_number}.post_self_attn_layernorm.weight"
113 |             ]
114 |         elif "post_mlp_layernorm" in mcore_weights_name:
115 |             layer_number = mcore_weights_name.split(".")[2]
116 |             return [
117 |                 f"model.layers.{layer_number}.post_mlp_layernorm.weight"
118 |             ]
119 |         elif "self_attention" in mcore_weights_name:
120 |             return self._weight_name_mapping_attention(mcore_weights_name)
121 |         elif "mlp" in mcore_weights_name:
122 |             return self._weight_name_mapping_mlp(mcore_weights_name)
123 |         else:
124 |             raise NotImplementedError(
125 |                 f"Unsupported parameter name: {mcore_weights_name}"
126 |             )
127 | 


--------------------------------------------------------------------------------
/slime_plugins/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/1097adeb5f84f75a0494adf5357bfec7f8b2f5df/slime_plugins/models/__init__.py


--------------------------------------------------------------------------------
/slime_plugins/models/glm4.py:
--------------------------------------------------------------------------------
 1 | from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
 2 | 
 3 | 
 4 | def get_glm_spec(args):
 5 |     transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
 6 |         args.num_experts,
 7 |         args.moe_grouped_gemm,
 8 |         args.qk_layernorm,
 9 |         args.multi_latent_attention,
10 |         args.moe_use_legacy_grouped_gemm,
11 |         post_self_attn_layernorm=args.post_self_attn_layernorm,
12 |         post_mlp_layernorm=args.post_mlp_layernorm,
13 |     )
14 |     return transformer_layer_spec
15 | 


--------------------------------------------------------------------------------
/slime_plugins/rollout_buffer/generator/__init__.py:
--------------------------------------------------------------------------------
 1 | from .base_generator import BaseGenerator, query_single_turn
 2 | from .reward_utils import get_rule_based_math_reward
 3 | from .utils.arguments import add_arguments
 4 | 
 5 | __all__ = [
 6 |     "BaseGenerator",
 7 |     "query_single_turn",
 8 |     "get_rule_based_math_reward",
 9 |     "add_arguments",
10 | ]
11 | 


--------------------------------------------------------------------------------
/slime_plugins/rollout_buffer/generator/reward_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .math_utils import get_rule_based_math_reward
2 | 
3 | __all__ = ["get_rule_based_math_reward"]
4 | 


--------------------------------------------------------------------------------
/slime_plugins/rollout_buffer/generator/utils/arguments.py:
--------------------------------------------------------------------------------
 1 | import argparse
 2 | 
 3 | 
 4 | def add_arguments(add_task_arguments=None):
 5 |     parser = argparse.ArgumentParser()
 6 |     parser.add_argument("--task_type", type=str, default="math")
 7 |     parser.add_argument("--input_file", type=str, default="None")
 8 |     parser.add_argument("--num_epoch", type=int, default=1)
 9 |     parser.add_argument("--num_repeat_per_sample", type=int, default=4)
10 |     parser.add_argument("--num_process", type=int, default=4)
11 |     parser.add_argument("--remote_engine_url", type=str, default="http://0.0.0.0:8000/v1")
12 |     parser.add_argument("--remote_buffer_url", type=str, default="http://localhost:8888")
13 |     parser.add_argument("--max_tokens", type=int, default=4096)
14 |     parser.add_argument("--num_repeats", type=int, default=20)
15 | 
16 |     if add_task_arguments is not None:
17 |         parser = add_task_arguments(parser)
18 | 
19 |     args = parser.parse_args()
20 |     return args
21 | 


--------------------------------------------------------------------------------
/slime_plugins/rollout_buffer/tools/assign_instance_id.py:
--------------------------------------------------------------------------------
 1 | import argparse
 2 | import json
 3 | from pathlib import Path
 4 | 
 5 | 
 6 | def main(input_path, task_type="math", output_path=None):
 7 |     input_path = Path(input_path)
 8 |     if output_path is None:
 9 |         output_path = str(input_path).replace(".jsonl", "_processed.jsonl")
10 |     used_ids = set()
11 |     processed = []
12 | 
13 |     # First pass: load all lines and collect existing instance_ids
14 |     with open(input_path, "r", encoding="utf-8") as f:
15 |         for line in f:
16 |             item = json.loads(line)
17 |             if "instance_id" in item:
18 |                 used_ids.add(item["instance_id"])
19 |             processed.append(item)
20 | 
21 |     # Second pass: assign missing instance_ids
22 |     counter = 0
23 |     for item in processed:
24 |         if "instance_id" not in item:
25 |             # Find unused id
26 |             while True:
27 |                 candidate_id = f"{task_type}_{counter}"
28 |                 counter += 1
29 |                 if candidate_id not in used_ids:
30 |                     item["instance_id"] = candidate_id
31 |                     used_ids.add(candidate_id)
32 |                     break
33 | 
34 |     # Save to new jsonl file
35 |     with open(output_path, "w", encoding="utf-8") as f:
36 |         for item in processed:
37 |             f.write(json.dumps(item, ensure_ascii=False) + "\n")
38 | 
39 |     print(f"✅ Processed {len(processed)} items. Saved to {output_path}")
40 | 
41 | 
42 | if __name__ == "__main__":
43 |     parser = argparse.ArgumentParser()
44 |     parser.add_argument("--input_path", type=str, help="Path to input JSONL file")
45 |     parser.add_argument(
46 |         "--task_type",
47 |         type=str,
48 |         default="math",
49 |         help="Task type prefix for new instance_id",
50 |     )
51 |     parser.add_argument("--output_path", type=str, default=None, help="Optional path to output file")
52 |     args = parser.parse_args()
53 | 
54 |     main(args.input_path, args.task_type, args.output_path)
55 | 


--------------------------------------------------------------------------------
/slime_plugins/rollout_buffer/tools/visualizer.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import threading
  3 | import time
  4 | 
  5 | import matplotlib.pyplot as plt
  6 | 
  7 | 
  8 | class BufferStatsVisualizer:
  9 |     def __init__(self, time_window=60):
 10 |         """
 11 |         Initialize the buffer statistics visualizer
 12 | 
 13 |         Args:
 14 |             time_window (int): Time window in seconds for each data point (default: 60s)
 15 |         """
 16 |         self.time_window = time_window
 17 |         self.data_points = []  # List to store data points
 18 |         self.timestamps = []  # List to store timestamps
 19 |         self.start_time = time.time()
 20 |         self.last_window_start = self.start_time
 21 |         self.window_count = 0  # Counter for current window
 22 |         self.args_dict = None  # Store args for filename
 23 | 
 24 |         # Initialize the plot
 25 |         plt.ion()  # Enable interactive mode
 26 |         self.fig, self.ax = plt.subplots(figsize=(12, 6))
 27 |         (self.line,) = self.ax.plot([], [], "b-", label="Data Points per Window")
 28 | 
 29 |         # Set up the plot
 30 |         self.ax.set_xlabel("Time (minutes)")
 31 |         self.ax.set_ylabel("Data Points per 60s Window")
 32 |         self.ax.set_title("Buffer Statistics - Data Points per 60s Window")
 33 |         self.ax.grid(True)
 34 |         self.ax.legend()
 35 | 
 36 |         # Start the update thread
 37 |         self.running = True
 38 |         self.update_thread = threading.Thread(target=self._update_plot)
 39 |         self.update_thread.daemon = True
 40 |         self.update_thread.start()
 41 | 
 42 |     def set_args(self, args_dict):
 43 |         """Set the args dictionary for filename generation"""
 44 |         self.args_dict = args_dict
 45 | 
 46 |     def add_data_point(self, _):
 47 |         """Add a new data point to the statistics"""
 48 |         current_time = time.time()
 49 |         self.window_count += 1  # Increment counter for current window
 50 | 
 51 |         # Check if we've reached the end of a time window
 52 |         if current_time - self.last_window_start >= self.time_window:
 53 |             # Calculate the time in minutes since start
 54 |             time_in_minutes = (current_time - self.start_time) / 60
 55 | 
 56 |             # Add the data point and timestamp
 57 |             self.data_points.append(self.window_count)
 58 |             self.timestamps.append(time_in_minutes)
 59 | 
 60 |             # Save the plot after adding new data point
 61 |             self.save_plot()
 62 | 
 63 |             # Reset for next window
 64 |             self.last_window_start = current_time
 65 |             self.window_count = 0
 66 | 
 67 |     def _update_plot(self):
 68 |         """Update the plot periodically"""
 69 |         while self.running:
 70 |             if self.data_points:  # Only update if we have data
 71 |                 self.line.set_data(self.timestamps, self.data_points)
 72 |                 self.ax.relim()
 73 |                 self.ax.autoscale_view()
 74 |                 self.fig.canvas.draw()
 75 |                 self.fig.canvas.flush_events()
 76 |             time.sleep(1)  # Update every second
 77 | 
 78 |     def save_plot(self):
 79 |         """Save the current plot to a file"""
 80 |         timestamp = self.start_time
 81 | 
 82 |         # Create filename based on args and timestamp
 83 |         if self.args_dict:
 84 |             # Extract key parameters from args
 85 |             key_params = []
 86 |             for key in ["task_type", "num_repeat_per_sample", "group_size"]:
 87 |                 if key in self.args_dict:
 88 |                     key_params.append(f"{key}_{self.args_dict[key]}")
 89 | 
 90 |             filename = f"buffer_stats_{'_'.join(key_params)}_{timestamp}.png"
 91 |         else:
 92 |             filename = f"buffer_stats_{timestamp}.png"
 93 | 
 94 |         # Create directory if it doesn't exist
 95 |         os.makedirs("buffer_stats", exist_ok=True)
 96 |         filepath = os.path.join("buffer_stats", filename)
 97 | 
 98 |         # Save the plot
 99 |         plt.savefig(filepath, dpi=300, bbox_inches="tight")
100 |         print(f"Plot saved to {filepath}")
101 | 
102 |     def close(self):
103 |         """Close the visualizer and clean up"""
104 |         self.running = False
105 |         if self.update_thread.is_alive():
106 |             self.update_thread.join()
107 |         plt.close(self.fig)
108 | 
109 | 
110 | # Example usage:
111 | if __name__ == "__main__":
112 |     visualizer = BufferStatsVisualizer(time_window=60)
113 |     visualizer.set_args({"task_type": "test", "num_repeat_per_sample": 16})
114 | 
115 |     # Simulate some data
116 |     try:
117 |         for i in range(1000):
118 |             visualizer.add_data_point(1)  # Just increment the counter
119 |             time.sleep(0.1)
120 |     except KeyboardInterrupt:
121 |         pass
122 |     finally:
123 |         visualizer.close()
124 | 


--------------------------------------------------------------------------------
/tests/test_qwen3_0.6B.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | # for rerun the task
  4 | pkill -9 sglang
  5 | sleep 3
  6 | ray stop --force
  7 | pkill -9 ray
  8 | pkill -9 python
  9 | sleep 3
 10 | pkill -9 ray
 11 | pkill -9 python
 12 | 
 13 | set -ex
 14 | 
 15 | # will prevent ray from buffering stdout/stderr
 16 | export PYTHONBUFFERED=16
 17 | 
 18 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
 19 | source "${SCRIPT_DIR}/../scripts/models/qwen3-0.6B.sh"
 20 | 
 21 | CKPT_ARGS=(
 22 |    --hf-checkpoint /root/Qwen3-0.6B
 23 |    --ref-load /root/Qwen3-0.6B_torch_dist
 24 | )
 25 | 
 26 | ROLLOUT_ARGS=(
 27 |    --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
 28 |    --input-key prompt
 29 |    --label-key label
 30 |    --apply-chat-template
 31 |    --rollout-shuffle
 32 |    --rm-type deepscaler
 33 |    --num-rollout 3000
 34 |    --rollout-batch-size 32
 35 |    --n-samples-per-prompt 8
 36 |    --rollout-max-response-len 8192
 37 |    --rollout-temperature 0.8
 38 | 
 39 |    --over-sampling-batch-size 64
 40 |    --dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std
 41 |    #--partial-rollout
 42 | 
 43 |    --global-batch-size 256
 44 |    #--balance-data
 45 | )
 46 | 
 47 | EVAL_ARGS=(
 48 |    --eval-interval 20
 49 |    --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
 50 |    --n-samples-per-eval-prompt 1
 51 |    --eval-max-response-len 16384
 52 |    --eval-temperature 0
 53 | )
 54 | 
 55 | PERF_ARGS=(
 56 |    --tensor-model-parallel-size 1
 57 |    --sequence-parallel
 58 |    --pipeline-model-parallel-size 1
 59 |    --context-parallel-size 1
 60 |    --expert-model-parallel-size 1
 61 |    --expert-tensor-parallel-size 1
 62 | 
 63 |    --recompute-granularity full
 64 |    --recompute-method uniform
 65 |    --recompute-num-layers 1
 66 | 
 67 |    # --micro-batch-size 1
 68 |    --use-dynamic-batch-size
 69 |    --max-tokens-per-gpu 9216
 70 | )
 71 | 
 72 | GRPO_ARGS=(
 73 |    --advantage-estimator grpo
 74 |    --use-kl-loss
 75 |    --kl-loss-coef 0.00
 76 |    --kl-loss-type low_var_kl
 77 |    --kl-coef 0.00
 78 |    --entropy-coef 0.00
 79 |    --eps-clip 0.2
 80 |    --eps-clip-high 0.28
 81 | )
 82 | 
 83 | OPTIMIZER_ARGS=(
 84 |    --optimizer adam
 85 |    --lr 1e-6
 86 |    --lr-decay-style constant
 87 |    --weight-decay 0.1
 88 |    --adam-beta1 0.9
 89 |    --adam-beta2 0.98
 90 | )
 91 | 
 92 | WANDB_ARGS=(
 93 |    #--use-wandb
 94 |    --wandb-project slime-test
 95 |    --wandb-group test-qwen-3-0.6B
 96 | )
 97 | 
 98 | SGLANG_ARGS=(
 99 |    --rollout-num-gpus-per-engine 1
100 |    --sglang-mem-fraction-static 0.7
101 | )
102 | 
103 | MISC_ARGS=(
104 |    # default dropout in megatron is 0.1
105 |    --attention-dropout 0.0
106 |    --hidden-dropout 0.0
107 |    # should be good for model performance
108 |    --accumulate-allreduce-grads-in-fp32
109 |    --attention-softmax-in-fp32
110 |    # need to comment this when using model with MLA
111 |    --attention-backend flash
112 | )
113 | 
114 | # launch the master node of ray in container
115 | ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
116 | 
117 | ray job submit --address="http://127.0.0.1:8265" \
118 |    --runtime-env-json='{
119 |      "env_vars": {
120 |         "PYTHONPATH": "/root/Megatron-LM",
121 |         "CUDA_DEVICE_MAX_CONNECTIONS": "1"
122 |      }
123 |    }' \
124 |    -- python3 train.py \
125 |    --actor-num-nodes 1 \
126 |    --actor-num-gpus-per-node 1 \
127 |    --colocate \
128 |    ${MODEL_ARGS[@]} \
129 |    ${CKPT_ARGS[@]} \
130 |    ${ROLLOUT_ARGS[@]} \
131 |    ${OPTIMIZER_ARGS[@]} \
132 |    ${GRPO_ARGS[@]} \
133 |    ${DISTRIBUTED_ARGS[@]} \
134 |    ${WANDB_ARGS[@]} \
135 |    ${PERF_ARGS[@]} \
136 |    ${EVAL_ARGS[@]} \
137 |    ${SGLANG_ARGS[@]} \
138 |    ${MISC_ARGS[@]}


--------------------------------------------------------------------------------
/tools/convert_hf_to_torch_dist.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import shutil
 3 | import inspect
 4 | 
 5 | import torch
 6 | import mbridge
 7 | from mbridge import AutoBridge
 8 | 
 9 | import slime_plugins.mbridge
10 | from megatron.core import parallel_state as mpu
11 | from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
12 | from megatron.training.arguments import parse_args
13 | from megatron.training.checkpointing import get_checkpoint_name, get_checkpoint_tracker_filename, save_checkpoint
14 | from megatron.training.global_vars import set_args
15 | 
16 | 
17 | def init_distributed():
18 |     """Initialize distributed environment"""
19 |     os.environ["RANK"] = "0"
20 |     os.environ["WORLD_SIZE"] = "1"
21 |     os.environ["MASTER_ADDR"] = "localhost"
22 |     os.environ["MASTER_PORT"] = "12355"
23 |     torch.distributed.init_process_group("nccl")
24 |     mpu.initialize_model_parallel(
25 |         tensor_model_parallel_size=1,
26 |         virtual_pipeline_model_parallel_size=None,
27 |         context_parallel_size=1,
28 |         expert_model_parallel_size=1,
29 |     )
30 |     model_parallel_cuda_manual_seed(0)
31 | 
32 | 
33 | def add_convertion_args(parser):
34 |     """Add conversion arguments to the parser"""
35 |     parser.add_argument("--hf-checkpoint", type=str, required=True, help="HuggingFace model path")
36 |     return parser
37 | 
38 | 
39 | def main():
40 |     # Parse command line arguments
41 |     args = parse_args(add_convertion_args)
42 |     args.use_dist_ckpt = args.ckpt_format != "torch"
43 |     set_args(args)
44 | 
45 | 
46 |     # Initialize distributed environment
47 |     init_distributed()
48 | 
49 |     # Load model
50 |     hf_model_path = args.hf_checkpoint
51 |     bridge = AutoBridge.from_pretrained(hf_model_path)
52 |     model = bridge.get_model()
53 |     bridge.load_weights(model, hf_model_path)
54 |     print(f"Model loaded: {hf_model_path}")
55 | 
56 |     save_checkpoint(1, model, None, None, 0)
57 |     # change to release ckpt
58 |     tracker_filename = get_checkpoint_tracker_filename(args.save)
59 |     with open(tracker_filename, "w") as f:
60 |         f.write("release")
61 |     source_dir = get_checkpoint_name(args.save, 1, False, return_base_dir=True)
62 |     target_dir = get_checkpoint_name(args.save, -1, True, return_base_dir=True)
63 |     shutil.move(source_dir, target_dir)
64 | 
65 | 
66 | if __name__ == "__main__":
67 |     main()
68 | 


--------------------------------------------------------------------------------
/tools/convert_to_hf.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.distributed as dist
  3 | from megatron.core import mpu
  4 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
  5 | 
  6 | import slime.backends.megatron_utils as megatron_utils
  7 | from slime.backends.megatron_utils import update_weight_utils
  8 | from slime.utils.arguments import parse_args
  9 | 
 10 | 
 11 | def add_checkpoint_args(parser):
 12 |     parser.add_argument(
 13 |         "--output-dir",
 14 |         type=str,
 15 |         default=None,
 16 |         help="Directory to save the converted HF model.",
 17 |     )
 18 |     parser.add_argument(
 19 |         "--check-same",
 20 |         action="store_true",
 21 |         default=False,
 22 |         help="Check if the converted model is the same as the original model.",
 23 |     )
 24 |     return parser
 25 | 
 26 | 
 27 | def main(args):
 28 |     megatron_utils.init(args)
 29 | 
 30 |     pp_size = mpu.get_pipeline_model_parallel_world_size()
 31 |     ep_size = mpu.get_expert_model_parallel_world_size()
 32 | 
 33 |     is_save_rank = (
 34 |         mpu.get_data_parallel_rank(with_context_parallel=True) == 0 and mpu.get_tensor_model_parallel_rank() == 0
 35 |     )
 36 | 
 37 |     # Setup the model and optimizer
 38 |     args.no_load_optim = True
 39 |     args.no_load_rng = True
 40 |     model, _, _, _ = megatron_utils.initialize_model_and_optimizer(args)
 41 | 
 42 |     hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
 43 |     model_name = type(hf_config).__name__.lower()
 44 | 
 45 |     tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
 46 | 
 47 |     vocab_size = tokenizer.vocab_size if args.vocab_size is None else args.vocab_size
 48 | 
 49 |     param_infos = update_weight_utils.get_param_infos(args, model)
 50 | 
 51 |     state_dict = {}
 52 |     rank = dist.get_rank()
 53 |     for info in param_infos:
 54 |         if dist.get_rank() == info.src_rank:
 55 |             for name_, param_ in update_weight_utils.named_parameters(args, model):
 56 |                 if name_ == info.name:
 57 |                     param = param_
 58 |                     break
 59 |         else:
 60 |             param = torch.empty(info.shape, dtype=info.dtype, device=torch.cuda.current_device())
 61 | 
 62 |         if pp_size > 1:
 63 |             if info.src_rank in dist.get_process_group_ranks(mpu.get_pipeline_model_parallel_group()):
 64 |                 torch.distributed.broadcast(param, src=info.src_rank, group=mpu.get_pipeline_model_parallel_group())
 65 | 
 66 |         # broadcast params across ep ranks
 67 |         if ep_size > 1:
 68 |             if ".experts." in info.name:
 69 |                 src_rank = (
 70 |                     info.src_rank
 71 |                     if info.src_rank in dist.get_process_group_ranks(mpu.get_expert_model_parallel_group())
 72 |                     else rank
 73 |                 )
 74 |                 torch.distributed.broadcast(param, src=src_rank, group=mpu.get_expert_model_parallel_group())
 75 | 
 76 |         for key, value in info.attrs.items():
 77 |             setattr(param, key, value)
 78 | 
 79 |         param = update_weight_utils.all_gather_param(info.name, param)
 80 |         param = update_weight_utils.remove_padding(info.name, param, vocab_size)
 81 |         # use torch.distributed
 82 |         if is_save_rank:
 83 |             converted_named_tensors = update_weight_utils.convert_to_hf(args, model_name, info.name, param)
 84 |             for name, param in converted_named_tensors:
 85 |                 state_dict[name] = param.cpu()
 86 |         del param
 87 | 
 88 |     if is_save_rank:
 89 |         hf_model = AutoModelForCausalLM.from_pretrained(
 90 |             args.hf_checkpoint, torch_dtype="auto", device_map="cpu", trust_remote_code=True
 91 |         )
 92 | 
 93 |         if args.check_same:
 94 |             for name, param in hf_model.named_parameters():
 95 |                 if name in state_dict:
 96 |                     assert (
 97 |                         param.shape == state_dict[name].shape
 98 |                     ), f"Shape mismatch for {name}: {param.shape} vs {state_dict[name].shape}"
 99 |                     assert torch.all(param == state_dict[name]), f"Value mismatch for {name}"
100 |                 else:
101 |                     print(f"Warning: {name} not found in state_dict")
102 | 
103 |         if args.output_dir:
104 |             tokenizer.save_pretrained(args.output_dir)
105 |             print(hf_model.load_state_dict(state_dict, strict=False))
106 |             hf_model.save_pretrained(args.output_dir)
107 | 
108 |     dist.barrier()
109 | 
110 | 
111 | if __name__ == "__main__":
112 |     args = parse_args(add_custom_arguments=add_checkpoint_args)
113 |     main(args)
114 | 


--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
 1 | import ray
 2 | 
 3 | from slime.ray.placement_group import create_actor_group, create_placement_groups, create_rollout_group
 4 | from slime.utils.arguments import parse_args
 5 | 
 6 | 
 7 | def train(args):
 8 |     # allocate the GPUs
 9 |     pgs = create_placement_groups(args)
10 | 
11 |     actor_model = create_actor_group(args, pgs["actor"])
12 | 
13 |     # create the rollout generator, with sglang engines inside.
14 |     rollout_generator = create_rollout_group(args, pgs["rollout"])
15 | 
16 |     # calculate num_rollout from num_epoch
17 |     num_rollout_per_epoch = None
18 |     if args.num_rollout is None:
19 |         num_rollout_per_epoch = ray.get(rollout_generator.data_buffer.get_num_rollout_per_epoch.remote())
20 |         args.num_rollout = num_rollout_per_epoch * args.num_epoch
21 |     assert args.num_rollout > 0
22 | 
23 |     # sync the initialization (model initalization, load checkpoint, etc.)
24 |     # Note that we initialize it earlier as megatron ckpt loading may have really large peak memory usage.
25 |     start_rollout_ids = ray.get(
26 |         actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss)
27 |     )
28 |     assert len(set(start_rollout_ids)) == 1
29 |     if args.start_rollout_id is None:
30 |         args.start_rollout_id = start_rollout_ids[0]
31 | 
32 |     if args.rollout_global_dataset:
33 |         ray.get(rollout_generator.data_buffer.load.remote(args.start_rollout_id - 1))
34 | 
35 |     # initialize the connection for weight update during training
36 |     ray.get(actor_model.async_init_weight_update_connections(rollout_generator))
37 | 
38 |     if args.offload:
39 |         ray.get(rollout_generator.async_onload())
40 | 
41 |     # always update weight first so that sglang has the loaded weights from training.
42 |     ray.get(actor_model.async_update_weights())
43 | 
44 |     # train loop.
45 |     # note that for async training, one can change the position of the sync operation(ray.get).
46 |     for rollout_id in range(args.start_rollout_id, args.num_rollout):
47 |         if args.eval_interval is not None and rollout_id == 0:
48 |             ray.get(rollout_generator.async_generate(rollout_id, evaluation=True))
49 |             ray.get(actor_model.async_eval(rollout_id))
50 | 
51 |         ray.get(rollout_generator.async_generate(rollout_id))
52 | 
53 |         if args.offload:
54 |             ray.get(rollout_generator.async_offload())
55 | 
56 |         ray.get(actor_model.async_train(rollout_id))
57 | 
58 |         if args.save_interval is not None and (
59 |             (rollout_id + 1) % args.save_interval == 0
60 |             or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
61 |         ):
62 |             ray.get(actor_model.async_save_model(rollout_id))
63 |             if args.rollout_global_dataset:
64 |                 ray.get(rollout_generator.data_buffer.save.remote(rollout_id))
65 | 
66 |         if args.offload:
67 |             ray.get(actor_model.async_offload())
68 |             ray.get(rollout_generator.async_onload())
69 | 
70 |         ray.get(actor_model.async_update_weights())
71 | 
72 |         if args.eval_interval is not None and (
73 |             (rollout_id + 1) % args.eval_interval == 0
74 |             or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
75 |         ):
76 |             ray.get(rollout_generator.async_generate(rollout_id, evaluation=True))
77 |             ray.get(actor_model.async_eval(rollout_id))
78 | 
79 | 
80 | if __name__ == "__main__":
81 |     args = parse_args()
82 |     train(args)
83 | 


--------------------------------------------------------------------------------
/train_async.py:
--------------------------------------------------------------------------------
 1 | import ray
 2 | 
 3 | from slime.ray.placement_group import create_actor_group, create_placement_groups, create_rollout_group
 4 | from slime.utils.arguments import parse_args
 5 | 
 6 | 
 7 | def train(args):
 8 |     assert not args.colocate, "Colocation is not supported for async training."
 9 |     # allocate the GPUs
10 |     pgs = create_placement_groups(args)
11 | 
12 |     actor_model = create_actor_group(args, pgs["actor"])
13 | 
14 |     # create the rollout generator, with sglang engines inside.
15 |     rollout_generator = create_rollout_group(args, pgs["rollout"])
16 | 
17 |     # calculate num_rollout from num_epoch
18 |     num_rollout_per_epoch = None
19 |     if args.num_rollout is None:
20 |         num_rollout_per_epoch = ray.get(rollout_generator.data_buffer.get_num_rollout_per_epoch.remote())
21 |         args.num_rollout = num_rollout_per_epoch * args.num_epoch
22 |     assert args.num_rollout > 0
23 | 
24 |     # sync the initialization (model initalization, load checkpoint, etc.)
25 |     # Note that we initialize it earlier as megatron ckpt loading may have really large peak memory usage.
26 |     start_rollout_ids = ray.get(
27 |         actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss)
28 |     )
29 |     assert len(set(start_rollout_ids)) == 1
30 |     if args.start_rollout_id is None:
31 |         args.start_rollout_id = start_rollout_ids[0]
32 | 
33 |     if args.rollout_global_dataset:
34 |         ray.get(rollout_generator.data_buffer.load.remote(args.start_rollout_id - 1))
35 | 
36 |     # initialize the connection for weight update during training
37 |     ray.get(actor_model.async_init_weight_update_connections(rollout_generator))
38 | 
39 |     # always update weight first so that sglang has the loaded weights from training.
40 |     ray.get(actor_model.async_update_weights())
41 | 
42 |     # async train loop.
43 |     generation_handles = rollout_generator.async_generate(args.start_rollout_id)
44 |     for rollout_id in range(args.start_rollout_id, args.num_rollout):
45 |         # Sync the last generation
46 |         ray.get(generation_handles)
47 | 
48 |         # This is a synchronous call to ensure that the rollout data is ready
49 |         actor_model.get_rollout_data(rollout_id)
50 | 
51 |         # Start the next rollout early.
52 |         if rollout_id + 1 < args.num_rollout:
53 |             generation_handles = rollout_generator.async_generate(rollout_id + 1)
54 | 
55 |         ray.get(actor_model.async_train(rollout_id, with_data_fetching=False))
56 | 
57 |         if args.save_interval is not None and (
58 |             (rollout_id + 1) % args.save_interval == 0
59 |             or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
60 |         ):
61 |             ray.get(actor_model.async_save_model(rollout_id))
62 |             if args.rollout_global_dataset:
63 |                 ray.get(rollout_generator.data_buffer.save.remote(rollout_id))
64 | 
65 |         if (rollout_id + 1) % args.update_weights_interval == 0:
66 |             # sync generate before update weights to prevent update weight in the middle of generation
67 |             ray.get(generation_handles)
68 |             ray.get(actor_model.async_update_weights())
69 | 
70 |         if args.eval_interval is not None and (
71 |             (rollout_id + 1) % args.eval_interval == 0
72 |             or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
73 |         ):
74 |             ray.get(rollout_generator.async_generate(rollout_id, evaluation=True))
75 |             ray.get(actor_model.async_eval(rollout_id))
76 | 
77 | 
78 | if __name__ == "__main__":
79 |     args = parse_args()
80 |     train(args)
81 | 


--------------------------------------------------------------------------------