├── images └── figure-1.png ├── rstar2_agent ├── config │ ├── rstar2_agent_loop.yaml │ ├── tool_config │ │ ├── jupyter_tool_config.yaml │ │ └── python_tool_config.yaml │ └── rstar2_agent_trainer.yaml ├── reward │ ├── __init__.py │ ├── compute_score.py │ └── server.py ├── rollout │ ├── __init__.py │ └── rstar2_agent_loop.py ├── down_sample │ ├── __init__.py │ ├── reject_sampling.py │ ├── utils.py │ └── roc.py ├── tools │ ├── __init__.py │ ├── tool_parser.py │ ├── code_judge_tool.py │ ├── code_judge_utils.py │ └── request_processor.py ├── __init__.py ├── main_rstar2_agent.py └── rstar2_agent_ray_trainer.py ├── .gitmodules ├── pyproject.toml ├── install.sh ├── CODE_OF_CONDUCT.md ├── fused_compute_score ├── __init__.py ├── math_verify.py └── prime_math │ ├── math_normalize.py │ ├── __init__.py │ └── grader.py ├── LICENSE ├── SUPPORT.md ├── data_preprocess ├── math500_rstar2_agent_loop.py ├── aime2024_rstar2_agent_loop.py ├── dapo_rstar2_agent_loop.py └── aime2025_rstar2_agent_loop.py ├── SECURITY.md ├── examples ├── math500_eval.sh ├── run_qwen3-14b_rstar2_agent_weave.sh ├── aime_eval.sh └── chat_with_tool_call.py ├── .gitignore └── README.md /images/figure-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rStar/HEAD/images/figure-1.png -------------------------------------------------------------------------------- /rstar2_agent/config/rstar2_agent_loop.yaml: -------------------------------------------------------------------------------- 1 | - name: rstar2_agent 2 | _target_: rstar2_agent.rollout.rstar2_agent_loop.RStar2AgentLoop 3 | -------------------------------------------------------------------------------- /rstar2_agent/reward/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .server import CodeJudgeRewardManager 5 | -------------------------------------------------------------------------------- /rstar2_agent/rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .rstar2_agent_loop import RStar2AgentLoop 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "code-judge"] 2 | path = code-judge 3 | url = https://github.com/0xWJ/code-judge.git 4 | [submodule "verl"] 5 | path = verl 6 | url = https://github.com/J-shang/verl.git 7 | -------------------------------------------------------------------------------- /rstar2_agent/down_sample/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .reject_sampling import reject_equal_reward 5 | from .roc import resample_of_correct 6 | -------------------------------------------------------------------------------- /rstar2_agent/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .code_judge_tool import CodeJudgeTool, SimJupyterTool, PythonTool 5 | from .tool_parser import RStar2AgentHermesToolParser 6 | -------------------------------------------------------------------------------- /rstar2_agent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .rollout.rstar2_agent_loop import RStar2AgentLoop 5 | from .tools import RStar2AgentHermesToolParser 6 | from .reward import CodeJudgeRewardManager 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "rstar2_agent" 7 | version = "0.1.0" 8 | description = "rstar2-agent core training recipe" 9 | dependencies = [ 10 | "weave", 11 | "sympy", 12 | "scipy", 13 | "math-verify" 14 | ] 15 | 16 | [tool.setuptools] 17 | packages = ["rstar2_agent", "fused_compute_score"] 18 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") &>/dev/null && pwd -P) 2 | 3 | cd $SCRIPT_DIR 4 | git submodule init 5 | git submodule update 6 | 7 | # install verl 8 | pip install "torch<2.8" 9 | pip install -r verl/requirements_sglang.txt 10 | pip install -e verl 11 | 12 | # install code judge 13 | pip install -r code-judge/requirements.txt 14 | pip install -e code-judge 15 | 16 | # install rstar2_agent 17 | pip install -e . 18 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /fused_compute_score/__init__.py: -------------------------------------------------------------------------------- 1 | from .prime_math import compute_score as prime_compute_score 2 | from .math_verify import compute_score as math_verify_compute_score 3 | 4 | def compute_score(model_output: str, ground_truth: str) -> bool: 5 | try: 6 | prime_score = prime_compute_score(model_output, ground_truth)[0] 7 | if prime_score: 8 | return 1.0 9 | except Exception as e: 10 | prime_score = 0.0 11 | try: 12 | math_verify_score = math_verify_compute_score(model_output, ground_truth) 13 | if math_verify_score: 14 | return 1.0 15 | except Exception as e: 16 | return 0.0 17 | return 0.0 18 | -------------------------------------------------------------------------------- /rstar2_agent/config/tool_config/jupyter_tool_config.yaml: -------------------------------------------------------------------------------- 1 | tools: 2 | - class_name: "rstar2_agent.tools.code_judge_tool.CodeJudgeTool" 3 | config: 4 | type: native 5 | request_processor_concurrency: 8 6 | request_processor_batch_size: 32 7 | request_processor_batch_timeout_seconds: 30 8 | host_addr: "localhost" 9 | host_port: "8088" 10 | tool_schema: 11 | type: "function" 12 | function: 13 | name: "jupyter_code" 14 | description: "Execute python code in a Jupyter notebook cell and return result" 15 | parameters: 16 | type: "object" 17 | properties: 18 | code: 19 | type: "string" 20 | description: "The python code to execute in a single cell" 21 | required: ["code"] 22 | -------------------------------------------------------------------------------- /rstar2_agent/reward/compute_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from verl.utils.reward_score.prime_math import compute_score as prime_compute_score 5 | from verl.utils.reward_score.math_verify import compute_score as math_verify_compute_score 6 | 7 | def compute_score(model_output: str, ground_truth: str) -> bool: 8 | try: 9 | prime_score = prime_compute_score(model_output, ground_truth)[0] 10 | if prime_score: 11 | return 1.0 12 | except Exception as e: 13 | prime_score = 0.0 14 | try: 15 | math_verify_score = math_verify_compute_score(model_output, ground_truth) 16 | if math_verify_score: 17 | return 1.0 18 | except Exception as e: 19 | return 0.0 20 | return 0.0 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /rstar2_agent/config/tool_config/python_tool_config.yaml: -------------------------------------------------------------------------------- 1 | tools: 2 | - class_name: "rstar2_agent.tools.code_judge_tool.PythonTool" 3 | config: 4 | type: native 5 | request_processor_concurrency: 8 6 | request_processor_batch_size: 32 7 | request_processor_batch_timeout_seconds: 30 8 | host_addr: "localhost" 9 | host_port: "8088" 10 | tool_schema: 11 | type: "function" 12 | function: 13 | name: "python_code_with_standard_io" 14 | description: "Execute Python code with standard input and capture standard output. This function takes a Python code string and an input string, provides the input string through standard input (stdin) to the code, and captures and returns any output produced through standard output (stdout). If the executed code raises an exception, the error message will be captured and returned instead." 15 | parameters: 16 | type: "object" 17 | properties: 18 | code: 19 | type: "string" 20 | description: "A string containing Python code to be executed. The code can read from standard input using the input() function." 21 | input: 22 | type: "string" 23 | description: "A string that will be provided as standard input to the code when it calls input()." 24 | required: ["code", "input"] 25 | -------------------------------------------------------------------------------- /fused_compute_score/math_verify.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | try: 16 | from math_verify.errors import TimeoutException 17 | from math_verify.metric import math_metric 18 | from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig 19 | except ImportError: 20 | print("To use Math-Verify, please install it first by running `pip install math-verify`.") 21 | 22 | 23 | def compute_score(model_output: str, ground_truth: str, timeout_score: float = 0) -> bool: 24 | verify_func = math_metric( 25 | gold_extraction_target=(LatexExtractionConfig(),), 26 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), 27 | ) 28 | ret_score = 0.0 29 | 30 | # Wrap the ground truth in \boxed{} format for verification 31 | ground_truth_boxed = "\\boxed{" + ground_truth + "}" 32 | try: 33 | ret_score, _ = verify_func([ground_truth_boxed], [model_output]) 34 | except Exception: 35 | pass 36 | except TimeoutException: 37 | ret_score = timeout_score 38 | 39 | return ret_score 40 | -------------------------------------------------------------------------------- /rstar2_agent/down_sample/reject_sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from verl.protocol import DataProto 8 | from .utils import filter_by_mask 9 | 10 | 11 | def reject_equal_reward(batch: DataProto, do_sample=True, world_size=None): 12 | # Rejection sampling based on rewards 13 | # Group rewards by uid 14 | uids = batch.non_tensor_batch['uid'] 15 | unique_uids = np.unique(uids) 16 | valid_mask = torch.ones(len(uids), dtype=torch.bool) 17 | solve_equal = 0 18 | solve_equal_zeros = 0 19 | solve_equal_non_all_zeros = 0 20 | 21 | for uid in unique_uids: 22 | uid_mask = uids == uid 23 | # Sum rewards for each sequence 24 | uid_rewards = batch.batch['token_level_scores'][uid_mask].sum(-1) 25 | 26 | if torch.allclose(uid_rewards[0], uid_rewards): 27 | valid_mask[uid_mask] = False 28 | solve_equal += 1 29 | 30 | if torch.allclose(torch.zeros_like(uid_rewards), uid_rewards): 31 | solve_equal_zeros += 1 32 | else: 33 | solve_equal_non_all_zeros += 1 34 | 35 | metrics = {} 36 | metrics['reject_equal_reward/solve_non_equal_total'] = len(unique_uids) - solve_equal 37 | metrics['reject_equal_reward/solve_equal_total'] = solve_equal 38 | metrics['reject_equal_reward/solve_equal_total_ratio'] = solve_equal / len(unique_uids) if len(unique_uids) > 0 else 0 39 | metrics['reject_equal_reward/solve_equal_zeros'] = solve_equal_zeros 40 | metrics['reject_equal_reward/solve_equal_non_all_zeros'] = solve_equal_non_all_zeros 41 | 42 | if do_sample: 43 | if not valid_mask.any(): 44 | return None, metrics 45 | batch = filter_by_mask(batch, valid_mask, world_size) 46 | return batch, metrics 47 | -------------------------------------------------------------------------------- /data_preprocess/math500_rstar2_agent_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | """ 5 | Preprocess the MATH 500 dataset to parquet format 6 | """ 7 | 8 | import argparse 9 | import os 10 | 11 | import datasets 12 | 13 | from verl.utils.hdfs_io import copy, makedirs 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--local_dir", default="~/data/rstar2-agent/math500") 19 | parser.add_argument("--hdfs_dir", default=None) 20 | 21 | args = parser.parse_args() 22 | 23 | data_source = "HuggingFaceH4/MATH-500" 24 | dataset = datasets.load_dataset(data_source, "default") 25 | 26 | train_dataset = dataset["test"] 27 | 28 | # add a row to each data item that represents a unique id 29 | def make_map_fn(split): 30 | def process_fn(example, idx): 31 | question = example.pop("problem") 32 | solution = example.pop("answer") 33 | 34 | data = { 35 | "data_source": f"rstar_{data_source}", 36 | "agent_name": "rstar2_agent", 37 | "prompt": [ 38 | { 39 | "role": "user", 40 | "content": question, 41 | }, 42 | ], 43 | "ability": "math", 44 | "reward_model": {"style": "rule", "ground_truth": solution}, 45 | "extra_info": { 46 | "split": split, 47 | "index": idx, 48 | "answer": solution, 49 | "question": question, 50 | "need_tools_kwargs": False, 51 | "interaction_kwargs": { 52 | "query": question, 53 | "ground_truth": solution, 54 | }, 55 | }, 56 | } 57 | return data 58 | 59 | return process_fn 60 | 61 | train_dataset = train_dataset.map(function=make_map_fn("test"), with_indices=True) 62 | 63 | local_dir = args.local_dir 64 | hdfs_dir = args.hdfs_dir 65 | 66 | train_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) 67 | 68 | if hdfs_dir is not None: 69 | makedirs(hdfs_dir) 70 | copy(src=local_dir, dst=hdfs_dir) 71 | -------------------------------------------------------------------------------- /data_preprocess/aime2024_rstar2_agent_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | """ 5 | Preprocess the AIME 2024 dataset to parquet format 6 | """ 7 | 8 | import argparse 9 | import os 10 | 11 | import datasets 12 | 13 | from verl.utils.hdfs_io import copy, makedirs 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--local_dir", default="~/data/rstar2-agent/aime2024") 19 | parser.add_argument("--hdfs_dir", default=None) 20 | 21 | args = parser.parse_args() 22 | 23 | data_source = "HuggingFaceH4/aime_2024" 24 | dataset = datasets.load_dataset(data_source, "default") 25 | 26 | train_dataset = dataset["train"] 27 | 28 | # add a row to each data item that represents a unique id 29 | def make_map_fn(split): 30 | def process_fn(example, idx): 31 | question = example.pop("problem") 32 | solution = example.pop("answer") 33 | 34 | data = { 35 | "data_source": f"rstar_{data_source}", 36 | "agent_name": "rstar2_agent", 37 | "prompt": [ 38 | { 39 | "role": "user", 40 | "content": question, 41 | }, 42 | ], 43 | "ability": "math", 44 | "reward_model": {"style": "rule", "ground_truth": solution}, 45 | "extra_info": { 46 | "split": split, 47 | "index": idx, 48 | "answer": solution, 49 | "question": question, 50 | "need_tools_kwargs": False, 51 | "interaction_kwargs": { 52 | "query": question, 53 | "ground_truth": solution, 54 | }, 55 | }, 56 | } 57 | return data 58 | 59 | return process_fn 60 | 61 | train_dataset = train_dataset.map(function=make_map_fn("test"), with_indices=True) 62 | 63 | local_dir = args.local_dir 64 | hdfs_dir = args.hdfs_dir 65 | 66 | train_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) 67 | 68 | if hdfs_dir is not None: 69 | makedirs(hdfs_dir) 70 | copy(src=local_dir, dst=hdfs_dir) 71 | -------------------------------------------------------------------------------- /data_preprocess/dapo_rstar2_agent_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | """ 5 | Preprocess the DAPO-Math-17k dataset to parquet format 6 | """ 7 | 8 | import argparse 9 | import os 10 | 11 | import datasets 12 | 13 | from verl.utils.hdfs_io import copy, makedirs 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--local_dir", default="~/data/rstar2-agent/dapo-math-17k-en") 19 | parser.add_argument("--hdfs_dir", default=None) 20 | 21 | args = parser.parse_args() 22 | 23 | data_source = "open-r1/DAPO-Math-17k-Processed" 24 | dataset = datasets.load_dataset(data_source, "en") 25 | 26 | train_dataset = dataset["train"] 27 | 28 | # add a row to each data item that represents a unique id 29 | def make_map_fn(split): 30 | def process_fn(example, idx): 31 | question = example.pop("prompt") 32 | solution = example["reward_model"]["ground_truth"] 33 | 34 | data = { 35 | "data_source": f"rstar_{data_source}", 36 | "agent_name": "rstar2_agent", 37 | "prompt": [ 38 | { 39 | "role": "user", 40 | "content": question, 41 | }, 42 | ], 43 | "ability": "math", 44 | "reward_model": {"style": "rule", "ground_truth": solution}, 45 | "extra_info": { 46 | "split": split, 47 | "index": idx, 48 | "answer": solution, 49 | "question": question, 50 | "need_tools_kwargs": False, 51 | "interaction_kwargs": { 52 | "query": question, 53 | "ground_truth": solution, 54 | }, 55 | }, 56 | } 57 | return data 58 | 59 | return process_fn 60 | 61 | train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) 62 | 63 | local_dir = args.local_dir 64 | hdfs_dir = args.hdfs_dir 65 | 66 | train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) 67 | 68 | if hdfs_dir is not None: 69 | makedirs(hdfs_dir) 70 | copy(src=local_dir, dst=hdfs_dir) 71 | -------------------------------------------------------------------------------- /data_preprocess/aime2025_rstar2_agent_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | """ 5 | Preprocess the AIME 2025 dataset to parquet format 6 | """ 7 | 8 | import argparse 9 | import os 10 | 11 | import datasets 12 | 13 | from verl.utils.hdfs_io import copy, makedirs 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--local_dir", default="~/data/rstar2-agent/aime2025") 19 | parser.add_argument("--hdfs_dir", default=None) 20 | 21 | args = parser.parse_args() 22 | 23 | data_source = "opencompass/AIME2025" 24 | dataset1 = datasets.load_dataset(data_source, "AIME2025-I")["test"] 25 | dataset2 = datasets.load_dataset(data_source, "AIME2025-II")["test"] 26 | 27 | train_dataset = datasets.concatenate_datasets([dataset1, dataset2]) 28 | 29 | # add a row to each data item that represents a unique id 30 | def make_map_fn(split): 31 | def process_fn(example, idx): 32 | question = example.pop("question") 33 | solution = example.pop("answer") 34 | 35 | data = { 36 | "data_source": f"rstar_{data_source}", 37 | "agent_name": "rstar2_agent", 38 | "prompt": [ 39 | { 40 | "role": "user", 41 | "content": question, 42 | }, 43 | ], 44 | "ability": "math", 45 | "reward_model": {"style": "rule", "ground_truth": solution}, 46 | "extra_info": { 47 | "split": split, 48 | "index": idx, 49 | "answer": solution, 50 | "question": question, 51 | "need_tools_kwargs": False, 52 | "interaction_kwargs": { 53 | "query": question, 54 | "ground_truth": solution, 55 | }, 56 | }, 57 | } 58 | return data 59 | 60 | return process_fn 61 | 62 | train_dataset = train_dataset.map(function=make_map_fn("test"), with_indices=True) 63 | 64 | local_dir = args.local_dir 65 | hdfs_dir = args.hdfs_dir 66 | 67 | train_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) 68 | 69 | if hdfs_dir is not None: 70 | makedirs(hdfs_dir) 71 | copy(src=local_dir, dst=hdfs_dir) 72 | -------------------------------------------------------------------------------- /rstar2_agent/tools/tool_parser.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import asyncio 5 | import json 6 | import logging 7 | import os 8 | 9 | import regex as re 10 | 11 | from verl.experimental.agent_loop.tool_parser import ToolParser, FunctionCall 12 | from verl.utils.rollout_trace import rollout_trace_op 13 | 14 | logger = logging.getLogger(__file__) 15 | logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) 16 | 17 | 18 | @ToolParser.register("rstar2_agent_hermes") 19 | class RStar2AgentHermesToolParser(ToolParser): 20 | """Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py""" 21 | 22 | def __init__(self, tokenizer) -> None: 23 | super().__init__(tokenizer) 24 | 25 | self.tool_call_start_token: str = "" 26 | self.tool_call_end_token: str = "" 27 | self.tool_call_regex = re.compile(r"(.*?)", re.DOTALL) 28 | 29 | @rollout_trace_op 30 | async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]: 31 | loop = asyncio.get_running_loop() 32 | text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids) 33 | if self.tool_call_start_token not in text or self.tool_call_end_token not in text: 34 | return text, [] 35 | 36 | matches = self.tool_call_regex.findall(text) 37 | function_calls = [] 38 | for match in matches: 39 | try: 40 | function_call = json.loads(match) 41 | name, arguments = function_call["name"], function_call["arguments"] 42 | function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False))) 43 | except Exception as e: 44 | logger.error(f"Failed to decode tool call: {e}") 45 | ################################### rStar ################################### 46 | # keep the error msg as tool response 47 | from verl.tools.schemas import ToolResponse 48 | function_calls.append(ToolResponse(text=f"Failed to decode tool call: {e}")) 49 | ############################################################################# 50 | 51 | # remaing text exclude tool call tokens 52 | content = self.tool_call_regex.sub("", text) 53 | 54 | return content, function_calls 55 | -------------------------------------------------------------------------------- /rstar2_agent/down_sample/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | from verl.protocol import DataProto, DataProtoItem 6 | 7 | 8 | def filter_by_mask(batch: DataProto, mask: torch.Tensor, num_trainer_replicas: int) -> DataProto: 9 | # Filter batch to keep only valid samples 10 | batch = batch[mask] 11 | # Round down to the nearest multiple of world size 12 | max_batch_size = (batch.batch['input_ids'].shape[0] // num_trainer_replicas) * num_trainer_replicas 13 | if not max_batch_size: 14 | # give up, you got everything either all wrong or right. 15 | return None 16 | 17 | size_mask = torch.zeros(batch.batch['input_ids'].shape[0], dtype=torch.bool) 18 | size_mask[:max_batch_size] = True 19 | batch = batch[size_mask] 20 | return batch 21 | 22 | 23 | def decode_prompt_response_str(data: DataProto, tokenizer) -> tuple[list[str], list[str]]: 24 | """ 25 | Decode the prompt and response strings from a DataProto object using the provided tokenizer. 26 | 27 | Args: 28 | data (DataProto): The DataProto object containing the data. 29 | tokenizer: The tokenizer to decode the IDs into strings. 30 | 31 | Returns: 32 | tuple[list[str], list[str]]: A tuple containing two lists: 33 | - List of decoded prompt strings. 34 | - List of decoded response strings. 35 | """ 36 | prompts = [] 37 | responses = [] 38 | 39 | for item in data: 40 | # Decode prompt IDs 41 | if "prompt_text" in item.non_tensor_batch and item.non_tensor_batch['prompt_text'] is not None: 42 | prompt_str = item.non_tensor_batch['prompt_text'] 43 | else: 44 | prompt_ids = item.batch['prompts'] 45 | valid_prompt_length = item.batch['attention_mask'][:prompt_ids.shape[-1]].sum() 46 | prompt_str = tokenizer.decode(prompt_ids[-valid_prompt_length:], skip_special_tokens=False) 47 | prompts.append(prompt_str) 48 | 49 | # Decode response IDs 50 | if "response_text" in item.non_tensor_batch and item.non_tensor_batch['response_text'] is not None: 51 | response_str = item.non_tensor_batch['response_text'] 52 | else: 53 | response_ids = item.batch['responses'] 54 | valid_response_length = item.batch['attention_mask'][item.batch['prompts'].shape[-1]:].sum() 55 | response_str = tokenizer.decode(response_ids[:valid_response_length], skip_special_tokens=False) 56 | responses.append(response_str) 57 | 58 | return prompts, responses 59 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /examples/math500_eval.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | set -x 5 | 6 | ulimit -n 65535 7 | 8 | SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") &>/dev/null && pwd -P) 9 | PROJECT_DIR=$SCRIPT_DIR/.. 10 | CONFIG_PATH="$PROJECT_DIR/rstar2_agent/config" 11 | PROJECT_NAME="rstar2-agent" 12 | EXPERIMENT_NAME="eval-rstar2-agent-math500" 13 | 14 | python3 -m rstar2_agent.main_rstar2_agent \ 15 | --config-path="$CONFIG_PATH" \ 16 | --config-name='rstar2_agent_trainer' \ 17 | algorithm.adv_estimator=grpo \ 18 | data.train_batch_size=128 \ 19 | data.max_prompt_length=2048 \ 20 | data.max_response_length=30720 \ 21 | data.filter_overlong_prompts=True \ 22 | data.truncation='error' \ 23 | data.return_raw_chat=True \ 24 | actor_rollout_ref.model.path=$MODEL_PATH \ 25 | actor_rollout_ref.actor.optim.lr=1e-6 \ 26 | actor_rollout_ref.actor.optim.lr_warmup_steps=20 \ 27 | actor_rollout_ref.model.use_remove_padding=True \ 28 | actor_rollout_ref.actor.ppo_mini_batch_size=128 \ 29 | actor_rollout_ref.actor.use_dynamic_bsz=True \ 30 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=20480 \ 31 | actor_rollout_ref.actor.use_kl_loss=False \ 32 | actor_rollout_ref.actor.kl_loss_coef=0 \ 33 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \ 34 | actor_rollout_ref.actor.entropy_coeff=0 \ 35 | actor_rollout_ref.actor.clip_ratio_low=0.2 \ 36 | actor_rollout_ref.actor.clip_ratio_high=0.28 \ 37 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 38 | actor_rollout_ref.actor.fsdp_config.param_offload=False \ 39 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ 40 | actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ 41 | actor_rollout_ref.rollout.name=sglang \ 42 | actor_rollout_ref.rollout.mode=async \ 43 | actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ 44 | actor_rollout_ref.rollout.n=32 \ 45 | actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ 46 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 47 | actor_rollout_ref.rollout.trace.backend=weave \ 48 | actor_rollout_ref.rollout.trace.token2text=True \ 49 | actor_rollout_ref.rollout.agent.num_workers=1 \ 50 | algorithm.use_kl_in_reward=False \ 51 | augmentation.do_down_sampling=True \ 52 | augmentation.down_sampling_config.reject_equal_reward=True \ 53 | augmentation.down_sampling_config.roc_error_ratio=True \ 54 | augmentation.down_sampling_config.roc_answer_format=True \ 55 | augmentation.down_sampling_config.min_zero_reward_trace_num=2 \ 56 | augmentation.down_sampling_config.min_non_zero_reward_trace_num=2 \ 57 | augmentation.down_sampling_config.down_sample_to_n=16 \ 58 | reward_model.reward_manager=code_judge \ 59 | trainer.critic_warmup=0 \ 60 | trainer.logger='["console", "wandb"]' \ 61 | trainer.project_name=$PROJECT_NAME \ 62 | trainer.experiment_name=$EXPERIMENT_NAME \ 63 | trainer.n_gpus_per_node=8 \ 64 | trainer.nnodes=1 \ 65 | trainer.save_freq=-1 \ 66 | trainer.test_freq=5 \ 67 | trainer.total_training_steps=200 \ 68 | trainer.val_only=True \ 69 | data.train_files="['$HOME/data/rstar2-agent/dapo-math-17k-en/train.parquet']" \ 70 | data.val_files="['$HOME/data/rstar2-agent/math500/test.parquet']" \ 71 | actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/rstar2_agent/config/tool_config/python_tool_config.yaml" \ 72 | trainer.total_epochs=15 $@ 2>&1 | tee $PROJECT_NAME-$EXPERIMENT_NAME.log 73 | -------------------------------------------------------------------------------- /examples/run_qwen3-14b_rstar2_agent_weave.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | set -x 5 | 6 | ulimit -n 65535 7 | 8 | SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") &>/dev/null && pwd -P) 9 | PROJECT_DIR=$SCRIPT_DIR/.. 10 | CONFIG_PATH="$PROJECT_DIR/rstar2_agent/config" 11 | PROJECT_NAME="rstar2-agent" 12 | EXPERIMENT_NAME="qwen3-14b-sgl-tool-agent-verify-n32" 13 | MODEL_PATH="$HOME/models/Qwen3-14B-Base" 14 | 15 | python3 -m rstar2_agent.main_rstar2_agent \ 16 | --config-path="$CONFIG_PATH" \ 17 | --config-name='rstar2_agent_trainer' \ 18 | algorithm.adv_estimator=grpo \ 19 | data.train_batch_size=128 \ 20 | data.max_prompt_length=2048 \ 21 | data.max_response_length=8192 \ 22 | data.filter_overlong_prompts=True \ 23 | data.truncation='error' \ 24 | data.return_raw_chat=True \ 25 | actor_rollout_ref.model.path=$MODEL_PATH \ 26 | actor_rollout_ref.actor.optim.lr=1e-6 \ 27 | actor_rollout_ref.actor.optim.lr_warmup_steps=20 \ 28 | actor_rollout_ref.model.use_remove_padding=True \ 29 | actor_rollout_ref.actor.ppo_mini_batch_size=128 \ 30 | actor_rollout_ref.actor.use_dynamic_bsz=True \ 31 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=20480 \ 32 | actor_rollout_ref.actor.use_kl_loss=False \ 33 | actor_rollout_ref.actor.kl_loss_coef=0 \ 34 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \ 35 | actor_rollout_ref.actor.entropy_coeff=0 \ 36 | actor_rollout_ref.actor.clip_ratio_low=0.2 \ 37 | actor_rollout_ref.actor.clip_ratio_high=0.28 \ 38 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 39 | actor_rollout_ref.actor.fsdp_config.param_offload=False \ 40 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ 41 | actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ 42 | actor_rollout_ref.rollout.name=sglang \ 43 | actor_rollout_ref.rollout.mode=async \ 44 | actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ 45 | actor_rollout_ref.rollout.n=32 \ 46 | actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ 47 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 48 | actor_rollout_ref.rollout.trace.backend=weave \ 49 | actor_rollout_ref.rollout.trace.token2text=True \ 50 | actor_rollout_ref.rollout.agent.num_workers=1 \ 51 | algorithm.use_kl_in_reward=False \ 52 | augmentation.do_down_sampling=True \ 53 | augmentation.down_sampling_config.reject_equal_reward=True \ 54 | augmentation.down_sampling_config.roc_error_ratio=True \ 55 | augmentation.down_sampling_config.roc_answer_format=True \ 56 | augmentation.down_sampling_config.min_zero_reward_trace_num=2 \ 57 | augmentation.down_sampling_config.min_non_zero_reward_trace_num=2 \ 58 | augmentation.down_sampling_config.down_sample_to_n=16 \ 59 | reward_model.reward_manager=code_judge \ 60 | trainer.critic_warmup=0 \ 61 | trainer.logger='["console", "wandb"]' \ 62 | trainer.project_name=$PROJECT_NAME \ 63 | trainer.experiment_name=$EXPERIMENT_NAME \ 64 | trainer.n_gpus_per_node=8 \ 65 | trainer.nnodes=1 \ 66 | trainer.save_freq=-1 \ 67 | trainer.test_freq=5 \ 68 | trainer.total_training_steps=200 \ 69 | data.train_files="['$HOME/data/rstar2-agent/dapo-math-17k-en/train.parquet']" \ 70 | data.val_files="['$HOME/data/rstar2-agent/aime2024/test.parquet']" \ 71 | actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/rstar2_agent/config/tool_config/python_tool_config.yaml" \ 72 | trainer.total_epochs=15 $@ 2>&1 | tee $PROJECT_NAME-$EXPERIMENT_NAME.log 73 | -------------------------------------------------------------------------------- /examples/aime_eval.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | set -x 5 | 6 | ulimit -n 65535 7 | 8 | SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") &>/dev/null && pwd -P) 9 | PROJECT_DIR=$SCRIPT_DIR/.. 10 | CONFIG_PATH="$PROJECT_DIR/rstar2_agent/config" 11 | PROJECT_NAME="rstar2-agent" 12 | EXPERIMENT_NAME="eval-rstar2-agent-aime" 13 | 14 | python3 -m rstar2_agent.main_rstar2_agent \ 15 | --config-path="$CONFIG_PATH" \ 16 | --config-name='rstar2_agent_trainer' \ 17 | algorithm.adv_estimator=grpo \ 18 | data.train_batch_size=128 \ 19 | data.max_prompt_length=2048 \ 20 | data.max_response_length=30720 \ 21 | data.filter_overlong_prompts=True \ 22 | data.truncation='error' \ 23 | data.return_raw_chat=True \ 24 | actor_rollout_ref.model.path=$MODEL_PATH \ 25 | actor_rollout_ref.actor.optim.lr=1e-6 \ 26 | actor_rollout_ref.actor.optim.lr_warmup_steps=20 \ 27 | actor_rollout_ref.model.use_remove_padding=True \ 28 | actor_rollout_ref.actor.ppo_mini_batch_size=128 \ 29 | actor_rollout_ref.actor.use_dynamic_bsz=True \ 30 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=20480 \ 31 | actor_rollout_ref.actor.use_kl_loss=False \ 32 | actor_rollout_ref.actor.kl_loss_coef=0 \ 33 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \ 34 | actor_rollout_ref.actor.entropy_coeff=0 \ 35 | actor_rollout_ref.actor.clip_ratio_low=0.2 \ 36 | actor_rollout_ref.actor.clip_ratio_high=0.28 \ 37 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 38 | actor_rollout_ref.actor.fsdp_config.param_offload=False \ 39 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ 40 | actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ 41 | actor_rollout_ref.rollout.name=sglang \ 42 | actor_rollout_ref.rollout.mode=async \ 43 | actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ 44 | actor_rollout_ref.rollout.n=32 \ 45 | actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ 46 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 47 | actor_rollout_ref.rollout.trace.backend=weave \ 48 | actor_rollout_ref.rollout.trace.token2text=True \ 49 | actor_rollout_ref.rollout.agent.num_workers=1 \ 50 | algorithm.use_kl_in_reward=False \ 51 | augmentation.do_down_sampling=True \ 52 | augmentation.down_sampling_config.reject_equal_reward=True \ 53 | augmentation.down_sampling_config.roc_error_ratio=True \ 54 | augmentation.down_sampling_config.roc_answer_format=True \ 55 | augmentation.down_sampling_config.min_zero_reward_trace_num=2 \ 56 | augmentation.down_sampling_config.min_non_zero_reward_trace_num=2 \ 57 | augmentation.down_sampling_config.down_sample_to_n=16 \ 58 | reward_model.reward_manager=code_judge \ 59 | trainer.critic_warmup=0 \ 60 | trainer.logger='["console", "wandb"]' \ 61 | trainer.project_name=$PROJECT_NAME \ 62 | trainer.experiment_name=$EXPERIMENT_NAME \ 63 | trainer.n_gpus_per_node=8 \ 64 | trainer.nnodes=1 \ 65 | trainer.save_freq=-1 \ 66 | trainer.test_freq=5 \ 67 | trainer.total_training_steps=200 \ 68 | trainer.val_only=True \ 69 | actor_rollout_ref.rollout.val_kwargs.n=16 \ 70 | data.train_files="['$HOME/data/rstar2-agent/dapo-math-17k-en/train.parquet']" \ 71 | data.val_files="['$HOME/data/rstar2-agent/aime2024/test.parquet', '$HOME/data/rstar2-agent/aime2025/test.parquet']" \ 72 | actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/rstar2_agent/config/tool_config/python_tool_config.yaml" \ 73 | trainer.total_epochs=15 $@ 2>&1 | tee $PROJECT_NAME-$EXPERIMENT_NAME.log 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | outputs/ 165 | wandb/ 166 | -------------------------------------------------------------------------------- /rstar2_agent/config/rstar2_agent_trainer.yaml: -------------------------------------------------------------------------------- 1 | # the rstar config will override default ppo_trainer.yaml 2 | 3 | hydra: 4 | searchpath: 5 | - file://verl/verl/trainer/config 6 | 7 | defaults: 8 | - ppo_trainer 9 | - _self_ 10 | 11 | data: 12 | return_raw_chat: True 13 | 14 | actor_rollout_ref: 15 | hybrid_engine: True 16 | model: 17 | custom_chat_template: "\n{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'A conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the User with the answer. The reasoning process is enclosed within and answer is enclosed within tags, respectively, i.e., reasoning process here answer here .' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nA conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the User with the answer. The reasoning process is enclosed within and answer is enclosed within tags, respectively, i.e., reasoning process here answer here .<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + 'You must put your answer inside tags, i.e., answer here . And your final answer will be extracted automatically by the \\\\boxed{} tag.\\nThis is the problem:\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n" 18 | rollout: 19 | name: sglang 20 | multi_turn: 21 | enable: True 22 | max_assistant_turns: 5 23 | tool_config_path: "./config/tool_config/python_tool_config.yaml" 24 | # "rstar2_agent_hermes" additionally takes the toolcall parse error as tool response. 25 | format: rstar2_agent_hermes 26 | agent: 27 | num_workers: 1 28 | agent_loop_config_path: rstar2_agent/config/rstar2_agent_loop.yaml 29 | 30 | augmentation: 31 | do_down_sampling: False 32 | down_sampling_config: 33 | reject_equal_reward: False 34 | roc_error_ratio: False 35 | roc_answer_format: False 36 | min_zero_reward_trace_num: -1 37 | min_non_zero_reward_trace_num: -1 38 | down_sample_to_n: -1 39 | -------------------------------------------------------------------------------- /rstar2_agent/reward/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from collections import defaultdict 5 | from typing import Any 6 | 7 | import aiohttp 8 | import asyncio 9 | import base64 10 | import re 11 | import torch 12 | 13 | from verl import DataProto 14 | from verl.workers.reward_manager import register 15 | from verl.workers.reward_manager.abstract import AbstractRewardManager 16 | 17 | from rstar2_agent.tools.code_judge_utils import run_tool_calls_on_server_async 18 | 19 | verify_math_prefix = """ 20 | from fused_compute_score import compute_score 21 | import base64 22 | solution_str = base64.b64decode("{}".encode()).decode() 23 | ground_truth = base64.b64decode("{}".encode()).decode() 24 | result = compute_score(solution_str, ground_truth) 25 | """ 26 | 27 | verify_math_suffix = """ 28 | print(f"{result}") 29 | """ 30 | 31 | 32 | @register("code_judge") 33 | class CodeJudgeRewardManager(AbstractRewardManager): 34 | """The reward manager.""" 35 | 36 | def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None: 37 | """ 38 | Initialize the CodeJudgeRewardManager instance. 39 | 40 | Note that num_examine, compute_score, reward_fn_key is not used in this implementation. 41 | """ 42 | self.tokenizer = tokenizer # Store the tokenizer for decoding token IDs 43 | 44 | def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]: 45 | """We will expand this function gradually based on the available datasets""" 46 | 47 | # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn 48 | if "rm_scores" in data.batch.keys(): 49 | if return_dict: 50 | return {"reward_tensor": data.batch["rm_scores"]} 51 | else: 52 | return data.batch["rm_scores"] 53 | 54 | reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) 55 | reward_extra_info = defaultdict(list) 56 | 57 | for i in range(0, len(data), 64): 58 | batch_data = data[i : i + 64] 59 | tool_calls = [] 60 | for j in range(len(batch_data)): 61 | data_item = batch_data[j] # DataProtoItem 62 | 63 | if "response_text" in data_item.non_tensor_batch and data_item.non_tensor_batch["response_text"] is not None: 64 | response_str = data_item.non_tensor_batch["response_text"] 65 | else: 66 | response_ids = data_item.batch["responses"] 67 | prompt_length = data_item.batch["prompts"].shape[-1] 68 | valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() 69 | valid_response_ids = response_ids[:valid_response_length] 70 | response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) 71 | 72 | ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] 73 | tool_calls.append(self.create_tool_call(response_str, ground_truth)) 74 | 75 | results = self.execute_tool_calls(tool_calls) 76 | for j in range(len(results)): 77 | prompt_length = batch_data[j].batch["prompts"].shape[-1] 78 | valid_response_length = batch_data[j].batch["attention_mask"][prompt_length:].sum() 79 | reward_tensor[i + j, valid_response_length - 1] = results[j] 80 | 81 | if return_dict: 82 | return { 83 | "reward_tensor": reward_tensor, 84 | "reward_extra_info": reward_extra_info, 85 | } 86 | else: 87 | return reward_tensor 88 | 89 | def create_tool_call(self, solution_str: str, ground_truth: str): 90 | ground_truth = str(ground_truth) 91 | a = base64.b64encode(solution_str.encode()).decode() 92 | b = base64.b64encode(ground_truth.encode()).decode() 93 | code = verify_math_prefix.format(a, b) + verify_math_suffix 94 | return { 95 | "name": "compute_score", 96 | "arguments": { 97 | "code": code 98 | } 99 | } 100 | 101 | def extract_tool_call_result(self, result: str): 102 | if result is None: 103 | return 0.0 104 | match = re.search(r'(.*?)', result) 105 | return float(match.group(1)) if match else 0.0 106 | 107 | def execute_tool_calls(self, tool_calls): 108 | async def run_tool_calls(tool_calls): 109 | tool_connector = aiohttp.TCPConnector(limit=32, force_close=True, enable_cleanup_closed=True) 110 | tool_timeout = aiohttp.ClientTimeout(total=60) 111 | tool_session = aiohttp.ClientSession(connector=tool_connector, timeout=tool_timeout) 112 | responses = await run_tool_calls_on_server_async( 113 | tool_calls=tool_calls, 114 | session=tool_session, 115 | generate_tool_call_code=lambda x: x["arguments"]["code"], 116 | generate_tool_call_input=lambda x: None, 117 | ) 118 | await tool_session.close() 119 | return responses 120 | 121 | results = asyncio.run(run_tool_calls(tool_calls)) 122 | return [self.extract_tool_call_result(result) for result in results] 123 | -------------------------------------------------------------------------------- /rstar2_agent/tools/code_judge_tool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from functools import partial 5 | from typing import Any, Optional 6 | from uuid import uuid4 7 | 8 | import aiohttp 9 | import json 10 | 11 | from verl.utils.rollout_trace import rollout_trace_op 12 | from verl.tools.base_tool import BaseTool 13 | from verl.tools.schemas import OpenAIFunctionToolSchema, ToolResponse 14 | 15 | from .request_processor import RequestProcessor 16 | from .code_judge_utils import run_tool_calls_on_server_async, generate_tool_call_code, generate_tool_call_input 17 | 18 | 19 | class CodeJudgeTool(BaseTool): 20 | def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): 21 | super().__init__(config, tool_schema) 22 | self._instance_dict = {} 23 | 24 | host_addr = self.config.get("host_addr", "localhost") 25 | host_port = self.config.get("host_port", "8088") 26 | run_jupyter_tool_calls_on_server_async = partial( 27 | run_tool_calls_on_server_async, 28 | generate_tool_call_code=generate_tool_call_code, 29 | generate_tool_call_input=generate_tool_call_input, 30 | host_addr=host_addr, 31 | host_port=host_port, 32 | ) 33 | request_processor_batch_size = self.config.get("request_processor_batch_size", 1) 34 | request_processor_concurrency = self.config.get("request_processor_concurrency", 1) 35 | request_processor_batch_timeout_seconds = self.config.get("request_processor_batch_timeout_seconds", 30) 36 | tool_connector = aiohttp.TCPConnector(limit=request_processor_concurrency, force_close=True, enable_cleanup_closed=True) 37 | tool_timeout = aiohttp.ClientTimeout(total=60) 38 | tool_session = aiohttp.ClientSession(connector=tool_connector, timeout=tool_timeout) 39 | self.request_processor = RequestProcessor( 40 | batch_size=request_processor_batch_size, 41 | batch_timeout_seconds=request_processor_batch_timeout_seconds, 42 | session=tool_session, 43 | concurrency=request_processor_concurrency, 44 | batch_submit_func=run_jupyter_tool_calls_on_server_async, 45 | ) 46 | 47 | def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: 48 | return self.tool_schema 49 | 50 | async def _start_request_processor(self): 51 | if not self.request_processor._running: 52 | await self.request_processor.start() 53 | 54 | async def calc_reward(self, instance_id: str, **kwargs) -> str: 55 | return self._instance_dict[instance_id]["reward"] 56 | 57 | async def release(self, instance_id: str, **kwargs) -> None: 58 | del self._instance_dict[instance_id] 59 | 60 | 61 | class SimJupyterTool(CodeJudgeTool): 62 | async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]: 63 | if instance_id is None: 64 | instance_id = str(uuid4()) 65 | assert "history_tool_calls" in kwargs, "history_tool_calls must be provided in kwargs" 66 | await self._start_request_processor() 67 | history_tool_calls = [] 68 | for history_tool_call in kwargs["history_tool_calls"]: 69 | if history_tool_call.name == "jupyter_code": 70 | try: 71 | arguments = json.loads(history_tool_call.arguments) 72 | assert len(arguments) == 1 and "code" in arguments 73 | history_tool_calls.append({ 74 | "name": "jupyter_code", 75 | "arguments": { 76 | "code": arguments["code"], 77 | } 78 | }) 79 | except Exception as e: 80 | pass 81 | 82 | self._instance_dict[instance_id] = { 83 | "response": "", 84 | "reward": [], 85 | "history_tool_calls": history_tool_calls, 86 | } 87 | return instance_id, ToolResponse() 88 | 89 | @rollout_trace_op 90 | async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: 91 | code = parameters.get("code", "") 92 | tool_call = { 93 | "name": "jupyter_code", 94 | "arguments": { 95 | "code": code, 96 | }, 97 | "history_tool_calls": self._instance_dict[instance_id]["history_tool_calls"] 98 | } 99 | result_text = await self.request_processor.send_request(tool_call) 100 | return ToolResponse(text=result_text), 0.0, {} 101 | 102 | 103 | class PythonTool(CodeJudgeTool): 104 | async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]: 105 | if instance_id is None: 106 | instance_id = str(uuid4()) 107 | await self._start_request_processor() 108 | 109 | self._instance_dict[instance_id] = { 110 | "response": "", 111 | "reward": [], 112 | } 113 | return instance_id, ToolResponse() 114 | 115 | @rollout_trace_op 116 | async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: 117 | code = parameters.get("code", "") 118 | input = parameters.get("input", "") 119 | tool_call = { 120 | "name": "python_code_with_standard_io", 121 | "arguments": { 122 | "code": code, 123 | "input": input, 124 | }, 125 | } 126 | result_text = await self.request_processor.send_request(tool_call) 127 | return ToolResponse(text=result_text), 0.0, {} 128 | -------------------------------------------------------------------------------- /examples/chat_with_tool_call.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | """ 5 | # Launch VLLM server 6 | 7 | vllm serve /path/to/your/model \ 8 | --host 0.0.0.0 \ 9 | --port 8000 \ 10 | --enable-auto-tool-choice \ 11 | --tool-call-parser hermes 12 | 13 | # Check if the server is running well 14 | 15 | curl http://localhost:8000/v1/models 16 | """ 17 | import aiohttp 18 | import argparse 19 | import asyncio 20 | import json 21 | import requests 22 | import yaml 23 | 24 | from pathlib import Path 25 | 26 | from transformers import AutoTokenizer, PreTrainedTokenizer 27 | from verl.tools.schemas import ToolResponse 28 | 29 | from rstar2_agent.tools.code_judge_utils import ( 30 | run_tool_calls_on_server_async, 31 | generate_tool_call_code, 32 | generate_tool_call_input, 33 | ) 34 | from rstar2_agent.tools.tool_parser import ( 35 | RStar2AgentHermesToolParser, 36 | ) 37 | 38 | 39 | async def run_tool_calls(tool_calls): 40 | tool_connector = aiohttp.TCPConnector(limit=32, force_close=True, enable_cleanup_closed=True) 41 | tool_timeout = aiohttp.ClientTimeout(total=60) 42 | tool_session = aiohttp.ClientSession(connector=tool_connector, timeout=tool_timeout) 43 | responses = await run_tool_calls_on_server_async( 44 | tool_calls=tool_calls, 45 | session=tool_session, 46 | generate_tool_call_code=generate_tool_call_code, 47 | generate_tool_call_input=generate_tool_call_input, 48 | ) 49 | await tool_session.close() 50 | return responses 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--model", default="/path/to/your/model") 56 | parser.add_argument("--prompt", default="Solve the system of equations: 2x + 3y = 7, x - y = 1") 57 | parser.add_argument("--max_tokens", default=8192) 58 | args = parser.parse_args() 59 | 60 | project_dir = Path(__file__).parent.parent 61 | python_tool_yaml_path = project_dir / "rstar2_agent/config/tool_config/python_tool_config.yaml" 62 | with python_tool_yaml_path.open() as file: 63 | python_tool_schema = yaml.safe_load(file)["tools"][0]["tool_schema"] 64 | 65 | tools = [python_tool_schema] 66 | url = "http://localhost:8000/v1/completions" 67 | budget = int(args.max_tokens) 68 | prompt = f"You must put your answer inside tags, i.e., answer here . And your final answer will be extracted automatically by the \\boxed{{}} tag.\nThis is the problem:\n{args.prompt}" 69 | 70 | messages = [{"role": "user", "content": prompt}] 71 | tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True) 72 | parser = RStar2AgentHermesToolParser(tokenizer) 73 | 74 | prompt = tokenizer.apply_chat_template(messages, tools=tools, tokenize=False, add_generation_prompt=True) 75 | print(prompt) 76 | 77 | input_ids = tokenizer.apply_chat_template(messages, tools=tools, tokenize=True, add_generation_prompt=True) 78 | prompt_len = len(input_ids) 79 | 80 | payload = { 81 | "model": args.model, 82 | "prompt": input_ids, 83 | "temperature": 1.0, 84 | "max_tokens": budget, 85 | "skip_special_tokens": False, 86 | "include_stop_str_in_output": True 87 | } 88 | 89 | response = requests.post(url, json=payload).json() 90 | response: str = response["choices"][0]["text"] 91 | print(response) 92 | _, tool_calls = asyncio.run(parser.extract_tool_calls(responses_ids=tokenizer.encode(response))) 93 | 94 | while tool_calls: 95 | # execute tool call 96 | total_tool_responses, filtered_tool_calls, pending_pos = [], [], [] 97 | for i, tool_call in enumerate(tool_calls): 98 | if isinstance(tool_call, ToolResponse): 99 | total_tool_responses.append(tool_call.text) 100 | else: 101 | total_tool_responses.append(None) 102 | pending_pos.append(i) 103 | filtered_tool_calls.append(tool_call) 104 | 105 | if filtered_tool_calls: 106 | filtered_tool_calls = [{ 107 | "name": tool_call.name, 108 | "arguments": json.loads(tool_call.arguments), 109 | } for tool_call in filtered_tool_calls] 110 | filtered_tool_responses = asyncio.run(run_tool_calls(filtered_tool_calls)) 111 | for i, tool_response in zip(pending_pos, filtered_tool_responses): 112 | total_tool_responses[i] = tool_response 113 | 114 | # append assistant response to messages 115 | if response.endswith("<|im_end|>"): 116 | response = response[: -len("<|im_end|>")] 117 | assistant_msg = f"{response}" 118 | messages.append({"role": "assistant", "content": assistant_msg}) 119 | 120 | prefix_text = tokenizer.apply_chat_template(messages, tools=tools, tokenize=False, add_generation_prompt=False) 121 | 122 | # append tool responses to messages 123 | for tool_response in total_tool_responses: 124 | messages.append({"role": "tool", "content": tool_response}) 125 | 126 | entire_text = tokenizer.apply_chat_template(messages, tools=tools, tokenize=False, add_generation_prompt=True) 127 | print(entire_text[len(prefix_text):]) 128 | 129 | # next turn generation 130 | input_ids = tokenizer.apply_chat_template(messages, tools=tools, tokenize=True, add_generation_prompt=True) 131 | if budget > (len(input_ids) - prompt_len): 132 | payload = { 133 | "model": args.model, 134 | "prompt": input_ids, 135 | "temperature": 1.0, 136 | "max_tokens": budget - (len(input_ids) - prompt_len), 137 | "skip_special_tokens": False, 138 | "include_stop_str_in_output": True 139 | } 140 | 141 | response = requests.post(url, json=payload).json() 142 | response: str = response["choices"][0]["text"] 143 | print(response) 144 | _, tool_calls = asyncio.run(parser.extract_tool_calls(responses_ids=tokenizer.encode(response))) 145 | else: 146 | print(f"[Generation end: reached the maximum generation token number {(len(input_ids) - prompt_len)}]") 147 | break 148 | -------------------------------------------------------------------------------- /fused_compute_score/prime_math/math_normalize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 PRIME team and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright (c) 2021 Dan Hendrycks 16 | # 17 | # Permission is hereby granted, free of charge, to any person obtaining a copy 18 | # of this software and associated documentation files (the "Software"), to deal 19 | # in the Software without restriction, including without limitation the rights 20 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 21 | # copies of the Software, and to permit persons to whom the Software is 22 | # furnished to do so, subject to the following conditions: 23 | # 24 | # The above copyright notice and this permission notice shall be included in all 25 | # copies or substantial portions of the Software. 26 | # 27 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 28 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 29 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 30 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 31 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 32 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 33 | # SOFTWARE. 34 | """ 35 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence). 36 | 37 | From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py 38 | """ 39 | 40 | import re 41 | from typing import Optional 42 | 43 | 44 | def normalize_answer(answer: Optional[str]) -> Optional[str]: 45 | if answer is None: 46 | return None 47 | answer = answer.strip() 48 | try: 49 | # Remove enclosing `\text{}`. 50 | m = re.search("^\\\\text\{(?P.+?)\}$", answer) 51 | if m is not None: 52 | answer = m.group("text").strip() 53 | return _strip_string(answer) 54 | except: # noqa: E722 55 | return answer 56 | 57 | 58 | def _fix_fracs(string): 59 | substrs = string.split("\\frac") 60 | new_str = substrs[0] 61 | if len(substrs) > 1: 62 | substrs = substrs[1:] 63 | for substr in substrs: 64 | new_str += "\\frac" 65 | if substr[0] == "{": 66 | new_str += substr 67 | else: 68 | try: 69 | assert len(substr) >= 2 70 | except: # noqa: E722 71 | return string 72 | a = substr[0] 73 | b = substr[1] 74 | if b != "{": 75 | if len(substr) > 2: 76 | post_substr = substr[2:] 77 | new_str += "{" + a + "}{" + b + "}" + post_substr 78 | else: 79 | new_str += "{" + a + "}{" + b + "}" 80 | else: 81 | if len(substr) > 2: 82 | post_substr = substr[2:] 83 | new_str += "{" + a + "}" + b + post_substr 84 | else: 85 | new_str += "{" + a + "}" + b 86 | string = new_str 87 | return string 88 | 89 | 90 | def _fix_a_slash_b(string): 91 | if len(string.split("/")) != 2: 92 | return string 93 | a = string.split("/")[0] 94 | b = string.split("/")[1] 95 | try: 96 | a = int(a) 97 | b = int(b) 98 | assert string == "{}/{}".format(a, b) 99 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 100 | return new_string 101 | except: # noqa: E722 102 | return string 103 | 104 | 105 | def _remove_right_units(string): 106 | # "\\text{ " only ever occurs (at least in the val set) when describing units 107 | if "\\text{ " in string: 108 | splits = string.split("\\text{ ") 109 | assert len(splits) == 2 110 | return splits[0] 111 | else: 112 | return string 113 | 114 | 115 | def _fix_sqrt(string): 116 | if "\\sqrt" not in string: 117 | return string 118 | splits = string.split("\\sqrt") 119 | new_string = splits[0] 120 | for split in splits[1:]: 121 | if split[0] != "{": 122 | a = split[0] 123 | new_substr = "\\sqrt{" + a + "}" + split[1:] 124 | else: 125 | new_substr = "\\sqrt" + split 126 | new_string += new_substr 127 | return new_string 128 | 129 | 130 | def _strip_string(string): 131 | # linebreaks 132 | string = string.replace("\n", "") 133 | 134 | # remove inverse spaces 135 | string = string.replace("\\!", "") 136 | 137 | # replace \\ with \ 138 | string = string.replace("\\\\", "\\") 139 | 140 | # replace tfrac and dfrac with frac 141 | string = string.replace("tfrac", "frac") 142 | string = string.replace("dfrac", "frac") 143 | 144 | # remove \left and \right 145 | string = string.replace("\\left", "") 146 | string = string.replace("\\right", "") 147 | 148 | # Remove circ (degrees) 149 | string = string.replace("^{\\circ}", "") 150 | string = string.replace("^\\circ", "") 151 | 152 | # remove dollar signs 153 | string = string.replace("\\$", "") 154 | 155 | # remove units (on the right) 156 | string = _remove_right_units(string) 157 | 158 | # remove percentage 159 | string = string.replace("\\%", "") 160 | string = string.replace("\%", "") 161 | 162 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 163 | string = string.replace(" .", " 0.") 164 | string = string.replace("{.", "{0.") 165 | # if empty, return empty string 166 | if len(string) == 0: 167 | return string 168 | if string[0] == ".": 169 | string = "0" + string 170 | 171 | # to consider: get rid of e.g. "k = " or "q = " at beginning 172 | if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2: 173 | string = string.split("=")[1] 174 | 175 | # fix sqrt3 --> sqrt{3} 176 | string = _fix_sqrt(string) 177 | 178 | # remove spaces 179 | string = string.replace(" ", "") 180 | 181 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). 182 | # Also does a/b --> \\frac{a}{b} 183 | string = _fix_fracs(string) 184 | 185 | # manually change 0.5 --> \frac{1}{2} 186 | if string == "0.5": 187 | string = "\\frac{1}{2}" 188 | 189 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 190 | string = _fix_a_slash_b(string) 191 | 192 | return string 193 | -------------------------------------------------------------------------------- /rstar2_agent/down_sample/roc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import re 5 | import numpy as np 6 | import torch 7 | from pprint import pprint 8 | from typing import List 9 | from transformers import PreTrainedTokenizerFast 10 | 11 | from verl.protocol import DataProto 12 | from .utils import filter_by_mask, decode_prompt_response_str 13 | 14 | 15 | def resample_of_correct(batch: DataProto, tokenizer: PreTrainedTokenizerFast, config: dict, do_sample=True, world_size=None): 16 | roc_error_ratio = config["roc_error_ratio"] 17 | roc_answer_format = config["roc_answer_format"] 18 | min_zero_reward_trace_num = config["min_zero_reward_trace_num"] 19 | min_non_zero_reward_trace_num = config["min_non_zero_reward_trace_num"] 20 | down_sample_to_n = config["down_sample_to_n"] 21 | assert min_zero_reward_trace_num + min_non_zero_reward_trace_num <= down_sample_to_n, \ 22 | f"Invalid down sampling configuration: {min_zero_reward_trace_num=}, {min_non_zero_reward_trace_num=}, {down_sample_to_n=}" 23 | 24 | _, response_text = decode_prompt_response_str(batch, tokenizer) 25 | penalty_weights = np.zeros(len(response_text)) 26 | metrics = {} 27 | 28 | # calculate error ratio penalty weight 29 | _penalty_weights, _metrics = calc_error_ratio_penalty_weights(response_text) 30 | metrics.update(_metrics) 31 | if roc_error_ratio: 32 | penalty_weights += _penalty_weights 33 | 34 | # calculate format penalty weight 35 | _penalty_weights, _metrics = calc_format_penalty_weights(response_text) 36 | metrics.update(_metrics) 37 | if roc_answer_format: 38 | penalty_weights += _penalty_weights 39 | 40 | # sample by penalty weights 41 | if do_sample and down_sample_to_n > 0: 42 | uids = batch.non_tensor_batch['uid'] 43 | unique_uids = np.unique(uids) 44 | valid_mask = torch.zeros(len(uids), dtype=torch.bool) 45 | 46 | for uid in unique_uids: 47 | indices = np.where(uids == uid)[0] 48 | if len(indices) < down_sample_to_n: 49 | continue # Not enough samples for this uid, skip 50 | if len(indices) == down_sample_to_n: 51 | valid_mask[indices] = True 52 | continue 53 | uid_mask = uids == uid 54 | uid_rewards = batch.batch['token_level_scores'][uid_mask].sum(-1) 55 | 56 | zero_reward_pairs = [(indice, penalty_weight) for indice, uid_reward, penalty_weight in zip(indices, uid_rewards, penalty_weights[uid_mask]) if uid_reward <= 0] 57 | non_zero_reward_pairs = [(indice, penalty_weight) for indice, uid_reward, penalty_weight in zip(indices, uid_rewards, penalty_weights[uid_mask]) if uid_reward > 0] 58 | non_zero_reward_pairs.sort(key=lambda x: x[1]) 59 | zero_reward_trace_num = round(len(zero_reward_pairs) * down_sample_to_n / len(indices)) 60 | non_zero_reward_trace_num = round(len(non_zero_reward_pairs) * down_sample_to_n / len(indices)) 61 | if zero_reward_trace_num < min_zero_reward_trace_num and non_zero_reward_trace_num < min_non_zero_reward_trace_num: 62 | pprint(f"Total trace number before down sampling: {len(indices)}, smaller than {min_zero_reward_trace_num=} + {min_non_zero_reward_trace_num=}") 63 | valid_mask[indices] = True 64 | else: 65 | if zero_reward_trace_num <= min(min_zero_reward_trace_num, len(zero_reward_pairs)): 66 | zero_reward_trace_num = min(min_zero_reward_trace_num, len(zero_reward_pairs)) 67 | non_zero_reward_trace_num = down_sample_to_n - zero_reward_trace_num 68 | if non_zero_reward_trace_num <= min(min_non_zero_reward_trace_num, len(non_zero_reward_pairs)): 69 | non_zero_reward_trace_num = min(min_non_zero_reward_trace_num, len(non_zero_reward_pairs)) 70 | zero_reward_trace_num = down_sample_to_n - non_zero_reward_trace_num 71 | choices = [non_zero_reward_pair[0] for non_zero_reward_pair in non_zero_reward_pairs[:non_zero_reward_trace_num]] \ 72 | + [zero_reward_pair[0] for zero_reward_pair in zero_reward_pairs[:zero_reward_trace_num]] 73 | assert len(choices) == down_sample_to_n, f"{down_sample_to_n=} != {len(choices)}" 74 | valid_mask[choices] = True 75 | 76 | batch = filter_by_mask(batch, valid_mask, world_size) 77 | return batch, metrics 78 | 79 | 80 | def calc_error_ratio_penalty_weights(response_text: List[str]): 81 | def error_ratio(text, pattern=r'.*?'): 82 | matches = re.findall(pattern, text, re.DOTALL) 83 | error_count = len([match for match in matches if 'error' in match.lower()]) 84 | if len(matches) == 0: 85 | return 0.5, 0, 0 86 | else: 87 | return error_count / len(matches), error_count, len(matches) 88 | 89 | penalty_weights = [] 90 | total_error_count, total_res_count = 0, 0 91 | 92 | for text in response_text: 93 | penalty_weight, error_count, res_count = error_ratio(text) 94 | penalty_weights.append(penalty_weight) 95 | total_error_count += error_count 96 | total_res_count += res_count 97 | metrics = { 98 | 'roc_error_ratio/global_err_ratio': total_error_count / total_res_count if total_res_count > 0 else 0, 99 | 'roc_error_ratio/penalty_weight': np.mean(penalty_weights) if penalty_weights else 0, 100 | } 101 | return np.array(penalty_weights), metrics 102 | 103 | 104 | def calc_format_penalty_weights(response_text: List[str]): 105 | def answer_tag_repetition(text: str, answer_tags=["", ""], answer_pattern=r'.*?', turn_pattern=r'<\|im_start\|>assistant.*?<\|im_end\|>'): 106 | if any(ans_tag not in text for ans_tag in answer_tags): 107 | return 1.0, 0 108 | 109 | answer_tags_count = [text.count(ans_tag) for ans_tag in answer_tags] 110 | closed_ans_tag_count = len(re.findall(answer_pattern, text, re.DOTALL)) 111 | if any(ans_tag_count!=closed_ans_tag_count for ans_tag_count in answer_tags_count): 112 | return 1.0, closed_ans_tag_count 113 | 114 | matches = re.findall(turn_pattern, text, re.DOTALL) 115 | num_turns = len(matches) 116 | if num_turns == 0: 117 | return 1.0, closed_ans_tag_count 118 | 119 | penalty_weight = min((closed_ans_tag_count - 1) / num_turns, 1.0) 120 | return penalty_weight, closed_ans_tag_count 121 | 122 | penalty_weights = [] 123 | total_ans_count, zero_ans_count, one_ans_count, gt_one_ans_count = 0, 0, 0, 0 124 | for text in response_text: 125 | penalty_weight, ans_tag_count = answer_tag_repetition(text) 126 | penalty_weights.append(penalty_weight) 127 | total_ans_count += ans_tag_count 128 | zero_ans_count += (1 if ans_tag_count == 0 else 0) 129 | one_ans_count += (1 if ans_tag_count == 1 else 0) 130 | gt_one_ans_count += (1 if ans_tag_count > 1 else 0) 131 | 132 | metrics = { 133 | 'roc_answer_format/answer_per_rollout_mean': total_ans_count / len(response_text), 134 | 'roc_answer_format/zero_answer_count': zero_ans_count, 135 | 'roc_answer_format/one_answer_count': one_ans_count, 136 | 'roc_answer_format/gt_one_answer_count': gt_one_ans_count, 137 | } 138 | return np.array(penalty_weights), metrics 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | rStar2-Agent 4 |

