├── slime ├── __init__.py ├── ray │ ├── __init__.py │ ├── ray_actor.py │ └── utils.py ├── rollout │ ├── __init__.py │ ├── filter_hub │ │ ├── __init__.py │ │ ├── base_types.py │ │ └── dynamic_sampling_filters.py │ ├── generate_hub │ │ ├── __init__.py │ │ └── benchmarkers.py │ ├── sleep_rollout.py │ ├── base_types.py │ ├── rm_hub │ │ ├── deepscaler.py │ │ ├── f1.py │ │ └── __init__.py │ └── sft_rollout.py ├── router │ ├── __init__.py │ └── middleware_hub │ │ └── __init__.py ├── backends │ ├── __init__.py │ ├── sglang_utils │ │ └── __init__.py │ ├── fsdp_utils │ │ ├── kernels │ │ │ └── __init__.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── qwen3_moe_hf.py │ │ └── __init__.py │ └── megatron_utils │ │ ├── update_weight │ │ ├── __init__.py │ │ ├── hf_weight_iterator_base.py │ │ └── hf_weight_iterator_bridge.py │ │ ├── megatron_to_hf │ │ ├── processors │ │ │ ├── __init__.py │ │ │ └── padding_remover.py │ │ ├── llama.py │ │ └── mimo.py │ │ ├── misc_utils.py │ │ ├── sglang.py │ │ ├── arguments.py │ │ ├── __init__.py │ │ └── checkpoint.py └── utils │ ├── debug_utils │ ├── __init__.py │ ├── replay_reward_fn.py │ ├── send_to_sglang.py │ └── display_debug_rollout_data.py │ ├── external_utils │ └── __init__.py │ ├── __init__.py │ ├── ray_utils.py │ ├── context_utils.py │ ├── logging_utils.py │ ├── train_dump_utils.py │ ├── tracking_utils.py │ ├── megatron_bridge_utils.py │ ├── async_utils.py │ ├── metric_checker.py │ ├── rocm_checkpoint_writer.py │ ├── iter_utils.py │ ├── memory_utils.py │ ├── train_metric_utils.py │ ├── tensorboard_utils.py │ ├── timer.py │ ├── typer_utils.py │ ├── fp8_kernel.py │ └── misc.py ├── examples ├── __init__.py ├── multi_agent │ ├── __init__.py │ ├── rollout_with_multi_agents.py │ ├── README.md │ └── prompts.py ├── eval │ ├── __init__.py │ ├── nemo_skills │ │ ├── __init__.py │ │ ├── config │ │ │ └── local_cluster.yaml │ │ ├── skills_config.py │ │ └── skills_client.py │ ├── scripts │ │ └── multi_tasks.yaml │ └── README.md ├── retool │ ├── requirements.txt │ ├── rl_data_preprocess.py │ ├── sft_data_processing.py │ └── README.md ├── strands-agents │ ├── requirements.txt │ └── README.md ├── true_on_policy │ └── src │ │ ├── aime.png │ │ ├── raw_reward.png │ │ ├── step_time.png │ │ ├── rollout_time.png │ │ └── train_rollout_abs_diff.png ├── true_on_policy_vlm │ ├── diff.png │ └── README.md ├── geo3k_vlm │ └── fsdp_vs_megatron.png ├── eval_multi_task │ ├── requirements_ifbench.txt │ ├── multi_task.yaml │ └── README.md ├── formal_math │ └── single_round │ │ └── README.md ├── tau-bench │ ├── sglang_tool_parser.py │ ├── tau1_mock.py │ └── README.md ├── train_infer_mismatch_helper │ └── mis.yaml ├── on_policy_distillation │ └── on_policy_distillation.py ├── search-r1 │ └── local_dense_retriever │ │ └── download.py ├── reproducibility │ └── README.md └── fully_async │ └── README.md ├── slime_plugins ├── __init__.py ├── models │ ├── __init__.py │ └── glm4.py ├── megatron_bridge │ └── __init__.py ├── rollout_buffer │ ├── generator │ │ └── __init__.py │ ├── README_zh.md │ └── README.md └── mbridge │ └── __init__.py ├── tests ├── ci │ ├── github_runner │ │ ├── .gitignore │ │ ├── .env.example │ │ └── docker-compose.yml │ └── README.md ├── test_fsdp_import.py ├── test_gspo.sh ├── test_chunked_gae.py └── test_qwen3_0.6B_fsdp_colocated_2xGPU.py ├── docker ├── version.txt ├── README.md ├── justfile └── amd_patch │ ├── latest │ └── amd_megatron_fused_kernels_init.patch │ └── sglv0.5.0rc0 │ └── amd_megatron_fused_kernels_init.patch ├── imgs └── arch.png ├── docs ├── _static │ ├── image │ │ ├── logo.ico │ │ ├── logo.jpg │ │ └── blogs │ │ │ └── release_v0.1.0 │ │ │ ├── cuda_vmm.png │ │ │ └── overrall.png │ └── css │ │ ├── readthedocs.css │ │ └── custom_log.css ├── requirements.txt ├── build.sh ├── zh │ ├── advanced │ │ ├── fault-torlance.md │ │ ├── speculative-decoding.md │ │ └── arch-support-beyond-megatron.md │ ├── index.rst │ ├── developer_guide │ │ └── debug.md │ ├── get_started │ │ └── qa.md │ └── examples │ │ ├── qwen3-4b-base-openhermes.md │ │ ├── qwen3-30B-A3B.md │ │ └── qwen3-next-80B-A3B.md ├── serve.sh ├── README.md ├── en │ ├── advanced │ │ ├── fault-tolerance.md │ │ └── speculative-decoding.md │ └── index.rst └── build_all.sh ├── scripts └── models │ ├── deepseek-v3-5layer.sh │ ├── deepseek-v3-20layer.sh │ ├── qwen3-4B-Instruct-2507.sh │ ├── qwen2.5-0.5B.sh │ ├── qwen2.5-1.5B.sh │ ├── qwen2.5-3B.sh │ ├── qwen3-0.6B.sh │ ├── qwen3-1.7B.sh │ ├── qwen2.5-32B.sh │ ├── qwen2.5-7B.sh │ ├── qwen3-4B.sh │ ├── qwen3-8B.sh │ ├── qwen3-14B.sh │ ├── qwen3-32B.sh │ ├── llama3.2-3B-Instruct.sh │ ├── llama3.2-3B-Instruct-amd.sh │ ├── mimo-7B-rl.sh │ ├── llama3.1-8B-Instruct.sh │ ├── glm4-9B.sh │ ├── glm4-32B.sh │ ├── qwen3-30B-A3B.sh │ ├── glm4.5-106B-A12B.sh │ ├── qwen3-235B-A22B.sh │ ├── glm4.5-355B-A32B.sh │ ├── qwen3-next-80B-A3B.sh │ ├── kimi-k2.sh │ ├── kimi-k2-thinking.sh │ ├── deepseek-v3.sh │ └── moonlight.sh ├── requirements.txt ├── .github └── workflows │ ├── pre-commit.yml │ ├── generate_github_workflows.py │ ├── release-docs.yaml │ └── conda-ci.yml ├── .pre-commit-config.yaml ├── setup.py ├── pyproject.toml ├── README_zh.md └── train_async.py /slime/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/ray/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/rollout/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/router/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime_plugins/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/multi_agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/backends/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slime_plugins/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/rollout/filter_hub/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/utils/debug_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/utils/external_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/ci/github_runner/.gitignore: -------------------------------------------------------------------------------- 1 | .env -------------------------------------------------------------------------------- /docker/version.txt: -------------------------------------------------------------------------------- 1 | nightly-dev-20251222a -------------------------------------------------------------------------------- /slime/backends/sglang_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/router/middleware_hub/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime_plugins/megatron_bridge/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/backends/fsdp_utils/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/backends/fsdp_utils/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/update_weight/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/megatron_to_hf/processors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility package root for Slime.""" 2 | -------------------------------------------------------------------------------- /imgs/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/imgs/arch.png -------------------------------------------------------------------------------- /examples/eval/__init__.py: -------------------------------------------------------------------------------- 1 | """Evaluation helpers and example configs.""" 2 | -------------------------------------------------------------------------------- /examples/eval/nemo_skills/__init__.py: -------------------------------------------------------------------------------- 1 | """NeMo Skills evaluation helpers.""" 2 | -------------------------------------------------------------------------------- /examples/retool/requirements.txt: -------------------------------------------------------------------------------- 1 | jinja2>=3.0.0 2 | psutil>=5.8.0 3 | pytest>=7.0.0 4 | -------------------------------------------------------------------------------- /examples/strands-agents/requirements.txt: -------------------------------------------------------------------------------- 1 | camel-ai 2 | strands-agents 3 | strands-agents-tools 4 | -------------------------------------------------------------------------------- /docs/_static/image/logo.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/docs/_static/image/logo.ico -------------------------------------------------------------------------------- /docs/_static/image/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/docs/_static/image/logo.jpg -------------------------------------------------------------------------------- /slime/rollout/generate_hub/__init__.py: -------------------------------------------------------------------------------- 1 | # TODO: maybe move `sglang_rollout::generate` to this folder 2 | -------------------------------------------------------------------------------- /examples/true_on_policy/src/aime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy/src/aime.png -------------------------------------------------------------------------------- /examples/true_on_policy_vlm/diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy_vlm/diff.png -------------------------------------------------------------------------------- /scripts/models/deepseek-v3-5layer.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS_NUM_LAYERS=5 source "$(dirname -- "${BASH_SOURCE[0]}")/deepseek-v3.sh" 2 | -------------------------------------------------------------------------------- /examples/geo3k_vlm/fsdp_vs_megatron.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/geo3k_vlm/fsdp_vs_megatron.png -------------------------------------------------------------------------------- /scripts/models/deepseek-v3-20layer.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS_NUM_LAYERS=20 source "$(dirname -- "${BASH_SOURCE[0]}")/deepseek-v3.sh" 2 | -------------------------------------------------------------------------------- /scripts/models/qwen3-4B-Instruct-2507.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS_ROTARY_BASE=5000000 source "$(dirname -- "${BASH_SOURCE[0]}")/qwen3-4B.sh" -------------------------------------------------------------------------------- /examples/eval_multi_task/requirements_ifbench.txt: -------------------------------------------------------------------------------- 1 | emoji 2 | immutabledict 3 | nltk 4 | numpy==1.26.4 5 | spacy==3.7.4 6 | syllapy 7 | -------------------------------------------------------------------------------- /examples/true_on_policy/src/raw_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy/src/raw_reward.png -------------------------------------------------------------------------------- /examples/true_on_policy/src/step_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy/src/step_time.png -------------------------------------------------------------------------------- /tests/ci/github_runner/.env.example: -------------------------------------------------------------------------------- 1 | GITHUB_RUNNER_URL=https://github.com/slimerl/slime 2 | GITHUB_RUNNER_TOKEN=paste-your-token-here -------------------------------------------------------------------------------- /examples/true_on_policy/src/rollout_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy/src/rollout_time.png -------------------------------------------------------------------------------- /docs/_static/image/blogs/release_v0.1.0/cuda_vmm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/docs/_static/image/blogs/release_v0.1.0/cuda_vmm.png -------------------------------------------------------------------------------- /docs/_static/image/blogs/release_v0.1.0/overrall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/docs/_static/image/blogs/release_v0.1.0/overrall.png -------------------------------------------------------------------------------- /examples/true_on_policy/src/train_rollout_abs_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/slime/HEAD/examples/true_on_policy/src/train_rollout_abs_diff.png -------------------------------------------------------------------------------- /slime/rollout/filter_hub/base_types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class DynamicFilterOutput: 6 | keep: bool 7 | reason: str | None = None 8 | -------------------------------------------------------------------------------- /slime/utils/ray_utils.py: -------------------------------------------------------------------------------- 1 | class Box: 2 | def __init__(self, inner): 3 | self._inner = inner 4 | 5 | @property 6 | def inner(self): 7 | return self._inner 8 | -------------------------------------------------------------------------------- /slime_plugins/rollout_buffer/generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_generator import BaseGenerator, query_single_turn 2 | 3 | __all__ = [ 4 | "BaseGenerator", 5 | "query_single_turn", 6 | ] 7 | -------------------------------------------------------------------------------- /docs/_static/css/readthedocs.css: -------------------------------------------------------------------------------- 1 | table.autosummary td { 2 | width: 50% 3 | } 4 | 5 | img.align-center { 6 | display: block; 7 | margin-left: auto; 8 | margin-right: auto; 9 | } 10 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | def strip_param_name_prefix(name: str): 2 | prefix = "module." 3 | while name.startswith(prefix): 4 | name = name.removeprefix(prefix) 5 | return name 6 | -------------------------------------------------------------------------------- /slime_plugins/mbridge/__init__.py: -------------------------------------------------------------------------------- 1 | from .glm4 import GLM4Bridge 2 | from .glm4moe import GLM4MoEBridge 3 | from .mimo import MimoBridge 4 | from .qwen3_next import Qwen3NextBridge 5 | 6 | __all__ = ["GLM4Bridge", "GLM4MoEBridge", "Qwen3NextBridge", "MimoBridge"] 7 | -------------------------------------------------------------------------------- /tests/test_fsdp_import.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_fsdp_import(): 5 | try: 6 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 7 | except ImportError: 8 | pytest.skip("FSDP not available in this environment") 9 | assert FSDP is not None 10 | -------------------------------------------------------------------------------- /slime/rollout/sleep_rollout.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | def sleep(args, rollout_id, data_source, evaluation=False): 8 | count = 0 9 | while True: 10 | time.sleep(3600) 11 | count += 1 12 | logger.info(f"rollout sleep for {count} hours") 13 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | gguf>=0.10.0 2 | ipykernel 3 | ipywidgets 4 | jupyter_client 5 | markdown>=3.4.0 6 | matplotlib 7 | myst-parser 8 | nbconvert 9 | nbsphinx 10 | nbstripout 11 | pandoc 12 | pillow 13 | pydantic 14 | sphinx 15 | sphinx-autobuild 16 | sphinx-book-theme 17 | sphinx-copybutton 18 | sphinx-tabs 19 | sphinxcontrib-mermaid 20 | urllib3<2.0.0 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | blobfile 3 | datasets 4 | httpx[http2] 5 | mcp[cli] 6 | memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml 7 | omegaconf 8 | pillow 9 | pylatexenc 10 | pyyaml 11 | qwen_vl_utils # for VLM 12 | ray[default] 13 | ring_flash_attn 14 | sglang-router>=0.2.3 15 | tensorboard 16 | transformers 17 | wandb 18 | -------------------------------------------------------------------------------- /slime/utils/context_utils.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | 4 | def with_defer(deferred_func): 5 | def decorator(fn): 6 | @wraps(fn) 7 | def wrapper(*args, **kwargs): 8 | try: 9 | return fn(*args, **kwargs) 10 | finally: 11 | deferred_func() 12 | 13 | return wrapper 14 | 15 | return decorator 16 | -------------------------------------------------------------------------------- /docs/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 4 | LANG=$1 5 | 6 | # make sure language is only en or zh 7 | if [ "$LANG" != "en" ] && [ "$LANG" != "zh" ]; then 8 | echo "Language must be en or zh" 9 | exit 1 10 | fi 11 | 12 | cd $SCRIPT_DIR 13 | SLIME_DOC_LANG=$LANG sphinx-build -b html -D language=$LANG --conf-dir ./ ./$LANG ./build/$LANG -------------------------------------------------------------------------------- /slime/ray/ray_actor.py: -------------------------------------------------------------------------------- 1 | from slime.utils.misc import get_current_node_ip, get_free_port 2 | 3 | 4 | class RayActor: 5 | @staticmethod 6 | def _get_current_node_ip_and_free_port(start_port=10000, consecutive=1): 7 | return get_current_node_ip(), get_free_port(start_port=start_port, consecutive=consecutive) 8 | 9 | def get_master_addr_and_port(self): 10 | return self.master_addr, self.master_port 11 | -------------------------------------------------------------------------------- /scripts/models/qwen2.5-0.5B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 24 4 | --hidden-size 896 5 | --ffn-hidden-size 4864 6 | --num-attention-heads 14 7 | --use-rotary-position-embeddings 8 | --disable-bias-linear 9 | --add-qkv-bias 10 | --normalization "RMSNorm" 11 | --norm-epsilon 1e-6 12 | --rotary-base 1000000 13 | --group-query-attention 14 | --num-query-groups 2 15 | --vocab-size 151936 16 | ) -------------------------------------------------------------------------------- /scripts/models/qwen2.5-1.5B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 28 4 | --hidden-size 1536 5 | --ffn-hidden-size 8960 6 | --num-attention-heads 12 7 | --use-rotary-position-embeddings 8 | --disable-bias-linear 9 | --add-qkv-bias 10 | --normalization "RMSNorm" 11 | --norm-epsilon 1e-6 12 | --rotary-base 10000 13 | --group-query-attention 14 | --num-query-groups 2 15 | --vocab-size 151936 16 | ) -------------------------------------------------------------------------------- /scripts/models/qwen2.5-3B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 36 4 | --hidden-size 2048 5 | --ffn-hidden-size 11008 6 | --num-attention-heads 16 7 | --use-rotary-position-embeddings 8 | --disable-bias-linear 9 | --add-qkv-bias 10 | --normalization "RMSNorm" 11 | --norm-epsilon 1e-6 12 | --rotary-base 1000000 13 | --group-query-attention 14 | --num-query-groups 2 15 | --vocab-size 151936 16 | ) 17 | -------------------------------------------------------------------------------- /scripts/models/qwen3-0.6B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 28 4 | --hidden-size 1024 5 | --ffn-hidden-size 3072 6 | --num-attention-heads 16 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --normalization "RMSNorm" 12 | --norm-epsilon 1e-6 13 | --rotary-base 1000000 14 | --vocab-size 151936 15 | --kv-channels 128 16 | --qk-layernorm 17 | ) -------------------------------------------------------------------------------- /scripts/models/qwen3-1.7B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 28 4 | --hidden-size 2048 5 | --ffn-hidden-size 6144 6 | --num-attention-heads 16 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --normalization "RMSNorm" 12 | --norm-epsilon 1e-6 13 | --rotary-base 1000000 14 | --vocab-size 151936 15 | --kv-channels 128 16 | --qk-layernorm 17 | ) -------------------------------------------------------------------------------- /scripts/models/qwen2.5-32B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 64 4 | --hidden-size 5120 5 | --ffn-hidden-size 27648 6 | --num-attention-heads 40 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --add-qkv-bias 12 | --normalization "RMSNorm" 13 | --norm-epsilon 1e-5 14 | --rotary-base 1000000 15 | --vocab-size 152064 16 | --untie-embeddings-and-output-weights 17 | ) -------------------------------------------------------------------------------- /scripts/models/qwen2.5-7B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 28 4 | --hidden-size 3584 5 | --ffn-hidden-size 18944 6 | --num-attention-heads 28 7 | --group-query-attention 8 | --num-query-groups 4 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --add-qkv-bias 12 | --normalization "RMSNorm" 13 | --norm-epsilon 1e-06 14 | --rotary-base 1000000 15 | --vocab-size 152064 16 | --untie-embeddings-and-output-weights 17 | ) -------------------------------------------------------------------------------- /scripts/models/qwen3-4B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 36 4 | --hidden-size 2560 5 | --ffn-hidden-size 9728 6 | --num-attention-heads 32 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --normalization "RMSNorm" 12 | --norm-epsilon 1e-6 13 | --rotary-base "${MODEL_ARGS_ROTARY_BASE:-1000000}" 14 | --vocab-size 151936 15 | --kv-channels 128 16 | --qk-layernorm 17 | ) -------------------------------------------------------------------------------- /slime/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | _LOGGER_CONFIGURED = False 4 | 5 | 6 | # ref: SGLang 7 | def configure_logger(prefix: str = ""): 8 | global _LOGGER_CONFIGURED 9 | if _LOGGER_CONFIGURED: 10 | return 11 | 12 | _LOGGER_CONFIGURED = True 13 | 14 | logging.basicConfig( 15 | level=logging.INFO, 16 | format=f"[%(asctime)s{prefix}] %(filename)s:%(lineno)d - %(message)s", 17 | datefmt="%Y-%m-%d %H:%M:%S", 18 | force=True, 19 | ) 20 | -------------------------------------------------------------------------------- /scripts/models/qwen3-8B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 36 4 | --hidden-size 4096 5 | --ffn-hidden-size 12288 6 | --num-attention-heads 32 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --normalization "RMSNorm" 12 | --norm-epsilon 1e-6 13 | --rotary-base 1000000 14 | --vocab-size 151936 15 | --kv-channels 128 16 | --qk-layernorm 17 | --untie-embeddings-and-output-weights 18 | ) -------------------------------------------------------------------------------- /scripts/models/qwen3-14B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 40 4 | --hidden-size 5120 5 | --ffn-hidden-size 17408 6 | --num-attention-heads 40 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --normalization "RMSNorm" 12 | --norm-epsilon 1e-6 13 | --rotary-base 1000000 14 | --vocab-size 151936 15 | --kv-channels 128 16 | --qk-layernorm 17 | --untie-embeddings-and-output-weights 18 | ) 19 | -------------------------------------------------------------------------------- /scripts/models/qwen3-32B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 64 4 | --hidden-size 5120 5 | --ffn-hidden-size 25600 6 | --num-attention-heads 64 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --normalization "RMSNorm" 12 | --norm-epsilon 1e-6 13 | --rotary-base 1000000 14 | --vocab-size 151936 15 | --kv-channels 128 16 | --qk-layernorm 17 | --untie-embeddings-and-output-weights 18 | ) 19 | -------------------------------------------------------------------------------- /scripts/models/llama3.2-3B-Instruct.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 28 4 | --hidden-size 3072 5 | --ffn-hidden-size 8192 6 | --num-attention-heads 24 7 | --group-query-attention 8 | --num-query-groups 8 9 | --max-position-embeddings 131072 10 | --use-rotary-position-embeddings 11 | --disable-bias-linear 12 | --normalization "RMSNorm" 13 | --norm-epsilon 1e-5 14 | --rotary-base 500000 15 | --vocab-size 128256 16 | --kv-channels 128 17 | --use-rope-scaling 18 | --rotary-scaling-factor 32.0 19 | ) -------------------------------------------------------------------------------- /slime/backends/megatron_utils/megatron_to_hf/processors/padding_remover.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from slime.backends.megatron_utils.misc_utils import strip_param_name_prefix 4 | 5 | 6 | def remove_padding(name: str, param: torch.Tensor, vocab_size: int) -> torch.Tensor: 7 | """ 8 | Remove vocab padding: param[:vocab_size] for embedding/output layers, else unchanged. 9 | """ 10 | if strip_param_name_prefix(name) in {"embedding.word_embeddings.weight", "output_layer.weight"}: 11 | return param[:vocab_size] 12 | return param 13 | -------------------------------------------------------------------------------- /scripts/models/llama3.2-3B-Instruct-amd.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 28 4 | --hidden-size 3072 5 | --ffn-hidden-size 8192 6 | --num-attention-heads 24 7 | --group-query-attention 8 | --num-query-groups 8 9 | --max-position-embeddings 131072 10 | --use-rotary-position-embeddings 11 | --disable-bias-linear 12 | --normalization "RMSNorm" 13 | --norm-epsilon 1e-5 14 | --rotary-base 500000 15 | --vocab-size 128256 16 | --kv-channels 128 17 | --use-rope-scaling 18 | --rotary-scaling-factor 32.0 19 | ) 20 | -------------------------------------------------------------------------------- /scripts/models/mimo-7B-rl.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 36 4 | --hidden-size 4096 5 | --ffn-hidden-size 11008 6 | --num-attention-heads 32 7 | --group-query-attention 8 | --num-query-groups 8 9 | --use-rotary-position-embeddings 10 | --disable-bias-linear 11 | --add-qkv-bias 12 | --normalization "RMSNorm" 13 | --norm-epsilon 1e-05 14 | --rotary-base 640000 15 | --vocab-size 151680 16 | --untie-embeddings-and-output-weights 17 | --max-position-embeddings 32768 18 | --mtp-num-layers 1 19 | ) 20 | -------------------------------------------------------------------------------- /docs/zh/advanced/fault-torlance.md: -------------------------------------------------------------------------------- 1 | # 容灾 2 | 3 | 为了保证长期稳定的 RL 训练,slime 会默认开始一定程度的容灾机制。这里主要介绍一下 slime 中容灾的一些设计思路。 4 | 5 | 可以通过 `--use-fault-tolerance` 开启容灾机制。 6 | 7 | ## rollout 容灾 8 | 9 | slime 会在 rollout 过程中,定期向所有 SGLang server 发送心跳请求(`/health_generate`),如果心跳超时,则会停止这个 SGLang server。并在这轮 rollout 完成之后进行重启和正确的参数更新。 10 | 11 | - `--rollout-health-check-first-wait`:由于一些大的 MoE 模型在第一次运行时需要处理一些编译,我们会在第一次 rollout 前等待 `rollout_health_check_first_wait` 秒再开始发送心跳,默认为 300s; 12 | - `--rollout-health-check-interval`:心跳检查间隔,默认为 10s; 13 | - `--rollout-health-check-timeout`:心跳超时限额,默认为 5s。 14 | -------------------------------------------------------------------------------- /scripts/models/llama3.1-8B-Instruct.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --swiglu 3 | --num-layers 32 4 | --hidden-size 4096 5 | --ffn-hidden-size 14336 6 | --num-attention-heads 32 7 | --group-query-attention 8 | --num-query-groups 8 9 | --max-position-embeddings 131072 10 | --use-rotary-position-embeddings 11 | --disable-bias-linear 12 | --normalization "RMSNorm" 13 | --norm-epsilon 1e-5 14 | --rotary-base 500000 15 | --vocab-size 128256 16 | --kv-channels 128 17 | --use-rope-scaling 18 | --rotary-scaling-factor 8.0 19 | --untie-embeddings-and-output-weights 20 | ) -------------------------------------------------------------------------------- /slime/rollout/filter_hub/dynamic_sampling_filters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from slime.rollout.filter_hub.base_types import DynamicFilterOutput 4 | from slime.utils.types import Sample 5 | 6 | __all__ = ["check_reward_nonzero_std"] 7 | 8 | 9 | def check_reward_nonzero_std(args, samples: list[Sample], **kwargs): 10 | rewards = [sample.get_reward_value(args) for sample in samples] 11 | keep = torch.tensor(rewards, dtype=torch.float).std() > 0.0 12 | return DynamicFilterOutput( 13 | keep=keep, 14 | reason=None if keep else f"zero_std_{round(rewards[0], 1)}", 15 | ) 16 | -------------------------------------------------------------------------------- /examples/eval/nemo_skills/config/local_cluster.yaml: -------------------------------------------------------------------------------- 1 | # Minimal cluster config for running `ns eval` directly on the current host 2 | # without spinning up containers or mount checks. 3 | 4 | executor: none 5 | 6 | # Provide stub container entries so pipeline code can reference them even 7 | # though we do not launch actual containers when executor == "none". 8 | containers: 9 | nemo-skills: "" 10 | sglang: "" 11 | trtllm: "" 12 | vllm: "" 13 | 14 | # No mount enforcement in "none" mode, but keep the field for completeness. 15 | mounts: [] 16 | 17 | # Optional default env vars propagated to ns jobs. Leave empty unless needed. 18 | env_vars: [] 19 | -------------------------------------------------------------------------------- /scripts/models/glm4-9B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --spec "slime_plugins.models.glm4" "get_glm_spec" 3 | --swiglu 4 | --num-layers 40 5 | --hidden-size 4096 6 | --ffn-hidden-size 13696 7 | --num-attention-heads 32 8 | --group-query-attention 9 | --num-query-groups 2 10 | --use-rotary-position-embeddings 11 | --disable-bias-linear 12 | --add-qkv-bias 13 | --normalization "RMSNorm" 14 | --norm-epsilon 1e-5 15 | --rotary-base 10000 16 | --vocab-size 151552 17 | --post-self-attn-layernorm 18 | --post-mlp-layernorm 19 | --rotary-interleaved 20 | --rotary-percent 0.5 21 | --no-rope-fusion 22 | --untie-embeddings-and-output-weights 23 | ) -------------------------------------------------------------------------------- /slime_plugins/models/glm4.py: -------------------------------------------------------------------------------- 1 | from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec 2 | 3 | 4 | def get_glm_spec(args, config, vp_stage): 5 | transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( 6 | num_experts=args.num_experts, 7 | moe_grouped_gemm=args.moe_grouped_gemm, 8 | qk_layernorm=args.qk_layernorm, 9 | multi_latent_attention=args.multi_latent_attention, 10 | moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, 11 | post_self_attn_layernorm=args.post_self_attn_layernorm, 12 | post_mlp_layernorm=args.post_mlp_layernorm, 13 | ) 14 | return transformer_layer_spec 15 | -------------------------------------------------------------------------------- /scripts/models/glm4-32B.sh: -------------------------------------------------------------------------------- 1 | MODEL_ARGS=( 2 | --spec "slime_plugins.models.glm4" "get_glm_spec" 3 | --swiglu 4 | --num-layers 64 5 | --hidden-size 6144 6 | --ffn-hidden-size 23040 7 | --num-attention-heads 48 8 | --max-position-embeddings 32768 9 | --seq-length 32768 10 | --use-rotary-position-embeddings 11 | --disable-bias-linear 12 | --normalization "RMSNorm" 13 | --norm-epsilon 1e-5 14 | --rotary-base 10000 15 | --group-query-attention 16 | --num-query-groups 8 17 | --vocab-size 151552 18 | --post-self-attn-layernorm 19 | --post-mlp-layernorm 20 | --rotary-interleaved 21 | --rotary-percent 0.5 22 | --no-rope-fusion 23 | --untie-embeddings-and-output-weights 24 | ) -------------------------------------------------------------------------------- /examples/eval_multi_task/multi_task.yaml: -------------------------------------------------------------------------------- 1 | eval: 2 | defaults: 3 | max_response_len: 16384 4 | top_p: 0.7 5 | datasets: 6 | - name: aime 7 | path: /root/aime-2024/aime-2024.jsonl 8 | rm_type: deepscaler 9 | n_samples_per_eval_prompt: 16 10 | - name: gpqa # huggingface-cli download --repo-type dataset zyzshishui0627/gpqa_diamond --local-dir /root/gpqa 11 | path: /root/gpqa/gpqa_eval.jsonl 12 | rm_type: gpqa 13 | n_samples_per_eval_prompt: 2 14 | - name: ifbench # huggingface-cli download --repo-type dataset zyzshishui0627/IFBench --local-dir /root/ifbench 15 | path: /root/ifbench/IFBench_eval.jsonl 16 | rm_type: ifbench 17 | n_samples_per_eval_prompt: 1 18 | -------------------------------------------------------------------------------- /examples/formal_math/single_round/README.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | For the minimal demo: 4 | 5 | ```shell 6 | # install dependencies 7 | apt update && apt install -y docker-cli 8 | pip install kimina-client polars 9 | 10 | # prepare data 11 | python examples/formal_math/single_round/prepare_data.py --output-name minimal_demo 12 | 13 | # prepare ray, model, test dataset, etc 14 | # normally just use this script, but here we want to demonstrate run_minimal.py, thus skip ray-submit part 15 | SLIME_SCRIPT_ENABLE_RAY_SUBMIT=0 python examples/formal_math/single_round/run.py 16 | 17 | # run 18 | python examples/formal_math/single_round/run_minimal.py 19 | ``` 20 | 21 | The code also support more complicated cases, e.g.: 22 | 23 | * SFT + RL 24 | * Data filter + RL 25 | -------------------------------------------------------------------------------- /examples/retool/rl_data_preprocess.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | # Load the original dataset 4 | ds = load_dataset("BytedTsinghua-SIA/DAPO-Math-17k", split="train") 5 | 6 | 7 | # Map to extract the ground_truth from the reward_model dict and create a new 'label' field 8 | def transform(example): 9 | return { 10 | "prompt": example["prompt"][0]["content"] if example["prompt"] else None, 11 | "label": example["reward_model"]["ground_truth"], 12 | } 13 | 14 | 15 | ds2 = ds.map(transform, remove_columns=ds.column_names) 16 | 17 | # Optionally, verify the first few entries 18 | print(ds2[0]) 19 | 20 | # save to jsonl 21 | ds2.to_json("/root/dapo-math-17k-processed/dapo_math_17k_cleaned.jsonl", orient="records", lines=True) 22 | -------------------------------------------------------------------------------- /slime/utils/train_dump_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def save_debug_train_data(args, *, rollout_id, rollout_data): 10 | if (path_template := args.save_debug_train_data) is not None: 11 | rank = torch.distributed.get_rank() 12 | path = Path(path_template.format(rollout_id=rollout_id, rank=rank)) 13 | logger.info(f"Save debug train data to {path}") 14 | path.parent.mkdir(parents=True, exist_ok=True) 15 | torch.save( 16 | dict( 17 | rollout_id=rollout_id, 18 | rank=rank, 19 | rollout_data=rollout_data, 20 | ), 21 | path, 22 | ) 23 | -------------------------------------------------------------------------------- /slime/utils/tracking_utils.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from slime.utils.tensorboard_utils import _TensorboardAdapter 3 | 4 | from . import wandb_utils 5 | 6 | 7 | def init_tracking(args, primary: bool = True, **kwargs): 8 | if primary: 9 | wandb_utils.init_wandb_primary(args, **kwargs) 10 | else: 11 | wandb_utils.init_wandb_secondary(args, **kwargs) 12 | 13 | 14 | # TODO further refactor, e.g. put TensorBoard init to the "init" part 15 | def log(args, metrics, step_key: str): 16 | if args.use_wandb: 17 | wandb.log(metrics) 18 | 19 | if args.use_tensorboard: 20 | metrics_except_step = {k: v for k, v in metrics.items() if k != step_key} 21 | _TensorboardAdapter(args).log(data=metrics_except_step, step=metrics[step_key]) 22 | -------------------------------------------------------------------------------- /slime/utils/megatron_bridge_utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | try: 4 | from megatron.core.utils import unwrap_model 5 | except ImportError: 6 | unwrap_model = None 7 | 8 | 9 | @contextmanager 10 | def patch_megatron_model(model): 11 | unwrapped_model = unwrap_model(model)[0] 12 | model_config = unwrapped_model.config 13 | attribute_was_added = False 14 | if not hasattr(model_config, "share_embeddings_and_output_weights"): 15 | model_config.share_embeddings_and_output_weights = unwrapped_model.share_embeddings_and_output_weights 16 | attribute_was_added = True 17 | 18 | try: 19 | yield 20 | finally: 21 | if attribute_was_added: 22 | delattr(model_config, "share_embeddings_and_output_weights") 23 | -------------------------------------------------------------------------------- /slime/rollout/base_types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | from slime.utils.types import Sample 5 | 6 | 7 | @dataclass 8 | class RolloutFnTrainOutput: 9 | samples: list[list[Sample]] 10 | metrics: dict[str, Any] = None 11 | 12 | 13 | @dataclass 14 | class RolloutFnEvalOutput: 15 | data: dict[str, dict[str, Any]] 16 | metrics: dict[str, Any] = None 17 | 18 | 19 | def call_rollout_fn(fn, *args, evaluation: bool, **kwargs): 20 | output = fn(*args, **kwargs, evaluation=evaluation) 21 | 22 | # compatibility for legacy version 23 | if not isinstance(output, (RolloutFnTrainOutput, RolloutFnEvalOutput)): 24 | output = RolloutFnEvalOutput(data=output) if evaluation else RolloutFnTrainOutput(samples=output) 25 | 26 | return output 27 | -------------------------------------------------------------------------------- /examples/eval_multi_task/README.md: -------------------------------------------------------------------------------- 1 | # Multi-Task Evaluation Example 2 | 3 | ## Configuring `multi_task.yaml` 4 | - `eval.defaults` defines inference parameters shared by every dataset entry. Override them inside an individual dataset block if needed. 5 | - `eval.datasets` enumerates the datasets to evaluate. Each entry should specify: 6 | - `name`: a short identifier that appears in logs and dashboards. 7 | - `path`: the path to the dataset JSONL file. 8 | - `rm_type`: which reward function to use for scoring. 9 | - `n_samples_per_eval_prompt`: how many candidate completions to generate per prompt. 10 | 11 | ## IFBench Notes 12 | - When `ifbench` is used, `slime/rollout/rm_hub/ifbench.py` will automatically prepares the scoring environment, so no additional manual setup is required beyond providing the dataset path. 13 | -------------------------------------------------------------------------------- /examples/retool/sft_data_processing.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | ds = load_dataset("JoeYing/ReTool-SFT")["train"] 4 | 5 | 6 | def convert(sample): 7 | conversations = sample["messages"] 8 | 9 | def convert_role(role): 10 | if role == "user": 11 | return "user" 12 | elif role == "assistant": 13 | return "assistant" 14 | elif role == "system": 15 | return "system" 16 | else: 17 | raise ValueError(f"Unknown role: {role}") 18 | 19 | messages = [ 20 | { 21 | "role": convert_role(turn["role"]), 22 | "content": turn["content"], 23 | } 24 | for turn in conversations 25 | ] 26 | 27 | return {"messages": messages} 28 | 29 | 30 | ds = ds.map(convert) 31 | ds.to_parquet("./data/retool/ReTool-SFT.parquet") 32 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Docker release rule 2 | 3 | We will publish 2 kinds of docker images: 4 | 1. stable version, which based on official sglang release. We will store the patch on those versions. 5 | 2. latest version, which aligns to `lmsysorg/sglang:latest`. 6 | 7 | current stable version is: 8 | - sglang nightly-dev-20251208-5e2cda61 (5e2cda6158e670e64b926a9985d65826c537ac82), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) 9 | 10 | history versions: 11 | - sglang v0.5.5.post1 (303cc957e62384044dfa8e52d7d8af8abe12f0ac), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) 12 | 13 | The command to build: 14 | 15 | ```bash 16 | just release 17 | ``` 18 | 19 | Before each update, we will test the following models with 64xH100: 20 | 21 | - Qwen3-4B sync 22 | - Qwen3-4B async 23 | - Qwen3-30B-A3B sync 24 | - Qwen3-30B-A3B fp8 sync 25 | - GLM-4.5-355B-A32B sync 26 | -------------------------------------------------------------------------------- /docs/serve.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 4 | LANG="${1:-all}" 5 | PORT="${PORT:-8000}" 6 | 7 | cd "$SCRIPT_DIR" 8 | 9 | if [ "$LANG" = "all" ]; then 10 | # Expect both builds present 11 | if [ ! -d build/en ] || [ ! -d build/zh ]; then 12 | echo "[serve] Missing build/en or build/zh. Run ./build_all.sh first." >&2 13 | fi 14 | echo "[serve] Serving multi-language docs root on http://localhost:$PORT (en/, zh/)" 15 | python -m http.server -d ./build "$PORT" 16 | exit $? 17 | fi 18 | 19 | if [ "$LANG" != "en" ] && [ "$LANG" != "zh" ]; then 20 | echo "Usage: $0 [en|zh|all]" >&2 21 | exit 1 22 | fi 23 | 24 | if [ ! -d "build/$LANG" ]; then 25 | echo "[serve] build/$LANG not found. Run ./build.sh $LANG first." >&2 26 | exit 1 27 | fi 28 | echo "[serve] Serving $LANG docs on http://localhost:$PORT" 29 | python -m http.server -d ./build/$LANG "$PORT" -------------------------------------------------------------------------------- /slime/rollout/generate_hub/benchmarkers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from argparse import Namespace 4 | from copy import deepcopy 5 | from typing import Any 6 | 7 | from slime.rollout.sglang_rollout import generate as _generate_base 8 | from slime.utils.types import Sample 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | async def generate_with_random_osl(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: 14 | # TODO: make it configurable after we have an enhanced arg parser 15 | min_osl = 32 * 1024 16 | max_osl = 64 * 1024 17 | 18 | modified_sampling_params = deepcopy(sampling_params) 19 | modified_sampling_params["ignore_eos"] = True 20 | modified_sampling_params["max_new_tokens"] = random.randrange(min_osl, max_osl) 21 | 22 | ans = await _generate_base(args, sample, modified_sampling_params) 23 | 24 | logger.info(f"generate_with_random_osl {ans.response_length=}") 25 | return ans 26 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # slime Documentation 2 | 3 | We recommend new contributors start from writing documentation, which helps you quickly understand slime codebase. 4 | Most documentation files are located under the `docs/` folder. 5 | 6 | ## Docs Workflow 7 | 8 | ### Install Dependency 9 | 10 | ```bash 11 | apt-get update && apt-get install -y pandoc parallel retry 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ### Update Documentation 16 | 17 | You can update the documentation in the en and zh folders by adding Markdown or Jupyter Notebook files to the appropriate subdirectories. If you create new files, make sure to update index.rst (or any other relevant .rst files) accordingly. 18 | 19 | ## Build and Render 20 | 21 | ```bash 22 | # build english version 23 | bash ./build.sh en 24 | bash ./serve.sh en 25 | 26 | # build chinese version 27 | bash ./build.sh zh 28 | bash ./serve.sh zh 29 | ``` 30 | 31 | You can then visit `http://localhost:8000` to view the documentation. -------------------------------------------------------------------------------- /slime/utils/async_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import threading 3 | 4 | __all__ = ["get_async_loop", "run"] 5 | 6 | 7 | # Create a background event loop thread 8 | class AsyncLoopThread: 9 | def __init__(self): 10 | self.loop = asyncio.new_event_loop() 11 | self._thread = threading.Thread(target=self._start_loop, daemon=True) 12 | self._thread.start() 13 | 14 | def _start_loop(self): 15 | asyncio.set_event_loop(self.loop) 16 | self.loop.run_forever() 17 | 18 | def run(self, coro): 19 | # Schedule a coroutine onto the loop and block until it's done 20 | return asyncio.run_coroutine_threadsafe(coro, self.loop).result() 21 | 22 | 23 | # Create one global instance 24 | async_loop = None 25 | 26 | 27 | def get_async_loop(): 28 | global async_loop 29 | if async_loop is None: 30 | async_loop = AsyncLoopThread() 31 | return async_loop 32 | 33 | 34 | def run(coro): 35 | """Run a coroutine in the background event loop.""" 36 | return get_async_loop().run(coro) 37 | -------------------------------------------------------------------------------- /examples/true_on_policy_vlm/README.md: -------------------------------------------------------------------------------- 1 | # True On-Policy between Training and Inference for VLM 2 | 3 | This example demonstrates true on-policy training with Qwen3-VL dense model on FSDP. The core concepts and expected observations are the same as [true_on_policy](../true_on_policy/README.md). 4 | 5 |

6 | Training Inference Log Prob Diff 7 |

8 | 9 | ## Usage 10 | 11 | ```bash 12 | SLIME_SCRIPT_NUM_GPUS=8 python examples/true_on_policy_vlm/run_simple.py 13 | ``` 14 | 15 | ## How it is Implemented 16 | 17 | For the text backbone, please refer to [true_on_policy for the text-only model](../true_on_policy/README.md). 18 | 19 | For the VLM, we only need to ensure that the image encoder behaves as expected. Please refer to [SGLang#14636](https://github.com/sgl-project/sglang/pull/14636). We need to align numeric operation details between the two systems, so that the ViT forward pass matches the behavior in both SGLang and transformers. 20 | 21 | ## Notes 22 | 23 | It is expected that the true-on-policy version is slower. -------------------------------------------------------------------------------- /slime/backends/megatron_utils/update_weight/hf_weight_iterator_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class HfWeightIteratorBase(ABC): 5 | @staticmethod 6 | def create(args, model, **kwargs): 7 | from .hf_weight_iterator_bridge import HfWeightIteratorBridge 8 | from .hf_weight_iterator_direct import HfWeightIteratorDirect 9 | 10 | c = { 11 | "raw": HfWeightIteratorDirect, 12 | "bridge": HfWeightIteratorBridge, 13 | }[args.megatron_to_hf_mode] 14 | 15 | return c(args, model, **kwargs) 16 | 17 | def __init__(self, args, model, model_name, quantization_config): 18 | self.args = args 19 | self.model = model 20 | self.model_name = model_name 21 | self.quantization_config = quantization_config 22 | 23 | @abstractmethod 24 | def get_hf_weight_chunks(self, megatron_local_weights): 25 | """ 26 | Mental model of the API: 27 | megatron_model.to_hf_magically().named_parameters() 28 | """ 29 | raise NotImplementedError 30 | -------------------------------------------------------------------------------- /slime/utils/metric_checker.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger(__name__) 4 | 5 | 6 | class MetricChecker: 7 | @staticmethod 8 | def maybe_create(args): 9 | if args.ci_test and (args.ci_metric_checker_key is not None): 10 | return MetricChecker(args) 11 | return None 12 | 13 | def __init__(self, args): 14 | self.args = args 15 | self._exists_check_success = False 16 | 17 | def on_eval(self, metrics: dict[str, float]): 18 | actual_value = metrics.get(self.args.ci_metric_checker_key) 19 | assert actual_value is not None, f"{metrics=} {self.args.ci_metric_checker_key=}" 20 | 21 | check_success = actual_value >= self.args.ci_metric_checker_threshold 22 | logger.info(f"[MetricChecker] {check_success=} {actual_value=} {self.args.ci_metric_checker_threshold=}") 23 | 24 | self._exists_check_success |= check_success 25 | 26 | def dispose(self): 27 | assert self._exists_check_success, "[MetricChecker] accuracy check failed" 28 | logger.info("[MetricChecker] pass dispose check") 29 | -------------------------------------------------------------------------------- /docs/en/advanced/fault-tolerance.md: -------------------------------------------------------------------------------- 1 | # Fault Tolerance 2 | 3 | To ensure long-term, stable RL training, slime enables a certain level of fault tolerance by default. This section introduces the design philosophy behind fault tolerance in slime. 4 | 5 | To enable the fault tolerance function in slime, please set `--use-fault-tolerance`. 6 | 7 | ## Rollout Fault Tolerance 8 | 9 | During the rollout process, slime periodically sends heartbeat requests (`/health_generate`) to all SGLang servers. If a heartbeat times out, that SGLang server will be stopped. After the current rollout round is complete, the server will be restarted and its parameters will be correctly updated. 10 | 11 | - `--rollout-health-check-first-wait`: Since some large MoE models require compilation on their first run, slime will wait for `rollout_health_check_first_wait` seconds before the first rollout to start sending heartbeats. Defaults to 300s. 12 | - `--rollout-health-check-interval`: The interval between heartbeat checks. Defaults to 10s. 13 | - `--rollout-health-check-timeout`: The timeout limit for a heartbeat request. Defaults to 5s. 14 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | types: [opened, synchronize, reopened, ready_for_review] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | run-pre-commit: 14 | name: Run pre-commit 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout repository 18 | uses: actions/checkout@v4 19 | with: 20 | fetch-depth: 0 21 | 22 | - name: Set up Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: '3.10' 26 | cache: 'pip' 27 | 28 | - name: Install pre-commit 29 | run: pip install --upgrade pip pre-commit 30 | 31 | - name: Cache pre-commit environments 32 | uses: actions/cache@v4 33 | with: 34 | path: ~/.cache/pre-commit 35 | key: pre-commit-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }} 36 | restore-keys: | 37 | pre-commit-${{ runner.os }}- 38 | 39 | - name: Run pre-commit on all files 40 | run: pre-commit run --all-files --show-diff-on-failure --color=always 41 | 42 | -------------------------------------------------------------------------------- /slime/backends/fsdp_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | try: 5 | _TORCH_MEMORY_SAVER_AVAILABLE = True 6 | except ImportError: 7 | logging.warning("torch_memory_saver is not installed, refer to : https://github.com/fzyzcjy/torch_memory_saver") 8 | _TORCH_MEMORY_SAVER_AVAILABLE = False 9 | 10 | try: 11 | _FSDP_AVAILABLE = True 12 | except ImportError as e: 13 | logging.warning(f"FSDP backend dependencies not available: {e}") 14 | _FSDP_AVAILABLE = False 15 | 16 | if _FSDP_AVAILABLE: 17 | from .actor import FSDPTrainRayActor 18 | from .arguments import load_fsdp_args 19 | else: 20 | 21 | def _raise_import_error(*args, **kwargs): 22 | raise ImportError( 23 | "FSDP backend is not available. " 24 | "Please ensure PyTorch with FSDP2 support is installed. " 25 | "For installation instructions, refer to: https://pytorch.org/docs/stable/distributed.fsdp.fully_shard.html" 26 | ) 27 | 28 | FSDPTrainRayActor = _raise_import_error 29 | load_fsdp_args = _raise_import_error 30 | 31 | __all__ = ["load_fsdp_args", "FSDPTrainRayActor"] 32 | 33 | logging.getLogger().setLevel(logging.WARNING) 34 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/sglang.py: -------------------------------------------------------------------------------- 1 | # the file to manage all sglang deps in the megatron actor 2 | try: 3 | from sglang.srt.layers.quantization.fp8_utils import quant_weight_ue8m0, transform_scale_ue8m0 4 | from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 5 | except ImportError: 6 | quant_weight_ue8m0 = None 7 | transform_scale_ue8m0 = None 8 | should_deepgemm_weight_requant_ue8m0 = None 9 | 10 | try: 11 | from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions 12 | except ImportError: 13 | from sglang.srt.patch_torch import monkey_patch_torch_reductions 14 | 15 | 16 | from sglang.srt.utils import MultiprocessingSerializer 17 | 18 | 19 | try: 20 | from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import] 21 | except ImportError: 22 | from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] 23 | 24 | __all__ = [ 25 | "quant_weight_ue8m0", 26 | "transform_scale_ue8m0", 27 | "should_deepgemm_weight_requant_ue8m0", 28 | "monkey_patch_torch_reductions", 29 | "MultiprocessingSerializer", 30 | "FlattenedTensorBucket", 31 | ] 32 | -------------------------------------------------------------------------------- /examples/tau-bench/sglang_tool_parser.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from sglang.srt.function_call.function_call_parser import FunctionCallParser 4 | from sglang.srt.managers.io_struct import Function, Tool 5 | 6 | 7 | def parse_tools(response: str, tools: list[dict[str, Any]], parser: str = "qwen25"): 8 | """ 9 | This function mimics the function call parser API from 10 | https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py#L952 11 | But running locally 12 | """ 13 | tools_list = [ 14 | Tool( 15 | function=Function( 16 | name=tool["function"]["name"], 17 | description=tool["function"]["description"], 18 | parameters=tool["function"]["parameters"], 19 | ), 20 | type=tool["type"], 21 | ) 22 | for tool in tools 23 | ] 24 | parser = FunctionCallParser(tools=tools_list, tool_call_parser=parser) 25 | 26 | normal_text, calls = parser.parse_non_stream(response) 27 | 28 | return { 29 | "normal_text": normal_text, 30 | "calls": [call.model_dump() for call in calls], # Convert pydantic objects to dictionaries 31 | } 32 | -------------------------------------------------------------------------------- /scripts/models/qwen3-30B-A3B.sh: -------------------------------------------------------------------------------- 1 | NLAYERS=48 2 | FIRST_K_DENSE_REPLACE=0 3 | 4 | arr=() 5 | for ((i=0; i list[Sample]: 17 | 18 | tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True) 19 | max_context_length = args.rollout_max_context_len if not evaluation else args.eval_max_context_len 20 | 21 | args.sampling_params = sampling_params 22 | args.rollout_max_context_len = max_context_length 23 | args.tokenizer = tokenizer 24 | 25 | for key, value in MULTI_AGENT_CONFIGS.items(): 26 | setattr(args, key, value) 27 | 28 | custom_multi_agent_func = load_function(args.custom_multi_agent_function_path) 29 | samples = await custom_multi_agent_func(args, sample) 30 | 31 | random.shuffle(samples) 32 | 33 | return samples 34 | -------------------------------------------------------------------------------- /.github/workflows/generate_github_workflows.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import jinja2 3 | 4 | 5 | def main(): 6 | """ 7 | Generates GitHub workflow YAML files from Jinja2 templates. 8 | """ 9 | workflows_dir = Path(__file__).parent 10 | print(f"Scan dir: {workflows_dir}") 11 | env = jinja2.Environment( 12 | loader=jinja2.FileSystemLoader(str(workflows_dir)), 13 | block_start_string="<%", 14 | block_end_string="%>", 15 | variable_start_string="<<", 16 | variable_end_string=">>", 17 | ) 18 | 19 | for template_path in workflows_dir.glob("*.yml.j2"): 20 | template = env.get_template(template_path.name) 21 | content = template.render() 22 | 23 | yaml_path = template_path.with_suffix("") 24 | with open(yaml_path, "w") as f: 25 | f.write( 26 | "#" * 80 27 | + "\n# This file is auto-generated from the .j2 file via generate_github_workflows.py. Do not edit manually.\n" 28 | + "#" * 80 29 | + "\n" 30 | ) 31 | f.write(content) 32 | 33 | print(f"Generated {yaml_path} from {template_path}") 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /scripts/models/qwen3-235B-A22B.sh: -------------------------------------------------------------------------------- 1 | # qwen3-235B-a22B 2 | NLAYERS=94 3 | FIRST_K_DENSE_REPLACE=0 4 | 5 | arr=() 6 | for ((i=0; i 25 | sh -c " 26 | cd /data/slime_ci && 27 | /home/runner/config.sh --url ${GITHUB_RUNNER_URL} --token ${GITHUB_RUNNER_TOKEN} --unattended --work /data/slime_ci/runner_$(hostname) --disableupdate && 28 | /home/runner/run.sh 29 | " 30 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/arguments.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from megatron.training.arguments import parse_args, validate_args 4 | from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding 5 | 6 | __all__ = ["validate_args", "parse_args", "set_default_megatron_args"] 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def set_default_megatron_args(args): 12 | # always use zero optimizer 13 | args.use_distributed_optimizer = True 14 | # TODO: maybe change this after megatron has good fp8 support 15 | args.bf16 = not args.fp16 16 | # placeholders 17 | args.seq_length = 4096 18 | args.max_position_embeddings = args.seq_length 19 | # compatible for megatron 20 | if hasattr(args, "rope_type") and args.rope_type is None: 21 | args.rope_type = "yarn" if args.multi_latent_attention else "rope" 22 | 23 | if args.vocab_size and not args.padded_vocab_size: 24 | args.padded_vocab_size = _vocab_size_with_padding(args.vocab_size, args) 25 | 26 | if not args.tokenizer_model and not args.tokenizer_type: 27 | logger.info("--tokenizer-model not set, use --hf-checkpoint as tokenizer model.") 28 | args.tokenizer_model = args.hf_checkpoint 29 | args.tokenizer_type = "HuggingFaceTokenizer" 30 | return args 31 | -------------------------------------------------------------------------------- /scripts/models/glm4.5-355B-A32B.sh: -------------------------------------------------------------------------------- 1 | N_DENSE_LAYERS=3 2 | N_MOE_LAYERS=89 3 | 4 | # glm4.5-355B-A32B 5 | MODEL_ARGS=( 6 | --disable-bias-linear 7 | --qk-layernorm 8 | --group-query-attention 9 | --num-attention-heads 96 10 | --num-query-groups 8 11 | --kv-channels 128 12 | --num-layers $((N_DENSE_LAYERS + N_MOE_LAYERS)) 13 | --hidden-size 5120 14 | --ffn-hidden-size 12288 15 | 16 | --add-qkv-bias 17 | --normalization RMSNorm 18 | --position-embedding-type rope 19 | --rotary-percent 0.5 20 | --swiglu 21 | --untie-embeddings-and-output-weights 22 | --vocab-size 151552 23 | 24 | --rotary-base 1000000 25 | 26 | # moe 27 | --moe-ffn-hidden-size 1536 28 | --moe-shared-expert-intermediate-size 1536 29 | --moe-router-pre-softmax 30 | --moe-router-score-function sigmoid 31 | --moe-router-enable-expert-bias 32 | --moe-router-bias-update-rate 0 33 | --moe-router-load-balancing-type seq_aux_loss 34 | --moe-token-dispatcher-type alltoall 35 | --moe-router-topk 8 36 | --moe-router-topk-scaling-factor 2.5 37 | --moe-layer-freq [0]*$N_DENSE_LAYERS+[1]*$N_MOE_LAYERS 38 | --num-experts 160 39 | --moe-grouped-gemm 40 | --moe-router-dtype fp32 41 | --moe-permute-fusion 42 | --moe-aux-loss-coeff 0 43 | ) 44 | -------------------------------------------------------------------------------- /slime/utils/rocm_checkpoint_writer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from megatron.core.dist_checkpointing.strategies.filesystem_async import FileSystemWriterAsync 3 | 4 | 5 | class ROCmFileSystemWriterAsync(FileSystemWriterAsync): 6 | """ 7 | FileSystemWriterAsync wrapper for ROCm compatibility. 8 | 9 | On ROCm/HIP, using non_blocking=True causes tensors to be stored in pinned memory, 10 | which triggers segmentation faults when forking subprocesses afterward. 11 | """ 12 | 13 | @staticmethod 14 | def preload_tensors(*args, **kwargs): 15 | # Change argument non_blocking to False on HIP platform 16 | # The tensors will be stored in pinned memory if non_blocking=True 17 | # Currently on the ROCm platform, forking a subprocess afterward 18 | # with pinned_memory=True will trigger segmentation fault 19 | if torch.version.hip: 20 | print("HIP/ROCm detected: setting non_blocking=False in preload_tensors") 21 | if "non_blocking" in kwargs: 22 | kwargs["non_blocking"] = False 23 | elif len(args) > 1 and isinstance(args[-1], bool): 24 | # non_blocking is typically the last argument 25 | args = args[:-1] + (False,) 26 | 27 | return FileSystemWriterAsync.preload_tensors(*args, **kwargs) 28 | -------------------------------------------------------------------------------- /slime/utils/iter_utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from collections.abc import Callable, Iterable 3 | from typing import Any 4 | 5 | import torch 6 | 7 | 8 | # details: https://stackoverflow.com/questions/773/how-do-i-use-itertools-groupby 9 | def group_by(iterable, key=None): 10 | """Similar to itertools.groupby, but do not require iterable to be sorted""" 11 | ret = defaultdict(list) 12 | for item in iterable: 13 | ret[key(item) if key is not None else item].append(item) 14 | return dict(ret) 15 | 16 | 17 | # TODO fsdp can also use this 18 | def chunk_named_params_by_size(named_params: Iterable[tuple[str, torch.Tensor]], chunk_size: int): 19 | return _chunk_by_size( 20 | named_params, 21 | compute_size=lambda named_weight: named_weight[1].nbytes, 22 | chunk_size=chunk_size, 23 | ) 24 | 25 | 26 | def _chunk_by_size(objects: Iterable[Any], compute_size: Callable[[Any], int], chunk_size: int): 27 | bucket: list[Any] = [] 28 | bucket_size = 0 29 | 30 | for obj in objects: 31 | obj_size = compute_size(obj) 32 | 33 | if bucket and (bucket_size + obj_size) >= chunk_size: 34 | yield bucket 35 | bucket = [] 36 | bucket_size = 0 37 | 38 | bucket.append(obj) 39 | bucket_size += obj_size 40 | 41 | if bucket: 42 | yield bucket 43 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 7 | autoupdate_schedule: quarterly 8 | 9 | repos: 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.5.0 12 | hooks: 13 | - id: check-yaml 14 | - id: check-case-conflict 15 | - id: detect-private-key 16 | - id: check-added-large-files 17 | args: ['--maxkb=1000'] 18 | - id: requirements-txt-fixer 19 | 20 | - repo: https://github.com/astral-sh/ruff-pre-commit 21 | rev: v0.14.7 22 | hooks: 23 | - id: ruff-check 24 | args: [ --fix ] 25 | 26 | - repo: https://github.com/PyCQA/autoflake 27 | rev: v2.0.2 28 | hooks: 29 | - id: autoflake 30 | args: [--remove-all-unused-imports, --in-place] 31 | 32 | - repo: https://github.com/pycqa/isort 33 | rev: 5.13.2 # 选一个稳定版本 34 | hooks: 35 | - id: isort 36 | args: 37 | - "--profile=black" # 常见:与 Black 对齐风格 38 | - "--filter-files" # 忽略已在 .gitignore 的文件 39 | additional_dependencies: [] # 需要插件时在这里加 40 | 41 | - repo: https://github.com/psf/black 42 | rev: 24.3.0 43 | hooks: 44 | - id: black 45 | name: Format code 46 | additional_dependencies: ['click==8.0.2'] 47 | -------------------------------------------------------------------------------- /docs/zh/advanced/speculative-decoding.md: -------------------------------------------------------------------------------- 1 | # 投机采样 2 | 3 | 投机采样是加速 rollout 的重要优化手段。推理过程中不再让昂贵的 Target Model 逐个 token 进行 decode,而是先由一个轻量级的 draft model 先进行 decode,生成多个 token 后,再由大模型进行批量验证。 4 | 5 | ## 使用投机采样加速推理 6 | 7 | 对于有 MTP 层的模型(例如 GLM-4.6、Deepseek-V3/R1),只需要添加: 8 | 9 | ```bash 10 | --sglang-speculative-algorithm EAGLE 11 | --sglang-speculative-num-steps 3 12 | --sglang-speculative-eagle-topk 1 13 | --sglang-speculative-num-draft-tokens 4 14 | ``` 15 | 16 | 如果要使用单独训练的 draft model(例如 [SpecForge](https://docs.sglang.ai/SpecForge/) 训练的),还需要额外设置: 17 | 18 | ```bash 19 | --sglang-speculative-draft-model-path /your/draft/model/path 20 | ``` 21 | 22 | 详细参数含义及配置方法,请参考 SGLang 的 speculative decoding [文档](https://docs.sglang.ai/advanced_features/speculative_decoding.html) 23 | 24 | ## 在线 SFT draft model 25 | 26 | 随着 RL 流程的进行,draft model 和 target model 的采样概率差异逐渐增大,能通过验证的 draft token 逐渐减少,spec 甚至可能造成负收益。 27 | 28 | 目前,slime 支持了在 RL 流程中在线训练 MTP 层,随着训练的进行同步更新 draft model,稳定提高了采样速度,相关原理可参见 [blog](https://www.notion.so/jiajunli-guapisolo/Power-Up-Speculative-Decoding-In-Reinforcement-Learning-2a92d24a293b802d9c73dbae429e581e)。使用方法如下: 29 | 30 | ```bash 31 | --mtp-num-layers 1 32 | --enable-mtp-training 33 | --mtp-loss-scaling-factor 0.2 34 | ``` 35 | 36 | 注意 MTP 训练需要一个包含了 MTP 权重的 checkpoint,所以在将 huggingface checkpoint 转为 torch dist 时,也需要加上 `--mtp-num-layers 1`。 37 | 38 | 外部 draft model 的训练还在 WIP。 39 | -------------------------------------------------------------------------------- /.github/workflows/release-docs.yaml: -------------------------------------------------------------------------------- 1 | name: Release Documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "docs/**" 9 | - "examples/**" 10 | - "version.txt" 11 | workflow_dispatch: 12 | 13 | concurrency: 14 | group: release-docs-${{ github.ref }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | deploy: 19 | runs-on: ubuntu-latest 20 | if: github.repository == 'THUDM/slime' 21 | permissions: 22 | contents: write 23 | steps: 24 | - name: Checkout code 25 | uses: actions/checkout@v4 26 | 27 | - name: Setup Python 28 | uses: actions/setup-python@v5 29 | with: 30 | python-version: '3.13' 31 | 32 | 33 | - name: Install dependencies 34 | run: | 35 | apt-get update && apt-get install -y pandoc parallel retry 36 | pip install -r docs/requirements.txt 37 | 38 | - name: Build documentation 39 | run: | 40 | cd docs 41 | bash ./build.sh en 42 | bash ./build.sh zh 43 | mv ./build/zh ./build/en/ 44 | env: 45 | LC_ALL: "en_US.UTF-8" 46 | LC_CTYPE: "en_US.UTF-8" 47 | 48 | 49 | - name: Deploy 50 | uses: peaceiris/actions-gh-pages@v4 51 | with: 52 | github_token: ${{ secrets.GITHUB_TOKEN }} 53 | publish_dir: ./docs/build/en -------------------------------------------------------------------------------- /examples/eval/scripts/multi_tasks.yaml: -------------------------------------------------------------------------------- 1 | eval: 2 | defaults: 3 | n_samples_per_eval_prompt: 1 4 | temperature: 0.6 5 | top_p: 0.95 6 | top_k: -1 7 | max_response_len: 24576 8 | datasets: # these eval tasks go through slime dataset config and default rollout function (slime.rollout.sglang_rollout.generate_rollout) 9 | - name: gpqa # huggingface-cli download --repo-type dataset zyzshishui0627/gpqa_diamond --local-dir /root/gpqa 10 | path: /root/gpqa/gpqa_eval.jsonl 11 | rm_type: gpqa 12 | n_samples_per_eval_prompt: 2 13 | - name: ifbench # huggingface-cli download --repo-type dataset zyzshishui0627/IFBench --local-dir /root/ifbench 14 | path: /root/ifbench/IFBench_eval.jsonl 15 | rm_type: ifbench 16 | n_samples_per_eval_prompt: 1 17 | delegate: # these tasks go through delegate eval function (examples.eval.eval_delegate_rollout.generate_rollout) 18 | - name: skills 19 | # this url should align with env docker network alias 20 | url: http://skills_server:9050/evaluate 21 | timeout_secs: 7200 22 | max_retries: 5 23 | headers: {} 24 | datasets: 25 | - name: aime25 26 | max_response_len: 8192 27 | n_samples_per_eval_prompt: 8 28 | - name: arena-hard 29 | n_samples_per_eval_prompt: 2 30 | - name: hle 31 | max_response_len: 32768 32 | 33 | -------------------------------------------------------------------------------- /slime/utils/memory_utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def clear_memory(clear_host_memory: bool = False): 11 | torch.cuda.synchronize() 12 | gc.collect() 13 | torch.cuda.empty_cache() 14 | if clear_host_memory: 15 | torch._C._host_emptyCache() 16 | 17 | 18 | def available_memory(): 19 | device = torch.cuda.current_device() 20 | free, total = torch.cuda.mem_get_info(device) 21 | return { 22 | "gpu": str(device), 23 | "total_GB": _byte_to_gb(total), 24 | "free_GB": _byte_to_gb(free), 25 | "used_GB": _byte_to_gb(total - free), 26 | "allocated_GB": _byte_to_gb(torch.cuda.memory_allocated(device)), 27 | "reserved_GB": _byte_to_gb(torch.cuda.memory_reserved(device)), 28 | } 29 | 30 | 31 | def _byte_to_gb(n: int): 32 | return round(n / (1024**3), 2) 33 | 34 | 35 | def print_memory(msg, clear_before_print: bool = False): 36 | if clear_before_print: 37 | clear_memory() 38 | 39 | memory_info = available_memory() 40 | # Need to print for all ranks, b/c different rank can have different behaviors 41 | logger.info( 42 | f"[Rank {dist.get_rank()}] Memory-Usage {msg}{' (cleared before print)' if clear_before_print else ''}: {memory_info}" 43 | ) 44 | return memory_info 45 | -------------------------------------------------------------------------------- /slime_plugins/rollout_buffer/README_zh.md: -------------------------------------------------------------------------------- 1 | # Rollout Buffer 2 | 3 | ## 概述 4 | 5 | Rollout Buffer 是用于辅助纯异步 agent 训练的独立组件,其主要功能是使用 slime 训练启动的 LLM OpenAI Server 进行智能体轨迹的生成。 6 | 7 | ### 工作流程 8 | 9 | ``` 10 | slime Training Process ←─── HTTP API ───→ Rollout Buffer 11 | ↓ ↓ 12 | LLM Server ←─────── HTTP Requests ─────── Agent Framework 13 | ↓ ↓ 14 | Model Response ──────────────────────→ Trajectory Generation 15 | ``` 16 | 17 | 对于每一个不同的 Agent 任务,都应该对应一个独立的 Generator 类,负责生成该类任务的轨迹。Rollout Buffer 会自动读取并加载不同类型的 Generator。 18 | 19 | ## 快速开始 20 | 21 | ### 基本使用流程 22 | 23 | 1. **复制模板**:将 `base_generator.py` 作为模板进行复制 24 | 2. **修改任务类型**:将 `TASK_TYPE` 修改为您的任务名称(不能与其他 Generator 重复) 25 | 3. **实现核心函数**:实现 `run_rollout()` 函数 26 | 4. **可选定制**:根据需要重写五个可选函数 27 | 28 | 29 | Generator 文件必须以 `_generator.py` 结尾,并放置在 `generator/` 目录下: 30 | 31 | ``` 32 | generator/ 33 | ├── base_generator.py # Math 任务实现(默认模板) 34 | └── your_task_generator.py # 您的自定义任务 35 | ``` 36 | 37 | 每个 Generator 文件必须定义 `TASK_TYPE` 与 `run_rollout()`。 38 | 39 | 此外,Rollout Buffer 还提供了一些可自定义的函数来满足不同任务的特殊需求。如果不提供自定义实现,系统将使用默认实现(位于 `slime_plugins/rollout_buffer/default_func.py`)。 40 | 41 | ### 示例脚本 42 | 43 | 请仿照 [示例:Qwen3-4B 模型](../../docs/zh/models/qwen3-4B.md) 文档中配置好 slime 的运行环境,下载数据,并转换模型 ckpt。之后分别运行 44 | 45 | ```bash 46 | cd slime_plugins/rollout_buffer 47 | bash rollout_buffer_example.sh 48 | 49 | # In a different terminal 50 | python buffer.py 51 | ``` 52 | -------------------------------------------------------------------------------- /scripts/models/qwen3-next-80B-A3B.sh: -------------------------------------------------------------------------------- 1 | NLAYERS=48 2 | FIRST_K_DENSE_REPLACE=0 3 | 4 | arr=() 5 | for ((i=0; i" in response: 6 | model_solution = response.split("")[-1] 7 | elif "###Response" in response: 8 | model_solution = response.split("###Response")[1] 9 | else: 10 | return 0 11 | 12 | model_answer = extract_answer(model_solution) 13 | if model_answer is None: 14 | return 0 15 | if label == "": 16 | return 0 17 | 18 | # Convert single answer to list for uniform processing 19 | assert isinstance(label, (str, float, int)) 20 | ground_truths = [label] 21 | 22 | # Process each ground truth 23 | processed_ground_truths = [] 24 | for truth in ground_truths: 25 | truth = str(truth) 26 | if "\\boxed" in truth: 27 | processed_truth = extract_answer(truth) 28 | if processed_truth is not None: 29 | processed_ground_truths.append(processed_truth) 30 | else: 31 | processed_ground_truths.append(truth) 32 | 33 | if not processed_ground_truths: 34 | return 0 35 | 36 | # Check against all possible correct answers 37 | for ground_truth in processed_ground_truths: 38 | is_correct = grade_answer_mathd(model_answer, ground_truth) or grade_answer_sympy(model_answer, ground_truth) 39 | if is_correct: 40 | return 1 41 | 42 | return 0 43 | -------------------------------------------------------------------------------- /examples/train_infer_mismatch_helper/mis.yaml: -------------------------------------------------------------------------------- 1 | # Enable importance sampling, details refer to the comments of compute_mis_weights in mis.py 2 | use_tis: true 3 | use_rs: true 4 | 5 | # Aggregation level for importance sampling weights: 6 | # token: per-token 7 | # sequence: product over tokens 8 | # geometric: geometric mean 9 | tis_level: "token" 10 | rs_level: "token" 11 | 12 | # Handling mode for IS weights: 13 | # truncate: cap to upper bound, TIS 14 | # mask: zero outside [lower, upper], MIS 15 | # clip: clip to [lower, upper], CIS 16 | tis_mode: "truncate" 17 | 18 | # For clip mode, the lower bound of the IS weights. 19 | # For truncate mode, it will not be used. 20 | # If not set, it will be set to 1.0 / mis_upper_bound 21 | # For Geometry level, the lower bound should be 0.9999 22 | tis_lower_bound: 0.5 23 | 24 | # For truncate or clip mode, the upper bound of the IS weights 25 | # For Geometry level, the upper bound should be 1.0001 26 | tis_upper_bound: 2.0 27 | 28 | # Lower and upper bound for rejection sampling. 29 | # If not set, it will be same as tis_lower_bound and tis_upper_bound. 30 | rs_lower_bound: null 31 | rs_upper_bound: null 32 | 33 | # Per-token veto threshold. If any token ratio < this, zero the entire sequence weight, the sequences won't have gradient 34 | # Note: float number must be written with dot e.g. 1.0e-4, not 1e-4 35 | rs_veto_threshold: 1.0e-4 36 | 37 | # Batch normalization: normalize IS weights to mean=1.0 across entire batch 38 | # This reduces variance in gradient updates 39 | tis_batch_normalize: true 40 | -------------------------------------------------------------------------------- /examples/on_policy_distillation/on_policy_distillation.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | import torch 3 | 4 | from slime.utils.types import Sample 5 | 6 | 7 | async def reward_func(args, sample, **kwargs): 8 | payload = { 9 | "text": sample.prompt + sample.response, 10 | "sampling_params": { 11 | "temperature": 0, 12 | "max_new_tokens": 0, 13 | "skip_special_tokens": False, 14 | }, 15 | "return_logprob": True, 16 | "logprob_start_len": 0, 17 | } 18 | session_kwargs = {} 19 | async with aiohttp.ClientSession(**session_kwargs) as session: 20 | async with session.post(args.rm_url, json=payload) as resp: 21 | resp.raise_for_status() 22 | return await resp.json() 23 | 24 | 25 | def post_process_rewards(args, samples: list[Sample], **kwargs): 26 | rewards = [sample.get_reward_value(args) for sample in samples] 27 | response_lengths = [sample.response_length for sample in samples] 28 | teacher_log_probs = [ 29 | torch.tensor([item[0] for item in reward["meta_info"]["input_token_logprobs"][1:]], dtype=torch.float32) 30 | for reward in rewards 31 | ] 32 | teacher_log_probs = [ 33 | t_log_prob[-response_length:] 34 | for t_log_prob, response_length in zip(teacher_log_probs, response_lengths, strict=False) 35 | ] 36 | 37 | for sample, t_log_probs in zip(samples, teacher_log_probs, strict=False): 38 | sample.teacher_log_probs = t_log_probs 39 | 40 | return teacher_log_probs, teacher_log_probs 41 | -------------------------------------------------------------------------------- /tests/ci/README.md: -------------------------------------------------------------------------------- 1 | # Doc about CI 2 | 3 | ## Configure GitHub secrets 4 | 5 | https://github.com/slimerl/slime/settings/secrets/actions 6 | 7 | * `WANDB_API_KEY`: get from https://wandb.ai/authorize 8 | 9 | ## Setup new GitHub runners 10 | 11 | ### Step 1: Env 12 | 13 | Write `.env` mimicking `.env.example`. 14 | The token can be found at https://github.com/slimerl/slime/settings/actions/runners/new?arch=x64&os=linux. 15 | 16 | WARN: The `GITHUB_RUNNER_TOKEN` changes after a while. 17 | 18 | ### Step 2: Prepare `/home/runner/externals` 19 | 20 | ```shell 21 | docker run --rm -it --privileged --pid=host -v /:/host_root ubuntu /bin/bash -c 'rm -rf /host_root/home/runner/externals && mkdir -p /host_root/home/runner/externals && chmod -R 777 /host_root/home/runner/externals' 22 | docker run -d --name temp-runner ghcr.io/actions/actions-runner:2.328.0 tail -f /dev/null 23 | docker cp temp-runner:/home/runner/externals/. /home/runner/externals 24 | docker rm -f temp-runner 25 | ls -alh /home/runner/externals 26 | ``` 27 | 28 | ### Step 3: Run 29 | 30 | ```shell 31 | cd /mnt/data/tom/primary_synced/slime/tests/ci/github_runner 32 | docker compose up -d 33 | ``` 34 | 35 | ### Debugging 36 | 37 | Logs 38 | 39 | ```shell 40 | # All containers 41 | docker compose logs -f 42 | 43 | # One container 44 | docker logs -f github_runner-runner-1 45 | ``` 46 | 47 | Exec 48 | 49 | ```shell 50 | docker exec -it github_runner-runner-1 /bin/bash 51 | ``` 52 | 53 | An example of quickly iterate 54 | 55 | ```shell 56 | docker compose down -v && docker compose up -d && docker logs -f github_runner-runner-1 57 | ``` 58 | -------------------------------------------------------------------------------- /examples/tau-bench/tau1_mock.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | from tau_bench.envs import get_env 6 | from tau_bench.types import RunConfig 7 | 8 | ALL_DATA_MAPPINGS = {"retail": ["train", "test", "dev"], "airline": ["test"]} 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser(description="Tau1 Mock Script") 13 | parser.add_argument("--local_dir", required=True, help="Path to the local directory") 14 | args = parser.parse_args() 15 | 16 | local_dir = args.local_dir 17 | if not os.path.isdir(local_dir): 18 | os.makedirs(local_dir) 19 | config = RunConfig(model_provider="mock", user_model_provider="mock", user_strategy="human", model="mock") 20 | for env, split in ALL_DATA_MAPPINGS.items(): 21 | for s in split: 22 | config.env = env 23 | config.task_split = s 24 | env_instance = get_env( 25 | env_name=config.env, 26 | user_strategy=config.user_strategy, 27 | user_model=config.user_model, 28 | task_split=config.task_split, 29 | ) 30 | output_path = os.path.join(local_dir, f"{env}_{s}_tasks.jsonl") 31 | with open(output_path, "w") as f: 32 | for i, task in enumerate(env_instance.tasks): 33 | row = {"index": i, "metadata": task.model_dump()} 34 | f.write(json.dumps(row) + "\n") # <-- one JSON object per line 35 | print(f"Saved preprocessed task indices for {env} ({s}) to {output_path}") 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /examples/multi_agent/README.md: -------------------------------------------------------------------------------- 1 | # Multi-Agent RL 2 | 3 | This directory provides an example of running multi-agent reinforcement learning (RL) with slime. 4 | 5 | ## Environment Setup 6 | 7 | The environment setup is identical to the standard RL setup used in slime. 8 | 9 | ## Running the Script 10 | 11 | You can either define your own multi-agent system or use the provided default configuration. 12 | 13 | ```python 14 | MULTI_AGENT_CONFIGS = { 15 | "custom_multi_agent_function_path": "examples.multi_agent.agent_system.run_agent_system", 16 | "num_parallel": 5, 17 | "incorrect_reward_weight": 0.8, 18 | "correct_reward_weight": 1.2, 19 | } 20 | ``` 21 | 22 | To start a run, execute: 23 | 24 | ```bash 25 | cd slime/ 26 | bash examples/multi_agent/run-qwen3-30B-A3B-multi-agent.sh 27 | ``` 28 | 29 | ## New Arguments 30 | 31 | - Specify the agent rollout function with the `--custom-generate-function-path` argument. 32 | - Set the `--rollout-max-context-len` argument according to your model’s context window. 33 | 34 | ```bash 35 | ROLLOUT_ARGS=( 36 | --custom-generate-function-path examples.multi_agent.rollout_with_multi_agents.generate_with_multi_agents 37 | --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl 38 | --input-key prompt 39 | --label-key label 40 | --apply-chat-template 41 | --rollout-shuffle 42 | --rm-type deepscaler 43 | --num-rollout 3000 44 | --rollout-batch-size 32 45 | --n-samples-per-prompt 8 46 | --rollout-max-context-len 16384 47 | --rollout-max-response-len 8192 48 | --rollout-temperature 0.8 49 | 50 | --global-batch-size 256 51 | --balance-data 52 | ) 53 | ``` -------------------------------------------------------------------------------- /docs/zh/index.rst: -------------------------------------------------------------------------------- 1 | slime 文档 2 | ==================== 3 | 4 | slime 是一个面向 RL Scaling 的 LLM 后训练框架,提供两大核心能力: 5 | 6 | - 高性能训练:通过连接 Megatron 与 SGLang,支持多种模式下的高效训练; 7 | - 灵活的数据生成:通过自定义数据生成接口与基于服务器的引擎,实现任意训练数据生成流程。 8 | 9 | slime 是 GLM-4.5 与 GLM-4.6 背后的 RL 训练框架。除此之外,slime 还支持: 10 | 11 | - Qwen3 系列 (Qwen3Next, Qwen3MoE, Qwen3), Qwen2.5 系列; 12 | - DeepSeek V3 系列 (DeepSeek V3, V3.1, DeepSeek R1); 13 | - Llama 3。 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | :caption: 开始使用 18 | 19 | get_started/quick_start.md 20 | get_started/usage.md 21 | get_started/qa.md 22 | 23 | .. toctree:: 24 | :maxdepth: 1 25 | :caption: Dense 26 | 27 | examples/qwen3-4B.md 28 | examples/glm4-9B.md 29 | 30 | .. toctree:: 31 | :maxdepth: 1 32 | :caption: MoE 33 | 34 | examples/qwen3-30B-A3B.md 35 | examples/glm4.5-355B-A32B.md 36 | examples/deepseek-r1.md 37 | 38 | .. toctree:: 39 | :maxdepth: 1 40 | :caption: 高级特性 41 | 42 | _examples_synced/reproducibility/README.md 43 | advanced/speculative-decoding.md 44 | advanced/fault-torlance.md 45 | advanced/arch-support-beyond-megatron.md 46 | 47 | .. toctree:: 48 | :maxdepth: 1 49 | :caption: 其他用法 50 | 51 | examples/qwen3-4b-base-openhermes.md 52 | _examples_synced/search-r1/README.md 53 | _examples_synced/fully_async/README.md 54 | _examples_synced/retool/README.md 55 | _examples_synced/multi_agent/README.md 56 | 57 | .. toctree:: 58 | :maxdepth: 1 59 | :caption: 开发指南 60 | 61 | developer_guide/debug.md 62 | 63 | .. toctree:: 64 | :maxdepth: 1 65 | :caption: 博客 66 | 67 | blogs/release_v0.1.0.md 68 | blogs/introducing_slime.md 69 | -------------------------------------------------------------------------------- /docker/justfile: -------------------------------------------------------------------------------- 1 | release-primary: 2 | ARG_TAG_POSTFIX="" ARG_BUILD_EXTRA_ARGS="" just _release-raw 3 | 4 | # Should be executed on ARM machines 5 | release-cu129-arm64: 6 | ARG_TAG_POSTFIX="-cu129-arm64" ARG_BUILD_EXTRA_ARGS='--build-arg SGLANG_IMAGE_TAG=v0.5.5.post3-cu129-arm64 --build-arg ENABLE_SGLANG_PATCH=0' just _release-raw 7 | 8 | # Should be executed on ARM machines 9 | release-cu13-arm64: 10 | ARG_TAG_POSTFIX="-cu13-arm64" ARG_BUILD_EXTRA_ARGS='--build-arg SGLANG_IMAGE_TAG=dev-arm64-cu13-20251122 --build-arg ENABLE_CUDA_13=1 --build-arg ENABLE_SGLANG_PATCH=0' just _release-raw 11 | 12 | _release-raw: 13 | #!/bin/bash 14 | set -euxo pipefail 15 | cd .. 16 | 17 | VERSION="$(cat docker/version.txt | tr -d '\n')" 18 | IMAGE_TAG=${VERSION}${ARG_TAG_POSTFIX} 19 | 20 | docker build -f docker/Dockerfile . --build-arg HTTP_PROXY="$http_proxy" --build-arg HTTPS_PROXY="$https_proxy" --build-arg NO_PROXY="localhost,127.0.0.1" $ARG_BUILD_EXTRA_ARGS -t slimerl/slime:$IMAGE_TAG 21 | docker push slimerl/slime:$IMAGE_TAG 22 | 23 | if [ -z "${ARG_TAG_POSTFIX}" ]; then 24 | docker tag slimerl/slime:$IMAGE_TAG slimerl/slime:latest 25 | docker push slimerl/slime:latest 26 | fi 27 | 28 | debug: 29 | #!/bin/bash 30 | set -euxo pipefail 31 | cd .. 32 | 33 | VERSION="$(cat docker/version.txt | tr -d '\n')" 34 | IMAGE_TAG=${VERSION} 35 | 36 | docker build -f docker/Dockerfile . --build-arg HTTP_PROXY="$http_proxy" --build-arg HTTPS_PROXY="$https_proxy" --build-arg NO_PROXY="localhost,127.0.0.1" -t slimerl/slime-test:$IMAGE_TAG 37 | docker push slimerl/slime-test:$IMAGE_TAG 38 | -------------------------------------------------------------------------------- /docs/build_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euo pipefail 3 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 4 | cd "$SCRIPT_DIR" 5 | 6 | echo "[slime-docs] Building EN..." 7 | ./build.sh en 8 | echo "[slime-docs] Building ZH..." 9 | ./build.sh zh 10 | 11 | # Create a lightweight root index with auto redirect based on localStorage (done client side) 12 | ROOT_INDEX=build/index.html 13 | cat > "$ROOT_INDEX" <<'EOF' 14 | 15 | 16 | 17 | 18 | slime docs 19 | 20 | 26 | 34 | 35 | 36 |

