├── .gitignore ├── LICENSE ├── README.md ├── benchmark ├── README.md └── apeval.json ├── data ├── README.md ├── commit.py ├── submit.py └── submit_process.py ├── eval ├── README.md ├── eval_apeval.py ├── eval_humaneval.py ├── eval_mbpp.py ├── extract_results.py ├── search_and_replace.py └── utils.py ├── gen ├── __init__.py ├── genaiprogrammer.py ├── gencht.py ├── geninst.py ├── genjudge.py ├── llmeval.py ├── llmgen.py ├── openai.py ├── template │ ├── __init__.py │ ├── aiprogrammer_template.py │ ├── gencht_template.py │ ├── geninst_template.py │ └── genjudge_template.py └── utils.py ├── generic ├── __init__.py ├── special_tokens.py └── utils.py ├── model_map.json ├── pictures ├── APEval.png ├── CursorWeb.gif ├── EvalPlus_CanItEdit_OctoPack.png ├── conversation.png ├── cursorcore.png └── discord.png ├── requirements.txt ├── src ├── README.md ├── aiprogrammer.py ├── data_collection.py ├── merge_data.py ├── post_collection.py └── utils.py ├── tests └── __init__.py └── train ├── README.md ├── ds_config.json ├── prepare_conversation.py ├── prepare_data.py └── training.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *tmp* 3 | **/*.log 4 | .vscode/ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CursorCore: Assist Programming through Aligning Anything 2 | 3 |
4 | [📄arXiv] |
5 | [🤗HF Paper] |
6 | [🤖Models] |
7 | [🛠️Code] |
8 | [Web] |
9 | [
Discord]
10 |
34 |
35 |
\n````
54 | - `start_line` and `end_line` specify the line range in the current code to be replaced by ``.
55 |
56 | If an error occurs during processing, the function prints the exception and returns the original `current` code.
57 | """
58 | try:
59 | output = output.split(NEXT_START)[-1].split(NEXT_END)[0].strip()
60 | pattern = r"(\d+),(\d+)\n```(.*?)\n([\s\S]*?)\n```"
61 | lc = re.findall(pattern, output)
62 | current_lines = current.split("\n")
63 | for start_line, end_line, _, code in lc[::-1]:
64 | start_line = int(start_line)
65 | end_line = int(end_line)
66 | current_lines = current_lines[:start_line] + code.split("\n") + current_lines[end_line:]
67 | return "\n".join(current_lines)
68 | except Exception as e:
69 | print(e)
70 | return current
71 |
72 | def postprocess_output_sr(current, output):
73 | """
74 | Post-processes the output string by extracting and applying search-and-replace operations.
75 |
76 | Args:
77 | current (str): The current string content to be modified.
78 | output (str): The output string containing the search-and-replace instructions.
79 |
80 | Returns:
81 | str: The modified string after applying the search-and-replace operations.
82 |
83 | The function performs the following steps:
84 | 1. Extracts the relevant portion of the output string between NEXT_START and NEXT_END markers.
85 | 2. Finds all code blocks within the extracted portion using a regular expression pattern.
86 | 3. For each code block, splits it into 'before' and 'after' parts using the SEARCH_AND_REPLACE marker.
87 | 4. Finds the best match for the 'before' part in the current string.
88 | 5. Replaces the matched portion in the current string with the 'after' part.
89 | 6. Returns the modified string.
90 |
91 | If an exception occurs during processing, the function prints the exception and returns the original current string.
92 | """
93 | try:
94 | output = output.split(NEXT_START)[-1].split(NEXT_END)[0].strip()
95 | pattern = r"```(.*?)\n([\s\S]*?)\n```"
96 | sr = re.findall(pattern, output)
97 | current_lines = current.split("\n")
98 | for _, before_and_after in sr[::-1]:
99 | before, after = before_and_after.split("\n" + SEARCH_AND_REPLACE + "\n")
100 | match_before = find_best_match(before, current)
101 | current_lines = current_lines[:match_before.start] + after.split("\n") + current_lines[match_before.end:]
102 | return "\n".join(current_lines)
103 | except Exception as e:
104 | print(e)
105 | return current
106 |
107 | def postprocess_output_base(current, output):
108 | """
109 | Processes the given output string to extract content between specific markers.
110 |
111 | This function splits the output string using the NEXT_START and NEXT_END markers,
112 | then uses a regular expression to find content enclosed in triple backticks (```)
113 | and returns the first match. If an error occurs during processing, the function
114 | returns the current string.
115 |
116 | Args:
117 | current (str): The current string to return in case of an error.
118 | output (str): The output string to be processed.
119 |
120 | Returns:
121 | str: The extracted content between the markers, or the current string if an error occurs.
122 | """
123 | try:
124 | output = output.split(NEXT_START)[-1].split(NEXT_END)[0]
125 | pattern = r"```(.*?)\n([\s\S]*?)\n```"
126 | wf = re.findall(pattern, output)
127 | return wf[0][1]
128 | except Exception as e:
129 | print(e)
130 | return current
131 |
132 | def postprocess_output_instruct(current, output):
133 | """
134 | Post-processes the given output string by extracting the content between
135 | specific markers and returning the first code block found.
136 |
137 | Args:
138 | current (str): The current string to return in case of an error.
139 | output (str): The output string to be processed.
140 |
141 | Returns:
142 | str: The extracted code block from the output string. If an error occurs,
143 | the function returns the 'current' string.
144 | """
145 | try:
146 | output = output.split(NEXT_START)[-1].split(NEXT_END)[0]
147 | pattern = r"```(.*?)\n([\s\S]*?)\n```"
148 | wf = re.findall(pattern, output)
149 | return wf[0][1]
150 | except Exception as e:
151 | print(e)
152 | return current
153 |
154 | def prepare_input_for_instruct(sample):
155 | """
156 | Prepares a conversation input for an instruction-following model based on the provided sample.
157 |
158 | Args:
159 | sample (dict): A dictionary containing the following keys:
160 | - "history" (list): A list of dictionaries, each containing:
161 | - "code" (str): The code from a previous programming process.
162 | - "lang" (str): The programming language of the code.
163 | - "current" (dict): A dictionary containing:
164 | - "code" (str): The current code to be modified.
165 | - "lang" (str): The programming language of the current code.
166 | - "user" (str): The user's instruction for modifying the current code.
167 |
168 | Returns:
169 | list: A list of dictionaries representing the conversation, where each dictionary contains:
170 | - "role" (str): The role in the conversation, either "user" or "assistant".
171 | - "content" (str): The content of the message.
172 | """
173 | conversation = []
174 | one_shot = [{"role": "user", "content": "Read the following messages during programming and return the modified code in this format:\n\n<|next_start|>{modified code}<|next_end|>\n\nProgramming process 1:\n```python\na = 1\nb = 2\nc = a + b\n```\n\nCurrent code:\n```python\ni = 1\nb = 2\nc = a + b\n```\n\nUser instruction:\nPlease change variable names."}, {"role": "assistant", "content": "<|next_start|>```python\ni = 1\nj = 2\nk = i + j\n```<|next_end|>"}]
175 | conversation += one_shot
176 | prompt = ""
177 | prompt += "Read the following messages during programming and return the modified code in this format:\n\n<|next_start|>{modified code}<|next_end|>"
178 | prompt += "\n\n"
179 | if sample["history"]:
180 | for i, h in enumerate(sample["history"]):
181 | prompt += f"Programming process {i + 1}:\n"
182 | prompt += decorate_code(h["code"], lang=h["lang"]) + "\n\n"
183 | prompt += "Current code:\n"
184 | prompt += decorate_code(sample["current"]["code"], lang=sample["current"]["lang"]) + "\n\n"
185 | if sample["user"]:
186 | prompt += "User instruction:\n"
187 | prompt += sample["user"] + "\n\n"
188 | conversation.append({"role": "user", "content": prompt.strip()})
189 | return conversation
190 |
191 | def prepare_input_for_base(sample):
192 | """
193 | Prepares a formatted input string for a base model by combining a prompt, one-shot example,
194 | and the provided sample data including history, current code, and user instructions.
195 |
196 | Args:
197 | sample (dict): A dictionary containing the following keys:
198 | - "history" (list): A list of dictionaries, each containing:
199 | - "code" (str): The code snippet from the history.
200 | - "lang" (str): The programming language of the code snippet.
201 | - "current" (dict): A dictionary containing:
202 | - "code" (str): The current code snippet.
203 | - "lang" (str): The programming language of the current code snippet.
204 | - "user" (str): The user instruction.
205 |
206 | Returns:
207 | str: A formatted string that includes the prompt, one-shot example, and the sample data.
208 | """
209 | prompt = "Read the following messages during programming and return the modified code in this format:\n\n<|next_start|>{modified code}<|next_end|>\n\n"
210 | one_shot = "<|messages_start|>Programming process 1:\n```python\na = 1\nb = 2\nc = a + b\n```\n\nCurrent code:\n```python\ni = 1\nb = 2\nc = a + b\n```\n\nUser instruction:\nPlease change variable names.<|messages_end|>\n\n<|next_start|>```python\ni = 1\nj = 2\nk = i + j\n```<|next_end|>\n\n"
211 | prompt += one_shot
212 | prompt += "Read the following messages during programming and return the modified code in this format:\n\n<|next_start|>{modified code}<|next_end|>\n\n<|messages_start|>"
213 | if sample["history"]:
214 | for i, h in enumerate(sample["history"]):
215 | prompt += f"Programming process {i + 1}:\n"
216 | prompt += decorate_code(h["code"], lang=h["lang"]) + "\n\n"
217 | prompt += "Current code:\n"
218 | prompt += decorate_code(sample["current"]["code"], lang=sample["current"]["lang"]) + "\n\n"
219 | if sample["user"]:
220 | prompt += "User instruction:\n"
221 | prompt += sample["user"] + "\n\n"
222 | prompt = prompt.strip() + "<|messages_end|>\n\n"
223 | return prompt
224 |
225 | def prepare_input_for_wf(sample):
226 | """
227 | Prepares the input data for workflow processing by formatting the conversation history.
228 |
229 | Args:
230 | sample (dict): A dictionary containing the following keys:
231 | - "history" (list): A list of dictionaries representing past messages. Each dictionary contains:
232 | - "code" (str): The code snippet.
233 | - "lang" (str): The programming language of the code snippet.
234 | - "current" (dict): A dictionary representing the current message with the same structure as the history messages.
235 | - "user" (str): The user's input message.
236 |
237 | Returns:
238 | list: A list of dictionaries representing the conversation. Each dictionary contains:
239 | - "role" (str): The role of the message, either "history", "current", or "user".
240 | - "content" (str): The formatted content of the message.
241 | """
242 | conversation = []
243 | if sample["history"]:
244 | history_current = sample["history"] + [sample["current"]]
245 | for m1, m2 in zip(history_current[:-1], history_current[1:]):
246 | message = decorate_code(m1["code"], m1["lang"])
247 | conversation.append({"role": "history", "content": message})
248 | conversation.append({"role": "current", "content": decorate_code(sample["current"]["code"], sample["current"]["lang"])})
249 | if sample["user"]:
250 | conversation.append({"role": "user", "content": sample["user"]})
251 | return conversation
252 |
253 | def prepare_input_for_lc(sample):
254 | """
255 | Prepares the input data for language model conversation (LC) by processing the sample's history, current code, and user input.
256 |
257 | Args:
258 | sample (dict): A dictionary containing the following keys:
259 | - "history" (list): A list of dictionaries representing the history of code changes.
260 | - "current" (dict): A dictionary representing the current code state with keys:
261 | - "code" (str): The current code.
262 | - "lang" (str): The programming language of the current code.
263 | - "user" (str): The user's input or query.
264 |
265 | Returns:
266 | list: A list of dictionaries representing the conversation, where each dictionary has:
267 | - "role" (str): The role in the conversation, either "history", "current", or "user".
268 | - "content" (str): The content associated with the role, such as code changes or user input.
269 | """
270 | conversation = []
271 | if sample["history"]:
272 | history_current = sample["history"] + [sample["current"]]
273 | for m1, m2 in zip(history_current[:-1], history_current[1:]):
274 | changes_lines = extract_changes_lines(m2["code"], m1["code"])
275 | locations_changes = generate_locations_changes(m2["code"], m1["code"], m1["lang"], changes_lines)
276 | conversation.append({"role": "history", "content": locations_changes})
277 | conversation.append({"role": "current", "content": decorate_code(sample["current"]["code"], sample["current"]["lang"], use_line_num=True)})
278 | if sample["user"]:
279 | conversation.append({"role": "user", "content": sample["user"]})
280 | return conversation
281 |
282 | def prepare_input_for_sr(sample):
283 | """
284 | Prepares the input for a search and replace (SR) task based on the provided sample.
285 |
286 | Args:
287 | sample (dict): A dictionary containing the following keys:
288 | - "history" (list): A list of dictionaries representing the history of code changes.
289 | - "current" (dict): A dictionary representing the current code state with keys:
290 | - "code" (str): The current code.
291 | - "lang" (str): The programming language of the current code.
292 | - "user" (str, optional): A string representing the user's input or query.
293 |
294 | Returns:
295 | list: A list of dictionaries representing the conversation for the SR task. Each dictionary contains:
296 | - "role" (str): The role in the conversation, either "history", "current", or "user".
297 | - "content" (str): The content associated with the role, such as code changes or user input.
298 | """
299 | conversation = []
300 | if sample["history"]:
301 | history_current = sample["history"] + [sample["current"]]
302 | for m1, m2 in zip(history_current[:-1], history_current[1:]):
303 | changes_lines = extract_changes_lines(m2["code"], m1["code"], unique=True, merge_changes=True)
304 | changes_lines = [(new, old) for old, new in changes_lines]
305 | search_and_replace = generate_search_and_replace(m1["code"], m2["code"], m1["lang"], changes_lines)
306 | conversation.append({"role": "history", "content": search_and_replace})
307 | conversation.append({"role": "current", "content": decorate_code(sample["current"]["code"], sample["current"]["lang"])})
308 | if sample["user"]:
309 | conversation.append({"role": "user", "content": sample["user"]})
310 | return conversation
311 |
--------------------------------------------------------------------------------
/gen/__init__.py:
--------------------------------------------------------------------------------
1 | from .genaiprogrammer import AIProgrammer
2 | from .genjudge import GenJudgement
3 | from .geninst import GenInstruction
4 | from .gencht import GenChat
5 | from .llmeval import GenEvaluation
6 |
--------------------------------------------------------------------------------
/gen/genaiprogrammer.py:
--------------------------------------------------------------------------------
1 | import random
2 | from .llmgen import LLMGen
3 | from .template.aiprogrammer_template import NOVICE_AIPROGRAMMER_SYSTEM, ORDINARY_AIPROGRAMMER_SYSTEM, EXPERT_AIPROGRAMMER_SYSTEM, AIPROGRAMMER_PROMPT_INPUT, AIPROGRAMMER_PROMPT_OUTPUT, NOVICE_AIPROGRAMMER_FEWSHOT, ORDINARY_AIPROGRAMMER_FEWSHOT, EXPERT_AIPROGRAMMER_FEWSHOT
4 | from .utils import extract_code_blocks
5 |
6 | class AIProgrammer(LLMGen):
7 | def __init__(self, backend="openai", model_map={}, max_try=5, **kwargs) -> None:
8 | super().__init__(backend, model_map, max_try, **kwargs)
9 |
10 | def create_prompt(self, text_map):
11 | """
12 | Generates a prompt for an AI programmer based on a random skill level.
13 |
14 | The function randomly selects a skill level from "NOVICE", "ORDINARY", or "EXPERT".
15 | Based on the selected skill level, it sets the appropriate system message and few-shot examples.
16 | It then constructs a list of messages that includes the system message, alternating user and assistant
17 | messages from the few-shot examples, and a final user message based on the provided text_map.
18 |
19 | Args:
20 | text_map (dict): A dictionary containing the input text to be formatted into the final user message.
21 |
22 | Returns:
23 | list: A list of dictionaries representing the prompt messages for the AI programmer.
24 | """
25 | dice = random.choice(["NOVICE", "ORDINARY", "EXPERT"])
26 | if dice == "NOVICE":
27 | AIPROGRAMMER_SYSTEM = NOVICE_AIPROGRAMMER_SYSTEM
28 | AIPROGRAMMER_FEWSHOT = NOVICE_AIPROGRAMMER_FEWSHOT
29 | elif dice == "ORDINARY":
30 | AIPROGRAMMER_SYSTEM = ORDINARY_AIPROGRAMMER_SYSTEM
31 | AIPROGRAMMER_FEWSHOT = ORDINARY_AIPROGRAMMER_FEWSHOT
32 | else:
33 | AIPROGRAMMER_SYSTEM = EXPERT_AIPROGRAMMER_SYSTEM
34 | AIPROGRAMMER_FEWSHOT = EXPERT_AIPROGRAMMER_FEWSHOT
35 | out = [
36 | {"role": "system", "content": AIPROGRAMMER_SYSTEM},
37 | ] + [
38 | {
39 | "role": "user",
40 | "content": AIPROGRAMMER_PROMPT_INPUT.format_map(shot)
41 | } if i % 2 == 0 else {
42 | "role": "assistant",
43 | "content": AIPROGRAMMER_PROMPT_OUTPUT.format_map(shot)
44 | } for i, shot in enumerate(AIPROGRAMMER_FEWSHOT)
45 | ] + [
46 | {"role": "user", "content": AIPROGRAMMER_PROMPT_INPUT.format_map(text_map)}
47 | ]
48 | return out
49 |
50 | def reject_response(self, text_map, response):
51 | """
52 | Determines whether a given response should be rejected based on various criteria.
53 |
54 | Args:
55 | text_map (dict): A dictionary containing the original input text with a key "content".
56 | response (str): The response text to be evaluated.
57 |
58 | Returns:
59 | bool: True if the response should be rejected, False otherwise.
60 |
61 | The function checks the following conditions to decide if the response should be rejected:
62 | - The response length is less than 20 characters.
63 | - The response contains the word "sorry".
64 | - The response does not contain any code blocks or the last code block is not the same as the input code.
65 | - The response contains repeated code blocks.
66 | - The response contains a summary of the process and repeats the last block.
67 | - The response contains segments of code where the sum of the lengths of all but the last block is less than or equal to the length of the last block plus 2 lines.
68 | - The response contains any consecutive identical code blocks.
69 | """
70 | if len(response) < 20:
71 | return True
72 | if "sorry" in response.lower():
73 | return True
74 | try:
75 | blocks = extract_code_blocks(response)
76 | # if the last code block is not the same as the input code
77 | if len(blocks) == 0 or blocks[-1].strip() != text_map["content"].strip():
78 | return True
79 | # sometimes llm will summarize the process and repeat the last block
80 | if len(blocks) >= 2 and blocks[-2] == blocks[-1]:
81 | blocks = blocks[:-1]
82 | block_lengths = [len(block.split("\n")) for block in blocks]
83 | # filter each step is just a segment of the whole code
84 | if len(block_lengths) >= 4 and sum(block_lengths[:-1]) <= block_lengths[-1] + 2:
85 | return True
86 | # filter if any block is the same during the process
87 | elif any(blocks[i] == blocks[i + 1] for i in range(len(blocks) - 1)):
88 | return True
89 | except:
90 | return True
91 | return False
92 |
93 | def post_process(self, text_map, response):
94 | """
95 | Post-processes the response by extracting code blocks and removing any repeated blocks.
96 |
97 | Args:
98 | text_map (dict): A mapping of text elements.
99 | response (str): The response string containing code blocks.
100 |
101 | Returns:
102 | list: A list of extracted code blocks with any repeated blocks removed.
103 | """
104 | blocks = extract_code_blocks(response)
105 | # sometimes llm will summarize the process and repeat the last block
106 | if len(blocks) >= 2 and blocks[-2] == blocks[-1]:
107 | blocks = blocks[:-1]
108 | return blocks
109 |
--------------------------------------------------------------------------------
/gen/gencht.py:
--------------------------------------------------------------------------------
1 | from .llmgen import LLMGen
2 | from .template.gencht_template import GENCHAT_SYSTEM, GENCHAT_PROMPT_INPUT, GENCHAT_PROMPT_OUTPUT, GENCHAT_FEWSHOT, GENCHAT_RECORD_TYPE
3 |
4 |
5 | class GenChat(LLMGen):
6 | def __init__(self, backend="openai", model_map={}, max_try=5, num_proc=512, **kwargs) -> None:
7 | super().__init__(backend, model_map, max_try, num_proc, **kwargs)
8 |
9 | def __exit__(self, exc_type, exc_val, exc_tb):
10 | return super().__exit__(exc_type, exc_val, exc_tb)
11 |
12 | def create_prompt(self, text_map):
13 | """
14 | Generates a prompt for a conversational AI model based on provided text mappings.
15 |
16 | Args:
17 | text_map (dict): A dictionary containing the following keys:
18 | - "record" (list): A list of dictionaries, each representing a record with a "type" key.
19 | - "change" (str): A string describing the change.
20 |
21 | Returns:
22 | list: A list of dictionaries, each representing a message in the conversation. Each dictionary contains:
23 | - "role" (str): The role of the message sender, either "system", "user", or "assistant".
24 | - "content" (str): The content of the message.
25 | """
26 | out = [
27 | {"role": "system", "content": GENCHAT_SYSTEM},
28 | ] + [
29 | {"role": "user", "content": GENCHAT_PROMPT_INPUT.format_map(
30 | {
31 | "record": "\n".join(GENCHAT_RECORD_TYPE[r["type"]].format_map(r) for r in shot["record"]),
32 | "change": shot["change"],
33 | }
34 | )} if i % 2 == 0 else {"role": "assistant", "content": GENCHAT_PROMPT_OUTPUT.format_map(shot)}
35 | for i, shot in enumerate(GENCHAT_FEWSHOT)
36 | ] + [
37 | {
38 | "role": "user",
39 | "content": GENCHAT_PROMPT_INPUT.format_map(
40 | {
41 | "record": "\n".join(GENCHAT_RECORD_TYPE[r["type"]].format_map(r) for r in text_map["record"]),
42 | "change": text_map["change"],
43 | }
44 | ),
45 | },
46 | ]
47 | return out
48 |
49 | def reject_response(self, text_map, response):
50 | """
51 | Determines whether a given response should be rejected based on its format.
52 |
53 | Args:
54 | text_map (dict): A dictionary containing text mappings (not used in the current implementation).
55 | response (str): The response string to be evaluated.
56 |
57 | Returns:
58 | bool: True if the response should be rejected (i.e., it does not start with "**chat:**"), False otherwise.
59 | """
60 | if not response.strip().startswith("**chat:**"):
61 | return True
62 | return False
63 |
64 | def post_process(self, text_map, response):
65 | """
66 | Post-processes the response by stripping whitespace and extracting the relevant part.
67 |
68 | Args:
69 | text_map (dict): A dictionary containing text mappings (not used in this function).
70 | response (str): The response string to be processed.
71 |
72 | Returns:
73 | str: The processed response, which is the part after the last occurrence of "**chat:**".
74 | """
75 | return response.strip().split("**chat:**")[-1].strip()
76 |
--------------------------------------------------------------------------------
/gen/geninst.py:
--------------------------------------------------------------------------------
1 | from .llmgen import LLMGen
2 | from .template.geninst_template import GENINST_SYSTEM, GENINST_PROMPT_INPUT, GENINST_PROMPT_OUTPUT, GENINST_FEWSHOT, GENINST_RECORD_TYPE
3 |
4 |
5 | class GenInstruction(LLMGen):
6 | def __init__(self, backend="openai", model_map={}, max_try=5, num_proc=512, **kwargs) -> None:
7 | super().__init__(backend, model_map, max_try, num_proc, **kwargs)
8 |
9 | def __exit__(self, exc_type, exc_val, exc_tb):
10 | return super().__exit__(exc_type, exc_val, exc_tb)
11 |
12 | def create_prompt(self, text_map):
13 | """
14 | Generates a prompt for the language model based on the provided text map and predefined few-shot examples.
15 |
16 | Args:
17 | text_map (dict): A dictionary containing the 'record' and 'change' keys.
18 | 'record' is a list of dictionaries, each with a 'type' key.
19 | 'change' is a string describing the change.
20 |
21 | Returns:
22 | list: A list of dictionaries representing the prompt, where each dictionary has 'role' and 'content' keys.
23 | """
24 | out = [
25 | {"role": "system", "content": GENINST_SYSTEM},
26 | ] + [
27 | {"role": "user", "content": GENINST_PROMPT_INPUT.format_map(
28 | {
29 | "record": "\n".join(GENINST_RECORD_TYPE[r["type"]].format_map(r) for r in shot["record"]),
30 | "change": shot["change"],
31 | }
32 | )} if i % 2 == 0 else {"role": "assistant", "content": GENINST_PROMPT_OUTPUT.format_map(shot)}
33 | for i, shot in enumerate(GENINST_FEWSHOT)
34 | ] + [
35 | {
36 | "role": "user",
37 | "content": GENINST_PROMPT_INPUT.format_map(
38 | {
39 | "record": "\n".join(GENINST_RECORD_TYPE[r["type"]].format_map(r) for r in text_map["record"]),
40 | "change": text_map["change"],
41 | }
42 | ),
43 | },
44 | ]
45 | return out
46 |
47 | def reject_response(self, text_map, response):
48 | """
49 | Determines whether a given response should be rejected based on specific criteria.
50 |
51 | Args:
52 | text_map (dict): A dictionary containing text mappings (not used in the current implementation).
53 | response (str): The response string to be evaluated.
54 |
55 | Returns:
56 | bool: True if the response should be rejected, False otherwise.
57 |
58 | The response is rejected if:
59 | - It does not start with "**instruction:**" after stripping leading and trailing whitespace.
60 | - It contains the substring "\nNote:".
61 | - It contains the phrase "no change" (case insensitive).
62 | """
63 | if not response.strip().startswith("**instruction:**") or "\nNote:" in response or "no change" in response.lower():
64 | return True
65 | return False
66 |
67 | def post_process(self, text_map, response):
68 | """
69 | Post-processes the response by stripping leading and trailing whitespace and
70 | splitting the text based on the delimiter "**instruction:**". It returns the
71 | last segment of the split response.
72 |
73 | Args:
74 | text_map (dict): A dictionary mapping text segments (not used in this method).
75 | response (str): The response string to be post-processed.
76 |
77 | Returns:
78 | str: The processed response string.
79 | """
80 | return response.strip().split("**instruction:**")[-1].strip()
81 |
--------------------------------------------------------------------------------
/gen/genjudge.py:
--------------------------------------------------------------------------------
1 | from .llmgen import LLMGen
2 | from .template.genjudge_template import GENJUDGE_SYSTEM, GENJUDGE_RECORD_TYPE, GENJUDGE_PROMPT_INPUT, GENJUDGE_PROMPT_INPUT_RECORD, GENJUDGE_PROMPT_INPUT_CHANGE, GENJUDGE_PROMPT_OUTPUT, GENJUDGE_FEWSHOT
3 |
4 | class GenJudgement(LLMGen):
5 | def __init__(self, backend="openai", model_map={}, max_try=5, num_proc=512, **kwargs) -> None:
6 | super().__init__(backend, model_map, max_try, num_proc, **kwargs)
7 |
8 | def __exit__(self, exc_type, exc_val, exc_tb):
9 | return super().__exit__(exc_type, exc_val, exc_tb)
10 |
11 | def create_prompt(self, text_map):
12 | """
13 | Generates a prompt for the GENJUDGE model based on the provided text map.
14 |
15 | The prompt consists of a series of messages formatted for the model, including
16 | system instructions, few-shot examples, and the user input based on the given text map.
17 |
18 | Args:
19 | text_map (dict): A dictionary containing the 'record' and 'change' keys.
20 | 'record' is a list of dictionaries, each representing a record with a 'type' key.
21 | 'change' is a list of changes to be included in the prompt.
22 |
23 | Returns:
24 | list: A list of dictionaries, each representing a message in the prompt.
25 | The messages alternate between 'user' and 'assistant' roles, with the final message being the user input.
26 | """
27 | out = [
28 | {"role": "system", "content": GENJUDGE_SYSTEM},
29 | ] + [
30 | {"role": "user", "content": GENJUDGE_PROMPT_INPUT.format_map(
31 | {
32 | "record": GENJUDGE_PROMPT_INPUT_RECORD.format_map({"record": "\n".join(GENJUDGE_RECORD_TYPE[r["type"]].format_map(r) for r in shot["record"])}),
33 | "change": "\n".join(GENJUDGE_PROMPT_INPUT_CHANGE.format_map({"num": i+1, "change": c}) for i, c in enumerate(shot["change"])),
34 | }
35 | )} if i % 2 == 0 else {"role": "assistant", "content": GENJUDGE_PROMPT_OUTPUT.format_map(shot)}
36 | for i, shot in enumerate(GENJUDGE_FEWSHOT)
37 | ] + [
38 | {
39 | "role": "user",
40 | "content": GENJUDGE_PROMPT_INPUT.format_map(
41 | {
42 | "record": GENJUDGE_PROMPT_INPUT_RECORD.format_map({"record": "\n".join(GENJUDGE_RECORD_TYPE[r["type"]].format_map(r) for r in text_map["record"])}),
43 | "change": "\n".join(GENJUDGE_PROMPT_INPUT_CHANGE.format_map({"num": i+1, "change": c}) for i, c in enumerate(text_map["change"])),
44 | }
45 | ),
46 | },
47 | ]
48 | return out
49 |
50 | def reject_response(self, text_map, response):
51 | """
52 | Determines whether a given response should be rejected based on specific criteria.
53 |
54 | Args:
55 | text_map (dict): A dictionary containing the text data, specifically with a key "change".
56 | response (str): The response string to be evaluated.
57 |
58 | Returns:
59 | bool: True if the response should be rejected, False otherwise.
60 |
61 | Criteria for rejection:
62 | - The response length is less than 20 characters.
63 | - The response contains the word "sorry" (case insensitive).
64 | - The number of segments in the response after splitting by "Analysis of change" does not match the number of changes in text_map["change"].
65 | - Any segment in the response does not contain "**Decision:**" or does not have "True" or "False" following "**Decision:**".
66 | """
67 | if len(response) < 20:
68 | return True
69 | if "sorry" in response.lower():
70 | return True
71 | each_judge = response.split("Analysis of change")[1:]
72 | if len(each_judge) != len(text_map["change"]):
73 | return True
74 | for judge in each_judge:
75 | if "**Decision:**" not in judge or ("True" not in judge.split("**Decision:**")[-1] and "False" not in judge.split("**Decision:**")[-1]):
76 | return True
77 | return False
78 |
79 | def post_process(self, text_map, response):
80 | """
81 | Processes the response text to extract decision outcomes.
82 |
83 | Args:
84 | text_map (dict): A mapping of text elements (not used in the function).
85 | response (str): The response string containing multiple judge analyses.
86 |
87 | Returns:
88 | list: A list of boolean values representing the decisions extracted from the response.
89 | Each boolean corresponds to a decision where True indicates a positive decision
90 | and False indicates a negative decision.
91 |
92 | Raises:
93 | AssertionError: If a decision is not found in any of the judge analyses.
94 | """
95 | each_judge = response.split("Analysis of change")[1:]
96 | out = []
97 | for judge in each_judge:
98 | if "True" in judge.split("**Decision:**")[-1]:
99 | out.append(True)
100 | elif "False" in judge.split("**Decision:**")[-1]:
101 | out.append(False)
102 | else:
103 | assert False, f"Decision not found in {judge}"
104 | return out
105 |
--------------------------------------------------------------------------------
/gen/llmeval.py:
--------------------------------------------------------------------------------
1 | from .llmgen import LLMGen
2 |
3 |
4 | class GenEvaluation(LLMGen):
5 | def __init__(self, backend="openai", model_map={}, max_try=1, num_proc=512, **kwargs) -> None:
6 | super().__init__(backend, model_map, max_try, num_proc, **kwargs)
7 |
8 | def __exit__(self, exc_type, exc_val, exc_tb):
9 | return super().__exit__(exc_type, exc_val, exc_tb)
10 |
11 | def create_prompt(self, text_map):
12 | return text_map["conversation"]
13 |
14 | def reject_response(self, text_map, response):
15 | return False
16 |
17 | def post_process(self, text_map, response):
18 | return response
19 |
--------------------------------------------------------------------------------
/gen/llmgen.py:
--------------------------------------------------------------------------------
1 | import random
2 | import concurrent.futures
3 |
4 | class LLMGen:
5 | def __init__(self, backend="openai", model_map={}, max_try=5, num_proc=512, **kwargs) -> None:
6 | self.model_name = list(model_map.keys())
7 | self.max_try = max_try
8 | if backend == "openai":
9 | from .openai import OpenAICompletion
10 | self.backend = OpenAICompletion(model_map)
11 | elif backend == "test":
12 | self.backend = None
13 | else:
14 | raise ValueError(
15 | f"backend {backend} is currently not supported"
16 | )
17 | self.kwargs = kwargs
18 | self.executor = concurrent.futures.ThreadPoolExecutor(num_proc)
19 |
20 | def __exit__(self, exc_type, exc_val, exc_tb):
21 | self.executor.shutdown()
22 |
23 | def create_prompt(self, text_map, model_name):
24 | pass
25 |
26 | def reject_response(self, text_map, response):
27 | pass
28 |
29 | def post_process(self, text_map, response):
30 | pass
31 |
32 | def gen(
33 | self, text_maps, api_type="chat_completion"
34 | ):
35 | def process_text_map(text_map):
36 | """
37 | Processes a given text map by generating a response using a specified model.
38 |
39 | Args:
40 | text_map (dict): The input text map to be processed.
41 |
42 | Returns:
43 | dict: A dictionary containing the input text map and the processed output.
44 | If the response is rejected after the maximum number of tries, the output will be None.
45 |
46 | Raises:
47 | ValueError: If the specified api_type is not supported.
48 | """
49 | success = False
50 | try_count = 0
51 | answer = ""
52 | while try_count < self.max_try and not success:
53 | if type(self.model_name) == list:
54 | model_name = random.choice(self.model_name)
55 | else:
56 | model_name = self.model_name
57 | input_text = self.create_prompt(text_map)
58 | if api_type == "chat_completion":
59 | answer = self.backend.chat_completion(input_text, model_name, **self.kwargs)
60 | elif api_type == "completion":
61 | answer = self.backend.completion(input_text, model_name, **self.kwargs)
62 | else:
63 | raise ValueError(
64 | f"api_type {api_type} is currently not supported"
65 | )
66 | if self.reject_response(text_map, answer):
67 | try_count += 1
68 | continue
69 | success = True
70 | if self.reject_response(text_map, answer):
71 | return {"input": text_map, "output": None}
72 | else:
73 | return {"input": text_map, "output": self.post_process(text_map, answer)}
74 |
75 | gen = []
76 | results = self.executor.map(process_text_map, text_maps)
77 | for result in results:
78 | gen.append(result)
79 | return gen
80 |
--------------------------------------------------------------------------------
/gen/openai.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 |
3 | #TODO: Support batch api calls
4 | class OpenAICompletion:
5 | def __init__(
6 | self,
7 | model_map,
8 | ) -> None:
9 | self.client = {
10 | model: OpenAI(
11 | base_url=model_map[model]["base"],
12 | api_key=model_map[model]["api"],
13 | )
14 | for model in model_map
15 | }
16 |
17 | def chat_completion(
18 | self,
19 | prompt,
20 | model_name="deepseek-chat",
21 | temperature=0.2,
22 | max_tokens=3072,
23 | top_p=0.95,
24 | frequency_penalty=0,
25 | presence_penalty=0,
26 | stop=["<|eot_id|>", "<|im_end|>", "", "<|EOT|>", "<|endoftext|>", "<|eos|>"], # default stop tokens
27 | timeout=200,
28 | extra_body={},
29 | ):
30 | response = ""
31 | client = self.client[model_name]
32 | try:
33 | response = client.chat.completions.create(
34 | model=model_name,
35 | messages=prompt,
36 | temperature=temperature,
37 | max_tokens=max_tokens,
38 | top_p=top_p,
39 | frequency_penalty=frequency_penalty,
40 | presence_penalty=presence_penalty,
41 | stop=stop,
42 | timeout=timeout,
43 | extra_body=extra_body,
44 | )
45 | except Exception as e:
46 | print("Exception", e)
47 | if response != "":
48 | if response.choices[0].finish_reason == "length":
49 | return ""
50 | return response.choices[0].message.content
51 | return ""
52 |
53 | def completion(
54 | self,
55 | prompt,
56 | model_name="deepseek-chat",
57 | temperature=0.2,
58 | max_tokens=3072,
59 | top_p=0.95,
60 | frequency_penalty=0,
61 | presence_penalty=0,
62 | stop=["", "<|endoftext|>", "<|eos_token|>"], # default stop tokens
63 | timeout=200,
64 | extra_body={},
65 | ):
66 | response = ""
67 | client = self.client[model_name]
68 | try:
69 | response = client.completions.create(
70 | model=model_name,
71 | prompt=prompt,
72 | temperature=temperature,
73 | max_tokens=max_tokens,
74 | top_p=top_p,
75 | frequency_penalty=frequency_penalty,
76 | presence_penalty=presence_penalty,
77 | stop=stop,
78 | timeout=timeout,
79 | extra_body=extra_body,
80 | )
81 | except Exception as e:
82 | print("Exception", e)
83 | if response != "":
84 | return response.choices[0].text
85 | return ""
86 |
--------------------------------------------------------------------------------
/gen/template/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/gen/template/__init__.py
--------------------------------------------------------------------------------
/gen/template/gencht_template.py:
--------------------------------------------------------------------------------
1 | GENCHAT_SYSTEM = """You are a programming assistant. The following content includes information related to your programming assistance, which may contain the record of the programming process, the current code, the user instruction, and your predicted modifications. Please provide the chat conversation for making the prediction. This may include analyzing the past programming process, speculating on the user's intent, and explaining the planning and ideas for modifying the code. Return your chat conversation in the following format:
2 | ```
3 | **chat:**
4 | {chat}
5 | ```"""
6 |
7 | GENCHAT_RECORD_TYPE = {
8 | "current": """Current code:
9 | {current}
10 | """,
11 | "history": """Revised code changes:
12 | {history}
13 | """,
14 | "user": """User instruction:
15 | {user}
16 | """,
17 | }
18 |
19 | GENCHAT_PROMPT_INPUT = """{record}
20 | Predicted modifications:
21 | {change}"""
22 |
23 | GENCHAT_PROMPT_OUTPUT = """**chat:**
24 | {chat}"""
25 |
26 | GENCHAT_FEWSHOT = [
27 | {
28 | "record": [
29 | {
30 | "type": "history",
31 | "history": """```diff\n@@ -14,3 +14,30 @@\n if (row == n) {\n vector board = generateBoard(queens, n);\n solutions.push_back(board);\n+ } else {\n+ for (int i = 0; i < n; i++) {\n+ if (columns.find(i) != columns.end()) {\n+ continue;\n+ }\n+ int diagonal1 = row - i;\n+ if (diagonals1.find(diagonal1) != diagonals1.end()) {\n+ continue;\n+ }\n+ int diagonal2 = row + i;\n+ if (diagonals2.find(diagonal2) != diagonals2.end()) {\n+ continue;\n+ }\n+ queens[row] = i;\n+ columns.insert(i);\n+ diagonals1.insert(diagonal1);\n+ diagonals2.insert(diagonal2);\n+ backtrack(solutions, queens, n, row + 1, columns, diagonals1, diagonals2);\n+ queens[row] = -1;\n+ columns.erase(i);\n+ diagonals1.erase(diagonal1);\n+ diagonals2.erase(diagonal2);\n+ }\n+ }\n+ }\n+\n+ vector generateBoard(vector &queens, int n)\n```"""
32 | },
33 | {
34 | "type": "history",
35 | "history": """```diff\n@@ -3,41 +3,3 @@\n vector> solveNQueens(int n) {\n auto solutions = vector>();\n auto queens = vector(n, -1);\n- auto columns = unordered_set();\n- auto diagonals1 = unordered_set();\n- auto diagonals2 = unordered_set();\n- backtrack(solutions, queens, n, 0, columns, diagonals1, diagonals2);\n- return solutions;\n- }\n-\n- void backtrack(vector> &solutions, vector &queens, int n, int row, unordered_set &columns, unordered_set &diagonals1, unordered_set &diagonals2) {\n- if (row == n) {\n- vector board = generateBoard(queens, n);\n- solutions.push_back(board);\n- } else {\n- for (int i = 0; i < n; i++) {\n- if (columns.find(i) != columns.end()) {\n- continue;\n- }\n- int diagonal1 = row - i;\n- if (diagonals1.find(diagonal1) != diagonals1.end()) {\n- continue;\n- }\n- int diagonal2 = row + i;\n- if (diagonals2.find(diagonal2) != diagonals2.end()) {\n- continue;\n- }\n- queens[row] = i;\n- columns.insert(i);\n- diagonals1.insert(diagonal1);\n- diagonals2.insert(diagonal2);\n- backtrack(solutions, queens, n, row + 1, columns, diagonals1, diagonals2);\n- queens[row] = -1;\n- columns.erase(i);\n- diagonals1.erase(diagonal1);\n- diagonals2.erase(diagonal2);\n- }\n- }\n- }\n-\n- vector generateBoard(vector &queens, int n)\n```"""
36 | },
37 | {
38 | "type": "history",
39 | "history": """```diff\n@@ -3,3 +3,17 @@\n vector> solveNQueens(int n) {\n auto solutions = vector>();\n auto queens = vector(n, -1);\n+ solve(solutions, queens, n, 0, 0, 0, 0);\n+ return solutions;\n+ }\n+\n+ vector generateBoard(vector &queens, int n) {\n+ auto board = vector();\n+ for (int i = 0; i < n; i++) {\n+ string row = string(n, '.');\n+ row[queens[i]] = 'Q';\n+ board.push_back(row);\n+ }\n+ return board;\n+ }\n+};\n```"""
40 | },
41 | {
42 | "type": "current",
43 | "current": """```cpp\n1 class Solution {\n2 public:\n3 vector> solveNQueens(int n) {\n4 auto solutions = vector>();\n5 auto queens = vector(n, -1);\n6 solve(solutions, queens, n, 0, 0, 0, 0);\n7 return solutions;\n8 }\n9 \n10 vector generateBoard(vector &queens, int n) {\n11 auto board = vector();\n12 for (int i = 0; i < n; i++) {\n13 string row = string(n, '.');\n14 row[queens[i]] = 'Q';\n15 board.push_back(row);\n16 }\n17 return board;\n18 }\n19 };\n```"""
44 | }
45 | ],
46 | "change": """```diff\n@@ -5,6 +5,23 @@\n auto queens = vector(n, -1);\n solve(solutions, queens, n, 0, 0, 0, 0);\n return solutions;\n+ }\n+\n+ void solve(vector> &solutions, vector &queens, int n, int row, int columns, int diagonals1, int diagonals2) {\n+ if (row == n) {\n+ auto board = generateBoard(queens, n);\n+ solutions.push_back(board);\n+ } else {\n+ int availablePositions = ((1 << n) - 1) & (~(columns | diagonals1 | diagonals2));\n+ while (availablePositions != 0) {\n+ int position = availablePositions & (-availablePositions);\n+ availablePositions = availablePositions & (availablePositions - 1);\n+ int column = __builtin_ctz(position);\n+ queens[row] = column;\n+ solve(solutions, queens, n, row + 1, columns | position, (diagonals1 | position) >> 1, (diagonals2 | position) << 1);\n+ queens[row] = -1;\n+ }\n+ }\n }\n \n vector generateBoard(vector &queens, int n) {\n```"""
47 | },
48 | {
49 | "chat": """The revised code transitions from using sets to bitwise operations for solving the N-Queens problem, aiming to optimize the algorithm by reducing dynamic memory allocations and improving speed. Initially, the original code utilized sets (`unordered_set`) to track columns and diagonals. The user likely seeks to enhance efficiency, prompting the replacement of these sets with integer bit masks. To achieve this, the `solve` function is adjusted to use bitwise operations for conflict checks, calculating `availablePositions` using bitwise AND and NOT operations to identify valid queen placements. The `__builtin_ctz` function finds the least significant bit set to 1, corresponding to a valid column, and the recursive `solve` call updates bit masks for columns and diagonals after placing a queen. This maintains the backtracking logic while leveraging bitwise operations for performance. The `generateBoard` function remains unchanged, as it simply formats the result. This approach ensures efficient, streamlined code for the N-Queens problem."""
50 | },
51 | {
52 | "record": [
53 | {
54 | "type": "current",
55 | "current": """```\n1 const express = require('express');\n2 const app = express();\n```"""
56 | },
57 | {
58 | "type": "user",
59 | "user": """Add two new routes to the Express.js application:\n\n1. `/upper`: This route should accept a query parameter `string` and return the uppercase version of the string. If no string is provided, it should return a 400 error with a message "Please provide a string".\n2. `/lower`: This route should accept a query parameter `string` and return the lowercase version of the string. If no string is provided, it should return a 400 error with a message "Please provide a string".\n\nImplement these routes by modifying the existing code."""
60 | }
61 | ],
62 | "change": """```diff\n@@ -1,2 +1,22 @@\n const express = require('express');\n const app = express();\n+\n+app.get('/upper', (req, res) => {\n+ let string = req.query.string;\n+ if (!string) {\n+ res.status(400).send('Please provide a string');\n+ return;\n+ }\n+ res.send(string.toUpperCase());\n+});\n+\n+app.get('/lower', (req, res) => {\n+ let string = req.query.string;\n+ if (!string) {\n+ res.status(400).send('Please provide a string');\n+ return;\n+ }\n+ res.send(string.toLowerCase());\n+});\n+\n+app.listen(3000, () => console.log('Server started'));\n```"""
63 | },
64 | {
65 | "chat": """We should add two new routes to the existing Express.js application. These routes should handle the conversion of a query parameter `string` to either uppercase or lowercase, depending on the route accessed. If the query parameter is not provided, the routes should return a 400 error with an appropriate message.\n\nTo implement this, we will:\n1. Define a route `/upper` that checks for the presence of the query parameter `string`. If it exists, we will convert it to uppercase and return the result. If it does not exist, we will return a 400 error with the message "Please provide a string".\n2. Define a route `/lower` with similar logic, but it will convert the `string` to lowercase.\n\nFinally, we will ensure that the server is listening on a specified port, in this case, port 3000."""
66 | }
67 | ]
68 |
--------------------------------------------------------------------------------
/gen/template/geninst_template.py:
--------------------------------------------------------------------------------
1 | GENINST_SYSTEM = """You are a programming assistant. The following content includes information related to your programming assistance, which may contain the record of the programming process, the current code, the git commit after all changes, relevant details about the problem, and your predicted modifications. Please generate an instruction for you to make the corresponding modifications, ensuring it resembles instructions typically given by a human programmer. The instruction may be detailed or concise and may or may not specify the location of the modification. Return the generated instruction in the following format:
2 | ```
3 | **instruction:**
4 | {instruction}
5 | ```"""
6 |
7 | GENINST_RECORD_TYPE = {
8 | "current": """Current code:
9 | {current}
10 | """,
11 | "history": """Revised code changes:
12 | {history}
13 | """,
14 | "git": """Git commit message after all changes:
15 | {git}
16 | """,
17 | "problem": """Relevant details about the problem:
18 | {problem}
19 | """,
20 | }
21 |
22 | GENINST_PROMPT_INPUT = """{record}
23 | Changes in predictions:
24 | {change}"""
25 |
26 | GENINST_PROMPT_OUTPUT = """**instruction:**
27 | {instruction}"""
28 |
29 | GENINST_FEWSHOT = [
30 | {
31 | "record": [
32 | {
33 | "type": "current",
34 | "current": """```python\n1 import dedupe\n2 \n3 def detect_fraud(transactions):\n4 deduper = dedupe.Dedupe()\n5 deduper.sample(transactions)\n6 deduper.train()\n7 \n8 clusters = deduper.match(transactions, threshold=0.5)\n9 \n10 return clusters\n11 \n12 transactions = [\n13 {'id': 1, 'amount': 100.0},\n14 {'id': 2, 'amount': 200.0},\n15 {'id': 3, 'amount': 150.0},\n16 ]\n17 \n18 fraud_clusters = detect_fraud(transactions)\n19 print(fraud_clusters)\n```"""
35 | }
36 | ],
37 | "change": """```diff\n@@ -15,5 +15,10 @@\n {'id': 3, 'amount': 150.0},\n ]\n \n+# Replace '_sample' with 'sample'\n+deduper = dedupe.Dedupe()\n+deduper.sample = deduper._sample\n+del deduper._sample\n+\n fraud_clusters = detect_fraud(transactions)\n print(fraud_clusters)\n```"""
38 | },
39 | {
40 | "instruction": """Replace the internal '_sample' method with 'sample' in the dedupe object, approximately at line 17."""
41 | },
42 | {
43 | "record": [
44 | {
45 | "type": "history",
46 | "history": """```diff\n@@ -3,6 +3,10 @@\n def create_cnn_model(in_channels, config):\n layers = []\n conv2d = nn.Conv2d(in_channels, config, kernel_size=3, padding=1)\n- layers += [conv2d, nn.ReLU(inplace=True)]\n+ if batch_norm:\n+ layers += [conv2d, nn.BatchNorm2d(config)]\n+ else:\n+ layers += [conv2d]\n+ layers += [nn.ReLU(inplace=True)]\n model = nn.Sequential(*layers)\n return model\n```"""
47 | },
48 | {
49 | "type": "history",
50 | "history": """```diff\n@@ -1,6 +1,6 @@\n import torch.nn as nn\n \n-def create_cnn_model(in_channels, config):\n+def create_cnn_model(in_channels, config, batch_norm=False):\n layers = []\n conv2d = nn.Conv2d(in_channels, config, kernel_size=3, padding=1)\n if batch_norm:\n```"""
51 | },
52 | {
53 | "type": "current",
54 | "current": """```\n1 import torch.nn as nn\n2 \n3 def create_cnn_model(in_channels, config, batch_norm=False):\n4 layers = []\n5 conv2d = nn.Conv2d(in_channels, config, kernel_size=3, padding=1)\n6 if batch_norm:\n7 layers += [conv2d, nn.BatchNorm2d(config)]\n8 else:\n9 layers += [conv2d]\n10 layers += [nn.ReLU(inplace=True)]\n11 model = nn.Sequential(*layers)\n12 return model\n```"""
55 | }
56 | ],
57 | "change": """```diff\n@@ -1,12 +1,11 @@\n import torch.nn as nn\n \n-def create_cnn_model(in_channels, config, batch_norm=False):\n+def create_cnn_model(in_channels, config, batch_norm):\n layers = []\n conv2d = nn.Conv2d(in_channels, config, kernel_size=3, padding=1)\n if batch_norm:\n- layers += [conv2d, nn.BatchNorm2d(config)]\n+ layers += [conv2d, nn.BatchNorm2d(config), nn.ReLU(inplace=True)]\n else:\n- layers += [conv2d]\n- layers += [nn.ReLU(inplace=True)]\n+ layers += [conv2d, nn.ReLU(inplace=True)]\n model = nn.Sequential(*layers)\n return model\n```"""
58 | },
59 | {
60 | "instruction": """Update the create_cnn_model function to ensure that the ReLU activation function is added immediately after the BatchNorm layer if batch_norm is enabled. Adjust the function signature to remove the default value for the batch_norm parameter. The updated code should handle the addition of the ReLU layer conditionally based on the batch_norm parameter."""
61 | },
62 | {
63 | "record": [
64 | {
65 | "type": "current",
66 | "current": """```ruby\n1 # frozen_string_literal: true\n2 module Extensions::DeferredWorkflowStatePersistence::Workflow; end\n3 module Extensions::DeferredWorkflowStatePersistence::Workflow::Adapter; end\n4 module Extensions::DeferredWorkflowStatePersistence::Workflow::Adapter::DeferredActiveRecord\n5 extend ActiveSupport::Concern\n6 included do\n7 include Workflow::Adapter::ActiveRecord\n8 include InstanceMethods\n9 end\n10 \n11 module InstanceMethods\n12 def persist_workflow_state(new_value)\n13 write_attribute(self.class.workflow_column, new_value)\n14 true\n15 end\n16 end\n17 end\n18 \n```"""
67 | },
68 | {
69 | "type": "git",
70 | "git": "Include WorkflowActiverecord in the state persistence extension."
71 | }
72 | ],
73 | "change": """```diff\n@@ -1,10 +1,12 @@\n # frozen_string_literal: true\n+require 'workflow_activerecord'\n+\n module Extensions::DeferredWorkflowStatePersistence::Workflow; end\n module Extensions::DeferredWorkflowStatePersistence::Workflow::Adapter; end\n module Extensions::DeferredWorkflowStatePersistence::Workflow::Adapter::DeferredActiveRecord\n extend ActiveSupport::Concern\n included do\n- include Workflow::Adapter::ActiveRecord\n+ include WorkflowActiverecord::Adapter::ActiveRecord\n include InstanceMethods\n end\n \n```"""
74 | },
75 | {
76 | "instruction": """At the beginning of the file, add the statement `require 'workflow_activerecord'`; On line 7, change `include Workflow::Adapter::ActiveRecord` to `include WorkflowActiverecord::Adapter::ActiveRecord`; Ensure the final code reflects the necessary changes for including `WorkflowActiverecord` in the state persistence extension."""
77 | },
78 | {
79 | "record": [
80 | {
81 | "type": "history",
82 | "history": """```diff\n@@ -1,4 +1,5 @@\n import java.util.*;\n+import java.io.*;\n \n class Main\n {\n@@ -15,14 +16,18 @@\n }\n }\n \n- int n;\n- Scanner sc = new Scanner(System.in);\n- n = sc.nextInt();\n+ int ans = 0;\n+ BufferedReader bf = new BufferedReader(new InputStreamReader(System.in));\n \n- int ans = 0;\n- for(Map.Entry e : map.entrySet()){\n- Integer v = map.get(n-e.getKey());\n- ans += (v==null ? 0 : v) * e.getValue();\n+ try{\n+ int n = Integer.parseInt(bf.readLine());\n+ for(Map.Entry e : map.entrySet()){\n+ Integer v = map.get(n-e.getKey());\n+ ans += (v==null ? 0 : v) * e.getValue();\n+ }\n+ }\n+ catch(IOException ex){\n+ System.out.println(ex);\n }\n \n System.out.println(ans);\n```"""
83 | },
84 | {
85 | "type": "current",
86 | "current": """```java\n1 import java.util.*;\n2 import java.io.*;\n3 \n4 class Main\n5 {\n6 public static void main(String[] args)\n7 {\n8 HashMap map = new HashMap();\n9 \n10 for(int i=1; i<10; ++i){\n11 for(int j=1; j<10; ++j){\n12 if(!map.containsKey(i+j))\n13 map.put(i+j, 1);\n14 else\n15 map.put(i+j, map.get(i+j)+1);\n16 }\n17 }\n18 \n19 int ans = 0;\n20 BufferedReader bf = new BufferedReader(new InputStreamReader(System.in));\n21 \n22 try{\n23 int n = Integer.parseInt(bf.readLine());\n24 for(Map.Entry e : map.entrySet()){\n25 Integer v = map.get(n-e.getKey());\n26 ans += (v==null ? 0 : v) * e.getValue();\n27 }\n28 }\n29 catch(IOException ex){\n30 System.out.println(ex);\n31 }\n32 \n33 System.out.println(ans);\n34 }\n35 }\n```"""
87 | },
88 | {
89 | "type": "problem",
90 | "problem": """Problem Description:\nWrite a program which reads an integer n and identifies the number of combinations of a, b, c and d (0 ≤ a, b, c, d ≤ 9) which meet the following equality:a + b + c + d = n\nFor example, for n = 35, we have 4 different combinations of (a, b, c, d): (8, 9, 9, 9), (9, 8, 9, 9), (9, 9, 8, 9), and (9, 9, 9, 8).\n\nInput Description:\nThe input consists of several datasets. Each dataset consists of n (1 ≤ n ≤ 50) in a line. The number of datasets is less than or equal to 50.\n\nOutput Description:\nPrint the number of combination in a line.\n\nSample Input 1:\n35\n1\n\nSample Output 1:\n4\n4"""
91 | }
92 | ],
93 | "change": """```diff\n@@ -7,8 +7,8 @@\n {\n HashMap map = new HashMap();\n \n- for(int i=1; i<10; ++i){\n- for(int j=1; j<10; ++j){\n+ for(int i=0; i<10; ++i){\n+ for(int j=0; j<10; ++j){\n if(!map.containsKey(i+j))\n map.put(i+j, 1);\n else\n@@ -16,20 +16,22 @@\n }\n }\n \n- int ans = 0;\n BufferedReader bf = new BufferedReader(new InputStreamReader(System.in));\n+ String str;\n \n try{\n- int n = Integer.parseInt(bf.readLine());\n- for(Map.Entry e : map.entrySet()){\n- Integer v = map.get(n-e.getKey());\n- ans += (v==null ? 0 : v) * e.getValue();\n+ while((str = bf.readLine()) != null){\n+ int ans = 0;\n+ int n = Integer.parseInt(str);\n+ for(Map.Entry e : map.entrySet()){\n+ Integer v = map.get(n-e.getKey());\n+ ans += (v==null ? 0 : v) * e.getValue();\n+ }\n+ System.out.println(ans);\n }\n }\n catch(IOException ex){\n System.out.println(ex);\n }\n-\n- System.out.println(ans);\n }\n }\n```"""
94 | },
95 | {
96 | "instruction": """Update the code to handle multiple datasets and calculate the number of combinations of four digits (0 ≤ a, b, c, d ≤ 9) that sum up to a given number n for each dataset. Initialize the map to store the counts of all possible sums of pairs of digits (a+b). Modify the code to read multiple inputs until EOF and for each input, calculate the number of valid combinations where the sum of two pairs equals n."""
97 | },
98 | ]
99 |
--------------------------------------------------------------------------------
/gen/template/genjudge_template.py:
--------------------------------------------------------------------------------
1 | GENJUDGE_SYSTEM = """You are tasked with assisting a programmer by maintaining a record of the programming process, including potential future changes. Your role is to discern which changes the programmer desires you to propose proactively. These should align with their actual intentions and be helpful. To determine which changes align with a programmer's intentions, consider the following principles:
2 |
3 | 1. **Understand the Context**: Assess the overall goal of the programming project. Ensure that any proposed change aligns with the project's objectives and the programmer's current focus.
4 |
5 | 2. **Maintain Clear Communication**: Before proposing changes, ensure that your suggestions are clear and concise. This helps the programmer quickly understand the potential impact of each change.
6 |
7 | 3. **Prioritize Stability**: Avoid proposing changes that could introduce instability or significant complexity unless there is a clear benefit. Stability is often more valued than optimization in the early stages of development.
8 |
9 | 4. **Respect the Programmer's Preferences**: Pay attention to the programmer's coding style and preferences. Propose changes that enhance their style rather than contradict it.
10 |
11 | 5. **Incremental Improvements**: Suggest changes that offer incremental improvements rather than drastic overhauls, unless specifically requested. This approach is less disruptive and easier for the programmer to integrate.
12 |
13 | 6. **Consider Long-Term Maintenance**: Propose changes that improve code maintainability and readability. This includes refactoring for clarity, reducing redundancy, and enhancing documentation.
14 |
15 | 7. **Balance Proactivity and Reactivity**: Be proactive in suggesting improvements that are likely to be universally beneficial (e.g., bug fixes, performance enhancements). However, be reactive, not proactive, in areas where the programmer's specific intentions are unclear or where personal preference plays a significant role.
16 |
17 | For each potential change, return `True` if suggesting this change would be beneficial to the programmer, return `False` if the change does not align with the programmer's intentions or if they do not want you to predict this change. Give your decision after analyzing each change. Provide your response in the following format:
18 |
19 | ```
20 | **Analysis of change 1:**
21 |
22 | Your analysis here.
23 |
24 | **Decision:** `True` or `False`
25 |
26 | **Analysis of change 2:**
27 |
28 | Your analysis here.
29 |
30 | **Decision:** `True` or `False`
31 |
32 | ...
33 | ```"""
34 |
35 | GENJUDGE_RECORD_TYPE = {
36 | "current": """Current code:
37 | {current}
38 | """,
39 | "history": """Revised code changes:
40 | {history}
41 | """
42 | }
43 |
44 | GENJUDGE_PROMPT_INPUT = """{record}
45 | {change}"""
46 |
47 | GENJUDGE_PROMPT_INPUT_RECORD = """**record:**
48 | {record}"""
49 |
50 | GENJUDGE_PROMPT_INPUT_CHANGE = """**change {num}:**
51 | {change}"""
52 |
53 | GENJUDGE_PROMPT_OUTPUT = """{judgement}"""
54 |
55 | GENJUDGE_FEWSHOT = [
56 | {
57 | "record": [
58 | {
59 | "type": "history",
60 | "history": """```diff\n@@ -1 +1,5 @@\n+/**\n+ * magical invsqrt function from Quake III code\n+ * see: http://www.codemaestro.com/reviews/9\n+ */\n \n```"""
61 | },
62 | {
63 | "type": "current",
64 | "current": """```c\n1 /**\n2 * magical invsqrt function from Quake III code\n3 * see: http://www.codemaestro.com/reviews/9\n4 */\n5 \n```"""
65 | }
66 | ],
67 | "change": [
68 | """```diff\n@@ -3,3 +3,19 @@\n * see: http://www.codemaestro.com/reviews/9\n */\n \n+float InvSqrt(float x)\n+{\n+\tfloat xhalf = 0.5f*x;\n+\tint i = *(int*)&x;\n+\ti = 0x5f3759df - (i>>1);\n+\tx = *(float*)&i;\n+\tx = x*(1.5f-xhalf*x*x);\n+\treturn x;\n+}\n+\n+int main(void) {\n+\tint result = InvSqrt(0.00056);\n+\tprintf("Result: %d (should be 42)", result);\n+\treturn result != 42;\n+}\n+\n```"""
69 | ]
70 | },
71 | {
72 | "judgement": """**Analysis of change 1:**\n\nThe change introduces the `InvSqrt` function implementation and a basic test case within the `main` function. This change is appropriate and beneficial. The primary goal is to implement and test the `InvSqrt` function, a well-known algorithm from the Quake III code. Adding the function and a basic test case aligns directly with this goal. The implementation is clear and concise, following common C coding practices. The function's logic is straightforward and well-documented, making it easy to understand. Overall, the proposed change is a logical next step in developing the `InvSqrt` function and ensuring it works correctly.\n\n**Decision:** `True`"""
73 | },
74 | {
75 | "record": [
76 | {
77 | "type": "history",
78 | "history": """```diff\n@@ -1 +1,21 @@\n \n+package com.google.gwtjsonrpc.client;\n+\n+public class VoidResult_JsonSerializer extends JsonSerializer {\n+ public static final VoidResult_JsonSerializer INSTANCE =\n+ new VoidResult_JsonSerializer();\n+\n+ private VoidResult_JsonSerializer() {\n+ }\n+\n+ @Override\n+ public void printJson(final StringBuffer sb, final VoidResult o) {\n+ sb.append("{}");\n+ }\n+\n+ @Override\n+ public VoidResult fromJson(final Object o) {\n+ return VoidResult.INSTANCE;\n+ }\n+}\n+\n```"""
79 | },
80 | {
81 | "type": "current",
82 | "current": """```java\n1 \n2 \n3 package com.google.gwtjsonrpc.client;\n4 \n5 public class VoidResult_JsonSerializer extends JsonSerializer {\n6 public static final VoidResult_JsonSerializer INSTANCE =\n7 new VoidResult_JsonSerializer();\n8 \n9 private VoidResult_JsonSerializer() {\n10 }\n11 \n12 @Override\n13 public void printJson(final StringBuffer sb, final VoidResult o) {\n14 sb.append("{}");\n15 }\n16 \n17 @Override\n18 public VoidResult fromJson(final Object o) {\n19 return VoidResult.INSTANCE;\n20 }\n21 }\n22 \n```"""
83 | }
84 | ],
85 | "change": [
86 | """```diff@@ -1,3 +1,16 @@\n+// Copyright 2008 Google Inc.\n+//\n+// Licensed under the Apache License, Version 2.0 (the "License");\n+// you may not use this file except in compliance with the License.\n+// You may obtain a copy of the License at\n+//\n+// http://www.apache.org/licenses/LICENSE-2.0\n+//\n+// Unless required by applicable law or agreed to in writing, software\n+// distributed under the License is distributed on an "AS IS" BASIS,\n+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+// See the License for the specific language governing permissions and\n+// limitations under the License.\n \n package com.google.gwtjsonrpc.client;\n \n```"""
87 | ]
88 | },
89 | {
90 | "judgement": """**Analysis of change 1:**\n\nWithout enough information to confirm the use of the Apache License, predicting this change may not align with the programmer's intentions or the project's actual licensing requirements.\n\n**Decision:** `False`"""
91 | }
92 | ]
93 |
--------------------------------------------------------------------------------
/gen/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | def extract_code_blocks(markdown):
4 | """
5 | Extracts code blocks from a given markdown string.
6 |
7 | Args:
8 | markdown (str): The markdown string to extract code blocks from.
9 |
10 | Returns:
11 | list: A list of code blocks extracted from the markdown string.
12 | """
13 | # Define a regular expression to match code blocks. The (?s) allows '.' to match all characters including newlines.
14 | code_block_pattern = re.compile(r'```(.*?)\n(.*?)```', re.DOTALL)
15 |
16 | # Store the extracted code blocks
17 | code_blocks = []
18 |
19 | # Use finditer to iterate over all matches of code blocks
20 | for block in code_block_pattern.finditer(markdown):
21 | # Get the content of the code block
22 | code_content = block.group(2)
23 |
24 | # Remove any additional indents (when the code block is in lists or other indented structures)
25 | lines = code_content.split('\n')
26 | code_blocks.append('\n'.join(line[len(lines[-1]):] for line in lines[:-1]))
27 |
28 | return code_blocks
29 |
--------------------------------------------------------------------------------
/generic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/generic/__init__.py
--------------------------------------------------------------------------------
/generic/special_tokens.py:
--------------------------------------------------------------------------------
1 | IM_START = "<|im_start|>"
2 | IM_END = "<|im_end|>"
3 | NEXT_START = "<|next_start|>"
4 | NEXT_END = "<|next_end|>"
5 | TARGET_START = "<|target_start|>"
6 | TARGET_END = "<|target_end|>"
7 | TARGET = "<|target|>"
8 | SEARCH_AND_REPLACE = "<|search_and_replace|>"
9 | SPECIAL_WORDS=[IM_START, IM_END, NEXT_START, NEXT_END, TARGET_START, TARGET_END, TARGET, SEARCH_AND_REPLACE]
10 |
--------------------------------------------------------------------------------
/generic/utils.py:
--------------------------------------------------------------------------------
1 | import difflib
2 | from .special_tokens import *
3 |
4 | def data_args(parser):
5 | """
6 | Adds data arguments to the given argument parser.
7 |
8 | Args:
9 | parser (argparse.ArgumentParser): The argument parser to add the arguments to.
10 |
11 | Returns:
12 | argparse.ArgumentParser: The updated argument parser.
13 | """
14 | parser.add_argument("--model_map", type=str, help="Model name, base and port")
15 | parser.add_argument("--input_path", type=str, help="Input Path")
16 | parser.add_argument("--output_path", type=str, help="Output Path")
17 | parser.add_argument("--num_proc", type=int, default=512, help="Number of processes")
18 | return parser
19 |
20 | def openai_args(parser):
21 | """
22 | Adds OpenAI arguments to the given argument parser.
23 |
24 | Args:
25 | parser (argparse.ArgumentParser): The argument parser to add the arguments to.
26 |
27 | Returns:
28 | argparse.ArgumentParser: The updated argument parser.
29 | """
30 | parser.add_argument("--max_tokens", type=int, default=3072, help="Max tokens")
31 | parser.add_argument("--temperature", type=float, default=0.2, help="Temperature")
32 | parser.add_argument("--top_p", type=float, default=0.95, help="Temperature")
33 | parser.add_argument("--frequency_penalty", type=float, default=0, help="Temperature")
34 | parser.add_argument("--presence_penalty", type=float, default=0, help="Temperature")
35 | return parser
36 |
37 | def get_openai_kwargs(args):
38 | """
39 | Get OpenAI keyword arguments from the given arguments.
40 |
41 | Args:
42 | args (argparse.Namespace): The parsed arguments.
43 |
44 | Returns:
45 | dict: The OpenAI keyword arguments.
46 | """
47 | openai_kwargs = {
48 | "max_tokens": args.max_tokens,
49 | "temperature": args.temperature,
50 | "top_p": args.top_p,
51 | "frequency_penalty": args.frequency_penalty,
52 | "presence_penalty": args.presence_penalty,
53 | }
54 | return openai_kwargs
55 |
56 | def decorate_code(code, lang="", use_line_num=False, start_line=None, end_line=None):
57 | """
58 | Decorates the given code with markdown syntax.
59 |
60 | Args:
61 | code (str): The code to be decorated.
62 | lang (str, optional): The language identifier for syntax highlighting. Defaults to "".
63 | use_line_num (bool, optional): Whether to include line numbers. Defaults to False.
64 | start_line (int, optional): The starting line number. Defaults to None.
65 | end_line (int, optional): The ending line number. Defaults to None.
66 |
67 | Returns:
68 | str: The decorated code.
69 |
70 | """
71 | decorate = "```"
72 | if lang:
73 | decorate += lang.lower()
74 | decorate += "\n"
75 | code_lines = code.split("\n")
76 | if start_line is None:
77 | start_line = 0
78 | if end_line is None:
79 | end_line = len(code_lines)
80 | if start_line != 0:
81 | decorate += "...\n"
82 | if use_line_num:
83 | decorate += "\n".join([f"{i + 1 + start_line} {line}" for i, line in enumerate(code_lines[start_line: end_line])])
84 | else:
85 | decorate += "\n".join(code_lines[start_line: end_line])
86 | if end_line != len(code_lines):
87 | decorate += "\n..."
88 | decorate += "\n```"
89 | return decorate
90 |
91 | def generate_locations_changes(code1, code2, lang, changes_lines):
92 | """
93 | Generates a string representing the changes between two versions of code with specified line ranges.
94 |
95 | Args:
96 | code1 (str): The original version of the code.
97 | code2 (str): The modified version of the code.
98 | lang (str): The programming language of the code (used for syntax highlighting).
99 | changes_lines (list of tuples): A list of tuples where each tuple contains two pairs of integers.
100 | Each tuple represents the line ranges in the format ((old_start, old_end), (new_start, new_end)).
101 |
102 | Returns:
103 | str: A formatted string that shows the changes between the two code versions with syntax highlighting.
104 | """
105 | code1_lines = code1.split('\n')
106 | code2_lines = code2.split('\n')
107 | locations_changes = []
108 | for c in changes_lines:
109 | (old_start, old_end), (new_start, new_end) = c
110 | next_code = '\n'.join(code2_lines[new_start: new_end])
111 | locations_changes.append(f"{old_start},{old_end}\n```{lang}\n{next_code}\n```")
112 | return "\n".join(locations_changes)
113 |
114 | def generate_search_and_replace(code1, code2, lang, changes_lines, sep_token=SEARCH_AND_REPLACE):
115 | """
116 | Generates a search and replace string for code changes between two versions of code.
117 |
118 | Args:
119 | code1 (str): The original version of the code.
120 | code2 (str): The modified version of the code.
121 | lang (str): The programming language of the code (used for syntax highlighting).
122 | changes_lines (list of tuples): A list of tuples where each tuple contains two pairs of integers.
123 | Each pair represents the start and end line numbers of the changes in the original and modified code respectively.
124 | sep_token (str, optional): The token used to separate the search and replace sections in the output. Defaults to SEARCH_AND_REPLACE.
125 |
126 | Returns:
127 | str: A formatted string containing the search and replace sections for the specified changes,
128 | with syntax highlighting for the specified programming language.
129 | """
130 | code1_lines = code1.split('\n')
131 | code2_lines = code2.split('\n')
132 | search_and_replace = []
133 | for i in range(len(changes_lines)):
134 | (old_start, old_end), (new_start, new_end) = changes_lines[i]
135 | search = '\n'.join(code1_lines[old_start: old_end])
136 | replace = '\n'.join(code2_lines[new_start: new_end])
137 | search_and_replace.append(f"```{lang}\n{search}\n{sep_token}\n{replace}\n```")
138 | return "\n".join(search_and_replace)
139 |
140 | def extract_changes_lines(code1, code2, fromfile='', tofile='', unique=False, merge_changes=False, merge_threshold=2):
141 | """
142 | Extracts the lines that have changed between two versions of code.
143 | Args:
144 | code1 (str): The original version of the code.
145 | code2 (str): The modified version of the code.
146 | fromfile (str, optional): The name of the original file. Defaults to ''.
147 | tofile (str, optional): The name of the modified file. Defaults to ''.
148 | unique (bool, optional): If True, ensures that the changes are unique. Defaults to False.
149 | merge_changes (bool, optional): If True, merges neighboring changes. Defaults to False.
150 | merge_threshold (int, optional): The threshold for merging neighboring changes. Defaults to 2.
151 | Returns:
152 | list: A list of tuples, where each tuple contains two tuples representing the start and end lines of changes
153 | in the original and modified code respectively.
154 | """
155 | # Split the code into lines
156 | lines1 = code1.split('\n')
157 | lines2 = code2.split('\n')
158 |
159 | # Create a unified diff
160 | diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile, n=0, lineterm=''))[2:]
161 |
162 | # Process the diff to extract changes and their contexts from both files
163 | changes_lines = []
164 | current_old_start_line = None
165 | current_new_start_line = None
166 | current_old_end_line = 0
167 | current_new_end_line = 0
168 |
169 | for line in diff:
170 | if line.startswith('@@'):
171 | # Start of a new change block, flush the previous block if it exists
172 | if current_old_start_line or current_new_start_line:
173 | if current_old_end_line > current_old_start_line:
174 | current_old_start_line -= 1
175 | current_old_end_line -= 1
176 | if current_new_end_line > current_new_start_line:
177 | current_new_start_line -= 1
178 | current_new_end_line -= 1
179 | changes_lines.append(((current_old_start_line, current_old_end_line),
180 | (current_new_start_line, current_new_end_line)))
181 |
182 | # Extract line numbers from the hunk information
183 | parts = line.split()
184 | old_info, new_info = parts[1], parts[2]
185 | current_old_start_line = int(old_info.split(',')[0][1:])
186 | current_new_start_line = int(new_info.split(',')[0][1:])
187 | current_old_end_line = current_old_start_line
188 | current_new_end_line = current_new_start_line
189 | elif line.startswith('-'):
190 | # Line removed from the old file
191 | current_old_end_line += 1
192 | elif line.startswith('+'):
193 | # Line added to the new file
194 | current_new_end_line += 1
195 | elif line.startswith(' '):
196 | assert False, "No lines should be unchanged in a unified diff with context lines number 0"
197 |
198 | # Append the last block if it exists
199 | if current_old_start_line or current_new_start_line:
200 | if current_old_end_line > current_old_start_line:
201 | current_old_start_line -= 1
202 | current_old_end_line -= 1
203 | if current_new_end_line > current_new_start_line:
204 | current_new_start_line -= 1
205 | current_new_end_line -= 1
206 | changes_lines.append(((current_old_start_line, current_old_end_line),
207 | (current_new_start_line, current_new_end_line)))
208 | changes_lines = filter_changes(changes_lines, lines1, lines2)
209 | if unique:
210 | changes_lines = unique_changes(changes_lines, lines1)
211 | if merge_changes:
212 | changes_lines = merge_neighbor_change(changes_lines, merge_threshold)
213 | return changes_lines
214 |
215 | def filter_changes(changes_lines, lines1, lines2):
216 | """
217 | Filters out changes that are only due to differences in indentation or whitespace.
218 |
219 | Args:
220 | changes_lines (list of tuples): A list of tuples where each tuple contains two tuples.
221 | The first tuple represents the start and end indices of the change in `lines1`.
222 | The second tuple represents the start and end indices of the change in `lines2`.
223 | lines1 (list of str): The original list of lines.
224 | lines2 (list of str): The modified list of lines.
225 |
226 | Returns:
227 | list of tuples: A list of tuples containing the changes that are not purely due to differences in indentation or whitespace.
228 | """
229 | filtered_changes_lines = []
230 | if len(changes_lines) == 0:
231 | return filtered_changes_lines
232 | for old_change, new_change in changes_lines:
233 | old_start, old_end = old_change
234 | new_start, new_end = new_change
235 | indent1_max = max(len(line) - len(line.lstrip()) for line in lines1)
236 | indent2_max = max(len(line) - len(line.lstrip()) for line in lines2)
237 | before = " ".join(" ".join(lines1[old_start: old_end]).split())
238 | after = " ".join(" ".join(lines2[new_start: new_end]).split())
239 | if before != after or indent1_max != indent2_max:
240 | filtered_changes_lines.append((old_change, new_change))
241 | return filtered_changes_lines
242 |
243 | def unique_changes(changes_lines, lines1):
244 | """
245 | Identifies unique changes in a list of changes and expands them based on the context of the original lines.
246 |
247 | Args:
248 | changes_lines (list of tuples): A list of tuples where each tuple contains two sub-tuples.
249 | The first sub-tuple represents the old change with start and end indices,
250 | and the second sub-tuple represents the new change with start and end indices.
251 | lines1 (list): A list of lines from the original content.
252 |
253 | Returns:
254 | list of tuples: A list of tuples where each tuple contains two sub-tuples.
255 | The first sub-tuple represents the expanded old change with updated start and end indices,
256 | and the second sub-tuple represents the expanded new change with updated start and end indices.
257 |
258 | Raises:
259 | AssertionError: If the left or right expansion values are negative, indicating an invalid expansion.
260 | """
261 | unique_changes_lines = []
262 | if len(changes_lines) == 0:
263 | return unique_changes_lines
264 | for old_change, new_change in changes_lines:
265 | left_expansion, right_expansion = find_unique_sublist(lines1, old_change[0], old_change[1])
266 | assert left_expansion >= 0 and right_expansion >= 0, "Invalid expansion"
267 | unique_changes_lines.append(((old_change[0] - left_expansion, old_change[1] + right_expansion), (new_change[0] - left_expansion, new_change[1] + right_expansion)))
268 | return unique_changes_lines
269 |
270 | # TODO: accelerate
271 | def find_unique_sublist(b, a1, a2):
272 | """
273 | Finds the smallest extension of the sublist `b[a1:a2]` that is unique within the list `b`.
274 |
275 | Args:
276 | b (list): The list in which to find the unique sublist.
277 | a1 (int): The starting index of the sublist.
278 | a2 (int): The ending index of the sublist.
279 |
280 | Returns:
281 | tuple: A tuple (x, y) where `x` is the minimum number of elements to extend the sublist
282 | at the beginning, and `y` is the minimum number of elements to extend the sublist
283 | at the end to make it unique within `b`. If no unique sublist is found, returns (-1, -1).
284 | """
285 | if not "\n".join(b).strip():
286 | return 0, 0
287 |
288 | def is_unique_sublist(b, sublist):
289 | """
290 | Check if a sublist appears exactly once in a list.
291 |
292 | Args:
293 | b (list): The main list in which to search for the sublist.
294 | sublist (list): The sublist to search for within the main list.
295 |
296 | Returns:
297 | bool: True if the sublist appears exactly once in the main list, False otherwise.
298 |
299 | Raises:
300 | AssertionError: If the sublist appears more than once in the main list.
301 | """
302 | if not "\n".join(sublist).strip():
303 | return False
304 | count = 0
305 | for i in range(len(b) - len(sublist) + 1):
306 | if b[i:i + len(sublist)] == sublist:
307 | count += 1
308 | if count > 1:
309 | return False
310 | assert count == 1, "Invalid count"
311 | return count == 1
312 |
313 | for extend_length in range(len(b) - a2 + a1 + 1):
314 | for window in range(extend_length + 1):
315 | current_sublist = b[a1 - extend_length + window:a2 + window]
316 | if is_unique_sublist(b, current_sublist):
317 | return min(a1, extend_length - window), min(len(b) - a2, window)
318 | return -1, -1
319 |
320 | def merge_neighbor_change(changes_lines, merge_threshold=2):
321 | """
322 | Merges neighboring changes in a list of change line ranges based on a specified threshold.
323 |
324 | Args:
325 | changes_lines (list of tuples): A list of tuples where each tuple contains two tuples representing
326 | the old and new change line ranges respectively.
327 | Example: [((old_start, old_end), (new_start, new_end)), ...]
328 | merge_threshold (int, optional): The maximum allowed distance between neighboring changes to be merged.
329 | Defaults to 2.
330 |
331 | Returns:
332 | list of tuples: A list of merged change line ranges.
333 | """
334 | merged_changes_lines = []
335 | if len(changes_lines) == 0:
336 | return merged_changes_lines
337 | for i, (old_change, new_change) in enumerate(changes_lines):
338 | if i == 0:
339 | merged_changes_lines.append((old_change, new_change))
340 | else:
341 | last_old_change, last_new_change = merged_changes_lines[-1]
342 | if old_change[0] - last_old_change[1] <= merge_threshold and new_change[0] - last_new_change[1] <= merge_threshold:
343 | merged_changes_lines[-1] = (min(last_old_change[0], old_change[0]), max(last_old_change[1], old_change[1])), (min(last_new_change[0], new_change[0]), max(last_new_change[1], new_change[1]))
344 | else:
345 | merged_changes_lines.append((old_change, new_change))
346 | return merged_changes_lines
347 |
--------------------------------------------------------------------------------
/model_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "TechxGenus/CursorCore-Yi-1.5B": {
3 | "base": "http://127.0.0.1:10086/v1",
4 | "api": "sk-xxx"
5 | }
6 | }
--------------------------------------------------------------------------------
/pictures/APEval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/APEval.png
--------------------------------------------------------------------------------
/pictures/CursorWeb.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/CursorWeb.gif
--------------------------------------------------------------------------------
/pictures/EvalPlus_CanItEdit_OctoPack.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/EvalPlus_CanItEdit_OctoPack.png
--------------------------------------------------------------------------------
/pictures/conversation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/conversation.png
--------------------------------------------------------------------------------
/pictures/cursorcore.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/cursorcore.png
--------------------------------------------------------------------------------
/pictures/discord.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/discord.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # data
2 | tqdm
3 | pandas
4 | datasets
5 | rouge_score
6 |
7 | # gen
8 | sglang[all]
9 | openai
10 |
11 | # src
12 | Levenshtein
13 |
14 | # train
15 | numba
16 | numpy
17 | ninja
18 | torch
19 | packaging
20 | transformers
21 | liger-kernel
22 | tensorboard
23 | accelerate
24 | deepspeed<=0.14.5
25 |
26 | # eval
27 | vllm
28 | evaluate
29 | evalplus
30 |
--------------------------------------------------------------------------------
/src/README.md:
--------------------------------------------------------------------------------
1 | # Programming-Instruct
2 |
3 | This folder is the code for our data collection pipeline.
4 |
5 | ## Inference service for LLMs
6 |
7 | We generate data by prompting different LLMs. Currently, the OpenAI interface has been standardized, and almost all Inference frameworks support OpenAI compatible servers. Therefore, we uniformly use the OpenAI interface to generate.
8 |
9 | Example script to deploy `deepseek-coder-6.7b-instruct` using `sglang`:
10 |
11 | ```bash
12 | python -m sglang.launch_server --model-path deepseek-ai/deepseek-coder-6.7b-instruct --port 10086
13 | ```
14 |
15 | Example script to deploy `deepseek-coder-6.7b-instruct` using `vllm`:
16 |
17 | ```bash
18 | python -m vllm.entrypoints.openai.api_server --port 10086 --model deepseek-ai/deepseek-coder-6.7b-instruct --enable-prefix-caching
19 | ```
20 |
21 | We define the model inference service parameters in `model_map.json`. An example configuration is as follows:
22 |
23 | ```json
24 | {
25 | "deepseek-ai/deepseek-coder-6.7b-instruct": {
26 | "base": "http://127.0.0.1:10086/v1",
27 | "api": "sk-xxx"
28 | }
29 | }
30 | ```
31 |
32 | ## AI programmer
33 |
34 | For each code snippet, we use LLMs to generate the corresponding coding history. Its input file is a list of code snippets. Examples are as follows:
35 |
36 | ```json
37 | [
38 | "int i...",
39 | "import json...",
40 | "func main...",
41 | ...
42 | ]
43 | ```
44 |
45 | The command to run the code is:
46 |
47 | ```bash
48 | python src/aiprogrammer.py --model_map model_map.json --input_path data/code_snippets.json --output_path data/aiprogrammer.json
49 | ```
50 |
51 | ## Data collection
52 |
53 | Run the following scripts to generate data from various data sources:
54 |
55 | ```bash
56 | # AIprogrammer
57 | python src/data_collection.py --model_map model_map.json --data_type aiprogrammer --input_path data/aiprogrammer.json --output_path data/aiprogrammer_end.json
58 |
59 | # Git Commit
60 | python src/data_collection.py --model_map model_map.json --data_type commit --input_path data/commit.json --output_path data/commit_end.json --limit_one_block_prob 1.0
61 |
62 | # Online Submit
63 | python src/data_collection.py --model_map model_map.json --data_type submit --input_path data/submit_process.json --output_path data/submit_end.json
64 | ```
65 |
66 | ## Synthetic target area
67 |
68 | Programmers often specify the parts requiring changes, typically in one of two ways: either by clicking with the cursor to indicate a general area or by selecting a specific text range with defined start and end points.
69 |
70 | We synthesize the target modification area with a random algorithm:
71 |
72 | ```bash
73 | python src/post_collection.py --input_path data/tmp.json --output_path data/tmp_area.json
74 | ```
75 |
--------------------------------------------------------------------------------
/src/aiprogrammer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | current_path = os.path.abspath(__file__)
4 | parent_directory = os.path.dirname(os.path.dirname(current_path))
5 | sys.path.append(parent_directory)
6 |
7 | import json
8 | import argparse
9 | from gen import AIProgrammer
10 | from generic.utils import data_args, openai_args, get_openai_kwargs, decorate_code
11 |
12 | parser = argparse.ArgumentParser()
13 | parser = data_args(parser)
14 | parser = openai_args(parser)
15 | args = parser.parse_args()
16 | openai_kwargs = get_openai_kwargs(args)
17 |
18 | with open(args.model_map, 'r') as f:
19 | model_map = json.load(f)
20 |
21 | with open(args.input_path, 'r') as f:
22 | inputs = json.load(f)
23 |
24 | inputs = [{'code': decorate_code(i), "content": i} for i in inputs]
25 |
26 | Gen = AIProgrammer(
27 | model_map=model_map, num_proc=args.num_proc, **openai_kwargs
28 | )
29 |
30 | results = Gen.gen(inputs)
31 | results = [result for result in results if result["output"]]
32 |
33 | with open(args.output_path, 'w') as f:
34 | json.dump(results, f, indent=4)
35 |
--------------------------------------------------------------------------------
/src/data_collection.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | current_path = os.path.abspath(__file__)
4 | parent_directory = os.path.dirname(os.path.dirname(current_path))
5 | sys.path.append(parent_directory)
6 |
7 | import json
8 | import random
9 | import argparse
10 | import concurrent.futures
11 | from tqdm import tqdm
12 | from gen import GenJudgement, GenInstruction, GenChat
13 | from generic.utils import data_args, openai_args, get_openai_kwargs, decorate_code
14 | from utils import generate_diff, random_edit_series, generate_diff_blocks, apply_selected_blocks
15 |
16 | def generate_record_and_change(history, current, final):
17 | """
18 | Generates a record of changes between historical code segments, the current code segment,
19 | and the final code segment.
20 |
21 | Args:
22 | history (list): A list of dictionaries representing historical code segments. Each dictionary
23 | should have a 'type' key with the value 'code' and a 'code' key with the code
24 | segment as its value.
25 | current (dict): A dictionary representing the current code segment. It should have a 'type' key
26 | with the value 'code', a 'code' key with the code segment, and a 'lang' key
27 | indicating the programming language.
28 | final (str): The final code segment as a string.
29 |
30 | Returns:
31 | tuple: A tuple containing:
32 | - blocks (list): A list of blocks representing the differences between the current code
33 | and the final code.
34 | - current_record (list): A list of dictionaries representing the historical changes and
35 | the current code segment.
36 | - future_change (list): A list of decorated code segments representing the future changes
37 | needed to reach the final code.
38 | """
39 | current_record = []
40 | future_change = []
41 |
42 | history_current = history + [current]
43 | for previous, now in zip(history_current[:-1], history_current[1:]):
44 | if previous['type'] == 'code' and now['type'] == 'code':
45 | diff = generate_diff(previous['code'], now['code'])
46 | current_record.append({"type": "history", "history": decorate_code(diff, "diff")})
47 | else:
48 | raise ValueError("History should only contain code segments now.")
49 |
50 | current_record.append({"type": "current", "current": decorate_code(current['code'], current['lang'], use_line_num=True)})
51 | blocks = []
52 | if current['code'] != final:
53 | blocks = generate_diff_blocks(current['code'], final)
54 | for i in range(len(blocks)):
55 | code = apply_selected_blocks(current['code'], blocks, [i])
56 | future_change.append(decorate_code(generate_diff(current['code'], code), "diff"))
57 | return blocks, current_record, future_change
58 |
59 | parser = argparse.ArgumentParser()
60 | parser = data_args(parser)
61 | parser = openai_args(parser)
62 | parser.add_argument("--alpha", type=float, default=1.0, help="Sampling parameter for history")
63 | parser.add_argument("--data_type", type=str, help="Data Type (aiprogrammer, commit, submit)")
64 | parser.add_argument("--max_per_sample", type=int, default=1, help="Maximum number of samples per input")
65 | parser.add_argument("--use_history_prob", type=float, default=0.6, help="Probability of using history")
66 | parser.add_argument("--use_user_prob", type=float, default=0.5, help="Probability of using user")
67 | parser.add_argument("--limit_one_block_prob", type=float, default=0.5, help="Probability of limiting one block during History")
68 | args = parser.parse_args()
69 | openai_kwargs = get_openai_kwargs(args)
70 |
71 | with open(args.model_map, 'r') as f:
72 | model_map = json.load(f)
73 |
74 | with open(args.input_path, 'r') as f:
75 | inputs = json.load(f)
76 |
77 | random.shuffle(inputs)
78 |
79 | GenJudge = GenJudgement(
80 | model_map=model_map, num_proc=args.num_proc, **openai_kwargs
81 | )
82 |
83 | GenInst = GenInstruction(
84 | model_map=model_map, num_proc=args.num_proc, **openai_kwargs
85 | )
86 |
87 | GenCht = GenChat(
88 | model_map=model_map, num_proc=args.num_proc, **openai_kwargs
89 | )
90 |
91 | def process(sample):
92 | """
93 | Processes a given sample based on the specified data type and generates a series of code transformations.
94 |
95 | Args:
96 | sample (dict): A dictionary containing the sample data. The structure of the sample depends on the data type:
97 | - For 'aiprogrammer': Must contain 'output'.
98 | - For 'commit': Must contain 'code1' and 'code2'.
99 | - For 'submit': Must contain 'submissions', 'language', and 'problems_contexts'.
100 |
101 | Returns:
102 | list: A list of dictionaries, each representing a step in the code transformation process. Each dictionary contains:
103 | - 'history': The history of code transformations up to the current step.
104 | - 'current': The current code block being processed.
105 | - 'user': The user input or instruction for the current step.
106 | - 'chat': The chat or commentary generated for the current step.
107 | - 'next': The next code block after applying the transformation.
108 |
109 | Raises:
110 | ValueError: If the data type specified in args.data_type is not supported.
111 | """
112 | if args.data_type == 'aiprogrammer':
113 | sample['lang'] = ""
114 | code_series = []
115 | for code1, code2 in zip(sample['output'][:-1], sample['output'][1:]):
116 | if random.random() < args.limit_one_block_prob:
117 | code_series += random_edit_series(code1, code2)[:-1]
118 | else:
119 | code_series += [code1]
120 | final = sample['output'][-1]
121 | elif args.data_type == 'commit':
122 | if random.random() < args.limit_one_block_prob:
123 | code_series = random_edit_series(sample['code1'], sample['code2'])[:-1]
124 | else:
125 | code_series = [sample['code1']]
126 | final = sample['code2']
127 | elif args.data_type == 'submit':
128 | sample['lang'] = sample['language'].lower()
129 | code_series = []
130 | for code1, code2 in zip(sample['submissions'][:-1], sample['submissions'][1:]):
131 | if random.random() < args.limit_one_block_prob:
132 | code_series += random_edit_series(code1, code2)[:-1]
133 | else:
134 | code_series += [code1]
135 | final = sample['submissions'][-1]
136 | else:
137 | raise ValueError("Data Type {} is not supported.".format(args.data_type))
138 | histories = [{"type": "code", 'lang': sample['lang'], "code": code} for code in code_series]
139 | candidates = list(range(len(histories)))
140 | result = []
141 | candidate_num = 0
142 | while len(candidates) > 0 and candidate_num < args.max_per_sample:
143 | weights = [args.alpha ** c + 1e-6 for c in candidates]
144 | candidate = random.choices(candidates, weights=weights, k=1)[0]
145 | candidates = [c for c in candidates if c != candidate]
146 | current = histories[candidate]
147 | use_history = random.random() < args.use_history_prob
148 | use_user = random.random() < args.use_user_prob
149 | if candidate == 0 and current['code'].strip() == "" and not use_user:
150 | continue
151 | if use_history:
152 | history = histories[:candidate]
153 | if len(history) == 0:
154 | continue
155 | else:
156 | history = []
157 | blocks, current_record, future_change = generate_record_and_change(history, current, final)
158 | if len(future_change) == 0:
159 | continue
160 | if use_user:
161 | if args.data_type == 'aiprogrammer':
162 | current_record_with_auxiliary = current_record
163 | elif args.data_type == 'commit':
164 | current_record_with_auxiliary = current_record + [{"type": "git", "git": sample['git']}]
165 | elif args.data_type == 'submit':
166 | current_record_with_auxiliary = current_record + [{"type": "problem", "problem": sample['problems_contexts']}]
167 | if candidate == 0 and args.data_type == 'commit':
168 | user = sample['git']
169 | elif candidate == 0 and args.data_type == 'submit':
170 | user = "Please generate a correct {} program for the following problem:\n\n{}".format(sample['lang'], sample['problems_contexts'])
171 | else:
172 | inst = GenInst.gen([{"record": current_record_with_auxiliary, "change": decorate_code(generate_diff(current['code'], final), "diff")}])
173 | if inst[0]["output"] is None:
174 | continue
175 | user = inst[0]["output"]
176 | next_ = {"type": "code", "lang": current['lang'], "code": final}
177 | current_record += [{"type": "user", "user": user}]
178 | else:
179 | judge = GenJudge.gen([{"record": current_record, "change": future_change}])
180 | if judge[0]["output"] is None:
181 | continue
182 | selected_indices = [index for index, value in enumerate(judge[0]["output"]) if value]
183 | user = ""
184 | next_ = {"type": "code", "lang": current['lang'], "code": apply_selected_blocks(current['code'], blocks, selected_indices)}
185 | next_change = decorate_code(generate_diff(current['code'], next_["code"]), "diff")
186 | chat = GenCht.gen([{"record": current_record, "change": next_change}])
187 | if chat[0]["output"] is None:
188 | chat = ""
189 | else:
190 | chat = chat[0]["output"]
191 | result.append({"history": history, "current": current, "user": user, "chat": chat, "next": next_})
192 | candidate_num += 1
193 | return result
194 |
195 | with concurrent.futures.ThreadPoolExecutor(args.num_proc) as executor:
196 | futures = [executor.submit(process, item) for item in inputs]
197 | results = []
198 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
199 | results.append(future.result())
200 |
201 | results = [r for result in results for r in result]
202 | with open(args.output_path, 'w') as f:
203 | json.dump(list(results), f, indent=4)
204 |
--------------------------------------------------------------------------------
/src/merge_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 |
4 | def merge_json_files(output_file, *input_files):
5 | """
6 | Merge multiple JSON files into a single file.
7 |
8 | Args:
9 | output_file (str): The path to the output file where the merged data will be written.
10 | *input_files (str): Variable number of input file paths to be merged.
11 |
12 | Raises:
13 | FileNotFoundError: If any of the input files are not found.
14 | JSONDecodeError: If any of the input files cannot be parsed as JSON.
15 |
16 | Returns:
17 | None
18 | """
19 | merged_data = []
20 | for file in input_files:
21 | try:
22 | with open(file, 'r') as f:
23 | data = json.load(f)
24 | if isinstance(data, list):
25 | merged_data.extend(data)
26 | else:
27 | print(f"Warning: {file} does not contain a list. Skipping.")
28 | except FileNotFoundError as e:
29 | print(f"Error reading {file}: {e}")
30 | except json.JSONDecodeError as e:
31 | print(f"Error parsing {file} as JSON: {e}")
32 | try:
33 | with open(output_file, 'w') as f:
34 | json.dump(merged_data, f, indent=4)
35 | print(f"Merged data written to {output_file}")
36 | except Exception as e:
37 | print(f"Error writing to {output_file}: {e}")
38 |
39 | parser = argparse.ArgumentParser(description='Merge multiple JSON files into one.')
40 | parser.add_argument('output', type=str, help='The output file to write the merged data to.')
41 | parser.add_argument('inputs', nargs='+', type=str, help='The input JSON files to merge.')
42 |
43 | args = parser.parse_args()
44 | merge_json_files(args.output, *args.inputs)
45 |
--------------------------------------------------------------------------------
/src/post_collection.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import argparse
4 | import Levenshtein
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("--input_path", type=str, help="Input Path")
8 | parser.add_argument("--output_path", type=str, help="Output Path")
9 | parser.add_argument("--use_loc_prob", type=float, default=0.75, help="Probability of using LOC")
10 | parser.add_argument("--loc_cursor_prob", type=float, default=0.5, help="Probability of using cursor, using range otherwise")
11 | parser.add_argument("--loc_noisy_prob", type=float, default=0.5, help="Probability of using noisy range")
12 | parser.add_argument("--loc_noisy_length", type=int, default=32, help="Length of noisy range")
13 | parser.add_argument("--loc_all_noisy_prob", type=float, default=0.05, help="Probability of using noisy range for all")
14 | args = parser.parse_args()
15 |
16 | def get_changes(s1, s2):
17 | """
18 | Compute the changes needed to transform string s1 into string s2 using Levenshtein distance.
19 |
20 | Args:
21 | s1 (str): The original string.
22 | s2 (str): The target string.
23 |
24 | Returns:
25 | list of tuple: A list of tuples where each tuple contains the start and end indices of the changes in s1.
26 | Each tuple represents a range of indices in s1 that need to be changed to match s2.
27 | """
28 | edit_ops = Levenshtein.opcodes(s1, s2)
29 | changes = []
30 | for op, i1, i2, _, _ in edit_ops:
31 | if op != 'equal':
32 | if changes and changes[-1][1] == i1:
33 | changes[-1] = (changes[-1][0], i2)
34 | else:
35 | changes.append((i1, i2))
36 | return changes
37 |
38 | def sample_from_boundaries(left, right, max_length, noisy_length=0):
39 | """
40 | Samples a start and end point within specified boundaries, with optional noise.
41 |
42 | Args:
43 | left (int): The left boundary for sampling the start point.
44 | right (int): The right boundary for sampling the end point.
45 | max_length (int): The maximum allowable length for the end point.
46 | noisy_length (int, optional): The amount of noise to add to the boundaries. Defaults to 0.
47 |
48 | Returns:
49 | tuple: A tuple containing the sampled start and end points (start, end).
50 | """
51 | start = random.choice(list(range(max(0, left - noisy_length), left + 1)))
52 | end = random.choice(list(range(right, min(right + noisy_length, max_length) + 1)))
53 | return start, end
54 |
55 | def sample_from_ranges(ranges):
56 | """
57 | Selects a random number from a set of ranges.
58 |
59 | Args:
60 | ranges (list of tuple): A list of tuples where each tuple contains two integers (start, end)
61 | representing the inclusive range from start to end.
62 |
63 | Returns:
64 | int: A randomly selected number from the combined ranges.
65 |
66 | Raises:
67 | ValueError: If the ranges list is empty or if no numbers can be generated from the given ranges.
68 | """
69 | numbers = set()
70 | for start, end in ranges:
71 | numbers.update(range(start, end + 1))
72 | return random.choice(list(numbers))
73 |
74 | def post_current(current, next_):
75 | """
76 | Determines the cursor position or range based on the changes between the current and next states.
77 |
78 | Args:
79 | current (str): The current state.
80 | next_ (str): The next state.
81 |
82 | Returns:
83 | Union[int, Tuple[int, int], None]:
84 | - An integer representing a single cursor position.
85 | - A tuple of two integers representing a range (left, right).
86 | - None if no cursor position or range is determined.
87 | """
88 | if random.random() < args.use_loc_prob:
89 | changes = get_changes(current, next_)
90 | if not changes:
91 | changes = [(0, len(current))]
92 | if random.random() < args.loc_cursor_prob:
93 | if random.random() < args.loc_all_noisy_prob:
94 | cursor = sample_from_ranges([(0, len(current))])
95 | else:
96 | if random.random() < args.loc_noisy_prob:
97 | cursor = sample_from_ranges([random.choice(changes)])
98 | else:
99 | cursor = random.choice(random.choice(changes))
100 | return cursor
101 | else:
102 | if random.random() < args.loc_all_noisy_prob:
103 | left, right = sample_from_ranges([(0, len(current))]), sample_from_ranges([(0, len(current))])
104 | if left > right:
105 | left, right = right, left
106 | else:
107 | if random.random() < args.loc_noisy_prob:
108 | noisy_length = args.loc_noisy_length
109 | else:
110 | noisy_length = 0
111 | left, right = sample_from_boundaries(changes[0][0], changes[-1][1], len(current), noisy_length)
112 | return left, right
113 | else:
114 | return None
115 |
116 | with open(args.input_path, "r") as f:
117 | data = json.load(f)
118 |
119 | for sample in data:
120 | loc = post_current(sample["current"]["code"], sample["next"]["code"])
121 | # TODO: Unify variable names
122 | sample["loc"] = loc
123 |
124 | with open(args.output_path, "w") as f:
125 | json.dump(data, f, indent=4)
126 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import difflib
3 | import Levenshtein
4 |
5 | def generate_diff(code1, code2, fromfile='', tofile='', n=3):
6 | """
7 | Generate a unified diff between two code strings.
8 |
9 | Args:
10 | code1 (str): The first code string.
11 | code2 (str): The second code string.
12 | fromfile (str, optional): The name of the first file. Defaults to ''.
13 | tofile (str, optional): The name of the second file. Defaults to ''.
14 |
15 | Returns:
16 | str: The unified diff as a string.
17 | """
18 | # Split the code strings into lines
19 | lines1 = code1.split('\n')
20 | lines2 = code2.split('\n')
21 |
22 | # Generate unified diff
23 | diff = difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile, n=n, lineterm='')
24 | output = ""
25 |
26 | for line in diff:
27 | output += line + "\n"
28 |
29 | if fromfile == '' and tofile == '':
30 | output = "\n".join(output.split("\n")[2:])
31 |
32 | return output[:-1]
33 |
34 | def if_continuous_modify(code1, code2, code3):
35 | """
36 | Check if code3 is a continuous modification of code1 and code2.
37 |
38 | Args:
39 | code1 (str): The first code string.
40 | code2 (str): The second code string.
41 | code3 (str): The third code string.
42 |
43 | Returns:
44 | bool: True if code3 is a continuous modification of code1 and code2, False otherwise.
45 | """
46 | # Calculate the Levenshtein distance between code1 and code2
47 | dist1 = Levenshtein.distance(code1, code2)
48 | # Calculate the Levenshtein distance between code2 and code3
49 | dist2 = Levenshtein.distance(code2, code3)
50 | # Calculate the Levenshtein distance between code1 and code3
51 | dist3 = Levenshtein.distance(code1, code3)
52 |
53 | # Check if code3 is a continuous modification of code1 and code2
54 | if dist3 == dist1 + dist2:
55 | return True
56 | else:
57 | return False
58 |
59 | def blockwise_if_continuous_modify(code1, code2, code3):
60 | """
61 | Check if code3 is a continuous modification of code1 and code2.
62 |
63 | Args:
64 | code1 (str): The first code string.
65 | code2 (str): The second code string.
66 | code3 (str): The third code string.
67 |
68 | Returns:
69 | bool: True if code3 is a continuous modification of code1 and code2, False otherwise.
70 | """
71 | # Calculate the Levenshtein distance between code1 and code2
72 | dist1 = Levenshtein.distance(code1, code2)
73 | # Calculate the Levenshtein distance between code2 and code3
74 | dist2 = Levenshtein.distance(code2, code3)
75 | # Calculate the Levenshtein distance between code1 and code3
76 | dist3 = Levenshtein.distance(code1, code3)
77 |
78 | past_diff_blocks = generate_diff_blocks(code1, code2)
79 | new_diff_blocks = generate_diff_blocks(code1, code3)
80 |
81 | # Check if code3 is a continuous modification of code1 and code2
82 | if dist3 == dist1 + dist2 and len(past_diff_blocks) == len(new_diff_blocks):
83 | return True
84 | else:
85 | return False
86 |
87 | def generate_diff_blocks(original, modified):
88 | """
89 | Generate diff blocks between two strings.
90 |
91 | Args:
92 | original (str): The original string.
93 | modified (str): The modified string.
94 |
95 | Returns:
96 | list: A list of tuples, where each tuple contains a block of modified lines and the line number in the original string where the block starts.
97 | """
98 | # Use difflib's ndiff to find differences
99 | differ = difflib.Differ()
100 | diff = list(differ.compare(original.split('\n'), modified.split('\n')))
101 |
102 | # store all modified blocks
103 | blocks = []
104 | current_block = []
105 |
106 | # track the current line number
107 | orig_line_no = 0
108 | block_line_len = 0
109 |
110 | # Traverse the diff results into chunks
111 | for line in diff:
112 | if line.startswith(' '):
113 | # If the current block has content and an unmodified line is encountered, save the current block and reset
114 | if current_block:
115 | blocks.append((current_block, orig_line_no - block_line_len))
116 | current_block = []
117 | block_line_len = 0
118 | orig_line_no += 1
119 | elif line.startswith('- '):
120 | current_block.append(line)
121 | orig_line_no += 1
122 | block_line_len += 1
123 | else:
124 | current_block.append(line)
125 |
126 | # Make sure the last chunk is added
127 | if current_block:
128 | blocks.append((current_block, orig_line_no - block_line_len))
129 |
130 | return blocks
131 |
132 | def apply_selected_blocks(original, blocks, selected_indices):
133 | """
134 | Apply selected blocks to the original code.
135 |
136 | Args:
137 | original (str): The original code as a string.
138 | blocks (list): A list of code blocks.
139 | selected_indices (list): A list of indices representing the selected blocks.
140 |
141 | Returns:
142 | str: The modified code after applying the selected blocks.
143 | """
144 | # Split the original code by lines
145 | original_lines = original.split('\n')
146 |
147 | # Adjust code based on selected block
148 | offset = 0
149 | for index in selected_indices:
150 | block, start_line = blocks[index]
151 | # Iterate through each row in the block
152 | delete_offset = 0
153 | for line in block:
154 | if line.startswith('- '):
155 | del original_lines[start_line + offset]
156 | delete_offset -= 1
157 | elif line.startswith('+ '):
158 | original_lines.insert(start_line + offset, line[2:])
159 | offset += 1
160 | offset += delete_offset
161 | delete_offset = 0
162 |
163 | return '\n'.join(original_lines)
164 |
165 | # TODO: accelerate the process
166 | def random_edit_series(code1, code2):
167 | """
168 | Generates a series of randomly edited code versions between two given code snippets.
169 |
170 | Args:
171 | code1 (str): The original code snippet.
172 | code2 (str): The modified code snippet.
173 |
174 | Returns:
175 | list: A list of code snippets representing the series of randomly edited versions
176 | between the original and modified code snippets.
177 | """
178 | blocks = generate_diff_blocks(code1, code2)
179 | random_block_order = list(range(len(blocks)))
180 | # random.seed(42)
181 | random.shuffle(random_block_order)
182 | code_series = []
183 | for i in range(len(blocks) + 1):
184 | code_series.append(apply_selected_blocks(code1, blocks, sorted(random_block_order[:i])))
185 | return code_series
186 |
187 | def sequential_edit_series(code1, code2):
188 | """
189 | Generate a series of code snippets by sequentially applying selected diff blocks.
190 |
191 | Args:
192 | code1 (str): The original code.
193 | code2 (str): The modified code.
194 |
195 | Returns:
196 | list: A list of code snippets, each representing the result of applying a selected diff block.
197 | """
198 | blocks = generate_diff_blocks(code1, code2)
199 | code_series = []
200 | for i in range(len(blocks)):
201 | code_series.append(apply_selected_blocks(code1, blocks, [i]))
202 | return code_series
203 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/tests/__init__.py
--------------------------------------------------------------------------------
/train/README.md:
--------------------------------------------------------------------------------
1 | # Training
2 |
3 | This directory contains programs related to training the CursorCore series. Some other training engines that support custom templates should also work.
4 |
5 | ## Prepare conversation
6 |
7 | Run the following program to inject the conversation template into the model:
8 |
9 | ```bash
10 | python train/prepare_conversation.py --model_name_or_path deepseek-ai/deepseek-coder-1.3b-base --save_path train/formatted-deepseek-coder-1.3b-base
11 | ```
12 |
13 | ## Prepare data
14 |
15 | Run the following program to format training data:
16 |
17 | ```bash
18 | # WF Format
19 | python train/prepare_data.py --input_path data/data.json --output_path data/train_data.json
20 |
21 | # LC Format
22 | python train/prepare_data.py --input_path data/data.json --output_path data/train_data_lc.json --format_type lc
23 |
24 | # SR Format
25 | python train/prepare_data.py --input_path data/data.json --output_path data/train_data_sr.json --format_type sr
26 | ```
27 |
28 | ## Training script
29 |
30 | Our script can be run with common distributed launchers such as deepspeed and accelerate. Here is an example script for training models using deepspeed:
31 |
32 | ```bash
33 | MODEL=$1
34 | DATA=$2
35 | OUTPUT=$3
36 | mkdir -p $OUTPUT
37 |
38 | deepspeed --master_port='10086' train/training.py \
39 | --model_name_or_path $MODEL \
40 | --data_path $DATA \
41 | --model_max_length 16384 \
42 | --batch_max_length 50000 \
43 | --learning_rate 5e-5 \
44 | --weight_decay 0 \
45 | --num_train_epochs 2 \
46 | --gradient_accumulation_steps 1 \
47 | --gradient_checkpointing \
48 | --lr_scheduler_type cosine \
49 | --warmup_steps 15 \
50 | --seed 10086 \
51 | --output_dir $OUTPUT \
52 | --bf16 True \
53 | --evaluation_strategy "no" \
54 | --save_strategy "epoch" \
55 | --load_best_model_at_end False \
56 | --save_total_limit 1000 \
57 | --logging_steps 20 \
58 | --tf32 True \
59 | --optim adafactor \
60 | --use_liger_kernel True \
61 | --deepspeed train/ds_config.json
62 | ```
63 |
--------------------------------------------------------------------------------
/train/ds_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "overlap_comm": true,
5 | "contiguous_gradients": true,
6 | "sub_group_size": 0,
7 | "reduce_bucket_size": "auto",
8 | "stage3_prefetch_bucket_size": "auto",
9 | "stage3_param_persistence_threshold": "auto",
10 | "stage3_max_live_parameters": 0,
11 | "stage3_max_reuse_distance": 0,
12 | "stage3_gather_16bit_weights_on_model_save": true
13 | },
14 | "bf16": {
15 | "enabled": "auto"
16 | },
17 | "gradient_accumulation_steps": "auto",
18 | "gradient_clipping": "auto",
19 | "train_batch_size": "auto",
20 | "train_micro_batch_size_per_gpu": "auto",
21 | "wall_clock_breakdown": false
22 | }
23 |
--------------------------------------------------------------------------------
/train/prepare_conversation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | current_path = os.path.abspath(__file__)
4 | parent_directory = os.path.dirname(os.path.dirname(current_path))
5 | sys.path.append(parent_directory)
6 |
7 | import torch
8 | import argparse
9 | import transformers
10 | from typing import Dict
11 | from transformers import AutoModelForCausalLM, AutoTokenizer
12 | from generic.special_tokens import *
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--model_name_or_path", type=str, help="The model checkpoint for the model to be merged.")
16 | parser.add_argument("--save_path", type=str, help="The path to save the merged model.")
17 | args = parser.parse_args()
18 |
19 | def smart_tokenizer_and_embedding_resize(
20 | special_tokens_dict: Dict,
21 | tokenizer: transformers.PreTrainedTokenizer,
22 | model: transformers.PreTrainedModel,
23 | ):
24 | """Resize tokenizer and embedding."""
25 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
26 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
27 |
28 | if num_new_tokens > 0:
29 | input_embeddings = model.get_input_embeddings().weight.data
30 | output_embeddings = model.get_output_embeddings().weight.data
31 |
32 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
33 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
34 |
35 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
36 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
37 |
38 | tokenizer = AutoTokenizer.from_pretrained(
39 | args.model_name_or_path,
40 | trust_remote_code=True,
41 | )
42 |
43 | model = AutoModelForCausalLM.from_pretrained(
44 | args.model_name_or_path,
45 | trust_remote_code=True,
46 | torch_dtype=torch.bfloat16,
47 | )
48 |
49 | tokenizer.pad_token = tokenizer.eos_token
50 | tokenizer.add_special_tokens({"eos_token": IM_END})
51 | smart_tokenizer_and_embedding_resize(
52 | special_tokens_dict=dict(additional_special_tokens=SPECIAL_WORDS),
53 | tokenizer=tokenizer,
54 | model=model,
55 | )
56 |
57 | tokenizer.chat_template = [
58 | {
59 | "name": "default",
60 | "template": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful programming assistant.<|im_end|>\n' }}{% endif %}{% if not loop.first %}{{ '\n' }}{% endif %}{{ '<|im_start|>' + message['role'] }}{% if 'name' in message %}{{ ' name=' + message['name'] }}{% endif %}{{ '\n' + message['content'] + '<|im_end|>' }}{% if loop.last and add_generation_prompt %}{{ '\n<|im_start|>assistant\n' }}{% endif %}{% endfor %}"
61 | },
62 | {
63 | "name": "assistant-conversation",
64 | "template": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful programming assistant.<|im_end|>\n' }}{% endif %}{% if not loop.first %}{{ '\n' }}{% endif %}{{ '<|im_start|>' + message['role'] }}{% if 'name' in message %}{{ ' name=' + message['name'] }}{% endif %}{{ '\n' + message['content'] + '<|im_end|>' }}{% if loop.last and add_generation_prompt %}{{ '\n<|im_start|>assistant\n<|next_start|>' }}{% endif %}{% endfor %}"
65 | },
66 | {
67 | "name": "prefix_response",
68 | "template": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful programming assistant.<|im_end|>\n' }}{% endif %}{% if not loop.first %}{{ '\n' }}{% endif %}{{ '<|im_start|>' + message['role'] }}{% if 'name' in message %}{{ ' name=' + message['name'] }}{% endif %}{{ '\n' + message['content'] }}{% if not loop.last %}{{ '<|im_end|>' }}{% endif %}{% endfor %}"
69 | },
70 | ]
71 |
72 | model.config.eos_token_id = tokenizer.eos_token_id
73 | model.generation_config.eos_token_id = tokenizer.eos_token_id
74 |
75 | tokenizer.save_pretrained(args.save_path)
76 | model.save_pretrained(args.save_path)
77 |
--------------------------------------------------------------------------------
/train/prepare_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | current_path = os.path.abspath(__file__)
4 | parent_directory = os.path.dirname(os.path.dirname(current_path))
5 | sys.path.append(parent_directory)
6 |
7 | import json
8 | import argparse
9 | from generic.special_tokens import *
10 | from generic.utils import decorate_code, extract_changes_lines, generate_locations_changes, generate_search_and_replace
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--input_path", type=str, help="Input Path")
14 | parser.add_argument("--output_path", type=str, help="Output Path")
15 | parser.add_argument("--format_type", type=str, default="wf", help="Format Type for data")
16 | args = parser.parse_args()
17 |
18 | def format_data(data):
19 | """
20 | Formats the given data based on the specified format type in args.
21 |
22 | Args:
23 | data (list): A list of dictionaries containing the data to be formatted. Each dictionary represents a sample and contains the following keys:
24 | - "history" (list): A list of dictionaries representing the history of messages. Each dictionary contains:
25 | - "code" (str): The code snippet.
26 | - "lang" (str): The programming language of the code snippet.
27 | - "current" (dict): A dictionary representing the current message with keys:
28 | - "code" (str): The current code snippet.
29 | - "lang" (str): The programming language of the current code snippet.
30 | - "loc" (int or tuple): The location(s) in the current code snippet where the target(s) should be inserted.
31 | - "user" (str): The user's message.
32 | - "next" (dict): A dictionary representing the next message with keys:
33 | - "code" (str): The next code snippet.
34 | - "lang" (str): The programming language of the next code snippet.
35 | - "chat" (str): Additional chat message from the assistant.
36 |
37 | Returns:
38 | list: A list of formatted conversations. Each conversation is a list of dictionaries with the following keys:
39 | - "role" (str): The role of the message ("history", "current", "user", "assistant").
40 | - "content" (str): The content of the message, which may include decorated code or location changes.
41 |
42 | Raises:
43 | ValueError: If the format type specified in args is invalid.
44 | """
45 | formatted_data = []
46 | if args.format_type == "wf":
47 | for sample in data:
48 | conversation = []
49 | if sample["history"]:
50 | for message in sample["history"]:
51 | conversation.append({"role": "history", "content": decorate_code(message["code"], message["lang"])})
52 | current_loc = sample["current"]["code"]
53 | if sample["loc"]:
54 | if type(sample["loc"]) == int:
55 | current_loc = current_loc[:sample["loc"]] + TARGET + current_loc[sample["loc"]:]
56 | else:
57 | start, end = sample["loc"]
58 | current_loc = current_loc[:end] + TARGET_END + current_loc[end:]
59 | current_loc = current_loc[:start] + TARGET_START + current_loc[start:]
60 | conversation.append({"role": "current", "content": decorate_code(current_loc, sample["current"]["lang"])})
61 | if sample["user"]:
62 | conversation.append({"role": "user", "content": sample["user"]})
63 | assistant = ""
64 | assistant += NEXT_START + decorate_code(sample["next"]["code"], sample["next"]["lang"]) + NEXT_END
65 | if sample["chat"]:
66 | assistant += "\n" + sample["chat"]
67 | conversation.append({"role": "assistant", "content": assistant})
68 | formatted_data.append(conversation)
69 | elif args.format_type == "lc":
70 | for sample in data:
71 | conversation = []
72 | if sample["history"]:
73 | history_current = sample["history"] + [sample["current"]]
74 | for m1, m2 in zip(history_current[:-1], history_current[1:]):
75 | # In H, we record the modified position of the subsequent code snippet and the modified content of the previous code snippet
76 | # in A, we record the modified position of the previous code snippet and the modified content of the subsequent code snippet
77 | changes_lines = extract_changes_lines(m2["code"], m1["code"])
78 | locations_changes = generate_locations_changes(m2["code"], m1["code"], m1["lang"], changes_lines)
79 | conversation.append({"role": "history", "content": locations_changes})
80 | current_loc = sample["current"]["code"]
81 | if sample["loc"]:
82 | if type(sample["loc"]) == int:
83 | current_loc = current_loc[:sample["loc"]] + TARGET + current_loc[sample["loc"]:]
84 | else:
85 | start, end = sample["loc"]
86 | current_loc = current_loc[:end] + TARGET_END + current_loc[end:]
87 | current_loc = current_loc[:start] + TARGET_START + current_loc[start:]
88 | # we add line numbers for assistant to understand the location
89 | conversation.append({"role": "current", "content": decorate_code(current_loc, sample["current"]["lang"], use_line_num=True)})
90 | if sample["user"]:
91 | conversation.append({"role": "user", "content": sample["user"]})
92 | assistant = ""
93 | changes_lines = extract_changes_lines(sample["current"]["code"], sample["next"]["code"])
94 | locations_changes = generate_locations_changes(sample["current"]["code"], sample["next"]["code"], sample["next"]["lang"], changes_lines)
95 | assistant += NEXT_START + locations_changes + NEXT_END
96 | if sample["chat"]:
97 | assistant += "\n" + sample["chat"]
98 | conversation.append({"role": "assistant", "content": assistant})
99 | formatted_data.append(conversation)
100 | elif args.format_type == "sr":
101 | for sample in data:
102 | conversation = []
103 | if sample["history"]:
104 | history_current = sample["history"] + [sample["current"]]
105 | for m1, m2 in zip(history_current[:-1], history_current[1:]):
106 | # We ensure that the searched content matches exactly and there are no duplicate paragraphs
107 | changes_lines = extract_changes_lines(m2["code"], m1["code"], unique=True, merge_changes=True)
108 | changes_lines = [(new, old) for old, new in changes_lines]
109 | search_and_replace = generate_search_and_replace(m1["code"], m2["code"], m1["lang"], changes_lines)
110 | conversation.append({"role": "history", "content": search_and_replace})
111 | current_loc = sample["current"]["code"]
112 | if sample["loc"]:
113 | if type(sample["loc"]) == int:
114 | current_loc = current_loc[:sample["loc"]] + TARGET + current_loc[sample["loc"]:]
115 | else:
116 | start, end = sample["loc"]
117 | current_loc = current_loc[:end] + TARGET_END + current_loc[end:]
118 | current_loc = current_loc[:start] + TARGET_START + current_loc[start:]
119 | conversation.append({"role": "current", "content": decorate_code(current_loc, sample["current"]["lang"])})
120 | if sample["user"]:
121 | conversation.append({"role": "user", "content": sample["user"]})
122 | assistant = ""
123 | changes_lines = extract_changes_lines(sample["current"]["code"], sample["next"]["code"], unique=True, merge_changes=True)
124 | search_and_replace = generate_search_and_replace(sample["current"]["code"], sample["next"]["code"], sample["next"]["lang"], changes_lines)
125 | assistant += NEXT_START + search_and_replace + NEXT_END
126 | if sample["chat"]:
127 | assistant += "\n" + sample["chat"]
128 | conversation.append({"role": "assistant", "content": assistant})
129 | formatted_data.append(conversation)
130 | else:
131 | raise ValueError(f"Invalid format type: {args.format_type}")
132 | return formatted_data
133 |
134 | with open(args.input_path, "r") as f:
135 | input_data = json.load(f)
136 |
137 | output_data = format_data(input_data)
138 |
139 | with open(args.output_path, "w") as f:
140 | json.dump(output_data, f, indent=4)
141 |
--------------------------------------------------------------------------------
/train/training.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
3 | import copy
4 | import json
5 | import numba
6 | import torch
7 | import numpy as np
8 | import transformers
9 | from dataclasses import dataclass, field
10 | from typing import Optional, Dict, List
11 | from concurrent.futures import ThreadPoolExecutor
12 | from torch.utils.data import Dataset, DataLoader, Sampler
13 | from transformers import Trainer, AutoModelForCausalLM, AutoTokenizer
14 | from transformers import DefaultDataCollator, default_data_collator
15 | from transformers.trainer_utils import seed_worker
16 | from transformers.trainer_pt_utils import LabelSmoother
17 | IGNORE_INDEX = LabelSmoother.ignore_index
18 |
19 | @dataclass
20 | class ModelArguments:
21 | model_name_or_path: Optional[str] = field(default="deepseek-ai/DeepSeek-Coder-V2-Lite-Base")
22 |
23 | @dataclass
24 | class DataArguments:
25 | data_path: str = field(default=None, metadata={"help": "Path to the training data."})
26 | num_proc: int = field(default=8, metadata={"help": "Number of processes to use for data preprocessing."})
27 |
28 | @dataclass
29 | class TrainingArguments(transformers.TrainingArguments):
30 | model_max_length: int = field(default=10000, metadata={"help": "Maximum sequence length."})
31 | batch_max_length: int = field(default=25000, metadata={"help": "Maximum batch length."})
32 |
33 | def preprocess(
34 | list_data_dict: List,
35 | tokenizer: transformers.PreTrainedTokenizer,
36 | num_proc: int,
37 | ) -> Dict:
38 | """Preprocess the data by tokenizing."""
39 | examples = [tokenizer.apply_chat_template(l, tokenize=False) for l in list_data_dict]
40 | sources = [tokenizer.apply_chat_template(l[:-1], tokenize=False, add_generation_prompt=True) for l in list_data_dict]
41 |
42 | """Tokenize a list of strings."""
43 | def tokenize_text(text):
44 | return tokenizer(
45 | text,
46 | return_tensors="pt",
47 | padding="longest",
48 | max_length=tokenizer.model_max_length,
49 | truncation=True,
50 | )
51 |
52 | with ThreadPoolExecutor(max_workers=num_proc) as executor:
53 | examples_tokenized = list(executor.map(tokenize_text, examples))
54 | sources_tokenized = list(executor.map(tokenize_text, sources))
55 |
56 | input_ids = [tokenized.input_ids[0].tolist() for tokenized in examples_tokenized]
57 | labels = copy.deepcopy(input_ids)
58 | source_lens = [len(tokenized.input_ids[0]) for tokenized in sources_tokenized]
59 | for label, source_len in zip(labels, source_lens):
60 | label[:source_len] = [IGNORE_INDEX] * source_len
61 | return dict(input_ids=input_ids, labels=labels)
62 |
63 | @numba.njit
64 | def ffd(lengths, batch_max_length):
65 | """
66 | First-Fit Decreasing (FFD) algorithm for bin packing.
67 |
68 | This function sorts the input lengths in decreasing order and then attempts to pack them into bins
69 | such that the sum of lengths in each bin does not exceed the specified batch_max_length. It returns
70 | the indices of the original lengths in each bin.
71 |
72 | Args:
73 | lengths (list or array-like): A list or array of lengths to be packed into bins.
74 | batch_max_length (int): The maximum allowable length for each bin.
75 |
76 | Returns:
77 | list of lists: A list where each sublist contains the indices of the original lengths that have
78 | been packed into the corresponding bin.
79 | """
80 | lengths = np.array(lengths)
81 | indices = np.argsort(lengths)[::-1]
82 | lengths = lengths[indices]
83 | bins = []
84 | bins_result = []
85 | for lengths_id, size in enumerate(lengths):
86 | add_new = True
87 | for idx in range(len(bins)):
88 | if bins[idx] >= size:
89 | bins[idx] -= size
90 | bins_result[idx].append(indices[lengths_id])
91 | add_new = False
92 | break
93 | if add_new:
94 | bins.append(batch_max_length - size)
95 | bins_result.append([indices[lengths_id]])
96 | return bins_result
97 |
98 | class PackSampler(Sampler):
99 | def __init__(self, batch_max_length: int, lengths: List[int], seed: int = 0):
100 | batches = ffd(lengths, batch_max_length)
101 | indices = np.random.default_rng(seed=seed).permutation(len(batches))
102 | self.batches = [batches[idx] for idx in indices]
103 |
104 | def __iter__(self):
105 | return iter(self.batches)
106 |
107 | def __len__(self):
108 | return len(self.batches)
109 |
110 | class SupervisedDataset(Dataset):
111 | """Dataset for supervised fine-tuning."""
112 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, num_proc: int):
113 | super(SupervisedDataset, self).__init__()
114 | with open(data_path, "r") as json_file:
115 | list_data_dict = json.load(json_file)
116 | data_dict = preprocess(list_data_dict, tokenizer, num_proc=num_proc)
117 | self.input_ids = data_dict["input_ids"]
118 | self.labels = data_dict["labels"]
119 |
120 | def __len__(self):
121 | return len(self.input_ids)
122 |
123 | def __getitem__(self, index) -> Dict[str, List[int]]:
124 | return [dict(input_ids=self.input_ids[i], labels=self.labels[i]) for i in index]
125 |
126 | @dataclass
127 | class DataCollatorWithFlattening(DefaultDataCollator):
128 | """Collate examples for supervised fine-tuning."""
129 | def __init__(self, *args, return_position_ids=True, **kwargs):
130 | super().__init__(*args, **kwargs)
131 | self.return_position_ids = return_position_ids
132 |
133 | def __call__(self, features, return_tensors=None):
134 | features = [item for feature in features for item in feature]
135 | if return_tensors is None:
136 | return_tensors = self.return_tensors
137 | is_labels_provided = "labels" in features[0]
138 | ret = {"input_ids": [], "labels": []}
139 | if self.return_position_ids:
140 | ret.update({"position_ids": []})
141 | for idx in range(0, len(features)):
142 | ret["input_ids"] += features[idx]["input_ids"]
143 | if is_labels_provided:
144 | ret["labels"] += [IGNORE_INDEX] + features[idx]["labels"][1:]
145 | else:
146 | ret["labels"] += [IGNORE_INDEX] + features[idx]["input_ids"][1:]
147 | if self.return_position_ids:
148 | ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
149 | return default_data_collator([ret], return_tensors)
150 |
151 | class CustomTrainer(Trainer):
152 | def __init__(self, sampler, *args, **kwargs):
153 | super().__init__(*args, **kwargs)
154 | self.sampler = sampler
155 |
156 | def get_train_dataloader(self) -> DataLoader:
157 | """Returns the training [`~torch.utils.data.DataLoader`]."""
158 | dataloader_params = {
159 | "collate_fn": self.data_collator,
160 | "num_workers": self.args.dataloader_num_workers,
161 | "pin_memory": self.args.dataloader_pin_memory,
162 | "persistent_workers": self.args.dataloader_persistent_workers,
163 | "sampler": self.sampler,
164 | "drop_last": self.args.dataloader_drop_last,
165 | "worker_init_fn": seed_worker,
166 | "prefetch_factor": self.args.dataloader_prefetch_factor
167 | }
168 | return self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
169 |
170 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
171 | """Make dataset and collator for supervised fine-tuning."""
172 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, num_proc=data_args.num_proc)
173 | data_collator = DataCollatorWithFlattening()
174 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
175 |
176 | def train():
177 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
178 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
179 | tokenizer = AutoTokenizer.from_pretrained(
180 | model_args.model_name_or_path,
181 | trust_remote_code=True,
182 | model_max_length=training_args.model_max_length,
183 | use_fast=True,
184 | )
185 | model = AutoModelForCausalLM.from_pretrained(
186 | model_args.model_name_or_path,
187 | trust_remote_code=True,
188 | torch_dtype=torch.bfloat16,
189 | attn_implementation="flash_attention_2",
190 | )
191 | model.gradient_checkpointing_enable()
192 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
193 | lengths = [len(input_id) for input_id in data_module["train_dataset"].input_ids]
194 | sampler = PackSampler(batch_max_length=training_args.batch_max_length, lengths=lengths, seed=training_args.seed)
195 | trainer = CustomTrainer(sampler=sampler, model=model, tokenizer=tokenizer, args=training_args, **data_module)
196 | model.config.use_cache = False
197 | # model.model.forward = torch.compile(model.model.forward)
198 | trainer.train()
199 | trainer.save_model(training_args.output_dir)
200 |
201 | if __name__ == "__main__":
202 | train()
203 |
--------------------------------------------------------------------------------