5 | 6 |

7 | 📃 [Paper] 8 |

9 | 10 | Repo for "[rStar2-Agent: Agentic Reasoning Technical Report](https://huggingface.co/papers/2508.20722)". 11 | 12 | Authors: Ning Shang\*, Yifei Liu\*, Yi Zhu\*, Li Lyna Zhang\*†, Weijiang Xu, Xinyu Guan, Buze Zhang, Bingcheng Dong, Xudong Zhou, Bowen Zhang, Ying Xin, Ziming Miao, Scarlett Li, Fan Yang, Mao Yang† 13 | 14 |

15 | 16 |
17 | Figure 1: rStar2-Agent-14B reaches frontier-level math reasoning in just 510 RL training step 18 |

19 | 20 | ## News 21 | 22 | - **[07/15/2025]** Our rStar-Coder [paper](https://arxiv.org/abs/2505.21297) and [dataset](https://huggingface.co/datasets/microsoft/rStar-Coder) are released. We introduce a large-scale, verified dataset of 418K competition-level code problems with **test cases** of varying difficulty, enabling small LLMs (1.5B-14B) to achieve frontier-level code reasoning performance. 23 | - **[02/10/2025]** We are hiring interns! If you are interested in improving LLM reasoning, please send your CV to lzhani@microsoft.com. 24 | - **[01/21/2025]** rStar-Math code has been open-sourced. 25 | - **[01/09/2025]** rStar-Math paper is released: https://huggingface.co/papers/2501.04519. 26 | 27 | Note: Our prior work [Mutual Reasoning Makes Smaller LLMs Stronger Problem-Solvers](https://huggingface.co/papers/2408.06195) is open-sourced on the [rStar-mutualreasoning b](https://github.com/microsoft/rStar/tree/rStar-mutualreasoning) branch. 28 | 29 | Note: Our prior work [rStar-Math: Small LLMs Can Master Math Reasoning with Self-Evolved Deep Thinking](https://huggingface.co/papers/2501.04519) is open-sourced on the [rStar-math](https://github.com/microsoft/rStar/tree/rStar-math) branch. 30 | 31 | ## Contents 32 | - [Introduction](#Introduction) 33 | - [Try rStar2-Agent with Tool Calling](#Try-rStar2-Agent-with-Tool-Calling) 34 | - [Evaluation](#Evaluation) 35 | - [rStar2-Agent RL Training](#rStar2-Agent-RL-Training) 36 | - [Citation](#Citation) 37 | 38 | ## Introduction 39 | We introduce rStar2-Agent, a 14B math reasoning model that thinks smarter rather than merely longer, achieving performance comparable to 671B DeepSeek-R1 through pure agentic reinforcement learning. The model plans, reasons, and autonomously uses coding tools to efficiently explore, verify, and reflect for more complex problem-solving. This capability relies on three key innovations: (i) GRPO-RoC, an effective agentic reinforcement learning algorithm with a novel Resample-on-Correct rollout strategy that optimizes coding tool usage and enables shorter, smarter reasoning by selectively retaining higher-quality positive trajectories while preserving all failure cases; (ii) a scalable and efficient RL infrastructure that supports high-throughput tool call execution and mitigates the high costs of agentic RL rollout, enabling efficient training on limited GPU resources (64 MI300X GPUs); (iii) an agent training recipe that starts with non-reasoning SFT and proceeds through multi-stage RL with concise maximum response lengths per stage and increasing dataset difficulty. To this end, rStar2-Agent boosts a pre-trained 14B model to state-of-the-art levels in only 510 RL steps within one week, achieving 80.6% and 69.8% average pass@1 on AIME24 and AIME25, surpassing DeepSeek-R1 (671B) with shorter responses. Beyond mathematics, rStar2-Agent-14B also demonstrates strong generalization to alignment, scientific reasoning, and agentic tool-use tasks. 40 | 41 | ## Try rStar2-Agent with Tool Calling 42 | 43 | ### Installation 44 | 45 | #### Option 1: Manual Installation 46 | 47 | ```bash 48 | # Initialize and update submodules 49 | git submodule init 50 | git submodule update 51 | 52 | # install verl 53 | pip install "torch<2.8" 54 | pip install -r verl/requirements_sglang.txt 55 | pip install -e verl 56 | 57 | # install code judge 58 | pip install -r code-judge/requirements.txt 59 | pip install -e code-judge 60 | 61 | # install rstar2_agent 62 | pip install -e . 63 | ``` 64 | 65 | #### Option 2: Automated Installation 66 | 67 | ```bash 68 | bash install.sh 69 | ``` 70 | 71 | ### Code Judge Server Setup 72 | 73 | > ⚠️ **Security Warning**: Code Judge executes arbitrary code. Always deploy in an isolated environment (preferably Docker) and never expose to external networks. 74 | 75 | The rStar2-Agent uses Code Judge as a tool call server to execute model-generated Python code. 76 | 77 | #### 1. Start Redis Server 78 | 79 | ```bash 80 | sudo apt-get update -y && sudo apt-get install redis -y 81 | redis-server --daemonize yes --protected-mode no --bind 0.0.0.0 82 | ``` 83 | 84 | #### 2. Launch Code Judge Server 85 | 86 | ```bash 87 | # Start the main server (master node only) 88 | # Environment variables can be configured as per: https://github.com/0xWJ/code-judge/blob/main/app/config.py 89 | # Replace $WORKSPACE and $MASTER_ADDR with your actual paths 90 | 91 | tmux new-session -d -s server \ 92 | 'cd $WORKSPACE/code-judge && \ 93 | MAX_EXECUTION_TIME=4 \ 94 | REDIS_URI="redis://$MASTER_ADDR:6379" \ 95 | RUN_WORKERS=0 \ 96 | uvicorn app.main:app --host 0.0.0.0 --port 8088 --workers 16 \ 97 | 2>&1 | tee server.log' 98 | ``` 99 | 100 | #### 3. Start Code Judge Workers 101 | 102 | ```bash 103 | # Launch workers (can be deployed on multiple nodes for increased parallelism) 104 | # Adjust MAX_WORKERS based on your CPU count per node 105 | 106 | tmux new-session -d -s worker \ 107 | 'cd $WORKSPACE/code-judge && \ 108 | MAX_EXECUTION_TIME=4 \ 109 | REDIS_URI="redis://$MASTER_ADDR:6379" \ 110 | MAX_WORKERS=64 \ 111 | python run_workers.py \ 112 | 2>&1 | tee worker.log' 113 | ``` 114 | 115 | ### Launch the VLLM Server 116 | 117 | First, start the VLLM server: 118 | 119 | ```bash 120 | vllm serve /path/to/your/model \ 121 | --host 0.0.0.0 \ 122 | --port 8000 \ 123 | --enable-auto-tool-choice \ 124 | --tool-call-parser hermes 125 | ``` 126 | 127 | Replace `/path/to/your/model` with the actual path to your downloaded model. 128 | 129 | ### Verify Server Status 130 | 131 | Check if the server is running properly: 132 | 133 | ```bash 134 | curl http://localhost:8000/v1/models 135 | ``` 136 | 137 | ### Run Interactive Chat with Tool Calling 138 | 139 | Use the provided script to interact with your model: 140 | 141 | ```bash 142 | python examples/chat_with_tool_call.py \ 143 | --model /path/to/your/model \ 144 | --prompt "Solve the system of equations: 2x + 3y = 7, x - y = 1" \ 145 | --max_tokens 8192 146 | ``` 147 | 148 | ### Script Options 149 | 150 | The `examples/chat_with_tool_call.py` script supports the following arguments: 151 | 152 | - `--model`: Path to your model 153 | - `--prompt`: Input prompt for the model 154 | - `--max_tokens`: Maximum number of tokens to generate 155 | 156 | ## Evaluation 157 | 158 | ### Environment Setup 159 | 160 | Please view [Installation](#Installation) and [Code Judge Server Setup](#Code-Judge-Server-Setup). 161 | 162 | ### Run Evaluation Script 163 | 164 | We evaluate following mathematical reasoning benchmarks: 165 | 166 | - **AIME 2024/2025 (American Invitational Mathematics Examination)**: High-school level competition mathematics 167 | - **MATH500**: A subset of the MATH dataset containing 500 challenging problems 168 | 169 | ```bash 170 | MODEL_PATH=/path/to/your/model bash examples/aime_eval.sh 171 | MODEL_PATH=/path/to/your/model bash examples/math500_eval.sh 172 | ``` 173 | 174 | ## rStar2-Agent RL Training 175 | 176 | A comprehensive reinforcement learning training framework for the rStar2-Agent, built on [Verl](https://github.com/volcengine/verl) and [Code Judge](https://github.com/0xWJ/code-judge). This framework enables training models after instruction-following supervised fine-tuning (SFT). 177 | 178 | ### Environment Setup 179 | 180 | Please view [Installation](#Installation) and [Code Judge Server Setup](#Code-Judge-Server-Setup). 181 | 182 | ### Data Preparation 183 | 184 | This example uses: 185 | - **Training Dataset**: DAPO-17k (English subset) 186 | - **Test Dataset**: AIME24 187 | 188 | ```bash 189 | # Process AIME 2024 dataset 190 | python data_preprocess/aime2024_rstar2_agent_loop.py 191 | 192 | # Process DAPO dataset 193 | python data_preprocess/dapo_rstar2_agent_loop.py 194 | ``` 195 | 196 | ### Model Setup 197 | 198 | Download the base model (Qwen3-14B-Base): 199 | 200 | ```bash 201 | huggingface-cli download Qwen/Qwen3-14B-Base --local-dir $HOME/models/Qwen3-14B-Base 202 | ``` 203 | 204 | > **Note**: The base model requires instruction-following SFT before RL training for optimal performance. 205 | 206 | ### Training 207 | 208 | #### Basic Training 209 | 210 | Run the training script (for 8x A100/H100 GPUs): 211 | 212 | ```bash 213 | bash examples/run_qwen3-14b_rstar2_agent_weave.sh 214 | ``` 215 | 216 | > Adjust configuration parameters based on your hardware environment. 217 | 218 | ### Configuration 219 | 220 | #### Data Augmentation Settings 221 | 222 | The framework supports various sampling strategies to improve training efficiency: 223 | 224 | ```bash 225 | # Global Settings 226 | augmentation.do_down_sampling=True # Enable down sampling 227 | augmentation.down_sampling_config.down_sample_to_n=16 # Target number of traces per data point 228 | 229 | # Sampling Strategies 230 | augmentation.down_sampling_config.reject_equal_reward=True # Enable reject sampling for equal rewards 231 | augmentation.down_sampling_config.roc_error_ratio=True # Resample correct traces by tool call error ratio 232 | augmentation.down_sampling_config.roc_answer_format=True # Resample correct traces by answer format 233 | 234 | # Minimum Trace Requirements 235 | augmentation.down_sampling_config.min_zero_reward_trace_num=2 # Minimum negative traces to retain 236 | augmentation.down_sampling_config.min_non_zero_reward_trace_num=2 # Minimum positive traces to retain 237 | ``` 238 | 239 | ### Important Note 240 | 241 | rStar2-Agent was originally training based on VERL v0.2 with our custom multi-turn tool calling training framework. The current training framework released here has been migrated to VERL v0.5 to ensure compatibility with the latest community standards. While this release framework hasn't been used to train a complete model yet, we have verified that the first 50 training steps show minimal differences between our original and migrated frameworks, maintaining the core functionality of our proven training approach. 242 | 243 | Although our original framework includes additional advanced features such as rollout request load balance scheduler, we chose to migrate to the latest VERL version to maintain community compatibility and facilitate easier customization by users. This approach ensures you can benefit from ongoing VERL improvements and easily integrate with the latest open-source developments. We also consider migrating all features to the current version in the future. 244 | 245 | If you encounter any issues during usage or need assistance with the training framework, please contact us. 246 | 247 | ### Troubleshooting 248 | 249 | #### Common Issues 250 | 251 | 1. **Redis Connection Errors**: Ensure Redis is running and accessible at the specified address 252 | 2. **GPU Memory Issues**: Adjust batch sizes and model parameters for your hardware 253 | 3. **Code Judge Timeouts**: Increase `MAX_EXECUTION_TIME` for complex computations 254 | 4. **Worker Scaling**: Adjust `MAX_WORKERS` based on available CPU cores 255 | 256 | #### Log Locations 257 | 258 | - Server logs: `server.log` in the code-judge directory 259 | - Worker logs: `worker.log` in the code-judge directory 260 | - Training logs: Check your training script output directory 261 | 262 | --- 263 | 264 | 265 | ## Citation 266 | If you find this repo useful for your research, please consider citing the paper 267 | ``` 268 | @misc{shang2025rstar2agentagenticreasoningtechnical, 269 | title={rStar2-Agent: Agentic Reasoning Technical Report}, 270 | author={Ning Shang and Yifei Liu and Yi Zhu and Li Lyna Zhang and Weijiang Xu and Xinyu Guan and Buze Zhang and Bingcheng Dong and Xudong Zhou and Bowen Zhang and Ying Xin and Ziming Miao and Scarlett Li and Fan Yang and Mao Yang}, 271 | year={2025}, 272 | eprint={2508.20722}, 273 | archivePrefix={arXiv}, 274 | primaryClass={cs.CL}, 275 | url={https://arxiv.org/abs/2508.20722}, 276 | } 277 | ``` 278 | -------------------------------------------------------------------------------- /rstar2_agent/rollout/rstar2_agent_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import asyncio 5 | import copy 6 | import json 7 | import logging 8 | import os 9 | from typing import Any 10 | from uuid import uuid4 11 | 12 | from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register 13 | from verl.experimental.agent_loop.tool_agent_loop import ToolAgentLoop 14 | from verl.experimental.agent_loop.tool_parser import FunctionCall 15 | from verl.tools.schemas import ToolResponse 16 | from verl.utils.profiler import simple_timer 17 | from verl.utils.rollout_trace import rollout_trace_op 18 | 19 | logger = logging.getLogger(__file__) 20 | logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) 21 | 22 | 23 | @register("rstar2_agent") 24 | class RStar2AgentLoop(ToolAgentLoop): 25 | @rollout_trace_op 26 | async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: 27 | messages = list(kwargs["raw_prompt"]) 28 | image_data = copy.deepcopy(kwargs.get("multi_modal_data", {}).get("image", None)) 29 | metrics = {} 30 | request_id = uuid4().hex 31 | if self.processor is not None: 32 | raw_prompt = await self.loop.run_in_executor( 33 | None, 34 | lambda: self.processor.apply_chat_template( 35 | messages, 36 | tools=self.tool_schemas, 37 | add_generation_prompt=True, 38 | tokenize=False, 39 | **self.apply_chat_template_kwargs, 40 | ), 41 | ) 42 | model_inputs = self.processor(text=[raw_prompt], images=image_data, return_tensors="pt") 43 | prompt_ids = model_inputs.pop("input_ids").squeeze(0).tolist() 44 | else: 45 | prompt_ids = await self.loop.run_in_executor( 46 | None, 47 | lambda: self.tokenizer.apply_chat_template( 48 | messages, 49 | tools=self.tool_schemas, 50 | add_generation_prompt=True, 51 | tokenize=True, 52 | **self.apply_chat_template_kwargs, 53 | ), 54 | ) 55 | response_mask = [] 56 | tools_kwargs = kwargs.get("tools_kwargs", {}) 57 | ################################### rStar ################################### 58 | history_tool_calls = [] # Keep track of all tool calls made during the conversation 59 | # budget = len(prompt_ids) + self.response_length 60 | ############################################################################# 61 | 62 | user_turns, assistant_turns = 0, 0 63 | while True: 64 | with simple_timer("generate_sequences", metrics): 65 | ################################### rStar ################################### 66 | sampling_params["max_new_tokens"] = self.response_length - len(response_mask) 67 | ############################################################################# 68 | response_ids = await self.server_manager.generate( 69 | request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params, image_data=image_data 70 | ) 71 | prompt_ids += response_ids 72 | response_mask += [1] * len(response_ids) 73 | assistant_turns += 1 74 | 75 | # reach max response length 76 | if len(response_mask) >= self.response_length: 77 | # self.server_manager._release_request(request_id, budget) 78 | break 79 | 80 | # reach max assistant turns 81 | if self.max_assistant_turns and assistant_turns >= self.max_assistant_turns: 82 | # self.server_manager._release_request(request_id, budget) 83 | break 84 | 85 | # reach max user turns 86 | if self.max_user_turns and user_turns >= self.max_user_turns: 87 | # self.server_manager._release_request(request_id, budget) 88 | break 89 | 90 | # no tool calls 91 | _, tool_calls = await self.tool_parser.extract_tool_calls(response_ids) 92 | if not tool_calls: 93 | # self.server_manager._release_request(request_id, budget) 94 | break 95 | 96 | ################################### rStar ################################### 97 | tool_calls = tool_calls[: self.max_parallel_calls] 98 | total_tool_responses, filtered_tool_calls, pending_pos = [], [], [] 99 | for i, tool_call in enumerate(tool_calls): 100 | if isinstance(tool_call, ToolResponse): 101 | total_tool_responses.append(tool_call) 102 | else: 103 | total_tool_responses.append(None) 104 | pending_pos.append(i) 105 | filtered_tool_calls.append(tool_call) 106 | tool_calls = filtered_tool_calls 107 | ############################################################################# 108 | # call tools 109 | tasks = [] 110 | for tool_call in tool_calls[: self.max_parallel_calls]: 111 | ################################### rStar ################################### 112 | tools_kwargs_copy = dict(tools_kwargs) # Copy to avoid modifying original 113 | tools_kwargs_copy["history_tool_calls"] = list(history_tool_calls) # Pass history tool calls 114 | tasks.append(self._call_tool(tool_call, tools_kwargs_copy)) 115 | history_tool_calls.append(tool_call) 116 | ############################################################################# 117 | with simple_timer("tool_calls", metrics): 118 | tool_responses = await asyncio.gather(*tasks) 119 | ################################### rStar ################################### 120 | assert len(pending_pos[: self.max_parallel_calls]) == len(tool_responses) 121 | for i, tool_response in zip(pending_pos[: self.max_parallel_calls], tool_responses): 122 | total_tool_responses[i] = tool_response 123 | tool_responses = total_tool_responses 124 | ############################################################################# 125 | if any(isinstance(item, Exception) for item in tool_responses): 126 | # self.server_manager._release_request(request_id, budget) 127 | break 128 | 129 | # Extract messages and update multi_modal_data 130 | tool_messages = [] 131 | new_images_this_turn = [] 132 | for tool_response in tool_responses: 133 | # Create message from tool response 134 | if tool_response.image or tool_response.video: 135 | # Multi-modal content with structured format 136 | content = [] 137 | if tool_response.image: 138 | content.append({"type": "image"}) 139 | if tool_response.video: 140 | content.append({"type": "video"}) 141 | if tool_response.text: 142 | content.append({"type": "text", "text": tool_response.text}) 143 | message = {"role": "tool", "content": content} 144 | else: 145 | # Text-only content 146 | message = {"role": "tool", "content": tool_response.text or ""} 147 | 148 | tool_messages.append(message) 149 | 150 | # Handle image data 151 | if tool_response.image: 152 | if image_data is None: 153 | image_data = [] 154 | elif not isinstance(image_data, list): 155 | image_data = [image_data] 156 | 157 | # Add new image data 158 | if isinstance(tool_response.image, list): 159 | image_data.extend(tool_response.image) 160 | new_images_this_turn.extend(tool_response.image) 161 | else: 162 | image_data.append(tool_response.image) 163 | new_images_this_turn.append(tool_response.image) 164 | 165 | # Handle video data 166 | if tool_response.video: 167 | # Currently not supported, raise informative error 168 | logger.warning("Multimedia type 'video' is not currently supported. Only 'image' is supported.") 169 | raise NotImplementedError( 170 | "Multimedia type 'video' is not currently supported. Only 'image' is supported." 171 | ) 172 | 173 | # append tool_response_ids 174 | if self.processor is not None: 175 | raw_tool_response = await self.loop.run_in_executor( 176 | None, 177 | lambda messages=tool_messages: self.processor.apply_chat_template( 178 | messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs 179 | ), 180 | ) 181 | # Use only the new images from this turn for processing tool responses 182 | current_images = new_images_this_turn if new_images_this_turn else None 183 | model_inputs = self.processor(text=[raw_tool_response], images=current_images, return_tensors="pt") 184 | tool_response_ids = model_inputs.pop("input_ids").squeeze(0).tolist() 185 | else: 186 | tool_response_ids = await self.loop.run_in_executor( 187 | None, 188 | lambda messages=tool_messages: self.tokenizer.apply_chat_template( 189 | messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs 190 | ), 191 | ) 192 | tool_response_ids = tool_response_ids[len(self.system_prompt) :] 193 | 194 | # NOTE: last turn should not be user turn, or the EOS token reward 195 | # can't be propagated to previous token in GAE. 196 | if len(response_mask) + len(tool_response_ids) >= self.response_length: 197 | # self.server_manager._release_request(request_id, budget) 198 | break 199 | 200 | prompt_ids += tool_response_ids 201 | response_mask += [0] * len(tool_response_ids) 202 | user_turns += 1 203 | 204 | response_ids = prompt_ids[-len(response_mask) :] 205 | prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)] 206 | 207 | multi_modal_data = {"image": image_data} if image_data is not None else {} 208 | 209 | output = AgentLoopOutput( 210 | prompt_ids=prompt_ids, 211 | response_ids=response_ids[: self.response_length], 212 | response_mask=response_mask[: self.response_length], 213 | multi_modal_data=multi_modal_data, 214 | num_turns=user_turns + assistant_turns + 1, 215 | metrics=metrics, 216 | ) 217 | return output 218 | 219 | async def _call_tool(self, tool_call: FunctionCall, tools_kwargs: dict[str, Any]) -> ToolResponse: 220 | """Call tool and return tool response.""" 221 | tool, instance_id = None, None 222 | try: 223 | # TODO: append malformed tool_call to the prompt: invalid function name or arguments 224 | tool_name = tool_call.name 225 | tool_args = json.loads(tool_call.arguments) 226 | tool = self.tools[tool_name] 227 | kwargs = tools_kwargs.get(tool_name, {}) 228 | ################################### rStar ################################### 229 | instance_id, _ = await tool.create( 230 | create_kwargs=kwargs.get("create_kwargs", {}), 231 | history_tool_calls=tools_kwargs.get("history_tool_calls", []), 232 | ) 233 | ############################################################################# 234 | tool_execution_response, _, _ = await tool.execute(instance_id, tool_args) 235 | except Exception as e: 236 | logger.warning(f"Error when executing tool: {e}") 237 | return ToolResponse( 238 | text=f"Error when executing tool: {e}", 239 | ) 240 | finally: 241 | if tool and instance_id: 242 | await tool.release(instance_id) 243 | 244 | tool_response_text = tool_execution_response.text 245 | if tool_response_text and len(tool_response_text) > self.max_tool_response_length: 246 | if self.tool_response_truncate_side == "left": 247 | tool_response_text = tool_response_text[: self.max_tool_response_length] + "...(truncated)" 248 | elif self.tool_response_truncate_side == "right": 249 | tool_response_text = "(truncated)..." + tool_response_text[-self.max_tool_response_length :] 250 | else: 251 | length = self.max_tool_response_length // 2 252 | tool_response_text = tool_response_text[:length] + "...(truncated)..." + tool_response_text[-length:] 253 | 254 | # Create ToolResponse from tool execution result 255 | tool_response_kwargs = {"text": tool_response_text} 256 | 257 | # Add multimedia data if present 258 | for attr_name in ["image", "video"]: 259 | if hasattr(tool_execution_response, attr_name): 260 | attr_value = getattr(tool_execution_response, attr_name) 261 | if attr_value is not None: 262 | tool_response_kwargs[attr_name] = attr_value 263 | 264 | return ToolResponse(**tool_response_kwargs) 265 | -------------------------------------------------------------------------------- /fused_compute_score/prime_math/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 PRIME team and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Answer checker API that uses sympy to simplify expressions and check for equality. 16 | 17 | Call grade_answer(given_answer: str, ground_truth: str). 18 | 19 | FROM: https://github.com/openai/prm800k/blob/main/prm800k/grading/grader.py 20 | """ 21 | 22 | import contextlib 23 | import math 24 | import re 25 | 26 | import sympy 27 | from pylatexenc import latex2text 28 | from sympy.parsing import sympy_parser 29 | 30 | from . import math_normalize 31 | from .grader import math_equal, timeout_limit 32 | 33 | # import math_normalize 34 | # from grader import math_equal 35 | 36 | # sympy might hang -- we don't care about trying to be lenient in these cases 37 | BAD_SUBSTRINGS = ["^{", "^("] 38 | BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"] 39 | TUPLE_CHARS = "()[]" 40 | 41 | 42 | def _sympy_parse(expr: str): 43 | """Parses an expression with sympy.""" 44 | py_expr = expr.replace("^", "**") 45 | return sympy_parser.parse_expr( 46 | py_expr, 47 | transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)), 48 | ) 49 | 50 | 51 | def _parse_latex(expr: str) -> str: 52 | """Attempts to parse latex to an expression sympy can read.""" 53 | expr = expr.replace("\\tfrac", "\\frac") 54 | expr = expr.replace("\\dfrac", "\\frac") 55 | expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. 56 | expr = latex2text.LatexNodes2Text().latex_to_text(expr) 57 | 58 | # Replace the specific characters that this parser uses. 59 | expr = expr.replace("√", "sqrt") 60 | expr = expr.replace("π", "pi") 61 | expr = expr.replace("∞", "inf") 62 | expr = expr.replace("∪", "U") 63 | expr = expr.replace("·", "*") 64 | expr = expr.replace("×", "*") 65 | 66 | return expr.strip() 67 | 68 | 69 | def _is_float(num: str) -> bool: 70 | try: 71 | float(num) 72 | return True 73 | except ValueError: 74 | return False 75 | 76 | 77 | def _is_int(x: float) -> bool: 78 | try: 79 | return abs(x - int(round(x))) <= 1e-7 80 | except Exception: 81 | return False 82 | 83 | 84 | def _is_frac(expr: str) -> bool: 85 | return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) 86 | 87 | 88 | def _str_is_int(x: str) -> bool: 89 | try: 90 | x = _strip_properly_formatted_commas(x) 91 | x = float(x) 92 | return abs(x - int(round(x))) <= 1e-7 93 | except Exception: 94 | return False 95 | 96 | 97 | def _str_to_int(x: str) -> bool: 98 | x = x.replace(",", "") 99 | x = float(x) 100 | return int(x) 101 | 102 | 103 | def _inject_implicit_mixed_number(step: str): 104 | """ 105 | Automatically make a mixed number evalable 106 | e.g. 7 3/4 => 7+3/4 107 | """ 108 | p1 = re.compile("([0-9]) +([0-9])") 109 | step = p1.sub("\\1+\\2", step) ## implicit mults 110 | return step 111 | 112 | 113 | def _strip_properly_formatted_commas(expr: str): 114 | # We want to be careful because we don't want to strip tuple commas 115 | p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)") 116 | while True: 117 | next_expr = p1.sub("\\1\\3\\4", expr) 118 | if next_expr == expr: 119 | break 120 | expr = next_expr 121 | return next_expr 122 | 123 | 124 | def _normalize(expr: str) -> str: 125 | """Normalize answer expressions.""" 126 | if expr is None: 127 | return None 128 | 129 | # Remove enclosing `\text{}`. 130 | m = re.search(r"^\\text\{(?P.+?)\}$", expr) 131 | if m is not None: 132 | expr = m.group("text") 133 | 134 | expr = expr.replace("\\%", "%") 135 | expr = expr.replace("\\$", "$") 136 | expr = expr.replace("$", "") 137 | expr = expr.replace("%", "") 138 | expr = expr.replace(" or ", " , ") 139 | expr = expr.replace(" and ", " , ") 140 | 141 | expr = expr.replace("million", "*10^6") 142 | expr = expr.replace("billion", "*10^9") 143 | expr = expr.replace("trillion", "*10^12") 144 | 145 | for unit in [ 146 | "degree", 147 | "cm", 148 | "centimeter", 149 | "meter", 150 | "mile", 151 | "second", 152 | "minute", 153 | "hour", 154 | "day", 155 | "week", 156 | "month", 157 | "year", 158 | "foot", 159 | "feet", 160 | "inch", 161 | "yard", 162 | "liter", 163 | ]: 164 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) 165 | expr = re.sub("\^ *\\\\circ", "", expr) 166 | 167 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": 168 | expr = expr[1:-1] 169 | 170 | expr = re.sub(",\\\\! *", "", expr) 171 | if _is_float(expr) and _is_int(float(expr)): 172 | expr = str(int(round(float(expr)))) 173 | if "\\" in expr: 174 | with contextlib.suppress(Exception): 175 | expr = _parse_latex(expr) 176 | 177 | # edge case with mixed numbers and negative signs 178 | expr = re.sub("- *", "-", expr) 179 | 180 | expr = _inject_implicit_mixed_number(expr) 181 | 182 | # don't be case sensitive for text answers 183 | expr = expr.lower() 184 | 185 | if _str_is_int(expr): 186 | expr = str(_str_to_int(expr)) 187 | 188 | return expr 189 | 190 | 191 | def count_unknown_letters_in_expr(expr: str): 192 | expr = expr.replace("sqrt", "") 193 | expr = expr.replace("frac", "") 194 | letters_in_expr = set([x for x in expr if x.isalpha()]) 195 | return len(letters_in_expr) 196 | 197 | 198 | def should_allow_eval(expr: str): 199 | # we don't want to try parsing unknown text or functions of more than two variables 200 | if count_unknown_letters_in_expr(expr) > 2: 201 | return False 202 | 203 | for bad_string in BAD_SUBSTRINGS: 204 | if bad_string in expr: 205 | return False 206 | 207 | return all(re.search(bad_regex, expr) is None for bad_regex in BAD_REGEXES) 208 | 209 | 210 | @timeout_limit(seconds=10) 211 | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): 212 | are_equal = False 213 | try: 214 | expr = f"({ground_truth_normalized})-({given_normalized})" 215 | if should_allow_eval(expr): 216 | sympy_diff = _sympy_parse(expr) 217 | simplified = sympy.simplify(sympy_diff) 218 | if simplified == 0: 219 | are_equal = True 220 | except Exception: 221 | pass 222 | return are_equal 223 | 224 | 225 | def split_tuple(expr: str): 226 | """ 227 | Split the elements in a tuple/interval, while handling well-formatted commas in large numbers 228 | """ 229 | expr = _strip_properly_formatted_commas(expr) 230 | if len(expr) == 0: 231 | return [] 232 | if ( 233 | len(expr) > 2 234 | and expr[0] in TUPLE_CHARS 235 | and expr[-1] in TUPLE_CHARS 236 | and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) 237 | ): 238 | elems = [elem.strip() for elem in expr[1:-1].split(",")] 239 | else: 240 | elems = [expr] 241 | return elems 242 | 243 | 244 | def grade_answer(given_answer: str, ground_truth: str) -> bool: 245 | """ 246 | The answer will be considered correct if: 247 | (a) it normalizes to the same string as the ground truth answer 248 | OR 249 | (b) sympy can simplify the difference between the expressions to 0 250 | """ 251 | if given_answer is None: 252 | return False 253 | 254 | ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth) 255 | given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer) 256 | 257 | # be at least as lenient as mathd 258 | if ground_truth_normalized_mathd == given_answer_normalized_mathd: 259 | return True 260 | 261 | ground_truth_normalized = _normalize(ground_truth) 262 | given_normalized = _normalize(given_answer) 263 | 264 | if ground_truth_normalized is None: 265 | return False 266 | 267 | if ground_truth_normalized == given_normalized: 268 | return True 269 | 270 | if len(given_normalized) == 0: 271 | return False 272 | 273 | ground_truth_elems = split_tuple(ground_truth_normalized) 274 | given_elems = split_tuple(given_normalized) 275 | 276 | if ( 277 | len(ground_truth_elems) > 1 278 | and (ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]) 279 | or len(ground_truth_elems) != len(given_elems) 280 | ): 281 | is_correct = False 282 | else: 283 | for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True): 284 | if _is_frac(ground_truth_elem) and _is_frac(given_elem): 285 | # if fractions aren't reduced, then shouldn't be marked as correct 286 | # so, we don't want to allow sympy.simplify in this case 287 | is_correct = ground_truth_elem == given_elem 288 | elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): 289 | # if the ground truth answer is an integer, we require the given answer to be a strict match 290 | # (no sympy.simplify) 291 | is_correct = False 292 | else: 293 | try: 294 | is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) 295 | except Exception as e: 296 | # if there's an error, we'll just say it's not correct 297 | is_correct = False 298 | print(f"Error: {e} from are_equal_under_sympy, {ground_truth_elem}, {given_elem}") 299 | if not is_correct: 300 | break 301 | 302 | return is_correct 303 | 304 | 305 | def remove_boxed(s): 306 | left = "\\boxed{" 307 | try: 308 | assert s[: len(left)] == left 309 | assert s[-1] == "}" 310 | return s[len(left) : -1] 311 | except Exception: 312 | return None 313 | 314 | 315 | def _last_boxed_only_string(string): 316 | idx = string.rfind("\\boxed") 317 | if idx < 0: 318 | idx = string.rfind("\\fbox") 319 | if idx < 0: 320 | return None 321 | 322 | i = idx 323 | left_brace_idx = None 324 | right_brace_idx = None 325 | num_left_braces_open = 0 326 | while i < len(string): 327 | if string[i] == "{": 328 | num_left_braces_open += 1 329 | if left_brace_idx is None: 330 | left_brace_idx = i 331 | elif string[i] == "}": 332 | num_left_braces_open -= 1 333 | if num_left_braces_open == 0: 334 | right_brace_idx = i 335 | break 336 | 337 | i += 1 338 | 339 | if left_brace_idx is None or right_brace_idx is None: 340 | return None 341 | 342 | return string[left_brace_idx + 1 : right_brace_idx].strip() 343 | 344 | 345 | def match_answer(response): 346 | is_matched = False 347 | for ans_marker in ["answer:", "answer is", "answers are"]: 348 | ans_idx = response.lower().rfind(ans_marker) 349 | if ans_idx != -1: 350 | is_matched = True 351 | response = response[ans_idx + len(ans_marker) :].strip() 352 | if response.endswith("\n"): 353 | response = response[:-2] 354 | 355 | for ans_marker in ["is answer", "is the answer", "are answers", "are the answers"]: 356 | ans_idx = response.lower().rfind(ans_marker) 357 | if ans_idx != -1: 358 | is_matched = True 359 | response = response[:ans_idx].strip() 360 | if response.endswith("\n"): 361 | response = response[:-2] 362 | 363 | # Find boxed 364 | ans_boxed = _last_boxed_only_string(response) 365 | if ans_boxed: 366 | is_matched = True 367 | response = ans_boxed 368 | 369 | if ". " in response: 370 | dot_idx = response.lower().rfind(". ") 371 | if dot_idx != -1: 372 | response = response[:dot_idx].strip() 373 | 374 | for ans_marker in ["be ", "is ", "are ", "=", ": ", "get ", "be\n", "is\n", "are\n", ":\n", "get\n"]: 375 | ans_idx = response.lower().rfind(ans_marker) 376 | if ans_idx != -1: 377 | is_matched = True 378 | response = response[ans_idx + len(ans_marker) :].strip() 379 | if response.endswith("\n"): 380 | response = response[:-2] 381 | 382 | is_matched = is_matched if any([c.isdigit() for c in response]) else False # answer must have a digit 383 | # Grade 384 | return is_matched, response 385 | 386 | 387 | def compute_score(model_output: str, ground_truth: str) -> bool: 388 | model_output = str(model_output) 389 | ground_truth = str(ground_truth) 390 | 391 | is_matched, extracted_model_output = match_answer(model_output) 392 | format_correctness = "Step 2:" in model_output and "\\box" in model_output 393 | 394 | # grade simple algebra questions. if succeeded, return; otherwise, proceed to more complex grading 395 | if grade_answer(extracted_model_output, ground_truth): 396 | return True, True, extracted_model_output 397 | 398 | try: 399 | if "\pi" in extracted_model_output or "\pi" in ground_truth: 400 | equivs = [] 401 | for pi in [math.pi, 3.14]: 402 | equivs.append(math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi)) 403 | is_correct = any(equivs) 404 | else: 405 | is_correct = math_equal(extracted_model_output, ground_truth, timeout=True) 406 | except Exception: 407 | is_correct = False 408 | 409 | return is_correct, format_correctness, extracted_model_output 410 | -------------------------------------------------------------------------------- /rstar2_agent/main_rstar2_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | """ 5 | The different of this file and the verl/trainer/main_ppo.py is the usage of RStar2AgentRayTrainer instead of RayPPOTrainer. 6 | """ 7 | 8 | import os 9 | import socket 10 | 11 | import hydra 12 | import ray 13 | from omegaconf import OmegaConf 14 | 15 | from verl.trainer.constants_ppo import get_ppo_ray_runtime_env 16 | from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler 17 | from verl.trainer.ppo.reward import load_reward_manager 18 | from verl.utils.device import is_cuda_available 19 | 20 | from .rstar2_agent_ray_trainer import RStar2AgentRayTrainer 21 | 22 | 23 | @hydra.main(config_path="config", config_name="rstar2_agent_trainer", version_base=None) 24 | def main(config): 25 | """Main entry point for PPO training with Hydra configuration management. 26 | 27 | Args: 28 | config_dict: Hydra configuration dictionary containing training parameters. 29 | """ 30 | run_ppo(config) 31 | 32 | 33 | # Define a function to run the PPO-like training process 34 | def run_ppo(config) -> None: 35 | """Initialize Ray cluster and run distributed PPO training process. 36 | 37 | Args: 38 | config: Training configuration object containing all necessary parameters 39 | for distributed PPO training including Ray initialization settings, 40 | model paths, and training hyperparameters. 41 | """ 42 | # Check if Ray is not initialized 43 | if not ray.is_initialized(): 44 | # Initialize Ray with a local cluster configuration 45 | # Set environment variables in the runtime environment to control tokenizer parallelism, 46 | # NCCL debug level, VLLM logging level, and allow runtime LoRA updating 47 | # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration 48 | ray.init( 49 | runtime_env=get_ppo_ray_runtime_env(), 50 | num_cpus=config.ray_init.num_cpus, 51 | ) 52 | 53 | # Create a remote instance of the TaskRunner class, and 54 | # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete 55 | if ( 56 | is_cuda_available 57 | and config.global_profiler.tool == "nsys" 58 | and config.global_profiler.get("steps") is not None 59 | and len(config.global_profiler.get("steps", [])) > 0 60 | ): 61 | from verl.utils.import_utils import is_nvtx_available 62 | 63 | assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" 64 | nsight_options = OmegaConf.to_container( 65 | config.global_profiler.global_tool_config.nsys.controller_nsight_options 66 | ) 67 | runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() 68 | else: 69 | runner = TaskRunner.remote() 70 | ray.get(runner.run.remote(config)) 71 | 72 | # [Optional] get the path of the timeline trace file from the configuration, default to None 73 | # This file is used for performance analysis 74 | timeline_json_file = config.ray_init.get("timeline_json_file", None) 75 | if timeline_json_file: 76 | ray.timeline(filename=timeline_json_file) 77 | 78 | 79 | @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head 80 | class TaskRunner: 81 | """Ray remote class for executing distributed PPO training tasks. 82 | 83 | This class encapsulates the main training logic and runs as a Ray remote actor 84 | to enable distributed execution across multiple nodes and GPUs. 85 | 86 | Attributes: 87 | role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes 88 | mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation 89 | """ 90 | 91 | def __init__(self): 92 | self.role_worker_mapping = {} 93 | self.mapping = {} 94 | 95 | def add_actor_rollout_worker(self, config): 96 | """Add actor rollout worker based on the actor strategy.""" 97 | from verl.single_controller.ray import RayWorkerGroup 98 | 99 | if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: 100 | from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker 101 | 102 | actor_rollout_cls = ( 103 | AsyncActorRolloutRefWorker 104 | if config.actor_rollout_ref.rollout.mode == "async" 105 | else ActorRolloutRefWorker 106 | ) 107 | ray_worker_group_cls = RayWorkerGroup 108 | 109 | elif config.actor_rollout_ref.actor.strategy == "megatron": 110 | from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker 111 | 112 | actor_rollout_cls = ( 113 | AsyncActorRolloutRefWorker 114 | if config.actor_rollout_ref.rollout.mode == "async" 115 | else ActorRolloutRefWorker 116 | ) 117 | ray_worker_group_cls = RayWorkerGroup 118 | 119 | else: 120 | raise NotImplementedError 121 | 122 | from verl.trainer.ppo.ray_trainer import Role 123 | 124 | self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) 125 | 126 | return actor_rollout_cls, ray_worker_group_cls 127 | 128 | def add_critic_worker(self, config): 129 | """Add critic worker to role mapping.""" 130 | if config.critic.strategy in {"fsdp", "fsdp2"}: 131 | use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") 132 | if use_legacy_worker_impl in ["auto", "enable"]: 133 | from verl.workers.fsdp_workers import CriticWorker 134 | elif use_legacy_worker_impl == "disable": 135 | from verl.workers.roles import CriticWorker 136 | 137 | print("Using new worker implementation") 138 | else: 139 | raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") 140 | 141 | elif config.critic.strategy == "megatron": 142 | from verl.workers.megatron_workers import CriticWorker 143 | 144 | else: 145 | raise NotImplementedError 146 | 147 | from verl.trainer.ppo.ray_trainer import Role 148 | 149 | self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker) 150 | 151 | def init_resource_pool_mgr(self, config): 152 | """Initialize resource pool manager.""" 153 | from verl.trainer.ppo.ray_trainer import Role 154 | 155 | global_pool_id = "global_pool" 156 | resource_pool_spec = { 157 | global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, 158 | } 159 | self.mapping[Role.ActorRollout] = global_pool_id 160 | self.mapping[Role.Critic] = global_pool_id 161 | from verl.trainer.ppo.ray_trainer import ResourcePoolManager 162 | 163 | resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping) 164 | return resource_pool_manager 165 | 166 | def add_reward_model_worker(self, config): 167 | """Add reward model worker if enabled.""" 168 | from verl.trainer.ppo.ray_trainer import Role 169 | 170 | if config.reward_model.enable: 171 | if config.reward_model.strategy in {"fsdp", "fsdp2"}: 172 | from verl.workers.fsdp_workers import RewardModelWorker 173 | elif config.reward_model.strategy == "megatron": 174 | from verl.workers.megatron_workers import RewardModelWorker 175 | else: 176 | raise NotImplementedError 177 | self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) 178 | self.mapping[Role.RewardModel] = "global_pool" 179 | 180 | def add_ref_policy_worker(self, config, ref_policy_cls): 181 | """Add reference policy worker if KL loss or KL reward is used.""" 182 | from verl.trainer.ppo.ray_trainer import Role 183 | 184 | if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: 185 | self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls) 186 | self.mapping[Role.RefPolicy] = "global_pool" 187 | 188 | def run(self, config): 189 | """Execute the main PPO training workflow. 190 | 191 | This method sets up the distributed training environment, initializes 192 | workers, datasets, and reward functions, then starts the training process. 193 | 194 | Args: 195 | config: Training configuration object containing all parameters needed 196 | for setting up and running the PPO training process. 197 | """ 198 | # Print the initial configuration. `resolve=True` will evaluate symbolic values. 199 | from pprint import pprint 200 | 201 | from omegaconf import OmegaConf 202 | 203 | from verl.utils.fs import copy_to_local 204 | 205 | print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") 206 | pprint(OmegaConf.to_container(config, resolve=True)) 207 | OmegaConf.resolve(config) 208 | 209 | # Download the checkpoint from HDFS to the local machine. 210 | # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on 211 | local_path = copy_to_local( 212 | config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) 213 | ) 214 | 215 | # Instantiate the tokenizer and processor. 216 | from verl.utils import hf_processor, hf_tokenizer 217 | 218 | trust_remote_code = config.data.get("trust_remote_code", False) 219 | tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) 220 | # Used for multimodal LLM, could be None 221 | processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) 222 | 223 | actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) 224 | self.add_critic_worker(config) 225 | 226 | # We should adopt a multi-source reward function here: 227 | # - for rule-based rm, we directly call a reward score 228 | # - for model-based rm, we call a model 229 | # - for code related prompt, we send to a sandbox if there are test cases 230 | # finally, we combine all the rewards together 231 | # The reward type depends on the tag of the data 232 | self.add_reward_model_worker(config) 233 | 234 | # Add a reference policy worker if KL loss or KL reward is used. 235 | self.add_ref_policy_worker(config, actor_rollout_cls) 236 | 237 | ################################### rStar ################################### 238 | # support data.filter_overlong_prompts 239 | if config.actor_rollout_ref.model.get("custom_chat_template", None) is not None: 240 | if processor is not None: 241 | processor.chat_template = config.actor_rollout_ref.model.custom_chat_template 242 | tokenizer.chat_template = config.actor_rollout_ref.model.custom_chat_template 243 | 244 | tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path 245 | tool_list = [] 246 | if tool_config_path is not None: 247 | from verl.tools.utils.tool_registry import ToolType, get_tool_class, OpenAIFunctionToolSchema 248 | tools_config = OmegaConf.load(tool_config_path) 249 | for tool_config in tools_config.tools: 250 | tool_type = ToolType(tool_config.config.type) 251 | assert tool_type is ToolType.NATIVE 252 | if tool_config.get("tool_schema", None) is None: 253 | tool_schema = None 254 | else: 255 | tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) 256 | tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict).model_dump(exclude_unset=True, exclude_none=True) 257 | tool_list.append(tool_schema) 258 | ############################################################################# 259 | 260 | # Load the reward manager for training and validation. 261 | reward_fn = load_reward_manager( 262 | config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) 263 | ) 264 | val_reward_fn = load_reward_manager( 265 | config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) 266 | ) 267 | 268 | resource_pool_manager = self.init_resource_pool_mgr(config) 269 | 270 | from verl.utils.dataset.rl_dataset import collate_fn 271 | 272 | # Create training and validation datasets. 273 | train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) 274 | val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) 275 | ################################### rStar ################################### 276 | train_dataset.dataframe = train_dataset.maybe_filter_out_long_prompts(train_dataset.dataframe, tools=tool_list) 277 | val_dataset.dataframe = val_dataset.maybe_filter_out_long_prompts(val_dataset.dataframe, tools=tool_list) 278 | ############################################################################# 279 | train_sampler = create_rl_sampler(config.data, train_dataset) 280 | 281 | # Initialize the rstar2 agent PPO trainer. 282 | trainer = RStar2AgentRayTrainer( 283 | config=config, 284 | tokenizer=tokenizer, 285 | processor=processor, 286 | role_worker_mapping=self.role_worker_mapping, 287 | resource_pool_manager=resource_pool_manager, 288 | ray_worker_group_cls=ray_worker_group_cls, 289 | reward_fn=reward_fn, 290 | val_reward_fn=val_reward_fn, 291 | train_dataset=train_dataset, 292 | val_dataset=val_dataset, 293 | collate_fn=collate_fn, 294 | train_sampler=train_sampler, 295 | ) 296 | # Initialize the workers of the trainer. 297 | trainer.init_workers() 298 | # Start the training process. 299 | trainer.fit() 300 | 301 | 302 | if __name__ == "__main__": 303 | main() 304 | -------------------------------------------------------------------------------- /rstar2_agent/tools/code_judge_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import json 5 | import aiohttp 6 | import asyncio 7 | import traceback 8 | import os 9 | import datetime 10 | 11 | from typing import Dict, List, Literal, Callable, Optional 12 | 13 | # Global variable to store the path for failed submissions 14 | _failed_submissions_path = os.path.expanduser("~") 15 | 16 | 17 | def set_failed_submissions_path(path: str): 18 | """ 19 | Set the path where failed submissions will be saved. 20 | 21 | Args: 22 | path: The directory path to save failed submissions 23 | """ 24 | global _failed_submissions_path 25 | _failed_submissions_path = os.path.expanduser(path) 26 | # Create directory if it doesn't exist 27 | os.makedirs(_failed_submissions_path, exist_ok=True) 28 | print(f"Failed submissions will be saved to: {_failed_submissions_path}") 29 | 30 | 31 | def get_failed_submissions_path() -> str: 32 | """ 33 | Get the current path where failed submissions will be saved. 34 | 35 | Returns: 36 | The current path for saving failed submissions 37 | """ 38 | return _failed_submissions_path 39 | 40 | 41 | async def call_long_batch( 42 | url: str, 43 | submissions: List[Dict], 44 | session: aiohttp.ClientSession, 45 | max_retries: int = 4, 46 | backoff_factor: float = 0.5): 47 | 48 | sub_num = len(submissions) 49 | results = [None] * sub_num 50 | sub_ids = list(range(sub_num)) 51 | attempt_count = 0 52 | while submissions and attempt_count < max_retries: 53 | attempt_count += 1 54 | try: 55 | data = { 56 | "type": "batch", 57 | "submissions": submissions 58 | } 59 | queue_timeouts = [] 60 | async with session.post(url, json=data) as response: 61 | response.raise_for_status() 62 | response_json = await response.json() 63 | for sub_id, result in zip(sub_ids, response_json['results']): 64 | if result['reason'] != 'queue_timeout': 65 | results[sub_id] = result 66 | else: 67 | queue_timeouts.append((sub_id, submissions[sub_id])) 68 | submissions = [sub for _, sub in queue_timeouts] 69 | sub_ids = [sub_id for sub_id, _ in queue_timeouts] 70 | except aiohttp.ClientResponseError as e: 71 | print(f"Attempt {attempt_count}: Server responded with {e.status}") 72 | except (aiohttp.ClientError, asyncio.TimeoutError) as e: 73 | print(f"Attempt {attempt_count}: Caught {type(e).__name__}: {repr(e)}") 74 | except Exception as e: 75 | print(f"run_tool_calls_on_server_async Error: {e}") 76 | traceback.print_exc() 77 | finally: 78 | await asyncio.sleep(backoff_factor * (2 ** (attempt_count - 1))) 79 | 80 | # Save failed submissions to file if any remain after max retries 81 | if submissions: 82 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 83 | failed_file = os.path.join(_failed_submissions_path, f"failed_submissions_{timestamp}.json") 84 | 85 | failed_data = { 86 | "timestamp": timestamp, 87 | "url": url, 88 | "max_retries": max_retries, 89 | "failed_submissions": [] 90 | } 91 | 92 | for sub_id, submission in zip(sub_ids, submissions): 93 | failed_data["failed_submissions"].append({ 94 | "original_index": sub_id, 95 | "submission": submission 96 | }) 97 | 98 | try: 99 | with open(failed_file, 'w', encoding='utf-8') as f: 100 | json.dump(failed_data, f, indent=2, ensure_ascii=False) 101 | print(f"Saved {len(submissions)} failed submissions to: {failed_file}") 102 | except Exception as e: 103 | print(f"Failed to save failed submissions: {e}") 104 | 105 | return results 106 | 107 | 108 | async def run_tool_calls_on_server_async( 109 | tool_calls: List, 110 | session: aiohttp.ClientSession, 111 | language: Literal["python", "cpp"] = "python", 112 | max_retries: int = 4, 113 | backoff_factor: float = 0.5, 114 | generate_tool_call_code: Callable = None, 115 | generate_tool_call_input: Callable = None, 116 | host_addr: str = "localhost", 117 | host_port: str = "8088"): 118 | submissions = [] 119 | for tool_call in tool_calls: 120 | submissions.append({ 121 | "type": language, 122 | "solution": generate_tool_call_code(tool_call), 123 | "input": generate_tool_call_input(tool_call), 124 | }) 125 | 126 | url = f"http://{host_addr}:{host_port}/run/long-batch" 127 | results = await call_long_batch(url, submissions, session, max_retries, backoff_factor) 128 | 129 | if None in results: 130 | failed_indices = [i for i, result in enumerate(results) if result is None] 131 | # throw an error if any tool call failed after max retries 132 | if len(failed_indices) > 0: 133 | raise RuntimeError(f"run_tool_calls_on_server_async failed for {len(failed_indices)} tool calls after {max_retries} attempts.") 134 | 135 | for i in range(len(results)): 136 | if results[i]['run_success'] and results[i]['success']: 137 | output_parts = [] 138 | output_parts.append('Tool call success') 139 | if results[i]["stdout"]: 140 | output_parts.append(f'stdout: {results[i]["stdout"]}') 141 | if results[i]["stderr"]: 142 | output_parts.append(f'stderr: {results[i]["stderr"]}') 143 | output_parts.append(f'execution time: {results[i]["cost"]:.2f}s') 144 | results[i] = '\n'.join(output_parts) 145 | else: 146 | output_parts = [] 147 | output_parts.append('Tool call failure') 148 | output_parts.append(f'reason: {results[i]["reason"]}') 149 | if results[i]["stdout"]: 150 | output_parts.append(f'stdout: {results[i]["stdout"]}') 151 | if results[i]["stderr"]: 152 | output_parts.append(f'stderr: {results[i]["stderr"]}') 153 | output_parts.append(f'execution time: {results[i]["cost"]:.2f}s') 154 | results[i] = '\n'.join(output_parts) 155 | 156 | return results 157 | 158 | 159 | ### Generate tool call code 160 | 161 | code_template_setup = ''' 162 | import os 163 | import base64 164 | import sys 165 | import ast 166 | import traceback 167 | from typing import Optional, Any 168 | import linecache 169 | from types import CodeType 170 | from contextlib import redirect_stdout, redirect_stderr 171 | from io import StringIO 172 | 173 | class CodeExecutionError(Exception): 174 | """Custom exception for code execution errors with line information""" 175 | def __init__(self, original_error: Exception, code: str, line_offset: int = 0): 176 | self.original_error = original_error 177 | self.code = code 178 | self.line_offset = line_offset 179 | 180 | # Get error line number 181 | if hasattr(original_error, 'lineno'): 182 | self.lineno = original_error.lineno 183 | else: 184 | tb = getattr(original_error, '__traceback__', None) 185 | if tb: 186 | while tb.tb_next: 187 | tb = tb.tb_next 188 | self.lineno = tb.tb_lineno 189 | else: 190 | self.lineno = -1 191 | 192 | # Adjust line number for code segment 193 | if self.lineno != -1: 194 | self.lineno += line_offset 195 | 196 | # Format error message 197 | error_type = type(original_error).__name__ 198 | error_msg = str(original_error) 199 | 200 | if self.lineno != -1: 201 | # Get the problematic line 202 | lines = code.splitlines() 203 | if 0 <= self.lineno - 1 < len(lines): 204 | error_line = lines[self.lineno - 1] 205 | # Create error message with line information 206 | super().__init__(f"{error_type} at line {self.lineno}: {error_msg}\\n {error_line}") 207 | return 208 | 209 | super().__init__(f"{error_type}: {error_msg}") 210 | 211 | class PersistentExecutor: 212 | def __init__(self): 213 | self.exec_globals = { 214 | '__name__': '__main__', 215 | '__file__': '', 216 | '__builtins__': __builtins__ 217 | } 218 | 219 | def split_code(self, code: str) -> tuple[str, Optional[str]]: 220 | """ 221 | Intelligently split code into main body and last expression 222 | 223 | Args: 224 | code: The source code string 225 | 226 | Returns: 227 | tuple[str, Optional[str]]: (main code body, last expression if exists) 228 | """ 229 | try: 230 | # Parse code into AST 231 | tree = ast.parse(code) 232 | if not tree.body: 233 | return code, None 234 | 235 | # Check if the last node is a pure expression (not a call) 236 | last_node = tree.body[-1] 237 | if isinstance(last_node, ast.Expr): 238 | # Get the line range of the last expression 239 | last_expr_start = last_node.lineno 240 | last_expr_end = last_node.end_lineno if hasattr(last_node, 'end_lineno') else last_node.lineno 241 | 242 | # Split the code 243 | lines = code.splitlines() 244 | main_code = '\\n'.join(lines[:last_expr_start-1]) 245 | last_expr = '\\n'.join(lines[last_expr_start-1:last_expr_end]) 246 | return main_code, last_expr 247 | except SyntaxError as e: 248 | raise CodeExecutionError(e, code) 249 | return code, None 250 | 251 | def execute_code(self, code: str, replay_history_code: bool) -> None: 252 | """ 253 | Execute code while maintaining persistent environment state. 254 | If the last line is an expression, its value will be printed to stdout. 255 | 256 | Args: 257 | code: The source code string to execute 258 | replay_history_code: If True, suppress stdout and stderr output 259 | """ 260 | try: 261 | # Split code intelligently 262 | main_code, last_expr = self.split_code(code) 263 | 264 | # Set up output redirection if replay_history_code is True 265 | if replay_history_code: 266 | stdout_capture = StringIO() 267 | stderr_capture = StringIO() 268 | stdout_context = redirect_stdout(stdout_capture) 269 | stderr_context = redirect_stderr(stderr_capture) 270 | else: 271 | stdout_context = redirect_stdout(sys.stdout) 272 | stderr_context = redirect_stderr(sys.stderr) 273 | 274 | # Execute main code body 275 | if main_code: 276 | try: 277 | # Compile code to get better error line numbers 278 | compiled_code = compile(main_code, '', 'exec') 279 | with stdout_context, stderr_context: 280 | exec(compiled_code, self.exec_globals) 281 | except Exception as e: 282 | raise CodeExecutionError(e, main_code) 283 | 284 | # If there's a last expression, try to evaluate and print it 285 | if last_expr: 286 | try: 287 | # Compile expression to get better error line numbers 288 | compiled_expr = compile(last_expr, '', 'eval') 289 | with stdout_context, stderr_context: 290 | last_value = eval(compiled_expr, self.exec_globals) 291 | 292 | # Only print the result if not in replay mode 293 | if last_value is not None and not replay_history_code: 294 | print(repr(last_value), file=sys.stdout) 295 | except Exception as e: 296 | # Try executing as statement if evaluation fails 297 | try: 298 | compiled_stmt = compile(last_expr, '', 'exec') 299 | with stdout_context, stderr_context: 300 | exec(compiled_stmt, self.exec_globals) 301 | except Exception as e: 302 | # Calculate line offset for the last expression 303 | line_offset = len(main_code.splitlines()) if main_code else 0 304 | raise CodeExecutionError(e, last_expr, line_offset) 305 | 306 | except Exception as e: 307 | if replay_history_code: 308 | return 309 | if isinstance(e, CodeExecutionError): 310 | print(str(e), file=sys.stderr) 311 | else: 312 | traceback.print_exc(file=sys.stderr) 313 | os._exit(1) 314 | return 315 | 316 | persistent_executor = PersistentExecutor() 317 | ''' 318 | 319 | code_template_exec = ''' 320 | code_to_execute = base64.b64decode("{}".encode()).decode() 321 | persistent_executor.execute_code(code_to_execute, replay_history_code={}) 322 | ''' 323 | 324 | def combine_code_template(code_to_execute: str, history_code_to_execute: Optional[List[str]] = None) -> str: 325 | history_code_to_execute = history_code_to_execute or [] 326 | final_code = code_template_setup 327 | for history_code in history_code_to_execute: 328 | final_code += code_template_exec.format(history_code, "True") 329 | final_code += code_template_exec.format(code_to_execute, "False") 330 | return final_code 331 | 332 | 333 | def generate_tool_call_code(tool_call: Dict) -> str: 334 | import base64 335 | 336 | def jupyter_code_gencode(json_format_data: Dict) -> str: 337 | code_to_execute = base64.b64encode(json_format_data["arguments"]["code"].encode()).decode() 338 | history_code_to_execute = [ 339 | base64.b64encode(tool_call_json["arguments"]["code"].encode()).decode() 340 | for tool_call_json in json_format_data.get("history_tool_calls", []) if tool_call_json["name"] == "jupyter_code" 341 | ] 342 | return combine_code_template(code_to_execute, history_code_to_execute) 343 | 344 | def python_code_with_standard_io_gencode(json_format_data: Dict) -> str: 345 | code_to_execute = base64.b64encode(json_format_data["arguments"]["code"].encode()).decode() 346 | return combine_code_template(code_to_execute) 347 | 348 | if tool_call["name"] == "jupyter_code": 349 | return jupyter_code_gencode(tool_call) 350 | elif tool_call["name"] == "python_code_with_standard_io": 351 | return python_code_with_standard_io_gencode(tool_call) 352 | else: 353 | raise ValueError(f"Unsupported tool call name: {tool_call['name']}") 354 | 355 | 356 | def generate_tool_call_input(tool_call: Dict) -> str: 357 | if tool_call["name"] == "jupyter_code": 358 | return None 359 | elif tool_call["name"] == "python_code_with_standard_io": 360 | return tool_call["arguments"]["input"] 361 | else: 362 | raise ValueError(f"Unsupported tool call name: {tool_call['name']}") 363 | -------------------------------------------------------------------------------- /rstar2_agent/tools/request_processor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import asyncio 5 | import time 6 | import uuid 7 | import collections 8 | import aiohttp 9 | import traceback 10 | from typing import List, Dict, Any, Callable, Awaitable 11 | 12 | # Define the expected signature for the batch submission function 13 | # It should be an async callable that takes: 14 | # 1. A list of original request payloads (List[Any]) 15 | # 2. The aiohttp.ClientSession instance 16 | # It should return: 17 | # 1. A list of results (List[Any]) where the order *strictly* matches the input payloads order. 18 | BatchSubmitFunc = Callable[[List[Any], aiohttp.ClientSession], Awaitable[List[Any]]] 19 | 20 | 21 | class RequestProcessor: 22 | """ 23 | Manages batch submission concurrently using an injected batch submission function. 24 | Requests are buffered and processed by concurrent sender workers. 25 | """ 26 | def __init__(self, batch_size: int, batch_timeout_seconds: float, session: aiohttp.ClientSession, concurrency: int, batch_submit_func: BatchSubmitFunc): 27 | """ 28 | Initializes the Request Processor with concurrent sending and a generic submission function. 29 | Must be called within an event loop context. 30 | 31 | Args: 32 | batch_size: Maximum items per batch. 33 | batch_timeout_seconds: Timeout for gathering items into a batch. 34 | session: The aiohttp.ClientSession to pass to the submission function. 35 | language: Submission language to pass to the submission function. 36 | concurrency: Maximum number of concurrent batches being sent to B. 37 | batch_submit_func: The async function to call for sending a batch. 38 | Must match the BatchSubmitFunc signature. 39 | """ 40 | if batch_size <= 0 or concurrency <= 0: 41 | raise ValueError("batch_size and concurrency must be positive") 42 | if batch_timeout_seconds <= 0: 43 | print("Warning: batch_timeout_seconds <= 0, batching will be strictly based on batch_size or queue availability.") 44 | 45 | self._batch_size = batch_size 46 | self._batch_timeout_seconds = batch_timeout_seconds 47 | self._session = session 48 | self._concurrency = concurrency 49 | self._batch_submit_func = batch_submit_func # Store the injected function 50 | 51 | self._submission_queue = asyncio.Queue() 52 | self._pending_requests: Dict[str, Dict[str, Any]] = {} # {request_id: {"future": Future, "payload": payload}} 53 | 54 | self._semaphore = asyncio.Semaphore(concurrency) 55 | 56 | self._sender_workers: List[asyncio.Task] = [] 57 | self._running = False 58 | 59 | # --- Statistics --- 60 | self._reset_stats_internal() # Initialize stats 61 | # --- End Statistics --- 62 | 63 | print(f"[{time.monotonic():.4f}] RequestProcessor initialized with concurrency={self._concurrency} (async thread).") 64 | 65 | def _reset_stats_internal(self): 66 | """Helper to initialize or reset statistics.""" 67 | self._stats = { 68 | "total_batch_submission_duration_seconds": 0.0, 69 | "num_batches_submitted": 0, 70 | "num_successful_batches": 0, 71 | "num_failed_batches": 0, 72 | "total_items_processed_in_batches": 0, 73 | "actual_batch_sizes": [] # Stores the size of each batch 74 | } 75 | 76 | async def send_request(self, request_payload: Any, timeout: float = None): 77 | """ 78 | Adds a single request to the buffer and waits for its result. 79 | This call is awaitable and provides the synchronous-like pattern. 80 | """ 81 | if not self._running: 82 | raise RuntimeError("RequestProcessor is not running. Call .start() first.") 83 | 84 | request_id = str(uuid.uuid4()) 85 | 86 | future = asyncio.get_running_loop().create_future() 87 | self._pending_requests[request_id] = { 88 | "future": future, 89 | "payload": request_payload 90 | } 91 | 92 | await self._submission_queue.put(request_id) 93 | 94 | try: 95 | result = await asyncio.wait_for(future, timeout=timeout) 96 | return result 97 | except asyncio.TimeoutError: 98 | if request_id in self._pending_requests: 99 | del self._pending_requests[request_id] 100 | print(f"[{time.monotonic():.4f}] Request {request_id[:6]}... timed out waiting for result.") 101 | raise 102 | except Exception as e: 103 | if request_id in self._pending_requests: 104 | del self._pending_requests[request_id] 105 | print(f"[{time.monotonic():.4f}] Request {request_id[:6]}... encountered error while waiting: {e}") 106 | raise 107 | 108 | async def send_requests(self, request_payloads: List[Any], timeout: float = None) -> List[Any]: 109 | """ 110 | Submits multiple request payloads concurrently and waits for all their results. 111 | Returns results or exceptions in the same order as input payloads. 112 | Uses send_request internally and gathers futures. 113 | """ 114 | if not self._running: 115 | raise RuntimeError("RequestProcessor is not running. Call .start() first.") 116 | 117 | if not request_payloads: 118 | return [] 119 | 120 | # print(f"[{time.monotonic():.4f}] RequestProcessor.send_requests submitting {len(request_payloads)} individual requests concurrently...") 121 | 122 | tasks = [] 123 | for payload in request_payloads: 124 | tasks.append(asyncio.create_task(self.send_request(payload, timeout=timeout))) 125 | 126 | results = await asyncio.gather(*tasks, return_exceptions=True) 127 | 128 | # print(f"[{time.monotonic():.4f}] RequestProcessor.send_requests received all results.") 129 | 130 | return results 131 | 132 | async def start(self): 133 | """ 134 | Starts the concurrent sender worker tasks. Must be called within loop. 135 | """ 136 | if self._running: 137 | print(f"[{time.monotonic():.4f}] RequestProcessor is already running.") 138 | return 139 | self._running = True 140 | self._sender_workers = [asyncio.create_task(self._sender_worker()) for _ in range(self._concurrency)] 141 | print(f"[{time.monotonic():.4f}] RequestProcessor started {self._concurrency} sender workers.") 142 | 143 | async def stop(self): 144 | if not self._running: 145 | print(f"[{time.monotonic():.4f}] RequestProcessor is not running.") 146 | return 147 | 148 | print(f"[{time.monotonic():.4f}] Stopping RequestProcessor. Signaling workers...") 149 | self._running = False 150 | 151 | await self._submission_queue.join() 152 | print(f"[{time.monotonic():.4f}] Submission queue joined. All buffered items processed by sender workers.") 153 | 154 | # Wait for sender workers to finish their current batch and exit their loops 155 | for worker in self._sender_workers: 156 | worker.cancel() 157 | try: 158 | await asyncio.gather(*self._sender_workers, return_exceptions=True) 159 | except asyncio.CancelledError: 160 | print(f"[{time.monotonic():.4f}] Sender workers cancelled as expected.") 161 | except Exception as e: 162 | print(f"[{time.monotonic():.4f}] Error during sender workers shutdown: {e}") 163 | 164 | print(f"[{time.monotonic():.4f}] All sender workers stopped.") 165 | 166 | print(f"[{time.monotonic():.4f}] Waiting for {len(self._pending_requests)} pending results...") 167 | wait_tasks = [asyncio.create_task(req_info["future"]) 168 | for req_id, req_info in list(self._pending_requests.items()) 169 | if not req_info["future"].done()] 170 | 171 | if wait_tasks: 172 | stop_results_timeout = self._batch_timeout_seconds * 5 173 | print(f"[{time.monotonic():.4f}] Waiting for remaining results with timeout {stop_results_timeout:.2f}s...") 174 | try: 175 | await asyncio.wait_for(asyncio.gather(*wait_tasks, return_exceptions=True), timeout=stop_results_timeout) 176 | print(f"[{time.monotonic():.4f}] All pending results awaited or timed out during stop.") 177 | except asyncio.TimeoutError: 178 | print(f"[{time.monotonic():.4f}] Warning: Timeout waiting for all pending results during stop.") 179 | 180 | else: 181 | print(f"[{time.monotonic():.4f}] No pending results to await.") 182 | 183 | if self._pending_requests: 184 | print(f"[{time.monotonic():.4f}] Warning: Stopping with {len(self._pending_requests)} requests still pending (futures not completed/timed out)!") 185 | 186 | print(f"[{time.monotonic():.4f}] RequestProcessor stopped.") 187 | 188 | async def _sender_worker(self): 189 | """ 190 | A single worker coroutine that continuously tries to send batches. 191 | Controls its own access to concurrent sending via the semaphore. 192 | Runs within the async thread's event loop. 193 | """ 194 | print(f"[{time.monotonic():.4f}] Sender worker started.") 195 | batch_gathering_timeout = self._batch_timeout_seconds 196 | 197 | try: 198 | while self._running or not self._submission_queue.empty(): 199 | batch_item_ids = [] 200 | 201 | try: 202 | first_item_id = await asyncio.wait_for(self._submission_queue.get(), timeout=batch_gathering_timeout) 203 | self._submission_queue.task_done() 204 | batch_item_ids.append(first_item_id) 205 | 206 | while len(batch_item_ids) < self._batch_size: 207 | try: 208 | next_item_id = self._submission_queue.get_nowait() 209 | self._submission_queue.task_done() 210 | batch_item_ids.append(next_item_id) 211 | except asyncio.QueueEmpty: 212 | break 213 | 214 | except asyncio.TimeoutError: 215 | if not batch_item_ids: 216 | continue 217 | print(f"[{time.monotonic():.4f}] Worker: Timeout, but gathered {len(batch_item_ids)} items. Proceeding to send.") 218 | pass 219 | 220 | except Exception as e: 221 | print(f"[{time.monotonic():.4f}] Worker encountered error getting items: {e}") 222 | await asyncio.sleep(1.0) 223 | continue 224 | 225 | if batch_item_ids: 226 | # Acquire semaphore permit before starting the potentially long-running batch submission 227 | async with self._semaphore: 228 | # Perform the actual batch sending using the injected function 229 | await self._perform_send_batch(batch_item_ids) 230 | else: 231 | pass 232 | 233 | except asyncio.CancelledError: 234 | print(f"[{time.monotonic():.4f}] Sender worker received cancellation signal.") 235 | except Exception as e: 236 | print(f"[{time.monotonic():.4f}] Sender worker encountered major error: {e}") 237 | 238 | print(f"[{time.monotonic():.4f}] Sender worker finished.") 239 | 240 | 241 | async def _perform_send_batch(self, batch_item_ids: List[str]): 242 | """ 243 | Internal method to execute the batch submission using the injected function and process results. 244 | Assumes this method is called within the context of an acquired semaphore permit. 245 | Runs within the async thread's event loop. 246 | """ 247 | batch_info = [] # [{"request_id": id, "payload": payload}] 248 | payloads_for_server = [] # List of just payloads to pass to the injected function 249 | 250 | # Ensure the original requests are still pending before forming the batch data 251 | valid_item_ids_for_batch = [req_id for req_id in batch_item_ids if req_id in self._pending_requests] 252 | 253 | if not valid_item_ids_for_batch: 254 | # print(f"[{time.monotonic():.4f}] Batch contains no valid pending items after worker picked them up.") 255 | return # Nothing valid to send 256 | 257 | # Build the payload list for the injected function using only valid IDs 258 | for req_id in valid_item_ids_for_batch: 259 | req_info = self._pending_requests[req_id] # Should exist based on valid_item_ids_for_batch 260 | batch_info.append({"request_id": req_id, "payload": req_info["payload"]}) 261 | payloads_for_server.append(req_info["payload"]) 262 | 263 | # --- CALL THE INJECTED BATCH SUBMISSION FUNCTION --- 264 | # print(f"[{time.monotonic():.4f}] Submitting batch of {len(payloads_for_server)} items using injected function...") 265 | self._stats["num_batches_submitted"] += 1 266 | self._stats["actual_batch_sizes"].append(len(payloads_for_server)) 267 | 268 | start_time = time.monotonic() 269 | try: 270 | # Call the function provided during initialization 271 | # It must return results in the same order as input payloads_for_server 272 | results_list = await self._batch_submit_func(payloads_for_server, self._session) 273 | 274 | submission_duration = time.monotonic() - start_time 275 | self._stats["total_batch_submission_duration_seconds"] += submission_duration 276 | self._stats["num_successful_batches"] += 1 277 | self._stats["total_items_processed_in_batches"] += len(payloads_for_server) 278 | 279 | 280 | # Process the results returned by the injected function 281 | # The order of results_list is assumed to match the order of payloads_for_server 282 | if len(results_list) != len(batch_info): 283 | print(f"[{time.monotonic():.4f}] Warning: Injected function returned {len(results_list)} results, but batch had {len(batch_info)} items. Cannot reliably match results.") 284 | match_count = min(len(results_list), len(batch_info)) 285 | else: 286 | match_count = len(batch_info) 287 | 288 | for i in range(match_count): 289 | req_id = batch_info[i]["request_id"] # Get the original ID 290 | result = results_list[i] # Get the corresponding result 291 | 292 | if req_id in self._pending_requests: 293 | req_info = self._pending_requests[req_id] 294 | future = req_info["future"] 295 | if not future.done(): 296 | future.set_result(result) 297 | del self._pending_requests[req_id] 298 | else: 299 | if req_id in self._pending_requests: 300 | del self._pending_requests[req_id] 301 | else: 302 | print(f"[{time.monotonic():.4f}] Warning: Received result for unknown or already completed request ID {req_id[:6]}... Result: {result}") 303 | 304 | except Exception as e: 305 | submission_duration = time.monotonic() - start_time 306 | self._stats["total_batch_submission_duration_seconds"] += submission_duration # Still record time even on failure 307 | self._stats["num_failed_batches"] += 1 308 | # print error stack trace for debugging 309 | traceback.print_exc() 310 | print(f"[{time.monotonic():.4f}] Error calling or processing results from injected function for batch: {e}") 311 | # Handle failure of the injected function. 312 | # Items remain in _pending_requests, rely on timeout/stop cleanup. 313 | # To avoid silently failing, throw an exception directly to the caller 314 | for req_id in valid_item_ids_for_batch: 315 | if req_id in self._pending_requests: 316 | req_info = self._pending_requests[req_id] 317 | future = req_info["future"] 318 | if not future.done(): 319 | future.set_exception(e) 320 | del self._pending_requests[req_id] 321 | 322 | def get_stats(self) -> Dict[str, Any]: 323 | """Returns the collected performance statistics.""" 324 | stats_copy = self._stats.copy() 325 | if stats_copy["num_successful_batches"] > 0: 326 | stats_copy["avg_successful_batch_submission_duration_seconds"] = \ 327 | self._stats["total_batch_submission_duration_seconds"] / stats_copy["num_successful_batches"] \ 328 | if self._stats["num_successful_batches"] > 0 else 0 # Avoid division by zero if only failures 329 | else: 330 | stats_copy["avg_successful_batch_submission_duration_seconds"] = 0 331 | 332 | if stats_copy["num_batches_submitted"] > 0: # Calculate overall average if any batch was submitted 333 | stats_copy["avg_overall_batch_submission_duration_seconds"] = \ 334 | self._stats["total_batch_submission_duration_seconds"] / stats_copy["num_batches_submitted"] 335 | else: 336 | stats_copy["avg_overall_batch_submission_duration_seconds"] = 0 337 | 338 | if self._stats["actual_batch_sizes"]: 339 | stats_copy["avg_actual_batch_size"] = sum(self._stats["actual_batch_sizes"]) / len(self._stats["actual_batch_sizes"]) 340 | else: 341 | stats_copy["avg_actual_batch_size"] = 0 342 | return stats_copy 343 | 344 | def print_stats(self): 345 | """Prints the collected performance statistics.""" 346 | stats_to_print = self.get_stats() 347 | print(f"[{time.monotonic():.4f}] --- RequestProcessor Statistics ---") 348 | for key, value in stats_to_print.items(): 349 | if key == "actual_batch_sizes": 350 | if value: # Check if the list of batch sizes is not empty 351 | batch_size_counts = collections.Counter(value) 352 | # Format as a list of (batch_size, count) tuples, sorted by batch_size 353 | formatted_batch_sizes = sorted(batch_size_counts.items()) 354 | print(f" {key}: {formatted_batch_sizes}") 355 | else: 356 | print(f" {key}: []") # Print an empty list if no batches were processed 357 | elif isinstance(value, float): 358 | print(f" {key}: {value:.4f}") 359 | else: 360 | # Handles other data types including other lists (if any) 361 | print(f" {key}: {value}") 362 | print(f"[{time.monotonic():.4f}] --- End Statistics ---") 363 | 364 | def reset_stats(self): 365 | """Resets all collected performance statistics to their initial values.""" 366 | print(f"[{time.monotonic():.4f}] Resetting RequestProcessor statistics.") 367 | self._reset_stats_internal() 368 | -------------------------------------------------------------------------------- /fused_compute_score/prime_math/grader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright (c) Microsoft Corporation. 16 | # 17 | # Permission is hereby granted, free of charge, to any person obtaining a copy 18 | # of this software and associated documentation files (the "Software"), to deal 19 | # in the Software without restriction, including without limitation the rights 20 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 21 | # copies of the Software, and to permit persons to whom the Software is 22 | # furnished to do so, subject to the following conditions: 23 | # 24 | # The above copyright notice and this permission notice shall be included in all 25 | # copies or substantial portions of the Software. 26 | # 27 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 28 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 29 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 30 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 31 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 32 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 33 | # SOFTWARE 34 | 35 | # Copyright (c) 2023 OpenAI 36 | # 37 | # Permission is hereby granted, free of charge, to any person obtaining a copy 38 | # of this software and associated documentation files (the "Software"), to deal 39 | # in the Software without restriction, including without limitation the rights 40 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 41 | # copies of the Software, and to permit persons to whom the Software is 42 | # furnished to do so, subject to the following conditions: 43 | 44 | # The above copyright notice and this permission notice shall be included in all 45 | # copies or substantial portions of the Software. 46 | # 47 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 48 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 49 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 50 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 51 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 52 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 53 | # SOFTWARE. 54 | 55 | # Copyright (c) 2021 Dan Hendrycks 56 | # 57 | # Permission is hereby granted, free of charge, to any person obtaining a copy 58 | # of this software and associated documentation files (the "Software"), to deal 59 | # in the Software without restriction, including without limitation the rights 60 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 61 | # copies of the Software, and to permit persons to whom the Software is 62 | # furnished to do so, subject to the following conditions: 63 | # 64 | # The above copyright notice and this permission notice shall be included in all 65 | # copies or substantial portions of the Software. 66 | # 67 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 68 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 69 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 70 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 71 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 72 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 73 | # SOFTWARE. 74 | 75 | # Copyright 2024 PRIME team and/or its affiliates 76 | # 77 | # Licensed under the Apache License, Version 2.0 (the "License"); 78 | # you may not use this file except in compliance with the License. 79 | # You may obtain a copy of the License at 80 | # 81 | # http://www.apache.org/licenses/LICENSE-2.0 82 | # 83 | # Unless required by applicable law or agreed to in writing, software 84 | # distributed under the License is distributed on an "AS IS" BASIS, 85 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 86 | # See the License for the specific language governing permissions and 87 | # limitations under the License. 88 | """ 89 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: 90 | - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py 91 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC 92 | - https://github.com/openai/prm800k 93 | """ 94 | 95 | import contextlib 96 | import math 97 | import re 98 | from math import isclose 99 | 100 | # sympy related 101 | from sympy import N, simplify 102 | from sympy.parsing.latex import parse_latex 103 | from sympy.parsing.sympy_parser import parse_expr 104 | 105 | 106 | def is_digit(s): 107 | try: 108 | if "{,}" in str(s): 109 | num = float(str(s).replace("{,}", "")) 110 | return True, num 111 | 112 | num = float(str(s).replace(",", "")) 113 | return True, num 114 | except ValueError: 115 | return False, None 116 | 117 | 118 | def normalize(answer, pi) -> str: 119 | # checking if answer is $ and removing $ in that case to compare 120 | if isinstance(answer, str) and bool(re.match(r"\$\d+(\.\d+)?", answer)): 121 | return answer[1:] 122 | 123 | # checking if answer is % or \\% and removing % 124 | if isinstance(answer, str) and ( 125 | bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer)) 126 | ): 127 | return answer.replace("\\%", "").replace("%", "") 128 | 129 | # handle base 130 | answer = handle_base(answer) 131 | 132 | # handle pi 133 | answer = handle_pi(answer, pi) 134 | 135 | return answer 136 | 137 | 138 | def handle_base(x) -> str: 139 | if isinstance(x, str) and "_" in x: 140 | # Due to base 141 | x = x.split("_")[0] 142 | x = float(x) 143 | return int(x) 144 | return x 145 | 146 | 147 | def handle_pi(string, pi): 148 | if isinstance(string, str) and "\pi" in string: 149 | # Find the first occurrence of "\pi" 150 | idx = string.find("\pi") 151 | 152 | # Iterate over the string and find all occurrences of "\pi" with a valid previous character 153 | while idx != -1: 154 | if idx > 0 and string[idx - 1].isdigit(): 155 | # Replace "\pi" with "*math.pi" if the previous character is a digit 156 | string = string[:idx] + f"*{pi}" + string[idx + 3 :] 157 | else: 158 | # Replace "\pi" with "1*math.pi" if the previous character is not a digit 159 | string = string[:idx] + f"1*{pi}" + string[idx + 3 :] 160 | 161 | # Find the next occurrence of "\pi" 162 | idx = string.find("\pi", idx + 1) 163 | 164 | # Evaluate the expression using eval() function 165 | with contextlib.suppress(Exception): 166 | string = eval(string) 167 | 168 | return string 169 | 170 | 171 | def math_equal( 172 | prediction: bool | float | str, 173 | reference: float | str, 174 | include_percentage: bool = True, 175 | tolerance: float = 1e-4, 176 | timeout: float = 10.0, 177 | pi: float = math.pi, 178 | ) -> bool: 179 | """ 180 | Exact match of math if and only if: 181 | 1. numerical equal: both can convert to float and are equal 182 | 2. symbolic equal: both can convert to sympy expression and are equal 183 | """ 184 | 185 | prediction = normalize(prediction, pi) 186 | reference = normalize(reference, pi) 187 | 188 | if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases 189 | prediction = prediction[:1000] 190 | 191 | # 0. string comparison 192 | if isinstance(prediction, str) and isinstance(reference, str): 193 | if prediction.strip().lower() == reference.strip().lower(): 194 | return True 195 | if prediction.replace(" ", "") == reference.replace(" ", ""): 196 | return True 197 | 198 | try: # 1. numerical equal 199 | if is_digit(prediction)[0] and is_digit(reference)[0]: 200 | prediction = is_digit(prediction)[1] 201 | reference = is_digit(reference)[1] 202 | # number questions 203 | gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] 204 | for item in gt_result: 205 | try: 206 | if isclose(item, prediction, rel_tol=tolerance): 207 | return True 208 | except Exception: 209 | continue 210 | return False 211 | except Exception: 212 | pass 213 | 214 | if not prediction and prediction not in [0, False]: 215 | return False 216 | 217 | # 2. symbolic equal 218 | reference = str(reference).strip() 219 | prediction = str(prediction).strip() 220 | 221 | ## deal with [], (), {} 222 | prediction = format_intervals(prediction) 223 | 224 | pred_str, ref_str = prediction, reference 225 | if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or ( 226 | prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") 227 | ): 228 | pred_str = pred_str.strip("[]()") 229 | ref_str = ref_str.strip("[]()") 230 | for s in ["{", "}", "(", ")"]: 231 | ref_str = ref_str.replace(s, "") 232 | pred_str = pred_str.replace(s, "") 233 | if pred_str == ref_str: 234 | return True 235 | 236 | ## [a, b] vs. [c, d], return a==c and b==d 237 | if ( 238 | prediction 239 | and reference 240 | and prediction[0] in "([" 241 | and prediction[-1] in ")]" 242 | and prediction[0] == reference[0] 243 | and prediction[-1] == reference[-1] 244 | ): 245 | pred_parts = prediction[1:-1].split(",") 246 | ref_parts = reference[1:-1].split(",") 247 | if len(pred_parts) == len(ref_parts) and all( 248 | [ 249 | math_equal(pred_pt, ref_pt, include_percentage, tolerance) 250 | for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True) 251 | ] 252 | ): 253 | return True 254 | 255 | if "," in prediction and "," in reference: 256 | pred_parts = [item.strip() for item in prediction.split(",")] 257 | ref_parts = [item.strip() for item in reference.split(",")] 258 | 259 | if len(pred_parts) == len(ref_parts): 260 | return bool( 261 | all( 262 | [ 263 | math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) 264 | for i in range(len(pred_parts)) 265 | ] 266 | ) 267 | ) 268 | 269 | # if we have point == tuple of values 270 | if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")": 271 | pred_parts = prediction[prediction.find("(") + 1 : -1].split(",") 272 | ref_parts = reference[1:-1].split(",") 273 | if len(pred_parts) == len(ref_parts) and all( 274 | [ 275 | math_equal(pred_pt, ref_pt, include_percentage, tolerance) 276 | for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=False) 277 | ] 278 | ): 279 | return True 280 | 281 | # if reference is a matrix 282 | if "\begin{pmatrix}" in reference and prediction.startswith("Matrix"): 283 | try: 284 | pred_matrix = parse_expr(prediction) 285 | ref_matrix_items = reference.split()[1:-1:2] 286 | if len(pred_matrix) == len(ref_matrix_items) and all( 287 | [ 288 | math_equal(pred, ref, include_percentage, tolerance) 289 | for ref, pred in zip(ref_matrix_items, pred_matrix, strict=False) 290 | ] 291 | ): 292 | return True 293 | except Exception: 294 | pass 295 | elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"): 296 | if isinstance(eval(prediction), list): 297 | try: 298 | pred_matrix = eval(prediction) 299 | # ref_matrix_items = reference.split()[1:-1:2] 300 | ref_matrix_items = ( 301 | reference.lstrip("\\begin{pmatrix}") # noqa: B005 302 | .lstrip("\begin{pmatrix}") 303 | .rstrip("\\end{pmatrix}") 304 | .rstrip("\end{pmatrix}") 305 | ) # noqa: B005 306 | ref_matrix_items = ref_matrix_items.split("\\") 307 | ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items] 308 | if len(pred_matrix) == len(ref_matrix_items) and all( 309 | [ 310 | math_equal(pred, ref, include_percentage, tolerance) 311 | for ref, pred in zip(ref_matrix_items, pred_matrix, strict=False) 312 | ] 313 | ): 314 | return True 315 | except Exception: 316 | pass 317 | 318 | return symbolic_equal(prediction, reference, tolerance, timeout) 319 | 320 | 321 | def symbolic_equal(a, b, tolerance, timeout=10.0): 322 | def _parse(s): 323 | for f in [parse_expr, parse_latex]: 324 | try: 325 | with timeout_limit(seconds=timeout): 326 | return f(s) 327 | except TimeoutError: 328 | print(f"Parsing timed out for {s}") 329 | continue 330 | except Exception: 331 | continue 332 | return s 333 | 334 | a = _parse(a) 335 | b = _parse(b) 336 | 337 | try: 338 | with timeout_limit(seconds=timeout): 339 | if simplify(a - b) == 0: 340 | return True 341 | except TimeoutError: 342 | print(f"Simplification timed out for {a} - {b}") 343 | pass 344 | except Exception: 345 | pass 346 | 347 | try: 348 | with timeout_limit(seconds=timeout): 349 | if isclose(N(a), N(b), rel_tol=tolerance): 350 | return True 351 | except TimeoutError: 352 | print(f"Numerical evaluation timed out for {a}, {b}") 353 | pass 354 | except Exception: 355 | pass 356 | return False 357 | 358 | 359 | def format_intervals(prediction): 360 | patterns = { 361 | "Interval(": r"^Interval\((.*)\)$", 362 | "Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$", 363 | "Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$", 364 | "Interval.open(": r"^Interval\.open\((.*)\)$", 365 | } 366 | 367 | for key, pattern in patterns.items(): 368 | match = re.match(pattern, prediction) 369 | if match: 370 | inner_content = match.group(1) 371 | 372 | if key == "Interval(": # Intarval(a, b) == [a, b] 373 | return f"[{inner_content}]" 374 | elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b) 375 | return f"[{inner_content})" 376 | elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b] 377 | return f"({inner_content}]" 378 | elif key == "Interval.open(": # Intarval.open(a, b) == (a, b) 379 | return f"({inner_content})" 380 | 381 | return prediction 382 | 383 | 384 | import os 385 | import signal 386 | import queue 387 | import multiprocessing 388 | from functools import wraps 389 | from typing import Callable, Any 390 | 391 | 392 | def _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: tuple, kwargs: dict[str, Any]): 393 | """ 394 | Internal wrapper function executed in the child process. 395 | Calls the original target function and puts the result or exception into the queue. 396 | """ 397 | try: 398 | result = target_func(*args, **kwargs) 399 | mp_queue.put((True, result)) # Indicate success and put result 400 | except Exception as e: 401 | # Ensure the exception is pickleable for the queue 402 | try: 403 | import pickle 404 | 405 | pickle.dumps(e) # Test if the exception is pickleable 406 | mp_queue.put((False, e)) # Indicate failure and put exception 407 | except (pickle.PicklingError, TypeError): 408 | # Fallback if the original exception cannot be pickled 409 | mp_queue.put((False, RuntimeError(f"Original exception type {type(e).__name__} not pickleable: {e}"))) 410 | 411 | 412 | def timeout_limit(seconds: float, use_signals: bool = False): 413 | """ 414 | Decorator to add a timeout to a function. 415 | 416 | Args: 417 | seconds: The timeout duration in seconds. 418 | use_signals: (Deprecated) This is deprecated because signals only work reliably in the main thread 419 | and can cause issues in multiprocessing or multithreading contexts. 420 | Defaults to False, which uses the more robust multiprocessing approach. 421 | 422 | Returns: 423 | A decorated function with timeout. 424 | 425 | Raises: 426 | TimeoutError: If the function execution exceeds the specified time. 427 | RuntimeError: If the child process exits with an error (multiprocessing mode). 428 | NotImplementedError: If the OS is not POSIX (signals are only supported on POSIX). 429 | """ 430 | 431 | def decorator(func): 432 | if use_signals: 433 | if os.name != "posix": 434 | raise NotImplementedError(f"Unsupported OS: {os.name}") 435 | # Issue deprecation warning if use_signals is explicitly True 436 | print( 437 | "WARN: The 'use_signals=True' option in the timeout decorator is deprecated. \ 438 | Signals are unreliable outside the main thread. \ 439 | Please use the default multiprocessing-based timeout (use_signals=False)." 440 | ) 441 | 442 | @wraps(func) 443 | def wrapper_signal(*args, **kwargs): 444 | def handler(signum, frame): 445 | # Update function name in error message if needed (optional but good practice) 446 | raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (signal)!") 447 | 448 | old_handler = signal.getsignal(signal.SIGALRM) 449 | signal.signal(signal.SIGALRM, handler) 450 | # Use setitimer for float seconds support, alarm only supports integers 451 | signal.setitimer(signal.ITIMER_REAL, seconds) 452 | 453 | try: 454 | result = func(*args, **kwargs) 455 | finally: 456 | # Reset timer and handler 457 | signal.setitimer(signal.ITIMER_REAL, 0) 458 | signal.signal(signal.SIGALRM, old_handler) 459 | return result 460 | 461 | return wrapper_signal 462 | else: 463 | # --- Multiprocessing based timeout (existing logic) --- 464 | @wraps(func) 465 | def wrapper_mp(*args, **kwargs): 466 | q = multiprocessing.Queue(maxsize=1) 467 | process = multiprocessing.Process(target=_mp_target_wrapper, args=(func, q, args, kwargs)) 468 | process.start() 469 | process.join(timeout=seconds) 470 | 471 | if process.is_alive(): 472 | process.terminate() 473 | process.join(timeout=0.5) # Give it a moment to terminate 474 | if process.is_alive(): 475 | print(f"Warning: Process {process.pid} did not terminate gracefully after timeout.") 476 | # Update function name in error message if needed (optional but good practice) 477 | raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (multiprocessing)!") 478 | 479 | try: 480 | success, result_or_exc = q.get(timeout=0.1) # Small timeout for queue read 481 | if success: 482 | return result_or_exc 483 | else: 484 | raise result_or_exc # Reraise exception from child 485 | except queue.Empty as err: 486 | exitcode = process.exitcode 487 | if exitcode is not None and exitcode != 0: 488 | raise RuntimeError( 489 | f"Child process exited with error (exitcode: {exitcode}) before returning result." 490 | ) from err 491 | else: 492 | # Should have timed out if queue is empty after join unless process died unexpectedly 493 | # Update function name in error message if needed (optional but good practice) 494 | raise TimeoutError( 495 | f"Operation timed out or process finished unexpectedly without result " 496 | f"(exitcode: {exitcode})." 497 | ) from err 498 | finally: 499 | q.close() 500 | q.join_thread() 501 | 502 | return wrapper_mp 503 | 504 | return decorator 505 | -------------------------------------------------------------------------------- /rstar2_agent/rstar2_agent_ray_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import uuid 5 | from copy import deepcopy 6 | from pprint import pprint 7 | 8 | import numpy as np 9 | import ray 10 | import torch 11 | from tqdm import tqdm 12 | 13 | from verl import DataProto 14 | from verl.experimental.dataset.sampler import AbstractCurriculumSampler 15 | from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss 16 | from verl.trainer.ppo.metric_utils import ( 17 | compute_data_metrics, 18 | compute_throughout_metrics, 19 | compute_timing_metrics, 20 | ) 21 | from verl.trainer.ppo.ray_trainer import ( 22 | RayPPOTrainer, 23 | apply_kl_penalty, 24 | compute_advantage, 25 | compute_response_mask, 26 | ) 27 | from verl.trainer.ppo.reward import compute_reward, compute_reward_async 28 | from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi 29 | from verl.utils.debug import marked_timer 30 | from verl.utils.metric import reduce_metrics 31 | from verl.utils.rollout_skip import RolloutSkip 32 | 33 | from .down_sample import reject_equal_reward, resample_of_correct 34 | 35 | 36 | class RStar2AgentRayTrainer(RayPPOTrainer): 37 | def _down_sample_batch(self, batch: DataProto) -> DataProto: 38 | do_down_sampling = self.config.augmentation.do_down_sampling 39 | down_sampling_config = self.config.augmentation.down_sampling_config 40 | world_size = self.actor_rollout_wg.world_size 41 | metrics = {"down_sampling/before_sampling_trace_num": len(batch),} 42 | 43 | def check_batch_is_empty(batch: DataProto, down_sampling_stage: str): 44 | if batch is None or len(batch) == 0: 45 | print(f"Batch is empty after {down_sampling_stage}, skipping the training step.") 46 | return True 47 | return False 48 | 49 | # reject rollout trace of the same prompt with equal rewards 50 | do_reject_equal_reward = down_sampling_config.get("reject_equal_reward", False) and do_down_sampling 51 | batch, _metrics = reject_equal_reward(batch, do_reject_equal_reward, world_size) 52 | metrics.update(_metrics) 53 | if check_batch_is_empty(batch, "reject_equal_reward"): 54 | return None, metrics 55 | 56 | # weighted sampling 57 | config = { 58 | "roc_error_ratio": down_sampling_config.get("roc_error_ratio", False) and do_down_sampling, 59 | "roc_answer_format": down_sampling_config.get("roc_answer_format", False) and do_down_sampling, 60 | "min_zero_reward_trace_num": down_sampling_config.get("min_zero_reward_trace_num", -1), 61 | "min_non_zero_reward_trace_num": down_sampling_config.get("min_non_zero_reward_trace_num", -1), 62 | "down_sample_to_n": down_sampling_config.get("down_sample_to_n", -1), 63 | } 64 | batch, _metrics = resample_of_correct(batch, self.tokenizer, config, do_down_sampling, world_size=world_size) 65 | metrics.update(_metrics) 66 | if check_batch_is_empty(batch, "fused_weighted_sampling"): 67 | return None, metrics 68 | 69 | metrics["down_sampling/after_sampling_trace_num"] = len(batch) 70 | return batch, metrics 71 | 72 | def fit(self): 73 | """ 74 | The training loop of PPO. 75 | The driver process only need to call the compute functions of the worker group through RPC 76 | to construct the PPO dataflow. 77 | The light-weight advantage computation is done on the driver process. 78 | 79 | Most logic is same with RayPPOTrainer, mainly add down sample related. 80 | """ 81 | from omegaconf import OmegaConf 82 | 83 | from verl.utils.tracking import Tracking 84 | 85 | logger = Tracking( 86 | project_name=self.config.trainer.project_name, 87 | experiment_name=self.config.trainer.experiment_name, 88 | default_backend=self.config.trainer.logger, 89 | config=OmegaConf.to_container(self.config, resolve=True), 90 | ) 91 | 92 | self.global_steps = 0 93 | 94 | # load checkpoint before doing anything 95 | self._load_checkpoint() 96 | 97 | # perform validation before training 98 | # currently, we only support validation using the reward_function. 99 | if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): 100 | val_metrics = self._validate() 101 | assert val_metrics, f"{val_metrics=}" 102 | pprint(f"Initial validation metrics: {val_metrics}") 103 | logger.log(data=val_metrics, step=self.global_steps) 104 | if self.config.trainer.get("val_only", False): 105 | return 106 | 107 | if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): 108 | rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) 109 | rollout_skip.wrap_generate_sequences() 110 | 111 | # add tqdm 112 | progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") 113 | 114 | # we start from step 1 115 | self.global_steps += 1 116 | last_val_metrics = None 117 | self.max_steps_duration = 0 118 | 119 | prev_step_profile = False 120 | curr_step_profile = ( 121 | self.global_steps in self.config.global_profiler.steps 122 | if self.config.global_profiler.steps is not None 123 | else False 124 | ) 125 | next_step_profile = False 126 | 127 | for epoch in range(self.config.trainer.total_epochs): 128 | for batch_dict in self.train_dataloader: 129 | metrics = {} 130 | timing_raw = {} 131 | 132 | with marked_timer("start_profile", timing_raw): 133 | self._start_profiling( 134 | not prev_step_profile and curr_step_profile 135 | if self.config.global_profiler.profile_continuous_steps 136 | else curr_step_profile 137 | ) 138 | 139 | batch: DataProto = DataProto.from_single_dict(batch_dict) 140 | 141 | # add uid to batch 142 | batch.non_tensor_batch["uid"] = np.array( 143 | [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object 144 | ) 145 | 146 | # pop those keys for generation 147 | batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] 148 | non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] 149 | if "multi_modal_data" in batch.non_tensor_batch: 150 | non_tensor_batch_keys_to_pop.append("multi_modal_data") 151 | if "raw_prompt" in batch.non_tensor_batch: 152 | non_tensor_batch_keys_to_pop.append("raw_prompt") 153 | if "tools_kwargs" in batch.non_tensor_batch: 154 | non_tensor_batch_keys_to_pop.append("tools_kwargs") 155 | if "interaction_kwargs" in batch.non_tensor_batch: 156 | non_tensor_batch_keys_to_pop.append("interaction_kwargs") 157 | if "index" in batch.non_tensor_batch: 158 | non_tensor_batch_keys_to_pop.append("index") 159 | if "agent_name" in batch.non_tensor_batch: 160 | non_tensor_batch_keys_to_pop.append("agent_name") 161 | 162 | gen_batch = batch.pop( 163 | batch_keys=batch_keys_to_pop, 164 | non_tensor_batch_keys=non_tensor_batch_keys_to_pop, 165 | ) 166 | 167 | # pass global_steps to trace 168 | gen_batch.meta_info["global_steps"] = self.global_steps 169 | gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) 170 | 171 | is_last_step = self.global_steps >= self.total_training_steps 172 | 173 | with marked_timer("step", timing_raw): 174 | # generate a batch 175 | with marked_timer("gen", timing_raw, color="red"): 176 | if not self.async_rollout_mode: 177 | gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) 178 | else: 179 | gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) 180 | timing_raw.update(gen_batch_output.meta_info["timing"]) 181 | gen_batch_output.meta_info.pop("timing", None) 182 | 183 | if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: 184 | if self.reward_fn is None: 185 | raise ValueError("A reward_fn is required for REMAX advantage estimation.") 186 | 187 | with marked_timer("gen_max", timing_raw, color="purple"): 188 | gen_baseline_batch = deepcopy(gen_batch) 189 | gen_baseline_batch.meta_info["do_sample"] = False 190 | if not self.async_rollout_mode: 191 | gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) 192 | else: 193 | gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) 194 | batch = batch.union(gen_baseline_output) 195 | reward_baseline_tensor = self.reward_fn(batch) 196 | reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) 197 | 198 | batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) 199 | 200 | batch.batch["reward_baselines"] = reward_baseline_tensor 201 | 202 | del gen_baseline_batch, gen_baseline_output 203 | 204 | # repeat to align with repeated responses in rollout 205 | batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) 206 | batch = batch.union(gen_batch_output) 207 | 208 | if "response_mask" not in batch.batch.keys(): 209 | batch.batch["response_mask"] = compute_response_mask(batch) 210 | 211 | # compute global_valid tokens 212 | batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() 213 | 214 | with marked_timer("reward", timing_raw, color="yellow"): 215 | # compute reward model score 216 | if self.use_rm: 217 | reward_tensor = self.rm_wg.compute_rm_score(batch) 218 | batch = batch.union(reward_tensor) 219 | 220 | if self.config.reward_model.launch_reward_fn_async: 221 | future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn) 222 | else: 223 | reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) 224 | batch.batch["token_level_scores"] = reward_tensor 225 | if reward_extra_infos_dict: 226 | batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) 227 | reward_extra_infos_dict_keys = list(reward_extra_infos_dict.keys()) 228 | 229 | ################################### rStar ################################### 230 | # Need to refactor the launch_reward_fn_async to support down sampling, 231 | # only forbid combine launch_reward_fn_async and down sampling for now. 232 | with marked_timer("down_sample", timing_raw, color="yellow"): 233 | assert not (self.config.reward_model.launch_reward_fn_async and self.config.augmentation.do_down_sampling), \ 234 | "down sampling cannot combine with async reward function for now" 235 | batch, down_sampling_metrics = self._down_sample_batch(batch) 236 | metrics.update(down_sampling_metrics) 237 | if batch is None: 238 | continue 239 | ############################################################################# 240 | 241 | ################################### rStar ################################### 242 | # Move the balance logic after down sampling 243 | 244 | # Balance the number of valid tokens across DP ranks. 245 | # NOTE: This usually changes the order of data in the `batch`, 246 | # which won't affect the advantage calculation (since it's based on uid), 247 | # but might affect the loss calculation (due to the change of mini-batching). 248 | # TODO: Decouple the DP balancing and mini-batching. 249 | if self.config.trainer.balance_batch: 250 | self._balance_batch(batch, metrics=metrics) 251 | ############################################################################# 252 | 253 | # recompute old_log_probs 254 | with marked_timer("old_log_prob", timing_raw, color="blue"): 255 | old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) 256 | entropys = old_log_prob.batch["entropys"] 257 | response_masks = batch.batch["response_mask"] 258 | loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode 259 | entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) 260 | old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} 261 | metrics.update(old_log_prob_metrics) 262 | old_log_prob.batch.pop("entropys") 263 | batch = batch.union(old_log_prob) 264 | 265 | if "rollout_log_probs" in batch.batch.keys(): 266 | # TODO: we may want to add diff of probs too. 267 | from verl.utils.debug.metrics import calculate_debug_metrics 268 | 269 | metrics.update(calculate_debug_metrics(batch)) 270 | 271 | if self.use_reference_policy: 272 | # compute reference log_prob 273 | with marked_timer("ref", timing_raw, color="olive"): 274 | if not self.ref_in_actor: 275 | ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) 276 | else: 277 | ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) 278 | batch = batch.union(ref_log_prob) 279 | 280 | # compute values 281 | if self.use_critic: 282 | with marked_timer("values", timing_raw, color="cyan"): 283 | values = self.critic_wg.compute_values(batch) 284 | batch = batch.union(values) 285 | 286 | with marked_timer("adv", timing_raw, color="brown"): 287 | # we combine with rule-based rm 288 | ################################### rStar ################################### 289 | # Because down sampling cannot combine with config.reward_model.launch_reward_fn_async, 290 | # here refactor the reward setting logic, and recreate the reward_extra_infos_dict. 291 | if self.config.reward_model.launch_reward_fn_async: 292 | reward_tensor, reward_extra_infos_dict = ray.get(future_reward) 293 | batch.batch["token_level_scores"] = reward_tensor 294 | if reward_extra_infos_dict: 295 | batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) 296 | reward_extra_infos_dict_keys = list(reward_extra_infos_dict.keys()) 297 | reward_extra_infos_dict = {key: batch.non_tensor_batch[key].tolist() for key in reward_extra_infos_dict_keys} 298 | ################################################################################ 299 | 300 | # compute rewards. apply_kl_penalty if available 301 | if self.config.algorithm.use_kl_in_reward: 302 | batch, kl_metrics = apply_kl_penalty( 303 | batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty 304 | ) 305 | metrics.update(kl_metrics) 306 | else: 307 | batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] 308 | 309 | # compute advantages, executed on the driver process 310 | 311 | norm_adv_by_std_in_grpo = self.config.algorithm.get( 312 | "norm_adv_by_std_in_grpo", True 313 | ) # GRPO adv normalization factor 314 | 315 | batch = compute_advantage( 316 | batch, 317 | adv_estimator=self.config.algorithm.adv_estimator, 318 | gamma=self.config.algorithm.gamma, 319 | lam=self.config.algorithm.lam, 320 | num_repeat=self.config.actor_rollout_ref.rollout.n, 321 | norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, 322 | config=self.config.algorithm, 323 | ) 324 | 325 | # update critic 326 | if self.use_critic: 327 | with marked_timer("update_critic", timing_raw, color="pink"): 328 | critic_output = self.critic_wg.update_critic(batch) 329 | critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) 330 | metrics.update(critic_output_metrics) 331 | 332 | # implement critic warmup 333 | if self.config.trainer.critic_warmup <= self.global_steps: 334 | # update actor 335 | with marked_timer("update_actor", timing_raw, color="red"): 336 | batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable 337 | actor_output = self.actor_rollout_wg.update_actor(batch) 338 | actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) 339 | metrics.update(actor_output_metrics) 340 | 341 | # Log rollout generations if enabled 342 | rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) 343 | if rollout_data_dir: 344 | with marked_timer("dump_rollout_generations", timing_raw, color="green"): 345 | inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) 346 | outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) 347 | scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() 348 | sample_gts = [ 349 | item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) 350 | for item in batch 351 | ] 352 | 353 | if "request_id" in batch.non_tensor_batch: 354 | reward_extra_infos_dict.setdefault( 355 | "request_id", 356 | batch.non_tensor_batch["request_id"].tolist(), 357 | ) 358 | 359 | self._dump_generations( 360 | inputs=inputs, 361 | outputs=outputs, 362 | gts=sample_gts, 363 | scores=scores, 364 | reward_extra_infos_dict=reward_extra_infos_dict, 365 | dump_path=rollout_data_dir, 366 | ) 367 | 368 | # validate 369 | if ( 370 | self.val_reward_fn is not None 371 | and self.config.trainer.test_freq > 0 372 | and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) 373 | ): 374 | with marked_timer("testing", timing_raw, color="green"): 375 | val_metrics: dict = self._validate() 376 | if is_last_step: 377 | last_val_metrics = val_metrics 378 | metrics.update(val_metrics) 379 | 380 | # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. 381 | esi_close_to_expiration = should_save_ckpt_esi( 382 | max_steps_duration=self.max_steps_duration, 383 | redundant_time=self.config.trainer.esi_redundant_time, 384 | ) 385 | # Check if the conditions for saving a checkpoint are met. 386 | # The conditions include a mandatory condition (1) and 387 | # one of the following optional conditions (2/3/4): 388 | # 1. The save frequency is set to a positive value. 389 | # 2. It's the last training step. 390 | # 3. The current step number is a multiple of the save frequency. 391 | # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. 392 | if self.config.trainer.save_freq > 0 and ( 393 | is_last_step 394 | or self.global_steps % self.config.trainer.save_freq == 0 395 | or esi_close_to_expiration 396 | ): 397 | if esi_close_to_expiration: 398 | print("Force saving checkpoint: ESI instance expiration approaching.") 399 | with marked_timer("save_checkpoint", timing_raw, color="green"): 400 | self._save_checkpoint() 401 | 402 | with marked_timer("stop_profile", timing_raw): 403 | next_step_profile = ( 404 | self.global_steps + 1 in self.config.global_profiler.steps 405 | if self.config.global_profiler.steps is not None 406 | else False 407 | ) 408 | self._stop_profiling( 409 | curr_step_profile and not next_step_profile 410 | if self.config.global_profiler.profile_continuous_steps 411 | else curr_step_profile 412 | ) 413 | prev_step_profile = curr_step_profile 414 | curr_step_profile = next_step_profile 415 | 416 | steps_duration = timing_raw["step"] 417 | self.max_steps_duration = max(self.max_steps_duration, steps_duration) 418 | 419 | # training metrics 420 | metrics.update( 421 | { 422 | "training/global_step": self.global_steps, 423 | "training/epoch": epoch, 424 | } 425 | ) 426 | # collect metrics 427 | metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) 428 | metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) 429 | # TODO: implement actual tflpo and theoretical tflpo 430 | n_gpus = self.resource_pool_manager.get_n_gpus() 431 | metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) 432 | 433 | # this is experimental and may be changed/removed in the future in favor of a general-purpose one 434 | if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): 435 | self.train_dataloader.sampler.update(batch=batch) 436 | 437 | # TODO: make a canonical logger that supports various backend 438 | logger.log(data=metrics, step=self.global_steps) 439 | 440 | progress_bar.update(1) 441 | self.global_steps += 1 442 | 443 | if is_last_step: 444 | pprint(f"Final validation metrics: {last_val_metrics}") 445 | progress_bar.close() 446 | return 447 | 448 | # this is experimental and may be changed/removed in the future 449 | # in favor of a general-purpose data buffer pool 450 | if hasattr(self.train_dataset, "on_batch_end"): 451 | # The dataset may be changed after each training batch 452 | self.train_dataset.on_batch_end(batch=batch) 453 | --------------------------------------------------------------------------------