├── LICENSE ├── PDF └── o1-Coder.pdf ├── README.md ├── assets └── algo.jpeg └── src ├── RL ├── README.md ├── rewards │ ├── __init__.py │ ├── examples.py │ ├── prm_utils.py │ ├── pyext2.py │ ├── rewards.py │ └── testing_util.py └── rewards_test.py ├── TestCaseGenerate ├── README.md ├── TCG.py └── wash_code.py ├── mcts ├── README.md ├── common │ ├── arguments.py │ └── utils.py ├── data │ └── TACO │ │ ├── self_create_data.json │ │ ├── test_one.json │ │ └── train_easy_10.json ├── eval_src │ ├── Evaluator.py │ ├── checker_utils.py │ ├── pyext2.py │ └── testing_util.py ├── models │ ├── HuggingFace_API.py │ ├── IO_System.py │ ├── OpenAI_API.py │ └── vLLM_API.py ├── prompts │ └── TACO │ │ ├── examples.txt │ │ └── prompt.json ├── run_outputs │ └── data_examples │ │ ├── answer_sheets │ │ ├── Question 0000 - Answer.json │ │ ├── Question 0000 - Best Solutions.json │ │ ├── Question 0000 - Complete Solutions.json │ │ └── Question 0000 - Rollout Solutions.json │ │ ├── args.json │ │ └── intermediate_result.txt ├── run_src │ ├── MCTS_backbone.py │ ├── MCTS_for_reasoning.py │ ├── do_generate.py │ └── rstar_utils.py └── scripts │ ├── api_run_TACO.sh │ ├── run_TACO.sh │ └── start.md └── prm_training ├── README.md ├── data └── examples │ ├── hard_label_examples.json │ └── soft_label_examples.json ├── requirements.txt ├── run.py ├── run.sh └── train_prm ├── __init__.py ├── run_train.py └── utils ├── accelerator_utils.py └── dist_configs ├── ds_zero2_config.json ├── multi_gpu.yaml └── single_gpu.yaml /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ADaM-LAB. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PDF/o1-Coder.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ADaM-BJTU/O1-CODER/f01c769397afaefc89c22c51d048484a79d11c1f/PDF/o1-Coder.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # O1-CODER 2 | [O1-CODER: An O1 Replication for Coding (Paper)](https://arxiv.org/abs/2412.00154) 3 | 4 | ## Overview 5 | 6 | **O1-CODER** is an attempt to replicate OpenAI's **O1 model**, focused on coding tasks. The approach combines **Reinforcement Learning (RL)** and **Monte Carlo Tree Search (MCTS)** to enhance the model’s **System-2** thinking capabilities, aiming to generate more efficient and logical code. 7 | 8 | ### Method 9 | 10 | The core components of **O1-CODER** are: 11 | 12 | 1. **Test Case Generator (TCG)**: Automatically generates standardized test cases to evaluate the correctness of the generated code. 13 | 2. **Self-Play and Reinforcement Learning**: The model generates reasoning data through self-play, and uses RL and MCTS to iteratively optimize the policy model. 14 | These methods work in an iterative cycle, continuously refining the model to improve systematic reasoning and optimization in coding tasks. 15 | 16 |
[Code End]." Here, should contain a working Python code based on the final pseudocode, and it must be enclosed within Python code block syntax.
21 |
22 | ## Notes
23 | 1. Aim to break down the solution into as many detailed, intermediate steps as possible while ensuring logical coherence between steps and avoiding unnecessary redundancy.
24 | 2. The Python code solution should match the input and output requirements as described in the question. This means the solution may use terminal I/O for inputs and outputs, or it may require function parameters and return values. Carefully review the question's description to determine the expected code structure, and ensure there are no input/output format errors.
25 | 3. Gradually refine each functional part of the pseudocode, breaking down complex operations into manageable steps.
26 | 4. Transition to Python code only once all parts of the pseudocode have been fully refined.
27 | 6. Do not generate content unrelated to the answer or any other explanations.
28 |
29 | Now, with the problem description provided below, you need to provide or complete a full, step-by-step solution according to the previous explanations. **If the 'Solution' section is empty, please directly provide a complete, step-by-step solution. If it is not empty, do not repeat or rephrase existing content; simply continue from where it left off to complete the solution.**
30 | ### Description
31 | {question}
32 |
33 | ### Solution
34 | """
35 |
36 |
37 | @dataclass
38 | class StepTokensForLM:
39 | step_tag: str = field(
40 | default=' Rating',
41 | metadata={'help': 'The tag that indicates the end of a step/action'}
42 | )
43 | good_token: str = field(
44 | default=' +',
45 | metadata={'help': 'The token that indicates a positive action'}
46 | )
47 | bad_token: str = field(
48 | default=' -',
49 | metadata={'help': 'The token that indicates a negative action'}
50 | )
51 |
52 | @dataclass
53 | class RewardStrategy(Enum):
54 | """
55 | Enum class for the tokenized format of the text.
56 | """
57 | TOKEN_LOGITS = 'token_logits'
58 | VALUE_HEAD = 'value_head'
59 |
60 |
61 | PromptType: TypeAlias = Union[str, List[str]]
62 | PrefixesType: TypeAlias = List[PromptType]
63 |
64 | @torch.no_grad()
65 | def get_process_rewards(model: AutoModelForCausalLM,
66 | tokenizer: AutoTokenizer,
67 | prompts: PromptType,
68 | completed_processes: PrefixesType,
69 | tokenized_format: Optional[Literal['completion', 'chat_completion']],
70 | reward_strategy: Optional[RewardStrategy] = RewardStrategy.TOKEN_LOGITS.value,
71 | ) -> List[float]:
72 |
73 | if isinstance(prompts, str):
74 | if isinstance(completed_processes[0], list):
75 | raise ValueError("The `completed_prefixes` argument must be a list of strings if `prompts` is a string.")
76 | prompts = [prompts]
77 | completed_processes = [completed_processes]
78 | if isinstance(prompts, list) and isinstance(completed_processes[0], str):
79 | raise ValueError("The `completed_prefixes` argument must be a list of lists if `prompts` is a list.")
80 | if len(prompts) != len(completed_processes):
81 | raise ValueError("The number of prompts must match the number of completed prefixes in order.")
82 |
83 | if reward_strategy is RewardStrategy.TOKEN_LOGITS.value:
84 | tokenized_ids = tokenizer(
85 | [StepTokensForLM.step_tag, StepTokensForLM.good_token, StepTokensForLM.bad_token]
86 | )["input_ids"]
87 | if any(len(tokenized_id) != 1 for tokenized_id in tokenized_ids):
88 | raise ValueError("The tokens `step_tag`, `good_token`, `bad_token` used in class `StepTokensForLM` must be single tokens.")
89 | step_token_id, good_token_id, bad_token_id = [ids[0] for ids in tokenized_ids]
90 |
91 | if tokenized_format == 'completion':
92 | raise NotImplementedError("Token Logits reward strategy is not implemented for completion tokenized format.")
93 | elif tokenized_format == 'chat_completion':
94 | input_texts = []
95 | for prompt, processes in zip(prompts, completed_processes):
96 | step_tag_inserted_process = ''
97 | for process in processes:
98 | step_tag_inserted_process += process + StepTokensForLM.step_tag
99 |
100 | messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": step_tag_inserted_process}]
101 | chat_template = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
102 |
103 | end_token_of_one_turn = '<|im_end|>\n'
104 | chat_template = chat_template.removesuffix(end_token_of_one_turn)
105 | input_texts.append(chat_template)
106 |
107 | model_inputs = tokenizer(input_texts, return_tensors='pt', padding=True).to(model.device)
108 | step_tag_positions = (model_inputs["input_ids"] == step_token_id).nonzero()
109 | model_inputs["attention_mask"][step_tag_positions[:, 0], step_tag_positions[:, 1]] = 0
110 |
111 | outputs = model(**model_inputs)
112 | logits = outputs.logits
113 | logits = logits[..., [good_token_id, bad_token_id]]
114 | scores = logits[step_tag_positions[:, 0], step_tag_positions[:, 1] - 1, :].softmax(dim=-1) # [bs, 2(' +', ' -')]
115 |
116 | return scores.cpu().tolist()
117 | else:
118 | raise ValueError("The `tokenized_format` argument must be either 'completion' or 'chat_completion'.")
119 |
120 |
121 | if __name__ == '__main__':
122 | model_path = '/root/autodl-tmp/models/Qwen2.5-Coder-7B-Instruct-PRM'
123 | model = AutoModelForCausalLM.from_pretrained(
124 | model_path,
125 | torch_dtype=torch.bfloat16,
126 | attn_implementation="flash_attention_2"
127 | ).to('cuda').eval()
128 | tokenizer = AutoTokenizer.from_pretrained(model_path)
129 |
130 | problem = 'Given some positive integers, I wish to print the integers such that all take up the same width by adding a minimum number of leading zeroes. No leading zeroes shall be added to the largest integer.\n\nFor example, given `1, 23, 2, 17, 102`, I wish to print out these numbers as follows:\n\n```python\n001\n023\n002\n017\n102\n```\n\nWrite a function `print_nums(n1, n2, n3, ...)` that takes a variable number of arguments and returns the string to be printed out.'
131 | prompt = [
132 | CODEPRM_PROMPT.format(question=problem),
133 | problem
134 | ]
135 | examples = [
136 | ["Step 1: Defining Function Structures Using pseudocode\nWe start by defining the structure of our solution. We need a function `reorder` that takes two integers `N` and `M`, and returns a numpy array with two sub-arrays. Each sub-array will contain numbers in the specified ranges and will be rotated `M` times.\n\n[Pseudo Start]\n```\nFunction reorder(N, M):\n Calculate half of N\n Create the first sub-array with numbers in the range [0, N/2)\n Create the second sub-array with numbers in the range [N/2, N)\n Rotate the first sub-array M times\n Rotate the second sub-array M times\n Combine the two sub-arrays into a numpy array\n Return the combined numpy array\n[Pseudo End]\n```"],
137 | ["Step 1: Defining Function Structures Using pseudocode\nWe start by defining the structure of our solution. We need a function `reorder` that takes two integers `N` and `M`, and returns a numpy array with two sub-arrays. Each sub-array will contain numbers in the specified ranges and will be rotated `M` times.\n\n[Pseudo Start]\n```\nFunction reorder(N, M):\n Calculate half of N\n Create the first sub-array with numbers in the range [0, N/2)\n Create the second sub-array with numbers in the range [N/2, N)\n Rotate the first sub-array M times\n Rotate the second sub-array M times\n Combine the two sub-arrays into a numpy array\n Return the combined numpy array\n[Pseudo End]\n```"]
138 | ]
139 |
140 | get_process_rewards(model, tokenizer, prompt, examples, 'chat_completion', 'token_logits')
--------------------------------------------------------------------------------
/src/RL/rewards/pyext2.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2014 Ryan Gonzalez
3 |
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to use,
8 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
9 | Software, and to permit persons to whom the Software is furnished to do so,
10 | subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 | '''
22 |
23 | g_backup = globals().copy()
24 |
25 | __version__ = '0.7'
26 |
27 | __all__ = ['overload', 'RuntimeModule', 'switch', 'tail_recurse', 'copyfunc', 'set_docstring', 'annotate', 'safe_unpack', 'modify_function', 'assign', 'fannotate', 'compare_and_swap', 'is_main', 'call_if_main', 'run_main']
28 |
29 | import sys, inspect, types
30 |
31 | def __targspec(func, specs, attr='__orig_arg__'):
32 | if hasattr(func, '__is_overload__') and func.__is_overload__:
33 | return getattr(func, attr)
34 | return specs(func)
35 |
36 | def set_docstring(doc):
37 | '''A simple decorator to set docstrings.
38 |
39 | :param doc: The docstring to tie to the function.
40 |
41 | Example::
42 |
43 | @set_docstring('This is a docstring')
44 | def myfunc(x):
45 | pass'''
46 | def _wrap(f):
47 | f.__doc__ = doc
48 | return f
49 | return _wrap
50 |
51 | __modify_function_doc = '''
52 | Creates a copy of a function, changing its attributes.
53 |
54 | :param globals: Will be added to the function's globals.
55 |
56 | :param name: The new function name. Set to ``None`` to use the function's original name.
57 |
58 | :param code: The new function code object. Set to ``None`` to use the function's original code object.
59 |
60 | :param defaults: The new function defaults. Set to ``None`` to use the function's original defaults.
61 |
62 | :param closure: The new function closure. Set to ``None`` to use the function's original closure.
63 |
64 | .. warning:: This function can be potentially dangerous.
65 | '''
66 |
67 | def copyfunc(f):
68 | '''Copies a funcion.
69 |
70 | :param f: The function to copy.
71 |
72 | :return: The copied function.
73 |
74 | .. deprecated:: 0.4
75 | Use :func:`modify_function` instead.
76 | '''
77 | return modify_function(f)
78 |
79 | if sys.version_info.major == 3:
80 | @set_docstring(__modify_function_doc)
81 | def modify_function(f, globals={}, name=None, code=None, defaults=None,
82 | closure=None):
83 | if code is None: code = f.__code__
84 | if name is None: name = f.__name__
85 | if defaults is None: defaults = f.__defaults__
86 | if closure is None: closure = f.__closure__
87 | newf = types.FunctionType(code, dict(f.__globals__, **globals), name=name,
88 | argdefs=defaults, closure=closure)
89 | newf.__dict__.update(f.__dict__)
90 | return newf
91 | def argspec(f):
92 | return inspect.getfullargspec(f)
93 | ofullargspec = inspect.getfullargspec
94 | def _fullargspec(func):
95 | return __targspec(func, ofullargspec)
96 | inspect.getfullargspec = _fullargspec
97 | def _exec(m,g): exec(m,g)
98 | else:
99 | @set_docstring(__modify_function_doc)
100 | def modify_function(f, globals={}, name=None, code=None, defaults=None,
101 | closure=None):
102 | if code is None: code = f.func_code
103 | if name is None: name = f.__name__
104 | if defaults is None: defaults = f.func_defaults
105 | if closure is None: closure = f.func_closure
106 | newf = types.FunctionType(code, dict(f.func_globals, **globals), name=name,
107 | argdefs=defaults, closure=closure)
108 | newf.__dict__.update(f.__dict__)
109 | return newf
110 | def argspec(f):
111 | return inspect.getargspec(f)
112 | eval(compile('def _exec(m,g): exec m in g', '', 'exec'))
113 |
114 | def _gettypes(args):
115 | return tuple(map(type, args))
116 |
117 | oargspec = inspect.getargs
118 |
119 | def _argspec(func):
120 | return __targspec(func, oargspec)
121 |
122 | inspect.getargspec = _argspec
123 |
124 | try:
125 | import IPython
126 | except ImportError:
127 | IPython = None
128 | else:
129 | # Replace IPython's argspec
130 | oipyargspec = IPython.core.oinspect.getargspec
131 | def _ipyargspec(func):
132 | return __targspec(func, oipyargspec, '__orig_arg_ipy__')
133 | IPython.core.oinspect.getargspec = _ipyargspec
134 |
135 | class overload(object):
136 | '''Simple function overloading in Python.'''
137 | _items = {}
138 | _types = {}
139 | @classmethod
140 | def argc(self, argc=None):
141 | '''Overloads a function based on the specified argument count.
142 |
143 | :param argc: The argument count. Defaults to ``None``. If ``None`` is given, automatically compute the argument count from the given function.
144 |
145 | .. note::
146 |
147 | Keyword argument counts are NOT checked! In addition, when the argument count is automatically calculated, the keyword argument count is also ignored!
148 |
149 | Example::
150 |
151 | @overload.argc()
152 | def func(a):
153 | print 'Function 1 called'
154 |
155 | @overload.argc()
156 | def func(a, b):
157 | print 'Function 2 called'
158 |
159 | func(1) # Calls first function
160 | func(1, 2) # Calls second function
161 | func() # Raises error
162 | '''
163 | # Python 2 UnboundLocalError fix
164 | argc = {'argc': argc}
165 | def _wrap(f):
166 | def _newf(*args, **kwargs):
167 | if len(args) not in self._items[f.__name__]:
168 | raise TypeError("No overload of function '%s' that takes %d args" % (f.__name__, len(args)))
169 | return self._items[f.__name__][len(args)](*args, **kwargs)
170 | if f.__name__ not in self._items:
171 | self._items[f.__name__] = {}
172 | if argc['argc'] is None:
173 | argc['argc'] = len(argspec(f).args)
174 | self._items[f.__name__][argc['argc']] = f
175 | _newf.__name__ = f.__name__
176 | _newf.__doc__ = f.__doc__
177 | _newf.__is_overload__ = True
178 | _newf.__orig_arg__ = argspec(f)
179 | if IPython:
180 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f)
181 | return _newf
182 | return _wrap
183 | @classmethod
184 | def args(self, *argtypes, **kw):
185 | '''Overload a function based on the specified argument types.
186 |
187 | :param argtypes: The argument types. If None is given, get the argument types from the function annotations(Python 3 only)
188 | :param kw: Can only contain 1 argument, `is_cls`. If True, the function is assumed to be part of a class.
189 |
190 | Example::
191 |
192 | @overload.args(str)
193 | def func(s):
194 | print 'Got string'
195 |
196 | @overload.args(int, str)
197 | def func(i, s):
198 | print 'Got int and string'
199 |
200 | @overload.args()
201 | def func(i:int): # A function annotation example
202 | print 'Got int'
203 |
204 | func('s')
205 | func(1)
206 | func(1, 's')
207 | func(True) # Raises error
208 | '''
209 |
210 | # Python 2 UnboundLocalError fix...again!
211 | argtypes = {'args': tuple(argtypes)}
212 | def _wrap(f):
213 | def _newf(*args):
214 | if len(kw) == 0:
215 | cargs = args
216 | elif len(kw) == 1 and 'is_cls' in kw and kw['is_cls']:
217 | cargs = args[1:]
218 | else:
219 | raise ValueError('Invalid keyword args specified')
220 | if _gettypes(cargs) not in self._types[f.__name__]:
221 | raise TypeError("No overload of function '%s' that takes '%s' types and %d arg(s)" % (f.__name__, _gettypes(cargs), len(cargs)))
222 | return self._types[f.__name__][_gettypes(cargs)](*args)
223 | if f.__name__ not in self._types:
224 | self._types[f.__name__] = {}
225 | if len(argtypes['args']) == 1 and argtypes['args'][0] is None:
226 | aspec = argspec(f)
227 | argtypes['args'] = tuple(map(lambda x: x[1], sorted(
228 | aspec.annotations.items(), key=lambda x: aspec.args.index(x[0]))))
229 | self._types[f.__name__][argtypes['args']] = f
230 | _newf.__name__ = f.__name__
231 | _newf.__doc__ = f.__doc__
232 | _newf.__is_overload__ = True
233 | _newf.__orig_arg__ = argspec(f)
234 | if IPython:
235 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f)
236 | return _newf
237 | return _wrap
238 |
239 | class _RuntimeModule(object):
240 | 'Create a module object at runtime and insert it into sys.path. If called, same as :py:func:`from_objects`.'
241 | def __call__(self, *args, **kwargs):
242 | return self.from_objects(*args, **kwargs)
243 | @staticmethod
244 | @overload.argc(1)
245 | def from_objects(module_name_for_code_eval, **d):
246 | return _RuntimeModule.from_objects(module_name_for_code_eval, '', **d)
247 | @staticmethod
248 | @overload.argc(2)
249 | def from_objects(module_name_for_code_eval, docstring, **d):
250 | '''Create a module at runtime from `d`.
251 |
252 | :param name: The module name.
253 |
254 | :param docstring: Optional. The module's docstring.
255 |
256 | :param \*\*d: All the keyword args, mapped from name->value.
257 |
258 | Example: ``RuntimeModule.from_objects('name', 'doc', a=1, b=2)``'''
259 | module = types.ModuleType(module_name_for_code_eval, docstring)
260 | module.__dict__.update(d)
261 | module.__file__ = ''
262 | sys.modules[module_name_for_code_eval] = module
263 | return module
264 | @staticmethod
265 | @overload.argc(2)
266 | def from_string(module_name_for_code_eval, s):
267 | return _RuntimeModule.from_string(module_name_for_code_eval, '', s)
268 | @staticmethod
269 | @overload.argc(3)
270 | def from_string(module_name_for_code_eval, docstring, s):
271 | '''Create a module at runtime from `s``.
272 |
273 | :param name: The module name.
274 |
275 | :param docstring: Optional. The module docstring.
276 |
277 | :param s: A string containing the module definition.'''
278 | g = {}
279 | _exec(s, g)
280 | return _RuntimeModule.from_objects(module_name_for_code_eval, docstring, **dict(filter(lambda x: x[0] not in g_backup, g.items())))
281 |
282 | RuntimeModule = _RuntimeModule()
283 |
284 | class CaseObject(object):
285 | 'The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments.'
286 | def __init__(self, value):
287 | self.value = value
288 | self.did_match = False
289 | self.did_pass = False
290 | def __call__(self, *args):
291 | if assign('res', not self.did_pass and any([self.value == rhs for rhs in args])):
292 | self.did_match = True
293 | return res
294 | def quit(self):
295 | 'Forces all other calls to return False. Equilavent of a ``break`` statement.'
296 | self.did_pass = True
297 | def default(self):
298 | "Executed if quit wasn't called."
299 | return not self.did_match and not self.did_pass
300 | def __iter__(self):
301 | yield self
302 | def __enter__(self):
303 | return self
304 | def __exit__(self, *args):
305 | pass
306 |
307 | def switch(value):
308 | '''A Python switch statement implementation that is used with a ``with`` statement.
309 |
310 | :param value: The value to "switch".
311 |
312 | ``with`` statement example::
313 |
314 | with switch('x'):
315 | if case(1): print 'Huh?'
316 | if case('x'): print 'It works!!!'
317 |
318 | .. warning:: If you modify a variable named "case" in the same scope that you use the ``with`` statement version, you will get an UnboundLocalError. The soluction is to use ``with switch('x') as case:`` instead of ``with switch('x'):``.'''
319 | res = CaseObject(value)
320 | inspect.stack()[1][0].f_globals['case'] = res
321 | return res
322 |
323 | def tail_recurse(spec=None):
324 | '''Remove tail recursion from a function.
325 |
326 | :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value.
327 |
328 | .. note::
329 |
330 | This function has a slight overhead that is noticable when using timeit. Only use it if the function has a possibility of going over the recursion limit.
331 |
332 | .. warning::
333 |
334 | This function will BREAK any code that either uses any recursion other than tail recursion or calls itself multiple times. For example, ``def x(): return x()+1`` will fail.
335 |
336 | Example::
337 |
338 | @tail_recurse()
339 | def add(a, b):
340 | if a == 0: return b
341 | return add(a-1, b+1)
342 |
343 | add(10000000, 1) # Doesn't max the recursion limit.
344 | '''
345 | def _wrap(f):
346 | class TailRecursion(Exception):
347 | def __init__(self, args, kwargs):
348 | self.args = args
349 | self.kwargs = kwargs
350 | def _newf(*args, **kwargs):
351 | if inspect.stack()[1][3] == f.__name__:
352 | if (spec and spec(args)) or not spec:
353 | raise TailRecursion(args, kwargs)
354 | while True:
355 | try:
356 | res = f(*args, **kwargs)
357 | except TailRecursion as ex:
358 | args = ex.args
359 | kwargs = ex.kwargs
360 | continue
361 | else:
362 | return res
363 | _newf.__doc__ = f.__doc__
364 | return _newf
365 | return _wrap
366 |
367 | def annotate(*args, **kwargs):
368 | '''Set function annotations using decorators.
369 |
370 | :param args: This is a list of annotations for the function, in the order of the function's parameters. For example, ``annotate('Annotation 1', 'Annotation 2')`` will set the annotations of parameter 1 of the function to ``Annotation 1``.
371 |
372 | :param kwargs: This is a mapping of argument names to annotations. Note that these are applied *after* the argument list, so any args set that way will be overriden by this mapping. If there is a key named `ret`, that will be the annotation for the function's return value.
373 |
374 | .. deprecated:: 0.5
375 | Use :func:`fannotate` instead.
376 | '''
377 | def _wrap(f):
378 | if not hasattr(f, '__annotations__'):
379 | f.__annotations__ = {}
380 | if 'ret' in kwargs:
381 | f.__annotations__['return'] = kwargs.pop('ret')
382 | f.__annotations__.update(dict(zip(argspec(f).args, args)))
383 | f.__annotations__.update(kwargs)
384 | return f
385 | return _wrap
386 |
387 | def fannotate(*args, **kwargs):
388 | '''Set function annotations using decorators.
389 |
390 | :param \*args: The first positional argument is used for the function's return value; all others are discarded.
391 |
392 | :param \**kwargs: This is a mapping of argument names to annotations.
393 |
394 | Example::
395 |
396 | @fannotate('This for the return value', a='Parameter a', b='Parameter b')
397 | def x(a, b):
398 | pass
399 |
400 | '''
401 | def _wrap(f):
402 | if not hasattr(f, '__annotations__'):
403 | f.__annotations__ = {}
404 | if len(args) >= 1:
405 | f.__annotations__['return'] = args[0]
406 | f.__annotations__.update(kwargs)
407 | return f
408 | return _wrap
409 |
410 | def safe_unpack(seq, ln, fill=None):
411 | '''Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off.
412 |
413 | :param seq: The sequence to unpack.
414 |
415 | :param ln: The expected length of the sequence.
416 |
417 | :param fill: The value to substitute if the sequence is too small. Defaults to ``None``.
418 |
419 | Example::
420 |
421 | s = 'a:b'
422 | a, b = safe_unpack(s.split(':'), 2)
423 | # a = 'a'
424 | # b = 'b'
425 | s = 'a'
426 | a, b = safe_unpack(s.split(':'), 2)
427 | # a = 'a'
428 | # b = None'''
429 | if len(seq) > ln:
430 | return seq[:ln]
431 | elif len(seq) < ln:
432 | return seq + type(seq)([fill]*(ln-len(seq)))
433 | else:
434 | return seq
435 |
436 | def assign(varname, value):
437 | '''Assign `value` to `varname` and return it. If `varname` is an attribute and the instance name it belongs to is not defined, a NameError is raised.
438 | This can be used to emulate assignment as an expression. For example, this::
439 |
440 | if assign('x', 7): ...
441 |
442 | is equilavent to this C code::
443 |
444 | if (x = 7) ...
445 |
446 | .. warning::
447 |
448 | When assigning an attribute, the instance it belongs to MUST be declared as global prior to the assignment. Otherwise, the assignment will not work.
449 | '''
450 | fd = inspect.stack()[1][0].f_globals
451 | if '.' not in varname:
452 | fd[varname] = value
453 | else:
454 | vsplit = list(map(str.strip, varname.split('.')))
455 | if vsplit[0] not in fd:
456 | raise NameError('Unknown object: %s'%vsplit[0])
457 | base = fd[vsplit[0]]
458 | for x in vsplit[1:-1]:
459 | base = getattr(base, x)
460 | setattr(base, vsplit[-1], value)
461 | return value
462 |
463 | def is_main(frame=1):
464 | "Return if the caller is main. Equilavent to ``__name__ == '__main__'``."
465 | return inspect.stack()[frame][0].f_globals['__name__'] == '__main__'
466 |
467 | def _call_if_main(frame, f, args):
468 | if is_main(frame): return f(*args)
469 |
470 | def call_if_main(f,*args):
471 | "Call the `f` with `args` if the caller's module is main."
472 | return _call_if_main(3,f,args)
473 |
474 | def run_main(f,*args):
475 | "Call `f` with the `args` and terminate the program with its return code if the caller's module is main."
476 | sys.exit(_call_if_main(3,f,args))
477 |
478 | def compare_and_swap(var, compare, new):
479 | "If `var` is equal to `compare`, set it to `new`."
480 | if assign('v', inspect.stack()[1][0].f_globals)[var] == compare:
481 | v[var] = new
482 |
--------------------------------------------------------------------------------
/src/RL/rewards/rewards.py:
--------------------------------------------------------------------------------
1 | from .testing_util import run_test
2 | import torch
3 | import re
4 | import numpy as np
5 | from transformers import Qwen2ForSequenceClassification, AutoModelForCausalLM
6 | from transformers import AutoTokenizer
7 |
8 | from .prm_utils import get_process_rewards, PromptType, PrefixesType
9 |
10 |
11 | def phi(R_i, r_i, t, alpha_t_func, gamma, m=None):
12 | r"""
13 | Calculate the aggregated reward function \(\phi(R_i, r_i^{1:m})\).
14 |
15 | Parameters:
16 | - R_i: The final reward (scalar).
17 | - r_i: A sequence of intermediate rewards (array of length m).
18 | - t: The current timestep (scalar).
19 | - alpha_t_func: A function for the time-varying factor \(\alpha(t)\), which takes the time step t as input.
20 | - gamma: The discount factor (scalar, in the range [0, 1]).
21 | - m: The number of intermediate rewards (scalar).
22 |
23 | Returns:
24 | - Aggregated reward (scalar).
25 | """
26 | # Calculate the weighted sum of intermediate rewards, considering the discount factor
27 | if m is None:
28 | m = len(r_i)
29 | else:
30 | assert len(r_i) == m, "Number of intermediate rewards must match the given value of m."
31 | weighted_intermediate_rewards = np.sum([gamma**j * r_i[j] for j in range(m)])
32 |
33 | # Get the time-varying factor alpha(t)
34 | alpha_t = alpha_t_func(t)
35 |
36 | # Compute the aggregated reward using the given formula
37 | aggregated_reward = alpha_t * R_i + (1 - alpha_t) * (weighted_intermediate_rewards / m)
38 |
39 | return aggregated_reward
40 |
41 | # Example: Define a linear decay function for alpha(t)
42 | def linear_alpha(t, alpha_max=1.0, alpha_min=0.1, decay_rate=0.01):
43 | r"""
44 | A simple linear decay function for the time-varying factor \(\alpha(t)\).
45 |
46 | Parameters:
47 | - t: The current timestep (scalar).
48 | - alpha_max: The maximum value for \(\alpha(t)\) at t = 0 (default 1.0).
49 | - alpha_min: The minimum value for \(\alpha(t)\) as t increases (default 0.1).
50 | - decay_rate: The rate of decay (default 0.01).
51 |
52 | Returns:
53 | - The time-varying factor \(\alpha(t)\) (scalar).
54 | """
55 | return max(alpha_min, alpha_max - decay_rate * t)
56 |
57 |
58 |
59 | class RewardAggregater():
60 | def __init__(self,model,tokenizer, phi_func=phi, alpha_func=linear_alpha, gamma=0.9, device='cuda'):
61 |
62 | self.device = device
63 | self.phi_func = phi_func
64 | self.alpha_func = alpha_func
65 | self.gamma = gamma
66 |
67 | self.model = model
68 | self.tokenizer = tokenizer
69 |
70 | self.device = torch.device(self.device)
71 | self.model.eval() # Set model to evaluation mode
72 |
73 | self.reward = 0 # Initialize reward
74 |
75 |
76 | def compute_intermediate_rewards(self, prompts: PromptType, intermediate_texts: PrefixesType):
77 | process_rewards = get_process_rewards(
78 | self.model,
79 | self.tokenizer,
80 | prompts=prompts,
81 | completed_processes=intermediate_texts,
82 | tokenized_format='chat_completion',
83 | )
84 | # retrun (probability of good, probability of bad)
85 | process_rewards = [good_and_bad_probs[0] for good_and_bad_probs in process_rewards]
86 | return process_rewards
87 |
88 |
89 | def update_reward(self, prompt, intermediate_texts, in_outs, current_timestep, outcome_reward=None):
90 |
91 | # Compute intermediate rewards
92 | intermediate_rewards = self.compute_intermediate_rewards(prompt, intermediate_texts)
93 |
94 | # Number of intermediate rewards
95 | m = len(intermediate_rewards)
96 | if outcome_reward is None:
97 | outcome_reward = self.comupte_outcome_reward(intermediate_texts[-1], in_outs)
98 | # Calculate the aggregated reward
99 | self.reward = self.phi_func(
100 | R_i=outcome_reward,
101 | r_i=intermediate_rewards,
102 | t=current_timestep,
103 | alpha_t_func=self.alpha_func,
104 | gamma=self.gamma,
105 | m=m
106 | )
107 |
108 | return self.reward
109 |
110 | def exctract_runable_code(self, text):
111 | pattern = r"```python(.*?)```"
112 |
113 | try:
114 | code_snippets = re.findall(pattern, text, re.DOTALL)[-1]
115 | except:
116 | code_snippets = None
117 |
118 | return code_snippets
119 |
120 |
121 | def comupte_outcome_reward(self, final_step, in_outs):
122 | code = self.exctract_runable_code(final_step)
123 | if code == None:
124 | return 0
125 | result = run_test(code, in_outs, debug=True)
126 |
127 | score = [1 for r in result if r == True]
128 | score = sum(score)/len(result)
129 | return score
130 |
131 | def reset_reward(self):
132 | """
133 | Reset the aggregated reward to zero.
134 | """
135 | self.reward = 0
136 |
137 | if __name__ == "__main__":
138 | ...
--------------------------------------------------------------------------------
/src/RL/rewards_test.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | from rewards.rewards import RewardAggregater
5 | from rewards.prm_utils import CODEPRM_PROMPT
6 | from rewards.examples import rewards_test_examples
7 |
8 |
9 | def test():
10 | model = AutoModelForCausalLM.from_pretrained(
11 | 'path/to/PRM',
12 | torch_dtype=torch.bfloat16,
13 | attn_implementation="flash_attention_2"
14 | ).to('cuda')
15 | tokenizer = AutoTokenizer.from_pretrained('path/to/PRM')
16 | aggregator = RewardAggregater(
17 | model=model,
18 | tokenizer=tokenizer
19 | )
20 |
21 | for test_example in rewards_test_examples:
22 | question, reasoning_steps, test_cases = test_example['question'], test_example['reasoning_steps'], test_example['test_cases']
23 | prompt = CODEPRM_PROMPT.format(question=question)
24 | test_reward = aggregator.update_reward(
25 | prompt,
26 | reasoning_steps,
27 | test_cases,
28 | 1,
29 | )
30 | print(test_reward)
31 |
32 | if __name__ == '__main__':
33 | test()
--------------------------------------------------------------------------------
/src/TestCaseGenerate/README.md:
--------------------------------------------------------------------------------
1 | ### Wash TCG Data
2 |
3 | you can use wash_code.py to generate four part of data, which only use stdin stdout to pass the case.
4 |
5 | ### Training Test Case Generate Model.
6 |
7 | accelerate launch TCG_SFT.py
--------------------------------------------------------------------------------
/src/TestCaseGenerate/TCG.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import transformers
5 | from accelerate import PartialState
6 | from datasets import load_dataset, Dataset
7 | from peft import LoraConfig
8 | from transformers import (
9 | AutoTokenizer,
10 | AutoModelForCausalLM,
11 | BitsAndBytesConfig,
12 | logging,
13 | set_seed,
14 | )
15 | from trl import DPOTrainer, SFTTrainer
16 | import numpy as np
17 | from transformers import AutoModelForCausalLM, AutoTokenizer
18 | import random
19 | from peft import PeftModel, LoraConfig
20 | import torch.distributed as dist
21 | import json
22 | from torch.utils.data import DataLoader, DistributedSampler
23 | import tqdm
24 | import re
25 | import io, sys
26 | import multiprocessing
27 |
28 |
29 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
30 |
31 | def get_args():
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument("--split", type=str, default="train")
34 | parser.add_argument("--dataset_text_field", type=str, default="text")
35 |
36 | parser.add_argument("--max_seq_length", type=int, default=1024 * 4)
37 | parser.add_argument("--max_steps", type=int, default=5000)
38 | parser.add_argument("--micro_batch_size", type=int, default=1)
39 | parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
40 | parser.add_argument("--weight_decay", type=float, default=0.01)
41 | parser.add_argument("--bf16", type=bool, default=True)
42 |
43 | parser.add_argument("--attention_dropout", type=float, default=0.1)
44 | parser.add_argument("--learning_rate", type=float, default=5e-4)
45 | parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
46 | parser.add_argument("--warmup_steps", type=int, default=100)
47 | parser.add_argument("--seed", type=int, default=0)
48 | parser.add_argument("--output_dir", type=str, default="finetune_deepseek1.3_instruct_o1_format_SFT")
49 | parser.add_argument("--num_proc", type=int, default=None)
50 |
51 | parser.add_argument("--model_path", type=str, default="/data/FastSSD/LLM_Models/deepseek-coder-1.3b-instruct/")
52 | return parser.parse_args()
53 |
54 | def build_test_part(A, B):
55 | test_part = """```case
56 | # input:
57 | {}
58 | # output:
59 | {}
60 | ```
61 | """.format(A.strip(), B.strip())
62 | if len(test_part) > 100:
63 | raise ValueError
64 | return test_part
65 |
66 | def build_TACO_SFT(item):
67 | return_item_lst = []
68 | for solve in eval(item['solutions']):
69 | in_out_case = eval(item['input_output'])
70 | test_case = [[x[0], x[1]] for x in zip(in_out_case['inputs'], in_out_case['outputs'])]
71 | test_case = random.choices(test_case, k=min(3, len(test_case)))
72 | test_part = "".join([build_test_part(item[0], item[1]) for item in test_case])
73 | templt = '''### Instruction
74 | Please complete the task in the code part and generate some test case in the test part that can be used to test the quality of the generated code.
75 | ### Problem
76 | {}
77 | ### Code Part
78 | {}
79 | ```python
80 | {}
81 | ```
82 | ### Test Part
83 | [Generate 3 test cases here to validate the code.]
84 | {}
85 | <|EOT|>
86 | '''.format(item['question'].strip(), ", ".join(eval(item['tags'])[:20]), solve.strip(), test_part.strip())
87 | return_item_lst.append(templt)
88 | return return_item_lst
89 |
90 | def main(args):
91 | # config
92 | bnb_config = BitsAndBytesConfig(
93 | load_in_4bit=True,
94 | bnb_4bit_quant_type="nf4",
95 | bnb_4bit_compute_dtype=torch.bfloat16,
96 | )
97 | lora_config = LoraConfig(
98 | r=1,
99 | target_modules=[
100 | "q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj",
101 | ],
102 | task_type="CAUSAL_LM",
103 | )
104 | model = AutoModelForCausalLM.from_pretrained(
105 | args.model_path,
106 | quantization_config=bnb_config,
107 | attention_dropout=args.attention_dropout,
108 | device_map={"": PartialState().process_index},
109 | torch_dtype=torch.bfloat16,
110 | attn_implementation="flash_attention_2",
111 | )
112 | tokenizer = AutoTokenizer.from_pretrained(args.model_path)
113 |
114 | select_data = json.load(open("select_data_train_can_use.json"))
115 | SFT_dataset = []
116 | length = []
117 | too_long_count = 0
118 | for item in tqdm.tqdm(select_data):
119 | try:
120 | if eval(item['solutions']) != []:
121 | item_str_lst = build_TACO_SFT(item)
122 | if len(item_str_lst) >= 30:
123 | item_str_lst = random.choices(item_str_lst, k=30)
124 | for item_str in item_str_lst:
125 | if len(item_str) > 3000:
126 | too_long_count += 1
127 | continue
128 | length.append(len(item_str))
129 | SFT_dataset.append(item_str)
130 | except:
131 | pass
132 | print("Avg Length Len:", np.mean(length), len(SFT_dataset), too_long_count)
133 | SFT_dataset = Dataset.from_dict({"text": SFT_dataset})
134 | SFT_dataset = SFT_dataset.shuffle(seed=42)
135 |
136 | # setup the SFT trainer
137 | trainer = SFTTrainer(
138 | model=model,
139 | train_dataset=SFT_dataset,
140 | max_seq_length=args.max_seq_length,
141 | tokenizer=tokenizer,
142 | args=transformers.TrainingArguments(
143 | per_device_train_batch_size=args.micro_batch_size,
144 | gradient_accumulation_steps=args.gradient_accumulation_steps,
145 | warmup_steps=args.warmup_steps,
146 | max_steps=5000,
147 | learning_rate=args.learning_rate,
148 | lr_scheduler_type=args.lr_scheduler_type,
149 | weight_decay=args.weight_decay,
150 | bf16=args.bf16,
151 | fp16=False,
152 | logging_strategy="steps",
153 | logging_steps=1,
154 | output_dir=args.output_dir,
155 | optim="paged_adamw_8bit",
156 | seed=args.seed,
157 | ),
158 | peft_config=lora_config,
159 | dataset_text_field="text",
160 | )
161 | # launch
162 | print("Training SFT...")
163 | trainer.train()
164 | model.save_pretrained(os.path.join(args.output_dir, "SFT_final_checkpoint/"))
165 | tokenizer.save_pretrained(os.path.join(args.output_dir, "SFT_final_checkpoint/"))
166 | print("SFT Training Done!")
167 |
168 |
169 | # accelerate launch TCG_SFT.py
170 | if __name__ == "__main__":
171 | args = get_args()
172 | set_seed(args.seed)
173 | logging.set_verbosity_error()
174 | main(args)
175 |
--------------------------------------------------------------------------------
/src/TestCaseGenerate/wash_code.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import json
3 | import requests
4 | import sys
5 | import multiprocessing
6 | import io
7 | import re
8 | import ray
9 | import ast
10 | import numpy as np
11 | import threading
12 | from datasets import load_dataset, Dataset
13 |
14 | # 初始化 Ray
15 | ray.init(ignore_reinit_error=True)
16 |
17 | # 超时终止器
18 | def timeout_handler(process, result_container):
19 | process.terminate() # 强制终止进程
20 | result_container["error"] = "Timeout"
21 |
22 | @ray.remote
23 | def execute_code_with_timeout(test_code, input_data, timeout=3):
24 | """
25 | 使用 Ray 实现的远程函数,用于在限定时间内执行代码。
26 | """
27 | manager = multiprocessing.Manager() # 使用 Python 标准库中的 multiprocessing.Manager()
28 | result_container = manager.dict()
29 |
30 | def target_func():
31 | try:
32 | sys.stdin = io.StringIO(input_data)
33 | sys.stdout = io.StringIO()
34 | env_test = {}
35 | exec(test_code, env_test)
36 | result_container["output"] = sys.stdout.getvalue().strip()
37 | result_container["error"] = None
38 | except Exception as e:
39 | result_container["error"] = str(e)
40 |
41 | process = multiprocessing.Process(target=target_func) # 创建进程执行代码
42 | process.start()
43 |
44 | timer = threading.Timer(timeout, timeout_handler, args=(process, result_container)) # 设置超时定时器
45 | timer.start()
46 |
47 | process.join(timeout) # 等待进程完成
48 | timer.cancel() # 取消定时器(如果进程已完成)
49 |
50 | return result_container.get("output", None), result_container.get("error", None)
51 |
52 | @ray.remote
53 | def worker(case, test_code, timeout):
54 | """
55 | Worker 函数,用于执行单个测试用例。
56 | """
57 | input_data = case['input'].strip()
58 | expected_output = case['output'].strip()
59 | actual_output, error = ray.get(execute_code_with_timeout.remote(test_code, input_data, timeout))
60 |
61 | if error:
62 | return {
63 | 'input': input_data,
64 | 'expected_output': expected_output,
65 | 'actual_output': None,
66 | 'error': error,
67 | 'status': 'failed'
68 | }
69 | if actual_output == expected_output:
70 | return {
71 | 'input': input_data,
72 | 'expected_output': expected_output,
73 | 'actual_output': actual_output,
74 | 'error': None,
75 | 'status': 'passed'
76 | }
77 | else:
78 | return {
79 | 'input': input_data,
80 | 'expected_output': expected_output,
81 | 'actual_output': actual_output,
82 | 'error': None,
83 | 'status': 'failed'
84 | }
85 |
86 | def test_cases_with_limit(output_cases, test_right_code, timeout=3):
87 | """
88 | 使用 Ray 实现的并行化测试用例函数。
89 | """
90 | futures = [worker.remote(case, test_right_code, timeout) for case in output_cases]
91 | results = ray.get(futures)
92 |
93 | passed_case = [res for res in results if res['status'] == 'passed']
94 | unpassed_case = [res for res in results if res['status'] == 'failed']
95 |
96 | return passed_case, unpassed_case
97 |
98 |
99 | if __name__ == "__main__":
100 | can_use_case = []
101 | can_use_case_all_pass = []
102 | pass_rate = []
103 | ds = load_dataset("/data/FastSSD/LLM_Models/TACO")
104 | Pbar = tqdm.tqdm(ds['train'])
105 | for item in Pbar:
106 | if eval(item['solutions']) == []:
107 | continue
108 | try:
109 | express_code = json.loads(item['solutions'])[0]
110 | example_case = json.loads(item['input_output'])
111 | except:
112 | continue
113 |
114 | all_example_case = [{'input': str(x), 'output': str(y)} for x, y in zip(example_case['inputs'], example_case['outputs'])]
115 | passed_case, unpassed_case = test_cases_with_limit(all_example_case, express_code, timeout=10)
116 |
117 | # 计算通过率
118 | if (len(passed_case) + len(unpassed_case)) == 0:
119 | current_pass_rate = 0
120 | else:
121 | current_pass_rate = len(passed_case) / (len(passed_case) + len(unpassed_case))
122 | pass_rate.append(current_pass_rate)
123 | if current_pass_rate == 1.0:
124 | can_use_case_all_pass.append(item)
125 | if current_pass_rate > 0.0:
126 | can_use_case.append(item)
127 | Pbar.set_description(f"Avg Pass: {np.mean(pass_rate)}, All Pass: {np.mean(np.array(pass_rate) == 1.)}, Current: {current_pass_rate}")
128 |
129 | json.dump(can_use_case, open("/home/xukaiyuan/Project/TreeSearch_Code/wash_code/select_data_train_can_use.json", 'w'))
130 | json.dump(can_use_case_all_pass, open("/home/xukaiyuan/Project/TreeSearch_Code/wash_code/select_data_train_all_pass.json", 'w'))
131 |
--------------------------------------------------------------------------------
/src/mcts/README.md:
--------------------------------------------------------------------------------
1 | ## How to Use
2 |
3 | ### Data Preparation
4 |
5 | You can find the TACO dataset on [Hugging Face](https://huggingface.co/datasets/BAAI/TACO). Convert the dataset into a JSON file and place it in the `\data\TACO` directory.
6 |
7 | ---
8 |
9 | ### Generate Data
10 |
11 | #### Local Model
12 |
13 | Run the `run_TACO.sh` script with the appropriate configurations to use the local model for data generation:
14 |
15 | ```bash
16 | bash run_TACO.sh
17 | ```
18 |
19 | #### Main Arguments
20 |
21 | | Argument | Type | Description |
22 | |-----------------------|--------|-----------------------------------------|
23 | | `--dataset_name` | str | Name of the dataset folder in the `data` directory |
24 | | `--test_json_filename`| str | Name of the JSON file containing the data |
25 | | `--model_ckpt` | str | Path to the model checkpoint |
26 | | `--num_rollouts` | int | Number of MCTS rollouts |
27 | | `--max_depth_allowed` | int | Maximum depth allowed for the MCTS search tree |
28 |
29 | ---
30 |
31 | #### OpenAI API
32 |
33 | First, configure your `api_key` in the `\models\OpenAI_API.py` file.
34 |
35 | Then, run the `api_run_TACO.sh` script to use the API for data generation:
36 |
37 | ```bash
38 | bash api_run_TACO.sh
39 | ```
40 |
41 | #### Additional Arguments
42 |
43 | | Argument | Type | Description |
44 | |-----------------------|--------|-----------------------------------------|
45 | | `--api` | str | Default is `vllm`, which calls the local model |
46 | | `--model_ckpt` | str | Specific OpenAI model name |
47 |
48 | ---
49 |
50 | ### Data Example
51 |
52 | A sample dataset can be found in the `run_outputs` folder. Detailed information for each problem is available in the `answer_sheets` folder.
53 |
54 | #### File Information
55 |
56 | | Filename | Description |
57 | |-----------------------------------------|-----------------------------------------|
58 | | `Question XXXX - Answer.json` | Contains the original question information |
59 | | `Question XXXX - Best Solution.json` | The path with the highest reward in the final step |
60 | | `Question XXXX - Complete Solutions.json` | All complete paths in the MCTS search tree |
61 | | `Question XXXX - Rollout Solutions.json` | Paths generated during each MCTS rollout |
62 | | `args.json` | Parameter configuration information |
63 | | `intermediate_result.txt` | Logs for model calls and intermediate results |
64 |
65 |
66 | ## Acknowledge
67 |
68 | This code is derived from and modified based on the project available at [https://github.com/zhentingqi/rStar/](https://github.com/zhentingqi/rStar/).
69 |
--------------------------------------------------------------------------------
/src/mcts/common/arguments.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 |
3 | import os, json, torch, math
4 | from argparse import ArgumentParser
5 | from datetime import datetime
6 |
7 |
8 | def get_parser():
9 | parser = ArgumentParser()
10 |
11 | parser.add_argument("--note", type=str, default="debug")
12 |
13 | allowed_apis = ["together", "huggingface", "llama", "vllm", "debug", "OpenAI"]
14 | parser.add_argument(
15 | "--api", type=str, choices=allowed_apis, default="vllm", help=f"API to use: Choose from {allowed_apis}."
16 | )
17 |
18 | parser.add_argument("--seed", type=int, default=42)
19 | parser.add_argument("--verbose", action="store_true")
20 |
21 | #! WandB settings
22 | parser.add_argument("--wandb_mode", type=str, default="disabled", choices=["disabled", "online"])
23 |
24 | #! LLM settings
25 | parser.add_argument("--model_ckpt", required=True)
26 |
27 | parser.add_argument("--model_parallel", action="store_true")
28 | parser.add_argument("--half_precision", action="store_true")
29 |
30 | parser.add_argument("--max_tokens", type=int, default=4096, help="max_tokens")
31 | parser.add_argument("--temperature", type=float, default=0.8, help="temperature")
32 | parser.add_argument("--top_k", type=int, default=40, help="top_k")
33 | parser.add_argument("--top_p", type=float, default=0.95, help="top_p")
34 | parser.add_argument("--num_beams", type=int, default=1, help="num_beams")
35 |
36 | parser.add_argument("--max_num_worker", type=int, default=3, help="maximum number of workers for dataloader")
37 | parser.add_argument("--test_batch_size", type=int, default=1) # batch_size
38 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size
39 |
40 | #! prompt settings
41 | parser.add_argument("--prompts_root", default="prompts")
42 |
43 | #! dataset settings
44 | parser.add_argument("--data_root", default="data")
45 | allowed_dataset_names = ["TACO"]
46 | parser.add_argument(
47 | "--dataset_name",
48 | required=True,
49 | choices=allowed_dataset_names,
50 | help=f"Test dataset name: Choose from {allowed_dataset_names}.",
51 | )
52 | parser.add_argument("--test_json_filename", type=str, default="test_all")
53 | parser.add_argument("--start_idx", type=int, default=0, help="Start index of test questions (inclusive)")
54 | parser.add_argument("--end_idx", type=int, default=math.inf, help="End index of test questions (inclusive))")
55 |
56 | #! outputs settings
57 | parser.add_argument("--run_outputs_root", type=str, default="run_outputs")
58 | parser.add_argument("--eval_outputs_root", type=str, default="eval_outputs")
59 | parser.add_argument("--run_outputs_dir", type=str, default="")
60 |
61 | return parser
62 |
63 |
64 | def post_process_args(args):
65 | # Set up logging
66 | suffix = "---[" + args.note + "]" if args.note is not None else ""
67 | model_name = args.model_ckpt.split("/")[-1]
68 | if args.run_outputs_dir == "":
69 | args.run_outputs_dir = os.path.join(
70 | args.run_outputs_root,
71 | args.dataset_name,
72 | model_name,
73 | f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + suffix,
74 | )
75 | else:
76 | args.run_outputs_dir = os.path.join(
77 | args.run_outputs_root,
78 | args.dataset_name,
79 | model_name,
80 | f"{args.run_outputs_dir}---{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + suffix,
81 | )
82 | os.makedirs(args.run_outputs_dir, exist_ok=True)
83 |
84 |
85 | args.answer_sheets_dir = os.path.join(args.run_outputs_dir, "answer_sheets")
86 | os.makedirs(args.answer_sheets_dir, exist_ok=True)
87 |
88 | # Check GPU
89 | num_gpus = torch.cuda.device_count()
90 | cuda_devices = [torch.cuda.get_device_name(i) for i in range(num_gpus)]
91 | assert len(cuda_devices) > 0, "No GPU available."
92 | args.cuda_0 = cuda_devices[0]
93 | args.cuda_1 = cuda_devices[1] if len(cuda_devices) > 1 else None
94 | args.cuda_2 = cuda_devices[2] if len(cuda_devices) > 2 else None
95 | args.cuda_3 = cuda_devices[3] if len(cuda_devices) > 3 else None
96 |
97 | return args
98 |
99 |
100 | def save_args(args):
101 | # Save args as json
102 | with open(os.path.join(args.run_outputs_dir, "args.json"), "w") as f:
103 | json.dump(vars(args), f, indent=4)
104 |
--------------------------------------------------------------------------------
/src/mcts/common/utils.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 |
3 | import json
4 | import re
5 | import os
6 | import random
7 | import numpy as np
8 | import torch
9 | import multiprocessing
10 | from typing import Tuple
11 | from statistics import mean
12 | from torch.utils.data import Dataset
13 |
14 |
15 | def fix_seeds(seed):
16 | # random
17 | random.seed(seed)
18 | # Numpy
19 | np.random.seed(seed)
20 | # Pytorch
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | torch.backends.cudnn.deterministic = True
24 | torch.backends.cudnn.benchmark = False
25 |
26 |
27 | def setup_model_parallel() -> Tuple[int, int]:
28 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel
29 |
30 | local_rank = int(os.environ.get("LOCAL_RANK", -1))
31 | world_size = int(os.environ.get("WORLD_SIZE", -1))
32 |
33 | torch.distributed.init_process_group("nccl")
34 | initialize_model_parallel(world_size)
35 | torch.cuda.set_device(local_rank)
36 |
37 | return local_rank, world_size
38 |
39 |
40 | def read_json(file_path):
41 | assert str(file_path).endswith(".json")
42 | with open(file_path, "r", encoding="utf-8") as f:
43 | data = json.load(f)
44 | return data
45 |
46 |
47 | def save_json(js_obj, file_path):
48 | assert str(file_path).endswith(".json")
49 | with open(file_path, "w", encoding="utf-8") as f:
50 | json.dump(js_obj, f, indent=4)
51 |
52 |
53 | def read_txt(file_path):
54 | assert str(file_path).endswith(".txt")
55 | with open(file_path, "r", encoding="utf-8") as f:
56 | data = f.read()
57 | return data
58 |
59 |
--------------------------------------------------------------------------------
/src/mcts/data/TACO/self_create_data.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "question": "In a galaxy far, far away, the Empire has launched a series of attacks using their powerful weapon, the Death Star. To counter this threat, the Rebel Alliance has decided to arm themselves with the legendary Lightsabers of the Jedi. Each Lightsaber has a certain power, and the sum of the powers of the Lightsabers will determine the strength of the Alliance's defense. You, as a strategist, must calculate the sum of the powers of the Lightsabers provided by your Jedi allies to ensure the best possible defense against the Empire's onslaught.\n\nInput:\nThe first line contains an integer T, representing the number of sets of Lightsabers. Each set is described in the following lines. The first number in each set is an integer n, representing the number of Lightsabers in that set, followed by n space-separated integers, each representing the power of a Lightsaber.\n\nOutput:\nFor each set of Lightsabers, calculate and print the sum of their powers. If there is no Lightsaber in a set, print 0 for that set.\n\nConstraints:\n1 ≤ T ≤ 10^6,\n0 ≤ n ≤ 10^6,\n0 ≤ power of each Lightsaber ≤ 10^9.\n\nSAMPLE INPUT\n2\n3 1 2 3\n4 10 20 30 40\n\nSAMPLE OUTPUT\n6\n100",
4 | "solutions": [],
5 | "starter_code": "",
6 | "input_output": {
7 | "inputs": [
8 | "1\n2\n1 2",
9 | "2\n3\n1 2 3\n2\n4 5",
10 | "1\n4\n0 0 0 0",
11 | "3\n2\n10 20\n2\n30 40\n2\n50 60",
12 | "1\n5\n7 8 9 10 11",
13 | "2\n1\n100\n3\n1 1 1",
14 | "1\n2\n-1 1",
15 | "1\n10\n1 2 3 4 5 6 7 8 9 10",
16 | "2\n4\n1 1 1 1\n4\n2 2 2 2",
17 | "1\n3\n0 0 0",
18 | "2\n5\n11 22 33 44 55\n5\n66 77 88 99 100",
19 | "1\n1\n1000000000",
20 | "2\n2\n-5 5\n3\n-10 0 10"
21 | ],
22 | "outputs": [
23 | "3",
24 | "6\n9",
25 | "0",
26 | "30\n70\n110",
27 | "45",
28 | "100\n3",
29 | "0",
30 | "55",
31 | "4\n8",
32 | "0",
33 | "165\n430",
34 | "1000000000",
35 | "0\n0"
36 | ]
37 | },
38 | "difficulty": "EASY",
39 | "raw_tags": "['Algorithms', 'ad-hoc', 'Constructive']",
40 | "name": null,
41 | "source": "codechef",
42 | "tags": "['Constructive algorithms', 'Ad-hoc']",
43 | "skill_types": [],
44 | "url": "https://www.codechef.com/problems/RRJOKE",
45 | "Expected Auxiliary Space": null,
46 | "time_limit": "1 seconds",
47 | "date": "2014-12-18",
48 | "picture_num": 0,
49 | "memory_limit": "50000 bytes",
50 | "Expected Time Complexity": null
51 | }
52 | ]
--------------------------------------------------------------------------------
/src/mcts/data/TACO/test_one.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "question": "**Step 1:** Create a function called `encode()` to replace all the lowercase vowels in a given string with numbers according to the following pattern:\n```\na -> 1\ne -> 2\ni -> 3\no -> 4\nu -> 5\n```\n\nFor example, `encode(\"hello\")` would return `\"h2ll4\"`. There is no need to worry about uppercase vowels in this kata.\n\n**Step 2:** Now create a function called `decode()` to turn the numbers back into vowels according to the same pattern shown above.\n\nFor example, `decode(\"h3 th2r2\")` would return `\"hi there\"`.\n\nFor the sake of simplicity, you can assume that any numbers passed into the function will correspond to vowels.",
4 | "solutions": [
5 | "def encode(s, t=str.maketrans('aeiou', '12345')):\n\treturn s.translate(t)\n\ndef decode(s, t=str.maketrans('12345', 'aeiou')):\n\treturn s.translate(t)\n",
6 | "CIPHER = ('aeiou', '12345')\n\ndef encode(st):\n\treturn st.translate(str.maketrans(CIPHER[0], CIPHER[1]))\n\ndef decode(st):\n\treturn st.translate(str.maketrans(CIPHER[1], CIPHER[0]))\n",
7 | "def encode(st):\n\tfor (i, v) in enumerate('aeiou', start=1):\n\t\tst = st.replace(v, str(i))\n\treturn st\n\ndef decode(st):\n\tfor (i, v) in enumerate('aeiou', start=1):\n\t\tst = st.replace(str(i), v)\n\treturn st\n",
8 | "tbl1 = str.maketrans('aeiou', '12345')\ntbl2 = str.maketrans('12345', 'aeiou')\n\ndef encode(st):\n\treturn st.translate(tbl1)\n\ndef decode(st):\n\treturn st.translate(tbl2)\n",
9 | "a = {'a': '1', 'e': '2', 'i': '3', 'o': '4', 'u': '5'}\nb = ('a', 'e', 'i', 'o', 'u')\n\ndef encode(st):\n\treturn ''.join((a[c] if c in a else c for c in st))\n\ndef decode(st):\n\treturn ''.join((b[int(c) - 1] if c.isdigit() else c for c in st))\n",
10 | "def cipher(mode):\n\ttable = str.maketrans(*['aeiou', '12345'][::mode])\n\treturn lambda s: s.translate(table)\n(encode, decode) = (cipher(1), cipher(-1))\n",
11 | "CYPHER = tuple(zip('aeiou', '12345'))\n\ndef munge(st, mapping):\n\treturn ''.join([mapping.get(c, c) for c in st])\n\ndef encode(st):\n\treturn munge(st, {a: b for (a, b) in CYPHER})\n\ndef decode(st):\n\treturn munge(st, {b: a for (a, b) in CYPHER})\n",
12 | "def encode(st):\n\tL = []\n\tA = {'a': '1', 'e': '2', 'i': '3', 'o': '4', 'u': '5'}\n\tfor i in st:\n\t\tif i in A:\n\t\t\tL.append(A[i])\n\t\telse:\n\t\t\tL.append(i)\n\treturn ''.join(L)\n\ndef decode(st):\n\tL = []\n\tA = {'1': 'a', '2': 'e', '3': 'i', '4': 'o', '5': 'u'}\n\tfor i in st:\n\t\tif i in A:\n\t\t\tL.append(A[i])\n\t\telse:\n\t\t\tL.append(i)\n\treturn ''.join(L)\n",
13 | "import re\n\ndef encode(st):\n\tvowel = ' aeiou'\n\treturn re.sub('[aeoui]', lambda x: str(vowel.index(x.group(0))), st)\n\ndef decode(st):\n\tvowel = ' aeiou'\n\treturn re.sub('[1-5]', lambda x: vowel[int(x.group(0))], st)\n",
14 | "a = ['a', 'e', 'i', 'o', 'u']\n\ndef encode(st):\n\treturn ''.join([str(a.index(c) + 1) if c in a else c for c in st])\n\ndef decode(st):\n\treturn ''.join([a[int(c) - 1] if c.isdigit() else c for c in st])\n"
15 | ],
16 | "starter_code": "def encode(st):\n\t",
17 | "input_output": {
18 | "fn_name": "encode",
19 | "inputs": [
20 | [
21 | "hello"
22 | ],
23 | [
24 | "How are you today?"
25 | ],
26 | [
27 | "This is an encoding test."
28 | ]
29 | ],
30 | "outputs": [
31 | [
32 | "h2ll4"
33 | ],
34 | [
35 | "H4w 1r2 y45 t4d1y?"
36 | ],
37 | [
38 | "Th3s 3s 1n 2nc4d3ng t2st."
39 | ]
40 | ]
41 | },
42 | "difficulty": "EASY",
43 | "raw_tags": "['Regular Expressions', 'Strings', 'Fundamentals', 'Arrays']",
44 | "name": null,
45 | "source": "codewars",
46 | "tags": "['String algorithms', 'Fundamentals', 'Data structures']",
47 | "skill_types": "['Data structures']",
48 | "url": "https://www.codewars.com/kata/53697be005f803751e0015aa",
49 | "Expected Auxiliary Space": null,
50 | "time_limit": null,
51 | "date": null,
52 | "picture_num": null,
53 | "memory_limit": null,
54 | "Expected Time Complexity": null
55 | }
56 | ]
--------------------------------------------------------------------------------
/src/mcts/eval_src/Evaluator.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 |
3 | from eval_src.checker_utils import CodeSolutionParser, check_generation_correctness
4 |
5 | import os, json, re
6 | from typing import List, Dict, Tuple
7 | from collections import defaultdict
8 | import random
9 | import copy
10 | from fuzzywuzzy import fuzz, process
11 |
12 | from multiprocessing import Manager, Process
13 | import concurrent.futures
14 |
15 |
16 | class Evaluator:
17 | def __init__(self) -> None:
18 | self.answer_marker = "answer is"
19 | self.parser = CodeSolutionParser()
20 |
21 |
22 | def find_TACO_code(self, completions: List[str], test_case: dict, solution_trace: Dict[int, Dict[str, str]],):
23 | if completions is None or len(completions) == 0:
24 | return None, None, None, None
25 | solution_trace_ = copy.deepcopy(solution_trace)
26 | id2pass_completions = defaultdict(list)
27 | pass_ratio = 0
28 | compile_pass = False
29 |
30 |
31 | for id, c in enumerate(completions):
32 | result = self.parser.process_solution(c)
33 |
34 | generation_code = result["final_code"]
35 |
36 | if "fn_name" in test_case:
37 | if "main_function" in result:
38 | if result["main_function"] is not None:
39 | if "name" in result["main_function"]:
40 | if test_case["fn_name"] != result["main_function"]['name']:
41 | test_case["fn_name"] = result["main_function"]['name']
42 |
43 |
44 |
45 | if generation_code == None:
46 | pass_ratio = 0
47 | continue
48 |
49 |
50 | correctness_results = check_generation_correctness(test_case, generation_code, debug=False, n_cases=10)
51 | # print(correctness_results)
52 |
53 | if isinstance(correctness_results, list):
54 | if True in correctness_results or False in correctness_results:
55 | compile_pass = True
56 | pass_case_count = correctness_results.count(True)
57 |
58 | # 计算比例
59 | pass_ratio = pass_case_count / len(correctness_results)
60 | else:
61 | pass_ratio = 0
62 |
63 | alpha = 0
64 | if compile_pass:
65 | pass_ratio = alpha * 1 + (1 - alpha) * pass_ratio
66 |
67 | # print(f"*********** {id} : score : {pass_ratio} *********")
68 |
69 |
70 |
71 |
72 | return "", completions[0], pass_ratio, solution_trace_
73 |
74 |
75 |
76 | class TACOEvaluator(Evaluator):
77 | def __init__(self) -> None:
78 | super().__init__()
79 |
80 | def passed(self, references):
81 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
82 | results = []
83 |
84 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
85 | future_results = {executor.submit(self.run_code_with_timeout, references)}
86 | for future in concurrent.futures.as_completed(future_results):
87 | results.append(future.result())
88 | # print(results)
89 | return results[0] == 'passed'
90 |
91 | def run_code_with_timeout(self, code_string, timeout=1):
92 | with Manager() as manager:
93 | result_dict = manager.dict()
94 | process = Process(target=self.exec_code, args=(code_string, result_dict))
95 | process.start()
96 | process.join(timeout=timeout)
97 | if process.is_alive():
98 | process.kill()
99 | return "timeout"
100 | else:
101 | return result_dict['result']
102 |
103 | @staticmethod
104 | def exec_code(code, result_dict):
105 | result_dict['result'] = 'Not executed'
106 | try:
107 | exec_globals = {}
108 | exec(code, exec_globals)
109 | result_dict['result'] = 'passed'
110 | except Exception as e:
111 |
112 | result_dict['result'] = f'Error: {str(e)}'
113 |
114 | def extract_answer_from_gold_solution(self, solution: str):
115 | return None
116 |
117 | def extract_answer_from_model_completion(self, completion: str):
118 | if completion is None:
119 | return None
120 |
121 | assert isinstance(completion, str)
122 |
123 | preds = completion.replace('\\n', '\n')
124 | code_maker = "The code is: \[Code Start\]\s*(.*?)\s*\[Code End\]"
125 | code = re.search(code_maker, preds, re.DOTALL)
126 |
127 | if code:
128 | result = code.group(1)
129 | return str(result.replace('\\r', '').replace('\\n', '\n').replace('\\t', '\t'))
130 | else:
131 |
132 | return None
--------------------------------------------------------------------------------
/src/mcts/eval_src/checker_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import ast
3 | from typing import Optional, Dict
4 |
5 | class CodeSolutionParser:
6 | def __init__(self):
7 | self.steps = []
8 | self.final_code = None
9 | self.main_function = None
10 |
11 | def check_final_step(self, text: str) -> bool:
12 | """Check if the last step is code generation."""
13 | if text == "":
14 | return False
15 |
16 | last_step = text.lower()
17 | # Check if the last step mentions code generation
18 | code_indicators = [
19 | "```python"
20 | ]
21 |
22 | return any(indicator in last_step for indicator in code_indicators)
23 |
24 | def extract_code(self, text: str) -> str:
25 | """Extract the Python code from the last step."""
26 | if text == "":
27 | return None
28 |
29 | last_step = text
30 |
31 | # Find code between triple backticks
32 | code_pattern = r'```python(.*?)```'
33 | code_match = re.search(code_pattern, last_step, re.DOTALL)
34 |
35 | if code_match:
36 | code = code_match.group(1).strip()
37 | self.final_code = code
38 | return code
39 | return None
40 |
41 | def extract_outermost_function(self) -> Optional[Dict]:
42 | """Extract the outermost function from the code, including class methods."""
43 | if not self.final_code:
44 | return None
45 |
46 | try:
47 | # Parse the code into an AST
48 | tree = ast.parse(self.final_code)
49 |
50 | # First try to find module-level function
51 | for node in ast.iter_child_nodes(tree):
52 | if isinstance(node, ast.FunctionDef):
53 | return self._extract_function_info(node)
54 |
55 | # If no module-level function found, look for class methods
56 | for node in ast.iter_child_nodes(tree):
57 | if isinstance(node, ast.ClassDef):
58 | # Look for the first method in the class
59 | for class_node in node.body:
60 | if isinstance(class_node, ast.FunctionDef):
61 | # Skip __init__ and other special methods
62 | if not class_node.name.startswith('__'):
63 | function_info = self._extract_function_info(class_node)
64 | function_info['class_name'] = node.name
65 | return function_info
66 |
67 | except SyntaxError:
68 | return None
69 |
70 | return None
71 |
72 | def _extract_function_info(self, node: ast.FunctionDef) -> Dict:
73 | """Helper method to extract information from a function node."""
74 | function_info = {
75 | 'name': node.name,
76 | 'args': [arg.arg for arg in node.args.args],
77 | 'body': ast.unparse(node)
78 | }
79 |
80 | # Add return type annotation if exists
81 | if node.returns:
82 | function_info['return_type'] = ast.unparse(node.returns)
83 |
84 | # Add argument type annotations if exist
85 | arg_types = {}
86 | for arg in node.args.args:
87 | if arg.annotation:
88 | arg_types[arg.arg] = ast.unparse(arg.annotation)
89 | if arg_types:
90 | function_info['arg_types'] = arg_types
91 |
92 | # Add docstring if exists
93 | docstring = ast.get_docstring(node)
94 | if docstring:
95 | function_info['docstring'] = docstring
96 |
97 | return function_info
98 |
99 | def process_solution(self, text: str) -> dict:
100 | """Process the entire solution text and return results."""
101 | has_code_generation = self.check_final_step(text)
102 | code = self.extract_code(text) if has_code_generation else None
103 |
104 | # Extract the outermost function if code exists
105 | main_function = None
106 | if code:
107 | main_function = self.extract_outermost_function()
108 |
109 | return {
110 | 'has_code_generation': has_code_generation,
111 | 'final_code': code,
112 | 'main_function': main_function
113 | }
114 |
115 | import json
116 | import multiprocessing as mp
117 | import concurrent
118 | import numpy as np
119 | from typing import List, Dict, Any, Union
120 | from eval_src.testing_util import run_test
121 |
122 | TIMEOUT = 10
123 |
124 | def check_generation_correctness(
125 | test_cases: Dict[str, Union[str, List]],
126 | generation: str,
127 | timeout: int = TIMEOUT,
128 | debug: bool = False,
129 | n_cases: Optional[int] = None,
130 | ) -> List[bool]:
131 | """
132 | Args:
133 | test_cases (Dict[str, Union[str, List]]): A dictionary containing test cases with inputs and expected outputs.
134 | generation (str): The generated code to be tested.
135 | timeout (int, optional): The maximum time allowed for the test execution. Defaults to TIMEOUT.
136 | debug (bool, optional): If True, prints debug information. Defaults to False.
137 | Returns:
138 | List[bool]: A list of booleans indicating the correctness of each test case. If a timeout occurs, returns a list of -1s.
139 | """
140 |
141 | try:
142 | return run_test(test_cases, generation, debug, n_cases)
143 | except Exception as e:
144 | if debug:
145 | print(f"Error in running test cases: {e}")
146 | in_outs = test_cases
147 | return [-2] * len(in_outs["inputs"])
148 |
149 |
--------------------------------------------------------------------------------
/src/mcts/eval_src/pyext2.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2014 Ryan Gonzalez
3 |
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to use,
8 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
9 | Software, and to permit persons to whom the Software is furnished to do so,
10 | subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 | '''
22 |
23 | g_backup = globals().copy()
24 |
25 | __version__ = '0.7'
26 |
27 | __all__ = ['overload', 'RuntimeModule', 'switch', 'tail_recurse', 'copyfunc', 'set_docstring', 'annotate', 'safe_unpack', 'modify_function', 'assign', 'fannotate', 'compare_and_swap', 'is_main', 'call_if_main', 'run_main']
28 |
29 | import sys, inspect, types
30 |
31 | def __targspec(func, specs, attr='__orig_arg__'):
32 | if hasattr(func, '__is_overload__') and func.__is_overload__:
33 | return getattr(func, attr)
34 | return specs(func)
35 |
36 | def set_docstring(doc):
37 | '''A simple decorator to set docstrings.
38 |
39 | :param doc: The docstring to tie to the function.
40 |
41 | Example::
42 |
43 | @set_docstring('This is a docstring')
44 | def myfunc(x):
45 | pass'''
46 | def _wrap(f):
47 | f.__doc__ = doc
48 | return f
49 | return _wrap
50 |
51 | __modify_function_doc = '''
52 | Creates a copy of a function, changing its attributes.
53 |
54 | :param globals: Will be added to the function's globals.
55 |
56 | :param name: The new function name. Set to ``None`` to use the function's original name.
57 |
58 | :param code: The new function code object. Set to ``None`` to use the function's original code object.
59 |
60 | :param defaults: The new function defaults. Set to ``None`` to use the function's original defaults.
61 |
62 | :param closure: The new function closure. Set to ``None`` to use the function's original closure.
63 |
64 | .. warning:: This function can be potentially dangerous.
65 | '''
66 |
67 | def copyfunc(f):
68 | '''Copies a funcion.
69 |
70 | :param f: The function to copy.
71 |
72 | :return: The copied function.
73 |
74 | .. deprecated:: 0.4
75 | Use :func:`modify_function` instead.
76 | '''
77 | return modify_function(f)
78 |
79 | if sys.version_info.major == 3:
80 | @set_docstring(__modify_function_doc)
81 | def modify_function(f, globals={}, name=None, code=None, defaults=None,
82 | closure=None):
83 | if code is None: code = f.__code__
84 | if name is None: name = f.__name__
85 | if defaults is None: defaults = f.__defaults__
86 | if closure is None: closure = f.__closure__
87 | newf = types.FunctionType(code, dict(f.__globals__, **globals), name=name,
88 | argdefs=defaults, closure=closure)
89 | newf.__dict__.update(f.__dict__)
90 | return newf
91 | def argspec(f):
92 | return inspect.getfullargspec(f)
93 | ofullargspec = inspect.getfullargspec
94 | def _fullargspec(func):
95 | return __targspec(func, ofullargspec)
96 | inspect.getfullargspec = _fullargspec
97 | def _exec(m,g): exec(m,g)
98 | else:
99 | @set_docstring(__modify_function_doc)
100 | def modify_function(f, globals={}, name=None, code=None, defaults=None,
101 | closure=None):
102 | if code is None: code = f.func_code
103 | if name is None: name = f.__name__
104 | if defaults is None: defaults = f.func_defaults
105 | if closure is None: closure = f.func_closure
106 | newf = types.FunctionType(code, dict(f.func_globals, **globals), name=name,
107 | argdefs=defaults, closure=closure)
108 | newf.__dict__.update(f.__dict__)
109 | return newf
110 | def argspec(f):
111 | return inspect.getargspec(f)
112 | eval(compile('def _exec(m,g): exec m in g', '', 'exec'))
113 |
114 | def _gettypes(args):
115 | return tuple(map(type, args))
116 |
117 | oargspec = inspect.getargspec
118 |
119 | def _argspec(func):
120 | return __targspec(func, oargspec)
121 |
122 | inspect.getargspec = _argspec
123 |
124 | try:
125 | import IPython
126 | except ImportError:
127 | IPython = None
128 | else:
129 | # Replace IPython's argspec
130 | oipyargspec = IPython.core.oinspect.getargspec
131 | def _ipyargspec(func):
132 | return __targspec(func, oipyargspec, '__orig_arg_ipy__')
133 | IPython.core.oinspect.getargspec = _ipyargspec
134 |
135 | class overload(object):
136 | '''Simple function overloading in Python.'''
137 | _items = {}
138 | _types = {}
139 | @classmethod
140 | def argc(self, argc=None):
141 | '''Overloads a function based on the specified argument count.
142 |
143 | :param argc: The argument count. Defaults to ``None``. If ``None`` is given, automatically compute the argument count from the given function.
144 |
145 | .. note::
146 |
147 | Keyword argument counts are NOT checked! In addition, when the argument count is automatically calculated, the keyword argument count is also ignored!
148 |
149 | Example::
150 |
151 | @overload.argc()
152 | def func(a):
153 | print 'Function 1 called'
154 |
155 | @overload.argc()
156 | def func(a, b):
157 | print 'Function 2 called'
158 |
159 | func(1) # Calls first function
160 | func(1, 2) # Calls second function
161 | func() # Raises error
162 | '''
163 | # Python 2 UnboundLocalError fix
164 | argc = {'argc': argc}
165 | def _wrap(f):
166 | def _newf(*args, **kwargs):
167 | if len(args) not in self._items[f.__name__]:
168 | raise TypeError("No overload of function '%s' that takes %d args" % (f.__name__, len(args)))
169 | return self._items[f.__name__][len(args)](*args, **kwargs)
170 | if f.__name__ not in self._items:
171 | self._items[f.__name__] = {}
172 | if argc['argc'] is None:
173 | argc['argc'] = len(argspec(f).args)
174 | self._items[f.__name__][argc['argc']] = f
175 | _newf.__name__ = f.__name__
176 | _newf.__doc__ = f.__doc__
177 | _newf.__is_overload__ = True
178 | _newf.__orig_arg__ = argspec(f)
179 | if IPython:
180 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f)
181 | return _newf
182 | return _wrap
183 | @classmethod
184 | def args(self, *argtypes, **kw):
185 | '''Overload a function based on the specified argument types.
186 |
187 | :param argtypes: The argument types. If None is given, get the argument types from the function annotations(Python 3 only)
188 | :param kw: Can only contain 1 argument, `is_cls`. If True, the function is assumed to be part of a class.
189 |
190 | Example::
191 |
192 | @overload.args(str)
193 | def func(s):
194 | print 'Got string'
195 |
196 | @overload.args(int, str)
197 | def func(i, s):
198 | print 'Got int and string'
199 |
200 | @overload.args()
201 | def func(i:int): # A function annotation example
202 | print 'Got int'
203 |
204 | func('s')
205 | func(1)
206 | func(1, 's')
207 | func(True) # Raises error
208 | '''
209 |
210 | # Python 2 UnboundLocalError fix...again!
211 | argtypes = {'args': tuple(argtypes)}
212 | def _wrap(f):
213 | def _newf(*args):
214 | if len(kw) == 0:
215 | cargs = args
216 | elif len(kw) == 1 and 'is_cls' in kw and kw['is_cls']:
217 | cargs = args[1:]
218 | else:
219 | raise ValueError('Invalid keyword args specified')
220 | if _gettypes(cargs) not in self._types[f.__name__]:
221 | raise TypeError("No overload of function '%s' that takes '%s' types and %d arg(s)" % (f.__name__, _gettypes(cargs), len(cargs)))
222 | return self._types[f.__name__][_gettypes(cargs)](*args)
223 | if f.__name__ not in self._types:
224 | self._types[f.__name__] = {}
225 | if len(argtypes['args']) == 1 and argtypes['args'][0] is None:
226 | aspec = argspec(f)
227 | argtypes['args'] = tuple(map(lambda x: x[1], sorted(
228 | aspec.annotations.items(), key=lambda x: aspec.args.index(x[0]))))
229 | self._types[f.__name__][argtypes['args']] = f
230 | _newf.__name__ = f.__name__
231 | _newf.__doc__ = f.__doc__
232 | _newf.__is_overload__ = True
233 | _newf.__orig_arg__ = argspec(f)
234 | if IPython:
235 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f)
236 | return _newf
237 | return _wrap
238 |
239 | class _RuntimeModule(object):
240 | 'Create a module object at runtime and insert it into sys.path. If called, same as :py:func:`from_objects`.'
241 | def __call__(self, *args, **kwargs):
242 | return self.from_objects(*args, **kwargs)
243 | @staticmethod
244 | @overload.argc(1)
245 | def from_objects(module_name_for_code_eval, **d):
246 | return _RuntimeModule.from_objects(module_name_for_code_eval, '', **d)
247 | @staticmethod
248 | @overload.argc(2)
249 | def from_objects(module_name_for_code_eval, docstring, **d):
250 | '''Create a module at runtime from `d`.
251 |
252 | :param name: The module name.
253 |
254 | :param docstring: Optional. The module's docstring.
255 |
256 | :param \*\*d: All the keyword args, mapped from name->value.
257 |
258 | Example: ``RuntimeModule.from_objects('name', 'doc', a=1, b=2)``'''
259 | module = types.ModuleType(module_name_for_code_eval, docstring)
260 | module.__dict__.update(d)
261 | module.__file__ = ''
262 | sys.modules[module_name_for_code_eval] = module
263 | return module
264 | @staticmethod
265 | @overload.argc(2)
266 | def from_string(module_name_for_code_eval, s):
267 | return _RuntimeModule.from_string(module_name_for_code_eval, '', s)
268 | @staticmethod
269 | @overload.argc(3)
270 | def from_string(module_name_for_code_eval, docstring, s):
271 | '''Create a module at runtime from `s``.
272 |
273 | :param name: The module name.
274 |
275 | :param docstring: Optional. The module docstring.
276 |
277 | :param s: A string containing the module definition.'''
278 | g = {}
279 | _exec(s, g)
280 | return _RuntimeModule.from_objects(module_name_for_code_eval, docstring, **dict(filter(lambda x: x[0] not in g_backup, g.items())))
281 |
282 | RuntimeModule = _RuntimeModule()
283 |
284 | class CaseObject(object):
285 | 'The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments.'
286 | def __init__(self, value):
287 | self.value = value
288 | self.did_match = False
289 | self.did_pass = False
290 | def __call__(self, *args):
291 | if assign('res', not self.did_pass and any([self.value == rhs for rhs in args])):
292 | self.did_match = True
293 | return res
294 | def quit(self):
295 | 'Forces all other calls to return False. Equilavent of a ``break`` statement.'
296 | self.did_pass = True
297 | def default(self):
298 | "Executed if quit wasn't called."
299 | return not self.did_match and not self.did_pass
300 | def __iter__(self):
301 | yield self
302 | def __enter__(self):
303 | return self
304 | def __exit__(self, *args):
305 | pass
306 |
307 | def switch(value):
308 | '''A Python switch statement implementation that is used with a ``with`` statement.
309 |
310 | :param value: The value to "switch".
311 |
312 | ``with`` statement example::
313 |
314 | with switch('x'):
315 | if case(1): print 'Huh?'
316 | if case('x'): print 'It works!!!'
317 |
318 | .. warning:: If you modify a variable named "case" in the same scope that you use the ``with`` statement version, you will get an UnboundLocalError. The soluction is to use ``with switch('x') as case:`` instead of ``with switch('x'):``.'''
319 | res = CaseObject(value)
320 | inspect.stack()[1][0].f_globals['case'] = res
321 | return res
322 |
323 | def tail_recurse(spec=None):
324 | '''Remove tail recursion from a function.
325 |
326 | :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value.
327 |
328 | .. note::
329 |
330 | This function has a slight overhead that is noticable when using timeit. Only use it if the function has a possibility of going over the recursion limit.
331 |
332 | .. warning::
333 |
334 | This function will BREAK any code that either uses any recursion other than tail recursion or calls itself multiple times. For example, ``def x(): return x()+1`` will fail.
335 |
336 | Example::
337 |
338 | @tail_recurse()
339 | def add(a, b):
340 | if a == 0: return b
341 | return add(a-1, b+1)
342 |
343 | add(10000000, 1) # Doesn't max the recursion limit.
344 | '''
345 | def _wrap(f):
346 | class TailRecursion(Exception):
347 | def __init__(self, args, kwargs):
348 | self.args = args
349 | self.kwargs = kwargs
350 | def _newf(*args, **kwargs):
351 | if inspect.stack()[1][3] == f.__name__:
352 | if (spec and spec(args)) or not spec:
353 | raise TailRecursion(args, kwargs)
354 | while True:
355 | try:
356 | res = f(*args, **kwargs)
357 | except TailRecursion as ex:
358 | args = ex.args
359 | kwargs = ex.kwargs
360 | continue
361 | else:
362 | return res
363 | _newf.__doc__ = f.__doc__
364 | return _newf
365 | return _wrap
366 |
367 | def annotate(*args, **kwargs):
368 | '''Set function annotations using decorators.
369 |
370 | :param args: This is a list of annotations for the function, in the order of the function's parameters. For example, ``annotate('Annotation 1', 'Annotation 2')`` will set the annotations of parameter 1 of the function to ``Annotation 1``.
371 |
372 | :param kwargs: This is a mapping of argument names to annotations. Note that these are applied *after* the argument list, so any args set that way will be overriden by this mapping. If there is a key named `ret`, that will be the annotation for the function's return value.
373 |
374 | .. deprecated:: 0.5
375 | Use :func:`fannotate` instead.
376 | '''
377 | def _wrap(f):
378 | if not hasattr(f, '__annotations__'):
379 | f.__annotations__ = {}
380 | if 'ret' in kwargs:
381 | f.__annotations__['return'] = kwargs.pop('ret')
382 | f.__annotations__.update(dict(zip(argspec(f).args, args)))
383 | f.__annotations__.update(kwargs)
384 | return f
385 | return _wrap
386 |
387 | def fannotate(*args, **kwargs):
388 | '''Set function annotations using decorators.
389 |
390 | :param \*args: The first positional argument is used for the function's return value; all others are discarded.
391 |
392 | :param \**kwargs: This is a mapping of argument names to annotations.
393 |
394 | Example::
395 |
396 | @fannotate('This for the return value', a='Parameter a', b='Parameter b')
397 | def x(a, b):
398 | pass
399 |
400 | '''
401 | def _wrap(f):
402 | if not hasattr(f, '__annotations__'):
403 | f.__annotations__ = {}
404 | if len(args) >= 1:
405 | f.__annotations__['return'] = args[0]
406 | f.__annotations__.update(kwargs)
407 | return f
408 | return _wrap
409 |
410 | def safe_unpack(seq, ln, fill=None):
411 | '''Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off.
412 |
413 | :param seq: The sequence to unpack.
414 |
415 | :param ln: The expected length of the sequence.
416 |
417 | :param fill: The value to substitute if the sequence is too small. Defaults to ``None``.
418 |
419 | Example::
420 |
421 | s = 'a:b'
422 | a, b = safe_unpack(s.split(':'), 2)
423 | # a = 'a'
424 | # b = 'b'
425 | s = 'a'
426 | a, b = safe_unpack(s.split(':'), 2)
427 | # a = 'a'
428 | # b = None'''
429 | if len(seq) > ln:
430 | return seq[:ln]
431 | elif len(seq) < ln:
432 | return seq + type(seq)([fill]*(ln-len(seq)))
433 | else:
434 | return seq
435 |
436 | def assign(varname, value):
437 | '''Assign `value` to `varname` and return it. If `varname` is an attribute and the instance name it belongs to is not defined, a NameError is raised.
438 | This can be used to emulate assignment as an expression. For example, this::
439 |
440 | if assign('x', 7): ...
441 |
442 | is equilavent to this C code::
443 |
444 | if (x = 7) ...
445 |
446 | .. warning::
447 |
448 | When assigning an attribute, the instance it belongs to MUST be declared as global prior to the assignment. Otherwise, the assignment will not work.
449 | '''
450 | fd = inspect.stack()[1][0].f_globals
451 | if '.' not in varname:
452 | fd[varname] = value
453 | else:
454 | vsplit = list(map(str.strip, varname.split('.')))
455 | if vsplit[0] not in fd:
456 | raise NameError('Unknown object: %s'%vsplit[0])
457 | base = fd[vsplit[0]]
458 | for x in vsplit[1:-1]:
459 | base = getattr(base, x)
460 | setattr(base, vsplit[-1], value)
461 | return value
462 |
463 | def is_main(frame=1):
464 | "Return if the caller is main. Equilavent to ``__name__ == '__main__'``."
465 | return inspect.stack()[frame][0].f_globals['__name__'] == '__main__'
466 |
467 | def _call_if_main(frame, f, args):
468 | if is_main(frame): return f(*args)
469 |
470 | def call_if_main(f,*args):
471 | "Call the `f` with `args` if the caller's module is main."
472 | return _call_if_main(3,f,args)
473 |
474 | def run_main(f,*args):
475 | "Call `f` with the `args` and terminate the program with its return code if the caller's module is main."
476 | sys.exit(_call_if_main(3,f,args))
477 |
478 | def compare_and_swap(var, compare, new):
479 | "If `var` is equal to `compare`, set it to `new`."
480 | if assign('v', inspect.stack()[1][0].f_globals)[var] == compare:
481 | v[var] = new
482 |
--------------------------------------------------------------------------------
/src/mcts/models/HuggingFace_API.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 |
3 | import torch
4 | from transformers import (
5 | GenerationConfig,
6 | AutoModelForCausalLM,
7 | AutoTokenizer,
8 | )
9 | from tqdm import tqdm
10 | import torch.nn.functional as F
11 | import numpy as np
12 |
13 |
14 | def load_HF_model(ckpt) -> tuple:
15 | tokenizer = AutoTokenizer.from_pretrained(ckpt)
16 | model = AutoModelForCausalLM.from_pretrained(
17 | ckpt,
18 | load_in_8bit=False,
19 | torch_dtype=torch.float16,
20 | device_map="auto",
21 | trust_remote_code=True,
22 | )
23 | return tokenizer, model
24 |
25 |
26 | def generate_with_HF_model(
27 | tokenizer, model, input=None, temperature=0.8, top_p=0.95, top_k=40, num_beams=1, max_new_tokens=128, **kwargs
28 | ):
29 | try:
30 | inputs = tokenizer(input, return_tensors="pt")
31 | input_ids = inputs["input_ids"].to("cuda")
32 | generation_config = GenerationConfig(
33 | do_sample=True,
34 | temperature=temperature,
35 | top_p=top_p,
36 | top_k=top_k,
37 | num_beams=num_beams,
38 | **kwargs,
39 | )
40 | with torch.no_grad():
41 | generation_output = model.generate(
42 | input_ids=input_ids,
43 | generation_config=generation_config,
44 | return_dict_in_generate=True,
45 | output_scores=True,
46 | max_new_tokens=max_new_tokens,
47 | do_sample=True,
48 | pad_token_id=tokenizer.eos_token_id,
49 | )
50 | s = generation_output.sequences[0]
51 | output = tokenizer.decode(s)
52 | except Exception as e:
53 | breakpoint()
54 | return output
55 |
--------------------------------------------------------------------------------
/src/mcts/models/IO_System.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 |
3 | import sys
4 | from transformers import AutoTokenizer, AutoModelForCausalLM
5 | import torch
6 | from run_src.rstar_utils import time_decorator
7 | sys.path.append(".")
8 |
9 | from typing import List, Dict
10 |
11 | try:
12 | from models.vLLM_API import generate_with_vLLM_model
13 | except:
14 | pass
15 |
16 | try:
17 | from models.OpenAI_API import generate_n_with_OpenAI_model
18 | except:
19 | pass
20 |
21 |
22 | class IO_System:
23 | """Input/Output system"""
24 |
25 | def __init__(self, args, tokenizer, model) -> None:
26 | self.api = args.api
27 | if self.api == "together":
28 | assert tokenizer is None and model is None
29 | elif self.api == "gpt3.5-turbo":
30 | assert tokenizer is None and isinstance(model, str)
31 | self.model_ckpt = args.model_ckpt
32 | self.temperature = args.temperature
33 | self.top_k = args.top_k
34 | self.top_p = args.top_p
35 | self.tokenizer = tokenizer
36 | self.model = model
37 |
38 | self.call_counter = 0
39 | self.token_counter = 0
40 |
41 |
42 | @time_decorator
43 | def generate(self, model_input, max_tokens: int, num_return: int, stop_tokens):
44 | if isinstance(model_input, str):
45 | if self.api == "vllm":
46 | vllm_response = generate_with_vLLM_model(
47 | self.model,
48 | input=model_input,
49 | temperature=self.temperature,
50 | top_p=self.top_p,
51 | top_k=self.top_k,
52 | n=num_return,
53 | max_tokens=max_tokens,
54 | stop=stop_tokens,
55 | )
56 | io_output_list = [o.text for o in vllm_response[0].outputs]
57 | self.call_counter += 1
58 | self.token_counter += sum([len(o.token_ids) for o in vllm_response[0].outputs])
59 |
60 | elif self.api == "OpenAI":
61 | gpt_response = generate_n_with_OpenAI_model(
62 | prompt=model_input,
63 | n=num_return,
64 | model_ckpt=self.model,
65 | max_tokens=max_tokens,
66 | max_completion_tokens=max_tokens,
67 | temperature=self.temperature,
68 | top_p=self.top_p,
69 | top_k=self.top_k,
70 | # stop=["\n", "Answer"],
71 | stop = stop_tokens,
72 | )
73 | io_output_list = gpt_response
74 | self.call_counter += num_return
75 | self.token_counter += 0
76 | elif self.api == "debug":
77 | io_output_list = ["Debug: The answer is generated with debug mode, 233." for _ in range(num_return)]
78 | else:
79 | raise NotImplementedError(f"API {self.api} is not implemented.")
80 | elif isinstance(model_input, list):
81 | if self.api == "vllm":
82 | vllm_response = generate_with_vLLM_model(
83 | self.model,
84 | input=model_input,
85 | temperature=self.temperature,
86 | top_p=self.top_p,
87 | top_k=self.top_k,
88 | n=num_return,
89 | max_tokens=max_tokens,
90 | stop=stop_tokens,
91 | )
92 | io_output_list = [
93 | [o.text for o in resp_to_single_input.outputs] for resp_to_single_input in vllm_response
94 | ]
95 | self.call_counter += 1
96 | self.token_counter += sum(
97 | [
98 | sum([len(o.token_ids) for o in resp_to_single_input.outputs])
99 | for resp_to_single_input in vllm_response
100 | ]
101 | )
102 | elif self.api == "gpt3.5-turbo":
103 | io_output_list = []
104 | for input in model_input:
105 | gpt_response = generate_n_with_OpenAI_model(
106 | prompt=input,
107 | n=num_return,
108 | model_ckpt=self.model,
109 | max_tokens=max_tokens,
110 | temperature=self.temperature,
111 | top_p=self.top_p,
112 | top_k=self.top_k,
113 | stop=["\n", "Answer"],
114 | )
115 | io_output_list.append(gpt_response)
116 | self.call_counter += num_return
117 | self.token_counter += 0
118 | elif self.api == "debug":
119 | io_output_list = [
120 | ["Debug: The answer is generated with debug mode, 233." for _ in range(num_return)]
121 | for _ in model_input
122 | ]
123 | else:
124 | raise NotImplementedError(f"API {self.api} is not implemented.")
125 |
126 | return io_output_list
127 |
--------------------------------------------------------------------------------
/src/mcts/models/OpenAI_API.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 |
3 | import os
4 | import time
5 | from tqdm import tqdm
6 | import concurrent.futures
7 | from openai import OpenAI
8 |
9 | client = AzureOpenAI(
10 | api_version="",
11 | azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT", ""),
12 | api_key=os.environ.get("AZURE_OPENAI_API_KEY", ""),
13 | )
14 |
15 | max_threads = 32
16 |
17 |
18 | def load_OpenAI_model(model):
19 | return None, model
20 |
21 |
22 | def generate_with_OpenAI_model(
23 | prompt,
24 | model_ckpt="gpt-35-turbo",
25 | max_tokens=2048,
26 | max_completion_tokens=4096,
27 | temperature=0.8,
28 | top_k=40,
29 | top_p=0.95,
30 | stop=["\n"],
31 | ):
32 | messages = [{"role": "user", "content": prompt}]
33 | parameters = {
34 | "model": model_ckpt,
35 | "temperature": temperature,
36 | "max_tokens": max_tokens,
37 | "max_completion_tokens": max_completion_tokens,
38 | "top_p": top_p,
39 | "stop": stop,
40 | "seed": 1,
41 | }
42 |
43 | ans, timeout = "", 5
44 | while not ans:
45 | try:
46 | time.sleep(timeout)
47 | completion = client.chat.completions.create(messages=messages, **parameters)
48 | ans = completion.choices[0].message.content
49 |
50 | except Exception as e:
51 | print(e)
52 | if not ans:
53 | timeout = timeout * 2
54 | if timeout > 120:
55 | timeout = 1
56 | try:
57 | print(f"Will retry after {timeout} seconds ...")
58 | except:
59 | pass
60 | return ans
61 |
62 |
63 | def generate_n_with_OpenAI_model(
64 | prompt,
65 | n=1,
66 | model_ckpt="gpt-35-turbo",
67 | max_tokens=4096,
68 | max_completion_tokens=4096,
69 | temperature=0.8,
70 | top_k=40,
71 | top_p=0.95,
72 | stop=["\n"],
73 | max_threads=16,
74 | disable_tqdm=True,
75 | ):
76 | preds = []
77 | with concurrent.futures.ProcessPoolExecutor(max_workers=max_threads) as executor:
78 | futures = [
79 | executor.submit(generate_with_OpenAI_model, prompt, model_ckpt, max_tokens, max_completion_tokens, temperature, top_k, top_p, stop)
80 | for _ in range(n)
81 | ]
82 | for i, future in tqdm(
83 | enumerate(concurrent.futures.as_completed(futures)),
84 | total=len(futures),
85 | desc="running evaluate",
86 | disable=disable_tqdm,
87 | ):
88 | ans = future.result()
89 | preds.append(ans)
90 | return preds
91 |
--------------------------------------------------------------------------------
/src/mcts/models/vLLM_API.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 |
3 | from vllm import LLM, SamplingParams
4 | from transformers import AutoTokenizer
5 | import numpy as np
6 | import math
7 |
8 |
9 | def load_vLLM_model(model_ckpt, seed, tensor_parallel_size=1, half_precision=False, max_num_seqs=256):
10 | tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
11 |
12 | if half_precision:
13 | llm = LLM(
14 | model=model_ckpt,
15 | dtype="half",
16 | tensor_parallel_size=tensor_parallel_size,
17 | seed=seed,
18 | trust_remote_code=True,
19 | max_num_seqs=max_num_seqs,
20 | swap_space=16,
21 | )
22 | else:
23 | llm = LLM(
24 | model=model_ckpt,
25 | tensor_parallel_size=tensor_parallel_size,
26 | seed=seed,
27 | trust_remote_code=True,
28 | max_num_seqs=max_num_seqs,
29 | swap_space=16,
30 | # max_model_len=14336,
31 | )
32 |
33 | return tokenizer, llm
34 |
35 |
36 | def generate_with_vLLM_model(
37 | model,
38 | input,
39 | temperature=0.8,
40 | top_p=0.95,
41 | top_k=40,
42 | repetition_penalty=1.1,
43 | n=1,
44 | max_tokens=256,
45 | logprobs=1,
46 | stop=[],
47 | ):
48 | sampling_params = SamplingParams(
49 | temperature=temperature,
50 | top_p=top_p,
51 | top_k=top_k,
52 | repetition_penalty=repetition_penalty,
53 | n=n,
54 | logprobs=logprobs,
55 | max_tokens=max_tokens,
56 | stop=stop,
57 | )
58 |
59 | output = model.generate(input, sampling_params, use_tqdm=False)
60 | return output
61 |
62 |
63 | if __name__ == "__main__":
64 | model_ckpt = "deepseek-ai/deepseek-coder-1.3b-instruct"
65 | tokenizer, model = load_vLLM_model(model_ckpt, seed=42, tensor_parallel_size=1, half_precision=False)
66 | input = "write a python function to calculate the sum of two numbers"
67 | output = generate_with_vLLM_model(model, input)
68 | # breakpoint()
69 | print("-------------------------------------")
70 | print(output[0].outputs[0].text)
71 |
--------------------------------------------------------------------------------
/src/mcts/prompts/TACO/examples.txt:
--------------------------------------------------------------------------------
1 | ### Question
2 | In this problem you have to implement an algorithm to defragment your hard disk. The hard disk consists of a sequence of clusters, numbered by integers from 1 to n. The disk has m recorded files, the i-th file occupies clusters with numbers ai, 1, ai, 2, ..., ai, ni. These clusters are not necessarily located consecutively on the disk, but the order in which they are given corresponds to their sequence in the file (cluster ai, 1 contains the first fragment of the i-th file, cluster ai, 2 has the second fragment, etc.). Also the disc must have one or several clusters which are free from files.
3 |
4 | You are permitted to perform operations of copying the contents of cluster number i to cluster number j (i and j must be different). Moreover, if the cluster number j used to keep some information, it is lost forever. Clusters are not cleaned, but after the defragmentation is complete, some of them are simply declared unusable (although they may possibly still contain some fragments of files).
5 |
6 | Your task is to use a sequence of copy operations to ensure that each file occupies a contiguous area of memory. Each file should occupy a consecutive cluster section, the files must follow one after another from the beginning of the hard disk. After defragmentation all free (unused) clusters should be at the end of the hard disk. After defragmenting files can be placed in an arbitrary order. Clusters of each file should go consecutively from first to last. See explanatory examples in the notes.
7 |
8 | Print the sequence of operations leading to the disk defragmentation. Note that you do not have to minimize the number of operations, but it should not exceed 2n.
9 |
10 | Input
11 |
12 | The first line contains two integers n and m (1 ≤ n, m ≤ 200) — the number of clusters and the number of files, correspondingly. Next m lines contain descriptions of the files. The first number in the line is ni (ni ≥ 1), the number of clusters occupied by the i-th file. Then follow ni numbers ai, 1, ai, 2, ..., ai, ni (1 ≤ ai, j ≤ n). It is guaranteed that each cluster number occurs not more than once and , that is, there exists at least one unused cluster. Numbers on each line are separated by spaces.
13 |
14 | Output
15 |
16 | In the first line print a single integer k (0 ≤ k ≤ 2n) — the number of operations needed to defragment the disk. Next k lines should contain the operations' descriptions as "i j" (copy the contents of the cluster number i to the cluster number j).
17 |
18 | Examples
19 |
20 | Input
21 |
22 | 7 2
23 | 2 1 2
24 | 3 3 4 5
25 |
26 |
27 | Output
28 |
29 | 0
30 |
31 |
32 | Input
33 |
34 | 7 2
35 | 2 1 3
36 | 3 2 4 5
37 |
38 |
39 | Output
40 |
41 | 3
42 | 2 6
43 | 3 2
44 | 6 3
45 |
46 | Note
47 |
48 | Let's say that a disk consists of 8 clusters and contains two files. The first file occupies two clusters and the second file occupies three clusters. Let's look at examples of correct and incorrect positions of files after defragmentation.
49 |
50 |
51 |
52 | Example 2: each file must occupy a contiguous area of memory.
53 |
54 | Example 3: the order of files to each other is not important, at first the second file can be written, and then — the first one.
55 |
56 | Example 4: violating the order of file fragments to each other is not allowed.
57 |
58 | Example 5: unused clusters should be located at the end, and in this example the unused clusters are 3, 7, 8.
59 |
60 | ### Solution
61 | Let's think step by step.
62 |
63 | ### Step 1: **Defining High-Level Structure (Action 1)**
64 |
65 | In this first step, we define the high-level structure of the solution. Our objective is to outline how to store the files and clusters, track which clusters are in use, and prepare the basic setup for further refinement.
66 |
67 | #### Goal:
68 | 1. Parse the input to retrieve the number of clusters `n` and files `m`.
69 | 2. Initialize a `memory` array to track which files occupy which clusters (initialized to `-1` for unused clusters).
70 | 3. Store each file's cluster allocation.
71 |
72 | #### Pseudocode:
73 | ```pseudo
74 | Function defragmentation(n, m, files):
75 | Initialize memory array of size n+1 with -1 (to represent unused clusters)
76 | Initialize an empty list to store the files and their respective clusters
77 | For each file in files:
78 | Store its clusters in the list and mark them in the memory array
79 | Return the initialized memory and files list
80 | ```
81 |
82 | ---
83 |
84 | ### Step 2: **Finding Free Clusters (Action 2)**
85 |
86 | In this step, we refine the pseudocode to handle finding the first free cluster on the disk. This is necessary because we need free clusters to move file fragments to contiguous positions.
87 |
88 | #### Goal:
89 | 1. Identify the first available free cluster.
90 | 2. Return the index of this free cluster for use in later operations.
91 |
92 | #### Pseudocode:
93 | ```pseudo
94 | Function find_free_cluster(memory):
95 | For i from 1 to n:
96 | If memory[i] == -1:
97 | Return i // Return first free cluster found
98 | Return -1 // If no free cluster is found
99 | ```
100 |
101 | ---
102 |
103 | ### Step 3: **Defragmenting Files (Action 2)**
104 |
105 | Now, we move to the main part of the problem: defragmenting the files. This involves moving file fragments to available contiguous clusters. We will loop through the files and their respective clusters, identifying fragments that are not in their correct position and moving them accordingly.
106 |
107 | #### Goal:
108 | 1. Check if the current file's clusters are contiguous.
109 | 2. If not, use the free clusters to move fragments into place.
110 | 3. Track the sequence of operations that result in the defragmentation.
111 |
112 | #### Pseudocode:
113 | ```pseudo
114 | Function defragment_files(memory, files):
115 | Initialize result list to store operations
116 | Initialize cnt to 0 // To track the expected position of file fragments
117 | For each file in files:
118 | For each cluster in file:
119 | If cluster is not in the correct position (i != cnt):
120 | Find the free cluster
121 | Append operation (move cluster to the free position)
122 | Update memory to reflect the changes
123 | Update free cluster to next available cluster
124 | Update the file list accordingly
125 | Return the list of operations
126 | ```
127 |
128 | ---
129 |
130 | ### Step 4: **Refining File Movement and Operations (Action 2)**
131 |
132 | In this step, we refine the logic for moving the files and ensuring that no clusters are overwritten during the defragmentation process. We will also ensure that the correct operations are recorded.
133 |
134 | #### Goal:
135 | 1. Handle the case where we need to move fragments from a used cluster to a free cluster.
136 | 2. Ensure that the file list and memory are updated correctly during each operation.
137 |
138 | #### Pseudocode:
139 | ```pseudo
140 | For each file in files:
141 | For each cluster in file:
142 | If cluster is not in the correct place:
143 | Find a free cluster
144 | Perform a copy operation from cluster to free cluster
145 | Record the operation
146 | Update memory and file list
147 | Ensure no overwriting of other files' clusters
148 | Return the list of operations
149 | ```
150 |
151 | ---
152 |
153 | ### Step 5: **Generating Python Code (Action 3)**
154 |
155 | Now that the pseudocode has been fully refined, we proceed to translate it into Python code. This involves implementing the main function that drives the defragmentation process, handling input/output correctly, and printing the sequence of operations.
156 |
157 | #### Goal:
158 | 1. Implement the `defragmentation` function in Python.
159 | 2. Ensure the input is parsed correctly and the output matches the expected format.
160 | 3. Ensure the program handles the movement of clusters and generates the correct sequence of operations.
161 |
162 | #### Python Code:
163 |
164 | ```python
165 | def defragmentation(n, m, files):
166 | # Initialize memory array with -1 (representing unused clusters)
167 | memory = [-1] * (n + 1)
168 | d = []
169 |
170 | # Mark memory with files' cluster allocations
171 | for i in range(m):
172 | file = files[i][1:] # Ignore first value (size of file)
173 | for cluster in file:
174 | memory[cluster] = i # Mark the memory with the file index
175 | d.append(file)
176 |
177 | free = -1
178 | # Find the first free cluster
179 | for i in range(1, n + 1):
180 | if memory[i] == -1:
181 | free = i
182 | break
183 |
184 | res = [] # List to store operations
185 | cnt = 0
186 |
187 | # Start defragmenting the files
188 | for x in range(m):
189 | for i in range(len(d[x])):
190 | cnt += 1
191 | if d[x][i] == cnt:
192 | continue # File fragment is already in the correct position
193 |
194 | y = memory[cnt]
195 | if y == -1:
196 | # Move fragment to the free cluster
197 | res.append((d[x][i], cnt))
198 | memory[d[x][i]] = -1
199 | free = d[x][i]
200 | d[x][i] = cnt
201 | memory[cnt] = x
202 | continue
203 |
204 | # Move fragments to fill gaps
205 | for j in range(len(d[y])):
206 | if d[y][j] == cnt:
207 | res.append((cnt, free))
208 | res.append((d[x][i], cnt))
209 | d[y][j] = free
210 | memory[free] = y
211 | free = d[x][i]
212 | memory[free] = -1
213 | d[x][i] = cnt
214 | memory[cnt] = x
215 |
216 | # Output the result
217 | print(len(res))
218 | for op in res:
219 | print(op[0], op[1])
220 |
221 | n, m = map(int, input().split())
222 | files = [list(map(int, input().split())) for _ in range(m)]
223 | defragmentation(n, m, files)
224 | ```
225 |
226 |
227 |
228 | ### Question
229 | A gene is represented as a string of length $n$ (where $n$ is divisible by $4$), composed of the letters $\mbox{A}$, $\mbox{C}$, $\textbf{T}$, and $\mbox{G}$.
230 | It is considered to be steady if each of the four letters occurs exactly $\frac{n}{4}$ times. For example, $\textbf{GACT}$ and $\textbf{AAGTGCCT}$ are both steady genes.
231 |
232 | Bear Limak is a famous biotechnology scientist who specializes in modifying bear DNA to make it steady. Right now, he is examining a gene represented as a string $\textit{gene}$. It is not necessarily steady. Fortunately, Limak can choose one (maybe empty) substring of $\textit{gene}$ and replace it with any string of the same length.
233 |
234 | Modifying a large substring of bear genes can be dangerous.
235 | Given a string $\textit{gene}$, can you help Limak find the length of the smallest possible substring that he can replace to make $\textit{gene}$ a steady gene?
236 |
237 | Note: A substring of a string $\boldsymbol{\mathrm{~S~}}$ is a subsequence made up of zero or more contiguous characters of $\boldsymbol{\mathrm{~S~}}$.
238 |
239 | As an example, consider $\textit{gene}=\textit{ACTGAAAG}$. The substring $\boldsymbol{AA}$ just before or after $\mbox{G}$ can be replaced with $\textit{CT}$ or $\boldsymbol{TC}$. One selection would create $\textbf{ACTGACTG}$.
240 |
241 | Function Description
242 |
243 | Complete the $\textit{steadyGene}$ function in the editor below. It should return an integer that represents the length of the smallest substring to replace.
244 |
245 | steadyGene has the following parameter:
246 |
247 | gene: a string
248 |
249 | Input Format
250 |
251 | The first line contains an interger $n$ divisible by $4$, that denotes the length of a string $\textit{gene}$.
252 |
253 | The second line contains a string $\textit{gene}$ of length $n$.
254 |
255 | Constraints
256 |
257 | $4\leq n\leq500000$
258 | $n$ is divisible by $4$
259 | $gene[i]\in[CGAT]$
260 |
261 | Subtask
262 |
263 | $4\leq n\leq2000$ in tests worth $30\%$ points.
264 |
265 | Output Format
266 |
267 | Print the length of the minimum length substring that can be replaced to make $\textit{gene}$ stable.
268 |
269 | Sample Input
270 | 8
271 | GAAATAAA
272 |
273 | Sample Output
274 | 5
275 |
276 | Explanation
277 |
278 | One optimal solution is to replace $\textbf{AAATA}$ with $\textbf{TTCCG}$ resulting in $\textbf{GTTCCGAA}$.
279 |
280 | The replaced substring has length $5$.
281 |
282 | ### Solution
283 | Let's think step by step.
284 |
285 | ### Step 1: **Defining High-Level Structure (Action 1)**
286 |
287 | #### Goal:
288 | Define the overall solution framework and how the problem will be approached:
289 | 1. Parse the input to get the string and calculate the target frequency of each nucleotide.
290 | 2. Identify imbalances (if any) in nucleotide counts.
291 | 3. Use a sliding window to find the smallest substring that can balance the nucleotide counts when replaced.
292 |
293 | #### Pseudocode:
294 | ```pseudo
295 | Function steadyGene(gene, n):
296 | target = n / 4 # Each nucleotide should appear n/4 times
297 | Count the frequency of each nucleotide in the gene
298 | If all nucleotide counts are <= target:
299 | Return 0 # Gene is already steady
300 |
301 | Use a sliding window approach:
302 | Initialize start and end of the window
303 | Track the frequency of nucleotides in the current window
304 | Expand and shrink the window to find the minimal length substring that can be replaced
305 |
306 | Return the minimal length of the valid substring
307 | ```
308 |
309 | ---
310 |
311 | ### Step 2: **Calculating Frequencies and Imbalances (Action 2)**
312 |
313 | #### Goal:
314 | 1. Count the frequency of each nucleotide in the gene.
315 | 2. Identify nucleotides that exceed the target frequency and need balancing.
316 |
317 | #### Pseudocode:
318 | ```pseudo
319 | Function calculateImbalances(gene, target):
320 | Initialize a frequency dictionary for nucleotides ('A', 'C', 'T', 'G') with counts as 0
321 | For each character in the gene:
322 | Increment its count in the frequency dictionary
323 |
324 | Identify excess nucleotides:
325 | For each nucleotide in the dictionary:
326 | If its count > target:
327 | Record the excess
328 |
329 | Return the frequency dictionary and excess nucleotides
330 | ```
331 |
332 | ---
333 |
334 | ### Step 3: **Sliding Window Implementation (Action 2)**
335 |
336 | #### Goal:
337 | 1. Use a sliding window to dynamically evaluate substrings.
338 | 2. Track the frequency of nucleotides within the window.
339 | 3. Ensure the substring balances the excess nucleotides.
340 |
341 | #### Pseudocode:
342 | ```pseudo
343 | Function findMinimalSubstring(gene, n, target, full_freq):
344 | Initialize start = 0, end = 0, and current_freq = {}
345 | Initialize min_length = n (start with maximum possible length)
346 |
347 | While end < n:
348 | Expand the window by including gene[end]
349 | Update the current_freq for gene[end]
350 |
351 | While the window satisfies the balancing condition:
352 | Update min_length with the current window size
353 | Shrink the window from the start
354 | Update current_freq for gene[start]
355 | Increment start
356 |
357 | Increment end
358 |
359 | Return min_length
360 | ```
361 |
362 | ---
363 |
364 | ### Step 4: **Satisfaction Check (Action 2)**
365 |
366 | #### Goal:
367 | Ensure that the current substring can balance the nucleotide counts when replaced.
368 |
369 | #### Pseudocode:
370 | ```pseudo
371 | Function satisfiesCondition(current_freq, full_freq, target):
372 | For each nucleotide in ('A', 'C', 'T', 'G'):
373 | If full_freq[nucleotide] > target:
374 | If current_freq[nucleotide] < full_freq[nucleotide] - target:
375 | Return False
376 | Return True
377 | ```
378 |
379 | ---
380 |
381 | ### Step 5: **Implementing Python Code (Action 3)**
382 |
383 | #### Python Code:
384 | The refined pseudocode is translated into Python as follows:
385 |
386 | ```python
387 | def satisfied(freq):
388 | for bp in base_pairs:
389 | if full_freq[bp] > n / 4:
390 | if freq[bp] < full_freq[bp] - n / 4:
391 | return False
392 | return True
393 |
394 | def steadyGene(n, gene):
395 | base_pairs = ['A', 'C', 'T', 'G']
396 |
397 | # Calculate the frequency of each nucleotide in the gene
398 | full_freq = {bp: 0 for bp in base_pairs}
399 | for nuc in gene:
400 | full_freq[nuc] += 1
401 |
402 | # Check if the gene is already steady
403 | if all(full_freq[bp] <= n / 4 for bp in base_pairs):
404 | return 0 # No replacement needed
405 |
406 | # Sliding window to find the minimal substring
407 | cur_freq = {bp: 0 for bp in base_pairs}
408 | start = 0
409 | min_length = n
410 |
411 | for end in range(n):
412 | cur_freq[gene[end]] += 1
413 |
414 | while satisfied(cur_freq):
415 | # Update minimal length
416 | min_length = min(min_length, end - start + 1)
417 | # Shrink the window
418 | cur_freq[gene[start]] -= 1
419 | start += 1
420 |
421 | return min_length
422 |
423 | n = int(input())
424 | gene = input()
425 | print(steadyGene(n, gene))
426 | ```
427 |
--------------------------------------------------------------------------------
/src/mcts/prompts/TACO/prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompt_template": "### Instruction\nPlease refer to the given task description and provide a thought process in the form of step-by-step pseudocode refinement.\n\nA curious user has approached you with a programming question. You should give step-by-step solutions to the user's questions. For each step you can choose one of the following three actions\n\n Defining algorithm Structures Using pseudocode\n**Description:** \nOutline the core functions and overall structure of the solution without getting into implementation details. Define inputs, outputs, and the main tasks each function will perform.\n\n Refine part of the pseudocode\n**Description:** \nAdd more details to the pseudocode, specifying the exact steps, logic, and operations each function will carry out. This prepares the pseudocode for actual coding.\n\n Generate python code from the pseudocode\n**Description:** \nTranslate the refined pseudocode into executable Python code, making sure to handle inputs, outputs, and ensure correctness in the implementation.\n\n**Note:**\n- You can choose one of the three actions for each step.\n- Provide a detailed explanation of the reasoning behind each step.\n- Try to refer to the reference code as much as possible, but you can also modify it if needed (e.g. change variable names, add some comments, etc.).\n\n### Examples\n{examples}\n\n### Question\n{question}\n\n### Solution\nLet's think step by step.\n",
3 | "stop_tokens": [
4 | "Step_End"
5 | ]
6 | }
--------------------------------------------------------------------------------
/src/mcts/run_outputs/data_examples/answer_sheets/Question 0000 - Answer.json:
--------------------------------------------------------------------------------
1 | {"id": 0, "problem": "**Step 1:** Create a function called `encode()` to replace all the lowercase vowels in a given string with numbers according to the following pattern:\n```\na -> 1\ne -> 2\ni -> 3\no -> 4\nu -> 5\n```\n\nFor example, `encode(\"hello\")` would return `\"h2ll4\"`. There is no need to worry about uppercase vowels in this kata.\n\n**Step 2:** Now create a function called `decode()` to turn the numbers back into vowels according to the same pattern shown above.\n\nFor example, `decode(\"h3 th2r2\")` would return `\"hi there\"`.\n\nFor the sake of simplicity, you can assume that any numbers passed into the function will correspond to vowels.", "gold_solution": "def encode(s, t=str.maketrans('aeiou', '12345')):\n\treturn s.translate(t)\n\ndef decode(s, t=str.maketrans('12345', 'aeiou')):\n\treturn s.translate(t)\n", "test_case": {"fn_name": "encode", "inputs": [["How are you today?"], ["hello"], ["This is an encoding test."]], "outputs": [["H4w 1r2 y45 t4d1y?"], ["h2ll4"], ["Th3s 3s 1n 2nc4d3ng t2st."]]}, "difficulty": "EASY", "time_taken": "299.88s"}
--------------------------------------------------------------------------------
/src/mcts/run_outputs/data_examples/answer_sheets/Question 0000 - Best Solutions.json:
--------------------------------------------------------------------------------
1 | {
2 | "trace": {
3 | "0": {
4 | "user_question": "**Step 1:** Create a function called `encode()` to replace all the lowercase vowels in a given string with numbers according to the following pattern:\n```\na -> 1\ne -> 2\ni -> 3\no -> 4\nu -> 5\n```\n\nFor example, `encode(\"hello\")` would return `\"h2ll4\"`. There is no need to worry about uppercase vowels in this kata.\n\n**Step 2:** Now create a function called `decode()` to turn the numbers back into vowels according to the same pattern shown above.\n\nFor example, `decode(\"h3 th2r2\")` would return `\"hi there\"`.\n\nFor the sake of simplicity, you can assume that any numbers passed into the function will correspond to vowels.",
5 | "ost_step": {
6 | "1": "### Step 1: **Defining High-Level Structure for `encode()` and `decode()` (Action 1)**\n\nIn this first step, we define the overall structure of both the `encode()` and `decode()` functions. The goal of the `encode()` function is to replace vowels in the string with their respective numbers, while the `decode()` function will reverse this process, converting numbers back into the corresponding vowels.\n\n#### Plan for `encode()`:\n1. Loop through each character in the input string.\n2. For each lowercase vowel (`a`, `e`, `i`, `o`, `u`), replace it with the corresponding number (`a -> 1`, `e -> 2`, etc.).\n3. Keep all other characters (non-vowels) unchanged.\n\n#### Plan for `decode()`:\n1. Loop through each character in the input string.\n2. For each number (`1`, `2`, `3`, `4`, `5`), replace it with the corresponding vowel (`1 -> a`, `2 -> e`, etc.).\n3. Keep all other characters (non-numeric) unchanged.\n\n#### Pseudocode:\n```pseudo\nFunction encode(string):\n Initialize a dictionary for vowel to number mappings\n Initialize an empty string for the result\n For each character in the string:\n If the character is a vowel:\n Append the corresponding number to the result\n Else:\n Append the character as is\n Return the result string\n\nFunction decode(string):\n Initialize a dictionary for number to vowel mappings\n Initialize an empty string for the result\n For each character in the string:\n If the character is a number:\n Append the corresponding vowel to the result\n Else:\n Append the character as is\n Return the result string\n```\n\n",
7 | "2": "### Step 2: **Refine the `encode()` Function (Action 2)**\n\nIn this step, we will refine the pseudocode for the `encode()` function by detailing how to handle each character in the input string and map vowels to numbers. This includes the setup of the dictionary for encoding and the loop to process each character.\n\n#### Pseudocode:\n```pseudo\nFunction encode(string):\n Initialize a dictionary `vowel_to_number` with mappings: {'a': '1', 'e': '2', 'i': '3', 'o': '4', 'u': '5'}\n Initialize an empty string `encoded_string` to store the result\n For each character `char` in the input `string`:\n If `char` is in `vowel_to_number`:\n Append `vowel_to_number[char]` to `encoded_string`\n Else:\n Append `char` to `encoded_string`\n Return `encoded_string`\n```\n\n#### Reasoning:\n- We use a dictionary to efficiently map each vowel to its corresponding number.\n- For each character in the string, we check if it's a vowel by looking it up in the dictionary.\n- If it's a vowel, we append the mapped number to the result string.\n- If it's not a vowel, we append the character as is to maintain the integrity of non-vowel characters.",
8 | "3": "### Step 3: **Refine the `decode()` Function (Action 2)**\n\nIn this step, we refine the pseudocode for the `decode()` function. The goal is to handle the process of replacing numbers with the corresponding vowels. We will establish a dictionary for number-to-vowel mappings and refine the logic for parsing through the string.\n\n#### Pseudocode:\n```pseudo\nFunction decode(string):\n Initialize a dictionary `number_to_vowel` with mappings: {'1': 'a', '2': 'e', '3': 'i', '4': 'o', '5': 'u'}\n Initialize an empty string `decoded_string` to store the result\n For each character `char` in the input `string`:\n If `char` is in `number_to_vowel`:\n Append `number_to_vowel[char]` to `decoded_string`\n Else:\n Append `char` to `decoded_string`\n Return `decoded_string`\n```\n\n#### Reasoning:\n- We create a dictionary `number_to_vowel` to map each number (1-5) back to its respective vowel.\n- For each character in the string, we check if it's a number (1-5) by looking it up in the dictionary.\n- If it's a number, we append the corresponding vowel to the result string.\n- If it's not a number (e.g., spaces or other characters), we simply append it as is.\n\nThis approach ensures that we correctly reverse the encoding process, replacing numbers with vowels while maintaining the integrity of non-numeric characters.\n\n",
9 | "4": "### Step 4: **Generating Python Code for `encode()` and `decode()` (Action 3)**\n\nNow that the pseudocode for both `encode()` and `decode()` functions has been fully refined, we translate it into Python code. This involves implementing the logic for both encoding vowels as numbers and decoding numbers back into vowels.\n\n#### Python Code:\n```python\ndef encode(string):\n # Mapping from vowels to numbers\n vowel_to_number = {'a': '1', 'e': '2', 'i': '3', 'o': '4', 'u': '5'}\n # String to store encoded result\n encoded_string = \"\"\n # Iterate over each character in the input string\n for char in string:\n # Check if the character is a vowel and map it to the corresponding number\n if char in vowel_to_number:\n encoded_string += vowel_to_number[char]\n else:\n # Keep non-vowel characters unchanged\n encoded_string += char\n return encoded_string\n\ndef decode(string):\n # Mapping from numbers to vowels\n number_to_vowel = {'1': 'a', '2': 'e', '3': 'i', '4': 'o', '5': 'u'}\n # String to store decoded result\n decoded_string = \"\"\n # Iterate over each character in the input string\n for char in string:\n # Check if the character is a number and map it to the corresponding vowel\n if char in number_to_vowel:\n decoded_string += number_to_vowel[char]\n else:\n # Keep non-numeric characters unchanged\n decoded_string += char\n return decoded_string\n\n# Example usage:\nprint(encode(\"hello\")) # Output: \"h2ll4\"\nprint(decode(\"h3 th2r2\")) # Output: \"hi there\"\n```\n\n#### Explanation:\n- For the `encode()` function, we create a dictionary `vowel_to_number` to map vowels to numbers. We iterate over each character in the input string, replacing vowels with their corresponding numbers and keeping other characters unchanged.\n- For the `decode()` function, we create a dictionary `number_to_vowel` to map numbers back to vowels. We iterate over each character in the input string, replacing numbers with their corresponding vowels and keeping other characters unchanged.\n- The code is efficient and straightforward, utilizing dictionary lookups to perform the necessary character replacements.\n\nThis completes the translation of the refined pseudocode into Python, providing a functional implementation of the `encode()` and `decode()` operations."
10 | },
11 | "ost_step_value": {
12 | "1": 1.0,
13 | "2": 1.0,
14 | "3": 1.0,
15 | "4": 1.0
16 | }
17 | }
18 | },
19 | "rollout_id": 0,
20 | "value": 1.0
21 | }
--------------------------------------------------------------------------------
/src/mcts/run_outputs/data_examples/args.json:
--------------------------------------------------------------------------------
1 | {
2 | "note": "default",
3 | "api": "OpenAI",
4 | "seed": 42,
5 | "verbose": true,
6 | "wandb_mode": "disabled",
7 | "model_ckpt": "gpt-4o",
8 | "model_parallel": false,
9 | "half_precision": false,
10 | "max_tokens": 4096,
11 | "temperature": 0.8,
12 | "top_k": 40,
13 | "top_p": 0.95,
14 | "num_beams": 1,
15 | "max_num_worker": 3,
16 | "test_batch_size": 1,
17 | "tensor_parallel_size": 1,
18 | "prompts_root": "prompts",
19 | "data_root": "data",
20 | "dataset_name": "TACO",
21 | "test_json_filename": "test_one",
22 | "start_idx": 0,
23 | "end_idx": Infinity,
24 | "run_outputs_root": "run_outputs",
25 | "eval_outputs_root": "eval_outputs",
26 | "run_outputs_dir": "run_outputs/",
27 | "num_rollouts": 12,
28 | "max_depth_allowed": 10,
29 | "mcts_discount_factor": 1.0,
30 | "mcts_exploration_weight": 2.0,
31 | "mcts_weight_scheduler": "const",
32 | "mcts_num_last_votes": 32,
33 | "save_tree": false,
34 | "num_sampling": 3,
35 | "enable_potential_score": false,
36 | "examples_txt_path": "prompts/TACO/examples.txt",
37 | "prompt_config_path": "prompts/TACO/prompt.json",
38 | "answer_sheets_dir": "run_outputs/answer_sheets",
39 | "cuda_0": null,
40 | "cuda_1": null,
41 | "cuda_2": null,
42 | "cuda_3": null
43 | }
--------------------------------------------------------------------------------
/src/mcts/run_outputs/data_examples/intermediate_result.txt:
--------------------------------------------------------------------------------
1 | Total calls: 48, Avg calls: 48.00
2 | Total tokens: 0, Avg tokens: 0.00
3 |
--------------------------------------------------------------------------------
/src/mcts/run_src/MCTS_backbone.py:
--------------------------------------------------------------------------------
1 | """
2 | A minimal implementation of Monte Carlo tree search (MCTS) in Python 3
3 | Luke Harold Miles, July 2019, Public Domain Dedication
4 | See also https://en.wikipedia.org/wiki/Monte_Carlo_tree_search
5 | https://gist.github.com/qpwo/c538c6f73727e254fdc7fab81024f6e1
6 | """
7 |
8 | from abc import ABC, abstractmethod
9 | from collections import defaultdict
10 | from typing import Dict, List
11 | import math, random
12 |
13 |
14 | node_cnt = 0
15 |
16 |
17 | def verbose_print(s: str, verbose: bool):
18 | if verbose:
19 | print(s)
20 |
21 |
22 | class MCTS_Node(ABC):
23 | """
24 | A representation of a single board state.
25 | MCTS works by constructing a tree of these Nodes.
26 | Could be e.g. a chess or checkers board state.
27 | """
28 |
29 | def __init__(self) -> None:
30 | super().__init__()
31 |
32 | global node_cnt
33 | self.id = node_cnt
34 | node_cnt += 1
35 |
36 | self.rollout_id = None
37 |
38 | def set_rollout_id(self, rollout_id: int):
39 | self.rollout_id = rollout_id
40 |
41 | @abstractmethod
42 | def find_children(self, rollout_id: int):
43 | "All possible successors of this board state"
44 | raise NotImplementedError
45 |
46 | @abstractmethod
47 | def is_terminal(self):
48 | "Returns True if the node has no children"
49 | raise NotImplementedError
50 |
51 | @abstractmethod
52 | def calculate_reward(self):
53 | "Assumes `self` is terminal node. 1=win, 0=loss, .5=tie, etc"
54 | raise NotImplementedError
55 |
56 |
57 |
58 | class MCTS_Searcher:
59 | "Monte Carlo tree searcher. First rollout the tree then choose a move."
60 |
61 | def __init__(
62 | self,
63 | exploration_weight: float,
64 | weight_scheduler: str,
65 | num_rollouts: int,
66 | discount: float,
67 | verbose: bool = False,
68 | ):
69 | self.Q: Dict[MCTS_Node, float] = defaultdict(lambda: 0.0) # total reward of each node
70 | self.N: Dict[MCTS_Node, int] = defaultdict(lambda: 0) # total visit count for each node
71 | self.parent2children: Dict[MCTS_Node, List[MCTS_Node]] = dict() # children of each node
72 |
73 | #! explored = expanded + simulated, i.e. has seen terminal at least once, i.e. we can calculate its UCT value, i.e. has Q and N
74 | self.explored_nodes = set()
75 |
76 | self.exploration_weight = exploration_weight
77 | self.weight_scheduler = weight_scheduler
78 | self.num_rollouts = num_rollouts
79 | self.discount = discount
80 |
81 | self.verbose = verbose
82 |
83 | global node_cnt
84 | node_cnt = 0
85 |
86 | def do_rollout(self, root_node: MCTS_Node, rollout_id: int):
87 | "Make the tree one layer better. (Train for one iteration.)"
88 | verbose_print("==> Selecting a node...", self.verbose)
89 | path_1 = self._select(root_node, rollout_id)
90 | leaf = path_1[-1]
91 | verbose_print(f"==> Expanding node {leaf.id}...", self.verbose)
92 | self._expand(leaf, rollout_id)
93 | verbose_print(f"==> Simulating node {leaf.id}...", self.verbose)
94 | path_2 = self._simulate(leaf, rollout_id)
95 | verbose_print(f"==> Backpropagating...", self.verbose)
96 | self._backpropagate(path_1 + path_2)
97 | try:
98 | return path_2[-1]
99 | except:
100 | return path_1[-1]
101 |
102 | def _select(self, node: MCTS_Node, rollout_id: int) -> List[MCTS_Node]:
103 | "Find an unexplored descendent of `node`"
104 | path = []
105 | while True:
106 | path.append(node)
107 | # case 1: a node does not have children, then select the node itself
108 | if node not in self.parent2children.keys():
109 | return path
110 |
111 | # case 2: a node has children but not all children have been explored, then randomly select an unexplored child
112 | # unexplored = set(self.parent2children[node]) - self.explored_nodes # `set` introduces randomness
113 | unexplored = [n for n in self.parent2children[node] if n not in self.explored_nodes]
114 | if unexplored:
115 | n = random.choice(unexplored)
116 | path.append(n)
117 | return path
118 |
119 | # case 3: a node has children and all children have been explored, then select one child and go to the next layer
120 | node = self._uct_select(node, rollout_id)
121 |
122 | def _expand(self, node: MCTS_Node, rollout_id: int):
123 | "Update the `children` dict with the children of `node`"
124 | if node in self.explored_nodes:
125 | return # already expanded
126 |
127 | if node.is_terminal():
128 | self.explored_nodes.add(node)
129 | return # terminal node is non-expandable
130 |
131 | self.parent2children[node] = node.find_children(rollout_id)
132 |
133 | def _simulate(self, node: MCTS_Node, rollout_id: int) -> List[MCTS_Node]:
134 | "Returns the reward for a random simulation (to completion) of `node`"
135 | path = []
136 | cur_node = node
137 | while True:
138 | if cur_node.is_terminal():
139 | self.explored_nodes.add(node)
140 | return path
141 |
142 | if cur_node not in self.parent2children.keys():
143 | self.parent2children[cur_node] = cur_node.find_children(rollout_id)
144 |
145 | cur_node = random.choice(self.parent2children[cur_node]) # randomly select a child
146 | path.append(cur_node)
147 |
148 | def _backpropagate(self, path: List[MCTS_Node]):
149 | "Send the reward back up to the ancestors of the leaf"
150 | leaf = path[-1]
151 | reward = leaf.calculate_reward()
152 | for node in reversed(path):
153 | self.Q[node] += reward
154 | self.N[node] += 1
155 | self.explored_nodes.add(node)
156 | from run_src.rstar_utils import Node_Type
157 | parent_node = None
158 | for node in path[:-1]:
159 | if node.node_type == Node_Type.USER_QUESTION:
160 | parent_node = node
161 | continue
162 | if node.node_value is None:
163 | node.node_value = 0
164 | node.node_value += reward
165 | for key, val in parent_node.solution_trace[0]["ost_step_value"].items():
166 | node.solution_trace[0]["ost_step_value"][key] = val
167 | last_key = list(node.solution_trace[0]["ost_step_value"].keys())[-1]
168 | node.solution_trace[0]["ost_step_value"][last_key] = node.node_value
169 | parent_node = node
170 | node = path[-1]
171 | for key, val in parent_node.solution_trace[0]["ost_step_value"].items():
172 | node.solution_trace[0]["ost_step_value"][key] = val
173 |
174 |
175 | def _get_weight(self, rollout_id: int):
176 | # start with exploration weight, end with 0.1 * exploration weight
177 | if self.weight_scheduler == "exp":
178 | return self.exploration_weight * (0.1 ** (rollout_id / self.num_rollouts))
179 | elif self.weight_scheduler == "lin":
180 | return self.exploration_weight * (1 - 0.9 * (rollout_id / self.num_rollouts))
181 | elif self.weight_scheduler == "const":
182 | return self.exploration_weight
183 |
184 | def _uct_select(self, node: MCTS_Node, rollout_id: int):
185 | "Select a child of node, balancing exploration & exploitation"
186 |
187 | # All children of the node should already be expanded
188 | assert all(n in self.explored_nodes for n in self.parent2children[node])
189 |
190 | return max(
191 | self.parent2children[node], key=lambda n: self._compute_uct(parent_node=node, node=n, rollout_id=rollout_id)
192 | )
193 |
194 | def _compute_uct(self, parent_node: MCTS_Node, node: MCTS_Node, rollout_id: int):
195 | "Upper confidence bound for trees"
196 | if parent_node is None: # invalid UCT: the node is the root
197 | return 666
198 | else:
199 | if self.N[node] == 0: # invalid UCT: the node has not been explored yet
200 | return 999
201 | else:
202 | weight = self._get_weight(rollout_id)
203 | return self.Q[node] / self.N[node] + weight * math.sqrt(math.log(self.N[parent_node]) / self.N[node])
204 |
--------------------------------------------------------------------------------
/src/mcts/run_src/MCTS_for_reasoning.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 |
3 | import sys
4 |
5 | sys.path.append(".")
6 |
7 | import numpy as np
8 | import os, random, json, math
9 | import wandb
10 | from tqdm import trange
11 | from typing import List, Dict, Tuple
12 | from copy import deepcopy
13 |
14 |
15 | from models.IO_System import IO_System
16 | from common.utils import read_txt, read_json
17 | from eval_src.Evaluator import Evaluator
18 | from MCTS_backbone import MCTS_Searcher, MCTS_Node
19 | from run_src.rstar_utils import (
20 | Node_Type,
21 | reach_terminal_ost_step,
22 | concat_ost_steps,
23 | ost_find_best_solution,
24 | find_solution,
25 | time_decorator,
26 | print_tree_from_root
27 | )
28 |
29 |
30 | def verbose_print(s: str, verbose: bool):
31 | if verbose:
32 | print(s)
33 |
34 |
35 | class Generator:
36 | """Generator generates children nodes"""
37 |
38 | def __init__(self, args, tokenizer, model, evaluator: Evaluator) -> None:
39 | self.io = IO_System(args, tokenizer, model)
40 | self.evaluator = evaluator
41 |
42 | self.num_sampling = args.num_sampling
43 | self.max_tokens = args.max_tokens
44 | self.enable_potential_score = args.enable_potential_score
45 |
46 | self.mcts_num_last_votes = args.mcts_num_last_votes
47 |
48 |
49 | self.examples = read_txt(args.examples_txt_path)
50 | self.prompt = read_json(args.prompt_config_path)
51 |
52 |
53 | def _get_pass_code(self, io_output_list: List[str], user_question: str) -> Tuple[str, float]:
54 | assert len(io_output_list) > 0
55 |
56 | if len(io_output_list) == 1:
57 | most_confident_answer_full_completion = io_output_list[0]
58 | confidence = 1
59 | else:
60 | _, passed_full_completion, _, confidence = self.evaluator.find_pass_code(io_output_list, user_question)
61 | assert confidence >= 0
62 |
63 | return passed_full_completion, confidence
64 |
65 | @time_decorator
66 | def _get_TACO_code(self, io_output_list: List[str], test_case: dict, solution_trace: Dict[int, Dict[str, str]],) -> Tuple[str, float]:
67 | assert len(io_output_list) > 0
68 |
69 |
70 | _, passed_full_completion, confidence, solution_trace_ = self.evaluator.find_TACO_code(io_output_list, test_case, solution_trace)
71 | assert confidence >= 0
72 |
73 | return passed_full_completion, confidence, solution_trace_
74 |
75 | def generate_ost_step(
76 | self,
77 | user_question: str,
78 | test_case: dict,
79 | solution_trace: Dict[int, Dict[str, str]],
80 | paraphrased: bool,
81 | ):
82 | ost_step_list = []
83 | existing_ost_steps, next_ost_step_id = concat_ost_steps(solution_trace)
84 | io_input = (
85 | self.prompt["prompt_template"].format(
86 | examples=self.examples,
87 | question=user_question,
88 | )
89 | + existing_ost_steps
90 | + f"\n### Step {next_ost_step_id}:"
91 | )
92 |
93 | io_output_list = self.io.generate(
94 | model_input=io_input, max_tokens=8192, num_return=self.num_sampling, stop_tokens=[""]
95 | )
96 | ost_step_list = [io_output.strip() for io_output in io_output_list]
97 |
98 | last_ost_step = []
99 | value_list = []
100 | completion_confidence_list = []
101 | reach_last_step_flag = False
102 | # have_terminal_ost_step = False
103 | for ost_step in ost_step_list:
104 | if reach_terminal_ost_step(ost_step):
105 | reach_last_step_flag = True
106 | passed_full_completion, confidence, solution_trace_with_last_step = self._get_TACO_code([ost_step], test_case, solution_trace)
107 | completion_confidence_list.append((passed_full_completion, confidence))
108 | else:
109 | last_ost_step.append(ost_step)
110 | value_list.append(None)
111 |
112 | if reach_last_step_flag == True:
113 | last_ost_step.clear()
114 | value_list.clear()
115 | completion_confidence_list.sort(key=lambda x: x[1], reverse=True)
116 | best_passing_completion, highest_confidence = completion_confidence_list[0]
117 |
118 | return [best_passing_completion], [highest_confidence], [None]
119 | else:
120 | potential_answers_list: List[List[str]] = []
121 | # print(value_list)
122 | if value_list.count(None) != 0 and value_list.count(None) != len(value_list):
123 | for idx, value in enumerate(value_list):
124 | if value is not None:
125 | number = value
126 | corresponding_step = last_ost_step[idx]
127 | break
128 | value_list = [number]
129 | last_ost_step = [corresponding_step]
130 | potential_answers_list = [None] * len(value_list)
131 | return last_ost_step, value_list, potential_answers_list
132 |
133 |
134 | class Reasoning_MCTS_Node(MCTS_Node):
135 | def __init__(
136 | self,
137 | parent: "Reasoning_MCTS_Node",
138 | depth: int,
139 | node_type: Node_Type,
140 | verbose: bool = False,
141 | # --- For instantiating root node ---
142 | node_value: float = None,
143 | generator: Generator = None,
144 | user_question: str = None,
145 | max_depth_allowed: int = None,
146 | difficulty: str = None,
147 | # -------------------------------------------
148 | # --- For instantiating OST_STEP node ---
149 | ost_step: str = None,
150 | # ---------------------------------------
151 | # --- For node selection (not in sanity checks yet) ---
152 | enable_potential_score: bool = None,
153 | potential_answers: List[str] = None,
154 | test_case: dict = None,
155 | ) -> None:
156 | """params:
157 | subquestion: the node is proposing a new subquestion
158 | subanswer: the answer corresponding to the new subquestion the node proposed
159 | re_subanswer: the node is proposing a new subanswer to the parent's subquestion
160 | """
161 | super().__init__()
162 |
163 | #! sanity checks
164 | try:
165 | assert depth is not None
166 | assert node_type is not None
167 | if node_value is not None:
168 | print(node_value)
169 | assert node_value >= 0, breakpoint()
170 |
171 | if node_type is Node_Type.USER_QUESTION:
172 | assert depth == 0
173 | assert all(
174 | attr is None
175 | for attr in [
176 | parent,
177 | node_value,
178 | ost_step,
179 | ]
180 | )
181 | assert all(
182 | attr is not None
183 | for attr in [generator, user_question, difficulty, max_depth_allowed]
184 | )
185 | elif node_type is Node_Type.ONE_STEP:
186 | assert depth > 0
187 | assert all(
188 | attr is None
189 | for attr in [
190 | generator,
191 | user_question,
192 | difficulty,
193 | max_depth_allowed,
194 | ]
195 | )
196 | assert all(attr is not None for attr in [parent, ost_step])
197 | except AssertionError:
198 | print(f"Instantiating node with type {node_type} failed!")
199 | breakpoint()
200 | exit()
201 |
202 | #! attributes
203 | self.parent = parent # if parent is None, then the node is the root
204 | self.children: List["Reasoning_MCTS_Node"] = []
205 | self.depth = depth
206 | self.node_type = node_type
207 | self.node_value = node_value
208 | self.ost_step = ost_step
209 | self.test_case = test_case
210 |
211 | if parent is None: # root
212 | self.verbose = verbose
213 | self.user_question = user_question
214 | self.difficulty = difficulty
215 | self.generator = generator
216 | self.max_depth_allowed = max_depth_allowed
217 | self.enable_potential_score = enable_potential_score
218 | self.test_case = test_case
219 | else: # inherit from parent
220 | self.verbose = parent.verbose
221 | self.user_question = parent.user_question
222 | self.difficulty = parent.difficulty
223 | self.generator = parent.generator
224 | self.max_depth_allowed = parent.max_depth_allowed
225 | self.enable_potential_score = parent.enable_potential_score
226 | self.test_case = parent.test_case
227 |
228 | #! keep track of paraphrasing
229 | if node_type is Node_Type.USER_QUESTION:
230 | self.paraphrased = False
231 | else:
232 | assert parent is not None
233 | self.paraphrased = parent.paraphrased
234 |
235 |
236 | #! record number of one-step thought steps till now
237 | if parent is None: # root
238 | self.ost_step_counter = 0
239 | else:
240 | if node_type is Node_Type.ONE_STEP:
241 | self.ost_step_counter = parent.ost_step_counter + 1
242 | else:
243 | self.ost_step_counter = parent.ost_step_counter
244 |
245 | #! record solution trace from root to the current node. key: subquestion id
246 | if parent is None: # root
247 | assert self.node_type is Node_Type.USER_QUESTION
248 | self.solution_trace: Dict[int, Dict[str, str]] = {0: {"user_question": user_question, "ost_step": {}, "ost_step_value": {}}}
249 | else:
250 | assert self.node_type is not Node_Type.USER_QUESTION
251 | self.solution_trace = deepcopy(parent.solution_trace)
252 |
253 | if node_type is Node_Type.ONE_STEP:
254 | assert "ost_step" in self.solution_trace[0].keys()
255 | self.solution_trace[0]["ost_step"][self.ost_step_counter] = ost_step
256 | self.solution_trace[0]["ost_step_value"][self.ost_step_counter] = node_value
257 |
258 | #! potential_score for intermediate nodes (only used for node selection)
259 | if self.enable_potential_score:
260 | self.potential_answers = potential_answers
261 | self.potential_score = 0
262 | if parent is None: # root
263 | assert self.node_type is Node_Type.USER_QUESTION
264 | self.potential_answers_history = {}
265 | else:
266 | assert self.node_type is not Node_Type.USER_QUESTION
267 | self.potential_answers_history = deepcopy(parent.potential_answers_history)
268 | self.potential_answers_history[self.depth] = potential_answers
269 |
270 | def __str__(self) -> str:
271 | type2str = {
272 | Node_Type.USER_QUESTION: "U",
273 | Node_Type.ONE_STEP: "TS",
274 | }
275 | return f"{type2str[self.node_type]}-{self.id}"
276 |
277 | def _create_children(self):
278 |
279 | def do_action_generate_ost_step():
280 | verbose_print(f"---- Generating one-step thought steps for node {self.id}...", self.verbose)
281 |
282 | #! ACTION: generate one-step thought step
283 | ost_step_list, value_list, potential_answers_list = self.generator.generate_ost_step(
284 | user_question=self.user_question,
285 | test_case = self.test_case,
286 | solution_trace=self.solution_trace,
287 | paraphrased=self.paraphrased,
288 | )
289 | for ost_step, value, potential_answers in zip(ost_step_list, value_list, potential_answers_list):
290 | self.children.append(
291 | Reasoning_MCTS_Node(
292 | parent=self,
293 | depth=self.depth + 1,
294 | node_type=Node_Type.ONE_STEP,
295 | node_value=value,
296 | ost_step=ost_step,
297 | potential_answers=deepcopy(potential_answers),
298 | )
299 | )
300 |
301 | #! create children
302 | if self.node_type is Node_Type.USER_QUESTION:
303 | # generate one-step thought steps
304 | do_action_generate_ost_step()
305 |
306 |
307 | elif self.node_type is Node_Type.ONE_STEP:
308 |
309 | do_action_generate_ost_step()
310 |
311 | assert self.children
312 | return self.children
313 |
314 | def is_valid_leaf_node(self):
315 |
316 | return (self.node_type is Node_Type.ONE_STEP and reach_terminal_ost_step(self.ost_step))
317 |
318 | def is_valid_solution_node(self):
319 |
320 | return (self.node_type is Node_Type.ONE_STEP and reach_terminal_ost_step(self.ost_step))
321 |
322 | def set_potential_score(self, score: float):
323 | self.potential_score = score
324 |
325 | def find_children(self, rollout_id: int):
326 | self.children = self.children or self._create_children()
327 | for child in self.children:
328 | child.set_rollout_id(rollout_id)
329 | assert self.children
330 | return self.children
331 |
332 | def is_terminal(self):
333 | return self.depth >= self.max_depth_allowed or self.is_valid_leaf_node()
334 |
335 | def calculate_reward(self):
336 | if self.is_valid_leaf_node():
337 | assert self.node_value is not None, breakpoint()
338 | return self.node_value
339 | else:
340 | return 0
341 |
342 |
343 | def search_for_answers(args, user_question: str, question_id: int, difficulty: str, generator: Generator, test_case: dict):
344 | verbose_print(
345 | f"********************* Searching for answers to question {question_id} ********************* ", args.verbose
346 | )
347 |
348 | #! build an MCTS searcher
349 | mcts_searcher = MCTS_Searcher(
350 | exploration_weight=args.mcts_exploration_weight,
351 | weight_scheduler=args.mcts_weight_scheduler,
352 | num_rollouts=args.num_rollouts,
353 | discount=args.mcts_discount_factor,
354 | verbose=args.verbose,
355 | )
356 |
357 | #! build the MCTS tree
358 | root_node = Reasoning_MCTS_Node(
359 | parent=None,
360 | depth=0,
361 | node_type=Node_Type.USER_QUESTION,
362 | verbose=args.verbose,
363 | generator=generator,
364 | user_question=user_question,
365 | difficulty=difficulty,
366 | max_depth_allowed=args.max_depth_allowed,
367 | enable_potential_score=args.enable_potential_score,
368 | test_case=test_case,
369 | )
370 |
371 | model_solutions = []
372 | model_all_solutions = []
373 | model_rollout_nodes = []
374 | for i in (pbar := trange(args.num_rollouts, disable=True, position=0)):
375 | rollout_node = mcts_searcher.do_rollout(root_node, i)
376 | model_rollout_nodes.append(rollout_node)
377 | jss = {"trace": rollout_node.solution_trace, "rollout_id": rollout_node.rollout_id, "value": rollout_node.node_value}
378 |
379 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - rollout Solutions.json"), "a") as f:
380 | json.dump(jss, f)
381 | f.write(',')
382 |
383 | # print_tree_from_root(mcts_searcher, args.num_rollouts - 1, root_node)
384 |
385 | ost_best_node, ost_all_solution_nodes, TREE = ost_find_best_solution(root_node, generator.evaluator)
386 |
387 | complete_road = []
388 |
389 | for solution_node in ost_all_solution_nodes:
390 | complete_road_json = find_solution(root_node, solution_node, mcts_searcher)
391 | complete_road.append(complete_road_json)
392 |
393 |
394 | bestv = -1
395 | ost_best_node = None
396 | for rollout_node in model_rollout_nodes:
397 | if rollout_node.node_value is not None:
398 | if rollout_node.node_value > bestv:
399 | bestv = rollout_node.node_value
400 | ost_best_node = rollout_node
401 |
402 |
403 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Complete Solutions.json"), "w", encoding="utf-8") as f:
404 | json.dump(complete_road, f, ensure_ascii=False, indent=4)
405 | #! record final traces
406 | js = [{"trace": node.solution_trace, "rollout_id": node.rollout_id, "parent_id": node.parent.id, "value": node.node_value} for node in ost_all_solution_nodes]
407 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Final Solutions.json"), "w") as f:
408 | json.dump(js, f)
409 |
410 | js2 = [{"trace": node.solution_trace, "rollout_id": i, "value": node.node_value} for i, node in enumerate(model_rollout_nodes)]
411 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Rollout Solutions.json"), "w") as f:
412 | json.dump(js2, f)
413 |
414 | if ost_best_node is not None:
415 | js3 = {"trace": ost_best_node.solution_trace, "rollout_id": ost_best_node.rollout_id, "value": ost_best_node.node_value}
416 |
417 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Best Solutions.json"), "w") as f:
418 | json.dump(js3, f)
419 |
420 |
421 |
422 | if args.enable_potential_score:
423 | js = [node.potential_answers_history for node in ost_all_solution_nodes]
424 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Potentials.json"), "w") as f:
425 | json.dump(js, f)
426 |
427 |
428 | return model_solutions, i, model_all_solutions
429 |
--------------------------------------------------------------------------------
/src/mcts/run_src/do_generate.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 |
3 | import sys
4 |
5 |
6 | print("args: ", sys.argv)
7 |
8 | import os, json, time
9 | from tqdm import tqdm
10 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
11 |
12 | sys.path.append(".")
13 |
14 | from common.utils import fix_seeds, setup_model_parallel, read_json
15 | from common.arguments import get_parser, post_process_args, save_args
16 | from MCTS_for_reasoning import Generator, search_for_answers
17 | from eval_src.Evaluator import *
18 |
19 |
20 | def main(args):
21 | fix_seeds(args.seed)
22 | if args.model_parallel:
23 | args.local_rank, args.world_size = setup_model_parallel()
24 | else:
25 | args.local_rank, args.world_size = 0, 1
26 |
27 | test_file = os.path.join(args.data_root, args.dataset_name, args.test_json_filename + ".json")
28 | assert os.path.exists(test_file), f"Test file {test_file} does not exist."
29 | data_item_list = read_json(test_file)
30 |
31 | evaluator = eval(f"{args.dataset_name}Evaluator()")
32 |
33 | tokenizer, model = None, None
34 | if args.api == "huggingface":
35 | from models.HuggingFace_API import load_HF_model
36 |
37 | tokenizer, model = load_HF_model(args.model_ckpt)
38 | elif args.api == "vllm":
39 | from models.vLLM_API import load_vLLM_model
40 |
41 | tokenizer, model = load_vLLM_model(args.model_ckpt, args.seed, args.tensor_parallel_size, args.half_precision)
42 | elif args.api == "OpenAI":
43 | from models.OpenAI_API import load_OpenAI_model
44 |
45 | tokenizer, model = load_OpenAI_model(args.model_ckpt)
46 | generator = Generator(args, tokenizer, model, evaluator)
47 |
48 | num_tested = 0
49 | start_time = time.time()
50 |
51 | for i, data_item in enumerate(
52 | (pbar := tqdm(data_item_list, disable=args.local_rank > 0 or args.verbose, position=1))
53 | ):
54 | if i < args.start_idx or i >= args.end_idx:
55 | continue
56 | st_time = time.time()
57 | problem_id, problem, test_case, difficulty = i, data_item["question"],data_item["input_output"], data_item["difficulty"]
58 | gt_solution = data_item["solutions"][0] if len(data_item["solutions"]) > 0 else None
59 |
60 | js = {
61 | "id": problem_id,
62 | "problem": problem,
63 | "gold_solution": gt_solution,
64 | "test_case": test_case,
65 | "difficulty": difficulty,
66 | }
67 |
68 |
69 | model_solutions, stopping_id, model_all_solutions = search_for_answers(
70 | args=args, user_question=problem, question_id=i, difficulty=difficulty, generator=generator, test_case = test_case
71 | )
72 |
73 | num_tested += 1
74 |
75 |
76 |
77 | with open(os.path.join(args.run_outputs_dir, "intermediate_result.txt"), "w") as f:
78 | f.write(
79 | f"Total calls: {generator.io.call_counter}, Avg calls: {generator.io.call_counter/(num_tested):.2f}\n"
80 | )
81 | f.write(
82 | f"Total tokens: {generator.io.token_counter}, Avg tokens: {generator.io.token_counter/(num_tested):.2f}\n"
83 | )
84 | ed_time = time.time()
85 | js["time_taken"] = f"{ed_time-st_time:.2f}s"
86 |
87 | with open(os.path.join(args.answer_sheets_dir, f"Question {i:04d} - Answer.json"), "w") as f:
88 | json.dump(js, f)
89 | print(f"==> Time taken for this question: {ed_time-st_time:.2f}s")
90 |
91 | end_time = time.time()
92 |
93 | print(f"==> Total calls: {generator.io.call_counter}, Avg calls: {generator.io.call_counter/(num_tested):.2f}")
94 | print(f"==> Total tokens: {generator.io.token_counter}, Avg tokens: {generator.io.token_counter/(num_tested):.2f}")
95 | print(f"==> Total time: {end_time-start_time:.2f}s, Avg time: {(end_time-start_time)/(num_tested):.2f}s")
96 |
97 | with open(os.path.join(args.run_outputs_dir, "final_result.txt"), "w") as f:
98 | f.write(f"Total calls: {generator.io.call_counter}, Avg calls: {generator.io.call_counter/(num_tested):.2f}\n")
99 | f.write(
100 | f"Total tokens: {generator.io.token_counter}, Avg tokens: {generator.io.token_counter/(num_tested):.2f}\n"
101 | )
102 | f.write(f"Total time: {end_time-start_time:.2f}s, Avg time: {(end_time-start_time)/(num_tested):.2f}s\n")
103 |
104 |
105 | if __name__ == "__main__":
106 | #! -------------------------------- Arguments --------------------------------
107 | parser = get_parser()
108 |
109 | parser.add_argument("--num_rollouts", type=int, default=15)
110 | parser.add_argument("--max_depth_allowed", type=int, default=5)
111 |
112 | # MCTS
113 | parser.add_argument("--mcts_discount_factor", type=float, default=1.0)
114 | parser.add_argument("--mcts_exploration_weight", type=float, default=2.0)
115 | parser.add_argument("--mcts_weight_scheduler", choices=["exp", "lin", "const"], default="const")
116 | parser.add_argument("--mcts_num_last_votes", type=int, default=None)
117 | parser.add_argument("--save_tree", action="store_true")
118 |
119 | # Action1: Propose an one-step thought.
120 | parser.add_argument("--num_sampling", type=int, default=3)
121 |
122 |
123 | #! -------------------------- Used for selecting answer --------------------------
124 | parser.add_argument("--enable_potential_score", action="store_true")
125 |
126 | #! -------------------------------------------------------------------------------
127 |
128 | args = parser.parse_args()
129 |
130 | if args.mcts_num_last_votes is None:
131 | args.mcts_num_last_votes = 32
132 |
133 | #! ----------------------------------------------------------------------------
134 |
135 | prompts_dir = os.path.join(args.prompts_root, args.dataset_name)
136 |
137 |
138 | args.examples_txt_path = os.path.join(prompts_dir, "examples.txt")
139 | args.prompt_config_path = os.path.join(prompts_dir, "prompt.json")
140 |
141 |
142 | args = post_process_args(args)
143 | print(args)
144 | save_args(args)
145 | main(args)
146 |
--------------------------------------------------------------------------------
/src/mcts/run_src/rstar_utils.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 |
3 | from enum import Enum, unique
4 | import re
5 | import math
6 | from typing import Dict, Tuple
7 | from colorama import Fore, Style
8 | import math
9 |
10 |
11 | @unique
12 | class Node_Type(Enum):
13 | USER_QUESTION = "USER_QUESTION"
14 | ONE_STEP = "ONE_STEP"
15 |
16 | import time
17 |
18 | def time_decorator(func):
19 | def wrapper(*args, **kwargs):
20 | start_time = time.time()
21 | result = func(*args, **kwargs)
22 | end_time = time.time()
23 | print(f"{func.__name__} runs: {end_time - start_time:.6f} 秒")
24 | return result
25 | return wrapper
26 |
27 |
28 | def reach_terminal_ost_step(ost_step: str):
29 | assert ost_step is not None
30 | last_step = ost_step.lower()
31 |
32 | code_indicators = [
33 | # "",
34 | "```python"
35 | ]
36 |
37 | return any(indicator in last_step for indicator in code_indicators)
38 |
39 |
40 | def print_tree_from_root(mcts_searcher, rollout_id, root_node, chosen_node=None, file=None):
41 | color_print = False if file else True
42 |
43 | def my_print(text):
44 | if file:
45 | file.write(text + "\n")
46 | else:
47 | print(text)
48 |
49 | def print_tree(parent_node, node, file, rollout_id):
50 | to_print = ""
51 |
52 | num_indent = 4
53 | dash = "-" * num_indent * node.depth
54 | space = " " * num_indent * node.depth
55 |
56 | attributes = f"Q: {round(mcts_searcher.Q[node], 2)}" + "; " + f"N: {mcts_searcher.N[node]}" + "; "
57 | attributes += f"V: {round(node.node_value, 2)}" if node.node_value is not None else "V: None"
58 |
59 | uct_value = "UCT: " + str(
60 | round(mcts_searcher._compute_uct(parent_node=parent_node, node=node, rollout_id=rollout_id), 2)
61 | )
62 | attributes += "; " + uct_value
63 |
64 | solution_marker = "(T) " if node.is_valid_solution_node() else ""
65 |
66 | node_info = "[" + solution_marker + node.__str__() + ": " + attributes + "]"
67 | if chosen_node and node == chosen_node:
68 | node_info = "[" + node_info + "]"
69 | node_info += " "
70 |
71 | if color_print and node.is_valid_solution_node():
72 | node_details = Fore.RED + Style.BRIGHT + node_info + Fore.RESET + Style.RESET_ALL
73 | else:
74 | node_details = node_info
75 |
76 | if node.node_type is Node_Type.USER_QUESTION:
77 | node_details += f"User: {node.user_question}" + "\n" + space + " " * len(node_info)
78 | elif node.node_type is Node_Type.ONE_STEP:
79 | node_details += f"OST: {node.ost_step}"
80 |
81 | to_print += dash + node_details
82 |
83 | my_print(to_print)
84 |
85 | for child in node.children:
86 | print_tree(node, child, file, rollout_id)
87 |
88 | if node.depth == 0:
89 | my_print("\n" + "=" * 50 + "\n")
90 |
91 | print_tree(parent_node=None, node=root_node, file=file, rollout_id=rollout_id)
92 |
93 |
94 | def concat_ost_steps(solution_trace: Dict[int, Dict[str, str]]) -> Tuple[str, int]:
95 | """Return: concatenated one-step thought steps, next one-step thought step id"""
96 | last_tuple = list(solution_trace.items())[-1]
97 | last_tuple_id, last_tuple_recording = last_tuple[0], last_tuple[1]
98 | assert "ost_step" in last_tuple_recording.keys()
99 | if len(last_tuple_recording["ost_step"]) > 0:
100 | solution_trace_str = ""
101 | for step_id, step_text in last_tuple_recording["ost_step"].items():
102 | solution_trace_str += f"\n### Step {step_id}: " + step_text + "\n\n"
103 | return solution_trace_str, step_id + 1
104 | else:
105 | # no one-step thought step yet
106 | return "", 1
107 |
108 |
109 | def find_valid_solution_nodes(root_node):
110 | valid_solution_nodes = []
111 | TREE = {}
112 |
113 | def recursion(node):
114 | if root_node.depth in TREE:
115 | TREE[root_node.depth].append(root_node)
116 | else:
117 | TREE[root_node.depth] = [root_node]
118 |
119 | if node.is_valid_solution_node():
120 | valid_solution_nodes.append(node)
121 | return
122 |
123 | if not node.children: #! no children
124 | return
125 |
126 | for child in node.children:
127 | recursion(child)
128 |
129 | recursion(root_node)
130 |
131 | return valid_solution_nodes, TREE
132 |
133 |
134 | def find_best_solution(root_node, evaluator, enable_potential_score=False):
135 | # todo: what strategy do we use to select best node?
136 | """The function finds the best solution from the solution nodes in the MCTS tree.
137 | Return: top answer, top solution, confidence of the top answer, the corresponding node of the answer, all solution nodes
138 | """
139 | solution_nodes = find_valid_solution_nodes(root_node)
140 |
141 | if len(solution_nodes) == 0:
142 | return None, None
143 |
144 | def extract_solution_from_node(node):
145 | if node.node_type is Node_Type.SUBQUESTION:
146 | return node.subanswer
147 | elif node.node_type is Node_Type.DIRECT_ANSWER:
148 | return node.direct_answer
149 | else:
150 | return None
151 |
152 | solutions = [extract_solution_from_node(node) for node in solution_nodes]
153 |
154 | def calculate_potential_score_for_solution_node(node):
155 | model_answer = evaluator.extract_answer_from_model_completion(extract_solution_from_node(node))
156 | potential_answers_history = node.potential_answers_history # {depth -> [potential answers]}
157 | assert potential_answers_history[node.depth] is None
158 |
159 | potential_score = 1
160 | for depth, depth_potential_answers in potential_answers_history.items():
161 | if depth < node.depth:
162 | depth_score = sum(
163 | evaluator.check_answers_equiv(dpa, model_answer) for dpa in depth_potential_answers
164 | ) / len(depth_potential_answers)
165 | potential_score *= depth_score
166 |
167 | node.set_potential_score(potential_score)
168 | return potential_score
169 |
170 | prior_weights = (
171 | [calculate_potential_score_for_solution_node(node) for node in solution_nodes]
172 | if enable_potential_score
173 | else None
174 | )
175 | top_answer, top_completion, top_completion_id, top_confidence = evaluator.find_most_confident_answer(
176 | solutions, prior_weights
177 | )
178 | return top_answer, top_completion, top_confidence, solution_nodes[top_completion_id], solution_nodes
179 |
180 |
181 | def ost_find_best_solution(
182 | root_node,
183 | evaluator,
184 | ):
185 | solution_nodes, TREE = find_valid_solution_nodes(root_node)
186 |
187 | bestv = -1
188 | best_node = None
189 | for solution_node in solution_nodes:
190 | if solution_node.node_value > bestv:
191 | bestv = solution_node.node_value
192 | best_node = solution_node
193 |
194 | return best_node, solution_nodes, TREE
195 |
196 | def find_solution(root_node, solution_node, mcts_searcher):
197 | """
198 | Recursively traces back from the given solution node to the root node,
199 | calculating the value of each node along the path.
200 |
201 | Parameters:
202 | solution_node (Node): The current solution node to start the backtrace from.
203 | mcts_searcher (MCTS): The MCTS searcher object used to access node visit counts and values.
204 |
205 | Returns:
206 | dict: A dictionary representing the complete solution, containing the node id,
207 | OST step, step value, and edge information for each node in the path.
208 | """
209 | comlete_solution = {}
210 |
211 | def reback(node):
212 | """
213 | Recursively backtracks from the current node to the root node,
214 | calculating the value for each node and updating the solution.
215 |
216 | Parameters:
217 | node (Node): The current node being processed in the backtrack.
218 | """
219 | if node.node_value is not None and mcts_searcher.N[node] != 0:
220 | value = node.node_value / mcts_searcher.N[node]
221 | else:
222 | value = 0
223 | if node.node_type is Node_Type.ONE_STEP:
224 | comlete_solution[node.depth] = {
225 | "node_id": node.id,
226 | "ost_step": node.ost_step,
227 | "step_value": value,
228 | "edges": (node.parent.id, node.id) # source_node_id -> target_node_id
229 | }
230 | else:
231 | comlete_solution[node.depth] = {
232 | "node_id": node.id,
233 | "question": root_node.user_question,
234 | }
235 | if node.node_type is Node_Type.USER_QUESTION:
236 | return
237 |
238 | reback(node.parent)
239 |
240 | reback(solution_node)
241 | return comlete_solution
242 |
243 |
--------------------------------------------------------------------------------
/src/mcts/scripts/api_run_TACO.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python run_src/do_generate.py \
2 | --dataset_name TACO \
3 | --test_json_filename test_one \
4 | --api OpenAI \
5 | --model_ckpt gpt-4o-mini \
6 | --note default \
7 | --num_rollouts 12 \
8 | --verbose \
9 | --max_depth_allowed 10
--------------------------------------------------------------------------------
/src/mcts/scripts/run_TACO.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=3 python run_src/do_generate.py \
2 | --dataset_name TACO \
3 | --test_json_filename test_one \
4 | --model_ckpt \
5 | --note default \
6 | --num_rollouts 3 \
7 | --verbose \
8 | --max_depth_allowed 8
--------------------------------------------------------------------------------
/src/mcts/scripts/start.md:
--------------------------------------------------------------------------------
1 | bash scripts/api_run_TACO.sh
2 | bash scripts/run_TACO.sh
--------------------------------------------------------------------------------
/src/prm_training/README.md:
--------------------------------------------------------------------------------
1 | ## How to use
2 | To initialize the environment, you need to install the required packages. You can do this by running the following command:
3 |
4 | ```bash
5 | pip install -r requirements.txt
6 | ```
7 |
8 | This will install all the necessary dependencies listed in the `requirements.txt` file.
9 |
10 | ### Data
11 | We provide data examples for PRM training in the `data` folder. The reward labels are available in both hard and soft estimation forms.
12 |
13 | For the hard estimation labels, you can refer to `data/examples/hard_label_examples.json` for processing, while the corresponding soft label forms are provided in `data/examples/soft_label_examples.json`.
14 |
15 | ### Train
16 | #### Basic Usage
17 | Ensure the path is within the `prm_training` folder and run the following script
18 | ```bash
19 | bash run.sh
20 | ```
21 | #### Main Arguments
22 | | | | |
23 | |---------------|--------|-----------------------------------|
24 | | `--config_file` | str | accelerate config file path |
25 | | `--model_name_or_path` | str | your model path |
26 | | `--data_path` | str | data for training |
27 | | `--use_soft_label` | bool | Whether to use soft labels during training, default is false |
--------------------------------------------------------------------------------
/src/prm_training/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.1.1
2 | datasets==3.1.0
3 | numpy==2.1.3
4 | peft==0.13.2
5 | scikit_learn==1.5.2
6 | torch==2.4.1
7 | transformers==4.46.2
8 | deepspeed==0.15.3
9 | wandb
--------------------------------------------------------------------------------
/src/prm_training/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from dataclasses import dataclass, field
4 | from typing import Optional
5 | from transformers import Seq2SeqTrainingArguments, HfArgumentParser
6 | from accelerate import PartialState
7 |
8 | from train_prm.run_train import run_exp
9 |
10 | DIST_STATE = PartialState()
11 |
12 | @DIST_STATE.on_local_main_process
13 | def print_rank_0(msg):
14 | print(msg)
15 |
16 | @dataclass
17 | class ModelArguments:
18 | """
19 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
20 | """
21 | model_name_or_path: str = field(
22 | metadata={"help": "Path to pretrained model", "required": True}
23 | )
24 |
25 | @dataclass
26 | class DataTrainingArguments:
27 | """
28 | Arguments pertaining to what data we are going to input our model for training and eval.
29 | """
30 | data_path: str = field(
31 | metadata={"help": "Path to dataset", "required": True}
32 | )
33 | use_soft_label: bool = field(
34 | default=False,
35 | metadata={"help": "Whether to use soft labels for prm training"}
36 | )
37 | server: str = field(
38 | default="1",
39 | metadata={"help": "Server configuration"}
40 | )
41 |
42 | def main():
43 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
44 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
45 | print_rank_0('*'*30+f'\nModel arguments:\n{model_args}\nData arguments:\n{data_args}\nTraining arguments:\n{training_args}\n'+'*'*30)
46 | run_exp(model_args, data_args, training_args)
47 |
48 | if __name__ == '__main__':
49 | main()
--------------------------------------------------------------------------------
/src/prm_training/run.sh:
--------------------------------------------------------------------------------
1 | accelerate launch --config_file "./train_prm/utils/dist_configs/multi_gpu.yaml" \
2 | ./run.py \
3 | --model_name_or_path "your model path" \
4 | --data_path "./data/examples/soft_label_examples.json" \
5 | --use_soft_label \
6 | --output_dir "./train_prm/outputs" \
7 | --overwrite_output_dir \
8 | --per_device_train_batch_size 1 \
9 | --per_device_eval_batch_size 4 \
10 | --gradient_accumulation_steps 4 \
11 | --num_train_epochs 3 \
12 | --learning_rate 1e-4 \
13 | --lr_scheduler_type "cosine" \
14 | --save_strategy "steps" \
15 | --eval_strategy "steps" \
16 | --save_steps 200 \
17 | --eval_steps 100 \
18 | --save_total_limit 2 \
19 | --weight_decay 0.01 \
20 | --logging_steps 10 \
21 | --log_level "info" \
22 | --bf16 \
23 | --report_to "wandb" \
24 | --ddp_find_unused_parameters False \
--------------------------------------------------------------------------------
/src/prm_training/train_prm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ADaM-BJTU/O1-CODER/f01c769397afaefc89c22c51d048484a79d11c1f/src/prm_training/train_prm/__init__.py
--------------------------------------------------------------------------------
/src/prm_training/train_prm/run_train.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import argparse
4 | import torch.nn as nn
5 | import os
6 | import random
7 | import numpy as np
8 | from dataclasses import dataclass
9 | from typing import Tuple, List, Dict, Union, Optional, Any
10 | from peft import PeftModel
11 | from peft import get_peft_model, LoraConfig, TaskType
12 | from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, Seq2SeqTrainingArguments
13 | from transformers import DataCollatorWithPadding, DataCollatorForSeq2Seq
14 | from transformers.modeling_outputs import CausalLMOutputWithPast
15 | from sklearn.metrics import roc_auc_score, log_loss, accuracy_score
16 | from torch.nn import BCEWithLogitsLoss
17 | from datasets import load_dataset
18 | from accelerate import PartialState
19 |
20 | DIST_STATE = PartialState()
21 |
22 | @DIST_STATE.on_local_main_process
23 | def print_rank_0(msg):
24 | print(msg)
25 |
26 | def print_rank(msg: str):
27 | local_rank = int(os.environ.get('LOCAL_RANK', -1))
28 | print(f'[LOCAL_RANK {local_rank}]:\n{msg}')
29 |
30 |
31 | def setup_model_and_tokenizer(model_path: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
32 | tokenizer = AutoTokenizer.from_pretrained(
33 | model_path,
34 | add_eos_token=False,
35 | )
36 | tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
37 |
38 | model = AutoModelForCausalLM.from_pretrained(
39 | model_path,
40 | torch_dtype=torch.bfloat16,
41 | attn_implementation="flash_attention_2",
42 | ).to('cuda')
43 |
44 | lora_config = LoraConfig(
45 | task_type=TaskType.CAUSAL_LM, # LoRA for causal language modeling task
46 | r=8, # Rank of LoRA
47 | lora_alpha=32, # Alpha scaling factor for LoRA
48 | lora_dropout=0.1, # Dropout rate for LoRA layers
49 | target_modules=["q_proj", "v_proj"], # Apply LoRA to specific layers
50 | )
51 |
52 | model = get_peft_model(model, lora_config)
53 | if DIST_STATE.is_local_main_process:
54 | model.print_trainable_parameters()
55 | return model, tokenizer
56 |
57 |
58 | # Make corresponding modifications based on `openr/prm/code/finetune_qwen.py`
59 | class DatasetProcessor:
60 | def __init__(self, args, tokenizer: AutoTokenizer):
61 | self.args = args
62 | self.tokenizer = tokenizer
63 | self.good_token = '+'
64 | self.bad_token = '-'
65 | self.step_tag = ' Rating'
66 | self.candidate_token_ids = self.tokenizer.encode(f" {self.good_token} {self.bad_token}")
67 | self.step_tag_id = self.tokenizer.encode(f"{self.step_tag}")[-1]
68 |
69 | def print_example(self, example):
70 | print_rank_0('*' * 20 + ' Example View ' + '*' * 20)
71 | print_rank_0('Tokenized Data:\n' + '=' * 30 + '\n'
72 | f'Input_ids: {example["input_ids"]}\nAttention_mask: {example["attention_mask"]}\nLabels: {example["labels"]}')
73 | ignore_index = -100 if not self.args.use_soft_label else -100.0
74 | valid_labels = [label for label in example["labels"] if label != ignore_index]
75 | if not self.args.use_soft_label:
76 | print_rank_0('Decoded Data:\n' + '=' * 30 + '\n'
77 | f'Input: {self.tokenizer.decode(example["input_ids"])}\nLabels: {self.tokenizer.decode(valid_labels)}')
78 | else:
79 | print_rank_0('Decoded Data:\n' + '=' * 30 + '\n'
80 | f'Input: {self.tokenizer.decode(example["input_ids"])}\nLabels: {valid_labels}')
81 |
82 | def preprocess_example(self, example):
83 | messages = [
84 | {"role": "user", "content": example["prompt"]},
85 | {"role": "assistant", "content": example["response"]},
86 | ]
87 | end_token_of_one_turn = '<|im_end|>\n'
88 | input_w_template = self.tokenizer.apply_chat_template(messages, tokenize=False)
89 | input_w_template = input_w_template.removesuffix(end_token_of_one_turn)
90 | input_w_template += self.step_tag
91 |
92 | tokenized_inputs = self.tokenizer(input_w_template, padding=True)
93 |
94 | indices = [i for i, x in enumerate(tokenized_inputs['input_ids']) if x == self.step_tag_id]
95 |
96 | if len(indices) != len(example['label']):
97 | example['label'] = example['label'][:len(indices)]
98 |
99 | assert len(indices) == len(example['label'])
100 |
101 | length = len(tokenized_inputs['input_ids'])
102 |
103 | if not self.args.use_soft_label:
104 | tokenized_inputs['labels'] = [-100] * length
105 | for i, idx in enumerate(indices):
106 | if example['label'][i] in ['positive', 1]:
107 | tokenized_inputs['labels'][idx] = self.candidate_token_ids[0]
108 | elif example['label'][i] in ['negative', 0]:
109 | tokenized_inputs['labels'][idx] = self.candidate_token_ids[1]
110 | else:
111 | raise ValueError('Invalid label value')
112 | tokenized_inputs['attention_mask'][idx] = 0
113 | else: # use soft labels
114 | tokenized_inputs['labels'] = [-100.0] * length
115 | for i, idx in enumerate(indices):
116 | tokenized_inputs['labels'][idx] = example['label'][i]
117 | tokenized_inputs['attention_mask'][idx] = 0
118 |
119 | assert len(tokenized_inputs["input_ids"]) == len(tokenized_inputs["labels"]) == len(tokenized_inputs["attention_mask"])
120 | return tokenized_inputs
121 |
122 | def prepare_datasets(self, training_args: Seq2SeqTrainingArguments, test_size=0.2, seed=42):
123 | dataset = load_dataset('json', data_files=self.args.data_path, split='train')
124 | dataset = dataset.filter(lambda x: x["prompt"])
125 |
126 | splits = dataset.train_test_split(
127 | test_size=test_size,
128 | seed=seed,
129 | shuffle=True
130 | )
131 |
132 | with training_args.main_process_first(desc="Tokenizing datasets"):
133 | tokenized_datasets = {
134 | split: splits[split].map(
135 | self.preprocess_example,
136 | remove_columns=splits[split].column_names,
137 | )
138 | for split in splits
139 | }
140 |
141 | print_rank_0(f"Training set size: {len(tokenized_datasets['train'])}")
142 | print_rank_0(f"Test set size: {len(tokenized_datasets['test'])}")
143 |
144 | ridx = random.randint(0, len(tokenized_datasets["train"]) - 1)
145 | self.print_example(tokenized_datasets["train"][ridx])
146 |
147 | return tokenized_datasets
148 |
149 | # Define a custom metric function (e.g., accuracy for binary classification)
150 | def preprocess_logits_for_metrics(self, logits, labels):
151 |
152 | labels_index = torch.argwhere(torch.bitwise_or(
153 | labels == self.candidate_token_ids[0],
154 | labels == self.candidate_token_ids[1]
155 | ))
156 | gold = torch.where(
157 | labels[labels_index[:, 0], labels_index[:, 1]] == self.candidate_token_ids[1],
158 | 0, 1
159 | )
160 | labels_index[:, 1] -= 1
161 | logits = logits[labels_index[:, 0], labels_index[:, 1]][:, [
162 | self.candidate_token_ids[1],
163 | self.candidate_token_ids[0]
164 | ]]
165 | prob = torch.softmax(logits, dim=-1)
166 | return prob[:, 1], gold
167 |
168 | def preprocess_logits_for_soft_label_metrics(self, logits, labels):
169 |
170 | labels_index = labels.ne(-100.0).nonzero()
171 | positive_labels = labels[labels_index[:, 0], labels_index[:, 1]]
172 | negative_labels = 1 - positive_labels
173 | gold = torch.stack([positive_labels, negative_labels], dim=-1).argmax(dim=-1)
174 | labels_index[:, 1] -= 1
175 | logits = logits[labels_index[:, 0], labels_index[:, 1]][:, self.candidate_token_ids]
176 | return logits, gold
177 |
178 | def compute_metrics(self, eval_pred):
179 |
180 | if not self.args.use_soft_label:
181 | pre, labels = eval_pred
182 | auc = roc_auc_score(pre[1], pre[0])
183 | ll = log_loss(pre[1], pre[0])
184 | acc = accuracy_score(pre[1], pre[0] > 0.5)
185 | result = {
186 | 'auc': auc,
187 | 'll': ll,
188 | 'acc': acc,
189 | }
190 | else:
191 | predictions, labels = eval_pred
192 | acc = accuracy_score(predictions[0].argmax(axis=-1), predictions[1])
193 | result = {
194 | 'acc': acc,
195 | }
196 | print_rank_0(result)
197 | return result
198 |
199 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase
200 | from transformers.utils import PaddingStrategy
201 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning
202 |
203 | @dataclass
204 | class DataCollatorForSeq2SeqWithSoftLabels:
205 | """
206 | Data collator that will dynamically pad the inputs received, as well as the labels.
207 | This version supports soft labels (float values) in the label tensors.
208 |
209 | Args:
210 | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
211 | The tokenizer used for encoding the data.
212 | model ([`PreTrainedModel`], *optional*):
213 | The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
214 | prepare the *decoder_input_ids*
215 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
216 | Select a strategy to pad the returned sequences
217 | max_length (`int`, *optional*):
218 | Maximum length of the returned list and optionally padding length.
219 | pad_to_multiple_of (`int`, *optional*):
220 | If set will pad the sequence to a multiple of the provided value.
221 | label_pad_token_id (`float`, *optional*, defaults to 0.0):
222 | The value to use when padding the labels. Changed to 0.0 for soft labels.
223 | return_tensors (`str`, *optional*, defaults to `"pt"`):
224 | The type of Tensor to return. Allowable values are "np", "pt" and "tf".
225 | """
226 |
227 | tokenizer: PreTrainedTokenizerBase
228 | model: Optional[Any] = None
229 | padding: Union[bool, str, PaddingStrategy] = True
230 | max_length: Optional[int] = None
231 | pad_to_multiple_of: Optional[int] = None
232 | label_pad_token_id: float = -100.0 # Changed to float
233 | return_tensors: str = "pt"
234 |
235 | def __call__(self, features, return_tensors=None):
236 | if return_tensors is None:
237 | return_tensors = self.return_tensors
238 |
239 | label_name = "label" if "label" in features[0].keys() else "labels"
240 | labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
241 |
242 | # Handle None labels
243 | if labels is not None and all(label is None for label in labels):
244 | labels = None
245 |
246 | non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
247 |
248 | # Process inputs without labels
249 | batch = pad_without_fast_tokenizer_warning(
250 | self.tokenizer,
251 | non_labels_features,
252 | padding=self.padding,
253 | max_length=self.max_length,
254 | pad_to_multiple_of=self.pad_to_multiple_of,
255 | return_tensors=return_tensors,
256 | )
257 |
258 | # Process labels if they exist
259 | no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
260 | if labels is not None:
261 | if no_padding:
262 | if isinstance(features[0][label_name], list):
263 | batch["labels"] = list(labels)
264 | else:
265 | batch["labels"] = [np.concatenate([label, []]) for label in labels]
266 | else:
267 | max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
268 | max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
269 |
270 | if self.pad_to_multiple_of is not None:
271 | max_label_length = (
272 | (max_label_length + self.pad_to_multiple_of - 1)
273 | // self.pad_to_multiple_of
274 | * self.pad_to_multiple_of
275 | )
276 |
277 | padding_side = self.tokenizer.padding_side
278 |
279 | if isinstance(features[0][label_name], list):
280 | for idx, label in enumerate(labels[0]):
281 | if type(label) != type(self.label_pad_token_id):
282 | raise ValueError(
283 | f'The {idx} th label is of type {type(label)} while the label_pad_token_id is of type {type(self.label_pad_token_id)}, '
284 | 'you should make sure that they are of the same type'
285 | )
286 |
287 | batch["labels"] = [
288 | label + [self.label_pad_token_id] * (max_label_length - len(label))
289 | if padding_side == "right"
290 | else [self.label_pad_token_id] * (max_label_length - len(label)) + label
291 | for label in labels
292 | ]
293 | else:
294 | # Convert to float16 for soft labels
295 | batch["labels"] = [
296 | np.concatenate(
297 | [
298 | label.astype(np.float16),
299 | np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.float16),
300 | ]
301 | )
302 | if padding_side == "right"
303 | else np.concatenate(
304 | [
305 | np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.float16),
306 | label.astype(np.float16),
307 | ]
308 | )
309 | for label in labels
310 | ]
311 |
312 | # Convert to appropriate tensor type
313 | if batch.get("labels", None) is not None:
314 | if return_tensors == "pt":
315 | import torch
316 | batch["labels"] = torch.tensor(batch["labels"], dtype=torch.float16) # Changed to float16
317 | else:
318 | raise NotImplementedError(f"return_tensors='{return_tensors}' not supported yet.")
319 | else:
320 | batch["labels"] = None
321 |
322 | if (
323 | labels is not None
324 | and self.model is not None
325 | and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
326 | ):
327 | decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
328 | batch["decoder_input_ids"] = decoder_input_ids
329 |
330 | return batch
331 |
332 |
333 | class PRMTrainerForTokenPrediction(Trainer):
334 | def __init__(self, prm_use_tokens_cfg: Dict[str, Union[int, List[int]]], **kwargs):
335 | super().__init__(**kwargs)
336 | self.prm_use_tokens_cfg = prm_use_tokens_cfg
337 | self.loss_func = nn.functional.cross_entropy
338 | # self.model_accepts_loss_kwargs = False
339 |
340 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
341 | labels = inputs.pop("labels")
342 | if self.model_accepts_loss_kwargs:
343 | loss_kwargs = {}
344 | if num_items_in_batch is not None:
345 | loss_kwargs["num_items_in_batch"] = num_items_in_batch
346 | inputs = {**inputs, **loss_kwargs}
347 |
348 | outputs = model(**inputs)
349 | if self.args.past_index >= 0:
350 | self._past = outputs[self.args.past_index]
351 |
352 | placeholder_token_id = self.prm_use_tokens_cfg["placeholder_token_id"]
353 | candidate_token_ids_for_prediction = self.prm_use_tokens_cfg["candidate_token_ids_for_prediction"]
354 |
355 | logits = outputs.logits
356 | # new_labels = torch.zeros_like(logits).to(outputs.logits.dtype)
357 | # positive_labels = labels.to(logits.dtype)
358 | # negative_labels = 1 - positive_labels
359 | # new_labels[..., candidate_token_ids_for_prediction] = torch.stack([negative_labels, positive_labels], dim=-1)
360 |
361 | # reference from https://github.com/OpenRLHF/OpenRLHF/blob/460477d628751bfaa95297af2763f2fd729ecd20/openrlhf/models/loss.py#L259
362 | placeholder_positions = (inputs["input_ids"] == placeholder_token_id).nonzero()
363 | shift_placeholder_positions = placeholder_positions.clone()
364 | shift_placeholder_positions[:, -1] -= 1
365 | logits = logits[shift_placeholder_positions[:, 0], shift_placeholder_positions[:, 1], :]
366 | labels = labels[placeholder_positions[:, 0], placeholder_positions[:, 1]]
367 | if len(candidate_token_ids_for_prediction) != 2:
368 | raise ValueError("The number of candidate tokens for prediction must be 2.")
369 | logits = logits[..., candidate_token_ids_for_prediction]
370 | positive_labels = labels.to(logits.dtype)
371 | negative_labels = 1 - positive_labels
372 | labels = torch.stack([positive_labels, negative_labels], dim=-1)
373 | reduction = 'sum' if num_items_in_batch is not None else 'mean'
374 | loss = self.loss_func(logits, labels, reduction=reduction)
375 | if reduction == 'sum':
376 | loss /= num_items_in_batch
377 |
378 | if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
379 | loss *= self.accelerator.num_processes
380 |
381 | return (loss, outputs) if return_outputs else loss
382 |
383 |
384 | def run_exp(model_args, data_args, training_args):
385 | print_rank_0('loading model and toeknizer...')
386 | model, tokenizer = setup_model_and_tokenizer(model_args.model_name_or_path)
387 |
388 | processor = DatasetProcessor(data_args, tokenizer)
389 | print_rank_0('start data processing...')
390 | tokenized_datasets = processor.prepare_datasets(training_args)
391 | if not data_args.use_soft_label:
392 | data_collator = DataCollatorForSeq2Seq(tokenizer)
393 | else:
394 | data_collator = DataCollatorForSeq2SeqWithSoftLabels(tokenizer)
395 |
396 | world_size = DIST_STATE.num_processes
397 | per_device_total_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
398 | TOTAL_BATCH_SIZE = per_device_total_batch_size * world_size
399 | print_rank_0(f"Total batch size: {TOTAL_BATCH_SIZE}")
400 |
401 | fp = f'bs_{TOTAL_BATCH_SIZE}_g_{training_args.gradient_accumulation_steps}_lr_{training_args.learning_rate}_ep_{training_args.num_train_epochs}'
402 | training_args.output_dir = os.path.join(training_args.output_dir, fp)
403 | training_args.logging_dir = os.path.join(training_args.output_dir, 'logs')
404 |
405 | prm_use_tokens_cfg = {
406 | "placeholder_token_id": processor.step_tag_id,
407 | "candidate_token_ids_for_prediction": processor.candidate_token_ids,
408 | }
409 | if data_args.use_soft_label:
410 | trainer = PRMTrainerForTokenPrediction(
411 | prm_use_tokens_cfg=prm_use_tokens_cfg,
412 | model=model,
413 | args=training_args,
414 | train_dataset=tokenized_datasets["train"],
415 | eval_dataset=tokenized_datasets["test"], # Replace with a validation set if available
416 | data_collator=data_collator,
417 | tokenizer=tokenizer,
418 | preprocess_logits_for_metrics=processor.preprocess_logits_for_soft_label_metrics,
419 | compute_metrics=processor.compute_metrics,
420 | )
421 | else:
422 | trainer = Trainer(
423 | model=model,
424 | args=training_args,
425 | train_dataset=tokenized_datasets["train"],
426 | eval_dataset=tokenized_datasets["test"], # Replace with a validation set if available
427 | data_collator=data_collator,
428 | tokenizer=tokenizer,
429 | preprocess_logits_for_metrics=processor.preprocess_logits_for_metrics,
430 | compute_metrics=processor.compute_metrics,
431 | )
432 |
433 | trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
434 |
435 | trainer.save_state()
436 | trainer.save_model(output_dir=training_args.output_dir)
437 |
438 | if __name__ == '__main__':
439 | ...
--------------------------------------------------------------------------------
/src/prm_training/train_prm/utils/accelerator_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from accelerate import Accelerator
3 | from typing import Optional
4 |
5 | class AcceleratorManager:
6 | _instance: Optional[Accelerator] = None
7 |
8 | @classmethod
9 | def initialize(cls, **kwargs):
10 | if cls._instance is None:
11 | cls._instance = Accelerator(**kwargs)
12 |
13 | if cls._instance.is_main_process:
14 | print(f"Distributed training setup:")
15 | print(f"- Number of processes: {cls._instance.num_processes}")
16 | print(f"- Mixed precision: {cls._instance.mixed_precision}")
17 | print(f"- Gradient accumulation steps: {cls._instance.gradient_accumulation_steps}")
18 |
19 | @classmethod
20 | def get_accelerator(cls) -> Accelerator:
21 | if cls._instance is None:
22 | raise RuntimeError(
23 | "Accelerator not initialized. Call AcceleratorManager.initialize() first."
24 | )
25 | return cls._instance
26 |
27 | @classmethod
28 | def is_initialized(cls) -> bool:
29 | return cls._instance is not None
30 |
31 |
32 | def get_accelerator() -> Accelerator:
33 | return AcceleratorManager.get_accelerator()
34 |
35 | def is_main_process() -> bool:
36 | return get_accelerator().is_main_process
37 |
38 | def get_local_rank() -> int:
39 | return get_accelerator().local_process_index
40 |
41 | def get_world_size() -> int:
42 | return get_accelerator().num_processes
43 |
44 | def synchronize():
45 | get_accelerator().wait_for_everyone()
46 |
47 | def print_rank_0(msg):
48 | if is_main_process():
49 | print(msg)
--------------------------------------------------------------------------------
/src/prm_training/train_prm/utils/dist_configs/ds_zero2_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 2,
20 | "allgather_partitions": true,
21 | "allgather_bucket_size": 5e8,
22 | "overlap_comm": true,
23 | "reduce_scatter": true,
24 | "reduce_bucket_size": 5e8,
25 | "contiguous_gradients": true,
26 | "round_robin_gradients": true
27 | }
28 | }
--------------------------------------------------------------------------------
/src/prm_training/train_prm/utils/dist_configs/multi_gpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | # distributed_type: MULTI_GPU
4 | deepspeed_config:
5 | deepspeed_config_file: ./train_prm/utils/dist_configs/ds_zero2_config.json
6 | distributed_type: DEEPSPEED
7 | downcast_bf16: 'no'
8 | gpu_ids: all
9 | machine_rank: 0
10 | # main_training_function: main
11 | # mixed_precision: 'bf16'
12 | num_machines: 1
13 | num_processes: 2
14 | rdzv_backend: static
15 | same_network: true
16 | tpu_env: []
17 | tpu_use_cluster: false
18 | tpu_use_sudo: false
19 | use_cpu: false
--------------------------------------------------------------------------------
/src/prm_training/train_prm/utils/dist_configs/single_gpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | # main_training_function: main
8 | mixed_precision: 'bf16'
9 | num_machines: 1
10 | num_processes: 1
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
--------------------------------------------------------------------------------