.+?)\}$", answer)
51 | if m is not None:
52 | answer = m.group("text").strip()
53 | return _strip_string(answer)
54 | except: # noqa: E722
55 | return answer
56 |
57 |
58 | def _fix_fracs(string):
59 | substrs = string.split("\\frac")
60 | new_str = substrs[0]
61 | if len(substrs) > 1:
62 | substrs = substrs[1:]
63 | for substr in substrs:
64 | new_str += "\\frac"
65 | if substr[0] == "{":
66 | new_str += substr
67 | else:
68 | try:
69 | assert len(substr) >= 2
70 | except: # noqa: E722
71 | return string
72 | a = substr[0]
73 | b = substr[1]
74 | if b != "{":
75 | if len(substr) > 2:
76 | post_substr = substr[2:]
77 | new_str += "{" + a + "}{" + b + "}" + post_substr
78 | else:
79 | new_str += "{" + a + "}{" + b + "}"
80 | else:
81 | if len(substr) > 2:
82 | post_substr = substr[2:]
83 | new_str += "{" + a + "}" + b + post_substr
84 | else:
85 | new_str += "{" + a + "}" + b
86 | string = new_str
87 | return string
88 |
89 |
90 | def _fix_a_slash_b(string):
91 | if len(string.split("/")) != 2:
92 | return string
93 | a = string.split("/")[0]
94 | b = string.split("/")[1]
95 | try:
96 | a = int(a)
97 | b = int(b)
98 | assert string == "{}/{}".format(a, b)
99 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
100 | return new_string
101 | except: # noqa: E722
102 | return string
103 |
104 |
105 | def _remove_right_units(string):
106 | # "\\text{ " only ever occurs (at least in the val set) when describing units
107 | if "\\text{ " in string:
108 | splits = string.split("\\text{ ")
109 | assert len(splits) == 2
110 | return splits[0]
111 | else:
112 | return string
113 |
114 |
115 | def _fix_sqrt(string):
116 | if "\\sqrt" not in string:
117 | return string
118 | splits = string.split("\\sqrt")
119 | new_string = splits[0]
120 | for split in splits[1:]:
121 | if split[0] != "{":
122 | a = split[0]
123 | new_substr = "\\sqrt{" + a + "}" + split[1:]
124 | else:
125 | new_substr = "\\sqrt" + split
126 | new_string += new_substr
127 | return new_string
128 |
129 |
130 | def _strip_string(string):
131 | # linebreaks
132 | string = string.replace("\n", "")
133 |
134 | # remove inverse spaces
135 | string = string.replace("\\!", "")
136 |
137 | # replace \\ with \
138 | string = string.replace("\\\\", "\\")
139 |
140 | # replace tfrac and dfrac with frac
141 | string = string.replace("tfrac", "frac")
142 | string = string.replace("dfrac", "frac")
143 |
144 | # remove \left and \right
145 | string = string.replace("\\left", "")
146 | string = string.replace("\\right", "")
147 |
148 | # Remove circ (degrees)
149 | string = string.replace("^{\\circ}", "")
150 | string = string.replace("^\\circ", "")
151 |
152 | # remove dollar signs
153 | string = string.replace("\\$", "")
154 |
155 | # remove units (on the right)
156 | string = _remove_right_units(string)
157 |
158 | # remove percentage
159 | string = string.replace("\\%", "")
160 | string = string.replace("\%", "")
161 |
162 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
163 | string = string.replace(" .", " 0.")
164 | string = string.replace("{.", "{0.")
165 | # if empty, return empty string
166 | if len(string) == 0:
167 | return string
168 | if string[0] == ".":
169 | string = "0" + string
170 |
171 | # to consider: get rid of e.g. "k = " or "q = " at beginning
172 | if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
173 | string = string.split("=")[1]
174 |
175 | # fix sqrt3 --> sqrt{3}
176 | string = _fix_sqrt(string)
177 |
178 | # remove spaces
179 | string = string.replace(" ", "")
180 |
181 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1).
182 | # Also does a/b --> \\frac{a}{b}
183 | string = _fix_fracs(string)
184 |
185 | # manually change 0.5 --> \frac{1}{2}
186 | if string == "0.5":
187 | string = "\\frac{1}{2}"
188 |
189 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
190 | string = _fix_a_slash_b(string)
191 |
192 | return string
193 |
--------------------------------------------------------------------------------
/rstar2_agent/down_sample/roc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import re
5 | import numpy as np
6 | import torch
7 | from pprint import pprint
8 | from typing import List
9 | from transformers import PreTrainedTokenizerFast
10 |
11 | from verl.protocol import DataProto
12 | from .utils import filter_by_mask, decode_prompt_response_str
13 |
14 |
15 | def resample_of_correct(batch: DataProto, tokenizer: PreTrainedTokenizerFast, config: dict, do_sample=True, world_size=None):
16 | roc_error_ratio = config["roc_error_ratio"]
17 | roc_answer_format = config["roc_answer_format"]
18 | min_zero_reward_trace_num = config["min_zero_reward_trace_num"]
19 | min_non_zero_reward_trace_num = config["min_non_zero_reward_trace_num"]
20 | down_sample_to_n = config["down_sample_to_n"]
21 | assert min_zero_reward_trace_num + min_non_zero_reward_trace_num <= down_sample_to_n, \
22 | f"Invalid down sampling configuration: {min_zero_reward_trace_num=}, {min_non_zero_reward_trace_num=}, {down_sample_to_n=}"
23 |
24 | _, response_text = decode_prompt_response_str(batch, tokenizer)
25 | penalty_weights = np.zeros(len(response_text))
26 | metrics = {}
27 |
28 | # calculate error ratio penalty weight
29 | _penalty_weights, _metrics = calc_error_ratio_penalty_weights(response_text)
30 | metrics.update(_metrics)
31 | if roc_error_ratio:
32 | penalty_weights += _penalty_weights
33 |
34 | # calculate format penalty weight
35 | _penalty_weights, _metrics = calc_format_penalty_weights(response_text)
36 | metrics.update(_metrics)
37 | if roc_answer_format:
38 | penalty_weights += _penalty_weights
39 |
40 | # sample by penalty weights
41 | if do_sample and down_sample_to_n > 0:
42 | uids = batch.non_tensor_batch['uid']
43 | unique_uids = np.unique(uids)
44 | valid_mask = torch.zeros(len(uids), dtype=torch.bool)
45 |
46 | for uid in unique_uids:
47 | indices = np.where(uids == uid)[0]
48 | if len(indices) < down_sample_to_n:
49 | continue # Not enough samples for this uid, skip
50 | if len(indices) == down_sample_to_n:
51 | valid_mask[indices] = True
52 | continue
53 | uid_mask = uids == uid
54 | uid_rewards = batch.batch['token_level_scores'][uid_mask].sum(-1)
55 |
56 | zero_reward_pairs = [(indice, penalty_weight) for indice, uid_reward, penalty_weight in zip(indices, uid_rewards, penalty_weights[uid_mask]) if uid_reward <= 0]
57 | non_zero_reward_pairs = [(indice, penalty_weight) for indice, uid_reward, penalty_weight in zip(indices, uid_rewards, penalty_weights[uid_mask]) if uid_reward > 0]
58 | non_zero_reward_pairs.sort(key=lambda x: x[1])
59 | zero_reward_trace_num = round(len(zero_reward_pairs) * down_sample_to_n / len(indices))
60 | non_zero_reward_trace_num = round(len(non_zero_reward_pairs) * down_sample_to_n / len(indices))
61 | if zero_reward_trace_num < min_zero_reward_trace_num and non_zero_reward_trace_num < min_non_zero_reward_trace_num:
62 | pprint(f"Total trace number before down sampling: {len(indices)}, smaller than {min_zero_reward_trace_num=} + {min_non_zero_reward_trace_num=}")
63 | valid_mask[indices] = True
64 | else:
65 | if zero_reward_trace_num <= min(min_zero_reward_trace_num, len(zero_reward_pairs)):
66 | zero_reward_trace_num = min(min_zero_reward_trace_num, len(zero_reward_pairs))
67 | non_zero_reward_trace_num = down_sample_to_n - zero_reward_trace_num
68 | if non_zero_reward_trace_num <= min(min_non_zero_reward_trace_num, len(non_zero_reward_pairs)):
69 | non_zero_reward_trace_num = min(min_non_zero_reward_trace_num, len(non_zero_reward_pairs))
70 | zero_reward_trace_num = down_sample_to_n - non_zero_reward_trace_num
71 | choices = [non_zero_reward_pair[0] for non_zero_reward_pair in non_zero_reward_pairs[:non_zero_reward_trace_num]] \
72 | + [zero_reward_pair[0] for zero_reward_pair in zero_reward_pairs[:zero_reward_trace_num]]
73 | assert len(choices) == down_sample_to_n, f"{down_sample_to_n=} != {len(choices)}"
74 | valid_mask[choices] = True
75 |
76 | batch = filter_by_mask(batch, valid_mask, world_size)
77 | return batch, metrics
78 |
79 |
80 | def calc_error_ratio_penalty_weights(response_text: List[str]):
81 | def error_ratio(text, pattern=r'.*?'):
82 | matches = re.findall(pattern, text, re.DOTALL)
83 | error_count = len([match for match in matches if 'error' in match.lower()])
84 | if len(matches) == 0:
85 | return 0.5, 0, 0
86 | else:
87 | return error_count / len(matches), error_count, len(matches)
88 |
89 | penalty_weights = []
90 | total_error_count, total_res_count = 0, 0
91 |
92 | for text in response_text:
93 | penalty_weight, error_count, res_count = error_ratio(text)
94 | penalty_weights.append(penalty_weight)
95 | total_error_count += error_count
96 | total_res_count += res_count
97 | metrics = {
98 | 'roc_error_ratio/global_err_ratio': total_error_count / total_res_count if total_res_count > 0 else 0,
99 | 'roc_error_ratio/penalty_weight': np.mean(penalty_weights) if penalty_weights else 0,
100 | }
101 | return np.array(penalty_weights), metrics
102 |
103 |
104 | def calc_format_penalty_weights(response_text: List[str]):
105 | def answer_tag_repetition(text: str, answer_tags=["", ""], answer_pattern=r'.*?', turn_pattern=r'<\|im_start\|>assistant.*?<\|im_end\|>'):
106 | if any(ans_tag not in text for ans_tag in answer_tags):
107 | return 1.0, 0
108 |
109 | answer_tags_count = [text.count(ans_tag) for ans_tag in answer_tags]
110 | closed_ans_tag_count = len(re.findall(answer_pattern, text, re.DOTALL))
111 | if any(ans_tag_count!=closed_ans_tag_count for ans_tag_count in answer_tags_count):
112 | return 1.0, closed_ans_tag_count
113 |
114 | matches = re.findall(turn_pattern, text, re.DOTALL)
115 | num_turns = len(matches)
116 | if num_turns == 0:
117 | return 1.0, closed_ans_tag_count
118 |
119 | penalty_weight = min((closed_ans_tag_count - 1) / num_turns, 1.0)
120 | return penalty_weight, closed_ans_tag_count
121 |
122 | penalty_weights = []
123 | total_ans_count, zero_ans_count, one_ans_count, gt_one_ans_count = 0, 0, 0, 0
124 | for text in response_text:
125 | penalty_weight, ans_tag_count = answer_tag_repetition(text)
126 | penalty_weights.append(penalty_weight)
127 | total_ans_count += ans_tag_count
128 | zero_ans_count += (1 if ans_tag_count == 0 else 0)
129 | one_ans_count += (1 if ans_tag_count == 1 else 0)
130 | gt_one_ans_count += (1 if ans_tag_count > 1 else 0)
131 |
132 | metrics = {
133 | 'roc_answer_format/answer_per_rollout_mean': total_ans_count / len(response_text),
134 | 'roc_answer_format/zero_answer_count': zero_ans_count,
135 | 'roc_answer_format/one_answer_count': one_ans_count,
136 | 'roc_answer_format/gt_one_answer_count': gt_one_ans_count,
137 | }
138 | return np.array(penalty_weights), metrics
139 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | rStar2-Agent
4 |
5 |
6 |
7 | 📃 [Paper]
8 |
9 |
10 | Repo for "[rStar2-Agent: Agentic Reasoning Technical Report](https://huggingface.co/papers/2508.20722)".
11 |
12 | Authors: Ning Shang\*, Yifei Liu\*, Yi Zhu\*, Li Lyna Zhang\*†, Weijiang Xu, Xinyu Guan, Buze Zhang, Bingcheng Dong, Xudong Zhou, Bowen Zhang, Ying Xin, Ziming Miao, Scarlett Li, Fan Yang, Mao Yang†
13 |
14 |
15 |
16 |
17 | Figure 1: rStar2-Agent-14B reaches frontier-level math reasoning in just 510 RL training step
18 |
19 |
20 | ## News
21 |
22 | - **[07/15/2025]** Our rStar-Coder [paper](https://arxiv.org/abs/2505.21297) and [dataset](https://huggingface.co/datasets/microsoft/rStar-Coder) are released. We introduce a large-scale, verified dataset of 418K competition-level code problems with **test cases** of varying difficulty, enabling small LLMs (1.5B-14B) to achieve frontier-level code reasoning performance.
23 | - **[02/10/2025]** We are hiring interns! If you are interested in improving LLM reasoning, please send your CV to lzhani@microsoft.com.
24 | - **[01/21/2025]** rStar-Math code has been open-sourced.
25 | - **[01/09/2025]** rStar-Math paper is released: https://huggingface.co/papers/2501.04519.
26 |
27 | Note: Our prior work [Mutual Reasoning Makes Smaller LLMs Stronger Problem-Solvers](https://huggingface.co/papers/2408.06195) is open-sourced on the [rStar-mutualreasoning b](https://github.com/microsoft/rStar/tree/rStar-mutualreasoning) branch.
28 |
29 | Note: Our prior work [rStar-Math: Small LLMs Can Master Math Reasoning with Self-Evolved Deep Thinking](https://huggingface.co/papers/2501.04519) is open-sourced on the [rStar-math](https://github.com/microsoft/rStar/tree/rStar-math) branch.
30 |
31 | ## Contents
32 | - [Introduction](#Introduction)
33 | - [Try rStar2-Agent with Tool Calling](#Try-rStar2-Agent-with-Tool-Calling)
34 | - [Evaluation](#Evaluation)
35 | - [rStar2-Agent RL Training](#rStar2-Agent-RL-Training)
36 | - [Citation](#Citation)
37 |
38 | ## Introduction
39 | We introduce rStar2-Agent, a 14B math reasoning model that thinks smarter rather than merely longer, achieving performance comparable to 671B DeepSeek-R1 through pure agentic reinforcement learning. The model plans, reasons, and autonomously uses coding tools to efficiently explore, verify, and reflect for more complex problem-solving. This capability relies on three key innovations: (i) GRPO-RoC, an effective agentic reinforcement learning algorithm with a novel Resample-on-Correct rollout strategy that optimizes coding tool usage and enables shorter, smarter reasoning by selectively retaining higher-quality positive trajectories while preserving all failure cases; (ii) a scalable and efficient RL infrastructure that supports high-throughput tool call execution and mitigates the high costs of agentic RL rollout, enabling efficient training on limited GPU resources (64 MI300X GPUs); (iii) an agent training recipe that starts with non-reasoning SFT and proceeds through multi-stage RL with concise maximum response lengths per stage and increasing dataset difficulty. To this end, rStar2-Agent boosts a pre-trained 14B model to state-of-the-art levels in only 510 RL steps within one week, achieving 80.6% and 69.8% average pass@1 on AIME24 and AIME25, surpassing DeepSeek-R1 (671B) with shorter responses. Beyond mathematics, rStar2-Agent-14B also demonstrates strong generalization to alignment, scientific reasoning, and agentic tool-use tasks.
40 |
41 | ## Try rStar2-Agent with Tool Calling
42 |
43 | ### Installation
44 |
45 | #### Option 1: Manual Installation
46 |
47 | ```bash
48 | # Initialize and update submodules
49 | git submodule init
50 | git submodule update
51 |
52 | # install verl
53 | pip install "torch<2.8"
54 | pip install -r verl/requirements_sglang.txt
55 | pip install -e verl
56 |
57 | # install code judge
58 | pip install -r code-judge/requirements.txt
59 | pip install -e code-judge
60 |
61 | # install rstar2_agent
62 | pip install -e .
63 | ```
64 |
65 | #### Option 2: Automated Installation
66 |
67 | ```bash
68 | bash install.sh
69 | ```
70 |
71 | ### Code Judge Server Setup
72 |
73 | > ⚠️ **Security Warning**: Code Judge executes arbitrary code. Always deploy in an isolated environment (preferably Docker) and never expose to external networks.
74 |
75 | The rStar2-Agent uses Code Judge as a tool call server to execute model-generated Python code.
76 |
77 | #### 1. Start Redis Server
78 |
79 | ```bash
80 | sudo apt-get update -y && sudo apt-get install redis -y
81 | redis-server --daemonize yes --protected-mode no --bind 0.0.0.0
82 | ```
83 |
84 | #### 2. Launch Code Judge Server
85 |
86 | ```bash
87 | # Start the main server (master node only)
88 | # Environment variables can be configured as per: https://github.com/0xWJ/code-judge/blob/main/app/config.py
89 | # Replace $WORKSPACE and $MASTER_ADDR with your actual paths
90 |
91 | tmux new-session -d -s server \
92 | 'cd $WORKSPACE/code-judge && \
93 | MAX_EXECUTION_TIME=4 \
94 | REDIS_URI="redis://$MASTER_ADDR:6379" \
95 | RUN_WORKERS=0 \
96 | uvicorn app.main:app --host 0.0.0.0 --port 8088 --workers 16 \
97 | 2>&1 | tee server.log'
98 | ```
99 |
100 | #### 3. Start Code Judge Workers
101 |
102 | ```bash
103 | # Launch workers (can be deployed on multiple nodes for increased parallelism)
104 | # Adjust MAX_WORKERS based on your CPU count per node
105 |
106 | tmux new-session -d -s worker \
107 | 'cd $WORKSPACE/code-judge && \
108 | MAX_EXECUTION_TIME=4 \
109 | REDIS_URI="redis://$MASTER_ADDR:6379" \
110 | MAX_WORKERS=64 \
111 | python run_workers.py \
112 | 2>&1 | tee worker.log'
113 | ```
114 |
115 | ### Launch the VLLM Server
116 |
117 | First, start the VLLM server:
118 |
119 | ```bash
120 | vllm serve /path/to/your/model \
121 | --host 0.0.0.0 \
122 | --port 8000 \
123 | --enable-auto-tool-choice \
124 | --tool-call-parser hermes
125 | ```
126 |
127 | Replace `/path/to/your/model` with the actual path to your downloaded model.
128 |
129 | ### Verify Server Status
130 |
131 | Check if the server is running properly:
132 |
133 | ```bash
134 | curl http://localhost:8000/v1/models
135 | ```
136 |
137 | ### Run Interactive Chat with Tool Calling
138 |
139 | Use the provided script to interact with your model:
140 |
141 | ```bash
142 | python examples/chat_with_tool_call.py \
143 | --model /path/to/your/model \
144 | --prompt "Solve the system of equations: 2x + 3y = 7, x - y = 1" \
145 | --max_tokens 8192
146 | ```
147 |
148 | ### Script Options
149 |
150 | The `examples/chat_with_tool_call.py` script supports the following arguments:
151 |
152 | - `--model`: Path to your model
153 | - `--prompt`: Input prompt for the model
154 | - `--max_tokens`: Maximum number of tokens to generate
155 |
156 | ## Evaluation
157 |
158 | ### Environment Setup
159 |
160 | Please view [Installation](#Installation) and [Code Judge Server Setup](#Code-Judge-Server-Setup).
161 |
162 | ### Run Evaluation Script
163 |
164 | We evaluate following mathematical reasoning benchmarks:
165 |
166 | - **AIME 2024/2025 (American Invitational Mathematics Examination)**: High-school level competition mathematics
167 | - **MATH500**: A subset of the MATH dataset containing 500 challenging problems
168 |
169 | ```bash
170 | MODEL_PATH=/path/to/your/model bash examples/aime_eval.sh
171 | MODEL_PATH=/path/to/your/model bash examples/math500_eval.sh
172 | ```
173 |
174 | ## rStar2-Agent RL Training
175 |
176 | A comprehensive reinforcement learning training framework for the rStar2-Agent, built on [Verl](https://github.com/volcengine/verl) and [Code Judge](https://github.com/0xWJ/code-judge). This framework enables training models after instruction-following supervised fine-tuning (SFT).
177 |
178 | ### Environment Setup
179 |
180 | Please view [Installation](#Installation) and [Code Judge Server Setup](#Code-Judge-Server-Setup).
181 |
182 | ### Data Preparation
183 |
184 | This example uses:
185 | - **Training Dataset**: DAPO-17k (English subset)
186 | - **Test Dataset**: AIME24
187 |
188 | ```bash
189 | # Process AIME 2024 dataset
190 | python data_preprocess/aime2024_rstar2_agent_loop.py
191 |
192 | # Process DAPO dataset
193 | python data_preprocess/dapo_rstar2_agent_loop.py
194 | ```
195 |
196 | ### Model Setup
197 |
198 | Download the base model (Qwen3-14B-Base):
199 |
200 | ```bash
201 | huggingface-cli download Qwen/Qwen3-14B-Base --local-dir $HOME/models/Qwen3-14B-Base
202 | ```
203 |
204 | > **Note**: The base model requires instruction-following SFT before RL training for optimal performance.
205 |
206 | ### Training
207 |
208 | #### Basic Training
209 |
210 | Run the training script (for 8x A100/H100 GPUs):
211 |
212 | ```bash
213 | bash examples/run_qwen3-14b_rstar2_agent_weave.sh
214 | ```
215 |
216 | > Adjust configuration parameters based on your hardware environment.
217 |
218 | ### Configuration
219 |
220 | #### Data Augmentation Settings
221 |
222 | The framework supports various sampling strategies to improve training efficiency:
223 |
224 | ```bash
225 | # Global Settings
226 | augmentation.do_down_sampling=True # Enable down sampling
227 | augmentation.down_sampling_config.down_sample_to_n=16 # Target number of traces per data point
228 |
229 | # Sampling Strategies
230 | augmentation.down_sampling_config.reject_equal_reward=True # Enable reject sampling for equal rewards
231 | augmentation.down_sampling_config.roc_error_ratio=True # Resample correct traces by tool call error ratio
232 | augmentation.down_sampling_config.roc_answer_format=True # Resample correct traces by answer format
233 |
234 | # Minimum Trace Requirements
235 | augmentation.down_sampling_config.min_zero_reward_trace_num=2 # Minimum negative traces to retain
236 | augmentation.down_sampling_config.min_non_zero_reward_trace_num=2 # Minimum positive traces to retain
237 | ```
238 |
239 | ### Important Note
240 |
241 | rStar2-Agent was originally training based on VERL v0.2 with our custom multi-turn tool calling training framework. The current training framework released here has been migrated to VERL v0.5 to ensure compatibility with the latest community standards. While this release framework hasn't been used to train a complete model yet, we have verified that the first 50 training steps show minimal differences between our original and migrated frameworks, maintaining the core functionality of our proven training approach.
242 |
243 | Although our original framework includes additional advanced features such as rollout request load balance scheduler, we chose to migrate to the latest VERL version to maintain community compatibility and facilitate easier customization by users. This approach ensures you can benefit from ongoing VERL improvements and easily integrate with the latest open-source developments. We also consider migrating all features to the current version in the future.
244 |
245 | If you encounter any issues during usage or need assistance with the training framework, please contact us.
246 |
247 | ### Troubleshooting
248 |
249 | #### Common Issues
250 |
251 | 1. **Redis Connection Errors**: Ensure Redis is running and accessible at the specified address
252 | 2. **GPU Memory Issues**: Adjust batch sizes and model parameters for your hardware
253 | 3. **Code Judge Timeouts**: Increase `MAX_EXECUTION_TIME` for complex computations
254 | 4. **Worker Scaling**: Adjust `MAX_WORKERS` based on available CPU cores
255 |
256 | #### Log Locations
257 |
258 | - Server logs: `server.log` in the code-judge directory
259 | - Worker logs: `worker.log` in the code-judge directory
260 | - Training logs: Check your training script output directory
261 |
262 | ---
263 |
264 |
265 | ## Citation
266 | If you find this repo useful for your research, please consider citing the paper
267 | ```
268 | @misc{shang2025rstar2agentagenticreasoningtechnical,
269 | title={rStar2-Agent: Agentic Reasoning Technical Report},
270 | author={Ning Shang and Yifei Liu and Yi Zhu and Li Lyna Zhang and Weijiang Xu and Xinyu Guan and Buze Zhang and Bingcheng Dong and Xudong Zhou and Bowen Zhang and Ying Xin and Ziming Miao and Scarlett Li and Fan Yang and Mao Yang},
271 | year={2025},
272 | eprint={2508.20722},
273 | archivePrefix={arXiv},
274 | primaryClass={cs.CL},
275 | url={https://arxiv.org/abs/2508.20722},
276 | }
277 | ```
278 |
--------------------------------------------------------------------------------
/rstar2_agent/rollout/rstar2_agent_loop.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import asyncio
5 | import copy
6 | import json
7 | import logging
8 | import os
9 | from typing import Any
10 | from uuid import uuid4
11 |
12 | from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register
13 | from verl.experimental.agent_loop.tool_agent_loop import ToolAgentLoop
14 | from verl.experimental.agent_loop.tool_parser import FunctionCall
15 | from verl.tools.schemas import ToolResponse
16 | from verl.utils.profiler import simple_timer
17 | from verl.utils.rollout_trace import rollout_trace_op
18 |
19 | logger = logging.getLogger(__file__)
20 | logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
21 |
22 |
23 | @register("rstar2_agent")
24 | class RStar2AgentLoop(ToolAgentLoop):
25 | @rollout_trace_op
26 | async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
27 | messages = list(kwargs["raw_prompt"])
28 | image_data = copy.deepcopy(kwargs.get("multi_modal_data", {}).get("image", None))
29 | metrics = {}
30 | request_id = uuid4().hex
31 | if self.processor is not None:
32 | raw_prompt = await self.loop.run_in_executor(
33 | None,
34 | lambda: self.processor.apply_chat_template(
35 | messages,
36 | tools=self.tool_schemas,
37 | add_generation_prompt=True,
38 | tokenize=False,
39 | **self.apply_chat_template_kwargs,
40 | ),
41 | )
42 | model_inputs = self.processor(text=[raw_prompt], images=image_data, return_tensors="pt")
43 | prompt_ids = model_inputs.pop("input_ids").squeeze(0).tolist()
44 | else:
45 | prompt_ids = await self.loop.run_in_executor(
46 | None,
47 | lambda: self.tokenizer.apply_chat_template(
48 | messages,
49 | tools=self.tool_schemas,
50 | add_generation_prompt=True,
51 | tokenize=True,
52 | **self.apply_chat_template_kwargs,
53 | ),
54 | )
55 | response_mask = []
56 | tools_kwargs = kwargs.get("tools_kwargs", {})
57 | ################################### rStar ###################################
58 | history_tool_calls = [] # Keep track of all tool calls made during the conversation
59 | # budget = len(prompt_ids) + self.response_length
60 | #############################################################################
61 |
62 | user_turns, assistant_turns = 0, 0
63 | while True:
64 | with simple_timer("generate_sequences", metrics):
65 | ################################### rStar ###################################
66 | sampling_params["max_new_tokens"] = self.response_length - len(response_mask)
67 | #############################################################################
68 | response_ids = await self.server_manager.generate(
69 | request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params, image_data=image_data
70 | )
71 | prompt_ids += response_ids
72 | response_mask += [1] * len(response_ids)
73 | assistant_turns += 1
74 |
75 | # reach max response length
76 | if len(response_mask) >= self.response_length:
77 | # self.server_manager._release_request(request_id, budget)
78 | break
79 |
80 | # reach max assistant turns
81 | if self.max_assistant_turns and assistant_turns >= self.max_assistant_turns:
82 | # self.server_manager._release_request(request_id, budget)
83 | break
84 |
85 | # reach max user turns
86 | if self.max_user_turns and user_turns >= self.max_user_turns:
87 | # self.server_manager._release_request(request_id, budget)
88 | break
89 |
90 | # no tool calls
91 | _, tool_calls = await self.tool_parser.extract_tool_calls(response_ids)
92 | if not tool_calls:
93 | # self.server_manager._release_request(request_id, budget)
94 | break
95 |
96 | ################################### rStar ###################################
97 | tool_calls = tool_calls[: self.max_parallel_calls]
98 | total_tool_responses, filtered_tool_calls, pending_pos = [], [], []
99 | for i, tool_call in enumerate(tool_calls):
100 | if isinstance(tool_call, ToolResponse):
101 | total_tool_responses.append(tool_call)
102 | else:
103 | total_tool_responses.append(None)
104 | pending_pos.append(i)
105 | filtered_tool_calls.append(tool_call)
106 | tool_calls = filtered_tool_calls
107 | #############################################################################
108 | # call tools
109 | tasks = []
110 | for tool_call in tool_calls[: self.max_parallel_calls]:
111 | ################################### rStar ###################################
112 | tools_kwargs_copy = dict(tools_kwargs) # Copy to avoid modifying original
113 | tools_kwargs_copy["history_tool_calls"] = list(history_tool_calls) # Pass history tool calls
114 | tasks.append(self._call_tool(tool_call, tools_kwargs_copy))
115 | history_tool_calls.append(tool_call)
116 | #############################################################################
117 | with simple_timer("tool_calls", metrics):
118 | tool_responses = await asyncio.gather(*tasks)
119 | ################################### rStar ###################################
120 | assert len(pending_pos[: self.max_parallel_calls]) == len(tool_responses)
121 | for i, tool_response in zip(pending_pos[: self.max_parallel_calls], tool_responses):
122 | total_tool_responses[i] = tool_response
123 | tool_responses = total_tool_responses
124 | #############################################################################
125 | if any(isinstance(item, Exception) for item in tool_responses):
126 | # self.server_manager._release_request(request_id, budget)
127 | break
128 |
129 | # Extract messages and update multi_modal_data
130 | tool_messages = []
131 | new_images_this_turn = []
132 | for tool_response in tool_responses:
133 | # Create message from tool response
134 | if tool_response.image or tool_response.video:
135 | # Multi-modal content with structured format
136 | content = []
137 | if tool_response.image:
138 | content.append({"type": "image"})
139 | if tool_response.video:
140 | content.append({"type": "video"})
141 | if tool_response.text:
142 | content.append({"type": "text", "text": tool_response.text})
143 | message = {"role": "tool", "content": content}
144 | else:
145 | # Text-only content
146 | message = {"role": "tool", "content": tool_response.text or ""}
147 |
148 | tool_messages.append(message)
149 |
150 | # Handle image data
151 | if tool_response.image:
152 | if image_data is None:
153 | image_data = []
154 | elif not isinstance(image_data, list):
155 | image_data = [image_data]
156 |
157 | # Add new image data
158 | if isinstance(tool_response.image, list):
159 | image_data.extend(tool_response.image)
160 | new_images_this_turn.extend(tool_response.image)
161 | else:
162 | image_data.append(tool_response.image)
163 | new_images_this_turn.append(tool_response.image)
164 |
165 | # Handle video data
166 | if tool_response.video:
167 | # Currently not supported, raise informative error
168 | logger.warning("Multimedia type 'video' is not currently supported. Only 'image' is supported.")
169 | raise NotImplementedError(
170 | "Multimedia type 'video' is not currently supported. Only 'image' is supported."
171 | )
172 |
173 | # append tool_response_ids
174 | if self.processor is not None:
175 | raw_tool_response = await self.loop.run_in_executor(
176 | None,
177 | lambda messages=tool_messages: self.processor.apply_chat_template(
178 | messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
179 | ),
180 | )
181 | # Use only the new images from this turn for processing tool responses
182 | current_images = new_images_this_turn if new_images_this_turn else None
183 | model_inputs = self.processor(text=[raw_tool_response], images=current_images, return_tensors="pt")
184 | tool_response_ids = model_inputs.pop("input_ids").squeeze(0).tolist()
185 | else:
186 | tool_response_ids = await self.loop.run_in_executor(
187 | None,
188 | lambda messages=tool_messages: self.tokenizer.apply_chat_template(
189 | messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs
190 | ),
191 | )
192 | tool_response_ids = tool_response_ids[len(self.system_prompt) :]
193 |
194 | # NOTE: last turn should not be user turn, or the EOS token reward
195 | # can't be propagated to previous token in GAE.
196 | if len(response_mask) + len(tool_response_ids) >= self.response_length:
197 | # self.server_manager._release_request(request_id, budget)
198 | break
199 |
200 | prompt_ids += tool_response_ids
201 | response_mask += [0] * len(tool_response_ids)
202 | user_turns += 1
203 |
204 | response_ids = prompt_ids[-len(response_mask) :]
205 | prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
206 |
207 | multi_modal_data = {"image": image_data} if image_data is not None else {}
208 |
209 | output = AgentLoopOutput(
210 | prompt_ids=prompt_ids,
211 | response_ids=response_ids[: self.response_length],
212 | response_mask=response_mask[: self.response_length],
213 | multi_modal_data=multi_modal_data,
214 | num_turns=user_turns + assistant_turns + 1,
215 | metrics=metrics,
216 | )
217 | return output
218 |
219 | async def _call_tool(self, tool_call: FunctionCall, tools_kwargs: dict[str, Any]) -> ToolResponse:
220 | """Call tool and return tool response."""
221 | tool, instance_id = None, None
222 | try:
223 | # TODO: append malformed tool_call to the prompt: invalid function name or arguments
224 | tool_name = tool_call.name
225 | tool_args = json.loads(tool_call.arguments)
226 | tool = self.tools[tool_name]
227 | kwargs = tools_kwargs.get(tool_name, {})
228 | ################################### rStar ###################################
229 | instance_id, _ = await tool.create(
230 | create_kwargs=kwargs.get("create_kwargs", {}),
231 | history_tool_calls=tools_kwargs.get("history_tool_calls", []),
232 | )
233 | #############################################################################
234 | tool_execution_response, _, _ = await tool.execute(instance_id, tool_args)
235 | except Exception as e:
236 | logger.warning(f"Error when executing tool: {e}")
237 | return ToolResponse(
238 | text=f"Error when executing tool: {e}",
239 | )
240 | finally:
241 | if tool and instance_id:
242 | await tool.release(instance_id)
243 |
244 | tool_response_text = tool_execution_response.text
245 | if tool_response_text and len(tool_response_text) > self.max_tool_response_length:
246 | if self.tool_response_truncate_side == "left":
247 | tool_response_text = tool_response_text[: self.max_tool_response_length] + "...(truncated)"
248 | elif self.tool_response_truncate_side == "right":
249 | tool_response_text = "(truncated)..." + tool_response_text[-self.max_tool_response_length :]
250 | else:
251 | length = self.max_tool_response_length // 2
252 | tool_response_text = tool_response_text[:length] + "...(truncated)..." + tool_response_text[-length:]
253 |
254 | # Create ToolResponse from tool execution result
255 | tool_response_kwargs = {"text": tool_response_text}
256 |
257 | # Add multimedia data if present
258 | for attr_name in ["image", "video"]:
259 | if hasattr(tool_execution_response, attr_name):
260 | attr_value = getattr(tool_execution_response, attr_name)
261 | if attr_value is not None:
262 | tool_response_kwargs[attr_name] = attr_value
263 |
264 | return ToolResponse(**tool_response_kwargs)
265 |
--------------------------------------------------------------------------------
/fused_compute_score/prime_math/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 PRIME team and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Answer checker API that uses sympy to simplify expressions and check for equality.
16 |
17 | Call grade_answer(given_answer: str, ground_truth: str).
18 |
19 | FROM: https://github.com/openai/prm800k/blob/main/prm800k/grading/grader.py
20 | """
21 |
22 | import contextlib
23 | import math
24 | import re
25 |
26 | import sympy
27 | from pylatexenc import latex2text
28 | from sympy.parsing import sympy_parser
29 |
30 | from . import math_normalize
31 | from .grader import math_equal, timeout_limit
32 |
33 | # import math_normalize
34 | # from grader import math_equal
35 |
36 | # sympy might hang -- we don't care about trying to be lenient in these cases
37 | BAD_SUBSTRINGS = ["^{", "^("]
38 | BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"]
39 | TUPLE_CHARS = "()[]"
40 |
41 |
42 | def _sympy_parse(expr: str):
43 | """Parses an expression with sympy."""
44 | py_expr = expr.replace("^", "**")
45 | return sympy_parser.parse_expr(
46 | py_expr,
47 | transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)),
48 | )
49 |
50 |
51 | def _parse_latex(expr: str) -> str:
52 | """Attempts to parse latex to an expression sympy can read."""
53 | expr = expr.replace("\\tfrac", "\\frac")
54 | expr = expr.replace("\\dfrac", "\\frac")
55 | expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
56 | expr = latex2text.LatexNodes2Text().latex_to_text(expr)
57 |
58 | # Replace the specific characters that this parser uses.
59 | expr = expr.replace("√", "sqrt")
60 | expr = expr.replace("π", "pi")
61 | expr = expr.replace("∞", "inf")
62 | expr = expr.replace("∪", "U")
63 | expr = expr.replace("·", "*")
64 | expr = expr.replace("×", "*")
65 |
66 | return expr.strip()
67 |
68 |
69 | def _is_float(num: str) -> bool:
70 | try:
71 | float(num)
72 | return True
73 | except ValueError:
74 | return False
75 |
76 |
77 | def _is_int(x: float) -> bool:
78 | try:
79 | return abs(x - int(round(x))) <= 1e-7
80 | except Exception:
81 | return False
82 |
83 |
84 | def _is_frac(expr: str) -> bool:
85 | return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr))
86 |
87 |
88 | def _str_is_int(x: str) -> bool:
89 | try:
90 | x = _strip_properly_formatted_commas(x)
91 | x = float(x)
92 | return abs(x - int(round(x))) <= 1e-7
93 | except Exception:
94 | return False
95 |
96 |
97 | def _str_to_int(x: str) -> bool:
98 | x = x.replace(",", "")
99 | x = float(x)
100 | return int(x)
101 |
102 |
103 | def _inject_implicit_mixed_number(step: str):
104 | """
105 | Automatically make a mixed number evalable
106 | e.g. 7 3/4 => 7+3/4
107 | """
108 | p1 = re.compile("([0-9]) +([0-9])")
109 | step = p1.sub("\\1+\\2", step) ## implicit mults
110 | return step
111 |
112 |
113 | def _strip_properly_formatted_commas(expr: str):
114 | # We want to be careful because we don't want to strip tuple commas
115 | p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)")
116 | while True:
117 | next_expr = p1.sub("\\1\\3\\4", expr)
118 | if next_expr == expr:
119 | break
120 | expr = next_expr
121 | return next_expr
122 |
123 |
124 | def _normalize(expr: str) -> str:
125 | """Normalize answer expressions."""
126 | if expr is None:
127 | return None
128 |
129 | # Remove enclosing `\text{}`.
130 | m = re.search(r"^\\text\{(?P.+?)\}$", expr)
131 | if m is not None:
132 | expr = m.group("text")
133 |
134 | expr = expr.replace("\\%", "%")
135 | expr = expr.replace("\\$", "$")
136 | expr = expr.replace("$", "")
137 | expr = expr.replace("%", "")
138 | expr = expr.replace(" or ", " , ")
139 | expr = expr.replace(" and ", " , ")
140 |
141 | expr = expr.replace("million", "*10^6")
142 | expr = expr.replace("billion", "*10^9")
143 | expr = expr.replace("trillion", "*10^12")
144 |
145 | for unit in [
146 | "degree",
147 | "cm",
148 | "centimeter",
149 | "meter",
150 | "mile",
151 | "second",
152 | "minute",
153 | "hour",
154 | "day",
155 | "week",
156 | "month",
157 | "year",
158 | "foot",
159 | "feet",
160 | "inch",
161 | "yard",
162 | "liter",
163 | ]:
164 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
165 | expr = re.sub("\^ *\\\\circ", "", expr)
166 |
167 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
168 | expr = expr[1:-1]
169 |
170 | expr = re.sub(",\\\\! *", "", expr)
171 | if _is_float(expr) and _is_int(float(expr)):
172 | expr = str(int(round(float(expr))))
173 | if "\\" in expr:
174 | with contextlib.suppress(Exception):
175 | expr = _parse_latex(expr)
176 |
177 | # edge case with mixed numbers and negative signs
178 | expr = re.sub("- *", "-", expr)
179 |
180 | expr = _inject_implicit_mixed_number(expr)
181 |
182 | # don't be case sensitive for text answers
183 | expr = expr.lower()
184 |
185 | if _str_is_int(expr):
186 | expr = str(_str_to_int(expr))
187 |
188 | return expr
189 |
190 |
191 | def count_unknown_letters_in_expr(expr: str):
192 | expr = expr.replace("sqrt", "")
193 | expr = expr.replace("frac", "")
194 | letters_in_expr = set([x for x in expr if x.isalpha()])
195 | return len(letters_in_expr)
196 |
197 |
198 | def should_allow_eval(expr: str):
199 | # we don't want to try parsing unknown text or functions of more than two variables
200 | if count_unknown_letters_in_expr(expr) > 2:
201 | return False
202 |
203 | for bad_string in BAD_SUBSTRINGS:
204 | if bad_string in expr:
205 | return False
206 |
207 | return all(re.search(bad_regex, expr) is None for bad_regex in BAD_REGEXES)
208 |
209 |
210 | @timeout_limit(seconds=10)
211 | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
212 | are_equal = False
213 | try:
214 | expr = f"({ground_truth_normalized})-({given_normalized})"
215 | if should_allow_eval(expr):
216 | sympy_diff = _sympy_parse(expr)
217 | simplified = sympy.simplify(sympy_diff)
218 | if simplified == 0:
219 | are_equal = True
220 | except Exception:
221 | pass
222 | return are_equal
223 |
224 |
225 | def split_tuple(expr: str):
226 | """
227 | Split the elements in a tuple/interval, while handling well-formatted commas in large numbers
228 | """
229 | expr = _strip_properly_formatted_commas(expr)
230 | if len(expr) == 0:
231 | return []
232 | if (
233 | len(expr) > 2
234 | and expr[0] in TUPLE_CHARS
235 | and expr[-1] in TUPLE_CHARS
236 | and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
237 | ):
238 | elems = [elem.strip() for elem in expr[1:-1].split(",")]
239 | else:
240 | elems = [expr]
241 | return elems
242 |
243 |
244 | def grade_answer(given_answer: str, ground_truth: str) -> bool:
245 | """
246 | The answer will be considered correct if:
247 | (a) it normalizes to the same string as the ground truth answer
248 | OR
249 | (b) sympy can simplify the difference between the expressions to 0
250 | """
251 | if given_answer is None:
252 | return False
253 |
254 | ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth)
255 | given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer)
256 |
257 | # be at least as lenient as mathd
258 | if ground_truth_normalized_mathd == given_answer_normalized_mathd:
259 | return True
260 |
261 | ground_truth_normalized = _normalize(ground_truth)
262 | given_normalized = _normalize(given_answer)
263 |
264 | if ground_truth_normalized is None:
265 | return False
266 |
267 | if ground_truth_normalized == given_normalized:
268 | return True
269 |
270 | if len(given_normalized) == 0:
271 | return False
272 |
273 | ground_truth_elems = split_tuple(ground_truth_normalized)
274 | given_elems = split_tuple(given_normalized)
275 |
276 | if (
277 | len(ground_truth_elems) > 1
278 | and (ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1])
279 | or len(ground_truth_elems) != len(given_elems)
280 | ):
281 | is_correct = False
282 | else:
283 | for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True):
284 | if _is_frac(ground_truth_elem) and _is_frac(given_elem):
285 | # if fractions aren't reduced, then shouldn't be marked as correct
286 | # so, we don't want to allow sympy.simplify in this case
287 | is_correct = ground_truth_elem == given_elem
288 | elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
289 | # if the ground truth answer is an integer, we require the given answer to be a strict match
290 | # (no sympy.simplify)
291 | is_correct = False
292 | else:
293 | try:
294 | is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
295 | except Exception as e:
296 | # if there's an error, we'll just say it's not correct
297 | is_correct = False
298 | print(f"Error: {e} from are_equal_under_sympy, {ground_truth_elem}, {given_elem}")
299 | if not is_correct:
300 | break
301 |
302 | return is_correct
303 |
304 |
305 | def remove_boxed(s):
306 | left = "\\boxed{"
307 | try:
308 | assert s[: len(left)] == left
309 | assert s[-1] == "}"
310 | return s[len(left) : -1]
311 | except Exception:
312 | return None
313 |
314 |
315 | def _last_boxed_only_string(string):
316 | idx = string.rfind("\\boxed")
317 | if idx < 0:
318 | idx = string.rfind("\\fbox")
319 | if idx < 0:
320 | return None
321 |
322 | i = idx
323 | left_brace_idx = None
324 | right_brace_idx = None
325 | num_left_braces_open = 0
326 | while i < len(string):
327 | if string[i] == "{":
328 | num_left_braces_open += 1
329 | if left_brace_idx is None:
330 | left_brace_idx = i
331 | elif string[i] == "}":
332 | num_left_braces_open -= 1
333 | if num_left_braces_open == 0:
334 | right_brace_idx = i
335 | break
336 |
337 | i += 1
338 |
339 | if left_brace_idx is None or right_brace_idx is None:
340 | return None
341 |
342 | return string[left_brace_idx + 1 : right_brace_idx].strip()
343 |
344 |
345 | def match_answer(response):
346 | is_matched = False
347 | for ans_marker in ["answer:", "answer is", "answers are"]:
348 | ans_idx = response.lower().rfind(ans_marker)
349 | if ans_idx != -1:
350 | is_matched = True
351 | response = response[ans_idx + len(ans_marker) :].strip()
352 | if response.endswith("\n"):
353 | response = response[:-2]
354 |
355 | for ans_marker in ["is answer", "is the answer", "are answers", "are the answers"]:
356 | ans_idx = response.lower().rfind(ans_marker)
357 | if ans_idx != -1:
358 | is_matched = True
359 | response = response[:ans_idx].strip()
360 | if response.endswith("\n"):
361 | response = response[:-2]
362 |
363 | # Find boxed
364 | ans_boxed = _last_boxed_only_string(response)
365 | if ans_boxed:
366 | is_matched = True
367 | response = ans_boxed
368 |
369 | if ". " in response:
370 | dot_idx = response.lower().rfind(". ")
371 | if dot_idx != -1:
372 | response = response[:dot_idx].strip()
373 |
374 | for ans_marker in ["be ", "is ", "are ", "=", ": ", "get ", "be\n", "is\n", "are\n", ":\n", "get\n"]:
375 | ans_idx = response.lower().rfind(ans_marker)
376 | if ans_idx != -1:
377 | is_matched = True
378 | response = response[ans_idx + len(ans_marker) :].strip()
379 | if response.endswith("\n"):
380 | response = response[:-2]
381 |
382 | is_matched = is_matched if any([c.isdigit() for c in response]) else False # answer must have a digit
383 | # Grade
384 | return is_matched, response
385 |
386 |
387 | def compute_score(model_output: str, ground_truth: str) -> bool:
388 | model_output = str(model_output)
389 | ground_truth = str(ground_truth)
390 |
391 | is_matched, extracted_model_output = match_answer(model_output)
392 | format_correctness = "Step 2:" in model_output and "\\box" in model_output
393 |
394 | # grade simple algebra questions. if succeeded, return; otherwise, proceed to more complex grading
395 | if grade_answer(extracted_model_output, ground_truth):
396 | return True, True, extracted_model_output
397 |
398 | try:
399 | if "\pi" in extracted_model_output or "\pi" in ground_truth:
400 | equivs = []
401 | for pi in [math.pi, 3.14]:
402 | equivs.append(math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi))
403 | is_correct = any(equivs)
404 | else:
405 | is_correct = math_equal(extracted_model_output, ground_truth, timeout=True)
406 | except Exception:
407 | is_correct = False
408 |
409 | return is_correct, format_correctness, extracted_model_output
410 |
--------------------------------------------------------------------------------
/rstar2_agent/main_rstar2_agent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | """
5 | The different of this file and the verl/trainer/main_ppo.py is the usage of RStar2AgentRayTrainer instead of RayPPOTrainer.
6 | """
7 |
8 | import os
9 | import socket
10 |
11 | import hydra
12 | import ray
13 | from omegaconf import OmegaConf
14 |
15 | from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
16 | from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
17 | from verl.trainer.ppo.reward import load_reward_manager
18 | from verl.utils.device import is_cuda_available
19 |
20 | from .rstar2_agent_ray_trainer import RStar2AgentRayTrainer
21 |
22 |
23 | @hydra.main(config_path="config", config_name="rstar2_agent_trainer", version_base=None)
24 | def main(config):
25 | """Main entry point for PPO training with Hydra configuration management.
26 |
27 | Args:
28 | config_dict: Hydra configuration dictionary containing training parameters.
29 | """
30 | run_ppo(config)
31 |
32 |
33 | # Define a function to run the PPO-like training process
34 | def run_ppo(config) -> None:
35 | """Initialize Ray cluster and run distributed PPO training process.
36 |
37 | Args:
38 | config: Training configuration object containing all necessary parameters
39 | for distributed PPO training including Ray initialization settings,
40 | model paths, and training hyperparameters.
41 | """
42 | # Check if Ray is not initialized
43 | if not ray.is_initialized():
44 | # Initialize Ray with a local cluster configuration
45 | # Set environment variables in the runtime environment to control tokenizer parallelism,
46 | # NCCL debug level, VLLM logging level, and allow runtime LoRA updating
47 | # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration
48 | ray.init(
49 | runtime_env=get_ppo_ray_runtime_env(),
50 | num_cpus=config.ray_init.num_cpus,
51 | )
52 |
53 | # Create a remote instance of the TaskRunner class, and
54 | # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
55 | if (
56 | is_cuda_available
57 | and config.global_profiler.tool == "nsys"
58 | and config.global_profiler.get("steps") is not None
59 | and len(config.global_profiler.get("steps", [])) > 0
60 | ):
61 | from verl.utils.import_utils import is_nvtx_available
62 |
63 | assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'"
64 | nsight_options = OmegaConf.to_container(
65 | config.global_profiler.global_tool_config.nsys.controller_nsight_options
66 | )
67 | runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
68 | else:
69 | runner = TaskRunner.remote()
70 | ray.get(runner.run.remote(config))
71 |
72 | # [Optional] get the path of the timeline trace file from the configuration, default to None
73 | # This file is used for performance analysis
74 | timeline_json_file = config.ray_init.get("timeline_json_file", None)
75 | if timeline_json_file:
76 | ray.timeline(filename=timeline_json_file)
77 |
78 |
79 | @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
80 | class TaskRunner:
81 | """Ray remote class for executing distributed PPO training tasks.
82 |
83 | This class encapsulates the main training logic and runs as a Ray remote actor
84 | to enable distributed execution across multiple nodes and GPUs.
85 |
86 | Attributes:
87 | role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes
88 | mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation
89 | """
90 |
91 | def __init__(self):
92 | self.role_worker_mapping = {}
93 | self.mapping = {}
94 |
95 | def add_actor_rollout_worker(self, config):
96 | """Add actor rollout worker based on the actor strategy."""
97 | from verl.single_controller.ray import RayWorkerGroup
98 |
99 | if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
100 | from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
101 |
102 | actor_rollout_cls = (
103 | AsyncActorRolloutRefWorker
104 | if config.actor_rollout_ref.rollout.mode == "async"
105 | else ActorRolloutRefWorker
106 | )
107 | ray_worker_group_cls = RayWorkerGroup
108 |
109 | elif config.actor_rollout_ref.actor.strategy == "megatron":
110 | from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
111 |
112 | actor_rollout_cls = (
113 | AsyncActorRolloutRefWorker
114 | if config.actor_rollout_ref.rollout.mode == "async"
115 | else ActorRolloutRefWorker
116 | )
117 | ray_worker_group_cls = RayWorkerGroup
118 |
119 | else:
120 | raise NotImplementedError
121 |
122 | from verl.trainer.ppo.ray_trainer import Role
123 |
124 | self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls)
125 |
126 | return actor_rollout_cls, ray_worker_group_cls
127 |
128 | def add_critic_worker(self, config):
129 | """Add critic worker to role mapping."""
130 | if config.critic.strategy in {"fsdp", "fsdp2"}:
131 | use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
132 | if use_legacy_worker_impl in ["auto", "enable"]:
133 | from verl.workers.fsdp_workers import CriticWorker
134 | elif use_legacy_worker_impl == "disable":
135 | from verl.workers.roles import CriticWorker
136 |
137 | print("Using new worker implementation")
138 | else:
139 | raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
140 |
141 | elif config.critic.strategy == "megatron":
142 | from verl.workers.megatron_workers import CriticWorker
143 |
144 | else:
145 | raise NotImplementedError
146 |
147 | from verl.trainer.ppo.ray_trainer import Role
148 |
149 | self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker)
150 |
151 | def init_resource_pool_mgr(self, config):
152 | """Initialize resource pool manager."""
153 | from verl.trainer.ppo.ray_trainer import Role
154 |
155 | global_pool_id = "global_pool"
156 | resource_pool_spec = {
157 | global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
158 | }
159 | self.mapping[Role.ActorRollout] = global_pool_id
160 | self.mapping[Role.Critic] = global_pool_id
161 | from verl.trainer.ppo.ray_trainer import ResourcePoolManager
162 |
163 | resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping)
164 | return resource_pool_manager
165 |
166 | def add_reward_model_worker(self, config):
167 | """Add reward model worker if enabled."""
168 | from verl.trainer.ppo.ray_trainer import Role
169 |
170 | if config.reward_model.enable:
171 | if config.reward_model.strategy in {"fsdp", "fsdp2"}:
172 | from verl.workers.fsdp_workers import RewardModelWorker
173 | elif config.reward_model.strategy == "megatron":
174 | from verl.workers.megatron_workers import RewardModelWorker
175 | else:
176 | raise NotImplementedError
177 | self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
178 | self.mapping[Role.RewardModel] = "global_pool"
179 |
180 | def add_ref_policy_worker(self, config, ref_policy_cls):
181 | """Add reference policy worker if KL loss or KL reward is used."""
182 | from verl.trainer.ppo.ray_trainer import Role
183 |
184 | if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
185 | self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls)
186 | self.mapping[Role.RefPolicy] = "global_pool"
187 |
188 | def run(self, config):
189 | """Execute the main PPO training workflow.
190 |
191 | This method sets up the distributed training environment, initializes
192 | workers, datasets, and reward functions, then starts the training process.
193 |
194 | Args:
195 | config: Training configuration object containing all parameters needed
196 | for setting up and running the PPO training process.
197 | """
198 | # Print the initial configuration. `resolve=True` will evaluate symbolic values.
199 | from pprint import pprint
200 |
201 | from omegaconf import OmegaConf
202 |
203 | from verl.utils.fs import copy_to_local
204 |
205 | print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
206 | pprint(OmegaConf.to_container(config, resolve=True))
207 | OmegaConf.resolve(config)
208 |
209 | # Download the checkpoint from HDFS to the local machine.
210 | # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
211 | local_path = copy_to_local(
212 | config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
213 | )
214 |
215 | # Instantiate the tokenizer and processor.
216 | from verl.utils import hf_processor, hf_tokenizer
217 |
218 | trust_remote_code = config.data.get("trust_remote_code", False)
219 | tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
220 | # Used for multimodal LLM, could be None
221 | processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
222 |
223 | actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)
224 | self.add_critic_worker(config)
225 |
226 | # We should adopt a multi-source reward function here:
227 | # - for rule-based rm, we directly call a reward score
228 | # - for model-based rm, we call a model
229 | # - for code related prompt, we send to a sandbox if there are test cases
230 | # finally, we combine all the rewards together
231 | # The reward type depends on the tag of the data
232 | self.add_reward_model_worker(config)
233 |
234 | # Add a reference policy worker if KL loss or KL reward is used.
235 | self.add_ref_policy_worker(config, actor_rollout_cls)
236 |
237 | ################################### rStar ###################################
238 | # support data.filter_overlong_prompts
239 | if config.actor_rollout_ref.model.get("custom_chat_template", None) is not None:
240 | if processor is not None:
241 | processor.chat_template = config.actor_rollout_ref.model.custom_chat_template
242 | tokenizer.chat_template = config.actor_rollout_ref.model.custom_chat_template
243 |
244 | tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path
245 | tool_list = []
246 | if tool_config_path is not None:
247 | from verl.tools.utils.tool_registry import ToolType, get_tool_class, OpenAIFunctionToolSchema
248 | tools_config = OmegaConf.load(tool_config_path)
249 | for tool_config in tools_config.tools:
250 | tool_type = ToolType(tool_config.config.type)
251 | assert tool_type is ToolType.NATIVE
252 | if tool_config.get("tool_schema", None) is None:
253 | tool_schema = None
254 | else:
255 | tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True)
256 | tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict).model_dump(exclude_unset=True, exclude_none=True)
257 | tool_list.append(tool_schema)
258 | #############################################################################
259 |
260 | # Load the reward manager for training and validation.
261 | reward_fn = load_reward_manager(
262 | config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
263 | )
264 | val_reward_fn = load_reward_manager(
265 | config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
266 | )
267 |
268 | resource_pool_manager = self.init_resource_pool_mgr(config)
269 |
270 | from verl.utils.dataset.rl_dataset import collate_fn
271 |
272 | # Create training and validation datasets.
273 | train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True)
274 | val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False)
275 | ################################### rStar ###################################
276 | train_dataset.dataframe = train_dataset.maybe_filter_out_long_prompts(train_dataset.dataframe, tools=tool_list)
277 | val_dataset.dataframe = val_dataset.maybe_filter_out_long_prompts(val_dataset.dataframe, tools=tool_list)
278 | #############################################################################
279 | train_sampler = create_rl_sampler(config.data, train_dataset)
280 |
281 | # Initialize the rstar2 agent PPO trainer.
282 | trainer = RStar2AgentRayTrainer(
283 | config=config,
284 | tokenizer=tokenizer,
285 | processor=processor,
286 | role_worker_mapping=self.role_worker_mapping,
287 | resource_pool_manager=resource_pool_manager,
288 | ray_worker_group_cls=ray_worker_group_cls,
289 | reward_fn=reward_fn,
290 | val_reward_fn=val_reward_fn,
291 | train_dataset=train_dataset,
292 | val_dataset=val_dataset,
293 | collate_fn=collate_fn,
294 | train_sampler=train_sampler,
295 | )
296 | # Initialize the workers of the trainer.
297 | trainer.init_workers()
298 | # Start the training process.
299 | trainer.fit()
300 |
301 |
302 | if __name__ == "__main__":
303 | main()
304 |
--------------------------------------------------------------------------------
/rstar2_agent/tools/code_judge_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import json
5 | import aiohttp
6 | import asyncio
7 | import traceback
8 | import os
9 | import datetime
10 |
11 | from typing import Dict, List, Literal, Callable, Optional
12 |
13 | # Global variable to store the path for failed submissions
14 | _failed_submissions_path = os.path.expanduser("~")
15 |
16 |
17 | def set_failed_submissions_path(path: str):
18 | """
19 | Set the path where failed submissions will be saved.
20 |
21 | Args:
22 | path: The directory path to save failed submissions
23 | """
24 | global _failed_submissions_path
25 | _failed_submissions_path = os.path.expanduser(path)
26 | # Create directory if it doesn't exist
27 | os.makedirs(_failed_submissions_path, exist_ok=True)
28 | print(f"Failed submissions will be saved to: {_failed_submissions_path}")
29 |
30 |
31 | def get_failed_submissions_path() -> str:
32 | """
33 | Get the current path where failed submissions will be saved.
34 |
35 | Returns:
36 | The current path for saving failed submissions
37 | """
38 | return _failed_submissions_path
39 |
40 |
41 | async def call_long_batch(
42 | url: str,
43 | submissions: List[Dict],
44 | session: aiohttp.ClientSession,
45 | max_retries: int = 4,
46 | backoff_factor: float = 0.5):
47 |
48 | sub_num = len(submissions)
49 | results = [None] * sub_num
50 | sub_ids = list(range(sub_num))
51 | attempt_count = 0
52 | while submissions and attempt_count < max_retries:
53 | attempt_count += 1
54 | try:
55 | data = {
56 | "type": "batch",
57 | "submissions": submissions
58 | }
59 | queue_timeouts = []
60 | async with session.post(url, json=data) as response:
61 | response.raise_for_status()
62 | response_json = await response.json()
63 | for sub_id, result in zip(sub_ids, response_json['results']):
64 | if result['reason'] != 'queue_timeout':
65 | results[sub_id] = result
66 | else:
67 | queue_timeouts.append((sub_id, submissions[sub_id]))
68 | submissions = [sub for _, sub in queue_timeouts]
69 | sub_ids = [sub_id for sub_id, _ in queue_timeouts]
70 | except aiohttp.ClientResponseError as e:
71 | print(f"Attempt {attempt_count}: Server responded with {e.status}")
72 | except (aiohttp.ClientError, asyncio.TimeoutError) as e:
73 | print(f"Attempt {attempt_count}: Caught {type(e).__name__}: {repr(e)}")
74 | except Exception as e:
75 | print(f"run_tool_calls_on_server_async Error: {e}")
76 | traceback.print_exc()
77 | finally:
78 | await asyncio.sleep(backoff_factor * (2 ** (attempt_count - 1)))
79 |
80 | # Save failed submissions to file if any remain after max retries
81 | if submissions:
82 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
83 | failed_file = os.path.join(_failed_submissions_path, f"failed_submissions_{timestamp}.json")
84 |
85 | failed_data = {
86 | "timestamp": timestamp,
87 | "url": url,
88 | "max_retries": max_retries,
89 | "failed_submissions": []
90 | }
91 |
92 | for sub_id, submission in zip(sub_ids, submissions):
93 | failed_data["failed_submissions"].append({
94 | "original_index": sub_id,
95 | "submission": submission
96 | })
97 |
98 | try:
99 | with open(failed_file, 'w', encoding='utf-8') as f:
100 | json.dump(failed_data, f, indent=2, ensure_ascii=False)
101 | print(f"Saved {len(submissions)} failed submissions to: {failed_file}")
102 | except Exception as e:
103 | print(f"Failed to save failed submissions: {e}")
104 |
105 | return results
106 |
107 |
108 | async def run_tool_calls_on_server_async(
109 | tool_calls: List,
110 | session: aiohttp.ClientSession,
111 | language: Literal["python", "cpp"] = "python",
112 | max_retries: int = 4,
113 | backoff_factor: float = 0.5,
114 | generate_tool_call_code: Callable = None,
115 | generate_tool_call_input: Callable = None,
116 | host_addr: str = "localhost",
117 | host_port: str = "8088"):
118 | submissions = []
119 | for tool_call in tool_calls:
120 | submissions.append({
121 | "type": language,
122 | "solution": generate_tool_call_code(tool_call),
123 | "input": generate_tool_call_input(tool_call),
124 | })
125 |
126 | url = f"http://{host_addr}:{host_port}/run/long-batch"
127 | results = await call_long_batch(url, submissions, session, max_retries, backoff_factor)
128 |
129 | if None in results:
130 | failed_indices = [i for i, result in enumerate(results) if result is None]
131 | # throw an error if any tool call failed after max retries
132 | if len(failed_indices) > 0:
133 | raise RuntimeError(f"run_tool_calls_on_server_async failed for {len(failed_indices)} tool calls after {max_retries} attempts.")
134 |
135 | for i in range(len(results)):
136 | if results[i]['run_success'] and results[i]['success']:
137 | output_parts = []
138 | output_parts.append('Tool call success')
139 | if results[i]["stdout"]:
140 | output_parts.append(f'stdout: {results[i]["stdout"]}')
141 | if results[i]["stderr"]:
142 | output_parts.append(f'stderr: {results[i]["stderr"]}')
143 | output_parts.append(f'execution time: {results[i]["cost"]:.2f}s')
144 | results[i] = '\n'.join(output_parts)
145 | else:
146 | output_parts = []
147 | output_parts.append('Tool call failure')
148 | output_parts.append(f'reason: {results[i]["reason"]}')
149 | if results[i]["stdout"]:
150 | output_parts.append(f'stdout: {results[i]["stdout"]}')
151 | if results[i]["stderr"]:
152 | output_parts.append(f'stderr: {results[i]["stderr"]}')
153 | output_parts.append(f'execution time: {results[i]["cost"]:.2f}s')
154 | results[i] = '\n'.join(output_parts)
155 |
156 | return results
157 |
158 |
159 | ### Generate tool call code
160 |
161 | code_template_setup = '''
162 | import os
163 | import base64
164 | import sys
165 | import ast
166 | import traceback
167 | from typing import Optional, Any
168 | import linecache
169 | from types import CodeType
170 | from contextlib import redirect_stdout, redirect_stderr
171 | from io import StringIO
172 |
173 | class CodeExecutionError(Exception):
174 | """Custom exception for code execution errors with line information"""
175 | def __init__(self, original_error: Exception, code: str, line_offset: int = 0):
176 | self.original_error = original_error
177 | self.code = code
178 | self.line_offset = line_offset
179 |
180 | # Get error line number
181 | if hasattr(original_error, 'lineno'):
182 | self.lineno = original_error.lineno
183 | else:
184 | tb = getattr(original_error, '__traceback__', None)
185 | if tb:
186 | while tb.tb_next:
187 | tb = tb.tb_next
188 | self.lineno = tb.tb_lineno
189 | else:
190 | self.lineno = -1
191 |
192 | # Adjust line number for code segment
193 | if self.lineno != -1:
194 | self.lineno += line_offset
195 |
196 | # Format error message
197 | error_type = type(original_error).__name__
198 | error_msg = str(original_error)
199 |
200 | if self.lineno != -1:
201 | # Get the problematic line
202 | lines = code.splitlines()
203 | if 0 <= self.lineno - 1 < len(lines):
204 | error_line = lines[self.lineno - 1]
205 | # Create error message with line information
206 | super().__init__(f"{error_type} at line {self.lineno}: {error_msg}\\n {error_line}")
207 | return
208 |
209 | super().__init__(f"{error_type}: {error_msg}")
210 |
211 | class PersistentExecutor:
212 | def __init__(self):
213 | self.exec_globals = {
214 | '__name__': '__main__',
215 | '__file__': '',
216 | '__builtins__': __builtins__
217 | }
218 |
219 | def split_code(self, code: str) -> tuple[str, Optional[str]]:
220 | """
221 | Intelligently split code into main body and last expression
222 |
223 | Args:
224 | code: The source code string
225 |
226 | Returns:
227 | tuple[str, Optional[str]]: (main code body, last expression if exists)
228 | """
229 | try:
230 | # Parse code into AST
231 | tree = ast.parse(code)
232 | if not tree.body:
233 | return code, None
234 |
235 | # Check if the last node is a pure expression (not a call)
236 | last_node = tree.body[-1]
237 | if isinstance(last_node, ast.Expr):
238 | # Get the line range of the last expression
239 | last_expr_start = last_node.lineno
240 | last_expr_end = last_node.end_lineno if hasattr(last_node, 'end_lineno') else last_node.lineno
241 |
242 | # Split the code
243 | lines = code.splitlines()
244 | main_code = '\\n'.join(lines[:last_expr_start-1])
245 | last_expr = '\\n'.join(lines[last_expr_start-1:last_expr_end])
246 | return main_code, last_expr
247 | except SyntaxError as e:
248 | raise CodeExecutionError(e, code)
249 | return code, None
250 |
251 | def execute_code(self, code: str, replay_history_code: bool) -> None:
252 | """
253 | Execute code while maintaining persistent environment state.
254 | If the last line is an expression, its value will be printed to stdout.
255 |
256 | Args:
257 | code: The source code string to execute
258 | replay_history_code: If True, suppress stdout and stderr output
259 | """
260 | try:
261 | # Split code intelligently
262 | main_code, last_expr = self.split_code(code)
263 |
264 | # Set up output redirection if replay_history_code is True
265 | if replay_history_code:
266 | stdout_capture = StringIO()
267 | stderr_capture = StringIO()
268 | stdout_context = redirect_stdout(stdout_capture)
269 | stderr_context = redirect_stderr(stderr_capture)
270 | else:
271 | stdout_context = redirect_stdout(sys.stdout)
272 | stderr_context = redirect_stderr(sys.stderr)
273 |
274 | # Execute main code body
275 | if main_code:
276 | try:
277 | # Compile code to get better error line numbers
278 | compiled_code = compile(main_code, '', 'exec')
279 | with stdout_context, stderr_context:
280 | exec(compiled_code, self.exec_globals)
281 | except Exception as e:
282 | raise CodeExecutionError(e, main_code)
283 |
284 | # If there's a last expression, try to evaluate and print it
285 | if last_expr:
286 | try:
287 | # Compile expression to get better error line numbers
288 | compiled_expr = compile(last_expr, '', 'eval')
289 | with stdout_context, stderr_context:
290 | last_value = eval(compiled_expr, self.exec_globals)
291 |
292 | # Only print the result if not in replay mode
293 | if last_value is not None and not replay_history_code:
294 | print(repr(last_value), file=sys.stdout)
295 | except Exception as e:
296 | # Try executing as statement if evaluation fails
297 | try:
298 | compiled_stmt = compile(last_expr, '', 'exec')
299 | with stdout_context, stderr_context:
300 | exec(compiled_stmt, self.exec_globals)
301 | except Exception as e:
302 | # Calculate line offset for the last expression
303 | line_offset = len(main_code.splitlines()) if main_code else 0
304 | raise CodeExecutionError(e, last_expr, line_offset)
305 |
306 | except Exception as e:
307 | if replay_history_code:
308 | return
309 | if isinstance(e, CodeExecutionError):
310 | print(str(e), file=sys.stderr)
311 | else:
312 | traceback.print_exc(file=sys.stderr)
313 | os._exit(1)
314 | return
315 |
316 | persistent_executor = PersistentExecutor()
317 | '''
318 |
319 | code_template_exec = '''
320 | code_to_execute = base64.b64decode("{}".encode()).decode()
321 | persistent_executor.execute_code(code_to_execute, replay_history_code={})
322 | '''
323 |
324 | def combine_code_template(code_to_execute: str, history_code_to_execute: Optional[List[str]] = None) -> str:
325 | history_code_to_execute = history_code_to_execute or []
326 | final_code = code_template_setup
327 | for history_code in history_code_to_execute:
328 | final_code += code_template_exec.format(history_code, "True")
329 | final_code += code_template_exec.format(code_to_execute, "False")
330 | return final_code
331 |
332 |
333 | def generate_tool_call_code(tool_call: Dict) -> str:
334 | import base64
335 |
336 | def jupyter_code_gencode(json_format_data: Dict) -> str:
337 | code_to_execute = base64.b64encode(json_format_data["arguments"]["code"].encode()).decode()
338 | history_code_to_execute = [
339 | base64.b64encode(tool_call_json["arguments"]["code"].encode()).decode()
340 | for tool_call_json in json_format_data.get("history_tool_calls", []) if tool_call_json["name"] == "jupyter_code"
341 | ]
342 | return combine_code_template(code_to_execute, history_code_to_execute)
343 |
344 | def python_code_with_standard_io_gencode(json_format_data: Dict) -> str:
345 | code_to_execute = base64.b64encode(json_format_data["arguments"]["code"].encode()).decode()
346 | return combine_code_template(code_to_execute)
347 |
348 | if tool_call["name"] == "jupyter_code":
349 | return jupyter_code_gencode(tool_call)
350 | elif tool_call["name"] == "python_code_with_standard_io":
351 | return python_code_with_standard_io_gencode(tool_call)
352 | else:
353 | raise ValueError(f"Unsupported tool call name: {tool_call['name']}")
354 |
355 |
356 | def generate_tool_call_input(tool_call: Dict) -> str:
357 | if tool_call["name"] == "jupyter_code":
358 | return None
359 | elif tool_call["name"] == "python_code_with_standard_io":
360 | return tool_call["arguments"]["input"]
361 | else:
362 | raise ValueError(f"Unsupported tool call name: {tool_call['name']}")
363 |
--------------------------------------------------------------------------------
/rstar2_agent/tools/request_processor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import asyncio
5 | import time
6 | import uuid
7 | import collections
8 | import aiohttp
9 | import traceback
10 | from typing import List, Dict, Any, Callable, Awaitable
11 |
12 | # Define the expected signature for the batch submission function
13 | # It should be an async callable that takes:
14 | # 1. A list of original request payloads (List[Any])
15 | # 2. The aiohttp.ClientSession instance
16 | # It should return:
17 | # 1. A list of results (List[Any]) where the order *strictly* matches the input payloads order.
18 | BatchSubmitFunc = Callable[[List[Any], aiohttp.ClientSession], Awaitable[List[Any]]]
19 |
20 |
21 | class RequestProcessor:
22 | """
23 | Manages batch submission concurrently using an injected batch submission function.
24 | Requests are buffered and processed by concurrent sender workers.
25 | """
26 | def __init__(self, batch_size: int, batch_timeout_seconds: float, session: aiohttp.ClientSession, concurrency: int, batch_submit_func: BatchSubmitFunc):
27 | """
28 | Initializes the Request Processor with concurrent sending and a generic submission function.
29 | Must be called within an event loop context.
30 |
31 | Args:
32 | batch_size: Maximum items per batch.
33 | batch_timeout_seconds: Timeout for gathering items into a batch.
34 | session: The aiohttp.ClientSession to pass to the submission function.
35 | language: Submission language to pass to the submission function.
36 | concurrency: Maximum number of concurrent batches being sent to B.
37 | batch_submit_func: The async function to call for sending a batch.
38 | Must match the BatchSubmitFunc signature.
39 | """
40 | if batch_size <= 0 or concurrency <= 0:
41 | raise ValueError("batch_size and concurrency must be positive")
42 | if batch_timeout_seconds <= 0:
43 | print("Warning: batch_timeout_seconds <= 0, batching will be strictly based on batch_size or queue availability.")
44 |
45 | self._batch_size = batch_size
46 | self._batch_timeout_seconds = batch_timeout_seconds
47 | self._session = session
48 | self._concurrency = concurrency
49 | self._batch_submit_func = batch_submit_func # Store the injected function
50 |
51 | self._submission_queue = asyncio.Queue()
52 | self._pending_requests: Dict[str, Dict[str, Any]] = {} # {request_id: {"future": Future, "payload": payload}}
53 |
54 | self._semaphore = asyncio.Semaphore(concurrency)
55 |
56 | self._sender_workers: List[asyncio.Task] = []
57 | self._running = False
58 |
59 | # --- Statistics ---
60 | self._reset_stats_internal() # Initialize stats
61 | # --- End Statistics ---
62 |
63 | print(f"[{time.monotonic():.4f}] RequestProcessor initialized with concurrency={self._concurrency} (async thread).")
64 |
65 | def _reset_stats_internal(self):
66 | """Helper to initialize or reset statistics."""
67 | self._stats = {
68 | "total_batch_submission_duration_seconds": 0.0,
69 | "num_batches_submitted": 0,
70 | "num_successful_batches": 0,
71 | "num_failed_batches": 0,
72 | "total_items_processed_in_batches": 0,
73 | "actual_batch_sizes": [] # Stores the size of each batch
74 | }
75 |
76 | async def send_request(self, request_payload: Any, timeout: float = None):
77 | """
78 | Adds a single request to the buffer and waits for its result.
79 | This call is awaitable and provides the synchronous-like pattern.
80 | """
81 | if not self._running:
82 | raise RuntimeError("RequestProcessor is not running. Call .start() first.")
83 |
84 | request_id = str(uuid.uuid4())
85 |
86 | future = asyncio.get_running_loop().create_future()
87 | self._pending_requests[request_id] = {
88 | "future": future,
89 | "payload": request_payload
90 | }
91 |
92 | await self._submission_queue.put(request_id)
93 |
94 | try:
95 | result = await asyncio.wait_for(future, timeout=timeout)
96 | return result
97 | except asyncio.TimeoutError:
98 | if request_id in self._pending_requests:
99 | del self._pending_requests[request_id]
100 | print(f"[{time.monotonic():.4f}] Request {request_id[:6]}... timed out waiting for result.")
101 | raise
102 | except Exception as e:
103 | if request_id in self._pending_requests:
104 | del self._pending_requests[request_id]
105 | print(f"[{time.monotonic():.4f}] Request {request_id[:6]}... encountered error while waiting: {e}")
106 | raise
107 |
108 | async def send_requests(self, request_payloads: List[Any], timeout: float = None) -> List[Any]:
109 | """
110 | Submits multiple request payloads concurrently and waits for all their results.
111 | Returns results or exceptions in the same order as input payloads.
112 | Uses send_request internally and gathers futures.
113 | """
114 | if not self._running:
115 | raise RuntimeError("RequestProcessor is not running. Call .start() first.")
116 |
117 | if not request_payloads:
118 | return []
119 |
120 | # print(f"[{time.monotonic():.4f}] RequestProcessor.send_requests submitting {len(request_payloads)} individual requests concurrently...")
121 |
122 | tasks = []
123 | for payload in request_payloads:
124 | tasks.append(asyncio.create_task(self.send_request(payload, timeout=timeout)))
125 |
126 | results = await asyncio.gather(*tasks, return_exceptions=True)
127 |
128 | # print(f"[{time.monotonic():.4f}] RequestProcessor.send_requests received all results.")
129 |
130 | return results
131 |
132 | async def start(self):
133 | """
134 | Starts the concurrent sender worker tasks. Must be called within loop.
135 | """
136 | if self._running:
137 | print(f"[{time.monotonic():.4f}] RequestProcessor is already running.")
138 | return
139 | self._running = True
140 | self._sender_workers = [asyncio.create_task(self._sender_worker()) for _ in range(self._concurrency)]
141 | print(f"[{time.monotonic():.4f}] RequestProcessor started {self._concurrency} sender workers.")
142 |
143 | async def stop(self):
144 | if not self._running:
145 | print(f"[{time.monotonic():.4f}] RequestProcessor is not running.")
146 | return
147 |
148 | print(f"[{time.monotonic():.4f}] Stopping RequestProcessor. Signaling workers...")
149 | self._running = False
150 |
151 | await self._submission_queue.join()
152 | print(f"[{time.monotonic():.4f}] Submission queue joined. All buffered items processed by sender workers.")
153 |
154 | # Wait for sender workers to finish their current batch and exit their loops
155 | for worker in self._sender_workers:
156 | worker.cancel()
157 | try:
158 | await asyncio.gather(*self._sender_workers, return_exceptions=True)
159 | except asyncio.CancelledError:
160 | print(f"[{time.monotonic():.4f}] Sender workers cancelled as expected.")
161 | except Exception as e:
162 | print(f"[{time.monotonic():.4f}] Error during sender workers shutdown: {e}")
163 |
164 | print(f"[{time.monotonic():.4f}] All sender workers stopped.")
165 |
166 | print(f"[{time.monotonic():.4f}] Waiting for {len(self._pending_requests)} pending results...")
167 | wait_tasks = [asyncio.create_task(req_info["future"])
168 | for req_id, req_info in list(self._pending_requests.items())
169 | if not req_info["future"].done()]
170 |
171 | if wait_tasks:
172 | stop_results_timeout = self._batch_timeout_seconds * 5
173 | print(f"[{time.monotonic():.4f}] Waiting for remaining results with timeout {stop_results_timeout:.2f}s...")
174 | try:
175 | await asyncio.wait_for(asyncio.gather(*wait_tasks, return_exceptions=True), timeout=stop_results_timeout)
176 | print(f"[{time.monotonic():.4f}] All pending results awaited or timed out during stop.")
177 | except asyncio.TimeoutError:
178 | print(f"[{time.monotonic():.4f}] Warning: Timeout waiting for all pending results during stop.")
179 |
180 | else:
181 | print(f"[{time.monotonic():.4f}] No pending results to await.")
182 |
183 | if self._pending_requests:
184 | print(f"[{time.monotonic():.4f}] Warning: Stopping with {len(self._pending_requests)} requests still pending (futures not completed/timed out)!")
185 |
186 | print(f"[{time.monotonic():.4f}] RequestProcessor stopped.")
187 |
188 | async def _sender_worker(self):
189 | """
190 | A single worker coroutine that continuously tries to send batches.
191 | Controls its own access to concurrent sending via the semaphore.
192 | Runs within the async thread's event loop.
193 | """
194 | print(f"[{time.monotonic():.4f}] Sender worker started.")
195 | batch_gathering_timeout = self._batch_timeout_seconds
196 |
197 | try:
198 | while self._running or not self._submission_queue.empty():
199 | batch_item_ids = []
200 |
201 | try:
202 | first_item_id = await asyncio.wait_for(self._submission_queue.get(), timeout=batch_gathering_timeout)
203 | self._submission_queue.task_done()
204 | batch_item_ids.append(first_item_id)
205 |
206 | while len(batch_item_ids) < self._batch_size:
207 | try:
208 | next_item_id = self._submission_queue.get_nowait()
209 | self._submission_queue.task_done()
210 | batch_item_ids.append(next_item_id)
211 | except asyncio.QueueEmpty:
212 | break
213 |
214 | except asyncio.TimeoutError:
215 | if not batch_item_ids:
216 | continue
217 | print(f"[{time.monotonic():.4f}] Worker: Timeout, but gathered {len(batch_item_ids)} items. Proceeding to send.")
218 | pass
219 |
220 | except Exception as e:
221 | print(f"[{time.monotonic():.4f}] Worker encountered error getting items: {e}")
222 | await asyncio.sleep(1.0)
223 | continue
224 |
225 | if batch_item_ids:
226 | # Acquire semaphore permit before starting the potentially long-running batch submission
227 | async with self._semaphore:
228 | # Perform the actual batch sending using the injected function
229 | await self._perform_send_batch(batch_item_ids)
230 | else:
231 | pass
232 |
233 | except asyncio.CancelledError:
234 | print(f"[{time.monotonic():.4f}] Sender worker received cancellation signal.")
235 | except Exception as e:
236 | print(f"[{time.monotonic():.4f}] Sender worker encountered major error: {e}")
237 |
238 | print(f"[{time.monotonic():.4f}] Sender worker finished.")
239 |
240 |
241 | async def _perform_send_batch(self, batch_item_ids: List[str]):
242 | """
243 | Internal method to execute the batch submission using the injected function and process results.
244 | Assumes this method is called within the context of an acquired semaphore permit.
245 | Runs within the async thread's event loop.
246 | """
247 | batch_info = [] # [{"request_id": id, "payload": payload}]
248 | payloads_for_server = [] # List of just payloads to pass to the injected function
249 |
250 | # Ensure the original requests are still pending before forming the batch data
251 | valid_item_ids_for_batch = [req_id for req_id in batch_item_ids if req_id in self._pending_requests]
252 |
253 | if not valid_item_ids_for_batch:
254 | # print(f"[{time.monotonic():.4f}] Batch contains no valid pending items after worker picked them up.")
255 | return # Nothing valid to send
256 |
257 | # Build the payload list for the injected function using only valid IDs
258 | for req_id in valid_item_ids_for_batch:
259 | req_info = self._pending_requests[req_id] # Should exist based on valid_item_ids_for_batch
260 | batch_info.append({"request_id": req_id, "payload": req_info["payload"]})
261 | payloads_for_server.append(req_info["payload"])
262 |
263 | # --- CALL THE INJECTED BATCH SUBMISSION FUNCTION ---
264 | # print(f"[{time.monotonic():.4f}] Submitting batch of {len(payloads_for_server)} items using injected function...")
265 | self._stats["num_batches_submitted"] += 1
266 | self._stats["actual_batch_sizes"].append(len(payloads_for_server))
267 |
268 | start_time = time.monotonic()
269 | try:
270 | # Call the function provided during initialization
271 | # It must return results in the same order as input payloads_for_server
272 | results_list = await self._batch_submit_func(payloads_for_server, self._session)
273 |
274 | submission_duration = time.monotonic() - start_time
275 | self._stats["total_batch_submission_duration_seconds"] += submission_duration
276 | self._stats["num_successful_batches"] += 1
277 | self._stats["total_items_processed_in_batches"] += len(payloads_for_server)
278 |
279 |
280 | # Process the results returned by the injected function
281 | # The order of results_list is assumed to match the order of payloads_for_server
282 | if len(results_list) != len(batch_info):
283 | print(f"[{time.monotonic():.4f}] Warning: Injected function returned {len(results_list)} results, but batch had {len(batch_info)} items. Cannot reliably match results.")
284 | match_count = min(len(results_list), len(batch_info))
285 | else:
286 | match_count = len(batch_info)
287 |
288 | for i in range(match_count):
289 | req_id = batch_info[i]["request_id"] # Get the original ID
290 | result = results_list[i] # Get the corresponding result
291 |
292 | if req_id in self._pending_requests:
293 | req_info = self._pending_requests[req_id]
294 | future = req_info["future"]
295 | if not future.done():
296 | future.set_result(result)
297 | del self._pending_requests[req_id]
298 | else:
299 | if req_id in self._pending_requests:
300 | del self._pending_requests[req_id]
301 | else:
302 | print(f"[{time.monotonic():.4f}] Warning: Received result for unknown or already completed request ID {req_id[:6]}... Result: {result}")
303 |
304 | except Exception as e:
305 | submission_duration = time.monotonic() - start_time
306 | self._stats["total_batch_submission_duration_seconds"] += submission_duration # Still record time even on failure
307 | self._stats["num_failed_batches"] += 1
308 | # print error stack trace for debugging
309 | traceback.print_exc()
310 | print(f"[{time.monotonic():.4f}] Error calling or processing results from injected function for batch: {e}")
311 | # Handle failure of the injected function.
312 | # Items remain in _pending_requests, rely on timeout/stop cleanup.
313 | # To avoid silently failing, throw an exception directly to the caller
314 | for req_id in valid_item_ids_for_batch:
315 | if req_id in self._pending_requests:
316 | req_info = self._pending_requests[req_id]
317 | future = req_info["future"]
318 | if not future.done():
319 | future.set_exception(e)
320 | del self._pending_requests[req_id]
321 |
322 | def get_stats(self) -> Dict[str, Any]:
323 | """Returns the collected performance statistics."""
324 | stats_copy = self._stats.copy()
325 | if stats_copy["num_successful_batches"] > 0:
326 | stats_copy["avg_successful_batch_submission_duration_seconds"] = \
327 | self._stats["total_batch_submission_duration_seconds"] / stats_copy["num_successful_batches"] \
328 | if self._stats["num_successful_batches"] > 0 else 0 # Avoid division by zero if only failures
329 | else:
330 | stats_copy["avg_successful_batch_submission_duration_seconds"] = 0
331 |
332 | if stats_copy["num_batches_submitted"] > 0: # Calculate overall average if any batch was submitted
333 | stats_copy["avg_overall_batch_submission_duration_seconds"] = \
334 | self._stats["total_batch_submission_duration_seconds"] / stats_copy["num_batches_submitted"]
335 | else:
336 | stats_copy["avg_overall_batch_submission_duration_seconds"] = 0
337 |
338 | if self._stats["actual_batch_sizes"]:
339 | stats_copy["avg_actual_batch_size"] = sum(self._stats["actual_batch_sizes"]) / len(self._stats["actual_batch_sizes"])
340 | else:
341 | stats_copy["avg_actual_batch_size"] = 0
342 | return stats_copy
343 |
344 | def print_stats(self):
345 | """Prints the collected performance statistics."""
346 | stats_to_print = self.get_stats()
347 | print(f"[{time.monotonic():.4f}] --- RequestProcessor Statistics ---")
348 | for key, value in stats_to_print.items():
349 | if key == "actual_batch_sizes":
350 | if value: # Check if the list of batch sizes is not empty
351 | batch_size_counts = collections.Counter(value)
352 | # Format as a list of (batch_size, count) tuples, sorted by batch_size
353 | formatted_batch_sizes = sorted(batch_size_counts.items())
354 | print(f" {key}: {formatted_batch_sizes}")
355 | else:
356 | print(f" {key}: []") # Print an empty list if no batches were processed
357 | elif isinstance(value, float):
358 | print(f" {key}: {value:.4f}")
359 | else:
360 | # Handles other data types including other lists (if any)
361 | print(f" {key}: {value}")
362 | print(f"[{time.monotonic():.4f}] --- End Statistics ---")
363 |
364 | def reset_stats(self):
365 | """Resets all collected performance statistics to their initial values."""
366 | print(f"[{time.monotonic():.4f}] Resetting RequestProcessor statistics.")
367 | self._reset_stats_internal()
368 |
--------------------------------------------------------------------------------
/fused_compute_score/prime_math/grader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright (c) Microsoft Corporation.
16 | #
17 | # Permission is hereby granted, free of charge, to any person obtaining a copy
18 | # of this software and associated documentation files (the "Software"), to deal
19 | # in the Software without restriction, including without limitation the rights
20 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
21 | # copies of the Software, and to permit persons to whom the Software is
22 | # furnished to do so, subject to the following conditions:
23 | #
24 | # The above copyright notice and this permission notice shall be included in all
25 | # copies or substantial portions of the Software.
26 | #
27 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
28 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
29 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
30 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
31 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
33 | # SOFTWARE
34 |
35 | # Copyright (c) 2023 OpenAI
36 | #
37 | # Permission is hereby granted, free of charge, to any person obtaining a copy
38 | # of this software and associated documentation files (the "Software"), to deal
39 | # in the Software without restriction, including without limitation the rights
40 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
41 | # copies of the Software, and to permit persons to whom the Software is
42 | # furnished to do so, subject to the following conditions:
43 |
44 | # The above copyright notice and this permission notice shall be included in all
45 | # copies or substantial portions of the Software.
46 | #
47 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
48 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
49 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
50 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
51 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
52 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
53 | # SOFTWARE.
54 |
55 | # Copyright (c) 2021 Dan Hendrycks
56 | #
57 | # Permission is hereby granted, free of charge, to any person obtaining a copy
58 | # of this software and associated documentation files (the "Software"), to deal
59 | # in the Software without restriction, including without limitation the rights
60 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
61 | # copies of the Software, and to permit persons to whom the Software is
62 | # furnished to do so, subject to the following conditions:
63 | #
64 | # The above copyright notice and this permission notice shall be included in all
65 | # copies or substantial portions of the Software.
66 | #
67 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
68 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
69 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
70 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
71 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
72 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
73 | # SOFTWARE.
74 |
75 | # Copyright 2024 PRIME team and/or its affiliates
76 | #
77 | # Licensed under the Apache License, Version 2.0 (the "License");
78 | # you may not use this file except in compliance with the License.
79 | # You may obtain a copy of the License at
80 | #
81 | # http://www.apache.org/licenses/LICENSE-2.0
82 | #
83 | # Unless required by applicable law or agreed to in writing, software
84 | # distributed under the License is distributed on an "AS IS" BASIS,
85 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
86 | # See the License for the specific language governing permissions and
87 | # limitations under the License.
88 | """
89 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
90 | - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
91 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
92 | - https://github.com/openai/prm800k
93 | """
94 |
95 | import contextlib
96 | import math
97 | import re
98 | from math import isclose
99 |
100 | # sympy related
101 | from sympy import N, simplify
102 | from sympy.parsing.latex import parse_latex
103 | from sympy.parsing.sympy_parser import parse_expr
104 |
105 |
106 | def is_digit(s):
107 | try:
108 | if "{,}" in str(s):
109 | num = float(str(s).replace("{,}", ""))
110 | return True, num
111 |
112 | num = float(str(s).replace(",", ""))
113 | return True, num
114 | except ValueError:
115 | return False, None
116 |
117 |
118 | def normalize(answer, pi) -> str:
119 | # checking if answer is $ and removing $ in that case to compare
120 | if isinstance(answer, str) and bool(re.match(r"\$\d+(\.\d+)?", answer)):
121 | return answer[1:]
122 |
123 | # checking if answer is % or \\% and removing %
124 | if isinstance(answer, str) and (
125 | bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer))
126 | ):
127 | return answer.replace("\\%", "").replace("%", "")
128 |
129 | # handle base
130 | answer = handle_base(answer)
131 |
132 | # handle pi
133 | answer = handle_pi(answer, pi)
134 |
135 | return answer
136 |
137 |
138 | def handle_base(x) -> str:
139 | if isinstance(x, str) and "_" in x:
140 | # Due to base
141 | x = x.split("_")[0]
142 | x = float(x)
143 | return int(x)
144 | return x
145 |
146 |
147 | def handle_pi(string, pi):
148 | if isinstance(string, str) and "\pi" in string:
149 | # Find the first occurrence of "\pi"
150 | idx = string.find("\pi")
151 |
152 | # Iterate over the string and find all occurrences of "\pi" with a valid previous character
153 | while idx != -1:
154 | if idx > 0 and string[idx - 1].isdigit():
155 | # Replace "\pi" with "*math.pi" if the previous character is a digit
156 | string = string[:idx] + f"*{pi}" + string[idx + 3 :]
157 | else:
158 | # Replace "\pi" with "1*math.pi" if the previous character is not a digit
159 | string = string[:idx] + f"1*{pi}" + string[idx + 3 :]
160 |
161 | # Find the next occurrence of "\pi"
162 | idx = string.find("\pi", idx + 1)
163 |
164 | # Evaluate the expression using eval() function
165 | with contextlib.suppress(Exception):
166 | string = eval(string)
167 |
168 | return string
169 |
170 |
171 | def math_equal(
172 | prediction: bool | float | str,
173 | reference: float | str,
174 | include_percentage: bool = True,
175 | tolerance: float = 1e-4,
176 | timeout: float = 10.0,
177 | pi: float = math.pi,
178 | ) -> bool:
179 | """
180 | Exact match of math if and only if:
181 | 1. numerical equal: both can convert to float and are equal
182 | 2. symbolic equal: both can convert to sympy expression and are equal
183 | """
184 |
185 | prediction = normalize(prediction, pi)
186 | reference = normalize(reference, pi)
187 |
188 | if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases
189 | prediction = prediction[:1000]
190 |
191 | # 0. string comparison
192 | if isinstance(prediction, str) and isinstance(reference, str):
193 | if prediction.strip().lower() == reference.strip().lower():
194 | return True
195 | if prediction.replace(" ", "") == reference.replace(" ", ""):
196 | return True
197 |
198 | try: # 1. numerical equal
199 | if is_digit(prediction)[0] and is_digit(reference)[0]:
200 | prediction = is_digit(prediction)[1]
201 | reference = is_digit(reference)[1]
202 | # number questions
203 | gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]
204 | for item in gt_result:
205 | try:
206 | if isclose(item, prediction, rel_tol=tolerance):
207 | return True
208 | except Exception:
209 | continue
210 | return False
211 | except Exception:
212 | pass
213 |
214 | if not prediction and prediction not in [0, False]:
215 | return False
216 |
217 | # 2. symbolic equal
218 | reference = str(reference).strip()
219 | prediction = str(prediction).strip()
220 |
221 | ## deal with [], (), {}
222 | prediction = format_intervals(prediction)
223 |
224 | pred_str, ref_str = prediction, reference
225 | if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (
226 | prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")
227 | ):
228 | pred_str = pred_str.strip("[]()")
229 | ref_str = ref_str.strip("[]()")
230 | for s in ["{", "}", "(", ")"]:
231 | ref_str = ref_str.replace(s, "")
232 | pred_str = pred_str.replace(s, "")
233 | if pred_str == ref_str:
234 | return True
235 |
236 | ## [a, b] vs. [c, d], return a==c and b==d
237 | if (
238 | prediction
239 | and reference
240 | and prediction[0] in "(["
241 | and prediction[-1] in ")]"
242 | and prediction[0] == reference[0]
243 | and prediction[-1] == reference[-1]
244 | ):
245 | pred_parts = prediction[1:-1].split(",")
246 | ref_parts = reference[1:-1].split(",")
247 | if len(pred_parts) == len(ref_parts) and all(
248 | [
249 | math_equal(pred_pt, ref_pt, include_percentage, tolerance)
250 | for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)
251 | ]
252 | ):
253 | return True
254 |
255 | if "," in prediction and "," in reference:
256 | pred_parts = [item.strip() for item in prediction.split(",")]
257 | ref_parts = [item.strip() for item in reference.split(",")]
258 |
259 | if len(pred_parts) == len(ref_parts):
260 | return bool(
261 | all(
262 | [
263 | math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance)
264 | for i in range(len(pred_parts))
265 | ]
266 | )
267 | )
268 |
269 | # if we have point == tuple of values
270 | if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")":
271 | pred_parts = prediction[prediction.find("(") + 1 : -1].split(",")
272 | ref_parts = reference[1:-1].split(",")
273 | if len(pred_parts) == len(ref_parts) and all(
274 | [
275 | math_equal(pred_pt, ref_pt, include_percentage, tolerance)
276 | for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=False)
277 | ]
278 | ):
279 | return True
280 |
281 | # if reference is a matrix
282 | if "\begin{pmatrix}" in reference and prediction.startswith("Matrix"):
283 | try:
284 | pred_matrix = parse_expr(prediction)
285 | ref_matrix_items = reference.split()[1:-1:2]
286 | if len(pred_matrix) == len(ref_matrix_items) and all(
287 | [
288 | math_equal(pred, ref, include_percentage, tolerance)
289 | for ref, pred in zip(ref_matrix_items, pred_matrix, strict=False)
290 | ]
291 | ):
292 | return True
293 | except Exception:
294 | pass
295 | elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"):
296 | if isinstance(eval(prediction), list):
297 | try:
298 | pred_matrix = eval(prediction)
299 | # ref_matrix_items = reference.split()[1:-1:2]
300 | ref_matrix_items = (
301 | reference.lstrip("\\begin{pmatrix}") # noqa: B005
302 | .lstrip("\begin{pmatrix}")
303 | .rstrip("\\end{pmatrix}")
304 | .rstrip("\end{pmatrix}")
305 | ) # noqa: B005
306 | ref_matrix_items = ref_matrix_items.split("\\")
307 | ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items]
308 | if len(pred_matrix) == len(ref_matrix_items) and all(
309 | [
310 | math_equal(pred, ref, include_percentage, tolerance)
311 | for ref, pred in zip(ref_matrix_items, pred_matrix, strict=False)
312 | ]
313 | ):
314 | return True
315 | except Exception:
316 | pass
317 |
318 | return symbolic_equal(prediction, reference, tolerance, timeout)
319 |
320 |
321 | def symbolic_equal(a, b, tolerance, timeout=10.0):
322 | def _parse(s):
323 | for f in [parse_expr, parse_latex]:
324 | try:
325 | with timeout_limit(seconds=timeout):
326 | return f(s)
327 | except TimeoutError:
328 | print(f"Parsing timed out for {s}")
329 | continue
330 | except Exception:
331 | continue
332 | return s
333 |
334 | a = _parse(a)
335 | b = _parse(b)
336 |
337 | try:
338 | with timeout_limit(seconds=timeout):
339 | if simplify(a - b) == 0:
340 | return True
341 | except TimeoutError:
342 | print(f"Simplification timed out for {a} - {b}")
343 | pass
344 | except Exception:
345 | pass
346 |
347 | try:
348 | with timeout_limit(seconds=timeout):
349 | if isclose(N(a), N(b), rel_tol=tolerance):
350 | return True
351 | except TimeoutError:
352 | print(f"Numerical evaluation timed out for {a}, {b}")
353 | pass
354 | except Exception:
355 | pass
356 | return False
357 |
358 |
359 | def format_intervals(prediction):
360 | patterns = {
361 | "Interval(": r"^Interval\((.*)\)$",
362 | "Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$",
363 | "Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$",
364 | "Interval.open(": r"^Interval\.open\((.*)\)$",
365 | }
366 |
367 | for key, pattern in patterns.items():
368 | match = re.match(pattern, prediction)
369 | if match:
370 | inner_content = match.group(1)
371 |
372 | if key == "Interval(": # Intarval(a, b) == [a, b]
373 | return f"[{inner_content}]"
374 | elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b)
375 | return f"[{inner_content})"
376 | elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b]
377 | return f"({inner_content}]"
378 | elif key == "Interval.open(": # Intarval.open(a, b) == (a, b)
379 | return f"({inner_content})"
380 |
381 | return prediction
382 |
383 |
384 | import os
385 | import signal
386 | import queue
387 | import multiprocessing
388 | from functools import wraps
389 | from typing import Callable, Any
390 |
391 |
392 | def _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: tuple, kwargs: dict[str, Any]):
393 | """
394 | Internal wrapper function executed in the child process.
395 | Calls the original target function and puts the result or exception into the queue.
396 | """
397 | try:
398 | result = target_func(*args, **kwargs)
399 | mp_queue.put((True, result)) # Indicate success and put result
400 | except Exception as e:
401 | # Ensure the exception is pickleable for the queue
402 | try:
403 | import pickle
404 |
405 | pickle.dumps(e) # Test if the exception is pickleable
406 | mp_queue.put((False, e)) # Indicate failure and put exception
407 | except (pickle.PicklingError, TypeError):
408 | # Fallback if the original exception cannot be pickled
409 | mp_queue.put((False, RuntimeError(f"Original exception type {type(e).__name__} not pickleable: {e}")))
410 |
411 |
412 | def timeout_limit(seconds: float, use_signals: bool = False):
413 | """
414 | Decorator to add a timeout to a function.
415 |
416 | Args:
417 | seconds: The timeout duration in seconds.
418 | use_signals: (Deprecated) This is deprecated because signals only work reliably in the main thread
419 | and can cause issues in multiprocessing or multithreading contexts.
420 | Defaults to False, which uses the more robust multiprocessing approach.
421 |
422 | Returns:
423 | A decorated function with timeout.
424 |
425 | Raises:
426 | TimeoutError: If the function execution exceeds the specified time.
427 | RuntimeError: If the child process exits with an error (multiprocessing mode).
428 | NotImplementedError: If the OS is not POSIX (signals are only supported on POSIX).
429 | """
430 |
431 | def decorator(func):
432 | if use_signals:
433 | if os.name != "posix":
434 | raise NotImplementedError(f"Unsupported OS: {os.name}")
435 | # Issue deprecation warning if use_signals is explicitly True
436 | print(
437 | "WARN: The 'use_signals=True' option in the timeout decorator is deprecated. \
438 | Signals are unreliable outside the main thread. \
439 | Please use the default multiprocessing-based timeout (use_signals=False)."
440 | )
441 |
442 | @wraps(func)
443 | def wrapper_signal(*args, **kwargs):
444 | def handler(signum, frame):
445 | # Update function name in error message if needed (optional but good practice)
446 | raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (signal)!")
447 |
448 | old_handler = signal.getsignal(signal.SIGALRM)
449 | signal.signal(signal.SIGALRM, handler)
450 | # Use setitimer for float seconds support, alarm only supports integers
451 | signal.setitimer(signal.ITIMER_REAL, seconds)
452 |
453 | try:
454 | result = func(*args, **kwargs)
455 | finally:
456 | # Reset timer and handler
457 | signal.setitimer(signal.ITIMER_REAL, 0)
458 | signal.signal(signal.SIGALRM, old_handler)
459 | return result
460 |
461 | return wrapper_signal
462 | else:
463 | # --- Multiprocessing based timeout (existing logic) ---
464 | @wraps(func)
465 | def wrapper_mp(*args, **kwargs):
466 | q = multiprocessing.Queue(maxsize=1)
467 | process = multiprocessing.Process(target=_mp_target_wrapper, args=(func, q, args, kwargs))
468 | process.start()
469 | process.join(timeout=seconds)
470 |
471 | if process.is_alive():
472 | process.terminate()
473 | process.join(timeout=0.5) # Give it a moment to terminate
474 | if process.is_alive():
475 | print(f"Warning: Process {process.pid} did not terminate gracefully after timeout.")
476 | # Update function name in error message if needed (optional but good practice)
477 | raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (multiprocessing)!")
478 |
479 | try:
480 | success, result_or_exc = q.get(timeout=0.1) # Small timeout for queue read
481 | if success:
482 | return result_or_exc
483 | else:
484 | raise result_or_exc # Reraise exception from child
485 | except queue.Empty as err:
486 | exitcode = process.exitcode
487 | if exitcode is not None and exitcode != 0:
488 | raise RuntimeError(
489 | f"Child process exited with error (exitcode: {exitcode}) before returning result."
490 | ) from err
491 | else:
492 | # Should have timed out if queue is empty after join unless process died unexpectedly
493 | # Update function name in error message if needed (optional but good practice)
494 | raise TimeoutError(
495 | f"Operation timed out or process finished unexpectedly without result "
496 | f"(exitcode: {exitcode})."
497 | ) from err
498 | finally:
499 | q.close()
500 | q.join_thread()
501 |
502 | return wrapper_mp
503 |
504 | return decorator
505 |
--------------------------------------------------------------------------------
/rstar2_agent/rstar2_agent_ray_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import uuid
5 | from copy import deepcopy
6 | from pprint import pprint
7 |
8 | import numpy as np
9 | import ray
10 | import torch
11 | from tqdm import tqdm
12 |
13 | from verl import DataProto
14 | from verl.experimental.dataset.sampler import AbstractCurriculumSampler
15 | from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
16 | from verl.trainer.ppo.metric_utils import (
17 | compute_data_metrics,
18 | compute_throughout_metrics,
19 | compute_timing_metrics,
20 | )
21 | from verl.trainer.ppo.ray_trainer import (
22 | RayPPOTrainer,
23 | apply_kl_penalty,
24 | compute_advantage,
25 | compute_response_mask,
26 | )
27 | from verl.trainer.ppo.reward import compute_reward, compute_reward_async
28 | from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi
29 | from verl.utils.debug import marked_timer
30 | from verl.utils.metric import reduce_metrics
31 | from verl.utils.rollout_skip import RolloutSkip
32 |
33 | from .down_sample import reject_equal_reward, resample_of_correct
34 |
35 |
36 | class RStar2AgentRayTrainer(RayPPOTrainer):
37 | def _down_sample_batch(self, batch: DataProto) -> DataProto:
38 | do_down_sampling = self.config.augmentation.do_down_sampling
39 | down_sampling_config = self.config.augmentation.down_sampling_config
40 | world_size = self.actor_rollout_wg.world_size
41 | metrics = {"down_sampling/before_sampling_trace_num": len(batch),}
42 |
43 | def check_batch_is_empty(batch: DataProto, down_sampling_stage: str):
44 | if batch is None or len(batch) == 0:
45 | print(f"Batch is empty after {down_sampling_stage}, skipping the training step.")
46 | return True
47 | return False
48 |
49 | # reject rollout trace of the same prompt with equal rewards
50 | do_reject_equal_reward = down_sampling_config.get("reject_equal_reward", False) and do_down_sampling
51 | batch, _metrics = reject_equal_reward(batch, do_reject_equal_reward, world_size)
52 | metrics.update(_metrics)
53 | if check_batch_is_empty(batch, "reject_equal_reward"):
54 | return None, metrics
55 |
56 | # weighted sampling
57 | config = {
58 | "roc_error_ratio": down_sampling_config.get("roc_error_ratio", False) and do_down_sampling,
59 | "roc_answer_format": down_sampling_config.get("roc_answer_format", False) and do_down_sampling,
60 | "min_zero_reward_trace_num": down_sampling_config.get("min_zero_reward_trace_num", -1),
61 | "min_non_zero_reward_trace_num": down_sampling_config.get("min_non_zero_reward_trace_num", -1),
62 | "down_sample_to_n": down_sampling_config.get("down_sample_to_n", -1),
63 | }
64 | batch, _metrics = resample_of_correct(batch, self.tokenizer, config, do_down_sampling, world_size=world_size)
65 | metrics.update(_metrics)
66 | if check_batch_is_empty(batch, "fused_weighted_sampling"):
67 | return None, metrics
68 |
69 | metrics["down_sampling/after_sampling_trace_num"] = len(batch)
70 | return batch, metrics
71 |
72 | def fit(self):
73 | """
74 | The training loop of PPO.
75 | The driver process only need to call the compute functions of the worker group through RPC
76 | to construct the PPO dataflow.
77 | The light-weight advantage computation is done on the driver process.
78 |
79 | Most logic is same with RayPPOTrainer, mainly add down sample related.
80 | """
81 | from omegaconf import OmegaConf
82 |
83 | from verl.utils.tracking import Tracking
84 |
85 | logger = Tracking(
86 | project_name=self.config.trainer.project_name,
87 | experiment_name=self.config.trainer.experiment_name,
88 | default_backend=self.config.trainer.logger,
89 | config=OmegaConf.to_container(self.config, resolve=True),
90 | )
91 |
92 | self.global_steps = 0
93 |
94 | # load checkpoint before doing anything
95 | self._load_checkpoint()
96 |
97 | # perform validation before training
98 | # currently, we only support validation using the reward_function.
99 | if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
100 | val_metrics = self._validate()
101 | assert val_metrics, f"{val_metrics=}"
102 | pprint(f"Initial validation metrics: {val_metrics}")
103 | logger.log(data=val_metrics, step=self.global_steps)
104 | if self.config.trainer.get("val_only", False):
105 | return
106 |
107 | if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
108 | rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
109 | rollout_skip.wrap_generate_sequences()
110 |
111 | # add tqdm
112 | progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
113 |
114 | # we start from step 1
115 | self.global_steps += 1
116 | last_val_metrics = None
117 | self.max_steps_duration = 0
118 |
119 | prev_step_profile = False
120 | curr_step_profile = (
121 | self.global_steps in self.config.global_profiler.steps
122 | if self.config.global_profiler.steps is not None
123 | else False
124 | )
125 | next_step_profile = False
126 |
127 | for epoch in range(self.config.trainer.total_epochs):
128 | for batch_dict in self.train_dataloader:
129 | metrics = {}
130 | timing_raw = {}
131 |
132 | with marked_timer("start_profile", timing_raw):
133 | self._start_profiling(
134 | not prev_step_profile and curr_step_profile
135 | if self.config.global_profiler.profile_continuous_steps
136 | else curr_step_profile
137 | )
138 |
139 | batch: DataProto = DataProto.from_single_dict(batch_dict)
140 |
141 | # add uid to batch
142 | batch.non_tensor_batch["uid"] = np.array(
143 | [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
144 | )
145 |
146 | # pop those keys for generation
147 | batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
148 | non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
149 | if "multi_modal_data" in batch.non_tensor_batch:
150 | non_tensor_batch_keys_to_pop.append("multi_modal_data")
151 | if "raw_prompt" in batch.non_tensor_batch:
152 | non_tensor_batch_keys_to_pop.append("raw_prompt")
153 | if "tools_kwargs" in batch.non_tensor_batch:
154 | non_tensor_batch_keys_to_pop.append("tools_kwargs")
155 | if "interaction_kwargs" in batch.non_tensor_batch:
156 | non_tensor_batch_keys_to_pop.append("interaction_kwargs")
157 | if "index" in batch.non_tensor_batch:
158 | non_tensor_batch_keys_to_pop.append("index")
159 | if "agent_name" in batch.non_tensor_batch:
160 | non_tensor_batch_keys_to_pop.append("agent_name")
161 |
162 | gen_batch = batch.pop(
163 | batch_keys=batch_keys_to_pop,
164 | non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
165 | )
166 |
167 | # pass global_steps to trace
168 | gen_batch.meta_info["global_steps"] = self.global_steps
169 | gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
170 |
171 | is_last_step = self.global_steps >= self.total_training_steps
172 |
173 | with marked_timer("step", timing_raw):
174 | # generate a batch
175 | with marked_timer("gen", timing_raw, color="red"):
176 | if not self.async_rollout_mode:
177 | gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
178 | else:
179 | gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
180 | timing_raw.update(gen_batch_output.meta_info["timing"])
181 | gen_batch_output.meta_info.pop("timing", None)
182 |
183 | if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
184 | if self.reward_fn is None:
185 | raise ValueError("A reward_fn is required for REMAX advantage estimation.")
186 |
187 | with marked_timer("gen_max", timing_raw, color="purple"):
188 | gen_baseline_batch = deepcopy(gen_batch)
189 | gen_baseline_batch.meta_info["do_sample"] = False
190 | if not self.async_rollout_mode:
191 | gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
192 | else:
193 | gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
194 | batch = batch.union(gen_baseline_output)
195 | reward_baseline_tensor = self.reward_fn(batch)
196 | reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
197 |
198 | batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
199 |
200 | batch.batch["reward_baselines"] = reward_baseline_tensor
201 |
202 | del gen_baseline_batch, gen_baseline_output
203 |
204 | # repeat to align with repeated responses in rollout
205 | batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
206 | batch = batch.union(gen_batch_output)
207 |
208 | if "response_mask" not in batch.batch.keys():
209 | batch.batch["response_mask"] = compute_response_mask(batch)
210 |
211 | # compute global_valid tokens
212 | batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
213 |
214 | with marked_timer("reward", timing_raw, color="yellow"):
215 | # compute reward model score
216 | if self.use_rm:
217 | reward_tensor = self.rm_wg.compute_rm_score(batch)
218 | batch = batch.union(reward_tensor)
219 |
220 | if self.config.reward_model.launch_reward_fn_async:
221 | future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)
222 | else:
223 | reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
224 | batch.batch["token_level_scores"] = reward_tensor
225 | if reward_extra_infos_dict:
226 | batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
227 | reward_extra_infos_dict_keys = list(reward_extra_infos_dict.keys())
228 |
229 | ################################### rStar ###################################
230 | # Need to refactor the launch_reward_fn_async to support down sampling,
231 | # only forbid combine launch_reward_fn_async and down sampling for now.
232 | with marked_timer("down_sample", timing_raw, color="yellow"):
233 | assert not (self.config.reward_model.launch_reward_fn_async and self.config.augmentation.do_down_sampling), \
234 | "down sampling cannot combine with async reward function for now"
235 | batch, down_sampling_metrics = self._down_sample_batch(batch)
236 | metrics.update(down_sampling_metrics)
237 | if batch is None:
238 | continue
239 | #############################################################################
240 |
241 | ################################### rStar ###################################
242 | # Move the balance logic after down sampling
243 |
244 | # Balance the number of valid tokens across DP ranks.
245 | # NOTE: This usually changes the order of data in the `batch`,
246 | # which won't affect the advantage calculation (since it's based on uid),
247 | # but might affect the loss calculation (due to the change of mini-batching).
248 | # TODO: Decouple the DP balancing and mini-batching.
249 | if self.config.trainer.balance_batch:
250 | self._balance_batch(batch, metrics=metrics)
251 | #############################################################################
252 |
253 | # recompute old_log_probs
254 | with marked_timer("old_log_prob", timing_raw, color="blue"):
255 | old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
256 | entropys = old_log_prob.batch["entropys"]
257 | response_masks = batch.batch["response_mask"]
258 | loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
259 | entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
260 | old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
261 | metrics.update(old_log_prob_metrics)
262 | old_log_prob.batch.pop("entropys")
263 | batch = batch.union(old_log_prob)
264 |
265 | if "rollout_log_probs" in batch.batch.keys():
266 | # TODO: we may want to add diff of probs too.
267 | from verl.utils.debug.metrics import calculate_debug_metrics
268 |
269 | metrics.update(calculate_debug_metrics(batch))
270 |
271 | if self.use_reference_policy:
272 | # compute reference log_prob
273 | with marked_timer("ref", timing_raw, color="olive"):
274 | if not self.ref_in_actor:
275 | ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
276 | else:
277 | ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
278 | batch = batch.union(ref_log_prob)
279 |
280 | # compute values
281 | if self.use_critic:
282 | with marked_timer("values", timing_raw, color="cyan"):
283 | values = self.critic_wg.compute_values(batch)
284 | batch = batch.union(values)
285 |
286 | with marked_timer("adv", timing_raw, color="brown"):
287 | # we combine with rule-based rm
288 | ################################### rStar ###################################
289 | # Because down sampling cannot combine with config.reward_model.launch_reward_fn_async,
290 | # here refactor the reward setting logic, and recreate the reward_extra_infos_dict.
291 | if self.config.reward_model.launch_reward_fn_async:
292 | reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
293 | batch.batch["token_level_scores"] = reward_tensor
294 | if reward_extra_infos_dict:
295 | batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
296 | reward_extra_infos_dict_keys = list(reward_extra_infos_dict.keys())
297 | reward_extra_infos_dict = {key: batch.non_tensor_batch[key].tolist() for key in reward_extra_infos_dict_keys}
298 | ################################################################################
299 |
300 | # compute rewards. apply_kl_penalty if available
301 | if self.config.algorithm.use_kl_in_reward:
302 | batch, kl_metrics = apply_kl_penalty(
303 | batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
304 | )
305 | metrics.update(kl_metrics)
306 | else:
307 | batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
308 |
309 | # compute advantages, executed on the driver process
310 |
311 | norm_adv_by_std_in_grpo = self.config.algorithm.get(
312 | "norm_adv_by_std_in_grpo", True
313 | ) # GRPO adv normalization factor
314 |
315 | batch = compute_advantage(
316 | batch,
317 | adv_estimator=self.config.algorithm.adv_estimator,
318 | gamma=self.config.algorithm.gamma,
319 | lam=self.config.algorithm.lam,
320 | num_repeat=self.config.actor_rollout_ref.rollout.n,
321 | norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
322 | config=self.config.algorithm,
323 | )
324 |
325 | # update critic
326 | if self.use_critic:
327 | with marked_timer("update_critic", timing_raw, color="pink"):
328 | critic_output = self.critic_wg.update_critic(batch)
329 | critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
330 | metrics.update(critic_output_metrics)
331 |
332 | # implement critic warmup
333 | if self.config.trainer.critic_warmup <= self.global_steps:
334 | # update actor
335 | with marked_timer("update_actor", timing_raw, color="red"):
336 | batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
337 | actor_output = self.actor_rollout_wg.update_actor(batch)
338 | actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
339 | metrics.update(actor_output_metrics)
340 |
341 | # Log rollout generations if enabled
342 | rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
343 | if rollout_data_dir:
344 | with marked_timer("dump_rollout_generations", timing_raw, color="green"):
345 | inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
346 | outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
347 | scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
348 | sample_gts = [
349 | item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None)
350 | for item in batch
351 | ]
352 |
353 | if "request_id" in batch.non_tensor_batch:
354 | reward_extra_infos_dict.setdefault(
355 | "request_id",
356 | batch.non_tensor_batch["request_id"].tolist(),
357 | )
358 |
359 | self._dump_generations(
360 | inputs=inputs,
361 | outputs=outputs,
362 | gts=sample_gts,
363 | scores=scores,
364 | reward_extra_infos_dict=reward_extra_infos_dict,
365 | dump_path=rollout_data_dir,
366 | )
367 |
368 | # validate
369 | if (
370 | self.val_reward_fn is not None
371 | and self.config.trainer.test_freq > 0
372 | and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
373 | ):
374 | with marked_timer("testing", timing_raw, color="green"):
375 | val_metrics: dict = self._validate()
376 | if is_last_step:
377 | last_val_metrics = val_metrics
378 | metrics.update(val_metrics)
379 |
380 | # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
381 | esi_close_to_expiration = should_save_ckpt_esi(
382 | max_steps_duration=self.max_steps_duration,
383 | redundant_time=self.config.trainer.esi_redundant_time,
384 | )
385 | # Check if the conditions for saving a checkpoint are met.
386 | # The conditions include a mandatory condition (1) and
387 | # one of the following optional conditions (2/3/4):
388 | # 1. The save frequency is set to a positive value.
389 | # 2. It's the last training step.
390 | # 3. The current step number is a multiple of the save frequency.
391 | # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
392 | if self.config.trainer.save_freq > 0 and (
393 | is_last_step
394 | or self.global_steps % self.config.trainer.save_freq == 0
395 | or esi_close_to_expiration
396 | ):
397 | if esi_close_to_expiration:
398 | print("Force saving checkpoint: ESI instance expiration approaching.")
399 | with marked_timer("save_checkpoint", timing_raw, color="green"):
400 | self._save_checkpoint()
401 |
402 | with marked_timer("stop_profile", timing_raw):
403 | next_step_profile = (
404 | self.global_steps + 1 in self.config.global_profiler.steps
405 | if self.config.global_profiler.steps is not None
406 | else False
407 | )
408 | self._stop_profiling(
409 | curr_step_profile and not next_step_profile
410 | if self.config.global_profiler.profile_continuous_steps
411 | else curr_step_profile
412 | )
413 | prev_step_profile = curr_step_profile
414 | curr_step_profile = next_step_profile
415 |
416 | steps_duration = timing_raw["step"]
417 | self.max_steps_duration = max(self.max_steps_duration, steps_duration)
418 |
419 | # training metrics
420 | metrics.update(
421 | {
422 | "training/global_step": self.global_steps,
423 | "training/epoch": epoch,
424 | }
425 | )
426 | # collect metrics
427 | metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
428 | metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
429 | # TODO: implement actual tflpo and theoretical tflpo
430 | n_gpus = self.resource_pool_manager.get_n_gpus()
431 | metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
432 |
433 | # this is experimental and may be changed/removed in the future in favor of a general-purpose one
434 | if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
435 | self.train_dataloader.sampler.update(batch=batch)
436 |
437 | # TODO: make a canonical logger that supports various backend
438 | logger.log(data=metrics, step=self.global_steps)
439 |
440 | progress_bar.update(1)
441 | self.global_steps += 1
442 |
443 | if is_last_step:
444 | pprint(f"Final validation metrics: {last_val_metrics}")
445 | progress_bar.close()
446 | return
447 |
448 | # this is experimental and may be changed/removed in the future
449 | # in favor of a general-purpose data buffer pool
450 | if hasattr(self.train_dataset, "on_batch_end"):
451 | # The dataset may be changed after each training batch
452 | self.train_dataset.on_batch_end(batch=batch)
453 |
--------------------------------------------------------------------------------