├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── build.yml │ ├── pylint.yml │ ├── python-publish.yml │ └── unit_test.yml ├── .gitignore ├── .pylintrc ├── LICENSE ├── Makefile ├── README.md ├── README_CN.md ├── chatlearn ├── __init__.py ├── algorithm │ ├── base_algo.py │ └── grpo.py ├── chatlearn.py ├── checkpoint │ ├── __init__.py │ └── checkpoint_manager.py ├── configs │ └── common.py ├── data │ ├── __init__.py │ ├── data.py │ └── sampler.py ├── hooks.py ├── launcher │ ├── __init__.py │ ├── dlc_utils.py │ └── initialize.py ├── models │ ├── __init__.py │ ├── base_module.py │ ├── fsdp_module.py │ ├── megatron │ │ ├── __init__.py │ │ ├── memory_manager │ │ │ ├── __init__.py │ │ │ ├── base_trainer.py │ │ │ ├── inference.py │ │ │ ├── trainer_v1v2.py │ │ │ ├── trainer_v3.py │ │ │ ├── trainer_v4.py │ │ │ └── trainer_v5.py │ │ └── ops │ │ │ ├── __init__.py │ │ │ └── policy_gradient.py │ ├── megatron_module.py │ ├── patches │ │ ├── monkey_patch.py │ │ └── transformers │ │ │ ├── qwen2_patch.py │ │ │ └── qwen3_patch.py │ ├── torch_module.py │ ├── vllm │ │ ├── __init__.py │ │ ├── hooks │ │ │ ├── __init__.py │ │ │ └── vllm_0_8_5 │ │ │ │ ├── __init__.py │ │ │ │ ├── async_llm_engine.py │ │ │ │ ├── llm.py │ │ │ │ ├── llm_engine.py │ │ │ │ ├── logits_processor.py │ │ │ │ ├── ray_distributed_executor.py │ │ │ │ └── worker_base.py │ │ └── inference.py │ └── vllm_module.py ├── runtime │ ├── __init__.py │ ├── decorator.py │ ├── dist_actor.py │ ├── engine.py │ ├── environment.py │ ├── evaluator.py │ ├── executor.py │ ├── model_flow.py │ ├── trainer.py │ └── utils.py ├── schedule │ ├── __init__.py │ ├── metric_manager.py │ ├── model_manager.py │ ├── port_manager.py │ └── resource_manager.py ├── synchronizer │ ├── __init__.py │ ├── base.py │ ├── megatron_megatron.py │ ├── megatron_vllm.py │ ├── parameter_sync.py │ ├── parameter_sync_fsdp.py │ └── scheduler.py ├── tools │ ├── check_parameter_sync.py │ ├── megatron_checkpoint_utils.py │ └── megatron_to_hf.py └── utils │ ├── __init__.py │ ├── arguments.py │ ├── communication_op.py │ ├── constant.py │ ├── dist_utils.py │ ├── error_monitor.py │ ├── flat_tensors.py │ ├── future.py │ ├── global_vars.py │ ├── log_monitor.py │ ├── logger.py │ ├── megatron_import_helper.py │ ├── megatron_import_hook_helper.py │ ├── megatron_import_memory_helper.py │ ├── megatron_import_transformer_helper.py │ ├── megatron_utils.py │ ├── timer.py │ ├── utils.py │ ├── version.py │ └── vllm_utils.py ├── docker └── torch │ ├── Dockerfile.torch2.3.0 │ ├── Dockerfile.torch2.5.1.vllm066 │ └── Dockerfile.torch2.6.0.vllm085 ├── docs ├── .gitignore ├── Makefile ├── README.md ├── en │ ├── .readthedocs.yaml │ ├── Makefile │ ├── advanced.rst │ ├── api │ │ ├── config.rst │ │ ├── engine.rst │ │ ├── index.rst │ │ └── module.rst │ ├── chatlearn.md │ ├── conf.py │ ├── config_yaml.md │ ├── faq.md │ ├── index.rst │ ├── installation.md │ ├── programming.md │ └── tutorial │ │ ├── check_sync_parameter.md │ │ ├── continue_train.md │ │ ├── custom_model_flow.md │ │ ├── data.md │ │ ├── ems.md │ │ ├── evaluator.md │ │ ├── profile.md │ │ ├── run.md │ │ ├── tutorial_grpo_fsdp.md │ │ ├── tutorial_grpo_mcore.md │ │ ├── tutorial_llama2.md │ │ └── tutorial_qwen.md ├── images │ ├── arch.png │ ├── class.png │ ├── dlc_1.jpg │ ├── dlc_2.jpg │ ├── engine.jpg │ ├── engine_class.png │ ├── fault.png │ ├── logo.jpg │ ├── perf.png │ ├── rlhf.png │ └── yaml.jpg ├── requirements.txt └── zh │ ├── .readthedocs.yaml │ ├── Makefile │ ├── advanced.rst │ ├── api │ ├── config.rst │ ├── engine.rst │ ├── index.rst │ └── module.rst │ ├── chatlearn.md │ ├── conf.py │ ├── config_yaml.md │ ├── faq.md │ ├── index.rst │ ├── installation.md │ ├── programming.md │ └── tutorial │ ├── continue_train.md │ ├── custom_model_flow.md │ ├── data.md │ ├── ems.md │ ├── evaluator.md │ ├── profile.md │ ├── run.md │ ├── tutorial_grpo_fsdp.md │ ├── tutorial_grpo_mcore.md │ ├── tutorial_llama2.md │ └── tutorial_qwen.md ├── examples ├── __init__.py ├── fsdp │ ├── configs │ │ └── grpo │ │ │ ├── base.yaml │ │ │ ├── grpo.yaml │ │ │ ├── log.yaml │ │ │ ├── policy_trainer.yaml │ │ │ ├── reference.yaml │ │ │ └── vllm_policy_inference.yaml │ ├── data │ │ ├── data_preprocess │ │ │ ├── gsm8k.py │ │ │ └── math_lighteval.py │ │ └── prompt_dataset.py │ ├── entry │ │ └── train_grpo.py │ ├── models │ │ ├── grpo │ │ │ ├── __init__.py │ │ │ ├── loss_gallery.py │ │ │ └── policy_trainer.py │ │ ├── rule_reward.py │ │ └── vllm_policy_inference.py │ ├── scripts │ │ ├── base_env.sh │ │ ├── train_grpo_qwen2_5.sh │ │ └── train_grpo_qwen3.sh │ └── utils │ │ ├── __init__.py │ │ └── rule_reward_score │ │ ├── __init__.py │ │ └── math.py ├── mcore │ ├── configs │ │ └── grpo │ │ │ ├── base.yaml │ │ │ ├── grpo_qwen2_5.yaml │ │ │ ├── grpo_qwen3.yaml │ │ │ ├── log.yaml │ │ │ ├── model_qwen2_5.yaml │ │ │ ├── model_qwen3.yaml │ │ │ ├── policy_trainer_qwen2_5.yaml │ │ │ ├── policy_trainer_qwen3.yaml │ │ │ └── vllm_policy_inference.yaml │ ├── entry │ │ └── train_grpo.py │ ├── models │ │ ├── __init__.py │ │ ├── policy_model.py │ │ ├── policy_trainer.py │ │ ├── train_helper.py │ │ └── utils.py │ ├── scripts │ │ ├── train_grpo_qwen2_5.sh │ │ └── train_grpo_qwen3.sh │ └── tokenizer │ │ └── __init__.py └── tests │ ├── barrier.py │ ├── benchmark_vllm.py │ └── benchmark_vllm.sh ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── base ├── test_exp.py ├── test_multi_dataloader.py └── test_send_recv.py ├── configs ├── base.yaml ├── eval.yaml ├── exp1.yaml ├── exp2.yaml ├── grpo.yaml ├── model.yaml ├── o1.yaml ├── parameter_sync.yaml ├── rlhf.yaml ├── rlhf2.yaml ├── rlhf_cpu.yaml ├── rlhf_eval.yaml ├── sprl.yaml ├── test_eval.yaml └── test_eval2.yaml ├── parameter_sync ├── __init__.py ├── test_balanced_hep.py ├── test_hep_ep_vllm_tp.py ├── test_hep_eppp_vllm_tp.py ├── test_hep_eppp_vllm_tppp.py ├── test_hep_eptp_vllm_tp.py ├── test_hep_eptp_vllm_tp_2.py ├── test_hep_eptppp_vllm_tp.py ├── test_hep_eptppp_vllm_tp_2.py ├── test_hep_eptppp_vllm_tppp.py ├── test_hep_tp_vllm_tp.py └── test_unbalanced_tp.py ├── rlhf ├── __init__.py ├── test_ckpt.py ├── test_data.py ├── test_grpo.py ├── test_indivisible_batchsz.py ├── test_model_flow.py ├── test_placement.py ├── test_placement_colocate.py ├── test_relay_buffer.py ├── test_rlhf_custom.py ├── test_rlhf_placement_colocate.py └── test_rlhf_replica.py ├── run_tests.sh ├── test_main.py ├── unittests ├── test_flat_tensors.py ├── test_sampler.py ├── test_scheduler.py └── test_utils.py └── utils.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | 22 | **Additional context** 23 | Add any other context about the problem here. 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Daily Building Script Execution 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | # Runs at 00:30 every day (UTC+8) 7 | - cron: '30 16 * * *' 8 | 9 | jobs: 10 | run-shell-script: 11 | runs-on: self-hosted 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v3 16 | 17 | - name: Run shell script 18 | run: | 19 | tar_name=chatlearn-$(date +%F).tar.gz 20 | tar czvf /tmp/${tar_name} . 21 | ossutil64 -i ${{ secrets.OSS_AK_ID }} -k ${{ secrets.OSS_AK_SECRET }} -e ${{ secrets.OSS_ENDPOINT }} cp -r /tmp/${tar_name} ${{ secrets.OSS_BUCKET }}/regression/chatlearn/src/ 22 | 23 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10"] 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install pylint==2.16.1 25 | - name: Analysing the code with pylint 26 | run: | 27 | make lint 28 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.github/workflows/unit_test.yml: -------------------------------------------------------------------------------- 1 | name: Unit Tests 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | branches: 7 | - main 8 | - dev 9 | paths-ignore: 10 | - 'docs/**' 11 | push: 12 | branches: 13 | - main 14 | paths-ignore: 15 | - 'docs/**' 16 | tags: 17 | - v[0-9]+.[0-9]+.[0-9]+ 18 | 19 | concurrency: 20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 21 | cancel-in-progress: true 22 | 23 | jobs: 24 | run-shell-script: 25 | runs-on: self-hosted 26 | 27 | steps: 28 | 29 | - name: Checkout code 30 | uses: actions/checkout@v4 31 | 32 | - name: Run unit test 33 | run: | 34 | containers=$(docker ps -aqf "name=^chatlearn_ut_") 35 | if [[ -n "$containers" ]]; then 36 | docker rm $containers -f 37 | fi 38 | docker pull $UT_IMAGE 39 | docker run -v $PWD:$PWD -w $PWD --name chatlearn_ut_$(date '+%d_%m_%Y_%H_%M_%S') --net host --ipc host --shm-size 80G -t --rm --gpus all $UT_IMAGE bash -c 'make test' 40 | env: 41 | UT_IMAGE: dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | examples/megatron/step3_rlhf/None 2 | .ipynb_checkpoints/ 3 | output/ 4 | build/ 5 | dist/ 6 | *.diff 7 | *.pyc 8 | *.idea 9 | *.DS_Store 10 | *.swp 11 | .nfs* 12 | *.dot 13 | *.egg-info/ 14 | docs/zh/api/doc/*.rst 15 | *._*py 16 | .vscode/ 17 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | # Disable the message, report, category or checker with the given id(s). You 3 | # can either give multiple identifiers separated by comma (,) or put this 4 | # option multiple times (only on the command line, not in the configuration 5 | # file where it should appear only once).You can also use "--disable=all" to 6 | # disable everything first and then reenable specific checks. For example, if 7 | # you want to run only the similarities checker, you can use "--disable=all 8 | # --enable=similarities". If you want to run only the classes checker, but have 9 | # no Warning level messages displayed, use"--disable=all --enable=classes 10 | # --disable=W" 11 | disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,superfluous-parens,consider-using-f-string,invalid-name,missing-function-docstring,inconsistent-return-statements,logging-fstring-interpolation,no-else-return,broad-exception-raised,broad-exception-caught,protected-access,too-many-lines 12 | 13 | # Maximum number of characters on a single line. 14 | max-line-length=150 15 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PYTHON ?= python3 2 | ADDITIONAL_DEPS ?= 3 | current_dir := $(shell pwd | sed 's:/*$$::') 4 | 5 | .PHONY: build 6 | build: $(LIB) 7 | $(PYTHON) setup.py bdist_wheel --universal 8 | @printf "\033[0;32mPIP package built\033[0m: " 9 | @ls dist/*.whl 10 | 11 | .PHONY: test 12 | test: $(LIB) 13 | cd tests; bash run_tests.sh 14 | 15 | .PHONY: lint 16 | lint: 17 | git config --global --add safe.directory $(current_dir) 18 | @$(PYTHON) -m pip install pylint==2.16.1 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com 19 | @$(PYTHON) -m pylint \ 20 | --rcfile=.pylintrc --output-format=parseable --jobs=8 \ 21 | $(shell git ls-tree --full-tree --name-only -r HEAD chatlearn | grep \.py$) \ 22 | $(shell git diff --cached --name-only chatlearn | grep \.py$) \ 23 | $(shell git ls-tree --full-tree --name-only -r HEAD examples/megatron/ | grep \.py$) \ 24 | $(shell git diff --cached --name-only examples/megatron/ | grep \.py$) \ 25 | 26 | .PHONY: doc 27 | doc: 28 | cd docs; make html 29 | 30 | 31 | .DEFAULT_GOAL := lint 32 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | 2 | [![docs](https://img.shields.io/badge/docs-latest-brightgreen.svg)](https://chatlearn.readthedocs.io/zh-cn/latest/) 3 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/alibaba/ChatLearn/blob/main/LICENSE) 4 | 5 |

6 | 7 | ChatLearn 8 | 9 |

10 | 11 |

12 | 灵活、易用、高效的大语言模型(LLMs)强化学习训练框架 13 |

14 |

15 | English  |  中文  16 |

17 | 18 | --- 19 | 20 | *最新进展* 🔥 21 | - [2025/5] 训练支持Mcore框架!基于Mcore和vLLM,我们提供了Qwen2.5模型的端到端GRPO训练[教学](docs/en/tutorial/tutorial_grpo_mcore.md)!🔥 22 | - [2025/5] 训练支持FSDP框架!基于FSDP和vLLM,我们提供了Qwen3模型的端到端GRPO训练[教学](docs/en/tutorial/tutorial_grpo_fsdp.md)!🔥 23 | - [2024/8] 正式开源 ChatLearn,更多介绍请参考我们的 [文档](docs/zh/chatlearn.md)。 24 | 25 | --- 26 | 27 | ChatLearn 是阿里云PAI团队开发的大规模LLMs强化学习训练框架。ChatLearn 通过对模型计算逻辑的抽象,解耦了模型和计算 backend、分布式策略的绑定,提供灵活的资源调度机制,可以支持灵活的资源分配和并行调度策略。 28 | 29 | ![RLHF Flow](docs/images/rlhf.png) 30 | 31 | ChatLearn的特点如下: 32 | 1. 🚀**易用的编程接口**: ChatLearn提供通用的编程抽象,用户只需要封装几个函数即可完成模型构造。用户只需要专注于单模型的编程,系统负责资源调度、数据流传输、控制流传输、分布式执行等。 33 | 2. 🔧**高可扩展的训练方式**: ChatLearn 提供 RLHF、DPO、OnlineDPO、GRPO 等 强化学习训练,同时也支持用户自定义 model 的执行 flow,使定制化训练流程变得非常便捷。 34 | 3. 🔄**多种分布式加速引擎**: 用户可以使用不同的计算 backend 进行模型建模,如 Megatron-LM、DeepSpeed、vLLM 等。用户也可以组合使用不同的 backend,如用 Megatron-LM 来进行加速训练,用 vLLM 来加速推理。 35 | 4. 🎯**灵活的并行策略和资源分配**: ChatLearn 支持不同模型配置不同的并行策略,可以结合各模型计算、显存、通信的特点来制定不同的并行策略。同时 ChatLearn 支持灵活的资源调度机制,支持各模型的资源独占或复用,通过系统调度策略支持高效的串行/并行执行和高效的显存共享。 36 | 5. ⚡**高性能**: 相较于当前的 SOTA 系统,ChatLearn 在 7B+7B (Policy+Reward) 规模性能提升52%,70B+70B 规模性能提升 137%。同时,ChatLearn 支持更大规模的 Alignment 训练,例如:300B+300B。 37 | 38 | # 快速开始 39 | 40 | 请参考 [文档](https://chatlearn.readthedocs.io/zh-cn/latest/) 快速开始. 41 | 42 | 1. [环境和代码准备](docs/zh/installation.md) 43 | 2. [基于 FSDP + vLLM的Qwen3模型端到端GRPO训练流程](docs/zh/tutorial/tutorial_grpo_fsdp.md) 44 | 3. [基于 LLaMA/LLaMA2 模型的端到端训练教程](docs/zh/tutorial/tutorial_llama2.md) 45 | 46 | 47 | # 性能评估 48 | 49 | 我们比较了不同参数量规模模型的 RLHF 训练吞吐量,我们采取 N+N 的模型配置,即 Policy 模型和 Reward 模型采用相同大小的参数量。我们和 DeepSpeed-Chat、OpenRLHF 对比了 7B 和 70B 的模型配置,在 8 GPUs 7B+7B 规模,有 115% 的加速,在 32 GPUs 70B+70B 规模,有 208% 的加速。规模越大,加速效果越明显。同时ChatLearn还能支持更大规模的 Alignment 训练,例如:300B+300B 规模。 50 | 51 | 52 | ![Compare Performance](docs/images/perf.png) 53 | 54 | 注:DeepSpeed-Chat和OpenRLHF性能已经优化过。 55 | 56 | # 功能列表 57 | 58 | - 支持 RLHF、DPO、OnlineDPO、GRPO 以及用户自定义的RL训练; 59 | - 支持 Megatron-LM,FSDP 作为训练的 backend,支持 vLLM 作为推理的 backend; 60 | - 支持 各模型独立配置并行策略,并支持模型间高效参数同步,自动进行并行策略转换; 61 | - 支持 EMS(Efficient Memory Sharing) 功能,支持模型间显存高效共享; 62 | - 支持模型的资源类型:GPU、CPU,例如定义纯 CPU 的 Math Reward 模型; 63 | - 支持 Megatron-Core 格式模型; 64 | 65 | # Roadmap 66 | 67 | ChatLearn 接下来会支持以下特性: 68 | - [ ] 简化参数配置 69 | - [ ] 提供MoE模型强化学习训练的教程 70 | - [ ] 支持更多的模型 71 | - [ ] 性能优化 72 | - [ ] 支持更多的强化学习算法 73 | 74 |

