├── .gitignore ├── LICENSE ├── README.md ├── assets ├── eval_bar.png └── overview.png ├── data ├── download_dataset.sh └── prepare_musique_recall.py ├── scripts ├── evaluation │ ├── eval_config.yaml │ ├── llm_judge.py │ ├── run_eval.py │ └── utils.py ├── inference │ └── re_call_use_case.py ├── serving │ ├── retriever_config.yaml │ ├── retriever_serving.py │ └── sandbox.py └── train │ ├── get_domain_ip.py │ ├── get_host_ip.py │ ├── train.sh │ └── train_multi_node.sh ├── setup.py └── src ├── flashrag ├── __init__.py ├── config │ ├── __init__.py │ ├── basic_config.yaml │ └── config.py ├── dataset │ ├── __init__.py │ ├── dataset.py │ └── utils.py ├── evaluator │ ├── __init__.py │ ├── _bleu.py │ ├── evaluator.py │ ├── metrics.py │ └── utils.py ├── generator │ ├── __init__.py │ ├── fid.py │ ├── generator.py │ ├── multimodal_generator.py │ ├── openai_generator.py │ ├── stop_word_criteria.py │ └── utils.py ├── judger │ ├── __init__.py │ └── judger.py ├── pipeline │ ├── __init__.py │ ├── active_pipeline.py │ ├── branching_pipeline.py │ ├── mm_pipeline.py │ ├── pipeline.py │ └── replug_utils.py ├── prompt │ ├── __init__.py │ ├── base_prompt.py │ ├── mm_prompt.py │ ├── selfask_examplars.py │ └── trace_examplars.py ├── refiner │ ├── __init__.py │ ├── kg_refiner.py │ ├── llmlingua_compressor.py │ ├── refiner.py │ └── selective_context_compressor.py ├── retriever │ ├── __init__.py │ ├── __main__.py │ ├── encoder.py │ ├── index_builder.py │ ├── reranker.py │ ├── retriever.py │ └── utils.py ├── utils │ ├── __init__.py │ ├── constants.py │ ├── pred_parse.py │ └── utils.py └── version.py ├── re_call ├── __init__.py └── inference │ ├── __init__.py │ └── re_call.py ├── verl ├── __init__.py ├── models │ ├── README.md │ ├── __init__.py │ ├── llama │ │ ├── __init__.py │ │ └── megatron │ │ │ ├── __init__.py │ │ │ ├── checkpoint_utils │ │ │ ├── __init__.py │ │ │ ├── llama_loader.py │ │ │ ├── llama_loader_depracated.py │ │ │ └── llama_saver.py │ │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── parallel_attention.py │ │ │ ├── parallel_decoder.py │ │ │ ├── parallel_linear.py │ │ │ ├── parallel_mlp.py │ │ │ └── parallel_rmsnorm.py │ │ │ └── modeling_llama_megatron.py │ ├── mcore │ │ ├── __init__.py │ │ ├── gpt_model.py │ │ ├── loader.py │ │ ├── readme.md │ │ └── saver.py │ ├── qwen2 │ │ ├── __init__.py │ │ └── megatron │ │ │ ├── __init__.py │ │ │ ├── checkpoint_utils │ │ │ ├── __init__.py │ │ │ ├── qwen2_loader.py │ │ │ ├── qwen2_loader_depracated.py │ │ │ └── qwen2_saver.py │ │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── parallel_attention.py │ │ │ ├── parallel_decoder.py │ │ │ ├── parallel_linear.py │ │ │ ├── parallel_mlp.py │ │ │ └── parallel_rmsnorm.py │ │ │ └── modeling_qwen2_megatron.py │ ├── registry.py │ ├── transformers │ │ ├── __init__.py │ │ ├── llama.py │ │ ├── monkey_patch.py │ │ ├── qwen2.py │ │ └── qwen2_vl.py │ └── weight_loader_registry.py ├── protocol.py ├── single_controller │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── decorator.py │ │ ├── megatron │ │ │ ├── __init__.py │ │ │ ├── worker.py │ │ │ └── worker_group.py │ │ ├── register_center │ │ │ ├── __init__.py │ │ │ └── ray.py │ │ ├── worker.py │ │ └── worker_group.py │ └── ray │ │ ├── __init__.py │ │ ├── base.py │ │ └── megatron.py ├── third_party │ ├── __init__.py │ ├── sglang │ │ ├── __init__.py │ │ └── parallel_state.py │ └── vllm │ │ ├── __init__.py │ │ ├── vllm_v_0_3_1 │ │ ├── __init__.py │ │ ├── arg_utils.py │ │ ├── config.py │ │ ├── llm.py │ │ ├── llm_engine_sp.py │ │ ├── model_loader.py │ │ ├── model_runner.py │ │ ├── parallel_state.py │ │ ├── tokenizer.py │ │ ├── weight_loaders.py │ │ └── worker.py │ │ ├── vllm_v_0_4_2 │ │ ├── __init__.py │ │ ├── arg_utils.py │ │ ├── config.py │ │ ├── dtensor_weight_loaders.py │ │ ├── hf_weight_loader.py │ │ ├── llm.py │ │ ├── llm_engine_sp.py │ │ ├── megatron_weight_loaders.py │ │ ├── model_loader.py │ │ ├── model_runner.py │ │ ├── parallel_state.py │ │ ├── spmd_gpu_executor.py │ │ ├── tokenizer.py │ │ └── worker.py │ │ ├── vllm_v_0_5_4 │ │ ├── __init__.py │ │ ├── arg_utils.py │ │ ├── config.py │ │ ├── dtensor_weight_loaders.py │ │ ├── hf_weight_loader.py │ │ ├── llm.py │ │ ├── llm_engine_sp.py │ │ ├── megatron_weight_loaders.py │ │ ├── model_loader.py │ │ ├── model_runner.py │ │ ├── parallel_state.py │ │ ├── spmd_gpu_executor.py │ │ ├── tokenizer.py │ │ └── worker.py │ │ └── vllm_v_0_6_3 │ │ ├── __init__.py │ │ ├── arg_utils.py │ │ ├── config.py │ │ ├── dtensor_weight_loaders.py │ │ ├── hf_weight_loader.py │ │ ├── llm.py │ │ ├── llm_engine_sp.py │ │ ├── megatron_weight_loaders.py │ │ ├── model_loader.py │ │ ├── model_runner.py │ │ ├── parallel_state.py │ │ ├── spmd_gpu_executor.py │ │ ├── tokenizer.py │ │ └── worker.py ├── trainer │ ├── __init__.py │ ├── config │ │ ├── evaluation.yaml │ │ ├── generation.yaml │ │ ├── ppo_megatron_trainer.yaml │ │ ├── ppo_trainer.yaml │ │ └── sft_trainer.yaml │ ├── fsdp_sft_trainer.py │ ├── main_eval.py │ ├── main_generation.py │ ├── main_ppo.py │ ├── ppo │ │ ├── __init__.py │ │ ├── core_algos.py │ │ ├── metric_utils.py │ │ └── ray_trainer.py │ └── runtime_env.yaml ├── utils │ ├── __init__.py │ ├── checkpoint │ │ ├── __init__.py │ │ ├── checkpoint_manager.py │ │ ├── fsdp_checkpoint_manager.py │ │ └── megatron_checkpoint_manager.py │ ├── config.py │ ├── dataset │ │ ├── README.md │ │ ├── __init__.py │ │ ├── multiturn_sft_dataset.py │ │ ├── rl_dataset.py │ │ ├── rm_dataset.py │ │ ├── sft_dataset.py │ │ └── template.py │ ├── debug │ │ ├── __init__.py │ │ ├── performance.py │ │ └── trajectory_tracker.py │ ├── distributed.py │ ├── flops_counter.py │ ├── fs.py │ ├── fsdp_utils.py │ ├── hdfs_io.py │ ├── import_utils.py │ ├── logger │ │ ├── __init__.py │ │ └── aggregate_logger.py │ ├── logging_utils.py │ ├── megatron │ │ ├── __init__.py │ │ ├── memory.py │ │ ├── optimizer.py │ │ ├── pipeline_parallel.py │ │ ├── sequence_parallel.py │ │ └── tensor_parallel.py │ ├── megatron_utils.py │ ├── memory_buffer.py │ ├── model.py │ ├── py_functional.py │ ├── ray_utils.py │ ├── rendezvous │ │ ├── __init__.py │ │ └── ray_backend.py │ ├── reward_score │ │ ├── __init__.py │ │ ├── geo3k.py │ │ ├── gsm8k.py │ │ ├── math.py │ │ ├── math_batch.py │ │ ├── math_dapo.py │ │ ├── math_verify.py │ │ ├── prime_code │ │ │ ├── __init__.py │ │ │ ├── testing_util.py │ │ │ └── utils.py │ │ ├── prime_math │ │ │ ├── __init__.py │ │ │ ├── grader.py │ │ │ └── math_normalize.py │ │ └── re_call.py │ ├── seqlen_balancing.py │ ├── tokenizer.py │ ├── torch_dtypes.py │ ├── torch_functional.py │ ├── tracking.py │ └── ulysses.py ├── version │ └── version └── workers │ ├── __init__.py │ ├── actor │ ├── __init__.py │ ├── base.py │ ├── dp_actor.py │ └── megatron_actor.py │ ├── critic │ ├── __init__.py │ ├── base.py │ ├── dp_critic.py │ └── megatron_critic.py │ ├── fsdp_workers.py │ ├── megatron_workers.py │ ├── reward_manager │ ├── __init__.py │ ├── batch.py │ ├── dapo.py │ ├── naive.py │ ├── prime.py │ └── re_call.py │ ├── reward_model │ ├── __init__.py │ ├── base.py │ └── megatron │ │ ├── __init__.py │ │ └── reward_model.py │ ├── rollout │ ├── __init__.py │ ├── base.py │ ├── hf_rollout.py │ ├── naive │ │ ├── __init__.py │ │ └── naive_rollout.py │ ├── sglang_rollout │ │ ├── __init__.py │ │ └── sglang_rollout.py │ ├── tokenizer.py │ └── vllm_rollout │ │ ├── __init__.py │ │ ├── fire_vllm_rollout.py │ │ ├── vllm_rollout.py │ │ └── vllm_rollout_spmd.py │ └── sharding_manager │ ├── __init__.py │ ├── base.py │ ├── fsdp_sglang.py │ ├── fsdp_ulysses.py │ ├── fsdp_vllm.py │ ├── megatron_vllm.py │ └── patch │ ├── __init__.py │ └── fsdp_vllm_patch.py └── version /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | # Distribution / packaging 4 | .Python 5 | build/ 6 | develop-eggs/ 7 | dist/ 8 | downloads/ 9 | eggs/ 10 | .eggs/ 11 | lib/ 12 | lib64/ 13 | parts/ 14 | sdist/ 15 | var/ 16 | *.egg-info/ 17 | .installed.cfg 18 | *.egg 19 | 20 | # vscode 21 | .vscode 22 | 23 | # Mac 24 | .DS_Store 25 | 26 | notebooks/ 27 | output/ 28 | pyproject.toml 29 | uv.lock 30 | .env -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Agent-RL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /assets/eval_bar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Agent-RL/ReCall/aaf16b31c83702cdd3c7f7a0ad431efce7904b63/assets/eval_bar.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Agent-RL/ReCall/aaf16b31c83702cdd3c7f7a0ad431efce7904b63/assets/overview.png -------------------------------------------------------------------------------- /data/download_dataset.sh: -------------------------------------------------------------------------------- 1 | mkdir -p hotpotqa 2 | wget https://huggingface.co/datasets/RUC-NLPIR/FlashRAG_datasets/resolve/main/hotpotqa/train.jsonl -O hotpotqa/train.jsonl 3 | wget https://huggingface.co/datasets/RUC-NLPIR/FlashRAG_datasets/resolve/main/hotpotqa/dev.jsonl -O hotpotqa/dev.jsonl 4 | 5 | mkdir -p 2wikimultihopqa 6 | wget https://huggingface.co/datasets/RUC-NLPIR/FlashRAG_datasets/resolve/main/2wikimultihopqa/train.jsonl -O 2wikimultihopqa/train.jsonl 7 | wget https://huggingface.co/datasets/RUC-NLPIR/FlashRAG_datasets/resolve/main/2wikimultihopqa/dev.jsonl -O 2wikimultihopqa/dev.jsonl 8 | 9 | mkdir -p musique 10 | wget https://huggingface.co/datasets/RUC-NLPIR/FlashRAG_datasets/resolve/main/musique/train.jsonl -O musique/train.jsonl 11 | wget https://huggingface.co/datasets/RUC-NLPIR/FlashRAG_datasets/resolve/main/musique/dev.jsonl -O musique/dev.jsonl 12 | 13 | mkdir -p bamboogle 14 | wget https://huggingface.co/datasets/RUC-NLPIR/FlashRAG_datasets/resolve/main/bamboogle/test.jsonl -O bamboogle/test.jsonl 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /data/prepare_musique_recall.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import datasets 4 | import jsonlines 5 | import argparse 6 | import random 7 | random.seed(42) 8 | 9 | wikipedia_search_env = """import requests 10 | 11 | def wikipedia_search(query: str, top_n: int = 5): 12 | url = "/search" 13 | 14 | if query == '': 15 | return 'invalid query' 16 | 17 | data = {'query': query, 'top_n': top_n} 18 | response = requests.post(url, json=data) 19 | retrieval_text = '' 20 | for line in response.json(): 21 | retrieval_text += f"{line['contents']}\\n\\n" 22 | retrieval_text = retrieval_text.strip() 23 | 24 | return retrieval_text""" 25 | 26 | wikipedia_search_schemas = [{ 27 | "type": "function", 28 | "function": { 29 | "name": "wikipedia_search", 30 | "description": "Search Wikipedia for a given query.", 31 | "parameters": { 32 | "type": "object", 33 | "properties": { 34 | "query": { 35 | "type": "string", 36 | "description": "Query to search for." 37 | }, 38 | "top_n": { 39 | "type": "integer", 40 | "description": "Number of results to return. The default value is 5.", 41 | "default": 5 42 | } 43 | }, 44 | "required": ["query"] 45 | } 46 | } 47 | } 48 | ] 49 | wikipedia_search_schemas = json.dumps(wikipedia_search_schemas, indent=4) 50 | 51 | if __name__ == '__main__': 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument( 54 | '--input_dir', 55 | help='the directory of the input data, for example the dir of musique, refer to the data/download_dataset.sh to download the data' 56 | ) 57 | parser.add_argument( 58 | '--output_dir', 59 | help='the directory of the output parquet data' 60 | ) 61 | 62 | args = parser.parse_args() 63 | 64 | train_data_path = os.path.join(args.input_dir, 'train.jsonl') 65 | lines = [] 66 | with jsonlines.open(train_data_path) as reader: 67 | for line in reader: 68 | lines.append(line) 69 | train_data = [] 70 | for line in lines: 71 | train_data.append({ 72 | "data_source": "musique_re_call", 73 | "question": line['question'], 74 | "ability": "re_call", 75 | "reward_model": { 76 | "style": "rule", 77 | "ground_truth": line['golden_answers'] 78 | }, 79 | "extra_info": { 80 | "id": line['id'], 81 | "env": wikipedia_search_env, 82 | "func_schemas": wikipedia_search_schemas 83 | } 84 | }) 85 | 86 | dev_data_path = os.path.join(args.input_dir, 'dev.jsonl') 87 | lines = [] 88 | with jsonlines.open(dev_data_path) as reader: 89 | for line in reader: 90 | lines.append(line) 91 | dev_data = [] 92 | random.shuffle(lines) 93 | for line in lines[:100]: 94 | dev_data.append({ 95 | "data_source": "musique_re_call", 96 | "question": line['question'], 97 | "ability": "re_call", 98 | "reward_model": { 99 | "style": "rule", 100 | "ground_truth": line['golden_answers'] 101 | }, 102 | "extra_info": { 103 | "id": line['id'], 104 | "env": wikipedia_search_env, 105 | "func_schemas": wikipedia_search_schemas 106 | } 107 | }) 108 | 109 | train_dataset = datasets.Dataset.from_list(train_data) 110 | test_dataset = datasets.Dataset.from_list(dev_data) 111 | 112 | train_dataset.to_parquet(os.path.join(args.output_dir, 'train.parquet')) 113 | test_dataset.to_parquet(os.path.join(args.output_dir, 'test.parquet')) -------------------------------------------------------------------------------- /scripts/evaluation/eval_config.yaml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------Environment Settings------------------------------------------------# 2 | # Directory paths for data and outputs 3 | data_dir: "your-data-dir" 4 | save_dir: "your-save-dir" 5 | 6 | # Seed for reproducibility 7 | seed: 2024 8 | 9 | # Whether save intermediate data 10 | save_intermediate_data: True 11 | save_note: 'experiment' 12 | 13 | # -------------------------------------------------Retrieval Settings------------------------------------------------# 14 | # If set the remote url, the retriever will be a remote retriever and ignore following settings 15 | use_remote_retriever: True 16 | remote_retriever_url: "your-remote-retriever-url" 17 | 18 | instruction: ~ # instruction for retrieval model 19 | retrieval_topk: 5 # number of retrieved documents 20 | retrieval_batch_size: 256 # batch size for retrieval 21 | retrieval_use_fp16: True # whether to use fp16 for retrieval model 22 | retrieval_query_max_length: 128 # max length of the query 23 | save_retrieval_cache: False # whether to save the retrieval cache 24 | use_retrieval_cache: False # whether to use the retrieval cache 25 | retrieval_cache_path: ~ # path to the retrieval cache 26 | retrieval_pooling_method: ~ # set automatically if not provided 27 | 28 | # -------------------------------------------------Generator Settings------------------------------------------------# 29 | framework: sgl_remote # inference frame work of LLM, supporting: 'hf','vllm','fschat' 30 | sgl_remote_url: "your-sgl-remote-url" 31 | sandbox_url: "your-sandbox-url" 32 | generator_model: "the-model-local-path" # name or path of the generator model, for laoding tokenizer 33 | generator_max_input_len: 8192 # max length of the input 34 | generation_params: 35 | do_sample: False 36 | max_tokens: 8192 37 | 38 | # -------------------------------------------------Evaluation Settings------------------------------------------------# 39 | # Metrics to evaluate the result 40 | metrics: [ 'em','f1','acc','precision','recall'] 41 | # Specify setting for metric, will be called within certain metrics 42 | metric_setting: 43 | retrieval_recall_topk: 5 44 | save_metric_score: True # whether to save the metric score into txt file 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /scripts/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import logging 4 | import jsonlines 5 | from tqdm import tqdm 6 | from typing import Union 7 | from functools import wraps 8 | from concurrent.futures import ThreadPoolExecutor, as_completed 9 | 10 | def execute(func, input_list_or_num_samples: Union[list, int], output_path: str, max_workers: int, logger: logging.Logger): 11 | out = open(output_path, 'a') 12 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 13 | if isinstance(input_list_or_num_samples, list): 14 | futures = [executor.submit(func, item) for item in input_list_or_num_samples] 15 | else: 16 | futures = [executor.submit(func) for _ in range(input_list_or_num_samples)] 17 | for future in tqdm(as_completed(futures), total=len(futures)): 18 | try: 19 | res: dict = future.result() 20 | except Exception as e: 21 | logger.info(f"[error] {e}") 22 | continue 23 | if res: 24 | out.write(json.dumps(res, ensure_ascii=False) + '\n') 25 | out.flush() 26 | out.close() 27 | 28 | def retry(max: int=10, sleep: int=1, logger: logging.Logger=None): 29 | def decorator(func): 30 | @wraps(func) 31 | def wrapper(*args, **kwargs): 32 | for i in range(max): 33 | try: 34 | return func(*args, **kwargs) 35 | except Exception as e: 36 | logger.info(f"[retry] try {i} times") 37 | if i == max - 1: 38 | raise Exception("Error: {}. Retry {} failed after {} times".format(e, func.__name__, max)) 39 | elif sleep: 40 | time.sleep(sleep) 41 | return wrapper 42 | return decorator 43 | 44 | def init_logger(log_path: str, log_name: str): 45 | logger = logging.getLogger(log_name) 46 | logger.setLevel(logging.INFO) 47 | file_handler = logging.FileHandler(log_path) 48 | file_handler.setLevel(logging.INFO) 49 | 50 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 51 | file_handler.setFormatter(formatter) 52 | 53 | logger.addHandler(file_handler) 54 | return logger -------------------------------------------------------------------------------- /scripts/inference/re_call_use_case.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from re_call import ReCall 3 | 4 | model_url = "" 5 | sandbox_url = "" 6 | data_path = "" 7 | 8 | # load some data 9 | test_lines = [] 10 | test_data = pd.read_parquet(data_path) 11 | for row in test_data.iterrows(): 12 | curr_line = {} 13 | curr_line['question'] = row[1]['question'] 14 | curr_line['answer'] = row[1]['reward_model']['ground_truth'] 15 | curr_line['env'] = row[1]['extra_info']['env'] 16 | curr_line['func_schemas'] = row[1]['extra_info']['func_schemas'] 17 | test_lines.append(curr_line) 18 | 19 | # initialize the re_call model 20 | re_call = ReCall(model_url, sandbox_url) 21 | 22 | # run the re_call model 23 | response = re_call.run(test_lines[1]['env'], test_lines[1]['func_schemas'], test_lines[1]['question']) 24 | print(response) -------------------------------------------------------------------------------- /scripts/serving/retriever_config.yaml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------Environment Settings------------------------------------------------# 2 | gpu_id: "0,1,2,3,4,5,6,7" 3 | 4 | # -------------------------------------------------Retrieval Settings------------------------------------------------# 5 | # If set the name, the model path will be find in global paths 6 | retrieval_method: "/path/to/retrieval/model" # name or path of the retrieval model. 7 | index_path: "/path/to/index" # path to the indexed file 8 | faiss_gpu: True # whether use gpu to hold index 9 | corpus_path: "/path/to/corpus" # path to corpus in '.jsonl' format that store the documents -------------------------------------------------------------------------------- /scripts/serving/sandbox.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | import contextlib 3 | import io 4 | from typing import Dict 5 | from pydantic import BaseModel 6 | from argparse import ArgumentParser 7 | 8 | app = FastAPI() 9 | 10 | class CodeRequest(BaseModel): 11 | env: str 12 | call: str 13 | timeout: int = 5 14 | 15 | @app.post("/execute") 16 | async def execute_code(request: CodeRequest) -> Dict: 17 | output = io.StringIO() 18 | result = None 19 | error = None 20 | 21 | print("-"*30) 22 | print(request.env) 23 | print(request.call) 24 | print("-"*30) 25 | 26 | try: 27 | with contextlib.redirect_stdout(output): 28 | exec_env = {} 29 | exec(compile(request.env, '', 'exec'), exec_env) 30 | exec(compile(f"response = {request.call}", '', 'exec'), exec_env) 31 | result = exec_env.get('response') 32 | except Exception as e: 33 | error = str(e) 34 | 35 | return { 36 | "output": output.getvalue(), 37 | "result": result, 38 | "error": error 39 | } 40 | 41 | if __name__ == "__main__": 42 | parser = ArgumentParser() 43 | parser.add_argument("--port", type=int, default=80) 44 | args = parser.parse_args() 45 | 46 | import uvicorn 47 | uvicorn.run(app, host='0.0.0.0', port=args.port) 48 | -------------------------------------------------------------------------------- /scripts/train/get_domain_ip.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import socket 3 | import time 4 | 5 | domain = sys.argv[1] 6 | while True: 7 | try: 8 | ip_address = socket.gethostbyname(domain) 9 | print(ip_address) 10 | break 11 | except socket.error as e: 12 | sys.stderr.write("Error: %s\n" % e) 13 | time.sleep(5) -------------------------------------------------------------------------------- /scripts/train/get_host_ip.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import subprocess 3 | 4 | def get_host_ip(): 5 | try: 6 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 7 | s.connect(("8.8.8.8", 80)) 8 | return s.getsockname()[0].strip() 9 | except: 10 | result = subprocess.run("hostname -I", capture_output=True, text=True, shell=True).stdout 11 | return result.strip().split()[0] 12 | 13 | if __name__ == "__main__": 14 | print(get_host_ip()) 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from setuptools import setup, find_packages 4 | 5 | this_directory = Path(__file__).parent 6 | long_description = (this_directory / "README.md").read_text() 7 | 8 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) 9 | with open(os.path.join(version_folder, 'src/version')) as f: 10 | __version__ = f.read().strip() 11 | 12 | install_requires = [ 13 | # verl 14 | 'accelerate', 15 | 'codetiming', 16 | 'datasets', 17 | 'dill', 18 | 'hydra-core', 19 | 'numpy', 20 | 'pandas', 21 | 'peft', 22 | 'pyarrow>=15.0.0', 23 | 'pybind11', 24 | 'pylatexenc', 25 | 'ray[default]>=2.10', 26 | 'tensordict<=0.6.2', 27 | 'torchdata', 28 | 'transformers', 29 | 'vllm==0.8.4', 30 | 'wandb', 31 | 32 | # flashrag 33 | 'datasets', 34 | 'base58', 35 | 'nltk', 36 | 'numpy', 37 | 'langid', 38 | 'openai', 39 | 'peft', 40 | 'PyYAML', 41 | 'rank_bm25', 42 | 'rouge', 43 | 'spacy', 44 | 'tiktoken', 45 | 'torch', 46 | 'tqdm', 47 | 'transformers>=4.40.0', 48 | 'bm25s[core]==0.2.0', 49 | 'fschat', 50 | 'streamlit', 51 | 'chonkie>=0.4.0', 52 | 'gradio>=5.0.0', 53 | 'rouge-chinese', 54 | 'jieba', 55 | 56 | # others 57 | 'sglang', 58 | 'jsonlines', 59 | ] 60 | 61 | setup( 62 | name='re-call', 63 | version=__version__, 64 | package_dir={'': 'src'}, 65 | packages=find_packages(where='src'), 66 | url='https://github.com/Agent-RL/ReCall', 67 | license='MIT License', 68 | author='Baichuan Inc.', 69 | author_email='chenmingyang@baichuan-inc.com', 70 | description='ReCall: Learning to Reason with Tool Call for LLMs via Reinforcement Learning', 71 | install_requires=install_requires, 72 | package_data={'': ['**/*.yaml']}, 73 | include_package_data=True, 74 | long_description=long_description, 75 | long_description_content_type='text/markdown' 76 | ) -------------------------------------------------------------------------------- /src/flashrag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Agent-RL/ReCall/aaf16b31c83702cdd3c7f7a0ad431efce7904b63/src/flashrag/__init__.py -------------------------------------------------------------------------------- /src/flashrag/config/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.config.config import Config 2 | 3 | -------------------------------------------------------------------------------- /src/flashrag/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.dataset.dataset import * 2 | -------------------------------------------------------------------------------- /src/flashrag/dataset/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Union 2 | import numpy as np 3 | from flashrag.dataset import Dataset 4 | 5 | 6 | def convert_numpy(data: Any) -> Any: 7 | if isinstance(data, dict): 8 | return {key: convert_numpy(value) for key, value in data.items()} 9 | elif isinstance(data, list): 10 | return [convert_numpy(element) for element in data] 11 | elif isinstance(data, np.ndarray): 12 | return data.tolist() 13 | elif isinstance(data, (np.integer,)): 14 | return int(data) 15 | elif isinstance(data, (np.floating,)): 16 | return float(data) 17 | elif isinstance(data, (np.bool_)): 18 | return bool(data) 19 | elif isinstance(data, (np.str_)): 20 | return str(data) 21 | else: 22 | return data 23 | 24 | def filter_dataset(dataset: Dataset, filter_func=None): 25 | if filter_func is None: 26 | return dataset 27 | data = dataset.data 28 | for item in data: 29 | if not filter_func(item): 30 | data.remove(item) 31 | return Dataset(config=dataset.config, data=data) 32 | 33 | 34 | def split_dataset(dataset: Dataset, split_symbol: list): 35 | assert len(split_symbol) == len(dataset) 36 | 37 | data = dataset.data 38 | data_split = {symbol: [] for symbol in set(split_symbol)} 39 | for symbol in set(split_symbol): 40 | symbol_data = [x for x, x_symbol in zip(data, split_symbol) if x_symbol == symbol] 41 | data_split[symbol] = Dataset(config=dataset.config, data=symbol_data) 42 | 43 | return data_split 44 | 45 | 46 | def merge_dataset(dataset_split: dict, split_symbol: list): 47 | assert len(split_symbol) == sum([len(data) for data in dataset_split.values()]) 48 | dataset_split_iter = {symbol: iter(dataset.data) for symbol, dataset in dataset_split.items()} 49 | 50 | final_data = [] 51 | for item_symbol in split_symbol: 52 | final_data.append(next(dataset_split_iter[item_symbol])) 53 | final_dataset = Dataset(config=list(dataset_split.values())[0].config, data=final_data) 54 | 55 | return final_dataset 56 | 57 | 58 | def get_batch_dataset(dataset: Dataset, batch_size=16): 59 | data = dataset.data 60 | for idx in range(0, len(data), batch_size): 61 | batched_data = data[idx : idx + batch_size] 62 | batch_dataset = Dataset(config=dataset.config, data=batched_data) 63 | yield batch_dataset 64 | 65 | 66 | def merge_batch_dataset(dataset_list: Dataset): 67 | dataset = dataset_list[0] 68 | total_data = [] 69 | for batch_dataset in dataset_list: 70 | total_data.extend(batch_dataset.data) 71 | dataset = Dataset(config=dataset.config, data=total_data) 72 | return dataset 73 | def remove_images(data: Any) -> Any: 74 | from PIL import Image 75 | from typing import Any 76 | if isinstance(data, dict): 77 | return {key: remove_images(value) 78 | for key, value in data.items() 79 | if not isinstance(value, Image.Image)} 80 | elif isinstance(data, list): 81 | return [remove_images(element) 82 | for element in data 83 | if not isinstance(element, Image.Image)] 84 | elif isinstance(data, tuple): 85 | return tuple(remove_images(element) 86 | for element in data 87 | if not isinstance(element, Image.Image)) 88 | elif isinstance(data, set): 89 | return {remove_images(element) 90 | for element in data 91 | if not isinstance(element, Image.Image)} 92 | else: 93 | return data 94 | 95 | 96 | def clean_prompt_image(input): 97 | try: 98 | for message in input: 99 | if isinstance(message.get("content"), list): 100 | message["content"] = [item for item in message["content"] if item.get("type") != "image"] 101 | return input 102 | except: 103 | return input -------------------------------------------------------------------------------- /src/flashrag/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.evaluator.evaluator import * 2 | from flashrag.evaluator.metrics import * 3 | -------------------------------------------------------------------------------- /src/flashrag/evaluator/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from flashrag.evaluator.metrics import BaseMetric 3 | 4 | 5 | class Evaluator: 6 | """Evaluator is used to summarize the results of all metrics.""" 7 | 8 | def __init__(self, config): 9 | self.config = config 10 | self.save_dir = config["save_dir"] 11 | 12 | self.save_metric_flag = config["save_metric_score"] 13 | self.save_data_flag = config["save_intermediate_data"] 14 | self.metrics = [metric.lower() for metric in self.config["metrics"]] 15 | 16 | self.avaliable_metrics = self._collect_metrics() 17 | 18 | self.metric_class = {} 19 | for metric in self.metrics: 20 | if metric in self.avaliable_metrics: 21 | self.metric_class[metric] = self.avaliable_metrics[metric](self.config) 22 | else: 23 | print(f"{metric} has not been implemented!") 24 | raise NotImplementedError 25 | 26 | def _collect_metrics(self): 27 | """Collect all classes based on ```BaseMetric```.""" 28 | 29 | def find_descendants(base_class, subclasses=None): 30 | if subclasses is None: 31 | subclasses = set() 32 | 33 | direct_subclasses = base_class.__subclasses__() 34 | for subclass in direct_subclasses: 35 | if subclass not in subclasses: 36 | subclasses.add(subclass) 37 | find_descendants(subclass, subclasses) 38 | return subclasses 39 | 40 | avaliable_metrics = {} 41 | for cls in find_descendants(BaseMetric): 42 | metric_name = cls.metric_name 43 | avaliable_metrics[metric_name] = cls 44 | return avaliable_metrics 45 | 46 | def evaluate(self, data): 47 | """Calculate all metric indicators and summarize them.""" 48 | 49 | result_dict = {} 50 | for metric in self.metrics: 51 | try: 52 | metric_result, metric_scores = self.metric_class[metric].calculate_metric(data) 53 | result_dict.update(metric_result) 54 | 55 | for metric_score, item in zip(metric_scores, data): 56 | item.update_evaluation_score(metric, metric_score) 57 | except Exception as e: 58 | print(f"Error in {metric}: {e}") 59 | continue 60 | 61 | if self.save_metric_flag: 62 | self.save_metric_score(result_dict) 63 | 64 | if self.save_data_flag: 65 | self.save_data(data) 66 | 67 | return result_dict 68 | 69 | def save_metric_score(self, result_dict, file_name="metric_score.txt"): 70 | save_path = os.path.join(self.save_dir, file_name) 71 | with open(save_path, "w", encoding="utf-8") as f: 72 | for k, v in result_dict.items(): 73 | f.write(f"{k}: {v}\n") 74 | 75 | def save_data(self, data, file_name="intermediate_data.json"): 76 | """Save the evaluated data, including the raw data and the score of each data 77 | sample on each metric.""" 78 | 79 | save_path = os.path.join(self.save_dir, file_name) 80 | 81 | data.save(save_path) 82 | -------------------------------------------------------------------------------- /src/flashrag/evaluator/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | 5 | def normalize_answer(s): 6 | def remove_articles(text): 7 | return re.sub(r"\b(a|an|the)\b", " ", text) 8 | 9 | def white_space_fix(text): 10 | return " ".join(text.split()) 11 | 12 | def remove_punc(text): 13 | exclude = set(string.punctuation) 14 | return "".join(ch for ch in text if ch not in exclude) 15 | 16 | def lower(text): 17 | return text.lower() 18 | 19 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 20 | 21 | -------------------------------------------------------------------------------- /src/flashrag/generator/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.generator.generator import * 2 | from flashrag.generator.multimodal_generator import * 3 | from flashrag.generator.openai_generator import * 4 | from flashrag.generator.utils import * 5 | -------------------------------------------------------------------------------- /src/flashrag/judger/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.judger.judger import * -------------------------------------------------------------------------------- /src/flashrag/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.pipeline.mm_pipeline import * 2 | from flashrag.pipeline.pipeline import * 3 | from flashrag.pipeline.branching_pipeline import REPLUGPipeline, SuRePipeline 4 | from flashrag.pipeline.active_pipeline import IterativePipeline, SelfRAGPipeline, FLAREPipeline, SelfAskPipeline, IRCOTPipeline, RQRAGPipeline, ReCallPipeline -------------------------------------------------------------------------------- /src/flashrag/pipeline/mm_pipeline.py: -------------------------------------------------------------------------------- 1 | from flashrag.evaluator import Evaluator 2 | from flashrag.utils import get_retriever, get_generator 3 | 4 | class BasicMultiModalPipeline: 5 | """Base object of all multimodal pipelines. A pipeline includes the overall process of RAG. 6 | If you want to implement a pipeline, you should inherit this class. 7 | """ 8 | 9 | def __init__(self, config, prompt_template=None): 10 | from flashrag.prompt import MMPromptTemplate 11 | self.config = config 12 | self.device = config["device"] 13 | self.retriever = None 14 | self.evaluator = Evaluator(config) 15 | if prompt_template is None: 16 | prompt_template = MMPromptTemplate(config) 17 | self.prompt_template = prompt_template 18 | 19 | def run(self, dataset, pred_process_fun=None): 20 | """The overall inference process of a RAG framework.""" 21 | pass 22 | 23 | def evaluate(self, dataset, do_eval=True, pred_process_func=None): 24 | """The evaluation process after finishing overall generation""" 25 | 26 | if pred_process_func is not None: 27 | dataset = pred_process_func(dataset) 28 | 29 | if do_eval: 30 | # evaluate & save result 31 | eval_result = self.evaluator.evaluate(dataset) 32 | print(eval_result) 33 | 34 | return 35 | 36 | 37 | class MMSequentialPipeline(BasicMultiModalPipeline): 38 | PERFORM_MODALITY_DICT = { 39 | 'text': ['text'], 40 | 'image': ['image'] 41 | } 42 | def __init__(self, config, prompt_template=None, retriever=None, generator=None): 43 | super().__init__(config, prompt_template) 44 | self.generator = get_generator(config) if generator is None else generator 45 | self.retriever = get_retriever(config) if retriever is None else retriever 46 | 47 | def naive_run(self, dataset, do_eval=True, pred_process_func=None): 48 | input_prompts = [ 49 | self.prompt_template.get_string(item) for item in dataset 50 | ] 51 | 52 | dataset.update_output("prompt", input_prompts) 53 | 54 | pred_answer_list = self.generator.generate(input_prompts) 55 | dataset.update_output("pred", pred_answer_list) 56 | 57 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_func=pred_process_func) 58 | 59 | return dataset 60 | 61 | def run(self, dataset, do_eval=True, perform_modality_dict=PERFORM_MODALITY_DICT, pred_process_func=None): 62 | if None not in dataset.question: 63 | text_query_list = dataset.question 64 | else: 65 | text_query_list = dataset.text 66 | image_query_list = dataset.image 67 | 68 | # perform retrieval 69 | retrieval_result = [] 70 | for modal in perform_modality_dict.get('text', []): 71 | retrieval_result.append( 72 | self.retriever.batch_search(text_query_list, target_modal=modal) 73 | ) 74 | for modal in perform_modality_dict.get('image', []): 75 | retrieval_result.append( 76 | self.retriever.batch_search(image_query_list, target_modal=modal) 77 | ) 78 | retrieval_result = [sum(group, []) for group in zip(*retrieval_result)] 79 | 80 | dataset.update_output("retrieval_result", retrieval_result) 81 | 82 | input_prompts = [ 83 | self.prompt_template.get_string(item) for item in dataset 84 | ] 85 | 86 | dataset.update_output("prompt", input_prompts) 87 | 88 | pred_answer_list = self.generator.generate(input_prompts) 89 | dataset.update_output("pred", pred_answer_list) 90 | 91 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_func=pred_process_func) 92 | 93 | return dataset 94 | 95 | -------------------------------------------------------------------------------- /src/flashrag/prompt/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.prompt.base_prompt import * 2 | from flashrag.prompt.mm_prompt import * -------------------------------------------------------------------------------- /src/flashrag/prompt/mm_prompt.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | class MMPromptTemplate: 4 | BASE_USER_PROMPT = '{reference}\nBased on the above examples, answer the following question. Only give me the final choices.\nQuestion: {question}\nAnswer: ' 5 | def __init__(self, config, system_prompt=None, user_prompt=None): 6 | self.config = config 7 | self.system_prompt = system_prompt 8 | self.user_prompt = user_prompt if user_prompt is not None else self.BASE_USER_PROMPT 9 | def get_string(self, item): 10 | question = item.question if item.question is not None else item.text 11 | question_image = item.image 12 | # retrieval_result = item.retrieval_result 13 | try: 14 | retrieval_result = item.retrieval_result 15 | except: 16 | retrieval_result = [] 17 | 18 | messages = [] 19 | if self.system_prompt is not None: 20 | messages.append({"role": "system", "content": self.system_prompt}) 21 | reference_str = "" 22 | content_list = [] 23 | for idx, item in enumerate(retrieval_result): 24 | # item is multimodal data or raw text 25 | if 'image' not in item: 26 | # raw text item 27 | reference_str += f'Example {idx+1}: {item["contents"]}\n' 28 | else: 29 | content_list.append({'type': 'image', 'image': item['image']}) 30 | reference_str += f'Example {idx+1}: {item["text"]}\n' 31 | content_list.append({'type': 'image', 'image': question_image}) 32 | content_list.append({'type': 'text', 'text': self.user_prompt.format(question=question, reference=reference_str)}) 33 | messages.append({"role": "user", "content": content_list}) 34 | return messages 35 | 36 | 37 | class GAOKAOMMPromptTemplate(MMPromptTemplate): 38 | BASE_USER_PROMPT = "请你做一道{subject}选择题\n请你结合文字和图片一步一步思考,并将思考过程写在【解析】和之间。{instruction}\n例如:{example}\n请你严格按照上述格式作答。\n你可以参考一些知识: {reference}。题目如下:{question}" 39 | INSTRUCTION_DICT = { 40 | 'single_choice': '你将从A,B,C,D等选项中选出正确的答案,并写在【答案】和之间。', 41 | 'multiple_choice': '你将从A,B,C,D等选项中选出所有符合题意的答案,并写在【答案】和之间。' 42 | } 43 | EXAMPLE_DICT = { 44 | 'single_choice': '【答案】: A \n完整的题目回答的格式如下:\n【解析】 ... \n【答案】 ... ', 45 | 'multiple_choice': '【答案】 AB \n完整的题目回答的格式如下:\n【解析】 ... \n【答案】... ' 46 | } 47 | def __init__(self, config, system_prompt=None, user_prompt=None): 48 | self.config = config 49 | self.system_prompt = system_prompt 50 | if user_prompt is None: 51 | self.user_prompt = self.BASE_USER_PROMPT 52 | else: 53 | self.user_prompt = user_prompt 54 | 55 | def get_string(self, item): 56 | question = item.question if item.question is not None else item.text 57 | question_image = item.image 58 | question_type = item.question_type 59 | subject = item.subject 60 | 61 | instruction = self.INSTRUCTION_DICT[question_type] 62 | example = self.EXAMPLE_DICT[question_type] 63 | 64 | messages = [] 65 | if self.system_prompt is not None: 66 | messages.append({"role": "system", "content": self.system_prompt}) 67 | content_list = [] 68 | if '{reference}' not in self.user_prompt: 69 | user_prompt = self.user_prompt.format(question=question, instruction=instruction, example=example, subject=subject) 70 | else: 71 | retrieval_result = item.retrieval_result 72 | reference_str = "" 73 | for idx, item in enumerate(retrieval_result): 74 | # item is multimodal data or raw text 75 | if 'image' not in item: 76 | # raw text item 77 | reference_str += f'参考内容 {idx+1}: {item["contents"]}\n' 78 | else: 79 | content_list.append({'type': 'image', 'image': item['image']}) 80 | reference_str += f'参考内容 {idx+1}: {item["text"]}, 标准答案: {item["golden_answers"][0]}\n' 81 | user_prompt = self.user_prompt.format(question=question, reference=reference_str, instruction=instruction, example=example, subject=subject) 82 | 83 | content_list.append({'type': 'image', 'image': question_image}) 84 | content_list.append({'type': 'text', 'text': user_prompt}) 85 | messages.append({"role": "user", "content": content_list}) 86 | return messages 87 | -------------------------------------------------------------------------------- /src/flashrag/refiner/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.refiner.refiner import * 2 | from flashrag.refiner.kg_refiner import * -------------------------------------------------------------------------------- /src/flashrag/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.retriever.retriever import * 2 | from flashrag.retriever.reranker import * 3 | from flashrag.retriever.utils import * -------------------------------------------------------------------------------- /src/flashrag/retriever/__main__.py: -------------------------------------------------------------------------------- 1 | from . import index_builder 2 | 3 | if __name__ == "__main__": 4 | index_builder.main() 5 | -------------------------------------------------------------------------------- /src/flashrag/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.utils.utils import * 2 | from flashrag.utils.pred_parse import * -------------------------------------------------------------------------------- /src/flashrag/utils/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_MODEL_DICT = { 2 | # chat 3 | "gpt-4o": "o200k_base", 4 | "gpt-4": "cl100k_base", 5 | "gpt-3.5-turbo": "cl100k_base", 6 | "gpt-3.5": "cl100k_base", # Common shorthand 7 | "gpt-35-turbo": "cl100k_base", # Azure deployment name 8 | # base 9 | "davinci-002": "cl100k_base", 10 | "babbage-002": "cl100k_base", 11 | # embeddings 12 | "text-embedding-ada-002": "cl100k_base", 13 | "text-embedding-3-small": "cl100k_base", 14 | "text-embedding-3-large": "cl100k_base", 15 | # DEPRECATED MODELS 16 | # text (DEPRECATED) 17 | "text-davinci-003": "p50k_base", 18 | "text-davinci-002": "p50k_base", 19 | "text-davinci-001": "r50k_base", 20 | "text-curie-001": "r50k_base", 21 | "text-babbage-001": "r50k_base", 22 | "text-ada-001": "r50k_base", 23 | "davinci": "r50k_base", 24 | "curie": "r50k_base", 25 | "babbage": "r50k_base", 26 | "ada": "r50k_base", 27 | # code (DEPRECATED) 28 | "code-davinci-002": "p50k_base", 29 | "code-davinci-001": "p50k_base", 30 | "code-cushman-002": "p50k_base", 31 | "code-cushman-001": "p50k_base", 32 | "davinci-codex": "p50k_base", 33 | "cushman-codex": "p50k_base", 34 | # edit (DEPRECATED) 35 | "text-davinci-edit-001": "p50k_edit", 36 | "code-davinci-edit-001": "p50k_edit", 37 | # old embeddings (DEPRECATED) 38 | "text-similarity-davinci-001": "r50k_base", 39 | "text-similarity-curie-001": "r50k_base", 40 | "text-similarity-babbage-001": "r50k_base", 41 | "text-similarity-ada-001": "r50k_base", 42 | "text-search-davinci-doc-001": "r50k_base", 43 | "text-search-curie-doc-001": "r50k_base", 44 | "text-search-babbage-doc-001": "r50k_base", 45 | "text-search-ada-doc-001": "r50k_base", 46 | "code-search-babbage-code-001": "r50k_base", 47 | "code-search-ada-code-001": "r50k_base", 48 | # open source 49 | "gpt2": "gpt2", 50 | "gpt-2": "gpt2", # Maintains consistency with gpt-4 51 | } 52 | -------------------------------------------------------------------------------- /src/flashrag/utils/pred_parse.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def selfask_pred_parse(dataset): 4 | """Parsing the prediction results of self-ask format.""" 5 | FINAL_ANSWER_PREFIX = "So the final answer is: " 6 | 7 | for item in dataset: 8 | pred = item.pred 9 | lines = pred.split("\n") 10 | answer = "" 11 | for line in lines: 12 | if FINAL_ANSWER_PREFIX in line: 13 | answer = line.split(FINAL_ANSWER_PREFIX)[1].strip() 14 | break 15 | item.update_output('raw_pred', pred) 16 | item.update_output('pred', answer) 17 | 18 | return dataset 19 | 20 | 21 | def ircot_pred_parse(dataset): 22 | FINAL_ANSWER_PREFIX = "So the answer is:" 23 | for item in dataset: 24 | pred = item.pred 25 | if FINAL_ANSWER_PREFIX in pred: 26 | answer = pred.split(FINAL_ANSWER_PREFIX)[1].strip() 27 | else: 28 | answer = pred 29 | item.update_output('raw_pred', pred) 30 | item.update_output('pred', answer) 31 | return dataset 32 | 33 | 34 | def basic_pred_parse(dataset): 35 | for item in dataset: 36 | pred = item.pred 37 | item.update_output('raw_pred', pred) 38 | item.update_output('pred', pred.split("\n")[0].strip()) 39 | return dataset 40 | 41 | 42 | 43 | def gaokaomm_pred_parse(dataset): 44 | """ 45 | Extract choice answer from model output. 46 | 47 | Format of model_output that is expected: 48 | 'single_choice': choice answer should be the last Capital Letter of the model_output, e.g.: "...【答案】 A " 49 | 'multi_choice': "...【答案】 ABD " or write the choice answers at the end of the model_output, e.g. "... ACD" 50 | """ 51 | 52 | for item in dataset: 53 | model_output = item.pred 54 | question_type = item.question_type 55 | 56 | if question_type == 'single_choice': 57 | model_answer = "" 58 | temp = re.findall(r'[A-D]', model_output[::-1]) 59 | if len(temp) != 0: 60 | model_answer = temp[0] 61 | 62 | elif question_type == 'multiple_choice': 63 | model_answer = [] 64 | answer = '' 65 | content = re.sub(r'\s+', '', model_output) 66 | answer_index = content.find('【答案】') 67 | if answer_index > 0: 68 | temp = content[answer_index:] 69 | if len(re.findall(r'[A-D]', temp)) > 0: 70 | for t in re.findall(r'[A-D]', temp): 71 | answer += t 72 | else: 73 | temp = content[-10:] 74 | if len(re.findall(r'[A-D]', temp)) > 0: 75 | for t in re.findall(r'[A-D]', temp): 76 | answer += t 77 | if len(answer) != 0: 78 | model_answer.append(answer) 79 | model_answer = "".join(model_answer) 80 | 81 | item.update_output('raw_pred', model_output) 82 | item.update_output('pred', model_answer) 83 | 84 | return dataset -------------------------------------------------------------------------------- /src/flashrag/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.4dev0" 2 | -------------------------------------------------------------------------------- /src/re_call/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference.re_call import ReCall 2 | 3 | __all__ = ["ReCall"] -------------------------------------------------------------------------------- /src/re_call/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Agent-RL/ReCall/aaf16b31c83702cdd3c7f7a0ad431efce7904b63/src/re_call/inference/__init__.py -------------------------------------------------------------------------------- /src/verl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) 18 | 19 | with open(os.path.join(version_folder, 'version/version')) as f: 20 | __version__ = f.read().strip() 21 | 22 | from .protocol import DataProto 23 | 24 | from .utils.logging_utils import set_basic_config 25 | import logging 26 | 27 | set_basic_config(level=logging.WARNING) 28 | 29 | from . import single_controller 30 | 31 | __all__ = ['DataProto', "__version__"] 32 | 33 | if os.getenv('VERL_USE_MODELSCOPE', 'False').lower() == 'true': 34 | import importlib 35 | if importlib.util.find_spec("modelscope") is None: 36 | raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`') 37 | # Patch hub to download models from modelscope to speed up. 38 | from modelscope.utils.hf_util import patch_hub 39 | patch_hub() 40 | -------------------------------------------------------------------------------- /src/verl/models/README.md: -------------------------------------------------------------------------------- 1 | # Models 2 | Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl. 3 | ## Adding a New Huggingface Model 4 | ### Step 1: Copy the model file from HF to verl 5 | - Add a new file under verl/models/hf 6 | - Copy ONLY the model file from huggingface/transformers/models to verl/models/hf 7 | 8 | ### Step 2: Modify the model file to use packed inputs 9 | - Remove all the code related to inference (kv cache) 10 | - Modify the inputs to include only 11 | - input_ids (total_nnz,) 12 | - cu_seqlens (total_nnz + 1,) 13 | - max_seqlen_in_batch: int 14 | - Note that this requires using flash attention with causal mask. 15 | 16 | ### Step 2.5: Add tests 17 | - Add a test to compare this version and the huggingface version 18 | - Following the infrastructure and add tests to tests/models/hf 19 | 20 | ### Step 3: Add a function to apply tensor parallelism 21 | - Please follow 22 | - https://pytorch.org/docs/stable/distributed.tensor.parallel.html 23 | - https://pytorch.org/tutorials/intermediate/TP_tutorial.html 24 | - General comments 25 | - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward. 26 | 27 | ### Step 4: Add a function to apply data parallelism 28 | - Please use FSDP2 APIs 29 | - See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413 30 | 31 | ### Step 5: Add a function to apply pipeline parallelism 32 | - Comes in Pytorch 2.4 33 | - Currently only in alpha in nightly version 34 | - Check torchtitan for more details 35 | 36 | -------------------------------------------------------------------------------- /src/verl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/models/llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/models/llama/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .modeling_llama_megatron import ( 16 | # original model with megatron 17 | ParallelLlamaModel, 18 | ParallelLlamaForCausalLM, 19 | # rmpad with megatron 20 | ParallelLlamaForCausalLMRmPad, 21 | ParallelLlamaForValueRmPad, 22 | # rmpad with megatron and pipeline parallelism 23 | ParallelLlamaForCausalLMRmPadPP, 24 | ParallelLlamaForValueRmPadPP) 25 | -------------------------------------------------------------------------------- /src/verl/models/llama/megatron/checkpoint_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/models/llama/megatron/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .parallel_attention import ParallelLlamaAttention 16 | from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad 17 | from .parallel_mlp import ParallelLlamaMLP 18 | from .parallel_rmsnorm import ParallelLlamaRMSNorm 19 | -------------------------------------------------------------------------------- /src/verl/models/llama/megatron/layers/parallel_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py 15 | 16 | from typing import Optional, Tuple 17 | 18 | from megatron.core import tensor_parallel 19 | 20 | 21 | class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): 22 | 23 | def __init__(self, 24 | input_size, 25 | num_heads, 26 | num_key_value_heads, 27 | head_dim, 28 | *, 29 | bias=True, 30 | gather_output=True, 31 | skip_bias_add=False, 32 | **kwargs): 33 | # Keep input parameters, and already restrict the head numbers 34 | self.input_size = input_size 35 | self.q_output_size = num_heads * head_dim 36 | self.kv_output_size = num_key_value_heads * head_dim 37 | self.head_dim = head_dim 38 | self.gather_output = gather_output 39 | self.skip_bias_add = skip_bias_add 40 | 41 | input_size = self.input_size 42 | output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim 43 | 44 | super().__init__(input_size=input_size, 45 | output_size=output_size, 46 | bias=bias, 47 | gather_output=gather_output, 48 | skip_bias_add=skip_bias_add, 49 | **kwargs) 50 | 51 | 52 | class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): 53 | 54 | def __init__(self, 55 | input_size, 56 | gate_ouput_size, 57 | up_output_size, 58 | *, 59 | bias=True, 60 | gather_output=True, 61 | skip_bias_add=False, 62 | **kwargs): 63 | # Keep input parameters, and already restrict the head numbers 64 | self.input_size = input_size 65 | self.output_size = gate_ouput_size + up_output_size 66 | self.gather_output = gather_output 67 | self.skip_bias_add = skip_bias_add 68 | 69 | super().__init__(input_size=self.input_size, 70 | output_size=self.output_size, 71 | bias=bias, 72 | gather_output=gather_output, 73 | skip_bias_add=skip_bias_add, 74 | **kwargs) 75 | 76 | 77 | import torch 78 | 79 | 80 | class LinearForLastLayer(torch.nn.Linear): 81 | 82 | def __init__( 83 | self, 84 | input_size, 85 | output_size, 86 | *, 87 | config, 88 | bias=True, 89 | ): 90 | super().__init__(in_features=input_size, out_features=output_size, bias=bias) 91 | self.sequence_parallel = config.sequence_parallel 92 | if self.sequence_parallel: 93 | setattr(self.weight, 'sequence_parallel', True) 94 | 95 | def forward( 96 | self, 97 | input_, 98 | weight=None, 99 | runtime_gather_output=None, 100 | ): 101 | logits = super().forward(input_) 102 | logits = logits.float() 103 | if self.sequence_parallel: 104 | logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) 105 | return logits, None 106 | -------------------------------------------------------------------------------- /src/verl/models/llama/megatron/layers/parallel_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | 21 | from megatron.core import parallel_state as mpu 22 | from megatron.core import tensor_parallel 23 | from megatron.core import ModelParallelConfig 24 | from torch import nn 25 | from transformers.activations import ACT2FN 26 | from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear 27 | 28 | from verl.utils.megatron import tensor_parallel as tp_utils 29 | 30 | 31 | class ParallelLlamaMLP(nn.Module): 32 | 33 | def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: 34 | super().__init__() 35 | self.config = config 36 | self.hidden_size = config.hidden_size 37 | self.intermediate_size = config.intermediate_size 38 | # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] 39 | 40 | column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() 41 | row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() 42 | 43 | if megatron_config is not None: 44 | assert column_kwargs.get('config', False), 'must have ModelParallelConfig' 45 | assert row_kwargs.get('config', False), 'must have ModelParallelConfig' 46 | tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) 47 | tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) 48 | 49 | tp_size = mpu.get_tensor_model_parallel_world_size() 50 | 51 | self.gate_up_proj = MergedColumnParallelLinear( 52 | input_size=self.hidden_size, 53 | gate_ouput_size=self.intermediate_size, 54 | up_output_size=self.intermediate_size, 55 | bias=False, 56 | gather_output=False, 57 | skip_bias_add=False, 58 | **column_kwargs, 59 | ) 60 | self.gate_size = self.intermediate_size // tp_size 61 | 62 | self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size, 63 | output_size=self.hidden_size, 64 | bias=False, 65 | input_is_parallel=True, 66 | skip_bias_add=False, 67 | **row_kwargs) 68 | 69 | self.act_fn = ACT2FN[config.hidden_act] 70 | 71 | def forward(self, x): 72 | gate_up = self.gate_up_proj(x)[0] 73 | gate, up = gate_up.split(self.gate_size, dim=-1) 74 | return self.down_proj(self.act_fn(gate) * up)[0] 75 | -------------------------------------------------------------------------------- /src/verl/models/llama/megatron/layers/parallel_rmsnorm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numbers 16 | import torch 17 | from megatron.core import ModelParallelConfig 18 | from torch import nn 19 | from transformers import LlamaConfig 20 | 21 | from apex.normalization.fused_layer_norm import fused_rms_norm_affine 22 | from verl.utils.megatron import sequence_parallel as sp_utils 23 | 24 | 25 | class ParallelLlamaRMSNorm(nn.Module): 26 | 27 | def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): 28 | """ 29 | LlamaRMSNorm is equivalent to T5LayerNorm 30 | """ 31 | super().__init__() 32 | if isinstance(config.hidden_size, numbers.Integral): 33 | normalized_shape = (config.hidden_size,) 34 | self.normalized_shape = torch.Size(normalized_shape) 35 | self.weight = nn.Parameter(torch.ones(self.normalized_shape)) 36 | self.variance_epsilon = config.rms_norm_eps 37 | 38 | if megatron_config.sequence_parallel: 39 | sp_utils.mark_parameter_as_sequence_parallel(self.weight) 40 | 41 | def forward(self, hidden_states): 42 | return fused_rms_norm_affine(input=hidden_states, 43 | weight=self.weight, 44 | normalized_shape=self.normalized_shape, 45 | eps=self.variance_epsilon, 46 | memory_efficient=True) -------------------------------------------------------------------------------- /src/verl/models/mcore/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Bytedance Ltd. and/or its affiliates 2 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .gpt_model import gptmodel_forward -------------------------------------------------------------------------------- /src/verl/models/qwen2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/models/qwen2/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .modeling_qwen2_megatron import ( 16 | # original model with megatron 17 | ParallelQwen2Model, 18 | ParallelQwen2ForCausalLM, 19 | # rmpad with megatron 20 | ParallelQwen2ForCausalLMRmPad, 21 | ParallelQwen2ForValueRmPad, 22 | # rmpad with megatron and pipeline parallelism 23 | ParallelQwen2ForCausalLMRmPadPP, 24 | ParallelQwen2ForValueRmPadPP) 25 | -------------------------------------------------------------------------------- /src/verl/models/qwen2/megatron/checkpoint_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/models/qwen2/megatron/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .parallel_attention import ParallelQwen2Attention 16 | from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad 17 | from .parallel_mlp import ParallelQwen2MLP 18 | from .parallel_rmsnorm import ParallelQwen2RMSNorm 19 | -------------------------------------------------------------------------------- /src/verl/models/qwen2/megatron/layers/parallel_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py 15 | 16 | from typing import Optional, Tuple 17 | 18 | from megatron.core import tensor_parallel 19 | 20 | 21 | class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): 22 | 23 | def __init__(self, 24 | input_size, 25 | num_heads, 26 | num_key_value_heads, 27 | head_dim, 28 | *, 29 | bias=True, 30 | gather_output=True, 31 | skip_bias_add=False, 32 | **kwargs): 33 | # Keep input parameters, and already restrict the head numbers 34 | self.input_size = input_size 35 | self.q_output_size = num_heads * head_dim 36 | self.kv_output_size = num_key_value_heads * head_dim 37 | self.head_dim = head_dim 38 | self.gather_output = gather_output 39 | self.skip_bias_add = skip_bias_add 40 | 41 | input_size = self.input_size 42 | output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim 43 | 44 | super().__init__(input_size=input_size, 45 | output_size=output_size, 46 | bias=bias, 47 | gather_output=gather_output, 48 | skip_bias_add=skip_bias_add, 49 | **kwargs) 50 | 51 | 52 | class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): 53 | 54 | def __init__(self, 55 | input_size, 56 | gate_ouput_size, 57 | up_output_size, 58 | *, 59 | bias=True, 60 | gather_output=True, 61 | skip_bias_add=False, 62 | **kwargs): 63 | # Keep input parameters, and already restrict the head numbers 64 | self.input_size = input_size 65 | self.output_size = gate_ouput_size + up_output_size 66 | self.gather_output = gather_output 67 | self.skip_bias_add = skip_bias_add 68 | 69 | super().__init__(input_size=self.input_size, 70 | output_size=self.output_size, 71 | bias=bias, 72 | gather_output=gather_output, 73 | skip_bias_add=skip_bias_add, 74 | **kwargs) 75 | -------------------------------------------------------------------------------- /src/verl/models/qwen2/megatron/layers/parallel_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | 21 | from megatron.core import parallel_state as mpu 22 | from megatron.core import tensor_parallel 23 | from megatron.core import ModelParallelConfig 24 | from torch import nn 25 | from transformers.activations import ACT2FN 26 | from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear 27 | 28 | from verl.utils.megatron import tensor_parallel as tp_utils 29 | 30 | 31 | class ParallelQwen2MLP(nn.Module): 32 | 33 | def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: 34 | super().__init__() 35 | self.config = config 36 | self.hidden_size = config.hidden_size 37 | self.intermediate_size = config.intermediate_size 38 | # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] 39 | 40 | column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() 41 | row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() 42 | 43 | if megatron_config is not None: 44 | assert column_kwargs.get('config', False), 'must have ModelParallelConfig' 45 | assert row_kwargs.get('config', False), 'must have ModelParallelConfig' 46 | tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) 47 | tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) 48 | 49 | tp_size = mpu.get_tensor_model_parallel_world_size() 50 | 51 | self.gate_up_proj = MergedColumnParallelLinear( 52 | input_size=self.hidden_size, 53 | gate_ouput_size=self.intermediate_size, 54 | up_output_size=self.intermediate_size, 55 | bias=False, 56 | gather_output=False, 57 | skip_bias_add=False, 58 | **column_kwargs, 59 | ) 60 | self.gate_size = self.intermediate_size // tp_size 61 | 62 | self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size, 63 | output_size=self.hidden_size, 64 | bias=False, 65 | input_is_parallel=True, 66 | skip_bias_add=False, 67 | **row_kwargs) 68 | 69 | self.act_fn = ACT2FN[config.hidden_act] 70 | 71 | def forward(self, x): 72 | gate_up = self.gate_up_proj(x)[0] 73 | gate, up = gate_up.split(self.gate_size, dim=-1) 74 | return self.down_proj(self.act_fn(gate) * up)[0] 75 | -------------------------------------------------------------------------------- /src/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numbers 16 | import torch 17 | from megatron.core import ModelParallelConfig 18 | from torch import nn 19 | from transformers import Qwen2Config 20 | 21 | from apex.normalization.fused_layer_norm import fused_rms_norm_affine 22 | from verl.utils.megatron import sequence_parallel as sp_utils 23 | 24 | 25 | class ParallelQwen2RMSNorm(nn.Module): 26 | 27 | def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): 28 | """ 29 | Qwen2RMSNorm is equivalent to T5LayerNorm 30 | """ 31 | super().__init__() 32 | if isinstance(config.hidden_size, numbers.Integral): 33 | normalized_shape = (config.hidden_size,) 34 | self.normalized_shape = torch.Size(normalized_shape) 35 | self.weight = nn.Parameter(torch.ones(self.normalized_shape)) 36 | self.variance_epsilon = config.rms_norm_eps 37 | 38 | if megatron_config.sequence_parallel: 39 | sp_utils.mark_parameter_as_sequence_parallel(self.weight) 40 | 41 | def forward(self, hidden_states): 42 | return fused_rms_norm_affine(input=hidden_states, 43 | weight=self.weight, 44 | normalized_shape=self.normalized_shape, 45 | eps=self.variance_epsilon, 46 | memory_efficient=True) -------------------------------------------------------------------------------- /src/verl/models/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib 16 | from typing import List, Optional, Type 17 | 18 | import torch.nn as nn 19 | 20 | # Supported models in Megatron-LM 21 | # Architecture -> (module, class). 22 | _MODELS = { 23 | "LlamaForCausalLM": 24 | ("llama", ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad")), 25 | "Qwen2ForCausalLM": 26 | ("qwen2", ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad")), 27 | "MistralForCausalLM": ("mistral", ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", 28 | "ParallelMistralForCausalLMRmPad")) 29 | } 30 | 31 | 32 | # return model class 33 | class ModelRegistry: 34 | 35 | @staticmethod 36 | def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]: 37 | if model_arch not in _MODELS: 38 | return None 39 | 40 | megatron = "megatron" 41 | 42 | module_name, model_cls_name = _MODELS[model_arch] 43 | if not value: # actor/ref 44 | model_cls_name = model_cls_name[0] 45 | elif value: # critic/rm 46 | model_cls_name = model_cls_name[1] 47 | 48 | module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron") 49 | return getattr(module, model_cls_name, None) 50 | 51 | @staticmethod 52 | def get_supported_archs() -> List[str]: 53 | return list(_MODELS.keys()) 54 | -------------------------------------------------------------------------------- /src/verl/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/models/weight_loader_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def get_weight_loader(arch: str): 17 | from verl.models.llama.megatron.checkpoint_utils.llama_loader import load_state_dict_to_megatron_llama 18 | from verl.models.qwen2.megatron.checkpoint_utils.qwen2_loader import load_state_dict_to_megatron_qwen2 19 | from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel 20 | _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = { 21 | 'LlamaForCausalLM': load_state_dict_to_megatron_gptmodel, 22 | 'Qwen2ForCausalLM': load_state_dict_to_megatron_gptmodel, 23 | } 24 | 25 | if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: 26 | return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch] 27 | raise ValueError(f"Model architectures {arch} loader are not supported for now. " 28 | f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}") 29 | 30 | 31 | def get_weight_saver(arch: str): 32 | from verl.models.llama.megatron.checkpoint_utils.llama_saver import merge_megatron_ckpt_llama 33 | from verl.models.qwen2.megatron.checkpoint_utils.qwen2_saver import merge_megatron_ckpt_qwen2 34 | from verl.models.mcore.saver import merge_megatron_ckpt_gptmodel 35 | _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = { 36 | 'LlamaForCausalLM': merge_megatron_ckpt_gptmodel, 37 | 'Qwen2ForCausalLM': merge_megatron_ckpt_gptmodel, 38 | } 39 | if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY: 40 | return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch] 41 | raise ValueError(f"Model architectures {arch} saver are not supported for now. " 42 | f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}") 43 | -------------------------------------------------------------------------------- /src/verl/single_controller/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) 18 | 19 | # Note(haibin.lin): single_controller.__version__ is deprecated 20 | with open(os.path.join(os.path.join(version_folder, os.pardir), 'version/version')) as f: 21 | __version__ = f.read().strip() 22 | 23 | from . import base 24 | from .base import * 25 | 26 | __all__ = base.__all__ -------------------------------------------------------------------------------- /src/verl/single_controller/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .worker import Worker 16 | from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool 17 | 18 | __all__ = ['Worker', 'WorkerGroup', 'ClassWithInitArgs', 'ResourcePool'] -------------------------------------------------------------------------------- /src/verl/single_controller/base/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/single_controller/base/megatron/worker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo 16 | 17 | 18 | class MegatronWorker(Worker): 19 | 20 | def __init__(self, cuda_visible_devices=None) -> None: 21 | super().__init__(cuda_visible_devices) 22 | 23 | def get_megatron_global_info(self): 24 | from megatron.core import parallel_state as mpu 25 | tp_size = mpu.get_tensor_model_parallel_world_size() 26 | dp_size = mpu.get_data_parallel_world_size() 27 | pp_size = mpu.get_pipeline_model_parallel_world_size() 28 | cp_size = mpu.get_context_parallel_world_size() 29 | info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size, cp_size=cp_size) 30 | return info 31 | 32 | def get_megatron_rank_info(self): 33 | from megatron.core import parallel_state as mpu 34 | tp_rank = mpu.get_tensor_model_parallel_rank() 35 | dp_rank = mpu.get_data_parallel_rank() 36 | pp_rank = mpu.get_pipeline_model_parallel_rank() 37 | cp_rank = mpu.get_context_parallel_rank() 38 | info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank, cp_rank=cp_rank) 39 | return info -------------------------------------------------------------------------------- /src/verl/single_controller/base/megatron/worker_group.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict 16 | 17 | from .worker import DistRankInfo, DistGlobalInfo 18 | from verl.single_controller.base import ResourcePool, WorkerGroup 19 | 20 | 21 | class MegatronWorkerGroup(WorkerGroup): 22 | 23 | def __init__(self, resource_pool: ResourcePool, **kwargs): 24 | super().__init__(resource_pool=resource_pool, **kwargs) 25 | self._megatron_rank_info = None 26 | self._megatron_global_info: DistGlobalInfo = None 27 | 28 | def init_megatron(self, default_megatron_kwargs: Dict = None): 29 | raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten") 30 | 31 | def get_megatron_rank_info(self, rank: int) -> DistRankInfo: 32 | assert 0 <= rank < self.world_size, f'rank must be from [0, world_size), Got {rank}' 33 | return self._megatron_rank_info[rank] 34 | 35 | @property 36 | def tp_size(self): 37 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" 38 | return self._megatron_global_info.tp_size 39 | 40 | @property 41 | def dp_size(self): 42 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" 43 | return self._megatron_global_info.dp_size 44 | 45 | @property 46 | def pp_size(self): 47 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" 48 | return self._megatron_global_info.pp_size 49 | 50 | @property 51 | def cp_size(self): 52 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" 53 | return self._megatron_global_info.cp_size 54 | 55 | def get_megatron_global_info(self): 56 | return self._megatron_global_info 57 | -------------------------------------------------------------------------------- /src/verl/single_controller/base/register_center/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/single_controller/base/register_center/ray.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ray 16 | 17 | 18 | @ray.remote 19 | class WorkerGroupRegisterCenter: 20 | 21 | def __init__(self, rank_zero_info): 22 | self.rank_zero_info = rank_zero_info 23 | 24 | def get_rank_zero_info(self): 25 | return self.rank_zero_info 26 | 27 | 28 | def create_worker_group_register_center(name, info): 29 | return WorkerGroupRegisterCenter.options(name=name).remote(info) 30 | -------------------------------------------------------------------------------- /src/verl/single_controller/ray/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls -------------------------------------------------------------------------------- /src/verl/single_controller/ray/megatron.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict, Optional 16 | 17 | import ray 18 | 19 | from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs 20 | from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo 21 | from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup 22 | 23 | 24 | # NOTE(sgm): for open-source megatron-core 25 | class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): 26 | """ 27 | MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup 28 | so that the dispatcher can use it to dispatch data. 29 | """ 30 | 31 | def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs): 32 | super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) 33 | self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info') 34 | self._megatron_global_info: DistGlobalInfo = ray.get( 35 | self.execute_rank_zero_async(method_name='get_megatron_global_info')) 36 | 37 | 38 | class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): 39 | """ 40 | MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup 41 | so that the dispatcher can use it to dispatch data. 42 | """ 43 | 44 | def __init__(self, 45 | resource_pool: RayResourcePool, 46 | ray_cls_with_init: RayClassWithInitArgs, 47 | default_megatron_kwargs: Dict = None, 48 | **kwargs): 49 | super().__init__(resource_pool=resource_pool, 50 | ray_cls_with_init=ray_cls_with_init, 51 | default_megatron_kwargs=default_megatron_kwargs, 52 | **kwargs) 53 | self.init_megatron(default_megatron_kwargs=default_megatron_kwargs) 54 | self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info') 55 | self._megatron_global_info: DistGlobalInfo = ray.get( 56 | self.execute_rank_zero_async(method_name='get_megatron_global_info')) 57 | 58 | def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None): 59 | # after super, we will call init of each worker 60 | if not self._is_init_with_detached_workers: 61 | # only init_megatron if the WorkerGroup is created from scratch 62 | self.execute_all_sync(method_name='init_megatron', default_megatron_kwargs=default_megatron_kwargs) 63 | -------------------------------------------------------------------------------- /src/verl/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/third_party/sglang/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 SGLang Team 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 15 | # 16 | # Licensed under the Apache License, Version 2.0 (the "License"); 17 | # you may not use this file except in compliance with the License. 18 | # You may obtain a copy of the License at 19 | # 20 | # http://www.apache.org/licenses/LICENSE-2.0 21 | # 22 | # Unless required by applicable law or agreed to in writing, software 23 | # distributed under the License is distributed on an "AS IS" BASIS, 24 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 | # See the License for the specific language governing permissions and 26 | # limitations under the License. -------------------------------------------------------------------------------- /src/verl/third_party/vllm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from importlib.metadata import version, PackageNotFoundError 16 | from packaging import version as vs 17 | from verl.utils.import_utils import is_sglang_available 18 | 19 | 20 | def get_version(pkg): 21 | try: 22 | return version(pkg) 23 | except PackageNotFoundError: 24 | return None 25 | 26 | 27 | package_name = 'vllm' 28 | package_version = get_version(package_name) 29 | vllm_version = None 30 | 31 | if package_version == '0.3.1': 32 | vllm_version = '0.3.1' 33 | from .vllm_v_0_3_1.llm import LLM 34 | from .vllm_v_0_3_1.llm import LLMEngine 35 | from .vllm_v_0_3_1 import parallel_state 36 | elif package_version == '0.4.2': 37 | vllm_version = '0.4.2' 38 | from .vllm_v_0_4_2.llm import LLM 39 | from .vllm_v_0_4_2.llm import LLMEngine 40 | from .vllm_v_0_4_2 import parallel_state 41 | elif package_version == '0.5.4': 42 | vllm_version = '0.5.4' 43 | from .vllm_v_0_5_4.llm import LLM 44 | from .vllm_v_0_5_4.llm import LLMEngine 45 | from .vllm_v_0_5_4 import parallel_state 46 | elif package_version == '0.6.3': 47 | vllm_version = '0.6.3' 48 | from .vllm_v_0_6_3.llm import LLM 49 | from .vllm_v_0_6_3.llm import LLMEngine 50 | from .vllm_v_0_6_3 import parallel_state 51 | elif package_version == '0.6.3+rocm624': 52 | vllm_version = '0.6.3' 53 | from .vllm_v_0_6_3.llm import LLM 54 | from .vllm_v_0_6_3.llm import LLMEngine 55 | from .vllm_v_0_6_3 import parallel_state 56 | elif vs.parse(package_version) >= vs.parse('0.7.0'): 57 | # From 0.6.6.post2 on, vllm supports SPMD inference 58 | # See https://github.com/vllm-project/vllm/pull/12071 59 | 60 | from vllm import LLM 61 | from vllm.distributed import parallel_state 62 | else: 63 | if not is_sglang_available(): 64 | raise ValueError( 65 | f'vllm version {package_version} not supported and SGLang also not Found. Currently supported vllm versions are 0.3.1, 0.4.2, 0.5.4, 0.6.3 and 0.7.0+' 66 | ) 67 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_3_1/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) 19 | 20 | from vllm.lora.request import LoRARequest 21 | from vllm.utils import make_async, LRUCache 22 | from vllm.transformers_utils.tokenizers import * 23 | 24 | 25 | class TokenizerGroup: 26 | """A group of tokenizers that can be used for LoRA adapters.""" 27 | 28 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, 29 | max_input_length: Optional[int]): 30 | self.enable_lora = enable_lora 31 | self.max_input_length = max_input_length 32 | self.tokenizer = tokenizer 33 | if enable_lora: 34 | self.lora_tokenizers = LRUCache(capacity=max_num_seqs) 35 | else: 36 | self.lora_tokenizers = None 37 | 38 | def encode(self, 39 | prompt: str, 40 | request_id: Optional[str] = None, 41 | lora_request: Optional[LoRARequest] = None) -> List[int]: 42 | tokenizer = self.get_lora_tokenizer(lora_request) 43 | return tokenizer.encode(prompt) 44 | 45 | async def encode_async(self, 46 | prompt: str, 47 | request_id: Optional[str] = None, 48 | lora_request: Optional[LoRARequest] = None) -> List[int]: 49 | tokenizer = await self.get_lora_tokenizer_async(lora_request) 50 | return tokenizer.encode(prompt) 51 | 52 | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": 53 | if not lora_request or not self.enable_lora: 54 | return self.tokenizer 55 | if lora_request.lora_int_id not in self.lora_tokenizers: 56 | # TODO(sgm): the lora tokenizer is also passed, but may be different 57 | tokenizer = self.tokenizer 58 | # tokenizer = (get_lora_tokenizer( 59 | # lora_request, **self.tokenizer_config) or self.tokenizer) 60 | self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) 61 | return tokenizer 62 | else: 63 | return self.lora_tokenizers.get(lora_request.lora_int_id) 64 | 65 | # FIXME(sgm): for simplicity, we assign the special token here 66 | @property 67 | def pad_token_id(self): 68 | return self.tokenizer.pad_token_id 69 | 70 | @property 71 | def eos_token_id(self): 72 | return self.tokenizer.eos_token_id 73 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models 15 | 16 | from typing import Dict 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | # NOTE(shengguangming): replace the origin weight loader function in the class 22 | def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: 23 | """Parallel Linear weight loader.""" 24 | assert param.size() == loaded_weight.size( 25 | ), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format( 26 | param.size(), loaded_weight.size()) 27 | assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" 28 | 29 | param.data = loaded_weight.data 30 | 31 | 32 | def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: 33 | """Default weight loader.""" 34 | assert param.size() == loaded_weight.size() 35 | assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" 36 | 37 | param.data = loaded_weight.data 38 | 39 | 40 | def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: 41 | params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) 42 | for name, loaded_weight in actor_weights.items(): 43 | if "lm_head.weight" in name: 44 | # GPT-2 ties the weights of the embedding layer and the final 45 | # linear layer. 46 | continue 47 | if ".attn.bias" in name or ".attn.masked_bias" in name: 48 | # Skip attention mask. 49 | # NOTE: "c_attn.bias" should not be skipped. 50 | continue 51 | if not name.startswith("transformer."): 52 | name = "transformer." + name 53 | param = params_dict[name] 54 | # The HF's GPT-2 implementation uses Conv1D instead of Linear. 55 | # Because of this, we need to transpose the weights. 56 | # Note(zhuohan): the logic below might break quantized models. 57 | for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: 58 | if conv1d_weight_name not in name: 59 | continue 60 | if not name.endswith(".weight"): 61 | continue 62 | # TODO: check megatron 63 | loaded_weight = loaded_weight.t() 64 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 65 | weight_loader(param, loaded_weight) 66 | 67 | 68 | def llama_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: 69 | # NOTE(shengguangming): the megatron llama may have this prefix 70 | prefix = '0.module.module.' 71 | params_dict = dict(vllm_model.named_parameters()) 72 | for name, loaded_weight in actor_weights.items(): 73 | if name[:len(prefix)] == prefix: 74 | name = name[len(prefix):] 75 | if "rotary_emb.inv_freq" in name: 76 | continue 77 | else: 78 | param = params_dict[name] 79 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 80 | weight_loader(param, loaded_weight) 81 | 82 | 83 | def mistral_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: 84 | # TODO: need to implement a general way to deal with prefix 85 | prefix = '0.module.module.' 86 | params_dict = dict(vllm_model.named_parameters()) 87 | for name, loaded_weight in actor_weights.items(): 88 | if name[:len(prefix)] == prefix: 89 | name = name[len(prefix):] 90 | if "rotary_emb.inv_freq" in name: 91 | continue 92 | else: 93 | param = params_dict[name] 94 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 95 | weight_loader(param, loaded_weight) 96 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_4_2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models 15 | 16 | from typing import Dict, Union, Optional, Iterable, Tuple 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype 22 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader 23 | 24 | 25 | def update_hf_weight_loader(): 26 | from vllm.model_executor.models.gemma import GemmaForCausalLM 27 | GemmaForCausalLM.load_weights = gemma_load_weights 28 | 29 | 30 | def gemma_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): 31 | stacked_params_mapping = [ 32 | # (param_name, shard_name, shard_id) 33 | ("qkv_proj", "q_proj", "q"), 34 | ("qkv_proj", "k_proj", "k"), 35 | ("qkv_proj", "v_proj", "v"), 36 | ("gate_up_proj", "gate_proj", 0), 37 | ("gate_up_proj", "up_proj", 1), 38 | ] 39 | params_dict = dict(self.named_parameters()) 40 | loaded_params = set() 41 | for name, loaded_weight in weights: 42 | for (param_name, shard_name, shard_id) in stacked_params_mapping: 43 | if shard_name not in name: 44 | continue 45 | name = name.replace(shard_name, param_name) 46 | # Skip loading extra bias for GPTQ models. 47 | if name.endswith(".bias") and name not in params_dict: 48 | continue 49 | param = params_dict[name] 50 | weight_loader = param.weight_loader 51 | weight_loader(param, loaded_weight, shard_id) 52 | break 53 | else: 54 | # lm_head is not used in vllm as it is tied with embed_token. 55 | # To prevent errors, skip loading lm_head.weight. 56 | if "lm_head.weight" in name: 57 | continue 58 | # Skip loading extra bias for GPTQ models. 59 | if name.endswith(".bias") and name not in params_dict: 60 | continue 61 | # GemmaRMSNorm is different from Llama's in that it multiplies 62 | # (1 + weight) to the output, instead of just weight. 63 | if "norm.weight" in name: 64 | norm_weight = loaded_weight + 1.0 # prevent inplace modify actor weights 65 | param = params_dict[name] 66 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 67 | weight_loader(param, norm_weight) 68 | else: 69 | param = params_dict[name] 70 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 71 | weight_loader(param, loaded_weight) 72 | loaded_params.add(name) 73 | unloaded_params = params_dict.keys() - loaded_params 74 | if unloaded_params: 75 | raise RuntimeError("Some weights are not initialized from checkpoints: " 76 | f"{unloaded_params}") 77 | 78 | 79 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): 80 | assert isinstance(actor_weights, Dict) 81 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO 82 | vllm_model.load_weights(actor_weights.items()) 83 | for _, module in vllm_model.named_modules(): 84 | quant_method = getattr(module, "quant_method", None) 85 | if quant_method is not None: 86 | quant_method.process_weights_after_loading(module) 87 | # FIXME: Remove this after Mixtral is updated 88 | # to use quant_method. 89 | if hasattr(module, "process_weights_after_loading"): 90 | module.process_weights_after_loading() 91 | vllm_model = vllm_model.cuda() 92 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) 19 | 20 | from vllm.lora.request import LoRARequest 21 | from vllm.utils import make_async, LRUCache 22 | from vllm.transformers_utils.tokenizers import * 23 | 24 | 25 | class TokenizerGroup: 26 | """A group of tokenizers that can be used for LoRA adapters.""" 27 | 28 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, 29 | max_input_length: Optional[int]): 30 | self.enable_lora = enable_lora 31 | self.max_input_length = max_input_length 32 | self.tokenizer = tokenizer 33 | self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None 34 | 35 | def ping(self) -> bool: 36 | """Check if the tokenizer group is alive.""" 37 | return True 38 | 39 | def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]: 40 | """Get the maximum input length for the LoRA request.""" 41 | return self.max_input_length 42 | 43 | def encode(self, 44 | prompt: str, 45 | request_id: Optional[str] = None, 46 | lora_request: Optional[LoRARequest] = None) -> List[int]: 47 | tokenizer = self.get_lora_tokenizer(lora_request) 48 | return tokenizer.encode(prompt) 49 | 50 | async def encode_async(self, 51 | prompt: str, 52 | request_id: Optional[str] = None, 53 | lora_request: Optional[LoRARequest] = None) -> List[int]: 54 | tokenizer = await self.get_lora_tokenizer_async(lora_request) 55 | return tokenizer.encode(prompt) 56 | 57 | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": 58 | if not lora_request or not self.enable_lora: 59 | return self.tokenizer 60 | if lora_request.lora_int_id not in self.lora_tokenizers: 61 | # TODO(sgm): the lora tokenizer is also passed, but may be different 62 | tokenizer = self.tokenizer 63 | # tokenizer = (get_lora_tokenizer( 64 | # lora_request, **self.tokenizer_config) or self.tokenizer) 65 | self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) 66 | return tokenizer 67 | else: 68 | return self.lora_tokenizers.get(lora_request.lora_int_id) 69 | 70 | # FIXME(sgm): for simplicity, we assign the special token here 71 | @property 72 | def pad_token_id(self): 73 | return self.tokenizer.pad_token_id 74 | 75 | @property 76 | def eos_token_id(self): 77 | return self.tokenizer.eos_token_id 78 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_5_4/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models 15 | 16 | from typing import Dict, Union, Optional, Iterable, Tuple 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype 22 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader 23 | 24 | 25 | def update_hf_weight_loader(): 26 | print('no hf weight loader need to be updated') 27 | return 28 | 29 | 30 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): 31 | assert isinstance(actor_weights, Dict) 32 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO 33 | if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys(): 34 | del actor_weights["lm_head.weight"] 35 | vllm_model.load_weights(actor_weights.items()) 36 | for _, module in vllm_model.named_modules(): 37 | quant_method = getattr(module, "quant_method", None) 38 | if quant_method is not None: 39 | quant_method.process_weights_after_loading(module) 40 | # FIXME: Remove this after Mixtral is updated 41 | # to use quant_method. 42 | if hasattr(module, "process_weights_after_loading"): 43 | module.process_weights_after_loading() 44 | vllm_model = vllm_model.cuda() 45 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) 19 | 20 | from vllm.lora.request import LoRARequest 21 | from vllm.utils import make_async, LRUCache 22 | from vllm.transformers_utils.tokenizers import * 23 | 24 | 25 | class TokenizerGroup: 26 | """A group of tokenizers that can be used for LoRA adapters.""" 27 | 28 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, 29 | max_input_length: Optional[int]): 30 | self.enable_lora = enable_lora 31 | self.max_input_length = max_input_length 32 | self.tokenizer = tokenizer 33 | self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None 34 | 35 | def ping(self) -> bool: 36 | """Check if the tokenizer group is alive.""" 37 | return True 38 | 39 | def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]: 40 | """Get the maximum input length for the LoRA request.""" 41 | return self.max_input_length 42 | 43 | def encode(self, 44 | prompt: str, 45 | request_id: Optional[str] = None, 46 | lora_request: Optional[LoRARequest] = None) -> List[int]: 47 | tokenizer = self.get_lora_tokenizer(lora_request) 48 | return tokenizer.encode(prompt) 49 | 50 | async def encode_async(self, 51 | prompt: str, 52 | request_id: Optional[str] = None, 53 | lora_request: Optional[LoRARequest] = None) -> List[int]: 54 | tokenizer = await self.get_lora_tokenizer_async(lora_request) 55 | return tokenizer.encode(prompt) 56 | 57 | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": 58 | if not lora_request or not self.enable_lora: 59 | return self.tokenizer 60 | if lora_request.lora_int_id not in self.lora_tokenizers: 61 | # TODO(sgm): the lora tokenizer is also passed, but may be different 62 | tokenizer = self.tokenizer 63 | # tokenizer = (get_lora_tokenizer( 64 | # lora_request, **self.tokenizer_config) or self.tokenizer) 65 | self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) 66 | return tokenizer 67 | else: 68 | return self.lora_tokenizers.get(lora_request.lora_int_id) 69 | 70 | # FIXME(sgm): for simplicity, we assign the special token here 71 | @property 72 | def pad_token_id(self): 73 | return self.tokenizer.pad_token_id 74 | 75 | @property 76 | def eos_token_id(self): 77 | return self.tokenizer.eos_token_id 78 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_6_3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py 15 | 16 | import os 17 | from dataclasses import dataclass 18 | 19 | from transformers import PretrainedConfig 20 | from vllm.config import EngineConfig 21 | from vllm.engine.arg_utils import EngineArgs 22 | 23 | from .config import LoadConfig, ModelConfig 24 | 25 | 26 | @dataclass 27 | class EngineArgs(EngineArgs): 28 | model_hf_config: PretrainedConfig = None # for verl 29 | 30 | def __post_init__(self): 31 | pass 32 | 33 | def create_model_config(self) -> ModelConfig: 34 | return ModelConfig( 35 | hf_config=self.model_hf_config, 36 | tokenizer_mode=self.tokenizer_mode, 37 | trust_remote_code=self.trust_remote_code, 38 | dtype=self.dtype, 39 | seed=self.seed, 40 | revision=self.revision, 41 | code_revision=self.code_revision, 42 | rope_scaling=self.rope_scaling, 43 | rope_theta=self.rope_theta, 44 | tokenizer_revision=self.tokenizer_revision, 45 | max_model_len=self.max_model_len, 46 | quantization=self.quantization, 47 | quantization_param_path=self.quantization_param_path, 48 | enforce_eager=self.enforce_eager, 49 | max_context_len_to_capture=self.max_context_len_to_capture, 50 | max_seq_len_to_capture=self.max_seq_len_to_capture, 51 | max_logprobs=self.max_logprobs, 52 | disable_sliding_window=self.disable_sliding_window, 53 | skip_tokenizer_init=self.skip_tokenizer_init, 54 | served_model_name=self.served_model_name, 55 | limit_mm_per_prompt=self.limit_mm_per_prompt, 56 | use_async_output_proc=not self.disable_async_output_proc, 57 | override_neuron_config=self.override_neuron_config, 58 | config_format=self.config_format, 59 | mm_processor_kwargs=self.mm_processor_kwargs, 60 | ) 61 | 62 | def create_load_config(self) -> LoadConfig: 63 | return LoadConfig( 64 | load_format=self.load_format, 65 | download_dir=self.download_dir, 66 | model_loader_extra_config=self.model_loader_extra_config, 67 | ignore_patterns=self.ignore_patterns, 68 | ) 69 | 70 | def create_engine_config(self) -> EngineConfig: 71 | engine_config = super().create_engine_config() 72 | 73 | # NOTE[VERL]: Use the world_size set by torchrun 74 | world_size = int(os.getenv("WORLD_SIZE", "-1")) 75 | assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" 76 | engine_config.parallel_config.world_size = world_size 77 | 78 | return engine_config 79 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader 15 | 16 | from typing import Dict 17 | 18 | import torch.nn as nn 19 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype 20 | 21 | 22 | def update_hf_weight_loader(): 23 | print("no hf weight loader need to be updated") 24 | return 25 | 26 | 27 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): 28 | assert isinstance(actor_weights, Dict) 29 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO 30 | if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys(): 31 | del actor_weights["lm_head.weight"] 32 | vllm_model.load_weights(actor_weights.items()) 33 | for _, module in vllm_model.named_modules(): 34 | quant_method = getattr(module, "quant_method", None) 35 | if quant_method is not None: 36 | quant_method.process_weights_after_loading(module) 37 | # FIXME: Remove this after Mixtral is updated 38 | # to use quant_method. 39 | if hasattr(module, "process_weights_after_loading"): 40 | module.process_weights_after_loading() 41 | vllm_model = vllm_model.cuda() 42 | -------------------------------------------------------------------------------- /src/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py 15 | 16 | from typing import Optional 17 | 18 | from transformers import PreTrainedTokenizer 19 | from vllm.transformers_utils.tokenizer_group import TokenizerGroup 20 | from vllm.utils import LRUCache 21 | 22 | 23 | class TokenizerGroup(TokenizerGroup): 24 | """A group of tokenizers that can be used for LoRA adapters.""" 25 | 26 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, 27 | max_input_length: Optional[int]): 28 | self.enable_lora = enable_lora 29 | self.max_input_length = max_input_length 30 | self.tokenizer = tokenizer 31 | self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None 32 | 33 | # FIXME(sgm): for simplicity, we assign the special token here 34 | @property 35 | def pad_token_id(self): 36 | return self.tokenizer.pad_token_id 37 | 38 | @property 39 | def eos_token_id(self): 40 | return self.tokenizer.eos_token_id 41 | -------------------------------------------------------------------------------- /src/verl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/trainer/config/evaluation.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | path: /tmp/math_Qwen2-7B-Instruct.parquet 3 | prompt_key: prompt 4 | response_key: responses 5 | data_source_key: data_source 6 | reward_model_key: reward_model 7 | 8 | custom_reward_function: 9 | path: null 10 | name: compute_score 11 | -------------------------------------------------------------------------------- /src/verl/trainer/config/generation.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | nnodes: 1 3 | n_gpus_per_node: 8 4 | 5 | data: 6 | path: ~/data/rlhf/math/test.parquet 7 | prompt_key: prompt 8 | n_samples: 5 9 | output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet 10 | batch_size: 128 11 | 12 | model: 13 | path: ~/models/Qwen2-7B-Instruct 14 | external_lib: null 15 | rollout: 16 | name: vllm 17 | temperature: 1.0 18 | top_k: 50 # 0 for hf rollout, -1 for vllm rollout 19 | top_p: 0.7 20 | prompt_length: 1536 21 | response_length: 512 22 | # for vllm rollout 23 | dtype: bfloat16 # should align with FSDP 24 | gpu_memory_utilization: 0.5 25 | ignore_eos: False 26 | enforce_eager: True 27 | free_cache_engine: True 28 | load_format: dummy_dtensor 29 | tensor_model_parallel_size: 1 30 | max_num_batched_tokens: 8192 31 | max_model_len: null 32 | max_num_seqs: 1024 33 | log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu 34 | log_prob_micro_batch_size_per_gpu: 8 35 | # for fire vllm rollout 36 | use_fire_sampling: False # enable FIRE https://arxiv.org/abs/2410.21236 37 | # for hf rollout 38 | do_sample: True 39 | disable_log_stats: True 40 | enable_chunked_prefill: True 41 | n: 1 42 | actor: 43 | strategy: fsdp # This is for backward-compatibility 44 | ulysses_sequence_parallel_size: 1 # sp size 45 | fsdp_config: 46 | fsdp_size: -1 -------------------------------------------------------------------------------- /src/verl/trainer/config/sft_trainer.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_batch_size: 256 3 | micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu 4 | micro_batch_size_per_gpu: 4 # this is also val batch size 5 | train_files: ~/data/gsm8k/train.parquet 6 | val_files: ~/data/gsm8k/test.parquet 7 | # Single-turn settings 8 | prompt_key: question 9 | response_key: answer 10 | prompt_dict_keys: ['question'] 11 | response_dict_keys: ['answer'] 12 | # Multi-turn settings 13 | multiturn: 14 | enable: false # Set to true to use multi-turn dataset 15 | messages_key: messages # Key for messages list in multi-turn mode 16 | max_length: 1024 17 | truncation: error 18 | balance_dp_token: False 19 | chat_template: null 20 | custom_cls: 21 | path: null 22 | name: null 23 | model: 24 | partial_pretrain: ~/models/gemma-1.1-7b-it 25 | fsdp_config: 26 | wrap_policy: 27 | min_num_params: 0 28 | cpu_offload: False 29 | offload_params: False 30 | external_lib: null 31 | enable_gradient_checkpointing: False 32 | trust_remote_code: False 33 | lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) 34 | lora_alpha: 16 # LoRA scaling factor 35 | target_modules: all-linear # Target modules for LoRA adaptation 36 | use_liger: False 37 | optim: 38 | lr: 1e-5 39 | betas: [0.9, 0.95] 40 | weight_decay: 0.01 41 | warmup_steps_ratio: 0.1 42 | clip_grad: 1.0 43 | ulysses_sequence_parallel_size: 1 44 | use_remove_padding: False 45 | trainer: 46 | default_local_dir: /tmp/sft_model 47 | default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here 48 | resume_path: null 49 | project_name: gsm8k-sft 50 | experiment_name: test 51 | total_epochs: 4 52 | total_training_steps: null 53 | logger: ['console'] 54 | seed: 1 55 | 56 | -------------------------------------------------------------------------------- /src/verl/trainer/main_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Offline evaluate the performance of a generated file using reward model and ground truth verifier. 16 | The input is a parquet file that contains N generated sequences and (optional) the ground truth. 17 | 18 | """ 19 | 20 | import hydra 21 | from verl.utils.fs import copy_to_local 22 | import pandas as pd 23 | import numpy as np 24 | from tqdm import tqdm 25 | from collections import defaultdict 26 | import ray 27 | 28 | 29 | def get_custom_reward_fn(config): 30 | import importlib.util, os, sys 31 | reward_fn_config = config.get("custom_reward_function") or {} 32 | file_path = reward_fn_config.get("path") 33 | if not file_path: 34 | return None 35 | 36 | if not os.path.exists(file_path): 37 | raise FileNotFoundError(f"Reward function file '{file_path}' not found.") 38 | 39 | spec = importlib.util.spec_from_file_location("custom_module", file_path) 40 | module = importlib.util.module_from_spec(spec) 41 | try: 42 | sys.modules["custom_module"] = module 43 | spec.loader.exec_module(module) 44 | except Exception as e: 45 | raise RuntimeError(f"Error loading module from '{file_path}': {e}") 46 | 47 | function_name = reward_fn_config.get("name") 48 | if not hasattr(module, function_name): 49 | raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.") 50 | 51 | print(f"using customized reward function '{function_name}' from '{file_path}'") 52 | raw_fn = getattr(module, function_name) 53 | 54 | reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {})) 55 | 56 | def wrapped_fn(*args, **kwargs): 57 | return raw_fn(*args, **kwargs, **reward_kwargs) 58 | 59 | return wrapped_fn 60 | 61 | 62 | @ray.remote 63 | def process_item(reward_fn, data_source, response_lst, reward_data): 64 | ground_truth = reward_data['ground_truth'] 65 | score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst] 66 | return data_source, np.mean(score_lst) 67 | 68 | 69 | @hydra.main(config_path='config', config_name='evaluation', version_base=None) 70 | def main(config): 71 | local_path = copy_to_local(config.data.path) 72 | dataset = pd.read_parquet(local_path) 73 | prompts = dataset[config.data.prompt_key] 74 | responses = dataset[config.data.response_key] 75 | data_sources = dataset[config.data.data_source_key] 76 | reward_model_data = dataset[config.data.reward_model_key] 77 | 78 | total = len(dataset) 79 | 80 | # Initialize Ray 81 | if not ray.is_initialized(): 82 | ray.init() 83 | 84 | # evaluate test_score based on data source 85 | data_source_reward = defaultdict(list) 86 | compute_score = get_custom_reward_fn(config) 87 | 88 | # Create remote tasks 89 | remote_tasks = [ 90 | process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total) 91 | ] 92 | 93 | # Process results as they come in 94 | with tqdm(total=total) as pbar: 95 | while len(remote_tasks) > 0: 96 | # Use ray.wait to get completed tasks 97 | done_ids, remote_tasks = ray.wait(remote_tasks) 98 | for result_id in done_ids: 99 | data_source, score = ray.get(result_id) 100 | data_source_reward[data_source].append(score) 101 | pbar.update(1) 102 | 103 | metric_dict = {} 104 | for data_source, rewards in data_source_reward.items(): 105 | metric_dict[f'test_score/{data_source}'] = np.mean(rewards) 106 | 107 | print(metric_dict) 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /src/verl/trainer/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/trainer/runtime_env.yaml: -------------------------------------------------------------------------------- 1 | working_dir: ./ 2 | excludes: ["/.git/"] 3 | env_vars: 4 | TORCH_NCCL_AVOID_RECORD_STREAMS: "1" 5 | VLLM_ATTENTION_BACKEND: "XFORMERS" -------------------------------------------------------------------------------- /src/verl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import tokenizer 16 | from .tokenizer import hf_tokenizer, hf_processor 17 | 18 | __all__ = tokenizer.__all__ -------------------------------------------------------------------------------- /src/verl/utils/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. -------------------------------------------------------------------------------- /src/verl/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict 16 | 17 | from omegaconf import DictConfig 18 | 19 | 20 | def update_dict_with_config(dictionary: Dict, config: DictConfig): 21 | for key in dictionary: 22 | if hasattr(config, key): 23 | dictionary[key] = getattr(config, key) 24 | -------------------------------------------------------------------------------- /src/verl/utils/dataset/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Format 2 | ## RLHF dataset 3 | We combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers. 4 | 5 | Math problems 6 | ```json 7 | { 8 | "data_source": "openai/gsm8k", 9 | "prompt": [{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\""}], 10 | "ability": "math", 11 | "reward_model": { 12 | "style": "rule", 13 | "ground_truth": ["72"] 14 | }, 15 | } 16 | ``` 17 | -------------------------------------------------------------------------------- /src/verl/utils/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .rl_dataset import RLHFDataset 16 | from .rm_dataset import RMDataset 17 | from .sft_dataset import SFTDataset 18 | -------------------------------------------------------------------------------- /src/verl/utils/dataset/template.py: -------------------------------------------------------------------------------- 1 | re_call_template_sys = """In this environment you have access to a set of tools you can use to assist with the user query. \ 2 | You may perform multiple rounds of function calls. \ 3 | In each round, you can call one or more functions. 4 | 5 | Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n``` 6 | 7 | In your response, you need to first think about the reasoning process in the mind and then conduct function calling to get the information or perform the actions if needed. \ 8 | The reasoning process and function calling are enclosed within and tags. \ 9 | The results of the function calls will be given back to you after execution, \ 10 | and you can continue to call functions until you get the final answer for the user's question. \ 11 | Finally, if you have got the answer, enclose it within \\boxed{{}} with latex format and do not continue to call functions, \ 12 | i.e., Based on the response from the function call, I get the weather information. The weather in Beijing on 2025-04-01 is \\[ \\boxed{{20C}} \\]. 13 | 14 | For each function call, return a json object with function name and arguments within XML tags: 15 | 16 | {{"name": , "arguments": }} 17 | """ 18 | 19 | prompt_template_dict = {} 20 | prompt_template_dict['re_call_template_sys'] = re_call_template_sys 21 | -------------------------------------------------------------------------------- /src/verl/utils/debug/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .performance import log_gpu_memory_usage -------------------------------------------------------------------------------- /src/verl/utils/debug/performance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.distributed as dist 17 | import logging 18 | 19 | 20 | def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): 21 | if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank): 22 | memory_allocated = torch.cuda.memory_allocated() / 1024**3 23 | memory_reserved = torch.cuda.memory_reserved() / 1024**3 24 | 25 | message = f'{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}' 26 | 27 | if logger is None: 28 | print(message) 29 | else: 30 | logger.log(msg=message, level=level) 31 | -------------------------------------------------------------------------------- /src/verl/utils/debug/trajectory_tracker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Trajectory tracker can be inserted into code to save the intermediate results. 16 | The results will be dump to hdfs for offline comparison. 17 | Each process will have a client that first move all the tensors to CPU 18 | """ 19 | 20 | from verl.utils.hdfs_io import makedirs, copy 21 | import torch 22 | import os 23 | import ray 24 | import io 25 | import tempfile 26 | 27 | from collections import deque 28 | 29 | remote_copy = ray.remote(copy) 30 | 31 | 32 | @ray.remote 33 | def save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose): 34 | filename = name + '.pth' 35 | with tempfile.TemporaryDirectory() as tmpdirname: 36 | local_filepath = os.path.join(tmpdirname, filename) 37 | with open(local_filepath, 'wb') as f: 38 | f.write(data.getbuffer()) 39 | # upload to hdfs 40 | 41 | if verbose: 42 | print(f'Saving {local_filepath} to {hdfs_dir}') 43 | try: 44 | copy(local_filepath, hdfs_dir) 45 | except Exception as e: 46 | print(e) 47 | 48 | 49 | @ray.remote 50 | class TrajectoryTracker(): 51 | 52 | def __init__(self, hdfs_dir, verbose) -> None: 53 | self.hdfs_dir = hdfs_dir 54 | makedirs(hdfs_dir) 55 | self.verbose = verbose 56 | 57 | self.handle = deque() 58 | 59 | def dump(self, data: io.BytesIO, name): 60 | # get a temp file and write to it 61 | self.handle.append(save_to_hdfs.remote(data, name, self.hdfs_dir, self.verbose)) 62 | 63 | def wait_for_hdfs(self): 64 | while len(self.handle) != 0: 65 | future = self.handle.popleft() 66 | ray.get(future) 67 | 68 | 69 | def dump_data(data, name): 70 | enable = os.getenv('VERL_ENABLE_TRACKER', '0') == '1' 71 | if not enable: 72 | return 73 | buffer = io.BytesIO() 74 | torch.save(data, buffer) 75 | tracker = get_trajectory_tracker() 76 | ray.get(tracker.dump.remote(buffer, name)) 77 | 78 | 79 | def get_trajectory_tracker(): 80 | hdfs_dir = os.getenv('VERL_TRACKER_HDFS_DIR', default=None) 81 | verbose = os.getenv('VERL_TRACKER_VERBOSE', default='0') == '1' 82 | assert hdfs_dir is not None 83 | tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True, 84 | lifetime="detached").remote(hdfs_dir, verbose) 85 | return tracker 86 | 87 | 88 | if __name__ == '__main__': 89 | # testing 90 | os.environ['VERL_ENABLE_TRACKER'] = '1' 91 | os.environ['VERL_TRACKER_HDFS_DIR'] = '~/debug/test' 92 | 93 | @ray.remote 94 | def process(iter): 95 | data = {'obs': torch.randn(10, 20)} 96 | dump_data(data, f'process_{iter}_obs') 97 | 98 | ray.init() 99 | 100 | output_lst = [] 101 | 102 | for i in range(10): 103 | output_lst.append(process.remote(i)) 104 | 105 | out = ray.get(output_lst) 106 | 107 | tracker = get_trajectory_tracker() 108 | ray.get(tracker.wait_for_hdfs.remote()) 109 | -------------------------------------------------------------------------------- /src/verl/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for distributed training.""" 15 | import os 16 | 17 | 18 | def initialize_global_process_group(timeout_second=36000): 19 | import torch.distributed 20 | from datetime import timedelta 21 | torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second)) 22 | local_rank = int(os.environ["LOCAL_RANK"]) 23 | rank = int(os.environ["RANK"]) 24 | world_size = int(os.environ["WORLD_SIZE"]) 25 | 26 | if torch.distributed.is_initialized(): 27 | torch.cuda.set_device(local_rank) 28 | return local_rank, rank, world_size 29 | -------------------------------------------------------------------------------- /src/verl/utils/fs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # -*- coding: utf-8 -*- 17 | """File-system agnostic IO APIs""" 18 | import os 19 | import tempfile 20 | import hashlib 21 | 22 | try: 23 | from hdfs_io import copy, makedirs, exists # for internal use only 24 | except ImportError: 25 | from .hdfs_io import copy, makedirs, exists 26 | 27 | __all__ = ["copy", "exists", "makedirs"] 28 | 29 | _HDFS_PREFIX = "hdfs://" 30 | 31 | 32 | def is_non_local(path): 33 | return path.startswith(_HDFS_PREFIX) 34 | 35 | 36 | def md5_encode(path: str) -> str: 37 | return hashlib.md5(path.encode()).hexdigest() 38 | 39 | 40 | def get_local_temp_path(hdfs_path: str, cache_dir: str) -> str: 41 | """Return a local temp path that joins cache_dir and basename of hdfs_path 42 | 43 | Args: 44 | hdfs_path: 45 | cache_dir: 46 | 47 | Returns: 48 | 49 | """ 50 | # make a base64 encoding of hdfs_path to avoid directory conflict 51 | encoded_hdfs_path = md5_encode(hdfs_path) 52 | temp_dir = os.path.join(cache_dir, encoded_hdfs_path) 53 | os.makedirs(temp_dir, exist_ok=True) 54 | dst = os.path.join(temp_dir, os.path.basename(hdfs_path)) 55 | return dst 56 | 57 | 58 | def copy_to_local(src: str, cache_dir=None, filelock='.file.lock', verbose=False) -> str: 59 | """Copy src from hdfs to local if src is on hdfs or directly return src. 60 | If cache_dir is None, we will use the default cache dir of the system. Note that this may cause conflicts if 61 | the src name is the same between calls 62 | 63 | Args: 64 | src (str): a HDFS path of a local path 65 | 66 | Returns: 67 | a local path of the copied file 68 | """ 69 | return copy_local_path_from_hdfs(src, cache_dir, filelock, verbose) 70 | 71 | 72 | def copy_local_path_from_hdfs(src: str, cache_dir=None, filelock='.file.lock', verbose=False) -> str: 73 | """Deprecated. Please use copy_to_local instead.""" 74 | from filelock import FileLock 75 | 76 | assert src[-1] != '/', f'Make sure the last char in src is not / because it will cause error. Got {src}' 77 | 78 | if is_non_local(src): 79 | # download from hdfs to local 80 | if cache_dir is None: 81 | # get a temp folder 82 | cache_dir = tempfile.gettempdir() 83 | os.makedirs(cache_dir, exist_ok=True) 84 | assert os.path.exists(cache_dir) 85 | local_path = get_local_temp_path(src, cache_dir) 86 | # get a specific lock 87 | filelock = md5_encode(src) + '.lock' 88 | lock_file = os.path.join(cache_dir, filelock) 89 | with FileLock(lock_file=lock_file): 90 | if not os.path.exists(local_path): 91 | if verbose: 92 | print(f'Copy from {src} to {local_path}') 93 | copy(src, local_path) 94 | return local_path 95 | else: 96 | return src 97 | -------------------------------------------------------------------------------- /src/verl/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Utilities to check if packages are available. 16 | We assume package availability won't change during runtime. 17 | """ 18 | 19 | from functools import cache 20 | from typing import List, Optional 21 | import importlib 22 | 23 | 24 | @cache 25 | def is_megatron_core_available(): 26 | try: 27 | mcore_spec = importlib.util.find_spec('megatron.core') 28 | except ModuleNotFoundError: 29 | mcore_spec = None 30 | return mcore_spec is not None 31 | 32 | 33 | @cache 34 | def is_vllm_available(): 35 | try: 36 | vllm_spec = importlib.util.find_spec('vllm') 37 | except ModuleNotFoundError: 38 | vllm_spec = None 39 | return vllm_spec is not None 40 | 41 | 42 | @cache 43 | def is_sglang_available(): 44 | try: 45 | sglang_spec = importlib.util.find_spec('sglang') 46 | except ModuleNotFoundError: 47 | sglang_spec = None 48 | return sglang_spec is not None 49 | 50 | 51 | def import_external_libs(external_libs=None): 52 | if external_libs is None: 53 | return 54 | if not isinstance(external_libs, List): 55 | external_libs = [external_libs] 56 | import importlib 57 | for external_lib in external_libs: 58 | importlib.import_module(external_lib) 59 | 60 | 61 | def load_extern_type(file_path: Optional[str], type_name: Optional[str]): 62 | """Load a external data type based on the file path and type name""" 63 | import importlib.util, os 64 | 65 | if not file_path: 66 | return None 67 | 68 | if not os.path.exists(file_path): 69 | raise FileNotFoundError(f"Custom type file '{file_path}' not found.") 70 | 71 | spec = importlib.util.spec_from_file_location("custom_module", file_path) 72 | module = importlib.util.module_from_spec(spec) 73 | try: 74 | spec.loader.exec_module(module) 75 | except Exception as e: 76 | raise RuntimeError(f"Error loading module from '{file_path}': {e}") 77 | 78 | if not hasattr(module, type_name): 79 | raise AttributeError(f"Custom type '{type_name}' not found in '{file_path}'.") 80 | 81 | return getattr(module, type_name) -------------------------------------------------------------------------------- /src/verl/utils/logger/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/utils/logger/aggregate_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A Ray logger will receive logging info from different processes. 16 | """ 17 | import numbers 18 | from typing import Dict 19 | 20 | 21 | def concat_dict_to_str(dict: Dict, step): 22 | output = [f'step:{step}'] 23 | for k, v in dict.items(): 24 | if isinstance(v, numbers.Number): 25 | output.append(f'{k}:{v:.3f}') 26 | output_str = ' - '.join(output) 27 | return output_str 28 | 29 | 30 | class LocalLogger: 31 | 32 | def __init__(self, remote_logger=None, enable_wandb=False, print_to_console=False): 33 | self.print_to_console = print_to_console 34 | if print_to_console: 35 | print('Using LocalLogger is deprecated. The constructor API will change ') 36 | 37 | def flush(self): 38 | pass 39 | 40 | def log(self, data, step): 41 | if self.print_to_console: 42 | print(concat_dict_to_str(data, step=step), flush=True) -------------------------------------------------------------------------------- /src/verl/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import torch 18 | 19 | 20 | def set_basic_config(level): 21 | """ 22 | This function sets the global logging format and level. It will be called when import verl 23 | """ 24 | logging.basicConfig(format='%(levelname)s:%(asctime)s:%(message)s', level=level) 25 | 26 | 27 | def log_to_file(string): 28 | print(string) 29 | if os.path.isdir('logs'): 30 | with open(f'logs/log_{torch.distributed.get_rank()}', 'a+') as f: 31 | f.write(string + '\n') 32 | -------------------------------------------------------------------------------- /src/verl/utils/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/utils/megatron/memory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | class MemoryBuffer: 19 | 20 | def __init__(self, numel, numel_padded, dtype): 21 | self.numel = numel 22 | self.numel_padded = numel_padded 23 | self.dtype = dtype 24 | self.data = torch.zeros(self.numel_padded, 25 | dtype=self.dtype, 26 | device=torch.cuda.current_device(), 27 | requires_grad=False) 28 | 29 | def zero(self): 30 | """Reset the buffer to zero.""" 31 | self.data.zero_() 32 | 33 | def get(self, shape, start_index): 34 | """Return a tensor with the input `shape` as a view into the 35 | 1-D data starting at `start_index`.""" 36 | end_index = start_index + shape.numel() 37 | assert end_index <= self.numel, \ 38 | 'requested tensor is out of the buffer range.' 39 | buffer_tensor = self.data[start_index:end_index] 40 | buffer_tensor = buffer_tensor.view(shape) 41 | return buffer_tensor 42 | -------------------------------------------------------------------------------- /src/verl/utils/megatron/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import importlib 17 | from packaging.version import Version 18 | 19 | from apex.optimizers import FusedAdam as Adam 20 | from apex.optimizers import FusedSGD as SGD 21 | 22 | from megatron.core.optimizer import OptimizerConfig 23 | 24 | from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native 25 | 26 | 27 | def get_megatron_optimizer( 28 | model, 29 | config: OptimizerConfig, 30 | no_weight_decay_cond=None, 31 | scale_lr_cond=None, 32 | lr_mult=1.0, 33 | check_for_nan_in_loss_and_grad=False, 34 | overlap_param_gather=False # add for verl 35 | ): 36 | # Base optimizer. 37 | return get_megatron_optimizer_native(config=config, 38 | model_chunks=model, 39 | no_weight_decay_cond=no_weight_decay_cond, 40 | scale_lr_cond=scale_lr_cond, 41 | lr_mult=lr_mult) 42 | -------------------------------------------------------------------------------- /src/verl/utils/megatron/pipeline_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from megatron.core import parallel_state as mpu 18 | 19 | from .sequence_parallel import pad_to_sequence_parallel 20 | 21 | 22 | def compute_transformers_input_shapes(batches, meta_info): 23 | from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron 24 | # pre-compute input shapes for each micro-batch at each pp stage 25 | input_shapes = [] 26 | for model_inputs in batches: 27 | input_ids = model_inputs['input_ids'] 28 | attention_mask = model_inputs['attention_mask'] 29 | input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1) 30 | if meta_info['sequence_parallel']: 31 | input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad) 32 | # compute shapes for model_inputs 33 | input_shapes.append( 34 | torch.Size([ 35 | input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), 1, meta_info['hidden_size'] 36 | ])) 37 | else: 38 | # compute shapes for model_inputs 39 | input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info['hidden_size']])) 40 | return input_shapes 41 | 42 | 43 | def make_batch_generator(batches, vpp_size): 44 | if vpp_size > 1: 45 | # has vpp 46 | batch_generator = [batches] * vpp_size # number of vpp chunks 47 | batch_generator = [iter(b) for b in batch_generator] 48 | else: 49 | # no vpp 50 | batch_generator = iter(batches) 51 | return batch_generator 52 | -------------------------------------------------------------------------------- /src/verl/utils/megatron/sequence_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from megatron.core import parallel_state as mpu 19 | 20 | 21 | def mark_parameter_as_sequence_parallel(parameter): 22 | setattr(parameter, 'sequence_parallel', True) 23 | 24 | 25 | def is_sequence_parallel_param(param): 26 | return hasattr(param, 'sequence_parallel') and param.sequence_parallel 27 | 28 | 29 | def pad_to_sequence_parallel(unpad_tokens: torch.Tensor): 30 | """pad the tokens such that the total length is a multiple of sp world size 31 | 32 | Args: 33 | unpad_tokens: (total_nnz, ...). Tokens after removing padding 34 | 35 | Returns: 36 | 37 | """ 38 | total_nnz = unpad_tokens.shape[0] 39 | sp_world_size = mpu.get_tensor_model_parallel_world_size() 40 | 41 | if total_nnz % sp_world_size == 0: 42 | pad_size = 0 43 | else: 44 | pad_size = sp_world_size - total_nnz % sp_world_size 45 | 46 | if pad_size > 0: 47 | if unpad_tokens.ndim == 1: 48 | unpad_tokens = F.pad(unpad_tokens, (0, pad_size)) 49 | elif unpad_tokens.ndim == 2: 50 | unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size)) 51 | else: 52 | raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported') 53 | 54 | return unpad_tokens 55 | -------------------------------------------------------------------------------- /src/verl/utils/py_functional.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contain small python utility functions 16 | """ 17 | 18 | from typing import Dict 19 | from types import SimpleNamespace 20 | 21 | 22 | def union_two_dict(dict1: Dict, dict2: Dict): 23 | """Union two dict. Will throw an error if there is an item not the same object with the same key. 24 | 25 | Args: 26 | dict1: 27 | dict2: 28 | 29 | Returns: 30 | 31 | """ 32 | for key, val in dict2.items(): 33 | if key in dict1: 34 | assert dict2[key] == dict1[key], \ 35 | f'{key} in meta_dict1 and meta_dict2 are not the same object' 36 | dict1[key] = val 37 | 38 | return dict1 39 | 40 | 41 | def append_to_dict(data: Dict, new_data: Dict): 42 | for key, val in new_data.items(): 43 | if key not in data: 44 | data[key] = [] 45 | data[key].append(val) 46 | 47 | 48 | class NestedNamespace(SimpleNamespace): 49 | 50 | def __init__(self, dictionary, **kwargs): 51 | super().__init__(**kwargs) 52 | for key, value in dictionary.items(): 53 | if isinstance(value, dict): 54 | self.__setattr__(key, NestedNamespace(value)) 55 | else: 56 | self.__setattr__(key, value) 57 | -------------------------------------------------------------------------------- /src/verl/utils/ray_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contains commonly used utilities for ray 16 | """ 17 | 18 | import ray 19 | 20 | import concurrent.futures 21 | 22 | 23 | def parallel_put(data_list, max_workers=None): 24 | 25 | def put_data(index, data): 26 | return index, ray.put(data) 27 | 28 | if max_workers is None: 29 | max_workers = min(len(data_list), 16) 30 | 31 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 32 | data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)] 33 | res_lst = [] 34 | for future in concurrent.futures.as_completed(data_list_f): 35 | res_lst.append(future.result()) 36 | 37 | # reorder based on index 38 | output = [None for _ in range(len(data_list))] 39 | for res in res_lst: 40 | index, data_ref = res 41 | output[index] = data_ref 42 | 43 | return output 44 | -------------------------------------------------------------------------------- /src/verl/utils/rendezvous/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/utils/rendezvous/ray_backend.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import time 17 | 18 | from cupy.cuda.nccl import NcclCommunicator, get_unique_id 19 | 20 | import ray 21 | from ray.util import list_named_actors 22 | 23 | 24 | @ray.remote 25 | class NCCLIDStore: 26 | 27 | def __init__(self, nccl_id): 28 | self._nccl_id = nccl_id 29 | 30 | def get(self): 31 | return self._nccl_id 32 | 33 | 34 | def get_nccl_id_store_by_name(name): 35 | all_actors = list_named_actors(all_namespaces=True) 36 | matched_actors = [actor for actor in all_actors if actor.get("name", None) == name] 37 | if len(matched_actors) == 1: 38 | actor = matched_actors[0] 39 | return ray.get_actor(**actor) 40 | elif len(matched_actors) > 1: 41 | logging.warning(f"multiple actors with same name found: {matched_actors}") 42 | elif len(matched_actors) == 0: 43 | logging.info(f"failed to get any actor named {name}") 44 | return None 45 | 46 | 47 | def create_nccl_communicator_in_ray(rank: int, 48 | world_size: int, 49 | group_name: str, 50 | max_retries: int = 100, 51 | interval_s: int = 5): 52 | if rank == 0: 53 | nccl_id = get_unique_id() 54 | nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id) 55 | 56 | assert ray.get(nccl_id_store.get.remote()) == nccl_id 57 | communicator = NcclCommunicator( 58 | ndev=world_size, 59 | commId=nccl_id, 60 | rank=0, 61 | ) 62 | return communicator 63 | else: 64 | for i in range(max_retries): 65 | nccl_id_store = get_nccl_id_store_by_name(group_name) 66 | if nccl_id_store is not None: 67 | logging.info(f"nccl_id_store {group_name} got") 68 | nccl_id = ray.get(nccl_id_store.get.remote()) 69 | logging.info(f"nccl id for {group_name} got: {nccl_id}") 70 | communicator = NcclCommunicator( 71 | ndev=world_size, 72 | commId=nccl_id, 73 | rank=rank, 74 | ) 75 | return communicator 76 | logging.info(f"failed to get nccl_id for {i+1} time, sleep for {interval_s} seconds") 77 | time.sleep(interval_s) 78 | -------------------------------------------------------------------------------- /src/verl/utils/reward_score/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # from . import gsm8k, math, prime_math, prime_code 15 | 16 | 17 | def _default_compute_score(data_source, tokenizer, solution_str, ground_truth, extra_info=None): 18 | if data_source == 'openai/gsm8k': 19 | from . import gsm8k 20 | res = gsm8k.compute_score(solution_str, ground_truth) 21 | elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']: 22 | from . import math 23 | res = math.compute_score(solution_str, ground_truth) 24 | # [Optional] Math-Verify Integration 25 | # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify). 26 | # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`. 27 | # To use it, override the `compute_score` function with the following implementation: 28 | 29 | # from . import math_verify 30 | # res = math_verify.compute_score(solution_str, ground_truth) 31 | elif data_source == 'math_dapo' or data_source.startswith("aime"): 32 | from . import math_dapo 33 | res = math_dapo.compute_score(solution_str, ground_truth) 34 | elif data_source in [ 35 | 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', 36 | 'numina_olympiads' 37 | ]: 38 | from . import prime_math 39 | res = prime_math.compute_score(solution_str, ground_truth) 40 | elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']: 41 | from . import prime_code 42 | res = prime_code.compute_score(solution_str, ground_truth, continuous=True) 43 | elif data_source in ['hiyouga/geometry3k']: 44 | from . import geo3k 45 | res = geo3k.compute_score(solution_str, ground_truth) 46 | elif 're_call' in data_source: 47 | from . import re_call 48 | res = re_call.compute_score_with_format(tokenizer, solution_str, ground_truth) 49 | else: 50 | raise NotImplementedError(f"Reward function is not implemented for {data_source=}") 51 | 52 | if isinstance(res, dict): 53 | return res 54 | elif isinstance(res, tuple): 55 | return res 56 | elif isinstance(res, (int, float, bool)): 57 | return float(res) 58 | else: 59 | return float(res[0]) 60 | -------------------------------------------------------------------------------- /src/verl/utils/reward_score/geo3k.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | from mathruler.grader import extract_boxed_content, grade_answer 17 | 18 | 19 | def format_reward(predict_str: str) -> float: 20 | pattern = re.compile(r'.*.*\\boxed\{.*\}.*', re.DOTALL) 21 | match_result = re.fullmatch(pattern, predict_str) 22 | return 1.0 if match_result else 0.0 23 | 24 | 25 | def acc_reward(predict_str: str, ground_truth: str) -> float: 26 | answer = extract_boxed_content(predict_str) 27 | return 1.0 if grade_answer(answer, ground_truth) else 0.0 28 | 29 | 30 | def compute_score(predict_str: str, ground_truth: str) -> float: 31 | return 0.9 * acc_reward(predict_str, ground_truth) + 0.1 * format_reward(predict_str) 32 | -------------------------------------------------------------------------------- /src/verl/utils/reward_score/gsm8k.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | 18 | def extract_solution(solution_str, method='strict'): 19 | assert method in ['strict', 'flexible'] 20 | 21 | if method == 'strict': 22 | # this also tests the formatting of the model 23 | solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) 24 | if solution is None: 25 | final_answer = None 26 | else: 27 | final_answer = solution.group(0) 28 | final_answer = final_answer.split('#### ')[1].replace(',', '').replace('$', '') 29 | elif method == 'flexible': 30 | answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str) 31 | final_answer = None 32 | if len(answer) == 0: 33 | # no reward is there is no answer 34 | pass 35 | else: 36 | invalid_str = ['', '.'] 37 | # find the last number that is not '.' 38 | for final_answer in reversed(answer): 39 | if final_answer not in invalid_str: 40 | break 41 | return final_answer 42 | 43 | 44 | def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.): 45 | """The scoring function for GSM8k. 46 | 47 | Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024. 48 | 49 | Args: 50 | solution_str: the solution text 51 | ground_truth: the ground truth 52 | method: the method to extract the solution, choices are 'strict' and 'flexible' 53 | format_score: the score for the format 54 | score: the score for the correct answer 55 | """ 56 | answer = extract_solution(solution_str=solution_str, method=method) 57 | if answer is None: 58 | return 0 59 | else: 60 | if answer == ground_truth: 61 | return score 62 | else: 63 | return format_score -------------------------------------------------------------------------------- /src/verl/utils/reward_score/math_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Individual Contributor: Mert Unsal 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .math import compute_score 16 | 17 | 18 | def compute_score_batched(data_sources, solution_strs, ground_truths, extra_infos): 19 | """ 20 | This is a demonstration of how the batched reward function should look like. 21 | Typically, you want to use batched reward to speed up the process with parallelization 22 | """ 23 | return [ 24 | compute_score(solution_str, ground_truth) for solution_str, ground_truth in zip(solution_strs, ground_truths) 25 | ] 26 | -------------------------------------------------------------------------------- /src/verl/utils/reward_score/math_verify.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | try: 16 | from math_verify.metric import math_metric 17 | from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig 18 | from math_verify.errors import TimeoutException 19 | except ImportError: 20 | print("To use Math-Verify, please install it first by running `pip install math-verify`.") 21 | 22 | 23 | def compute_score(model_output: str, ground_truth: str, timeout_score: float = 0) -> bool: 24 | verify_func = math_metric( 25 | gold_extraction_target=(LatexExtractionConfig(),), 26 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), 27 | ) 28 | ret_score = 0. 29 | 30 | # Wrap the ground truth in \boxed{} format for verification 31 | ground_truth_boxed = "\\boxed{" + ground_truth + "}" 32 | try: 33 | ret_score, _ = verify_func([ground_truth_boxed], [model_output]) 34 | except Exception as e: 35 | pass 36 | except TimeoutException: 37 | ret_score = timeout_score 38 | 39 | return ret_score 40 | -------------------------------------------------------------------------------- /src/verl/utils/reward_score/prime_code/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 PRIME team and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .utils import check_correctness as apps_check_correctness 16 | import json 17 | import re 18 | import traceback 19 | 20 | 21 | def compute_score(completion, test_cases, continuous=False): 22 | # try to get code solution from completion. if the completion is pure code, this will not take effect. 23 | solution = completion.split('```python')[-1].split('```')[0] 24 | try: 25 | try: 26 | if not isinstance(test_cases, dict): 27 | test_cases = json.loads(test_cases) 28 | except Exception as e: 29 | print(f"Error:{e}") 30 | 31 | # Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped. 32 | try: 33 | res, metadata = apps_check_correctness(in_outs=test_cases, generation=solution, timeout=5, debug=False) 34 | metadata = dict(enumerate(metadata))[0] 35 | success = all(map(lambda x: x == True, res)) 36 | if success: 37 | return success, metadata 38 | except Exception as e: 39 | pass 40 | 41 | test_cases_list = [] 42 | inputs = test_cases["inputs"] 43 | outputs = test_cases["outputs"] 44 | for i in range(len(inputs)): 45 | test_cases_list.append({"inputs": [inputs[i]], "outputs": [outputs[i]]}) 46 | 47 | if continuous: 48 | # per sample test: if continuous score is needed, test first 10 samples regardless of failures 49 | # do not test all samples cuz some problems have enormous test cases 50 | metadata_list = [] 51 | res_list = [] 52 | for test_case_id, test_case in enumerate(test_cases_list): 53 | res, metadata = apps_check_correctness(in_outs=test_case, generation=solution, timeout=5, debug=False) 54 | try: 55 | metadata = dict(enumerate(metadata))[0] # metadata can be empty occasionally 56 | except Exception as e: 57 | metadata = {} 58 | metadata["test_case"] = {} 59 | metadata["test_case"]["input"] = str(test_case["inputs"][0]) 60 | metadata["test_case"]["output"] = str(test_case["outputs"][0]) 61 | metadata["test_case"]["res"] = str(res) 62 | metadata_list.append(metadata) 63 | res_list.extend(res) 64 | 65 | if test_case_id >= 9: 66 | break 67 | res_count = len(res_list) if len(res_list) > 0 else 1 68 | success = sum(map(lambda x: x == True, res_list)) / res_count 69 | except Exception as e: 70 | traceback.print_exc(10) 71 | success = False 72 | metadata_list = None 73 | return success, metadata_list 74 | -------------------------------------------------------------------------------- /src/verl/utils/reward_score/prime_code/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 PRIME team and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py 16 | 17 | import multiprocessing 18 | from typing import Dict, Optional 19 | from datasets import load_dataset 20 | from .testing_util import run_test 21 | import traceback 22 | import os, sys 23 | 24 | 25 | def _temp_run(sample, generation, debug, result, metadata_list, timeout): 26 | with open(os.devnull, 'w') as devnull: 27 | sys.stdout = devnull 28 | sys.stderr = devnull 29 | try: 30 | res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) 31 | result.append(res) 32 | metadata_list.append(metadata) 33 | except Exception as e: 34 | # print(e) # some tracebacks are extremely long. 35 | traceback.print_exc(10) 36 | result.append([-1 for i in range(len(sample['inputs']))]) 37 | metadata_list.append({}) 38 | 39 | 40 | def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True): 41 | """Check correctness of code generation with a global timeout. 42 | The global timeout is to catch some extreme/rare cases not handled by the timeouts 43 | inside `run_test`""" 44 | 45 | manager = multiprocessing.Manager() 46 | result = manager.list() 47 | metadata_list = manager.list() 48 | p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout)) 49 | p.start() 50 | p.join(timeout=timeout + 1) 51 | if p.is_alive(): 52 | p.kill() 53 | # p.terminate() 54 | if not result: 55 | # consider that all tests failed 56 | result = [[-1 for i in range(len(in_outs["inputs"]))]] 57 | if debug: 58 | print(f"global timeout") 59 | return result[0], metadata_list 60 | -------------------------------------------------------------------------------- /src/verl/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for tokenization.""" 15 | import warnings 16 | 17 | __all__ = ['hf_tokenizer', 'hf_processor'] 18 | 19 | 20 | def set_pad_token_id(tokenizer): 21 | """Set pad_token_id to eos_token_id if it is None. 22 | 23 | Args: 24 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be set. 25 | 26 | """ 27 | if tokenizer.pad_token_id is None: 28 | tokenizer.pad_token_id = tokenizer.eos_token_id 29 | warnings.warn(f'tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}') 30 | if tokenizer.pad_token is None: 31 | tokenizer.pad_token = tokenizer.eos_token 32 | warnings.warn(f'tokenizer.pad_token is None. Now set to {tokenizer.eos_token}') 33 | 34 | 35 | def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs): 36 | """Create a huggingface pretrained tokenizer which correctness handles eos and pad tokens. 37 | 38 | Args: 39 | 40 | name (str): The name of the tokenizer. 41 | correct_pad_token (bool): Whether to correct the pad token id. 42 | correct_gemma2 (bool): Whether to correct the gemma2 tokenizer. 43 | 44 | Returns: 45 | 46 | transformers.PreTrainedTokenizer: The pretrained tokenizer. 47 | 48 | """ 49 | from transformers import AutoTokenizer 50 | if correct_gemma2 and isinstance(name_or_path, str) and 'gemma-2-2b-it' in name_or_path: 51 | # the EOS token in gemma2 is ambiguious, which may worsen RL performance. 52 | # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a 53 | warnings.warn('Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.') 54 | kwargs['eos_token'] = '' 55 | kwargs['eos_token_id'] = 107 56 | tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) 57 | if correct_pad_token: 58 | set_pad_token_id(tokenizer) 59 | return tokenizer 60 | 61 | 62 | def hf_processor(name_or_path, **kwargs): 63 | """Create a huggingface processor to process multimodal data. 64 | 65 | Args: 66 | name_or_path (str): The name of the processor. 67 | 68 | Returns: 69 | transformers.ProcessorMixin: The pretrained processor. 70 | """ 71 | from transformers import AutoProcessor 72 | try: 73 | processor = AutoProcessor.from_pretrained(name_or_path, **kwargs) 74 | except Exception: 75 | processor = None 76 | # Avoid load tokenizer, see: 77 | # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 78 | if processor is not None and "Processor" not in processor.__class__.__name__: 79 | processor = None 80 | return processor 81 | -------------------------------------------------------------------------------- /src/verl/utils/torch_dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Adapted from Cruise. 16 | """ 17 | 18 | import torch 19 | 20 | from typing import Union 21 | 22 | HALF_LIST = [16, "16", "fp16", "float16", torch.float16] 23 | FLOAT_LIST = [32, "32", "fp32", "float32", torch.float32] 24 | BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16] 25 | 26 | 27 | class PrecisionType(object): 28 | """Type of precision used. 29 | 30 | >>> PrecisionType.HALF == 16 31 | True 32 | >>> PrecisionType.HALF in (16, "16") 33 | True 34 | """ 35 | 36 | HALF = "16" 37 | FLOAT = "32" 38 | FULL = "64" 39 | BFLOAT = "bf16" 40 | MIXED = "mixed" 41 | 42 | @staticmethod 43 | def supported_type(precision: Union[str, int]) -> bool: 44 | return any(x == precision for x in PrecisionType) 45 | 46 | @staticmethod 47 | def supported_types() -> list[str]: 48 | return [x.value for x in PrecisionType] 49 | 50 | @staticmethod 51 | def is_fp16(precision): 52 | return precision in HALF_LIST 53 | 54 | @staticmethod 55 | def is_fp32(precision): 56 | return precision in FLOAT_LIST 57 | 58 | @staticmethod 59 | def is_bf16(precision): 60 | return precision in BFLOAT_LIST 61 | 62 | @staticmethod 63 | def to_dtype(precision): 64 | if precision in HALF_LIST: 65 | return torch.float16 66 | elif precision in FLOAT_LIST: 67 | return torch.float32 68 | elif precision in BFLOAT_LIST: 69 | return torch.bfloat16 70 | else: 71 | raise RuntimeError(f"unexpected precision: {precision}") 72 | 73 | @staticmethod 74 | def to_str(precision): 75 | if precision == torch.float16: 76 | return 'fp16' 77 | elif precision == torch.float32: 78 | return 'fp32' 79 | elif precision == torch.bfloat16: 80 | return 'bf16' 81 | else: 82 | raise RuntimeError(f"unexpected precision: {precision}") 83 | -------------------------------------------------------------------------------- /src/verl/version/version: -------------------------------------------------------------------------------- 1 | 0.2.0.dev 2 | -------------------------------------------------------------------------------- /src/verl/workers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/verl/workers/actor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPOActor 16 | from .dp_actor import DataParallelPPOActor 17 | 18 | __all__ = ["BasePPOActor", "DataParallelPPOActor"] 19 | -------------------------------------------------------------------------------- /src/verl/workers/actor/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The base class for Actor 16 | """ 17 | from abc import ABC, abstractmethod 18 | from typing import Iterable, Dict 19 | 20 | from verl import DataProto 21 | import torch 22 | 23 | __all__ = ['BasePPOActor'] 24 | 25 | 26 | class BasePPOActor(ABC): 27 | 28 | def __init__(self, config): 29 | """The base class for PPO actor 30 | 31 | Args: 32 | config (DictConfig): a config passed to the PPOActor. We expect the type to be 33 | DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general. 34 | """ 35 | super().__init__() 36 | self.config = config 37 | 38 | @abstractmethod 39 | def compute_log_prob(self, data: DataProto) -> torch.Tensor: 40 | """Compute logits given a batch of data. 41 | 42 | Args: 43 | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```, 44 | ```attention_mask``` and ```position_ids```. 45 | 46 | Returns: 47 | DataProto: a DataProto containing the key ```log_probs``` 48 | 49 | 50 | """ 51 | pass 52 | 53 | @abstractmethod 54 | def update_policy(self, data: DataProto) -> Dict: 55 | """Update the policy with an iterator of DataProto 56 | 57 | Args: 58 | data (DataProto): an iterator over the DataProto that returns by 59 | ```make_minibatch_iterator``` 60 | 61 | Returns: 62 | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model 63 | such as ```loss```, ```grad_norm```, etc,. 64 | 65 | """ 66 | pass 67 | -------------------------------------------------------------------------------- /src/verl/workers/critic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPOCritic 16 | from .dp_critic import DataParallelPPOCritic 17 | 18 | __all__ = ["BasePPOCritic", "DataParallelPPOCritic"] 19 | -------------------------------------------------------------------------------- /src/verl/workers/critic/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Base class for a critic 16 | """ 17 | from abc import ABC, abstractmethod 18 | 19 | import torch 20 | 21 | from verl import DataProto 22 | 23 | __all__ = ['BasePPOCritic'] 24 | 25 | 26 | class BasePPOCritic(ABC): 27 | 28 | def __init__(self, config): 29 | super().__init__() 30 | self.config = config 31 | 32 | @abstractmethod 33 | def compute_values(self, data: DataProto) -> torch.Tensor: 34 | """Compute values""" 35 | pass 36 | 37 | @abstractmethod 38 | def update_critic(self, data: DataProto): 39 | """Update the critic""" 40 | pass 41 | -------------------------------------------------------------------------------- /src/verl/workers/reward_manager/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 PRIME team and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .naive import NaiveRewardManager 16 | from .prime import PrimeRewardManager 17 | from .batch import BatchRewardManager 18 | from .dapo import DAPORewardManager 19 | from .re_call import ReCallRewardManagerWithSave -------------------------------------------------------------------------------- /src/verl/workers/reward_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPORewardModel 16 | -------------------------------------------------------------------------------- /src/verl/workers/reward_model/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The base class for reward model 16 | """ 17 | 18 | from abc import ABC, abstractmethod 19 | 20 | from verl import DataProto 21 | 22 | 23 | class BasePPORewardModel(ABC): 24 | 25 | def __init__(self, config): 26 | self.config = config 27 | 28 | @abstractmethod 29 | def compute_reward(self, data: DataProto) -> DataProto: 30 | """Computing reward given input_ids. The transformers should output a tensor with shape 31 | [batch_size, sequence_length], and the value at [EOS] mask should be gathered. 32 | 33 | Args: 34 | data: must contain keys "input_ids", "attention_mask" and "position_ids". 35 | - input_ids: [batch_size, sequence_length] 36 | - attention_mask: [batch_size, sequence_length] 37 | - position_ids: [batch_size, sequence_length] 38 | 39 | Returns: a data pass protocol containing "reward". Only the [EOS] position contains the reward. 40 | Other position should have zero reward. Note that this may change in the future if we use 41 | dense reward. So, we leave the interface for general case. 42 | - reward: [batch_size, sequence_length]. 43 | 44 | """ 45 | pass 46 | -------------------------------------------------------------------------------- /src/verl/workers/reward_model/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .reward_model import MegatronRewardModel 16 | -------------------------------------------------------------------------------- /src/verl/workers/rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BaseRollout 16 | from .naive import NaiveRollout 17 | from .hf_rollout import HFRollout 18 | 19 | __all__ = ["BaseRollout", "NaiveRollout", "HFRollout"] 20 | -------------------------------------------------------------------------------- /src/verl/workers/rollout/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from typing import Iterable, Union 17 | 18 | from verl import DataProto 19 | 20 | __all__ = ['BaseRollout'] 21 | 22 | 23 | class BaseRollout(ABC): 24 | 25 | def __init__(self): 26 | """ 27 | 28 | Args: 29 | dataloader: an Iterable of TensorDict that consistently generates prompts. Note that the dataloader 30 | should handle when the training stops. 31 | """ 32 | super().__init__() 33 | 34 | @abstractmethod 35 | def generate_sequences(self, prompts: DataProto) -> DataProto: 36 | """Generate sequences""" 37 | pass 38 | -------------------------------------------------------------------------------- /src/verl/workers/rollout/naive/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .naive_rollout import NaiveRollout 16 | -------------------------------------------------------------------------------- /src/verl/workers/rollout/sglang_rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | 14 | from .sglang_rollout import SGLangRollout 15 | -------------------------------------------------------------------------------- /src/verl/workers/rollout/vllm_rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from importlib.metadata import version, PackageNotFoundError 16 | 17 | ### 18 | # [SUPPORT AMD:] 19 | import torch 20 | ### 21 | 22 | 23 | def get_version(pkg): 24 | try: 25 | return version(pkg) 26 | except PackageNotFoundError: 27 | return None 28 | 29 | 30 | package_name = 'vllm' 31 | package_version = get_version(package_name) 32 | 33 | ### 34 | # package_version = get_version(package_name) 35 | # [SUPPORT AMD:] 36 | if "AMD" in torch.cuda.get_device_name(): 37 | import re 38 | package_version = version(package_name) 39 | package_version = re.match(r'(\d+\.\d+\.?\d*)', package_version).group(1) 40 | else: 41 | package_version = get_version(package_name) 42 | ### 43 | 44 | if package_version <= '0.6.3': 45 | vllm_mode = 'customized' 46 | from .vllm_rollout import vLLMRollout 47 | from .fire_vllm_rollout import FIREvLLMRollout 48 | else: 49 | vllm_mode = 'spmd' 50 | from .vllm_rollout_spmd import vLLMRollout, vLLMRolloutWithTool 51 | -------------------------------------------------------------------------------- /src/verl/workers/sharding_manager/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from verl.utils.import_utils import ( 16 | is_vllm_available, 17 | is_sglang_available, 18 | is_megatron_core_available, 19 | ) 20 | 21 | from .base import BaseShardingManager 22 | from .fsdp_ulysses import FSDPUlyssesShardingManager 23 | 24 | AllGatherPPModel = None 25 | 26 | if is_megatron_core_available() and is_vllm_available(): 27 | from .megatron_vllm import AllGatherPPModel, MegatronVLLMShardingManager 28 | elif AllGatherPPModel is not None: 29 | pass 30 | else: 31 | AllGatherPPModel = None 32 | MegatronVLLMShardingManager = None 33 | 34 | if is_vllm_available(): 35 | from .fsdp_vllm import FSDPVLLMShardingManager 36 | else: 37 | FSDPVLLMShardingManager = None 38 | 39 | # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability. 40 | # However, due to veRL's setting, the main process of ray can not find any CUDA device, which would potentially lead to: 41 | # "RuntimeError: No CUDA GPUs are available". 42 | # For this reason, sharding_manager.__init__ should not import SGLangShardingManager and user need to import use the abs path. 43 | # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 44 | # if is_sglang_available(): 45 | # from .fsdp.fsdp_sglang import FSDPSGLangShardingManager 46 | # else: 47 | # FSDPSGLangShardingManager = None 48 | -------------------------------------------------------------------------------- /src/verl/workers/sharding_manager/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Sharding manager to implement HybridEngine 16 | """ 17 | 18 | from verl import DataProto 19 | 20 | 21 | class BaseShardingManager: 22 | 23 | def __enter__(self): 24 | pass 25 | 26 | def __exit__(self, exc_type, exc_value, traceback): 27 | pass 28 | 29 | def preprocess_data(self, data: DataProto) -> DataProto: 30 | return data 31 | 32 | def postprocess_data(self, data: DataProto) -> DataProto: 33 | return data 34 | -------------------------------------------------------------------------------- /src/verl/workers/sharding_manager/fsdp_ulysses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT 16 | """ 17 | from .base import BaseShardingManager 18 | 19 | from torch.distributed.device_mesh import DeviceMesh 20 | 21 | from verl.utils.torch_functional import allgather_dict_tensors 22 | from verl.protocol import all_gather_data_proto 23 | from verl.utils.ulysses import set_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_group 24 | import numpy as np 25 | 26 | import torch 27 | import torch.distributed 28 | 29 | from verl import DataProto 30 | 31 | 32 | class FSDPUlyssesShardingManager(BaseShardingManager): 33 | """ 34 | Sharding manager to support data resharding when using FSDP + Ulysses 35 | """ 36 | 37 | def __init__(self, device_mesh: DeviceMesh): 38 | super().__init__() 39 | self.device_mesh = device_mesh 40 | self.seed_offset = 12345 41 | 42 | def __enter__(self): 43 | if self.device_mesh is not None: 44 | # We have a global SP group 45 | # so we have to change to use model-specific sp group 46 | self.prev_sp_group = get_ulysses_sequence_parallel_group() 47 | set_ulysses_sequence_parallel_group(self.device_mesh['sp'].get_group()) 48 | # TODO: check how to set seed for each model 49 | 50 | def __exit__(self, exc_type, exc_value, traceback): 51 | # restore random states 52 | if self.device_mesh is not None: 53 | # revert to previous sp group 54 | set_ulysses_sequence_parallel_group(self.prev_sp_group) 55 | # TODO: check how to set seed for each model 56 | 57 | def preprocess_data(self, data: DataProto) -> DataProto: 58 | """ 59 | AllGather data from sp region 60 | This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE 61 | In Ulysses, we need to make sure the same data is used across a SP group 62 | """ 63 | if self.device_mesh is not None: 64 | sp_size = self.device_mesh['sp'].size() 65 | group = self.device_mesh['sp'].get_group() 66 | 67 | all_gather_data_proto(data=data, process_group=group) 68 | return data 69 | 70 | def postprocess_data(self, data: DataProto) -> DataProto: 71 | """ 72 | Split the data to follow FSDP partition 73 | """ 74 | if self.device_mesh is not None: 75 | sp_size = self.device_mesh['sp'].size() 76 | sp_rank = self.device_mesh['sp'].get_local_rank() 77 | data = data.chunk(chunks=sp_size)[sp_rank] 78 | return data -------------------------------------------------------------------------------- /src/verl/workers/sharding_manager/patch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .fsdp_vllm_patch import patched_ds_v3_load_weights 16 | -------------------------------------------------------------------------------- /src/version: -------------------------------------------------------------------------------- 1 | 0.2.0 2 | --------------------------------------------------------------------------------