├── slime
├── __init__.py
├── ray
│ ├── __init__.py
│ ├── ray_actor.py
│ └── utils.py
├── rollout
│ ├── __init__.py
│ ├── filter_hub
│ │ ├── __init__.py
│ │ ├── base_types.py
│ │ └── dynamic_sampling_filters.py
│ ├── generate_hub
│ │ ├── __init__.py
│ │ └── benchmarkers.py
│ ├── sleep_rollout.py
│ ├── base_types.py
│ ├── rm_hub
│ │ ├── deepscaler.py
│ │ ├── f1.py
│ │ └── __init__.py
│ └── sft_rollout.py
├── router
│ ├── __init__.py
│ └── middleware_hub
│ │ └── __init__.py
├── backends
│ ├── __init__.py
│ ├── sglang_utils
│ │ └── __init__.py
│ ├── fsdp_utils
│ │ ├── kernels
│ │ │ └── __init__.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ └── qwen3_moe_hf.py
│ │ └── __init__.py
│ └── megatron_utils
│ │ ├── update_weight
│ │ ├── __init__.py
│ │ ├── hf_weight_iterator_base.py
│ │ └── hf_weight_iterator_bridge.py
│ │ ├── megatron_to_hf
│ │ ├── processors
│ │ │ ├── __init__.py
│ │ │ └── padding_remover.py
│ │ ├── llama.py
│ │ └── mimo.py
│ │ ├── misc_utils.py
│ │ ├── sglang.py
│ │ ├── arguments.py
│ │ ├── __init__.py
│ │ └── checkpoint.py
└── utils
│ ├── debug_utils
│ ├── __init__.py
│ ├── replay_reward_fn.py
│ ├── send_to_sglang.py
│ └── display_debug_rollout_data.py
│ ├── external_utils
│ └── __init__.py
│ ├── __init__.py
│ ├── ray_utils.py
│ ├── context_utils.py
│ ├── logging_utils.py
│ ├── train_dump_utils.py
│ ├── tracking_utils.py
│ ├── megatron_bridge_utils.py
│ ├── async_utils.py
│ ├── metric_checker.py
│ ├── rocm_checkpoint_writer.py
│ ├── iter_utils.py
│ ├── memory_utils.py
│ ├── train_metric_utils.py
│ ├── tensorboard_utils.py
│ ├── timer.py
│ ├── typer_utils.py
│ ├── fp8_kernel.py
│ └── misc.py
├── examples
├── __init__.py
├── multi_agent
│ ├── __init__.py
│ ├── rollout_with_multi_agents.py
│ ├── README.md
│ └── prompts.py
├── eval
│ ├── __init__.py
│ ├── nemo_skills
│ │ ├── __init__.py
│ │ ├── config
│ │ │ └── local_cluster.yaml
│ │ ├── skills_config.py
│ │ └── skills_client.py
│ ├── scripts
│ │ └── multi_tasks.yaml
│ └── README.md
├── retool
│ ├── requirements.txt
│ ├── rl_data_preprocess.py
│ ├── sft_data_processing.py
│ └── README.md
├── strands-agents
│ ├── requirements.txt
│ └── README.md
├── true_on_policy
│ └── src
│ │ ├── aime.png
│ │ ├── raw_reward.png
│ │ ├── step_time.png
│ │ ├── rollout_time.png
│ │ └── train_rollout_abs_diff.png
├── true_on_policy_vlm
│ ├── diff.png
│ └── README.md
├── geo3k_vlm
│ └── fsdp_vs_megatron.png
├── eval_multi_task
│ ├── requirements_ifbench.txt
│ ├── multi_task.yaml
│ └── README.md
├── formal_math
│ └── single_round
│ │ └── README.md
├── tau-bench
│ ├── sglang_tool_parser.py
│ ├── tau1_mock.py
│ └── README.md
├── train_infer_mismatch_helper
│ └── mis.yaml
├── on_policy_distillation
│ └── on_policy_distillation.py
├── search-r1
│ └── local_dense_retriever
│ │ └── download.py
├── reproducibility
│ └── README.md
└── fully_async
│ └── README.md
├── slime_plugins
├── __init__.py
├── models
│ ├── __init__.py
│ └── glm4.py
├── megatron_bridge
│ └── __init__.py
├── rollout_buffer
│ ├── generator
│ │ └── __init__.py
│ ├── README_zh.md
│ └── README.md
└── mbridge
│ └── __init__.py
├── tests
├── ci
│ ├── github_runner
│ │ ├── .gitignore
│ │ ├── .env.example
│ │ └── docker-compose.yml
│ └── README.md
├── test_fsdp_import.py
├── test_gspo.sh
├── test_chunked_gae.py
└── test_qwen3_0.6B_fsdp_colocated_2xGPU.py
├── docker
├── version.txt
├── README.md
├── justfile
└── amd_patch
│ ├── latest
│ └── amd_megatron_fused_kernels_init.patch
│ └── sglv0.5.0rc0
│ └── amd_megatron_fused_kernels_init.patch
├── imgs
└── arch.png
├── docs
├── _static
│ ├── image
│ │ ├── logo.ico
│ │ ├── logo.jpg
│ │ └── blogs
│ │ │ └── release_v0.1.0
│ │ │ ├── cuda_vmm.png
│ │ │ └── overrall.png
│ └── css
│ │ ├── readthedocs.css
│ │ └── custom_log.css
├── requirements.txt
├── build.sh
├── zh
│ ├── advanced
│ │ ├── fault-torlance.md
│ │ ├── speculative-decoding.md
│ │ └── arch-support-beyond-megatron.md
│ ├── index.rst
│ ├── developer_guide
│ │ └── debug.md
│ ├── get_started
│ │ └── qa.md
│ └── examples
│ │ ├── qwen3-4b-base-openhermes.md
│ │ ├── qwen3-30B-A3B.md
│ │ └── qwen3-next-80B-A3B.md
├── serve.sh
├── README.md
├── en
│ ├── advanced
│ │ ├── fault-tolerance.md
│ │ └── speculative-decoding.md
│ └── index.rst
└── build_all.sh
├── scripts
└── models
│ ├── deepseek-v3-5layer.sh
│ ├── deepseek-v3-20layer.sh
│ ├── qwen3-4B-Instruct-2507.sh
│ ├── qwen2.5-0.5B.sh
│ ├── qwen2.5-1.5B.sh
│ ├── qwen2.5-3B.sh
│ ├── qwen3-0.6B.sh
│ ├── qwen3-1.7B.sh
│ ├── qwen2.5-32B.sh
│ ├── qwen2.5-7B.sh
│ ├── qwen3-4B.sh
│ ├── qwen3-8B.sh
│ ├── qwen3-14B.sh
│ ├── qwen3-32B.sh
│ ├── llama3.2-3B-Instruct.sh
│ ├── llama3.2-3B-Instruct-amd.sh
│ ├── mimo-7B-rl.sh
│ ├── llama3.1-8B-Instruct.sh
│ ├── glm4-9B.sh
│ ├── glm4-32B.sh
│ ├── qwen3-30B-A3B.sh
│ ├── glm4.5-106B-A12B.sh
│ ├── qwen3-235B-A22B.sh
│ ├── glm4.5-355B-A32B.sh
│ ├── qwen3-next-80B-A3B.sh
│ ├── kimi-k2.sh
│ ├── kimi-k2-thinking.sh
│ ├── deepseek-v3.sh
│ └── moonlight.sh
├── requirements.txt
├── .github
└── workflows
│ ├── pre-commit.yml
│ ├── generate_github_workflows.py
│ ├── release-docs.yaml
│ └── conda-ci.yml
├── .pre-commit-config.yaml
├── setup.py
├── pyproject.toml
├── README_zh.md
└── train_async.py
/slime/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/ray/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/rollout/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/router/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime_plugins/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/multi_agent/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/backends/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/slime_plugins/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/rollout/filter_hub/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/utils/debug_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/utils/external_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/ci/github_runner/.gitignore:
--------------------------------------------------------------------------------
1 | .env
--------------------------------------------------------------------------------
/docker/version.txt:
--------------------------------------------------------------------------------
1 | nightly-dev-20251222a
--------------------------------------------------------------------------------
/slime/backends/sglang_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/router/middleware_hub/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime_plugins/megatron_bridge/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/backends/fsdp_utils/kernels/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/backends/fsdp_utils/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/backends/megatron_utils/update_weight/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/backends/megatron_utils/megatron_to_hf/processors/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Utility package root for Slime."""
2 |
--------------------------------------------------------------------------------
/imgs/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/imgs/arch.png
--------------------------------------------------------------------------------
/examples/eval/__init__.py:
--------------------------------------------------------------------------------
1 | """Evaluation helpers and example configs."""
2 |
--------------------------------------------------------------------------------
/examples/eval/nemo_skills/__init__.py:
--------------------------------------------------------------------------------
1 | """NeMo Skills evaluation helpers."""
2 |
--------------------------------------------------------------------------------
/examples/retool/requirements.txt:
--------------------------------------------------------------------------------
1 | jinja2>=3.0.0
2 | psutil>=5.8.0
3 | pytest>=7.0.0
4 |
--------------------------------------------------------------------------------
/examples/strands-agents/requirements.txt:
--------------------------------------------------------------------------------
1 | camel-ai
2 | strands-agents
3 | strands-agents-tools
4 |
--------------------------------------------------------------------------------
/docs/_static/image/logo.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/docs/_static/image/logo.ico
--------------------------------------------------------------------------------
/docs/_static/image/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/docs/_static/image/logo.jpg
--------------------------------------------------------------------------------
/slime/rollout/generate_hub/__init__.py:
--------------------------------------------------------------------------------
1 | # TODO: maybe move `sglang_rollout::generate` to this folder
2 |
--------------------------------------------------------------------------------
/examples/true_on_policy/src/aime.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy/src/aime.png
--------------------------------------------------------------------------------
/examples/true_on_policy_vlm/diff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy_vlm/diff.png
--------------------------------------------------------------------------------
/scripts/models/deepseek-v3-5layer.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS_NUM_LAYERS=5 source "$(dirname -- "${BASH_SOURCE[0]}")/deepseek-v3.sh"
2 |
--------------------------------------------------------------------------------
/examples/geo3k_vlm/fsdp_vs_megatron.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/geo3k_vlm/fsdp_vs_megatron.png
--------------------------------------------------------------------------------
/scripts/models/deepseek-v3-20layer.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS_NUM_LAYERS=20 source "$(dirname -- "${BASH_SOURCE[0]}")/deepseek-v3.sh"
2 |
--------------------------------------------------------------------------------
/scripts/models/qwen3-4B-Instruct-2507.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS_ROTARY_BASE=5000000 source "$(dirname -- "${BASH_SOURCE[0]}")/qwen3-4B.sh"
--------------------------------------------------------------------------------
/examples/eval_multi_task/requirements_ifbench.txt:
--------------------------------------------------------------------------------
1 | emoji
2 | immutabledict
3 | nltk
4 | numpy==1.26.4
5 | spacy==3.7.4
6 | syllapy
7 |
--------------------------------------------------------------------------------
/examples/true_on_policy/src/raw_reward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy/src/raw_reward.png
--------------------------------------------------------------------------------
/examples/true_on_policy/src/step_time.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy/src/step_time.png
--------------------------------------------------------------------------------
/tests/ci/github_runner/.env.example:
--------------------------------------------------------------------------------
1 | GITHUB_RUNNER_URL=https://github.com/slimerl/slime
2 | GITHUB_RUNNER_TOKEN=paste-your-token-here
--------------------------------------------------------------------------------
/examples/true_on_policy/src/rollout_time.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy/src/rollout_time.png
--------------------------------------------------------------------------------
/docs/_static/image/blogs/release_v0.1.0/cuda_vmm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/docs/_static/image/blogs/release_v0.1.0/cuda_vmm.png
--------------------------------------------------------------------------------
/docs/_static/image/blogs/release_v0.1.0/overrall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/docs/_static/image/blogs/release_v0.1.0/overrall.png
--------------------------------------------------------------------------------
/examples/true_on_policy/src/train_rollout_abs_diff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy/src/train_rollout_abs_diff.png
--------------------------------------------------------------------------------
/slime/rollout/filter_hub/base_types.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 |
4 | @dataclass
5 | class DynamicFilterOutput:
6 | keep: bool
7 | reason: str | None = None
8 |
--------------------------------------------------------------------------------
/slime/utils/ray_utils.py:
--------------------------------------------------------------------------------
1 | class Box:
2 | def __init__(self, inner):
3 | self._inner = inner
4 |
5 | @property
6 | def inner(self):
7 | return self._inner
8 |
--------------------------------------------------------------------------------
/slime_plugins/rollout_buffer/generator/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_generator import BaseGenerator, query_single_turn
2 |
3 | __all__ = [
4 | "BaseGenerator",
5 | "query_single_turn",
6 | ]
7 |
--------------------------------------------------------------------------------
/docs/_static/css/readthedocs.css:
--------------------------------------------------------------------------------
1 | table.autosummary td {
2 | width: 50%
3 | }
4 |
5 | img.align-center {
6 | display: block;
7 | margin-left: auto;
8 | margin-right: auto;
9 | }
10 |
--------------------------------------------------------------------------------
/slime/backends/megatron_utils/misc_utils.py:
--------------------------------------------------------------------------------
1 | def strip_param_name_prefix(name: str):
2 | prefix = "module."
3 | while name.startswith(prefix):
4 | name = name.removeprefix(prefix)
5 | return name
6 |
--------------------------------------------------------------------------------
/slime_plugins/mbridge/__init__.py:
--------------------------------------------------------------------------------
1 | from .glm4 import GLM4Bridge
2 | from .glm4moe import GLM4MoEBridge
3 | from .mimo import MimoBridge
4 | from .qwen3_next import Qwen3NextBridge
5 |
6 | __all__ = ["GLM4Bridge", "GLM4MoEBridge", "Qwen3NextBridge", "MimoBridge"]
7 |
--------------------------------------------------------------------------------
/tests/test_fsdp_import.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | def test_fsdp_import():
5 | try:
6 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7 | except ImportError:
8 | pytest.skip("FSDP not available in this environment")
9 | assert FSDP is not None
10 |
--------------------------------------------------------------------------------
/slime/rollout/sleep_rollout.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 |
4 | logger = logging.getLogger(__name__)
5 |
6 |
7 | def sleep(args, rollout_id, data_source, evaluation=False):
8 | count = 0
9 | while True:
10 | time.sleep(3600)
11 | count += 1
12 | logger.info(f"rollout sleep for {count} hours")
13 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | gguf>=0.10.0
2 | ipykernel
3 | ipywidgets
4 | jupyter_client
5 | markdown>=3.4.0
6 | matplotlib
7 | myst-parser
8 | nbconvert
9 | nbsphinx
10 | nbstripout
11 | pandoc
12 | pillow
13 | pydantic
14 | sphinx
15 | sphinx-autobuild
16 | sphinx-book-theme
17 | sphinx-copybutton
18 | sphinx-tabs
19 | sphinxcontrib-mermaid
20 | urllib3<2.0.0
21 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | blobfile
3 | datasets
4 | httpx[http2]
5 | mcp[cli]
6 | memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml
7 | omegaconf
8 | pillow
9 | pylatexenc
10 | pyyaml
11 | qwen_vl_utils # for VLM
12 | ray[default]
13 | ring_flash_attn
14 | sglang-router>=0.2.3
15 | tensorboard
16 | transformers
17 | wandb
18 |
--------------------------------------------------------------------------------
/slime/utils/context_utils.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 |
3 |
4 | def with_defer(deferred_func):
5 | def decorator(fn):
6 | @wraps(fn)
7 | def wrapper(*args, **kwargs):
8 | try:
9 | return fn(*args, **kwargs)
10 | finally:
11 | deferred_func()
12 |
13 | return wrapper
14 |
15 | return decorator
16 |
--------------------------------------------------------------------------------
/docs/build.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
4 | LANG=$1
5 |
6 | # make sure language is only en or zh
7 | if [ "$LANG" != "en" ] && [ "$LANG" != "zh" ]; then
8 | echo "Language must be en or zh"
9 | exit 1
10 | fi
11 |
12 | cd $SCRIPT_DIR
13 | SLIME_DOC_LANG=$LANG sphinx-build -b html -D language=$LANG --conf-dir ./ ./$LANG ./build/$LANG
--------------------------------------------------------------------------------
/slime/ray/ray_actor.py:
--------------------------------------------------------------------------------
1 | from slime.utils.misc import get_current_node_ip, get_free_port
2 |
3 |
4 | class RayActor:
5 | @staticmethod
6 | def _get_current_node_ip_and_free_port(start_port=10000, consecutive=1):
7 | return get_current_node_ip(), get_free_port(start_port=start_port, consecutive=consecutive)
8 |
9 | def get_master_addr_and_port(self):
10 | return self.master_addr, self.master_port
11 |
--------------------------------------------------------------------------------
/scripts/models/qwen2.5-0.5B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 24
4 | --hidden-size 896
5 | --ffn-hidden-size 4864
6 | --num-attention-heads 14
7 | --use-rotary-position-embeddings
8 | --disable-bias-linear
9 | --add-qkv-bias
10 | --normalization "RMSNorm"
11 | --norm-epsilon 1e-6
12 | --rotary-base 1000000
13 | --group-query-attention
14 | --num-query-groups 2
15 | --vocab-size 151936
16 | )
--------------------------------------------------------------------------------
/scripts/models/qwen2.5-1.5B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 28
4 | --hidden-size 1536
5 | --ffn-hidden-size 8960
6 | --num-attention-heads 12
7 | --use-rotary-position-embeddings
8 | --disable-bias-linear
9 | --add-qkv-bias
10 | --normalization "RMSNorm"
11 | --norm-epsilon 1e-6
12 | --rotary-base 10000
13 | --group-query-attention
14 | --num-query-groups 2
15 | --vocab-size 151936
16 | )
--------------------------------------------------------------------------------
/scripts/models/qwen2.5-3B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 36
4 | --hidden-size 2048
5 | --ffn-hidden-size 11008
6 | --num-attention-heads 16
7 | --use-rotary-position-embeddings
8 | --disable-bias-linear
9 | --add-qkv-bias
10 | --normalization "RMSNorm"
11 | --norm-epsilon 1e-6
12 | --rotary-base 1000000
13 | --group-query-attention
14 | --num-query-groups 2
15 | --vocab-size 151936
16 | )
17 |
--------------------------------------------------------------------------------
/scripts/models/qwen3-0.6B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 28
4 | --hidden-size 1024
5 | --ffn-hidden-size 3072
6 | --num-attention-heads 16
7 | --group-query-attention
8 | --num-query-groups 8
9 | --use-rotary-position-embeddings
10 | --disable-bias-linear
11 | --normalization "RMSNorm"
12 | --norm-epsilon 1e-6
13 | --rotary-base 1000000
14 | --vocab-size 151936
15 | --kv-channels 128
16 | --qk-layernorm
17 | )
--------------------------------------------------------------------------------
/scripts/models/qwen3-1.7B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 28
4 | --hidden-size 2048
5 | --ffn-hidden-size 6144
6 | --num-attention-heads 16
7 | --group-query-attention
8 | --num-query-groups 8
9 | --use-rotary-position-embeddings
10 | --disable-bias-linear
11 | --normalization "RMSNorm"
12 | --norm-epsilon 1e-6
13 | --rotary-base 1000000
14 | --vocab-size 151936
15 | --kv-channels 128
16 | --qk-layernorm
17 | )
--------------------------------------------------------------------------------
/scripts/models/qwen2.5-32B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 64
4 | --hidden-size 5120
5 | --ffn-hidden-size 27648
6 | --num-attention-heads 40
7 | --group-query-attention
8 | --num-query-groups 8
9 | --use-rotary-position-embeddings
10 | --disable-bias-linear
11 | --add-qkv-bias
12 | --normalization "RMSNorm"
13 | --norm-epsilon 1e-5
14 | --rotary-base 1000000
15 | --vocab-size 152064
16 | --untie-embeddings-and-output-weights
17 | )
--------------------------------------------------------------------------------
/scripts/models/qwen2.5-7B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 28
4 | --hidden-size 3584
5 | --ffn-hidden-size 18944
6 | --num-attention-heads 28
7 | --group-query-attention
8 | --num-query-groups 4
9 | --use-rotary-position-embeddings
10 | --disable-bias-linear
11 | --add-qkv-bias
12 | --normalization "RMSNorm"
13 | --norm-epsilon 1e-06
14 | --rotary-base 1000000
15 | --vocab-size 152064
16 | --untie-embeddings-and-output-weights
17 | )
--------------------------------------------------------------------------------
/scripts/models/qwen3-4B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 36
4 | --hidden-size 2560
5 | --ffn-hidden-size 9728
6 | --num-attention-heads 32
7 | --group-query-attention
8 | --num-query-groups 8
9 | --use-rotary-position-embeddings
10 | --disable-bias-linear
11 | --normalization "RMSNorm"
12 | --norm-epsilon 1e-6
13 | --rotary-base "${MODEL_ARGS_ROTARY_BASE:-1000000}"
14 | --vocab-size 151936
15 | --kv-channels 128
16 | --qk-layernorm
17 | )
--------------------------------------------------------------------------------
/slime/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | _LOGGER_CONFIGURED = False
4 |
5 |
6 | # ref: SGLang
7 | def configure_logger(prefix: str = ""):
8 | global _LOGGER_CONFIGURED
9 | if _LOGGER_CONFIGURED:
10 | return
11 |
12 | _LOGGER_CONFIGURED = True
13 |
14 | logging.basicConfig(
15 | level=logging.INFO,
16 | format=f"[%(asctime)s{prefix}] %(filename)s:%(lineno)d - %(message)s",
17 | datefmt="%Y-%m-%d %H:%M:%S",
18 | force=True,
19 | )
20 |
--------------------------------------------------------------------------------
/scripts/models/qwen3-8B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 36
4 | --hidden-size 4096
5 | --ffn-hidden-size 12288
6 | --num-attention-heads 32
7 | --group-query-attention
8 | --num-query-groups 8
9 | --use-rotary-position-embeddings
10 | --disable-bias-linear
11 | --normalization "RMSNorm"
12 | --norm-epsilon 1e-6
13 | --rotary-base 1000000
14 | --vocab-size 151936
15 | --kv-channels 128
16 | --qk-layernorm
17 | --untie-embeddings-and-output-weights
18 | )
--------------------------------------------------------------------------------
/scripts/models/qwen3-14B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 40
4 | --hidden-size 5120
5 | --ffn-hidden-size 17408
6 | --num-attention-heads 40
7 | --group-query-attention
8 | --num-query-groups 8
9 | --use-rotary-position-embeddings
10 | --disable-bias-linear
11 | --normalization "RMSNorm"
12 | --norm-epsilon 1e-6
13 | --rotary-base 1000000
14 | --vocab-size 151936
15 | --kv-channels 128
16 | --qk-layernorm
17 | --untie-embeddings-and-output-weights
18 | )
19 |
--------------------------------------------------------------------------------
/scripts/models/qwen3-32B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 64
4 | --hidden-size 5120
5 | --ffn-hidden-size 25600
6 | --num-attention-heads 64
7 | --group-query-attention
8 | --num-query-groups 8
9 | --use-rotary-position-embeddings
10 | --disable-bias-linear
11 | --normalization "RMSNorm"
12 | --norm-epsilon 1e-6
13 | --rotary-base 1000000
14 | --vocab-size 151936
15 | --kv-channels 128
16 | --qk-layernorm
17 | --untie-embeddings-and-output-weights
18 | )
19 |
--------------------------------------------------------------------------------
/scripts/models/llama3.2-3B-Instruct.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 28
4 | --hidden-size 3072
5 | --ffn-hidden-size 8192
6 | --num-attention-heads 24
7 | --group-query-attention
8 | --num-query-groups 8
9 | --max-position-embeddings 131072
10 | --use-rotary-position-embeddings
11 | --disable-bias-linear
12 | --normalization "RMSNorm"
13 | --norm-epsilon 1e-5
14 | --rotary-base 500000
15 | --vocab-size 128256
16 | --kv-channels 128
17 | --use-rope-scaling
18 | --rotary-scaling-factor 32.0
19 | )
--------------------------------------------------------------------------------
/slime/backends/megatron_utils/megatron_to_hf/processors/padding_remover.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from slime.backends.megatron_utils.misc_utils import strip_param_name_prefix
4 |
5 |
6 | def remove_padding(name: str, param: torch.Tensor, vocab_size: int) -> torch.Tensor:
7 | """
8 | Remove vocab padding: param[:vocab_size] for embedding/output layers, else unchanged.
9 | """
10 | if strip_param_name_prefix(name) in {"embedding.word_embeddings.weight", "output_layer.weight"}:
11 | return param[:vocab_size]
12 | return param
13 |
--------------------------------------------------------------------------------
/scripts/models/llama3.2-3B-Instruct-amd.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 28
4 | --hidden-size 3072
5 | --ffn-hidden-size 8192
6 | --num-attention-heads 24
7 | --group-query-attention
8 | --num-query-groups 8
9 | --max-position-embeddings 131072
10 | --use-rotary-position-embeddings
11 | --disable-bias-linear
12 | --normalization "RMSNorm"
13 | --norm-epsilon 1e-5
14 | --rotary-base 500000
15 | --vocab-size 128256
16 | --kv-channels 128
17 | --use-rope-scaling
18 | --rotary-scaling-factor 32.0
19 | )
20 |
--------------------------------------------------------------------------------
/scripts/models/mimo-7B-rl.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 36
4 | --hidden-size 4096
5 | --ffn-hidden-size 11008
6 | --num-attention-heads 32
7 | --group-query-attention
8 | --num-query-groups 8
9 | --use-rotary-position-embeddings
10 | --disable-bias-linear
11 | --add-qkv-bias
12 | --normalization "RMSNorm"
13 | --norm-epsilon 1e-05
14 | --rotary-base 640000
15 | --vocab-size 151680
16 | --untie-embeddings-and-output-weights
17 | --max-position-embeddings 32768
18 | --mtp-num-layers 1
19 | )
20 |
--------------------------------------------------------------------------------
/docs/zh/advanced/fault-torlance.md:
--------------------------------------------------------------------------------
1 | # 容灾
2 |
3 | 为了保证长期稳定的 RL 训练,slime 会默认开始一定程度的容灾机制。这里主要介绍一下 slime 中容灾的一些设计思路。
4 |
5 | 可以通过 `--use-fault-tolerance` 开启容灾机制。
6 |
7 | ## rollout 容灾
8 |
9 | slime 会在 rollout 过程中,定期向所有 SGLang server 发送心跳请求(`/health_generate`),如果心跳超时,则会停止这个 SGLang server。并在这轮 rollout 完成之后进行重启和正确的参数更新。
10 |
11 | - `--rollout-health-check-first-wait`:由于一些大的 MoE 模型在第一次运行时需要处理一些编译,我们会在第一次 rollout 前等待 `rollout_health_check_first_wait` 秒再开始发送心跳,默认为 300s;
12 | - `--rollout-health-check-interval`:心跳检查间隔,默认为 10s;
13 | - `--rollout-health-check-timeout`:心跳超时限额,默认为 5s。
14 |
--------------------------------------------------------------------------------
/scripts/models/llama3.1-8B-Instruct.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --swiglu
3 | --num-layers 32
4 | --hidden-size 4096
5 | --ffn-hidden-size 14336
6 | --num-attention-heads 32
7 | --group-query-attention
8 | --num-query-groups 8
9 | --max-position-embeddings 131072
10 | --use-rotary-position-embeddings
11 | --disable-bias-linear
12 | --normalization "RMSNorm"
13 | --norm-epsilon 1e-5
14 | --rotary-base 500000
15 | --vocab-size 128256
16 | --kv-channels 128
17 | --use-rope-scaling
18 | --rotary-scaling-factor 8.0
19 | --untie-embeddings-and-output-weights
20 | )
--------------------------------------------------------------------------------
/slime/rollout/filter_hub/dynamic_sampling_filters.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from slime.rollout.filter_hub.base_types import DynamicFilterOutput
4 | from slime.utils.types import Sample
5 |
6 | __all__ = ["check_reward_nonzero_std"]
7 |
8 |
9 | def check_reward_nonzero_std(args, samples: list[Sample], **kwargs):
10 | rewards = [sample.get_reward_value(args) for sample in samples]
11 | keep = torch.tensor(rewards, dtype=torch.float).std() > 0.0
12 | return DynamicFilterOutput(
13 | keep=keep,
14 | reason=None if keep else f"zero_std_{round(rewards[0], 1)}",
15 | )
16 |
--------------------------------------------------------------------------------
/examples/eval/nemo_skills/config/local_cluster.yaml:
--------------------------------------------------------------------------------
1 | # Minimal cluster config for running `ns eval` directly on the current host
2 | # without spinning up containers or mount checks.
3 |
4 | executor: none
5 |
6 | # Provide stub container entries so pipeline code can reference them even
7 | # though we do not launch actual containers when executor == "none".
8 | containers:
9 | nemo-skills: ""
10 | sglang: ""
11 | trtllm: ""
12 | vllm: ""
13 |
14 | # No mount enforcement in "none" mode, but keep the field for completeness.
15 | mounts: []
16 |
17 | # Optional default env vars propagated to ns jobs. Leave empty unless needed.
18 | env_vars: []
19 |
--------------------------------------------------------------------------------
/scripts/models/glm4-9B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --spec "slime_plugins.models.glm4" "get_glm_spec"
3 | --swiglu
4 | --num-layers 40
5 | --hidden-size 4096
6 | --ffn-hidden-size 13696
7 | --num-attention-heads 32
8 | --group-query-attention
9 | --num-query-groups 2
10 | --use-rotary-position-embeddings
11 | --disable-bias-linear
12 | --add-qkv-bias
13 | --normalization "RMSNorm"
14 | --norm-epsilon 1e-5
15 | --rotary-base 10000
16 | --vocab-size 151552
17 | --post-self-attn-layernorm
18 | --post-mlp-layernorm
19 | --rotary-interleaved
20 | --rotary-percent 0.5
21 | --no-rope-fusion
22 | --untie-embeddings-and-output-weights
23 | )
--------------------------------------------------------------------------------
/slime_plugins/models/glm4.py:
--------------------------------------------------------------------------------
1 | from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
2 |
3 |
4 | def get_glm_spec(args, config, vp_stage):
5 | transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
6 | num_experts=args.num_experts,
7 | moe_grouped_gemm=args.moe_grouped_gemm,
8 | qk_layernorm=args.qk_layernorm,
9 | multi_latent_attention=args.multi_latent_attention,
10 | moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
11 | post_self_attn_layernorm=args.post_self_attn_layernorm,
12 | post_mlp_layernorm=args.post_mlp_layernorm,
13 | )
14 | return transformer_layer_spec
15 |
--------------------------------------------------------------------------------
/scripts/models/glm4-32B.sh:
--------------------------------------------------------------------------------
1 | MODEL_ARGS=(
2 | --spec "slime_plugins.models.glm4" "get_glm_spec"
3 | --swiglu
4 | --num-layers 64
5 | --hidden-size 6144
6 | --ffn-hidden-size 23040
7 | --num-attention-heads 48
8 | --max-position-embeddings 32768
9 | --seq-length 32768
10 | --use-rotary-position-embeddings
11 | --disable-bias-linear
12 | --normalization "RMSNorm"
13 | --norm-epsilon 1e-5
14 | --rotary-base 10000
15 | --group-query-attention
16 | --num-query-groups 8
17 | --vocab-size 151552
18 | --post-self-attn-layernorm
19 | --post-mlp-layernorm
20 | --rotary-interleaved
21 | --rotary-percent 0.5
22 | --no-rope-fusion
23 | --untie-embeddings-and-output-weights
24 | )
--------------------------------------------------------------------------------
/examples/eval_multi_task/multi_task.yaml:
--------------------------------------------------------------------------------
1 | eval:
2 | defaults:
3 | max_response_len: 16384
4 | top_p: 0.7
5 | datasets:
6 | - name: aime
7 | path: /root/aime-2024/aime-2024.jsonl
8 | rm_type: deepscaler
9 | n_samples_per_eval_prompt: 16
10 | - name: gpqa # huggingface-cli download --repo-type dataset zyzshishui0627/gpqa_diamond --local-dir /root/gpqa
11 | path: /root/gpqa/gpqa_eval.jsonl
12 | rm_type: gpqa
13 | n_samples_per_eval_prompt: 2
14 | - name: ifbench # huggingface-cli download --repo-type dataset zyzshishui0627/IFBench --local-dir /root/ifbench
15 | path: /root/ifbench/IFBench_eval.jsonl
16 | rm_type: ifbench
17 | n_samples_per_eval_prompt: 1
18 |
--------------------------------------------------------------------------------
/examples/formal_math/single_round/README.md:
--------------------------------------------------------------------------------
1 | # Usage
2 |
3 | For the minimal demo:
4 |
5 | ```shell
6 | # install dependencies
7 | apt update && apt install -y docker-cli
8 | pip install kimina-client polars
9 |
10 | # prepare data
11 | python examples/formal_math/single_round/prepare_data.py --output-name minimal_demo
12 |
13 | # prepare ray, model, test dataset, etc
14 | # normally just use this script, but here we want to demonstrate run_minimal.py, thus skip ray-submit part
15 | SLIME_SCRIPT_ENABLE_RAY_SUBMIT=0 python examples/formal_math/single_round/run.py
16 |
17 | # run
18 | python examples/formal_math/single_round/run_minimal.py
19 | ```
20 |
21 | The code also support more complicated cases, e.g.:
22 |
23 | * SFT + RL
24 | * Data filter + RL
25 |
--------------------------------------------------------------------------------
/examples/retool/rl_data_preprocess.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 |
3 | # Load the original dataset
4 | ds = load_dataset("BytedTsinghua-SIA/DAPO-Math-17k", split="train")
5 |
6 |
7 | # Map to extract the ground_truth from the reward_model dict and create a new 'label' field
8 | def transform(example):
9 | return {
10 | "prompt": example["prompt"][0]["content"] if example["prompt"] else None,
11 | "label": example["reward_model"]["ground_truth"],
12 | }
13 |
14 |
15 | ds2 = ds.map(transform, remove_columns=ds.column_names)
16 |
17 | # Optionally, verify the first few entries
18 | print(ds2[0])
19 |
20 | # save to jsonl
21 | ds2.to_json("/root/dapo-math-17k-processed/dapo_math_17k_cleaned.jsonl", orient="records", lines=True)
22 |
--------------------------------------------------------------------------------
/slime/utils/train_dump_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 |
4 | import torch
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | def save_debug_train_data(args, *, rollout_id, rollout_data):
10 | if (path_template := args.save_debug_train_data) is not None:
11 | rank = torch.distributed.get_rank()
12 | path = Path(path_template.format(rollout_id=rollout_id, rank=rank))
13 | logger.info(f"Save debug train data to {path}")
14 | path.parent.mkdir(parents=True, exist_ok=True)
15 | torch.save(
16 | dict(
17 | rollout_id=rollout_id,
18 | rank=rank,
19 | rollout_data=rollout_data,
20 | ),
21 | path,
22 | )
23 |
--------------------------------------------------------------------------------
/slime/utils/tracking_utils.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | from slime.utils.tensorboard_utils import _TensorboardAdapter
3 |
4 | from . import wandb_utils
5 |
6 |
7 | def init_tracking(args, primary: bool = True, **kwargs):
8 | if primary:
9 | wandb_utils.init_wandb_primary(args, **kwargs)
10 | else:
11 | wandb_utils.init_wandb_secondary(args, **kwargs)
12 |
13 |
14 | # TODO further refactor, e.g. put TensorBoard init to the "init" part
15 | def log(args, metrics, step_key: str):
16 | if args.use_wandb:
17 | wandb.log(metrics)
18 |
19 | if args.use_tensorboard:
20 | metrics_except_step = {k: v for k, v in metrics.items() if k != step_key}
21 | _TensorboardAdapter(args).log(data=metrics_except_step, step=metrics[step_key])
22 |
--------------------------------------------------------------------------------
/slime/utils/megatron_bridge_utils.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 |
3 | try:
4 | from megatron.core.utils import unwrap_model
5 | except ImportError:
6 | unwrap_model = None
7 |
8 |
9 | @contextmanager
10 | def patch_megatron_model(model):
11 | unwrapped_model = unwrap_model(model)[0]
12 | model_config = unwrapped_model.config
13 | attribute_was_added = False
14 | if not hasattr(model_config, "share_embeddings_and_output_weights"):
15 | model_config.share_embeddings_and_output_weights = unwrapped_model.share_embeddings_and_output_weights
16 | attribute_was_added = True
17 |
18 | try:
19 | yield
20 | finally:
21 | if attribute_was_added:
22 | delattr(model_config, "share_embeddings_and_output_weights")
23 |
--------------------------------------------------------------------------------
/slime/rollout/base_types.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any
3 |
4 | from slime.utils.types import Sample
5 |
6 |
7 | @dataclass
8 | class RolloutFnTrainOutput:
9 | samples: list[list[Sample]]
10 | metrics: dict[str, Any] = None
11 |
12 |
13 | @dataclass
14 | class RolloutFnEvalOutput:
15 | data: dict[str, dict[str, Any]]
16 | metrics: dict[str, Any] = None
17 |
18 |
19 | def call_rollout_fn(fn, *args, evaluation: bool, **kwargs):
20 | output = fn(*args, **kwargs, evaluation=evaluation)
21 |
22 | # compatibility for legacy version
23 | if not isinstance(output, (RolloutFnTrainOutput, RolloutFnEvalOutput)):
24 | output = RolloutFnEvalOutput(data=output) if evaluation else RolloutFnTrainOutput(samples=output)
25 |
26 | return output
27 |
--------------------------------------------------------------------------------
/examples/eval_multi_task/README.md:
--------------------------------------------------------------------------------
1 | # Multi-Task Evaluation Example
2 |
3 | ## Configuring `multi_task.yaml`
4 | - `eval.defaults` defines inference parameters shared by every dataset entry. Override them inside an individual dataset block if needed.
5 | - `eval.datasets` enumerates the datasets to evaluate. Each entry should specify:
6 | - `name`: a short identifier that appears in logs and dashboards.
7 | - `path`: the path to the dataset JSONL file.
8 | - `rm_type`: which reward function to use for scoring.
9 | - `n_samples_per_eval_prompt`: how many candidate completions to generate per prompt.
10 |
11 | ## IFBench Notes
12 | - When `ifbench` is used, `slime/rollout/rm_hub/ifbench.py` will automatically prepares the scoring environment, so no additional manual setup is required beyond providing the dataset path.
13 |
--------------------------------------------------------------------------------
/examples/retool/sft_data_processing.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 |
3 | ds = load_dataset("JoeYing/ReTool-SFT")["train"]
4 |
5 |
6 | def convert(sample):
7 | conversations = sample["messages"]
8 |
9 | def convert_role(role):
10 | if role == "user":
11 | return "user"
12 | elif role == "assistant":
13 | return "assistant"
14 | elif role == "system":
15 | return "system"
16 | else:
17 | raise ValueError(f"Unknown role: {role}")
18 |
19 | messages = [
20 | {
21 | "role": convert_role(turn["role"]),
22 | "content": turn["content"],
23 | }
24 | for turn in conversations
25 | ]
26 |
27 | return {"messages": messages}
28 |
29 |
30 | ds = ds.map(convert)
31 | ds.to_parquet("./data/retool/ReTool-SFT.parquet")
32 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | # Docker release rule
2 |
3 | We will publish 2 kinds of docker images:
4 | 1. stable version, which based on official sglang release. We will store the patch on those versions.
5 | 2. latest version, which aligns to `lmsysorg/sglang:latest`.
6 |
7 | current stable version is:
8 | - sglang nightly-dev-20251208-5e2cda61 (5e2cda6158e670e64b926a9985d65826c537ac82), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133)
9 |
10 | history versions:
11 | - sglang v0.5.5.post1 (303cc957e62384044dfa8e52d7d8af8abe12f0ac), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133)
12 |
13 | The command to build:
14 |
15 | ```bash
16 | just release
17 | ```
18 |
19 | Before each update, we will test the following models with 64xH100:
20 |
21 | - Qwen3-4B sync
22 | - Qwen3-4B async
23 | - Qwen3-30B-A3B sync
24 | - Qwen3-30B-A3B fp8 sync
25 | - GLM-4.5-355B-A32B sync
26 |
--------------------------------------------------------------------------------
/docs/serve.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
4 | LANG="${1:-all}"
5 | PORT="${PORT:-8000}"
6 |
7 | cd "$SCRIPT_DIR"
8 |
9 | if [ "$LANG" = "all" ]; then
10 | # Expect both builds present
11 | if [ ! -d build/en ] || [ ! -d build/zh ]; then
12 | echo "[serve] Missing build/en or build/zh. Run ./build_all.sh first." >&2
13 | fi
14 | echo "[serve] Serving multi-language docs root on http://localhost:$PORT (en/, zh/)"
15 | python -m http.server -d ./build "$PORT"
16 | exit $?
17 | fi
18 |
19 | if [ "$LANG" != "en" ] && [ "$LANG" != "zh" ]; then
20 | echo "Usage: $0 [en|zh|all]" >&2
21 | exit 1
22 | fi
23 |
24 | if [ ! -d "build/$LANG" ]; then
25 | echo "[serve] build/$LANG not found. Run ./build.sh $LANG first." >&2
26 | exit 1
27 | fi
28 | echo "[serve] Serving $LANG docs on http://localhost:$PORT"
29 | python -m http.server -d ./build/$LANG "$PORT"
--------------------------------------------------------------------------------
/slime/rollout/generate_hub/benchmarkers.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | from argparse import Namespace
4 | from copy import deepcopy
5 | from typing import Any
6 |
7 | from slime.rollout.sglang_rollout import generate as _generate_base
8 | from slime.utils.types import Sample
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | async def generate_with_random_osl(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample:
14 | # TODO: make it configurable after we have an enhanced arg parser
15 | min_osl = 32 * 1024
16 | max_osl = 64 * 1024
17 |
18 | modified_sampling_params = deepcopy(sampling_params)
19 | modified_sampling_params["ignore_eos"] = True
20 | modified_sampling_params["max_new_tokens"] = random.randrange(min_osl, max_osl)
21 |
22 | ans = await _generate_base(args, sample, modified_sampling_params)
23 |
24 | logger.info(f"generate_with_random_osl {ans.response_length=}")
25 | return ans
26 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # slime Documentation
2 |
3 | We recommend new contributors start from writing documentation, which helps you quickly understand slime codebase.
4 | Most documentation files are located under the `docs/` folder.
5 |
6 | ## Docs Workflow
7 |
8 | ### Install Dependency
9 |
10 | ```bash
11 | apt-get update && apt-get install -y pandoc parallel retry
12 | pip install -r requirements.txt
13 | ```
14 |
15 | ### Update Documentation
16 |
17 | You can update the documentation in the en and zh folders by adding Markdown or Jupyter Notebook files to the appropriate subdirectories. If you create new files, make sure to update index.rst (or any other relevant .rst files) accordingly.
18 |
19 | ## Build and Render
20 |
21 | ```bash
22 | # build english version
23 | bash ./build.sh en
24 | bash ./serve.sh en
25 |
26 | # build chinese version
27 | bash ./build.sh zh
28 | bash ./serve.sh zh
29 | ```
30 |
31 | You can then visit `http://localhost:8000` to view the documentation.
--------------------------------------------------------------------------------
/slime/utils/async_utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import threading
3 |
4 | __all__ = ["get_async_loop", "run"]
5 |
6 |
7 | # Create a background event loop thread
8 | class AsyncLoopThread:
9 | def __init__(self):
10 | self.loop = asyncio.new_event_loop()
11 | self._thread = threading.Thread(target=self._start_loop, daemon=True)
12 | self._thread.start()
13 |
14 | def _start_loop(self):
15 | asyncio.set_event_loop(self.loop)
16 | self.loop.run_forever()
17 |
18 | def run(self, coro):
19 | # Schedule a coroutine onto the loop and block until it's done
20 | return asyncio.run_coroutine_threadsafe(coro, self.loop).result()
21 |
22 |
23 | # Create one global instance
24 | async_loop = None
25 |
26 |
27 | def get_async_loop():
28 | global async_loop
29 | if async_loop is None:
30 | async_loop = AsyncLoopThread()
31 | return async_loop
32 |
33 |
34 | def run(coro):
35 | """Run a coroutine in the background event loop."""
36 | return get_async_loop().run(coro)
37 |
--------------------------------------------------------------------------------
/examples/true_on_policy_vlm/README.md:
--------------------------------------------------------------------------------
1 | # True On-Policy between Training and Inference for VLM
2 |
3 | This example demonstrates true on-policy training with Qwen3-VL dense model on FSDP. The core concepts and expected observations are the same as [true_on_policy](../true_on_policy/README.md).
4 |
5 |
6 |
7 |
8 |
9 | ## Usage
10 |
11 | ```bash
12 | SLIME_SCRIPT_NUM_GPUS=8 python examples/true_on_policy_vlm/run_simple.py
13 | ```
14 |
15 | ## How it is Implemented
16 |
17 | For the text backbone, please refer to [true_on_policy for the text-only model](../true_on_policy/README.md).
18 |
19 | For the VLM, we only need to ensure that the image encoder behaves as expected. Please refer to [SGLang#14636](https://github.com/sgl-project/sglang/pull/14636). We need to align numeric operation details between the two systems, so that the ViT forward pass matches the behavior in both SGLang and transformers.
20 |
21 | ## Notes
22 |
23 | It is expected that the true-on-policy version is slower.
--------------------------------------------------------------------------------
/slime/backends/megatron_utils/update_weight/hf_weight_iterator_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class HfWeightIteratorBase(ABC):
5 | @staticmethod
6 | def create(args, model, **kwargs):
7 | from .hf_weight_iterator_bridge import HfWeightIteratorBridge
8 | from .hf_weight_iterator_direct import HfWeightIteratorDirect
9 |
10 | c = {
11 | "raw": HfWeightIteratorDirect,
12 | "bridge": HfWeightIteratorBridge,
13 | }[args.megatron_to_hf_mode]
14 |
15 | return c(args, model, **kwargs)
16 |
17 | def __init__(self, args, model, model_name, quantization_config):
18 | self.args = args
19 | self.model = model
20 | self.model_name = model_name
21 | self.quantization_config = quantization_config
22 |
23 | @abstractmethod
24 | def get_hf_weight_chunks(self, megatron_local_weights):
25 | """
26 | Mental model of the API:
27 | megatron_model.to_hf_magically().named_parameters()
28 | """
29 | raise NotImplementedError
30 |
--------------------------------------------------------------------------------
/slime/utils/metric_checker.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logger = logging.getLogger(__name__)
4 |
5 |
6 | class MetricChecker:
7 | @staticmethod
8 | def maybe_create(args):
9 | if args.ci_test and (args.ci_metric_checker_key is not None):
10 | return MetricChecker(args)
11 | return None
12 |
13 | def __init__(self, args):
14 | self.args = args
15 | self._exists_check_success = False
16 |
17 | def on_eval(self, metrics: dict[str, float]):
18 | actual_value = metrics.get(self.args.ci_metric_checker_key)
19 | assert actual_value is not None, f"{metrics=} {self.args.ci_metric_checker_key=}"
20 |
21 | check_success = actual_value >= self.args.ci_metric_checker_threshold
22 | logger.info(f"[MetricChecker] {check_success=} {actual_value=} {self.args.ci_metric_checker_threshold=}")
23 |
24 | self._exists_check_success |= check_success
25 |
26 | def dispose(self):
27 | assert self._exists_check_success, "[MetricChecker] accuracy check failed"
28 | logger.info("[MetricChecker] pass dispose check")
29 |
--------------------------------------------------------------------------------
/docs/en/advanced/fault-tolerance.md:
--------------------------------------------------------------------------------
1 | # Fault Tolerance
2 |
3 | To ensure long-term, stable RL training, slime enables a certain level of fault tolerance by default. This section introduces the design philosophy behind fault tolerance in slime.
4 |
5 | To enable the fault tolerance function in slime, please set `--use-fault-tolerance`.
6 |
7 | ## Rollout Fault Tolerance
8 |
9 | During the rollout process, slime periodically sends heartbeat requests (`/health_generate`) to all SGLang servers. If a heartbeat times out, that SGLang server will be stopped. After the current rollout round is complete, the server will be restarted and its parameters will be correctly updated.
10 |
11 | - `--rollout-health-check-first-wait`: Since some large MoE models require compilation on their first run, slime will wait for `rollout_health_check_first_wait` seconds before the first rollout to start sending heartbeats. Defaults to 300s.
12 | - `--rollout-health-check-interval`: The interval between heartbeat checks. Defaults to 10s.
13 | - `--rollout-health-check-timeout`: The timeout limit for a heartbeat request. Defaults to 5s.
14 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | types: [opened, synchronize, reopened, ready_for_review]
8 |
9 | permissions:
10 | contents: read
11 |
12 | jobs:
13 | run-pre-commit:
14 | name: Run pre-commit
15 | runs-on: ubuntu-latest
16 | steps:
17 | - name: Checkout repository
18 | uses: actions/checkout@v4
19 | with:
20 | fetch-depth: 0
21 |
22 | - name: Set up Python
23 | uses: actions/setup-python@v5
24 | with:
25 | python-version: '3.10'
26 | cache: 'pip'
27 |
28 | - name: Install pre-commit
29 | run: pip install --upgrade pip pre-commit
30 |
31 | - name: Cache pre-commit environments
32 | uses: actions/cache@v4
33 | with:
34 | path: ~/.cache/pre-commit
35 | key: pre-commit-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}
36 | restore-keys: |
37 | pre-commit-${{ runner.os }}-
38 |
39 | - name: Run pre-commit on all files
40 | run: pre-commit run --all-files --show-diff-on-failure --color=always
41 |
42 |
--------------------------------------------------------------------------------
/slime/backends/fsdp_utils/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | try:
5 | _TORCH_MEMORY_SAVER_AVAILABLE = True
6 | except ImportError:
7 | logging.warning("torch_memory_saver is not installed, refer to : https://github.com/fzyzcjy/torch_memory_saver")
8 | _TORCH_MEMORY_SAVER_AVAILABLE = False
9 |
10 | try:
11 | _FSDP_AVAILABLE = True
12 | except ImportError as e:
13 | logging.warning(f"FSDP backend dependencies not available: {e}")
14 | _FSDP_AVAILABLE = False
15 |
16 | if _FSDP_AVAILABLE:
17 | from .actor import FSDPTrainRayActor
18 | from .arguments import load_fsdp_args
19 | else:
20 |
21 | def _raise_import_error(*args, **kwargs):
22 | raise ImportError(
23 | "FSDP backend is not available. "
24 | "Please ensure PyTorch with FSDP2 support is installed. "
25 | "For installation instructions, refer to: https://pytorch.org/docs/stable/distributed.fsdp.fully_shard.html"
26 | )
27 |
28 | FSDPTrainRayActor = _raise_import_error
29 | load_fsdp_args = _raise_import_error
30 |
31 | __all__ = ["load_fsdp_args", "FSDPTrainRayActor"]
32 |
33 | logging.getLogger().setLevel(logging.WARNING)
34 |
--------------------------------------------------------------------------------
/slime/backends/megatron_utils/sglang.py:
--------------------------------------------------------------------------------
1 | # the file to manage all sglang deps in the megatron actor
2 | try:
3 | from sglang.srt.layers.quantization.fp8_utils import quant_weight_ue8m0, transform_scale_ue8m0
4 | from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0
5 | except ImportError:
6 | quant_weight_ue8m0 = None
7 | transform_scale_ue8m0 = None
8 | should_deepgemm_weight_requant_ue8m0 = None
9 |
10 | try:
11 | from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
12 | except ImportError:
13 | from sglang.srt.patch_torch import monkey_patch_torch_reductions
14 |
15 |
16 | from sglang.srt.utils import MultiprocessingSerializer
17 |
18 |
19 | try:
20 | from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import]
21 | except ImportError:
22 | from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import]
23 |
24 | __all__ = [
25 | "quant_weight_ue8m0",
26 | "transform_scale_ue8m0",
27 | "should_deepgemm_weight_requant_ue8m0",
28 | "monkey_patch_torch_reductions",
29 | "MultiprocessingSerializer",
30 | "FlattenedTensorBucket",
31 | ]
32 |
--------------------------------------------------------------------------------
/examples/tau-bench/sglang_tool_parser.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from sglang.srt.function_call.function_call_parser import FunctionCallParser
4 | from sglang.srt.managers.io_struct import Function, Tool
5 |
6 |
7 | def parse_tools(response: str, tools: list[dict[str, Any]], parser: str = "qwen25"):
8 | """
9 | This function mimics the function call parser API from
10 | https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py#L952
11 | But running locally
12 | """
13 | tools_list = [
14 | Tool(
15 | function=Function(
16 | name=tool["function"]["name"],
17 | description=tool["function"]["description"],
18 | parameters=tool["function"]["parameters"],
19 | ),
20 | type=tool["type"],
21 | )
22 | for tool in tools
23 | ]
24 | parser = FunctionCallParser(tools=tools_list, tool_call_parser=parser)
25 |
26 | normal_text, calls = parser.parse_non_stream(response)
27 |
28 | return {
29 | "normal_text": normal_text,
30 | "calls": [call.model_dump() for call in calls], # Convert pydantic objects to dictionaries
31 | }
32 |
--------------------------------------------------------------------------------
/scripts/models/qwen3-30B-A3B.sh:
--------------------------------------------------------------------------------
1 | NLAYERS=48
2 | FIRST_K_DENSE_REPLACE=0
3 |
4 | arr=()
5 | for ((i=0; i list[Sample]:
17 |
18 | tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
19 | max_context_length = args.rollout_max_context_len if not evaluation else args.eval_max_context_len
20 |
21 | args.sampling_params = sampling_params
22 | args.rollout_max_context_len = max_context_length
23 | args.tokenizer = tokenizer
24 |
25 | for key, value in MULTI_AGENT_CONFIGS.items():
26 | setattr(args, key, value)
27 |
28 | custom_multi_agent_func = load_function(args.custom_multi_agent_function_path)
29 | samples = await custom_multi_agent_func(args, sample)
30 |
31 | random.shuffle(samples)
32 |
33 | return samples
34 |
--------------------------------------------------------------------------------
/.github/workflows/generate_github_workflows.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import jinja2
3 |
4 |
5 | def main():
6 | """
7 | Generates GitHub workflow YAML files from Jinja2 templates.
8 | """
9 | workflows_dir = Path(__file__).parent
10 | print(f"Scan dir: {workflows_dir}")
11 | env = jinja2.Environment(
12 | loader=jinja2.FileSystemLoader(str(workflows_dir)),
13 | block_start_string="<%",
14 | block_end_string="%>",
15 | variable_start_string="<<",
16 | variable_end_string=">>",
17 | )
18 |
19 | for template_path in workflows_dir.glob("*.yml.j2"):
20 | template = env.get_template(template_path.name)
21 | content = template.render()
22 |
23 | yaml_path = template_path.with_suffix("")
24 | with open(yaml_path, "w") as f:
25 | f.write(
26 | "#" * 80
27 | + "\n# This file is auto-generated from the .j2 file via generate_github_workflows.py. Do not edit manually.\n"
28 | + "#" * 80
29 | + "\n"
30 | )
31 | f.write(content)
32 |
33 | print(f"Generated {yaml_path} from {template_path}")
34 |
35 |
36 | if __name__ == "__main__":
37 | main()
38 |
--------------------------------------------------------------------------------
/scripts/models/qwen3-235B-A22B.sh:
--------------------------------------------------------------------------------
1 | # qwen3-235B-a22B
2 | NLAYERS=94
3 | FIRST_K_DENSE_REPLACE=0
4 |
5 | arr=()
6 | for ((i=0; i
25 | sh -c "
26 | cd /data/slime_ci &&
27 | /home/runner/config.sh --url ${GITHUB_RUNNER_URL} --token ${GITHUB_RUNNER_TOKEN} --unattended --work /data/slime_ci/runner_$(hostname) --disableupdate &&
28 | /home/runner/run.sh
29 | "
30 |
--------------------------------------------------------------------------------
/slime/backends/megatron_utils/arguments.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from megatron.training.arguments import parse_args, validate_args
4 | from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
5 |
6 | __all__ = ["validate_args", "parse_args", "set_default_megatron_args"]
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def set_default_megatron_args(args):
12 | # always use zero optimizer
13 | args.use_distributed_optimizer = True
14 | # TODO: maybe change this after megatron has good fp8 support
15 | args.bf16 = not args.fp16
16 | # placeholders
17 | args.seq_length = 4096
18 | args.max_position_embeddings = args.seq_length
19 | # compatible for megatron
20 | if hasattr(args, "rope_type") and args.rope_type is None:
21 | args.rope_type = "yarn" if args.multi_latent_attention else "rope"
22 |
23 | if args.vocab_size and not args.padded_vocab_size:
24 | args.padded_vocab_size = _vocab_size_with_padding(args.vocab_size, args)
25 |
26 | if not args.tokenizer_model and not args.tokenizer_type:
27 | logger.info("--tokenizer-model not set, use --hf-checkpoint as tokenizer model.")
28 | args.tokenizer_model = args.hf_checkpoint
29 | args.tokenizer_type = "HuggingFaceTokenizer"
30 | return args
31 |
--------------------------------------------------------------------------------
/scripts/models/glm4.5-355B-A32B.sh:
--------------------------------------------------------------------------------
1 | N_DENSE_LAYERS=3
2 | N_MOE_LAYERS=89
3 |
4 | # glm4.5-355B-A32B
5 | MODEL_ARGS=(
6 | --disable-bias-linear
7 | --qk-layernorm
8 | --group-query-attention
9 | --num-attention-heads 96
10 | --num-query-groups 8
11 | --kv-channels 128
12 | --num-layers $((N_DENSE_LAYERS + N_MOE_LAYERS))
13 | --hidden-size 5120
14 | --ffn-hidden-size 12288
15 |
16 | --add-qkv-bias
17 | --normalization RMSNorm
18 | --position-embedding-type rope
19 | --rotary-percent 0.5
20 | --swiglu
21 | --untie-embeddings-and-output-weights
22 | --vocab-size 151552
23 |
24 | --rotary-base 1000000
25 |
26 | # moe
27 | --moe-ffn-hidden-size 1536
28 | --moe-shared-expert-intermediate-size 1536
29 | --moe-router-pre-softmax
30 | --moe-router-score-function sigmoid
31 | --moe-router-enable-expert-bias
32 | --moe-router-bias-update-rate 0
33 | --moe-router-load-balancing-type seq_aux_loss
34 | --moe-token-dispatcher-type alltoall
35 | --moe-router-topk 8
36 | --moe-router-topk-scaling-factor 2.5
37 | --moe-layer-freq [0]*$N_DENSE_LAYERS+[1]*$N_MOE_LAYERS
38 | --num-experts 160
39 | --moe-grouped-gemm
40 | --moe-router-dtype fp32
41 | --moe-permute-fusion
42 | --moe-aux-loss-coeff 0
43 | )
44 |
--------------------------------------------------------------------------------
/slime/utils/rocm_checkpoint_writer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from megatron.core.dist_checkpointing.strategies.filesystem_async import FileSystemWriterAsync
3 |
4 |
5 | class ROCmFileSystemWriterAsync(FileSystemWriterAsync):
6 | """
7 | FileSystemWriterAsync wrapper for ROCm compatibility.
8 |
9 | On ROCm/HIP, using non_blocking=True causes tensors to be stored in pinned memory,
10 | which triggers segmentation faults when forking subprocesses afterward.
11 | """
12 |
13 | @staticmethod
14 | def preload_tensors(*args, **kwargs):
15 | # Change argument non_blocking to False on HIP platform
16 | # The tensors will be stored in pinned memory if non_blocking=True
17 | # Currently on the ROCm platform, forking a subprocess afterward
18 | # with pinned_memory=True will trigger segmentation fault
19 | if torch.version.hip:
20 | print("HIP/ROCm detected: setting non_blocking=False in preload_tensors")
21 | if "non_blocking" in kwargs:
22 | kwargs["non_blocking"] = False
23 | elif len(args) > 1 and isinstance(args[-1], bool):
24 | # non_blocking is typically the last argument
25 | args = args[:-1] + (False,)
26 |
27 | return FileSystemWriterAsync.preload_tensors(*args, **kwargs)
28 |
--------------------------------------------------------------------------------
/slime/utils/iter_utils.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from collections.abc import Callable, Iterable
3 | from typing import Any
4 |
5 | import torch
6 |
7 |
8 | # details: https://stackoverflow.com/questions/773/how-do-i-use-itertools-groupby
9 | def group_by(iterable, key=None):
10 | """Similar to itertools.groupby, but do not require iterable to be sorted"""
11 | ret = defaultdict(list)
12 | for item in iterable:
13 | ret[key(item) if key is not None else item].append(item)
14 | return dict(ret)
15 |
16 |
17 | # TODO fsdp can also use this
18 | def chunk_named_params_by_size(named_params: Iterable[tuple[str, torch.Tensor]], chunk_size: int):
19 | return _chunk_by_size(
20 | named_params,
21 | compute_size=lambda named_weight: named_weight[1].nbytes,
22 | chunk_size=chunk_size,
23 | )
24 |
25 |
26 | def _chunk_by_size(objects: Iterable[Any], compute_size: Callable[[Any], int], chunk_size: int):
27 | bucket: list[Any] = []
28 | bucket_size = 0
29 |
30 | for obj in objects:
31 | obj_size = compute_size(obj)
32 |
33 | if bucket and (bucket_size + obj_size) >= chunk_size:
34 | yield bucket
35 | bucket = []
36 | bucket_size = 0
37 |
38 | bucket.append(obj)
39 | bucket_size += obj_size
40 |
41 | if bucket:
42 | yield bucket
43 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | ci:
5 | autofix_prs: true
6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
7 | autoupdate_schedule: quarterly
8 |
9 | repos:
10 | - repo: https://github.com/pre-commit/pre-commit-hooks
11 | rev: v4.5.0
12 | hooks:
13 | - id: check-yaml
14 | - id: check-case-conflict
15 | - id: detect-private-key
16 | - id: check-added-large-files
17 | args: ['--maxkb=1000']
18 | - id: requirements-txt-fixer
19 |
20 | - repo: https://github.com/astral-sh/ruff-pre-commit
21 | rev: v0.14.7
22 | hooks:
23 | - id: ruff-check
24 | args: [ --fix ]
25 |
26 | - repo: https://github.com/PyCQA/autoflake
27 | rev: v2.0.2
28 | hooks:
29 | - id: autoflake
30 | args: [--remove-all-unused-imports, --in-place]
31 |
32 | - repo: https://github.com/pycqa/isort
33 | rev: 5.13.2 # 选一个稳定版本
34 | hooks:
35 | - id: isort
36 | args:
37 | - "--profile=black" # 常见:与 Black 对齐风格
38 | - "--filter-files" # 忽略已在 .gitignore 的文件
39 | additional_dependencies: [] # 需要插件时在这里加
40 |
41 | - repo: https://github.com/psf/black
42 | rev: 24.3.0
43 | hooks:
44 | - id: black
45 | name: Format code
46 | additional_dependencies: ['click==8.0.2']
47 |
--------------------------------------------------------------------------------
/docs/zh/advanced/speculative-decoding.md:
--------------------------------------------------------------------------------
1 | # 投机采样
2 |
3 | 投机采样是加速 rollout 的重要优化手段。推理过程中不再让昂贵的 Target Model 逐个 token 进行 decode,而是先由一个轻量级的 draft model 先进行 decode,生成多个 token 后,再由大模型进行批量验证。
4 |
5 | ## 使用投机采样加速推理
6 |
7 | 对于有 MTP 层的模型(例如 GLM-4.6、Deepseek-V3/R1),只需要添加:
8 |
9 | ```bash
10 | --sglang-speculative-algorithm EAGLE
11 | --sglang-speculative-num-steps 3
12 | --sglang-speculative-eagle-topk 1
13 | --sglang-speculative-num-draft-tokens 4
14 | ```
15 |
16 | 如果要使用单独训练的 draft model(例如 [SpecForge](https://docs.sglang.ai/SpecForge/) 训练的),还需要额外设置:
17 |
18 | ```bash
19 | --sglang-speculative-draft-model-path /your/draft/model/path
20 | ```
21 |
22 | 详细参数含义及配置方法,请参考 SGLang 的 speculative decoding [文档](https://docs.sglang.ai/advanced_features/speculative_decoding.html)
23 |
24 | ## 在线 SFT draft model
25 |
26 | 随着 RL 流程的进行,draft model 和 target model 的采样概率差异逐渐增大,能通过验证的 draft token 逐渐减少,spec 甚至可能造成负收益。
27 |
28 | 目前,slime 支持了在 RL 流程中在线训练 MTP 层,随着训练的进行同步更新 draft model,稳定提高了采样速度,相关原理可参见 [blog](https://www.notion.so/jiajunli-guapisolo/Power-Up-Speculative-Decoding-In-Reinforcement-Learning-2a92d24a293b802d9c73dbae429e581e)。使用方法如下:
29 |
30 | ```bash
31 | --mtp-num-layers 1
32 | --enable-mtp-training
33 | --mtp-loss-scaling-factor 0.2
34 | ```
35 |
36 | 注意 MTP 训练需要一个包含了 MTP 权重的 checkpoint,所以在将 huggingface checkpoint 转为 torch dist 时,也需要加上 `--mtp-num-layers 1`。
37 |
38 | 外部 draft model 的训练还在 WIP。
39 |
--------------------------------------------------------------------------------
/.github/workflows/release-docs.yaml:
--------------------------------------------------------------------------------
1 | name: Release Documentation
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - "docs/**"
9 | - "examples/**"
10 | - "version.txt"
11 | workflow_dispatch:
12 |
13 | concurrency:
14 | group: release-docs-${{ github.ref }}
15 | cancel-in-progress: true
16 |
17 | jobs:
18 | deploy:
19 | runs-on: ubuntu-latest
20 | if: github.repository == 'THUDM/slime'
21 | permissions:
22 | contents: write
23 | steps:
24 | - name: Checkout code
25 | uses: actions/checkout@v4
26 |
27 | - name: Setup Python
28 | uses: actions/setup-python@v5
29 | with:
30 | python-version: '3.13'
31 |
32 |
33 | - name: Install dependencies
34 | run: |
35 | apt-get update && apt-get install -y pandoc parallel retry
36 | pip install -r docs/requirements.txt
37 |
38 | - name: Build documentation
39 | run: |
40 | cd docs
41 | bash ./build.sh en
42 | bash ./build.sh zh
43 | mv ./build/zh ./build/en/
44 | env:
45 | LC_ALL: "en_US.UTF-8"
46 | LC_CTYPE: "en_US.UTF-8"
47 |
48 |
49 | - name: Deploy
50 | uses: peaceiris/actions-gh-pages@v4
51 | with:
52 | github_token: ${{ secrets.GITHUB_TOKEN }}
53 | publish_dir: ./docs/build/en
--------------------------------------------------------------------------------
/examples/eval/scripts/multi_tasks.yaml:
--------------------------------------------------------------------------------
1 | eval:
2 | defaults:
3 | n_samples_per_eval_prompt: 1
4 | temperature: 0.6
5 | top_p: 0.95
6 | top_k: -1
7 | max_response_len: 24576
8 | datasets: # these eval tasks go through slime dataset config and default rollout function (slime.rollout.sglang_rollout.generate_rollout)
9 | - name: gpqa # huggingface-cli download --repo-type dataset zyzshishui0627/gpqa_diamond --local-dir /root/gpqa
10 | path: /root/gpqa/gpqa_eval.jsonl
11 | rm_type: gpqa
12 | n_samples_per_eval_prompt: 2
13 | - name: ifbench # huggingface-cli download --repo-type dataset zyzshishui0627/IFBench --local-dir /root/ifbench
14 | path: /root/ifbench/IFBench_eval.jsonl
15 | rm_type: ifbench
16 | n_samples_per_eval_prompt: 1
17 | delegate: # these tasks go through delegate eval function (examples.eval.eval_delegate_rollout.generate_rollout)
18 | - name: skills
19 | # this url should align with env docker network alias
20 | url: http://skills_server:9050/evaluate
21 | timeout_secs: 7200
22 | max_retries: 5
23 | headers: {}
24 | datasets:
25 | - name: aime25
26 | max_response_len: 8192
27 | n_samples_per_eval_prompt: 8
28 | - name: arena-hard
29 | n_samples_per_eval_prompt: 2
30 | - name: hle
31 | max_response_len: 32768
32 |
33 |
--------------------------------------------------------------------------------
/slime/utils/memory_utils.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import logging
3 |
4 | import torch
5 | import torch.distributed as dist
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def clear_memory(clear_host_memory: bool = False):
11 | torch.cuda.synchronize()
12 | gc.collect()
13 | torch.cuda.empty_cache()
14 | if clear_host_memory:
15 | torch._C._host_emptyCache()
16 |
17 |
18 | def available_memory():
19 | device = torch.cuda.current_device()
20 | free, total = torch.cuda.mem_get_info(device)
21 | return {
22 | "gpu": str(device),
23 | "total_GB": _byte_to_gb(total),
24 | "free_GB": _byte_to_gb(free),
25 | "used_GB": _byte_to_gb(total - free),
26 | "allocated_GB": _byte_to_gb(torch.cuda.memory_allocated(device)),
27 | "reserved_GB": _byte_to_gb(torch.cuda.memory_reserved(device)),
28 | }
29 |
30 |
31 | def _byte_to_gb(n: int):
32 | return round(n / (1024**3), 2)
33 |
34 |
35 | def print_memory(msg, clear_before_print: bool = False):
36 | if clear_before_print:
37 | clear_memory()
38 |
39 | memory_info = available_memory()
40 | # Need to print for all ranks, b/c different rank can have different behaviors
41 | logger.info(
42 | f"[Rank {dist.get_rank()}] Memory-Usage {msg}{' (cleared before print)' if clear_before_print else ''}: {memory_info}"
43 | )
44 | return memory_info
45 |
--------------------------------------------------------------------------------
/slime_plugins/rollout_buffer/README_zh.md:
--------------------------------------------------------------------------------
1 | # Rollout Buffer
2 |
3 | ## 概述
4 |
5 | Rollout Buffer 是用于辅助纯异步 agent 训练的独立组件,其主要功能是使用 slime 训练启动的 LLM OpenAI Server 进行智能体轨迹的生成。
6 |
7 | ### 工作流程
8 |
9 | ```
10 | slime Training Process ←─── HTTP API ───→ Rollout Buffer
11 | ↓ ↓
12 | LLM Server ←─────── HTTP Requests ─────── Agent Framework
13 | ↓ ↓
14 | Model Response ──────────────────────→ Trajectory Generation
15 | ```
16 |
17 | 对于每一个不同的 Agent 任务,都应该对应一个独立的 Generator 类,负责生成该类任务的轨迹。Rollout Buffer 会自动读取并加载不同类型的 Generator。
18 |
19 | ## 快速开始
20 |
21 | ### 基本使用流程
22 |
23 | 1. **复制模板**:将 `base_generator.py` 作为模板进行复制
24 | 2. **修改任务类型**:将 `TASK_TYPE` 修改为您的任务名称(不能与其他 Generator 重复)
25 | 3. **实现核心函数**:实现 `run_rollout()` 函数
26 | 4. **可选定制**:根据需要重写五个可选函数
27 |
28 |
29 | Generator 文件必须以 `_generator.py` 结尾,并放置在 `generator/` 目录下:
30 |
31 | ```
32 | generator/
33 | ├── base_generator.py # Math 任务实现(默认模板)
34 | └── your_task_generator.py # 您的自定义任务
35 | ```
36 |
37 | 每个 Generator 文件必须定义 `TASK_TYPE` 与 `run_rollout()`。
38 |
39 | 此外,Rollout Buffer 还提供了一些可自定义的函数来满足不同任务的特殊需求。如果不提供自定义实现,系统将使用默认实现(位于 `slime_plugins/rollout_buffer/default_func.py`)。
40 |
41 | ### 示例脚本
42 |
43 | 请仿照 [示例:Qwen3-4B 模型](../../docs/zh/models/qwen3-4B.md) 文档中配置好 slime 的运行环境,下载数据,并转换模型 ckpt。之后分别运行
44 |
45 | ```bash
46 | cd slime_plugins/rollout_buffer
47 | bash rollout_buffer_example.sh
48 |
49 | # In a different terminal
50 | python buffer.py
51 | ```
52 |
--------------------------------------------------------------------------------
/scripts/models/qwen3-next-80B-A3B.sh:
--------------------------------------------------------------------------------
1 | NLAYERS=48
2 | FIRST_K_DENSE_REPLACE=0
3 |
4 | arr=()
5 | for ((i=0; i" in response:
6 | model_solution = response.split("")[-1]
7 | elif "###Response" in response:
8 | model_solution = response.split("###Response")[1]
9 | else:
10 | return 0
11 |
12 | model_answer = extract_answer(model_solution)
13 | if model_answer is None:
14 | return 0
15 | if label == "":
16 | return 0
17 |
18 | # Convert single answer to list for uniform processing
19 | assert isinstance(label, (str, float, int))
20 | ground_truths = [label]
21 |
22 | # Process each ground truth
23 | processed_ground_truths = []
24 | for truth in ground_truths:
25 | truth = str(truth)
26 | if "\\boxed" in truth:
27 | processed_truth = extract_answer(truth)
28 | if processed_truth is not None:
29 | processed_ground_truths.append(processed_truth)
30 | else:
31 | processed_ground_truths.append(truth)
32 |
33 | if not processed_ground_truths:
34 | return 0
35 |
36 | # Check against all possible correct answers
37 | for ground_truth in processed_ground_truths:
38 | is_correct = grade_answer_mathd(model_answer, ground_truth) or grade_answer_sympy(model_answer, ground_truth)
39 | if is_correct:
40 | return 1
41 |
42 | return 0
43 |
--------------------------------------------------------------------------------
/examples/train_infer_mismatch_helper/mis.yaml:
--------------------------------------------------------------------------------
1 | # Enable importance sampling, details refer to the comments of compute_mis_weights in mis.py
2 | use_tis: true
3 | use_rs: true
4 |
5 | # Aggregation level for importance sampling weights:
6 | # token: per-token
7 | # sequence: product over tokens
8 | # geometric: geometric mean
9 | tis_level: "token"
10 | rs_level: "token"
11 |
12 | # Handling mode for IS weights:
13 | # truncate: cap to upper bound, TIS
14 | # mask: zero outside [lower, upper], MIS
15 | # clip: clip to [lower, upper], CIS
16 | tis_mode: "truncate"
17 |
18 | # For clip mode, the lower bound of the IS weights.
19 | # For truncate mode, it will not be used.
20 | # If not set, it will be set to 1.0 / mis_upper_bound
21 | # For Geometry level, the lower bound should be 0.9999
22 | tis_lower_bound: 0.5
23 |
24 | # For truncate or clip mode, the upper bound of the IS weights
25 | # For Geometry level, the upper bound should be 1.0001
26 | tis_upper_bound: 2.0
27 |
28 | # Lower and upper bound for rejection sampling.
29 | # If not set, it will be same as tis_lower_bound and tis_upper_bound.
30 | rs_lower_bound: null
31 | rs_upper_bound: null
32 |
33 | # Per-token veto threshold. If any token ratio < this, zero the entire sequence weight, the sequences won't have gradient
34 | # Note: float number must be written with dot e.g. 1.0e-4, not 1e-4
35 | rs_veto_threshold: 1.0e-4
36 |
37 | # Batch normalization: normalize IS weights to mean=1.0 across entire batch
38 | # This reduces variance in gradient updates
39 | tis_batch_normalize: true
40 |
--------------------------------------------------------------------------------
/examples/on_policy_distillation/on_policy_distillation.py:
--------------------------------------------------------------------------------
1 | import aiohttp
2 | import torch
3 |
4 | from slime.utils.types import Sample
5 |
6 |
7 | async def reward_func(args, sample, **kwargs):
8 | payload = {
9 | "text": sample.prompt + sample.response,
10 | "sampling_params": {
11 | "temperature": 0,
12 | "max_new_tokens": 0,
13 | "skip_special_tokens": False,
14 | },
15 | "return_logprob": True,
16 | "logprob_start_len": 0,
17 | }
18 | session_kwargs = {}
19 | async with aiohttp.ClientSession(**session_kwargs) as session:
20 | async with session.post(args.rm_url, json=payload) as resp:
21 | resp.raise_for_status()
22 | return await resp.json()
23 |
24 |
25 | def post_process_rewards(args, samples: list[Sample], **kwargs):
26 | rewards = [sample.get_reward_value(args) for sample in samples]
27 | response_lengths = [sample.response_length for sample in samples]
28 | teacher_log_probs = [
29 | torch.tensor([item[0] for item in reward["meta_info"]["input_token_logprobs"][1:]], dtype=torch.float32)
30 | for reward in rewards
31 | ]
32 | teacher_log_probs = [
33 | t_log_prob[-response_length:]
34 | for t_log_prob, response_length in zip(teacher_log_probs, response_lengths, strict=False)
35 | ]
36 |
37 | for sample, t_log_probs in zip(samples, teacher_log_probs, strict=False):
38 | sample.teacher_log_probs = t_log_probs
39 |
40 | return teacher_log_probs, teacher_log_probs
41 |
--------------------------------------------------------------------------------
/tests/ci/README.md:
--------------------------------------------------------------------------------
1 | # Doc about CI
2 |
3 | ## Configure GitHub secrets
4 |
5 | https://github.com/slimerl/slime/settings/secrets/actions
6 |
7 | * `WANDB_API_KEY`: get from https://wandb.ai/authorize
8 |
9 | ## Setup new GitHub runners
10 |
11 | ### Step 1: Env
12 |
13 | Write `.env` mimicking `.env.example`.
14 | The token can be found at https://github.com/slimerl/slime/settings/actions/runners/new?arch=x64&os=linux.
15 |
16 | WARN: The `GITHUB_RUNNER_TOKEN` changes after a while.
17 |
18 | ### Step 2: Prepare `/home/runner/externals`
19 |
20 | ```shell
21 | docker run --rm -it --privileged --pid=host -v /:/host_root ubuntu /bin/bash -c 'rm -rf /host_root/home/runner/externals && mkdir -p /host_root/home/runner/externals && chmod -R 777 /host_root/home/runner/externals'
22 | docker run -d --name temp-runner ghcr.io/actions/actions-runner:2.328.0 tail -f /dev/null
23 | docker cp temp-runner:/home/runner/externals/. /home/runner/externals
24 | docker rm -f temp-runner
25 | ls -alh /home/runner/externals
26 | ```
27 |
28 | ### Step 3: Run
29 |
30 | ```shell
31 | cd /mnt/data/tom/primary_synced/slime/tests/ci/github_runner
32 | docker compose up -d
33 | ```
34 |
35 | ### Debugging
36 |
37 | Logs
38 |
39 | ```shell
40 | # All containers
41 | docker compose logs -f
42 |
43 | # One container
44 | docker logs -f github_runner-runner-1
45 | ```
46 |
47 | Exec
48 |
49 | ```shell
50 | docker exec -it github_runner-runner-1 /bin/bash
51 | ```
52 |
53 | An example of quickly iterate
54 |
55 | ```shell
56 | docker compose down -v && docker compose up -d && docker logs -f github_runner-runner-1
57 | ```
58 |
--------------------------------------------------------------------------------
/examples/tau-bench/tau1_mock.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | from tau_bench.envs import get_env
6 | from tau_bench.types import RunConfig
7 |
8 | ALL_DATA_MAPPINGS = {"retail": ["train", "test", "dev"], "airline": ["test"]}
9 |
10 |
11 | def main():
12 | parser = argparse.ArgumentParser(description="Tau1 Mock Script")
13 | parser.add_argument("--local_dir", required=True, help="Path to the local directory")
14 | args = parser.parse_args()
15 |
16 | local_dir = args.local_dir
17 | if not os.path.isdir(local_dir):
18 | os.makedirs(local_dir)
19 | config = RunConfig(model_provider="mock", user_model_provider="mock", user_strategy="human", model="mock")
20 | for env, split in ALL_DATA_MAPPINGS.items():
21 | for s in split:
22 | config.env = env
23 | config.task_split = s
24 | env_instance = get_env(
25 | env_name=config.env,
26 | user_strategy=config.user_strategy,
27 | user_model=config.user_model,
28 | task_split=config.task_split,
29 | )
30 | output_path = os.path.join(local_dir, f"{env}_{s}_tasks.jsonl")
31 | with open(output_path, "w") as f:
32 | for i, task in enumerate(env_instance.tasks):
33 | row = {"index": i, "metadata": task.model_dump()}
34 | f.write(json.dumps(row) + "\n") # <-- one JSON object per line
35 | print(f"Saved preprocessed task indices for {env} ({s}) to {output_path}")
36 |
37 |
38 | if __name__ == "__main__":
39 | main()
40 |
--------------------------------------------------------------------------------
/examples/multi_agent/README.md:
--------------------------------------------------------------------------------
1 | # Multi-Agent RL
2 |
3 | This directory provides an example of running multi-agent reinforcement learning (RL) with slime.
4 |
5 | ## Environment Setup
6 |
7 | The environment setup is identical to the standard RL setup used in slime.
8 |
9 | ## Running the Script
10 |
11 | You can either define your own multi-agent system or use the provided default configuration.
12 |
13 | ```python
14 | MULTI_AGENT_CONFIGS = {
15 | "custom_multi_agent_function_path": "examples.multi_agent.agent_system.run_agent_system",
16 | "num_parallel": 5,
17 | "incorrect_reward_weight": 0.8,
18 | "correct_reward_weight": 1.2,
19 | }
20 | ```
21 |
22 | To start a run, execute:
23 |
24 | ```bash
25 | cd slime/
26 | bash examples/multi_agent/run-qwen3-30B-A3B-multi-agent.sh
27 | ```
28 |
29 | ## New Arguments
30 |
31 | - Specify the agent rollout function with the `--custom-generate-function-path` argument.
32 | - Set the `--rollout-max-context-len` argument according to your model’s context window.
33 |
34 | ```bash
35 | ROLLOUT_ARGS=(
36 | --custom-generate-function-path examples.multi_agent.rollout_with_multi_agents.generate_with_multi_agents
37 | --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
38 | --input-key prompt
39 | --label-key label
40 | --apply-chat-template
41 | --rollout-shuffle
42 | --rm-type deepscaler
43 | --num-rollout 3000
44 | --rollout-batch-size 32
45 | --n-samples-per-prompt 8
46 | --rollout-max-context-len 16384
47 | --rollout-max-response-len 8192
48 | --rollout-temperature 0.8
49 |
50 | --global-batch-size 256
51 | --balance-data
52 | )
53 | ```
--------------------------------------------------------------------------------
/docs/zh/index.rst:
--------------------------------------------------------------------------------
1 | slime 文档
2 | ====================
3 |
4 | slime 是一个面向 RL Scaling 的 LLM 后训练框架,提供两大核心能力:
5 |
6 | - 高性能训练:通过连接 Megatron 与 SGLang,支持多种模式下的高效训练;
7 | - 灵活的数据生成:通过自定义数据生成接口与基于服务器的引擎,实现任意训练数据生成流程。
8 |
9 | slime 是 GLM-4.5 与 GLM-4.6 背后的 RL 训练框架。除此之外,slime 还支持:
10 |
11 | - Qwen3 系列 (Qwen3Next, Qwen3MoE, Qwen3), Qwen2.5 系列;
12 | - DeepSeek V3 系列 (DeepSeek V3, V3.1, DeepSeek R1);
13 | - Llama 3。
14 |
15 | .. toctree::
16 | :maxdepth: 1
17 | :caption: 开始使用
18 |
19 | get_started/quick_start.md
20 | get_started/usage.md
21 | get_started/qa.md
22 |
23 | .. toctree::
24 | :maxdepth: 1
25 | :caption: Dense
26 |
27 | examples/qwen3-4B.md
28 | examples/glm4-9B.md
29 |
30 | .. toctree::
31 | :maxdepth: 1
32 | :caption: MoE
33 |
34 | examples/qwen3-30B-A3B.md
35 | examples/glm4.5-355B-A32B.md
36 | examples/deepseek-r1.md
37 |
38 | .. toctree::
39 | :maxdepth: 1
40 | :caption: 高级特性
41 |
42 | _examples_synced/reproducibility/README.md
43 | advanced/speculative-decoding.md
44 | advanced/fault-torlance.md
45 | advanced/arch-support-beyond-megatron.md
46 |
47 | .. toctree::
48 | :maxdepth: 1
49 | :caption: 其他用法
50 |
51 | examples/qwen3-4b-base-openhermes.md
52 | _examples_synced/search-r1/README.md
53 | _examples_synced/fully_async/README.md
54 | _examples_synced/retool/README.md
55 | _examples_synced/multi_agent/README.md
56 |
57 | .. toctree::
58 | :maxdepth: 1
59 | :caption: 开发指南
60 |
61 | developer_guide/debug.md
62 |
63 | .. toctree::
64 | :maxdepth: 1
65 | :caption: 博客
66 |
67 | blogs/release_v0.1.0.md
68 | blogs/introducing_slime.md
69 |
--------------------------------------------------------------------------------
/docker/justfile:
--------------------------------------------------------------------------------
1 | release-primary:
2 | ARG_TAG_POSTFIX="" ARG_BUILD_EXTRA_ARGS="" just _release-raw
3 |
4 | # Should be executed on ARM machines
5 | release-cu129-arm64:
6 | ARG_TAG_POSTFIX="-cu129-arm64" ARG_BUILD_EXTRA_ARGS='--build-arg SGLANG_IMAGE_TAG=v0.5.5.post3-cu129-arm64 --build-arg ENABLE_SGLANG_PATCH=0' just _release-raw
7 |
8 | # Should be executed on ARM machines
9 | release-cu13-arm64:
10 | ARG_TAG_POSTFIX="-cu13-arm64" ARG_BUILD_EXTRA_ARGS='--build-arg SGLANG_IMAGE_TAG=dev-arm64-cu13-20251122 --build-arg ENABLE_CUDA_13=1 --build-arg ENABLE_SGLANG_PATCH=0' just _release-raw
11 |
12 | _release-raw:
13 | #!/bin/bash
14 | set -euxo pipefail
15 | cd ..
16 |
17 | VERSION="$(cat docker/version.txt | tr -d '\n')"
18 | IMAGE_TAG=${VERSION}${ARG_TAG_POSTFIX}
19 |
20 | docker build -f docker/Dockerfile . --build-arg HTTP_PROXY="$http_proxy" --build-arg HTTPS_PROXY="$https_proxy" --build-arg NO_PROXY="localhost,127.0.0.1" $ARG_BUILD_EXTRA_ARGS -t slimerl/slime:$IMAGE_TAG
21 | docker push slimerl/slime:$IMAGE_TAG
22 |
23 | if [ -z "${ARG_TAG_POSTFIX}" ]; then
24 | docker tag slimerl/slime:$IMAGE_TAG slimerl/slime:latest
25 | docker push slimerl/slime:latest
26 | fi
27 |
28 | debug:
29 | #!/bin/bash
30 | set -euxo pipefail
31 | cd ..
32 |
33 | VERSION="$(cat docker/version.txt | tr -d '\n')"
34 | IMAGE_TAG=${VERSION}
35 |
36 | docker build -f docker/Dockerfile . --build-arg HTTP_PROXY="$http_proxy" --build-arg HTTPS_PROXY="$https_proxy" --build-arg NO_PROXY="localhost,127.0.0.1" -t slimerl/slime-test:$IMAGE_TAG
37 | docker push slimerl/slime-test:$IMAGE_TAG
38 |
--------------------------------------------------------------------------------
/docs/build_all.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -euo pipefail
3 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
4 | cd "$SCRIPT_DIR"
5 |
6 | echo "[slime-docs] Building EN..."
7 | ./build.sh en
8 | echo "[slime-docs] Building ZH..."
9 | ./build.sh zh
10 |
11 | # Create a lightweight root index with auto redirect based on localStorage (done client side)
12 | ROOT_INDEX=build/index.html
13 | cat > "$ROOT_INDEX" <<'EOF'
14 |
15 |
16 |
17 |
18 | slime docs
19 |
20 |
26 |
34 |
35 |
36 | slime Documentation
37 | Select language:
38 | English 中文
39 | Auto-redirect uses your last choice if stored; else pick above.
40 |
41 |
42 | EOF
43 |
44 | echo "[slime-docs] Done. Root landing page at build/index.html"
--------------------------------------------------------------------------------
/slime/utils/debug_utils/replay_reward_fn.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Annotated
3 |
4 | import ray
5 | import torch
6 | import typer
7 |
8 | from slime.utils.misc import load_function
9 | from slime.utils.types import Sample
10 |
11 |
12 | def _truncate(text, max_len=200):
13 | """Truncate text and add ellipsis if too long."""
14 | if text is None:
15 | return None
16 | text = str(text).replace("\n", "\\n")
17 | if len(text) > max_len:
18 | return text[:max_len] + "..."
19 | return text
20 |
21 |
22 | def main(
23 | rollout_data_path: Annotated[str, typer.Option()],
24 | custom_rm_path: Annotated[str, typer.Option()],
25 | ):
26 | if not ray.is_initialized():
27 | ray.init()
28 |
29 | pack = torch.load(rollout_data_path)
30 | samples = [Sample.from_dict(s) for s in pack["samples"]]
31 | asyncio.run(_main_async(samples=samples, custom_rm_path=custom_rm_path))
32 |
33 |
34 | async def _main_async(samples, custom_rm_path):
35 | rm_function = load_function(custom_rm_path)
36 | rewards = await asyncio.gather(*[rm_function(None, sample) for sample in samples])
37 |
38 | for i, (sample, reward) in enumerate(zip(samples, rewards, strict=True)):
39 | print("-" * 60)
40 | print(f"Sample {i + 1}/{len(samples)}")
41 | print(f" Index: {sample.index}")
42 | print(f" Status: {sample.status}")
43 | print(f" Reward: {reward}")
44 | print(f" Prompt: {_truncate(sample.prompt, 200)}")
45 | print(f" Response: {_truncate(sample.response, 200)}")
46 | print("-" * 60)
47 |
48 |
49 | if __name__ == "__main__":
50 | typer.run(main)
51 |
--------------------------------------------------------------------------------
/slime/rollout/rm_hub/f1.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 | from collections import Counter
4 |
5 |
6 | def normalize_answer(s):
7 |
8 | def remove_articles(text):
9 | return re.sub(r"\b(a|an|the)\b", " ", text)
10 |
11 | def white_space_fix(text):
12 | return " ".join(text.split())
13 |
14 | def remove_punc(text):
15 | exclude = set(string.punctuation)
16 | return "".join(ch for ch in text if ch not in exclude)
17 |
18 | def lower(text):
19 | return text.lower()
20 |
21 | return white_space_fix(remove_articles(remove_punc(lower(s))))
22 |
23 |
24 | def f1_score(prediction, ground_truth):
25 | ZERO_METRIC = (0, 0, 0)
26 |
27 | if prediction is None:
28 | return ZERO_METRIC
29 |
30 | normalized_prediction = normalize_answer(prediction)
31 | normalized_ground_truth = normalize_answer(ground_truth)
32 |
33 | if normalized_prediction in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth:
34 | return ZERO_METRIC
35 | if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth:
36 | return ZERO_METRIC
37 |
38 | prediction_tokens = normalized_prediction.split()
39 | ground_truth_tokens = normalized_ground_truth.split()
40 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
41 | num_same = sum(common.values())
42 | if num_same == 0:
43 | return ZERO_METRIC
44 | precision = 1.0 * num_same / len(prediction_tokens)
45 | recall = 1.0 * num_same / len(ground_truth_tokens)
46 | f1 = (2 * precision * recall) / (precision + recall)
47 | return f1, precision, recall
48 |
--------------------------------------------------------------------------------
/docs/zh/advanced/arch-support-beyond-megatron.md:
--------------------------------------------------------------------------------
1 | # 在 Megatron-LM 中快速支持新模型架构
2 |
3 | Megatron-LM 框架虽然并行效率高,但在支持日新月异的新模型架构(如 Qwen3Next)时,其灵活性有所欠缺。若要原生支持这些模型的特殊结构(例如 Gated-Delta-Net),往往需要对 Megatron 的核心代码进行侵入性较大、开发周期较长的改造。
4 |
5 | 为了能快速跟进这些前沿模型,`slime` 提出了一种更敏捷的方案:**与其深度改造 Megatron,不如直接引入并封装模型官方的 HuggingFace 实现**,将其作为一个“黑盒模块”无缝嵌入到 Megatron 的并行训练流程中。
6 |
7 | 本文以 Qwen3Next 80B-A3B 为例,介绍这一实现思路。
8 |
9 | ## 实现原理与核心组件
10 |
11 | Megatron 的模型实例化分为两步:首先根据配置生成“层规格”(`ModuleSpec`),再依据该规格实例化具体的 PyTorch 模块。
12 |
13 | `slime` 正是利用这一机制,在**生成 Spec 的阶段“劫持”并替换掉 Megatron 的原生模块**,从而将外部实现(此处为 HuggingFace 模块)无缝嵌入。这一过程主要涉及三个核心组件的协同:
14 |
15 | 1. **替换 Megatron 模块规格 (Spec)**
16 | 这是整个方案的入口。我们通过一个自定义函数(例如 `get_qwen3_next_spec`)来修改标准的 `ModuleSpec`,用我们自己的封装层换掉 Megatron 的原生 Attention 层。
17 | * **具体操作**:获取标准的 Decoder Block Spec,将其 `self_attention` 字段指向我们的自定义模块,并按需开启 `qk_layernorm` 等模型特有配置。
18 | * **对应文件**: `slime_plugins/models/qwen3_next.py`
19 |
20 | 2. **封装 HuggingFace 实现**
21 | 上一步的 Spec 会指向一个封装层,例如 `HuggingfaceAttention`。它继承了 Megatron 的 `MegatronModule`,核心职责是作为桥梁,处理好并行策略所需的数据对齐(如序列并行),然后在内部直接调用从 HuggingFace 加载的原生 `Qwen3NextAttention` 模块。
22 | * **对应文件**: `slime_plugins/models/hf_attention.py`
23 |
24 | 3. **对齐模型权重**
25 | 模型结构跑通后,还需要确保权重能正确加载。我们借助 [mbridge](https://github.com/ISEEKYAN/mbridge) 库,通过 `Qwen3NextBridge` 建立了 HuggingFace Checkpoint 与 Megatron 参数之间的命名映射关系,实现双向互通。
26 | * **对应文件**: `slime_plugins/mbridge/qwen3_next.py`
27 |
28 | 通过这三层协同,我们成功地将一个 Megatron 原本不支持的复杂模型结构(以其 HuggingFace 实现为载体),运行在了 Megatron 的并行框架之上,并完整保留了模型并行、MoE 加速、流水线调度等全部关键能力。
29 |
30 | ## 当前限制
31 |
32 | * 本方案暂不支持被替换模块(如此处的 Attention 层)自身的张量并行(TP)。
33 | * **影响**:在大多数大规模 MoE 模型中,Attention 层的参数量占比较小,因此该限制对显存占用和训练吞吐的影响通常有限。
34 | * **替代方案**:如果该模块的 TP 至关重要,则需要回归到侵入式修改 Megatron 的原生实现方案。
35 |
--------------------------------------------------------------------------------
/scripts/models/kimi-k2.sh:
--------------------------------------------------------------------------------
1 | NLAYERS=61
2 | FIRST_K_DENSE_REPLACE=1
3 |
4 | arr=()
5 | for ((i=0; i=2.0",
42 | ]
43 | },
44 | python_requires=">=3.10",
45 | classifiers=[
46 | "Programming Language :: Python :: 3.10",
47 | "Programming Language :: Python :: 3.11",
48 | "Programming Language :: Python :: 3.12",
49 | "Environment :: GPU :: NVIDIA CUDA",
50 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
51 | "Topic :: System :: Distributed Computing",
52 | ],
53 | cmdclass={"bdist_wheel": bdist_wheel},
54 | )
55 |
--------------------------------------------------------------------------------
/docs/_static/css/custom_log.css:
--------------------------------------------------------------------------------
1 | .output_area {
2 | color: #615656;
3 | }
4 |
5 | table.autosummary td {
6 | width: 50%
7 | }
8 |
9 | img.align-center {
10 | display: block;
11 | margin-left: auto;
12 | margin-right: auto;
13 | }
14 |
15 | .output_area.stderr {
16 | color: #d3d3d3 !important;
17 | }
18 |
19 | .output_area.stdout {
20 | color: #d3d3d3 !important;
21 | }
22 |
23 | div.output_area.stderr {
24 | color: #d3d3d3 !important;
25 | }
26 |
27 | div.output_area.stdout {
28 | color: #d3d3d3 !important;
29 | }
30 |
31 | /* Language toggle button styling */
32 | .lang-toggle-btn {
33 | --lt-border: var(--pst-color-border, #d0d7de);
34 | --lt-bg: var(--pst-color-surface, #f6f8fa);
35 | --lt-bg-hover: var(--pst-color-on-surface, #e6ebf1);
36 | --lt-active: var(--pst-color-primary, #0969da);
37 | display: inline-flex;
38 | align-items: center;
39 | padding: 2px 10px;
40 | line-height: 1.1;
41 | font-size: 0.72rem;
42 | font-weight: 600;
43 | border: 1px solid var(--lt-border);
44 | border-radius: 6px;
45 | background: var(--lt-bg);
46 | cursor: pointer;
47 | gap: 2px;
48 | letter-spacing: .5px;
49 | }
50 | .lang-toggle-btn:hover {
51 | background: var(--lt-bg-hover);
52 | }
53 | .lang-toggle-btn .lang-seg {
54 | opacity: .55;
55 | transition: opacity .15s;
56 | }
57 | .lang-toggle-btn[data-current="en"] .lang-seg[data-lang="en"],
58 | .lang-toggle-btn[data-current="zh"] .lang-seg[data-lang="zh"] {
59 | opacity: 1;
60 | color: var(--lt-active);
61 | }
62 | .lang-toggle-btn .lang-sep { opacity: .35; }
63 |
64 | @media (prefers-color-scheme: dark) {
65 | .lang-toggle-btn { --lt-border: #30363d; --lt-bg:#161b22; --lt-bg-hover:#1c2128; }
66 | .lang-toggle-btn .lang-seg { color: #adbac7; }
67 | }
68 |
--------------------------------------------------------------------------------
/scripts/models/deepseek-v3.sh:
--------------------------------------------------------------------------------
1 | NLAYERS="${MODEL_ARGS_NUM_LAYERS:-61}"
2 | FIRST_K_DENSE_REPLACE=3
3 |
4 | arr=()
5 | for ((i=0; i= 11:
24 | - cc_flag.append('-gencode')
25 | - cc_flag.append('arch=compute_80,code=sm_80')
26 | - if int(bare_metal_minor) >= 8:
27 | + if torch.cuda.is_available() and torch.version.cuda:
28 | + # Check if cuda 11 is installed for compute capability 8.0
29 | + cc_flag = []
30 | + _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
31 | + cpp_extension.CUDA_HOME
32 | + )
33 | + if int(bare_metal_major) >= 11:
34 | cc_flag.append('-gencode')
35 | - cc_flag.append('arch=compute_90,code=sm_90')
36 | + cc_flag.append('arch=compute_80,code=sm_80')
37 | + if int(bare_metal_minor) >= 8:
38 | + cc_flag.append('-gencode')
39 | + cc_flag.append('arch=compute_90,code=sm_90')
40 |
41 | - # Build path
42 | - srcpath = pathlib.Path(__file__).parent.absolute()
43 | - buildpath = srcpath / "build"
44 | - _create_build_dir(buildpath)
45 | + # Build path
46 | + srcpath = pathlib.Path(__file__).parent.absolute()
47 | + buildpath = srcpath / "build"
48 | + _create_build_dir(buildpath)
49 |
50 | # Helper function to build the kernels.
51 | def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
52 |
--------------------------------------------------------------------------------
/scripts/models/moonlight.sh:
--------------------------------------------------------------------------------
1 | MOE_SHARED_EXPERTS=2
2 | MOE_FFN_HIDDEN=1408
3 | MOE_SHARED_EXPERT_INTERMEDIATE_SIZE=$(($MOE_FFN_HIDDEN * $MOE_SHARED_EXPERTS))
4 | MOE_ROUTER_TOPK_SCALING_FACTOR=2.446
5 | NLAYERS=27
6 | FIRST_K_DENSE_REPLACE=1
7 |
8 | arr=()
9 | for ((i=0; i= 11:
24 | - cc_flag.append('-gencode')
25 | - cc_flag.append('arch=compute_80,code=sm_80')
26 | - if int(bare_metal_minor) >= 8:
27 | + if torch.cuda.is_available() and torch.version.cuda:
28 | + # Check if cuda 11 is installed for compute capability 8.0
29 | + cc_flag = []
30 | + _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
31 | + cpp_extension.CUDA_HOME
32 | + )
33 | + if int(bare_metal_major) >= 11:
34 | cc_flag.append('-gencode')
35 | - cc_flag.append('arch=compute_90,code=sm_90')
36 | + cc_flag.append('arch=compute_80,code=sm_80')
37 | + if int(bare_metal_minor) >= 8:
38 | + cc_flag.append('-gencode')
39 | + cc_flag.append('arch=compute_90,code=sm_90')
40 |
41 | - # Build path
42 | - srcpath = pathlib.Path(__file__).parent.absolute()
43 | - buildpath = srcpath / "build"
44 | - _create_build_dir(buildpath)
45 | + # Build path
46 | + srcpath = pathlib.Path(__file__).parent.absolute()
47 | + buildpath = srcpath / "build"
48 | + _create_build_dir(buildpath)
49 |
50 | # Helper function to build the kernels.
51 | def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
52 |
--------------------------------------------------------------------------------
/docs/en/index.rst:
--------------------------------------------------------------------------------
1 | slime Documentation
2 | ====================
3 |
4 | slime is an LLM post-training framework for RL scaling, providing two core capabilities:
5 |
6 | - High-Performance Training: Supports efficient training in various modes by connecting Megatron with SGLang;
7 | - Flexible Data Generation: Enables arbitrary training data generation workflows through custom data generation interfaces and server-based engines.
8 |
9 | slime is the RL-framework behind GLM-4.5 and GLM-4.6. Apart from models from Z.ai, we also supports the following models:
10 |
11 | - Qwen3 series (Qwen3Next, Qwen3MoE, Qwen3), Qwen2.5 series;
12 | - DeepSeek V3 series (DeepSeek V3, V3.1, DeepSeek R1);
13 | - Llama 3.
14 |
15 | .. toctree::
16 | :maxdepth: 1
17 | :caption: Get Started
18 |
19 | get_started/quick_start.md
20 | get_started/usage.md
21 | get_started/qa.md
22 |
23 | .. toctree::
24 | :maxdepth: 1
25 | :caption: Dense
26 |
27 | examples/qwen3-4B.md
28 | examples/glm4-9B.md
29 |
30 | .. toctree::
31 | :maxdepth: 1
32 | :caption: MoE
33 |
34 | examples/qwen3-30B-A3B.md
35 | examples/glm4.5-355B-A32B.md
36 | examples/deepseek-r1.md
37 |
38 | .. toctree::
39 | :maxdepth: 1
40 | :caption: Advanced Features
41 |
42 | _examples_synced/reproducibility/README.md
43 | advanced/speculative-decoding.md
44 | advanced/fault-tolerance.md
45 | advanced/arch-support-beyond-megatron.md
46 |
47 | .. toctree::
48 | :maxdepth: 1
49 | :caption: Other Usage
50 |
51 | examples/qwen3-4b-base-openhermes.md
52 | _examples_synced/search-r1/README.md
53 | _examples_synced/fully_async/README.md
54 | _examples_synced/retool/README.md
55 | _examples_synced/multi_agent/README.md
56 |
57 | .. toctree::
58 | :maxdepth: 1
59 | :caption: Developer Guide
60 |
61 | developer_guide/debug.md
62 |
63 | .. toctree::
64 | :maxdepth: 1
65 | :caption: Hardware Platforms
66 |
67 | platform_support/amd_tutorial.md
68 |
69 | .. toctree::
70 | :maxdepth: 1
71 | :caption: Blogs
72 |
73 | blogs/release_v0.1.0.md
74 | blogs/introducing_slime.md
75 |
--------------------------------------------------------------------------------
/docs/zh/developer_guide/debug.md:
--------------------------------------------------------------------------------
1 | # Debug 指南
2 |
3 | ## 对齐精度
4 |
5 | 在开发 slime 的过程中,经常会需要检查模型的精度是否正确,可以通过以下方式检查:
6 |
7 | 1. 训练第一步
8 | 1. rollout 的生成是否是人话,如果不是,有以下 2 种可能:
9 | - 参数没有正常加载。需要查看是否有 megatron 成功加载 ckpt 的日志;
10 | - 更新参数有误。可以查看是不是所有的参数都做了转换和参数对应,或者参数名是不是根据并行做了转换(例如 pp_size > 1 时,第二个 stage 提供的参数的 layer id 是不是正确的)。一个比较彻底的方法是在对应模型的 sglang 实现的 `load_weights` 中保存所有的参数,查看和加载的 ckpt 中是否一致;
11 | - 如果所有参数更新都正确,还出现问题,有可能是 sglang 里有一些特殊的 buffer 在 release 的时候被释放了;
12 | - 如果是用 pretrain 模型进行的测试,可以换成同结构模型的 instruct 版本,查看这种乱码是不是 pretrain 模型特有的。
13 | 2. 查看打印的 rollout stats 的 `log_probs` 和 `ref_log_probs` 是否完全相等(即第一步 kl=0),且值较小
14 | - 如果不是完全相等的,一般是 transformer engine 中的某些 non-deterministic kernel 导致的,例如:
15 | - 在某些版本的 te 里,megatron 需要 `--attention-backend flash`,来强制使用 flash attention,从而避免 CP 下 fused attention 的数值不稳定;
16 | - 如果数值较大(例如 >1),一般有 2 种可能:
17 | - 如果值非常大,应该是训练配置有问题;
18 | - 如果值只是比 sft loss 的状态略大,例如 instruct 模型的 logprob 到了 0.8,有可能是数据不符合训练的 chat template,或者不符合冷启动的分布。
19 | 3. 查看在推一训一(`num_steps_per_rollout == 1`),kl 是否为 0,grad_norm 是否较小
20 | - 基本上就是一些 megatron / te 相关的 bug,例如:
21 | - moe 需要开启 `--moe-permute-fusion`。
22 |
23 | 2. 训练第二步
24 | 1. 对于训推一体,查看是否能正确加载第二步,是否会 OOM;
25 |
26 | ## 训练推理单独 debug
27 |
28 | slime 支持将训练部分和推理部分分开进行调试,从而实现:
29 |
30 | - 在调优/debug 推理部分时,只用少量卡就可以启动任务;
31 | - 在调优/debug 训练部分时,可以保证模型输入固定,去除 rollout 的随机性。
32 |
33 | 具体来说,目前 slime 提供了如下的参数来进行分离调试:
34 |
35 | 1. `--debug-rollout-only`
36 |
37 | 开启后,slime 将不会加载 megatron,只初始化 sglang ,可以用这个方法来进行推理部分的调试。
38 |
39 | 1. `--debug-train-only`
40 |
41 | 开启后,slime 将不会加载 sglang,只初始化 megatron ,可以用这个方法来进行训练部分的调试。
42 |
43 | 2. `--save-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt`
44 |
45 | 开启后,会保存每次 rollout 的结果,可以和 `--debug-rollout-only` 配合使用。注意保存的方式为 `args.save_debug_rollout_data.format(rollout_id=rollout_id)`。
46 |
47 | 3. `--load-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt`
48 |
49 | 开启后,会从 `args.load_debug_rollout_data.format(rollout_id=rollout_id)` 来加载数据,并且不会初始化 sglang(自动设置 `debug_train_only=True`)。可以以这种方式来固定训练部分的输入,对训练部分进行调优,例如切换各种并行。
50 |
--------------------------------------------------------------------------------
/slime/utils/train_metric_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from argparse import Namespace
3 | from collections.abc import Callable
4 | from copy import deepcopy
5 |
6 | from slime.utils import tracking_utils
7 | from slime.utils.metric_utils import compute_rollout_step
8 | from slime.utils.timer import Timer
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def log_perf_data_raw(
14 | rollout_id: int, args: Namespace, is_primary_rank: bool, compute_total_fwd_flops: Callable
15 | ) -> None:
16 | timer_instance = Timer()
17 | log_dict_raw = deepcopy(timer_instance.log_dict())
18 | timer_instance.reset()
19 |
20 | if not is_primary_rank:
21 | return
22 |
23 | log_dict = {f"perf/{key}_time": val for key, val in log_dict_raw.items()}
24 |
25 | if ("perf/actor_train_time" in log_dict) and (compute_total_fwd_flops is not None):
26 | total_fwd_flops = compute_total_fwd_flops(seq_lens=timer_instance.seq_lens)
27 |
28 | if "perf/log_probs_time" in log_dict:
29 | log_dict["perf/log_probs_tflops"] = total_fwd_flops / log_dict["perf/log_probs_time"]
30 |
31 | if "perf/ref_log_probs_time" in log_dict:
32 | log_dict["perf/ref_log_probs_tflops"] = total_fwd_flops / log_dict["perf/ref_log_probs_time"]
33 |
34 | if log_dict["perf/actor_train_time"] > 0:
35 | log_dict["perf/actor_train_tflops"] = 3 * total_fwd_flops / log_dict["perf/actor_train_time"]
36 | log_dict["perf/actor_train_tok_per_s"] = sum(timer_instance.seq_lens) / log_dict["perf/actor_train_time"]
37 |
38 | if "perf/train_wait_time" in log_dict and "perf/train_time" in log_dict:
39 | total_time = log_dict["perf/train_wait_time"] + log_dict["perf/train_time"]
40 | if total_time > 0:
41 | log_dict["perf/step_time"] = total_time
42 | log_dict["perf/wait_time_ratio"] = log_dict["perf/train_wait_time"] / total_time
43 |
44 | logger.info(f"perf {rollout_id}: {log_dict}")
45 |
46 | step = compute_rollout_step(args, rollout_id)
47 | log_dict["rollout/step"] = step
48 | tracking_utils.log(args, log_dict, step_key="rollout/step")
49 |
--------------------------------------------------------------------------------
/examples/eval/nemo_skills/skills_config.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Mapping
4 | from dataclasses import dataclass, field
5 | from typing import Any
6 |
7 | from examples.eval.eval_delegate import EvalEnvConfig, EvalEnvDatasetConfig
8 |
9 |
10 | @dataclass
11 | class SkillsEvalEnvDatasetConfig(EvalEnvDatasetConfig):
12 | """Dataset configuration shared by the Skills client/server."""
13 |
14 | def __post_init__(self):
15 | name = (self.name or "").strip()
16 | self.name = name
17 | if not name:
18 | raise ValueError("Each Skills dataset entry must include a non-empty `name`.")
19 | if ":" in name:
20 | raise ValueError(
21 | "Colon in dataset name is not allowed; use `n_samples_per_eval_prompt` to configure samples per prompt."
22 | )
23 |
24 | @property
25 | def runtime_name(self) -> str:
26 | if self.n_samples_per_eval_prompt is None:
27 | return self.name
28 | return f"{self.name}:{self.n_samples_per_eval_prompt}"
29 |
30 | @classmethod
31 | def parse(cls, args, dataset_cfg: Mapping[str, Any], defaults: Mapping[str, Any]):
32 | return super().parse(args, dataset_cfg, defaults)
33 |
34 |
35 | @dataclass
36 | class SkillsEvalEnvConfig(EvalEnvConfig):
37 | """Environment configuration shared by the Skills client/server."""
38 |
39 | datasets: list[SkillsEvalEnvDatasetConfig] = field(default_factory=list)
40 |
41 | @classmethod
42 | def parse(cls, args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]) -> SkillsEvalEnvConfig:
43 | base_cfg: SkillsEvalEnvConfig = super().parse(raw_env_config, defaults)
44 | datasets = raw_env_config.get("datasets") or []
45 | base_cfg.datasets = [
46 | SkillsEvalEnvDatasetConfig.parse(args, dataset_cfg, base_cfg.defaults) for dataset_cfg in datasets
47 | ]
48 | return base_cfg
49 |
50 |
51 | def build_skills_eval_env_config(args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]):
52 | return SkillsEvalEnvConfig.parse(args, raw_env_config, defaults)
53 |
--------------------------------------------------------------------------------
/slime/backends/fsdp_utils/models/qwen3_moe_hf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def apply_fsdp_moe_patch():
6 |
7 | from transformers.models.qwen3_moe import modeling_qwen3_moe
8 |
9 | def _forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
10 | batch_size, sequence_length, hidden_dim = hidden_states.shape
11 | hidden_states = hidden_states.view(-1, hidden_dim)
12 | router_logits = self.gate(hidden_states)
13 |
14 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
15 | routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
16 | if self.norm_topk_prob:
17 | routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
18 | routing_weights = routing_weights.to(hidden_states.dtype)
19 |
20 | final_hidden_states = torch.zeros(
21 | (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
22 | )
23 |
24 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
25 |
26 | # Loop over all experts
27 | for expert_idx in range(self.num_experts):
28 | expert_layer = self.experts[expert_idx]
29 | idx, top_x = torch.where(expert_mask[expert_idx])
30 |
31 | if top_x.numel() > 0:
32 | current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
33 | current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
34 | final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
35 | else:
36 | # force experts to participate in computation graph
37 | dummy_output = expert_layer(hidden_states[:1]) * 0.0
38 | final_hidden_states[:1] = final_hidden_states[:1] + dummy_output
39 |
40 | final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
41 | return final_hidden_states, router_logits
42 |
43 | modeling_qwen3_moe.Qwen3MoeSparseMoeBlock.forward = _forward
44 |
--------------------------------------------------------------------------------
/tests/test_chunked_gae.py:
--------------------------------------------------------------------------------
1 | import time
2 | import pytest
3 | import torch
4 |
5 | from slime.utils.ppo_utils import chunked_gae, vanilla_gae
6 |
7 |
8 | @pytest.mark.parametrize(
9 | "B,T",
10 | [
11 | (16, 4096),
12 | (32, 8192),
13 | (256, 128 * 1024),
14 | ],
15 | )
16 | @pytest.mark.parametrize("chunk_size", [64, 128, 256])
17 | def test_gae_parallel_matches_serial(B, T, chunk_size):
18 | """
19 | Test that chunked_gae (parallel-scan) matches vanilla_gae (batch-serial)
20 | under various shapes, chunk sizes and dtypes.
21 | """
22 | device = "cuda" if torch.cuda.is_available() else "cpu"
23 | torch.manual_seed(0)
24 |
25 | rewards = torch.randn(B, T, device=device, dtype=torch.float32)
26 | values = torch.randn(B, T, device=device, dtype=torch.float32)
27 |
28 | gamma, lam = 0.99, 0.95
29 |
30 | # ---------- Serial ----------
31 | if device == "cuda":
32 | torch.cuda.synchronize()
33 | t0 = time.time()
34 | adv_s, ret_s = vanilla_gae(rewards, values, gamma, lam)
35 | if device == "cuda":
36 | torch.cuda.synchronize()
37 | t1 = time.time()
38 | serial_time = t1 - t0
39 |
40 | # ---------- Parallel-scan ----------
41 | if device == "cuda":
42 | torch.cuda.synchronize()
43 | t0 = time.time()
44 | adv_p, ret_p = chunked_gae(rewards, values, gamma, lam, chunk_size=chunk_size)
45 | if device == "cuda":
46 | torch.cuda.synchronize()
47 | t1 = time.time()
48 | parallel_time = t1 - t0
49 |
50 | # ---------- Accuracy ----------
51 | adv_err = (adv_s - adv_p).abs().max().item()
52 | ret_err = (ret_s - ret_p).abs().max().item()
53 |
54 | atol = 1e-5
55 | assert adv_err < atol, f"adv error too large: {adv_err}"
56 | assert ret_err < atol, f"ret error too large: {ret_err}"
57 |
58 | # ---------- logging ----------
59 | print(f"\n[GAE Test] B={B}, T={T}, chunk={chunk_size}")
60 | print(f" Serial : {serial_time:.6f} s")
61 | print(f" Parallel : {parallel_time:.6f} s")
62 | print(f" Speedup : x{serial_time / parallel_time:.2f}")
63 | print(f" Max diff adv={adv_err:.3e}, ret={ret_err:.3e}")
64 |
--------------------------------------------------------------------------------
/examples/strands-agents/README.md:
--------------------------------------------------------------------------------
1 | # Slime x Strands-Agents
2 |
3 | This is a running example that connects the [Strands-Agents](https://github.com/strands-agents/sdk-python) agent scaffolding framework with Slime for RL training.
4 |
5 | ## Install Dependencies
6 |
7 | 1. Pull the `slimerl/slime:latest` image and enter it
8 | 2. Goes to slime folder: `cd /root/slime` (Clone the repository if not already there: `cd /root && git clone https://github.com/THUDM/slime.git`)
9 | 3. Install Slime: `pip install -e .`
10 | 4. Goes to the example folder: `cd /root/slime/examples/strands-agents`
11 | 5. Install other dependencies: `pip install -r requirements.txt`
12 |
13 | > NOTE: we use camel-ai's subprocess code interpreter for python code execution, which is NOT a good practice; it's just for convenience of this example and the dependencies for solving math problems are usually ready in `slime`'s docker
14 |
15 | ## Prepare Model
16 |
17 | ```bash
18 | # hf checkpoint
19 | huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/models/Qwen/Qwen3-4B-Instruct-2507
20 |
21 | # mcore checkpoint
22 | cd /root/slime
23 | source scripts/models/qwen3-4B.sh
24 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
25 | ${MODEL_ARGS[@]} \
26 | --hf-checkpoint /root/models/Qwen/Qwen3-4B-Instruct-2507 \
27 | --save /root/models/Qwen/Qwen3-4B-Instruct-2507_torch_dist
28 | ```
29 |
30 | ## Prepare Dataset
31 |
32 | Following [Retool](https://arxiv.org/abs/2504.11536), we used `dapo-math-17k` as training data:
33 |
34 | ```
35 | from datasets import load_dataset
36 | ds = load_dataset("zhuzilin/dapo-math-17k", split="train")
37 | ds.to_json("/root/data/dapo-math-17k.jsonl", orient="records", lines=True)
38 | ```
39 |
40 | and `aime-2024` as eval data:
41 |
42 | ```
43 | from datasets import load_dataset
44 | ds = load_dataset("zhuzilin/aime-2024", split="train")
45 | ds.to_json("/root/data/aime-2024.jsonl", orient="records", lines=True)
46 | ```
47 |
48 | ## Run Training
49 |
50 | Assuming `/root/slime` is up-to-date (if this PR is not merged you may need to switch branch):
51 |
52 | ```
53 | cd /root/slime
54 | export WANDB_KEY=$your_wandb_key
55 | bash examples/strands-agents/strands_qwen3_4b.sh
56 | ```
57 |
--------------------------------------------------------------------------------
/slime_plugins/rollout_buffer/README.md:
--------------------------------------------------------------------------------
1 | # Rollout Buffer
2 |
3 | ## Overview
4 |
5 | Rollout Buffer is an independent component for asynchronous agent trajectory generation, with the main function of using the LLM OpenAI Server launched by slime training to generate agent trajectories.
6 |
7 | ### Workflow
8 |
9 | ```
10 | slime Training Process ←─── HTTP API ───→ Rollout Buffer
11 | ↓ ↓
12 | LLM Server ←─────── HTTP Requests ─────── Agent Framework
13 | ↓ ↓
14 | Model Response ──────────────────────→ Trajectory Generation
15 | ```
16 |
17 | For each different Agent task, there should be a corresponding independent Generator class, responsible for generating trajectories for that type of task. Rollout Buffer automatically reads and loads different types of Generators.
18 |
19 | ## Quick Start
20 |
21 | ### Basic Usage Process
22 |
23 | 1. **Copy Template**: Copy `base_generator.py` as a template
24 | 2. **Modify Task Type**: Change `TASK_TYPE` to your task name (cannot duplicate with other Generators)
25 | 3. **Implement Core Function**: Implement the `run_rollout()` function
26 | 4. **Optional Customization**: Rewrite five optional functions as needed
27 |
28 |
29 | Generator files must end with `_generator.py` and be placed in the `generator/` directory:
30 |
31 | ```
32 | generator/
33 | ├── base_generator.py # Math task implementation (default template)
34 | └── your_task_generator.py # Your custom task
35 | ```
36 |
37 | Each Generator file must define `TASK_TYPE` and `run_rollout()`.
38 |
39 | In addition, Rollout Buffer also provides some customizable functions to meet special needs of different tasks. If no custom implementation is provided, the system will use default implementations (located in `slime_plugins/rollout_buffer/default_func.py`).
40 |
41 | ### Example Script
42 |
43 | First, you need to follow [Example: Qwen3-4B Model](../../docs/en/models/qwen3-4B.md) to configure the environment, download data and convert model checkpoints. And then run the following scripts:
44 | ```bash
45 | cd slime_plugins/rollout_buffer
46 | bash rollout_buffer_example.sh
47 |
48 | # In a different terminal
49 | python buffer.py
50 | ```
51 |
--------------------------------------------------------------------------------
/examples/tau-bench/README.md:
--------------------------------------------------------------------------------
1 | # Tau bench
2 | This example shows slime training in an agentic multi-turn tool use environment.
3 |
4 |
5 | ## Environment Setup
6 | Use the `zhuzilin/slime:latest` image and initialize the environment required for Search-R1:
7 |
8 | ```bash
9 | cd /root/
10 | git clone https://github.com/THUDM/slime.git
11 | cd slime
12 | pip install -e .
13 | # for tau bench
14 | cd /root/
15 | git clone https://github.com/JD-ETH/tau-bench.git
16 | cd tau-bench
17 | git checkout feature/litellm-retry
18 | pip install -e .
19 | ```
20 |
21 | Use the following script to generate mock data for slime training.
22 |
23 | ```bash
24 | cd /root/slime/examples/tau-bench
25 | python tau1_mock.py --local_dir /root/tau-bench/
26 | ```
27 |
28 | Initialize the Qwen2.5-3B-Instruct model needed for tool use:
29 |
30 | ```bash
31 | # hf checkpoint
32 | huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/Qwen3-4B-Instruct-2507
33 |
34 | # mcore checkpoint
35 | cd /root/slime
36 | source scripts/models/qwen3-4B-Instruct-2507.sh
37 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
38 | ${MODEL_ARGS[@]} \
39 | --hf-checkpoint /root/Qwen3-4B-Instruct-2507 \
40 | --save /root/Qwen3-4B-Instruct-2507_torch_dist
41 | ```
42 |
43 | ## Running the Script
44 |
45 | You need to configure your litellm API in `generate_with_tau.py` for user simulation:
46 |
47 | ```python
48 | TAU_CONFIGS = {
49 | "env": "retail", # Select between ["retail", "airline"]
50 | "agent": "tool-calling", # Select between ["tool-calling", "act", "react", "few-shot"], only tool-calling implemented for now
51 | "user_model": "gemini-2.0-flash-lite", # Cheap Model for user simulator
52 | "user_model_provider": "gemini",
53 | "task_split": "train", # Select between ["train", "test", "dev"] for retail, ["test"] for airline
54 | "user_strategy": "llm", # Select between ["llm", "react", "verify", "reflection"]
55 | "model_provider": "auto_router", # Unused, required
56 | "model": "qwen3-4b", # Unused, reqired
57 | }
58 | # Replace with your actual API key for user sim
59 | GEMINI_API_KEY = "YOUR KEY"
60 | ```
61 |
62 | And run:
63 |
64 |
65 | ```bash
66 | cd /root/slime
67 | bash examples/tau-bench/run_qwen3_4B.sh
68 | ```
--------------------------------------------------------------------------------
/examples/eval/README.md:
--------------------------------------------------------------------------------
1 | # Docs
2 |
3 | ## Prerequisites
4 | - A writable host directory for cached data (`/data/.cache`)
5 | - Choose descriptive container names to replace the placeholders (``, ``).
6 |
7 | ## 1) Prepare host network
8 | ```bash
9 | docker network create skills-net
10 | ```
11 |
12 | ## 2) Launch the slime container
13 | ```bash
14 | docker run \
15 | -itd \
16 | --shm-size 32g \
17 | --gpus all \
18 | -v /data/.cache:/root/.cache \
19 | -v /dev/shm:/shm \
20 | --ipc=host \
21 | --privileged \
22 | --network skills-net \
23 | --name \
24 | slimerl/slime:latest \
25 | /bin/bash
26 | ```
27 |
28 | ## 3) Launch the Skills container
29 | ```bash
30 | docker run \
31 | -itd \
32 | --shm-size 32g \
33 | --gpus all \
34 | -v /data/.cache:/root/.cache \
35 | -v /dev/shm:/shm \
36 | --ipc=host \
37 | --privileged \
38 | --network skills-net \
39 | --name \
40 | --network-alias skills_server \
41 | guapisolo/nemoskills:0.7.1 \
42 | /bin/bash
43 | ```
44 |
45 | ## 4) Inside the Skills container
46 | Clone repos and install the Skills package:
47 | ```bash
48 | git clone -b slime_skills https://github.com/guapisolo/slime.git /opt/slime
49 | git clone -b slime https://github.com/guapisolo/Skills.git /opt/Skills
50 |
51 | cd /opt/Skills
52 | pip install -e .
53 | ```
54 |
55 | Download/prepare datasets:
56 | ```bash
57 | cd /opt/Skills/nemo_skills/dataset
58 | python3 aime25/prepare.py
59 | python3 hle/prepare.py
60 | python3 arena-hard/prepare.py
61 | ```
62 |
63 | Start the skills server:
64 | ```bash
65 | cd /opt/slime
66 | python examples/eval/nemo_skills/skills_server.py \
67 | --host 0.0.0.0 \
68 | --port 9050 \
69 | --output-root /opt/skills-eval \
70 | --config-dir examples/eval/nemo_skills/config \
71 | --cluster local_cluster \
72 | --max-concurrent-requests 512 \
73 | --openai-model-name slime-openai-model
74 | ```
75 |
76 | You can now connect to the server at `skills_server:9050` from within the `skills-net` Docker network. The server always proxies evaluation traffic to an OpenAI-compatible sglang router (Slime starts and manage the router), so adjust `--openai-model-name` and `--max-concurrent-requests` as needed for your deployment.
77 |
--------------------------------------------------------------------------------
/examples/fully_async/README.md:
--------------------------------------------------------------------------------
1 | ## Fully Asynchronous Rollout Example
2 |
3 | This example shows a simple way to make rollout generation **fully asynchronous**: a single global worker is created once and then keeps running in the background, continuously pulling prompts and launching generation tasks. Training only needs to fetch already finished results. This removes the per‑step wait that happens in the normal synchronous style.
4 |
5 | ### Files
6 | * `fully_async_rollout.py`: global async worker + `generate_rollout_fully_async` entry.
7 | * `run-qwen3-4b-fully_async.sh`: example launch script with Qwen3‑4B.
8 |
9 | ### Prerequisite
10 | First set up model & environment following the Qwen3-4B example.
11 |
12 | ### Quick Start
13 | ```bash
14 | cd slime
15 | bash examples/fully_async/run-qwen3-4b-fully_async.sh
16 | ```
17 | You should see log lines like:
18 | ```
19 | Creating new global async worker...
20 | Continuous async rollout worker started
21 | ```
22 |
23 | ### How It Works (Very Short)
24 | * First call: create `AsyncRolloutWorker` (thread + asyncio loop).
25 | * Loop keeps up to `--rollout-batch-size` tasks in flight using `generate_and_rm_group`.
26 | * Completed groups are pushed into a queue; caller drains until it has enough samples.
27 | * Worker is stopped automatically at process exit.
28 |
29 | ### Limitations
30 | * No evaluation mode.
31 | * Ordering is best effort (sorted at the end by index).
32 | * Minimal error handling.
33 |
34 | ### Config Differences (2 Key Points)
35 | To enable the fully async pattern there are only two changes compared to a normal run:
36 |
37 | 1. Use the async training driver: `train_async.py` (not `train.py`).
38 | 2. Set the rollout function path:
39 | ```bash
40 | --rollout-function-path fully_async_rollout.generate_rollout_fully_async
41 | ```
42 |
43 | Why is it still "fully" async although `train_async.py` itself schedules rollouts step‑by‑step?
44 |
45 | Because the real generation work is done by a **persistent background worker** created in `generate_rollout_fully_async`. Each call from `train_async.py` only drains already completed samples from the worker's output queue; the worker has been continuously generating since the first call. Thus rollout production (model inference) and training consume happen in parallel with minimal waiting.
46 |
--------------------------------------------------------------------------------
/slime/utils/debug_utils/send_to_sglang.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | from typing import Annotated
4 |
5 | import typer
6 | from openai import AsyncOpenAI
7 |
8 | from slime.utils.data import read_file
9 |
10 |
11 | # can unify w/ sglang_rollout.py later, e.g. add RM, if needed
12 | def main(
13 | prompt_data: Annotated[str, typer.Option()],
14 | url: Annotated[str, typer.Option()] = "http://localhost:30000/v1",
15 | input_key: Annotated[str, typer.Option()] = "input",
16 | n_samples_per_prompt: Annotated[int, typer.Option()] = 1,
17 | rollout_max_response_len: Annotated[int, typer.Option()] = 1024,
18 | rollout_temperature: Annotated[float, typer.Option()] = 1.0,
19 | rollout_top_p: Annotated[float, typer.Option()] = 1.0,
20 | ):
21 | """
22 | Minimally send prompts to SGLang using OpenAI endpoints with arguments in the same format as main Slime.
23 |
24 | Example usage:
25 | python -m slime.utils.debug_utils.send_to_sglang --prompt-data /root/datasets/aime-2024/aime-2024.jsonl --input-key prompt --n-samples-per-prompt 16 --rollout-max-response-len 32768 --rollout-temperature 0.8 --rollout-top-p 0.7
26 | """
27 |
28 | async def _main_async():
29 | tasks = [
30 | asyncio.create_task(_run_one(row, row_index=row_index, repeat_index=repeat_index))
31 | for row_index, row in enumerate(read_file(prompt_data))
32 | for repeat_index in range(n_samples_per_prompt)
33 | ]
34 | outputs = await asyncio.gather(*tasks)
35 | for output in outputs:
36 | print(json.dumps(output))
37 |
38 | async def _run_one(row, row_index: int, repeat_index: int):
39 | resp = await client.chat.completions.create(
40 | messages=row[input_key],
41 | model="dummy_model",
42 | max_tokens=rollout_max_response_len,
43 | temperature=rollout_temperature,
44 | top_p=rollout_top_p,
45 | )
46 | return dict(
47 | row_index=row_index,
48 | repeat_index=repeat_index,
49 | **row,
50 | response=resp.choices[0].message.content,
51 | )
52 |
53 | client = AsyncOpenAI(api_key="dummy_key", base_url=url)
54 | asyncio.run(_main_async())
55 |
56 |
57 | if __name__ == "__main__":
58 | typer.run(main)
59 |
--------------------------------------------------------------------------------
/slime/utils/tensorboard_utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import os
4 | from slime.utils.misc import SingletonMeta
5 |
6 | try:
7 | from torch.utils.tensorboard import SummaryWriter
8 | except ImportError:
9 | SummaryWriter = None
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class _TensorboardAdapter(metaclass=SingletonMeta):
15 | _writer = None
16 |
17 | """
18 | # Usage example: This will return the same instance every rank
19 | # tb = _TensorboardAdapter(args) # Initialize on first call
20 | # tb.log({"Loss": 0.1}, step=1)
21 |
22 | # In other files:
23 | # from tensorboard_utils import _TensorboardAdapter
24 | # tb = _TensorboardAdapter(args) # No parameters needed to get existing instance
25 | # tb.log({"Accuracy": 0.9}, step=1)
26 | """
27 |
28 | def __init__(self, args):
29 | assert args.use_tensorboard, f"{args.use_tensorboard=}"
30 | tb_project_name = args.tb_project_name
31 | tb_experiment_name = args.tb_experiment_name
32 | if tb_project_name is not None or os.environ.get("TENSORBOARD_DIR", None):
33 | if tb_project_name is not None and tb_experiment_name is None:
34 | tb_experiment_name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
35 | self._initialize(tb_project_name, tb_experiment_name)
36 | else:
37 | raise ValueError("tb_project_name and tb_experiment_name, or TENSORBOARD_DIR are required")
38 |
39 | def _initialize(self, tb_project_name, tb_experiment_name):
40 | """Actual initialization logic"""
41 | # Get tensorboard directory from environment variable or use default path
42 | tensorboard_dir = os.environ.get("TENSORBOARD_DIR", f"tensorboard_log/{tb_project_name}/{tb_experiment_name}")
43 | os.makedirs(tensorboard_dir, exist_ok=True)
44 | logger.info(f"Saving tensorboard log to {tensorboard_dir}.")
45 | self._writer = SummaryWriter(tensorboard_dir)
46 |
47 | def log(self, data, step):
48 | """Log data to tensorboard
49 |
50 | Args:
51 | data (dict): Dictionary containing metric names and values
52 | step (int): Current step/epoch number
53 | """
54 | for key in data:
55 | self._writer.add_scalar(key, data[key], step)
56 |
57 | def finish(self):
58 | """Close the tensorboard writer"""
59 | self._writer.close()
60 |
--------------------------------------------------------------------------------
/slime/utils/debug_utils/display_debug_rollout_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from types import SimpleNamespace
4 | from typing import Annotated
5 |
6 | import torch
7 | import typer
8 |
9 | from slime.ray.rollout import compute_metrics_from_samples
10 | from slime.utils.types import Sample
11 |
12 | _WHITELIST_KEYS = [
13 | "group_index",
14 | "index",
15 | "prompt",
16 | "response",
17 | "response_length",
18 | "label",
19 | "reward",
20 | "status",
21 | "metadata",
22 | ]
23 |
24 |
25 | def main(
26 | # Deliberately make this name consistent with main training arguments
27 | load_debug_rollout_data: Annotated[str, typer.Option()],
28 | show_metrics: bool = True,
29 | show_samples: bool = True,
30 | category: list[str] = None,
31 | ):
32 | if category is None:
33 | category = ["train", "eval"]
34 | for rollout_id, path in _get_rollout_dump_paths(load_debug_rollout_data, category):
35 | print("-" * 80)
36 | print(f"{rollout_id=} {path=}")
37 | print("-" * 80)
38 |
39 | pack = torch.load(path)
40 | sample_dicts = pack["samples"]
41 |
42 | if show_metrics:
43 | # TODO read these configs from dumps
44 | args = SimpleNamespace(
45 | advantage_estimator="grpo",
46 | reward_key=None,
47 | log_reward_category=None,
48 | )
49 | sample_objects = [Sample.from_dict(s) for s in sample_dicts]
50 | metrics = compute_metrics_from_samples(args, sample_objects)
51 | print("metrics", metrics)
52 |
53 | if show_samples:
54 | for sample in sample_dicts:
55 | print(json.dumps({k: v for k, v in sample.items() if k in _WHITELIST_KEYS}))
56 |
57 |
58 | def _get_rollout_dump_paths(load_debug_rollout_data: str, categories: list[str]):
59 | # may improve later
60 | for rollout_id in range(1000):
61 | for category in categories:
62 | prefix = {
63 | "train": "",
64 | "eval": "eval_",
65 | }[category]
66 | path = Path(load_debug_rollout_data.format(rollout_id=f"{prefix}{rollout_id}"))
67 | if path.exists():
68 | yield rollout_id, path
69 |
70 |
71 | if __name__ == "__main__":
72 | """python -m slime.utils.debug_utils.display_debug_rollout_data --load-debug-rollout-data ..."""
73 | typer.run(main)
74 |
--------------------------------------------------------------------------------
/slime/utils/timer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from contextlib import contextmanager
3 | from functools import wraps
4 | from time import time
5 |
6 | import torch.distributed
7 |
8 | from .misc import SingletonMeta
9 |
10 | __all__ = ["Timer", "timer"]
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class Timer(metaclass=SingletonMeta):
16 | def __init__(self):
17 | self.timers = {}
18 | self.start_time = {}
19 |
20 | def start(self, name):
21 | assert name not in self.start_time, f"Timer {name} already started."
22 | self.start_time[name] = time()
23 | if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
24 | logger.info(f"Timer {name} start")
25 |
26 | def end(self, name):
27 | assert name in self.start_time, f"Timer {name} not started."
28 | elapsed_time = time() - self.start_time[name]
29 | self.add(name, elapsed_time)
30 | del self.start_time[name]
31 | if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
32 | logger.info(f"Timer {name} end (elapsed: {elapsed_time:.1f}s)")
33 |
34 | def reset(self, name=None):
35 | if name is None:
36 | self.timers = {}
37 | elif name in self.timers:
38 | del self.timers[name]
39 |
40 | def add(self, name, elapsed_time):
41 | self.timers[name] = self.timers.get(name, 0) + elapsed_time
42 |
43 | def log_dict(self):
44 | return self.timers
45 |
46 | @contextmanager
47 | def context(self, name):
48 | self.start(name)
49 | try:
50 | yield
51 | finally:
52 | self.end(name)
53 |
54 |
55 | def timer(name_or_func):
56 | """
57 | Can be used either as a decorator or a context manager:
58 |
59 | @timer
60 | def func():
61 | ...
62 |
63 | or
64 |
65 | with timer("block_name"):
66 | ...
67 | """
68 | # When used as a context manager
69 | if isinstance(name_or_func, str):
70 | name = name_or_func
71 | return Timer().context(name)
72 |
73 | func = name_or_func
74 |
75 | @wraps(func)
76 | def wrapper(*args, **kwargs):
77 | with Timer().context(func.__name__):
78 | return func(*args, **kwargs)
79 |
80 | return wrapper
81 |
82 |
83 | @contextmanager
84 | def inverse_timer(name):
85 | Timer().end(name)
86 | try:
87 | yield
88 | finally:
89 | Timer().start(name)
90 |
--------------------------------------------------------------------------------
/slime/utils/typer_utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import inspect
3 | from typing import Annotated
4 |
5 | import typer
6 |
7 |
8 | def dataclass_cli(func, env_var_prefix: str = "SLIME_SCRIPT_"):
9 | """Modified from https://github.com/fastapi/typer/issues/154#issuecomment-1544876144"""
10 |
11 | # The dataclass type is the first argument of the function.
12 | sig = inspect.signature(func)
13 | param = list(sig.parameters.values())[0]
14 | dataclass_cls = param.annotation
15 | assert dataclasses.is_dataclass(dataclass_cls)
16 |
17 | # To construct the signature, we remove the first argument (self)
18 | # from the dataclass __init__ signature.
19 | signature = inspect.signature(dataclass_cls.__init__)
20 | old_parameters = list(signature.parameters.values())
21 | if len(old_parameters) > 0 and old_parameters[0].name == "self":
22 | del old_parameters[0]
23 |
24 | new_parameters = []
25 | for param in old_parameters:
26 | env_var_name = f"{env_var_prefix}{param.name.upper()}"
27 | new_annotation = Annotated[param.annotation, typer.Option(envvar=env_var_name)]
28 | new_parameters.append(param.replace(annotation=new_annotation))
29 |
30 | def wrapped(**kwargs):
31 | data = dataclass_cls(**kwargs)
32 | print(f"Execute command with args: {data}")
33 | return func(data)
34 |
35 | wrapped.__signature__ = signature.replace(parameters=new_parameters)
36 | wrapped.__doc__ = func.__doc__
37 | wrapped.__name__ = func.__name__
38 | wrapped.__qualname__ = func.__qualname__
39 |
40 | return wrapped
41 |
42 |
43 | # unit test
44 | if __name__ == "__main__":
45 | from typer.testing import CliRunner
46 |
47 | @dataclasses.dataclass
48 | class DemoArgs:
49 | name: str
50 | count: int = 1
51 |
52 | app = typer.Typer()
53 |
54 | @app.command()
55 | @dataclass_cli
56 | def main(args: DemoArgs):
57 | print(f"{args.name}|{args.count}")
58 |
59 | runner = CliRunner()
60 |
61 | res1 = runner.invoke(app, [], env={"SLIME_SCRIPT_NAME": "EnvName", "SLIME_SCRIPT_COUNT": "10"})
62 | print(f"{res1.stdout=}")
63 | assert res1.exit_code == 0
64 | assert "EnvName|10" in res1.stdout.strip()
65 |
66 | res2 = runner.invoke(app, ["--count", "999"], env={"SLIME_SCRIPT_NAME": "EnvName"})
67 | print(f"{res2.stdout=}")
68 | assert res2.exit_code == 0
69 | assert "EnvName|999" in res2.stdout.strip()
70 |
71 | print("✅ All Tests Passed!")
72 |
--------------------------------------------------------------------------------
/slime/utils/fp8_kernel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | import triton.language as tl
4 |
5 | fp8_dtype = torch.float8_e4m3fn
6 | fp8_max = torch.finfo(fp8_dtype).max
7 | fp8_min = -fp8_max
8 |
9 |
10 | def ceil_div(x: int, y: int) -> int:
11 | """
12 | Perform ceiling division of two integers.
13 |
14 | Args:
15 | x: the dividend.
16 | y: the divisor.
17 |
18 | Returns:
19 | The result of the ceiling division.
20 | """
21 | return (x + y - 1) // y
22 |
23 |
24 | @triton.jit
25 | def _blockwise_cast_to_fp8_triton(
26 | X,
27 | Y,
28 | S,
29 | stride_xm,
30 | stride_xn,
31 | stride_ym,
32 | stride_yn,
33 | stride_sm,
34 | stride_sn,
35 | M,
36 | N,
37 | eps,
38 | fp8_min,
39 | fp8_max,
40 | BLOCK_M: tl.constexpr = 32,
41 | BLOCK_N: tl.constexpr = 128,
42 | ):
43 | pid_m = tl.cast(tl.program_id(axis=0), tl.int64)
44 | pid_n = tl.cast(tl.program_id(axis=1), tl.int64)
45 | off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
46 | off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
47 | mask_m = off_m < M
48 | mask_n = off_n < N
49 | mask = mask_m[:, None] & mask_n[None, :]
50 |
51 | x = tl.load(X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn, mask=mask, other=0.0).to(tl.float32)
52 | _absmax = tl.maximum(tl.max(tl.abs(x)), eps)
53 | x_s = _absmax / fp8_max
54 | s_inv = 1.0 / x_s
55 | y_q = tl.clamp(x * s_inv, fp8_min, fp8_max).to(Y.dtype.element_ty)
56 |
57 | tl.store(Y + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn, y_q, mask=mask)
58 | tl.store(S + pid_m * stride_sm + pid_n * stride_sn, x_s)
59 |
60 |
61 | def blockwise_cast_to_fp8_triton(x: torch.Tensor, block_size=None) -> tuple[torch.Tensor, torch.Tensor]:
62 | BLOCK_M, BLOCK_N = 128, 128
63 | if block_size:
64 | BLOCK_M, BLOCK_N = block_size[0], block_size[1]
65 | M, N = x.shape
66 | y = torch.empty(M, N, device=x.device, dtype=torch.float8_e4m3fn)
67 | s = torch.empty(ceil_div(M, BLOCK_M), ceil_div(N, BLOCK_N), dtype=torch.float32, device=x.device)
68 |
69 | def grid(meta):
70 | return (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]))
71 |
72 | if x.is_contiguous():
73 | kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 8, "num_stages": 2}
74 | else:
75 | kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 1, "num_stages": 4}
76 | _blockwise_cast_to_fp8_triton[grid](
77 | x, y, s, *x.stride(), *y.stride(), *s.stride(), M, N, 1e-10, fp8_min, fp8_max, **kwargs
78 | )
79 | return y, s
80 |
--------------------------------------------------------------------------------
/slime/ray/utils.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ray/utils.py#L1
2 | import os
3 |
4 | import ray
5 | import torch
6 | from slime.ray.ray_actor import RayActor
7 |
8 |
9 | # Refer to
10 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96
11 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103
12 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95
13 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117
14 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109
15 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172
16 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98
17 | NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [
18 | "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
19 | "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES",
20 | "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES",
21 | "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES",
22 | "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES",
23 | "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS",
24 | "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR",
25 | ]
26 |
27 |
28 | def ray_noset_visible_devices(env_vars=os.environ):
29 | return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST)
30 |
31 |
32 | def get_physical_gpu_id():
33 | device = torch.cuda.current_device()
34 | props = torch.cuda.get_device_properties(device)
35 | return str(props.uuid)
36 |
37 |
38 | @ray.remote
39 | class Lock(RayActor):
40 | def __init__(self):
41 | self._locked = False # False: unlocked, True: locked
42 |
43 | def acquire(self):
44 | """
45 | Try to acquire the lock. Returns True if acquired, False otherwise.
46 | Caller should retry until it returns True.
47 | """
48 | if not self._locked:
49 | self._locked = True
50 | return True
51 | return False
52 |
53 | def release(self):
54 | """Release the lock, allowing others to acquire."""
55 | assert self._locked, "Lock is not acquired, cannot release."
56 | self._locked = False
57 |
--------------------------------------------------------------------------------
/docs/zh/get_started/qa.md:
--------------------------------------------------------------------------------
1 | # 常见 Q&A
2 |
3 | 1. **训练过程中为什么会出现乱码?**
4 |
5 | 一般来说这种情况是 megatron 没有被正确加载。请检查 `--load` 或 `--ref-load` 是否有对应的 ckpt。注意 megatron 只能加载其中有 `latest_checkpointed_iteration.txt` 的目录。
6 |
7 | 如果需要指定某个特定的 iter,可以查看当前 megatron 的使用方法,一般是可以通过 `--ckpt-step` 来指定步数。
8 |
9 | 1. **为什么我的任务一直卡在 ray 提交的页面上?**
10 |
11 | 请先检查你需要跑的任务是训推一体的,还是训推分离的。
12 |
13 | 如果是训推一体,即训练和推理共用 GPU,请检查
14 |
15 | - 是否设置了 `--colocate` 参数开启训推一体;
16 | - 当前任务的总卡数是否大于等于 `actor_num_nodes * actor_num_gpus_per_node`
17 |
18 | 如果是训推分离,请检查:
19 |
20 | - 当前任务的总卡数是否大于等于 `actor_num_nodes * actor_num_gpus_per_node + rollout_num_gpus`
21 |
22 | 1. **为什么训着训着 OOM 了?`max_tokens_per_gpu` 是干什么用的?**
23 |
24 | OOM 往往是因为 `max_tokens_per_gpu` 设置过高了。 `max_tokens_per_gpu` 是指在训练过程中,每张 GPU 上最多可以放多少 token。如果担心 OOM 的话,可以先把这个值设成 `rollout_max_response_len / cp_size`,之后再为了提升训练效率来增大这个值。`--max-tokens-per-gpu` 只有在开启 `--use-dynamic-batch-size` 的情况下才会启用。
25 |
26 | 如果 `max_tokens_per_gpu` 很小,还会 oom,可以检查一下是否单次生成的数据太长了,需要开启 cp(`--context-parallel-size`)。如果进行了自定义的数据生成,可以看一下是否在多轮生成的情况下,生成的总长度比预期的长很多。
27 |
28 | 1. **多机训练的时候,遇到了 transformers 库找不到某个模型的错误该怎么办?**
29 |
30 | 这种情况一般是因为多个进程都在通过类似于 `AutoConfig.from_pretrained` 或者 `AutoModelForCausalLM.from_pretrained` 的方式读取本地文件,出现了文件系统的写冲突。可以通过设置 `--model-name` 缓解这一问题。
31 |
32 | 1. **如何续训?**
33 |
34 | 直接将 `--load` 设置为 `--save` 的目录即可。
35 |
36 | 1. **batch size 是如何计算的?**
37 |
38 | 一个 rollout 会用 `rollout_batch_size` 条 prompt,每一条会采 `n_samples_per_prompt` 条,所以一个 rollout 共 `rollout_batch_size * n_samples_per_prompt` 条数据。
39 |
40 | 可以用 `--num-steps-per-rollout` 来决定每一个 rollout 跑多少步。这相当于是把 `global_batch_size` 设置成 `rollout_batch_size * n_samples_per_prompt // num_steps_per_rollout`。
41 |
42 | 1. **slime 是否进行了 data packing / varlen 处理?**
43 |
44 | data packing 是指在训练过程中,将长短不一的 sample 拼接到一起,从而提升训练的利用率。slime 默认会进行这样的操作。
45 |
46 | 1. **sglang 部分出现 `Max retries exceeded with url: /get_model_info (Caused by NewConnectionError` 的问题怎么办?**
47 |
48 | 这个问题主要来源于单机内多个 sglang server 导致的端口冲突,目前我们仍在和 sglang 团队一起解决这个问题。一个临时的缓解方案是尽可能减少单机内的 sglang server 数量,例如设置 tp=8。
49 |
50 | 1. **grad norm 好高,训练训崩了怎么办?**
51 |
52 | 首先请确保数据和模型是匹配的,例如说,如果数据是实现已经做好 chat template 的了,这个 chat template 是否和原模型一致。如果数据正确的话,可以参考 [debug 指南](../developer_guide/debug.md) 进行更深入的分析。
53 |
54 | 1. **我的 sglang 生成时间特别特别久,gpu 功率都打满了,跑了好久好没有输出是为什么?**
55 |
56 | 请确认一下 `--hf-checkpoint` 对应的模型是否正确设置了 stop token,如果没有,可以通过 `--rollout-stop` 或者 `--rollout-stop-token-ids` 来进行设置。
57 |
58 | 1. **sglang 出现 an illegal memory access was encountered**
59 |
60 | 根据 sglang 的文档(https://docs.sglang.ai/references/troubleshooting.html),有可能是 OOM 了,可以考虑缩小 `--sglang-mem-fraction-static`。
61 |
62 | 1. **出现 torch compile/inducer 的 `JSONDecodeError`**
63 |
64 | 一般是 torch compile 读写 cache 出现的问题。可以考虑在 ray 的 env_var 里加上 `"TORCHINDUCTOR_FORCE_DISABLE_CACHES": "1"`。
65 |
66 | 1. **训练出现 grad NaN 或者 Inf 的情况**
67 |
68 | 可以通过设置 `--no-check-for-nan-in-loss-and-grad` 来尝试跳过对应的训练步。
69 |
--------------------------------------------------------------------------------
/slime/utils/misc.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import subprocess
3 |
4 | import ray
5 |
6 | from slime.utils.http_utils import is_port_available
7 |
8 |
9 | def load_function(path):
10 | """
11 | Load a function from a module.
12 | :param path: The path to the function, e.g. "module.submodule.function".
13 | :return: The function object.
14 | """
15 | module_path, _, attr = path.rpartition(".")
16 | module = importlib.import_module(module_path)
17 | return getattr(module, attr)
18 |
19 |
20 | class SingletonMeta(type):
21 | """
22 | A metaclass for creating singleton classes.
23 | """
24 |
25 | _instances = {}
26 |
27 | def __call__(cls, *args, **kwargs):
28 | if cls not in cls._instances:
29 | instance = super().__call__(*args, **kwargs)
30 | cls._instances[cls] = instance
31 | return cls._instances[cls]
32 |
33 |
34 | def exec_command(cmd: str, capture_output: bool = False) -> str | None:
35 | print(f"EXEC: {cmd}", flush=True)
36 |
37 | try:
38 | result = subprocess.run(
39 | ["bash", "-c", cmd],
40 | shell=False,
41 | check=True,
42 | capture_output=capture_output,
43 | **(dict(text=True) if capture_output else {}),
44 | )
45 | except subprocess.CalledProcessError as e:
46 | if capture_output:
47 | print(f"{e.stdout=} {e.stderr=}")
48 | raise
49 |
50 | if capture_output:
51 | print(f"Captured stdout={result.stdout} stderr={result.stderr}")
52 | return result.stdout
53 |
54 |
55 | def get_current_node_ip():
56 | address = ray._private.services.get_node_ip_address()
57 | # strip ipv6 address
58 | address = address.strip("[]")
59 | return address
60 |
61 |
62 | def get_free_port(start_port=10000, consecutive=1):
63 | # find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available
64 | port = start_port
65 | while not all(is_port_available(port + i) for i in range(consecutive)):
66 | port += 1
67 | return port
68 |
69 |
70 | def should_run_periodic_action(
71 | rollout_id: int,
72 | interval: int | None,
73 | num_rollout_per_epoch: int | None = None,
74 | num_rollout: int | None = None,
75 | ) -> bool:
76 | """
77 | Return True when a periodic action (eval/save/checkpoint) should run.
78 |
79 | Args:
80 | rollout_id: The current rollout index (0-based).
81 | interval: Desired cadence; disables checks when None.
82 | num_rollout_per_epoch: Optional epoch boundary to treat as a trigger.
83 | """
84 | if interval is None:
85 | return False
86 |
87 | if num_rollout is not None and rollout_id == num_rollout - 1:
88 | return True
89 |
90 | step = rollout_id + 1
91 | return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0)
92 |
--------------------------------------------------------------------------------
/docs/zh/examples/qwen3-4b-base-openhermes.md:
--------------------------------------------------------------------------------
1 | # SFT Qwen3-4B-Base
2 |
3 | ## 环境准备
4 |
5 | 首先需要我们仿照 [示例:Qwen3-4B 模型](qwen3-4B.md) 创建镜像环境与转换 `Qwen3-4B-Base` 模型。
6 |
7 | 之后,我们处理 sft 数据。这里我们以经典的 [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5) 为例,首先把数据处理成适合 slime 加载的格式,可以用如下的脚本进行处理,增加一个符合 openai message 格式的列,并保存在 `/root/openhermes2_5.parquet`。
8 |
9 | ```python
10 | from datasets import load_dataset
11 |
12 | ds = load_dataset("teknium/OpenHermes-2.5")["train"]
13 |
14 | def convert(sample):
15 | conversations = sample["conversations"]
16 |
17 | def convert_role(role):
18 | if role == "human":
19 | return "user"
20 | elif role == "gpt":
21 | return "assistant"
22 | elif role == "system":
23 | return "system"
24 | else:
25 | raise ValueError(f"Unknown role: {role}")
26 |
27 | messages = [
28 | {
29 | "role": convert_role(turn["from"]),
30 | "content": turn["value"],
31 | }
32 | for turn in conversations
33 | ]
34 |
35 | return {"messages": messages}
36 |
37 | ds = ds.map(convert)
38 | ds.to_parquet("/root/openhermes2_5.parquet")
39 | ```
40 |
41 | ## 执行训练
42 |
43 | 执行训练:
44 |
45 | ```bash
46 | cd /root/slime
47 | bash script/run-qwen3-4B-base-sft.sh
48 | ```
49 |
50 | ### 参数简介
51 |
52 | 可以将 [run-qwen3-4B-base-sft.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-4B-base-sft.sh) 与 [run-qwen3-4B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-4B.sh) 进行对比。会发现除了我们将模型由 instruct 模型换为了 base 模型之外,主要进行了如下的几个调整:
53 |
54 | 1. 移除了 `SGLANG_ARGS` 和 `GRPO_ARGS`。这是因为 sft 的过程中不需要启动 sglang 或者做 grpo 相关的配置;
55 |
56 | 2. 将 `ROLLOUT_ARGS` 改名为了 `SFT_ARGS`,并配置为:
57 |
58 | ```bash
59 | SFT_ARGS=(
60 | --rollout-function-path slime.rollout.sft_rollout.generate_rollout
61 | --prompt-data /root/openhermes2_5.parquet
62 | --input-key messages
63 | --rollout-shuffle
64 | --num-epoch 3
65 | --rollout-batch-size 128
66 | --global-batch-size 128
67 |
68 | --loss-type sft_loss
69 | --calculate-per-token-loss
70 | --disable-compute-advantages-and-returns
71 | --debug-train-only
72 | )
73 | ```
74 |
75 | slime 中的 sft 实际上是复用了 slime 的 custom rollout 功能,通过 `--rollout-function-path` 将数据生成部分从使用 sglang 的 RL rollout,切换成了从文件中读取数据的 sft 版本,即 `slime.rollout.sft_rollout.generate_rollout`。
76 |
77 | 对于 sft 来说,建议将 `rollout_batch_size` 与 `global_batch_size` 设置成相同的,并不要配置 `n_samples_per_prompt`,这样相当于是读一个 batch 就训一个 batch。
78 |
79 | slime 还支持不同的 loss 类型,我们就是通过 `--loss-type sft_loss` 配置上 sft loss 的。
80 |
81 | 至于 `--calculate-per-token-loss`,这是因为 slime 默认是以 GRPO 的 per sample mean 进行计算的,而一般 sft 训练都是按一个 batch 的所有不被 mask 的 token 取平均,所以建议配置上。
82 |
83 | 最后 `--disable-compute-advantages-and-returns` 表示 sft 的过程中不需要预先计算 log prob,`--debug-train-only` 表示不需要初始化 sglang。
84 |
85 | 3. 使用了 `train_async.py` 而不是 `train.py`。这是为了利用异步训练的流程,来实现数据 prefetch。
86 |
--------------------------------------------------------------------------------
/slime/backends/megatron_utils/megatron_to_hf/llama.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 |
4 |
5 | def convert_llama_to_hf(args, name, param):
6 | if name == "module.module.embedding.word_embeddings.weight":
7 | return [("model.embed_tokens.weight", param)]
8 | if name == "module.module.output_layer.weight":
9 | return [("lm_head.weight", param)]
10 | if name == "module.module.decoder.final_layernorm.weight":
11 | return [("model.norm.weight", param)]
12 |
13 | try:
14 | head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads
15 | except AttributeError:
16 | head_dim = args.hidden_size // args.num_attention_heads
17 | value_num_per_group = args.num_attention_heads // args.num_query_groups
18 |
19 | decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)"
20 | match = re.match(decoder_layers_pattern, name)
21 | if match:
22 | layer_idx, rest = match.groups()
23 | if rest == "self_attention.linear_proj.weight":
24 | return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)]
25 | elif rest == "self_attention.linear_qkv.weight":
26 | # Split QKV weight for Llama
27 | param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size)
28 | q_param, k_param, v_param = torch.split(param, split_size_or_sections=[value_num_per_group, 1, 1], dim=1)
29 | q_param = q_param.reshape(-1, args.hidden_size)
30 | k_param = k_param.reshape(-1, args.hidden_size)
31 | v_param = v_param.reshape(-1, args.hidden_size)
32 | return [
33 | (f"model.layers.{layer_idx}.self_attn.q_proj.weight", q_param),
34 | (f"model.layers.{layer_idx}.self_attn.k_proj.weight", k_param),
35 | (f"model.layers.{layer_idx}.self_attn.v_proj.weight", v_param),
36 | ]
37 | elif rest == "mlp.linear_fc1.weight":
38 | # Split gate and up projections for SwiGLU
39 | gate_weight, up_weight = param.chunk(2, dim=0)
40 | return [
41 | (f"model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight),
42 | (f"model.layers.{layer_idx}.mlp.up_proj.weight", up_weight),
43 | ]
44 | elif rest == "mlp.linear_fc2.weight":
45 | return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)]
46 | elif rest == "self_attention.linear_qkv.layer_norm_weight":
47 | return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)]
48 | elif rest == "mlp.linear_fc1.layer_norm_weight":
49 | return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)]
50 | elif rest == "pre_mlp_layernorm.weight":
51 | return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)]
52 |
53 | raise ValueError(f"Unknown parameter name: {name}")
54 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "packaging",
4 | "setuptools >= 49.4.0",
5 | "wheel",
6 | ]
7 | build-backend = "setuptools.build_meta"
8 |
9 | [tool.isort]
10 | profile = "black" # black-compatible
11 | line_length = 119 # should match black parameters
12 | ignore_whitespace = true # ignore whitespace for compatibility with the initial style
13 | py_version = 310 # python 3.10 as a target version
14 | sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
15 | default_section = "THIRDPARTY"
16 | extend_skip = ["setup.py", "docs/source/conf.py"]
17 | known_first_party = ["slime", "slime_plugins"]
18 | known_third_party = ["megatron", "wandb", "ray", "transformers"]
19 | src_paths = ["slime", "slime_plugins"]
20 |
21 |
22 | [tool.black]
23 | line_length = 119
24 |
25 | [tool.ruff]
26 | line-length = 320 # TODO
27 | select = [
28 | "E", # Pycodestyle Errors (Structural/Fundamental Errors like bad indentation)
29 | "F", # Pyflakes (Core Errors: Unused imports, undefined names)
30 | "B", # Flake8-Bugbear (Logic Bugs: Variable shadowing, dangerous default arguments)
31 | "UP", # pyupgrade (Modernization and compatibility issues)
32 | ]
33 | ignore = [
34 | "E402", # module-import-not-at-top-of-file
35 | "E501", # Line too long # TODO handle it later
36 | ]
37 |
38 | [tool.pytest.ini_options]
39 | # durations=0 will display all tests execution time, sorted in ascending order starting from from the slowest one.
40 | # -vv will also display tests with duration = 0.00s
41 | addopts = "--verbose --pyargs --durations=0 --strict-markers" # always add these arguments to pytest
42 | testpaths = ["./tests"] # must be an explicit path to avoid importing another "tests" module
43 | # directories to ignore when discovering tests
44 | norecursedirs = [
45 | "external",
46 | "examples",
47 | "docs",
48 | "scripts",
49 | "tools",
50 | "tutorials",
51 | "*.egg",
52 | ".*",
53 | "_darcs",
54 | "build",
55 | "CVS",
56 | "dist",
57 | "venv",
58 | "{arch}",
59 | ]
60 | # markers to select tests, use `pytest --markers` to see all available markers, `pytest -m ""` to select tests
61 | markers = [
62 | "unit: marks unit test, i.e. testing a single, well isolated functionality (deselect with '-m \"not unit\"')",
63 | "integration: marks test checking the elements when integrated into subsystems (deselect with '-m \"not integration\"')",
64 | "system: marks test working at the highest integration level (deselect with '-m \"not system\"')",
65 | "acceptance: marks test checking whether the developed product/model passes the user defined acceptance criteria (deselect with '-m \"not acceptance\"')",
66 | "docs: mark tests related to documentation (deselect with '-m \"not docs\"')",
67 | "skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups",
68 | "pleasefixme: marks tests that are broken and need fixing",
69 | ]
70 |
--------------------------------------------------------------------------------
/slime/backends/megatron_utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import re
4 | from pathlib import Path
5 |
6 | # TODO: may need to copy those 2 functions and do refactoring.
7 | from megatron.training.checkpointing import load_checkpoint as _load_checkpoint_megatron
8 | from megatron.training.checkpointing import save_checkpoint
9 | from megatron.training.global_vars import get_args
10 |
11 | from slime.utils import megatron_bridge_utils
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | __all__ = ["save_checkpoint"]
16 |
17 |
18 | def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_context, skip_load_to_model_and_opt):
19 | # ref: how megatron `load_checkpoint` gets directory
20 | args = get_args()
21 | load_path = args.load
22 |
23 | assert Path(load_path).exists() and _is_dir_nonempty(
24 | load_path
25 | ), f"{args.load=} does not exist or is an empty directory. Did you specify the wrong folder?"
26 |
27 | if _is_megatron_checkpoint(load_path):
28 | return _load_checkpoint_megatron(
29 | ddp_model=ddp_model,
30 | optimizer=optimizer,
31 | opt_param_scheduler=opt_param_scheduler,
32 | checkpointing_context=checkpointing_context,
33 | skip_load_to_model_and_opt=skip_load_to_model_and_opt,
34 | )
35 | else:
36 | return _load_checkpoint_hf(
37 | ddp_model=ddp_model,
38 | optimizer=optimizer,
39 | args=args,
40 | load_path=load_path,
41 | )
42 |
43 |
44 | def _is_megatron_checkpoint(path: str | Path) -> bool:
45 | return (Path(path) / "latest_checkpointed_iteration.txt").is_file() or bool(
46 | re.fullmatch(r"iter_\d{7}", Path(path).name)
47 | )
48 |
49 |
50 | def _load_checkpoint_hf(ddp_model, optimizer, args, load_path: str):
51 | assert args.megatron_to_hf_mode == "bridge", "Only bridge mode is supported for loading HF checkpoint"
52 | from megatron.bridge import AutoBridge
53 |
54 | import slime_plugins.megatron_bridge # noqa: F401
55 |
56 | logger.info(f"Load checkpoint from HuggingFace model into Megatron (path={load_path})")
57 |
58 | with megatron_bridge_utils.patch_megatron_model(ddp_model):
59 | bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True)
60 | bridge.load_hf_weights(ddp_model)
61 |
62 | # Copied from Megatron-core :: load_checkpoint (with simplifications)
63 | if (args.fp16 or args.bf16) and optimizer is not None:
64 | assert not args.load_main_params_from_ckpt
65 | optimizer.reload_model_params()
66 |
67 | # We can see `successfully loaded checkpoint from ... [ t 1/2, p 1/1 ] at iteration 0`
68 | # when loading Megatron, thus it is 0
69 | iteration = 0
70 | num_floating_point_operations_so_far = 0
71 | return iteration, num_floating_point_operations_so_far
72 |
73 |
74 | def _is_dir_nonempty(path):
75 | with os.scandir(path) as it:
76 | return any(it)
77 |
--------------------------------------------------------------------------------
/slime/rollout/sft_rollout.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from slime.utils.mask_utils import MultiTurnLossMaskGenerator
4 | from slime.utils.processing_utils import load_processor, load_tokenizer, prepare_model_inputs
5 |
6 | __all__ = ["generate_rollout"]
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | TOKENIZER = None
12 | PROCESSOR = None
13 | MASK_GENERATOR = None
14 | SAMPLE_PRINTED = False
15 |
16 |
17 | def generate_rollout(args, rollout_id, data_buffer, evaluation=False):
18 | """An example to implement the generate_rollout function for an rule based rm rollout generation.
19 |
20 | Args:
21 | args: the whole args
22 | rollout_id: int, the id of the rollout, used for deterministic data generation
23 | data_buffer: the data buffer to store the generated samples
24 | evaluation: bool, whether the rollout is for evaluation or not
25 |
26 | Returns:
27 | list[Sample]: a list of samples generated by the rollout
28 | """
29 | assert not evaluation
30 | assert args.rollout_global_dataset
31 |
32 | global TOKENIZER, PROCESSOR, MASK_GENERATOR, SAMPLE_PRINTED
33 | if TOKENIZER is None:
34 | TOKENIZER = load_tokenizer(args.hf_checkpoint, trust_remote_code=True)
35 |
36 | if PROCESSOR is None:
37 | PROCESSOR = load_processor(args.hf_checkpoint, trust_remote_code=True)
38 |
39 | if MASK_GENERATOR is None:
40 | MASK_GENERATOR = MultiTurnLossMaskGenerator(TOKENIZER, tokenizer_type=args.loss_mask_type)
41 |
42 | samples = data_buffer.get_samples(args.rollout_batch_size)
43 |
44 | for i, sample in enumerate(samples):
45 | (sample,) = sample
46 | messages = sample.prompt
47 | tools = sample.metadata.get("tools", None)
48 |
49 | input_ids, extra_info = prepare_model_inputs(
50 | messages, TOKENIZER, PROCESSOR, sample.metadata,
51 | args.apply_chat_template, args.apply_chat_template_kwargs
52 | )
53 |
54 | has_multimodal = bool(extra_info.get("images") or extra_info.get("videos"))
55 | if has_multimodal:
56 | sample.multimodal_inputs = extra_info["multimodal_inputs"]
57 | token_ids, loss_mask = MASK_GENERATOR.get_loss_mask_with_multimodal_alignment(
58 | messages, input_ids, tools=tools
59 | )
60 | else:
61 | token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools)
62 |
63 | response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0]
64 |
65 | sample.tokens = token_ids
66 | sample.response_length = response_length
67 | sample.reward = 0
68 | sample.loss_mask = loss_mask[-response_length:]
69 |
70 | if i == 0 and not SAMPLE_PRINTED:
71 | logger.info(
72 | f"sft_rollout::generate_rollout example data: {sample=} (raw){messages=} (raw){token_ids=} (raw){loss_mask=} {response_length=}"
73 | )
74 | SAMPLE_PRINTED = True
75 |
76 | return samples
77 |
--------------------------------------------------------------------------------
/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | from slime.utils import megatron_bridge_utils
4 | from slime.utils.iter_utils import chunk_named_params_by_size
5 |
6 | from ..megatron_to_hf import postprocess_hf_param
7 | from ..misc_utils import strip_param_name_prefix
8 | from .hf_weight_iterator_base import HfWeightIteratorBase
9 |
10 |
11 | class HfWeightIteratorBridge(HfWeightIteratorBase):
12 | def __init__(self, *args, **kwargs):
13 | super().__init__(*args, **kwargs)
14 |
15 | from megatron.bridge import AutoBridge
16 | import slime_plugins.megatron_bridge # noqa: F401
17 |
18 | self._bridge = AutoBridge.from_hf_pretrained(self.args.hf_checkpoint)
19 |
20 | def get_hf_weight_chunks(self, megatron_local_weights):
21 | # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name)
22 | renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()}
23 | with megatron_bridge_utils.patch_megatron_model(self.model):
24 | conversion_tasks = self._bridge.get_conversion_tasks(self.model)
25 | conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights)
26 |
27 | named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks)
28 |
29 | named_weights = (
30 | (
31 | hf_param_name,
32 | postprocess_hf_param(
33 | args=self.args,
34 | megatron_param_name=megatron_param_name,
35 | hf_param_name=hf_param_name,
36 | param=weight,
37 | ),
38 | )
39 | for hf_param_name, weight, megatron_param_name in named_weights
40 | )
41 |
42 | yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size)
43 |
44 |
45 | def _process_conversion_tasks(vanilla_conversion_tasks, new_weight_dict):
46 | def _handle_one(task):
47 | if task.param_weight is None:
48 | return task
49 |
50 | weight_dict_key = f"vp_stages.{task.vp_stage}.{task.param_name}"
51 | assert (
52 | weight_dict_key in new_weight_dict
53 | ), f"{weight_dict_key=} not in new_weight_dict ({task.vp_stage=}, {task.param_name=}, {list(new_weight_dict)=})"
54 |
55 | new_param_weight = new_weight_dict[weight_dict_key]
56 | new_param_weight = new_param_weight.cuda()
57 | return dataclasses.replace(task, param_weight=new_param_weight)
58 |
59 | return _MapWithLen(_handle_one, vanilla_conversion_tasks)
60 |
61 |
62 | class _MapWithLen:
63 | def __init__(self, fn, xs):
64 | self.fn = fn
65 | self.xs = xs
66 |
67 | def __len__(self):
68 | return len(self.xs)
69 |
70 | def __iter__(self):
71 | for x in self.xs:
72 | yield self.fn(x)
73 |
--------------------------------------------------------------------------------
/examples/eval/nemo_skills/skills_client.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | from typing import Any
4 |
5 | import requests
6 | from examples.eval.eval_delegate import EvalClient, EvalDelegateError
7 | from examples.eval.nemo_skills.skills_config import SkillsEvalEnvConfig
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class SkillsEvalClient(EvalClient):
13 | """HTTP client that proxies evaluation requests to the NeMo Skills server."""
14 |
15 | def __init__(self, config: SkillsEvalEnvConfig, router_url: str):
16 | super().__init__(config.name or "skills")
17 | self._config = config
18 | self._router_url = router_url.rstrip("/")
19 | self._endpoint = (config.url or "").rstrip("/")
20 | self._timeout_secs = float(config.timeout_secs)
21 | self._max_retries = max(1, int(config.max_retries))
22 | self._headers = dict(config.headers or {})
23 | self._session = requests.Session()
24 |
25 | @classmethod
26 | def from_config(cls, config: SkillsEvalEnvConfig, router_url: str):
27 | if not config.url:
28 | return None
29 | return cls(config, router_url)
30 |
31 | def evaluate(self, args, rollout_id: int) -> tuple[dict[str, Any], dict[str, Any]]:
32 | if not self._config.datasets:
33 | logger.warning("No Skills datasets configured; skipping delegate evaluation.")
34 | return {}, {}
35 |
36 | payload = self._build_payload(args, rollout_id)
37 | response = self._request(payload)
38 | metrics = response["raw_metrics"]
39 | return metrics, response
40 |
41 | def _build_payload(self, args, rollout_id: int) -> dict[str, Any]:
42 | benchmarks = [cfg.to_payload() for cfg in self._config.datasets]
43 | benchmarks = [cfg for cfg in benchmarks if cfg]
44 | return {
45 | "rollout_id": rollout_id,
46 | "router_url": self._router_url,
47 | "benchmarks": benchmarks,
48 | }
49 |
50 | def _request(self, payload: dict[str, Any]) -> dict[str, Any]:
51 | last_error: Exception | None = None
52 | for attempt in range(1, self._max_retries + 1):
53 | try:
54 | response = self._session.post(
55 | self._endpoint,
56 | json=payload,
57 | timeout=self._timeout_secs,
58 | headers=self._headers,
59 | )
60 | response.raise_for_status()
61 | if not response.content:
62 | return {}
63 | return response.json()
64 | except requests.RequestException as exc:
65 | last_error = exc
66 | logger.warning(
67 | "Skills eval delegate request failed (attempt %s/%s): %s", attempt, self._max_retries, exc
68 | )
69 | if attempt < self._max_retries:
70 | time.sleep(min(2**attempt, 30))
71 | raise EvalDelegateError("Skills evaluation request failed") from last_error
72 |
--------------------------------------------------------------------------------
/slime/backends/megatron_utils/megatron_to_hf/mimo.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | from .qwen2 import convert_qwen2_to_hf
4 |
5 |
6 | def convert_mimo_to_hf(args, name, param):
7 | """
8 | Convert MiMo model parameters from Megatron to HuggingFace format.
9 |
10 | MiMo extends Qwen2 with MTP (Multi-Token Prediction) layers.
11 | """
12 |
13 | if "mtp" in name:
14 | return convert_mimo_mtp_param(args, name, param)
15 |
16 | return convert_qwen2_to_hf(args, name, param)
17 |
18 |
19 | def convert_mimo_mtp_param(args, name, param):
20 | """
21 | Convert MTP layer parameters from Megatron to HuggingFace format.
22 |
23 | MTP layers in MiMo contain:
24 | - LayerNorms (token_layernorm, hidden_layernorm, final_layernorm)
25 | - Input projection (input_proj)
26 | - Self attention (reuses Qwen2 attention structure)
27 | - MLP (reuses Qwen2 MLP structure)
28 |
29 | Based on MimoBridge._convert_mtp_param logic (reverse mapping)
30 | """
31 | mtp_pattern = r"module\.module\.mtp\.layers\.(\d+)\.(.+)"
32 | match = re.match(mtp_pattern, name)
33 |
34 | if not match:
35 | raise ValueError(f"Invalid MTP parameter name: {name}")
36 |
37 | layer_idx, component = match.groups()
38 |
39 | # Direct mappings for MTP-specific components (Megatron -> HF)
40 | # Based on MimoBridge direct_name_mapping (reversed)
41 | direct_mappings = {
42 | "enorm.weight": f"model.mtp_layers.{layer_idx}.token_layernorm.weight",
43 | "hnorm.weight": f"model.mtp_layers.{layer_idx}.hidden_layernorm.weight",
44 | "eh_proj.weight": f"model.mtp_layers.{layer_idx}.input_proj.weight",
45 | "final_layernorm.weight": f"model.mtp_layers.{layer_idx}.final_layernorm.weight",
46 | }
47 | if component == "eh_proj.weight":
48 | first_half, second_half = param.chunk(2, dim=1)
49 | param = torch.cat([second_half, first_half], dim=1)
50 |
51 | # Check direct mappings first
52 | if component in direct_mappings:
53 | return [(direct_mappings[component], param)]
54 |
55 | # Handle transformer_layer components
56 | if component.startswith("transformer_layer."):
57 | # Remove "transformer_layer." prefix
58 | transformer_component = component[len("transformer_layer.") :]
59 |
60 | # Create proxy name for reusing existing Qwen2 conversion functions
61 | proxy_name = f"module.module.decoder.layers.{layer_idx}.{transformer_component}"
62 |
63 | # Use existing convert_qwen2_to_hf function for transformer components
64 | results = convert_qwen2_to_hf(args, proxy_name, param)
65 |
66 | # Replace model.layers with mtp_layers in results
67 | converted_results = []
68 | for hf_name, hf_param in results:
69 | # Replace model.layers.{idx} with mtp_layers.{idx}
70 | hf_name = hf_name.replace(f"model.layers.{layer_idx}", f"model.mtp_layers.{layer_idx}")
71 | converted_results.append((hf_name, hf_param))
72 |
73 | return converted_results
74 |
75 | raise ValueError(f"Unknown MTP component: {component} in {name}")
76 |
--------------------------------------------------------------------------------
/slime/rollout/rm_hub/__init__.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import random
3 |
4 | import aiohttp
5 |
6 | from slime.utils.misc import load_function
7 | from slime.utils.types import Sample
8 |
9 | from .deepscaler import get_deepscaler_rule_based_reward
10 | from .f1 import f1_score
11 | from .gpqa import compute_gpqa_reward
12 | from .math_dapo_utils import compute_score as compute_score_dapo
13 | from .math_utils import extract_answer as extract_boxed_answer
14 | from .math_utils import grade_answer_verl
15 |
16 |
17 | async def remote_rm(args, sample: Sample):
18 | payload = {
19 | "prompt": sample.prompt,
20 | "response": sample.response,
21 | "label": sample.label,
22 | }
23 | session_kwargs = {}
24 | async with aiohttp.ClientSession(**session_kwargs) as session:
25 | async with session.post(args.rm_url, json=payload) as resp:
26 | resp.raise_for_status()
27 | return await resp.json()
28 |
29 |
30 | async def async_rm(args, sample: Sample, **kwargs):
31 | if args.custom_rm_path is not None:
32 | rm_function = load_function(args.custom_rm_path)
33 | return await rm_function(args, sample, **kwargs)
34 |
35 | metadata = sample.metadata if isinstance(sample.metadata, dict) else {}
36 | rm_type = (metadata.get("rm_type") or args.rm_type or "").strip()
37 | response = sample.response
38 | label = sample.label
39 | if rm_type.startswith("boxed_"):
40 | response = extract_boxed_answer(response) or ""
41 | rm_type = rm_type[len("boxed_") :]
42 |
43 | # This function is intended for remote or time-consuming reward model evaluation.
44 | # Implement the actual logic as needed.
45 | if rm_type == "remote_rm":
46 | return await remote_rm(args, sample)
47 | elif rm_type == "deepscaler":
48 | return get_deepscaler_rule_based_reward(response, label)
49 | elif rm_type == "dapo":
50 | return compute_score_dapo(response, label)
51 | elif rm_type == "math":
52 | return 1 if grade_answer_verl(response, label) else 0
53 | elif rm_type == "f1":
54 | return f1_score(response, label)[0]
55 | elif rm_type == "gpqa":
56 | return compute_gpqa_reward(response, label, metadata=metadata)
57 | elif rm_type == "ifbench":
58 | from .ifbench import compute_ifbench_reward
59 |
60 | return compute_ifbench_reward(response, label, metadata=metadata)
61 | elif rm_type == "random":
62 | return random.randint(0, 1)
63 | elif rm_type:
64 | raise NotImplementedError(f"Rule-based RM for {rm_type} is not implemented.")
65 | else:
66 | raise NotImplementedError("Rule-based RM type is not specified.")
67 |
68 |
69 | async def batched_async_rm(
70 | args,
71 | samples: list[Sample],
72 | **kwargs,
73 | ) -> list[int | float]:
74 | if args.custom_rm_path is not None:
75 | # Ensure the custom reward function is implemented in batch mode
76 | rm_function = load_function(args.custom_rm_path)
77 | return await rm_function(args, samples, **kwargs)
78 | tasks = [async_rm(args, sample, **kwargs) for sample in samples]
79 | rewards = await asyncio.gather(*tasks)
80 | return rewards
81 |
--------------------------------------------------------------------------------
/.github/workflows/conda-ci.yml:
--------------------------------------------------------------------------------
1 | name: conda CI
2 |
3 | on:
4 | pull_request:
5 | branches: [main]
6 |
7 | concurrency:
8 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
9 | cancel-in-progress: true
10 |
11 | jobs:
12 | build-conda:
13 | if: contains(github.event.pull_request.title, '[release]')
14 | runs-on: self-hosted
15 | container:
16 | image: lmsysorg/sglang:v0.5.0rc0-cu126
17 | options: --gpus all --ipc=host --shm-size=16g --ulimit memlock=-1 --ulimit stack=67108864 --memory=0 --memory-swap=0 -v /mnt/nvme0n1/models:/root/models -v /mnt/nvme0n1/datasets:/root/datasets
18 |
19 | defaults:
20 | run:
21 | working-directory: ${{ github.workspace }}
22 |
23 | steps:
24 | - name: Checkout repository
25 | uses: actions/checkout@v4
26 |
27 | - name: Construct Conda
28 | run: |
29 | echo "📦 Installing slime..."
30 | cd $GITHUB_WORKSPACE
31 | echo "Current directory: $(pwd)"
32 |
33 | mkdir -p /root/
34 | BASE_DIR=/root bash build_conda.sh
35 | shell: bash
36 |
37 | - name: Download model and dataset
38 | run: |
39 | echo "🔗 Downloading up model and dataset..."
40 |
41 | # Create cache directories if they don't exist
42 | mkdir -p /root/models /root/datasets
43 |
44 | echo "Downloading Qwen3-30B-A3B..."
45 | hf download Qwen/Qwen3-30B-A3B --local-dir /root/models/Qwen3-30B-A3B
46 | hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/models/Qwen3-30B-A3B-FP8
47 |
48 | hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir /root/datasets/dapo-math-17k
49 |
50 | hf download --repo-type dataset zhuzilin/aime-2024 --local-dir /root/datasets/aime-2024
51 | shell: bash
52 |
53 | - name: Convert checkpoint
54 | run: |
55 | echo "🔄 Converting model checkpoint..."
56 | cd $GITHUB_WORKSPACE
57 | echo "Current directory: $(pwd)"
58 |
59 | source ~/.bashrc
60 | micromamba activate slime
61 | export CUDA_HOME="$CONDA_PREFIX"
62 |
63 | source scripts/models/qwen3-30B-A3B.sh
64 | PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 tools/convert_hf_to_torch_dist.py \
65 | ${MODEL_ARGS[@]} \
66 | --hf-checkpoint /root/models/Qwen3-30B-A3B \
67 | --save /root/Qwen3-30B-A3B_torch_dist
68 | shell: bash
69 |
70 | - name: Run tests
71 | run: |
72 | echo "🧪 Running tests..."
73 | cd $GITHUB_WORKSPACE
74 | echo "Current directory: $(pwd)"
75 |
76 | source ~/.bashrc
77 | micromamba activate slime
78 | export CUDA_HOME="$CONDA_PREFIX"
79 |
80 | SLIME_TEST_USE_DEEPEP=0 SLIME_TEST_USE_FP8_ROLLOUT=0 python tests/test_qwen3_30B_A3B.py
81 | shell: bash
82 |
83 | - name: Cleanup
84 | if: always()
85 | run: |
86 | echo "🧹 Cleaning up..."
87 | pkill -9 ray || true
88 | ray stop --force || true
89 | pkill -9 python || true
90 | shell: bash
91 |
--------------------------------------------------------------------------------
/examples/multi_agent/prompts.py:
--------------------------------------------------------------------------------
1 | ## 定义了一些 prompts 的模板,用于生成不同的 prompts
2 |
3 |
4 | SOLVER_PROMPT_TEMPLATE = """{problem_statement}"""
5 |
6 |
7 | def generate_rewriter_template(num_solutions: int) -> str:
8 | """Dynamically generate rewrite templates based on the number of solutions."""
9 | solution_sections = []
10 | for i in range(num_solutions):
11 | solution_sections.append(f"#### Solution {i+1}\n{{solution{i+1}}}\n\n---")
12 |
13 | solutions_text = "\n".join(solution_sections)
14 |
15 | return f"""### Task: Solution Rewriting Based on Previous Solutions ###
16 | You are being reactivated to revise your mathematical proof. You are provided with two documents:
17 | 1. The problem you need to solve.
18 | 2. Your {num_solutions} different "Previous Solutions".
19 |
20 | Your sole task is to generate a new, correct version of your solution based on your previous discoveries in the provided {num_solutions} solutions.
21 |
22 | Refer to the following {num_solutions} solutions and solve the problem.
23 | ---
24 |
25 | ### Problem
26 |
27 | {{problem_statement}}
28 |
29 | ---
30 |
31 | ### Candidates Solution
32 | {solutions_text}
33 | """
34 |
35 |
36 | def generate_select_template(num_solutions: int) -> str:
37 | """Dynamically generate select templates based on the number of solutions."""
38 | solution_sections = []
39 | for i in range(num_solutions):
40 | solution_sections.append(f"#### Solution {i+1}\n{{solution{i+1}}}\n\n---")
41 |
42 | solutions_text = "\n".join(solution_sections)
43 |
44 | return f"""You will be given a challenging math problem followed by {num_solutions} solutions.
45 | Your task is to systematically analyze these solutions to identify the most mathematically sound approach.
46 |
47 | You are provided with two documents:
48 | 1. The problem you need to solve.
49 | 2. Your {num_solutions} "Candidate Solutions".
50 |
51 | Evaluation Process:
52 | 1. Initial Screening
53 | - Group solutions by their final answers
54 | - Identify and explain mathematical contradictions between different answers
55 | - Eliminate solutions with clear mathematical errors
56 |
57 | 2. Detailed Analysis
58 | For remaining solutions, evaluate:
59 | - Mathematical precision and accuracy
60 | - Logical progression of steps
61 | - Completeness of mathematical reasoning
62 | - Handling of edge cases or special conditions
63 | - For solutions containing and addressing errors, evaluate the error identification and correction methodology.
64 |
65 | 3. Solution Comparison
66 | Compare viable solutions based on:
67 | - Efficiency of approach
68 | - Clarity of mathematical reasoning
69 | - Sophistication of method
70 | - Robustness of solution (works for all cases)
71 |
72 | Your response should include:
73 | 1. Brief analysis of conflicting answers
74 | 2. Detailed evaluation of mathematically sound solutions
75 | 3. Justification for eliminating incorrect solutions
76 | 4. Clear explanation for selecting the best approach
77 |
78 | End your evaluation with exactly:
79 | Judgment: IDX
80 | where IDX is the index 1-{num_solutions} of the best solution
81 |
82 | ### Problem
83 |
84 | {{problem_statement}}
85 |
86 | ---
87 |
88 | ### Candidate Solutions
89 | {solutions_text}
90 | """
91 |
--------------------------------------------------------------------------------
/docs/zh/examples/qwen3-30B-A3B.md:
--------------------------------------------------------------------------------
1 | # 8xH100 训练 Qwen3-30B-A3B
2 |
3 | ## 环境准备
4 |
5 | 搭建环境、下载模型、数据与 ckpt 转换均与 Qwen3-4B 模型相同,可以参考 [示例:Qwen3-4B](qwen3-4B.md),将文中 Qwen3-4B 的部分转换为 Qwen3-30B-A3B 即可。
6 |
7 | 可以用如下方法把 huggingface checkpoint 转化为 torch_dist 格式:
8 |
9 | ```bash
10 | cd slime/
11 | pip install -e .
12 | source scripts/models/qwen3-30B-A3B.sh
13 | PYTHONPATH=/root/Megatron-LM/ torchrun --nproc-per-node 8 \
14 | tools/convert_hf_to_torch_dist.py \
15 | ${MODEL_ARGS[@]} \
16 | --hf-checkpoint /root/Qwen3-30B-A3B/ \
17 | --save /root/Qwen3-30B-A3B_torch_dist/
18 | ```
19 |
20 | ## 执行训练
21 |
22 | 执行训练:
23 |
24 | ```bash
25 | cd /root/slime
26 | bash scripts/run-qwen3-30B-A3B.sh
27 | ```
28 |
29 | ### 参数简介
30 |
31 | 这里我们简单介绍一下脚本 [run-qwen3-30B-A3B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-30B-A3B.sh) 中与 MoE 相关的部分。
32 |
33 | 1. 为了支持在 8xH800 环境中运行 Qwen3-30B-A3B,我们需要开启 megatron 的 CPU Adam 以节省显存,对应配置为:
34 |
35 | ```bash
36 | OPTIMIZER_ARGS=(
37 | ...
38 | --optimizer-cpu-offload
39 | --overlap-cpu-optimizer-d2h-h2d
40 | --use-precision-aware-optimizer
41 | )
42 | ```
43 |
44 | 2. 开启 megatron 支持的 moe 优化,当前配置为 tp4, ep8:
45 |
46 | ```bash
47 | PERF_ARGS=(
48 | --tensor-model-parallel-size 4
49 | --sequence-parallel
50 | --pipeline-model-parallel-size 1
51 | --context-parallel-size 1
52 | --expert-model-parallel-size 8
53 | --expert-tensor-parallel-size 1
54 | ...
55 | )
56 | ```
57 |
58 | 3. 开启 sglang 支持的 moe 优化,当前配置为 ep8:
59 |
60 | ```bash
61 | SGLANG_ARGS=(
62 | --rollout-num-gpus-per-engine 8
63 | --sglang-mem-fraction-static 0.7
64 | --sglang-enable-ep-moe
65 | --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256)
66 | )
67 | ```
68 |
69 | 类似地,也可以加入 dp attention,例如配置上:
70 |
71 | ```bash
72 | --sglang-enable-dp-attention
73 | --sglang-dp-size 8
74 | ```
75 |
76 | ### bf16 训练 fp8 推理
77 |
78 | slime 还支持 bf16 训练,fp8 推理。对于 Qwen3-30B-A3B 模型,只需要下载如下模型:
79 |
80 | ```bash
81 | hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/Qwen3-30B-A3B-FP8
82 | ```
83 |
84 | 并将 `--hf-checkpoint` 替换为:
85 |
86 | ```bash
87 | #--hf-checkpoint /root/Qwen3-30B-A3B
88 | --hf-checkpoint /root/Qwen3-30B-A3B-FP8
89 | ```
90 |
91 | 即可触发 fp8 训练。目前我们会将 bf16 权重直接 cast 为 fp8,后续会逐渐添加对精度影响更小的量化方案。
92 |
93 | ⚠️ 训练的 megatron checkpoint 还需要是最开始用 bf16 的 huggingface 转换的。
94 |
95 | ### 多机支持
96 |
97 | 对于多机环境,需要进行如下的几点修改:
98 | - 将训练模型,数据放在所有机器都可以访问到的路径上;
99 | - 设置各台机器都可以访问到的 `MASTER_ADDR` 之外;
100 | - 去掉 CPU adam 相关的配置,因为使用了 distributed optimizer,所以多机环境下 optimizer 的显存占比会明显下降。
101 |
102 | 除此之外,还可以进行如下的修改:
103 |
104 | - 当总卡数并不能被 expert 总数乘除时,可以使用 `--sglang-ep-num-redundant-experts` 来增加冗余的 expert,例如对于 24 卡的场景,可以配置:
105 |
106 | ```bash
107 | SGLANG_ARGS=(
108 | --rollout-num-gpus-per-engine 24
109 | --sglang-mem-fraction-static 0.7
110 | --sglang-enable-ep-moe
111 | --sglang-enable-dp-attention
112 | --sglang-dp-size 3
113 |
114 | --sglang-moe-dense-tp-size 1
115 | --sglang-enable-dp-lm-head
116 | --sglang-ep-num-redundant-experts 16
117 | )
118 | ```
119 |
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 | # slime
2 |
3 | [English](./README.md)
4 |
5 | [](https://thudm.github.io/slime/)
6 | [](https://deepwiki.com/THUDM/slime)
7 |
8 | **slime** 是为 RL scaling 设计的 LLM post‑training 框架,提供两大核心能力:
9 |
10 | 1. **高性能训练**:通过连接 Megatron 与 SGLang,支持各种模式的高效训练;
11 | 2. **灵活的数据生成**:通过自定义数据生成接口以及 server based engine,实现任意的数据训练数据生成流程。
12 |
13 | slime 是 [GLM-4.5](https://z.ai/blog/glm-4.5) 与 [GLM-4.6](https://z.ai/blog/glm-4.6) 背后的 RL 训练框架,除此之外,slime 还支持:
14 | - Qwen3 系列 (Qwen3Next, Qwen3MoE, Qwen3), Qwen2.5 系列;
15 | - DeepSeek V3 系列 (DeepSeek V3, V3.1, DeepSeek R1);
16 | - Llama 3。
17 |
18 | ## 博文
19 |
20 | - 我们的愿景:[slime:为 RL Scaling 设计的 SGLang-Native 后训练框架](https://thudm.github.io/slime/zh/blogs/introducing_slime.html)
21 | - 关于纯异步 agentic 训练的一些想法:[Agent-Oriented Design: An Asynchronous and Decoupled Framework for Agentic RL](https://www.notion.so/Agent-Oriented-Design-An-Asynchronous-and-Decoupled-Framework-for-Agentic-RL-2278e692d081802cbdd5d37cef76a547)
22 | - v0.1.0 日志:[slime v0.1.0: 重新定义高性能 RL 训练框架](https://zhuanlan.zhihu.com/p/1945237948166547268)
23 |
24 |
25 | ## 目录
26 |
27 | - [架构总览](#架构总览)
28 | - [快速开始](#快速开始)
29 | - [Checkpoint 格式转换](#checkpoint-格式转换)
30 | - [启动训练流程](#启动训练流程)
31 | - [参数说明](#参数说明)
32 | - [开发指南](#开发指南)
33 | - [常见 Q&A 与致谢](#常见-qa-与致谢)
34 |
35 | ## 架构总览
36 |
37 | 
38 |
39 | **模块说明**:
40 |
41 | - **training (Megatron)**:负责主训练流程,从 Data Buffer 读取数据,训练完后将参数同步至 rollout 模块;
42 | - **rollout (SGLang + router)**:生成新数据(含 reward/verifier),存储至 Data Buffer;
43 | - **data buffer**:桥梁模块,管理 prompt 初始化、自定义数据与 rollout 生成方法。
44 |
45 | ## 快速开始
46 |
47 | 有关环境配置、数据准备、训练启动和关键代码分析的完整快速开始指南,请参考:
48 |
49 | - [快速开始指南](./docs/zh/get_started/quick_start.md)
50 |
51 | 我们还提供了一些未在快速开始中覆盖的使用示例,请查看 [examples](examples/)。
52 |
53 | ## 参数说明
54 |
55 | 参数分为三类:
56 |
57 | 1. **megatron 参数**:slime 会读取 `PYTHONPATH` 中的 megatron 里设置的所有参数,可以通过传入如 `--tensor-model-parallel-size 2` 的方式配置 megatron;
58 | 2. **sglang 参数**:支持环境中安装的 sglang 的所有参数,这些参数需要以 `--sglang` 起始,例如 `--mem-fraction-static` 需要通过 `--sglang-mem-fraction-static` 传入。
59 | 3. **slime 自身的参数**:请见:[slime/utils/arguments.py](slime/utils/arguments.py)
60 |
61 | 完整使用说明请查阅 [使用文档](docs/zh/get_started/usage.md)。
62 |
63 | ## 开发指南
64 |
65 | - **欢迎贡献!** 若有功能建议、性能调优或使用体验反馈,欢迎提交 Issue / PR 😊
66 |
67 | - 使用 [pre-commit](https://pre-commit.com/) 保证提交代码风格:
68 |
69 | ```bash
70 | apt install pre-commit -y
71 | pre-commit install
72 |
73 | # 运行 pre-commit 保证代码风格
74 | pre-commit run --all-files --show-diff-on-failure --color=always
75 | ```
76 |
77 | - 调试技巧请参考 [debug 指南](docs/zh/developer_guide/debug.md)
78 |
79 | ## 常见 Q&A 与致谢
80 |
81 | - 常见问题请见 [Q&A](docs/zh/get_started/qa.md)
82 | - 特别感谢以下项目 & 社区:SGLang、Megatron‑LM、mbridge、OpenRLHF、veRL、Pai-Megatron-Patch 等。
83 |
84 | - 引用 slime 请使用:
85 | ```bibtex
86 | @misc{slime_github,
87 | author = {Zilin Zhu and Chengxing Xie and Xin Lv and slime Contributors},
88 | title = {slime: An LLM post-training framework for RL Scaling},
89 | year = {2025},
90 | howpublished = {\url{https://github.com/THUDM/slime}},
91 | note = {GitHub repository. Corresponding author: Xin Lv},
92 | urldate = {2025-06-19}
93 | }
94 | ```
95 |
--------------------------------------------------------------------------------
/examples/retool/README.md:
--------------------------------------------------------------------------------
1 | # Retool: from SFT to RL
2 |
3 | This example demonstrates how to use the retool functionality for tool-enabled language model generation.
4 |
5 | ## Overview
6 |
7 | The retool example provides:
8 | - Safe Python code execution in a sandbox environment
9 | - Tool registry for managing available tools
10 | - Integration with language model generation
11 | - Reward calculation for tool usage
12 |
13 | ## Files
14 |
15 | - `generate_with_retool.py`: Main generation function with tool support
16 | - `tool_sandbox.py`: Tool execution and safety management
17 | - `sft_data_processing.py`: Process SFT dataset
18 |
19 | ## Usage
20 |
21 | 1. Setup and download datasets:
22 | ```bash
23 | cd slime
24 | pip install -e .
25 | # For SFT part, you can use later model to RL directly and skip SFT.
26 | hf download --repo-type dataset JoeYing/ReTool-SFT --local-dir /root/JoeYing/ReTool-SFT
27 | hf download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/Qwen/Qwen3-4B-Instruct-2507
28 |
29 | # For RL part
30 | hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir /root/dapo-math-17k
31 | hf download --repo-type dataset zhuzilin/aime-2024 --local-dir /root/aime-2024
32 | # download our SFT model if you want to skip SFT
33 | hf download font-info/qwen3-4b-sft-SGLang-RL --local-dir /root/font-info/qwen3-4b-sft
34 | ```
35 |
36 | 2. Create torch dist
37 | For SFT
38 | ```bash
39 | source scripts/models/qwen3-4B.sh
40 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
41 | ${MODEL_ARGS[@]} \
42 | --hf-checkpoint /root/Qwen/Qwen3-4B-Instruct-2507 \
43 | --rotary-base 5000000 \
44 | --save /root/Qwen/Qwen3-4B-Instruct-2507_torch_dist
45 | ```
46 |
47 | Or RL only
48 | ```bash
49 | source scripts/models/qwen3-4B.sh
50 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
51 | ${MODEL_ARGS[@]} \
52 | --hf-checkpoint /root/font-info/qwen3-4b-sft \
53 | --rotary-base 5000000 \
54 | --save /root/font-info/qwen3-4b-sft_torch_dist
55 |
56 | ```
57 |
58 | 3. SFT:
59 | ```bash
60 | python examples/retool/sft_data_processing.py
61 | bash examples/retool/retool_qwen3_4b_sft.sh
62 | ```
63 |
64 | 4. RL:
65 | ```bash
66 | bash examples/retool/retool_qwen3_4b_rl.sh
67 | ```
68 |
69 | 5. Use in your training scripts by importing the generate function:
70 | ```python
71 | from generate_with_retool import generate, reward_func
72 | ```
73 |
74 | ## Tool Format
75 |
76 | The system uses the following tool format:
77 |
78 | ```
79 | You may call one or more functions to assist with the user query.
80 |
81 | You are provided with function signatures within XML tags:
82 |
83 | {"type": "function", "function": {"name": "code_interpreter", "description": "A tool for executing code.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The code to execute."}}, "required": ["code"]}}}
84 |
85 |
86 | For each function call, return a json object with function name and arguments within XML tags:
87 |
88 | {"name": , "arguments": }
89 |
90 | ```
91 |
92 | ## Safety Features
93 |
94 | - Code execution in isolated sandbox
95 | - Memory and time limits
96 | - Dangerous operation detection
97 | - Allowed module restrictions
98 |
--------------------------------------------------------------------------------
/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py:
--------------------------------------------------------------------------------
1 | import os
2 | import slime.utils.external_utils.command_utils as U
3 |
4 | MODEL_NAME = "Qwen3-0.6B"
5 |
6 |
7 | def prepare():
8 | U.exec_command("mkdir -p /root/models /root/datasets")
9 | U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}")
10 | U.hf_download_dataset("zhuzilin/gsm8k")
11 |
12 |
13 | def execute():
14 | ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} "
15 |
16 | rollout_args = (
17 | "--prompt-data /root/datasets/gsm8k/train.parquet "
18 | "--input-key messages "
19 | "--label-key label "
20 | "--apply-chat-template "
21 | "--rollout-shuffle "
22 | "--rm-type math "
23 | f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 60} "
24 | "--rollout-batch-size 32 "
25 | "--n-samples-per-prompt 8 "
26 | "--rollout-max-response-len 1024 "
27 | "--rollout-temperature 0.8 "
28 | "--over-sampling-batch-size 64 "
29 | "--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std "
30 | "--global-batch-size 256 "
31 | )
32 |
33 | eval_args = (
34 | "--eval-interval 20 "
35 | "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet "
36 | "--n-samples-per-eval-prompt 1 "
37 | "--eval-max-response-len 1024 "
38 | "--eval-top-k 1 "
39 | )
40 |
41 | grpo_args = (
42 | "--advantage-estimator grpo "
43 | # "--use-kl-loss "
44 | "--kl-loss-coef 0.00 "
45 | "--kl-loss-type low_var_kl "
46 | "--kl-coef 0.00 "
47 | "--entropy-coef 0.00 "
48 | "--eps-clip 0.2 "
49 | "--eps-clip-high 0.28 "
50 | )
51 |
52 | optimizer_args = (
53 | "--optimizer adam "
54 | "--lr 1e-6 "
55 | "--lr-decay-style constant "
56 | "--weight-decay 0.1 "
57 | "--adam-beta1 0.9 "
58 | "--adam-beta2 0.98 "
59 | )
60 |
61 | sglang_args = "--rollout-num-gpus-per-engine 2 " "--sglang-decode-log-interval 1000 " "--sglang-enable-metrics "
62 |
63 | fsdp_args = (
64 | # Set the bucket size for weight update
65 | "--update-weight-buffer-size 536870912 " # 512MB
66 | )
67 |
68 | ci_args = (
69 | "--ci-test "
70 | "--ci-disable-kl-checker "
71 | "--ci-metric-checker-key eval/gsm8k "
72 | "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step
73 | )
74 |
75 | misc_args = "--actor-num-nodes 1 " "--actor-num-gpus-per-node 2 " "--colocate " "--train-backend fsdp "
76 |
77 | train_args = (
78 | f"{ckpt_args} "
79 | f"{rollout_args} "
80 | f"{optimizer_args} "
81 | f"{grpo_args} "
82 | f"{sglang_args} "
83 | f"{U.get_default_wandb_args(__file__)} "
84 | f"{eval_args} "
85 | f"{fsdp_args} "
86 | f"{ci_args} "
87 | f"{misc_args} "
88 | )
89 |
90 | U.execute_train(
91 | train_args=train_args,
92 | num_gpus_per_node=2,
93 | megatron_model_type=None,
94 | )
95 |
96 |
97 | if __name__ == "__main__":
98 | prepare()
99 | os.environ.pop("http_proxy")
100 | os.environ.pop("https_proxy")
101 | os.environ.pop("HTTP_PROXY")
102 | os.environ.pop("HTTPS_PROXY")
103 | execute()
104 |
--------------------------------------------------------------------------------
/train_async.py:
--------------------------------------------------------------------------------
1 | import ray
2 |
3 | from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models
4 | from slime.utils.arguments import parse_args
5 | from slime.utils.logging_utils import configure_logger
6 | from slime.utils.misc import should_run_periodic_action
7 | from slime.utils.tracking_utils import init_tracking
8 |
9 |
10 | # The framework supports other asynchronous approaches such as fully async (which is shown in examples/full_async).
11 | def train(args):
12 | assert not args.colocate, "Colocation is not supported for async training."
13 | configure_logger()
14 | # allocate the GPUs
15 | pgs = create_placement_groups(args)
16 | init_tracking(args)
17 |
18 | # create the rollout manager, with sglang engines inside.
19 | # need to initialize rollout manager first to calculate num_rollout
20 | rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"])
21 |
22 | # create the actor and critic models
23 | actor_model, critic_model = create_training_models(args, pgs, rollout_manager)
24 |
25 | # always update weight first so that sglang has the loaded weights from training.
26 | actor_model.update_weights()
27 |
28 | if args.check_weight_update_equal:
29 | ray.get(rollout_manager.check_weights.remote(action="compare"))
30 |
31 | # async train loop.
32 | rollout_data_next_future = rollout_manager.generate.remote(args.start_rollout_id)
33 | for rollout_id in range(args.start_rollout_id, args.num_rollout):
34 | # Sync the last generation
35 | if rollout_data_next_future is not None:
36 | rollout_data_curr_ref = ray.get(rollout_data_next_future)
37 |
38 | # Start the next rollout early.
39 | if rollout_id + 1 < args.num_rollout:
40 | rollout_data_next_future = rollout_manager.generate.remote(rollout_id + 1)
41 |
42 | if args.use_critic:
43 | critic_train_handle = critic_model.async_train(rollout_id, rollout_data_curr_ref)
44 | if rollout_id >= args.num_critic_only_steps:
45 | ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref))
46 | ray.get(critic_train_handle)
47 | else:
48 | ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref))
49 |
50 | if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch, args.num_rollout):
51 | actor_model.save_model(rollout_id, force_sync=rollout_id == args.num_rollout - 1)
52 | if args.use_critic:
53 | critic_model.save_model(rollout_id, force_sync=rollout_id == args.num_rollout - 1)
54 | if args.rollout_global_dataset:
55 | ray.get(rollout_manager.save.remote(rollout_id))
56 |
57 | if (rollout_id + 1) % args.update_weights_interval == 0:
58 | # sync generate before update weights to prevent update weight in the middle of generation
59 | rollout_data_curr_ref = ray.get(x) if (x := rollout_data_next_future) is not None else None
60 | rollout_data_next_future = None
61 | actor_model.update_weights()
62 |
63 | if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch):
64 | ray.get(rollout_manager.eval.remote(rollout_id))
65 |
66 | ray.get(rollout_manager.dispose.remote())
67 |
68 |
69 | if __name__ == "__main__":
70 | args = parse_args()
71 | train(args)
72 |
--------------------------------------------------------------------------------
/docs/zh/examples/qwen3-next-80B-A3B.md:
--------------------------------------------------------------------------------
1 | # 8xH100 训练 Qwen3-30B-A3B
2 |
3 | ## 环境准备
4 |
5 | 搭建环境、下载模型、数据与 ckpt 转换均与 Qwen3-4B 模型相同,可以参考 [示例:Qwen3-4B](./qwen3-4B.md),将文中 Qwen3-4B 的部分转换为
6 | Qwen3-next-80B-A3B-Instruct 即可。
7 |
8 | 可以用如下完整方法把 huggingface checkpoint 转化为 torch_dist 格式:
9 |
10 | ```bash
11 | export BASE_FOLDER=./models/
12 | # 下载模型权重 (Qwen3-Next-80B-A3B-Thinking)
13 | hf download Qwen/Qwen3-Next-80B-A3B-Thinking --local-dir ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking
14 | ```
15 |
16 | ```shell
17 | cd slime/
18 | pip install -e .
19 |
20 | # (for acceleration)
21 | cd .. # and find a proper folder
22 | git clone https://github.com/fla-org/flash-linear-attention
23 | cd flash-linear-attention
24 | git checkout 9714c595
25 | pip install -e .
26 |
27 | wget https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.4/causal_conv1d-1.5.4+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
28 | pip install ./causal_conv1d-1.5.4+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
29 | ```
30 |
31 | ## [Optional] Fix a bug in triton compilation on Blackwell (sm100)
32 |
33 | see discussion here https://github.com/triton-lang/triton/issues/8695
34 | and https://github.com/fla-org/flash-linear-attention/issues/638
35 |
36 | We need to apply a patch to fix the bug.
37 | Go to the flash-linear-attention folder you just installed, and apply the following patch:
38 |
39 | ```diff
40 | diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py
41 | index c5119dcf..838f5e4e 100644
42 | --- a/fla/ops/gated_delta_rule/wy_fast.py
43 | +++ b/fla/ops/gated_delta_rule/wy_fast.py
44 | @@ -198,7 +198,14 @@ def prepare_wy_repr_bwd_kernel(
45 | b_A += tl.dot(b_kb, tl.trans(b_k))
46 | b_dkb = tl.dot(b_dA, b_k)
47 | b_db += tl.sum(b_dkb * b_k, 1)
48 | - b_dk += tl.dot(tl.trans(b_dA), b_kb)
49 | + b_dk += tl.inline_asm_elementwise(
50 | + asm="mov.f32 $0, $1;",
51 | + constraints="=r,r",
52 | + args=[tl.dot(tl.trans(b_dA), b_kb)],
53 | + dtype=tl.float32,
54 | + is_pure=True,
55 | + pack=1,
56 | + )
57 | b_dk += b_dkb * b_b[:, None]
58 | tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
59 | tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,))
60 |
61 | ```
62 |
63 | save it as `patch.diff` (Please remember to copy the last empty line to the file!) and do `git apply patch.diff`
64 |
65 | ## 执行训练 (Megatron)
66 |
67 | **当前暂不支持Blackwell**
68 |
69 | 转换模型权重:
70 |
71 | ```bash
72 | source scripts/models/qwen3-next-80B-A3B.sh
73 | PYTHONPATH=/root/Megatron-LM/ torchrun --nproc-per-node 8 \
74 | tools/convert_hf_to_torch_dist.py \
75 | ${MODEL_ARGS[@]} \
76 | --hf-checkpoint /root/Qwen3-Next-80B-A3B-Thinking/ \
77 | --save /root/Qwen3-Next-80B-A3B-Thinking_torch_dist/
78 | ```
79 |
80 | 单机8卡
81 |
82 | ```bash
83 | cd /root/slime
84 | export BASE_FOLDER=/root
85 | export MASTER_ADDR=127.0.0.1
86 | bash scripts/run-qwen3-next-80B-A3B-8gpus.sh
87 | ```
88 | 如果显存不够,考虑disable `--accumulate-allreduce-grads-in-fp32`,enable `--grad-reduce-in-bf16`
89 |
90 |
91 | 多机(4x8)
92 |
93 | ```bash
94 | cd /root/slime
95 | export BASE_FOLDER=/root
96 | export MASTER_ADDR=your_master_addr
97 | bash scripts/run-qwen3-next-80B-A3B.sh
98 | ```
99 |
100 | ## 执行训练 (FSDP)
101 |
102 | ```bash
103 | export BASE_FOLDER=./models/
104 | export MASTER_ADDR=127.0.0.1
105 |
106 | bash scripts/run-qwen3-next-80B-A3B-fsdp.sh
107 | ```
108 |
109 |
--------------------------------------------------------------------------------