75 | 我们欢迎社区小伙伴参与进来合作开发,也欢迎加入钉钉群:98090003312 参与讨论。我们在持续招聘中,欢迎联系我们或者投递简历到[email](mailto:wanglin.zj@alibaba-inc.com)。 76 | -------------------------------------------------------------------------------- /chatlearn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """init""" 16 | 17 | import importlib 18 | 19 | from chatlearn import hooks 20 | from chatlearn.launcher.initialize import init 21 | from chatlearn.models.base_module import BaseModule 22 | from chatlearn.models.megatron_module import MegatronModule 23 | from chatlearn.models.torch_module import TorchModule 24 | from chatlearn.models.fsdp_module import FSDPModule 25 | from chatlearn.runtime.engine import Engine, RLHFEngine 26 | from chatlearn.runtime.engine import Environment 27 | from chatlearn.runtime.engine import Trainer 28 | from chatlearn.runtime.evaluator import Evaluator 29 | from chatlearn.runtime.model_flow import ControlDependencies 30 | from chatlearn.utils.future import get 31 | from chatlearn.utils.global_vars import get_args 32 | from chatlearn.utils.logger import logger 33 | -------------------------------------------------------------------------------- /chatlearn/algorithm/base_algo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """base algorithm""" 16 | 17 | from abc import ABC, abstractmethod 18 | 19 | class BaseAlgorithm(ABC): 20 | """BaseAlgorithm""" 21 | 22 | @abstractmethod 23 | def run(self): 24 | """ 25 | Run the algorithm. 26 | """ 27 | -------------------------------------------------------------------------------- /chatlearn/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /chatlearn/configs/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """common configs""" 16 | 17 | from dataclasses import dataclass, field 18 | from omegaconf import MISSING 19 | 20 | 21 | @dataclass 22 | class RuntimeEnvConfig: 23 | """RuntimeEnvConfig""" 24 | 25 | platform: str = field( 26 | default="DLC", 27 | metadata={"help": "Platform to run the model. Default is DLC."} 28 | ) 29 | 30 | 31 | @dataclass 32 | class BaseModelConfig: 33 | """BaseModelConfig""" 34 | 35 | seed: int = field( 36 | default=1234, 37 | metadata={"help": "Random seed. Default is 1234."} 38 | ) 39 | 40 | 41 | @dataclass 42 | class PolicyConfig(BaseModelConfig): 43 | """PolicyConfig""" 44 | 45 | num_gpus: int = field( 46 | default=1, 47 | metadata={"help": "Number of GPUs to use. Default is 1."} 48 | ) 49 | trainable: bool = field( 50 | default=False, 51 | metadata={"help": "Whether the policy is trainable. Default is False."} 52 | ) 53 | 54 | 55 | @dataclass 56 | class RewardConfig(BaseModelConfig): 57 | """RewardConfig""" 58 | 59 | num_cpus: int = field( 60 | default=2, 61 | metadata={"help": "Number of CPUs to use. Default is 1."} 62 | ) 63 | 64 | 65 | @dataclass 66 | class RefPolicyConfig(BaseModelConfig): 67 | """RefPolicyConfig""" 68 | 69 | fsdp_size: int = field( 70 | default=-1, 71 | metadata={"help": "FSDP size. Default is -1."} 72 | ) 73 | 74 | 75 | @dataclass 76 | class PolicyTrainerConfig(BaseModelConfig): 77 | """PolicyTrainerConfig""" 78 | 79 | free_memory: bool = field( 80 | default=True, 81 | metadata={"help": "Whether to free memory. Default is True."} 82 | ) 83 | 84 | 85 | @dataclass 86 | class RuntimeConfig: 87 | """RuntimeConfig""" 88 | 89 | colocation: list[str] = field( 90 | default_factory=list, 91 | metadata={"help": "List of modules to colocate. Default is empty."} 92 | ) 93 | data_path: str = field( 94 | default=MISSING, 95 | metadata={"help": "Path to the data file. Required."} 96 | ) 97 | eval_data_path: str = field( 98 | default=MISSING, 99 | metadata={"help": "Path to the evaluation data file. Required."} 100 | ) 101 | -------------------------------------------------------------------------------- /chatlearn/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /chatlearn/hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """hooks""" 16 | 17 | from chatlearn.models.vllm import hooks as vllm_hooks # pylint: disable=unused-import 18 | -------------------------------------------------------------------------------- /chatlearn/launcher/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /chatlearn/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Module related.""" 16 | -------------------------------------------------------------------------------- /chatlearn/models/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """megatron""" 16 | -------------------------------------------------------------------------------- /chatlearn/models/megatron/memory_manager/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Memery manager for Megatron modules which provides utilities to free memory when unused.""" 16 | 17 | from chatlearn.models.megatron.memory_manager.base_trainer import create_trainer_memory_manager 18 | from chatlearn.models.megatron.memory_manager.inference import InferenceMemoryManager 19 | -------------------------------------------------------------------------------- /chatlearn/models/megatron/memory_manager/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Inference Memery manager for Megatron.""" 16 | from typing import Optional, List 17 | 18 | from chatlearn.utils.flat_tensors import BucketizedFlatTensors 19 | from chatlearn.utils.logger import log_rank_0 20 | from chatlearn.utils.megatron_import_helper import DistributedDataParallel 21 | 22 | 23 | class InferenceMemoryManager: 24 | """ 25 | Memory manager for Megatron inference modules which provides utilities to free memory when unused. 26 | """ 27 | 28 | def __init__(self, model, bucket_size_mb=0): 29 | self._model = model 30 | 31 | assert not isinstance( 32 | model, (DistributedDataParallel,) 33 | ), f'Only support model type non-DistributedDataParallel, current type is {str(type(model))}.' 34 | 35 | self._weights_offloaded = False 36 | self._group_flat_weights: Optional[List[BucketizedFlatTensors]] = None 37 | self._bucket_size_mb = bucket_size_mb 38 | 39 | def offload_weights(self): 40 | """ 41 | offload weights 42 | """ 43 | if self._weights_offloaded: 44 | log_rank_0('Call offload_weights when already offloaded. Ignore it.') 45 | return 46 | 47 | if self._group_flat_weights is None: 48 | dtype_to_params = {} 49 | for p in self._model.parameters(): 50 | dtype = p.dtype 51 | if dtype not in dtype_to_params: 52 | dtype_to_params[dtype] = [] 53 | dtype_to_params[dtype].append(p) 54 | 55 | self._group_flat_weights = [] 56 | for params in dtype_to_params.values(): 57 | self._group_flat_weights.append( 58 | BucketizedFlatTensors(params, primary_store_device='cpu', bucket_size_mb=self._bucket_size_mb) 59 | ) 60 | 61 | for flat_weights in self._group_flat_weights: 62 | flat_weights.copy_to_primary_store() 63 | 64 | self._weights_offloaded = True 65 | 66 | def onload_weights(self): 67 | """ 68 | onload weights 69 | """ 70 | if not self._weights_offloaded: 71 | log_rank_0('Call onload_weights when already onloaded. Ignore it.') 72 | return 73 | 74 | for flat_weights in self._group_flat_weights: 75 | flat_weights.copy_to_gpu_buffer() 76 | 77 | self._weights_offloaded = False 78 | -------------------------------------------------------------------------------- /chatlearn/models/megatron/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ops.""" 16 | -------------------------------------------------------------------------------- /chatlearn/models/patches/monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Apply patches for different model architectures""" 16 | def apply_sp_monkey_patch(model_config): 17 | print(f"apply sequence parallel patches for {model_config.architectures}") 18 | if model_config.architectures[0] == "Qwen2ForCausalLM": 19 | from chatlearn.models.patches.transformers.qwen2_patch import register_sp_attention_forward \ 20 | # pylint: disable=import-outside-toplevel 21 | register_sp_attention_forward() 22 | elif model_config.architectures[0] == "Qwen3ForCausalLM": 23 | from chatlearn.models.patches.transformers.qwen3_patch import register_sp_attention_forward \ 24 | # pylint: disable=import-outside-toplevel 25 | register_sp_attention_forward() 26 | else: 27 | raise ValueError(f"Unsupported model architecture: {model_config.architectures}") 28 | -------------------------------------------------------------------------------- /chatlearn/models/vllm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """vLLM related.""" 16 | -------------------------------------------------------------------------------- /chatlearn/models/vllm/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """vLLM Hooks.""" 16 | 17 | import importlib 18 | import os 19 | import warnings 20 | 21 | 22 | if importlib.util.find_spec("vllm"): 23 | 24 | from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion 25 | 26 | if CURRENT_VLLM_VERSION == VLLMVersion.v_0_8_5: 27 | from .vllm_0_8_5 import * 28 | else: 29 | raise RuntimeError( 30 | f"vLLM version expected in {list(member.value for member in VLLMVersion)}, while {CURRENT_VLLM_VERSION}. \ 31 | if you want to use previous vllm version, please git checkout 4ad5912306df5d4a814dc2dd5567fcb26f5d473b") 32 | -------------------------------------------------------------------------------- /chatlearn/models/vllm/hooks/vllm_0_8_5/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional hooks of vllm-0.8.5.""" 16 | 17 | from . import async_llm_engine 18 | from . import llm 19 | from . import llm_engine 20 | from . import ray_distributed_executor 21 | from . import worker_base 22 | from . import logits_processor 23 | -------------------------------------------------------------------------------- /chatlearn/models/vllm/hooks/vllm_0_8_5/async_llm_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """del init_ray_cluster in AsyncLLMEngine.""" 16 | 17 | from typing import Dict, Optional 18 | 19 | # pylint: disable=unused-import,wildcard-import,unused-argument,not-callable 20 | from vllm.config import VllmConfig 21 | from vllm.engine import async_llm_engine 22 | from vllm.engine.arg_utils import AsyncEngineArgs 23 | from vllm.engine.metrics_types import StatLoggerBase 24 | from vllm.usage.usage_lib import UsageContext 25 | 26 | @classmethod 27 | def from_engine_args(cls, 28 | engine_args: AsyncEngineArgs, 29 | engine_config: Optional[VllmConfig] = None, 30 | start_engine_loop: bool = True, 31 | usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, 32 | stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, 33 | ) -> "AsyncLLMEngine": 34 | """Creates an async LLM engine from the engine arguments.""" 35 | # Create the engine configs. 36 | if engine_config is None: 37 | engine_config = engine_args.create_engine_config(usage_context) 38 | 39 | executor_class = cls._get_executor_cls(engine_config) 40 | 41 | # Create the async LLM engine. 42 | engine = cls( 43 | vllm_config=engine_config, 44 | executor_class=executor_class, 45 | log_requests=not engine_args.disable_log_requests, 46 | log_stats=not engine_args.disable_log_stats, 47 | start_engine_loop=start_engine_loop, 48 | usage_context=usage_context, 49 | stat_loggers=stat_loggers, 50 | ) 51 | return engine 52 | 53 | async_llm_engine.AsyncLLMEngine.from_engine_args = from_engine_args 54 | -------------------------------------------------------------------------------- /chatlearn/models/vllm/hooks/vllm_0_8_5/llm_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Hooks of vllm-0.8.5 llm_engine remove __reduce__ function.""" 16 | 17 | import inspect 18 | from typing import Dict, Optional 19 | 20 | # pylint: disable=unused-import,wildcard-import,unused-argument,wrong-import-order 21 | from chatlearn.utils.vllm_utils import vllm_use_v1 22 | from vllm.engine.metrics_types import StatLoggerBase 23 | from vllm.usage.usage_lib import UsageContext 24 | if vllm_use_v1(): 25 | from vllm.v1.engine import llm_engine 26 | else: 27 | from vllm.engine import llm_engine 28 | source = inspect.getsource(llm_engine.LLMEngine.__reduce__) 29 | if 'RuntimeError' in source: 30 | def __reduce__(self): 31 | # This is to ensure that the LLMEngine can be referenced in 32 | # the closure used to initialize Ray worker actors 33 | pass 34 | 35 | del llm_engine.LLMEngine.__reduce__ 36 | 37 | 38 | @classmethod 39 | def from_engine_args( 40 | cls, 41 | engine_args, 42 | usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, 43 | stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, 44 | ) -> "LLMEngine": 45 | """Creates an LLM engine from the engine arguments.""" 46 | # Create the engine configs. 47 | engine_config = engine_args.create_engine_config(usage_context) 48 | if vllm_use_v1(): 49 | from vllm.v1.executor.ray_distributed_executor import RayDistributedExecutor # pylint: disable=import-outside-toplevel 50 | executor_class = RayDistributedExecutor 51 | else: 52 | from vllm.executor.ray_distributed_executor import RayDistributedExecutor # pylint: disable=import-outside-toplevel 53 | executor_class = RayDistributedExecutor 54 | # Create the LLM engine. 55 | engine = cls( # pylint: disable=not-callable 56 | vllm_config=engine_config, 57 | executor_class=executor_class, 58 | log_stats=not engine_args.disable_log_stats, 59 | usage_context=usage_context, 60 | stat_loggers=stat_loggers, 61 | ) 62 | 63 | return engine 64 | llm_engine.LLMEngine.from_engine_args = from_engine_args 65 | -------------------------------------------------------------------------------- /chatlearn/models/vllm/hooks/vllm_0_8_5/logits_processor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Hooks of vllm-0.8.5 logits_processor to allgather logits of all ranks.""" 16 | 17 | import inspect 18 | 19 | # pylint: disable=wildcard-import,ungrouped-imports 20 | from vllm.model_executor.layers import logits_processor 21 | 22 | 23 | source = inspect.getsource(logits_processor.LogitsProcessor._gather_logits) 24 | if 'tensor_model_parallel_gather' in source: 25 | import torch 26 | def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: # pylint: disable=unused-argument 27 | from vllm.distributed import tensor_model_parallel_all_gather # pylint: disable=import-outside-toplevel 28 | return tensor_model_parallel_all_gather(logits) 29 | 30 | logits_processor.LogitsProcessor._gather_logits = _gather_logits 31 | -------------------------------------------------------------------------------- /chatlearn/models/vllm/hooks/vllm_0_8_5/worker_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Hooks of vllm-0.6.6 worker_base to update execute_method.""" 16 | 17 | # pylint: disable=unused-import,wildcard-import 18 | from typing import Union 19 | from vllm.worker import worker_base 20 | from vllm.worker.worker_base import logger 21 | from vllm.utils import run_method 22 | 23 | 24 | del worker_base.WorkerWrapperBase.__getattr__ 25 | def execute_method(self, method: Union[str, bytes], *args, **kwargs): 26 | try: 27 | if self.worker is None: 28 | target = self 29 | else: 30 | if hasattr(self, method): 31 | target = self 32 | else: 33 | target = self.worker 34 | # method resolution order: 35 | # if a method is defined in this class, it will be called directly. 36 | # otherwise, since we define `__getattr__` and redirect attribute 37 | # query to `self.worker`, the method will be called on the worker. 38 | return run_method(target, method, args, kwargs) 39 | except Exception as e: 40 | # if the driver worker also execute methods, 41 | # exceptions in the rest worker may cause deadlock in rpc like ray 42 | # see https://github.com/vllm-project/vllm/issues/3455 43 | # print the error and inform the user to solve the error 44 | msg = (f"Error executing method {method!r}. " 45 | "This might cause deadlock in distributed execution.") 46 | logger.exception(msg) 47 | raise e 48 | 49 | worker_base.WorkerWrapperBase.execute_method = execute_method 50 | -------------------------------------------------------------------------------- /chatlearn/models/vllm/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Inference Memery manager for Megatron.""" 16 | from typing import Optional, List 17 | 18 | from chatlearn.utils.flat_tensors import BucketizedFlatTensors 19 | from chatlearn.utils.logger import log_rank_0 20 | 21 | 22 | class InferenceMemoryManager: 23 | """ 24 | Memory manager for Megatron inference modules which provides utilities to free memory when unused. 25 | """ 26 | 27 | def __init__(self, model, bucket_size_mb=0): 28 | self._model = model 29 | self._weights_offloaded = False 30 | self._group_flat_weights: Optional[List[BucketizedFlatTensors]] = None 31 | self._bucket_size_mb = bucket_size_mb 32 | 33 | def offload_weights(self): 34 | """ 35 | offload weights 36 | """ 37 | if self._weights_offloaded: 38 | log_rank_0('Call offload_weights when already offloaded. Ignore it.') 39 | return 40 | 41 | if self._group_flat_weights is None: 42 | dtype_to_params = {} 43 | for p in self._model.parameters(): 44 | dtype = p.dtype 45 | if dtype not in dtype_to_params: 46 | dtype_to_params[dtype] = [] 47 | dtype_to_params[dtype].append(p) 48 | 49 | self._group_flat_weights = [] 50 | for params in dtype_to_params.values(): 51 | self._group_flat_weights.append( 52 | BucketizedFlatTensors(params, primary_store_device='cpu', bucket_size_mb=self._bucket_size_mb) 53 | ) 54 | 55 | for flat_weights in self._group_flat_weights: 56 | flat_weights.copy_to_primary_store() 57 | 58 | self._weights_offloaded = True 59 | 60 | def onload_weights(self): 61 | """ 62 | onload weights 63 | """ 64 | if not self._weights_offloaded: 65 | log_rank_0('Call onload_weights when already onloaded. Ignore it.') 66 | return 67 | 68 | for flat_weights in self._group_flat_weights: 69 | flat_weights.copy_to_gpu_buffer() 70 | 71 | self._weights_offloaded = False 72 | -------------------------------------------------------------------------------- /chatlearn/runtime/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /chatlearn/runtime/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """runtime utils""" 16 | 17 | import ast 18 | import textwrap 19 | import inspect 20 | from collections import defaultdict 21 | 22 | def encode_data(mb, data): 23 | return {"iter": mb, "data": data} 24 | 25 | 26 | def decode_data(data): 27 | mb = data["iter"] 28 | data = data["data"] 29 | return mb, data 30 | 31 | 32 | def parse_assign_target(line): 33 | targets = [] 34 | for target in line.targets: 35 | targets.append(target.id) 36 | return targets 37 | 38 | 39 | def parse_expr(line): 40 | func = line.value.func 41 | func_name = func.attr 42 | func_args = [arg.id for arg in line.value.args] 43 | if isinstance(func.value, ast.Name): 44 | model_name = func.value.id 45 | else: 46 | model_name = func.value.attr 47 | return func_name, model_name, func_args 48 | 49 | 50 | class FlowParser: 51 | """Flow Parser""" 52 | 53 | def __init__(self): 54 | self.model_to_call_funcs = defaultdict(list) 55 | 56 | def parse_assign(self, line): 57 | func_name, model_name, _ = parse_expr(line) 58 | model = self.global_models[model_name] 59 | self.model_to_call_funcs[model].append(func_name) 60 | 61 | def visit_func(self, node): 62 | for line in node.body: 63 | if isinstance(line, (ast.Assign, ast.Expr)): 64 | self.parse_assign(line) 65 | elif isinstance(line, ast.With): 66 | for line0 in line.body: 67 | if isinstance(line0, (ast.Assign, ast.Expr)): 68 | self.parse_assign(line0) 69 | 70 | def parse(self, func): 71 | closure_vars = inspect.getclosurevars(func) 72 | self.global_models = {} 73 | if closure_vars.globals: 74 | self.global_models.update(closure_vars.globals) 75 | if closure_vars.nonlocals: 76 | self.global_models.update(closure_vars.nonlocals) 77 | node_iter = ast.NodeVisitor() 78 | node_iter.visit_FunctionDef = self.visit_func 79 | if isinstance(func, str): 80 | code = textwrap.dedent(func) 81 | else: 82 | code = textwrap.dedent(inspect.getsource(func)) 83 | node_iter.visit(ast.parse(code)) 84 | return self.model_to_call_funcs 85 | -------------------------------------------------------------------------------- /chatlearn/schedule/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /chatlearn/schedule/port_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """port manager""" 16 | 17 | from multiprocessing import Lock 18 | import ray 19 | 20 | @ray.remote 21 | class PortManager: 22 | """port manager""" 23 | 24 | def __init__(self, port_list): 25 | self._port_list = port_list 26 | self._address_to_port_index = {} 27 | self._lock = Lock() 28 | 29 | def get_free_port(self, address): 30 | self._lock.acquire() 31 | free_port = None 32 | try: 33 | port_index = self._address_to_port_index.get(address, 0) 34 | assert port_index < len(self._port_list) 35 | self._address_to_port_index[address] = port_index + 1 36 | free_port = self._port_list[port_index] 37 | finally: 38 | self._lock.release() 39 | return free_port 40 | -------------------------------------------------------------------------------- /chatlearn/synchronizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """synchronizer""" 16 | 17 | from transformers import AutoConfig 18 | from chatlearn.models.megatron_module import MegatronModule 19 | from chatlearn.models.vllm_module import VLLMModule 20 | from chatlearn.runtime.dist_actor import DistModel 21 | from .base import BaseSync 22 | from .megatron_megatron import MegatronMegatronSync 23 | from .megatron_vllm import( 24 | MegatronVllmQWenSync, 25 | MegatronVllmQWen2Sync, 26 | MegatronVllmLlamaSync, 27 | MegatronVllmMoonlightSync, 28 | MegatronVllmQWen2MCoreSync 29 | ) 30 | 31 | def get_synchronizer(src_model, dst_model): 32 | assert isinstance(src_model, DistModel) 33 | assert isinstance(dst_model, DistModel) 34 | src_model = src_model.replicas[0].model 35 | dst_model = dst_model.replicas[0].model 36 | if isinstance(src_model, MegatronModule) and isinstance(dst_model, MegatronModule): 37 | return MegatronMegatronSync(src_model, dst_model) 38 | elif isinstance(src_model, MegatronModule) and isinstance(dst_model, VLLMModule): 39 | config_dir = dst_model.module_args.args_dict["tokenizer"] 40 | config = AutoConfig.from_pretrained(config_dir, trust_remote_code=True) 41 | model_class_name = config.architectures[0] 42 | if model_class_name == "QWenLMHeadModel": 43 | return MegatronVllmQWenSync(src_model, dst_model) 44 | elif model_class_name in ["Qwen2ForCausalLM", "Qwen2MoeForCausalLM"]: 45 | # NOTE: check if the model is mcore or not 46 | if src_model.module_args.args_dict.get("use_legacy_models", True): 47 | return MegatronVllmQWen2Sync(src_model, dst_model) 48 | return MegatronVllmQWen2MCoreSync(src_model, dst_model) 49 | elif model_class_name == "LlamaForCausalLM": 50 | return MegatronVllmLlamaSync(src_model, dst_model) 51 | elif model_class_name in ["DeepseekV3ForCausalLM", "Qwen3ForCausalLM"]: 52 | return MegatronVllmMoonlightSync(src_model, dst_model) 53 | else: 54 | raise RuntimeError( 55 | f"Unsupported model {model_class_name}, Expect QWenLMHeadModel, Qwen2ForCausalLM, Qwen2MoeForCausalLM or LlamaForCausalLM.") 56 | else: 57 | return BaseSync(src_model, dst_model) 58 | -------------------------------------------------------------------------------- /chatlearn/synchronizer/megatron_megatron.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """megatron to megatron synchronizer""" 16 | 17 | from chatlearn.utils import future 18 | from .base import BaseSync 19 | 20 | class MegatronMegatronSync(BaseSync): 21 | """megatron to megatron synchronizer""" 22 | 23 | def _get_dst_name(self, src_name, src_prefix, dst_prefix): 24 | if src_prefix: 25 | dst_name = src_name[len(src_prefix):] 26 | else: 27 | dst_name = dst_prefix + src_name 28 | return dst_name 29 | 30 | def set_model_prefix(self, src_names, dst_names): 31 | dst_prefix = None 32 | src_prefix = None 33 | for sname in src_names: 34 | for dname in dst_names: 35 | if sname in dname: 36 | prefix = dname[:dname.index(sname)] 37 | dst_prefix = prefix 38 | return src_prefix, dst_prefix 39 | elif dname in sname: 40 | prefix = sname[:sname.index(dname)] 41 | src_prefix = prefix 42 | return src_prefix, dst_prefix 43 | if dst_prefix is None and src_prefix is None: 44 | raise RuntimeError("Cannot find prefix") 45 | return src_prefix, dst_prefix 46 | 47 | def map_name_from_src_to_dst(self, send_actor, recv_actor, src_names, dst_names): 48 | dst_names_ref = future.get(recv_actor.get_parameter_names.remote(requires_grad=False)) 49 | src_prefix, dst_prefix = self.set_model_prefix(src_names, dst_names_ref) 50 | dst_names = [self._get_dst_name(name, src_prefix, dst_prefix) for name in dst_names] 51 | return src_names, dst_names 52 | -------------------------------------------------------------------------------- /chatlearn/synchronizer/parameter_sync_fsdp.py: -------------------------------------------------------------------------------- 1 | """fsdp to vllm parameter sync group""" 2 | import ray 3 | from chatlearn.utils import future 4 | from chatlearn.runtime.dist_actor import DistModel 5 | from chatlearn.utils.error_monitor import ErrorSignalActor 6 | 7 | 8 | def flatten(lst: list, reverse=False): 9 | result = [] 10 | for item in lst: 11 | if reverse: 12 | result += item[::-1] 13 | else: 14 | result += item 15 | return result 16 | 17 | 18 | class FSDP2VllmParameterSyncGroup: 19 | """fsdp to vllm parameter sync group 20 | """ 21 | def __init__( 22 | self, 23 | src_model: DistModel, 24 | dst_model: DistModel, 25 | group_name: str, 26 | frequency: int, 27 | error_signal: ErrorSignalActor, 28 | ): 29 | self.src_model = src_model 30 | self.dst_model = dst_model 31 | self.group_name = group_name 32 | self.error_signal = error_signal 33 | self.frequency = frequency 34 | 35 | self.setup_collective_group() 36 | 37 | def setup_collective_group(self): 38 | # we put src_model first, so we don't need to change the rank of training model 39 | models = [self.src_model, self.dst_model] 40 | 41 | rank_offset = 0 42 | for model in models: 43 | for replica in model.replicas: 44 | replica._setup_ranks(rank_offset) 45 | rank_offset += replica.actor_num 46 | 47 | def sync(self, *args, **kwargs): # pylint: disable=unused-argument 48 | """ 49 | sync function for fsdp to vllm 50 | """ 51 | # for fsdp to vllm, we only need to find the src and dst actors that are on the same GPU. 52 | src_model_ranks = flatten(self.src_model.all_ranks) 53 | # adapt for model manager: models_to_revert 54 | dst_model_ranks = flatten(self.dst_model.all_ranks, reverse=True) 55 | 56 | param_name_list = ray.get(self.src_model.get_actor(0).get_fsdp_param_name.remote()) 57 | 58 | for param_name in param_name_list: 59 | 60 | refs = [] 61 | for src_rank, dst_rank in zip(src_model_ranks, dst_model_ranks): 62 | src_actor = self.src_model.get_actor(src_rank) 63 | dst_actor = self.dst_model.get_actor(dst_rank) 64 | reduce_data_ref = src_actor.get_weight_ipc_handles_by_name.remote(param_name) 65 | ref = dst_actor.update_weights_from_ipc_handles.remote(reduce_data_ref) 66 | refs.append(ref) 67 | future.wait(refs, return_output=True) 68 | -------------------------------------------------------------------------------- /chatlearn/synchronizer/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 16 | CollectiveTaskScheduler uses two queue to schedule collective task: 17 | - TodoQueue maintenances all tasks that to be executed 18 | - PendingQueue maintenances execution tasks 19 | 20 | This scheduler resort the remote tasks to avoid collective operator hang. 21 | """ 22 | from queue import Queue 23 | import concurrent 24 | import traceback 25 | from concurrent.futures import ThreadPoolExecutor 26 | 27 | from chatlearn.utils.logger import logger 28 | 29 | class CollectiveTask: 30 | """ColleciteTask represents a group of actors to execute a collective task""" 31 | def __init__(self, actors, group): 32 | self.actors = actors 33 | self.group = group 34 | 35 | def collective_task_scheduler(tasks): 36 | 37 | todo_queue = Queue() 38 | pending_queue = [] 39 | _ = [todo_queue.put(task) for task in tasks] 40 | 41 | while not todo_queue.empty(): 42 | send_actors_set = set() 43 | recv_actors_set = set() 44 | list_count = todo_queue.qsize() 45 | # re-put it if conflict, otherwise put it to PendingQueue 46 | for _ in range(list_count): 47 | task = todo_queue.get() 48 | send = task.actors[0] 49 | recvs = task.actors[1:] 50 | if send not in send_actors_set and send not in recv_actors_set and \ 51 | all(recv not in send_actors_set for recv in recvs) and all(recv not in recv_actors_set for recv in recvs): 52 | pending_queue.append(task) 53 | send_actors_set.add(send) 54 | recv_actors_set.update(recvs) 55 | else: 56 | todo_queue.put(task) 57 | if pending_queue: 58 | yield pending_queue 59 | send_actors_set = set() 60 | recv_actors_set = set() 61 | pending_queue = [] 62 | 63 | def parallel_execute_collective_tasks(tasks, submit_func): 64 | scheduler = collective_task_scheduler(tasks) 65 | for parallel_tasks in scheduler: 66 | logger.info(f"DEBUG parallel_execute_tasks: {[task.group for task in parallel_tasks]}") 67 | with ThreadPoolExecutor(max_workers=len(parallel_tasks)) as executor: 68 | futures = [executor.submit(submit_func, task) for task in parallel_tasks] 69 | for _future in concurrent.futures.as_completed(futures): 70 | try: 71 | _future.result() 72 | except Exception as e: 73 | traceback.print_exc() 74 | raise RuntimeError(f"ParameterSync warmup failed: {e}") # pylint: disable=raise-missing-from 75 | -------------------------------------------------------------------------------- /chatlearn/tools/check_parameter_sync.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Check ParameterSync""" 16 | 17 | import argparse 18 | import os 19 | import torch 20 | 21 | def chatlearn_compare(expected_dir, actural_dir): 22 | total = 0 23 | diff = 0 24 | not_exists = 0 25 | for tp_rank in os.listdir(actural_dir): 26 | for param in os.listdir(os.path.join(actural_dir, tp_rank)): 27 | actual_fname = os.path.join(actural_dir, tp_rank, param) 28 | expected_fname = os.path.join(expected_dir, tp_rank, param) 29 | message = f"{tp_rank}|{param}" 30 | total += 1 31 | if not os.path.exists(expected_fname): 32 | print(f"NOT_EXISTS|{message}|NOT_EXISTS", flush=True) 33 | not_exists += 1 34 | continue 35 | ta = torch.load(actual_fname, map_location="cpu") 36 | tb = torch.load(expected_fname, map_location="cpu") 37 | if not torch.allclose(ta, tb): 38 | print(f"DIFF|{message}|{ta.shape}|{ta.mean()}|{tb.shape}|{tb.mean()}", flush=True) 39 | else: 40 | print(f"PASS|{message}") 41 | print(f"ALL: {all}, DIFF: {diff}, NOT_EXISTS: {not_exists}") 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument( 47 | "--root_dir", 48 | type=str, 49 | required=True, 50 | help="Root dir to check the dumped parameters") 51 | args = parser.parse_args() 52 | dir1 = os.path.join(args.root_dir, "before_sync_paramter") 53 | dir2 = os.path.join(args.root_dir, "after_sync_paramter") 54 | chatlearn_compare(dir1, dir2) 55 | -------------------------------------------------------------------------------- /chatlearn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """utils""" 16 | 17 | from .utils import to_device 18 | -------------------------------------------------------------------------------- /chatlearn/utils/constant.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """constants.""" 16 | 17 | import importlib 18 | from enum import Enum 19 | 20 | # Regroup 21 | 22 | CHATLEARN_REGROUP_TAG = "chatlearn_regroup_tag" 23 | INDEX_TAG = "data_index" 24 | 25 | LOG_START = "chatlearn_log" 26 | 27 | QKV_LAYER_NAME = ["query_key_value"] 28 | 29 | 30 | # vLLM version 31 | CURRENT_VLLM_VERSION = None 32 | if importlib.util.find_spec("vllm"): 33 | import vllm 34 | if hasattr(vllm, "__version_tuple__"): 35 | version_tuple = vllm.__version_tuple__ 36 | CURRENT_VLLM_VERSION = ".".join([str(ele) for ele in version_tuple[:3]]) 37 | else: 38 | CURRENT_VLLM_VERSION = vllm.__version__ 39 | 40 | 41 | class VLLMVersion(str, Enum): 42 | """support versions of vLLM.""" 43 | v_0_8_5 = "0.8.5" 44 | 45 | 46 | class QwenVersion(float, Enum): 47 | """qwen version""" 48 | v_1 = 1.0 49 | v_2 = 2.0 50 | 51 | 52 | class RAY_PG_STRATEGY(Enum): 53 | """ray placement group strategy""" 54 | PACK = "PACK" 55 | SPREAD = "SPREAD" 56 | 57 | 58 | class PARAM_SYNC_COMM_TYPE(str, Enum): 59 | """parameter sync communication type""" 60 | BROADCAST = "broadcast" 61 | P2P = "p2p" 62 | 63 | 64 | class ROUTED_EXPERT_REGROUPING_COMM_TYPE(str, Enum): 65 | """communication type of routed expert regrouping.""" 66 | ALLTOALL = "alltoall" 67 | ALLGATHER = "allgather" 68 | -------------------------------------------------------------------------------- /chatlearn/utils/error_monitor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Error monitor""" 16 | 17 | import time 18 | 19 | import ray 20 | import ray.util.collective as col 21 | 22 | from chatlearn.utils import future 23 | 24 | 25 | @ray.remote 26 | class ErrorMonitor: 27 | """Error Monitor""" 28 | 29 | def __init__(self, error_signal, remote_models, group_names): 30 | self.error_signal = error_signal 31 | self.remote_models = remote_models 32 | self.collective_groups = group_names 33 | 34 | def monitor(self): 35 | while True: 36 | try: 37 | catch_err = future.get(self.error_signal.is_set.remote()) 38 | except Exception: 39 | catch_err = False 40 | if catch_err: 41 | break 42 | time.sleep(2) 43 | for group_name in self.collective_groups: 44 | col.destroy_collective_group(group_name) 45 | for model in self.remote_models: 46 | model.terminate() 47 | error_msg = future.get(self.error_signal.error_msg.remote()) 48 | error_address = future.get(self.error_signal.error_address.remote()) 49 | raise Exception(f"Catch an exception in {error_address}, error msg: {error_msg}") 50 | 51 | 52 | @ray.remote(num_cpus=0) 53 | class ErrorSignalActor: 54 | """ErrorSignalActor""" 55 | def __init__(self): 56 | self.error_state = False 57 | self.err_msg = None 58 | self._address_list = [] 59 | 60 | def set(self, err_msg=None): 61 | self.error_state = True 62 | if err_msg is not None: 63 | self.err_msg = err_msg 64 | 65 | def set_address(self, address): 66 | if address not in self._address_list: 67 | self._address_list.append(address) 68 | 69 | def is_set(self): 70 | return self.error_state 71 | 72 | def error_msg(self): 73 | return self.err_msg 74 | 75 | def error_address(self): 76 | return self._address_list 77 | -------------------------------------------------------------------------------- /chatlearn/utils/future.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """get remote object""" 16 | 17 | import ray 18 | 19 | from chatlearn.utils.logger import logging_tqdm 20 | from chatlearn.utils.utils import flatten 21 | 22 | 23 | def check_nested_2_level_list(refs): 24 | """ 25 | Checks if a list is a nested list with a nested level of 2. 26 | e.g. 27 | [[ref0, ref1], [ref2, ref3]] returns True, [2, 2] 28 | [ref0, ref1] returns False, -1 29 | [[ref0], [ref1, ref2]] returns True, [1, 2] 30 | 31 | Returns a tuple containing two elements: 32 | - A boolean value indicating if the list is a nested 2-level list 33 | - A list of integers containing the length of each sublist 34 | """ 35 | sublist_lens = [] 36 | for sublist in refs: 37 | if isinstance(sublist, list): 38 | if len(sublist) == 0: 39 | sublist_lens.append(0) 40 | else: 41 | if isinstance(sublist[0], ray.ObjectRef): 42 | sublist_lens.append(len(sublist)) 43 | else: 44 | return False, None 45 | else: 46 | return False, None 47 | return True, sublist_lens 48 | 49 | 50 | def wait(refs, desc=None, return_output=False): 51 | """ 52 | wait until all computation finish 53 | """ 54 | if isinstance(refs, ray.ObjectRef): 55 | ray.get(refs) 56 | return 57 | if len(refs) == 0: 58 | return 59 | nested2, sublist_lens = check_nested_2_level_list(refs) 60 | refs = flatten(refs) 61 | if desc is not None: 62 | total = len(refs) if not nested2 else len(sublist_lens) 63 | pbar = logging_tqdm(total=total, desc=desc) 64 | i = 0 65 | wait_refs = refs.copy() 66 | while wait_refs: 67 | num_returns = 1 if not nested2 else sublist_lens[i] 68 | done, wait_refs = ray.wait(wait_refs, num_returns=num_returns) 69 | i += 1 70 | if desc is not None: 71 | done_size = len(done) if not nested2 else 1 72 | pbar.update(done_size) 73 | if return_output: 74 | outputs = ray.get(refs) 75 | if desc is not None: 76 | pbar.close() 77 | if return_output: 78 | return outputs 79 | 80 | 81 | def get(data): 82 | """get remote data""" 83 | if isinstance(data, (list, tuple)): 84 | dtype = type(data) 85 | ret = dtype(get(item) for item in data) 86 | return ret 87 | if isinstance(data, dict): 88 | return {key: get(value) for key, value in data.items()} 89 | while isinstance(data, ray.ObjectRef): 90 | data = ray.get(data) 91 | return data 92 | -------------------------------------------------------------------------------- /chatlearn/utils/global_vars.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """global vars.""" 16 | 17 | _GLOBAL_ARGS = None 18 | _EXIT_ACTOR = None 19 | _EXIT_ACTOR_NAME = "ChatLearnExitActor" 20 | _DECORATED_MODELS = None 21 | _DECORATED_OUTER_TO_INNER = {} 22 | _DEPENDENCIES = None 23 | _VLLM_ACTORS = None 24 | 25 | 26 | def _ensure_var_is_initialized(var, name): 27 | """Make sure the input variable is not None.""" 28 | assert var is not None, '{} is not initialized.'.format(name) 29 | 30 | def get_args(): 31 | """Return arguments.""" 32 | _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') 33 | return _GLOBAL_ARGS 34 | 35 | def set_global_variables(args): 36 | """Set global vars""" 37 | assert args is not None 38 | global _GLOBAL_ARGS 39 | _GLOBAL_ARGS = args 40 | global _DECORATED_MODELS 41 | _DECORATED_MODELS = set() 42 | 43 | def set_decorated(model_name): 44 | _DECORATED_MODELS.add(model_name) 45 | 46 | def is_decorated(model_name): 47 | _ensure_var_is_initialized(_DECORATED_MODELS, 'decorated_models') 48 | return bool(model_name in _DECORATED_MODELS) 49 | 50 | def unwrap_func(func, level=None): 51 | """ 52 | func: func to unwrap 53 | level: unwrap level, if level is None, unwrap to the original func 54 | """ 55 | if func not in _DECORATED_OUTER_TO_INNER: 56 | return func 57 | if level is not None: 58 | if level > 0: 59 | level -= 1 60 | else: 61 | return func 62 | return unwrap_func(_DECORATED_OUTER_TO_INNER[func], level) 63 | 64 | def set_wrap_func(func, new_func): 65 | assert new_func not in _DECORATED_OUTER_TO_INNER 66 | _DECORATED_OUTER_TO_INNER[new_func] = func 67 | 68 | def set_dependencies(dependencies): 69 | global _DEPENDENCIES 70 | assert _DEPENDENCIES is None 71 | _DEPENDENCIES = dependencies 72 | 73 | def reset_dependencies(): 74 | global _DEPENDENCIES 75 | _DEPENDENCIES = None 76 | 77 | def get_dependencies(): 78 | return _DEPENDENCIES 79 | 80 | def set_vllm_actors(actors): 81 | global _VLLM_ACTORS 82 | _VLLM_ACTORS = actors 83 | 84 | def get_vllm_actors(): 85 | return _VLLM_ACTORS 86 | -------------------------------------------------------------------------------- /chatlearn/utils/megatron_import_hook_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """"Version compatibility for hook""" 16 | 17 | # pylint: disable=unused-import,wildcard-import 18 | 19 | # megatron.text_generation.* 20 | try: 21 | from megatron.text_generation import generation 22 | from megatron.text_generation.generation import * 23 | from megatron.text_generation.generation import _build_attention_mask_and_position_ids 24 | from megatron.text_generation.generation import generate_tokens_probs_and_return_on_first_stage 25 | except ImportError: 26 | from megatron.inference.text_generation import generation 27 | from megatron.inference.text_generation.generation import * 28 | from megatron.inference.text_generation.generation import _build_attention_mask_and_position_ids 29 | from megatron.inference.text_generation.generation import generate_tokens_probs_and_return_on_first_stage 30 | 31 | # pylint: enable=unused-import,wildcard-import 32 | -------------------------------------------------------------------------------- /chatlearn/utils/megatron_import_memory_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 16 | Version compatibility utilities for Megatron memory management of gradients and parameter weights. 17 | Base on how Megatron uses buffers to manage memory, we support 3 different versions. 18 | """ 19 | import os 20 | 21 | from enum import Enum, auto 22 | from typing import List 23 | 24 | __all__ = ['MegatronVersion', 'get_megatron_version', 'check_megatron_versions'] 25 | 26 | 27 | class MegatronVersion(Enum): 28 | """ 29 | There are currently three different Megatron versions supported. 30 | """ 31 | 32 | V1 = auto() # use `MemoryBuffer` to manage gradients 33 | V2 = auto() # use `GradBuffer` to manage gradients 34 | V3 = auto() # use `ParamAndGradBuffer` to manage parameter weights and gradients 35 | V4 = auto() # for compatibility with temporary version for Qwen-MoE 36 | V5 = auto() # use refactored `ParamAndGradBuffer` to manage parameter weights and gradients 37 | 38 | 39 | def get_megatron_version(): 40 | # for compatibility with temporary version for Qwen-MoE 41 | if os.environ.get("QWEN_VERSION", '') == 'qwen_moe_v1': 42 | return MegatronVersion.V4 43 | 44 | try: 45 | # pylint: disable-next=import-outside-toplevel, unused-import 46 | from megatron.core.distributed.distributed_data_parallel import _ParamAndGradBuffer 47 | return MegatronVersion.V5 48 | except ImportError: 49 | ... 50 | try: 51 | # pylint: disable-next=import-outside-toplevel, unused-import 52 | from megatron.core.distributed import ParamAndGradBuffer 53 | 54 | return MegatronVersion.V3 55 | except ImportError: 56 | ... 57 | try: 58 | # pylint: disable-next=import-outside-toplevel, unused-import 59 | from megatron.core.distributed import GradBuffer 60 | 61 | return MegatronVersion.V2 62 | except ImportError: 63 | ... 64 | return MegatronVersion.V1 65 | 66 | 67 | def check_megatron_versions(targets: List[MegatronVersion]): 68 | version = get_megatron_version() 69 | assert version in targets, f'Different Megatron version {version} from targets: {targets}.' 70 | 71 | 72 | _version = get_megatron_version() 73 | 74 | # pylint: disable=unused-import 75 | 76 | if _version in [MegatronVersion.V3, MegatronVersion.V5]: 77 | from megatron.core.distributed.param_and_grad_buffer import BufferType 78 | 79 | __all__.append('BufferType') 80 | 81 | # pylint: enable=unused-import 82 | -------------------------------------------------------------------------------- /chatlearn/utils/megatron_import_transformer_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """"Version compatibility for hook""" 16 | 17 | # pylint: disable=unused-import,wildcard-import 18 | 19 | # megatron.model.transformer.* 20 | try: 21 | from megatron.model import transformer 22 | from megatron.model.transformer import ParallelAttention 23 | from megatron.model.transformer import * 24 | except ImportError: 25 | from megatron.legacy.model import transformer 26 | from megatron.legacy.model.transformer import ParallelAttention 27 | from megatron.legacy.model.transformer import * 28 | 29 | # pylint: enable=unused-import,wildcard-import 30 | -------------------------------------------------------------------------------- /chatlearn/utils/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """version""" 16 | 17 | VERSION = "1.1.0" 18 | -------------------------------------------------------------------------------- /docker/torch/Dockerfile.torch2.3.0: -------------------------------------------------------------------------------- 1 | # docker build -t your_docker_image -f Dockerfile.torch2.3.0 . 2 | FROM pytorch/pytorch:2.3.0-cuda12.1-cudnn8-devel 3 | 4 | LABEL com.nvidia.volumes.needed="nvidia_driver" 5 | LABEL com.nvidia.cuda.version= 6 | ENV NVIDIA_VISIBLE_DEVICES= \ 7 | NVIDIA_REQUIRE_CUDA="cuda>=11.0" \ 8 | LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/cuda/lib64 9 | 10 | # install common libs 11 | RUN pip install --no-cache-dir -U \ 12 | ray[default]==2.32.0 \ 13 | transformers==4.42.0 \ 14 | pynvml==11.4.1 \ 15 | deepspeed==0.14.4 \ 16 | vllm==0.5.1 \ 17 | accelerate \ 18 | jsonlines \ 19 | torchtyping \ 20 | tensorboard \ 21 | cupy 22 | 23 | # intall apex 24 | RUN apt-get update && apt-get install git vim -y 25 | WORKDIR /tmp/third_party 26 | RUN git clone https://github.com/NVIDIA/apex 27 | WORKDIR /tmp/third_party/apex 28 | RUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ 29 | RUN rm -rf /tmp/third_party 30 | 31 | # install transformer engine v1.2.1 32 | RUN MAX_JOBS=4 pip install git+https://github.com/NVIDIA/TransformerEngine.git@v1.2.1 33 | 34 | # env 35 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64 \ 36 | CUDA_DEVICE_MAX_CONNECTIONS=1 37 | -------------------------------------------------------------------------------- /docker/torch/Dockerfile.torch2.5.1.vllm066: -------------------------------------------------------------------------------- 1 | # currently support fsdp 2 | FROM nvcr.io/nvidia/pytorch:24.10-py3 3 | 4 | ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/ 5 | ENV PIP_TRUSTED_HOST=mirrors.aliyun.com 6 | 7 | RUN pip install --no-cache-dir -U \ 8 | opencv-python-headless==4.5.4.58 \ 9 | vllm==0.6.6 \ 10 | wandb==0.19.3 \ 11 | ray[default]==2.40.0 \ 12 | transformers==4.51.3 \ 13 | modelscope==1.26.0 \ 14 | datasets==3.6.0 \ 15 | deepspeed==0.14.4 \ 16 | grpcio==1.70.0 \ 17 | setuptools==69.5.1 18 | 19 | RUN pip uninstall -y flash_attn && pip install -U flash_attn==2.4.2 --no-cache-dir --no-build-isolation 20 | RUN pip uninstall -y apex && pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/apex/torch2.5.1-cuda12x/apex-0.1-cp310-cp310-linux_x86_64.whl --no-cache-dir 21 | RUN pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/transformer_engine/torch2.5.1-cuda12x/transformer_engine-1.13.0%2Be5edd6c-cp310-cp310-linux_x86_64.whl --no-cache-dir -------------------------------------------------------------------------------- /docker/torch/Dockerfile.torch2.6.0.vllm085: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.12-py3 2 | RUN unset NCCL_DEBUG 3 | ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/ 4 | ENV PIP_TRUSTED_HOST=mirrors.aliyun.com 5 | 6 | RUN pip install --no-cache-dir -U \ 7 | vllm==0.8.5.post1 \ 8 | wandb==0.19.3 \ 9 | transformers==4.51.3 \ 10 | modelscope==1.26.0 \ 11 | datasets==3.6.0 \ 12 | deepspeed==0.16.7 \ 13 | grpcio==1.70.0 \ 14 | nvidia-modelopt==0.27.0 \ 15 | nvidia-modelopt-core==0.27.0 \ 16 | ray[default]==2.46.0 17 | 18 | RUN pip install --no-cache-dir -U setuptools==69.5.1 19 | 20 | RUN pip uninstall -y flash_attn && pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/flash-attention/torch2.6.0-cu12x/flash_attn-2.4.2-cp312-cp312-linux_x86_64.whl --no-cache-dir 21 | RUN pip uninstall -y apex && pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/apex/torch2.6.0-cuda12x/apex-0.1-cp312-cp312-linux_x86_64.whl --no-cache-dir 22 | RUN pip uninstall -y transformer_engine && pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/transformer_engine/torch2.6.0-cuda12x/transformer_engine-1.13.0%2Be5edd6c-cp312-cp312-linux_x86_64.whl --no-cache-dir -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | make.bat 2 | _build 3 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | 3 | # You can set these variables from the command line, and also 4 | # from the environment for the first two. 5 | SPHINXOPTS ?= 6 | SPHINXBUILD ?= sphinx-build 7 | SOURCEDIR = zh 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Setup 2 | 3 | ``` 4 | apt-get install latexmk texlive-xetex fonts-noto fonts-freefont-otf xindy latex-cjk-all 5 | pip install -r requirements.txt 6 | ``` 7 | 8 | # build pdf 9 | 10 | ``` 11 | # cd en for english doc 12 | cd zh 13 | # make latexpdf 14 | sphinx-build -b latex . _build/latex 15 | cd _build/latex 16 | 17 | # modify chatlearn.tex for auto wrap of text in the table 18 | # Find the table with `stream\\_data\\_loader\\_type`, replace `\begin{tabulary}{\linewidth}[t]{TTT}` with `\begin{tabularx}{\linewidth}[t]{|l|l|X|}` 19 | # and replace the corresponding `\end` 20 | # save the change, and execute 21 | make all-pdf 22 | ``` 23 | 24 | 25 | # build html 26 | ``` 27 | # cd en for english doc 28 | cd zh 29 | make html 30 | ``` 31 | 32 | -------------------------------------------------------------------------------- /docs/en/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.10" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the "docs/" directory with Sphinx 19 | sphinx: 20 | configuration: docs/en/conf.py 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | formats: 24 | - pdf 25 | - epub 26 | 27 | # Optional but recommended, declare the Python requirements required 28 | # to build your documentation 29 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 30 | python: 31 | install: 32 | - requirements: docs/requirements.txt 33 | -------------------------------------------------------------------------------- /docs/en/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | 3 | # You can set these variables from the command line, and also 4 | # from the environment for the first two. 5 | SPHINXOPTS ?= 6 | SPHINXBUILD ?= sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/en/api/config.rst: -------------------------------------------------------------------------------- 1 | Config 2 | ====== 3 | 4 | .. autoclass:: chatlearn.utils.arguments.RuntimeEnvConfig 5 | :members: 6 | :undoc-members: 7 | 8 | .. autoclass:: chatlearn.utils.arguments.RuntimeConfig 9 | :members: 10 | :undoc-members: 11 | 12 | .. autoclass:: chatlearn.utils.arguments.ModelConfig 13 | :members: 14 | :undoc-members: 15 | 16 | 17 | .. autoclass:: chatlearn.utils.arguments.BatchGenerationConfig 18 | :members: 19 | :undoc-members: 20 | 21 | 22 | .. autoclass:: chatlearn.utils.arguments.LoraConfig 23 | :members: 24 | :undoc-members: 25 | 26 | 27 | -------------------------------------------------------------------------------- /docs/en/api/engine.rst: -------------------------------------------------------------------------------- 1 | Engine 2 | ====== 3 | 4 | .. autoclass:: chatlearn.DPOEngine 5 | :members: 6 | :undoc-members: 7 | 8 | .. autoclass:: chatlearn.OnlineDPOEngine 9 | :members: 10 | :undoc-members: 11 | 12 | .. autoclass:: chatlearn.RLHFEngine 13 | :members: 14 | :undoc-members: 15 | 16 | .. autoclass:: chatlearn.EvalEngine 17 | :members: 18 | :undoc-members: 19 | 20 | .. autoclass:: chatlearn.Evaluator 21 | :members: set_dataset,set_post_process_func,eval 22 | :undoc-members: 23 | -------------------------------------------------------------------------------- /docs/en/api/index.rst: -------------------------------------------------------------------------------- 1 | API 2 | ======================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | engine.rst 8 | module.rst 9 | config.rst 10 | -------------------------------------------------------------------------------- /docs/en/api/module.rst: -------------------------------------------------------------------------------- 1 | RLHF Module 2 | =========== 3 | 4 | .. autoclass:: chatlearn.models.base_module.BaseModule 5 | :members: 6 | :undoc-members: 7 | 8 | .. autoclass:: chatlearn.models.torch_module.TorchModule 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | .. autoclass:: chatlearn.models.megatron_module.MegatronModule 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | 18 | -------------------------------------------------------------------------------- /docs/en/index.rst: -------------------------------------------------------------------------------- 1 | ChatLearn Documentation 2 | ======================= 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | :caption: Introduction 8 | 9 | chatlearn 10 | 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | :caption: Installation 15 | 16 | installation 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | :caption: Tutorial 21 | 22 | tutorial/data 23 | tutorial/run 24 | tutorial/tutorial_llama2 25 | tutorial/tutorial_qwen 26 | tutorial/evaluator 27 | tutorial/continue_train 28 | tutorial/custom_model_flow 29 | tutorial/ems 30 | tutorial/profile 31 | 32 | .. toctree:: 33 | :maxdepth: 1 34 | :caption: Programming 35 | 36 | programming 37 | config_yaml 38 | advanced 39 | 40 | .. toctree:: 41 | :maxdepth: 1 42 | :caption: API Documentation 43 | 44 | api/index 45 | 46 | .. toctree:: 47 | :maxdepth: 1 48 | :caption: FAQ 49 | 50 | faq 51 | -------------------------------------------------------------------------------- /docs/en/installation.md: -------------------------------------------------------------------------------- 1 | # Environment and Code Setup 2 | 3 | 1. Docker Image Preparation 4 | 5 | It is recommended to refer to `https://github.com/alibaba/ChatLearn/tree/master/docker/torch/Dockerfile.torch2.3.0` for preparing the docker image. 6 | If you're training on the PAI DLC/DSW environment, we suggest using the pre-built image provided below: 7 | 8 | ```bash 9 | registry.cn-wulanchabu.aliyuncs.com/pai-dlc/pytorch-training:2.4.0-gpu-py3.10-cu12.5-ngc24.06-ubuntu22.04 10 | ``` 11 | 12 | 2. Code Preparation: Users need to download the ChatLearn framework code. 13 | 14 | ``` 15 | # Clone ChatLearn code 16 | git clone https://github.com/alibaba/ChatLearn.git 17 | ``` 18 | 19 | 3. If you need to run the alignment training program based on the Megatron-LM framework, you also need to download the `Megatron-LM` code. 20 | 21 | ``` 22 | # Clone Megatron-LM 23 | git clone https://github.com/NVIDIA/Megatron-LM.git 24 | git checkout core_r0.8.0 25 | ``` 26 | 27 | > [!NOTE] 28 | > If you are using Megatron-LM version `core_r0.8.0`, you may encounter an issue in converting checkpoints: `ValueError: Default process group has not been initialized, please make sure to call init_process_group`. Please refer to the solution in the [FAQ: Failure when converting checkpoint](faq.md#failure-when-converting-checkpoint). 29 | -------------------------------------------------------------------------------- /docs/en/tutorial/check_sync_parameter.md: -------------------------------------------------------------------------------- 1 | # Debugging Parameter Synchronization 2 | 3 | The ParameterSync module ensures weight consistency between the training (trainer) and serving (inference) components. This guide helps debug synchronization issues that may arise when using different distributed strategies, such as: 4 | 5 | Trainer Side: Megatron-LM (expert parallel + tensor parallel + pipeline parallel) 6 | Inference Side: vLLM (tensor parallel) 7 | 8 | ## Step-by-Step Debugging Guide 9 | 10 | ### 1. Set Up Environment Variable 11 | 12 | Specify a path to dump parameter snapshots before/after synchronization by setting the DEBUG_SYNC_PARAMETERS_PATH environment variable, and huggingface format checkpoint used in vLLM: 13 | 14 | ``` bash 15 | export DEBUG_SYNC_PARAMETERS_PATH=/path/to/dump_directory 16 | export vllm_load_format=auto 17 | export policy_load=/workspace/hf_ckp/QWen2-Max 18 | ``` 19 | 20 | ### 2. Launch the ChatLearn Job 21 | 22 | Start your training job (e.g., DPO fine-tuning for Llama 2) using [tutorial_llama2.md](/docs/en/tutorial/tuotiral_llama2.md): 23 | ```bash 24 | bash scripts/train_online_dpo_llama.sh 25 | ``` 26 | 27 | This will generate two directories: 28 | before_sync_parameter: Parameters before synchronization. 29 | after_sync_parameter: Parameters after synchronization. 30 | Each directory contains subfolders for every TP rank. 31 | 32 | ### 3. Check Dumped Files 33 | Verify the dumped parameter files: 34 | ``` bash 35 | tree /path/to/dump_directory 36 | ``` 37 | 38 | Example output: 39 | ``` text 40 | /workspace/debug_sync_params 41 | ├── before_sync_parameter 42 | │ ├── 0 # Parameters from TP rank 0 43 | │ ├── 1 # Parameters from TP rank 1 44 | │ └── ... 45 | └── after_sync_parameter 46 | ├── 0 47 | ├── 1 48 | └── ... 49 | ``` 50 | 51 | ### 4. Run the Parameter Check Script 52 | Use the check_sync_parameter.py tool to compare parameters before/after synchronization: 53 | ``` bash 54 | python chatlearn/tools/check_sync_parameter.py --root_dir /path/to/dump_directory | tee check.log 55 | ``` 56 | 57 | This script will compare parameter shapes and values 58 | 59 | Generate a log file (check.log) with detailed results. 60 | 61 | ### 5. Interpret the Log File 62 | 63 | Successful Sync: 64 | ```text 65 | PASS|1|model.layers.5.self_attn.qkv_proj.weight 66 | ``` 67 | 68 | MisMatch Detected with Mean Values: 69 | ```text 70 | DIFF|1|model.layers.3.mlp.shared_expert.gate_up_proj.weight|torch.Size([2816, 2048])|tensor(0.5247)|torch.Size([2816, 2048])|tensor(8.231) 71 | ``` -------------------------------------------------------------------------------- /docs/en/tutorial/continue_train.md: -------------------------------------------------------------------------------- 1 | # Resuming and Fault Tolerance 2 | 3 | The alignment task involves the computation and interaction of multiple models. As the model scale and computational resources increase, occasional exceptions may occur due to the dependent software stack and hardware environment, leading to task interruption. 4 | 5 | To ensure that interrupted tasks can automatically resume their state, ChatLearn provides the resuming function, which in combination with PAI-DLC's AIMaster, can automatically detect errors and resume functionality. 6 | 7 | ## Configuring ChatLearn Resuming 8 | 9 | Resuming an alignment task requires consideration of the following points: 10 | 11 | 1. Recording and restoring data progress: For recording data status, users need to configure `data_checkpoint_path` in the training configuration master file, such as `rlhf.yaml`. If `data_checkpoint_path` is not empty, ChatLearn will record the current data progress and store the data checkpoint during each `save_checkpoint`. 12 | 13 | 2. Restoring training states such as episodes and iterations: When users configure `data_checkpoint_path` and the corresponding data checkpoint exists in the folder, ChatLearn will automatically restore the training state to the latest checkpoint status and set the `resume_training` variable to `True`. 14 | 15 | 3. Loading checkpoints: When `resume_training==True`, the checkpoints for `reference` and `reward` in RLHF remain unchanged. However, `ppo_policy` and `ppo_value` need to load the checkpoints stored during training, rather than the original initialized checkpoints. Therefore, special processing needs to be done in the setup phase. 16 | 17 | ```python 18 | if self.resume_training: 19 | self.args.load = get_args().save 20 | self.args.load_iteration = -1 21 | self.args.no_load_optim = False 22 | self.args.no_load_rng = False 23 | self.args.no_load_args = False 24 | self.args.no_load_scheduler = False 25 | self.args.finetune = False 26 | ``` 27 | 28 | For more details, refer to `examples/megatron/scripts/train_rlhf_llama.sh`. 29 | 30 | If a user configures `data_checkpoint_path` in the program but does not want to enable the resuming function, they can also disable this functionality by configuring `enable_resume_training: False`. 31 | 32 | ## Combining with DLC AIMaster to Achieve Fault Tolerance and Automatic Resuming 33 | 34 | DLC provides fault tolerance monitoring based on AIMaster. AIMaster is a task-level component that, when enabled for fault tolerance monitoring, launches an AIMaster instance to run with other task instances, serving the roles of task monitoring, fault judgment, and resource control. 35 | 36 | Users can combine AIMaster's fault tolerance functionality with ChatLearn's resuming functionality to achieve automatic resuming of training tasks. 37 | 38 | The following is an example of fault tolerance monitoring configuration, which includes enabling hang detection and error detection. When the hang exceeds 1 hour or AIMaster detects an error, the task will be automatically restarted, with a maximum number of restarts being 3 times. 39 | 40 | ![aimaster](../../images/fault.png) 41 | 42 | For more fault tolerance configuration, please refer to the DLC [Fault Tolerance Documentation](https://help.aliyun.com/zh/pai/user-guide/fault-tolerance-monitoring-based-on-aimaster?spm=a2c4g.11186623.0.0.12011976WAncyo). 43 | -------------------------------------------------------------------------------- /docs/en/tutorial/data.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | This document describes the data preparation process for different stages: SFT, Reward, RLHF, DPO, OnlineDPO and GRPO. 4 | 5 | 6 | **The following is a collection of general environment variables used in this tutorial script:** 7 | 8 | | ENV | Explanation | 9 | | --- | --- | 10 | | `CHATLEARN` | The location where the ChatLearn code is cloned [https://github.com/alibaba/ChatLearn.git](https://github.com/alibaba/ChatLearn.git) | 11 | | `DATASET_ROOT` | The root directory for storing the SFT/Reward/RLHF/DPO/OnlineDPO/GRPO training dataset collection. | 12 | 13 | 14 | 15 | ## 1 Prepare SFT Training Data 16 | 17 | Organize the question-response pairs of SFT data into a jsonl file, where each line of the jsonl file represents a SFT data sample in the following Python dictionary format: 18 | 19 | ``` 20 | {'query': question, 'response': reply} 21 | ``` 22 | 23 | Taking the example of Anthropic's helpful&harmless data, use the following code to store it in `$DATASET_ROOT/sft/train.jsonl`. 24 | 25 | ```bash 26 | cd ${CHATLEARN}/examples/megatron/ 27 | DATASET_ROOT=$path_to_dataset_root 28 | python data/prepare_data_sft.py $DATASET_ROOT 29 | ``` 30 | 31 | ## 2 Prepare Reward Training Data 32 | 33 | 1. First, prepare question-different response pairs and organize them into a jsonl file. Each line in the jsonl file represents a Reward model training data sample in the following Python dictionary format: 34 | 35 | ``` 36 | {'query': question, 'response': [reply 1, reply 2, ...], 'score': [score1, score2, ...]} 37 | ``` 38 | 39 | The score value indicates the quality of the corresponding response, with higher scores indicating higher quality and closer to human preference. 40 | 41 | 2. Taking the example of Anthropic's helpful&harmless data, use the following code to store it in `$DATASET_ROOT/rm/train.jsonl` and `$DATASET_ROOT/rm/dev.jsonl`. 42 | 43 | ```bash 44 | cd ${CHATLEARN}/examples/megatron/ 45 | DATASET_ROOT=path-to-dataset-root 46 | python data/prepare_data_reward.py $DATASET_ROOT 47 | ``` 48 | 49 | ## 3 Prepare Alignment Training Data 50 | 51 | ChatLearn supports multiple alignments: RLHF, DPO, OnlineDPO, GRPO 52 | 53 | 1. Firstly, prepare a dataset of instructions to be explored and organize it into a JSON file. Each line in the JSON file should represent a prompt in the following format: 54 | 55 | ``` 56 | {"prompt": prompt} 57 | ``` 58 | 59 | 2. Taking Anthropic's helpful & harmless data as an example, use the following code to store the dataset in `$DATASET_ROOT/alignment/train.jsonl` and `$DATASET_ROOT/alignment/dev.jsonl`: 60 | 61 | ```bash 62 | cd ${CHATLEARN}/examples/megatron/ 63 | DATASET_ROOT=path-to-dataset-root 64 | python data/prepare_data_alignment.py $DATASET_ROOT 65 | ``` 66 | ## 4 Prepare Math Training Data 67 | 68 | 1. Firstly, prepare a dataset of math data to be explored and organize it into a JSON file. Each line in the JSON file should represent a prompt in the following format: 69 | 70 | ``` 71 | {"eval_func": "math_rule", "prompt": prompt, 'answer': answer} 72 | ``` 73 | 74 | 2. Taking openai/gsm8k data as an example, use the following code to store the dataset in `$DATASET_ROOT/math/train.jsonl`: 75 | 76 | ```bash 77 | cd ${CHATLEARN}/examples/megatron/ 78 | DATASET_ROOT=path-to-dataset-root 79 | python data/prepare_data_math.py $DATASET_ROOT 80 | ``` -------------------------------------------------------------------------------- /docs/en/tutorial/ems.md: -------------------------------------------------------------------------------- 1 | # Efficient Memory Sharing (EMS) 2 | 3 | ChatLearn provides EMS feature to significantly reduce the GPU memory usage during the alignment training. 4 | It maximizes the use of limited resources to train models with larger-scale or to improve overall training efficiency by improving the model's parallel strategy and increasing the batch size after GPU memory saved. 5 | 6 | When multiple models in ChatLearn share the same resources for training or inference, enabling the EMS feature allows these models to sequentially share GPU memory: 7 | - After each model is initialized, tensors/buffers that constantly reside in GPU memory (such as weights, gradient buffers, and optimization states) are unloaded to the RAM or freed to release their occupied GPU memory. 8 | - Before training or inference for a specific model, the tensors/buffers are loaded from the RAM or reconstructed, and then training or inference takes place. 9 | - Once the training or inference is complete, the tensors/buffers are again unloaded to the RAM or freed to release their occupied GPU memory. 10 | 11 | By repeating the above process, multiple models sequentially share GPU memory, maximizing the efficiency of GPU memory usage. 12 | 13 | ## Usage 14 | Users can specify whether to enable the EMS feature by configuring the `free_memory` (bool type, default is False) parameter for each model. This can be directly modified in the `rlhf.yaml` for each model. For example, to enable the EMS feature for the policy model: 15 | ```yaml 16 | policy: 17 | model_config_file: old_policy_inference.yaml 18 | ... 19 | free_memory: ${free_memory_policy:True} 20 | ``` 21 | Alternatively, it can also be configured in the training script using environment variables: 22 | - Policy model: `export free_memory_policy=True` 23 | - Reference model: `export free_memory_reference=True` 24 | - Reward model: `export free_memory_reward=True` 25 | - Value model: `export free_memory_value=True` 26 | - PPO policy model: `export free_memory_ppo_policy=True` 27 | - PPO value model: `export free_memory_ppo_value=True` 28 | 29 | A complete example can be found in the [llama2 configuration](https://github.com/alibaba/ChatLearn/blob/main/examples/megatron/configs/llama2/rlhf.yaml). -------------------------------------------------------------------------------- /docs/en/tutorial/evaluator.md: -------------------------------------------------------------------------------- 1 | # Evaluator 2 | 3 | This document will introduce how to perform model evaluation. Users can use `EvalEngine` to evaluate models independently or configure the evaluator within the training engine to perform evaluations during training. 4 | 5 | ```python 6 | def eval_flow(batch): 7 | p = policy.forward_step(batch) 8 | r = reward.eval_step(p) 9 | r1 = reward2.eval_step(p) 10 | return r, r1 11 | 12 | evaluator = Evaluator(eval_flow) 13 | evaluator.set_dataset(prompts) 14 | results = evaluator.eval() 15 | ``` 16 | 17 | In the above example, we constructed an evaluation flow for three models. Users can customize the evaluation execution flow through the `eval_flow`. 18 | 19 | The result returned by `evaluator.eval` is of type `dict`, where the key is `model_name` and the value is a `list` containing the results of the computations for each batch. 20 | 21 | In the above example, the result returned by `eval` will be `{"reward": [batch0, batch1, batch2], "reward2": [batch0, batch1, batch2]}`. 22 | -------------------------------------------------------------------------------- /docs/en/tutorial/profile.md: -------------------------------------------------------------------------------- 1 | # Profile 2 | ChatLearn provides two ways to profile performance: 3 | 1. Torch profiler 4 | 2. nsys 5 | Note: For large models, the profile result can be very large. It is recommended to reduce the model size when profiling. 6 | 7 | ## Torch Profiler 8 | Users can enable the Torch profiler by configuring the rlhf setting `profiler_dir: path_to_profile_dir` in the main configuration file of the system. 9 | ```yaml 10 | profiler_dir: path_to_profile_dir 11 | ``` 12 | 13 | ## nsys 14 | Users can enable the nsys profiler by configuring the rlhf setting `nsys: True` in the main configuration file of the system. 15 | ```yaml 16 | runtime: 17 | nsys: True 18 | ``` 19 | When launching the program, nsys startup parameters need to be added before the execution command, as shown in the following example: 20 | ```bash 21 | nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s none --capture-range=cudaProfilerApi --capture-range-end=stop-shutdown --cudabacktrace=true -x true --force-overwrite true -o my_profile \ 22 | python train_rlhf.py XXX 23 | ``` -------------------------------------------------------------------------------- /docs/en/tutorial/run.md: -------------------------------------------------------------------------------- 1 | # Distributed Execution 2 | 3 | This document will provide instructions on how to execute a distributed training task. 4 | 5 | ## PAI DLC Distributed Execution 6 | 7 | [Aliyun PAI DLC](https://www.aliyun.com/activity/bigdata/pai-dlc) [1] can conveniently and efficiently support training for various tasks. 8 | 9 | The screenshots of the PAI-DLC task creation page is shown as follows. 10 | Select the job type as `PyTorch` and paste the command into the `Execution Command` window. 11 | 12 | ![image.png](../../images/dlc_1.jpg) 13 | 14 | ![image.png](../../images/dlc_2.jpg) 15 | 16 | 17 | 18 | ## Non-PAI-DLC environment 19 | 20 | If you want to submit distributed training in a non-PAI-DLC environment, 21 | the following environment variables need to be configured on each node before executing the script: 22 | 23 | ```bash 24 | export MASTER_ADDR=xxx 25 | export MASTER_PORT=xxx 26 | export WORLD_SIZE=xxx 27 | export GPUS_PER_NODE=8 28 | export RANK=xx 29 | ``` 30 | 31 | ## Reference 32 | 33 | 1. Aliyun Machine Learning PAI-DLC: [https://www.aliyun.com/activity/bigdata/pai-dlc](https://www.aliyun.com/activity/bigdata/pai-dlc) 34 | -------------------------------------------------------------------------------- /docs/en/tutorial/tutorial_grpo_fsdp.md: -------------------------------------------------------------------------------- 1 | # End-to-End GRPO Training Tutorial with FSDP 2 | 3 | This document provides instructions for end-to-end training using the ChatLearn, pytorch FSDP and vLLM framework, and the qwen3 model. 4 | 5 | ## Environment Setup 6 | 1. Docker Image Preparation 7 | 8 | We recommend running the following example in PAI [DSW](https://help.aliyun.com/zh/pai/user-guide/create-and-manage-dsw-instances/)/[DLC](https://help.aliyun.com/zh/pai/user-guide/create-a-training-task?spm=a2c4g.11186623.help-menu-30347.d_3_3_5_5.2dfb1925l3QjwG). You need to use the following image to launch the instance. 9 | ```bash 10 | dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312 11 | ``` 12 | 13 | You can use a VPC address to accelerate image pulling. The image address should be adjusted based on the current region. For example, if you need to launch a DSW instance in Shanghai, you can use the following image `dsw-registry-vpc.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312`. 14 | 15 | 2. Code Preparation 16 | 17 | ```bash 18 | git clone https://github.com/alibaba/ChatLearn.git && cd ChatLearn 19 | ``` 20 | 21 | ## Data Preparation 22 | We take [MATH-lighteval](https://www.modelscope.cn/datasets/AI-ModelScope/MATH-lighteval) as exmaple. 23 | ```bash 24 | # download dataset 25 | mkdir -p dataset 26 | modelscope download --dataset AI-ModelScope/MATH-lighteval --local_dir dataset/MATH-lighteval 27 | # preprocess dataset 28 | python examples/fsdp/data/data_preprocess/math_lighteval.py --input_dir dataset/MATH-lighteval --local_dir dataset/MATH-lighteval 29 | # download model weight 30 | modelscope download --model Qwen/Qwen3-8B --local_dir Qwen3-8B 31 | ``` 32 | 33 | ## Training 34 | You can run the following command to start training: 35 | 36 | ```bash 37 | bash examples/fsdp/scripts/train_grpo_qwen3.sh 38 | ``` 39 | 40 | ## Using Wandb 41 | If you want to use Wandb to log the training process, you need to modify the following configuration in [train_grpo_qwen3.sh](../../../examples/fsdp/scripts/train_grpo_qwen3.sh): 42 | 43 | ```bash 44 | export enable_wandb=True 45 | export wandb_project="Your-Wandb-Project-Name" 46 | export WANDB_API_KEY="Your-Wandb-api-key" 47 | ``` -------------------------------------------------------------------------------- /docs/en/tutorial/tutorial_qwen.md: -------------------------------------------------------------------------------- 1 | # End-to-end training tutorial based on the Qwen model 2 | 3 | This document describes DPO training based on the ChatLearn, DeepSpeed framework, and Qwen model. 4 | 5 | **The following is a collection of common environment variables used in this tutorial script:** 6 | | ENV | Meaning | 7 | | --- |-------------------------------------------------------------------------------------------------------------------------------| 8 | | `CHATLEARN` | Location where the ChatLearn code repository is cloned [https://github.com/alibaba/ChatLearn.git](https://github.com/alibaba/ChatLearn.git) | 9 | | `DATASET_ROOT` | Root directory where the training datasets are stored | 10 | 11 | ## Setup: Image, Code, and Data Preparation 12 | 13 | ### Image / Code 14 | 15 | Please refer to [Environment and Code Setup](../installation.md). 16 | 17 | ### Data 18 | The data format required by qwen2 is chatml: 19 | ``` 20 | {"type": "chatml", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Tell me something about large language models."}, {"role": "assistant", "content": "Large language models are a type of language model that is trained on a large corpus of text data. They are capable of generating human-like text and are used in a variety of natural language processing tasks..."}], "source": "unknown"} 21 | ``` 22 | The following script can convert `Dahoas/full-hh-rlhf` to data in chatml format and store it in the file `$DATASET_ROOT/alignment/train.jsonl`: 23 | ```bash 24 | cd ${CHATLEARN}/examples/huggingface/ 25 | DATASET_ROOT=path-to-dataset-root 26 | python data/preprocess_data_chatml.py $DATASET_ROOT 27 | ``` 28 | 29 | ## DPO 30 | Here is an example of DPO training for Qwen2-7B. 31 | In this example, the user needs to set `policy_model_path` to the initialization model checkpoint path, and the Policy model and Reference model will be initialized with this checkpoint. 32 | ``` 33 | export CHATLEARN=path-to-chatlearn 34 | export DATASET_PATH=$DATASET_ROOT/alignment/train.jsonl 35 | export policy_model_path=path-to-qwen2-ckpt 36 | cd ${CHATLEARN}/examples/huggingface/ 37 | bash scripts/train_dpo_qwen.sh 38 | ``` 39 | -------------------------------------------------------------------------------- /docs/images/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/arch.png -------------------------------------------------------------------------------- /docs/images/class.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/class.png -------------------------------------------------------------------------------- /docs/images/dlc_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/dlc_1.jpg -------------------------------------------------------------------------------- /docs/images/dlc_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/dlc_2.jpg -------------------------------------------------------------------------------- /docs/images/engine.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/engine.jpg -------------------------------------------------------------------------------- /docs/images/engine_class.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/engine_class.png -------------------------------------------------------------------------------- /docs/images/fault.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/fault.png -------------------------------------------------------------------------------- /docs/images/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/logo.jpg -------------------------------------------------------------------------------- /docs/images/perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/perf.png -------------------------------------------------------------------------------- /docs/images/rlhf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/rlhf.png -------------------------------------------------------------------------------- /docs/images/yaml.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/docs/images/yaml.jpg -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx_rtd_theme 3 | recommonmark 4 | sphinx-markdown-tables 5 | myst-parser 6 | sphinx-markdown-builder 7 | sphinx_markdown_checkbox 8 | -------------------------------------------------------------------------------- /docs/zh/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.10" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the "docs/" directory with Sphinx 19 | sphinx: 20 | configuration: docs/zh/conf.py 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | formats: 24 | - pdf 25 | - epub 26 | 27 | # Optional but recommended, declare the Python requirements required 28 | # to build your documentation 29 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 30 | python: 31 | install: 32 | - requirements: docs/requirements.txt 33 | -------------------------------------------------------------------------------- /docs/zh/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | 3 | # You can set these variables from the command line, and also 4 | # from the environment for the first two. 5 | SPHINXOPTS ?= 6 | SPHINXBUILD ?= sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/zh/api/config.rst: -------------------------------------------------------------------------------- 1 | Config 2 | ====== 3 | 4 | .. autoclass:: chatlearn.utils.arguments.RuntimeEnvConfig 5 | :members: 6 | :undoc-members: 7 | 8 | .. autoclass:: chatlearn.utils.arguments.RuntimeConfig 9 | :members: 10 | :undoc-members: 11 | 12 | .. autoclass:: chatlearn.utils.arguments.ModelConfig 13 | :members: 14 | :undoc-members: 15 | 16 | 17 | .. autoclass:: chatlearn.utils.arguments.BatchGenerationConfig 18 | :members: 19 | :undoc-members: 20 | 21 | 22 | .. autoclass:: chatlearn.utils.arguments.LoraConfig 23 | :members: 24 | :undoc-members: 25 | 26 | 27 | -------------------------------------------------------------------------------- /docs/zh/api/engine.rst: -------------------------------------------------------------------------------- 1 | Engine 2 | ====== 3 | 4 | .. autoclass:: chatlearn.DPOEngine 5 | :members: 6 | :undoc-members: 7 | 8 | .. autoclass:: chatlearn.OnlineDPOEngine 9 | :members: 10 | :undoc-members: 11 | 12 | .. autoclass:: chatlearn.RLHFEngine 13 | :members: 14 | :undoc-members: 15 | 16 | .. autoclass:: chatlearn.EvalEngine 17 | :members: 18 | :undoc-members: 19 | 20 | .. autoclass:: chatlearn.Evaluator 21 | :members: set_dataset,set_post_process_func,eval 22 | :undoc-members: 23 | -------------------------------------------------------------------------------- /docs/zh/api/index.rst: -------------------------------------------------------------------------------- 1 | API 2 | ======================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | engine.rst 8 | module.rst 9 | config.rst 10 | -------------------------------------------------------------------------------- /docs/zh/api/module.rst: -------------------------------------------------------------------------------- 1 | RLHF Module 2 | =========== 3 | 4 | .. autoclass:: chatlearn.models.base_module.BaseModule 5 | :members: 6 | :undoc-members: 7 | 8 | .. autoclass:: chatlearn.models.torch_module.TorchModule 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | .. autoclass:: chatlearn.models.megatron_module.MegatronModule 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | 18 | -------------------------------------------------------------------------------- /docs/zh/config_yaml.md: -------------------------------------------------------------------------------- 1 | # 配置文件 2 | ## 训练配置文件 3 | 4 | 用户需要一个程序主 yaml 配置来设置运行环境、模型配置和 RLHF 训练流程相关的配置。同时,用户也可能需要为每个模型配置单独的模型配置。 5 | 6 | RLHF 的训练配置包括三部分 7 | 8 | 1. runtime_env: 运行环境配置 9 | 2. models: 模型配置。每一个模型都可以单独配置模型参数。通过`model_name`来区分不同的模型。这里`model_name`对应主文件中定义模型时传入的`model_name`。 10 | 3. runtime: 训练配置 11 | 12 | 以下为一个训练配置的示例。具体的配置项含义可以参考 [Config API 文档](api/config.rst). 13 | 14 | 为了方便配置不同的超参数,我们也支持从环境变量读取参数。格式如下 15 | 16 | ``` 17 | param: ${env_name:default_value} 18 | ``` 19 | `param`为参数名,`env_name`为环境变量名,`default_value`为默认值 (可选)。在以下例子中,如果设置了环境变量`ref_generation_batch_size`, 则会从环境变量中读取赋值给`reference`的`generation_batch_size`,如果没有设置环境变量`ref_generation_batch_size`,则使用默认值 4。 20 | 21 | ```yaml 22 | runtime_env: 23 | platform: DLC 24 | excludes: 25 | - "*pt" 26 | - "logs" 27 | - "tensorboards" 28 | - ".nfs*" 29 | 30 | 31 | models: 32 | policy: 33 | model_config_file: policy_inference.yaml 34 | num_gpu: 8 35 | trainable: False 36 | 37 | reference: 38 | model_config_file: reference.yaml 39 | num_gpu: 8 40 | trainable: False 41 | generation_batch_size: ${ref_generation_batch_size:4} 42 | 43 | reward: 44 | model_config_file: reward_inference.yaml 45 | num_gpu: 8 46 | trainable: False 47 | 48 | value: 49 | model_config_file: old_value_inference.yaml 50 | num_gpu: 8 51 | trainable: False 52 | 53 | ppo_policy: 54 | model_config_file: ppo_policy.yaml 55 | num_gpu: 8 56 | trainable: True 57 | 58 | ppo_value: 59 | model_config_file: ppo_value.yaml 60 | num_gpu: ${num_gpu} 61 | trainable: True 62 | 63 | runtime: 64 | colocation: 65 | - policy,ppo_policy,reward,reference,value,ppo_value 66 | generation_batch_size: ${generation_batch_size:4} 67 | train_micro_batch_size: 2 68 | train_global_batch_size: ${train_global_batch_size:512} 69 | num_episode: 200 70 | sample_per_episode: ${sample_per_episode:1024} 71 | num_training_epoch: 1 72 | save_episode_interval: ${save_episode_interval:50} 73 | data_path: ${data_path} 74 | eval_episode_interval: ${eval_episode_interval:100} 75 | ``` 76 | 77 | 78 | ## 模型配置 YAML 79 | 80 | 本框架支持对每个模型配置单独的配置文件,用于配置不同模型的超参数,并行化策略,checkpoint 初始化等。模型配置文件格式为 yaml 文件。下面是一个简单的模型配置例子。 81 | 82 | ```yaml 83 | num_layers: 6 84 | hidden_size: 768 85 | num_attention_heads: 12 86 | bf16: True 87 | seq_length: 2048 88 | tensor_model_parallel_size: 8 89 | pipeline_model_parallel_size: 2 90 | load: path-to-ckpt 91 | ``` 92 | 93 | 为了简化不同模型的共享配置,我们拓展了 yaml 的语法,通过 include 的字段来集成 base 配置文件的配置。在下面这个例子中,`policy_inference.yaml`和`ppo_policy.yaml`共享`num_layers`/`hidden_size`等参数,同时,两个模型的配置了各自不同的`pipeline_model_parallel_size`。 94 | 95 | ![yaml](../images/yaml.jpg) 96 | -------------------------------------------------------------------------------- /docs/zh/index.rst: -------------------------------------------------------------------------------- 1 | ChatLearn 使用文档 2 | ======================= 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | :caption: 简介 8 | 9 | chatlearn 10 | 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | :caption: 安装 15 | 16 | installation 17 | 18 | 19 | .. toctree:: 20 | :maxdepth: 1 21 | :caption: 使用教程 22 | 23 | tutorial/data 24 | tutorial/run 25 | tutorial/tutorial_llama2 26 | tutorial/tutorial_qwen 27 | tutorial/evaluator 28 | tutorial/continue_train 29 | tutorial/custom_model_flow 30 | tutorial/ems 31 | tutorial/profile 32 | 33 | 34 | .. toctree:: 35 | :maxdepth: 1 36 | :caption: 编程接口 37 | 38 | programming 39 | config_yaml 40 | advanced 41 | 42 | .. toctree:: 43 | :maxdepth: 1 44 | :caption: API 文档 45 | 46 | api/index 47 | 48 | 49 | .. toctree:: 50 | :maxdepth: 1 51 | :caption: 常见问题 52 | 53 | faq -------------------------------------------------------------------------------- /docs/zh/installation.md: -------------------------------------------------------------------------------- 1 | # 环境和代码准备 2 | 3 | 1. 镜像准备 4 | 5 | 可以参考 `https://github.com/alibaba/ChatLearn/tree/master/docker/torch/Dockerfile.torch2.3.0` 准备镜像。 6 | 如果在 PAI DLC/DSW 环境上训练,推荐使用我们准备好的镜像: 7 | 8 | ```bash 9 | registry.cn-wulanchabu.aliyuncs.com/pai-dlc/pytorch-training:2.4.0-gpu-py3.10-cu12.5-ngc24.06-ubuntu22.04 10 | ``` 11 | 12 | 2. 代码准备: 用户需要下载 `ChatLearn` 框架代码。 13 | 14 | ``` 15 | # 下载 ChatLearn 代码 16 | git clone https://github.com/alibaba/ChatLearn.git 17 | ``` 18 | 19 | 3. 如果您需要运行基于 Megatron-LM 框架的 alignment 训练程序,您也需要下载 `Megatron-LM` 代码。 20 | 21 | ``` 22 | # 下载 Megatron-LM 23 | git clone https://github.com/NVIDIA/Megatron-LM.git 24 | git checkout core_r0.8.0 25 | ``` 26 | 27 | > [!NOTE] 28 | > 若使用 Megatron-LM core_r0.8.0,您可能在转换 checkpoint 时遇到错误:`ValueError: Default process group has not been initialized, please make sure to call init_process_group.`,您可以参考 [FAQ:转换 Checkpoint 失败](faq.md#转换-checkpoint-失败) 中的解决方案。 29 | -------------------------------------------------------------------------------- /docs/zh/tutorial/continue_train.md: -------------------------------------------------------------------------------- 1 | # 续跑和容错 2 | 3 | Alignment 任务涉及到多模型的计算和交互,随着模型规模的增大和计算资源的增加,由于依赖的软件栈和硬件环境都有可能出现偶发异常,会导致任务停止运行。 4 | 为了保障被中断的任务可以恢复状态进行自动续跑,ChatLearn提供了续跑的功能,结合 PAI-DLC 的 AIMaster,可以实现自动错误检测和续跑功能。 5 | 6 | ## 配置 ChatLearn 续跑 7 | 8 | 任务的续跑需要考虑以下几点: 9 | 1. 数据进度的记录和恢复; 对于数据状态的记录,用户需要在训练配置主文件如 `rlhf.yaml` 中配置 `data_checkpoint_path`。 10 | 如果 `data_checkpoint_path` 不为空,则 ChatLearn 会记录当前的数据进度,并在每次 `save_checkpoint` 的同时存储 data checkpoint。 11 | 2. 训练状态比如 episode、iteration 等信息的恢复;当用户配置了 `data_checkpoint_path`,同时文件夹中存在对应的 data checkpoint,ChatLearn 会自动恢复训练状态到当前最新的checkpoint状态。 12 | 并将模型的 `resume_training` 变量设为 `True` 。 13 | 3. checkpoint的加载;当 `resume_training==True`, 对于 RLHF 中的几个模型,`reference` 和 `reward` 加载的 checkpoint 保持不变。 14 | `ppo_policy` 和 `ppo_value` 需要加载训练中存储的checkpoint,而不是原始初始化的checkpoint。 因此需要在 `setup` 阶段做特殊处理。 15 | 16 | ```python 17 | if self.resume_training: 18 | self.args.load = get_args().save 19 | self.args.load_iteration = -1 20 | self.args.no_load_optim = False 21 | self.args.no_load_rng = False 22 | self.args.no_load_args = False 23 | self.args.no_load_scheduler = False 24 | self.args.finetune = False 25 | ``` 26 | 27 | 更多详情可以参考 `examples/megatron/scripts/train_rlhf_llama.sh` 。 28 | 29 | 如果用户在程序中配置了 `data_checkpoint_path` ,但是不想打开续跑功能,则也可以通过配置 `enable_resume_training: False` 来关闭此功能。 30 | 31 | ## 和 DLC AIMaster 结合实现容错和自动续跑 32 | 33 | DLC提供了基于AIMaster的容错监控功能。AIMaster是一个任务级别的组件,当任务开启AIMaster的容错监控功能后, 34 | 会拉起一个AIMaster实例和任务其他实例一起运行,起到任务监控、容错判断、资源控制的作用。 35 | 36 | 用户可以通过结合 AIMaster 的容错功能和 ChatLearn 的续跑功能来实现训练任务的自动续跑。 37 | 38 | 以下为容错监控的配置示例,在这个配置中打开了 hang 检测和错误检测, 39 | 当hang 超过 1 个小时或者当 AIMaster 检测到错误,会将任务自动重启,最大重启次数为3次。 40 | 41 | ![aimaster](../../images/fault.png) 42 | 43 | 更多的容错配置请参考 DLC [容错文档](https://help.aliyun.com/zh/pai/user-guide/fault-tolerance-monitoring-based-on-aimaster?spm=a2c4g.11186623.0.0.12011976WAncyo) 。 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /docs/zh/tutorial/data.md: -------------------------------------------------------------------------------- 1 | # 数据准备 2 | 3 | 本文档介绍不同阶段 SFT, Reward,RLHF,DPO, OnlineDPO 和 GRPO 的数据准备流程。 4 | 5 | **以下是这个 Tutorial 脚本中使用的通用环境变量集合:** 6 | 7 | | ENV | 含义 | 8 | | --- | --- | 9 | | `CHATLEARN` | ChatLearn 代码仓库 clone 存放的位置 [https://github.com/alibaba/ChatLearn.git](https://github.com/alibaba/ChatLearn.git) | 10 | | `DATASET_ROOT` | 存放SFT/Reward/RLHF/DPO/OnlineDPO/GRPO训练数据集合的根目录 | 11 | 12 | ## 准备 SFT 训练数据 13 | 14 | 将 SFT 数据的问题 - 回复配对的样本,整理到一个 jsonl 文件中,其中 jsonl 文件中每一行为一条 SFT 数据,形式为如下的 Python 字典格式: 15 | 16 | ``` 17 | {'query': 问题,'response': 回复} 18 | ``` 19 | 20 | 以 Anthropic 的 helpful&harmless 的数据为例,使用如下代码,会存一个 `$DATASET_ROOT/sft/train.jsonl`. 21 | 22 | ```bash 23 | cd ${CHATLEARN}/examples/megatron/ 24 | DATASET_ROOT=$path_to_dataset_root 25 | python data/prepare_data_sft.py $DATASET_ROOT 26 | ``` 27 | 28 | ## 准备 Reward 训练数据 29 | 30 | 1. 首先准备问题 - 不同回复配对的样本,整理到一个 jsonl 文件中,其中 jsonl 文件中每一行为一条 Reward 模型训练数据,形式为如下的 Python 字典格式: 31 | 32 | ``` 33 | {'query': 问题,'response': [回复 1, 回复 2, .....], 'score': [score1, score2, .....]} 34 | ``` 35 | 36 | 其中 score 的值越高意味着对应回复的质量越高,越贴近人类偏好。 37 | 38 | 2. 以 Anthropic 的 helpful&harmless 的数据为例,使用如下代码,会存一个 `$DATASET_ROOT/rm/train.jsonl` 和 `$DATASET_ROOT/rm/dev.jsonl`. 39 | 40 | ```bash 41 | cd ${CHATLEARN}/examples/megatron/ 42 | DATASET_ROOT=path-to-dataset-root 43 | python data/prepare_data_reward.py $DATASET_ROOT 44 | ``` 45 | 46 | ## 准备 Alignment 训练数据 47 | 48 | ChatLearn中支持多种Alignment的训练模式:RLHF, DPO, OnlineDPO, GRPO 49 | 50 | 其中RLHF/OnlineDPO/GRPO数据格式相同。 51 | 52 | 53 | ### RLHF/OnlineDPO/GRPO 54 | 55 | 1. 首先准备一个需要被探索的指令数据集,整理到一个 jsonl 文件中,其中 jsonl 文件中每一行为一条指令,格式为 56 | 57 | ``` 58 | {"prompt": 问题} 59 | ``` 60 | 61 | 2. 以 Anthropic 的 helpful&harmless 的数据为例,使用如下代码,会存一个`$DATASET_ROOT/alignment/train.jsonl` 和`$DATASET_ROOT/alignment/dev.jsonl`: 62 | 63 | ```bash 64 | cd ${CHATLEARN}/examples/megatron/ 65 | DATASET_ROOT=path-to-dataset-root 66 | python data/prepare_data_alignment.py $DATASET_ROOT 67 | ``` 68 | 69 | ### DPO 70 | 71 | 准备一个需要被探索的指令数据集,整理到一个 jsonl 文件中,其中 jsonl 文件中每一行为一条指令,格式为:prompt+chosen+rejected,例如: 72 | 73 | ``` 74 | {"prompt": 问题, "chosen": 正偏好回答, "rejected": 负偏好回答} 75 | ``` 76 | 77 | 其中prompt字段内容分为两种场景: 78 | 1. 单轮对话:仅包括单轮对话的`问题`; 79 | 2. 多轮对话:包含前几轮对话的问答及最后一轮的提问。 80 | 81 | 开源Anthropic的helpful&harmless的数据满足DPO训练需求,可参考RLHF章节下载相应训练数据。 82 | 83 | ### Math 84 | 85 | 首先,准备一个要训练的数学数据集,并将其组织成一个 JSON 文件。JSON 文件中的每一行应该表示一个样本,格式如下: 86 | 87 | ``` 88 | {"eval_func": "math_rule", "prompt": prompt, "answer": answer} 89 | ``` 90 | 91 | 以 `openai/gsm8k` 数据为例,使用以下代码将数据集存储在 `$DATASET_ROOT/math/train.jsonl` 中: 92 | 93 | ``` 94 | cd ${CHATLEARN}/examples/megatron/ 95 | DATASET_ROOT=path-to-dataset-root 96 | python data/prepare_data_math.py $DATASET_ROOT 97 | ``` 98 | 99 | -------------------------------------------------------------------------------- /docs/zh/tutorial/ems.md: -------------------------------------------------------------------------------- 1 | # 高效显存复用(EMS) 2 | 3 | ChatLearn 中提供高效显存复用 (Efficient Memory Sharing, EMS) 功能来大幅减少训练过程中的显存占用。 4 | EMS 功能可以充分利用有限资源来训练更大规模的模型,也可以利用节约的显存来调整模型的并行策略或者增大 batch size,从而提升整体的训练效率。 5 | 6 | ChatLearn 中多个模型共享相同的资源进行训练或推理时,打开 EMS 功能,可以让这些模型按序共享使用显存: 7 | - 每个模型初始化完成后,将常驻显存的各类 tensor/buffer(包括 weight, grad buffer, optim states 等)卸载到内存或者直接释放,清空该模型占用的显存; 8 | - 某个模型训练或推理前,先从内存中加载或者重建 tensor/buffer,然后进行训练或推理; 9 | - 训练或推理完成后,将常驻显存的 tensor/buffer 卸载到内存或者直接释放,再次清空该模型占用的显存。 10 | 11 | 重复如上流程,多个模型间按序共享使用显存,最大化显存利用效率。 12 | 13 | ## 功能用法 14 | 用户通过配置每个模型的 `free_memory` (bool 类型, 默认为 False) 参数来指定是否开启 EMS 功能。 15 | 可以直接修改 `rlhf.yaml` 中每个模型的 `free_memory` 配置,例如打开 policy 模型的 EMS 功能: 16 | 17 | ```yaml 18 | policy: 19 | model_config_file: old_policy_inference.yaml 20 | ... 21 | free_memory: ${free_memory_policy:True} 22 | ``` 23 | 24 | 用户也可以在训练脚本中通过配置环境变量来启动 EMS 功能: 25 | - policy 模型:`export free_memory_policy=True` 26 | - reference 模型:`export free_memory_reference=True` 27 | - reward 模型:`export free_memory_reward=True` 28 | - value 模型:`export free_memory_value=True` 29 | - ppo_policy 模型:`export free_memory_ppo_policy=True` 30 | - ppo_value 模型:`export free_memory_ppo_value=True` 31 | 32 | 完整示例可以参考 [llama2 配置](https://github.com/alibaba/ChatLearn/blob/main/examples/megatron/configs/llama2/rlhf.yaml)。 33 | -------------------------------------------------------------------------------- /docs/zh/tutorial/evaluator.md: -------------------------------------------------------------------------------- 1 | # Evaluator 2 | 3 | 本文档将介绍如何进行模型评估。用户可以使用 `EvalEngine` 单独对模型进行评估,也可以在训练 Engine 里配置 evaluator 在训练的过程中进行评估。 4 | 5 | ```python 6 | def eval_flow(batch): 7 | p = policy.forward_step(batch) 8 | r = reward.eval_step(p) 9 | r1 = reward2.eval_step(p) 10 | return r, r1 11 | evaluator = Evaluator(eval_flow) 12 | evaluator.set_dataset(prompts) 13 | results = evaluator.eval() 14 | ``` 15 | 在上述例子中,我们构建了一个三个模型的评估flow,用户可以自定义 evaluation 的执行 flow。 16 | evaluator.eval 返回的结果是一个 dict 类型,key 是 model_name,value 是一个 list,包含 batch 的计算结果。 17 | 在上述例子中,eval 返回的结果为 {"reward": [batch0, batch1, batch2], "reward2": [batch0, batch1, batch2]} 18 | -------------------------------------------------------------------------------- /docs/zh/tutorial/profile.md: -------------------------------------------------------------------------------- 1 | # Profile 2 | 3 | ChatLearn 提供了两种 Profile 的方式: 4 | 1. torch profiler 5 | 2. nsys 6 | 7 | 注意:对于大模型,profile 的结果会非常大,建议在 profile 的时候减小模型尺寸。 8 | 9 | ## Torch Profiler 10 | 11 | 用户可以在系统的主配置文件中配置 rlhf 配置 `profiler_dir: path_to_profile_dir` 来开启 Torch profiler。 12 | 13 | ```yaml 14 | profiler_dir: path_to_profile_dir 15 | ``` 16 | 17 | ## nsys 18 | 19 | 用户可以在系统的主配置文件中配置 rlhf 配置 `nsys: True` 来开启 nsys 的 profiler。 20 | 21 | ```yaml 22 | runtime: 23 | nsys: True 24 | ``` 25 | 26 | 在启动程序的时候,需要在执行命令前加上 nsys 的启动参数,可以参考下述命令 27 | 28 | ```bash 29 | nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s none --capture-range=cudaProfilerApi --capture-range-end=stop-shutdown --cudabacktrace=true -x true --force-overwrite true -o my_profile \ 30 | python train_rlhf.py XXX 31 | ``` 32 | 33 | -------------------------------------------------------------------------------- /docs/zh/tutorial/run.md: -------------------------------------------------------------------------------- 1 | # 分布式执行 2 | 3 | 本文档将介绍如何执行一个分布式训练任务。 4 | 5 | ## PAI DLC 分布式执行 6 | 7 | [阿里云 PAI DLC](https://www.aliyun.com/activity/bigdata/pai-dlc) [1]可以非常便捷高效地支持各种任务的训练。 8 | 9 | 以下为 PAI-DLC 创建任务的页面截图,选择作业类型为 `PyTorch`, 同时将上述命令修改后粘贴到`执行命令`窗口中, 设置节点镜像为 ChatLearn 编译后的镜像地址。在这个例子中,我们申请了2个节点,每个节点配置 8 卡 GPU。 10 | 11 | ![image.png](../../images/dlc_1.jpg) 12 | 13 | ![image.png](../../images/dlc_2.jpg) 14 | 15 | 16 | 17 | ## 其他环境分布式执行 18 | 19 | 如果您需要在非 PAI DLC 环境执行分布式任务,您需要配置以下环境变量。 20 | 21 | ```bash 22 | export MASTER_ADDR=xxx 23 | export MASTER_PORT=xxx 24 | export WORLD_SIZE=xxx 25 | export GPUS_PER_NODE=8 26 | export RANK=xx 27 | ``` 28 | 29 | ## reference 30 | 31 | 1. 阿里云机器学习 PAI-DLC:[https://www.aliyun.com/activity/bigdata/pai-dlc](https://www.aliyun.com/activity/bigdata/pai-dlc) 32 | -------------------------------------------------------------------------------- /docs/zh/tutorial/tutorial_grpo_fsdp.md: -------------------------------------------------------------------------------- 1 | # 基于 FSDP 的端到端GRPO训练流程 2 | 3 | 本文档提供使用 ChatLearn、PyTorch FSDP 和 vLLM 框架来对Qwen3模型进行GRPO训练的快速开始指南。 4 | 5 | ## 环境配置 6 | 1. Docker镜像准备 7 | 我们建议在PAI [DSW](https://help.aliyun.com/zh/pai/user-guide/create-and-manage-dsw-instances/)/[DLC](https://help.aliyun.com/zh/pai/user-guide/create-a-training-task?spm=a2c4g.11186623.help-menu-30347.d_3_3_5_5.2dfb1925l3QjwG)中运行该示例,你需要填写如下镜像地址来启动实例: 8 | ```bash 9 | dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312 10 | ``` 11 | 12 | 可以使用vpc地址来加速镜像拉取速度,需要根据当前region信息来更改镜像地址。比如,启动在上海的DSW实例,可以使用如下镜像`dsw-registry-vpc.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312`。 13 | 14 | 2. 代码准备 15 | 16 | ```bash 17 | git clone https://github.com/alibaba/ChatLearn.git && cd ChatLearn 18 | ``` 19 | 20 | ## 数据准备 21 | 以[MATH-lighteval](https://www.modelscope.cn/datasets/AI-ModelScope/MATH-lighteval)数据集作为示例. 22 | ```bash 23 | # 下载数据集 24 | mkdir -p dataset 25 | modelscope download --dataset AI-ModelScope/MATH-lighteval --local_dir dataset/MATH-lighteval 26 | # 数据集预处理 27 | python examples/fsdp/data/data_preprocess/math_lighteval.py --input_dir dataset/MATH-lighteval --local_dir dataset/MATH-lighteval 28 | # 下载模型权重 29 | modelscope download --model Qwen/Qwen3-8B --local_dir Qwen3-8B 30 | ``` 31 | 32 | ## 训练 33 | 运行以下命令开始训练: 34 | 35 | ```bash 36 | bash examples/fsdp/scripts/train_grpo_qwen3.sh 37 | ``` 38 | 39 | ## 使用 Wandb 监控 40 | 如需使用 Wandb 记录训练过程,请修改[train_grpo_qwen3.sh](../../../examples/fsdp/scripts/train_grpo_qwen3.sh)中的配置: 41 | 42 | ```bash 43 | export enable_wandb=True 44 | export wandb_project="Your-Wandb-Project-Name" 45 | export WANDB_API_KEY="Your-Wandb-api-key" 46 | ``` -------------------------------------------------------------------------------- /docs/zh/tutorial/tutorial_grpo_mcore.md: -------------------------------------------------------------------------------- 1 | # 基于 Mcore 的端到端GRPO训练流程 2 | 3 | 本文档提供使用 ChatLearn、Mcore 和 vLLM 框架来对Qwen2.5模型进行GRPO训练的快速开始指南。 4 | 5 | ## 环境配置 6 | 1. Docker镜像准备 7 | 我们建议在PAI [DSW](https://help.aliyun.com/zh/pai/user-guide/create-and-manage-dsw-instances/)/[DLC](https://help.aliyun.com/zh/pai/user-guide/create-a-training-task?spm=a2c4g.11186623.help-menu-30347.d_3_3_5_5.2dfb1925l3QjwG)中运行该示例,你需要填写如下镜像地址来启动实例: 8 | ```bash 9 | dsw-registry.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312 10 | ``` 11 | 12 | 可以使用vpc地址来加速镜像拉取速度,需要根据当前region信息来更改镜像地址。比如,启动在上海的DSW实例,可以使用如下镜像`dsw-registry-vpc.cn-shanghai.cr.aliyuncs.com/pai-training-algorithm/chatlearn:torch2.6.0-vllm0.8.5-ubuntu24.04-cuda12.6-py312`。 13 | 14 | 2. 代码准备 15 | 16 | ```bash 17 | git clone https://github.com/NVIDIA/Megatron-LM.git 18 | cd Megatron-LM && git checkout 6ba97dd37150a6bfba03d31808674211cf2a4d0d 19 | git clone https://github.com/alibaba/ChatLearn.git && cd ChatLearn 20 | ``` 21 | 22 | ## 数据准备 23 | 以[MATH-lighteval](https://www.modelscope.cn/datasets/AI-ModelScope/MATH-lighteval)数据集作为示例. 24 | ```bash 25 | # 下载数据集 26 | mkdir -p dataset 27 | modelscope download --dataset AI-ModelScope/MATH-lighteval --local_dir dataset/MATH-lighteval 28 | # 数据集预处理 29 | python examples/fsdp/data/data_preprocess/math_lighteval.py --input_dir dataset/MATH-lighteval --local_dir dataset/MATH-lighteval 30 | # 下载模型权重 31 | modelscope download --model Qwen/Qwen2.5-7B-Instruct --local_dir Qwen2.5-7B-Instruct 32 | ``` 33 | 34 | ## 模型转换 35 | 36 | 模型格式转换可以参考 [Pai-Megatron-Patch](https://github.com/alibaba/Pai-Megatron-Patch) 项目提供的转换脚本。 37 | 高性能分布式权重转换可以参考:https://github.com/alibaba/Pai-Megatron-Patch/tree/main/toolkits/distributed_checkpoints_convertor 38 | 39 | 运行`hf2mcore_qwen2.5_convertor.sh`脚本,需要传入的参数列表如下 40 | ``` 41 | MODEL_SIZE=$1 # 模型参数:0.5B/1.5B/3B/7B/14B/32B/72B 42 | SOURCE_CKPT_PATH=$2 # 源路径 43 | TARGET_CKPT_PATH=$3 # 目标路径 44 | TP=$4 # 模型并行度 45 | PP=$5 # 流水并行度 46 | PR=$6 # 转换精度 47 | USE_TE=$7 # 是否使用Transformer Engine建模 48 | mg2hf=$8 # 是否执行mcore2hf转换 49 | HG_CKPT_PATH=$9 # HF的CKPT的路径 50 | ``` 51 | 52 | 例如,使用下述脚本将7B量级的Qwen2.5的Huggingface格式的模型转换到MCore格式 53 | ```bash 54 | git clone --recurse-submodules https://github.com/alibaba/Pai-Megatron-Patch.git 55 | cd ~/Pai-Megatron-Patch/toolkits/model_checkpoints_convertor/qwen 56 | bash hf2mcore_qwen2.5_convertor.sh \ 57 | 7B \ 58 | /mnt/qwen-ckpts/Qwen2.5-7B-Instruct \ 59 | /mnt/qwen-ckpts/Qwen2.5-7B-Instruct-hf-to-mcore-tp4-pp1 \ 60 | 4 \ 61 | 1 \ 62 | bf16 \ 63 | true \ 64 | false 65 | ``` 66 | 67 | ## 训练 68 | 运行以下命令开始训练: 69 | 70 | ```bash 71 | export MEGATRON_PATH="your megatron path" 72 | bash examples/mcore/scripts/train_grpo_qwen2_5.sh 73 | ``` 74 | 75 | ## 使用 Wandb 监控 76 | 如需使用 Wandb 记录训练过程,请修改[train_grpo_qwen2_5.sh](../../../examples/mcore/scripts/train_grpo_qwen2_5.sh)中的配置: 77 | 78 | ```bash 79 | export enable_wandb=True 80 | export wandb_project="Your-Wandb-Project-Name" 81 | export WANDB_API_KEY="Your-Wandb-api-key" 82 | ``` -------------------------------------------------------------------------------- /docs/zh/tutorial/tutorial_qwen.md: -------------------------------------------------------------------------------- 1 | # 基于 Qwen 模型的端到端训练教程 2 | 3 | 本文档介绍基于 ChatLearn, DeepSpeed 框架和 Qwen 模型进行 DPO 训练。 4 | 5 | **以下是这个 Tutorial 脚本中使用的通用环境变量集合:** 6 | 7 | | ENV | 含义 | 8 | | --- |-------------------------------------------------------------------------------------------------------------------------------| 9 | | `CHATLEARN` | ChatLearn 代码仓库 clone 存放的位置 [https://github.com/alibaba/ChatLearn.git](https://github.com/alibaba/ChatLearn.git) | 10 | | `DATASET_ROOT` | 存放训练数据集合的根目录 | 11 | 12 | 13 | ## Setup: 镜像、代码、数据准备 14 | 15 | ### 镜像和代码 16 | 17 | 请参考 [镜像和代码准备](../installation.md)。 18 | 19 | ### 数据 20 | 21 | qwen2 要求的数据格式为chatml 22 | 23 | ``` 24 | {"type": "chatml", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Tell me something about large language models."}, {"role": "assistant", "content": "Large language models are a type of language model that is trained on a large corpus of text data. They are capable of generating human-like text and are used in a variety of natural language processing tasks..."}], "source": "unknown"} 25 | ``` 26 | 通过以下脚本可以将 `Dahoas/full-hh-rlhf` 转换为 chatml 格式的数据, 并存储 `$DATASET_ROOT/alignment/train.jsonl` 文件. 27 | 28 | ```bash 29 | cd ${CHATLEARN}/examples/huggingface/ 30 | DATASET_ROOT=path-to-dataset-root 31 | python data/preprocess_data_chatml.py $DATASET_ROOT 32 | ``` 33 | 34 | 35 | ### DPO 36 | 37 | 以下是一个 Qwen2-7B 的 DPO 训练范例。 38 | 在这个例子中,用户需要设置 `policy_model_path` 为 初始化模型 checkpoint 路径,Policy 模型和 Reference 模型将以这个 checkpoint 初始化。 39 | 40 | ``` 41 | export CHATLEARN=path-to-chatlearn 42 | export DATASET_PATH=$DATASET_ROOT/alignment/train.jsonl 43 | export policy_model_path=path-to-qwen2-ckpt 44 | cd ${CHATLEARN}/examples/huggingface/ 45 | bash scripts/train_dpo_qwen.sh 46 | ``` 47 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/examples/__init__.py -------------------------------------------------------------------------------- /examples/fsdp/configs/grpo/base.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | -------------------------------------------------------------------------------- /examples/fsdp/configs/grpo/grpo.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | model_config_file: vllm_policy_inference.yaml 13 | num_gpu: ${num_device:1} 14 | trainable: False 15 | batch_generation: 16 | ranking: ${batch_generation_ranking:False} 17 | min_prompt_length: ${batch_generation_min_prompt_length:0} 18 | free_memory: ${free_memory_policy:True} 19 | generation_batch_size: ${vllm_generation_batch_size:256} 20 | 21 | reward: 22 | model_config_file: base.yaml 23 | num_cpu: 2 # must set devices 24 | generation_batch_size: ${vllm_generation_batch_size:256} 25 | 26 | ref_policy: 27 | model_config_file: reference.yaml 28 | num_gpu: ${num_device:1} 29 | gpu_per_process: 1 30 | fsdp_size: -1 31 | trainable: False 32 | free_memory: ${free_memory_reference:True} 33 | generation_batch_size: ${ref_generation_batch_size:8} 34 | 35 | 36 | policy_trainer: 37 | model_config_file: policy_trainer.yaml 38 | num_gpu: ${num_device:1} 39 | gpu_per_process: 1 40 | fsdp_size: -1 41 | sp_size: ${sp_size:1} 42 | trainable: True 43 | free_memory: ${free_memory_ppo_policy:True} 44 | generation_batch_size: ${trainer_generation_batch_size:8} 45 | 46 | runtime: 47 | colocation: 48 | - policy,policy_trainer,ref_policy 49 | data_path: ${train_data_path} 50 | eval_data_path: ${eval_data_path} 51 | output_dir: ${output_dir} 52 | exp_name: ${exp_name} 53 | num_episode: ${num_episode:200} 54 | sample_per_episode: ${sample_per_episode:1024} 55 | train_micro_batch_size: ${train_micro_batch_size:1} 56 | train_global_batch_size: ${train_global_batch_size:256} 57 | save_episode_interval: ${save_episode_interval:20} 58 | max_relay_episode: 2 # for enable grpo adv compute 59 | eval_episode_interval: ${eval_episode_interval:5} 60 | log_config_file: log.yaml 61 | data_checkpoint_path: ${data_checkpoint_path} 62 | enable_eval_before_training: ${enable_eval_before_training:False} 63 | -------------------------------------------------------------------------------- /examples/fsdp/configs/grpo/log.yaml: -------------------------------------------------------------------------------- 1 | log_dir: ${log_dir:} 2 | 3 | enable_tensorboard: ${enable_tensorboard:False} 4 | tensorboard_dir: ${tensorboard_dir:<>} 5 | 6 | enable_wandb: ${enable_wandb:False} 7 | wandb_dir: ${wandb_dir} 8 | wandb_project: ${wandb_project} 9 | wandb_id: ${exp_name} 10 | wandb_name: ${exp_name} 11 | wandb_resume: ${WANDB_RESUME:allow} 12 | 13 | # export WANDB_DISABLE_CODE="true" 14 | # export WANDB_IGNORE_GLOBS="*.patch" 15 | -------------------------------------------------------------------------------- /examples/fsdp/configs/grpo/policy_trainer.yaml: -------------------------------------------------------------------------------- 1 | includes: 2 | - base.yaml 3 | pretrain_or_model: ${model_path} 4 | learning_rate: 2e-6 5 | grad_clip: 1 6 | gradient_checkpointing: True 7 | # for grpo algorithm 8 | pos_clip_ratio: 0.2 9 | negative_clip_ratio: 0.2 10 | save_hf: True -------------------------------------------------------------------------------- /examples/fsdp/configs/grpo/reference.yaml: -------------------------------------------------------------------------------- 1 | includes: 2 | - base.yaml 3 | pretrain_or_model: ${model_path} -------------------------------------------------------------------------------- /examples/fsdp/configs/grpo/vllm_policy_inference.yaml: -------------------------------------------------------------------------------- 1 | includes: 2 | - base.yaml 3 | 4 | # model partition 5 | tensor_model_parallel_size: ${tensor_model_parallel_size:2} 6 | pipeline_model_parallel_size: 1 7 | num_inference_per_prompt: ${num_inference_per_prompt:8} 8 | seq_length: ${seq_length:1024} 9 | max_new_tokens: ${max_new_tokens:1023} 10 | 11 | # sampling params 12 | temperature: ${policy_temperature:1.0} 13 | top_p: ${policy_top_p:0.9} 14 | top_k: ${policy_top_k:-1} 15 | presence_penalty: ${policy_presence_penalty:0.0} 16 | frequency_penalty: ${policy_frequency_penalty:0.0} 17 | repetition_penalty: ${policy_repetition_penalty:1.0} 18 | 19 | eval_temperature: ${policy_eval_temperature:0.6} 20 | eval_top_k: ${policy_eval_top_k:-1} 21 | eval_top_p: ${policy_eval_top_p:0.9} 22 | eval_presence_penalty: ${policy_eval_presence_penalty:0.0} 23 | eval_frequency_penalty: ${policy_eval_frequency_penalty:0.0} 24 | eval_repetition_penalty: ${policy_eval_repetition_penalty:1.0} 25 | 26 | # dataset 27 | vllm_prompt_key: ${vllm_prompt_key:prompt} 28 | vllm_input_ids_key: ${vllm_input_ids_key:input_ids} 29 | enable_thinking: ${enable_thinking:False} 30 | 31 | # scheduler config 32 | max_num_batched_tokens: ${max_num_batched_tokens:32768} 33 | max_seq_len_to_capture: ${max_seq_len_to_capture:32768} 34 | enable_stage_resume: ${enable_policy_stage_resume:False} 35 | 36 | # cache config 37 | gpu_memory_utilization: ${gpu_memory_utilization:0.85} 38 | enforce_eager: False 39 | 40 | # tokenizer path 41 | tokenizer: ${model_path} -------------------------------------------------------------------------------- /examples/fsdp/data/data_preprocess/gsm8k.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess the GSM8k dataset to json format 3 | """ 4 | import re 5 | import os 6 | import argparse 7 | 8 | import datasets 9 | 10 | 11 | def extract_solution(solution_str): 12 | solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) 13 | assert solution is not None 14 | final_solution = solution.group(0) 15 | final_solution = final_solution.split('#### ')[1].replace(',', '') 16 | return final_solution 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--input_dir', default=None) 22 | parser.add_argument('--local_dir', default='~/data/gsm8k') 23 | 24 | args = parser.parse_args() 25 | 26 | data_source = 'openai/gsm8k' 27 | data_dir = 'openai/gsm8k' if args.input_dir is None else args.input_dir 28 | 29 | dataset = datasets.load_dataset(data_dir, 'main') 30 | 31 | train_dataset = dataset['train'] 32 | test_dataset = dataset['test'] 33 | 34 | # instruction_following = "Let's think step by step and output the final answer after \"####\"." 35 | instruction_following = "Let's think step by step and output the final answer within \\boxed{}." 36 | 37 | # add a row to each data item that represents a unique id 38 | def make_map_fn(split): 39 | 40 | def process_fn(example, idx): 41 | question_raw = example.pop('question') 42 | 43 | question = question_raw + ' ' + instruction_following 44 | 45 | answer_raw = example.pop('answer') 46 | solution = extract_solution(answer_raw) 47 | data = { 48 | "data_source": data_source, 49 | "prompt": [{ 50 | "role": "user", 51 | "content": question, 52 | }], 53 | "ability": "math", 54 | "reward_model": { 55 | "style": "rule", 56 | "ground_truth": solution 57 | }, 58 | "extra_info": { 59 | 'split': split, 60 | 'index': idx, 61 | 'answer': answer_raw, 62 | "question": question_raw, 63 | } 64 | } 65 | return data 66 | 67 | return process_fn 68 | 69 | train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True) 70 | test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True) 71 | 72 | local_dir = args.local_dir 73 | 74 | train_dataset.to_json(os.path.join(local_dir, 'train.json')) 75 | test_dataset.to_json(os.path.join(local_dir, 'test.json')) 76 | -------------------------------------------------------------------------------- /examples/fsdp/data/prompt_dataset.py: -------------------------------------------------------------------------------- 1 | """prompt dataset""" 2 | 3 | import copy 4 | from collections import defaultdict 5 | from typing import List, Dict 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torch.nn.functional as F 10 | from transformers import AutoTokenizer 11 | 12 | from chatlearn.utils.utils import multi_thread_data_processing 13 | 14 | 15 | class VLLMPromptPipeline(Dataset): 16 | """ 17 | process this format 18 | { 19 | "data_source": data_source, 20 | "prompt": [{ 21 | "role": "user", 22 | "content": question, 23 | }], 24 | "ability": "math", 25 | "reward_model": { 26 | "style": "rule", 27 | "ground_truth": solution 28 | }, 29 | "extra_info": { 30 | 'split': split, 31 | 'index': idx, 32 | 'answer': answer_raw, 33 | "question": question_raw, 34 | } 35 | } 36 | self.data format 37 | {"input_ids": prompt_ids, "prompt": prompt} 38 | """ 39 | 40 | def __init__(self, data_list: List[Dict], seq_length: int, tokenizer: AutoTokenizer = None, num_inference_per_prompt: int = 1, enable_thinking = False):# pylint: disable=super-init-not-called 41 | super().__init__() 42 | 43 | self.tokenizer = tokenizer 44 | self.data = [] 45 | 46 | for data_item in data_list: 47 | prompt = data_item["prompt"] 48 | data_source = data_item.get("data_source", "") 49 | ground_truth = data_item['reward_model']['ground_truth'] 50 | if isinstance(prompt, list): 51 | prompt = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking) 52 | input_ids = self.tokenizer.encode(prompt) 53 | processed_data = {"input_ids": input_ids, "prompt": prompt, "data_source": data_source, "ground_truth": ground_truth} 54 | if seq_length > len(input_ids): 55 | self.data.extend([processed_data]*num_inference_per_prompt) 56 | 57 | def __getitem__(self, ix: int): 58 | return self.data[ix] 59 | 60 | def __len__(self) -> int: 61 | return len(self.data) 62 | 63 | def collate_fn(self, samples): 64 | collate_dict = defaultdict(list) 65 | 66 | # Loop over the samples and append each tensor value to the corresponding list 67 | for sample in samples: 68 | for key in sample.keys(): 69 | collate_dict[key].append(sample[key]) 70 | 71 | # Return the collate_dict 72 | return collate_dict -------------------------------------------------------------------------------- /examples/fsdp/models/grpo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/examples/fsdp/models/grpo/__init__.py -------------------------------------------------------------------------------- /examples/fsdp/models/grpo/loss_gallery.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | def calculate_grpo_loss( 6 | log_probs: torch.Tensor, 7 | old_log_probs: torch.Tensor, 8 | advantages: torch.Tensor, 9 | diff_clip_ratio: float = 10, 10 | pos_clip_ratio: float = 0.2, 11 | negative_clip_ratio: float = 0.2, 12 | final_clip_ratio: float = 0.01, 13 | ): 14 | logprobs_diff = log_probs - old_log_probs 15 | # clip logprobs_diff before exp to avoid overflow 16 | logprobs_diff = torch.clamp(logprobs_diff, max=diff_clip_ratio) 17 | 18 | ratio = torch.exp(logprobs_diff) 19 | pg_loss = -advantages.unsqueeze(-1) * ratio 20 | # Upper and lower bound clip 21 | pg_loss_2 = -advantages.unsqueeze(-1) * torch.clamp(ratio, 1 - negative_clip_ratio, 1 + pos_clip_ratio) 22 | pg_loss_clip = torch.max(pg_loss, pg_loss_2) 23 | pg_loss_upperbound = torch.ones_like(pg_loss) * final_clip_ratio 24 | # final clip on loss 25 | loss = torch.min(pg_loss_clip, pg_loss_upperbound) 26 | 27 | # check pg_loss nan 28 | assert not torch.isnan(loss).any(), "pg loss is nan" 29 | 30 | return loss.contiguous() 31 | -------------------------------------------------------------------------------- /examples/fsdp/models/rule_reward.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2024 Alibaba-inc. and/or its affiliates 3 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py 16 | 17 | import argparse 18 | import copy 19 | import os 20 | from collections import defaultdict 21 | from typing import Dict 22 | 23 | import torch 24 | import numpy as np 25 | 26 | from chatlearn import BaseModule 27 | from examples.fsdp.utils.rule_reward_score import math 28 | 29 | 30 | class RuleReward(BaseModule): 31 | 32 | def setup(self): 33 | self.stats = {} 34 | self._metric_prefix = "rulereward" 35 | 36 | def _forward_step(self, data: Dict) -> torch.Tensor: 37 | 38 | # str_prompts_list = data["str_prompts"] 39 | str_outputs_list = data["str_outputs"] 40 | data_source_list = data["data_source"] 41 | ground_truth_list = data["ground_truth"] 42 | self._logger.info(f"RuleReward _forward_step Num of request: {len(str_outputs_list)}") 43 | 44 | reward_tensor = torch.zeros([len(str_outputs_list), 1], dtype=torch.float32) 45 | 46 | for i, str_output in enumerate(str_outputs_list): 47 | data_source = data_source_list[i] 48 | ground_truth = ground_truth_list[i] 49 | compute_score_fn = self.select_rule_reward_score_fn(data_source) 50 | reward_tensor[i] = compute_score_fn(str_output, ground_truth) 51 | 52 | res_dict = {"rule_rewards": reward_tensor} 53 | return res_dict 54 | 55 | def forward_step(self, data: Dict, iteration=0) -> Dict: 56 | 57 | res_dict = self._forward_step(data) 58 | 59 | # collect stats 60 | rule_rewards = res_dict["rule_rewards"] 61 | train_reward_score = rule_rewards.mean().item() 62 | train_reward_stats = { 63 | "train_reward_score": train_reward_score, 64 | } 65 | self._metric_list.append(train_reward_stats) 66 | return res_dict 67 | 68 | def eval_forward(self, data: Dict) -> Dict: 69 | 70 | return self._forward_step(data) 71 | 72 | def select_rule_reward_score_fn(self, data_source: str): 73 | if data_source in ['openai/gsm8k', 'DigitalLearningGmbH/MATH-lighteval']: 74 | return math.compute_score 75 | else: 76 | raise NotImplementedError 77 | -------------------------------------------------------------------------------- /examples/fsdp/scripts/base_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ray stop 4 | 5 | export CUDA_DEVICE_MAX_CONNECTIONS=1 6 | export NCCL_DEBUG=WARN 7 | 8 | [ -z "$MASTER_ADDR" ] && export MASTER_ADDR=localhost 9 | [ -z "$WORLD_SIZE" ] && export WORLD_SIZE=1 10 | [ -z "$GPUS_PER_NODE" ] && export GPUS_PER_NODE=${KUBERNETES_CONTAINER_RESOURCE_GPU:-$(python -c "import torch; print(torch.cuda.device_count())")} 11 | [ -z "$RANK" ] && export RANK=0 12 | [ -z "$MASTER_PORT" ] && export MASTER_PORT=12456 13 | [ -z "$NNODES" ] && export NNODES=${WORLD_SIZE:-1} 14 | [ -z "$NODE_RANK" ] && export NODE_RANK=${RANK:-0} 15 | if [ -z "${CUSTOM_PORTS}" ]; then 16 | set +x 17 | ports="30010" 18 | for i in $(seq 30011 30050); do 19 | ports="${ports};${i}" 20 | done 21 | set -x 22 | export CUSTOM_PORTS=$ports 23 | [ -z "$LOCAL_MASTER_ADDR" ] && export LOCAL_MASTER_ADDR=$MASTER_ADDR 24 | echo LOCAL_MASTER_ADDR=$MASTER_ADDR 25 | fi 26 | 27 | if [ -z "$CHATLEARN" ]; then 28 | echo "please set CHATLEARN path" 29 | exit 1 30 | fi 31 | 32 | rm core* 33 | 34 | export PYTHONPATH=${CHATLEARN}:${CHATLEARN}/examples/fsdp:${PYTHONPATH} 35 | export num_device=$(($WORLD_SIZE * $GPUS_PER_NODE)) 36 | -------------------------------------------------------------------------------- /examples/fsdp/scripts/train_grpo_qwen2_5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | # set path 5 | export CHATLEARN=$(pwd) 6 | export model_path="${CHATLEARN}/Qwen2.5-7B-Instruct" 7 | export exp_name=qwen2.5-grpo 8 | export output_dir=${CHATLEARN}/output/${exp_name} 9 | export train_data_path=${CHATLEARN}/dataset/MATH-lighteval/train.json 10 | export eval_data_path=${CHATLEARN}/dataset/MATH-lighteval/test.json 11 | export data_checkpoint_path=${output_dir}/data_checkpoint_path/ 12 | log_file=$output_dir/log_${RANK}.log 13 | mkdir -p $output_dir/ 14 | export log_dir=${output_dir} 15 | export wandb_dir=${output_dir} 16 | 17 | cd $CHATLEARN/examples/fsdp/ 18 | source scripts/base_env.sh 19 | 20 | # Env setup 21 | # export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True 22 | export RAY_DEDUP_LOGS=1 23 | export NCCL_NVLS_ENABLE=0 24 | 25 | # log setup 26 | export enable_wandb=False 27 | export wandb_project="grpo-exp" 28 | export WANDB_API_KEY="wandb-api-key" 29 | 30 | #Setup sequence_parallel 31 | export sp_size=1 32 | 33 | #VLLM setup 34 | export VLLM_USE_RAY_SPMD_WORKER=1 35 | export VLLM_USE_RAY_COMPILED_DAG=1 36 | 37 | export tensor_model_parallel_size=2 38 | export policy_temperature=1.0 39 | export policy_top_p=1.0 40 | export policy_top_k=-1 41 | export policy_eval_temperature=0.6 42 | export policy_eval_top_p=0.95 43 | export policy_eval_top_k=20 44 | 45 | export seq_length=2048 46 | export max_new_tokens=2048 47 | export max_seq_len_to_capture=2348 48 | export num_inference_per_prompt=32 49 | export train_global_batch_size=2048 50 | export sample_per_episode=2048 51 | export vllm_generation_batch_size=128 52 | export train_micro_batch_size=16 53 | export gpu_memory_utilization=0.85 54 | 55 | export enable_eval_before_training=True 56 | export num_episode=20 57 | export eval_episode_interval=5 58 | export save_episode_interval=20 59 | 60 | python entry/train_grpo.py -c configs/grpo/grpo.yaml 2>&1 | tee ${log_file}.log ; exit ${PIPESTATUS[0]} 61 | -------------------------------------------------------------------------------- /examples/fsdp/scripts/train_grpo_qwen3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | # set path 5 | export CHATLEARN=$(pwd) 6 | export model_path="${CHATLEARN}/Qwen3-8B" 7 | export exp_name=qwen3-grpo 8 | export output_dir=${CHATLEARN}/output/${exp_name} 9 | export train_data_path=${CHATLEARN}/dataset/MATH-lighteval/train.json 10 | export eval_data_path=${CHATLEARN}/dataset/MATH-lighteval/test.json 11 | export data_checkpoint_path=${output_dir}/data_checkpoint_path/ 12 | log_file=$output_dir/log_${RANK}.log 13 | mkdir -p $output_dir/ 14 | export log_dir=${output_dir} 15 | export wandb_dir=${output_dir} 16 | 17 | cd $CHATLEARN/examples/fsdp/ 18 | source scripts/base_env.sh 19 | 20 | # Env setup 21 | # export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True 22 | export RAY_DEDUP_LOGS=1 23 | export NCCL_NVLS_ENABLE=0 24 | 25 | # log setup 26 | export enable_wandb=False 27 | export wandb_project="grpo-exp" 28 | export WANDB_API_KEY="wandb-api-key" 29 | 30 | #Setup sequence_parallel 31 | export sp_size=1 32 | 33 | #VLLM setup 34 | export VLLM_USE_RAY_SPMD_WORKER=1 35 | export VLLM_USE_RAY_COMPILED_DAG=1 36 | 37 | export tensor_model_parallel_size=1 38 | export policy_temperature=1.0 39 | export policy_top_p=1.0 40 | export policy_top_k=-1 41 | export policy_eval_temperature=0.6 42 | export policy_eval_top_p=0.95 43 | export policy_eval_top_k=20 44 | 45 | export seq_length=2048 46 | export max_new_tokens=2048 47 | export max_seq_len_to_capture=2348 48 | export num_inference_per_prompt=32 49 | export train_global_batch_size=2048 50 | export sample_per_episode=2048 51 | export vllm_generation_batch_size=256 52 | export train_micro_batch_size=8 53 | export gpu_memory_utilization=0.80 54 | 55 | export enable_eval_before_training=False 56 | export num_episode=200 57 | export eval_episode_interval=5 58 | export save_episode_interval=400 59 | # for qwen3 where enable_thinking 60 | export enable_thinking=False 61 | 62 | python entry/train_grpo.py -c configs/grpo/grpo.yaml 2>&1 | tee ${log_file}.log ; exit ${PIPESTATUS[0]} 63 | -------------------------------------------------------------------------------- /examples/fsdp/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/examples/fsdp/utils/__init__.py -------------------------------------------------------------------------------- /examples/fsdp/utils/rule_reward_score/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/examples/fsdp/utils/rule_reward_score/__init__.py -------------------------------------------------------------------------------- /examples/mcore/configs/grpo/base.yaml: -------------------------------------------------------------------------------- 1 | includes: 2 | - log.yaml 3 | 4 | seed: 1234 5 | tokenizer_type: ${tokenizer_type:NullTokenizer} 6 | patch_tokenizer_type: ${patch_tokenizer_type:NullTokenizer} 7 | tokenizer_model: ${tokenizer_model:"None"} 8 | vocab_file: ${vocab_file:"None"} 9 | merge_file: ${merge_file:"None"} 10 | vocab_size: ${vocab_size:32000} 11 | extra_vocab_size: ${extra_vocab_size:421} 12 | make_vocab_size_divisible_by: ${make_vocab_size_divisible_by:128} 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /examples/mcore/configs/grpo/grpo_qwen2_5.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | model_config_file: vllm_policy_inference.yaml 13 | num_gpu: ${num_device:1} 14 | trainable: False 15 | batch_generation: 16 | ranking: ${batch_generation_ranking:False} 17 | min_prompt_length: ${batch_generation_min_prompt_length:0} 18 | free_memory: ${free_memory_policy:True} 19 | generation_batch_size: ${vllm_generation_batch_size:8} 20 | 21 | reward: 22 | model_config_file: base.yaml 23 | num_cpu: 2 # must set devices 24 | generation_batch_size: ${vllm_generation_batch_size:8} 25 | 26 | ref_policy: 27 | model_config_file: policy_trainer_qwen2_5.yaml 28 | num_gpu: ${num_device:1} 29 | gpu_per_process: 1 30 | trainable: False 31 | free_memory: ${free_memory_reference:True} 32 | sync_frequency: ${sync_frequency:-1} 33 | generation_batch_size: ${trainer_generation_batch_size:8} 34 | 35 | policy_trainer: 36 | model_config_file: policy_trainer_qwen2_5.yaml 37 | num_gpu: ${num_device:1} 38 | gpu_per_process: 1 39 | trainable: True 40 | free_memory: ${free_memory_ppo_policy:True} 41 | generation_batch_size: ${trainer_generation_batch_size:8} 42 | 43 | runtime: 44 | colocation: 45 | - policy,policy_trainer,ref_policy 46 | data_path: ${train_data_path} 47 | eval_data_path: ${eval_data_path} 48 | output_dir: ${output_dir} 49 | exp_name: ${exp_name} 50 | num_episode: ${num_episode:200} 51 | sample_per_episode: ${sample_per_episode:1024} 52 | train_micro_batch_size: ${train_micro_batch_size:1} 53 | train_global_batch_size: ${train_global_batch_size:256} 54 | save_episode_interval: ${save_episode_interval:20} 55 | max_relay_episode: 2 # for enable grpo adv compute 56 | eval_episode_interval: ${eval_episode_interval:5} 57 | log_config_file: log.yaml 58 | data_checkpoint_path: ${data_checkpoint_path} 59 | -------------------------------------------------------------------------------- /examples/mcore/configs/grpo/grpo_qwen3.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | model_config_file: vllm_policy_inference.yaml 13 | num_gpu: ${num_device:1} 14 | trainable: False 15 | batch_generation: 16 | ranking: ${batch_generation_ranking:False} 17 | min_prompt_length: ${batch_generation_min_prompt_length:0} 18 | free_memory: ${free_memory_policy:True} 19 | generation_batch_size: ${vllm_generation_batch_size:8} 20 | 21 | reward: 22 | model_config_file: base.yaml 23 | num_cpu: 2 # must set devices 24 | generation_batch_size: ${vllm_generation_batch_size:8} 25 | 26 | ref_policy: 27 | model_config_file: policy_trainer_qwen3.yaml 28 | num_gpu: ${num_device:1} 29 | gpu_per_process: 1 30 | trainable: False 31 | free_memory: ${free_memory_reference:True} 32 | sync_frequency: ${sync_frequency:-1} 33 | generation_batch_size: ${trainer_generation_batch_size:8} 34 | 35 | policy_trainer: 36 | model_config_file: policy_trainer_qwen3.yaml 37 | num_gpu: ${num_device:1} 38 | gpu_per_process: 1 39 | trainable: True 40 | free_memory: ${free_memory_ppo_policy:True} 41 | generation_batch_size: ${trainer_generation_batch_size:8} 42 | 43 | runtime: 44 | colocation: 45 | - policy,policy_trainer,ref_policy 46 | data_path: ${train_data_path} 47 | eval_data_path: ${eval_data_path} 48 | output_dir: ${output_dir} 49 | exp_name: ${exp_name} 50 | num_episode: ${num_episode:200} 51 | sample_per_episode: ${sample_per_episode:1024} 52 | train_micro_batch_size: ${train_micro_batch_size:1} 53 | train_global_batch_size: ${train_global_batch_size:256} 54 | save_episode_interval: ${save_episode_interval:20} 55 | max_relay_episode: 2 # for enable grpo adv compute 56 | eval_episode_interval: ${eval_episode_interval:5} 57 | log_config_file: log.yaml 58 | data_checkpoint_path: ${data_checkpoint_path} 59 | -------------------------------------------------------------------------------- /examples/mcore/configs/grpo/log.yaml: -------------------------------------------------------------------------------- 1 | log_dir: ${log_dir:"None"} 2 | exp_name: ${exp_name:test} 3 | tensorboard_dir: ${tensorboard_dir:"None"} 4 | enable_tensorboard: ${enable_tensorboard:False} 5 | 6 | enable_wandb: ${enable_wandb:False} 7 | wandb_dir: ${wandb_dir} 8 | wandb_project: ${wandb_project} 9 | wandb_id: ${exp_name} 10 | wandb_name: ${exp_name} 11 | wandb_resume: ${WANDB_RESUME:allow} -------------------------------------------------------------------------------- /examples/mcore/configs/grpo/model_qwen2_5.yaml: -------------------------------------------------------------------------------- 1 | attention_dropout: ${attention_dropout:0.0} 2 | hidden_dropout: ${hidden_dropout:0.0} 3 | num_layers: ${policy_num_layers:28} 4 | hidden_size: ${policy_hidden_size:3584} 5 | num_attention_heads: ${policy_num_attention_heads:28} 6 | ffn_hidden_size: ${policy_ffn_hidden_size:18944} 7 | num_query_groups: ${policy_num_query_groups:4} 8 | seq_length: ${seq_length:2048} 9 | max_position_embeddings: ${max_position_embeddings:131072} 10 | swiglu: True 11 | normalization: ${normalization:RMSNorm} 12 | norm_epsilon: ${RMS_NORM_EPS:1e-6} 13 | use_rotary_position_embeddings: True 14 | position_embedding_type: ${position_embedding_type:rope} 15 | add_qkv_bias: true 16 | add_bias_linear: false 17 | rotary_percent: 1.0 18 | rotary_base: 1000000 19 | rotary_seq_len_interpolation_factor: 1 20 | group_query_attention: True 21 | use_legacy_models: false 22 | untie_embeddings_and_output_weights: True -------------------------------------------------------------------------------- /examples/mcore/configs/grpo/model_qwen3.yaml: -------------------------------------------------------------------------------- 1 | attention_dropout: ${attention_dropout:0.0} 2 | hidden_dropout: ${hidden_dropout:0.0} 3 | num_layers: ${policy_num_layers:28} 4 | hidden_size: ${policy_hidden_size:3584} 5 | num_attention_heads: ${policy_num_attention_heads:28} 6 | ffn_hidden_size: ${policy_ffn_hidden_size:18944} 7 | num_query_groups: ${policy_num_query_groups:4} 8 | seq_length: ${seq_length:2048} 9 | max_position_embeddings: ${max_position_embeddings:131072} 10 | swiglu: True 11 | normalization: ${normalization:RMSNorm} 12 | norm_epsilon: ${RMS_NORM_EPS:1e-6} 13 | use_rotary_position_embeddings: True 14 | position_embedding_type: ${position_embedding_type:rope} 15 | add_qkv_bias: false 16 | add_bias_linear: false 17 | qk_layernorm: true 18 | kv_channels: 128 19 | rotary_percent: 1.0 20 | rotary_base: 1000000 21 | rotary_seq_len_interpolation_factor: 1 22 | group_query_attention: True 23 | use_legacy_models: false 24 | untie_embeddings_and_output_weights: True -------------------------------------------------------------------------------- /examples/mcore/configs/grpo/policy_trainer_qwen2_5.yaml: -------------------------------------------------------------------------------- 1 | includes: 2 | - base.yaml 3 | - model_qwen2_5.yaml 4 | 5 | load: ${load} 6 | save: ${save_dir} 7 | save_interval: ${save_interval:10000} 8 | train_iters: ${train_iters:12000} 9 | tensor_model_parallel_size: ${tensor_model_parallel_size:2} 10 | pipeline_model_parallel_size: ${pipeline_model_parallel_size:1} 11 | distributed_backend: nccl 12 | 13 | clip_grad: ${clip_grad:1.0} 14 | log_interval: 1 15 | 16 | bf16: True 17 | use_checkpoint_opt_param_scheduler: False 18 | adam_beta1: 0.9 19 | adam_beta2: 0.95 20 | num_workers: 8 21 | init_method_std: 0.006 22 | 23 | #recompute_method: uniform 24 | #recompute_granularity: full 25 | recompute_granularity: selective 26 | sequence_parallel: True 27 | 28 | no_load_optim: True 29 | no_load_rng: True 30 | no_load_args: True 31 | no_load_scheduler: True 32 | finetune: True 33 | dummy: 0 34 | 35 | 36 | lr_decay_iters: ${lr_decay_iters:12000} 37 | lr_warmup_iters: ${policy_lr_warmup_iters:100} 38 | lr: ${policy_lr:0.00000008} 39 | min_lr: ${policy_min_lr:0.000000008} 40 | lr_decay_style: ${policy_lr_decay_style:linear} 41 | weight_decay: 0.01 42 | lr_freeze_episodes: ${policy_lr_freeze_episodes:0} 43 | 44 | 45 | init_kl_coef: ${init_kl_coef:0.0} 46 | target: 6 47 | horizon: 10000 48 | gamma: 1 49 | lam: 0.95 50 | cliprange: 0.2 51 | diff_clip_ratio: ${diff_clip_ratio:10} 52 | pos_clip_ratio: ${pos_clip_ratio:0.2} 53 | neg_clip_ratio: ${neg_clip_ratio:0.2} 54 | final_clip_ratio: ${final_clip_ratio:3} 55 | logprob_cliprange: 10 56 | neg_cliprange2: 3 57 | cliprange_value: ${cliprange_value:0.1} 58 | scale_reward: ${scale_reward:null} 59 | clip_onlineness: ${clip_onlineness:0} 60 | cliprange_onlineness: ${cliprange_onlineness:0} 61 | train_to_compare_num_responses: ${train_to_compare_num_responses:2} -------------------------------------------------------------------------------- /examples/mcore/configs/grpo/policy_trainer_qwen3.yaml: -------------------------------------------------------------------------------- 1 | includes: 2 | - base.yaml 3 | - model_qwen3.yaml 4 | 5 | load: ${load} 6 | save: ${save_dir} 7 | save_interval: ${save_interval:10000} 8 | train_iters: ${train_iters:12000} 9 | tensor_model_parallel_size: ${tensor_model_parallel_size:2} 10 | pipeline_model_parallel_size: ${pipeline_model_parallel_size:1} 11 | distributed_backend: nccl 12 | 13 | clip_grad: ${clip_grad:1.0} 14 | log_interval: 1 15 | 16 | bf16: True 17 | use_checkpoint_opt_param_scheduler: False 18 | adam_beta1: 0.9 19 | adam_beta2: 0.95 20 | num_workers: 8 21 | init_method_std: 0.006 22 | 23 | #recompute_method: uniform 24 | #recompute_granularity: full 25 | recompute_granularity: selective 26 | sequence_parallel: True 27 | 28 | no_load_optim: True 29 | no_load_rng: True 30 | no_load_args: True 31 | no_load_scheduler: True 32 | finetune: True 33 | dummy: 0 34 | 35 | 36 | lr_decay_iters: ${lr_decay_iters:12000} 37 | lr_warmup_iters: ${policy_lr_warmup_iters:100} 38 | lr: ${policy_lr:0.00000008} 39 | min_lr: ${policy_min_lr:0.000000008} 40 | lr_decay_style: ${policy_lr_decay_style:linear} 41 | weight_decay: 0.01 42 | lr_freeze_episodes: ${policy_lr_freeze_episodes:0} 43 | 44 | 45 | init_kl_coef: ${init_kl_coef:0.0} 46 | target: 6 47 | horizon: 10000 48 | gamma: 1 49 | lam: 0.95 50 | cliprange: 0.2 51 | diff_clip_ratio: ${diff_clip_ratio:10} 52 | pos_clip_ratio: ${pos_clip_ratio:0.2} 53 | neg_clip_ratio: ${neg_clip_ratio:0.2} 54 | final_clip_ratio: ${final_clip_ratio:3} 55 | logprob_cliprange: 10 56 | neg_cliprange2: 3 57 | cliprange_value: ${cliprange_value:0.1} 58 | scale_reward: ${scale_reward:null} 59 | clip_onlineness: ${clip_onlineness:0} 60 | cliprange_onlineness: ${cliprange_onlineness:0} 61 | train_to_compare_num_responses: ${train_to_compare_num_responses:2} -------------------------------------------------------------------------------- /examples/mcore/configs/grpo/vllm_policy_inference.yaml: -------------------------------------------------------------------------------- 1 | includes: 2 | - base.yaml 3 | 4 | # model partition 5 | tensor_model_parallel_size: ${tensor_model_parallel_size:2} 6 | pipeline_model_parallel_size: 1 7 | num_inference_per_prompt: ${num_inference_per_prompt:8} 8 | seq_length: ${seq_length:1024} 9 | max_new_tokens: ${max_new_tokens:1023} 10 | 11 | # sampling params 12 | temperature: ${policy_temperature:1.0} 13 | top_p: ${policy_top_p:0.9} 14 | top_k: ${policy_top_k:-1} 15 | presence_penalty: ${policy_presence_penalty:0.0} 16 | frequency_penalty: ${policy_frequency_penalty:0.0} 17 | repetition_penalty: ${policy_repetition_penalty:1.0} 18 | 19 | eval_temperature: ${policy_eval_temperature:0.6} 20 | eval_top_k: ${policy_eval_top_k:-1} 21 | eval_top_p: ${policy_eval_top_p:0.9} 22 | eval_presence_penalty: ${policy_eval_presence_penalty:0.0} 23 | eval_frequency_penalty: ${policy_eval_frequency_penalty:0.0} 24 | eval_repetition_penalty: ${policy_eval_repetition_penalty:1.0} 25 | 26 | # dataset 27 | vllm_prompt_key: ${vllm_prompt_key:prompt} 28 | vllm_input_ids_key: ${vllm_input_ids_key:input_ids} 29 | enable_thinking: ${enable_thinking:False} 30 | 31 | # scheduler config 32 | max_num_batched_tokens: ${max_num_batched_tokens:32768} 33 | max_seq_len_to_capture: ${max_seq_len_to_capture:32768} 34 | enable_stage_resume: ${enable_policy_stage_resume:False} 35 | 36 | # cache config 37 | gpu_memory_utilization: ${gpu_memory_utilization:0.85} 38 | enforce_eager: False 39 | 40 | tokenizer: ${tokenizer_path} 41 | 42 | -------------------------------------------------------------------------------- /examples/mcore/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== -------------------------------------------------------------------------------- /examples/mcore/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import torch 17 | 18 | def pad_to_max_len(all_tokens_right_padded, max_len, pad_value): 19 | pad_length = max_len - all_tokens_right_padded.size(1) 20 | if pad_length <= 0: 21 | return all_tokens_right_padded 22 | # Pad the tensor with zeros on the right side to the desired length 23 | padded_tensor = torch.nn.functional.pad(all_tokens_right_padded, (0, pad_length), mode='constant', value=pad_value) 24 | return padded_tensor 25 | 26 | def generate_loss_mask_position_ids(tokens: torch.Tensor, prompt_token_length: list, response_token_length:list): 27 | # Setup attention mask by prompt token length and response token length 28 | loss_mask = torch.zeros_like(tokens, dtype=torch.int32, device=tokens.device) 29 | for i in range(len(prompt_token_length)): 30 | loss_mask[i, prompt_token_length[i]: prompt_token_length[i] + response_token_length[i]] = 1.0 31 | _, seq_len = tokens.size() 32 | position_ids = torch.arange(seq_len, dtype=torch.long, device=tokens.device) 33 | position_ids = position_ids.unsqueeze(0).expand_as(tokens) 34 | 35 | return loss_mask, position_ids -------------------------------------------------------------------------------- /examples/mcore/scripts/train_grpo_qwen2_5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | ray stop 4 | rm -rf /tmp/ray/* 5 | 6 | # enveriment 7 | export CUDA_DEVICE_MAX_CONNECTIONS=1 8 | export RAY_num_server_call_thread=1 9 | export VLLM_USE_RAY_SPMD_WORKER=1 10 | export VLLM_USE_RAY_COMPILED_DAG=1 11 | export CHATLEARN=$(pwd) 12 | export PYTHONPATH=${MEGATRON_PATH}:${CHATLEARN}:${CHATLEARN}/examples:$PYTHONPATH 13 | export WORLD_SIZE=${WORLD_SIZE:-1} 14 | export RANK=${RANK:-0} 15 | export LOCAL_MASTER_ADDR=${MASTER_ADDR:-localhost} 16 | ports="30000" 17 | for i in $(seq 30001 30050); do 18 | ports="${ports};${i}" 19 | done 20 | export CUSTOM_PORTS=$ports 21 | export num_device=$(($WORLD_SIZE * 8)) 22 | 23 | # data 24 | export train_data_path="/mnt/data/datasets/MATH-lighteval/train.json" 25 | export eval_data_path="/mnt/data/datasets/MATH-lighteval/test.json" 26 | export patch_tokenizer_type=Qwen2Tokenizer 27 | export extra_vocab_size=421 28 | export tokenizer_path="/mnt/data/qwen-ckpts/Qwen2.5-7B-Instruct" 29 | export load="/mnt/data/qwen-ckpts/Qwen2.5-7B-Instruct-hf-to-mcore-tp4-pp2" 30 | 31 | # model 32 | export max_position_embedding=131072 33 | export policy_num_layers=28 34 | export policy_hidden_size=3584 35 | export policy_num_attention_heads=28 36 | export policy_num_query_groups=4 37 | export policy_ffn_hidden_size=18944 38 | export tensor_model_parallel_size=4 39 | export pipeline_model_parallel_size=2 40 | 41 | # training 42 | export final_clip_ratio=3 43 | export clip_grad=1.0 44 | export seed=3407 45 | export policy_lr=2e-6 46 | export policy_min_lr=2e-6 47 | export eval_episode_interval=1 48 | export save_interval=100000 49 | export save_episode_interval=10000 50 | export num_episode=200 51 | export sample_per_episode=2048 52 | export save_episode_interval=10000 53 | export train_micro_batch_size=8 54 | export train_global_batch_size=2048 55 | export vllm_generation_batch_size=128 56 | export trainer_generation_batch_size=8 57 | export train_iters=$(( ${num_episode} * ${sample_per_episode} / ${train_global_batch_size} )) 58 | export policy_lr_warmup_iters=0 59 | export lr_decay_iters=160000 60 | export max_num_batched_tokens=65536 61 | export gpu_memory_utilization=0.85 62 | 63 | # vllm 64 | export seq_length=2048 65 | export max_new_tokens=2048 66 | export max_seq_len_to_capture=2348 67 | export num_inference_per_prompt=32 68 | export policy_temperature=1.0 69 | export policy_top_p=1.0 70 | export policy_top_k=-1 71 | export policy_eval_temperature=0.6 72 | export policy_eval_top_p=0.95 73 | export policy_eval_top_k=20 74 | 75 | # logging and saving 76 | export enable_tensorboard=True 77 | export enable_wandb=False 78 | export WANDB_API_KEY="wandb-api-key" 79 | export exp_name=qwen2_5_7B_lr${policy_lr}_mbs${train_micro_batch_size}_gbs${train_global_batch_size}_tp${tensor_model_parallel_size}_pp${pipeline_model_parallel_size}_${WORLD_SIZE}nodes 80 | export output_dir=${CHATLEARN}/output/${exp_name} 81 | mkdir -p $output_dir/ 82 | export log_dir=${output_dir}/logs 83 | mkdir -p $log_dir 84 | log_file=$log_dir/${exp_name}_rank${RANK}.log 85 | export tensorboard_dir=${output_dir}/tensorboard 86 | export wandb_dir=${output_dir} 87 | export save_dir=${output_dir} 88 | 89 | cd $CHATLEARN/examples/mcore 90 | python entry/train_grpo.py -c configs/grpo/grpo_qwen2_5.yaml 2>&1 | tee ${log_file} ; exit ${PIPESTATUS[0]} 91 | 92 | -------------------------------------------------------------------------------- /examples/mcore/scripts/train_grpo_qwen3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | ray stop 4 | rm -rf /tmp/ray/* 5 | 6 | # enveriment 7 | export CUDA_DEVICE_MAX_CONNECTIONS=1 8 | export RAY_num_server_call_thread=1 9 | export VLLM_USE_RAY_SPMD_WORKER=1 10 | export VLLM_USE_RAY_COMPILED_DAG=1 11 | export CHATLEARN=$(pwd) 12 | export PYTHONPATH=${MEGATRON_PATH}:${CHATLEARN}:${CHATLEARN}/examples:$PYTHONPATH 13 | export WORLD_SIZE=${WORLD_SIZE:-1} 14 | export RANK=${RANK:-0} 15 | export LOCAL_MASTER_ADDR=${MASTER_ADDR:-localhost} 16 | ports="30000" 17 | for i in $(seq 30001 30050); do 18 | ports="${ports};${i}" 19 | done 20 | export CUSTOM_PORTS=$ports 21 | export num_device=$(($WORLD_SIZE * 8)) 22 | 23 | # data 24 | export train_data_path="/mnt/data/datasets/MATH-lighteval/train.json" 25 | export eval_data_path="/mnt/data/datasets/MATH-lighteval/test.json" 26 | export patch_tokenizer_type=Qwen2Tokenizer 27 | export extra_vocab_size=293 28 | export tokenizer_path="/mnt/data/qwen-ckpts/Qwen3-8B" 29 | export load="/mnt/data/qwen-ckpts/Qwen3-8B-to-mcore/" 30 | 31 | # model 32 | export max_position_embedding=40960 33 | export policy_num_layers=36 34 | export policy_hidden_size=4096 35 | export policy_num_attention_heads=32 36 | export policy_num_query_groups=8 37 | export policy_ffn_hidden_size=12288 38 | export tensor_model_parallel_size=4 39 | export pipeline_model_parallel_size=2 40 | 41 | # training 42 | export final_clip_ratio=3 43 | export clip_grad=1.0 44 | export seed=3407 45 | export policy_lr=2e-6 46 | export policy_min_lr=2e-6 47 | export eval_episode_interval=1 48 | export save_interval=100000 49 | export save_episode_interval=10000 50 | export num_episode=200 51 | export sample_per_episode=2048 52 | export save_episode_interval=10000 53 | export train_micro_batch_size=8 54 | export train_global_batch_size=2048 55 | export vllm_generation_batch_size=128 56 | export trainer_generation_batch_size=8 57 | export train_iters=$(( ${num_episode} * ${sample_per_episode} / ${train_global_batch_size} )) 58 | export policy_lr_warmup_iters=0 59 | export lr_decay_iters=160000 60 | export max_num_batched_tokens=65536 61 | export gpu_memory_utilization=0.85 62 | # for qwen3 where enable_thinking 63 | export enable_thinking=False 64 | 65 | # vllm 66 | export seq_length=2048 67 | export max_new_tokens=2048 68 | export max_seq_len_to_capture=2348 69 | export num_inference_per_prompt=32 70 | export policy_temperature=1.0 71 | export policy_top_p=1.0 72 | export policy_top_k=-1 73 | export policy_eval_temperature=0.6 74 | export policy_eval_top_p=0.95 75 | export policy_eval_top_k=20 76 | 77 | # logging and saving 78 | export enable_tensorboard=True 79 | export enable_wandb=False 80 | export WANDB_API_KEY="wandb-api-key" 81 | export exp_name=qwen3_8B_lr${policy_lr}_mbs${train_micro_batch_size}_gbs${train_global_batch_size}_tp${tensor_model_parallel_size}_pp${pipeline_model_parallel_size}_${WORLD_SIZE}nodes 82 | export output_dir=${CHATLEARN}/output/${exp_name} 83 | mkdir -p $output_dir/ 84 | export log_dir=${output_dir}/logs 85 | mkdir -p $log_dir 86 | log_file=$log_dir/${exp_name}_rank${RANK}.log 87 | export tensorboard_dir=${output_dir}/tensorboard 88 | export wandb_dir=${output_dir} 89 | export save_dir=${output_dir} 90 | 91 | cd $CHATLEARN/examples/mcore 92 | python entry/train_grpo.py -c configs/grpo/grpo_qwen3.yaml 2>&1 | tee ${log_file} ; exit ${PIPESTATUS[0]} 93 | 94 | -------------------------------------------------------------------------------- /examples/tests/barrier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import argparse 17 | from datetime import timedelta 18 | 19 | import torch.distributed as dist 20 | 21 | parser = argparse.ArgumentParser(description="Barrier") 22 | parser.add_argument( 23 | '--timeout', type=int, default=None, 24 | help="Timeout in minutes" 25 | ) 26 | 27 | def main(): 28 | args = parser.parse_args() 29 | 30 | if args.timeout is not None: 31 | timeout = timedelta(minutes=args.timeout) 32 | else: 33 | timeout = None 34 | 35 | dist.init_process_group(backend="nccl", timeout=timeout) 36 | dist.destroy_process_group() 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ray[default]==2.32.0 2 | transformers==4.42.0 3 | pynvml==11.4.1 4 | deepspeed==0.14.4 5 | vllm==0.5.1 6 | accelerate 7 | jsonlines 8 | torchtyping 9 | tensorboard 10 | cupy 11 | # math related 12 | word2number 13 | timeout-decorator 14 | latex2sympy2==1.9.0 15 | 16 | # install apex if you needed 17 | # git clone https://github.com/NVIDIA/apex 18 | # cd apex 19 | # pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ 20 | 21 | # install transformer engine if you needed 22 | # git+https://github.com/NVIDIA/TransformerEngine.git@v1.2.1 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import sys 17 | from setuptools import setup, find_packages 18 | 19 | if sys.version_info[0] < 3: 20 | import imp 21 | VERSION = imp.load_source('chatlearn.version', 'chatlearn/utils/version.py').VERSION 22 | else: 23 | from importlib.machinery import SourceFileLoader 24 | VERSION = SourceFileLoader("chatlearn.version", "chatlearn/utils/version.py") \ 25 | .load_module().VERSION 26 | 27 | setup( 28 | name='pai-chatlearn', 29 | version=VERSION, 30 | python_requires='>3.6.0', 31 | packages=find_packages(), 32 | include_package_data=True, 33 | install_requires=[], 34 | long_description="PAI ChatLearn", 35 | author='Alibaba Group', 36 | license='Apache 2.0', 37 | ) 38 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/tests/__init__.py -------------------------------------------------------------------------------- /tests/base/test_send_recv.py: -------------------------------------------------------------------------------- 1 | 2 | import ray 3 | import ray.util.collective as collective 4 | import torch 5 | 6 | import chatlearn 7 | 8 | 9 | @ray.remote(num_gpus=1) 10 | class Worker: 11 | def __init__(self): 12 | from chatlearn.launcher.initialize import patch_ray 13 | patch_ray() 14 | self.send1 = torch.ones((4,), dtype=torch.bfloat16, device="cuda") 15 | self.send1 = torch.nn.parameter.Parameter(self.send1) 16 | self.recv1 = torch.zeros((4,), dtype=torch.bfloat16, device="cuda") 17 | self.recv1 = torch.nn.parameter.Parameter(self.recv1) 18 | self.rank = -1 19 | 20 | def setup(self, world_size, rank): 21 | self.rank = rank 22 | collective.init_collective_group(world_size, rank, "nccl", "8") 23 | return True 24 | 25 | def compute(self, src, tgt): 26 | if self.rank == 0: 27 | collective.send(self.send1*2, tgt, "8") 28 | else: 29 | collective.recv(self.recv1, src, "8") 30 | return self.recv1 31 | 32 | def compute2(self): 33 | if self.rank == 0: 34 | collective.send_multigpu(self.send2 * 4, 1, 0, "8") 35 | else: 36 | collective.recv_multigpu(self.recv1, 0, 1, "8") 37 | return self.recv1 38 | 39 | def recv(self, src_rank, src_gpu): 40 | collective.recv_multigpu(self.recv1, src_rank, src_gpu, "8") 41 | 42 | 43 | def recv2(self, src_rank, src_gpu): 44 | collective.recv_multigpu(self.recv2, src_rank, src_gpu, "8") 45 | return self.recv2 46 | 47 | def destroy(self): 48 | collective.destroy_collective_group("8") 49 | 50 | 51 | def test_send_recv(): 52 | num_workers = 3 53 | workers = [] 54 | init_rets = [] 55 | w0 = Worker.remote() 56 | init_rets.append(w0.setup.remote(num_workers, 0)) 57 | w1 = Worker.remote() 58 | init_rets.append(w1.setup.remote(num_workers, 1)) 59 | w2 = Worker.remote() 60 | init_rets.append(w2.setup.remote(num_workers, 2)) 61 | 62 | workers = [w0, w1, w2] 63 | a = ray.get(init_rets) 64 | print('================== init done', a, flush=True) 65 | results = [w0.compute.remote(0, 1), w1.compute.remote(0, 1)] 66 | print(ray.get(results)) 67 | print('send from w0 to w2', flush=True) 68 | results = [w0.compute.remote(0, 2), w2.compute.remote(0, 2)] 69 | print(ray.get(results)) 70 | 71 | ray.get([w.destroy.remote() for w in workers]) 72 | 73 | TEST_CASE = [test_send_recv] -------------------------------------------------------------------------------- /tests/configs/base.yaml: -------------------------------------------------------------------------------- 1 | models: 2 | policy: 3 | model_config_file: model.yaml 4 | num_gpu: 1 5 | gpu_per_process: 1 6 | trainable: False 7 | 8 | reference: 9 | model_config_file: model.yaml 10 | num_gpu: 1 11 | gpu_per_process: 1 12 | trainable: False 13 | 14 | 15 | runtime: 16 | num_rollout_worker: 1 17 | num_iteration: 5000 18 | sample_per_episode: 1000 19 | num_training_epoch: ${num_training_epoch:3} 20 | unknown_args: "test_unknown" 21 | -------------------------------------------------------------------------------- /tests/configs/eval.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | num_gpu: 2 13 | gpu_per_process: 1 14 | trainable: False 15 | 16 | reference: 17 | num_gpu: 1 18 | gpu_per_process: 1 19 | trainable: False 20 | 21 | reward: 22 | num_gpu: 2 23 | gpu_per_process: 1 24 | trainable: False 25 | 26 | value: 27 | num_gpu: 1 28 | gpu_per_process: 1 29 | trainable: False 30 | 31 | ppo_policy: 32 | num_gpu: 1 33 | gpu_per_process: 1 34 | trainable: True 35 | lora: 36 | enable_lora: ${enable_lora_policy:False} 37 | lora_dim: 64 38 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear 39 | column_only_qkv: False 40 | lora_dropout: 0.05 41 | 42 | ppo_value: 43 | num_gpu: 1 44 | gpu_per_process: 1 45 | trainable: True 46 | 47 | runtime: 48 | debug: True 49 | generation_batch_size: ${batch_size:4} 50 | train_micro_batch_size: 5 51 | train_global_batch_size: 10 52 | num_episode: 2 53 | sample_per_episode: 16 54 | num_training_epoch: 1 55 | save_episode_interval: 200 56 | eval_data_num_limit: 4 57 | eval_episode_interval: 1 58 | -------------------------------------------------------------------------------- /tests/configs/exp1.yaml: -------------------------------------------------------------------------------- 1 | generate_config: 2 | num_beams: 1 3 | num_return_sequences: 1 4 | temperature: 1.0 5 | do_sample: True 6 | early_stopping: True 7 | top_k: 1 8 | top_p: 0.9 9 | repetition_penalty: 1.0 10 | length_penalty: 1.0 11 | min_length: 5 12 | max_length: 4096 13 | no_repeat_ngram_size: 2 14 | eos_token_id: 102 15 | -------------------------------------------------------------------------------- /tests/configs/exp2.yaml: -------------------------------------------------------------------------------- 1 | test: 123 2 | model_config: 3 | attention_probs_dropout_prob: 0.3 4 | generate_config: 5 | eos_token_id: 103 6 | -------------------------------------------------------------------------------- /tests/configs/grpo.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | num_gpu: 1 13 | gpu_per_process: 1 14 | trainable: False 15 | 16 | reference: 17 | num_gpu: 1 18 | gpu_per_process: 1 19 | trainable: False 20 | 21 | reward: 22 | num_gpu: 1 23 | gpu_per_process: 1 24 | trainable: False 25 | 26 | ppo_policy: 27 | num_gpu: 1 28 | gpu_per_process: 1 29 | trainable: True 30 | lora: 31 | enable_lora: ${enable_lora_policy:False} 32 | lora_dim: 64 33 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear 34 | column_only_qkv: False 35 | lora_dropout: 0.05 36 | 37 | runtime: 38 | debug: True 39 | generation_batch_size: ${batch_size:4} 40 | train_micro_batch_size: 5 41 | train_global_batch_size: 10 42 | num_episode: 2 43 | sample_per_episode: 16 44 | num_training_epoch: 1 45 | save_episode_interval: 200 46 | -------------------------------------------------------------------------------- /tests/configs/model.yaml: -------------------------------------------------------------------------------- 1 | includes: 2 | - exp1.yaml 3 | - exp2.yaml 4 | 5 | model_config: 6 | attention_probs_dropout_prob: 0.1 7 | attention_type: "self" 8 | hidden_act: "gelu" 9 | hidden_dropout_prob: 0.1 10 | hidden_size: 768 11 | initializer_range: 0.02 12 | intermediate_size: 768 13 | layer_norm_eps: 1e-12 14 | layernorm_epsilon: 1e-12 15 | max_position_embeddings: 2048 16 | model_type: "gpt" 17 | num_attention_heads: 12 18 | num_hidden_layers: 12 19 | transformers_version: "4.22.0" 20 | type_vocab_size: 2 21 | vocab_size: 25600 22 | -------------------------------------------------------------------------------- /tests/configs/o1.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | - "mypipe" 9 | - "*.pt.trace.json" 10 | - "*.nsys-rep" 11 | 12 | models: 13 | mcts: 14 | model_config_file: model.yaml 15 | num_cpu: 4 16 | cpu_per_process: 1 17 | trainable: False 18 | policy: 19 | model_config_file: model.yaml 20 | num_gpu: 1 21 | gpu_per_process: 1 22 | trainable: False 23 | 24 | reward: 25 | model_config_file: model.yaml 26 | num_gpu: 1 27 | trainable: False 28 | 29 | reward1: 30 | model_config_file: model.yaml 31 | num_cpu: 1 32 | cpu_per_process: 1 33 | trainable: False 34 | 35 | 36 | runtime: 37 | generation_batch_size: ${generation_batch_size:1} 38 | num_episode: ${num_episode:1} 39 | sample_per_episode: ${sample_per_episode:4} 40 | num_training_epoch: 1 41 | save_episode_interval: ${save_episode_interval:1000} 42 | query_key: ${query_key:query} 43 | data_path: ${data_path:/path/to/data} 44 | training_data_num_limit: ${training_data_num_limit:-1} 45 | eval_episode_interval: ${eval_episode_interval:0} 46 | eval_data_num_limit: 20 47 | nsys: False 48 | free_sync_collective_group: ${free_sync_collective_group:False} 49 | param_sync_comm_type: ${param_sync_comm_type:broadcast} 50 | validate_param_sync: ${validate_param_sync:False} 51 | -------------------------------------------------------------------------------- /tests/configs/parameter_sync.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | num_gpu: 1 13 | gpu_per_process: 1 14 | trainable: False 15 | 16 | ppo_policy: 17 | num_gpu: 1 18 | gpu_per_process: 1 19 | trainable: True 20 | lora: 21 | enable_lora: ${enable_lora_policy:False} 22 | lora_dim: 64 23 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear 24 | column_only_qkv: False 25 | lora_dropout: 0.05 26 | 27 | runtime: 28 | debug: False 29 | generation_batch_size: ${batch_size:4} 30 | train_micro_batch_size: 5 31 | train_global_batch_size: 10 32 | num_episode: 2 33 | sample_per_episode: 16 34 | num_training_epoch: 1 35 | save_episode_interval: 200 36 | -------------------------------------------------------------------------------- /tests/configs/rlhf.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | num_gpu: 1 13 | gpu_per_process: 1 14 | trainable: False 15 | 16 | reference: 17 | num_gpu: 1 18 | gpu_per_process: 1 19 | trainable: False 20 | 21 | reward: 22 | num_gpu: 1 23 | gpu_per_process: 1 24 | trainable: False 25 | 26 | value: 27 | num_gpu: 1 28 | gpu_per_process: 1 29 | trainable: False 30 | 31 | ppo_policy: 32 | num_gpu: 1 33 | gpu_per_process: 1 34 | trainable: True 35 | lora: 36 | enable_lora: ${enable_lora_policy:False} 37 | lora_dim: 64 38 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear 39 | column_only_qkv: False 40 | lora_dropout: 0.05 41 | 42 | ppo_value: 43 | num_gpu: 1 44 | gpu_per_process: 1 45 | trainable: True 46 | 47 | runtime: 48 | debug: True 49 | generation_batch_size: ${batch_size:4} 50 | train_micro_batch_size: 5 51 | train_global_batch_size: 10 52 | num_episode: 2 53 | sample_per_episode: 16 54 | num_training_epoch: 1 55 | save_episode_interval: 200 56 | data_shuffle: False 57 | data_rerank: False 58 | -------------------------------------------------------------------------------- /tests/configs/rlhf2.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | num_gpu: 1 13 | gpu_per_process: 1 14 | trainable: False 15 | 16 | reference: 17 | num_gpu: 1 18 | gpu_per_process: 1 19 | trainable: False 20 | 21 | reward: 22 | num_gpu: 1 23 | gpu_per_process: 1 24 | trainable: False 25 | 26 | value: 27 | num_gpu: 1 28 | gpu_per_process: 1 29 | trainable: True 30 | 31 | ppo_policy: 32 | num_gpu: 1 33 | gpu_per_process: 1 34 | trainable: True 35 | lora: 36 | enable_lora: ${enable_lora_policy:False} 37 | lora_dim: 64 38 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear 39 | column_only_qkv: False 40 | lora_dropout: 0.05 41 | 42 | runtime: 43 | debug: True 44 | generation_batch_size: ${batch_size:4} 45 | train_micro_batch_size: 5 46 | train_global_batch_size: 10 47 | num_episode: 2 48 | sample_per_episode: 16 49 | num_training_epoch: 1 50 | save_episode_interval: 200 51 | -------------------------------------------------------------------------------- /tests/configs/rlhf_cpu.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | num_gpu: 1 13 | gpu_per_process: 1 14 | trainable: False 15 | 16 | reference: 17 | num_gpu: 1 18 | gpu_per_process: 1 19 | trainable: False 20 | 21 | reward: 22 | num_cpu: 2 23 | cpu_per_process: 1 24 | trainable: False 25 | 26 | value: 27 | num_gpu: 1 28 | gpu_per_process: 1 29 | trainable: False 30 | 31 | ppo_policy: 32 | num_gpu: 1 33 | gpu_per_process: 1 34 | trainable: True 35 | lora: 36 | enable_lora: ${enable_lora_policy:False} 37 | lora_dim: 64 38 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear 39 | column_only_qkv: False 40 | lora_dropout: 0.05 41 | 42 | ppo_value: 43 | num_gpu: 1 44 | gpu_per_process: 1 45 | trainable: True 46 | 47 | runtime: 48 | debug: True 49 | generation_batch_size: ${batch_size:4} 50 | train_micro_batch_size: 5 51 | train_global_batch_size: 10 52 | num_episode: 2 53 | sample_per_episode: 16 54 | num_training_epoch: 1 55 | save_episode_interval: 200 56 | -------------------------------------------------------------------------------- /tests/configs/rlhf_eval.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | num_gpu: 2 13 | gpu_per_process: 1 14 | trainable: False 15 | 16 | reference: 17 | num_gpu: 1 18 | gpu_per_process: 1 19 | trainable: False 20 | 21 | reward: 22 | num_gpu: 2 23 | gpu_per_process: 1 24 | trainable: False 25 | 26 | value: 27 | num_gpu: 1 28 | gpu_per_process: 1 29 | trainable: False 30 | 31 | ppo_policy: 32 | num_gpu: 1 33 | gpu_per_process: 1 34 | trainable: True 35 | lora: 36 | enable_lora: ${enable_lora_policy:False} 37 | lora_dim: 64 38 | lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear 39 | column_only_qkv: False 40 | lora_dropout: 0.05 41 | 42 | ppo_value: 43 | num_gpu: 1 44 | gpu_per_process: 1 45 | trainable: True 46 | 47 | runtime: 48 | debug: True 49 | generation_batch_size: ${batch_size:4} 50 | train_micro_batch_size: 5 51 | train_global_batch_size: 10 52 | num_episode: 2 53 | sample_per_episode: 16 54 | num_training_epoch: 1 55 | save_episode_interval: 200 56 | eval_data_num_limit: 4 57 | eval_episode_interval: 1 58 | -------------------------------------------------------------------------------- /tests/configs/sprl.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | - "mypipe" 9 | - "*.pt.trace.json" 10 | - "*.nsys-rep" 11 | 12 | models: 13 | sprl: 14 | model_config_file: model.yaml 15 | num_cpu: 1 16 | cpu_per_process: 1 17 | trainable: False 18 | actor: 19 | model_config_file: model.yaml 20 | num_gpu: 1 21 | gpu_per_process: 1 22 | trainable: False 23 | 24 | critic: 25 | model_config_file: model.yaml 26 | num_gpu: 1 27 | gpu_per_process: 1 28 | trainable: False 29 | 30 | value: 31 | model_config_file: model.yaml 32 | num_gpu: 1 33 | gpu_per_process: 1 34 | trainable: False 35 | 36 | prm: 37 | model_config_file: model.yaml 38 | num_gpu: 1 39 | gpu_per_process: 1 40 | trainable: False 41 | 42 | 43 | runtime: 44 | generation_batch_size: ${generation_batch_size:1} 45 | num_episode: ${num_episode:2} 46 | sample_per_episode: ${sample_per_episode:4} 47 | train_global_batch_size: ${train_global_batch_size:1} 48 | train_micro_batch_size: ${train_micro_batch_size:1} 49 | num_training_epoch: 1 50 | save_episode_interval: ${save_episode_interval:1000} 51 | query_key: ${query_key:query} 52 | data_path: ${data_path:/path/to/data} 53 | training_data_num_limit: ${training_data_num_limit:-1} 54 | eval_episode_interval: ${eval_episode_interval:0} 55 | eval_data_num_limit: 20 56 | nsys: False 57 | free_sync_collective_group: ${free_sync_collective_group:False} 58 | param_sync_comm_type: ${param_sync_comm_type:broadcast} 59 | validate_param_sync: ${validate_param_sync:False} 60 | -------------------------------------------------------------------------------- /tests/configs/test_eval.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | num_gpu: 2 13 | trainable: False 14 | 15 | runtime: 16 | debug: True 17 | generation_batch_size: 4 18 | -------------------------------------------------------------------------------- /tests/configs/test_eval2.yaml: -------------------------------------------------------------------------------- 1 | runtime_env: 2 | platform: DLC 3 | excludes: 4 | - "*pt" 5 | - "logs" 6 | - "tensorboards" 7 | - ".nfs*" 8 | 9 | 10 | models: 11 | policy: 12 | num_gpu: 2 13 | trainable: False 14 | 15 | reward: 16 | num_gpu: 2 17 | trainable: False 18 | 19 | reward2: 20 | num_cpu: 2 21 | trainable: False 22 | 23 | runtime: 24 | debug: True 25 | generation_batch_size: 4 26 | -------------------------------------------------------------------------------- /tests/parameter_sync/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/tests/parameter_sync/__init__.py -------------------------------------------------------------------------------- /tests/rlhf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/ChatLearn/8797214eeb3d7777a75ec09efd794844cb12dfe2/tests/rlhf/__init__.py -------------------------------------------------------------------------------- /tests/rlhf/test_rlhf_replica.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data import Dataset 6 | 7 | import chatlearn 8 | from chatlearn.utils import future 9 | from chatlearn import RLHFEngine 10 | from chatlearn import TorchModule 11 | 12 | from utils import CustomDataset, PolicyModel, ReferenceModel, RewardModel, ValueModel, PPOPolicy, PPOValue 13 | 14 | 15 | def test_rlhf_replica(): 16 | chatlearn.get_args().models["policy"].num_replica = 2 17 | policy = PolicyModel("policy") 18 | reference = ReferenceModel("reference") 19 | reward = RewardModel("reward") 20 | value = ValueModel("value") 21 | ppo_policy = PPOPolicy("ppo_policy") 22 | ppo_value = PPOValue("ppo_value") 23 | 24 | engine = RLHFEngine(policy, reference, reward, value, ppo_policy, ppo_value) 25 | #assert policy.num_replica == 2 26 | 27 | data = torch.ones([1024]) 28 | engine.set_dataset([data] * 35) 29 | engine.setup() 30 | if policy.num_replica == 2: 31 | assert reference.num_replica == 1 32 | data = torch.ones([1024]) 33 | assert len(engine.env._all_datasets[0]) == 35, len(engine.env._all_datasets[0]) 34 | visible_devices = engine.models[0].replicas[0].get_visible_gpus() 35 | visible_devices = future.get(visible_devices) 36 | assert visible_devices == [[0]], visible_devices 37 | visible_devices = engine.models[0].replicas[1].get_visible_gpus() 38 | visible_devices = future.get(visible_devices) 39 | assert visible_devices == [[1]], visible_devices 40 | engine.stop() 41 | 42 | def test_rlhf_replica_2(): 43 | chatlearn.get_args().models["policy"].num_replica = 2 44 | chatlearn.get_args().models["value"].num_replica = 2 45 | policy = PolicyModel("policy") 46 | reference = ReferenceModel("reference") 47 | reward = RewardModel("reward") 48 | value = ValueModel("value") 49 | ppo_policy = PPOPolicy("ppo_policy") 50 | ppo_value = PPOValue("ppo_value") 51 | 52 | engine = RLHFEngine(policy, reference, reward, value, ppo_policy, ppo_value) 53 | #assert policy.num_replica == 2 54 | 55 | data = torch.ones([1024]) 56 | engine.set_dataset([data] * 35) 57 | engine.setup() 58 | if policy.num_replica == 2: 59 | assert reference.num_replica == 1 60 | data = torch.ones([1024]) 61 | engine.set_dataset([data] * 35) 62 | assert len(engine.env._all_datasets[0]) == 35, len(engine.env._all_datasets[0]) 63 | visible_devices = engine.models[0].replicas[0].get_visible_gpus() 64 | visible_devices = future.get(visible_devices) 65 | assert visible_devices == [[0]], visible_devices 66 | visible_devices = engine.models[0].replicas[1].get_visible_gpus() 67 | visible_devices = future.get(visible_devices) 68 | assert visible_devices == [[1]], visible_devices 69 | engine.stop() 70 | 71 | 72 | TEST_CASE = [test_rlhf_replica, test_rlhf_replica_2] -------------------------------------------------------------------------------- /tests/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=$(cd ../ && pwd):${PWD}:${PYTHONPATH} 3 | CDIR="$(cd "$(dirname "$0")" ; pwd -P)" 4 | LOGFILE=/tmp/pytorch_py_test.log 5 | rm -rf core* 6 | rm -rf /tmp/ray/* 7 | 8 | [ -z "$MASTER_ADDR" ] && export MASTER_ADDR=localhost 9 | [ -z "$WORLD_SIZE" ] && export WORLD_SIZE=1 10 | [ -z "$GPUS_PER_NODE" ] && export GPUS_PER_NODE=8 11 | [ -z "$RANK" ] && export RANK=0 12 | 13 | if [ -z "${CUSTOM_PORTS}" ]; then 14 | ports="30000" 15 | for i in $(seq 30001 30050); do 16 | ports="${ports};${i}" 17 | done 18 | export CUSTOM_PORTS=$ports 19 | [ -z "$LOCAL_MASTER_ADDR" ] && export LOCAL_MASTER_ADDR=$MASTER_ADDR 20 | echo LOCAL_MASTER_ADDR=$MASTER_ADDR 21 | fi 22 | 23 | if [ -d checkpoint ]; then 24 | rm -r checkpoint 25 | fi 26 | if [ -d checkpoint2 ]; then 27 | rm -r checkpoint2 28 | fi 29 | 30 | 31 | shift $(($OPTIND - 1)) 32 | 33 | 34 | function run_test { 35 | rm -rf core* 36 | ray stop --force 37 | time "$@" 38 | } 39 | 40 | set -x 41 | 42 | TEST_CASES=( 43 | "unittest" # passed 44 | "base" # passed 45 | "rlhf" # partial passed 46 | #"parameter_sync" # partial passed 47 | #"eval" # to be fixed 48 | #"o1" # to be fixed 49 | #"sprl" # to be fixed 50 | ) 51 | # Run ALL Tests in TEST_CASES 52 | for test_case in "${TEST_CASES[@]}" 53 | do 54 | run_test python test_main.py -t "$test_case" -c "configs/$test_case.yaml" || exit 1 55 | done 56 | 57 | # Usage: Run A Specified TestCase with case name 58 | #run_test python test_main.py -t "rlhf.test_rlhf_ckpt" -c "configs/rlhf.yaml" 59 | 60 | ray stop --force 61 | -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import argparse 4 | import glob 5 | import importlib 6 | from pathlib import Path 7 | 8 | from utils import run_test 9 | 10 | import chatlearn 11 | 12 | 13 | def _test_func(path, case_name): 14 | # Traverse all test_.*.py test files in target path 15 | test_modules = sorted([ 16 | Path(f).stem for f in glob.glob(f"{path}/test_*.py") 17 | if not Path(f).stem.startswith("__") ]) 18 | 19 | test_cases = [] 20 | for module_name in test_modules: 21 | try: 22 | module = importlib.import_module(f"{path}.{module_name}") 23 | test_case = getattr(module, "TEST_CASE") 24 | test_cases.extend(test_case) 25 | except (ModuleNotFoundError, AttributeError) as e: 26 | print(f"加载失败: {module_name} ({type(e).__name__})") 27 | 28 | # init chatlearn framework once 29 | chatlearn.init() 30 | 31 | return run_test(test_cases, case_name) 32 | 33 | def _parse_args(): 34 | parser = argparse.ArgumentParser(description="") 35 | parser.add_argument( 36 | "-t", "--test_case", 37 | type=str, required=True, help="test_case or test_case.case_name") 38 | return parser.parse_known_args() 39 | 40 | def _run_unit_tests(test_dir): 41 | import unittest 42 | import os 43 | discover_config = { 44 | "start_dir": test_dir, 45 | "pattern": "test_*.py", 46 | "top_level_dir": None 47 | } 48 | 49 | test_suite = unittest.defaultTestLoader.discover(**discover_config) 50 | 51 | runner = unittest.TextTestRunner( 52 | verbosity=2, 53 | failfast=False, 54 | buffer=False, 55 | resultclass=None 56 | ) 57 | 58 | test_result = runner.run(test_suite) 59 | 60 | print(f"Total unittest cases: {test_result.testsRun}") 61 | print(f"Failed unittest cases: {len(test_result.failures)}") 62 | print(f"Error unittest cases: {len(test_result.errors)}") 63 | if len(test_result.failures) == 0 and len(test_result.errors) == 0: 64 | return 0 # UT Passed 65 | return 1 66 | 67 | if __name__ == "__main__": 68 | args, _ = _parse_args() 69 | test_case = args.test_case 70 | case_name = None 71 | args_split = test_case.split('.') 72 | if (len(args_split) == 2): 73 | test_case = args_split[0] 74 | case_name = args_split[1] 75 | 76 | if test_case == "unittest": 77 | sys.exit(_run_unit_tests("unittests")) 78 | sys.exit(_test_func(test_case, case_name)) -------------------------------------------------------------------------------- /tests/unittests/test_flat_tensors.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | import torch 5 | 6 | from chatlearn.utils.flat_tensors import FlatTensors, BucketizedFlatTensors 7 | 8 | 9 | # pylint: disable=missing-class-docstring 10 | class TestFlatTensors(unittest.TestCase): 11 | 12 | @staticmethod 13 | def almost_same_memory_usage(t1, t2, eps): 14 | return abs(t1 / t2 - 1) < eps 15 | 16 | def run_flat_tensors_test_with_constructor(self, constructor): 17 | seed = 0 18 | random.seed(seed) 19 | torch.manual_seed(seed) 20 | 21 | measure1 = torch.cuda.memory_allocated() 22 | # Randomly generate some tensors. 23 | n = 4 24 | n_dims = [random.randint(1, 4) for _ in range(n)] 25 | shapes = [ 26 | [random.randint(0, 8) for _ in range(dim)] 27 | for dim in n_dims 28 | ] 29 | 30 | tensors = [ 31 | torch.rand(size=shape, device='cuda') 32 | for shape in shapes 33 | ] 34 | measure2 = torch.cuda.memory_allocated() 35 | tensors_usage = measure2 - measure1 36 | 37 | # Clone tensors for comparison. 38 | cloned = [ 39 | tensor.detach().clone() for tensor in tensors 40 | ] 41 | measure3 = torch.cuda.memory_allocated() 42 | cloned_usage = measure3 - measure2 43 | self.almost_same_memory_usage(cloned_usage, tensors_usage, 1e-3) 44 | 45 | # Check after creating FlatTensors 46 | flat_tensor = constructor(tensors) 47 | for t, t_copied in zip(tensors, cloned): 48 | assert torch.equal(t, t_copied) 49 | 50 | # Check after offloaded. 51 | flat_tensor.copy_to_primary_store() 52 | for t in tensors: 53 | assert t.shape == torch.Size([0]) 54 | 55 | measure4 = torch.cuda.memory_allocated() 56 | offloaded_memory = measure3 - measure4 57 | self.almost_same_memory_usage(offloaded_memory, tensors_usage, 1e-3) 58 | 59 | # Check after onloaded. 60 | flat_tensor.copy_to_gpu_buffer() 61 | measure5 = torch.cuda.memory_allocated() 62 | onloaded = measure5 - measure4 63 | self.almost_same_memory_usage(onloaded, tensors_usage, 1e-3) 64 | 65 | for t, t_copied in zip(tensors, cloned): 66 | assert torch.equal(t, t_copied) 67 | 68 | def test_flat_tensors(self): 69 | self.run_flat_tensors_test_with_constructor( 70 | lambda tensors: FlatTensors(tensors, primary_store_device='cpu') 71 | ) 72 | torch.cuda.synchronize() 73 | 74 | def test_bucketized_flat_tensors(self): 75 | self.run_flat_tensors_test_with_constructor( 76 | lambda tensors: BucketizedFlatTensors( 77 | tensors, primary_store_device='cpu', bucket_size_mb=16 78 | ) 79 | ) 80 | torch.cuda.synchronize() 81 | 82 | 83 | if __name__ == '__main__': 84 | unittest.main() 85 | -------------------------------------------------------------------------------- /tests/unittests/test_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """UnitTest for CollectiveTaskScheduler .""" 16 | 17 | from unittest import TestCase 18 | from chatlearn.synchronizer.scheduler import CollectiveTask, collective_task_scheduler, parallel_execute_collective_tasks 19 | 20 | 21 | class TestCollectiveTaskScheduler(TestCase): 22 | """Test for CollectiveTaskScheduler""" 23 | def test_scheduler(self): 24 | tasks = [ 25 | CollectiveTask([1, 2], "group1"), 26 | CollectiveTask([1, 3], "group1"), 27 | CollectiveTask([2, 1], "group2"), 28 | CollectiveTask([4, 5], "group2"), 29 | ] 30 | generator = collective_task_scheduler(tasks) 31 | paralel_tasks = next(generator) 32 | self.assertEqual([1, 2], paralel_tasks[0].actors) 33 | self.assertEqual([4, 5], paralel_tasks[1].actors) 34 | paralel_tasks = next(generator) 35 | self.assertEqual([1, 3], paralel_tasks[0].actors) 36 | paralel_tasks = next(generator) 37 | self.assertEqual([2, 1], paralel_tasks[0].actors) 38 | 39 | def task_func(task): 40 | if task.actors[0] == 2: 41 | self.assertEqual("group2", task.group) 42 | 43 | parallel_execute_collective_tasks(tasks, task_func) 44 | -------------------------------------------------------------------------------- /tests/unittests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """UT for utils.""" 16 | 17 | import unittest 18 | import ray 19 | from chatlearn.utils.utils import parse_function_return_num 20 | from chatlearn import get 21 | from chatlearn.utils.utils import split_index 22 | 23 | 24 | # pylint: disable=missing-class-docstring 25 | class TestDataset(unittest.TestCase): 26 | 27 | def test_function_return(self): 28 | 29 | def func(): 30 | return 1, 2, 3 31 | 32 | res = parse_function_return_num(func) 33 | self.assertEqual(res, 3) 34 | 35 | def func2(aaa): 36 | if aaa > 0: 37 | return 1, 2 38 | else: 39 | return 3, 4 40 | 41 | res = parse_function_return_num(func2) 42 | self.assertEqual(res, 2) 43 | 44 | def func3(): 45 | res = [1, 2, 3] 46 | return res 47 | 48 | res = parse_function_return_num(func3) 49 | self.assertEqual(res, 1) 50 | 51 | def func4(): 52 | res = [1, 2, 3] 53 | return res, 1 54 | 55 | res = parse_function_return_num(func4) 56 | self.assertEqual(res, 2) 57 | 58 | def _test_get(self): 59 | ray.init() 60 | value = ray.put(1) 61 | data = (value, {1:1}) 62 | data1 = get(data) 63 | self.assertEqual(data1, (1, {1:1})) 64 | data = ([value], {1:1}) 65 | data1 = get(data) 66 | self.assertEqual(data1, ([1], {1:1})) 67 | value = ray.put({"a":2}) 68 | data = ([value], {1:1}) 69 | data1 = get(data) 70 | self.assertEqual(data1, ([{"a":2}], {1:1})) 71 | 72 | def test_split_index(self): 73 | length = 10 74 | num_splits = 3 75 | res = split_index(length,num_splits) 76 | self.assertEqual(res, [(0, 4), (4, 7), (7, 10)]) 77 | 78 | # pylint: enable=missing-class-docstring 79 | 80 | 81 | if __name__ == '__main__': 82 | unittest.main() 83 | --------------------------------------------------------------------------------