├── LICENSE ├── PDF └── o1-Coder.pdf ├── README.md ├── assets └── algo.jpeg └── src ├── RL ├── README.md ├── rewards │ ├── __init__.py │ ├── examples.py │ ├── prm_utils.py │ ├── pyext2.py │ ├── rewards.py │ └── testing_util.py └── rewards_test.py ├── TestCaseGenerate ├── README.md ├── TCG.py └── wash_code.py ├── mcts ├── README.md ├── common │ ├── arguments.py │ └── utils.py ├── data │ └── TACO │ │ ├── self_create_data.json │ │ ├── test_one.json │ │ └── train_easy_10.json ├── eval_src │ ├── Evaluator.py │ ├── checker_utils.py │ ├── pyext2.py │ └── testing_util.py ├── models │ ├── HuggingFace_API.py │ ├── IO_System.py │ ├── OpenAI_API.py │ └── vLLM_API.py ├── prompts │ └── TACO │ │ ├── examples.txt │ │ └── prompt.json ├── run_outputs │ └── data_examples │ │ ├── answer_sheets │ │ ├── Question 0000 - Answer.json │ │ ├── Question 0000 - Best Solutions.json │ │ ├── Question 0000 - Complete Solutions.json │ │ └── Question 0000 - Rollout Solutions.json │ │ ├── args.json │ │ └── intermediate_result.txt ├── run_src │ ├── MCTS_backbone.py │ ├── MCTS_for_reasoning.py │ ├── do_generate.py │ └── rstar_utils.py └── scripts │ ├── api_run_TACO.sh │ ├── run_TACO.sh │ └── start.md └── prm_training ├── README.md ├── data └── examples │ ├── hard_label_examples.json │ └── soft_label_examples.json ├── requirements.txt ├── run.py ├── run.sh └── train_prm ├── __init__.py ├── run_train.py └── utils ├── accelerator_utils.py └── dist_configs ├── ds_zero2_config.json ├── multi_gpu.yaml └── single_gpu.yaml /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ADaM-LAB. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PDF/o1-Coder.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ADaM-BJTU/O1-CODER/f01c769397afaefc89c22c51d048484a79d11c1f/PDF/o1-Coder.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # O1-CODER 2 | [O1-CODER: An O1 Replication for Coding (Paper)](https://arxiv.org/abs/2412.00154) 3 | 4 | ## Overview 5 | 6 | **O1-CODER** is an attempt to replicate OpenAI's **O1 model**, focused on coding tasks. The approach combines **Reinforcement Learning (RL)** and **Monte Carlo Tree Search (MCTS)** to enhance the model’s **System-2** thinking capabilities, aiming to generate more efficient and logical code. 7 | 8 | ### Method 9 | 10 | The core components of **O1-CODER** are: 11 | 12 | 1. **Test Case Generator (TCG)**: Automatically generates standardized test cases to evaluate the correctness of the generated code. 13 | 2. **Self-Play and Reinforcement Learning**: The model generates reasoning data through self-play, and uses RL and MCTS to iteratively optimize the policy model. 14 | These methods work in an iterative cycle, continuously refining the model to improve systematic reasoning and optimization in coding tasks. 15 | 16 |
17 | 18 |
19 | 20 | ## News 21 | 22 | ### Latest Updates 23 | #### - 2024-12-10 24 | - Updated the Reward Aggregator 25 | 26 | #### - 2024-12-07 27 | - Updated the training code for the process reward model and Test Case Generator. 28 | - Updated the MCTS-based data synthesis code for O1-CODER. 29 | 30 | #### - 2024-12-01 31 | - Updated the technical report for O1-CODER. 32 | 33 | --- 34 | 35 | ### Planned Updates 36 | 37 | TODO: Reinforcement Learning code, Curated datasets and derived models 38 | 39 | TODO: **Reinforcement Fine-Tuning (RFT) Version of O1-Coder** Due to the characteristics of the test case generator, O1-Coder can generate diverse process supervision data with only a small amount of ground truth code. Therefore, in the RFT version, we will skip the use of CoT data for initializing the policy model. 40 | 41 | --- 42 | 43 | ## License 44 | 45 | This work is released under the MIT License. See the [LICENSE](./LICENSE) file for more details. By using this code or associated materials, you agree to comply with the terms outlined in the license. 46 | 47 | 48 | ## Citation 49 | 50 | If you use **O1-CODER** or parts of this work in your research or applications, please cite the following paper: 51 | ``` 52 | @misc{zhang2024o1codero1replicationcoding, 53 | title={O1-Coder: An O1 Replication for Coding}, 54 | author={Yuxiang Zhang and Shangxi Wu and Yuqi Yang and Jiangming Shu and Jinlin Xiao and Chao Kong and Jitao Sang}, 55 | year={2024}, 56 | eprint={2412.00154}, 57 | archivePrefix={arXiv}, 58 | primaryClass={cs.SE}, 59 | url={https://arxiv.org/abs/2412.00154}, 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /assets/algo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ADaM-BJTU/O1-CODER/f01c769397afaefc89c22c51d048484a79d11c1f/assets/algo.jpeg -------------------------------------------------------------------------------- /src/RL/README.md: -------------------------------------------------------------------------------- 1 | # Reward Aggregator 2 | 3 | ## Introduction 4 | 5 | This project provides an offline reward update method, which is suitable for reward annotation tasks in reinforcement learning, especially for reward updates in methods like Iterative DPO. 6 | 7 | The core component of this project is `RewardAggregater`, which supports flexible configuration of the following parameters: 8 | 9 | - `phi` function (for reward aggregation) 10 | - `alpha` function (time decay factor) 11 | - `gamma` (discount factor) 12 | 13 | This tool can compute **intermediate rewards** and **outcome rewards** based on different model outputs. 14 | 15 | ## Usage 16 | 17 | The `rewards_test.py` file provides a complete usage example. It demonstrates how to initialize the `RewardAggregater`, load test examples, and compute rewards. 18 | 19 | Simply run the `rewards_test.py` script: 20 | 21 | ```bash 22 | python rewards_test.py 23 | ``` 24 | 25 | This will compute and print the rewards for the predefined test cases provided in `examples.py`. 26 | 27 | 28 | 29 | ## Features 30 | 31 | - **Reward Aggregation**: The `phi` function aggregates the final and intermediate rewards, computing a smoother cumulative reward. 32 | - **Time Decay**: Supports adjustable linear decay factor `alpha(t)`, allowing the reward weight to be adjusted according to the time step. 33 | - **Offline Update**: Supports offline reward calculation and updates, enabling batch processing of reward signals without requiring online training. 34 | - **Flexible Configuration**: The `RewardAggregater` class allows users to customize reward computation methods, reward decay functions, discount factors, and other parameters. 35 | 36 | ## Planned Updates 37 | 38 | - **Reinforcement Learning Training Code** 39 | - **Online Reward Update**: Plan to add support for online reward updates in future versions. 40 | -------------------------------------------------------------------------------- /src/RL/rewards/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ADaM-BJTU/O1-CODER/f01c769397afaefc89c22c51d048484a79d11c1f/src/RL/rewards/__init__.py -------------------------------------------------------------------------------- /src/RL/rewards/examples.py: -------------------------------------------------------------------------------- 1 | rewards_test_examples = [ 2 | { 3 | "question": "You are given a sequence of $n$ integers, $a_0,a_1,\\ldots,a_{n-1}$. Find and print any integer $\\boldsymbol{x}$ such that $\\boldsymbol{x}$ is divisor of every $a_i$ except for exactly one element.\n\nInput Format\n\nThe first line contains an integer, $n$, denoting the length of the sequence. \n\nThe second line contains $n$ positive space-separated integers describing $a_0,a_1,\\ldots,a_{n-1}$.\n\nConstraints\n\n$1\\leq n\\leq10^5$\n$1\\leq a_i\\leq10^{18}$\nIt is guaranteed that a solution exists.\n\nOutput Format\n\nPrint any positive integer denoting $\\boldsymbol{x}$ such that $\\boldsymbol{x}$ is a divisor of exactly $n-1$ of the sequence's elements. $\\boldsymbol{x}$ must be between $\\mbox{1}$ and $2\\cdot10^{18}$\n\nSample Input 0\n4\n3 6 18 12\n\nSample Output 0\n6\n\nExplanation 0\n\nWe are given the array $[3,6,18,12]$. There are two possible answers:\n\n$x=6$ is a divisor of $\\boldsymbol{6}$, $12$, and $\\mbox{18}$ but not a divisor of $3$.\n$x=2$ is a divisor of $\\boldsymbol{6}$, $12$, and $\\mbox{18}$ but not a divisor of $3$.\n\nThus, we can print either $\\boldsymbol{6}$ or $2$ as our answer.", 4 | "reasoning_steps": [ 5 | "Step 1: Defining Function Structures Using pseudocode\nWe start by defining the main function `find_divisor` which will take the input and compute the integer \\( x \\) that meets the criteria. We also define a helper function to check if a given number is a divisor of all elements in the array except one.\n\n[Pseudo Start]\n```\nFunction is_divisor_except_one(arr, x):\n Initialize count to 0\n For each element a in arr:\n If a modulo x is 0, increment count\n If count is equal to length of arr - 1, return True\n Return False\n\nFunction find_divisor(n, arr):\n For each potential divisor x in arr:\n If is_divisor_except_one(arr, x) is True, return x\n Return -1 # This line should never be reached as per the problem statement\n```\n[Pseudo End]", 6 | "Step 2: Refine part of the pseudocode\nNotice that the most efficient way to find a number that is a divisor of all elements except one is to consider the elements of the array themselves, as they are potential candidates for \\( x \\). We'll iterate through the array and for each candidate, check if it divides \\( n-1 \\) elements.\n\n[Pseudo Start]\n```\nFunction find_divisor(n, arr):\n For each element x in arr:\n Initialize count to 0\n For each element a in arr:\n If a modulo x is 0, increment count\n If count is equal to n - 1, return x\n Return -1 # This line should never be reached as per the problem statement\n```\n[Pseudo End]", 7 | "Step 3: Generate python code from the pseudocode\nWe can now convert the refined pseudocode into working Python code.\n\nThe code is:\n```python\ndef find_divisor(n, arr):\n for x in arr:\n count = 0\n for a in arr:\n if a % x == 0:\n count += 1\n if count == n - 1:\n return x\n return -1 # This line should never be reached as per the problem statement\n\n# Read input\nn = int(input())\narr = list(map(int, input().split()))\n\n# Find and print the result\nresult = find_divisor(n, arr)\nprint(result)\n```\n\n### Optimization Considerations\nThe above code works correctly but is not optimal for large arrays, as it has a time complexity of \\( O(n^2) \\) due to the nested loop. Instead, we can optimize by leveraging the greatest common divisor (GCD) properties.", 8 | "Step 4: Refine part of the pseudocode\nA more optimal approach involves computing the GCD of all elements in the array and then checking each unique element in the array to see if it diverts from the GCD pattern for exactly one element.\n\n[Pseudo Start]\n```\nFunction gcd(a, b):\n While b is not 0:\n temp = b\n b = a % b\n a = temp\n Return a\n\nFunction gcd_of_list(lst):\n current_gcd = lst[0]\n For each element in lst:\n current_gcd = gcd(current_gcd, element)\n Return current_gcd\n\nFunction find_divisor(n, arr):\n full_gcd = gcd_of_list(arr)\n For each element x in arr:\n If x is not equal to full_gcd:\n Initialize count to 0\n For each element a in arr:\n If a modulo x is 0, increment count\n If count is equal to n - 1, return x\n Return full_gcd # If all elements are the same, return the gcd itself\n```\n[Pseudo End]", 9 | "Step 5: Generate python code from the pseudocode\nWe can now convert the refined pseudocode into working Python code with the optimized approach.\n\nThe code is:\n```python\ndef gcd(a, b):\n while b:\n a, b = b, a % b\n return a\n\ndef gcd_of_list(lst):\n current_gcd = lst[0]\n for element in lst:\n current_gcd = gcd(current_gcd, element)\n return current_gcd\n\ndef find_divisor(n, arr):\n full_gcd = gcd_of_list(arr)\n for x in arr:\n if x != full_gcd:\n count = sum(1 for a in arr if a % x == 0)\n if count == n - 1:\n return x\n return full_gcd # If all elements are the same, return the gcd itself\n\n# Read input\nn = int(input())\narr = list(map(int, input().split()))\n\n# Find and print the result\nresult = find_divisor(n, arr)\nprint(result)\n```\n\nThis optimized solution works efficiently even for large input sizes and ensures that we find the correct divisor in a more optimal manner." 10 | ], 11 | "test_cases": { 12 | "inputs": [ 13 | "4\n3 6 18 12\n" 14 | ], 15 | "outputs": [ 16 | "6\n" 17 | ] 18 | } 19 | } 20 | ] -------------------------------------------------------------------------------- /src/RL/rewards/prm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | from dataclasses import dataclass, field 4 | from enum import Enum 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | from typing import List, Optional, Union, Dict, Literal, TypeAlias 7 | 8 | 9 | CODEPRM_PROMPT = """Please refer to the given task description and provide a thought process in the form of step-by-step pseudocode refinement. 10 | 11 | A curious user has approached you with a programming question. You should give step-by-step solutions to the user's questions. For each step you can choose one of the following three actions: 12 | Defining Function Structures Using pseudocode 13 | Refine part of the pseudocode 14 | Generate python code from the pseudocode 15 | 16 | ## Structure Guidelines: 17 | 1. Please note that the pseudocode should be detailed and provide a step-by-step solution. Each step should logically build upon the previous one, moving from high-level structure to detailed implementation. 18 | 2. Each step should follow the format: "Step x: ..." (where x is the step number and is one of the specified actions). 19 | 3. The pseudocode should be presented in the format: "[Pseudo Start][Pseudo End]". 20 | 4. At the final step, provide the complete Python code in this format: "The code is: [Code Start][Code End]." Here, should contain a working Python code based on the final pseudocode, and it must be enclosed within Python code block syntax. 21 | 22 | ## Notes 23 | 1. Aim to break down the solution into as many detailed, intermediate steps as possible while ensuring logical coherence between steps and avoiding unnecessary redundancy. 24 | 2. The Python code solution should match the input and output requirements as described in the question. This means the solution may use terminal I/O for inputs and outputs, or it may require function parameters and return values. Carefully review the question's description to determine the expected code structure, and ensure there are no input/output format errors. 25 | 3. Gradually refine each functional part of the pseudocode, breaking down complex operations into manageable steps. 26 | 4. Transition to Python code only once all parts of the pseudocode have been fully refined. 27 | 6. Do not generate content unrelated to the answer or any other explanations. 28 | 29 | Now, with the problem description provided below, you need to provide or complete a full, step-by-step solution according to the previous explanations. **If the 'Solution' section is empty, please directly provide a complete, step-by-step solution. If it is not empty, do not repeat or rephrase existing content; simply continue from where it left off to complete the solution.** 30 | ### Description 31 | {question} 32 | 33 | ### Solution 34 | """ 35 | 36 | 37 | @dataclass 38 | class StepTokensForLM: 39 | step_tag: str = field( 40 | default=' Rating', 41 | metadata={'help': 'The tag that indicates the end of a step/action'} 42 | ) 43 | good_token: str = field( 44 | default=' +', 45 | metadata={'help': 'The token that indicates a positive action'} 46 | ) 47 | bad_token: str = field( 48 | default=' -', 49 | metadata={'help': 'The token that indicates a negative action'} 50 | ) 51 | 52 | @dataclass 53 | class RewardStrategy(Enum): 54 | """ 55 | Enum class for the tokenized format of the text. 56 | """ 57 | TOKEN_LOGITS = 'token_logits' 58 | VALUE_HEAD = 'value_head' 59 | 60 | 61 | PromptType: TypeAlias = Union[str, List[str]] 62 | PrefixesType: TypeAlias = List[PromptType] 63 | 64 | @torch.no_grad() 65 | def get_process_rewards(model: AutoModelForCausalLM, 66 | tokenizer: AutoTokenizer, 67 | prompts: PromptType, 68 | completed_processes: PrefixesType, 69 | tokenized_format: Optional[Literal['completion', 'chat_completion']], 70 | reward_strategy: Optional[RewardStrategy] = RewardStrategy.TOKEN_LOGITS.value, 71 | ) -> List[float]: 72 | 73 | if isinstance(prompts, str): 74 | if isinstance(completed_processes[0], list): 75 | raise ValueError("The `completed_prefixes` argument must be a list of strings if `prompts` is a string.") 76 | prompts = [prompts] 77 | completed_processes = [completed_processes] 78 | if isinstance(prompts, list) and isinstance(completed_processes[0], str): 79 | raise ValueError("The `completed_prefixes` argument must be a list of lists if `prompts` is a list.") 80 | if len(prompts) != len(completed_processes): 81 | raise ValueError("The number of prompts must match the number of completed prefixes in order.") 82 | 83 | if reward_strategy is RewardStrategy.TOKEN_LOGITS.value: 84 | tokenized_ids = tokenizer( 85 | [StepTokensForLM.step_tag, StepTokensForLM.good_token, StepTokensForLM.bad_token] 86 | )["input_ids"] 87 | if any(len(tokenized_id) != 1 for tokenized_id in tokenized_ids): 88 | raise ValueError("The tokens `step_tag`, `good_token`, `bad_token` used in class `StepTokensForLM` must be single tokens.") 89 | step_token_id, good_token_id, bad_token_id = [ids[0] for ids in tokenized_ids] 90 | 91 | if tokenized_format == 'completion': 92 | raise NotImplementedError("Token Logits reward strategy is not implemented for completion tokenized format.") 93 | elif tokenized_format == 'chat_completion': 94 | input_texts = [] 95 | for prompt, processes in zip(prompts, completed_processes): 96 | step_tag_inserted_process = '' 97 | for process in processes: 98 | step_tag_inserted_process += process + StepTokensForLM.step_tag 99 | 100 | messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": step_tag_inserted_process}] 101 | chat_template = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) 102 | 103 | end_token_of_one_turn = '<|im_end|>\n' 104 | chat_template = chat_template.removesuffix(end_token_of_one_turn) 105 | input_texts.append(chat_template) 106 | 107 | model_inputs = tokenizer(input_texts, return_tensors='pt', padding=True).to(model.device) 108 | step_tag_positions = (model_inputs["input_ids"] == step_token_id).nonzero() 109 | model_inputs["attention_mask"][step_tag_positions[:, 0], step_tag_positions[:, 1]] = 0 110 | 111 | outputs = model(**model_inputs) 112 | logits = outputs.logits 113 | logits = logits[..., [good_token_id, bad_token_id]] 114 | scores = logits[step_tag_positions[:, 0], step_tag_positions[:, 1] - 1, :].softmax(dim=-1) # [bs, 2(' +', ' -')] 115 | 116 | return scores.cpu().tolist() 117 | else: 118 | raise ValueError("The `tokenized_format` argument must be either 'completion' or 'chat_completion'.") 119 | 120 | 121 | if __name__ == '__main__': 122 | model_path = '/root/autodl-tmp/models/Qwen2.5-Coder-7B-Instruct-PRM' 123 | model = AutoModelForCausalLM.from_pretrained( 124 | model_path, 125 | torch_dtype=torch.bfloat16, 126 | attn_implementation="flash_attention_2" 127 | ).to('cuda').eval() 128 | tokenizer = AutoTokenizer.from_pretrained(model_path) 129 | 130 | problem = 'Given some positive integers, I wish to print the integers such that all take up the same width by adding a minimum number of leading zeroes. No leading zeroes shall be added to the largest integer.\n\nFor example, given `1, 23, 2, 17, 102`, I wish to print out these numbers as follows:\n\n```python\n001\n023\n002\n017\n102\n```\n\nWrite a function `print_nums(n1, n2, n3, ...)` that takes a variable number of arguments and returns the string to be printed out.' 131 | prompt = [ 132 | CODEPRM_PROMPT.format(question=problem), 133 | problem 134 | ] 135 | examples = [ 136 | ["Step 1: Defining Function Structures Using pseudocode\nWe start by defining the structure of our solution. We need a function `reorder` that takes two integers `N` and `M`, and returns a numpy array with two sub-arrays. Each sub-array will contain numbers in the specified ranges and will be rotated `M` times.\n\n[Pseudo Start]\n```\nFunction reorder(N, M):\n Calculate half of N\n Create the first sub-array with numbers in the range [0, N/2)\n Create the second sub-array with numbers in the range [N/2, N)\n Rotate the first sub-array M times\n Rotate the second sub-array M times\n Combine the two sub-arrays into a numpy array\n Return the combined numpy array\n[Pseudo End]\n```"], 137 | ["Step 1: Defining Function Structures Using pseudocode\nWe start by defining the structure of our solution. We need a function `reorder` that takes two integers `N` and `M`, and returns a numpy array with two sub-arrays. Each sub-array will contain numbers in the specified ranges and will be rotated `M` times.\n\n[Pseudo Start]\n```\nFunction reorder(N, M):\n Calculate half of N\n Create the first sub-array with numbers in the range [0, N/2)\n Create the second sub-array with numbers in the range [N/2, N)\n Rotate the first sub-array M times\n Rotate the second sub-array M times\n Combine the two sub-arrays into a numpy array\n Return the combined numpy array\n[Pseudo End]\n```"] 138 | ] 139 | 140 | get_process_rewards(model, tokenizer, prompt, examples, 'chat_completion', 'token_logits') -------------------------------------------------------------------------------- /src/RL/rewards/pyext2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2014 Ryan Gonzalez 3 | 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to use, 8 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 9 | Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | ''' 22 | 23 | g_backup = globals().copy() 24 | 25 | __version__ = '0.7' 26 | 27 | __all__ = ['overload', 'RuntimeModule', 'switch', 'tail_recurse', 'copyfunc', 'set_docstring', 'annotate', 'safe_unpack', 'modify_function', 'assign', 'fannotate', 'compare_and_swap', 'is_main', 'call_if_main', 'run_main'] 28 | 29 | import sys, inspect, types 30 | 31 | def __targspec(func, specs, attr='__orig_arg__'): 32 | if hasattr(func, '__is_overload__') and func.__is_overload__: 33 | return getattr(func, attr) 34 | return specs(func) 35 | 36 | def set_docstring(doc): 37 | '''A simple decorator to set docstrings. 38 | 39 | :param doc: The docstring to tie to the function. 40 | 41 | Example:: 42 | 43 | @set_docstring('This is a docstring') 44 | def myfunc(x): 45 | pass''' 46 | def _wrap(f): 47 | f.__doc__ = doc 48 | return f 49 | return _wrap 50 | 51 | __modify_function_doc = ''' 52 | Creates a copy of a function, changing its attributes. 53 | 54 | :param globals: Will be added to the function's globals. 55 | 56 | :param name: The new function name. Set to ``None`` to use the function's original name. 57 | 58 | :param code: The new function code object. Set to ``None`` to use the function's original code object. 59 | 60 | :param defaults: The new function defaults. Set to ``None`` to use the function's original defaults. 61 | 62 | :param closure: The new function closure. Set to ``None`` to use the function's original closure. 63 | 64 | .. warning:: This function can be potentially dangerous. 65 | ''' 66 | 67 | def copyfunc(f): 68 | '''Copies a funcion. 69 | 70 | :param f: The function to copy. 71 | 72 | :return: The copied function. 73 | 74 | .. deprecated:: 0.4 75 | Use :func:`modify_function` instead. 76 | ''' 77 | return modify_function(f) 78 | 79 | if sys.version_info.major == 3: 80 | @set_docstring(__modify_function_doc) 81 | def modify_function(f, globals={}, name=None, code=None, defaults=None, 82 | closure=None): 83 | if code is None: code = f.__code__ 84 | if name is None: name = f.__name__ 85 | if defaults is None: defaults = f.__defaults__ 86 | if closure is None: closure = f.__closure__ 87 | newf = types.FunctionType(code, dict(f.__globals__, **globals), name=name, 88 | argdefs=defaults, closure=closure) 89 | newf.__dict__.update(f.__dict__) 90 | return newf 91 | def argspec(f): 92 | return inspect.getfullargspec(f) 93 | ofullargspec = inspect.getfullargspec 94 | def _fullargspec(func): 95 | return __targspec(func, ofullargspec) 96 | inspect.getfullargspec = _fullargspec 97 | def _exec(m,g): exec(m,g) 98 | else: 99 | @set_docstring(__modify_function_doc) 100 | def modify_function(f, globals={}, name=None, code=None, defaults=None, 101 | closure=None): 102 | if code is None: code = f.func_code 103 | if name is None: name = f.__name__ 104 | if defaults is None: defaults = f.func_defaults 105 | if closure is None: closure = f.func_closure 106 | newf = types.FunctionType(code, dict(f.func_globals, **globals), name=name, 107 | argdefs=defaults, closure=closure) 108 | newf.__dict__.update(f.__dict__) 109 | return newf 110 | def argspec(f): 111 | return inspect.getargspec(f) 112 | eval(compile('def _exec(m,g): exec m in g', '', 'exec')) 113 | 114 | def _gettypes(args): 115 | return tuple(map(type, args)) 116 | 117 | oargspec = inspect.getargs 118 | 119 | def _argspec(func): 120 | return __targspec(func, oargspec) 121 | 122 | inspect.getargspec = _argspec 123 | 124 | try: 125 | import IPython 126 | except ImportError: 127 | IPython = None 128 | else: 129 | # Replace IPython's argspec 130 | oipyargspec = IPython.core.oinspect.getargspec 131 | def _ipyargspec(func): 132 | return __targspec(func, oipyargspec, '__orig_arg_ipy__') 133 | IPython.core.oinspect.getargspec = _ipyargspec 134 | 135 | class overload(object): 136 | '''Simple function overloading in Python.''' 137 | _items = {} 138 | _types = {} 139 | @classmethod 140 | def argc(self, argc=None): 141 | '''Overloads a function based on the specified argument count. 142 | 143 | :param argc: The argument count. Defaults to ``None``. If ``None`` is given, automatically compute the argument count from the given function. 144 | 145 | .. note:: 146 | 147 | Keyword argument counts are NOT checked! In addition, when the argument count is automatically calculated, the keyword argument count is also ignored! 148 | 149 | Example:: 150 | 151 | @overload.argc() 152 | def func(a): 153 | print 'Function 1 called' 154 | 155 | @overload.argc() 156 | def func(a, b): 157 | print 'Function 2 called' 158 | 159 | func(1) # Calls first function 160 | func(1, 2) # Calls second function 161 | func() # Raises error 162 | ''' 163 | # Python 2 UnboundLocalError fix 164 | argc = {'argc': argc} 165 | def _wrap(f): 166 | def _newf(*args, **kwargs): 167 | if len(args) not in self._items[f.__name__]: 168 | raise TypeError("No overload of function '%s' that takes %d args" % (f.__name__, len(args))) 169 | return self._items[f.__name__][len(args)](*args, **kwargs) 170 | if f.__name__ not in self._items: 171 | self._items[f.__name__] = {} 172 | if argc['argc'] is None: 173 | argc['argc'] = len(argspec(f).args) 174 | self._items[f.__name__][argc['argc']] = f 175 | _newf.__name__ = f.__name__ 176 | _newf.__doc__ = f.__doc__ 177 | _newf.__is_overload__ = True 178 | _newf.__orig_arg__ = argspec(f) 179 | if IPython: 180 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f) 181 | return _newf 182 | return _wrap 183 | @classmethod 184 | def args(self, *argtypes, **kw): 185 | '''Overload a function based on the specified argument types. 186 | 187 | :param argtypes: The argument types. If None is given, get the argument types from the function annotations(Python 3 only) 188 | :param kw: Can only contain 1 argument, `is_cls`. If True, the function is assumed to be part of a class. 189 | 190 | Example:: 191 | 192 | @overload.args(str) 193 | def func(s): 194 | print 'Got string' 195 | 196 | @overload.args(int, str) 197 | def func(i, s): 198 | print 'Got int and string' 199 | 200 | @overload.args() 201 | def func(i:int): # A function annotation example 202 | print 'Got int' 203 | 204 | func('s') 205 | func(1) 206 | func(1, 's') 207 | func(True) # Raises error 208 | ''' 209 | 210 | # Python 2 UnboundLocalError fix...again! 211 | argtypes = {'args': tuple(argtypes)} 212 | def _wrap(f): 213 | def _newf(*args): 214 | if len(kw) == 0: 215 | cargs = args 216 | elif len(kw) == 1 and 'is_cls' in kw and kw['is_cls']: 217 | cargs = args[1:] 218 | else: 219 | raise ValueError('Invalid keyword args specified') 220 | if _gettypes(cargs) not in self._types[f.__name__]: 221 | raise TypeError("No overload of function '%s' that takes '%s' types and %d arg(s)" % (f.__name__, _gettypes(cargs), len(cargs))) 222 | return self._types[f.__name__][_gettypes(cargs)](*args) 223 | if f.__name__ not in self._types: 224 | self._types[f.__name__] = {} 225 | if len(argtypes['args']) == 1 and argtypes['args'][0] is None: 226 | aspec = argspec(f) 227 | argtypes['args'] = tuple(map(lambda x: x[1], sorted( 228 | aspec.annotations.items(), key=lambda x: aspec.args.index(x[0])))) 229 | self._types[f.__name__][argtypes['args']] = f 230 | _newf.__name__ = f.__name__ 231 | _newf.__doc__ = f.__doc__ 232 | _newf.__is_overload__ = True 233 | _newf.__orig_arg__ = argspec(f) 234 | if IPython: 235 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f) 236 | return _newf 237 | return _wrap 238 | 239 | class _RuntimeModule(object): 240 | 'Create a module object at runtime and insert it into sys.path. If called, same as :py:func:`from_objects`.' 241 | def __call__(self, *args, **kwargs): 242 | return self.from_objects(*args, **kwargs) 243 | @staticmethod 244 | @overload.argc(1) 245 | def from_objects(module_name_for_code_eval, **d): 246 | return _RuntimeModule.from_objects(module_name_for_code_eval, '', **d) 247 | @staticmethod 248 | @overload.argc(2) 249 | def from_objects(module_name_for_code_eval, docstring, **d): 250 | '''Create a module at runtime from `d`. 251 | 252 | :param name: The module name. 253 | 254 | :param docstring: Optional. The module's docstring. 255 | 256 | :param \*\*d: All the keyword args, mapped from name->value. 257 | 258 | Example: ``RuntimeModule.from_objects('name', 'doc', a=1, b=2)``''' 259 | module = types.ModuleType(module_name_for_code_eval, docstring) 260 | module.__dict__.update(d) 261 | module.__file__ = '' 262 | sys.modules[module_name_for_code_eval] = module 263 | return module 264 | @staticmethod 265 | @overload.argc(2) 266 | def from_string(module_name_for_code_eval, s): 267 | return _RuntimeModule.from_string(module_name_for_code_eval, '', s) 268 | @staticmethod 269 | @overload.argc(3) 270 | def from_string(module_name_for_code_eval, docstring, s): 271 | '''Create a module at runtime from `s``. 272 | 273 | :param name: The module name. 274 | 275 | :param docstring: Optional. The module docstring. 276 | 277 | :param s: A string containing the module definition.''' 278 | g = {} 279 | _exec(s, g) 280 | return _RuntimeModule.from_objects(module_name_for_code_eval, docstring, **dict(filter(lambda x: x[0] not in g_backup, g.items()))) 281 | 282 | RuntimeModule = _RuntimeModule() 283 | 284 | class CaseObject(object): 285 | 'The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments.' 286 | def __init__(self, value): 287 | self.value = value 288 | self.did_match = False 289 | self.did_pass = False 290 | def __call__(self, *args): 291 | if assign('res', not self.did_pass and any([self.value == rhs for rhs in args])): 292 | self.did_match = True 293 | return res 294 | def quit(self): 295 | 'Forces all other calls to return False. Equilavent of a ``break`` statement.' 296 | self.did_pass = True 297 | def default(self): 298 | "Executed if quit wasn't called." 299 | return not self.did_match and not self.did_pass 300 | def __iter__(self): 301 | yield self 302 | def __enter__(self): 303 | return self 304 | def __exit__(self, *args): 305 | pass 306 | 307 | def switch(value): 308 | '''A Python switch statement implementation that is used with a ``with`` statement. 309 | 310 | :param value: The value to "switch". 311 | 312 | ``with`` statement example:: 313 | 314 | with switch('x'): 315 | if case(1): print 'Huh?' 316 | if case('x'): print 'It works!!!' 317 | 318 | .. warning:: If you modify a variable named "case" in the same scope that you use the ``with`` statement version, you will get an UnboundLocalError. The soluction is to use ``with switch('x') as case:`` instead of ``with switch('x'):``.''' 319 | res = CaseObject(value) 320 | inspect.stack()[1][0].f_globals['case'] = res 321 | return res 322 | 323 | def tail_recurse(spec=None): 324 | '''Remove tail recursion from a function. 325 | 326 | :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value. 327 | 328 | .. note:: 329 | 330 | This function has a slight overhead that is noticable when using timeit. Only use it if the function has a possibility of going over the recursion limit. 331 | 332 | .. warning:: 333 | 334 | This function will BREAK any code that either uses any recursion other than tail recursion or calls itself multiple times. For example, ``def x(): return x()+1`` will fail. 335 | 336 | Example:: 337 | 338 | @tail_recurse() 339 | def add(a, b): 340 | if a == 0: return b 341 | return add(a-1, b+1) 342 | 343 | add(10000000, 1) # Doesn't max the recursion limit. 344 | ''' 345 | def _wrap(f): 346 | class TailRecursion(Exception): 347 | def __init__(self, args, kwargs): 348 | self.args = args 349 | self.kwargs = kwargs 350 | def _newf(*args, **kwargs): 351 | if inspect.stack()[1][3] == f.__name__: 352 | if (spec and spec(args)) or not spec: 353 | raise TailRecursion(args, kwargs) 354 | while True: 355 | try: 356 | res = f(*args, **kwargs) 357 | except TailRecursion as ex: 358 | args = ex.args 359 | kwargs = ex.kwargs 360 | continue 361 | else: 362 | return res 363 | _newf.__doc__ = f.__doc__ 364 | return _newf 365 | return _wrap 366 | 367 | def annotate(*args, **kwargs): 368 | '''Set function annotations using decorators. 369 | 370 | :param args: This is a list of annotations for the function, in the order of the function's parameters. For example, ``annotate('Annotation 1', 'Annotation 2')`` will set the annotations of parameter 1 of the function to ``Annotation 1``. 371 | 372 | :param kwargs: This is a mapping of argument names to annotations. Note that these are applied *after* the argument list, so any args set that way will be overriden by this mapping. If there is a key named `ret`, that will be the annotation for the function's return value. 373 | 374 | .. deprecated:: 0.5 375 | Use :func:`fannotate` instead. 376 | ''' 377 | def _wrap(f): 378 | if not hasattr(f, '__annotations__'): 379 | f.__annotations__ = {} 380 | if 'ret' in kwargs: 381 | f.__annotations__['return'] = kwargs.pop('ret') 382 | f.__annotations__.update(dict(zip(argspec(f).args, args))) 383 | f.__annotations__.update(kwargs) 384 | return f 385 | return _wrap 386 | 387 | def fannotate(*args, **kwargs): 388 | '''Set function annotations using decorators. 389 | 390 | :param \*args: The first positional argument is used for the function's return value; all others are discarded. 391 | 392 | :param \**kwargs: This is a mapping of argument names to annotations. 393 | 394 | Example:: 395 | 396 | @fannotate('This for the return value', a='Parameter a', b='Parameter b') 397 | def x(a, b): 398 | pass 399 | 400 | ''' 401 | def _wrap(f): 402 | if not hasattr(f, '__annotations__'): 403 | f.__annotations__ = {} 404 | if len(args) >= 1: 405 | f.__annotations__['return'] = args[0] 406 | f.__annotations__.update(kwargs) 407 | return f 408 | return _wrap 409 | 410 | def safe_unpack(seq, ln, fill=None): 411 | '''Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off. 412 | 413 | :param seq: The sequence to unpack. 414 | 415 | :param ln: The expected length of the sequence. 416 | 417 | :param fill: The value to substitute if the sequence is too small. Defaults to ``None``. 418 | 419 | Example:: 420 | 421 | s = 'a:b' 422 | a, b = safe_unpack(s.split(':'), 2) 423 | # a = 'a' 424 | # b = 'b' 425 | s = 'a' 426 | a, b = safe_unpack(s.split(':'), 2) 427 | # a = 'a' 428 | # b = None''' 429 | if len(seq) > ln: 430 | return seq[:ln] 431 | elif len(seq) < ln: 432 | return seq + type(seq)([fill]*(ln-len(seq))) 433 | else: 434 | return seq 435 | 436 | def assign(varname, value): 437 | '''Assign `value` to `varname` and return it. If `varname` is an attribute and the instance name it belongs to is not defined, a NameError is raised. 438 | This can be used to emulate assignment as an expression. For example, this:: 439 | 440 | if assign('x', 7): ... 441 | 442 | is equilavent to this C code:: 443 | 444 | if (x = 7) ... 445 | 446 | .. warning:: 447 | 448 | When assigning an attribute, the instance it belongs to MUST be declared as global prior to the assignment. Otherwise, the assignment will not work. 449 | ''' 450 | fd = inspect.stack()[1][0].f_globals 451 | if '.' not in varname: 452 | fd[varname] = value 453 | else: 454 | vsplit = list(map(str.strip, varname.split('.'))) 455 | if vsplit[0] not in fd: 456 | raise NameError('Unknown object: %s'%vsplit[0]) 457 | base = fd[vsplit[0]] 458 | for x in vsplit[1:-1]: 459 | base = getattr(base, x) 460 | setattr(base, vsplit[-1], value) 461 | return value 462 | 463 | def is_main(frame=1): 464 | "Return if the caller is main. Equilavent to ``__name__ == '__main__'``." 465 | return inspect.stack()[frame][0].f_globals['__name__'] == '__main__' 466 | 467 | def _call_if_main(frame, f, args): 468 | if is_main(frame): return f(*args) 469 | 470 | def call_if_main(f,*args): 471 | "Call the `f` with `args` if the caller's module is main." 472 | return _call_if_main(3,f,args) 473 | 474 | def run_main(f,*args): 475 | "Call `f` with the `args` and terminate the program with its return code if the caller's module is main." 476 | sys.exit(_call_if_main(3,f,args)) 477 | 478 | def compare_and_swap(var, compare, new): 479 | "If `var` is equal to `compare`, set it to `new`." 480 | if assign('v', inspect.stack()[1][0].f_globals)[var] == compare: 481 | v[var] = new 482 | -------------------------------------------------------------------------------- /src/RL/rewards/rewards.py: -------------------------------------------------------------------------------- 1 | from .testing_util import run_test 2 | import torch 3 | import re 4 | import numpy as np 5 | from transformers import Qwen2ForSequenceClassification, AutoModelForCausalLM 6 | from transformers import AutoTokenizer 7 | 8 | from .prm_utils import get_process_rewards, PromptType, PrefixesType 9 | 10 | 11 | def phi(R_i, r_i, t, alpha_t_func, gamma, m=None): 12 | r""" 13 | Calculate the aggregated reward function \(\phi(R_i, r_i^{1:m})\). 14 | 15 | Parameters: 16 | - R_i: The final reward (scalar). 17 | - r_i: A sequence of intermediate rewards (array of length m). 18 | - t: The current timestep (scalar). 19 | - alpha_t_func: A function for the time-varying factor \(\alpha(t)\), which takes the time step t as input. 20 | - gamma: The discount factor (scalar, in the range [0, 1]). 21 | - m: The number of intermediate rewards (scalar). 22 | 23 | Returns: 24 | - Aggregated reward (scalar). 25 | """ 26 | # Calculate the weighted sum of intermediate rewards, considering the discount factor 27 | if m is None: 28 | m = len(r_i) 29 | else: 30 | assert len(r_i) == m, "Number of intermediate rewards must match the given value of m." 31 | weighted_intermediate_rewards = np.sum([gamma**j * r_i[j] for j in range(m)]) 32 | 33 | # Get the time-varying factor alpha(t) 34 | alpha_t = alpha_t_func(t) 35 | 36 | # Compute the aggregated reward using the given formula 37 | aggregated_reward = alpha_t * R_i + (1 - alpha_t) * (weighted_intermediate_rewards / m) 38 | 39 | return aggregated_reward 40 | 41 | # Example: Define a linear decay function for alpha(t) 42 | def linear_alpha(t, alpha_max=1.0, alpha_min=0.1, decay_rate=0.01): 43 | r""" 44 | A simple linear decay function for the time-varying factor \(\alpha(t)\). 45 | 46 | Parameters: 47 | - t: The current timestep (scalar). 48 | - alpha_max: The maximum value for \(\alpha(t)\) at t = 0 (default 1.0). 49 | - alpha_min: The minimum value for \(\alpha(t)\) as t increases (default 0.1). 50 | - decay_rate: The rate of decay (default 0.01). 51 | 52 | Returns: 53 | - The time-varying factor \(\alpha(t)\) (scalar). 54 | """ 55 | return max(alpha_min, alpha_max - decay_rate * t) 56 | 57 | 58 | 59 | class RewardAggregater(): 60 | def __init__(self,model,tokenizer, phi_func=phi, alpha_func=linear_alpha, gamma=0.9, device='cuda'): 61 | 62 | self.device = device 63 | self.phi_func = phi_func 64 | self.alpha_func = alpha_func 65 | self.gamma = gamma 66 | 67 | self.model = model 68 | self.tokenizer = tokenizer 69 | 70 | self.device = torch.device(self.device) 71 | self.model.eval() # Set model to evaluation mode 72 | 73 | self.reward = 0 # Initialize reward 74 | 75 | 76 | def compute_intermediate_rewards(self, prompts: PromptType, intermediate_texts: PrefixesType): 77 | process_rewards = get_process_rewards( 78 | self.model, 79 | self.tokenizer, 80 | prompts=prompts, 81 | completed_processes=intermediate_texts, 82 | tokenized_format='chat_completion', 83 | ) 84 | # retrun (probability of good, probability of bad) 85 | process_rewards = [good_and_bad_probs[0] for good_and_bad_probs in process_rewards] 86 | return process_rewards 87 | 88 | 89 | def update_reward(self, prompt, intermediate_texts, in_outs, current_timestep, outcome_reward=None): 90 | 91 | # Compute intermediate rewards 92 | intermediate_rewards = self.compute_intermediate_rewards(prompt, intermediate_texts) 93 | 94 | # Number of intermediate rewards 95 | m = len(intermediate_rewards) 96 | if outcome_reward is None: 97 | outcome_reward = self.comupte_outcome_reward(intermediate_texts[-1], in_outs) 98 | # Calculate the aggregated reward 99 | self.reward = self.phi_func( 100 | R_i=outcome_reward, 101 | r_i=intermediate_rewards, 102 | t=current_timestep, 103 | alpha_t_func=self.alpha_func, 104 | gamma=self.gamma, 105 | m=m 106 | ) 107 | 108 | return self.reward 109 | 110 | def exctract_runable_code(self, text): 111 | pattern = r"```python(.*?)```" 112 | 113 | try: 114 | code_snippets = re.findall(pattern, text, re.DOTALL)[-1] 115 | except: 116 | code_snippets = None 117 | 118 | return code_snippets 119 | 120 | 121 | def comupte_outcome_reward(self, final_step, in_outs): 122 | code = self.exctract_runable_code(final_step) 123 | if code == None: 124 | return 0 125 | result = run_test(code, in_outs, debug=True) 126 | 127 | score = [1 for r in result if r == True] 128 | score = sum(score)/len(result) 129 | return score 130 | 131 | def reset_reward(self): 132 | """ 133 | Reset the aggregated reward to zero. 134 | """ 135 | self.reward = 0 136 | 137 | if __name__ == "__main__": 138 | ... -------------------------------------------------------------------------------- /src/RL/rewards_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from rewards.rewards import RewardAggregater 5 | from rewards.prm_utils import CODEPRM_PROMPT 6 | from rewards.examples import rewards_test_examples 7 | 8 | 9 | def test(): 10 | model = AutoModelForCausalLM.from_pretrained( 11 | 'path/to/PRM', 12 | torch_dtype=torch.bfloat16, 13 | attn_implementation="flash_attention_2" 14 | ).to('cuda') 15 | tokenizer = AutoTokenizer.from_pretrained('path/to/PRM') 16 | aggregator = RewardAggregater( 17 | model=model, 18 | tokenizer=tokenizer 19 | ) 20 | 21 | for test_example in rewards_test_examples: 22 | question, reasoning_steps, test_cases = test_example['question'], test_example['reasoning_steps'], test_example['test_cases'] 23 | prompt = CODEPRM_PROMPT.format(question=question) 24 | test_reward = aggregator.update_reward( 25 | prompt, 26 | reasoning_steps, 27 | test_cases, 28 | 1, 29 | ) 30 | print(test_reward) 31 | 32 | if __name__ == '__main__': 33 | test() -------------------------------------------------------------------------------- /src/TestCaseGenerate/README.md: -------------------------------------------------------------------------------- 1 | ### Wash TCG Data 2 | 3 | you can use wash_code.py to generate four part of data, which only use stdin stdout to pass the case. 4 | 5 | ### Training Test Case Generate Model. 6 | 7 | accelerate launch TCG_SFT.py -------------------------------------------------------------------------------- /src/TestCaseGenerate/TCG.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import transformers 5 | from accelerate import PartialState 6 | from datasets import load_dataset, Dataset 7 | from peft import LoraConfig 8 | from transformers import ( 9 | AutoTokenizer, 10 | AutoModelForCausalLM, 11 | BitsAndBytesConfig, 12 | logging, 13 | set_seed, 14 | ) 15 | from trl import DPOTrainer, SFTTrainer 16 | import numpy as np 17 | from transformers import AutoModelForCausalLM, AutoTokenizer 18 | import random 19 | from peft import PeftModel, LoraConfig 20 | import torch.distributed as dist 21 | import json 22 | from torch.utils.data import DataLoader, DistributedSampler 23 | import tqdm 24 | import re 25 | import io, sys 26 | import multiprocessing 27 | 28 | 29 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 30 | 31 | def get_args(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--split", type=str, default="train") 34 | parser.add_argument("--dataset_text_field", type=str, default="text") 35 | 36 | parser.add_argument("--max_seq_length", type=int, default=1024 * 4) 37 | parser.add_argument("--max_steps", type=int, default=5000) 38 | parser.add_argument("--micro_batch_size", type=int, default=1) 39 | parser.add_argument("--gradient_accumulation_steps", type=int, default=8) 40 | parser.add_argument("--weight_decay", type=float, default=0.01) 41 | parser.add_argument("--bf16", type=bool, default=True) 42 | 43 | parser.add_argument("--attention_dropout", type=float, default=0.1) 44 | parser.add_argument("--learning_rate", type=float, default=5e-4) 45 | parser.add_argument("--lr_scheduler_type", type=str, default="cosine") 46 | parser.add_argument("--warmup_steps", type=int, default=100) 47 | parser.add_argument("--seed", type=int, default=0) 48 | parser.add_argument("--output_dir", type=str, default="finetune_deepseek1.3_instruct_o1_format_SFT") 49 | parser.add_argument("--num_proc", type=int, default=None) 50 | 51 | parser.add_argument("--model_path", type=str, default="/data/FastSSD/LLM_Models/deepseek-coder-1.3b-instruct/") 52 | return parser.parse_args() 53 | 54 | def build_test_part(A, B): 55 | test_part = """```case 56 | # input: 57 | {} 58 | # output: 59 | {} 60 | ``` 61 | """.format(A.strip(), B.strip()) 62 | if len(test_part) > 100: 63 | raise ValueError 64 | return test_part 65 | 66 | def build_TACO_SFT(item): 67 | return_item_lst = [] 68 | for solve in eval(item['solutions']): 69 | in_out_case = eval(item['input_output']) 70 | test_case = [[x[0], x[1]] for x in zip(in_out_case['inputs'], in_out_case['outputs'])] 71 | test_case = random.choices(test_case, k=min(3, len(test_case))) 72 | test_part = "".join([build_test_part(item[0], item[1]) for item in test_case]) 73 | templt = '''### Instruction 74 | Please complete the task in the code part and generate some test case in the test part that can be used to test the quality of the generated code. 75 | ### Problem 76 | {} 77 | ### Code Part 78 | {} 79 | ```python 80 | {} 81 | ``` 82 | ### Test Part 83 | [Generate 3 test cases here to validate the code.] 84 | {} 85 | <|EOT|> 86 | '''.format(item['question'].strip(), ", ".join(eval(item['tags'])[:20]), solve.strip(), test_part.strip()) 87 | return_item_lst.append(templt) 88 | return return_item_lst 89 | 90 | def main(args): 91 | # config 92 | bnb_config = BitsAndBytesConfig( 93 | load_in_4bit=True, 94 | bnb_4bit_quant_type="nf4", 95 | bnb_4bit_compute_dtype=torch.bfloat16, 96 | ) 97 | lora_config = LoraConfig( 98 | r=1, 99 | target_modules=[ 100 | "q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", 101 | ], 102 | task_type="CAUSAL_LM", 103 | ) 104 | model = AutoModelForCausalLM.from_pretrained( 105 | args.model_path, 106 | quantization_config=bnb_config, 107 | attention_dropout=args.attention_dropout, 108 | device_map={"": PartialState().process_index}, 109 | torch_dtype=torch.bfloat16, 110 | attn_implementation="flash_attention_2", 111 | ) 112 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 113 | 114 | select_data = json.load(open("select_data_train_can_use.json")) 115 | SFT_dataset = [] 116 | length = [] 117 | too_long_count = 0 118 | for item in tqdm.tqdm(select_data): 119 | try: 120 | if eval(item['solutions']) != []: 121 | item_str_lst = build_TACO_SFT(item) 122 | if len(item_str_lst) >= 30: 123 | item_str_lst = random.choices(item_str_lst, k=30) 124 | for item_str in item_str_lst: 125 | if len(item_str) > 3000: 126 | too_long_count += 1 127 | continue 128 | length.append(len(item_str)) 129 | SFT_dataset.append(item_str) 130 | except: 131 | pass 132 | print("Avg Length Len:", np.mean(length), len(SFT_dataset), too_long_count) 133 | SFT_dataset = Dataset.from_dict({"text": SFT_dataset}) 134 | SFT_dataset = SFT_dataset.shuffle(seed=42) 135 | 136 | # setup the SFT trainer 137 | trainer = SFTTrainer( 138 | model=model, 139 | train_dataset=SFT_dataset, 140 | max_seq_length=args.max_seq_length, 141 | tokenizer=tokenizer, 142 | args=transformers.TrainingArguments( 143 | per_device_train_batch_size=args.micro_batch_size, 144 | gradient_accumulation_steps=args.gradient_accumulation_steps, 145 | warmup_steps=args.warmup_steps, 146 | max_steps=5000, 147 | learning_rate=args.learning_rate, 148 | lr_scheduler_type=args.lr_scheduler_type, 149 | weight_decay=args.weight_decay, 150 | bf16=args.bf16, 151 | fp16=False, 152 | logging_strategy="steps", 153 | logging_steps=1, 154 | output_dir=args.output_dir, 155 | optim="paged_adamw_8bit", 156 | seed=args.seed, 157 | ), 158 | peft_config=lora_config, 159 | dataset_text_field="text", 160 | ) 161 | # launch 162 | print("Training SFT...") 163 | trainer.train() 164 | model.save_pretrained(os.path.join(args.output_dir, "SFT_final_checkpoint/")) 165 | tokenizer.save_pretrained(os.path.join(args.output_dir, "SFT_final_checkpoint/")) 166 | print("SFT Training Done!") 167 | 168 | 169 | # accelerate launch TCG_SFT.py 170 | if __name__ == "__main__": 171 | args = get_args() 172 | set_seed(args.seed) 173 | logging.set_verbosity_error() 174 | main(args) 175 | -------------------------------------------------------------------------------- /src/TestCaseGenerate/wash_code.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import json 3 | import requests 4 | import sys 5 | import multiprocessing 6 | import io 7 | import re 8 | import ray 9 | import ast 10 | import numpy as np 11 | import threading 12 | from datasets import load_dataset, Dataset 13 | 14 | # 初始化 Ray 15 | ray.init(ignore_reinit_error=True) 16 | 17 | # 超时终止器 18 | def timeout_handler(process, result_container): 19 | process.terminate() # 强制终止进程 20 | result_container["error"] = "Timeout" 21 | 22 | @ray.remote 23 | def execute_code_with_timeout(test_code, input_data, timeout=3): 24 | """ 25 | 使用 Ray 实现的远程函数,用于在限定时间内执行代码。 26 | """ 27 | manager = multiprocessing.Manager() # 使用 Python 标准库中的 multiprocessing.Manager() 28 | result_container = manager.dict() 29 | 30 | def target_func(): 31 | try: 32 | sys.stdin = io.StringIO(input_data) 33 | sys.stdout = io.StringIO() 34 | env_test = {} 35 | exec(test_code, env_test) 36 | result_container["output"] = sys.stdout.getvalue().strip() 37 | result_container["error"] = None 38 | except Exception as e: 39 | result_container["error"] = str(e) 40 | 41 | process = multiprocessing.Process(target=target_func) # 创建进程执行代码 42 | process.start() 43 | 44 | timer = threading.Timer(timeout, timeout_handler, args=(process, result_container)) # 设置超时定时器 45 | timer.start() 46 | 47 | process.join(timeout) # 等待进程完成 48 | timer.cancel() # 取消定时器(如果进程已完成) 49 | 50 | return result_container.get("output", None), result_container.get("error", None) 51 | 52 | @ray.remote 53 | def worker(case, test_code, timeout): 54 | """ 55 | Worker 函数,用于执行单个测试用例。 56 | """ 57 | input_data = case['input'].strip() 58 | expected_output = case['output'].strip() 59 | actual_output, error = ray.get(execute_code_with_timeout.remote(test_code, input_data, timeout)) 60 | 61 | if error: 62 | return { 63 | 'input': input_data, 64 | 'expected_output': expected_output, 65 | 'actual_output': None, 66 | 'error': error, 67 | 'status': 'failed' 68 | } 69 | if actual_output == expected_output: 70 | return { 71 | 'input': input_data, 72 | 'expected_output': expected_output, 73 | 'actual_output': actual_output, 74 | 'error': None, 75 | 'status': 'passed' 76 | } 77 | else: 78 | return { 79 | 'input': input_data, 80 | 'expected_output': expected_output, 81 | 'actual_output': actual_output, 82 | 'error': None, 83 | 'status': 'failed' 84 | } 85 | 86 | def test_cases_with_limit(output_cases, test_right_code, timeout=3): 87 | """ 88 | 使用 Ray 实现的并行化测试用例函数。 89 | """ 90 | futures = [worker.remote(case, test_right_code, timeout) for case in output_cases] 91 | results = ray.get(futures) 92 | 93 | passed_case = [res for res in results if res['status'] == 'passed'] 94 | unpassed_case = [res for res in results if res['status'] == 'failed'] 95 | 96 | return passed_case, unpassed_case 97 | 98 | 99 | if __name__ == "__main__": 100 | can_use_case = [] 101 | can_use_case_all_pass = [] 102 | pass_rate = [] 103 | ds = load_dataset("/data/FastSSD/LLM_Models/TACO") 104 | Pbar = tqdm.tqdm(ds['train']) 105 | for item in Pbar: 106 | if eval(item['solutions']) == []: 107 | continue 108 | try: 109 | express_code = json.loads(item['solutions'])[0] 110 | example_case = json.loads(item['input_output']) 111 | except: 112 | continue 113 | 114 | all_example_case = [{'input': str(x), 'output': str(y)} for x, y in zip(example_case['inputs'], example_case['outputs'])] 115 | passed_case, unpassed_case = test_cases_with_limit(all_example_case, express_code, timeout=10) 116 | 117 | # 计算通过率 118 | if (len(passed_case) + len(unpassed_case)) == 0: 119 | current_pass_rate = 0 120 | else: 121 | current_pass_rate = len(passed_case) / (len(passed_case) + len(unpassed_case)) 122 | pass_rate.append(current_pass_rate) 123 | if current_pass_rate == 1.0: 124 | can_use_case_all_pass.append(item) 125 | if current_pass_rate > 0.0: 126 | can_use_case.append(item) 127 | Pbar.set_description(f"Avg Pass: {np.mean(pass_rate)}, All Pass: {np.mean(np.array(pass_rate) == 1.)}, Current: {current_pass_rate}") 128 | 129 | json.dump(can_use_case, open("/home/xukaiyuan/Project/TreeSearch_Code/wash_code/select_data_train_can_use.json", 'w')) 130 | json.dump(can_use_case_all_pass, open("/home/xukaiyuan/Project/TreeSearch_Code/wash_code/select_data_train_all_pass.json", 'w')) 131 | -------------------------------------------------------------------------------- /src/mcts/README.md: -------------------------------------------------------------------------------- 1 | ## How to Use 2 | 3 | ### Data Preparation 4 | 5 | You can find the TACO dataset on [Hugging Face](https://huggingface.co/datasets/BAAI/TACO). Convert the dataset into a JSON file and place it in the `\data\TACO` directory. 6 | 7 | --- 8 | 9 | ### Generate Data 10 | 11 | #### Local Model 12 | 13 | Run the `run_TACO.sh` script with the appropriate configurations to use the local model for data generation: 14 | 15 | ```bash 16 | bash run_TACO.sh 17 | ``` 18 | 19 | #### Main Arguments 20 | 21 | | Argument | Type | Description | 22 | |-----------------------|--------|-----------------------------------------| 23 | | `--dataset_name` | str | Name of the dataset folder in the `data` directory | 24 | | `--test_json_filename`| str | Name of the JSON file containing the data | 25 | | `--model_ckpt` | str | Path to the model checkpoint | 26 | | `--num_rollouts` | int | Number of MCTS rollouts | 27 | | `--max_depth_allowed` | int | Maximum depth allowed for the MCTS search tree | 28 | 29 | --- 30 | 31 | #### OpenAI API 32 | 33 | First, configure your `api_key` in the `\models\OpenAI_API.py` file. 34 | 35 | Then, run the `api_run_TACO.sh` script to use the API for data generation: 36 | 37 | ```bash 38 | bash api_run_TACO.sh 39 | ``` 40 | 41 | #### Additional Arguments 42 | 43 | | Argument | Type | Description | 44 | |-----------------------|--------|-----------------------------------------| 45 | | `--api` | str | Default is `vllm`, which calls the local model | 46 | | `--model_ckpt` | str | Specific OpenAI model name | 47 | 48 | --- 49 | 50 | ### Data Example 51 | 52 | A sample dataset can be found in the `run_outputs` folder. Detailed information for each problem is available in the `answer_sheets` folder. 53 | 54 | #### File Information 55 | 56 | | Filename | Description | 57 | |-----------------------------------------|-----------------------------------------| 58 | | `Question XXXX - Answer.json` | Contains the original question information | 59 | | `Question XXXX - Best Solution.json` | The path with the highest reward in the final step | 60 | | `Question XXXX - Complete Solutions.json` | All complete paths in the MCTS search tree | 61 | | `Question XXXX - Rollout Solutions.json` | Paths generated during each MCTS rollout | 62 | | `args.json` | Parameter configuration information | 63 | | `intermediate_result.txt` | Logs for model calls and intermediate results | 64 | 65 | 66 | ## Acknowledge 67 | 68 | This code is derived from and modified based on the project available at [https://github.com/zhentingqi/rStar/](https://github.com/zhentingqi/rStar/). 69 | -------------------------------------------------------------------------------- /src/mcts/common/arguments.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | 3 | import os, json, torch, math 4 | from argparse import ArgumentParser 5 | from datetime import datetime 6 | 7 | 8 | def get_parser(): 9 | parser = ArgumentParser() 10 | 11 | parser.add_argument("--note", type=str, default="debug") 12 | 13 | allowed_apis = ["together", "huggingface", "llama", "vllm", "debug", "OpenAI"] 14 | parser.add_argument( 15 | "--api", type=str, choices=allowed_apis, default="vllm", help=f"API to use: Choose from {allowed_apis}." 16 | ) 17 | 18 | parser.add_argument("--seed", type=int, default=42) 19 | parser.add_argument("--verbose", action="store_true") 20 | 21 | #! WandB settings 22 | parser.add_argument("--wandb_mode", type=str, default="disabled", choices=["disabled", "online"]) 23 | 24 | #! LLM settings 25 | parser.add_argument("--model_ckpt", required=True) 26 | 27 | parser.add_argument("--model_parallel", action="store_true") 28 | parser.add_argument("--half_precision", action="store_true") 29 | 30 | parser.add_argument("--max_tokens", type=int, default=4096, help="max_tokens") 31 | parser.add_argument("--temperature", type=float, default=0.8, help="temperature") 32 | parser.add_argument("--top_k", type=int, default=40, help="top_k") 33 | parser.add_argument("--top_p", type=float, default=0.95, help="top_p") 34 | parser.add_argument("--num_beams", type=int, default=1, help="num_beams") 35 | 36 | parser.add_argument("--max_num_worker", type=int, default=3, help="maximum number of workers for dataloader") 37 | parser.add_argument("--test_batch_size", type=int, default=1) # batch_size 38 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size 39 | 40 | #! prompt settings 41 | parser.add_argument("--prompts_root", default="prompts") 42 | 43 | #! dataset settings 44 | parser.add_argument("--data_root", default="data") 45 | allowed_dataset_names = ["TACO"] 46 | parser.add_argument( 47 | "--dataset_name", 48 | required=True, 49 | choices=allowed_dataset_names, 50 | help=f"Test dataset name: Choose from {allowed_dataset_names}.", 51 | ) 52 | parser.add_argument("--test_json_filename", type=str, default="test_all") 53 | parser.add_argument("--start_idx", type=int, default=0, help="Start index of test questions (inclusive)") 54 | parser.add_argument("--end_idx", type=int, default=math.inf, help="End index of test questions (inclusive))") 55 | 56 | #! outputs settings 57 | parser.add_argument("--run_outputs_root", type=str, default="run_outputs") 58 | parser.add_argument("--eval_outputs_root", type=str, default="eval_outputs") 59 | parser.add_argument("--run_outputs_dir", type=str, default="") 60 | 61 | return parser 62 | 63 | 64 | def post_process_args(args): 65 | # Set up logging 66 | suffix = "---[" + args.note + "]" if args.note is not None else "" 67 | model_name = args.model_ckpt.split("/")[-1] 68 | if args.run_outputs_dir == "": 69 | args.run_outputs_dir = os.path.join( 70 | args.run_outputs_root, 71 | args.dataset_name, 72 | model_name, 73 | f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + suffix, 74 | ) 75 | else: 76 | args.run_outputs_dir = os.path.join( 77 | args.run_outputs_root, 78 | args.dataset_name, 79 | model_name, 80 | f"{args.run_outputs_dir}---{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + suffix, 81 | ) 82 | os.makedirs(args.run_outputs_dir, exist_ok=True) 83 | 84 | 85 | args.answer_sheets_dir = os.path.join(args.run_outputs_dir, "answer_sheets") 86 | os.makedirs(args.answer_sheets_dir, exist_ok=True) 87 | 88 | # Check GPU 89 | num_gpus = torch.cuda.device_count() 90 | cuda_devices = [torch.cuda.get_device_name(i) for i in range(num_gpus)] 91 | assert len(cuda_devices) > 0, "No GPU available." 92 | args.cuda_0 = cuda_devices[0] 93 | args.cuda_1 = cuda_devices[1] if len(cuda_devices) > 1 else None 94 | args.cuda_2 = cuda_devices[2] if len(cuda_devices) > 2 else None 95 | args.cuda_3 = cuda_devices[3] if len(cuda_devices) > 3 else None 96 | 97 | return args 98 | 99 | 100 | def save_args(args): 101 | # Save args as json 102 | with open(os.path.join(args.run_outputs_dir, "args.json"), "w") as f: 103 | json.dump(vars(args), f, indent=4) 104 | -------------------------------------------------------------------------------- /src/mcts/common/utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | 3 | import json 4 | import re 5 | import os 6 | import random 7 | import numpy as np 8 | import torch 9 | import multiprocessing 10 | from typing import Tuple 11 | from statistics import mean 12 | from torch.utils.data import Dataset 13 | 14 | 15 | def fix_seeds(seed): 16 | # random 17 | random.seed(seed) 18 | # Numpy 19 | np.random.seed(seed) 20 | # Pytorch 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = False 25 | 26 | 27 | def setup_model_parallel() -> Tuple[int, int]: 28 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel 29 | 30 | local_rank = int(os.environ.get("LOCAL_RANK", -1)) 31 | world_size = int(os.environ.get("WORLD_SIZE", -1)) 32 | 33 | torch.distributed.init_process_group("nccl") 34 | initialize_model_parallel(world_size) 35 | torch.cuda.set_device(local_rank) 36 | 37 | return local_rank, world_size 38 | 39 | 40 | def read_json(file_path): 41 | assert str(file_path).endswith(".json") 42 | with open(file_path, "r", encoding="utf-8") as f: 43 | data = json.load(f) 44 | return data 45 | 46 | 47 | def save_json(js_obj, file_path): 48 | assert str(file_path).endswith(".json") 49 | with open(file_path, "w", encoding="utf-8") as f: 50 | json.dump(js_obj, f, indent=4) 51 | 52 | 53 | def read_txt(file_path): 54 | assert str(file_path).endswith(".txt") 55 | with open(file_path, "r", encoding="utf-8") as f: 56 | data = f.read() 57 | return data 58 | 59 | -------------------------------------------------------------------------------- /src/mcts/data/TACO/self_create_data.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "question": "In a galaxy far, far away, the Empire has launched a series of attacks using their powerful weapon, the Death Star. To counter this threat, the Rebel Alliance has decided to arm themselves with the legendary Lightsabers of the Jedi. Each Lightsaber has a certain power, and the sum of the powers of the Lightsabers will determine the strength of the Alliance's defense. You, as a strategist, must calculate the sum of the powers of the Lightsabers provided by your Jedi allies to ensure the best possible defense against the Empire's onslaught.\n\nInput:\nThe first line contains an integer T, representing the number of sets of Lightsabers. Each set is described in the following lines. The first number in each set is an integer n, representing the number of Lightsabers in that set, followed by n space-separated integers, each representing the power of a Lightsaber.\n\nOutput:\nFor each set of Lightsabers, calculate and print the sum of their powers. If there is no Lightsaber in a set, print 0 for that set.\n\nConstraints:\n1 ≤ T ≤ 10^6,\n0 ≤ n ≤ 10^6,\n0 ≤ power of each Lightsaber ≤ 10^9.\n\nSAMPLE INPUT\n2\n3 1 2 3\n4 10 20 30 40\n\nSAMPLE OUTPUT\n6\n100", 4 | "solutions": [], 5 | "starter_code": "", 6 | "input_output": { 7 | "inputs": [ 8 | "1\n2\n1 2", 9 | "2\n3\n1 2 3\n2\n4 5", 10 | "1\n4\n0 0 0 0", 11 | "3\n2\n10 20\n2\n30 40\n2\n50 60", 12 | "1\n5\n7 8 9 10 11", 13 | "2\n1\n100\n3\n1 1 1", 14 | "1\n2\n-1 1", 15 | "1\n10\n1 2 3 4 5 6 7 8 9 10", 16 | "2\n4\n1 1 1 1\n4\n2 2 2 2", 17 | "1\n3\n0 0 0", 18 | "2\n5\n11 22 33 44 55\n5\n66 77 88 99 100", 19 | "1\n1\n1000000000", 20 | "2\n2\n-5 5\n3\n-10 0 10" 21 | ], 22 | "outputs": [ 23 | "3", 24 | "6\n9", 25 | "0", 26 | "30\n70\n110", 27 | "45", 28 | "100\n3", 29 | "0", 30 | "55", 31 | "4\n8", 32 | "0", 33 | "165\n430", 34 | "1000000000", 35 | "0\n0" 36 | ] 37 | }, 38 | "difficulty": "EASY", 39 | "raw_tags": "['Algorithms', 'ad-hoc', 'Constructive']", 40 | "name": null, 41 | "source": "codechef", 42 | "tags": "['Constructive algorithms', 'Ad-hoc']", 43 | "skill_types": [], 44 | "url": "https://www.codechef.com/problems/RRJOKE", 45 | "Expected Auxiliary Space": null, 46 | "time_limit": "1 seconds", 47 | "date": "2014-12-18", 48 | "picture_num": 0, 49 | "memory_limit": "50000 bytes", 50 | "Expected Time Complexity": null 51 | } 52 | ] -------------------------------------------------------------------------------- /src/mcts/data/TACO/test_one.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "question": "**Step 1:** Create a function called `encode()` to replace all the lowercase vowels in a given string with numbers according to the following pattern:\n```\na -> 1\ne -> 2\ni -> 3\no -> 4\nu -> 5\n```\n\nFor example, `encode(\"hello\")` would return `\"h2ll4\"`. There is no need to worry about uppercase vowels in this kata.\n\n**Step 2:** Now create a function called `decode()` to turn the numbers back into vowels according to the same pattern shown above.\n\nFor example, `decode(\"h3 th2r2\")` would return `\"hi there\"`.\n\nFor the sake of simplicity, you can assume that any numbers passed into the function will correspond to vowels.", 4 | "solutions": [ 5 | "def encode(s, t=str.maketrans('aeiou', '12345')):\n\treturn s.translate(t)\n\ndef decode(s, t=str.maketrans('12345', 'aeiou')):\n\treturn s.translate(t)\n", 6 | "CIPHER = ('aeiou', '12345')\n\ndef encode(st):\n\treturn st.translate(str.maketrans(CIPHER[0], CIPHER[1]))\n\ndef decode(st):\n\treturn st.translate(str.maketrans(CIPHER[1], CIPHER[0]))\n", 7 | "def encode(st):\n\tfor (i, v) in enumerate('aeiou', start=1):\n\t\tst = st.replace(v, str(i))\n\treturn st\n\ndef decode(st):\n\tfor (i, v) in enumerate('aeiou', start=1):\n\t\tst = st.replace(str(i), v)\n\treturn st\n", 8 | "tbl1 = str.maketrans('aeiou', '12345')\ntbl2 = str.maketrans('12345', 'aeiou')\n\ndef encode(st):\n\treturn st.translate(tbl1)\n\ndef decode(st):\n\treturn st.translate(tbl2)\n", 9 | "a = {'a': '1', 'e': '2', 'i': '3', 'o': '4', 'u': '5'}\nb = ('a', 'e', 'i', 'o', 'u')\n\ndef encode(st):\n\treturn ''.join((a[c] if c in a else c for c in st))\n\ndef decode(st):\n\treturn ''.join((b[int(c) - 1] if c.isdigit() else c for c in st))\n", 10 | "def cipher(mode):\n\ttable = str.maketrans(*['aeiou', '12345'][::mode])\n\treturn lambda s: s.translate(table)\n(encode, decode) = (cipher(1), cipher(-1))\n", 11 | "CYPHER = tuple(zip('aeiou', '12345'))\n\ndef munge(st, mapping):\n\treturn ''.join([mapping.get(c, c) for c in st])\n\ndef encode(st):\n\treturn munge(st, {a: b for (a, b) in CYPHER})\n\ndef decode(st):\n\treturn munge(st, {b: a for (a, b) in CYPHER})\n", 12 | "def encode(st):\n\tL = []\n\tA = {'a': '1', 'e': '2', 'i': '3', 'o': '4', 'u': '5'}\n\tfor i in st:\n\t\tif i in A:\n\t\t\tL.append(A[i])\n\t\telse:\n\t\t\tL.append(i)\n\treturn ''.join(L)\n\ndef decode(st):\n\tL = []\n\tA = {'1': 'a', '2': 'e', '3': 'i', '4': 'o', '5': 'u'}\n\tfor i in st:\n\t\tif i in A:\n\t\t\tL.append(A[i])\n\t\telse:\n\t\t\tL.append(i)\n\treturn ''.join(L)\n", 13 | "import re\n\ndef encode(st):\n\tvowel = ' aeiou'\n\treturn re.sub('[aeoui]', lambda x: str(vowel.index(x.group(0))), st)\n\ndef decode(st):\n\tvowel = ' aeiou'\n\treturn re.sub('[1-5]', lambda x: vowel[int(x.group(0))], st)\n", 14 | "a = ['a', 'e', 'i', 'o', 'u']\n\ndef encode(st):\n\treturn ''.join([str(a.index(c) + 1) if c in a else c for c in st])\n\ndef decode(st):\n\treturn ''.join([a[int(c) - 1] if c.isdigit() else c for c in st])\n" 15 | ], 16 | "starter_code": "def encode(st):\n\t", 17 | "input_output": { 18 | "fn_name": "encode", 19 | "inputs": [ 20 | [ 21 | "hello" 22 | ], 23 | [ 24 | "How are you today?" 25 | ], 26 | [ 27 | "This is an encoding test." 28 | ] 29 | ], 30 | "outputs": [ 31 | [ 32 | "h2ll4" 33 | ], 34 | [ 35 | "H4w 1r2 y45 t4d1y?" 36 | ], 37 | [ 38 | "Th3s 3s 1n 2nc4d3ng t2st." 39 | ] 40 | ] 41 | }, 42 | "difficulty": "EASY", 43 | "raw_tags": "['Regular Expressions', 'Strings', 'Fundamentals', 'Arrays']", 44 | "name": null, 45 | "source": "codewars", 46 | "tags": "['String algorithms', 'Fundamentals', 'Data structures']", 47 | "skill_types": "['Data structures']", 48 | "url": "https://www.codewars.com/kata/53697be005f803751e0015aa", 49 | "Expected Auxiliary Space": null, 50 | "time_limit": null, 51 | "date": null, 52 | "picture_num": null, 53 | "memory_limit": null, 54 | "Expected Time Complexity": null 55 | } 56 | ] -------------------------------------------------------------------------------- /src/mcts/eval_src/Evaluator.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | 3 | from eval_src.checker_utils import CodeSolutionParser, check_generation_correctness 4 | 5 | import os, json, re 6 | from typing import List, Dict, Tuple 7 | from collections import defaultdict 8 | import random 9 | import copy 10 | from fuzzywuzzy import fuzz, process 11 | 12 | from multiprocessing import Manager, Process 13 | import concurrent.futures 14 | 15 | 16 | class Evaluator: 17 | def __init__(self) -> None: 18 | self.answer_marker = "answer is" 19 | self.parser = CodeSolutionParser() 20 | 21 | 22 | def find_TACO_code(self, completions: List[str], test_case: dict, solution_trace: Dict[int, Dict[str, str]],): 23 | if completions is None or len(completions) == 0: 24 | return None, None, None, None 25 | solution_trace_ = copy.deepcopy(solution_trace) 26 | id2pass_completions = defaultdict(list) 27 | pass_ratio = 0 28 | compile_pass = False 29 | 30 | 31 | for id, c in enumerate(completions): 32 | result = self.parser.process_solution(c) 33 | 34 | generation_code = result["final_code"] 35 | 36 | if "fn_name" in test_case: 37 | if "main_function" in result: 38 | if result["main_function"] is not None: 39 | if "name" in result["main_function"]: 40 | if test_case["fn_name"] != result["main_function"]['name']: 41 | test_case["fn_name"] = result["main_function"]['name'] 42 | 43 | 44 | 45 | if generation_code == None: 46 | pass_ratio = 0 47 | continue 48 | 49 | 50 | correctness_results = check_generation_correctness(test_case, generation_code, debug=False, n_cases=10) 51 | # print(correctness_results) 52 | 53 | if isinstance(correctness_results, list): 54 | if True in correctness_results or False in correctness_results: 55 | compile_pass = True 56 | pass_case_count = correctness_results.count(True) 57 | 58 | # 计算比例 59 | pass_ratio = pass_case_count / len(correctness_results) 60 | else: 61 | pass_ratio = 0 62 | 63 | alpha = 0 64 | if compile_pass: 65 | pass_ratio = alpha * 1 + (1 - alpha) * pass_ratio 66 | 67 | # print(f"*********** {id} : score : {pass_ratio} *********") 68 | 69 | 70 | 71 | 72 | return "", completions[0], pass_ratio, solution_trace_ 73 | 74 | 75 | 76 | class TACOEvaluator(Evaluator): 77 | def __init__(self) -> None: 78 | super().__init__() 79 | 80 | def passed(self, references): 81 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 82 | results = [] 83 | 84 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 85 | future_results = {executor.submit(self.run_code_with_timeout, references)} 86 | for future in concurrent.futures.as_completed(future_results): 87 | results.append(future.result()) 88 | # print(results) 89 | return results[0] == 'passed' 90 | 91 | def run_code_with_timeout(self, code_string, timeout=1): 92 | with Manager() as manager: 93 | result_dict = manager.dict() 94 | process = Process(target=self.exec_code, args=(code_string, result_dict)) 95 | process.start() 96 | process.join(timeout=timeout) 97 | if process.is_alive(): 98 | process.kill() 99 | return "timeout" 100 | else: 101 | return result_dict['result'] 102 | 103 | @staticmethod 104 | def exec_code(code, result_dict): 105 | result_dict['result'] = 'Not executed' 106 | try: 107 | exec_globals = {} 108 | exec(code, exec_globals) 109 | result_dict['result'] = 'passed' 110 | except Exception as e: 111 | 112 | result_dict['result'] = f'Error: {str(e)}' 113 | 114 | def extract_answer_from_gold_solution(self, solution: str): 115 | return None 116 | 117 | def extract_answer_from_model_completion(self, completion: str): 118 | if completion is None: 119 | return None 120 | 121 | assert isinstance(completion, str) 122 | 123 | preds = completion.replace('\\n', '\n') 124 | code_maker = "The code is: \[Code Start\]\s*(.*?)\s*\[Code End\]" 125 | code = re.search(code_maker, preds, re.DOTALL) 126 | 127 | if code: 128 | result = code.group(1) 129 | return str(result.replace('\\r', '').replace('\\n', '\n').replace('\\t', '\t')) 130 | else: 131 | 132 | return None -------------------------------------------------------------------------------- /src/mcts/eval_src/checker_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import ast 3 | from typing import Optional, Dict 4 | 5 | class CodeSolutionParser: 6 | def __init__(self): 7 | self.steps = [] 8 | self.final_code = None 9 | self.main_function = None 10 | 11 | def check_final_step(self, text: str) -> bool: 12 | """Check if the last step is code generation.""" 13 | if text == "": 14 | return False 15 | 16 | last_step = text.lower() 17 | # Check if the last step mentions code generation 18 | code_indicators = [ 19 | "```python" 20 | ] 21 | 22 | return any(indicator in last_step for indicator in code_indicators) 23 | 24 | def extract_code(self, text: str) -> str: 25 | """Extract the Python code from the last step.""" 26 | if text == "": 27 | return None 28 | 29 | last_step = text 30 | 31 | # Find code between triple backticks 32 | code_pattern = r'```python(.*?)```' 33 | code_match = re.search(code_pattern, last_step, re.DOTALL) 34 | 35 | if code_match: 36 | code = code_match.group(1).strip() 37 | self.final_code = code 38 | return code 39 | return None 40 | 41 | def extract_outermost_function(self) -> Optional[Dict]: 42 | """Extract the outermost function from the code, including class methods.""" 43 | if not self.final_code: 44 | return None 45 | 46 | try: 47 | # Parse the code into an AST 48 | tree = ast.parse(self.final_code) 49 | 50 | # First try to find module-level function 51 | for node in ast.iter_child_nodes(tree): 52 | if isinstance(node, ast.FunctionDef): 53 | return self._extract_function_info(node) 54 | 55 | # If no module-level function found, look for class methods 56 | for node in ast.iter_child_nodes(tree): 57 | if isinstance(node, ast.ClassDef): 58 | # Look for the first method in the class 59 | for class_node in node.body: 60 | if isinstance(class_node, ast.FunctionDef): 61 | # Skip __init__ and other special methods 62 | if not class_node.name.startswith('__'): 63 | function_info = self._extract_function_info(class_node) 64 | function_info['class_name'] = node.name 65 | return function_info 66 | 67 | except SyntaxError: 68 | return None 69 | 70 | return None 71 | 72 | def _extract_function_info(self, node: ast.FunctionDef) -> Dict: 73 | """Helper method to extract information from a function node.""" 74 | function_info = { 75 | 'name': node.name, 76 | 'args': [arg.arg for arg in node.args.args], 77 | 'body': ast.unparse(node) 78 | } 79 | 80 | # Add return type annotation if exists 81 | if node.returns: 82 | function_info['return_type'] = ast.unparse(node.returns) 83 | 84 | # Add argument type annotations if exist 85 | arg_types = {} 86 | for arg in node.args.args: 87 | if arg.annotation: 88 | arg_types[arg.arg] = ast.unparse(arg.annotation) 89 | if arg_types: 90 | function_info['arg_types'] = arg_types 91 | 92 | # Add docstring if exists 93 | docstring = ast.get_docstring(node) 94 | if docstring: 95 | function_info['docstring'] = docstring 96 | 97 | return function_info 98 | 99 | def process_solution(self, text: str) -> dict: 100 | """Process the entire solution text and return results.""" 101 | has_code_generation = self.check_final_step(text) 102 | code = self.extract_code(text) if has_code_generation else None 103 | 104 | # Extract the outermost function if code exists 105 | main_function = None 106 | if code: 107 | main_function = self.extract_outermost_function() 108 | 109 | return { 110 | 'has_code_generation': has_code_generation, 111 | 'final_code': code, 112 | 'main_function': main_function 113 | } 114 | 115 | import json 116 | import multiprocessing as mp 117 | import concurrent 118 | import numpy as np 119 | from typing import List, Dict, Any, Union 120 | from eval_src.testing_util import run_test 121 | 122 | TIMEOUT = 10 123 | 124 | def check_generation_correctness( 125 | test_cases: Dict[str, Union[str, List]], 126 | generation: str, 127 | timeout: int = TIMEOUT, 128 | debug: bool = False, 129 | n_cases: Optional[int] = None, 130 | ) -> List[bool]: 131 | """ 132 | Args: 133 | test_cases (Dict[str, Union[str, List]]): A dictionary containing test cases with inputs and expected outputs. 134 | generation (str): The generated code to be tested. 135 | timeout (int, optional): The maximum time allowed for the test execution. Defaults to TIMEOUT. 136 | debug (bool, optional): If True, prints debug information. Defaults to False. 137 | Returns: 138 | List[bool]: A list of booleans indicating the correctness of each test case. If a timeout occurs, returns a list of -1s. 139 | """ 140 | 141 | try: 142 | return run_test(test_cases, generation, debug, n_cases) 143 | except Exception as e: 144 | if debug: 145 | print(f"Error in running test cases: {e}") 146 | in_outs = test_cases 147 | return [-2] * len(in_outs["inputs"]) 148 | 149 | -------------------------------------------------------------------------------- /src/mcts/eval_src/pyext2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2014 Ryan Gonzalez 3 | 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to use, 8 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 9 | Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | ''' 22 | 23 | g_backup = globals().copy() 24 | 25 | __version__ = '0.7' 26 | 27 | __all__ = ['overload', 'RuntimeModule', 'switch', 'tail_recurse', 'copyfunc', 'set_docstring', 'annotate', 'safe_unpack', 'modify_function', 'assign', 'fannotate', 'compare_and_swap', 'is_main', 'call_if_main', 'run_main'] 28 | 29 | import sys, inspect, types 30 | 31 | def __targspec(func, specs, attr='__orig_arg__'): 32 | if hasattr(func, '__is_overload__') and func.__is_overload__: 33 | return getattr(func, attr) 34 | return specs(func) 35 | 36 | def set_docstring(doc): 37 | '''A simple decorator to set docstrings. 38 | 39 | :param doc: The docstring to tie to the function. 40 | 41 | Example:: 42 | 43 | @set_docstring('This is a docstring') 44 | def myfunc(x): 45 | pass''' 46 | def _wrap(f): 47 | f.__doc__ = doc 48 | return f 49 | return _wrap 50 | 51 | __modify_function_doc = ''' 52 | Creates a copy of a function, changing its attributes. 53 | 54 | :param globals: Will be added to the function's globals. 55 | 56 | :param name: The new function name. Set to ``None`` to use the function's original name. 57 | 58 | :param code: The new function code object. Set to ``None`` to use the function's original code object. 59 | 60 | :param defaults: The new function defaults. Set to ``None`` to use the function's original defaults. 61 | 62 | :param closure: The new function closure. Set to ``None`` to use the function's original closure. 63 | 64 | .. warning:: This function can be potentially dangerous. 65 | ''' 66 | 67 | def copyfunc(f): 68 | '''Copies a funcion. 69 | 70 | :param f: The function to copy. 71 | 72 | :return: The copied function. 73 | 74 | .. deprecated:: 0.4 75 | Use :func:`modify_function` instead. 76 | ''' 77 | return modify_function(f) 78 | 79 | if sys.version_info.major == 3: 80 | @set_docstring(__modify_function_doc) 81 | def modify_function(f, globals={}, name=None, code=None, defaults=None, 82 | closure=None): 83 | if code is None: code = f.__code__ 84 | if name is None: name = f.__name__ 85 | if defaults is None: defaults = f.__defaults__ 86 | if closure is None: closure = f.__closure__ 87 | newf = types.FunctionType(code, dict(f.__globals__, **globals), name=name, 88 | argdefs=defaults, closure=closure) 89 | newf.__dict__.update(f.__dict__) 90 | return newf 91 | def argspec(f): 92 | return inspect.getfullargspec(f) 93 | ofullargspec = inspect.getfullargspec 94 | def _fullargspec(func): 95 | return __targspec(func, ofullargspec) 96 | inspect.getfullargspec = _fullargspec 97 | def _exec(m,g): exec(m,g) 98 | else: 99 | @set_docstring(__modify_function_doc) 100 | def modify_function(f, globals={}, name=None, code=None, defaults=None, 101 | closure=None): 102 | if code is None: code = f.func_code 103 | if name is None: name = f.__name__ 104 | if defaults is None: defaults = f.func_defaults 105 | if closure is None: closure = f.func_closure 106 | newf = types.FunctionType(code, dict(f.func_globals, **globals), name=name, 107 | argdefs=defaults, closure=closure) 108 | newf.__dict__.update(f.__dict__) 109 | return newf 110 | def argspec(f): 111 | return inspect.getargspec(f) 112 | eval(compile('def _exec(m,g): exec m in g', '', 'exec')) 113 | 114 | def _gettypes(args): 115 | return tuple(map(type, args)) 116 | 117 | oargspec = inspect.getargspec 118 | 119 | def _argspec(func): 120 | return __targspec(func, oargspec) 121 | 122 | inspect.getargspec = _argspec 123 | 124 | try: 125 | import IPython 126 | except ImportError: 127 | IPython = None 128 | else: 129 | # Replace IPython's argspec 130 | oipyargspec = IPython.core.oinspect.getargspec 131 | def _ipyargspec(func): 132 | return __targspec(func, oipyargspec, '__orig_arg_ipy__') 133 | IPython.core.oinspect.getargspec = _ipyargspec 134 | 135 | class overload(object): 136 | '''Simple function overloading in Python.''' 137 | _items = {} 138 | _types = {} 139 | @classmethod 140 | def argc(self, argc=None): 141 | '''Overloads a function based on the specified argument count. 142 | 143 | :param argc: The argument count. Defaults to ``None``. If ``None`` is given, automatically compute the argument count from the given function. 144 | 145 | .. note:: 146 | 147 | Keyword argument counts are NOT checked! In addition, when the argument count is automatically calculated, the keyword argument count is also ignored! 148 | 149 | Example:: 150 | 151 | @overload.argc() 152 | def func(a): 153 | print 'Function 1 called' 154 | 155 | @overload.argc() 156 | def func(a, b): 157 | print 'Function 2 called' 158 | 159 | func(1) # Calls first function 160 | func(1, 2) # Calls second function 161 | func() # Raises error 162 | ''' 163 | # Python 2 UnboundLocalError fix 164 | argc = {'argc': argc} 165 | def _wrap(f): 166 | def _newf(*args, **kwargs): 167 | if len(args) not in self._items[f.__name__]: 168 | raise TypeError("No overload of function '%s' that takes %d args" % (f.__name__, len(args))) 169 | return self._items[f.__name__][len(args)](*args, **kwargs) 170 | if f.__name__ not in self._items: 171 | self._items[f.__name__] = {} 172 | if argc['argc'] is None: 173 | argc['argc'] = len(argspec(f).args) 174 | self._items[f.__name__][argc['argc']] = f 175 | _newf.__name__ = f.__name__ 176 | _newf.__doc__ = f.__doc__ 177 | _newf.__is_overload__ = True 178 | _newf.__orig_arg__ = argspec(f) 179 | if IPython: 180 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f) 181 | return _newf 182 | return _wrap 183 | @classmethod 184 | def args(self, *argtypes, **kw): 185 | '''Overload a function based on the specified argument types. 186 | 187 | :param argtypes: The argument types. If None is given, get the argument types from the function annotations(Python 3 only) 188 | :param kw: Can only contain 1 argument, `is_cls`. If True, the function is assumed to be part of a class. 189 | 190 | Example:: 191 | 192 | @overload.args(str) 193 | def func(s): 194 | print 'Got string' 195 | 196 | @overload.args(int, str) 197 | def func(i, s): 198 | print 'Got int and string' 199 | 200 | @overload.args() 201 | def func(i:int): # A function annotation example 202 | print 'Got int' 203 | 204 | func('s') 205 | func(1) 206 | func(1, 's') 207 | func(True) # Raises error 208 | ''' 209 | 210 | # Python 2 UnboundLocalError fix...again! 211 | argtypes = {'args': tuple(argtypes)} 212 | def _wrap(f): 213 | def _newf(*args): 214 | if len(kw) == 0: 215 | cargs = args 216 | elif len(kw) == 1 and 'is_cls' in kw and kw['is_cls']: 217 | cargs = args[1:] 218 | else: 219 | raise ValueError('Invalid keyword args specified') 220 | if _gettypes(cargs) not in self._types[f.__name__]: 221 | raise TypeError("No overload of function '%s' that takes '%s' types and %d arg(s)" % (f.__name__, _gettypes(cargs), len(cargs))) 222 | return self._types[f.__name__][_gettypes(cargs)](*args) 223 | if f.__name__ not in self._types: 224 | self._types[f.__name__] = {} 225 | if len(argtypes['args']) == 1 and argtypes['args'][0] is None: 226 | aspec = argspec(f) 227 | argtypes['args'] = tuple(map(lambda x: x[1], sorted( 228 | aspec.annotations.items(), key=lambda x: aspec.args.index(x[0])))) 229 | self._types[f.__name__][argtypes['args']] = f 230 | _newf.__name__ = f.__name__ 231 | _newf.__doc__ = f.__doc__ 232 | _newf.__is_overload__ = True 233 | _newf.__orig_arg__ = argspec(f) 234 | if IPython: 235 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f) 236 | return _newf 237 | return _wrap 238 | 239 | class _RuntimeModule(object): 240 | 'Create a module object at runtime and insert it into sys.path. If called, same as :py:func:`from_objects`.' 241 | def __call__(self, *args, **kwargs): 242 | return self.from_objects(*args, **kwargs) 243 | @staticmethod 244 | @overload.argc(1) 245 | def from_objects(module_name_for_code_eval, **d): 246 | return _RuntimeModule.from_objects(module_name_for_code_eval, '', **d) 247 | @staticmethod 248 | @overload.argc(2) 249 | def from_objects(module_name_for_code_eval, docstring, **d): 250 | '''Create a module at runtime from `d`. 251 | 252 | :param name: The module name. 253 | 254 | :param docstring: Optional. The module's docstring. 255 | 256 | :param \*\*d: All the keyword args, mapped from name->value. 257 | 258 | Example: ``RuntimeModule.from_objects('name', 'doc', a=1, b=2)``''' 259 | module = types.ModuleType(module_name_for_code_eval, docstring) 260 | module.__dict__.update(d) 261 | module.__file__ = '' 262 | sys.modules[module_name_for_code_eval] = module 263 | return module 264 | @staticmethod 265 | @overload.argc(2) 266 | def from_string(module_name_for_code_eval, s): 267 | return _RuntimeModule.from_string(module_name_for_code_eval, '', s) 268 | @staticmethod 269 | @overload.argc(3) 270 | def from_string(module_name_for_code_eval, docstring, s): 271 | '''Create a module at runtime from `s``. 272 | 273 | :param name: The module name. 274 | 275 | :param docstring: Optional. The module docstring. 276 | 277 | :param s: A string containing the module definition.''' 278 | g = {} 279 | _exec(s, g) 280 | return _RuntimeModule.from_objects(module_name_for_code_eval, docstring, **dict(filter(lambda x: x[0] not in g_backup, g.items()))) 281 | 282 | RuntimeModule = _RuntimeModule() 283 | 284 | class CaseObject(object): 285 | 'The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments.' 286 | def __init__(self, value): 287 | self.value = value 288 | self.did_match = False 289 | self.did_pass = False 290 | def __call__(self, *args): 291 | if assign('res', not self.did_pass and any([self.value == rhs for rhs in args])): 292 | self.did_match = True 293 | return res 294 | def quit(self): 295 | 'Forces all other calls to return False. Equilavent of a ``break`` statement.' 296 | self.did_pass = True 297 | def default(self): 298 | "Executed if quit wasn't called." 299 | return not self.did_match and not self.did_pass 300 | def __iter__(self): 301 | yield self 302 | def __enter__(self): 303 | return self 304 | def __exit__(self, *args): 305 | pass 306 | 307 | def switch(value): 308 | '''A Python switch statement implementation that is used with a ``with`` statement. 309 | 310 | :param value: The value to "switch". 311 | 312 | ``with`` statement example:: 313 | 314 | with switch('x'): 315 | if case(1): print 'Huh?' 316 | if case('x'): print 'It works!!!' 317 | 318 | .. warning:: If you modify a variable named "case" in the same scope that you use the ``with`` statement version, you will get an UnboundLocalError. The soluction is to use ``with switch('x') as case:`` instead of ``with switch('x'):``.''' 319 | res = CaseObject(value) 320 | inspect.stack()[1][0].f_globals['case'] = res 321 | return res 322 | 323 | def tail_recurse(spec=None): 324 | '''Remove tail recursion from a function. 325 | 326 | :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value. 327 | 328 | .. note:: 329 | 330 | This function has a slight overhead that is noticable when using timeit. Only use it if the function has a possibility of going over the recursion limit. 331 | 332 | .. warning:: 333 | 334 | This function will BREAK any code that either uses any recursion other than tail recursion or calls itself multiple times. For example, ``def x(): return x()+1`` will fail. 335 | 336 | Example:: 337 | 338 | @tail_recurse() 339 | def add(a, b): 340 | if a == 0: return b 341 | return add(a-1, b+1) 342 | 343 | add(10000000, 1) # Doesn't max the recursion limit. 344 | ''' 345 | def _wrap(f): 346 | class TailRecursion(Exception): 347 | def __init__(self, args, kwargs): 348 | self.args = args 349 | self.kwargs = kwargs 350 | def _newf(*args, **kwargs): 351 | if inspect.stack()[1][3] == f.__name__: 352 | if (spec and spec(args)) or not spec: 353 | raise TailRecursion(args, kwargs) 354 | while True: 355 | try: 356 | res = f(*args, **kwargs) 357 | except TailRecursion as ex: 358 | args = ex.args 359 | kwargs = ex.kwargs 360 | continue 361 | else: 362 | return res 363 | _newf.__doc__ = f.__doc__ 364 | return _newf 365 | return _wrap 366 | 367 | def annotate(*args, **kwargs): 368 | '''Set function annotations using decorators. 369 | 370 | :param args: This is a list of annotations for the function, in the order of the function's parameters. For example, ``annotate('Annotation 1', 'Annotation 2')`` will set the annotations of parameter 1 of the function to ``Annotation 1``. 371 | 372 | :param kwargs: This is a mapping of argument names to annotations. Note that these are applied *after* the argument list, so any args set that way will be overriden by this mapping. If there is a key named `ret`, that will be the annotation for the function's return value. 373 | 374 | .. deprecated:: 0.5 375 | Use :func:`fannotate` instead. 376 | ''' 377 | def _wrap(f): 378 | if not hasattr(f, '__annotations__'): 379 | f.__annotations__ = {} 380 | if 'ret' in kwargs: 381 | f.__annotations__['return'] = kwargs.pop('ret') 382 | f.__annotations__.update(dict(zip(argspec(f).args, args))) 383 | f.__annotations__.update(kwargs) 384 | return f 385 | return _wrap 386 | 387 | def fannotate(*args, **kwargs): 388 | '''Set function annotations using decorators. 389 | 390 | :param \*args: The first positional argument is used for the function's return value; all others are discarded. 391 | 392 | :param \**kwargs: This is a mapping of argument names to annotations. 393 | 394 | Example:: 395 | 396 | @fannotate('This for the return value', a='Parameter a', b='Parameter b') 397 | def x(a, b): 398 | pass 399 | 400 | ''' 401 | def _wrap(f): 402 | if not hasattr(f, '__annotations__'): 403 | f.__annotations__ = {} 404 | if len(args) >= 1: 405 | f.__annotations__['return'] = args[0] 406 | f.__annotations__.update(kwargs) 407 | return f 408 | return _wrap 409 | 410 | def safe_unpack(seq, ln, fill=None): 411 | '''Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off. 412 | 413 | :param seq: The sequence to unpack. 414 | 415 | :param ln: The expected length of the sequence. 416 | 417 | :param fill: The value to substitute if the sequence is too small. Defaults to ``None``. 418 | 419 | Example:: 420 | 421 | s = 'a:b' 422 | a, b = safe_unpack(s.split(':'), 2) 423 | # a = 'a' 424 | # b = 'b' 425 | s = 'a' 426 | a, b = safe_unpack(s.split(':'), 2) 427 | # a = 'a' 428 | # b = None''' 429 | if len(seq) > ln: 430 | return seq[:ln] 431 | elif len(seq) < ln: 432 | return seq + type(seq)([fill]*(ln-len(seq))) 433 | else: 434 | return seq 435 | 436 | def assign(varname, value): 437 | '''Assign `value` to `varname` and return it. If `varname` is an attribute and the instance name it belongs to is not defined, a NameError is raised. 438 | This can be used to emulate assignment as an expression. For example, this:: 439 | 440 | if assign('x', 7): ... 441 | 442 | is equilavent to this C code:: 443 | 444 | if (x = 7) ... 445 | 446 | .. warning:: 447 | 448 | When assigning an attribute, the instance it belongs to MUST be declared as global prior to the assignment. Otherwise, the assignment will not work. 449 | ''' 450 | fd = inspect.stack()[1][0].f_globals 451 | if '.' not in varname: 452 | fd[varname] = value 453 | else: 454 | vsplit = list(map(str.strip, varname.split('.'))) 455 | if vsplit[0] not in fd: 456 | raise NameError('Unknown object: %s'%vsplit[0]) 457 | base = fd[vsplit[0]] 458 | for x in vsplit[1:-1]: 459 | base = getattr(base, x) 460 | setattr(base, vsplit[-1], value) 461 | return value 462 | 463 | def is_main(frame=1): 464 | "Return if the caller is main. Equilavent to ``__name__ == '__main__'``." 465 | return inspect.stack()[frame][0].f_globals['__name__'] == '__main__' 466 | 467 | def _call_if_main(frame, f, args): 468 | if is_main(frame): return f(*args) 469 | 470 | def call_if_main(f,*args): 471 | "Call the `f` with `args` if the caller's module is main." 472 | return _call_if_main(3,f,args) 473 | 474 | def run_main(f,*args): 475 | "Call `f` with the `args` and terminate the program with its return code if the caller's module is main." 476 | sys.exit(_call_if_main(3,f,args)) 477 | 478 | def compare_and_swap(var, compare, new): 479 | "If `var` is equal to `compare`, set it to `new`." 480 | if assign('v', inspect.stack()[1][0].f_globals)[var] == compare: 481 | v[var] = new 482 | -------------------------------------------------------------------------------- /src/mcts/models/HuggingFace_API.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | 3 | import torch 4 | from transformers import ( 5 | GenerationConfig, 6 | AutoModelForCausalLM, 7 | AutoTokenizer, 8 | ) 9 | from tqdm import tqdm 10 | import torch.nn.functional as F 11 | import numpy as np 12 | 13 | 14 | def load_HF_model(ckpt) -> tuple: 15 | tokenizer = AutoTokenizer.from_pretrained(ckpt) 16 | model = AutoModelForCausalLM.from_pretrained( 17 | ckpt, 18 | load_in_8bit=False, 19 | torch_dtype=torch.float16, 20 | device_map="auto", 21 | trust_remote_code=True, 22 | ) 23 | return tokenizer, model 24 | 25 | 26 | def generate_with_HF_model( 27 | tokenizer, model, input=None, temperature=0.8, top_p=0.95, top_k=40, num_beams=1, max_new_tokens=128, **kwargs 28 | ): 29 | try: 30 | inputs = tokenizer(input, return_tensors="pt") 31 | input_ids = inputs["input_ids"].to("cuda") 32 | generation_config = GenerationConfig( 33 | do_sample=True, 34 | temperature=temperature, 35 | top_p=top_p, 36 | top_k=top_k, 37 | num_beams=num_beams, 38 | **kwargs, 39 | ) 40 | with torch.no_grad(): 41 | generation_output = model.generate( 42 | input_ids=input_ids, 43 | generation_config=generation_config, 44 | return_dict_in_generate=True, 45 | output_scores=True, 46 | max_new_tokens=max_new_tokens, 47 | do_sample=True, 48 | pad_token_id=tokenizer.eos_token_id, 49 | ) 50 | s = generation_output.sequences[0] 51 | output = tokenizer.decode(s) 52 | except Exception as e: 53 | breakpoint() 54 | return output 55 | -------------------------------------------------------------------------------- /src/mcts/models/IO_System.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | 3 | import sys 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | import torch 6 | from run_src.rstar_utils import time_decorator 7 | sys.path.append(".") 8 | 9 | from typing import List, Dict 10 | 11 | try: 12 | from models.vLLM_API import generate_with_vLLM_model 13 | except: 14 | pass 15 | 16 | try: 17 | from models.OpenAI_API import generate_n_with_OpenAI_model 18 | except: 19 | pass 20 | 21 | 22 | class IO_System: 23 | """Input/Output system""" 24 | 25 | def __init__(self, args, tokenizer, model) -> None: 26 | self.api = args.api 27 | if self.api == "together": 28 | assert tokenizer is None and model is None 29 | elif self.api == "gpt3.5-turbo": 30 | assert tokenizer is None and isinstance(model, str) 31 | self.model_ckpt = args.model_ckpt 32 | self.temperature = args.temperature 33 | self.top_k = args.top_k 34 | self.top_p = args.top_p 35 | self.tokenizer = tokenizer 36 | self.model = model 37 | 38 | self.call_counter = 0 39 | self.token_counter = 0 40 | 41 | 42 | @time_decorator 43 | def generate(self, model_input, max_tokens: int, num_return: int, stop_tokens): 44 | if isinstance(model_input, str): 45 | if self.api == "vllm": 46 | vllm_response = generate_with_vLLM_model( 47 | self.model, 48 | input=model_input, 49 | temperature=self.temperature, 50 | top_p=self.top_p, 51 | top_k=self.top_k, 52 | n=num_return, 53 | max_tokens=max_tokens, 54 | stop=stop_tokens, 55 | ) 56 | io_output_list = [o.text for o in vllm_response[0].outputs] 57 | self.call_counter += 1 58 | self.token_counter += sum([len(o.token_ids) for o in vllm_response[0].outputs]) 59 | 60 | elif self.api == "OpenAI": 61 | gpt_response = generate_n_with_OpenAI_model( 62 | prompt=model_input, 63 | n=num_return, 64 | model_ckpt=self.model, 65 | max_tokens=max_tokens, 66 | max_completion_tokens=max_tokens, 67 | temperature=self.temperature, 68 | top_p=self.top_p, 69 | top_k=self.top_k, 70 | # stop=["\n", "Answer"], 71 | stop = stop_tokens, 72 | ) 73 | io_output_list = gpt_response 74 | self.call_counter += num_return 75 | self.token_counter += 0 76 | elif self.api == "debug": 77 | io_output_list = ["Debug: The answer is generated with debug mode, 233." for _ in range(num_return)] 78 | else: 79 | raise NotImplementedError(f"API {self.api} is not implemented.") 80 | elif isinstance(model_input, list): 81 | if self.api == "vllm": 82 | vllm_response = generate_with_vLLM_model( 83 | self.model, 84 | input=model_input, 85 | temperature=self.temperature, 86 | top_p=self.top_p, 87 | top_k=self.top_k, 88 | n=num_return, 89 | max_tokens=max_tokens, 90 | stop=stop_tokens, 91 | ) 92 | io_output_list = [ 93 | [o.text for o in resp_to_single_input.outputs] for resp_to_single_input in vllm_response 94 | ] 95 | self.call_counter += 1 96 | self.token_counter += sum( 97 | [ 98 | sum([len(o.token_ids) for o in resp_to_single_input.outputs]) 99 | for resp_to_single_input in vllm_response 100 | ] 101 | ) 102 | elif self.api == "gpt3.5-turbo": 103 | io_output_list = [] 104 | for input in model_input: 105 | gpt_response = generate_n_with_OpenAI_model( 106 | prompt=input, 107 | n=num_return, 108 | model_ckpt=self.model, 109 | max_tokens=max_tokens, 110 | temperature=self.temperature, 111 | top_p=self.top_p, 112 | top_k=self.top_k, 113 | stop=["\n", "Answer"], 114 | ) 115 | io_output_list.append(gpt_response) 116 | self.call_counter += num_return 117 | self.token_counter += 0 118 | elif self.api == "debug": 119 | io_output_list = [ 120 | ["Debug: The answer is generated with debug mode, 233." for _ in range(num_return)] 121 | for _ in model_input 122 | ] 123 | else: 124 | raise NotImplementedError(f"API {self.api} is not implemented.") 125 | 126 | return io_output_list 127 | -------------------------------------------------------------------------------- /src/mcts/models/OpenAI_API.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | 3 | import os 4 | import time 5 | from tqdm import tqdm 6 | import concurrent.futures 7 | from openai import OpenAI 8 | 9 | client = AzureOpenAI( 10 | api_version="", 11 | azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT", ""), 12 | api_key=os.environ.get("AZURE_OPENAI_API_KEY", ""), 13 | ) 14 | 15 | max_threads = 32 16 | 17 | 18 | def load_OpenAI_model(model): 19 | return None, model 20 | 21 | 22 | def generate_with_OpenAI_model( 23 | prompt, 24 | model_ckpt="gpt-35-turbo", 25 | max_tokens=2048, 26 | max_completion_tokens=4096, 27 | temperature=0.8, 28 | top_k=40, 29 | top_p=0.95, 30 | stop=["\n"], 31 | ): 32 | messages = [{"role": "user", "content": prompt}] 33 | parameters = { 34 | "model": model_ckpt, 35 | "temperature": temperature, 36 | "max_tokens": max_tokens, 37 | "max_completion_tokens": max_completion_tokens, 38 | "top_p": top_p, 39 | "stop": stop, 40 | "seed": 1, 41 | } 42 | 43 | ans, timeout = "", 5 44 | while not ans: 45 | try: 46 | time.sleep(timeout) 47 | completion = client.chat.completions.create(messages=messages, **parameters) 48 | ans = completion.choices[0].message.content 49 | 50 | except Exception as e: 51 | print(e) 52 | if not ans: 53 | timeout = timeout * 2 54 | if timeout > 120: 55 | timeout = 1 56 | try: 57 | print(f"Will retry after {timeout} seconds ...") 58 | except: 59 | pass 60 | return ans 61 | 62 | 63 | def generate_n_with_OpenAI_model( 64 | prompt, 65 | n=1, 66 | model_ckpt="gpt-35-turbo", 67 | max_tokens=4096, 68 | max_completion_tokens=4096, 69 | temperature=0.8, 70 | top_k=40, 71 | top_p=0.95, 72 | stop=["\n"], 73 | max_threads=16, 74 | disable_tqdm=True, 75 | ): 76 | preds = [] 77 | with concurrent.futures.ProcessPoolExecutor(max_workers=max_threads) as executor: 78 | futures = [ 79 | executor.submit(generate_with_OpenAI_model, prompt, model_ckpt, max_tokens, max_completion_tokens, temperature, top_k, top_p, stop) 80 | for _ in range(n) 81 | ] 82 | for i, future in tqdm( 83 | enumerate(concurrent.futures.as_completed(futures)), 84 | total=len(futures), 85 | desc="running evaluate", 86 | disable=disable_tqdm, 87 | ): 88 | ans = future.result() 89 | preds.append(ans) 90 | return preds 91 | -------------------------------------------------------------------------------- /src/mcts/models/vLLM_API.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | 3 | from vllm import LLM, SamplingParams 4 | from transformers import AutoTokenizer 5 | import numpy as np 6 | import math 7 | 8 | 9 | def load_vLLM_model(model_ckpt, seed, tensor_parallel_size=1, half_precision=False, max_num_seqs=256): 10 | tokenizer = AutoTokenizer.from_pretrained(model_ckpt) 11 | 12 | if half_precision: 13 | llm = LLM( 14 | model=model_ckpt, 15 | dtype="half", 16 | tensor_parallel_size=tensor_parallel_size, 17 | seed=seed, 18 | trust_remote_code=True, 19 | max_num_seqs=max_num_seqs, 20 | swap_space=16, 21 | ) 22 | else: 23 | llm = LLM( 24 | model=model_ckpt, 25 | tensor_parallel_size=tensor_parallel_size, 26 | seed=seed, 27 | trust_remote_code=True, 28 | max_num_seqs=max_num_seqs, 29 | swap_space=16, 30 | # max_model_len=14336, 31 | ) 32 | 33 | return tokenizer, llm 34 | 35 | 36 | def generate_with_vLLM_model( 37 | model, 38 | input, 39 | temperature=0.8, 40 | top_p=0.95, 41 | top_k=40, 42 | repetition_penalty=1.1, 43 | n=1, 44 | max_tokens=256, 45 | logprobs=1, 46 | stop=[], 47 | ): 48 | sampling_params = SamplingParams( 49 | temperature=temperature, 50 | top_p=top_p, 51 | top_k=top_k, 52 | repetition_penalty=repetition_penalty, 53 | n=n, 54 | logprobs=logprobs, 55 | max_tokens=max_tokens, 56 | stop=stop, 57 | ) 58 | 59 | output = model.generate(input, sampling_params, use_tqdm=False) 60 | return output 61 | 62 | 63 | if __name__ == "__main__": 64 | model_ckpt = "deepseek-ai/deepseek-coder-1.3b-instruct" 65 | tokenizer, model = load_vLLM_model(model_ckpt, seed=42, tensor_parallel_size=1, half_precision=False) 66 | input = "write a python function to calculate the sum of two numbers" 67 | output = generate_with_vLLM_model(model, input) 68 | # breakpoint() 69 | print("-------------------------------------") 70 | print(output[0].outputs[0].text) 71 | -------------------------------------------------------------------------------- /src/mcts/prompts/TACO/examples.txt: -------------------------------------------------------------------------------- 1 | ### Question 2 | In this problem you have to implement an algorithm to defragment your hard disk. The hard disk consists of a sequence of clusters, numbered by integers from 1 to n. The disk has m recorded files, the i-th file occupies clusters with numbers ai, 1, ai, 2, ..., ai, ni. These clusters are not necessarily located consecutively on the disk, but the order in which they are given corresponds to their sequence in the file (cluster ai, 1 contains the first fragment of the i-th file, cluster ai, 2 has the second fragment, etc.). Also the disc must have one or several clusters which are free from files. 3 | 4 | You are permitted to perform operations of copying the contents of cluster number i to cluster number j (i and j must be different). Moreover, if the cluster number j used to keep some information, it is lost forever. Clusters are not cleaned, but after the defragmentation is complete, some of them are simply declared unusable (although they may possibly still contain some fragments of files). 5 | 6 | Your task is to use a sequence of copy operations to ensure that each file occupies a contiguous area of memory. Each file should occupy a consecutive cluster section, the files must follow one after another from the beginning of the hard disk. After defragmentation all free (unused) clusters should be at the end of the hard disk. After defragmenting files can be placed in an arbitrary order. Clusters of each file should go consecutively from first to last. See explanatory examples in the notes. 7 | 8 | Print the sequence of operations leading to the disk defragmentation. Note that you do not have to minimize the number of operations, but it should not exceed 2n. 9 | 10 | Input 11 | 12 | The first line contains two integers n and m (1 ≤ n, m ≤ 200) — the number of clusters and the number of files, correspondingly. Next m lines contain descriptions of the files. The first number in the line is ni (ni ≥ 1), the number of clusters occupied by the i-th file. Then follow ni numbers ai, 1, ai, 2, ..., ai, ni (1 ≤ ai, j ≤ n). It is guaranteed that each cluster number occurs not more than once and , that is, there exists at least one unused cluster. Numbers on each line are separated by spaces. 13 | 14 | Output 15 | 16 | In the first line print a single integer k (0 ≤ k ≤ 2n) — the number of operations needed to defragment the disk. Next k lines should contain the operations' descriptions as "i j" (copy the contents of the cluster number i to the cluster number j). 17 | 18 | Examples 19 | 20 | Input 21 | 22 | 7 2 23 | 2 1 2 24 | 3 3 4 5 25 | 26 | 27 | Output 28 | 29 | 0 30 | 31 | 32 | Input 33 | 34 | 7 2 35 | 2 1 3 36 | 3 2 4 5 37 | 38 | 39 | Output 40 | 41 | 3 42 | 2 6 43 | 3 2 44 | 6 3 45 | 46 | Note 47 | 48 | Let's say that a disk consists of 8 clusters and contains two files. The first file occupies two clusters and the second file occupies three clusters. Let's look at examples of correct and incorrect positions of files after defragmentation. 49 | 50 | 51 | 52 | Example 2: each file must occupy a contiguous area of memory. 53 | 54 | Example 3: the order of files to each other is not important, at first the second file can be written, and then — the first one. 55 | 56 | Example 4: violating the order of file fragments to each other is not allowed. 57 | 58 | Example 5: unused clusters should be located at the end, and in this example the unused clusters are 3, 7, 8. 59 | 60 | ### Solution 61 | Let's think step by step. 62 | 63 | ### Step 1: **Defining High-Level Structure (Action 1)** 64 | 65 | In this first step, we define the high-level structure of the solution. Our objective is to outline how to store the files and clusters, track which clusters are in use, and prepare the basic setup for further refinement. 66 | 67 | #### Goal: 68 | 1. Parse the input to retrieve the number of clusters `n` and files `m`. 69 | 2. Initialize a `memory` array to track which files occupy which clusters (initialized to `-1` for unused clusters). 70 | 3. Store each file's cluster allocation. 71 | 72 | #### Pseudocode: 73 | ```pseudo 74 | Function defragmentation(n, m, files): 75 | Initialize memory array of size n+1 with -1 (to represent unused clusters) 76 | Initialize an empty list to store the files and their respective clusters 77 | For each file in files: 78 | Store its clusters in the list and mark them in the memory array 79 | Return the initialized memory and files list 80 | ``` 81 | 82 | --- 83 | 84 | ### Step 2: **Finding Free Clusters (Action 2)** 85 | 86 | In this step, we refine the pseudocode to handle finding the first free cluster on the disk. This is necessary because we need free clusters to move file fragments to contiguous positions. 87 | 88 | #### Goal: 89 | 1. Identify the first available free cluster. 90 | 2. Return the index of this free cluster for use in later operations. 91 | 92 | #### Pseudocode: 93 | ```pseudo 94 | Function find_free_cluster(memory): 95 | For i from 1 to n: 96 | If memory[i] == -1: 97 | Return i // Return first free cluster found 98 | Return -1 // If no free cluster is found 99 | ``` 100 | 101 | --- 102 | 103 | ### Step 3: **Defragmenting Files (Action 2)** 104 | 105 | Now, we move to the main part of the problem: defragmenting the files. This involves moving file fragments to available contiguous clusters. We will loop through the files and their respective clusters, identifying fragments that are not in their correct position and moving them accordingly. 106 | 107 | #### Goal: 108 | 1. Check if the current file's clusters are contiguous. 109 | 2. If not, use the free clusters to move fragments into place. 110 | 3. Track the sequence of operations that result in the defragmentation. 111 | 112 | #### Pseudocode: 113 | ```pseudo 114 | Function defragment_files(memory, files): 115 | Initialize result list to store operations 116 | Initialize cnt to 0 // To track the expected position of file fragments 117 | For each file in files: 118 | For each cluster in file: 119 | If cluster is not in the correct position (i != cnt): 120 | Find the free cluster 121 | Append operation (move cluster to the free position) 122 | Update memory to reflect the changes 123 | Update free cluster to next available cluster 124 | Update the file list accordingly 125 | Return the list of operations 126 | ``` 127 | 128 | --- 129 | 130 | ### Step 4: **Refining File Movement and Operations (Action 2)** 131 | 132 | In this step, we refine the logic for moving the files and ensuring that no clusters are overwritten during the defragmentation process. We will also ensure that the correct operations are recorded. 133 | 134 | #### Goal: 135 | 1. Handle the case where we need to move fragments from a used cluster to a free cluster. 136 | 2. Ensure that the file list and memory are updated correctly during each operation. 137 | 138 | #### Pseudocode: 139 | ```pseudo 140 | For each file in files: 141 | For each cluster in file: 142 | If cluster is not in the correct place: 143 | Find a free cluster 144 | Perform a copy operation from cluster to free cluster 145 | Record the operation 146 | Update memory and file list 147 | Ensure no overwriting of other files' clusters 148 | Return the list of operations 149 | ``` 150 | 151 | --- 152 | 153 | ### Step 5: **Generating Python Code (Action 3)** 154 | 155 | Now that the pseudocode has been fully refined, we proceed to translate it into Python code. This involves implementing the main function that drives the defragmentation process, handling input/output correctly, and printing the sequence of operations. 156 | 157 | #### Goal: 158 | 1. Implement the `defragmentation` function in Python. 159 | 2. Ensure the input is parsed correctly and the output matches the expected format. 160 | 3. Ensure the program handles the movement of clusters and generates the correct sequence of operations. 161 | 162 | #### Python Code: 163 | 164 | ```python 165 | def defragmentation(n, m, files): 166 | # Initialize memory array with -1 (representing unused clusters) 167 | memory = [-1] * (n + 1) 168 | d = [] 169 | 170 | # Mark memory with files' cluster allocations 171 | for i in range(m): 172 | file = files[i][1:] # Ignore first value (size of file) 173 | for cluster in file: 174 | memory[cluster] = i # Mark the memory with the file index 175 | d.append(file) 176 | 177 | free = -1 178 | # Find the first free cluster 179 | for i in range(1, n + 1): 180 | if memory[i] == -1: 181 | free = i 182 | break 183 | 184 | res = [] # List to store operations 185 | cnt = 0 186 | 187 | # Start defragmenting the files 188 | for x in range(m): 189 | for i in range(len(d[x])): 190 | cnt += 1 191 | if d[x][i] == cnt: 192 | continue # File fragment is already in the correct position 193 | 194 | y = memory[cnt] 195 | if y == -1: 196 | # Move fragment to the free cluster 197 | res.append((d[x][i], cnt)) 198 | memory[d[x][i]] = -1 199 | free = d[x][i] 200 | d[x][i] = cnt 201 | memory[cnt] = x 202 | continue 203 | 204 | # Move fragments to fill gaps 205 | for j in range(len(d[y])): 206 | if d[y][j] == cnt: 207 | res.append((cnt, free)) 208 | res.append((d[x][i], cnt)) 209 | d[y][j] = free 210 | memory[free] = y 211 | free = d[x][i] 212 | memory[free] = -1 213 | d[x][i] = cnt 214 | memory[cnt] = x 215 | 216 | # Output the result 217 | print(len(res)) 218 | for op in res: 219 | print(op[0], op[1]) 220 | 221 | n, m = map(int, input().split()) 222 | files = [list(map(int, input().split())) for _ in range(m)] 223 | defragmentation(n, m, files) 224 | ``` 225 | 226 | 227 | 228 | ### Question 229 | A gene is represented as a string of length $n$ (where $n$ is divisible by $4$), composed of the letters $\mbox{A}$, $\mbox{C}$, $\textbf{T}$, and $\mbox{G}$. 230 | It is considered to be steady if each of the four letters occurs exactly $\frac{n}{4}$ times. For example, $\textbf{GACT}$ and $\textbf{AAGTGCCT}$ are both steady genes. 231 | 232 | Bear Limak is a famous biotechnology scientist who specializes in modifying bear DNA to make it steady. Right now, he is examining a gene represented as a string $\textit{gene}$. It is not necessarily steady. Fortunately, Limak can choose one (maybe empty) substring of $\textit{gene}$ and replace it with any string of the same length. 233 | 234 | Modifying a large substring of bear genes can be dangerous. 235 | Given a string $\textit{gene}$, can you help Limak find the length of the smallest possible substring that he can replace to make $\textit{gene}$ a steady gene? 236 | 237 | Note: A substring of a string $\boldsymbol{\mathrm{~S~}}$ is a subsequence made up of zero or more contiguous characters of $\boldsymbol{\mathrm{~S~}}$. 238 | 239 | As an example, consider $\textit{gene}=\textit{ACTGAAAG}$. The substring $\boldsymbol{AA}$ just before or after $\mbox{G}$ can be replaced with $\textit{CT}$ or $\boldsymbol{TC}$. One selection would create $\textbf{ACTGACTG}$. 240 | 241 | Function Description 242 | 243 | Complete the $\textit{steadyGene}$ function in the editor below. It should return an integer that represents the length of the smallest substring to replace. 244 | 245 | steadyGene has the following parameter: 246 | 247 | gene: a string 248 | 249 | Input Format 250 | 251 | The first line contains an interger $n$ divisible by $4$, that denotes the length of a string $\textit{gene}$. 252 | 253 | The second line contains a string $\textit{gene}$ of length $n$. 254 | 255 | Constraints 256 | 257 | $4\leq n\leq500000$ 258 | $n$ is divisible by $4$ 259 | $gene[i]\in[CGAT]$ 260 | 261 | Subtask 262 | 263 | $4\leq n\leq2000$ in tests worth $30\%$ points. 264 | 265 | Output Format 266 | 267 | Print the length of the minimum length substring that can be replaced to make $\textit{gene}$ stable. 268 | 269 | Sample Input 270 | 8 271 | GAAATAAA 272 | 273 | Sample Output 274 | 5 275 | 276 | Explanation 277 | 278 | One optimal solution is to replace $\textbf{AAATA}$ with $\textbf{TTCCG}$ resulting in $\textbf{GTTCCGAA}$. 279 | 280 | The replaced substring has length $5$. 281 | 282 | ### Solution 283 | Let's think step by step. 284 | 285 | ### Step 1: **Defining High-Level Structure (Action 1)** 286 | 287 | #### Goal: 288 | Define the overall solution framework and how the problem will be approached: 289 | 1. Parse the input to get the string and calculate the target frequency of each nucleotide. 290 | 2. Identify imbalances (if any) in nucleotide counts. 291 | 3. Use a sliding window to find the smallest substring that can balance the nucleotide counts when replaced. 292 | 293 | #### Pseudocode: 294 | ```pseudo 295 | Function steadyGene(gene, n): 296 | target = n / 4 # Each nucleotide should appear n/4 times 297 | Count the frequency of each nucleotide in the gene 298 | If all nucleotide counts are <= target: 299 | Return 0 # Gene is already steady 300 | 301 | Use a sliding window approach: 302 | Initialize start and end of the window 303 | Track the frequency of nucleotides in the current window 304 | Expand and shrink the window to find the minimal length substring that can be replaced 305 | 306 | Return the minimal length of the valid substring 307 | ``` 308 | 309 | --- 310 | 311 | ### Step 2: **Calculating Frequencies and Imbalances (Action 2)** 312 | 313 | #### Goal: 314 | 1. Count the frequency of each nucleotide in the gene. 315 | 2. Identify nucleotides that exceed the target frequency and need balancing. 316 | 317 | #### Pseudocode: 318 | ```pseudo 319 | Function calculateImbalances(gene, target): 320 | Initialize a frequency dictionary for nucleotides ('A', 'C', 'T', 'G') with counts as 0 321 | For each character in the gene: 322 | Increment its count in the frequency dictionary 323 | 324 | Identify excess nucleotides: 325 | For each nucleotide in the dictionary: 326 | If its count > target: 327 | Record the excess 328 | 329 | Return the frequency dictionary and excess nucleotides 330 | ``` 331 | 332 | --- 333 | 334 | ### Step 3: **Sliding Window Implementation (Action 2)** 335 | 336 | #### Goal: 337 | 1. Use a sliding window to dynamically evaluate substrings. 338 | 2. Track the frequency of nucleotides within the window. 339 | 3. Ensure the substring balances the excess nucleotides. 340 | 341 | #### Pseudocode: 342 | ```pseudo 343 | Function findMinimalSubstring(gene, n, target, full_freq): 344 | Initialize start = 0, end = 0, and current_freq = {} 345 | Initialize min_length = n (start with maximum possible length) 346 | 347 | While end < n: 348 | Expand the window by including gene[end] 349 | Update the current_freq for gene[end] 350 | 351 | While the window satisfies the balancing condition: 352 | Update min_length with the current window size 353 | Shrink the window from the start 354 | Update current_freq for gene[start] 355 | Increment start 356 | 357 | Increment end 358 | 359 | Return min_length 360 | ``` 361 | 362 | --- 363 | 364 | ### Step 4: **Satisfaction Check (Action 2)** 365 | 366 | #### Goal: 367 | Ensure that the current substring can balance the nucleotide counts when replaced. 368 | 369 | #### Pseudocode: 370 | ```pseudo 371 | Function satisfiesCondition(current_freq, full_freq, target): 372 | For each nucleotide in ('A', 'C', 'T', 'G'): 373 | If full_freq[nucleotide] > target: 374 | If current_freq[nucleotide] < full_freq[nucleotide] - target: 375 | Return False 376 | Return True 377 | ``` 378 | 379 | --- 380 | 381 | ### Step 5: **Implementing Python Code (Action 3)** 382 | 383 | #### Python Code: 384 | The refined pseudocode is translated into Python as follows: 385 | 386 | ```python 387 | def satisfied(freq): 388 | for bp in base_pairs: 389 | if full_freq[bp] > n / 4: 390 | if freq[bp] < full_freq[bp] - n / 4: 391 | return False 392 | return True 393 | 394 | def steadyGene(n, gene): 395 | base_pairs = ['A', 'C', 'T', 'G'] 396 | 397 | # Calculate the frequency of each nucleotide in the gene 398 | full_freq = {bp: 0 for bp in base_pairs} 399 | for nuc in gene: 400 | full_freq[nuc] += 1 401 | 402 | # Check if the gene is already steady 403 | if all(full_freq[bp] <= n / 4 for bp in base_pairs): 404 | return 0 # No replacement needed 405 | 406 | # Sliding window to find the minimal substring 407 | cur_freq = {bp: 0 for bp in base_pairs} 408 | start = 0 409 | min_length = n 410 | 411 | for end in range(n): 412 | cur_freq[gene[end]] += 1 413 | 414 | while satisfied(cur_freq): 415 | # Update minimal length 416 | min_length = min(min_length, end - start + 1) 417 | # Shrink the window 418 | cur_freq[gene[start]] -= 1 419 | start += 1 420 | 421 | return min_length 422 | 423 | n = int(input()) 424 | gene = input() 425 | print(steadyGene(n, gene)) 426 | ``` 427 | -------------------------------------------------------------------------------- /src/mcts/prompts/TACO/prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt_template": "### Instruction\nPlease refer to the given task description and provide a thought process in the form of step-by-step pseudocode refinement.\n\nA curious user has approached you with a programming question. You should give step-by-step solutions to the user's questions. For each step you can choose one of the following three actions\n\n Defining algorithm Structures Using pseudocode\n**Description:** \nOutline the core functions and overall structure of the solution without getting into implementation details. Define inputs, outputs, and the main tasks each function will perform.\n\n Refine part of the pseudocode\n**Description:** \nAdd more details to the pseudocode, specifying the exact steps, logic, and operations each function will carry out. This prepares the pseudocode for actual coding.\n\n Generate python code from the pseudocode\n**Description:** \nTranslate the refined pseudocode into executable Python code, making sure to handle inputs, outputs, and ensure correctness in the implementation.\n\n**Note:**\n- You can choose one of the three actions for each step.\n- Provide a detailed explanation of the reasoning behind each step.\n- Try to refer to the reference code as much as possible, but you can also modify it if needed (e.g. change variable names, add some comments, etc.).\n\n### Examples\n{examples}\n\n### Question\n{question}\n\n### Solution\nLet's think step by step.\n", 3 | "stop_tokens": [ 4 | "Step_End" 5 | ] 6 | } -------------------------------------------------------------------------------- /src/mcts/run_outputs/data_examples/answer_sheets/Question 0000 - Answer.json: -------------------------------------------------------------------------------- 1 | {"id": 0, "problem": "**Step 1:** Create a function called `encode()` to replace all the lowercase vowels in a given string with numbers according to the following pattern:\n```\na -> 1\ne -> 2\ni -> 3\no -> 4\nu -> 5\n```\n\nFor example, `encode(\"hello\")` would return `\"h2ll4\"`. There is no need to worry about uppercase vowels in this kata.\n\n**Step 2:** Now create a function called `decode()` to turn the numbers back into vowels according to the same pattern shown above.\n\nFor example, `decode(\"h3 th2r2\")` would return `\"hi there\"`.\n\nFor the sake of simplicity, you can assume that any numbers passed into the function will correspond to vowels.", "gold_solution": "def encode(s, t=str.maketrans('aeiou', '12345')):\n\treturn s.translate(t)\n\ndef decode(s, t=str.maketrans('12345', 'aeiou')):\n\treturn s.translate(t)\n", "test_case": {"fn_name": "encode", "inputs": [["How are you today?"], ["hello"], ["This is an encoding test."]], "outputs": [["H4w 1r2 y45 t4d1y?"], ["h2ll4"], ["Th3s 3s 1n 2nc4d3ng t2st."]]}, "difficulty": "EASY", "time_taken": "299.88s"} -------------------------------------------------------------------------------- /src/mcts/run_outputs/data_examples/answer_sheets/Question 0000 - Best Solutions.json: -------------------------------------------------------------------------------- 1 | { 2 | "trace": { 3 | "0": { 4 | "user_question": "**Step 1:** Create a function called `encode()` to replace all the lowercase vowels in a given string with numbers according to the following pattern:\n```\na -> 1\ne -> 2\ni -> 3\no -> 4\nu -> 5\n```\n\nFor example, `encode(\"hello\")` would return `\"h2ll4\"`. There is no need to worry about uppercase vowels in this kata.\n\n**Step 2:** Now create a function called `decode()` to turn the numbers back into vowels according to the same pattern shown above.\n\nFor example, `decode(\"h3 th2r2\")` would return `\"hi there\"`.\n\nFor the sake of simplicity, you can assume that any numbers passed into the function will correspond to vowels.", 5 | "ost_step": { 6 | "1": "### Step 1: **Defining High-Level Structure for `encode()` and `decode()` (Action 1)**\n\nIn this first step, we define the overall structure of both the `encode()` and `decode()` functions. The goal of the `encode()` function is to replace vowels in the string with their respective numbers, while the `decode()` function will reverse this process, converting numbers back into the corresponding vowels.\n\n#### Plan for `encode()`:\n1. Loop through each character in the input string.\n2. For each lowercase vowel (`a`, `e`, `i`, `o`, `u`), replace it with the corresponding number (`a -> 1`, `e -> 2`, etc.).\n3. Keep all other characters (non-vowels) unchanged.\n\n#### Plan for `decode()`:\n1. Loop through each character in the input string.\n2. For each number (`1`, `2`, `3`, `4`, `5`), replace it with the corresponding vowel (`1 -> a`, `2 -> e`, etc.).\n3. Keep all other characters (non-numeric) unchanged.\n\n#### Pseudocode:\n```pseudo\nFunction encode(string):\n Initialize a dictionary for vowel to number mappings\n Initialize an empty string for the result\n For each character in the string:\n If the character is a vowel:\n Append the corresponding number to the result\n Else:\n Append the character as is\n Return the result string\n\nFunction decode(string):\n Initialize a dictionary for number to vowel mappings\n Initialize an empty string for the result\n For each character in the string:\n If the character is a number:\n Append the corresponding vowel to the result\n Else:\n Append the character as is\n Return the result string\n```\n\n", 7 | "2": "### Step 2: **Refine the `encode()` Function (Action 2)**\n\nIn this step, we will refine the pseudocode for the `encode()` function by detailing how to handle each character in the input string and map vowels to numbers. This includes the setup of the dictionary for encoding and the loop to process each character.\n\n#### Pseudocode:\n```pseudo\nFunction encode(string):\n Initialize a dictionary `vowel_to_number` with mappings: {'a': '1', 'e': '2', 'i': '3', 'o': '4', 'u': '5'}\n Initialize an empty string `encoded_string` to store the result\n For each character `char` in the input `string`:\n If `char` is in `vowel_to_number`:\n Append `vowel_to_number[char]` to `encoded_string`\n Else:\n Append `char` to `encoded_string`\n Return `encoded_string`\n```\n\n#### Reasoning:\n- We use a dictionary to efficiently map each vowel to its corresponding number.\n- For each character in the string, we check if it's a vowel by looking it up in the dictionary.\n- If it's a vowel, we append the mapped number to the result string.\n- If it's not a vowel, we append the character as is to maintain the integrity of non-vowel characters.", 8 | "3": "### Step 3: **Refine the `decode()` Function (Action 2)**\n\nIn this step, we refine the pseudocode for the `decode()` function. The goal is to handle the process of replacing numbers with the corresponding vowels. We will establish a dictionary for number-to-vowel mappings and refine the logic for parsing through the string.\n\n#### Pseudocode:\n```pseudo\nFunction decode(string):\n Initialize a dictionary `number_to_vowel` with mappings: {'1': 'a', '2': 'e', '3': 'i', '4': 'o', '5': 'u'}\n Initialize an empty string `decoded_string` to store the result\n For each character `char` in the input `string`:\n If `char` is in `number_to_vowel`:\n Append `number_to_vowel[char]` to `decoded_string`\n Else:\n Append `char` to `decoded_string`\n Return `decoded_string`\n```\n\n#### Reasoning:\n- We create a dictionary `number_to_vowel` to map each number (1-5) back to its respective vowel.\n- For each character in the string, we check if it's a number (1-5) by looking it up in the dictionary.\n- If it's a number, we append the corresponding vowel to the result string.\n- If it's not a number (e.g., spaces or other characters), we simply append it as is.\n\nThis approach ensures that we correctly reverse the encoding process, replacing numbers with vowels while maintaining the integrity of non-numeric characters.\n\n", 9 | "4": "### Step 4: **Generating Python Code for `encode()` and `decode()` (Action 3)**\n\nNow that the pseudocode for both `encode()` and `decode()` functions has been fully refined, we translate it into Python code. This involves implementing the logic for both encoding vowels as numbers and decoding numbers back into vowels.\n\n#### Python Code:\n```python\ndef encode(string):\n # Mapping from vowels to numbers\n vowel_to_number = {'a': '1', 'e': '2', 'i': '3', 'o': '4', 'u': '5'}\n # String to store encoded result\n encoded_string = \"\"\n # Iterate over each character in the input string\n for char in string:\n # Check if the character is a vowel and map it to the corresponding number\n if char in vowel_to_number:\n encoded_string += vowel_to_number[char]\n else:\n # Keep non-vowel characters unchanged\n encoded_string += char\n return encoded_string\n\ndef decode(string):\n # Mapping from numbers to vowels\n number_to_vowel = {'1': 'a', '2': 'e', '3': 'i', '4': 'o', '5': 'u'}\n # String to store decoded result\n decoded_string = \"\"\n # Iterate over each character in the input string\n for char in string:\n # Check if the character is a number and map it to the corresponding vowel\n if char in number_to_vowel:\n decoded_string += number_to_vowel[char]\n else:\n # Keep non-numeric characters unchanged\n decoded_string += char\n return decoded_string\n\n# Example usage:\nprint(encode(\"hello\")) # Output: \"h2ll4\"\nprint(decode(\"h3 th2r2\")) # Output: \"hi there\"\n```\n\n#### Explanation:\n- For the `encode()` function, we create a dictionary `vowel_to_number` to map vowels to numbers. We iterate over each character in the input string, replacing vowels with their corresponding numbers and keeping other characters unchanged.\n- For the `decode()` function, we create a dictionary `number_to_vowel` to map numbers back to vowels. We iterate over each character in the input string, replacing numbers with their corresponding vowels and keeping other characters unchanged.\n- The code is efficient and straightforward, utilizing dictionary lookups to perform the necessary character replacements.\n\nThis completes the translation of the refined pseudocode into Python, providing a functional implementation of the `encode()` and `decode()` operations." 10 | }, 11 | "ost_step_value": { 12 | "1": 1.0, 13 | "2": 1.0, 14 | "3": 1.0, 15 | "4": 1.0 16 | } 17 | } 18 | }, 19 | "rollout_id": 0, 20 | "value": 1.0 21 | } -------------------------------------------------------------------------------- /src/mcts/run_outputs/data_examples/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "note": "default", 3 | "api": "OpenAI", 4 | "seed": 42, 5 | "verbose": true, 6 | "wandb_mode": "disabled", 7 | "model_ckpt": "gpt-4o", 8 | "model_parallel": false, 9 | "half_precision": false, 10 | "max_tokens": 4096, 11 | "temperature": 0.8, 12 | "top_k": 40, 13 | "top_p": 0.95, 14 | "num_beams": 1, 15 | "max_num_worker": 3, 16 | "test_batch_size": 1, 17 | "tensor_parallel_size": 1, 18 | "prompts_root": "prompts", 19 | "data_root": "data", 20 | "dataset_name": "TACO", 21 | "test_json_filename": "test_one", 22 | "start_idx": 0, 23 | "end_idx": Infinity, 24 | "run_outputs_root": "run_outputs", 25 | "eval_outputs_root": "eval_outputs", 26 | "run_outputs_dir": "run_outputs/", 27 | "num_rollouts": 12, 28 | "max_depth_allowed": 10, 29 | "mcts_discount_factor": 1.0, 30 | "mcts_exploration_weight": 2.0, 31 | "mcts_weight_scheduler": "const", 32 | "mcts_num_last_votes": 32, 33 | "save_tree": false, 34 | "num_sampling": 3, 35 | "enable_potential_score": false, 36 | "examples_txt_path": "prompts/TACO/examples.txt", 37 | "prompt_config_path": "prompts/TACO/prompt.json", 38 | "answer_sheets_dir": "run_outputs/answer_sheets", 39 | "cuda_0": null, 40 | "cuda_1": null, 41 | "cuda_2": null, 42 | "cuda_3": null 43 | } -------------------------------------------------------------------------------- /src/mcts/run_outputs/data_examples/intermediate_result.txt: -------------------------------------------------------------------------------- 1 | Total calls: 48, Avg calls: 48.00 2 | Total tokens: 0, Avg tokens: 0.00 3 | -------------------------------------------------------------------------------- /src/mcts/run_src/MCTS_backbone.py: -------------------------------------------------------------------------------- 1 | """ 2 | A minimal implementation of Monte Carlo tree search (MCTS) in Python 3 3 | Luke Harold Miles, July 2019, Public Domain Dedication 4 | See also https://en.wikipedia.org/wiki/Monte_Carlo_tree_search 5 | https://gist.github.com/qpwo/c538c6f73727e254fdc7fab81024f6e1 6 | """ 7 | 8 | from abc import ABC, abstractmethod 9 | from collections import defaultdict 10 | from typing import Dict, List 11 | import math, random 12 | 13 | 14 | node_cnt = 0 15 | 16 | 17 | def verbose_print(s: str, verbose: bool): 18 | if verbose: 19 | print(s) 20 | 21 | 22 | class MCTS_Node(ABC): 23 | """ 24 | A representation of a single board state. 25 | MCTS works by constructing a tree of these Nodes. 26 | Could be e.g. a chess or checkers board state. 27 | """ 28 | 29 | def __init__(self) -> None: 30 | super().__init__() 31 | 32 | global node_cnt 33 | self.id = node_cnt 34 | node_cnt += 1 35 | 36 | self.rollout_id = None 37 | 38 | def set_rollout_id(self, rollout_id: int): 39 | self.rollout_id = rollout_id 40 | 41 | @abstractmethod 42 | def find_children(self, rollout_id: int): 43 | "All possible successors of this board state" 44 | raise NotImplementedError 45 | 46 | @abstractmethod 47 | def is_terminal(self): 48 | "Returns True if the node has no children" 49 | raise NotImplementedError 50 | 51 | @abstractmethod 52 | def calculate_reward(self): 53 | "Assumes `self` is terminal node. 1=win, 0=loss, .5=tie, etc" 54 | raise NotImplementedError 55 | 56 | 57 | 58 | class MCTS_Searcher: 59 | "Monte Carlo tree searcher. First rollout the tree then choose a move." 60 | 61 | def __init__( 62 | self, 63 | exploration_weight: float, 64 | weight_scheduler: str, 65 | num_rollouts: int, 66 | discount: float, 67 | verbose: bool = False, 68 | ): 69 | self.Q: Dict[MCTS_Node, float] = defaultdict(lambda: 0.0) # total reward of each node 70 | self.N: Dict[MCTS_Node, int] = defaultdict(lambda: 0) # total visit count for each node 71 | self.parent2children: Dict[MCTS_Node, List[MCTS_Node]] = dict() # children of each node 72 | 73 | #! explored = expanded + simulated, i.e. has seen terminal at least once, i.e. we can calculate its UCT value, i.e. has Q and N 74 | self.explored_nodes = set() 75 | 76 | self.exploration_weight = exploration_weight 77 | self.weight_scheduler = weight_scheduler 78 | self.num_rollouts = num_rollouts 79 | self.discount = discount 80 | 81 | self.verbose = verbose 82 | 83 | global node_cnt 84 | node_cnt = 0 85 | 86 | def do_rollout(self, root_node: MCTS_Node, rollout_id: int): 87 | "Make the tree one layer better. (Train for one iteration.)" 88 | verbose_print("==> Selecting a node...", self.verbose) 89 | path_1 = self._select(root_node, rollout_id) 90 | leaf = path_1[-1] 91 | verbose_print(f"==> Expanding node {leaf.id}...", self.verbose) 92 | self._expand(leaf, rollout_id) 93 | verbose_print(f"==> Simulating node {leaf.id}...", self.verbose) 94 | path_2 = self._simulate(leaf, rollout_id) 95 | verbose_print(f"==> Backpropagating...", self.verbose) 96 | self._backpropagate(path_1 + path_2) 97 | try: 98 | return path_2[-1] 99 | except: 100 | return path_1[-1] 101 | 102 | def _select(self, node: MCTS_Node, rollout_id: int) -> List[MCTS_Node]: 103 | "Find an unexplored descendent of `node`" 104 | path = [] 105 | while True: 106 | path.append(node) 107 | # case 1: a node does not have children, then select the node itself 108 | if node not in self.parent2children.keys(): 109 | return path 110 | 111 | # case 2: a node has children but not all children have been explored, then randomly select an unexplored child 112 | # unexplored = set(self.parent2children[node]) - self.explored_nodes # `set` introduces randomness 113 | unexplored = [n for n in self.parent2children[node] if n not in self.explored_nodes] 114 | if unexplored: 115 | n = random.choice(unexplored) 116 | path.append(n) 117 | return path 118 | 119 | # case 3: a node has children and all children have been explored, then select one child and go to the next layer 120 | node = self._uct_select(node, rollout_id) 121 | 122 | def _expand(self, node: MCTS_Node, rollout_id: int): 123 | "Update the `children` dict with the children of `node`" 124 | if node in self.explored_nodes: 125 | return # already expanded 126 | 127 | if node.is_terminal(): 128 | self.explored_nodes.add(node) 129 | return # terminal node is non-expandable 130 | 131 | self.parent2children[node] = node.find_children(rollout_id) 132 | 133 | def _simulate(self, node: MCTS_Node, rollout_id: int) -> List[MCTS_Node]: 134 | "Returns the reward for a random simulation (to completion) of `node`" 135 | path = [] 136 | cur_node = node 137 | while True: 138 | if cur_node.is_terminal(): 139 | self.explored_nodes.add(node) 140 | return path 141 | 142 | if cur_node not in self.parent2children.keys(): 143 | self.parent2children[cur_node] = cur_node.find_children(rollout_id) 144 | 145 | cur_node = random.choice(self.parent2children[cur_node]) # randomly select a child 146 | path.append(cur_node) 147 | 148 | def _backpropagate(self, path: List[MCTS_Node]): 149 | "Send the reward back up to the ancestors of the leaf" 150 | leaf = path[-1] 151 | reward = leaf.calculate_reward() 152 | for node in reversed(path): 153 | self.Q[node] += reward 154 | self.N[node] += 1 155 | self.explored_nodes.add(node) 156 | from run_src.rstar_utils import Node_Type 157 | parent_node = None 158 | for node in path[:-1]: 159 | if node.node_type == Node_Type.USER_QUESTION: 160 | parent_node = node 161 | continue 162 | if node.node_value is None: 163 | node.node_value = 0 164 | node.node_value += reward 165 | for key, val in parent_node.solution_trace[0]["ost_step_value"].items(): 166 | node.solution_trace[0]["ost_step_value"][key] = val 167 | last_key = list(node.solution_trace[0]["ost_step_value"].keys())[-1] 168 | node.solution_trace[0]["ost_step_value"][last_key] = node.node_value 169 | parent_node = node 170 | node = path[-1] 171 | for key, val in parent_node.solution_trace[0]["ost_step_value"].items(): 172 | node.solution_trace[0]["ost_step_value"][key] = val 173 | 174 | 175 | def _get_weight(self, rollout_id: int): 176 | # start with exploration weight, end with 0.1 * exploration weight 177 | if self.weight_scheduler == "exp": 178 | return self.exploration_weight * (0.1 ** (rollout_id / self.num_rollouts)) 179 | elif self.weight_scheduler == "lin": 180 | return self.exploration_weight * (1 - 0.9 * (rollout_id / self.num_rollouts)) 181 | elif self.weight_scheduler == "const": 182 | return self.exploration_weight 183 | 184 | def _uct_select(self, node: MCTS_Node, rollout_id: int): 185 | "Select a child of node, balancing exploration & exploitation" 186 | 187 | # All children of the node should already be expanded 188 | assert all(n in self.explored_nodes for n in self.parent2children[node]) 189 | 190 | return max( 191 | self.parent2children[node], key=lambda n: self._compute_uct(parent_node=node, node=n, rollout_id=rollout_id) 192 | ) 193 | 194 | def _compute_uct(self, parent_node: MCTS_Node, node: MCTS_Node, rollout_id: int): 195 | "Upper confidence bound for trees" 196 | if parent_node is None: # invalid UCT: the node is the root 197 | return 666 198 | else: 199 | if self.N[node] == 0: # invalid UCT: the node has not been explored yet 200 | return 999 201 | else: 202 | weight = self._get_weight(rollout_id) 203 | return self.Q[node] / self.N[node] + weight * math.sqrt(math.log(self.N[parent_node]) / self.N[node]) 204 | -------------------------------------------------------------------------------- /src/mcts/run_src/MCTS_for_reasoning.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | 3 | import sys 4 | 5 | sys.path.append(".") 6 | 7 | import numpy as np 8 | import os, random, json, math 9 | import wandb 10 | from tqdm import trange 11 | from typing import List, Dict, Tuple 12 | from copy import deepcopy 13 | 14 | 15 | from models.IO_System import IO_System 16 | from common.utils import read_txt, read_json 17 | from eval_src.Evaluator import Evaluator 18 | from MCTS_backbone import MCTS_Searcher, MCTS_Node 19 | from run_src.rstar_utils import ( 20 | Node_Type, 21 | reach_terminal_ost_step, 22 | concat_ost_steps, 23 | ost_find_best_solution, 24 | find_solution, 25 | time_decorator, 26 | print_tree_from_root 27 | ) 28 | 29 | 30 | def verbose_print(s: str, verbose: bool): 31 | if verbose: 32 | print(s) 33 | 34 | 35 | class Generator: 36 | """Generator generates children nodes""" 37 | 38 | def __init__(self, args, tokenizer, model, evaluator: Evaluator) -> None: 39 | self.io = IO_System(args, tokenizer, model) 40 | self.evaluator = evaluator 41 | 42 | self.num_sampling = args.num_sampling 43 | self.max_tokens = args.max_tokens 44 | self.enable_potential_score = args.enable_potential_score 45 | 46 | self.mcts_num_last_votes = args.mcts_num_last_votes 47 | 48 | 49 | self.examples = read_txt(args.examples_txt_path) 50 | self.prompt = read_json(args.prompt_config_path) 51 | 52 | 53 | def _get_pass_code(self, io_output_list: List[str], user_question: str) -> Tuple[str, float]: 54 | assert len(io_output_list) > 0 55 | 56 | if len(io_output_list) == 1: 57 | most_confident_answer_full_completion = io_output_list[0] 58 | confidence = 1 59 | else: 60 | _, passed_full_completion, _, confidence = self.evaluator.find_pass_code(io_output_list, user_question) 61 | assert confidence >= 0 62 | 63 | return passed_full_completion, confidence 64 | 65 | @time_decorator 66 | def _get_TACO_code(self, io_output_list: List[str], test_case: dict, solution_trace: Dict[int, Dict[str, str]],) -> Tuple[str, float]: 67 | assert len(io_output_list) > 0 68 | 69 | 70 | _, passed_full_completion, confidence, solution_trace_ = self.evaluator.find_TACO_code(io_output_list, test_case, solution_trace) 71 | assert confidence >= 0 72 | 73 | return passed_full_completion, confidence, solution_trace_ 74 | 75 | def generate_ost_step( 76 | self, 77 | user_question: str, 78 | test_case: dict, 79 | solution_trace: Dict[int, Dict[str, str]], 80 | paraphrased: bool, 81 | ): 82 | ost_step_list = [] 83 | existing_ost_steps, next_ost_step_id = concat_ost_steps(solution_trace) 84 | io_input = ( 85 | self.prompt["prompt_template"].format( 86 | examples=self.examples, 87 | question=user_question, 88 | ) 89 | + existing_ost_steps 90 | + f"\n### Step {next_ost_step_id}:" 91 | ) 92 | 93 | io_output_list = self.io.generate( 94 | model_input=io_input, max_tokens=8192, num_return=self.num_sampling, stop_tokens=[""] 95 | ) 96 | ost_step_list = [io_output.strip() for io_output in io_output_list] 97 | 98 | last_ost_step = [] 99 | value_list = [] 100 | completion_confidence_list = [] 101 | reach_last_step_flag = False 102 | # have_terminal_ost_step = False 103 | for ost_step in ost_step_list: 104 | if reach_terminal_ost_step(ost_step): 105 | reach_last_step_flag = True 106 | passed_full_completion, confidence, solution_trace_with_last_step = self._get_TACO_code([ost_step], test_case, solution_trace) 107 | completion_confidence_list.append((passed_full_completion, confidence)) 108 | else: 109 | last_ost_step.append(ost_step) 110 | value_list.append(None) 111 | 112 | if reach_last_step_flag == True: 113 | last_ost_step.clear() 114 | value_list.clear() 115 | completion_confidence_list.sort(key=lambda x: x[1], reverse=True) 116 | best_passing_completion, highest_confidence = completion_confidence_list[0] 117 | 118 | return [best_passing_completion], [highest_confidence], [None] 119 | else: 120 | potential_answers_list: List[List[str]] = [] 121 | # print(value_list) 122 | if value_list.count(None) != 0 and value_list.count(None) != len(value_list): 123 | for idx, value in enumerate(value_list): 124 | if value is not None: 125 | number = value 126 | corresponding_step = last_ost_step[idx] 127 | break 128 | value_list = [number] 129 | last_ost_step = [corresponding_step] 130 | potential_answers_list = [None] * len(value_list) 131 | return last_ost_step, value_list, potential_answers_list 132 | 133 | 134 | class Reasoning_MCTS_Node(MCTS_Node): 135 | def __init__( 136 | self, 137 | parent: "Reasoning_MCTS_Node", 138 | depth: int, 139 | node_type: Node_Type, 140 | verbose: bool = False, 141 | # --- For instantiating root node --- 142 | node_value: float = None, 143 | generator: Generator = None, 144 | user_question: str = None, 145 | max_depth_allowed: int = None, 146 | difficulty: str = None, 147 | # ------------------------------------------- 148 | # --- For instantiating OST_STEP node --- 149 | ost_step: str = None, 150 | # --------------------------------------- 151 | # --- For node selection (not in sanity checks yet) --- 152 | enable_potential_score: bool = None, 153 | potential_answers: List[str] = None, 154 | test_case: dict = None, 155 | ) -> None: 156 | """params: 157 | subquestion: the node is proposing a new subquestion 158 | subanswer: the answer corresponding to the new subquestion the node proposed 159 | re_subanswer: the node is proposing a new subanswer to the parent's subquestion 160 | """ 161 | super().__init__() 162 | 163 | #! sanity checks 164 | try: 165 | assert depth is not None 166 | assert node_type is not None 167 | if node_value is not None: 168 | print(node_value) 169 | assert node_value >= 0, breakpoint() 170 | 171 | if node_type is Node_Type.USER_QUESTION: 172 | assert depth == 0 173 | assert all( 174 | attr is None 175 | for attr in [ 176 | parent, 177 | node_value, 178 | ost_step, 179 | ] 180 | ) 181 | assert all( 182 | attr is not None 183 | for attr in [generator, user_question, difficulty, max_depth_allowed] 184 | ) 185 | elif node_type is Node_Type.ONE_STEP: 186 | assert depth > 0 187 | assert all( 188 | attr is None 189 | for attr in [ 190 | generator, 191 | user_question, 192 | difficulty, 193 | max_depth_allowed, 194 | ] 195 | ) 196 | assert all(attr is not None for attr in [parent, ost_step]) 197 | except AssertionError: 198 | print(f"Instantiating node with type {node_type} failed!") 199 | breakpoint() 200 | exit() 201 | 202 | #! attributes 203 | self.parent = parent # if parent is None, then the node is the root 204 | self.children: List["Reasoning_MCTS_Node"] = [] 205 | self.depth = depth 206 | self.node_type = node_type 207 | self.node_value = node_value 208 | self.ost_step = ost_step 209 | self.test_case = test_case 210 | 211 | if parent is None: # root 212 | self.verbose = verbose 213 | self.user_question = user_question 214 | self.difficulty = difficulty 215 | self.generator = generator 216 | self.max_depth_allowed = max_depth_allowed 217 | self.enable_potential_score = enable_potential_score 218 | self.test_case = test_case 219 | else: # inherit from parent 220 | self.verbose = parent.verbose 221 | self.user_question = parent.user_question 222 | self.difficulty = parent.difficulty 223 | self.generator = parent.generator 224 | self.max_depth_allowed = parent.max_depth_allowed 225 | self.enable_potential_score = parent.enable_potential_score 226 | self.test_case = parent.test_case 227 | 228 | #! keep track of paraphrasing 229 | if node_type is Node_Type.USER_QUESTION: 230 | self.paraphrased = False 231 | else: 232 | assert parent is not None 233 | self.paraphrased = parent.paraphrased 234 | 235 | 236 | #! record number of one-step thought steps till now 237 | if parent is None: # root 238 | self.ost_step_counter = 0 239 | else: 240 | if node_type is Node_Type.ONE_STEP: 241 | self.ost_step_counter = parent.ost_step_counter + 1 242 | else: 243 | self.ost_step_counter = parent.ost_step_counter 244 | 245 | #! record solution trace from root to the current node. key: subquestion id 246 | if parent is None: # root 247 | assert self.node_type is Node_Type.USER_QUESTION 248 | self.solution_trace: Dict[int, Dict[str, str]] = {0: {"user_question": user_question, "ost_step": {}, "ost_step_value": {}}} 249 | else: 250 | assert self.node_type is not Node_Type.USER_QUESTION 251 | self.solution_trace = deepcopy(parent.solution_trace) 252 | 253 | if node_type is Node_Type.ONE_STEP: 254 | assert "ost_step" in self.solution_trace[0].keys() 255 | self.solution_trace[0]["ost_step"][self.ost_step_counter] = ost_step 256 | self.solution_trace[0]["ost_step_value"][self.ost_step_counter] = node_value 257 | 258 | #! potential_score for intermediate nodes (only used for node selection) 259 | if self.enable_potential_score: 260 | self.potential_answers = potential_answers 261 | self.potential_score = 0 262 | if parent is None: # root 263 | assert self.node_type is Node_Type.USER_QUESTION 264 | self.potential_answers_history = {} 265 | else: 266 | assert self.node_type is not Node_Type.USER_QUESTION 267 | self.potential_answers_history = deepcopy(parent.potential_answers_history) 268 | self.potential_answers_history[self.depth] = potential_answers 269 | 270 | def __str__(self) -> str: 271 | type2str = { 272 | Node_Type.USER_QUESTION: "U", 273 | Node_Type.ONE_STEP: "TS", 274 | } 275 | return f"{type2str[self.node_type]}-{self.id}" 276 | 277 | def _create_children(self): 278 | 279 | def do_action_generate_ost_step(): 280 | verbose_print(f"---- Generating one-step thought steps for node {self.id}...", self.verbose) 281 | 282 | #! ACTION: generate one-step thought step 283 | ost_step_list, value_list, potential_answers_list = self.generator.generate_ost_step( 284 | user_question=self.user_question, 285 | test_case = self.test_case, 286 | solution_trace=self.solution_trace, 287 | paraphrased=self.paraphrased, 288 | ) 289 | for ost_step, value, potential_answers in zip(ost_step_list, value_list, potential_answers_list): 290 | self.children.append( 291 | Reasoning_MCTS_Node( 292 | parent=self, 293 | depth=self.depth + 1, 294 | node_type=Node_Type.ONE_STEP, 295 | node_value=value, 296 | ost_step=ost_step, 297 | potential_answers=deepcopy(potential_answers), 298 | ) 299 | ) 300 | 301 | #! create children 302 | if self.node_type is Node_Type.USER_QUESTION: 303 | # generate one-step thought steps 304 | do_action_generate_ost_step() 305 | 306 | 307 | elif self.node_type is Node_Type.ONE_STEP: 308 | 309 | do_action_generate_ost_step() 310 | 311 | assert self.children 312 | return self.children 313 | 314 | def is_valid_leaf_node(self): 315 | 316 | return (self.node_type is Node_Type.ONE_STEP and reach_terminal_ost_step(self.ost_step)) 317 | 318 | def is_valid_solution_node(self): 319 | 320 | return (self.node_type is Node_Type.ONE_STEP and reach_terminal_ost_step(self.ost_step)) 321 | 322 | def set_potential_score(self, score: float): 323 | self.potential_score = score 324 | 325 | def find_children(self, rollout_id: int): 326 | self.children = self.children or self._create_children() 327 | for child in self.children: 328 | child.set_rollout_id(rollout_id) 329 | assert self.children 330 | return self.children 331 | 332 | def is_terminal(self): 333 | return self.depth >= self.max_depth_allowed or self.is_valid_leaf_node() 334 | 335 | def calculate_reward(self): 336 | if self.is_valid_leaf_node(): 337 | assert self.node_value is not None, breakpoint() 338 | return self.node_value 339 | else: 340 | return 0 341 | 342 | 343 | def search_for_answers(args, user_question: str, question_id: int, difficulty: str, generator: Generator, test_case: dict): 344 | verbose_print( 345 | f"********************* Searching for answers to question {question_id} ********************* ", args.verbose 346 | ) 347 | 348 | #! build an MCTS searcher 349 | mcts_searcher = MCTS_Searcher( 350 | exploration_weight=args.mcts_exploration_weight, 351 | weight_scheduler=args.mcts_weight_scheduler, 352 | num_rollouts=args.num_rollouts, 353 | discount=args.mcts_discount_factor, 354 | verbose=args.verbose, 355 | ) 356 | 357 | #! build the MCTS tree 358 | root_node = Reasoning_MCTS_Node( 359 | parent=None, 360 | depth=0, 361 | node_type=Node_Type.USER_QUESTION, 362 | verbose=args.verbose, 363 | generator=generator, 364 | user_question=user_question, 365 | difficulty=difficulty, 366 | max_depth_allowed=args.max_depth_allowed, 367 | enable_potential_score=args.enable_potential_score, 368 | test_case=test_case, 369 | ) 370 | 371 | model_solutions = [] 372 | model_all_solutions = [] 373 | model_rollout_nodes = [] 374 | for i in (pbar := trange(args.num_rollouts, disable=True, position=0)): 375 | rollout_node = mcts_searcher.do_rollout(root_node, i) 376 | model_rollout_nodes.append(rollout_node) 377 | jss = {"trace": rollout_node.solution_trace, "rollout_id": rollout_node.rollout_id, "value": rollout_node.node_value} 378 | 379 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - rollout Solutions.json"), "a") as f: 380 | json.dump(jss, f) 381 | f.write(',') 382 | 383 | # print_tree_from_root(mcts_searcher, args.num_rollouts - 1, root_node) 384 | 385 | ost_best_node, ost_all_solution_nodes, TREE = ost_find_best_solution(root_node, generator.evaluator) 386 | 387 | complete_road = [] 388 | 389 | for solution_node in ost_all_solution_nodes: 390 | complete_road_json = find_solution(root_node, solution_node, mcts_searcher) 391 | complete_road.append(complete_road_json) 392 | 393 | 394 | bestv = -1 395 | ost_best_node = None 396 | for rollout_node in model_rollout_nodes: 397 | if rollout_node.node_value is not None: 398 | if rollout_node.node_value > bestv: 399 | bestv = rollout_node.node_value 400 | ost_best_node = rollout_node 401 | 402 | 403 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Complete Solutions.json"), "w", encoding="utf-8") as f: 404 | json.dump(complete_road, f, ensure_ascii=False, indent=4) 405 | #! record final traces 406 | js = [{"trace": node.solution_trace, "rollout_id": node.rollout_id, "parent_id": node.parent.id, "value": node.node_value} for node in ost_all_solution_nodes] 407 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Final Solutions.json"), "w") as f: 408 | json.dump(js, f) 409 | 410 | js2 = [{"trace": node.solution_trace, "rollout_id": i, "value": node.node_value} for i, node in enumerate(model_rollout_nodes)] 411 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Rollout Solutions.json"), "w") as f: 412 | json.dump(js2, f) 413 | 414 | if ost_best_node is not None: 415 | js3 = {"trace": ost_best_node.solution_trace, "rollout_id": ost_best_node.rollout_id, "value": ost_best_node.node_value} 416 | 417 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Best Solutions.json"), "w") as f: 418 | json.dump(js3, f) 419 | 420 | 421 | 422 | if args.enable_potential_score: 423 | js = [node.potential_answers_history for node in ost_all_solution_nodes] 424 | with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Potentials.json"), "w") as f: 425 | json.dump(js, f) 426 | 427 | 428 | return model_solutions, i, model_all_solutions 429 | -------------------------------------------------------------------------------- /src/mcts/run_src/do_generate.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | 3 | import sys 4 | 5 | 6 | print("args: ", sys.argv) 7 | 8 | import os, json, time 9 | from tqdm import tqdm 10 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 11 | 12 | sys.path.append(".") 13 | 14 | from common.utils import fix_seeds, setup_model_parallel, read_json 15 | from common.arguments import get_parser, post_process_args, save_args 16 | from MCTS_for_reasoning import Generator, search_for_answers 17 | from eval_src.Evaluator import * 18 | 19 | 20 | def main(args): 21 | fix_seeds(args.seed) 22 | if args.model_parallel: 23 | args.local_rank, args.world_size = setup_model_parallel() 24 | else: 25 | args.local_rank, args.world_size = 0, 1 26 | 27 | test_file = os.path.join(args.data_root, args.dataset_name, args.test_json_filename + ".json") 28 | assert os.path.exists(test_file), f"Test file {test_file} does not exist." 29 | data_item_list = read_json(test_file) 30 | 31 | evaluator = eval(f"{args.dataset_name}Evaluator()") 32 | 33 | tokenizer, model = None, None 34 | if args.api == "huggingface": 35 | from models.HuggingFace_API import load_HF_model 36 | 37 | tokenizer, model = load_HF_model(args.model_ckpt) 38 | elif args.api == "vllm": 39 | from models.vLLM_API import load_vLLM_model 40 | 41 | tokenizer, model = load_vLLM_model(args.model_ckpt, args.seed, args.tensor_parallel_size, args.half_precision) 42 | elif args.api == "OpenAI": 43 | from models.OpenAI_API import load_OpenAI_model 44 | 45 | tokenizer, model = load_OpenAI_model(args.model_ckpt) 46 | generator = Generator(args, tokenizer, model, evaluator) 47 | 48 | num_tested = 0 49 | start_time = time.time() 50 | 51 | for i, data_item in enumerate( 52 | (pbar := tqdm(data_item_list, disable=args.local_rank > 0 or args.verbose, position=1)) 53 | ): 54 | if i < args.start_idx or i >= args.end_idx: 55 | continue 56 | st_time = time.time() 57 | problem_id, problem, test_case, difficulty = i, data_item["question"],data_item["input_output"], data_item["difficulty"] 58 | gt_solution = data_item["solutions"][0] if len(data_item["solutions"]) > 0 else None 59 | 60 | js = { 61 | "id": problem_id, 62 | "problem": problem, 63 | "gold_solution": gt_solution, 64 | "test_case": test_case, 65 | "difficulty": difficulty, 66 | } 67 | 68 | 69 | model_solutions, stopping_id, model_all_solutions = search_for_answers( 70 | args=args, user_question=problem, question_id=i, difficulty=difficulty, generator=generator, test_case = test_case 71 | ) 72 | 73 | num_tested += 1 74 | 75 | 76 | 77 | with open(os.path.join(args.run_outputs_dir, "intermediate_result.txt"), "w") as f: 78 | f.write( 79 | f"Total calls: {generator.io.call_counter}, Avg calls: {generator.io.call_counter/(num_tested):.2f}\n" 80 | ) 81 | f.write( 82 | f"Total tokens: {generator.io.token_counter}, Avg tokens: {generator.io.token_counter/(num_tested):.2f}\n" 83 | ) 84 | ed_time = time.time() 85 | js["time_taken"] = f"{ed_time-st_time:.2f}s" 86 | 87 | with open(os.path.join(args.answer_sheets_dir, f"Question {i:04d} - Answer.json"), "w") as f: 88 | json.dump(js, f) 89 | print(f"==> Time taken for this question: {ed_time-st_time:.2f}s") 90 | 91 | end_time = time.time() 92 | 93 | print(f"==> Total calls: {generator.io.call_counter}, Avg calls: {generator.io.call_counter/(num_tested):.2f}") 94 | print(f"==> Total tokens: {generator.io.token_counter}, Avg tokens: {generator.io.token_counter/(num_tested):.2f}") 95 | print(f"==> Total time: {end_time-start_time:.2f}s, Avg time: {(end_time-start_time)/(num_tested):.2f}s") 96 | 97 | with open(os.path.join(args.run_outputs_dir, "final_result.txt"), "w") as f: 98 | f.write(f"Total calls: {generator.io.call_counter}, Avg calls: {generator.io.call_counter/(num_tested):.2f}\n") 99 | f.write( 100 | f"Total tokens: {generator.io.token_counter}, Avg tokens: {generator.io.token_counter/(num_tested):.2f}\n" 101 | ) 102 | f.write(f"Total time: {end_time-start_time:.2f}s, Avg time: {(end_time-start_time)/(num_tested):.2f}s\n") 103 | 104 | 105 | if __name__ == "__main__": 106 | #! -------------------------------- Arguments -------------------------------- 107 | parser = get_parser() 108 | 109 | parser.add_argument("--num_rollouts", type=int, default=15) 110 | parser.add_argument("--max_depth_allowed", type=int, default=5) 111 | 112 | # MCTS 113 | parser.add_argument("--mcts_discount_factor", type=float, default=1.0) 114 | parser.add_argument("--mcts_exploration_weight", type=float, default=2.0) 115 | parser.add_argument("--mcts_weight_scheduler", choices=["exp", "lin", "const"], default="const") 116 | parser.add_argument("--mcts_num_last_votes", type=int, default=None) 117 | parser.add_argument("--save_tree", action="store_true") 118 | 119 | # Action1: Propose an one-step thought. 120 | parser.add_argument("--num_sampling", type=int, default=3) 121 | 122 | 123 | #! -------------------------- Used for selecting answer -------------------------- 124 | parser.add_argument("--enable_potential_score", action="store_true") 125 | 126 | #! ------------------------------------------------------------------------------- 127 | 128 | args = parser.parse_args() 129 | 130 | if args.mcts_num_last_votes is None: 131 | args.mcts_num_last_votes = 32 132 | 133 | #! ---------------------------------------------------------------------------- 134 | 135 | prompts_dir = os.path.join(args.prompts_root, args.dataset_name) 136 | 137 | 138 | args.examples_txt_path = os.path.join(prompts_dir, "examples.txt") 139 | args.prompt_config_path = os.path.join(prompts_dir, "prompt.json") 140 | 141 | 142 | args = post_process_args(args) 143 | print(args) 144 | save_args(args) 145 | main(args) 146 | -------------------------------------------------------------------------------- /src/mcts/run_src/rstar_utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | 3 | from enum import Enum, unique 4 | import re 5 | import math 6 | from typing import Dict, Tuple 7 | from colorama import Fore, Style 8 | import math 9 | 10 | 11 | @unique 12 | class Node_Type(Enum): 13 | USER_QUESTION = "USER_QUESTION" 14 | ONE_STEP = "ONE_STEP" 15 | 16 | import time 17 | 18 | def time_decorator(func): 19 | def wrapper(*args, **kwargs): 20 | start_time = time.time() 21 | result = func(*args, **kwargs) 22 | end_time = time.time() 23 | print(f"{func.__name__} runs: {end_time - start_time:.6f} 秒") 24 | return result 25 | return wrapper 26 | 27 | 28 | def reach_terminal_ost_step(ost_step: str): 29 | assert ost_step is not None 30 | last_step = ost_step.lower() 31 | 32 | code_indicators = [ 33 | # "", 34 | "```python" 35 | ] 36 | 37 | return any(indicator in last_step for indicator in code_indicators) 38 | 39 | 40 | def print_tree_from_root(mcts_searcher, rollout_id, root_node, chosen_node=None, file=None): 41 | color_print = False if file else True 42 | 43 | def my_print(text): 44 | if file: 45 | file.write(text + "\n") 46 | else: 47 | print(text) 48 | 49 | def print_tree(parent_node, node, file, rollout_id): 50 | to_print = "" 51 | 52 | num_indent = 4 53 | dash = "-" * num_indent * node.depth 54 | space = " " * num_indent * node.depth 55 | 56 | attributes = f"Q: {round(mcts_searcher.Q[node], 2)}" + "; " + f"N: {mcts_searcher.N[node]}" + "; " 57 | attributes += f"V: {round(node.node_value, 2)}" if node.node_value is not None else "V: None" 58 | 59 | uct_value = "UCT: " + str( 60 | round(mcts_searcher._compute_uct(parent_node=parent_node, node=node, rollout_id=rollout_id), 2) 61 | ) 62 | attributes += "; " + uct_value 63 | 64 | solution_marker = "(T) " if node.is_valid_solution_node() else "" 65 | 66 | node_info = "[" + solution_marker + node.__str__() + ": " + attributes + "]" 67 | if chosen_node and node == chosen_node: 68 | node_info = "[" + node_info + "]" 69 | node_info += " " 70 | 71 | if color_print and node.is_valid_solution_node(): 72 | node_details = Fore.RED + Style.BRIGHT + node_info + Fore.RESET + Style.RESET_ALL 73 | else: 74 | node_details = node_info 75 | 76 | if node.node_type is Node_Type.USER_QUESTION: 77 | node_details += f"User: {node.user_question}" + "\n" + space + " " * len(node_info) 78 | elif node.node_type is Node_Type.ONE_STEP: 79 | node_details += f"OST: {node.ost_step}" 80 | 81 | to_print += dash + node_details 82 | 83 | my_print(to_print) 84 | 85 | for child in node.children: 86 | print_tree(node, child, file, rollout_id) 87 | 88 | if node.depth == 0: 89 | my_print("\n" + "=" * 50 + "\n") 90 | 91 | print_tree(parent_node=None, node=root_node, file=file, rollout_id=rollout_id) 92 | 93 | 94 | def concat_ost_steps(solution_trace: Dict[int, Dict[str, str]]) -> Tuple[str, int]: 95 | """Return: concatenated one-step thought steps, next one-step thought step id""" 96 | last_tuple = list(solution_trace.items())[-1] 97 | last_tuple_id, last_tuple_recording = last_tuple[0], last_tuple[1] 98 | assert "ost_step" in last_tuple_recording.keys() 99 | if len(last_tuple_recording["ost_step"]) > 0: 100 | solution_trace_str = "" 101 | for step_id, step_text in last_tuple_recording["ost_step"].items(): 102 | solution_trace_str += f"\n### Step {step_id}: " + step_text + "\n\n" 103 | return solution_trace_str, step_id + 1 104 | else: 105 | # no one-step thought step yet 106 | return "", 1 107 | 108 | 109 | def find_valid_solution_nodes(root_node): 110 | valid_solution_nodes = [] 111 | TREE = {} 112 | 113 | def recursion(node): 114 | if root_node.depth in TREE: 115 | TREE[root_node.depth].append(root_node) 116 | else: 117 | TREE[root_node.depth] = [root_node] 118 | 119 | if node.is_valid_solution_node(): 120 | valid_solution_nodes.append(node) 121 | return 122 | 123 | if not node.children: #! no children 124 | return 125 | 126 | for child in node.children: 127 | recursion(child) 128 | 129 | recursion(root_node) 130 | 131 | return valid_solution_nodes, TREE 132 | 133 | 134 | def find_best_solution(root_node, evaluator, enable_potential_score=False): 135 | # todo: what strategy do we use to select best node? 136 | """The function finds the best solution from the solution nodes in the MCTS tree. 137 | Return: top answer, top solution, confidence of the top answer, the corresponding node of the answer, all solution nodes 138 | """ 139 | solution_nodes = find_valid_solution_nodes(root_node) 140 | 141 | if len(solution_nodes) == 0: 142 | return None, None 143 | 144 | def extract_solution_from_node(node): 145 | if node.node_type is Node_Type.SUBQUESTION: 146 | return node.subanswer 147 | elif node.node_type is Node_Type.DIRECT_ANSWER: 148 | return node.direct_answer 149 | else: 150 | return None 151 | 152 | solutions = [extract_solution_from_node(node) for node in solution_nodes] 153 | 154 | def calculate_potential_score_for_solution_node(node): 155 | model_answer = evaluator.extract_answer_from_model_completion(extract_solution_from_node(node)) 156 | potential_answers_history = node.potential_answers_history # {depth -> [potential answers]} 157 | assert potential_answers_history[node.depth] is None 158 | 159 | potential_score = 1 160 | for depth, depth_potential_answers in potential_answers_history.items(): 161 | if depth < node.depth: 162 | depth_score = sum( 163 | evaluator.check_answers_equiv(dpa, model_answer) for dpa in depth_potential_answers 164 | ) / len(depth_potential_answers) 165 | potential_score *= depth_score 166 | 167 | node.set_potential_score(potential_score) 168 | return potential_score 169 | 170 | prior_weights = ( 171 | [calculate_potential_score_for_solution_node(node) for node in solution_nodes] 172 | if enable_potential_score 173 | else None 174 | ) 175 | top_answer, top_completion, top_completion_id, top_confidence = evaluator.find_most_confident_answer( 176 | solutions, prior_weights 177 | ) 178 | return top_answer, top_completion, top_confidence, solution_nodes[top_completion_id], solution_nodes 179 | 180 | 181 | def ost_find_best_solution( 182 | root_node, 183 | evaluator, 184 | ): 185 | solution_nodes, TREE = find_valid_solution_nodes(root_node) 186 | 187 | bestv = -1 188 | best_node = None 189 | for solution_node in solution_nodes: 190 | if solution_node.node_value > bestv: 191 | bestv = solution_node.node_value 192 | best_node = solution_node 193 | 194 | return best_node, solution_nodes, TREE 195 | 196 | def find_solution(root_node, solution_node, mcts_searcher): 197 | """ 198 | Recursively traces back from the given solution node to the root node, 199 | calculating the value of each node along the path. 200 | 201 | Parameters: 202 | solution_node (Node): The current solution node to start the backtrace from. 203 | mcts_searcher (MCTS): The MCTS searcher object used to access node visit counts and values. 204 | 205 | Returns: 206 | dict: A dictionary representing the complete solution, containing the node id, 207 | OST step, step value, and edge information for each node in the path. 208 | """ 209 | comlete_solution = {} 210 | 211 | def reback(node): 212 | """ 213 | Recursively backtracks from the current node to the root node, 214 | calculating the value for each node and updating the solution. 215 | 216 | Parameters: 217 | node (Node): The current node being processed in the backtrack. 218 | """ 219 | if node.node_value is not None and mcts_searcher.N[node] != 0: 220 | value = node.node_value / mcts_searcher.N[node] 221 | else: 222 | value = 0 223 | if node.node_type is Node_Type.ONE_STEP: 224 | comlete_solution[node.depth] = { 225 | "node_id": node.id, 226 | "ost_step": node.ost_step, 227 | "step_value": value, 228 | "edges": (node.parent.id, node.id) # source_node_id -> target_node_id 229 | } 230 | else: 231 | comlete_solution[node.depth] = { 232 | "node_id": node.id, 233 | "question": root_node.user_question, 234 | } 235 | if node.node_type is Node_Type.USER_QUESTION: 236 | return 237 | 238 | reback(node.parent) 239 | 240 | reback(solution_node) 241 | return comlete_solution 242 | 243 | -------------------------------------------------------------------------------- /src/mcts/scripts/api_run_TACO.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python run_src/do_generate.py \ 2 | --dataset_name TACO \ 3 | --test_json_filename test_one \ 4 | --api OpenAI \ 5 | --model_ckpt gpt-4o-mini \ 6 | --note default \ 7 | --num_rollouts 12 \ 8 | --verbose \ 9 | --max_depth_allowed 10 -------------------------------------------------------------------------------- /src/mcts/scripts/run_TACO.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=3 python run_src/do_generate.py \ 2 | --dataset_name TACO \ 3 | --test_json_filename test_one \ 4 | --model_ckpt \ 5 | --note default \ 6 | --num_rollouts 3 \ 7 | --verbose \ 8 | --max_depth_allowed 8 -------------------------------------------------------------------------------- /src/mcts/scripts/start.md: -------------------------------------------------------------------------------- 1 | bash scripts/api_run_TACO.sh 2 | bash scripts/run_TACO.sh -------------------------------------------------------------------------------- /src/prm_training/README.md: -------------------------------------------------------------------------------- 1 | ## How to use 2 | To initialize the environment, you need to install the required packages. You can do this by running the following command: 3 | 4 | ```bash 5 | pip install -r requirements.txt 6 | ``` 7 | 8 | This will install all the necessary dependencies listed in the `requirements.txt` file. 9 | 10 | ### Data 11 | We provide data examples for PRM training in the `data` folder. The reward labels are available in both hard and soft estimation forms. 12 | 13 | For the hard estimation labels, you can refer to `data/examples/hard_label_examples.json` for processing, while the corresponding soft label forms are provided in `data/examples/soft_label_examples.json`. 14 | 15 | ### Train 16 | #### Basic Usage 17 | Ensure the path is within the `prm_training` folder and run the following script 18 | ```bash 19 | bash run.sh 20 | ``` 21 | #### Main Arguments 22 | | | | | 23 | |---------------|--------|-----------------------------------| 24 | | `--config_file` | str | accelerate config file path | 25 | | `--model_name_or_path` | str | your model path | 26 | | `--data_path` | str | data for training | 27 | | `--use_soft_label` | bool | Whether to use soft labels during training, default is false | -------------------------------------------------------------------------------- /src/prm_training/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.1.1 2 | datasets==3.1.0 3 | numpy==2.1.3 4 | peft==0.13.2 5 | scikit_learn==1.5.2 6 | torch==2.4.1 7 | transformers==4.46.2 8 | deepspeed==0.15.3 9 | wandb -------------------------------------------------------------------------------- /src/prm_training/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from dataclasses import dataclass, field 4 | from typing import Optional 5 | from transformers import Seq2SeqTrainingArguments, HfArgumentParser 6 | from accelerate import PartialState 7 | 8 | from train_prm.run_train import run_exp 9 | 10 | DIST_STATE = PartialState() 11 | 12 | @DIST_STATE.on_local_main_process 13 | def print_rank_0(msg): 14 | print(msg) 15 | 16 | @dataclass 17 | class ModelArguments: 18 | """ 19 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune. 20 | """ 21 | model_name_or_path: str = field( 22 | metadata={"help": "Path to pretrained model", "required": True} 23 | ) 24 | 25 | @dataclass 26 | class DataTrainingArguments: 27 | """ 28 | Arguments pertaining to what data we are going to input our model for training and eval. 29 | """ 30 | data_path: str = field( 31 | metadata={"help": "Path to dataset", "required": True} 32 | ) 33 | use_soft_label: bool = field( 34 | default=False, 35 | metadata={"help": "Whether to use soft labels for prm training"} 36 | ) 37 | server: str = field( 38 | default="1", 39 | metadata={"help": "Server configuration"} 40 | ) 41 | 42 | def main(): 43 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 44 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 45 | print_rank_0('*'*30+f'\nModel arguments:\n{model_args}\nData arguments:\n{data_args}\nTraining arguments:\n{training_args}\n'+'*'*30) 46 | run_exp(model_args, data_args, training_args) 47 | 48 | if __name__ == '__main__': 49 | main() -------------------------------------------------------------------------------- /src/prm_training/run.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file "./train_prm/utils/dist_configs/multi_gpu.yaml" \ 2 | ./run.py \ 3 | --model_name_or_path "your model path" \ 4 | --data_path "./data/examples/soft_label_examples.json" \ 5 | --use_soft_label \ 6 | --output_dir "./train_prm/outputs" \ 7 | --overwrite_output_dir \ 8 | --per_device_train_batch_size 1 \ 9 | --per_device_eval_batch_size 4 \ 10 | --gradient_accumulation_steps 4 \ 11 | --num_train_epochs 3 \ 12 | --learning_rate 1e-4 \ 13 | --lr_scheduler_type "cosine" \ 14 | --save_strategy "steps" \ 15 | --eval_strategy "steps" \ 16 | --save_steps 200 \ 17 | --eval_steps 100 \ 18 | --save_total_limit 2 \ 19 | --weight_decay 0.01 \ 20 | --logging_steps 10 \ 21 | --log_level "info" \ 22 | --bf16 \ 23 | --report_to "wandb" \ 24 | --ddp_find_unused_parameters False \ -------------------------------------------------------------------------------- /src/prm_training/train_prm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ADaM-BJTU/O1-CODER/f01c769397afaefc89c22c51d048484a79d11c1f/src/prm_training/train_prm/__init__.py -------------------------------------------------------------------------------- /src/prm_training/train_prm/run_train.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import argparse 4 | import torch.nn as nn 5 | import os 6 | import random 7 | import numpy as np 8 | from dataclasses import dataclass 9 | from typing import Tuple, List, Dict, Union, Optional, Any 10 | from peft import PeftModel 11 | from peft import get_peft_model, LoraConfig, TaskType 12 | from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, Seq2SeqTrainingArguments 13 | from transformers import DataCollatorWithPadding, DataCollatorForSeq2Seq 14 | from transformers.modeling_outputs import CausalLMOutputWithPast 15 | from sklearn.metrics import roc_auc_score, log_loss, accuracy_score 16 | from torch.nn import BCEWithLogitsLoss 17 | from datasets import load_dataset 18 | from accelerate import PartialState 19 | 20 | DIST_STATE = PartialState() 21 | 22 | @DIST_STATE.on_local_main_process 23 | def print_rank_0(msg): 24 | print(msg) 25 | 26 | def print_rank(msg: str): 27 | local_rank = int(os.environ.get('LOCAL_RANK', -1)) 28 | print(f'[LOCAL_RANK {local_rank}]:\n{msg}') 29 | 30 | 31 | def setup_model_and_tokenizer(model_path: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: 32 | tokenizer = AutoTokenizer.from_pretrained( 33 | model_path, 34 | add_eos_token=False, 35 | ) 36 | tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token 37 | 38 | model = AutoModelForCausalLM.from_pretrained( 39 | model_path, 40 | torch_dtype=torch.bfloat16, 41 | attn_implementation="flash_attention_2", 42 | ).to('cuda') 43 | 44 | lora_config = LoraConfig( 45 | task_type=TaskType.CAUSAL_LM, # LoRA for causal language modeling task 46 | r=8, # Rank of LoRA 47 | lora_alpha=32, # Alpha scaling factor for LoRA 48 | lora_dropout=0.1, # Dropout rate for LoRA layers 49 | target_modules=["q_proj", "v_proj"], # Apply LoRA to specific layers 50 | ) 51 | 52 | model = get_peft_model(model, lora_config) 53 | if DIST_STATE.is_local_main_process: 54 | model.print_trainable_parameters() 55 | return model, tokenizer 56 | 57 | 58 | # Make corresponding modifications based on `openr/prm/code/finetune_qwen.py` 59 | class DatasetProcessor: 60 | def __init__(self, args, tokenizer: AutoTokenizer): 61 | self.args = args 62 | self.tokenizer = tokenizer 63 | self.good_token = '+' 64 | self.bad_token = '-' 65 | self.step_tag = ' Rating' 66 | self.candidate_token_ids = self.tokenizer.encode(f" {self.good_token} {self.bad_token}") 67 | self.step_tag_id = self.tokenizer.encode(f"{self.step_tag}")[-1] 68 | 69 | def print_example(self, example): 70 | print_rank_0('*' * 20 + ' Example View ' + '*' * 20) 71 | print_rank_0('Tokenized Data:\n' + '=' * 30 + '\n' 72 | f'Input_ids: {example["input_ids"]}\nAttention_mask: {example["attention_mask"]}\nLabels: {example["labels"]}') 73 | ignore_index = -100 if not self.args.use_soft_label else -100.0 74 | valid_labels = [label for label in example["labels"] if label != ignore_index] 75 | if not self.args.use_soft_label: 76 | print_rank_0('Decoded Data:\n' + '=' * 30 + '\n' 77 | f'Input: {self.tokenizer.decode(example["input_ids"])}\nLabels: {self.tokenizer.decode(valid_labels)}') 78 | else: 79 | print_rank_0('Decoded Data:\n' + '=' * 30 + '\n' 80 | f'Input: {self.tokenizer.decode(example["input_ids"])}\nLabels: {valid_labels}') 81 | 82 | def preprocess_example(self, example): 83 | messages = [ 84 | {"role": "user", "content": example["prompt"]}, 85 | {"role": "assistant", "content": example["response"]}, 86 | ] 87 | end_token_of_one_turn = '<|im_end|>\n' 88 | input_w_template = self.tokenizer.apply_chat_template(messages, tokenize=False) 89 | input_w_template = input_w_template.removesuffix(end_token_of_one_turn) 90 | input_w_template += self.step_tag 91 | 92 | tokenized_inputs = self.tokenizer(input_w_template, padding=True) 93 | 94 | indices = [i for i, x in enumerate(tokenized_inputs['input_ids']) if x == self.step_tag_id] 95 | 96 | if len(indices) != len(example['label']): 97 | example['label'] = example['label'][:len(indices)] 98 | 99 | assert len(indices) == len(example['label']) 100 | 101 | length = len(tokenized_inputs['input_ids']) 102 | 103 | if not self.args.use_soft_label: 104 | tokenized_inputs['labels'] = [-100] * length 105 | for i, idx in enumerate(indices): 106 | if example['label'][i] in ['positive', 1]: 107 | tokenized_inputs['labels'][idx] = self.candidate_token_ids[0] 108 | elif example['label'][i] in ['negative', 0]: 109 | tokenized_inputs['labels'][idx] = self.candidate_token_ids[1] 110 | else: 111 | raise ValueError('Invalid label value') 112 | tokenized_inputs['attention_mask'][idx] = 0 113 | else: # use soft labels 114 | tokenized_inputs['labels'] = [-100.0] * length 115 | for i, idx in enumerate(indices): 116 | tokenized_inputs['labels'][idx] = example['label'][i] 117 | tokenized_inputs['attention_mask'][idx] = 0 118 | 119 | assert len(tokenized_inputs["input_ids"]) == len(tokenized_inputs["labels"]) == len(tokenized_inputs["attention_mask"]) 120 | return tokenized_inputs 121 | 122 | def prepare_datasets(self, training_args: Seq2SeqTrainingArguments, test_size=0.2, seed=42): 123 | dataset = load_dataset('json', data_files=self.args.data_path, split='train') 124 | dataset = dataset.filter(lambda x: x["prompt"]) 125 | 126 | splits = dataset.train_test_split( 127 | test_size=test_size, 128 | seed=seed, 129 | shuffle=True 130 | ) 131 | 132 | with training_args.main_process_first(desc="Tokenizing datasets"): 133 | tokenized_datasets = { 134 | split: splits[split].map( 135 | self.preprocess_example, 136 | remove_columns=splits[split].column_names, 137 | ) 138 | for split in splits 139 | } 140 | 141 | print_rank_0(f"Training set size: {len(tokenized_datasets['train'])}") 142 | print_rank_0(f"Test set size: {len(tokenized_datasets['test'])}") 143 | 144 | ridx = random.randint(0, len(tokenized_datasets["train"]) - 1) 145 | self.print_example(tokenized_datasets["train"][ridx]) 146 | 147 | return tokenized_datasets 148 | 149 | # Define a custom metric function (e.g., accuracy for binary classification) 150 | def preprocess_logits_for_metrics(self, logits, labels): 151 | 152 | labels_index = torch.argwhere(torch.bitwise_or( 153 | labels == self.candidate_token_ids[0], 154 | labels == self.candidate_token_ids[1] 155 | )) 156 | gold = torch.where( 157 | labels[labels_index[:, 0], labels_index[:, 1]] == self.candidate_token_ids[1], 158 | 0, 1 159 | ) 160 | labels_index[:, 1] -= 1 161 | logits = logits[labels_index[:, 0], labels_index[:, 1]][:, [ 162 | self.candidate_token_ids[1], 163 | self.candidate_token_ids[0] 164 | ]] 165 | prob = torch.softmax(logits, dim=-1) 166 | return prob[:, 1], gold 167 | 168 | def preprocess_logits_for_soft_label_metrics(self, logits, labels): 169 | 170 | labels_index = labels.ne(-100.0).nonzero() 171 | positive_labels = labels[labels_index[:, 0], labels_index[:, 1]] 172 | negative_labels = 1 - positive_labels 173 | gold = torch.stack([positive_labels, negative_labels], dim=-1).argmax(dim=-1) 174 | labels_index[:, 1] -= 1 175 | logits = logits[labels_index[:, 0], labels_index[:, 1]][:, self.candidate_token_ids] 176 | return logits, gold 177 | 178 | def compute_metrics(self, eval_pred): 179 | 180 | if not self.args.use_soft_label: 181 | pre, labels = eval_pred 182 | auc = roc_auc_score(pre[1], pre[0]) 183 | ll = log_loss(pre[1], pre[0]) 184 | acc = accuracy_score(pre[1], pre[0] > 0.5) 185 | result = { 186 | 'auc': auc, 187 | 'll': ll, 188 | 'acc': acc, 189 | } 190 | else: 191 | predictions, labels = eval_pred 192 | acc = accuracy_score(predictions[0].argmax(axis=-1), predictions[1]) 193 | result = { 194 | 'acc': acc, 195 | } 196 | print_rank_0(result) 197 | return result 198 | 199 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 200 | from transformers.utils import PaddingStrategy 201 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning 202 | 203 | @dataclass 204 | class DataCollatorForSeq2SeqWithSoftLabels: 205 | """ 206 | Data collator that will dynamically pad the inputs received, as well as the labels. 207 | This version supports soft labels (float values) in the label tensors. 208 | 209 | Args: 210 | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): 211 | The tokenizer used for encoding the data. 212 | model ([`PreTrainedModel`], *optional*): 213 | The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to 214 | prepare the *decoder_input_ids* 215 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): 216 | Select a strategy to pad the returned sequences 217 | max_length (`int`, *optional*): 218 | Maximum length of the returned list and optionally padding length. 219 | pad_to_multiple_of (`int`, *optional*): 220 | If set will pad the sequence to a multiple of the provided value. 221 | label_pad_token_id (`float`, *optional*, defaults to 0.0): 222 | The value to use when padding the labels. Changed to 0.0 for soft labels. 223 | return_tensors (`str`, *optional*, defaults to `"pt"`): 224 | The type of Tensor to return. Allowable values are "np", "pt" and "tf". 225 | """ 226 | 227 | tokenizer: PreTrainedTokenizerBase 228 | model: Optional[Any] = None 229 | padding: Union[bool, str, PaddingStrategy] = True 230 | max_length: Optional[int] = None 231 | pad_to_multiple_of: Optional[int] = None 232 | label_pad_token_id: float = -100.0 # Changed to float 233 | return_tensors: str = "pt" 234 | 235 | def __call__(self, features, return_tensors=None): 236 | if return_tensors is None: 237 | return_tensors = self.return_tensors 238 | 239 | label_name = "label" if "label" in features[0].keys() else "labels" 240 | labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None 241 | 242 | # Handle None labels 243 | if labels is not None and all(label is None for label in labels): 244 | labels = None 245 | 246 | non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features] 247 | 248 | # Process inputs without labels 249 | batch = pad_without_fast_tokenizer_warning( 250 | self.tokenizer, 251 | non_labels_features, 252 | padding=self.padding, 253 | max_length=self.max_length, 254 | pad_to_multiple_of=self.pad_to_multiple_of, 255 | return_tensors=return_tensors, 256 | ) 257 | 258 | # Process labels if they exist 259 | no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD 260 | if labels is not None: 261 | if no_padding: 262 | if isinstance(features[0][label_name], list): 263 | batch["labels"] = list(labels) 264 | else: 265 | batch["labels"] = [np.concatenate([label, []]) for label in labels] 266 | else: 267 | max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None 268 | max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length 269 | 270 | if self.pad_to_multiple_of is not None: 271 | max_label_length = ( 272 | (max_label_length + self.pad_to_multiple_of - 1) 273 | // self.pad_to_multiple_of 274 | * self.pad_to_multiple_of 275 | ) 276 | 277 | padding_side = self.tokenizer.padding_side 278 | 279 | if isinstance(features[0][label_name], list): 280 | for idx, label in enumerate(labels[0]): 281 | if type(label) != type(self.label_pad_token_id): 282 | raise ValueError( 283 | f'The {idx} th label is of type {type(label)} while the label_pad_token_id is of type {type(self.label_pad_token_id)}, ' 284 | 'you should make sure that they are of the same type' 285 | ) 286 | 287 | batch["labels"] = [ 288 | label + [self.label_pad_token_id] * (max_label_length - len(label)) 289 | if padding_side == "right" 290 | else [self.label_pad_token_id] * (max_label_length - len(label)) + label 291 | for label in labels 292 | ] 293 | else: 294 | # Convert to float16 for soft labels 295 | batch["labels"] = [ 296 | np.concatenate( 297 | [ 298 | label.astype(np.float16), 299 | np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.float16), 300 | ] 301 | ) 302 | if padding_side == "right" 303 | else np.concatenate( 304 | [ 305 | np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.float16), 306 | label.astype(np.float16), 307 | ] 308 | ) 309 | for label in labels 310 | ] 311 | 312 | # Convert to appropriate tensor type 313 | if batch.get("labels", None) is not None: 314 | if return_tensors == "pt": 315 | import torch 316 | batch["labels"] = torch.tensor(batch["labels"], dtype=torch.float16) # Changed to float16 317 | else: 318 | raise NotImplementedError(f"return_tensors='{return_tensors}' not supported yet.") 319 | else: 320 | batch["labels"] = None 321 | 322 | if ( 323 | labels is not None 324 | and self.model is not None 325 | and hasattr(self.model, "prepare_decoder_input_ids_from_labels") 326 | ): 327 | decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"]) 328 | batch["decoder_input_ids"] = decoder_input_ids 329 | 330 | return batch 331 | 332 | 333 | class PRMTrainerForTokenPrediction(Trainer): 334 | def __init__(self, prm_use_tokens_cfg: Dict[str, Union[int, List[int]]], **kwargs): 335 | super().__init__(**kwargs) 336 | self.prm_use_tokens_cfg = prm_use_tokens_cfg 337 | self.loss_func = nn.functional.cross_entropy 338 | # self.model_accepts_loss_kwargs = False 339 | 340 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 341 | labels = inputs.pop("labels") 342 | if self.model_accepts_loss_kwargs: 343 | loss_kwargs = {} 344 | if num_items_in_batch is not None: 345 | loss_kwargs["num_items_in_batch"] = num_items_in_batch 346 | inputs = {**inputs, **loss_kwargs} 347 | 348 | outputs = model(**inputs) 349 | if self.args.past_index >= 0: 350 | self._past = outputs[self.args.past_index] 351 | 352 | placeholder_token_id = self.prm_use_tokens_cfg["placeholder_token_id"] 353 | candidate_token_ids_for_prediction = self.prm_use_tokens_cfg["candidate_token_ids_for_prediction"] 354 | 355 | logits = outputs.logits 356 | # new_labels = torch.zeros_like(logits).to(outputs.logits.dtype) 357 | # positive_labels = labels.to(logits.dtype) 358 | # negative_labels = 1 - positive_labels 359 | # new_labels[..., candidate_token_ids_for_prediction] = torch.stack([negative_labels, positive_labels], dim=-1) 360 | 361 | # reference from https://github.com/OpenRLHF/OpenRLHF/blob/460477d628751bfaa95297af2763f2fd729ecd20/openrlhf/models/loss.py#L259 362 | placeholder_positions = (inputs["input_ids"] == placeholder_token_id).nonzero() 363 | shift_placeholder_positions = placeholder_positions.clone() 364 | shift_placeholder_positions[:, -1] -= 1 365 | logits = logits[shift_placeholder_positions[:, 0], shift_placeholder_positions[:, 1], :] 366 | labels = labels[placeholder_positions[:, 0], placeholder_positions[:, 1]] 367 | if len(candidate_token_ids_for_prediction) != 2: 368 | raise ValueError("The number of candidate tokens for prediction must be 2.") 369 | logits = logits[..., candidate_token_ids_for_prediction] 370 | positive_labels = labels.to(logits.dtype) 371 | negative_labels = 1 - positive_labels 372 | labels = torch.stack([positive_labels, negative_labels], dim=-1) 373 | reduction = 'sum' if num_items_in_batch is not None else 'mean' 374 | loss = self.loss_func(logits, labels, reduction=reduction) 375 | if reduction == 'sum': 376 | loss /= num_items_in_batch 377 | 378 | if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: 379 | loss *= self.accelerator.num_processes 380 | 381 | return (loss, outputs) if return_outputs else loss 382 | 383 | 384 | def run_exp(model_args, data_args, training_args): 385 | print_rank_0('loading model and toeknizer...') 386 | model, tokenizer = setup_model_and_tokenizer(model_args.model_name_or_path) 387 | 388 | processor = DatasetProcessor(data_args, tokenizer) 389 | print_rank_0('start data processing...') 390 | tokenized_datasets = processor.prepare_datasets(training_args) 391 | if not data_args.use_soft_label: 392 | data_collator = DataCollatorForSeq2Seq(tokenizer) 393 | else: 394 | data_collator = DataCollatorForSeq2SeqWithSoftLabels(tokenizer) 395 | 396 | world_size = DIST_STATE.num_processes 397 | per_device_total_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps 398 | TOTAL_BATCH_SIZE = per_device_total_batch_size * world_size 399 | print_rank_0(f"Total batch size: {TOTAL_BATCH_SIZE}") 400 | 401 | fp = f'bs_{TOTAL_BATCH_SIZE}_g_{training_args.gradient_accumulation_steps}_lr_{training_args.learning_rate}_ep_{training_args.num_train_epochs}' 402 | training_args.output_dir = os.path.join(training_args.output_dir, fp) 403 | training_args.logging_dir = os.path.join(training_args.output_dir, 'logs') 404 | 405 | prm_use_tokens_cfg = { 406 | "placeholder_token_id": processor.step_tag_id, 407 | "candidate_token_ids_for_prediction": processor.candidate_token_ids, 408 | } 409 | if data_args.use_soft_label: 410 | trainer = PRMTrainerForTokenPrediction( 411 | prm_use_tokens_cfg=prm_use_tokens_cfg, 412 | model=model, 413 | args=training_args, 414 | train_dataset=tokenized_datasets["train"], 415 | eval_dataset=tokenized_datasets["test"], # Replace with a validation set if available 416 | data_collator=data_collator, 417 | tokenizer=tokenizer, 418 | preprocess_logits_for_metrics=processor.preprocess_logits_for_soft_label_metrics, 419 | compute_metrics=processor.compute_metrics, 420 | ) 421 | else: 422 | trainer = Trainer( 423 | model=model, 424 | args=training_args, 425 | train_dataset=tokenized_datasets["train"], 426 | eval_dataset=tokenized_datasets["test"], # Replace with a validation set if available 427 | data_collator=data_collator, 428 | tokenizer=tokenizer, 429 | preprocess_logits_for_metrics=processor.preprocess_logits_for_metrics, 430 | compute_metrics=processor.compute_metrics, 431 | ) 432 | 433 | trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 434 | 435 | trainer.save_state() 436 | trainer.save_model(output_dir=training_args.output_dir) 437 | 438 | if __name__ == '__main__': 439 | ... -------------------------------------------------------------------------------- /src/prm_training/train_prm/utils/accelerator_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from accelerate import Accelerator 3 | from typing import Optional 4 | 5 | class AcceleratorManager: 6 | _instance: Optional[Accelerator] = None 7 | 8 | @classmethod 9 | def initialize(cls, **kwargs): 10 | if cls._instance is None: 11 | cls._instance = Accelerator(**kwargs) 12 | 13 | if cls._instance.is_main_process: 14 | print(f"Distributed training setup:") 15 | print(f"- Number of processes: {cls._instance.num_processes}") 16 | print(f"- Mixed precision: {cls._instance.mixed_precision}") 17 | print(f"- Gradient accumulation steps: {cls._instance.gradient_accumulation_steps}") 18 | 19 | @classmethod 20 | def get_accelerator(cls) -> Accelerator: 21 | if cls._instance is None: 22 | raise RuntimeError( 23 | "Accelerator not initialized. Call AcceleratorManager.initialize() first." 24 | ) 25 | return cls._instance 26 | 27 | @classmethod 28 | def is_initialized(cls) -> bool: 29 | return cls._instance is not None 30 | 31 | 32 | def get_accelerator() -> Accelerator: 33 | return AcceleratorManager.get_accelerator() 34 | 35 | def is_main_process() -> bool: 36 | return get_accelerator().is_main_process 37 | 38 | def get_local_rank() -> int: 39 | return get_accelerator().local_process_index 40 | 41 | def get_world_size() -> int: 42 | return get_accelerator().num_processes 43 | 44 | def synchronize(): 45 | get_accelerator().wait_for_everyone() 46 | 47 | def print_rank_0(msg): 48 | if is_main_process(): 49 | print(msg) -------------------------------------------------------------------------------- /src/prm_training/train_prm/utils/dist_configs/ds_zero2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 2, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 5e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 5e8, 25 | "contiguous_gradients": true, 26 | "round_robin_gradients": true 27 | } 28 | } -------------------------------------------------------------------------------- /src/prm_training/train_prm/utils/dist_configs/multi_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | # distributed_type: MULTI_GPU 4 | deepspeed_config: 5 | deepspeed_config_file: ./train_prm/utils/dist_configs/ds_zero2_config.json 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | gpu_ids: all 9 | machine_rank: 0 10 | # main_training_function: main 11 | # mixed_precision: 'bf16' 12 | num_machines: 1 13 | num_processes: 2 14 | rdzv_backend: static 15 | same_network: true 16 | tpu_env: [] 17 | tpu_use_cluster: false 18 | tpu_use_sudo: false 19 | use_cpu: false -------------------------------------------------------------------------------- /src/prm_training/train_prm/utils/dist_configs/single_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | # main_training_function: main 8 | mixed_precision: 'bf16' 9 | num_machines: 1 10 | num_processes: 1 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false --------------------------------------------------------------------------------