├── .gitattributes
├── LICENSE
├── LLaMA-Factory
├── .env.local
├── .gitattributes
├── LICENSE
├── MANIFEST.in
├── Makefile
├── data
│ ├── dataset_info.json
│ └── tool_sft.json
├── evaluation
│ ├── ceval
│ │ ├── ceval.py
│ │ ├── ceval.zip
│ │ └── mapping.json
│ ├── cmmlu
│ │ ├── cmmlu.py
│ │ ├── cmmlu.zip
│ │ └── mapping.json
│ └── mmlu
│ │ ├── mapping.json
│ │ ├── mmlu.py
│ │ └── mmlu.zip
├── examples
│ ├── README.md
│ ├── qwen_merge.yaml
│ └── qwen_sft.yaml
├── get_raw_sft_data.py
├── lam_data_process.py
├── pyproject.toml
├── requirements.txt
├── scripts
│ ├── api_example
│ │ ├── test_image.py
│ │ └── test_toolcall.py
│ ├── convert_ckpt
│ │ ├── llamafy_baichuan2.py
│ │ ├── llamafy_qwen.py
│ │ └── tiny_llama4.py
│ ├── eval_bleu_rouge.py
│ ├── llama_pro.py
│ ├── loftq_init.py
│ ├── pissa_init.py
│ ├── qwen_omni_merge.py
│ ├── stat_utils
│ │ ├── cal_flops.py
│ │ ├── cal_lr.py
│ │ ├── cal_mfu.py
│ │ ├── cal_ppl.py
│ │ └── length_cdf.py
│ └── vllm_infer.py
├── setup.py
└── src
│ ├── api.py
│ ├── llamafactory.egg-info
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ ├── entry_points.txt
│ ├── requires.txt
│ └── top_level.txt
│ ├── llamafactory
│ ├── __init__.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── app.py
│ │ ├── chat.py
│ │ ├── common.py
│ │ └── protocol.py
│ ├── chat
│ │ ├── __init__.py
│ │ ├── base_engine.py
│ │ ├── chat_model.py
│ │ ├── hf_engine.py
│ │ ├── sglang_engine.py
│ │ └── vllm_engine.py
│ ├── cli.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── collator.py
│ │ ├── converter.py
│ │ ├── data_utils.py
│ │ ├── formatter.py
│ │ ├── loader.py
│ │ ├── mm_plugin.py
│ │ ├── parser.py
│ │ ├── processor
│ │ │ ├── __init__.py
│ │ │ ├── feedback.py
│ │ │ ├── pairwise.py
│ │ │ ├── pretrain.py
│ │ │ ├── processor_utils.py
│ │ │ ├── supervised.py
│ │ │ └── unsupervised.py
│ │ ├── template.py
│ │ └── tool_utils.py
│ ├── eval
│ │ ├── __init__.py
│ │ ├── evaluator.py
│ │ └── template.py
│ ├── extras
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── env.py
│ │ ├── logging.py
│ │ ├── misc.py
│ │ ├── packages.py
│ │ └── ploting.py
│ ├── hparams
│ │ ├── __init__.py
│ │ ├── data_args.py
│ │ ├── evaluation_args.py
│ │ ├── finetuning_args.py
│ │ ├── generating_args.py
│ │ ├── model_args.py
│ │ ├── parser.py
│ │ └── training_args.py
│ ├── launcher.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── adapter.py
│ │ ├── loader.py
│ │ ├── model_utils
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── checkpointing.py
│ │ │ ├── embedding.py
│ │ │ ├── kv_cache.py
│ │ │ ├── liger_kernel.py
│ │ │ ├── longlora.py
│ │ │ ├── misc.py
│ │ │ ├── mod.py
│ │ │ ├── moe.py
│ │ │ ├── packing.py
│ │ │ ├── quantization.py
│ │ │ ├── rope.py
│ │ │ ├── unsloth.py
│ │ │ ├── valuehead.py
│ │ │ └── visual.py
│ │ └── patcher.py
│ ├── third_party
│ │ ├── __init__.py
│ │ └── muon
│ │ │ ├── __init__.py
│ │ │ └── muon.py
│ ├── train
│ │ ├── __init__.py
│ │ ├── callbacks.py
│ │ ├── dpo
│ │ │ ├── __init__.py
│ │ │ ├── trainer.py
│ │ │ └── workflow.py
│ │ ├── kto
│ │ │ ├── __init__.py
│ │ │ ├── trainer.py
│ │ │ └── workflow.py
│ │ ├── ppo
│ │ │ ├── __init__.py
│ │ │ ├── ppo_utils.py
│ │ │ ├── trainer.py
│ │ │ └── workflow.py
│ │ ├── pt
│ │ │ ├── __init__.py
│ │ │ ├── trainer.py
│ │ │ └── workflow.py
│ │ ├── rm
│ │ │ ├── __init__.py
│ │ │ ├── metric.py
│ │ │ ├── trainer.py
│ │ │ └── workflow.py
│ │ ├── sft
│ │ │ ├── __init__.py
│ │ │ ├── metric.py
│ │ │ ├── trainer.py
│ │ │ └── workflow.py
│ │ ├── test_utils.py
│ │ ├── trainer_utils.py
│ │ └── tuner.py
│ └── webui
│ │ ├── __init__.py
│ │ ├── chatter.py
│ │ ├── common.py
│ │ ├── components
│ │ ├── __init__.py
│ │ ├── chatbot.py
│ │ ├── data.py
│ │ ├── eval.py
│ │ ├── export.py
│ │ ├── infer.py
│ │ ├── top.py
│ │ └── train.py
│ │ ├── control.py
│ │ ├── css.py
│ │ ├── engine.py
│ │ ├── interface.py
│ │ ├── locales.py
│ │ ├── manager.py
│ │ └── runner.py
│ ├── train.py
│ └── webui.py
├── README.md
├── assets
├── data_composition.png
├── exp_main.png
├── overview.png
├── reward.png
├── scability.png
├── sft_or_rl.png
└── thinking_template.png
├── data_process
├── distill_data.py
├── hammer_utils.py
└── raw_data_process.py
├── eval
└── bfcl_handler.py
├── llamafact_merge.sh
├── paper.pdf
├── qwen_rl.sh
├── qwen_sft.sh
├── verl
├── LICENSE
├── examples
│ ├── agent
│ │ └── qwen.sh
│ └── data_preprocess
│ │ └── verl_toolcall_preprocess.py
├── setup.py
└── verl
│ ├── __init__.py
│ ├── data
│ ├── test.parquet
│ └── train.parquet
│ ├── models
│ ├── README.md
│ ├── __init__.py
│ ├── llama
│ │ ├── __init__.py
│ │ └── megatron
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_utils
│ │ │ ├── __init__.py
│ │ │ ├── llama_loader.py
│ │ │ └── llama_saver.py
│ │ │ ├── layers
│ │ │ ├── __init__.py
│ │ │ ├── parallel_attention.py
│ │ │ ├── parallel_decoder.py
│ │ │ ├── parallel_linear.py
│ │ │ ├── parallel_mlp.py
│ │ │ └── parallel_rmsnorm.py
│ │ │ └── modeling_llama_megatron.py
│ ├── qwen2
│ │ ├── __init__.py
│ │ └── megatron
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_utils
│ │ │ ├── __init__.py
│ │ │ ├── qwen2_loader.py
│ │ │ └── qwen2_saver.py
│ │ │ ├── layers
│ │ │ ├── __init__.py
│ │ │ ├── parallel_attention.py
│ │ │ ├── parallel_decoder.py
│ │ │ ├── parallel_linear.py
│ │ │ ├── parallel_mlp.py
│ │ │ └── parallel_rmsnorm.py
│ │ │ └── modeling_qwen2_megatron.py
│ ├── registry.py
│ ├── transformers
│ │ ├── __init__.py
│ │ ├── llama.py
│ │ ├── monkey_patch.py
│ │ ├── qwen2.py
│ │ └── qwen2_vl.py
│ └── weight_loader_registry.py
│ ├── protocol.py
│ ├── single_controller
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── decorator.py
│ │ ├── megatron
│ │ │ ├── __init__.py
│ │ │ ├── worker.py
│ │ │ └── worker_group.py
│ │ ├── register_center
│ │ │ ├── __init__.py
│ │ │ └── ray.py
│ │ ├── worker.py
│ │ └── worker_group.py
│ └── ray
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── megatron.py
│ ├── third_party
│ ├── __init__.py
│ └── vllm
│ │ ├── __init__.py
│ │ ├── vllm_spmd
│ │ ├── __init__.py
│ │ └── dtensor_weight_loaders.py
│ │ ├── vllm_v_0_3_1
│ │ ├── __init__.py
│ │ ├── arg_utils.py
│ │ ├── config.py
│ │ ├── llm.py
│ │ ├── llm_engine_sp.py
│ │ ├── model_loader.py
│ │ ├── model_runner.py
│ │ ├── parallel_state.py
│ │ ├── tokenizer.py
│ │ ├── weight_loaders.py
│ │ └── worker.py
│ │ ├── vllm_v_0_4_2
│ │ ├── __init__.py
│ │ ├── arg_utils.py
│ │ ├── config.py
│ │ ├── dtensor_weight_loaders.py
│ │ ├── hf_weight_loader.py
│ │ ├── llm.py
│ │ ├── llm_engine_sp.py
│ │ ├── megatron_weight_loaders.py
│ │ ├── model_loader.py
│ │ ├── model_runner.py
│ │ ├── parallel_state.py
│ │ ├── spmd_gpu_executor.py
│ │ ├── tokenizer.py
│ │ └── worker.py
│ │ ├── vllm_v_0_5_4
│ │ ├── __init__.py
│ │ ├── arg_utils.py
│ │ ├── config.py
│ │ ├── dtensor_weight_loaders.py
│ │ ├── hf_weight_loader.py
│ │ ├── llm.py
│ │ ├── llm_engine_sp.py
│ │ ├── megatron_weight_loaders.py
│ │ ├── model_loader.py
│ │ ├── model_runner.py
│ │ ├── parallel_state.py
│ │ ├── spmd_gpu_executor.py
│ │ ├── tokenizer.py
│ │ └── worker.py
│ │ └── vllm_v_0_6_3
│ │ ├── __init__.py
│ │ ├── arg_utils.py
│ │ ├── config.py
│ │ ├── dtensor_weight_loaders.py
│ │ ├── hf_weight_loader.py
│ │ ├── llm.py
│ │ ├── llm_engine_sp.py
│ │ ├── megatron_weight_loaders.py
│ │ ├── model_loader.py
│ │ ├── model_runner.py
│ │ ├── parallel_state.py
│ │ ├── spmd_gpu_executor.py
│ │ ├── tokenizer.py
│ │ └── worker.py
│ ├── trainer
│ ├── __init__.py
│ ├── config
│ │ ├── evaluation.yaml
│ │ ├── generation.yaml
│ │ ├── ppo_megatron_trainer.yaml
│ │ ├── ppo_trainer.yaml
│ │ └── sft_trainer.yaml
│ ├── fsdp_sft_trainer.py
│ ├── main_eval.py
│ ├── main_generation.py
│ ├── main_ppo.py
│ ├── ppo
│ │ ├── __init__.py
│ │ ├── core_algos.py
│ │ └── ray_trainer.py
│ └── runtime_env.yaml
│ ├── utils
│ ├── __init__.py
│ ├── checkpoint
│ │ ├── __init__.py
│ │ ├── checkpoint_manager.py
│ │ └── fsdp_checkpoint_manager.py
│ ├── config.py
│ ├── dataset
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── rl_dataset.py
│ │ ├── rm_dataset.py
│ │ └── sft_dataset.py
│ ├── debug
│ │ ├── __init__.py
│ │ ├── performance.py
│ │ └── trajectory_tracker.py
│ ├── distributed.py
│ ├── flops_counter.py
│ ├── fs.py
│ ├── fsdp_utils.py
│ ├── hdfs_io.py
│ ├── import_utils.py
│ ├── logger
│ │ ├── __init__.py
│ │ └── aggregate_logger.py
│ ├── logging_utils.py
│ ├── megatron
│ │ ├── __init__.py
│ │ ├── memory.py
│ │ ├── optimizer.py
│ │ ├── pipeline_parallel.py
│ │ ├── sequence_parallel.py
│ │ └── tensor_parallel.py
│ ├── megatron_utils.py
│ ├── memory_buffer.py
│ ├── model.py
│ ├── py_functional.py
│ ├── ray_utils.py
│ ├── rendezvous
│ │ ├── __init__.py
│ │ └── ray_backend.py
│ ├── reward_score
│ │ ├── __init__.py
│ │ ├── geo3k.py
│ │ ├── gsm8k.py
│ │ ├── math.py
│ │ ├── prime_code
│ │ │ ├── __init__.py
│ │ │ ├── testing_util.py
│ │ │ └── utils.py
│ │ ├── prime_math
│ │ │ ├── __init__.py
│ │ │ ├── grader.py
│ │ │ └── math_normalize.py
│ │ └── toolcall.py
│ ├── seqlen_balancing.py
│ ├── tokenizer.py
│ ├── tool_utils.py
│ ├── torch_dtypes.py
│ ├── torch_functional.py
│ ├── tracking.py
│ └── ulysses.py
│ ├── version
│ └── version
│ └── workers
│ ├── __init__.py
│ ├── actor
│ ├── __init__.py
│ ├── base.py
│ ├── dp_actor.py
│ └── megatron_actor.py
│ ├── critic
│ ├── __init__.py
│ ├── base.py
│ ├── dp_critic.py
│ └── megatron_critic.py
│ ├── fsdp_workers.py
│ ├── megatron_workers.py
│ ├── reward_manager
│ ├── __init__.py
│ ├── naive.py
│ └── prime.py
│ ├── reward_model
│ ├── __init__.py
│ ├── base.py
│ └── megatron
│ │ ├── __init__.py
│ │ └── reward_model.py
│ ├── rollout
│ ├── __init__.py
│ ├── base.py
│ ├── hf_rollout.py
│ ├── naive
│ │ ├── __init__.py
│ │ └── naive_rollout.py
│ ├── tokenizer.py
│ └── vllm_rollout
│ │ ├── __init__.py
│ │ ├── fire_vllm_rollout.py
│ │ ├── vllm_rollout.py
│ │ └── vllm_rollout_spmd.py
│ └── sharding_manager
│ ├── __init__.py
│ ├── base.py
│ ├── fsdp_ulysses.py
│ ├── fsdp_vllm.py
│ └── megatron_vllm.py
└── verl_convert.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | LLaMA-Factory/data/tool_sft.json filter=lfs diff=lfs merge=lfs -text
2 | verl/verl/data/test.parquet filter=lfs diff=lfs merge=lfs -text
3 | verl/verl/data/train.parquet filter=lfs diff=lfs merge=lfs -text
4 |
--------------------------------------------------------------------------------
/LLaMA-Factory/.env.local:
--------------------------------------------------------------------------------
1 | # Note: actually we do not support .env, just for reference
2 | # api
3 | API_HOST=
4 | API_PORT=
5 | API_KEY=
6 | API_MODEL_NAME=
7 | API_VERBOSE=
8 | FASTAPI_ROOT_PATH=
9 | MAX_CONCURRENT=
10 | # general
11 | DISABLE_VERSION_CHECK=
12 | FORCE_CHECK_IMPORTS=
13 | ALLOW_EXTRA_ARGS=
14 | LLAMAFACTORY_VERBOSITY=
15 | USE_MODELSCOPE_HUB=
16 | USE_OPENMIND_HUB=
17 | USE_RAY=
18 | RECORD_VRAM=
19 | OPTIM_TORCH=
20 | NPU_JIT_COMPILE=
21 | # torchrun
22 | FORCE_TORCHRUN=
23 | MASTER_ADDR=
24 | MASTER_PORT=
25 | NNODES=
26 | NODE_RANK=
27 | NPROC_PER_NODE=
28 | # wandb
29 | WANDB_DISABLED=
30 | WANDB_PROJECT=
31 | WANDB_API_KEY=
32 | # gradio ui
33 | GRADIO_SHARE=
34 | GRADIO_SERVER_NAME=
35 | GRADIO_SERVER_PORT=
36 | GRADIO_ROOT_PATH=
37 | GRADIO_IPV6=
38 | # setup
39 | ENABLE_SHORT_CONSOLE=
40 | # reserved (do not use)
41 | LLAMABOARD_ENABLED=
42 | LLAMABOARD_WORKDIR=
43 |
--------------------------------------------------------------------------------
/LLaMA-Factory/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/LLaMA-Factory/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE requirements.txt
2 |
--------------------------------------------------------------------------------
/LLaMA-Factory/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: build commit license quality style test
2 |
3 | check_dirs := scripts src tests setup.py
4 |
5 | build:
6 | pip3 install build && python3 -m build
7 |
8 | commit:
9 | pre-commit install
10 | pre-commit run --all-files
11 |
12 | license:
13 | python3 tests/check_license.py $(check_dirs)
14 |
15 | quality:
16 | ruff check $(check_dirs)
17 | ruff format --check $(check_dirs)
18 |
19 | style:
20 | ruff check $(check_dirs) --fix
21 | ruff format $(check_dirs)
22 |
23 | test:
24 | CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/
25 |
--------------------------------------------------------------------------------
/LLaMA-Factory/data/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "tool_sft": {
3 | "file_name": "tool_sft.json",
4 | "columns": {
5 | "prompt": "instruction",
6 | "query": "query",
7 | "response": "output",
8 | "system": "system_prompt"
9 | }
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/LLaMA-Factory/data/tool_sft.json:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:836c6592b03ee881ca3ba4934a25e69b14a909567f062e91da3aef6ac146c7fb
3 | size 24144869
4 |
--------------------------------------------------------------------------------
/LLaMA-Factory/evaluation/ceval/ceval.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/evaluation/ceval/ceval.zip
--------------------------------------------------------------------------------
/LLaMA-Factory/evaluation/cmmlu/cmmlu.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/evaluation/cmmlu/cmmlu.zip
--------------------------------------------------------------------------------
/LLaMA-Factory/evaluation/mmlu/mmlu.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/evaluation/mmlu/mmlu.zip
--------------------------------------------------------------------------------
/LLaMA-Factory/examples/qwen_merge.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: path/to/model/Qwen2.5-7B-Instruct
3 | adapter_name_or_path: path/to/adapter/qwen-7b
4 | template: qwen
5 | finetuning_type: lora
6 |
7 | ### export
8 | export_dir: path/to/export
9 | export_size: 2
10 | export_device: cpu
11 | export_legacy_format: false
--------------------------------------------------------------------------------
/LLaMA-Factory/examples/qwen_sft.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: path/to/model/Qwen2.5-7B-Instruct
3 | trust_remote_code: true
4 |
5 | ### method
6 | stage: sft
7 | do_train: true
8 | finetuning_type: lora
9 | lora_rank: 8
10 | lora_target: all
11 |
12 | ### dataset
13 | dataset: tool_sft
14 | template: qwen
15 | cutoff_len: 2048
16 | overwrite_cache: true
17 | preprocessing_num_workers: 16
18 | dataloader_num_workers: 4
19 |
20 | ### output
21 | output_dir: saves/qwen-7b/lora/sft
22 | logging_steps: 10
23 | save_steps: 10
24 | plot_loss: true
25 | overwrite_output_dir: true
26 | save_only_model: false
27 | report_to: wandb # choices: [none, wandb, tensorboard, swanlab, mlflow]
28 |
29 | ### train
30 | per_device_train_batch_size: 1
31 | gradient_accumulation_steps: 8
32 | learning_rate: 1.0e-4
33 | num_train_epochs: 5.0
34 | lr_scheduler_type: cosine
35 | warmup_ratio: 0.1
36 | bf16: true
37 | ddp_timeout: 180000000
38 | resume_from_checkpoint: null
--------------------------------------------------------------------------------
/LLaMA-Factory/get_raw_sft_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 |
18 | import json
19 | import os
20 |
21 | def process_json(input_path, output_path):
22 | with open(input_path, 'r') as f:
23 | data = json.load(f)
24 |
25 | new_instruction = '''When you want to perform tool call, in each action step, providing a json object with function names and arguments within XML tags. i.e., [{"name": , "arguments": }, {"name": , "arguments": }, ...]
26 | A complete reply example is: [{"name": "email", "arguments": {"receiver": "Bob", "content": "I will bug banana through walmart"}}, {"name": "walmart", "arguments": {"input": "banana"}}]. Please make sure the type of the arguments is correct.'''
27 |
28 | for entry in data:
29 | if 'system_prompt' in entry and "In each action step, you MUST" in entry['system_prompt']:
30 | parts = entry['system_prompt'].split("In each action step, you MUST", 1)
31 | entry['system_prompt'] = parts[0] + new_instruction
32 |
33 | if 'output' in entry:
34 | think_start = entry['output'].find("")
35 | think_end = entry['output'].find("") + len("")
36 | if think_start != -1 and think_end != -1:
37 | entry['output'] = entry['output'][think_end:].lstrip("\n")
38 |
39 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
40 | with open(output_path, 'w') as f:
41 | json.dump(data, f, indent=2, ensure_ascii=False)
42 |
43 | process_json("path/to/tool_sft.json", "path/to/data/raw_tool_sft.json")
44 |
--------------------------------------------------------------------------------
/LLaMA-Factory/lam_data_process.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 |
18 | import json
19 |
20 | def truncate_after_first_tool_call_end(s):
21 | marker = ""
22 | index = s.find(marker)
23 | if index != -1:
24 | return s[:index + len(marker)]
25 | return s
26 |
27 | def truncate_before_expert_prompt(s):
28 | marker = "You are an expert in composing functions"
29 | index = s.find(marker)
30 | if index != -1:
31 | return s[index:]
32 | return s
33 |
34 | # Define file paths
35 | input_file = "path/to/distilled_data/sft_data.json"
36 | output_file = "path/to/data/tool_sft.json"
37 |
38 | # Load the input JSON data
39 | with open(input_file, "r", encoding="utf-8") as f:
40 | input_data = json.load(f)
41 |
42 | # Transform the data to match the Alpaca format
43 | alpaca_data = []
44 | for item in input_data:
45 | transformed_item = {
46 | "system_prompt": truncate_before_expert_prompt(item.get("system_prompt", "")),
47 | "query": item.get("query", ""),
48 | "output": truncate_after_first_tool_call_end(item.get("output", "")),
49 | "instruction": ""
50 | }
51 | alpaca_data.append(transformed_item)
52 |
53 | # Save the transformed data to the output file
54 | with open(output_file, "w", encoding="utf-8") as f:
55 | json.dump(alpaca_data, f, ensure_ascii=False, indent=2)
56 |
57 | print(f"Data has been successfully transformed and saved to {output_file}")
--------------------------------------------------------------------------------
/LLaMA-Factory/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "llamafactory"
7 | dynamic = [
8 | "version",
9 | "dependencies",
10 | "optional-dependencies",
11 | "requires-python",
12 | "scripts",
13 | "authors",
14 | "description",
15 | "readme",
16 | "license",
17 | "keywords",
18 | "classifiers"
19 | ]
20 |
21 | [tool.ruff]
22 | target-version = "py39"
23 | line-length = 119
24 | indent-width = 4
25 |
26 | [tool.ruff.lint]
27 | ignore = [
28 | "C408", # collection
29 | "C901", # complex
30 | "E501", # line too long
31 | "E731", # lambda function
32 | "E741", # ambiguous var name
33 | "D100", # no doc public module
34 | "D101", # no doc public class
35 | "D102", # no doc public method
36 | "D103", # no doc public function
37 | "D104", # no doc public package
38 | "D105", # no doc magic method
39 | "D107", # no doc __init__
40 | ]
41 | extend-select = [
42 | "C", # complexity
43 | "E", # error
44 | "F", # pyflakes
45 | "I", # isort
46 | "W", # warning
47 | "UP", # pyupgrade
48 | "D", # pydocstyle
49 | "PT009", # pytest assert
50 | "RUF022", # sort __all__
51 | ]
52 |
53 | [tool.ruff.lint.isort]
54 | lines-after-imports = 2
55 | known-first-party = ["llamafactory"]
56 | known-third-party = [
57 | "accelerate",
58 | "datasets",
59 | "gradio",
60 | "numpy",
61 | "peft",
62 | "torch",
63 | "transformers",
64 | "trl",
65 | ]
66 |
67 | [tool.ruff.lint.pydocstyle]
68 | convention = "google"
69 |
70 | [tool.ruff.format]
71 | quote-style = "double"
72 | indent-style = "space"
73 | docstring-code-format = true
74 | skip-magic-trailing-comma = false
75 | line-ending = "auto"
76 |
77 | [tool.uv]
78 | conflicts = [
79 | [
80 | { extra = "torch-npu" },
81 | { extra = "aqlm" },
82 | ],
83 | [
84 | { extra = "torch-npu" },
85 | { extra = "liger-kernel" },
86 | ],
87 | [
88 | { extra = "torch-npu" },
89 | { extra = "vllm" },
90 | ],
91 | [
92 | { extra = "sglang" },
93 | { extra = "minicpm_v" },
94 | ],
95 | ]
96 |
--------------------------------------------------------------------------------
/LLaMA-Factory/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
2 | datasets>=2.16.0,<=3.5.0
3 | accelerate>=0.34.0,<=1.6.0
4 | peft>=0.14.0,<=0.15.1
5 | trl>=0.8.6,<=0.9.6
6 | tokenizers>=0.19.0,<=0.21.1
7 | gradio>=4.38.0,<=5.25.0
8 | scipy
9 | einops
10 | sentencepiece
11 | tiktoken
12 | protobuf
13 | uvicorn
14 | fastapi
15 | sse-starlette
16 | matplotlib>=3.7.0
17 | fire
18 | omegaconf
19 | packaging
20 | pyyaml
21 | numpy<2.0.0
22 | pydantic<=2.10.6
23 | pandas>=2.0.0
24 | av
25 | librosa
26 | tyro<0.9.0
27 |
--------------------------------------------------------------------------------
/LLaMA-Factory/scripts/api_example/test_image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | from openai import OpenAI
18 | from transformers.utils.versions import require_version
19 |
20 |
21 | require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
22 |
23 |
24 | def main():
25 | client = OpenAI(
26 | api_key="{}".format(os.getenv("API_KEY", "0")),
27 | base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)),
28 | )
29 | messages = []
30 | messages.append(
31 | {
32 | "role": "user",
33 | "content": [
34 | {"type": "text", "text": "Output the color and number of each box."},
35 | {
36 | "type": "image_url",
37 | "image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"},
38 | },
39 | ],
40 | }
41 | )
42 | result = client.chat.completions.create(messages=messages, model="test")
43 | messages.append(result.choices[0].message)
44 | print("Round 1:", result.choices[0].message.content)
45 | # The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ...
46 | messages.append(
47 | {
48 | "role": "user",
49 | "content": [
50 | {"type": "text", "text": "What kind of flower is this?"},
51 | {
52 | "type": "image_url",
53 | "image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"},
54 | },
55 | ],
56 | }
57 | )
58 | result = client.chat.completions.create(messages=messages, model="test")
59 | messages.append(result.choices[0].message)
60 | print("Round 2:", result.choices[0].message.content)
61 | # The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ...
62 |
63 |
64 | if __name__ == "__main__":
65 | main()
66 |
--------------------------------------------------------------------------------
/LLaMA-Factory/scripts/convert_ckpt/tiny_llama4.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from transformers import Llama4Config, Llama4ForConditionalGeneration, Llama4TextConfig, Llama4VisionConfig
16 |
17 |
18 | if __name__ == "__main__":
19 | vision_config = Llama4VisionConfig(
20 | hidden_size=1408,
21 | image_size=336,
22 | intermediate_size=5632,
23 | num_attention_heads=16,
24 | num_hidden_layers=4,
25 | vision_output_dim=4096,
26 | )
27 | text_config = Llama4TextConfig(
28 | hidden_size=512,
29 | intermediate_size=1024,
30 | intermediate_size_mlp=1024,
31 | num_hidden_layers=4,
32 | num_attention_heads=8,
33 | num_key_value_heads=2,
34 | head_dim=512 // 8,
35 | num_local_experts=2,
36 | )
37 | config = Llama4Config(vision_config=vision_config, text_config=text_config)
38 | model = Llama4ForConditionalGeneration._from_config(config)
39 | model.save_pretrained("tiny-llama4")
40 |
--------------------------------------------------------------------------------
/LLaMA-Factory/scripts/eval_bleu_rouge.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 | import logging
17 | import time
18 |
19 | import fire
20 | from datasets import load_dataset
21 |
22 |
23 | try:
24 | import jieba # type: ignore
25 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu # type: ignore
26 | from rouge_chinese import Rouge # type: ignore
27 |
28 | jieba.setLogLevel(logging.CRITICAL)
29 | jieba.initialize()
30 | except ImportError:
31 | print("Please install llamafactory with `pip install -e .[metrics]`.")
32 | raise
33 |
34 |
35 | def compute_metrics(sample):
36 | hypothesis = list(jieba.cut(sample["predict"]))
37 | reference = list(jieba.cut(sample["label"]))
38 |
39 | bleu_score = sentence_bleu(
40 | [list(sample["label"])],
41 | list(sample["predict"]),
42 | smoothing_function=SmoothingFunction().method3,
43 | )
44 |
45 | if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
46 | result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
47 | else:
48 | rouge = Rouge()
49 | scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
50 | result = scores[0]
51 |
52 | metric_result = {}
53 | for k, v in result.items():
54 | metric_result[k] = round(v["f"] * 100, 4)
55 |
56 | metric_result["bleu-4"] = round(bleu_score * 100, 4)
57 |
58 | return metric_result
59 |
60 |
61 | def main(filename: str):
62 | start_time = time.time()
63 | dataset = load_dataset("json", data_files=filename, split="train")
64 | dataset = dataset.map(compute_metrics, num_proc=8, remove_columns=dataset.column_names)
65 | score_dict = dataset.to_dict()
66 |
67 | average_score = {}
68 | for task, scores in sorted(score_dict.items(), key=lambda x: x[0]):
69 | print(f"{task}: {sum(scores) / len(scores):.4f}")
70 | average_score[task] = sum(scores) / len(scores)
71 |
72 | with open("predictions_score.json", "w", encoding="utf-8") as f:
73 | json.dump(average_score, f, indent=4)
74 |
75 | print(f"\nDone in {time.time() - start_time:.3f}s.\nScore file saved to predictions_score.json")
76 |
77 |
78 | if __name__ == "__main__":
79 | fire.Fire(main)
80 |
--------------------------------------------------------------------------------
/LLaMA-Factory/scripts/stat_utils/cal_flops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Microsoft Corporation and the LlamaFactory team.
2 | #
3 | # This code is inspired by the Microsoft's DeepSpeed library.
4 | # https://www.deepspeed.ai/tutorials/flops-profiler/
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import fire
19 | import torch
20 | from deepspeed.accelerator import get_accelerator # type: ignore
21 | from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
22 |
23 | from llamafactory.chat import ChatModel
24 |
25 |
26 | def calculate_flops(
27 | model_name_or_path: str,
28 | batch_size: int = 1,
29 | seq_length: int = 512,
30 | flash_attn: str = "auto",
31 | ):
32 | r"""Calculate the flops of pre-trained models.
33 |
34 | Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
35 | """
36 | with get_accelerator().device(0):
37 | chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
38 | fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device)
39 | input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
40 | flops, macs, params = get_model_profile(
41 | chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True
42 | )
43 | print("FLOPs:", flops)
44 | print("MACs:", macs)
45 | print("Params:", params)
46 |
47 |
48 | if __name__ == "__main__":
49 | fire.Fire(calculate_flops)
50 |
--------------------------------------------------------------------------------
/LLaMA-Factory/scripts/stat_utils/length_cdf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from collections import defaultdict
16 |
17 | import fire
18 | from tqdm import tqdm
19 |
20 | from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
21 | from llamafactory.hparams import get_train_args
22 | from llamafactory.model import load_tokenizer
23 |
24 |
25 | def length_cdf(
26 | model_name_or_path: str,
27 | dataset: str = "alpaca_en_demo",
28 | dataset_dir: str = "data",
29 | template: str = "default",
30 | interval: int = 1000,
31 | ):
32 | r"""Calculate the distribution of the input lengths in the dataset.
33 |
34 | Usage: export CUDA_VISIBLE_DEVICES=0
35 | python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
36 | """
37 | model_args, data_args, training_args, _, _ = get_train_args(
38 | dict(
39 | stage="sft",
40 | model_name_or_path=model_name_or_path,
41 | dataset=dataset,
42 | dataset_dir=dataset_dir,
43 | template=template,
44 | cutoff_len=1_000_000,
45 | preprocessing_num_workers=16,
46 | output_dir="dummy_dir",
47 | overwrite_cache=True,
48 | do_train=True,
49 | )
50 | )
51 | tokenizer_module = load_tokenizer(model_args)
52 | template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
53 | trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
54 | total_num = len(trainset)
55 | length_dict = defaultdict(int)
56 | for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"):
57 | length_dict[len(sample) // interval * interval] += 1
58 |
59 | length_tuples = list(length_dict.items())
60 | length_tuples.sort()
61 | count_accu, prob_accu = 0, 0
62 | for length, count in length_tuples:
63 | count_accu += count
64 | prob_accu += count / total_num * 100
65 | print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
66 |
67 |
68 | if __name__ == "__main__":
69 | fire.Fire(length_cdf)
70 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/api.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import uvicorn
18 |
19 | from llamafactory.api.app import create_app
20 | from llamafactory.chat import ChatModel
21 |
22 |
23 | def main():
24 | chat_model = ChatModel()
25 | app = create_app(chat_model)
26 | api_host = os.getenv("API_HOST", "0.0.0.0")
27 | api_port = int(os.getenv("API_PORT", "8000"))
28 | print(f"Visit http://localhost:{api_port}/docs for API document.")
29 | uvicorn.run(app, host=api_host, port=api_port)
30 |
31 |
32 | if __name__ == "__main__":
33 | main()
34 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory.egg-info/entry_points.txt:
--------------------------------------------------------------------------------
1 | [console_scripts]
2 | llamafactory-cli = llamafactory.cli:main
3 | lmf = llamafactory.cli:main
4 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | transformers!=4.46.*,!=4.47.*,!=4.48.0,<=4.51.3,>=4.45.0
2 | datasets<=3.5.0,>=2.16.0
3 | accelerate<=1.6.0,>=0.34.0
4 | peft<=0.15.1,>=0.14.0
5 | trl<=0.9.6,>=0.8.6
6 | tokenizers<=0.21.1,>=0.19.0
7 | gradio<=5.25.0,>=4.38.0
8 | scipy
9 | einops
10 | sentencepiece
11 | tiktoken
12 | protobuf
13 | uvicorn
14 | fastapi
15 | sse-starlette
16 | matplotlib>=3.7.0
17 | fire
18 | omegaconf
19 | packaging
20 | pyyaml
21 | numpy<2.0.0
22 | pydantic<=2.10.6
23 | pandas>=2.0.0
24 | av
25 | librosa
26 | tyro<0.9.0
27 |
28 | [adam-mini]
29 | adam-mini
30 |
31 | [apollo]
32 | apollo-torch
33 |
34 | [aqlm]
35 | aqlm[gpu]>=1.1.0
36 |
37 | [awq]
38 | autoawq
39 |
40 | [badam]
41 | badam>=1.2.1
42 |
43 | [bitsandbytes]
44 | bitsandbytes>=0.39.0
45 |
46 | [deepspeed]
47 | deepspeed<=0.16.5,>=0.10.0
48 |
49 | [dev]
50 | pre-commit
51 | ruff
52 | pytest
53 | build
54 |
55 | [eetq]
56 | eetq
57 |
58 | [galore]
59 | galore-torch
60 |
61 | [gptq]
62 | optimum>=1.17.0
63 | auto-gptq>=0.5.0
64 |
65 | [hqq]
66 | hqq
67 |
68 | [liger-kernel]
69 | liger-kernel>=0.5.5
70 |
71 | [metrics]
72 | nltk
73 | jieba
74 | rouge-chinese
75 |
76 | [minicpm_v]
77 | soundfile
78 | torchvision
79 | torchaudio
80 | vector_quantize_pytorch
81 | vocos
82 | msgpack
83 | referencing
84 | jsonschema_specifications
85 | transformers==4.48.3
86 |
87 | [modelscope]
88 | modelscope
89 |
90 | [openmind]
91 | openmind
92 |
93 | [qwen]
94 | transformers_stream_generator
95 |
96 | [sglang]
97 | sglang[srt]>=0.4.5
98 | transformers==4.51.1
99 |
100 | [swanlab]
101 | swanlab
102 |
103 | [torch]
104 | torch>=1.13.1
105 |
106 | [torch-npu]
107 | torch==2.4.0
108 | torch-npu==2.4.0.post2
109 | decorator
110 |
111 | [vllm]
112 | vllm<=0.8.4,>=0.4.3
113 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | llamafactory
2 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Efficient fine-tuning of large language models.
16 |
17 | Level:
18 | api, webui > chat, eval, train > data, model > hparams > extras
19 |
20 | Disable version checking: DISABLE_VERSION_CHECK=1
21 | Enable VRAM recording: RECORD_VRAM=1
22 | Force using torchrun: FORCE_TORCHRUN=1
23 | Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
24 | Use modelscope: USE_MODELSCOPE_HUB=1
25 | Use openmind: USE_OPENMIND_HUB=1
26 | """
27 |
28 | from .extras.env import VERSION
29 |
30 |
31 | __version__ = VERSION
32 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/api/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/api/__init__.py
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/api/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 | from typing import TYPE_CHECKING, Any
17 |
18 |
19 | if TYPE_CHECKING:
20 | from pydantic import BaseModel
21 |
22 |
23 | def dictify(data: "BaseModel") -> dict[str, Any]:
24 | try: # pydantic v2
25 | return data.model_dump(exclude_unset=True)
26 | except AttributeError: # pydantic v1
27 | return data.dict(exclude_unset=True)
28 |
29 |
30 | def jsonify(data: "BaseModel") -> str:
31 | try: # pydantic v2
32 | return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
33 | except AttributeError: # pydantic v1
34 | return data.json(exclude_unset=True, ensure_ascii=False)
35 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/chat/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base_engine import BaseEngine
16 | from .chat_model import ChatModel
17 |
18 |
19 | __all__ = ["BaseEngine", "ChatModel"]
20 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .collator import (
16 | KTODataCollatorWithPadding,
17 | MultiModalDataCollatorForSeq2Seq,
18 | PairwiseDataCollatorWithPadding,
19 | SFTDataCollatorWith4DAttentionMask,
20 | )
21 | from .data_utils import Role, split_dataset
22 | from .loader import get_dataset
23 | from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
24 |
25 |
26 | __all__ = [
27 | "TEMPLATES",
28 | "KTODataCollatorWithPadding",
29 | "MultiModalDataCollatorForSeq2Seq",
30 | "PairwiseDataCollatorWithPadding",
31 | "Role",
32 | "SFTDataCollatorWith4DAttentionMask",
33 | "Template",
34 | "get_dataset",
35 | "get_template_and_fix_tokenizer",
36 | "split_dataset",
37 | ]
38 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/data/processor/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .feedback import FeedbackDatasetProcessor
16 | from .pairwise import PairwiseDatasetProcessor
17 | from .pretrain import PretrainDatasetProcessor
18 | from .processor_utils import DatasetProcessor
19 | from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor
20 | from .unsupervised import UnsupervisedDatasetProcessor
21 |
22 |
23 | __all__ = [
24 | "DatasetProcessor",
25 | "FeedbackDatasetProcessor",
26 | "PackedSupervisedDatasetProcessor",
27 | "PairwiseDatasetProcessor",
28 | "PretrainDatasetProcessor",
29 | "SupervisedDatasetProcessor",
30 | "UnsupervisedDatasetProcessor",
31 | ]
32 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/data/processor/pretrain.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
2 | #
3 | # This code is inspired by the HuggingFace's transformers library.
4 | # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | from dataclasses import dataclass
19 | from itertools import chain
20 | from typing import Any
21 |
22 | from .processor_utils import DatasetProcessor
23 |
24 |
25 | @dataclass
26 | class PretrainDatasetProcessor(DatasetProcessor):
27 | def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
28 | # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
29 | eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
30 | text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
31 |
32 | if not self.data_args.packing:
33 | if getattr(self.tokenizer, "add_bos_token", False):
34 | text_examples = [self.tokenizer.bos_token + example for example in text_examples]
35 |
36 | result = self.tokenizer(
37 | text_examples, add_special_tokens=False, truncation=True, max_length=self.data_args.cutoff_len
38 | )
39 | else:
40 | tokenized_examples = self.tokenizer(text_examples, add_special_tokens=False)
41 | concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
42 | total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
43 | block_size = self.data_args.cutoff_len
44 | total_length = (total_length // block_size) * block_size
45 | result = {
46 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
47 | for k, t in concatenated_examples.items()
48 | }
49 | if getattr(self.tokenizer, "add_bos_token", False):
50 | for i in range(len(result["input_ids"])):
51 | result["input_ids"][i][0] = self.tokenizer.bos_token_id
52 |
53 | return result
54 |
55 | def print_data_example(self, example: dict[str, list[int]]) -> None:
56 | print("input_ids:\n{}".format(example["input_ids"]))
57 | print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
58 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/eval/__init__.py
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/eval/template.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass
16 |
17 | from ..data import Role
18 | from ..extras.constants import CHOICES
19 |
20 |
21 | @dataclass
22 | class EvalTemplate:
23 | system: str
24 | choice: str
25 | answer: str
26 |
27 | def _parse_example(self, example: dict[str, str]) -> tuple[str, str]:
28 | r"""Parse eval example.
29 |
30 | input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
31 | output: a tuple of (prompt, response).
32 | """
33 | candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
34 | return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
35 |
36 | def format_example(
37 | self, target_data: dict[str, str], support_set: list[dict[str, str]], subject_name: str
38 | ) -> list[dict[str, str]]:
39 | r"""Convert dataset examples to messages."""
40 | messages = []
41 | for k in range(len(support_set)):
42 | prompt, response = self._parse_example(support_set[k])
43 | messages.append({"role": Role.USER.value, "content": prompt})
44 | messages.append({"role": Role.ASSISTANT.value, "content": response})
45 |
46 | prompt, response = self._parse_example(target_data)
47 | messages.append({"role": Role.USER.value, "content": prompt})
48 | messages.append({"role": Role.ASSISTANT.value, "content": response})
49 | messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
50 | return messages
51 |
52 |
53 | eval_templates: dict[str, "EvalTemplate"] = {}
54 |
55 |
56 | def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
57 | eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
58 |
59 |
60 | def get_eval_template(name: str) -> "EvalTemplate":
61 | eval_template = eval_templates.get(name, None)
62 | assert eval_template is not None, f"Template {name} does not exist."
63 | return eval_template
64 |
65 |
66 | _register_eval_template(
67 | name="en",
68 | system="The following are multiple choice questions (with answers) about {subject}.\n\n",
69 | choice="\n{choice}. {content}",
70 | answer="\nAnswer:",
71 | )
72 |
73 |
74 | _register_eval_template(
75 | name="zh",
76 | system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
77 | choice="\n{choice}. {content}",
78 | answer="\n答案:",
79 | )
80 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/extras/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/extras/__init__.py
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/extras/env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
2 | #
3 | # This code is inspired by the HuggingFace's transformers library.
4 | # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import platform
19 |
20 | import accelerate
21 | import datasets
22 | import peft
23 | import torch
24 | import transformers
25 | import trl
26 | from transformers.utils import is_torch_cuda_available, is_torch_npu_available
27 |
28 |
29 | VERSION = "0.9.3.dev0"
30 |
31 |
32 | def print_env() -> None:
33 | info = {
34 | "`llamafactory` version": VERSION,
35 | "Platform": platform.platform(),
36 | "Python version": platform.python_version(),
37 | "PyTorch version": torch.__version__,
38 | "Transformers version": transformers.__version__,
39 | "Datasets version": datasets.__version__,
40 | "Accelerate version": accelerate.__version__,
41 | "PEFT version": peft.__version__,
42 | "TRL version": trl.__version__,
43 | }
44 |
45 | if is_torch_cuda_available():
46 | info["PyTorch version"] += " (GPU)"
47 | info["GPU type"] = torch.cuda.get_device_name()
48 | info["GPU number"] = torch.cuda.device_count()
49 | info["GPU memory"] = f"{torch.cuda.mem_get_info()[1] / (1024**3):.2f}GB"
50 |
51 | if is_torch_npu_available():
52 | info["PyTorch version"] += " (NPU)"
53 | info["NPU type"] = torch.npu.get_device_name()
54 | info["CANN version"] = torch.version.cann
55 |
56 | try:
57 | import deepspeed # type: ignore
58 |
59 | info["DeepSpeed version"] = deepspeed.__version__
60 | except Exception:
61 | pass
62 |
63 | try:
64 | import bitsandbytes # type: ignore
65 |
66 | info["Bitsandbytes version"] = bitsandbytes.__version__
67 | except Exception:
68 | pass
69 |
70 | try:
71 | import vllm
72 |
73 | info["vLLM version"] = vllm.__version__
74 | except Exception:
75 | pass
76 |
77 | try:
78 | import subprocess
79 |
80 | commit_info = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True)
81 | commit_hash = commit_info.stdout.strip()
82 | info["Git commit"] = commit_hash
83 | except Exception:
84 | pass
85 |
86 | print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
87 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/extras/packages.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
2 | #
3 | # This code is inspired by the HuggingFace's transformers library.
4 | # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import importlib.metadata
19 | import importlib.util
20 | from functools import lru_cache
21 | from typing import TYPE_CHECKING
22 |
23 | from packaging import version
24 |
25 |
26 | if TYPE_CHECKING:
27 | from packaging.version import Version
28 |
29 |
30 | def _is_package_available(name: str) -> bool:
31 | return importlib.util.find_spec(name) is not None
32 |
33 |
34 | def _get_package_version(name: str) -> "Version":
35 | try:
36 | return version.parse(importlib.metadata.version(name))
37 | except Exception:
38 | return version.parse("0.0.0")
39 |
40 |
41 | def is_pyav_available():
42 | return _is_package_available("av")
43 |
44 |
45 | def is_librosa_available():
46 | return _is_package_available("librosa")
47 |
48 |
49 | def is_fastapi_available():
50 | return _is_package_available("fastapi")
51 |
52 |
53 | def is_galore_available():
54 | return _is_package_available("galore_torch")
55 |
56 |
57 | def is_apollo_available():
58 | return _is_package_available("apollo_torch")
59 |
60 |
61 | def is_gradio_available():
62 | return _is_package_available("gradio")
63 |
64 |
65 | def is_matplotlib_available():
66 | return _is_package_available("matplotlib")
67 |
68 |
69 | def is_pillow_available():
70 | return _is_package_available("PIL")
71 |
72 |
73 | def is_ray_available():
74 | return _is_package_available("ray")
75 |
76 |
77 | def is_requests_available():
78 | return _is_package_available("requests")
79 |
80 |
81 | def is_rouge_available():
82 | return _is_package_available("rouge_chinese")
83 |
84 |
85 | def is_starlette_available():
86 | return _is_package_available("sse_starlette")
87 |
88 |
89 | @lru_cache
90 | def is_transformers_version_greater_than(content: str):
91 | return _get_package_version("transformers") >= version.parse(content)
92 |
93 |
94 | def is_uvicorn_available():
95 | return _is_package_available("uvicorn")
96 |
97 |
98 | def is_vllm_available():
99 | return _is_package_available("vllm")
100 |
101 |
102 | def is_sglang_available():
103 | return _is_package_available("sglang")
104 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/hparams/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .data_args import DataArguments
16 | from .evaluation_args import EvaluationArguments
17 | from .finetuning_args import FinetuningArguments
18 | from .generating_args import GeneratingArguments
19 | from .model_args import ModelArguments
20 | from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args
21 | from .training_args import RayArguments, TrainingArguments
22 |
23 |
24 | __all__ = [
25 | "DataArguments",
26 | "EvaluationArguments",
27 | "FinetuningArguments",
28 | "GeneratingArguments",
29 | "ModelArguments",
30 | "RayArguments",
31 | "TrainingArguments",
32 | "get_eval_args",
33 | "get_infer_args",
34 | "get_ray_args",
35 | "get_train_args",
36 | "read_args",
37 | ]
38 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/hparams/evaluation_args.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from dataclasses import dataclass, field
17 | from typing import Literal, Optional
18 |
19 | from datasets import DownloadMode
20 |
21 |
22 | @dataclass
23 | class EvaluationArguments:
24 | r"""Arguments pertaining to specify the evaluation parameters."""
25 |
26 | task: str = field(
27 | metadata={"help": "Name of the evaluation task."},
28 | )
29 | task_dir: str = field(
30 | default="evaluation",
31 | metadata={"help": "Path to the folder containing the evaluation datasets."},
32 | )
33 | batch_size: int = field(
34 | default=4,
35 | metadata={"help": "The batch size per GPU for evaluation."},
36 | )
37 | seed: int = field(
38 | default=42,
39 | metadata={"help": "Random seed to be used with data loaders."},
40 | )
41 | lang: Literal["en", "zh"] = field(
42 | default="en",
43 | metadata={"help": "Language used at evaluation."},
44 | )
45 | n_shot: int = field(
46 | default=5,
47 | metadata={"help": "Number of examplars for few-shot learning."},
48 | )
49 | save_dir: Optional[str] = field(
50 | default=None,
51 | metadata={"help": "Path to save the evaluation results."},
52 | )
53 | download_mode: DownloadMode = field(
54 | default=DownloadMode.REUSE_DATASET_IF_EXISTS,
55 | metadata={"help": "Download mode used for the evaluation datasets."},
56 | )
57 |
58 | def __post_init__(self):
59 | if self.save_dir is not None and os.path.exists(self.save_dir):
60 | raise ValueError("`save_dir` already exists, use another one.")
61 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/launcher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from llamafactory.train.tuner import run_exp # use absolute import
16 |
17 |
18 | def launch():
19 | run_exp()
20 |
21 |
22 | if __name__ == "__main__":
23 | launch()
24 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .loader import load_config, load_model, load_tokenizer
16 | from .model_utils.misc import find_all_linear_modules
17 | from .model_utils.quantization import QuantizationMethod
18 | from .model_utils.valuehead import load_valuehead_params
19 |
20 |
21 | __all__ = [
22 | "QuantizationMethod",
23 | "find_all_linear_modules",
24 | "load_config",
25 | "load_model",
26 | "load_tokenizer",
27 | "load_valuehead_params",
28 | ]
29 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/model/model_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/model/model_utils/__init__.py
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/model/model_utils/embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | from contextlib import nullcontext
17 | from typing import TYPE_CHECKING
18 |
19 | import torch
20 | from transformers.integrations import is_deepspeed_zero3_enabled
21 |
22 | from ...extras import logging
23 |
24 |
25 | if TYPE_CHECKING:
26 | from transformers import PreTrainedModel, PreTrainedTokenizer
27 |
28 |
29 | logger = logging.get_logger(__name__)
30 |
31 |
32 | def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
33 | embedding_dim = embed_weight.size(1)
34 | avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
35 | noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
36 | noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
37 | embed_weight[-num_new_tokens:] = avg_weight + noise_weight
38 |
39 |
40 | def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
41 | r"""Resize token embeddings."""
42 | if is_deepspeed_zero3_enabled():
43 | import deepspeed # type: ignore
44 |
45 | params = [model.get_input_embeddings().weight]
46 | if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
47 | params.append(model.get_output_embeddings().weight)
48 |
49 | context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
50 | else:
51 | context_maybe_zero3 = nullcontext()
52 |
53 | with context_maybe_zero3:
54 | current_embedding_size = model.get_input_embeddings().weight.size(0)
55 |
56 | if len(tokenizer) > current_embedding_size:
57 | if getattr(model, "quantization_method", None):
58 | raise ValueError("Cannot resize embedding layers of a quantized model.")
59 |
60 | if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
61 | raise ValueError("Current model does not support resizing embedding layers.")
62 |
63 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
64 | with context_maybe_zero3:
65 | new_embedding_size = model.get_input_embeddings().weight.size(0)
66 | num_new_tokens = new_embedding_size - current_embedding_size
67 | _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
68 | _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
69 |
70 | logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
71 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/model/model_utils/kv_cache.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from ...extras import logging
18 |
19 |
20 | logger = logging.get_logger(__name__)
21 |
22 |
23 | if TYPE_CHECKING:
24 | from transformers import PretrainedConfig
25 |
26 | from ...hparams import ModelArguments
27 |
28 |
29 | def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
30 | if not is_trainable:
31 | setattr(config, "use_cache", model_args.use_cache)
32 | if hasattr(config, "text_config"):
33 | setattr(config.text_config, "use_cache", model_args.use_cache)
34 |
35 | if model_args.use_cache:
36 | logger.info_rank0("KV cache is enabled for faster generation.")
37 | else:
38 | logger.info_rank0("KV cache is disabled.")
39 | else:
40 | setattr(config, "use_cache", False)
41 | if hasattr(config, "text_config"):
42 | setattr(config.text_config, "use_cache", False)
43 |
44 | logger.info_rank0("KV cache is disabled during training.")
45 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/model/model_utils/mod.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from ...extras.constants import MOD_SUPPORTED_MODELS
18 |
19 |
20 | if TYPE_CHECKING:
21 | from transformers import PretrainedConfig, PreTrainedModel
22 |
23 | from ...hparams import ModelArguments
24 |
25 |
26 | def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
27 | from MoD import AutoMoDModelForCausalLM
28 |
29 | return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
30 |
31 |
32 | def convert_pretrained_model_to_mod(
33 | model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
34 | ) -> "PreTrainedModel":
35 | from MoD import apply_mod_to_hf
36 |
37 | if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
38 | raise ValueError("Current model is not supported by mixture-of-depth.")
39 |
40 | model = apply_mod_to_hf(model)
41 | model = model.to(model_args.compute_dtype)
42 | return model
43 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/model/model_utils/rope.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 LMSYS and the LlamaFactory team.
2 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
3 | #
4 | # This code is inspired by the LMSYS's FastChat library.
5 | # https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | import math
20 | from typing import TYPE_CHECKING
21 |
22 | from ...extras import logging
23 | from ...extras.constants import RopeScaling
24 |
25 |
26 | if TYPE_CHECKING:
27 | from transformers import PretrainedConfig
28 |
29 | from ...hparams import ModelArguments
30 |
31 |
32 | logger = logging.get_logger(__name__)
33 |
34 |
35 | def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
36 | if model_args.rope_scaling is None:
37 | return
38 |
39 | if not hasattr(config, "rope_scaling"):
40 | logger.warning_rank0("Current model does not support RoPE scaling.")
41 | return
42 |
43 | rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum
44 | if model_args.model_max_length is not None:
45 | if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC:
46 | logger.warning_rank0(
47 | "Dynamic NTK scaling may not work well with fine-tuning. "
48 | "See: https://github.com/huggingface/transformers/pull/24653"
49 | )
50 |
51 | current_max_length = getattr(config, "max_position_embeddings", None)
52 | if (not current_max_length) or model_args.model_max_length <= current_max_length:
53 | logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
54 | return
55 |
56 | logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
57 | setattr(config, "max_position_embeddings", model_args.model_max_length)
58 | rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
59 | if model_args.rope_scaling == RopeScaling.DYNAMIC:
60 | rope_kwargs["original_max_position_embeddings"] = current_max_length
61 | elif model_args.rope_scaling == RopeScaling.LLAMA3:
62 | rope_kwargs["original_max_position_embeddings"] = current_max_length
63 | rope_kwargs["low_freq_factor"] = 1.0
64 | rope_kwargs["high_freq_factor"] = 4.0
65 | else:
66 | rope_kwargs["factor"] = 2.0
67 |
68 | setattr(config, "rope_scaling", rope_kwargs)
69 | logger.info_rank0(
70 | f"Using {rope_kwargs['rope_type']} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
71 | )
72 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/model/model_utils/valuehead.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | import torch
18 | from transformers.utils import cached_file
19 |
20 | from ...extras import logging
21 | from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
22 |
23 |
24 | if TYPE_CHECKING:
25 | from transformers import PreTrainedModel
26 |
27 | from ...hparams import ModelArguments
28 |
29 |
30 | logger = logging.get_logger(__name__)
31 |
32 |
33 | def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> dict[str, torch.Tensor]:
34 | r"""Load value head parameters from Hugging Face Hub or local disk.
35 |
36 | Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
37 | """
38 | kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
39 | err_text = ""
40 |
41 | try:
42 | from safetensors import safe_open
43 |
44 | vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
45 | with safe_open(vhead_file, framework="pt", device="cpu") as f:
46 | return {key: f.get_tensor(key) for key in f.keys()}
47 | except Exception as err:
48 | err_text = str(err)
49 |
50 | try:
51 | vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
52 | return torch.load(vhead_file, map_location="cpu")
53 | except Exception as err:
54 | err_text = str(err)
55 |
56 | logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
57 | logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.")
58 | return None
59 |
60 |
61 | def prepare_valuehead_model(model: "PreTrainedModel") -> None:
62 | if getattr(model.config, "model_type", None) == "llava":
63 | setattr(model, "lm_head", model.language_model.get_output_embeddings())
64 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
65 |
66 | if getattr(model.config, "model_type", None) == "chatglm":
67 | setattr(model, "lm_head", model.transformer.output_layer)
68 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
69 |
70 | if getattr(model.config, "model_type", None) == "internlm2":
71 | setattr(model, "lm_head", model.output)
72 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
73 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/third_party/__init__.py
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/third_party/muon/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .muon import Muon
16 |
17 |
18 | __all__ = ["Muon"]
19 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/train/__init__.py
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/train/dpo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .workflow import run_dpo
16 |
17 |
18 | __all__ = ["run_dpo"]
19 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/train/kto/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .workflow import run_kto
16 |
17 |
18 | __all__ = ["run_kto"]
19 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/train/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .workflow import run_ppo
16 |
17 |
18 | __all__ = ["run_ppo"]
19 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/train/pt/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .workflow import run_pt
16 |
17 |
18 | __all__ = ["run_pt"]
19 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/train/pt/trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from types import MethodType
16 | from typing import TYPE_CHECKING, Optional
17 |
18 | import torch
19 | from transformers import Trainer
20 | from typing_extensions import override
21 |
22 | from ...extras.packages import is_transformers_version_greater_than
23 | from ..callbacks import SaveProcessorCallback
24 | from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
25 |
26 |
27 | if TYPE_CHECKING:
28 | from transformers import ProcessorMixin
29 |
30 | from ...hparams import FinetuningArguments
31 |
32 |
33 | class CustomTrainer(Trainer):
34 | r"""Inherit Trainer for custom optimizer."""
35 |
36 | def __init__(
37 | self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
38 | ) -> None:
39 | if is_transformers_version_greater_than("4.46"):
40 | kwargs["processing_class"] = kwargs.pop("tokenizer")
41 |
42 | super().__init__(**kwargs)
43 | if processor is not None:
44 | # avoid wrong loss under gradient accumulation
45 | # https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
46 | self.model_accepts_loss_kwargs = False
47 |
48 | self.finetuning_args = finetuning_args
49 |
50 | if processor is not None:
51 | self.add_callback(SaveProcessorCallback(processor))
52 |
53 | if finetuning_args.use_badam:
54 | from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
55 |
56 | self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
57 | self.add_callback(BAdamCallback)
58 |
59 | @override
60 | def create_optimizer(self) -> "torch.optim.Optimizer":
61 | if self.optimizer is None:
62 | self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
63 | return super().create_optimizer()
64 |
65 | @override
66 | def create_scheduler(
67 | self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
68 | ) -> "torch.optim.lr_scheduler.LRScheduler":
69 | create_custom_scheduler(self.args, num_training_steps, optimizer)
70 | return super().create_scheduler(num_training_steps, optimizer)
71 |
72 | @override
73 | def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
74 | if self.finetuning_args.disable_shuffling:
75 | return torch.utils.data.SequentialSampler(self.train_dataset)
76 |
77 | return super()._get_train_sampler()
78 |
79 | @override
80 | def compute_loss(self, model, inputs, *args, **kwargs):
81 | return super().compute_loss(model, inputs, *args, **kwargs)
82 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/train/rm/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .workflow import run_rm
16 |
17 |
18 | __all__ = ["run_rm"]
19 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/train/rm/metric.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass
16 | from typing import TYPE_CHECKING, Optional
17 |
18 | import numpy as np
19 |
20 | from ...extras.misc import numpify
21 |
22 |
23 | if TYPE_CHECKING:
24 | from transformers import EvalPrediction
25 |
26 |
27 | @dataclass
28 | class ComputeAccuracy:
29 | r"""Compute reward accuracy and support `batch_eval_metrics`."""
30 |
31 | def _dump(self) -> Optional[dict[str, float]]:
32 | result = None
33 | if hasattr(self, "score_dict"):
34 | result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
35 |
36 | self.score_dict = {"accuracy": []}
37 | return result
38 |
39 | def __post_init__(self):
40 | self._dump()
41 |
42 | def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
43 | chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
44 | if not chosen_scores.shape:
45 | self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
46 | else:
47 | for i in range(len(chosen_scores)):
48 | self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i])
49 |
50 | if compute_result:
51 | return self._dump()
52 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/train/sft/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .workflow import run_sft
16 |
17 |
18 | __all__ = ["run_sft"]
19 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/webui/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/webui/__init__.py
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/webui/components/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .chatbot import create_chat_box
16 | from .eval import create_eval_tab
17 | from .export import create_export_tab
18 | from .infer import create_infer_tab
19 | from .top import create_top
20 | from .train import create_train_tab
21 |
22 |
23 | __all__ = [
24 | "create_chat_box",
25 | "create_eval_tab",
26 | "create_export_tab",
27 | "create_infer_tab",
28 | "create_top",
29 | "create_train_tab",
30 | ]
31 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/webui/components/infer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from ...extras.packages import is_gradio_available
18 | from ..common import is_multimodal
19 | from .chatbot import create_chat_box
20 |
21 |
22 | if is_gradio_available():
23 | import gradio as gr
24 |
25 |
26 | if TYPE_CHECKING:
27 | from gradio.components import Component
28 |
29 | from ..engine import Engine
30 |
31 |
32 | def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
33 | input_elems = engine.manager.get_base_elems()
34 | elem_dict = dict()
35 |
36 | with gr.Row():
37 | infer_backend = gr.Dropdown(choices=["huggingface", "vllm", "sglang"], value="huggingface")
38 | infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto")
39 |
40 | with gr.Row():
41 | load_btn = gr.Button()
42 | unload_btn = gr.Button()
43 |
44 | info_box = gr.Textbox(show_label=False, interactive=False)
45 |
46 | input_elems.update({infer_backend, infer_dtype})
47 | elem_dict.update(
48 | dict(
49 | infer_backend=infer_backend,
50 | infer_dtype=infer_dtype,
51 | load_btn=load_btn,
52 | unload_btn=unload_btn,
53 | info_box=info_box,
54 | )
55 | )
56 |
57 | chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
58 | elem_dict.update(chat_elems)
59 |
60 | load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
61 | lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
62 | )
63 |
64 | unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
65 | lambda: ([], []), outputs=[chatbot, messages]
66 | ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
67 |
68 | engine.manager.get_elem_by_id("top.model_name").change(
69 | lambda model_name: gr.Column(visible=is_multimodal(model_name)),
70 | [engine.manager.get_elem_by_id("top.model_name")],
71 | [chat_elems["mm_box"]],
72 | )
73 |
74 | return elem_dict
75 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/webui/css.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | CSS = r"""
16 | .duplicate-button {
17 | margin: auto !important;
18 | color: white !important;
19 | background: black !important;
20 | border-radius: 100vh !important;
21 | }
22 |
23 | .thinking-summary {
24 | padding: 8px !important;
25 | }
26 |
27 | .thinking-summary span {
28 | border-radius: 4px !important;
29 | padding: 4px !important;
30 | cursor: pointer !important;
31 | font-size: 14px !important;
32 | background: rgb(245, 245, 245) !important;
33 | }
34 |
35 | .dark .thinking-summary span {
36 | background: rgb(73, 73, 73) !important;
37 | }
38 |
39 | .thinking-container {
40 | border-left: 2px solid #a6a6a6 !important;
41 | padding-left: 10px !important;
42 | margin: 4px 0 !important;
43 | }
44 |
45 | .thinking-container p {
46 | color: #a6a6a6 !important;
47 | }
48 |
49 | .modal-box {
50 | position: fixed !important;
51 | top: 50%;
52 | left: 50%;
53 | transform: translate(-50%, -50%); /* center horizontally */
54 | max-width: 1000px;
55 | max-height: 750px;
56 | overflow-y: auto;
57 | background-color: var(--input-background-fill);
58 | flex-wrap: nowrap !important;
59 | border: 2px solid black !important;
60 | z-index: 1000;
61 | padding: 10px;
62 | }
63 |
64 | .dark .modal-box {
65 | border: 2px solid white !important;
66 | }
67 | """
68 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/llamafactory/webui/manager.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from collections.abc import Generator
16 | from typing import TYPE_CHECKING
17 |
18 |
19 | if TYPE_CHECKING:
20 | from gradio.components import Component
21 |
22 |
23 | class Manager:
24 | r"""A class to manage all the gradio components in Web UI."""
25 |
26 | def __init__(self) -> None:
27 | self._id_to_elem: dict[str, Component] = {}
28 | self._elem_to_id: dict[Component, str] = {}
29 |
30 | def add_elems(self, tab_name: str, elem_dict: dict[str, "Component"]) -> None:
31 | r"""Add elements to manager."""
32 | for elem_name, elem in elem_dict.items():
33 | elem_id = f"{tab_name}.{elem_name}"
34 | self._id_to_elem[elem_id] = elem
35 | self._elem_to_id[elem] = elem_id
36 |
37 | def get_elem_list(self) -> list["Component"]:
38 | r"""Return the list of all elements."""
39 | return list(self._id_to_elem.values())
40 |
41 | def get_elem_iter(self) -> Generator[tuple[str, "Component"], None, None]:
42 | r"""Return an iterator over all elements with their names."""
43 | for elem_id, elem in self._id_to_elem.items():
44 | yield elem_id.split(".")[-1], elem
45 |
46 | def get_elem_by_id(self, elem_id: str) -> "Component":
47 | r"""Get element by id.
48 |
49 | Example: top.lang, train.dataset
50 | """
51 | return self._id_to_elem[elem_id]
52 |
53 | def get_id_by_elem(self, elem: "Component") -> str:
54 | r"""Get id by element."""
55 | return self._elem_to_id[elem]
56 |
57 | def get_base_elems(self) -> set["Component"]:
58 | r"""Get the base elements that are commonly used."""
59 | return {
60 | self._id_to_elem["top.lang"],
61 | self._id_to_elem["top.model_name"],
62 | self._id_to_elem["top.model_path"],
63 | self._id_to_elem["top.finetuning_type"],
64 | self._id_to_elem["top.checkpoint_path"],
65 | self._id_to_elem["top.quantization_bit"],
66 | self._id_to_elem["top.quantization_method"],
67 | self._id_to_elem["top.template"],
68 | self._id_to_elem["top.rope_scaling"],
69 | self._id_to_elem["top.booster"],
70 | }
71 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from llamafactory.train.tuner import run_exp
16 |
17 |
18 | def main():
19 | run_exp()
20 |
21 |
22 | def _mp_fn(index):
23 | # For xla_spawn (TPUs)
24 | run_exp()
25 |
26 |
27 | if __name__ == "__main__":
28 | main()
29 |
--------------------------------------------------------------------------------
/LLaMA-Factory/src/webui.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 the LlamaFactory team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | from llamafactory.extras.misc import fix_proxy, is_env_enabled
18 | from llamafactory.webui.interface import create_ui
19 |
20 |
21 | def main():
22 | gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
23 | gradio_share = is_env_enabled("GRADIO_SHARE")
24 | server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
25 | print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
26 | fix_proxy(ipv6_enabled=gradio_ipv6)
27 | create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
28 |
29 |
30 | if __name__ == "__main__":
31 | main()
32 |
--------------------------------------------------------------------------------
/assets/data_composition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/data_composition.png
--------------------------------------------------------------------------------
/assets/exp_main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/exp_main.png
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/overview.png
--------------------------------------------------------------------------------
/assets/reward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/reward.png
--------------------------------------------------------------------------------
/assets/scability.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/scability.png
--------------------------------------------------------------------------------
/assets/sft_or_rl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/sft_or_rl.png
--------------------------------------------------------------------------------
/assets/thinking_template.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/thinking_template.png
--------------------------------------------------------------------------------
/llamafact_merge.sh:
--------------------------------------------------------------------------------
1 | cd LLaMA-Factory
2 | llamafactory-cli export examples/qwen.yaml
--------------------------------------------------------------------------------
/paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/paper.pdf
--------------------------------------------------------------------------------
/qwen_rl.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 | #!/usr/bin/env bash
18 |
19 | export N_GPUS=8
20 | export BASE_MODEL="path/to/model/Qwen2.5-7B-Instruct"
21 | export DATA_DIR="verl/verl/data"
22 | export ROLLOUT_TP_SIZE=2
23 | export VLLM_ATTENTION_BACKEND="XFORMERS"
24 |
25 | export GPU_UT=0.6
26 | export BA_SIZE=1024
27 | export MAX_PROMPT_LEN=4096
28 | export PRO_NAME="qwen"
29 | export EXPERIMENT_NAME="qwen"
30 | export LOG_DIR="path/to/logs/qwen.txt"
31 |
32 | export LR=1e-6
33 | export ENTROPY=0
34 | export MAX_RES=8192
35 | export TEMPERATURE=0.7
36 | export EPOCH=7
37 | export KL_COE=0.001
38 |
39 | bash verl/examples/agent/qwen.sh
--------------------------------------------------------------------------------
/qwen_sft.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export PROJ_NAME="qwen_sft"
4 | export OUTPUT_DIR="path/to/output/dir"
5 | export MODEL_PATH="path/to/model/Qwen2.5-7B-Instruct"
6 | export LOG_DIR="logs/$PROJ_NAME.txt"
7 |
8 | export LR=2.0e-5
9 | export EPOCH=20
10 | export BATCH_SIZE=4
11 | export G_ACC=8
12 |
13 | cd LLaMA-Factory
14 |
15 | FORCE_TORCHRUN=1 llamafactory-cli train examples/qwen_sft.yaml \
16 | learning_rate=$LR \
17 | num_train_epochs=$EPOCH \
18 | per_device_train_batch_size=$BATCH_SIZE \
19 | gradient_accumulation_steps=$G_ACC \
20 | output_dir=$OUTPUT_DIR \
21 | run_name=$PROJ_NAME \
22 | model_name_or_path=$MODEL_PATH 2>&1 | tee -a "${LOG_DIR}"
--------------------------------------------------------------------------------
/verl/examples/agent/qwen.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # SPDX-License-Identifier: Apache-2.0
16 |
17 |
18 | python3 -m verl.trainer.main_ppo \
19 | algorithm.adv_estimator=grpo \
20 | actor_rollout_ref.actor.use_kl_loss=True \
21 | actor_rollout_ref.actor.kl_loss_coef=$KL_COE \
22 | actor_rollout_ref.actor.entropy_coeff=$ENTROPY \
23 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \
24 | actor_rollout_ref.actor.optim.lr=$LR \
25 | actor_rollout_ref.actor.fsdp_config.param_offload=False \
26 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
27 | actor_rollout_ref.actor.use_dynamic_bsz=True \
28 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30720 \
29 | actor_rollout_ref.model.enable_gradient_checkpointing=True \
30 | actor_rollout_ref.model.path=$BASE_MODEL \
31 | actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
32 | actor_rollout_ref.rollout.gpu_memory_utilization=$GPU_UT \
33 | actor_rollout_ref.rollout.temperature=$TEMPERATURE \
34 | actor_rollout_ref.rollout.n=5 \
35 | actor_rollout_ref.ref.fsdp_config.param_offload=True \
36 | actor_rollout_ref.model.use_remove_padding=True \
37 | data.train_files=$DATA_DIR/train.parquet \
38 | data.val_files=$DATA_DIR/test.parquet \
39 | data.train_batch_size=$BA_SIZE \
40 | data.val_batch_size=1312 \
41 | data.max_prompt_length=$MAX_PROMPT_LEN \
42 | data.max_response_length=$MAX_RES \
43 | algorithm.kl_ctrl.kl_coef=$KL_COE \
44 | trainer.critic_warmup=0 \
45 | trainer.logger=['wandb'] \
46 | +trainer.val_before_train=False \
47 | +actor_rollout_ref.actor.fsdp_config.grad_offload=False \
48 | trainer.default_hdfs_dir=null \
49 | trainer.n_gpus_per_node=$N_GPUS \
50 | trainer.nnodes=1 \
51 | trainer.save_freq=10 \
52 | trainer.test_freq=10 \
53 | trainer.project_name=$PRO_NAME \
54 | trainer.experiment_name=$EXPERIMENT_NAME \
55 | trainer.total_epochs=$EPOCH 2>&1 | tee -a "${LOG_DIR}"
--------------------------------------------------------------------------------
/verl/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # setup.py is the fallback installation script when pyproject.toml does not work
16 | from setuptools import setup, find_packages
17 | import os
18 |
19 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
20 |
21 | with open(os.path.join(version_folder, 'verl/version/version')) as f:
22 | __version__ = f.read().strip()
23 |
24 | install_requires = [
25 | 'accelerate',
26 | 'codetiming',
27 | 'datasets',
28 | 'dill',
29 | 'hydra-core',
30 | 'numpy',
31 | 'pandas',
32 | 'peft',
33 | 'pyarrow>=15.0.0',
34 | 'pybind11',
35 | 'pylatexenc',
36 | 'ray>=2.10',
37 | 'tensordict<0.6',
38 | 'torchdata',
39 | 'transformers',
40 | 'vllm<=0.6.3',
41 | 'wandb',
42 | ]
43 |
44 | TEST_REQUIRES = ['pytest', 'yapf', 'py-spy']
45 | PRIME_REQUIRES = ['pyext']
46 | GEO_REQUIRES = ['mathruler']
47 | GPU_REQUIRES = ['liger-kernel', 'flash-attn']
48 |
49 | extras_require = {
50 | 'test': TEST_REQUIRES,
51 | 'prime': PRIME_REQUIRES,
52 | 'geo': GEO_REQUIRES,
53 | 'gpu': GPU_REQUIRES,
54 | }
55 |
56 | from pathlib import Path
57 | this_directory = Path(__file__).parent
58 | long_description = "Volcano Engine Reinforcement Learning for LLM"
59 |
60 | setup(
61 | name='verl',
62 | version=__version__,
63 | package_dir={'': '.'},
64 | packages=find_packages(where='.'),
65 | url='https://github.com/volcengine/verl',
66 | license='Apache 2.0',
67 | author='Bytedance - Seed - MLSys',
68 | author_email='zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk',
69 | description='verl: Volcano Engine Reinforcement Learning for LLM',
70 | install_requires=install_requires,
71 | extras_require=extras_require,
72 | package_data={'': ['version/*'],
73 | 'verl': ['trainer/config/*.yaml'],},
74 | include_package_data=True,
75 | long_description=long_description,
76 | long_description_content_type='text/markdown'
77 | )
--------------------------------------------------------------------------------
/verl/verl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
18 |
19 | with open(os.path.join(version_folder, 'version/version')) as f:
20 | __version__ = f.read().strip()
21 |
22 | from .protocol import DataProto
23 |
24 | from .utils.logging_utils import set_basic_config
25 | import logging
26 |
27 | set_basic_config(level=logging.WARNING)
28 |
29 | from . import single_controller
30 |
31 | __all__ = ['DataProto', "__version__"]
--------------------------------------------------------------------------------
/verl/verl/data/test.parquet:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:674fbfa58ee11be7034a505323cdfa49b9863d3dd2c0e31131a08a09709869c9
3 | size 9338664
4 |
--------------------------------------------------------------------------------
/verl/verl/data/train.parquet:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:cb080d5ae5f3b84423d9bbfd1d37923b8f59282fe7778345c459cf1371a25bac
3 | size 83495252
4 |
--------------------------------------------------------------------------------
/verl/verl/models/README.md:
--------------------------------------------------------------------------------
1 | # Models
2 | Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl.
3 | ## Adding a New Huggingface Model
4 | ### Step 1: Copy the model file from HF to verl
5 | - Add a new file under verl/models/hf
6 | - Copy ONLY the model file from huggingface/transformers/models to verl/models/hf
7 |
8 | ### Step 2: Modify the model file to use packed inputs
9 | - Remove all the code related to inference (kv cache)
10 | - Modify the inputs to include only
11 | - input_ids (total_nnz,)
12 | - cu_seqlens (total_nnz + 1,)
13 | - max_seqlen_in_batch: int
14 | - Note that this requires using flash attention with causal mask.
15 |
16 | ### Step 2.5: Add tests
17 | - Add a test to compare this version and the huggingface version
18 | - Following the infrastructure and add tests to tests/models/hf
19 |
20 | ### Step 3: Add a function to apply tensor parallelism
21 | - Please follow
22 | - https://pytorch.org/docs/stable/distributed.tensor.parallel.html
23 | - https://pytorch.org/tutorials/intermediate/TP_tutorial.html
24 | - General comments
25 | - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward.
26 |
27 | ### Step 4: Add a function to apply data parallelism
28 | - Please use FSDP2 APIs
29 | - See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413
30 |
31 | ### Step 5: Add a function to apply pipeline parallelism
32 | - Comes in Pytorch 2.4
33 | - Currently only in alpha in nightly version
34 | - Check torchtitan for more details
35 |
36 |
--------------------------------------------------------------------------------
/verl/verl/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/models/llama/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/models/llama/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .modeling_llama_megatron import (
16 | # original model with megatron
17 | ParallelLlamaModel,
18 | ParallelLlamaForCausalLM,
19 | # rmpad with megatron
20 | ParallelLlamaForCausalLMRmPad,
21 | ParallelLlamaForValueRmPad,
22 | # rmpad with megatron and pipeline parallelism
23 | ParallelLlamaForCausalLMRmPadPP,
24 | ParallelLlamaForValueRmPadPP)
25 |
--------------------------------------------------------------------------------
/verl/verl/models/llama/megatron/checkpoint_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/models/llama/megatron/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .parallel_attention import ParallelLlamaAttention
16 | from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad
17 | from .parallel_mlp import ParallelLlamaMLP
18 | from .parallel_rmsnorm import ParallelLlamaRMSNorm
19 |
--------------------------------------------------------------------------------
/verl/verl/models/llama/megatron/layers/parallel_linear.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
15 |
16 | from typing import Optional, Tuple
17 |
18 | from megatron.core import tensor_parallel
19 |
20 |
21 | class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
22 |
23 | def __init__(self,
24 | input_size,
25 | num_heads,
26 | num_key_value_heads,
27 | head_dim,
28 | *,
29 | bias=True,
30 | gather_output=True,
31 | skip_bias_add=False,
32 | **kwargs):
33 | # Keep input parameters, and already restrict the head numbers
34 | self.input_size = input_size
35 | self.q_output_size = num_heads * head_dim
36 | self.kv_output_size = num_key_value_heads * head_dim
37 | self.head_dim = head_dim
38 | self.gather_output = gather_output
39 | self.skip_bias_add = skip_bias_add
40 |
41 | input_size = self.input_size
42 | output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim
43 |
44 | super().__init__(input_size=input_size,
45 | output_size=output_size,
46 | bias=bias,
47 | gather_output=gather_output,
48 | skip_bias_add=skip_bias_add,
49 | **kwargs)
50 |
51 |
52 | class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
53 |
54 | def __init__(self,
55 | input_size,
56 | gate_ouput_size,
57 | up_output_size,
58 | *,
59 | bias=True,
60 | gather_output=True,
61 | skip_bias_add=False,
62 | **kwargs):
63 | # Keep input parameters, and already restrict the head numbers
64 | self.input_size = input_size
65 | self.output_size = gate_ouput_size + up_output_size
66 | self.gather_output = gather_output
67 | self.skip_bias_add = skip_bias_add
68 |
69 | super().__init__(input_size=self.input_size,
70 | output_size=self.output_size,
71 | bias=bias,
72 | gather_output=gather_output,
73 | skip_bias_add=skip_bias_add,
74 | **kwargs)
75 |
--------------------------------------------------------------------------------
/verl/verl/models/llama/megatron/layers/parallel_rmsnorm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numbers
16 | import torch
17 | from megatron.core import ModelParallelConfig
18 | from torch import nn
19 | from transformers import LlamaConfig
20 |
21 | from apex.normalization.fused_layer_norm import fused_rms_norm_affine
22 | from verl.utils.megatron import sequence_parallel as sp_utils
23 |
24 |
25 | class ParallelLlamaRMSNorm(nn.Module):
26 |
27 | def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
28 | """
29 | LlamaRMSNorm is equivalent to T5LayerNorm
30 | """
31 | super().__init__()
32 | if isinstance(config.hidden_size, numbers.Integral):
33 | normalized_shape = (config.hidden_size,)
34 | self.normalized_shape = torch.Size(normalized_shape)
35 | self.weight = nn.Parameter(torch.ones(self.normalized_shape))
36 | self.variance_epsilon = config.rms_norm_eps
37 |
38 | if megatron_config.sequence_parallel:
39 | sp_utils.mark_parameter_as_sequence_parallel(self.weight)
40 |
41 | def forward(self, hidden_states):
42 | return fused_rms_norm_affine(input=hidden_states,
43 | weight=self.weight,
44 | normalized_shape=self.normalized_shape,
45 | eps=self.variance_epsilon,
46 | memory_efficient=True)
--------------------------------------------------------------------------------
/verl/verl/models/qwen2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/models/qwen2/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .modeling_qwen2_megatron import (
16 | # original model with megatron
17 | ParallelQwen2Model,
18 | ParallelQwen2ForCausalLM,
19 | # rmpad with megatron
20 | ParallelQwen2ForCausalLMRmPad,
21 | ParallelQwen2ForValueRmPad,
22 | # rmpad with megatron and pipeline parallelism
23 | ParallelQwen2ForCausalLMRmPadPP,
24 | ParallelQwen2ForValueRmPadPP)
25 |
--------------------------------------------------------------------------------
/verl/verl/models/qwen2/megatron/checkpoint_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/models/qwen2/megatron/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .parallel_attention import ParallelQwen2Attention
16 | from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad
17 | from .parallel_mlp import ParallelQwen2MLP
18 | from .parallel_rmsnorm import ParallelQwen2RMSNorm
19 |
--------------------------------------------------------------------------------
/verl/verl/models/qwen2/megatron/layers/parallel_linear.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
15 |
16 | from typing import Optional, Tuple
17 |
18 | from megatron.core import tensor_parallel
19 |
20 |
21 | class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
22 |
23 | def __init__(self,
24 | input_size,
25 | num_heads,
26 | num_key_value_heads,
27 | head_dim,
28 | *,
29 | bias=True,
30 | gather_output=True,
31 | skip_bias_add=False,
32 | **kwargs):
33 | # Keep input parameters, and already restrict the head numbers
34 | self.input_size = input_size
35 | self.q_output_size = num_heads * head_dim
36 | self.kv_output_size = num_key_value_heads * head_dim
37 | self.head_dim = head_dim
38 | self.gather_output = gather_output
39 | self.skip_bias_add = skip_bias_add
40 |
41 | input_size = self.input_size
42 | output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim
43 |
44 | super().__init__(input_size=input_size,
45 | output_size=output_size,
46 | bias=bias,
47 | gather_output=gather_output,
48 | skip_bias_add=skip_bias_add,
49 | **kwargs)
50 |
51 |
52 | class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
53 |
54 | def __init__(self,
55 | input_size,
56 | gate_ouput_size,
57 | up_output_size,
58 | *,
59 | bias=True,
60 | gather_output=True,
61 | skip_bias_add=False,
62 | **kwargs):
63 | # Keep input parameters, and already restrict the head numbers
64 | self.input_size = input_size
65 | self.output_size = gate_ouput_size + up_output_size
66 | self.gather_output = gather_output
67 | self.skip_bias_add = skip_bias_add
68 |
69 | super().__init__(input_size=self.input_size,
70 | output_size=self.output_size,
71 | bias=bias,
72 | gather_output=gather_output,
73 | skip_bias_add=skip_bias_add,
74 | **kwargs)
75 |
--------------------------------------------------------------------------------
/verl/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numbers
16 | import torch
17 | from megatron.core import ModelParallelConfig
18 | from torch import nn
19 | from transformers import Qwen2Config
20 |
21 | from apex.normalization.fused_layer_norm import fused_rms_norm_affine
22 | from verl.utils.megatron import sequence_parallel as sp_utils
23 |
24 |
25 | class ParallelQwen2RMSNorm(nn.Module):
26 |
27 | def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):
28 | """
29 | Qwen2RMSNorm is equivalent to T5LayerNorm
30 | """
31 | super().__init__()
32 | if isinstance(config.hidden_size, numbers.Integral):
33 | normalized_shape = (config.hidden_size,)
34 | self.normalized_shape = torch.Size(normalized_shape)
35 | self.weight = nn.Parameter(torch.ones(self.normalized_shape))
36 | self.variance_epsilon = config.rms_norm_eps
37 |
38 | if megatron_config.sequence_parallel:
39 | sp_utils.mark_parameter_as_sequence_parallel(self.weight)
40 |
41 | def forward(self, hidden_states):
42 | return fused_rms_norm_affine(input=hidden_states,
43 | weight=self.weight,
44 | normalized_shape=self.normalized_shape,
45 | eps=self.variance_epsilon,
46 | memory_efficient=True)
--------------------------------------------------------------------------------
/verl/verl/models/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import importlib
16 | from typing import List, Optional, Type
17 |
18 | import torch.nn as nn
19 |
20 | # Supported models using HF Rmpad
21 | # TODO(sgm): HF may supported more than listed here, we should add more after testing
22 | _MODELS_SUPPORT_RMPAD = {'llama', 'mistral', 'gemma', 'qwen2', 'qwen2_vl', 'qwen2_5_vl'}
23 |
24 |
25 | def check_model_support_rmpad(model_type: str):
26 | assert isinstance(model_type, str)
27 | if not model_type in _MODELS_SUPPORT_RMPAD:
28 | raise ValueError(f"Model architecture {model_type} is not supported for now. "
29 | f"RMPad supported architectures: {_MODELS_SUPPORT_RMPAD}."
30 | f"Please set `use_remove_padding=False` in the model config.")
31 |
32 | if model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope
33 | from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward
34 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
35 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
36 |
37 | Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward
38 | Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward
39 | print("Qwen2vl patch applied!")
40 |
41 |
42 | # Supported models in Megatron-LM
43 | # Architecture -> (module, class).
44 | _MODELS = {
45 | "LlamaForCausalLM":
46 | ("llama", ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad")),
47 | "Qwen2ForCausalLM":
48 | ("qwen2", ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad")),
49 | "MistralForCausalLM": ("mistral", ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP",
50 | "ParallelMistralForCausalLMRmPad"))
51 | }
52 |
53 | # return model class
54 | class ModelRegistry:
55 |
56 | @staticmethod
57 | def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]:
58 | if model_arch not in _MODELS:
59 | return None
60 |
61 | megatron = "megatron"
62 |
63 | module_name, model_cls_name = _MODELS[model_arch]
64 | if not value: # actor/ref
65 | model_cls_name = model_cls_name[0]
66 | elif value: # critic/rm
67 | model_cls_name = model_cls_name[1]
68 |
69 | module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron")
70 | return getattr(module, model_cls_name, None)
71 |
72 | @staticmethod
73 | def get_supported_archs() -> List[str]:
74 | return list(_MODELS.keys())
75 |
--------------------------------------------------------------------------------
/verl/verl/models/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/models/weight_loader_registry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | def get_weight_loader(arch: str):
17 | from verl.models.llama.megatron.checkpoint_utils.llama_loader import load_state_dict_to_megatron_llama
18 | from verl.models.qwen2.megatron.checkpoint_utils.qwen2_loader import load_state_dict_to_megatron_qwen2
19 | _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = {
20 | 'LlamaForCausalLM': load_state_dict_to_megatron_llama,
21 | 'Qwen2ForCausalLM': load_state_dict_to_megatron_qwen2,
22 | }
23 |
24 | if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY:
25 | return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch]
26 | raise ValueError(f"Model architectures {arch} are not supported for now. "
27 | f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}")
28 |
--------------------------------------------------------------------------------
/verl/verl/single_controller/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
18 |
19 | # Note(haibin.lin): single_controller.__version__ is deprecated
20 | with open(os.path.join(os.path.join(version_folder, os.pardir), 'version/version')) as f:
21 | __version__ = f.read().strip()
22 |
23 | from . import base
24 | from .base import *
25 |
26 | __all__ = base.__all__
--------------------------------------------------------------------------------
/verl/verl/single_controller/base/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .worker import Worker
16 | from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool
17 |
18 | __all__ = ['Worker', 'WorkerGroup', 'ClassWithInitArgs', 'ResourcePool']
--------------------------------------------------------------------------------
/verl/verl/single_controller/base/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/single_controller/base/megatron/worker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo
16 |
17 |
18 | class MegatronWorker(Worker):
19 |
20 | def __init__(self, cuda_visible_devices=None) -> None:
21 | super().__init__(cuda_visible_devices)
22 |
23 | def get_megatron_global_info(self):
24 | from megatron.core import parallel_state as mpu
25 | tp_size = mpu.get_tensor_model_parallel_world_size()
26 | dp_size = mpu.get_data_parallel_world_size()
27 | pp_size = mpu.get_pipeline_model_parallel_world_size()
28 | info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size)
29 | return info
30 |
31 | def get_megatron_rank_info(self):
32 | from megatron.core import parallel_state as mpu
33 | tp_rank = mpu.get_tensor_model_parallel_rank()
34 | dp_rank = mpu.get_data_parallel_rank()
35 | pp_rank = mpu.get_pipeline_model_parallel_rank()
36 | info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank)
37 | return info
--------------------------------------------------------------------------------
/verl/verl/single_controller/base/megatron/worker_group.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Dict
16 |
17 | from .worker import DistRankInfo, DistGlobalInfo
18 | from verl.single_controller.base import ResourcePool, WorkerGroup
19 |
20 |
21 | class MegatronWorkerGroup(WorkerGroup):
22 |
23 | def __init__(self, resource_pool: ResourcePool, **kwargs):
24 | super().__init__(resource_pool=resource_pool, **kwargs)
25 | self._megatron_rank_info = None
26 | self._megatron_global_info: DistGlobalInfo = None
27 |
28 | def init_megatron(self, default_megatron_kwargs: Dict = None):
29 | raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten")
30 |
31 | def get_megatron_rank_info(self, rank: int) -> DistRankInfo:
32 | assert 0 <= rank < self.world_size, f'rank must be from [0, world_size), Got {rank}'
33 | return self._megatron_rank_info[rank]
34 |
35 | @property
36 | def tp_size(self):
37 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
38 | return self._megatron_global_info.tp_size
39 |
40 | @property
41 | def dp_size(self):
42 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
43 | return self._megatron_global_info.dp_size
44 |
45 | @property
46 | def pp_size(self):
47 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
48 | return self._megatron_global_info.pp_size
49 |
50 | def get_megatron_global_info(self):
51 | return self._megatron_global_info
52 |
--------------------------------------------------------------------------------
/verl/verl/single_controller/base/register_center/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/single_controller/base/register_center/ray.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import ray
16 |
17 |
18 | @ray.remote
19 | class WorkerGroupRegisterCenter:
20 |
21 | def __init__(self, rank_zero_info):
22 | self.rank_zero_info = rank_zero_info
23 |
24 | def get_rank_zero_info(self):
25 | return self.rank_zero_info
26 |
27 |
28 | def create_worker_group_register_center(name, info):
29 | return WorkerGroupRegisterCenter.options(name=name).remote(info)
30 |
--------------------------------------------------------------------------------
/verl/verl/single_controller/ray/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
--------------------------------------------------------------------------------
/verl/verl/single_controller/ray/megatron.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Dict, Optional
16 |
17 | import ray
18 |
19 | from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs
20 | from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo
21 | from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
22 |
23 |
24 | # NOTE(sgm): for open-source megatron-core
25 | class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
26 | """
27 | MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
28 | so that the dispatcher can use it to dispatch data.
29 | """
30 |
31 | def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs):
32 | super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs)
33 | self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info')
34 | self._megatron_global_info: DistGlobalInfo = ray.get(
35 | self.execute_rank_zero_async(method_name='get_megatron_global_info'))
36 |
37 |
38 | class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
39 | """
40 | MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
41 | so that the dispatcher can use it to dispatch data.
42 | """
43 |
44 | def __init__(self,
45 | resource_pool: RayResourcePool,
46 | ray_cls_with_init: RayClassWithInitArgs,
47 | default_megatron_kwargs: Dict = None,
48 | **kwargs):
49 | super().__init__(resource_pool=resource_pool,
50 | ray_cls_with_init=ray_cls_with_init,
51 | default_megatron_kwargs=default_megatron_kwargs,
52 | **kwargs)
53 | self.init_megatron(default_megatron_kwargs=default_megatron_kwargs)
54 | self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info')
55 | self._megatron_global_info: DistGlobalInfo = ray.get(
56 | self.execute_rank_zero_async(method_name='get_megatron_global_info'))
57 |
58 | def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None):
59 | # after super, we will call init of each worker
60 | if not self._is_init_with_detached_workers:
61 | # only init_megatron if the WorkerGroup is created from scratch
62 | self.execute_all_sync(method_name='init_megatron', default_megatron_kwargs=default_megatron_kwargs)
63 |
--------------------------------------------------------------------------------
/verl/verl/third_party/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/third_party/vllm/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from importlib.metadata import version, PackageNotFoundError
16 | from packaging import version as vs
17 |
18 |
19 | def get_version(pkg):
20 | try:
21 | return version(pkg)
22 | except PackageNotFoundError:
23 | return None
24 |
25 |
26 | package_name = 'vllm'
27 | package_version = get_version(package_name)
28 | vllm_version = None
29 |
30 | if package_version == '0.3.1':
31 | vllm_version = '0.3.1'
32 | from .vllm_v_0_3_1.llm import LLM
33 | from .vllm_v_0_3_1.llm import LLMEngine
34 | from .vllm_v_0_3_1 import parallel_state
35 | elif package_version == '0.4.2':
36 | vllm_version = '0.4.2'
37 | from .vllm_v_0_4_2.llm import LLM
38 | from .vllm_v_0_4_2.llm import LLMEngine
39 | from .vllm_v_0_4_2 import parallel_state
40 | elif package_version == '0.5.4':
41 | vllm_version = '0.5.4'
42 | from .vllm_v_0_5_4.llm import LLM
43 | from .vllm_v_0_5_4.llm import LLMEngine
44 | from .vllm_v_0_5_4 import parallel_state
45 | elif package_version == '0.6.3':
46 | vllm_version = '0.6.3'
47 | from .vllm_v_0_6_3.llm import LLM
48 | from .vllm_v_0_6_3.llm import LLMEngine
49 | from .vllm_v_0_6_3 import parallel_state
50 | elif vs.parse(package_version) >= vs.parse('0.6.6.post2.dev252+g8027a724'):
51 | # From 0.6.6.post2 on, vllm supports SPMD inference
52 | # See https://github.com/vllm-project/vllm/pull/12071
53 |
54 | from vllm import LLM
55 | from vllm.distributed import parallel_state
56 | from .vllm_spmd.dtensor_weight_loaders import load_dtensor_weights
57 | else:
58 | raise ValueError(
59 | f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4, 0.6.3 and 0.7.0+'
60 | )
61 |
--------------------------------------------------------------------------------
/verl/verl/third_party/vllm/vllm_spmd/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/third_party/vllm/vllm_v_0_3_1/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast)
19 |
20 | from vllm.lora.request import LoRARequest
21 | from vllm.utils import make_async, LRUCache
22 | from vllm.transformers_utils.tokenizers import *
23 |
24 |
25 | class TokenizerGroup:
26 | """A group of tokenizers that can be used for LoRA adapters."""
27 |
28 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
29 | max_input_length: Optional[int]):
30 | self.enable_lora = enable_lora
31 | self.max_input_length = max_input_length
32 | self.tokenizer = tokenizer
33 | if enable_lora:
34 | self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
35 | else:
36 | self.lora_tokenizers = None
37 |
38 | def encode(self,
39 | prompt: str,
40 | request_id: Optional[str] = None,
41 | lora_request: Optional[LoRARequest] = None) -> List[int]:
42 | tokenizer = self.get_lora_tokenizer(lora_request)
43 | return tokenizer.encode(prompt)
44 |
45 | async def encode_async(self,
46 | prompt: str,
47 | request_id: Optional[str] = None,
48 | lora_request: Optional[LoRARequest] = None) -> List[int]:
49 | tokenizer = await self.get_lora_tokenizer_async(lora_request)
50 | return tokenizer.encode(prompt)
51 |
52 | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
53 | if not lora_request or not self.enable_lora:
54 | return self.tokenizer
55 | if lora_request.lora_int_id not in self.lora_tokenizers:
56 | # TODO(sgm): the lora tokenizer is also passed, but may be different
57 | tokenizer = self.tokenizer
58 | # tokenizer = (get_lora_tokenizer(
59 | # lora_request, **self.tokenizer_config) or self.tokenizer)
60 | self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
61 | return tokenizer
62 | else:
63 | return self.lora_tokenizers.get(lora_request.lora_int_id)
64 |
65 | # FIXME(sgm): for simplicity, we assign the special token here
66 | @property
67 | def pad_token_id(self):
68 | return self.tokenizer.pad_token_id
69 |
70 | @property
71 | def eos_token_id(self):
72 | return self.tokenizer.eos_token_id
73 |
--------------------------------------------------------------------------------
/verl/verl/third_party/vllm/vllm_v_0_4_2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/third_party/vllm/vllm_v_0_5_4/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
15 |
16 | from typing import Dict, Union, Optional, Iterable, Tuple
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype
22 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23 |
24 |
25 | def update_hf_weight_loader():
26 | print('no hf weight loader need to be updated')
27 | return
28 |
29 |
30 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
31 | assert isinstance(actor_weights, Dict)
32 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
33 | if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
34 | del actor_weights["lm_head.weight"]
35 | vllm_model.load_weights(actor_weights.items())
36 | for _, module in vllm_model.named_modules():
37 | quant_method = getattr(module, "quant_method", None)
38 | if quant_method is not None:
39 | quant_method.process_weights_after_loading(module)
40 | # FIXME: Remove this after Mixtral is updated
41 | # to use quant_method.
42 | if hasattr(module, "process_weights_after_loading"):
43 | module.process_weights_after_loading()
44 | vllm_model = vllm_model.cuda()
45 |
--------------------------------------------------------------------------------
/verl/verl/third_party/vllm/vllm_v_0_6_3/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
15 |
16 | from typing import Dict
17 |
18 | import torch.nn as nn
19 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype
20 |
21 |
22 | def update_hf_weight_loader():
23 | print("no hf weight loader need to be updated")
24 | return
25 |
26 |
27 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
28 | assert isinstance(actor_weights, Dict)
29 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
30 | if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
31 | del actor_weights["lm_head.weight"]
32 | vllm_model.load_weights(actor_weights.items())
33 | for _, module in vllm_model.named_modules():
34 | quant_method = getattr(module, "quant_method", None)
35 | if quant_method is not None:
36 | quant_method.process_weights_after_loading(module)
37 | # FIXME: Remove this after Mixtral is updated
38 | # to use quant_method.
39 | if hasattr(module, "process_weights_after_loading"):
40 | module.process_weights_after_loading()
41 | vllm_model = vllm_model.cuda()
42 |
--------------------------------------------------------------------------------
/verl/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
15 |
16 | from typing import Optional
17 |
18 | from transformers import PreTrainedTokenizer
19 | from vllm.transformers_utils.tokenizer_group import TokenizerGroup
20 | from vllm.utils import LRUCache
21 |
22 |
23 | class TokenizerGroup(TokenizerGroup):
24 | """A group of tokenizers that can be used for LoRA adapters."""
25 |
26 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
27 | max_input_length: Optional[int]):
28 | self.enable_lora = enable_lora
29 | self.max_input_length = max_input_length
30 | self.tokenizer = tokenizer
31 | self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None
32 |
33 | # FIXME(sgm): for simplicity, we assign the special token here
34 | @property
35 | def pad_token_id(self):
36 | return self.tokenizer.pad_token_id
37 |
38 | @property
39 | def eos_token_id(self):
40 | return self.tokenizer.eos_token_id
41 |
--------------------------------------------------------------------------------
/verl/verl/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/trainer/config/evaluation.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | path: /tmp/math_Qwen2-7B-Instruct.parquet
3 | prompt_key: prompt
4 | response_key: responses
5 | data_source_key: data_source
6 | reward_model_key: reward_model
--------------------------------------------------------------------------------
/verl/verl/trainer/config/generation.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | nnodes: 1
3 | n_gpus_per_node: 8
4 |
5 | data:
6 | path: ~/data/rlhf/math/test.parquet
7 | prompt_key: prompt
8 | n_samples: 5
9 | output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet
10 | batch_size: 128
11 |
12 | model:
13 | path: ~/models/Qwen2-7B-Instruct
14 | external_lib: null
15 | rollout:
16 | name: vllm
17 | temperature: 1.0
18 | top_k: 50 # 0 for hf rollout, -1 for vllm rollout
19 | top_p: 0.7
20 | prompt_length: 1536
21 | response_length: 512
22 | # for vllm rollout
23 | dtype: bfloat16 # should align with FSDP
24 | gpu_memory_utilization: 0.5
25 | ignore_eos: False
26 | enforce_eager: True
27 | free_cache_engine: True
28 | load_format: dummy_dtensor
29 | tensor_model_parallel_size: 1
30 | max_num_batched_tokens: 8192
31 | max_model_len: null
32 | max_num_seqs: 1024
33 | log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
34 | log_prob_micro_batch_size_per_gpu: 8
35 | # for fire vllm rollout
36 | use_fire_sampling: False # enable FIRE https://arxiv.org/abs/2410.21236
37 | # for hf rollout
38 | do_sample: True
39 | disable_log_stats: True
40 | enable_chunked_prefill: True
41 | n: 1
42 | actor:
43 | strategy: fsdp # This is for backward-compatibility
44 | ppo_mini_batch_size: 256
45 | ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
46 | ppo_micro_batch_size_per_gpu: null
47 | use_dynamic_bsz: False
48 | ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
49 | grad_clip: 1.0
50 | clip_ratio: 0.2
51 | entropy_coeff: 0.001
52 | use_kl_loss: False # True for GRPO
53 | kl_loss_coef: 0.001 # for grpo
54 | kl_loss_type: low_var_kl # for grpo
55 | ppo_epochs: 1
56 | shuffle: False
57 | ulysses_sequence_parallel_size: 1 # sp size
58 | optim:
59 | lr: 1e-6
60 | lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
61 | min_lr_ratio: null # only useful for warmup with cosine
62 | warmup_style: constant # select from constant/cosine
63 | total_training_steps: -1 # must be override by program
64 | fsdp_config:
65 | wrap_policy:
66 | min_num_params: 0
67 | param_offload: False
68 | optimizer_offload: False
69 | fsdp_size: -1
--------------------------------------------------------------------------------
/verl/verl/trainer/config/sft_trainer.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_batch_size: 256
3 | micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
4 | micro_batch_size_per_gpu: 4 # this is also val batch size
5 | train_files: ~/data/gsm8k/train.parquet
6 | val_files: ~/data/gsm8k/test.parquet
7 | prompt_key: question
8 | response_key: answer
9 | max_length: 1024
10 | truncation: error
11 | balance_dp_token: False
12 | chat_template: null
13 | model:
14 | partial_pretrain: ~/models/gemma-1.1-7b-it
15 | fsdp_config:
16 | wrap_policy:
17 | min_num_params: 0
18 | cpu_offload: False
19 | offload_params: False
20 | external_lib: null
21 | enable_gradient_checkpointing: False
22 | trust_remote_code: False
23 | lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32)
24 | lora_alpha: 16 # LoRA scaling factor
25 | target_modules: all-linear # Target modules for LoRA adaptation
26 | use_liger: False
27 | optim:
28 | lr: 1e-5
29 | betas: [0.9, 0.95]
30 | weight_decay: 0.01
31 | warmup_steps_ratio: 0.1
32 | clip_grad: 1.0
33 | ulysses_sequence_parallel_size: 1
34 | use_remove_padding: False
35 | trainer:
36 | default_local_dir: /tmp/sft_model
37 | default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here
38 | resume_path: null
39 | project_name: gsm8k-sft
40 | experiment_name: test
41 | total_epochs: 4
42 | total_training_steps: null
43 | logger: ['console']
44 | seed: 1
45 |
46 |
--------------------------------------------------------------------------------
/verl/verl/trainer/main_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Offline evaluate the performance of a generated file using reward model and ground truth verifier.
16 | The input is a parquet file that contains N generated sequences and (optional) the ground truth.
17 |
18 | """
19 |
20 | import hydra
21 | from verl.utils.fs import copy_to_local
22 | from verl.utils.reward_score import math, gsm8k
23 | import pandas as pd
24 | import numpy as np
25 |
26 |
27 | def select_reward_fn(data_source):
28 | if data_source == 'lighteval/MATH':
29 | return math.compute_score
30 | else:
31 | raise NotImplementedError
32 |
33 |
34 | @hydra.main(config_path='config', config_name='evaluation', version_base=None)
35 | def main(config):
36 | local_path = copy_to_local(config.data.path)
37 | dataset = pd.read_parquet(local_path)
38 | prompts = dataset[config.data.prompt_key]
39 | responses = dataset[config.data.response_key]
40 | data_sources = dataset[config.data.data_source_key]
41 | reward_model_data = dataset[config.data.reward_model_key]
42 |
43 | passes = 0
44 |
45 | total = len(dataset)
46 |
47 | for i in range(total):
48 | response_lst = responses[i]
49 | data_source = data_sources[i]
50 | # select reward score based on data_source
51 | prompt = prompts[i]
52 | reward_data = reward_model_data[i]
53 | reward_fn = select_reward_fn(data_source)
54 | ground_truth = reward_data['ground_truth']
55 | score_lst = []
56 | for r in response_lst:
57 | score = reward_fn(r, ground_truth)
58 | score_lst.append(score)
59 |
60 | max_score = np.max(score_lst)
61 |
62 | if max_score == 1:
63 | passes += 1
64 |
65 | print(f'pass@5: {passes / total}')
66 |
67 |
68 | if __name__ == '__main__':
69 | main()
70 |
--------------------------------------------------------------------------------
/verl/verl/trainer/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/trainer/runtime_env.yaml:
--------------------------------------------------------------------------------
1 | working_dir: ./
2 | excludes: ["/.git/"]
3 | env_vars:
4 | TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
5 | VLLM_ATTENTION_BACKEND: "XFORMERS"
--------------------------------------------------------------------------------
/verl/verl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import tokenizer
16 | from .tokenizer import hf_tokenizer, hf_processor
17 |
18 | __all__ = tokenizer.__all__
--------------------------------------------------------------------------------
/verl/verl/utils/checkpoint/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
--------------------------------------------------------------------------------
/verl/verl/utils/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Dict
16 |
17 | from omegaconf import DictConfig
18 |
19 |
20 | def update_dict_with_config(dictionary: Dict, config: DictConfig):
21 | for key in dictionary:
22 | if hasattr(config, key):
23 | dictionary[key] = getattr(config, key)
24 |
--------------------------------------------------------------------------------
/verl/verl/utils/dataset/README.md:
--------------------------------------------------------------------------------
1 | # Dataset Format
2 | ## RLHF dataset
3 | We combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers.
4 |
5 | Math problems
6 | ```json
7 | {
8 | "data_source": "openai/gsm8k",
9 | "prompt": [{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\""}],
10 | "ability": "math",
11 | "reward_model": {
12 | "style": "rule",
13 | "ground_truth": ["72"]
14 | },
15 | }
16 | ```
17 |
--------------------------------------------------------------------------------
/verl/verl/utils/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .rl_dataset import RLHFDataset
16 | from .rm_dataset import RMDataset
17 | from .sft_dataset import SFTDataset
18 |
--------------------------------------------------------------------------------
/verl/verl/utils/debug/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .performance import log_gpu_memory_usage
--------------------------------------------------------------------------------
/verl/verl/utils/debug/performance.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.distributed as dist
17 | import logging
18 |
19 |
20 | def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0):
21 | if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):
22 | memory_allocated = torch.cuda.memory_allocated() / 1024**3
23 | memory_reserved = torch.cuda.memory_reserved() / 1024**3
24 |
25 | message = f'{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}'
26 |
27 | if logger is None:
28 | print(message)
29 | else:
30 | logger.log(msg=message, level=level)
31 |
--------------------------------------------------------------------------------
/verl/verl/utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utilities for distributed training."""
15 | import os
16 |
17 |
18 | def initialize_global_process_group(timeout_second=36000):
19 | import torch.distributed
20 | from datetime import timedelta
21 | torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second))
22 | local_rank = int(os.environ["LOCAL_RANK"])
23 | rank = int(os.environ["RANK"])
24 | world_size = int(os.environ["WORLD_SIZE"])
25 |
26 | if torch.distributed.is_initialized():
27 | torch.cuda.set_device(local_rank)
28 | return local_rank, rank, world_size
29 |
--------------------------------------------------------------------------------
/verl/verl/utils/import_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Utilities to check if packages are available.
16 | We assume package availability won't change during runtime.
17 | """
18 |
19 | from functools import cache
20 | from typing import List
21 |
22 |
23 | @cache
24 | def is_megatron_core_available():
25 | try:
26 | from megatron.core import parallel_state as mpu
27 | return True
28 | except ImportError:
29 | return False
30 |
31 |
32 | @cache
33 | def is_vllm_available():
34 | try:
35 | import vllm
36 | return True
37 | except ImportError:
38 | return False
39 |
40 |
41 | def import_external_libs(external_libs=None):
42 | if external_libs is None:
43 | return
44 | if not isinstance(external_libs, List):
45 | external_libs = [external_libs]
46 | import importlib
47 | for external_lib in external_libs:
48 | importlib.import_module(external_lib)
49 |
--------------------------------------------------------------------------------
/verl/verl/utils/logger/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/utils/logger/aggregate_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | A Ray logger will receive logging info from different processes.
16 | """
17 | import numbers
18 | from typing import Dict
19 |
20 |
21 | def concat_dict_to_str(dict: Dict, step):
22 | output = [f'step:{step}']
23 | for k, v in dict.items():
24 | if isinstance(v, numbers.Number):
25 | output.append(f'{k}:{v:.3f}')
26 | output_str = ' - '.join(output)
27 | return output_str
28 |
29 |
30 | class LocalLogger:
31 |
32 | def __init__(self, remote_logger=None, enable_wandb=False, print_to_console=False):
33 | self.print_to_console = print_to_console
34 | if print_to_console:
35 | print('Using LocalLogger is deprecated. The constructor API will change ')
36 |
37 | def flush(self):
38 | pass
39 |
40 | def log(self, data, step):
41 | if self.print_to_console:
42 | print(concat_dict_to_str(data, step=step), flush=True)
--------------------------------------------------------------------------------
/verl/verl/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 |
17 |
18 | def set_basic_config(level):
19 | """
20 | This function sets the global logging format and level. It will be called when import verl
21 | """
22 | logging.basicConfig(format='%(levelname)s:%(asctime)s:%(message)s', level=level)
23 |
--------------------------------------------------------------------------------
/verl/verl/utils/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/utils/megatron/memory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 |
18 | class MemoryBuffer:
19 |
20 | def __init__(self, numel, numel_padded, dtype):
21 | self.numel = numel
22 | self.numel_padded = numel_padded
23 | self.dtype = dtype
24 | self.data = torch.zeros(self.numel_padded,
25 | dtype=self.dtype,
26 | device=torch.cuda.current_device(),
27 | requires_grad=False)
28 |
29 | def zero(self):
30 | """Reset the buffer to zero."""
31 | self.data.zero_()
32 |
33 | def get(self, shape, start_index):
34 | """Return a tensor with the input `shape` as a view into the
35 | 1-D data starting at `start_index`."""
36 | end_index = start_index + shape.numel()
37 | assert end_index <= self.numel, \
38 | 'requested tensor is out of the buffer range.'
39 | buffer_tensor = self.data[start_index:end_index]
40 | buffer_tensor = buffer_tensor.view(shape)
41 | return buffer_tensor
42 |
--------------------------------------------------------------------------------
/verl/verl/utils/megatron/optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import importlib
17 | from packaging.version import Version
18 |
19 | from apex.optimizers import FusedAdam as Adam
20 | from apex.optimizers import FusedSGD as SGD
21 |
22 | from megatron.core.optimizer import OptimizerConfig
23 |
24 | from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native
25 |
26 |
27 | def get_megatron_optimizer(
28 | model,
29 | config: OptimizerConfig,
30 | no_weight_decay_cond=None,
31 | scale_lr_cond=None,
32 | lr_mult=1.0,
33 | check_for_nan_in_loss_and_grad=False,
34 | overlap_param_gather=False # add for verl
35 | ):
36 | # Base optimizer.
37 | return get_megatron_optimizer_native(config=config,
38 | model_chunks=model,
39 | no_weight_decay_cond=no_weight_decay_cond,
40 | scale_lr_cond=scale_lr_cond,
41 | lr_mult=lr_mult)
42 |
--------------------------------------------------------------------------------
/verl/verl/utils/megatron/pipeline_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | from megatron.core import parallel_state as mpu
18 |
19 | from .sequence_parallel import pad_to_sequence_parallel
20 |
21 |
22 | def compute_transformers_input_shapes(batches, meta_info):
23 | from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron
24 | # pre-compute input shapes for each micro-batch at each pp stage
25 | input_shapes = []
26 | for model_inputs in batches:
27 | input_ids = model_inputs['input_ids']
28 | attention_mask = model_inputs['attention_mask']
29 | input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1)
30 | if meta_info['sequence_parallel']:
31 | input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad)
32 | # compute shapes for model_inputs
33 | input_shapes.append(
34 | torch.Size([
35 | input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), 1, meta_info['hidden_size']
36 | ]))
37 | else:
38 | # compute shapes for model_inputs
39 | input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info['hidden_size']]))
40 | return input_shapes
41 |
42 |
43 | def make_batch_generator(batches, vpp_size):
44 | if vpp_size > 1:
45 | # has vpp
46 | batch_generator = [batches] * vpp_size # number of vpp chunks
47 | batch_generator = [iter(b) for b in batch_generator]
48 | else:
49 | # no vpp
50 | batch_generator = iter(batches)
51 | return batch_generator
52 |
--------------------------------------------------------------------------------
/verl/verl/utils/megatron/sequence_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | import torch.nn.functional as F
18 | from megatron.core import parallel_state as mpu
19 |
20 |
21 | def mark_parameter_as_sequence_parallel(parameter):
22 | setattr(parameter, 'sequence_parallel', True)
23 |
24 |
25 | def is_sequence_parallel_param(param):
26 | return hasattr(param, 'sequence_parallel') and param.sequence_parallel
27 |
28 |
29 | def pad_to_sequence_parallel(unpad_tokens: torch.Tensor):
30 | """pad the tokens such that the total length is a multiple of sp world size
31 |
32 | Args:
33 | unpad_tokens: (total_nnz, ...). Tokens after removing padding
34 |
35 | Returns:
36 |
37 | """
38 | total_nnz = unpad_tokens.shape[0]
39 | sp_world_size = mpu.get_tensor_model_parallel_world_size()
40 |
41 | if total_nnz % sp_world_size == 0:
42 | pad_size = 0
43 | else:
44 | pad_size = sp_world_size - total_nnz % sp_world_size
45 |
46 | if pad_size > 0:
47 | if unpad_tokens.ndim == 1:
48 | unpad_tokens = F.pad(unpad_tokens, (0, pad_size))
49 | elif unpad_tokens.ndim == 2:
50 | unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))
51 | else:
52 | raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported')
53 |
54 | return unpad_tokens
55 |
--------------------------------------------------------------------------------
/verl/verl/utils/py_functional.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Contain small python utility functions
16 | """
17 |
18 | from typing import Dict
19 | from types import SimpleNamespace
20 |
21 |
22 | def union_two_dict(dict1: Dict, dict2: Dict):
23 | """Union two dict. Will throw an error if there is an item not the same object with the same key.
24 |
25 | Args:
26 | dict1:
27 | dict2:
28 |
29 | Returns:
30 |
31 | """
32 | for key, val in dict2.items():
33 | if key in dict1:
34 | assert dict2[key] == dict1[key], \
35 | f'{key} in meta_dict1 and meta_dict2 are not the same object'
36 | dict1[key] = val
37 |
38 | return dict1
39 |
40 |
41 | def append_to_dict(data: Dict, new_data: Dict):
42 | for key, val in new_data.items():
43 | if key not in data:
44 | data[key] = []
45 | data[key].append(val)
46 |
47 |
48 | class NestedNamespace(SimpleNamespace):
49 |
50 | def __init__(self, dictionary, **kwargs):
51 | super().__init__(**kwargs)
52 | for key, value in dictionary.items():
53 | if isinstance(value, dict):
54 | self.__setattr__(key, NestedNamespace(value))
55 | else:
56 | self.__setattr__(key, value)
57 |
--------------------------------------------------------------------------------
/verl/verl/utils/ray_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Contains commonly used utilities for ray
16 | """
17 |
18 | import ray
19 |
20 | import concurrent.futures
21 |
22 |
23 | def parallel_put(data_list, max_workers=None):
24 |
25 | def put_data(index, data):
26 | return index, ray.put(data)
27 |
28 | if max_workers is None:
29 | max_workers = min(len(data_list), 16)
30 |
31 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
32 | data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)]
33 | res_lst = []
34 | for future in concurrent.futures.as_completed(data_list_f):
35 | res_lst.append(future.result())
36 |
37 | # reorder based on index
38 | output = [None for _ in range(len(data_list))]
39 | for res in res_lst:
40 | index, data_ref = res
41 | output[index] = data_ref
42 |
43 | return output
44 |
--------------------------------------------------------------------------------
/verl/verl/utils/rendezvous/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/utils/rendezvous/ray_backend.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | import time
17 |
18 | from cupy.cuda.nccl import NcclCommunicator, get_unique_id
19 |
20 | import ray
21 | from ray.util import list_named_actors
22 |
23 |
24 | @ray.remote
25 | class NCCLIDStore:
26 |
27 | def __init__(self, nccl_id):
28 | self._nccl_id = nccl_id
29 |
30 | def get(self):
31 | return self._nccl_id
32 |
33 |
34 | def get_nccl_id_store_by_name(name):
35 | all_actors = list_named_actors(all_namespaces=True)
36 | matched_actors = [actor for actor in all_actors if actor.get("name", None) == name]
37 | if len(matched_actors) == 1:
38 | actor = matched_actors[0]
39 | return ray.get_actor(**actor)
40 | elif len(matched_actors) > 1:
41 | logging.warning(f"multiple actors with same name found: {matched_actors}")
42 | elif len(matched_actors) == 0:
43 | logging.info(f"failed to get any actor named {name}")
44 | return None
45 |
46 |
47 | def create_nccl_communicator_in_ray(rank: int,
48 | world_size: int,
49 | group_name: str,
50 | max_retries: int = 100,
51 | interval_s: int = 5):
52 | if rank == 0:
53 | nccl_id = get_unique_id()
54 | nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id)
55 |
56 | assert ray.get(nccl_id_store.get.remote()) == nccl_id
57 | communicator = NcclCommunicator(
58 | ndev=world_size,
59 | commId=nccl_id,
60 | rank=0,
61 | )
62 | return communicator
63 | else:
64 | for i in range(max_retries):
65 | nccl_id_store = get_nccl_id_store_by_name(group_name)
66 | if nccl_id_store is not None:
67 | logging.info(f"nccl_id_store {group_name} got")
68 | nccl_id = ray.get(nccl_id_store.get.remote())
69 | logging.info(f"nccl id for {group_name} got: {nccl_id}")
70 | communicator = NcclCommunicator(
71 | ndev=world_size,
72 | commId=nccl_id,
73 | rank=rank,
74 | )
75 | return communicator
76 | logging.info(f"failed to get nccl_id for {i+1} time, sleep for {interval_s} seconds")
77 | time.sleep(interval_s)
78 |
--------------------------------------------------------------------------------
/verl/verl/utils/reward_score/geo3k.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from mathruler.grader import extract_boxed_content, grade_answer
16 |
17 |
18 | def compute_score(predict_str: str, ground_truth: str) -> float:
19 | answer = extract_boxed_content(predict_str)
20 | if grade_answer(answer, ground_truth):
21 | return 1.0 # correct answer
22 |
23 | return 0.0 # wrong answer
24 |
--------------------------------------------------------------------------------
/verl/verl/utils/reward_score/gsm8k.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 |
17 |
18 | def extract_solution(solution_str, method='strict'):
19 | assert method in ['strict', 'flexible']
20 |
21 | if method == 'strict':
22 | # this also tests the formatting of the model
23 | solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
24 | if solution is None:
25 | final_answer = None
26 | else:
27 | final_answer = solution.group(0)
28 | final_answer = final_answer.split('#### ')[1].replace(',', '').replace('$', '')
29 | elif method == 'flexible':
30 | answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
31 | final_answer = None
32 | if len(answer) == 0:
33 | # no reward is there is no answer
34 | pass
35 | else:
36 | invalid_str = ['', '.']
37 | # find the last number that is not '.'
38 | for final_answer in reversed(answer):
39 | if final_answer not in invalid_str:
40 | break
41 | return final_answer
42 |
43 |
44 | def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.):
45 | """The scoring function for GSM8k.
46 |
47 | Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
48 |
49 | Args:
50 | solution_str: the solution text
51 | ground_truth: the ground truth
52 | method: the method to extract the solution, choices are 'strict' and 'flexible'
53 | format_score: the score for the format
54 | score: the score for the correct answer
55 | """
56 | answer = extract_solution(solution_str=solution_str, method=method)
57 | if answer is None:
58 | return 0
59 | else:
60 | if answer == ground_truth:
61 | return score
62 | else:
63 | return format_score
--------------------------------------------------------------------------------
/verl/verl/utils/reward_score/prime_code/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 PRIME team and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py
16 |
17 | import multiprocessing
18 | from typing import Dict, Optional
19 | from datasets import load_dataset
20 | from .testing_util import run_test
21 | import traceback
22 | import os, sys
23 |
24 |
25 | def _temp_run(sample, generation, debug, result, metadata_list, timeout):
26 | with open(os.devnull, 'w') as devnull:
27 | sys.stdout = devnull
28 | sys.stderr = devnull
29 | try:
30 | res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
31 | result.append(res)
32 | metadata_list.append(metadata)
33 | except Exception as e:
34 | # print(e) # some tracebacks are extremely long.
35 | traceback.print_exc(10)
36 | result.append([-1 for i in range(len(sample['inputs']))])
37 | metadata_list.append({})
38 |
39 |
40 | def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):
41 | """Check correctness of code generation with a global timeout.
42 | The global timeout is to catch some extreme/rare cases not handled by the timeouts
43 | inside `run_test`"""
44 |
45 | manager = multiprocessing.Manager()
46 | result = manager.list()
47 | metadata_list = manager.list()
48 | p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))
49 | p.start()
50 | p.join(timeout=timeout + 1)
51 | if p.is_alive():
52 | p.kill()
53 | # p.terminate()
54 | if not result:
55 | # consider that all tests failed
56 | result = [[-1 for i in range(len(in_outs["inputs"]))]]
57 | if debug:
58 | print(f"global timeout")
59 | return result[0], metadata_list
60 |
--------------------------------------------------------------------------------
/verl/verl/utils/torch_dtypes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Adapted from Cruise.
16 | """
17 |
18 | import torch
19 |
20 | from typing import Union
21 |
22 | HALF_LIST = [16, "16", "fp16", "float16", torch.float16]
23 | FLOAT_LIST = [32, "32", "fp32", "float32", torch.float32]
24 | BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16]
25 |
26 |
27 | class PrecisionType(object):
28 | """Type of precision used.
29 |
30 | >>> PrecisionType.HALF == 16
31 | True
32 | >>> PrecisionType.HALF in (16, "16")
33 | True
34 | """
35 |
36 | HALF = "16"
37 | FLOAT = "32"
38 | FULL = "64"
39 | BFLOAT = "bf16"
40 | MIXED = "mixed"
41 |
42 | @staticmethod
43 | def supported_type(precision: Union[str, int]) -> bool:
44 | return any(x == precision for x in PrecisionType)
45 |
46 | @staticmethod
47 | def supported_types() -> list[str]:
48 | return [x.value for x in PrecisionType]
49 |
50 | @staticmethod
51 | def is_fp16(precision):
52 | return precision in HALF_LIST
53 |
54 | @staticmethod
55 | def is_fp32(precision):
56 | return precision in FLOAT_LIST
57 |
58 | @staticmethod
59 | def is_bf16(precision):
60 | return precision in BFLOAT_LIST
61 |
62 | @staticmethod
63 | def to_dtype(precision):
64 | if precision in HALF_LIST:
65 | return torch.float16
66 | elif precision in FLOAT_LIST:
67 | return torch.float32
68 | elif precision in BFLOAT_LIST:
69 | return torch.bfloat16
70 | else:
71 | raise RuntimeError(f"unexpected precision: {precision}")
72 |
73 | @staticmethod
74 | def to_str(precision):
75 | if precision == torch.float16:
76 | return 'fp16'
77 | elif precision == torch.float32:
78 | return 'fp32'
79 | elif precision == torch.bfloat16:
80 | return 'bf16'
81 | else:
82 | raise RuntimeError(f"unexpected precision: {precision}")
83 |
--------------------------------------------------------------------------------
/verl/verl/version/version:
--------------------------------------------------------------------------------
1 | 0.2.0.dev
2 |
--------------------------------------------------------------------------------
/verl/verl/workers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/verl/workers/actor/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BasePPOActor
16 | from .dp_actor import DataParallelPPOActor
17 |
18 | __all__ = ["BasePPOActor", "DataParallelPPOActor"]
19 |
--------------------------------------------------------------------------------
/verl/verl/workers/actor/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | The base class for Actor
16 | """
17 | from abc import ABC, abstractmethod
18 | from typing import Iterable, Dict
19 |
20 | from verl import DataProto
21 | import torch
22 |
23 | __all__ = ['BasePPOActor']
24 |
25 |
26 | class BasePPOActor(ABC):
27 |
28 | def __init__(self, config):
29 | """The base class for PPO actor
30 |
31 | Args:
32 | config (DictConfig): a config passed to the PPOActor. We expect the type to be
33 | DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general.
34 | """
35 | super().__init__()
36 | self.config = config
37 |
38 | @abstractmethod
39 | def compute_log_prob(self, data: DataProto) -> torch.Tensor:
40 | """Compute logits given a batch of data.
41 |
42 | Args:
43 | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,
44 | ```attention_mask``` and ```position_ids```.
45 |
46 | Returns:
47 | DataProto: a DataProto containing the key ```log_probs```
48 |
49 |
50 | """
51 | pass
52 |
53 | @abstractmethod
54 | def update_policy(self, data: DataProto) -> Dict:
55 | """Update the policy with an iterator of DataProto
56 |
57 | Args:
58 | data (DataProto): an iterator over the DataProto that returns by
59 | ```make_minibatch_iterator```
60 |
61 | Returns:
62 | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model
63 | such as ```loss```, ```grad_norm```, etc,.
64 |
65 | """
66 | pass
67 |
--------------------------------------------------------------------------------
/verl/verl/workers/critic/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BasePPOCritic
16 | from .dp_critic import DataParallelPPOCritic
17 |
18 | __all__ = ["BasePPOCritic", "DataParallelPPOCritic"]
19 |
--------------------------------------------------------------------------------
/verl/verl/workers/critic/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Base class for a critic
16 | """
17 | from abc import ABC, abstractmethod
18 |
19 | import torch
20 |
21 | from verl import DataProto
22 |
23 | __all__ = ['BasePPOCritic']
24 |
25 |
26 | class BasePPOCritic(ABC):
27 |
28 | def __init__(self, config):
29 | super().__init__()
30 | self.config = config
31 |
32 | @abstractmethod
33 | def compute_values(self, data: DataProto) -> torch.Tensor:
34 | """Compute values"""
35 | pass
36 |
37 | @abstractmethod
38 | def update_critic(self, data: DataProto):
39 | """Update the critic"""
40 | pass
41 |
--------------------------------------------------------------------------------
/verl/verl/workers/reward_manager/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 PRIME team and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .naive import NaiveRewardManager
16 | from .prime import PrimeRewardManager
--------------------------------------------------------------------------------
/verl/verl/workers/reward_model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BasePPORewardModel
16 |
--------------------------------------------------------------------------------
/verl/verl/workers/reward_model/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | The base class for reward model
16 | """
17 |
18 | from abc import ABC, abstractmethod
19 |
20 | from verl import DataProto
21 |
22 |
23 | class BasePPORewardModel(ABC):
24 |
25 | def __init__(self, config):
26 | self.config = config
27 |
28 | @abstractmethod
29 | def compute_reward(self, data: DataProto) -> DataProto:
30 | """Computing reward given input_ids. The transformers should output a tensor with shape
31 | [batch_size, sequence_length], and the value at [EOS] mask should be gathered.
32 |
33 | Args:
34 | data: must contain keys "input_ids", "attention_mask" and "position_ids".
35 | - input_ids: [batch_size, sequence_length]
36 | - attention_mask: [batch_size, sequence_length]
37 | - position_ids: [batch_size, sequence_length]
38 |
39 | Returns: a data pass protocol containing "reward". Only the [EOS] position contains the reward.
40 | Other position should have zero reward. Note that this may change in the future if we use
41 | dense reward. So, we leave the interface for general case.
42 | - reward: [batch_size, sequence_length].
43 |
44 | """
45 | pass
46 |
--------------------------------------------------------------------------------
/verl/verl/workers/reward_model/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .reward_model import MegatronRewardModel
16 |
--------------------------------------------------------------------------------
/verl/verl/workers/rollout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BaseRollout
16 | from .naive import NaiveRollout
17 | from .hf_rollout import HFRollout
18 |
19 | __all__ = ["BaseRollout", "NaiveRollout", "HFRollout"]
20 |
--------------------------------------------------------------------------------
/verl/verl/workers/rollout/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from abc import ABC, abstractmethod
16 | from typing import Iterable, Union
17 |
18 | from verl import DataProto
19 |
20 | __all__ = ['BaseRollout']
21 |
22 |
23 | class BaseRollout(ABC):
24 |
25 | def __init__(self):
26 | """
27 |
28 | Args:
29 | dataloader: an Iterable of TensorDict that consistently generates prompts. Note that the dataloader
30 | should handle when the training stops.
31 | """
32 | super().__init__()
33 |
34 | @abstractmethod
35 | def generate_sequences(self, prompts: DataProto) -> DataProto:
36 | """Generate sequences"""
37 | pass
38 |
--------------------------------------------------------------------------------
/verl/verl/workers/rollout/naive/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .naive_rollout import NaiveRollout
16 |
--------------------------------------------------------------------------------
/verl/verl/workers/rollout/vllm_rollout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from importlib.metadata import version, PackageNotFoundError
16 |
17 |
18 | def get_version(pkg):
19 | try:
20 | return version(pkg)
21 | except PackageNotFoundError:
22 | return None
23 |
24 |
25 | package_name = 'vllm'
26 | package_version = get_version(package_name)
27 |
28 | if package_version <= '0.6.3':
29 | vllm_mode = 'customized'
30 | from .vllm_rollout import vLLMRollout
31 | from .fire_vllm_rollout import FIREvLLMRollout
32 | else:
33 | vllm_mode = 'spmd'
34 | from .vllm_rollout_spmd import vLLMRollout
35 |
--------------------------------------------------------------------------------
/verl/verl/workers/sharding_manager/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from verl.utils.import_utils import is_vllm_available, is_megatron_core_available
16 |
17 | from .base import BaseShardingManager
18 | from .fsdp_ulysses import FSDPUlyssesShardingManager
19 |
20 | AllGatherPPModel = None
21 |
22 | if is_megatron_core_available() and is_vllm_available():
23 | from .megatron_vllm import AllGatherPPModel, MegatronVLLMShardingManager
24 | elif AllGatherPPModel is not None:
25 | pass
26 | else:
27 | AllGatherPPModel = None
28 | MegatronVLLMShardingManager = None
29 |
30 | if is_vllm_available():
31 | from .fsdp_vllm import FSDPVLLMShardingManager
32 | else:
33 | FSDPVLLMShardingManager = None
34 |
--------------------------------------------------------------------------------
/verl/verl/workers/sharding_manager/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Sharding manager to implement HybridEngine
16 | """
17 |
18 | from verl import DataProto
19 |
20 |
21 | class BaseShardingManager:
22 |
23 | def __enter__(self):
24 | pass
25 |
26 | def __exit__(self, exc_type, exc_value, traceback):
27 | pass
28 |
29 | def preprocess_data(self, data: DataProto) -> DataProto:
30 | return data
31 |
32 | def postprocess_data(self, data: DataProto) -> DataProto:
33 | return data
34 |
--------------------------------------------------------------------------------