├── LICENSE
├── README.md
├── assets
└── abs.png
└── src
├── ToT
├── base.py
├── bfs.py
├── dfs.py
└── task.py
├── configs
├── dpo.yaml
└── sft.yaml
├── infer.py
├── infer.sh
├── judge.py
├── judge.sh
├── models
└── get_response.py
├── process_data.py
├── tasks
├── prompts.py
└── science.py
└── tree_search.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SPaR
2 | ## Self-Play with Tree-Search Refinement to Improve Instruction-Following in Large Language Models
3 |
4 |
5 | 🤗 Data • 📃 Paper
6 |
7 |
8 | SPaR focuses on creating interference-free preference pairs for effective self- improvement. An example of the interfering factors (*story content*) in independently sampled multiple responses (Left). Refined response pairs exclude these factors, highlight the key difference (*ending sentence*), and lead to improved performance on iteratively trained LLaMA3-8B-Instruct (Right).
9 |
10 |
11 |

12 |
13 |
14 |
15 |
16 | ## Table of Contents
17 | - [Data](#data)
18 | - [Quick Start](#quick-start)
19 | - [Data Construction](#data-construction)
20 | - [Model Training](#model-training)
21 | - [Citation](#citation)
22 |
23 | ## Data
24 |
25 | ### SPaR dataset
26 | SPaR Dataset can be found on [Hugging Face](https://huggingface.co/datasets/CCCCCC/SPaR).
27 |
28 | We provide a high-quality SFT dataset for instruction-following tasks and the data for iterative self-training.
29 |
30 |
31 | ## Quick Start
32 | For all codes, we have added `#TODO` comments to indicate places in the code that need modification before running. Please update the relevant parts as noted before executing each file.
33 |
34 | ### Data Construction
35 | To construct the iterative training data yourself, run the following command
36 | ```bash
37 | cd src
38 |
39 | bash infer.sh
40 |
41 | python process_data.py
42 |
43 | bash judge.py
44 |
45 | python process_data.py
46 |
47 | vllm serve
48 |
49 | python tree_search.py
50 |
51 | python process_data.py
52 | ```
53 |
54 | ### Model Training
55 | If you want to train your own model,
56 | please run the following command:
57 | ```bash
58 | cd src
59 |
60 | # dpo
61 | llamafactory-cli train configs/dpo.yaml
62 |
63 | # sft
64 | llamafactory-cli train configs/sft.yaml
65 |
66 | ```
67 |
68 |
69 | ## Acknowledgement
70 | - Training code: [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)
71 | - Tree-search implementation: [Rest-MCTS*](https://github.com/THUDM/ReST-MCTS)
72 |
73 | ## Citation
74 | ```
75 | @misc{cheng2024sparselfplaytreesearchrefinement,
76 | title={SPaR: Self-Play with Tree-Search Refinement to Improve Instruction-Following in Large Language Models},
77 | author={Jiale Cheng and Xiao Liu and Cunxiang Wang and Xiaotao Gu and Yida Lu and Dan Zhang and Yuxiao Dong and Jie Tang and Hongning Wang and Minlie Huang},
78 | year={2024},
79 | eprint={2412.11605},
80 | archivePrefix={arXiv},
81 | primaryClass={cs.CL},
82 | url={https://arxiv.org/abs/2412.11605},
83 | }
84 | ```
--------------------------------------------------------------------------------
/assets/abs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thu-coai/SPaR/4c1010addd771c580dc98288f3d63960364c5a5e/assets/abs.png
--------------------------------------------------------------------------------
/src/ToT/base.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | from tasks.prompts import *
4 |
5 |
6 | class Node(object):
7 | def __init__(self, prompt: str, response: str, critique: str, value=0, parent=None, depth=0, refine_cot=""):
8 | # self.pcd = pcd # 当前步骤
9 | self.children = []
10 | self.V = value
11 | self.parent = parent
12 | # self.y = '' # 全部步骤
13 | self.depth = depth
14 | self.visit_sequence = 0
15 | self.final_ans_flag = 0
16 | self.prompt = prompt
17 | self.response = response
18 | self.critique = critique
19 | self.refine_cot = refine_cot
20 |
21 |
22 | def append_children(self, new_res: str, refine_cot: str):
23 | node = Node(self.prompt, new_res, '', 0, self, self.depth + 1, refine_cot)
24 | # node.update_y_from_parent()
25 | self.children.append(node)
26 | return self, node
27 |
28 | def update_y_from_parent(self):
29 | if self.parent is None:
30 | self.y = self.pcd
31 | else:
32 | self.y = self.parent.y + self.pcd
33 |
34 | def update_value(self, value):
35 | self.V = value
36 |
37 | def update_critique(self, critique):
38 | self.critique = critique
39 |
40 | def getRefinement(self):
41 | if not self.children:
42 | return self, self.V
43 | max_V = self.V
44 | max_node = self
45 | for child in self.children:
46 | subNode, subValue = child.getBestV()
47 | if subValue > max_V:
48 | max_V = subValue
49 | max_node = subNode
50 | return max_node, max_V
51 |
52 | def getCritiqueRFT(self, end_gate=0.5):
53 | if not self.children:
54 | return [], []
55 | negative = []
56 | positive = []
57 | for child in self.children:
58 | negative_critique, positive_critique = child.getCritiqueRFT()
59 | if child.V > end_gate and child.V < 1.0:
60 | positive_critique.append({'prompt': judge_template.format(child.prompt, child.response), 'response': child.critique})
61 | elif child.V < end_gate and child.V > 0.0:
62 | negative_critique.append({'prompt': judge_template.format(child.prompt, child.response), 'response': child.critique})
63 |
64 | negative.extend(negative_critique)
65 | positive.extend(positive_critique)
66 |
67 | return negative, positive
68 |
69 |
70 | def getBestV(self): # 获取子树最大价值节点
71 | if not self.children:
72 | return self, self.V
73 | max_V = self.V
74 | max_node = self
75 | for child in self.children:
76 | subNode, subValue = child.getBestV()
77 | if subValue >= max_V:
78 | max_V = subValue
79 | max_node = subNode
80 | return max_node, max_V
81 |
82 | def get_multiply_value(self):
83 | if self.depth == 0:
84 | return 0
85 | multi_value = self.V
86 | cur_node = self.parent
87 | while cur_node.depth > 0:
88 | multi_value = multi_value * cur_node.V
89 | cur_node = cur_node.parent
90 | return multi_value
91 |
92 |
93 | class SolutionStep(object):
94 | def __init__(self, x, stp, all_steps, score, step_num):
95 | self.x = x
96 | self.stp = stp
97 | self.all_steps = all_steps
98 | self.score = score
99 | self.step_num = step_num
100 |
101 |
102 | def rand_select(data_list: list, probs: list): # 按概率抽样
103 | assert len(data_list) == len(probs), "length do not match!"
104 | probs_norm = []
105 | sum_prob = sum(probs)
106 | for i in probs:
107 | probs_norm.append(i / sum_prob)
108 | intervals = []
109 | count = 0
110 | for i in probs_norm:
111 | count = count + i
112 | intervals.append(count)
113 | # assert count == 1, "probs error!"
114 | intervals[len(intervals) - 1] = 1
115 | index = 0
116 | rand_prob = random.random()
117 | while rand_prob >= intervals[index]:
118 | index = index + 1
119 | return index, data_list[index]
--------------------------------------------------------------------------------
/src/ToT/bfs.py:
--------------------------------------------------------------------------------
1 | from ToT.base import Node, rand_select
2 |
3 |
4 | def BFS(tot_task):
5 | root = Node(tot_task.prompt, tot_task.response, tot_task.critique, tot_task.root_v)
6 | cur_nodes = [root]
7 | for depth in range(tot_task.max_depth):
8 | candidates = []
9 | for node in cur_nodes:
10 |
11 | for i in range(tot_task.branch):
12 | # print("depth: ", node.depth, "branch: ", i)
13 | refine_res = ''
14 | cnt = 3
15 | while not refine_res and cnt:
16 | # get refinement
17 | cnt -= 1
18 | refine_res, refine_cot = tot_task.get_next_step(node.prompt, node.response, node.critique)
19 | # print(refine_res)
20 |
21 | if not refine_res:
22 | continue
23 |
24 | node, child = node.append_children(refine_res, refine_cot)
25 |
26 | # get judgement, value is the pass rate
27 | critique, value = tot_task.get_step_value(child.prompt, child.response)
28 |
29 | if not critique:
30 | node.children.pop()
31 | del child
32 | continue
33 |
34 | child.update_value(value)
35 | child.update_critique(critique)
36 |
37 | child.visit_sequence = tot_task.node_count
38 | tot_task.update_count()
39 | candidates.append(child)
40 |
41 | tot_task.budget -= 1
42 |
43 | if tot_task.budget == 0:
44 | break
45 |
46 | if tot_task.budget == 0:
47 | break
48 |
49 | if not candidates:
50 | break
51 |
52 |
53 | ranked_candidates = sorted(candidates, key=lambda item: item.V, reverse=True)
54 | if ranked_candidates[0].V > tot_task.end_gate:
55 | print('已找到最终解!\n')
56 | ranked_candidates[0].final_ans_flag = 1
57 | return ranked_candidates[0].response, root, ranked_candidates[0]
58 |
59 | if tot_task.budget == 0:
60 | break
61 |
62 | if tot_task.select_method == 'greedy':
63 | cur_nodes = ranked_candidates[:min(tot_task.select_branch, tot_task.branch, len(ranked_candidates))]
64 |
65 | else:
66 | idx_list = []
67 | cur_nodes = []
68 | for j in range(min(tot_task.select_branch, tot_task.branch)):
69 | idx, node = rand_select(ranked_candidates, [item.V for item in ranked_candidates])
70 | if idx not in idx_list:
71 | idx_list.append(idx)
72 | cur_nodes.append(node)
73 | cur_nodes = sorted(cur_nodes, key=lambda item: item.V, reverse=True)
74 |
75 | print('未找到满足要求价值的解答,采用最高价值价值解答代替。\n')
76 | max_node, max_V = root.getBestV()
77 | max_node.final_ans_flag = 1
78 | return max_node.response, root, max_node
--------------------------------------------------------------------------------
/src/ToT/dfs.py:
--------------------------------------------------------------------------------
1 | from ToT.base import Node, rand_select
2 |
3 |
4 | def DFS_sub(tot_task, node):
5 | if node.depth >= tot_task.max_depth:
6 | print('达到最大深度限制!\n')
7 | return "", node, None
8 |
9 | candidates = []
10 |
11 | for i in range(tot_task.branch):
12 | refine_res = ''
13 | cnt = 3
14 | # print("depth: ", node.depth, "branch: ", i)
15 | while not refine_res and cnt:
16 | # get refinement
17 | # print("cnt: ", cnt)
18 | cnt -= 1
19 | refine_res, refine_cot = tot_task.get_next_step(node.prompt, node.response, node.critique)
20 | # print(2222)
21 | if not refine_res:
22 | continue
23 | # print(refine_res)
24 | node, child = node.append_children(refine_res, refine_cot)
25 |
26 | # get judgement, value is the pass rate
27 | # print(3333)
28 | critique, value = tot_task.get_step_value(child.prompt, child.response)
29 | # print(4444)
30 |
31 | if not critique:
32 | node.children.pop()
33 | del child
34 | continue
35 |
36 | child.update_value(value)
37 | child.update_critique(critique)
38 |
39 | child.visit_sequence = tot_task.node_count
40 | tot_task.update_count()
41 | candidates.append(child)
42 |
43 | tot_task.budget -= 1
44 |
45 | if tot_task.budget <= 0:
46 | break
47 |
48 | if not candidates:
49 | print('未找到合适的下一步!\n')
50 | return "", node, None
51 | ranked_candidates = sorted(candidates, key=lambda item: item.V, reverse=True)
52 | if ranked_candidates[0].V > tot_task.end_gate:
53 | ranked_candidates[0].final_ans_flag = 1
54 | return ranked_candidates[0].response, node, ranked_candidates[0]
55 |
56 | if tot_task.budget <= 0:
57 | return "", node, None
58 |
59 | # 继续下探
60 | if tot_task.select_method == 'greedy':
61 | selected = ranked_candidates[:min(tot_task.select_branch, tot_task.branch, len(ranked_candidates))]
62 |
63 | else:
64 | idx_list = []
65 | selected = []
66 | for j in range(min(tot_task.select_branch, tot_task.branch)):
67 | idx, node = rand_select(ranked_candidates, [item.V for item in ranked_candidates])
68 | if idx not in idx_list:
69 | idx_list.append(idx)
70 | selected.append(node)
71 | selected = sorted(selected, key=lambda item: item.V, reverse=True)
72 |
73 | for child in selected:
74 | solution, child, final_node = DFS_sub(tot_task, child)
75 | if solution:
76 | return solution, node, final_node
77 |
78 | return "", node, None
79 |
80 |
81 | def DFS(tot_task):
82 | root = Node(tot_task.prompt, tot_task.response, tot_task.critique, tot_task.root_v)
83 | solution, root, final_node = DFS_sub(tot_task, root)
84 | if solution:
85 | print(f'已找到最终解!\nSolution:{solution}\n')
86 | return solution, root, final_node
87 | else:
88 | max_node, max_V = root.getBestV()
89 | max_node.final_ans_flag = 1
90 | print(f'未找到满足要求价值的解答,采用最高价值价值解答代替。\nSolution:{max_node.response}\n')
91 | return max_node.response, root, max_node
--------------------------------------------------------------------------------
/src/ToT/task.py:
--------------------------------------------------------------------------------
1 | import random
2 | from tasks.science import SearchTask
3 | from ToT.base import Node
4 | from models.get_response import *
5 | from ToT.bfs import BFS
6 | from ToT.dfs import DFS
7 | from openai import OpenAI
8 |
9 |
10 |
11 | class ToT_Task(SearchTask):
12 | def __init__(self, prompt, response, critique, root_value, propose_method='llama',
13 | algorithm='dfs', branch=3, select_branch=2, budget=1000,
14 | refine_model_path='', critique_model_path='',
15 | max_refine_tokens=3072, max_critique_tokens=2048,
16 | max_depth=3, end_gate=0.5, select_method='greedy',
17 | temperature=0.7, top_p=1.0,
18 | openai_api_key = "EMPTY", openai_api_base = "http://localhost:8000/v1",
19 | do_sample=True, sample_critique_num=5, use_case_prompt=False, evaluate='', multiply_value=False, lang='en', answer=None, verify_method=None):
20 | super().__init__(prompt, response, critique)
21 |
22 |
23 | self.client = OpenAI(
24 | api_key=openai_api_key,
25 | base_url=openai_api_base,
26 | )
27 |
28 | self.budget = budget
29 | self.root_v = root_value
30 | self.propose_method = propose_method
31 | self.mode = 'tot'
32 | self.refine_model_path = refine_model_path
33 | self.critique_model_path = critique_model_path
34 | self.max_refine_tokens = max_refine_tokens
35 | self.max_critique_tokens = max_critique_tokens
36 | self.temperature = temperature
37 | self.top_p = top_p
38 | # self.max_tokens = max_tokens
39 | self.do_sample = do_sample
40 | self.sample_critique_num = sample_critique_num
41 | # self.max_new_tokens = max_new_tokens
42 | self.algorithm = algorithm
43 | self.branch = branch
44 | self.select_branch = select_branch
45 | self.max_depth = max_depth
46 | self.use_case_prompt = use_case_prompt
47 | self.evaluate = evaluate
48 | self.select_method = select_method
49 | self.end_gate = end_gate
50 | self.node_count = 1
51 | self.multiply_value = multiply_value
52 | self.lang = lang
53 | self.answer = answer
54 | self.verify_method = verify_method
55 |
56 | def update_count(self):
57 | self.node_count += 1
58 |
59 | def clear_cache(self):
60 | self.value_cache = {}
61 | self.critique_cache = {}
62 | self.node_count = 1
63 |
64 | def get_next_step(self, prompt, response, critique):
65 |
66 | messages = self.build_refine_message(prompt, response, critique)
67 |
68 | response, refine_cot = get_refine(messages, self.propose_method, self.refine_model_path, self.temperature, self.top_p, self.do_sample, self.max_refine_tokens, self.client)
69 |
70 | if not response:
71 | print('获得refine response失败!\n')
72 | return "", ""
73 |
74 | return response, refine_cot
75 |
76 |
77 | def get_step_value(self, prompt, response):
78 | if response in self.value_cache.keys():
79 | print("命中cache!\n")
80 | return self.critique_cache[response], self.value_cache[response]
81 | # return '', self.value_cache[response]
82 |
83 | messages = self.build_judge_message(prompt, response)
84 | critique, value = get_value(messages, self.propose_method, self.critique_model_path, 0.7, 1.0, self.do_sample, self.max_critique_tokens, self.sample_critique_num, self.client)
85 |
86 | if not critique:
87 | print('获得critique失败!\n')
88 | return '', 0
89 |
90 | print(f'获得评分:{value}\n')
91 | self.value_cache.update({response: value})
92 | self.critique_cache.update({response: critique})
93 |
94 | return critique, value
95 |
96 |
97 | def run(self):
98 | self.clear_cache()
99 | if self.algorithm == 'dfs':
100 | solution, root, final_node = DFS(self)
101 | elif self.algorithm == 'bfs':
102 | solution, root, final_node = BFS(self)
103 | else:
104 | print('Unsupported algorithm!\n')
105 | return {}
106 |
107 | return solution, root, final_node
--------------------------------------------------------------------------------
/src/configs/dpo.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path:
3 |
4 | ### method
5 | stage: dpo
6 | do_train: true
7 | finetuning_type: full
8 | pref_beta: 0.1
9 | pref_ftx: 0.1
10 |
11 | ### ddp
12 | ddp_timeout: 180000000
13 | deepspeed: examples/deepspeed/ds_z3_config.json
14 |
15 | ### dataset
16 | dataset: dpo
17 | template: llama3
18 | cutoff_len: 2048
19 | # max_samples: 1024
20 | overwrite_cache: true
21 | preprocessing_num_workers: 12
22 |
23 | ### output
24 | output_dir: saves/dpo
25 | logging_steps: 1
26 | save_steps: 20
27 | save_only_model: true
28 | plot_loss: true
29 | overwrite_output_dir: true
30 |
31 | ### train
32 | per_device_train_batch_size: 2
33 | gradient_accumulation_steps: 2
34 | learning_rate: 2.0e-7
35 | num_train_epochs: 1.0
36 | lr_scheduler_type: linear
37 | # lr_scheduler_type: constant_with_warmup
38 | warmup_ratio: 0.1
39 | bf16: true
40 |
41 | ### eval
42 | # val_size: 0.1
43 | # per_device_eval_batch_size: 1
44 | # evaluation_strategy: steps
45 | # eval_steps: 500
46 |
--------------------------------------------------------------------------------
/src/configs/sft.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path:
3 |
4 | ### method
5 | stage: sft
6 | do_train: true
7 | finetuning_type: full
8 | deepspeed: examples/deepspeed/ds_z3_config.json
9 |
10 | ### dataset
11 |
12 | # dataset: judge
13 | dataset: sft
14 | template: llama3
15 | cutoff_len: 4096
16 | # max_samples: 1000
17 | overwrite_cache: true
18 | preprocessing_num_workers: 16
19 |
20 | ### output
21 | output_dir: saves/sft
22 | logging_steps: 1
23 | save_strategy: "epoch"
24 | # save_steps: 500
25 | save_only_model: true
26 | plot_loss: true
27 | overwrite_output_dir: true
28 |
29 | ### train
30 | per_device_train_batch_size: 2
31 | gradient_accumulation_steps: 4
32 | # learning_rate: 1.0e-6
33 | learning_rate: 2.0e-6
34 | num_train_epochs: 5.0
35 | lr_scheduler_type: constant_with_warmup
36 | # lr_scheduler_type: cosine
37 | warmup_ratio: 0.1
38 | bf16: true
39 | ddp_timeout: 180000000
40 |
41 | ### eval
42 | # val_size: 0.1
43 | # per_device_eval_batch_size: 1
44 | # eval_strategy: steps
45 | # eval_steps: 500
--------------------------------------------------------------------------------
/src/infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 | parser.add_argument('--begin', type=int)
5 | parser.add_argument('--end', type=int)
6 | parser.add_argument('--gpu', type=int)
7 | parser.add_argument('--output_path', type=str)
8 | args = parser.parse_args()
9 |
10 | import os
11 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
12 |
13 | import pandas as pd
14 | from vllm import LLM, SamplingParams
15 | import json
16 | from transformers import AutoTokenizer
17 |
18 |
19 | # TODO: change data path, there should exist the key "prompt"
20 | with open('', encoding='utf-8') as f:
21 | data = json.load(f)[args.begin: args.end]
22 |
23 | tmp = [{'messages': [{'role': 'user', 'content': i['prompt']}]} for i in data]
24 |
25 |
26 | # Create an LLM.
27 | # TODO
28 | model_path = 'Your-PATH-Here'
29 |
30 | llm = LLM(model=model_path, trust_remote_code=True)
31 |
32 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
33 | tokenizer.padding_side = "left"
34 | prompts = []
35 |
36 | for i in tmp:
37 | try:
38 | prompts.append(tokenizer.apply_chat_template(i['messages'], add_generation_prompt=True, tokenize=True))
39 | except Exception as e:
40 | print(e)
41 | continue
42 | print("data numbers: ", len(prompts))
43 | print(tokenizer.decode(prompts[0]))
44 |
45 | # stop_token_ids = [151329, 151336, 151338]
46 | # Create a sampling params object.
47 | sampling_params = SamplingParams(
48 | temperature=0.9,
49 | top_p=0.9,
50 | max_tokens=2048,
51 | # repetition_penalty=1.05,
52 | n=5,
53 | # stop_token_ids=stop_token_ids
54 | stop=['<|eot_id|>', '']
55 | )
56 |
57 | # from IPython import embed; embed()
58 |
59 | outputs = llm.generate(prompt_token_ids=prompts, sampling_params=sampling_params)
60 | # Print the outputs.
61 | res = []
62 | for output, i in zip(outputs, data):
63 | i['output'] = [output.outputs[j].text.strip() for j in range(5)]
64 | i['generator'] = 'llama3-8b-instruct'
65 | # i['generator'] = 'mistrial-7b-instruct'
66 | res.append(i)
67 |
68 | with open(args.output_path, 'w', encoding='utf-8') as f:
69 | json.dump(res, f, indent=4, ensure_ascii=False)
70 |
--------------------------------------------------------------------------------
/src/infer.sh:
--------------------------------------------------------------------------------
1 | gpus=(0 1 2 3 4 5 6 7)
2 | batch=8 # num gpus
3 | num=10000 # samples
4 |
5 | len=$((num / batch + 1))
6 | echo $len
7 |
8 | l=0
9 | r=$len
10 | b=()
11 | e=()
12 | for i in `seq 1 $batch`
13 | do
14 | b+=($l)
15 | e+=($r)
16 | l=$((l+len))
17 | r=$((r+len))
18 | done
19 | echo ${b[@]}
20 | echo ${e[@]}
21 |
22 | for i in `seq 0 $((batch-1))`
23 | do
24 | (
25 | python infer.py --begin ${b[$i]} \
26 | --end ${e[$i]} \
27 | --gpu ${gpus[$i]} \
28 | --output_path /vllm_output_$i.json
29 | echo $i
30 | )&
31 | done
32 | wait
33 | echo "all weakup"
34 |
--------------------------------------------------------------------------------
/src/judge.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 | parser.add_argument('--begin', type=int)
5 | parser.add_argument('--end', type=int)
6 | parser.add_argument('--gpu', type=int)
7 | parser.add_argument('--output_path', type=str)
8 | args = parser.parse_args()
9 |
10 | import os
11 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
12 |
13 | import pandas as pd
14 | from vllm import LLM, SamplingParams
15 | import json
16 | from transformers import AutoTokenizer
17 |
18 |
19 | # TODO: change data file
20 | with open('', encoding='utf-8') as f:
21 | data = json.load(f)[args.begin: args.end]
22 |
23 |
24 | def build_judge_template(prompt, response):
25 | return """Please act an expert in evaluating the capabilities of instruction-following. In the instruction-following task, the Output needs to honestly/precisely/closely follow the given Prompt.
26 | Your task is to carefully judge whether the Output to honestly/precisely/closely follows the given Prompt. If there are any constraints in the Prompt that are not satisfied by the Output, please list all the constraints that are not satisfied.
27 |
28 | Prompt: “{}”
29 |
30 | Output: “{}”
31 |
32 | Please carefully judge if each constraint is perfectly satisfied and give a final judgement weather the Output accurately follows the Prompt in the following format:
33 | Step-by-step verification: xxx
34 | Final Judgement (if the Output accurately follows the Prompt): (Yes or No)""".format(prompt, response)
35 |
36 |
37 | # TODO: check keys
38 | tmp = [{'messages': [{'role': 'user', 'content': build_judge_template(i['prompt'], i['response'])}]} for i in data]
39 |
40 | # TODO
41 | model_path = 'Your-PATH-Here'
42 |
43 |
44 | llm = LLM(model=model_path, trust_remote_code=True)
45 |
46 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
47 | tokenizer.padding_side = "left"
48 | prompts = []
49 |
50 | for i in tmp:
51 | try:
52 | prompts.append(tokenizer.apply_chat_template(i['messages'], add_generation_prompt=True, tokenize=True))
53 | except Exception as e:
54 | print(e)
55 | continue
56 | print("data numbers: ", len(prompts))
57 | print(tokenizer.decode(prompts[0]))
58 |
59 | # stop_token_ids = [151329, 151336, 151338]
60 | # Create a sampling params object.
61 | sampling_params = SamplingParams(
62 | temperature=0.7,
63 | top_p=1.0,
64 | max_tokens=2048,
65 | # repetition_penalty=1.05,
66 | n=5,
67 | # n=1,
68 | # stop_token_ids=stop_token_ids
69 | stop=['<|eot_id|>', '']
70 | )
71 |
72 |
73 | outputs = llm.generate(prompt_token_ids=prompts, sampling_params=sampling_params)
74 |
75 |
76 | res = []
77 | for output, i in zip(outputs, data):
78 | i['critique'] = [output.outputs[j].text.strip() for j in range(5)]
79 |
80 | i['generator'] = 'llama3-8b-instruct'
81 | res.append(i)
82 |
83 | with open(args.output_path, 'w', encoding='utf-8') as f:
84 | json.dump(res, f, indent=4, ensure_ascii=False)
85 |
--------------------------------------------------------------------------------
/src/judge.sh:
--------------------------------------------------------------------------------
1 | gpus=(0 1 2 3 4 5 6 7)
2 | batch=8 # num gpus
3 | num=50000 # samples
4 |
5 | len=$((num / batch + 1))
6 | echo $len
7 |
8 | l=0
9 | r=$len
10 | b=()
11 | e=()
12 | for i in `seq 1 $batch`
13 | do
14 | b+=($l)
15 | e+=($r)
16 | l=$((l+len))
17 | r=$((r+len))
18 | done
19 | echo ${b[@]}
20 | echo ${e[@]}
21 |
22 | for i in `seq 0 $((batch-1))`
23 | do
24 | (
25 | python judge.py --begin ${b[$i]} \
26 | --end ${e[$i]} \
27 | --gpu ${gpus[$i]} \
28 | --output_path /vllm_output_$i.json
29 | echo $i
30 | )&
31 | done
32 | wait
33 | echo "all weakup"
34 |
--------------------------------------------------------------------------------
/src/models/get_response.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 | import random
3 |
4 | # Set OpenAI's API key and API base to use vLLM's API server.
5 | # openai_api_key = "EMPTY"
6 | # openai_api_base = "http://localhost:8000/v1"
7 |
8 | # client = OpenAI(
9 | # api_key=openai_api_key,
10 | # base_url=openai_api_base,
11 | # )
12 |
13 |
14 | def unwrap_refine_cot(text):
15 | return text.split('[[start]]')[1].split('[[end]]')[0].strip()
16 |
17 |
18 | def get_acc(resp):
19 | if 'Final Judgement' not in resp:
20 | return -1
21 | resp = resp.split('Final Judgement')[1].strip()
22 | if not resp.count('Yes') and not resp.count('No'):
23 | return -1
24 | if resp.count("Yes"):
25 | if resp.find("No") != -1 and resp.find("No") < resp.find("Yes"):
26 | return 0
27 | else:
28 | return 1
29 | elif resp.count("No"):
30 | if resp.find("Yes") != -1 and resp.find("Yes") < resp.find("No"):
31 | return 1
32 | else:
33 | return 0
34 |
35 | # given prompt, generate proposal under instruction, unwrap is required
36 | def get_refine(message, method='llama', model_path='', temperature=0.7, top_p=1.0, do_sample=True, max_tokens=2048, client=None, n=1):
37 | response = []
38 | refine_cot = []
39 | cnt = 3
40 | if method == 'llama' or method == 'mistral':
41 | while not response and cnt:
42 | try:
43 | # print("cnt: ", cnt)
44 | cnt -= 1
45 | seed = random.randint(1, 100000000)
46 | # print(temperature, top_p, seed)
47 | # print(client.base_url)
48 | chat_response = client.chat.completions.create(
49 | model=model_path,
50 | messages=message,
51 | temperature=temperature,
52 | seed=seed,
53 | top_p=top_p,
54 | max_tokens=max_tokens,
55 | n=n
56 | )
57 | # print("refine get!")
58 | for i in range(n):
59 | response.append(unwrap_refine_cot(chat_response.choices[i].message.content))
60 | refine_cot.append(chat_response.choices[i].message.content)
61 | except Exception as e:
62 | print(e)
63 | if len(response) <= n//2:
64 | response = []
65 | refine_cot = []
66 | continue
67 | if n > 1:
68 | if not response:
69 | print(f'获取<{method}>回复失败!\n')
70 | return [], []
71 | return response, refine_cot
72 | else:
73 | if not response:
74 | print(f'获取<{method}>回复失败!\n')
75 | return "", ""
76 | return response[0], refine_cot[0]
77 |
78 | else:
79 | print('尚未支持这种回复获取方法!\n')
80 | assert False
81 |
82 |
83 | # given prompt + answer, find its value
84 | # if you use api, unwrap is required. if you use local value model, the value is directly obtained
85 | def get_value(message, method='llama', model_path='', temperature=0.7, top_p=1.0, do_sample=True, max_tokens=2048, n=1, client=None):
86 |
87 | assert (do_sample and n > 1 and temperature > 0.0) or (not do_sample and n == 1 and temperature == 0.0)
88 |
89 | critique = []
90 | acc = []
91 | cnt = 3
92 | if method == 'llama' or method == 'mistral':
93 | while not critique and cnt:
94 | try:
95 | # print("cnt: ", cnt)
96 | cnt -= 1
97 | chat_response = client.chat.completions.create(
98 | model=model_path,
99 | messages=message,
100 | temperature=temperature,
101 | top_p=top_p,
102 | max_tokens=max_tokens,
103 | n=n
104 | )
105 | # print("value get!")
106 | for i in range(n):
107 | tmp_res = chat_response.choices[i].message.content
108 | critique.append(tmp_res)
109 | acc.append(get_acc(tmp_res))
110 | if acc.count(0) == acc.count(1):
111 | critique = []
112 | acc = []
113 | continue
114 | except Exception as e:
115 | critique = []
116 | acc = []
117 | continue
118 |
119 | if not critique:
120 | print(f'获取<{method}> value失败!\n')
121 | return "", 0
122 |
123 | if acc.count(0) > acc.count(1):
124 | label = 0
125 | else:
126 | label = 1
127 | value = acc.count(1) / (acc.count(0) + acc.count(1))
128 | for i in range(n):
129 | if acc[i] == label:
130 | response = critique[i]
131 | break
132 |
133 | return response, value
134 |
135 | else:
136 | print('尚未支持这种回复获取方法!\n')
137 | assert False
--------------------------------------------------------------------------------
/src/process_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import re
4 |
5 |
6 | def process_gen_res(input_path, output_path):
7 | d = []
8 | for i in range(8):
9 | with open(f'{input_path}/vllm_output_{i}.json', encoding='utf-8') as f:
10 | d.extend(json.load(f))
11 |
12 | num = 0
13 | res = []
14 | for i in d:
15 | for j in range(5):
16 | tmp = i.copy()
17 | tmp['id'] = '{}_{}'.format(num, j)
18 | tmp['response'] = i['output'][j]
19 | res.append(tmp)
20 | num += 1
21 |
22 | print(res[:6])
23 | print(len(res))
24 | with open(output_path, 'w', encoding='utf-8') as f:
25 | json.dump(res, f, indent=4, ensure_ascii=False)
26 |
27 |
28 | def get_acc(resp):
29 | if 'Final Judgement' not in resp:
30 | return -1
31 | resp = resp.split('Final Judgement')[1].strip()
32 | if not resp.count('Yes') and not resp.count('No'):
33 | return -1
34 | if resp.count("Yes"):
35 | if resp.find("No") != -1 and resp.find("No") < resp.find("Yes"):
36 | return 0
37 | else:
38 | return 1
39 | elif resp.count("No"):
40 | if resp.find("Yes") != -1 and resp.find("Yes") < resp.find("No"):
41 | return 1
42 | else:
43 | return 0
44 |
45 |
46 |
47 | def process_score_res_voting(input_path, output_path):
48 | l = []
49 | for i in range(8):
50 | with open(f'{input_path}/vllm_output_{i}.json', encoding='utf-8') as f:
51 | l.extend(json.load(f))
52 |
53 |
54 | d = []
55 | for i in l:
56 | tmp = []
57 | for j in i['critique']:
58 | tmp.append(get_acc(j))
59 |
60 | if tmp.count(0) == tmp.count(1):
61 | continue
62 | elif tmp.count(0) > tmp.count(1):
63 | i['acc'] = 0
64 | else:
65 | i['acc'] = 1
66 |
67 | for j in range(len(tmp)):
68 | if tmp[j] == i['acc']:
69 | i['critique'] = i['critique'][j]
70 | break
71 |
72 | d.append(i)
73 |
74 | res = {}
75 | for i in d:
76 | prompt_id = int(i['id'].split('_')[0])
77 | res_id = int(i['id'].split('_')[1])
78 | if prompt_id not in res:
79 | res[prompt_id] = {
80 | 'original_prompt': i['original_prompt'],
81 | 'prompt': i['prompt'],
82 | 'output': [],
83 | 'acc': [],
84 | 'critique': [],
85 | 'prompt_id': prompt_id
86 | }
87 | res[prompt_id]['output'].append(i['output'][res_id])
88 | res[prompt_id]['acc'].append(i['acc'])
89 | res[prompt_id]['critique'].append(i['critique'])
90 |
91 | rft_data = []
92 | for i in res:
93 | if min(res[i]['acc']) == 0:
94 | rft_data.append(res[i])
95 |
96 | bad_data = []
97 | for i in rft_data:
98 | tmp = []
99 | for j in range(len(i['acc'])):
100 | if i['acc'][j] == 0:
101 | tmp.append(j)
102 | random.shuffle(tmp)
103 | for j in tmp[:1]:
104 | bad_data.append({
105 | 'prompt': i['prompt'],
106 | 'response': i['output'][j],
107 | 'critique': i['critique'][j],
108 | 'acc': 0,
109 | 'prompt_id': i['prompt_id']
110 | })
111 | print(len(bad_data))
112 | with open(output_path, 'w', encoding='utf-8') as f:
113 | json.dump(bad_data, f, indent=4, ensure_ascii=False)
114 |
115 |
116 | def process_dpo_data(input_path, output_path):
117 |
118 | with open(input_path, encoding='utf-8') as f:
119 | l = f.readlines()
120 |
121 | d = []
122 | for i in l:
123 | i = json.loads(i)
124 | if i['final_node']['value'] > 0.5:
125 | d.append(i)
126 |
127 | data = []
128 | for i in d:
129 | data.append({
130 | 'messages': [
131 | {
132 | "role": "user",
133 | "content": i['prompt']
134 | }
135 | ],
136 | "chosen": {
137 | "role": "assistant",
138 | "content": i['final_node']['response']
139 | },
140 | "rejected": {
141 | "role": "assistant",
142 | "content": i['response']
143 | }
144 | })
145 | with open(output_path, 'w', encoding='utf-8') as f:
146 | json.dump(data, f, indent=4, ensure_ascii=False)
147 |
148 | judge_template = """Please act an expert in evaluating the capabilities of instruction-following. In the instruction-following task, the Output needs to honestly/precisely/closely follow the given Prompt.
149 | Your task is to carefully judge whether the Output to honestly/precisely/closely follows the given Prompt. If there are any constraints in the Prompt that are not satisfied by the Output, please list all the constraints that are not satisfied.
150 |
151 | Prompt: “{}”
152 |
153 | Output: “{}”
154 |
155 | Please carefully judge if each constraint is perfectly satisfied and give a final judgement weather the Output accurately follows the Prompt in the following format:
156 | Step-by-step verification: xxx
157 | Final Judgement (if the Output accurately follows the Prompt): (Yes or No)"""
158 |
159 | def process_rft_data(input_tree_search_path, input_judge_path, output_path):
160 |
161 | with open(input_tree_search_path, encoding='utf-8') as f:
162 | l = f.readlines()
163 | refine_data = []
164 | for i in l:
165 | i = json.loads(i)
166 | if i['final_node']['value'] > 0.5:
167 | refine_data.append({
168 | 'messages': [
169 | {"role": "user", "content": judge_template.format(i['prompt'], i['final_node']['parent_res'])},
170 | {"role": "assistant", "content": i['final_node']['parent_critique']},
171 | {"role": "user", "content": """Based on your judgement, refine the Output to make sure it can honestly/precisely/closely follows the given Prompt.
172 |
173 |
174 | Please carefully refine the Output to meet all the constraints in the Prompt.
175 |
176 | Please format like this:
177 | Reflection on how to refine the Output: xxx
178 | Final Refined Output: [[start]] xxx [[end]]"""},
179 | {
180 | "role": "assistant",
181 | "content": i['final_node']['refine_cot']
182 | }
183 | ]
184 | })
185 | pos = []
186 | neg = []
187 | for i in l:
188 | i = json.loads(i)
189 | if len(i['critique_rft']['positive']):
190 | pos.append({
191 | 'messages': [
192 | {"role": "user", "content": i['critique_rft']['positive'][0]['prompt']},
193 | {"role": "assistant", "content": i['critique_rft']['positive'][0]['response']}
194 | ]
195 | })
196 | if len(i['critique_rft']['negative']):
197 | neg.append({
198 | 'messages': [
199 | {"role": "user", "content": i['critique_rft']['negative'][0]['prompt']},
200 | {"role": "assistant", "content": i['critique_rft']['negative'][0]['response']}
201 | ]
202 | })
203 | l = []
204 | for i in range(8):
205 | with open(f'{input_judge_path}/vllm_output_{i}.json', encoding='utf-8') as f:
206 | l.extend(json.load(f))
207 |
208 | d_0 = []
209 | d_1 = []
210 | for i in l:
211 |
212 | tmp = []
213 | for j in i['critique']:
214 | tmp.append(get_acc(j))
215 |
216 | if tmp.count(0) == tmp.count(1):
217 | continue
218 | elif tmp.count(0) > tmp.count(1):
219 | i['acc'] = 0
220 | else:
221 | i['acc'] = 1
222 |
223 | good_res = None
224 | bad_res = None
225 | for j in range(len(tmp)):
226 | if good_res and bad_res:
227 | break
228 | if tmp[j] == i['acc']:
229 | good_res = i['critique'][j]
230 | elif tmp[j] == 1-i['acc']:
231 | bad_res = i['critique'][j]
232 | if not (good_res and bad_res):
233 | continue
234 | i['good_res'] = good_res
235 | i['bad_res'] = bad_res
236 | i['judge_prompt'] = judge_template.format(i['prompt'], i['response'])
237 | if i['acc'] == 0:
238 | d_0.append(i)
239 | else:
240 | d_1.append(i)
241 | random.shuffle(pos)
242 | random.shuffle(neg)
243 | random.shuffle(d_0)
244 | random.shuffle(d_1)
245 | random.shuffle(refine_data)
246 |
247 | sft_data = []
248 | p = []
249 | for i in refine_data[:4000]:
250 | if i['messages'][0]['content'] not in p:
251 | p.append(i['messages'][0]['content'])
252 | sft_data.append(i)
253 | for i in pos[:1500]:
254 | if i['messages'][0]['content'] not in p:
255 | p.append(i['messages'][0]['content'])
256 | sft_data.append(i)
257 | for i in neg[:1000]:
258 | if i['messages'][0]['content'] not in p:
259 | p.append(i['messages'][0]['content'])
260 | sft_data.append(i)
261 | for i in d_0[:500]:
262 | if i['judge_prompt'] not in p:
263 | p.append(i['judge_prompt'])
264 | sft_data.append({
265 | 'messages': [
266 | {"role": "user", "content": i['judge_prompt']},
267 | {"role": "assistant", "content": i['good_res']}
268 | ]
269 | })
270 | for i in d_1[:2500]:
271 | if i['judge_prompt'] not in p:
272 | p.append(i['judge_prompt'])
273 | sft_data.append({
274 | 'messages': [
275 | {"role": "user", "content": i['judge_prompt']},
276 | {"role": "assistant", "content": i['good_res']}
277 | ]
278 | })
279 |
280 | random.shuffle(sft_data)
281 | with open(output_path, 'w', encoding='utf-8') as f:
282 | json.dump(sft_data, f, indent=4, ensure_ascii=False)
283 |
284 |
285 | if __name__ == '__main__':
286 | input_path = ''
287 | output_path = ''
288 | process_gen_res(input_path, output_path)
289 |
290 |
291 | input_path = ''
292 | output_path = ''
293 | process_score_res_voting(input_path, output_path)
294 |
295 |
296 | input_path = ''
297 | output_path = ''
298 | process_dpo_data(input_path, output_path)
299 |
300 |
301 | input_tree_search_path = ''
302 | input_judge_path = ''
303 | output_path = ''
304 | process_rft_data(input_tree_search_path, input_judge_path, output_path)
--------------------------------------------------------------------------------
/src/tasks/prompts.py:
--------------------------------------------------------------------------------
1 |
2 | judge_template = """Please act an expert in evaluating the capabilities of instruction-following. In the instruction-following task, the Output needs to honestly/precisely/closely follow the given Prompt.
3 | Your task is to carefully judge whether the Output to honestly/precisely/closely follows the given Prompt. If there are any constraints in the Prompt that are not satisfied by the Output, please list all the constraints that are not satisfied.
4 |
5 | Prompt: “{}”
6 |
7 | Output: “{}”
8 |
9 | Please carefully judge if each constraint is perfectly satisfied and give a final judgement weather the Output accurately follows the Prompt in the following format:
10 | Step-by-step verification: xxx
11 | Final Judgement (if the Output accurately follows the Prompt): (Yes or No)"""
--------------------------------------------------------------------------------
/src/tasks/science.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | from tasks.prompts import *
4 |
5 |
6 | # mode: 'cot', 'tot', 'mcts'
7 | # method: 'glm', 'gpt', 'local'
8 | class SearchTask(object):
9 | def __init__(self, prompt, response, critique):
10 | super().__init__()
11 | self.prompt = prompt
12 | self.response = response
13 | self.critique = critique
14 | self.value_cache = {}
15 | self.critique_cache = {}
16 |
17 | def clear_cache(self):
18 | self.value_cache = {}
19 | self.critique_cache = {}
20 |
21 |
22 | @staticmethod
23 | def build_refine_message(prompt, response, critique):
24 | return [
25 | {"role": "user", "content": judge_template.format(prompt, response)},
26 | {"role": "assistant", "content": critique},
27 | {"role": "user", "content": """Based on your judgement, refine the Output to make sure it can honestly/precisely/closely follows the given Prompt.
28 |
29 |
30 | Please carefully refine the Output to meet all the constraints in the Prompt.
31 |
32 | Please format like this:
33 | Reflection on how to refine the Output: xxx
34 | Final Refined Output: [[start]] xxx [[end]]"""}
35 | ]
36 |
37 | @staticmethod
38 | def build_judge_message(prompt, response):
39 | return [
40 | {"role": "user", "content": judge_template.format(prompt, response)}
41 | ]
42 |
--------------------------------------------------------------------------------
/src/tree_search.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import multiprocessing
3 | from multiprocessing import Manager
4 | import json
5 | from tqdm import tqdm
6 | import os
7 | import time
8 | # import pandas as pd
9 | import random
10 | from ToT.task import *
11 |
12 |
13 |
14 |
15 | def chat_gpt(messages, counter, error_count):
16 | responses = []
17 | for i, m in enumerate(messages):
18 | try:
19 | # TODO: change model path, it should be same as the name deployed with vllm
20 | task = ToT_Task(m['prompt'], m['response'], m['critique'], root_value=0, propose_method='llama', algorithm='bfs', refine_model_path='', critique_model_path='', openai_api_base=m['openai_api_base'])
21 | solution, root, final_node = task.run()
22 | m['final_node'] = {
23 | 'response': final_node.response,
24 | 'refine_cot': final_node.refine_cot,
25 | 'parent_res': final_node.parent.response,
26 | 'parent_critique': final_node.parent.critique,
27 | 'value': final_node.V,
28 | }
29 | negative_critique, positive_critique = root.getCritiqueRFT()
30 | if len(negative_critique):
31 | negative_critique = random.sample(negative_critique, 1)
32 | if len(positive_critique):
33 | positive_critique = random.sample(positive_critique, 1)
34 | m['critique_rft'] = {
35 | 'positive': positive_critique,
36 | 'negative': negative_critique
37 | }
38 | # 保存响应到文件
39 | with open(output_file, 'a', encoding='utf-8') as f:
40 | print(json.dumps(m, ensure_ascii=False), file=f)
41 |
42 | responses.append(0)
43 |
44 | # Increment and print the counter
45 | counter.value += 1
46 | except Exception as e:
47 | error_count.value += 1
48 | print(e)
49 | print('running time:{} finished number:{} skipped number:{}'.format(time.time()-s_time, counter.value, error_count.value), end='\r')
50 |
51 | return responses
52 |
53 |
54 | def multi_process_chat_gpt(messages_list, num_processes):
55 | # 将messages_list分为num_processes个子列表
56 | sublists = [messages_list[i::num_processes] for i in range(num_processes)]
57 |
58 | # Create a shared counter
59 | manager = Manager()
60 | counter = manager.Value('i', 0)
61 | error_count = manager.Value('j', 0)
62 |
63 | with multiprocessing.Pool() as pool:
64 | all_responses = pool.starmap(chat_gpt, [(sublist, counter, error_count) for sublist in sublists])
65 | # 将所有响应合并为一个列表
66 | return [item for sublist in all_responses for item in sublist]
67 |
68 |
69 | def get_messages_list():
70 |
71 | evaluated = []
72 | with open(output_file, encoding='utf-8') as f:
73 | lines = f.readlines()
74 | for i in lines:
75 | evaluated.append(json.loads(i)['prompt_id'])
76 |
77 | with open(input_file, encoding='utf-8') as f:
78 | d = json.load(f)
79 |
80 |
81 | messages_list = []
82 |
83 | num = 0
84 | for i in d[:]:
85 | if i['prompt_id'] in evaluated:
86 | continue
87 | i['openai_api_base'] = f"http://localhost:800{num}/v1"
88 | # num += 1
89 | # num %= 8
90 | messages_list.append(i)
91 | return messages_list
92 |
93 |
94 |
95 | if __name__ == '__main__':
96 |
97 |
98 | # TODO: change file path
99 | input_file = ''
100 | output_file = ''
101 |
102 | if not os.path.exists(output_file):
103 | x = open(output_file, 'w')
104 | x.close()
105 | messages_list = get_messages_list()
106 |
107 | print("total num: ", len(messages_list))
108 | s_time = time.time()
109 | responses = multi_process_chat_gpt(messages_list, num_processes=32)
--------------------------------------------------------------------------------