├── assets
└── rise_method.jpeg
├── .gitmodules
├── verl_utils
├── run_ckpt_merge.sh
├── reward
│ ├── reward_func_verification.py
│ ├── reward_func.py
│ └── openmathinst_utils.py
├── data
│ ├── generate_splits.py
│ └── generate_splits_deepmath.py
└── model_merger.py
├── LICENSE
├── scripts
└── train
│ ├── start_qwen3b_rise_example.sh
│ └── start_qwen8b-base_rise_example.sh
└── README.md
/assets/rise_method.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyliu-cs/RISE/HEAD/assets/rise_method.jpeg
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "verl"]
2 | path = verl
3 | url = https://github.com/xyliu-cs/verl.git
4 |
--------------------------------------------------------------------------------
/verl_utils/run_ckpt_merge.sh:
--------------------------------------------------------------------------------
1 | FINETUNE_MODEL_PATH=/path/to/your/model
2 |
3 | python model_merger.py \
4 | --local_dir $FINETUNE_MODEL_PATH/global_step_96/actor
5 |
--------------------------------------------------------------------------------
/verl_utils/reward/reward_func_verification.py:
--------------------------------------------------------------------------------
1 | try:
2 | from .openmathinst_utils import extract_answer, math_equal
3 | except:
4 | from verl_utils.reward.openmathinst_utils import extract_answer, math_equal
5 |
6 | def ver_reward_func(data_source, solution_str, ground_truth, extra_info) -> float:
7 | extracted_answer = extract_answer(solution_str, extract_from_boxed=True)
8 | if extracted_answer is None: # formatting error
9 | return -1.0
10 | else:
11 | if math_equal(extracted_answer, ground_truth, check_antlr_version=False):
12 | return 1.0
13 | else:
14 | return -0.5
15 |
--------------------------------------------------------------------------------
/verl_utils/reward/reward_func.py:
--------------------------------------------------------------------------------
1 | try:
2 | from .openmathinst_utils import extract_answer, math_equal
3 | except:
4 | from verl_utils.reward.openmathinst_utils import extract_answer, math_equal
5 |
6 | def reward_func(data_source, solution_str, ground_truth, extra_info) -> float:
7 | extracted_answer = extract_answer(solution_str, extract_from_boxed=True)
8 | if extracted_answer is None: # formatting error
9 | return -1.0
10 | else:
11 | if math_equal(extracted_answer, ground_truth, check_antlr_version=False):
12 | return 1.0
13 | else:
14 | return -0.5
15 |
16 | def ver_reward_func(data_source, solution_str, ground_truth, extra_info) -> float:
17 | extracted_answer = extract_answer(solution_str, extract_from_boxed=True)
18 | if extracted_answer is None: # formatting error
19 | return -1.0
20 | if len(solution_str) < 800:
21 | return -1.0
22 | else:
23 | if math_equal(extracted_answer, ground_truth, check_antlr_version=False):
24 | return 1.0
25 | else:
26 | return -0.5
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Xiaoyuan Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/verl_utils/data/generate_splits.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess the MATH-Hard dataset to parquet format
3 | """
4 |
5 | import os
6 | from datasets import load_dataset
7 | import argparse
8 |
9 | train_data_path = 'data/train/MATH_Hard.jsonl'
10 | val_data_path = 'data/train/MATH_val.jsonl'
11 |
12 | if __name__ == '__main__':
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--local_dir', default='data/train')
15 | args = parser.parse_args()
16 |
17 | train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train")
18 | val_dataset = load_dataset("json", data_files={"val": val_data_path}, split="val")
19 |
20 | def process_fn_train(example, idx):
21 | data = {
22 | "data_source": train_data_path,
23 | "prompt": example['messages'],
24 | "ability": "math",
25 | "reward_model": {
26 | "style": "rule",
27 | "ground_truth": example['answer']
28 | },
29 | "extra_info": {
30 | 'split': 'train',
31 | 'index': idx,
32 | 'answer': example['answer'],
33 | "question": example['problem'],
34 | }
35 | }
36 | return data
37 |
38 | def process_fn_test(example, idx):
39 | data = {
40 | "data_source": val_data_path,
41 | "prompt": example['messages'],
42 | "ability": "math",
43 | "reward_model": {
44 | "style": "rule",
45 | "ground_truth": example['answer']
46 | },
47 | "extra_info": {
48 | 'split': 'test',
49 | 'index': idx,
50 | 'answer': example['answer'],
51 | "question": example['problem'],
52 | }
53 | }
54 | return data
55 |
56 | train_dataset = train_dataset.map(function=process_fn_train, with_indices=True)
57 | test_dataset = val_dataset.map(function=process_fn_test, with_indices=True)
58 | train_dataset.to_parquet(os.path.join(args.local_dir, 'train.parquet'))
59 | test_dataset.to_parquet(os.path.join(args.local_dir, 'test.parquet'))
60 |
--------------------------------------------------------------------------------
/scripts/train/start_qwen3b_rise_example.sh:
--------------------------------------------------------------------------------
1 | set -e
2 | set -u
3 |
4 | WANDB_TOKEN=xxx
5 | RUN_NAME=xxx
6 | DATA_DIR=/path/to/your/data
7 | MODEL_DIR=/path/to/your/model
8 | SAVE_DIR=/path/to/your/output
9 |
10 | mkdir -p .checkpoints/$RUN_NAME
11 | mkdir -p $SAVE_DIR
12 |
13 | # set http_proxy if needed
14 |
15 | # ray start --head --num-cpus=8 --dashboard-port=8265 --dashboard-host=0.0.0.0
16 |
17 | sleep 10
18 |
19 | ray job submit --address="http://127.0.0.1:8265" \
20 | --runtime-env-json='{
21 | "env_vars": {
22 | "HUGGING_FACE_HUB_TOKEN": "your_huggingface_token",
23 | "LM_HARNESS_CACHE_PATH": "cache",
24 | "VLLM_ATTENTION_BACKEND": "XFORMERS",
25 | "PYTHONUNBUFFERED": "1",
26 | "WANDB_API_KEY": "your_wandb_token",
27 | },
28 | "working_dir": "your_working_dir",
29 | "pip": ["latex2sympy2", "word2number", "timeout_decorator"]
30 | }' -- PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
31 | data.train_files=$DATA_DIR/train.parquet \
32 | data.val_files=$DATA_DIR/test.parquet \
33 | data.prompt_key=prompt \
34 | data.train_batch_size=1024 \
35 | +data.critique_batch_size=128 \
36 | data.val_batch_size=1024 \
37 | data.max_prompt_length=6000 \
38 | data.max_response_length=3000 \
39 | actor_rollout_ref.model.path=$MODEL_DIR \
40 | actor_rollout_ref.model.use_remove_padding=True \
41 | actor_rollout_ref.actor.optim.lr=5e-7 \
42 | actor_rollout_ref.actor.ppo_mini_batch_size=128 \
43 | actor_rollout_ref.actor.use_dynamic_bsz=True \
44 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
45 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
46 | actor_rollout_ref.rollout.temperature=1.0 \
47 | actor_rollout_ref.rollout.n=8 \
48 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
49 | actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
50 | actor_rollout_ref.rollout.disable_log_stats=False \
51 | actor_rollout_ref.rollout.enforce_eager=False \
52 | actor_rollout_ref.rollout.free_cache_engine=False \
53 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
54 | actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=48000 \
55 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
56 | actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=48000 \
57 | actor_rollout_ref.ref.fsdp_config.param_offload=True \
58 | critic.optim.lr=9e-6 \
59 | critic.model.path=$MODEL_DIR \
60 | critic.model.use_remove_padding=True \
61 | critic.ppo_max_token_len_per_gpu=24000 \
62 | critic.forward_max_token_len_per_gpu=48000 \
63 | reward_model.reward_func_path=verl_utils/reward/reward_func.py \
64 | algorithm.kl_ctrl.kl_coef=0.01 \
65 | trainer.project_name=verl \
66 | trainer.experiment_name=$RUN_NAME \
67 | trainer.default_local_dir=$SAVE_DIR/$RUN_NAME \
68 | trainer.logger=['console','wandb'] \
69 | +trainer.val_before_train=False \
70 | +trainer.online_critique=True \
71 | trainer.n_gpus_per_node=8 \
72 | trainer.nnodes=1 \
73 | trainer.save_freq=96 \
74 | trainer.save_rollout=True \
75 | trainer.test_freq=8 \
76 | trainer.total_epochs=12 2>&1 | tee -a .checkpoints/$RUN_NAME/train.log
77 |
78 | ray stop
79 |
--------------------------------------------------------------------------------
/verl_utils/data/generate_splits_deepmath.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess the deepmath dataset to parquet format
3 | """
4 |
5 | import os
6 | import datasets
7 | import argparse
8 |
9 | train_source = 'xiaoyuanliu/DeepMath-10K'
10 | train_split = 'train'
11 | val_source="xiaoyuanliu/math-gen-critique"
12 | val_split="math_val"
13 |
14 | my_system_prompt = 'Please reason step by step, and put your final answer within \\boxed{}.'
15 |
16 |
17 | def format_messages(question, system_prompt=my_system_prompt):
18 | if system_prompt:
19 | message = [
20 | {"role": "system", "content": system_prompt},
21 | {"role": "user", "content": question}
22 | ]
23 | else:
24 | message = [
25 | {"role": "user", "content": question}
26 | ]
27 |
28 | return message
29 |
30 |
31 |
32 | if __name__ == '__main__':
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('--local_dir', default='~/data/deepmath')
35 | parser.add_argument('--add_message', action='store_true', help='Whether to add message column to the dataset')
36 | args = parser.parse_args()
37 |
38 | train_dataset = datasets.load_dataset(train_source, split=train_split)
39 | val_dataset = datasets.load_dataset(val_source, split=val_split)
40 |
41 | if args.add_message:
42 | train_dataset = train_dataset.map(
43 | lambda x: {'messages': format_messages(x['question'], my_system_prompt)},
44 | desc='Formatting messages for train dataset'
45 | )
46 |
47 | def process_fn_train(example, idx):
48 | data = {
49 | "data_source": train_source,
50 | "prompt": example['messages'],
51 | "ability": "math",
52 | "reward_model": {
53 | "style": "rule",
54 | "ground_truth": example['final_answer']
55 | },
56 | "extra_info": {
57 | 'split': 'train',
58 | 'index': idx,
59 | 'answer': example['final_answer'],
60 | "question": example['question'],
61 | }
62 | }
63 | return data
64 |
65 | def process_fn_test(example, idx):
66 | data = {
67 | "data_source": val_source,
68 | "prompt": example['messages'],
69 | "ability": "math",
70 | "reward_model": {
71 | "style": "rule",
72 | "ground_truth": example['answer']
73 | },
74 | "extra_info": {
75 | 'split': 'test',
76 | 'index': idx,
77 | 'answer': example['answer'],
78 | "question": example['problem'],
79 | }
80 | }
81 | return data
82 |
83 |
84 | train_dataset = train_dataset.map(function=process_fn_train, with_indices=True)
85 | test_dataset = val_dataset.map(function=process_fn_test, with_indices=True)
86 | # preview the first few entries
87 | print('-'* 50)
88 | print("Train dataset sample:")
89 | print(train_dataset[5])
90 | print('-'* 50)
91 | print("Test dataset sample:")
92 | print(test_dataset[5])
93 | train_dataset.to_parquet(os.path.join(args.local_dir, 'train.parquet'))
94 | test_dataset.to_parquet(os.path.join(args.local_dir, 'test.parquet'))
95 |
96 |
--------------------------------------------------------------------------------
/scripts/train/start_qwen8b-base_rise_example.sh:
--------------------------------------------------------------------------------
1 | set -e
2 | set -u
3 |
4 |
5 | WANDB_TOKEN=xxxx
6 | RUN_NAME=xxxx
7 | DATA_DIR=/path/to/deepmath_10K
8 | MODEL_DIR=/path/to/Qwen3-8B-Base
9 | SAVE_DIR=/path/to/Qwen3-8B-Base-DeepMath10K-PPO-RISE
10 |
11 | mkdir -p .checkpoints/$RUN_NAME
12 | mkdir -p $SAVE_DIR
13 |
14 | # set http_proxy if needed
15 | ray start --head --num-cpus=16 --dashboard-port=8265 --dashboard-host=0.0.0.0
16 |
17 | sleep 10
18 |
19 | ray job submit --address="http://127.0.0.1:8265" \
20 | --runtime-env-json='{
21 | "env_vars": {
22 | "HUGGING_FACE_HUB_TOKEN": "xxxx",
23 | "LM_HARNESS_CACHE_PATH": "cache",
24 | "PYTHONUNBUFFERED": "1",
25 | "WANDB_API_KEY": "xxxx"
26 | },
27 | "working_dir": "/path/to/your/working_dir",
28 | "pip": ["latex2sympy2", "word2number", "timeout_decorator"],
29 | "excludes": [".git/**"]
30 | }' -- PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
31 | data.train_files=$DATA_DIR/train.parquet \
32 | data.val_files=$DATA_DIR/test.parquet \
33 | data.prompt_key=prompt \
34 | data.train_batch_size=1024 \
35 | +data.critique_batch_size=128 \
36 | data.max_prompt_length=3072 \
37 | data.max_response_length=8192 \
38 | data.qwen3_thinking=True \
39 | data.truncation=right \
40 | actor_rollout_ref.model.path=$MODEL_DIR \
41 | actor_rollout_ref.model.use_remove_padding=True \
42 | actor_rollout_ref.actor.optim.lr=5e-7 \
43 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \
44 | actor_rollout_ref.actor.use_dynamic_bsz=True \
45 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
46 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
47 | actor_rollout_ref.rollout.temperature=1.0 \
48 | actor_rollout_ref.rollout.n=8 \
49 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
50 | actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
51 | actor_rollout_ref.rollout.disable_log_stats=False \
52 | actor_rollout_ref.rollout.enforce_eager=False \
53 | actor_rollout_ref.rollout.free_cache_engine=False \
54 | actor_rollout_ref.rollout.max_num_batched_tokens=24000 \
55 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
56 | actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \
57 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
58 | actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=24000 \
59 | actor_rollout_ref.ref.fsdp_config.param_offload=True \
60 | critic.optim.lr=9e-6 \
61 | critic.model.path=$MODEL_DIR \
62 | critic.model.use_remove_padding=True \
63 | critic.ppo_max_token_len_per_gpu=24000 \
64 | critic.forward_max_token_len_per_gpu=24000 \
65 | custom_reward_function.path=verl_utils/reward/reward_func.py \
66 | custom_reward_function.name=reward_func \
67 | algorithm.kl_ctrl.kl_coef=0.01 \
68 | trainer.project_name=verl \
69 | trainer.experiment_name=$RUN_NAME \
70 | trainer.default_local_dir=$SAVE_DIR/$RUN_NAME \
71 | trainer.logger=['console','wandb'] \
72 | trainer.val_before_train=False \
73 | +trainer.online_critique=True \
74 | trainer.critique_prompt_idx=0 \
75 | trainer.n_gpus_per_node=8 \
76 | trainer.nnodes=1 \
77 | trainer.save_freq=8 \
78 | trainer.rollout_data_dir=/path/to/your/rollout_data_dir \
79 | trainer.test_freq=8 \
80 | trainer.total_epochs=12 2>&1 | tee -a .checkpoints/$RUN_NAME/train.log
81 |
82 | ray stop
83 |
--------------------------------------------------------------------------------
/verl_utils/model_merger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import List, Tuple, Dict
16 | import re
17 | import os
18 | import torch
19 | import argparse
20 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification
21 | from concurrent.futures import ThreadPoolExecutor
22 | from torch.distributed._tensor import DTensor, Shard, Placement
23 |
24 |
25 | def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
26 | if placement.is_replicate():
27 | return tensors[0]
28 | elif placement.is_partial():
29 | raise NotImplementedError("Partial placement is not supported yet")
30 | elif placement.is_shard():
31 | return torch.cat(tensors, dim=placement.dim).contiguous()
32 | else:
33 | raise ValueError(f"Unsupported placement: {placement}")
34 |
35 |
36 | if __name__ == '__main__':
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--local_dir', required=True, type = str, help="The path for your saved model")
39 | parser.add_argument("--hf_upload_path", default=False, type = str, help="The path of the huggingface repo to upload")
40 | args = parser.parse_args()
41 |
42 | assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface"
43 | local_dir = args.local_dir
44 |
45 | # copy rank zero to find the shape of (dp, fsdp)
46 | rank = 0
47 | world_size = 0
48 | for filename in os.listdir(local_dir):
49 | match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
50 | if match:
51 | world_size = match.group(1)
52 | break
53 | assert world_size, "No model file with the proper format"
54 |
55 | state_dict = torch.load(os.path.join(local_dir, f'model_world_size_{world_size}_rank_{rank}.pt'), map_location='cpu')
56 | pivot_key = sorted(list(state_dict.keys()))[0]
57 | weight = state_dict[pivot_key]
58 | assert isinstance(weight, torch.distributed._tensor.DTensor)
59 | # get sharding info
60 | device_mesh = weight.device_mesh
61 | mesh = device_mesh.mesh
62 | mesh_dim_names = device_mesh.mesh_dim_names
63 |
64 | print(f'Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}')
65 |
66 | assert mesh_dim_names in (
67 | ('fsdp',),
68 | ), f'Unsupported mesh_dim_names {mesh_dim_names}'
69 |
70 | if 'tp' in mesh_dim_names:
71 | # fsdp * tp
72 | total_shards = mesh.shape[-1] * mesh.shape[-2]
73 | mesh_shape = (mesh.shape[-2], mesh.shape[-1])
74 | else:
75 | # fsdp
76 | total_shards = mesh.shape[-1]
77 | mesh_shape = (mesh.shape[-1],)
78 |
79 | print(f'Processing model shards with {total_shards} {mesh_shape} in total')
80 |
81 | model_state_dict_lst = []
82 | model_state_dict_lst.append(state_dict)
83 | model_state_dict_lst.extend([""] * (total_shards - 1))
84 |
85 | def process_one_shard(rank):
86 | model_path = os.path.join(local_dir, f'model_world_size_{world_size}_rank_{rank}.pt')
87 | state_dict = torch.load(model_path, map_location='cpu', weights_only=False)
88 | model_state_dict_lst[rank] = state_dict
89 | return state_dict
90 |
91 | with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
92 | for rank in range(1, total_shards):
93 | executor.submit(process_one_shard, rank)
94 | state_dict = {}
95 | param_placements: Dict[str, List[Placement]] = {}
96 | keys = set(model_state_dict_lst[0].keys())
97 | for key in keys:
98 | state_dict[key] = []
99 | for model_state_dict in model_state_dict_lst:
100 | try:
101 | tensor = model_state_dict.pop(key)
102 | except:
103 | print("-"*30)
104 | print(model_state_dict)
105 | if isinstance(tensor, DTensor):
106 | state_dict[key].append(tensor._local_tensor.bfloat16())
107 | placements = tuple(tensor.placements)
108 | # replicated placement at dp dimension can be discarded
109 | if mesh_dim_names[0] == 'dp':
110 | placements = placements[1:]
111 | if key not in param_placements:
112 | param_placements[key] = placements
113 | else:
114 | assert param_placements[key] == placements
115 | else:
116 | state_dict[key] = tensor.bfloat16()
117 |
118 | del model_state_dict_lst
119 |
120 | for key in sorted(state_dict):
121 | if not isinstance(state_dict[key], list):
122 | print(f"No need to merge key {key}")
123 | continue
124 | # merge shards
125 | placements: Tuple[Shard] = param_placements[key]
126 | if len(mesh_shape) == 1:
127 | # 1-D list, FSDP without TP
128 | assert len(placements) == 1
129 | shards = state_dict[key]
130 | state_dict[key] = merge_by_placement(shards, placements[0])
131 | else:
132 | # 2-D list, FSDP + TP
133 | raise NotImplementedError("FSDP + TP is not supported yet")
134 |
135 | print('Writing to local disk')
136 | hf_path = os.path.join(local_dir, 'huggingface')
137 | config = AutoConfig.from_pretrained(hf_path)
138 |
139 | if 'ForTokenClassification' in config.architectures[0]:
140 | auto_model = AutoModelForTokenClassification
141 | elif 'ForCausalLM' in config.architectures[0]:
142 | auto_model = AutoModelForCausalLM
143 | else:
144 | raise NotImplementedError(f'Unknown architecture {config["architectures"]}')
145 |
146 | with torch.device('meta'):
147 | model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
148 | model.to_empty(device='cpu')
149 |
150 | print(f'Saving model to {hf_path}')
151 | model.save_pretrained(hf_path, state_dict=state_dict)
152 | del state_dict
153 | del model
154 | if args.hf_upload_path:
155 | # Push to hugging face
156 | from huggingface_hub import HfApi
157 | api = HfApi()
158 | api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True)
159 | api.upload_folder(
160 | folder_path=hf_path,
161 | repo_id=args.hf_upload_path,
162 | repo_type="model"
163 | )
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # RISE 📈
4 |
5 |
6 | Reinforcing Reasoning with Self-Verification
7 |
8 |
9 | 🔥 An online RL framework that simultaneously trains LLMs in problem-solving and self-verification with verifiable reward signals. 🔥
10 |
11 |
12 |
13 |
14 |
15 | 
16 |
17 |
18 | ## 🗒️ News
19 | - **July 5, 2025**: We release the training script of `Qwen3` series on RISE based on verl 0.4.0, which achieves strong results.
20 | - **June 12, 2025**: We update the [**RISE source code**](https://github.com/xyliu-cs/verl/tree/verl-v4) to support the latest verl release **v0.4.0**.
21 | - **May 20, 2025**: We release our technical report on [**arXiv**](https://arxiv.org/abs/2505.13445) and the initial version of training code based on [**verl**](https://github.com/volcengine/verl).
22 |
23 | ## 🎯Quick Start (verl v0.4.0)
24 | #### Environment Preparation
25 | ```shell
26 | conda create -y -n qwen3 python=3.12.2 ; conda activate qwen3
27 | pip3 install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
28 | pip3 install omegaconf==2.4.0.dev3 hydra-core==1.4.0.dev1 antlr4-python3-runtime==4.11.0
29 | pip3 install vllm==0.8.5.post1
30 | pip3 install math-verify[antlr4_11_0]==0.7.0
31 | git clone -b verl-v4 https://github.com/xyliu-cs/verl.git verl-v4
32 | pip3 uninstall -y verl ; cd verl-v4 ; pip3 install -e .
33 | pip3 install flash-attn==2.7.4.post1 --no-build-isolation
34 | pip3 install fire deepspeed tensorboardX prettytable datasets transformers==4.51.3
35 | pip3 install flashinfer-python -i https://flashinfer.ai/whl/cu124/torch2.6/
36 | pip3 install langdetect==1.0.9 pebble==5.1.0 word2number
37 | ```
38 |
39 | #### Data Processing
40 |
41 | ```shell
42 | OUTPUT_DATA_DIR=/path/to/your/data/output
43 | # Input data path is coded in generate_splits.py
44 | python3 verl_utils/data/generate_splits_deepmath.py --add_message --local_dir $OUTPUT_DATA_DIR
45 | ```
46 |
47 | #### Training
48 |
49 | * Start Ray
50 |
51 | ```shell
52 | # Head node (×1)
53 | ray start --head --port=6379 --node-ip-address=$HEAD_ADDR --num-gpus=8
54 |
55 | # Worker nodes (xN)
56 | # Use this only if you are running across multiple machines
57 | ray start --address=$HEAD_ADDR:6379 --node-ip-address=$WORKER_ADDR --num-gpus=8
58 | ```
59 |
60 | * Launch training at head node. See `scripts/train` for the complete training scripts.
61 | ```shell
62 | # Example
63 | sh scripts/train/start_qwen8b-base_rise_example.sh
64 | ```
65 | ‼️ **Key Parameters for RISE Algorithm**
66 |
67 | - `+trainer.online_critique`: Enables (`True`) or disables (`False`) online verification during the RL training.
68 | - `+data.critique_batch_size`: Controls the number of verification samples included in each training batch.
69 | - `trainer.critique_prompt_idx`: the verification prompt used for the RL training, can be customized in `verl/utils/critique_templates.py`. Default is 0.
70 | - `data.qwen3_thinking`: Enables (`True`) or disables (`False`) thinking mode for the Qwen3 (instruction-tuned) model. Set `True` for the base models.
71 | - `reward_model.reward_func_path`: Relative path (from `working_dir`) to the Python file defining the **generation reward** function. The file should contain a function named "reward_func".
72 | - `reward_model.ver_reward_func_path`: Path to the **verification reward** function file. This file should contain a function named "ver_reward_func". Default is `null`, and the generation reward function is used instead.
73 |
74 |
75 | ## 🎯Quick Start (verl v0.2.0)
76 |
77 | #### Environment Preparation
78 |
79 | ```shell
80 | git clone --recurse-submodules https://github.com/xyliu-cs/RISE.git && cd RISE
81 |
82 | conda create -y -n rise python=3.12.2 && conda activate rise
83 | pip3 install ray[default]
84 | pip3 install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
85 | pip3 install flash-attn==2.7.4.post1 --no-build-isolation
86 | pip3 install omegaconf==2.4.0.dev3 hydra-core==1.4.0.dev1 antlr4-python3-runtime==4.11.0 vllm==0.7.3
87 | pip3 install math-verify[antlr4_11_0]==0.7.0 fire deepspeed tensorboardX prettytable datasets
88 | cd verl
89 | pip3 install -e .
90 | ```
91 |
92 | #### Data Processing
93 |
94 | ```shell
95 | OUTPUT_DATA_DIR=/path/to/your/data/output
96 | # Input data path is coded in generate_splits.py
97 | python3 verl_utils/data/generate_splits.py --local_dir $OUTPUT_DATA_DIR
98 | ```
99 |
100 |
101 | #### Training
102 |
103 | * Start Ray
104 |
105 | ```shell
106 | # Head node (×1)
107 | ray start --head --port=6379 --node-ip-address=$HEAD_ADDR --num-gpus=8
108 |
109 | # Worker nodes (xN)
110 | # Use this only if you are running across multiple machines
111 | ray start --address=$HEAD_ADDR:6379 --node-ip-address=$WORKER_ADDR --num-gpus=8
112 | ```
113 |
114 | * Launch training at head node. See `scripts/train` for the complete training scripts.
115 | ```shell
116 | # Example
117 | sh scripts/train/start_qwen3b_rise_example.sh
118 | ```
119 | ‼️ **Key Parameters for RISE Algorithm**
120 |
121 | - `+trainer.online_critique`: Enables (`True`) or disables (`False`) online verification during the RL training.
122 | - `+data.critique_batch_size`: Controls the number of verification samples included in each training batch.
123 | - `reward_model.reward_func_path`: Relative path (from `working_dir`) to the Python file defining the **generation reward** function. The file should contain a function named "reward_func".
124 | - `reward_model.ver_reward_func_path`: Path to the **verification reward** function file. This file should contain a function named "ver_reward_func". Default is `null`, and the generation reward function is used instead.
125 |
126 |
127 | ## 🙏 Acknowledgements
128 |
129 | This work can not be done without the help of the following works:
130 |
131 | - **[verl](https://github.com/volcengine/verl)**: A very fast reinforcement learning framework.
132 | - **[vllm](https://github.com/vllm-project/vllm)**: A high-throughput and memory-efficient inference and serving engine for LLMs.
133 | - **[OpenMathInstruct-2](https://github.com/NVIDIA/NeMo-Skills)**: Model training and evaluation code.
134 | - **[SimpleRL](https://github.com/hkust-nlp/simpleRL-reason)**: RL training recipes for LLM reasoning.
135 | - **[DeepMath-103K](https://github.com/zwhe99/DeepMath)**: A Large-Scale, Challenging, Decontaminated, and Verifiable Mathematical Dataset for Advancing Reasoning.
136 |
137 |
138 |
139 | ## 📚 Citation
140 | ```bibtex
141 | @article{liu2025trustverifyselfverificationapproach,
142 | title={Trust, But Verify: A Self-Verification Approach to Reinforcement Learning with Verifiable Rewards},
143 | author={Xiaoyuan Liu and Tian Liang and Zhiwei He and Jiahao Xu and Wenxuan Wang and Pinjia He and Zhaopeng Tu and Haitao Mi and Dong Yu},
144 | year={2025},
145 | eprint={2505.13445},
146 | archivePrefix={arXiv},
147 | primaryClass={cs.AI},
148 | url={https://arxiv.org/abs/2505.13445},
149 | }
150 | ```
151 |
--------------------------------------------------------------------------------
/verl_utils/reward/openmathinst_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright (c) Microsoft Corporation.
16 | #
17 | # Permission is hereby granted, free of charge, to any person obtaining a copy
18 | # of this software and associated documentation files (the "Software"), to deal
19 | # in the Software without restriction, including without limitation the rights
20 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
21 | # copies of the Software, and to permit persons to whom the Software is
22 | # furnished to do so, subject to the following conditions:
23 | #
24 | # The above copyright notice and this permission notice shall be included in all
25 | # copies or substantial portions of the Software.
26 | #
27 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
28 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
29 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
30 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
31 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
33 | # SOFTWARE
34 |
35 | # Copyright (c) 2023 OpenAI
36 | #
37 | # Permission is hereby granted, free of charge, to any person obtaining a copy
38 | # of this software and associated documentation files (the "Software"), to deal
39 | # in the Software without restriction, including without limitation the rights
40 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
41 | # copies of the Software, and to permit persons to whom the Software is
42 | # furnished to do so, subject to the following conditions:
43 |
44 | # The above copyright notice and this permission notice shall be included in all
45 | # copies or substantial portions of the Software.
46 | #
47 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
48 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
49 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
50 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
51 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
52 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
53 | # SOFTWARE.
54 |
55 | # Copyright (c) 2021 Dan Hendrycks
56 | #
57 | # Permission is hereby granted, free of charge, to any person obtaining a copy
58 | # of this software and associated documentation files (the "Software"), to deal
59 | # in the Software without restriction, including without limitation the rights
60 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
61 | # copies of the Software, and to permit persons to whom the Software is
62 | # furnished to do so, subject to the following conditions:
63 | #
64 | # The above copyright notice and this permission notice shall be included in all
65 | # copies or substantial portions of the Software.
66 | #
67 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
68 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
69 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
70 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
71 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
72 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
73 | # SOFTWARE.
74 |
75 |
76 | """
77 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
78 | - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
79 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
80 | - https://github.com/openai/prm800k
81 | """
82 |
83 |
84 | import contextlib
85 | import re
86 | import signal
87 | from importlib.metadata import PackageNotFoundError, version
88 | from math import isclose
89 | from typing import Union
90 | from collections import Counter
91 | def most_common_element(data):
92 | """
93 | Finds the most common element in a list.
94 |
95 | Parameters:
96 | data (list): The list of elements.
97 |
98 | Returns:
99 | The most common element in the list. If there are multiple elements with
100 | the same highest frequency, it returns the first one encountered.
101 | """
102 | assert data and len(data) > 0, "Data is empty"
103 |
104 | counter = Counter(data)
105 | return counter.most_common(1)[0][0]
106 |
107 | def _check_antlr_version():
108 | "Function for checking the antlr package version."
109 | # Check antlr version
110 | PACKAGE_NAME = 'antlr4-python3-runtime'
111 | REQUIRED_VERSION = '4.11.0'
112 |
113 | try:
114 | installed_version = version(PACKAGE_NAME)
115 | if installed_version != REQUIRED_VERSION:
116 | raise RuntimeError(
117 | f"Package {PACKAGE_NAME} version mismatch: {installed_version} (required: {REQUIRED_VERSION})"
118 | )
119 | except PackageNotFoundError:
120 | raise RuntimeError(f"Package {PACKAGE_NAME} not found. Please install antlr4-python3-runtime==4.11.0.")
121 |
122 |
123 | def _fix_fracs(string):
124 | # replacing all extra spaces
125 | while "\\frac " in string:
126 | string = string.replace("\\frac ", "\\frac")
127 | substrs = string.split("\\frac")
128 | new_str = substrs[0]
129 | if len(substrs) > 1:
130 | substrs = substrs[1:]
131 | for substr in substrs:
132 | new_str += "\\frac"
133 | if len(substr) > 0 and substr[0] == "{":
134 | new_str += substr
135 | else:
136 | try:
137 | assert len(substr) >= 2
138 | except:
139 | return string
140 | a = substr[0]
141 | b = substr[1]
142 | if b != "{":
143 | if len(substr) > 2:
144 | post_substr = substr[2:]
145 | new_str += "{" + a + "}{" + b + "}" + post_substr
146 | else:
147 | new_str += "{" + a + "}{" + b + "}"
148 | else:
149 | if len(substr) > 2:
150 | post_substr = substr[2:]
151 | new_str += "{" + a + "}" + b + post_substr
152 | else:
153 | new_str += "{" + a + "}" + b
154 | string = new_str
155 | return string
156 |
157 |
158 | def _str_is_int(x: str) -> bool:
159 | try:
160 | x = _strip_properly_formatted_commas(x)
161 | x = float(x)
162 | return abs(x - int(round(x))) <= 1e-7
163 | except:
164 | return False
165 |
166 |
167 | def _str_to_int(x: str) -> bool:
168 | x = x.replace(",", "")
169 | if "_" in x:
170 | # Due to base
171 | x = x.split("_")[0]
172 | x = float(x)
173 | return int(x)
174 |
175 |
176 | def _inject_implicit_mixed_number(step: str):
177 | """
178 | Automatically make a mixed number evalable
179 | e.g. 7 3/4 => 7+3/4
180 | """
181 | p1 = re.compile("([0-9]) +([0-9])")
182 | step = p1.sub("\\1+\\2", step) ## implicit mults
183 | return step
184 |
185 |
186 | def _strip_properly_formatted_commas(expr: str):
187 | # We want to be careful because we don't want to strip tuple commas
188 | p1 = re.compile("(\d)(,)(\d\d\d)($|\D)")
189 | while True:
190 | next_expr = p1.sub("\\1\\3\\4", expr)
191 | if next_expr == expr:
192 | break
193 | expr = next_expr
194 | return next_expr
195 |
196 |
197 | def _remove_right_units(expr):
198 | # "\\text{ " only ever occurs (at least in the val set) when describing units
199 | if "\\text" in expr:
200 | try:
201 | splits = re.split(r"\\text\s*{\s*", expr)
202 | # print(splits)
203 | assert len(splits) == 2 and splits[0] not in ("", "(")
204 | return splits[0]
205 | except AssertionError:
206 | pass
207 |
208 | if "\\text{" in expr:
209 | return re.sub(r"\\text{([^}]+)}", r"\1", expr)
210 | elif "\\mbox{" in expr:
211 | splits = expr.split("\\mbox{")
212 | if len(splits) == 2:
213 | return splits[0]
214 | else:
215 | return expr
216 | else:
217 | return expr
218 |
219 |
220 | def _process_and_or_inside_text(string):
221 | string = re.sub(r"\s*\\text{\s*(or|and)\s*}\s*", ",", string)
222 | string = re.sub(r",\s*,", ",", string)
223 | return string
224 |
225 |
226 | def _remove_left_and_right(expr):
227 | """Remove the right and left latex commands."""
228 | expr = re.sub(r"\\left", "", expr)
229 | expr = re.sub(r"\\right", "", expr)
230 | return expr
231 |
232 |
233 | def _fix_sqrt(string):
234 | _string = re.sub(r"\\sqrt(\s*\w+)", r"\\sqrt{\1}", string)
235 | return _string
236 |
237 |
238 | def _fix_interval(expr):
239 | """Fix interval expression."""
240 | if "\\in " in expr:
241 | return expr.split("\\in ")[1].strip()
242 |
243 | return expr
244 |
245 |
246 | def _inject_implicit_mixed_fraction(step: str):
247 | """
248 | Automatically make a mixed number evalable
249 | e.g. 7 \\frac{3}{4} => 7+3/4
250 | """
251 | p1 = re.compile(r"(\d+) *\\frac{(\d+)}{(\d+)}")
252 |
253 | def replacer(match):
254 | whole_part = match.group(1)
255 | numerator = match.group(2)
256 | denominator = match.group(3)
257 |
258 | if whole_part:
259 | return f"{whole_part} + {numerator}/{denominator}"
260 | else:
261 | return f"{numerator}/{denominator}"
262 |
263 | step = p1.sub(replacer, step)
264 | return step
265 |
266 |
267 | def normalize_answer_string(expr: str) -> str:
268 | """Normalize answer expressions."""
269 | if expr is None:
270 | return None
271 |
272 | # Remove enclosing `\text{}`.
273 |
274 | expr = _remove_left_and_right(expr)
275 | expr = _process_and_or_inside_text(expr)
276 | expr = _remove_right_units(expr)
277 | expr = _fix_interval(expr)
278 | for surround_str in ["\\\\text", "\\\\mathrm", "\\\\mathcal", "\\\\textbf", "\\\\textit"]:
279 | expr = expr.replace(surround_str, "")
280 | pattern = f"^{surround_str}" + "\{(?P.+?)\}$"
281 | m = re.search(pattern, expr)
282 | if m is not None:
283 | expr = m.group("text")
284 |
285 | expr = expr.replace("\!", "")
286 | expr = expr.replace("\\%", "%")
287 | expr = expr.replace("\\$", "$")
288 | expr = expr.replace("$", "")
289 | expr = expr.replace("%", "")
290 | expr = expr.replace("^{\\circ}", "")
291 |
292 | expr = expr.replace(" or ", " , ")
293 | expr = expr.replace(" and ", " , ")
294 |
295 | expr = expr.replace("million", "*10^6")
296 | expr = expr.replace("billion", "*10^9")
297 | expr = expr.replace("trillion", "*10^12")
298 |
299 | for unit in [
300 | "degree",
301 | "cm",
302 | "centimeter",
303 | "meter",
304 | "mile",
305 | "second",
306 | "minute",
307 | "hour",
308 | "week",
309 | "month",
310 | "year",
311 | "foot",
312 | "feet",
313 | "inch",
314 | "yard",
315 | "p.m.",
316 | "PM",
317 | ]:
318 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
319 |
320 | if "day" in expr:
321 | days = [
322 | "Monday",
323 | "Tuesday",
324 | "Wednesday",
325 | "Thursday",
326 | "Friday",
327 | "Saturday",
328 | "Sunday",
329 | ]
330 | weekday_expressed = False
331 | for day in days:
332 | if day in expr:
333 | weekday_expressed = True
334 | break
335 |
336 | if not weekday_expressed:
337 | expr = re.sub(f"day(s)?", "", expr)
338 |
339 | expr = re.sub(f"\^ *\\\\circ", "", expr)
340 |
341 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
342 | expr = expr[1:-1]
343 |
344 | expr = _fix_sqrt(expr)
345 |
346 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
347 | expr = _fix_fracs(expr)
348 |
349 | # edge case with mixed numbers and negative signs
350 | expr = re.sub("- *", "-", expr)
351 | expr = _inject_implicit_mixed_number(expr)
352 | expr = _inject_implicit_mixed_fraction(expr)
353 | expr = expr.replace(" ", "")
354 |
355 | if _str_is_int(expr):
356 | expr = str(_str_to_int(expr))
357 |
358 | return expr
359 |
360 |
361 | def is_digit(s):
362 | try:
363 | if "{,}" in str(s):
364 | num = float(str(s).replace("{,}", ""))
365 | return True, num
366 |
367 | num = float(str(s).replace(",", ""))
368 | return True, num
369 | except ValueError:
370 | return False, None
371 |
372 |
373 | def normalize(answer) -> str:
374 | # checking if answer is $ and removing $ in that case to compare
375 | if isinstance(answer, str) and bool(re.match(r'\$\d+(\.\d+)?', answer)):
376 | return answer[1:]
377 |
378 | # checking if answer is % or \\% and removing %
379 | if isinstance(answer, str) and (
380 | bool(re.match(r'^\d+(\.\d+)?%$', answer)) or bool(re.match(r'^\d+(\.\d+)?\\%$', answer))
381 | ):
382 | return answer.replace("\\%", "").replace("%", "")
383 |
384 | return answer
385 |
386 |
387 | def math_equal(
388 | prediction: Union[bool, float, str],
389 | reference: Union[float, str],
390 | include_percentage: bool = True,
391 | tolerance: float = 1e-4,
392 | timeout: float = 10.0,
393 | check_antlr_version: bool = True
394 | ) -> bool:
395 | """
396 | Exact match of math if and only if:
397 | 1. numerical equal: both can convert to float and are equal
398 | 2. symbolic equal: both can convert to sympy expression and are equal
399 | """
400 |
401 | # Check that the right antlr version is installed.
402 | if check_antlr_version:
403 | _check_antlr_version()
404 |
405 | from sympy.parsing.sympy_parser import parse_expr
406 |
407 | prediction = normalize(prediction)
408 | reference = normalize(reference)
409 |
410 | # another round of normalization
411 | prediction = normalize_answer_string(prediction)
412 | reference = normalize_answer_string(reference)
413 |
414 | if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases
415 | prediction = prediction[:1000]
416 |
417 | # 0. string comparison
418 | if isinstance(prediction, str) and isinstance(reference, str):
419 | if prediction.strip().lower() == reference.strip().lower():
420 | return True
421 | if prediction.replace(" ", "") == reference.replace(" ", ""):
422 | return True
423 |
424 | try: # 1. numerical equal
425 | if is_digit(prediction)[0] and is_digit(reference)[0]:
426 | prediction = is_digit(prediction)[1]
427 | reference = is_digit(reference)[1]
428 | # number questions
429 | if include_percentage:
430 | gt_result = [reference / 100, reference, reference * 100]
431 | else:
432 | gt_result = [reference]
433 | for item in gt_result:
434 | try:
435 | if isclose(item, prediction, rel_tol=tolerance):
436 | return True
437 | except Exception:
438 | continue
439 | return False
440 | except Exception:
441 | pass
442 |
443 | if not prediction and prediction not in [0, False]:
444 | return False
445 |
446 | # 2. symbolic equal
447 | reference = str(reference).strip()
448 | prediction = str(prediction).strip()
449 |
450 | ## deal with [], (), {}
451 | prediction = format_intervals(prediction)
452 |
453 | pred_str, ref_str = prediction, reference
454 | if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (
455 | prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")
456 | ):
457 | pred_str = pred_str.strip("[]()")
458 | ref_str = ref_str.strip("[]()")
459 | for s in ["{", "}", "(", ")"]:
460 | ref_str = ref_str.replace(s, "")
461 | pred_str = pred_str.replace(s, "")
462 | if pred_str == ref_str:
463 | return True
464 |
465 | ## [a, b] vs. [c, d], return a==c and b==d
466 | if (
467 | prediction
468 | and reference
469 | and prediction[0] in "(["
470 | and prediction[-1] in ")]"
471 | and prediction[0] == reference[0]
472 | and prediction[-1] == reference[-1]
473 | ):
474 | pred_parts = prediction[1:-1].split(",")
475 | ref_parts = reference[1:-1].split(",")
476 | if len(pred_parts) == len(ref_parts):
477 | if all(
478 | [
479 | math_equal(pred_pt, ref_pt, include_percentage, tolerance, check_antlr_version=check_antlr_version)
480 | for pred_pt, ref_pt in zip(pred_parts, ref_parts)
481 | ]
482 | ):
483 | return True
484 |
485 | if "," in prediction and "," in reference:
486 | pred_parts = [item.strip() for item in prediction.split(",")]
487 | ref_parts = [item.strip() for item in reference.split(",")]
488 |
489 | if len(pred_parts) == len(ref_parts):
490 | if all(
491 | [
492 | math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance, check_antlr_version=check_antlr_version)
493 | for i in range(len(pred_parts))
494 | ]
495 | ):
496 | return True
497 | else:
498 | return False
499 |
500 | # if we have point == tuple of values
501 | if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")":
502 | pred_parts = prediction[prediction.find("(") + 1 : -1].split(",")
503 | ref_parts = reference[1:-1].split(",")
504 | if len(pred_parts) == len(ref_parts):
505 | if all(
506 | [
507 | math_equal(pred_pt, ref_pt, include_percentage, tolerance, check_antlr_version=check_antlr_version)
508 | for pred_pt, ref_pt in zip(pred_parts, ref_parts)
509 | ]
510 | ):
511 | return True
512 |
513 | # if reference is a matrix
514 | if reference.startswith("\\begin{pmatrix}") and prediction.startswith("Matrix"):
515 | try:
516 | pred_matrix = parse_expr(prediction)
517 | ref_matrix_items = reference.split()[1:-1:2]
518 | if len(pred_matrix) == len(ref_matrix_items):
519 | if all(
520 | [
521 | math_equal(ref, pred, include_percentage, tolerance, check_antlr_version=check_antlr_version)
522 | for ref, pred in zip(ref_matrix_items, pred_matrix)
523 | ]
524 | ):
525 | return True
526 | except Exception:
527 | pass
528 |
529 | return symbolic_equal(prediction, reference, tolerance, timeout)
530 |
531 |
532 | def symbolic_equal(a, b, tolerance, timeout=10.0):
533 | import sympy
534 | from sympy.parsing.latex import parse_latex
535 | from sympy.parsing.sympy_parser import parse_expr
536 |
537 | def _parse(s):
538 | for f in [parse_expr, parse_latex]:
539 | try:
540 | with time_limit(timeout):
541 | return f(s)
542 | except Exception:
543 | pass
544 | return s
545 |
546 | a = _parse(a)
547 | b = _parse(b)
548 |
549 | try:
550 | with time_limit(timeout):
551 | if sympy.simplify(a - b) == 0:
552 | return True
553 | except Exception:
554 | pass
555 |
556 | try:
557 | with time_limit(timeout):
558 | if isclose(sympy.N(a), sympy.N(b), rel_tol=tolerance):
559 | return True
560 | except Exception:
561 | pass
562 | return False
563 |
564 |
565 | def extract_answer(string: str, extract_from_boxed: bool = True, extract_regex: str = r"The final answer is (.+)$"):
566 | """Extract Answer String from \\boxed expression or based on regex"""
567 | if not extract_from_boxed:
568 | match = re.search(extract_regex, string)
569 | if match:
570 | return match.group(1)
571 | return None
572 |
573 | if "\\boxed" not in string:
574 | return None
575 |
576 | idx = string.rfind("\\boxed")
577 | if idx < 0:
578 | idx = string.rfind("\\fbox")
579 | if idx < 0:
580 | return None
581 |
582 | i = idx
583 | right_brace_idx = None
584 | num_left_braces_open = 0
585 | while i < len(string):
586 | if string[i] == "{":
587 | num_left_braces_open += 1
588 | if string[i] == "}":
589 | num_left_braces_open -= 1
590 | if num_left_braces_open == 0:
591 | right_brace_idx = i
592 | break
593 | i += 1
594 |
595 | if right_brace_idx is None:
596 | retval = None
597 | else:
598 | retval = string[idx : right_brace_idx + 1]
599 |
600 | if retval:
601 | left = "\\boxed{"
602 | try:
603 | assert retval[: len(left)] == left
604 | assert retval[-1] == "}"
605 | return retval[len(left) : -1]
606 | except AssertionError:
607 | return None
608 |
609 | return None
610 |
611 |
612 | class TimeoutException(Exception):
613 | pass
614 |
615 |
616 | @contextlib.contextmanager
617 | def time_limit(seconds: float):
618 | def signal_handler(signum, frame):
619 | raise TimeoutException("Timed out!")
620 |
621 | signal.setitimer(signal.ITIMER_REAL, seconds)
622 | signal.signal(signal.SIGALRM, signal_handler)
623 | try:
624 | yield
625 | finally:
626 | signal.setitimer(signal.ITIMER_REAL, 0)
627 |
628 |
629 | def format_intervals(prediction):
630 | patterns = {
631 | "Interval(": r"^Interval\((.*)\)$",
632 | "Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$",
633 | "Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$",
634 | "Interval.open(": r"^Interval\.open\((.*)\)$",
635 | }
636 |
637 | for key, pattern in patterns.items():
638 | match = re.match(pattern, prediction)
639 | if match:
640 | inner_content = match.group(1)
641 |
642 | if key == "Interval(": # Intarval(a, b) == [a, b]
643 | return f"[{inner_content}]"
644 | elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b)
645 | return f"[{inner_content})"
646 | elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b]
647 | return f"({inner_content}]"
648 | elif key == "Interval.open(": # Intarval.open(a, b) == (a, b)
649 | return f"({inner_content})"
650 |
651 | return prediction
652 |
653 | def process_results(
654 | response: Union[str, list[str]],
655 | answer: str,
656 | response_extract_from_boxed: bool = True,
657 | response_extract_regex: str = r"The final answer is (.+)$",
658 | ) -> bool:
659 | if isinstance(response, str):
660 | return math_equal(
661 | extract_answer(response, response_extract_from_boxed, response_extract_regex),
662 | answer,
663 | )
664 | elif isinstance(response, list):
665 | return math_equal(
666 | most_common_element(
667 | [
668 | extract_answer(r, response_extract_from_boxed, response_extract_regex)
669 | for r in response
670 | ]
671 | ),
672 | answer,
673 | )
674 | else:
675 | raise ValueError(f"Invalid response type: {type(response)}")
676 |
677 | def reward_func(data_source, solution_str, ground_truth) -> float:
678 | extracted_answer = extract_answer(solution_str, extract_from_boxed=True)
679 | if extracted_answer is None:
680 | return -1.0
681 | else:
682 | if math_equal(extracted_answer, ground_truth, check_antlr_version=False):
683 | return 1.0
684 | else:
685 | return -0.5
686 |
--------------------------------------------------------------------------------