├── .DS_Store ├── CoT └── task.py ├── MCTS ├── args.md ├── base.py ├── mcts.py └── task.py ├── PRM ├── train_VM_chatglm.py ├── train_VM_chatglm_deepspeed.py └── train_VM_mistral.py ├── README.md ├── ToT ├── base.py ├── bfs.py ├── dfs.py └── task.py ├── assets ├── comparison.png ├── overall.png ├── results.png ├── searches.png └── vm_results.png ├── data ├── math │ └── math_500.json ├── scibench │ ├── atkins_standardized.json │ ├── calculus_standardized.json │ ├── chemmc_standardized.json │ ├── class_standardized.json │ ├── diff_standardized.json │ ├── fund_standardized.json │ ├── matter_standardized.json │ ├── quan_standardized.json │ ├── stat_standardized.json │ └── thermo_standardized.json ├── scibench_100 │ └── 100.json └── scieval │ ├── scieval_part1.json │ ├── scieval_part2.json │ ├── scieval_part3.json │ └── scieval_part4.json ├── eval_vm.py ├── evaluate.py ├── figures ├── MATH2_completion_self_train.pdf ├── data │ └── ablation_math2_self_training.xlsx └── plot_math_self_training.py ├── models ├── get_response.py ├── inference_models.py ├── model.py └── value_models.py ├── requirements_mistral.txt ├── requirements_sciglm.txt ├── self_train ├── .DS_Store ├── config │ ├── deep3_new_config.yaml │ ├── deepspeed_zero1.yaml │ ├── deepspeed_zero2.yaml │ ├── deepspeed_zero3.json │ ├── deepspeed_zero3.yaml │ ├── default_config.yaml │ └── yaml_to_json.py ├── generation │ ├── generate_both_samples_GSM.py │ ├── generate_both_samples_MATH.py │ └── generate_both_samples_TheoremQA.py ├── self_train_dpo.py └── vm_critic │ ├── filter_policy_examples_by_value.py │ ├── manual_self_critic.py │ ├── manual_vm_critic.py │ └── vm_critic_for_extracted_samples.py ├── tasks ├── prompts.py └── science.py └── utils ├── aggregate_value_samples.py ├── answer_extractor.py ├── extract_both_samples.py ├── format_dpo.py ├── json_operator.py ├── orm_score.py ├── result_evaluator.py ├── self_consistency.py ├── solution_summary_extractor.py ├── verify_MATH.py ├── verify_answer.py ├── verify_llm.py ├── visualize.py └── weighted_consistency.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/ReST-MCTS/2d5f488c3d6e24f99d50a9860e818383b1bb5883/.DS_Store -------------------------------------------------------------------------------- /CoT/task.py: -------------------------------------------------------------------------------- 1 | import re 2 | from tasks.science import SearchTask 3 | from models.get_response import * 4 | from utils.verify_MATH import exact_match_score 5 | from utils.solution_summary_extractor import extract_summary_from_solution 6 | 7 | 8 | class CoT_Task(SearchTask): 9 | def __init__(self, data, propose_method='glm', value_method='glm', temperature=0.7, max_tokens=2048, seed=170, 10 | max_length=2048, truncation=True, 11 | do_sample=True, max_new_tokens=1024, evaluate='', summary=False, lang='zh', answer=None, 12 | verify_method='string', do_self_critic=False): 13 | super().__init__(data, propose_method, value_method) 14 | self.mode = 'cot' 15 | self.temperature = temperature 16 | self.max_tokens = max_tokens 17 | self.seed = seed 18 | self.max_length = max_length 19 | self.truncation = truncation 20 | self.do_sample = do_sample 21 | self.max_new_tokens = max_new_tokens 22 | self.evaluate = evaluate 23 | self.summary = summary 24 | self.lang = lang 25 | self.answer = answer 26 | self.verify_method = verify_method 27 | self.do_self_critic = do_self_critic 28 | 29 | def get_summary(self, solution: str): 30 | if self.lang == 'zh': 31 | if not self.summary: 32 | if "综上所述," in solution: 33 | summ = solution.split("综上所述,")[-1] 34 | return "综上所述," + summ 35 | elif '。' in solution: 36 | summ = solution.split("。")[-2] 37 | return "综上所述," + summ + '。' 38 | else: 39 | return '' 40 | else: 41 | if self.evaluate == 'scibench': 42 | prompt = self.evaluate_summary_prompt_wrap(self.question, solution) 43 | elif self.evaluate == 'scieval': 44 | prompt = self.general_evaluate_summary_prompt_wrap(self.question, solution) 45 | else: 46 | prompt = self.summary_prompt_wrap(self.question, solution) 47 | 48 | response = get_proposal(prompt, self.propose_method, self.temperature, self.max_tokens, self.seed, 49 | self.max_length, 50 | self.truncation, self.do_sample, 128) 51 | 52 | if not response: 53 | print('Get summary fail!\n') 54 | return '' 55 | p = '' 56 | for _ in response: 57 | p = p + _ + '\n' 58 | p = p.strip() 59 | 60 | if self.evaluate: 61 | if len(p) < 1: 62 | print('Get summary too short!\n') 63 | return '' 64 | 65 | if '综上所述,最终答案是:' not in p: 66 | summ = '综上所述,最终答案是:' + p 67 | print(f'Get summary:{summ}\n') 68 | return summ 69 | else: 70 | summ = '综上所述,最终答案是:' + p.split('综上所述,最终答案是:')[-1] 71 | print(f'Get summary:{summ}\n') 72 | return summ 73 | 74 | else: 75 | if len(p) < 1: 76 | print('Get summary too short!\n') 77 | return '' 78 | 79 | if '综上所述,' not in p: 80 | summ = '综上所述,' + p 81 | print(f'Get summary:{summ}\n') 82 | return summ 83 | else: 84 | summ = '综上所述,' + p.split('综上所述,')[-1] 85 | print(f'Get summary:{summ}\n') 86 | return summ 87 | else: 88 | if "Summary:" in solution: 89 | summ = solution.split("Summary:")[-1].strip() 90 | else: 91 | summ = '' 92 | return summ 93 | 94 | def get_MATH_summary(self, solution): 95 | prompt = self.MATH_summary_prompt_wrap(self.question, solution) 96 | response = get_proposal(prompt, self.propose_method, self.temperature, self.max_tokens, self.seed, 97 | self.max_length, 98 | self.truncation, self.do_sample, 128) 99 | if not response: 100 | print('Get summary fail!\n') 101 | return '' 102 | p = '' 103 | for _ in response: 104 | p = p + _ + '\n' 105 | p = p.strip() 106 | 107 | print(f'Get summary:{p}\n') 108 | return p 109 | 110 | def get_self_critic(self, solution): 111 | critic_prompt = self.self_critic_prompt_wrap(self.question, solution) 112 | output_score = get_proposal(critic_prompt, self.propose_method, self.temperature, self.max_tokens, self.seed, 113 | self.max_length, self.truncation, self.do_sample, 128) 114 | score_strs = '' 115 | for out in output_score: 116 | score_strs = score_strs + out + '\n' 117 | 118 | pattern = r'[0-9]+\.?[0-9]*' 119 | match = re.findall(pattern, score_strs) 120 | if not match: 121 | return None 122 | else: 123 | s = min(float(match[-1]), 1.0) 124 | s = max(s, 0) 125 | return s 126 | 127 | def run(self): 128 | self.clear_cache() 129 | if self.evaluate == 'math' or self.verify_method == 'string': 130 | prompt = self.cot_prompt_wrap(self.question, self.lang, True) 131 | else: 132 | prompt = self.cot_prompt_wrap(self.question, self.lang) 133 | out = get_proposal(prompt, self.propose_method, temperature=self.temperature, 134 | max_tokens=self.max_tokens, 135 | seed=self.seed, max_length=self.max_length, truncation=self.truncation, 136 | do_sample=self.do_sample, max_new_tokens=self.max_new_tokens) 137 | solution = '' 138 | for _ in out: 139 | solution = solution + _ + '\n' 140 | solution = solution.strip() 141 | print(f'Get answers:{solution}\n') 142 | 143 | if self.evaluate == 'math' or self.verify_method == 'string': 144 | cnt = 5 145 | summary = '' 146 | while cnt and not summary: 147 | summary = self.get_MATH_summary(solution) 148 | cnt -= 1 149 | 150 | if not summary: 151 | summary = extract_summary_from_solution(solution) 152 | 153 | result = exact_match_score(summary, self.answer) 154 | output = {'content': self.question, 'solution': solution, 'summary': summary, 'accurate': result, 155 | 'real_answer': self.answer} 156 | 157 | else: 158 | cnt = 5 159 | summary = '' 160 | while cnt: 161 | summary = self.get_summary(solution) 162 | if summary: 163 | break 164 | else: 165 | cnt -= 1 166 | 167 | output = {'content': self.question, 'solution': solution, 'summary': summary} 168 | 169 | if self.do_self_critic: 170 | score = None 171 | cnt = 3 172 | while score is None and cnt: 173 | score = self.get_self_critic(solution) 174 | cnt -= 1 175 | if score is None: 176 | score = 0 177 | output.update({'self_critic': score}) 178 | 179 | return output 180 | -------------------------------------------------------------------------------- /MCTS/args.md: -------------------------------------------------------------------------------- 1 | # Parameter explanation 2 | We introduce here the main args of the `MCTS*` algorithm. 3 | 4 | 1. temperature: Search temperature, used to determine the degrees of freedom for generating responses. 5 | 6 | 2. time_limit: The upper limit of the search time(ms) set in the MCTS framework. 7 | 8 | 3. iteration_limit: The maximum number of search rounds for exploration. 9 | 10 | 4. roll_policy: The strategy for Monte Carlo simulation in the MCTS framework, either random or greedy. 11 | 12 | 5. exploration_constant: Constant for the UCT formula which balances exploration and exploitation. 13 | 14 | 6. roll_forward_steps: The number of forward steps in the simulation process. 15 | 16 | 7. end_gate: The lowest value threshold for determining the end of search. 17 | 18 | 8. branch: The number of branches for node expansion. 19 | 20 | 9. roll_branch: Number of branches to sample for simulation. 21 | 22 | 10. inf: The base value of an unvisited node. 23 | 24 | 11. alpha: Value update weight for Monte Carlo simulation in the MCTS framework. 25 | 26 | 12. visualize: Whether the results are visualized in a tree diagram. 27 | 28 | 13. use_case_prompt: Whether to use sample output prompt assisted generation. 29 | 30 | 14. use_reflection: Whether to use the reflection mechanism. 31 | 32 | 15. low: The lower bound of the node value. 33 | 34 | 16. high: The upper bound of the node value. -------------------------------------------------------------------------------- /MCTS/base.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | 5 | class treeNode(object): 6 | def __init__(self, pcd, parent=None, depth=0): 7 | self.pcd = pcd # str 8 | self.y = '' # str 9 | self.parent = parent # treeNode 10 | self.numVisits = 0 # int 11 | self.V = 0 # float 12 | self.children = {} # dict{str:treeNode} 13 | self.depth = depth # int 14 | self.isFullyExpanded = False # expanded 15 | self.visit_sequence = 0 16 | self.final_ans_flag = 0 17 | self.reflection = '' 18 | self.isTerminal = False # value acceptable 19 | self.on_final_route = False 20 | self.min_steps_to_correct = 1024 21 | self.summary = '' 22 | self.he = 0 # hard estimation 23 | self.se = 0 # soft estimation 24 | 25 | def __str__(self): 26 | s = ["numVisits: %d" % self.numVisits, f'V:{self.V}', "possibleActions: %s" % (self.children.keys())] 27 | return "%s: {%s}" % (self.__class__.__name__, ', '.join(s)) 28 | 29 | def append_children(self, new_pcd: str): 30 | node = treeNode(new_pcd, self, self.depth + 1) 31 | node.update_y_from_parent() 32 | self.children.update({new_pcd: node}) 33 | return self 34 | 35 | def update_y_from_parent(self): 36 | if self.parent is None: 37 | self.y = self.pcd 38 | else: 39 | self.y = self.parent.y + self.pcd 40 | 41 | def update_value(self, value): 42 | self.V = value 43 | 44 | def update_reflection(self, reflection): 45 | self.reflection = reflection 46 | 47 | def getBestV(self): # Gets the subtree maximum value node 48 | if not self.isFullyExpanded: 49 | return self, self.V 50 | max_V = self.V 51 | max_node = self 52 | for child in self.children.values(): 53 | subNode, subValue = child.getBestV() 54 | if subValue >= max_V: 55 | max_V = subValue 56 | max_node = subNode 57 | return max_node, max_V 58 | 59 | def trace_route(self): # trace route from terminal node to root 60 | cur_node = self 61 | while cur_node is not None: 62 | cur_node.on_final_route = True 63 | cur_node = cur_node.parent 64 | 65 | def get_new_value_samples(self): # get value samples from search tree (start from terminal node) 66 | if self.depth == 0: 67 | return [] 68 | step_value = 1.0 / self.depth 69 | new_samples = [] 70 | cur_node = self.parent 71 | while cur_node is not None: 72 | for child in cur_node.children.values(): 73 | if child.on_final_route: 74 | child_value = step_value * child.depth 75 | new_item = {'steps': child.y, 'value': child_value} 76 | new_samples.append(new_item) 77 | else: 78 | child_value = max(step_value * (cur_node.depth - 1), 0) 79 | new_item = {'steps': child.y, 'value': child_value} 80 | new_samples.append(new_item) 81 | cur_node = cur_node.parent 82 | return new_samples 83 | 84 | def get_all_end_root_nodes_vm(self, end_gate): 85 | end_nodes = [] 86 | if self.isFullyExpanded: 87 | for child in self.children.values(): 88 | end_nodes.extend(child.get_all_end_root_nodes_vm(end_gate)) 89 | return end_nodes 90 | else: 91 | if self.V >= end_gate or self.reflection == '': 92 | return [self] 93 | else: 94 | return [] 95 | 96 | def get_all_end_root_nodes_prm(self): 97 | end_nodes = [] 98 | if self.isFullyExpanded: 99 | for child in self.children.values(): 100 | end_nodes.extend(child.get_all_end_root_nodes_prm()) 101 | return end_nodes 102 | else: 103 | if self.reflection == '': 104 | return [self] 105 | else: 106 | return [] 107 | 108 | def get_all_value_samples_vm(self): 109 | full_value_samples = [] 110 | if self.depth == 0: 111 | self.V = 0 112 | else: 113 | if self.he == 0: 114 | r = -1 115 | else: 116 | r = 1 117 | self.V = max(0, (1 - self.parent.V) * r / self.min_steps_to_correct + self.parent.V) 118 | full_value_samples.append({'steps': self.y, 'value': self.V}) 119 | if self.isFullyExpanded: 120 | for child in self.children.values(): 121 | if child.min_steps_to_correct < 1024: 122 | sub_samples = child.get_all_value_samples_vm() 123 | full_value_samples.extend(sub_samples) 124 | return full_value_samples 125 | 126 | def get_full_value_samples_vm(self, end_leaf_nodes): 127 | for leaf in end_leaf_nodes: 128 | if leaf.min_steps_to_correct > 1: 129 | continue 130 | else: 131 | leaf.he = 1 132 | cur_node = leaf.parent 133 | while cur_node is not None: 134 | cur_node.min_steps_to_correct = min( 135 | [n.min_steps_to_correct for n in cur_node.children.values()]) + 1 136 | cur_node.he = 1 137 | cur_node = cur_node.parent 138 | for leaf in end_leaf_nodes: 139 | if leaf.min_steps_to_correct > 1: 140 | cur_node = leaf.parent 141 | while cur_node is not None and cur_node.min_steps_to_correct == 1024: 142 | cur_node = cur_node.parent 143 | if cur_node is None: 144 | continue 145 | else: 146 | m = cur_node.min_steps_to_correct 147 | cur_node = leaf 148 | while cur_node.min_steps_to_correct == 1024: 149 | cur_node.min_steps_to_correct = m 150 | cur_node = cur_node.parent 151 | else: 152 | continue 153 | value_samples = self.get_all_value_samples_vm() 154 | return value_samples 155 | 156 | def get_all_value_samples_prm(self): 157 | full_value_samples = [] 158 | if self.on_final_route: 159 | full_value_samples.append({'steps': self.y, 'value': self.he}) 160 | if self.isFullyExpanded: 161 | for child in self.children.values(): 162 | if child.on_final_route: 163 | sub_samples = child.get_all_value_samples_prm() 164 | full_value_samples.extend(sub_samples) 165 | return full_value_samples 166 | else: 167 | return [] 168 | 169 | def get_full_value_samples_prm(self, end_leaf_nodes): 170 | for leaf in end_leaf_nodes: 171 | cur_node = leaf.parent 172 | while cur_node is not None: 173 | cur_node.on_final_route = True 174 | cur_node = cur_node.parent 175 | for leaf in end_leaf_nodes: 176 | cur_node = leaf.parent 177 | while cur_node is not None: 178 | cur_node.he = max([n.he for n in cur_node.children.values() if n.on_final_route]) 179 | cur_node = cur_node.parent 180 | value_samples = self.get_all_value_samples_prm() 181 | return value_samples 182 | -------------------------------------------------------------------------------- /MCTS/mcts.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import random 4 | import numpy 5 | from functools import partial 6 | import copy 7 | from MCTS.base import treeNode 8 | 9 | 10 | def get_next_steps_roll(y: str, step_n: int, mcts_task): 11 | next_steps = [] 12 | for i in range(mcts_task.roll_branch): 13 | proposal = '' 14 | cnt = 3 15 | while not proposal and cnt: 16 | proposal = mcts_task.get_next_step(y, step_n) 17 | cnt -= 1 18 | if not proposal: 19 | continue 20 | next_steps.append(proposal) 21 | return next_steps 22 | 23 | 24 | def get_next_steps_expand(node: treeNode, mcts_task): 25 | next_steps = [] 26 | reflection = node.reflection 27 | for i in range(mcts_task.branch): 28 | proposal = '' 29 | cnt = 3 30 | while not proposal and cnt: 31 | if mcts_task.use_reflection == 'common': 32 | proposal = mcts_task.get_next_step_use_reflection(node.y, node.depth + 1, reflection) 33 | else: 34 | proposal = mcts_task.get_next_step(node.y, node.depth + 1) 35 | cnt -= 1 36 | if not proposal: 37 | continue 38 | next_steps.append(proposal) 39 | return next_steps 40 | 41 | 42 | def randomPolicy(node: treeNode, mcts_task): 43 | max_V = mcts_task.low 44 | strs = node.y 45 | cur_step = node.depth + 1 46 | if mcts_task.use_reflection == 'common': 47 | reflection = mcts_task.get_reflection(strs, cur_step) 48 | else: 49 | reflection = mcts_task.get_simple_reflection(strs, cur_step) 50 | node.update_reflection(reflection) 51 | if reflection == '': 52 | print('This step has been resolved and does not require simulation.\n') 53 | return node.V 54 | for i in range(mcts_task.roll_forward_steps): 55 | next_steps = get_next_steps_roll(strs, cur_step, mcts_task) 56 | if not next_steps: 57 | break 58 | action = random.choice(next_steps) # str 59 | strs = strs + action 60 | cur_step += 1 61 | value = mcts_task.get_step_value(strs) 62 | if value > max_V: 63 | max_V = value 64 | if mcts_task.use_reflection == 'common': 65 | cur_ref = mcts_task.get_reflection(strs, cur_step) 66 | else: 67 | cur_ref = mcts_task.get_simple_reflection(strs, cur_step) 68 | if cur_ref == '': 69 | break 70 | return max_V 71 | 72 | 73 | def greedyPolicy(node: treeNode, mcts_task): 74 | max_V = mcts_task.low 75 | strs = node.y 76 | cur_step = node.depth + 1 77 | if mcts_task.use_reflection == 'common': 78 | reflection = mcts_task.get_reflection(strs, cur_step) 79 | else: 80 | reflection = mcts_task.get_simple_reflection(strs, cur_step) 81 | node.update_reflection(reflection) 82 | if reflection == '': 83 | print('This step has been resolved and does not require simulation.\n') 84 | return node.V 85 | for i in range(mcts_task.roll_forward_steps): 86 | actions = get_next_steps_roll(strs, cur_step, mcts_task) # str_list 87 | if not actions: 88 | break 89 | new_ys = [strs + action for action in actions] 90 | cur_step += 1 91 | values = [mcts_task.get_step_value(new_y) for new_y in new_ys] 92 | idx = numpy.argmax(values) 93 | strs = new_ys[idx] 94 | value = values[idx] 95 | if value > max_V: 96 | max_V = value 97 | if mcts_task.use_reflection == 'common': 98 | cur_ref = mcts_task.get_reflection(strs, cur_step) 99 | else: 100 | cur_ref = mcts_task.get_simple_reflection(strs, cur_step) 101 | if cur_ref == '': 102 | break 103 | return max_V 104 | 105 | 106 | def MCTS_search(mcts_task): 107 | root = treeNode('') 108 | 109 | if mcts_task.limit_type == 'time': 110 | timeLimit = time.time() + mcts_task.time_limit / 1000 111 | time_start = time.time() 112 | while time.time() < timeLimit: 113 | print(f'<开始新搜索轮次,目前总时间:{time.time() - time_start}>\n') 114 | flag, node, root = executeRound(root, mcts_task) 115 | if flag: 116 | print('已找到解决方案!\n') 117 | return root, node, time.time() - time_start 118 | else: 119 | for i in range(mcts_task.iteration_limit): 120 | print(f'<开始新搜索轮次,目前已完成轮次数:{i}>\n') 121 | flag, node, root = executeRound(root, mcts_task) 122 | if flag: 123 | print('已找到解决方案!\n') 124 | return root, node, i + 1 125 | return root, None, None 126 | 127 | 128 | def executeRound(root, mcts_task): 129 | # execute a selection-expansion-simulation-backpropagation round 130 | 131 | print('-' * 40) 132 | print('选择节点阶段\n') 133 | flag, node = selectNode(root, mcts_task) 134 | if flag: 135 | if mcts_task.sample_value != 'full': 136 | return True, node, root 137 | else: 138 | node.reflection = '' 139 | 140 | print('-' * 40) 141 | print('扩充阶段\n') 142 | if node.reflection == '': 143 | print('跳过此阶段。\n') 144 | else: 145 | node = expand(node, mcts_task) 146 | 147 | if mcts_task.reward_model_type == 'vm': 148 | print('-' * 40) 149 | print('模拟搜索阶段\n') 150 | if node.reflection == '': 151 | print('跳过此阶段。\n') 152 | else: 153 | roll_node = getBestChild(node, mcts_task) 154 | best_V = greedyPolicy(roll_node, mcts_task) if mcts_task.roll_policy == 'greedy' else randomPolicy(roll_node, 155 | mcts_task) 156 | roll_node.V = roll_node.V * (1 - mcts_task.alpha) + best_V * mcts_task.alpha 157 | roll_node.numVisits += 1 158 | 159 | print('-' * 40) 160 | print('反向传播阶段\n') 161 | back_propagate(node) 162 | return False, node, root 163 | 164 | 165 | def isTerminal(node, mcts_task): 166 | if mcts_task.reward_model_type == 'vm': 167 | return node.V >= mcts_task.end_gate 168 | else: 169 | return False 170 | 171 | 172 | def selectNode(node, mcts_task): 173 | while node.isFullyExpanded: 174 | node = getBestChild(node, mcts_task) 175 | if isTerminal(node, mcts_task): 176 | node.final_ans_flag = 1 177 | return True, node 178 | else: 179 | return False, node 180 | 181 | 182 | def expand(node: treeNode, mcts_task): 183 | if not node.reflection: 184 | if mcts_task.use_reflection == 'common': 185 | reflection = mcts_task.get_reflection(node.y, node.depth + 1) 186 | else: # simple 187 | reflection = mcts_task.get_simple_reflection(node.y, node.depth + 1) 188 | node.update_reflection(reflection) 189 | if node.reflection == '': 190 | return node 191 | actions = get_next_steps_expand(node, mcts_task) 192 | if not actions: 193 | node.update_reflection('') 194 | return node 195 | 196 | for action in actions: 197 | if action not in node.children.keys(): 198 | node.append_children(action) 199 | child = node.children[action] 200 | value = mcts_task.get_step_value(child.y) 201 | child.update_value(value) 202 | if mcts_task.sample_value == 'full': 203 | if mcts_task.use_reflection == 'common': 204 | child.update_reflection(mcts_task.get_reflection(child.y, child.depth + 1)) 205 | else: 206 | child.update_reflection(mcts_task.get_simple_reflection(child.y, child.depth + 1)) 207 | child.visit_sequence = mcts_task.node_count 208 | mcts_task.update_count() 209 | node.isFullyExpanded = True 210 | return node 211 | 212 | 213 | def back_propagate(node): 214 | while node is not None: 215 | node.numVisits += 1 216 | if node.isFullyExpanded: 217 | child_Vs = [child.V * child.numVisits for child in node.children.values()] 218 | total_num_visits = sum([child.numVisits for child in node.children.values()]) 219 | if total_num_visits > 0: 220 | node.V = sum(child_Vs) / total_num_visits 221 | node = node.parent 222 | 223 | 224 | def getBestChild(node, mcts_task): 225 | bestValue = mcts_task.low 226 | bestNodes = [] 227 | for child in node.children.values(): 228 | nodeValue = child.V + mcts_task.exploration_constant * math.sqrt( 229 | 2 * math.log(node.numVisits) / child.numVisits) if child.numVisits > 0 else child.V + mcts_task.INF 230 | if nodeValue > bestValue: 231 | bestValue = nodeValue 232 | bestNodes = [child] 233 | elif nodeValue == bestValue: 234 | bestNodes.append(child) 235 | return random.choice(bestNodes) 236 | 237 | 238 | def MCTS(mcts_task): 239 | root, node, finish = MCTS_search(mcts_task) 240 | 241 | if mcts_task.sample_value == 'full': 242 | print('采样完成。\n') 243 | return None, -1, root 244 | else: 245 | if mcts_task.reward_model_type == 'vm': 246 | if finish is not None: 247 | print(f'已找到最终解!\nSolution:{node.y}\n') 248 | return node, finish, root 249 | 250 | else: 251 | best_node, best_V = root.getBestV() 252 | print(f'在规定时间/轮次内未找到满足要求价值的解答,采用最高价值价值解答代替。\nSolution:{best_node.y}\n') 253 | return best_node, -1, root 254 | else: 255 | print('尚未支持解答选择,采样结束。\n') 256 | return None, -1, root 257 | -------------------------------------------------------------------------------- /PRM/train_VM_chatglm.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CUDA_VISIBLE_DEVICES']='1' 3 | import json 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | from transformers import AutoModel, AutoTokenizer, AdamW 9 | from sklearn.metrics import accuracy_score 10 | from torch.utils.data import DataLoader, Dataset 11 | import pandas as pd 12 | 13 | max_length = 1024 14 | 15 | # Load the pre-trained ChatGLM3-6b model and tokenizer 16 | tokenizer = AutoTokenizer.from_pretrained("/workspace/ckpt/chatglm3-6b", trust_remote_code=True) 17 | base_model = AutoModel.from_pretrained("/workspace/ckpt/chatglm3-6b", 18 | trust_remote_code=True).bfloat16().cuda() 19 | 20 | # Custom Dataset class 21 | class MyDataset(Dataset): 22 | def __init__(self, data_js, tokenizer): 23 | self.data_js = data_js 24 | self.tokenizer = tokenizer 25 | 26 | def __len__(self): 27 | return len(self.data_js) 28 | 29 | def __getitem__(self, idx): 30 | prompt_answer = self.data_js[idx]['prompt_answer'] 31 | label = self.data_js[idx]['label'] 32 | 33 | encoded_pair = self.tokenizer.encode_plus( 34 | prompt_answer, 35 | padding='max_length', 36 | max_length=max_length, # Set the max length 37 | truncation=True, 38 | return_tensors='pt', # Return PyTorch Tensor format 39 | ) 40 | 41 | return { 42 | 'input_ids': encoded_pair['input_ids'].squeeze(), 43 | 'attention_mask': encoded_pair['attention_mask'].squeeze(), 44 | 'label': label 45 | } 46 | 47 | 48 | class ChatGLM_VM(nn.Module): 49 | def __init__(self, base, vocab_size, num_classes=1): 50 | super(ChatGLM_VM, self).__init__() 51 | self.base_model = base 52 | self.LN = nn.Linear(vocab_size, num_classes, dtype=torch.bfloat16) 53 | 54 | def forward(self, input_ids, attention_mask): 55 | outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask).logits[:, -1] 56 | value_outputs = self.LN(outputs) 57 | return value_outputs.squeeze(dim=1) 58 | 59 | 60 | # Load training set, validation set, and test set data 61 | train_js = 'data/train_en.json' 62 | test_js = 'data/test_en.json' 63 | val_js = 'data/valid_en.json' 64 | 65 | 66 | def read_json(source): 67 | json_list = [] 68 | with open(source, 'r', encoding='utf-8') as f: 69 | for line in f: 70 | json_list.append(json.loads(line)) 71 | return json_list 72 | 73 | 74 | train_json = read_json(train_js) # This section uses a CSV file as an example to describe how to load data 75 | val_json = read_json(val_js) 76 | test_json = read_json(test_js) 77 | 78 | # Create a custom dataset 79 | train_dataset = MyDataset(train_json, tokenizer) 80 | val_dataset = MyDataset(val_json, tokenizer) 81 | test_dataset = MyDataset(test_json, tokenizer) 82 | 83 | # Create data loaders 84 | batch_size = 3 # Set batch size 85 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 86 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size) 87 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size) 88 | 89 | # Set device and model 90 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 91 | print(device, '\n') 92 | vocab_size = base_model.config.padded_vocab_size 93 | print(vocab_size) 94 | VM = ChatGLM_VM(base_model, vocab_size, 1) 95 | 96 | VM.to(device) 97 | 98 | # Define loss function and optimizer 99 | criterion = nn.MSELoss() 100 | optimizer = AdamW(VM.parameters(), lr=2e-5) 101 | num_epochs = 2 # or 3 102 | # Training and validation loop 103 | best_val_loss = 10000000 104 | train_losses = [] 105 | val_losses = [] 106 | for epoch in range(num_epochs): 107 | print(f"{epoch}/{num_epochs} training") 108 | # Training 109 | VM.train() 110 | train_loss = 0.0 111 | for batch in tqdm(train_dataloader): 112 | input_ids = batch['input_ids'].to(device) 113 | attention_mask = batch['attention_mask'].to(device) 114 | labels = batch['label'].bfloat16().to(device) 115 | 116 | optimizer.zero_grad() 117 | outputs = VM(input_ids=input_ids, attention_mask=attention_mask) 118 | loss = criterion(outputs, labels) 119 | loss.backward() 120 | optimizer.step() 121 | 122 | train_loss += loss.item() 123 | 124 | avg_train_loss = train_loss / len(train_dataloader) 125 | train_losses.append(avg_train_loss) 126 | 127 | # Validation 128 | VM.eval() 129 | val_loss = 0.0 130 | val_labels = [] 131 | with torch.no_grad(): 132 | for batch in tqdm(val_dataloader): 133 | input_ids = batch['input_ids'].to(device) 134 | attention_mask = batch['attention_mask'].to(device) 135 | labels = batch['label'].bfloat16().to(device) 136 | outputs = VM(input_ids=input_ids, attention_mask=attention_mask) 137 | loss = criterion(outputs, labels) 138 | val_loss += loss.item() 139 | val_labels.extend(labels.tolist()) 140 | 141 | avg_val_loss = val_loss / len(val_dataloader) 142 | val_losses.append(avg_val_loss) 143 | 144 | print(f"Epoch [{epoch + 1}/{num_epochs}]") 145 | print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} ") 146 | 147 | # Save best model 148 | if avg_val_loss < best_val_loss: 149 | best_val_loss = avg_val_loss 150 | torch.save(VM.state_dict(), "records/Chatglm/VM_best_checkpoint.pt") 151 | 152 | print("Training complete!") 153 | 154 | # import matplotlib.pyplot as plt 155 | # 156 | # epochs = range(1, num_epochs + 1) 157 | # plt.plot(epochs, train_losses, 'r', label='Training Loss') 158 | # plt.plot(epochs, val_losses, 'b', label='Validation Loss') 159 | # plt.xlabel('Epochs') 160 | # plt.ylabel('Loss') 161 | # plt.legend() 162 | # plt.savefig('VM_sensitive.png') 163 | # 164 | # plt.show() 165 | 166 | # Load the best model for inference 167 | best_model = ChatGLM_VM(base_model, vocab_size, 1) 168 | best_model.load_state_dict(torch.load("records/Chatglm/VM_best_checkpoint.pt")) 169 | best_model.to(device) 170 | best_model.eval() 171 | 172 | # Perform inference 173 | test_preds = [] 174 | test_labels = [] 175 | with torch.no_grad(): 176 | for batch in tqdm(test_dataloader): 177 | input_ids = batch['input_ids'].to(device) 178 | attention_mask = batch['attention_mask'].to(device) 179 | labels = batch['label'].bfloat16().to(device) 180 | outputs = best_model(input_ids=input_ids, attention_mask=attention_mask) 181 | test_preds.extend(outputs.tolist()) 182 | test_labels.extend(labels.tolist()) 183 | print("Inference results:") 184 | for i in range(len(test_preds)): 185 | print(f"Sample {i + 1}: Predicted score {test_preds[i]}, Actual score {test_labels[i]}, Truncated score {min(max(test_preds[i],0),1)}") 186 | 187 | cnt = 0 188 | for i in range(len(test_preds)): 189 | if abs(min(max(test_preds[i],0),1) - test_labels[i]) <= 0.1: 190 | cnt += 1 191 | test_acc = cnt / len(test_preds) 192 | print(f"Test accuracy: {test_acc:.4f}") 193 | ``` -------------------------------------------------------------------------------- /PRM/train_VM_chatglm_deepspeed.py: -------------------------------------------------------------------------------- 1 | # import debugpy; debugpy.connect(('100.98.26.69', 5690)) 2 | import argparse 3 | import os 4 | os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3' 5 | import json 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from tqdm import tqdm 10 | from transformers import AutoModel, AutoTokenizer, AdamW 11 | from sklearn.metrics import accuracy_score 12 | from torch.utils.data import DataLoader, Dataset 13 | import pandas as pd 14 | import time 15 | import deepspeed 16 | import torch.distributed as dist 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--debug', action='store_true', help='Enable debug mode') 20 | parser.add_argument('--local_rank', type=int, default=-1, help='Local rank for distributed training') 21 | args = parser.parse_args() 22 | 23 | if 'LOCAL_RANK' in os.environ: 24 | args.local_rank = int(os.environ['LOCAL_RANK']) 25 | print('>>>args.local_rank=', args.local_rank) 26 | 27 | 28 | deepspeed.init_distributed() 29 | 30 | max_length = 1024 31 | 32 | # Load the pre-trained ChatGLM3-6b model and tokenizer 33 | tokenizer = AutoTokenizer.from_pretrained("/data/llms/chatglm3-6b", trust_remote_code=True) 34 | base_model = AutoModel.from_pretrained("/data/llms/chatglm3-6b", 35 | trust_remote_code=True).bfloat16() 36 | 37 | # Custom Dataset class 38 | class MyDataset(Dataset): 39 | def __init__(self, data_js, tokenizer): 40 | self.data_js = data_js 41 | self.tokenizer = tokenizer 42 | 43 | def __len__(self): 44 | return len(self.data_js) 45 | 46 | def __getitem__(self, idx): 47 | prompt_answer = self.data_js[idx]['prompt_answer'] 48 | label = self.data_js[idx]['label'] 49 | 50 | encoded_pair = self.tokenizer.encode_plus( 51 | prompt_answer, 52 | padding='max_length', 53 | max_length=max_length, # Set the max length 54 | truncation=True, 55 | return_tensors='pt', # Return PyTorch Tensor format 56 | ) 57 | 58 | return { 59 | 'input_ids': encoded_pair['input_ids'].squeeze(), 60 | 'attention_mask': encoded_pair['attention_mask'].squeeze(), 61 | 'label': label 62 | } 63 | 64 | 65 | class ChatGLM_VM(nn.Module): 66 | def __init__(self, base, vocab_size, num_classes=1): 67 | super(ChatGLM_VM, self).__init__() 68 | self.base_model = base 69 | self.LN = nn.Linear(vocab_size, num_classes, dtype=torch.bfloat16) 70 | 71 | def forward(self, input_ids, attention_mask): 72 | outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask).logits[:, -1] 73 | value_outputs = self.LN(outputs) 74 | return value_outputs.squeeze(dim=1) 75 | 76 | 77 | # Load training set, validation set, and test set data 78 | train_js = '/data/ReST-MCTS-PRM-0th/train_en.json' 79 | test_js = '/data/ReST-MCTS-PRM-0th/test_en.json' 80 | val_js = '/data/ReST-MCTS-PRM-0th/valid_en.json' 81 | 82 | 83 | def read_json(source): 84 | json_list = [] 85 | with open(source, 'r', encoding='utf-8') as f: 86 | for line in f: 87 | json_list.append(json.loads(line)) 88 | return json_list 89 | 90 | 91 | train_json = read_json(train_js) # This section uses a CSV file as an example to describe how to load data 92 | val_json = read_json(val_js) 93 | test_json = read_json(test_js) 94 | 95 | if args.debug: 96 | print(">>>Debug mode: Using only toy training/val/test samples") 97 | train_json = train_json[:2] 98 | val_json = val_json[:2] 99 | test_json = test_json[:2] 100 | 101 | # Create a custom dataset 102 | train_dataset = MyDataset(train_json, tokenizer) 103 | val_dataset = MyDataset(val_json, tokenizer) 104 | test_dataset = MyDataset(test_json, tokenizer) 105 | 106 | # Create data loaders 107 | batch_size = 2 # 3 # Set batch size 108 | gradient_accumulation_steps = 4 109 | 110 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 111 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler) 112 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size) 113 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size) 114 | 115 | # Set device and model 116 | device = torch.device("cuda", args.local_rank) 117 | print(device, '\n') 118 | vocab_size = base_model.config.padded_vocab_size 119 | print(vocab_size) 120 | VM = ChatGLM_VM(base_model, vocab_size, 1) 121 | 122 | 123 | ds_config = { 124 | "train_micro_batch_size_per_gpu": batch_size, 125 | "gradient_accumulation_steps": gradient_accumulation_steps, 126 | "train_batch_size": batch_size * torch.cuda.device_count() * gradient_accumulation_steps, 127 | 128 | "bf16": { 129 | "enabled": True 130 | }, 131 | 132 | "zero_optimization": { 133 | "stage": 1, 134 | "allgather_partitions": True, 135 | "reduce_scatter": True 136 | }, 137 | # "zero_optimization": { 138 | # "stage": 0, 139 | # }, 140 | "zero_allow_untested_optimizer": True, 141 | 142 | "optimizer": { 143 | "type": "AdamW", 144 | "params": { 145 | "lr": 2e-5 146 | } 147 | } 148 | } 149 | 150 | 151 | 152 | 153 | model_engine, optimizer, _, _ = deepspeed.initialize( 154 | args=args, 155 | model=VM, 156 | model_parameters=VM.parameters(), 157 | config=ds_config 158 | ) 159 | 160 | # Define loss function 161 | criterion = nn.MSELoss() 162 | num_epochs = 3 # or 3 163 | # Training and validation loop 164 | best_val_loss = float('inf') 165 | train_losses = [] 166 | val_losses = [] 167 | 168 | train_start_time = time.time() 169 | for epoch in range(num_epochs): 170 | if args.local_rank == 0: 171 | print(f"{epoch}/{num_epochs} training") 172 | # Training 173 | model_engine.train() 174 | train_loss = 0.0 175 | train_sampler.set_epoch(epoch) 176 | 177 | for batch in tqdm(train_dataloader, disable=args.local_rank != 0): 178 | input_ids = batch['input_ids'].to(device) 179 | attention_mask = batch['attention_mask'].to(device) 180 | labels = batch['label'].bfloat16().to(device) 181 | 182 | outputs = model_engine(input_ids=input_ids, attention_mask=attention_mask) 183 | loss = criterion(outputs, labels) 184 | 185 | model_engine.backward(loss) 186 | model_engine.step() 187 | 188 | train_loss += loss.item() 189 | 190 | avg_train_loss = train_loss / len(train_dataloader) 191 | train_losses.append(avg_train_loss) 192 | 193 | # Validation 194 | if args.local_rank == 0: 195 | model_engine.eval() 196 | val_loss = 0.0 197 | val_labels = [] 198 | with torch.no_grad(): 199 | for batch in tqdm(val_dataloader): 200 | input_ids = batch['input_ids'].to(device) 201 | attention_mask = batch['attention_mask'].to(device) 202 | labels = batch['label'].bfloat16().to(device) 203 | outputs = model_engine(input_ids=input_ids, attention_mask=attention_mask) 204 | loss = criterion(outputs, labels) 205 | val_loss += loss.item() 206 | val_labels.extend(labels.tolist()) 207 | 208 | avg_val_loss = val_loss / len(val_dataloader) 209 | val_losses.append(avg_val_loss) 210 | 211 | print(f"Epoch [{epoch + 1}/{num_epochs}]") 212 | print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} ") 213 | 214 | # Save best model 215 | if avg_val_loss < best_val_loss: 216 | print(">>>Save best model...") 217 | best_val_loss = avg_val_loss 218 | 219 | # model_engine.save_checkpoint("/data/records/Chatglm", tag="VM_best_checkpoint") 220 | model_engine.save_16bit_model( 221 | "/data/records/Chatglm", 222 | "VM_best_checkpoint_0117.pt" 223 | ) 224 | 225 | if args.local_rank == 0: 226 | train_end_time = time.time() 227 | print("PRM Training complete!") 228 | print(f"PRM Training time: {train_end_time - train_start_time:.2f} seconds") 229 | 230 | -------------------------------------------------------------------------------- /PRM/train_VM_mistral.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 4 | import json 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW 10 | from sklearn.metrics import accuracy_score 11 | from torch.utils.data import DataLoader, Dataset 12 | import pandas as pd 13 | 14 | max_length = 900 15 | 16 | # Load the pre-trained Mistral-7b model and tokenizer 17 | tokenizer = AutoTokenizer.from_pretrained("/workspace/ckpt/MetaMath-Mistral-7B", trust_remote_code=True) 18 | base_model = AutoModelForCausalLM.from_pretrained("/workspace/ckpt/MetaMath-Mistral-7B", trust_remote_code=True, torch_dtype=torch.bfloat16).cuda() 19 | 20 | # Custom Dataset class 21 | class MyDataset(Dataset): 22 | def __init__(self, data_js, tokenizer): 23 | self.data_js = data_js 24 | self.tokenizer = tokenizer 25 | 26 | def __len__(self): 27 | return len(self.data_js) 28 | 29 | def __getitem__(self, idx): 30 | prompt_answer = self.data_js[idx]['prompt_answer'] 31 | label = self.data_js[idx]['label'] 32 | 33 | encoded_pair = self.tokenizer.encode_plus( 34 | prompt_answer, 35 | padding='max_length', 36 | max_length=max_length, # Set the max length 37 | truncation=True, 38 | return_tensors='pt', # Return PyTorch Tensor format 39 | ) 40 | 41 | return { 42 | 'input_ids': encoded_pair['input_ids'].squeeze(), 43 | 'attention_mask': encoded_pair['attention_mask'].squeeze(), 44 | 'label': label 45 | } 46 | 47 | 48 | class Mistral_VM(nn.Module): 49 | def __init__(self, base, vocab_size=32000): 50 | super(Mistral_VM, self).__init__() 51 | self.base_model = base 52 | self.LN = nn.Linear(vocab_size, 1) 53 | 54 | def forward(self, input_ids, attention_mask): 55 | outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask).logits[:, -1, :] 56 | value_outputs = self.LN(outputs) 57 | return value_outputs.squeeze(dim=1) 58 | 59 | 60 | # Load training set, validation set, and test set data 61 | train_js = 'data/train_en.json' 62 | test_js = 'data/test_en.json' 63 | val_js = 'data/valid_en.json' 64 | 65 | 66 | def read_json(source): 67 | json_list = [] 68 | with open(source, 'r', encoding='utf-8') as f: 69 | for line in f: 70 | json_list.append(json.loads(line)) 71 | return json_list 72 | 73 | 74 | train_json = read_json(train_js) # This section uses a CSV file as an example to describe how to load data 75 | val_json = read_json(val_js) 76 | test_json = read_json(test_js) 77 | 78 | # Create a custom dataset 79 | train_dataset = MyDataset(train_json, tokenizer) 80 | val_dataset = MyDataset(val_json, tokenizer) 81 | test_dataset = MyDataset(test_json, tokenizer) 82 | 83 | # Create data loaders 84 | batch_size = 3 # Set batch size 85 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 86 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size) 87 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size) 88 | 89 | # Set device and model 90 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 91 | print(device, '\n') 92 | vocab_size = base_model.config.vocab_size 93 | print(vocab_size) 94 | VM = Mistral_VM(base_model, vocab_size) 95 | VM.to(device) 96 | 97 | # Define loss function and optimizer 98 | criterion = nn.MSELoss() 99 | optimizer = AdamW(VM.parameters(), lr=3e-6) 100 | num_epochs = 2 # 2 or 3 101 | # Training and validation loop 102 | best_val_loss = 10000000 103 | train_losses = [] 104 | val_losses = [] 105 | for epoch in range(num_epochs): 106 | print(f"{epoch}/{num_epochs} training") 107 | # Training 108 | VM.train() 109 | train_loss = 0.0 110 | for batch in tqdm(train_dataloader): 111 | input_ids = batch['input_ids'].to(device) 112 | attention_mask = batch['attention_mask'].to(device) 113 | labels = batch['label'].to(dtype=torch.float32).to(device) 114 | 115 | optimizer.zero_grad() 116 | outputs = VM(input_ids=input_ids, attention_mask=attention_mask) 117 | loss = criterion(outputs, labels) 118 | loss.backward() 119 | optimizer.step() 120 | 121 | train_loss += loss.item() 122 | 123 | avg_train_loss = train_loss / len(train_dataloader) 124 | train_losses.append(avg_train_loss) 125 | 126 | # Validation 127 | VM.eval() 128 | val_loss = 0.0 129 | val_labels = [] 130 | with torch.no_grad(): 131 | for batch in tqdm(val_dataloader): 132 | input_ids = batch['input_ids'].to(device) 133 | attention_mask = batch['attention_mask'].to(device) 134 | labels = batch['label'].to(dtype=torch.float32).to(device) 135 | outputs = VM(input_ids=input_ids, attention_mask=attention_mask) 136 | loss = criterion(outputs, labels) 137 | val_loss += loss.item() 138 | val_labels.extend(labels.tolist()) 139 | 140 | avg_val_loss = val_loss / len(val_dataloader) 141 | val_losses.append(avg_val_loss) 142 | 143 | print(f"Epoch [{epoch + 1}/{num_epochs}]") 144 | print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} ") 145 | 146 | # Save best model 147 | if avg_val_loss < best_val_loss: 148 | best_val_loss = avg_val_loss 149 | torch.save(VM.state_dict(), "records/Mistral/VM_best_checkpoint.pt") 150 | 151 | print("Training complete!") 152 | 153 | # Load the best model for inference 154 | best_model = Mistral_VM(base_model, vocab_size) 155 | best_model.load_state_dict(torch.load("records/Mistral/VM_best_checkpoint.pt")) 156 | best_model.to(device) 157 | best_model.eval() 158 | 159 | # Perform inference 160 | test_preds = [] 161 | test_labels = [] 162 | with torch.no_grad(): 163 | for batch in tqdm(test_dataloader): 164 | input_ids = batch['input_ids'].to(device) 165 | attention_mask = batch['attention_mask'].to(device) 166 | labels = batch['label'].to(dtype=torch.float32).to(device) 167 | outputs = best_model(input_ids=input_ids, attention_mask=attention_mask) 168 | test_preds.extend(outputs.tolist()) 169 | test_labels.extend(labels.tolist()) 170 | print("Inference results:") 171 | for i in range(len(test_preds)): 172 | print(f"Sample {i + 1}: Predicted score {test_preds[i]}, Actual score {test_labels[i]}, Truncated score {min(max(test_preds[i], 0), 1)}") 173 | 174 | cnt = 0 175 | for i in range(len(test_preds)): 176 | if abs(min(max(test_preds[i], 0), 1) - test_labels[i]) <= 0.1: 177 | cnt += 1 178 | test_acc = cnt / len(test_preds) 179 | print(f"Test accuracy: {test_acc:.4f}") 180 | ``` -------------------------------------------------------------------------------- /ToT/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | 5 | class Node(object): 6 | def __init__(self, pcd: str, parent=None, depth=0): 7 | self.pcd = pcd # current step 8 | self.children = [] 9 | self.V = 0 10 | self.parent = parent 11 | self.y = '' # overall steps 12 | self.depth = depth 13 | self.visit_sequence = 0 14 | self.final_ans_flag = 0 15 | 16 | def append_children(self, new_pcd: str): 17 | node = Node(new_pcd, self, self.depth + 1) 18 | node.update_y_from_parent() 19 | self.children.append(node) 20 | return self, node 21 | 22 | def update_y_from_parent(self): 23 | if self.parent is None: 24 | self.y = self.pcd 25 | else: 26 | self.y = self.parent.y + self.pcd 27 | 28 | def update_value(self, value): 29 | self.V = value 30 | 31 | def getBestV(self): # Gets the subtree maximum value node 32 | if not self.children: 33 | return self, self.V 34 | max_V = self.V 35 | max_node = self 36 | for child in self.children: 37 | subNode, subValue = child.getBestV() 38 | if subValue >= max_V: 39 | max_V = subValue 40 | max_node = subNode 41 | return max_node, max_V 42 | 43 | def get_multiply_value(self): 44 | if self.depth == 0: 45 | return 0 46 | multi_value = self.V 47 | cur_node = self.parent 48 | while cur_node.depth > 0: 49 | multi_value = multi_value * cur_node.V 50 | cur_node = cur_node.parent 51 | return multi_value 52 | 53 | 54 | class SolutionStep(object): 55 | def __init__(self, x, stp, all_steps, score, step_num): 56 | self.x = x 57 | self.stp = stp 58 | self.all_steps = all_steps 59 | self.score = score 60 | self.step_num = step_num 61 | 62 | 63 | def rand_select(data_list: list, probs: list): # Sampling by probability 64 | assert len(data_list) == len(probs), "length do not match!" 65 | probs_norm = [] 66 | sum_prob = sum(probs) 67 | for i in probs: 68 | probs_norm.append(i / sum_prob) 69 | intervals = [] 70 | count = 0 71 | for i in probs_norm: 72 | count = count + i 73 | intervals.append(count) 74 | # assert count == 1, "probs error!" 75 | intervals[len(intervals) - 1] = 1 76 | index = 0 77 | rand_prob = random.random() 78 | while rand_prob >= intervals[index]: 79 | index = index + 1 80 | return index, data_list[index] 81 | -------------------------------------------------------------------------------- /ToT/bfs.py: -------------------------------------------------------------------------------- 1 | from ToT.base import Node, rand_select 2 | 3 | 4 | def BFS(tot_task): 5 | root = Node('') 6 | cur_nodes = [root] 7 | for depth in range(tot_task.max_depth): 8 | candidates = [] 9 | for node in cur_nodes: 10 | for i in range(tot_task.branch): 11 | new_pcd = '' 12 | cnt = 3 13 | while not new_pcd and cnt: 14 | new_pcd = tot_task.get_next_step(node.y, node.depth + 1) 15 | cnt -= 1 16 | if not new_pcd: 17 | continue 18 | 19 | node, child = node.append_children(new_pcd) 20 | value = tot_task.get_step_value(child.y) 21 | child.update_value(value) 22 | child.visit_sequence = tot_task.node_count 23 | tot_task.update_count() 24 | candidates.append(child) 25 | 26 | if not candidates: 27 | break 28 | ranked_candidates = sorted(candidates, key=lambda item: item.V, reverse=True) 29 | if ranked_candidates[0].V >= tot_task.end_gate: 30 | print('The final solution has been found!\n') 31 | ranked_candidates[0].final_ans_flag = 1 32 | return ranked_candidates[0].y, root, ranked_candidates[0] 33 | 34 | if tot_task.select_method == 'greedy': 35 | cur_nodes = ranked_candidates[:min(tot_task.select_branch, tot_task.branch, len(ranked_candidates))] 36 | 37 | else: 38 | idx_list = [] 39 | cur_nodes = [] 40 | for j in range(min(tot_task.select_branch, tot_task.branch)): 41 | idx, node = rand_select(ranked_candidates, [item.V for item in ranked_candidates]) 42 | if idx not in idx_list: 43 | idx_list.append(idx) 44 | cur_nodes.append(node) 45 | cur_nodes = sorted(cur_nodes, key=lambda item: item.V, reverse=True) 46 | 47 | print('If no solution satisfying the required value is found, the highest value value solution is used instead.\n') 48 | max_node, max_V = root.getBestV() 49 | max_node.final_ans_flag = 1 50 | return max_node.y, root, max_node 51 | -------------------------------------------------------------------------------- /ToT/dfs.py: -------------------------------------------------------------------------------- 1 | from ToT.base import Node, rand_select 2 | 3 | 4 | def DFS_sub(tot_task, node): 5 | if node.depth >= tot_task.max_depth: 6 | print('Maximum depth limit reached!\n') 7 | return "", node, None 8 | 9 | candidates = [] 10 | for i in range(tot_task.branch): 11 | new_pcd = '' 12 | cnt = 3 13 | while not new_pcd and cnt: 14 | new_pcd = tot_task.get_next_step(node.y, node.depth + 1) 15 | cnt -= 1 16 | if not new_pcd: 17 | continue 18 | 19 | node, child = node.append_children(new_pcd) 20 | value = tot_task.get_step_value(child.y) 21 | child.update_value(value) 22 | child.visit_sequence = tot_task.node_count 23 | tot_task.update_count() 24 | candidates.append(child) 25 | 26 | if not candidates: 27 | print('No suitable next step was found!\n') 28 | return "", node, None 29 | ranked_candidates = sorted(candidates, key=lambda item: item.V, reverse=True) 30 | if ranked_candidates[0].V >= tot_task.end_gate: 31 | ranked_candidates[0].final_ans_flag = 1 32 | return ranked_candidates[0].y, node, ranked_candidates[0] 33 | 34 | # Further probe 35 | if tot_task.select_method == 'greedy': 36 | selected = ranked_candidates[:min(tot_task.select_branch, tot_task.branch, len(ranked_candidates))] 37 | 38 | else: 39 | idx_list = [] 40 | selected = [] 41 | for j in range(min(tot_task.select_branch, tot_task.branch)): 42 | idx, node = rand_select(ranked_candidates, [item.V for item in ranked_candidates]) 43 | if idx not in idx_list: 44 | idx_list.append(idx) 45 | selected.append(node) 46 | selected = sorted(selected, key=lambda item: item.V, reverse=True) 47 | 48 | for child in selected: 49 | solution, child, final_node = DFS_sub(tot_task, child) 50 | if solution: 51 | return solution, node, final_node 52 | 53 | return "", node, None 54 | 55 | 56 | def DFS(tot_task): 57 | root = Node('') 58 | solution, root, final_node = DFS_sub(tot_task, root) 59 | if solution: 60 | print(f'The final solution has been found!\nSolution:{solution}\n') 61 | return solution, root, final_node 62 | else: 63 | max_node, max_V = root.getBestV() 64 | max_node.final_ans_flag = 1 65 | print(f'If no solution satisfying the required value is found, the highest value value solution is used instead.\nSolution:{max_node.y}\n') 66 | return max_node.y, root, max_node 67 | -------------------------------------------------------------------------------- /ToT/task.py: -------------------------------------------------------------------------------- 1 | import random 2 | from tasks.science import SearchTask 3 | from ToT.base import Node 4 | from models.get_response import * 5 | from ToT.bfs import BFS 6 | from ToT.dfs import DFS 7 | from utils.solution_summary_extractor import extract_summary_from_solution 8 | from utils.verify_MATH import exact_match_score 9 | 10 | 11 | class ToT_Task(SearchTask): 12 | def __init__(self, data, propose_method='glm', value_method='glm', algorithm='dfs', branch=3, select_branch=2, 13 | max_depth=8, end_gate=0.9, select_method='greedy', 14 | temperature=0.7, max_tokens=2048, 15 | seed=170, max_length=2048, truncation=True, 16 | do_sample=True, max_new_tokens=256, use_case_prompt=False, low=0, high=1, evaluate='', multiply_value=False, lang='zh', answer=None, verify_method='string'): 17 | super().__init__(data, propose_method, value_method) 18 | assert 0 <= low < high, "Inappropriate value range!" 19 | self.mode = 'tot' 20 | self.temperature = temperature 21 | self.max_tokens = max_tokens 22 | self.seed = seed 23 | self.max_length = max_length 24 | self.truncation = truncation 25 | self.do_sample = do_sample 26 | self.max_new_tokens = max_new_tokens 27 | self.algorithm = algorithm 28 | self.branch = branch 29 | self.select_branch = select_branch 30 | self.max_depth = max_depth 31 | self.use_case_prompt = use_case_prompt 32 | self.low = low 33 | self.high = high 34 | self.evaluate = evaluate 35 | self.select_method = select_method 36 | self.end_gate = end_gate 37 | self.node_count = 1 38 | self.multiply_value = multiply_value 39 | self.lang = lang 40 | self.answer = answer 41 | self.verify_method = verify_method 42 | 43 | def update_count(self): 44 | self.node_count += 1 45 | 46 | def clear_cache(self): 47 | self.value_cache = {} 48 | self.node_count = 1 49 | 50 | def get_next_step(self, y, step_n): 51 | if self.use_case_prompt: 52 | prompt = self.single_propose_prompt_wrap(self.question, y, step_n) 53 | else: 54 | prompt = self.zero_single_propose_wrap(self.question, y, step_n, self.lang) 55 | 56 | response = get_proposal(prompt, self.propose_method, self.temperature, self.max_tokens, self.seed, 57 | self.max_length, 58 | self.truncation, self.do_sample, self.max_new_tokens) 59 | if not response: 60 | print('Failed to get next step!\n') 61 | return '' 62 | 63 | if len(response) > 5: 64 | response = response[:5] 65 | 66 | p = '' 67 | for _ in response: 68 | p = p + _ + ' ' 69 | p = p.strip() 70 | 71 | if self.lang == 'zh': 72 | if '下一步:' in p: 73 | stp = p.split('下一步:')[1].strip() 74 | if len(stp) < 2: 75 | print('The output step is too short!\n') 76 | return '' 77 | if stp in y: 78 | print('Output step repeat!\n') 79 | return '' 80 | 81 | revised_ = '步骤' + str(step_n) + ':' + stp 82 | print(f'New steps after standardization:{revised_}\n') 83 | return revised_ + '\n' 84 | 85 | elif '步骤' in p and ':' in p: 86 | pre_len = len(p.split(':')[0]) 87 | p_ = p[pre_len:] 88 | p_ = p_.split('步骤')[0].strip() 89 | if len(p_) < 3: 90 | print('The output step is too short!\n') 91 | return '' 92 | if p_[1:] in y: 93 | print('Output step repeat!\n') 94 | return '' 95 | 96 | revised_ = '步骤' + str(step_n) + p_ 97 | print(f'New steps after standardization:{revised_}\n') 98 | return revised_ + '\n' 99 | 100 | else: 101 | print('Incorrect output format!\n') 102 | return '' 103 | else: 104 | if "Next step:" in p: 105 | stp = p.split('Next step:')[1].strip() 106 | if len(stp) < 2: 107 | print('The output step is too short!\n') 108 | return '' 109 | if stp in y: 110 | print('Output step repeat!\n') 111 | return '' 112 | 113 | revised_ = 'Step ' + str(step_n) + ': ' + stp 114 | print(f'New steps after standardization:{revised_}\n') 115 | return revised_ + '\n' 116 | 117 | elif "Step" in p and ":" in p: 118 | pre_len = len(p.split(':')[0]) 119 | p_ = p[pre_len:] 120 | p_ = p_.split('Step')[0].strip() 121 | if len(p_) < 4: 122 | print('The output step is too short!\n') 123 | return '' 124 | p_ = p_[1:].strip() 125 | if p_ in y: 126 | print('Output step repeat!\n') 127 | return '' 128 | 129 | revised_ = 'Step ' + str(step_n) + ': ' + p_ 130 | print(f'New steps after standardization:{revised_}\n') 131 | return revised_ + '\n' 132 | 133 | else: 134 | p_ = p.strip() 135 | if len(p_) < 3: 136 | print('The output step is too short!\n') 137 | return '' 138 | if p_ in y: 139 | print('Output step repeat!\n') 140 | return '' 141 | 142 | revised_ = 'Step ' + str(step_n) + ': ' + p_ 143 | print(f'New steps after standardization:{revised_}\n') 144 | return revised_ + '\n' 145 | 146 | def get_step_value(self, y): 147 | if y in self.value_cache.keys(): 148 | return self.value_cache[y] 149 | 150 | if self.value_method == 'local': 151 | if self.lang == 'zh': 152 | prompt_answer = '问题:' + self.question + '\n步骤:\n' + '【答案】' + y 153 | else: 154 | prompt_answer = 'Problem: ' + self.question + '\nSolution:\n' + y 155 | 156 | value = get_value(prompt_answer, self.value_method, self.temperature, self.max_tokens, self.seed, 157 | self.max_length, self.low, self.high) 158 | print(f'Get a score:{value}\n') 159 | self.value_cache.update({y: value}) 160 | return value 161 | 162 | else: 163 | prompt = self.value_prompt_wrap(self.question, y) 164 | response = get_value(prompt, self.value_method, self.temperature, self.max_tokens, self.seed, 165 | self.max_length, self.low, self.high) 166 | value = self.value_outputs_unwrap(response, self.low, self.high) 167 | print(f'Get a score:{value}\n') 168 | self.value_cache.update({y: value}) 169 | return value 170 | 171 | def get_summary(self, y): 172 | if self.lang == 'zh': 173 | if self.evaluate == 'scibench': 174 | prompt = self.evaluate_summary_prompt_wrap(self.question, y) 175 | elif self.evaluate == 'scieval': 176 | prompt = self.general_evaluate_summary_prompt_wrap(self.question, y) 177 | else: 178 | prompt = self.summary_prompt_wrap(self.question, y) 179 | 180 | response = get_proposal(prompt, self.propose_method, self.temperature, self.max_tokens, self.seed, 181 | self.max_length, 182 | self.truncation, self.do_sample, 128) 183 | 184 | if not response: 185 | print('Failed to get a summary!\n') 186 | return '' 187 | p = '' 188 | for _ in response: 189 | p = p + _ + ' ' 190 | p = p.strip() 191 | 192 | if self.evaluate: 193 | if len(p) < 1: 194 | print('Get the summary too short!\n') 195 | return '' 196 | 197 | if '综上所述,最终答案是:' not in p: 198 | summ = '综上所述,最终答案是:' + p 199 | print(f'Get summary:{summ}\n') 200 | return summ 201 | else: 202 | summ = '综上所述,最终答案是:' + p.split('综上所述,最终答案是:')[-1] 203 | print(f'Get summary:{summ}\n') 204 | return summ 205 | 206 | else: 207 | if len(p) < 1: 208 | print('Get the summary too short!\n') 209 | return '' 210 | 211 | if '综上所述,' not in p: 212 | summ = '综上所述,' + p 213 | print(f'Get summary:{summ}\n') 214 | return summ 215 | else: 216 | summ = '综上所述,' + p.split('综上所述,')[-1] 217 | print(f'Get summary:{summ}\n') 218 | return summ 219 | 220 | else: 221 | prompt = self.MATH_summary_prompt_wrap(self.question, y) 222 | response = get_proposal(prompt, self.propose_method, self.temperature, self.max_tokens, self.seed, 223 | self.max_length, 224 | self.truncation, self.do_sample, 128) 225 | if not response: 226 | print('Failed to get a summary!\n') 227 | return '' 228 | p = '' 229 | for _ in response: 230 | p = p + _ + ' ' 231 | summ = p.strip() 232 | 233 | print(f'Get summary:{summ}\n') 234 | return summ 235 | 236 | def run(self): 237 | self.clear_cache() 238 | if self.algorithm == 'dfs': 239 | solution, root, final_node = DFS(self) 240 | elif self.algorithm == 'bfs': 241 | solution, root, final_node = BFS(self) 242 | else: 243 | print('Unsupported algorithm!\n') 244 | return {} 245 | 246 | cnt = 5 247 | summary = '' 248 | while cnt: 249 | summary = self.get_summary(solution) 250 | if summary: 251 | break 252 | else: 253 | cnt -= 1 254 | if not summary and self.lang == 'en': 255 | summary = extract_summary_from_solution(solution) 256 | 257 | if self.evaluate == 'math' or self.verify_method == 'string': 258 | result = exact_match_score(summary, self.answer) 259 | final_answer = {'content': self.question, 'solution': solution, 'summary': summary, 'accurate': result, 'real_answer': self.answer} 260 | else: 261 | final_answer = {'content': self.question, 'solution': solution, 'summary': summary} 262 | 263 | if self.multiply_value: 264 | multiply_v = final_node.get_multiply_value() 265 | final_answer.update({'multiply_value': multiply_v}) 266 | 267 | return final_answer, root 268 | -------------------------------------------------------------------------------- /assets/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/ReST-MCTS/2d5f488c3d6e24f99d50a9860e818383b1bb5883/assets/comparison.png -------------------------------------------------------------------------------- /assets/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/ReST-MCTS/2d5f488c3d6e24f99d50a9860e818383b1bb5883/assets/overall.png -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/ReST-MCTS/2d5f488c3d6e24f99d50a9860e818383b1bb5883/assets/results.png -------------------------------------------------------------------------------- /assets/searches.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/ReST-MCTS/2d5f488c3d6e24f99d50a9860e818383b1bb5883/assets/searches.png -------------------------------------------------------------------------------- /assets/vm_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/ReST-MCTS/2d5f488c3d6e24f99d50a9860e818383b1bb5883/assets/vm_results.png -------------------------------------------------------------------------------- /data/scibench/chemmc_standardized.json: -------------------------------------------------------------------------------- 1 | {"content": "Calculate the de Broglie wavelength for (a) an electron with a kinetic energy of $100 \\mathrm{eV}$ The unit of the answer should be nm .", "answer": 0.123} 2 | {"content": "The threshold wavelength for potassium metal is $564 \\mathrm{~nm}$. What is its work function? \r\n The unit of the answer should be $10^{-19} \\mathrm{~J}$.", "answer": 3.52} 3 | {"content": "Evaluate the series\r\n$$\r\nS=\\sum_{n=0}^{\\infty} \\frac{1}{3^n}\r\n$$ The unit of the answer should be .", "answer": 1.5} 4 | {"content": "Evaluate the series\r\n$$\r\nS=\\sum_{n=1}^{\\infty} \\frac{(-1)^{n+1}}{2^n}\r\n$$ The unit of the answer should be .", "answer": 1.0} 5 | {"content": "The relationship introduced in Problem $1-48$ has been interpreted to mean that a particle of mass $m\\left(E=m c^2\\right)$ can materialize from nothing provided that it returns to nothing within a time $\\Delta t \\leq h / m c^2$. Particles that last for time $\\Delta t$ or more are called real particles; particles that last less than time $\\Delta t$ are called virtual particles. The mass of the charged pion, a subatomic particle, is $2.5 \\times 10^{-28} \\mathrm{~kg}$. What is the minimum lifetime if the pion is to be considered a real particle? The unit of the answer should be $10^{-23} \\mathrm{~s}$ .", "answer": 2.9} 6 | {"content": "A household lightbulb is a blackbody radiator. Many lightbulbs use tungsten filaments that are heated by an electric current. What temperature is needed so that $\\lambda_{\\max }=550 \\mathrm{~nm}$ ? The unit of the answer should be $\\mathrm{~K}$\r\n.", "answer": 5300.0} 7 | {"content": "Evaluate the series\r\n$$\r\nS=\\frac{1}{2}+\\frac{1}{4}+\\frac{1}{8}+\\frac{1}{16}+\\cdots\r\n$$\r\n The unit of the answer should be .", "answer": 1.0} 8 | {"content": "Through what potential must a proton initially at rest fall so that its de Broglie wavelength is $1.0 \\times 10^{-10} \\mathrm{~m}$ ? The unit of the answer should be V .", "answer": 0.082} 9 | {"content": "Example 5-3 shows that a Maclaurin expansion of a Morse potential leads to\r\n$$\r\nV(x)=D \\beta^2 x^2+\\cdots\r\n$$\r\nGiven that $D=7.31 \\times 10^{-19} \\mathrm{~J} \\cdot$ molecule ${ }^{-1}$ and $\\beta=1.81 \\times 10^{10} \\mathrm{~m}^{-1}$ for $\\mathrm{HCl}$, calculate the force constant of $\\mathrm{HCl}$. The unit of the answer should be $\\mathrm{~N} \\cdot \\mathrm{m}^{-1}$ .", "answer": 479.0} 10 | {"content": "A line in the Lyman series of hydrogen has a wavelength of $1.03 \\times 10^{-7} \\mathrm{~m}$. Find the original energy level of the electron. The unit of the answer should be .", "answer": 3.0} 11 | {"content": "A helium-neon laser (used in supermarket scanners) emits light at $632.8 \\mathrm{~nm}$. Calculate the frequency of this light. The unit of the answer should be $10^{14} \\mathrm{~Hz}$ .", "answer": 4.738} 12 | {"content": "What is the uncertainty of the momentum of an electron if we know its position is somewhere in a $10 \\mathrm{pm}$ interval? The unit of the answer should be $10^{-23} \\mathrm{~kg} \\cdot \\mathrm{m} \\cdot \\mathrm{s}^{-1}$.", "answer": 6.6} 13 | {"content": "Using the Bohr theory, calculate the ionization energy (in electron volts and in $\\mathrm{kJ} \\cdot \\mathrm{mol}^{-1}$ ) of singly ionized helium. The unit of the answer should be $\\mathrm{eV}$ .", "answer": 54.394} 14 | {"content": "When an excited nucleus decays, it emits a $\\gamma$ ray. The lifetime of an excited state of a nucleus is of the order of $10^{-12} \\mathrm{~s}$. What is the uncertainty in the energy of the $\\gamma$ ray produced? The unit of the answer should be $10^{-22} \\mathrm{~J}$ .", "answer": 7.0} 15 | {"content": "Calculate the wavelength and the energy of a photon associated with the series limit of the Lyman series. The unit of the answer should be nm .", "answer": 91.17} 16 | {"content": "Another application of the relationship given in Problem $1-48$ has to do with the excitedstate energies and lifetimes of atoms and molecules. If we know that the lifetime of an excited state is $10^{-9} \\mathrm{~s}$, then what is the uncertainty in the energy of this state?\r\n The unit of the answer should be $10^{-25} \\mathrm{~J}$.", "answer": 7.0} 17 | {"content": "One of the most powerful modern techniques for studying structure is neutron diffraction. This technique involves generating a collimated beam of neutrons at a particular temperature from a high-energy neutron source and is accomplished at several accelerator facilities around the world. If the speed of a neutron is given by $v_{\\mathrm{n}}=\\left(3 k_{\\mathrm{B}} T / m\\right)^{1 / 2}$, where $m$ is the mass of a neutron, then what temperature is needed so that the neutrons have a de Broglie wavelength of $50 \\mathrm{pm}$ ? The unit of the answer should be $\\mathrm{K}$ .", "answer": 2500.0} 18 | {"content": "The temperature of the fireball in a thermonuclear explosion can reach temperatures of approximately $10^7 \\mathrm{~K}$. What value of $\\lambda_{\\max }$ does this correspond to? The unit of the answer should be $10^{-10} \\mathrm{~m}$\r\n.", "answer": 3.0} 19 | {"content": "Show that l'Hôpital's rule amounts to forming a Taylor expansion of both the numerator and the denominator. Evaluate the limit\r\n$$\r\n\\lim _{x \\rightarrow 0} \\frac{\\ln (1+x)-x}{x^2}\r\n$$\r\nboth ways. The unit of the answer should be .", "answer": -0.5} 20 | {"content": "Evaluate the series\r\n$$\r\nS=\\sum_{n=1}^{\\infty} \\frac{(-1)^{n+1}}{2^n}\r\n$$ The unit of the answer should be .", "answer": 0.3333333} 21 | {"content": "Calculate the percentage difference between $\\ln (1+x)$ and $x$ for $x=0.0050$ The unit of the answer should be %.", "answer": 0.249} 22 | {"content": "Calculate the reduced mass of a nitrogen molecule in which both nitrogen atoms have an atomic mass of 14.00. The unit of the answer should be .", "answer": 7.0} 23 | {"content": "Two narrow slits are illuminated with red light of wavelength $694.3 \\mathrm{~nm}$ from a laser, producing a set of evenly placed bright bands on a screen located $3.00 \\mathrm{~m}$ beyond the slits. If the distance between the bands is $1.50 \\mathrm{~cm}$, then what is the distance between the slits?\r\n The unit of the answer should be mm .", "answer": 0.139} 24 | {"content": "Calculate the energy and wavelength associated with an $\\alpha$ particle that has fallen through a potential difference of $4.0 \\mathrm{~V}$. Take the mass of an $\\alpha$ particle to be $6.64 \\times 10^{-27} \\mathrm{~kg}$. The unit of the answer should be $\r\n\\text { 1-41. } 1.3 \\times 10^{-18} \\mathrm{~J} / \\alpha \\text {-particle, }\r\n$.", "answer": 1.3} 25 | {"content": "Calculate the number of photons in a $2.00 \\mathrm{~mJ}$ light pulse at (a) $1.06 \\mu \\mathrm{m}$\r\n The unit of the answer should be $10^{16}$ photons.", "answer": 1.07} 26 | {"content": "The force constant of ${ }^{35} \\mathrm{Cl}^{35} \\mathrm{Cl}$ is $319 \\mathrm{~N} \\cdot \\mathrm{m}^{-1}$. Calculate the fundamental vibrational frequency The unit of the answer should be $\\mathrm{~cm}^{-1}$.", "answer": 556.0} 27 | {"content": "$$\r\n\\text {Calculate the energy of a photon for a wavelength of } 100 \\mathrm{pm} \\text { (about one atomic diameter). }\r\n$$\r\n The unit of the answer should be $10^{-15} \\mathrm{~J}$.", "answer": 2.0} 28 | {"content": "A proton and a negatively charged $\\mu$ meson (called a muon) can form a short-lived species called a mesonic atom. The charge of a muon is the same as that on an electron and the mass of a muon is $207 m_{\\mathrm{e}}$. Assume that the Bohr theory can be applied to such a mesonic atom and calculate the ground-state energy, the radius of the first Bohr orbit, and the energy and frequency associated with the $n=1$ to $n=2$ transition in a mesonic atom. The unit of the answer should be $10^{-28} \\mathrm{~kg}$.", "answer": 1.69} 29 | {"content": "$$\r\n\\beta=2 \\pi c \\tilde{\\omega}_{\\mathrm{obs}}\\left(\\frac{\\mu}{2 D}\\right)^{1 / 2}\r\n$$\r\nGiven that $\\tilde{\\omega}_{\\mathrm{obs}}=2886 \\mathrm{~cm}^{-1}$ and $D=440.2 \\mathrm{~kJ} \\cdot \\mathrm{mol}^{-1}$ for $\\mathrm{H}^{35} \\mathrm{Cl}$, calculate $\\beta$. The unit of the answer should be $10^{10} \\mathrm{~m}^{-1}$.", "answer": 1.81} 30 | {"content": "Two narrow slits separated by $0.10 \\mathrm{~mm}$ are illuminated by light of wavelength $600 \\mathrm{~nm}$. What is the angular position of the first maximum in the interference pattern? If a detector is located $2.00 \\mathrm{~m}$ beyond the slits, what is the distance between the central maximum and the first maximum? The unit of the answer should be mm.", "answer": 12.0} 31 | {"content": "$$\r\n\\text { If we locate an electron to within } 20 \\mathrm{pm} \\text {, then what is the uncertainty in its speed? }\r\n$$ The unit of the answer should be $10^7 \\mathrm{~m} \\cdot \\mathrm{s}^{-1}$ .", "answer": 3.7} 32 | {"content": "The mean temperature of the earth's surface is $288 \\mathrm{~K}$. What is the maximum wavelength of the earth's blackbody radiation? The unit of the answer should be 10^{-5} \\mathrm{~m}.", "answer": 1.01} 33 | {"content": "The power output of a laser is measured in units of watts (W), where one watt is equal to one joule per second. $\\left(1 \\mathrm{~W}=1 \\mathrm{~J} \\cdot \\mathrm{s}^{-1}\\right.$.) What is the number of photons emitted per second by a $1.00 \\mathrm{~mW}$ nitrogen laser? The wavelength emitted by a nitrogen laser is $337 \\mathrm{~nm}$. The unit of the answer should be $\r\n10^{15} \\text { photon } \\cdot \\mathrm{s}^{-1}\r\n$.", "answer": 1.7} 34 | {"content": " Sirius, one of the hottest known stars, has approximately a blackbody spectrum with $\\lambda_{\\max }=260 \\mathrm{~nm}$. Estimate the surface temperature of Sirius.\r\n The unit of the answer should be $\\mathrm{~K}$\r\n.", "answer": 11000.0} 35 | {"content": "A ground-state hydrogen atom absorbs a photon of light that has a wavelength of $97.2 \\mathrm{~nm}$. It then gives off a photon that has a wavelength of $486 \\mathrm{~nm}$. What is the final state of the hydrogen atom? The unit of the answer should be .", "answer": 2.0} 36 | {"content": "It turns out that the solution of the Schrödinger equation for the Morse potential can be expressed as\r\n$$\r\nG(v)=\\tilde{\\omega}_{\\mathrm{e}}\\left(v+\\frac{1}{2}\\right)-\\tilde{\\omega}_{\\mathrm{e}} \\tilde{x}_{\\mathrm{e}}\\left(v+\\frac{1}{2}\\right)^2\r\n$$\r\nChapter 5 / The Harmonic Oscillator and Vibrational Spectroscopy\r\nwhere\r\n$$\r\n\\tilde{x}_{\\mathrm{e}}=\\frac{h c \\tilde{\\omega}_{\\mathrm{e}}}{4 D}\r\n$$\r\nGiven that $\\tilde{\\omega}_{\\mathrm{e}}=2886 \\mathrm{~cm}^{-1}$ and $D=440.2 \\mathrm{~kJ} \\cdot \\mathrm{mol}^{-1}$ for $\\mathrm{H}^{35} \\mathrm{Cl}$, calculate $\\tilde{x}_{\\mathrm{e}}$ and $\\tilde{\\omega}_{\\mathrm{e}} \\tilde{x}_{\\mathrm{e}}$. The unit of the answer should be .", "answer": 0.01961} 37 | {"content": " In the infrared spectrum of $\\mathrm{H}^{127} \\mathrm{I}$, there is an intense line at $2309 \\mathrm{~cm}^{-1}$. Calculate the force constant of $\\mathrm{H}^{127} \\mathrm{I}$. The unit of the answer should be $ \\mathrm{~N} \\cdot \\mathrm{m}^{-1}$.", "answer": 313.0} 38 | {"content": "Calculate the percentage difference between $e^x$ and $1+x$ for $x=0.0050$ The unit of the answer should be $10^{-3} \\%$.", "answer": 1.25} 39 | {"content": " Calculate (a) the wavelength and kinetic energy of an electron in a beam of electrons accelerated by a voltage increment of $100 \\mathrm{~V}$ The unit of the answer should be $10^{-17} \\mathrm{~J} \\cdot$ electron ${ }^{-1}$.", "answer": 1.602} 40 | -------------------------------------------------------------------------------- /data/scibench/quan_standardized.json: -------------------------------------------------------------------------------- 1 | {"content": "Use the $D_0$ value of $\\mathrm{H}_2(4.478 \\mathrm{eV})$ and the $D_0$ value of $\\mathrm{H}_2^{+}(2.651 \\mathrm{eV})$ to calculate the first ionization energy of $\\mathrm{H}_2$ (that is, the energy needed to remove an electron from $\\mathrm{H}_2$ ). The unit of the answer should be $\\mathrm{eV}$.", "answer": 15.425} 2 | {"content": "Calculate the energy of one mole of UV photons of wavelength $300 \\mathrm{~nm}$ and compare it with a typical single-bond energy of $400 \\mathrm{~kJ} / \\mathrm{mol}$. The unit of the answer should be $\\mathrm{~kJ} / \\mathrm{mol}$.", "answer": 399.0} 3 | {"content": "Calculate the magnitude of the spin angular momentum of a proton. Give a numerical answer. The unit of the answer should be $10^{-35} \\mathrm{~J} \\mathrm{~s}$.", "answer": 9.13} 4 | {"content": "The ${ }^7 \\mathrm{Li}^1 \\mathrm{H}$ ground electronic state has $D_0=2.4287 \\mathrm{eV}, \\nu_e / c=1405.65 \\mathrm{~cm}^{-1}$, and $\\nu_e x_e / c=23.20 \\mathrm{~cm}^{-1}$, where $c$ is the speed of light. (These last two quantities are usually designated $\\omega_e$ and $\\omega_e x_e$ in the literature.) Calculate $D_e$ for ${ }^7 \\mathrm{Li}^1 \\mathrm{H}$. The unit of the answer should be $\\mathrm{eV}$.", "answer": 2.5151} 5 | {"content": "The positron has charge $+e$ and mass equal to the electron mass. Calculate in electronvolts the ground-state energy of positronium-an \"atom\" that consists of a positron and an electron. The unit of the answer should be $\\mathrm{eV}$.", "answer": -6.8} 6 | {"content": "What is the value of the angular-momentum quantum number $l$ for a $t$ orbital? The unit of the answer should be .", "answer": 14.0} 7 | {"content": "How many states belong to the carbon configurations $1 s^2 2 s^2 2 p^2$? The unit of the answer should be .", "answer": 15.0} 8 | {"content": "Calculate the energy needed to compress three carbon-carbon single bonds and stretch three carbon-carbon double bonds to the benzene bond length $1.397 Å$. Assume a harmonicoscillator potential-energy function for bond stretching and compression. Typical carboncarbon single- and double-bond lengths are 1.53 and $1.335 Å$; typical stretching force constants for carbon-carbon single and double bonds are 500 and $950 \\mathrm{~N} / \\mathrm{m}$. The unit of the answer should be $\\mathrm{kcal} / \\mathrm{mol}$.", "answer": 27.0} 9 | {"content": "When a particle of mass $9.1 \\times 10^{-28} \\mathrm{~g}$ in a certain one-dimensional box goes from the $n=5$ level to the $n=2$ level, it emits a photon of frequency $6.0 \\times 10^{14} \\mathrm{~s}^{-1}$. Find the length of the box. The unit of the answer should be $\\mathrm{~nm}$.", "answer": 1.8} 10 | {"content": "Use the normalized Numerov-method harmonic-oscillator wave functions found by going from -5 to 5 in steps of 0.1 to estimate the probability of being in the classically forbidden region for the $v=0$ state. The unit of the answer should be .", "answer": 0.16} 11 | {"content": "Calculate the de Broglie wavelength of an electron moving at 1/137th the speed of light. (At this speed, the relativistic correction to the mass is negligible.) The unit of the answer should be $\\mathrm{~nm}$.", "answer": 0.332} 12 | {"content": "Calculate the angle that the spin vector $S$ makes with the $z$ axis for an electron with spin function $\\alpha$. The unit of the answer should be $^{\\circ}$.", "answer": 54.7} 13 | {"content": "The AM1 valence electronic energies of the atoms $\\mathrm{H}$ and $\\mathrm{O}$ are $-11.396 \\mathrm{eV}$ and $-316.100 \\mathrm{eV}$, respectively. For $\\mathrm{H}_2 \\mathrm{O}$ at its AM1-calculated equilibrium geometry, the AM1 valence electronic energy (core-core repulsion omitted) is $-493.358 \\mathrm{eV}$ and the AM1 core-core repulsion energy is $144.796 \\mathrm{eV}$. For $\\mathrm{H}(g)$ and $\\mathrm{O}(g), \\Delta H_{f, 298}^{\\circ}$ values are 52.102 and $59.559 \\mathrm{kcal} / \\mathrm{mol}$, respectively. Find the AM1 prediction of $\\Delta H_{f, 298}^{\\circ}$ of $\\mathrm{H}_2 \\mathrm{O}(g)$. The unit of the answer should be $\\mathrm{kcal} / \\mathrm{mol}$.", "answer": -59.24} 14 | {"content": "Given that $D_e=4.75 \\mathrm{eV}$ and $R_e=0.741 Å$ for the ground electronic state of $\\mathrm{H}_2$, find $U\\left(R_e\\right)$ for this state. The unit of the answer should be $\\mathrm{eV}$.", "answer": -31.95} 15 | {"content": "For $\\mathrm{NaCl}, R_e=2.36 Å$. The ionization energy of $\\mathrm{Na}$ is $5.14 \\mathrm{eV}$, and the electron affinity of $\\mathrm{Cl}$ is $3.61 \\mathrm{eV}$. Use the simple model of $\\mathrm{NaCl}$ as a pair of spherical ions in contact to estimate $D_e$. [One debye (D) is $3.33564 \\times 10^{-30} \\mathrm{C} \\mathrm{m}$.] The unit of the answer should be $\\mathrm{eV}$.", "answer": 4.56} 16 | {"content": "Find the number of CSFs in a full CI calculation of $\\mathrm{CH}_2 \\mathrm{SiHF}$ using a 6-31G** basis set. The unit of the answer should be $10^{28} $.", "answer": 1.86} 17 | {"content": "Calculate the ratio of the electrical and gravitational forces between a proton and an electron. The unit of the answer should be $10^{39}$.", "answer": 2.0} 18 | {"content": "A one-particle, one-dimensional system has the state function\r\n$$\r\n\\Psi=(\\sin a t)\\left(2 / \\pi c^2\\right)^{1 / 4} e^{-x^2 / c^2}+(\\cos a t)\\left(32 / \\pi c^6\\right)^{1 / 4} x e^{-x^2 / c^2}\r\n$$\r\nwhere $a$ is a constant and $c=2.000 Å$. If the particle's position is measured at $t=0$, estimate the probability that the result will lie between $2.000 Å$ and $2.001 Å$. The unit of the answer should be .", "answer": 0.000216} 19 | {"content": "A one-particle, one-dimensional system has the state function\r\n$$\r\n\\Psi=(\\sin a t)\\left(2 / \\pi c^2\\right)^{1 / 4} e^{-x^2 / c^2}+(\\cos a t)\\left(32 / \\pi c^6\\right)^{1 / 4} x e^{-x^2 / c^2}\r\n$$\r\nwhere $a$ is a constant and $c=2.000 Å$. If the particle's position is measured at $t=0$, estimate the probability that the result will lie between $2.000 Å$ and $2.001 Å$. The unit of the answer should be .", "answer": 0.000216} 20 | {"content": "The $J=2$ to 3 rotational transition in a certain diatomic molecule occurs at $126.4 \\mathrm{GHz}$, where $1 \\mathrm{GHz} \\equiv 10^9 \\mathrm{~Hz}$. Find the frequency of the $J=5$ to 6 absorption in this molecule. The unit of the answer should be $\\mathrm{GHz}$.", "answer": 252.8} 21 | {"content": "Assume that the charge of the proton is distributed uniformly throughout the volume of a sphere of radius $10^{-13} \\mathrm{~cm}$. Use perturbation theory to estimate the shift in the ground-state hydrogen-atom energy due to the finite proton size. The potential energy experienced by the electron when it has penetrated the nucleus and is at distance $r$ from the nuclear center is $-e Q / 4 \\pi \\varepsilon_0 r$, where $Q$ is the amount of proton charge within the sphere of radius $r$. The evaluation of the integral is simplified by noting that the exponential factor in $\\psi$ is essentially equal to 1 within the nucleus.\r\n The unit of the answer should be $10^{-8} \\mathrm{eV}$.", "answer": 1.2} 22 | {"content": "An electron in a three-dimensional rectangular box with dimensions of $5.00 Å, 3.00 Å$, and $6.00 Å$ makes a radiative transition from the lowest-lying excited state to the ground state. Calculate the frequency of the photon emitted. The unit of the answer should be $10^{14} \\mathrm{~s}^{-1}$.", "answer": 7.58} 23 | {"content": "Do $\\mathrm{HF} / 6-31 \\mathrm{G}^*$ geometry optimizations on one conformers of $\\mathrm{HCOOH}$ with $\\mathrm{OCOH}$ dihedral angle of $0^{\\circ}$. Calculate the dipole moment. The unit of the answer should be $\\mathrm{D}$.", "answer": 1.41} 24 | {"content": "Frozen-core $\\mathrm{SCF} / \\mathrm{DZP}$ and CI-SD/DZP calculations on $\\mathrm{H}_2 \\mathrm{O}$ at its equilibrium geometry gave energies of -76.040542 and -76.243772 hartrees. Application of the Davidson correction brought the energy to -76.254549 hartrees. Find the coefficient of $\\Phi_0$ in the normalized CI-SD wave function. The unit of the answer should be .", "answer": 0.9731} 25 | {"content": "Let $w$ be the variable defined as the number of heads that show when two coins are tossed simultaneously. Find $\\langle w\\rangle$. The unit of the answer should be .", "answer": 1.0} 26 | {"content": "Calculate the force on an alpha particle passing a gold atomic nucleus at a distance of $0.00300 Å$. The unit of the answer should be $\\mathrm{~N}$.", "answer": 0.405} 27 | {"content": "When an electron in a certain excited energy level in a one-dimensional box of length $2.00 Å$ makes a transition to the ground state, a photon of wavelength $8.79 \\mathrm{~nm}$ is emitted. Find the quantum number of the initial state. The unit of the answer should be .", "answer": 4.0} 28 | {"content": "For a macroscopic object of mass $1.0 \\mathrm{~g}$ moving with speed $1.0 \\mathrm{~cm} / \\mathrm{s}$ in a one-dimensional box of length $1.0 \\mathrm{~cm}$, find the quantum number $n$. The unit of the answer should be $10^{26}$.", "answer": 3.0} 29 | {"content": "For the $\\mathrm{H}_2$ ground electronic state, $D_0=4.4781 \\mathrm{eV}$. Find $\\Delta H_0^{\\circ}$ for $\\mathrm{H}_2(g) \\rightarrow 2 \\mathrm{H}(g)$ in $\\mathrm{kJ} / \\mathrm{mol}$ The unit of the answer should be $\\mathrm{~kJ} / \\mathrm{mol}$.", "answer": 432.07} 30 | {"content": "The contribution of molecular vibrations to the molar internal energy $U_{\\mathrm{m}}$ of a gas of nonlinear $N$-atom molecules is (zero-point vibrational energy not included) $U_{\\mathrm{m}, \\mathrm{vib}}=R \\sum_{s=1}^{3 N-6} \\theta_s /\\left(e^{\\theta_s / T}-1\\right)$, where $\\theta_s \\equiv h \\nu_s / k$ and $\\nu_s$ is the vibrational frequency of normal mode $s$. Calculate the contribution to $U_{\\mathrm{m}, \\text { vib }}$ at $25^{\\circ} \\mathrm{C}$ of a normal mode with wavenumber $\\widetilde{v} \\equiv v_s / c$ of $900 \\mathrm{~cm}^{-1}$. The unit of the answer should be $\\mathrm{kJ} / \\mathrm{mol}$.", "answer": 0.14} 31 | {"content": "Calculate the magnitude of the spin magnetic moment of an electron. The unit of the answer should be $10^{-23} \\mathrm{~J} / \\mathrm{T}$.", "answer": 1.61} 32 | {"content": "A particle is subject to the potential energy $V=a x^4+b y^4+c z^4$. If its ground-state energy is $10 \\mathrm{eV}$, calculate $\\langle V\\rangle$ for the ground state. The unit of the answer should be $\\mathrm{eV}$.", "answer": 3.333333333} 33 | {"content": "For an electron in a certain rectangular well with a depth of $20.0 \\mathrm{eV}$, the lowest energy level lies $3.00 \\mathrm{eV}$ above the bottom of the well. Find the width of this well. Hint: Use $\\tan \\theta=\\sin \\theta / \\cos \\theta$ The unit of the answer should be $\\mathrm{~nm}$.", "answer": 0.264} 34 | {"content": "Calculate the uncertainty $\\Delta L_z$ for the hydrogen-atom stationary state: $2 p_z$. The unit of the answer should be .", "answer": 0.0} 35 | -------------------------------------------------------------------------------- /eval_vm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | import pathlib 5 | import argparse 6 | from utils.json_operator import * 7 | from utils.verify_MATH import extract_answer 8 | 9 | 10 | def parse_args(): 11 | base_args = argparse.ArgumentParser() 12 | base_args.add_argument('--task_name', type=str, default='gsm_8k') 13 | base_args.add_argument('--file', type=str, default='gsm8k_all') # json 14 | base_args.add_argument('--propose_method', type=str, choices=['gpt', 'glm', 'llama', 'mistral', 'local'], 15 | default='mistral') 16 | base_args.add_argument('--generate_num', type=int, default=256) 17 | base_args.add_argument('--evaluate_method', type=str, choices=['best', 'weighted'], default='best') 18 | arguments = base_args.parse_args() 19 | return arguments 20 | 21 | 22 | def eval_vm(arguments): 23 | base_dir = os.getcwd() 24 | out_file = f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/cot/{arguments.propose_method}_local_vm_critic_all.json' 25 | datas = read_json(out_file) 26 | idx = 0 27 | corr_num = 0 28 | total_num = 0 29 | while idx < len(datas): 30 | total_num += 1 31 | cur_datas = datas[idx:idx + arguments.generate_num] 32 | idx += arguments.generate_num 33 | if arguments.evaluate_method == 'best': 34 | sorted_cur_datas = sorted(cur_datas, key=lambda x: x['vm_critic'], reverse=True) 35 | i = 0 36 | while not sorted_cur_datas[i]['summary'] and i < len(sorted_cur_datas) - 1: 37 | i += 1 38 | selected_data = sorted_cur_datas[i] 39 | if selected_data['accurate']: 40 | corr_num += 1 41 | 42 | elif arguments.evaluate_method == 'weighted': 43 | all_answers = {} # {answer: [idx, summ, value]} 44 | for i, data in enumerate(cur_datas): 45 | summ = data['summary'] 46 | if not summ: 47 | continue 48 | 49 | extracted_answer = extract_answer(summ) 50 | if extracted_answer in all_answers.keys(): 51 | all_answers[extracted_answer][2] += data['vm_critic'] 52 | else: 53 | all_answers[extracted_answer] = [i, summ, data['vm_critic']] 54 | 55 | if not all_answers: 56 | continue 57 | best_answer = max(all_answers.values(), key=lambda x: x[2]) 58 | best_id = best_answer[0] 59 | if cur_datas[best_id]['accurate']: 60 | corr_num += 1 61 | 62 | else: 63 | print('evaluate_method not implemented') 64 | raise NotImplementedError 65 | 66 | print(f'Test accuracy:{corr_num / total_num}') 67 | print(f'Total number of samples tested:{total_num}') 68 | print(f'Test the correct number of samples:{corr_num}') 69 | 70 | 71 | if __name__ == '__main__': 72 | args = parse_args() 73 | eval_vm(args) 74 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from CoT.task import CoT_Task 4 | from ToT.task import ToT_Task 5 | from MCTS.task import MCTS_Task 6 | import argparse 7 | from utils.visualize import visualize 8 | from utils.json_operator import * 9 | from utils.verify_answer import * 10 | from utils.self_consistency import get_consistency_output_scibench 11 | 12 | 13 | def run(arguments): 14 | print('-'*30, 'Begin testing', '-'*30, '\n') 15 | file = f'data/{arguments.task_name}/{arguments.file}.json' 16 | try: 17 | data_list = read_json(file) 18 | data_len = len(data_list) 19 | except Exception as e: 20 | print(f'File must be standardized json!\nError type:{e}\n') 21 | return 22 | assert data_len > 0, "Data list is empty!\n" 23 | assert 'content' in data_list[0].keys() and 'answer' in data_list[0].keys(), "Key error, Make sure json object contain correct keys!\n" 24 | 25 | output_list = [] 26 | correct_count = 0 27 | for i in range(data_len): 28 | # solve 29 | print(f'Begin to solve the problem {i+1}...\n') 30 | data = data_list[i]['content'] 31 | answer = data_list[i]['answer'] 32 | if arguments.mode == 'cot': 33 | Task = CoT_Task(data, arguments.propose_method, arguments.value_method, arguments.temperature, evaluate=arguments.evaluate) 34 | if arguments.consistency: 35 | outputs = [] 36 | for cnt in range(3): 37 | output = Task.run() 38 | outputs.append(output) 39 | output = get_consistency_output_scibench(outputs) 40 | else: 41 | output = Task.run() 42 | 43 | elif arguments.mode == 'tot': 44 | Task = ToT_Task(data, arguments.propose_method, arguments.value_method, arguments.algorithm, 45 | arguments.branch, arguments.select_branch, arguments.max_depth, arguments.end_gate, 46 | arguments.select_method, arguments.temperature, use_case_prompt=arguments.use_case_prompt, 47 | low=arguments.low, high=arguments.high, evaluate=arguments.evaluate) 48 | output, root = Task.run() 49 | if arguments.visualize: 50 | visualize(root, Task, arguments.task_name, arguments.file, i + 1) 51 | else: 52 | Task = MCTS_Task(data, arguments.propose_method, arguments.value_method, arguments.branch, arguments.end_gate, 53 | arguments.roll_policy, arguments.roll_branch, arguments.roll_forward_steps, arguments.time_limit, 54 | arguments.iteration_limit, arguments.exploration_constant, arguments.alpha, arguments.inf, 55 | arguments.temperature, use_case_prompt=arguments.use_case_prompt, use_reflection=arguments.use_reflection, 56 | low=arguments.low, high=arguments.high, evaluate=arguments.evaluate) 57 | output, root = Task.run() 58 | if arguments.visualize: 59 | visualize(root, Task, arguments.task_name, arguments.file, i + 1) 60 | 61 | # evaluate metrics 62 | if arguments.evaluate: 63 | result = verify_float(answer, output['summary']) 64 | output.update({'answer': answer, 'accurate': result}) 65 | if result: 66 | print(f'The answer of problem {i+1} is correct.\n') 67 | correct_count += 1 68 | else: 69 | print(f'The answer of problem {i+1} is wrong.\n') 70 | print(f'The solution to problem {i+1} is complete.\n') 71 | 72 | # output 73 | base_dir = os.getcwd() 74 | output_dir = pathlib.Path(f'{base_dir}/outputs/{arguments.task_name}/{arguments.file}/{Task.mode}') 75 | output_file = f'{base_dir}/outputs/{arguments.task_name}/{arguments.file}/{Task.mode}/{Task.propose_method}_{Task.value_method}.json' 76 | output_list.append(output) 77 | pathlib.Path.mkdir(output_dir, exist_ok=True, parents=True) 78 | dump_json(output_file, output_list) 79 | 80 | print('_' * 60) 81 | # accuracy 82 | if args.evaluate: 83 | print(f'Test accuracy:{correct_count / data_len}\n') 84 | print(f'Correct number of problems:{correct_count}\nTotal number of questions:{data_len}\n') 85 | print('_' * 60) 86 | 87 | 88 | def parse_args(): 89 | base_args = argparse.ArgumentParser() 90 | base_args.add_argument('--task_name', type=str, default='scibench') 91 | base_args.add_argument('--file', type=str, default='thermo_standardized') # json 92 | base_args.add_argument('--propose_method', type=str, choices=['gpt', 'glm', 'llama', 'local'], default='glm') 93 | base_args.add_argument('--value_method', type=str, choices=['gpt', 'glm', 'local'], default='local') 94 | base_args.add_argument('--mode', type=str, choices=['cot', 'tot', 'mcts'], default='tot') 95 | base_args.add_argument('--temperature', type=float, default=0.7) 96 | base_args.add_argument('--time_limit', type=int, default=None) 97 | base_args.add_argument('--iteration_limit', type=int, default=100) 98 | base_args.add_argument('--roll_policy', type=str, choices=['random', 'greedy'], default='greedy') 99 | base_args.add_argument('--exploration_constant', type=float, default=0.4) 100 | base_args.add_argument('--roll_forward_steps', type=int, default=2) 101 | base_args.add_argument('--end_gate', type=float, default=0.9) # End threshold 102 | base_args.add_argument('--branch', type=int, default=3) 103 | base_args.add_argument('--roll_branch', type=int, default=1) 104 | base_args.add_argument('--inf', type=float, default=0.8) 105 | base_args.add_argument('--evaluate', type=str, default='scibench') # Whether to evaluate (empty means no evaluation) 106 | base_args.add_argument('--alpha', type=float, default=0.5) 107 | base_args.add_argument('--visualize', type=bool, default=False) # visualization 108 | base_args.add_argument('--use_case_prompt', type=bool, default=False) # Use sample prompts 109 | base_args.add_argument('--use_reflection', type=str, choices=['simple', 'common'], default='simple') # Use reflective mode 110 | base_args.add_argument('--low', type=float, default=0) 111 | base_args.add_argument('--high', type=float, default=1) 112 | base_args.add_argument('--algorithm', type=str, choices=['dfs', 'bfs'], default='dfs') 113 | base_args.add_argument('--select_branch', type=int, default=2) 114 | base_args.add_argument('--max_depth', type=int, default=8) 115 | base_args.add_argument('--select_method', type=str, choices=['greedy', 'sample'], default='greedy') 116 | base_args.add_argument('--consistency', type=bool, default=True) 117 | 118 | arguments = base_args.parse_args() 119 | return arguments 120 | 121 | 122 | if __name__ == '__main__': 123 | args = parse_args() 124 | run(args) 125 | -------------------------------------------------------------------------------- /figures/MATH2_completion_self_train.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/ReST-MCTS/2d5f488c3d6e24f99d50a9860e818383b1bb5883/figures/MATH2_completion_self_train.pdf -------------------------------------------------------------------------------- /figures/data/ablation_math2_self_training.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/ReST-MCTS/2d5f488c3d6e24f99d50a9860e818383b1bb5883/figures/data/ablation_math2_self_training.xlsx -------------------------------------------------------------------------------- /figures/plot_math_self_training.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | from matplotlib import rcParams 5 | 6 | # sns.set_theme(style="darkgrid") 7 | all_data = [] 8 | source = './data/ablation_math2_self_training.xlsx' 9 | data = pd.read_excel(source, sheet_name=0) 10 | # print(data) 11 | 12 | y_fontsize = 27 13 | y_title_fontsize = 27 14 | legend_fontsize = 15 15 | config = { 16 | 'font.family': 'Times New Roman', 17 | 'font.size': 28, 18 | } 19 | rcParams.update(config) 20 | 21 | # Plot the responses for different events and regions 22 | # plt.rc('font', family='Times New Roman', size=24) 23 | plt.figure(figsize=(11, 8)) 24 | # sns.lineplot(x="completion_per_question(k)", y="acc", hue="Method", err_style="bars", style="Method", marker="*", linewidth=3, data=data) 25 | sns.lineplot(x="completion_per_question(k)", y="acc", data=data, err_style="bars", hue="Method", style="Method", markers=True, linewidth=4, dashes=False, errorbar=('ci', 50), markersize=15) 26 | plt.xlabel("Completion Tokens (Average Per Question)", fontsize=y_title_fontsize) 27 | # plt.xticks(fontsize=18) 28 | # plt.xticks([0, 2, 4, 6, 8], ['0', '10,000', '20,000', '30,000', '40,000'], fontsize=y_fontsize) 29 | # plt.xticks([0, 2, 4, 6, 8], ['0', '10', '20', '30', '40'], fontsize=y_fontsize) 30 | plt.xticks([0, 10000, 20000, 30000, 40000], ['0', '10,000', '20,000', '30,000', '40,000'], fontsize=y_fontsize) 31 | plt.yticks(fontsize=y_fontsize) 32 | plt.ylabel("Accuracy", fontsize=y_title_fontsize) 33 | # plt.legend(fontsize=legend_fontsize) 34 | handles, labels = plt.gca().get_legend_handles_labels() 35 | order = [4,1,3,0,2] 36 | plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order], fontsize=legend_fontsize) 37 | plt.savefig('MATH2_completion_self_train.pdf') 38 | # plt.show() 39 | 40 | -------------------------------------------------------------------------------- /models/get_response.py: -------------------------------------------------------------------------------- 1 | from models.model import * 2 | 3 | 4 | # given prompt, generate proposal under instruction, unwrap is required 5 | def get_proposal(prompt, method='glm', temperature=0.7, max_tokens=2048, seed=170, max_length=2048, truncation=True, 6 | do_sample=True, max_new_tokens=1024): 7 | response = [] 8 | cnt = 2 9 | if method == 'glm': 10 | while not response and cnt: 11 | response = glm(prompt, BASE_MODEL_GLM, temperature=temperature, max_tokens=max_tokens, seed=seed) 12 | cnt -= 1 13 | if not response: 14 | print(f'obtain<{method}>response fail!\n') 15 | return [] 16 | return response 17 | 18 | elif method == 'gpt': 19 | while not response and cnt: 20 | response = gpt(prompt, model=BASE_MODEL_GPT, temperature=temperature, max_tokens=max_tokens) 21 | cnt -= 1 22 | if not response: 23 | print(f'obtain<{method}>response fail!\n') 24 | return [] 25 | return response 26 | 27 | elif method == 'llama' or method == 'mistral' or method == 'local': 28 | while not response and cnt: 29 | response = local_inference_model(prompt, max_length=max_length, truncation=truncation, do_sample=do_sample, 30 | max_new_tokens=max_new_tokens, temperature=temperature) 31 | cnt -= 1 32 | if not response: 33 | print(f'obtain<{method}>response fail!\n') 34 | return [] 35 | return response 36 | 37 | else: 38 | print('This method of getting responses is not yet supported!\n') 39 | return [] 40 | 41 | 42 | # given prompt + answer, find its value 43 | # if you use api, unwrap is required. if you use local value model, the value is directly obtained 44 | def get_value(prompt_answer, method='glm', temperature=0.7, max_tokens=1000, seed=170, max_length=2048, low=0, high=1): 45 | response = [] 46 | cnt = 2 47 | if method == 'glm': 48 | while not response and cnt: 49 | response = glm(prompt_answer, BASE_MODEL_GLM, temperature=temperature, max_tokens=max_tokens, seed=seed) 50 | cnt -= 1 51 | if not response: 52 | print(f'obtain<{method}>score fail!\n') 53 | return [] 54 | return response 55 | 56 | elif method == 'gpt': 57 | while not response and cnt: 58 | response = gpt(prompt_answer, model=BASE_MODEL_GPT, temperature=temperature, max_tokens=max_tokens) 59 | cnt -= 1 60 | if not response: 61 | print(f'obtain<{method}>score fail!\n') 62 | return [] 63 | return response 64 | 65 | elif method == 'local': 66 | value = low 67 | while cnt: 68 | try: 69 | value = local_value_model(prompt_answer, max_length=max_length, low=low, high=high) 70 | break 71 | except Exception as e: 72 | print(f'obtain<{method}>score fail!\nError:{e}\n') 73 | cnt -= 1 74 | return value 75 | 76 | else: 77 | print('This method of getting scores is not yet supported!\n') 78 | return [] 79 | -------------------------------------------------------------------------------- /models/inference_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM 5 | 6 | 7 | # get model and tokenizer 8 | def get_inference_model(model_dir): 9 | inference_tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) 10 | inference_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda() 11 | inference_model.eval() 12 | return inference_tokenizer, inference_model 13 | 14 | 15 | # get llama model and tokenizer 16 | def get_inference_model_llama(model_dir): 17 | inference_model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16) 18 | inference_tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) 19 | device = "cuda" 20 | inference_model.to(device) 21 | return inference_tokenizer, inference_model 22 | 23 | 24 | # get mistral model and tokenizer 25 | def get_inference_model_mistral(model_dir): 26 | inference_model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16) 27 | inference_tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) 28 | # inference_tokenizer.pad_token = inference_tokenizer.eos_token 29 | device = "cuda" 30 | inference_model.to(device) 31 | return inference_tokenizer, inference_model 32 | 33 | 34 | # get glm model response 35 | def get_local_response(query, model, tokenizer, max_length=2048, truncation=True, do_sample=False, max_new_tokens=1024, temperature=0.7): 36 | cnt = 2 37 | all_response = '' 38 | while cnt: 39 | try: 40 | inputs = tokenizer([query], return_tensors="pt", truncation=truncation, max_length=max_length).to('cuda') 41 | output_ = model.generate(**inputs, do_sample=do_sample, max_new_tokens=max_new_tokens, temperature=temperature) 42 | output = output_.tolist()[0][len(inputs["input_ids"][0]):] 43 | response = tokenizer.decode(output) 44 | 45 | print(f'obtain response:{response}\n') 46 | all_response = response 47 | break 48 | except Exception as e: 49 | print(f'Error:{e}, obtain response again...\n') 50 | cnt -= 1 51 | if not cnt: 52 | return [] 53 | split_response = all_response.strip().split('\n') 54 | return split_response 55 | 56 | 57 | # get llama model response 58 | def get_local_response_llama(query, model, tokenizer, max_length=2048, truncation=True, max_new_tokens=1024, temperature=0.7, do_sample=False): 59 | cnt = 2 60 | all_response = '' 61 | # messages = [{"role": "user", "content": query}] 62 | # data = tokenizer.apply_chat_template(messages, return_tensors="pt").cuda() 63 | terminators = [ 64 | tokenizer.eos_token_id, 65 | tokenizer.convert_tokens_to_ids("<|eot_id|>") 66 | ] 67 | message = '<|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'.format(query=query) 68 | data = tokenizer.encode_plus(message, max_length=max_length, truncation=truncation, return_tensors='pt') 69 | input_ids = data['input_ids'].to('cuda') 70 | attention_mask = data['attention_mask'].to('cuda') 71 | while cnt: 72 | try: 73 | # query = "Human: " + query + "Assistant: " 74 | # input_ids = tokenizer([query], return_tensors="pt", add_special_tokens=False).input_ids.to('cuda') 75 | output = model.generate(input_ids, attention_mask=attention_mask, do_sample=do_sample, max_new_tokens=max_new_tokens, temperature=temperature, eos_token_id=terminators, pad_token_id=tokenizer.eos_token_id) 76 | ori_string = tokenizer.decode(output[0], skip_special_tokens=False) 77 | processed_string = ori_string.split('<|end_header_id|>')[2].strip().split('<|eot_id|>')[0].strip() 78 | response = processed_string.split('<|end_of_text|>')[0].strip() 79 | 80 | # print(f'获得回复:{response}\n') 81 | all_response = response 82 | break 83 | except Exception as e: 84 | print(f'Error:{e}, obtain response again...\n') 85 | cnt -= 1 86 | if not cnt: 87 | return [] 88 | # split_response = all_response.split("Assistant:")[-1].strip().split('\n') 89 | split_response = all_response.split('\n') 90 | return split_response 91 | 92 | 93 | # get mistral model response 94 | def get_local_response_mistral(query, model, tokenizer, max_length=1024, truncation=True, max_new_tokens=1024, temperature=0.7, do_sample=False): 95 | cnt = 2 96 | all_response = '' 97 | # messages = [{"role": "user", "content": query}] 98 | # data = tokenizer.apply_chat_template(messages, max_length=max_length, truncation=truncation, return_tensors="pt").cuda() 99 | message = '[INST]' + query + '[/INST]' 100 | data = tokenizer.encode_plus(message, max_length=max_length, truncation=truncation, return_tensors='pt') 101 | input_ids = data['input_ids'].to('cuda') 102 | attention_mask = data['attention_mask'].to('cuda') 103 | while cnt: 104 | try: 105 | output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id) 106 | ori_string = tokenizer.decode(output[0]) 107 | processed_string = ori_string.split('[/INST]')[1].strip() 108 | response = processed_string.split('')[0].strip() 109 | 110 | print(f'obtain response:{response}\n') 111 | all_response = response 112 | break 113 | except Exception as e: 114 | print(f'Error:{e}, obtain response again...\n') 115 | cnt -= 1 116 | if not cnt: 117 | return [] 118 | all_response = all_response.split('The answer is:')[0].strip() # intermediate steps should not always include a final answer 119 | ans_count = all_response.split('####') 120 | if len(ans_count) >= 2: 121 | all_response = ans_count[0] + 'Therefore, the answer is:' + ans_count[1] 122 | all_response = all_response.replace('[SOL]', '').replace('[ANS]', '').replace('[/ANS]', '').replace('[INST]', '').replace('[/INST]', '').replace('[ANSW]', '').replace('[/ANSW]', '') # remove unique answer mark for mistral 123 | split_response = all_response.split('\n') 124 | return split_response 125 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | import backoff 4 | import requests 5 | import json 6 | from models.inference_models import get_local_response, get_inference_model, get_inference_model_llama, get_local_response_llama, get_inference_model_mistral, get_local_response_mistral 7 | from models.value_models import get_local_value, get_value_model, get_value_model_prm, get_value_model_mistral, get_value_model_prm_mistral 8 | from transformers import AutoModel, AutoTokenizer 9 | 10 | # openai api settings 11 | API_KEY = 'sk-**' 12 | API_BASE = 'base' 13 | BASE_MODEL_GPT = "gpt-3.5-turbo" 14 | 15 | # GLM api settings 16 | URL = "https://api.chatglm.cn/v1/chat/completions" 17 | ID = "**" 18 | AUTH = '**' 19 | CONTENT_TYPE = 'application/json; charset=utf-8' 20 | BASE_MODEL_GLM = 'GLM4' 21 | 22 | # local model settings 23 | # if you want to use local models, set these two directories 24 | # INFERENCE_MODEL_DIR = "/workspace/ckpt/Meta-Llama-3-8B-Instruct" 25 | INFERENCE_MODEL_DIR = None 26 | LOCAL_INFERENCE_TYPES = ['glm', 'llama', 'mistral'] 27 | LOCAL_INFERENCE_IDX = 0 28 | 29 | # VALUE_BASE_MODEL_DIR = "/workspace/ckpt/MetaMath-Mistral-7B" 30 | VALUE_BASE_MODEL_DIR = None 31 | # VALUE_MODEL_STATE_DICT = "/Path/to/PRM/records/Mistral/VM_best_checkpoint.pt" 32 | VALUE_MODEL_STATE_DICT = None 33 | LOCAL_VALUE_TYPES = ['glm', 'mistral'] 34 | LOCAL_VALUE_IDX = 0 35 | USE_PRM = False 36 | 37 | INFERENCE_LOCAL = False 38 | VALUE_LOCAL = False 39 | 40 | # implement the inference model 41 | if INFERENCE_MODEL_DIR is not None: 42 | INFERENCE_LOCAL = True 43 | inference_type = LOCAL_INFERENCE_TYPES[LOCAL_INFERENCE_IDX] 44 | if inference_type == 'glm': 45 | inference_tokenizer, inference_model = get_inference_model(INFERENCE_MODEL_DIR) 46 | elif inference_type == 'llama': 47 | inference_tokenizer, inference_model = get_inference_model_llama(INFERENCE_MODEL_DIR) 48 | else: 49 | inference_tokenizer, inference_model = get_inference_model_mistral(INFERENCE_MODEL_DIR) 50 | 51 | # implement the value model (reward model) 52 | if VALUE_BASE_MODEL_DIR is not None: 53 | VALUE_LOCAL = True 54 | value_type = LOCAL_VALUE_TYPES[LOCAL_VALUE_IDX] 55 | if USE_PRM: 56 | if value_type == 'glm': 57 | value_tokenizer, value_model = get_value_model_prm(VALUE_BASE_MODEL_DIR, VALUE_MODEL_STATE_DICT) 58 | else: 59 | value_tokenizer, value_model = get_value_model_prm_mistral(VALUE_BASE_MODEL_DIR, VALUE_MODEL_STATE_DICT) 60 | else: 61 | if value_type == 'glm': 62 | value_tokenizer, value_model = get_value_model(VALUE_BASE_MODEL_DIR, VALUE_MODEL_STATE_DICT) 63 | else: 64 | value_tokenizer, value_model = get_value_model_mistral(VALUE_BASE_MODEL_DIR, VALUE_MODEL_STATE_DICT) 65 | 66 | completion_tokens = prompt_tokens = 0 67 | api_key = API_KEY 68 | if api_key != "": 69 | openai.api_key = api_key 70 | print(f'api_key:{api_key}\n') 71 | else: 72 | print("Warning: OPENAI_API_KEY is not set") 73 | 74 | api_base = API_BASE 75 | if api_base != "": 76 | print("Warning: OPENAI_API_BASE is set to {}".format(api_base)) 77 | openai.api_base = api_base 78 | 79 | 80 | @backoff.on_exception(backoff.expo, openai.error.OpenAIError) 81 | def completions_with_backoff(**kwargs): 82 | return openai.ChatCompletion.create(**kwargs) 83 | 84 | 85 | def gpt(prompt, model=BASE_MODEL_GPT, temperature=0.7, max_tokens=1000, n=1, stop=None) -> list: 86 | messages = [{"role": "user", "content": prompt}] 87 | out = [] 88 | cnt = 5 89 | while cnt: 90 | try: 91 | out = chatgpt(messages, model=model, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)[ 92 | 0].split('\n') 93 | break 94 | except Exception as e: 95 | print(f"Error occurred when getting gpt reply!\nError type:{e}\n") 96 | cnt -= 1 97 | return out 98 | 99 | 100 | def chatgpt(messages, model=BASE_MODEL_GPT, temperature=0.7, max_tokens=1000, n=1, stop=None) -> list: 101 | global completion_tokens, prompt_tokens 102 | outputs = [] 103 | while n > 0: 104 | cnt = min(n, 20) 105 | n -= cnt 106 | res = completions_with_backoff(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, 107 | n=cnt, stop=stop) 108 | # print(f'得到GPT回复:{res}\n\n') 109 | outputs.extend([choice["message"]["content"] for choice in res["choices"]]) 110 | # log completion tokens 111 | completion_tokens += res["usage"]["completion_tokens"] 112 | prompt_tokens += res["usage"]["prompt_tokens"] 113 | return outputs 114 | 115 | 116 | def gpt_usage(backend=BASE_MODEL_GPT): 117 | global completion_tokens, prompt_tokens 118 | if backend == "gpt-4": 119 | cost = completion_tokens / 1000 * 0.06 + prompt_tokens / 1000 * 0.03 120 | elif backend == "gpt-3.5-turbo": 121 | cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015 122 | else: 123 | cost = -1 124 | return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost} 125 | 126 | 127 | def extract_data(text): 128 | lines = text.split('\n') 129 | extracted_data = [] 130 | should_extract = False 131 | 132 | for line in lines: 133 | if line.startswith('event: finish'): 134 | should_extract = True 135 | elif should_extract and line.startswith('data: '): # and "left: " in line: 136 | if len(line[6:]) > 0: 137 | extracted_data.append(line[6:]) # Remove 'data: ' prefix, remain '\n' 138 | return extracted_data 139 | 140 | 141 | def glm(prompt, model=BASE_MODEL_GLM, temperature=0.7, max_tokens=1000, seed=170) -> list: 142 | return get_glm_reply(prompt, model, temperature=temperature, max_tokens=max_tokens, seed=seed) 143 | 144 | 145 | def get_glm_reply(query, model, temperature=0.7, max_tokens=1000, seed=175): 146 | if model == 'ChatGLM2': 147 | url = URL 148 | payload = { 149 | "id": ID, 150 | "prompt": query, 151 | "seed": seed, 152 | "max_tokens": str(max_tokens), 153 | "temperature": temperature, 154 | } 155 | headers = { 156 | 'Authorization': AUTH, 157 | 'Content-Type': CONTENT_TYPE 158 | } 159 | 160 | tol = 3 161 | response = None 162 | while tol: 163 | try: 164 | response = requests.post(url, headers=headers, data=json.dumps(payload)) 165 | break 166 | except Exception as e: 167 | print(f'Error occurred when getting proxy reply!\nError type:{e}\nRetrying...\n') 168 | tol -= 1 169 | 170 | if response is None: 171 | print('Error occurred when getting proxy reply!\n') 172 | return [] 173 | 174 | reply = response.content.decode('utf-8') 175 | replies = extract_data(reply) 176 | return replies 177 | 178 | elif model == 'GLM4': 179 | url = URL 180 | payload = { 181 | 'model': "glm4-alltools-130b-awq", 182 | "messages": [{"role": "user", "content": query}], 183 | "temperature": temperature, 184 | "top_p": 0.7, 185 | "stream": False, 186 | "max_tokens": max_tokens 187 | } 188 | headers = { 189 | 'Authorization': AUTH, 190 | 'Content-Type': CONTENT_TYPE 191 | } 192 | 193 | tol = 3 194 | response = None 195 | while tol: 196 | try: 197 | response = requests.post(url, headers=headers, data=json.dumps(payload)) 198 | break 199 | except Exception as e: 200 | print(f'Error occurred when getting proxy reply!\nError type:{e}\nRetrying...\n') 201 | tol -= 1 202 | 203 | if response is None: 204 | print('Error occurred when getting proxy reply!\n') 205 | return [] 206 | 207 | reply = response.content.decode('utf-8') 208 | # print('reply:', reply) 209 | try: 210 | content = reply.split("\"content\":\"")[1].split("\",\"role\":\"assistant\"")[0] 211 | except Exception as e: 212 | print(f'Error occurred when decoding reply!\nError type:{e}\n') 213 | return [] 214 | return content.split('\n') 215 | 216 | elif model == 'GLM3': 217 | url = URL 218 | payload = { 219 | 'model': "chatglm3-32b-v0.8", 220 | "messages": [{"role": "user", "content": query}], 221 | "temperature": temperature, 222 | "top_p": 0.7, 223 | "stream": False, 224 | "max_tokens": max_tokens 225 | } 226 | headers = { 227 | 'Authorization': AUTH, 228 | 'Content-Type': CONTENT_TYPE 229 | } 230 | 231 | tol = 3 232 | response = None 233 | while tol: 234 | try: 235 | response = requests.post(url, headers=headers, data=json.dumps(payload)) 236 | break 237 | except Exception as e: 238 | print(f'Error occurred when getting proxy reply!\nError type:{e}\nRetrying...\n') 239 | tol -= 1 240 | 241 | if response is None: 242 | print('Error occurred when getting proxy reply!\n') 243 | return [] 244 | 245 | reply = response.content.decode('utf-8') 246 | # print('reply:', reply) 247 | try: 248 | content = reply.split("\"content\":\"")[1].split("\",\"role\":\"assistant\"")[0] 249 | except Exception as e: 250 | print(f'Error occurred when decoding reply!\nError type:{e}\n') 251 | return [] 252 | return content.split('\n') 253 | 254 | else: 255 | print('Unsupported glm model!\n') 256 | return [] 257 | 258 | 259 | def local_inference_model(query, max_length=2048, truncation=True, do_sample=False, max_new_tokens=1024, 260 | temperature=0.7): 261 | assert INFERENCE_LOCAL, "Inference model not implemented!\n" 262 | if inference_type == 'glm': 263 | return get_local_response(query, inference_model, inference_tokenizer, max_length=max_length, 264 | truncation=truncation, 265 | do_sample=do_sample, max_new_tokens=max_new_tokens, temperature=temperature) 266 | elif inference_type == 'llama': 267 | return get_local_response_llama(query, inference_model, inference_tokenizer, max_new_tokens=max_new_tokens, 268 | temperature=temperature, do_sample=do_sample) 269 | else: 270 | return get_local_response_mistral(query, inference_model, inference_tokenizer, max_new_tokens=max_new_tokens, 271 | temperature=temperature, do_sample=do_sample) 272 | 273 | 274 | def local_value_model(prompt_answer, max_length=2048, low=0, high=1): 275 | assert VALUE_LOCAL, "Value model not implemented!\n" 276 | return get_local_value(prompt_answer, value_model, value_tokenizer, max_length=max_length, low=low, high=high) 277 | -------------------------------------------------------------------------------- /models/value_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '6' 4 | import torch 5 | import torch.nn as nn 6 | from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM 7 | 8 | 9 | # define your value model class 10 | class ChatGLM_VM(nn.Module): 11 | def __init__(self, base, vocab_size, num_classes=1): 12 | super(ChatGLM_VM, self).__init__() 13 | self.base_model = base 14 | self.LN = nn.Linear(vocab_size, num_classes, dtype=torch.bfloat16) 15 | 16 | def forward(self, input_ids, attention_mask): 17 | outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask).logits[:, -1] 18 | value_outputs = self.LN(outputs) 19 | return value_outputs.squeeze(dim=1) 20 | 21 | 22 | class Mistral_VM(nn.Module): 23 | def __init__(self, base, vocab_size=32000): 24 | super(Mistral_VM, self).__init__() 25 | self.base_model = base 26 | self.LN = nn.Linear(vocab_size, 1) 27 | 28 | def forward(self, input_ids, attention_mask): 29 | outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask).logits[:, -1, :] 30 | value_outputs = self.LN(outputs) 31 | return value_outputs.squeeze(dim=1) 32 | 33 | 34 | class ChatGLM_PRM(nn.Module): 35 | def __init__(self, base): 36 | super(ChatGLM_PRM, self).__init__() 37 | self.base_model = base 38 | 39 | def forward(self, input_ids, attention_mask): 40 | outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask).logits 41 | probs = torch.softmax(outputs, dim=-1) 42 | output = probs[:, -1, 7081] # n*1 tensor, 7081 is the index of token 'True' 43 | return output 44 | 45 | 46 | class Mistral_PRM(nn.Module): 47 | def __init__(self, base): 48 | super(Mistral_PRM, self).__init__() 49 | self.base_model = base 50 | 51 | def forward(self, input_ids, attention_mask): 52 | outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask).logits 53 | probs = torch.softmax(outputs, dim=-1) 54 | output = probs[:, -1, 7081] # n*1 tensor, 7081 is the index of token 'True' 55 | return output 56 | 57 | 58 | # get value model 59 | def get_value_model(base_model_dir, state_dict_file): 60 | value_tokenizer = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True) 61 | value_base_model = AutoModel.from_pretrained(base_model_dir, trust_remote_code=True).bfloat16().cuda() 62 | if state_dict_file is None: 63 | return value_tokenizer, value_base_model 64 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 65 | print("device is set to: ", device, '\n') 66 | vocab_size = value_base_model.config.padded_vocab_size 67 | VM = ChatGLM_VM(value_base_model, vocab_size, 1) 68 | VM.load_state_dict(torch.load(state_dict_file)) 69 | VM.to(device) 70 | VM.eval() 71 | return value_tokenizer, VM 72 | 73 | 74 | def get_value_model_mistral(base_model_dir, state_dict_file): 75 | value_tokenizer = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True) 76 | # value_tokenizer.pad_token = value_tokenizer.eos_token 77 | value_base_model = AutoModelForCausalLM.from_pretrained(base_model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16) 78 | if state_dict_file is None: 79 | return value_tokenizer, value_base_model 80 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 81 | print("device is set to: ", device, '\n') 82 | vocab_size = value_base_model.config.vocab_size 83 | VM = Mistral_VM(value_base_model, vocab_size) 84 | VM.load_state_dict(torch.load(state_dict_file)) 85 | VM.to(device) 86 | VM.eval() 87 | return value_tokenizer, VM 88 | 89 | 90 | # get prm 91 | def get_value_model_prm(base_model_dir, state_dict_file): 92 | prm_tokenizer = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True) 93 | prm_base_model = AutoModel.from_pretrained(base_model_dir, trust_remote_code=True).bfloat16().cuda() 94 | if state_dict_file is None: 95 | return prm_tokenizer, prm_base_model 96 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 97 | print("device is set to: ", device, '\n') 98 | prm = ChatGLM_PRM(prm_base_model) 99 | prm.load_state_dict(torch.load(state_dict_file)) 100 | prm.to(device) 101 | prm.eval() 102 | return prm_tokenizer, prm 103 | 104 | 105 | def get_value_model_prm_mistral(base_model_dir, state_dict_file): 106 | prm_tokenizer = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True) 107 | # prm_tokenizer.pad_token = prm_tokenizer.eos_token 108 | prm_base_model = AutoModelForCausalLM.from_pretrained(base_model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16) 109 | if state_dict_file is None: 110 | return prm_tokenizer, prm_base_model 111 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 112 | print("device is set to: ", device, '\n') 113 | prm = Mistral_PRM(prm_base_model) 114 | prm.load_state_dict(torch.load(state_dict_file)) 115 | prm.to(device) 116 | prm.eval() 117 | return prm_tokenizer, prm 118 | 119 | 120 | # local value model: str->digit in [low, high] 121 | def get_local_value(prompt_answer, model, tokenizer, max_length=2048, low=0, high=1): 122 | encoded_pair = tokenizer.encode_plus( 123 | prompt_answer, 124 | padding='max_length', 125 | max_length=max_length, # Set the max length 126 | truncation=True, 127 | return_tensors='pt', # Return PyTorch Tensor format 128 | ) 129 | input_ids = encoded_pair['input_ids'].to('cuda') 130 | # print(input_ids) 131 | attention_mask = encoded_pair['attention_mask'].to('cuda') 132 | value = model(input_ids, attention_mask).item() 133 | value = min(high, max(value, low)) 134 | return value 135 | -------------------------------------------------------------------------------- /requirements_mistral.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/ReST-MCTS/2d5f488c3d6e24f99d50a9860e818383b1bb5883/requirements_mistral.txt -------------------------------------------------------------------------------- /requirements_sciglm.txt: -------------------------------------------------------------------------------- 1 | --index-url https://pypi.tuna.tsinghua.edu.cn/simple 2 | protobuf==3.20.0 3 | transformers==4.30.2 4 | cpm_kernels 5 | torch==2.1.0 6 | gradio==3.48.0 7 | mdtex2html==1.2.0 8 | sentencepiece==0.1.99 9 | accelerate==0.23.0 10 | sse-starlette==1.6.5 11 | streamlit>=1.24.0 12 | rouge_chinese 13 | jieba==0.42.1 14 | datasets==2.13.0 15 | nltk==3.8.1 16 | deepspeed==0.11.1 17 | wandb 18 | sympy==1.12 -------------------------------------------------------------------------------- /self_train/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/ReST-MCTS/2d5f488c3d6e24f99d50a9860e818383b1bb5883/self_train/.DS_Store -------------------------------------------------------------------------------- /self_train/config/deep3_new_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | deepspeed_config: 4 | gradient_accumulation_steps: 2 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: 'no' 13 | dynamo_config: 14 | dynamo_backend: CUDAGRAPHS 15 | dynamo_mode: default 16 | dynamo_use_dynamic: true 17 | dynamo_use_fullgraph: true 18 | enable_cpu_affinity: false 19 | machine_rank: 0 20 | main_training_function: main 21 | mixed_precision: bf16 22 | num_machines: 1 23 | num_processes: 8 24 | rdzv_backend: static 25 | same_network: true 26 | tpu_env: [] 27 | tpu_use_cluster: false 28 | tpu_use_sudo: false 29 | use_cpu: false 30 | -------------------------------------------------------------------------------- /self_train/config/deepspeed_zero1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 1 6 | zero3_init_flag: false 7 | zero_stage: 1 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | machine_rank: 0 11 | main_training_function: main 12 | mixed_precision: 'bf16' 13 | num_machines: 1 14 | num_processes: 8 15 | rdzv_backend: static 16 | same_network: true 17 | tpu_env: [] 18 | tpu_use_cluster: false 19 | tpu_use_sudo: false 20 | use_cpu: false 21 | -------------------------------------------------------------------------------- /self_train/config/deepspeed_zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 1 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: 'bf16' 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /self_train/config/deepspeed_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "compute_environment": "LOCAL_MACHINE", 3 | "debug": false, 4 | "deepspeed_config": { 5 | "deepspeed_multinode_launcher": "standard", 6 | "gradient_accumulation_steps": 2, 7 | "offload_optimizer_device": "none", 8 | "offload_param_device": "none", 9 | "zero3_init_flag": true, 10 | "zero3_save_16bit_model": true, 11 | "zero_stage": 3 12 | }, 13 | "distributed_type": "DEEPSPEED", 14 | "downcast_bf16": "yes", 15 | "machine_rank": 0, 16 | "main_training_function": "main", 17 | "mixed_precision": "bf16", 18 | "num_machines": 1, 19 | "num_processes": 8, 20 | "rdzv_backend": "static", 21 | "same_network": true, 22 | "tpu_env": [], 23 | "tpu_use_cluster": false, 24 | "tpu_use_sudo": false, 25 | "use_cpu": false 26 | } -------------------------------------------------------------------------------- /self_train/config/deepspeed_zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 1 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: 'no' 13 | machine_rank: 0 14 | main_training_function: main 15 | mixed_precision: 'bf16' 16 | num_machines: 1 17 | num_processes: 8 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /self_train/config/default_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | deepspeed_config: 4 | gradient_accumulation_steps: 2 5 | gradient_clipping: 1.0 6 | zero3_init_flag: false 7 | zero_stage: 1 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | dynamo_config: 11 | dynamo_backend: CUDAGRAPHS 12 | enable_cpu_affinity: false 13 | machine_rank: 0 14 | main_training_function: main 15 | mixed_precision: bf16 16 | num_machines: 1 17 | num_processes: 8 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /self_train/config/yaml_to_json.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import json 3 | 4 | 5 | # yaml文件内容转换成json格式 6 | def yaml_to_json(yamlPath): 7 | with open(yamlPath, encoding="utf-8") as f: 8 | datas = yaml.load(f, Loader=yaml.FullLoader) # 将文件的内容转换为字典形式 9 | jsonDatas = json.dumps(datas, indent=5) # 将字典的内容转换为json格式的字符串 10 | return jsonDatas 11 | 12 | 13 | if __name__ == "__main__": 14 | yamlPath = 'deepspeed_zero3.yaml' 15 | with open(yamlPath.replace('.yaml', '.json'), 'w', encoding='utf-8') as f: 16 | datas = yaml_to_json(yamlPath) 17 | f.write(datas) 18 | -------------------------------------------------------------------------------- /self_train/generation/generate_both_samples_GSM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from CoT.task import CoT_Task 4 | from ToT.task import ToT_Task 5 | from MCTS.task import MCTS_Task 6 | import argparse 7 | from utils.visualize import visualize 8 | from utils.json_operator import * 9 | 10 | 11 | def run(arguments): 12 | print('-'*30, '开始生成', '-'*30, '\n') 13 | base_dir = os.getcwd() 14 | file = f'{base_dir}/data/{arguments.task_name}/{arguments.file}.json' 15 | print(f'Reading data from {file}...\n') 16 | try: 17 | data_list = read_json(file) 18 | data_len = len(data_list) 19 | except Exception as e: 20 | print(f'File must be standardized json!\nError type:{e}\n') 21 | return 22 | assert data_len > 0, "Data list is empty!\n" 23 | assert 'content' in data_list[0].keys() and 'answer' in data_list[0].keys(), "Key error, Make sure json object contain correct keys!\n" 24 | 25 | output_list = [] 26 | process = 0 27 | if arguments.partial: 28 | base_dir = os.getcwd() 29 | output_file = f'{base_dir}/generation/{arguments.reward_type}/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_start{arguments.start_num}_end{arguments.end_num}.json' 30 | output_list = read_json(output_file) 31 | process = len(output_list) / arguments.generate_num 32 | 33 | for i in range(data_len): 34 | if i < arguments.start_num + process or i >= arguments.end_num: 35 | continue 36 | # solve 37 | print(f'开始解答第{i+1}题...\n') 38 | data = data_list[i]['content'] 39 | answer = data_list[i]['answer'] 40 | outputs = [] 41 | 42 | if arguments.mode == 'mcts': 43 | Task = MCTS_Task(data, arguments.propose_method, arguments.value_method, arguments.branch, arguments.end_gate, 44 | arguments.roll_policy, arguments.roll_branch, arguments.roll_forward_steps, arguments.time_limit, 45 | arguments.iteration_limit, arguments.exploration_constant, arguments.alpha, arguments.inf, 46 | arguments.temperature, use_case_prompt=arguments.use_case_prompt, use_reflection=arguments.use_reflection, 47 | low=arguments.low, high=arguments.high, evaluate=arguments.evaluate, sample_value='full', answer=answer, verify_method='string', lang=arguments.lang) 48 | for cnt in range(arguments.generate_num): 49 | output, root = Task.run() 50 | corr_policy_sample_num = sum([sample['correct'] for sample in output['policy_samples']]) 51 | total_policy_sample_num = len(output['policy_samples']) 52 | output.update({'corr_policy_sample_num': corr_policy_sample_num, 'total_policy_sample_num': total_policy_sample_num}) 53 | outputs.append(output) 54 | 55 | elif arguments.mode == 'cot': 56 | Task = CoT_Task(data, arguments.propose_method, arguments.value_method, arguments.temperature, evaluate=arguments.evaluate, lang=arguments.lang, answer=answer, verify_method='string', do_self_critic=arguments.do_self_critic) 57 | for cnt in range(arguments.generate_num): 58 | output = Task.run() 59 | outputs.append(output) 60 | 61 | else: 62 | print("Unsupported sample generation mode!\n") 63 | return 64 | 65 | for output in outputs: 66 | output_list.append(output) 67 | 68 | print(f'第{i+1}题解答结束。\n') 69 | 70 | # output 71 | base_dir = os.getcwd() 72 | output_dir = pathlib.Path(f'{base_dir}/generation/{arguments.reward_type}/{arguments.task_name}/{arguments.file}/{Task.mode}') 73 | output_file = f'{base_dir}/generation/{arguments.reward_type}/{arguments.task_name}/{arguments.file}/{Task.mode}/{Task.propose_method}_{Task.value_method}_start{arguments.start_num}_end{arguments.end_num}.json' 74 | 75 | pathlib.Path.mkdir(output_dir, exist_ok=True, parents=True) 76 | dump_json(output_file, output_list) 77 | print('_' * 60) 78 | 79 | 80 | def parse_args(): 81 | base_args = argparse.ArgumentParser() 82 | base_args.add_argument('--task_name', type=str, default='gsm_8k') 83 | base_args.add_argument('--file', type=str, default='gsm8k_self_train_1') # json 84 | base_args.add_argument('--reward_type', type=str, choices=['vm', 'prm'], default='vm') 85 | base_args.add_argument('--propose_method', type=str, choices=['gpt', 'glm', 'llama', 'mistral', 'local'], default='local') 86 | base_args.add_argument('--value_method', type=str, choices=['gpt', 'glm', 'local'], default='local') 87 | base_args.add_argument('--mode', type=str, choices=['cot', 'tot', 'mcts'], default='mcts') 88 | base_args.add_argument('--temperature', type=float, default=0.7) 89 | base_args.add_argument('--time_limit', type=int, default=None) 90 | base_args.add_argument('--iteration_limit', type=int, default=40) 91 | base_args.add_argument('--roll_policy', type=str, choices=['random', 'greedy'], default='greedy') 92 | base_args.add_argument('--exploration_constant', type=float, default=0.5) 93 | base_args.add_argument('--roll_forward_steps', type=int, default=1) 94 | base_args.add_argument('--end_gate', type=float, default=0.8) 95 | base_args.add_argument('--branch', type=int, default=3) 96 | base_args.add_argument('--roll_branch', type=int, default=1) 97 | base_args.add_argument('--inf', type=float, default=0.9) 98 | base_args.add_argument('--evaluate', type=str, default='') 99 | base_args.add_argument('--alpha', type=float, default=0.5) 100 | base_args.add_argument('--visualize', type=bool, default=False) 101 | base_args.add_argument('--use_case_prompt', type=bool, default=False) 102 | base_args.add_argument('--use_reflection', type=str, choices=['simple', 'common'], default='simple') 103 | base_args.add_argument('--low', type=float, default=0) 104 | base_args.add_argument('--high', type=float, default=1) 105 | base_args.add_argument('--algorithm', type=str, choices=['dfs', 'bfs'], default='dfs') 106 | base_args.add_argument('--select_branch', type=int, default=2) 107 | base_args.add_argument('--max_depth', type=int, default=8) 108 | base_args.add_argument('--select_method', type=str, choices=['greedy', 'sample'], default='greedy') 109 | base_args.add_argument('--generate_num', type=int, default=1) 110 | base_args.add_argument('--start_num', type=int, default=0) 111 | base_args.add_argument('--end_num', type=int, default=165) 112 | base_args.add_argument('--partial', type=bool, default=False) 113 | base_args.add_argument('--lang', type=str, choices=['zh', 'en'], default='en') 114 | base_args.add_argument('--do_self_critic', type=bool, default=True) # for CoT 115 | 116 | arguments = base_args.parse_args() 117 | return arguments 118 | 119 | 120 | if __name__ == '__main__': 121 | args = parse_args() 122 | run(args) 123 | -------------------------------------------------------------------------------- /self_train/generation/generate_both_samples_MATH.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from CoT.task import CoT_Task 4 | from ToT.task import ToT_Task 5 | from MCTS.task import MCTS_Task 6 | import argparse 7 | from utils.visualize import visualize 8 | from utils.json_operator import * 9 | 10 | 11 | def run(arguments): 12 | print('-'*30, '开始生成', '-'*30, '\n') 13 | base_dir = os.getcwd() 14 | file = f'{base_dir}/data/{arguments.task_name}/{arguments.file}.json' 15 | print(f'Reading data from {file}...\n') 16 | try: 17 | data_list = read_json(file) 18 | data_len = len(data_list) 19 | except Exception as e: 20 | print(f'File must be standardized json!\nError type:{e}\n') 21 | return 22 | assert data_len > 0, "Data list is empty!\n" 23 | assert 'content' in data_list[0].keys() and 'answer' in data_list[0].keys(), "Key error, Make sure json object contain correct keys!\n" 24 | 25 | output_list = [] 26 | process = 0 27 | if arguments.partial: 28 | base_dir = os.getcwd() 29 | output_file = f'{base_dir}/generation/{arguments.reward_type}/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_start{arguments.start_num}_end{arguments.end_num}.json' 30 | output_list = read_json(output_file) 31 | process = len(output_list) / arguments.generate_num 32 | 33 | for i in range(data_len): 34 | if i < arguments.start_num + process or i >= arguments.end_num: 35 | continue 36 | # solve 37 | print(f'开始解答第{i+1}题...\n') 38 | data = data_list[i]['content'] 39 | answer = data_list[i]['answer'] 40 | outputs = [] 41 | 42 | if arguments.mode == 'mcts': 43 | Task = MCTS_Task(data, arguments.propose_method, arguments.value_method, arguments.branch, arguments.end_gate, 44 | arguments.roll_policy, arguments.roll_branch, arguments.roll_forward_steps, arguments.time_limit, 45 | arguments.iteration_limit, arguments.exploration_constant, arguments.alpha, arguments.inf, 46 | arguments.temperature, use_case_prompt=arguments.use_case_prompt, use_reflection=arguments.use_reflection, 47 | low=arguments.low, high=arguments.high, evaluate=arguments.evaluate, sample_value='full', answer=answer, verify_method='string', lang=arguments.lang) 48 | for cnt in range(arguments.generate_num): 49 | output, root = Task.run() 50 | corr_policy_sample_num = sum([sample['correct'] for sample in output['policy_samples']]) 51 | total_policy_sample_num = len(output['policy_samples']) 52 | output.update({'corr_policy_sample_num': corr_policy_sample_num, 'total_policy_sample_num': total_policy_sample_num}) 53 | outputs.append(output) 54 | 55 | elif arguments.mode == 'cot': 56 | Task = CoT_Task(data, arguments.propose_method, arguments.value_method, arguments.temperature, evaluate=arguments.evaluate, lang=arguments.lang, answer=answer, verify_method='string', do_self_critic=arguments.do_self_critic) 57 | for cnt in range(arguments.generate_num): 58 | output = Task.run() 59 | outputs.append(output) 60 | 61 | else: 62 | print("Unsupported sample generation mode!\n") 63 | return 64 | 65 | for output in outputs: 66 | output_list.append(output) 67 | 68 | print(f'第{i+1}题解答结束。\n') 69 | 70 | # output 71 | base_dir = os.getcwd() 72 | output_dir = pathlib.Path(f'{base_dir}/generation/{arguments.reward_type}/{arguments.task_name}/{arguments.file}/{Task.mode}') 73 | output_file = f'{base_dir}/generation/{arguments.reward_type}/{arguments.task_name}/{arguments.file}/{Task.mode}/{Task.propose_method}_{Task.value_method}_start{arguments.start_num}_end{arguments.end_num}.json' 74 | 75 | pathlib.Path.mkdir(output_dir, exist_ok=True, parents=True) 76 | dump_json(output_file, output_list) 77 | print('_' * 60) 78 | 79 | 80 | def parse_args(): 81 | base_args = argparse.ArgumentParser() 82 | base_args.add_argument('--task_name', type=str, default='math') 83 | base_args.add_argument('--file', type=str, default='math_self_train_1') # json 84 | base_args.add_argument('--reward_type', type=str, choices=['vm', 'prm'], default='vm') 85 | base_args.add_argument('--propose_method', type=str, choices=['gpt', 'glm', 'llama', 'mistral', 'local'], default='local') 86 | base_args.add_argument('--value_method', type=str, choices=['gpt', 'glm', 'local'], default='local') 87 | base_args.add_argument('--mode', type=str, choices=['cot', 'tot', 'mcts'], default='mcts') 88 | base_args.add_argument('--temperature', type=float, default=0.7) 89 | base_args.add_argument('--time_limit', type=int, default=None) 90 | base_args.add_argument('--iteration_limit', type=int, default=50) 91 | base_args.add_argument('--roll_policy', type=str, choices=['random', 'greedy'], default='greedy') 92 | base_args.add_argument('--exploration_constant', type=float, default=0.5) 93 | base_args.add_argument('--roll_forward_steps', type=int, default=1) 94 | base_args.add_argument('--end_gate', type=float, default=0.8) # 结束阈值 95 | base_args.add_argument('--branch', type=int, default=3) 96 | base_args.add_argument('--roll_branch', type=int, default=1) 97 | base_args.add_argument('--inf', type=float, default=0.9) 98 | base_args.add_argument('--evaluate', type=str, default='') # 是否进行评测(空即不评测) 99 | base_args.add_argument('--alpha', type=float, default=0.5) 100 | base_args.add_argument('--visualize', type=bool, default=False) # 可视化 101 | base_args.add_argument('--use_case_prompt', type=bool, default=False) # 使用样例提示 102 | base_args.add_argument('--use_reflection', type=str, choices=['simple', 'common'], default='simple') # 使用反思模式 103 | base_args.add_argument('--low', type=float, default=0) 104 | base_args.add_argument('--high', type=float, default=1) 105 | base_args.add_argument('--algorithm', type=str, choices=['dfs', 'bfs'], default='dfs') 106 | base_args.add_argument('--select_branch', type=int, default=2) 107 | base_args.add_argument('--max_depth', type=int, default=8) 108 | base_args.add_argument('--select_method', type=str, choices=['greedy', 'sample'], default='greedy') 109 | base_args.add_argument('--generate_num', type=int, default=1) 110 | base_args.add_argument('--start_num', type=int, default=0) 111 | base_args.add_argument('--end_num', type=int, default=625) 112 | base_args.add_argument('--partial', type=bool, default=False) # 是否继续生成 113 | base_args.add_argument('--lang', type=str, choices=['zh', 'en'], default='en') 114 | base_args.add_argument('--do_self_critic', type=bool, default=True) # for CoT 115 | 116 | arguments = base_args.parse_args() 117 | return arguments 118 | 119 | 120 | if __name__ == '__main__': 121 | args = parse_args() 122 | run(args) 123 | -------------------------------------------------------------------------------- /self_train/generation/generate_both_samples_TheoremQA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from CoT.task import CoT_Task 4 | from ToT.task import ToT_Task 5 | from MCTS.task import MCTS_Task 6 | import argparse 7 | from utils.visualize import visualize 8 | from utils.json_operator import * 9 | 10 | 11 | def run(arguments): 12 | print('-'*30, '开始生成', '-'*30, '\n') 13 | base_dir = os.getcwd() 14 | file = f'{base_dir}/data/{arguments.task_name}/{arguments.file}.json' 15 | print(f'Reading data from {file}...\n') 16 | try: 17 | data_list = read_json(file) 18 | data_len = len(data_list) 19 | except Exception as e: 20 | print(f'File must be standardized json!\nError type:{e}\n') 21 | return 22 | assert data_len > 0, "Data list is empty!\n" 23 | assert 'content' in data_list[0].keys() and 'answer' in data_list[0].keys(), "Key error, Make sure json object contain correct keys!\n" 24 | 25 | output_list = [] 26 | process = 0 27 | if arguments.partial: 28 | base_dir = os.getcwd() 29 | output_file = f'{base_dir}/generation/{arguments.reward_type}/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_start{arguments.start_num}_end{arguments.end_num}.json' 30 | output_list = read_json(output_file) 31 | process = len(output_list) / arguments.generate_num 32 | 33 | for i in range(data_len): 34 | if i < arguments.start_num + process or i >= arguments.end_num: 35 | continue 36 | # solve 37 | print(f'开始解答第{i+1}题...\n') 38 | data = data_list[i]['content'] 39 | answer = data_list[i]['answer'] 40 | outputs = [] 41 | 42 | if arguments.mode == 'mcts': 43 | Task = MCTS_Task(data, arguments.propose_method, arguments.value_method, arguments.branch, arguments.end_gate, 44 | arguments.roll_policy, arguments.roll_branch, arguments.roll_forward_steps, arguments.time_limit, 45 | arguments.iteration_limit, arguments.exploration_constant, arguments.alpha, arguments.inf, 46 | arguments.temperature, use_case_prompt=arguments.use_case_prompt, use_reflection=arguments.use_reflection, 47 | low=arguments.low, high=arguments.high, evaluate=arguments.evaluate, sample_value='full', answer=answer, verify_method='string', lang=arguments.lang) 48 | for cnt in range(arguments.generate_num): 49 | output, root = Task.run() 50 | corr_policy_sample_num = sum([sample['correct'] for sample in output['policy_samples']]) 51 | total_policy_sample_num = len(output['policy_samples']) 52 | output.update({'corr_policy_sample_num': corr_policy_sample_num, 'total_policy_sample_num': total_policy_sample_num}) 53 | outputs.append(output) 54 | 55 | elif arguments.mode == 'cot': 56 | Task = CoT_Task(data, arguments.propose_method, arguments.value_method, arguments.temperature, evaluate=arguments.evaluate, lang=arguments.lang, answer=answer, verify_method='string', do_self_critic=arguments.do_self_critic) 57 | for cnt in range(arguments.generate_num): 58 | output = Task.run() 59 | outputs.append(output) 60 | 61 | else: 62 | print("Unsupported sample generation mode!\n") 63 | return 64 | 65 | for output in outputs: 66 | output_list.append(output) 67 | 68 | print(f'第{i+1}题解答结束。\n') 69 | 70 | # output 71 | base_dir = os.getcwd() 72 | output_dir = pathlib.Path(f'{base_dir}/generation/{arguments.reward_type}/{arguments.task_name}/{arguments.file}/{Task.mode}') 73 | output_file = f'{base_dir}/generation/{arguments.reward_type}/{arguments.task_name}/{arguments.file}/{Task.mode}/{Task.propose_method}_{Task.value_method}_start{arguments.start_num}_end{arguments.end_num}.json' 74 | 75 | pathlib.Path.mkdir(output_dir, exist_ok=True, parents=True) 76 | dump_json(output_file, output_list) 77 | print('_' * 60) 78 | 79 | 80 | def parse_args(): 81 | base_args = argparse.ArgumentParser() 82 | base_args.add_argument('--task_name', type=str, default='theoremQA') 83 | base_args.add_argument('--file', type=str, default='theoremQA_self_train_1') # json 84 | base_args.add_argument('--reward_type', type=str, choices=['vm', 'prm'], default='vm') 85 | base_args.add_argument('--propose_method', type=str, choices=['gpt', 'glm', 'llama', 'mistral', 'local'], default='local') 86 | base_args.add_argument('--value_method', type=str, choices=['gpt', 'glm', 'local'], default='local') 87 | base_args.add_argument('--mode', type=str, choices=['cot', 'tot', 'mcts'], default='mcts') 88 | base_args.add_argument('--temperature', type=float, default=0.7) 89 | base_args.add_argument('--time_limit', type=int, default=None) 90 | base_args.add_argument('--iteration_limit', type=int, default=50) 91 | base_args.add_argument('--roll_policy', type=str, choices=['random', 'greedy'], default='greedy') 92 | base_args.add_argument('--exploration_constant', type=float, default=0.5) 93 | base_args.add_argument('--roll_forward_steps', type=int, default=1) 94 | base_args.add_argument('--end_gate', type=float, default=0.8) # 结束阈值 95 | base_args.add_argument('--branch', type=int, default=3) 96 | base_args.add_argument('--roll_branch', type=int, default=1) 97 | base_args.add_argument('--inf', type=float, default=0.9) 98 | base_args.add_argument('--evaluate', type=str, default='') # 是否进行评测(空即不评测) 99 | base_args.add_argument('--alpha', type=float, default=0.5) 100 | base_args.add_argument('--visualize', type=bool, default=False) # 可视化 101 | base_args.add_argument('--use_case_prompt', type=bool, default=False) # 使用样例提示 102 | base_args.add_argument('--use_reflection', type=str, choices=['simple', 'common'], default='simple') # 使用反思模式 103 | base_args.add_argument('--low', type=float, default=0) 104 | base_args.add_argument('--high', type=float, default=1) 105 | base_args.add_argument('--algorithm', type=str, choices=['dfs', 'bfs'], default='dfs') 106 | base_args.add_argument('--select_branch', type=int, default=2) 107 | base_args.add_argument('--max_depth', type=int, default=8) 108 | base_args.add_argument('--select_method', type=str, choices=['greedy', 'sample'], default='greedy') 109 | base_args.add_argument('--generate_num', type=int, default=1) 110 | base_args.add_argument('--start_num', type=int, default=0) 111 | base_args.add_argument('--end_num', type=int, default=70) 112 | base_args.add_argument('--partial', type=bool, default=False) # 是否继续生成 113 | base_args.add_argument('--lang', type=str, choices=['zh', 'en'], default='en') 114 | base_args.add_argument('--do_self_critic', type=bool, default=True) # for CoT 115 | 116 | arguments = base_args.parse_args() 117 | return arguments 118 | 119 | 120 | if __name__ == '__main__': 121 | args = parse_args() 122 | run(args) 123 | -------------------------------------------------------------------------------- /self_train/self_train_dpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import nullcontext 3 | from trl.trainer.utils import DPODataCollatorWithPadding 4 | from utils.json_operator import * 5 | from trl import DPOTrainer, TrlParser, ModelConfig 6 | from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer 7 | import torch 8 | from datasets import Dataset 9 | from trl.commands.cli_utils import DpoScriptArguments 10 | from trl.trainer import ppo_config 11 | # import wandb 12 | # wandb.login(key='0') 13 | 14 | # model 15 | model_dir = "/workspace/ckpt/MetaMath-Mistral-7B" 16 | model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16) 17 | model_ref = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16) 18 | tokenizer = AutoTokenizer.from_pretrained(model_dir) 19 | model_config = model.config 20 | 21 | # dataset can be downloaded from https://github.com/THUDM/ReST-MCTS#policy-data). 22 | d_path = "/your/path/to/ReST-MCTS_Llama3-8b-Instruct_Self-Rewarding-DPO_1st.json" 23 | data_dict = read_json(d_path)[0] 24 | d_len = len(data_dict['prompt']) 25 | assert d_len == len(data_dict['chosen']) and d_len == len(data_dict['rejected']) 26 | print("data_len:", d_len) 27 | if 'llama' in d_path: 28 | chat_format = '<|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' 29 | ans_format = '{solution}' 30 | elif 'mistral' in d_path: 31 | chat_format = '[INST]{query}[/INST]' 32 | ans_format = '{solution}' 33 | else: 34 | raise NotImplementedError 35 | 36 | 37 | def preprocess(row): 38 | processed_prompt = chat_format.format(query=row['prompt']) 39 | processed_chosen = ans_format.format(solution=row['chosen']) 40 | processed_rejected = ans_format.format(solution=row['rejected']) 41 | processed_example = { 42 | "prompt": processed_prompt, 43 | "chosen": processed_chosen, 44 | "rejected": processed_rejected, 45 | } 46 | return processed_example 47 | 48 | 49 | dataset = Dataset.from_dict(data_dict) 50 | dataset = dataset.map(preprocess, batched=False) 51 | dataset = dataset.train_test_split(test_size=0.05) 52 | train_dataset = dataset['train'] 53 | test_dataset = dataset['test'] 54 | 55 | 56 | if __name__ == "__main__": 57 | ################ 58 | # Training Args 59 | ################ 60 | args = TrainingArguments( 61 | output_dir="", 62 | overwrite_output_dir=True, 63 | do_train=True, 64 | do_eval=True, 65 | evaluation_strategy="epoch", 66 | save_strategy="epoch", 67 | save_total_limit=2, 68 | num_train_epochs=2, 69 | learning_rate=3e-6, 70 | per_device_train_batch_size=1, 71 | optim="adamw_torch", 72 | bf16_full_eval=True, 73 | bf16=True, 74 | gradient_accumulation_steps=2, 75 | per_gpu_eval_batch_size=1, 76 | remove_unused_columns=False, 77 | # deepspeed="config/deepspeed_zero3.json" 78 | ) 79 | ################ 80 | # Tokenizer 81 | ################ 82 | if tokenizer.pad_token is None: 83 | tokenizer.pad_token = tokenizer.eos_token 84 | if tokenizer.chat_template is None: 85 | tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}" 86 | 87 | ################ 88 | # Training 89 | ################ 90 | pad_id = 128001 if 'llama' in d_path else 32000 91 | collator = DPODataCollatorWithPadding(pad_token_id=pad_id, is_encoder_decoder=model_config.is_encoder_decoder) 92 | trainer = DPOTrainer( 93 | model, 94 | model_ref, 95 | args=args, 96 | data_collator=collator, 97 | dataset_num_proc=8, 98 | max_length=1024, 99 | max_prompt_length=256, 100 | max_target_length=1024, 101 | train_dataset=train_dataset, 102 | eval_dataset=test_dataset, 103 | tokenizer=tokenizer, 104 | truncation_mode='keep_end', 105 | beta=0.1, 106 | ) 107 | 108 | # print('train num: ', len(trainer.train_dataset)) 109 | trainer.train() 110 | trainer.save_model(args.output_dir) 111 | -------------------------------------------------------------------------------- /self_train/vm_critic/filter_policy_examples_by_value.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import pathlib 4 | import argparse 5 | from utils.json_operator import * 6 | 7 | 8 | def parse_args(): 9 | base_args = argparse.ArgumentParser() 10 | base_args.add_argument('--task_name', type=str, default='gsm_8k') 11 | base_args.add_argument('--file', type=str, default='gsm8k_all') # json 12 | base_args.add_argument('--propose_method', type=list[str], default=['mistral', 'llama', 'local']) 13 | base_args.add_argument('--value_threshold', type=float, default=0.2) # value low gate 14 | base_args.add_argument('--len_threshold', type=int, default=50) # str len 15 | 16 | arguments = base_args.parse_args() 17 | return arguments 18 | 19 | 20 | def do_filter_policy_examples_by_value(arguments): 21 | source_dir = f"extracted_samples/{arguments.task_name}/{arguments.file}/mcts/policy_samples" 22 | for file in os.listdir(source_dir): 23 | if file.split('_')[0] in arguments.propose_method and 'vm_critic' in file: 24 | source_file = os.path.join(source_dir, file) 25 | datas = read_json(source_file) 26 | selected_datas = [] 27 | for data in datas: 28 | if data['vm_critic'] >= arguments.value_threshold and len(data['summary']) >= arguments.len_threshold: 29 | selected_datas.append(data) 30 | backend = file.split('_')[0] 31 | out_file = f'{source_dir}/{backend}_local_vm_filtered.json' 32 | dump_json(out_file, selected_datas) 33 | 34 | 35 | if __name__ == '__main__': 36 | args = parse_args() 37 | do_filter_policy_examples_by_value(args) 38 | -------------------------------------------------------------------------------- /self_train/vm_critic/manual_self_critic.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | import pathlib 5 | import argparse 6 | from utils.json_operator import * 7 | from CoT.task import * 8 | from tqdm import tqdm 9 | 10 | 11 | def parse_args(): 12 | base_args = argparse.ArgumentParser() 13 | base_args.add_argument('--task_name', type=str, default='gsm_8k') 14 | base_args.add_argument('--file', type=str, default='gsm8k_self_train_1') # json 15 | base_args.add_argument('--propose_method', type=str, choices=['gpt', 'glm', 'llama', 'mistral', 'local'], 16 | default='local') 17 | 18 | arguments = base_args.parse_args() 19 | return arguments 20 | 21 | 22 | def do_self_critic(arguments): 23 | base_dir = os.getcwd() 24 | policy_output_file = f'{base_dir}/extracted_samples/{arguments.task_name}/{arguments.file}/cot/policy_samples/{arguments.propose_method}_local.json' 25 | datas = read_json(policy_output_file) 26 | out_file = f'{base_dir}/extracted_samples/{arguments.task_name}/{arguments.file}/cot/policy_samples/{arguments.propose_method}_local_critic.json' 27 | new_datas = [] 28 | for data in tqdm(datas): 29 | question = data['content'] 30 | solution = data['summary'] 31 | task = CoT_Task(question, propose_method=arguments.propose_method, value_method='local', do_self_critic=True) 32 | 33 | score = None 34 | cnt = 3 35 | while score is None and cnt: 36 | score = task.get_self_critic(solution) 37 | cnt -= 1 38 | if score is None: 39 | score = 0 40 | data.update({'self_critic': score}) 41 | new_datas.append(data) 42 | dump_json(out_file, new_datas) 43 | 44 | 45 | if __name__ == '__main__': 46 | args = parse_args() 47 | do_self_critic(args) 48 | -------------------------------------------------------------------------------- /self_train/vm_critic/manual_vm_critic.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CUDA_VISIBLE_DEVICES'] = '6' 3 | import json 4 | import re 5 | import pathlib 6 | import argparse 7 | from utils.json_operator import * 8 | from MCTS.task import * 9 | from tqdm import tqdm 10 | 11 | 12 | def parse_args(): 13 | base_args = argparse.ArgumentParser() 14 | base_args.add_argument('--task_name', type=str, default='gsm_8k') 15 | base_args.add_argument('--file', type=str, default='gsm8k_all') # json 16 | base_args.add_argument('--propose_method', type=str, choices=['gpt', 'glm', 'llama', 'mistral', 'local'], 17 | default='mistral') 18 | base_args.add_argument('--start_num', type=int, default=225) # Starting sequence number (not absolute sequence number) 19 | base_args.add_argument('--end_num', type=int, default=450) 20 | base_args.add_argument('--generate_num', type=int, default=256) 21 | base_args.add_argument('--do_aggregate', type=bool, default=False) # aggregate results 22 | 23 | arguments = base_args.parse_args() 24 | return arguments 25 | 26 | 27 | def do_vm_critic(arguments): 28 | base_dir = os.getcwd() 29 | policy_output_file = f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/cot/{arguments.propose_method}_local_all.json' 30 | datas = read_json(policy_output_file) 31 | assert len(datas) % arguments.generate_num == 0, 'length not match!\n' 32 | out_file = f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/cot/{arguments.propose_method}_local_vm_critic_{arguments.start_num}_{arguments.end_num}.json' 33 | done_datas = read_json(out_file) 34 | done_idx = len(done_datas) 35 | new_datas = done_datas 36 | idx = 0 37 | for data in tqdm(datas): 38 | if idx < arguments.start_num * arguments.generate_num + done_idx or idx >= arguments.end_num * arguments.generate_num: 39 | idx += 1 40 | continue 41 | question = data['content'] 42 | solution = data['solution'] + '\n' + data['summary'] 43 | task = MCTS_Task(question, propose_method=arguments.propose_method, value_method='local', lang='en') 44 | 45 | score = None 46 | cnt = 3 47 | while score is None and cnt: 48 | score = task.get_step_value(solution) 49 | cnt -= 1 50 | if score is None: 51 | score = 0 52 | data.update({'vm_critic': score}) 53 | new_datas.append(data) 54 | dump_json(out_file, new_datas) 55 | idx += 1 56 | 57 | 58 | def aggregate_vm_critic(arguments): 59 | base_dir = os.getcwd() 60 | out_dir = f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/cot' 61 | pattern = f'{arguments.propose_method}_local_vm_critic' 62 | all_outputs = [] 63 | for file in os.listdir(out_dir): 64 | if pattern in file and 'all' not in file: 65 | all_outputs.extend(read_json(f'{out_dir}/{file}')) 66 | dump_json(f'{out_dir}/{pattern}_all.json', all_outputs) 67 | 68 | 69 | if __name__ == '__main__': 70 | args = parse_args() 71 | if args.do_aggregate: 72 | aggregate_vm_critic(args) 73 | else: 74 | do_vm_critic(args) 75 | -------------------------------------------------------------------------------- /self_train/vm_critic/vm_critic_for_extracted_samples.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | import pathlib 5 | import argparse 6 | from utils.json_operator import * 7 | from MCTS.task import * 8 | from tqdm import tqdm 9 | 10 | 11 | def parse_args(): 12 | base_args = argparse.ArgumentParser() 13 | base_args.add_argument('--task_name', type=str, default='gsm_8k') 14 | base_args.add_argument('--file', type=str, default='gsm8k_all') # json 15 | base_args.add_argument('--propose_method', type=str, choices=['llama', 'mistral', 'local'], 16 | default='mistral') 17 | 18 | arguments = base_args.parse_args() 19 | return arguments 20 | 21 | 22 | def do_vm_critic_mc(arguments): 23 | base_dir = os.getcwd() 24 | policy_output_file = f'{base_dir}/extracted_samples/{arguments.task_name}/{arguments.file}/mcts/policy_samples/{arguments.propose_method}_local.json' 25 | datas = read_json(policy_output_file) 26 | out_file = f'{base_dir}/extracted_samples/{arguments.task_name}/{arguments.file}/mcts/policy_samples/{arguments.propose_method}_local_vm_critic.json' 27 | done_datas = read_json(out_file) 28 | done_idx = len(done_datas) 29 | new_datas = done_datas 30 | idx = 0 31 | for data in tqdm(datas): 32 | if idx < done_idx: 33 | idx += 1 34 | continue 35 | question = data['content'] 36 | solution = data['summary'] 37 | task = MCTS_Task(question, propose_method=arguments.propose_method, value_method='local', lang='en') 38 | 39 | score = None 40 | cnt = 3 41 | while score is None and cnt: 42 | score = task.get_step_value(solution) 43 | cnt -= 1 44 | if score is None: 45 | score = 0 46 | data.update({'vm_critic': score}) 47 | new_datas.append(data) 48 | dump_json(out_file, new_datas) 49 | idx += 1 50 | 51 | 52 | if __name__ == '__main__': 53 | args = parse_args() 54 | do_vm_critic_mc(args) 55 | -------------------------------------------------------------------------------- /tasks/science.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from tasks.prompts import * 4 | 5 | 6 | # data: question: str 7 | # mode: 'cot', 'tot', 'mcts' 8 | # method: 'glm', 'gpt', 'local' 9 | class SearchTask(object): 10 | def __init__(self, data, propose_method='glm', value_method='glm'): 11 | super().__init__() 12 | self.question = data 13 | self.propose_method = propose_method 14 | self.value_method = value_method 15 | self.value_cache = {} 16 | 17 | def clear_cache(self): 18 | self.value_cache = {} 19 | 20 | @staticmethod 21 | def summary_prompt_wrap(x: str, y: str = '') -> str: 22 | print('\n', '==============================', 'summary', '==============================', '\n') 23 | print('summary_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤的综述为:\n') 24 | prompt = summary_prompt + x + '\n已有步骤:\n' + y + '\n输出:' 25 | return prompt 26 | 27 | @staticmethod 28 | def MATH_summary_prompt_wrap(x: str, y: str = '') -> str: 29 | print('\n', '==============================', 'summary', '==============================', '\n') 30 | print('summary_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤的综述为:\n') 31 | prompt = MATH_summary_prompt + x + '\nSolution: ' + y + '\nExtracted answer:' 32 | return prompt 33 | 34 | @staticmethod 35 | def evaluate_summary_prompt_wrap(x: str, y: str = '') -> str: 36 | print('\n', '==============================', 'summary', '==============================', '\n') 37 | print('summary_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤的综述为:\n') 38 | prompt = evaluate_summary_prompt + x + '\n已有步骤:\n' + y + '\n输出:' 39 | return prompt 40 | 41 | @staticmethod 42 | def general_evaluate_summary_prompt_wrap(x: str, y: str = '') -> str: 43 | print('\n', '==============================', 'summary', '==============================', '\n') 44 | print('summary_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤的综述为:\n') 45 | prompt = general_evaluate_summary_prompt + x + '\n已有步骤:\n' + y + '\n输出:' 46 | return prompt 47 | 48 | @staticmethod 49 | def single_propose_prompt_wrap(x: str, y: str = '', step: int = 0) -> str: 50 | print('\n', '==============================', 'proposal', '==============================', '\nstep: ', step) 51 | print('propose_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤,可能的当前步骤解法是:\n') 52 | prompt = single_proposal_prompt + x + '\n已有步骤:\n' + y + '\n输出:' 53 | return prompt 54 | 55 | @staticmethod 56 | def zero_single_propose_wrap(x: str, y: str = '', step: int = 0, lang: str = 'zh') -> str: 57 | print('\n', '==============================', 'proposal', '==============================', '\nstep: ', step) 58 | print('propose_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤,可能的当前步骤解法是:\n') 59 | if lang == 'zh': 60 | if not y: 61 | y = '无\n' 62 | prompt = zero_single_proposal_prompt + x + '\n已有步骤:\n' + y + '\n输出:' 63 | else: 64 | if not y: 65 | y = 'None\n' 66 | prompt = zero_single_proposal_prompt_en + x + '\nExisting Steps:\n' + y + '\nOutput:' 67 | return prompt 68 | 69 | @staticmethod 70 | def zero_single_propose_wrap_mistral(x: str, y: str = '', step: int = 0) -> str: 71 | print('\n', '==============================', 'proposal', '==============================', '\nstep: ', step) 72 | print('propose_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤,可能的当前步骤解法是:\n') 73 | if not y: 74 | y = 'None\n' 75 | prompt = zero_single_proposal_prompt_mistral + x + '\nExisting Steps:\n' + y + '\nOutput:' 76 | return prompt 77 | 78 | @staticmethod 79 | def zero_single_propose_wrap_gpt(x: str, y: str = '', step: int = 0, lang: str = 'zh') -> str: 80 | print('\n', '==============================', 'proposal', '==============================', '\nstep: ', step) 81 | print('propose_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤,可能的当前步骤解法是:\n') 82 | if lang == 'zh': 83 | if not y: 84 | y = '无\n' 85 | prompt = zero_single_proposal_prompt_gpt + x + '\n已有步骤:\n' + y + '\n输出:' 86 | else: 87 | if not y: 88 | y = 'None\n' 89 | prompt = zero_single_proposal_prompt_gpt_en + x + '\nExisting Steps:\n' + y + '\nOutput:' 90 | return prompt 91 | 92 | @staticmethod 93 | def zero_single_propose_wrap_use_reflection(x: str, y: str = '', step: int = 0, ref: str = '', lang: str = 'zh') -> str: 94 | print('\n', '==============================', 'proposal', '==============================', '\nstep: ', step) 95 | print('propose_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤,可能的当前步骤解法是:\n') 96 | if lang == 'zh': 97 | if not y: 98 | y = '无\n' 99 | if not ref: 100 | ref = '无\n' 101 | prompt = zero_single_proposal_prompt_use_reflection + x + '\n已有步骤:\n' + y + '\n意见:' + ref + '\n输出:' 102 | else: 103 | if not y: 104 | y = 'None\n' 105 | if not ref: 106 | ref = 'None\n' 107 | prompt = zero_single_proposal_prompt_use_reflection_en + x + '\nExisting Steps:\n' + y + '\nAnalysis: ' + ref + '\nOutput:' 108 | return prompt 109 | 110 | @staticmethod 111 | def zero_single_propose_wrap_use_reflection_gpt(x: str, y: str = '', step: int = 0, ref: str = '', lang: str = 'zh') -> str: 112 | print('\n', '==============================', 'proposal', '==============================', '\nstep: ', step) 113 | print('propose_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤,可能的当前步骤解法是:\n') 114 | if lang == 'zh': 115 | if not y: 116 | y = '无\n' 117 | if not ref: 118 | ref = '无\n' 119 | prompt = zero_single_proposal_prompt_use_reflection_gpt + x + '\n已有步骤:\n' + y + '\n意见:' + ref + '\n' 120 | else: 121 | if not y: 122 | y = 'None\n' 123 | if not ref: 124 | ref = 'None\n' 125 | prompt = zero_single_proposal_prompt_use_reflection_gpt_en + x + '\nExisting Steps:\n' + y + '\nAnalysis: ' + ref + '\nOutput:' 126 | return prompt 127 | 128 | @staticmethod 129 | def single_reflection_wrap(x: str, y: str = '', step: int = 0, lang: str = 'zh') -> str: 130 | print('\n', '==============================', 'reflection', '==============================', '\nstep: ', step) 131 | print('propose_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤给出的意见:\n') 132 | if lang == 'zh': 133 | if not y: 134 | y = '无\n' 135 | prompt = single_reflection_prompt + x + '\n已有步骤:\n' + y + '\n输出:' # glm style 136 | else: 137 | if not y: 138 | y = 'None\n' 139 | prompt = single_reflection_prompt_en + x + '\nExisting Steps:\n' + y + '\nOutput:' 140 | return prompt 141 | 142 | @staticmethod 143 | def single_reflection_wrap_gpt(x: str, y: str = '', step: int = 0) -> str: 144 | print('\n', '==============================', 'reflection', '==============================', '\nstep: ', step) 145 | print('propose_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤给出的意见:\n') 146 | if not y: 147 | y = '无\n' 148 | prompt = single_reflection_prompt_gpt + x + '\n已有步骤:\n' + y # gpt style 149 | return prompt 150 | 151 | @staticmethod 152 | def single_reflection_wrap_llama(x: str, y: str = '', step: int = 0) -> str: 153 | print('\n', '==============================', 'reflection', '==============================', '\nstep: ', step) 154 | print('propose_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤给出的意见:\n') 155 | if not y: 156 | y = '无\n' 157 | prompt = single_reflection_prompt_llama + x + '\n已有步骤:\n' + y + '\n空\n请你给出意见,不要解答问题,你给出的意见应该完全基于给定的步骤。' # llama style 158 | return prompt 159 | 160 | @staticmethod 161 | def single_reflection_wrap_simple(x: str, y: str = '', step: int = 0, lang: str = 'zh') -> str: 162 | print('\n', '==============================', 'reflection', '==============================', '\nstep: ', step) 163 | print('propose_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤给出的意见:\n') 164 | if lang == 'zh': 165 | if not y: 166 | y = '无\n' 167 | prompt = single_reflection_prompt_simple + x + '\n已有步骤:\n' + y + '\n输出:' # simple style 168 | else: 169 | if not y: 170 | y = 'None\n' 171 | prompt = single_reflection_prompt_simple_en + x + '\nExisting Steps:\n' + y + '\nOutput:' 172 | return prompt 173 | 174 | @staticmethod 175 | def single_reflection_wrap_simple_mistral(x: str, y: str = '', step: int = 0) -> str: 176 | print('\n', '==============================', 'reflection', '==============================', '\nstep: ', step) 177 | print('propose_prompt: \n', x + '\n已有步骤:\n' + y + '基于以上步骤给出的意见:\n') 178 | if not y: 179 | y = 'None\n' 180 | prompt = single_reflection_prompt_simple_mistral + x + '\nExisting Steps:\n' + y + '\nOutput:' 181 | return prompt 182 | 183 | @staticmethod 184 | def value_prompt_wrap(x: str, y: str) -> str: 185 | print('\n', '==============================', 'critic', '==============================', '\n') 186 | value_prompt = critic_simplified + x + '\n已有步骤:\n' + y.strip() + '\n输出:' 187 | return value_prompt 188 | 189 | @staticmethod 190 | def self_critic_prompt_wrap(x: str, y: str) -> str: 191 | print('\n', '==============================', 'self-critic', '==============================', '\n') 192 | if not y: 193 | y = 'None\n' 194 | critic_prompt = self_critic_prompt + x + '\nSolution:\n' + y + '\nScore:' 195 | return critic_prompt 196 | 197 | @staticmethod 198 | def cot_prompt_wrap(x: str, lang: str = 'zh', use_math: bool = False) -> str: 199 | print('\n', '==============================', 'proposal', '==============================', '\n') 200 | if not use_math: 201 | if lang == 'zh': 202 | prompt = cot_prompt + x + "\n解答过程:" 203 | else: 204 | prompt = cot_prompt_en + x + "\nSolution:" 205 | else: 206 | prompt = MATH_cot_prompt.format(query=x) 207 | print('propose_prompt: \n', prompt, '\n') 208 | return prompt 209 | 210 | @staticmethod 211 | def value_outputs_unwrap(value_outputs: list, low=0.0, high=1.0) -> float: 212 | out_value = low 213 | all_out = '' 214 | for _ in value_outputs: 215 | all_out = all_out + _ 216 | if '分数' not in all_out: 217 | print('分数输出不合法!\n') 218 | return out_value 219 | stp = all_out.split('分数')[-1].strip() 220 | try: 221 | match = re.findall(r'-?[0-9]+\.?[0-9]*', stp)[-1] 222 | out_value = float(match) 223 | out_value = min(max(low, out_value), high) 224 | except Exception as e: 225 | print(f'分数输出有误!错误类型:{e}\n') 226 | return low 227 | return out_value 228 | -------------------------------------------------------------------------------- /utils/aggregate_value_samples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pathlib 4 | import json 5 | import pandas as pd 6 | from utils.json_operator import * 7 | 8 | 9 | def aggregate(arguments): 10 | base_dir = os.getcwd() 11 | source_dir = f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/{arguments.mode}' 12 | if arguments.best_k > 1: 13 | out_file = f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_best@{arguments.best_k}_all.json' 14 | else: 15 | out_file = f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_all.json' 16 | 17 | cur_outputs = read_json(out_file) 18 | all_outputs = cur_outputs 19 | if cur_outputs: 20 | if arguments.best_k > 1: 21 | df = pd.read_csv( 22 | f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_best@{arguments.best_k}_all_process.csv') 23 | process_dict = {col: df[col].iloc[0] for col in df.columns} 24 | assert sum([value for value in process_dict.values()]) == len( 25 | cur_outputs), 'process_dict length not match cur_outputs length!\n' 26 | else: 27 | df = pd.read_csv( 28 | f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_all_process.csv') 29 | process_dict = {col: df[col].iloc[0] for col in df.columns} 30 | assert sum([value for value in process_dict.values()]) == len( 31 | cur_outputs), 'process_dict length not match cur_outputs length!\n' 32 | for file in os.listdir(source_dir): 33 | if arguments.best_k > 1: 34 | if f'{arguments.propose_method}_{arguments.value_method}' in file and f'best@{arguments.best_k}' in file and 'all' not in file: 35 | this_output = read_json(f'{source_dir}/{file}') 36 | this_new_output = this_output[process_dict[file]:] 37 | all_outputs.extend(this_new_output) 38 | process_dict[file] = len(this_output) 39 | else: 40 | if f'{arguments.propose_method}_{arguments.value_method}' in file and 'best@' not in file and 'all' not in file: 41 | this_output = read_json(f'{source_dir}/{file}') 42 | this_new_output = this_output[process_dict[file]:] 43 | all_outputs.extend(this_new_output) 44 | process_dict[file] = len(this_output) 45 | new_df = pd.DataFrame({file: [process_dict[file]] for file in process_dict.keys()}) 46 | if arguments.best_k > 1: 47 | new_df.to_csv( 48 | f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_best@{arguments.best_k}_all_process.csv', 49 | index=False) 50 | else: 51 | new_df.to_csv( 52 | f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_all_process.csv', 53 | index=False) 54 | else: 55 | process_dict = {} 56 | for file in os.listdir(source_dir): 57 | if arguments.best_k > 1: 58 | if f'{arguments.propose_method}_{arguments.value_method}' in file and f'best@{arguments.best_k}' in file and 'all' not in file: 59 | this_output = read_json(f'{source_dir}/{file}') 60 | all_outputs.extend(this_output) 61 | process_dict.update({file: [len(this_output)]}) 62 | else: 63 | if f'{arguments.propose_method}_{arguments.value_method}' in file and 'best@' not in file and 'all' not in file: 64 | this_output = read_json(f'{source_dir}/{file}') 65 | all_outputs.extend(this_output) 66 | process_dict.update({file: [len(this_output)]}) 67 | new_df = pd.DataFrame(process_dict) 68 | if arguments.best_k > 1: 69 | new_df.to_csv( 70 | f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_best@{arguments.best_k}_all_process.csv', 71 | index=False) 72 | else: 73 | new_df.to_csv( 74 | f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_all_process.csv', 75 | index=False) 76 | 77 | dump_json(out_file, all_outputs) 78 | 79 | 80 | def parse_args(): 81 | base_args = argparse.ArgumentParser() 82 | base_args.add_argument('--task_name', type=str, default='gsm_8k') 83 | base_args.add_argument('--file', type=str, default='gsm8k_self_train_1') # json 84 | base_args.add_argument('--propose_method', type=str, choices=['gpt', 'glm', 'llama', 'mistral', 'local'], default='local') 85 | base_args.add_argument('--value_method', type=str, choices=['gpt', 'glm', 'local'], default='local') 86 | base_args.add_argument('--mode', type=str, choices=['cot', 'tot', 'mcts'], default='mcts') 87 | base_args.add_argument('--algorithm', type=str, choices=['dfs', 'bfs'], default='dfs') 88 | base_args.add_argument('--generate_num', type=int, default=1) 89 | base_args.add_argument('--best_k', type=int, default=1) # best@k 90 | 91 | arguments = base_args.parse_args() 92 | return arguments 93 | 94 | 95 | if __name__ == '__main__': 96 | args = parse_args() 97 | aggregate(args) 98 | -------------------------------------------------------------------------------- /utils/answer_extractor.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | choices = ['A', 'B', 'C', 'D', 'a', 'b', 'c', 'd'] 4 | 5 | 6 | def extract(answer, q_type): 7 | if '\\boxed{' in answer: 8 | trunc_ans = answer.split('\\boxed{')[-1] 9 | extracted_ans = trunc_ans.split('}')[0].strip().replace(' ', '').replace(',', '') 10 | flag = 1 11 | if q_type == 'MCQ': 12 | if len(extracted_ans) == 1 and extracted_ans in choices: 13 | flag = 1 14 | extracted_ans = extracted_ans.upper() 15 | else: 16 | flag = 0 17 | elif q_type == 'MCQ(multiple)': 18 | for let in extracted_ans: 19 | if let not in choices: 20 | flag = 0 21 | break 22 | extracted_ans = extracted_ans.upper() 23 | else: 24 | try: 25 | float_ans = float(extracted_ans) 26 | except Exception as e: 27 | flag = 0 28 | if flag == 1: 29 | return extracted_ans 30 | else: 31 | return 'None' 32 | else: 33 | answer = answer.strip().upper().replace(' ', '').replace(',', '').replace('AND', '').replace(':', '') 34 | print(f'Processed strings:{answer}\n') 35 | match1 = re.findall(r'[\[,\{,\(][A-D]+[\],\},\)]', answer) 36 | match2 = re.findall(r'[\[,\{,\(]-?[0-9]+\.?[0-9]*[\],\},\)]', answer) 37 | match3 = re.findall(r'ANSWERIS-?[0-9]+\.?[0-9]*', answer) 38 | match4 = re.findall(r'ANSWERIS[A-D]{1,4}', answer) 39 | match5 = re.findall(r'ANSWER-?[0-9]+\.?[0-9]*', answer) 40 | match6 = re.findall(r'ANSWER[A-D]{1,4}', answer) 41 | match7 = re.findall( 42 | r'[\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)]', 43 | answer) 44 | match8 = re.findall(r'[\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)]', answer) 45 | match9 = re.findall(r'[\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)]', answer) 46 | match10 = re.findall(r'ANSWERIS[\[,\{,\(]-?[0-9]+\.?[0-9]*[\],\},\)]', answer) 47 | match11 = re.findall(r'ANSWER[\[,\{,\(]-?[0-9]+\.?[0-9]*[\],\},\)]', answer) 48 | match12 = re.findall(r'ANSWERIS[\[,\{,\(][A-D]{1,4}[\],\},\)]', answer) 49 | match13 = re.findall(r'ANSWER[\[,\{,\(][A-D]{1,4}[\],\},\)]', answer) 50 | match14 = re.findall( 51 | r'ANSWERIS[\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)]', 52 | answer) 53 | match15 = re.findall(r'ANSWERIS[\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)]', 54 | answer) 55 | match16 = re.findall(r'ANSWERIS[\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)]', answer) 56 | match17 = re.findall( 57 | r'ANSWER[\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)]', 58 | answer) 59 | match18 = re.findall(r'ANSWER[\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)]', 60 | answer) 61 | match19 = re.findall(r'ANSWER[\[,\{,\(][A-D][\],\},\)][\[,\{,\(][A-D][\],\},\)]', answer) 62 | 63 | if match14: 64 | print('Answer matching type 14\n') 65 | ans = match14[-1] 66 | ans = ans[8:] 67 | final_ans = '' 68 | for let in ans: 69 | if let in choices: 70 | final_ans = final_ans + let 71 | if 'MCQ' in q_type: 72 | return final_ans 73 | 74 | if match15: 75 | print('Answer matching type 15\n') 76 | ans = match15[-1] 77 | ans = ans[8:] 78 | final_ans = '' 79 | for let in ans: 80 | if let in choices: 81 | final_ans = final_ans + let 82 | if 'MCQ' in q_type: 83 | return final_ans 84 | 85 | if match16: 86 | print('Answer matching type 16\n') 87 | ans = match16[-1] 88 | ans = ans[8:] 89 | final_ans = '' 90 | for let in ans: 91 | if let in choices: 92 | final_ans = final_ans + let 93 | if 'MCQ' in q_type: 94 | return final_ans 95 | 96 | if match12: 97 | print('Answer matching type 12\n') 98 | ans = match12[-1] 99 | ans = ans[8:] 100 | final_ans = '' 101 | for let in ans: 102 | if let in choices: 103 | final_ans = final_ans + let 104 | if 'MCQ' in q_type: 105 | return final_ans 106 | 107 | if match17: 108 | print('Answer matching type 17\n') 109 | ans = match17[-1] 110 | ans = ans[6:] 111 | final_ans = '' 112 | for let in ans: 113 | if let in choices: 114 | final_ans = final_ans + let 115 | if 'MCQ' in q_type: 116 | return final_ans 117 | 118 | if match18: 119 | print('Answer matching type 18\n') 120 | ans = match18[-1] 121 | ans = ans[6:] 122 | final_ans = '' 123 | for let in ans: 124 | if let in choices: 125 | final_ans = final_ans + let 126 | if 'MCQ' in q_type: 127 | return final_ans 128 | 129 | if match19: 130 | print('Answer matching type 19\n') 131 | ans = match19[-1] 132 | ans = ans[6:] 133 | final_ans = '' 134 | for let in ans: 135 | if let in choices: 136 | final_ans = final_ans + let 137 | if 'MCQ' in q_type: 138 | return final_ans 139 | 140 | if match13: 141 | print('Answer matching type 13\n') 142 | ans = match13[-1] 143 | ans = ans[6:] 144 | final_ans = '' 145 | for let in ans: 146 | if let in choices: 147 | final_ans = final_ans + let 148 | if 'MCQ' in q_type: 149 | return final_ans 150 | 151 | if match10: 152 | print('Answer matching type 10\n') 153 | ans = match10[-1] 154 | ans = ans[9:] 155 | ans = ans[:-1] 156 | if 'MCQ' not in q_type: 157 | try: 158 | float_ans = float(ans) 159 | return ans 160 | except Exception as e: 161 | print('Matching error!\n') 162 | 163 | if match11: 164 | print('Answer matching type 11\n') 165 | ans = match11[-1] 166 | ans = ans[7:] 167 | ans = ans[:-1] 168 | if 'MCQ' not in q_type: 169 | try: 170 | float_ans = float(ans) 171 | return ans 172 | except Exception as e: 173 | print('Matching error!\n') 174 | 175 | if match3: 176 | print('Answer matching type 3\n') 177 | ans = match3[-1] 178 | ans = ans[8:] 179 | if 'MCQ' not in q_type: 180 | try: 181 | float_ans = float(ans) 182 | return ans 183 | except Exception as e: 184 | print('Matching error!\n') 185 | 186 | if match4: 187 | print('Answer matching type 4\n') 188 | ans = match4[-1] 189 | ans = ans[8:] 190 | if 'MCQ' in q_type: 191 | return ans 192 | 193 | if match5: 194 | print('Answer matching type 5\n') 195 | ans = match5[-1] 196 | ans = ans[6:] 197 | if 'MCQ' not in q_type: 198 | try: 199 | float_ans = float(ans) 200 | return ans 201 | except Exception as e: 202 | print('Matching error!\n') 203 | 204 | if match6: 205 | print('Answer matching type 6\n') 206 | ans = match6[-1] 207 | ans = ans[6:] 208 | if 'MCQ' in q_type: 209 | return ans 210 | 211 | if match7: 212 | print('Answer matching type 7\n') 213 | ans = match7[-1] 214 | final_ans = '' 215 | for let in ans: 216 | if let in choices: 217 | final_ans = final_ans + let 218 | if 'MCQ' in q_type: 219 | return final_ans 220 | 221 | if match8: 222 | print('Answer matching type 8\n') 223 | ans = match8[-1] 224 | final_ans = '' 225 | for let in ans: 226 | if let in choices: 227 | final_ans = final_ans + let 228 | if 'MCQ' in q_type: 229 | return final_ans 230 | 231 | if match9: 232 | print('Answer matching type 9\n') 233 | ans = match9[-1] 234 | final_ans = '' 235 | for let in ans: 236 | if let in choices: 237 | final_ans = final_ans + let 238 | if 'MCQ' in q_type: 239 | return final_ans 240 | 241 | if match1: 242 | print('Answer matching type 1\n') 243 | ans = match1[-1] 244 | ans = ans[1:] 245 | ans = ans[:-1] 246 | if 'MCQ' in q_type: 247 | return ans 248 | 249 | if match2: 250 | print('Answer matching type 2\n') 251 | ans = match2[-1] 252 | ans = ans[1:] 253 | ans = ans[:-1] 254 | if 'MCQ' not in q_type: 255 | try: 256 | float_ans = float(ans) 257 | return ans 258 | except Exception as e: 259 | print('Matching error!\n') 260 | print('answer invalid!\n') 261 | return 'None' 262 | -------------------------------------------------------------------------------- /utils/extract_both_samples.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | import pathlib 5 | import argparse 6 | from utils.json_operator import * 7 | 8 | 9 | def extract_both_samples(arguments): 10 | base_dir = os.getcwd() 11 | out_file = f'{base_dir}/generation/vm/{arguments.task_name}/{arguments.file}/{arguments.mode}/{arguments.propose_method}_{arguments.value_method}_all.json' 12 | all_outputs = read_json(out_file) 13 | 14 | if arguments.mode == 'mcts': 15 | policy_output_dir = pathlib.Path(f'{base_dir}/extracted_samples/{arguments.task_name}/{arguments.file}/{arguments.mode}/policy_samples') 16 | pathlib.Path.mkdir(policy_output_dir, exist_ok=True, parents=True) 17 | policy_output_file = f'{base_dir}/extracted_samples/{arguments.task_name}/{arguments.file}/{arguments.mode}/policy_samples/{arguments.propose_method}_{arguments.value_method}.json' 18 | new_policy_outputs = [] 19 | for output in all_outputs: 20 | cur_outputs = [] 21 | content = output['content'] 22 | real_answer = output['real_answer'] 23 | for cur_output in output['policy_samples']: 24 | new_sample = {'content': content, 'summary': cur_output['solution'] + cur_output['summary'] if 'final answer is' not in cur_output['solution'] else cur_output['solution'], 'label': cur_output['correct'], 'real_answer': real_answer} 25 | cur_outputs.append(new_sample) 26 | new_policy_outputs.extend(cur_outputs) 27 | dump_json(policy_output_file, new_policy_outputs) 28 | 29 | value_output_dir = pathlib.Path(f'{base_dir}/extracted_samples/{arguments.task_name}/{arguments.file}/{arguments.mode}/value_samples') 30 | pathlib.Path.mkdir(value_output_dir, exist_ok=True, parents=True) 31 | value_output_file = f'{base_dir}/extracted_samples/{arguments.task_name}/{arguments.file}/{arguments.mode}/value_samples/{arguments.propose_method}_{arguments.value_method}.json' 32 | new_value_outputs = [] 33 | for output in all_outputs: 34 | cur_outputs = [] 35 | content = output['content'] 36 | for cur_output in output['value_samples']: 37 | new_sample = {'prompt_answer': 'Problem:' + content + '\nSolution:\n' + cur_output['steps'], 'label': cur_output['value']} 38 | cur_outputs.append(new_sample) 39 | new_value_outputs.extend(cur_outputs) 40 | dump_json(value_output_file, new_value_outputs) 41 | 42 | elif arguments.mode == 'cot': 43 | policy_output_dir = pathlib.Path(f'{base_dir}/extracted_samples/{arguments.task_name}/{arguments.file}/{arguments.mode}/policy_samples') 44 | pathlib.Path.mkdir(policy_output_dir, exist_ok=True, parents=True) 45 | policy_output_file = f'{base_dir}/extracted_samples/{arguments.task_name}/{arguments.file}/{arguments.mode}/policy_samples/{arguments.propose_method}_{arguments.value_method}.json' 46 | new_policy_outputs = [] 47 | for output in all_outputs: 48 | content = output['content'] 49 | real_answer = output['real_answer'] 50 | new_sample = {'content': content, 'summary': output['solution'] + output['summary'] if 'final answer is' not in output['solution'] else output['solution'], 'label': output['accurate'], 'real_answer': real_answer} 51 | if arguments.do_self_critic: 52 | new_sample.update({'self_critic': output['self_critic']}) 53 | new_policy_outputs.append(new_sample) 54 | dump_json(policy_output_file, new_policy_outputs) 55 | 56 | else: 57 | print("Unsupported sample extraction mode!\n") 58 | return 59 | 60 | 61 | def parse_args(): 62 | base_args = argparse.ArgumentParser() 63 | base_args.add_argument('--task_name', type=str, default='gsm_8k') 64 | base_args.add_argument('--file', type=str, default='gsm8k_self_train_1') # json 65 | base_args.add_argument('--propose_method', type=str, choices=['gpt', 'glm', 'llama', 'mistral', 'local'], default='local') 66 | base_args.add_argument('--value_method', type=str, choices=['gpt', 'glm', 'local'], default='local') 67 | base_args.add_argument('--mode', type=str, choices=['cot', 'mcts'], default='mcts') 68 | base_args.add_argument('--generate_num', type=int, default=1) 69 | base_args.add_argument('--do_self_critic', type=bool, default=False) # for CoT 70 | 71 | arguments = base_args.parse_args() 72 | return arguments 73 | 74 | 75 | if __name__ == '__main__': 76 | args = parse_args() 77 | extract_both_samples(args) 78 | -------------------------------------------------------------------------------- /utils/format_dpo.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | round_n = 2 4 | base = 'mistral' 5 | 6 | dpo_dataset_dict = dict() 7 | dpo_dataset_dict["prompt"] = list() 8 | dpo_dataset_dict["chosen"] = list() 9 | dpo_dataset_dict["rejected"] = list() 10 | 11 | cnt = 0 12 | with open(f'extracted_samples/self_train_{round_n}/cot/dpo-{base}_local.json', 13 | 'rt') as f: 14 | for line in f.readlines(): 15 | cont = json.loads(line) 16 | cnt += 1 17 | # print(cont, type(cont)) 18 | dpo_dataset_dict["prompt"].append(cont['prompt']) 19 | dpo_dataset_dict["chosen"].append(cont['response_chosen']) 20 | dpo_dataset_dict["rejected"].append(cont['response_rejected']) 21 | 22 | print(cnt) 23 | print(len(dpo_dataset_dict["prompt"])) 24 | print(len(dpo_dataset_dict["chosen"])) 25 | print(len(dpo_dataset_dict["rejected"])) 26 | out_file = f"extracted_samples/self_train_{round_n}/cot/{base}_local_dpo.json" 27 | 28 | with open(out_file, "w") as f: 29 | json.dump(dpo_dataset_dict, f) 30 | print("Loading the file is complete...") 31 | 32 | with open(out_file, 'r') as load_f: 33 | load_dict = json.load(load_f) 34 | print(len(load_dict['prompt'])) 35 | print(len(load_dict['chosen'])) 36 | print(len(load_dict['rejected'])) 37 | -------------------------------------------------------------------------------- /utils/json_operator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path 3 | 4 | 5 | def read_json(source): 6 | json_list = [] 7 | if not os.path.exists(source): 8 | return json_list 9 | with open(source, 'r', encoding='utf-8') as f: 10 | for line in f: 11 | json_list.append(json.loads(line)) 12 | return json_list 13 | 14 | 15 | def dump_json(source, datas): 16 | with open(source, 'w', encoding='utf-8') as f: 17 | for item in datas: 18 | json.dump(item, f, ensure_ascii=False) 19 | f.write('\n') 20 | -------------------------------------------------------------------------------- /utils/orm_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 4 | 5 | import json 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from transformers import AutoModel, AutoTokenizer 10 | from torch.utils.data import DataLoader, Dataset 11 | import pandas as pd 12 | 13 | model_dir = '/workspace/ckpt/chatglm3-6b-base' 14 | tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) 15 | base_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda() 16 | max_length = 1024 17 | key_token1 = 'True' 18 | key_token1 = tokenizer.encode(key_token1)[-1] 19 | key_token2 = 'False' 20 | key_token2 = tokenizer.encode(key_token2)[-1] 21 | 22 | 23 | class ChatGLM_Filter(nn.Module): 24 | def __init__(self, base): 25 | super(ChatGLM_Filter, self).__init__() 26 | self.base_model = base 27 | 28 | def forward(self, input_ids, attention_mask): 29 | outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask).logits[:, -1] 30 | return outputs 31 | 32 | 33 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 34 | filter_model = ChatGLM_Filter(base_model) 35 | filter_model.load_state_dict(torch.load("/workspace/ckpt/rm_best_checkpoint_3.pt")) 36 | filter_model.to(device) 37 | filter_model.eval() 38 | 39 | 40 | def get_orm_score(question, answer): 41 | with torch.no_grad(): 42 | encoded_pair = tokenizer.encode_plus( 43 | question + '[answer]' + answer, 44 | padding='max_length', 45 | max_length=max_length, # Set the max length 46 | truncation=True, 47 | return_tensors='pt', # Return PyTorch Tensor format 48 | ) 49 | input_ids = encoded_pair['input_ids'].cuda() 50 | attention_mask = encoded_pair['attention_mask'].cuda() 51 | outputs = filter_model(input_ids, attention_mask) 52 | outputs = torch.softmax(outputs, dim=1) 53 | 54 | outputs_1 = outputs[0, key_token1].item() 55 | outputs_2 = outputs[0, key_token2].item() 56 | score = outputs_1 - outputs_2 57 | return score 58 | 59 | 60 | def get_orm_scores(outputs): 61 | scores = [] 62 | for output in outputs: 63 | question = output['content'] 64 | answer = output['solution'] 65 | score = get_orm_score(question, answer) 66 | print(f"Get an ORM score :{score}\n") 67 | scores.append(score) 68 | return scores 69 | 70 | 71 | def get_best_solution_orm(outputs): 72 | scores = get_orm_scores(outputs) 73 | best_idx = np.argmax(scores) 74 | best_output = outputs[best_idx] 75 | return best_output 76 | -------------------------------------------------------------------------------- /utils/result_evaluator.py: -------------------------------------------------------------------------------- 1 | from utils.json_operator import * 2 | from CoT.task import * 3 | 4 | 5 | def evaluate_result(source_dir: str, result_file_pattern: str): 6 | result_files = [] 7 | for file in os.listdir(source_dir): 8 | if result_file_pattern in file: 9 | result_files.append(file) 10 | 11 | all_results = [] 12 | for result_file in result_files: 13 | result_file_path = os.path.join(source_dir, result_file) 14 | result = read_json(result_file_path) 15 | all_results.extend(result) 16 | 17 | total_num = len(all_results) 18 | simulate_corr_count = 0 19 | for single_result in all_results: 20 | corr_num = single_result['correct_num'] 21 | sample_num = single_result['sample_num'] 22 | acc = corr_num / sample_num 23 | simulate_corr_count += acc 24 | 25 | simulate_acc = simulate_corr_count / total_num 26 | print(f'simulate_acc: {simulate_acc}, corr_expectation: {simulate_corr_count}, total_num: {total_num}') 27 | 28 | 29 | def reEvaluate_result(source_dir: str, result_file_pattern: str, multisample=False): 30 | result_files = [] 31 | for file in os.listdir(source_dir): 32 | if result_file_pattern in file: 33 | result_files.append(file) 34 | 35 | all_results = [] 36 | for result_file in result_files: 37 | result_file_path = os.path.join(source_dir, result_file) 38 | result = read_json(result_file_path) 39 | all_results.extend(result) 40 | 41 | total_num = len(all_results) 42 | if multisample: 43 | simulate_corr_count = 0 44 | for results in all_results: 45 | single_corr_num = 0 46 | answer = results['real_answer'] 47 | question = results['content'] 48 | for single_result in results['samples']: 49 | if single_result['accurate']: 50 | single_corr_num += 1 51 | else: 52 | if exact_match_score(single_result['summary'], answer): 53 | single_corr_num += 1 54 | single_result['accurate'] = True 55 | else: 56 | solution = single_result['solution'].strip() 57 | Task = CoT_Task(question, propose_method='local', value_method='local', evaluate='math', 58 | lang='en', answer=answer) 59 | cnt = 10 60 | summary = '' 61 | while not solution and cnt: 62 | out = Task.run() 63 | solution = out['solution'] 64 | summary = out['summary'] 65 | cnt -= 1 66 | single_result['solution'] = solution 67 | 68 | if '####' in solution: 69 | summary = 'The final answer is ' + solution.split('####')[-1].strip() 70 | elif 'The final answer is' in solution: 71 | summary = 'The final answer is ' + solution.split('The final answer is')[-1].strip() 72 | elif 'The answer is' in solution: 73 | summary = 'The final answer is ' + solution.split('The answer is')[-1].strip() 74 | else: 75 | cnt = 10 76 | while cnt and not summary: 77 | summary = Task.get_MATH_summary(solution) 78 | cnt -= 1 79 | 80 | result = exact_match_score(summary, answer) 81 | if result: 82 | single_corr_num += 1 83 | single_result['accurate'] = True 84 | single_result['summary'] = summary 85 | 86 | sample_num = results['sample_num'] 87 | results['correct_num'] = single_corr_num 88 | acc = single_corr_num / sample_num 89 | simulate_corr_count += acc 90 | 91 | simulate_acc = simulate_corr_count / total_num 92 | print(f'simulate_acc: {simulate_acc}, corr_expectation: {simulate_corr_count}, total_num: {total_num}') 93 | return all_results 94 | 95 | else: 96 | simulate_corr_count = 0 97 | for results in all_results: 98 | answer = results['real_answer'] 99 | question = results['content'] 100 | solution = results['solution'] 101 | if results['accurate']: 102 | simulate_corr_count += 1 103 | else: 104 | if exact_match_score(results['summary'], answer): 105 | simulate_corr_count += 1 106 | results['accurate'] = True 107 | else: 108 | Task = CoT_Task(question, propose_method='local', value_method='local', evaluate='math', 109 | lang='en', answer=answer) 110 | cnt = 10 111 | summary = '' 112 | while not solution and cnt: 113 | out = Task.run() 114 | solution = out['solution'] 115 | summary = out['summary'] 116 | cnt -= 1 117 | results['solution'] = solution 118 | 119 | if '####' in solution: 120 | summary = 'The final answer is ' + solution.split('####')[-1].strip() 121 | elif 'The final answer is' in solution: 122 | summary = 'The final answer is ' + solution.split('The final answer is')[-1].strip() 123 | elif 'The answer is' in solution: 124 | summary = 'The final answer is ' + solution.split('The answer is')[-1].strip() 125 | else: 126 | cnt = 10 127 | while cnt and not summary: 128 | summary = Task.get_MATH_summary(solution) 129 | cnt -= 1 130 | 131 | result = exact_match_score(summary, answer) 132 | if result: 133 | simulate_corr_count += 1 134 | results['accurate'] = True 135 | results['summary'] = summary 136 | 137 | simulate_acc = simulate_corr_count / total_num 138 | print(f'simulate_acc: {simulate_acc}, corr_expectation: {simulate_corr_count}, total_num: {total_num}') 139 | return all_results 140 | -------------------------------------------------------------------------------- /utils/self_consistency.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def get_consistency_output_scibench(outputs): 5 | output_count = {} 6 | for output in outputs: 7 | summ = output['summary'].strip() 8 | try: 9 | match = re.findall(r'[^^{.\-0123456789]-?[0-9]+\.?[0-9]*[^^}.0123456789]', summ)[-1][1:][:-1] 10 | model_ans = float(match) 11 | 12 | except Exception as e: 13 | try: 14 | match = re.findall(r'-?[0-9]+\.?[0-9]*', summ)[-1] 15 | model_ans = float(match) 16 | except Exception as e: 17 | print(f'Extract the answer error! Error type:{e}\n') 18 | continue 19 | 20 | if model_ans not in output_count.keys(): 21 | output_count.update({model_ans: [1, output]}) 22 | else: 23 | output_count[model_ans][0] += 1 24 | 25 | if not output_count: 26 | return outputs[0] 27 | 28 | most_cons_count = 0 29 | most_cons_output = {} 30 | for ans, info in output_count.items(): 31 | if info[0] > most_cons_count: 32 | most_cons_count = info[0] 33 | most_cons_output = info[1] 34 | return most_cons_output 35 | 36 | 37 | def get_consistency_output_scieval(outputs, q_type): 38 | output_count = {} 39 | for output in outputs: 40 | summ = output['summary'].strip() 41 | if q_type == "multiple-choice": 42 | try: 43 | model_ans = re.findall(r'[A-E]', summ)[0] 44 | except Exception as e: 45 | print(f"Extract the answer error! Error type:{e}\n") 46 | continue 47 | elif q_type == "judge": 48 | model_ans = summ 49 | elif q_type == "filling": 50 | model_ans = summ 51 | else: 52 | break 53 | 54 | if model_ans not in output_count.keys(): 55 | output_count.update({model_ans: [1, output]}) 56 | else: 57 | output_count[model_ans][0] += 1 58 | 59 | if not output_count: 60 | return outputs[0] 61 | 62 | most_cons_count = 0 63 | most_cons_output = {} 64 | for ans, info in output_count.items(): 65 | if info[0] > most_cons_count: 66 | most_cons_count = info[0] 67 | most_cons_output = info[1] 68 | return most_cons_output 69 | -------------------------------------------------------------------------------- /utils/solution_summary_extractor.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def extract_summary_from_solution(solution: str): 5 | pattern = r"\\boxed\{(.*)\}" 6 | match = re.findall(pattern, solution) 7 | if match: 8 | summary = 'The final answer is ' + match[-1] 9 | elif '####' in solution: 10 | extracted = solution.split('####')[-1].strip() 11 | if len(extracted) > 1: 12 | if extracted[-1] == '.': 13 | extracted = extracted[:-1].strip() 14 | if len(extracted) > 1: 15 | if extracted[0] == ':': 16 | extracted = extracted[1:].strip() 17 | summary = 'The final answer is ' + extracted 18 | elif 'The final answer is' in solution: 19 | extracted = solution.split('The final answer is')[-1].strip() 20 | if len(extracted) > 1: 21 | if extracted[-1] == '.': 22 | extracted = extracted[:-1].strip() 23 | if len(extracted) > 1: 24 | if extracted[0] == ':': 25 | extracted = extracted[1:].strip() 26 | summary = 'The final answer is ' + extracted 27 | elif 'The answer is' in solution: 28 | extracted = solution.split('The answer is')[-1].strip() 29 | if len(extracted) > 1: 30 | if extracted[-1] == '.': 31 | extracted = extracted[:-1].strip() 32 | if len(extracted) > 1: 33 | if extracted[0] == ':': 34 | extracted = extracted[1:].strip() 35 | summary = 'The final answer is ' + extracted 36 | elif 'final answer is' in solution: 37 | extracted = solution.split('final answer is')[-1].strip() 38 | if len(extracted) > 1: 39 | if extracted[-1] == '.': 40 | extracted = extracted[:-1].strip() 41 | if len(extracted) > 1: 42 | if extracted[0] == ':': 43 | extracted = extracted[1:].strip() 44 | summary = 'The final answer is ' + extracted 45 | elif 'answer is' in solution: 46 | extracted = solution.split('answer is')[-1].strip() 47 | if len(extracted) > 1: 48 | if extracted[-1] == '.': 49 | extracted = extracted[:-1].strip() 50 | if len(extracted) > 1: 51 | if extracted[0] == ':': 52 | extracted = extracted[1:].strip() 53 | summary = 'The final answer is ' + extracted 54 | else: 55 | summary = '' 56 | print('Extracted summary: ', summary, '\n') 57 | return summary 58 | -------------------------------------------------------------------------------- /utils/verify_answer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | 4 | 5 | # only support float answer verification 6 | def verify_float(answer: float, output: str): 7 | if not output: 8 | print(f'The output is empty and cannot match the answer!\n') 9 | return False 10 | 11 | if '综上所述,' in output: 12 | spl_ans = output.split('综上所述,')[-1] 13 | spl_ans = spl_ans.strip() 14 | else: 15 | spl_ans = output.strip() 16 | 17 | try: 18 | match = re.findall(r'[^^{.\-0123456789]-?[0-9]+\.?[0-9]*[^^}.0123456789]', spl_ans)[-1][1:][:-1] 19 | model_ans = float(match) 20 | 21 | # standard (adjustable) 22 | if abs(answer) >= 1: 23 | result = math.isclose(model_ans, answer, abs_tol=0.1) 24 | else: 25 | result = math.isclose(model_ans, answer, rel_tol=0.1) 26 | 27 | print(f'The ans of model is:{model_ans}, while the ground truth is {answer}.\n') 28 | return result 29 | 30 | except Exception as e: 31 | try: 32 | match = re.findall(r'-?[0-9]+\.?[0-9]*', spl_ans)[-1] 33 | model_ans = float(match) 34 | 35 | # standard (adjustable) 36 | if abs(answer) >= 1: 37 | result = math.isclose(model_ans, answer, abs_tol=0.1) 38 | else: 39 | result = math.isclose(model_ans, answer, rel_tol=0.1) 40 | 41 | print(f'The ans of model is:{model_ans}, while the ground truth is {answer}.\n') 42 | return result 43 | except Exception as e: 44 | print(f'Result not matched, error type:{e}\n') 45 | print(f'The ans of model is:{spl_ans}, while the ground truth is {answer}.\n') 46 | return False 47 | 48 | 49 | # only support choice answer verification 50 | def verify_choice(answer: str, output: str): 51 | if not output: 52 | print(f'The output is empty and cannot match the answer!\n') 53 | return False 54 | 55 | check_list = ['A', 'B', 'C', 'D', 'E'] 56 | 57 | if '综上所述,最终答案是:' in output: 58 | spl_ans = output.split('综上所述,最终答案是:')[-1] 59 | spl_ans = spl_ans.strip() 60 | elif '综上所述,' in output: 61 | spl_ans = output.split('综上所述,')[-1] 62 | spl_ans = spl_ans.strip() 63 | else: 64 | spl_ans = output.strip() 65 | 66 | # standard (adjustable) 67 | for choice in check_list: 68 | if choice in answer and choice not in spl_ans: 69 | print(f'The ans of model is:{spl_ans}, while the ground truth is {answer}.\n') 70 | return False 71 | if choice not in answer and choice in spl_ans: 72 | print(f'The ans of model is:{spl_ans}, while the ground truth is {answer}.\n') 73 | return False 74 | 75 | print(f'The ans of model is:{spl_ans}, while the ground truth is {answer}.\n') 76 | return True 77 | 78 | 79 | # for scieval 80 | def verify_scieval(answer, output, q_type): 81 | print(f'The ans of model is:"{output}", while the ground truth is {answer}.\n') 82 | if q_type == "multiple-choice": 83 | try: 84 | match = re.findall(r'[A-E]', output)[0] 85 | except Exception as e: 86 | print(f"Result not matched, error type:{e}\n") 87 | return False 88 | if answer.lower() == match.lower(): 89 | return True 90 | elif q_type == "judge": 91 | if answer.lower() in output.lower(): 92 | return True 93 | elif q_type == "filling": 94 | if answer.lower() in output.lower(): 95 | return True 96 | else: 97 | print('Type error!\n') 98 | return False 99 | return False 100 | -------------------------------------------------------------------------------- /utils/verify_llm.py: -------------------------------------------------------------------------------- 1 | from models.get_response import * 2 | 3 | 4 | def llm_verify(ans, real_ans, judge_model='gpt-4-1106-preview'): 5 | prompt = '下面将输入两段文字,第一段文字为某道理科题目的一个解答或答案(不一定正确),第二段是这道题目的标准答案。请判断第一段解答得到的答案与标准答案在数学意义上是否一致,并根据判断直接输出‘0’或’1‘,不需要输出任何别的信息。如果答案一致,请输出‘1’;否则,只要答案不匹配,或者第一个文段中没有明确指出答案也没有输出latex表达式,请输出‘0’;如果第一段解答与标准答案之间关系模糊,请输出‘0’。\n' 6 | qry = prompt + '文段1:' + ans + '\n' + '文段2:' + real_ans + '\n输出:' 7 | lbl = '' 8 | cnt = 5 9 | while lbl == '' and cnt: 10 | out = '' 11 | try: 12 | chat_comp = openai.ChatCompletion.create(model=judge_model, messages=[{"role": "user", "content": qry}]) 13 | out = chat_comp.choices[0].message.content[0] 14 | except Exception as e: 15 | print(f'Error:{e}\n') 16 | if out == '0' or out == '1': 17 | lbl = out 18 | else: 19 | cnt -= 1 20 | if not cnt: 21 | return 0 22 | return int(lbl) 23 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | from graphviz import Digraph 2 | 3 | colors = ['Yellow', 'Gold', 'Orange', 'Orangered', 'Red', 'Crimson', 'Darkred'] 4 | 5 | 6 | def split_str(strs): 7 | sent_str = strs.split("。") 8 | all_strs = '' 9 | for sent in sent_str: 10 | piece_str = sent.split(",") 11 | for piece in piece_str: 12 | all_strs = all_strs + piece + '\n' 13 | return all_strs 14 | 15 | 16 | def visualize(root, task, task_name, file_name, file_suffix): 17 | fname = f'graphs/{task_name}/{file_name}/{task.mode}/{task.propose_method}_{task.value_method}/{file_suffix}' 18 | g = Digraph("G", filename=fname, format='png', strict=False) 19 | str1 = 'Question: ' + split_str(task.question) + '\nAccess sequence: ' + str(root.visit_sequence) + '\nValue: ' + str( 20 | root.V) + '\nflag: ' + str(root.final_ans_flag) 21 | g.node(str(root.visit_sequence), str1, color=colors[root.visit_sequence % len(colors)]) 22 | sub_plot(g, root, task) 23 | g.node_attr['shape'] = 'tab' 24 | g.node_attr['fontname'] = 'Microsoft YaHei' 25 | g.graph_attr['size'] = '960,640' 26 | g.render(view=False) 27 | 28 | 29 | def sub_plot(graph, root, task): 30 | if task.mode == 'mcts': 31 | for child in root.children.values(): 32 | trans_str = split_str(child.pcd) 33 | str2 = trans_str + '\nAccess sequence: ' + str(child.visit_sequence) + '\nValue: ' + str(child.V) + '\nflag: ' + str(child.final_ans_flag) 34 | graph.node(str(child.visit_sequence), str2, color=colors[child.visit_sequence % len(colors)]) 35 | graph.edge(str(root.visit_sequence), str(child.visit_sequence), str(child.visit_sequence - 1)) 36 | sub_plot(graph, child, task) 37 | else: 38 | for child in root.children: 39 | trans_str = split_str(child.pcd) 40 | str2 = trans_str + '\nAccess sequence: ' + str(child.visit_sequence) + '\nValue: ' + str(child.V) + '\nflag: ' + str( 41 | child.final_ans_flag) 42 | graph.node(str(child.visit_sequence), str2, color=colors[child.visit_sequence % len(colors)]) 43 | graph.edge(str(root.visit_sequence), str(child.visit_sequence), str(child.visit_sequence - 1)) 44 | sub_plot(graph, child, task) 45 | -------------------------------------------------------------------------------- /utils/weighted_consistency.py: -------------------------------------------------------------------------------- 1 | from utils.orm_score import get_orm_scores 2 | import re 3 | 4 | 5 | def get_weighted_consistency_output_scibench(outputs): 6 | scores = get_orm_scores(outputs) 7 | output_count = {} 8 | for i in range(len(outputs)): 9 | output = outputs[i] 10 | score = scores[i] 11 | summ = output['summary'].strip() 12 | try: 13 | match = re.findall(r'[^^{.\-0123456789]-?[0-9]+\.?[0-9]*[^^}.0123456789]', summ)[-1][1:][:-1] 14 | model_ans = float(match) 15 | 16 | except Exception as e: 17 | try: 18 | match = re.findall(r'-?[0-9]+\.?[0-9]*', summ)[-1] 19 | model_ans = float(match) 20 | except Exception as e: 21 | print(f'Extract the answer error! Error type:{e}\n') 22 | continue 23 | 24 | if model_ans not in output_count.keys(): 25 | output_count.update({model_ans: [score, output]}) 26 | else: 27 | output_count[model_ans][0] += score 28 | 29 | if not output_count: 30 | return outputs[0] 31 | 32 | most_cons_score = 0 33 | most_cons_output = {} 34 | for ans, info in output_count.items(): 35 | if info[0] > most_cons_score: 36 | most_cons_score = info[0] 37 | most_cons_output = info[1] 38 | return most_cons_output 39 | --------------------------------------------------------------------------------