├── .DS_Store
├── LICENSE
├── README.md
├── figures
├── comp_raftpp_grpo_KL.png
├── comp_raftpp_grpo_entropy.png
├── comp_raftpp_grpo_llama_entropy.png
├── comp_raftpp_grpo_llama_kl.png
├── entropy_reinforce_rej.png
├── kl_reinforce_rej.png
└── reward_reinforce_rej.png
├── pyproject.toml
├── requirements.txt
├── scripts
├── data_preprocess
│ ├── gsm8k.py
│ ├── math_dataset.py
│ └── numina_math.py
├── run_grpo.sh
├── run_ppo.sh
├── run_raft.sh
├── run_raftpp.sh
└── run_reinforce_rej.sh
├── setup.py
└── verl
├── __init__.py
├── models
├── README.md
├── __init__.py
├── llama
│ ├── __init__.py
│ └── megatron
│ │ ├── __init__.py
│ │ ├── checkpoint_utils
│ │ ├── __init__.py
│ │ ├── llama_loader.py
│ │ ├── llama_loader_depracated.py
│ │ └── llama_saver.py
│ │ ├── layers
│ │ ├── __init__.py
│ │ ├── parallel_attention.py
│ │ ├── parallel_decoder.py
│ │ ├── parallel_linear.py
│ │ ├── parallel_mlp.py
│ │ └── parallel_rmsnorm.py
│ │ └── modeling_llama_megatron.py
├── mcore
│ ├── __init__.py
│ ├── gpt_model.py
│ ├── loader.py
│ └── saver.py
├── qwen2
│ ├── __init__.py
│ └── megatron
│ │ ├── __init__.py
│ │ ├── checkpoint_utils
│ │ ├── __init__.py
│ │ ├── qwen2_loader.py
│ │ ├── qwen2_loader_depracated.py
│ │ └── qwen2_saver.py
│ │ ├── layers
│ │ ├── __init__.py
│ │ ├── parallel_attention.py
│ │ ├── parallel_decoder.py
│ │ ├── parallel_linear.py
│ │ ├── parallel_mlp.py
│ │ └── parallel_rmsnorm.py
│ │ └── modeling_qwen2_megatron.py
├── registry.py
├── transformers
│ ├── __init__.py
│ ├── llama.py
│ ├── monkey_patch.py
│ ├── qwen2.py
│ └── qwen2_vl.py
└── weight_loader_registry.py
├── protocol.py
├── single_controller
├── __init__.py
├── base
│ ├── __init__.py
│ ├── decorator.py
│ ├── megatron
│ │ ├── __init__.py
│ │ ├── worker.py
│ │ └── worker_group.py
│ ├── register_center
│ │ ├── __init__.py
│ │ └── ray.py
│ ├── worker.py
│ └── worker_group.py
└── ray
│ ├── __init__.py
│ ├── base.py
│ └── megatron.py
├── third_party
├── __init__.py
├── sglang
│ ├── __init__.py
│ └── parallel_state.py
└── vllm
│ ├── __init__.py
│ ├── vllm_v_0_3_1
│ ├── __init__.py
│ ├── arg_utils.py
│ ├── config.py
│ ├── llm.py
│ ├── llm_engine_sp.py
│ ├── model_loader.py
│ ├── model_runner.py
│ ├── parallel_state.py
│ ├── tokenizer.py
│ ├── weight_loaders.py
│ └── worker.py
│ ├── vllm_v_0_4_2
│ ├── __init__.py
│ ├── arg_utils.py
│ ├── config.py
│ ├── dtensor_weight_loaders.py
│ ├── hf_weight_loader.py
│ ├── llm.py
│ ├── llm_engine_sp.py
│ ├── megatron_weight_loaders.py
│ ├── model_loader.py
│ ├── model_runner.py
│ ├── parallel_state.py
│ ├── spmd_gpu_executor.py
│ ├── tokenizer.py
│ └── worker.py
│ ├── vllm_v_0_5_4
│ ├── __init__.py
│ ├── arg_utils.py
│ ├── config.py
│ ├── dtensor_weight_loaders.py
│ ├── hf_weight_loader.py
│ ├── llm.py
│ ├── llm_engine_sp.py
│ ├── megatron_weight_loaders.py
│ ├── model_loader.py
│ ├── model_runner.py
│ ├── parallel_state.py
│ ├── spmd_gpu_executor.py
│ ├── tokenizer.py
│ └── worker.py
│ └── vllm_v_0_6_3
│ ├── __init__.py
│ ├── arg_utils.py
│ ├── config.py
│ ├── dtensor_weight_loaders.py
│ ├── hf_weight_loader.py
│ ├── llm.py
│ ├── llm_engine_sp.py
│ ├── megatron_weight_loaders.py
│ ├── model_loader.py
│ ├── model_runner.py
│ ├── parallel_state.py
│ ├── spmd_gpu_executor.py
│ ├── tokenizer.py
│ └── worker.py
├── trainer
├── __init__.py
├── config
│ ├── evaluation.yaml
│ ├── generation.yaml
│ ├── ppo_megatron_trainer.yaml
│ ├── ppo_trainer.yaml
│ └── sft_trainer.yaml
├── fsdp_sft_trainer.py
├── main_eval.py
├── main_generation.py
├── main_ppo.py
├── ppo
│ ├── __init__.py
│ ├── core_algos.py
│ ├── metric_utils.py
│ └── ray_trainer.py
└── runtime_env.yaml
├── utils
├── __init__.py
├── checkpoint
│ ├── __init__.py
│ ├── checkpoint_manager.py
│ ├── fsdp_checkpoint_manager.py
│ └── megatron_checkpoint_manager.py
├── config.py
├── dataset
│ ├── README.md
│ ├── __init__.py
│ ├── multiturn_sft_dataset.py
│ ├── rl_dataset.py
│ ├── rm_dataset.py
│ └── sft_dataset.py
├── debug
│ ├── __init__.py
│ ├── performance.py
│ └── trajectory_tracker.py
├── distributed.py
├── flops_counter.py
├── fs.py
├── fsdp_utils.py
├── hdfs_io.py
├── import_utils.py
├── logger
│ ├── __init__.py
│ └── aggregate_logger.py
├── logging_utils.py
├── megatron
│ ├── __init__.py
│ ├── memory.py
│ ├── optimizer.py
│ ├── pipeline_parallel.py
│ ├── sequence_parallel.py
│ └── tensor_parallel.py
├── megatron_utils.py
├── memory_buffer.py
├── model.py
├── py_functional.py
├── ray_utils.py
├── rendezvous
│ ├── __init__.py
│ └── ray_backend.py
├── reward_score
│ ├── __init__.py
│ ├── geo3k.py
│ ├── gsm8k.py
│ ├── math.py
│ ├── math_batch.py
│ ├── math_dapo.py
│ ├── math_verify.py
│ ├── prime_code
│ │ ├── __init__.py
│ │ ├── testing_util.py
│ │ └── utils.py
│ └── prime_math
│ │ ├── __init__.py
│ │ ├── grader.py
│ │ └── math_normalize.py
├── seqlen_balancing.py
├── tokenizer.py
├── torch_dtypes.py
├── torch_functional.py
├── tracking.py
└── ulysses.py
├── version
└── version
└── workers
├── __init__.py
├── actor
├── __init__.py
├── base.py
├── dp_actor.py
└── megatron_actor.py
├── critic
├── __init__.py
├── base.py
├── dp_critic.py
└── megatron_critic.py
├── fsdp_workers.py
├── megatron_workers.py
├── reward_manager
├── __init__.py
├── batch.py
├── dapo.py
├── naive.py
└── prime.py
├── reward_model
├── __init__.py
├── base.py
└── megatron
│ ├── __init__.py
│ └── reward_model.py
├── rollout
├── __init__.py
├── base.py
├── hf_rollout.py
├── naive
│ ├── __init__.py
│ └── naive_rollout.py
├── sglang_rollout
│ ├── __init__.py
│ └── sglang_rollout.py
├── tokenizer.py
└── vllm_rollout
│ ├── __init__.py
│ ├── fire_vllm_rollout.py
│ ├── vllm_rollout.py
│ └── vllm_rollout_spmd.py
└── sharding_manager
├── __init__.py
├── base.py
├── fsdp_sglang.py
├── fsdp_ulysses.py
├── fsdp_vllm.py
├── megatron_vllm.py
└── patch
├── __init__.py
└── fsdp_vllm_patch.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RLHFlow/Minimal-RL/e3b49f090ef4d4bd631f8c055e85f1f102e7dc1f/.DS_Store
--------------------------------------------------------------------------------
/figures/comp_raftpp_grpo_KL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RLHFlow/Minimal-RL/e3b49f090ef4d4bd631f8c055e85f1f102e7dc1f/figures/comp_raftpp_grpo_KL.png
--------------------------------------------------------------------------------
/figures/comp_raftpp_grpo_entropy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RLHFlow/Minimal-RL/e3b49f090ef4d4bd631f8c055e85f1f102e7dc1f/figures/comp_raftpp_grpo_entropy.png
--------------------------------------------------------------------------------
/figures/comp_raftpp_grpo_llama_entropy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RLHFlow/Minimal-RL/e3b49f090ef4d4bd631f8c055e85f1f102e7dc1f/figures/comp_raftpp_grpo_llama_entropy.png
--------------------------------------------------------------------------------
/figures/comp_raftpp_grpo_llama_kl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RLHFlow/Minimal-RL/e3b49f090ef4d4bd631f8c055e85f1f102e7dc1f/figures/comp_raftpp_grpo_llama_kl.png
--------------------------------------------------------------------------------
/figures/entropy_reinforce_rej.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RLHFlow/Minimal-RL/e3b49f090ef4d4bd631f8c055e85f1f102e7dc1f/figures/entropy_reinforce_rej.png
--------------------------------------------------------------------------------
/figures/kl_reinforce_rej.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RLHFlow/Minimal-RL/e3b49f090ef4d4bd631f8c055e85f1f102e7dc1f/figures/kl_reinforce_rej.png
--------------------------------------------------------------------------------
/figures/reward_reinforce_rej.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RLHFlow/Minimal-RL/e3b49f090ef4d4bd631f8c055e85f1f102e7dc1f/figures/reward_reinforce_rej.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # requirements.txt records the full set of dependencies for development
2 | accelerate
3 | codetiming
4 | datasets
5 | dill
6 | flash-attn
7 | hydra-core
8 | liger-kernel
9 | numpy
10 | pandas
11 | datasets
12 | peft
13 | pyarrow>=15.0.0
14 | pybind11
15 | pylatexenc
16 | pylint==3.3.6
17 | ray[default]
18 | tensordict<=0.6.2
19 | torchdata
20 | transformers
21 | # vllm==0.6.3.post1
22 | wandb
23 |
--------------------------------------------------------------------------------
/scripts/data_preprocess/gsm8k.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Preprocess the GSM8k dataset to parquet format
16 | """
17 |
18 | import re
19 | import os
20 | import datasets
21 |
22 | from verl.utils.hdfs_io import copy, makedirs
23 | import argparse
24 |
25 |
26 | def extract_solution(solution_str):
27 | solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
28 | assert solution is not None
29 | final_solution = solution.group(0)
30 | final_solution = final_solution.split('#### ')[1].replace(',', '')
31 | return final_solution
32 |
33 |
34 | if __name__ == '__main__':
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument('--local_dir', default='./data/gsm8k')
37 | parser.add_argument('--hdfs_dir', default=None)
38 |
39 | args = parser.parse_args()
40 |
41 | data_source = 'openai/gsm8k'
42 |
43 | dataset = datasets.load_dataset(data_source, 'main')
44 |
45 | train_dataset = dataset['train']
46 | test_dataset = dataset['test']
47 |
48 | instruction_following = "Let's think step by step and output the final answer after \"####\"."
49 |
50 | # add a row to each data item that represents a unique id
51 | def make_map_fn(split):
52 |
53 | def process_fn(example, idx):
54 | question_raw = example.pop('question')
55 |
56 | question = question_raw + ' ' + instruction_following
57 |
58 | answer_raw = example.pop('answer')
59 | solution = extract_solution(answer_raw)
60 | data = {
61 | "data_source": data_source,
62 | "prompt": [{
63 | "role": "user",
64 | "content": question,
65 | }],
66 | "ability": "math",
67 | "reward_model": {
68 | "style": "rule",
69 | "ground_truth": solution
70 | },
71 | "extra_info": {
72 | 'split': split,
73 | 'index': idx,
74 | 'answer': answer_raw,
75 | "question": question_raw,
76 | }
77 | }
78 | return data
79 |
80 | return process_fn
81 |
82 | train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
83 | test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
84 |
85 | local_dir = args.local_dir
86 | hdfs_dir = args.hdfs_dir
87 |
88 | train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
89 | test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
90 |
91 | if hdfs_dir is not None:
92 | makedirs(hdfs_dir)
93 |
94 | copy(src=local_dir, dst=hdfs_dir)
95 |
--------------------------------------------------------------------------------
/scripts/data_preprocess/math_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Preprocess the MATH-lighteval dataset to parquet format
16 | """
17 |
18 | import os
19 | import datasets
20 |
21 | from verl.utils.hdfs_io import copy, makedirs
22 | import argparse
23 |
24 | from verl.utils.reward_score.math import remove_boxed, last_boxed_only_string
25 |
26 |
27 | def extract_solution(solution_str):
28 | return remove_boxed(last_boxed_only_string(solution_str))
29 |
30 |
31 | if __name__ == '__main__':
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--local_dir', default='./data/math500')
34 | parser.add_argument('--hdfs_dir', default=None)
35 |
36 | args = parser.parse_args()
37 |
38 | # 'lighteval/MATH' is no longer available on huggingface.
39 | # Use mirror repo: DigitalLearningGmbH/MATH-lighteval
40 | data_source = 'DigitalLearningGmbH/MATH-lighteval'
41 | print(f"Loading the {data_source} dataset from huggingface...", flush=True)
42 | dataset = datasets.load_dataset(data_source, trust_remote_code=True)
43 |
44 | train_dataset = dataset['train']
45 | test_dataset = dataset['test']
46 |
47 | instruction_following = "Let's think step by step and output the final answer within \\boxed{}."
48 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
49 |
50 | # add a row to each data item that represents a unique id
51 | def make_map_fn(split):
52 |
53 | def process_fn(example, idx):
54 | question = example.pop('problem')
55 |
56 | question = question + ' ' + instruction_following
57 |
58 | answer = example.pop('solution')
59 | solution = extract_solution(answer)
60 | data = {
61 | "data_source": data_source,
62 | "prompt": [
63 | {
64 | "role": "system",
65 | "content": system_prompt
66 | },
67 | {
68 | "role": "user",
69 | "content": question
70 | }
71 | ],
72 | "ability": "math",
73 | "reward_model": {
74 | "style": "rule",
75 | "ground_truth": solution
76 | },
77 | "extra_info": {
78 | 'split': split,
79 | 'index': idx
80 | }
81 | }
82 | return data
83 |
84 | return process_fn
85 |
86 | train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
87 | test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
88 |
89 | local_dir = args.local_dir
90 | hdfs_dir = args.hdfs_dir
91 |
92 | train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
93 | test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
94 |
95 | if hdfs_dir is not None:
96 | makedirs(hdfs_dir)
97 |
98 | copy(src=local_dir, dst=hdfs_dir)
99 |
--------------------------------------------------------------------------------
/scripts/data_preprocess/numina_math.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess the Numia dataset to parquet format
3 | """
4 |
5 | import os
6 | import datasets
7 | from transformers import AutoTokenizer
8 |
9 | from verl.utils.hdfs_io import copy, makedirs
10 | import argparse
11 |
12 | from verl.utils.reward_score.math import remove_boxed, last_boxed_only_string
13 |
14 |
15 | def extract_solution(solution_str):
16 | return remove_boxed(last_boxed_only_string(solution_str))
17 |
18 |
19 | if __name__ == '__main__':
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--local_dir', default='./data/numina_math')
22 | parser.add_argument('--hdfs_dir', default=None)
23 | parser.add_argument('--train_start', type=int, default=0)
24 | parser.add_argument('--train_end', type=int, default=10000000)
25 | parser.add_argument('--model_name_or_path', type=str, default='Qwen/Qwen2.5-Math-1.5B')
26 |
27 | args = parser.parse_args()
28 |
29 | data_source = 'ScaleML-RLHF/numina_math'
30 | print(f"Loading the {data_source} dataset from huggingface...", flush=True)
31 | dataset = datasets.load_dataset(data_source, trust_remote_code=True)
32 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
33 |
34 | train_dataset = dataset['train']
35 | args.train_end = min(args.train_end, len(train_dataset))
36 | if args.train_end > 0:
37 | train_dataset = train_dataset.select(range(args.train_start, args.train_end))
38 |
39 | instruction_following = "Let's think step by step and output the final answer within \\boxed{}."
40 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
41 |
42 | # add a row to each data item that represents a unique id
43 | def make_map_fn(split):
44 |
45 | def process_fn(example, idx):
46 | question = example.pop('problem')
47 |
48 | question = question + ' ' + instruction_following
49 |
50 | # We set the data_source as MATH so that we can use the reward model designed for MATH dataset
51 |
52 | reward_model = {
53 | "style": "rule",
54 | "ground_truth": example['answer']
55 | }
56 |
57 | data = {
58 | "data_source": 'numina_math',
59 | "prompt": [
60 | {
61 | "role": "system",
62 | "content": system_prompt
63 | },
64 | {
65 | "role": "user",
66 | "content": question
67 | }
68 | ],
69 | "ability": "math",
70 | "reward_model": reward_model,
71 | "extra_info": {
72 | 'split': split,
73 | 'index': idx
74 | }
75 | }
76 | return data
77 |
78 | return process_fn
79 |
80 | def able_to_extract(example):
81 | if len(tokenizer.encode(example['problem'])) > 700:
82 | return False
83 | return True
84 |
85 | train_dataset = train_dataset.filter(able_to_extract)
86 |
87 | print(f"Train dataset size: {len(train_dataset)}")
88 |
89 | train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
90 | print(train_dataset[0])
91 | local_dir = args.local_dir
92 | hdfs_dir = args.hdfs_dir
93 | train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
94 |
95 | if hdfs_dir is not None:
96 | makedirs(hdfs_dir)
97 |
98 | copy(src=local_dir, dst=hdfs_dir)
99 |
100 |
--------------------------------------------------------------------------------
/scripts/run_grpo.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | export VLLM_ATTENTION_BACKEND=XFORMERS
4 | data=numina_math
5 | project_name=em-raft
6 | algorithm=grpo
7 | model=Qwen2.5-Math-1.5B
8 | model_name_or_path=Qwen/$model
9 | n=4
10 | experiment_name=${model}-${algorithm}-${data}-n${n}
11 | GPUS=(0 1 2 3 4 5 6 7)
12 | my_world_size=${#GPUS[@]}
13 |
14 | math_train_path=./data/$data/train.parquet
15 | math_test_path=./data/math500/test.parquet
16 |
17 | train_files="['$math_train_path']"
18 | test_files="['$math_test_path']"
19 |
20 | mkdir -p logs/${project_name}
21 |
22 | CUDA_VISIBLE_DEVICES=$(IFS=,; echo "${GPUS[*]}") python3 -m verl.trainer.main_ppo \
23 | algorithm.adv_estimator=$algorithm \
24 | data.train_files="$train_files" \
25 | data.val_files="$test_files" \
26 | data.train_batch_size=1024 \
27 | data.max_prompt_length=1024 \
28 | data.max_response_length=3072 \
29 | data.filter_overlong_prompts=True \
30 | data.truncation='error' \
31 | actor_rollout_ref.model.path=$model_name_or_path \
32 | actor_rollout_ref.actor.optim.lr=1e-6 \
33 | actor_rollout_ref.model.use_remove_padding=True \
34 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \
35 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
36 | actor_rollout_ref.actor.use_kl_loss=True \
37 | actor_rollout_ref.actor.kl_loss_coef=0.001 \
38 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \
39 | actor_rollout_ref.model.enable_gradient_checkpointing=True \
40 | actor_rollout_ref.actor.fsdp_config.param_offload=False \
41 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
42 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
43 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
44 | actor_rollout_ref.rollout.name=vllm \
45 | actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
46 | actor_rollout_ref.rollout.n=$n \
47 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
48 | actor_rollout_ref.ref.fsdp_config.param_offload=True \
49 | algorithm.kl_ctrl.kl_coef=0.001 \
50 | trainer.critic_warmup=0 \
51 | trainer.logger=['console','wandb'] \
52 | trainer.project_name=${project_name} \
53 | trainer.experiment_name=${experiment_name} \
54 | trainer.n_gpus_per_node=$my_world_size \
55 | trainer.val_before_train=True \
56 | trainer.nnodes=1 \
57 | trainer.save_freq=5 \
58 | trainer.default_local_dir=checkpoints/${project_name}/${experiment_name} \
59 | trainer.test_freq=5 \
60 | trainer.total_epochs=1 2>&1 | tee logs/${project_name}/${experiment_name}.log
--------------------------------------------------------------------------------
/scripts/run_ppo.sh:
--------------------------------------------------------------------------------
1 | set -x
2 | export VLLM_ATTENTION_BACKEND=XFORMERS
3 |
4 | data=numina_math
5 | project_name=em-raft
6 | algorithm=ppo
7 | model=Qwen2.5-Math-1.5B
8 | model_name_or_path=Qwen/$model
9 | policy_loss=plusplus
10 | n=1
11 | experiment_name=${model}-${algorithm}-${policy_loss}-${data}-n${n}
12 | GPUS=(0 1 2 3 4 5 6 7)
13 | my_world_size=${#GPUS[@]}
14 | train_path=./data/$data/train.parquet
15 | test_path=./data/math500/test.parquet
16 |
17 | train_files="['$train_path']"
18 | test_files="['$test_path']"
19 |
20 | mkdir -p logs/${project_name}
21 |
22 | CUDA_VISIBLE_DEVICES=$(IFS=,; echo "${GPUS[*]}") python3 -m verl.trainer.main_ppo \
23 | data.train_files="$train_files" \
24 | data.val_files="$test_files" \
25 | data.train_batch_size=1024 \
26 | data.max_prompt_length=1024 \
27 | data.max_response_length=3072 \
28 | data.filter_overlong_prompts=True \
29 | data.truncation='error' \
30 | actor_rollout_ref.model.path=$model_name_or_path \
31 | actor_rollout_ref.actor.optim.lr=1e-6 \
32 | actor_rollout_ref.model.use_remove_padding=True \
33 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \
34 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
35 | actor_rollout_ref.actor.policy_loss=$policy_loss \
36 | actor_rollout_ref.model.enable_gradient_checkpointing=True \
37 | actor_rollout_ref.actor.fsdp_config.param_offload=True \
38 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
39 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
40 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
41 | actor_rollout_ref.rollout.name=vllm \
42 | actor_rollout_ref.rollout.n=$n \
43 | actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
44 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
45 | actor_rollout_ref.ref.fsdp_config.param_offload=True \
46 | critic.optim.lr=1e-5 \
47 | critic.model.use_remove_padding=True \
48 | critic.model.path=$model_name_or_path \
49 | critic.model.enable_gradient_checkpointing=True \
50 | critic.model.fsdp_config.param_offload=True \
51 | critic.model.fsdp_config.optimizer_offload=True \
52 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
53 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
54 | actor_rollout_ref.rollout.name=vllm \
55 | actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
56 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
57 | actor_rollout_ref.ref.fsdp_config.param_offload=True \
58 | critic.ppo_micro_batch_size_per_gpu=4 \
59 | algorithm.kl_ctrl.kl_coef=0.001 \
60 | trainer.critic_warmup=0 \
61 | trainer.logger=['console','wandb'] \
62 | trainer.project_name=${project_name} \
63 | trainer.experiment_name=${experiment_name} \
64 | trainer.nnodes=1 \
65 | trainer.n_gpus_per_node=$my_world_size \
66 | trainer.default_local_dir=checkpoints/${project_name}/${experiment_name} \
67 | trainer.val_before_train=True \
68 | trainer.save_freq=5 \
69 | trainer.test_freq=5 \
70 | trainer.total_epochs=1 2>&1 | tee -a logs/${project_name}/${experiment_name}.log
71 |
--------------------------------------------------------------------------------
/scripts/run_raft.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | export VLLM_ATTENTION_BACKEND=XFORMERS
4 |
5 | data=numina_math
6 | project_name=raft++
7 | algorithm=raft
8 | model=Qwen2.5-Math-1.5B
9 | model_name_or_path=Qwen/$model
10 | policy_loss=vanilla # vanilla, plusplus (importance sample + clipping)
11 | n=4
12 | experiment_name=${model}-${algorithm}-${policy_loss}-${data}-n${n}
13 | GPUS=(0 1 2 3 4 5 6 7)
14 | my_world_size=${#GPUS[@]}
15 |
16 | math_train_path=./data/$data/train.parquet
17 | math_test_path=./data/math500/test.parquet
18 |
19 | train_files="['$math_train_path']"
20 | test_files="['$math_test_path']"
21 |
22 | mkdir -p logs/${project_name}
23 |
24 | CUDA_VISIBLE_DEVICES=$(IFS=,; echo "${GPUS[*]}") python3 -m verl.trainer.main_ppo \
25 | algorithm.adv_estimator=$algorithm \
26 | data.train_files="$train_files" \
27 | data.val_files="$test_files" \
28 | data.train_batch_size=1024 \
29 | data.max_prompt_length=1024 \
30 | data.max_response_length=3072 \
31 | data.filter_overlong_prompts=True \
32 | actor_rollout_ref.model.path=$model_name_or_path \
33 | actor_rollout_ref.actor.optim.lr=1e-6 \
34 | actor_rollout_ref.model.use_remove_padding=True \
35 | actor_rollout_ref.model.enable_gradient_checkpointing=True \
36 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \
37 | actor_rollout_ref.actor.use_dynamic_bsz=True \
38 | actor_rollout_ref.actor.fsdp_config.param_offload=False \
39 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
40 | actor_rollout_ref.actor.use_kl_loss=True \
41 | actor_rollout_ref.actor.kl_loss_coef=0.001 \
42 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \
43 | actor_rollout_ref.actor.policy_loss=$policy_loss \
44 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
45 | actor_rollout_ref.rollout.name=vllm \
46 | actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
47 | actor_rollout_ref.rollout.n=$n \
48 | actor_rollout_ref.rollout.max_num_batched_tokens=8192 \
49 | actor_rollout_ref.ref.fsdp_config.param_offload=True \
50 | algorithm.kl_ctrl.kl_coef=0.001 \
51 | trainer.critic_warmup=0 \
52 | trainer.logger=['console','wandb'] \
53 | trainer.project_name=${project_name} \
54 | trainer.experiment_name=${experiment_name} \
55 | trainer.n_gpus_per_node=$my_world_size \
56 | trainer.val_before_train=True \
57 | trainer.nnodes=1 \
58 | trainer.save_freq=5 \
59 | trainer.default_local_dir=checkpoints/${project_name}/${experiment_name} \
60 | trainer.test_freq=5 \
61 | trainer.total_epochs=1 2>&1 | tee -a logs/${project_name}/${experiment_name}.log
62 |
63 |
--------------------------------------------------------------------------------
/scripts/run_raftpp.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | export VLLM_ATTENTION_BACKEND=XFORMERS
4 |
5 | data=numina_math
6 | project_name=raft++
7 | algorithm=raft
8 | model=Qwen2.5-Math-1.5B
9 | model_name_or_path=Qwen/$model
10 | policy_loss=plusplus # vanilla, plusplus (importance sample + clipping)
11 | n=4
12 | experiment_name=${model}-${algorithm}-${policy_loss}-${data}-n${n}
13 | GPUS=(0 1 2 3 4 5 6 7)
14 | my_world_size=${#GPUS[@]}
15 |
16 | math_train_path=./data/$data/train.parquet
17 | math_test_path=./data/math500/test.parquet
18 |
19 | train_files="['$math_train_path']"
20 | test_files="['$math_test_path']"
21 |
22 | mkdir -p logs/${project_name}
23 |
24 | CUDA_VISIBLE_DEVICES=$(IFS=,; echo "${GPUS[*]}") python3 -m verl.trainer.main_ppo \
25 | algorithm.adv_estimator=$algorithm \
26 | data.train_files="$train_files" \
27 | data.val_files="$test_files" \
28 | data.train_batch_size=1024 \
29 | data.max_prompt_length=1024 \
30 | data.max_response_length=3072 \
31 | data.filter_overlong_prompts=True \
32 | actor_rollout_ref.model.path=$model_name_or_path \
33 | actor_rollout_ref.actor.optim.lr=1e-6 \
34 | actor_rollout_ref.model.use_remove_padding=True \
35 | actor_rollout_ref.model.enable_gradient_checkpointing=True \
36 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \
37 | actor_rollout_ref.actor.use_dynamic_bsz=True \
38 | actor_rollout_ref.actor.fsdp_config.param_offload=False \
39 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
40 | actor_rollout_ref.actor.use_kl_loss=True \
41 | actor_rollout_ref.actor.kl_loss_coef=0.001 \
42 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \
43 | actor_rollout_ref.actor.policy_loss=$policy_loss \
44 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
45 | actor_rollout_ref.rollout.name=vllm \
46 | actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
47 | actor_rollout_ref.rollout.n=$n \
48 | actor_rollout_ref.rollout.max_num_batched_tokens=8192 \
49 | actor_rollout_ref.ref.fsdp_config.param_offload=True \
50 | algorithm.kl_ctrl.kl_coef=0.001 \
51 | trainer.critic_warmup=0 \
52 | trainer.logger=['console','wandb'] \
53 | trainer.project_name=${project_name} \
54 | trainer.experiment_name=${experiment_name} \
55 | trainer.n_gpus_per_node=$my_world_size \
56 | trainer.val_before_train=True \
57 | trainer.nnodes=1 \
58 | trainer.save_freq=5 \
59 | trainer.default_local_dir=checkpoints/${project_name}/${experiment_name} \
60 | trainer.test_freq=5 \
61 | trainer.total_epochs=1 2>&1 | tee -a logs/${project_name}/${experiment_name}.log
62 |
63 |
--------------------------------------------------------------------------------
/scripts/run_reinforce_rej.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | export VLLM_ATTENTION_BACKEND=XFORMERS
4 |
5 | data=numina_math
6 | project_name=reinforce_rej
7 | algorithm=reinforce_rej
8 | model=Qwen2.5-Math-1.5B
9 | model_name_or_path=Qwen/$model
10 | policy_loss=plusplus # vanilla, plusplus (importance sample + clipping)
11 | n=4
12 | experiment_name=${model}-${algorithm}-${policy_loss}-${data}-n${n}
13 | GPUS=(0 1 2 3 4 5 6 7)
14 | my_world_size=${#GPUS[@]}
15 |
16 | math_train_path=./data/$data/train.parquet
17 | math_test_path=./data/math500/test.parquet
18 |
19 | train_files="['$math_train_path']"
20 | test_files="['$math_test_path']"
21 |
22 | mkdir -p logs/${project_name}
23 |
24 | CUDA_VISIBLE_DEVICES=$(IFS=,; echo "${GPUS[*]}") python3 -m verl.trainer.main_ppo \
25 | algorithm.adv_estimator=$algorithm \
26 | data.train_files="$train_files" \
27 | data.val_files="$test_files" \
28 | data.train_batch_size=1024 \
29 | data.max_prompt_length=1024 \
30 | data.max_response_length=3072 \
31 | data.filter_overlong_prompts=True \
32 | actor_rollout_ref.model.path=$model_name_or_path \
33 | actor_rollout_ref.actor.optim.lr=1e-6 \
34 | actor_rollout_ref.model.use_remove_padding=True \
35 | actor_rollout_ref.model.enable_gradient_checkpointing=True \
36 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \
37 | actor_rollout_ref.actor.use_dynamic_bsz=True \
38 | actor_rollout_ref.actor.fsdp_config.param_offload=False \
39 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
40 | actor_rollout_ref.actor.use_kl_loss=True \
41 | actor_rollout_ref.actor.kl_loss_coef=0.001 \
42 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \
43 | actor_rollout_ref.actor.policy_loss=$policy_loss \
44 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
45 | actor_rollout_ref.rollout.name=vllm \
46 | actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
47 | actor_rollout_ref.rollout.n=$n \
48 | actor_rollout_ref.rollout.max_num_batched_tokens=8192 \
49 | actor_rollout_ref.ref.fsdp_config.param_offload=True \
50 | algorithm.kl_ctrl.kl_coef=0.001 \
51 | trainer.critic_warmup=0 \
52 | trainer.logger=['console','wandb'] \
53 | trainer.project_name=${project_name} \
54 | trainer.experiment_name=${experiment_name} \
55 | trainer.n_gpus_per_node=$my_world_size \
56 | trainer.val_before_train=True \
57 | trainer.nnodes=1 \
58 | trainer.save_freq=5 \
59 | trainer.default_local_dir=checkpoints/${project_name}/${experiment_name} \
60 | trainer.test_freq=5 \
61 | trainer.total_epochs=1 2>&1 | tee -a logs/${project_name}/${experiment_name}.log
62 |
63 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # setup.py is the fallback installation script when pyproject.toml does not work
16 | from setuptools import setup, find_packages
17 | import os
18 |
19 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
20 |
21 | with open(os.path.join(version_folder, 'verl/version/version')) as f:
22 | __version__ = f.read().strip()
23 |
24 | install_requires = [
25 | 'accelerate',
26 | 'codetiming',
27 | 'datasets',
28 | 'dill',
29 | 'hydra-core',
30 | 'numpy',
31 | 'pandas',
32 | 'datasets',
33 | 'peft',
34 | 'pyarrow>=15.0.0',
35 | 'pybind11',
36 | 'pylatexenc',
37 | 'ray[default]>=2.10',
38 | 'tensordict<=0.6.2',
39 | 'torchdata',
40 | 'transformers',
41 | 'wandb',
42 | ]
43 |
44 | TEST_REQUIRES = ['pytest', 'yapf', 'py-spy']
45 | PRIME_REQUIRES = ['pyext']
46 | GEO_REQUIRES = ['mathruler']
47 | GPU_REQUIRES = ['liger-kernel', 'flash-attn']
48 | MATH_REQUIRES = ['math-verify'] # Add math-verify as an optional dependency
49 | VLLM_REQUIRES = ['tensordict<=0.6.2', 'vllm<=0.8.2']
50 | SGLANG_REQUIRES = [
51 | 'tensordict<=0.6.2',
52 | 'sglang[all]==0.4.4.post4',
53 | 'torch-memory-saver>=0.0.5'
54 | ]
55 |
56 | extras_require = {
57 | 'test': TEST_REQUIRES,
58 | 'prime': PRIME_REQUIRES,
59 | 'geo': GEO_REQUIRES,
60 | 'gpu': GPU_REQUIRES,
61 | 'math': MATH_REQUIRES,
62 | 'vllm': VLLM_REQUIRES,
63 | 'sglang': SGLANG_REQUIRES,
64 | }
65 |
66 | from pathlib import Path
67 | this_directory = Path(__file__).parent
68 | long_description = (this_directory / "README.md").read_text()
69 |
70 | setup(
71 | name='verl',
72 | version=__version__,
73 | package_dir={'': '.'},
74 | packages=find_packages(where='.'),
75 | url='https://github.com/volcengine/verl',
76 | license='Apache 2.0',
77 | author='Bytedance - Seed - MLSys',
78 | author_email='zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk',
79 | description='verl: Volcano Engine Reinforcement Learning for LLM',
80 | install_requires=install_requires,
81 | extras_require=extras_require,
82 | package_data={'': ['version/*'],
83 | 'verl': ['trainer/config/*.yaml'],},
84 | include_package_data=True,
85 | long_description=long_description,
86 | long_description_content_type='text/markdown'
87 | )
--------------------------------------------------------------------------------
/verl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
18 |
19 | with open(os.path.join(version_folder, 'version/version')) as f:
20 | __version__ = f.read().strip()
21 |
22 | from .protocol import DataProto
23 |
24 | from .utils.logging_utils import set_basic_config
25 | import logging
26 |
27 | set_basic_config(level=logging.WARNING)
28 |
29 | from . import single_controller
30 |
31 | __all__ = ['DataProto', "__version__"]
32 |
33 | if os.getenv('VERL_USE_MODELSCOPE', 'False').lower() == 'true':
34 | import importlib
35 | if importlib.util.find_spec("modelscope") is None:
36 | raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`')
37 | # Patch hub to download models from modelscope to speed up.
38 | from modelscope.utils.hf_util import patch_hub
39 | patch_hub()
40 |
--------------------------------------------------------------------------------
/verl/models/README.md:
--------------------------------------------------------------------------------
1 | # Models
2 | Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl.
3 | ## Adding a New Huggingface Model
4 | ### Step 1: Copy the model file from HF to verl
5 | - Add a new file under verl/models/hf
6 | - Copy ONLY the model file from huggingface/transformers/models to verl/models/hf
7 |
8 | ### Step 2: Modify the model file to use packed inputs
9 | - Remove all the code related to inference (kv cache)
10 | - Modify the inputs to include only
11 | - input_ids (total_nnz,)
12 | - cu_seqlens (total_nnz + 1,)
13 | - max_seqlen_in_batch: int
14 | - Note that this requires using flash attention with causal mask.
15 |
16 | ### Step 2.5: Add tests
17 | - Add a test to compare this version and the huggingface version
18 | - Following the infrastructure and add tests to tests/models/hf
19 |
20 | ### Step 3: Add a function to apply tensor parallelism
21 | - Please follow
22 | - https://pytorch.org/docs/stable/distributed.tensor.parallel.html
23 | - https://pytorch.org/tutorials/intermediate/TP_tutorial.html
24 | - General comments
25 | - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward.
26 |
27 | ### Step 4: Add a function to apply data parallelism
28 | - Please use FSDP2 APIs
29 | - See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413
30 |
31 | ### Step 5: Add a function to apply pipeline parallelism
32 | - Comes in Pytorch 2.4
33 | - Currently only in alpha in nightly version
34 | - Check torchtitan for more details
35 |
36 |
--------------------------------------------------------------------------------
/verl/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/llama/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/llama/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .modeling_llama_megatron import (
16 | # original model with megatron
17 | ParallelLlamaModel,
18 | ParallelLlamaForCausalLM,
19 | # rmpad with megatron
20 | ParallelLlamaForCausalLMRmPad,
21 | ParallelLlamaForValueRmPad,
22 | # rmpad with megatron and pipeline parallelism
23 | ParallelLlamaForCausalLMRmPadPP,
24 | ParallelLlamaForValueRmPadPP)
25 |
--------------------------------------------------------------------------------
/verl/models/llama/megatron/checkpoint_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/llama/megatron/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .parallel_attention import ParallelLlamaAttention
16 | from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad
17 | from .parallel_mlp import ParallelLlamaMLP
18 | from .parallel_rmsnorm import ParallelLlamaRMSNorm
19 |
--------------------------------------------------------------------------------
/verl/models/llama/megatron/layers/parallel_linear.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
15 |
16 | from typing import Optional, Tuple
17 |
18 | from megatron.core import tensor_parallel
19 |
20 |
21 | class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
22 |
23 | def __init__(self,
24 | input_size,
25 | num_heads,
26 | num_key_value_heads,
27 | head_dim,
28 | *,
29 | bias=True,
30 | gather_output=True,
31 | skip_bias_add=False,
32 | **kwargs):
33 | # Keep input parameters, and already restrict the head numbers
34 | self.input_size = input_size
35 | self.q_output_size = num_heads * head_dim
36 | self.kv_output_size = num_key_value_heads * head_dim
37 | self.head_dim = head_dim
38 | self.gather_output = gather_output
39 | self.skip_bias_add = skip_bias_add
40 |
41 | input_size = self.input_size
42 | output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim
43 |
44 | super().__init__(input_size=input_size,
45 | output_size=output_size,
46 | bias=bias,
47 | gather_output=gather_output,
48 | skip_bias_add=skip_bias_add,
49 | **kwargs)
50 |
51 |
52 | class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
53 |
54 | def __init__(self,
55 | input_size,
56 | gate_ouput_size,
57 | up_output_size,
58 | *,
59 | bias=True,
60 | gather_output=True,
61 | skip_bias_add=False,
62 | **kwargs):
63 | # Keep input parameters, and already restrict the head numbers
64 | self.input_size = input_size
65 | self.output_size = gate_ouput_size + up_output_size
66 | self.gather_output = gather_output
67 | self.skip_bias_add = skip_bias_add
68 |
69 | super().__init__(input_size=self.input_size,
70 | output_size=self.output_size,
71 | bias=bias,
72 | gather_output=gather_output,
73 | skip_bias_add=skip_bias_add,
74 | **kwargs)
75 |
76 |
77 | import torch
78 |
79 |
80 | class LinearForLastLayer(torch.nn.Linear):
81 |
82 | def __init__(
83 | self,
84 | input_size,
85 | output_size,
86 | *,
87 | config,
88 | bias=True,
89 | ):
90 | super().__init__(in_features=input_size, out_features=output_size, bias=bias)
91 | self.sequence_parallel = config.sequence_parallel
92 | if self.sequence_parallel:
93 | setattr(self.weight, 'sequence_parallel', True)
94 |
95 | def forward(
96 | self,
97 | input_,
98 | weight=None,
99 | runtime_gather_output=None,
100 | ):
101 | logits = super().forward(input_)
102 | logits = logits.float()
103 | if self.sequence_parallel:
104 | logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
105 | return logits, None
106 |
--------------------------------------------------------------------------------
/verl/models/llama/megatron/layers/parallel_mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 |
21 | from megatron.core import parallel_state as mpu
22 | from megatron.core import tensor_parallel
23 | from megatron.core import ModelParallelConfig
24 | from torch import nn
25 | from transformers.activations import ACT2FN
26 | from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear
27 |
28 | from verl.utils.megatron import tensor_parallel as tp_utils
29 |
30 |
31 | class ParallelLlamaMLP(nn.Module):
32 |
33 | def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
34 | super().__init__()
35 | self.config = config
36 | self.hidden_size = config.hidden_size
37 | self.intermediate_size = config.intermediate_size
38 | # The weight is only [hidden_size, intermediate_size // model_parallel_world_size]
39 |
40 | column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
41 | row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
42 |
43 | if megatron_config is not None:
44 | assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
45 | assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
46 | tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
47 | tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
48 |
49 | tp_size = mpu.get_tensor_model_parallel_world_size()
50 |
51 | self.gate_up_proj = MergedColumnParallelLinear(
52 | input_size=self.hidden_size,
53 | gate_ouput_size=self.intermediate_size,
54 | up_output_size=self.intermediate_size,
55 | bias=False,
56 | gather_output=False,
57 | skip_bias_add=False,
58 | **column_kwargs,
59 | )
60 | self.gate_size = self.intermediate_size // tp_size
61 |
62 | self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size,
63 | output_size=self.hidden_size,
64 | bias=False,
65 | input_is_parallel=True,
66 | skip_bias_add=False,
67 | **row_kwargs)
68 |
69 | self.act_fn = ACT2FN[config.hidden_act]
70 |
71 | def forward(self, x):
72 | gate_up = self.gate_up_proj(x)[0]
73 | gate, up = gate_up.split(self.gate_size, dim=-1)
74 | return self.down_proj(self.act_fn(gate) * up)[0]
75 |
--------------------------------------------------------------------------------
/verl/models/llama/megatron/layers/parallel_rmsnorm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numbers
16 | import torch
17 | from megatron.core import ModelParallelConfig
18 | from torch import nn
19 | from transformers import LlamaConfig
20 |
21 | from apex.normalization.fused_layer_norm import fused_rms_norm_affine
22 | from verl.utils.megatron import sequence_parallel as sp_utils
23 |
24 |
25 | class ParallelLlamaRMSNorm(nn.Module):
26 |
27 | def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
28 | """
29 | LlamaRMSNorm is equivalent to T5LayerNorm
30 | """
31 | super().__init__()
32 | if isinstance(config.hidden_size, numbers.Integral):
33 | normalized_shape = (config.hidden_size,)
34 | self.normalized_shape = torch.Size(normalized_shape)
35 | self.weight = nn.Parameter(torch.ones(self.normalized_shape))
36 | self.variance_epsilon = config.rms_norm_eps
37 |
38 | if megatron_config.sequence_parallel:
39 | sp_utils.mark_parameter_as_sequence_parallel(self.weight)
40 |
41 | def forward(self, hidden_states):
42 | return fused_rms_norm_affine(input=hidden_states,
43 | weight=self.weight,
44 | normalized_shape=self.normalized_shape,
45 | eps=self.variance_epsilon,
46 | memory_efficient=True)
--------------------------------------------------------------------------------
/verl/models/mcore/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from .gpt_model import gptmodel_forward
--------------------------------------------------------------------------------
/verl/models/qwen2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/qwen2/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .modeling_qwen2_megatron import (
16 | # original model with megatron
17 | ParallelQwen2Model,
18 | ParallelQwen2ForCausalLM,
19 | # rmpad with megatron
20 | ParallelQwen2ForCausalLMRmPad,
21 | ParallelQwen2ForValueRmPad,
22 | # rmpad with megatron and pipeline parallelism
23 | ParallelQwen2ForCausalLMRmPadPP,
24 | ParallelQwen2ForValueRmPadPP)
25 |
--------------------------------------------------------------------------------
/verl/models/qwen2/megatron/checkpoint_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/qwen2/megatron/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .parallel_attention import ParallelQwen2Attention
16 | from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad
17 | from .parallel_mlp import ParallelQwen2MLP
18 | from .parallel_rmsnorm import ParallelQwen2RMSNorm
19 |
--------------------------------------------------------------------------------
/verl/models/qwen2/megatron/layers/parallel_linear.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
15 |
16 | from typing import Optional, Tuple
17 |
18 | from megatron.core import tensor_parallel
19 |
20 |
21 | class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
22 |
23 | def __init__(self,
24 | input_size,
25 | num_heads,
26 | num_key_value_heads,
27 | head_dim,
28 | *,
29 | bias=True,
30 | gather_output=True,
31 | skip_bias_add=False,
32 | **kwargs):
33 | # Keep input parameters, and already restrict the head numbers
34 | self.input_size = input_size
35 | self.q_output_size = num_heads * head_dim
36 | self.kv_output_size = num_key_value_heads * head_dim
37 | self.head_dim = head_dim
38 | self.gather_output = gather_output
39 | self.skip_bias_add = skip_bias_add
40 |
41 | input_size = self.input_size
42 | output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim
43 |
44 | super().__init__(input_size=input_size,
45 | output_size=output_size,
46 | bias=bias,
47 | gather_output=gather_output,
48 | skip_bias_add=skip_bias_add,
49 | **kwargs)
50 |
51 |
52 | class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
53 |
54 | def __init__(self,
55 | input_size,
56 | gate_ouput_size,
57 | up_output_size,
58 | *,
59 | bias=True,
60 | gather_output=True,
61 | skip_bias_add=False,
62 | **kwargs):
63 | # Keep input parameters, and already restrict the head numbers
64 | self.input_size = input_size
65 | self.output_size = gate_ouput_size + up_output_size
66 | self.gather_output = gather_output
67 | self.skip_bias_add = skip_bias_add
68 |
69 | super().__init__(input_size=self.input_size,
70 | output_size=self.output_size,
71 | bias=bias,
72 | gather_output=gather_output,
73 | skip_bias_add=skip_bias_add,
74 | **kwargs)
75 |
--------------------------------------------------------------------------------
/verl/models/qwen2/megatron/layers/parallel_mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 |
21 | from megatron.core import parallel_state as mpu
22 | from megatron.core import tensor_parallel
23 | from megatron.core import ModelParallelConfig
24 | from torch import nn
25 | from transformers.activations import ACT2FN
26 | from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear
27 |
28 | from verl.utils.megatron import tensor_parallel as tp_utils
29 |
30 |
31 | class ParallelQwen2MLP(nn.Module):
32 |
33 | def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
34 | super().__init__()
35 | self.config = config
36 | self.hidden_size = config.hidden_size
37 | self.intermediate_size = config.intermediate_size
38 | # The weight is only [hidden_size, intermediate_size // model_parallel_world_size]
39 |
40 | column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
41 | row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
42 |
43 | if megatron_config is not None:
44 | assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
45 | assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
46 | tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
47 | tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
48 |
49 | tp_size = mpu.get_tensor_model_parallel_world_size()
50 |
51 | self.gate_up_proj = MergedColumnParallelLinear(
52 | input_size=self.hidden_size,
53 | gate_ouput_size=self.intermediate_size,
54 | up_output_size=self.intermediate_size,
55 | bias=False,
56 | gather_output=False,
57 | skip_bias_add=False,
58 | **column_kwargs,
59 | )
60 | self.gate_size = self.intermediate_size // tp_size
61 |
62 | self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size,
63 | output_size=self.hidden_size,
64 | bias=False,
65 | input_is_parallel=True,
66 | skip_bias_add=False,
67 | **row_kwargs)
68 |
69 | self.act_fn = ACT2FN[config.hidden_act]
70 |
71 | def forward(self, x):
72 | gate_up = self.gate_up_proj(x)[0]
73 | gate, up = gate_up.split(self.gate_size, dim=-1)
74 | return self.down_proj(self.act_fn(gate) * up)[0]
75 |
--------------------------------------------------------------------------------
/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numbers
16 | import torch
17 | from megatron.core import ModelParallelConfig
18 | from torch import nn
19 | from transformers import Qwen2Config
20 |
21 | from apex.normalization.fused_layer_norm import fused_rms_norm_affine
22 | from verl.utils.megatron import sequence_parallel as sp_utils
23 |
24 |
25 | class ParallelQwen2RMSNorm(nn.Module):
26 |
27 | def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):
28 | """
29 | Qwen2RMSNorm is equivalent to T5LayerNorm
30 | """
31 | super().__init__()
32 | if isinstance(config.hidden_size, numbers.Integral):
33 | normalized_shape = (config.hidden_size,)
34 | self.normalized_shape = torch.Size(normalized_shape)
35 | self.weight = nn.Parameter(torch.ones(self.normalized_shape))
36 | self.variance_epsilon = config.rms_norm_eps
37 |
38 | if megatron_config.sequence_parallel:
39 | sp_utils.mark_parameter_as_sequence_parallel(self.weight)
40 |
41 | def forward(self, hidden_states):
42 | return fused_rms_norm_affine(input=hidden_states,
43 | weight=self.weight,
44 | normalized_shape=self.normalized_shape,
45 | eps=self.variance_epsilon,
46 | memory_efficient=True)
--------------------------------------------------------------------------------
/verl/models/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import importlib
16 | from typing import List, Optional, Type
17 |
18 | import torch.nn as nn
19 |
20 | # Supported models in Megatron-LM
21 | # Architecture -> (module, class).
22 | _MODELS = {
23 | "LlamaForCausalLM":
24 | ("llama", ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad")),
25 | "Qwen2ForCausalLM":
26 | ("qwen2", ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad")),
27 | "MistralForCausalLM": ("mistral", ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP",
28 | "ParallelMistralForCausalLMRmPad"))
29 | }
30 |
31 |
32 | # return model class
33 | class ModelRegistry:
34 |
35 | @staticmethod
36 | def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]:
37 | if model_arch not in _MODELS:
38 | return None
39 |
40 | megatron = "megatron"
41 |
42 | module_name, model_cls_name = _MODELS[model_arch]
43 | if not value: # actor/ref
44 | model_cls_name = model_cls_name[0]
45 | elif value: # critic/rm
46 | model_cls_name = model_cls_name[1]
47 |
48 | module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron")
49 | return getattr(module, model_cls_name, None)
50 |
51 | @staticmethod
52 | def get_supported_archs() -> List[str]:
53 | return list(_MODELS.keys())
54 |
--------------------------------------------------------------------------------
/verl/models/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/models/weight_loader_registry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | def get_weight_loader(arch: str):
17 | from verl.models.llama.megatron.checkpoint_utils.llama_loader import load_state_dict_to_megatron_llama
18 | from verl.models.qwen2.megatron.checkpoint_utils.qwen2_loader import load_state_dict_to_megatron_qwen2
19 | from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel
20 | _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = {
21 | 'LlamaForCausalLM': load_state_dict_to_megatron_gptmodel,
22 | 'Qwen2ForCausalLM': load_state_dict_to_megatron_gptmodel,
23 | }
24 |
25 | if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY:
26 | return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch]
27 | raise ValueError(f"Model architectures {arch} loader are not supported for now. "
28 | f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}")
29 |
30 |
31 | def get_weight_saver(arch: str):
32 | from verl.models.llama.megatron.checkpoint_utils.llama_saver import merge_megatron_ckpt_llama
33 | from verl.models.qwen2.megatron.checkpoint_utils.qwen2_saver import merge_megatron_ckpt_qwen2
34 | from verl.models.mcore.saver import merge_megatron_ckpt_gptmodel
35 | _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = {
36 | 'LlamaForCausalLM': merge_megatron_ckpt_gptmodel,
37 | 'Qwen2ForCausalLM': merge_megatron_ckpt_gptmodel,
38 | }
39 | if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY:
40 | return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch]
41 | raise ValueError(f"Model architectures {arch} saver are not supported for now. "
42 | f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}")
43 |
--------------------------------------------------------------------------------
/verl/single_controller/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
18 |
19 | # Note(haibin.lin): single_controller.__version__ is deprecated
20 | with open(os.path.join(os.path.join(version_folder, os.pardir), 'version/version')) as f:
21 | __version__ = f.read().strip()
22 |
23 | from . import base
24 | from .base import *
25 |
26 | __all__ = base.__all__
--------------------------------------------------------------------------------
/verl/single_controller/base/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .worker import Worker
16 | from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool
17 |
18 | __all__ = ['Worker', 'WorkerGroup', 'ClassWithInitArgs', 'ResourcePool']
--------------------------------------------------------------------------------
/verl/single_controller/base/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/single_controller/base/megatron/worker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo
16 |
17 |
18 | class MegatronWorker(Worker):
19 |
20 | def __init__(self, cuda_visible_devices=None) -> None:
21 | super().__init__(cuda_visible_devices)
22 |
23 | def get_megatron_global_info(self):
24 | from megatron.core import parallel_state as mpu
25 | tp_size = mpu.get_tensor_model_parallel_world_size()
26 | dp_size = mpu.get_data_parallel_world_size()
27 | pp_size = mpu.get_pipeline_model_parallel_world_size()
28 | cp_size = mpu.get_context_parallel_world_size()
29 | info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size, cp_size=cp_size)
30 | return info
31 |
32 | def get_megatron_rank_info(self):
33 | from megatron.core import parallel_state as mpu
34 | tp_rank = mpu.get_tensor_model_parallel_rank()
35 | dp_rank = mpu.get_data_parallel_rank()
36 | pp_rank = mpu.get_pipeline_model_parallel_rank()
37 | cp_rank = mpu.get_context_parallel_rank()
38 | info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank, cp_rank=cp_rank)
39 | return info
--------------------------------------------------------------------------------
/verl/single_controller/base/megatron/worker_group.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Dict
16 |
17 | from .worker import DistRankInfo, DistGlobalInfo
18 | from verl.single_controller.base import ResourcePool, WorkerGroup
19 |
20 |
21 | class MegatronWorkerGroup(WorkerGroup):
22 |
23 | def __init__(self, resource_pool: ResourcePool, **kwargs):
24 | super().__init__(resource_pool=resource_pool, **kwargs)
25 | self._megatron_rank_info = None
26 | self._megatron_global_info: DistGlobalInfo = None
27 |
28 | def init_megatron(self, default_megatron_kwargs: Dict = None):
29 | raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten")
30 |
31 | def get_megatron_rank_info(self, rank: int) -> DistRankInfo:
32 | assert 0 <= rank < self.world_size, f'rank must be from [0, world_size), Got {rank}'
33 | return self._megatron_rank_info[rank]
34 |
35 | @property
36 | def tp_size(self):
37 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
38 | return self._megatron_global_info.tp_size
39 |
40 | @property
41 | def dp_size(self):
42 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
43 | return self._megatron_global_info.dp_size
44 |
45 | @property
46 | def pp_size(self):
47 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
48 | return self._megatron_global_info.pp_size
49 |
50 | @property
51 | def cp_size(self):
52 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
53 | return self._megatron_global_info.cp_size
54 |
55 | def get_megatron_global_info(self):
56 | return self._megatron_global_info
57 |
--------------------------------------------------------------------------------
/verl/single_controller/base/register_center/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/single_controller/base/register_center/ray.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import ray
16 |
17 |
18 | @ray.remote
19 | class WorkerGroupRegisterCenter:
20 |
21 | def __init__(self, rank_zero_info):
22 | self.rank_zero_info = rank_zero_info
23 |
24 | def get_rank_zero_info(self):
25 | return self.rank_zero_info
26 |
27 |
28 | def create_worker_group_register_center(name, info):
29 | return WorkerGroupRegisterCenter.options(name=name).remote(info)
30 |
--------------------------------------------------------------------------------
/verl/single_controller/ray/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
--------------------------------------------------------------------------------
/verl/single_controller/ray/megatron.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Dict, Optional
16 |
17 | import ray
18 |
19 | from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs
20 | from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo
21 | from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
22 |
23 |
24 | # NOTE(sgm): for open-source megatron-core
25 | class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
26 | """
27 | MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
28 | so that the dispatcher can use it to dispatch data.
29 | """
30 |
31 | def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs):
32 | super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs)
33 | self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info')
34 | self._megatron_global_info: DistGlobalInfo = ray.get(
35 | self.execute_rank_zero_async(method_name='get_megatron_global_info'))
36 |
37 |
38 | class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
39 | """
40 | MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
41 | so that the dispatcher can use it to dispatch data.
42 | """
43 |
44 | def __init__(self,
45 | resource_pool: RayResourcePool,
46 | ray_cls_with_init: RayClassWithInitArgs,
47 | default_megatron_kwargs: Dict = None,
48 | **kwargs):
49 | super().__init__(resource_pool=resource_pool,
50 | ray_cls_with_init=ray_cls_with_init,
51 | default_megatron_kwargs=default_megatron_kwargs,
52 | **kwargs)
53 | self.init_megatron(default_megatron_kwargs=default_megatron_kwargs)
54 | self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info')
55 | self._megatron_global_info: DistGlobalInfo = ray.get(
56 | self.execute_rank_zero_async(method_name='get_megatron_global_info'))
57 |
58 | def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None):
59 | # after super, we will call init of each worker
60 | if not self._is_init_with_detached_workers:
61 | # only init_megatron if the WorkerGroup is created from scratch
62 | self.execute_all_sync(method_name='init_megatron', default_megatron_kwargs=default_megatron_kwargs)
63 |
--------------------------------------------------------------------------------
/verl/third_party/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/third_party/sglang/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-2024 SGLang Team
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
15 | #
16 | # Licensed under the Apache License, Version 2.0 (the "License");
17 | # you may not use this file except in compliance with the License.
18 | # You may obtain a copy of the License at
19 | #
20 | # http://www.apache.org/licenses/LICENSE-2.0
21 | #
22 | # Unless required by applicable law or agreed to in writing, software
23 | # distributed under the License is distributed on an "AS IS" BASIS,
24 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 | # See the License for the specific language governing permissions and
26 | # limitations under the License.
--------------------------------------------------------------------------------
/verl/third_party/vllm/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from importlib.metadata import version, PackageNotFoundError
16 | from packaging import version as vs
17 | from verl.utils.import_utils import is_sglang_available
18 |
19 |
20 | def get_version(pkg):
21 | try:
22 | return version(pkg)
23 | except PackageNotFoundError:
24 | return None
25 |
26 |
27 | package_name = 'vllm'
28 | package_version = get_version(package_name)
29 | vllm_version = None
30 |
31 | if package_version == '0.3.1':
32 | vllm_version = '0.3.1'
33 | from .vllm_v_0_3_1.llm import LLM
34 | from .vllm_v_0_3_1.llm import LLMEngine
35 | from .vllm_v_0_3_1 import parallel_state
36 | elif package_version == '0.4.2':
37 | vllm_version = '0.4.2'
38 | from .vllm_v_0_4_2.llm import LLM
39 | from .vllm_v_0_4_2.llm import LLMEngine
40 | from .vllm_v_0_4_2 import parallel_state
41 | elif package_version == '0.5.4':
42 | vllm_version = '0.5.4'
43 | from .vllm_v_0_5_4.llm import LLM
44 | from .vllm_v_0_5_4.llm import LLMEngine
45 | from .vllm_v_0_5_4 import parallel_state
46 | elif package_version == '0.6.3':
47 | vllm_version = '0.6.3'
48 | from .vllm_v_0_6_3.llm import LLM
49 | from .vllm_v_0_6_3.llm import LLMEngine
50 | from .vllm_v_0_6_3 import parallel_state
51 | elif package_version == '0.6.3+rocm624':
52 | vllm_version = '0.6.3'
53 | from .vllm_v_0_6_3.llm import LLM
54 | from .vllm_v_0_6_3.llm import LLMEngine
55 | from .vllm_v_0_6_3 import parallel_state
56 | elif vs.parse(package_version) >= vs.parse('0.7.0'):
57 | # From 0.6.6.post2 on, vllm supports SPMD inference
58 | # See https://github.com/vllm-project/vllm/pull/12071
59 |
60 | from vllm import LLM
61 | from vllm.distributed import parallel_state
62 | else:
63 | if not is_sglang_available():
64 | raise ValueError(
65 | f'vllm version {package_version} not supported and SGLang also not Found. Currently supported vllm versions are 0.3.1, 0.4.2, 0.5.4, 0.6.3 and 0.7.0+'
66 | )
67 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_3_1/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast)
19 |
20 | from vllm.lora.request import LoRARequest
21 | from vllm.utils import make_async, LRUCache
22 | from vllm.transformers_utils.tokenizers import *
23 |
24 |
25 | class TokenizerGroup:
26 | """A group of tokenizers that can be used for LoRA adapters."""
27 |
28 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
29 | max_input_length: Optional[int]):
30 | self.enable_lora = enable_lora
31 | self.max_input_length = max_input_length
32 | self.tokenizer = tokenizer
33 | if enable_lora:
34 | self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
35 | else:
36 | self.lora_tokenizers = None
37 |
38 | def encode(self,
39 | prompt: str,
40 | request_id: Optional[str] = None,
41 | lora_request: Optional[LoRARequest] = None) -> List[int]:
42 | tokenizer = self.get_lora_tokenizer(lora_request)
43 | return tokenizer.encode(prompt)
44 |
45 | async def encode_async(self,
46 | prompt: str,
47 | request_id: Optional[str] = None,
48 | lora_request: Optional[LoRARequest] = None) -> List[int]:
49 | tokenizer = await self.get_lora_tokenizer_async(lora_request)
50 | return tokenizer.encode(prompt)
51 |
52 | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
53 | if not lora_request or not self.enable_lora:
54 | return self.tokenizer
55 | if lora_request.lora_int_id not in self.lora_tokenizers:
56 | # TODO(sgm): the lora tokenizer is also passed, but may be different
57 | tokenizer = self.tokenizer
58 | # tokenizer = (get_lora_tokenizer(
59 | # lora_request, **self.tokenizer_config) or self.tokenizer)
60 | self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
61 | return tokenizer
62 | else:
63 | return self.lora_tokenizers.get(lora_request.lora_int_id)
64 |
65 | # FIXME(sgm): for simplicity, we assign the special token here
66 | @property
67 | def pad_token_id(self):
68 | return self.tokenizer.pad_token_id
69 |
70 | @property
71 | def eos_token_id(self):
72 | return self.tokenizer.eos_token_id
73 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
15 |
16 | from typing import Dict
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | # NOTE(shengguangming): replace the origin weight loader function in the class
22 | def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
23 | """Parallel Linear weight loader."""
24 | assert param.size() == loaded_weight.size(
25 | ), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format(
26 | param.size(), loaded_weight.size())
27 | assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
28 |
29 | param.data = loaded_weight.data
30 |
31 |
32 | def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
33 | """Default weight loader."""
34 | assert param.size() == loaded_weight.size()
35 | assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"
36 |
37 | param.data = loaded_weight.data
38 |
39 |
40 | def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
41 | params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
42 | for name, loaded_weight in actor_weights.items():
43 | if "lm_head.weight" in name:
44 | # GPT-2 ties the weights of the embedding layer and the final
45 | # linear layer.
46 | continue
47 | if ".attn.bias" in name or ".attn.masked_bias" in name:
48 | # Skip attention mask.
49 | # NOTE: "c_attn.bias" should not be skipped.
50 | continue
51 | if not name.startswith("transformer."):
52 | name = "transformer." + name
53 | param = params_dict[name]
54 | # The HF's GPT-2 implementation uses Conv1D instead of Linear.
55 | # Because of this, we need to transpose the weights.
56 | # Note(zhuohan): the logic below might break quantized models.
57 | for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
58 | if conv1d_weight_name not in name:
59 | continue
60 | if not name.endswith(".weight"):
61 | continue
62 | # TODO: check megatron
63 | loaded_weight = loaded_weight.t()
64 | weight_loader = getattr(param, "weight_loader", default_weight_loader)
65 | weight_loader(param, loaded_weight)
66 |
67 |
68 | def llama_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
69 | # NOTE(shengguangming): the megatron llama may have this prefix
70 | prefix = '0.module.module.'
71 | params_dict = dict(vllm_model.named_parameters())
72 | for name, loaded_weight in actor_weights.items():
73 | if name[:len(prefix)] == prefix:
74 | name = name[len(prefix):]
75 | if "rotary_emb.inv_freq" in name:
76 | continue
77 | else:
78 | param = params_dict[name]
79 | weight_loader = getattr(param, "weight_loader", default_weight_loader)
80 | weight_loader(param, loaded_weight)
81 |
82 |
83 | def mistral_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
84 | # TODO: need to implement a general way to deal with prefix
85 | prefix = '0.module.module.'
86 | params_dict = dict(vllm_model.named_parameters())
87 | for name, loaded_weight in actor_weights.items():
88 | if name[:len(prefix)] == prefix:
89 | name = name[len(prefix):]
90 | if "rotary_emb.inv_freq" in name:
91 | continue
92 | else:
93 | param = params_dict[name]
94 | weight_loader = getattr(param, "weight_loader", default_weight_loader)
95 | weight_loader(param, loaded_weight)
96 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_4_2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
15 |
16 | from typing import Dict, Union, Optional, Iterable, Tuple
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype
22 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23 |
24 |
25 | def update_hf_weight_loader():
26 | from vllm.model_executor.models.gemma import GemmaForCausalLM
27 | GemmaForCausalLM.load_weights = gemma_load_weights
28 |
29 |
30 | def gemma_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
31 | stacked_params_mapping = [
32 | # (param_name, shard_name, shard_id)
33 | ("qkv_proj", "q_proj", "q"),
34 | ("qkv_proj", "k_proj", "k"),
35 | ("qkv_proj", "v_proj", "v"),
36 | ("gate_up_proj", "gate_proj", 0),
37 | ("gate_up_proj", "up_proj", 1),
38 | ]
39 | params_dict = dict(self.named_parameters())
40 | loaded_params = set()
41 | for name, loaded_weight in weights:
42 | for (param_name, shard_name, shard_id) in stacked_params_mapping:
43 | if shard_name not in name:
44 | continue
45 | name = name.replace(shard_name, param_name)
46 | # Skip loading extra bias for GPTQ models.
47 | if name.endswith(".bias") and name not in params_dict:
48 | continue
49 | param = params_dict[name]
50 | weight_loader = param.weight_loader
51 | weight_loader(param, loaded_weight, shard_id)
52 | break
53 | else:
54 | # lm_head is not used in vllm as it is tied with embed_token.
55 | # To prevent errors, skip loading lm_head.weight.
56 | if "lm_head.weight" in name:
57 | continue
58 | # Skip loading extra bias for GPTQ models.
59 | if name.endswith(".bias") and name not in params_dict:
60 | continue
61 | # GemmaRMSNorm is different from Llama's in that it multiplies
62 | # (1 + weight) to the output, instead of just weight.
63 | if "norm.weight" in name:
64 | norm_weight = loaded_weight + 1.0 # prevent inplace modify actor weights
65 | param = params_dict[name]
66 | weight_loader = getattr(param, "weight_loader", default_weight_loader)
67 | weight_loader(param, norm_weight)
68 | else:
69 | param = params_dict[name]
70 | weight_loader = getattr(param, "weight_loader", default_weight_loader)
71 | weight_loader(param, loaded_weight)
72 | loaded_params.add(name)
73 | unloaded_params = params_dict.keys() - loaded_params
74 | if unloaded_params:
75 | raise RuntimeError("Some weights are not initialized from checkpoints: "
76 | f"{unloaded_params}")
77 |
78 |
79 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
80 | assert isinstance(actor_weights, Dict)
81 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
82 | vllm_model.load_weights(actor_weights.items())
83 | for _, module in vllm_model.named_modules():
84 | quant_method = getattr(module, "quant_method", None)
85 | if quant_method is not None:
86 | quant_method.process_weights_after_loading(module)
87 | # FIXME: Remove this after Mixtral is updated
88 | # to use quant_method.
89 | if hasattr(module, "process_weights_after_loading"):
90 | module.process_weights_after_loading()
91 | vllm_model = vllm_model.cuda()
92 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast)
19 |
20 | from vllm.lora.request import LoRARequest
21 | from vllm.utils import make_async, LRUCache
22 | from vllm.transformers_utils.tokenizers import *
23 |
24 |
25 | class TokenizerGroup:
26 | """A group of tokenizers that can be used for LoRA adapters."""
27 |
28 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
29 | max_input_length: Optional[int]):
30 | self.enable_lora = enable_lora
31 | self.max_input_length = max_input_length
32 | self.tokenizer = tokenizer
33 | self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None
34 |
35 | def ping(self) -> bool:
36 | """Check if the tokenizer group is alive."""
37 | return True
38 |
39 | def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]:
40 | """Get the maximum input length for the LoRA request."""
41 | return self.max_input_length
42 |
43 | def encode(self,
44 | prompt: str,
45 | request_id: Optional[str] = None,
46 | lora_request: Optional[LoRARequest] = None) -> List[int]:
47 | tokenizer = self.get_lora_tokenizer(lora_request)
48 | return tokenizer.encode(prompt)
49 |
50 | async def encode_async(self,
51 | prompt: str,
52 | request_id: Optional[str] = None,
53 | lora_request: Optional[LoRARequest] = None) -> List[int]:
54 | tokenizer = await self.get_lora_tokenizer_async(lora_request)
55 | return tokenizer.encode(prompt)
56 |
57 | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
58 | if not lora_request or not self.enable_lora:
59 | return self.tokenizer
60 | if lora_request.lora_int_id not in self.lora_tokenizers:
61 | # TODO(sgm): the lora tokenizer is also passed, but may be different
62 | tokenizer = self.tokenizer
63 | # tokenizer = (get_lora_tokenizer(
64 | # lora_request, **self.tokenizer_config) or self.tokenizer)
65 | self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
66 | return tokenizer
67 | else:
68 | return self.lora_tokenizers.get(lora_request.lora_int_id)
69 |
70 | # FIXME(sgm): for simplicity, we assign the special token here
71 | @property
72 | def pad_token_id(self):
73 | return self.tokenizer.pad_token_id
74 |
75 | @property
76 | def eos_token_id(self):
77 | return self.tokenizer.eos_token_id
78 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_5_4/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
15 |
16 | from typing import Dict, Union, Optional, Iterable, Tuple
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype
22 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23 |
24 |
25 | def update_hf_weight_loader():
26 | print('no hf weight loader need to be updated')
27 | return
28 |
29 |
30 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
31 | assert isinstance(actor_weights, Dict)
32 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
33 | if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
34 | del actor_weights["lm_head.weight"]
35 | vllm_model.load_weights(actor_weights.items())
36 | for _, module in vllm_model.named_modules():
37 | quant_method = getattr(module, "quant_method", None)
38 | if quant_method is not None:
39 | quant_method.process_weights_after_loading(module)
40 | # FIXME: Remove this after Mixtral is updated
41 | # to use quant_method.
42 | if hasattr(module, "process_weights_after_loading"):
43 | module.process_weights_after_loading()
44 | vllm_model = vllm_model.cuda()
45 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast)
19 |
20 | from vllm.lora.request import LoRARequest
21 | from vllm.utils import make_async, LRUCache
22 | from vllm.transformers_utils.tokenizers import *
23 |
24 |
25 | class TokenizerGroup:
26 | """A group of tokenizers that can be used for LoRA adapters."""
27 |
28 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
29 | max_input_length: Optional[int]):
30 | self.enable_lora = enable_lora
31 | self.max_input_length = max_input_length
32 | self.tokenizer = tokenizer
33 | self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None
34 |
35 | def ping(self) -> bool:
36 | """Check if the tokenizer group is alive."""
37 | return True
38 |
39 | def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]:
40 | """Get the maximum input length for the LoRA request."""
41 | return self.max_input_length
42 |
43 | def encode(self,
44 | prompt: str,
45 | request_id: Optional[str] = None,
46 | lora_request: Optional[LoRARequest] = None) -> List[int]:
47 | tokenizer = self.get_lora_tokenizer(lora_request)
48 | return tokenizer.encode(prompt)
49 |
50 | async def encode_async(self,
51 | prompt: str,
52 | request_id: Optional[str] = None,
53 | lora_request: Optional[LoRARequest] = None) -> List[int]:
54 | tokenizer = await self.get_lora_tokenizer_async(lora_request)
55 | return tokenizer.encode(prompt)
56 |
57 | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
58 | if not lora_request or not self.enable_lora:
59 | return self.tokenizer
60 | if lora_request.lora_int_id not in self.lora_tokenizers:
61 | # TODO(sgm): the lora tokenizer is also passed, but may be different
62 | tokenizer = self.tokenizer
63 | # tokenizer = (get_lora_tokenizer(
64 | # lora_request, **self.tokenizer_config) or self.tokenizer)
65 | self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
66 | return tokenizer
67 | else:
68 | return self.lora_tokenizers.get(lora_request.lora_int_id)
69 |
70 | # FIXME(sgm): for simplicity, we assign the special token here
71 | @property
72 | def pad_token_id(self):
73 | return self.tokenizer.pad_token_id
74 |
75 | @property
76 | def eos_token_id(self):
77 | return self.tokenizer.eos_token_id
78 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_6_3/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
15 |
16 | import os
17 | from dataclasses import dataclass
18 |
19 | from transformers import PretrainedConfig
20 | from vllm.config import EngineConfig
21 | from vllm.engine.arg_utils import EngineArgs
22 |
23 | from .config import LoadConfig, ModelConfig
24 |
25 |
26 | @dataclass
27 | class EngineArgs(EngineArgs):
28 | model_hf_config: PretrainedConfig = None # for verl
29 |
30 | def __post_init__(self):
31 | pass
32 |
33 | def create_model_config(self) -> ModelConfig:
34 | return ModelConfig(
35 | hf_config=self.model_hf_config,
36 | tokenizer_mode=self.tokenizer_mode,
37 | trust_remote_code=self.trust_remote_code,
38 | dtype=self.dtype,
39 | seed=self.seed,
40 | revision=self.revision,
41 | code_revision=self.code_revision,
42 | rope_scaling=self.rope_scaling,
43 | rope_theta=self.rope_theta,
44 | tokenizer_revision=self.tokenizer_revision,
45 | max_model_len=self.max_model_len,
46 | quantization=self.quantization,
47 | quantization_param_path=self.quantization_param_path,
48 | enforce_eager=self.enforce_eager,
49 | max_context_len_to_capture=self.max_context_len_to_capture,
50 | max_seq_len_to_capture=self.max_seq_len_to_capture,
51 | max_logprobs=self.max_logprobs,
52 | disable_sliding_window=self.disable_sliding_window,
53 | skip_tokenizer_init=self.skip_tokenizer_init,
54 | served_model_name=self.served_model_name,
55 | limit_mm_per_prompt=self.limit_mm_per_prompt,
56 | use_async_output_proc=not self.disable_async_output_proc,
57 | override_neuron_config=self.override_neuron_config,
58 | config_format=self.config_format,
59 | mm_processor_kwargs=self.mm_processor_kwargs,
60 | )
61 |
62 | def create_load_config(self) -> LoadConfig:
63 | return LoadConfig(
64 | load_format=self.load_format,
65 | download_dir=self.download_dir,
66 | model_loader_extra_config=self.model_loader_extra_config,
67 | ignore_patterns=self.ignore_patterns,
68 | )
69 |
70 | def create_engine_config(self) -> EngineConfig:
71 | engine_config = super().create_engine_config()
72 |
73 | # NOTE[VERL]: Use the world_size set by torchrun
74 | world_size = int(os.getenv("WORLD_SIZE", "-1"))
75 | assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
76 | engine_config.parallel_config.world_size = world_size
77 |
78 | return engine_config
79 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_6_3/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py
15 |
16 | import enum
17 | import json
18 | from dataclasses import dataclass, field
19 | from typing import TYPE_CHECKING, List, Optional, Union
20 |
21 | from transformers import PretrainedConfig
22 |
23 | # Add for verl
24 | from vllm.config import ModelConfig
25 | from vllm.logger import init_logger
26 | from vllm.utils import is_hip
27 |
28 | if TYPE_CHECKING:
29 | from vllm.model_executor.model_loader.loader import BaseModelLoader
30 |
31 | logger = init_logger(__name__)
32 |
33 |
34 | class LoadFormat(str, enum.Enum):
35 | AUTO = "auto"
36 | MEGATRON = "megatron"
37 | HF = "hf"
38 | DTENSOR = "dtensor"
39 | DUMMY_HF = "dummy_hf"
40 | DUMMY_MEGATRON = "dummy_megatron"
41 | DUMMY_DTENSOR = "dummy_dtensor"
42 |
43 |
44 | class ModelConfig(ModelConfig):
45 |
46 | def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None:
47 | super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs)
48 | self.hf_config = hf_config
49 |
50 |
51 | @dataclass
52 | class LoadConfig:
53 | """
54 | download_dir: Directory to download and load the weights, default to the
55 | default cache directory of huggingface.
56 | load_format: The format of the model weights to load:
57 | "auto" will try to load the weights in the safetensors format and
58 | fall back to the pytorch bin format if safetensors format is
59 | not available.
60 | "pt" will load the weights in the pytorch bin format.
61 | "safetensors" will load the weights in the safetensors format.
62 | "npcache" will load the weights in pytorch format and store
63 | a numpy cache to speed up the loading.
64 | "dummy" will initialize the weights with random values, which is
65 | mainly for profiling.
66 | "tensorizer" will use CoreWeave's tensorizer library for
67 | fast weight loading.
68 | "bitsandbytes" will load nf4 type weights.
69 | ignore_patterns: The list of patterns to ignore when loading the model.
70 | Default to "original/**/*" to avoid repeated loading of llama's
71 | checkpoints.
72 |
73 | """
74 |
75 | load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
76 | download_dir: Optional[str] = None
77 | model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
78 | ignore_patterns: Optional[Union[List[str], str]] = None
79 |
80 | def __post_init__(self):
81 | model_loader_extra_config = self.model_loader_extra_config or {}
82 | if isinstance(model_loader_extra_config, str):
83 | self.model_loader_extra_config = json.loads(model_loader_extra_config)
84 | self._verify_load_format()
85 |
86 | if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
87 | logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns)
88 | else:
89 | self.ignore_patterns = ["original/**/*"]
90 |
91 | def _verify_load_format(self) -> None:
92 | if not isinstance(self.load_format, str):
93 | return
94 |
95 | load_format = self.load_format.lower()
96 | self.load_format = LoadFormat(load_format)
97 |
98 | rocm_not_supported_load_format: List[str] = []
99 | if is_hip() and load_format in rocm_not_supported_load_format:
100 | rocm_supported_load_format = [
101 | f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)
102 | ]
103 | raise ValueError(f"load format '{load_format}' is not supported in ROCm. "
104 | f"Supported load formats are "
105 | f"{rocm_supported_load_format}")
106 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
15 |
16 | from typing import Dict
17 |
18 | import torch.nn as nn
19 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype
20 |
21 |
22 | def update_hf_weight_loader():
23 | print("no hf weight loader need to be updated")
24 | return
25 |
26 |
27 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
28 | assert isinstance(actor_weights, Dict)
29 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
30 | if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
31 | del actor_weights["lm_head.weight"]
32 | vllm_model.load_weights(actor_weights.items())
33 | for _, module in vllm_model.named_modules():
34 | quant_method = getattr(module, "quant_method", None)
35 | if quant_method is not None:
36 | quant_method.process_weights_after_loading(module)
37 | # FIXME: Remove this after Mixtral is updated
38 | # to use quant_method.
39 | if hasattr(module, "process_weights_after_loading"):
40 | module.process_weights_after_loading()
41 | vllm_model = vllm_model.cuda()
42 |
--------------------------------------------------------------------------------
/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright 2023 The vLLM team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
15 |
16 | from typing import Optional
17 |
18 | from transformers import PreTrainedTokenizer
19 | from vllm.transformers_utils.tokenizer_group import TokenizerGroup
20 | from vllm.utils import LRUCache
21 |
22 |
23 | class TokenizerGroup(TokenizerGroup):
24 | """A group of tokenizers that can be used for LoRA adapters."""
25 |
26 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
27 | max_input_length: Optional[int]):
28 | self.enable_lora = enable_lora
29 | self.max_input_length = max_input_length
30 | self.tokenizer = tokenizer
31 | self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None
32 |
33 | # FIXME(sgm): for simplicity, we assign the special token here
34 | @property
35 | def pad_token_id(self):
36 | return self.tokenizer.pad_token_id
37 |
38 | @property
39 | def eos_token_id(self):
40 | return self.tokenizer.eos_token_id
41 |
--------------------------------------------------------------------------------
/verl/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/trainer/config/evaluation.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | path: /tmp/math_Qwen2-7B-Instruct.parquet
3 | prompt_key: prompt
4 | response_key: responses
5 | data_source_key: data_source
6 | reward_model_key: reward_model
7 |
8 | custom_reward_function:
9 | path: null
10 | name: compute_score
11 |
--------------------------------------------------------------------------------
/verl/trainer/config/generation.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | nnodes: 1
3 | n_gpus_per_node: 8
4 |
5 | data:
6 | path: ~/data/rlhf/math/test.parquet
7 | prompt_key: prompt
8 | n_samples: 5
9 | output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet
10 | batch_size: 128
11 |
12 | model:
13 | path: ~/models/Qwen2-7B-Instruct
14 | external_lib: null
15 | rollout:
16 | name: vllm
17 | temperature: 1.0
18 | top_k: 50 # 0 for hf rollout, -1 for vllm rollout
19 | top_p: 0.7
20 | prompt_length: 1536
21 | response_length: 512
22 | # for vllm rollout
23 | dtype: bfloat16 # should align with FSDP
24 | gpu_memory_utilization: 0.5
25 | ignore_eos: False
26 | enforce_eager: True
27 | free_cache_engine: True
28 | load_format: dummy_dtensor
29 | tensor_model_parallel_size: 1
30 | max_num_batched_tokens: 8192
31 | max_model_len: null
32 | max_num_seqs: 1024
33 | log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
34 | log_prob_micro_batch_size_per_gpu: 8
35 | # for fire vllm rollout
36 | use_fire_sampling: False # enable FIRE https://arxiv.org/abs/2410.21236
37 | # for hf rollout
38 | do_sample: True
39 | disable_log_stats: True
40 | enable_chunked_prefill: True
41 | n: 1
42 | actor:
43 | strategy: fsdp # This is for backward-compatibility
44 | ulysses_sequence_parallel_size: 1 # sp size
45 | fsdp_config:
46 | fsdp_size: -1
--------------------------------------------------------------------------------
/verl/trainer/config/sft_trainer.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_batch_size: 256
3 | micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
4 | micro_batch_size_per_gpu: 4 # this is also val batch size
5 | train_files: ~/data/gsm8k/train.parquet
6 | val_files: ~/data/gsm8k/test.parquet
7 | # Single-turn settings
8 | prompt_key: question
9 | response_key: answer
10 | # Multi-turn settings
11 | multiturn:
12 | enable: false # Set to true to use multi-turn dataset
13 | messages_key: messages # Key for messages list in multi-turn mode
14 | max_length: 1024
15 | truncation: error
16 | balance_dp_token: False
17 | chat_template: null
18 | custom_cls:
19 | path: null
20 | name: null
21 | model:
22 | partial_pretrain: ~/models/gemma-1.1-7b-it
23 | fsdp_config:
24 | wrap_policy:
25 | min_num_params: 0
26 | cpu_offload: False
27 | offload_params: False
28 | external_lib: null
29 | enable_gradient_checkpointing: False
30 | trust_remote_code: False
31 | lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32)
32 | lora_alpha: 16 # LoRA scaling factor
33 | target_modules: all-linear # Target modules for LoRA adaptation
34 | use_liger: False
35 | optim:
36 | lr: 1e-5
37 | betas: [0.9, 0.95]
38 | weight_decay: 0.01
39 | warmup_steps_ratio: 0.1
40 | clip_grad: 1.0
41 | ulysses_sequence_parallel_size: 1
42 | use_remove_padding: False
43 | trainer:
44 | default_local_dir: /tmp/sft_model
45 | default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here
46 | resume_path: null
47 | project_name: gsm8k-sft
48 | experiment_name: test
49 | total_epochs: 4
50 | total_training_steps: null
51 | logger: ['console']
52 | seed: 1
53 |
54 |
--------------------------------------------------------------------------------
/verl/trainer/main_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Offline evaluate the performance of a generated file using reward model and ground truth verifier.
16 | The input is a parquet file that contains N generated sequences and (optional) the ground truth.
17 |
18 | """
19 |
20 | import hydra
21 | from verl.utils.fs import copy_to_local
22 | import pandas as pd
23 | import numpy as np
24 | from tqdm import tqdm
25 | from collections import defaultdict
26 | import ray
27 |
28 |
29 | def get_custom_reward_fn(config):
30 | import importlib.util, os, sys
31 | reward_fn_config = config.get("custom_reward_function") or {}
32 | file_path = reward_fn_config.get("path")
33 | if not file_path:
34 | return None
35 |
36 | if not os.path.exists(file_path):
37 | raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
38 |
39 | spec = importlib.util.spec_from_file_location("custom_module", file_path)
40 | module = importlib.util.module_from_spec(spec)
41 | try:
42 | sys.modules["custom_module"] = module
43 | spec.loader.exec_module(module)
44 | except Exception as e:
45 | raise RuntimeError(f"Error loading module from '{file_path}': {e}")
46 |
47 | function_name = reward_fn_config.get("name")
48 | if not hasattr(module, function_name):
49 | raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
50 |
51 | print(f"using customized reward function '{function_name}' from '{file_path}'")
52 | raw_fn = getattr(module, function_name)
53 |
54 | reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {}))
55 |
56 | def wrapped_fn(*args, **kwargs):
57 | return raw_fn(*args, **kwargs, **reward_kwargs)
58 |
59 | return wrapped_fn
60 |
61 |
62 | @ray.remote
63 | def process_item(reward_fn, data_source, response_lst, reward_data):
64 | ground_truth = reward_data['ground_truth']
65 | score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]
66 | return data_source, np.mean(score_lst)
67 |
68 |
69 | @hydra.main(config_path='config', config_name='evaluation', version_base=None)
70 | def main(config):
71 | local_path = copy_to_local(config.data.path)
72 | dataset = pd.read_parquet(local_path)
73 | prompts = dataset[config.data.prompt_key]
74 | responses = dataset[config.data.response_key]
75 | data_sources = dataset[config.data.data_source_key]
76 | reward_model_data = dataset[config.data.reward_model_key]
77 |
78 | total = len(dataset)
79 |
80 | # Initialize Ray
81 | if not ray.is_initialized():
82 | ray.init()
83 |
84 | # evaluate test_score based on data source
85 | data_source_reward = defaultdict(list)
86 | compute_score = get_custom_reward_fn(config)
87 |
88 | # Create remote tasks
89 | remote_tasks = [
90 | process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)
91 | ]
92 |
93 | # Process results as they come in
94 | with tqdm(total=total) as pbar:
95 | while len(remote_tasks) > 0:
96 | # Use ray.wait to get completed tasks
97 | done_ids, remote_tasks = ray.wait(remote_tasks)
98 | for result_id in done_ids:
99 | data_source, score = ray.get(result_id)
100 | data_source_reward[data_source].append(score)
101 | pbar.update(1)
102 |
103 | metric_dict = {}
104 | for data_source, rewards in data_source_reward.items():
105 | metric_dict[f'test_score/{data_source}'] = np.mean(rewards)
106 |
107 | print(metric_dict)
108 |
109 |
110 | if __name__ == '__main__':
111 | main()
112 |
--------------------------------------------------------------------------------
/verl/trainer/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/trainer/runtime_env.yaml:
--------------------------------------------------------------------------------
1 | working_dir: ./
2 | excludes: ["/.git/"]
3 | env_vars:
4 | TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
5 | VLLM_ATTENTION_BACKEND: "XFORMERS"
--------------------------------------------------------------------------------
/verl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import tokenizer
16 | from .tokenizer import hf_tokenizer, hf_processor
17 |
18 | __all__ = tokenizer.__all__
--------------------------------------------------------------------------------
/verl/utils/checkpoint/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
--------------------------------------------------------------------------------
/verl/utils/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Dict
16 |
17 | from omegaconf import DictConfig
18 |
19 |
20 | def update_dict_with_config(dictionary: Dict, config: DictConfig):
21 | for key in dictionary:
22 | if hasattr(config, key):
23 | dictionary[key] = getattr(config, key)
24 |
--------------------------------------------------------------------------------
/verl/utils/dataset/README.md:
--------------------------------------------------------------------------------
1 | # Dataset Format
2 | ## RLHF dataset
3 | We combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers.
4 |
5 | Math problems
6 | ```json
7 | {
8 | "data_source": "openai/gsm8k",
9 | "prompt": [{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\""}],
10 | "ability": "math",
11 | "reward_model": {
12 | "style": "rule",
13 | "ground_truth": ["72"]
14 | },
15 | }
16 | ```
17 |
--------------------------------------------------------------------------------
/verl/utils/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .rl_dataset import RLHFDataset
16 | from .rm_dataset import RMDataset
17 | from .sft_dataset import SFTDataset
18 |
--------------------------------------------------------------------------------
/verl/utils/debug/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .performance import log_gpu_memory_usage
--------------------------------------------------------------------------------
/verl/utils/debug/performance.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.distributed as dist
17 | import logging
18 |
19 |
20 | def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0):
21 | if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):
22 | memory_allocated = torch.cuda.memory_allocated() / 1024**3
23 | memory_reserved = torch.cuda.memory_reserved() / 1024**3
24 |
25 | message = f'{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}'
26 |
27 | if logger is None:
28 | print(message)
29 | else:
30 | logger.log(msg=message, level=level)
31 |
--------------------------------------------------------------------------------
/verl/utils/debug/trajectory_tracker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Trajectory tracker can be inserted into code to save the intermediate results.
16 | The results will be dump to hdfs for offline comparison.
17 | Each process will have a client that first move all the tensors to CPU
18 | """
19 |
20 | from verl.utils.hdfs_io import makedirs, copy
21 | import torch
22 | import os
23 | import ray
24 | import io
25 | import tempfile
26 |
27 | from collections import deque
28 |
29 | remote_copy = ray.remote(copy)
30 |
31 |
32 | @ray.remote
33 | def save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose):
34 | filename = name + '.pth'
35 | with tempfile.TemporaryDirectory() as tmpdirname:
36 | local_filepath = os.path.join(tmpdirname, filename)
37 | with open(local_filepath, 'wb') as f:
38 | f.write(data.getbuffer())
39 | # upload to hdfs
40 |
41 | if verbose:
42 | print(f'Saving {local_filepath} to {hdfs_dir}')
43 | try:
44 | copy(local_filepath, hdfs_dir)
45 | except Exception as e:
46 | print(e)
47 |
48 |
49 | @ray.remote
50 | class TrajectoryTracker():
51 |
52 | def __init__(self, hdfs_dir, verbose) -> None:
53 | self.hdfs_dir = hdfs_dir
54 | makedirs(hdfs_dir)
55 | self.verbose = verbose
56 |
57 | self.handle = deque()
58 |
59 | def dump(self, data: io.BytesIO, name):
60 | # get a temp file and write to it
61 | self.handle.append(save_to_hdfs.remote(data, name, self.hdfs_dir, self.verbose))
62 |
63 | def wait_for_hdfs(self):
64 | while len(self.handle) != 0:
65 | future = self.handle.popleft()
66 | ray.get(future)
67 |
68 |
69 | def dump_data(data, name):
70 | enable = os.getenv('VERL_ENABLE_TRACKER', '0') == '1'
71 | if not enable:
72 | return
73 | buffer = io.BytesIO()
74 | torch.save(data, buffer)
75 | tracker = get_trajectory_tracker()
76 | ray.get(tracker.dump.remote(buffer, name))
77 |
78 |
79 | def get_trajectory_tracker():
80 | hdfs_dir = os.getenv('VERL_TRACKER_HDFS_DIR', default=None)
81 | verbose = os.getenv('VERL_TRACKER_VERBOSE', default='0') == '1'
82 | assert hdfs_dir is not None
83 | tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True,
84 | lifetime="detached").remote(hdfs_dir, verbose)
85 | return tracker
86 |
87 |
88 | if __name__ == '__main__':
89 | # testing
90 | os.environ['VERL_ENABLE_TRACKER'] = '1'
91 | os.environ['VERL_TRACKER_HDFS_DIR'] = '~/debug/test'
92 |
93 | @ray.remote
94 | def process(iter):
95 | data = {'obs': torch.randn(10, 20)}
96 | dump_data(data, f'process_{iter}_obs')
97 |
98 | ray.init()
99 |
100 | output_lst = []
101 |
102 | for i in range(10):
103 | output_lst.append(process.remote(i))
104 |
105 | out = ray.get(output_lst)
106 |
107 | tracker = get_trajectory_tracker()
108 | ray.get(tracker.wait_for_hdfs.remote())
109 |
--------------------------------------------------------------------------------
/verl/utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utilities for distributed training."""
15 | import os
16 |
17 |
18 | def initialize_global_process_group(timeout_second=36000):
19 | import torch.distributed
20 | from datetime import timedelta
21 | torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second))
22 | local_rank = int(os.environ["LOCAL_RANK"])
23 | rank = int(os.environ["RANK"])
24 | world_size = int(os.environ["WORLD_SIZE"])
25 |
26 | if torch.distributed.is_initialized():
27 | torch.cuda.set_device(local_rank)
28 | return local_rank, rank, world_size
29 |
--------------------------------------------------------------------------------
/verl/utils/fs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # -*- coding: utf-8 -*-
17 | """File-system agnostic IO APIs"""
18 | import os
19 | import tempfile
20 | import hashlib
21 |
22 | try:
23 | from hdfs_io import copy, makedirs, exists # for internal use only
24 | except ImportError:
25 | from .hdfs_io import copy, makedirs, exists
26 |
27 | __all__ = ["copy", "exists", "makedirs"]
28 |
29 | _HDFS_PREFIX = "hdfs://"
30 |
31 |
32 | def is_non_local(path):
33 | return path.startswith(_HDFS_PREFIX)
34 |
35 |
36 | def md5_encode(path: str) -> str:
37 | return hashlib.md5(path.encode()).hexdigest()
38 |
39 |
40 | def get_local_temp_path(hdfs_path: str, cache_dir: str) -> str:
41 | """Return a local temp path that joins cache_dir and basename of hdfs_path
42 |
43 | Args:
44 | hdfs_path:
45 | cache_dir:
46 |
47 | Returns:
48 |
49 | """
50 | # make a base64 encoding of hdfs_path to avoid directory conflict
51 | encoded_hdfs_path = md5_encode(hdfs_path)
52 | temp_dir = os.path.join(cache_dir, encoded_hdfs_path)
53 | os.makedirs(temp_dir, exist_ok=True)
54 | dst = os.path.join(temp_dir, os.path.basename(hdfs_path))
55 | return dst
56 |
57 |
58 | def copy_to_local(src: str, cache_dir=None, filelock='.file.lock', verbose=False) -> str:
59 | """Copy src from hdfs to local if src is on hdfs or directly return src.
60 | If cache_dir is None, we will use the default cache dir of the system. Note that this may cause conflicts if
61 | the src name is the same between calls
62 |
63 | Args:
64 | src (str): a HDFS path of a local path
65 |
66 | Returns:
67 | a local path of the copied file
68 | """
69 | return copy_local_path_from_hdfs(src, cache_dir, filelock, verbose)
70 |
71 |
72 | def copy_local_path_from_hdfs(src: str, cache_dir=None, filelock='.file.lock', verbose=False) -> str:
73 | """Deprecated. Please use copy_to_local instead."""
74 | from filelock import FileLock
75 |
76 | assert src[-1] != '/', f'Make sure the last char in src is not / because it will cause error. Got {src}'
77 |
78 | if is_non_local(src):
79 | # download from hdfs to local
80 | if cache_dir is None:
81 | # get a temp folder
82 | cache_dir = tempfile.gettempdir()
83 | os.makedirs(cache_dir, exist_ok=True)
84 | assert os.path.exists(cache_dir)
85 | local_path = get_local_temp_path(src, cache_dir)
86 | # get a specific lock
87 | filelock = md5_encode(src) + '.lock'
88 | lock_file = os.path.join(cache_dir, filelock)
89 | with FileLock(lock_file=lock_file):
90 | if not os.path.exists(local_path):
91 | if verbose:
92 | print(f'Copy from {src} to {local_path}')
93 | copy(src, local_path)
94 | return local_path
95 | else:
96 | return src
97 |
--------------------------------------------------------------------------------
/verl/utils/import_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Utilities to check if packages are available.
16 | We assume package availability won't change during runtime.
17 | """
18 |
19 | from functools import cache
20 | from typing import List, Optional
21 |
22 |
23 | @cache
24 | def is_megatron_core_available():
25 | try:
26 | from megatron.core import parallel_state as mpu
27 | return True
28 | except ImportError:
29 | return False
30 |
31 |
32 | @cache
33 | def is_vllm_available():
34 | try:
35 | import vllm
36 | return True
37 | except ImportError:
38 | return False
39 |
40 |
41 | @cache
42 | def is_sglang_available():
43 | try:
44 | import sglang
45 | return True
46 | except ImportError:
47 | return False
48 |
49 |
50 | def import_external_libs(external_libs=None):
51 | if external_libs is None:
52 | return
53 | if not isinstance(external_libs, List):
54 | external_libs = [external_libs]
55 | import importlib
56 | for external_lib in external_libs:
57 | importlib.import_module(external_lib)
58 |
59 |
60 | def load_extern_type(file_path: Optional[str], type_name: Optional[str]):
61 | """Load a external data type based on the file path and type name"""
62 | import importlib.util, os
63 |
64 | if not file_path:
65 | return None
66 |
67 | if not os.path.exists(file_path):
68 | raise FileNotFoundError(f"Custom type file '{file_path}' not found.")
69 |
70 | spec = importlib.util.spec_from_file_location("custom_module", file_path)
71 | module = importlib.util.module_from_spec(spec)
72 | try:
73 | spec.loader.exec_module(module)
74 | except Exception as e:
75 | raise RuntimeError(f"Error loading module from '{file_path}': {e}")
76 |
77 | if not hasattr(module, type_name):
78 | raise AttributeError(f"Custom type '{type_name}' not found in '{file_path}'.")
79 |
80 | return getattr(module, type_name)
--------------------------------------------------------------------------------
/verl/utils/logger/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/utils/logger/aggregate_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | A Ray logger will receive logging info from different processes.
16 | """
17 | import numbers
18 | from typing import Dict
19 |
20 |
21 | def concat_dict_to_str(dict: Dict, step):
22 | output = [f'step:{step}']
23 | for k, v in dict.items():
24 | if isinstance(v, numbers.Number):
25 | output.append(f'{k}:{v:.3f}')
26 | output_str = ' - '.join(output)
27 | return output_str
28 |
29 |
30 | class LocalLogger:
31 |
32 | def __init__(self, remote_logger=None, enable_wandb=False, print_to_console=False):
33 | self.print_to_console = print_to_console
34 | if print_to_console:
35 | print('Using LocalLogger is deprecated. The constructor API will change ')
36 |
37 | def flush(self):
38 | pass
39 |
40 | def log(self, data, step):
41 | if self.print_to_console:
42 | print(concat_dict_to_str(data, step=step), flush=True)
--------------------------------------------------------------------------------
/verl/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | import os
17 | import torch
18 |
19 |
20 | def set_basic_config(level):
21 | """
22 | This function sets the global logging format and level. It will be called when import verl
23 | """
24 | logging.basicConfig(format='%(levelname)s:%(asctime)s:%(message)s', level=level)
25 |
26 |
27 | def log_to_file(string):
28 | print(string)
29 | if os.path.isdir('logs'):
30 | with open(f'logs/log_{torch.distributed.get_rank()}', 'a+') as f:
31 | f.write(string + '\n')
32 |
--------------------------------------------------------------------------------
/verl/utils/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/utils/megatron/memory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 |
18 | class MemoryBuffer:
19 |
20 | def __init__(self, numel, numel_padded, dtype):
21 | self.numel = numel
22 | self.numel_padded = numel_padded
23 | self.dtype = dtype
24 | self.data = torch.zeros(self.numel_padded,
25 | dtype=self.dtype,
26 | device=torch.cuda.current_device(),
27 | requires_grad=False)
28 |
29 | def zero(self):
30 | """Reset the buffer to zero."""
31 | self.data.zero_()
32 |
33 | def get(self, shape, start_index):
34 | """Return a tensor with the input `shape` as a view into the
35 | 1-D data starting at `start_index`."""
36 | end_index = start_index + shape.numel()
37 | assert end_index <= self.numel, \
38 | 'requested tensor is out of the buffer range.'
39 | buffer_tensor = self.data[start_index:end_index]
40 | buffer_tensor = buffer_tensor.view(shape)
41 | return buffer_tensor
42 |
--------------------------------------------------------------------------------
/verl/utils/megatron/optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import importlib
17 | from packaging.version import Version
18 |
19 | from apex.optimizers import FusedAdam as Adam
20 | from apex.optimizers import FusedSGD as SGD
21 |
22 | from megatron.core.optimizer import OptimizerConfig
23 |
24 | from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native
25 |
26 |
27 | def get_megatron_optimizer(
28 | model,
29 | config: OptimizerConfig,
30 | no_weight_decay_cond=None,
31 | scale_lr_cond=None,
32 | lr_mult=1.0,
33 | check_for_nan_in_loss_and_grad=False,
34 | overlap_param_gather=False # add for verl
35 | ):
36 | # Base optimizer.
37 | return get_megatron_optimizer_native(config=config,
38 | model_chunks=model,
39 | no_weight_decay_cond=no_weight_decay_cond,
40 | scale_lr_cond=scale_lr_cond,
41 | lr_mult=lr_mult)
42 |
--------------------------------------------------------------------------------
/verl/utils/megatron/pipeline_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | from megatron.core import parallel_state as mpu
18 |
19 | from .sequence_parallel import pad_to_sequence_parallel
20 |
21 |
22 | def compute_transformers_input_shapes(batches, meta_info):
23 | from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron
24 | # pre-compute input shapes for each micro-batch at each pp stage
25 | input_shapes = []
26 | for model_inputs in batches:
27 | input_ids = model_inputs['input_ids']
28 | attention_mask = model_inputs['attention_mask']
29 | input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1)
30 | if meta_info['sequence_parallel']:
31 | input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad)
32 | # compute shapes for model_inputs
33 | input_shapes.append(
34 | torch.Size([
35 | input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), 1, meta_info['hidden_size']
36 | ]))
37 | else:
38 | # compute shapes for model_inputs
39 | input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info['hidden_size']]))
40 | return input_shapes
41 |
42 |
43 | def make_batch_generator(batches, vpp_size):
44 | if vpp_size > 1:
45 | # has vpp
46 | batch_generator = [batches] * vpp_size # number of vpp chunks
47 | batch_generator = [iter(b) for b in batch_generator]
48 | else:
49 | # no vpp
50 | batch_generator = iter(batches)
51 | return batch_generator
52 |
--------------------------------------------------------------------------------
/verl/utils/megatron/sequence_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | import torch.nn.functional as F
18 | from megatron.core import parallel_state as mpu
19 |
20 |
21 | def mark_parameter_as_sequence_parallel(parameter):
22 | setattr(parameter, 'sequence_parallel', True)
23 |
24 |
25 | def is_sequence_parallel_param(param):
26 | return hasattr(param, 'sequence_parallel') and param.sequence_parallel
27 |
28 |
29 | def pad_to_sequence_parallel(unpad_tokens: torch.Tensor):
30 | """pad the tokens such that the total length is a multiple of sp world size
31 |
32 | Args:
33 | unpad_tokens: (total_nnz, ...). Tokens after removing padding
34 |
35 | Returns:
36 |
37 | """
38 | total_nnz = unpad_tokens.shape[0]
39 | sp_world_size = mpu.get_tensor_model_parallel_world_size()
40 |
41 | if total_nnz % sp_world_size == 0:
42 | pad_size = 0
43 | else:
44 | pad_size = sp_world_size - total_nnz % sp_world_size
45 |
46 | if pad_size > 0:
47 | if unpad_tokens.ndim == 1:
48 | unpad_tokens = F.pad(unpad_tokens, (0, pad_size))
49 | elif unpad_tokens.ndim == 2:
50 | unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))
51 | else:
52 | raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported')
53 |
54 | return unpad_tokens
55 |
--------------------------------------------------------------------------------
/verl/utils/py_functional.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Contain small python utility functions
16 | """
17 |
18 | from typing import Dict
19 | from types import SimpleNamespace
20 |
21 |
22 | def union_two_dict(dict1: Dict, dict2: Dict):
23 | """Union two dict. Will throw an error if there is an item not the same object with the same key.
24 |
25 | Args:
26 | dict1:
27 | dict2:
28 |
29 | Returns:
30 |
31 | """
32 | for key, val in dict2.items():
33 | if key in dict1:
34 | assert dict2[key] == dict1[key], \
35 | f'{key} in meta_dict1 and meta_dict2 are not the same object'
36 | dict1[key] = val
37 |
38 | return dict1
39 |
40 |
41 | def append_to_dict(data: Dict, new_data: Dict):
42 | for key, val in new_data.items():
43 | if key not in data:
44 | data[key] = []
45 | data[key].append(val)
46 |
47 |
48 | class NestedNamespace(SimpleNamespace):
49 |
50 | def __init__(self, dictionary, **kwargs):
51 | super().__init__(**kwargs)
52 | for key, value in dictionary.items():
53 | if isinstance(value, dict):
54 | self.__setattr__(key, NestedNamespace(value))
55 | else:
56 | self.__setattr__(key, value)
57 |
--------------------------------------------------------------------------------
/verl/utils/ray_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Contains commonly used utilities for ray
16 | """
17 |
18 | import ray
19 |
20 | import concurrent.futures
21 |
22 |
23 | def parallel_put(data_list, max_workers=None):
24 |
25 | def put_data(index, data):
26 | return index, ray.put(data)
27 |
28 | if max_workers is None:
29 | max_workers = min(len(data_list), 16)
30 |
31 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
32 | data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)]
33 | res_lst = []
34 | for future in concurrent.futures.as_completed(data_list_f):
35 | res_lst.append(future.result())
36 |
37 | # reorder based on index
38 | output = [None for _ in range(len(data_list))]
39 | for res in res_lst:
40 | index, data_ref = res
41 | output[index] = data_ref
42 |
43 | return output
44 |
--------------------------------------------------------------------------------
/verl/utils/rendezvous/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/utils/rendezvous/ray_backend.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | import time
17 |
18 | from cupy.cuda.nccl import NcclCommunicator, get_unique_id
19 |
20 | import ray
21 | from ray.util import list_named_actors
22 |
23 |
24 | @ray.remote
25 | class NCCLIDStore:
26 |
27 | def __init__(self, nccl_id):
28 | self._nccl_id = nccl_id
29 |
30 | def get(self):
31 | return self._nccl_id
32 |
33 |
34 | def get_nccl_id_store_by_name(name):
35 | all_actors = list_named_actors(all_namespaces=True)
36 | matched_actors = [actor for actor in all_actors if actor.get("name", None) == name]
37 | if len(matched_actors) == 1:
38 | actor = matched_actors[0]
39 | return ray.get_actor(**actor)
40 | elif len(matched_actors) > 1:
41 | logging.warning(f"multiple actors with same name found: {matched_actors}")
42 | elif len(matched_actors) == 0:
43 | logging.info(f"failed to get any actor named {name}")
44 | return None
45 |
46 |
47 | def create_nccl_communicator_in_ray(rank: int,
48 | world_size: int,
49 | group_name: str,
50 | max_retries: int = 100,
51 | interval_s: int = 5):
52 | if rank == 0:
53 | nccl_id = get_unique_id()
54 | nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id)
55 |
56 | assert ray.get(nccl_id_store.get.remote()) == nccl_id
57 | communicator = NcclCommunicator(
58 | ndev=world_size,
59 | commId=nccl_id,
60 | rank=0,
61 | )
62 | return communicator
63 | else:
64 | for i in range(max_retries):
65 | nccl_id_store = get_nccl_id_store_by_name(group_name)
66 | if nccl_id_store is not None:
67 | logging.info(f"nccl_id_store {group_name} got")
68 | nccl_id = ray.get(nccl_id_store.get.remote())
69 | logging.info(f"nccl id for {group_name} got: {nccl_id}")
70 | communicator = NcclCommunicator(
71 | ndev=world_size,
72 | commId=nccl_id,
73 | rank=rank,
74 | )
75 | return communicator
76 | logging.info(f"failed to get nccl_id for {i+1} time, sleep for {interval_s} seconds")
77 | time.sleep(interval_s)
78 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # from . import gsm8k, math, prime_math, prime_code
15 |
16 |
17 | def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None):
18 | if data_source == 'openai/gsm8k':
19 | from . import gsm8k
20 | res = gsm8k.compute_score(solution_str, ground_truth)
21 | elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval', "numina_math"]:
22 | #from . import math
23 | #res = math.compute_score(solution_str, ground_truth)
24 | # [Optional] Math-Verify Integration
25 | # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify).
26 | # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`.
27 | # To use it, override the `compute_score` function with the following implementation:
28 |
29 | from . import math_verify
30 | res = math_verify.compute_score(solution_str, ground_truth)
31 | elif data_source == 'math_dapo' or data_source.startswith("aime"):
32 | from . import math_dapo
33 | res = math_dapo.compute_score(solution_str, ground_truth)
34 | elif data_source in [
35 | 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12',
36 | 'numina_olympiads'
37 | ]:
38 | from . import prime_math
39 | res = prime_math.compute_score(solution_str, ground_truth)
40 | elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']:
41 | from . import prime_code
42 | res = prime_code.compute_score(solution_str, ground_truth, continuous=True)
43 | elif data_source in ['hiyouga/geometry3k']:
44 | from . import geo3k
45 | res = geo3k.compute_score(solution_str, ground_truth)
46 | else:
47 | raise NotImplementedError(f"Reward function is not implemented for {data_source=}")
48 |
49 | if isinstance(res, dict):
50 | return res
51 | elif isinstance(res, (int, float, bool)):
52 | return float(res)
53 | else:
54 | return float(res[0])
55 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/geo3k.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | from mathruler.grader import extract_boxed_content, grade_answer
17 |
18 |
19 | def format_reward(predict_str: str) -> float:
20 | pattern = re.compile(r'.*.*\\boxed\{.*\}.*', re.DOTALL)
21 | match_result = re.fullmatch(pattern, predict_str)
22 | return 1.0 if match_result else 0.0
23 |
24 |
25 | def acc_reward(predict_str: str, ground_truth: str) -> float:
26 | answer = extract_boxed_content(predict_str)
27 | return 1.0 if grade_answer(answer, ground_truth) else 0.0
28 |
29 |
30 | def compute_score(predict_str: str, ground_truth: str) -> float:
31 | return 0.9 * acc_reward(predict_str, ground_truth) + 0.1 * format_reward(predict_str)
32 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/gsm8k.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 |
17 |
18 | def extract_solution(solution_str, method='strict'):
19 | assert method in ['strict', 'flexible']
20 |
21 | if method == 'strict':
22 | # this also tests the formatting of the model
23 | solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
24 | if solution is None:
25 | final_answer = None
26 | else:
27 | final_answer = solution.group(0)
28 | final_answer = final_answer.split('#### ')[1].replace(',', '').replace('$', '')
29 | elif method == 'flexible':
30 | answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
31 | final_answer = None
32 | if len(answer) == 0:
33 | # no reward is there is no answer
34 | pass
35 | else:
36 | invalid_str = ['', '.']
37 | # find the last number that is not '.'
38 | for final_answer in reversed(answer):
39 | if final_answer not in invalid_str:
40 | break
41 | return final_answer
42 |
43 |
44 | def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.):
45 | """The scoring function for GSM8k.
46 |
47 | Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
48 |
49 | Args:
50 | solution_str: the solution text
51 | ground_truth: the ground truth
52 | method: the method to extract the solution, choices are 'strict' and 'flexible'
53 | format_score: the score for the format
54 | score: the score for the correct answer
55 | """
56 | answer = extract_solution(solution_str=solution_str, method=method)
57 | if answer is None:
58 | return 0
59 | else:
60 | if answer == ground_truth:
61 | return score
62 | else:
63 | return format_score
--------------------------------------------------------------------------------
/verl/utils/reward_score/math_batch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Individual Contributor: Mert Unsal
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .math import compute_score
16 |
17 |
18 | def compute_score_batched(data_sources, solution_strs, ground_truths, extra_infos):
19 | """
20 | This is a demonstration of how the batched reward function should look like.
21 | Typically, you want to use batched reward to speed up the process with parallelization
22 | """
23 | return [
24 | compute_score(solution_str, ground_truth) for solution_str, ground_truth in zip(solution_strs, ground_truths)
25 | ]
26 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/math_verify.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | try:
16 | from math_verify.metric import math_metric
17 | from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig
18 | from math_verify.errors import TimeoutException
19 | except ImportError:
20 | print("To use Math-Verify, please install it first by running `pip install math-verify`.")
21 |
22 |
23 | def compute_score(model_output: str, ground_truth: str, timeout_score: float = 0) -> bool:
24 | verify_func = math_metric(
25 | gold_extraction_target=(LatexExtractionConfig(),),
26 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
27 | )
28 | ret_score = 0.
29 |
30 | # Wrap the ground truth in \boxed{} format for verification
31 | ground_truth_boxed = "\\boxed{" + ground_truth + "}"
32 | try:
33 | ret_score, _ = verify_func([ground_truth_boxed], [model_output])
34 | except Exception as e:
35 | pass
36 | except TimeoutException:
37 | ret_score = timeout_score
38 |
39 | return ret_score
40 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/prime_code/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 PRIME team and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .utils import check_correctness as apps_check_correctness
16 | import json
17 | import re
18 | import traceback
19 |
20 |
21 | def compute_score(completion, test_cases, continuous=False):
22 | # try to get code solution from completion. if the completion is pure code, this will not take effect.
23 | solution = completion.split('```python')[-1].split('```')[0]
24 | try:
25 | try:
26 | if not isinstance(test_cases, dict):
27 | test_cases = json.loads(test_cases)
28 | except Exception as e:
29 | print(f"Error:{e}")
30 |
31 | # Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped.
32 | try:
33 | res, metadata = apps_check_correctness(in_outs=test_cases, generation=solution, timeout=5, debug=False)
34 | metadata = dict(enumerate(metadata))[0]
35 | success = all(map(lambda x: x == True, res))
36 | if success:
37 | return success, metadata
38 | except Exception as e:
39 | pass
40 |
41 | test_cases_list = []
42 | inputs = test_cases["inputs"]
43 | outputs = test_cases["outputs"]
44 | for i in range(len(inputs)):
45 | test_cases_list.append({"inputs": [inputs[i]], "outputs": [outputs[i]]})
46 |
47 | if continuous:
48 | # per sample test: if continuous score is needed, test first 10 samples regardless of failures
49 | # do not test all samples cuz some problems have enormous test cases
50 | metadata_list = []
51 | res_list = []
52 | for test_case_id, test_case in enumerate(test_cases_list):
53 | res, metadata = apps_check_correctness(in_outs=test_case, generation=solution, timeout=5, debug=False)
54 | try:
55 | metadata = dict(enumerate(metadata))[0] # metadata can be empty occasionally
56 | except Exception as e:
57 | metadata = {}
58 | metadata["test_case"] = {}
59 | metadata["test_case"]["input"] = str(test_case["inputs"][0])
60 | metadata["test_case"]["output"] = str(test_case["outputs"][0])
61 | metadata["test_case"]["res"] = str(res)
62 | metadata_list.append(metadata)
63 | res_list.extend(res)
64 |
65 | if test_case_id >= 9:
66 | break
67 | res_count = len(res_list) if len(res_list) > 0 else 1
68 | success = sum(map(lambda x: x == True, res_list)) / res_count
69 | except Exception as e:
70 | traceback.print_exc(10)
71 | success = False
72 | metadata_list = None
73 | return success, metadata_list
74 |
--------------------------------------------------------------------------------
/verl/utils/reward_score/prime_code/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 PRIME team and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py
16 |
17 | import multiprocessing
18 | from typing import Dict, Optional
19 | from datasets import load_dataset
20 | from .testing_util import run_test
21 | import traceback
22 | import os, sys
23 |
24 |
25 | def _temp_run(sample, generation, debug, result, metadata_list, timeout):
26 | with open(os.devnull, 'w') as devnull:
27 | sys.stdout = devnull
28 | sys.stderr = devnull
29 | try:
30 | res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
31 | result.append(res)
32 | metadata_list.append(metadata)
33 | except Exception as e:
34 | # print(e) # some tracebacks are extremely long.
35 | traceback.print_exc(10)
36 | result.append([-1 for i in range(len(sample['inputs']))])
37 | metadata_list.append({})
38 |
39 |
40 | def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):
41 | """Check correctness of code generation with a global timeout.
42 | The global timeout is to catch some extreme/rare cases not handled by the timeouts
43 | inside `run_test`"""
44 |
45 | manager = multiprocessing.Manager()
46 | result = manager.list()
47 | metadata_list = manager.list()
48 | p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))
49 | p.start()
50 | p.join(timeout=timeout + 1)
51 | if p.is_alive():
52 | p.kill()
53 | # p.terminate()
54 | if not result:
55 | # consider that all tests failed
56 | result = [[-1 for i in range(len(in_outs["inputs"]))]]
57 | if debug:
58 | print(f"global timeout")
59 | return result[0], metadata_list
60 |
--------------------------------------------------------------------------------
/verl/utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils for tokenization."""
15 | import warnings
16 |
17 | __all__ = ['hf_tokenizer', 'hf_processor']
18 |
19 |
20 | def set_pad_token_id(tokenizer):
21 | """Set pad_token_id to eos_token_id if it is None.
22 |
23 | Args:
24 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be set.
25 |
26 | """
27 | if tokenizer.pad_token_id is None:
28 | tokenizer.pad_token_id = tokenizer.eos_token_id
29 | warnings.warn(f'tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}')
30 | if tokenizer.pad_token is None:
31 | tokenizer.pad_token = tokenizer.eos_token
32 | warnings.warn(f'tokenizer.pad_token is None. Now set to {tokenizer.eos_token}')
33 |
34 |
35 | def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs):
36 | """Create a huggingface pretrained tokenizer which correctness handles eos and pad tokens.
37 |
38 | Args:
39 |
40 | name (str): The name of the tokenizer.
41 | correct_pad_token (bool): Whether to correct the pad token id.
42 | correct_gemma2 (bool): Whether to correct the gemma2 tokenizer.
43 |
44 | Returns:
45 |
46 | transformers.PreTrainedTokenizer: The pretrained tokenizer.
47 |
48 | """
49 | from transformers import AutoTokenizer
50 | if correct_gemma2 and isinstance(name_or_path, str) and 'gemma-2-2b-it' in name_or_path:
51 | # the EOS token in gemma2 is ambiguious, which may worsen RL performance.
52 | # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a
53 | warnings.warn('Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.')
54 | kwargs['eos_token'] = ''
55 | kwargs['eos_token_id'] = 107
56 | tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)
57 | if correct_pad_token:
58 | set_pad_token_id(tokenizer)
59 | return tokenizer
60 |
61 |
62 | def hf_processor(name_or_path, **kwargs):
63 | """Create a huggingface processor to process multimodal data.
64 |
65 | Args:
66 | name_or_path (str): The name of the processor.
67 |
68 | Returns:
69 | transformers.ProcessorMixin: The pretrained processor.
70 | """
71 | from transformers import AutoProcessor
72 | try:
73 | processor = AutoProcessor.from_pretrained(name_or_path, **kwargs)
74 | except Exception:
75 | processor = None
76 | # Avoid load tokenizer, see:
77 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344
78 | if processor is not None and "Processor" not in processor.__class__.__name__:
79 | processor = None
80 | return processor
81 |
--------------------------------------------------------------------------------
/verl/utils/torch_dtypes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Adapted from Cruise.
16 | """
17 |
18 | import torch
19 |
20 | from typing import Union
21 |
22 | HALF_LIST = [16, "16", "fp16", "float16", torch.float16]
23 | FLOAT_LIST = [32, "32", "fp32", "float32", torch.float32]
24 | BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16]
25 |
26 |
27 | class PrecisionType(object):
28 | """Type of precision used.
29 |
30 | >>> PrecisionType.HALF == 16
31 | True
32 | >>> PrecisionType.HALF in (16, "16")
33 | True
34 | """
35 |
36 | HALF = "16"
37 | FLOAT = "32"
38 | FULL = "64"
39 | BFLOAT = "bf16"
40 | MIXED = "mixed"
41 |
42 | @staticmethod
43 | def supported_type(precision: Union[str, int]) -> bool:
44 | return any(x == precision for x in PrecisionType)
45 |
46 | @staticmethod
47 | def supported_types() -> list[str]:
48 | return [x.value for x in PrecisionType]
49 |
50 | @staticmethod
51 | def is_fp16(precision):
52 | return precision in HALF_LIST
53 |
54 | @staticmethod
55 | def is_fp32(precision):
56 | return precision in FLOAT_LIST
57 |
58 | @staticmethod
59 | def is_bf16(precision):
60 | return precision in BFLOAT_LIST
61 |
62 | @staticmethod
63 | def to_dtype(precision):
64 | if precision in HALF_LIST:
65 | return torch.float16
66 | elif precision in FLOAT_LIST:
67 | return torch.float32
68 | elif precision in BFLOAT_LIST:
69 | return torch.bfloat16
70 | else:
71 | raise RuntimeError(f"unexpected precision: {precision}")
72 |
73 | @staticmethod
74 | def to_str(precision):
75 | if precision == torch.float16:
76 | return 'fp16'
77 | elif precision == torch.float32:
78 | return 'fp32'
79 | elif precision == torch.bfloat16:
80 | return 'bf16'
81 | else:
82 | raise RuntimeError(f"unexpected precision: {precision}")
83 |
--------------------------------------------------------------------------------
/verl/version/version:
--------------------------------------------------------------------------------
1 | 0.2.0.dev
2 |
--------------------------------------------------------------------------------
/verl/workers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/verl/workers/actor/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BasePPOActor
16 | from .dp_actor import DataParallelPPOActor
17 |
18 | __all__ = ["BasePPOActor", "DataParallelPPOActor"]
19 |
--------------------------------------------------------------------------------
/verl/workers/actor/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | The base class for Actor
16 | """
17 | from abc import ABC, abstractmethod
18 | from typing import Iterable, Dict
19 |
20 | from verl import DataProto
21 | import torch
22 |
23 | __all__ = ['BasePPOActor']
24 |
25 |
26 | class BasePPOActor(ABC):
27 |
28 | def __init__(self, config):
29 | """The base class for PPO actor
30 |
31 | Args:
32 | config (DictConfig): a config passed to the PPOActor. We expect the type to be
33 | DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general.
34 | """
35 | super().__init__()
36 | self.config = config
37 |
38 | @abstractmethod
39 | def compute_log_prob(self, data: DataProto) -> torch.Tensor:
40 | """Compute logits given a batch of data.
41 |
42 | Args:
43 | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,
44 | ```attention_mask``` and ```position_ids```.
45 |
46 | Returns:
47 | DataProto: a DataProto containing the key ```log_probs```
48 |
49 |
50 | """
51 | pass
52 |
53 | @abstractmethod
54 | def update_policy(self, data: DataProto) -> Dict:
55 | """Update the policy with an iterator of DataProto
56 |
57 | Args:
58 | data (DataProto): an iterator over the DataProto that returns by
59 | ```make_minibatch_iterator```
60 |
61 | Returns:
62 | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model
63 | such as ```loss```, ```grad_norm```, etc,.
64 |
65 | """
66 | pass
67 |
--------------------------------------------------------------------------------
/verl/workers/critic/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BasePPOCritic
16 | from .dp_critic import DataParallelPPOCritic
17 |
18 | __all__ = ["BasePPOCritic", "DataParallelPPOCritic"]
19 |
--------------------------------------------------------------------------------
/verl/workers/critic/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Base class for a critic
16 | """
17 | from abc import ABC, abstractmethod
18 |
19 | import torch
20 |
21 | from verl import DataProto
22 |
23 | __all__ = ['BasePPOCritic']
24 |
25 |
26 | class BasePPOCritic(ABC):
27 |
28 | def __init__(self, config):
29 | super().__init__()
30 | self.config = config
31 |
32 | @abstractmethod
33 | def compute_values(self, data: DataProto) -> torch.Tensor:
34 | """Compute values"""
35 | pass
36 |
37 | @abstractmethod
38 | def update_critic(self, data: DataProto):
39 | """Update the critic"""
40 | pass
41 |
--------------------------------------------------------------------------------
/verl/workers/reward_manager/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 PRIME team and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .naive import NaiveRewardManager
16 | from .prime import PrimeRewardManager
17 | from .batch import BatchRewardManager
18 | from .dapo import DAPORewardManager
19 |
--------------------------------------------------------------------------------
/verl/workers/reward_manager/batch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Individual Contributor: Mert Unsal
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | from verl import DataProto
17 | from collections import defaultdict
18 |
19 |
20 | class BatchRewardManager:
21 |
22 | def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key='data_source', **reward_kwargs):
23 | self.tokenizer = tokenizer
24 | self.num_examine = num_examine
25 | self.compute_score = compute_score
26 | self.reward_fn_key = reward_fn_key
27 | self.reward_kwargs = reward_kwargs
28 |
29 | def verify(self, data):
30 | prompt_ids = data.batch['prompts']
31 | response_ids = data.batch['responses']
32 | attention_mask = data.batch['attention_mask']
33 |
34 | prompt_len = prompt_ids.shape[-1]
35 | valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1)
36 |
37 | responses_str = []
38 | for i in range(len(data)):
39 | valid_len = valid_response_lengths[i]
40 | valid_response_ids = response_ids[i][:valid_len]
41 | response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
42 | responses_str.append(response_str)
43 |
44 | ground_truths = [item.non_tensor_batch['reward_model'].get('ground_truth', None) for item in data]
45 | data_sources = data.non_tensor_batch[self.reward_fn_key]
46 | extras = data.non_tensor_batch.get('extra_info', [None] * len(data))
47 |
48 | scores = self.compute_score(data_sources=data_sources,
49 | solution_strs=responses_str,
50 | ground_truths=ground_truths,
51 | extra_infos=extras,
52 | **self.reward_kwargs)
53 |
54 | return scores
55 |
56 | def __call__(self, data: DataProto, return_dict=False):
57 |
58 | # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
59 | if 'rm_scores' in data.batch.keys():
60 | if return_dict:
61 | return {"reward_tensor": data.batch['rm_scores']}
62 | else:
63 | return data.batch['rm_scores']
64 |
65 | reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
66 | reward_extra_info = defaultdict(list)
67 | prompt_ids = data.batch['prompts']
68 | prompt_len = prompt_ids.shape[-1]
69 | attention_mask = data.batch['attention_mask']
70 | valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1)
71 | data_sources = data.non_tensor_batch[self.reward_fn_key]
72 |
73 | scores = self.verify(data)
74 | rewards = []
75 | already_printed = {}
76 |
77 | for i in range(len(data)):
78 | length = valid_response_lengths[i].item()
79 | score = scores[i]
80 |
81 | if isinstance(score, dict):
82 | reward = score["score"]
83 | for key, value in score.items():
84 | reward_extra_info[key].append(value)
85 | else:
86 | reward = score
87 |
88 | rewards.append(reward)
89 | reward_tensor[i, length - 1] = reward
90 |
91 | data_source = data_sources[i]
92 | if already_printed.get(data_source, 0) < self.num_examine:
93 | response_str = self.tokenizer.decode(data.batch['responses'][i][:length], skip_special_tokens=True)
94 | prompt_str = self.tokenizer.decode(data.batch['prompts'][i], skip_special_tokens=True)
95 | ground_truth = data[i].non_tensor_batch['reward_model'].get('ground_truth', None)
96 | print("[prompt]", prompt_str)
97 | print("[response]", response_str)
98 | print("[ground_truth]", ground_truth)
99 | print("[score]", scores[i])
100 | already_printed[data_source] = already_printed.get(data_source, 0) + 1
101 |
102 | data.batch['acc'] = torch.tensor(rewards, dtype=torch.float32, device=prompt_ids.device)
103 |
104 | if return_dict:
105 | return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info}
106 | else:
107 | return reward_tensor
108 |
--------------------------------------------------------------------------------
/verl/workers/reward_manager/naive.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from verl import DataProto
16 | from verl.utils.reward_score import _default_compute_score
17 | import torch
18 | from collections import defaultdict
19 |
20 |
21 | class NaiveRewardManager:
22 | """The reward manager.
23 | """
24 |
25 | def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key='data_source') -> None:
26 | self.tokenizer = tokenizer
27 | self.num_examine = num_examine # the number of batches of decoded responses to print to the console
28 | self.compute_score = compute_score or _default_compute_score
29 | self.reward_fn_key = reward_fn_key
30 |
31 | def __call__(self, data: DataProto, return_dict=False):
32 | """We will expand this function gradually based on the available datasets"""
33 |
34 | # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
35 | if 'rm_scores' in data.batch.keys():
36 | if return_dict:
37 | return {"reward_tensor": data.batch['rm_scores']}
38 | else:
39 | return data.batch['rm_scores']
40 |
41 | reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
42 | reward_extra_info = defaultdict(list)
43 |
44 | already_print_data_sources = {}
45 |
46 | for i in range(len(data)):
47 | data_item = data[i] # DataProtoItem
48 |
49 | prompt_ids = data_item.batch['prompts']
50 |
51 | prompt_length = prompt_ids.shape[-1]
52 |
53 | valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
54 | valid_prompt_ids = prompt_ids[-valid_prompt_length:]
55 |
56 | response_ids = data_item.batch['responses']
57 | valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
58 | valid_response_ids = response_ids[:valid_response_length]
59 |
60 | # decode
61 | prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
62 | response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
63 |
64 | ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
65 |
66 | data_source = data_item.non_tensor_batch[self.reward_fn_key]
67 |
68 | extra_info = data_item.non_tensor_batch.get('extra_info', None)
69 |
70 | score = self.compute_score(
71 | data_source=data_source,
72 | solution_str=response_str,
73 | ground_truth=ground_truth,
74 | extra_info=extra_info,
75 | )
76 |
77 | if isinstance(score, dict):
78 | reward = score["score"]
79 | # Store the information including original reward
80 | for key, value in score.items():
81 | reward_extra_info[key].append(value)
82 | else:
83 | reward = score
84 |
85 | reward_tensor[i, valid_response_length - 1] = reward
86 |
87 | if data_source not in already_print_data_sources:
88 | already_print_data_sources[data_source] = 0
89 |
90 | if already_print_data_sources[data_source] < self.num_examine:
91 | already_print_data_sources[data_source] += 1
92 | print("[prompt]", prompt_str)
93 | print("[response]", response_str)
94 | print("[ground_truth]", ground_truth)
95 | if isinstance(score, dict):
96 | for key, value in score.items():
97 | print(f"[{key}]", value)
98 | else:
99 | print(f"[score]", score)
100 |
101 | if return_dict:
102 | return {
103 | "reward_tensor": reward_tensor,
104 | "reward_extra_info": reward_extra_info,
105 | }
106 | else:
107 | return reward_tensor
108 |
--------------------------------------------------------------------------------
/verl/workers/reward_model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BasePPORewardModel
16 |
--------------------------------------------------------------------------------
/verl/workers/reward_model/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | The base class for reward model
16 | """
17 |
18 | from abc import ABC, abstractmethod
19 |
20 | from verl import DataProto
21 |
22 |
23 | class BasePPORewardModel(ABC):
24 |
25 | def __init__(self, config):
26 | self.config = config
27 |
28 | @abstractmethod
29 | def compute_reward(self, data: DataProto) -> DataProto:
30 | """Computing reward given input_ids. The transformers should output a tensor with shape
31 | [batch_size, sequence_length], and the value at [EOS] mask should be gathered.
32 |
33 | Args:
34 | data: must contain keys "input_ids", "attention_mask" and "position_ids".
35 | - input_ids: [batch_size, sequence_length]
36 | - attention_mask: [batch_size, sequence_length]
37 | - position_ids: [batch_size, sequence_length]
38 |
39 | Returns: a data pass protocol containing "reward". Only the [EOS] position contains the reward.
40 | Other position should have zero reward. Note that this may change in the future if we use
41 | dense reward. So, we leave the interface for general case.
42 | - reward: [batch_size, sequence_length].
43 |
44 | """
45 | pass
46 |
--------------------------------------------------------------------------------
/verl/workers/reward_model/megatron/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .reward_model import MegatronRewardModel
16 |
--------------------------------------------------------------------------------
/verl/workers/rollout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .base import BaseRollout
16 | from .naive import NaiveRollout
17 | from .hf_rollout import HFRollout
18 |
19 | __all__ = ["BaseRollout", "NaiveRollout", "HFRollout"]
20 |
--------------------------------------------------------------------------------
/verl/workers/rollout/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from abc import ABC, abstractmethod
16 | from typing import Iterable, Union
17 |
18 | from verl import DataProto
19 |
20 | __all__ = ['BaseRollout']
21 |
22 |
23 | class BaseRollout(ABC):
24 |
25 | def __init__(self):
26 | """
27 |
28 | Args:
29 | dataloader: an Iterable of TensorDict that consistently generates prompts. Note that the dataloader
30 | should handle when the training stops.
31 | """
32 | super().__init__()
33 |
34 | @abstractmethod
35 | def generate_sequences(self, prompts: DataProto) -> DataProto:
36 | """Generate sequences"""
37 | pass
38 |
--------------------------------------------------------------------------------
/verl/workers/rollout/naive/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .naive_rollout import NaiveRollout
16 |
--------------------------------------------------------------------------------
/verl/workers/rollout/sglang_rollout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 |
14 | from .sglang_rollout import SGLangRollout
15 |
--------------------------------------------------------------------------------
/verl/workers/rollout/vllm_rollout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from importlib.metadata import version, PackageNotFoundError
16 |
17 | ###
18 | # [SUPPORT AMD:]
19 | import torch
20 | ###
21 |
22 |
23 | def get_version(pkg):
24 | try:
25 | return version(pkg)
26 | except PackageNotFoundError:
27 | return None
28 |
29 |
30 | package_name = 'vllm'
31 | package_version = get_version(package_name)
32 |
33 | ###
34 | # package_version = get_version(package_name)
35 | # [SUPPORT AMD:]
36 | if "AMD" in torch.cuda.get_device_name():
37 | import re
38 | package_version = version(package_name)
39 | package_version = re.match(r'(\d+\.\d+\.?\d*)', package_version).group(1)
40 | else:
41 | package_version = get_version(package_name)
42 | ###
43 |
44 | if package_version <= '0.6.3':
45 | vllm_mode = 'customized'
46 | from .vllm_rollout import vLLMRollout
47 | from .fire_vllm_rollout import FIREvLLMRollout
48 | else:
49 | vllm_mode = 'spmd'
50 | from .vllm_rollout_spmd import vLLMRollout
51 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from verl.utils.import_utils import (
16 | is_vllm_available,
17 | is_sglang_available,
18 | is_megatron_core_available,
19 | )
20 |
21 | from .base import BaseShardingManager
22 | from .fsdp_ulysses import FSDPUlyssesShardingManager
23 |
24 | AllGatherPPModel = None
25 |
26 | if is_megatron_core_available() and is_vllm_available():
27 | from .megatron_vllm import AllGatherPPModel, MegatronVLLMShardingManager
28 | elif AllGatherPPModel is not None:
29 | pass
30 | else:
31 | AllGatherPPModel = None
32 | MegatronVLLMShardingManager = None
33 |
34 | if is_vllm_available():
35 | from .fsdp_vllm import FSDPVLLMShardingManager
36 | else:
37 | FSDPVLLMShardingManager = None
38 |
39 | # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability.
40 | # However, due to veRL's setting, the main process of ray can not find any CUDA device, which would potentially lead to:
41 | # "RuntimeError: No CUDA GPUs are available".
42 | # For this reason, sharding_manager.__init__ should not import SGLangShardingManager and user need to import use the abs path.
43 | # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
44 | # if is_sglang_available():
45 | # from .fsdp.fsdp_sglang import FSDPSGLangShardingManager
46 | # else:
47 | # FSDPSGLangShardingManager = None
48 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Sharding manager to implement HybridEngine
16 | """
17 |
18 | from verl import DataProto
19 |
20 |
21 | class BaseShardingManager:
22 |
23 | def __enter__(self):
24 | pass
25 |
26 | def __exit__(self, exc_type, exc_value, traceback):
27 | pass
28 |
29 | def preprocess_data(self, data: DataProto) -> DataProto:
30 | return data
31 |
32 | def postprocess_data(self, data: DataProto) -> DataProto:
33 | return data
34 |
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/fsdp_ulysses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
16 | """
17 | from .base import BaseShardingManager
18 |
19 | from torch.distributed.device_mesh import DeviceMesh
20 |
21 | from verl.utils.torch_functional import allgather_dict_tensors
22 | from verl.protocol import all_gather_data_proto
23 | from verl.utils.ulysses import set_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_group
24 | import numpy as np
25 |
26 | import torch
27 | import torch.distributed
28 |
29 | from verl import DataProto
30 |
31 |
32 | class FSDPUlyssesShardingManager(BaseShardingManager):
33 | """
34 | Sharding manager to support data resharding when using FSDP + Ulysses
35 | """
36 |
37 | def __init__(self, device_mesh: DeviceMesh):
38 | super().__init__()
39 | self.device_mesh = device_mesh
40 | self.seed_offset = 12345
41 |
42 | def __enter__(self):
43 | if self.device_mesh is not None:
44 | # We have a global SP group
45 | # so we have to change to use model-specific sp group
46 | self.prev_sp_group = get_ulysses_sequence_parallel_group()
47 | set_ulysses_sequence_parallel_group(self.device_mesh['sp'].get_group())
48 | # TODO: check how to set seed for each model
49 |
50 | def __exit__(self, exc_type, exc_value, traceback):
51 | # restore random states
52 | if self.device_mesh is not None:
53 | # revert to previous sp group
54 | set_ulysses_sequence_parallel_group(self.prev_sp_group)
55 | # TODO: check how to set seed for each model
56 |
57 | def preprocess_data(self, data: DataProto) -> DataProto:
58 | """
59 | AllGather data from sp region
60 | This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
61 | In Ulysses, we need to make sure the same data is used across a SP group
62 | """
63 | if self.device_mesh is not None:
64 | sp_size = self.device_mesh['sp'].size()
65 | group = self.device_mesh['sp'].get_group()
66 |
67 | all_gather_data_proto(data=data, process_group=group)
68 | return data
69 |
70 | def postprocess_data(self, data: DataProto) -> DataProto:
71 | """
72 | Split the data to follow FSDP partition
73 | """
74 | if self.device_mesh is not None:
75 | sp_size = self.device_mesh['sp'].size()
76 | sp_rank = self.device_mesh['sp'].get_local_rank()
77 | data = data.chunk(chunks=sp_size)[sp_rank]
78 | return data
--------------------------------------------------------------------------------
/verl/workers/sharding_manager/patch/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .fsdp_vllm_patch import patched_ds_v3_load_weights
16 |
--------------------------------------------------------------------------------