├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
└── workflows
│ ├── build.yml
│ ├── pylint.yml
│ ├── python-publish.yml
│ └── unit_test.yml
├── .gitignore
├── .pylintrc
├── LICENSE
├── Makefile
├── README.md
├── README_CN.md
├── chatlearn
├── __init__.py
├── algorithm
│ ├── base_algo.py
│ └── grpo.py
├── chatlearn.py
├── checkpoint
│ ├── __init__.py
│ └── checkpoint_manager.py
├── configs
│ └── common.py
├── data
│ ├── __init__.py
│ ├── data.py
│ └── sampler.py
├── hooks.py
├── launcher
│ ├── __init__.py
│ ├── dlc_utils.py
│ └── initialize.py
├── models
│ ├── __init__.py
│ ├── base_module.py
│ ├── fsdp_module.py
│ ├── megatron
│ │ ├── __init__.py
│ │ ├── memory_manager
│ │ │ ├── __init__.py
│ │ │ ├── base_trainer.py
│ │ │ ├── inference.py
│ │ │ ├── trainer_v1v2.py
│ │ │ ├── trainer_v3.py
│ │ │ ├── trainer_v4.py
│ │ │ └── trainer_v5.py
│ │ └── ops
│ │ │ ├── __init__.py
│ │ │ └── policy_gradient.py
│ ├── megatron_module.py
│ ├── patches
│ │ ├── monkey_patch.py
│ │ └── transformers
│ │ │ ├── qwen2_patch.py
│ │ │ └── qwen3_patch.py
│ ├── torch_module.py
│ ├── vllm
│ │ ├── __init__.py
│ │ ├── hooks
│ │ │ ├── __init__.py
│ │ │ └── vllm_0_8_5
│ │ │ │ ├── __init__.py
│ │ │ │ ├── async_llm_engine.py
│ │ │ │ ├── llm.py
│ │ │ │ ├── llm_engine.py
│ │ │ │ ├── logits_processor.py
│ │ │ │ ├── ray_distributed_executor.py
│ │ │ │ └── worker_base.py
│ │ └── inference.py
│ └── vllm_module.py
├── runtime
│ ├── __init__.py
│ ├── decorator.py
│ ├── dist_actor.py
│ ├── engine.py
│ ├── environment.py
│ ├── evaluator.py
│ ├── executor.py
│ ├── model_flow.py
│ ├── trainer.py
│ └── utils.py
├── schedule
│ ├── __init__.py
│ ├── metric_manager.py
│ ├── model_manager.py
│ ├── port_manager.py
│ └── resource_manager.py
├── synchronizer
│ ├── __init__.py
│ ├── base.py
│ ├── megatron_megatron.py
│ ├── megatron_vllm.py
│ ├── parameter_sync.py
│ ├── parameter_sync_fsdp.py
│ └── scheduler.py
├── tools
│ ├── check_parameter_sync.py
│ ├── megatron_checkpoint_utils.py
│ └── megatron_to_hf.py
└── utils
│ ├── __init__.py
│ ├── arguments.py
│ ├── communication_op.py
│ ├── constant.py
│ ├── dist_utils.py
│ ├── error_monitor.py
│ ├── flat_tensors.py
│ ├── future.py
│ ├── global_vars.py
│ ├── log_monitor.py
│ ├── logger.py
│ ├── megatron_import_helper.py
│ ├── megatron_import_hook_helper.py
│ ├── megatron_import_memory_helper.py
│ ├── megatron_import_transformer_helper.py
│ ├── megatron_utils.py
│ ├── timer.py
│ ├── utils.py
│ ├── version.py
│ └── vllm_utils.py
├── docker
└── torch
│ ├── Dockerfile.torch2.3.0
│ ├── Dockerfile.torch2.5.1.vllm066
│ └── Dockerfile.torch2.6.0.vllm085
├── docs
├── .gitignore
├── Makefile
├── README.md
├── en
│ ├── .readthedocs.yaml
│ ├── Makefile
│ ├── advanced.rst
│ ├── api
│ │ ├── config.rst
│ │ ├── engine.rst
│ │ ├── index.rst
│ │ └── module.rst
│ ├── chatlearn.md
│ ├── conf.py
│ ├── config_yaml.md
│ ├── faq.md
│ ├── index.rst
│ ├── installation.md
│ ├── programming.md
│ └── tutorial
│ │ ├── check_sync_parameter.md
│ │ ├── continue_train.md
│ │ ├── custom_model_flow.md
│ │ ├── data.md
│ │ ├── ems.md
│ │ ├── evaluator.md
│ │ ├── profile.md
│ │ ├── run.md
│ │ ├── tutorial_grpo_fsdp.md
│ │ ├── tutorial_grpo_mcore.md
│ │ ├── tutorial_llama2.md
│ │ └── tutorial_qwen.md
├── images
│ ├── arch.png
│ ├── class.png
│ ├── dlc_1.jpg
│ ├── dlc_2.jpg
│ ├── engine.jpg
│ ├── engine_class.png
│ ├── fault.png
│ ├── logo.jpg
│ ├── perf.png
│ ├── rlhf.png
│ └── yaml.jpg
├── requirements.txt
└── zh
│ ├── .readthedocs.yaml
│ ├── Makefile
│ ├── advanced.rst
│ ├── api
│ ├── config.rst
│ ├── engine.rst
│ ├── index.rst
│ └── module.rst
│ ├── chatlearn.md
│ ├── conf.py
│ ├── config_yaml.md
│ ├── faq.md
│ ├── index.rst
│ ├── installation.md
│ ├── programming.md
│ └── tutorial
│ ├── continue_train.md
│ ├── custom_model_flow.md
│ ├── data.md
│ ├── ems.md
│ ├── evaluator.md
│ ├── profile.md
│ ├── run.md
│ ├── tutorial_grpo_fsdp.md
│ ├── tutorial_grpo_mcore.md
│ ├── tutorial_llama2.md
│ └── tutorial_qwen.md
├── examples
├── __init__.py
├── fsdp
│ ├── configs
│ │ └── grpo
│ │ │ ├── base.yaml
│ │ │ ├── grpo.yaml
│ │ │ ├── log.yaml
│ │ │ ├── policy_trainer.yaml
│ │ │ ├── reference.yaml
│ │ │ └── vllm_policy_inference.yaml
│ ├── data
│ │ ├── data_preprocess
│ │ │ ├── gsm8k.py
│ │ │ └── math_lighteval.py
│ │ └── prompt_dataset.py
│ ├── entry
│ │ └── train_grpo.py
│ ├── models
│ │ ├── grpo
│ │ │ ├── __init__.py
│ │ │ ├── loss_gallery.py
│ │ │ └── policy_trainer.py
│ │ ├── rule_reward.py
│ │ └── vllm_policy_inference.py
│ ├── scripts
│ │ ├── base_env.sh
│ │ ├── train_grpo_qwen2_5.sh
│ │ └── train_grpo_qwen3.sh
│ └── utils
│ │ ├── __init__.py
│ │ └── rule_reward_score
│ │ ├── __init__.py
│ │ └── math.py
├── mcore
│ ├── configs
│ │ └── grpo
│ │ │ ├── base.yaml
│ │ │ ├── grpo_qwen2_5.yaml
│ │ │ ├── grpo_qwen3.yaml
│ │ │ ├── log.yaml
│ │ │ ├── model_qwen2_5.yaml
│ │ │ ├── model_qwen3.yaml
│ │ │ ├── policy_trainer_qwen2_5.yaml
│ │ │ ├── policy_trainer_qwen3.yaml
│ │ │ └── vllm_policy_inference.yaml
│ ├── entry
│ │ └── train_grpo.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── policy_model.py
│ │ ├── policy_trainer.py
│ │ ├── train_helper.py
│ │ └── utils.py
│ ├── scripts
│ │ ├── train_grpo_qwen2_5.sh
│ │ └── train_grpo_qwen3.sh
│ └── tokenizer
│ │ └── __init__.py
└── tests
│ ├── barrier.py
│ ├── benchmark_vllm.py
│ └── benchmark_vllm.sh
├── requirements.txt
├── setup.py
└── tests
├── __init__.py
├── base
├── test_exp.py
├── test_multi_dataloader.py
└── test_send_recv.py
├── configs
├── base.yaml
├── eval.yaml
├── exp1.yaml
├── exp2.yaml
├── grpo.yaml
├── model.yaml
├── o1.yaml
├── parameter_sync.yaml
├── rlhf.yaml
├── rlhf2.yaml
├── rlhf_cpu.yaml
├── rlhf_eval.yaml
├── sprl.yaml
├── test_eval.yaml
└── test_eval2.yaml
├── parameter_sync
├── __init__.py
├── test_balanced_hep.py
├── test_hep_ep_vllm_tp.py
├── test_hep_eppp_vllm_tp.py
├── test_hep_eppp_vllm_tppp.py
├── test_hep_eptp_vllm_tp.py
├── test_hep_eptp_vllm_tp_2.py
├── test_hep_eptppp_vllm_tp.py
├── test_hep_eptppp_vllm_tp_2.py
├── test_hep_eptppp_vllm_tppp.py
├── test_hep_tp_vllm_tp.py
└── test_unbalanced_tp.py
├── rlhf
├── __init__.py
├── test_ckpt.py
├── test_data.py
├── test_grpo.py
├── test_indivisible_batchsz.py
├── test_model_flow.py
├── test_placement.py
├── test_placement_colocate.py
├── test_relay_buffer.py
├── test_rlhf_custom.py
├── test_rlhf_placement_colocate.py
└── test_rlhf_replica.py
├── run_tests.sh
├── test_main.py
├── unittests
├── test_flat_tensors.py
├── test_sampler.py
├── test_scheduler.py
└── test_utils.py
└── utils.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: "[BUG]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 |
16 | **Expected behavior**
17 | A clear and concise description of what you expected to happen.
18 |
19 | **Screenshots**
20 | If applicable, add screenshots to help explain your problem.
21 |
22 | **Additional context**
23 | Add any other context about the problem here.
24 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: "[Feature]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Daily Building Script Execution
2 |
3 | on:
4 | workflow_dispatch:
5 | schedule:
6 | # Runs at 00:30 every day (UTC+8)
7 | - cron: '30 16 * * *'
8 |
9 | jobs:
10 | run-shell-script:
11 | runs-on: self-hosted
12 |
13 | steps:
14 | - name: Checkout code
15 | uses: actions/checkout@v3
16 |
17 | - name: Run shell script
18 | run: |
19 | tar_name=chatlearn-$(date +%F).tar.gz
20 | tar czvf /tmp/${tar_name} .
21 | ossutil64 -i ${{ secrets.OSS_AK_ID }} -k ${{ secrets.OSS_AK_SECRET }} -e ${{ secrets.OSS_ENDPOINT }} cp -r /tmp/${tar_name} ${{ secrets.OSS_BUCKET }}/regression/chatlearn/src/
22 |
23 |
--------------------------------------------------------------------------------
/.github/workflows/pylint.yml:
--------------------------------------------------------------------------------
1 | name: Pylint
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches:
7 | - main
8 |
9 | jobs:
10 | build:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ["3.10"]
15 | steps:
16 | - uses: actions/checkout@v4
17 | - name: Set up Python ${{ matrix.python-version }}
18 | uses: actions/setup-python@v3
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install pylint==2.16.1
25 | - name: Analysing the code with pylint
26 | run: |
27 | make lint
28 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v4
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/.github/workflows/unit_test.yml:
--------------------------------------------------------------------------------
1 | name: Unit Tests
2 |
3 | on:
4 | workflow_dispatch:
5 | pull_request:
6 | branches:
7 | - main
8 | - dev
9 | paths-ignore:
10 | - 'docs/**'
11 | push:
12 | branches:
13 | - main
14 | paths-ignore:
15 | - 'docs/**'
16 | tags:
17 | - v[0-9]+.[0-9]+.[0-9]+
18 |
19 | concurrency:
20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
21 | cancel-in-progress: true
22 |
23 | jobs:
24 | run-shell-script:
25 | runs-on: self-hosted
26 |
27 | steps:
28 |
29 | - name: Checkout code
30 | uses: actions/checkout@v4
31 |
32 | - name: Run unit test
33 | run: |
34 | containers=$(docker ps -aqf "name=^chatlearn_ut_")
35 | if [[ -n "$containers" ]]; then
36 | docker rm $containers -f
37 | fi
38 | docker pull $UT_IMAGE
39 | docker run -v $PWD:$PWD -w $PWD --name chatlearn_ut_$(date '+%d_%m_%Y_%H_%M_%S') --net host --ipc host --shm-size 80G -t --rm --gpus all $UT_IMAGE bash -c 'make test'
40 | env:
41 | UT_IMAGE: dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312
42 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | examples/megatron/step3_rlhf/None
2 | .ipynb_checkpoints/
3 | output/
4 | build/
5 | dist/
6 | *.diff
7 | *.pyc
8 | *.idea
9 | *.DS_Store
10 | *.swp
11 | .nfs*
12 | *.dot
13 | *.egg-info/
14 | docs/zh/api/doc/*.rst
15 | *._*py
16 | .vscode/
17 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 | # Disable the message, report, category or checker with the given id(s). You
3 | # can either give multiple identifiers separated by comma (,) or put this
4 | # option multiple times (only on the command line, not in the configuration
5 | # file where it should appear only once).You can also use "--disable=all" to
6 | # disable everything first and then reenable specific checks. For example, if
7 | # you want to run only the similarities checker, you can use "--disable=all
8 | # --enable=similarities". If you want to run only the classes checker, but have
9 | # no Warning level messages displayed, use"--disable=all --enable=classes
10 | # --disable=W"
11 | disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,superfluous-parens,consider-using-f-string,invalid-name,missing-function-docstring,inconsistent-return-statements,logging-fstring-interpolation,no-else-return,broad-exception-raised,broad-exception-caught,protected-access,too-many-lines
12 |
13 | # Maximum number of characters on a single line.
14 | max-line-length=150
15 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | PYTHON ?= python3
2 | ADDITIONAL_DEPS ?=
3 | current_dir := $(shell pwd | sed 's:/*$$::')
4 |
5 | .PHONY: build
6 | build: $(LIB)
7 | $(PYTHON) setup.py bdist_wheel --universal
8 | @printf "\033[0;32mPIP package built\033[0m: "
9 | @ls dist/*.whl
10 |
11 | .PHONY: test
12 | test: $(LIB)
13 | cd tests; bash run_tests.sh
14 |
15 | .PHONY: lint
16 | lint:
17 | git config --global --add safe.directory $(current_dir)
18 | @$(PYTHON) -m pip install pylint==2.16.1 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
19 | @$(PYTHON) -m pylint \
20 | --rcfile=.pylintrc --output-format=parseable --jobs=8 \
21 | $(shell git ls-tree --full-tree --name-only -r HEAD chatlearn | grep \.py$) \
22 | $(shell git diff --cached --name-only chatlearn | grep \.py$) \
23 | $(shell git ls-tree --full-tree --name-only -r HEAD examples/megatron/ | grep \.py$) \
24 | $(shell git diff --cached --name-only examples/megatron/ | grep \.py$) \
25 |
26 | .PHONY: doc
27 | doc:
28 | cd docs; make html
29 |
30 |
31 | .DEFAULT_GOAL := lint
32 |
--------------------------------------------------------------------------------
/README_CN.md:
--------------------------------------------------------------------------------
1 |
2 | [](https://chatlearn.readthedocs.io/zh-cn/latest/)
3 | [](https://github.com/alibaba/ChatLearn/blob/main/LICENSE)
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | 灵活、易用、高效的大语言模型(LLMs)强化学习训练框架
13 |
14 |
15 | English  |  中文 
16 |
17 |
18 | ---
19 |
20 | *最新进展* 🔥
21 | - [2025/5] 训练支持Mcore框架!基于Mcore和vLLM,我们提供了Qwen2.5模型的端到端GRPO训练[教学](docs/en/tutorial/tutorial_grpo_mcore.md)!🔥
22 | - [2025/5] 训练支持FSDP框架!基于FSDP和vLLM,我们提供了Qwen3模型的端到端GRPO训练[教学](docs/en/tutorial/tutorial_grpo_fsdp.md)!🔥
23 | - [2024/8] 正式开源 ChatLearn,更多介绍请参考我们的 [文档](docs/zh/chatlearn.md)。
24 |
25 | ---
26 |
27 | ChatLearn 是阿里云PAI团队开发的大规模LLMs强化学习训练框架。ChatLearn 通过对模型计算逻辑的抽象,解耦了模型和计算 backend、分布式策略的绑定,提供灵活的资源调度机制,可以支持灵活的资源分配和并行调度策略。
28 |
29 | 
30 |
31 | ChatLearn的特点如下:
32 | 1. 🚀**易用的编程接口**: ChatLearn提供通用的编程抽象,用户只需要封装几个函数即可完成模型构造。用户只需要专注于单模型的编程,系统负责资源调度、数据流传输、控制流传输、分布式执行等。
33 | 2. 🔧**高可扩展的训练方式**: ChatLearn 提供 RLHF、DPO、OnlineDPO、GRPO 等 强化学习训练,同时也支持用户自定义 model 的执行 flow,使定制化训练流程变得非常便捷。
34 | 3. 🔄**多种分布式加速引擎**: 用户可以使用不同的计算 backend 进行模型建模,如 Megatron-LM、DeepSpeed、vLLM 等。用户也可以组合使用不同的 backend,如用 Megatron-LM 来进行加速训练,用 vLLM 来加速推理。
35 | 4. 🎯**灵活的并行策略和资源分配**: ChatLearn 支持不同模型配置不同的并行策略,可以结合各模型计算、显存、通信的特点来制定不同的并行策略。同时 ChatLearn 支持灵活的资源调度机制,支持各模型的资源独占或复用,通过系统调度策略支持高效的串行/并行执行和高效的显存共享。
36 | 5. ⚡**高性能**: 相较于当前的 SOTA 系统,ChatLearn 在 7B+7B (Policy+Reward) 规模性能提升52%,70B+70B 规模性能提升 137%。同时,ChatLearn 支持更大规模的 Alignment 训练,例如:300B+300B。
37 |
38 | # 快速开始
39 |
40 | 请参考 [文档](https://chatlearn.readthedocs.io/zh-cn/latest/) 快速开始.
41 |
42 | 1. [环境和代码准备](docs/zh/installation.md)
43 | 2. [基于 FSDP + vLLM的Qwen3模型端到端GRPO训练流程](docs/zh/tutorial/tutorial_grpo_fsdp.md)
44 | 3. [基于 LLaMA/LLaMA2 模型的端到端训练教程](docs/zh/tutorial/tutorial_llama2.md)
45 |
46 |
47 | # 性能评估
48 |
49 | 我们比较了不同参数量规模模型的 RLHF 训练吞吐量,我们采取 N+N 的模型配置,即 Policy 模型和 Reward 模型采用相同大小的参数量。我们和 DeepSpeed-Chat、OpenRLHF 对比了 7B 和 70B 的模型配置,在 8 GPUs 7B+7B 规模,有 115% 的加速,在 32 GPUs 70B+70B 规模,有 208% 的加速。规模越大,加速效果越明显。同时ChatLearn还能支持更大规模的 Alignment 训练,例如:300B+300B 规模。
50 |
51 |
52 | 
53 |
54 | 注:DeepSpeed-Chat和OpenRLHF性能已经优化过。
55 |
56 | # 功能列表
57 |
58 | - 支持 RLHF、DPO、OnlineDPO、GRPO 以及用户自定义的RL训练;
59 | - 支持 Megatron-LM,FSDP 作为训练的 backend,支持 vLLM 作为推理的 backend;
60 | - 支持 各模型独立配置并行策略,并支持模型间高效参数同步,自动进行并行策略转换;
61 | - 支持 EMS(Efficient Memory Sharing) 功能,支持模型间显存高效共享;
62 | - 支持模型的资源类型:GPU、CPU,例如定义纯 CPU 的 Math Reward 模型;
63 | - 支持 Megatron-Core 格式模型;
64 |
65 | # Roadmap
66 |
67 | ChatLearn 接下来会支持以下特性:
68 | - [ ] 简化参数配置
69 | - [ ] 提供MoE模型强化学习训练的教程
70 | - [ ] 支持更多的模型
71 | - [ ] 性能优化
72 | - [ ] 支持更多的强化学习算法
73 |
74 |
75 | 我们欢迎社区小伙伴参与进来合作开发,也欢迎加入钉钉群:98090003312 参与讨论。我们在持续招聘中,欢迎联系我们或者投递简历到[email](mailto:wanglin.zj@alibaba-inc.com)。
76 |
--------------------------------------------------------------------------------
/chatlearn/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """init"""
16 |
17 | import importlib
18 |
19 | from chatlearn import hooks
20 | from chatlearn.launcher.initialize import init
21 | from chatlearn.models.base_module import BaseModule
22 | from chatlearn.models.megatron_module import MegatronModule
23 | from chatlearn.models.torch_module import TorchModule
24 | from chatlearn.models.fsdp_module import FSDPModule
25 | from chatlearn.runtime.engine import Engine, RLHFEngine
26 | from chatlearn.runtime.engine import Environment
27 | from chatlearn.runtime.engine import Trainer
28 | from chatlearn.runtime.evaluator import Evaluator
29 | from chatlearn.runtime.model_flow import ControlDependencies
30 | from chatlearn.utils.future import get
31 | from chatlearn.utils.global_vars import get_args
32 | from chatlearn.utils.logger import logger
33 |
--------------------------------------------------------------------------------
/chatlearn/algorithm/base_algo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """base algorithm"""
16 |
17 | from abc import ABC, abstractmethod
18 |
19 | class BaseAlgorithm(ABC):
20 | """BaseAlgorithm"""
21 |
22 | @abstractmethod
23 | def run(self):
24 | """
25 | Run the algorithm.
26 | """
27 |
--------------------------------------------------------------------------------
/chatlearn/checkpoint/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/chatlearn/configs/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """common configs"""
16 |
17 | from dataclasses import dataclass, field
18 | from omegaconf import MISSING
19 |
20 |
21 | @dataclass
22 | class RuntimeEnvConfig:
23 | """RuntimeEnvConfig"""
24 |
25 | platform: str = field(
26 | default="DLC",
27 | metadata={"help": "Platform to run the model. Default is DLC."}
28 | )
29 |
30 |
31 | @dataclass
32 | class BaseModelConfig:
33 | """BaseModelConfig"""
34 |
35 | seed: int = field(
36 | default=1234,
37 | metadata={"help": "Random seed. Default is 1234."}
38 | )
39 |
40 |
41 | @dataclass
42 | class PolicyConfig(BaseModelConfig):
43 | """PolicyConfig"""
44 |
45 | num_gpus: int = field(
46 | default=1,
47 | metadata={"help": "Number of GPUs to use. Default is 1."}
48 | )
49 | trainable: bool = field(
50 | default=False,
51 | metadata={"help": "Whether the policy is trainable. Default is False."}
52 | )
53 |
54 |
55 | @dataclass
56 | class RewardConfig(BaseModelConfig):
57 | """RewardConfig"""
58 |
59 | num_cpus: int = field(
60 | default=2,
61 | metadata={"help": "Number of CPUs to use. Default is 1."}
62 | )
63 |
64 |
65 | @dataclass
66 | class RefPolicyConfig(BaseModelConfig):
67 | """RefPolicyConfig"""
68 |
69 | fsdp_size: int = field(
70 | default=-1,
71 | metadata={"help": "FSDP size. Default is -1."}
72 | )
73 |
74 |
75 | @dataclass
76 | class PolicyTrainerConfig(BaseModelConfig):
77 | """PolicyTrainerConfig"""
78 |
79 | free_memory: bool = field(
80 | default=True,
81 | metadata={"help": "Whether to free memory. Default is True."}
82 | )
83 |
84 |
85 | @dataclass
86 | class RuntimeConfig:
87 | """RuntimeConfig"""
88 |
89 | colocation: list[str] = field(
90 | default_factory=list,
91 | metadata={"help": "List of modules to colocate. Default is empty."}
92 | )
93 | data_path: str = field(
94 | default=MISSING,
95 | metadata={"help": "Path to the data file. Required."}
96 | )
97 | eval_data_path: str = field(
98 | default=MISSING,
99 | metadata={"help": "Path to the evaluation data file. Required."}
100 | )
101 |
--------------------------------------------------------------------------------
/chatlearn/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/chatlearn/hooks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """hooks"""
16 |
17 | from chatlearn.models.vllm import hooks as vllm_hooks # pylint: disable=unused-import
18 |
--------------------------------------------------------------------------------
/chatlearn/launcher/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/chatlearn/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Module related."""
16 |
--------------------------------------------------------------------------------
/chatlearn/models/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """megatron"""
16 |
--------------------------------------------------------------------------------
/chatlearn/models/megatron/memory_manager/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Memery manager for Megatron modules which provides utilities to free memory when unused."""
16 |
17 | from chatlearn.models.megatron.memory_manager.base_trainer import create_trainer_memory_manager
18 | from chatlearn.models.megatron.memory_manager.inference import InferenceMemoryManager
19 |
--------------------------------------------------------------------------------
/chatlearn/models/megatron/memory_manager/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Inference Memery manager for Megatron."""
16 | from typing import Optional, List
17 |
18 | from chatlearn.utils.flat_tensors import BucketizedFlatTensors
19 | from chatlearn.utils.logger import log_rank_0
20 | from chatlearn.utils.megatron_import_helper import DistributedDataParallel
21 |
22 |
23 | class InferenceMemoryManager:
24 | """
25 | Memory manager for Megatron inference modules which provides utilities to free memory when unused.
26 | """
27 |
28 | def __init__(self, model, bucket_size_mb=0):
29 | self._model = model
30 |
31 | assert not isinstance(
32 | model, (DistributedDataParallel,)
33 | ), f'Only support model type non-DistributedDataParallel, current type is {str(type(model))}.'
34 |
35 | self._weights_offloaded = False
36 | self._group_flat_weights: Optional[List[BucketizedFlatTensors]] = None
37 | self._bucket_size_mb = bucket_size_mb
38 |
39 | def offload_weights(self):
40 | """
41 | offload weights
42 | """
43 | if self._weights_offloaded:
44 | log_rank_0('Call offload_weights when already offloaded. Ignore it.')
45 | return
46 |
47 | if self._group_flat_weights is None:
48 | dtype_to_params = {}
49 | for p in self._model.parameters():
50 | dtype = p.dtype
51 | if dtype not in dtype_to_params:
52 | dtype_to_params[dtype] = []
53 | dtype_to_params[dtype].append(p)
54 |
55 | self._group_flat_weights = []
56 | for params in dtype_to_params.values():
57 | self._group_flat_weights.append(
58 | BucketizedFlatTensors(params, primary_store_device='cpu', bucket_size_mb=self._bucket_size_mb)
59 | )
60 |
61 | for flat_weights in self._group_flat_weights:
62 | flat_weights.copy_to_primary_store()
63 |
64 | self._weights_offloaded = True
65 |
66 | def onload_weights(self):
67 | """
68 | onload weights
69 | """
70 | if not self._weights_offloaded:
71 | log_rank_0('Call onload_weights when already onloaded. Ignore it.')
72 | return
73 |
74 | for flat_weights in self._group_flat_weights:
75 | flat_weights.copy_to_gpu_buffer()
76 |
77 | self._weights_offloaded = False
78 |
--------------------------------------------------------------------------------
/chatlearn/models/megatron/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """ops."""
16 |
--------------------------------------------------------------------------------
/chatlearn/models/patches/monkey_patch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Apply patches for different model architectures"""
16 | def apply_sp_monkey_patch(model_config):
17 | print(f"apply sequence parallel patches for {model_config.architectures}")
18 | if model_config.architectures[0] == "Qwen2ForCausalLM":
19 | from chatlearn.models.patches.transformers.qwen2_patch import register_sp_attention_forward \
20 | # pylint: disable=import-outside-toplevel
21 | register_sp_attention_forward()
22 | elif model_config.architectures[0] == "Qwen3ForCausalLM":
23 | from chatlearn.models.patches.transformers.qwen3_patch import register_sp_attention_forward \
24 | # pylint: disable=import-outside-toplevel
25 | register_sp_attention_forward()
26 | else:
27 | raise ValueError(f"Unsupported model architecture: {model_config.architectures}")
28 |
--------------------------------------------------------------------------------
/chatlearn/models/vllm/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """vLLM related."""
16 |
--------------------------------------------------------------------------------
/chatlearn/models/vllm/hooks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """vLLM Hooks."""
16 |
17 | import importlib
18 | import os
19 | import warnings
20 |
21 |
22 | if importlib.util.find_spec("vllm"):
23 |
24 | from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
25 |
26 | if CURRENT_VLLM_VERSION == VLLMVersion.v_0_8_5:
27 | from .vllm_0_8_5 import *
28 | else:
29 | raise RuntimeError(
30 | f"vLLM version expected in {list(member.value for member in VLLMVersion)}, while {CURRENT_VLLM_VERSION}. \
31 | if you want to use previous vllm version, please git checkout 4ad5912306df5d4a814dc2dd5567fcb26f5d473b")
32 |
--------------------------------------------------------------------------------
/chatlearn/models/vllm/hooks/vllm_0_8_5/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Additional hooks of vllm-0.8.5."""
16 |
17 | from . import async_llm_engine
18 | from . import llm
19 | from . import llm_engine
20 | from . import ray_distributed_executor
21 | from . import worker_base
22 | from . import logits_processor
23 |
--------------------------------------------------------------------------------
/chatlearn/models/vllm/hooks/vllm_0_8_5/async_llm_engine.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """del init_ray_cluster in AsyncLLMEngine."""
16 |
17 | from typing import Dict, Optional
18 |
19 | # pylint: disable=unused-import,wildcard-import,unused-argument,not-callable
20 | from vllm.config import VllmConfig
21 | from vllm.engine import async_llm_engine
22 | from vllm.engine.arg_utils import AsyncEngineArgs
23 | from vllm.engine.metrics_types import StatLoggerBase
24 | from vllm.usage.usage_lib import UsageContext
25 |
26 | @classmethod
27 | def from_engine_args(cls,
28 | engine_args: AsyncEngineArgs,
29 | engine_config: Optional[VllmConfig] = None,
30 | start_engine_loop: bool = True,
31 | usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
32 | stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
33 | ) -> "AsyncLLMEngine":
34 | """Creates an async LLM engine from the engine arguments."""
35 | # Create the engine configs.
36 | if engine_config is None:
37 | engine_config = engine_args.create_engine_config(usage_context)
38 |
39 | executor_class = cls._get_executor_cls(engine_config)
40 |
41 | # Create the async LLM engine.
42 | engine = cls(
43 | vllm_config=engine_config,
44 | executor_class=executor_class,
45 | log_requests=not engine_args.disable_log_requests,
46 | log_stats=not engine_args.disable_log_stats,
47 | start_engine_loop=start_engine_loop,
48 | usage_context=usage_context,
49 | stat_loggers=stat_loggers,
50 | )
51 | return engine
52 |
53 | async_llm_engine.AsyncLLMEngine.from_engine_args = from_engine_args
54 |
--------------------------------------------------------------------------------
/chatlearn/models/vllm/hooks/vllm_0_8_5/llm_engine.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Hooks of vllm-0.8.5 llm_engine remove __reduce__ function."""
16 |
17 | import inspect
18 | from typing import Dict, Optional
19 |
20 | # pylint: disable=unused-import,wildcard-import,unused-argument,wrong-import-order
21 | from chatlearn.utils.vllm_utils import vllm_use_v1
22 | from vllm.engine.metrics_types import StatLoggerBase
23 | from vllm.usage.usage_lib import UsageContext
24 | if vllm_use_v1():
25 | from vllm.v1.engine import llm_engine
26 | else:
27 | from vllm.engine import llm_engine
28 | source = inspect.getsource(llm_engine.LLMEngine.__reduce__)
29 | if 'RuntimeError' in source:
30 | def __reduce__(self):
31 | # This is to ensure that the LLMEngine can be referenced in
32 | # the closure used to initialize Ray worker actors
33 | pass
34 |
35 | del llm_engine.LLMEngine.__reduce__
36 |
37 |
38 | @classmethod
39 | def from_engine_args(
40 | cls,
41 | engine_args,
42 | usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
43 | stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
44 | ) -> "LLMEngine":
45 | """Creates an LLM engine from the engine arguments."""
46 | # Create the engine configs.
47 | engine_config = engine_args.create_engine_config(usage_context)
48 | if vllm_use_v1():
49 | from vllm.v1.executor.ray_distributed_executor import RayDistributedExecutor # pylint: disable=import-outside-toplevel
50 | executor_class = RayDistributedExecutor
51 | else:
52 | from vllm.executor.ray_distributed_executor import RayDistributedExecutor # pylint: disable=import-outside-toplevel
53 | executor_class = RayDistributedExecutor
54 | # Create the LLM engine.
55 | engine = cls( # pylint: disable=not-callable
56 | vllm_config=engine_config,
57 | executor_class=executor_class,
58 | log_stats=not engine_args.disable_log_stats,
59 | usage_context=usage_context,
60 | stat_loggers=stat_loggers,
61 | )
62 |
63 | return engine
64 | llm_engine.LLMEngine.from_engine_args = from_engine_args
65 |
--------------------------------------------------------------------------------
/chatlearn/models/vllm/hooks/vllm_0_8_5/logits_processor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Hooks of vllm-0.8.5 logits_processor to allgather logits of all ranks."""
16 |
17 | import inspect
18 |
19 | # pylint: disable=wildcard-import,ungrouped-imports
20 | from vllm.model_executor.layers import logits_processor
21 |
22 |
23 | source = inspect.getsource(logits_processor.LogitsProcessor._gather_logits)
24 | if 'tensor_model_parallel_gather' in source:
25 | import torch
26 | def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: # pylint: disable=unused-argument
27 | from vllm.distributed import tensor_model_parallel_all_gather # pylint: disable=import-outside-toplevel
28 | return tensor_model_parallel_all_gather(logits)
29 |
30 | logits_processor.LogitsProcessor._gather_logits = _gather_logits
31 |
--------------------------------------------------------------------------------
/chatlearn/models/vllm/hooks/vllm_0_8_5/worker_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Hooks of vllm-0.6.6 worker_base to update execute_method."""
16 |
17 | # pylint: disable=unused-import,wildcard-import
18 | from typing import Union
19 | from vllm.worker import worker_base
20 | from vllm.worker.worker_base import logger
21 | from vllm.utils import run_method
22 |
23 |
24 | del worker_base.WorkerWrapperBase.__getattr__
25 | def execute_method(self, method: Union[str, bytes], *args, **kwargs):
26 | try:
27 | if self.worker is None:
28 | target = self
29 | else:
30 | if hasattr(self, method):
31 | target = self
32 | else:
33 | target = self.worker
34 | # method resolution order:
35 | # if a method is defined in this class, it will be called directly.
36 | # otherwise, since we define `__getattr__` and redirect attribute
37 | # query to `self.worker`, the method will be called on the worker.
38 | return run_method(target, method, args, kwargs)
39 | except Exception as e:
40 | # if the driver worker also execute methods,
41 | # exceptions in the rest worker may cause deadlock in rpc like ray
42 | # see https://github.com/vllm-project/vllm/issues/3455
43 | # print the error and inform the user to solve the error
44 | msg = (f"Error executing method {method!r}. "
45 | "This might cause deadlock in distributed execution.")
46 | logger.exception(msg)
47 | raise e
48 |
49 | worker_base.WorkerWrapperBase.execute_method = execute_method
50 |
--------------------------------------------------------------------------------
/chatlearn/models/vllm/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Inference Memery manager for Megatron."""
16 | from typing import Optional, List
17 |
18 | from chatlearn.utils.flat_tensors import BucketizedFlatTensors
19 | from chatlearn.utils.logger import log_rank_0
20 |
21 |
22 | class InferenceMemoryManager:
23 | """
24 | Memory manager for Megatron inference modules which provides utilities to free memory when unused.
25 | """
26 |
27 | def __init__(self, model, bucket_size_mb=0):
28 | self._model = model
29 | self._weights_offloaded = False
30 | self._group_flat_weights: Optional[List[BucketizedFlatTensors]] = None
31 | self._bucket_size_mb = bucket_size_mb
32 |
33 | def offload_weights(self):
34 | """
35 | offload weights
36 | """
37 | if self._weights_offloaded:
38 | log_rank_0('Call offload_weights when already offloaded. Ignore it.')
39 | return
40 |
41 | if self._group_flat_weights is None:
42 | dtype_to_params = {}
43 | for p in self._model.parameters():
44 | dtype = p.dtype
45 | if dtype not in dtype_to_params:
46 | dtype_to_params[dtype] = []
47 | dtype_to_params[dtype].append(p)
48 |
49 | self._group_flat_weights = []
50 | for params in dtype_to_params.values():
51 | self._group_flat_weights.append(
52 | BucketizedFlatTensors(params, primary_store_device='cpu', bucket_size_mb=self._bucket_size_mb)
53 | )
54 |
55 | for flat_weights in self._group_flat_weights:
56 | flat_weights.copy_to_primary_store()
57 |
58 | self._weights_offloaded = True
59 |
60 | def onload_weights(self):
61 | """
62 | onload weights
63 | """
64 | if not self._weights_offloaded:
65 | log_rank_0('Call onload_weights when already onloaded. Ignore it.')
66 | return
67 |
68 | for flat_weights in self._group_flat_weights:
69 | flat_weights.copy_to_gpu_buffer()
70 |
71 | self._weights_offloaded = False
72 |
--------------------------------------------------------------------------------
/chatlearn/runtime/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/chatlearn/runtime/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """runtime utils"""
16 |
17 | import ast
18 | import textwrap
19 | import inspect
20 | from collections import defaultdict
21 |
22 | def encode_data(mb, data):
23 | return {"iter": mb, "data": data}
24 |
25 |
26 | def decode_data(data):
27 | mb = data["iter"]
28 | data = data["data"]
29 | return mb, data
30 |
31 |
32 | def parse_assign_target(line):
33 | targets = []
34 | for target in line.targets:
35 | targets.append(target.id)
36 | return targets
37 |
38 |
39 | def parse_expr(line):
40 | func = line.value.func
41 | func_name = func.attr
42 | func_args = [arg.id for arg in line.value.args]
43 | if isinstance(func.value, ast.Name):
44 | model_name = func.value.id
45 | else:
46 | model_name = func.value.attr
47 | return func_name, model_name, func_args
48 |
49 |
50 | class FlowParser:
51 | """Flow Parser"""
52 |
53 | def __init__(self):
54 | self.model_to_call_funcs = defaultdict(list)
55 |
56 | def parse_assign(self, line):
57 | func_name, model_name, _ = parse_expr(line)
58 | model = self.global_models[model_name]
59 | self.model_to_call_funcs[model].append(func_name)
60 |
61 | def visit_func(self, node):
62 | for line in node.body:
63 | if isinstance(line, (ast.Assign, ast.Expr)):
64 | self.parse_assign(line)
65 | elif isinstance(line, ast.With):
66 | for line0 in line.body:
67 | if isinstance(line0, (ast.Assign, ast.Expr)):
68 | self.parse_assign(line0)
69 |
70 | def parse(self, func):
71 | closure_vars = inspect.getclosurevars(func)
72 | self.global_models = {}
73 | if closure_vars.globals:
74 | self.global_models.update(closure_vars.globals)
75 | if closure_vars.nonlocals:
76 | self.global_models.update(closure_vars.nonlocals)
77 | node_iter = ast.NodeVisitor()
78 | node_iter.visit_FunctionDef = self.visit_func
79 | if isinstance(func, str):
80 | code = textwrap.dedent(func)
81 | else:
82 | code = textwrap.dedent(inspect.getsource(func))
83 | node_iter.visit(ast.parse(code))
84 | return self.model_to_call_funcs
85 |
--------------------------------------------------------------------------------
/chatlearn/schedule/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/chatlearn/schedule/port_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """port manager"""
16 |
17 | from multiprocessing import Lock
18 | import ray
19 |
20 | @ray.remote
21 | class PortManager:
22 | """port manager"""
23 |
24 | def __init__(self, port_list):
25 | self._port_list = port_list
26 | self._address_to_port_index = {}
27 | self._lock = Lock()
28 |
29 | def get_free_port(self, address):
30 | self._lock.acquire()
31 | free_port = None
32 | try:
33 | port_index = self._address_to_port_index.get(address, 0)
34 | assert port_index < len(self._port_list)
35 | self._address_to_port_index[address] = port_index + 1
36 | free_port = self._port_list[port_index]
37 | finally:
38 | self._lock.release()
39 | return free_port
40 |
--------------------------------------------------------------------------------
/chatlearn/synchronizer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """synchronizer"""
16 |
17 | from transformers import AutoConfig
18 | from chatlearn.models.megatron_module import MegatronModule
19 | from chatlearn.models.vllm_module import VLLMModule
20 | from chatlearn.runtime.dist_actor import DistModel
21 | from .base import BaseSync
22 | from .megatron_megatron import MegatronMegatronSync
23 | from .megatron_vllm import(
24 | MegatronVllmQWenSync,
25 | MegatronVllmQWen2Sync,
26 | MegatronVllmLlamaSync,
27 | MegatronVllmMoonlightSync,
28 | MegatronVllmQWen2MCoreSync
29 | )
30 |
31 | def get_synchronizer(src_model, dst_model):
32 | assert isinstance(src_model, DistModel)
33 | assert isinstance(dst_model, DistModel)
34 | src_model = src_model.replicas[0].model
35 | dst_model = dst_model.replicas[0].model
36 | if isinstance(src_model, MegatronModule) and isinstance(dst_model, MegatronModule):
37 | return MegatronMegatronSync(src_model, dst_model)
38 | elif isinstance(src_model, MegatronModule) and isinstance(dst_model, VLLMModule):
39 | config_dir = dst_model.module_args.args_dict["tokenizer"]
40 | config = AutoConfig.from_pretrained(config_dir, trust_remote_code=True)
41 | model_class_name = config.architectures[0]
42 | if model_class_name == "QWenLMHeadModel":
43 | return MegatronVllmQWenSync(src_model, dst_model)
44 | elif model_class_name in ["Qwen2ForCausalLM", "Qwen2MoeForCausalLM"]:
45 | # NOTE: check if the model is mcore or not
46 | if src_model.module_args.args_dict.get("use_legacy_models", True):
47 | return MegatronVllmQWen2Sync(src_model, dst_model)
48 | return MegatronVllmQWen2MCoreSync(src_model, dst_model)
49 | elif model_class_name == "LlamaForCausalLM":
50 | return MegatronVllmLlamaSync(src_model, dst_model)
51 | elif model_class_name in ["DeepseekV3ForCausalLM", "Qwen3ForCausalLM"]:
52 | return MegatronVllmMoonlightSync(src_model, dst_model)
53 | else:
54 | raise RuntimeError(
55 | f"Unsupported model {model_class_name}, Expect QWenLMHeadModel, Qwen2ForCausalLM, Qwen2MoeForCausalLM or LlamaForCausalLM.")
56 | else:
57 | return BaseSync(src_model, dst_model)
58 |
--------------------------------------------------------------------------------
/chatlearn/synchronizer/megatron_megatron.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """megatron to megatron synchronizer"""
16 |
17 | from chatlearn.utils import future
18 | from .base import BaseSync
19 |
20 | class MegatronMegatronSync(BaseSync):
21 | """megatron to megatron synchronizer"""
22 |
23 | def _get_dst_name(self, src_name, src_prefix, dst_prefix):
24 | if src_prefix:
25 | dst_name = src_name[len(src_prefix):]
26 | else:
27 | dst_name = dst_prefix + src_name
28 | return dst_name
29 |
30 | def set_model_prefix(self, src_names, dst_names):
31 | dst_prefix = None
32 | src_prefix = None
33 | for sname in src_names:
34 | for dname in dst_names:
35 | if sname in dname:
36 | prefix = dname[:dname.index(sname)]
37 | dst_prefix = prefix
38 | return src_prefix, dst_prefix
39 | elif dname in sname:
40 | prefix = sname[:sname.index(dname)]
41 | src_prefix = prefix
42 | return src_prefix, dst_prefix
43 | if dst_prefix is None and src_prefix is None:
44 | raise RuntimeError("Cannot find prefix")
45 | return src_prefix, dst_prefix
46 |
47 | def map_name_from_src_to_dst(self, send_actor, recv_actor, src_names, dst_names):
48 | dst_names_ref = future.get(recv_actor.get_parameter_names.remote(requires_grad=False))
49 | src_prefix, dst_prefix = self.set_model_prefix(src_names, dst_names_ref)
50 | dst_names = [self._get_dst_name(name, src_prefix, dst_prefix) for name in dst_names]
51 | return src_names, dst_names
52 |
--------------------------------------------------------------------------------
/chatlearn/synchronizer/parameter_sync_fsdp.py:
--------------------------------------------------------------------------------
1 | """fsdp to vllm parameter sync group"""
2 | import ray
3 | from chatlearn.utils import future
4 | from chatlearn.runtime.dist_actor import DistModel
5 | from chatlearn.utils.error_monitor import ErrorSignalActor
6 |
7 |
8 | def flatten(lst: list, reverse=False):
9 | result = []
10 | for item in lst:
11 | if reverse:
12 | result += item[::-1]
13 | else:
14 | result += item
15 | return result
16 |
17 |
18 | class FSDP2VllmParameterSyncGroup:
19 | """fsdp to vllm parameter sync group
20 | """
21 | def __init__(
22 | self,
23 | src_model: DistModel,
24 | dst_model: DistModel,
25 | group_name: str,
26 | frequency: int,
27 | error_signal: ErrorSignalActor,
28 | ):
29 | self.src_model = src_model
30 | self.dst_model = dst_model
31 | self.group_name = group_name
32 | self.error_signal = error_signal
33 | self.frequency = frequency
34 |
35 | self.setup_collective_group()
36 |
37 | def setup_collective_group(self):
38 | # we put src_model first, so we don't need to change the rank of training model
39 | models = [self.src_model, self.dst_model]
40 |
41 | rank_offset = 0
42 | for model in models:
43 | for replica in model.replicas:
44 | replica._setup_ranks(rank_offset)
45 | rank_offset += replica.actor_num
46 |
47 | def sync(self, *args, **kwargs): # pylint: disable=unused-argument
48 | """
49 | sync function for fsdp to vllm
50 | """
51 | # for fsdp to vllm, we only need to find the src and dst actors that are on the same GPU.
52 | src_model_ranks = flatten(self.src_model.all_ranks)
53 | # adapt for model manager: models_to_revert
54 | dst_model_ranks = flatten(self.dst_model.all_ranks, reverse=True)
55 |
56 | param_name_list = ray.get(self.src_model.get_actor(0).get_fsdp_param_name.remote())
57 |
58 | for param_name in param_name_list:
59 |
60 | refs = []
61 | for src_rank, dst_rank in zip(src_model_ranks, dst_model_ranks):
62 | src_actor = self.src_model.get_actor(src_rank)
63 | dst_actor = self.dst_model.get_actor(dst_rank)
64 | reduce_data_ref = src_actor.get_weight_ipc_handles_by_name.remote(param_name)
65 | ref = dst_actor.update_weights_from_ipc_handles.remote(reduce_data_ref)
66 | refs.append(ref)
67 | future.wait(refs, return_output=True)
68 |
--------------------------------------------------------------------------------
/chatlearn/synchronizer/scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """
16 | CollectiveTaskScheduler uses two queue to schedule collective task:
17 | - TodoQueue maintenances all tasks that to be executed
18 | - PendingQueue maintenances execution tasks
19 |
20 | This scheduler resort the remote tasks to avoid collective operator hang.
21 | """
22 | from queue import Queue
23 | import concurrent
24 | import traceback
25 | from concurrent.futures import ThreadPoolExecutor
26 |
27 | from chatlearn.utils.logger import logger
28 |
29 | class CollectiveTask:
30 | """ColleciteTask represents a group of actors to execute a collective task"""
31 | def __init__(self, actors, group):
32 | self.actors = actors
33 | self.group = group
34 |
35 | def collective_task_scheduler(tasks):
36 |
37 | todo_queue = Queue()
38 | pending_queue = []
39 | _ = [todo_queue.put(task) for task in tasks]
40 |
41 | while not todo_queue.empty():
42 | send_actors_set = set()
43 | recv_actors_set = set()
44 | list_count = todo_queue.qsize()
45 | # re-put it if conflict, otherwise put it to PendingQueue
46 | for _ in range(list_count):
47 | task = todo_queue.get()
48 | send = task.actors[0]
49 | recvs = task.actors[1:]
50 | if send not in send_actors_set and send not in recv_actors_set and \
51 | all(recv not in send_actors_set for recv in recvs) and all(recv not in recv_actors_set for recv in recvs):
52 | pending_queue.append(task)
53 | send_actors_set.add(send)
54 | recv_actors_set.update(recvs)
55 | else:
56 | todo_queue.put(task)
57 | if pending_queue:
58 | yield pending_queue
59 | send_actors_set = set()
60 | recv_actors_set = set()
61 | pending_queue = []
62 |
63 | def parallel_execute_collective_tasks(tasks, submit_func):
64 | scheduler = collective_task_scheduler(tasks)
65 | for parallel_tasks in scheduler:
66 | logger.info(f"DEBUG parallel_execute_tasks: {[task.group for task in parallel_tasks]}")
67 | with ThreadPoolExecutor(max_workers=len(parallel_tasks)) as executor:
68 | futures = [executor.submit(submit_func, task) for task in parallel_tasks]
69 | for _future in concurrent.futures.as_completed(futures):
70 | try:
71 | _future.result()
72 | except Exception as e:
73 | traceback.print_exc()
74 | raise RuntimeError(f"ParameterSync warmup failed: {e}") # pylint: disable=raise-missing-from
75 |
--------------------------------------------------------------------------------
/chatlearn/tools/check_parameter_sync.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Check ParameterSync"""
16 |
17 | import argparse
18 | import os
19 | import torch
20 |
21 | def chatlearn_compare(expected_dir, actural_dir):
22 | total = 0
23 | diff = 0
24 | not_exists = 0
25 | for tp_rank in os.listdir(actural_dir):
26 | for param in os.listdir(os.path.join(actural_dir, tp_rank)):
27 | actual_fname = os.path.join(actural_dir, tp_rank, param)
28 | expected_fname = os.path.join(expected_dir, tp_rank, param)
29 | message = f"{tp_rank}|{param}"
30 | total += 1
31 | if not os.path.exists(expected_fname):
32 | print(f"NOT_EXISTS|{message}|NOT_EXISTS", flush=True)
33 | not_exists += 1
34 | continue
35 | ta = torch.load(actual_fname, map_location="cpu")
36 | tb = torch.load(expected_fname, map_location="cpu")
37 | if not torch.allclose(ta, tb):
38 | print(f"DIFF|{message}|{ta.shape}|{ta.mean()}|{tb.shape}|{tb.mean()}", flush=True)
39 | else:
40 | print(f"PASS|{message}")
41 | print(f"ALL: {all}, DIFF: {diff}, NOT_EXISTS: {not_exists}")
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument(
47 | "--root_dir",
48 | type=str,
49 | required=True,
50 | help="Root dir to check the dumped parameters")
51 | args = parser.parse_args()
52 | dir1 = os.path.join(args.root_dir, "before_sync_paramter")
53 | dir2 = os.path.join(args.root_dir, "after_sync_paramter")
54 | chatlearn_compare(dir1, dir2)
55 |
--------------------------------------------------------------------------------
/chatlearn/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """utils"""
16 |
17 | from .utils import to_device
18 |
--------------------------------------------------------------------------------
/chatlearn/utils/constant.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """constants."""
16 |
17 | import importlib
18 | from enum import Enum
19 |
20 | # Regroup
21 |
22 | CHATLEARN_REGROUP_TAG = "chatlearn_regroup_tag"
23 | INDEX_TAG = "data_index"
24 |
25 | LOG_START = "chatlearn_log"
26 |
27 | QKV_LAYER_NAME = ["query_key_value"]
28 |
29 |
30 | # vLLM version
31 | CURRENT_VLLM_VERSION = None
32 | if importlib.util.find_spec("vllm"):
33 | import vllm
34 | if hasattr(vllm, "__version_tuple__"):
35 | version_tuple = vllm.__version_tuple__
36 | CURRENT_VLLM_VERSION = ".".join([str(ele) for ele in version_tuple[:3]])
37 | else:
38 | CURRENT_VLLM_VERSION = vllm.__version__
39 |
40 |
41 | class VLLMVersion(str, Enum):
42 | """support versions of vLLM."""
43 | v_0_8_5 = "0.8.5"
44 |
45 |
46 | class QwenVersion(float, Enum):
47 | """qwen version"""
48 | v_1 = 1.0
49 | v_2 = 2.0
50 |
51 |
52 | class RAY_PG_STRATEGY(Enum):
53 | """ray placement group strategy"""
54 | PACK = "PACK"
55 | SPREAD = "SPREAD"
56 |
57 |
58 | class PARAM_SYNC_COMM_TYPE(str, Enum):
59 | """parameter sync communication type"""
60 | BROADCAST = "broadcast"
61 | P2P = "p2p"
62 |
63 |
64 | class ROUTED_EXPERT_REGROUPING_COMM_TYPE(str, Enum):
65 | """communication type of routed expert regrouping."""
66 | ALLTOALL = "alltoall"
67 | ALLGATHER = "allgather"
68 |
--------------------------------------------------------------------------------
/chatlearn/utils/error_monitor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Error monitor"""
16 |
17 | import time
18 |
19 | import ray
20 | import ray.util.collective as col
21 |
22 | from chatlearn.utils import future
23 |
24 |
25 | @ray.remote
26 | class ErrorMonitor:
27 | """Error Monitor"""
28 |
29 | def __init__(self, error_signal, remote_models, group_names):
30 | self.error_signal = error_signal
31 | self.remote_models = remote_models
32 | self.collective_groups = group_names
33 |
34 | def monitor(self):
35 | while True:
36 | try:
37 | catch_err = future.get(self.error_signal.is_set.remote())
38 | except Exception:
39 | catch_err = False
40 | if catch_err:
41 | break
42 | time.sleep(2)
43 | for group_name in self.collective_groups:
44 | col.destroy_collective_group(group_name)
45 | for model in self.remote_models:
46 | model.terminate()
47 | error_msg = future.get(self.error_signal.error_msg.remote())
48 | error_address = future.get(self.error_signal.error_address.remote())
49 | raise Exception(f"Catch an exception in {error_address}, error msg: {error_msg}")
50 |
51 |
52 | @ray.remote(num_cpus=0)
53 | class ErrorSignalActor:
54 | """ErrorSignalActor"""
55 | def __init__(self):
56 | self.error_state = False
57 | self.err_msg = None
58 | self._address_list = []
59 |
60 | def set(self, err_msg=None):
61 | self.error_state = True
62 | if err_msg is not None:
63 | self.err_msg = err_msg
64 |
65 | def set_address(self, address):
66 | if address not in self._address_list:
67 | self._address_list.append(address)
68 |
69 | def is_set(self):
70 | return self.error_state
71 |
72 | def error_msg(self):
73 | return self.err_msg
74 |
75 | def error_address(self):
76 | return self._address_list
77 |
--------------------------------------------------------------------------------
/chatlearn/utils/future.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """get remote object"""
16 |
17 | import ray
18 |
19 | from chatlearn.utils.logger import logging_tqdm
20 | from chatlearn.utils.utils import flatten
21 |
22 |
23 | def check_nested_2_level_list(refs):
24 | """
25 | Checks if a list is a nested list with a nested level of 2.
26 | e.g.
27 | [[ref0, ref1], [ref2, ref3]] returns True, [2, 2]
28 | [ref0, ref1] returns False, -1
29 | [[ref0], [ref1, ref2]] returns True, [1, 2]
30 |
31 | Returns a tuple containing two elements:
32 | - A boolean value indicating if the list is a nested 2-level list
33 | - A list of integers containing the length of each sublist
34 | """
35 | sublist_lens = []
36 | for sublist in refs:
37 | if isinstance(sublist, list):
38 | if len(sublist) == 0:
39 | sublist_lens.append(0)
40 | else:
41 | if isinstance(sublist[0], ray.ObjectRef):
42 | sublist_lens.append(len(sublist))
43 | else:
44 | return False, None
45 | else:
46 | return False, None
47 | return True, sublist_lens
48 |
49 |
50 | def wait(refs, desc=None, return_output=False):
51 | """
52 | wait until all computation finish
53 | """
54 | if isinstance(refs, ray.ObjectRef):
55 | ray.get(refs)
56 | return
57 | if len(refs) == 0:
58 | return
59 | nested2, sublist_lens = check_nested_2_level_list(refs)
60 | refs = flatten(refs)
61 | if desc is not None:
62 | total = len(refs) if not nested2 else len(sublist_lens)
63 | pbar = logging_tqdm(total=total, desc=desc)
64 | i = 0
65 | wait_refs = refs.copy()
66 | while wait_refs:
67 | num_returns = 1 if not nested2 else sublist_lens[i]
68 | done, wait_refs = ray.wait(wait_refs, num_returns=num_returns)
69 | i += 1
70 | if desc is not None:
71 | done_size = len(done) if not nested2 else 1
72 | pbar.update(done_size)
73 | if return_output:
74 | outputs = ray.get(refs)
75 | if desc is not None:
76 | pbar.close()
77 | if return_output:
78 | return outputs
79 |
80 |
81 | def get(data):
82 | """get remote data"""
83 | if isinstance(data, (list, tuple)):
84 | dtype = type(data)
85 | ret = dtype(get(item) for item in data)
86 | return ret
87 | if isinstance(data, dict):
88 | return {key: get(value) for key, value in data.items()}
89 | while isinstance(data, ray.ObjectRef):
90 | data = ray.get(data)
91 | return data
92 |
--------------------------------------------------------------------------------
/chatlearn/utils/global_vars.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """global vars."""
16 |
17 | _GLOBAL_ARGS = None
18 | _EXIT_ACTOR = None
19 | _EXIT_ACTOR_NAME = "ChatLearnExitActor"
20 | _DECORATED_MODELS = None
21 | _DECORATED_OUTER_TO_INNER = {}
22 | _DEPENDENCIES = None
23 | _VLLM_ACTORS = None
24 |
25 |
26 | def _ensure_var_is_initialized(var, name):
27 | """Make sure the input variable is not None."""
28 | assert var is not None, '{} is not initialized.'.format(name)
29 |
30 | def get_args():
31 | """Return arguments."""
32 | _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
33 | return _GLOBAL_ARGS
34 |
35 | def set_global_variables(args):
36 | """Set global vars"""
37 | assert args is not None
38 | global _GLOBAL_ARGS
39 | _GLOBAL_ARGS = args
40 | global _DECORATED_MODELS
41 | _DECORATED_MODELS = set()
42 |
43 | def set_decorated(model_name):
44 | _DECORATED_MODELS.add(model_name)
45 |
46 | def is_decorated(model_name):
47 | _ensure_var_is_initialized(_DECORATED_MODELS, 'decorated_models')
48 | return bool(model_name in _DECORATED_MODELS)
49 |
50 | def unwrap_func(func, level=None):
51 | """
52 | func: func to unwrap
53 | level: unwrap level, if level is None, unwrap to the original func
54 | """
55 | if func not in _DECORATED_OUTER_TO_INNER:
56 | return func
57 | if level is not None:
58 | if level > 0:
59 | level -= 1
60 | else:
61 | return func
62 | return unwrap_func(_DECORATED_OUTER_TO_INNER[func], level)
63 |
64 | def set_wrap_func(func, new_func):
65 | assert new_func not in _DECORATED_OUTER_TO_INNER
66 | _DECORATED_OUTER_TO_INNER[new_func] = func
67 |
68 | def set_dependencies(dependencies):
69 | global _DEPENDENCIES
70 | assert _DEPENDENCIES is None
71 | _DEPENDENCIES = dependencies
72 |
73 | def reset_dependencies():
74 | global _DEPENDENCIES
75 | _DEPENDENCIES = None
76 |
77 | def get_dependencies():
78 | return _DEPENDENCIES
79 |
80 | def set_vllm_actors(actors):
81 | global _VLLM_ACTORS
82 | _VLLM_ACTORS = actors
83 |
84 | def get_vllm_actors():
85 | return _VLLM_ACTORS
86 |
--------------------------------------------------------------------------------
/chatlearn/utils/megatron_import_hook_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """"Version compatibility for hook"""
16 |
17 | # pylint: disable=unused-import,wildcard-import
18 |
19 | # megatron.text_generation.*
20 | try:
21 | from megatron.text_generation import generation
22 | from megatron.text_generation.generation import *
23 | from megatron.text_generation.generation import _build_attention_mask_and_position_ids
24 | from megatron.text_generation.generation import generate_tokens_probs_and_return_on_first_stage
25 | except ImportError:
26 | from megatron.inference.text_generation import generation
27 | from megatron.inference.text_generation.generation import *
28 | from megatron.inference.text_generation.generation import _build_attention_mask_and_position_ids
29 | from megatron.inference.text_generation.generation import generate_tokens_probs_and_return_on_first_stage
30 |
31 | # pylint: enable=unused-import,wildcard-import
32 |
--------------------------------------------------------------------------------
/chatlearn/utils/megatron_import_memory_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """
16 | Version compatibility utilities for Megatron memory management of gradients and parameter weights.
17 | Base on how Megatron uses buffers to manage memory, we support 3 different versions.
18 | """
19 | import os
20 |
21 | from enum import Enum, auto
22 | from typing import List
23 |
24 | __all__ = ['MegatronVersion', 'get_megatron_version', 'check_megatron_versions']
25 |
26 |
27 | class MegatronVersion(Enum):
28 | """
29 | There are currently three different Megatron versions supported.
30 | """
31 |
32 | V1 = auto() # use `MemoryBuffer` to manage gradients
33 | V2 = auto() # use `GradBuffer` to manage gradients
34 | V3 = auto() # use `ParamAndGradBuffer` to manage parameter weights and gradients
35 | V4 = auto() # for compatibility with temporary version for Qwen-MoE
36 | V5 = auto() # use refactored `ParamAndGradBuffer` to manage parameter weights and gradients
37 |
38 |
39 | def get_megatron_version():
40 | # for compatibility with temporary version for Qwen-MoE
41 | if os.environ.get("QWEN_VERSION", '') == 'qwen_moe_v1':
42 | return MegatronVersion.V4
43 |
44 | try:
45 | # pylint: disable-next=import-outside-toplevel, unused-import
46 | from megatron.core.distributed.distributed_data_parallel import _ParamAndGradBuffer
47 | return MegatronVersion.V5
48 | except ImportError:
49 | ...
50 | try:
51 | # pylint: disable-next=import-outside-toplevel, unused-import
52 | from megatron.core.distributed import ParamAndGradBuffer
53 |
54 | return MegatronVersion.V3
55 | except ImportError:
56 | ...
57 | try:
58 | # pylint: disable-next=import-outside-toplevel, unused-import
59 | from megatron.core.distributed import GradBuffer
60 |
61 | return MegatronVersion.V2
62 | except ImportError:
63 | ...
64 | return MegatronVersion.V1
65 |
66 |
67 | def check_megatron_versions(targets: List[MegatronVersion]):
68 | version = get_megatron_version()
69 | assert version in targets, f'Different Megatron version {version} from targets: {targets}.'
70 |
71 |
72 | _version = get_megatron_version()
73 |
74 | # pylint: disable=unused-import
75 |
76 | if _version in [MegatronVersion.V3, MegatronVersion.V5]:
77 | from megatron.core.distributed.param_and_grad_buffer import BufferType
78 |
79 | __all__.append('BufferType')
80 |
81 | # pylint: enable=unused-import
82 |
--------------------------------------------------------------------------------
/chatlearn/utils/megatron_import_transformer_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """"Version compatibility for hook"""
16 |
17 | # pylint: disable=unused-import,wildcard-import
18 |
19 | # megatron.model.transformer.*
20 | try:
21 | from megatron.model import transformer
22 | from megatron.model.transformer import ParallelAttention
23 | from megatron.model.transformer import *
24 | except ImportError:
25 | from megatron.legacy.model import transformer
26 | from megatron.legacy.model.transformer import ParallelAttention
27 | from megatron.legacy.model.transformer import *
28 |
29 | # pylint: enable=unused-import,wildcard-import
30 |
--------------------------------------------------------------------------------
/chatlearn/utils/version.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """version"""
16 |
17 | VERSION = "1.1.0"
18 |
--------------------------------------------------------------------------------
/docker/torch/Dockerfile.torch2.3.0:
--------------------------------------------------------------------------------
1 | # docker build -t your_docker_image -f Dockerfile.torch2.3.0 .
2 | FROM pytorch/pytorch:2.3.0-cuda12.1-cudnn8-devel
3 |
4 | LABEL com.nvidia.volumes.needed="nvidia_driver"
5 | LABEL com.nvidia.cuda.version=
6 | ENV NVIDIA_VISIBLE_DEVICES= \
7 | NVIDIA_REQUIRE_CUDA="cuda>=11.0" \
8 | LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/cuda/lib64
9 |
10 | # install common libs
11 | RUN pip install --no-cache-dir -U \
12 | ray[default]==2.32.0 \
13 | transformers==4.42.0 \
14 | pynvml==11.4.1 \
15 | deepspeed==0.14.4 \
16 | vllm==0.5.1 \
17 | accelerate \
18 | jsonlines \
19 | torchtyping \
20 | tensorboard \
21 | cupy
22 |
23 | # intall apex
24 | RUN apt-get update && apt-get install git vim -y
25 | WORKDIR /tmp/third_party
26 | RUN git clone https://github.com/NVIDIA/apex
27 | WORKDIR /tmp/third_party/apex
28 | RUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
29 | RUN rm -rf /tmp/third_party
30 |
31 | # install transformer engine v1.2.1
32 | RUN MAX_JOBS=4 pip install git+https://github.com/NVIDIA/TransformerEngine.git@v1.2.1
33 |
34 | # env
35 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64 \
36 | CUDA_DEVICE_MAX_CONNECTIONS=1
37 |
--------------------------------------------------------------------------------
/docker/torch/Dockerfile.torch2.5.1.vllm066:
--------------------------------------------------------------------------------
1 | # currently support fsdp
2 | FROM nvcr.io/nvidia/pytorch:24.10-py3
3 |
4 | ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
5 | ENV PIP_TRUSTED_HOST=mirrors.aliyun.com
6 |
7 | RUN pip install --no-cache-dir -U \
8 | opencv-python-headless==4.5.4.58 \
9 | vllm==0.6.6 \
10 | wandb==0.19.3 \
11 | ray[default]==2.40.0 \
12 | transformers==4.51.3 \
13 | modelscope==1.26.0 \
14 | datasets==3.6.0 \
15 | deepspeed==0.14.4 \
16 | grpcio==1.70.0 \
17 | setuptools==69.5.1
18 |
19 | RUN pip uninstall -y flash_attn && pip install -U flash_attn==2.4.2 --no-cache-dir --no-build-isolation
20 | RUN pip uninstall -y apex && pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/apex/torch2.5.1-cuda12x/apex-0.1-cp310-cp310-linux_x86_64.whl --no-cache-dir
21 | RUN pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/transformer_engine/torch2.5.1-cuda12x/transformer_engine-1.13.0%2Be5edd6c-cp310-cp310-linux_x86_64.whl --no-cache-dir
--------------------------------------------------------------------------------
/docker/torch/Dockerfile.torch2.6.0.vllm085:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:24.12-py3
2 | RUN unset NCCL_DEBUG
3 | ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
4 | ENV PIP_TRUSTED_HOST=mirrors.aliyun.com
5 |
6 | RUN pip install --no-cache-dir -U \
7 | vllm==0.8.5.post1 \
8 | wandb==0.19.3 \
9 | transformers==4.51.3 \
10 | modelscope==1.26.0 \
11 | datasets==3.6.0 \
12 | deepspeed==0.16.7 \
13 | grpcio==1.70.0 \
14 | nvidia-modelopt==0.27.0 \
15 | nvidia-modelopt-core==0.27.0 \
16 | ray[default]==2.46.0
17 |
18 | RUN pip install --no-cache-dir -U setuptools==69.5.1
19 |
20 | RUN pip uninstall -y flash_attn && pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/flash-attention/torch2.6.0-cu12x/flash_attn-2.4.2-cp312-cp312-linux_x86_64.whl --no-cache-dir
21 | RUN pip uninstall -y apex && pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/apex/torch2.6.0-cuda12x/apex-0.1-cp312-cp312-linux_x86_64.whl --no-cache-dir
22 | RUN pip uninstall -y transformer_engine && pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/transformer_engine/torch2.6.0-cuda12x/transformer_engine-1.13.0%2Be5edd6c-cp312-cp312-linux_x86_64.whl --no-cache-dir
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | make.bat
2 | _build
3 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 |
3 | # You can set these variables from the command line, and also
4 | # from the environment for the first two.
5 | SPHINXOPTS ?=
6 | SPHINXBUILD ?= sphinx-build
7 | SOURCEDIR = zh
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Setup
2 |
3 | ```
4 | apt-get install latexmk texlive-xetex fonts-noto fonts-freefont-otf xindy latex-cjk-all
5 | pip install -r requirements.txt
6 | ```
7 |
8 | # build pdf
9 |
10 | ```
11 | # cd en for english doc
12 | cd zh
13 | # make latexpdf
14 | sphinx-build -b latex . _build/latex
15 | cd _build/latex
16 |
17 | # modify chatlearn.tex for auto wrap of text in the table
18 | # Find the table with `stream\\_data\\_loader\\_type`, replace `\begin{tabulary}{\linewidth}[t]{TTT}` with `\begin{tabularx}{\linewidth}[t]{|l|l|X|}`
19 | # and replace the corresponding `\end`
20 | # save the change, and execute
21 | make all-pdf
22 | ```
23 |
24 |
25 | # build html
26 | ```
27 | # cd en for english doc
28 | cd zh
29 | make html
30 | ```
31 |
32 |
--------------------------------------------------------------------------------
/docs/en/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the OS, Python version and other tools you might need
9 | build:
10 | os: ubuntu-22.04
11 | tools:
12 | python: "3.10"
13 | # You can also specify other tool versions:
14 | # nodejs: "19"
15 | # rust: "1.64"
16 | # golang: "1.19"
17 |
18 | # Build documentation in the "docs/" directory with Sphinx
19 | sphinx:
20 | configuration: docs/en/conf.py
21 |
22 | # Optionally build your docs in additional formats such as PDF and ePub
23 | formats:
24 | - pdf
25 | - epub
26 |
27 | # Optional but recommended, declare the Python requirements required
28 | # to build your documentation
29 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
30 | python:
31 | install:
32 | - requirements: docs/requirements.txt
33 |
--------------------------------------------------------------------------------
/docs/en/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 |
3 | # You can set these variables from the command line, and also
4 | # from the environment for the first two.
5 | SPHINXOPTS ?=
6 | SPHINXBUILD ?= sphinx-build
7 | SOURCEDIR = .
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20 |
--------------------------------------------------------------------------------
/docs/en/api/config.rst:
--------------------------------------------------------------------------------
1 | Config
2 | ======
3 |
4 | .. autoclass:: chatlearn.utils.arguments.RuntimeEnvConfig
5 | :members:
6 | :undoc-members:
7 |
8 | .. autoclass:: chatlearn.utils.arguments.RuntimeConfig
9 | :members:
10 | :undoc-members:
11 |
12 | .. autoclass:: chatlearn.utils.arguments.ModelConfig
13 | :members:
14 | :undoc-members:
15 |
16 |
17 | .. autoclass:: chatlearn.utils.arguments.BatchGenerationConfig
18 | :members:
19 | :undoc-members:
20 |
21 |
22 | .. autoclass:: chatlearn.utils.arguments.LoraConfig
23 | :members:
24 | :undoc-members:
25 |
26 |
27 |
--------------------------------------------------------------------------------
/docs/en/api/engine.rst:
--------------------------------------------------------------------------------
1 | Engine
2 | ======
3 |
4 | .. autoclass:: chatlearn.DPOEngine
5 | :members:
6 | :undoc-members:
7 |
8 | .. autoclass:: chatlearn.OnlineDPOEngine
9 | :members:
10 | :undoc-members:
11 |
12 | .. autoclass:: chatlearn.RLHFEngine
13 | :members:
14 | :undoc-members:
15 |
16 | .. autoclass:: chatlearn.EvalEngine
17 | :members:
18 | :undoc-members:
19 |
20 | .. autoclass:: chatlearn.Evaluator
21 | :members: set_dataset,set_post_process_func,eval
22 | :undoc-members:
23 |
--------------------------------------------------------------------------------
/docs/en/api/index.rst:
--------------------------------------------------------------------------------
1 | API
2 | =======================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 |
7 | engine.rst
8 | module.rst
9 | config.rst
10 |
--------------------------------------------------------------------------------
/docs/en/api/module.rst:
--------------------------------------------------------------------------------
1 | RLHF Module
2 | ===========
3 |
4 | .. autoclass:: chatlearn.models.base_module.BaseModule
5 | :members:
6 | :undoc-members:
7 |
8 | .. autoclass:: chatlearn.models.torch_module.TorchModule
9 | :members:
10 | :undoc-members:
11 | :show-inheritance:
12 |
13 | .. autoclass:: chatlearn.models.megatron_module.MegatronModule
14 | :members:
15 | :undoc-members:
16 | :show-inheritance:
17 |
18 |
--------------------------------------------------------------------------------
/docs/en/index.rst:
--------------------------------------------------------------------------------
1 | ChatLearn Documentation
2 | =======================
3 |
4 |
5 | .. toctree::
6 | :maxdepth: 1
7 | :caption: Introduction
8 |
9 | chatlearn
10 |
11 |
12 | .. toctree::
13 | :maxdepth: 1
14 | :caption: Installation
15 |
16 | installation
17 |
18 | .. toctree::
19 | :maxdepth: 1
20 | :caption: Tutorial
21 |
22 | tutorial/data
23 | tutorial/run
24 | tutorial/tutorial_llama2
25 | tutorial/tutorial_qwen
26 | tutorial/evaluator
27 | tutorial/continue_train
28 | tutorial/custom_model_flow
29 | tutorial/ems
30 | tutorial/profile
31 |
32 | .. toctree::
33 | :maxdepth: 1
34 | :caption: Programming
35 |
36 | programming
37 | config_yaml
38 | advanced
39 |
40 | .. toctree::
41 | :maxdepth: 1
42 | :caption: API Documentation
43 |
44 | api/index
45 |
46 | .. toctree::
47 | :maxdepth: 1
48 | :caption: FAQ
49 |
50 | faq
51 |
--------------------------------------------------------------------------------
/docs/en/installation.md:
--------------------------------------------------------------------------------
1 | # Environment and Code Setup
2 |
3 | 1. Docker Image Preparation
4 |
5 | It is recommended to refer to `https://github.com/alibaba/ChatLearn/tree/master/docker/torch/Dockerfile.torch2.3.0` for preparing the docker image.
6 | If you're training on the PAI DLC/DSW environment, we suggest using the pre-built image provided below:
7 |
8 | ```bash
9 | registry.cn-wulanchabu.aliyuncs.com/pai-dlc/pytorch-training:2.4.0-gpu-py3.10-cu12.5-ngc24.06-ubuntu22.04
10 | ```
11 |
12 | 2. Code Preparation: Users need to download the ChatLearn framework code.
13 |
14 | ```
15 | # Clone ChatLearn code
16 | git clone https://github.com/alibaba/ChatLearn.git
17 | ```
18 |
19 | 3. If you need to run the alignment training program based on the Megatron-LM framework, you also need to download the `Megatron-LM` code.
20 |
21 | ```
22 | # Clone Megatron-LM
23 | git clone https://github.com/NVIDIA/Megatron-LM.git
24 | git checkout core_r0.8.0
25 | ```
26 |
27 | > [!NOTE]
28 | > If you are using Megatron-LM version `core_r0.8.0`, you may encounter an issue in converting checkpoints: `ValueError: Default process group has not been initialized, please make sure to call init_process_group`. Please refer to the solution in the [FAQ: Failure when converting checkpoint](faq.md#failure-when-converting-checkpoint).
29 |
--------------------------------------------------------------------------------
/docs/en/tutorial/check_sync_parameter.md:
--------------------------------------------------------------------------------
1 | # Debugging Parameter Synchronization
2 |
3 | The ParameterSync module ensures weight consistency between the training (trainer) and serving (inference) components. This guide helps debug synchronization issues that may arise when using different distributed strategies, such as:
4 |
5 | Trainer Side: Megatron-LM (expert parallel + tensor parallel + pipeline parallel)
6 | Inference Side: vLLM (tensor parallel)
7 |
8 | ## Step-by-Step Debugging Guide
9 |
10 | ### 1. Set Up Environment Variable
11 |
12 | Specify a path to dump parameter snapshots before/after synchronization by setting the DEBUG_SYNC_PARAMETERS_PATH environment variable, and huggingface format checkpoint used in vLLM:
13 |
14 | ``` bash
15 | export DEBUG_SYNC_PARAMETERS_PATH=/path/to/dump_directory
16 | export vllm_load_format=auto
17 | export policy_load=/workspace/hf_ckp/QWen2-Max
18 | ```
19 |
20 | ### 2. Launch the ChatLearn Job
21 |
22 | Start your training job (e.g., DPO fine-tuning for Llama 2) using [tutorial_llama2.md](/docs/en/tutorial/tuotiral_llama2.md):
23 | ```bash
24 | bash scripts/train_online_dpo_llama.sh
25 | ```
26 |
27 | This will generate two directories:
28 | before_sync_parameter: Parameters before synchronization.
29 | after_sync_parameter: Parameters after synchronization.
30 | Each directory contains subfolders for every TP rank.
31 |
32 | ### 3. Check Dumped Files
33 | Verify the dumped parameter files:
34 | ``` bash
35 | tree /path/to/dump_directory
36 | ```
37 |
38 | Example output:
39 | ``` text
40 | /workspace/debug_sync_params
41 | ├── before_sync_parameter
42 | │ ├── 0 # Parameters from TP rank 0
43 | │ ├── 1 # Parameters from TP rank 1
44 | │ └── ...
45 | └── after_sync_parameter
46 | ├── 0
47 | ├── 1
48 | └── ...
49 | ```
50 |
51 | ### 4. Run the Parameter Check Script
52 | Use the check_sync_parameter.py tool to compare parameters before/after synchronization:
53 | ``` bash
54 | python chatlearn/tools/check_sync_parameter.py --root_dir /path/to/dump_directory | tee check.log
55 | ```
56 |
57 | This script will compare parameter shapes and values
58 |
59 | Generate a log file (check.log) with detailed results.
60 |
61 | ### 5. Interpret the Log File
62 |
63 | Successful Sync:
64 | ```text
65 | PASS|1|model.layers.5.self_attn.qkv_proj.weight
66 | ```
67 |
68 | MisMatch Detected with Mean Values:
69 | ```text
70 | DIFF|1|model.layers.3.mlp.shared_expert.gate_up_proj.weight|torch.Size([2816, 2048])|tensor(0.5247)|torch.Size([2816, 2048])|tensor(8.231)
71 | ```
--------------------------------------------------------------------------------
/docs/en/tutorial/continue_train.md:
--------------------------------------------------------------------------------
1 | # Resuming and Fault Tolerance
2 |
3 | The alignment task involves the computation and interaction of multiple models. As the model scale and computational resources increase, occasional exceptions may occur due to the dependent software stack and hardware environment, leading to task interruption.
4 |
5 | To ensure that interrupted tasks can automatically resume their state, ChatLearn provides the resuming function, which in combination with PAI-DLC's AIMaster, can automatically detect errors and resume functionality.
6 |
7 | ## Configuring ChatLearn Resuming
8 |
9 | Resuming an alignment task requires consideration of the following points:
10 |
11 | 1. Recording and restoring data progress: For recording data status, users need to configure `data_checkpoint_path` in the training configuration master file, such as `rlhf.yaml`. If `data_checkpoint_path` is not empty, ChatLearn will record the current data progress and store the data checkpoint during each `save_checkpoint`.
12 |
13 | 2. Restoring training states such as episodes and iterations: When users configure `data_checkpoint_path` and the corresponding data checkpoint exists in the folder, ChatLearn will automatically restore the training state to the latest checkpoint status and set the `resume_training` variable to `True`.
14 |
15 | 3. Loading checkpoints: When `resume_training==True`, the checkpoints for `reference` and `reward` in RLHF remain unchanged. However, `ppo_policy` and `ppo_value` need to load the checkpoints stored during training, rather than the original initialized checkpoints. Therefore, special processing needs to be done in the setup phase.
16 |
17 | ```python
18 | if self.resume_training:
19 | self.args.load = get_args().save
20 | self.args.load_iteration = -1
21 | self.args.no_load_optim = False
22 | self.args.no_load_rng = False
23 | self.args.no_load_args = False
24 | self.args.no_load_scheduler = False
25 | self.args.finetune = False
26 | ```
27 |
28 | For more details, refer to `examples/megatron/scripts/train_rlhf_llama.sh`.
29 |
30 | If a user configures `data_checkpoint_path` in the program but does not want to enable the resuming function, they can also disable this functionality by configuring `enable_resume_training: False`.
31 |
32 | ## Combining with DLC AIMaster to Achieve Fault Tolerance and Automatic Resuming
33 |
34 | DLC provides fault tolerance monitoring based on AIMaster. AIMaster is a task-level component that, when enabled for fault tolerance monitoring, launches an AIMaster instance to run with other task instances, serving the roles of task monitoring, fault judgment, and resource control.
35 |
36 | Users can combine AIMaster's fault tolerance functionality with ChatLearn's resuming functionality to achieve automatic resuming of training tasks.
37 |
38 | The following is an example of fault tolerance monitoring configuration, which includes enabling hang detection and error detection. When the hang exceeds 1 hour or AIMaster detects an error, the task will be automatically restarted, with a maximum number of restarts being 3 times.
39 |
40 | 
41 |
42 | For more fault tolerance configuration, please refer to the DLC [Fault Tolerance Documentation](https://help.aliyun.com/zh/pai/user-guide/fault-tolerance-monitoring-based-on-aimaster?spm=a2c4g.11186623.0.0.12011976WAncyo).
43 |
--------------------------------------------------------------------------------
/docs/en/tutorial/data.md:
--------------------------------------------------------------------------------
1 | # Data
2 |
3 | This document describes the data preparation process for different stages: SFT, Reward, RLHF, DPO, OnlineDPO and GRPO.
4 |
5 |
6 | **The following is a collection of general environment variables used in this tutorial script:**
7 |
8 | | ENV | Explanation |
9 | | --- | --- |
10 | | `CHATLEARN` | The location where the ChatLearn code is cloned [https://github.com/alibaba/ChatLearn.git](https://github.com/alibaba/ChatLearn.git) |
11 | | `DATASET_ROOT` | The root directory for storing the SFT/Reward/RLHF/DPO/OnlineDPO/GRPO training dataset collection. |
12 |
13 |
14 |
15 | ## 1 Prepare SFT Training Data
16 |
17 | Organize the question-response pairs of SFT data into a jsonl file, where each line of the jsonl file represents a SFT data sample in the following Python dictionary format:
18 |
19 | ```
20 | {'query': question, 'response': reply}
21 | ```
22 |
23 | Taking the example of Anthropic's helpful&harmless data, use the following code to store it in `$DATASET_ROOT/sft/train.jsonl`.
24 |
25 | ```bash
26 | cd ${CHATLEARN}/examples/megatron/
27 | DATASET_ROOT=$path_to_dataset_root
28 | python data/prepare_data_sft.py $DATASET_ROOT
29 | ```
30 |
31 | ## 2 Prepare Reward Training Data
32 |
33 | 1. First, prepare question-different response pairs and organize them into a jsonl file. Each line in the jsonl file represents a Reward model training data sample in the following Python dictionary format:
34 |
35 | ```
36 | {'query': question, 'response': [reply 1, reply 2, ...], 'score': [score1, score2, ...]}
37 | ```
38 |
39 | The score value indicates the quality of the corresponding response, with higher scores indicating higher quality and closer to human preference.
40 |
41 | 2. Taking the example of Anthropic's helpful&harmless data, use the following code to store it in `$DATASET_ROOT/rm/train.jsonl` and `$DATASET_ROOT/rm/dev.jsonl`.
42 |
43 | ```bash
44 | cd ${CHATLEARN}/examples/megatron/
45 | DATASET_ROOT=path-to-dataset-root
46 | python data/prepare_data_reward.py $DATASET_ROOT
47 | ```
48 |
49 | ## 3 Prepare Alignment Training Data
50 |
51 | ChatLearn supports multiple alignments: RLHF, DPO, OnlineDPO, GRPO
52 |
53 | 1. Firstly, prepare a dataset of instructions to be explored and organize it into a JSON file. Each line in the JSON file should represent a prompt in the following format:
54 |
55 | ```
56 | {"prompt": prompt}
57 | ```
58 |
59 | 2. Taking Anthropic's helpful & harmless data as an example, use the following code to store the dataset in `$DATASET_ROOT/alignment/train.jsonl` and `$DATASET_ROOT/alignment/dev.jsonl`:
60 |
61 | ```bash
62 | cd ${CHATLEARN}/examples/megatron/
63 | DATASET_ROOT=path-to-dataset-root
64 | python data/prepare_data_alignment.py $DATASET_ROOT
65 | ```
66 | ## 4 Prepare Math Training Data
67 |
68 | 1. Firstly, prepare a dataset of math data to be explored and organize it into a JSON file. Each line in the JSON file should represent a prompt in the following format:
69 |
70 | ```
71 | {"eval_func": "math_rule", "prompt": prompt, 'answer': answer}
72 | ```
73 |
74 | 2. Taking openai/gsm8k data as an example, use the following code to store the dataset in `$DATASET_ROOT/math/train.jsonl`:
75 |
76 | ```bash
77 | cd ${CHATLEARN}/examples/megatron/
78 | DATASET_ROOT=path-to-dataset-root
79 | python data/prepare_data_math.py $DATASET_ROOT
80 | ```
--------------------------------------------------------------------------------
/docs/en/tutorial/ems.md:
--------------------------------------------------------------------------------
1 | # Efficient Memory Sharing (EMS)
2 |
3 | ChatLearn provides EMS feature to significantly reduce the GPU memory usage during the alignment training.
4 | It maximizes the use of limited resources to train models with larger-scale or to improve overall training efficiency by improving the model's parallel strategy and increasing the batch size after GPU memory saved.
5 |
6 | When multiple models in ChatLearn share the same resources for training or inference, enabling the EMS feature allows these models to sequentially share GPU memory:
7 | - After each model is initialized, tensors/buffers that constantly reside in GPU memory (such as weights, gradient buffers, and optimization states) are unloaded to the RAM or freed to release their occupied GPU memory.
8 | - Before training or inference for a specific model, the tensors/buffers are loaded from the RAM or reconstructed, and then training or inference takes place.
9 | - Once the training or inference is complete, the tensors/buffers are again unloaded to the RAM or freed to release their occupied GPU memory.
10 |
11 | By repeating the above process, multiple models sequentially share GPU memory, maximizing the efficiency of GPU memory usage.
12 |
13 | ## Usage
14 | Users can specify whether to enable the EMS feature by configuring the `free_memory` (bool type, default is False) parameter for each model. This can be directly modified in the `rlhf.yaml` for each model. For example, to enable the EMS feature for the policy model:
15 | ```yaml
16 | policy:
17 | model_config_file: old_policy_inference.yaml
18 | ...
19 | free_memory: ${free_memory_policy:True}
20 | ```
21 | Alternatively, it can also be configured in the training script using environment variables:
22 | - Policy model: `export free_memory_policy=True`
23 | - Reference model: `export free_memory_reference=True`
24 | - Reward model: `export free_memory_reward=True`
25 | - Value model: `export free_memory_value=True`
26 | - PPO policy model: `export free_memory_ppo_policy=True`
27 | - PPO value model: `export free_memory_ppo_value=True`
28 |
29 | A complete example can be found in the [llama2 configuration](https://github.com/alibaba/ChatLearn/blob/main/examples/megatron/configs/llama2/rlhf.yaml).
--------------------------------------------------------------------------------
/docs/en/tutorial/evaluator.md:
--------------------------------------------------------------------------------
1 | # Evaluator
2 |
3 | This document will introduce how to perform model evaluation. Users can use `EvalEngine` to evaluate models independently or configure the evaluator within the training engine to perform evaluations during training.
4 |
5 | ```python
6 | def eval_flow(batch):
7 | p = policy.forward_step(batch)
8 | r = reward.eval_step(p)
9 | r1 = reward2.eval_step(p)
10 | return r, r1
11 |
12 | evaluator = Evaluator(eval_flow)
13 | evaluator.set_dataset(prompts)
14 | results = evaluator.eval()
15 | ```
16 |
17 | In the above example, we constructed an evaluation flow for three models. Users can customize the evaluation execution flow through the `eval_flow`.
18 |
19 | The result returned by `evaluator.eval` is of type `dict`, where the key is `model_name` and the value is a `list` containing the results of the computations for each batch.
20 |
21 | In the above example, the result returned by `eval` will be `{"reward": [batch0, batch1, batch2], "reward2": [batch0, batch1, batch2]}`.
22 |
--------------------------------------------------------------------------------
/docs/en/tutorial/profile.md:
--------------------------------------------------------------------------------
1 | # Profile
2 | ChatLearn provides two ways to profile performance:
3 | 1. Torch profiler
4 | 2. nsys
5 | Note: For large models, the profile result can be very large. It is recommended to reduce the model size when profiling.
6 |
7 | ## Torch Profiler
8 | Users can enable the Torch profiler by configuring the rlhf setting `profiler_dir: path_to_profile_dir` in the main configuration file of the system.
9 | ```yaml
10 | profiler_dir: path_to_profile_dir
11 | ```
12 |
13 | ## nsys
14 | Users can enable the nsys profiler by configuring the rlhf setting `nsys: True` in the main configuration file of the system.
15 | ```yaml
16 | runtime:
17 | nsys: True
18 | ```
19 | When launching the program, nsys startup parameters need to be added before the execution command, as shown in the following example:
20 | ```bash
21 | nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s none --capture-range=cudaProfilerApi --capture-range-end=stop-shutdown --cudabacktrace=true -x true --force-overwrite true -o my_profile \
22 | python train_rlhf.py XXX
23 | ```
--------------------------------------------------------------------------------
/docs/en/tutorial/run.md:
--------------------------------------------------------------------------------
1 | # Distributed Execution
2 |
3 | This document will provide instructions on how to execute a distributed training task.
4 |
5 | ## PAI DLC Distributed Execution
6 |
7 | [Aliyun PAI DLC](https://www.aliyun.com/activity/bigdata/pai-dlc) [1] can conveniently and efficiently support training for various tasks.
8 |
9 | The screenshots of the PAI-DLC task creation page is shown as follows.
10 | Select the job type as `PyTorch` and paste the command into the `Execution Command` window.
11 |
12 | 
13 |
14 | 
15 |
16 |
17 |
18 | ## Non-PAI-DLC environment
19 |
20 | If you want to submit distributed training in a non-PAI-DLC environment,
21 | the following environment variables need to be configured on each node before executing the script:
22 |
23 | ```bash
24 | export MASTER_ADDR=xxx
25 | export MASTER_PORT=xxx
26 | export WORLD_SIZE=xxx
27 | export GPUS_PER_NODE=8
28 | export RANK=xx
29 | ```
30 |
31 | ## Reference
32 |
33 | 1. Aliyun Machine Learning PAI-DLC: [https://www.aliyun.com/activity/bigdata/pai-dlc](https://www.aliyun.com/activity/bigdata/pai-dlc)
34 |
--------------------------------------------------------------------------------
/docs/en/tutorial/tutorial_grpo_fsdp.md:
--------------------------------------------------------------------------------
1 | # End-to-End GRPO Training Tutorial with FSDP
2 |
3 | This document provides instructions for end-to-end training using the ChatLearn, pytorch FSDP and vLLM framework, and the qwen3 model.
4 |
5 | ## Environment Setup
6 | 1. Docker Image Preparation
7 |
8 | We recommend running the following example in PAI [DSW](https://help.aliyun.com/zh/pai/user-guide/create-and-manage-dsw-instances/)/[DLC](https://help.aliyun.com/zh/pai/user-guide/create-a-training-task?spm=a2c4g.11186623.help-menu-30347.d_3_3_5_5.2dfb1925l3QjwG). You need to use the following image to launch the instance.
9 | ```bash
10 | dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312
11 | ```
12 |
13 | You can use a VPC address to accelerate image pulling. The image address should be adjusted based on the current region. For example, if you need to launch a DSW instance in Shanghai, you can use the following image `dsw-registry-vpc.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312`.
14 |
15 | 2. Code Preparation
16 |
17 | ```bash
18 | git clone https://github.com/alibaba/ChatLearn.git && cd ChatLearn
19 | ```
20 |
21 | ## Data Preparation
22 | We take [MATH-lighteval](https://www.modelscope.cn/datasets/AI-ModelScope/MATH-lighteval) as exmaple.
23 | ```bash
24 | # download dataset
25 | mkdir -p dataset
26 | modelscope download --dataset AI-ModelScope/MATH-lighteval --local_dir dataset/MATH-lighteval
27 | # preprocess dataset
28 | python examples/fsdp/data/data_preprocess/math_lighteval.py --input_dir dataset/MATH-lighteval --local_dir dataset/MATH-lighteval
29 | # download model weight
30 | modelscope download --model Qwen/Qwen3-8B --local_dir Qwen3-8B
31 | ```
32 |
33 | ## Training
34 | You can run the following command to start training:
35 |
36 | ```bash
37 | bash examples/fsdp/scripts/train_grpo_qwen3.sh
38 | ```
39 |
40 | ## Using Wandb
41 | If you want to use Wandb to log the training process, you need to modify the following configuration in [train_grpo_qwen3.sh](../../../examples/fsdp/scripts/train_grpo_qwen3.sh):
42 |
43 | ```bash
44 | export enable_wandb=True
45 | export wandb_project="Your-Wandb-Project-Name"
46 | export WANDB_API_KEY="Your-Wandb-api-key"
47 | ```
--------------------------------------------------------------------------------
/docs/en/tutorial/tutorial_qwen.md:
--------------------------------------------------------------------------------
1 | # End-to-end training tutorial based on the Qwen model
2 |
3 | This document describes DPO training based on the ChatLearn, DeepSpeed framework, and Qwen model.
4 |
5 | **The following is a collection of common environment variables used in this tutorial script:**
6 | | ENV | Meaning |
7 | | --- |-------------------------------------------------------------------------------------------------------------------------------|
8 | | `CHATLEARN` | Location where the ChatLearn code repository is cloned [https://github.com/alibaba/ChatLearn.git](https://github.com/alibaba/ChatLearn.git) |
9 | | `DATASET_ROOT` | Root directory where the training datasets are stored |
10 |
11 | ## Setup: Image, Code, and Data Preparation
12 |
13 | ### Image / Code
14 |
15 | Please refer to [Environment and Code Setup](../installation.md).
16 |
17 | ### Data
18 | The data format required by qwen2 is chatml:
19 | ```
20 | {"type": "chatml", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Tell me something about large language models."}, {"role": "assistant", "content": "Large language models are a type of language model that is trained on a large corpus of text data. They are capable of generating human-like text and are used in a variety of natural language processing tasks..."}], "source": "unknown"}
21 | ```
22 | The following script can convert `Dahoas/full-hh-rlhf` to data in chatml format and store it in the file `$DATASET_ROOT/alignment/train.jsonl`:
23 | ```bash
24 | cd ${CHATLEARN}/examples/huggingface/
25 | DATASET_ROOT=path-to-dataset-root
26 | python data/preprocess_data_chatml.py $DATASET_ROOT
27 | ```
28 |
29 | ## DPO
30 | Here is an example of DPO training for Qwen2-7B.
31 | In this example, the user needs to set `policy_model_path` to the initialization model checkpoint path, and the Policy model and Reference model will be initialized with this checkpoint.
32 | ```
33 | export CHATLEARN=path-to-chatlearn
34 | export DATASET_PATH=$DATASET_ROOT/alignment/train.jsonl
35 | export policy_model_path=path-to-qwen2-ckpt
36 | cd ${CHATLEARN}/examples/huggingface/
37 | bash scripts/train_dpo_qwen.sh
38 | ```
39 |
--------------------------------------------------------------------------------
/docs/images/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/arch.png
--------------------------------------------------------------------------------
/docs/images/class.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/class.png
--------------------------------------------------------------------------------
/docs/images/dlc_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/dlc_1.jpg
--------------------------------------------------------------------------------
/docs/images/dlc_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/dlc_2.jpg
--------------------------------------------------------------------------------
/docs/images/engine.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/engine.jpg
--------------------------------------------------------------------------------
/docs/images/engine_class.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/engine_class.png
--------------------------------------------------------------------------------
/docs/images/fault.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/fault.png
--------------------------------------------------------------------------------
/docs/images/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/logo.jpg
--------------------------------------------------------------------------------
/docs/images/perf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/perf.png
--------------------------------------------------------------------------------
/docs/images/rlhf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/rlhf.png
--------------------------------------------------------------------------------
/docs/images/yaml.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/yaml.jpg
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | sphinx_rtd_theme
3 | recommonmark
4 | sphinx-markdown-tables
5 | myst-parser
6 | sphinx-markdown-builder
7 | sphinx_markdown_checkbox
8 |
--------------------------------------------------------------------------------
/docs/zh/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the OS, Python version and other tools you might need
9 | build:
10 | os: ubuntu-22.04
11 | tools:
12 | python: "3.10"
13 | # You can also specify other tool versions:
14 | # nodejs: "19"
15 | # rust: "1.64"
16 | # golang: "1.19"
17 |
18 | # Build documentation in the "docs/" directory with Sphinx
19 | sphinx:
20 | configuration: docs/zh/conf.py
21 |
22 | # Optionally build your docs in additional formats such as PDF and ePub
23 | formats:
24 | - pdf
25 | - epub
26 |
27 | # Optional but recommended, declare the Python requirements required
28 | # to build your documentation
29 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
30 | python:
31 | install:
32 | - requirements: docs/requirements.txt
33 |
--------------------------------------------------------------------------------
/docs/zh/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 |
3 | # You can set these variables from the command line, and also
4 | # from the environment for the first two.
5 | SPHINXOPTS ?=
6 | SPHINXBUILD ?= sphinx-build
7 | SOURCEDIR = .
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20 |
--------------------------------------------------------------------------------
/docs/zh/api/config.rst:
--------------------------------------------------------------------------------
1 | Config
2 | ======
3 |
4 | .. autoclass:: chatlearn.utils.arguments.RuntimeEnvConfig
5 | :members:
6 | :undoc-members:
7 |
8 | .. autoclass:: chatlearn.utils.arguments.RuntimeConfig
9 | :members:
10 | :undoc-members:
11 |
12 | .. autoclass:: chatlearn.utils.arguments.ModelConfig
13 | :members:
14 | :undoc-members:
15 |
16 |
17 | .. autoclass:: chatlearn.utils.arguments.BatchGenerationConfig
18 | :members:
19 | :undoc-members:
20 |
21 |
22 | .. autoclass:: chatlearn.utils.arguments.LoraConfig
23 | :members:
24 | :undoc-members:
25 |
26 |
27 |
--------------------------------------------------------------------------------
/docs/zh/api/engine.rst:
--------------------------------------------------------------------------------
1 | Engine
2 | ======
3 |
4 | .. autoclass:: chatlearn.DPOEngine
5 | :members:
6 | :undoc-members:
7 |
8 | .. autoclass:: chatlearn.OnlineDPOEngine
9 | :members:
10 | :undoc-members:
11 |
12 | .. autoclass:: chatlearn.RLHFEngine
13 | :members:
14 | :undoc-members:
15 |
16 | .. autoclass:: chatlearn.EvalEngine
17 | :members:
18 | :undoc-members:
19 |
20 | .. autoclass:: chatlearn.Evaluator
21 | :members: set_dataset,set_post_process_func,eval
22 | :undoc-members:
23 |
--------------------------------------------------------------------------------
/docs/zh/api/index.rst:
--------------------------------------------------------------------------------
1 | API
2 | =======================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 |
7 | engine.rst
8 | module.rst
9 | config.rst
10 |
--------------------------------------------------------------------------------
/docs/zh/api/module.rst:
--------------------------------------------------------------------------------
1 | RLHF Module
2 | ===========
3 |
4 | .. autoclass:: chatlearn.models.base_module.BaseModule
5 | :members:
6 | :undoc-members:
7 |
8 | .. autoclass:: chatlearn.models.torch_module.TorchModule
9 | :members:
10 | :undoc-members:
11 | :show-inheritance:
12 |
13 | .. autoclass:: chatlearn.models.megatron_module.MegatronModule
14 | :members:
15 | :undoc-members:
16 | :show-inheritance:
17 |
18 |
--------------------------------------------------------------------------------
/docs/zh/config_yaml.md:
--------------------------------------------------------------------------------
1 | # 配置文件
2 | ## 训练配置文件
3 |
4 | 用户需要一个程序主 yaml 配置来设置运行环境、模型配置和 RLHF 训练流程相关的配置。同时,用户也可能需要为每个模型配置单独的模型配置。
5 |
6 | RLHF 的训练配置包括三部分
7 |
8 | 1. runtime_env: 运行环境配置
9 | 2. models: 模型配置。每一个模型都可以单独配置模型参数。通过`model_name`来区分不同的模型。这里`model_name`对应主文件中定义模型时传入的`model_name`。
10 | 3. runtime: 训练配置
11 |
12 | 以下为一个训练配置的示例。具体的配置项含义可以参考 [Config API 文档](api/config.rst).
13 |
14 | 为了方便配置不同的超参数,我们也支持从环境变量读取参数。格式如下
15 |
16 | ```
17 | param: ${env_name:default_value}
18 | ```
19 | `param`为参数名,`env_name`为环境变量名,`default_value`为默认值 (可选)。在以下例子中,如果设置了环境变量`ref_generation_batch_size`, 则会从环境变量中读取赋值给`reference`的`generation_batch_size`,如果没有设置环境变量`ref_generation_batch_size`,则使用默认值 4。
20 |
21 | ```yaml
22 | runtime_env:
23 | platform: DLC
24 | excludes:
25 | - "*pt"
26 | - "logs"
27 | - "tensorboards"
28 | - ".nfs*"
29 |
30 |
31 | models:
32 | policy:
33 | model_config_file: policy_inference.yaml
34 | num_gpu: 8
35 | trainable: False
36 |
37 | reference:
38 | model_config_file: reference.yaml
39 | num_gpu: 8
40 | trainable: False
41 | generation_batch_size: ${ref_generation_batch_size:4}
42 |
43 | reward:
44 | model_config_file: reward_inference.yaml
45 | num_gpu: 8
46 | trainable: False
47 |
48 | value:
49 | model_config_file: old_value_inference.yaml
50 | num_gpu: 8
51 | trainable: False
52 |
53 | ppo_policy:
54 | model_config_file: ppo_policy.yaml
55 | num_gpu: 8
56 | trainable: True
57 |
58 | ppo_value:
59 | model_config_file: ppo_value.yaml
60 | num_gpu: ${num_gpu}
61 | trainable: True
62 |
63 | runtime:
64 | colocation:
65 | - policy,ppo_policy,reward,reference,value,ppo_value
66 | generation_batch_size: ${generation_batch_size:4}
67 | train_micro_batch_size: 2
68 | train_global_batch_size: ${train_global_batch_size:512}
69 | num_episode: 200
70 | sample_per_episode: ${sample_per_episode:1024}
71 | num_training_epoch: 1
72 | save_episode_interval: ${save_episode_interval:50}
73 | data_path: ${data_path}
74 | eval_episode_interval: ${eval_episode_interval:100}
75 | ```
76 |
77 |
78 | ## 模型配置 YAML
79 |
80 | 本框架支持对每个模型配置单独的配置文件,用于配置不同模型的超参数,并行化策略,checkpoint 初始化等。模型配置文件格式为 yaml 文件。下面是一个简单的模型配置例子。
81 |
82 | ```yaml
83 | num_layers: 6
84 | hidden_size: 768
85 | num_attention_heads: 12
86 | bf16: True
87 | seq_length: 2048
88 | tensor_model_parallel_size: 8
89 | pipeline_model_parallel_size: 2
90 | load: path-to-ckpt
91 | ```
92 |
93 | 为了简化不同模型的共享配置,我们拓展了 yaml 的语法,通过 include 的字段来集成 base 配置文件的配置。在下面这个例子中,`policy_inference.yaml`和`ppo_policy.yaml`共享`num_layers`/`hidden_size`等参数,同时,两个模型的配置了各自不同的`pipeline_model_parallel_size`。
94 |
95 | 
96 |
--------------------------------------------------------------------------------
/docs/zh/index.rst:
--------------------------------------------------------------------------------
1 | ChatLearn 使用文档
2 | =======================
3 |
4 |
5 | .. toctree::
6 | :maxdepth: 1
7 | :caption: 简介
8 |
9 | chatlearn
10 |
11 |
12 | .. toctree::
13 | :maxdepth: 1
14 | :caption: 安装
15 |
16 | installation
17 |
18 |
19 | .. toctree::
20 | :maxdepth: 1
21 | :caption: 使用教程
22 |
23 | tutorial/data
24 | tutorial/run
25 | tutorial/tutorial_llama2
26 | tutorial/tutorial_qwen
27 | tutorial/evaluator
28 | tutorial/continue_train
29 | tutorial/custom_model_flow
30 | tutorial/ems
31 | tutorial/profile
32 |
33 |
34 | .. toctree::
35 | :maxdepth: 1
36 | :caption: 编程接口
37 |
38 | programming
39 | config_yaml
40 | advanced
41 |
42 | .. toctree::
43 | :maxdepth: 1
44 | :caption: API 文档
45 |
46 | api/index
47 |
48 |
49 | .. toctree::
50 | :maxdepth: 1
51 | :caption: 常见问题
52 |
53 | faq
--------------------------------------------------------------------------------
/docs/zh/installation.md:
--------------------------------------------------------------------------------
1 | # 环境和代码准备
2 |
3 | 1. 镜像准备
4 |
5 | 可以参考 `https://github.com/alibaba/ChatLearn/tree/master/docker/torch/Dockerfile.torch2.3.0` 准备镜像。
6 | 如果在 PAI DLC/DSW 环境上训练,推荐使用我们准备好的镜像:
7 |
8 | ```bash
9 | registry.cn-wulanchabu.aliyuncs.com/pai-dlc/pytorch-training:2.4.0-gpu-py3.10-cu12.5-ngc24.06-ubuntu22.04
10 | ```
11 |
12 | 2. 代码准备: 用户需要下载 `ChatLearn` 框架代码。
13 |
14 | ```
15 | # 下载 ChatLearn 代码
16 | git clone https://github.com/alibaba/ChatLearn.git
17 | ```
18 |
19 | 3. 如果您需要运行基于 Megatron-LM 框架的 alignment 训练程序,您也需要下载 `Megatron-LM` 代码。
20 |
21 | ```
22 | # 下载 Megatron-LM
23 | git clone https://github.com/NVIDIA/Megatron-LM.git
24 | git checkout core_r0.8.0
25 | ```
26 |
27 | > [!NOTE]
28 | > 若使用 Megatron-LM core_r0.8.0,您可能在转换 checkpoint 时遇到错误:`ValueError: Default process group has not been initialized, please make sure to call init_process_group.`,您可以参考 [FAQ:转换 Checkpoint 失败](faq.md#转换-checkpoint-失败) 中的解决方案。
29 |
--------------------------------------------------------------------------------
/docs/zh/tutorial/continue_train.md:
--------------------------------------------------------------------------------
1 | # 续跑和容错
2 |
3 | Alignment 任务涉及到多模型的计算和交互,随着模型规模的增大和计算资源的增加,由于依赖的软件栈和硬件环境都有可能出现偶发异常,会导致任务停止运行。
4 | 为了保障被中断的任务可以恢复状态进行自动续跑,ChatLearn提供了续跑的功能,结合 PAI-DLC 的 AIMaster,可以实现自动错误检测和续跑功能。
5 |
6 | ## 配置 ChatLearn 续跑
7 |
8 | 任务的续跑需要考虑以下几点:
9 | 1. 数据进度的记录和恢复; 对于数据状态的记录,用户需要在训练配置主文件如 `rlhf.yaml` 中配置 `data_checkpoint_path`。
10 | 如果 `data_checkpoint_path` 不为空,则 ChatLearn 会记录当前的数据进度,并在每次 `save_checkpoint` 的同时存储 data checkpoint。
11 | 2. 训练状态比如 episode、iteration 等信息的恢复;当用户配置了 `data_checkpoint_path`,同时文件夹中存在对应的 data checkpoint,ChatLearn 会自动恢复训练状态到当前最新的checkpoint状态。
12 | 并将模型的 `resume_training` 变量设为 `True` 。
13 | 3. checkpoint的加载;当 `resume_training==True`, 对于 RLHF 中的几个模型,`reference` 和 `reward` 加载的 checkpoint 保持不变。
14 | `ppo_policy` 和 `ppo_value` 需要加载训练中存储的checkpoint,而不是原始初始化的checkpoint。 因此需要在 `setup` 阶段做特殊处理。
15 |
16 | ```python
17 | if self.resume_training:
18 | self.args.load = get_args().save
19 | self.args.load_iteration = -1
20 | self.args.no_load_optim = False
21 | self.args.no_load_rng = False
22 | self.args.no_load_args = False
23 | self.args.no_load_scheduler = False
24 | self.args.finetune = False
25 | ```
26 |
27 | 更多详情可以参考 `examples/megatron/scripts/train_rlhf_llama.sh` 。
28 |
29 | 如果用户在程序中配置了 `data_checkpoint_path` ,但是不想打开续跑功能,则也可以通过配置 `enable_resume_training: False` 来关闭此功能。
30 |
31 | ## 和 DLC AIMaster 结合实现容错和自动续跑
32 |
33 | DLC提供了基于AIMaster的容错监控功能。AIMaster是一个任务级别的组件,当任务开启AIMaster的容错监控功能后,
34 | 会拉起一个AIMaster实例和任务其他实例一起运行,起到任务监控、容错判断、资源控制的作用。
35 |
36 | 用户可以通过结合 AIMaster 的容错功能和 ChatLearn 的续跑功能来实现训练任务的自动续跑。
37 |
38 | 以下为容错监控的配置示例,在这个配置中打开了 hang 检测和错误检测,
39 | 当hang 超过 1 个小时或者当 AIMaster 检测到错误,会将任务自动重启,最大重启次数为3次。
40 |
41 | 
42 |
43 | 更多的容错配置请参考 DLC [容错文档](https://help.aliyun.com/zh/pai/user-guide/fault-tolerance-monitoring-based-on-aimaster?spm=a2c4g.11186623.0.0.12011976WAncyo) 。
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/docs/zh/tutorial/data.md:
--------------------------------------------------------------------------------
1 | # 数据准备
2 |
3 | 本文档介绍不同阶段 SFT, Reward,RLHF,DPO, OnlineDPO 和 GRPO 的数据准备流程。
4 |
5 | **以下是这个 Tutorial 脚本中使用的通用环境变量集合:**
6 |
7 | | ENV | 含义 |
8 | | --- | --- |
9 | | `CHATLEARN` | ChatLearn 代码仓库 clone 存放的位置 [https://github.com/alibaba/ChatLearn.git](https://github.com/alibaba/ChatLearn.git) |
10 | | `DATASET_ROOT` | 存放SFT/Reward/RLHF/DPO/OnlineDPO/GRPO训练数据集合的根目录 |
11 |
12 | ## 准备 SFT 训练数据
13 |
14 | 将 SFT 数据的问题 - 回复配对的样本,整理到一个 jsonl 文件中,其中 jsonl 文件中每一行为一条 SFT 数据,形式为如下的 Python 字典格式:
15 |
16 | ```
17 | {'query': 问题,'response': 回复}
18 | ```
19 |
20 | 以 Anthropic 的 helpful&harmless 的数据为例,使用如下代码,会存一个 `$DATASET_ROOT/sft/train.jsonl`.
21 |
22 | ```bash
23 | cd ${CHATLEARN}/examples/megatron/
24 | DATASET_ROOT=$path_to_dataset_root
25 | python data/prepare_data_sft.py $DATASET_ROOT
26 | ```
27 |
28 | ## 准备 Reward 训练数据
29 |
30 | 1. 首先准备问题 - 不同回复配对的样本,整理到一个 jsonl 文件中,其中 jsonl 文件中每一行为一条 Reward 模型训练数据,形式为如下的 Python 字典格式:
31 |
32 | ```
33 | {'query': 问题,'response': [回复 1, 回复 2, .....], 'score': [score1, score2, .....]}
34 | ```
35 |
36 | 其中 score 的值越高意味着对应回复的质量越高,越贴近人类偏好。
37 |
38 | 2. 以 Anthropic 的 helpful&harmless 的数据为例,使用如下代码,会存一个 `$DATASET_ROOT/rm/train.jsonl` 和 `$DATASET_ROOT/rm/dev.jsonl`.
39 |
40 | ```bash
41 | cd ${CHATLEARN}/examples/megatron/
42 | DATASET_ROOT=path-to-dataset-root
43 | python data/prepare_data_reward.py $DATASET_ROOT
44 | ```
45 |
46 | ## 准备 Alignment 训练数据
47 |
48 | ChatLearn中支持多种Alignment的训练模式:RLHF, DPO, OnlineDPO, GRPO
49 |
50 | 其中RLHF/OnlineDPO/GRPO数据格式相同。
51 |
52 |
53 | ### RLHF/OnlineDPO/GRPO
54 |
55 | 1. 首先准备一个需要被探索的指令数据集,整理到一个 jsonl 文件中,其中 jsonl 文件中每一行为一条指令,格式为
56 |
57 | ```
58 | {"prompt": 问题}
59 | ```
60 |
61 | 2. 以 Anthropic 的 helpful&harmless 的数据为例,使用如下代码,会存一个`$DATASET_ROOT/alignment/train.jsonl` 和`$DATASET_ROOT/alignment/dev.jsonl`:
62 |
63 | ```bash
64 | cd ${CHATLEARN}/examples/megatron/
65 | DATASET_ROOT=path-to-dataset-root
66 | python data/prepare_data_alignment.py $DATASET_ROOT
67 | ```
68 |
69 | ### DPO
70 |
71 | 准备一个需要被探索的指令数据集,整理到一个 jsonl 文件中,其中 jsonl 文件中每一行为一条指令,格式为:prompt+chosen+rejected,例如:
72 |
73 | ```
74 | {"prompt": 问题, "chosen": 正偏好回答, "rejected": 负偏好回答}
75 | ```
76 |
77 | 其中prompt字段内容分为两种场景:
78 | 1. 单轮对话:仅包括单轮对话的`问题`;
79 | 2. 多轮对话:包含前几轮对话的问答及最后一轮的提问。
80 |
81 | 开源Anthropic的helpful&harmless的数据满足DPO训练需求,可参考RLHF章节下载相应训练数据。
82 |
83 | ### Math
84 |
85 | 首先,准备一个要训练的数学数据集,并将其组织成一个 JSON 文件。JSON 文件中的每一行应该表示一个样本,格式如下:
86 |
87 | ```
88 | {"eval_func": "math_rule", "prompt": prompt, "answer": answer}
89 | ```
90 |
91 | 以 `openai/gsm8k` 数据为例,使用以下代码将数据集存储在 `$DATASET_ROOT/math/train.jsonl` 中:
92 |
93 | ```
94 | cd ${CHATLEARN}/examples/megatron/
95 | DATASET_ROOT=path-to-dataset-root
96 | python data/prepare_data_math.py $DATASET_ROOT
97 | ```
98 |
99 |
--------------------------------------------------------------------------------
/docs/zh/tutorial/ems.md:
--------------------------------------------------------------------------------
1 | # 高效显存复用(EMS)
2 |
3 | ChatLearn 中提供高效显存复用 (Efficient Memory Sharing, EMS) 功能来大幅减少训练过程中的显存占用。
4 | EMS 功能可以充分利用有限资源来训练更大规模的模型,也可以利用节约的显存来调整模型的并行策略或者增大 batch size,从而提升整体的训练效率。
5 |
6 | ChatLearn 中多个模型共享相同的资源进行训练或推理时,打开 EMS 功能,可以让这些模型按序共享使用显存:
7 | - 每个模型初始化完成后,将常驻显存的各类 tensor/buffer(包括 weight, grad buffer, optim states 等)卸载到内存或者直接释放,清空该模型占用的显存;
8 | - 某个模型训练或推理前,先从内存中加载或者重建 tensor/buffer,然后进行训练或推理;
9 | - 训练或推理完成后,将常驻显存的 tensor/buffer 卸载到内存或者直接释放,再次清空该模型占用的显存。
10 |
11 | 重复如上流程,多个模型间按序共享使用显存,最大化显存利用效率。
12 |
13 | ## 功能用法
14 | 用户通过配置每个模型的 `free_memory` (bool 类型, 默认为 False) 参数来指定是否开启 EMS 功能。
15 | 可以直接修改 `rlhf.yaml` 中每个模型的 `free_memory` 配置,例如打开 policy 模型的 EMS 功能:
16 |
17 | ```yaml
18 | policy:
19 | model_config_file: old_policy_inference.yaml
20 | ...
21 | free_memory: ${free_memory_policy:True}
22 | ```
23 |
24 | 用户也可以在训练脚本中通过配置环境变量来启动 EMS 功能:
25 | - policy 模型:`export free_memory_policy=True`
26 | - reference 模型:`export free_memory_reference=True`
27 | - reward 模型:`export free_memory_reward=True`
28 | - value 模型:`export free_memory_value=True`
29 | - ppo_policy 模型:`export free_memory_ppo_policy=True`
30 | - ppo_value 模型:`export free_memory_ppo_value=True`
31 |
32 | 完整示例可以参考 [llama2 配置](https://github.com/alibaba/ChatLearn/blob/main/examples/megatron/configs/llama2/rlhf.yaml)。
33 |
--------------------------------------------------------------------------------
/docs/zh/tutorial/evaluator.md:
--------------------------------------------------------------------------------
1 | # Evaluator
2 |
3 | 本文档将介绍如何进行模型评估。用户可以使用 `EvalEngine` 单独对模型进行评估,也可以在训练 Engine 里配置 evaluator 在训练的过程中进行评估。
4 |
5 | ```python
6 | def eval_flow(batch):
7 | p = policy.forward_step(batch)
8 | r = reward.eval_step(p)
9 | r1 = reward2.eval_step(p)
10 | return r, r1
11 | evaluator = Evaluator(eval_flow)
12 | evaluator.set_dataset(prompts)
13 | results = evaluator.eval()
14 | ```
15 | 在上述例子中,我们构建了一个三个模型的评估flow,用户可以自定义 evaluation 的执行 flow。
16 | evaluator.eval 返回的结果是一个 dict 类型,key 是 model_name,value 是一个 list,包含 batch 的计算结果。
17 | 在上述例子中,eval 返回的结果为 {"reward": [batch0, batch1, batch2], "reward2": [batch0, batch1, batch2]}
18 |
--------------------------------------------------------------------------------
/docs/zh/tutorial/profile.md:
--------------------------------------------------------------------------------
1 | # Profile
2 |
3 | ChatLearn 提供了两种 Profile 的方式:
4 | 1. torch profiler
5 | 2. nsys
6 |
7 | 注意:对于大模型,profile 的结果会非常大,建议在 profile 的时候减小模型尺寸。
8 |
9 | ## Torch Profiler
10 |
11 | 用户可以在系统的主配置文件中配置 rlhf 配置 `profiler_dir: path_to_profile_dir` 来开启 Torch profiler。
12 |
13 | ```yaml
14 | profiler_dir: path_to_profile_dir
15 | ```
16 |
17 | ## nsys
18 |
19 | 用户可以在系统的主配置文件中配置 rlhf 配置 `nsys: True` 来开启 nsys 的 profiler。
20 |
21 | ```yaml
22 | runtime:
23 | nsys: True
24 | ```
25 |
26 | 在启动程序的时候,需要在执行命令前加上 nsys 的启动参数,可以参考下述命令
27 |
28 | ```bash
29 | nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s none --capture-range=cudaProfilerApi --capture-range-end=stop-shutdown --cudabacktrace=true -x true --force-overwrite true -o my_profile \
30 | python train_rlhf.py XXX
31 | ```
32 |
33 |
--------------------------------------------------------------------------------
/docs/zh/tutorial/run.md:
--------------------------------------------------------------------------------
1 | # 分布式执行
2 |
3 | 本文档将介绍如何执行一个分布式训练任务。
4 |
5 | ## PAI DLC 分布式执行
6 |
7 | [阿里云 PAI DLC](https://www.aliyun.com/activity/bigdata/pai-dlc) [1]可以非常便捷高效地支持各种任务的训练。
8 |
9 | 以下为 PAI-DLC 创建任务的页面截图,选择作业类型为 `PyTorch`, 同时将上述命令修改后粘贴到`执行命令`窗口中, 设置节点镜像为 ChatLearn 编译后的镜像地址。在这个例子中,我们申请了2个节点,每个节点配置 8 卡 GPU。
10 |
11 | 
12 |
13 | 
14 |
15 |
16 |
17 | ## 其他环境分布式执行
18 |
19 | 如果您需要在非 PAI DLC 环境执行分布式任务,您需要配置以下环境变量。
20 |
21 | ```bash
22 | export MASTER_ADDR=xxx
23 | export MASTER_PORT=xxx
24 | export WORLD_SIZE=xxx
25 | export GPUS_PER_NODE=8
26 | export RANK=xx
27 | ```
28 |
29 | ## reference
30 |
31 | 1. 阿里云机器学习 PAI-DLC:[https://www.aliyun.com/activity/bigdata/pai-dlc](https://www.aliyun.com/activity/bigdata/pai-dlc)
32 |
--------------------------------------------------------------------------------
/docs/zh/tutorial/tutorial_grpo_fsdp.md:
--------------------------------------------------------------------------------
1 | # 基于 FSDP 的端到端GRPO训练流程
2 |
3 | 本文档提供使用 ChatLearn、PyTorch FSDP 和 vLLM 框架来对Qwen3模型进行GRPO训练的快速开始指南。
4 |
5 | ## 环境配置
6 | 1. Docker镜像准备
7 | 我们建议在PAI [DSW](https://help.aliyun.com/zh/pai/user-guide/create-and-manage-dsw-instances/)/[DLC](https://help.aliyun.com/zh/pai/user-guide/create-a-training-task?spm=a2c4g.11186623.help-menu-30347.d_3_3_5_5.2dfb1925l3QjwG)中运行该示例,你需要填写如下镜像地址来启动实例:
8 | ```bash
9 | dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312
10 | ```
11 |
12 | 可以使用vpc地址来加速镜像拉取速度,需要根据当前region信息来更改镜像地址。比如,启动在上海的DSW实例,可以使用如下镜像`dsw-registry-vpc.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312`。
13 |
14 | 2. 代码准备
15 |
16 | ```bash
17 | git clone https://github.com/alibaba/ChatLearn.git && cd ChatLearn
18 | ```
19 |
20 | ## 数据准备
21 | 以[MATH-lighteval](https://www.modelscope.cn/datasets/AI-ModelScope/MATH-lighteval)数据集作为示例.
22 | ```bash
23 | # 下载数据集
24 | mkdir -p dataset
25 | modelscope download --dataset AI-ModelScope/MATH-lighteval --local_dir dataset/MATH-lighteval
26 | # 数据集预处理
27 | python examples/fsdp/data/data_preprocess/math_lighteval.py --input_dir dataset/MATH-lighteval --local_dir dataset/MATH-lighteval
28 | # 下载模型权重
29 | modelscope download --model Qwen/Qwen3-8B --local_dir Qwen3-8B
30 | ```
31 |
32 | ## 训练
33 | 运行以下命令开始训练:
34 |
35 | ```bash
36 | bash examples/fsdp/scripts/train_grpo_qwen3.sh
37 | ```
38 |
39 | ## 使用 Wandb 监控
40 | 如需使用 Wandb 记录训练过程,请修改[train_grpo_qwen3.sh](../../../examples/fsdp/scripts/train_grpo_qwen3.sh)中的配置:
41 |
42 | ```bash
43 | export enable_wandb=True
44 | export wandb_project="Your-Wandb-Project-Name"
45 | export WANDB_API_KEY="Your-Wandb-api-key"
46 | ```
--------------------------------------------------------------------------------
/docs/zh/tutorial/tutorial_grpo_mcore.md:
--------------------------------------------------------------------------------
1 | # 基于 Mcore 的端到端GRPO训练流程
2 |
3 | 本文档提供使用 ChatLearn、Mcore 和 vLLM 框架来对Qwen2.5模型进行GRPO训练的快速开始指南。
4 |
5 | ## 环境配置
6 | 1. Docker镜像准备
7 | 我们建议在PAI [DSW](https://help.aliyun.com/zh/pai/user-guide/create-and-manage-dsw-instances/)/[DLC](https://help.aliyun.com/zh/pai/user-guide/create-a-training-task?spm=a2c4g.11186623.help-menu-30347.d_3_3_5_5.2dfb1925l3QjwG)中运行该示例,你需要填写如下镜像地址来启动实例:
8 | ```bash
9 | dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312
10 | ```
11 |
12 | 可以使用vpc地址来加速镜像拉取速度,需要根据当前region信息来更改镜像地址。比如,启动在上海的DSW实例,可以使用如下镜像`dsw-registry-vpc.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312`。
13 |
14 | 2. 代码准备
15 |
16 | ```bash
17 | git clone https://github.com/NVIDIA/Megatron-LM.git
18 | cd Megatron-LM && git checkout 6ba97dd37150a6bfba03d31808674211cf2a4d0d
19 | git clone https://github.com/alibaba/ChatLearn.git && cd ChatLearn
20 | ```
21 |
22 | ## 数据准备
23 | 以[MATH-lighteval](https://www.modelscope.cn/datasets/AI-ModelScope/MATH-lighteval)数据集作为示例.
24 | ```bash
25 | # 下载数据集
26 | mkdir -p dataset
27 | modelscope download --dataset AI-ModelScope/MATH-lighteval --local_dir dataset/MATH-lighteval
28 | # 数据集预处理
29 | python examples/fsdp/data/data_preprocess/math_lighteval.py --input_dir dataset/MATH-lighteval --local_dir dataset/MATH-lighteval
30 | # 下载模型权重
31 | modelscope download --model Qwen/Qwen2.5-7B-Instruct --local_dir Qwen2.5-7B-Instruct
32 | ```
33 |
34 | ## 模型转换
35 |
36 | 模型格式转换可以参考 [Pai-Megatron-Patch](https://github.com/alibaba/Pai-Megatron-Patch) 项目提供的转换脚本。
37 | 高性能分布式权重转换可以参考:https://github.com/alibaba/Pai-Megatron-Patch/tree/main/toolkits/distributed_checkpoints_convertor
38 |
39 | 运行`hf2mcore_qwen2.5_convertor.sh`脚本,需要传入的参数列表如下
40 | ```
41 | MODEL_SIZE=$1 # 模型参数:0.5B/1.5B/3B/7B/14B/32B/72B
42 | SOURCE_CKPT_PATH=$2 # 源路径
43 | TARGET_CKPT_PATH=$3 # 目标路径
44 | TP=$4 # 模型并行度
45 | PP=$5 # 流水并行度
46 | PR=$6 # 转换精度
47 | USE_TE=$7 # 是否使用Transformer Engine建模
48 | mg2hf=$8 # 是否执行mcore2hf转换
49 | HG_CKPT_PATH=$9 # HF的CKPT的路径
50 | ```
51 |
52 | 例如,使用下述脚本将7B量级的Qwen2.5的Huggingface格式的模型转换到MCore格式
53 | ```bash
54 | git clone --recurse-submodules https://github.com/alibaba/Pai-Megatron-Patch.git
55 | cd ~/Pai-Megatron-Patch/toolkits/model_checkpoints_convertor/qwen
56 | bash hf2mcore_qwen2.5_convertor.sh \
57 | 7B \
58 | /mnt/qwen-ckpts/Qwen2.5-7B-Instruct \
59 | /mnt/qwen-ckpts/Qwen2.5-7B-Instruct-hf-to-mcore-tp4-pp1 \
60 | 4 \
61 | 1 \
62 | bf16 \
63 | true \
64 | false
65 | ```
66 |
67 | ## 训练
68 | 运行以下命令开始训练:
69 |
70 | ```bash
71 | export MEGATRON_PATH="your megatron path"
72 | bash examples/mcore/scripts/train_grpo_qwen2_5.sh
73 | ```
74 |
75 | ## 使用 Wandb 监控
76 | 如需使用 Wandb 记录训练过程,请修改[train_grpo_qwen2_5.sh](../../../examples/mcore/scripts/train_grpo_qwen2_5.sh)中的配置:
77 |
78 | ```bash
79 | export enable_wandb=True
80 | export wandb_project="Your-Wandb-Project-Name"
81 | export WANDB_API_KEY="Your-Wandb-api-key"
82 | ```
--------------------------------------------------------------------------------
/docs/zh/tutorial/tutorial_qwen.md:
--------------------------------------------------------------------------------
1 | # 基于 Qwen 模型的端到端训练教程
2 |
3 | 本文档介绍基于 ChatLearn, DeepSpeed 框架和 Qwen 模型进行 DPO 训练。
4 |
5 | **以下是这个 Tutorial 脚本中使用的通用环境变量集合:**
6 |
7 | | ENV | 含义 |
8 | | --- |-------------------------------------------------------------------------------------------------------------------------------|
9 | | `CHATLEARN` | ChatLearn 代码仓库 clone 存放的位置 [https://github.com/alibaba/ChatLearn.git](https://github.com/alibaba/ChatLearn.git) |
10 | | `DATASET_ROOT` | 存放训练数据集合的根目录 |
11 |
12 |
13 | ## Setup: 镜像、代码、数据准备
14 |
15 | ### 镜像和代码
16 |
17 | 请参考 [镜像和代码准备](../installation.md)。
18 |
19 | ### 数据
20 |
21 | qwen2 要求的数据格式为chatml
22 |
23 | ```
24 | {"type": "chatml", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Tell me something about large language models."}, {"role": "assistant", "content": "Large language models are a type of language model that is trained on a large corpus of text data. They are capable of generating human-like text and are used in a variety of natural language processing tasks..."}], "source": "unknown"}
25 | ```
26 | 通过以下脚本可以将 `Dahoas/full-hh-rlhf` 转换为 chatml 格式的数据, 并存储 `$DATASET_ROOT/alignment/train.jsonl` 文件.
27 |
28 | ```bash
29 | cd ${CHATLEARN}/examples/huggingface/
30 | DATASET_ROOT=path-to-dataset-root
31 | python data/preprocess_data_chatml.py $DATASET_ROOT
32 | ```
33 |
34 |
35 | ### DPO
36 |
37 | 以下是一个 Qwen2-7B 的 DPO 训练范例。
38 | 在这个例子中,用户需要设置 `policy_model_path` 为 初始化模型 checkpoint 路径,Policy 模型和 Reference 模型将以这个 checkpoint 初始化。
39 |
40 | ```
41 | export CHATLEARN=path-to-chatlearn
42 | export DATASET_PATH=$DATASET_ROOT/alignment/train.jsonl
43 | export policy_model_path=path-to-qwen2-ckpt
44 | cd ${CHATLEARN}/examples/huggingface/
45 | bash scripts/train_dpo_qwen.sh
46 | ```
47 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/examples/__init__.py
--------------------------------------------------------------------------------
/examples/fsdp/configs/grpo/base.yaml:
--------------------------------------------------------------------------------
1 | seed: 1234
2 |
--------------------------------------------------------------------------------
/examples/fsdp/configs/grpo/grpo.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | model_config_file: vllm_policy_inference.yaml
13 | num_gpu: ${num_device:1}
14 | trainable: False
15 | batch_generation:
16 | ranking: ${batch_generation_ranking:False}
17 | min_prompt_length: ${batch_generation_min_prompt_length:0}
18 | free_memory: ${free_memory_policy:True}
19 | generation_batch_size: ${vllm_generation_batch_size:256}
20 |
21 | reward:
22 | model_config_file: base.yaml
23 | num_cpu: 2 # must set devices
24 | generation_batch_size: ${vllm_generation_batch_size:256}
25 |
26 | ref_policy:
27 | model_config_file: reference.yaml
28 | num_gpu: ${num_device:1}
29 | gpu_per_process: 1
30 | fsdp_size: -1
31 | trainable: False
32 | free_memory: ${free_memory_reference:True}
33 | generation_batch_size: ${ref_generation_batch_size:8}
34 |
35 |
36 | policy_trainer:
37 | model_config_file: policy_trainer.yaml
38 | num_gpu: ${num_device:1}
39 | gpu_per_process: 1
40 | fsdp_size: -1
41 | sp_size: ${sp_size:1}
42 | trainable: True
43 | free_memory: ${free_memory_ppo_policy:True}
44 | generation_batch_size: ${trainer_generation_batch_size:8}
45 |
46 | runtime:
47 | colocation:
48 | - policy,policy_trainer,ref_policy
49 | data_path: ${train_data_path}
50 | eval_data_path: ${eval_data_path}
51 | output_dir: ${output_dir}
52 | exp_name: ${exp_name}
53 | num_episode: ${num_episode:200}
54 | sample_per_episode: ${sample_per_episode:1024}
55 | train_micro_batch_size: ${train_micro_batch_size:1}
56 | train_global_batch_size: ${train_global_batch_size:256}
57 | save_episode_interval: ${save_episode_interval:20}
58 | max_relay_episode: 2 # for enable grpo adv compute
59 | eval_episode_interval: ${eval_episode_interval:5}
60 | log_config_file: log.yaml
61 | data_checkpoint_path: ${data_checkpoint_path}
62 | enable_eval_before_training: ${enable_eval_before_training:False}
63 |
--------------------------------------------------------------------------------
/examples/fsdp/configs/grpo/log.yaml:
--------------------------------------------------------------------------------
1 | log_dir: ${log_dir:}
2 |
3 | enable_tensorboard: ${enable_tensorboard:False}
4 | tensorboard_dir: ${tensorboard_dir:<>}
5 |
6 | enable_wandb: ${enable_wandb:False}
7 | wandb_dir: ${wandb_dir}
8 | wandb_project: ${wandb_project}
9 | wandb_id: ${exp_name}
10 | wandb_name: ${exp_name}
11 | wandb_resume: ${WANDB_RESUME:allow}
12 |
13 | # export WANDB_DISABLE_CODE="true"
14 | # export WANDB_IGNORE_GLOBS="*.patch"
15 |
--------------------------------------------------------------------------------
/examples/fsdp/configs/grpo/policy_trainer.yaml:
--------------------------------------------------------------------------------
1 | includes:
2 | - base.yaml
3 | pretrain_or_model: ${model_path}
4 | learning_rate: 2e-6
5 | grad_clip: 1
6 | gradient_checkpointing: True
7 | # for grpo algorithm
8 | pos_clip_ratio: 0.2
9 | negative_clip_ratio: 0.2
10 | save_hf: True
--------------------------------------------------------------------------------
/examples/fsdp/configs/grpo/reference.yaml:
--------------------------------------------------------------------------------
1 | includes:
2 | - base.yaml
3 | pretrain_or_model: ${model_path}
--------------------------------------------------------------------------------
/examples/fsdp/configs/grpo/vllm_policy_inference.yaml:
--------------------------------------------------------------------------------
1 | includes:
2 | - base.yaml
3 |
4 | # model partition
5 | tensor_model_parallel_size: ${tensor_model_parallel_size:2}
6 | pipeline_model_parallel_size: 1
7 | num_inference_per_prompt: ${num_inference_per_prompt:8}
8 | seq_length: ${seq_length:1024}
9 | max_new_tokens: ${max_new_tokens:1023}
10 |
11 | # sampling params
12 | temperature: ${policy_temperature:1.0}
13 | top_p: ${policy_top_p:0.9}
14 | top_k: ${policy_top_k:-1}
15 | presence_penalty: ${policy_presence_penalty:0.0}
16 | frequency_penalty: ${policy_frequency_penalty:0.0}
17 | repetition_penalty: ${policy_repetition_penalty:1.0}
18 |
19 | eval_temperature: ${policy_eval_temperature:0.6}
20 | eval_top_k: ${policy_eval_top_k:-1}
21 | eval_top_p: ${policy_eval_top_p:0.9}
22 | eval_presence_penalty: ${policy_eval_presence_penalty:0.0}
23 | eval_frequency_penalty: ${policy_eval_frequency_penalty:0.0}
24 | eval_repetition_penalty: ${policy_eval_repetition_penalty:1.0}
25 |
26 | # dataset
27 | vllm_prompt_key: ${vllm_prompt_key:prompt}
28 | vllm_input_ids_key: ${vllm_input_ids_key:input_ids}
29 | enable_thinking: ${enable_thinking:False}
30 |
31 | # scheduler config
32 | max_num_batched_tokens: ${max_num_batched_tokens:32768}
33 | max_seq_len_to_capture: ${max_seq_len_to_capture:32768}
34 | enable_stage_resume: ${enable_policy_stage_resume:False}
35 |
36 | # cache config
37 | gpu_memory_utilization: ${gpu_memory_utilization:0.85}
38 | enforce_eager: False
39 |
40 | # tokenizer path
41 | tokenizer: ${model_path}
--------------------------------------------------------------------------------
/examples/fsdp/data/data_preprocess/gsm8k.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess the GSM8k dataset to json format
3 | """
4 | import re
5 | import os
6 | import argparse
7 |
8 | import datasets
9 |
10 |
11 | def extract_solution(solution_str):
12 | solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
13 | assert solution is not None
14 | final_solution = solution.group(0)
15 | final_solution = final_solution.split('#### ')[1].replace(',', '')
16 | return final_solution
17 |
18 |
19 | if __name__ == '__main__':
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--input_dir', default=None)
22 | parser.add_argument('--local_dir', default='~/data/gsm8k')
23 |
24 | args = parser.parse_args()
25 |
26 | data_source = 'openai/gsm8k'
27 | data_dir = 'openai/gsm8k' if args.input_dir is None else args.input_dir
28 |
29 | dataset = datasets.load_dataset(data_dir, 'main')
30 |
31 | train_dataset = dataset['train']
32 | test_dataset = dataset['test']
33 |
34 | # instruction_following = "Let's think step by step and output the final answer after \"####\"."
35 | instruction_following = "Let's think step by step and output the final answer within \\boxed{}."
36 |
37 | # add a row to each data item that represents a unique id
38 | def make_map_fn(split):
39 |
40 | def process_fn(example, idx):
41 | question_raw = example.pop('question')
42 |
43 | question = question_raw + ' ' + instruction_following
44 |
45 | answer_raw = example.pop('answer')
46 | solution = extract_solution(answer_raw)
47 | data = {
48 | "data_source": data_source,
49 | "prompt": [{
50 | "role": "user",
51 | "content": question,
52 | }],
53 | "ability": "math",
54 | "reward_model": {
55 | "style": "rule",
56 | "ground_truth": solution
57 | },
58 | "extra_info": {
59 | 'split': split,
60 | 'index': idx,
61 | 'answer': answer_raw,
62 | "question": question_raw,
63 | }
64 | }
65 | return data
66 |
67 | return process_fn
68 |
69 | train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
70 | test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
71 |
72 | local_dir = args.local_dir
73 |
74 | train_dataset.to_json(os.path.join(local_dir, 'train.json'))
75 | test_dataset.to_json(os.path.join(local_dir, 'test.json'))
76 |
--------------------------------------------------------------------------------
/examples/fsdp/data/prompt_dataset.py:
--------------------------------------------------------------------------------
1 | """prompt dataset"""
2 |
3 | import copy
4 | from collections import defaultdict
5 | from typing import List, Dict
6 |
7 | import torch
8 | from torch.utils.data import Dataset
9 | import torch.nn.functional as F
10 | from transformers import AutoTokenizer
11 |
12 | from chatlearn.utils.utils import multi_thread_data_processing
13 |
14 |
15 | class VLLMPromptPipeline(Dataset):
16 | """
17 | process this format
18 | {
19 | "data_source": data_source,
20 | "prompt": [{
21 | "role": "user",
22 | "content": question,
23 | }],
24 | "ability": "math",
25 | "reward_model": {
26 | "style": "rule",
27 | "ground_truth": solution
28 | },
29 | "extra_info": {
30 | 'split': split,
31 | 'index': idx,
32 | 'answer': answer_raw,
33 | "question": question_raw,
34 | }
35 | }
36 | self.data format
37 | {"input_ids": prompt_ids, "prompt": prompt}
38 | """
39 |
40 | def __init__(self, data_list: List[Dict], seq_length: int, tokenizer: AutoTokenizer = None, num_inference_per_prompt: int = 1, enable_thinking = False):# pylint: disable=super-init-not-called
41 | super().__init__()
42 |
43 | self.tokenizer = tokenizer
44 | self.data = []
45 |
46 | for data_item in data_list:
47 | prompt = data_item["prompt"]
48 | data_source = data_item.get("data_source", "")
49 | ground_truth = data_item['reward_model']['ground_truth']
50 | if isinstance(prompt, list):
51 | prompt = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking)
52 | input_ids = self.tokenizer.encode(prompt)
53 | processed_data = {"input_ids": input_ids, "prompt": prompt, "data_source": data_source, "ground_truth": ground_truth}
54 | if seq_length > len(input_ids):
55 | self.data.extend([processed_data]*num_inference_per_prompt)
56 |
57 | def __getitem__(self, ix: int):
58 | return self.data[ix]
59 |
60 | def __len__(self) -> int:
61 | return len(self.data)
62 |
63 | def collate_fn(self, samples):
64 | collate_dict = defaultdict(list)
65 |
66 | # Loop over the samples and append each tensor value to the corresponding list
67 | for sample in samples:
68 | for key in sample.keys():
69 | collate_dict[key].append(sample[key])
70 |
71 | # Return the collate_dict
72 | return collate_dict
--------------------------------------------------------------------------------
/examples/fsdp/models/grpo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/examples/fsdp/models/grpo/__init__.py
--------------------------------------------------------------------------------
/examples/fsdp/models/grpo/loss_gallery.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 |
5 | def calculate_grpo_loss(
6 | log_probs: torch.Tensor,
7 | old_log_probs: torch.Tensor,
8 | advantages: torch.Tensor,
9 | diff_clip_ratio: float = 10,
10 | pos_clip_ratio: float = 0.2,
11 | negative_clip_ratio: float = 0.2,
12 | final_clip_ratio: float = 0.01,
13 | ):
14 | logprobs_diff = log_probs - old_log_probs
15 | # clip logprobs_diff before exp to avoid overflow
16 | logprobs_diff = torch.clamp(logprobs_diff, max=diff_clip_ratio)
17 |
18 | ratio = torch.exp(logprobs_diff)
19 | pg_loss = -advantages.unsqueeze(-1) * ratio
20 | # Upper and lower bound clip
21 | pg_loss_2 = -advantages.unsqueeze(-1) * torch.clamp(ratio, 1 - negative_clip_ratio, 1 + pos_clip_ratio)
22 | pg_loss_clip = torch.max(pg_loss, pg_loss_2)
23 | pg_loss_upperbound = torch.ones_like(pg_loss) * final_clip_ratio
24 | # final clip on loss
25 | loss = torch.min(pg_loss_clip, pg_loss_upperbound)
26 |
27 | # check pg_loss nan
28 | assert not torch.isnan(loss).any(), "pg loss is nan"
29 |
30 | return loss.contiguous()
31 |
--------------------------------------------------------------------------------
/examples/fsdp/models/rule_reward.py:
--------------------------------------------------------------------------------
1 |
2 | # Copyright 2024 Alibaba-inc. and/or its affiliates
3 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
16 |
17 | import argparse
18 | import copy
19 | import os
20 | from collections import defaultdict
21 | from typing import Dict
22 |
23 | import torch
24 | import numpy as np
25 |
26 | from chatlearn import BaseModule
27 | from examples.fsdp.utils.rule_reward_score import math
28 |
29 |
30 | class RuleReward(BaseModule):
31 |
32 | def setup(self):
33 | self.stats = {}
34 | self._metric_prefix = "rulereward"
35 |
36 | def _forward_step(self, data: Dict) -> torch.Tensor:
37 |
38 | # str_prompts_list = data["str_prompts"]
39 | str_outputs_list = data["str_outputs"]
40 | data_source_list = data["data_source"]
41 | ground_truth_list = data["ground_truth"]
42 | self._logger.info(f"RuleReward _forward_step Num of request: {len(str_outputs_list)}")
43 |
44 | reward_tensor = torch.zeros([len(str_outputs_list), 1], dtype=torch.float32)
45 |
46 | for i, str_output in enumerate(str_outputs_list):
47 | data_source = data_source_list[i]
48 | ground_truth = ground_truth_list[i]
49 | compute_score_fn = self.select_rule_reward_score_fn(data_source)
50 | reward_tensor[i] = compute_score_fn(str_output, ground_truth)
51 |
52 | res_dict = {"rule_rewards": reward_tensor}
53 | return res_dict
54 |
55 | def forward_step(self, data: Dict, iteration=0) -> Dict:
56 |
57 | res_dict = self._forward_step(data)
58 |
59 | # collect stats
60 | rule_rewards = res_dict["rule_rewards"]
61 | train_reward_score = rule_rewards.mean().item()
62 | train_reward_stats = {
63 | "train_reward_score": train_reward_score,
64 | }
65 | self._metric_list.append(train_reward_stats)
66 | return res_dict
67 |
68 | def eval_forward(self, data: Dict) -> Dict:
69 |
70 | return self._forward_step(data)
71 |
72 | def select_rule_reward_score_fn(self, data_source: str):
73 | if data_source in ['openai/gsm8k', 'DigitalLearningGmbH/MATH-lighteval']:
74 | return math.compute_score
75 | else:
76 | raise NotImplementedError
77 |
--------------------------------------------------------------------------------
/examples/fsdp/scripts/base_env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ray stop
4 |
5 | export CUDA_DEVICE_MAX_CONNECTIONS=1
6 | export NCCL_DEBUG=WARN
7 |
8 | [ -z "$MASTER_ADDR" ] && export MASTER_ADDR=localhost
9 | [ -z "$WORLD_SIZE" ] && export WORLD_SIZE=1
10 | [ -z "$GPUS_PER_NODE" ] && export GPUS_PER_NODE=${KUBERNETES_CONTAINER_RESOURCE_GPU:-$(python -c "import torch; print(torch.cuda.device_count())")}
11 | [ -z "$RANK" ] && export RANK=0
12 | [ -z "$MASTER_PORT" ] && export MASTER_PORT=12456
13 | [ -z "$NNODES" ] && export NNODES=${WORLD_SIZE:-1}
14 | [ -z "$NODE_RANK" ] && export NODE_RANK=${RANK:-0}
15 | if [ -z "${CUSTOM_PORTS}" ]; then
16 | set +x
17 | ports="30010"
18 | for i in $(seq 30011 30050); do
19 | ports="${ports};${i}"
20 | done
21 | set -x
22 | export CUSTOM_PORTS=$ports
23 | [ -z "$LOCAL_MASTER_ADDR" ] && export LOCAL_MASTER_ADDR=$MASTER_ADDR
24 | echo LOCAL_MASTER_ADDR=$MASTER_ADDR
25 | fi
26 |
27 | if [ -z "$CHATLEARN" ]; then
28 | echo "please set CHATLEARN path"
29 | exit 1
30 | fi
31 |
32 | rm core*
33 |
34 | export PYTHONPATH=${CHATLEARN}:${CHATLEARN}/examples/fsdp:${PYTHONPATH}
35 | export num_device=$(($WORLD_SIZE * $GPUS_PER_NODE))
36 |
--------------------------------------------------------------------------------
/examples/fsdp/scripts/train_grpo_qwen2_5.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 |
4 | # set path
5 | export CHATLEARN=$(pwd)
6 | export model_path="${CHATLEARN}/Qwen2.5-7B-Instruct"
7 | export exp_name=qwen2.5-grpo
8 | export output_dir=${CHATLEARN}/output/${exp_name}
9 | export train_data_path=${CHATLEARN}/dataset/MATH-lighteval/train.json
10 | export eval_data_path=${CHATLEARN}/dataset/MATH-lighteval/test.json
11 | export data_checkpoint_path=${output_dir}/data_checkpoint_path/
12 | log_file=$output_dir/log_${RANK}.log
13 | mkdir -p $output_dir/
14 | export log_dir=${output_dir}
15 | export wandb_dir=${output_dir}
16 |
17 | cd $CHATLEARN/examples/fsdp/
18 | source scripts/base_env.sh
19 |
20 | # Env setup
21 | # export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
22 | export RAY_DEDUP_LOGS=1
23 | export NCCL_NVLS_ENABLE=0
24 |
25 | # log setup
26 | export enable_wandb=False
27 | export wandb_project="grpo-exp"
28 | export WANDB_API_KEY="wandb-api-key"
29 |
30 | #Setup sequence_parallel
31 | export sp_size=1
32 |
33 | #VLLM setup
34 | export VLLM_USE_RAY_SPMD_WORKER=1
35 | export VLLM_USE_RAY_COMPILED_DAG=1
36 |
37 | export tensor_model_parallel_size=2
38 | export policy_temperature=1.0
39 | export policy_top_p=1.0
40 | export policy_top_k=-1
41 | export policy_eval_temperature=0.6
42 | export policy_eval_top_p=0.95
43 | export policy_eval_top_k=20
44 |
45 | export seq_length=2048
46 | export max_new_tokens=2048
47 | export max_seq_len_to_capture=2348
48 | export num_inference_per_prompt=32
49 | export train_global_batch_size=2048
50 | export sample_per_episode=2048
51 | export vllm_generation_batch_size=128
52 | export train_micro_batch_size=16
53 | export gpu_memory_utilization=0.85
54 |
55 | export enable_eval_before_training=True
56 | export num_episode=20
57 | export eval_episode_interval=5
58 | export save_episode_interval=20
59 |
60 | python entry/train_grpo.py -c configs/grpo/grpo.yaml 2>&1 | tee ${log_file}.log ; exit ${PIPESTATUS[0]}
61 |
--------------------------------------------------------------------------------
/examples/fsdp/scripts/train_grpo_qwen3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 |
4 | # set path
5 | export CHATLEARN=$(pwd)
6 | export model_path="${CHATLEARN}/Qwen3-8B"
7 | export exp_name=qwen3-grpo
8 | export output_dir=${CHATLEARN}/output/${exp_name}
9 | export train_data_path=${CHATLEARN}/dataset/MATH-lighteval/train.json
10 | export eval_data_path=${CHATLEARN}/dataset/MATH-lighteval/test.json
11 | export data_checkpoint_path=${output_dir}/data_checkpoint_path/
12 | log_file=$output_dir/log_${RANK}.log
13 | mkdir -p $output_dir/
14 | export log_dir=${output_dir}
15 | export wandb_dir=${output_dir}
16 |
17 | cd $CHATLEARN/examples/fsdp/
18 | source scripts/base_env.sh
19 |
20 | # Env setup
21 | # export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
22 | export RAY_DEDUP_LOGS=1
23 | export NCCL_NVLS_ENABLE=0
24 |
25 | # log setup
26 | export enable_wandb=False
27 | export wandb_project="grpo-exp"
28 | export WANDB_API_KEY="wandb-api-key"
29 |
30 | #Setup sequence_parallel
31 | export sp_size=1
32 |
33 | #VLLM setup
34 | export VLLM_USE_RAY_SPMD_WORKER=1
35 | export VLLM_USE_RAY_COMPILED_DAG=1
36 |
37 | export tensor_model_parallel_size=1
38 | export policy_temperature=1.0
39 | export policy_top_p=1.0
40 | export policy_top_k=-1
41 | export policy_eval_temperature=0.6
42 | export policy_eval_top_p=0.95
43 | export policy_eval_top_k=20
44 |
45 | export seq_length=2048
46 | export max_new_tokens=2048
47 | export max_seq_len_to_capture=2348
48 | export num_inference_per_prompt=32
49 | export train_global_batch_size=2048
50 | export sample_per_episode=2048
51 | export vllm_generation_batch_size=256
52 | export train_micro_batch_size=8
53 | export gpu_memory_utilization=0.80
54 |
55 | export enable_eval_before_training=False
56 | export num_episode=200
57 | export eval_episode_interval=5
58 | export save_episode_interval=400
59 | # for qwen3 where enable_thinking
60 | export enable_thinking=False
61 |
62 | python entry/train_grpo.py -c configs/grpo/grpo.yaml 2>&1 | tee ${log_file}.log ; exit ${PIPESTATUS[0]}
63 |
--------------------------------------------------------------------------------
/examples/fsdp/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/examples/fsdp/utils/__init__.py
--------------------------------------------------------------------------------
/examples/fsdp/utils/rule_reward_score/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/examples/fsdp/utils/rule_reward_score/__init__.py
--------------------------------------------------------------------------------
/examples/mcore/configs/grpo/base.yaml:
--------------------------------------------------------------------------------
1 | includes:
2 | - log.yaml
3 |
4 | seed: 1234
5 | tokenizer_type: ${tokenizer_type:NullTokenizer}
6 | patch_tokenizer_type: ${patch_tokenizer_type:NullTokenizer}
7 | tokenizer_model: ${tokenizer_model:"None"}
8 | vocab_file: ${vocab_file:"None"}
9 | merge_file: ${merge_file:"None"}
10 | vocab_size: ${vocab_size:32000}
11 | extra_vocab_size: ${extra_vocab_size:421}
12 | make_vocab_size_divisible_by: ${make_vocab_size_divisible_by:128}
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/examples/mcore/configs/grpo/grpo_qwen2_5.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | model_config_file: vllm_policy_inference.yaml
13 | num_gpu: ${num_device:1}
14 | trainable: False
15 | batch_generation:
16 | ranking: ${batch_generation_ranking:False}
17 | min_prompt_length: ${batch_generation_min_prompt_length:0}
18 | free_memory: ${free_memory_policy:True}
19 | generation_batch_size: ${vllm_generation_batch_size:8}
20 |
21 | reward:
22 | model_config_file: base.yaml
23 | num_cpu: 2 # must set devices
24 | generation_batch_size: ${vllm_generation_batch_size:8}
25 |
26 | ref_policy:
27 | model_config_file: policy_trainer_qwen2_5.yaml
28 | num_gpu: ${num_device:1}
29 | gpu_per_process: 1
30 | trainable: False
31 | free_memory: ${free_memory_reference:True}
32 | sync_frequency: ${sync_frequency:-1}
33 | generation_batch_size: ${trainer_generation_batch_size:8}
34 |
35 | policy_trainer:
36 | model_config_file: policy_trainer_qwen2_5.yaml
37 | num_gpu: ${num_device:1}
38 | gpu_per_process: 1
39 | trainable: True
40 | free_memory: ${free_memory_ppo_policy:True}
41 | generation_batch_size: ${trainer_generation_batch_size:8}
42 |
43 | runtime:
44 | colocation:
45 | - policy,policy_trainer,ref_policy
46 | data_path: ${train_data_path}
47 | eval_data_path: ${eval_data_path}
48 | output_dir: ${output_dir}
49 | exp_name: ${exp_name}
50 | num_episode: ${num_episode:200}
51 | sample_per_episode: ${sample_per_episode:1024}
52 | train_micro_batch_size: ${train_micro_batch_size:1}
53 | train_global_batch_size: ${train_global_batch_size:256}
54 | save_episode_interval: ${save_episode_interval:20}
55 | max_relay_episode: 2 # for enable grpo adv compute
56 | eval_episode_interval: ${eval_episode_interval:5}
57 | log_config_file: log.yaml
58 | data_checkpoint_path: ${data_checkpoint_path}
59 |
--------------------------------------------------------------------------------
/examples/mcore/configs/grpo/grpo_qwen3.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | model_config_file: vllm_policy_inference.yaml
13 | num_gpu: ${num_device:1}
14 | trainable: False
15 | batch_generation:
16 | ranking: ${batch_generation_ranking:False}
17 | min_prompt_length: ${batch_generation_min_prompt_length:0}
18 | free_memory: ${free_memory_policy:True}
19 | generation_batch_size: ${vllm_generation_batch_size:8}
20 |
21 | reward:
22 | model_config_file: base.yaml
23 | num_cpu: 2 # must set devices
24 | generation_batch_size: ${vllm_generation_batch_size:8}
25 |
26 | ref_policy:
27 | model_config_file: policy_trainer_qwen3.yaml
28 | num_gpu: ${num_device:1}
29 | gpu_per_process: 1
30 | trainable: False
31 | free_memory: ${free_memory_reference:True}
32 | sync_frequency: ${sync_frequency:-1}
33 | generation_batch_size: ${trainer_generation_batch_size:8}
34 |
35 | policy_trainer:
36 | model_config_file: policy_trainer_qwen3.yaml
37 | num_gpu: ${num_device:1}
38 | gpu_per_process: 1
39 | trainable: True
40 | free_memory: ${free_memory_ppo_policy:True}
41 | generation_batch_size: ${trainer_generation_batch_size:8}
42 |
43 | runtime:
44 | colocation:
45 | - policy,policy_trainer,ref_policy
46 | data_path: ${train_data_path}
47 | eval_data_path: ${eval_data_path}
48 | output_dir: ${output_dir}
49 | exp_name: ${exp_name}
50 | num_episode: ${num_episode:200}
51 | sample_per_episode: ${sample_per_episode:1024}
52 | train_micro_batch_size: ${train_micro_batch_size:1}
53 | train_global_batch_size: ${train_global_batch_size:256}
54 | save_episode_interval: ${save_episode_interval:20}
55 | max_relay_episode: 2 # for enable grpo adv compute
56 | eval_episode_interval: ${eval_episode_interval:5}
57 | log_config_file: log.yaml
58 | data_checkpoint_path: ${data_checkpoint_path}
59 |
--------------------------------------------------------------------------------
/examples/mcore/configs/grpo/log.yaml:
--------------------------------------------------------------------------------
1 | log_dir: ${log_dir:"None"}
2 | exp_name: ${exp_name:test}
3 | tensorboard_dir: ${tensorboard_dir:"None"}
4 | enable_tensorboard: ${enable_tensorboard:False}
5 |
6 | enable_wandb: ${enable_wandb:False}
7 | wandb_dir: ${wandb_dir}
8 | wandb_project: ${wandb_project}
9 | wandb_id: ${exp_name}
10 | wandb_name: ${exp_name}
11 | wandb_resume: ${WANDB_RESUME:allow}
--------------------------------------------------------------------------------
/examples/mcore/configs/grpo/model_qwen2_5.yaml:
--------------------------------------------------------------------------------
1 | attention_dropout: ${attention_dropout:0.0}
2 | hidden_dropout: ${hidden_dropout:0.0}
3 | num_layers: ${policy_num_layers:28}
4 | hidden_size: ${policy_hidden_size:3584}
5 | num_attention_heads: ${policy_num_attention_heads:28}
6 | ffn_hidden_size: ${policy_ffn_hidden_size:18944}
7 | num_query_groups: ${policy_num_query_groups:4}
8 | seq_length: ${seq_length:2048}
9 | max_position_embeddings: ${max_position_embeddings:131072}
10 | swiglu: True
11 | normalization: ${normalization:RMSNorm}
12 | norm_epsilon: ${RMS_NORM_EPS:1e-6}
13 | use_rotary_position_embeddings: True
14 | position_embedding_type: ${position_embedding_type:rope}
15 | add_qkv_bias: true
16 | add_bias_linear: false
17 | rotary_percent: 1.0
18 | rotary_base: 1000000
19 | rotary_seq_len_interpolation_factor: 1
20 | group_query_attention: True
21 | use_legacy_models: false
22 | untie_embeddings_and_output_weights: True
--------------------------------------------------------------------------------
/examples/mcore/configs/grpo/model_qwen3.yaml:
--------------------------------------------------------------------------------
1 | attention_dropout: ${attention_dropout:0.0}
2 | hidden_dropout: ${hidden_dropout:0.0}
3 | num_layers: ${policy_num_layers:28}
4 | hidden_size: ${policy_hidden_size:3584}
5 | num_attention_heads: ${policy_num_attention_heads:28}
6 | ffn_hidden_size: ${policy_ffn_hidden_size:18944}
7 | num_query_groups: ${policy_num_query_groups:4}
8 | seq_length: ${seq_length:2048}
9 | max_position_embeddings: ${max_position_embeddings:131072}
10 | swiglu: True
11 | normalization: ${normalization:RMSNorm}
12 | norm_epsilon: ${RMS_NORM_EPS:1e-6}
13 | use_rotary_position_embeddings: True
14 | position_embedding_type: ${position_embedding_type:rope}
15 | add_qkv_bias: false
16 | add_bias_linear: false
17 | qk_layernorm: true
18 | kv_channels: 128
19 | rotary_percent: 1.0
20 | rotary_base: 1000000
21 | rotary_seq_len_interpolation_factor: 1
22 | group_query_attention: True
23 | use_legacy_models: false
24 | untie_embeddings_and_output_weights: True
--------------------------------------------------------------------------------
/examples/mcore/configs/grpo/policy_trainer_qwen2_5.yaml:
--------------------------------------------------------------------------------
1 | includes:
2 | - base.yaml
3 | - model_qwen2_5.yaml
4 |
5 | load: ${load}
6 | save: ${save_dir}
7 | save_interval: ${save_interval:10000}
8 | train_iters: ${train_iters:12000}
9 | tensor_model_parallel_size: ${tensor_model_parallel_size:2}
10 | pipeline_model_parallel_size: ${pipeline_model_parallel_size:1}
11 | distributed_backend: nccl
12 |
13 | clip_grad: ${clip_grad:1.0}
14 | log_interval: 1
15 |
16 | bf16: True
17 | use_checkpoint_opt_param_scheduler: False
18 | adam_beta1: 0.9
19 | adam_beta2: 0.95
20 | num_workers: 8
21 | init_method_std: 0.006
22 |
23 | #recompute_method: uniform
24 | #recompute_granularity: full
25 | recompute_granularity: selective
26 | sequence_parallel: True
27 |
28 | no_load_optim: True
29 | no_load_rng: True
30 | no_load_args: True
31 | no_load_scheduler: True
32 | finetune: True
33 | dummy: 0
34 |
35 |
36 | lr_decay_iters: ${lr_decay_iters:12000}
37 | lr_warmup_iters: ${policy_lr_warmup_iters:100}
38 | lr: ${policy_lr:0.00000008}
39 | min_lr: ${policy_min_lr:0.000000008}
40 | lr_decay_style: ${policy_lr_decay_style:linear}
41 | weight_decay: 0.01
42 | lr_freeze_episodes: ${policy_lr_freeze_episodes:0}
43 |
44 |
45 | init_kl_coef: ${init_kl_coef:0.0}
46 | target: 6
47 | horizon: 10000
48 | gamma: 1
49 | lam: 0.95
50 | cliprange: 0.2
51 | diff_clip_ratio: ${diff_clip_ratio:10}
52 | pos_clip_ratio: ${pos_clip_ratio:0.2}
53 | neg_clip_ratio: ${neg_clip_ratio:0.2}
54 | final_clip_ratio: ${final_clip_ratio:3}
55 | logprob_cliprange: 10
56 | neg_cliprange2: 3
57 | cliprange_value: ${cliprange_value:0.1}
58 | scale_reward: ${scale_reward:null}
59 | clip_onlineness: ${clip_onlineness:0}
60 | cliprange_onlineness: ${cliprange_onlineness:0}
61 | train_to_compare_num_responses: ${train_to_compare_num_responses:2}
--------------------------------------------------------------------------------
/examples/mcore/configs/grpo/policy_trainer_qwen3.yaml:
--------------------------------------------------------------------------------
1 | includes:
2 | - base.yaml
3 | - model_qwen3.yaml
4 |
5 | load: ${load}
6 | save: ${save_dir}
7 | save_interval: ${save_interval:10000}
8 | train_iters: ${train_iters:12000}
9 | tensor_model_parallel_size: ${tensor_model_parallel_size:2}
10 | pipeline_model_parallel_size: ${pipeline_model_parallel_size:1}
11 | distributed_backend: nccl
12 |
13 | clip_grad: ${clip_grad:1.0}
14 | log_interval: 1
15 |
16 | bf16: True
17 | use_checkpoint_opt_param_scheduler: False
18 | adam_beta1: 0.9
19 | adam_beta2: 0.95
20 | num_workers: 8
21 | init_method_std: 0.006
22 |
23 | #recompute_method: uniform
24 | #recompute_granularity: full
25 | recompute_granularity: selective
26 | sequence_parallel: True
27 |
28 | no_load_optim: True
29 | no_load_rng: True
30 | no_load_args: True
31 | no_load_scheduler: True
32 | finetune: True
33 | dummy: 0
34 |
35 |
36 | lr_decay_iters: ${lr_decay_iters:12000}
37 | lr_warmup_iters: ${policy_lr_warmup_iters:100}
38 | lr: ${policy_lr:0.00000008}
39 | min_lr: ${policy_min_lr:0.000000008}
40 | lr_decay_style: ${policy_lr_decay_style:linear}
41 | weight_decay: 0.01
42 | lr_freeze_episodes: ${policy_lr_freeze_episodes:0}
43 |
44 |
45 | init_kl_coef: ${init_kl_coef:0.0}
46 | target: 6
47 | horizon: 10000
48 | gamma: 1
49 | lam: 0.95
50 | cliprange: 0.2
51 | diff_clip_ratio: ${diff_clip_ratio:10}
52 | pos_clip_ratio: ${pos_clip_ratio:0.2}
53 | neg_clip_ratio: ${neg_clip_ratio:0.2}
54 | final_clip_ratio: ${final_clip_ratio:3}
55 | logprob_cliprange: 10
56 | neg_cliprange2: 3
57 | cliprange_value: ${cliprange_value:0.1}
58 | scale_reward: ${scale_reward:null}
59 | clip_onlineness: ${clip_onlineness:0}
60 | cliprange_onlineness: ${cliprange_onlineness:0}
61 | train_to_compare_num_responses: ${train_to_compare_num_responses:2}
--------------------------------------------------------------------------------
/examples/mcore/configs/grpo/vllm_policy_inference.yaml:
--------------------------------------------------------------------------------
1 | includes:
2 | - base.yaml
3 |
4 | # model partition
5 | tensor_model_parallel_size: ${tensor_model_parallel_size:2}
6 | pipeline_model_parallel_size: 1
7 | num_inference_per_prompt: ${num_inference_per_prompt:8}
8 | seq_length: ${seq_length:1024}
9 | max_new_tokens: ${max_new_tokens:1023}
10 |
11 | # sampling params
12 | temperature: ${policy_temperature:1.0}
13 | top_p: ${policy_top_p:0.9}
14 | top_k: ${policy_top_k:-1}
15 | presence_penalty: ${policy_presence_penalty:0.0}
16 | frequency_penalty: ${policy_frequency_penalty:0.0}
17 | repetition_penalty: ${policy_repetition_penalty:1.0}
18 |
19 | eval_temperature: ${policy_eval_temperature:0.6}
20 | eval_top_k: ${policy_eval_top_k:-1}
21 | eval_top_p: ${policy_eval_top_p:0.9}
22 | eval_presence_penalty: ${policy_eval_presence_penalty:0.0}
23 | eval_frequency_penalty: ${policy_eval_frequency_penalty:0.0}
24 | eval_repetition_penalty: ${policy_eval_repetition_penalty:1.0}
25 |
26 | # dataset
27 | vllm_prompt_key: ${vllm_prompt_key:prompt}
28 | vllm_input_ids_key: ${vllm_input_ids_key:input_ids}
29 | enable_thinking: ${enable_thinking:False}
30 |
31 | # scheduler config
32 | max_num_batched_tokens: ${max_num_batched_tokens:32768}
33 | max_seq_len_to_capture: ${max_seq_len_to_capture:32768}
34 | enable_stage_resume: ${enable_policy_stage_resume:False}
35 |
36 | # cache config
37 | gpu_memory_utilization: ${gpu_memory_utilization:0.85}
38 | enforce_eager: False
39 |
40 | tokenizer: ${tokenizer_path}
41 |
42 |
--------------------------------------------------------------------------------
/examples/mcore/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
--------------------------------------------------------------------------------
/examples/mcore/models/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | import torch
17 |
18 | def pad_to_max_len(all_tokens_right_padded, max_len, pad_value):
19 | pad_length = max_len - all_tokens_right_padded.size(1)
20 | if pad_length <= 0:
21 | return all_tokens_right_padded
22 | # Pad the tensor with zeros on the right side to the desired length
23 | padded_tensor = torch.nn.functional.pad(all_tokens_right_padded, (0, pad_length), mode='constant', value=pad_value)
24 | return padded_tensor
25 |
26 | def generate_loss_mask_position_ids(tokens: torch.Tensor, prompt_token_length: list, response_token_length:list):
27 | # Setup attention mask by prompt token length and response token length
28 | loss_mask = torch.zeros_like(tokens, dtype=torch.int32, device=tokens.device)
29 | for i in range(len(prompt_token_length)):
30 | loss_mask[i, prompt_token_length[i]: prompt_token_length[i] + response_token_length[i]] = 1.0
31 | _, seq_len = tokens.size()
32 | position_ids = torch.arange(seq_len, dtype=torch.long, device=tokens.device)
33 | position_ids = position_ids.unsqueeze(0).expand_as(tokens)
34 |
35 | return loss_mask, position_ids
--------------------------------------------------------------------------------
/examples/mcore/scripts/train_grpo_qwen2_5.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 | ray stop
4 | rm -rf /tmp/ray/*
5 |
6 | # enveriment
7 | export CUDA_DEVICE_MAX_CONNECTIONS=1
8 | export RAY_num_server_call_thread=1
9 | export VLLM_USE_RAY_SPMD_WORKER=1
10 | export VLLM_USE_RAY_COMPILED_DAG=1
11 | export CHATLEARN=$(pwd)
12 | export PYTHONPATH=${MEGATRON_PATH}:${CHATLEARN}:${CHATLEARN}/examples:$PYTHONPATH
13 | export WORLD_SIZE=${WORLD_SIZE:-1}
14 | export RANK=${RANK:-0}
15 | export LOCAL_MASTER_ADDR=${MASTER_ADDR:-localhost}
16 | ports="30000"
17 | for i in $(seq 30001 30050); do
18 | ports="${ports};${i}"
19 | done
20 | export CUSTOM_PORTS=$ports
21 | export num_device=$(($WORLD_SIZE * 8))
22 |
23 | # data
24 | export train_data_path="/mnt/data/datasets/MATH-lighteval/train.json"
25 | export eval_data_path="/mnt/data/datasets/MATH-lighteval/test.json"
26 | export patch_tokenizer_type=Qwen2Tokenizer
27 | export extra_vocab_size=421
28 | export tokenizer_path="/mnt/data/qwen-ckpts/Qwen2.5-7B-Instruct"
29 | export load="/mnt/data/qwen-ckpts/Qwen2.5-7B-Instruct-hf-to-mcore-tp4-pp2"
30 |
31 | # model
32 | export max_position_embedding=131072
33 | export policy_num_layers=28
34 | export policy_hidden_size=3584
35 | export policy_num_attention_heads=28
36 | export policy_num_query_groups=4
37 | export policy_ffn_hidden_size=18944
38 | export tensor_model_parallel_size=4
39 | export pipeline_model_parallel_size=2
40 |
41 | # training
42 | export final_clip_ratio=3
43 | export clip_grad=1.0
44 | export seed=3407
45 | export policy_lr=2e-6
46 | export policy_min_lr=2e-6
47 | export eval_episode_interval=1
48 | export save_interval=100000
49 | export save_episode_interval=10000
50 | export num_episode=200
51 | export sample_per_episode=2048
52 | export save_episode_interval=10000
53 | export train_micro_batch_size=8
54 | export train_global_batch_size=2048
55 | export vllm_generation_batch_size=128
56 | export trainer_generation_batch_size=8
57 | export train_iters=$(( ${num_episode} * ${sample_per_episode} / ${train_global_batch_size} ))
58 | export policy_lr_warmup_iters=0
59 | export lr_decay_iters=160000
60 | export max_num_batched_tokens=65536
61 | export gpu_memory_utilization=0.85
62 |
63 | # vllm
64 | export seq_length=2048
65 | export max_new_tokens=2048
66 | export max_seq_len_to_capture=2348
67 | export num_inference_per_prompt=32
68 | export policy_temperature=1.0
69 | export policy_top_p=1.0
70 | export policy_top_k=-1
71 | export policy_eval_temperature=0.6
72 | export policy_eval_top_p=0.95
73 | export policy_eval_top_k=20
74 |
75 | # logging and saving
76 | export enable_tensorboard=True
77 | export enable_wandb=False
78 | export WANDB_API_KEY="wandb-api-key"
79 | export exp_name=qwen2_5_7B_lr${policy_lr}_mbs${train_micro_batch_size}_gbs${train_global_batch_size}_tp${tensor_model_parallel_size}_pp${pipeline_model_parallel_size}_${WORLD_SIZE}nodes
80 | export output_dir=${CHATLEARN}/output/${exp_name}
81 | mkdir -p $output_dir/
82 | export log_dir=${output_dir}/logs
83 | mkdir -p $log_dir
84 | log_file=$log_dir/${exp_name}_rank${RANK}.log
85 | export tensorboard_dir=${output_dir}/tensorboard
86 | export wandb_dir=${output_dir}
87 | export save_dir=${output_dir}
88 |
89 | cd $CHATLEARN/examples/mcore
90 | python entry/train_grpo.py -c configs/grpo/grpo_qwen2_5.yaml 2>&1 | tee ${log_file} ; exit ${PIPESTATUS[0]}
91 |
92 |
--------------------------------------------------------------------------------
/examples/mcore/scripts/train_grpo_qwen3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 | ray stop
4 | rm -rf /tmp/ray/*
5 |
6 | # enveriment
7 | export CUDA_DEVICE_MAX_CONNECTIONS=1
8 | export RAY_num_server_call_thread=1
9 | export VLLM_USE_RAY_SPMD_WORKER=1
10 | export VLLM_USE_RAY_COMPILED_DAG=1
11 | export CHATLEARN=$(pwd)
12 | export PYTHONPATH=${MEGATRON_PATH}:${CHATLEARN}:${CHATLEARN}/examples:$PYTHONPATH
13 | export WORLD_SIZE=${WORLD_SIZE:-1}
14 | export RANK=${RANK:-0}
15 | export LOCAL_MASTER_ADDR=${MASTER_ADDR:-localhost}
16 | ports="30000"
17 | for i in $(seq 30001 30050); do
18 | ports="${ports};${i}"
19 | done
20 | export CUSTOM_PORTS=$ports
21 | export num_device=$(($WORLD_SIZE * 8))
22 |
23 | # data
24 | export train_data_path="/mnt/data/datasets/MATH-lighteval/train.json"
25 | export eval_data_path="/mnt/data/datasets/MATH-lighteval/test.json"
26 | export patch_tokenizer_type=Qwen2Tokenizer
27 | export extra_vocab_size=293
28 | export tokenizer_path="/mnt/data/qwen-ckpts/Qwen3-8B"
29 | export load="/mnt/data/qwen-ckpts/Qwen3-8B-to-mcore/"
30 |
31 | # model
32 | export max_position_embedding=40960
33 | export policy_num_layers=36
34 | export policy_hidden_size=4096
35 | export policy_num_attention_heads=32
36 | export policy_num_query_groups=8
37 | export policy_ffn_hidden_size=12288
38 | export tensor_model_parallel_size=4
39 | export pipeline_model_parallel_size=2
40 |
41 | # training
42 | export final_clip_ratio=3
43 | export clip_grad=1.0
44 | export seed=3407
45 | export policy_lr=2e-6
46 | export policy_min_lr=2e-6
47 | export eval_episode_interval=1
48 | export save_interval=100000
49 | export save_episode_interval=10000
50 | export num_episode=200
51 | export sample_per_episode=2048
52 | export save_episode_interval=10000
53 | export train_micro_batch_size=8
54 | export train_global_batch_size=2048
55 | export vllm_generation_batch_size=128
56 | export trainer_generation_batch_size=8
57 | export train_iters=$(( ${num_episode} * ${sample_per_episode} / ${train_global_batch_size} ))
58 | export policy_lr_warmup_iters=0
59 | export lr_decay_iters=160000
60 | export max_num_batched_tokens=65536
61 | export gpu_memory_utilization=0.85
62 | # for qwen3 where enable_thinking
63 | export enable_thinking=False
64 |
65 | # vllm
66 | export seq_length=2048
67 | export max_new_tokens=2048
68 | export max_seq_len_to_capture=2348
69 | export num_inference_per_prompt=32
70 | export policy_temperature=1.0
71 | export policy_top_p=1.0
72 | export policy_top_k=-1
73 | export policy_eval_temperature=0.6
74 | export policy_eval_top_p=0.95
75 | export policy_eval_top_k=20
76 |
77 | # logging and saving
78 | export enable_tensorboard=True
79 | export enable_wandb=False
80 | export WANDB_API_KEY="wandb-api-key"
81 | export exp_name=qwen3_8B_lr${policy_lr}_mbs${train_micro_batch_size}_gbs${train_global_batch_size}_tp${tensor_model_parallel_size}_pp${pipeline_model_parallel_size}_${WORLD_SIZE}nodes
82 | export output_dir=${CHATLEARN}/output/${exp_name}
83 | mkdir -p $output_dir/
84 | export log_dir=${output_dir}/logs
85 | mkdir -p $log_dir
86 | log_file=$log_dir/${exp_name}_rank${RANK}.log
87 | export tensorboard_dir=${output_dir}/tensorboard
88 | export wandb_dir=${output_dir}
89 | export save_dir=${output_dir}
90 |
91 | cd $CHATLEARN/examples/mcore
92 | python entry/train_grpo.py -c configs/grpo/grpo_qwen3.yaml 2>&1 | tee ${log_file} ; exit ${PIPESTATUS[0]}
93 |
94 |
--------------------------------------------------------------------------------
/examples/tests/barrier.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | import argparse
17 | from datetime import timedelta
18 |
19 | import torch.distributed as dist
20 |
21 | parser = argparse.ArgumentParser(description="Barrier")
22 | parser.add_argument(
23 | '--timeout', type=int, default=None,
24 | help="Timeout in minutes"
25 | )
26 |
27 | def main():
28 | args = parser.parse_args()
29 |
30 | if args.timeout is not None:
31 | timeout = timedelta(minutes=args.timeout)
32 | else:
33 | timeout = None
34 |
35 | dist.init_process_group(backend="nccl", timeout=timeout)
36 | dist.destroy_process_group()
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ray[default]==2.32.0
2 | transformers==4.42.0
3 | pynvml==11.4.1
4 | deepspeed==0.14.4
5 | vllm==0.5.1
6 | accelerate
7 | jsonlines
8 | torchtyping
9 | tensorboard
10 | cupy
11 | # math related
12 | word2number
13 | timeout-decorator
14 | latex2sympy2==1.9.0
15 |
16 | # install apex if you needed
17 | # git clone https://github.com/NVIDIA/apex
18 | # cd apex
19 | # pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
20 |
21 | # install transformer engine if you needed
22 | # git+https://github.com/NVIDIA/TransformerEngine.git@v1.2.1
23 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | import sys
17 | from setuptools import setup, find_packages
18 |
19 | if sys.version_info[0] < 3:
20 | import imp
21 | VERSION = imp.load_source('chatlearn.version', 'chatlearn/utils/version.py').VERSION
22 | else:
23 | from importlib.machinery import SourceFileLoader
24 | VERSION = SourceFileLoader("chatlearn.version", "chatlearn/utils/version.py") \
25 | .load_module().VERSION
26 |
27 | setup(
28 | name='pai-chatlearn',
29 | version=VERSION,
30 | python_requires='>3.6.0',
31 | packages=find_packages(),
32 | include_package_data=True,
33 | install_requires=[],
34 | long_description="PAI ChatLearn",
35 | author='Alibaba Group',
36 | license='Apache 2.0',
37 | )
38 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/tests/__init__.py
--------------------------------------------------------------------------------
/tests/base/test_send_recv.py:
--------------------------------------------------------------------------------
1 |
2 | import ray
3 | import ray.util.collective as collective
4 | import torch
5 |
6 | import chatlearn
7 |
8 |
9 | @ray.remote(num_gpus=1)
10 | class Worker:
11 | def __init__(self):
12 | from chatlearn.launcher.initialize import patch_ray
13 | patch_ray()
14 | self.send1 = torch.ones((4,), dtype=torch.bfloat16, device="cuda")
15 | self.send1 = torch.nn.parameter.Parameter(self.send1)
16 | self.recv1 = torch.zeros((4,), dtype=torch.bfloat16, device="cuda")
17 | self.recv1 = torch.nn.parameter.Parameter(self.recv1)
18 | self.rank = -1
19 |
20 | def setup(self, world_size, rank):
21 | self.rank = rank
22 | collective.init_collective_group(world_size, rank, "nccl", "8")
23 | return True
24 |
25 | def compute(self, src, tgt):
26 | if self.rank == 0:
27 | collective.send(self.send1*2, tgt, "8")
28 | else:
29 | collective.recv(self.recv1, src, "8")
30 | return self.recv1
31 |
32 | def compute2(self):
33 | if self.rank == 0:
34 | collective.send_multigpu(self.send2 * 4, 1, 0, "8")
35 | else:
36 | collective.recv_multigpu(self.recv1, 0, 1, "8")
37 | return self.recv1
38 |
39 | def recv(self, src_rank, src_gpu):
40 | collective.recv_multigpu(self.recv1, src_rank, src_gpu, "8")
41 |
42 |
43 | def recv2(self, src_rank, src_gpu):
44 | collective.recv_multigpu(self.recv2, src_rank, src_gpu, "8")
45 | return self.recv2
46 |
47 | def destroy(self):
48 | collective.destroy_collective_group("8")
49 |
50 |
51 | def test_send_recv():
52 | num_workers = 3
53 | workers = []
54 | init_rets = []
55 | w0 = Worker.remote()
56 | init_rets.append(w0.setup.remote(num_workers, 0))
57 | w1 = Worker.remote()
58 | init_rets.append(w1.setup.remote(num_workers, 1))
59 | w2 = Worker.remote()
60 | init_rets.append(w2.setup.remote(num_workers, 2))
61 |
62 | workers = [w0, w1, w2]
63 | a = ray.get(init_rets)
64 | print('================== init done', a, flush=True)
65 | results = [w0.compute.remote(0, 1), w1.compute.remote(0, 1)]
66 | print(ray.get(results))
67 | print('send from w0 to w2', flush=True)
68 | results = [w0.compute.remote(0, 2), w2.compute.remote(0, 2)]
69 | print(ray.get(results))
70 |
71 | ray.get([w.destroy.remote() for w in workers])
72 |
73 | TEST_CASE = [test_send_recv]
--------------------------------------------------------------------------------
/tests/configs/base.yaml:
--------------------------------------------------------------------------------
1 | models:
2 | policy:
3 | model_config_file: model.yaml
4 | num_gpu: 1
5 | gpu_per_process: 1
6 | trainable: False
7 |
8 | reference:
9 | model_config_file: model.yaml
10 | num_gpu: 1
11 | gpu_per_process: 1
12 | trainable: False
13 |
14 |
15 | runtime:
16 | num_rollout_worker: 1
17 | num_iteration: 5000
18 | sample_per_episode: 1000
19 | num_training_epoch: ${num_training_epoch:3}
20 | unknown_args: "test_unknown"
21 |
--------------------------------------------------------------------------------
/tests/configs/eval.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | num_gpu: 2
13 | gpu_per_process: 1
14 | trainable: False
15 |
16 | reference:
17 | num_gpu: 1
18 | gpu_per_process: 1
19 | trainable: False
20 |
21 | reward:
22 | num_gpu: 2
23 | gpu_per_process: 1
24 | trainable: False
25 |
26 | value:
27 | num_gpu: 1
28 | gpu_per_process: 1
29 | trainable: False
30 |
31 | ppo_policy:
32 | num_gpu: 1
33 | gpu_per_process: 1
34 | trainable: True
35 | lora:
36 | enable_lora: ${enable_lora_policy:False}
37 | lora_dim: 64
38 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear
39 | column_only_qkv: False
40 | lora_dropout: 0.05
41 |
42 | ppo_value:
43 | num_gpu: 1
44 | gpu_per_process: 1
45 | trainable: True
46 |
47 | runtime:
48 | debug: True
49 | generation_batch_size: ${batch_size:4}
50 | train_micro_batch_size: 5
51 | train_global_batch_size: 10
52 | num_episode: 2
53 | sample_per_episode: 16
54 | num_training_epoch: 1
55 | save_episode_interval: 200
56 | eval_data_num_limit: 4
57 | eval_episode_interval: 1
58 |
--------------------------------------------------------------------------------
/tests/configs/exp1.yaml:
--------------------------------------------------------------------------------
1 | generate_config:
2 | num_beams: 1
3 | num_return_sequences: 1
4 | temperature: 1.0
5 | do_sample: True
6 | early_stopping: True
7 | top_k: 1
8 | top_p: 0.9
9 | repetition_penalty: 1.0
10 | length_penalty: 1.0
11 | min_length: 5
12 | max_length: 4096
13 | no_repeat_ngram_size: 2
14 | eos_token_id: 102
15 |
--------------------------------------------------------------------------------
/tests/configs/exp2.yaml:
--------------------------------------------------------------------------------
1 | test: 123
2 | model_config:
3 | attention_probs_dropout_prob: 0.3
4 | generate_config:
5 | eos_token_id: 103
6 |
--------------------------------------------------------------------------------
/tests/configs/grpo.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | num_gpu: 1
13 | gpu_per_process: 1
14 | trainable: False
15 |
16 | reference:
17 | num_gpu: 1
18 | gpu_per_process: 1
19 | trainable: False
20 |
21 | reward:
22 | num_gpu: 1
23 | gpu_per_process: 1
24 | trainable: False
25 |
26 | ppo_policy:
27 | num_gpu: 1
28 | gpu_per_process: 1
29 | trainable: True
30 | lora:
31 | enable_lora: ${enable_lora_policy:False}
32 | lora_dim: 64
33 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear
34 | column_only_qkv: False
35 | lora_dropout: 0.05
36 |
37 | runtime:
38 | debug: True
39 | generation_batch_size: ${batch_size:4}
40 | train_micro_batch_size: 5
41 | train_global_batch_size: 10
42 | num_episode: 2
43 | sample_per_episode: 16
44 | num_training_epoch: 1
45 | save_episode_interval: 200
46 |
--------------------------------------------------------------------------------
/tests/configs/model.yaml:
--------------------------------------------------------------------------------
1 | includes:
2 | - exp1.yaml
3 | - exp2.yaml
4 |
5 | model_config:
6 | attention_probs_dropout_prob: 0.1
7 | attention_type: "self"
8 | hidden_act: "gelu"
9 | hidden_dropout_prob: 0.1
10 | hidden_size: 768
11 | initializer_range: 0.02
12 | intermediate_size: 768
13 | layer_norm_eps: 1e-12
14 | layernorm_epsilon: 1e-12
15 | max_position_embeddings: 2048
16 | model_type: "gpt"
17 | num_attention_heads: 12
18 | num_hidden_layers: 12
19 | transformers_version: "4.22.0"
20 | type_vocab_size: 2
21 | vocab_size: 25600
22 |
--------------------------------------------------------------------------------
/tests/configs/o1.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 | - "mypipe"
9 | - "*.pt.trace.json"
10 | - "*.nsys-rep"
11 |
12 | models:
13 | mcts:
14 | model_config_file: model.yaml
15 | num_cpu: 4
16 | cpu_per_process: 1
17 | trainable: False
18 | policy:
19 | model_config_file: model.yaml
20 | num_gpu: 1
21 | gpu_per_process: 1
22 | trainable: False
23 |
24 | reward:
25 | model_config_file: model.yaml
26 | num_gpu: 1
27 | trainable: False
28 |
29 | reward1:
30 | model_config_file: model.yaml
31 | num_cpu: 1
32 | cpu_per_process: 1
33 | trainable: False
34 |
35 |
36 | runtime:
37 | generation_batch_size: ${generation_batch_size:1}
38 | num_episode: ${num_episode:1}
39 | sample_per_episode: ${sample_per_episode:4}
40 | num_training_epoch: 1
41 | save_episode_interval: ${save_episode_interval:1000}
42 | query_key: ${query_key:query}
43 | data_path: ${data_path:/path/to/data}
44 | training_data_num_limit: ${training_data_num_limit:-1}
45 | eval_episode_interval: ${eval_episode_interval:0}
46 | eval_data_num_limit: 20
47 | nsys: False
48 | free_sync_collective_group: ${free_sync_collective_group:False}
49 | param_sync_comm_type: ${param_sync_comm_type:broadcast}
50 | validate_param_sync: ${validate_param_sync:False}
51 |
--------------------------------------------------------------------------------
/tests/configs/parameter_sync.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | num_gpu: 1
13 | gpu_per_process: 1
14 | trainable: False
15 |
16 | ppo_policy:
17 | num_gpu: 1
18 | gpu_per_process: 1
19 | trainable: True
20 | lora:
21 | enable_lora: ${enable_lora_policy:False}
22 | lora_dim: 64
23 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear
24 | column_only_qkv: False
25 | lora_dropout: 0.05
26 |
27 | runtime:
28 | debug: False
29 | generation_batch_size: ${batch_size:4}
30 | train_micro_batch_size: 5
31 | train_global_batch_size: 10
32 | num_episode: 2
33 | sample_per_episode: 16
34 | num_training_epoch: 1
35 | save_episode_interval: 200
36 |
--------------------------------------------------------------------------------
/tests/configs/rlhf.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | num_gpu: 1
13 | gpu_per_process: 1
14 | trainable: False
15 |
16 | reference:
17 | num_gpu: 1
18 | gpu_per_process: 1
19 | trainable: False
20 |
21 | reward:
22 | num_gpu: 1
23 | gpu_per_process: 1
24 | trainable: False
25 |
26 | value:
27 | num_gpu: 1
28 | gpu_per_process: 1
29 | trainable: False
30 |
31 | ppo_policy:
32 | num_gpu: 1
33 | gpu_per_process: 1
34 | trainable: True
35 | lora:
36 | enable_lora: ${enable_lora_policy:False}
37 | lora_dim: 64
38 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear
39 | column_only_qkv: False
40 | lora_dropout: 0.05
41 |
42 | ppo_value:
43 | num_gpu: 1
44 | gpu_per_process: 1
45 | trainable: True
46 |
47 | runtime:
48 | debug: True
49 | generation_batch_size: ${batch_size:4}
50 | train_micro_batch_size: 5
51 | train_global_batch_size: 10
52 | num_episode: 2
53 | sample_per_episode: 16
54 | num_training_epoch: 1
55 | save_episode_interval: 200
56 | data_shuffle: False
57 | data_rerank: False
58 |
--------------------------------------------------------------------------------
/tests/configs/rlhf2.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | num_gpu: 1
13 | gpu_per_process: 1
14 | trainable: False
15 |
16 | reference:
17 | num_gpu: 1
18 | gpu_per_process: 1
19 | trainable: False
20 |
21 | reward:
22 | num_gpu: 1
23 | gpu_per_process: 1
24 | trainable: False
25 |
26 | value:
27 | num_gpu: 1
28 | gpu_per_process: 1
29 | trainable: True
30 |
31 | ppo_policy:
32 | num_gpu: 1
33 | gpu_per_process: 1
34 | trainable: True
35 | lora:
36 | enable_lora: ${enable_lora_policy:False}
37 | lora_dim: 64
38 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear
39 | column_only_qkv: False
40 | lora_dropout: 0.05
41 |
42 | runtime:
43 | debug: True
44 | generation_batch_size: ${batch_size:4}
45 | train_micro_batch_size: 5
46 | train_global_batch_size: 10
47 | num_episode: 2
48 | sample_per_episode: 16
49 | num_training_epoch: 1
50 | save_episode_interval: 200
51 |
--------------------------------------------------------------------------------
/tests/configs/rlhf_cpu.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | num_gpu: 1
13 | gpu_per_process: 1
14 | trainable: False
15 |
16 | reference:
17 | num_gpu: 1
18 | gpu_per_process: 1
19 | trainable: False
20 |
21 | reward:
22 | num_cpu: 2
23 | cpu_per_process: 1
24 | trainable: False
25 |
26 | value:
27 | num_gpu: 1
28 | gpu_per_process: 1
29 | trainable: False
30 |
31 | ppo_policy:
32 | num_gpu: 1
33 | gpu_per_process: 1
34 | trainable: True
35 | lora:
36 | enable_lora: ${enable_lora_policy:False}
37 | lora_dim: 64
38 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear
39 | column_only_qkv: False
40 | lora_dropout: 0.05
41 |
42 | ppo_value:
43 | num_gpu: 1
44 | gpu_per_process: 1
45 | trainable: True
46 |
47 | runtime:
48 | debug: True
49 | generation_batch_size: ${batch_size:4}
50 | train_micro_batch_size: 5
51 | train_global_batch_size: 10
52 | num_episode: 2
53 | sample_per_episode: 16
54 | num_training_epoch: 1
55 | save_episode_interval: 200
56 |
--------------------------------------------------------------------------------
/tests/configs/rlhf_eval.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | num_gpu: 2
13 | gpu_per_process: 1
14 | trainable: False
15 |
16 | reference:
17 | num_gpu: 1
18 | gpu_per_process: 1
19 | trainable: False
20 |
21 | reward:
22 | num_gpu: 2
23 | gpu_per_process: 1
24 | trainable: False
25 |
26 | value:
27 | num_gpu: 1
28 | gpu_per_process: 1
29 | trainable: False
30 |
31 | ppo_policy:
32 | num_gpu: 1
33 | gpu_per_process: 1
34 | trainable: True
35 | lora:
36 | enable_lora: ${enable_lora_policy:False}
37 | lora_dim: 64
38 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear
39 | column_only_qkv: False
40 | lora_dropout: 0.05
41 |
42 | ppo_value:
43 | num_gpu: 1
44 | gpu_per_process: 1
45 | trainable: True
46 |
47 | runtime:
48 | debug: True
49 | generation_batch_size: ${batch_size:4}
50 | train_micro_batch_size: 5
51 | train_global_batch_size: 10
52 | num_episode: 2
53 | sample_per_episode: 16
54 | num_training_epoch: 1
55 | save_episode_interval: 200
56 | eval_data_num_limit: 4
57 | eval_episode_interval: 1
58 |
--------------------------------------------------------------------------------
/tests/configs/sprl.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 | - "mypipe"
9 | - "*.pt.trace.json"
10 | - "*.nsys-rep"
11 |
12 | models:
13 | sprl:
14 | model_config_file: model.yaml
15 | num_cpu: 1
16 | cpu_per_process: 1
17 | trainable: False
18 | actor:
19 | model_config_file: model.yaml
20 | num_gpu: 1
21 | gpu_per_process: 1
22 | trainable: False
23 |
24 | critic:
25 | model_config_file: model.yaml
26 | num_gpu: 1
27 | gpu_per_process: 1
28 | trainable: False
29 |
30 | value:
31 | model_config_file: model.yaml
32 | num_gpu: 1
33 | gpu_per_process: 1
34 | trainable: False
35 |
36 | prm:
37 | model_config_file: model.yaml
38 | num_gpu: 1
39 | gpu_per_process: 1
40 | trainable: False
41 |
42 |
43 | runtime:
44 | generation_batch_size: ${generation_batch_size:1}
45 | num_episode: ${num_episode:2}
46 | sample_per_episode: ${sample_per_episode:4}
47 | train_global_batch_size: ${train_global_batch_size:1}
48 | train_micro_batch_size: ${train_micro_batch_size:1}
49 | num_training_epoch: 1
50 | save_episode_interval: ${save_episode_interval:1000}
51 | query_key: ${query_key:query}
52 | data_path: ${data_path:/path/to/data}
53 | training_data_num_limit: ${training_data_num_limit:-1}
54 | eval_episode_interval: ${eval_episode_interval:0}
55 | eval_data_num_limit: 20
56 | nsys: False
57 | free_sync_collective_group: ${free_sync_collective_group:False}
58 | param_sync_comm_type: ${param_sync_comm_type:broadcast}
59 | validate_param_sync: ${validate_param_sync:False}
60 |
--------------------------------------------------------------------------------
/tests/configs/test_eval.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | num_gpu: 2
13 | trainable: False
14 |
15 | runtime:
16 | debug: True
17 | generation_batch_size: 4
18 |
--------------------------------------------------------------------------------
/tests/configs/test_eval2.yaml:
--------------------------------------------------------------------------------
1 | runtime_env:
2 | platform: DLC
3 | excludes:
4 | - "*pt"
5 | - "logs"
6 | - "tensorboards"
7 | - ".nfs*"
8 |
9 |
10 | models:
11 | policy:
12 | num_gpu: 2
13 | trainable: False
14 |
15 | reward:
16 | num_gpu: 2
17 | trainable: False
18 |
19 | reward2:
20 | num_cpu: 2
21 | trainable: False
22 |
23 | runtime:
24 | debug: True
25 | generation_batch_size: 4
26 |
--------------------------------------------------------------------------------
/tests/parameter_sync/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/tests/parameter_sync/__init__.py
--------------------------------------------------------------------------------
/tests/rlhf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/tests/rlhf/__init__.py
--------------------------------------------------------------------------------
/tests/rlhf/test_rlhf_replica.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from torch.utils.data import Dataset
6 |
7 | import chatlearn
8 | from chatlearn.utils import future
9 | from chatlearn import RLHFEngine
10 | from chatlearn import TorchModule
11 |
12 | from utils import CustomDataset, PolicyModel, ReferenceModel, RewardModel, ValueModel, PPOPolicy, PPOValue
13 |
14 |
15 | def test_rlhf_replica():
16 | chatlearn.get_args().models["policy"].num_replica = 2
17 | policy = PolicyModel("policy")
18 | reference = ReferenceModel("reference")
19 | reward = RewardModel("reward")
20 | value = ValueModel("value")
21 | ppo_policy = PPOPolicy("ppo_policy")
22 | ppo_value = PPOValue("ppo_value")
23 |
24 | engine = RLHFEngine(policy, reference, reward, value, ppo_policy, ppo_value)
25 | #assert policy.num_replica == 2
26 |
27 | data = torch.ones([1024])
28 | engine.set_dataset([data] * 35)
29 | engine.setup()
30 | if policy.num_replica == 2:
31 | assert reference.num_replica == 1
32 | data = torch.ones([1024])
33 | assert len(engine.env._all_datasets[0]) == 35, len(engine.env._all_datasets[0])
34 | visible_devices = engine.models[0].replicas[0].get_visible_gpus()
35 | visible_devices = future.get(visible_devices)
36 | assert visible_devices == [[0]], visible_devices
37 | visible_devices = engine.models[0].replicas[1].get_visible_gpus()
38 | visible_devices = future.get(visible_devices)
39 | assert visible_devices == [[1]], visible_devices
40 | engine.stop()
41 |
42 | def test_rlhf_replica_2():
43 | chatlearn.get_args().models["policy"].num_replica = 2
44 | chatlearn.get_args().models["value"].num_replica = 2
45 | policy = PolicyModel("policy")
46 | reference = ReferenceModel("reference")
47 | reward = RewardModel("reward")
48 | value = ValueModel("value")
49 | ppo_policy = PPOPolicy("ppo_policy")
50 | ppo_value = PPOValue("ppo_value")
51 |
52 | engine = RLHFEngine(policy, reference, reward, value, ppo_policy, ppo_value)
53 | #assert policy.num_replica == 2
54 |
55 | data = torch.ones([1024])
56 | engine.set_dataset([data] * 35)
57 | engine.setup()
58 | if policy.num_replica == 2:
59 | assert reference.num_replica == 1
60 | data = torch.ones([1024])
61 | engine.set_dataset([data] * 35)
62 | assert len(engine.env._all_datasets[0]) == 35, len(engine.env._all_datasets[0])
63 | visible_devices = engine.models[0].replicas[0].get_visible_gpus()
64 | visible_devices = future.get(visible_devices)
65 | assert visible_devices == [[0]], visible_devices
66 | visible_devices = engine.models[0].replicas[1].get_visible_gpus()
67 | visible_devices = future.get(visible_devices)
68 | assert visible_devices == [[1]], visible_devices
69 | engine.stop()
70 |
71 |
72 | TEST_CASE = [test_rlhf_replica, test_rlhf_replica_2]
--------------------------------------------------------------------------------
/tests/run_tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=$(cd ../ && pwd):${PWD}:${PYTHONPATH}
3 | CDIR="$(cd "$(dirname "$0")" ; pwd -P)"
4 | LOGFILE=/tmp/pytorch_py_test.log
5 | rm -rf core*
6 | rm -rf /tmp/ray/*
7 |
8 | [ -z "$MASTER_ADDR" ] && export MASTER_ADDR=localhost
9 | [ -z "$WORLD_SIZE" ] && export WORLD_SIZE=1
10 | [ -z "$GPUS_PER_NODE" ] && export GPUS_PER_NODE=8
11 | [ -z "$RANK" ] && export RANK=0
12 |
13 | if [ -z "${CUSTOM_PORTS}" ]; then
14 | ports="30000"
15 | for i in $(seq 30001 30050); do
16 | ports="${ports};${i}"
17 | done
18 | export CUSTOM_PORTS=$ports
19 | [ -z "$LOCAL_MASTER_ADDR" ] && export LOCAL_MASTER_ADDR=$MASTER_ADDR
20 | echo LOCAL_MASTER_ADDR=$MASTER_ADDR
21 | fi
22 |
23 | if [ -d checkpoint ]; then
24 | rm -r checkpoint
25 | fi
26 | if [ -d checkpoint2 ]; then
27 | rm -r checkpoint2
28 | fi
29 |
30 |
31 | shift $(($OPTIND - 1))
32 |
33 |
34 | function run_test {
35 | rm -rf core*
36 | ray stop --force
37 | time "$@"
38 | }
39 |
40 | set -x
41 |
42 | TEST_CASES=(
43 | "unittest" # passed
44 | "base" # passed
45 | "rlhf" # partial passed
46 | #"parameter_sync" # partial passed
47 | #"eval" # to be fixed
48 | #"o1" # to be fixed
49 | #"sprl" # to be fixed
50 | )
51 | # Run ALL Tests in TEST_CASES
52 | for test_case in "${TEST_CASES[@]}"
53 | do
54 | run_test python test_main.py -t "$test_case" -c "configs/$test_case.yaml" || exit 1
55 | done
56 |
57 | # Usage: Run A Specified TestCase with case name
58 | #run_test python test_main.py -t "rlhf.test_rlhf_ckpt" -c "configs/rlhf.yaml"
59 |
60 | ray stop --force
61 |
--------------------------------------------------------------------------------
/tests/test_main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 | import argparse
4 | import glob
5 | import importlib
6 | from pathlib import Path
7 |
8 | from utils import run_test
9 |
10 | import chatlearn
11 |
12 |
13 | def _test_func(path, case_name):
14 | # Traverse all test_.*.py test files in target path
15 | test_modules = sorted([
16 | Path(f).stem for f in glob.glob(f"{path}/test_*.py")
17 | if not Path(f).stem.startswith("__") ])
18 |
19 | test_cases = []
20 | for module_name in test_modules:
21 | try:
22 | module = importlib.import_module(f"{path}.{module_name}")
23 | test_case = getattr(module, "TEST_CASE")
24 | test_cases.extend(test_case)
25 | except (ModuleNotFoundError, AttributeError) as e:
26 | print(f"加载失败: {module_name} ({type(e).__name__})")
27 |
28 | # init chatlearn framework once
29 | chatlearn.init()
30 |
31 | return run_test(test_cases, case_name)
32 |
33 | def _parse_args():
34 | parser = argparse.ArgumentParser(description="")
35 | parser.add_argument(
36 | "-t", "--test_case",
37 | type=str, required=True, help="test_case or test_case.case_name")
38 | return parser.parse_known_args()
39 |
40 | def _run_unit_tests(test_dir):
41 | import unittest
42 | import os
43 | discover_config = {
44 | "start_dir": test_dir,
45 | "pattern": "test_*.py",
46 | "top_level_dir": None
47 | }
48 |
49 | test_suite = unittest.defaultTestLoader.discover(**discover_config)
50 |
51 | runner = unittest.TextTestRunner(
52 | verbosity=2,
53 | failfast=False,
54 | buffer=False,
55 | resultclass=None
56 | )
57 |
58 | test_result = runner.run(test_suite)
59 |
60 | print(f"Total unittest cases: {test_result.testsRun}")
61 | print(f"Failed unittest cases: {len(test_result.failures)}")
62 | print(f"Error unittest cases: {len(test_result.errors)}")
63 | if len(test_result.failures) == 0 and len(test_result.errors) == 0:
64 | return 0 # UT Passed
65 | return 1
66 |
67 | if __name__ == "__main__":
68 | args, _ = _parse_args()
69 | test_case = args.test_case
70 | case_name = None
71 | args_split = test_case.split('.')
72 | if (len(args_split) == 2):
73 | test_case = args_split[0]
74 | case_name = args_split[1]
75 |
76 | if test_case == "unittest":
77 | sys.exit(_run_unit_tests("unittests"))
78 | sys.exit(_test_func(test_case, case_name))
--------------------------------------------------------------------------------
/tests/unittests/test_flat_tensors.py:
--------------------------------------------------------------------------------
1 | import random
2 | import unittest
3 |
4 | import torch
5 |
6 | from chatlearn.utils.flat_tensors import FlatTensors, BucketizedFlatTensors
7 |
8 |
9 | # pylint: disable=missing-class-docstring
10 | class TestFlatTensors(unittest.TestCase):
11 |
12 | @staticmethod
13 | def almost_same_memory_usage(t1, t2, eps):
14 | return abs(t1 / t2 - 1) < eps
15 |
16 | def run_flat_tensors_test_with_constructor(self, constructor):
17 | seed = 0
18 | random.seed(seed)
19 | torch.manual_seed(seed)
20 |
21 | measure1 = torch.cuda.memory_allocated()
22 | # Randomly generate some tensors.
23 | n = 4
24 | n_dims = [random.randint(1, 4) for _ in range(n)]
25 | shapes = [
26 | [random.randint(0, 8) for _ in range(dim)]
27 | for dim in n_dims
28 | ]
29 |
30 | tensors = [
31 | torch.rand(size=shape, device='cuda')
32 | for shape in shapes
33 | ]
34 | measure2 = torch.cuda.memory_allocated()
35 | tensors_usage = measure2 - measure1
36 |
37 | # Clone tensors for comparison.
38 | cloned = [
39 | tensor.detach().clone() for tensor in tensors
40 | ]
41 | measure3 = torch.cuda.memory_allocated()
42 | cloned_usage = measure3 - measure2
43 | self.almost_same_memory_usage(cloned_usage, tensors_usage, 1e-3)
44 |
45 | # Check after creating FlatTensors
46 | flat_tensor = constructor(tensors)
47 | for t, t_copied in zip(tensors, cloned):
48 | assert torch.equal(t, t_copied)
49 |
50 | # Check after offloaded.
51 | flat_tensor.copy_to_primary_store()
52 | for t in tensors:
53 | assert t.shape == torch.Size([0])
54 |
55 | measure4 = torch.cuda.memory_allocated()
56 | offloaded_memory = measure3 - measure4
57 | self.almost_same_memory_usage(offloaded_memory, tensors_usage, 1e-3)
58 |
59 | # Check after onloaded.
60 | flat_tensor.copy_to_gpu_buffer()
61 | measure5 = torch.cuda.memory_allocated()
62 | onloaded = measure5 - measure4
63 | self.almost_same_memory_usage(onloaded, tensors_usage, 1e-3)
64 |
65 | for t, t_copied in zip(tensors, cloned):
66 | assert torch.equal(t, t_copied)
67 |
68 | def test_flat_tensors(self):
69 | self.run_flat_tensors_test_with_constructor(
70 | lambda tensors: FlatTensors(tensors, primary_store_device='cpu')
71 | )
72 | torch.cuda.synchronize()
73 |
74 | def test_bucketized_flat_tensors(self):
75 | self.run_flat_tensors_test_with_constructor(
76 | lambda tensors: BucketizedFlatTensors(
77 | tensors, primary_store_device='cpu', bucket_size_mb=16
78 | )
79 | )
80 | torch.cuda.synchronize()
81 |
82 |
83 | if __name__ == '__main__':
84 | unittest.main()
85 |
--------------------------------------------------------------------------------
/tests/unittests/test_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """UnitTest for CollectiveTaskScheduler ."""
16 |
17 | from unittest import TestCase
18 | from chatlearn.synchronizer.scheduler import CollectiveTask, collective_task_scheduler, parallel_execute_collective_tasks
19 |
20 |
21 | class TestCollectiveTaskScheduler(TestCase):
22 | """Test for CollectiveTaskScheduler"""
23 | def test_scheduler(self):
24 | tasks = [
25 | CollectiveTask([1, 2], "group1"),
26 | CollectiveTask([1, 3], "group1"),
27 | CollectiveTask([2, 1], "group2"),
28 | CollectiveTask([4, 5], "group2"),
29 | ]
30 | generator = collective_task_scheduler(tasks)
31 | paralel_tasks = next(generator)
32 | self.assertEqual([1, 2], paralel_tasks[0].actors)
33 | self.assertEqual([4, 5], paralel_tasks[1].actors)
34 | paralel_tasks = next(generator)
35 | self.assertEqual([1, 3], paralel_tasks[0].actors)
36 | paralel_tasks = next(generator)
37 | self.assertEqual([2, 1], paralel_tasks[0].actors)
38 |
39 | def task_func(task):
40 | if task.actors[0] == 2:
41 | self.assertEqual("group2", task.group)
42 |
43 | parallel_execute_collective_tasks(tasks, task_func)
44 |
--------------------------------------------------------------------------------
/tests/unittests/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """UT for utils."""
16 |
17 | import unittest
18 | import ray
19 | from chatlearn.utils.utils import parse_function_return_num
20 | from chatlearn import get
21 | from chatlearn.utils.utils import split_index
22 |
23 |
24 | # pylint: disable=missing-class-docstring
25 | class TestDataset(unittest.TestCase):
26 |
27 | def test_function_return(self):
28 |
29 | def func():
30 | return 1, 2, 3
31 |
32 | res = parse_function_return_num(func)
33 | self.assertEqual(res, 3)
34 |
35 | def func2(aaa):
36 | if aaa > 0:
37 | return 1, 2
38 | else:
39 | return 3, 4
40 |
41 | res = parse_function_return_num(func2)
42 | self.assertEqual(res, 2)
43 |
44 | def func3():
45 | res = [1, 2, 3]
46 | return res
47 |
48 | res = parse_function_return_num(func3)
49 | self.assertEqual(res, 1)
50 |
51 | def func4():
52 | res = [1, 2, 3]
53 | return res, 1
54 |
55 | res = parse_function_return_num(func4)
56 | self.assertEqual(res, 2)
57 |
58 | def _test_get(self):
59 | ray.init()
60 | value = ray.put(1)
61 | data = (value, {1:1})
62 | data1 = get(data)
63 | self.assertEqual(data1, (1, {1:1}))
64 | data = ([value], {1:1})
65 | data1 = get(data)
66 | self.assertEqual(data1, ([1], {1:1}))
67 | value = ray.put({"a":2})
68 | data = ([value], {1:1})
69 | data1 = get(data)
70 | self.assertEqual(data1, ([{"a":2}], {1:1}))
71 |
72 | def test_split_index(self):
73 | length = 10
74 | num_splits = 3
75 | res = split_index(length,num_splits)
76 | self.assertEqual(res, [(0, 4), (4, 7), (7, 10)])
77 |
78 | # pylint: enable=missing-class-docstring
79 |
80 |
81 | if __name__ == '__main__':
82 | unittest.main()
83 |
--------------------------------------------------------------------------------