├── LICENSE ├── README.md ├── assets └── abs.png └── src ├── ToT ├── base.py ├── bfs.py ├── dfs.py └── task.py ├── configs ├── dpo.yaml └── sft.yaml ├── infer.py ├── infer.sh ├── judge.py ├── judge.sh ├── models └── get_response.py ├── process_data.py ├── tasks ├── prompts.py └── science.py └── tree_search.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPaR 2 | ## Self-Play with Tree-Search Refinement to Improve Instruction-Following in Large Language Models 3 | 4 |

5 | 🤗 Data • 📃 Paper 6 |

7 | 8 | SPaR focuses on creating interference-free preference pairs for effective self- improvement. An example of the interfering factors (*story content*) in independently sampled multiple responses (Left). Refined response pairs exclude these factors, highlight the key difference (*ending sentence*), and lead to improved performance on iteratively trained LLaMA3-8B-Instruct (Right). 9 | 10 |
11 | BPO 12 |
13 | 14 |
15 | 16 | ## Table of Contents 17 | - [Data](#data) 18 | - [Quick Start](#quick-start) 19 | - [Data Construction](#data-construction) 20 | - [Model Training](#model-training) 21 | - [Citation](#citation) 22 | 23 | ## Data 24 | 25 | ### SPaR dataset 26 | SPaR Dataset can be found on [Hugging Face](https://huggingface.co/datasets/CCCCCC/SPaR). 27 | 28 | We provide a high-quality SFT dataset for instruction-following tasks and the data for iterative self-training. 29 | 30 | 31 | ## Quick Start 32 | For all codes, we have added `#TODO` comments to indicate places in the code that need modification before running. Please update the relevant parts as noted before executing each file. 33 | 34 | ### Data Construction 35 | To construct the iterative training data yourself, run the following command 36 | ```bash 37 | cd src 38 | 39 | bash infer.sh 40 | 41 | python process_data.py 42 | 43 | bash judge.py 44 | 45 | python process_data.py 46 | 47 | vllm serve 48 | 49 | python tree_search.py 50 | 51 | python process_data.py 52 | ``` 53 | 54 | ### Model Training 55 | If you want to train your own model, 56 | please run the following command: 57 | ```bash 58 | cd src 59 | 60 | # dpo 61 | llamafactory-cli train configs/dpo.yaml 62 | 63 | # sft 64 | llamafactory-cli train configs/sft.yaml 65 | 66 | ``` 67 | 68 | 69 | ## Acknowledgement 70 | - Training code: [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) 71 | - Tree-search implementation: [Rest-MCTS*](https://github.com/THUDM/ReST-MCTS) 72 | 73 | ## Citation 74 | ``` 75 | @misc{cheng2024sparselfplaytreesearchrefinement, 76 | title={SPaR: Self-Play with Tree-Search Refinement to Improve Instruction-Following in Large Language Models}, 77 | author={Jiale Cheng and Xiao Liu and Cunxiang Wang and Xiaotao Gu and Yida Lu and Dan Zhang and Yuxiao Dong and Jie Tang and Hongning Wang and Minlie Huang}, 78 | year={2024}, 79 | eprint={2412.11605}, 80 | archivePrefix={arXiv}, 81 | primaryClass={cs.CL}, 82 | url={https://arxiv.org/abs/2412.11605}, 83 | } 84 | ``` -------------------------------------------------------------------------------- /assets/abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/SPaR/4c1010addd771c580dc98288f3d63960364c5a5e/assets/abs.png -------------------------------------------------------------------------------- /src/ToT/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from tasks.prompts import * 4 | 5 | 6 | class Node(object): 7 | def __init__(self, prompt: str, response: str, critique: str, value=0, parent=None, depth=0, refine_cot=""): 8 | # self.pcd = pcd # 当前步骤 9 | self.children = [] 10 | self.V = value 11 | self.parent = parent 12 | # self.y = '' # 全部步骤 13 | self.depth = depth 14 | self.visit_sequence = 0 15 | self.final_ans_flag = 0 16 | self.prompt = prompt 17 | self.response = response 18 | self.critique = critique 19 | self.refine_cot = refine_cot 20 | 21 | 22 | def append_children(self, new_res: str, refine_cot: str): 23 | node = Node(self.prompt, new_res, '', 0, self, self.depth + 1, refine_cot) 24 | # node.update_y_from_parent() 25 | self.children.append(node) 26 | return self, node 27 | 28 | def update_y_from_parent(self): 29 | if self.parent is None: 30 | self.y = self.pcd 31 | else: 32 | self.y = self.parent.y + self.pcd 33 | 34 | def update_value(self, value): 35 | self.V = value 36 | 37 | def update_critique(self, critique): 38 | self.critique = critique 39 | 40 | def getRefinement(self): 41 | if not self.children: 42 | return self, self.V 43 | max_V = self.V 44 | max_node = self 45 | for child in self.children: 46 | subNode, subValue = child.getBestV() 47 | if subValue > max_V: 48 | max_V = subValue 49 | max_node = subNode 50 | return max_node, max_V 51 | 52 | def getCritiqueRFT(self, end_gate=0.5): 53 | if not self.children: 54 | return [], [] 55 | negative = [] 56 | positive = [] 57 | for child in self.children: 58 | negative_critique, positive_critique = child.getCritiqueRFT() 59 | if child.V > end_gate and child.V < 1.0: 60 | positive_critique.append({'prompt': judge_template.format(child.prompt, child.response), 'response': child.critique}) 61 | elif child.V < end_gate and child.V > 0.0: 62 | negative_critique.append({'prompt': judge_template.format(child.prompt, child.response), 'response': child.critique}) 63 | 64 | negative.extend(negative_critique) 65 | positive.extend(positive_critique) 66 | 67 | return negative, positive 68 | 69 | 70 | def getBestV(self): # 获取子树最大价值节点 71 | if not self.children: 72 | return self, self.V 73 | max_V = self.V 74 | max_node = self 75 | for child in self.children: 76 | subNode, subValue = child.getBestV() 77 | if subValue >= max_V: 78 | max_V = subValue 79 | max_node = subNode 80 | return max_node, max_V 81 | 82 | def get_multiply_value(self): 83 | if self.depth == 0: 84 | return 0 85 | multi_value = self.V 86 | cur_node = self.parent 87 | while cur_node.depth > 0: 88 | multi_value = multi_value * cur_node.V 89 | cur_node = cur_node.parent 90 | return multi_value 91 | 92 | 93 | class SolutionStep(object): 94 | def __init__(self, x, stp, all_steps, score, step_num): 95 | self.x = x 96 | self.stp = stp 97 | self.all_steps = all_steps 98 | self.score = score 99 | self.step_num = step_num 100 | 101 | 102 | def rand_select(data_list: list, probs: list): # 按概率抽样 103 | assert len(data_list) == len(probs), "length do not match!" 104 | probs_norm = [] 105 | sum_prob = sum(probs) 106 | for i in probs: 107 | probs_norm.append(i / sum_prob) 108 | intervals = [] 109 | count = 0 110 | for i in probs_norm: 111 | count = count + i 112 | intervals.append(count) 113 | # assert count == 1, "probs error!" 114 | intervals[len(intervals) - 1] = 1 115 | index = 0 116 | rand_prob = random.random() 117 | while rand_prob >= intervals[index]: 118 | index = index + 1 119 | return index, data_list[index] -------------------------------------------------------------------------------- /src/ToT/bfs.py: -------------------------------------------------------------------------------- 1 | from ToT.base import Node, rand_select 2 | 3 | 4 | def BFS(tot_task): 5 | root = Node(tot_task.prompt, tot_task.response, tot_task.critique, tot_task.root_v) 6 | cur_nodes = [root] 7 | for depth in range(tot_task.max_depth): 8 | candidates = [] 9 | for node in cur_nodes: 10 | 11 | for i in range(tot_task.branch): 12 | # print("depth: ", node.depth, "branch: ", i) 13 | refine_res = '' 14 | cnt = 3 15 | while not refine_res and cnt: 16 | # get refinement 17 | cnt -= 1 18 | refine_res, refine_cot = tot_task.get_next_step(node.prompt, node.response, node.critique) 19 | # print(refine_res) 20 | 21 | if not refine_res: 22 | continue 23 | 24 | node, child = node.append_children(refine_res, refine_cot) 25 | 26 | # get judgement, value is the pass rate 27 | critique, value = tot_task.get_step_value(child.prompt, child.response) 28 | 29 | if not critique: 30 | node.children.pop() 31 | del child 32 | continue 33 | 34 | child.update_value(value) 35 | child.update_critique(critique) 36 | 37 | child.visit_sequence = tot_task.node_count 38 | tot_task.update_count() 39 | candidates.append(child) 40 | 41 | tot_task.budget -= 1 42 | 43 | if tot_task.budget == 0: 44 | break 45 | 46 | if tot_task.budget == 0: 47 | break 48 | 49 | if not candidates: 50 | break 51 | 52 | 53 | ranked_candidates = sorted(candidates, key=lambda item: item.V, reverse=True) 54 | if ranked_candidates[0].V > tot_task.end_gate: 55 | print('已找到最终解!\n') 56 | ranked_candidates[0].final_ans_flag = 1 57 | return ranked_candidates[0].response, root, ranked_candidates[0] 58 | 59 | if tot_task.budget == 0: 60 | break 61 | 62 | if tot_task.select_method == 'greedy': 63 | cur_nodes = ranked_candidates[:min(tot_task.select_branch, tot_task.branch, len(ranked_candidates))] 64 | 65 | else: 66 | idx_list = [] 67 | cur_nodes = [] 68 | for j in range(min(tot_task.select_branch, tot_task.branch)): 69 | idx, node = rand_select(ranked_candidates, [item.V for item in ranked_candidates]) 70 | if idx not in idx_list: 71 | idx_list.append(idx) 72 | cur_nodes.append(node) 73 | cur_nodes = sorted(cur_nodes, key=lambda item: item.V, reverse=True) 74 | 75 | print('未找到满足要求价值的解答,采用最高价值价值解答代替。\n') 76 | max_node, max_V = root.getBestV() 77 | max_node.final_ans_flag = 1 78 | return max_node.response, root, max_node -------------------------------------------------------------------------------- /src/ToT/dfs.py: -------------------------------------------------------------------------------- 1 | from ToT.base import Node, rand_select 2 | 3 | 4 | def DFS_sub(tot_task, node): 5 | if node.depth >= tot_task.max_depth: 6 | print('达到最大深度限制!\n') 7 | return "", node, None 8 | 9 | candidates = [] 10 | 11 | for i in range(tot_task.branch): 12 | refine_res = '' 13 | cnt = 3 14 | # print("depth: ", node.depth, "branch: ", i) 15 | while not refine_res and cnt: 16 | # get refinement 17 | # print("cnt: ", cnt) 18 | cnt -= 1 19 | refine_res, refine_cot = tot_task.get_next_step(node.prompt, node.response, node.critique) 20 | # print(2222) 21 | if not refine_res: 22 | continue 23 | # print(refine_res) 24 | node, child = node.append_children(refine_res, refine_cot) 25 | 26 | # get judgement, value is the pass rate 27 | # print(3333) 28 | critique, value = tot_task.get_step_value(child.prompt, child.response) 29 | # print(4444) 30 | 31 | if not critique: 32 | node.children.pop() 33 | del child 34 | continue 35 | 36 | child.update_value(value) 37 | child.update_critique(critique) 38 | 39 | child.visit_sequence = tot_task.node_count 40 | tot_task.update_count() 41 | candidates.append(child) 42 | 43 | tot_task.budget -= 1 44 | 45 | if tot_task.budget <= 0: 46 | break 47 | 48 | if not candidates: 49 | print('未找到合适的下一步!\n') 50 | return "", node, None 51 | ranked_candidates = sorted(candidates, key=lambda item: item.V, reverse=True) 52 | if ranked_candidates[0].V > tot_task.end_gate: 53 | ranked_candidates[0].final_ans_flag = 1 54 | return ranked_candidates[0].response, node, ranked_candidates[0] 55 | 56 | if tot_task.budget <= 0: 57 | return "", node, None 58 | 59 | # 继续下探 60 | if tot_task.select_method == 'greedy': 61 | selected = ranked_candidates[:min(tot_task.select_branch, tot_task.branch, len(ranked_candidates))] 62 | 63 | else: 64 | idx_list = [] 65 | selected = [] 66 | for j in range(min(tot_task.select_branch, tot_task.branch)): 67 | idx, node = rand_select(ranked_candidates, [item.V for item in ranked_candidates]) 68 | if idx not in idx_list: 69 | idx_list.append(idx) 70 | selected.append(node) 71 | selected = sorted(selected, key=lambda item: item.V, reverse=True) 72 | 73 | for child in selected: 74 | solution, child, final_node = DFS_sub(tot_task, child) 75 | if solution: 76 | return solution, node, final_node 77 | 78 | return "", node, None 79 | 80 | 81 | def DFS(tot_task): 82 | root = Node(tot_task.prompt, tot_task.response, tot_task.critique, tot_task.root_v) 83 | solution, root, final_node = DFS_sub(tot_task, root) 84 | if solution: 85 | print(f'已找到最终解!\nSolution:{solution}\n') 86 | return solution, root, final_node 87 | else: 88 | max_node, max_V = root.getBestV() 89 | max_node.final_ans_flag = 1 90 | print(f'未找到满足要求价值的解答,采用最高价值价值解答代替。\nSolution:{max_node.response}\n') 91 | return max_node.response, root, max_node -------------------------------------------------------------------------------- /src/ToT/task.py: -------------------------------------------------------------------------------- 1 | import random 2 | from tasks.science import SearchTask 3 | from ToT.base import Node 4 | from models.get_response import * 5 | from ToT.bfs import BFS 6 | from ToT.dfs import DFS 7 | from openai import OpenAI 8 | 9 | 10 | 11 | class ToT_Task(SearchTask): 12 | def __init__(self, prompt, response, critique, root_value, propose_method='llama', 13 | algorithm='dfs', branch=3, select_branch=2, budget=1000, 14 | refine_model_path='', critique_model_path='', 15 | max_refine_tokens=3072, max_critique_tokens=2048, 16 | max_depth=3, end_gate=0.5, select_method='greedy', 17 | temperature=0.7, top_p=1.0, 18 | openai_api_key = "EMPTY", openai_api_base = "http://localhost:8000/v1", 19 | do_sample=True, sample_critique_num=5, use_case_prompt=False, evaluate='', multiply_value=False, lang='en', answer=None, verify_method=None): 20 | super().__init__(prompt, response, critique) 21 | 22 | 23 | self.client = OpenAI( 24 | api_key=openai_api_key, 25 | base_url=openai_api_base, 26 | ) 27 | 28 | self.budget = budget 29 | self.root_v = root_value 30 | self.propose_method = propose_method 31 | self.mode = 'tot' 32 | self.refine_model_path = refine_model_path 33 | self.critique_model_path = critique_model_path 34 | self.max_refine_tokens = max_refine_tokens 35 | self.max_critique_tokens = max_critique_tokens 36 | self.temperature = temperature 37 | self.top_p = top_p 38 | # self.max_tokens = max_tokens 39 | self.do_sample = do_sample 40 | self.sample_critique_num = sample_critique_num 41 | # self.max_new_tokens = max_new_tokens 42 | self.algorithm = algorithm 43 | self.branch = branch 44 | self.select_branch = select_branch 45 | self.max_depth = max_depth 46 | self.use_case_prompt = use_case_prompt 47 | self.evaluate = evaluate 48 | self.select_method = select_method 49 | self.end_gate = end_gate 50 | self.node_count = 1 51 | self.multiply_value = multiply_value 52 | self.lang = lang 53 | self.answer = answer 54 | self.verify_method = verify_method 55 | 56 | def update_count(self): 57 | self.node_count += 1 58 | 59 | def clear_cache(self): 60 | self.value_cache = {} 61 | self.critique_cache = {} 62 | self.node_count = 1 63 | 64 | def get_next_step(self, prompt, response, critique): 65 | 66 | messages = self.build_refine_message(prompt, response, critique) 67 | 68 | response, refine_cot = get_refine(messages, self.propose_method, self.refine_model_path, self.temperature, self.top_p, self.do_sample, self.max_refine_tokens, self.client) 69 | 70 | if not response: 71 | print('获得refine response失败!\n') 72 | return "", "" 73 | 74 | return response, refine_cot 75 | 76 | 77 | def get_step_value(self, prompt, response): 78 | if response in self.value_cache.keys(): 79 | print("命中cache!\n") 80 | return self.critique_cache[response], self.value_cache[response] 81 | # return '', self.value_cache[response] 82 | 83 | messages = self.build_judge_message(prompt, response) 84 | critique, value = get_value(messages, self.propose_method, self.critique_model_path, 0.7, 1.0, self.do_sample, self.max_critique_tokens, self.sample_critique_num, self.client) 85 | 86 | if not critique: 87 | print('获得critique失败!\n') 88 | return '', 0 89 | 90 | print(f'获得评分:{value}\n') 91 | self.value_cache.update({response: value}) 92 | self.critique_cache.update({response: critique}) 93 | 94 | return critique, value 95 | 96 | 97 | def run(self): 98 | self.clear_cache() 99 | if self.algorithm == 'dfs': 100 | solution, root, final_node = DFS(self) 101 | elif self.algorithm == 'bfs': 102 | solution, root, final_node = BFS(self) 103 | else: 104 | print('Unsupported algorithm!\n') 105 | return {} 106 | 107 | return solution, root, final_node -------------------------------------------------------------------------------- /src/configs/dpo.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: 3 | 4 | ### method 5 | stage: dpo 6 | do_train: true 7 | finetuning_type: full 8 | pref_beta: 0.1 9 | pref_ftx: 0.1 10 | 11 | ### ddp 12 | ddp_timeout: 180000000 13 | deepspeed: examples/deepspeed/ds_z3_config.json 14 | 15 | ### dataset 16 | dataset: dpo 17 | template: llama3 18 | cutoff_len: 2048 19 | # max_samples: 1024 20 | overwrite_cache: true 21 | preprocessing_num_workers: 12 22 | 23 | ### output 24 | output_dir: saves/dpo 25 | logging_steps: 1 26 | save_steps: 20 27 | save_only_model: true 28 | plot_loss: true 29 | overwrite_output_dir: true 30 | 31 | ### train 32 | per_device_train_batch_size: 2 33 | gradient_accumulation_steps: 2 34 | learning_rate: 2.0e-7 35 | num_train_epochs: 1.0 36 | lr_scheduler_type: linear 37 | # lr_scheduler_type: constant_with_warmup 38 | warmup_ratio: 0.1 39 | bf16: true 40 | 41 | ### eval 42 | # val_size: 0.1 43 | # per_device_eval_batch_size: 1 44 | # evaluation_strategy: steps 45 | # eval_steps: 500 46 | -------------------------------------------------------------------------------- /src/configs/sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: full 8 | deepspeed: examples/deepspeed/ds_z3_config.json 9 | 10 | ### dataset 11 | 12 | # dataset: judge 13 | dataset: sft 14 | template: llama3 15 | cutoff_len: 4096 16 | # max_samples: 1000 17 | overwrite_cache: true 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/sft 22 | logging_steps: 1 23 | save_strategy: "epoch" 24 | # save_steps: 500 25 | save_only_model: true 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 4 32 | # learning_rate: 1.0e-6 33 | learning_rate: 2.0e-6 34 | num_train_epochs: 5.0 35 | lr_scheduler_type: constant_with_warmup 36 | # lr_scheduler_type: cosine 37 | warmup_ratio: 0.1 38 | bf16: true 39 | ddp_timeout: 180000000 40 | 41 | ### eval 42 | # val_size: 0.1 43 | # per_device_eval_batch_size: 1 44 | # eval_strategy: steps 45 | # eval_steps: 500 -------------------------------------------------------------------------------- /src/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--begin', type=int) 5 | parser.add_argument('--end', type=int) 6 | parser.add_argument('--gpu', type=int) 7 | parser.add_argument('--output_path', type=str) 8 | args = parser.parse_args() 9 | 10 | import os 11 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 12 | 13 | import pandas as pd 14 | from vllm import LLM, SamplingParams 15 | import json 16 | from transformers import AutoTokenizer 17 | 18 | 19 | # TODO: change data path, there should exist the key "prompt" 20 | with open('', encoding='utf-8') as f: 21 | data = json.load(f)[args.begin: args.end] 22 | 23 | tmp = [{'messages': [{'role': 'user', 'content': i['prompt']}]} for i in data] 24 | 25 | 26 | # Create an LLM. 27 | # TODO 28 | model_path = 'Your-PATH-Here' 29 | 30 | llm = LLM(model=model_path, trust_remote_code=True) 31 | 32 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 33 | tokenizer.padding_side = "left" 34 | prompts = [] 35 | 36 | for i in tmp: 37 | try: 38 | prompts.append(tokenizer.apply_chat_template(i['messages'], add_generation_prompt=True, tokenize=True)) 39 | except Exception as e: 40 | print(e) 41 | continue 42 | print("data numbers: ", len(prompts)) 43 | print(tokenizer.decode(prompts[0])) 44 | 45 | # stop_token_ids = [151329, 151336, 151338] 46 | # Create a sampling params object. 47 | sampling_params = SamplingParams( 48 | temperature=0.9, 49 | top_p=0.9, 50 | max_tokens=2048, 51 | # repetition_penalty=1.05, 52 | n=5, 53 | # stop_token_ids=stop_token_ids 54 | stop=['<|eot_id|>', ''] 55 | ) 56 | 57 | # from IPython import embed; embed() 58 | 59 | outputs = llm.generate(prompt_token_ids=prompts, sampling_params=sampling_params) 60 | # Print the outputs. 61 | res = [] 62 | for output, i in zip(outputs, data): 63 | i['output'] = [output.outputs[j].text.strip() for j in range(5)] 64 | i['generator'] = 'llama3-8b-instruct' 65 | # i['generator'] = 'mistrial-7b-instruct' 66 | res.append(i) 67 | 68 | with open(args.output_path, 'w', encoding='utf-8') as f: 69 | json.dump(res, f, indent=4, ensure_ascii=False) 70 | -------------------------------------------------------------------------------- /src/infer.sh: -------------------------------------------------------------------------------- 1 | gpus=(0 1 2 3 4 5 6 7) 2 | batch=8 # num gpus 3 | num=10000 # samples 4 | 5 | len=$((num / batch + 1)) 6 | echo $len 7 | 8 | l=0 9 | r=$len 10 | b=() 11 | e=() 12 | for i in `seq 1 $batch` 13 | do 14 | b+=($l) 15 | e+=($r) 16 | l=$((l+len)) 17 | r=$((r+len)) 18 | done 19 | echo ${b[@]} 20 | echo ${e[@]} 21 | 22 | for i in `seq 0 $((batch-1))` 23 | do 24 | ( 25 | python infer.py --begin ${b[$i]} \ 26 | --end ${e[$i]} \ 27 | --gpu ${gpus[$i]} \ 28 | --output_path /vllm_output_$i.json 29 | echo $i 30 | )& 31 | done 32 | wait 33 | echo "all weakup" 34 | -------------------------------------------------------------------------------- /src/judge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--begin', type=int) 5 | parser.add_argument('--end', type=int) 6 | parser.add_argument('--gpu', type=int) 7 | parser.add_argument('--output_path', type=str) 8 | args = parser.parse_args() 9 | 10 | import os 11 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 12 | 13 | import pandas as pd 14 | from vllm import LLM, SamplingParams 15 | import json 16 | from transformers import AutoTokenizer 17 | 18 | 19 | # TODO: change data file 20 | with open('', encoding='utf-8') as f: 21 | data = json.load(f)[args.begin: args.end] 22 | 23 | 24 | def build_judge_template(prompt, response): 25 | return """Please act an expert in evaluating the capabilities of instruction-following. In the instruction-following task, the Output needs to honestly/precisely/closely follow the given Prompt. 26 | Your task is to carefully judge whether the Output to honestly/precisely/closely follows the given Prompt. If there are any constraints in the Prompt that are not satisfied by the Output, please list all the constraints that are not satisfied. 27 | 28 | Prompt: “{}” 29 | 30 | Output: “{}” 31 | 32 | Please carefully judge if each constraint is perfectly satisfied and give a final judgement weather the Output accurately follows the Prompt in the following format: 33 | Step-by-step verification: xxx 34 | Final Judgement (if the Output accurately follows the Prompt): (Yes or No)""".format(prompt, response) 35 | 36 | 37 | # TODO: check keys 38 | tmp = [{'messages': [{'role': 'user', 'content': build_judge_template(i['prompt'], i['response'])}]} for i in data] 39 | 40 | # TODO 41 | model_path = 'Your-PATH-Here' 42 | 43 | 44 | llm = LLM(model=model_path, trust_remote_code=True) 45 | 46 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 47 | tokenizer.padding_side = "left" 48 | prompts = [] 49 | 50 | for i in tmp: 51 | try: 52 | prompts.append(tokenizer.apply_chat_template(i['messages'], add_generation_prompt=True, tokenize=True)) 53 | except Exception as e: 54 | print(e) 55 | continue 56 | print("data numbers: ", len(prompts)) 57 | print(tokenizer.decode(prompts[0])) 58 | 59 | # stop_token_ids = [151329, 151336, 151338] 60 | # Create a sampling params object. 61 | sampling_params = SamplingParams( 62 | temperature=0.7, 63 | top_p=1.0, 64 | max_tokens=2048, 65 | # repetition_penalty=1.05, 66 | n=5, 67 | # n=1, 68 | # stop_token_ids=stop_token_ids 69 | stop=['<|eot_id|>', ''] 70 | ) 71 | 72 | 73 | outputs = llm.generate(prompt_token_ids=prompts, sampling_params=sampling_params) 74 | 75 | 76 | res = [] 77 | for output, i in zip(outputs, data): 78 | i['critique'] = [output.outputs[j].text.strip() for j in range(5)] 79 | 80 | i['generator'] = 'llama3-8b-instruct' 81 | res.append(i) 82 | 83 | with open(args.output_path, 'w', encoding='utf-8') as f: 84 | json.dump(res, f, indent=4, ensure_ascii=False) 85 | -------------------------------------------------------------------------------- /src/judge.sh: -------------------------------------------------------------------------------- 1 | gpus=(0 1 2 3 4 5 6 7) 2 | batch=8 # num gpus 3 | num=50000 # samples 4 | 5 | len=$((num / batch + 1)) 6 | echo $len 7 | 8 | l=0 9 | r=$len 10 | b=() 11 | e=() 12 | for i in `seq 1 $batch` 13 | do 14 | b+=($l) 15 | e+=($r) 16 | l=$((l+len)) 17 | r=$((r+len)) 18 | done 19 | echo ${b[@]} 20 | echo ${e[@]} 21 | 22 | for i in `seq 0 $((batch-1))` 23 | do 24 | ( 25 | python judge.py --begin ${b[$i]} \ 26 | --end ${e[$i]} \ 27 | --gpu ${gpus[$i]} \ 28 | --output_path /vllm_output_$i.json 29 | echo $i 30 | )& 31 | done 32 | wait 33 | echo "all weakup" 34 | -------------------------------------------------------------------------------- /src/models/get_response.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | import random 3 | 4 | # Set OpenAI's API key and API base to use vLLM's API server. 5 | # openai_api_key = "EMPTY" 6 | # openai_api_base = "http://localhost:8000/v1" 7 | 8 | # client = OpenAI( 9 | # api_key=openai_api_key, 10 | # base_url=openai_api_base, 11 | # ) 12 | 13 | 14 | def unwrap_refine_cot(text): 15 | return text.split('[[start]]')[1].split('[[end]]')[0].strip() 16 | 17 | 18 | def get_acc(resp): 19 | if 'Final Judgement' not in resp: 20 | return -1 21 | resp = resp.split('Final Judgement')[1].strip() 22 | if not resp.count('Yes') and not resp.count('No'): 23 | return -1 24 | if resp.count("Yes"): 25 | if resp.find("No") != -1 and resp.find("No") < resp.find("Yes"): 26 | return 0 27 | else: 28 | return 1 29 | elif resp.count("No"): 30 | if resp.find("Yes") != -1 and resp.find("Yes") < resp.find("No"): 31 | return 1 32 | else: 33 | return 0 34 | 35 | # given prompt, generate proposal under instruction, unwrap is required 36 | def get_refine(message, method='llama', model_path='', temperature=0.7, top_p=1.0, do_sample=True, max_tokens=2048, client=None, n=1): 37 | response = [] 38 | refine_cot = [] 39 | cnt = 3 40 | if method == 'llama' or method == 'mistral': 41 | while not response and cnt: 42 | try: 43 | # print("cnt: ", cnt) 44 | cnt -= 1 45 | seed = random.randint(1, 100000000) 46 | # print(temperature, top_p, seed) 47 | # print(client.base_url) 48 | chat_response = client.chat.completions.create( 49 | model=model_path, 50 | messages=message, 51 | temperature=temperature, 52 | seed=seed, 53 | top_p=top_p, 54 | max_tokens=max_tokens, 55 | n=n 56 | ) 57 | # print("refine get!") 58 | for i in range(n): 59 | response.append(unwrap_refine_cot(chat_response.choices[i].message.content)) 60 | refine_cot.append(chat_response.choices[i].message.content) 61 | except Exception as e: 62 | print(e) 63 | if len(response) <= n//2: 64 | response = [] 65 | refine_cot = [] 66 | continue 67 | if n > 1: 68 | if not response: 69 | print(f'获取<{method}>回复失败!\n') 70 | return [], [] 71 | return response, refine_cot 72 | else: 73 | if not response: 74 | print(f'获取<{method}>回复失败!\n') 75 | return "", "" 76 | return response[0], refine_cot[0] 77 | 78 | else: 79 | print('尚未支持这种回复获取方法!\n') 80 | assert False 81 | 82 | 83 | # given prompt + answer, find its value 84 | # if you use api, unwrap is required. if you use local value model, the value is directly obtained 85 | def get_value(message, method='llama', model_path='', temperature=0.7, top_p=1.0, do_sample=True, max_tokens=2048, n=1, client=None): 86 | 87 | assert (do_sample and n > 1 and temperature > 0.0) or (not do_sample and n == 1 and temperature == 0.0) 88 | 89 | critique = [] 90 | acc = [] 91 | cnt = 3 92 | if method == 'llama' or method == 'mistral': 93 | while not critique and cnt: 94 | try: 95 | # print("cnt: ", cnt) 96 | cnt -= 1 97 | chat_response = client.chat.completions.create( 98 | model=model_path, 99 | messages=message, 100 | temperature=temperature, 101 | top_p=top_p, 102 | max_tokens=max_tokens, 103 | n=n 104 | ) 105 | # print("value get!") 106 | for i in range(n): 107 | tmp_res = chat_response.choices[i].message.content 108 | critique.append(tmp_res) 109 | acc.append(get_acc(tmp_res)) 110 | if acc.count(0) == acc.count(1): 111 | critique = [] 112 | acc = [] 113 | continue 114 | except Exception as e: 115 | critique = [] 116 | acc = [] 117 | continue 118 | 119 | if not critique: 120 | print(f'获取<{method}> value失败!\n') 121 | return "", 0 122 | 123 | if acc.count(0) > acc.count(1): 124 | label = 0 125 | else: 126 | label = 1 127 | value = acc.count(1) / (acc.count(0) + acc.count(1)) 128 | for i in range(n): 129 | if acc[i] == label: 130 | response = critique[i] 131 | break 132 | 133 | return response, value 134 | 135 | else: 136 | print('尚未支持这种回复获取方法!\n') 137 | assert False -------------------------------------------------------------------------------- /src/process_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import re 4 | 5 | 6 | def process_gen_res(input_path, output_path): 7 | d = [] 8 | for i in range(8): 9 | with open(f'{input_path}/vllm_output_{i}.json', encoding='utf-8') as f: 10 | d.extend(json.load(f)) 11 | 12 | num = 0 13 | res = [] 14 | for i in d: 15 | for j in range(5): 16 | tmp = i.copy() 17 | tmp['id'] = '{}_{}'.format(num, j) 18 | tmp['response'] = i['output'][j] 19 | res.append(tmp) 20 | num += 1 21 | 22 | print(res[:6]) 23 | print(len(res)) 24 | with open(output_path, 'w', encoding='utf-8') as f: 25 | json.dump(res, f, indent=4, ensure_ascii=False) 26 | 27 | 28 | def get_acc(resp): 29 | if 'Final Judgement' not in resp: 30 | return -1 31 | resp = resp.split('Final Judgement')[1].strip() 32 | if not resp.count('Yes') and not resp.count('No'): 33 | return -1 34 | if resp.count("Yes"): 35 | if resp.find("No") != -1 and resp.find("No") < resp.find("Yes"): 36 | return 0 37 | else: 38 | return 1 39 | elif resp.count("No"): 40 | if resp.find("Yes") != -1 and resp.find("Yes") < resp.find("No"): 41 | return 1 42 | else: 43 | return 0 44 | 45 | 46 | 47 | def process_score_res_voting(input_path, output_path): 48 | l = [] 49 | for i in range(8): 50 | with open(f'{input_path}/vllm_output_{i}.json', encoding='utf-8') as f: 51 | l.extend(json.load(f)) 52 | 53 | 54 | d = [] 55 | for i in l: 56 | tmp = [] 57 | for j in i['critique']: 58 | tmp.append(get_acc(j)) 59 | 60 | if tmp.count(0) == tmp.count(1): 61 | continue 62 | elif tmp.count(0) > tmp.count(1): 63 | i['acc'] = 0 64 | else: 65 | i['acc'] = 1 66 | 67 | for j in range(len(tmp)): 68 | if tmp[j] == i['acc']: 69 | i['critique'] = i['critique'][j] 70 | break 71 | 72 | d.append(i) 73 | 74 | res = {} 75 | for i in d: 76 | prompt_id = int(i['id'].split('_')[0]) 77 | res_id = int(i['id'].split('_')[1]) 78 | if prompt_id not in res: 79 | res[prompt_id] = { 80 | 'original_prompt': i['original_prompt'], 81 | 'prompt': i['prompt'], 82 | 'output': [], 83 | 'acc': [], 84 | 'critique': [], 85 | 'prompt_id': prompt_id 86 | } 87 | res[prompt_id]['output'].append(i['output'][res_id]) 88 | res[prompt_id]['acc'].append(i['acc']) 89 | res[prompt_id]['critique'].append(i['critique']) 90 | 91 | rft_data = [] 92 | for i in res: 93 | if min(res[i]['acc']) == 0: 94 | rft_data.append(res[i]) 95 | 96 | bad_data = [] 97 | for i in rft_data: 98 | tmp = [] 99 | for j in range(len(i['acc'])): 100 | if i['acc'][j] == 0: 101 | tmp.append(j) 102 | random.shuffle(tmp) 103 | for j in tmp[:1]: 104 | bad_data.append({ 105 | 'prompt': i['prompt'], 106 | 'response': i['output'][j], 107 | 'critique': i['critique'][j], 108 | 'acc': 0, 109 | 'prompt_id': i['prompt_id'] 110 | }) 111 | print(len(bad_data)) 112 | with open(output_path, 'w', encoding='utf-8') as f: 113 | json.dump(bad_data, f, indent=4, ensure_ascii=False) 114 | 115 | 116 | def process_dpo_data(input_path, output_path): 117 | 118 | with open(input_path, encoding='utf-8') as f: 119 | l = f.readlines() 120 | 121 | d = [] 122 | for i in l: 123 | i = json.loads(i) 124 | if i['final_node']['value'] > 0.5: 125 | d.append(i) 126 | 127 | data = [] 128 | for i in d: 129 | data.append({ 130 | 'messages': [ 131 | { 132 | "role": "user", 133 | "content": i['prompt'] 134 | } 135 | ], 136 | "chosen": { 137 | "role": "assistant", 138 | "content": i['final_node']['response'] 139 | }, 140 | "rejected": { 141 | "role": "assistant", 142 | "content": i['response'] 143 | } 144 | }) 145 | with open(output_path, 'w', encoding='utf-8') as f: 146 | json.dump(data, f, indent=4, ensure_ascii=False) 147 | 148 | judge_template = """Please act an expert in evaluating the capabilities of instruction-following. In the instruction-following task, the Output needs to honestly/precisely/closely follow the given Prompt. 149 | Your task is to carefully judge whether the Output to honestly/precisely/closely follows the given Prompt. If there are any constraints in the Prompt that are not satisfied by the Output, please list all the constraints that are not satisfied. 150 | 151 | Prompt: “{}” 152 | 153 | Output: “{}” 154 | 155 | Please carefully judge if each constraint is perfectly satisfied and give a final judgement weather the Output accurately follows the Prompt in the following format: 156 | Step-by-step verification: xxx 157 | Final Judgement (if the Output accurately follows the Prompt): (Yes or No)""" 158 | 159 | def process_rft_data(input_tree_search_path, input_judge_path, output_path): 160 | 161 | with open(input_tree_search_path, encoding='utf-8') as f: 162 | l = f.readlines() 163 | refine_data = [] 164 | for i in l: 165 | i = json.loads(i) 166 | if i['final_node']['value'] > 0.5: 167 | refine_data.append({ 168 | 'messages': [ 169 | {"role": "user", "content": judge_template.format(i['prompt'], i['final_node']['parent_res'])}, 170 | {"role": "assistant", "content": i['final_node']['parent_critique']}, 171 | {"role": "user", "content": """Based on your judgement, refine the Output to make sure it can honestly/precisely/closely follows the given Prompt. 172 | 173 | 174 | Please carefully refine the Output to meet all the constraints in the Prompt. 175 | 176 | Please format like this: 177 | Reflection on how to refine the Output: xxx 178 | Final Refined Output: [[start]] xxx [[end]]"""}, 179 | { 180 | "role": "assistant", 181 | "content": i['final_node']['refine_cot'] 182 | } 183 | ] 184 | }) 185 | pos = [] 186 | neg = [] 187 | for i in l: 188 | i = json.loads(i) 189 | if len(i['critique_rft']['positive']): 190 | pos.append({ 191 | 'messages': [ 192 | {"role": "user", "content": i['critique_rft']['positive'][0]['prompt']}, 193 | {"role": "assistant", "content": i['critique_rft']['positive'][0]['response']} 194 | ] 195 | }) 196 | if len(i['critique_rft']['negative']): 197 | neg.append({ 198 | 'messages': [ 199 | {"role": "user", "content": i['critique_rft']['negative'][0]['prompt']}, 200 | {"role": "assistant", "content": i['critique_rft']['negative'][0]['response']} 201 | ] 202 | }) 203 | l = [] 204 | for i in range(8): 205 | with open(f'{input_judge_path}/vllm_output_{i}.json', encoding='utf-8') as f: 206 | l.extend(json.load(f)) 207 | 208 | d_0 = [] 209 | d_1 = [] 210 | for i in l: 211 | 212 | tmp = [] 213 | for j in i['critique']: 214 | tmp.append(get_acc(j)) 215 | 216 | if tmp.count(0) == tmp.count(1): 217 | continue 218 | elif tmp.count(0) > tmp.count(1): 219 | i['acc'] = 0 220 | else: 221 | i['acc'] = 1 222 | 223 | good_res = None 224 | bad_res = None 225 | for j in range(len(tmp)): 226 | if good_res and bad_res: 227 | break 228 | if tmp[j] == i['acc']: 229 | good_res = i['critique'][j] 230 | elif tmp[j] == 1-i['acc']: 231 | bad_res = i['critique'][j] 232 | if not (good_res and bad_res): 233 | continue 234 | i['good_res'] = good_res 235 | i['bad_res'] = bad_res 236 | i['judge_prompt'] = judge_template.format(i['prompt'], i['response']) 237 | if i['acc'] == 0: 238 | d_0.append(i) 239 | else: 240 | d_1.append(i) 241 | random.shuffle(pos) 242 | random.shuffle(neg) 243 | random.shuffle(d_0) 244 | random.shuffle(d_1) 245 | random.shuffle(refine_data) 246 | 247 | sft_data = [] 248 | p = [] 249 | for i in refine_data[:4000]: 250 | if i['messages'][0]['content'] not in p: 251 | p.append(i['messages'][0]['content']) 252 | sft_data.append(i) 253 | for i in pos[:1500]: 254 | if i['messages'][0]['content'] not in p: 255 | p.append(i['messages'][0]['content']) 256 | sft_data.append(i) 257 | for i in neg[:1000]: 258 | if i['messages'][0]['content'] not in p: 259 | p.append(i['messages'][0]['content']) 260 | sft_data.append(i) 261 | for i in d_0[:500]: 262 | if i['judge_prompt'] not in p: 263 | p.append(i['judge_prompt']) 264 | sft_data.append({ 265 | 'messages': [ 266 | {"role": "user", "content": i['judge_prompt']}, 267 | {"role": "assistant", "content": i['good_res']} 268 | ] 269 | }) 270 | for i in d_1[:2500]: 271 | if i['judge_prompt'] not in p: 272 | p.append(i['judge_prompt']) 273 | sft_data.append({ 274 | 'messages': [ 275 | {"role": "user", "content": i['judge_prompt']}, 276 | {"role": "assistant", "content": i['good_res']} 277 | ] 278 | }) 279 | 280 | random.shuffle(sft_data) 281 | with open(output_path, 'w', encoding='utf-8') as f: 282 | json.dump(sft_data, f, indent=4, ensure_ascii=False) 283 | 284 | 285 | if __name__ == '__main__': 286 | input_path = '' 287 | output_path = '' 288 | process_gen_res(input_path, output_path) 289 | 290 | 291 | input_path = '' 292 | output_path = '' 293 | process_score_res_voting(input_path, output_path) 294 | 295 | 296 | input_path = '' 297 | output_path = '' 298 | process_dpo_data(input_path, output_path) 299 | 300 | 301 | input_tree_search_path = '' 302 | input_judge_path = '' 303 | output_path = '' 304 | process_rft_data(input_tree_search_path, input_judge_path, output_path) -------------------------------------------------------------------------------- /src/tasks/prompts.py: -------------------------------------------------------------------------------- 1 | 2 | judge_template = """Please act an expert in evaluating the capabilities of instruction-following. In the instruction-following task, the Output needs to honestly/precisely/closely follow the given Prompt. 3 | Your task is to carefully judge whether the Output to honestly/precisely/closely follows the given Prompt. If there are any constraints in the Prompt that are not satisfied by the Output, please list all the constraints that are not satisfied. 4 | 5 | Prompt: “{}” 6 | 7 | Output: “{}” 8 | 9 | Please carefully judge if each constraint is perfectly satisfied and give a final judgement weather the Output accurately follows the Prompt in the following format: 10 | Step-by-step verification: xxx 11 | Final Judgement (if the Output accurately follows the Prompt): (Yes or No)""" -------------------------------------------------------------------------------- /src/tasks/science.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from tasks.prompts import * 4 | 5 | 6 | # mode: 'cot', 'tot', 'mcts' 7 | # method: 'glm', 'gpt', 'local' 8 | class SearchTask(object): 9 | def __init__(self, prompt, response, critique): 10 | super().__init__() 11 | self.prompt = prompt 12 | self.response = response 13 | self.critique = critique 14 | self.value_cache = {} 15 | self.critique_cache = {} 16 | 17 | def clear_cache(self): 18 | self.value_cache = {} 19 | self.critique_cache = {} 20 | 21 | 22 | @staticmethod 23 | def build_refine_message(prompt, response, critique): 24 | return [ 25 | {"role": "user", "content": judge_template.format(prompt, response)}, 26 | {"role": "assistant", "content": critique}, 27 | {"role": "user", "content": """Based on your judgement, refine the Output to make sure it can honestly/precisely/closely follows the given Prompt. 28 | 29 | 30 | Please carefully refine the Output to meet all the constraints in the Prompt. 31 | 32 | Please format like this: 33 | Reflection on how to refine the Output: xxx 34 | Final Refined Output: [[start]] xxx [[end]]"""} 35 | ] 36 | 37 | @staticmethod 38 | def build_judge_message(prompt, response): 39 | return [ 40 | {"role": "user", "content": judge_template.format(prompt, response)} 41 | ] 42 | -------------------------------------------------------------------------------- /src/tree_search.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import multiprocessing 3 | from multiprocessing import Manager 4 | import json 5 | from tqdm import tqdm 6 | import os 7 | import time 8 | # import pandas as pd 9 | import random 10 | from ToT.task import * 11 | 12 | 13 | 14 | 15 | def chat_gpt(messages, counter, error_count): 16 | responses = [] 17 | for i, m in enumerate(messages): 18 | try: 19 | # TODO: change model path, it should be same as the name deployed with vllm 20 | task = ToT_Task(m['prompt'], m['response'], m['critique'], root_value=0, propose_method='llama', algorithm='bfs', refine_model_path='', critique_model_path='', openai_api_base=m['openai_api_base']) 21 | solution, root, final_node = task.run() 22 | m['final_node'] = { 23 | 'response': final_node.response, 24 | 'refine_cot': final_node.refine_cot, 25 | 'parent_res': final_node.parent.response, 26 | 'parent_critique': final_node.parent.critique, 27 | 'value': final_node.V, 28 | } 29 | negative_critique, positive_critique = root.getCritiqueRFT() 30 | if len(negative_critique): 31 | negative_critique = random.sample(negative_critique, 1) 32 | if len(positive_critique): 33 | positive_critique = random.sample(positive_critique, 1) 34 | m['critique_rft'] = { 35 | 'positive': positive_critique, 36 | 'negative': negative_critique 37 | } 38 | # 保存响应到文件 39 | with open(output_file, 'a', encoding='utf-8') as f: 40 | print(json.dumps(m, ensure_ascii=False), file=f) 41 | 42 | responses.append(0) 43 | 44 | # Increment and print the counter 45 | counter.value += 1 46 | except Exception as e: 47 | error_count.value += 1 48 | print(e) 49 | print('running time:{} finished number:{} skipped number:{}'.format(time.time()-s_time, counter.value, error_count.value), end='\r') 50 | 51 | return responses 52 | 53 | 54 | def multi_process_chat_gpt(messages_list, num_processes): 55 | # 将messages_list分为num_processes个子列表 56 | sublists = [messages_list[i::num_processes] for i in range(num_processes)] 57 | 58 | # Create a shared counter 59 | manager = Manager() 60 | counter = manager.Value('i', 0) 61 | error_count = manager.Value('j', 0) 62 | 63 | with multiprocessing.Pool() as pool: 64 | all_responses = pool.starmap(chat_gpt, [(sublist, counter, error_count) for sublist in sublists]) 65 | # 将所有响应合并为一个列表 66 | return [item for sublist in all_responses for item in sublist] 67 | 68 | 69 | def get_messages_list(): 70 | 71 | evaluated = [] 72 | with open(output_file, encoding='utf-8') as f: 73 | lines = f.readlines() 74 | for i in lines: 75 | evaluated.append(json.loads(i)['prompt_id']) 76 | 77 | with open(input_file, encoding='utf-8') as f: 78 | d = json.load(f) 79 | 80 | 81 | messages_list = [] 82 | 83 | num = 0 84 | for i in d[:]: 85 | if i['prompt_id'] in evaluated: 86 | continue 87 | i['openai_api_base'] = f"http://localhost:800{num}/v1" 88 | # num += 1 89 | # num %= 8 90 | messages_list.append(i) 91 | return messages_list 92 | 93 | 94 | 95 | if __name__ == '__main__': 96 | 97 | 98 | # TODO: change file path 99 | input_file = '' 100 | output_file = '' 101 | 102 | if not os.path.exists(output_file): 103 | x = open(output_file, 'w') 104 | x.close() 105 | messages_list = get_messages_list() 106 | 107 | print("total num: ", len(messages_list)) 108 | s_time = time.time() 109 | responses = multi_process_chat_gpt(messages_list, num_processes=32) --------------------------------------------------------------------------------