├── construct_dataset ├── README.md ├── construct_prompt.py ├── train_eval_split.py └── rm_dataset_construction_phase1.py ├── self-training ├── envs.py ├── tree.py ├── llm_query.py ├── construct_prompt.py ├── choose_best_reward_model.py ├── assign_node_rm.py ├── choose_best_policy_model.py ├── grading │ ├── math_normalize.py │ └── grader.py ├── sft_rm.py ├── assign_node_without_rm.py ├── generate_mcts.py ├── construct_prm_train_data.py ├── best_of_n_gsm8k.py ├── construct_hrm_train_data.py ├── best_of_n_math500.py └── sft_policy_model.py ├── accelerate_config ├── 1gpu.yaml ├── 2gpus.yaml ├── 3gpus.yaml ├── 4gpus.yaml ├── 6gpus.yaml └── 8gpus.yaml ├── dataset └── README.md ├── LICENSE.txt ├── deepspeed_config ├── RM_config_stage2.json ├── rm_config_stage2_py.json ├── RM_config.json └── policy_model_72b.json ├── README.md └── sft_rw_manual_annotation ├── llm_query.py ├── grading ├── math_normalize.py └── grader.py ├── choose_best_model.py ├── sft_rw.py └── best_of_n.py /construct_dataset/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /self-training/envs.py: -------------------------------------------------------------------------------- 1 | MAX_HEIGHT = 6 2 | NUMBER_OF_CHILDREN = 5 3 | -------------------------------------------------------------------------------- /accelerate_config/1gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | distributed_type: DEEPSPEED 4 | enable_cpu_affinity: false 5 | machine_rank: 0 6 | main_training_function: main 7 | num_machines: 1 8 | num_processes: 1 9 | rdzv_backend: static 10 | same_network: true 11 | tpu_env: [] 12 | tpu_use_cluster: false 13 | tpu_use_sudo: false 14 | use_cpu: false 15 | -------------------------------------------------------------------------------- /accelerate_config/2gpus.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | distributed_type: DEEPSPEED 4 | enable_cpu_affinity: false 5 | machine_rank: 0 6 | main_training_function: main 7 | num_machines: 1 8 | num_processes: 2 9 | rdzv_backend: static 10 | same_network: true 11 | tpu_env: [] 12 | tpu_use_cluster: false 13 | tpu_use_sudo: false 14 | use_cpu: false 15 | -------------------------------------------------------------------------------- /accelerate_config/3gpus.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | distributed_type: DEEPSPEED 4 | enable_cpu_affinity: false 5 | machine_rank: 0 6 | main_training_function: main 7 | num_machines: 1 8 | num_processes: 3 9 | rdzv_backend: static 10 | same_network: true 11 | tpu_env: [] 12 | tpu_use_cluster: false 13 | tpu_use_sudo: false 14 | use_cpu: false 15 | -------------------------------------------------------------------------------- /accelerate_config/4gpus.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | distributed_type: DEEPSPEED 4 | enable_cpu_affinity: false 5 | machine_rank: 0 6 | main_training_function: main 7 | num_machines: 1 8 | num_processes: 4 9 | rdzv_backend: static 10 | same_network: true 11 | tpu_env: [] 12 | tpu_use_cluster: false 13 | tpu_use_sudo: false 14 | use_cpu: false 15 | -------------------------------------------------------------------------------- /accelerate_config/6gpus.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | distributed_type: DEEPSPEED 4 | enable_cpu_affinity: false 5 | machine_rank: 0 6 | main_training_function: main 7 | num_machines: 1 8 | num_processes: 6 9 | rdzv_backend: static 10 | same_network: true 11 | tpu_env: [] 12 | tpu_use_cluster: false 13 | tpu_use_sudo: false 14 | use_cpu: false 15 | -------------------------------------------------------------------------------- /accelerate_config/8gpus.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | distributed_type: DEEPSPEED 4 | enable_cpu_affinity: false 5 | machine_rank: 0 6 | main_training_function: main 7 | num_machines: 1 8 | num_processes: 8 9 | rdzv_backend: static 10 | same_network: true 11 | tpu_env: [] 12 | tpu_use_cluster: false 13 | tpu_use_sudo: false 14 | use_cpu: false 15 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | Just download PRM800K dataset, create another folder named prm_dataset and place these 4 jsonl files into that folder. 2 | 3 | In this project, we only use phase1 part which contains manual annotation. 4 | 5 | construct_dataset folder is used for constructing training dataset when handling with manual annotation data, and this process will create a new folder named phase1 that contains ORM, PRM and HRM training data. 6 | 7 | In auto-annotation process, we only use the question and ground truth from PRM800K dataset. Self-training module will generate the reasoning process and label it autonomously. 8 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /deepspeed_config/RM_config_stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "initial_scale_power": 12, 5 | "loss_scale_window": 1000, 6 | "hysteresis": 2, 7 | "min_loss_scale": 1 8 | }, 9 | "optimizer": { 10 | "type": "AdamW", 11 | "params": { 12 | "lr": "auto", 13 | "betas": "auto", 14 | "eps": "auto", 15 | "weight_decay": "auto" 16 | } 17 | }, 18 | "scheduler": { 19 | "type": "WarmupCosineLR", 20 | "params": { 21 | "warmup_num_steps": "auto", 22 | "total_num_steps": "auto", 23 | "warmup_type": "linear" 24 | } 25 | }, 26 | "zero_optimization": { 27 | "stage": 2, 28 | "offload_optimizer": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "allgather_partitions": true, 33 | "allgather_bucket_size": 2e8, 34 | "overlap_comm": true, 35 | "reduce_scatter": true, 36 | "reduce_bucket_size": 2e8, 37 | "contiguous_gradients": true, 38 | "round_robin_gradients": true 39 | }, 40 | "gradient_accumulation_steps": "auto", 41 | "gradient_clipping": "auto", 42 | "train_batch_size": "auto" 43 | } -------------------------------------------------------------------------------- /deepspeed_config/rm_config_stage2_py.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "initial_scale_power": 12, 5 | "loss_scale_window": 1000, 6 | "hysteresis": 2, 7 | "min_loss_scale": 1 8 | }, 9 | "optimizer": { 10 | "type": "AdamW", 11 | "params": { 12 | "lr": "auto", 13 | "betas": "auto", 14 | "eps": "auto", 15 | "weight_decay": "auto" 16 | } 17 | }, 18 | "scheduler": { 19 | "type": "WarmupCosineLR", 20 | "params": { 21 | "warmup_num_steps": "auto", 22 | "total_num_steps": "auto", 23 | "warmup_type": "linear" 24 | } 25 | }, 26 | "zero_optimization": { 27 | "stage": 3, 28 | "offload_optimizer": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "allgather_partitions": true, 33 | "allgather_bucket_size": 2e8, 34 | "overlap_comm": true, 35 | "reduce_scatter": true, 36 | "reduce_bucket_size": 2e8, 37 | "contiguous_gradients": true, 38 | "round_robin_gradients": true 39 | }, 40 | "gradient_accumulation_steps": "auto", 41 | "gradient_clipping": "auto", 42 | "train_batch_size": "auto" 43 | } 44 | -------------------------------------------------------------------------------- /self-training/tree.py: -------------------------------------------------------------------------------- 1 | """ 2 | The data structure for MCTS 3 | """ 4 | from grading.grader import grade_answer 5 | from envs import * 6 | 7 | 8 | class Node: 9 | def __init__(self, parent=None, height=0, question="", number_of_children=NUMBER_OF_CHILDREN, score=-1): 10 | self.previous_answer = [] 11 | self.parent = parent 12 | self.number_of_children = number_of_children 13 | self.children = [] 14 | self.height = height 15 | self.question = question 16 | self.contain_answer = False 17 | self.answer = None 18 | self.score = score 19 | self.is_correct = False 20 | self.should_stop = False 21 | self.is_leaf_node = False 22 | 23 | def add_children(self, children_list): 24 | self.children = children_list 25 | if len(self.children) != self.number_of_children: 26 | print(f"question: {self.question} children doesn't equal number of children: {self.number_of_children}") 27 | 28 | def have_the_answer(self, answer, ground_truth): 29 | self.answer = answer 30 | self.contain_answer = True 31 | self.is_correct = grade_answer(answer, ground_truth) 32 | 33 | def set_previous_answer(self, previous_answer): 34 | self.previous_answer = previous_answer[:] 35 | assert len(self.previous_answer) == self.height - 1 36 | -------------------------------------------------------------------------------- /deepspeed_config/RM_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "optimizer": { 11 | "type": "AdamW", 12 | "params": { 13 | "lr": "auto", 14 | "betas": "auto", 15 | "eps": "auto", 16 | "weight_decay": "auto" 17 | } 18 | }, 19 | "scheduler": { 20 | "type": "WarmupLR", 21 | "params": { 22 | "warmup_min_lr": "auto", 23 | "warmup_max_lr": "auto", 24 | "warmup_num_steps": 20 25 | } 26 | }, 27 | "zero_optimization": { 28 | "stage": 3, 29 | "offload_optimizer": { 30 | "device": "cpu", 31 | "pin_memory": true 32 | }, 33 | "offload_param": { 34 | "device": "cpu", 35 | "pin_memory": true 36 | }, 37 | "overlap_comm": true, 38 | "contiguous_gradients": true, 39 | "sub_group_size": 1e9, 40 | "reduce_bucket_size": "auto", 41 | "stage3_prefetch_bucket_size": "auto", 42 | "stage3_param_persistence_threshold": "auto", 43 | "stage3_max_live_parameters": 1e9, 44 | "stage3_max_reuse_distance": 1e9, 45 | "stage3_gather_16bit_weights_on_model_save": true 46 | }, 47 | 48 | "gradient_accumulation_steps": "auto", 49 | "gradient_clipping": "auto", 50 | "train_batch_size": "auto" 51 | } -------------------------------------------------------------------------------- /deepspeed_config/policy_model_72b.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": true, 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "optimizer": { 11 | "type": "AdamW", 12 | "params": { 13 | "lr": "auto", 14 | "weight_decay": "auto" 15 | } 16 | }, 17 | "scheduler": { 18 | "type": "WarmupDecayLR", 19 | "params": { 20 | "warmup_min_lr": "auto", 21 | "warmup_max_lr": "auto", 22 | "warmup_num_steps": "auto", 23 | "total_num_steps": "auto" 24 | } 25 | }, 26 | "zero_optimization": { 27 | "stage": 3, 28 | "offload_optimizer": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "offload_param": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "overlap_comm": true, 37 | "contiguous_gradients": true, 38 | "reduce_bucket_size": "auto", 39 | "stage3_prefetch_bucket_size": "auto", 40 | "stage3_param_persistence_threshold": "auto", 41 | "sub_group_size": 1e6, 42 | "stage3_max_live_parameters": 5e7, 43 | "stage3_max_reuse_distance": 5e7, 44 | "stage3_gather_16bit_weights_on_model_save": true 45 | }, 46 | "gradient_accumulation_steps": "auto", 47 | "gradient_clipping": "auto", 48 | "steps_per_print": 2000, 49 | "train_batch_size": "auto", 50 | "train_micro_batch_size_per_gpu": "auto", 51 | "wall_clock_breakdown": false 52 | } -------------------------------------------------------------------------------- /self-training/llm_query.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | from envs import * 3 | 4 | 5 | def parallel_query(host, port, model_name, prompt, api_key="", repetition_penalty=1.0, n=NUMBER_OF_CHILDREN): 6 | client = OpenAI(base_url=f"http://{host}:{port}/v1", 7 | api_key=api_key, ) 8 | 9 | completion = client.completions.create( 10 | model=model_name, 11 | prompt=prompt, 12 | temperature=0.1, 13 | stop=['# END!', '# Step 2', "# Step 3", "# Step 4", "# Step 5"], 14 | max_tokens=1024, 15 | extra_body={"include_stop_str_in_output": True, "repetition_penalty": repetition_penalty}, 16 | n=n 17 | ) 18 | return [choice.text for choice in completion.choices] 19 | 20 | 21 | def sequential_query(host, port, model_name, prompt, api_key="", repetition_penalty=1.0): 22 | client = OpenAI(base_url=f"http://{host}:{port}/v1", 23 | api_key=api_key, ) 24 | 25 | completion = client.completions.create( 26 | model=model_name, 27 | prompt=prompt, 28 | temperature=0.1, 29 | stop=['# END!', '# Step 2', "# Step 3", "# Step 4", "# Step 5"], 30 | max_tokens=1024, 31 | extra_body={"include_stop_str_in_output": True, "repetition_penalty": repetition_penalty}, 32 | 33 | ) 34 | return completion.choices[0].text 35 | 36 | 37 | def one_step_query(host, port, model_name, prompt, api_key="", repetition_penalty=1.0): 38 | client = OpenAI(base_url=f"http://{host}:{port}/v1", 39 | api_key=api_key, ) 40 | 41 | completion = client.completions.create( 42 | model=model_name, 43 | prompt=prompt, 44 | temperature=0.1, 45 | stop=['# END!'], 46 | max_tokens=3600, 47 | extra_body={"include_stop_str_in_output": True, "repetition_penalty": repetition_penalty}, 48 | 49 | ) 50 | return completion.choices[0].text 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the official code for the paper [**towards hierarchical multi-step reward models for enhanced reasoning in large language models**](https://arxiv.org/abs/2503.13551). 2 | 3 | This project has two parts for illustrating this paper: 4 | 5 | ### 1. Training with Manual Annotations and Best-of-N Evaluation 6 | 7 | - Manual annotations are utilized to train three reward models: 8 | - **ORM** (Output Reward Model) 9 | - **PRM** (Process Reward Model) 10 | - **HRM** (Hierarchical Reward Model) 11 | - The policy model is evaluated using the **best-of-N** strategy based on these reward models. 12 | - Relevant code: 13 | - The `construct_dataset` folder is used to generate training data for ORM, PRM, and HRM. 14 | - The `sft_rw_manual_annotation` folder contains code for fine-tuning the reward models and evaluating the policy model. 15 | 16 | ### 2. Self-Supervised Training with MCTS and HNC 17 | 18 | - **MCTS** (Monte Carlo Tree Search) and **HNC** (Hierachical Node Compression, as described in the paper) are implemented to automatically generate and label PRM and HRM training data. 19 | - The automatically generated **PRM800K** dataset is used to train PRM and HRM models. 20 | - Evaluation is conducted on the following datasets: 21 | - **PRM800K** 22 | - **MATH500** (Cross-domain) 23 | - **GSM8K** (Cross-domain) 24 | - A policy model trained with SFT is included, which incorporates **KL divergence** from a reference model to enhance reasoning ability. 25 | - Relevant code is located in the `self-training` folder. 26 | 27 | If you find our work useful, please consider citing it in your research. 28 | 29 | ``` 30 | @article{wang2025towards, 31 | title={Towards Hierarchical Multi-Step Reward Models for Enhanced Reasoning in Large Language Models}, 32 | author={Wang, Teng and Jiang, Zhangyi and He, Zhenqi and Yang, Wenhan and Zheng, Yanan and Li, Zeyu and He, Zifan and Tong, Shenyang and Gong, Hailei}, 33 | journal={arXiv preprint arXiv:2503.13551}, 34 | year={2025} 35 | } 36 | ``` 37 | 38 | -------------------------------------------------------------------------------- /sft_rw_manual_annotation/llm_query.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | 4 | 5 | def orm_query(host, port, model_name, prompt, api_key="", repetition_penalty=1.0): 6 | client = OpenAI(base_url=f"http://{host}:{port}/v1", 7 | api_key=api_key) 8 | 9 | completion = client.completions.create( 10 | model=model_name, 11 | prompt=prompt, 12 | temperature=0.1, 13 | max_tokens=2048, 14 | stop=['# END!'], 15 | 16 | extra_body={"repetition_penalty": repetition_penalty, "include_stop_str_in_output": False} 17 | ) 18 | 19 | return completion.choices[0].text 20 | 21 | 22 | def orm_parallel_query(host, port, model_name, prompt, api_key="", repetition_penalty=1.0, n=1): 23 | client = OpenAI(base_url=f"http://{host}:{port}/v1", 24 | api_key=api_key) 25 | 26 | completion = client.completions.create( 27 | model=model_name, 28 | prompt=prompt, 29 | temperature=0.1, 30 | max_tokens=2048, 31 | stop=['# END!'], 32 | 33 | extra_body={"repetition_penalty": repetition_penalty, "include_stop_str_in_output": False}, 34 | n=n 35 | 36 | ) 37 | 38 | return [choice.text for choice in completion.choices] 39 | 40 | 41 | def prm_query(host, port, model_name, prompt, api_key="", repetition_penalty=1.0): 42 | client = OpenAI(base_url=f"http://{host}:{port}/v1", 43 | api_key=api_key, ) 44 | 45 | completion = client.completions.create( 46 | model=model_name, 47 | prompt=prompt, 48 | temperature=0.1, 49 | stop=['\]\n\n', '\)\n\n', '# END!'], 50 | max_tokens=1024, 51 | extra_body={"include_stop_str_in_output": True, "repetition_penalty": repetition_penalty} 52 | 53 | ) 54 | 55 | return completion.choices[0].text 56 | 57 | 58 | def prm_parallel_query(host, port, model_name, prompt, api_key="", repetition_penalty=1.0, n=1): 59 | client = OpenAI(base_url=f"http://{host}:{port}/v1", 60 | api_key=api_key, ) 61 | 62 | completion = client.completions.create( 63 | model=model_name, 64 | prompt=prompt, 65 | temperature=0.1, 66 | stop=['\]\n\n', '\)\n\n', '# END!'], 67 | max_tokens=1024, 68 | extra_body={"include_stop_str_in_output": True, "repetition_penalty": repetition_penalty}, 69 | n=n 70 | 71 | ) 72 | return [choice.text for choice in completion.choices] 73 | 74 | -------------------------------------------------------------------------------- /self-training/construct_prompt.py: -------------------------------------------------------------------------------- 1 | def construct_policy_model_prompt_for_PRM_HRM(question: str, previous_reasoning: str = None) -> str: 2 | if not previous_reasoning: 3 | template = f"""You are an expert of Math and need to solve the following question and return the answer. 4 | 5 | Question: 6 | {question} 7 | 8 | 9 | Let's analyze this step by step. 10 | 11 | Begin each step with '# Step X' to clearly indicate the entire reasoning step. 12 | After you finish thinking, you need to output the answer again! 13 | The answer should start with '# Answer', followed by two line breaks and the final response. 14 | Just provide the answer value without any descriptive text at the end. 15 | And the answer ends with '# END!' 16 | Below is a correct example of the expected output format: 17 | ----------------- 18 | Question: 1+2+3 = ? 19 | 20 | # Step 1 21 | solve 1 + 2 = 3, 22 | 23 | # Step 2 24 | Then, 3 + 3 = 6. 25 | 26 | # Answer 27 | 28 | 6 29 | # END! 30 | ----------------- 31 | """ 32 | else: 33 | template = f"""You are an expert of Math and need to solve the following question and return the answer. 34 | 35 | Question: 36 | {question} 37 | 38 | 39 | Let's analyze this step by step. 40 | 41 | Begin each step with '# Step X' to clearly indicate the entire reasoning step. 42 | After you finish thinking, you need to output the answer again! 43 | The answer should start with '# Answer', followed by two line breaks and the final response. 44 | Just provide the answer value without any descriptive text at the end. 45 | And the answer ends with '# END!' 46 | Below is a correct example of the expected output format: 47 | ----------------- 48 | Question: 1+2+3 = ? 49 | 50 | # Step 1 51 | solve 1 + 2 = 3, 52 | 53 | # Step 2 54 | Then, 3 + 3 = 6. 55 | 56 | # Answer 57 | 58 | 6 59 | # END! 60 | ----------------- 61 | {previous_reasoning} 62 | """ 63 | return template 64 | 65 | def construct_PRM_HRM_prompt_v2(question, previous_steps, current_step): 66 | if previous_steps: 67 | template = f"""Question: 68 | {question} 69 | 70 | Let's break it down step by step! 71 | 72 | Previous reasoning: 73 | {previous_steps} 74 | 75 | Now, let's focus on the current step: 76 | {current_step}""" 77 | 78 | 79 | else: 80 | template = f"""Question: 81 | {question} 82 | 83 | Let's break it down step by step! 84 | 85 | Now, let's focus on the current step: 86 | {current_step} 87 | """ 88 | # print("----\nPRM_REWARD_MODEL_PROMPT:") 89 | # print(template) 90 | # print("-----finish prm prompt") 91 | return template 92 | -------------------------------------------------------------------------------- /construct_dataset/construct_prompt.py: -------------------------------------------------------------------------------- 1 | def construct_policy_model_prompt_for_ORM(question): 2 | # This is the policy model's prompt for ORM task 3 | template = f"""You are an expert of Math and need to solve the following question and return the answer. 4 | 5 | Question: 6 | {question} 7 | 8 | Let's analyze this step by step. 9 | 10 | After you finish thinking, you need to output the answer again! 11 | The answer should start with '# Answer', followed by two line breaks and the final response. 12 | Just provide the answer value without any descriptive text at the end. 13 | And the answer ends with '# END!' 14 | Below is a correct example of the expected output format: 15 | ----------------- 16 | Question: 1+2+3 = ? 17 | 18 | Firstly, solve 1 + 2 = 3, 19 | Then, 3 + 3 = 6. 20 | 21 | # Answer 22 | 23 | 6 24 | # END! 25 | ----------------- 26 | """ 27 | 28 | return template 29 | 30 | 31 | def construct_policy_model_prompt_for_PRM_HRM(question: str, previous_reasoning: str = None) -> str: 32 | # This is the policy model's prompt for PRM and HRM task 33 | if not previous_reasoning: 34 | template = f"""You are an expert of Math and need to solve the following question and return the answer. 35 | 36 | Question: 37 | {question} 38 | 39 | 40 | Let's analyze this step by step. 41 | 42 | After you finish thinking, you need to output the answer again! 43 | The answer should start with '# Answer', followed by two line breaks and the final response. 44 | Just provide the answer value without any descriptive text at the end. 45 | And the answer ends with '# END!' 46 | Below is a correct example of the expected output format: 47 | ----------------- 48 | Question: 1+2+3 = ? 49 | 50 | Firstly, solve 1 + 2 = 3, 51 | Then, 3 + 3 = 6. 52 | 53 | # Answer 54 | 55 | 6 56 | # END! 57 | ----------------- 58 | """ 59 | else: 60 | template = f"""You are an expert of Math and need to solve the following question and return the answer. 61 | 62 | Question: 63 | {question} 64 | 65 | 66 | Let's analyze this step by step. 67 | 68 | After you finish thinking, you need to output the answer again! 69 | The answer should start with '# Answer', followed by two line breaks and the final response. 70 | Just provide the answer value without any descriptive text at the end. 71 | And the answer ends with '# END!' 72 | Below is a correct example of the expected output format: 73 | ----------------- 74 | Question: 1+2+3 = ? 75 | 76 | Firstly, solve 1 + 2 = 3, 77 | Then, 3 + 3 = 6. 78 | 79 | # Answer 80 | 81 | 6 82 | # END! 83 | ----------------- 84 | {previous_reasoning} 85 | """ 86 | return template 87 | 88 | 89 | def construct_ORM_prompt(question, answer): 90 | # ORM prompt (RM) 91 | template = f"""Question is as follows: 92 | {question} 93 | 94 | The answer is as follows: 95 | {answer} 96 | """ 97 | return template 98 | 99 | 100 | def construct_PRM_HRM_prompt_v2(question, previous_steps, current_step): 101 | # HRM and PRM prompt (RM) when calculating the score 102 | if previous_steps: 103 | template = f"""Question: 104 | {question} 105 | 106 | Let's break it down step by step! 107 | 108 | Previous reasoning: 109 | {previous_steps} 110 | 111 | Now, let's focus on the current step: 112 | {current_step}""" 113 | 114 | 115 | else: 116 | template = f"""Question: 117 | {question} 118 | 119 | Let's break it down step by step! 120 | 121 | Now, let's focus on the current step: 122 | {current_step} 123 | """ 124 | return template 125 | 126 | 127 | def construct_PRM_HRM_prompt(question, answer): 128 | # PRM and HRM prompt when construct training dataset for PRM800K dataset 129 | placeholder = " /qwerdf12344567" 130 | len_placeholder = len(placeholder) 131 | answer = answer[len_placeholder:] 132 | answer_slices = answer.split(placeholder) 133 | if len(answer_slices) == 1: 134 | current_step = answer_slices[0] 135 | previous_steps = None 136 | else: 137 | current_step = answer_slices[-1] 138 | previous_steps = "\n\n".join(answer_slices[:-1]) 139 | if previous_steps: 140 | template = f"""Question: 141 | {question} 142 | 143 | Let's break it down step by step! 144 | 145 | Previous reasoning: 146 | {previous_steps} 147 | 148 | Now, let's focus on the current step: 149 | {current_step}""" 150 | 151 | 152 | else: 153 | template = f"""Question: 154 | {question} 155 | 156 | Let's break it down step by step! 157 | 158 | Now, let's focus on the current step: 159 | {current_step} 160 | """ 161 | return template 162 | -------------------------------------------------------------------------------- /self-training/choose_best_reward_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments 4 | from datasets import load_dataset 5 | import torch 6 | 7 | 8 | def load_data_for_task(task: str = 'orm', version="v1"): 9 | if task in ["orm", "prm", "hrm"]: 10 | eval_data_path = f"dataset/self_training_{version}_scoring/{task}_test.jsonl" 11 | else: 12 | raise ValueError("The task should be either 'orm', 'prm', or 'hrm'.") 13 | 14 | print(f"eval data path: {eval_data_path}") 15 | 16 | eval_dataset = load_dataset('json', data_files=eval_data_path) 17 | eval_dataset = eval_dataset['train'] 18 | 19 | return eval_dataset 20 | 21 | 22 | def load_model(model_path): 23 | model = AutoModelForSequenceClassification.from_pretrained(model_path, 24 | num_labels=1, 25 | attn_implementation="flash_attention_2", 26 | torch_dtype=torch.float16, 27 | ) 28 | tokenizer = AutoTokenizer.from_pretrained(model_path) 29 | 30 | model.config.pad_token_id = model.config.eos_token_id 31 | return model, tokenizer 32 | 33 | 34 | def preprocess_function(examples, tokenizer): 35 | inputs = examples["input"] 36 | labels = examples["label"] 37 | model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=4096) 38 | model_inputs["labels"] = labels 39 | return model_inputs 40 | 41 | 42 | def load_all_model_paths(base_path): 43 | paths = [] 44 | for folder_name in os.listdir(base_path): 45 | temp_folder_path = os.path.join(base_path, folder_name) 46 | for model_filename in os.listdir(temp_folder_path): 47 | current_path = os.path.join(temp_folder_path, model_filename) 48 | paths.append(current_path) 49 | return paths 50 | 51 | 52 | def main(task, version): 53 | _, tokenizer = load_model("models/Qwen2.5-1.5B-math") 54 | 55 | save_best_path = f"sf_best_model/{task}/{version}" 56 | os.makedirs(save_best_path, exist_ok=True) 57 | 58 | # Load data 59 | eval_dataset = load_data_for_task(task) 60 | 61 | # Load all models 62 | base_path = f"outputs/self_training_{version}/{task}/Qwen2.5-1.5B-math" 63 | 64 | all_model_paths = load_all_model_paths(base_path) 65 | print(f"all_model_paths: {all_model_paths}") 66 | 67 | # best_f1 = -1 68 | best_eval_loss = 1000 69 | best_model_path = None 70 | 71 | # Tokenize data 72 | preprocess_with_tokenizer = lambda examples: preprocess_function(examples, tokenizer) 73 | tokenized_eval_dataset = eval_dataset.map(preprocess_with_tokenizer, batched=True) 74 | 75 | for model_path in all_model_paths: 76 | print(f"Evaluating model: {model_path}") 77 | 78 | model, tokenizer = load_model(model_path) 79 | 80 | temp_model_path = model_path.replace(task, "temp" + task) 81 | training_args = TrainingArguments( 82 | output_dir=f"{temp_model_path}", 83 | per_device_eval_batch_size=32, 84 | per_device_train_batch_size=1, 85 | logging_dir=f"./logs/best/{task}", 86 | deepspeed="deepspeed_config/rm_config_stage2_py.json", 87 | fp16=True, 88 | ) 89 | 90 | trainer = Trainer( 91 | model=model, 92 | args=training_args, 93 | eval_dataset=tokenized_eval_dataset, 94 | tokenizer=tokenizer, 95 | ) 96 | 97 | metrics = trainer.evaluate() 98 | eval_loss = metrics["eval_loss"] 99 | 100 | print(f"Model: {model_path}, Loss: {eval_loss}") 101 | 102 | if eval_loss < best_eval_loss: 103 | best_eval_loss = eval_loss 104 | best_model_path = model_path 105 | 106 | if best_model_path: 107 | print(f"Best model found: {best_model_path} with F1: {best_eval_loss}") 108 | # model = AutoModelForSequenceClassification.from_pretrained(best_model_path) 109 | # tokenizer = AutoTokenizer.from_pretrained(best_model_path) 110 | # model.save_pretrained(save_best_path) 111 | # tokenizer.save_pretrained(save_best_path) 112 | else: 113 | print("No models evaluated.") 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser() 118 | parser.add_argument("--task", type=str, default="orm") 119 | parser.add_argument("--version", type=str, default="v1") 120 | args = parser.parse_args() 121 | 122 | task = args.task 123 | version = args.version 124 | main(task, version) 125 | -------------------------------------------------------------------------------- /self-training/assign_node_rm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Goal: assign every node in MCTS a score by using RM(PRM or HRM), but the score might be affected by reward hacking. 3 | So we don't use it and instead use the MC-score. 4 | """ 5 | import json 6 | 7 | import tqdm 8 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 9 | import os 10 | from construct_prm_train_data import load_all_pickle_file_paths 11 | from construct_prompt import construct_PRM_HRM_prompt_v2, construct_policy_model_prompt_for_PRM_HRM 12 | import torch 13 | import argparse 14 | import pickle 15 | 16 | 17 | def load_rm(path): 18 | model = AutoModelForSequenceClassification.from_pretrained(path, 19 | num_labels=1, 20 | attn_implementation="flash_attention_2", 21 | torch_dtype=torch.float16, 22 | ) 23 | tokenizer = AutoTokenizer.from_pretrained(path) 24 | model.config.pad_token_id = model.config.eos_token_id 25 | device = torch.device("cuda") 26 | model.to(device) 27 | return model, tokenizer 28 | 29 | 30 | def traverse_tree(path): 31 | prompts = [] 32 | with open(path, 'rb') as f: 33 | data = pickle.load(f) 34 | 35 | queue = [data] 36 | while queue: 37 | temp_length = len(queue) 38 | for _ in range(temp_length): 39 | node = queue.pop(0) 40 | if queue: 41 | rm_prompt = construct_PRM_HRM_prompt_v2(node.question, "".join(node.previous_answer[:-1]), 42 | node.previous_answer[-1]) 43 | policy_model_prompt = construct_policy_model_prompt_for_PRM_HRM(node.question, 44 | "".join(node.previous_answer)) 45 | prompts.append( 46 | json.dumps({"policy_model_prompt": policy_model_prompt, "reward_model_prompt": rm_prompt})) 47 | 48 | for child in node.children: 49 | queue.append(child) 50 | return set(prompts) 51 | 52 | 53 | def calculate_PRM_HRM_scores(model, tokenizer, prompts, N_batch=16): 54 | device = "cuda" 55 | current_step_score_pairs = [] 56 | 57 | prompts = [json.loads(prompt) for prompt in prompts] 58 | N = len(prompts) 59 | for i in range(0, N, N_batch): 60 | temp_prompts = prompts[i:i + N_batch] 61 | 62 | rw_prompts = [item['reward_model_prompt'] for item in temp_prompts] 63 | 64 | inputs = tokenizer(rw_prompts, return_tensors="pt", max_length=4096, padding=True, truncation=True) 65 | inputs = {key: value.to(device) for key, value in inputs.items()} 66 | 67 | with torch.no_grad(): 68 | outputs = model(**inputs) 69 | logits = outputs.logits 70 | 71 | positive_scores = logits[:, 0].tolist() 72 | 73 | batch_step_score_pairs = [{"score": score, "prompt": prompt['policy_model_prompt']} for score, prompt in 74 | zip(positive_scores, temp_prompts)] 75 | current_step_score_pairs.extend(batch_step_score_pairs) 76 | 77 | current_step_score_pairs.sort(reverse=True, key=lambda x: x['score']) 78 | return current_step_score_pairs 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--task', type=str, default="prm") 84 | parser.add_argument('--version', type=str, default="v1") 85 | parser.add_argument('--pkl_path', type=str, default="dataset/self_training") 86 | parser.add_argument('--N_batch', type=int, default=16) 87 | 88 | args = parser.parse_args() 89 | task = args.task 90 | version = args.version 91 | pkl_path = args.pkl_path 92 | N_batch = args.N_batch 93 | 94 | final_pairs = [] 95 | 96 | rm_path = f"sf_best_model/{task}/{version}" 97 | data_path = f"dataset/self_training_{version}_policy/" 98 | os.makedirs(data_path, exist_ok=True) 99 | 100 | model, tokenizer = load_rm(rm_path) 101 | 102 | pkl_paths = load_all_pickle_file_paths(pkl_path) 103 | 104 | for pkl_path in tqdm.tqdm(pkl_paths, total=len(pkl_paths)): 105 | prompts = traverse_tree(pkl_path) 106 | current_pair = calculate_PRM_HRM_scores(model, tokenizer, prompts, N_batch) 107 | final_pairs.extend(current_pair) 108 | with open(os.path.join(data_path, f"{task}.jsonl"), "w") as f: 109 | for pair in final_pairs: 110 | f.write(json.dumps(pair) + "\n") 111 | # CUDA_VISIBLE_DEVICES=0 nohup python self-training/assign_node_rm.py > logs/prm_v1_.log & 112 | # CUDA_VISIBLE_DEVICES=0 nohup python self-training/assign_node_rm.py --task hrm --version v1 > logs/hrm_v1_.log & 113 | -------------------------------------------------------------------------------- /self-training/choose_best_policy_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments 4 | from datasets import load_dataset 5 | import torch 6 | 7 | 8 | def load_data_for_task(task: str = 'orm', version="v1"): 9 | if task in ["orm", "prm", "hrm"]: 10 | eval_data_path = f"dataset/self_training_{version}_policy/can_use_{task}_eval.jsonl" 11 | else: 12 | raise ValueError("The task should be either 'orm', 'prm', or 'hrm'.") 13 | 14 | print(f"eval data path: {eval_data_path}") 15 | 16 | eval_dataset = load_dataset('json', data_files=eval_data_path) 17 | eval_dataset = eval_dataset['train'] 18 | 19 | return eval_dataset 20 | 21 | 22 | def load_model(model_path): 23 | model = AutoModelForCausalLM.from_pretrained(model_path, 24 | attn_implementation="flash_attention_2", 25 | torch_dtype=torch.float16, 26 | ) 27 | tokenizer = AutoTokenizer.from_pretrained(model_path) 28 | 29 | model.config.pad_token_id = model.config.eos_token_id 30 | return model, tokenizer 31 | 32 | 33 | def preprocess_function(examples, tokenizer): 34 | inputs = examples["input"] 35 | 36 | # Tokenize the inputs 37 | model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=4096) 38 | # Shift labels by one token (for causal LM) 39 | labels = model_inputs["input_ids"].copy() 40 | labels = labels[1:] + [tokenizer.pad_token_id] # Shift by one and pad the last token 41 | 42 | model_inputs["labels"] = labels 43 | 44 | return model_inputs 45 | 46 | 47 | def load_all_model_paths(base_path): 48 | paths = [] 49 | for folder_name in os.listdir(base_path): 50 | temp_folder_path = os.path.join(base_path, folder_name) 51 | for model_filename in os.listdir(temp_folder_path): 52 | current_path = os.path.join(temp_folder_path, model_filename) 53 | paths.append(current_path) 54 | return paths 55 | 56 | 57 | def main(task, version): 58 | _, tokenizer = load_model("models/Qwen2.5-Math-7B-Instruct") 59 | 60 | save_best_path = f"sf_best_policy_model/{task}/{version}" 61 | os.makedirs(save_best_path, exist_ok=True) 62 | 63 | # Load data 64 | eval_dataset = load_data_for_task(task) 65 | 66 | # Load all models 67 | base_path = f"outputs/self_training_{version}_policy_model/{task}/Qwen2.5-Math-7B-Instruct" 68 | 69 | all_model_paths = load_all_model_paths(base_path) 70 | print(f"all_model_paths: {all_model_paths}") 71 | 72 | # best_f1 = -1 73 | best_eval_loss = 1000 74 | best_model_path = None 75 | 76 | # Tokenize data 77 | preprocess_with_tokenizer = lambda examples: preprocess_function(examples, tokenizer) 78 | tokenized_eval_dataset = eval_dataset.map(preprocess_with_tokenizer, batched=True) 79 | 80 | for model_path in all_model_paths: 81 | print(f"Evaluating model: {model_path}") 82 | 83 | model, tokenizer = load_model(model_path) 84 | 85 | temp_model_path = model_path.replace(task, "temp" + task) 86 | training_args = TrainingArguments( 87 | output_dir=f"{temp_model_path}", 88 | per_device_eval_batch_size=16, 89 | per_device_train_batch_size=1, 90 | logging_dir=f"./logs/best/{task}", 91 | deepspeed="deepspeed_config/policy_model_72b.json", 92 | fp16=True, 93 | ) 94 | 95 | trainer = Trainer( 96 | model=model, 97 | args=training_args, 98 | eval_dataset=tokenized_eval_dataset, 99 | tokenizer=tokenizer, 100 | ) 101 | 102 | metrics = trainer.evaluate() 103 | eval_loss = metrics["eval_loss"] 104 | 105 | print(f"Model: {model_path}, Loss: {eval_loss}") 106 | 107 | if eval_loss < best_eval_loss: 108 | best_eval_loss = eval_loss 109 | best_model_path = model_path 110 | 111 | if best_model_path: 112 | print(f"Best model found: {best_model_path} with F1: {best_eval_loss}") 113 | # model = AutoModelForSequenceClassification.from_pretrained(best_model_path) 114 | # tokenizer = AutoTokenizer.from_pretrained(best_model_path) 115 | # model.save_pretrained(save_best_path) 116 | # tokenizer.save_pretrained(save_best_path) 117 | else: 118 | print("No models evaluated.") 119 | 120 | 121 | if __name__ == "__main__": 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument("--task", type=str, default="orm") 124 | parser.add_argument("--version", type=str, default="v1") 125 | args = parser.parse_args() 126 | 127 | task = args.task 128 | version = args.version 129 | main(task, version) 130 | -------------------------------------------------------------------------------- /construct_dataset/train_eval_split.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from rm_dataset_construction_phase1 import phase1_ORM_dataset, phase1_PRM_dataset, phase1_HRM_dataset, \ 3 | split_correct_and_wrong_step_phase1, parse_jsonl 4 | from rm_dataset_construction_phase2 import phase2_ORM_dataset, phase2_PRM_dataset, phase2_HRM_dataset, \ 5 | split_correct_and_wrong_step_phase2 6 | from construct_prompt import construct_ORM_prompt, construct_PRM_HRM_prompt 7 | import json 8 | import random 9 | import os 10 | from transformers import AutoTokenizer 11 | 12 | RATIO = 0.8 13 | 14 | 15 | def train_eval_split_rm(a_list): 16 | random.shuffle(a_list) 17 | length = len(a_list) 18 | train_idx = int(length * RATIO) 19 | return a_list[:train_idx], a_list[train_idx:] 20 | 21 | 22 | def write_jsonl(a_list, path): 23 | tokenizer = AutoTokenizer.from_pretrained("models/Qwen2.5-1.5B-math") 24 | seen = set() 25 | with open(path, 'w') as f: 26 | for json_ in a_list: 27 | input_ = json_['input'] 28 | length = len(tokenizer(input_, truncation=False)['input_ids']) 29 | if length > 4096 or input_ in seen: 30 | continue 31 | seen.add(input_) 32 | f.write(json.dumps(json_) + "\n") 33 | 34 | 35 | def orm_formatter(a_list): 36 | datas = [] 37 | for a_dict in a_list: 38 | question = a_dict['question'] 39 | for correct_answer in a_dict['correct']: 40 | prompt = construct_ORM_prompt(question, correct_answer) 41 | datas.append({'input': prompt, 'label': 1}) 42 | for incorrect_answer in a_dict['incorrect']: 43 | prompt = construct_ORM_prompt(question, incorrect_answer) 44 | datas.append({'input': prompt, 'label': 0}) 45 | return datas 46 | 47 | 48 | 49 | def prm_and_hrm_formatter(a_list): 50 | datas = [] 51 | for a_dict in a_list: 52 | question = a_dict['question'] 53 | for correct_answer in a_dict['correct']: 54 | prompt = construct_PRM_HRM_prompt(question, correct_answer) 55 | datas.append({'input': prompt, 'label': 1}) 56 | for incorrect_answer in a_dict['incorrect']: 57 | prompt = construct_PRM_HRM_prompt(question, incorrect_answer) 58 | datas.append({'input': prompt, 'label': 0}) 59 | return datas 60 | 61 | 62 | def construct_phase1_train_eval(origin_data_location='dataset/prm_dataset/phase1_train.jsonl', 63 | base_path='dataset/phase1/'): 64 | dicts = parse_jsonl(origin_data_location) 65 | questions, positive_list, negative_list, neutral_list, chosen_completion_list, status_list = split_correct_and_wrong_step_phase1( 66 | dicts) 67 | 68 | # ORM 69 | orm_list = phase1_ORM_dataset(questions, positive_list, negative_list, neutral_list, chosen_completion_list) 70 | orm_train_list, orm_eval_list = train_eval_split_rm(orm_list) 71 | orm_train_data = orm_formatter(orm_train_list) 72 | orm_eval_data = orm_formatter(orm_eval_list) 73 | 74 | orm_path = os.path.join(base_path, 'orm') 75 | os.makedirs(orm_path, exist_ok=True) 76 | 77 | orm_train_path = os.path.join(orm_path, 'train.jsonl') 78 | orm_eval_path = os.path.join(orm_path, 'eval.jsonl') 79 | 80 | write_jsonl(orm_train_data, orm_train_path) 81 | write_jsonl(orm_eval_data, orm_eval_path) 82 | 83 | # PRM 84 | print("PRM: ") 85 | prm_list = phase1_PRM_dataset(questions, positive_list, negative_list, neutral_list, chosen_completion_list) 86 | prm_train_list, prm_eval_list = train_eval_split_rm(prm_list) 87 | prm_train_data = prm_and_hrm_formatter(prm_train_list) 88 | prm_eval_data = prm_and_hrm_formatter(prm_eval_list) 89 | 90 | prm_path = os.path.join(base_path, 'prm') 91 | os.makedirs(prm_path, exist_ok=True) 92 | 93 | prm_train_path = os.path.join(prm_path, 'train.jsonl') 94 | prm_eval_path = os.path.join(prm_path, 'eval.jsonl') 95 | 96 | write_jsonl(prm_train_data, prm_train_path) 97 | write_jsonl(prm_eval_data, prm_eval_path) 98 | 99 | # HRM 100 | print("HRM: ") 101 | hrm_list = phase1_HRM_dataset(questions, positive_list, negative_list, neutral_list, chosen_completion_list) 102 | hrm_train_list, hrm_eval_list = train_eval_split_rm(hrm_list) 103 | hrm_train_data = prm_and_hrm_formatter(hrm_train_list) 104 | hrm_eval_data = prm_and_hrm_formatter(hrm_eval_list) 105 | 106 | hrm_path = os.path.join(base_path, 'hrm') 107 | os.makedirs(hrm_path, exist_ok=True) 108 | 109 | hrm_train_path = os.path.join(hrm_path, 'train.jsonl') 110 | hrm_eval_path = os.path.join(hrm_path, 'eval.jsonl') 111 | 112 | write_jsonl(hrm_train_data, hrm_train_path) 113 | write_jsonl(hrm_eval_data, hrm_eval_path) 114 | 115 | 116 | 117 | if __name__ == '__main__': 118 | construct_phase1_train_eval('dataset/prm_dataset/phase1_train.jsonl', 119 | base_path='dataset/phase1/') 120 | -------------------------------------------------------------------------------- /self-training/grading/math_normalize.py: -------------------------------------------------------------------------------- 1 | """ 2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence). 3 | """ 4 | import re 5 | from typing import Optional 6 | 7 | 8 | def normalize_answer(answer: Optional[str]) -> Optional[str]: 9 | if answer is None: 10 | return None 11 | answer = answer.strip() 12 | try: 13 | # Remove enclosing `\text{}`. 14 | m = re.search("^\\\\text\{(?P.+?)\}$", answer) 15 | if m is not None: 16 | answer = m.group("text").strip() 17 | return _strip_string(answer) 18 | except: 19 | return answer 20 | 21 | 22 | def _fix_fracs(string): 23 | substrs = string.split("\\frac") 24 | new_str = substrs[0] 25 | if len(substrs) > 1: 26 | substrs = substrs[1:] 27 | for substr in substrs: 28 | new_str += "\\frac" 29 | if substr[0] == "{": 30 | new_str += substr 31 | else: 32 | try: 33 | assert len(substr) >= 2 34 | except: 35 | return string 36 | a = substr[0] 37 | b = substr[1] 38 | if b != "{": 39 | if len(substr) > 2: 40 | post_substr = substr[2:] 41 | new_str += "{" + a + "}{" + b + "}" + post_substr 42 | else: 43 | new_str += "{" + a + "}{" + b + "}" 44 | else: 45 | if len(substr) > 2: 46 | post_substr = substr[2:] 47 | new_str += "{" + a + "}" + b + post_substr 48 | else: 49 | new_str += "{" + a + "}" + b 50 | string = new_str 51 | return string 52 | 53 | 54 | def _fix_a_slash_b(string): 55 | if len(string.split("/")) != 2: 56 | return string 57 | a = string.split("/")[0] 58 | b = string.split("/")[1] 59 | try: 60 | a = int(a) 61 | b = int(b) 62 | assert string == "{}/{}".format(a, b) 63 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 64 | return new_string 65 | except: 66 | return string 67 | 68 | 69 | def _remove_right_units(string): 70 | # "\\text{ " only ever occurs (at least in the val set) when describing units 71 | if "\\text{ " in string: 72 | splits = string.split("\\text{ ") 73 | assert len(splits) == 2 74 | return splits[0] 75 | else: 76 | return string 77 | 78 | 79 | def _fix_sqrt(string): 80 | if "\\sqrt" not in string: 81 | return string 82 | splits = string.split("\\sqrt") 83 | new_string = splits[0] 84 | for split in splits[1:]: 85 | if split[0] != "{": 86 | a = split[0] 87 | new_substr = "\\sqrt{" + a + "}" + split[1:] 88 | else: 89 | new_substr = "\\sqrt" + split 90 | new_string += new_substr 91 | return new_string 92 | 93 | 94 | def _strip_string(string): 95 | # linebreaks 96 | string = string.replace("\n", "") 97 | # print(string) 98 | 99 | # remove inverse spaces 100 | string = string.replace("\\!", "") 101 | # print(string) 102 | 103 | # replace \\ with \ 104 | string = string.replace("\\\\", "\\") 105 | # print(string) 106 | 107 | # replace tfrac and dfrac with frac 108 | string = string.replace("tfrac", "frac") 109 | string = string.replace("dfrac", "frac") 110 | # print(string) 111 | 112 | # remove \left and \right 113 | string = string.replace("\\left", "") 114 | string = string.replace("\\right", "") 115 | # print(string) 116 | 117 | # Remove circ (degrees) 118 | string = string.replace("^{\\circ}", "") 119 | string = string.replace("^\\circ", "") 120 | 121 | # remove dollar signs 122 | string = string.replace("\\$", "") 123 | 124 | # remove units (on the right) 125 | string = _remove_right_units(string) 126 | 127 | # remove percentage 128 | string = string.replace("\\%", "") 129 | string = string.replace("\%", "") 130 | 131 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 132 | string = string.replace(" .", " 0.") 133 | string = string.replace("{.", "{0.") 134 | # if empty, return empty string 135 | if len(string) == 0: 136 | return string 137 | if string[0] == ".": 138 | string = "0" + string 139 | 140 | # to consider: get rid of e.g. "k = " or "q = " at beginning 141 | if len(string.split("=")) == 2: 142 | if len(string.split("=")[0]) <= 2: 143 | string = string.split("=")[1] 144 | 145 | # fix sqrt3 --> sqrt{3} 146 | string = _fix_sqrt(string) 147 | 148 | # remove spaces 149 | string = string.replace(" ", "") 150 | 151 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 152 | string = _fix_fracs(string) 153 | 154 | # manually change 0.5 --> \frac{1}{2} 155 | if string == "0.5": 156 | string = "\\frac{1}{2}" 157 | 158 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 159 | string = _fix_a_slash_b(string) 160 | 161 | return string 162 | -------------------------------------------------------------------------------- /sft_rw_manual_annotation/grading/math_normalize.py: -------------------------------------------------------------------------------- 1 | """ 2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence). 3 | """ 4 | import re 5 | from typing import Optional 6 | 7 | 8 | def normalize_answer(answer: Optional[str]) -> Optional[str]: 9 | if answer is None: 10 | return None 11 | answer = answer.strip() 12 | try: 13 | # Remove enclosing `\text{}`. 14 | m = re.search("^\\\\text\{(?P.+?)\}$", answer) 15 | if m is not None: 16 | answer = m.group("text").strip() 17 | return _strip_string(answer) 18 | except: 19 | return answer 20 | 21 | 22 | def _fix_fracs(string): 23 | substrs = string.split("\\frac") 24 | new_str = substrs[0] 25 | if len(substrs) > 1: 26 | substrs = substrs[1:] 27 | for substr in substrs: 28 | new_str += "\\frac" 29 | if substr[0] == "{": 30 | new_str += substr 31 | else: 32 | try: 33 | assert len(substr) >= 2 34 | except: 35 | return string 36 | a = substr[0] 37 | b = substr[1] 38 | if b != "{": 39 | if len(substr) > 2: 40 | post_substr = substr[2:] 41 | new_str += "{" + a + "}{" + b + "}" + post_substr 42 | else: 43 | new_str += "{" + a + "}{" + b + "}" 44 | else: 45 | if len(substr) > 2: 46 | post_substr = substr[2:] 47 | new_str += "{" + a + "}" + b + post_substr 48 | else: 49 | new_str += "{" + a + "}" + b 50 | string = new_str 51 | return string 52 | 53 | 54 | def _fix_a_slash_b(string): 55 | if len(string.split("/")) != 2: 56 | return string 57 | a = string.split("/")[0] 58 | b = string.split("/")[1] 59 | try: 60 | a = int(a) 61 | b = int(b) 62 | assert string == "{}/{}".format(a, b) 63 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 64 | return new_string 65 | except: 66 | return string 67 | 68 | 69 | def _remove_right_units(string): 70 | # "\\text{ " only ever occurs (at least in the val set) when describing units 71 | if "\\text{ " in string: 72 | splits = string.split("\\text{ ") 73 | assert len(splits) == 2 74 | return splits[0] 75 | else: 76 | return string 77 | 78 | 79 | def _fix_sqrt(string): 80 | if "\\sqrt" not in string: 81 | return string 82 | splits = string.split("\\sqrt") 83 | new_string = splits[0] 84 | for split in splits[1:]: 85 | if split[0] != "{": 86 | a = split[0] 87 | new_substr = "\\sqrt{" + a + "}" + split[1:] 88 | else: 89 | new_substr = "\\sqrt" + split 90 | new_string += new_substr 91 | return new_string 92 | 93 | 94 | def _strip_string(string): 95 | # linebreaks 96 | string = string.replace("\n", "") 97 | # print(string) 98 | 99 | # remove inverse spaces 100 | string = string.replace("\\!", "") 101 | # print(string) 102 | 103 | # replace \\ with \ 104 | string = string.replace("\\\\", "\\") 105 | # print(string) 106 | 107 | # replace tfrac and dfrac with frac 108 | string = string.replace("tfrac", "frac") 109 | string = string.replace("dfrac", "frac") 110 | # print(string) 111 | 112 | # remove \left and \right 113 | string = string.replace("\\left", "") 114 | string = string.replace("\\right", "") 115 | # print(string) 116 | 117 | # Remove circ (degrees) 118 | string = string.replace("^{\\circ}", "") 119 | string = string.replace("^\\circ", "") 120 | 121 | # remove dollar signs 122 | string = string.replace("\\$", "") 123 | 124 | # remove units (on the right) 125 | string = _remove_right_units(string) 126 | 127 | # remove percentage 128 | string = string.replace("\\%", "") 129 | string = string.replace("\%", "") 130 | 131 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 132 | string = string.replace(" .", " 0.") 133 | string = string.replace("{.", "{0.") 134 | # if empty, return empty string 135 | if len(string) == 0: 136 | return string 137 | if string[0] == ".": 138 | string = "0" + string 139 | 140 | # to consider: get rid of e.g. "k = " or "q = " at beginning 141 | if len(string.split("=")) == 2: 142 | if len(string.split("=")[0]) <= 2: 143 | string = string.split("=")[1] 144 | 145 | # fix sqrt3 --> sqrt{3} 146 | string = _fix_sqrt(string) 147 | 148 | # remove spaces 149 | string = string.replace(" ", "") 150 | 151 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 152 | string = _fix_fracs(string) 153 | 154 | # manually change 0.5 --> \frac{1}{2} 155 | if string == "0.5": 156 | string = "\\frac{1}{2}" 157 | 158 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 159 | string = _fix_a_slash_b(string) 160 | 161 | return string 162 | -------------------------------------------------------------------------------- /self-training/sft_rm.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 3 | from transformers import Trainer, TrainingArguments 4 | import os 5 | from sklearn.metrics import mean_squared_error, mean_absolute_error 6 | from datetime import datetime 7 | import argparse 8 | from functools import partial 9 | import torch 10 | 11 | 12 | def compute_metrics(p): 13 | predictions, labels = p 14 | predictions = predictions.squeeze(-1) # Remove extra dimensions 15 | mse = mean_squared_error(labels, predictions) 16 | mae = mean_absolute_error(labels, predictions) 17 | return {"mse": mse, "mae": mae} 18 | 19 | 20 | def choose_task(data_path, task: str = 'orm'): 21 | train_data_path = os.path.join(data_path, f'{task}_train.jsonl') 22 | eval_data_path = os.path.join(data_path, f'{task}_test.jsonl') 23 | 24 | print(f"train data path: {train_data_path}") 25 | print(f"eval data path: {eval_data_path}") 26 | 27 | train_dataset = load_dataset('json', data_files=train_data_path) 28 | eval_dataset = load_dataset('json', data_files=eval_data_path) 29 | 30 | train_dataset = train_dataset['train'].shuffle() 31 | eval_dataset = eval_dataset['train'] 32 | 33 | # train_dataset = train_dataset.select([0, 1, 2, 3, 4, 5, 6, 7]) 34 | # eval_dataset = eval_dataset.select([0, 1, 2, 3, 4, 5, 6, 7, ]) 35 | 36 | return train_dataset, eval_dataset 37 | 38 | 39 | def train(model_path, train_dataset, eval_dataset, idx=1, task='orm', 40 | time=datetime.now().strftime("%Y-%m-%d_%H-%M-%S")): 41 | model = AutoModelForSequenceClassification.from_pretrained(model_path, 42 | num_labels=1, 43 | attn_implementation="flash_attention_2", 44 | torch_dtype=torch.float16, 45 | ) 46 | tokenizer = AutoTokenizer.from_pretrained(model_path) 47 | 48 | model.config.pad_token_id = model.config.eos_token_id 49 | model_name = model_path.split('/')[-1] 50 | output_dir = f"outputs/self_training_v{idx}/{task}/{model_name}/{time}" 51 | 52 | EVAL_STEP = 50 53 | 54 | training_args = TrainingArguments( 55 | 56 | output_dir=output_dir, 57 | learning_rate=1e-5, 58 | max_grad_norm=0.01, 59 | # warmup_steps=20, 60 | warmup_steps=10, 61 | per_device_train_batch_size=5, 62 | per_device_eval_batch_size=64, 63 | num_train_epochs=20, 64 | eval_strategy="steps", 65 | eval_steps=EVAL_STEP, 66 | save_steps=EVAL_STEP, 67 | save_strategy="steps", 68 | logging_steps=EVAL_STEP, 69 | logging_dir=f"./self-train_logs_v{idx}/{task}/{model_name}/{time}", 70 | 71 | load_best_model_at_end=True, 72 | fp16=True, 73 | # tf32=True, 74 | gradient_accumulation_steps=16, 75 | 76 | # greater_is_better=False, 77 | # save_total_limit=3, 78 | # metric_for_best_model="eval_loss", 79 | 80 | greater_is_better=False, 81 | save_total_limit=1, 82 | metric_for_best_model="eval_loss", 83 | 84 | deepspeed="deepspeed_config/RM_config_stage2.json", 85 | report_to="tensorboard", 86 | ) 87 | 88 | trainer = Trainer( 89 | model=model, 90 | args=training_args, 91 | 92 | train_dataset=train_dataset, 93 | eval_dataset=eval_dataset, 94 | tokenizer=tokenizer, 95 | compute_metrics=compute_metrics 96 | ) 97 | 98 | trainer.train() 99 | return trainer 100 | 101 | 102 | def preprocess_function(examples, tokenizer): 103 | inputs = examples["input"] 104 | labels = examples["label"] 105 | model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=4096) 106 | model_inputs["labels"] = labels 107 | return model_inputs 108 | 109 | 110 | if __name__ == '__main__': 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument("--data_path", type=str, default="dataset/self_training_v1_scoring") 113 | parser.add_argument("--model_path", type=str, default="models/Qwen2.5-1.5B-math") 114 | parser.add_argument("--task", type=str, default='orm') 115 | parser.add_argument("--idx", type=int, default=1) 116 | 117 | args = parser.parse_args() 118 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, legacy=False) 119 | 120 | train_dataset, eval_dataset = choose_task(args.data_path, args.task) 121 | 122 | preprocess_with_tokenizer = partial(preprocess_function, tokenizer=tokenizer) 123 | 124 | train_dataset = train_dataset.map(preprocess_with_tokenizer, batched=True) 125 | eval_dataset = eval_dataset.map(preprocess_with_tokenizer, batched=True) 126 | 127 | train(model_path=args.model_path, train_dataset=train_dataset, idx=args.idx, eval_dataset=eval_dataset, 128 | task=args.task) 129 | # nohup accelerate launch --config_file accelerate_config/8gpus.yaml --gpu_ids 0,1,2,3,4,5,6,7 self-training/sft_rm.py --task prm --idx 1 > logs/self_training_prm.log & 130 | # nohup accelerate launch --config_file accelerate_config/6gpus.yaml --gpu_ids 0,1,2,3,4,5 self-training/sft_rm.py --task prm --idx 1 > logs/self_training_prm.log & 131 | # nohup accelerate launch --config_file accelerate_config/8gpus.yaml --gpu_ids 0,1,2,3,4,5,6,7 self-training/sft_rm.py --task hrm --idx 1 > logs/self_training_hrm_v1.log & -------------------------------------------------------------------------------- /sft_rw_manual_annotation/choose_best_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments 4 | from datasets import load_dataset 5 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def load_data_for_task(task: str = 'orm'): 11 | if task in ["orm", "prm", "hrm"]: 12 | data_path = f"dataset/phase1/{task}" 13 | else: 14 | raise ValueError("The task should be either 'orm', 'prm', or 'hrm'.") 15 | eval_data_path = os.path.join(data_path, 'eval.jsonl') 16 | 17 | print(f"eval data path: {eval_data_path}") 18 | 19 | eval_dataset = load_dataset('json', data_files=eval_data_path) 20 | eval_dataset = eval_dataset['train'] 21 | 22 | return eval_dataset 23 | 24 | 25 | def load_model(model_path): 26 | model = AutoModelForSequenceClassification.from_pretrained(model_path, 27 | num_labels=2, 28 | attn_implementation="flash_attention_2", 29 | torch_dtype=torch.float16, 30 | ) 31 | tokenizer = AutoTokenizer.from_pretrained(model_path) 32 | 33 | model.config.pad_token_id = model.config.eos_token_id 34 | return model, tokenizer 35 | 36 | 37 | def preprocess_function(examples, tokenizer): 38 | inputs = examples["input"] 39 | labels = examples["label"] 40 | model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=4096) 41 | model_inputs["labels"] = labels 42 | return model_inputs 43 | 44 | 45 | def compute_metrics(eval_pred): 46 | logits, labels = eval_pred 47 | predictions = np.argmax(logits, axis=-1) 48 | 49 | accuracy = accuracy_score(labels, predictions) 50 | precision = precision_score(labels, predictions, average='weighted') 51 | recall = recall_score(labels, predictions, average='weighted') 52 | f1 = f1_score(labels, predictions, average='weighted') 53 | 54 | return { 55 | "accuracy": accuracy, 56 | "precision": precision, 57 | "recall": recall, 58 | "f1": f1 59 | } 60 | 61 | 62 | def load_all_model_paths(base_path): 63 | paths = [] 64 | for folder_name in os.listdir(base_path): 65 | temp_folder_path = os.path.join(base_path, folder_name) 66 | for model_filename in os.listdir(temp_folder_path): 67 | current_path = os.path.join(temp_folder_path, model_filename) 68 | paths.append(current_path) 69 | return paths 70 | 71 | 72 | def main(task): 73 | _, tokenizer = load_model("models/Qwen2.5-1.5B-math") 74 | 75 | save_best_path = f"best_model/{task}" 76 | os.makedirs(save_best_path, exist_ok=True) 77 | 78 | # Load data 79 | eval_dataset = load_data_for_task(task) 80 | 81 | # Load all models 82 | base_path = f"outputs/{task}/Qwen2.5-1.5B-math" 83 | all_model_paths = load_all_model_paths(base_path) 84 | print(f"all_model_paths: {all_model_paths}") 85 | 86 | # best_f1 = -1 87 | best_eval_loss = 1000 88 | best_model_path = None 89 | 90 | # Tokenize data 91 | preprocess_with_tokenizer = lambda examples: preprocess_function(examples, tokenizer) 92 | tokenized_eval_dataset = eval_dataset.map(preprocess_with_tokenizer, batched=True) 93 | 94 | for model_path in all_model_paths: 95 | print(f"Evaluating model: {model_path}") 96 | 97 | model, tokenizer = load_model(model_path) 98 | 99 | temp_model_path = model_path.replace(task, "temp" + task) 100 | training_args = TrainingArguments( 101 | output_dir=f"{temp_model_path}", 102 | per_device_eval_batch_size=64, 103 | per_device_train_batch_size=1, 104 | logging_dir=f"./logs/best/{task}", 105 | deepspeed="deepspeed_config/rm_config_stage2_py.json", 106 | fp16=True, 107 | ) 108 | 109 | trainer = Trainer( 110 | model=model, 111 | args=training_args, 112 | eval_dataset=tokenized_eval_dataset, 113 | tokenizer=tokenizer, 114 | compute_metrics=compute_metrics 115 | ) 116 | 117 | metrics = trainer.evaluate() 118 | eval_f1 = metrics["eval_f1"] 119 | eval_loss = metrics["eval_loss"] 120 | eval_accuracy = metrics["eval_accuracy"] 121 | eval_precision = metrics["eval_precision"] 122 | eval_recall = metrics["eval_recall"] 123 | print( 124 | f"Model: {model_path}, Loss: {eval_loss}, F1: {eval_f1}, Precision: {eval_precision}, Recall: {eval_recall}, Accuracy: {eval_accuracy}") 125 | 126 | if eval_loss < best_eval_loss: 127 | best_eval_loss = eval_loss 128 | best_model_path = model_path 129 | 130 | if best_model_path: 131 | print(f"Best model found: {best_model_path} with F1: {best_eval_loss}") 132 | # model = AutoModelForSequenceClassification.from_pretrained(best_model_path) 133 | # tokenizer = AutoTokenizer.from_pretrained(best_model_path) 134 | # model.save_pretrained(save_best_path) 135 | # tokenizer.save_pretrained(save_best_path) 136 | else: 137 | print("No models evaluated.") 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument("--task", type=str, default="orm") 143 | args = parser.parse_args() 144 | 145 | task = args.task 146 | main(task) 147 | -------------------------------------------------------------------------------- /sft_rw_manual_annotation/sft_rw.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 3 | from transformers import Trainer, TrainingArguments 4 | import os 5 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 6 | from datetime import datetime 7 | import numpy as np 8 | import argparse 9 | from functools import partial 10 | import torch 11 | 12 | 13 | 14 | def compute_metrics(eval_pred): 15 | logits, labels = eval_pred 16 | predictions = np.argmax(logits, axis=-1) 17 | 18 | accuracy = accuracy_score(labels, predictions) 19 | precision = precision_score(labels, predictions, average='weighted') 20 | recall = recall_score(labels, predictions, average='weighted') 21 | f1 = f1_score(labels, predictions, average='weighted') 22 | 23 | return { 24 | "accuracy": accuracy, 25 | "precision": precision, 26 | "recall": recall, 27 | "f1": f1 28 | } 29 | 30 | 31 | def choose_task(task: str = 'orm'): 32 | if task == "orm" or task == "prm" or task == "hrm" or task == "impunity_hrm": 33 | data_path = f"dataset/phase1/{task}" 34 | else: 35 | raise ValueError("The task should be either 'orm' or 'prm' or 'hrm'.") 36 | train_data_path = os.path.join(data_path, 'train.jsonl') 37 | eval_data_path = os.path.join(data_path, 'eval.jsonl') 38 | 39 | print(f"train data path: {train_data_path}") 40 | print(f"eval data path: {eval_data_path}") 41 | 42 | train_dataset = load_dataset('json', data_files=train_data_path) 43 | eval_dataset = load_dataset('json', data_files=eval_data_path) 44 | 45 | train_dataset = train_dataset['train'].shuffle() 46 | eval_dataset = eval_dataset['train'] 47 | 48 | # train_dataset = train_dataset.select([0, 1, 2, 3, 4, 5, 6, 7]) 49 | # eval_dataset = eval_dataset.select([0, 1, 2, 3, 4, 5, 6, 7, ]) 50 | 51 | return train_dataset, eval_dataset 52 | 53 | 54 | def train(model_path, train_dataset, eval_dataset, task='orm', time=datetime.now().strftime("%Y-%m-%d_%H-%M-%S")): 55 | model = AutoModelForSequenceClassification.from_pretrained(model_path, 56 | num_labels=2, 57 | attn_implementation="flash_attention_2", 58 | torch_dtype=torch.float16, 59 | ) 60 | tokenizer = AutoTokenizer.from_pretrained(model_path) 61 | 62 | model.config.pad_token_id = model.config.eos_token_id 63 | model_name = model_path.split('/')[-1] 64 | output_dir = f"outputs/{task}/{model_name}/{time}" 65 | 66 | EVAL_STEP = 20 67 | 68 | training_args = TrainingArguments( 69 | 70 | output_dir=output_dir, 71 | learning_rate=1e-5, 72 | max_grad_norm=0.01, 73 | # warmup_steps=20, 74 | warmup_steps=10, 75 | per_device_train_batch_size=5, 76 | per_device_eval_batch_size=64, 77 | num_train_epochs=40, 78 | eval_strategy="steps", 79 | eval_steps=EVAL_STEP, 80 | save_steps=EVAL_STEP, 81 | save_strategy="steps", 82 | logging_steps=EVAL_STEP, 83 | logging_dir=f"./logs/{task}/{model_name}/{time}", 84 | 85 | load_best_model_at_end=True, 86 | fp16=True, 87 | gradient_accumulation_steps=32, 88 | 89 | 90 | greater_is_better=False, 91 | save_total_limit=1, 92 | metric_for_best_model="eval_loss", 93 | 94 | deepspeed="deepspeed_config/RM_config_stage2.json", 95 | report_to="tensorboard", 96 | ) 97 | 98 | trainer = Trainer( 99 | model=model, 100 | args=training_args, 101 | 102 | train_dataset=train_dataset, 103 | eval_dataset=eval_dataset, 104 | tokenizer=tokenizer, 105 | compute_metrics=compute_metrics 106 | ) 107 | 108 | trainer.train() 109 | return trainer 110 | 111 | 112 | def preprocess_function(examples, tokenizer): 113 | inputs = examples["input"] 114 | labels = examples["label"] 115 | model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=4096) 116 | model_inputs["labels"] = labels 117 | return model_inputs 118 | 119 | 120 | if __name__ == '__main__': 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("--model_path", type=str, default="models/Qwen2.5-1.5B-math") 123 | parser.add_argument("--task", type=str, default='orm') 124 | 125 | args = parser.parse_args() 126 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, legacy=False) 127 | 128 | train_dataset, eval_dataset = choose_task(args.task) 129 | 130 | preprocess_with_tokenizer = partial(preprocess_function, tokenizer=tokenizer) 131 | 132 | train_dataset = train_dataset.map(preprocess_with_tokenizer, batched=True) 133 | eval_dataset = eval_dataset.map(preprocess_with_tokenizer, batched=True) 134 | 135 | train(model_path=args.model_path, train_dataset= train_dataset, eval_dataset=eval_dataset, task=args.task) 136 | # nohup accelerate launch --config_file accelerate_config/1gpu.yaml --gpu_ids 0 fine_tune/train.py --task orm > logs/new_orm.log & 137 | # nohup accelerate launch --config_file accelerate_config/3gpus.yaml --gpu_ids 2,3,4 --main_process_port 29501 fine_tune/train.py --task prm > logs/new_prm.log & 138 | # nohup accelerate launch --config_file accelerate_config/4gpus.yaml --gpu_ids 1,5,6,7 --main_process_port 29502 fine_tune/train.py --task hrm > logs/new_hrm.log & 139 | # nohup accelerate launch --config_file accelerate_config/3gpus.yaml --gpu_ids 0,1,2 fine_tune/train.py --task prm > logs/new_prm_v2.log & 140 | -------------------------------------------------------------------------------- /self-training/assign_node_without_rm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Goal: Traverse the MCTS tree, filter out nodes with high Monte Carlo (MC) scores, and save their logits into a file. 3 | These logits will later be used in the SFT policy model to compute the KL divergence against the reference model. 4 | """ 5 | import argparse 6 | import json 7 | import os 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from tqdm import tqdm 11 | 12 | 13 | def load_model(model_path): 14 | teacher_model = AutoModelForCausalLM.from_pretrained( 15 | model_path, 16 | attn_implementation="flash_attention_2", 17 | torch_dtype=torch.float16 18 | ) 19 | teacher_model.to("cuda") 20 | tokenizer = AutoTokenizer.from_pretrained(model_path) 21 | return teacher_model, tokenizer 22 | 23 | 24 | 25 | 26 | def read_jsonl(path, threshold): 27 | prompts = [] 28 | with open(path, 'r') as f: 29 | for line in f: 30 | dic = json.loads(line) 31 | score = dic.get('label', 0) 32 | if "Python" in dic['input']: 33 | continue 34 | if score >= threshold: 35 | prompts.append(dic['input']) 36 | return sorted(list(set(prompts)), reverse=True) 37 | 38 | 39 | 40 | def write_jsonl(jsonl_path, prompts, version, task, model, tokenizer, start_idx, end_idx, train_or_eval): 41 | tensor_dir = f"dataset/temp_tensor/{version}_{task}_{start_idx}_{end_idx}_{train_or_eval}" 42 | device = "cuda" 43 | os.makedirs(tensor_dir, exist_ok=True) 44 | with open(jsonl_path, 'a') as f: 45 | for idx, prompt in tqdm(enumerate(prompts), total=len(prompts)): 46 | tensor_file = os.path.join(tensor_dir, f"logits_{idx}.pt") 47 | inputs = tokenizer([prompt], return_tensors="pt", max_length=4096, padding=True, truncation=True) 48 | inputs = {key: value.to(device) for key, value in inputs.items()} 49 | 50 | with torch.no_grad(): 51 | outputs = model(**inputs) 52 | logits = outputs.logits 53 | 54 | logits = logits.type(torch.bfloat16) 55 | torch.save(logits, tensor_file) 56 | torch.cuda.empty_cache() 57 | 58 | record = {"input": prompt, "logits_path": tensor_file} 59 | f.write(json.dumps(record) + "\n") 60 | 61 | 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--version', type=str, default="v1") 65 | parser.add_argument("--task", type=str, default="hrm") 66 | parser.add_argument("--threshold", type=float, default=0.95) 67 | parser.add_argument("--slots", type=int, default=8) 68 | parser.add_argument("--idx", type=int, default=0) 69 | 70 | args = parser.parse_args() 71 | 72 | version = args.version 73 | task = args.task 74 | threshold = float(args.threshold) 75 | slots = int(args.slots) 76 | idx = int(args.idx) 77 | 78 | save_train_path = f"dataset/self_training_{version}_policy_without_rm/can_use_{task}_train_with_logits.jsonl" 79 | save_test_path = f"dataset/self_training_{version}_policy_without_rm/can_use_{task}_eval_with_logits.jsonl" 80 | 81 | original_data_train_path = f"dataset/self_training_{version}_scoring/{task}_train.jsonl" 82 | original_data_eval_path = f"dataset/self_training_{version}_scoring/{task}_test.jsonl" 83 | 84 | train_prompts = read_jsonl(original_data_train_path, threshold) 85 | length_train_prompts = len(train_prompts) 86 | train_start_idx = int(length_train_prompts / slots * idx) 87 | train_end_idx = int(length_train_prompts / slots * (idx + 1)) 88 | train_prompts = train_prompts[train_start_idx:train_end_idx] 89 | 90 | eval_prompts = read_jsonl(original_data_eval_path, threshold) 91 | length_eval_prompts = len(eval_prompts) 92 | eval_start_idx = int(length_eval_prompts / slots * idx) 93 | eval_end_idx = int(length_eval_prompts / slots * (idx + 1)) 94 | eval_prompts = eval_prompts[eval_start_idx:eval_end_idx] 95 | 96 | print("train prompts:", len(train_prompts), "eval prompts:", len(eval_prompts)) 97 | 98 | model, tokenizer = load_model("models/Qwen2.5-Math-7B-Instruct") 99 | 100 | write_jsonl(save_train_path, train_prompts, version, task, model, tokenizer, start_idx=train_start_idx, 101 | end_idx=train_end_idx, train_or_eval="train") 102 | write_jsonl(save_test_path, eval_prompts, version, task, model, tokenizer, start_idx=eval_start_idx, 103 | end_idx=eval_end_idx, train_or_eval="eval") 104 | # CUDA_VISIBLE_DEVICES=0 nohup python self-training/assign_node_without_rm.py --version v1 --task hrm --slots 8 --idx 0 > logs/logits_hrm_v1_0.log & 105 | # CUDA_VISIBLE_DEVICES=1 nohup python self-training/assign_node_without_rm.py --version v1 --task hrm --slots 8 --idx 1 > logs/logits_hrm_v1_1.log & 106 | # CUDA_VISIBLE_DEVICES=2 nohup python self-training/assign_node_without_rm.py --version v1 --task hrm --slots 8 --idx 2 > logs/logits_hrm_v1_2.log & 107 | # CUDA_VISIBLE_DEVICES=3 nohup python self-training/assign_node_without_rm.py --version v1 --task hrm --slots 8 --idx 3 > logs/logits_hrm_v1_3.log & 108 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/assign_node_without_rm.py --version v1 --task hrm --slots 8 --idx 4 > logs/logits_hrm_v1_4.log & 109 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/assign_node_without_rm.py --version v1 --task hrm --slots 8 --idx 5 > logs/logits_hrm_v1_5.log & 110 | # CUDA_VISIBLE_DEVICES=6 nohup python self-training/assign_node_without_rm.py --version v1 --task hrm --slots 8 --idx 6 > logs/logits_hrm_v1_6.log & 111 | # CUDA_VISIBLE_DEVICES=7 nohup python self-training/assign_node_without_rm.py --version v1 --task hrm --slots 8 --idx 7 > logs/logits_hrm_v1_7.log & 112 | # _ 113 | # CUDA_VISIBLE_DEVICES=0 nohup python self-training/assign_node_without_rm.py --version v1 --task prm --slots 6 --idx 0 > logs/logits_prm_v1_0.log & 114 | # CUDA_VISIBLE_DEVICES=1 nohup python self-training/assign_node_without_rm.py --version v1 --task prm --slots 6 --idx 1 > logs/logits_prm_v1_1.log & 115 | # CUDA_VISIBLE_DEVICES=2 nohup python self-training/assign_node_without_rm.py --version v1 --task prm --slots 6 --idx 2 > logs/logits_prm_v1_2.log & 116 | # CUDA_VISIBLE_DEVICES=3 nohup python self-training/assign_node_without_rm.py --version v1 --task prm --slots 6 --idx 3 > logs/logits_prm_v1_3.log & 117 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/assign_node_without_rm.py --version v1 --task prm --slots 6 --idx 4 > logs/logits_prm_v1_4.log & 118 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/assign_node_without_rm.py --version v1 --task prm --slots 6 --idx 5 > logs/logits_prm_v1_5.log & 119 | -------------------------------------------------------------------------------- /self-training/generate_mcts.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tree import Node 4 | from llm_query import parallel_query, sequential_query 5 | from construct_prompt import construct_policy_model_prompt_for_PRM_HRM 6 | from tqdm import tqdm 7 | import argparse 8 | import json 9 | import pickle 10 | from envs import * 11 | 12 | 13 | def load_train_data(data_path='dataset/prm_dataset/phase1_train.jsonl'): 14 | question_answer_pairs = [] 15 | with open(data_path) as f: 16 | for line in f.readlines(): 17 | dic = json.loads(line) 18 | if dic['label']['finish_reason'] != "solution": 19 | continue 20 | 21 | question = dic['question']['problem'] 22 | ground_truth = dic['question']['ground_truth_answer'] 23 | question_answer_pairs.append((question, ground_truth)) 24 | 25 | return question_answer_pairs 26 | 27 | 28 | def mcts_process_for_single_question(question, ground_truth, host, port, model_name, api_key, idx, version, task): 29 | root = Node(question=question) 30 | queue = [root] 31 | LLM_query_times = 0 32 | for height in range(1, MAX_HEIGHT): 33 | current_length = len(queue) 34 | 35 | for _ in range(current_length): 36 | current_node = queue.pop(0) 37 | 38 | temp_children_list = [] 39 | previous_answer = current_node.previous_answer[:] 40 | 41 | if not previous_answer: 42 | prompt = construct_policy_model_prompt_for_PRM_HRM(question, None) 43 | else: 44 | prompt = construct_policy_model_prompt_for_PRM_HRM(question, "\n".join(previous_answer)) 45 | try: 46 | LLM_query_times += 1 47 | intermediate_steps = parallel_query(host, port, model_name, prompt, api_key) 48 | # print("成功并行运算") 49 | except Exception as e: 50 | print(e) 51 | print("一个一个算") 52 | intermediate_steps = [] 53 | for cnt_child in range(NUMBER_OF_CHILDREN): 54 | try: 55 | intermediate_step = sequential_query(host, port, model_name, prompt, api_key) 56 | except Exception as e: 57 | print(e) 58 | intermediate_step = "IT SHOULD STOP!" 59 | intermediate_steps.append(intermediate_step) 60 | 61 | for intermediate_step in intermediate_steps: 62 | node = Node(question=question, parent=current_node, height=height) 63 | previous_answer = current_node.previous_answer[:] 64 | node.set_previous_answer(previous_answer) 65 | 66 | if intermediate_step == "IT SHOULD STOP!": 67 | node.should_stop = True 68 | 69 | node.previous_answer.append(intermediate_step) 70 | if whether_contain_answer(intermediate_step): 71 | answer = extract_answer(intermediate_step) 72 | node.have_the_answer(answer, ground_truth) 73 | elif node.should_stop: 74 | pass 75 | else: 76 | queue.append(node) 77 | 78 | temp_children_list.append(node) 79 | # print("------") 80 | # print("\n".join(node.previous_answer)) 81 | # for cnt_child in range(NUMBER_OF_CHILDREN): 82 | # 83 | # node = Node(question=question, parent=current_node, height=height) 84 | # 85 | # previous_answer = current_node.previous_answer[:] 86 | # node.set_previous_answer(previous_answer) 87 | # # print("Current height is ", height) 88 | # if not previous_answer: 89 | # prompt = construct_policy_model_prompt_for_PRM_HRM(question, None) 90 | # else: 91 | # prompt = construct_policy_model_prompt_for_PRM_HRM(question, "\n".join(previous_answer)) 92 | # # print(prompt) 93 | # # print("----prompt end-----") 94 | # try: 95 | # intermediate_step = sequential_query(host, port, model_name, prompt, api_key) 96 | # except Exception as e: 97 | # print(e) 98 | # node.should_stop = True 99 | # intermediate_step = "IT SHOULD STOP!" 100 | # node.previous_answer.append(intermediate_step) 101 | # # print(intermediate_step) 102 | # if whether_contain_answer(intermediate_step): 103 | # answer = extract_answer(intermediate_step) 104 | # node.have_the_answer(answer, ground_truth) 105 | # elif node.should_stop: 106 | # pass 107 | # else: 108 | # queue.append(node) 109 | # 110 | # # print("----------------") 111 | # # 112 | # # print("\n".join(node.previous_answer)) 113 | # # print("----------------\n\n") 114 | # temp_children_list.append(node) 115 | current_node.add_children(temp_children_list) 116 | 117 | if not queue: 118 | break 119 | print(f"For question {idx}, it accesses {LLM_query_times} times.") 120 | os.makedirs(f"dataset/self_training_{version}_{task}", exist_ok=True) 121 | with open(f"dataset/self_training_{version}_{task}/{idx}.pkl", 'wb') as f: 122 | pickle.dump(root, f) 123 | 124 | 125 | def mcts_process_for_special_range(start_idx, end_idx, host, port, model_name, api_key, 126 | data_path='dataset/prm_dataset/phase1_train.jsonl', version='v1', task='prm'): 127 | question_answer_pairs = load_train_data(data_path=data_path) 128 | for idx in tqdm(range(start_idx, end_idx), total=end_idx - start_idx): 129 | question, ground_truth = question_answer_pairs[idx] 130 | mcts_process_for_single_question(question, ground_truth, host, port, model_name, api_key, idx, version, task) 131 | 132 | 133 | def extract_answer(text: str, placeholder="# Answer", end_placeholder='# END!'): 134 | text = text.lower() 135 | left_idx = text.rindex(placeholder.lower()) 136 | length = len(placeholder) 137 | try: 138 | right_idx = text.rindex(end_placeholder.lower()) 139 | except: 140 | right_idx = -1 141 | return text[left_idx + length:right_idx].strip() 142 | 143 | 144 | def whether_contain_answer(text): 145 | return "# Answer" in text 146 | 147 | 148 | if __name__ == '__main__': 149 | # 一共808个数据 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument("--host", type=str, default="127.0.0.1") 152 | parser.add_argument("--port", type=str, default="10086") 153 | parser.add_argument("--model_name", type=str, default="policy_model") 154 | parser.add_argument("--api_key", type=str, default="xxxx") 155 | parser.add_argument("--train_data_path", type=str, default='dataset/prm_dataset/phase1_train.jsonl') 156 | parser.add_argument("--start_idx", type=str, default="0") 157 | parser.add_argument("--end_idx", type=str, default="1") 158 | parser.add_argument("--version", type=str, default="v1") 159 | parser.add_argument("--task", type=str, default="prm") 160 | args = parser.parse_args() 161 | 162 | host = args.host 163 | port = args.port 164 | model_name = args.model_name 165 | api_key = args.api_key 166 | train_data_path = args.train_data_path 167 | start_idx = int(args.start_idx) 168 | end_idx = int(args.end_idx) 169 | version = args.version 170 | task = args.task 171 | 172 | mcts_process_for_special_range(start_idx, end_idx, host, port, model_name, api_key, train_data_path, version, task) 173 | # nohup python self-training/generate_mcts.py --start_idx 0 --end_idx 400 > logs/self_train_0_400.log & 174 | # nohup python self-training/generate_mcts.py --port 10087 --model_name policy_model_v2 --start_idx 400 --end_idx 809 > logs/self_train_400_809.log & 175 | 176 | # nohup python self-training/generate_mcts.py --port 10088 --model_name hrm_v1_policy_first --version v2 --task hrm --start_idx 0 --end_idx 200 > logs/self_train_hrm_0_200.log & 177 | # nohup python self-training/generate_mcts.py --port 10088 --model_name hrm_v1_policy_first --version v2 --task hrm --start_idx 200 --end_idx 400 > logs/self_train_hrm_200_400.log & 178 | # nohup python self-training/generate_mcts.py --port 10089 --model_name hrm_v1_policy_two --version v2 --task hrm --start_idx 400 --end_idx 600 > logs/self_train_hrm_400_600.log & 179 | # nohup python self-training/generate_mcts.py --port 10089 --model_name hrm_v1_policy_two --version v2 --task hrm --start_idx 600 --end_idx 809 > logs/self_train_hrm_600_809.log & 180 | 181 | -------------------------------------------------------------------------------- /self-training/construct_prm_train_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import random 4 | 5 | from tree import Node 6 | from envs import * 7 | import matplotlib.pyplot as plt 8 | from collections import defaultdict 9 | from construct_prompt import construct_PRM_HRM_prompt_v2 10 | import argparse 11 | from tqdm import tqdm 12 | import json 13 | 14 | import re 15 | def ensure_step_spacing(text): 16 | """ 17 | 确保 '# Step X' 后面有空格或换行,如果没有则添加一个空格 18 | """ 19 | pattern = r"(# Step \d)(?=[^\s])" # 匹配 "# Step X" 且后面不是空格或换行 20 | corrected_text = re.sub(pattern, r"\1 " + "\n", text) # 在匹配的后面加空格 21 | return corrected_text 22 | 23 | 24 | def replace_string(a_list): 25 | height_2_placeholder = {1: '# Step 2', 2: "# Step 3", 3: "# Step 4", 4: "# Step 5", 5: '# END!'} 26 | height = len(a_list) 27 | # if height == 6: 28 | # return a_list 29 | if height != 5: 30 | a_list[-1] = a_list[-1].replace(height_2_placeholder[height], "\n") 31 | if height > 1: 32 | a_list[-2] = a_list[-2].replace(height_2_placeholder[height - 1], "\n") 33 | a_list[-1] = height_2_placeholder[height - 1] + "\n\n" + a_list[- 1] 34 | return a_list 35 | 36 | 37 | def load_all_pickle_file_paths(base_path="dataset/self_training"): 38 | answers = [] 39 | for filename in os.listdir(base_path): 40 | answers.append(os.path.join(base_path, filename)) 41 | return answers 42 | 43 | 44 | def is_leaf_node(node: Node): 45 | if not node.children: 46 | return True 47 | return False 48 | 49 | 50 | def assign_score_for_every_nodes(path): 51 | with open(path, 'rb') as f: 52 | data = pickle.load(f) 53 | LEAF_NODE = 0 54 | 55 | def dfs(node: Node): 56 | if node.score == -1: 57 | node.score = 0 58 | 59 | if is_leaf_node(node): 60 | nonlocal LEAF_NODE 61 | LEAF_NODE += 1 62 | if node.is_correct: 63 | node.score = 1 64 | return 1 65 | else: 66 | node.score = 0 67 | return 0 68 | else: 69 | for child in node.children: 70 | cnt = dfs(child) 71 | node.score += cnt 72 | # if node.score!=0: 73 | # print(node.score) 74 | return node.score 75 | 76 | def dfs_v2(node: Node): 77 | if hasattr(node, 'total'): 78 | pass 79 | else: 80 | node.total = 0 81 | 82 | if is_leaf_node(node): 83 | node.total = 1 84 | return 1 85 | else: 86 | for child in node.children: 87 | cnt = dfs_v2(child) 88 | node.total += cnt 89 | return node.total 90 | 91 | dfs(data) 92 | dfs_v2(data) 93 | queue = [data] 94 | CORRECT_TOTAL_NUMBERS = 0 95 | TOTAL_NUMBERS = 0 96 | 97 | score_height_2_node = defaultdict(lambda: defaultdict(list)) 98 | while queue: 99 | temp_length = len(queue) 100 | for _ in range(temp_length): 101 | node = queue.pop(0) 102 | node.score = node.score / node.total 103 | if node.height != 0: 104 | score_height_2_node[node.height][node.score].append(node) 105 | 106 | if node.score == 1: 107 | CORRECT_TOTAL_NUMBERS += 1 108 | TOTAL_NUMBERS += 1 109 | for child in node.children: 110 | queue.append(child) 111 | # print(score_height_2_node.keys()) 112 | 113 | data.score = 1 114 | # print(LEAF_NODE) 115 | # print(TOTAL_NUMBERS) 116 | # 117 | # print(CORRECT_TOTAL_NUMBERS) 118 | # from collections import Counter 119 | # print(Counter(scores)) 120 | # plot_distribution(scores) 121 | return score_height_2_node 122 | 123 | 124 | def write_new_data(root, idx): 125 | with open(f"dataset/self_training_v1_scoring/{idx}.pkl", 'wb') as f: 126 | pickle.dump(root, f) 127 | 128 | 129 | def plot_distribution(data): 130 | plt.hist(data, bins=10, edgecolor='black') # Adjust the bins as needed 131 | 132 | plt.xlabel('Value') 133 | plt.ylabel('Frequency') 134 | plt.title('Distribution of Numbers') 135 | 136 | plt.savefig("score_distribution.png") 137 | 138 | 139 | def sort_node(score_height_2_node): 140 | answers = [] 141 | for height in score_height_2_node: 142 | different_score_dict = score_height_2_node[height] 143 | if len(different_score_dict) > 1: 144 | min_length = float("inf") 145 | for score, list_ in different_score_dict.items(): 146 | min_length = min(min_length, len(list_)) 147 | for score, list_ in different_score_dict.items(): 148 | nodes = random.sample(list_, max(min_length, int(len(list_) * 0.1))) 149 | for node in nodes: 150 | if node.previous_answer: 151 | # origin_answer = node.previous_answer[:] 152 | # origin_length = len(node.previous_answer) 153 | 154 | node.previous_answer = replace_string(node.previous_answer) 155 | 156 | prompt = construct_PRM_HRM_prompt_v2(node.question, "".join(node.previous_answer[:-1]), 157 | node.previous_answer[-1]) 158 | if "# Step 1" not in prompt: 159 | continue 160 | 161 | # prompt = construct_policy_model_prompt_for_PRM_HRM(node.question, "".join(node.previous_answer)) 162 | # answers.append(json.dumps( 163 | # {"input": prompt, "label": score, "reasoning_answer": origin_answer, 164 | # "length": len(node.previous_answer), "origin_length": origin_length})) 165 | prompt = ensure_step_spacing(prompt) 166 | 167 | answers.append(json.dumps({"input": prompt, "label": score})) 168 | 169 | else: 170 | for score, list_ in different_score_dict.items(): 171 | k = min(5, len(list_)) 172 | nodes = random.sample(list_, k) 173 | for node in nodes: 174 | if node.previous_answer: 175 | # origin_answer = node.previous_answer[:] 176 | # origin_length = len(node.previous_answer) 177 | 178 | node.previous_answer = replace_string(node.previous_answer) 179 | 180 | prompt = construct_PRM_HRM_prompt_v2(node.question, "".join(node.previous_answer[:-1]), 181 | node.previous_answer[-1]) 182 | if "# Step 1" not in prompt: 183 | continue 184 | # answers.append(json.dumps( 185 | # {"input": prompt, "label": score, "reasoning_answer": origin_answer, 186 | # "length": len(node.previous_answer), "origin_length": origin_length})) 187 | prompt = ensure_step_spacing(prompt) 188 | 189 | answers.append(json.dumps({"input": prompt, "label": score})) 190 | return set(answers) 191 | 192 | 193 | if __name__ == '__main__': 194 | parser = argparse.ArgumentParser() 195 | parser.add_argument("--base_path", type=str, default="dataset/self_training") 196 | parser.add_argument("--save_path", type=str, default="dataset/self_training_v1_scoring") 197 | 198 | args = parser.parse_args() 199 | base_path = args.base_path 200 | save_path = args.save_path 201 | os.makedirs(save_path, exist_ok=True) 202 | 203 | paths = load_all_pickle_file_paths() 204 | cnt = 0 205 | scores = [] 206 | train_idx = int(0.8 * len(paths)) 207 | train_paths = paths[:train_idx] 208 | test_paths = paths[train_idx:] 209 | 210 | with open(os.path.join(save_path, "prm_train.jsonl"), 'w') as f: 211 | for path in tqdm(train_paths, total=len(train_paths)): 212 | score_height_2_node = assign_score_for_every_nodes(path) 213 | answers = sort_node(score_height_2_node) 214 | for answer in answers: 215 | scores.append(json.loads(answer)['label']) 216 | # json_string = json.dumps(answer) 217 | f.write(answer + "\n") 218 | cnt += 1 219 | 220 | with open(os.path.join(save_path, "prm_test.jsonl"), 'w') as f: 221 | for path in tqdm(test_paths, total=len(test_paths)): 222 | score_height_2_node = assign_score_for_every_nodes(path) 223 | answers = sort_node(score_height_2_node) 224 | for answer in answers: 225 | scores.append(json.loads(answer)['label']) 226 | # json_string = json.dumps(answer) 227 | f.write(answer + "\n") 228 | cnt += 1 229 | print(cnt) 230 | plot_distribution(scores) 231 | # from collections import Counter 232 | # print(Counter(scores)) 233 | -------------------------------------------------------------------------------- /self-training/grading/grader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use OpenAI open source parser to mark the grade. 3 | 4 | Answer checker API that uses sympy to simplify expressions and check for equality. 5 | 6 | Call grade_answer(given_answer: str, ground_truth: str). 7 | """ 8 | import re 9 | import sympy 10 | from pylatexenc import latex2text 11 | from sympy.parsing import sympy_parser 12 | 13 | from sft_rw_manual_annotation.grading import math_normalize 14 | 15 | # sympy might hang -- we don't care about trying to be lenient in these cases 16 | BAD_SUBSTRINGS = ["^{", "^("] 17 | BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] 18 | TUPLE_CHARS = "()[]" 19 | 20 | 21 | def _sympy_parse(expr: str): 22 | """Parses an expression with sympy.""" 23 | py_expr = expr.replace("^", "**") 24 | return sympy_parser.parse_expr( 25 | py_expr, 26 | transformations=( 27 | sympy_parser.standard_transformations 28 | + (sympy_parser.implicit_multiplication_application,) 29 | ), 30 | ) 31 | 32 | 33 | def _parse_latex(expr: str) -> str: 34 | """Attempts to parse latex to an expression sympy can read.""" 35 | expr = expr.replace("\\tfrac", "\\frac") 36 | expr = expr.replace("\\dfrac", "\\frac") 37 | expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. 38 | expr = latex2text.LatexNodes2Text().latex_to_text(expr) 39 | 40 | # Replace the specific characters that this parser uses. 41 | expr = expr.replace("√", "sqrt") 42 | expr = expr.replace("π", "pi") 43 | expr = expr.replace("∞", "inf") 44 | expr = expr.replace("∪", "U") 45 | expr = expr.replace("·", "*") 46 | expr = expr.replace("×", "*") 47 | 48 | return expr.strip() 49 | 50 | 51 | def _is_float(num: str) -> bool: 52 | try: 53 | float(num) 54 | return True 55 | except ValueError: 56 | return False 57 | 58 | 59 | def _is_int(x: float) -> bool: 60 | try: 61 | return abs(x - int(round(x))) <= 1e-7 62 | except: 63 | return False 64 | 65 | 66 | def _is_frac(expr: str) -> bool: 67 | return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) 68 | 69 | 70 | def _str_is_int(x: str) -> bool: 71 | try: 72 | x = _strip_properly_formatted_commas(x) 73 | x = float(x) 74 | return abs(x - int(round(x))) <= 1e-7 75 | except: 76 | return False 77 | 78 | 79 | def _str_to_int(x: str) -> bool: 80 | x = x.replace(",", "") 81 | x = float(x) 82 | return int(x) 83 | 84 | 85 | def _inject_implicit_mixed_number(step: str): 86 | """ 87 | Automatically make a mixed number evalable 88 | e.g. 7 3/4 => 7+3/4 89 | """ 90 | p1 = re.compile("([0-9]) +([0-9])") 91 | step = p1.sub("\\1+\\2", step) ## implicit mults 92 | return step 93 | 94 | 95 | def _strip_properly_formatted_commas(expr: str): 96 | # We want to be careful because we don't want to strip tuple commas 97 | p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") 98 | while True: 99 | next_expr = p1.sub("\\1\\3\\4", expr) 100 | if next_expr == expr: 101 | break 102 | expr = next_expr 103 | return next_expr 104 | 105 | 106 | def _normalize(expr: str) -> str: 107 | """Normalize answer expressions.""" 108 | if expr is None: 109 | return None 110 | 111 | # Remove enclosing `\text{}`. 112 | m = re.search("^\\\\text\{(?P.+?)\}$", expr) 113 | if m is not None: 114 | expr = m.group("text") 115 | 116 | expr = expr.replace("\\%", "%") 117 | expr = expr.replace("\\$", "$") 118 | expr = expr.replace("$", "") 119 | expr = expr.replace("%", "") 120 | expr = expr.replace(" or ", " , ") 121 | expr = expr.replace(" and ", " , ") 122 | 123 | expr = expr.replace("million", "*10^6") 124 | expr = expr.replace("billion", "*10^9") 125 | expr = expr.replace("trillion", "*10^12") 126 | 127 | for unit in [ 128 | "degree", 129 | "cm", 130 | "centimeter", 131 | "meter", 132 | "mile", 133 | "second", 134 | "minute", 135 | "hour", 136 | "day", 137 | "week", 138 | "month", 139 | "year", 140 | "foot", 141 | "feet", 142 | "inch", 143 | "yard", 144 | ]: 145 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) 146 | expr = re.sub(f"\^ *\\\\circ", "", expr) 147 | 148 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": 149 | expr = expr[1:-1] 150 | 151 | expr = re.sub(",\\\\! *", "", expr) 152 | if _is_float(expr) and _is_int(float(expr)): 153 | expr = str(int(round(float(expr)))) 154 | if "\\" in expr: 155 | try: 156 | expr = _parse_latex(expr) 157 | except: 158 | pass 159 | 160 | # edge case with mixed numbers and negative signs 161 | expr = re.sub("- *", "-", expr) 162 | 163 | expr = _inject_implicit_mixed_number(expr) 164 | expr = expr.replace(" ", "") 165 | 166 | # if we somehow still have latex braces here, just drop them 167 | expr = expr.replace("{", "") 168 | expr = expr.replace("}", "") 169 | 170 | # don't be case sensitive for text answers 171 | expr = expr.lower() 172 | 173 | if _str_is_int(expr): 174 | expr = str(_str_to_int(expr)) 175 | 176 | return expr 177 | 178 | 179 | def count_unknown_letters_in_expr(expr: str): 180 | expr = expr.replace("sqrt", "") 181 | expr = expr.replace("frac", "") 182 | letters_in_expr = set([x for x in expr if x.isalpha()]) 183 | return len(letters_in_expr) 184 | 185 | 186 | def should_allow_eval(expr: str): 187 | # we don't want to try parsing unknown text or functions of more than two variables 188 | if count_unknown_letters_in_expr(expr) > 2: 189 | return False 190 | 191 | for bad_string in BAD_SUBSTRINGS: 192 | if bad_string in expr: 193 | return False 194 | 195 | for bad_regex in BAD_REGEXES: 196 | if re.search(bad_regex, expr) is not None: 197 | return False 198 | 199 | return True 200 | 201 | 202 | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): 203 | are_equal = False 204 | try: 205 | expr = f"({ground_truth_normalized})-({given_normalized})" 206 | if should_allow_eval(expr): 207 | sympy_diff = _sympy_parse(expr) 208 | simplified = sympy.simplify(sympy_diff) 209 | if simplified == 0: 210 | are_equal = True 211 | except: 212 | pass 213 | return are_equal 214 | 215 | 216 | def split_tuple(expr: str): 217 | """ 218 | Split the elements in a tuple/interval, while handling well-formatted commas in large numbers 219 | """ 220 | expr = _strip_properly_formatted_commas(expr) 221 | if len(expr) == 0: 222 | return [] 223 | if ( 224 | len(expr) > 2 225 | and expr[0] in TUPLE_CHARS 226 | and expr[-1] in TUPLE_CHARS 227 | and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) 228 | ): 229 | elems = [elem.strip() for elem in expr[1:-1].split(",")] 230 | else: 231 | elems = [expr] 232 | return elems 233 | 234 | 235 | def grade_answer(given_answer: str, ground_truth: str) -> bool: 236 | """ 237 | The answer will be considered correct if: 238 | (a) it normalizes to the same string as the ground truth answer 239 | OR 240 | (b) sympy can simplify the difference between the expressions to 0 241 | """ 242 | if given_answer is None: 243 | return False 244 | 245 | ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth) 246 | given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer) 247 | 248 | # be at least as lenient as mathd 249 | if ground_truth_normalized_mathd == given_answer_normalized_mathd: 250 | return True 251 | 252 | ground_truth_normalized = _normalize(ground_truth) 253 | given_normalized = _normalize(given_answer) 254 | 255 | if ground_truth_normalized is None: 256 | return False 257 | 258 | if ground_truth_normalized == given_normalized: 259 | return True 260 | 261 | if len(given_normalized) == 0: 262 | return False 263 | 264 | ground_truth_elems = split_tuple(ground_truth_normalized) 265 | given_elems = split_tuple(given_normalized) 266 | 267 | if len(ground_truth_elems) > 1 and ( 268 | ground_truth_normalized[0] != given_normalized[0] 269 | or ground_truth_normalized[-1] != given_normalized[-1] 270 | ): 271 | is_correct = False 272 | elif len(ground_truth_elems) != len(given_elems): 273 | is_correct = False 274 | else: 275 | for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): 276 | if _is_frac(ground_truth_elem) and _is_frac(given_elem): 277 | # if fractions aren't reduced, then shouldn't be marked as correct 278 | # so, we don't want to allow sympy.simplify in this case 279 | is_correct = ground_truth_elem == given_elem 280 | elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): 281 | # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) 282 | is_correct = False 283 | else: 284 | is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) 285 | if not is_correct: 286 | break 287 | 288 | return is_correct 289 | 290 | -------------------------------------------------------------------------------- /sft_rw_manual_annotation/grading/grader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use OpenAI open source parser to mark the grade. 3 | 4 | Answer checker API that uses sympy to simplify expressions and check for equality. 5 | 6 | Call grade_answer(given_answer: str, ground_truth: str). 7 | """ 8 | import re 9 | import sympy 10 | from pylatexenc import latex2text 11 | from sympy.parsing import sympy_parser 12 | 13 | from sft_rw_manual_annotation.grading import math_normalize 14 | 15 | # sympy might hang -- we don't care about trying to be lenient in these cases 16 | BAD_SUBSTRINGS = ["^{", "^("] 17 | BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] 18 | TUPLE_CHARS = "()[]" 19 | 20 | 21 | def _sympy_parse(expr: str): 22 | """Parses an expression with sympy.""" 23 | py_expr = expr.replace("^", "**") 24 | return sympy_parser.parse_expr( 25 | py_expr, 26 | transformations=( 27 | sympy_parser.standard_transformations 28 | + (sympy_parser.implicit_multiplication_application,) 29 | ), 30 | ) 31 | 32 | 33 | def _parse_latex(expr: str) -> str: 34 | """Attempts to parse latex to an expression sympy can read.""" 35 | expr = expr.replace("\\tfrac", "\\frac") 36 | expr = expr.replace("\\dfrac", "\\frac") 37 | expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. 38 | expr = latex2text.LatexNodes2Text().latex_to_text(expr) 39 | 40 | # Replace the specific characters that this parser uses. 41 | expr = expr.replace("√", "sqrt") 42 | expr = expr.replace("π", "pi") 43 | expr = expr.replace("∞", "inf") 44 | expr = expr.replace("∪", "U") 45 | expr = expr.replace("·", "*") 46 | expr = expr.replace("×", "*") 47 | 48 | return expr.strip() 49 | 50 | 51 | def _is_float(num: str) -> bool: 52 | try: 53 | float(num) 54 | return True 55 | except ValueError: 56 | return False 57 | 58 | 59 | def _is_int(x: float) -> bool: 60 | try: 61 | return abs(x - int(round(x))) <= 1e-7 62 | except: 63 | return False 64 | 65 | 66 | def _is_frac(expr: str) -> bool: 67 | return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) 68 | 69 | 70 | def _str_is_int(x: str) -> bool: 71 | try: 72 | x = _strip_properly_formatted_commas(x) 73 | x = float(x) 74 | return abs(x - int(round(x))) <= 1e-7 75 | except: 76 | return False 77 | 78 | 79 | def _str_to_int(x: str) -> bool: 80 | x = x.replace(",", "") 81 | x = float(x) 82 | return int(x) 83 | 84 | 85 | def _inject_implicit_mixed_number(step: str): 86 | """ 87 | Automatically make a mixed number evalable 88 | e.g. 7 3/4 => 7+3/4 89 | """ 90 | p1 = re.compile("([0-9]) +([0-9])") 91 | step = p1.sub("\\1+\\2", step) ## implicit mults 92 | return step 93 | 94 | 95 | def _strip_properly_formatted_commas(expr: str): 96 | # We want to be careful because we don't want to strip tuple commas 97 | p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") 98 | while True: 99 | next_expr = p1.sub("\\1\\3\\4", expr) 100 | if next_expr == expr: 101 | break 102 | expr = next_expr 103 | return next_expr 104 | 105 | 106 | def _normalize(expr: str) -> str: 107 | """Normalize answer expressions.""" 108 | if expr is None: 109 | return None 110 | 111 | # Remove enclosing `\text{}`. 112 | m = re.search("^\\\\text\{(?P.+?)\}$", expr) 113 | if m is not None: 114 | expr = m.group("text") 115 | 116 | expr = expr.replace("\\%", "%") 117 | expr = expr.replace("\\$", "$") 118 | expr = expr.replace("$", "") 119 | expr = expr.replace("%", "") 120 | expr = expr.replace(" or ", " , ") 121 | expr = expr.replace(" and ", " , ") 122 | 123 | expr = expr.replace("million", "*10^6") 124 | expr = expr.replace("billion", "*10^9") 125 | expr = expr.replace("trillion", "*10^12") 126 | 127 | for unit in [ 128 | "degree", 129 | "cm", 130 | "centimeter", 131 | "meter", 132 | "mile", 133 | "second", 134 | "minute", 135 | "hour", 136 | "day", 137 | "week", 138 | "month", 139 | "year", 140 | "foot", 141 | "feet", 142 | "inch", 143 | "yard", 144 | ]: 145 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) 146 | expr = re.sub(f"\^ *\\\\circ", "", expr) 147 | 148 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": 149 | expr = expr[1:-1] 150 | 151 | expr = re.sub(",\\\\! *", "", expr) 152 | if _is_float(expr) and _is_int(float(expr)): 153 | expr = str(int(round(float(expr)))) 154 | if "\\" in expr: 155 | try: 156 | expr = _parse_latex(expr) 157 | except: 158 | pass 159 | 160 | # edge case with mixed numbers and negative signs 161 | expr = re.sub("- *", "-", expr) 162 | 163 | expr = _inject_implicit_mixed_number(expr) 164 | expr = expr.replace(" ", "") 165 | 166 | # if we somehow still have latex braces here, just drop them 167 | expr = expr.replace("{", "") 168 | expr = expr.replace("}", "") 169 | 170 | # don't be case sensitive for text answers 171 | expr = expr.lower() 172 | 173 | if _str_is_int(expr): 174 | expr = str(_str_to_int(expr)) 175 | 176 | return expr 177 | 178 | 179 | def count_unknown_letters_in_expr(expr: str): 180 | expr = expr.replace("sqrt", "") 181 | expr = expr.replace("frac", "") 182 | letters_in_expr = set([x for x in expr if x.isalpha()]) 183 | return len(letters_in_expr) 184 | 185 | 186 | def should_allow_eval(expr: str): 187 | # we don't want to try parsing unknown text or functions of more than two variables 188 | if count_unknown_letters_in_expr(expr) > 2: 189 | return False 190 | 191 | for bad_string in BAD_SUBSTRINGS: 192 | if bad_string in expr: 193 | return False 194 | 195 | for bad_regex in BAD_REGEXES: 196 | if re.search(bad_regex, expr) is not None: 197 | return False 198 | 199 | return True 200 | 201 | 202 | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): 203 | are_equal = False 204 | try: 205 | expr = f"({ground_truth_normalized})-({given_normalized})" 206 | if should_allow_eval(expr): 207 | sympy_diff = _sympy_parse(expr) 208 | simplified = sympy.simplify(sympy_diff) 209 | if simplified == 0: 210 | are_equal = True 211 | except: 212 | pass 213 | return are_equal 214 | 215 | 216 | def split_tuple(expr: str): 217 | """ 218 | Split the elements in a tuple/interval, while handling well-formatted commas in large numbers 219 | """ 220 | expr = _strip_properly_formatted_commas(expr) 221 | if len(expr) == 0: 222 | return [] 223 | if ( 224 | len(expr) > 2 225 | and expr[0] in TUPLE_CHARS 226 | and expr[-1] in TUPLE_CHARS 227 | and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) 228 | ): 229 | elems = [elem.strip() for elem in expr[1:-1].split(",")] 230 | else: 231 | elems = [expr] 232 | return elems 233 | 234 | 235 | def grade_answer(given_answer: str, ground_truth: str) -> bool: 236 | """ 237 | The answer will be considered correct if: 238 | (a) it normalizes to the same string as the ground truth answer 239 | OR 240 | (b) sympy can simplify the difference between the expressions to 0 241 | """ 242 | if given_answer is None: 243 | return False 244 | 245 | ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth) 246 | given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer) 247 | 248 | # be at least as lenient as mathd 249 | if ground_truth_normalized_mathd == given_answer_normalized_mathd: 250 | return True 251 | 252 | ground_truth_normalized = _normalize(ground_truth) 253 | given_normalized = _normalize(given_answer) 254 | 255 | if ground_truth_normalized is None: 256 | return False 257 | 258 | if ground_truth_normalized == given_normalized: 259 | return True 260 | 261 | if len(given_normalized) == 0: 262 | return False 263 | 264 | ground_truth_elems = split_tuple(ground_truth_normalized) 265 | given_elems = split_tuple(given_normalized) 266 | 267 | if len(ground_truth_elems) > 1 and ( 268 | ground_truth_normalized[0] != given_normalized[0] 269 | or ground_truth_normalized[-1] != given_normalized[-1] 270 | ): 271 | is_correct = False 272 | elif len(ground_truth_elems) != len(given_elems): 273 | is_correct = False 274 | else: 275 | for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): 276 | if _is_frac(ground_truth_elem) and _is_frac(given_elem): 277 | # if fractions aren't reduced, then shouldn't be marked as correct 278 | # so, we don't want to allow sympy.simplify in this case 279 | is_correct = ground_truth_elem == given_elem 280 | elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): 281 | # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) 282 | is_correct = False 283 | else: 284 | is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) 285 | if not is_correct: 286 | break 287 | 288 | return is_correct 289 | 290 | -------------------------------------------------------------------------------- /self-training/best_of_n_gsm8k.py: -------------------------------------------------------------------------------- 1 | """ 2 | PRM and HRM are trained from auto-labeled reasoning process in the PRM800K dataset. And evaluation in gsm8k dataset. 3 | """ 4 | import argparse 5 | import json 6 | import os 7 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 8 | import torch 9 | from llm_query import sequential_query, parallel_query 10 | from construct_prompt import construct_policy_model_prompt_for_PRM_HRM, construct_PRM_HRM_prompt_v2 11 | from grading.grader import grade_answer 12 | from tqdm import tqdm 13 | import gc 14 | 15 | MAX_HEIGHT = 8 16 | 17 | 18 | def load_rw_model(model_path): 19 | model = AutoModelForSequenceClassification.from_pretrained(model_path, 20 | num_labels=1, 21 | attn_implementation="flash_attention_2", 22 | torch_dtype=torch.float16, 23 | ) 24 | tokenizer = AutoTokenizer.from_pretrained(model_path) 25 | 26 | model.config.pad_token_id = model.config.eos_token_id 27 | device = torch.device("cuda") 28 | model.to(device) 29 | return model, tokenizer 30 | 31 | 32 | def calculate_PRM_HRM_scores(rm_model, tokenizer, question, N, current_steps, previous_steps="", N_batch=4): 33 | assert len(current_steps) == N, "The length of current steps must be equal to N" 34 | 35 | device = "cuda" 36 | current_step_score_pairs = [] 37 | 38 | for i in range(0, N, N_batch): 39 | batch_steps = current_steps[i:i + N_batch] 40 | prompts = [construct_PRM_HRM_prompt_v2(question, previous_steps, current_step) for current_step in batch_steps] 41 | 42 | inputs = tokenizer(prompts, return_tensors="pt", max_length=4096, padding=True, truncation=True) 43 | inputs = {key: value.to(device) for key, value in inputs.items()} 44 | 45 | with torch.no_grad(): 46 | outputs = rm_model(**inputs) 47 | logits = outputs.logits 48 | 49 | positive_scores = logits[:, 0].tolist() 50 | 51 | batch_step_score_pairs = [(score, step) for score, step in zip(positive_scores, batch_steps)] 52 | current_step_score_pairs.extend(batch_step_score_pairs) 53 | 54 | current_step_score_pairs.sort(reverse=True) 55 | return current_step_score_pairs 56 | 57 | 58 | def load_test_data(data_path='dataset/gsm-8k/evaluation.jsonl'): 59 | question_answer_pairs = [] 60 | with open(data_path) as f: 61 | for line in f.readlines(): 62 | dic = json.loads(line) 63 | question = dic['question'] 64 | ground_truth = dic['answer'] 65 | question_answer_pairs.append((question, ground_truth)) 66 | 67 | return question_answer_pairs 68 | 69 | 70 | def extract_answer(text: str, placeholder="# Answer", end_placeholder='# END!'): 71 | text = text.lower() 72 | left_idx = text.rindex(placeholder.lower()) 73 | length = len(placeholder) 74 | try: 75 | right_idx = text.rindex(end_placeholder.lower()) 76 | except: 77 | right_idx = -1 78 | return text[left_idx + length:right_idx].strip() 79 | 80 | 81 | def prm_hrm_best_of_n(model_path, host, port, model_name, api_key="", N=2, 82 | test_data_path='dataset/gsm-8k/evaluation.jsonl', task="prm", 83 | version='v1'): 84 | def clear_memory(): 85 | gc.collect() 86 | torch.cuda.empty_cache() 87 | 88 | def whether_contain_answer(text): 89 | return "# Answer" in text 90 | 91 | base_path = f"stats/{task}_{version}_{model_name}/" 92 | os.makedirs(base_path, exist_ok=True) 93 | 94 | saving_filename = os.path.join(base_path, f"{version}_{task}_{N}.txt") 95 | 96 | test_data_pairs = load_test_data(test_data_path) 97 | correct_cnt = 0 98 | total_cnt = len(test_data_pairs) 99 | 100 | final_answers = [] 101 | answer_situations = [] 102 | ground_truths = [] 103 | 104 | prm_model, prm_tokenizer = load_rw_model(model_path) 105 | 106 | for question, ground_truth in tqdm(test_data_pairs, total=len(test_data_pairs)): 107 | clear_memory() 108 | 109 | previous_steps = "" 110 | ground_truths.append(ground_truth) 111 | found_answer = False 112 | for height in range(MAX_HEIGHT): 113 | candidates = [] 114 | policy_model_prompt = construct_policy_model_prompt_for_PRM_HRM(question, previous_steps) 115 | print(f"---------\nheight{height}: policy model prompt") 116 | print(policy_model_prompt) 117 | print(f"---------\nPrompt finish!") 118 | try: 119 | intermediate_steps = parallel_query(host, port, model_name, policy_model_prompt, api_key, n=N) 120 | except Exception as e: 121 | intermediate_steps = [] 122 | for _ in range(N): 123 | try: 124 | intermediate_step = sequential_query(host, port, model_name, policy_model_prompt, api_key) 125 | except Exception as ee: 126 | print(ee) 127 | continue 128 | intermediate_steps.append(intermediate_step) 129 | candidates.extend(intermediate_steps) 130 | 131 | if not candidates: 132 | break 133 | 134 | sorted_current_step_score_pairs = calculate_PRM_HRM_scores(prm_model, prm_tokenizer, question, N, 135 | candidates, 136 | previous_steps) 137 | _, current_step = sorted_current_step_score_pairs[0] 138 | print(f"Current step: {current_step}") 139 | previous_steps = previous_steps + current_step 140 | # print(f"Updated previous steps height{height}: ", previous_steps) 141 | 142 | print("---------") 143 | 144 | if whether_contain_answer(current_step): 145 | found_answer = True 146 | break 147 | print("-------最终的推理过程----------") 148 | print(previous_steps) 149 | if found_answer: 150 | answer = extract_answer(previous_steps) 151 | else: 152 | answer = "cannot find a correct answer." 153 | final_answers.append(answer) 154 | true_or_false = grade_answer(answer, ground_truth) 155 | answer_situations.append(true_or_false) 156 | if true_or_false: 157 | correct_cnt += 1 158 | print(f"find answer in {question}") 159 | 160 | with open(saving_filename, "a") as f: 161 | f.write(f"{answer}\t{ground_truth}\t{true_or_false}\n") 162 | 163 | print(f"N = {N}, the correct rate is {correct_cnt / total_cnt}") 164 | 165 | 166 | if __name__ == '__main__': 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("--host", type=str, default="127.0.0.1") 169 | parser.add_argument("--port", type=str, default="10086") 170 | parser.add_argument("--task", type=str, default="prm") 171 | parser.add_argument("--model_name", type=str, default="policy_model") 172 | parser.add_argument("--api_key", type=str, default="xxxx") 173 | parser.add_argument("--N", type=int, default=2) 174 | parser.add_argument("--test_data_path", type=str, 175 | default='dataset/gsm-8k/evaluation.jsonl') 176 | parser.add_argument("--version", type=str, default="v1") 177 | 178 | args = parser.parse_args() 179 | 180 | task = args.task 181 | host = args.host 182 | port = args.port 183 | model_name = args.model_name 184 | api_key = args.api_key 185 | version = args.version 186 | 187 | N = int(args.N) 188 | test_data_path = args.test_data_path 189 | model_path = f"sf_best_model/{task}/{version}" 190 | 191 | if task == "prm": 192 | prm_hrm_best_of_n(model_path, host, port, model_name, api_key, N, test_data_path, task, version) 193 | elif task == "hrm": 194 | prm_hrm_best_of_n(model_path, host, port, model_name, api_key, N, test_data_path, task, version) 195 | 196 | 197 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 1 --task hrm --model_name qwen7b > logs/gsm8k_hrm_1_v1_qwen7b.log & 198 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 2 --task hrm --model_name qwen7b > logs/gsm8k_hrm_2_v1_qwen7b.log & 199 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 4 --task hrm --model_name qwen7b > logs/gsm8k_hrm_4_v1_qwen7b.log & 200 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 8 --task hrm --model_name qwen7b > logs/gsm8k_hrm_8_v1_qwen7b.log & 201 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 16 --task hrm --model_name qwen7b > logs/gsm8k_hrm_16_v1_qwen7b.log & 202 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 24 --task hrm --model_name qwen7b > logs/gsm8k_hrm_24_v1_qwen7b.log & 203 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 32 --task hrm --model_name qwen7b > logs/gsm8k_hrm_32_v1_qwen7b.log & 204 | # CUDA_VISIBLE_DEVICES=0 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 64 --port 10087 --task hrm --model_name qwen7b_math_instruct_v2 > logs/gsm8k_hrm_64_v1_qwen7b.log & 205 | # CUDA_VISIBLE_DEVICES=1 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 128 --port 10087 --task hrm --model_name qwen7b_math_instruct_v2 > logs/gsm8k_hrm_128_v1_qwen7b.log & 206 | # CUDA_VISIBLE_DEVICES=2 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 256 --port 10087 --task hrm --model_name qwen7b_math_instruct_v2 > logs/gsm8k_hrm_256_v1_qwen7b.log & 207 | # CUDA_VISIBLE_DEVICES=2 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 512 --port 10086 --task hrm --model_name qwen7b > logs/gsm8k_hrm_512_v1_qwen7b.log & 208 | 209 | 210 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 1 --task prm --model_name qwen7b > logs/gsm8k_prm_1_v1_qwen7b.log & 211 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 2 --task prm --model_name qwen7b > logs/gsm8k_prm_2_v1_qwen7b.log & 212 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 4 --task prm --model_name qwen7b > logs/gsm8k_prm_4_v1_qwen7b.log & 213 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 8 --task prm --model_name qwen7b > logs/gsm8k_prm_8_v1_qwen7b.log & 214 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 16 --task prm --model_name qwen7b > logs/gsm8k_prm_16_v1_qwen7b.log & 215 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 24 --task prm --model_name qwen7b > logs/gsm8k_prm_24_v1_qwen7b.log & 216 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 32 --task prm --model_name qwen7b > logs/gsm8k_prm_32_v1_qwen7b.log & 217 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 64 --task prm --model_name qwen7b > logs/gsm8k_prm_64_v1_qwen7b.log & 218 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 128 --task prm --model_name qwen7b > logs/gsm8k_prm_128_v1_qwen7b.log & 219 | # CUDA_VISIBLE_DEVICES=6 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 256 --task prm --model_name qwen7b > logs/gsm8k_prm_256_v1_qwen7b.log & 220 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_gsm8k.py --version v1 --N 512 --port 10086 --task prm --model_name qwen7b > logs/gsm8k_prm_512_v1_qwen7b.log & 221 | 222 | 223 | -------------------------------------------------------------------------------- /self-training/construct_hrm_train_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | 4 | from tree import Node 5 | from envs import * 6 | import matplotlib.pyplot as plt 7 | from collections import defaultdict 8 | from construct_prompt import construct_PRM_HRM_prompt_v2 9 | import argparse 10 | from tqdm import tqdm 11 | import json 12 | import random 13 | import re 14 | 15 | 16 | def ensure_step_spacing(text): 17 | """ 18 | 确保 '# Step X' 后面有空格或换行,如果没有则添加一个空格 19 | """ 20 | pattern = r"(# Step \d)(?=[^\s])" # 匹配 "# Step X" 且后面不是空格或换行 21 | corrected_text = re.sub(pattern, r"\1 " + "\n", text) # 在匹配的后面加空格 22 | return corrected_text 23 | 24 | 25 | # 测试案例 26 | test_text = """# Step 1do something 27 | # Step 2 28 | # Step 3next step 29 | # Step 4final step""" 30 | 31 | corrected_text = ensure_step_spacing(test_text) 32 | print(corrected_text) 33 | 34 | 35 | def replace_string(a_list): 36 | height_2_placeholder = {1: '# Step 2', 2: "# Step 3", 3: "# Step 4", 4: "# Step 5", 5: '# END!'} 37 | height = len(a_list) 38 | # if height == 6: 39 | # return a_list 40 | if height != 5: 41 | a_list[-1] = a_list[-1].replace(height_2_placeholder[height], "\n").replace(height_2_placeholder[height + 1], 42 | "\n") 43 | if height > 2: 44 | a_list[-3] = a_list[-3].replace(height_2_placeholder[height - 2], "\n") 45 | a_list[-2] = height_2_placeholder[height - 2] + "\n" + a_list[-2] 46 | if height > 1: 47 | a_list[-2] = a_list[-2].replace(height_2_placeholder[height - 1], "\n") 48 | a_list[-1] = height_2_placeholder[height - 1] + "\n" + a_list[- 1] 49 | 50 | return a_list 51 | 52 | 53 | def load_all_pickle_file_paths(base_path="dataset/self_training"): 54 | answers = [] 55 | for filename in os.listdir(base_path): 56 | answers.append(os.path.join(base_path, filename)) 57 | return answers 58 | 59 | 60 | def is_leaf_node(node: Node): 61 | if not node.children: 62 | return True 63 | return False 64 | 65 | 66 | def split_node(root: Node): 67 | queue = [root] 68 | while queue: 69 | temp_length = len(queue) 70 | for _ in range(temp_length): 71 | node = queue.pop(0) 72 | 73 | children = [] 74 | have_changed_one_child = False 75 | for idx, child in enumerate(node.children): 76 | 77 | child.visited = False 78 | if child.height >= 2 and child.parent.visited is False and have_changed_one_child is False and len( 79 | node.children) >= 4: 80 | if random.random() < 0.5: 81 | child.visited = True 82 | 83 | have_changed_one_child = True 84 | 85 | previous_answers = child.previous_answer[:] 86 | assert len(previous_answers) >= 2 87 | last_one = previous_answers.pop() 88 | last_two = previous_answers.pop() 89 | previous_answers.append(last_two + last_one) 90 | 91 | child.height = child.height - 1 92 | child.previous_answer = previous_answers 93 | 94 | child.parent = node.parent 95 | node.parent.children.append(child) 96 | if child.visited is False: 97 | children.append(child) 98 | 99 | queue.append(child) 100 | node.children = children 101 | return root 102 | 103 | 104 | def assign_score_for_every_nodes(path, split=False): 105 | with open(path, 'rb') as f: 106 | data = pickle.load(f) 107 | if split: 108 | data = split_node(data) 109 | LEAF_NODE = 0 110 | 111 | def dfs(node: Node): 112 | if node.score == -1: 113 | node.score = 0 114 | 115 | if is_leaf_node(node): 116 | nonlocal LEAF_NODE 117 | LEAF_NODE += 1 118 | if node.is_correct: 119 | node.score = 1 120 | return 1 121 | else: 122 | node.score = 0 123 | return 0 124 | else: 125 | for child in node.children: 126 | cnt = dfs(child) 127 | node.score += cnt 128 | # if node.score!=0: 129 | # print(node.score) 130 | return node.score 131 | 132 | def dfs_v2(node: Node): 133 | if hasattr(node, 'total'): 134 | pass 135 | else: 136 | node.total = 0 137 | 138 | if is_leaf_node(node): 139 | node.total = 1 140 | return 1 141 | else: 142 | for child in node.children: 143 | cnt = dfs_v2(child) 144 | node.total += cnt 145 | return node.total 146 | 147 | dfs(data) 148 | dfs_v2(data) 149 | queue = [data] 150 | CORRECT_TOTAL_NUMBERS = 0 151 | TOTAL_NUMBERS = 0 152 | 153 | score_height_2_node = defaultdict(lambda: defaultdict(list)) 154 | while queue: 155 | temp_length = len(queue) 156 | for _ in range(temp_length): 157 | node = queue.pop(0) 158 | node.score = node.score / node.total 159 | if node.height != 0: 160 | score_height_2_node[node.height][node.score].append(node) 161 | 162 | if node.score == 1: 163 | CORRECT_TOTAL_NUMBERS += 1 164 | TOTAL_NUMBERS += 1 165 | for child in node.children: 166 | queue.append(child) 167 | # print(score_height_2_node.keys()) 168 | 169 | data.score = 1 170 | # print(LEAF_NODE) 171 | # print(TOTAL_NUMBERS) 172 | # 173 | # print(CORRECT_TOTAL_NUMBERS) 174 | # from collections import Counter 175 | # print(Counter(scores)) 176 | # plot_distribution(scores) 177 | return score_height_2_node 178 | 179 | 180 | def plot_distribution(data): 181 | plt.hist(data, bins=10, edgecolor='black') # Adjust the bins as needed 182 | 183 | plt.xlabel('Value') 184 | plt.ylabel('Frequency') 185 | plt.title('Distribution of Numbers') 186 | 187 | plt.savefig("score_distribution.png") 188 | 189 | 190 | def sort_node(score_height_2_node, merge_without_splitting=True, split=False): 191 | answers = [] 192 | for height in score_height_2_node: 193 | different_score_dict = score_height_2_node[height] 194 | if len(different_score_dict) > 1: 195 | min_length = float("inf") 196 | for score, list_ in different_score_dict.items(): 197 | min_length = min(min_length, len(list_)) 198 | for score, list_ in different_score_dict.items(): 199 | nodes = random.sample(list_, max(min_length, int(len(list_) * 0.1))) 200 | for node in nodes: 201 | if node.previous_answer: 202 | node.previous_answer = replace_string(node.previous_answer) 203 | 204 | prompt = construct_PRM_HRM_prompt_v2(node.question, "".join(node.previous_answer[:-1]), 205 | node.previous_answer[-1]) 206 | if "# Step 1" not in prompt: 207 | continue 208 | prompt = ensure_step_spacing(prompt) 209 | answers.append(json.dumps({"input": prompt, "label": score})) 210 | 211 | if len(node.previous_answer) >= 2 and merge_without_splitting: 212 | prompt = construct_PRM_HRM_prompt_v2(node.question, "".join(node.previous_answer[:-2]), 213 | node.previous_answer[-2] + node.previous_answer[-1]) 214 | prompt = ensure_step_spacing(prompt) 215 | 216 | answers.append(json.dumps({"input": prompt, "label": score})) 217 | 218 | 219 | else: 220 | for score, list_ in different_score_dict.items(): 221 | k = min(5, len(list_)) 222 | nodes = random.sample(list_, k) 223 | for node in nodes: 224 | if node.previous_answer: 225 | 226 | node.previous_answer = replace_string(node.previous_answer) 227 | 228 | prompt = construct_PRM_HRM_prompt_v2(node.question, "".join(node.previous_answer[:-1]), 229 | node.previous_answer[-1]) 230 | 231 | if "# Step 1" not in prompt: 232 | continue 233 | prompt = ensure_step_spacing(prompt) 234 | 235 | answers.append(json.dumps({"input": prompt, "label": score})) 236 | 237 | if len(node.previous_answer) >= 2 and merge_without_splitting: 238 | prompt = construct_PRM_HRM_prompt_v2(node.question, "".join(node.previous_answer[:-2]), 239 | node.previous_answer[-2] + node.previous_answer[-1]) 240 | prompt = ensure_step_spacing(prompt) 241 | 242 | answers.append(json.dumps({"input": prompt, "label": score})) 243 | 244 | return set(answers) 245 | 246 | 247 | if __name__ == '__main__': 248 | parser = argparse.ArgumentParser() 249 | parser.add_argument("--base_path", type=str, default="dataset/self_training") 250 | parser.add_argument("--save_path", type=str, default="dataset/self_training_v1_scoring") 251 | 252 | args = parser.parse_args() 253 | base_path = args.base_path 254 | save_path = args.save_path 255 | os.makedirs(save_path, exist_ok=True) 256 | 257 | paths = load_all_pickle_file_paths() 258 | cnt = 0 259 | scores = [] 260 | train_idx = int(0.8 * len(paths)) 261 | train_paths = paths[:train_idx] 262 | test_paths = paths[train_idx:] 263 | 264 | with open(os.path.join(save_path, "hrm_train.jsonl"), 'w') as f: 265 | for path in tqdm(train_paths, total=len(train_paths)): 266 | score_height_2_node = assign_score_for_every_nodes(path, split=False) 267 | answers = sort_node(score_height_2_node, merge_without_splitting=True) 268 | # for answer in answers: 269 | # scores.append(json.loads(answer)['label']) 270 | # f.write(answer + "\n") 271 | # cnt += 1 272 | 273 | score_height_2_node = assign_score_for_every_nodes(path, split=True) 274 | temp_answers = sort_node(score_height_2_node, merge_without_splitting=True) 275 | 276 | answers = answers.union(temp_answers) 277 | for answer in answers: 278 | scores.append(json.loads(answer)['label']) 279 | f.write(answer + "\n") 280 | cnt += 1 281 | 282 | with open(os.path.join(save_path, "hrm_test.jsonl"), 'w') as f: 283 | for path in tqdm(test_paths, total=len(test_paths)): 284 | score_height_2_node = assign_score_for_every_nodes(path, split=False) 285 | answers = sort_node(score_height_2_node, merge_without_splitting=True, split=False) 286 | # for answer in answers: 287 | # scores.append(json.loads(answer)['label']) 288 | # f.write(answer + "\n") 289 | # cnt += 1 290 | 291 | score_height_2_node = assign_score_for_every_nodes(path, split=True) 292 | temp_answers = sort_node(score_height_2_node, merge_without_splitting=False, split=True) 293 | 294 | answers = answers.union(temp_answers) 295 | for answer in answers: 296 | scores.append(json.loads(answer)['label']) 297 | f.write(answer + "\n") 298 | cnt += 1 299 | 300 | print(cnt) 301 | plot_distribution(scores) 302 | # from collections import Counter 303 | # print(Counter(scores)) 304 | -------------------------------------------------------------------------------- /self-training/best_of_n_math500.py: -------------------------------------------------------------------------------- 1 | """ 2 | PRM and HRM are trained from auto-labeled reasoning process in the PRM800K dataset. And evaluation in MATH500 dataset. 3 | """ 4 | import argparse 5 | import json 6 | import os 7 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 8 | import torch 9 | from llm_query import sequential_query, parallel_query 10 | from construct_prompt import construct_policy_model_prompt_for_PRM_HRM, construct_PRM_HRM_prompt_v2 11 | from grading.grader import grade_answer 12 | from tqdm import tqdm 13 | import gc 14 | 15 | MAX_HEIGHT = 8 16 | 17 | 18 | def load_rw_model(model_path): 19 | model = AutoModelForSequenceClassification.from_pretrained(model_path, 20 | num_labels=1, 21 | attn_implementation="flash_attention_2", 22 | torch_dtype=torch.float16, 23 | ) 24 | tokenizer = AutoTokenizer.from_pretrained(model_path) 25 | 26 | model.config.pad_token_id = model.config.eos_token_id 27 | device = torch.device("cuda") 28 | model.to(device) 29 | return model, tokenizer 30 | 31 | 32 | def load_test_data(data_path='dataset/math500/test.jsonl'): 33 | question_answer_pairs = [] 34 | with open(data_path) as f: 35 | for line in f.readlines(): 36 | dic = json.loads(line) 37 | 38 | question = dic['problem'] 39 | ground_truth = dic['answer'] 40 | question_answer_pairs.append((question, ground_truth)) 41 | 42 | return question_answer_pairs 43 | 44 | 45 | def calculate_PRM_HRM_scores(rm_model, tokenizer, question, N, current_steps, previous_steps="", N_batch=4): 46 | assert len(current_steps) == N, "The length of current steps must be equal to N" 47 | 48 | device = "cuda" 49 | current_step_score_pairs = [] 50 | 51 | for i in range(0, N, N_batch): 52 | batch_steps = current_steps[i:i + N_batch] 53 | prompts = [construct_PRM_HRM_prompt_v2(question, previous_steps, current_step) for current_step in batch_steps] 54 | 55 | inputs = tokenizer(prompts, return_tensors="pt", max_length=4096, padding=True, truncation=True) 56 | inputs = {key: value.to(device) for key, value in inputs.items()} 57 | 58 | with torch.no_grad(): 59 | outputs = rm_model(**inputs) 60 | logits = outputs.logits 61 | 62 | positive_scores = logits[:, 0].tolist() 63 | 64 | batch_step_score_pairs = [(score, step) for score, step in zip(positive_scores, batch_steps)] 65 | current_step_score_pairs.extend(batch_step_score_pairs) 66 | 67 | current_step_score_pairs.sort(reverse=True) 68 | return current_step_score_pairs 69 | 70 | 71 | def extract_answer(text: str, placeholder="# Answer", end_placeholder='# END!'): 72 | text = text.lower() 73 | left_idx = text.rindex(placeholder.lower()) 74 | length = len(placeholder) 75 | try: 76 | right_idx = text.rindex(end_placeholder.lower()) 77 | except: 78 | right_idx = -1 79 | return text[left_idx + length:right_idx].strip() 80 | 81 | 82 | def prm_hrm_best_of_n(model_path, host, port, model_name, api_key="", N=2, 83 | test_data_path='dataset/math500/test.jsonl', task="prm", 84 | version='v1'): 85 | def clear_memory(): 86 | gc.collect() 87 | torch.cuda.empty_cache() 88 | 89 | def whether_contain_answer(text): 90 | return "# Answer" in text 91 | 92 | base_path = f"stats/{task}_{version}_{model_name}/" 93 | os.makedirs(base_path, exist_ok=True) 94 | 95 | saving_filename = os.path.join(base_path, f"{version}_{task}_{N}.txt") 96 | # try: 97 | # with open(saving_filename, "r") as f: 98 | # saving_length = len(f.readlines()) 99 | # print(f"已经存了{saving_length}个例子了") 100 | # except: 101 | # saving_length = 0 102 | # print(f"已经存了{saving_length}个例子了") 103 | 104 | test_data_pairs = load_test_data(test_data_path) 105 | correct_cnt = 0 106 | total_cnt = len(test_data_pairs) 107 | 108 | final_answers = [] 109 | answer_situations = [] 110 | ground_truths = [] 111 | 112 | prm_model, prm_tokenizer = load_rw_model(model_path) 113 | 114 | for question, ground_truth in tqdm(test_data_pairs, total=len(test_data_pairs)): 115 | clear_memory() 116 | 117 | previous_steps = "" 118 | ground_truths.append(ground_truth) 119 | found_answer = False 120 | for height in range(MAX_HEIGHT): 121 | candidates = [] 122 | policy_model_prompt = construct_policy_model_prompt_for_PRM_HRM(question, previous_steps) 123 | print(f"---------\nheight{height}: policy model prompt") 124 | print(policy_model_prompt) 125 | print(f"---------\nPrompt finish!") 126 | try: 127 | intermediate_steps = parallel_query(host, port, model_name, policy_model_prompt, api_key, n=N) 128 | except Exception as e: 129 | intermediate_steps = [] 130 | for _ in range(N): 131 | try: 132 | intermediate_step = sequential_query(host, port, model_name, policy_model_prompt, api_key) 133 | except Exception as ee: 134 | print(ee) 135 | continue 136 | intermediate_steps.append(intermediate_step) 137 | candidates.extend(intermediate_steps) 138 | 139 | if not candidates: 140 | break 141 | 142 | sorted_current_step_score_pairs = calculate_PRM_HRM_scores(prm_model, prm_tokenizer, question, N, 143 | candidates, 144 | previous_steps) 145 | _, current_step = sorted_current_step_score_pairs[0] 146 | print(f"Current step: {current_step}") 147 | previous_steps = previous_steps + current_step 148 | # print(f"Updated previous steps height{height}: ", previous_steps) 149 | 150 | print("---------") 151 | 152 | if whether_contain_answer(current_step): 153 | found_answer = True 154 | break 155 | print("-------最终的推理过程----------") 156 | print(previous_steps) 157 | if found_answer: 158 | answer = extract_answer(previous_steps) 159 | else: 160 | answer = "cannot find a correct answer." 161 | final_answers.append(answer) 162 | true_or_false = grade_answer(answer, ground_truth) 163 | answer_situations.append(true_or_false) 164 | if true_or_false: 165 | correct_cnt += 1 166 | print(f"find answer in {question}") 167 | 168 | with open(saving_filename, "a") as f: 169 | f.write(f"{answer}\t{ground_truth}\t{true_or_false}\n") 170 | 171 | print(f"N = {N}, the correct rate is {correct_cnt / total_cnt}") 172 | 173 | 174 | if __name__ == '__main__': 175 | parser = argparse.ArgumentParser() 176 | parser.add_argument("--host", type=str, default="127.0.0.1") 177 | parser.add_argument("--port", type=str, default="10086") 178 | parser.add_argument("--task", type=str, default="prm") 179 | parser.add_argument("--model_name", type=str, default="policy_model") 180 | parser.add_argument("--api_key", type=str, default="xxxx") 181 | parser.add_argument("--N", type=int, default=2) 182 | parser.add_argument("--test_data_path", type=str, 183 | default='dataset/math500/test.jsonl') 184 | parser.add_argument("--version", type=str, default="v1") 185 | 186 | args = parser.parse_args() 187 | 188 | task = args.task 189 | host = args.host 190 | port = args.port 191 | model_name = args.model_name 192 | api_key = args.api_key 193 | version = args.version 194 | 195 | N = int(args.N) 196 | test_data_path = args.test_data_path 197 | model_path = f"sf_best_model/{task}/{version}" 198 | 199 | if task == "prm": 200 | prm_hrm_best_of_n(model_path, host, port, model_name, api_key, N, test_data_path, task, version) 201 | elif task == "hrm": 202 | prm_hrm_best_of_n(model_path, host, port, model_name, api_key, N, test_data_path, task, version) 203 | 204 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 1 --task hrm --model_name qwen7b > logs/math500_hrm_1_v1_qwen7b.log & 205 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 2 --task hrm --model_name qwen7b > logs/math500_hrm_2_v1_qwen7b.log & 206 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 4 --task hrm --model_name qwen7b > logs/math500_hrm_4_v1_qwen7b.log & 207 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 8 --task hrm --model_name qwen7b > logs/math500_hrm_8_v1_qwen7b.log & 208 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 16 --task hrm --model_name qwen7b > logs/math500_hrm_16_v1_qwen7b.log & 209 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 24 --task hrm --model_name qwen7b > logs/math500_hrm_24_v1_qwen7b.log & 210 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 32 --task hrm --model_name qwen7b > logs/math500_hrm_32_v1_qwen7b.log & 211 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 64 --task hrm --model_name qwen7b > logs/math500_hrm_64_v1_qwen7b.log & 212 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 128 --task hrm --model_name qwen7b > logs/math500_hrm_128_v1_qwen7b.log & 213 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 256 --port 10087 --task hrm --model_name qwen7b > logs/math500_hrm_256_v1_qwen7b.log & 214 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 512 --port 10086 --task hrm --model_name qwen7b > logs/math500_hrm_512_v1_qwen7b.log & 215 | 216 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_math500.py --version v1 --N 1 --task prm --model_name qwen7b > logs/math500_prm_1_v1_qwen7b.log & 217 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_math500.py --version v1 --N 2 --task prm --model_name qwen7b > logs/math500_prm_2_v1_qwen7b.log & 218 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_math500.py --version v1 --N 4 --task prm --model_name qwen7b > logs/math500_prm_4_v1_qwen7b.log & 219 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_math500.py --version v1 --N 8 --task prm --model_name qwen7b > logs/math500_prm_8_v1_qwen7b.log & 220 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_math500.py --version v1 --N 16 --task prm --model_name qwen7b > logs/math500_prm_16_v1_qwen7b.log & 221 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_math500.py --version v1 --N 24 --task prm --model_name qwen7b > logs/math500_prm_24_v1_qwen7b.log & 222 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_math500.py --version v1 --N 32 --task prm --model_name qwen7b > logs/math500_prm_32_v1_qwen7b.log & 223 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_math500.py --version v1 --N 64 --task prm --model_name qwen7b > logs/math500_prm_64_v1_qwen7b.log & 224 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_math500.py --version v1 --N 128 --task prm --model_name qwen7b > logs/math500_prm_128_v1_qwen7b.log & 225 | # CUDA_VISIBLE_DEVICES=4 nohup python self-training/best_of_n_math500.py --version v1 --N 256 --port 10086 --task prm --model_name qwen7b > logs/math500_prm_256_v1_qwen7b.log & 226 | # CUDA_VISIBLE_DEVICES=5 nohup python self-training/best_of_n_math500.py --version v1 --N 512 --port 10087 --task prm --model_name qwen7b > logs/math500_prm_512_v1_qwen7b.log & 227 | -------------------------------------------------------------------------------- /construct_dataset/rm_dataset_construction_phase1.py: -------------------------------------------------------------------------------- 1 | import json 2 | from helper import correct_wrong_comparison 3 | 4 | 5 | def parse_jsonl(jsonl_file="dataset/prm_dataset/phase1_train.jsonl"): 6 | dicts = [] 7 | with open(jsonl_file) as f: 8 | for line in f.readlines(): 9 | temp_json = json.loads(line) 10 | dicts.append(temp_json) 11 | return dicts 12 | 13 | 14 | def split_correct_and_wrong_step_phase1(dicts): 15 | questions = [] 16 | positive_list, negative_list, neutral_list = [], [], [] 17 | chosen_completion_list = [] 18 | status_list = [] 19 | 20 | for dict_ in dicts: 21 | 22 | current_status = dict_['label']['finish_reason'] 23 | 24 | if current_status == "bad_problem" or current_status == "give_up": 25 | continue 26 | 27 | question = dict_['question']['problem'] 28 | 29 | questions.append(question) 30 | 31 | steps = dict_['label']['steps'] 32 | status_list.append(current_status) 33 | 34 | positives = dict() 35 | negatives = dict() 36 | neutrals = dict() 37 | 38 | chosen_completions = [] 39 | 40 | for idx, step in enumerate(steps): 41 | completions = step['completions'] 42 | 43 | chosen_completion = step['chosen_completion'] 44 | 45 | positives[idx] = [] 46 | negatives[idx] = [] 47 | neutrals[idx] = [] 48 | 49 | for completion in completions: 50 | if completion['rating'] == 1: 51 | positives[idx].append(completion['text']) 52 | elif completion['rating'] == 0: 53 | neutrals[idx].append(completion['text']) 54 | elif completion['rating'] == -1: 55 | negatives[idx].append(completion['text']) 56 | if step['chosen_completion'] is not None: 57 | chosen_completions.append( 58 | [completions[chosen_completion]['text'], completions[chosen_completion]['rating']]) 59 | else: 60 | chosen_completions.append([step['human_completion']['text'], 1]) 61 | 62 | positive_list.append(positives) 63 | negative_list.append(negatives) 64 | neutral_list.append(neutrals) 65 | chosen_completion_list.append(chosen_completions) 66 | assert len(questions) == len(positive_list) == len(negative_list) == len(neutral_list) == len( 67 | chosen_completion_list) == len(status_list) 68 | print(len(questions)) 69 | return questions, positive_list, negative_list, neutral_list, chosen_completion_list, status_list 70 | 71 | 72 | def phase1_ORM_dataset(questions, positive_list, negative_list, neutral_list, chosen_completion_list): 73 | json_list = [] 74 | for question, positive, negative, neutral, chosen_completions in zip(questions, positive_list, negative_list, 75 | neutral_list, chosen_completion_list): 76 | json_obj = phase1_ORM_dataset_for_single_question(question, positive, negative, neutral, chosen_completions) 77 | json_list.append(json_obj) 78 | return json_list 79 | 80 | 81 | 82 | 83 | def phase1_ORM_dataset_for_single_question(question, positives, negatives, neutrals, chosen_completions): 84 | def traverse_(a_dict): 85 | res = [] 86 | length = len(a_dict) 87 | for idx in range(length): 88 | item = a_dict[idx] 89 | if len(item) > 0: 90 | res.append(idx) 91 | return res 92 | 93 | correct_reasoning_s = [] 94 | incorrect_reasoning_s = [] 95 | total_steps = len(positives) 96 | 97 | step_2_text = [str(0)] * total_steps 98 | 99 | for idx, (text, _) in enumerate(chosen_completions): 100 | step_2_text[idx] = text 101 | 102 | trajectory = "\n\n".join(step_2_text) 103 | correct_reasoning_s.append(trajectory) 104 | 105 | positive_idxes = traverse_(positives) 106 | negative_idxes = traverse_(negatives) 107 | 108 | try: 109 | last_positive = positive_idxes[-1] 110 | except: 111 | raise Exception("it should have some elements.") 112 | 113 | if last_positive == total_steps - 1: 114 | previous_slot = step_2_text[last_positive] 115 | for candidate in positives[total_steps - 1]: 116 | step_2_text[last_positive] = candidate 117 | correct_reasoning_s.append("\n\n".join(step_2_text)) 118 | step_2_text[last_positive] = previous_slot 119 | else: 120 | raise Exception("it should have the answer.") 121 | 122 | can_use_idxes = [] 123 | if len(negative_idxes) > 0: 124 | negative_length = len(negative_idxes) 125 | if negative_idxes[-1] == total_steps - 1: 126 | 127 | can_use_idxes.append(negative_idxes[-1]) 128 | 129 | for i in range(negative_length - 1, 0, -1): 130 | if negative_idxes[i] - negative_idxes[i - 1] != 1: 131 | break 132 | else: 133 | can_use_idxes.append(negative_idxes[i - 1]) 134 | 135 | if can_use_idxes: 136 | last_idx = can_use_idxes[0] 137 | counterpart_step2text = step_2_text[:last_idx + 1] 138 | queue = [counterpart_step2text] 139 | total_queue = [] 140 | 141 | for idx in can_use_idxes[:]: 142 | current_queue_length = len(queue) 143 | negative_pool = negatives[idx] 144 | for j in range(current_queue_length): 145 | step2text_from_queue = queue.pop(0) 146 | 147 | for candidate in negative_pool: 148 | temp_step2text = step2text_from_queue[:] 149 | temp_step2text[idx] = candidate 150 | queue.append(temp_step2text) 151 | total_queue.append(temp_step2text) 152 | 153 | if len(queue) > len(correct_reasoning_s) * 4: 154 | # print(question) 155 | break 156 | 157 | for item in queue: 158 | incorrect_reasoning_s.append("\n\n".join(item)) 159 | 160 | correct_reasoning_s = list(set(correct_reasoning_s)) 161 | incorrect_reasoning_s = list(set(incorrect_reasoning_s)) 162 | return {'question': question, "correct": correct_reasoning_s, "incorrect": incorrect_reasoning_s} 163 | 164 | 165 | 166 | def phase1_PRM_dataset(questions, positive_list, negative_list, neutral_list, chosen_completion_list): 167 | json_list = [] 168 | for question, positive, negative, neutral, chosen_completions in zip(questions, positive_list, negative_list, 169 | neutral_list, chosen_completion_list): 170 | json_obj = phase1_PRM_dataset_for_single_question(question, positive, negative, neutral, chosen_completions) 171 | json_list.append(json_obj) 172 | return json_list 173 | 174 | 175 | def phase1_PRM_dataset_for_single_question(question, positives, negatives, neutrals, chosen_completions): 176 | correct_process_reasoning_s = [] 177 | incorrect_process_reasoning_s = [] 178 | chosen_trajectory = [""] 179 | 180 | total_steps = len(positives) 181 | 182 | for i in range(total_steps): 183 | positive_pool = positives[i] 184 | negative_pool = negatives[i] 185 | neutral_pool = neutrals[i] 186 | 187 | for candidate in positive_pool: 188 | correct_process_reasoning_s.append(chosen_trajectory[-1] + r" /qwerdf12344567" + candidate) 189 | for candidate in negative_pool: 190 | incorrect_process_reasoning_s.append(chosen_trajectory[-1] + r" /qwerdf12344567" + candidate) 191 | # for candidate in neutral_pool: 192 | # incorrect_process_reasoning_s.append(chosen_trajectory[-1] + r" /qwerdf12344567" + candidate) 193 | # correct_process_reasoning_s.append(chosen_trajectory[-1] + r" /qwerdf12344567" + candidate) 194 | # continue 195 | if chosen_completions[i]: 196 | chosen_trajectory.append(chosen_trajectory[-1] + r" /qwerdf12344567" + chosen_completions[i][0]) 197 | else: 198 | assert i == total_steps - 1 199 | return { 200 | 'question': question, 201 | 'correct': correct_process_reasoning_s, 202 | 'incorrect': incorrect_process_reasoning_s 203 | } 204 | 205 | 206 | def phase1_HRM_dataset(questions, positive_list, negative_list, neutral_list, chosen_completion_list): 207 | json_list = [] 208 | for question, positive, negative, neutral, chosen_completions in zip(questions, positive_list, negative_list, 209 | neutral_list, chosen_completion_list): 210 | json_obj = phase1_HRM_dataset_for_single_question(question, positive, negative, neutral, chosen_completions) 211 | json_list.append(json_obj) 212 | return json_list 213 | 214 | 215 | def phase1_HRM_dataset_for_single_question(question, positives, negatives, neutrals, 216 | chosen_completions): 217 | correct_process_reasoning_s = [] 218 | incorrect_process_reasoning_s = [] 219 | chosen_trajectory = [""] 220 | previous_rating_completions = [] 221 | total_steps = len(positives) 222 | 223 | for i in range(total_steps): 224 | positive_pool = positives[i] 225 | negative_pool = negatives[i] 226 | neutral_pool = neutrals[i] 227 | 228 | for candidate in positive_pool: 229 | correct_process_reasoning_s.append(chosen_trajectory[-1] + r" /qwerdf12344567" + candidate) 230 | 231 | if previous_rating_completions: 232 | correct_process_reasoning_s.append(chosen_trajectory[-1] + "\n\n" + candidate) 233 | 234 | for candidate in negative_pool: 235 | incorrect_process_reasoning_s.append(chosen_trajectory[-1] + r" /qwerdf12344567" + candidate) 236 | 237 | if previous_rating_completions: 238 | incorrect_process_reasoning_s.append(chosen_trajectory[-1] + "\n\n" + candidate) 239 | 240 | for candidate in neutral_pool: 241 | continue 242 | incorrect_process_reasoning_s.append(chosen_trajectory[-1] + r" /qwerdf12344567" + candidate) 243 | 244 | if previous_rating_completions: 245 | if previous_rating_completions[-1] == 0: 246 | incorrect_process_reasoning_s.append(chosen_trajectory[-1] + " " + candidate) 247 | elif previous_rating_completions[-1] == 1: 248 | continue 249 | correct_process_reasoning_s.append(chosen_trajectory[-1] + " " + candidate) 250 | else: 251 | raise Exception("something fucked up in HRM phase 1") 252 | 253 | if chosen_completions[i]: 254 | chosen_trajectory.append(chosen_trajectory[-1] + r" /qwerdf12344567" + chosen_completions[i][0]) 255 | previous_rating_completions.append(chosen_completions[i][1]) 256 | else: 257 | assert i == total_steps - 1 258 | return { 259 | 'question': question, 260 | 'correct': correct_process_reasoning_s, 261 | 'incorrect': incorrect_process_reasoning_s 262 | } 263 | 264 | 265 | if __name__ == '__main__': 266 | dicts = parse_jsonl('dataset/prm_dataset/phase1_train.jsonl') 267 | questions, positive_list, negative_list, neutral_list, chosen_completion_list, status_list = split_correct_and_wrong_step_phase1( 268 | dicts) 269 | # 270 | orm_list = phase1_ORM_dataset(questions, positive_list, negative_list, neutral_list, chosen_completion_list) 271 | # print(orm_list[1]) 272 | # print(len(orm_list)) 273 | # print("\n\n\n\n\n") 274 | correct_wrong_comparison(orm_list, "orm phase1") 275 | 276 | prm_list = phase1_PRM_dataset(questions, positive_list, negative_list, neutral_list, chosen_completion_list) 277 | # print(prm_list[0]) 278 | # print(len(prm_list)) 279 | # print("\n\n\n\n\n\n") 280 | correct_wrong_comparison(prm_list, "prm phase1") 281 | 282 | hrm_list = phase1_HRM_dataset(questions, positive_list, negative_list, neutral_list, chosen_completion_list) 283 | # print(hrm_list[2]) 284 | # print(len(hrm_list)) 285 | correct_wrong_comparison(hrm_list, "hrm phase1") 286 | -------------------------------------------------------------------------------- /sft_rw_manual_annotation/best_of_n.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import json 4 | import os 5 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 6 | import torch 7 | from llm_query import orm_query, prm_query, orm_parallel_query, prm_parallel_query 8 | from construct_dataset.construct_prompt import construct_ORM_prompt, construct_PRM_HRM_prompt_v2, \ 9 | construct_policy_model_prompt_for_ORM, construct_policy_model_prompt_for_PRM_HRM 10 | import torch.nn.functional as F 11 | from grading.grader import grade_answer 12 | from tqdm import tqdm 13 | import gc 14 | import matplotlib.pyplot as plt 15 | 16 | MAX_HEIGHT = 100 17 | 18 | 19 | def load_rw_model(model_path): 20 | model = AutoModelForSequenceClassification.from_pretrained(model_path, 21 | num_labels=2, 22 | attn_implementation="flash_attention_2", 23 | torch_dtype=torch.float16, 24 | ) 25 | tokenizer = AutoTokenizer.from_pretrained(model_path) 26 | 27 | model.config.pad_token_id = model.config.eos_token_id 28 | device = torch.device("cuda") 29 | model.to(device) 30 | return model, tokenizer 31 | 32 | 33 | def load_test_data(data_path='dataset/prm_dataset/phase_test.jsonl'): 34 | question_answer_pairs = [] 35 | with open(data_path) as f: 36 | for line in f.readlines(): 37 | dic = json.loads(line) 38 | if dic['label']['finish_reason'] != "solution": 39 | continue 40 | 41 | question = dic['question']['problem'] 42 | ground_truth = dic['question']['ground_truth_answer'] 43 | question_answer_pairs.append((question, ground_truth)) 44 | 45 | return question_answer_pairs[:1] 46 | 47 | 48 | def clear_memory(): 49 | gc.collect() 50 | torch.cuda.empty_cache() 51 | 52 | 53 | def calculate_ORM_scores(model, tokenizer, question, answers, N, N_batch=4): 54 | assert len(answers) == N, "The length of answers must be equal to N" 55 | clear_memory() 56 | device = "cuda" 57 | 58 | answer_score_pairs = [] 59 | 60 | for i in range(0, N, N_batch): 61 | batch_answers = answers[i:i + N_batch] 62 | prompts = [construct_ORM_prompt(question, answer) for answer in batch_answers] 63 | 64 | inputs = tokenizer(prompts, return_tensors="pt", max_length=4096, padding="max_length", truncation=True) 65 | inputs = {key: value.to(device) for key, value in inputs.items()} 66 | 67 | with torch.no_grad(): 68 | outputs = model(**inputs) 69 | 70 | logits = outputs.logits 71 | probs = F.softmax(logits, dim=-1) 72 | 73 | positive_scores = probs[:, 1].tolist() 74 | batch_answer_score_pairs = [(score, answer) for score, answer in zip(positive_scores, batch_answers)] 75 | answer_score_pairs.extend(batch_answer_score_pairs) 76 | print(f"length of answer pairs: {len(answer_score_pairs)}") 77 | answer_score_pairs.sort(reverse=True) 78 | return answer_score_pairs 79 | 80 | 81 | def calculate_PRM_HRM_scores(model, tokenizer, question, N, current_steps, previous_steps="", N_batch=2): 82 | assert len(current_steps) == N, "The length of current steps must be equal to N" 83 | 84 | device = "cuda" 85 | current_step_score_pairs = [] 86 | 87 | for i in range(0, N, N_batch): 88 | batch_steps = current_steps[i:i + N_batch] 89 | prompts = [construct_PRM_HRM_prompt_v2(question, previous_steps, current_step) for current_step in batch_steps] 90 | 91 | inputs = tokenizer(prompts, return_tensors="pt", max_length=4096, padding="max_length", truncation=True) 92 | inputs = {key: value.to(device) for key, value in inputs.items()} 93 | 94 | with torch.no_grad(): 95 | outputs = model(**inputs) 96 | logits = outputs.logits 97 | probs = F.softmax(logits, dim=-1) 98 | 99 | positive_scores = probs[:, 1].tolist() 100 | batch_step_score_pairs = [(score, step) for score, step in zip(positive_scores, batch_steps)] 101 | current_step_score_pairs.extend(batch_step_score_pairs) 102 | 103 | del inputs, outputs, logits, probs 104 | torch.cuda.empty_cache() 105 | 106 | current_step_score_pairs.sort(reverse=True) 107 | return current_step_score_pairs 108 | 109 | 110 | def extract_answer(text: str, placeholder="# Answer", end_placeholder='# END!'): 111 | text = text.lower() 112 | left_idx = text.rindex(placeholder.lower()) 113 | length = len(placeholder) 114 | try: 115 | right_idx = text.rindex(end_placeholder.lower()) 116 | except: 117 | right_idx = -1 118 | return text[left_idx + length:right_idx].strip() 119 | 120 | 121 | def orm_best_of_n(model_path, host, port, model_name, api_key="", N=2, 122 | test_data_path='dataset/prm_dataset/phase_test.jsonl', repetition_penalty=1.0): 123 | test_data_pairs = load_test_data(test_data_path) 124 | correct_cnt = 0 125 | total_cnt = len(test_data_pairs) 126 | 127 | final_answers = [] 128 | answer_situations = [] 129 | ground_truths = [] 130 | 131 | orm_model, orm_tokenizer = load_rw_model(model_path) 132 | 133 | for idx, (question, ground_truth) in enumerate(tqdm(test_data_pairs)): 134 | ground_truths.append(ground_truth) 135 | policy_model_prompt = construct_policy_model_prompt_for_ORM(question) 136 | 137 | print(policy_model_prompt) 138 | try: 139 | candidates = orm_parallel_query(host, port, model_name, policy_model_prompt, api_key, n=N) 140 | except Exception as e: 141 | candidates = [] 142 | 143 | for _ in range(N): 144 | answer = orm_query(host, port, model_name, policy_model_prompt, api_key) 145 | print(f"for question {idx}, the {_} try for current answer: {answer}\n\n----") 146 | candidates.append(answer) 147 | for candidate in candidates: 148 | print(f"------{candidate}------\n\n") 149 | sorted_answer_score_pairs = calculate_ORM_scores(orm_model, orm_tokenizer, question, candidates, N) 150 | 151 | _, answer = sorted_answer_score_pairs[0] 152 | 153 | try: 154 | final_answer = extract_answer(answer) 155 | except Exception as e: 156 | print(e) 157 | final_answer = "cannot find a correct answer." 158 | 159 | final_answers.append(final_answer) 160 | 161 | true_or_false = grade_answer(final_answer, ground_truth) 162 | answer_situations.append(true_or_false) 163 | 164 | if true_or_false: 165 | correct_cnt += 1 166 | print(f"find right answer for question {idx}") 167 | 168 | base_path = f"stats/orm/" 169 | os.makedirs(base_path, exist_ok=True) 170 | 171 | with open(os.path.join(base_path, f"{model_name}_N_{N}_repetition_penalty_{repetition_penalty}.txt"), "w") as f: 172 | for final_answer, ground_truth, answer_situation in zip(final_answers, ground_truths, answer_situations): 173 | f.write(f"{final_answer}\t{ground_truth}\t{answer_situation}\n") 174 | print(f"N = {N}, the correct rate is {correct_cnt / total_cnt}") 175 | 176 | 177 | def prm_hrm_best_of_n(model_path, host, port, model_name, api_key="", N=2, 178 | test_data_path='dataset/prm_dataset/phase_test.jsonl', task="prm", repetition_penalty=1.0): 179 | def clear_memory(): 180 | gc.collect() 181 | torch.cuda.empty_cache() 182 | 183 | def whether_contain_answer(text): 184 | return "# Answer" in text 185 | 186 | base_path = f"stats/{task}/" 187 | os.makedirs(base_path, exist_ok=True) 188 | 189 | saving_filename = os.path.join(base_path, f"{model_name}_N_{N}_repetition_penalty_{repetition_penalty}_{task}.txt") 190 | 191 | test_data_pairs = load_test_data(test_data_path) 192 | correct_cnt = 0 193 | total_cnt = len(test_data_pairs) 194 | 195 | final_answers = [] 196 | answer_situations = [] 197 | ground_truths = [] 198 | 199 | prm_model, prm_tokenizer = load_rw_model(model_path) 200 | 201 | for question, ground_truth in tqdm(test_data_pairs): 202 | clear_memory() 203 | 204 | previous_steps = "" 205 | ground_truths.append(ground_truth) 206 | found_answer = False 207 | for height in range(MAX_HEIGHT): 208 | policy_model_prompt = construct_policy_model_prompt_for_PRM_HRM(question, previous_steps) 209 | print(f"---------\nheight{height}: policy model prompt") 210 | print(policy_model_prompt) 211 | print(f"---------\nPrompt finish!") 212 | 213 | try: 214 | candidates = prm_parallel_query(host, port, model_name, policy_model_prompt, api_key, n=N) 215 | except Exception as e: 216 | candidates = [] 217 | 218 | for _ in range(N): 219 | try: 220 | intermediate_step = prm_query(host, port, model_name, policy_model_prompt, api_key) 221 | # intermediate_step = intermediate_step.replace("\n\n", " ").replace("\n", " ") 222 | candidates.append(intermediate_step) 223 | except Exception as e: 224 | print(e) 225 | break 226 | if not candidates: 227 | break 228 | 229 | sorted_current_step_score_pairs = calculate_PRM_HRM_scores(prm_model, prm_tokenizer, question, N, 230 | candidates, 231 | previous_steps) 232 | _, current_step = sorted_current_step_score_pairs[0] 233 | 234 | # current_step = current_step.strip() 235 | 236 | print(f"current step height{height}: ", current_step) 237 | 238 | previous_steps = previous_steps + current_step 239 | print(f"Updated previous steps height{height}: ", previous_steps) 240 | 241 | print("---------") 242 | 243 | if whether_contain_answer(current_step): 244 | found_answer = True 245 | break 246 | if found_answer: 247 | answer = extract_answer(previous_steps) 248 | else: 249 | answer = "cannot find a correct answer." 250 | final_answers.append(answer) 251 | true_or_false = grade_answer(answer, ground_truth) 252 | answer_situations.append(true_or_false) 253 | if true_or_false: 254 | correct_cnt += 1 255 | print(f"find answer in {question}") 256 | 257 | with open(saving_filename, "a") as f: 258 | f.write(f"{answer}\t{ground_truth}\t{true_or_false}\n") 259 | 260 | print(f"N = {N}, the correct rate is {correct_cnt / total_cnt}") 261 | # base_path = f"stats/{task}/" 262 | # os.makedirs(base_path, exist_ok=True) 263 | # 264 | # with open(os.path.join(base_path, f"N_{N}_repetition_penalty_{repetition_penalty}_new_prm.txt"), "w") as f: 265 | # for final_answer, ground_truth, answer_situation in zip(final_answers, ground_truths, answer_situations): 266 | # f.write(f"{final_answer}\t{ground_truth}\t{answer_situation}\n") 267 | # print(f"N = {N}, the correct rate is {correct_cnt / total_cnt}") 268 | 269 | 270 | if __name__ == '__main__': 271 | parser = argparse.ArgumentParser() 272 | parser.add_argument("--host", type=str, default="127.0.0.1") 273 | parser.add_argument("--port", type=str, default="10086") 274 | parser.add_argument("--task", type=str, default="orm") 275 | parser.add_argument("--model_name", type=str, default="policy_model") 276 | parser.add_argument("--api_key", type=str, default="xxxx") 277 | parser.add_argument("--N", type=int, default=2) 278 | parser.add_argument("--test_data_path", type=str, default='dataset/prm_dataset/phase1_test.jsonl') 279 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 280 | 281 | args = parser.parse_args() 282 | model_path = f"best_model/{args.task}" 283 | task = args.task 284 | host = args.host 285 | port = args.port 286 | model_name = args.model_name 287 | api_key = args.api_key 288 | repetition_penalty = args.repetition_penalty 289 | 290 | N = int(args.N) 291 | test_data_path = args.test_data_path 292 | 293 | if task == "orm": 294 | orm_best_of_n(model_path, host, port, model_name, api_key, N, test_data_path, repetition_penalty) 295 | elif task == "prm": 296 | prm_hrm_best_of_n(model_path, host, port, model_name, api_key, N, test_data_path, task, repetition_penalty) 297 | elif task == "hrm": 298 | prm_hrm_best_of_n(model_path, host, port, model_name, api_key, N, test_data_path, task, repetition_penalty) 299 | -------------------------------------------------------------------------------- /self-training/sft_policy_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Introduce log KL divergence to sft policy model, detailed analysis is shown in the paper. 3 | """ 4 | from datasets import load_dataset 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | from transformers import Trainer, TrainingArguments 7 | import os 8 | from datetime import datetime 9 | import argparse 10 | from functools import partial 11 | import torch 12 | import torch.nn.functional as F 13 | from accelerate import Accelerator 14 | from torch.utils.data import DataLoader 15 | 16 | accelerator = Accelerator() 17 | device = accelerator.device 18 | 19 | 20 | class SimpleDataCollator: 21 | def __init__(self): 22 | pass 23 | 24 | def __call__(self, features): 25 | batch = { 26 | "input_ids": torch.stack( 27 | [torch.tensor(f["input_ids"]) if not isinstance(f["input_ids"], torch.Tensor) else f["input_ids"] for f 28 | in features]), 29 | "attention_mask": torch.stack([torch.tensor(f["attention_mask"]) if not isinstance(f["attention_mask"], 30 | torch.Tensor) else f[ 31 | "attention_mask"] for f in features]), 32 | "labels": torch.stack( 33 | [torch.tensor(f["labels"]) if not isinstance(f["labels"], torch.Tensor) else f["labels"] for f in 34 | features]), 35 | "logits_path": [f["logits_path"] for f in features], 36 | } 37 | return batch 38 | 39 | 40 | def choose_task(data_path, task: str = 'orm'): 41 | train_data_path = os.path.join(data_path, f'can_use_{task}_train_with_logits.jsonl') 42 | eval_data_path = os.path.join(data_path, f'can_use_{task}_eval_with_logits.jsonl') 43 | 44 | print(f"train data path: {train_data_path}") 45 | print(f"eval data path: {eval_data_path}") 46 | 47 | train_dataset = load_dataset('json', data_files=train_data_path) 48 | eval_dataset = load_dataset('json', data_files=eval_data_path) 49 | 50 | train_dataset = train_dataset['train'].shuffle() 51 | eval_dataset = eval_dataset['train'].shuffle() 52 | 53 | eval_dataset = eval_dataset.select(list(range(int(len(eval_dataset) * 0.4)))) 54 | # train_dataset = train_dataset.select([0, 1, 2, 3, 4, 5, 6, 7]) 55 | # eval_dataset = eval_dataset.select([0, 1, 2, 3, 4, 5, 6, 7, ]) 56 | return train_dataset, eval_dataset 57 | 58 | 59 | def pad_or_truncate(logits, target_length=4096, padding_value=-1e9): 60 | # 确保输入的 logits 是可处理的形状 61 | if len(logits.size()) == 4: # 例如 [1, 1, seq_len, 152064] 62 | logits = logits.squeeze(0) # -> [1, seq_len, 152064] 63 | if len(logits.size()) == 3: # 例如 [1, seq_len, 152064] 64 | logits = logits.squeeze(0) # -> [seq_len, 152064] 65 | elif len(logits.size()) != 2: # 其他意外形状 66 | raise ValueError(f"Unexpected logits shape: {logits.size()}") 67 | 68 | seq_len = logits.size(0) 69 | if seq_len < target_length: 70 | # 填充 0 到目标长度 71 | padding = (0, 0, 0, target_length - seq_len) # (left, right, top, bottom) 72 | logits = F.pad(logits, padding, value=padding_value) 73 | elif seq_len > target_length: 74 | # 截断到目标长度 75 | logits = logits[:target_length, :] 76 | return logits.unsqueeze(0) # 恢复为 [1, target_length, 152064] 77 | 78 | 79 | def pad_or_truncate_for_multiple_batch(logits, target_length=4096, padding_value=-1e9): 80 | # [batch size, seq len, vocab size] 81 | seq_len = logits.size(1) 82 | if seq_len < target_length: 83 | # 填充 0 到目标长度 84 | padding = (0, 0, 0, target_length - seq_len) # (left, right, top, bottom) 85 | logits = F.pad(logits, padding, value=padding_value) 86 | elif seq_len > target_length: 87 | # 截断到目标长度 88 | logits = logits[:, :target_length, :] 89 | return logits 90 | 91 | 92 | def kl_divergence_loss_for_batch_one(student_logits, teacher_logits): 93 | # print("Original teacher_logits shape:", teacher_logits.size()) 94 | # print("Original student_logits shape:", student_logits.size()) 95 | 96 | # 处理 teacher_logits 的维度(从 4D 到 3D) 97 | if len(teacher_logits.size()) == 4: 98 | teacher_logits = teacher_logits.squeeze(1) # [1, 1, seq_len, 152064] -> [1, seq_len, 152064] 99 | elif len(teacher_logits.size()) != 3: 100 | raise ValueError(f"Unexpected teacher_logits shape: {teacher_logits.size()}") 101 | # 获取序列长度 102 | student_seq_len = student_logits.size(1) # 应为 4096 103 | teacher_seq_len = teacher_logits.size(1) # 如 1599 without padding 104 | 105 | target_length = teacher_seq_len 106 | 107 | if student_seq_len < target_length: 108 | raise ValueError(f"The size of student seq length is wrong! The student seq length should be {student_seq_len}") 109 | elif student_seq_len > target_length: 110 | student_logits = student_logits[:, :target_length, :] 111 | 112 | kl_loss = F.kl_div(F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1), reduction='batchmean') 113 | return kl_loss 114 | 115 | 116 | def kl_divergence_loss_for_batch_more(student_logits, teacher_logits): 117 | # bs = student_logits.size(0) 118 | kl_loss = F.kl_div(F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1), 119 | reduction='batchmean') 120 | return kl_loss 121 | 122 | 123 | def compute_loss(model, inputs, return_outputs=False): 124 | input_ids = inputs["input_ids"] 125 | attention_mask = inputs["attention_mask"] 126 | labels = inputs["labels"] 127 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 128 | teacher_model_logits_paths = inputs['logits_path'] 129 | 130 | teacher_logits_list = [] 131 | seq_lengths = [] 132 | for path in teacher_model_logits_paths: 133 | logits = torch.load(path, weights_only=True, map_location=device).to(device) 134 | seq_len = logits.size(1) if len(logits.size()) == 3 else logits.size(0) 135 | seq_lengths.append(seq_len) 136 | 137 | max_length = max(seq_lengths) 138 | 139 | for path in teacher_model_logits_paths: 140 | logits = torch.load(path, weights_only=True, map_location=device).to(device) 141 | adjusted_logits = pad_or_truncate(logits, target_length=max_length) 142 | teacher_logits_list.append(adjusted_logits) 143 | 144 | teacher_logits = torch.stack(teacher_logits_list) 145 | logits = outputs.logits 146 | loss = outputs.loss 147 | path = f"plots/data/{args.task}_{args.kl}" 148 | os.makedirs(path, exist_ok=True) 149 | 150 | if len(teacher_model_logits_paths) == 1: 151 | # print("training logit size()", logits.size()) 152 | # print("training teacher size()", teacher_logits.size()) 153 | kl_loss = kl_divergence_loss_for_batch_one(logits, teacher_logits) 154 | kl_loss_log = torch.log(1 + kl_loss) 155 | # kl_loss_log = kl_loss 156 | 157 | with open(os.path.join(path, "train.txt"), "a") as f: 158 | f.write(f"{kl_loss_log}/{loss}\n") 159 | else: 160 | # print(f"seq_lengths = {seq_lengths}") 161 | # print("Teacher logits size before squeeze:", teacher_logits.size()) 162 | teacher_logits = teacher_logits.squeeze(1) 163 | # print("Teacher logits size after squeeze", teacher_logits.size()) 164 | # print(f"max_length={max_length}") 165 | # print(f"logits size is {logits.size()}") 166 | logits = pad_or_truncate_for_multiple_batch(logits, max_length) 167 | logits = logits.squeeze(0) 168 | # print(f"logits size after padding is {logits.size()}") 169 | kl_loss = kl_divergence_loss_for_batch_more(logits, teacher_logits) 170 | # print(f"KL loss is {kl_loss}") 171 | kl_loss_log = torch.log(1 + kl_loss) 172 | # kl_loss_log = kl_loss 173 | with open(os.path.join(path, "eval.txt"), "a") as f: 174 | f.write(f"{kl_loss_log}/{loss}\n") 175 | 176 | total_loss = loss + kl_weight * kl_loss_log 177 | # print(f"Loss: {loss}, kl loss log: {kl_loss_log}, total loss: {total_loss}") 178 | return (total_loss, outputs) if return_outputs else total_loss 179 | 180 | 181 | def compute_loss_kl_0(model, inputs, return_outputs=False): 182 | input_ids = inputs["input_ids"] 183 | attention_mask = inputs["attention_mask"] 184 | labels = inputs["labels"] 185 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 186 | 187 | loss = outputs.loss 188 | path = f"plots/data/{args.task}_{args.kl}" 189 | os.makedirs(path, exist_ok=True) 190 | 191 | with open(os.path.join(path, "eval.txt"), "a") as f: 192 | f.write(f"{loss}\n") 193 | 194 | total_loss = loss 195 | 196 | return (total_loss, outputs) if return_outputs else total_loss 197 | 198 | 199 | def preprocess_function(examples, tokenizer): 200 | inputs = examples["input"] 201 | logits_path = examples['logits_path'] 202 | model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=4096, return_tensors="pt") 203 | input_ids = model_inputs["input_ids"].squeeze(0) 204 | labels = input_ids.clone() 205 | labels = torch.cat([labels[1:], torch.tensor([tokenizer.pad_token_id], dtype=labels.dtype)]) 206 | model_inputs["input_ids"] = input_ids 207 | model_inputs["labels"] = labels 208 | model_inputs["attention_mask"] = model_inputs["attention_mask"].squeeze(0) 209 | model_inputs["logits_path"] = logits_path 210 | return model_inputs 211 | 212 | 213 | class CustomTrainer(Trainer): 214 | def get_train_dataloader(self): 215 | return DataLoader( 216 | self.train_dataset, 217 | batch_size=self.args.per_device_train_batch_size, 218 | collate_fn=self.data_collator, 219 | shuffle=True, 220 | drop_last=self.args.dataloader_drop_last, 221 | num_workers=self.args.dataloader_num_workers, 222 | pin_memory=self.args.dataloader_pin_memory, 223 | ) 224 | 225 | def get_eval_dataloader(self, eval_dataset=None): 226 | eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset 227 | return DataLoader( 228 | eval_dataset, 229 | batch_size=self.args.per_device_eval_batch_size, 230 | collate_fn=self.data_collator, 231 | shuffle=False, 232 | num_workers=self.args.dataloader_num_workers, 233 | pin_memory=self.args.dataloader_pin_memory, 234 | ) 235 | 236 | def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): 237 | # 将输入数据转移到模型所在的设备 238 | inputs = self._prepare_inputs(inputs) 239 | with torch.no_grad(): 240 | loss = self.compute_loss(model, inputs) 241 | return (loss, None, None) 242 | 243 | 244 | def train(model_path, train_dataset, eval_dataset, idx=1, task='prm', 245 | time=datetime.now().strftime("%Y-%m-%d_%H-%M-%S")): 246 | model = AutoModelForCausalLM.from_pretrained(model_path, 247 | attn_implementation="flash_attention_2", 248 | torch_dtype=torch.float16, 249 | ) 250 | tokenizer = AutoTokenizer.from_pretrained(model_path) 251 | 252 | model.config.pad_token_id = model.config.eos_token_id 253 | model_name = model_path.split('/')[-1] 254 | output_dir = f"outputs/self_training_v{idx}_policy_model/{task}/{model_name}/KL_{kl_weight}_{time}" 255 | 256 | EVAL_STEP = 50 257 | 258 | training_args = TrainingArguments( 259 | 260 | output_dir=output_dir, 261 | learning_rate=2e-6, 262 | # max_grad_norm=0.1, 263 | # warmup_steps=20, 264 | warmup_steps=50, 265 | per_device_train_batch_size=1, 266 | per_device_eval_batch_size=4, 267 | num_train_epochs=1, 268 | eval_strategy="steps", 269 | eval_steps=EVAL_STEP, 270 | save_steps=EVAL_STEP, 271 | save_strategy="steps", 272 | logging_steps=EVAL_STEP, 273 | logging_dir=f"./self-train_logs_v{idx}/{task}/{model_name}/{time}", 274 | 275 | load_best_model_at_end=True, 276 | fp16=True, 277 | gradient_accumulation_steps=16, 278 | 279 | greater_is_better=False, 280 | save_total_limit=5, 281 | metric_for_best_model="eval_loss", 282 | 283 | deepspeed="deepspeed_config/policy_model_72b.json", 284 | report_to="tensorboard", 285 | 286 | # gradient_checkpointing=True 287 | ) 288 | 289 | data_collator = SimpleDataCollator() 290 | 291 | trainer = CustomTrainer( 292 | model=model, 293 | args=training_args, 294 | train_dataset=train_dataset, 295 | eval_dataset=eval_dataset, 296 | tokenizer=tokenizer, 297 | data_collator=data_collator 298 | ) 299 | if kl_weight > 0.0001: 300 | trainer.compute_loss = compute_loss 301 | else: 302 | trainer.compute_loss = compute_loss_kl_0 303 | 304 | trainer.train() 305 | return trainer 306 | 307 | 308 | if __name__ == '__main__': 309 | parser = argparse.ArgumentParser() 310 | parser.add_argument("--data_path", type=str, default="dataset/self_training_v1_policy_without_rm") 311 | parser.add_argument("--model_path", type=str, 312 | default="models/Qwen2.5-Math-7B-Instruct") 313 | 314 | parser.add_argument("--task", type=str, default='prm') 315 | parser.add_argument("--idx", type=int, default=1) 316 | parser.add_argument("--kl", type=float, default=0.1) 317 | 318 | args = parser.parse_args() 319 | kl_weight = args.kl 320 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, legacy=False) 321 | 322 | train_dataset, eval_dataset = choose_task(args.data_path, args.task) 323 | # print("Train dataset sample 0:", train_dataset[0]) 324 | # print("Keys in sample 0:", train_dataset[0].keys()) 325 | preprocess_with_tokenizer = partial(preprocess_function, tokenizer=tokenizer) 326 | 327 | train_dataset = train_dataset.map(preprocess_with_tokenizer, batched=False, keep_in_memory=True) 328 | eval_dataset = eval_dataset.map(preprocess_with_tokenizer, batched=False, keep_in_memory=True) 329 | 330 | train(model_path=args.model_path, train_dataset=train_dataset, idx=args.idx, eval_dataset=eval_dataset, 331 | task=args.task) 332 | # nohup accelerate launch --config_file accelerate_config/8gpus.yaml --gpu_ids 0,1,2,3,4,5,6,7 self-training/sft_policy_model.py --task hrm --idx 1 > logs/self_training_v1_policy_model_hrm.log & 333 | # nohup accelerate launch --config_file accelerate_config/6gpus.yaml --gpu_ids 0,1,2,3,4,5 self-training/sft_policy_model.py --task prm --idx 1 > logs/self_training_v1_policy_model_prm.log & 334 | # nohup accelerate launch --config_file accelerate_config/6gpus.yaml --gpu_ids 0,1,2,3,4,5 self-training/sft_policy_model.py --task hrm --idx 1 > logs/self_training_v1_policy_model_hrm_kl_version.log & 335 | # nohup accelerate launch --config_file accelerate_config/4gpus.yaml --main_process_port 29501 --gpu_ids 0,1,2,3 self-training/sft_policy_model.py --task hrm --idx 1 --kl 10 > logs/self_training_v1_policy_model_hrm_kl_10.log & 336 | # nohup accelerate launch --config_file accelerate_config/4gpus.yaml --main_process_port 29502 --gpu_ids 4,5,6,7 self-training/sft_policy_model.py --task hrm --idx 1 --kl 0.5 > logs/self_training_v1_policy_model_hrm_kl_05.log & 337 | # nohup accelerate launch --config_file accelerate_config/8gpus.yaml --main_process_port 29502 --gpu_ids 0,1,2,3,4,5,6,7 self-training/sft_policy_model.py --task hrm --idx 1 --kl 5 > logs/self_training_v1_policy_model_hrm_kl_5.log & 338 | # nohup accelerate launch --config_file accelerate_config/8gpus.yaml --main_process_port 29500 --gpu_ids 0,1,2,3,4,5,6,7 self-training/sft_policy_model.py --task hrm --idx 1 --kl 1 > logs/self_training_v1_policy_model_hrm_kl_1.log & 339 | # nohup accelerate launch --config_file accelerate_config/6gpus.yaml --main_process_port 29501 --gpu_ids 0,1,2,3,4,5 self-training/sft_policy_model.py --task prm --idx 1 --kl 1 > logs/self_training_v1_policy_model_prm_kl_1.log & 340 | 341 | # nohup accelerate launch --config_file accelerate_config/6gpus.yaml --main_process_port 29500 --gpu_ids 0,1,2,3,4,5 self-training/sft_policy_model.py --task hrm --idx 1 --kl 0.5 > logs/self_training_v1_policy_model_hrm_kl_05.log & 342 | # nohup accelerate launch --config_file accelerate_config/6gpus.yaml --main_process_port 29501 --gpu_ids 0,1,2,3,4,5 self-training/sft_policy_model.py --task prm --idx 1 --kl 0.5 > logs/self_training_v1_policy_model_prm_kl_05.log & 343 | # nohup accelerate launch --config_file accelerate_config/8gpus.yaml --main_process_port 29500 --gpu_ids 0,1,2,3,4,5,6,7 self-training/sft_policy_model.py --task hrm --idx 1 --kl 0 > logs/self_training_v1_policy_model_prm_kl_0.log & 344 | 345 | 346 | # nohup accelerate launch --config_file accelerate_config/8gpus.yaml --main_process_port 29500 --gpu_ids 0,1,2,3,4,5,6,7 self-training/sft_policy_model.py --task hrm --idx 1 --kl 0.001 > logs/self_training_v1_policy_model_hrm_kl_05_8gpus.log & 347 | 348 | 349 | # nohup accelerate launch --config_file accelerate_config/4gpus.yaml --main_process_port 29501 --gpu_ids 0,1,2,3 self-training/sft_policy_model.py --task hrm --idx 1 --kl 10 > logs/self_training_v1_policy_model_hrm_kl_10.log & 350 | # nohup accelerate launch --config_file accelerate_config/4gpus.yaml --main_process_port 29502 --gpu_ids 4,5,6,7 self-training/sft_policy_model.py --task hrm --idx 1 --kl 0.5 > logs/self_training_v1_policy_model_hrm_kl_05.log & 351 | # nohup accelerate launch --config_file accelerate_config/4gpus.yaml --main_process_port 29503 --gpu_ids 1,2,3,4 self-training/sft_policy_model.py --task hrm --idx 1 --kl 0.001 > logs/self_training_v1_policy_model_hrm_kl_001.log & 352 | --------------------------------------------------------------------------------