slime Documentation

37 |

Select language:

38 | 39 |

Auto-redirect uses your last choice if stored; else pick above.

40 | 41 | 42 | EOF 43 | 44 | echo "[slime-docs] Done. Root landing page at build/index.html" -------------------------------------------------------------------------------- /slime/utils/debug_utils/replay_reward_fn.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Annotated 3 | 4 | import ray 5 | import torch 6 | import typer 7 | 8 | from slime.utils.misc import load_function 9 | from slime.utils.types import Sample 10 | 11 | 12 | def _truncate(text, max_len=200): 13 | """Truncate text and add ellipsis if too long.""" 14 | if text is None: 15 | return None 16 | text = str(text).replace("\n", "\\n") 17 | if len(text) > max_len: 18 | return text[:max_len] + "..." 19 | return text 20 | 21 | 22 | def main( 23 | rollout_data_path: Annotated[str, typer.Option()], 24 | custom_rm_path: Annotated[str, typer.Option()], 25 | ): 26 | if not ray.is_initialized(): 27 | ray.init() 28 | 29 | pack = torch.load(rollout_data_path) 30 | samples = [Sample.from_dict(s) for s in pack["samples"]] 31 | asyncio.run(_main_async(samples=samples, custom_rm_path=custom_rm_path)) 32 | 33 | 34 | async def _main_async(samples, custom_rm_path): 35 | rm_function = load_function(custom_rm_path) 36 | rewards = await asyncio.gather(*[rm_function(None, sample) for sample in samples]) 37 | 38 | for i, (sample, reward) in enumerate(zip(samples, rewards, strict=True)): 39 | print("-" * 60) 40 | print(f"Sample {i + 1}/{len(samples)}") 41 | print(f" Index: {sample.index}") 42 | print(f" Status: {sample.status}") 43 | print(f" Reward: {reward}") 44 | print(f" Prompt: {_truncate(sample.prompt, 200)}") 45 | print(f" Response: {_truncate(sample.response, 200)}") 46 | print("-" * 60) 47 | 48 | 49 | if __name__ == "__main__": 50 | typer.run(main) 51 | -------------------------------------------------------------------------------- /slime/rollout/rm_hub/f1.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | from collections import Counter 4 | 5 | 6 | def normalize_answer(s): 7 | 8 | def remove_articles(text): 9 | return re.sub(r"\b(a|an|the)\b", " ", text) 10 | 11 | def white_space_fix(text): 12 | return " ".join(text.split()) 13 | 14 | def remove_punc(text): 15 | exclude = set(string.punctuation) 16 | return "".join(ch for ch in text if ch not in exclude) 17 | 18 | def lower(text): 19 | return text.lower() 20 | 21 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 22 | 23 | 24 | def f1_score(prediction, ground_truth): 25 | ZERO_METRIC = (0, 0, 0) 26 | 27 | if prediction is None: 28 | return ZERO_METRIC 29 | 30 | normalized_prediction = normalize_answer(prediction) 31 | normalized_ground_truth = normalize_answer(ground_truth) 32 | 33 | if normalized_prediction in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: 34 | return ZERO_METRIC 35 | if normalized_ground_truth in ["yes", "no", "noanswer"] and normalized_prediction != normalized_ground_truth: 36 | return ZERO_METRIC 37 | 38 | prediction_tokens = normalized_prediction.split() 39 | ground_truth_tokens = normalized_ground_truth.split() 40 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 41 | num_same = sum(common.values()) 42 | if num_same == 0: 43 | return ZERO_METRIC 44 | precision = 1.0 * num_same / len(prediction_tokens) 45 | recall = 1.0 * num_same / len(ground_truth_tokens) 46 | f1 = (2 * precision * recall) / (precision + recall) 47 | return f1, precision, recall 48 | -------------------------------------------------------------------------------- /docs/zh/advanced/arch-support-beyond-megatron.md: -------------------------------------------------------------------------------- 1 | # 在 Megatron-LM 中快速支持新模型架构 2 | 3 | Megatron-LM 框架虽然并行效率高,但在支持日新月异的新模型架构(如 Qwen3Next)时,其灵活性有所欠缺。若要原生支持这些模型的特殊结构(例如 Gated-Delta-Net),往往需要对 Megatron 的核心代码进行侵入性较大、开发周期较长的改造。 4 | 5 | 为了能快速跟进这些前沿模型,`slime` 提出了一种更敏捷的方案:**与其深度改造 Megatron,不如直接引入并封装模型官方的 HuggingFace 实现**,将其作为一个“黑盒模块”无缝嵌入到 Megatron 的并行训练流程中。 6 | 7 | 本文以 Qwen3Next 80B-A3B 为例,介绍这一实现思路。 8 | 9 | ## 实现原理与核心组件 10 | 11 | Megatron 的模型实例化分为两步:首先根据配置生成“层规格”(`ModuleSpec`),再依据该规格实例化具体的 PyTorch 模块。 12 | 13 | `slime` 正是利用这一机制,在**生成 Spec 的阶段“劫持”并替换掉 Megatron 的原生模块**,从而将外部实现(此处为 HuggingFace 模块)无缝嵌入。这一过程主要涉及三个核心组件的协同: 14 | 15 | 1. **替换 Megatron 模块规格 (Spec)** 16 | 这是整个方案的入口。我们通过一个自定义函数(例如 `get_qwen3_next_spec`)来修改标准的 `ModuleSpec`,用我们自己的封装层换掉 Megatron 的原生 Attention 层。 17 | * **具体操作**:获取标准的 Decoder Block Spec,将其 `self_attention` 字段指向我们的自定义模块,并按需开启 `qk_layernorm` 等模型特有配置。 18 | * **对应文件**: `slime_plugins/models/qwen3_next.py` 19 | 20 | 2. **封装 HuggingFace 实现** 21 | 上一步的 Spec 会指向一个封装层,例如 `HuggingfaceAttention`。它继承了 Megatron 的 `MegatronModule`,核心职责是作为桥梁,处理好并行策略所需的数据对齐(如序列并行),然后在内部直接调用从 HuggingFace 加载的原生 `Qwen3NextAttention` 模块。 22 | * **对应文件**: `slime_plugins/models/hf_attention.py` 23 | 24 | 3. **对齐模型权重** 25 | 模型结构跑通后,还需要确保权重能正确加载。我们借助 [mbridge](https://github.com/ISEEKYAN/mbridge) 库,通过 `Qwen3NextBridge` 建立了 HuggingFace Checkpoint 与 Megatron 参数之间的命名映射关系,实现双向互通。 26 | * **对应文件**: `slime_plugins/mbridge/qwen3_next.py` 27 | 28 | 通过这三层协同,我们成功地将一个 Megatron 原本不支持的复杂模型结构(以其 HuggingFace 实现为载体),运行在了 Megatron 的并行框架之上,并完整保留了模型并行、MoE 加速、流水线调度等全部关键能力。 29 | 30 | ## 当前限制 31 | 32 | * 本方案暂不支持被替换模块(如此处的 Attention 层)自身的张量并行(TP)。 33 | * **影响**:在大多数大规模 MoE 模型中,Attention 层的参数量占比较小,因此该限制对显存占用和训练吞吐的影响通常有限。 34 | * **替代方案**:如果该模块的 TP 至关重要,则需要回归到侵入式修改 Megatron 的原生实现方案。 35 | -------------------------------------------------------------------------------- /scripts/models/kimi-k2.sh: -------------------------------------------------------------------------------- 1 | NLAYERS=61 2 | FIRST_K_DENSE_REPLACE=1 3 | 4 | arr=() 5 | for ((i=0; i=2.0", 42 | ] 43 | }, 44 | python_requires=">=3.10", 45 | classifiers=[ 46 | "Programming Language :: Python :: 3.10", 47 | "Programming Language :: Python :: 3.11", 48 | "Programming Language :: Python :: 3.12", 49 | "Environment :: GPU :: NVIDIA CUDA", 50 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 51 | "Topic :: System :: Distributed Computing", 52 | ], 53 | cmdclass={"bdist_wheel": bdist_wheel}, 54 | ) 55 | -------------------------------------------------------------------------------- /docs/_static/css/custom_log.css: -------------------------------------------------------------------------------- 1 | .output_area { 2 | color: #615656; 3 | } 4 | 5 | table.autosummary td { 6 | width: 50% 7 | } 8 | 9 | img.align-center { 10 | display: block; 11 | margin-left: auto; 12 | margin-right: auto; 13 | } 14 | 15 | .output_area.stderr { 16 | color: #d3d3d3 !important; 17 | } 18 | 19 | .output_area.stdout { 20 | color: #d3d3d3 !important; 21 | } 22 | 23 | div.output_area.stderr { 24 | color: #d3d3d3 !important; 25 | } 26 | 27 | div.output_area.stdout { 28 | color: #d3d3d3 !important; 29 | } 30 | 31 | /* Language toggle button styling */ 32 | .lang-toggle-btn { 33 | --lt-border: var(--pst-color-border, #d0d7de); 34 | --lt-bg: var(--pst-color-surface, #f6f8fa); 35 | --lt-bg-hover: var(--pst-color-on-surface, #e6ebf1); 36 | --lt-active: var(--pst-color-primary, #0969da); 37 | display: inline-flex; 38 | align-items: center; 39 | padding: 2px 10px; 40 | line-height: 1.1; 41 | font-size: 0.72rem; 42 | font-weight: 600; 43 | border: 1px solid var(--lt-border); 44 | border-radius: 6px; 45 | background: var(--lt-bg); 46 | cursor: pointer; 47 | gap: 2px; 48 | letter-spacing: .5px; 49 | } 50 | .lang-toggle-btn:hover { 51 | background: var(--lt-bg-hover); 52 | } 53 | .lang-toggle-btn .lang-seg { 54 | opacity: .55; 55 | transition: opacity .15s; 56 | } 57 | .lang-toggle-btn[data-current="en"] .lang-seg[data-lang="en"], 58 | .lang-toggle-btn[data-current="zh"] .lang-seg[data-lang="zh"] { 59 | opacity: 1; 60 | color: var(--lt-active); 61 | } 62 | .lang-toggle-btn .lang-sep { opacity: .35; } 63 | 64 | @media (prefers-color-scheme: dark) { 65 | .lang-toggle-btn { --lt-border: #30363d; --lt-bg:#161b22; --lt-bg-hover:#1c2128; } 66 | .lang-toggle-btn .lang-seg { color: #adbac7; } 67 | } 68 | -------------------------------------------------------------------------------- /scripts/models/deepseek-v3.sh: -------------------------------------------------------------------------------- 1 | NLAYERS="${MODEL_ARGS_NUM_LAYERS:-61}" 2 | FIRST_K_DENSE_REPLACE=3 3 | 4 | arr=() 5 | for ((i=0; i= 11: 24 | - cc_flag.append('-gencode') 25 | - cc_flag.append('arch=compute_80,code=sm_80') 26 | - if int(bare_metal_minor) >= 8: 27 | + if torch.cuda.is_available() and torch.version.cuda: 28 | + # Check if cuda 11 is installed for compute capability 8.0 29 | + cc_flag = [] 30 | + _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( 31 | + cpp_extension.CUDA_HOME 32 | + ) 33 | + if int(bare_metal_major) >= 11: 34 | cc_flag.append('-gencode') 35 | - cc_flag.append('arch=compute_90,code=sm_90') 36 | + cc_flag.append('arch=compute_80,code=sm_80') 37 | + if int(bare_metal_minor) >= 8: 38 | + cc_flag.append('-gencode') 39 | + cc_flag.append('arch=compute_90,code=sm_90') 40 | 41 | - # Build path 42 | - srcpath = pathlib.Path(__file__).parent.absolute() 43 | - buildpath = srcpath / "build" 44 | - _create_build_dir(buildpath) 45 | + # Build path 46 | + srcpath = pathlib.Path(__file__).parent.absolute() 47 | + buildpath = srcpath / "build" 48 | + _create_build_dir(buildpath) 49 | 50 | # Helper function to build the kernels. 51 | def _cpp_extention_load_helper(name, sources, extra_cuda_flags): 52 | -------------------------------------------------------------------------------- /scripts/models/moonlight.sh: -------------------------------------------------------------------------------- 1 | MOE_SHARED_EXPERTS=2 2 | MOE_FFN_HIDDEN=1408 3 | MOE_SHARED_EXPERT_INTERMEDIATE_SIZE=$(($MOE_FFN_HIDDEN * $MOE_SHARED_EXPERTS)) 4 | MOE_ROUTER_TOPK_SCALING_FACTOR=2.446 5 | NLAYERS=27 6 | FIRST_K_DENSE_REPLACE=1 7 | 8 | arr=() 9 | for ((i=0; i= 11: 24 | - cc_flag.append('-gencode') 25 | - cc_flag.append('arch=compute_80,code=sm_80') 26 | - if int(bare_metal_minor) >= 8: 27 | + if torch.cuda.is_available() and torch.version.cuda: 28 | + # Check if cuda 11 is installed for compute capability 8.0 29 | + cc_flag = [] 30 | + _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( 31 | + cpp_extension.CUDA_HOME 32 | + ) 33 | + if int(bare_metal_major) >= 11: 34 | cc_flag.append('-gencode') 35 | - cc_flag.append('arch=compute_90,code=sm_90') 36 | + cc_flag.append('arch=compute_80,code=sm_80') 37 | + if int(bare_metal_minor) >= 8: 38 | + cc_flag.append('-gencode') 39 | + cc_flag.append('arch=compute_90,code=sm_90') 40 | 41 | - # Build path 42 | - srcpath = pathlib.Path(__file__).parent.absolute() 43 | - buildpath = srcpath / "build" 44 | - _create_build_dir(buildpath) 45 | + # Build path 46 | + srcpath = pathlib.Path(__file__).parent.absolute() 47 | + buildpath = srcpath / "build" 48 | + _create_build_dir(buildpath) 49 | 50 | # Helper function to build the kernels. 51 | def _cpp_extention_load_helper(name, sources, extra_cuda_flags): 52 | -------------------------------------------------------------------------------- /docs/en/index.rst: -------------------------------------------------------------------------------- 1 | slime Documentation 2 | ==================== 3 | 4 | slime is an LLM post-training framework for RL scaling, providing two core capabilities: 5 | 6 | - High-Performance Training: Supports efficient training in various modes by connecting Megatron with SGLang; 7 | - Flexible Data Generation: Enables arbitrary training data generation workflows through custom data generation interfaces and server-based engines. 8 | 9 | slime is the RL-framework behind GLM-4.5 and GLM-4.6. Apart from models from Z.ai, we also supports the following models: 10 | 11 | - Qwen3 series (Qwen3Next, Qwen3MoE, Qwen3), Qwen2.5 series; 12 | - DeepSeek V3 series (DeepSeek V3, V3.1, DeepSeek R1); 13 | - Llama 3. 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | :caption: Get Started 18 | 19 | get_started/quick_start.md 20 | get_started/usage.md 21 | get_started/qa.md 22 | 23 | .. toctree:: 24 | :maxdepth: 1 25 | :caption: Dense 26 | 27 | examples/qwen3-4B.md 28 | examples/glm4-9B.md 29 | 30 | .. toctree:: 31 | :maxdepth: 1 32 | :caption: MoE 33 | 34 | examples/qwen3-30B-A3B.md 35 | examples/glm4.5-355B-A32B.md 36 | examples/deepseek-r1.md 37 | 38 | .. toctree:: 39 | :maxdepth: 1 40 | :caption: Advanced Features 41 | 42 | _examples_synced/reproducibility/README.md 43 | advanced/speculative-decoding.md 44 | advanced/fault-tolerance.md 45 | advanced/arch-support-beyond-megatron.md 46 | 47 | .. toctree:: 48 | :maxdepth: 1 49 | :caption: Other Usage 50 | 51 | examples/qwen3-4b-base-openhermes.md 52 | _examples_synced/search-r1/README.md 53 | _examples_synced/fully_async/README.md 54 | _examples_synced/retool/README.md 55 | _examples_synced/multi_agent/README.md 56 | 57 | .. toctree:: 58 | :maxdepth: 1 59 | :caption: Developer Guide 60 | 61 | developer_guide/debug.md 62 | 63 | .. toctree:: 64 | :maxdepth: 1 65 | :caption: Hardware Platforms 66 | 67 | platform_support/amd_tutorial.md 68 | 69 | .. toctree:: 70 | :maxdepth: 1 71 | :caption: Blogs 72 | 73 | blogs/release_v0.1.0.md 74 | blogs/introducing_slime.md 75 | -------------------------------------------------------------------------------- /docs/zh/developer_guide/debug.md: -------------------------------------------------------------------------------- 1 | # Debug 指南 2 | 3 | ## 对齐精度 4 | 5 | 在开发 slime 的过程中,经常会需要检查模型的精度是否正确,可以通过以下方式检查: 6 | 7 | 1. 训练第一步 8 | 1. rollout 的生成是否是人话,如果不是,有以下 2 种可能: 9 | - 参数没有正常加载。需要查看是否有 megatron 成功加载 ckpt 的日志; 10 | - 更新参数有误。可以查看是不是所有的参数都做了转换和参数对应,或者参数名是不是根据并行做了转换(例如 pp_size > 1 时,第二个 stage 提供的参数的 layer id 是不是正确的)。一个比较彻底的方法是在对应模型的 sglang 实现的 `load_weights` 中保存所有的参数,查看和加载的 ckpt 中是否一致; 11 | - 如果所有参数更新都正确,还出现问题,有可能是 sglang 里有一些特殊的 buffer 在 release 的时候被释放了; 12 | - 如果是用 pretrain 模型进行的测试,可以换成同结构模型的 instruct 版本,查看这种乱码是不是 pretrain 模型特有的。 13 | 2. 查看打印的 rollout stats 的 `log_probs` 和 `ref_log_probs` 是否完全相等(即第一步 kl=0),且值较小 14 | - 如果不是完全相等的,一般是 transformer engine 中的某些 non-deterministic kernel 导致的,例如: 15 | - 在某些版本的 te 里,megatron 需要 `--attention-backend flash`,来强制使用 flash attention,从而避免 CP 下 fused attention 的数值不稳定; 16 | - 如果数值较大(例如 >1),一般有 2 种可能: 17 | - 如果值非常大,应该是训练配置有问题; 18 | - 如果值只是比 sft loss 的状态略大,例如 instruct 模型的 logprob 到了 0.8,有可能是数据不符合训练的 chat template,或者不符合冷启动的分布。 19 | 3. 查看在推一训一(`num_steps_per_rollout == 1`),kl 是否为 0,grad_norm 是否较小 20 | - 基本上就是一些 megatron / te 相关的 bug,例如: 21 | - moe 需要开启 `--moe-permute-fusion`。 22 | 23 | 2. 训练第二步 24 | 1. 对于训推一体,查看是否能正确加载第二步,是否会 OOM; 25 | 26 | ## 训练推理单独 debug 27 | 28 | slime 支持将训练部分和推理部分分开进行调试,从而实现: 29 | 30 | - 在调优/debug 推理部分时,只用少量卡就可以启动任务; 31 | - 在调优/debug 训练部分时,可以保证模型输入固定,去除 rollout 的随机性。 32 | 33 | 具体来说,目前 slime 提供了如下的参数来进行分离调试: 34 | 35 | 1. `--debug-rollout-only` 36 | 37 | 开启后,slime 将不会加载 megatron,只初始化 sglang ,可以用这个方法来进行推理部分的调试。 38 | 39 | 1. `--debug-train-only` 40 | 41 | 开启后,slime 将不会加载 sglang,只初始化 megatron ,可以用这个方法来进行训练部分的调试。 42 | 43 | 2. `--save-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt` 44 | 45 | 开启后,会保存每次 rollout 的结果,可以和 `--debug-rollout-only` 配合使用。注意保存的方式为 `args.save_debug_rollout_data.format(rollout_id=rollout_id)`。 46 | 47 | 3. `--load-debug-rollout-data /your/saved/debug/data_{rollout_id}.pt` 48 | 49 | 开启后,会从 `args.load_debug_rollout_data.format(rollout_id=rollout_id)` 来加载数据,并且不会初始化 sglang(自动设置 `debug_train_only=True`)。可以以这种方式来固定训练部分的输入,对训练部分进行调优,例如切换各种并行。 50 | -------------------------------------------------------------------------------- /slime/utils/train_metric_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from argparse import Namespace 3 | from collections.abc import Callable 4 | from copy import deepcopy 5 | 6 | from slime.utils import tracking_utils 7 | from slime.utils.metric_utils import compute_rollout_step 8 | from slime.utils.timer import Timer 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def log_perf_data_raw( 14 | rollout_id: int, args: Namespace, is_primary_rank: bool, compute_total_fwd_flops: Callable 15 | ) -> None: 16 | timer_instance = Timer() 17 | log_dict_raw = deepcopy(timer_instance.log_dict()) 18 | timer_instance.reset() 19 | 20 | if not is_primary_rank: 21 | return 22 | 23 | log_dict = {f"perf/{key}_time": val for key, val in log_dict_raw.items()} 24 | 25 | if ("perf/actor_train_time" in log_dict) and (compute_total_fwd_flops is not None): 26 | total_fwd_flops = compute_total_fwd_flops(seq_lens=timer_instance.seq_lens) 27 | 28 | if "perf/log_probs_time" in log_dict: 29 | log_dict["perf/log_probs_tflops"] = total_fwd_flops / log_dict["perf/log_probs_time"] 30 | 31 | if "perf/ref_log_probs_time" in log_dict: 32 | log_dict["perf/ref_log_probs_tflops"] = total_fwd_flops / log_dict["perf/ref_log_probs_time"] 33 | 34 | if log_dict["perf/actor_train_time"] > 0: 35 | log_dict["perf/actor_train_tflops"] = 3 * total_fwd_flops / log_dict["perf/actor_train_time"] 36 | log_dict["perf/actor_train_tok_per_s"] = sum(timer_instance.seq_lens) / log_dict["perf/actor_train_time"] 37 | 38 | if "perf/train_wait_time" in log_dict and "perf/train_time" in log_dict: 39 | total_time = log_dict["perf/train_wait_time"] + log_dict["perf/train_time"] 40 | if total_time > 0: 41 | log_dict["perf/step_time"] = total_time 42 | log_dict["perf/wait_time_ratio"] = log_dict["perf/train_wait_time"] / total_time 43 | 44 | logger.info(f"perf {rollout_id}: {log_dict}") 45 | 46 | step = compute_rollout_step(args, rollout_id) 47 | log_dict["rollout/step"] = step 48 | tracking_utils.log(args, log_dict, step_key="rollout/step") 49 | -------------------------------------------------------------------------------- /examples/eval/nemo_skills/skills_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Mapping 4 | from dataclasses import dataclass, field 5 | from typing import Any 6 | 7 | from examples.eval.eval_delegate import EvalEnvConfig, EvalEnvDatasetConfig 8 | 9 | 10 | @dataclass 11 | class SkillsEvalEnvDatasetConfig(EvalEnvDatasetConfig): 12 | """Dataset configuration shared by the Skills client/server.""" 13 | 14 | def __post_init__(self): 15 | name = (self.name or "").strip() 16 | self.name = name 17 | if not name: 18 | raise ValueError("Each Skills dataset entry must include a non-empty `name`.") 19 | if ":" in name: 20 | raise ValueError( 21 | "Colon in dataset name is not allowed; use `n_samples_per_eval_prompt` to configure samples per prompt." 22 | ) 23 | 24 | @property 25 | def runtime_name(self) -> str: 26 | if self.n_samples_per_eval_prompt is None: 27 | return self.name 28 | return f"{self.name}:{self.n_samples_per_eval_prompt}" 29 | 30 | @classmethod 31 | def parse(cls, args, dataset_cfg: Mapping[str, Any], defaults: Mapping[str, Any]): 32 | return super().parse(args, dataset_cfg, defaults) 33 | 34 | 35 | @dataclass 36 | class SkillsEvalEnvConfig(EvalEnvConfig): 37 | """Environment configuration shared by the Skills client/server.""" 38 | 39 | datasets: list[SkillsEvalEnvDatasetConfig] = field(default_factory=list) 40 | 41 | @classmethod 42 | def parse(cls, args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]) -> SkillsEvalEnvConfig: 43 | base_cfg: SkillsEvalEnvConfig = super().parse(raw_env_config, defaults) 44 | datasets = raw_env_config.get("datasets") or [] 45 | base_cfg.datasets = [ 46 | SkillsEvalEnvDatasetConfig.parse(args, dataset_cfg, base_cfg.defaults) for dataset_cfg in datasets 47 | ] 48 | return base_cfg 49 | 50 | 51 | def build_skills_eval_env_config(args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]): 52 | return SkillsEvalEnvConfig.parse(args, raw_env_config, defaults) 53 | -------------------------------------------------------------------------------- /slime/backends/fsdp_utils/models/qwen3_moe_hf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def apply_fsdp_moe_patch(): 6 | 7 | from transformers.models.qwen3_moe import modeling_qwen3_moe 8 | 9 | def _forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 10 | batch_size, sequence_length, hidden_dim = hidden_states.shape 11 | hidden_states = hidden_states.view(-1, hidden_dim) 12 | router_logits = self.gate(hidden_states) 13 | 14 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) 15 | routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) 16 | if self.norm_topk_prob: 17 | routing_weights /= routing_weights.sum(dim=-1, keepdim=True) 18 | routing_weights = routing_weights.to(hidden_states.dtype) 19 | 20 | final_hidden_states = torch.zeros( 21 | (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device 22 | ) 23 | 24 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) 25 | 26 | # Loop over all experts 27 | for expert_idx in range(self.num_experts): 28 | expert_layer = self.experts[expert_idx] 29 | idx, top_x = torch.where(expert_mask[expert_idx]) 30 | 31 | if top_x.numel() > 0: 32 | current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) 33 | current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] 34 | final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) 35 | else: 36 | # force experts to participate in computation graph 37 | dummy_output = expert_layer(hidden_states[:1]) * 0.0 38 | final_hidden_states[:1] = final_hidden_states[:1] + dummy_output 39 | 40 | final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) 41 | return final_hidden_states, router_logits 42 | 43 | modeling_qwen3_moe.Qwen3MoeSparseMoeBlock.forward = _forward 44 | -------------------------------------------------------------------------------- /tests/test_chunked_gae.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pytest 3 | import torch 4 | 5 | from slime.utils.ppo_utils import chunked_gae, vanilla_gae 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "B,T", 10 | [ 11 | (16, 4096), 12 | (32, 8192), 13 | (256, 128 * 1024), 14 | ], 15 | ) 16 | @pytest.mark.parametrize("chunk_size", [64, 128, 256]) 17 | def test_gae_parallel_matches_serial(B, T, chunk_size): 18 | """ 19 | Test that chunked_gae (parallel-scan) matches vanilla_gae (batch-serial) 20 | under various shapes, chunk sizes and dtypes. 21 | """ 22 | device = "cuda" if torch.cuda.is_available() else "cpu" 23 | torch.manual_seed(0) 24 | 25 | rewards = torch.randn(B, T, device=device, dtype=torch.float32) 26 | values = torch.randn(B, T, device=device, dtype=torch.float32) 27 | 28 | gamma, lam = 0.99, 0.95 29 | 30 | # ---------- Serial ---------- 31 | if device == "cuda": 32 | torch.cuda.synchronize() 33 | t0 = time.time() 34 | adv_s, ret_s = vanilla_gae(rewards, values, gamma, lam) 35 | if device == "cuda": 36 | torch.cuda.synchronize() 37 | t1 = time.time() 38 | serial_time = t1 - t0 39 | 40 | # ---------- Parallel-scan ---------- 41 | if device == "cuda": 42 | torch.cuda.synchronize() 43 | t0 = time.time() 44 | adv_p, ret_p = chunked_gae(rewards, values, gamma, lam, chunk_size=chunk_size) 45 | if device == "cuda": 46 | torch.cuda.synchronize() 47 | t1 = time.time() 48 | parallel_time = t1 - t0 49 | 50 | # ---------- Accuracy ---------- 51 | adv_err = (adv_s - adv_p).abs().max().item() 52 | ret_err = (ret_s - ret_p).abs().max().item() 53 | 54 | atol = 1e-5 55 | assert adv_err < atol, f"adv error too large: {adv_err}" 56 | assert ret_err < atol, f"ret error too large: {ret_err}" 57 | 58 | # ---------- logging ---------- 59 | print(f"\n[GAE Test] B={B}, T={T}, chunk={chunk_size}") 60 | print(f" Serial : {serial_time:.6f} s") 61 | print(f" Parallel : {parallel_time:.6f} s") 62 | print(f" Speedup : x{serial_time / parallel_time:.2f}") 63 | print(f" Max diff adv={adv_err:.3e}, ret={ret_err:.3e}") 64 | -------------------------------------------------------------------------------- /examples/strands-agents/README.md: -------------------------------------------------------------------------------- 1 | # Slime x Strands-Agents 2 | 3 | This is a running example that connects the [Strands-Agents](https://github.com/strands-agents/sdk-python) agent scaffolding framework with Slime for RL training. 4 | 5 | ## Install Dependencies 6 | 7 | 1. Pull the `slimerl/slime:latest` image and enter it 8 | 2. Goes to slime folder: `cd /root/slime` (Clone the repository if not already there: `cd /root && git clone https://github.com/THUDM/slime.git`) 9 | 3. Install Slime: `pip install -e .` 10 | 4. Goes to the example folder: `cd /root/slime/examples/strands-agents` 11 | 5. Install other dependencies: `pip install -r requirements.txt` 12 | 13 | > NOTE: we use camel-ai's subprocess code interpreter for python code execution, which is NOT a good practice; it's just for convenience of this example and the dependencies for solving math problems are usually ready in `slime`'s docker 14 | 15 | ## Prepare Model 16 | 17 | ```bash 18 | # hf checkpoint 19 | huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/models/Qwen/Qwen3-4B-Instruct-2507 20 | 21 | # mcore checkpoint 22 | cd /root/slime 23 | source scripts/models/qwen3-4B.sh 24 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ 25 | ${MODEL_ARGS[@]} \ 26 | --hf-checkpoint /root/models/Qwen/Qwen3-4B-Instruct-2507 \ 27 | --save /root/models/Qwen/Qwen3-4B-Instruct-2507_torch_dist 28 | ``` 29 | 30 | ## Prepare Dataset 31 | 32 | Following [Retool](https://arxiv.org/abs/2504.11536), we used `dapo-math-17k` as training data: 33 | 34 | ``` 35 | from datasets import load_dataset 36 | ds = load_dataset("zhuzilin/dapo-math-17k", split="train") 37 | ds.to_json("/root/data/dapo-math-17k.jsonl", orient="records", lines=True) 38 | ``` 39 | 40 | and `aime-2024` as eval data: 41 | 42 | ``` 43 | from datasets import load_dataset 44 | ds = load_dataset("zhuzilin/aime-2024", split="train") 45 | ds.to_json("/root/data/aime-2024.jsonl", orient="records", lines=True) 46 | ``` 47 | 48 | ## Run Training 49 | 50 | Assuming `/root/slime` is up-to-date (if this PR is not merged you may need to switch branch): 51 | 52 | ``` 53 | cd /root/slime 54 | export WANDB_KEY=$your_wandb_key 55 | bash examples/strands-agents/strands_qwen3_4b.sh 56 | ``` 57 | -------------------------------------------------------------------------------- /slime_plugins/rollout_buffer/README.md: -------------------------------------------------------------------------------- 1 | # Rollout Buffer 2 | 3 | ## Overview 4 | 5 | Rollout Buffer is an independent component for asynchronous agent trajectory generation, with the main function of using the LLM OpenAI Server launched by slime training to generate agent trajectories. 6 | 7 | ### Workflow 8 | 9 | ``` 10 | slime Training Process ←─── HTTP API ───→ Rollout Buffer 11 | ↓ ↓ 12 | LLM Server ←─────── HTTP Requests ─────── Agent Framework 13 | ↓ ↓ 14 | Model Response ──────────────────────→ Trajectory Generation 15 | ``` 16 | 17 | For each different Agent task, there should be a corresponding independent Generator class, responsible for generating trajectories for that type of task. Rollout Buffer automatically reads and loads different types of Generators. 18 | 19 | ## Quick Start 20 | 21 | ### Basic Usage Process 22 | 23 | 1. **Copy Template**: Copy `base_generator.py` as a template 24 | 2. **Modify Task Type**: Change `TASK_TYPE` to your task name (cannot duplicate with other Generators) 25 | 3. **Implement Core Function**: Implement the `run_rollout()` function 26 | 4. **Optional Customization**: Rewrite five optional functions as needed 27 | 28 | 29 | Generator files must end with `_generator.py` and be placed in the `generator/` directory: 30 | 31 | ``` 32 | generator/ 33 | ├── base_generator.py # Math task implementation (default template) 34 | └── your_task_generator.py # Your custom task 35 | ``` 36 | 37 | Each Generator file must define `TASK_TYPE` and `run_rollout()`. 38 | 39 | In addition, Rollout Buffer also provides some customizable functions to meet special needs of different tasks. If no custom implementation is provided, the system will use default implementations (located in `slime_plugins/rollout_buffer/default_func.py`). 40 | 41 | ### Example Script 42 | 43 | First, you need to follow [Example: Qwen3-4B Model](../../docs/en/models/qwen3-4B.md) to configure the environment, download data and convert model checkpoints. And then run the following scripts: 44 | ```bash 45 | cd slime_plugins/rollout_buffer 46 | bash rollout_buffer_example.sh 47 | 48 | # In a different terminal 49 | python buffer.py 50 | ``` 51 | -------------------------------------------------------------------------------- /examples/tau-bench/README.md: -------------------------------------------------------------------------------- 1 | # Tau bench 2 | This example shows slime training in an agentic multi-turn tool use environment. 3 | 4 | 5 | ## Environment Setup 6 | Use the `zhuzilin/slime:latest` image and initialize the environment required for Search-R1: 7 | 8 | ```bash 9 | cd /root/ 10 | git clone https://github.com/THUDM/slime.git 11 | cd slime 12 | pip install -e . 13 | # for tau bench 14 | cd /root/ 15 | git clone https://github.com/JD-ETH/tau-bench.git 16 | cd tau-bench 17 | git checkout feature/litellm-retry 18 | pip install -e . 19 | ``` 20 | 21 | Use the following script to generate mock data for slime training. 22 | 23 | ```bash 24 | cd /root/slime/examples/tau-bench 25 | python tau1_mock.py --local_dir /root/tau-bench/ 26 | ``` 27 | 28 | Initialize the Qwen2.5-3B-Instruct model needed for tool use: 29 | 30 | ```bash 31 | # hf checkpoint 32 | huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/Qwen3-4B-Instruct-2507 33 | 34 | # mcore checkpoint 35 | cd /root/slime 36 | source scripts/models/qwen3-4B-Instruct-2507.sh 37 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ 38 | ${MODEL_ARGS[@]} \ 39 | --hf-checkpoint /root/Qwen3-4B-Instruct-2507 \ 40 | --save /root/Qwen3-4B-Instruct-2507_torch_dist 41 | ``` 42 | 43 | ## Running the Script 44 | 45 | You need to configure your litellm API in `generate_with_tau.py` for user simulation: 46 | 47 | ```python 48 | TAU_CONFIGS = { 49 | "env": "retail", # Select between ["retail", "airline"] 50 | "agent": "tool-calling", # Select between ["tool-calling", "act", "react", "few-shot"], only tool-calling implemented for now 51 | "user_model": "gemini-2.0-flash-lite", # Cheap Model for user simulator 52 | "user_model_provider": "gemini", 53 | "task_split": "train", # Select between ["train", "test", "dev"] for retail, ["test"] for airline 54 | "user_strategy": "llm", # Select between ["llm", "react", "verify", "reflection"] 55 | "model_provider": "auto_router", # Unused, required 56 | "model": "qwen3-4b", # Unused, reqired 57 | } 58 | # Replace with your actual API key for user sim 59 | GEMINI_API_KEY = "YOUR KEY" 60 | ``` 61 | 62 | And run: 63 | 64 | 65 | ```bash 66 | cd /root/slime 67 | bash examples/tau-bench/run_qwen3_4B.sh 68 | ``` -------------------------------------------------------------------------------- /examples/eval/README.md: -------------------------------------------------------------------------------- 1 | # Docs 2 | 3 | ## Prerequisites 4 | - A writable host directory for cached data (`/data/.cache`) 5 | - Choose descriptive container names to replace the placeholders (``, ``). 6 | 7 | ## 1) Prepare host network 8 | ```bash 9 | docker network create skills-net 10 | ``` 11 | 12 | ## 2) Launch the slime container 13 | ```bash 14 | docker run \ 15 | -itd \ 16 | --shm-size 32g \ 17 | --gpus all \ 18 | -v /data/.cache:/root/.cache \ 19 | -v /dev/shm:/shm \ 20 | --ipc=host \ 21 | --privileged \ 22 | --network skills-net \ 23 | --name \ 24 | slimerl/slime:latest \ 25 | /bin/bash 26 | ``` 27 | 28 | ## 3) Launch the Skills container 29 | ```bash 30 | docker run \ 31 | -itd \ 32 | --shm-size 32g \ 33 | --gpus all \ 34 | -v /data/.cache:/root/.cache \ 35 | -v /dev/shm:/shm \ 36 | --ipc=host \ 37 | --privileged \ 38 | --network skills-net \ 39 | --name \ 40 | --network-alias skills_server \ 41 | guapisolo/nemoskills:0.7.1 \ 42 | /bin/bash 43 | ``` 44 | 45 | ## 4) Inside the Skills container 46 | Clone repos and install the Skills package: 47 | ```bash 48 | git clone -b slime_skills https://github.com/guapisolo/slime.git /opt/slime 49 | git clone -b slime https://github.com/guapisolo/Skills.git /opt/Skills 50 | 51 | cd /opt/Skills 52 | pip install -e . 53 | ``` 54 | 55 | Download/prepare datasets: 56 | ```bash 57 | cd /opt/Skills/nemo_skills/dataset 58 | python3 aime25/prepare.py 59 | python3 hle/prepare.py 60 | python3 arena-hard/prepare.py 61 | ``` 62 | 63 | Start the skills server: 64 | ```bash 65 | cd /opt/slime 66 | python examples/eval/nemo_skills/skills_server.py \ 67 | --host 0.0.0.0 \ 68 | --port 9050 \ 69 | --output-root /opt/skills-eval \ 70 | --config-dir examples/eval/nemo_skills/config \ 71 | --cluster local_cluster \ 72 | --max-concurrent-requests 512 \ 73 | --openai-model-name slime-openai-model 74 | ``` 75 | 76 | You can now connect to the server at `skills_server:9050` from within the `skills-net` Docker network. The server always proxies evaluation traffic to an OpenAI-compatible sglang router (Slime starts and manage the router), so adjust `--openai-model-name` and `--max-concurrent-requests` as needed for your deployment. 77 | -------------------------------------------------------------------------------- /examples/fully_async/README.md: -------------------------------------------------------------------------------- 1 | ## Fully Asynchronous Rollout Example 2 | 3 | This example shows a simple way to make rollout generation **fully asynchronous**: a single global worker is created once and then keeps running in the background, continuously pulling prompts and launching generation tasks. Training only needs to fetch already finished results. This removes the per‑step wait that happens in the normal synchronous style. 4 | 5 | ### Files 6 | * `fully_async_rollout.py`: global async worker + `generate_rollout_fully_async` entry. 7 | * `run-qwen3-4b-fully_async.sh`: example launch script with Qwen3‑4B. 8 | 9 | ### Prerequisite 10 | First set up model & environment following the Qwen3-4B example. 11 | 12 | ### Quick Start 13 | ```bash 14 | cd slime 15 | bash examples/fully_async/run-qwen3-4b-fully_async.sh 16 | ``` 17 | You should see log lines like: 18 | ``` 19 | Creating new global async worker... 20 | Continuous async rollout worker started 21 | ``` 22 | 23 | ### How It Works (Very Short) 24 | * First call: create `AsyncRolloutWorker` (thread + asyncio loop). 25 | * Loop keeps up to `--rollout-batch-size` tasks in flight using `generate_and_rm_group`. 26 | * Completed groups are pushed into a queue; caller drains until it has enough samples. 27 | * Worker is stopped automatically at process exit. 28 | 29 | ### Limitations 30 | * No evaluation mode. 31 | * Ordering is best effort (sorted at the end by index). 32 | * Minimal error handling. 33 | 34 | ### Config Differences (2 Key Points) 35 | To enable the fully async pattern there are only two changes compared to a normal run: 36 | 37 | 1. Use the async training driver: `train_async.py` (not `train.py`). 38 | 2. Set the rollout function path: 39 | ```bash 40 | --rollout-function-path fully_async_rollout.generate_rollout_fully_async 41 | ``` 42 | 43 | Why is it still "fully" async although `train_async.py` itself schedules rollouts step‑by‑step? 44 | 45 | Because the real generation work is done by a **persistent background worker** created in `generate_rollout_fully_async`. Each call from `train_async.py` only drains already completed samples from the worker's output queue; the worker has been continuously generating since the first call. Thus rollout production (model inference) and training consume happen in parallel with minimal waiting. 46 | -------------------------------------------------------------------------------- /slime/utils/debug_utils/send_to_sglang.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from typing import Annotated 4 | 5 | import typer 6 | from openai import AsyncOpenAI 7 | 8 | from slime.utils.data import read_file 9 | 10 | 11 | # can unify w/ sglang_rollout.py later, e.g. add RM, if needed 12 | def main( 13 | prompt_data: Annotated[str, typer.Option()], 14 | url: Annotated[str, typer.Option()] = "http://localhost:30000/v1", 15 | input_key: Annotated[str, typer.Option()] = "input", 16 | n_samples_per_prompt: Annotated[int, typer.Option()] = 1, 17 | rollout_max_response_len: Annotated[int, typer.Option()] = 1024, 18 | rollout_temperature: Annotated[float, typer.Option()] = 1.0, 19 | rollout_top_p: Annotated[float, typer.Option()] = 1.0, 20 | ): 21 | """ 22 | Minimally send prompts to SGLang using OpenAI endpoints with arguments in the same format as main Slime. 23 | 24 | Example usage: 25 | python -m slime.utils.debug_utils.send_to_sglang --prompt-data /root/datasets/aime-2024/aime-2024.jsonl --input-key prompt --n-samples-per-prompt 16 --rollout-max-response-len 32768 --rollout-temperature 0.8 --rollout-top-p 0.7 26 | """ 27 | 28 | async def _main_async(): 29 | tasks = [ 30 | asyncio.create_task(_run_one(row, row_index=row_index, repeat_index=repeat_index)) 31 | for row_index, row in enumerate(read_file(prompt_data)) 32 | for repeat_index in range(n_samples_per_prompt) 33 | ] 34 | outputs = await asyncio.gather(*tasks) 35 | for output in outputs: 36 | print(json.dumps(output)) 37 | 38 | async def _run_one(row, row_index: int, repeat_index: int): 39 | resp = await client.chat.completions.create( 40 | messages=row[input_key], 41 | model="dummy_model", 42 | max_tokens=rollout_max_response_len, 43 | temperature=rollout_temperature, 44 | top_p=rollout_top_p, 45 | ) 46 | return dict( 47 | row_index=row_index, 48 | repeat_index=repeat_index, 49 | **row, 50 | response=resp.choices[0].message.content, 51 | ) 52 | 53 | client = AsyncOpenAI(api_key="dummy_key", base_url=url) 54 | asyncio.run(_main_async()) 55 | 56 | 57 | if __name__ == "__main__": 58 | typer.run(main) 59 | -------------------------------------------------------------------------------- /slime/utils/tensorboard_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | from slime.utils.misc import SingletonMeta 5 | 6 | try: 7 | from torch.utils.tensorboard import SummaryWriter 8 | except ImportError: 9 | SummaryWriter = None 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class _TensorboardAdapter(metaclass=SingletonMeta): 15 | _writer = None 16 | 17 | """ 18 | # Usage example: This will return the same instance every rank 19 | # tb = _TensorboardAdapter(args) # Initialize on first call 20 | # tb.log({"Loss": 0.1}, step=1) 21 | 22 | # In other files: 23 | # from tensorboard_utils import _TensorboardAdapter 24 | # tb = _TensorboardAdapter(args) # No parameters needed to get existing instance 25 | # tb.log({"Accuracy": 0.9}, step=1) 26 | """ 27 | 28 | def __init__(self, args): 29 | assert args.use_tensorboard, f"{args.use_tensorboard=}" 30 | tb_project_name = args.tb_project_name 31 | tb_experiment_name = args.tb_experiment_name 32 | if tb_project_name is not None or os.environ.get("TENSORBOARD_DIR", None): 33 | if tb_project_name is not None and tb_experiment_name is None: 34 | tb_experiment_name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 35 | self._initialize(tb_project_name, tb_experiment_name) 36 | else: 37 | raise ValueError("tb_project_name and tb_experiment_name, or TENSORBOARD_DIR are required") 38 | 39 | def _initialize(self, tb_project_name, tb_experiment_name): 40 | """Actual initialization logic""" 41 | # Get tensorboard directory from environment variable or use default path 42 | tensorboard_dir = os.environ.get("TENSORBOARD_DIR", f"tensorboard_log/{tb_project_name}/{tb_experiment_name}") 43 | os.makedirs(tensorboard_dir, exist_ok=True) 44 | logger.info(f"Saving tensorboard log to {tensorboard_dir}.") 45 | self._writer = SummaryWriter(tensorboard_dir) 46 | 47 | def log(self, data, step): 48 | """Log data to tensorboard 49 | 50 | Args: 51 | data (dict): Dictionary containing metric names and values 52 | step (int): Current step/epoch number 53 | """ 54 | for key in data: 55 | self._writer.add_scalar(key, data[key], step) 56 | 57 | def finish(self): 58 | """Close the tensorboard writer""" 59 | self._writer.close() 60 | -------------------------------------------------------------------------------- /slime/utils/debug_utils/display_debug_rollout_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from types import SimpleNamespace 4 | from typing import Annotated 5 | 6 | import torch 7 | import typer 8 | 9 | from slime.ray.rollout import compute_metrics_from_samples 10 | from slime.utils.types import Sample 11 | 12 | _WHITELIST_KEYS = [ 13 | "group_index", 14 | "index", 15 | "prompt", 16 | "response", 17 | "response_length", 18 | "label", 19 | "reward", 20 | "status", 21 | "metadata", 22 | ] 23 | 24 | 25 | def main( 26 | # Deliberately make this name consistent with main training arguments 27 | load_debug_rollout_data: Annotated[str, typer.Option()], 28 | show_metrics: bool = True, 29 | show_samples: bool = True, 30 | category: list[str] = None, 31 | ): 32 | if category is None: 33 | category = ["train", "eval"] 34 | for rollout_id, path in _get_rollout_dump_paths(load_debug_rollout_data, category): 35 | print("-" * 80) 36 | print(f"{rollout_id=} {path=}") 37 | print("-" * 80) 38 | 39 | pack = torch.load(path) 40 | sample_dicts = pack["samples"] 41 | 42 | if show_metrics: 43 | # TODO read these configs from dumps 44 | args = SimpleNamespace( 45 | advantage_estimator="grpo", 46 | reward_key=None, 47 | log_reward_category=None, 48 | ) 49 | sample_objects = [Sample.from_dict(s) for s in sample_dicts] 50 | metrics = compute_metrics_from_samples(args, sample_objects) 51 | print("metrics", metrics) 52 | 53 | if show_samples: 54 | for sample in sample_dicts: 55 | print(json.dumps({k: v for k, v in sample.items() if k in _WHITELIST_KEYS})) 56 | 57 | 58 | def _get_rollout_dump_paths(load_debug_rollout_data: str, categories: list[str]): 59 | # may improve later 60 | for rollout_id in range(1000): 61 | for category in categories: 62 | prefix = { 63 | "train": "", 64 | "eval": "eval_", 65 | }[category] 66 | path = Path(load_debug_rollout_data.format(rollout_id=f"{prefix}{rollout_id}")) 67 | if path.exists(): 68 | yield rollout_id, path 69 | 70 | 71 | if __name__ == "__main__": 72 | """python -m slime.utils.debug_utils.display_debug_rollout_data --load-debug-rollout-data ...""" 73 | typer.run(main) 74 | -------------------------------------------------------------------------------- /slime/utils/timer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import contextmanager 3 | from functools import wraps 4 | from time import time 5 | 6 | import torch.distributed 7 | 8 | from .misc import SingletonMeta 9 | 10 | __all__ = ["Timer", "timer"] 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class Timer(metaclass=SingletonMeta): 16 | def __init__(self): 17 | self.timers = {} 18 | self.start_time = {} 19 | 20 | def start(self, name): 21 | assert name not in self.start_time, f"Timer {name} already started." 22 | self.start_time[name] = time() 23 | if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: 24 | logger.info(f"Timer {name} start") 25 | 26 | def end(self, name): 27 | assert name in self.start_time, f"Timer {name} not started." 28 | elapsed_time = time() - self.start_time[name] 29 | self.add(name, elapsed_time) 30 | del self.start_time[name] 31 | if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: 32 | logger.info(f"Timer {name} end (elapsed: {elapsed_time:.1f}s)") 33 | 34 | def reset(self, name=None): 35 | if name is None: 36 | self.timers = {} 37 | elif name in self.timers: 38 | del self.timers[name] 39 | 40 | def add(self, name, elapsed_time): 41 | self.timers[name] = self.timers.get(name, 0) + elapsed_time 42 | 43 | def log_dict(self): 44 | return self.timers 45 | 46 | @contextmanager 47 | def context(self, name): 48 | self.start(name) 49 | try: 50 | yield 51 | finally: 52 | self.end(name) 53 | 54 | 55 | def timer(name_or_func): 56 | """ 57 | Can be used either as a decorator or a context manager: 58 | 59 | @timer 60 | def func(): 61 | ... 62 | 63 | or 64 | 65 | with timer("block_name"): 66 | ... 67 | """ 68 | # When used as a context manager 69 | if isinstance(name_or_func, str): 70 | name = name_or_func 71 | return Timer().context(name) 72 | 73 | func = name_or_func 74 | 75 | @wraps(func) 76 | def wrapper(*args, **kwargs): 77 | with Timer().context(func.__name__): 78 | return func(*args, **kwargs) 79 | 80 | return wrapper 81 | 82 | 83 | @contextmanager 84 | def inverse_timer(name): 85 | Timer().end(name) 86 | try: 87 | yield 88 | finally: 89 | Timer().start(name) 90 | -------------------------------------------------------------------------------- /slime/utils/typer_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import inspect 3 | from typing import Annotated 4 | 5 | import typer 6 | 7 | 8 | def dataclass_cli(func, env_var_prefix: str = "SLIME_SCRIPT_"): 9 | """Modified from https://github.com/fastapi/typer/issues/154#issuecomment-1544876144""" 10 | 11 | # The dataclass type is the first argument of the function. 12 | sig = inspect.signature(func) 13 | param = list(sig.parameters.values())[0] 14 | dataclass_cls = param.annotation 15 | assert dataclasses.is_dataclass(dataclass_cls) 16 | 17 | # To construct the signature, we remove the first argument (self) 18 | # from the dataclass __init__ signature. 19 | signature = inspect.signature(dataclass_cls.__init__) 20 | old_parameters = list(signature.parameters.values()) 21 | if len(old_parameters) > 0 and old_parameters[0].name == "self": 22 | del old_parameters[0] 23 | 24 | new_parameters = [] 25 | for param in old_parameters: 26 | env_var_name = f"{env_var_prefix}{param.name.upper()}" 27 | new_annotation = Annotated[param.annotation, typer.Option(envvar=env_var_name)] 28 | new_parameters.append(param.replace(annotation=new_annotation)) 29 | 30 | def wrapped(**kwargs): 31 | data = dataclass_cls(**kwargs) 32 | print(f"Execute command with args: {data}") 33 | return func(data) 34 | 35 | wrapped.__signature__ = signature.replace(parameters=new_parameters) 36 | wrapped.__doc__ = func.__doc__ 37 | wrapped.__name__ = func.__name__ 38 | wrapped.__qualname__ = func.__qualname__ 39 | 40 | return wrapped 41 | 42 | 43 | # unit test 44 | if __name__ == "__main__": 45 | from typer.testing import CliRunner 46 | 47 | @dataclasses.dataclass 48 | class DemoArgs: 49 | name: str 50 | count: int = 1 51 | 52 | app = typer.Typer() 53 | 54 | @app.command() 55 | @dataclass_cli 56 | def main(args: DemoArgs): 57 | print(f"{args.name}|{args.count}") 58 | 59 | runner = CliRunner() 60 | 61 | res1 = runner.invoke(app, [], env={"SLIME_SCRIPT_NAME": "EnvName", "SLIME_SCRIPT_COUNT": "10"}) 62 | print(f"{res1.stdout=}") 63 | assert res1.exit_code == 0 64 | assert "EnvName|10" in res1.stdout.strip() 65 | 66 | res2 = runner.invoke(app, ["--count", "999"], env={"SLIME_SCRIPT_NAME": "EnvName"}) 67 | print(f"{res2.stdout=}") 68 | assert res2.exit_code == 0 69 | assert "EnvName|999" in res2.stdout.strip() 70 | 71 | print("✅ All Tests Passed!") 72 | -------------------------------------------------------------------------------- /slime/utils/fp8_kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | fp8_dtype = torch.float8_e4m3fn 6 | fp8_max = torch.finfo(fp8_dtype).max 7 | fp8_min = -fp8_max 8 | 9 | 10 | def ceil_div(x: int, y: int) -> int: 11 | """ 12 | Perform ceiling division of two integers. 13 | 14 | Args: 15 | x: the dividend. 16 | y: the divisor. 17 | 18 | Returns: 19 | The result of the ceiling division. 20 | """ 21 | return (x + y - 1) // y 22 | 23 | 24 | @triton.jit 25 | def _blockwise_cast_to_fp8_triton( 26 | X, 27 | Y, 28 | S, 29 | stride_xm, 30 | stride_xn, 31 | stride_ym, 32 | stride_yn, 33 | stride_sm, 34 | stride_sn, 35 | M, 36 | N, 37 | eps, 38 | fp8_min, 39 | fp8_max, 40 | BLOCK_M: tl.constexpr = 32, 41 | BLOCK_N: tl.constexpr = 128, 42 | ): 43 | pid_m = tl.cast(tl.program_id(axis=0), tl.int64) 44 | pid_n = tl.cast(tl.program_id(axis=1), tl.int64) 45 | off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 46 | off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 47 | mask_m = off_m < M 48 | mask_n = off_n < N 49 | mask = mask_m[:, None] & mask_n[None, :] 50 | 51 | x = tl.load(X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn, mask=mask, other=0.0).to(tl.float32) 52 | _absmax = tl.maximum(tl.max(tl.abs(x)), eps) 53 | x_s = _absmax / fp8_max 54 | s_inv = 1.0 / x_s 55 | y_q = tl.clamp(x * s_inv, fp8_min, fp8_max).to(Y.dtype.element_ty) 56 | 57 | tl.store(Y + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn, y_q, mask=mask) 58 | tl.store(S + pid_m * stride_sm + pid_n * stride_sn, x_s) 59 | 60 | 61 | def blockwise_cast_to_fp8_triton(x: torch.Tensor, block_size=None) -> tuple[torch.Tensor, torch.Tensor]: 62 | BLOCK_M, BLOCK_N = 128, 128 63 | if block_size: 64 | BLOCK_M, BLOCK_N = block_size[0], block_size[1] 65 | M, N = x.shape 66 | y = torch.empty(M, N, device=x.device, dtype=torch.float8_e4m3fn) 67 | s = torch.empty(ceil_div(M, BLOCK_M), ceil_div(N, BLOCK_N), dtype=torch.float32, device=x.device) 68 | 69 | def grid(meta): 70 | return (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"])) 71 | 72 | if x.is_contiguous(): 73 | kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 8, "num_stages": 2} 74 | else: 75 | kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 1, "num_stages": 4} 76 | _blockwise_cast_to_fp8_triton[grid]( 77 | x, y, s, *x.stride(), *y.stride(), *s.stride(), M, N, 1e-10, fp8_min, fp8_max, **kwargs 78 | ) 79 | return y, s 80 | -------------------------------------------------------------------------------- /slime/ray/utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ray/utils.py#L1 2 | import os 3 | 4 | import ray 5 | import torch 6 | from slime.ray.ray_actor import RayActor 7 | 8 | 9 | # Refer to 10 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96 11 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103 12 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95 13 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117 14 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109 15 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172 16 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98 17 | NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ 18 | "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", 19 | "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", 20 | "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", 21 | "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", 22 | "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", 23 | "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", 24 | "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", 25 | ] 26 | 27 | 28 | def ray_noset_visible_devices(env_vars=os.environ): 29 | return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) 30 | 31 | 32 | def get_physical_gpu_id(): 33 | device = torch.cuda.current_device() 34 | props = torch.cuda.get_device_properties(device) 35 | return str(props.uuid) 36 | 37 | 38 | @ray.remote 39 | class Lock(RayActor): 40 | def __init__(self): 41 | self._locked = False # False: unlocked, True: locked 42 | 43 | def acquire(self): 44 | """ 45 | Try to acquire the lock. Returns True if acquired, False otherwise. 46 | Caller should retry until it returns True. 47 | """ 48 | if not self._locked: 49 | self._locked = True 50 | return True 51 | return False 52 | 53 | def release(self): 54 | """Release the lock, allowing others to acquire.""" 55 | assert self._locked, "Lock is not acquired, cannot release." 56 | self._locked = False 57 | -------------------------------------------------------------------------------- /docs/zh/get_started/qa.md: -------------------------------------------------------------------------------- 1 | # 常见 Q&A 2 | 3 | 1. **训练过程中为什么会出现乱码?** 4 | 5 | 一般来说这种情况是 megatron 没有被正确加载。请检查 `--load` 或 `--ref-load` 是否有对应的 ckpt。注意 megatron 只能加载其中有 `latest_checkpointed_iteration.txt` 的目录。 6 | 7 | 如果需要指定某个特定的 iter,可以查看当前 megatron 的使用方法,一般是可以通过 `--ckpt-step` 来指定步数。 8 | 9 | 1. **为什么我的任务一直卡在 ray 提交的页面上?** 10 | 11 | 请先检查你需要跑的任务是训推一体的,还是训推分离的。 12 | 13 | 如果是训推一体,即训练和推理共用 GPU,请检查 14 | 15 | - 是否设置了 `--colocate` 参数开启训推一体; 16 | - 当前任务的总卡数是否大于等于 `actor_num_nodes * actor_num_gpus_per_node` 17 | 18 | 如果是训推分离,请检查: 19 | 20 | - 当前任务的总卡数是否大于等于 `actor_num_nodes * actor_num_gpus_per_node + rollout_num_gpus` 21 | 22 | 1. **为什么训着训着 OOM 了?`max_tokens_per_gpu` 是干什么用的?** 23 | 24 | OOM 往往是因为 `max_tokens_per_gpu` 设置过高了。 `max_tokens_per_gpu` 是指在训练过程中,每张 GPU 上最多可以放多少 token。如果担心 OOM 的话,可以先把这个值设成 `rollout_max_response_len / cp_size`,之后再为了提升训练效率来增大这个值。`--max-tokens-per-gpu` 只有在开启 `--use-dynamic-batch-size` 的情况下才会启用。 25 | 26 | 如果 `max_tokens_per_gpu` 很小,还会 oom,可以检查一下是否单次生成的数据太长了,需要开启 cp(`--context-parallel-size`)。如果进行了自定义的数据生成,可以看一下是否在多轮生成的情况下,生成的总长度比预期的长很多。 27 | 28 | 1. **多机训练的时候,遇到了 transformers 库找不到某个模型的错误该怎么办?** 29 | 30 | 这种情况一般是因为多个进程都在通过类似于 `AutoConfig.from_pretrained` 或者 `AutoModelForCausalLM.from_pretrained` 的方式读取本地文件,出现了文件系统的写冲突。可以通过设置 `--model-name` 缓解这一问题。 31 | 32 | 1. **如何续训?** 33 | 34 | 直接将 `--load` 设置为 `--save` 的目录即可。 35 | 36 | 1. **batch size 是如何计算的?** 37 | 38 | 一个 rollout 会用 `rollout_batch_size` 条 prompt,每一条会采 `n_samples_per_prompt` 条,所以一个 rollout 共 `rollout_batch_size * n_samples_per_prompt` 条数据。 39 | 40 | 可以用 `--num-steps-per-rollout` 来决定每一个 rollout 跑多少步。这相当于是把 `global_batch_size` 设置成 `rollout_batch_size * n_samples_per_prompt // num_steps_per_rollout`。 41 | 42 | 1. **slime 是否进行了 data packing / varlen 处理?** 43 | 44 | data packing 是指在训练过程中,将长短不一的 sample 拼接到一起,从而提升训练的利用率。slime 默认会进行这样的操作。 45 | 46 | 1. **sglang 部分出现 `Max retries exceeded with url: /get_model_info (Caused by NewConnectionError` 的问题怎么办?** 47 | 48 | 这个问题主要来源于单机内多个 sglang server 导致的端口冲突,目前我们仍在和 sglang 团队一起解决这个问题。一个临时的缓解方案是尽可能减少单机内的 sglang server 数量,例如设置 tp=8。 49 | 50 | 1. **grad norm 好高,训练训崩了怎么办?** 51 | 52 | 首先请确保数据和模型是匹配的,例如说,如果数据是实现已经做好 chat template 的了,这个 chat template 是否和原模型一致。如果数据正确的话,可以参考 [debug 指南](../developer_guide/debug.md) 进行更深入的分析。 53 | 54 | 1. **我的 sglang 生成时间特别特别久,gpu 功率都打满了,跑了好久好没有输出是为什么?** 55 | 56 | 请确认一下 `--hf-checkpoint` 对应的模型是否正确设置了 stop token,如果没有,可以通过 `--rollout-stop` 或者 `--rollout-stop-token-ids` 来进行设置。 57 | 58 | 1. **sglang 出现 an illegal memory access was encountered** 59 | 60 | 根据 sglang 的文档(https://docs.sglang.ai/references/troubleshooting.html),有可能是 OOM 了,可以考虑缩小 `--sglang-mem-fraction-static`。 61 | 62 | 1. **出现 torch compile/inducer 的 `JSONDecodeError`** 63 | 64 | 一般是 torch compile 读写 cache 出现的问题。可以考虑在 ray 的 env_var 里加上 `"TORCHINDUCTOR_FORCE_DISABLE_CACHES": "1"`。 65 | 66 | 1. **训练出现 grad NaN 或者 Inf 的情况** 67 | 68 | 可以通过设置 `--no-check-for-nan-in-loss-and-grad` 来尝试跳过对应的训练步。 69 | -------------------------------------------------------------------------------- /slime/utils/misc.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import subprocess 3 | 4 | import ray 5 | 6 | from slime.utils.http_utils import is_port_available 7 | 8 | 9 | def load_function(path): 10 | """ 11 | Load a function from a module. 12 | :param path: The path to the function, e.g. "module.submodule.function". 13 | :return: The function object. 14 | """ 15 | module_path, _, attr = path.rpartition(".") 16 | module = importlib.import_module(module_path) 17 | return getattr(module, attr) 18 | 19 | 20 | class SingletonMeta(type): 21 | """ 22 | A metaclass for creating singleton classes. 23 | """ 24 | 25 | _instances = {} 26 | 27 | def __call__(cls, *args, **kwargs): 28 | if cls not in cls._instances: 29 | instance = super().__call__(*args, **kwargs) 30 | cls._instances[cls] = instance 31 | return cls._instances[cls] 32 | 33 | 34 | def exec_command(cmd: str, capture_output: bool = False) -> str | None: 35 | print(f"EXEC: {cmd}", flush=True) 36 | 37 | try: 38 | result = subprocess.run( 39 | ["bash", "-c", cmd], 40 | shell=False, 41 | check=True, 42 | capture_output=capture_output, 43 | **(dict(text=True) if capture_output else {}), 44 | ) 45 | except subprocess.CalledProcessError as e: 46 | if capture_output: 47 | print(f"{e.stdout=} {e.stderr=}") 48 | raise 49 | 50 | if capture_output: 51 | print(f"Captured stdout={result.stdout} stderr={result.stderr}") 52 | return result.stdout 53 | 54 | 55 | def get_current_node_ip(): 56 | address = ray._private.services.get_node_ip_address() 57 | # strip ipv6 address 58 | address = address.strip("[]") 59 | return address 60 | 61 | 62 | def get_free_port(start_port=10000, consecutive=1): 63 | # find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available 64 | port = start_port 65 | while not all(is_port_available(port + i) for i in range(consecutive)): 66 | port += 1 67 | return port 68 | 69 | 70 | def should_run_periodic_action( 71 | rollout_id: int, 72 | interval: int | None, 73 | num_rollout_per_epoch: int | None = None, 74 | num_rollout: int | None = None, 75 | ) -> bool: 76 | """ 77 | Return True when a periodic action (eval/save/checkpoint) should run. 78 | 79 | Args: 80 | rollout_id: The current rollout index (0-based). 81 | interval: Desired cadence; disables checks when None. 82 | num_rollout_per_epoch: Optional epoch boundary to treat as a trigger. 83 | """ 84 | if interval is None: 85 | return False 86 | 87 | if num_rollout is not None and rollout_id == num_rollout - 1: 88 | return True 89 | 90 | step = rollout_id + 1 91 | return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0) 92 | -------------------------------------------------------------------------------- /docs/zh/examples/qwen3-4b-base-openhermes.md: -------------------------------------------------------------------------------- 1 | # SFT Qwen3-4B-Base 2 | 3 | ## 环境准备 4 | 5 | 首先需要我们仿照 [示例:Qwen3-4B 模型](qwen3-4B.md) 创建镜像环境与转换 `Qwen3-4B-Base` 模型。 6 | 7 | 之后,我们处理 sft 数据。这里我们以经典的 [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5) 为例,首先把数据处理成适合 slime 加载的格式,可以用如下的脚本进行处理,增加一个符合 openai message 格式的列,并保存在 `/root/openhermes2_5.parquet`。 8 | 9 | ```python 10 | from datasets import load_dataset 11 | 12 | ds = load_dataset("teknium/OpenHermes-2.5")["train"] 13 | 14 | def convert(sample): 15 | conversations = sample["conversations"] 16 | 17 | def convert_role(role): 18 | if role == "human": 19 | return "user" 20 | elif role == "gpt": 21 | return "assistant" 22 | elif role == "system": 23 | return "system" 24 | else: 25 | raise ValueError(f"Unknown role: {role}") 26 | 27 | messages = [ 28 | { 29 | "role": convert_role(turn["from"]), 30 | "content": turn["value"], 31 | } 32 | for turn in conversations 33 | ] 34 | 35 | return {"messages": messages} 36 | 37 | ds = ds.map(convert) 38 | ds.to_parquet("/root/openhermes2_5.parquet") 39 | ``` 40 | 41 | ## 执行训练 42 | 43 | 执行训练: 44 | 45 | ```bash 46 | cd /root/slime 47 | bash script/run-qwen3-4B-base-sft.sh 48 | ``` 49 | 50 | ### 参数简介 51 | 52 | 可以将 [run-qwen3-4B-base-sft.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-4B-base-sft.sh) 与 [run-qwen3-4B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-4B.sh) 进行对比。会发现除了我们将模型由 instruct 模型换为了 base 模型之外,主要进行了如下的几个调整: 53 | 54 | 1. 移除了 `SGLANG_ARGS` 和 `GRPO_ARGS`。这是因为 sft 的过程中不需要启动 sglang 或者做 grpo 相关的配置; 55 | 56 | 2. 将 `ROLLOUT_ARGS` 改名为了 `SFT_ARGS`,并配置为: 57 | 58 | ```bash 59 | SFT_ARGS=( 60 | --rollout-function-path slime.rollout.sft_rollout.generate_rollout 61 | --prompt-data /root/openhermes2_5.parquet 62 | --input-key messages 63 | --rollout-shuffle 64 | --num-epoch 3 65 | --rollout-batch-size 128 66 | --global-batch-size 128 67 | 68 | --loss-type sft_loss 69 | --calculate-per-token-loss 70 | --disable-compute-advantages-and-returns 71 | --debug-train-only 72 | ) 73 | ``` 74 | 75 | slime 中的 sft 实际上是复用了 slime 的 custom rollout 功能,通过 `--rollout-function-path` 将数据生成部分从使用 sglang 的 RL rollout,切换成了从文件中读取数据的 sft 版本,即 `slime.rollout.sft_rollout.generate_rollout`。 76 | 77 | 对于 sft 来说,建议将 `rollout_batch_size` 与 `global_batch_size` 设置成相同的,并不要配置 `n_samples_per_prompt`,这样相当于是读一个 batch 就训一个 batch。 78 | 79 | slime 还支持不同的 loss 类型,我们就是通过 `--loss-type sft_loss` 配置上 sft loss 的。 80 | 81 | 至于 `--calculate-per-token-loss`,这是因为 slime 默认是以 GRPO 的 per sample mean 进行计算的,而一般 sft 训练都是按一个 batch 的所有不被 mask 的 token 取平均,所以建议配置上。 82 | 83 | 最后 `--disable-compute-advantages-and-returns` 表示 sft 的过程中不需要预先计算 log prob,`--debug-train-only` 表示不需要初始化 sglang。 84 | 85 | 3. 使用了 `train_async.py` 而不是 `train.py`。这是为了利用异步训练的流程,来实现数据 prefetch。 86 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/megatron_to_hf/llama.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | 4 | 5 | def convert_llama_to_hf(args, name, param): 6 | if name == "module.module.embedding.word_embeddings.weight": 7 | return [("model.embed_tokens.weight", param)] 8 | if name == "module.module.output_layer.weight": 9 | return [("lm_head.weight", param)] 10 | if name == "module.module.decoder.final_layernorm.weight": 11 | return [("model.norm.weight", param)] 12 | 13 | try: 14 | head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads 15 | except AttributeError: 16 | head_dim = args.hidden_size // args.num_attention_heads 17 | value_num_per_group = args.num_attention_heads // args.num_query_groups 18 | 19 | decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)" 20 | match = re.match(decoder_layers_pattern, name) 21 | if match: 22 | layer_idx, rest = match.groups() 23 | if rest == "self_attention.linear_proj.weight": 24 | return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)] 25 | elif rest == "self_attention.linear_qkv.weight": 26 | # Split QKV weight for Llama 27 | param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size) 28 | q_param, k_param, v_param = torch.split(param, split_size_or_sections=[value_num_per_group, 1, 1], dim=1) 29 | q_param = q_param.reshape(-1, args.hidden_size) 30 | k_param = k_param.reshape(-1, args.hidden_size) 31 | v_param = v_param.reshape(-1, args.hidden_size) 32 | return [ 33 | (f"model.layers.{layer_idx}.self_attn.q_proj.weight", q_param), 34 | (f"model.layers.{layer_idx}.self_attn.k_proj.weight", k_param), 35 | (f"model.layers.{layer_idx}.self_attn.v_proj.weight", v_param), 36 | ] 37 | elif rest == "mlp.linear_fc1.weight": 38 | # Split gate and up projections for SwiGLU 39 | gate_weight, up_weight = param.chunk(2, dim=0) 40 | return [ 41 | (f"model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight), 42 | (f"model.layers.{layer_idx}.mlp.up_proj.weight", up_weight), 43 | ] 44 | elif rest == "mlp.linear_fc2.weight": 45 | return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)] 46 | elif rest == "self_attention.linear_qkv.layer_norm_weight": 47 | return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)] 48 | elif rest == "mlp.linear_fc1.layer_norm_weight": 49 | return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] 50 | elif rest == "pre_mlp_layernorm.weight": 51 | return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] 52 | 53 | raise ValueError(f"Unknown parameter name: {name}") 54 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "packaging", 4 | "setuptools >= 49.4.0", 5 | "wheel", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [tool.isort] 10 | profile = "black" # black-compatible 11 | line_length = 119 # should match black parameters 12 | ignore_whitespace = true # ignore whitespace for compatibility with the initial style 13 | py_version = 310 # python 3.10 as a target version 14 | sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] 15 | default_section = "THIRDPARTY" 16 | extend_skip = ["setup.py", "docs/source/conf.py"] 17 | known_first_party = ["slime", "slime_plugins"] 18 | known_third_party = ["megatron", "wandb", "ray", "transformers"] 19 | src_paths = ["slime", "slime_plugins"] 20 | 21 | 22 | [tool.black] 23 | line_length = 119 24 | 25 | [tool.ruff] 26 | line-length = 320 # TODO 27 | select = [ 28 | "E", # Pycodestyle Errors (Structural/Fundamental Errors like bad indentation) 29 | "F", # Pyflakes (Core Errors: Unused imports, undefined names) 30 | "B", # Flake8-Bugbear (Logic Bugs: Variable shadowing, dangerous default arguments) 31 | "UP", # pyupgrade (Modernization and compatibility issues) 32 | ] 33 | ignore = [ 34 | "E402", # module-import-not-at-top-of-file 35 | "E501", # Line too long # TODO handle it later 36 | ] 37 | 38 | [tool.pytest.ini_options] 39 | # durations=0 will display all tests execution time, sorted in ascending order starting from from the slowest one. 40 | # -vv will also display tests with duration = 0.00s 41 | addopts = "--verbose --pyargs --durations=0 --strict-markers" # always add these arguments to pytest 42 | testpaths = ["./tests"] # must be an explicit path to avoid importing another "tests" module 43 | # directories to ignore when discovering tests 44 | norecursedirs = [ 45 | "external", 46 | "examples", 47 | "docs", 48 | "scripts", 49 | "tools", 50 | "tutorials", 51 | "*.egg", 52 | ".*", 53 | "_darcs", 54 | "build", 55 | "CVS", 56 | "dist", 57 | "venv", 58 | "{arch}", 59 | ] 60 | # markers to select tests, use `pytest --markers` to see all available markers, `pytest -m ""` to select tests 61 | markers = [ 62 | "unit: marks unit test, i.e. testing a single, well isolated functionality (deselect with '-m \"not unit\"')", 63 | "integration: marks test checking the elements when integrated into subsystems (deselect with '-m \"not integration\"')", 64 | "system: marks test working at the highest integration level (deselect with '-m \"not system\"')", 65 | "acceptance: marks test checking whether the developed product/model passes the user defined acceptance criteria (deselect with '-m \"not acceptance\"')", 66 | "docs: mark tests related to documentation (deselect with '-m \"not docs\"')", 67 | "skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups", 68 | "pleasefixme: marks tests that are broken and need fixing", 69 | ] 70 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | from pathlib import Path 5 | 6 | # TODO: may need to copy those 2 functions and do refactoring. 7 | from megatron.training.checkpointing import load_checkpoint as _load_checkpoint_megatron 8 | from megatron.training.checkpointing import save_checkpoint 9 | from megatron.training.global_vars import get_args 10 | 11 | from slime.utils import megatron_bridge_utils 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | __all__ = ["save_checkpoint"] 16 | 17 | 18 | def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, checkpointing_context, skip_load_to_model_and_opt): 19 | # ref: how megatron `load_checkpoint` gets directory 20 | args = get_args() 21 | load_path = args.load 22 | 23 | assert Path(load_path).exists() and _is_dir_nonempty( 24 | load_path 25 | ), f"{args.load=} does not exist or is an empty directory. Did you specify the wrong folder?" 26 | 27 | if _is_megatron_checkpoint(load_path): 28 | return _load_checkpoint_megatron( 29 | ddp_model=ddp_model, 30 | optimizer=optimizer, 31 | opt_param_scheduler=opt_param_scheduler, 32 | checkpointing_context=checkpointing_context, 33 | skip_load_to_model_and_opt=skip_load_to_model_and_opt, 34 | ) 35 | else: 36 | return _load_checkpoint_hf( 37 | ddp_model=ddp_model, 38 | optimizer=optimizer, 39 | args=args, 40 | load_path=load_path, 41 | ) 42 | 43 | 44 | def _is_megatron_checkpoint(path: str | Path) -> bool: 45 | return (Path(path) / "latest_checkpointed_iteration.txt").is_file() or bool( 46 | re.fullmatch(r"iter_\d{7}", Path(path).name) 47 | ) 48 | 49 | 50 | def _load_checkpoint_hf(ddp_model, optimizer, args, load_path: str): 51 | assert args.megatron_to_hf_mode == "bridge", "Only bridge mode is supported for loading HF checkpoint" 52 | from megatron.bridge import AutoBridge 53 | 54 | import slime_plugins.megatron_bridge # noqa: F401 55 | 56 | logger.info(f"Load checkpoint from HuggingFace model into Megatron (path={load_path})") 57 | 58 | with megatron_bridge_utils.patch_megatron_model(ddp_model): 59 | bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) 60 | bridge.load_hf_weights(ddp_model) 61 | 62 | # Copied from Megatron-core :: load_checkpoint (with simplifications) 63 | if (args.fp16 or args.bf16) and optimizer is not None: 64 | assert not args.load_main_params_from_ckpt 65 | optimizer.reload_model_params() 66 | 67 | # We can see `successfully loaded checkpoint from ... [ t 1/2, p 1/1 ] at iteration 0` 68 | # when loading Megatron, thus it is 0 69 | iteration = 0 70 | num_floating_point_operations_so_far = 0 71 | return iteration, num_floating_point_operations_so_far 72 | 73 | 74 | def _is_dir_nonempty(path): 75 | with os.scandir(path) as it: 76 | return any(it) 77 | -------------------------------------------------------------------------------- /slime/rollout/sft_rollout.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from slime.utils.mask_utils import MultiTurnLossMaskGenerator 4 | from slime.utils.processing_utils import load_processor, load_tokenizer, prepare_model_inputs 5 | 6 | __all__ = ["generate_rollout"] 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | TOKENIZER = None 12 | PROCESSOR = None 13 | MASK_GENERATOR = None 14 | SAMPLE_PRINTED = False 15 | 16 | 17 | def generate_rollout(args, rollout_id, data_buffer, evaluation=False): 18 | """An example to implement the generate_rollout function for an rule based rm rollout generation. 19 | 20 | Args: 21 | args: the whole args 22 | rollout_id: int, the id of the rollout, used for deterministic data generation 23 | data_buffer: the data buffer to store the generated samples 24 | evaluation: bool, whether the rollout is for evaluation or not 25 | 26 | Returns: 27 | list[Sample]: a list of samples generated by the rollout 28 | """ 29 | assert not evaluation 30 | assert args.rollout_global_dataset 31 | 32 | global TOKENIZER, PROCESSOR, MASK_GENERATOR, SAMPLE_PRINTED 33 | if TOKENIZER is None: 34 | TOKENIZER = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) 35 | 36 | if PROCESSOR is None: 37 | PROCESSOR = load_processor(args.hf_checkpoint, trust_remote_code=True) 38 | 39 | if MASK_GENERATOR is None: 40 | MASK_GENERATOR = MultiTurnLossMaskGenerator(TOKENIZER, tokenizer_type=args.loss_mask_type) 41 | 42 | samples = data_buffer.get_samples(args.rollout_batch_size) 43 | 44 | for i, sample in enumerate(samples): 45 | (sample,) = sample 46 | messages = sample.prompt 47 | tools = sample.metadata.get("tools", None) 48 | 49 | input_ids, extra_info = prepare_model_inputs( 50 | messages, TOKENIZER, PROCESSOR, sample.metadata, 51 | args.apply_chat_template, args.apply_chat_template_kwargs 52 | ) 53 | 54 | has_multimodal = bool(extra_info.get("images") or extra_info.get("videos")) 55 | if has_multimodal: 56 | sample.multimodal_inputs = extra_info["multimodal_inputs"] 57 | token_ids, loss_mask = MASK_GENERATOR.get_loss_mask_with_multimodal_alignment( 58 | messages, input_ids, tools=tools 59 | ) 60 | else: 61 | token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools) 62 | 63 | response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0] 64 | 65 | sample.tokens = token_ids 66 | sample.response_length = response_length 67 | sample.reward = 0 68 | sample.loss_mask = loss_mask[-response_length:] 69 | 70 | if i == 0 and not SAMPLE_PRINTED: 71 | logger.info( 72 | f"sft_rollout::generate_rollout example data: {sample=} (raw){messages=} (raw){token_ids=} (raw){loss_mask=} {response_length=}" 73 | ) 74 | SAMPLE_PRINTED = True 75 | 76 | return samples 77 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | from slime.utils import megatron_bridge_utils 4 | from slime.utils.iter_utils import chunk_named_params_by_size 5 | 6 | from ..megatron_to_hf import postprocess_hf_param 7 | from ..misc_utils import strip_param_name_prefix 8 | from .hf_weight_iterator_base import HfWeightIteratorBase 9 | 10 | 11 | class HfWeightIteratorBridge(HfWeightIteratorBase): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | 15 | from megatron.bridge import AutoBridge 16 | import slime_plugins.megatron_bridge # noqa: F401 17 | 18 | self._bridge = AutoBridge.from_hf_pretrained(self.args.hf_checkpoint) 19 | 20 | def get_hf_weight_chunks(self, megatron_local_weights): 21 | # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name) 22 | renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()} 23 | with megatron_bridge_utils.patch_megatron_model(self.model): 24 | conversion_tasks = self._bridge.get_conversion_tasks(self.model) 25 | conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) 26 | 27 | named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) 28 | 29 | named_weights = ( 30 | ( 31 | hf_param_name, 32 | postprocess_hf_param( 33 | args=self.args, 34 | megatron_param_name=megatron_param_name, 35 | hf_param_name=hf_param_name, 36 | param=weight, 37 | ), 38 | ) 39 | for hf_param_name, weight, megatron_param_name in named_weights 40 | ) 41 | 42 | yield from chunk_named_params_by_size(named_weights, chunk_size=self.args.update_weight_buffer_size) 43 | 44 | 45 | def _process_conversion_tasks(vanilla_conversion_tasks, new_weight_dict): 46 | def _handle_one(task): 47 | if task.param_weight is None: 48 | return task 49 | 50 | weight_dict_key = f"vp_stages.{task.vp_stage}.{task.param_name}" 51 | assert ( 52 | weight_dict_key in new_weight_dict 53 | ), f"{weight_dict_key=} not in new_weight_dict ({task.vp_stage=}, {task.param_name=}, {list(new_weight_dict)=})" 54 | 55 | new_param_weight = new_weight_dict[weight_dict_key] 56 | new_param_weight = new_param_weight.cuda() 57 | return dataclasses.replace(task, param_weight=new_param_weight) 58 | 59 | return _MapWithLen(_handle_one, vanilla_conversion_tasks) 60 | 61 | 62 | class _MapWithLen: 63 | def __init__(self, fn, xs): 64 | self.fn = fn 65 | self.xs = xs 66 | 67 | def __len__(self): 68 | return len(self.xs) 69 | 70 | def __iter__(self): 71 | for x in self.xs: 72 | yield self.fn(x) 73 | -------------------------------------------------------------------------------- /examples/eval/nemo_skills/skills_client.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from typing import Any 4 | 5 | import requests 6 | from examples.eval.eval_delegate import EvalClient, EvalDelegateError 7 | from examples.eval.nemo_skills.skills_config import SkillsEvalEnvConfig 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class SkillsEvalClient(EvalClient): 13 | """HTTP client that proxies evaluation requests to the NeMo Skills server.""" 14 | 15 | def __init__(self, config: SkillsEvalEnvConfig, router_url: str): 16 | super().__init__(config.name or "skills") 17 | self._config = config 18 | self._router_url = router_url.rstrip("/") 19 | self._endpoint = (config.url or "").rstrip("/") 20 | self._timeout_secs = float(config.timeout_secs) 21 | self._max_retries = max(1, int(config.max_retries)) 22 | self._headers = dict(config.headers or {}) 23 | self._session = requests.Session() 24 | 25 | @classmethod 26 | def from_config(cls, config: SkillsEvalEnvConfig, router_url: str): 27 | if not config.url: 28 | return None 29 | return cls(config, router_url) 30 | 31 | def evaluate(self, args, rollout_id: int) -> tuple[dict[str, Any], dict[str, Any]]: 32 | if not self._config.datasets: 33 | logger.warning("No Skills datasets configured; skipping delegate evaluation.") 34 | return {}, {} 35 | 36 | payload = self._build_payload(args, rollout_id) 37 | response = self._request(payload) 38 | metrics = response["raw_metrics"] 39 | return metrics, response 40 | 41 | def _build_payload(self, args, rollout_id: int) -> dict[str, Any]: 42 | benchmarks = [cfg.to_payload() for cfg in self._config.datasets] 43 | benchmarks = [cfg for cfg in benchmarks if cfg] 44 | return { 45 | "rollout_id": rollout_id, 46 | "router_url": self._router_url, 47 | "benchmarks": benchmarks, 48 | } 49 | 50 | def _request(self, payload: dict[str, Any]) -> dict[str, Any]: 51 | last_error: Exception | None = None 52 | for attempt in range(1, self._max_retries + 1): 53 | try: 54 | response = self._session.post( 55 | self._endpoint, 56 | json=payload, 57 | timeout=self._timeout_secs, 58 | headers=self._headers, 59 | ) 60 | response.raise_for_status() 61 | if not response.content: 62 | return {} 63 | return response.json() 64 | except requests.RequestException as exc: 65 | last_error = exc 66 | logger.warning( 67 | "Skills eval delegate request failed (attempt %s/%s): %s", attempt, self._max_retries, exc 68 | ) 69 | if attempt < self._max_retries: 70 | time.sleep(min(2**attempt, 30)) 71 | raise EvalDelegateError("Skills evaluation request failed") from last_error 72 | -------------------------------------------------------------------------------- /slime/backends/megatron_utils/megatron_to_hf/mimo.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from .qwen2 import convert_qwen2_to_hf 4 | 5 | 6 | def convert_mimo_to_hf(args, name, param): 7 | """ 8 | Convert MiMo model parameters from Megatron to HuggingFace format. 9 | 10 | MiMo extends Qwen2 with MTP (Multi-Token Prediction) layers. 11 | """ 12 | 13 | if "mtp" in name: 14 | return convert_mimo_mtp_param(args, name, param) 15 | 16 | return convert_qwen2_to_hf(args, name, param) 17 | 18 | 19 | def convert_mimo_mtp_param(args, name, param): 20 | """ 21 | Convert MTP layer parameters from Megatron to HuggingFace format. 22 | 23 | MTP layers in MiMo contain: 24 | - LayerNorms (token_layernorm, hidden_layernorm, final_layernorm) 25 | - Input projection (input_proj) 26 | - Self attention (reuses Qwen2 attention structure) 27 | - MLP (reuses Qwen2 MLP structure) 28 | 29 | Based on MimoBridge._convert_mtp_param logic (reverse mapping) 30 | """ 31 | mtp_pattern = r"module\.module\.mtp\.layers\.(\d+)\.(.+)" 32 | match = re.match(mtp_pattern, name) 33 | 34 | if not match: 35 | raise ValueError(f"Invalid MTP parameter name: {name}") 36 | 37 | layer_idx, component = match.groups() 38 | 39 | # Direct mappings for MTP-specific components (Megatron -> HF) 40 | # Based on MimoBridge direct_name_mapping (reversed) 41 | direct_mappings = { 42 | "enorm.weight": f"model.mtp_layers.{layer_idx}.token_layernorm.weight", 43 | "hnorm.weight": f"model.mtp_layers.{layer_idx}.hidden_layernorm.weight", 44 | "eh_proj.weight": f"model.mtp_layers.{layer_idx}.input_proj.weight", 45 | "final_layernorm.weight": f"model.mtp_layers.{layer_idx}.final_layernorm.weight", 46 | } 47 | if component == "eh_proj.weight": 48 | first_half, second_half = param.chunk(2, dim=1) 49 | param = torch.cat([second_half, first_half], dim=1) 50 | 51 | # Check direct mappings first 52 | if component in direct_mappings: 53 | return [(direct_mappings[component], param)] 54 | 55 | # Handle transformer_layer components 56 | if component.startswith("transformer_layer."): 57 | # Remove "transformer_layer." prefix 58 | transformer_component = component[len("transformer_layer.") :] 59 | 60 | # Create proxy name for reusing existing Qwen2 conversion functions 61 | proxy_name = f"module.module.decoder.layers.{layer_idx}.{transformer_component}" 62 | 63 | # Use existing convert_qwen2_to_hf function for transformer components 64 | results = convert_qwen2_to_hf(args, proxy_name, param) 65 | 66 | # Replace model.layers with mtp_layers in results 67 | converted_results = [] 68 | for hf_name, hf_param in results: 69 | # Replace model.layers.{idx} with mtp_layers.{idx} 70 | hf_name = hf_name.replace(f"model.layers.{layer_idx}", f"model.mtp_layers.{layer_idx}") 71 | converted_results.append((hf_name, hf_param)) 72 | 73 | return converted_results 74 | 75 | raise ValueError(f"Unknown MTP component: {component} in {name}") 76 | -------------------------------------------------------------------------------- /slime/rollout/rm_hub/__init__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import random 3 | 4 | import aiohttp 5 | 6 | from slime.utils.misc import load_function 7 | from slime.utils.types import Sample 8 | 9 | from .deepscaler import get_deepscaler_rule_based_reward 10 | from .f1 import f1_score 11 | from .gpqa import compute_gpqa_reward 12 | from .math_dapo_utils import compute_score as compute_score_dapo 13 | from .math_utils import extract_answer as extract_boxed_answer 14 | from .math_utils import grade_answer_verl 15 | 16 | 17 | async def remote_rm(args, sample: Sample): 18 | payload = { 19 | "prompt": sample.prompt, 20 | "response": sample.response, 21 | "label": sample.label, 22 | } 23 | session_kwargs = {} 24 | async with aiohttp.ClientSession(**session_kwargs) as session: 25 | async with session.post(args.rm_url, json=payload) as resp: 26 | resp.raise_for_status() 27 | return await resp.json() 28 | 29 | 30 | async def async_rm(args, sample: Sample, **kwargs): 31 | if args.custom_rm_path is not None: 32 | rm_function = load_function(args.custom_rm_path) 33 | return await rm_function(args, sample, **kwargs) 34 | 35 | metadata = sample.metadata if isinstance(sample.metadata, dict) else {} 36 | rm_type = (metadata.get("rm_type") or args.rm_type or "").strip() 37 | response = sample.response 38 | label = sample.label 39 | if rm_type.startswith("boxed_"): 40 | response = extract_boxed_answer(response) or "" 41 | rm_type = rm_type[len("boxed_") :] 42 | 43 | # This function is intended for remote or time-consuming reward model evaluation. 44 | # Implement the actual logic as needed. 45 | if rm_type == "remote_rm": 46 | return await remote_rm(args, sample) 47 | elif rm_type == "deepscaler": 48 | return get_deepscaler_rule_based_reward(response, label) 49 | elif rm_type == "dapo": 50 | return compute_score_dapo(response, label) 51 | elif rm_type == "math": 52 | return 1 if grade_answer_verl(response, label) else 0 53 | elif rm_type == "f1": 54 | return f1_score(response, label)[0] 55 | elif rm_type == "gpqa": 56 | return compute_gpqa_reward(response, label, metadata=metadata) 57 | elif rm_type == "ifbench": 58 | from .ifbench import compute_ifbench_reward 59 | 60 | return compute_ifbench_reward(response, label, metadata=metadata) 61 | elif rm_type == "random": 62 | return random.randint(0, 1) 63 | elif rm_type: 64 | raise NotImplementedError(f"Rule-based RM for {rm_type} is not implemented.") 65 | else: 66 | raise NotImplementedError("Rule-based RM type is not specified.") 67 | 68 | 69 | async def batched_async_rm( 70 | args, 71 | samples: list[Sample], 72 | **kwargs, 73 | ) -> list[int | float]: 74 | if args.custom_rm_path is not None: 75 | # Ensure the custom reward function is implemented in batch mode 76 | rm_function = load_function(args.custom_rm_path) 77 | return await rm_function(args, samples, **kwargs) 78 | tasks = [async_rm(args, sample, **kwargs) for sample in samples] 79 | rewards = await asyncio.gather(*tasks) 80 | return rewards 81 | -------------------------------------------------------------------------------- /.github/workflows/conda-ci.yml: -------------------------------------------------------------------------------- 1 | name: conda CI 2 | 3 | on: 4 | pull_request: 5 | branches: [main] 6 | 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 9 | cancel-in-progress: true 10 | 11 | jobs: 12 | build-conda: 13 | if: contains(github.event.pull_request.title, '[release]') 14 | runs-on: self-hosted 15 | container: 16 | image: lmsysorg/sglang:v0.5.0rc0-cu126 17 | options: --gpus all --ipc=host --shm-size=16g --ulimit memlock=-1 --ulimit stack=67108864 --memory=0 --memory-swap=0 -v /mnt/nvme0n1/models:/root/models -v /mnt/nvme0n1/datasets:/root/datasets 18 | 19 | defaults: 20 | run: 21 | working-directory: ${{ github.workspace }} 22 | 23 | steps: 24 | - name: Checkout repository 25 | uses: actions/checkout@v4 26 | 27 | - name: Construct Conda 28 | run: | 29 | echo "📦 Installing slime..." 30 | cd $GITHUB_WORKSPACE 31 | echo "Current directory: $(pwd)" 32 | 33 | mkdir -p /root/ 34 | BASE_DIR=/root bash build_conda.sh 35 | shell: bash 36 | 37 | - name: Download model and dataset 38 | run: | 39 | echo "🔗 Downloading up model and dataset..." 40 | 41 | # Create cache directories if they don't exist 42 | mkdir -p /root/models /root/datasets 43 | 44 | echo "Downloading Qwen3-30B-A3B..." 45 | hf download Qwen/Qwen3-30B-A3B --local-dir /root/models/Qwen3-30B-A3B 46 | hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/models/Qwen3-30B-A3B-FP8 47 | 48 | hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir /root/datasets/dapo-math-17k 49 | 50 | hf download --repo-type dataset zhuzilin/aime-2024 --local-dir /root/datasets/aime-2024 51 | shell: bash 52 | 53 | - name: Convert checkpoint 54 | run: | 55 | echo "🔄 Converting model checkpoint..." 56 | cd $GITHUB_WORKSPACE 57 | echo "Current directory: $(pwd)" 58 | 59 | source ~/.bashrc 60 | micromamba activate slime 61 | export CUDA_HOME="$CONDA_PREFIX" 62 | 63 | source scripts/models/qwen3-30B-A3B.sh 64 | PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 tools/convert_hf_to_torch_dist.py \ 65 | ${MODEL_ARGS[@]} \ 66 | --hf-checkpoint /root/models/Qwen3-30B-A3B \ 67 | --save /root/Qwen3-30B-A3B_torch_dist 68 | shell: bash 69 | 70 | - name: Run tests 71 | run: | 72 | echo "🧪 Running tests..." 73 | cd $GITHUB_WORKSPACE 74 | echo "Current directory: $(pwd)" 75 | 76 | source ~/.bashrc 77 | micromamba activate slime 78 | export CUDA_HOME="$CONDA_PREFIX" 79 | 80 | SLIME_TEST_USE_DEEPEP=0 SLIME_TEST_USE_FP8_ROLLOUT=0 python tests/test_qwen3_30B_A3B.py 81 | shell: bash 82 | 83 | - name: Cleanup 84 | if: always() 85 | run: | 86 | echo "🧹 Cleaning up..." 87 | pkill -9 ray || true 88 | ray stop --force || true 89 | pkill -9 python || true 90 | shell: bash 91 | -------------------------------------------------------------------------------- /examples/multi_agent/prompts.py: -------------------------------------------------------------------------------- 1 | ## 定义了一些 prompts 的模板,用于生成不同的 prompts 2 | 3 | 4 | SOLVER_PROMPT_TEMPLATE = """{problem_statement}""" 5 | 6 | 7 | def generate_rewriter_template(num_solutions: int) -> str: 8 | """Dynamically generate rewrite templates based on the number of solutions.""" 9 | solution_sections = [] 10 | for i in range(num_solutions): 11 | solution_sections.append(f"#### Solution {i+1}\n{{solution{i+1}}}\n\n---") 12 | 13 | solutions_text = "\n".join(solution_sections) 14 | 15 | return f"""### Task: Solution Rewriting Based on Previous Solutions ### 16 | You are being reactivated to revise your mathematical proof. You are provided with two documents: 17 | 1. The problem you need to solve. 18 | 2. Your {num_solutions} different "Previous Solutions". 19 | 20 | Your sole task is to generate a new, correct version of your solution based on your previous discoveries in the provided {num_solutions} solutions. 21 | 22 | Refer to the following {num_solutions} solutions and solve the problem. 23 | --- 24 | 25 | ### Problem 26 | 27 | {{problem_statement}} 28 | 29 | --- 30 | 31 | ### Candidates Solution 32 | {solutions_text} 33 | """ 34 | 35 | 36 | def generate_select_template(num_solutions: int) -> str: 37 | """Dynamically generate select templates based on the number of solutions.""" 38 | solution_sections = [] 39 | for i in range(num_solutions): 40 | solution_sections.append(f"#### Solution {i+1}\n{{solution{i+1}}}\n\n---") 41 | 42 | solutions_text = "\n".join(solution_sections) 43 | 44 | return f"""You will be given a challenging math problem followed by {num_solutions} solutions. 45 | Your task is to systematically analyze these solutions to identify the most mathematically sound approach. 46 | 47 | You are provided with two documents: 48 | 1. The problem you need to solve. 49 | 2. Your {num_solutions} "Candidate Solutions". 50 | 51 | Evaluation Process: 52 | 1. Initial Screening 53 | - Group solutions by their final answers 54 | - Identify and explain mathematical contradictions between different answers 55 | - Eliminate solutions with clear mathematical errors 56 | 57 | 2. Detailed Analysis 58 | For remaining solutions, evaluate: 59 | - Mathematical precision and accuracy 60 | - Logical progression of steps 61 | - Completeness of mathematical reasoning 62 | - Handling of edge cases or special conditions 63 | - For solutions containing and addressing errors, evaluate the error identification and correction methodology. 64 | 65 | 3. Solution Comparison 66 | Compare viable solutions based on: 67 | - Efficiency of approach 68 | - Clarity of mathematical reasoning 69 | - Sophistication of method 70 | - Robustness of solution (works for all cases) 71 | 72 | Your response should include: 73 | 1. Brief analysis of conflicting answers 74 | 2. Detailed evaluation of mathematically sound solutions 75 | 3. Justification for eliminating incorrect solutions 76 | 4. Clear explanation for selecting the best approach 77 | 78 | End your evaluation with exactly: 79 | Judgment: IDX 80 | where IDX is the index 1-{num_solutions} of the best solution 81 | 82 | ### Problem 83 | 84 | {{problem_statement}} 85 | 86 | --- 87 | 88 | ### Candidate Solutions 89 | {solutions_text} 90 | """ 91 | -------------------------------------------------------------------------------- /docs/zh/examples/qwen3-30B-A3B.md: -------------------------------------------------------------------------------- 1 | # 8xH100 训练 Qwen3-30B-A3B 2 | 3 | ## 环境准备 4 | 5 | 搭建环境、下载模型、数据与 ckpt 转换均与 Qwen3-4B 模型相同,可以参考 [示例:Qwen3-4B](qwen3-4B.md),将文中 Qwen3-4B 的部分转换为 Qwen3-30B-A3B 即可。 6 | 7 | 可以用如下方法把 huggingface checkpoint 转化为 torch_dist 格式: 8 | 9 | ```bash 10 | cd slime/ 11 | pip install -e . 12 | source scripts/models/qwen3-30B-A3B.sh 13 | PYTHONPATH=/root/Megatron-LM/ torchrun --nproc-per-node 8 \ 14 | tools/convert_hf_to_torch_dist.py \ 15 | ${MODEL_ARGS[@]} \ 16 | --hf-checkpoint /root/Qwen3-30B-A3B/ \ 17 | --save /root/Qwen3-30B-A3B_torch_dist/ 18 | ``` 19 | 20 | ## 执行训练 21 | 22 | 执行训练: 23 | 24 | ```bash 25 | cd /root/slime 26 | bash scripts/run-qwen3-30B-A3B.sh 27 | ``` 28 | 29 | ### 参数简介 30 | 31 | 这里我们简单介绍一下脚本 [run-qwen3-30B-A3B.sh](https://github.com/THUDM/slime/blob/main/scripts/run-qwen3-30B-A3B.sh) 中与 MoE 相关的部分。 32 | 33 | 1. 为了支持在 8xH800 环境中运行 Qwen3-30B-A3B,我们需要开启 megatron 的 CPU Adam 以节省显存,对应配置为: 34 | 35 | ```bash 36 | OPTIMIZER_ARGS=( 37 | ... 38 | --optimizer-cpu-offload 39 | --overlap-cpu-optimizer-d2h-h2d 40 | --use-precision-aware-optimizer 41 | ) 42 | ``` 43 | 44 | 2. 开启 megatron 支持的 moe 优化,当前配置为 tp4, ep8: 45 | 46 | ```bash 47 | PERF_ARGS=( 48 | --tensor-model-parallel-size 4 49 | --sequence-parallel 50 | --pipeline-model-parallel-size 1 51 | --context-parallel-size 1 52 | --expert-model-parallel-size 8 53 | --expert-tensor-parallel-size 1 54 | ... 55 | ) 56 | ``` 57 | 58 | 3. 开启 sglang 支持的 moe 优化,当前配置为 ep8: 59 | 60 | ```bash 61 | SGLANG_ARGS=( 62 | --rollout-num-gpus-per-engine 8 63 | --sglang-mem-fraction-static 0.7 64 | --sglang-enable-ep-moe 65 | --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) 66 | ) 67 | ``` 68 | 69 | 类似地,也可以加入 dp attention,例如配置上: 70 | 71 | ```bash 72 | --sglang-enable-dp-attention 73 | --sglang-dp-size 8 74 | ``` 75 | 76 | ### bf16 训练 fp8 推理 77 | 78 | slime 还支持 bf16 训练,fp8 推理。对于 Qwen3-30B-A3B 模型,只需要下载如下模型: 79 | 80 | ```bash 81 | hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/Qwen3-30B-A3B-FP8 82 | ``` 83 | 84 | 并将 `--hf-checkpoint` 替换为: 85 | 86 | ```bash 87 | #--hf-checkpoint /root/Qwen3-30B-A3B 88 | --hf-checkpoint /root/Qwen3-30B-A3B-FP8 89 | ``` 90 | 91 | 即可触发 fp8 训练。目前我们会将 bf16 权重直接 cast 为 fp8,后续会逐渐添加对精度影响更小的量化方案。 92 | 93 | ⚠️ 训练的 megatron checkpoint 还需要是最开始用 bf16 的 huggingface 转换的。 94 | 95 | ### 多机支持 96 | 97 | 对于多机环境,需要进行如下的几点修改: 98 | - 将训练模型,数据放在所有机器都可以访问到的路径上; 99 | - 设置各台机器都可以访问到的 `MASTER_ADDR` 之外; 100 | - 去掉 CPU adam 相关的配置,因为使用了 distributed optimizer,所以多机环境下 optimizer 的显存占比会明显下降。 101 | 102 | 除此之外,还可以进行如下的修改: 103 | 104 | - 当总卡数并不能被 expert 总数乘除时,可以使用 `--sglang-ep-num-redundant-experts` 来增加冗余的 expert,例如对于 24 卡的场景,可以配置: 105 | 106 | ```bash 107 | SGLANG_ARGS=( 108 | --rollout-num-gpus-per-engine 24 109 | --sglang-mem-fraction-static 0.7 110 | --sglang-enable-ep-moe 111 | --sglang-enable-dp-attention 112 | --sglang-dp-size 3 113 | 114 | --sglang-moe-dense-tp-size 1 115 | --sglang-enable-dp-lm-head 116 | --sglang-ep-num-redundant-experts 16 117 | ) 118 | ``` 119 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # slime 2 | 3 | [English](./README.md) 4 | 5 | [![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://thudm.github.io/slime/) 6 | [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/THUDM/slime) 7 | 8 | **slime** 是为 RL scaling 设计的 LLM post‑training 框架,提供两大核心能力: 9 | 10 | 1. **高性能训练**:通过连接 Megatron 与 SGLang,支持各种模式的高效训练; 11 | 2. **灵活的数据生成**:通过自定义数据生成接口以及 server based engine,实现任意的数据训练数据生成流程。 12 | 13 | slime 是 [GLM-4.5](https://z.ai/blog/glm-4.5) 与 [GLM-4.6](https://z.ai/blog/glm-4.6) 背后的 RL 训练框架,除此之外,slime 还支持: 14 | - Qwen3 系列 (Qwen3Next, Qwen3MoE, Qwen3), Qwen2.5 系列; 15 | - DeepSeek V3 系列 (DeepSeek V3, V3.1, DeepSeek R1); 16 | - Llama 3。 17 | 18 | ## 博文 19 | 20 | - 我们的愿景:[slime:为 RL Scaling 设计的 SGLang-Native 后训练框架](https://thudm.github.io/slime/zh/blogs/introducing_slime.html) 21 | - 关于纯异步 agentic 训练的一些想法:[Agent-Oriented Design: An Asynchronous and Decoupled Framework for Agentic RL](https://www.notion.so/Agent-Oriented-Design-An-Asynchronous-and-Decoupled-Framework-for-Agentic-RL-2278e692d081802cbdd5d37cef76a547) 22 | - v0.1.0 日志:[slime v0.1.0: 重新定义高性能 RL 训练框架](https://zhuanlan.zhihu.com/p/1945237948166547268) 23 | 24 | 25 | ## 目录 26 | 27 | - [架构总览](#架构总览) 28 | - [快速开始](#快速开始) 29 | - [Checkpoint 格式转换](#checkpoint-格式转换) 30 | - [启动训练流程](#启动训练流程) 31 | - [参数说明](#参数说明) 32 | - [开发指南](#开发指南) 33 | - [常见 Q&A 与致谢](#常见-qa-与致谢) 34 | 35 | ## 架构总览 36 | 37 | ![arch](./imgs/arch.png) 38 | 39 | **模块说明**: 40 | 41 | - **training (Megatron)**:负责主训练流程,从 Data Buffer 读取数据,训练完后将参数同步至 rollout 模块; 42 | - **rollout (SGLang + router)**:生成新数据(含 reward/verifier),存储至 Data Buffer; 43 | - **data buffer**:桥梁模块,管理 prompt 初始化、自定义数据与 rollout 生成方法。 44 | 45 | ## 快速开始 46 | 47 | 有关环境配置、数据准备、训练启动和关键代码分析的完整快速开始指南,请参考: 48 | 49 | - [快速开始指南](./docs/zh/get_started/quick_start.md) 50 | 51 | 我们还提供了一些未在快速开始中覆盖的使用示例,请查看 [examples](examples/)。 52 | 53 | ## 参数说明 54 | 55 | 参数分为三类: 56 | 57 | 1. **megatron 参数**:slime 会读取 `PYTHONPATH` 中的 megatron 里设置的所有参数,可以通过传入如 `--tensor-model-parallel-size 2` 的方式配置 megatron; 58 | 2. **sglang 参数**:支持环境中安装的 sglang 的所有参数,这些参数需要以 `--sglang` 起始,例如 `--mem-fraction-static` 需要通过 `--sglang-mem-fraction-static` 传入。 59 | 3. **slime 自身的参数**:请见:[slime/utils/arguments.py](slime/utils/arguments.py) 60 | 61 | 完整使用说明请查阅 [使用文档](docs/zh/get_started/usage.md)。 62 | 63 | ## 开发指南 64 | 65 | - **欢迎贡献!** 若有功能建议、性能调优或使用体验反馈,欢迎提交 Issue / PR 😊 66 | 67 | - 使用 [pre-commit](https://pre-commit.com/) 保证提交代码风格: 68 | 69 | ```bash 70 | apt install pre-commit -y 71 | pre-commit install 72 | 73 | # 运行 pre-commit 保证代码风格 74 | pre-commit run --all-files --show-diff-on-failure --color=always 75 | ``` 76 | 77 | - 调试技巧请参考 [debug 指南](docs/zh/developer_guide/debug.md) 78 | 79 | ## 常见 Q&A 与致谢 80 | 81 | - 常见问题请见 [Q&A](docs/zh/get_started/qa.md) 82 | - 特别感谢以下项目 & 社区:SGLang、Megatron‑LM、mbridge、OpenRLHF、veRL、Pai-Megatron-Patch 等。 83 | 84 | - 引用 slime 请使用: 85 | ```bibtex 86 | @misc{slime_github, 87 | author = {Zilin Zhu and Chengxing Xie and Xin Lv and slime Contributors}, 88 | title = {slime: An LLM post-training framework for RL Scaling}, 89 | year = {2025}, 90 | howpublished = {\url{https://github.com/THUDM/slime}}, 91 | note = {GitHub repository. Corresponding author: Xin Lv}, 92 | urldate = {2025-06-19} 93 | } 94 | ``` 95 | -------------------------------------------------------------------------------- /examples/retool/README.md: -------------------------------------------------------------------------------- 1 | # Retool: from SFT to RL 2 | 3 | This example demonstrates how to use the retool functionality for tool-enabled language model generation. 4 | 5 | ## Overview 6 | 7 | The retool example provides: 8 | - Safe Python code execution in a sandbox environment 9 | - Tool registry for managing available tools 10 | - Integration with language model generation 11 | - Reward calculation for tool usage 12 | 13 | ## Files 14 | 15 | - `generate_with_retool.py`: Main generation function with tool support 16 | - `tool_sandbox.py`: Tool execution and safety management 17 | - `sft_data_processing.py`: Process SFT dataset 18 | 19 | ## Usage 20 | 21 | 1. Setup and download datasets: 22 | ```bash 23 | cd slime 24 | pip install -e . 25 | # For SFT part, you can use later model to RL directly and skip SFT. 26 | hf download --repo-type dataset JoeYing/ReTool-SFT --local-dir /root/JoeYing/ReTool-SFT 27 | hf download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/Qwen/Qwen3-4B-Instruct-2507 28 | 29 | # For RL part 30 | hf download --repo-type dataset zhuzilin/dapo-math-17k --local-dir /root/dapo-math-17k 31 | hf download --repo-type dataset zhuzilin/aime-2024 --local-dir /root/aime-2024 32 | # download our SFT model if you want to skip SFT 33 | hf download font-info/qwen3-4b-sft-SGLang-RL --local-dir /root/font-info/qwen3-4b-sft 34 | ``` 35 | 36 | 2. Create torch dist 37 | For SFT 38 | ```bash 39 | source scripts/models/qwen3-4B.sh 40 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ 41 | ${MODEL_ARGS[@]} \ 42 | --hf-checkpoint /root/Qwen/Qwen3-4B-Instruct-2507 \ 43 | --rotary-base 5000000 \ 44 | --save /root/Qwen/Qwen3-4B-Instruct-2507_torch_dist 45 | ``` 46 | 47 | Or RL only 48 | ```bash 49 | source scripts/models/qwen3-4B.sh 50 | PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ 51 | ${MODEL_ARGS[@]} \ 52 | --hf-checkpoint /root/font-info/qwen3-4b-sft \ 53 | --rotary-base 5000000 \ 54 | --save /root/font-info/qwen3-4b-sft_torch_dist 55 | 56 | ``` 57 | 58 | 3. SFT: 59 | ```bash 60 | python examples/retool/sft_data_processing.py 61 | bash examples/retool/retool_qwen3_4b_sft.sh 62 | ``` 63 | 64 | 4. RL: 65 | ```bash 66 | bash examples/retool/retool_qwen3_4b_rl.sh 67 | ``` 68 | 69 | 5. Use in your training scripts by importing the generate function: 70 | ```python 71 | from generate_with_retool import generate, reward_func 72 | ``` 73 | 74 | ## Tool Format 75 | 76 | The system uses the following tool format: 77 | 78 | ``` 79 | You may call one or more functions to assist with the user query. 80 | 81 | You are provided with function signatures within XML tags: 82 | 83 | {"type": "function", "function": {"name": "code_interpreter", "description": "A tool for executing code.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The code to execute."}}, "required": ["code"]}}} 84 | 85 | 86 | For each function call, return a json object with function name and arguments within XML tags: 87 | 88 | {"name": , "arguments": } 89 | 90 | ``` 91 | 92 | ## Safety Features 93 | 94 | - Code execution in isolated sandbox 95 | - Memory and time limits 96 | - Dangerous operation detection 97 | - Allowed module restrictions 98 | -------------------------------------------------------------------------------- /tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py: -------------------------------------------------------------------------------- 1 | import os 2 | import slime.utils.external_utils.command_utils as U 3 | 4 | MODEL_NAME = "Qwen3-0.6B" 5 | 6 | 7 | def prepare(): 8 | U.exec_command("mkdir -p /root/models /root/datasets") 9 | U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") 10 | U.hf_download_dataset("zhuzilin/gsm8k") 11 | 12 | 13 | def execute(): 14 | ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " 15 | 16 | rollout_args = ( 17 | "--prompt-data /root/datasets/gsm8k/train.parquet " 18 | "--input-key messages " 19 | "--label-key label " 20 | "--apply-chat-template " 21 | "--rollout-shuffle " 22 | "--rm-type math " 23 | f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 60} " 24 | "--rollout-batch-size 32 " 25 | "--n-samples-per-prompt 8 " 26 | "--rollout-max-response-len 1024 " 27 | "--rollout-temperature 0.8 " 28 | "--over-sampling-batch-size 64 " 29 | "--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " 30 | "--global-batch-size 256 " 31 | ) 32 | 33 | eval_args = ( 34 | "--eval-interval 20 " 35 | "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " 36 | "--n-samples-per-eval-prompt 1 " 37 | "--eval-max-response-len 1024 " 38 | "--eval-top-k 1 " 39 | ) 40 | 41 | grpo_args = ( 42 | "--advantage-estimator grpo " 43 | # "--use-kl-loss " 44 | "--kl-loss-coef 0.00 " 45 | "--kl-loss-type low_var_kl " 46 | "--kl-coef 0.00 " 47 | "--entropy-coef 0.00 " 48 | "--eps-clip 0.2 " 49 | "--eps-clip-high 0.28 " 50 | ) 51 | 52 | optimizer_args = ( 53 | "--optimizer adam " 54 | "--lr 1e-6 " 55 | "--lr-decay-style constant " 56 | "--weight-decay 0.1 " 57 | "--adam-beta1 0.9 " 58 | "--adam-beta2 0.98 " 59 | ) 60 | 61 | sglang_args = "--rollout-num-gpus-per-engine 2 " "--sglang-decode-log-interval 1000 " "--sglang-enable-metrics " 62 | 63 | fsdp_args = ( 64 | # Set the bucket size for weight update 65 | "--update-weight-buffer-size 536870912 " # 512MB 66 | ) 67 | 68 | ci_args = ( 69 | "--ci-test " 70 | "--ci-disable-kl-checker " 71 | "--ci-metric-checker-key eval/gsm8k " 72 | "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step 73 | ) 74 | 75 | misc_args = "--actor-num-nodes 1 " "--actor-num-gpus-per-node 2 " "--colocate " "--train-backend fsdp " 76 | 77 | train_args = ( 78 | f"{ckpt_args} " 79 | f"{rollout_args} " 80 | f"{optimizer_args} " 81 | f"{grpo_args} " 82 | f"{sglang_args} " 83 | f"{U.get_default_wandb_args(__file__)} " 84 | f"{eval_args} " 85 | f"{fsdp_args} " 86 | f"{ci_args} " 87 | f"{misc_args} " 88 | ) 89 | 90 | U.execute_train( 91 | train_args=train_args, 92 | num_gpus_per_node=2, 93 | megatron_model_type=None, 94 | ) 95 | 96 | 97 | if __name__ == "__main__": 98 | prepare() 99 | os.environ.pop("http_proxy") 100 | os.environ.pop("https_proxy") 101 | os.environ.pop("HTTP_PROXY") 102 | os.environ.pop("HTTPS_PROXY") 103 | execute() 104 | -------------------------------------------------------------------------------- /train_async.py: -------------------------------------------------------------------------------- 1 | import ray 2 | 3 | from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models 4 | from slime.utils.arguments import parse_args 5 | from slime.utils.logging_utils import configure_logger 6 | from slime.utils.misc import should_run_periodic_action 7 | from slime.utils.tracking_utils import init_tracking 8 | 9 | 10 | # The framework supports other asynchronous approaches such as fully async (which is shown in examples/full_async). 11 | def train(args): 12 | assert not args.colocate, "Colocation is not supported for async training." 13 | configure_logger() 14 | # allocate the GPUs 15 | pgs = create_placement_groups(args) 16 | init_tracking(args) 17 | 18 | # create the rollout manager, with sglang engines inside. 19 | # need to initialize rollout manager first to calculate num_rollout 20 | rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"]) 21 | 22 | # create the actor and critic models 23 | actor_model, critic_model = create_training_models(args, pgs, rollout_manager) 24 | 25 | # always update weight first so that sglang has the loaded weights from training. 26 | actor_model.update_weights() 27 | 28 | if args.check_weight_update_equal: 29 | ray.get(rollout_manager.check_weights.remote(action="compare")) 30 | 31 | # async train loop. 32 | rollout_data_next_future = rollout_manager.generate.remote(args.start_rollout_id) 33 | for rollout_id in range(args.start_rollout_id, args.num_rollout): 34 | # Sync the last generation 35 | if rollout_data_next_future is not None: 36 | rollout_data_curr_ref = ray.get(rollout_data_next_future) 37 | 38 | # Start the next rollout early. 39 | if rollout_id + 1 < args.num_rollout: 40 | rollout_data_next_future = rollout_manager.generate.remote(rollout_id + 1) 41 | 42 | if args.use_critic: 43 | critic_train_handle = critic_model.async_train(rollout_id, rollout_data_curr_ref) 44 | if rollout_id >= args.num_critic_only_steps: 45 | ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref)) 46 | ray.get(critic_train_handle) 47 | else: 48 | ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref)) 49 | 50 | if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch, args.num_rollout): 51 | actor_model.save_model(rollout_id, force_sync=rollout_id == args.num_rollout - 1) 52 | if args.use_critic: 53 | critic_model.save_model(rollout_id, force_sync=rollout_id == args.num_rollout - 1) 54 | if args.rollout_global_dataset: 55 | ray.get(rollout_manager.save.remote(rollout_id)) 56 | 57 | if (rollout_id + 1) % args.update_weights_interval == 0: 58 | # sync generate before update weights to prevent update weight in the middle of generation 59 | rollout_data_curr_ref = ray.get(x) if (x := rollout_data_next_future) is not None else None 60 | rollout_data_next_future = None 61 | actor_model.update_weights() 62 | 63 | if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch): 64 | ray.get(rollout_manager.eval.remote(rollout_id)) 65 | 66 | ray.get(rollout_manager.dispose.remote()) 67 | 68 | 69 | if __name__ == "__main__": 70 | args = parse_args() 71 | train(args) 72 | -------------------------------------------------------------------------------- /docs/zh/examples/qwen3-next-80B-A3B.md: -------------------------------------------------------------------------------- 1 | # 8xH100 训练 Qwen3-30B-A3B 2 | 3 | ## 环境准备 4 | 5 | 搭建环境、下载模型、数据与 ckpt 转换均与 Qwen3-4B 模型相同,可以参考 [示例:Qwen3-4B](./qwen3-4B.md),将文中 Qwen3-4B 的部分转换为 6 | Qwen3-next-80B-A3B-Instruct 即可。 7 | 8 | 可以用如下完整方法把 huggingface checkpoint 转化为 torch_dist 格式: 9 | 10 | ```bash 11 | export BASE_FOLDER=./models/ 12 | # 下载模型权重 (Qwen3-Next-80B-A3B-Thinking) 13 | hf download Qwen/Qwen3-Next-80B-A3B-Thinking --local-dir ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking 14 | ``` 15 | 16 | ```shell 17 | cd slime/ 18 | pip install -e . 19 | 20 | # (for acceleration) 21 | cd .. # and find a proper folder 22 | git clone https://github.com/fla-org/flash-linear-attention 23 | cd flash-linear-attention 24 | git checkout 9714c595 25 | pip install -e . 26 | 27 | wget https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.4/causal_conv1d-1.5.4+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl 28 | pip install ./causal_conv1d-1.5.4+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl 29 | ``` 30 | 31 | ## [Optional] Fix a bug in triton compilation on Blackwell (sm100) 32 | 33 | see discussion here https://github.com/triton-lang/triton/issues/8695 34 | and https://github.com/fla-org/flash-linear-attention/issues/638 35 | 36 | We need to apply a patch to fix the bug. 37 | Go to the flash-linear-attention folder you just installed, and apply the following patch: 38 | 39 | ```diff 40 | diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py 41 | index c5119dcf..838f5e4e 100644 42 | --- a/fla/ops/gated_delta_rule/wy_fast.py 43 | +++ b/fla/ops/gated_delta_rule/wy_fast.py 44 | @@ -198,7 +198,14 @@ def prepare_wy_repr_bwd_kernel( 45 | b_A += tl.dot(b_kb, tl.trans(b_k)) 46 | b_dkb = tl.dot(b_dA, b_k) 47 | b_db += tl.sum(b_dkb * b_k, 1) 48 | - b_dk += tl.dot(tl.trans(b_dA), b_kb) 49 | + b_dk += tl.inline_asm_elementwise( 50 | + asm="mov.f32 $0, $1;", 51 | + constraints="=r,r", 52 | + args=[tl.dot(tl.trans(b_dA), b_kb)], 53 | + dtype=tl.float32, 54 | + is_pure=True, 55 | + pack=1, 56 | + ) 57 | b_dk += b_dkb * b_b[:, None] 58 | tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) 59 | tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) 60 | 61 | ``` 62 | 63 | save it as `patch.diff` (Please remember to copy the last empty line to the file!) and do `git apply patch.diff` 64 | 65 | ## 执行训练 (Megatron) 66 | 67 | **当前暂不支持Blackwell** 68 | 69 | 转换模型权重: 70 | 71 | ```bash 72 | source scripts/models/qwen3-next-80B-A3B.sh 73 | PYTHONPATH=/root/Megatron-LM/ torchrun --nproc-per-node 8 \ 74 | tools/convert_hf_to_torch_dist.py \ 75 | ${MODEL_ARGS[@]} \ 76 | --hf-checkpoint /root/Qwen3-Next-80B-A3B-Thinking/ \ 77 | --save /root/Qwen3-Next-80B-A3B-Thinking_torch_dist/ 78 | ``` 79 | 80 | 单机8卡 81 | 82 | ```bash 83 | cd /root/slime 84 | export BASE_FOLDER=/root 85 | export MASTER_ADDR=127.0.0.1 86 | bash scripts/run-qwen3-next-80B-A3B-8gpus.sh 87 | ``` 88 | 如果显存不够,考虑disable `--accumulate-allreduce-grads-in-fp32`,enable `--grad-reduce-in-bf16` 89 | 90 | 91 | 多机(4x8) 92 | 93 | ```bash 94 | cd /root/slime 95 | export BASE_FOLDER=/root 96 | export MASTER_ADDR=your_master_addr 97 | bash scripts/run-qwen3-next-80B-A3B.sh 98 | ``` 99 | 100 | ## 执行训练 (FSDP) 101 | 102 | ```bash 103 | export BASE_FOLDER=./models/ 104 | export MASTER_ADDR=127.0.0.1 105 | 106 | bash scripts/run-qwen3-next-80B-A3B-fsdp.sh 107 | ``` 108 | 109 | --------------------------------------------------------------------------------