├── assets └── rise_method.jpeg ├── .gitmodules ├── verl_utils ├── run_ckpt_merge.sh ├── reward │ ├── reward_func_verification.py │ ├── reward_func.py │ └── openmathinst_utils.py ├── data │ ├── generate_splits.py │ └── generate_splits_deepmath.py └── model_merger.py ├── LICENSE ├── scripts └── train │ ├── start_qwen3b_rise_example.sh │ └── start_qwen8b-base_rise_example.sh └── README.md /assets/rise_method.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyliu-cs/RISE/HEAD/assets/rise_method.jpeg -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "verl"] 2 | path = verl 3 | url = https://github.com/xyliu-cs/verl.git 4 | -------------------------------------------------------------------------------- /verl_utils/run_ckpt_merge.sh: -------------------------------------------------------------------------------- 1 | FINETUNE_MODEL_PATH=/path/to/your/model 2 | 3 | python model_merger.py \ 4 | --local_dir $FINETUNE_MODEL_PATH/global_step_96/actor 5 | -------------------------------------------------------------------------------- /verl_utils/reward/reward_func_verification.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .openmathinst_utils import extract_answer, math_equal 3 | except: 4 | from verl_utils.reward.openmathinst_utils import extract_answer, math_equal 5 | 6 | def ver_reward_func(data_source, solution_str, ground_truth, extra_info) -> float: 7 | extracted_answer = extract_answer(solution_str, extract_from_boxed=True) 8 | if extracted_answer is None: # formatting error 9 | return -1.0 10 | else: 11 | if math_equal(extracted_answer, ground_truth, check_antlr_version=False): 12 | return 1.0 13 | else: 14 | return -0.5 15 | -------------------------------------------------------------------------------- /verl_utils/reward/reward_func.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .openmathinst_utils import extract_answer, math_equal 3 | except: 4 | from verl_utils.reward.openmathinst_utils import extract_answer, math_equal 5 | 6 | def reward_func(data_source, solution_str, ground_truth, extra_info) -> float: 7 | extracted_answer = extract_answer(solution_str, extract_from_boxed=True) 8 | if extracted_answer is None: # formatting error 9 | return -1.0 10 | else: 11 | if math_equal(extracted_answer, ground_truth, check_antlr_version=False): 12 | return 1.0 13 | else: 14 | return -0.5 15 | 16 | def ver_reward_func(data_source, solution_str, ground_truth, extra_info) -> float: 17 | extracted_answer = extract_answer(solution_str, extract_from_boxed=True) 18 | if extracted_answer is None: # formatting error 19 | return -1.0 20 | if len(solution_str) < 800: 21 | return -1.0 22 | else: 23 | if math_equal(extracted_answer, ground_truth, check_antlr_version=False): 24 | return 1.0 25 | else: 26 | return -0.5 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Xiaoyuan Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /verl_utils/data/generate_splits.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess the MATH-Hard dataset to parquet format 3 | """ 4 | 5 | import os 6 | from datasets import load_dataset 7 | import argparse 8 | 9 | train_data_path = 'data/train/MATH_Hard.jsonl' 10 | val_data_path = 'data/train/MATH_val.jsonl' 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--local_dir', default='data/train') 15 | args = parser.parse_args() 16 | 17 | train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train") 18 | val_dataset = load_dataset("json", data_files={"val": val_data_path}, split="val") 19 | 20 | def process_fn_train(example, idx): 21 | data = { 22 | "data_source": train_data_path, 23 | "prompt": example['messages'], 24 | "ability": "math", 25 | "reward_model": { 26 | "style": "rule", 27 | "ground_truth": example['answer'] 28 | }, 29 | "extra_info": { 30 | 'split': 'train', 31 | 'index': idx, 32 | 'answer': example['answer'], 33 | "question": example['problem'], 34 | } 35 | } 36 | return data 37 | 38 | def process_fn_test(example, idx): 39 | data = { 40 | "data_source": val_data_path, 41 | "prompt": example['messages'], 42 | "ability": "math", 43 | "reward_model": { 44 | "style": "rule", 45 | "ground_truth": example['answer'] 46 | }, 47 | "extra_info": { 48 | 'split': 'test', 49 | 'index': idx, 50 | 'answer': example['answer'], 51 | "question": example['problem'], 52 | } 53 | } 54 | return data 55 | 56 | train_dataset = train_dataset.map(function=process_fn_train, with_indices=True) 57 | test_dataset = val_dataset.map(function=process_fn_test, with_indices=True) 58 | train_dataset.to_parquet(os.path.join(args.local_dir, 'train.parquet')) 59 | test_dataset.to_parquet(os.path.join(args.local_dir, 'test.parquet')) 60 | -------------------------------------------------------------------------------- /scripts/train/start_qwen3b_rise_example.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | set -u 3 | 4 | WANDB_TOKEN=xxx 5 | RUN_NAME=xxx 6 | DATA_DIR=/path/to/your/data 7 | MODEL_DIR=/path/to/your/model 8 | SAVE_DIR=/path/to/your/output 9 | 10 | mkdir -p .checkpoints/$RUN_NAME 11 | mkdir -p $SAVE_DIR 12 | 13 | # set http_proxy if needed 14 | 15 | # ray start --head --num-cpus=8 --dashboard-port=8265 --dashboard-host=0.0.0.0 16 | 17 | sleep 10 18 | 19 | ray job submit --address="http://127.0.0.1:8265" \ 20 | --runtime-env-json='{ 21 | "env_vars": { 22 | "HUGGING_FACE_HUB_TOKEN": "your_huggingface_token", 23 | "LM_HARNESS_CACHE_PATH": "cache", 24 | "VLLM_ATTENTION_BACKEND": "XFORMERS", 25 | "PYTHONUNBUFFERED": "1", 26 | "WANDB_API_KEY": "your_wandb_token", 27 | }, 28 | "working_dir": "your_working_dir", 29 | "pip": ["latex2sympy2", "word2number", "timeout_decorator"] 30 | }' -- PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ 31 | data.train_files=$DATA_DIR/train.parquet \ 32 | data.val_files=$DATA_DIR/test.parquet \ 33 | data.prompt_key=prompt \ 34 | data.train_batch_size=1024 \ 35 | +data.critique_batch_size=128 \ 36 | data.val_batch_size=1024 \ 37 | data.max_prompt_length=6000 \ 38 | data.max_response_length=3000 \ 39 | actor_rollout_ref.model.path=$MODEL_DIR \ 40 | actor_rollout_ref.model.use_remove_padding=True \ 41 | actor_rollout_ref.actor.optim.lr=5e-7 \ 42 | actor_rollout_ref.actor.ppo_mini_batch_size=128 \ 43 | actor_rollout_ref.actor.use_dynamic_bsz=True \ 44 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ 45 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ 46 | actor_rollout_ref.rollout.temperature=1.0 \ 47 | actor_rollout_ref.rollout.n=8 \ 48 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ 49 | actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ 50 | actor_rollout_ref.rollout.disable_log_stats=False \ 51 | actor_rollout_ref.rollout.enforce_eager=False \ 52 | actor_rollout_ref.rollout.free_cache_engine=False \ 53 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ 54 | actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=48000 \ 55 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ 56 | actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=48000 \ 57 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 58 | critic.optim.lr=9e-6 \ 59 | critic.model.path=$MODEL_DIR \ 60 | critic.model.use_remove_padding=True \ 61 | critic.ppo_max_token_len_per_gpu=24000 \ 62 | critic.forward_max_token_len_per_gpu=48000 \ 63 | reward_model.reward_func_path=verl_utils/reward/reward_func.py \ 64 | algorithm.kl_ctrl.kl_coef=0.01 \ 65 | trainer.project_name=verl \ 66 | trainer.experiment_name=$RUN_NAME \ 67 | trainer.default_local_dir=$SAVE_DIR/$RUN_NAME \ 68 | trainer.logger=['console','wandb'] \ 69 | +trainer.val_before_train=False \ 70 | +trainer.online_critique=True \ 71 | trainer.n_gpus_per_node=8 \ 72 | trainer.nnodes=1 \ 73 | trainer.save_freq=96 \ 74 | trainer.save_rollout=True \ 75 | trainer.test_freq=8 \ 76 | trainer.total_epochs=12 2>&1 | tee -a .checkpoints/$RUN_NAME/train.log 77 | 78 | ray stop 79 | -------------------------------------------------------------------------------- /verl_utils/data/generate_splits_deepmath.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess the deepmath dataset to parquet format 3 | """ 4 | 5 | import os 6 | import datasets 7 | import argparse 8 | 9 | train_source = 'xiaoyuanliu/DeepMath-10K' 10 | train_split = 'train' 11 | val_source="xiaoyuanliu/math-gen-critique" 12 | val_split="math_val" 13 | 14 | my_system_prompt = 'Please reason step by step, and put your final answer within \\boxed{}.' 15 | 16 | 17 | def format_messages(question, system_prompt=my_system_prompt): 18 | if system_prompt: 19 | message = [ 20 | {"role": "system", "content": system_prompt}, 21 | {"role": "user", "content": question} 22 | ] 23 | else: 24 | message = [ 25 | {"role": "user", "content": question} 26 | ] 27 | 28 | return message 29 | 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--local_dir', default='~/data/deepmath') 35 | parser.add_argument('--add_message', action='store_true', help='Whether to add message column to the dataset') 36 | args = parser.parse_args() 37 | 38 | train_dataset = datasets.load_dataset(train_source, split=train_split) 39 | val_dataset = datasets.load_dataset(val_source, split=val_split) 40 | 41 | if args.add_message: 42 | train_dataset = train_dataset.map( 43 | lambda x: {'messages': format_messages(x['question'], my_system_prompt)}, 44 | desc='Formatting messages for train dataset' 45 | ) 46 | 47 | def process_fn_train(example, idx): 48 | data = { 49 | "data_source": train_source, 50 | "prompt": example['messages'], 51 | "ability": "math", 52 | "reward_model": { 53 | "style": "rule", 54 | "ground_truth": example['final_answer'] 55 | }, 56 | "extra_info": { 57 | 'split': 'train', 58 | 'index': idx, 59 | 'answer': example['final_answer'], 60 | "question": example['question'], 61 | } 62 | } 63 | return data 64 | 65 | def process_fn_test(example, idx): 66 | data = { 67 | "data_source": val_source, 68 | "prompt": example['messages'], 69 | "ability": "math", 70 | "reward_model": { 71 | "style": "rule", 72 | "ground_truth": example['answer'] 73 | }, 74 | "extra_info": { 75 | 'split': 'test', 76 | 'index': idx, 77 | 'answer': example['answer'], 78 | "question": example['problem'], 79 | } 80 | } 81 | return data 82 | 83 | 84 | train_dataset = train_dataset.map(function=process_fn_train, with_indices=True) 85 | test_dataset = val_dataset.map(function=process_fn_test, with_indices=True) 86 | # preview the first few entries 87 | print('-'* 50) 88 | print("Train dataset sample:") 89 | print(train_dataset[5]) 90 | print('-'* 50) 91 | print("Test dataset sample:") 92 | print(test_dataset[5]) 93 | train_dataset.to_parquet(os.path.join(args.local_dir, 'train.parquet')) 94 | test_dataset.to_parquet(os.path.join(args.local_dir, 'test.parquet')) 95 | 96 | -------------------------------------------------------------------------------- /scripts/train/start_qwen8b-base_rise_example.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | set -u 3 | 4 | 5 | WANDB_TOKEN=xxxx 6 | RUN_NAME=xxxx 7 | DATA_DIR=/path/to/deepmath_10K 8 | MODEL_DIR=/path/to/Qwen3-8B-Base 9 | SAVE_DIR=/path/to/Qwen3-8B-Base-DeepMath10K-PPO-RISE 10 | 11 | mkdir -p .checkpoints/$RUN_NAME 12 | mkdir -p $SAVE_DIR 13 | 14 | # set http_proxy if needed 15 | ray start --head --num-cpus=16 --dashboard-port=8265 --dashboard-host=0.0.0.0 16 | 17 | sleep 10 18 | 19 | ray job submit --address="http://127.0.0.1:8265" \ 20 | --runtime-env-json='{ 21 | "env_vars": { 22 | "HUGGING_FACE_HUB_TOKEN": "xxxx", 23 | "LM_HARNESS_CACHE_PATH": "cache", 24 | "PYTHONUNBUFFERED": "1", 25 | "WANDB_API_KEY": "xxxx" 26 | }, 27 | "working_dir": "/path/to/your/working_dir", 28 | "pip": ["latex2sympy2", "word2number", "timeout_decorator"], 29 | "excludes": [".git/**"] 30 | }' -- PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ 31 | data.train_files=$DATA_DIR/train.parquet \ 32 | data.val_files=$DATA_DIR/test.parquet \ 33 | data.prompt_key=prompt \ 34 | data.train_batch_size=1024 \ 35 | +data.critique_batch_size=128 \ 36 | data.max_prompt_length=3072 \ 37 | data.max_response_length=8192 \ 38 | data.qwen3_thinking=True \ 39 | data.truncation=right \ 40 | actor_rollout_ref.model.path=$MODEL_DIR \ 41 | actor_rollout_ref.model.use_remove_padding=True \ 42 | actor_rollout_ref.actor.optim.lr=5e-7 \ 43 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \ 44 | actor_rollout_ref.actor.use_dynamic_bsz=True \ 45 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ 46 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ 47 | actor_rollout_ref.rollout.temperature=1.0 \ 48 | actor_rollout_ref.rollout.n=8 \ 49 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ 50 | actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ 51 | actor_rollout_ref.rollout.disable_log_stats=False \ 52 | actor_rollout_ref.rollout.enforce_eager=False \ 53 | actor_rollout_ref.rollout.free_cache_engine=False \ 54 | actor_rollout_ref.rollout.max_num_batched_tokens=24000 \ 55 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ 56 | actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ 57 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ 58 | actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=24000 \ 59 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 60 | critic.optim.lr=9e-6 \ 61 | critic.model.path=$MODEL_DIR \ 62 | critic.model.use_remove_padding=True \ 63 | critic.ppo_max_token_len_per_gpu=24000 \ 64 | critic.forward_max_token_len_per_gpu=24000 \ 65 | custom_reward_function.path=verl_utils/reward/reward_func.py \ 66 | custom_reward_function.name=reward_func \ 67 | algorithm.kl_ctrl.kl_coef=0.01 \ 68 | trainer.project_name=verl \ 69 | trainer.experiment_name=$RUN_NAME \ 70 | trainer.default_local_dir=$SAVE_DIR/$RUN_NAME \ 71 | trainer.logger=['console','wandb'] \ 72 | trainer.val_before_train=False \ 73 | +trainer.online_critique=True \ 74 | trainer.critique_prompt_idx=0 \ 75 | trainer.n_gpus_per_node=8 \ 76 | trainer.nnodes=1 \ 77 | trainer.save_freq=8 \ 78 | trainer.rollout_data_dir=/path/to/your/rollout_data_dir \ 79 | trainer.test_freq=8 \ 80 | trainer.total_epochs=12 2>&1 | tee -a .checkpoints/$RUN_NAME/train.log 81 | 82 | ray stop 83 | -------------------------------------------------------------------------------- /verl_utils/model_merger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Tuple, Dict 16 | import re 17 | import os 18 | import torch 19 | import argparse 20 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification 21 | from concurrent.futures import ThreadPoolExecutor 22 | from torch.distributed._tensor import DTensor, Shard, Placement 23 | 24 | 25 | def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): 26 | if placement.is_replicate(): 27 | return tensors[0] 28 | elif placement.is_partial(): 29 | raise NotImplementedError("Partial placement is not supported yet") 30 | elif placement.is_shard(): 31 | return torch.cat(tensors, dim=placement.dim).contiguous() 32 | else: 33 | raise ValueError(f"Unsupported placement: {placement}") 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--local_dir', required=True, type = str, help="The path for your saved model") 39 | parser.add_argument("--hf_upload_path", default=False, type = str, help="The path of the huggingface repo to upload") 40 | args = parser.parse_args() 41 | 42 | assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface" 43 | local_dir = args.local_dir 44 | 45 | # copy rank zero to find the shape of (dp, fsdp) 46 | rank = 0 47 | world_size = 0 48 | for filename in os.listdir(local_dir): 49 | match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) 50 | if match: 51 | world_size = match.group(1) 52 | break 53 | assert world_size, "No model file with the proper format" 54 | 55 | state_dict = torch.load(os.path.join(local_dir, f'model_world_size_{world_size}_rank_{rank}.pt'), map_location='cpu') 56 | pivot_key = sorted(list(state_dict.keys()))[0] 57 | weight = state_dict[pivot_key] 58 | assert isinstance(weight, torch.distributed._tensor.DTensor) 59 | # get sharding info 60 | device_mesh = weight.device_mesh 61 | mesh = device_mesh.mesh 62 | mesh_dim_names = device_mesh.mesh_dim_names 63 | 64 | print(f'Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}') 65 | 66 | assert mesh_dim_names in ( 67 | ('fsdp',), 68 | ), f'Unsupported mesh_dim_names {mesh_dim_names}' 69 | 70 | if 'tp' in mesh_dim_names: 71 | # fsdp * tp 72 | total_shards = mesh.shape[-1] * mesh.shape[-2] 73 | mesh_shape = (mesh.shape[-2], mesh.shape[-1]) 74 | else: 75 | # fsdp 76 | total_shards = mesh.shape[-1] 77 | mesh_shape = (mesh.shape[-1],) 78 | 79 | print(f'Processing model shards with {total_shards} {mesh_shape} in total') 80 | 81 | model_state_dict_lst = [] 82 | model_state_dict_lst.append(state_dict) 83 | model_state_dict_lst.extend([""] * (total_shards - 1)) 84 | 85 | def process_one_shard(rank): 86 | model_path = os.path.join(local_dir, f'model_world_size_{world_size}_rank_{rank}.pt') 87 | state_dict = torch.load(model_path, map_location='cpu', weights_only=False) 88 | model_state_dict_lst[rank] = state_dict 89 | return state_dict 90 | 91 | with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: 92 | for rank in range(1, total_shards): 93 | executor.submit(process_one_shard, rank) 94 | state_dict = {} 95 | param_placements: Dict[str, List[Placement]] = {} 96 | keys = set(model_state_dict_lst[0].keys()) 97 | for key in keys: 98 | state_dict[key] = [] 99 | for model_state_dict in model_state_dict_lst: 100 | try: 101 | tensor = model_state_dict.pop(key) 102 | except: 103 | print("-"*30) 104 | print(model_state_dict) 105 | if isinstance(tensor, DTensor): 106 | state_dict[key].append(tensor._local_tensor.bfloat16()) 107 | placements = tuple(tensor.placements) 108 | # replicated placement at dp dimension can be discarded 109 | if mesh_dim_names[0] == 'dp': 110 | placements = placements[1:] 111 | if key not in param_placements: 112 | param_placements[key] = placements 113 | else: 114 | assert param_placements[key] == placements 115 | else: 116 | state_dict[key] = tensor.bfloat16() 117 | 118 | del model_state_dict_lst 119 | 120 | for key in sorted(state_dict): 121 | if not isinstance(state_dict[key], list): 122 | print(f"No need to merge key {key}") 123 | continue 124 | # merge shards 125 | placements: Tuple[Shard] = param_placements[key] 126 | if len(mesh_shape) == 1: 127 | # 1-D list, FSDP without TP 128 | assert len(placements) == 1 129 | shards = state_dict[key] 130 | state_dict[key] = merge_by_placement(shards, placements[0]) 131 | else: 132 | # 2-D list, FSDP + TP 133 | raise NotImplementedError("FSDP + TP is not supported yet") 134 | 135 | print('Writing to local disk') 136 | hf_path = os.path.join(local_dir, 'huggingface') 137 | config = AutoConfig.from_pretrained(hf_path) 138 | 139 | if 'ForTokenClassification' in config.architectures[0]: 140 | auto_model = AutoModelForTokenClassification 141 | elif 'ForCausalLM' in config.architectures[0]: 142 | auto_model = AutoModelForCausalLM 143 | else: 144 | raise NotImplementedError(f'Unknown architecture {config["architectures"]}') 145 | 146 | with torch.device('meta'): 147 | model = auto_model.from_config(config, torch_dtype=torch.bfloat16) 148 | model.to_empty(device='cpu') 149 | 150 | print(f'Saving model to {hf_path}') 151 | model.save_pretrained(hf_path, state_dict=state_dict) 152 | del state_dict 153 | del model 154 | if args.hf_upload_path: 155 | # Push to hugging face 156 | from huggingface_hub import HfApi 157 | api = HfApi() 158 | api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True) 159 | api.upload_folder( 160 | folder_path=hf_path, 161 | repo_id=args.hf_upload_path, 162 | repo_type="model" 163 | ) 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # RISE 📈 4 | 5 |
6 | Reinforcing Reasoning with Self-Verification 7 |
8 |
9 | 🔥 An online RL framework that simultaneously trains LLMs in problem-solving and self-verification with verifiable reward signals. 🔥 10 |
11 |
12 | 13 |
14 | 15 | ![Logo](./assets/rise_method.jpeg) 16 | 17 | 18 | ## 🗒️ News 19 | - **July 5, 2025**: We release the training script of `Qwen3` series on RISE based on verl 0.4.0, which achieves strong results. 20 | - **June 12, 2025**: We update the [**RISE source code**](https://github.com/xyliu-cs/verl/tree/verl-v4) to support the latest verl release **v0.4.0**. 21 | - **May 20, 2025**: We release our technical report on [**arXiv**](https://arxiv.org/abs/2505.13445) and the initial version of training code based on [**verl**](https://github.com/volcengine/verl). 22 | 23 | ## 🎯Quick Start (verl v0.4.0) 24 | #### Environment Preparation 25 | ```shell 26 | conda create -y -n qwen3 python=3.12.2 ; conda activate qwen3 27 | pip3 install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 28 | pip3 install omegaconf==2.4.0.dev3 hydra-core==1.4.0.dev1 antlr4-python3-runtime==4.11.0 29 | pip3 install vllm==0.8.5.post1 30 | pip3 install math-verify[antlr4_11_0]==0.7.0 31 | git clone -b verl-v4 https://github.com/xyliu-cs/verl.git verl-v4 32 | pip3 uninstall -y verl ; cd verl-v4 ; pip3 install -e . 33 | pip3 install flash-attn==2.7.4.post1 --no-build-isolation 34 | pip3 install fire deepspeed tensorboardX prettytable datasets transformers==4.51.3 35 | pip3 install flashinfer-python -i https://flashinfer.ai/whl/cu124/torch2.6/ 36 | pip3 install langdetect==1.0.9 pebble==5.1.0 word2number 37 | ``` 38 | 39 | #### Data Processing 40 | 41 | ```shell 42 | OUTPUT_DATA_DIR=/path/to/your/data/output 43 | # Input data path is coded in generate_splits.py 44 | python3 verl_utils/data/generate_splits_deepmath.py --add_message --local_dir $OUTPUT_DATA_DIR 45 | ``` 46 | 47 | #### Training 48 | 49 | * Start Ray 50 | 51 | ```shell 52 | # Head node (×1) 53 | ray start --head --port=6379 --node-ip-address=$HEAD_ADDR --num-gpus=8 54 | 55 | # Worker nodes (xN) 56 | # Use this only if you are running across multiple machines 57 | ray start --address=$HEAD_ADDR:6379 --node-ip-address=$WORKER_ADDR --num-gpus=8 58 | ``` 59 | 60 | * Launch training at head node. See `scripts/train` for the complete training scripts. 61 | ```shell 62 | # Example 63 | sh scripts/train/start_qwen8b-base_rise_example.sh 64 | ``` 65 | ‼️ **Key Parameters for RISE Algorithm** 66 | 67 | - `+trainer.online_critique`: Enables (`True`) or disables (`False`) online verification during the RL training. 68 | - `+data.critique_batch_size`: Controls the number of verification samples included in each training batch. 69 | - `trainer.critique_prompt_idx`: the verification prompt used for the RL training, can be customized in `verl/utils/critique_templates.py`. Default is 0. 70 | - `data.qwen3_thinking`: Enables (`True`) or disables (`False`) thinking mode for the Qwen3 (instruction-tuned) model. Set `True` for the base models. 71 | - `reward_model.reward_func_path`: Relative path (from `working_dir`) to the Python file defining the **generation reward** function. The file should contain a function named "reward_func". 72 | - `reward_model.ver_reward_func_path`: Path to the **verification reward** function file. This file should contain a function named "ver_reward_func". Default is `null`, and the generation reward function is used instead. 73 | 74 | 75 | ## 🎯Quick Start (verl v0.2.0) 76 | 77 | #### Environment Preparation 78 | 79 | ```shell 80 | git clone --recurse-submodules https://github.com/xyliu-cs/RISE.git && cd RISE 81 | 82 | conda create -y -n rise python=3.12.2 && conda activate rise 83 | pip3 install ray[default] 84 | pip3 install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124 85 | pip3 install flash-attn==2.7.4.post1 --no-build-isolation 86 | pip3 install omegaconf==2.4.0.dev3 hydra-core==1.4.0.dev1 antlr4-python3-runtime==4.11.0 vllm==0.7.3 87 | pip3 install math-verify[antlr4_11_0]==0.7.0 fire deepspeed tensorboardX prettytable datasets 88 | cd verl 89 | pip3 install -e . 90 | ``` 91 | 92 | #### Data Processing 93 | 94 | ```shell 95 | OUTPUT_DATA_DIR=/path/to/your/data/output 96 | # Input data path is coded in generate_splits.py 97 | python3 verl_utils/data/generate_splits.py --local_dir $OUTPUT_DATA_DIR 98 | ``` 99 | 100 | 101 | #### Training 102 | 103 | * Start Ray 104 | 105 | ```shell 106 | # Head node (×1) 107 | ray start --head --port=6379 --node-ip-address=$HEAD_ADDR --num-gpus=8 108 | 109 | # Worker nodes (xN) 110 | # Use this only if you are running across multiple machines 111 | ray start --address=$HEAD_ADDR:6379 --node-ip-address=$WORKER_ADDR --num-gpus=8 112 | ``` 113 | 114 | * Launch training at head node. See `scripts/train` for the complete training scripts. 115 | ```shell 116 | # Example 117 | sh scripts/train/start_qwen3b_rise_example.sh 118 | ``` 119 | ‼️ **Key Parameters for RISE Algorithm** 120 | 121 | - `+trainer.online_critique`: Enables (`True`) or disables (`False`) online verification during the RL training. 122 | - `+data.critique_batch_size`: Controls the number of verification samples included in each training batch. 123 | - `reward_model.reward_func_path`: Relative path (from `working_dir`) to the Python file defining the **generation reward** function. The file should contain a function named "reward_func". 124 | - `reward_model.ver_reward_func_path`: Path to the **verification reward** function file. This file should contain a function named "ver_reward_func". Default is `null`, and the generation reward function is used instead. 125 | 126 | 127 | ## 🙏 Acknowledgements 128 | 129 | This work can not be done without the help of the following works: 130 | 131 | - **[verl](https://github.com/volcengine/verl)**: A very fast reinforcement learning framework. 132 | - **[vllm](https://github.com/vllm-project/vllm)**: A high-throughput and memory-efficient inference and serving engine for LLMs. 133 | - **[OpenMathInstruct-2](https://github.com/NVIDIA/NeMo-Skills)**: Model training and evaluation code. 134 | - **[SimpleRL](https://github.com/hkust-nlp/simpleRL-reason)**: RL training recipes for LLM reasoning. 135 | - **[DeepMath-103K](https://github.com/zwhe99/DeepMath)**: A Large-Scale, Challenging, Decontaminated, and Verifiable Mathematical Dataset for Advancing Reasoning. 136 | 137 | 138 | 139 | ## 📚 Citation 140 | ```bibtex 141 | @article{liu2025trustverifyselfverificationapproach, 142 | title={Trust, But Verify: A Self-Verification Approach to Reinforcement Learning with Verifiable Rewards}, 143 | author={Xiaoyuan Liu and Tian Liang and Zhiwei He and Jiahao Xu and Wenxuan Wang and Pinjia He and Zhaopeng Tu and Haitao Mi and Dong Yu}, 144 | year={2025}, 145 | eprint={2505.13445}, 146 | archivePrefix={arXiv}, 147 | primaryClass={cs.AI}, 148 | url={https://arxiv.org/abs/2505.13445}, 149 | } 150 | ``` 151 | -------------------------------------------------------------------------------- /verl_utils/reward/openmathinst_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright (c) Microsoft Corporation. 16 | # 17 | # Permission is hereby granted, free of charge, to any person obtaining a copy 18 | # of this software and associated documentation files (the "Software"), to deal 19 | # in the Software without restriction, including without limitation the rights 20 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 21 | # copies of the Software, and to permit persons to whom the Software is 22 | # furnished to do so, subject to the following conditions: 23 | # 24 | # The above copyright notice and this permission notice shall be included in all 25 | # copies or substantial portions of the Software. 26 | # 27 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 28 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 29 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 30 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 31 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 32 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 33 | # SOFTWARE 34 | 35 | # Copyright (c) 2023 OpenAI 36 | # 37 | # Permission is hereby granted, free of charge, to any person obtaining a copy 38 | # of this software and associated documentation files (the "Software"), to deal 39 | # in the Software without restriction, including without limitation the rights 40 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 41 | # copies of the Software, and to permit persons to whom the Software is 42 | # furnished to do so, subject to the following conditions: 43 | 44 | # The above copyright notice and this permission notice shall be included in all 45 | # copies or substantial portions of the Software. 46 | # 47 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 48 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 49 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 50 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 51 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 52 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 53 | # SOFTWARE. 54 | 55 | # Copyright (c) 2021 Dan Hendrycks 56 | # 57 | # Permission is hereby granted, free of charge, to any person obtaining a copy 58 | # of this software and associated documentation files (the "Software"), to deal 59 | # in the Software without restriction, including without limitation the rights 60 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 61 | # copies of the Software, and to permit persons to whom the Software is 62 | # furnished to do so, subject to the following conditions: 63 | # 64 | # The above copyright notice and this permission notice shall be included in all 65 | # copies or substantial portions of the Software. 66 | # 67 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 68 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 69 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 70 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 71 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 72 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 73 | # SOFTWARE. 74 | 75 | 76 | """ 77 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: 78 | - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py 79 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC 80 | - https://github.com/openai/prm800k 81 | """ 82 | 83 | 84 | import contextlib 85 | import re 86 | import signal 87 | from importlib.metadata import PackageNotFoundError, version 88 | from math import isclose 89 | from typing import Union 90 | from collections import Counter 91 | def most_common_element(data): 92 | """ 93 | Finds the most common element in a list. 94 | 95 | Parameters: 96 | data (list): The list of elements. 97 | 98 | Returns: 99 | The most common element in the list. If there are multiple elements with 100 | the same highest frequency, it returns the first one encountered. 101 | """ 102 | assert data and len(data) > 0, "Data is empty" 103 | 104 | counter = Counter(data) 105 | return counter.most_common(1)[0][0] 106 | 107 | def _check_antlr_version(): 108 | "Function for checking the antlr package version." 109 | # Check antlr version 110 | PACKAGE_NAME = 'antlr4-python3-runtime' 111 | REQUIRED_VERSION = '4.11.0' 112 | 113 | try: 114 | installed_version = version(PACKAGE_NAME) 115 | if installed_version != REQUIRED_VERSION: 116 | raise RuntimeError( 117 | f"Package {PACKAGE_NAME} version mismatch: {installed_version} (required: {REQUIRED_VERSION})" 118 | ) 119 | except PackageNotFoundError: 120 | raise RuntimeError(f"Package {PACKAGE_NAME} not found. Please install antlr4-python3-runtime==4.11.0.") 121 | 122 | 123 | def _fix_fracs(string): 124 | # replacing all extra spaces 125 | while "\\frac " in string: 126 | string = string.replace("\\frac ", "\\frac") 127 | substrs = string.split("\\frac") 128 | new_str = substrs[0] 129 | if len(substrs) > 1: 130 | substrs = substrs[1:] 131 | for substr in substrs: 132 | new_str += "\\frac" 133 | if len(substr) > 0 and substr[0] == "{": 134 | new_str += substr 135 | else: 136 | try: 137 | assert len(substr) >= 2 138 | except: 139 | return string 140 | a = substr[0] 141 | b = substr[1] 142 | if b != "{": 143 | if len(substr) > 2: 144 | post_substr = substr[2:] 145 | new_str += "{" + a + "}{" + b + "}" + post_substr 146 | else: 147 | new_str += "{" + a + "}{" + b + "}" 148 | else: 149 | if len(substr) > 2: 150 | post_substr = substr[2:] 151 | new_str += "{" + a + "}" + b + post_substr 152 | else: 153 | new_str += "{" + a + "}" + b 154 | string = new_str 155 | return string 156 | 157 | 158 | def _str_is_int(x: str) -> bool: 159 | try: 160 | x = _strip_properly_formatted_commas(x) 161 | x = float(x) 162 | return abs(x - int(round(x))) <= 1e-7 163 | except: 164 | return False 165 | 166 | 167 | def _str_to_int(x: str) -> bool: 168 | x = x.replace(",", "") 169 | if "_" in x: 170 | # Due to base 171 | x = x.split("_")[0] 172 | x = float(x) 173 | return int(x) 174 | 175 | 176 | def _inject_implicit_mixed_number(step: str): 177 | """ 178 | Automatically make a mixed number evalable 179 | e.g. 7 3/4 => 7+3/4 180 | """ 181 | p1 = re.compile("([0-9]) +([0-9])") 182 | step = p1.sub("\\1+\\2", step) ## implicit mults 183 | return step 184 | 185 | 186 | def _strip_properly_formatted_commas(expr: str): 187 | # We want to be careful because we don't want to strip tuple commas 188 | p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") 189 | while True: 190 | next_expr = p1.sub("\\1\\3\\4", expr) 191 | if next_expr == expr: 192 | break 193 | expr = next_expr 194 | return next_expr 195 | 196 | 197 | def _remove_right_units(expr): 198 | # "\\text{ " only ever occurs (at least in the val set) when describing units 199 | if "\\text" in expr: 200 | try: 201 | splits = re.split(r"\\text\s*{\s*", expr) 202 | # print(splits) 203 | assert len(splits) == 2 and splits[0] not in ("", "(") 204 | return splits[0] 205 | except AssertionError: 206 | pass 207 | 208 | if "\\text{" in expr: 209 | return re.sub(r"\\text{([^}]+)}", r"\1", expr) 210 | elif "\\mbox{" in expr: 211 | splits = expr.split("\\mbox{") 212 | if len(splits) == 2: 213 | return splits[0] 214 | else: 215 | return expr 216 | else: 217 | return expr 218 | 219 | 220 | def _process_and_or_inside_text(string): 221 | string = re.sub(r"\s*\\text{\s*(or|and)\s*}\s*", ",", string) 222 | string = re.sub(r",\s*,", ",", string) 223 | return string 224 | 225 | 226 | def _remove_left_and_right(expr): 227 | """Remove the right and left latex commands.""" 228 | expr = re.sub(r"\\left", "", expr) 229 | expr = re.sub(r"\\right", "", expr) 230 | return expr 231 | 232 | 233 | def _fix_sqrt(string): 234 | _string = re.sub(r"\\sqrt(\s*\w+)", r"\\sqrt{\1}", string) 235 | return _string 236 | 237 | 238 | def _fix_interval(expr): 239 | """Fix interval expression.""" 240 | if "\\in " in expr: 241 | return expr.split("\\in ")[1].strip() 242 | 243 | return expr 244 | 245 | 246 | def _inject_implicit_mixed_fraction(step: str): 247 | """ 248 | Automatically make a mixed number evalable 249 | e.g. 7 \\frac{3}{4} => 7+3/4 250 | """ 251 | p1 = re.compile(r"(\d+) *\\frac{(\d+)}{(\d+)}") 252 | 253 | def replacer(match): 254 | whole_part = match.group(1) 255 | numerator = match.group(2) 256 | denominator = match.group(3) 257 | 258 | if whole_part: 259 | return f"{whole_part} + {numerator}/{denominator}" 260 | else: 261 | return f"{numerator}/{denominator}" 262 | 263 | step = p1.sub(replacer, step) 264 | return step 265 | 266 | 267 | def normalize_answer_string(expr: str) -> str: 268 | """Normalize answer expressions.""" 269 | if expr is None: 270 | return None 271 | 272 | # Remove enclosing `\text{}`. 273 | 274 | expr = _remove_left_and_right(expr) 275 | expr = _process_and_or_inside_text(expr) 276 | expr = _remove_right_units(expr) 277 | expr = _fix_interval(expr) 278 | for surround_str in ["\\\\text", "\\\\mathrm", "\\\\mathcal", "\\\\textbf", "\\\\textit"]: 279 | expr = expr.replace(surround_str, "") 280 | pattern = f"^{surround_str}" + "\{(?P.+?)\}$" 281 | m = re.search(pattern, expr) 282 | if m is not None: 283 | expr = m.group("text") 284 | 285 | expr = expr.replace("\!", "") 286 | expr = expr.replace("\\%", "%") 287 | expr = expr.replace("\\$", "$") 288 | expr = expr.replace("$", "") 289 | expr = expr.replace("%", "") 290 | expr = expr.replace("^{\\circ}", "") 291 | 292 | expr = expr.replace(" or ", " , ") 293 | expr = expr.replace(" and ", " , ") 294 | 295 | expr = expr.replace("million", "*10^6") 296 | expr = expr.replace("billion", "*10^9") 297 | expr = expr.replace("trillion", "*10^12") 298 | 299 | for unit in [ 300 | "degree", 301 | "cm", 302 | "centimeter", 303 | "meter", 304 | "mile", 305 | "second", 306 | "minute", 307 | "hour", 308 | "week", 309 | "month", 310 | "year", 311 | "foot", 312 | "feet", 313 | "inch", 314 | "yard", 315 | "p.m.", 316 | "PM", 317 | ]: 318 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) 319 | 320 | if "day" in expr: 321 | days = [ 322 | "Monday", 323 | "Tuesday", 324 | "Wednesday", 325 | "Thursday", 326 | "Friday", 327 | "Saturday", 328 | "Sunday", 329 | ] 330 | weekday_expressed = False 331 | for day in days: 332 | if day in expr: 333 | weekday_expressed = True 334 | break 335 | 336 | if not weekday_expressed: 337 | expr = re.sub(f"day(s)?", "", expr) 338 | 339 | expr = re.sub(f"\^ *\\\\circ", "", expr) 340 | 341 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": 342 | expr = expr[1:-1] 343 | 344 | expr = _fix_sqrt(expr) 345 | 346 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 347 | expr = _fix_fracs(expr) 348 | 349 | # edge case with mixed numbers and negative signs 350 | expr = re.sub("- *", "-", expr) 351 | expr = _inject_implicit_mixed_number(expr) 352 | expr = _inject_implicit_mixed_fraction(expr) 353 | expr = expr.replace(" ", "") 354 | 355 | if _str_is_int(expr): 356 | expr = str(_str_to_int(expr)) 357 | 358 | return expr 359 | 360 | 361 | def is_digit(s): 362 | try: 363 | if "{,}" in str(s): 364 | num = float(str(s).replace("{,}", "")) 365 | return True, num 366 | 367 | num = float(str(s).replace(",", "")) 368 | return True, num 369 | except ValueError: 370 | return False, None 371 | 372 | 373 | def normalize(answer) -> str: 374 | # checking if answer is $ and removing $ in that case to compare 375 | if isinstance(answer, str) and bool(re.match(r'\$\d+(\.\d+)?', answer)): 376 | return answer[1:] 377 | 378 | # checking if answer is % or \\% and removing % 379 | if isinstance(answer, str) and ( 380 | bool(re.match(r'^\d+(\.\d+)?%$', answer)) or bool(re.match(r'^\d+(\.\d+)?\\%$', answer)) 381 | ): 382 | return answer.replace("\\%", "").replace("%", "") 383 | 384 | return answer 385 | 386 | 387 | def math_equal( 388 | prediction: Union[bool, float, str], 389 | reference: Union[float, str], 390 | include_percentage: bool = True, 391 | tolerance: float = 1e-4, 392 | timeout: float = 10.0, 393 | check_antlr_version: bool = True 394 | ) -> bool: 395 | """ 396 | Exact match of math if and only if: 397 | 1. numerical equal: both can convert to float and are equal 398 | 2. symbolic equal: both can convert to sympy expression and are equal 399 | """ 400 | 401 | # Check that the right antlr version is installed. 402 | if check_antlr_version: 403 | _check_antlr_version() 404 | 405 | from sympy.parsing.sympy_parser import parse_expr 406 | 407 | prediction = normalize(prediction) 408 | reference = normalize(reference) 409 | 410 | # another round of normalization 411 | prediction = normalize_answer_string(prediction) 412 | reference = normalize_answer_string(reference) 413 | 414 | if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases 415 | prediction = prediction[:1000] 416 | 417 | # 0. string comparison 418 | if isinstance(prediction, str) and isinstance(reference, str): 419 | if prediction.strip().lower() == reference.strip().lower(): 420 | return True 421 | if prediction.replace(" ", "") == reference.replace(" ", ""): 422 | return True 423 | 424 | try: # 1. numerical equal 425 | if is_digit(prediction)[0] and is_digit(reference)[0]: 426 | prediction = is_digit(prediction)[1] 427 | reference = is_digit(reference)[1] 428 | # number questions 429 | if include_percentage: 430 | gt_result = [reference / 100, reference, reference * 100] 431 | else: 432 | gt_result = [reference] 433 | for item in gt_result: 434 | try: 435 | if isclose(item, prediction, rel_tol=tolerance): 436 | return True 437 | except Exception: 438 | continue 439 | return False 440 | except Exception: 441 | pass 442 | 443 | if not prediction and prediction not in [0, False]: 444 | return False 445 | 446 | # 2. symbolic equal 447 | reference = str(reference).strip() 448 | prediction = str(prediction).strip() 449 | 450 | ## deal with [], (), {} 451 | prediction = format_intervals(prediction) 452 | 453 | pred_str, ref_str = prediction, reference 454 | if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or ( 455 | prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") 456 | ): 457 | pred_str = pred_str.strip("[]()") 458 | ref_str = ref_str.strip("[]()") 459 | for s in ["{", "}", "(", ")"]: 460 | ref_str = ref_str.replace(s, "") 461 | pred_str = pred_str.replace(s, "") 462 | if pred_str == ref_str: 463 | return True 464 | 465 | ## [a, b] vs. [c, d], return a==c and b==d 466 | if ( 467 | prediction 468 | and reference 469 | and prediction[0] in "([" 470 | and prediction[-1] in ")]" 471 | and prediction[0] == reference[0] 472 | and prediction[-1] == reference[-1] 473 | ): 474 | pred_parts = prediction[1:-1].split(",") 475 | ref_parts = reference[1:-1].split(",") 476 | if len(pred_parts) == len(ref_parts): 477 | if all( 478 | [ 479 | math_equal(pred_pt, ref_pt, include_percentage, tolerance, check_antlr_version=check_antlr_version) 480 | for pred_pt, ref_pt in zip(pred_parts, ref_parts) 481 | ] 482 | ): 483 | return True 484 | 485 | if "," in prediction and "," in reference: 486 | pred_parts = [item.strip() for item in prediction.split(",")] 487 | ref_parts = [item.strip() for item in reference.split(",")] 488 | 489 | if len(pred_parts) == len(ref_parts): 490 | if all( 491 | [ 492 | math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance, check_antlr_version=check_antlr_version) 493 | for i in range(len(pred_parts)) 494 | ] 495 | ): 496 | return True 497 | else: 498 | return False 499 | 500 | # if we have point == tuple of values 501 | if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")": 502 | pred_parts = prediction[prediction.find("(") + 1 : -1].split(",") 503 | ref_parts = reference[1:-1].split(",") 504 | if len(pred_parts) == len(ref_parts): 505 | if all( 506 | [ 507 | math_equal(pred_pt, ref_pt, include_percentage, tolerance, check_antlr_version=check_antlr_version) 508 | for pred_pt, ref_pt in zip(pred_parts, ref_parts) 509 | ] 510 | ): 511 | return True 512 | 513 | # if reference is a matrix 514 | if reference.startswith("\\begin{pmatrix}") and prediction.startswith("Matrix"): 515 | try: 516 | pred_matrix = parse_expr(prediction) 517 | ref_matrix_items = reference.split()[1:-1:2] 518 | if len(pred_matrix) == len(ref_matrix_items): 519 | if all( 520 | [ 521 | math_equal(ref, pred, include_percentage, tolerance, check_antlr_version=check_antlr_version) 522 | for ref, pred in zip(ref_matrix_items, pred_matrix) 523 | ] 524 | ): 525 | return True 526 | except Exception: 527 | pass 528 | 529 | return symbolic_equal(prediction, reference, tolerance, timeout) 530 | 531 | 532 | def symbolic_equal(a, b, tolerance, timeout=10.0): 533 | import sympy 534 | from sympy.parsing.latex import parse_latex 535 | from sympy.parsing.sympy_parser import parse_expr 536 | 537 | def _parse(s): 538 | for f in [parse_expr, parse_latex]: 539 | try: 540 | with time_limit(timeout): 541 | return f(s) 542 | except Exception: 543 | pass 544 | return s 545 | 546 | a = _parse(a) 547 | b = _parse(b) 548 | 549 | try: 550 | with time_limit(timeout): 551 | if sympy.simplify(a - b) == 0: 552 | return True 553 | except Exception: 554 | pass 555 | 556 | try: 557 | with time_limit(timeout): 558 | if isclose(sympy.N(a), sympy.N(b), rel_tol=tolerance): 559 | return True 560 | except Exception: 561 | pass 562 | return False 563 | 564 | 565 | def extract_answer(string: str, extract_from_boxed: bool = True, extract_regex: str = r"The final answer is (.+)$"): 566 | """Extract Answer String from \\boxed expression or based on regex""" 567 | if not extract_from_boxed: 568 | match = re.search(extract_regex, string) 569 | if match: 570 | return match.group(1) 571 | return None 572 | 573 | if "\\boxed" not in string: 574 | return None 575 | 576 | idx = string.rfind("\\boxed") 577 | if idx < 0: 578 | idx = string.rfind("\\fbox") 579 | if idx < 0: 580 | return None 581 | 582 | i = idx 583 | right_brace_idx = None 584 | num_left_braces_open = 0 585 | while i < len(string): 586 | if string[i] == "{": 587 | num_left_braces_open += 1 588 | if string[i] == "}": 589 | num_left_braces_open -= 1 590 | if num_left_braces_open == 0: 591 | right_brace_idx = i 592 | break 593 | i += 1 594 | 595 | if right_brace_idx is None: 596 | retval = None 597 | else: 598 | retval = string[idx : right_brace_idx + 1] 599 | 600 | if retval: 601 | left = "\\boxed{" 602 | try: 603 | assert retval[: len(left)] == left 604 | assert retval[-1] == "}" 605 | return retval[len(left) : -1] 606 | except AssertionError: 607 | return None 608 | 609 | return None 610 | 611 | 612 | class TimeoutException(Exception): 613 | pass 614 | 615 | 616 | @contextlib.contextmanager 617 | def time_limit(seconds: float): 618 | def signal_handler(signum, frame): 619 | raise TimeoutException("Timed out!") 620 | 621 | signal.setitimer(signal.ITIMER_REAL, seconds) 622 | signal.signal(signal.SIGALRM, signal_handler) 623 | try: 624 | yield 625 | finally: 626 | signal.setitimer(signal.ITIMER_REAL, 0) 627 | 628 | 629 | def format_intervals(prediction): 630 | patterns = { 631 | "Interval(": r"^Interval\((.*)\)$", 632 | "Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$", 633 | "Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$", 634 | "Interval.open(": r"^Interval\.open\((.*)\)$", 635 | } 636 | 637 | for key, pattern in patterns.items(): 638 | match = re.match(pattern, prediction) 639 | if match: 640 | inner_content = match.group(1) 641 | 642 | if key == "Interval(": # Intarval(a, b) == [a, b] 643 | return f"[{inner_content}]" 644 | elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b) 645 | return f"[{inner_content})" 646 | elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b] 647 | return f"({inner_content}]" 648 | elif key == "Interval.open(": # Intarval.open(a, b) == (a, b) 649 | return f"({inner_content})" 650 | 651 | return prediction 652 | 653 | def process_results( 654 | response: Union[str, list[str]], 655 | answer: str, 656 | response_extract_from_boxed: bool = True, 657 | response_extract_regex: str = r"The final answer is (.+)$", 658 | ) -> bool: 659 | if isinstance(response, str): 660 | return math_equal( 661 | extract_answer(response, response_extract_from_boxed, response_extract_regex), 662 | answer, 663 | ) 664 | elif isinstance(response, list): 665 | return math_equal( 666 | most_common_element( 667 | [ 668 | extract_answer(r, response_extract_from_boxed, response_extract_regex) 669 | for r in response 670 | ] 671 | ), 672 | answer, 673 | ) 674 | else: 675 | raise ValueError(f"Invalid response type: {type(response)}") 676 | 677 | def reward_func(data_source, solution_str, ground_truth) -> float: 678 | extracted_answer = extract_answer(solution_str, extract_from_boxed=True) 679 | if extracted_answer is None: 680 | return -1.0 681 | else: 682 | if math_equal(extracted_answer, ground_truth, check_antlr_version=False): 683 | return 1.0 684 | else: 685 | return -0.5 686 | --------------------------------------------------------------------------------