├── .gitattributes ├── LICENSE ├── LLaMA-Factory ├── .env.local ├── .gitattributes ├── LICENSE ├── MANIFEST.in ├── Makefile ├── data │ ├── dataset_info.json │ └── tool_sft.json ├── evaluation │ ├── ceval │ │ ├── ceval.py │ │ ├── ceval.zip │ │ └── mapping.json │ ├── cmmlu │ │ ├── cmmlu.py │ │ ├── cmmlu.zip │ │ └── mapping.json │ └── mmlu │ │ ├── mapping.json │ │ ├── mmlu.py │ │ └── mmlu.zip ├── examples │ ├── README.md │ ├── qwen_merge.yaml │ └── qwen_sft.yaml ├── get_raw_sft_data.py ├── lam_data_process.py ├── pyproject.toml ├── requirements.txt ├── scripts │ ├── api_example │ │ ├── test_image.py │ │ └── test_toolcall.py │ ├── convert_ckpt │ │ ├── llamafy_baichuan2.py │ │ ├── llamafy_qwen.py │ │ └── tiny_llama4.py │ ├── eval_bleu_rouge.py │ ├── llama_pro.py │ ├── loftq_init.py │ ├── pissa_init.py │ ├── qwen_omni_merge.py │ ├── stat_utils │ │ ├── cal_flops.py │ │ ├── cal_lr.py │ │ ├── cal_mfu.py │ │ ├── cal_ppl.py │ │ └── length_cdf.py │ └── vllm_infer.py ├── setup.py └── src │ ├── api.py │ ├── llamafactory.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── entry_points.txt │ ├── requires.txt │ └── top_level.txt │ ├── llamafactory │ ├── __init__.py │ ├── api │ │ ├── __init__.py │ │ ├── app.py │ │ ├── chat.py │ │ ├── common.py │ │ └── protocol.py │ ├── chat │ │ ├── __init__.py │ │ ├── base_engine.py │ │ ├── chat_model.py │ │ ├── hf_engine.py │ │ ├── sglang_engine.py │ │ └── vllm_engine.py │ ├── cli.py │ ├── data │ │ ├── __init__.py │ │ ├── collator.py │ │ ├── converter.py │ │ ├── data_utils.py │ │ ├── formatter.py │ │ ├── loader.py │ │ ├── mm_plugin.py │ │ ├── parser.py │ │ ├── processor │ │ │ ├── __init__.py │ │ │ ├── feedback.py │ │ │ ├── pairwise.py │ │ │ ├── pretrain.py │ │ │ ├── processor_utils.py │ │ │ ├── supervised.py │ │ │ └── unsupervised.py │ │ ├── template.py │ │ └── tool_utils.py │ ├── eval │ │ ├── __init__.py │ │ ├── evaluator.py │ │ └── template.py │ ├── extras │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── env.py │ │ ├── logging.py │ │ ├── misc.py │ │ ├── packages.py │ │ └── ploting.py │ ├── hparams │ │ ├── __init__.py │ │ ├── data_args.py │ │ ├── evaluation_args.py │ │ ├── finetuning_args.py │ │ ├── generating_args.py │ │ ├── model_args.py │ │ ├── parser.py │ │ └── training_args.py │ ├── launcher.py │ ├── model │ │ ├── __init__.py │ │ ├── adapter.py │ │ ├── loader.py │ │ ├── model_utils │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── checkpointing.py │ │ │ ├── embedding.py │ │ │ ├── kv_cache.py │ │ │ ├── liger_kernel.py │ │ │ ├── longlora.py │ │ │ ├── misc.py │ │ │ ├── mod.py │ │ │ ├── moe.py │ │ │ ├── packing.py │ │ │ ├── quantization.py │ │ │ ├── rope.py │ │ │ ├── unsloth.py │ │ │ ├── valuehead.py │ │ │ └── visual.py │ │ └── patcher.py │ ├── third_party │ │ ├── __init__.py │ │ └── muon │ │ │ ├── __init__.py │ │ │ └── muon.py │ ├── train │ │ ├── __init__.py │ │ ├── callbacks.py │ │ ├── dpo │ │ │ ├── __init__.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── kto │ │ │ ├── __init__.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── ppo │ │ │ ├── __init__.py │ │ │ ├── ppo_utils.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── pt │ │ │ ├── __init__.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── rm │ │ │ ├── __init__.py │ │ │ ├── metric.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── sft │ │ │ ├── __init__.py │ │ │ ├── metric.py │ │ │ ├── trainer.py │ │ │ └── workflow.py │ │ ├── test_utils.py │ │ ├── trainer_utils.py │ │ └── tuner.py │ └── webui │ │ ├── __init__.py │ │ ├── chatter.py │ │ ├── common.py │ │ ├── components │ │ ├── __init__.py │ │ ├── chatbot.py │ │ ├── data.py │ │ ├── eval.py │ │ ├── export.py │ │ ├── infer.py │ │ ├── top.py │ │ └── train.py │ │ ├── control.py │ │ ├── css.py │ │ ├── engine.py │ │ ├── interface.py │ │ ├── locales.py │ │ ├── manager.py │ │ └── runner.py │ ├── train.py │ └── webui.py ├── README.md ├── assets ├── data_composition.png ├── exp_main.png ├── overview.png ├── reward.png ├── scability.png ├── sft_or_rl.png └── thinking_template.png ├── data_process ├── distill_data.py ├── hammer_utils.py └── raw_data_process.py ├── eval └── bfcl_handler.py ├── llamafact_merge.sh ├── paper.pdf ├── qwen_rl.sh ├── qwen_sft.sh ├── verl ├── LICENSE ├── examples │ ├── agent │ │ └── qwen.sh │ └── data_preprocess │ │ └── verl_toolcall_preprocess.py ├── setup.py └── verl │ ├── __init__.py │ ├── data │ ├── test.parquet │ └── train.parquet │ ├── models │ ├── README.md │ ├── __init__.py │ ├── llama │ │ ├── __init__.py │ │ └── megatron │ │ │ ├── __init__.py │ │ │ ├── checkpoint_utils │ │ │ ├── __init__.py │ │ │ ├── llama_loader.py │ │ │ └── llama_saver.py │ │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── parallel_attention.py │ │ │ ├── parallel_decoder.py │ │ │ ├── parallel_linear.py │ │ │ ├── parallel_mlp.py │ │ │ └── parallel_rmsnorm.py │ │ │ └── modeling_llama_megatron.py │ ├── qwen2 │ │ ├── __init__.py │ │ └── megatron │ │ │ ├── __init__.py │ │ │ ├── checkpoint_utils │ │ │ ├── __init__.py │ │ │ ├── qwen2_loader.py │ │ │ └── qwen2_saver.py │ │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── parallel_attention.py │ │ │ ├── parallel_decoder.py │ │ │ ├── parallel_linear.py │ │ │ ├── parallel_mlp.py │ │ │ └── parallel_rmsnorm.py │ │ │ └── modeling_qwen2_megatron.py │ ├── registry.py │ ├── transformers │ │ ├── __init__.py │ │ ├── llama.py │ │ ├── monkey_patch.py │ │ ├── qwen2.py │ │ └── qwen2_vl.py │ └── weight_loader_registry.py │ ├── protocol.py │ ├── single_controller │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── decorator.py │ │ ├── megatron │ │ │ ├── __init__.py │ │ │ ├── worker.py │ │ │ └── worker_group.py │ │ ├── register_center │ │ │ ├── __init__.py │ │ │ └── ray.py │ │ ├── worker.py │ │ └── worker_group.py │ └── ray │ │ ├── __init__.py │ │ ├── base.py │ │ └── megatron.py │ ├── third_party │ ├── __init__.py │ └── vllm │ │ ├── __init__.py │ │ ├── vllm_spmd │ │ ├── __init__.py │ │ └── dtensor_weight_loaders.py │ │ ├── vllm_v_0_3_1 │ │ ├── __init__.py │ │ ├── arg_utils.py │ │ ├── config.py │ │ ├── llm.py │ │ ├── llm_engine_sp.py │ │ ├── model_loader.py │ │ ├── model_runner.py │ │ ├── parallel_state.py │ │ ├── tokenizer.py │ │ ├── weight_loaders.py │ │ └── worker.py │ │ ├── vllm_v_0_4_2 │ │ ├── __init__.py │ │ ├── arg_utils.py │ │ ├── config.py │ │ ├── dtensor_weight_loaders.py │ │ ├── hf_weight_loader.py │ │ ├── llm.py │ │ ├── llm_engine_sp.py │ │ ├── megatron_weight_loaders.py │ │ ├── model_loader.py │ │ ├── model_runner.py │ │ ├── parallel_state.py │ │ ├── spmd_gpu_executor.py │ │ ├── tokenizer.py │ │ └── worker.py │ │ ├── vllm_v_0_5_4 │ │ ├── __init__.py │ │ ├── arg_utils.py │ │ ├── config.py │ │ ├── dtensor_weight_loaders.py │ │ ├── hf_weight_loader.py │ │ ├── llm.py │ │ ├── llm_engine_sp.py │ │ ├── megatron_weight_loaders.py │ │ ├── model_loader.py │ │ ├── model_runner.py │ │ ├── parallel_state.py │ │ ├── spmd_gpu_executor.py │ │ ├── tokenizer.py │ │ └── worker.py │ │ └── vllm_v_0_6_3 │ │ ├── __init__.py │ │ ├── arg_utils.py │ │ ├── config.py │ │ ├── dtensor_weight_loaders.py │ │ ├── hf_weight_loader.py │ │ ├── llm.py │ │ ├── llm_engine_sp.py │ │ ├── megatron_weight_loaders.py │ │ ├── model_loader.py │ │ ├── model_runner.py │ │ ├── parallel_state.py │ │ ├── spmd_gpu_executor.py │ │ ├── tokenizer.py │ │ └── worker.py │ ├── trainer │ ├── __init__.py │ ├── config │ │ ├── evaluation.yaml │ │ ├── generation.yaml │ │ ├── ppo_megatron_trainer.yaml │ │ ├── ppo_trainer.yaml │ │ └── sft_trainer.yaml │ ├── fsdp_sft_trainer.py │ ├── main_eval.py │ ├── main_generation.py │ ├── main_ppo.py │ ├── ppo │ │ ├── __init__.py │ │ ├── core_algos.py │ │ └── ray_trainer.py │ └── runtime_env.yaml │ ├── utils │ ├── __init__.py │ ├── checkpoint │ │ ├── __init__.py │ │ ├── checkpoint_manager.py │ │ └── fsdp_checkpoint_manager.py │ ├── config.py │ ├── dataset │ │ ├── README.md │ │ ├── __init__.py │ │ ├── rl_dataset.py │ │ ├── rm_dataset.py │ │ └── sft_dataset.py │ ├── debug │ │ ├── __init__.py │ │ ├── performance.py │ │ └── trajectory_tracker.py │ ├── distributed.py │ ├── flops_counter.py │ ├── fs.py │ ├── fsdp_utils.py │ ├── hdfs_io.py │ ├── import_utils.py │ ├── logger │ │ ├── __init__.py │ │ └── aggregate_logger.py │ ├── logging_utils.py │ ├── megatron │ │ ├── __init__.py │ │ ├── memory.py │ │ ├── optimizer.py │ │ ├── pipeline_parallel.py │ │ ├── sequence_parallel.py │ │ └── tensor_parallel.py │ ├── megatron_utils.py │ ├── memory_buffer.py │ ├── model.py │ ├── py_functional.py │ ├── ray_utils.py │ ├── rendezvous │ │ ├── __init__.py │ │ └── ray_backend.py │ ├── reward_score │ │ ├── __init__.py │ │ ├── geo3k.py │ │ ├── gsm8k.py │ │ ├── math.py │ │ ├── prime_code │ │ │ ├── __init__.py │ │ │ ├── testing_util.py │ │ │ └── utils.py │ │ ├── prime_math │ │ │ ├── __init__.py │ │ │ ├── grader.py │ │ │ └── math_normalize.py │ │ └── toolcall.py │ ├── seqlen_balancing.py │ ├── tokenizer.py │ ├── tool_utils.py │ ├── torch_dtypes.py │ ├── torch_functional.py │ ├── tracking.py │ └── ulysses.py │ ├── version │ └── version │ └── workers │ ├── __init__.py │ ├── actor │ ├── __init__.py │ ├── base.py │ ├── dp_actor.py │ └── megatron_actor.py │ ├── critic │ ├── __init__.py │ ├── base.py │ ├── dp_critic.py │ └── megatron_critic.py │ ├── fsdp_workers.py │ ├── megatron_workers.py │ ├── reward_manager │ ├── __init__.py │ ├── naive.py │ └── prime.py │ ├── reward_model │ ├── __init__.py │ ├── base.py │ └── megatron │ │ ├── __init__.py │ │ └── reward_model.py │ ├── rollout │ ├── __init__.py │ ├── base.py │ ├── hf_rollout.py │ ├── naive │ │ ├── __init__.py │ │ └── naive_rollout.py │ ├── tokenizer.py │ └── vllm_rollout │ │ ├── __init__.py │ │ ├── fire_vllm_rollout.py │ │ ├── vllm_rollout.py │ │ └── vllm_rollout_spmd.py │ └── sharding_manager │ ├── __init__.py │ ├── base.py │ ├── fsdp_ulysses.py │ ├── fsdp_vllm.py │ └── megatron_vllm.py └── verl_convert.py /.gitattributes: -------------------------------------------------------------------------------- 1 | LLaMA-Factory/data/tool_sft.json filter=lfs diff=lfs merge=lfs -text 2 | verl/verl/data/test.parquet filter=lfs diff=lfs merge=lfs -text 3 | verl/verl/data/train.parquet filter=lfs diff=lfs merge=lfs -text 4 | -------------------------------------------------------------------------------- /LLaMA-Factory/.env.local: -------------------------------------------------------------------------------- 1 | # Note: actually we do not support .env, just for reference 2 | # api 3 | API_HOST= 4 | API_PORT= 5 | API_KEY= 6 | API_MODEL_NAME= 7 | API_VERBOSE= 8 | FASTAPI_ROOT_PATH= 9 | MAX_CONCURRENT= 10 | # general 11 | DISABLE_VERSION_CHECK= 12 | FORCE_CHECK_IMPORTS= 13 | ALLOW_EXTRA_ARGS= 14 | LLAMAFACTORY_VERBOSITY= 15 | USE_MODELSCOPE_HUB= 16 | USE_OPENMIND_HUB= 17 | USE_RAY= 18 | RECORD_VRAM= 19 | OPTIM_TORCH= 20 | NPU_JIT_COMPILE= 21 | # torchrun 22 | FORCE_TORCHRUN= 23 | MASTER_ADDR= 24 | MASTER_PORT= 25 | NNODES= 26 | NODE_RANK= 27 | NPROC_PER_NODE= 28 | # wandb 29 | WANDB_DISABLED= 30 | WANDB_PROJECT= 31 | WANDB_API_KEY= 32 | # gradio ui 33 | GRADIO_SHARE= 34 | GRADIO_SERVER_NAME= 35 | GRADIO_SERVER_PORT= 36 | GRADIO_ROOT_PATH= 37 | GRADIO_IPV6= 38 | # setup 39 | ENABLE_SHORT_CONSOLE= 40 | # reserved (do not use) 41 | LLAMABOARD_ENABLED= 42 | LLAMABOARD_WORKDIR= 43 | -------------------------------------------------------------------------------- /LLaMA-Factory/.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LLaMA-Factory/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE requirements.txt 2 | -------------------------------------------------------------------------------- /LLaMA-Factory/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build commit license quality style test 2 | 3 | check_dirs := scripts src tests setup.py 4 | 5 | build: 6 | pip3 install build && python3 -m build 7 | 8 | commit: 9 | pre-commit install 10 | pre-commit run --all-files 11 | 12 | license: 13 | python3 tests/check_license.py $(check_dirs) 14 | 15 | quality: 16 | ruff check $(check_dirs) 17 | ruff format --check $(check_dirs) 18 | 19 | style: 20 | ruff check $(check_dirs) --fix 21 | ruff format $(check_dirs) 22 | 23 | test: 24 | CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/ 25 | -------------------------------------------------------------------------------- /LLaMA-Factory/data/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "tool_sft": { 3 | "file_name": "tool_sft.json", 4 | "columns": { 5 | "prompt": "instruction", 6 | "query": "query", 7 | "response": "output", 8 | "system": "system_prompt" 9 | } 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /LLaMA-Factory/data/tool_sft.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:836c6592b03ee881ca3ba4934a25e69b14a909567f062e91da3aef6ac146c7fb 3 | size 24144869 4 | -------------------------------------------------------------------------------- /LLaMA-Factory/evaluation/ceval/ceval.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/evaluation/ceval/ceval.zip -------------------------------------------------------------------------------- /LLaMA-Factory/evaluation/cmmlu/cmmlu.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/evaluation/cmmlu/cmmlu.zip -------------------------------------------------------------------------------- /LLaMA-Factory/evaluation/mmlu/mmlu.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/evaluation/mmlu/mmlu.zip -------------------------------------------------------------------------------- /LLaMA-Factory/examples/qwen_merge.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: path/to/model/Qwen2.5-7B-Instruct 3 | adapter_name_or_path: path/to/adapter/qwen-7b 4 | template: qwen 5 | finetuning_type: lora 6 | 7 | ### export 8 | export_dir: path/to/export 9 | export_size: 2 10 | export_device: cpu 11 | export_legacy_format: false -------------------------------------------------------------------------------- /LLaMA-Factory/examples/qwen_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: path/to/model/Qwen2.5-7B-Instruct 3 | trust_remote_code: true 4 | 5 | ### method 6 | stage: sft 7 | do_train: true 8 | finetuning_type: lora 9 | lora_rank: 8 10 | lora_target: all 11 | 12 | ### dataset 13 | dataset: tool_sft 14 | template: qwen 15 | cutoff_len: 2048 16 | overwrite_cache: true 17 | preprocessing_num_workers: 16 18 | dataloader_num_workers: 4 19 | 20 | ### output 21 | output_dir: saves/qwen-7b/lora/sft 22 | logging_steps: 10 23 | save_steps: 10 24 | plot_loss: true 25 | overwrite_output_dir: true 26 | save_only_model: false 27 | report_to: wandb # choices: [none, wandb, tensorboard, swanlab, mlflow] 28 | 29 | ### train 30 | per_device_train_batch_size: 1 31 | gradient_accumulation_steps: 8 32 | learning_rate: 1.0e-4 33 | num_train_epochs: 5.0 34 | lr_scheduler_type: cosine 35 | warmup_ratio: 0.1 36 | bf16: true 37 | ddp_timeout: 180000000 38 | resume_from_checkpoint: null -------------------------------------------------------------------------------- /LLaMA-Factory/get_raw_sft_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | 18 | import json 19 | import os 20 | 21 | def process_json(input_path, output_path): 22 | with open(input_path, 'r') as f: 23 | data = json.load(f) 24 | 25 | new_instruction = '''When you want to perform tool call, in each action step, providing a json object with function names and arguments within XML tags. i.e., [{"name": , "arguments": }, {"name": , "arguments": }, ...] 26 | A complete reply example is: [{"name": "email", "arguments": {"receiver": "Bob", "content": "I will bug banana through walmart"}}, {"name": "walmart", "arguments": {"input": "banana"}}]. Please make sure the type of the arguments is correct.''' 27 | 28 | for entry in data: 29 | if 'system_prompt' in entry and "In each action step, you MUST" in entry['system_prompt']: 30 | parts = entry['system_prompt'].split("In each action step, you MUST", 1) 31 | entry['system_prompt'] = parts[0] + new_instruction 32 | 33 | if 'output' in entry: 34 | think_start = entry['output'].find("") 35 | think_end = entry['output'].find("") + len("") 36 | if think_start != -1 and think_end != -1: 37 | entry['output'] = entry['output'][think_end:].lstrip("\n") 38 | 39 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 40 | with open(output_path, 'w') as f: 41 | json.dump(data, f, indent=2, ensure_ascii=False) 42 | 43 | process_json("path/to/tool_sft.json", "path/to/data/raw_tool_sft.json") 44 | -------------------------------------------------------------------------------- /LLaMA-Factory/lam_data_process.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | 18 | import json 19 | 20 | def truncate_after_first_tool_call_end(s): 21 | marker = "" 22 | index = s.find(marker) 23 | if index != -1: 24 | return s[:index + len(marker)] 25 | return s 26 | 27 | def truncate_before_expert_prompt(s): 28 | marker = "You are an expert in composing functions" 29 | index = s.find(marker) 30 | if index != -1: 31 | return s[index:] 32 | return s 33 | 34 | # Define file paths 35 | input_file = "path/to/distilled_data/sft_data.json" 36 | output_file = "path/to/data/tool_sft.json" 37 | 38 | # Load the input JSON data 39 | with open(input_file, "r", encoding="utf-8") as f: 40 | input_data = json.load(f) 41 | 42 | # Transform the data to match the Alpaca format 43 | alpaca_data = [] 44 | for item in input_data: 45 | transformed_item = { 46 | "system_prompt": truncate_before_expert_prompt(item.get("system_prompt", "")), 47 | "query": item.get("query", ""), 48 | "output": truncate_after_first_tool_call_end(item.get("output", "")), 49 | "instruction": "" 50 | } 51 | alpaca_data.append(transformed_item) 52 | 53 | # Save the transformed data to the output file 54 | with open(output_file, "w", encoding="utf-8") as f: 55 | json.dump(alpaca_data, f, ensure_ascii=False, indent=2) 56 | 57 | print(f"Data has been successfully transformed and saved to {output_file}") -------------------------------------------------------------------------------- /LLaMA-Factory/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llamafactory" 7 | dynamic = [ 8 | "version", 9 | "dependencies", 10 | "optional-dependencies", 11 | "requires-python", 12 | "scripts", 13 | "authors", 14 | "description", 15 | "readme", 16 | "license", 17 | "keywords", 18 | "classifiers" 19 | ] 20 | 21 | [tool.ruff] 22 | target-version = "py39" 23 | line-length = 119 24 | indent-width = 4 25 | 26 | [tool.ruff.lint] 27 | ignore = [ 28 | "C408", # collection 29 | "C901", # complex 30 | "E501", # line too long 31 | "E731", # lambda function 32 | "E741", # ambiguous var name 33 | "D100", # no doc public module 34 | "D101", # no doc public class 35 | "D102", # no doc public method 36 | "D103", # no doc public function 37 | "D104", # no doc public package 38 | "D105", # no doc magic method 39 | "D107", # no doc __init__ 40 | ] 41 | extend-select = [ 42 | "C", # complexity 43 | "E", # error 44 | "F", # pyflakes 45 | "I", # isort 46 | "W", # warning 47 | "UP", # pyupgrade 48 | "D", # pydocstyle 49 | "PT009", # pytest assert 50 | "RUF022", # sort __all__ 51 | ] 52 | 53 | [tool.ruff.lint.isort] 54 | lines-after-imports = 2 55 | known-first-party = ["llamafactory"] 56 | known-third-party = [ 57 | "accelerate", 58 | "datasets", 59 | "gradio", 60 | "numpy", 61 | "peft", 62 | "torch", 63 | "transformers", 64 | "trl", 65 | ] 66 | 67 | [tool.ruff.lint.pydocstyle] 68 | convention = "google" 69 | 70 | [tool.ruff.format] 71 | quote-style = "double" 72 | indent-style = "space" 73 | docstring-code-format = true 74 | skip-magic-trailing-comma = false 75 | line-ending = "auto" 76 | 77 | [tool.uv] 78 | conflicts = [ 79 | [ 80 | { extra = "torch-npu" }, 81 | { extra = "aqlm" }, 82 | ], 83 | [ 84 | { extra = "torch-npu" }, 85 | { extra = "liger-kernel" }, 86 | ], 87 | [ 88 | { extra = "torch-npu" }, 89 | { extra = "vllm" }, 90 | ], 91 | [ 92 | { extra = "sglang" }, 93 | { extra = "minicpm_v" }, 94 | ], 95 | ] 96 | -------------------------------------------------------------------------------- /LLaMA-Factory/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0 2 | datasets>=2.16.0,<=3.5.0 3 | accelerate>=0.34.0,<=1.6.0 4 | peft>=0.14.0,<=0.15.1 5 | trl>=0.8.6,<=0.9.6 6 | tokenizers>=0.19.0,<=0.21.1 7 | gradio>=4.38.0,<=5.25.0 8 | scipy 9 | einops 10 | sentencepiece 11 | tiktoken 12 | protobuf 13 | uvicorn 14 | fastapi 15 | sse-starlette 16 | matplotlib>=3.7.0 17 | fire 18 | omegaconf 19 | packaging 20 | pyyaml 21 | numpy<2.0.0 22 | pydantic<=2.10.6 23 | pandas>=2.0.0 24 | av 25 | librosa 26 | tyro<0.9.0 27 | -------------------------------------------------------------------------------- /LLaMA-Factory/scripts/api_example/test_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from openai import OpenAI 18 | from transformers.utils.versions import require_version 19 | 20 | 21 | require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") 22 | 23 | 24 | def main(): 25 | client = OpenAI( 26 | api_key="{}".format(os.getenv("API_KEY", "0")), 27 | base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)), 28 | ) 29 | messages = [] 30 | messages.append( 31 | { 32 | "role": "user", 33 | "content": [ 34 | {"type": "text", "text": "Output the color and number of each box."}, 35 | { 36 | "type": "image_url", 37 | "image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"}, 38 | }, 39 | ], 40 | } 41 | ) 42 | result = client.chat.completions.create(messages=messages, model="test") 43 | messages.append(result.choices[0].message) 44 | print("Round 1:", result.choices[0].message.content) 45 | # The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ... 46 | messages.append( 47 | { 48 | "role": "user", 49 | "content": [ 50 | {"type": "text", "text": "What kind of flower is this?"}, 51 | { 52 | "type": "image_url", 53 | "image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"}, 54 | }, 55 | ], 56 | } 57 | ) 58 | result = client.chat.completions.create(messages=messages, model="test") 59 | messages.append(result.choices[0].message) 60 | print("Round 2:", result.choices[0].message.content) 61 | # The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ... 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /LLaMA-Factory/scripts/convert_ckpt/tiny_llama4.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from transformers import Llama4Config, Llama4ForConditionalGeneration, Llama4TextConfig, Llama4VisionConfig 16 | 17 | 18 | if __name__ == "__main__": 19 | vision_config = Llama4VisionConfig( 20 | hidden_size=1408, 21 | image_size=336, 22 | intermediate_size=5632, 23 | num_attention_heads=16, 24 | num_hidden_layers=4, 25 | vision_output_dim=4096, 26 | ) 27 | text_config = Llama4TextConfig( 28 | hidden_size=512, 29 | intermediate_size=1024, 30 | intermediate_size_mlp=1024, 31 | num_hidden_layers=4, 32 | num_attention_heads=8, 33 | num_key_value_heads=2, 34 | head_dim=512 // 8, 35 | num_local_experts=2, 36 | ) 37 | config = Llama4Config(vision_config=vision_config, text_config=text_config) 38 | model = Llama4ForConditionalGeneration._from_config(config) 39 | model.save_pretrained("tiny-llama4") 40 | -------------------------------------------------------------------------------- /LLaMA-Factory/scripts/eval_bleu_rouge.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import logging 17 | import time 18 | 19 | import fire 20 | from datasets import load_dataset 21 | 22 | 23 | try: 24 | import jieba # type: ignore 25 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu # type: ignore 26 | from rouge_chinese import Rouge # type: ignore 27 | 28 | jieba.setLogLevel(logging.CRITICAL) 29 | jieba.initialize() 30 | except ImportError: 31 | print("Please install llamafactory with `pip install -e .[metrics]`.") 32 | raise 33 | 34 | 35 | def compute_metrics(sample): 36 | hypothesis = list(jieba.cut(sample["predict"])) 37 | reference = list(jieba.cut(sample["label"])) 38 | 39 | bleu_score = sentence_bleu( 40 | [list(sample["label"])], 41 | list(sample["predict"]), 42 | smoothing_function=SmoothingFunction().method3, 43 | ) 44 | 45 | if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: 46 | result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} 47 | else: 48 | rouge = Rouge() 49 | scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) 50 | result = scores[0] 51 | 52 | metric_result = {} 53 | for k, v in result.items(): 54 | metric_result[k] = round(v["f"] * 100, 4) 55 | 56 | metric_result["bleu-4"] = round(bleu_score * 100, 4) 57 | 58 | return metric_result 59 | 60 | 61 | def main(filename: str): 62 | start_time = time.time() 63 | dataset = load_dataset("json", data_files=filename, split="train") 64 | dataset = dataset.map(compute_metrics, num_proc=8, remove_columns=dataset.column_names) 65 | score_dict = dataset.to_dict() 66 | 67 | average_score = {} 68 | for task, scores in sorted(score_dict.items(), key=lambda x: x[0]): 69 | print(f"{task}: {sum(scores) / len(scores):.4f}") 70 | average_score[task] = sum(scores) / len(scores) 71 | 72 | with open("predictions_score.json", "w", encoding="utf-8") as f: 73 | json.dump(average_score, f, indent=4) 74 | 75 | print(f"\nDone in {time.time() - start_time:.3f}s.\nScore file saved to predictions_score.json") 76 | 77 | 78 | if __name__ == "__main__": 79 | fire.Fire(main) 80 | -------------------------------------------------------------------------------- /LLaMA-Factory/scripts/stat_utils/cal_flops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Microsoft Corporation and the LlamaFactory team. 2 | # 3 | # This code is inspired by the Microsoft's DeepSpeed library. 4 | # https://www.deepspeed.ai/tutorials/flops-profiler/ 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import fire 19 | import torch 20 | from deepspeed.accelerator import get_accelerator # type: ignore 21 | from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore 22 | 23 | from llamafactory.chat import ChatModel 24 | 25 | 26 | def calculate_flops( 27 | model_name_or_path: str, 28 | batch_size: int = 1, 29 | seq_length: int = 512, 30 | flash_attn: str = "auto", 31 | ): 32 | r"""Calculate the flops of pre-trained models. 33 | 34 | Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 35 | """ 36 | with get_accelerator().device(0): 37 | chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn)) 38 | fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device) 39 | input_dict = {"input_ids": fake_input, "labels": fake_input.clone()} 40 | flops, macs, params = get_model_profile( 41 | chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True 42 | ) 43 | print("FLOPs:", flops) 44 | print("MACs:", macs) 45 | print("Params:", params) 46 | 47 | 48 | if __name__ == "__main__": 49 | fire.Fire(calculate_flops) 50 | -------------------------------------------------------------------------------- /LLaMA-Factory/scripts/stat_utils/length_cdf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections import defaultdict 16 | 17 | import fire 18 | from tqdm import tqdm 19 | 20 | from llamafactory.data import get_dataset, get_template_and_fix_tokenizer 21 | from llamafactory.hparams import get_train_args 22 | from llamafactory.model import load_tokenizer 23 | 24 | 25 | def length_cdf( 26 | model_name_or_path: str, 27 | dataset: str = "alpaca_en_demo", 28 | dataset_dir: str = "data", 29 | template: str = "default", 30 | interval: int = 1000, 31 | ): 32 | r"""Calculate the distribution of the input lengths in the dataset. 33 | 34 | Usage: export CUDA_VISIBLE_DEVICES=0 35 | python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default 36 | """ 37 | model_args, data_args, training_args, _, _ = get_train_args( 38 | dict( 39 | stage="sft", 40 | model_name_or_path=model_name_or_path, 41 | dataset=dataset, 42 | dataset_dir=dataset_dir, 43 | template=template, 44 | cutoff_len=1_000_000, 45 | preprocessing_num_workers=16, 46 | output_dir="dummy_dir", 47 | overwrite_cache=True, 48 | do_train=True, 49 | ) 50 | ) 51 | tokenizer_module = load_tokenizer(model_args) 52 | template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) 53 | trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"] 54 | total_num = len(trainset) 55 | length_dict = defaultdict(int) 56 | for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"): 57 | length_dict[len(sample) // interval * interval] += 1 58 | 59 | length_tuples = list(length_dict.items()) 60 | length_tuples.sort() 61 | count_accu, prob_accu = 0, 0 62 | for length, count in length_tuples: 63 | count_accu += count 64 | prob_accu += count / total_num * 100 65 | print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.") 66 | 67 | 68 | if __name__ == "__main__": 69 | fire.Fire(length_cdf) 70 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import uvicorn 18 | 19 | from llamafactory.api.app import create_app 20 | from llamafactory.chat import ChatModel 21 | 22 | 23 | def main(): 24 | chat_model = ChatModel() 25 | app = create_app(chat_model) 26 | api_host = os.getenv("API_HOST", "0.0.0.0") 27 | api_port = int(os.getenv("API_PORT", "8000")) 28 | print(f"Visit http://localhost:{api_port}/docs for API document.") 29 | uvicorn.run(app, host=api_host, port=api_port) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory.egg-info/entry_points.txt: -------------------------------------------------------------------------------- 1 | [console_scripts] 2 | llamafactory-cli = llamafactory.cli:main 3 | lmf = llamafactory.cli:main 4 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | transformers!=4.46.*,!=4.47.*,!=4.48.0,<=4.51.3,>=4.45.0 2 | datasets<=3.5.0,>=2.16.0 3 | accelerate<=1.6.0,>=0.34.0 4 | peft<=0.15.1,>=0.14.0 5 | trl<=0.9.6,>=0.8.6 6 | tokenizers<=0.21.1,>=0.19.0 7 | gradio<=5.25.0,>=4.38.0 8 | scipy 9 | einops 10 | sentencepiece 11 | tiktoken 12 | protobuf 13 | uvicorn 14 | fastapi 15 | sse-starlette 16 | matplotlib>=3.7.0 17 | fire 18 | omegaconf 19 | packaging 20 | pyyaml 21 | numpy<2.0.0 22 | pydantic<=2.10.6 23 | pandas>=2.0.0 24 | av 25 | librosa 26 | tyro<0.9.0 27 | 28 | [adam-mini] 29 | adam-mini 30 | 31 | [apollo] 32 | apollo-torch 33 | 34 | [aqlm] 35 | aqlm[gpu]>=1.1.0 36 | 37 | [awq] 38 | autoawq 39 | 40 | [badam] 41 | badam>=1.2.1 42 | 43 | [bitsandbytes] 44 | bitsandbytes>=0.39.0 45 | 46 | [deepspeed] 47 | deepspeed<=0.16.5,>=0.10.0 48 | 49 | [dev] 50 | pre-commit 51 | ruff 52 | pytest 53 | build 54 | 55 | [eetq] 56 | eetq 57 | 58 | [galore] 59 | galore-torch 60 | 61 | [gptq] 62 | optimum>=1.17.0 63 | auto-gptq>=0.5.0 64 | 65 | [hqq] 66 | hqq 67 | 68 | [liger-kernel] 69 | liger-kernel>=0.5.5 70 | 71 | [metrics] 72 | nltk 73 | jieba 74 | rouge-chinese 75 | 76 | [minicpm_v] 77 | soundfile 78 | torchvision 79 | torchaudio 80 | vector_quantize_pytorch 81 | vocos 82 | msgpack 83 | referencing 84 | jsonschema_specifications 85 | transformers==4.48.3 86 | 87 | [modelscope] 88 | modelscope 89 | 90 | [openmind] 91 | openmind 92 | 93 | [qwen] 94 | transformers_stream_generator 95 | 96 | [sglang] 97 | sglang[srt]>=0.4.5 98 | transformers==4.51.1 99 | 100 | [swanlab] 101 | swanlab 102 | 103 | [torch] 104 | torch>=1.13.1 105 | 106 | [torch-npu] 107 | torch==2.4.0 108 | torch-npu==2.4.0.post2 109 | decorator 110 | 111 | [vllm] 112 | vllm<=0.8.4,>=0.4.3 113 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | llamafactory 2 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Efficient fine-tuning of large language models. 16 | 17 | Level: 18 | api, webui > chat, eval, train > data, model > hparams > extras 19 | 20 | Disable version checking: DISABLE_VERSION_CHECK=1 21 | Enable VRAM recording: RECORD_VRAM=1 22 | Force using torchrun: FORCE_TORCHRUN=1 23 | Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN 24 | Use modelscope: USE_MODELSCOPE_HUB=1 25 | Use openmind: USE_OPENMIND_HUB=1 26 | """ 27 | 28 | from .extras.env import VERSION 29 | 30 | 31 | __version__ = VERSION 32 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/api/__init__.py -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/api/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | from typing import TYPE_CHECKING, Any 17 | 18 | 19 | if TYPE_CHECKING: 20 | from pydantic import BaseModel 21 | 22 | 23 | def dictify(data: "BaseModel") -> dict[str, Any]: 24 | try: # pydantic v2 25 | return data.model_dump(exclude_unset=True) 26 | except AttributeError: # pydantic v1 27 | return data.dict(exclude_unset=True) 28 | 29 | 30 | def jsonify(data: "BaseModel") -> str: 31 | try: # pydantic v2 32 | return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) 33 | except AttributeError: # pydantic v1 34 | return data.json(exclude_unset=True, ensure_ascii=False) 35 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/chat/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base_engine import BaseEngine 16 | from .chat_model import ChatModel 17 | 18 | 19 | __all__ = ["BaseEngine", "ChatModel"] 20 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .collator import ( 16 | KTODataCollatorWithPadding, 17 | MultiModalDataCollatorForSeq2Seq, 18 | PairwiseDataCollatorWithPadding, 19 | SFTDataCollatorWith4DAttentionMask, 20 | ) 21 | from .data_utils import Role, split_dataset 22 | from .loader import get_dataset 23 | from .template import TEMPLATES, Template, get_template_and_fix_tokenizer 24 | 25 | 26 | __all__ = [ 27 | "TEMPLATES", 28 | "KTODataCollatorWithPadding", 29 | "MultiModalDataCollatorForSeq2Seq", 30 | "PairwiseDataCollatorWithPadding", 31 | "Role", 32 | "SFTDataCollatorWith4DAttentionMask", 33 | "Template", 34 | "get_dataset", 35 | "get_template_and_fix_tokenizer", 36 | "split_dataset", 37 | ] 38 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/data/processor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .feedback import FeedbackDatasetProcessor 16 | from .pairwise import PairwiseDatasetProcessor 17 | from .pretrain import PretrainDatasetProcessor 18 | from .processor_utils import DatasetProcessor 19 | from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor 20 | from .unsupervised import UnsupervisedDatasetProcessor 21 | 22 | 23 | __all__ = [ 24 | "DatasetProcessor", 25 | "FeedbackDatasetProcessor", 26 | "PackedSupervisedDatasetProcessor", 27 | "PairwiseDatasetProcessor", 28 | "PretrainDatasetProcessor", 29 | "SupervisedDatasetProcessor", 30 | "UnsupervisedDatasetProcessor", 31 | ] 32 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/data/processor/pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from dataclasses import dataclass 19 | from itertools import chain 20 | from typing import Any 21 | 22 | from .processor_utils import DatasetProcessor 23 | 24 | 25 | @dataclass 26 | class PretrainDatasetProcessor(DatasetProcessor): 27 | def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: 28 | # build grouped texts with format `X1 X2 X3 ...` if packing is enabled 29 | eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token 30 | text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] 31 | 32 | if not self.data_args.packing: 33 | if getattr(self.tokenizer, "add_bos_token", False): 34 | text_examples = [self.tokenizer.bos_token + example for example in text_examples] 35 | 36 | result = self.tokenizer( 37 | text_examples, add_special_tokens=False, truncation=True, max_length=self.data_args.cutoff_len 38 | ) 39 | else: 40 | tokenized_examples = self.tokenizer(text_examples, add_special_tokens=False) 41 | concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} 42 | total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) 43 | block_size = self.data_args.cutoff_len 44 | total_length = (total_length // block_size) * block_size 45 | result = { 46 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 47 | for k, t in concatenated_examples.items() 48 | } 49 | if getattr(self.tokenizer, "add_bos_token", False): 50 | for i in range(len(result["input_ids"])): 51 | result["input_ids"][i][0] = self.tokenizer.bos_token_id 52 | 53 | return result 54 | 55 | def print_data_example(self, example: dict[str, list[int]]) -> None: 56 | print("input_ids:\n{}".format(example["input_ids"])) 57 | print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) 58 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/eval/__init__.py -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/eval/template.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | 17 | from ..data import Role 18 | from ..extras.constants import CHOICES 19 | 20 | 21 | @dataclass 22 | class EvalTemplate: 23 | system: str 24 | choice: str 25 | answer: str 26 | 27 | def _parse_example(self, example: dict[str, str]) -> tuple[str, str]: 28 | r"""Parse eval example. 29 | 30 | input: a dict with keys {"question", "A", "B", "C", "D", "answer"} 31 | output: a tuple of (prompt, response). 32 | """ 33 | candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] 34 | return "".join([example["question"]] + candidates + [self.answer]), example["answer"] 35 | 36 | def format_example( 37 | self, target_data: dict[str, str], support_set: list[dict[str, str]], subject_name: str 38 | ) -> list[dict[str, str]]: 39 | r"""Convert dataset examples to messages.""" 40 | messages = [] 41 | for k in range(len(support_set)): 42 | prompt, response = self._parse_example(support_set[k]) 43 | messages.append({"role": Role.USER.value, "content": prompt}) 44 | messages.append({"role": Role.ASSISTANT.value, "content": response}) 45 | 46 | prompt, response = self._parse_example(target_data) 47 | messages.append({"role": Role.USER.value, "content": prompt}) 48 | messages.append({"role": Role.ASSISTANT.value, "content": response}) 49 | messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"] 50 | return messages 51 | 52 | 53 | eval_templates: dict[str, "EvalTemplate"] = {} 54 | 55 | 56 | def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None: 57 | eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer) 58 | 59 | 60 | def get_eval_template(name: str) -> "EvalTemplate": 61 | eval_template = eval_templates.get(name, None) 62 | assert eval_template is not None, f"Template {name} does not exist." 63 | return eval_template 64 | 65 | 66 | _register_eval_template( 67 | name="en", 68 | system="The following are multiple choice questions (with answers) about {subject}.\n\n", 69 | choice="\n{choice}. {content}", 70 | answer="\nAnswer:", 71 | ) 72 | 73 | 74 | _register_eval_template( 75 | name="zh", 76 | system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", 77 | choice="\n{choice}. {content}", 78 | answer="\n答案:", 79 | ) 80 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/extras/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/extras/__init__.py -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/extras/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import platform 19 | 20 | import accelerate 21 | import datasets 22 | import peft 23 | import torch 24 | import transformers 25 | import trl 26 | from transformers.utils import is_torch_cuda_available, is_torch_npu_available 27 | 28 | 29 | VERSION = "0.9.3.dev0" 30 | 31 | 32 | def print_env() -> None: 33 | info = { 34 | "`llamafactory` version": VERSION, 35 | "Platform": platform.platform(), 36 | "Python version": platform.python_version(), 37 | "PyTorch version": torch.__version__, 38 | "Transformers version": transformers.__version__, 39 | "Datasets version": datasets.__version__, 40 | "Accelerate version": accelerate.__version__, 41 | "PEFT version": peft.__version__, 42 | "TRL version": trl.__version__, 43 | } 44 | 45 | if is_torch_cuda_available(): 46 | info["PyTorch version"] += " (GPU)" 47 | info["GPU type"] = torch.cuda.get_device_name() 48 | info["GPU number"] = torch.cuda.device_count() 49 | info["GPU memory"] = f"{torch.cuda.mem_get_info()[1] / (1024**3):.2f}GB" 50 | 51 | if is_torch_npu_available(): 52 | info["PyTorch version"] += " (NPU)" 53 | info["NPU type"] = torch.npu.get_device_name() 54 | info["CANN version"] = torch.version.cann 55 | 56 | try: 57 | import deepspeed # type: ignore 58 | 59 | info["DeepSpeed version"] = deepspeed.__version__ 60 | except Exception: 61 | pass 62 | 63 | try: 64 | import bitsandbytes # type: ignore 65 | 66 | info["Bitsandbytes version"] = bitsandbytes.__version__ 67 | except Exception: 68 | pass 69 | 70 | try: 71 | import vllm 72 | 73 | info["vLLM version"] = vllm.__version__ 74 | except Exception: 75 | pass 76 | 77 | try: 78 | import subprocess 79 | 80 | commit_info = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True) 81 | commit_hash = commit_info.stdout.strip() 82 | info["Git commit"] = commit_hash 83 | except Exception: 84 | pass 85 | 86 | print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n") 87 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/extras/packages.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 HuggingFace Inc. and the LlamaFactory team. 2 | # 3 | # This code is inspired by the HuggingFace's transformers library. 4 | # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import importlib.metadata 19 | import importlib.util 20 | from functools import lru_cache 21 | from typing import TYPE_CHECKING 22 | 23 | from packaging import version 24 | 25 | 26 | if TYPE_CHECKING: 27 | from packaging.version import Version 28 | 29 | 30 | def _is_package_available(name: str) -> bool: 31 | return importlib.util.find_spec(name) is not None 32 | 33 | 34 | def _get_package_version(name: str) -> "Version": 35 | try: 36 | return version.parse(importlib.metadata.version(name)) 37 | except Exception: 38 | return version.parse("0.0.0") 39 | 40 | 41 | def is_pyav_available(): 42 | return _is_package_available("av") 43 | 44 | 45 | def is_librosa_available(): 46 | return _is_package_available("librosa") 47 | 48 | 49 | def is_fastapi_available(): 50 | return _is_package_available("fastapi") 51 | 52 | 53 | def is_galore_available(): 54 | return _is_package_available("galore_torch") 55 | 56 | 57 | def is_apollo_available(): 58 | return _is_package_available("apollo_torch") 59 | 60 | 61 | def is_gradio_available(): 62 | return _is_package_available("gradio") 63 | 64 | 65 | def is_matplotlib_available(): 66 | return _is_package_available("matplotlib") 67 | 68 | 69 | def is_pillow_available(): 70 | return _is_package_available("PIL") 71 | 72 | 73 | def is_ray_available(): 74 | return _is_package_available("ray") 75 | 76 | 77 | def is_requests_available(): 78 | return _is_package_available("requests") 79 | 80 | 81 | def is_rouge_available(): 82 | return _is_package_available("rouge_chinese") 83 | 84 | 85 | def is_starlette_available(): 86 | return _is_package_available("sse_starlette") 87 | 88 | 89 | @lru_cache 90 | def is_transformers_version_greater_than(content: str): 91 | return _get_package_version("transformers") >= version.parse(content) 92 | 93 | 94 | def is_uvicorn_available(): 95 | return _is_package_available("uvicorn") 96 | 97 | 98 | def is_vllm_available(): 99 | return _is_package_available("vllm") 100 | 101 | 102 | def is_sglang_available(): 103 | return _is_package_available("sglang") 104 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/hparams/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .data_args import DataArguments 16 | from .evaluation_args import EvaluationArguments 17 | from .finetuning_args import FinetuningArguments 18 | from .generating_args import GeneratingArguments 19 | from .model_args import ModelArguments 20 | from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args 21 | from .training_args import RayArguments, TrainingArguments 22 | 23 | 24 | __all__ = [ 25 | "DataArguments", 26 | "EvaluationArguments", 27 | "FinetuningArguments", 28 | "GeneratingArguments", 29 | "ModelArguments", 30 | "RayArguments", 31 | "TrainingArguments", 32 | "get_eval_args", 33 | "get_infer_args", 34 | "get_ray_args", 35 | "get_train_args", 36 | "read_args", 37 | ] 38 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/hparams/evaluation_args.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from dataclasses import dataclass, field 17 | from typing import Literal, Optional 18 | 19 | from datasets import DownloadMode 20 | 21 | 22 | @dataclass 23 | class EvaluationArguments: 24 | r"""Arguments pertaining to specify the evaluation parameters.""" 25 | 26 | task: str = field( 27 | metadata={"help": "Name of the evaluation task."}, 28 | ) 29 | task_dir: str = field( 30 | default="evaluation", 31 | metadata={"help": "Path to the folder containing the evaluation datasets."}, 32 | ) 33 | batch_size: int = field( 34 | default=4, 35 | metadata={"help": "The batch size per GPU for evaluation."}, 36 | ) 37 | seed: int = field( 38 | default=42, 39 | metadata={"help": "Random seed to be used with data loaders."}, 40 | ) 41 | lang: Literal["en", "zh"] = field( 42 | default="en", 43 | metadata={"help": "Language used at evaluation."}, 44 | ) 45 | n_shot: int = field( 46 | default=5, 47 | metadata={"help": "Number of examplars for few-shot learning."}, 48 | ) 49 | save_dir: Optional[str] = field( 50 | default=None, 51 | metadata={"help": "Path to save the evaluation results."}, 52 | ) 53 | download_mode: DownloadMode = field( 54 | default=DownloadMode.REUSE_DATASET_IF_EXISTS, 55 | metadata={"help": "Download mode used for the evaluation datasets."}, 56 | ) 57 | 58 | def __post_init__(self): 59 | if self.save_dir is not None and os.path.exists(self.save_dir): 60 | raise ValueError("`save_dir` already exists, use another one.") 61 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from llamafactory.train.tuner import run_exp # use absolute import 16 | 17 | 18 | def launch(): 19 | run_exp() 20 | 21 | 22 | if __name__ == "__main__": 23 | launch() 24 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .loader import load_config, load_model, load_tokenizer 16 | from .model_utils.misc import find_all_linear_modules 17 | from .model_utils.quantization import QuantizationMethod 18 | from .model_utils.valuehead import load_valuehead_params 19 | 20 | 21 | __all__ = [ 22 | "QuantizationMethod", 23 | "find_all_linear_modules", 24 | "load_config", 25 | "load_model", 26 | "load_tokenizer", 27 | "load_valuehead_params", 28 | ] 29 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/model/model_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/model/model_utils/__init__.py -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/model/model_utils/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from contextlib import nullcontext 17 | from typing import TYPE_CHECKING 18 | 19 | import torch 20 | from transformers.integrations import is_deepspeed_zero3_enabled 21 | 22 | from ...extras import logging 23 | 24 | 25 | if TYPE_CHECKING: 26 | from transformers import PreTrainedModel, PreTrainedTokenizer 27 | 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: 33 | embedding_dim = embed_weight.size(1) 34 | avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) 35 | noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) 36 | noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) 37 | embed_weight[-num_new_tokens:] = avg_weight + noise_weight 38 | 39 | 40 | def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: 41 | r"""Resize token embeddings.""" 42 | if is_deepspeed_zero3_enabled(): 43 | import deepspeed # type: ignore 44 | 45 | params = [model.get_input_embeddings().weight] 46 | if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: 47 | params.append(model.get_output_embeddings().weight) 48 | 49 | context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) 50 | else: 51 | context_maybe_zero3 = nullcontext() 52 | 53 | with context_maybe_zero3: 54 | current_embedding_size = model.get_input_embeddings().weight.size(0) 55 | 56 | if len(tokenizer) > current_embedding_size: 57 | if getattr(model, "quantization_method", None): 58 | raise ValueError("Cannot resize embedding layers of a quantized model.") 59 | 60 | if not isinstance(model.get_output_embeddings(), torch.nn.Linear): 61 | raise ValueError("Current model does not support resizing embedding layers.") 62 | 63 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) 64 | with context_maybe_zero3: 65 | new_embedding_size = model.get_input_embeddings().weight.size(0) 66 | num_new_tokens = new_embedding_size - current_embedding_size 67 | _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) 68 | _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) 69 | 70 | logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.") 71 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/model/model_utils/kv_cache.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ...extras import logging 18 | 19 | 20 | logger = logging.get_logger(__name__) 21 | 22 | 23 | if TYPE_CHECKING: 24 | from transformers import PretrainedConfig 25 | 26 | from ...hparams import ModelArguments 27 | 28 | 29 | def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: 30 | if not is_trainable: 31 | setattr(config, "use_cache", model_args.use_cache) 32 | if hasattr(config, "text_config"): 33 | setattr(config.text_config, "use_cache", model_args.use_cache) 34 | 35 | if model_args.use_cache: 36 | logger.info_rank0("KV cache is enabled for faster generation.") 37 | else: 38 | logger.info_rank0("KV cache is disabled.") 39 | else: 40 | setattr(config, "use_cache", False) 41 | if hasattr(config, "text_config"): 42 | setattr(config.text_config, "use_cache", False) 43 | 44 | logger.info_rank0("KV cache is disabled during training.") 45 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/model/model_utils/mod.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ...extras.constants import MOD_SUPPORTED_MODELS 18 | 19 | 20 | if TYPE_CHECKING: 21 | from transformers import PretrainedConfig, PreTrainedModel 22 | 23 | from ...hparams import ModelArguments 24 | 25 | 26 | def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel": 27 | from MoD import AutoMoDModelForCausalLM 28 | 29 | return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs) 30 | 31 | 32 | def convert_pretrained_model_to_mod( 33 | model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments" 34 | ) -> "PreTrainedModel": 35 | from MoD import apply_mod_to_hf 36 | 37 | if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS: 38 | raise ValueError("Current model is not supported by mixture-of-depth.") 39 | 40 | model = apply_mod_to_hf(model) 41 | model = model.to(model_args.compute_dtype) 42 | return model 43 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/model/model_utils/rope.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 LMSYS and the LlamaFactory team. 2 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 3 | # 4 | # This code is inspired by the LMSYS's FastChat library. 5 | # https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import math 20 | from typing import TYPE_CHECKING 21 | 22 | from ...extras import logging 23 | from ...extras.constants import RopeScaling 24 | 25 | 26 | if TYPE_CHECKING: 27 | from transformers import PretrainedConfig 28 | 29 | from ...hparams import ModelArguments 30 | 31 | 32 | logger = logging.get_logger(__name__) 33 | 34 | 35 | def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: 36 | if model_args.rope_scaling is None: 37 | return 38 | 39 | if not hasattr(config, "rope_scaling"): 40 | logger.warning_rank0("Current model does not support RoPE scaling.") 41 | return 42 | 43 | rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum 44 | if model_args.model_max_length is not None: 45 | if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC: 46 | logger.warning_rank0( 47 | "Dynamic NTK scaling may not work well with fine-tuning. " 48 | "See: https://github.com/huggingface/transformers/pull/24653" 49 | ) 50 | 51 | current_max_length = getattr(config, "max_position_embeddings", None) 52 | if (not current_max_length) or model_args.model_max_length <= current_max_length: 53 | logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.") 54 | return 55 | 56 | logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") 57 | setattr(config, "max_position_embeddings", model_args.model_max_length) 58 | rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length)) 59 | if model_args.rope_scaling == RopeScaling.DYNAMIC: 60 | rope_kwargs["original_max_position_embeddings"] = current_max_length 61 | elif model_args.rope_scaling == RopeScaling.LLAMA3: 62 | rope_kwargs["original_max_position_embeddings"] = current_max_length 63 | rope_kwargs["low_freq_factor"] = 1.0 64 | rope_kwargs["high_freq_factor"] = 4.0 65 | else: 66 | rope_kwargs["factor"] = 2.0 67 | 68 | setattr(config, "rope_scaling", rope_kwargs) 69 | logger.info_rank0( 70 | f"Using {rope_kwargs['rope_type']} scaling strategy and setting scaling factor to {rope_kwargs['factor']}." 71 | ) 72 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/model/model_utils/valuehead.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | import torch 18 | from transformers.utils import cached_file 19 | 20 | from ...extras import logging 21 | from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME 22 | 23 | 24 | if TYPE_CHECKING: 25 | from transformers import PreTrainedModel 26 | 27 | from ...hparams import ModelArguments 28 | 29 | 30 | logger = logging.get_logger(__name__) 31 | 32 | 33 | def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> dict[str, torch.Tensor]: 34 | r"""Load value head parameters from Hugging Face Hub or local disk. 35 | 36 | Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. 37 | """ 38 | kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token} 39 | err_text = "" 40 | 41 | try: 42 | from safetensors import safe_open 43 | 44 | vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs) 45 | with safe_open(vhead_file, framework="pt", device="cpu") as f: 46 | return {key: f.get_tensor(key) for key in f.keys()} 47 | except Exception as err: 48 | err_text = str(err) 49 | 50 | try: 51 | vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) 52 | return torch.load(vhead_file, map_location="cpu") 53 | except Exception as err: 54 | err_text = str(err) 55 | 56 | logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.") 57 | logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.") 58 | return None 59 | 60 | 61 | def prepare_valuehead_model(model: "PreTrainedModel") -> None: 62 | if getattr(model.config, "model_type", None) == "llava": 63 | setattr(model, "lm_head", model.language_model.get_output_embeddings()) 64 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) 65 | 66 | if getattr(model.config, "model_type", None) == "chatglm": 67 | setattr(model, "lm_head", model.transformer.output_layer) 68 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) 69 | 70 | if getattr(model.config, "model_type", None) == "internlm2": 71 | setattr(model, "lm_head", model.output) 72 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) 73 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/third_party/__init__.py -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/third_party/muon/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .muon import Muon 16 | 17 | 18 | __all__ = ["Muon"] 19 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/train/__init__.py -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/train/dpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_dpo 16 | 17 | 18 | __all__ = ["run_dpo"] 19 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/train/kto/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_kto 16 | 17 | 18 | __all__ = ["run_kto"] 19 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/train/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_ppo 16 | 17 | 18 | __all__ = ["run_ppo"] 19 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/train/pt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_pt 16 | 17 | 18 | __all__ = ["run_pt"] 19 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/train/pt/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from types import MethodType 16 | from typing import TYPE_CHECKING, Optional 17 | 18 | import torch 19 | from transformers import Trainer 20 | from typing_extensions import override 21 | 22 | from ...extras.packages import is_transformers_version_greater_than 23 | from ..callbacks import SaveProcessorCallback 24 | from ..trainer_utils import create_custom_optimizer, create_custom_scheduler 25 | 26 | 27 | if TYPE_CHECKING: 28 | from transformers import ProcessorMixin 29 | 30 | from ...hparams import FinetuningArguments 31 | 32 | 33 | class CustomTrainer(Trainer): 34 | r"""Inherit Trainer for custom optimizer.""" 35 | 36 | def __init__( 37 | self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs 38 | ) -> None: 39 | if is_transformers_version_greater_than("4.46"): 40 | kwargs["processing_class"] = kwargs.pop("tokenizer") 41 | 42 | super().__init__(**kwargs) 43 | if processor is not None: 44 | # avoid wrong loss under gradient accumulation 45 | # https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112 46 | self.model_accepts_loss_kwargs = False 47 | 48 | self.finetuning_args = finetuning_args 49 | 50 | if processor is not None: 51 | self.add_callback(SaveProcessorCallback(processor)) 52 | 53 | if finetuning_args.use_badam: 54 | from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore 55 | 56 | self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) 57 | self.add_callback(BAdamCallback) 58 | 59 | @override 60 | def create_optimizer(self) -> "torch.optim.Optimizer": 61 | if self.optimizer is None: 62 | self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) 63 | return super().create_optimizer() 64 | 65 | @override 66 | def create_scheduler( 67 | self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None 68 | ) -> "torch.optim.lr_scheduler.LRScheduler": 69 | create_custom_scheduler(self.args, num_training_steps, optimizer) 70 | return super().create_scheduler(num_training_steps, optimizer) 71 | 72 | @override 73 | def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: 74 | if self.finetuning_args.disable_shuffling: 75 | return torch.utils.data.SequentialSampler(self.train_dataset) 76 | 77 | return super()._get_train_sampler() 78 | 79 | @override 80 | def compute_loss(self, model, inputs, *args, **kwargs): 81 | return super().compute_loss(model, inputs, *args, **kwargs) 82 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/train/rm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_rm 16 | 17 | 18 | __all__ = ["run_rm"] 19 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/train/rm/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import TYPE_CHECKING, Optional 17 | 18 | import numpy as np 19 | 20 | from ...extras.misc import numpify 21 | 22 | 23 | if TYPE_CHECKING: 24 | from transformers import EvalPrediction 25 | 26 | 27 | @dataclass 28 | class ComputeAccuracy: 29 | r"""Compute reward accuracy and support `batch_eval_metrics`.""" 30 | 31 | def _dump(self) -> Optional[dict[str, float]]: 32 | result = None 33 | if hasattr(self, "score_dict"): 34 | result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} 35 | 36 | self.score_dict = {"accuracy": []} 37 | return result 38 | 39 | def __post_init__(self): 40 | self._dump() 41 | 42 | def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]: 43 | chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1]) 44 | if not chosen_scores.shape: 45 | self.score_dict["accuracy"].append(chosen_scores > rejected_scores) 46 | else: 47 | for i in range(len(chosen_scores)): 48 | self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i]) 49 | 50 | if compute_result: 51 | return self._dump() 52 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/train/sft/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .workflow import run_sft 16 | 17 | 18 | __all__ = ["run_sft"] 19 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/webui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/LLaMA-Factory/src/llamafactory/webui/__init__.py -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/webui/components/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .chatbot import create_chat_box 16 | from .eval import create_eval_tab 17 | from .export import create_export_tab 18 | from .infer import create_infer_tab 19 | from .top import create_top 20 | from .train import create_train_tab 21 | 22 | 23 | __all__ = [ 24 | "create_chat_box", 25 | "create_eval_tab", 26 | "create_export_tab", 27 | "create_infer_tab", 28 | "create_top", 29 | "create_train_tab", 30 | ] 31 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/webui/components/infer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ...extras.packages import is_gradio_available 18 | from ..common import is_multimodal 19 | from .chatbot import create_chat_box 20 | 21 | 22 | if is_gradio_available(): 23 | import gradio as gr 24 | 25 | 26 | if TYPE_CHECKING: 27 | from gradio.components import Component 28 | 29 | from ..engine import Engine 30 | 31 | 32 | def create_infer_tab(engine: "Engine") -> dict[str, "Component"]: 33 | input_elems = engine.manager.get_base_elems() 34 | elem_dict = dict() 35 | 36 | with gr.Row(): 37 | infer_backend = gr.Dropdown(choices=["huggingface", "vllm", "sglang"], value="huggingface") 38 | infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto") 39 | 40 | with gr.Row(): 41 | load_btn = gr.Button() 42 | unload_btn = gr.Button() 43 | 44 | info_box = gr.Textbox(show_label=False, interactive=False) 45 | 46 | input_elems.update({infer_backend, infer_dtype}) 47 | elem_dict.update( 48 | dict( 49 | infer_backend=infer_backend, 50 | infer_dtype=infer_dtype, 51 | load_btn=load_btn, 52 | unload_btn=unload_btn, 53 | info_box=info_box, 54 | ) 55 | ) 56 | 57 | chatbot, messages, chat_elems = create_chat_box(engine, visible=False) 58 | elem_dict.update(chat_elems) 59 | 60 | load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then( 61 | lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]] 62 | ) 63 | 64 | unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then( 65 | lambda: ([], []), outputs=[chatbot, messages] 66 | ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]) 67 | 68 | engine.manager.get_elem_by_id("top.model_name").change( 69 | lambda model_name: gr.Column(visible=is_multimodal(model_name)), 70 | [engine.manager.get_elem_by_id("top.model_name")], 71 | [chat_elems["mm_box"]], 72 | ) 73 | 74 | return elem_dict 75 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/webui/css.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | CSS = r""" 16 | .duplicate-button { 17 | margin: auto !important; 18 | color: white !important; 19 | background: black !important; 20 | border-radius: 100vh !important; 21 | } 22 | 23 | .thinking-summary { 24 | padding: 8px !important; 25 | } 26 | 27 | .thinking-summary span { 28 | border-radius: 4px !important; 29 | padding: 4px !important; 30 | cursor: pointer !important; 31 | font-size: 14px !important; 32 | background: rgb(245, 245, 245) !important; 33 | } 34 | 35 | .dark .thinking-summary span { 36 | background: rgb(73, 73, 73) !important; 37 | } 38 | 39 | .thinking-container { 40 | border-left: 2px solid #a6a6a6 !important; 41 | padding-left: 10px !important; 42 | margin: 4px 0 !important; 43 | } 44 | 45 | .thinking-container p { 46 | color: #a6a6a6 !important; 47 | } 48 | 49 | .modal-box { 50 | position: fixed !important; 51 | top: 50%; 52 | left: 50%; 53 | transform: translate(-50%, -50%); /* center horizontally */ 54 | max-width: 1000px; 55 | max-height: 750px; 56 | overflow-y: auto; 57 | background-color: var(--input-background-fill); 58 | flex-wrap: nowrap !important; 59 | border: 2px solid black !important; 60 | z-index: 1000; 61 | padding: 10px; 62 | } 63 | 64 | .dark .modal-box { 65 | border: 2px solid white !important; 66 | } 67 | """ 68 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/llamafactory/webui/manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Generator 16 | from typing import TYPE_CHECKING 17 | 18 | 19 | if TYPE_CHECKING: 20 | from gradio.components import Component 21 | 22 | 23 | class Manager: 24 | r"""A class to manage all the gradio components in Web UI.""" 25 | 26 | def __init__(self) -> None: 27 | self._id_to_elem: dict[str, Component] = {} 28 | self._elem_to_id: dict[Component, str] = {} 29 | 30 | def add_elems(self, tab_name: str, elem_dict: dict[str, "Component"]) -> None: 31 | r"""Add elements to manager.""" 32 | for elem_name, elem in elem_dict.items(): 33 | elem_id = f"{tab_name}.{elem_name}" 34 | self._id_to_elem[elem_id] = elem 35 | self._elem_to_id[elem] = elem_id 36 | 37 | def get_elem_list(self) -> list["Component"]: 38 | r"""Return the list of all elements.""" 39 | return list(self._id_to_elem.values()) 40 | 41 | def get_elem_iter(self) -> Generator[tuple[str, "Component"], None, None]: 42 | r"""Return an iterator over all elements with their names.""" 43 | for elem_id, elem in self._id_to_elem.items(): 44 | yield elem_id.split(".")[-1], elem 45 | 46 | def get_elem_by_id(self, elem_id: str) -> "Component": 47 | r"""Get element by id. 48 | 49 | Example: top.lang, train.dataset 50 | """ 51 | return self._id_to_elem[elem_id] 52 | 53 | def get_id_by_elem(self, elem: "Component") -> str: 54 | r"""Get id by element.""" 55 | return self._elem_to_id[elem] 56 | 57 | def get_base_elems(self) -> set["Component"]: 58 | r"""Get the base elements that are commonly used.""" 59 | return { 60 | self._id_to_elem["top.lang"], 61 | self._id_to_elem["top.model_name"], 62 | self._id_to_elem["top.model_path"], 63 | self._id_to_elem["top.finetuning_type"], 64 | self._id_to_elem["top.checkpoint_path"], 65 | self._id_to_elem["top.quantization_bit"], 66 | self._id_to_elem["top.quantization_method"], 67 | self._id_to_elem["top.template"], 68 | self._id_to_elem["top.rope_scaling"], 69 | self._id_to_elem["top.booster"], 70 | } 71 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from llamafactory.train.tuner import run_exp 16 | 17 | 18 | def main(): 19 | run_exp() 20 | 21 | 22 | def _mp_fn(index): 23 | # For xla_spawn (TPUs) 24 | run_exp() 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /LLaMA-Factory/src/webui.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from llamafactory.extras.misc import fix_proxy, is_env_enabled 18 | from llamafactory.webui.interface import create_ui 19 | 20 | 21 | def main(): 22 | gradio_ipv6 = is_env_enabled("GRADIO_IPV6") 23 | gradio_share = is_env_enabled("GRADIO_SHARE") 24 | server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0") 25 | print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860") 26 | fix_proxy(ipv6_enabled=gradio_ipv6) 27 | create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True) 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /assets/data_composition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/data_composition.png -------------------------------------------------------------------------------- /assets/exp_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/exp_main.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/overview.png -------------------------------------------------------------------------------- /assets/reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/reward.png -------------------------------------------------------------------------------- /assets/scability.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/scability.png -------------------------------------------------------------------------------- /assets/sft_or_rl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/sft_or_rl.png -------------------------------------------------------------------------------- /assets/thinking_template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/assets/thinking_template.png -------------------------------------------------------------------------------- /llamafact_merge.sh: -------------------------------------------------------------------------------- 1 | cd LLaMA-Factory 2 | llamafactory-cli export examples/qwen.yaml -------------------------------------------------------------------------------- /paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Tool-N1/f9b48492c7894dcdfeb0431def8bda24249a4e9c/paper.pdf -------------------------------------------------------------------------------- /qwen_rl.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | #!/usr/bin/env bash 18 | 19 | export N_GPUS=8 20 | export BASE_MODEL="path/to/model/Qwen2.5-7B-Instruct" 21 | export DATA_DIR="verl/verl/data" 22 | export ROLLOUT_TP_SIZE=2 23 | export VLLM_ATTENTION_BACKEND="XFORMERS" 24 | 25 | export GPU_UT=0.6 26 | export BA_SIZE=1024 27 | export MAX_PROMPT_LEN=4096 28 | export PRO_NAME="qwen" 29 | export EXPERIMENT_NAME="qwen" 30 | export LOG_DIR="path/to/logs/qwen.txt" 31 | 32 | export LR=1e-6 33 | export ENTROPY=0 34 | export MAX_RES=8192 35 | export TEMPERATURE=0.7 36 | export EPOCH=7 37 | export KL_COE=0.001 38 | 39 | bash verl/examples/agent/qwen.sh -------------------------------------------------------------------------------- /qwen_sft.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export PROJ_NAME="qwen_sft" 4 | export OUTPUT_DIR="path/to/output/dir" 5 | export MODEL_PATH="path/to/model/Qwen2.5-7B-Instruct" 6 | export LOG_DIR="logs/$PROJ_NAME.txt" 7 | 8 | export LR=2.0e-5 9 | export EPOCH=20 10 | export BATCH_SIZE=4 11 | export G_ACC=8 12 | 13 | cd LLaMA-Factory 14 | 15 | FORCE_TORCHRUN=1 llamafactory-cli train examples/qwen_sft.yaml \ 16 | learning_rate=$LR \ 17 | num_train_epochs=$EPOCH \ 18 | per_device_train_batch_size=$BATCH_SIZE \ 19 | gradient_accumulation_steps=$G_ACC \ 20 | output_dir=$OUTPUT_DIR \ 21 | run_name=$PROJ_NAME \ 22 | model_name_or_path=$MODEL_PATH 2>&1 | tee -a "${LOG_DIR}" -------------------------------------------------------------------------------- /verl/examples/agent/qwen.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | 18 | python3 -m verl.trainer.main_ppo \ 19 | algorithm.adv_estimator=grpo \ 20 | actor_rollout_ref.actor.use_kl_loss=True \ 21 | actor_rollout_ref.actor.kl_loss_coef=$KL_COE \ 22 | actor_rollout_ref.actor.entropy_coeff=$ENTROPY \ 23 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \ 24 | actor_rollout_ref.actor.optim.lr=$LR \ 25 | actor_rollout_ref.actor.fsdp_config.param_offload=False \ 26 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ 27 | actor_rollout_ref.actor.use_dynamic_bsz=True \ 28 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30720 \ 29 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 30 | actor_rollout_ref.model.path=$BASE_MODEL \ 31 | actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \ 32 | actor_rollout_ref.rollout.gpu_memory_utilization=$GPU_UT \ 33 | actor_rollout_ref.rollout.temperature=$TEMPERATURE \ 34 | actor_rollout_ref.rollout.n=5 \ 35 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 36 | actor_rollout_ref.model.use_remove_padding=True \ 37 | data.train_files=$DATA_DIR/train.parquet \ 38 | data.val_files=$DATA_DIR/test.parquet \ 39 | data.train_batch_size=$BA_SIZE \ 40 | data.val_batch_size=1312 \ 41 | data.max_prompt_length=$MAX_PROMPT_LEN \ 42 | data.max_response_length=$MAX_RES \ 43 | algorithm.kl_ctrl.kl_coef=$KL_COE \ 44 | trainer.critic_warmup=0 \ 45 | trainer.logger=['wandb'] \ 46 | +trainer.val_before_train=False \ 47 | +actor_rollout_ref.actor.fsdp_config.grad_offload=False \ 48 | trainer.default_hdfs_dir=null \ 49 | trainer.n_gpus_per_node=$N_GPUS \ 50 | trainer.nnodes=1 \ 51 | trainer.save_freq=10 \ 52 | trainer.test_freq=10 \ 53 | trainer.project_name=$PRO_NAME \ 54 | trainer.experiment_name=$EXPERIMENT_NAME \ 55 | trainer.total_epochs=$EPOCH 2>&1 | tee -a "${LOG_DIR}" -------------------------------------------------------------------------------- /verl/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # setup.py is the fallback installation script when pyproject.toml does not work 16 | from setuptools import setup, find_packages 17 | import os 18 | 19 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) 20 | 21 | with open(os.path.join(version_folder, 'verl/version/version')) as f: 22 | __version__ = f.read().strip() 23 | 24 | install_requires = [ 25 | 'accelerate', 26 | 'codetiming', 27 | 'datasets', 28 | 'dill', 29 | 'hydra-core', 30 | 'numpy', 31 | 'pandas', 32 | 'peft', 33 | 'pyarrow>=15.0.0', 34 | 'pybind11', 35 | 'pylatexenc', 36 | 'ray>=2.10', 37 | 'tensordict<0.6', 38 | 'torchdata', 39 | 'transformers', 40 | 'vllm<=0.6.3', 41 | 'wandb', 42 | ] 43 | 44 | TEST_REQUIRES = ['pytest', 'yapf', 'py-spy'] 45 | PRIME_REQUIRES = ['pyext'] 46 | GEO_REQUIRES = ['mathruler'] 47 | GPU_REQUIRES = ['liger-kernel', 'flash-attn'] 48 | 49 | extras_require = { 50 | 'test': TEST_REQUIRES, 51 | 'prime': PRIME_REQUIRES, 52 | 'geo': GEO_REQUIRES, 53 | 'gpu': GPU_REQUIRES, 54 | } 55 | 56 | from pathlib import Path 57 | this_directory = Path(__file__).parent 58 | long_description = "Volcano Engine Reinforcement Learning for LLM" 59 | 60 | setup( 61 | name='verl', 62 | version=__version__, 63 | package_dir={'': '.'}, 64 | packages=find_packages(where='.'), 65 | url='https://github.com/volcengine/verl', 66 | license='Apache 2.0', 67 | author='Bytedance - Seed - MLSys', 68 | author_email='zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk', 69 | description='verl: Volcano Engine Reinforcement Learning for LLM', 70 | install_requires=install_requires, 71 | extras_require=extras_require, 72 | package_data={'': ['version/*'], 73 | 'verl': ['trainer/config/*.yaml'],}, 74 | include_package_data=True, 75 | long_description=long_description, 76 | long_description_content_type='text/markdown' 77 | ) -------------------------------------------------------------------------------- /verl/verl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) 18 | 19 | with open(os.path.join(version_folder, 'version/version')) as f: 20 | __version__ = f.read().strip() 21 | 22 | from .protocol import DataProto 23 | 24 | from .utils.logging_utils import set_basic_config 25 | import logging 26 | 27 | set_basic_config(level=logging.WARNING) 28 | 29 | from . import single_controller 30 | 31 | __all__ = ['DataProto', "__version__"] -------------------------------------------------------------------------------- /verl/verl/data/test.parquet: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:674fbfa58ee11be7034a505323cdfa49b9863d3dd2c0e31131a08a09709869c9 3 | size 9338664 4 | -------------------------------------------------------------------------------- /verl/verl/data/train.parquet: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:cb080d5ae5f3b84423d9bbfd1d37923b8f59282fe7778345c459cf1371a25bac 3 | size 83495252 4 | -------------------------------------------------------------------------------- /verl/verl/models/README.md: -------------------------------------------------------------------------------- 1 | # Models 2 | Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl. 3 | ## Adding a New Huggingface Model 4 | ### Step 1: Copy the model file from HF to verl 5 | - Add a new file under verl/models/hf 6 | - Copy ONLY the model file from huggingface/transformers/models to verl/models/hf 7 | 8 | ### Step 2: Modify the model file to use packed inputs 9 | - Remove all the code related to inference (kv cache) 10 | - Modify the inputs to include only 11 | - input_ids (total_nnz,) 12 | - cu_seqlens (total_nnz + 1,) 13 | - max_seqlen_in_batch: int 14 | - Note that this requires using flash attention with causal mask. 15 | 16 | ### Step 2.5: Add tests 17 | - Add a test to compare this version and the huggingface version 18 | - Following the infrastructure and add tests to tests/models/hf 19 | 20 | ### Step 3: Add a function to apply tensor parallelism 21 | - Please follow 22 | - https://pytorch.org/docs/stable/distributed.tensor.parallel.html 23 | - https://pytorch.org/tutorials/intermediate/TP_tutorial.html 24 | - General comments 25 | - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward. 26 | 27 | ### Step 4: Add a function to apply data parallelism 28 | - Please use FSDP2 APIs 29 | - See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413 30 | 31 | ### Step 5: Add a function to apply pipeline parallelism 32 | - Comes in Pytorch 2.4 33 | - Currently only in alpha in nightly version 34 | - Check torchtitan for more details 35 | 36 | -------------------------------------------------------------------------------- /verl/verl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/models/llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/models/llama/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .modeling_llama_megatron import ( 16 | # original model with megatron 17 | ParallelLlamaModel, 18 | ParallelLlamaForCausalLM, 19 | # rmpad with megatron 20 | ParallelLlamaForCausalLMRmPad, 21 | ParallelLlamaForValueRmPad, 22 | # rmpad with megatron and pipeline parallelism 23 | ParallelLlamaForCausalLMRmPadPP, 24 | ParallelLlamaForValueRmPadPP) 25 | -------------------------------------------------------------------------------- /verl/verl/models/llama/megatron/checkpoint_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/models/llama/megatron/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .parallel_attention import ParallelLlamaAttention 16 | from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad 17 | from .parallel_mlp import ParallelLlamaMLP 18 | from .parallel_rmsnorm import ParallelLlamaRMSNorm 19 | -------------------------------------------------------------------------------- /verl/verl/models/llama/megatron/layers/parallel_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py 15 | 16 | from typing import Optional, Tuple 17 | 18 | from megatron.core import tensor_parallel 19 | 20 | 21 | class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): 22 | 23 | def __init__(self, 24 | input_size, 25 | num_heads, 26 | num_key_value_heads, 27 | head_dim, 28 | *, 29 | bias=True, 30 | gather_output=True, 31 | skip_bias_add=False, 32 | **kwargs): 33 | # Keep input parameters, and already restrict the head numbers 34 | self.input_size = input_size 35 | self.q_output_size = num_heads * head_dim 36 | self.kv_output_size = num_key_value_heads * head_dim 37 | self.head_dim = head_dim 38 | self.gather_output = gather_output 39 | self.skip_bias_add = skip_bias_add 40 | 41 | input_size = self.input_size 42 | output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim 43 | 44 | super().__init__(input_size=input_size, 45 | output_size=output_size, 46 | bias=bias, 47 | gather_output=gather_output, 48 | skip_bias_add=skip_bias_add, 49 | **kwargs) 50 | 51 | 52 | class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): 53 | 54 | def __init__(self, 55 | input_size, 56 | gate_ouput_size, 57 | up_output_size, 58 | *, 59 | bias=True, 60 | gather_output=True, 61 | skip_bias_add=False, 62 | **kwargs): 63 | # Keep input parameters, and already restrict the head numbers 64 | self.input_size = input_size 65 | self.output_size = gate_ouput_size + up_output_size 66 | self.gather_output = gather_output 67 | self.skip_bias_add = skip_bias_add 68 | 69 | super().__init__(input_size=self.input_size, 70 | output_size=self.output_size, 71 | bias=bias, 72 | gather_output=gather_output, 73 | skip_bias_add=skip_bias_add, 74 | **kwargs) 75 | -------------------------------------------------------------------------------- /verl/verl/models/llama/megatron/layers/parallel_rmsnorm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numbers 16 | import torch 17 | from megatron.core import ModelParallelConfig 18 | from torch import nn 19 | from transformers import LlamaConfig 20 | 21 | from apex.normalization.fused_layer_norm import fused_rms_norm_affine 22 | from verl.utils.megatron import sequence_parallel as sp_utils 23 | 24 | 25 | class ParallelLlamaRMSNorm(nn.Module): 26 | 27 | def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): 28 | """ 29 | LlamaRMSNorm is equivalent to T5LayerNorm 30 | """ 31 | super().__init__() 32 | if isinstance(config.hidden_size, numbers.Integral): 33 | normalized_shape = (config.hidden_size,) 34 | self.normalized_shape = torch.Size(normalized_shape) 35 | self.weight = nn.Parameter(torch.ones(self.normalized_shape)) 36 | self.variance_epsilon = config.rms_norm_eps 37 | 38 | if megatron_config.sequence_parallel: 39 | sp_utils.mark_parameter_as_sequence_parallel(self.weight) 40 | 41 | def forward(self, hidden_states): 42 | return fused_rms_norm_affine(input=hidden_states, 43 | weight=self.weight, 44 | normalized_shape=self.normalized_shape, 45 | eps=self.variance_epsilon, 46 | memory_efficient=True) -------------------------------------------------------------------------------- /verl/verl/models/qwen2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/models/qwen2/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .modeling_qwen2_megatron import ( 16 | # original model with megatron 17 | ParallelQwen2Model, 18 | ParallelQwen2ForCausalLM, 19 | # rmpad with megatron 20 | ParallelQwen2ForCausalLMRmPad, 21 | ParallelQwen2ForValueRmPad, 22 | # rmpad with megatron and pipeline parallelism 23 | ParallelQwen2ForCausalLMRmPadPP, 24 | ParallelQwen2ForValueRmPadPP) 25 | -------------------------------------------------------------------------------- /verl/verl/models/qwen2/megatron/checkpoint_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/models/qwen2/megatron/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .parallel_attention import ParallelQwen2Attention 16 | from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad 17 | from .parallel_mlp import ParallelQwen2MLP 18 | from .parallel_rmsnorm import ParallelQwen2RMSNorm 19 | -------------------------------------------------------------------------------- /verl/verl/models/qwen2/megatron/layers/parallel_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py 15 | 16 | from typing import Optional, Tuple 17 | 18 | from megatron.core import tensor_parallel 19 | 20 | 21 | class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): 22 | 23 | def __init__(self, 24 | input_size, 25 | num_heads, 26 | num_key_value_heads, 27 | head_dim, 28 | *, 29 | bias=True, 30 | gather_output=True, 31 | skip_bias_add=False, 32 | **kwargs): 33 | # Keep input parameters, and already restrict the head numbers 34 | self.input_size = input_size 35 | self.q_output_size = num_heads * head_dim 36 | self.kv_output_size = num_key_value_heads * head_dim 37 | self.head_dim = head_dim 38 | self.gather_output = gather_output 39 | self.skip_bias_add = skip_bias_add 40 | 41 | input_size = self.input_size 42 | output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim 43 | 44 | super().__init__(input_size=input_size, 45 | output_size=output_size, 46 | bias=bias, 47 | gather_output=gather_output, 48 | skip_bias_add=skip_bias_add, 49 | **kwargs) 50 | 51 | 52 | class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): 53 | 54 | def __init__(self, 55 | input_size, 56 | gate_ouput_size, 57 | up_output_size, 58 | *, 59 | bias=True, 60 | gather_output=True, 61 | skip_bias_add=False, 62 | **kwargs): 63 | # Keep input parameters, and already restrict the head numbers 64 | self.input_size = input_size 65 | self.output_size = gate_ouput_size + up_output_size 66 | self.gather_output = gather_output 67 | self.skip_bias_add = skip_bias_add 68 | 69 | super().__init__(input_size=self.input_size, 70 | output_size=self.output_size, 71 | bias=bias, 72 | gather_output=gather_output, 73 | skip_bias_add=skip_bias_add, 74 | **kwargs) 75 | -------------------------------------------------------------------------------- /verl/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numbers 16 | import torch 17 | from megatron.core import ModelParallelConfig 18 | from torch import nn 19 | from transformers import Qwen2Config 20 | 21 | from apex.normalization.fused_layer_norm import fused_rms_norm_affine 22 | from verl.utils.megatron import sequence_parallel as sp_utils 23 | 24 | 25 | class ParallelQwen2RMSNorm(nn.Module): 26 | 27 | def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): 28 | """ 29 | Qwen2RMSNorm is equivalent to T5LayerNorm 30 | """ 31 | super().__init__() 32 | if isinstance(config.hidden_size, numbers.Integral): 33 | normalized_shape = (config.hidden_size,) 34 | self.normalized_shape = torch.Size(normalized_shape) 35 | self.weight = nn.Parameter(torch.ones(self.normalized_shape)) 36 | self.variance_epsilon = config.rms_norm_eps 37 | 38 | if megatron_config.sequence_parallel: 39 | sp_utils.mark_parameter_as_sequence_parallel(self.weight) 40 | 41 | def forward(self, hidden_states): 42 | return fused_rms_norm_affine(input=hidden_states, 43 | weight=self.weight, 44 | normalized_shape=self.normalized_shape, 45 | eps=self.variance_epsilon, 46 | memory_efficient=True) -------------------------------------------------------------------------------- /verl/verl/models/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib 16 | from typing import List, Optional, Type 17 | 18 | import torch.nn as nn 19 | 20 | # Supported models using HF Rmpad 21 | # TODO(sgm): HF may supported more than listed here, we should add more after testing 22 | _MODELS_SUPPORT_RMPAD = {'llama', 'mistral', 'gemma', 'qwen2', 'qwen2_vl', 'qwen2_5_vl'} 23 | 24 | 25 | def check_model_support_rmpad(model_type: str): 26 | assert isinstance(model_type, str) 27 | if not model_type in _MODELS_SUPPORT_RMPAD: 28 | raise ValueError(f"Model architecture {model_type} is not supported for now. " 29 | f"RMPad supported architectures: {_MODELS_SUPPORT_RMPAD}." 30 | f"Please set `use_remove_padding=False` in the model config.") 31 | 32 | if model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope 33 | from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward 34 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 35 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2 36 | 37 | Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward 38 | Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward 39 | print("Qwen2vl patch applied!") 40 | 41 | 42 | # Supported models in Megatron-LM 43 | # Architecture -> (module, class). 44 | _MODELS = { 45 | "LlamaForCausalLM": 46 | ("llama", ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad")), 47 | "Qwen2ForCausalLM": 48 | ("qwen2", ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad")), 49 | "MistralForCausalLM": ("mistral", ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", 50 | "ParallelMistralForCausalLMRmPad")) 51 | } 52 | 53 | # return model class 54 | class ModelRegistry: 55 | 56 | @staticmethod 57 | def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]: 58 | if model_arch not in _MODELS: 59 | return None 60 | 61 | megatron = "megatron" 62 | 63 | module_name, model_cls_name = _MODELS[model_arch] 64 | if not value: # actor/ref 65 | model_cls_name = model_cls_name[0] 66 | elif value: # critic/rm 67 | model_cls_name = model_cls_name[1] 68 | 69 | module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron") 70 | return getattr(module, model_cls_name, None) 71 | 72 | @staticmethod 73 | def get_supported_archs() -> List[str]: 74 | return list(_MODELS.keys()) 75 | -------------------------------------------------------------------------------- /verl/verl/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/models/weight_loader_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def get_weight_loader(arch: str): 17 | from verl.models.llama.megatron.checkpoint_utils.llama_loader import load_state_dict_to_megatron_llama 18 | from verl.models.qwen2.megatron.checkpoint_utils.qwen2_loader import load_state_dict_to_megatron_qwen2 19 | _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = { 20 | 'LlamaForCausalLM': load_state_dict_to_megatron_llama, 21 | 'Qwen2ForCausalLM': load_state_dict_to_megatron_qwen2, 22 | } 23 | 24 | if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: 25 | return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch] 26 | raise ValueError(f"Model architectures {arch} are not supported for now. " 27 | f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}") 28 | -------------------------------------------------------------------------------- /verl/verl/single_controller/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) 18 | 19 | # Note(haibin.lin): single_controller.__version__ is deprecated 20 | with open(os.path.join(os.path.join(version_folder, os.pardir), 'version/version')) as f: 21 | __version__ = f.read().strip() 22 | 23 | from . import base 24 | from .base import * 25 | 26 | __all__ = base.__all__ -------------------------------------------------------------------------------- /verl/verl/single_controller/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .worker import Worker 16 | from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool 17 | 18 | __all__ = ['Worker', 'WorkerGroup', 'ClassWithInitArgs', 'ResourcePool'] -------------------------------------------------------------------------------- /verl/verl/single_controller/base/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/single_controller/base/megatron/worker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo 16 | 17 | 18 | class MegatronWorker(Worker): 19 | 20 | def __init__(self, cuda_visible_devices=None) -> None: 21 | super().__init__(cuda_visible_devices) 22 | 23 | def get_megatron_global_info(self): 24 | from megatron.core import parallel_state as mpu 25 | tp_size = mpu.get_tensor_model_parallel_world_size() 26 | dp_size = mpu.get_data_parallel_world_size() 27 | pp_size = mpu.get_pipeline_model_parallel_world_size() 28 | info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size) 29 | return info 30 | 31 | def get_megatron_rank_info(self): 32 | from megatron.core import parallel_state as mpu 33 | tp_rank = mpu.get_tensor_model_parallel_rank() 34 | dp_rank = mpu.get_data_parallel_rank() 35 | pp_rank = mpu.get_pipeline_model_parallel_rank() 36 | info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank) 37 | return info -------------------------------------------------------------------------------- /verl/verl/single_controller/base/megatron/worker_group.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict 16 | 17 | from .worker import DistRankInfo, DistGlobalInfo 18 | from verl.single_controller.base import ResourcePool, WorkerGroup 19 | 20 | 21 | class MegatronWorkerGroup(WorkerGroup): 22 | 23 | def __init__(self, resource_pool: ResourcePool, **kwargs): 24 | super().__init__(resource_pool=resource_pool, **kwargs) 25 | self._megatron_rank_info = None 26 | self._megatron_global_info: DistGlobalInfo = None 27 | 28 | def init_megatron(self, default_megatron_kwargs: Dict = None): 29 | raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten") 30 | 31 | def get_megatron_rank_info(self, rank: int) -> DistRankInfo: 32 | assert 0 <= rank < self.world_size, f'rank must be from [0, world_size), Got {rank}' 33 | return self._megatron_rank_info[rank] 34 | 35 | @property 36 | def tp_size(self): 37 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" 38 | return self._megatron_global_info.tp_size 39 | 40 | @property 41 | def dp_size(self): 42 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" 43 | return self._megatron_global_info.dp_size 44 | 45 | @property 46 | def pp_size(self): 47 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" 48 | return self._megatron_global_info.pp_size 49 | 50 | def get_megatron_global_info(self): 51 | return self._megatron_global_info 52 | -------------------------------------------------------------------------------- /verl/verl/single_controller/base/register_center/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/single_controller/base/register_center/ray.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ray 16 | 17 | 18 | @ray.remote 19 | class WorkerGroupRegisterCenter: 20 | 21 | def __init__(self, rank_zero_info): 22 | self.rank_zero_info = rank_zero_info 23 | 24 | def get_rank_zero_info(self): 25 | return self.rank_zero_info 26 | 27 | 28 | def create_worker_group_register_center(name, info): 29 | return WorkerGroupRegisterCenter.options(name=name).remote(info) 30 | -------------------------------------------------------------------------------- /verl/verl/single_controller/ray/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls -------------------------------------------------------------------------------- /verl/verl/single_controller/ray/megatron.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict, Optional 16 | 17 | import ray 18 | 19 | from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs 20 | from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo 21 | from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup 22 | 23 | 24 | # NOTE(sgm): for open-source megatron-core 25 | class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): 26 | """ 27 | MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup 28 | so that the dispatcher can use it to dispatch data. 29 | """ 30 | 31 | def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs): 32 | super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) 33 | self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info') 34 | self._megatron_global_info: DistGlobalInfo = ray.get( 35 | self.execute_rank_zero_async(method_name='get_megatron_global_info')) 36 | 37 | 38 | class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): 39 | """ 40 | MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup 41 | so that the dispatcher can use it to dispatch data. 42 | """ 43 | 44 | def __init__(self, 45 | resource_pool: RayResourcePool, 46 | ray_cls_with_init: RayClassWithInitArgs, 47 | default_megatron_kwargs: Dict = None, 48 | **kwargs): 49 | super().__init__(resource_pool=resource_pool, 50 | ray_cls_with_init=ray_cls_with_init, 51 | default_megatron_kwargs=default_megatron_kwargs, 52 | **kwargs) 53 | self.init_megatron(default_megatron_kwargs=default_megatron_kwargs) 54 | self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info') 55 | self._megatron_global_info: DistGlobalInfo = ray.get( 56 | self.execute_rank_zero_async(method_name='get_megatron_global_info')) 57 | 58 | def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None): 59 | # after super, we will call init of each worker 60 | if not self._is_init_with_detached_workers: 61 | # only init_megatron if the WorkerGroup is created from scratch 62 | self.execute_all_sync(method_name='init_megatron', default_megatron_kwargs=default_megatron_kwargs) 63 | -------------------------------------------------------------------------------- /verl/verl/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/third_party/vllm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from importlib.metadata import version, PackageNotFoundError 16 | from packaging import version as vs 17 | 18 | 19 | def get_version(pkg): 20 | try: 21 | return version(pkg) 22 | except PackageNotFoundError: 23 | return None 24 | 25 | 26 | package_name = 'vllm' 27 | package_version = get_version(package_name) 28 | vllm_version = None 29 | 30 | if package_version == '0.3.1': 31 | vllm_version = '0.3.1' 32 | from .vllm_v_0_3_1.llm import LLM 33 | from .vllm_v_0_3_1.llm import LLMEngine 34 | from .vllm_v_0_3_1 import parallel_state 35 | elif package_version == '0.4.2': 36 | vllm_version = '0.4.2' 37 | from .vllm_v_0_4_2.llm import LLM 38 | from .vllm_v_0_4_2.llm import LLMEngine 39 | from .vllm_v_0_4_2 import parallel_state 40 | elif package_version == '0.5.4': 41 | vllm_version = '0.5.4' 42 | from .vllm_v_0_5_4.llm import LLM 43 | from .vllm_v_0_5_4.llm import LLMEngine 44 | from .vllm_v_0_5_4 import parallel_state 45 | elif package_version == '0.6.3': 46 | vllm_version = '0.6.3' 47 | from .vllm_v_0_6_3.llm import LLM 48 | from .vllm_v_0_6_3.llm import LLMEngine 49 | from .vllm_v_0_6_3 import parallel_state 50 | elif vs.parse(package_version) >= vs.parse('0.6.6.post2.dev252+g8027a724'): 51 | # From 0.6.6.post2 on, vllm supports SPMD inference 52 | # See https://github.com/vllm-project/vllm/pull/12071 53 | 54 | from vllm import LLM 55 | from vllm.distributed import parallel_state 56 | from .vllm_spmd.dtensor_weight_loaders import load_dtensor_weights 57 | else: 58 | raise ValueError( 59 | f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4, 0.6.3 and 0.7.0+' 60 | ) 61 | -------------------------------------------------------------------------------- /verl/verl/third_party/vllm/vllm_spmd/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/third_party/vllm/vllm_v_0_3_1/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) 19 | 20 | from vllm.lora.request import LoRARequest 21 | from vllm.utils import make_async, LRUCache 22 | from vllm.transformers_utils.tokenizers import * 23 | 24 | 25 | class TokenizerGroup: 26 | """A group of tokenizers that can be used for LoRA adapters.""" 27 | 28 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, 29 | max_input_length: Optional[int]): 30 | self.enable_lora = enable_lora 31 | self.max_input_length = max_input_length 32 | self.tokenizer = tokenizer 33 | if enable_lora: 34 | self.lora_tokenizers = LRUCache(capacity=max_num_seqs) 35 | else: 36 | self.lora_tokenizers = None 37 | 38 | def encode(self, 39 | prompt: str, 40 | request_id: Optional[str] = None, 41 | lora_request: Optional[LoRARequest] = None) -> List[int]: 42 | tokenizer = self.get_lora_tokenizer(lora_request) 43 | return tokenizer.encode(prompt) 44 | 45 | async def encode_async(self, 46 | prompt: str, 47 | request_id: Optional[str] = None, 48 | lora_request: Optional[LoRARequest] = None) -> List[int]: 49 | tokenizer = await self.get_lora_tokenizer_async(lora_request) 50 | return tokenizer.encode(prompt) 51 | 52 | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": 53 | if not lora_request or not self.enable_lora: 54 | return self.tokenizer 55 | if lora_request.lora_int_id not in self.lora_tokenizers: 56 | # TODO(sgm): the lora tokenizer is also passed, but may be different 57 | tokenizer = self.tokenizer 58 | # tokenizer = (get_lora_tokenizer( 59 | # lora_request, **self.tokenizer_config) or self.tokenizer) 60 | self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) 61 | return tokenizer 62 | else: 63 | return self.lora_tokenizers.get(lora_request.lora_int_id) 64 | 65 | # FIXME(sgm): for simplicity, we assign the special token here 66 | @property 67 | def pad_token_id(self): 68 | return self.tokenizer.pad_token_id 69 | 70 | @property 71 | def eos_token_id(self): 72 | return self.tokenizer.eos_token_id 73 | -------------------------------------------------------------------------------- /verl/verl/third_party/vllm/vllm_v_0_4_2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/third_party/vllm/vllm_v_0_5_4/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models 15 | 16 | from typing import Dict, Union, Optional, Iterable, Tuple 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype 22 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader 23 | 24 | 25 | def update_hf_weight_loader(): 26 | print('no hf weight loader need to be updated') 27 | return 28 | 29 | 30 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): 31 | assert isinstance(actor_weights, Dict) 32 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO 33 | if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys(): 34 | del actor_weights["lm_head.weight"] 35 | vllm_model.load_weights(actor_weights.items()) 36 | for _, module in vllm_model.named_modules(): 37 | quant_method = getattr(module, "quant_method", None) 38 | if quant_method is not None: 39 | quant_method.process_weights_after_loading(module) 40 | # FIXME: Remove this after Mixtral is updated 41 | # to use quant_method. 42 | if hasattr(module, "process_weights_after_loading"): 43 | module.process_weights_after_loading() 44 | vllm_model = vllm_model.cuda() 45 | -------------------------------------------------------------------------------- /verl/verl/third_party/vllm/vllm_v_0_6_3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader 15 | 16 | from typing import Dict 17 | 18 | import torch.nn as nn 19 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype 20 | 21 | 22 | def update_hf_weight_loader(): 23 | print("no hf weight loader need to be updated") 24 | return 25 | 26 | 27 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): 28 | assert isinstance(actor_weights, Dict) 29 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO 30 | if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys(): 31 | del actor_weights["lm_head.weight"] 32 | vllm_model.load_weights(actor_weights.items()) 33 | for _, module in vllm_model.named_modules(): 34 | quant_method = getattr(module, "quant_method", None) 35 | if quant_method is not None: 36 | quant_method.process_weights_after_loading(module) 37 | # FIXME: Remove this after Mixtral is updated 38 | # to use quant_method. 39 | if hasattr(module, "process_weights_after_loading"): 40 | module.process_weights_after_loading() 41 | vllm_model = vllm_model.cuda() 42 | -------------------------------------------------------------------------------- /verl/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py 15 | 16 | from typing import Optional 17 | 18 | from transformers import PreTrainedTokenizer 19 | from vllm.transformers_utils.tokenizer_group import TokenizerGroup 20 | from vllm.utils import LRUCache 21 | 22 | 23 | class TokenizerGroup(TokenizerGroup): 24 | """A group of tokenizers that can be used for LoRA adapters.""" 25 | 26 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, 27 | max_input_length: Optional[int]): 28 | self.enable_lora = enable_lora 29 | self.max_input_length = max_input_length 30 | self.tokenizer = tokenizer 31 | self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None 32 | 33 | # FIXME(sgm): for simplicity, we assign the special token here 34 | @property 35 | def pad_token_id(self): 36 | return self.tokenizer.pad_token_id 37 | 38 | @property 39 | def eos_token_id(self): 40 | return self.tokenizer.eos_token_id 41 | -------------------------------------------------------------------------------- /verl/verl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/trainer/config/evaluation.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | path: /tmp/math_Qwen2-7B-Instruct.parquet 3 | prompt_key: prompt 4 | response_key: responses 5 | data_source_key: data_source 6 | reward_model_key: reward_model -------------------------------------------------------------------------------- /verl/verl/trainer/config/generation.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | nnodes: 1 3 | n_gpus_per_node: 8 4 | 5 | data: 6 | path: ~/data/rlhf/math/test.parquet 7 | prompt_key: prompt 8 | n_samples: 5 9 | output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet 10 | batch_size: 128 11 | 12 | model: 13 | path: ~/models/Qwen2-7B-Instruct 14 | external_lib: null 15 | rollout: 16 | name: vllm 17 | temperature: 1.0 18 | top_k: 50 # 0 for hf rollout, -1 for vllm rollout 19 | top_p: 0.7 20 | prompt_length: 1536 21 | response_length: 512 22 | # for vllm rollout 23 | dtype: bfloat16 # should align with FSDP 24 | gpu_memory_utilization: 0.5 25 | ignore_eos: False 26 | enforce_eager: True 27 | free_cache_engine: True 28 | load_format: dummy_dtensor 29 | tensor_model_parallel_size: 1 30 | max_num_batched_tokens: 8192 31 | max_model_len: null 32 | max_num_seqs: 1024 33 | log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu 34 | log_prob_micro_batch_size_per_gpu: 8 35 | # for fire vllm rollout 36 | use_fire_sampling: False # enable FIRE https://arxiv.org/abs/2410.21236 37 | # for hf rollout 38 | do_sample: True 39 | disable_log_stats: True 40 | enable_chunked_prefill: True 41 | n: 1 42 | actor: 43 | strategy: fsdp # This is for backward-compatibility 44 | ppo_mini_batch_size: 256 45 | ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu 46 | ppo_micro_batch_size_per_gpu: null 47 | use_dynamic_bsz: False 48 | ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} 49 | grad_clip: 1.0 50 | clip_ratio: 0.2 51 | entropy_coeff: 0.001 52 | use_kl_loss: False # True for GRPO 53 | kl_loss_coef: 0.001 # for grpo 54 | kl_loss_type: low_var_kl # for grpo 55 | ppo_epochs: 1 56 | shuffle: False 57 | ulysses_sequence_parallel_size: 1 # sp size 58 | optim: 59 | lr: 1e-6 60 | lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime 61 | min_lr_ratio: null # only useful for warmup with cosine 62 | warmup_style: constant # select from constant/cosine 63 | total_training_steps: -1 # must be override by program 64 | fsdp_config: 65 | wrap_policy: 66 | min_num_params: 0 67 | param_offload: False 68 | optimizer_offload: False 69 | fsdp_size: -1 -------------------------------------------------------------------------------- /verl/verl/trainer/config/sft_trainer.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_batch_size: 256 3 | micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu 4 | micro_batch_size_per_gpu: 4 # this is also val batch size 5 | train_files: ~/data/gsm8k/train.parquet 6 | val_files: ~/data/gsm8k/test.parquet 7 | prompt_key: question 8 | response_key: answer 9 | max_length: 1024 10 | truncation: error 11 | balance_dp_token: False 12 | chat_template: null 13 | model: 14 | partial_pretrain: ~/models/gemma-1.1-7b-it 15 | fsdp_config: 16 | wrap_policy: 17 | min_num_params: 0 18 | cpu_offload: False 19 | offload_params: False 20 | external_lib: null 21 | enable_gradient_checkpointing: False 22 | trust_remote_code: False 23 | lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) 24 | lora_alpha: 16 # LoRA scaling factor 25 | target_modules: all-linear # Target modules for LoRA adaptation 26 | use_liger: False 27 | optim: 28 | lr: 1e-5 29 | betas: [0.9, 0.95] 30 | weight_decay: 0.01 31 | warmup_steps_ratio: 0.1 32 | clip_grad: 1.0 33 | ulysses_sequence_parallel_size: 1 34 | use_remove_padding: False 35 | trainer: 36 | default_local_dir: /tmp/sft_model 37 | default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here 38 | resume_path: null 39 | project_name: gsm8k-sft 40 | experiment_name: test 41 | total_epochs: 4 42 | total_training_steps: null 43 | logger: ['console'] 44 | seed: 1 45 | 46 | -------------------------------------------------------------------------------- /verl/verl/trainer/main_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Offline evaluate the performance of a generated file using reward model and ground truth verifier. 16 | The input is a parquet file that contains N generated sequences and (optional) the ground truth. 17 | 18 | """ 19 | 20 | import hydra 21 | from verl.utils.fs import copy_to_local 22 | from verl.utils.reward_score import math, gsm8k 23 | import pandas as pd 24 | import numpy as np 25 | 26 | 27 | def select_reward_fn(data_source): 28 | if data_source == 'lighteval/MATH': 29 | return math.compute_score 30 | else: 31 | raise NotImplementedError 32 | 33 | 34 | @hydra.main(config_path='config', config_name='evaluation', version_base=None) 35 | def main(config): 36 | local_path = copy_to_local(config.data.path) 37 | dataset = pd.read_parquet(local_path) 38 | prompts = dataset[config.data.prompt_key] 39 | responses = dataset[config.data.response_key] 40 | data_sources = dataset[config.data.data_source_key] 41 | reward_model_data = dataset[config.data.reward_model_key] 42 | 43 | passes = 0 44 | 45 | total = len(dataset) 46 | 47 | for i in range(total): 48 | response_lst = responses[i] 49 | data_source = data_sources[i] 50 | # select reward score based on data_source 51 | prompt = prompts[i] 52 | reward_data = reward_model_data[i] 53 | reward_fn = select_reward_fn(data_source) 54 | ground_truth = reward_data['ground_truth'] 55 | score_lst = [] 56 | for r in response_lst: 57 | score = reward_fn(r, ground_truth) 58 | score_lst.append(score) 59 | 60 | max_score = np.max(score_lst) 61 | 62 | if max_score == 1: 63 | passes += 1 64 | 65 | print(f'pass@5: {passes / total}') 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /verl/verl/trainer/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/trainer/runtime_env.yaml: -------------------------------------------------------------------------------- 1 | working_dir: ./ 2 | excludes: ["/.git/"] 3 | env_vars: 4 | TORCH_NCCL_AVOID_RECORD_STREAMS: "1" 5 | VLLM_ATTENTION_BACKEND: "XFORMERS" -------------------------------------------------------------------------------- /verl/verl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import tokenizer 16 | from .tokenizer import hf_tokenizer, hf_processor 17 | 18 | __all__ = tokenizer.__all__ -------------------------------------------------------------------------------- /verl/verl/utils/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. -------------------------------------------------------------------------------- /verl/verl/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict 16 | 17 | from omegaconf import DictConfig 18 | 19 | 20 | def update_dict_with_config(dictionary: Dict, config: DictConfig): 21 | for key in dictionary: 22 | if hasattr(config, key): 23 | dictionary[key] = getattr(config, key) 24 | -------------------------------------------------------------------------------- /verl/verl/utils/dataset/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Format 2 | ## RLHF dataset 3 | We combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers. 4 | 5 | Math problems 6 | ```json 7 | { 8 | "data_source": "openai/gsm8k", 9 | "prompt": [{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\""}], 10 | "ability": "math", 11 | "reward_model": { 12 | "style": "rule", 13 | "ground_truth": ["72"] 14 | }, 15 | } 16 | ``` 17 | -------------------------------------------------------------------------------- /verl/verl/utils/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .rl_dataset import RLHFDataset 16 | from .rm_dataset import RMDataset 17 | from .sft_dataset import SFTDataset 18 | -------------------------------------------------------------------------------- /verl/verl/utils/debug/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .performance import log_gpu_memory_usage -------------------------------------------------------------------------------- /verl/verl/utils/debug/performance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.distributed as dist 17 | import logging 18 | 19 | 20 | def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): 21 | if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank): 22 | memory_allocated = torch.cuda.memory_allocated() / 1024**3 23 | memory_reserved = torch.cuda.memory_reserved() / 1024**3 24 | 25 | message = f'{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}' 26 | 27 | if logger is None: 28 | print(message) 29 | else: 30 | logger.log(msg=message, level=level) 31 | -------------------------------------------------------------------------------- /verl/verl/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for distributed training.""" 15 | import os 16 | 17 | 18 | def initialize_global_process_group(timeout_second=36000): 19 | import torch.distributed 20 | from datetime import timedelta 21 | torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second)) 22 | local_rank = int(os.environ["LOCAL_RANK"]) 23 | rank = int(os.environ["RANK"]) 24 | world_size = int(os.environ["WORLD_SIZE"]) 25 | 26 | if torch.distributed.is_initialized(): 27 | torch.cuda.set_device(local_rank) 28 | return local_rank, rank, world_size 29 | -------------------------------------------------------------------------------- /verl/verl/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Utilities to check if packages are available. 16 | We assume package availability won't change during runtime. 17 | """ 18 | 19 | from functools import cache 20 | from typing import List 21 | 22 | 23 | @cache 24 | def is_megatron_core_available(): 25 | try: 26 | from megatron.core import parallel_state as mpu 27 | return True 28 | except ImportError: 29 | return False 30 | 31 | 32 | @cache 33 | def is_vllm_available(): 34 | try: 35 | import vllm 36 | return True 37 | except ImportError: 38 | return False 39 | 40 | 41 | def import_external_libs(external_libs=None): 42 | if external_libs is None: 43 | return 44 | if not isinstance(external_libs, List): 45 | external_libs = [external_libs] 46 | import importlib 47 | for external_lib in external_libs: 48 | importlib.import_module(external_lib) 49 | -------------------------------------------------------------------------------- /verl/verl/utils/logger/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/utils/logger/aggregate_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A Ray logger will receive logging info from different processes. 16 | """ 17 | import numbers 18 | from typing import Dict 19 | 20 | 21 | def concat_dict_to_str(dict: Dict, step): 22 | output = [f'step:{step}'] 23 | for k, v in dict.items(): 24 | if isinstance(v, numbers.Number): 25 | output.append(f'{k}:{v:.3f}') 26 | output_str = ' - '.join(output) 27 | return output_str 28 | 29 | 30 | class LocalLogger: 31 | 32 | def __init__(self, remote_logger=None, enable_wandb=False, print_to_console=False): 33 | self.print_to_console = print_to_console 34 | if print_to_console: 35 | print('Using LocalLogger is deprecated. The constructor API will change ') 36 | 37 | def flush(self): 38 | pass 39 | 40 | def log(self, data, step): 41 | if self.print_to_console: 42 | print(concat_dict_to_str(data, step=step), flush=True) -------------------------------------------------------------------------------- /verl/verl/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | 18 | def set_basic_config(level): 19 | """ 20 | This function sets the global logging format and level. It will be called when import verl 21 | """ 22 | logging.basicConfig(format='%(levelname)s:%(asctime)s:%(message)s', level=level) 23 | -------------------------------------------------------------------------------- /verl/verl/utils/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/utils/megatron/memory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | class MemoryBuffer: 19 | 20 | def __init__(self, numel, numel_padded, dtype): 21 | self.numel = numel 22 | self.numel_padded = numel_padded 23 | self.dtype = dtype 24 | self.data = torch.zeros(self.numel_padded, 25 | dtype=self.dtype, 26 | device=torch.cuda.current_device(), 27 | requires_grad=False) 28 | 29 | def zero(self): 30 | """Reset the buffer to zero.""" 31 | self.data.zero_() 32 | 33 | def get(self, shape, start_index): 34 | """Return a tensor with the input `shape` as a view into the 35 | 1-D data starting at `start_index`.""" 36 | end_index = start_index + shape.numel() 37 | assert end_index <= self.numel, \ 38 | 'requested tensor is out of the buffer range.' 39 | buffer_tensor = self.data[start_index:end_index] 40 | buffer_tensor = buffer_tensor.view(shape) 41 | return buffer_tensor 42 | -------------------------------------------------------------------------------- /verl/verl/utils/megatron/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import importlib 17 | from packaging.version import Version 18 | 19 | from apex.optimizers import FusedAdam as Adam 20 | from apex.optimizers import FusedSGD as SGD 21 | 22 | from megatron.core.optimizer import OptimizerConfig 23 | 24 | from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native 25 | 26 | 27 | def get_megatron_optimizer( 28 | model, 29 | config: OptimizerConfig, 30 | no_weight_decay_cond=None, 31 | scale_lr_cond=None, 32 | lr_mult=1.0, 33 | check_for_nan_in_loss_and_grad=False, 34 | overlap_param_gather=False # add for verl 35 | ): 36 | # Base optimizer. 37 | return get_megatron_optimizer_native(config=config, 38 | model_chunks=model, 39 | no_weight_decay_cond=no_weight_decay_cond, 40 | scale_lr_cond=scale_lr_cond, 41 | lr_mult=lr_mult) 42 | -------------------------------------------------------------------------------- /verl/verl/utils/megatron/pipeline_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from megatron.core import parallel_state as mpu 18 | 19 | from .sequence_parallel import pad_to_sequence_parallel 20 | 21 | 22 | def compute_transformers_input_shapes(batches, meta_info): 23 | from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron 24 | # pre-compute input shapes for each micro-batch at each pp stage 25 | input_shapes = [] 26 | for model_inputs in batches: 27 | input_ids = model_inputs['input_ids'] 28 | attention_mask = model_inputs['attention_mask'] 29 | input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1) 30 | if meta_info['sequence_parallel']: 31 | input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad) 32 | # compute shapes for model_inputs 33 | input_shapes.append( 34 | torch.Size([ 35 | input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), 1, meta_info['hidden_size'] 36 | ])) 37 | else: 38 | # compute shapes for model_inputs 39 | input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info['hidden_size']])) 40 | return input_shapes 41 | 42 | 43 | def make_batch_generator(batches, vpp_size): 44 | if vpp_size > 1: 45 | # has vpp 46 | batch_generator = [batches] * vpp_size # number of vpp chunks 47 | batch_generator = [iter(b) for b in batch_generator] 48 | else: 49 | # no vpp 50 | batch_generator = iter(batches) 51 | return batch_generator 52 | -------------------------------------------------------------------------------- /verl/verl/utils/megatron/sequence_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from megatron.core import parallel_state as mpu 19 | 20 | 21 | def mark_parameter_as_sequence_parallel(parameter): 22 | setattr(parameter, 'sequence_parallel', True) 23 | 24 | 25 | def is_sequence_parallel_param(param): 26 | return hasattr(param, 'sequence_parallel') and param.sequence_parallel 27 | 28 | 29 | def pad_to_sequence_parallel(unpad_tokens: torch.Tensor): 30 | """pad the tokens such that the total length is a multiple of sp world size 31 | 32 | Args: 33 | unpad_tokens: (total_nnz, ...). Tokens after removing padding 34 | 35 | Returns: 36 | 37 | """ 38 | total_nnz = unpad_tokens.shape[0] 39 | sp_world_size = mpu.get_tensor_model_parallel_world_size() 40 | 41 | if total_nnz % sp_world_size == 0: 42 | pad_size = 0 43 | else: 44 | pad_size = sp_world_size - total_nnz % sp_world_size 45 | 46 | if pad_size > 0: 47 | if unpad_tokens.ndim == 1: 48 | unpad_tokens = F.pad(unpad_tokens, (0, pad_size)) 49 | elif unpad_tokens.ndim == 2: 50 | unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size)) 51 | else: 52 | raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported') 53 | 54 | return unpad_tokens 55 | -------------------------------------------------------------------------------- /verl/verl/utils/py_functional.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contain small python utility functions 16 | """ 17 | 18 | from typing import Dict 19 | from types import SimpleNamespace 20 | 21 | 22 | def union_two_dict(dict1: Dict, dict2: Dict): 23 | """Union two dict. Will throw an error if there is an item not the same object with the same key. 24 | 25 | Args: 26 | dict1: 27 | dict2: 28 | 29 | Returns: 30 | 31 | """ 32 | for key, val in dict2.items(): 33 | if key in dict1: 34 | assert dict2[key] == dict1[key], \ 35 | f'{key} in meta_dict1 and meta_dict2 are not the same object' 36 | dict1[key] = val 37 | 38 | return dict1 39 | 40 | 41 | def append_to_dict(data: Dict, new_data: Dict): 42 | for key, val in new_data.items(): 43 | if key not in data: 44 | data[key] = [] 45 | data[key].append(val) 46 | 47 | 48 | class NestedNamespace(SimpleNamespace): 49 | 50 | def __init__(self, dictionary, **kwargs): 51 | super().__init__(**kwargs) 52 | for key, value in dictionary.items(): 53 | if isinstance(value, dict): 54 | self.__setattr__(key, NestedNamespace(value)) 55 | else: 56 | self.__setattr__(key, value) 57 | -------------------------------------------------------------------------------- /verl/verl/utils/ray_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contains commonly used utilities for ray 16 | """ 17 | 18 | import ray 19 | 20 | import concurrent.futures 21 | 22 | 23 | def parallel_put(data_list, max_workers=None): 24 | 25 | def put_data(index, data): 26 | return index, ray.put(data) 27 | 28 | if max_workers is None: 29 | max_workers = min(len(data_list), 16) 30 | 31 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 32 | data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)] 33 | res_lst = [] 34 | for future in concurrent.futures.as_completed(data_list_f): 35 | res_lst.append(future.result()) 36 | 37 | # reorder based on index 38 | output = [None for _ in range(len(data_list))] 39 | for res in res_lst: 40 | index, data_ref = res 41 | output[index] = data_ref 42 | 43 | return output 44 | -------------------------------------------------------------------------------- /verl/verl/utils/rendezvous/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/utils/rendezvous/ray_backend.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import time 17 | 18 | from cupy.cuda.nccl import NcclCommunicator, get_unique_id 19 | 20 | import ray 21 | from ray.util import list_named_actors 22 | 23 | 24 | @ray.remote 25 | class NCCLIDStore: 26 | 27 | def __init__(self, nccl_id): 28 | self._nccl_id = nccl_id 29 | 30 | def get(self): 31 | return self._nccl_id 32 | 33 | 34 | def get_nccl_id_store_by_name(name): 35 | all_actors = list_named_actors(all_namespaces=True) 36 | matched_actors = [actor for actor in all_actors if actor.get("name", None) == name] 37 | if len(matched_actors) == 1: 38 | actor = matched_actors[0] 39 | return ray.get_actor(**actor) 40 | elif len(matched_actors) > 1: 41 | logging.warning(f"multiple actors with same name found: {matched_actors}") 42 | elif len(matched_actors) == 0: 43 | logging.info(f"failed to get any actor named {name}") 44 | return None 45 | 46 | 47 | def create_nccl_communicator_in_ray(rank: int, 48 | world_size: int, 49 | group_name: str, 50 | max_retries: int = 100, 51 | interval_s: int = 5): 52 | if rank == 0: 53 | nccl_id = get_unique_id() 54 | nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id) 55 | 56 | assert ray.get(nccl_id_store.get.remote()) == nccl_id 57 | communicator = NcclCommunicator( 58 | ndev=world_size, 59 | commId=nccl_id, 60 | rank=0, 61 | ) 62 | return communicator 63 | else: 64 | for i in range(max_retries): 65 | nccl_id_store = get_nccl_id_store_by_name(group_name) 66 | if nccl_id_store is not None: 67 | logging.info(f"nccl_id_store {group_name} got") 68 | nccl_id = ray.get(nccl_id_store.get.remote()) 69 | logging.info(f"nccl id for {group_name} got: {nccl_id}") 70 | communicator = NcclCommunicator( 71 | ndev=world_size, 72 | commId=nccl_id, 73 | rank=rank, 74 | ) 75 | return communicator 76 | logging.info(f"failed to get nccl_id for {i+1} time, sleep for {interval_s} seconds") 77 | time.sleep(interval_s) 78 | -------------------------------------------------------------------------------- /verl/verl/utils/reward_score/geo3k.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from mathruler.grader import extract_boxed_content, grade_answer 16 | 17 | 18 | def compute_score(predict_str: str, ground_truth: str) -> float: 19 | answer = extract_boxed_content(predict_str) 20 | if grade_answer(answer, ground_truth): 21 | return 1.0 # correct answer 22 | 23 | return 0.0 # wrong answer 24 | -------------------------------------------------------------------------------- /verl/verl/utils/reward_score/gsm8k.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | 18 | def extract_solution(solution_str, method='strict'): 19 | assert method in ['strict', 'flexible'] 20 | 21 | if method == 'strict': 22 | # this also tests the formatting of the model 23 | solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) 24 | if solution is None: 25 | final_answer = None 26 | else: 27 | final_answer = solution.group(0) 28 | final_answer = final_answer.split('#### ')[1].replace(',', '').replace('$', '') 29 | elif method == 'flexible': 30 | answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str) 31 | final_answer = None 32 | if len(answer) == 0: 33 | # no reward is there is no answer 34 | pass 35 | else: 36 | invalid_str = ['', '.'] 37 | # find the last number that is not '.' 38 | for final_answer in reversed(answer): 39 | if final_answer not in invalid_str: 40 | break 41 | return final_answer 42 | 43 | 44 | def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.): 45 | """The scoring function for GSM8k. 46 | 47 | Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024. 48 | 49 | Args: 50 | solution_str: the solution text 51 | ground_truth: the ground truth 52 | method: the method to extract the solution, choices are 'strict' and 'flexible' 53 | format_score: the score for the format 54 | score: the score for the correct answer 55 | """ 56 | answer = extract_solution(solution_str=solution_str, method=method) 57 | if answer is None: 58 | return 0 59 | else: 60 | if answer == ground_truth: 61 | return score 62 | else: 63 | return format_score -------------------------------------------------------------------------------- /verl/verl/utils/reward_score/prime_code/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 PRIME team and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py 16 | 17 | import multiprocessing 18 | from typing import Dict, Optional 19 | from datasets import load_dataset 20 | from .testing_util import run_test 21 | import traceback 22 | import os, sys 23 | 24 | 25 | def _temp_run(sample, generation, debug, result, metadata_list, timeout): 26 | with open(os.devnull, 'w') as devnull: 27 | sys.stdout = devnull 28 | sys.stderr = devnull 29 | try: 30 | res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) 31 | result.append(res) 32 | metadata_list.append(metadata) 33 | except Exception as e: 34 | # print(e) # some tracebacks are extremely long. 35 | traceback.print_exc(10) 36 | result.append([-1 for i in range(len(sample['inputs']))]) 37 | metadata_list.append({}) 38 | 39 | 40 | def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True): 41 | """Check correctness of code generation with a global timeout. 42 | The global timeout is to catch some extreme/rare cases not handled by the timeouts 43 | inside `run_test`""" 44 | 45 | manager = multiprocessing.Manager() 46 | result = manager.list() 47 | metadata_list = manager.list() 48 | p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout)) 49 | p.start() 50 | p.join(timeout=timeout + 1) 51 | if p.is_alive(): 52 | p.kill() 53 | # p.terminate() 54 | if not result: 55 | # consider that all tests failed 56 | result = [[-1 for i in range(len(in_outs["inputs"]))]] 57 | if debug: 58 | print(f"global timeout") 59 | return result[0], metadata_list 60 | -------------------------------------------------------------------------------- /verl/verl/utils/torch_dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Adapted from Cruise. 16 | """ 17 | 18 | import torch 19 | 20 | from typing import Union 21 | 22 | HALF_LIST = [16, "16", "fp16", "float16", torch.float16] 23 | FLOAT_LIST = [32, "32", "fp32", "float32", torch.float32] 24 | BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16] 25 | 26 | 27 | class PrecisionType(object): 28 | """Type of precision used. 29 | 30 | >>> PrecisionType.HALF == 16 31 | True 32 | >>> PrecisionType.HALF in (16, "16") 33 | True 34 | """ 35 | 36 | HALF = "16" 37 | FLOAT = "32" 38 | FULL = "64" 39 | BFLOAT = "bf16" 40 | MIXED = "mixed" 41 | 42 | @staticmethod 43 | def supported_type(precision: Union[str, int]) -> bool: 44 | return any(x == precision for x in PrecisionType) 45 | 46 | @staticmethod 47 | def supported_types() -> list[str]: 48 | return [x.value for x in PrecisionType] 49 | 50 | @staticmethod 51 | def is_fp16(precision): 52 | return precision in HALF_LIST 53 | 54 | @staticmethod 55 | def is_fp32(precision): 56 | return precision in FLOAT_LIST 57 | 58 | @staticmethod 59 | def is_bf16(precision): 60 | return precision in BFLOAT_LIST 61 | 62 | @staticmethod 63 | def to_dtype(precision): 64 | if precision in HALF_LIST: 65 | return torch.float16 66 | elif precision in FLOAT_LIST: 67 | return torch.float32 68 | elif precision in BFLOAT_LIST: 69 | return torch.bfloat16 70 | else: 71 | raise RuntimeError(f"unexpected precision: {precision}") 72 | 73 | @staticmethod 74 | def to_str(precision): 75 | if precision == torch.float16: 76 | return 'fp16' 77 | elif precision == torch.float32: 78 | return 'fp32' 79 | elif precision == torch.bfloat16: 80 | return 'bf16' 81 | else: 82 | raise RuntimeError(f"unexpected precision: {precision}") 83 | -------------------------------------------------------------------------------- /verl/verl/version/version: -------------------------------------------------------------------------------- 1 | 0.2.0.dev 2 | -------------------------------------------------------------------------------- /verl/verl/workers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/verl/workers/actor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPOActor 16 | from .dp_actor import DataParallelPPOActor 17 | 18 | __all__ = ["BasePPOActor", "DataParallelPPOActor"] 19 | -------------------------------------------------------------------------------- /verl/verl/workers/actor/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The base class for Actor 16 | """ 17 | from abc import ABC, abstractmethod 18 | from typing import Iterable, Dict 19 | 20 | from verl import DataProto 21 | import torch 22 | 23 | __all__ = ['BasePPOActor'] 24 | 25 | 26 | class BasePPOActor(ABC): 27 | 28 | def __init__(self, config): 29 | """The base class for PPO actor 30 | 31 | Args: 32 | config (DictConfig): a config passed to the PPOActor. We expect the type to be 33 | DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general. 34 | """ 35 | super().__init__() 36 | self.config = config 37 | 38 | @abstractmethod 39 | def compute_log_prob(self, data: DataProto) -> torch.Tensor: 40 | """Compute logits given a batch of data. 41 | 42 | Args: 43 | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```, 44 | ```attention_mask``` and ```position_ids```. 45 | 46 | Returns: 47 | DataProto: a DataProto containing the key ```log_probs``` 48 | 49 | 50 | """ 51 | pass 52 | 53 | @abstractmethod 54 | def update_policy(self, data: DataProto) -> Dict: 55 | """Update the policy with an iterator of DataProto 56 | 57 | Args: 58 | data (DataProto): an iterator over the DataProto that returns by 59 | ```make_minibatch_iterator``` 60 | 61 | Returns: 62 | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model 63 | such as ```loss```, ```grad_norm```, etc,. 64 | 65 | """ 66 | pass 67 | -------------------------------------------------------------------------------- /verl/verl/workers/critic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPOCritic 16 | from .dp_critic import DataParallelPPOCritic 17 | 18 | __all__ = ["BasePPOCritic", "DataParallelPPOCritic"] 19 | -------------------------------------------------------------------------------- /verl/verl/workers/critic/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Base class for a critic 16 | """ 17 | from abc import ABC, abstractmethod 18 | 19 | import torch 20 | 21 | from verl import DataProto 22 | 23 | __all__ = ['BasePPOCritic'] 24 | 25 | 26 | class BasePPOCritic(ABC): 27 | 28 | def __init__(self, config): 29 | super().__init__() 30 | self.config = config 31 | 32 | @abstractmethod 33 | def compute_values(self, data: DataProto) -> torch.Tensor: 34 | """Compute values""" 35 | pass 36 | 37 | @abstractmethod 38 | def update_critic(self, data: DataProto): 39 | """Update the critic""" 40 | pass 41 | -------------------------------------------------------------------------------- /verl/verl/workers/reward_manager/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 PRIME team and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .naive import NaiveRewardManager 16 | from .prime import PrimeRewardManager -------------------------------------------------------------------------------- /verl/verl/workers/reward_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPORewardModel 16 | -------------------------------------------------------------------------------- /verl/verl/workers/reward_model/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The base class for reward model 16 | """ 17 | 18 | from abc import ABC, abstractmethod 19 | 20 | from verl import DataProto 21 | 22 | 23 | class BasePPORewardModel(ABC): 24 | 25 | def __init__(self, config): 26 | self.config = config 27 | 28 | @abstractmethod 29 | def compute_reward(self, data: DataProto) -> DataProto: 30 | """Computing reward given input_ids. The transformers should output a tensor with shape 31 | [batch_size, sequence_length], and the value at [EOS] mask should be gathered. 32 | 33 | Args: 34 | data: must contain keys "input_ids", "attention_mask" and "position_ids". 35 | - input_ids: [batch_size, sequence_length] 36 | - attention_mask: [batch_size, sequence_length] 37 | - position_ids: [batch_size, sequence_length] 38 | 39 | Returns: a data pass protocol containing "reward". Only the [EOS] position contains the reward. 40 | Other position should have zero reward. Note that this may change in the future if we use 41 | dense reward. So, we leave the interface for general case. 42 | - reward: [batch_size, sequence_length]. 43 | 44 | """ 45 | pass 46 | -------------------------------------------------------------------------------- /verl/verl/workers/reward_model/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .reward_model import MegatronRewardModel 16 | -------------------------------------------------------------------------------- /verl/verl/workers/rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BaseRollout 16 | from .naive import NaiveRollout 17 | from .hf_rollout import HFRollout 18 | 19 | __all__ = ["BaseRollout", "NaiveRollout", "HFRollout"] 20 | -------------------------------------------------------------------------------- /verl/verl/workers/rollout/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from typing import Iterable, Union 17 | 18 | from verl import DataProto 19 | 20 | __all__ = ['BaseRollout'] 21 | 22 | 23 | class BaseRollout(ABC): 24 | 25 | def __init__(self): 26 | """ 27 | 28 | Args: 29 | dataloader: an Iterable of TensorDict that consistently generates prompts. Note that the dataloader 30 | should handle when the training stops. 31 | """ 32 | super().__init__() 33 | 34 | @abstractmethod 35 | def generate_sequences(self, prompts: DataProto) -> DataProto: 36 | """Generate sequences""" 37 | pass 38 | -------------------------------------------------------------------------------- /verl/verl/workers/rollout/naive/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .naive_rollout import NaiveRollout 16 | -------------------------------------------------------------------------------- /verl/verl/workers/rollout/vllm_rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from importlib.metadata import version, PackageNotFoundError 16 | 17 | 18 | def get_version(pkg): 19 | try: 20 | return version(pkg) 21 | except PackageNotFoundError: 22 | return None 23 | 24 | 25 | package_name = 'vllm' 26 | package_version = get_version(package_name) 27 | 28 | if package_version <= '0.6.3': 29 | vllm_mode = 'customized' 30 | from .vllm_rollout import vLLMRollout 31 | from .fire_vllm_rollout import FIREvLLMRollout 32 | else: 33 | vllm_mode = 'spmd' 34 | from .vllm_rollout_spmd import vLLMRollout 35 | -------------------------------------------------------------------------------- /verl/verl/workers/sharding_manager/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from verl.utils.import_utils import is_vllm_available, is_megatron_core_available 16 | 17 | from .base import BaseShardingManager 18 | from .fsdp_ulysses import FSDPUlyssesShardingManager 19 | 20 | AllGatherPPModel = None 21 | 22 | if is_megatron_core_available() and is_vllm_available(): 23 | from .megatron_vllm import AllGatherPPModel, MegatronVLLMShardingManager 24 | elif AllGatherPPModel is not None: 25 | pass 26 | else: 27 | AllGatherPPModel = None 28 | MegatronVLLMShardingManager = None 29 | 30 | if is_vllm_available(): 31 | from .fsdp_vllm import FSDPVLLMShardingManager 32 | else: 33 | FSDPVLLMShardingManager = None 34 | -------------------------------------------------------------------------------- /verl/verl/workers/sharding_manager/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Sharding manager to implement HybridEngine 16 | """ 17 | 18 | from verl import DataProto 19 | 20 | 21 | class BaseShardingManager: 22 | 23 | def __enter__(self): 24 | pass 25 | 26 | def __exit__(self, exc_type, exc_value, traceback): 27 | pass 28 | 29 | def preprocess_data(self, data: DataProto) -> DataProto: 30 | return data 31 | 32 | def postprocess_data(self, data: DataProto) -> DataProto: 33 | return data 34 | --------------------------------------------------------------------------------