├── .gitignore ├── LICENSE ├── README.md ├── benchmark ├── README.md └── apeval.json ├── data ├── README.md ├── commit.py ├── submit.py └── submit_process.py ├── eval ├── README.md ├── eval_apeval.py ├── eval_humaneval.py ├── eval_mbpp.py ├── extract_results.py ├── search_and_replace.py └── utils.py ├── gen ├── __init__.py ├── genaiprogrammer.py ├── gencht.py ├── geninst.py ├── genjudge.py ├── llmeval.py ├── llmgen.py ├── openai.py ├── template │ ├── __init__.py │ ├── aiprogrammer_template.py │ ├── gencht_template.py │ ├── geninst_template.py │ └── genjudge_template.py └── utils.py ├── generic ├── __init__.py ├── special_tokens.py └── utils.py ├── model_map.json ├── pictures ├── APEval.png ├── CursorWeb.gif ├── EvalPlus_CanItEdit_OctoPack.png ├── conversation.png ├── cursorcore.png └── discord.png ├── requirements.txt ├── src ├── README.md ├── aiprogrammer.py ├── data_collection.py ├── merge_data.py ├── post_collection.py └── utils.py ├── tests └── __init__.py └── train ├── README.md ├── ds_config.json ├── prepare_conversation.py ├── prepare_data.py └── training.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *tmp* 3 | **/*.log 4 | .vscode/ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CursorCore: Assist Programming through Aligning Anything 2 | 3 |

4 | [📄arXiv] | 5 | [🤗HF Paper] | 6 | [🤖Models] | 7 | [🛠️Code] | 8 | [Web] | 9 | [Discord] 10 |

11 | 12 |
13 | 14 | - [CursorCore: Assist Programming through Aligning Anything](#cursorcore-assist-programming-through-aligning-anything) 15 | - [Introduction](#introduction) 16 | - [Structure](#structure) 17 | - [Models](#models) 18 | - [Usage](#usage) 19 | - [1) Normal chat](#1-normal-chat) 20 | - [2) Assistant-Conversation](#2-assistant-conversation) 21 | - [3) Web Demo](#3-web-demo) 22 | - [Future Work](#future-work) 23 | - [Citation](#citation) 24 | - [Acknowledgements](#acknowledgements) 25 | - [Contribution](#contribution) 26 | 27 |
28 | 29 | ## Introduction 30 | 31 | CursorCore is a series of open-source models designed for AI-assisted programming. It aims to support features such as automated editing and inline chat, replicating the core abilities of closed-source AI-assisted programming tools like Cursor. This is achieved by aligning data generated through Programming-Instruct. Please read [our paper](http://arxiv.org/abs/2410.07002) to learn more. 32 | 33 |

34 | conversation 35 |

36 | 37 | ![CursorWeb](https://github.com/TechxGenus/CursorCore/blob/main/pictures/CursorWeb.gif) 38 | 39 | ## Structure 40 | 41 | - `./benchmark` contains the APEval benchmark 42 | - `./data` contains code to preprocess datasets 43 | - `./eval` contains code to evaluate models 44 | - `./gen` contains code to prompt LLMs for generation 45 | - `./generic` common functions, tools and special tokens 46 | - `./src` contains code about Programming-Instruct 47 | - `./train` contains code for training CursorCore 48 | 49 | Please ensure all dependencies are installed using the following command: 50 | 51 | ```bash 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | We also use [flash-attention](https://github.com/Dao-AILab/flash-attention) for efficient training and [flashinfer](https://github.com/flashinfer-ai/flashinfer) to accelerate inference. See the documents for them to learn how to install. 56 | 57 | ## Models 58 | 59 | Our models have been open-sourced on Hugging Face. You can access our models here: [CursorCore-Series](https://huggingface.co/collections/TechxGenus/cursorcore-series-6706618c38598468866b60e2"). We also provide pre-quantized weights for GPTQ and AWQ here: [CursorCore-Quantization](https://huggingface.co/collections/TechxGenus/cursorcore-quantization-67066431f29f252494ee8cf3) 60 | 61 | We use the manually written benchmark APEval to assess the model's ability to assist programming. We also utilize [EvalPlus](https://github.com/evalplus/evalplus), [CanItEdit](https://github.com/nuprl/CanItEdit) and [OctoPack](https://github.com/bigcode-project/octopack) to evaluate the model's performance in Python program generation, instructional code editing, and automated program repair. Since we use a custom conversation template, its generation method differs significantly from both instruct models and base models. Please refer to [our paper](http://arxiv.org/abs/2410.07002) for more details. 62 | 63 | Evaluation results on APEval: 64 | 65 | APEval 66 | 67 | Evaluation results on EvalPlus, CanItEdit and OctoPack: 68 | 69 | EvalPlus_CanItEdit_OctoPack 70 | 71 | ## Usage 72 | 73 | Here are some examples of how to use our model: 74 | 75 | ### 1) Normal chat 76 | 77 | Script: 78 | 79 | ````python 80 | import torch 81 | from transformers import AutoTokenizer, AutoModelForCausalLM 82 | 83 | tokenizer = AutoTokenizer.from_pretrained("TechxGenus/CursorCore-Yi-9B") 84 | model = AutoModelForCausalLM.from_pretrained( 85 | "TechxGenus/CursorCore-Yi-9B", 86 | torch_dtype=torch.bfloat16, 87 | device_map="auto" 88 | ) 89 | 90 | messages = [ 91 | {"role": "user", "content": "Hi!"}, 92 | ] 93 | prompt = tokenizer.apply_chat_template( 94 | messages, 95 | tokenize=False, 96 | add_generation_prompt=True 97 | ) 98 | 99 | inputs = tokenizer.encode(prompt, return_tensors="pt") 100 | outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=512) 101 | print(tokenizer.decode(outputs[0])) 102 | ```` 103 | 104 | Output: 105 | 106 | ````txt 107 | <|im_start|>system 108 | You are a helpful programming assistant.<|im_end|> 109 | <|im_start|>user 110 | Hi!<|im_end|> 111 | <|im_start|>assistant 112 | Hello! I'm an AI language model and I can help you with any programming questions you might have. What specific problem or task are you trying to solve?<|im_end|> 113 | ```` 114 | 115 | ### 2) Assistant-Conversation 116 | 117 | In our work, we introduce a new framework of AI-assisted programming task. It is designed for aligning anything during programming process, used for the implementation of features like Tab and Inline Chat. 118 | 119 | Script 1: 120 | 121 | ````python 122 | import torch 123 | from transformers import AutoTokenizer, AutoModelForCausalLM 124 | from eval.utils import prepare_input_for_wf 125 | 126 | tokenizer = AutoTokenizer.from_pretrained("TechxGenus/CursorCore-Yi-9B") 127 | model = AutoModelForCausalLM.from_pretrained( 128 | "TechxGenus/CursorCore-Yi-9B", 129 | torch_dtype=torch.bfloat16, 130 | device_map="auto" 131 | ) 132 | sample = { 133 | "history": [ 134 | { 135 | "type": "code", 136 | "lang": "python", 137 | "code": """def quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)""" 138 | } 139 | ], 140 | "current": { 141 | "type": "code", 142 | "lang": "python", 143 | "code": """def quick_sort(array):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)""" 144 | }, 145 | "user": "" 146 | } 147 | 148 | prompt = tokenizer.apply_chat_template( 149 | prepare_input_for_wf(sample), 150 | tokenize=False, 151 | chat_template="assistant-conversation", 152 | add_generation_prompt=True 153 | ) 154 | 155 | inputs = tokenizer.encode(prompt, return_tensors="pt") 156 | outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=512, do_sample=False) 157 | print(tokenizer.decode(outputs[0])) 158 | ```` 159 | 160 | Output 1: 161 | 162 | ````txt 163 | <|im_start|>system 164 | You are a helpful programming assistant.<|im_end|> 165 | <|im_start|>history 166 | ```python 167 | def quick_sort(arr): 168 | if len(arr) <= 1: 169 | return arr 170 | pivot = arr[len(arr) // 2] 171 | left = [x for x in arr if x < pivot] 172 | middle = [x for x in arr if x == pivot] 173 | right = [x for x in arr if x > pivot] 174 | return quick_sort(left) + middle + quick_sort(right) 175 | ```<|im_end|> 176 | <|im_start|>current 177 | ```python 178 | def quick_sort(array): 179 | if len(arr) <= 1: 180 | return arr 181 | pivot = arr[len(arr) // 2] 182 | left = [x for x in arr if x < pivot] 183 | middle = [x for x in arr if x == pivot] 184 | right = [x for x in arr if x > pivot] 185 | return quick_sort(left) + middle + quick_sort(right) 186 | ```<|im_end|> 187 | <|im_start|>assistant 188 | <|next_start|>```python 189 | def quick_sort(array): 190 | if len(array) <= 1: 191 | return array 192 | pivot = array[len(array) // 2] 193 | left = [x for x in array if x < pivot] 194 | middle = [x for x in array if x == pivot] 195 | right = [x for x in array if x > pivot] 196 | return quick_sort(left) + middle + quick_sort(right) 197 | ```<|next_end|> 198 | The user has provided a revised code change that corrects the parameter name from `arr` to `array` in the `quick_sort` function. This change ensures consistency in the function definition and avoids potential confusion or errors. 199 | 200 | To implement this, we will: 201 | 1. Update the parameter name in the function definition from `arr` to `array`. 202 | 2. Ensure that all references to `arr` within the function are updated to `array`. 203 | 204 | This will make the function definition and internal references consistent, improving code readability and maintainability.<|im_end|> 205 | ```` 206 | 207 | Script 2: 208 | 209 | ````python 210 | import torch 211 | from transformers import AutoTokenizer, AutoModelForCausalLM 212 | from eval.utils import prepare_input_for_wf 213 | 214 | tokenizer = AutoTokenizer.from_pretrained("TechxGenus/CursorCore-Yi-9B") 215 | model = AutoModelForCausalLM.from_pretrained( 216 | "TechxGenus/CursorCore-Yi-9B", 217 | torch_dtype=torch.bfloat16, 218 | device_map="auto" 219 | ) 220 | sample = { 221 | "history": [], 222 | "current": { 223 | "type": "code", 224 | "lang": "python", 225 | "code": """def quick_sort(array):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)""" 226 | }, 227 | "user": "Add Docstring." 228 | } 229 | 230 | prompt = tokenizer.apply_chat_template( 231 | prepare_input_for_wf(sample), 232 | tokenize=False, 233 | chat_template="assistant-conversation", 234 | add_generation_prompt=True 235 | ) 236 | 237 | inputs = tokenizer.encode(prompt, return_tensors="pt") 238 | outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=512, do_sample=False) 239 | print(tokenizer.decode(outputs[0])) 240 | ```` 241 | 242 | Output 2: 243 | 244 | ````txt 245 | <|im_start|>system 246 | You are a helpful programming assistant.<|im_end|> 247 | <|im_start|>current 248 | ```python 249 | def quick_sort(array): 250 | if len(arr) <= 1: 251 | return arr 252 | pivot = arr[len(arr) // 2] 253 | left = [x for x in arr if x < pivot] 254 | middle = [x for x in arr if x == pivot] 255 | right = [x for x in arr if x > pivot] 256 | return quick_sort(left) + middle + quick_sort(right) 257 | ```<|im_end|> 258 | <|im_start|>user 259 | Add Docstring.<|im_end|> 260 | <|im_start|>assistant 261 | <|next_start|>```python 262 | def quick_sort(array): 263 | """ 264 | This is an implementation of the quick sort algorithm. 265 | """ 266 | if len(arr) <= 1: 267 | return arr 268 | pivot = arr[len(arr) // 2] 269 | left = [x for x in arr if x < pivot] 270 | middle = [x for x in arr if x == pivot] 271 | right = [x for x in arr if x > pivot] 272 | return quick_sort(left) + middle + quick_sort(right) 273 | ```<|next_end|><|im_end|> 274 | ```` 275 | 276 | For models in Locate-and-Change (LC) and Search-and-Replace (SR) formats, the output examples are as follows: 277 | 278 | Script for LC: 279 | 280 | ````python 281 | import torch 282 | from transformers import AutoTokenizer, AutoModelForCausalLM 283 | from eval.utils import prepare_input_for_lc 284 | 285 | tokenizer = AutoTokenizer.from_pretrained("TechxGenus/CursorCore-Yi-1.5B-LC") 286 | model = AutoModelForCausalLM.from_pretrained( 287 | "TechxGenus/CursorCore-Yi-1.5B-LC", 288 | torch_dtype=torch.bfloat16, 289 | device_map="auto" 290 | ) 291 | sample = { 292 | "history": [], 293 | "current": { 294 | "type": "code", 295 | "lang": "python", 296 | "code": """def quick_sort(array):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)""" 297 | }, 298 | "user": "Add Docstring." 299 | } 300 | 301 | prompt = tokenizer.apply_chat_template( 302 | prepare_input_for_lc(sample), 303 | tokenize=False, 304 | chat_template="assistant-conversation", 305 | add_generation_prompt=True 306 | ) 307 | 308 | inputs = tokenizer.encode(prompt, return_tensors="pt") 309 | outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=512, do_sample=False) 310 | print(tokenizer.decode(outputs[0])) 311 | ```` 312 | 313 | Output for LC: 314 | 315 | ````txt 316 | <|im_start|>system 317 | You are a helpful programming assistant.<|im_end|> 318 | <|im_start|>current 319 | ```python 320 | 1 def quick_sort(array): 321 | 2 if len(arr) <= 1: 322 | 3 return arr 323 | 4 pivot = arr[len(arr) // 2] 324 | 5 left = [x for x in arr if x < pivot] 325 | 6 middle = [x for x in arr if x == pivot] 326 | 7 right = [x for x in arr if x > pivot] 327 | 8 return quick_sort(left) + middle + quick_sort(right) 328 | ```<|im_end|> 329 | <|im_start|>user 330 | Add Docstring.<|im_end|> 331 | <|im_start|>assistant 332 | <|next_start|>1,1 333 | ``` 334 | '''This function will sort an array using quick sort algorithm''' 335 | ```<|next_end|> 336 | To enhance the readability and maintainability of the code, we should add a docstring to the `quick_sort` function. A docstring provides a clear description of what the function does, which is particularly useful for other developers who might use or modify this code in the future. 337 | 338 | The docstring will be added immediately after the function definition, explaining that the function uses the quick sort algorithm to sort an array. This will make the code more self-explanatory and easier to understand. 339 | 340 | Here's the plan: 341 | 1. Add a docstring at the beginning of the `quick_sort` function. 342 | 2. Ensure the docstring is clear and concise, describing the purpose of the function. 343 | 344 | This modification will improve the code's documentation without altering its functionality.<|im_end|> 345 | ```` 346 | 347 | Script for SR: 348 | 349 | ````python 350 | import torch 351 | from transformers import AutoTokenizer, AutoModelForCausalLM 352 | from eval.utils import prepare_input_for_sr 353 | 354 | tokenizer = AutoTokenizer.from_pretrained("TechxGenus/CursorCore-Yi-1.5B-SR") 355 | model = AutoModelForCausalLM.from_pretrained( 356 | "TechxGenus/CursorCore-Yi-1.5B-SR", 357 | torch_dtype=torch.bfloat16, 358 | device_map="auto" 359 | ) 360 | sample = { 361 | "history": [], 362 | "current": { 363 | "type": "code", 364 | "lang": "python", 365 | "code": """def quick_sort(array):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)""" 366 | }, 367 | "user": "Add Docstring." 368 | } 369 | 370 | prompt = tokenizer.apply_chat_template( 371 | prepare_input_for_sr(sample), 372 | tokenize=False, 373 | chat_template="assistant-conversation", 374 | add_generation_prompt=True 375 | ) 376 | 377 | inputs = tokenizer.encode(prompt, return_tensors="pt") 378 | outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=512, do_sample=False) 379 | print(tokenizer.decode(outputs[0])) 380 | ```` 381 | 382 | Output for SR: 383 | 384 | ````txt 385 | <|im_start|>system 386 | You are a helpful programming assistant.<|im_end|> 387 | <|im_start|>current 388 | ```python 389 | def quick_sort(array): 390 | if len(arr) <= 1: 391 | return arr 392 | pivot = arr[len(arr) // 2] 393 | left = [x for x in arr if x < pivot] 394 | middle = [x for x in arr if x == pivot] 395 | right = [x for x in arr if x > pivot] 396 | return quick_sort(left) + middle + quick_sort(right) 397 | ```<|im_end|> 398 | <|im_start|>user 399 | Add Docstring.<|im_end|> 400 | <|im_start|>assistant 401 | <|next_start|>```python 402 | def quick_sort(array): 403 | <|search_and_replace|> 404 | def quick_sort(array): 405 | """ 406 | This function implements quick sort algorithm 407 | """ 408 | ```<|next_end|><|im_end|> 409 | ```` 410 | 411 | ### 3) Web Demo 412 | 413 | We create a web demo for CursorCore. Please visit [CursorWeb](https://github.com/TechxGenus/CursorWeb) for more details. 414 | 415 | ## Future Work 416 | 417 | CursorCore is still in a very early stage, and lots of work is needed to achieve a better user experience. For example: 418 | 419 | - Repository-level editing support 420 | - Better and faster editing formats 421 | - Better user interface and presentation 422 | - ... 423 | 424 | ## Citation 425 | 426 | ```bibtex 427 | @article{jiang2024cursorcore, 428 | title = {CursorCore: Assist Programming through Aligning Anything}, 429 | author = {Hao Jiang and Qi Liu and Rui Li and Shengyu Ye and Shijin Wang}, 430 | year = {2024}, 431 | journal = {arXiv preprint arXiv: 2410.07002} 432 | } 433 | ``` 434 | 435 | ## Acknowledgements 436 | 437 | The open-source community has been of great help to us, and we reference numerous projects and applications. They include but are not limited to: 438 | 439 | [Deepseek-Coder](https://github.com/deepseek-ai/DeepSeek-Coder), [Yi-Coder](https://github.com/01-ai/Yi-Coder), [Qwen-Coder](https://github.com/QwenLM/Qwen2.5-Coder), [Self-Instruct](https://github.com/yizhongw/self-instruct), [Evol-Instruct](https://github.com/theblackcat102/evol-dataset), [OSS-Instruct](https://github.com/ise-uiuc/magicoder), [EvalPlus](https://github.com/evalplus/evalplus), [CanItEdit](https://github.com/nuprl/CanItEdit), [OctoPack](https://github.com/bigcode-project/octopack), [Aider](https://github.com/Aider-AI/aider), [Continue](https://github.com/continuedev/continue), [Cursor](https://github.com/getcursor/cursor), ... 440 | 441 | ## Contribution 442 | 443 | Contributions are welcome! If you find any bugs or have suggestions for improvements, please open an issue or submit a pull request. 444 | -------------------------------------------------------------------------------- /benchmark/README.md: -------------------------------------------------------------------------------- 1 | # APEval: Assist Programming Eval 2 | 3 | This benchmark aims to assess how models use various types of information to assist programming. It is extended by the [HumanEval](https://github.com/openai/human-eval) benchmark. The benchmark is structured as a JSON file where each entry corresponds to a specific task. Below is a detailed description of the fields present in the benchmark. 4 | 5 | ## File Structure 6 | 7 | The dataset is in JSON format with each entry structured as follows: 8 | 9 | - **task_id**: A unique identifier for each task, corresponding to the original HumanEval benchmark. 10 | - **history**: An array of historical code snippets and their related metadata. Each snippet in the history represents a different version of code. 11 | - **type**: The type of content, typically `code`. 12 | - **lang**: The programming language used in the code, typically `python`. 13 | - **code**: The historical code snippets from different moments. 14 | - **current**: The current code and its related metadata: 15 | - **type**: The type of content, typically `code`. 16 | - **lang**: The programming language used in the code, typically `python`. 17 | - **code**: The current version of the code for the task. 18 | - **user**: An instruction or reflection provided by the user regarding the task. 19 | - **area**: Extra metadata, indicates the location of the cursor or the selected code area. 20 | 21 | ## Example Case 22 | 23 | ```json 24 | { 25 | "task_id": "HumanEval/0", 26 | "history": [ 27 | { 28 | "type": "code", 29 | "lang": "python", 30 | "code": "def has_close_elements(n, t):\n for i in range(prm)" 31 | }, 32 | ... 33 | ], 34 | "current": { 35 | "type": "code", 36 | "lang": "python", 37 | "code": "def has_close_elements(n, t):\n for i in range(len(n - 1)):\n for j in range(i + 1, len(n)):\n if n[i] - n[j] < t or n[j] - n[i] < t:" 38 | }, 39 | "user": "", 40 | "area": 151 41 | } 42 | ``` 43 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data preprocess 2 | 3 | This folder contains preprocessing programs for `Git Commit` and `Online Submit` data. 4 | 5 | ## Download relevant data 6 | 7 | `Online Submit` data needs to be downloaded manually, please refer to [Codenet](https://github.com/IBM/Project_CodeNet) to download. 8 | 9 | We clean and translate some of Codenet's question data, which can be found in [Codenet_Context](https://huggingface.co/datasets/TechxGenus/CodeNet_Context). 10 | 11 | ## Preprocess 12 | 13 | Run the following example script to preprocess the data: 14 | 15 | ```bash 16 | python data/commit.py --languages ruby python javascript shell php java c# c swift typescript c++ go scala rust r --output_path data/commit.json 17 | python data/submit.py --dataset_path --output_path submit.json 18 | python data/submit_process.py --context_path data/CodeNet_Context.json --dataset_path data/submit.json --identical_path /derived/duplicates/identical_problem_clusters --output_path data/submit_process.json 19 | ``` 20 | 21 | View the parameters of the scripts can further specify the required language, number of programs, strictness of deduplication, etc. 22 | -------------------------------------------------------------------------------- /data/commit.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import argparse 3 | import json 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--languages", help="Supported languages", type=str, nargs="+") 7 | parser.add_argument("--max_per_lang", help="Max samples per language", default=30000, type=int) 8 | parser.add_argument("--output_path", help="Output Path", type=str) 9 | args = parser.parse_args() 10 | 11 | commit = [] 12 | count_per_lang = {} 13 | 14 | def process_sample(sample, count_per_lang, args, commit): 15 | """ 16 | Processes a sample and updates the commit list based on language constraints. 17 | 18 | Args: 19 | sample (dict): A dictionary containing sample data with keys "lang", "old_contents", "new_contents", and "subject". 20 | count_per_lang (dict): A dictionary tracking the count of samples per language. 21 | args (Namespace): An object containing arguments, specifically: 22 | - languages (list): A list of languages to include. 23 | - max_per_lang (int): The maximum number of samples per language. 24 | commit (list): A list to which the processed sample will be appended if it meets the criteria. 25 | 26 | Returns: 27 | None 28 | """ 29 | lang = sample["lang"].lower() 30 | if lang not in args.languages: 31 | return 32 | if lang not in count_per_lang: 33 | count_per_lang[lang] = 1 34 | else: 35 | count_per_lang[lang] += 1 36 | if count_per_lang[lang] > args.max_per_lang: 37 | return 38 | commit.append({ 39 | "code1": sample["old_contents"], 40 | "code2": sample["new_contents"], 41 | "lang": lang, 42 | "git": sample["subject"] 43 | }) 44 | 45 | # bigcode/commitpack 46 | # bigcode/commitpackft 47 | # nuprl/EditPackFT 48 | ds = load_dataset("nuprl/EditPackFT-Multi", split="train") 49 | for sample in iter(ds): 50 | process_sample(sample, count_per_lang, args, commit) 51 | 52 | with open(args.output_path, "w") as f: 53 | json.dump(commit, f, indent=4, sort_keys=True) 54 | -------------------------------------------------------------------------------- /data/submit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pandas as pd 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--dataset_path", type=str, help="Path to the Codenet dataset") 8 | parser.add_argument("--output_path", type=str, help="Output Path") 9 | args = parser.parse_args() 10 | 11 | metadata_path = os.path.join(args.dataset_path, 'metadata') 12 | data_path = os.path.join(args.dataset_path, 'data') 13 | submit = {} 14 | 15 | # Traverse the files in the directory 16 | for file in sorted(os.listdir(metadata_path)): 17 | file_path = os.path.join(metadata_path, file) 18 | problem_id = file.split('.')[0] 19 | if problem_id == "problem_list": 20 | continue 21 | submit[problem_id] = {} 22 | # Load the data 23 | df = pd.read_csv(file_path) 24 | 25 | # Filter users who use a single language 26 | grouped_df = df.groupby('user_id') 27 | language_counts = grouped_df['language'].nunique().reset_index(name='language_count') 28 | one_lang_users = language_counts[language_counts['language_count'] == 1] 29 | if len(one_lang_users) == 0: 30 | continue 31 | 32 | # Filter users who have at least one "Accepted" status 33 | status_list = grouped_df['status'].unique().reset_index(name='status_list') 34 | accepted_users = status_list[status_list['status_list'].apply(lambda x: 'Accepted' in x)] 35 | if len(accepted_users) == 0: 36 | continue 37 | 38 | # Combine the first two conditions to find the user IDs that meet the criteria 39 | valid_users = one_lang_users[one_lang_users['user_id'].isin(accepted_users['user_id'])] 40 | if len(valid_users) == 0: 41 | continue 42 | 43 | # Filter the submissions of these users 44 | filtered_submissions = df[df['user_id'].isin(valid_users['user_id'])] 45 | 46 | # Sort by user ID and date 47 | sorted_submissions = filtered_submissions.sort_values(by=['user_id', 'date']) 48 | 49 | # For each user, find the first submission with "Accepted" status and keep this submission and all previous submissions 50 | def filter_submissions(sub_df): 51 | # Find the position of the first record with "Accepted" status 52 | accepted_index = sub_df[sub_df['status'] == 'Accepted'].index.min() 53 | if pd.notna(accepted_index): 54 | # Get the relative position in the sub dataframe 55 | relative_index = sub_df.index.get_loc(accepted_index) 56 | # Return the first "Accepted" submission and all previous submissions 57 | return sub_df.iloc[:relative_index + 1] 58 | 59 | # Apply the filtering logic 60 | final_submissions = sorted_submissions.groupby('user_id').apply(filter_submissions).reset_index(drop=True) 61 | 62 | for user_id, sub_df in final_submissions.groupby('user_id'): 63 | # Iterate over each sub dataframe 64 | submit[problem_id][user_id] = {} 65 | submit[problem_id][user_id]["submissions"] = [] 66 | language = sub_df['language'].iloc[0] 67 | submit[problem_id][user_id]['language'] = language 68 | submissions = sub_df['submission_id'].tolist() 69 | filename_ext = sub_df['filename_ext'].tolist() 70 | for submission, ext in zip(submissions, filename_ext): 71 | submission_path = os.path.join(data_path, problem_id, language, f'{submission}.{ext}') 72 | with open(submission_path, 'r') as f: 73 | code = f.read() 74 | submit[problem_id][user_id]["submissions"].append(code) 75 | 76 | with open(args.output_path, 'w') as f: 77 | json.dump(submit, f, indent=4, sort_keys=True) 78 | -------------------------------------------------------------------------------- /data/submit_process.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tqdm 3 | import argparse 4 | from multiprocessing import Pool 5 | from rouge_score import rouge_scorer 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--context_path", type=str, help="Path to the Codenet Context dataset") 9 | parser.add_argument("--dataset_path", type=str, help="Path to the Codenet Submit dataset") 10 | parser.add_argument("--identical_path", type=str, help="Path to the Codenet Identical dataset") 11 | parser.add_argument("--output_path", type=str, help="Output Path") 12 | parser.add_argument("--num_proc", type=int, default=32, help="Number of processes") 13 | parser.add_argument("--max_per_problem", type=int, default=50, help="Maximum number of submissions per problem") 14 | parser.add_argument("--max_per_lang_problem", type=int, default=10, help="Maximum number of submissions per language per problem") 15 | parser.add_argument("--internal_similarity_threshold", type=float, default=0.7, help="Internal similarity threshold") 16 | parser.add_argument("--external_similarity_threshold", type=float, default=0.4, help="External similarity threshold") 17 | args = parser.parse_args() 18 | 19 | def format_context(context): 20 | """ 21 | Formats the given context into a text description. 22 | 23 | Args: 24 | context (dict): The context containing information about the problem. 25 | 26 | Returns: 27 | str: The formatted text description. 28 | 29 | """ 30 | constraints = context['constraints'] 31 | input_description = context['input_description'] 32 | output_description = context['output_description'] 33 | problem_description = context['problem_description'] 34 | sample_inputs = context['sample_inputs'] 35 | sample_outputs = context['sample_outputs'] 36 | 37 | # Constructing the text description 38 | if problem_description: 39 | text = "Problem Description:\n" 40 | text += problem_description + "\n\n" 41 | 42 | if constraints: 43 | text += "Constraints:\n" 44 | text += constraints + "\n\n" 45 | 46 | if input_description: 47 | text += "Input Description:\n" 48 | text += input_description + "\n\n" 49 | 50 | if output_description: 51 | text += "Output Description:\n" 52 | text += output_description + "\n\n" 53 | 54 | if sample_inputs: 55 | for i, sample_input, sample_output in zip(range(len(sample_inputs)), sample_inputs, sample_outputs): 56 | text += f"Sample Input {i + 1}:\n" 57 | text += sample_input + "\n\n" 58 | text += f"Sample Output {i + 1}:\n" 59 | text += sample_output + "\n\n" 60 | 61 | return text.strip() 62 | 63 | def calculate_rouge_score(fs_tokens): 64 | """ 65 | Calculates the Rouge score between two sets of tokens. 66 | 67 | Args: 68 | fs_tokens (tuple): A tuple containing two sets of tokens. 69 | 70 | Returns: 71 | float: The Rouge score between the two sets of tokens. 72 | """ 73 | first_tokens, second_tokens = fs_tokens 74 | return rouge_scorer._score_lcs(first_tokens, second_tokens) 75 | 76 | def filter_problem_submissions(problems_submissions, problems_contexts, remove_problem_ids): 77 | """ 78 | Filters problem submissions based on various criteria including internal and external similarity thresholds. 79 | 80 | Args: 81 | problems_submissions (dict): A dictionary where keys are problem IDs and values are dictionaries of user submissions. 82 | problems_contexts (dict): A dictionary where keys are problem IDs and values are the context for each problem. 83 | remove_problem_ids (set): A set of problem IDs to be removed from consideration. 84 | 85 | Returns: 86 | list: A list of dictionaries containing filtered submissions with their respective languages and contexts. 87 | 88 | The function performs the following steps: 89 | 1. Initializes a multiprocessing pool and a ROUGE scorer. 90 | 2. Defines a nested function `filter_submissions` to filter individual submissions based on several criteria: 91 | - Single submission or single line submission. 92 | - Presence of quadruple newlines. 93 | - Duplicate submissions. 94 | - Non-ASCII characters. 95 | - Internal similarity threshold using ROUGE scores. 96 | 3. Iterates over each problem ID in `problems_submissions`: 97 | - Skips problem IDs in `remove_problem_ids`. 98 | - Limits the number of problems and languages per problem based on predefined thresholds. 99 | - Filters submissions using `filter_submissions`. 100 | - Checks external similarity threshold using ROUGE scores. 101 | - Appends valid submissions to the `submit` list. 102 | """ 103 | pool = Pool(args.num_proc) 104 | scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False) 105 | submit = [] 106 | 107 | def filter_submissions(submissions): 108 | """ 109 | Filters a list of submissions based on several criteria. 110 | 111 | Args: 112 | submissions (list of str): A list of submission strings. 113 | 114 | Returns: 115 | bool: True if the submissions should be filtered out, False otherwise. 116 | 117 | Criteria for filtering: 118 | - If there is only one submission. 119 | - If the last submission contains only one line. 120 | - If the last submission contains four consecutive newline characters. 121 | - If there are any duplicate submissions. 122 | - If any submission contains non-ASCII characters. 123 | - If the ROUGE score between the final submission and any previous submission is below a specified threshold. 124 | """ 125 | if len(submissions) == 1 or len(submissions[-1].split("\n")) == 1 or "\n\n\n\n" in submissions[-1]: 126 | return True 127 | for s1, s2 in zip(submissions[:-1], submissions[1:]): 128 | if s1 == s2: 129 | return True 130 | for submission in submissions: 131 | if any(not char.isascii() for char in submission): 132 | return True 133 | final_submission_tokens = scorer._tokenizer.tokenize(" ".join(submissions[-1].split())) 134 | pre_submissions_tokens = [scorer._tokenizer.tokenize(" ".join(submission.split())) for submission in submissions[:-1]] 135 | pre_final = zip(pre_submissions_tokens, [final_submission_tokens] * len(submissions[:-1])) 136 | rouge_scores = pool.map(calculate_rouge_score, pre_final) 137 | rouge_scores = [score.fmeasure for score in rouge_scores] 138 | if len(rouge_scores) == 0 or max(rouge_scores) < args.internal_similarity_threshold: 139 | return True 140 | return False 141 | 142 | for problem_id in tqdm.tqdm(problems_submissions): 143 | if problem_id in remove_problem_ids: 144 | continue 145 | num_problem = 0 146 | num_lang_problem = {} 147 | all_final_submissions_tokens = [] 148 | for user_id in problems_submissions[problem_id]: 149 | if num_problem >= args.max_per_problem: 150 | break 151 | lang = problems_submissions[problem_id][user_id]['language'] 152 | if num_lang_problem.get(lang, 0) >= args.max_per_lang_problem: 153 | continue 154 | submissions = problems_submissions[problem_id][user_id]["submissions"] 155 | if not filter_submissions(submissions): 156 | final_submission_tokens = scorer._tokenizer.tokenize(" ".join(submissions[-1].split())) 157 | all_final = zip(all_final_submissions_tokens, [final_submission_tokens] * len(all_final_submissions_tokens)) 158 | rouge_scores = pool.map(calculate_rouge_score, all_final) 159 | rouge_scores = [score.fmeasure for score in rouge_scores] 160 | if len(rouge_scores) == 0 or max(rouge_scores) < args.external_similarity_threshold: 161 | all_final_submissions_tokens.append(final_submission_tokens) 162 | num_problem += 1 163 | num_lang_problem[lang] = num_lang_problem.get(lang, 0) + 1 164 | submit.append({ 165 | "language": lang, 166 | "submissions": submissions, 167 | "problems_contexts": format_context(problems_contexts[problem_id]) if problem_id in problems_contexts else "" 168 | }) 169 | return submit 170 | 171 | with open(args.context_path, 'r') as f: 172 | codenet_context = json.load(f) 173 | 174 | with open(args.dataset_path, 'r') as f: 175 | codenet_submit = json.load(f) 176 | 177 | with open(args.identical_path, 'r') as f: 178 | identical_problem_clusters = f.read().split("\n")[:-1] 179 | 180 | remove_problem_ids = set() 181 | for cluster in identical_problem_clusters: 182 | problem_ids = cluster.split(",")[1:] 183 | remove_problem_ids.update(problem_ids) 184 | 185 | submit = filter_problem_submissions(codenet_submit, codenet_context, remove_problem_ids) 186 | 187 | with open(args.output_path, 'w') as f: 188 | json.dump(submit, f, indent=4, sort_keys=True) 189 | -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | This folder contains scripts for evaluating models. 4 | 5 | ## Inference service for LLMs 6 | 7 | Similar to data collection, we uniformly use the OpenAI interface to generate. 8 | 9 | **Note**: We leverages extra parameters specific to `vllm`'s OpenAI-compatible server for handling custom chat templates and special tokens for our models. Other OpenAI-compatible inference services may not be directly applicable. 10 | 11 | Example script to deploy `CursorCore-Yi-1.5B` using `vllm`: 12 | 13 | ```bash 14 | python -m vllm.entrypoints.openai.api_server --port 10086 --model TechxGenus/CursorCore-Yi-1.5B 15 | ``` 16 | 17 | We define the model inference service parameters in `model_map.json`. An example configuration is as follows: 18 | 19 | ```json 20 | { 21 | "TechxGenus/CursorCore-Yi-1.5B": { 22 | "base": "http://127.0.0.1:10086/v1", 23 | "api": "sk-xxx" 24 | } 25 | } 26 | ``` 27 | 28 | ## Run APEval evaluation 29 | 30 | Run the following program to generate predicted code: 31 | 32 | ```bash 33 | # WF Format (Default) 34 | python eval/eval_apeval.py --model_map model_map.json --input_path benchmark/apeval.json --output_path eval/generations.jsonl --temperature 0.0 --use_wf 35 | 36 | # LC Format 37 | python eval/eval_apeval.py --model_map model_map.json --input_path benchmark/apeval.json --output_path eval/generations.jsonl --temperature 0.0 --use_lc 38 | 39 | # SR Format 40 | python eval/eval_apeval.py --model_map model_map.json --input_path benchmark/apeval.json --output_path eval/generations.jsonl --temperature 0.0 --use_sr 41 | 42 | # Instruct Models 43 | python eval/eval_apeval.py --model_map model_map.json --input_path benchmark/apeval.json --output_path eval/generations.jsonl --temperature 0.0 --use_instruct 44 | 45 | # Base Models 46 | python eval/eval_apeval.py --model_map model_map.json --input_path benchmark/apeval.json --output_path eval/generations.jsonl --temperature 0.0 --use_base 47 | ``` 48 | 49 | Run the following script to execute programs: 50 | 51 | ```bash 52 | evalplus.evaluate --dataset humaneval --samples eval/generations.jsonl 53 | ``` 54 | 55 | Run the following script to get evaluation results for each type: 56 | 57 | ```bash 58 | python eval/extract_results.py --dataset_path benchmark/apeval.json --result_path eval/generations_eval_results.json 59 | ``` 60 | 61 | ## Run HumanEval/MBPP evaluation 62 | 63 | Run the following program to generate predicted code: 64 | 65 | ```bash 66 | # Tab 67 | python eval/eval_humaneval.py --model_map model_map.json --input_path evalplus/humanevalplus --output_path eval/generations.jsonl --temperature 0.0 --use_tab 68 | python eval/eval_mbpp.py --model_map model_map.json --input_path evalplus/mbppplus --output_path eval/generations.jsonl --temperature 0.0 --use_tab 69 | 70 | # Inline 71 | python eval/eval_humaneval.py --model_map model_map.json --input_path evalplus/humanevalplus --output_path eval/generations.jsonl --temperature 0.0 --use_inline 72 | python eval/eval_mbpp.py --model_map model_map.json --input_path evalplus/mbppplus --output_path eval/generations.jsonl --temperature 0.0 --use_inline 73 | 74 | # Chat 75 | python eval/eval_humaneval.py --model_map model_map.json --input_path evalplus/humanevalplus --output_path eval/generations.jsonl --temperature 0.0 --use_chat 76 | python eval/eval_mbpp.py --model_map model_map.json --input_path evalplus/mbppplus --output_path eval/generations.jsonl --temperature 0.0 --use_chat 77 | ``` 78 | 79 | Run the following script to execute programs: 80 | 81 | ```bash 82 | evalplus.evaluate --dataset humaneval --samples eval/generations.jsonl 83 | evalplus.evaluate --dataset mbpp --samples eval/generations.jsonl 84 | ``` 85 | -------------------------------------------------------------------------------- /eval/eval_apeval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | current_path = os.path.abspath(__file__) 4 | parent_directory = os.path.dirname(os.path.dirname(current_path)) 5 | sys.path.append(parent_directory) 6 | 7 | import json 8 | import argparse 9 | from gen import GenEvaluation 10 | from generic.special_tokens import * 11 | from generic.utils import data_args, openai_args, get_openai_kwargs 12 | from utils import prepare_input_for_wf, prepare_input_for_lc, prepare_input_for_sr, prepare_input_for_base, prepare_input_for_instruct, postprocess_output_wf, postprocess_output_lc, postprocess_output_sr, postprocess_output_instruct, postprocess_output_base 13 | 14 | parser = argparse.ArgumentParser() 15 | parser = data_args(parser) 16 | parser = openai_args(parser) 17 | parser.add_argument("--use_target_area", action="store_true", help="Whether to use target area") 18 | parser.add_argument("--sliding_window", type=int, default=-1, help="Sliding window size") 19 | parser.add_argument("--use_wf", action="store_true", help="Whether to use Whole File model") 20 | parser.add_argument("--use_lc", action="store_true", help="Whether to use Locate and Change model") 21 | parser.add_argument("--use_sr", action="store_true", help="Whether to use Search and Replace model") 22 | parser.add_argument("--use_instruct", action="store_true", help="Whether to use instruct model") 23 | parser.add_argument("--use_base", action="store_true", help="Whether to use base model") 24 | args = parser.parse_args() 25 | openai_kwargs = get_openai_kwargs(args) 26 | 27 | with open(args.input_path, 'r') as f: 28 | dataset = json.load(f) 29 | 30 | conversations = [] 31 | currents = [] 32 | task_ids = [] 33 | for sample in dataset: 34 | task_ids.append(sample["task_id"]) 35 | currents.append(sample["current"]["code"]) 36 | if args.use_target_area: 37 | if "area" in sample: 38 | if type(sample["area"]) == int: 39 | sample["current"]["code"] = sample["current"]["code"][:sample["area"]] + TARGET + sample["current"]["code"][sample["area"]:] 40 | elif type(sample["area"]) == list: 41 | start, end = sample["area"] 42 | sample["current"]["code"] = sample["current"]["code"][:end] + TARGET_END + sample["current"]["code"][end:] 43 | sample["current"]["code"] = sample["current"]["code"][:start] + TARGET_START + sample["current"]["code"][start:] 44 | else: 45 | raise ValueError("Invalid area type: {}".format(type(sample["area"]))) 46 | if args.sliding_window != -1: 47 | assert args.sliding_window > 0, "Sliding window size must be greater than 0" 48 | if len(sample["history"]) >= args.sliding_window: 49 | sample["history"] = sample["history"][-args.sliding_window:] 50 | if args.use_wf: 51 | conversations.append({"conversation": prepare_input_for_wf(sample)}) 52 | elif args.use_lc: 53 | conversations.append({"conversation": prepare_input_for_lc(sample)}) 54 | elif args.use_sr: 55 | conversations.append({"conversation": prepare_input_for_sr(sample)}) 56 | elif args.use_instruct: 57 | conversations.append({"conversation": prepare_input_for_instruct(sample)}) 58 | elif args.use_base: 59 | conversations.append({"conversation": prepare_input_for_base(sample)}) 60 | else: 61 | raise ValueError("Invalid model type: {}".format(args.model_type)) 62 | 63 | answers = [] 64 | 65 | with open(args.model_map, 'r') as f: 66 | model_map = json.load(f) 67 | 68 | if args.use_wf or args.use_lc or args.use_sr: 69 | openai_kwargs["extra_body"] = {"skip_special_tokens": False, 'chat_template': 'assistant-conversation'} 70 | openai_kwargs["stop"] = [NEXT_END] 71 | 72 | Gen = GenEvaluation( 73 | model_map=model_map, num_proc=args.num_proc, **openai_kwargs 74 | ) 75 | 76 | if not args.use_base: 77 | results = Gen.gen(conversations) 78 | else: 79 | results = Gen.gen(conversations, api_type="completion") 80 | 81 | if args.use_wf: 82 | output_data = [{"task_id": task_id, "solution": postprocess_output_wf(current, answer["output"])} for task_id, current, answer in zip(task_ids, currents, results)] 83 | elif args.use_lc: 84 | output_data = [{"task_id": task_id, "solution": postprocess_output_lc(current, answer["output"])} for task_id, current, answer in zip(task_ids, currents, results)] 85 | elif args.use_sr: 86 | output_data = [{"task_id": task_id, "solution": postprocess_output_sr(current, answer["output"])} for task_id, current, answer in zip(task_ids, currents, results)] 87 | elif args.use_instruct: 88 | output_data = [{"task_id": task_id, "solution": postprocess_output_instruct(current, answer["output"])} for task_id, current, answer in zip(task_ids, currents, results)] 89 | elif args.use_base: 90 | output_data = [{"task_id": task_id, "solution": postprocess_output_base(current, answer["output"])} for task_id, current, answer in zip(task_ids, currents, results)] 91 | else: 92 | raise ValueError("Invalid model type: {}".format(args.model_type)) 93 | 94 | with open(args.output_path, 'w') as f: 95 | for item in output_data: 96 | json_line = json.dumps(item) 97 | f.write(json_line + '\n') 98 | -------------------------------------------------------------------------------- /eval/eval_humaneval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | current_path = os.path.abspath(__file__) 4 | parent_directory = os.path.dirname(os.path.dirname(current_path)) 5 | sys.path.append(parent_directory) 6 | 7 | import json 8 | import argparse 9 | from gen import GenEvaluation 10 | from datasets import load_dataset 11 | from generic.special_tokens import * 12 | from generic.utils import data_args, openai_args, get_openai_kwargs, decorate_code 13 | from utils import postprocess_output_wf 14 | 15 | parser = argparse.ArgumentParser() 16 | parser = data_args(parser) 17 | parser = openai_args(parser) 18 | parser.add_argument("--use_tab", action="store_true", help="Whether to use automated editing") 19 | parser.add_argument("--use_inline", action="store_true", help="Whether to use inline chat") 20 | parser.add_argument("--use_chat", action="store_true", help="Whether to use chat") 21 | args = parser.parse_args() 22 | openai_kwargs = get_openai_kwargs(args) 23 | 24 | dataset = load_dataset(args.input_path) 25 | conversations = [] 26 | prompts = [] 27 | task_ids = [] 28 | for sample in dataset["test"]: 29 | task_ids.append(sample["task_id"]) 30 | prompt = sample['prompt'].strip() 31 | prompts.append(prompt) 32 | if args.use_tab: 33 | conversations.append({"conversation": [{"role": "current", "content": decorate_code(prompt, lang="python")}]}) 34 | elif args.use_inline: 35 | conversations.append({"conversation": [{"role": "current", "content": decorate_code(prompt, lang="python")}, {"role": "user", "content": "Please complete the function."}]}) 36 | elif args.use_chat: 37 | conversations.append({"conversation": [{"role": "user", "content": f"Please continue to complete the function. You are not allowed to modify the given code and do the completion only. Please return all completed function in a codeblock. Here is the given code to do completion:\n```python\n{prompt}\n```"}]}) 38 | else: 39 | raise ValueError("Invalid model type: {}".format(args.model_type)) 40 | 41 | answers = [] 42 | 43 | with open(args.model_map, 'r') as f: 44 | model_map = json.load(f) 45 | 46 | if args.use_tab or args.use_inline: 47 | openai_kwargs["extra_body"] = {"skip_special_tokens": False, 'chat_template': 'assistant-conversation'} 48 | openai_kwargs["stop"] = [NEXT_END] 49 | 50 | Gen = GenEvaluation( 51 | model_map=model_map, num_proc=args.num_proc, **openai_kwargs 52 | ) 53 | 54 | results = Gen.gen(conversations) 55 | 56 | if args.use_tab: 57 | output_data = [{"task_id": task_id, "solution": postprocess_output_wf(current, answer["output"])} for task_id, current, answer in zip(task_ids, prompts, results)] 58 | elif args.use_inline: 59 | output_data = [{"task_id": task_id, "solution": postprocess_output_wf(current, answer["output"])} for task_id, current, answer in zip(task_ids, prompts, results)] 60 | elif args.use_chat: 61 | output_data = [{"task_id": task_id, "solution": answer["output"]} for task_id, answer in zip(task_ids, results)] 62 | else: 63 | raise ValueError("Invalid model type: {}".format(args.model_type)) 64 | 65 | with open(args.output_path, 'w') as f: 66 | for item in output_data: 67 | json_line = json.dumps(item) 68 | f.write(json_line + '\n') 69 | -------------------------------------------------------------------------------- /eval/eval_mbpp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | current_path = os.path.abspath(__file__) 4 | parent_directory = os.path.dirname(os.path.dirname(current_path)) 5 | sys.path.append(parent_directory) 6 | 7 | import json 8 | import argparse 9 | from gen import GenEvaluation 10 | from datasets import load_dataset 11 | from generic.special_tokens import * 12 | from generic.utils import data_args, openai_args, get_openai_kwargs, decorate_code 13 | from utils import postprocess_output_wf 14 | 15 | parser = argparse.ArgumentParser() 16 | parser = data_args(parser) 17 | parser = openai_args(parser) 18 | parser.add_argument("--use_tab", action="store_true", help="Whether to use automated editing") 19 | parser.add_argument("--use_inline", action="store_true", help="Whether to use inline chat") 20 | parser.add_argument("--use_chat", action="store_true", help="Whether to use chat") 21 | args = parser.parse_args() 22 | openai_kwargs = get_openai_kwargs(args) 23 | 24 | dataset = load_dataset(args.input_path) 25 | conversations = [] 26 | prompts = [] 27 | task_ids = [] 28 | for sample in dataset["test"]: 29 | task_ids.append(sample["task_id"]) 30 | prompt = f'"""\n{sample["prompt"]}\n{sample["test_list"][0]}\n"""\n' 31 | prompts.append(prompt) 32 | if args.use_tab: 33 | conversations.append({"conversation": [{"role": "current", "content": decorate_code(prompt, lang="python")}]}) 34 | elif args.use_inline: 35 | conversations.append({"conversation": [{"role": "current", "content": decorate_code(prompt, lang="python")}, {"role": "user", "content": "Please complete the function."}]}) 36 | elif args.use_chat: 37 | conversations.append({"conversation": [{"role": "user", "content": f"Please continue to complete the function. You are not allowed to modify the given code and do the completion only. Please return all completed function in a codeblock. Here is the given code to do completion:\n```python\n{prompt}\n```"}]}) 38 | else: 39 | raise ValueError("Invalid model type: {}".format(args.model_type)) 40 | 41 | answers = [] 42 | 43 | with open(args.model_map, 'r') as f: 44 | model_map = json.load(f) 45 | 46 | if args.use_tab or args.use_inline: 47 | openai_kwargs["extra_body"] = {"skip_special_tokens": False, 'chat_template': 'assistant-conversation'} 48 | openai_kwargs["stop"] = [NEXT_END] 49 | 50 | Gen = GenEvaluation( 51 | model_map=model_map, num_proc=args.num_proc, **openai_kwargs 52 | ) 53 | 54 | results = Gen.gen(conversations) 55 | 56 | if args.use_tab: 57 | output_data = [{"task_id": f"Mbpp/{task_id}", "solution": postprocess_output_wf(current, answer["output"])} for task_id, current, answer in zip(task_ids, prompts, results)] 58 | elif args.use_inline: 59 | output_data = [{"task_id": f"Mbpp/{task_id}", "solution": postprocess_output_wf(current, answer["output"])} for task_id, current, answer in zip(task_ids, prompts, results)] 60 | elif args.use_instruct: 61 | output_data = [{"task_id": f"Mbpp/{task_id}", "solution": answer["output"]} for task_id, answer in zip(task_ids, results)] 62 | else: 63 | raise ValueError("Invalid model type: {}".format(args.model_type)) 64 | 65 | with open(args.output_path, 'w') as f: 66 | for item in output_data: 67 | json_line = json.dumps(item) 68 | f.write(json_line + '\n') 69 | -------------------------------------------------------------------------------- /eval/extract_results.py: -------------------------------------------------------------------------------- 1 | #TODO: Robust pass@k implementation, only evaluation of greedy decoding now 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--dataset_path", type=str, help="Path to the dataset file") 7 | parser.add_argument("--result_path", type=str, help="Path to the result file") 8 | args = parser.parse_args() 9 | 10 | with open(args.dataset_path, "r") as f: 11 | dataset = json.load(f) 12 | 13 | with open(args.result_path, "r") as f: 14 | result = json.load(f)["eval"] 15 | 16 | # ------------------------------------------------- 17 | # 1. Count how many samples are in each category 18 | # ------------------------------------------------- 19 | total_nh_nu = 0 # No history, no user 20 | total_h_nu = 0 # History, no user 21 | total_nh_u = 0 # No history, user 22 | total_h_u = 0 # History, user 23 | 24 | for sample in dataset: 25 | if not sample["history"] and not sample["user"]: 26 | total_nh_nu += 1 27 | elif sample["history"] and not sample["user"]: 28 | total_h_nu += 1 29 | elif not sample["history"] and sample["user"]: 30 | total_nh_u += 1 31 | else: 32 | total_h_u += 1 33 | 34 | total_all = len(dataset) 35 | 36 | # ------------------------------------------------- 37 | # 2. Initialize counters for passes in each category 38 | # ------------------------------------------------- 39 | c_base = 0 # base pass: no history, no user 40 | h_c_base = 0 # base pass: history, no user 41 | c_u_base = 0 # base pass: no history, user 42 | h_c_u_base = 0 # base pass: history, user 43 | 44 | c_extra = 0 # plus pass: no history, no user 45 | h_c_extra = 0 # plus pass: history, no user 46 | c_u_extra = 0 # plus pass: no history, user 47 | h_c_u_extra = 0 # plus pass: history, user 48 | 49 | # ------------------------------------------------- 50 | # 3. Fill in counters by checking pass status 51 | # ------------------------------------------------- 52 | for sample in dataset: 53 | base_status = result[sample["task_id"]][0]["base_status"] 54 | plus_status = result[sample["task_id"]][0]["plus_status"] 55 | 56 | # Base pass checks 57 | if base_status == "pass": 58 | if not sample["history"] and not sample["user"]: 59 | c_base += 1 60 | elif sample["history"] and not sample["user"]: 61 | h_c_base += 1 62 | elif not sample["history"] and sample["user"]: 63 | c_u_base += 1 64 | else: 65 | h_c_u_base += 1 66 | 67 | # Plus pass checks 68 | if plus_status == "pass": 69 | if not sample["history"] and not sample["user"]: 70 | c_extra += 1 71 | elif sample["history"] and not sample["user"]: 72 | h_c_extra += 1 73 | elif not sample["history"] and sample["user"]: 74 | c_u_extra += 1 75 | else: 76 | h_c_u_extra += 1 77 | 78 | 79 | # ------------------------------------------------- 80 | # 4. Helper function to safely handle divisions 81 | # ------------------------------------------------- 82 | def ratio_str(numerator, denominator): 83 | """Return 'count/denominator (xx.x%)', handling zero denominator.""" 84 | if denominator == 0: 85 | return f"{numerator}/0 (N/A)" 86 | else: 87 | return f"{numerator}/{denominator} ({numerator / denominator:.1%})" 88 | 89 | 90 | # ------------------------------------------------- 91 | # 5. Print results for Base 92 | # ------------------------------------------------- 93 | print("Base Status:") 94 | print(f" No History, No User: {ratio_str(c_base, total_nh_nu)}") 95 | print(f" History, No User: {ratio_str(h_c_base, total_h_nu)}") 96 | print(f" No History, User: {ratio_str(c_u_base, total_nh_u)}") 97 | print(f" History, User: {ratio_str(h_c_u_base, total_h_u)}") 98 | print( 99 | f" Total: {ratio_str(c_base + h_c_base + c_u_base + h_c_u_base, total_all)}" 100 | ) 101 | 102 | # ------------------------------------------------- 103 | # 6. Print results for Plus 104 | # ------------------------------------------------- 105 | print("\nPlus Status:") 106 | print(f" No History, No User: {ratio_str(c_extra, total_nh_nu)}") 107 | print(f" History, No User: {ratio_str(h_c_extra, total_h_nu)}") 108 | print(f" No History, User: {ratio_str(c_u_extra, total_nh_u)}") 109 | print(f" History, User: {ratio_str(h_c_u_extra, total_h_u)}") 110 | print( 111 | f" Total: {ratio_str(c_extra + h_c_extra + c_u_extra + h_c_u_extra, total_all)}" 112 | ) 113 | -------------------------------------------------------------------------------- /eval/search_and_replace.py: -------------------------------------------------------------------------------- 1 | # Ref: https://github.com/sweepai/sweep/blob/main/sweepai/utils/search_and_replace.py 2 | 3 | import re 4 | from dataclasses import dataclass 5 | from functools import lru_cache 6 | 7 | from rapidfuzz import fuzz 8 | from tqdm import tqdm 9 | 10 | 11 | @lru_cache() 12 | def score_line(str1: str, str2: str) -> float: 13 | if str1 == str2: 14 | return 100 15 | 16 | if str1.lstrip() == str2.lstrip(): 17 | whitespace_ratio = abs(len(str1) - len(str2)) / (len(str1) + len(str2)) 18 | score = 90 - whitespace_ratio * 10 19 | return max(score, 0) 20 | 21 | if str1.strip() == str2.strip(): 22 | whitespace_ratio = abs(len(str1) - len(str2)) / (len(str1) + len(str2)) 23 | score = 80 - whitespace_ratio * 10 24 | return max(score, 0) 25 | 26 | levenshtein_ratio = fuzz.ratio(str1, str2) 27 | 28 | score = 85 * (levenshtein_ratio / 100) 29 | return max(score, 0) 30 | 31 | 32 | def match_without_whitespace(str1: str, str2: str) -> bool: 33 | return str1.strip() == str2.strip() 34 | 35 | 36 | def line_cost(line: str) -> float: 37 | if line.strip() == "": 38 | return 50 39 | if line.strip().startswith("#") or line.strip().startswith("//"): 40 | return 50 + len(line) / (len(line) + 1) * 30 41 | return len(line) / (len(line) + 1) * 100 42 | 43 | 44 | def score_multiline(query: list[str], target: list[str]) -> float: 45 | # TODO: add weighting on first and last lines 46 | 47 | q, t = 0, 0 # indices for query and target 48 | scores: list[tuple[float, float]] = [] 49 | skipped_comments = 0 50 | 51 | def get_weight(q: int) -> float: 52 | # Prefers lines at beginning and end of query 53 | # Sequence: 1, 2/3, 1/2, 2/5... 54 | index = min(q, len(query) - q) 55 | return 100 / (index / 2 + 1) 56 | 57 | while q < len(query) and t < len(target): 58 | q_line = query[q] 59 | t_line = target[t] 60 | weight = get_weight(q) 61 | 62 | if match_without_whitespace(q_line, t_line): 63 | # Case 1: lines match 64 | scores.append((score_line(q_line, t_line), weight)) 65 | q += 1 66 | t += 1 67 | elif q_line.strip().startswith("...") or q_line.strip().endswith("..."): 68 | # Case 3: ellipsis wildcard 69 | t += 1 70 | if q + 1 == len(query): 71 | scores.append((100 - (len(target) - t), weight)) 72 | q += 1 73 | t = len(target) 74 | break 75 | max_score = 0 76 | # Radix optimization 77 | indices = [ 78 | t + i 79 | for i, line in enumerate(target[t:]) 80 | if match_without_whitespace(line, query[q + 1]) 81 | ] 82 | if not indices: 83 | indices = range(t, len(target)) 84 | for i in indices: 85 | score, weight = score_multiline(query[q + 1 :], target[i:]), ( 86 | 100 - (i - t) / len(target) * 10 87 | ) 88 | new_scores = scores + [(score, weight)] 89 | total_score = sum( 90 | [value * weight for value, weight in new_scores] 91 | ) / sum([weight for _, weight in new_scores]) 92 | max_score = max(max_score, total_score) 93 | return max_score 94 | elif ( 95 | t_line.strip() == "" 96 | or t_line.strip().startswith("#") 97 | or t_line.strip().startswith("//") 98 | or t_line.strip().startswith("print") 99 | or t_line.strip().startswith("logger") 100 | or t_line.strip().startswith("console.") 101 | ): 102 | # Case 2: skipped comment 103 | skipped_comments += 1 104 | t += 1 105 | scores.append((90, weight)) 106 | else: 107 | break 108 | 109 | if q < len(query): 110 | scores.extend( 111 | (100 - line_cost(line), get_weight(index)) 112 | for index, line in enumerate(query[q:]) 113 | ) 114 | if t < len(target): 115 | scores.extend( 116 | (100 - line_cost(line), 100) for index, line in enumerate(target[t:]) 117 | ) 118 | 119 | final_score = ( 120 | sum([value * weight for value, weight in scores]) 121 | / sum([weight for _, weight in scores]) 122 | if scores 123 | else 0 124 | ) 125 | final_score *= 1 - 0.05 * skipped_comments 126 | 127 | return final_score 128 | 129 | 130 | @dataclass 131 | class Match: 132 | start: int 133 | end: int 134 | score: float 135 | indent: str = "" 136 | 137 | def __gt__(self, other): 138 | return self.score > other.score 139 | 140 | 141 | def get_indent_type(content: str): 142 | two_spaces = len(re.findall(r"\n {2}[^ ]", content)) 143 | four_spaces = len(re.findall(r"\n {4}[^ ]", content)) 144 | 145 | return " " if two_spaces > four_spaces else " " 146 | 147 | 148 | def get_max_indent(content: str, indent_type: str): 149 | return max(len(line) - len(line.lstrip()) for line in content.split("\n")) // len( 150 | indent_type 151 | ) 152 | 153 | # Bug: raise "not enough values to unpack (expected 2, got 1)" for some inputs 154 | def find_best_match(query: str, code_file: str): 155 | best_match = Match(-1, -1, 0) 156 | 157 | code_file_lines = code_file.split("\n") 158 | query_lines = query.split("\n") 159 | if len(query_lines) > 0 and query_lines[-1].strip() == "...": 160 | query_lines = query_lines[:-1] 161 | if len(query_lines) > 0 and query_lines[0].strip() == "...": 162 | query_lines = query_lines[1:] 163 | indent = get_indent_type(code_file) 164 | max_indents = get_max_indent(code_file, indent) 165 | 166 | top_matches = [] 167 | 168 | if len(query_lines) == 1: 169 | for i, line in enumerate(code_file_lines): 170 | score = score_line(line, query_lines[0]) 171 | if score > best_match.score: 172 | best_match = Match(i, i + 1, score) 173 | return best_match 174 | 175 | truncate = min(40, len(code_file_lines) // 5) 176 | if truncate < 1: 177 | truncate = len(code_file_lines) 178 | 179 | indent_array = [i for i in range(0, max(min(max_indents + 1, 20), 1))] 180 | if max_indents > 3: 181 | indent_array = [3, 2, 4, 0, 1] + list(range(5, max_indents + 1)) 182 | for num_indents in indent_array: 183 | indented_query_lines = [indent * num_indents + line for line in query_lines] 184 | 185 | start_pairs = [ 186 | (i, score_line(line, indented_query_lines[0])) 187 | for i, line in enumerate(code_file_lines) 188 | ] 189 | start_pairs.sort(key=lambda x: x[1], reverse=True) 190 | start_pairs = start_pairs[:truncate] 191 | start_indices = [i for i, _ in start_pairs] 192 | 193 | for i in tqdm( 194 | start_indices, 195 | position=0, 196 | desc=f"Indent {num_indents}/{max_indents}", 197 | leave=False, 198 | ): 199 | end_pairs = [ 200 | (j, score_line(line, indented_query_lines[-1])) 201 | for j, line in enumerate(code_file_lines[i:], start=i) 202 | ] 203 | end_pairs.sort(key=lambda x: x[1], reverse=True) 204 | end_pairs = end_pairs[:truncate] 205 | end_indices = [j for j, _ in end_pairs] 206 | 207 | for j in tqdm( 208 | end_indices, position=1, leave=False, desc=f"Starting line {i}" 209 | ): 210 | candidate = code_file_lines[i : j + 1] 211 | raw_score = score_multiline(indented_query_lines, candidate) 212 | 213 | score = raw_score * (1 - num_indents * 0.01) 214 | current_match = Match(i, j + 1, score, indent * num_indents) 215 | 216 | if raw_score >= 99.99: # early exit, 99.99 for floating point error 217 | return current_match 218 | 219 | top_matches.append(current_match) 220 | 221 | if score > best_match.score: 222 | best_match = current_match 223 | 224 | unique_top_matches: list[Match] = [] 225 | unique_spans = set() 226 | for top_match in sorted(top_matches, reverse=True): 227 | if (top_match.start, top_match.end) not in unique_spans: 228 | unique_top_matches.append(top_match) 229 | unique_spans.add((top_match.start, top_match.end)) 230 | 231 | # Todo: on_comment file comments able to modify multiple files 232 | return unique_top_matches[0] if unique_top_matches else Match(-1, -1, 0) 233 | 234 | 235 | def split_ellipses(query: str) -> list[str]: 236 | queries = [] 237 | current_query = "" 238 | for line in query.split("\n"): 239 | if line.strip() == "...": 240 | queries.append(current_query.strip("\n")) 241 | current_query = "" 242 | else: 243 | current_query += line + "\n" 244 | queries.append(current_query.strip("\n")) 245 | return queries 246 | 247 | 248 | def match_indent(generated: str, original: str) -> str: 249 | indent_type = "\t" if "\t" in original[:5] else " " 250 | generated_indents = len(generated) - len(generated.lstrip()) 251 | target_indents = len(original) - len(original.lstrip()) 252 | diff_indents = target_indents - generated_indents 253 | if diff_indents > 0: 254 | generated = indent_type * diff_indents + generated.replace( 255 | "\n", "\n" + indent_type * diff_indents 256 | ) 257 | return generated 258 | -------------------------------------------------------------------------------- /eval/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | current_path = os.path.abspath(__file__) 4 | parent_directory = os.path.dirname(os.path.dirname(current_path)) 5 | sys.path.append(parent_directory) 6 | 7 | import re 8 | import difflib 9 | import Levenshtein 10 | from generic.special_tokens import * 11 | from .search_and_replace import find_best_match 12 | from generic.utils import decorate_code, extract_changes_lines, generate_locations_changes, generate_search_and_replace 13 | 14 | def postprocess_output_wf(current, output): 15 | """ 16 | Processes the given output string to extract a specific section of text 17 | between defined markers and returns the first match found. 18 | 19 | Args: 20 | current (str): The current string to return in case of an error. 21 | output (str): The output string to be processed. 22 | 23 | Returns: 24 | str: The extracted section of text if found, otherwise returns the 25 | current string. 26 | 27 | Raises: 28 | Exception: If an error occurs during processing, the exception is 29 | caught and the current string is returned. 30 | """ 31 | try: 32 | output = output.split(NEXT_START)[-1].split(NEXT_END)[0] 33 | pattern = r"```(.*?)\n([\s\S]*?)\n```" 34 | wf = re.findall(pattern, output) 35 | return wf[0][1] 36 | except Exception as e: 37 | print(e) 38 | return current 39 | 40 | def postprocess_output_lc(current, output): 41 | """ 42 | Post-processes the output by extracting and applying code modifications to the current code. 43 | 44 | Args: 45 | current (str): The current code as a string. 46 | output (str): The output containing the modifications. 47 | 48 | Returns: 49 | str: The updated code after applying the modifications. 50 | 51 | The function expects the `output` to contain code modifications in a specific format: 52 | - The modifications are enclosed between `NEXT_START` and `NEXT_END` markers. 53 | - Each modification block follows the pattern: `start_line,end_line\n```\n\n```` 54 | - `start_line` and `end_line` specify the line range in the current code to be replaced by ``. 55 | 56 | If an error occurs during processing, the function prints the exception and returns the original `current` code. 57 | """ 58 | try: 59 | output = output.split(NEXT_START)[-1].split(NEXT_END)[0].strip() 60 | pattern = r"(\d+),(\d+)\n```(.*?)\n([\s\S]*?)\n```" 61 | lc = re.findall(pattern, output) 62 | current_lines = current.split("\n") 63 | for start_line, end_line, _, code in lc[::-1]: 64 | start_line = int(start_line) 65 | end_line = int(end_line) 66 | current_lines = current_lines[:start_line] + code.split("\n") + current_lines[end_line:] 67 | return "\n".join(current_lines) 68 | except Exception as e: 69 | print(e) 70 | return current 71 | 72 | def postprocess_output_sr(current, output): 73 | """ 74 | Post-processes the output string by extracting and applying search-and-replace operations. 75 | 76 | Args: 77 | current (str): The current string content to be modified. 78 | output (str): The output string containing the search-and-replace instructions. 79 | 80 | Returns: 81 | str: The modified string after applying the search-and-replace operations. 82 | 83 | The function performs the following steps: 84 | 1. Extracts the relevant portion of the output string between NEXT_START and NEXT_END markers. 85 | 2. Finds all code blocks within the extracted portion using a regular expression pattern. 86 | 3. For each code block, splits it into 'before' and 'after' parts using the SEARCH_AND_REPLACE marker. 87 | 4. Finds the best match for the 'before' part in the current string. 88 | 5. Replaces the matched portion in the current string with the 'after' part. 89 | 6. Returns the modified string. 90 | 91 | If an exception occurs during processing, the function prints the exception and returns the original current string. 92 | """ 93 | try: 94 | output = output.split(NEXT_START)[-1].split(NEXT_END)[0].strip() 95 | pattern = r"```(.*?)\n([\s\S]*?)\n```" 96 | sr = re.findall(pattern, output) 97 | current_lines = current.split("\n") 98 | for _, before_and_after in sr[::-1]: 99 | before, after = before_and_after.split("\n" + SEARCH_AND_REPLACE + "\n") 100 | match_before = find_best_match(before, current) 101 | current_lines = current_lines[:match_before.start] + after.split("\n") + current_lines[match_before.end:] 102 | return "\n".join(current_lines) 103 | except Exception as e: 104 | print(e) 105 | return current 106 | 107 | def postprocess_output_base(current, output): 108 | """ 109 | Processes the given output string to extract content between specific markers. 110 | 111 | This function splits the output string using the NEXT_START and NEXT_END markers, 112 | then uses a regular expression to find content enclosed in triple backticks (```) 113 | and returns the first match. If an error occurs during processing, the function 114 | returns the current string. 115 | 116 | Args: 117 | current (str): The current string to return in case of an error. 118 | output (str): The output string to be processed. 119 | 120 | Returns: 121 | str: The extracted content between the markers, or the current string if an error occurs. 122 | """ 123 | try: 124 | output = output.split(NEXT_START)[-1].split(NEXT_END)[0] 125 | pattern = r"```(.*?)\n([\s\S]*?)\n```" 126 | wf = re.findall(pattern, output) 127 | return wf[0][1] 128 | except Exception as e: 129 | print(e) 130 | return current 131 | 132 | def postprocess_output_instruct(current, output): 133 | """ 134 | Post-processes the given output string by extracting the content between 135 | specific markers and returning the first code block found. 136 | 137 | Args: 138 | current (str): The current string to return in case of an error. 139 | output (str): The output string to be processed. 140 | 141 | Returns: 142 | str: The extracted code block from the output string. If an error occurs, 143 | the function returns the 'current' string. 144 | """ 145 | try: 146 | output = output.split(NEXT_START)[-1].split(NEXT_END)[0] 147 | pattern = r"```(.*?)\n([\s\S]*?)\n```" 148 | wf = re.findall(pattern, output) 149 | return wf[0][1] 150 | except Exception as e: 151 | print(e) 152 | return current 153 | 154 | def prepare_input_for_instruct(sample): 155 | """ 156 | Prepares a conversation input for an instruction-following model based on the provided sample. 157 | 158 | Args: 159 | sample (dict): A dictionary containing the following keys: 160 | - "history" (list): A list of dictionaries, each containing: 161 | - "code" (str): The code from a previous programming process. 162 | - "lang" (str): The programming language of the code. 163 | - "current" (dict): A dictionary containing: 164 | - "code" (str): The current code to be modified. 165 | - "lang" (str): The programming language of the current code. 166 | - "user" (str): The user's instruction for modifying the current code. 167 | 168 | Returns: 169 | list: A list of dictionaries representing the conversation, where each dictionary contains: 170 | - "role" (str): The role in the conversation, either "user" or "assistant". 171 | - "content" (str): The content of the message. 172 | """ 173 | conversation = [] 174 | one_shot = [{"role": "user", "content": "Read the following messages during programming and return the modified code in this format:\n\n<|next_start|>{modified code}<|next_end|>\n\nProgramming process 1:\n```python\na = 1\nb = 2\nc = a + b\n```\n\nCurrent code:\n```python\ni = 1\nb = 2\nc = a + b\n```\n\nUser instruction:\nPlease change variable names."}, {"role": "assistant", "content": "<|next_start|>```python\ni = 1\nj = 2\nk = i + j\n```<|next_end|>"}] 175 | conversation += one_shot 176 | prompt = "" 177 | prompt += "Read the following messages during programming and return the modified code in this format:\n\n<|next_start|>{modified code}<|next_end|>" 178 | prompt += "\n\n" 179 | if sample["history"]: 180 | for i, h in enumerate(sample["history"]): 181 | prompt += f"Programming process {i + 1}:\n" 182 | prompt += decorate_code(h["code"], lang=h["lang"]) + "\n\n" 183 | prompt += "Current code:\n" 184 | prompt += decorate_code(sample["current"]["code"], lang=sample["current"]["lang"]) + "\n\n" 185 | if sample["user"]: 186 | prompt += "User instruction:\n" 187 | prompt += sample["user"] + "\n\n" 188 | conversation.append({"role": "user", "content": prompt.strip()}) 189 | return conversation 190 | 191 | def prepare_input_for_base(sample): 192 | """ 193 | Prepares a formatted input string for a base model by combining a prompt, one-shot example, 194 | and the provided sample data including history, current code, and user instructions. 195 | 196 | Args: 197 | sample (dict): A dictionary containing the following keys: 198 | - "history" (list): A list of dictionaries, each containing: 199 | - "code" (str): The code snippet from the history. 200 | - "lang" (str): The programming language of the code snippet. 201 | - "current" (dict): A dictionary containing: 202 | - "code" (str): The current code snippet. 203 | - "lang" (str): The programming language of the current code snippet. 204 | - "user" (str): The user instruction. 205 | 206 | Returns: 207 | str: A formatted string that includes the prompt, one-shot example, and the sample data. 208 | """ 209 | prompt = "Read the following messages during programming and return the modified code in this format:\n\n<|next_start|>{modified code}<|next_end|>\n\n" 210 | one_shot = "<|messages_start|>Programming process 1:\n```python\na = 1\nb = 2\nc = a + b\n```\n\nCurrent code:\n```python\ni = 1\nb = 2\nc = a + b\n```\n\nUser instruction:\nPlease change variable names.<|messages_end|>\n\n<|next_start|>```python\ni = 1\nj = 2\nk = i + j\n```<|next_end|>\n\n" 211 | prompt += one_shot 212 | prompt += "Read the following messages during programming and return the modified code in this format:\n\n<|next_start|>{modified code}<|next_end|>\n\n<|messages_start|>" 213 | if sample["history"]: 214 | for i, h in enumerate(sample["history"]): 215 | prompt += f"Programming process {i + 1}:\n" 216 | prompt += decorate_code(h["code"], lang=h["lang"]) + "\n\n" 217 | prompt += "Current code:\n" 218 | prompt += decorate_code(sample["current"]["code"], lang=sample["current"]["lang"]) + "\n\n" 219 | if sample["user"]: 220 | prompt += "User instruction:\n" 221 | prompt += sample["user"] + "\n\n" 222 | prompt = prompt.strip() + "<|messages_end|>\n\n" 223 | return prompt 224 | 225 | def prepare_input_for_wf(sample): 226 | """ 227 | Prepares the input data for workflow processing by formatting the conversation history. 228 | 229 | Args: 230 | sample (dict): A dictionary containing the following keys: 231 | - "history" (list): A list of dictionaries representing past messages. Each dictionary contains: 232 | - "code" (str): The code snippet. 233 | - "lang" (str): The programming language of the code snippet. 234 | - "current" (dict): A dictionary representing the current message with the same structure as the history messages. 235 | - "user" (str): The user's input message. 236 | 237 | Returns: 238 | list: A list of dictionaries representing the conversation. Each dictionary contains: 239 | - "role" (str): The role of the message, either "history", "current", or "user". 240 | - "content" (str): The formatted content of the message. 241 | """ 242 | conversation = [] 243 | if sample["history"]: 244 | history_current = sample["history"] + [sample["current"]] 245 | for m1, m2 in zip(history_current[:-1], history_current[1:]): 246 | message = decorate_code(m1["code"], m1["lang"]) 247 | conversation.append({"role": "history", "content": message}) 248 | conversation.append({"role": "current", "content": decorate_code(sample["current"]["code"], sample["current"]["lang"])}) 249 | if sample["user"]: 250 | conversation.append({"role": "user", "content": sample["user"]}) 251 | return conversation 252 | 253 | def prepare_input_for_lc(sample): 254 | """ 255 | Prepares the input data for language model conversation (LC) by processing the sample's history, current code, and user input. 256 | 257 | Args: 258 | sample (dict): A dictionary containing the following keys: 259 | - "history" (list): A list of dictionaries representing the history of code changes. 260 | - "current" (dict): A dictionary representing the current code state with keys: 261 | - "code" (str): The current code. 262 | - "lang" (str): The programming language of the current code. 263 | - "user" (str): The user's input or query. 264 | 265 | Returns: 266 | list: A list of dictionaries representing the conversation, where each dictionary has: 267 | - "role" (str): The role in the conversation, either "history", "current", or "user". 268 | - "content" (str): The content associated with the role, such as code changes or user input. 269 | """ 270 | conversation = [] 271 | if sample["history"]: 272 | history_current = sample["history"] + [sample["current"]] 273 | for m1, m2 in zip(history_current[:-1], history_current[1:]): 274 | changes_lines = extract_changes_lines(m2["code"], m1["code"]) 275 | locations_changes = generate_locations_changes(m2["code"], m1["code"], m1["lang"], changes_lines) 276 | conversation.append({"role": "history", "content": locations_changes}) 277 | conversation.append({"role": "current", "content": decorate_code(sample["current"]["code"], sample["current"]["lang"], use_line_num=True)}) 278 | if sample["user"]: 279 | conversation.append({"role": "user", "content": sample["user"]}) 280 | return conversation 281 | 282 | def prepare_input_for_sr(sample): 283 | """ 284 | Prepares the input for a search and replace (SR) task based on the provided sample. 285 | 286 | Args: 287 | sample (dict): A dictionary containing the following keys: 288 | - "history" (list): A list of dictionaries representing the history of code changes. 289 | - "current" (dict): A dictionary representing the current code state with keys: 290 | - "code" (str): The current code. 291 | - "lang" (str): The programming language of the current code. 292 | - "user" (str, optional): A string representing the user's input or query. 293 | 294 | Returns: 295 | list: A list of dictionaries representing the conversation for the SR task. Each dictionary contains: 296 | - "role" (str): The role in the conversation, either "history", "current", or "user". 297 | - "content" (str): The content associated with the role, such as code changes or user input. 298 | """ 299 | conversation = [] 300 | if sample["history"]: 301 | history_current = sample["history"] + [sample["current"]] 302 | for m1, m2 in zip(history_current[:-1], history_current[1:]): 303 | changes_lines = extract_changes_lines(m2["code"], m1["code"], unique=True, merge_changes=True) 304 | changes_lines = [(new, old) for old, new in changes_lines] 305 | search_and_replace = generate_search_and_replace(m1["code"], m2["code"], m1["lang"], changes_lines) 306 | conversation.append({"role": "history", "content": search_and_replace}) 307 | conversation.append({"role": "current", "content": decorate_code(sample["current"]["code"], sample["current"]["lang"])}) 308 | if sample["user"]: 309 | conversation.append({"role": "user", "content": sample["user"]}) 310 | return conversation 311 | -------------------------------------------------------------------------------- /gen/__init__.py: -------------------------------------------------------------------------------- 1 | from .genaiprogrammer import AIProgrammer 2 | from .genjudge import GenJudgement 3 | from .geninst import GenInstruction 4 | from .gencht import GenChat 5 | from .llmeval import GenEvaluation 6 | -------------------------------------------------------------------------------- /gen/genaiprogrammer.py: -------------------------------------------------------------------------------- 1 | import random 2 | from .llmgen import LLMGen 3 | from .template.aiprogrammer_template import NOVICE_AIPROGRAMMER_SYSTEM, ORDINARY_AIPROGRAMMER_SYSTEM, EXPERT_AIPROGRAMMER_SYSTEM, AIPROGRAMMER_PROMPT_INPUT, AIPROGRAMMER_PROMPT_OUTPUT, NOVICE_AIPROGRAMMER_FEWSHOT, ORDINARY_AIPROGRAMMER_FEWSHOT, EXPERT_AIPROGRAMMER_FEWSHOT 4 | from .utils import extract_code_blocks 5 | 6 | class AIProgrammer(LLMGen): 7 | def __init__(self, backend="openai", model_map={}, max_try=5, **kwargs) -> None: 8 | super().__init__(backend, model_map, max_try, **kwargs) 9 | 10 | def create_prompt(self, text_map): 11 | """ 12 | Generates a prompt for an AI programmer based on a random skill level. 13 | 14 | The function randomly selects a skill level from "NOVICE", "ORDINARY", or "EXPERT". 15 | Based on the selected skill level, it sets the appropriate system message and few-shot examples. 16 | It then constructs a list of messages that includes the system message, alternating user and assistant 17 | messages from the few-shot examples, and a final user message based on the provided text_map. 18 | 19 | Args: 20 | text_map (dict): A dictionary containing the input text to be formatted into the final user message. 21 | 22 | Returns: 23 | list: A list of dictionaries representing the prompt messages for the AI programmer. 24 | """ 25 | dice = random.choice(["NOVICE", "ORDINARY", "EXPERT"]) 26 | if dice == "NOVICE": 27 | AIPROGRAMMER_SYSTEM = NOVICE_AIPROGRAMMER_SYSTEM 28 | AIPROGRAMMER_FEWSHOT = NOVICE_AIPROGRAMMER_FEWSHOT 29 | elif dice == "ORDINARY": 30 | AIPROGRAMMER_SYSTEM = ORDINARY_AIPROGRAMMER_SYSTEM 31 | AIPROGRAMMER_FEWSHOT = ORDINARY_AIPROGRAMMER_FEWSHOT 32 | else: 33 | AIPROGRAMMER_SYSTEM = EXPERT_AIPROGRAMMER_SYSTEM 34 | AIPROGRAMMER_FEWSHOT = EXPERT_AIPROGRAMMER_FEWSHOT 35 | out = [ 36 | {"role": "system", "content": AIPROGRAMMER_SYSTEM}, 37 | ] + [ 38 | { 39 | "role": "user", 40 | "content": AIPROGRAMMER_PROMPT_INPUT.format_map(shot) 41 | } if i % 2 == 0 else { 42 | "role": "assistant", 43 | "content": AIPROGRAMMER_PROMPT_OUTPUT.format_map(shot) 44 | } for i, shot in enumerate(AIPROGRAMMER_FEWSHOT) 45 | ] + [ 46 | {"role": "user", "content": AIPROGRAMMER_PROMPT_INPUT.format_map(text_map)} 47 | ] 48 | return out 49 | 50 | def reject_response(self, text_map, response): 51 | """ 52 | Determines whether a given response should be rejected based on various criteria. 53 | 54 | Args: 55 | text_map (dict): A dictionary containing the original input text with a key "content". 56 | response (str): The response text to be evaluated. 57 | 58 | Returns: 59 | bool: True if the response should be rejected, False otherwise. 60 | 61 | The function checks the following conditions to decide if the response should be rejected: 62 | - The response length is less than 20 characters. 63 | - The response contains the word "sorry". 64 | - The response does not contain any code blocks or the last code block is not the same as the input code. 65 | - The response contains repeated code blocks. 66 | - The response contains a summary of the process and repeats the last block. 67 | - The response contains segments of code where the sum of the lengths of all but the last block is less than or equal to the length of the last block plus 2 lines. 68 | - The response contains any consecutive identical code blocks. 69 | """ 70 | if len(response) < 20: 71 | return True 72 | if "sorry" in response.lower(): 73 | return True 74 | try: 75 | blocks = extract_code_blocks(response) 76 | # if the last code block is not the same as the input code 77 | if len(blocks) == 0 or blocks[-1].strip() != text_map["content"].strip(): 78 | return True 79 | # sometimes llm will summarize the process and repeat the last block 80 | if len(blocks) >= 2 and blocks[-2] == blocks[-1]: 81 | blocks = blocks[:-1] 82 | block_lengths = [len(block.split("\n")) for block in blocks] 83 | # filter each step is just a segment of the whole code 84 | if len(block_lengths) >= 4 and sum(block_lengths[:-1]) <= block_lengths[-1] + 2: 85 | return True 86 | # filter if any block is the same during the process 87 | elif any(blocks[i] == blocks[i + 1] for i in range(len(blocks) - 1)): 88 | return True 89 | except: 90 | return True 91 | return False 92 | 93 | def post_process(self, text_map, response): 94 | """ 95 | Post-processes the response by extracting code blocks and removing any repeated blocks. 96 | 97 | Args: 98 | text_map (dict): A mapping of text elements. 99 | response (str): The response string containing code blocks. 100 | 101 | Returns: 102 | list: A list of extracted code blocks with any repeated blocks removed. 103 | """ 104 | blocks = extract_code_blocks(response) 105 | # sometimes llm will summarize the process and repeat the last block 106 | if len(blocks) >= 2 and blocks[-2] == blocks[-1]: 107 | blocks = blocks[:-1] 108 | return blocks 109 | -------------------------------------------------------------------------------- /gen/gencht.py: -------------------------------------------------------------------------------- 1 | from .llmgen import LLMGen 2 | from .template.gencht_template import GENCHAT_SYSTEM, GENCHAT_PROMPT_INPUT, GENCHAT_PROMPT_OUTPUT, GENCHAT_FEWSHOT, GENCHAT_RECORD_TYPE 3 | 4 | 5 | class GenChat(LLMGen): 6 | def __init__(self, backend="openai", model_map={}, max_try=5, num_proc=512, **kwargs) -> None: 7 | super().__init__(backend, model_map, max_try, num_proc, **kwargs) 8 | 9 | def __exit__(self, exc_type, exc_val, exc_tb): 10 | return super().__exit__(exc_type, exc_val, exc_tb) 11 | 12 | def create_prompt(self, text_map): 13 | """ 14 | Generates a prompt for a conversational AI model based on provided text mappings. 15 | 16 | Args: 17 | text_map (dict): A dictionary containing the following keys: 18 | - "record" (list): A list of dictionaries, each representing a record with a "type" key. 19 | - "change" (str): A string describing the change. 20 | 21 | Returns: 22 | list: A list of dictionaries, each representing a message in the conversation. Each dictionary contains: 23 | - "role" (str): The role of the message sender, either "system", "user", or "assistant". 24 | - "content" (str): The content of the message. 25 | """ 26 | out = [ 27 | {"role": "system", "content": GENCHAT_SYSTEM}, 28 | ] + [ 29 | {"role": "user", "content": GENCHAT_PROMPT_INPUT.format_map( 30 | { 31 | "record": "\n".join(GENCHAT_RECORD_TYPE[r["type"]].format_map(r) for r in shot["record"]), 32 | "change": shot["change"], 33 | } 34 | )} if i % 2 == 0 else {"role": "assistant", "content": GENCHAT_PROMPT_OUTPUT.format_map(shot)} 35 | for i, shot in enumerate(GENCHAT_FEWSHOT) 36 | ] + [ 37 | { 38 | "role": "user", 39 | "content": GENCHAT_PROMPT_INPUT.format_map( 40 | { 41 | "record": "\n".join(GENCHAT_RECORD_TYPE[r["type"]].format_map(r) for r in text_map["record"]), 42 | "change": text_map["change"], 43 | } 44 | ), 45 | }, 46 | ] 47 | return out 48 | 49 | def reject_response(self, text_map, response): 50 | """ 51 | Determines whether a given response should be rejected based on its format. 52 | 53 | Args: 54 | text_map (dict): A dictionary containing text mappings (not used in the current implementation). 55 | response (str): The response string to be evaluated. 56 | 57 | Returns: 58 | bool: True if the response should be rejected (i.e., it does not start with "**chat:**"), False otherwise. 59 | """ 60 | if not response.strip().startswith("**chat:**"): 61 | return True 62 | return False 63 | 64 | def post_process(self, text_map, response): 65 | """ 66 | Post-processes the response by stripping whitespace and extracting the relevant part. 67 | 68 | Args: 69 | text_map (dict): A dictionary containing text mappings (not used in this function). 70 | response (str): The response string to be processed. 71 | 72 | Returns: 73 | str: The processed response, which is the part after the last occurrence of "**chat:**". 74 | """ 75 | return response.strip().split("**chat:**")[-1].strip() 76 | -------------------------------------------------------------------------------- /gen/geninst.py: -------------------------------------------------------------------------------- 1 | from .llmgen import LLMGen 2 | from .template.geninst_template import GENINST_SYSTEM, GENINST_PROMPT_INPUT, GENINST_PROMPT_OUTPUT, GENINST_FEWSHOT, GENINST_RECORD_TYPE 3 | 4 | 5 | class GenInstruction(LLMGen): 6 | def __init__(self, backend="openai", model_map={}, max_try=5, num_proc=512, **kwargs) -> None: 7 | super().__init__(backend, model_map, max_try, num_proc, **kwargs) 8 | 9 | def __exit__(self, exc_type, exc_val, exc_tb): 10 | return super().__exit__(exc_type, exc_val, exc_tb) 11 | 12 | def create_prompt(self, text_map): 13 | """ 14 | Generates a prompt for the language model based on the provided text map and predefined few-shot examples. 15 | 16 | Args: 17 | text_map (dict): A dictionary containing the 'record' and 'change' keys. 18 | 'record' is a list of dictionaries, each with a 'type' key. 19 | 'change' is a string describing the change. 20 | 21 | Returns: 22 | list: A list of dictionaries representing the prompt, where each dictionary has 'role' and 'content' keys. 23 | """ 24 | out = [ 25 | {"role": "system", "content": GENINST_SYSTEM}, 26 | ] + [ 27 | {"role": "user", "content": GENINST_PROMPT_INPUT.format_map( 28 | { 29 | "record": "\n".join(GENINST_RECORD_TYPE[r["type"]].format_map(r) for r in shot["record"]), 30 | "change": shot["change"], 31 | } 32 | )} if i % 2 == 0 else {"role": "assistant", "content": GENINST_PROMPT_OUTPUT.format_map(shot)} 33 | for i, shot in enumerate(GENINST_FEWSHOT) 34 | ] + [ 35 | { 36 | "role": "user", 37 | "content": GENINST_PROMPT_INPUT.format_map( 38 | { 39 | "record": "\n".join(GENINST_RECORD_TYPE[r["type"]].format_map(r) for r in text_map["record"]), 40 | "change": text_map["change"], 41 | } 42 | ), 43 | }, 44 | ] 45 | return out 46 | 47 | def reject_response(self, text_map, response): 48 | """ 49 | Determines whether a given response should be rejected based on specific criteria. 50 | 51 | Args: 52 | text_map (dict): A dictionary containing text mappings (not used in the current implementation). 53 | response (str): The response string to be evaluated. 54 | 55 | Returns: 56 | bool: True if the response should be rejected, False otherwise. 57 | 58 | The response is rejected if: 59 | - It does not start with "**instruction:**" after stripping leading and trailing whitespace. 60 | - It contains the substring "\nNote:". 61 | - It contains the phrase "no change" (case insensitive). 62 | """ 63 | if not response.strip().startswith("**instruction:**") or "\nNote:" in response or "no change" in response.lower(): 64 | return True 65 | return False 66 | 67 | def post_process(self, text_map, response): 68 | """ 69 | Post-processes the response by stripping leading and trailing whitespace and 70 | splitting the text based on the delimiter "**instruction:**". It returns the 71 | last segment of the split response. 72 | 73 | Args: 74 | text_map (dict): A dictionary mapping text segments (not used in this method). 75 | response (str): The response string to be post-processed. 76 | 77 | Returns: 78 | str: The processed response string. 79 | """ 80 | return response.strip().split("**instruction:**")[-1].strip() 81 | -------------------------------------------------------------------------------- /gen/genjudge.py: -------------------------------------------------------------------------------- 1 | from .llmgen import LLMGen 2 | from .template.genjudge_template import GENJUDGE_SYSTEM, GENJUDGE_RECORD_TYPE, GENJUDGE_PROMPT_INPUT, GENJUDGE_PROMPT_INPUT_RECORD, GENJUDGE_PROMPT_INPUT_CHANGE, GENJUDGE_PROMPT_OUTPUT, GENJUDGE_FEWSHOT 3 | 4 | class GenJudgement(LLMGen): 5 | def __init__(self, backend="openai", model_map={}, max_try=5, num_proc=512, **kwargs) -> None: 6 | super().__init__(backend, model_map, max_try, num_proc, **kwargs) 7 | 8 | def __exit__(self, exc_type, exc_val, exc_tb): 9 | return super().__exit__(exc_type, exc_val, exc_tb) 10 | 11 | def create_prompt(self, text_map): 12 | """ 13 | Generates a prompt for the GENJUDGE model based on the provided text map. 14 | 15 | The prompt consists of a series of messages formatted for the model, including 16 | system instructions, few-shot examples, and the user input based on the given text map. 17 | 18 | Args: 19 | text_map (dict): A dictionary containing the 'record' and 'change' keys. 20 | 'record' is a list of dictionaries, each representing a record with a 'type' key. 21 | 'change' is a list of changes to be included in the prompt. 22 | 23 | Returns: 24 | list: A list of dictionaries, each representing a message in the prompt. 25 | The messages alternate between 'user' and 'assistant' roles, with the final message being the user input. 26 | """ 27 | out = [ 28 | {"role": "system", "content": GENJUDGE_SYSTEM}, 29 | ] + [ 30 | {"role": "user", "content": GENJUDGE_PROMPT_INPUT.format_map( 31 | { 32 | "record": GENJUDGE_PROMPT_INPUT_RECORD.format_map({"record": "\n".join(GENJUDGE_RECORD_TYPE[r["type"]].format_map(r) for r in shot["record"])}), 33 | "change": "\n".join(GENJUDGE_PROMPT_INPUT_CHANGE.format_map({"num": i+1, "change": c}) for i, c in enumerate(shot["change"])), 34 | } 35 | )} if i % 2 == 0 else {"role": "assistant", "content": GENJUDGE_PROMPT_OUTPUT.format_map(shot)} 36 | for i, shot in enumerate(GENJUDGE_FEWSHOT) 37 | ] + [ 38 | { 39 | "role": "user", 40 | "content": GENJUDGE_PROMPT_INPUT.format_map( 41 | { 42 | "record": GENJUDGE_PROMPT_INPUT_RECORD.format_map({"record": "\n".join(GENJUDGE_RECORD_TYPE[r["type"]].format_map(r) for r in text_map["record"])}), 43 | "change": "\n".join(GENJUDGE_PROMPT_INPUT_CHANGE.format_map({"num": i+1, "change": c}) for i, c in enumerate(text_map["change"])), 44 | } 45 | ), 46 | }, 47 | ] 48 | return out 49 | 50 | def reject_response(self, text_map, response): 51 | """ 52 | Determines whether a given response should be rejected based on specific criteria. 53 | 54 | Args: 55 | text_map (dict): A dictionary containing the text data, specifically with a key "change". 56 | response (str): The response string to be evaluated. 57 | 58 | Returns: 59 | bool: True if the response should be rejected, False otherwise. 60 | 61 | Criteria for rejection: 62 | - The response length is less than 20 characters. 63 | - The response contains the word "sorry" (case insensitive). 64 | - The number of segments in the response after splitting by "Analysis of change" does not match the number of changes in text_map["change"]. 65 | - Any segment in the response does not contain "**Decision:**" or does not have "True" or "False" following "**Decision:**". 66 | """ 67 | if len(response) < 20: 68 | return True 69 | if "sorry" in response.lower(): 70 | return True 71 | each_judge = response.split("Analysis of change")[1:] 72 | if len(each_judge) != len(text_map["change"]): 73 | return True 74 | for judge in each_judge: 75 | if "**Decision:**" not in judge or ("True" not in judge.split("**Decision:**")[-1] and "False" not in judge.split("**Decision:**")[-1]): 76 | return True 77 | return False 78 | 79 | def post_process(self, text_map, response): 80 | """ 81 | Processes the response text to extract decision outcomes. 82 | 83 | Args: 84 | text_map (dict): A mapping of text elements (not used in the function). 85 | response (str): The response string containing multiple judge analyses. 86 | 87 | Returns: 88 | list: A list of boolean values representing the decisions extracted from the response. 89 | Each boolean corresponds to a decision where True indicates a positive decision 90 | and False indicates a negative decision. 91 | 92 | Raises: 93 | AssertionError: If a decision is not found in any of the judge analyses. 94 | """ 95 | each_judge = response.split("Analysis of change")[1:] 96 | out = [] 97 | for judge in each_judge: 98 | if "True" in judge.split("**Decision:**")[-1]: 99 | out.append(True) 100 | elif "False" in judge.split("**Decision:**")[-1]: 101 | out.append(False) 102 | else: 103 | assert False, f"Decision not found in {judge}" 104 | return out 105 | -------------------------------------------------------------------------------- /gen/llmeval.py: -------------------------------------------------------------------------------- 1 | from .llmgen import LLMGen 2 | 3 | 4 | class GenEvaluation(LLMGen): 5 | def __init__(self, backend="openai", model_map={}, max_try=1, num_proc=512, **kwargs) -> None: 6 | super().__init__(backend, model_map, max_try, num_proc, **kwargs) 7 | 8 | def __exit__(self, exc_type, exc_val, exc_tb): 9 | return super().__exit__(exc_type, exc_val, exc_tb) 10 | 11 | def create_prompt(self, text_map): 12 | return text_map["conversation"] 13 | 14 | def reject_response(self, text_map, response): 15 | return False 16 | 17 | def post_process(self, text_map, response): 18 | return response 19 | -------------------------------------------------------------------------------- /gen/llmgen.py: -------------------------------------------------------------------------------- 1 | import random 2 | import concurrent.futures 3 | 4 | class LLMGen: 5 | def __init__(self, backend="openai", model_map={}, max_try=5, num_proc=512, **kwargs) -> None: 6 | self.model_name = list(model_map.keys()) 7 | self.max_try = max_try 8 | if backend == "openai": 9 | from .openai import OpenAICompletion 10 | self.backend = OpenAICompletion(model_map) 11 | elif backend == "test": 12 | self.backend = None 13 | else: 14 | raise ValueError( 15 | f"backend {backend} is currently not supported" 16 | ) 17 | self.kwargs = kwargs 18 | self.executor = concurrent.futures.ThreadPoolExecutor(num_proc) 19 | 20 | def __exit__(self, exc_type, exc_val, exc_tb): 21 | self.executor.shutdown() 22 | 23 | def create_prompt(self, text_map, model_name): 24 | pass 25 | 26 | def reject_response(self, text_map, response): 27 | pass 28 | 29 | def post_process(self, text_map, response): 30 | pass 31 | 32 | def gen( 33 | self, text_maps, api_type="chat_completion" 34 | ): 35 | def process_text_map(text_map): 36 | """ 37 | Processes a given text map by generating a response using a specified model. 38 | 39 | Args: 40 | text_map (dict): The input text map to be processed. 41 | 42 | Returns: 43 | dict: A dictionary containing the input text map and the processed output. 44 | If the response is rejected after the maximum number of tries, the output will be None. 45 | 46 | Raises: 47 | ValueError: If the specified api_type is not supported. 48 | """ 49 | success = False 50 | try_count = 0 51 | answer = "" 52 | while try_count < self.max_try and not success: 53 | if type(self.model_name) == list: 54 | model_name = random.choice(self.model_name) 55 | else: 56 | model_name = self.model_name 57 | input_text = self.create_prompt(text_map) 58 | if api_type == "chat_completion": 59 | answer = self.backend.chat_completion(input_text, model_name, **self.kwargs) 60 | elif api_type == "completion": 61 | answer = self.backend.completion(input_text, model_name, **self.kwargs) 62 | else: 63 | raise ValueError( 64 | f"api_type {api_type} is currently not supported" 65 | ) 66 | if self.reject_response(text_map, answer): 67 | try_count += 1 68 | continue 69 | success = True 70 | if self.reject_response(text_map, answer): 71 | return {"input": text_map, "output": None} 72 | else: 73 | return {"input": text_map, "output": self.post_process(text_map, answer)} 74 | 75 | gen = [] 76 | results = self.executor.map(process_text_map, text_maps) 77 | for result in results: 78 | gen.append(result) 79 | return gen 80 | -------------------------------------------------------------------------------- /gen/openai.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | #TODO: Support batch api calls 4 | class OpenAICompletion: 5 | def __init__( 6 | self, 7 | model_map, 8 | ) -> None: 9 | self.client = { 10 | model: OpenAI( 11 | base_url=model_map[model]["base"], 12 | api_key=model_map[model]["api"], 13 | ) 14 | for model in model_map 15 | } 16 | 17 | def chat_completion( 18 | self, 19 | prompt, 20 | model_name="deepseek-chat", 21 | temperature=0.2, 22 | max_tokens=3072, 23 | top_p=0.95, 24 | frequency_penalty=0, 25 | presence_penalty=0, 26 | stop=["<|eot_id|>", "<|im_end|>", "", "<|EOT|>", "<|endoftext|>", "<|eos|>"], # default stop tokens 27 | timeout=200, 28 | extra_body={}, 29 | ): 30 | response = "" 31 | client = self.client[model_name] 32 | try: 33 | response = client.chat.completions.create( 34 | model=model_name, 35 | messages=prompt, 36 | temperature=temperature, 37 | max_tokens=max_tokens, 38 | top_p=top_p, 39 | frequency_penalty=frequency_penalty, 40 | presence_penalty=presence_penalty, 41 | stop=stop, 42 | timeout=timeout, 43 | extra_body=extra_body, 44 | ) 45 | except Exception as e: 46 | print("Exception", e) 47 | if response != "": 48 | if response.choices[0].finish_reason == "length": 49 | return "" 50 | return response.choices[0].message.content 51 | return "" 52 | 53 | def completion( 54 | self, 55 | prompt, 56 | model_name="deepseek-chat", 57 | temperature=0.2, 58 | max_tokens=3072, 59 | top_p=0.95, 60 | frequency_penalty=0, 61 | presence_penalty=0, 62 | stop=["", "<|endoftext|>", "<|eos_token|>"], # default stop tokens 63 | timeout=200, 64 | extra_body={}, 65 | ): 66 | response = "" 67 | client = self.client[model_name] 68 | try: 69 | response = client.completions.create( 70 | model=model_name, 71 | prompt=prompt, 72 | temperature=temperature, 73 | max_tokens=max_tokens, 74 | top_p=top_p, 75 | frequency_penalty=frequency_penalty, 76 | presence_penalty=presence_penalty, 77 | stop=stop, 78 | timeout=timeout, 79 | extra_body=extra_body, 80 | ) 81 | except Exception as e: 82 | print("Exception", e) 83 | if response != "": 84 | return response.choices[0].text 85 | return "" 86 | -------------------------------------------------------------------------------- /gen/template/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/gen/template/__init__.py -------------------------------------------------------------------------------- /gen/template/gencht_template.py: -------------------------------------------------------------------------------- 1 | GENCHAT_SYSTEM = """You are a programming assistant. The following content includes information related to your programming assistance, which may contain the record of the programming process, the current code, the user instruction, and your predicted modifications. Please provide the chat conversation for making the prediction. This may include analyzing the past programming process, speculating on the user's intent, and explaining the planning and ideas for modifying the code. Return your chat conversation in the following format: 2 | ``` 3 | **chat:** 4 | {chat} 5 | ```""" 6 | 7 | GENCHAT_RECORD_TYPE = { 8 | "current": """Current code: 9 | {current} 10 | """, 11 | "history": """Revised code changes: 12 | {history} 13 | """, 14 | "user": """User instruction: 15 | {user} 16 | """, 17 | } 18 | 19 | GENCHAT_PROMPT_INPUT = """{record} 20 | Predicted modifications: 21 | {change}""" 22 | 23 | GENCHAT_PROMPT_OUTPUT = """**chat:** 24 | {chat}""" 25 | 26 | GENCHAT_FEWSHOT = [ 27 | { 28 | "record": [ 29 | { 30 | "type": "history", 31 | "history": """```diff\n@@ -14,3 +14,30 @@\n if (row == n) {\n vector board = generateBoard(queens, n);\n solutions.push_back(board);\n+ } else {\n+ for (int i = 0; i < n; i++) {\n+ if (columns.find(i) != columns.end()) {\n+ continue;\n+ }\n+ int diagonal1 = row - i;\n+ if (diagonals1.find(diagonal1) != diagonals1.end()) {\n+ continue;\n+ }\n+ int diagonal2 = row + i;\n+ if (diagonals2.find(diagonal2) != diagonals2.end()) {\n+ continue;\n+ }\n+ queens[row] = i;\n+ columns.insert(i);\n+ diagonals1.insert(diagonal1);\n+ diagonals2.insert(diagonal2);\n+ backtrack(solutions, queens, n, row + 1, columns, diagonals1, diagonals2);\n+ queens[row] = -1;\n+ columns.erase(i);\n+ diagonals1.erase(diagonal1);\n+ diagonals2.erase(diagonal2);\n+ }\n+ }\n+ }\n+\n+ vector generateBoard(vector &queens, int n)\n```""" 32 | }, 33 | { 34 | "type": "history", 35 | "history": """```diff\n@@ -3,41 +3,3 @@\n vector> solveNQueens(int n) {\n auto solutions = vector>();\n auto queens = vector(n, -1);\n- auto columns = unordered_set();\n- auto diagonals1 = unordered_set();\n- auto diagonals2 = unordered_set();\n- backtrack(solutions, queens, n, 0, columns, diagonals1, diagonals2);\n- return solutions;\n- }\n-\n- void backtrack(vector> &solutions, vector &queens, int n, int row, unordered_set &columns, unordered_set &diagonals1, unordered_set &diagonals2) {\n- if (row == n) {\n- vector board = generateBoard(queens, n);\n- solutions.push_back(board);\n- } else {\n- for (int i = 0; i < n; i++) {\n- if (columns.find(i) != columns.end()) {\n- continue;\n- }\n- int diagonal1 = row - i;\n- if (diagonals1.find(diagonal1) != diagonals1.end()) {\n- continue;\n- }\n- int diagonal2 = row + i;\n- if (diagonals2.find(diagonal2) != diagonals2.end()) {\n- continue;\n- }\n- queens[row] = i;\n- columns.insert(i);\n- diagonals1.insert(diagonal1);\n- diagonals2.insert(diagonal2);\n- backtrack(solutions, queens, n, row + 1, columns, diagonals1, diagonals2);\n- queens[row] = -1;\n- columns.erase(i);\n- diagonals1.erase(diagonal1);\n- diagonals2.erase(diagonal2);\n- }\n- }\n- }\n-\n- vector generateBoard(vector &queens, int n)\n```""" 36 | }, 37 | { 38 | "type": "history", 39 | "history": """```diff\n@@ -3,3 +3,17 @@\n vector> solveNQueens(int n) {\n auto solutions = vector>();\n auto queens = vector(n, -1);\n+ solve(solutions, queens, n, 0, 0, 0, 0);\n+ return solutions;\n+ }\n+\n+ vector generateBoard(vector &queens, int n) {\n+ auto board = vector();\n+ for (int i = 0; i < n; i++) {\n+ string row = string(n, '.');\n+ row[queens[i]] = 'Q';\n+ board.push_back(row);\n+ }\n+ return board;\n+ }\n+};\n```""" 40 | }, 41 | { 42 | "type": "current", 43 | "current": """```cpp\n1 class Solution {\n2 public:\n3 vector> solveNQueens(int n) {\n4 auto solutions = vector>();\n5 auto queens = vector(n, -1);\n6 solve(solutions, queens, n, 0, 0, 0, 0);\n7 return solutions;\n8 }\n9 \n10 vector generateBoard(vector &queens, int n) {\n11 auto board = vector();\n12 for (int i = 0; i < n; i++) {\n13 string row = string(n, '.');\n14 row[queens[i]] = 'Q';\n15 board.push_back(row);\n16 }\n17 return board;\n18 }\n19 };\n```""" 44 | } 45 | ], 46 | "change": """```diff\n@@ -5,6 +5,23 @@\n auto queens = vector(n, -1);\n solve(solutions, queens, n, 0, 0, 0, 0);\n return solutions;\n+ }\n+\n+ void solve(vector> &solutions, vector &queens, int n, int row, int columns, int diagonals1, int diagonals2) {\n+ if (row == n) {\n+ auto board = generateBoard(queens, n);\n+ solutions.push_back(board);\n+ } else {\n+ int availablePositions = ((1 << n) - 1) & (~(columns | diagonals1 | diagonals2));\n+ while (availablePositions != 0) {\n+ int position = availablePositions & (-availablePositions);\n+ availablePositions = availablePositions & (availablePositions - 1);\n+ int column = __builtin_ctz(position);\n+ queens[row] = column;\n+ solve(solutions, queens, n, row + 1, columns | position, (diagonals1 | position) >> 1, (diagonals2 | position) << 1);\n+ queens[row] = -1;\n+ }\n+ }\n }\n \n vector generateBoard(vector &queens, int n) {\n```""" 47 | }, 48 | { 49 | "chat": """The revised code transitions from using sets to bitwise operations for solving the N-Queens problem, aiming to optimize the algorithm by reducing dynamic memory allocations and improving speed. Initially, the original code utilized sets (`unordered_set`) to track columns and diagonals. The user likely seeks to enhance efficiency, prompting the replacement of these sets with integer bit masks. To achieve this, the `solve` function is adjusted to use bitwise operations for conflict checks, calculating `availablePositions` using bitwise AND and NOT operations to identify valid queen placements. The `__builtin_ctz` function finds the least significant bit set to 1, corresponding to a valid column, and the recursive `solve` call updates bit masks for columns and diagonals after placing a queen. This maintains the backtracking logic while leveraging bitwise operations for performance. The `generateBoard` function remains unchanged, as it simply formats the result. This approach ensures efficient, streamlined code for the N-Queens problem.""" 50 | }, 51 | { 52 | "record": [ 53 | { 54 | "type": "current", 55 | "current": """```\n1 const express = require('express');\n2 const app = express();\n```""" 56 | }, 57 | { 58 | "type": "user", 59 | "user": """Add two new routes to the Express.js application:\n\n1. `/upper`: This route should accept a query parameter `string` and return the uppercase version of the string. If no string is provided, it should return a 400 error with a message "Please provide a string".\n2. `/lower`: This route should accept a query parameter `string` and return the lowercase version of the string. If no string is provided, it should return a 400 error with a message "Please provide a string".\n\nImplement these routes by modifying the existing code.""" 60 | } 61 | ], 62 | "change": """```diff\n@@ -1,2 +1,22 @@\n const express = require('express');\n const app = express();\n+\n+app.get('/upper', (req, res) => {\n+ let string = req.query.string;\n+ if (!string) {\n+ res.status(400).send('Please provide a string');\n+ return;\n+ }\n+ res.send(string.toUpperCase());\n+});\n+\n+app.get('/lower', (req, res) => {\n+ let string = req.query.string;\n+ if (!string) {\n+ res.status(400).send('Please provide a string');\n+ return;\n+ }\n+ res.send(string.toLowerCase());\n+});\n+\n+app.listen(3000, () => console.log('Server started'));\n```""" 63 | }, 64 | { 65 | "chat": """We should add two new routes to the existing Express.js application. These routes should handle the conversion of a query parameter `string` to either uppercase or lowercase, depending on the route accessed. If the query parameter is not provided, the routes should return a 400 error with an appropriate message.\n\nTo implement this, we will:\n1. Define a route `/upper` that checks for the presence of the query parameter `string`. If it exists, we will convert it to uppercase and return the result. If it does not exist, we will return a 400 error with the message "Please provide a string".\n2. Define a route `/lower` with similar logic, but it will convert the `string` to lowercase.\n\nFinally, we will ensure that the server is listening on a specified port, in this case, port 3000.""" 66 | } 67 | ] 68 | -------------------------------------------------------------------------------- /gen/template/geninst_template.py: -------------------------------------------------------------------------------- 1 | GENINST_SYSTEM = """You are a programming assistant. The following content includes information related to your programming assistance, which may contain the record of the programming process, the current code, the git commit after all changes, relevant details about the problem, and your predicted modifications. Please generate an instruction for you to make the corresponding modifications, ensuring it resembles instructions typically given by a human programmer. The instruction may be detailed or concise and may or may not specify the location of the modification. Return the generated instruction in the following format: 2 | ``` 3 | **instruction:** 4 | {instruction} 5 | ```""" 6 | 7 | GENINST_RECORD_TYPE = { 8 | "current": """Current code: 9 | {current} 10 | """, 11 | "history": """Revised code changes: 12 | {history} 13 | """, 14 | "git": """Git commit message after all changes: 15 | {git} 16 | """, 17 | "problem": """Relevant details about the problem: 18 | {problem} 19 | """, 20 | } 21 | 22 | GENINST_PROMPT_INPUT = """{record} 23 | Changes in predictions: 24 | {change}""" 25 | 26 | GENINST_PROMPT_OUTPUT = """**instruction:** 27 | {instruction}""" 28 | 29 | GENINST_FEWSHOT = [ 30 | { 31 | "record": [ 32 | { 33 | "type": "current", 34 | "current": """```python\n1 import dedupe\n2 \n3 def detect_fraud(transactions):\n4 deduper = dedupe.Dedupe()\n5 deduper.sample(transactions)\n6 deduper.train()\n7 \n8 clusters = deduper.match(transactions, threshold=0.5)\n9 \n10 return clusters\n11 \n12 transactions = [\n13 {'id': 1, 'amount': 100.0},\n14 {'id': 2, 'amount': 200.0},\n15 {'id': 3, 'amount': 150.0},\n16 ]\n17 \n18 fraud_clusters = detect_fraud(transactions)\n19 print(fraud_clusters)\n```""" 35 | } 36 | ], 37 | "change": """```diff\n@@ -15,5 +15,10 @@\n {'id': 3, 'amount': 150.0},\n ]\n \n+# Replace '_sample' with 'sample'\n+deduper = dedupe.Dedupe()\n+deduper.sample = deduper._sample\n+del deduper._sample\n+\n fraud_clusters = detect_fraud(transactions)\n print(fraud_clusters)\n```""" 38 | }, 39 | { 40 | "instruction": """Replace the internal '_sample' method with 'sample' in the dedupe object, approximately at line 17.""" 41 | }, 42 | { 43 | "record": [ 44 | { 45 | "type": "history", 46 | "history": """```diff\n@@ -3,6 +3,10 @@\n def create_cnn_model(in_channels, config):\n layers = []\n conv2d = nn.Conv2d(in_channels, config, kernel_size=3, padding=1)\n- layers += [conv2d, nn.ReLU(inplace=True)]\n+ if batch_norm:\n+ layers += [conv2d, nn.BatchNorm2d(config)]\n+ else:\n+ layers += [conv2d]\n+ layers += [nn.ReLU(inplace=True)]\n model = nn.Sequential(*layers)\n return model\n```""" 47 | }, 48 | { 49 | "type": "history", 50 | "history": """```diff\n@@ -1,6 +1,6 @@\n import torch.nn as nn\n \n-def create_cnn_model(in_channels, config):\n+def create_cnn_model(in_channels, config, batch_norm=False):\n layers = []\n conv2d = nn.Conv2d(in_channels, config, kernel_size=3, padding=1)\n if batch_norm:\n```""" 51 | }, 52 | { 53 | "type": "current", 54 | "current": """```\n1 import torch.nn as nn\n2 \n3 def create_cnn_model(in_channels, config, batch_norm=False):\n4 layers = []\n5 conv2d = nn.Conv2d(in_channels, config, kernel_size=3, padding=1)\n6 if batch_norm:\n7 layers += [conv2d, nn.BatchNorm2d(config)]\n8 else:\n9 layers += [conv2d]\n10 layers += [nn.ReLU(inplace=True)]\n11 model = nn.Sequential(*layers)\n12 return model\n```""" 55 | } 56 | ], 57 | "change": """```diff\n@@ -1,12 +1,11 @@\n import torch.nn as nn\n \n-def create_cnn_model(in_channels, config, batch_norm=False):\n+def create_cnn_model(in_channels, config, batch_norm):\n layers = []\n conv2d = nn.Conv2d(in_channels, config, kernel_size=3, padding=1)\n if batch_norm:\n- layers += [conv2d, nn.BatchNorm2d(config)]\n+ layers += [conv2d, nn.BatchNorm2d(config), nn.ReLU(inplace=True)]\n else:\n- layers += [conv2d]\n- layers += [nn.ReLU(inplace=True)]\n+ layers += [conv2d, nn.ReLU(inplace=True)]\n model = nn.Sequential(*layers)\n return model\n```""" 58 | }, 59 | { 60 | "instruction": """Update the create_cnn_model function to ensure that the ReLU activation function is added immediately after the BatchNorm layer if batch_norm is enabled. Adjust the function signature to remove the default value for the batch_norm parameter. The updated code should handle the addition of the ReLU layer conditionally based on the batch_norm parameter.""" 61 | }, 62 | { 63 | "record": [ 64 | { 65 | "type": "current", 66 | "current": """```ruby\n1 # frozen_string_literal: true\n2 module Extensions::DeferredWorkflowStatePersistence::Workflow; end\n3 module Extensions::DeferredWorkflowStatePersistence::Workflow::Adapter; end\n4 module Extensions::DeferredWorkflowStatePersistence::Workflow::Adapter::DeferredActiveRecord\n5 extend ActiveSupport::Concern\n6 included do\n7 include Workflow::Adapter::ActiveRecord\n8 include InstanceMethods\n9 end\n10 \n11 module InstanceMethods\n12 def persist_workflow_state(new_value)\n13 write_attribute(self.class.workflow_column, new_value)\n14 true\n15 end\n16 end\n17 end\n18 \n```""" 67 | }, 68 | { 69 | "type": "git", 70 | "git": "Include WorkflowActiverecord in the state persistence extension." 71 | } 72 | ], 73 | "change": """```diff\n@@ -1,10 +1,12 @@\n # frozen_string_literal: true\n+require 'workflow_activerecord'\n+\n module Extensions::DeferredWorkflowStatePersistence::Workflow; end\n module Extensions::DeferredWorkflowStatePersistence::Workflow::Adapter; end\n module Extensions::DeferredWorkflowStatePersistence::Workflow::Adapter::DeferredActiveRecord\n extend ActiveSupport::Concern\n included do\n- include Workflow::Adapter::ActiveRecord\n+ include WorkflowActiverecord::Adapter::ActiveRecord\n include InstanceMethods\n end\n \n```""" 74 | }, 75 | { 76 | "instruction": """At the beginning of the file, add the statement `require 'workflow_activerecord'`; On line 7, change `include Workflow::Adapter::ActiveRecord` to `include WorkflowActiverecord::Adapter::ActiveRecord`; Ensure the final code reflects the necessary changes for including `WorkflowActiverecord` in the state persistence extension.""" 77 | }, 78 | { 79 | "record": [ 80 | { 81 | "type": "history", 82 | "history": """```diff\n@@ -1,4 +1,5 @@\n import java.util.*;\n+import java.io.*;\n \n class Main\n {\n@@ -15,14 +16,18 @@\n }\n }\n \n- int n;\n- Scanner sc = new Scanner(System.in);\n- n = sc.nextInt();\n+ int ans = 0;\n+ BufferedReader bf = new BufferedReader(new InputStreamReader(System.in));\n \n- int ans = 0;\n- for(Map.Entry e : map.entrySet()){\n- Integer v = map.get(n-e.getKey());\n- ans += (v==null ? 0 : v) * e.getValue();\n+ try{\n+ int n = Integer.parseInt(bf.readLine());\n+ for(Map.Entry e : map.entrySet()){\n+ Integer v = map.get(n-e.getKey());\n+ ans += (v==null ? 0 : v) * e.getValue();\n+ }\n+ }\n+ catch(IOException ex){\n+ System.out.println(ex);\n }\n \n System.out.println(ans);\n```""" 83 | }, 84 | { 85 | "type": "current", 86 | "current": """```java\n1 import java.util.*;\n2 import java.io.*;\n3 \n4 class Main\n5 {\n6 public static void main(String[] args)\n7 {\n8 HashMap map = new HashMap();\n9 \n10 for(int i=1; i<10; ++i){\n11 for(int j=1; j<10; ++j){\n12 if(!map.containsKey(i+j))\n13 map.put(i+j, 1);\n14 else\n15 map.put(i+j, map.get(i+j)+1);\n16 }\n17 }\n18 \n19 int ans = 0;\n20 BufferedReader bf = new BufferedReader(new InputStreamReader(System.in));\n21 \n22 try{\n23 int n = Integer.parseInt(bf.readLine());\n24 for(Map.Entry e : map.entrySet()){\n25 Integer v = map.get(n-e.getKey());\n26 ans += (v==null ? 0 : v) * e.getValue();\n27 }\n28 }\n29 catch(IOException ex){\n30 System.out.println(ex);\n31 }\n32 \n33 System.out.println(ans);\n34 }\n35 }\n```""" 87 | }, 88 | { 89 | "type": "problem", 90 | "problem": """Problem Description:\nWrite a program which reads an integer n and identifies the number of combinations of a, b, c and d (0 ≤ a, b, c, d ≤ 9) which meet the following equality:a + b + c + d = n\nFor example, for n = 35, we have 4 different combinations of (a, b, c, d): (8, 9, 9, 9), (9, 8, 9, 9), (9, 9, 8, 9), and (9, 9, 9, 8).\n\nInput Description:\nThe input consists of several datasets. Each dataset consists of n (1 ≤ n ≤ 50) in a line. The number of datasets is less than or equal to 50.\n\nOutput Description:\nPrint the number of combination in a line.\n\nSample Input 1:\n35\n1\n\nSample Output 1:\n4\n4""" 91 | } 92 | ], 93 | "change": """```diff\n@@ -7,8 +7,8 @@\n {\n HashMap map = new HashMap();\n \n- for(int i=1; i<10; ++i){\n- for(int j=1; j<10; ++j){\n+ for(int i=0; i<10; ++i){\n+ for(int j=0; j<10; ++j){\n if(!map.containsKey(i+j))\n map.put(i+j, 1);\n else\n@@ -16,20 +16,22 @@\n }\n }\n \n- int ans = 0;\n BufferedReader bf = new BufferedReader(new InputStreamReader(System.in));\n+ String str;\n \n try{\n- int n = Integer.parseInt(bf.readLine());\n- for(Map.Entry e : map.entrySet()){\n- Integer v = map.get(n-e.getKey());\n- ans += (v==null ? 0 : v) * e.getValue();\n+ while((str = bf.readLine()) != null){\n+ int ans = 0;\n+ int n = Integer.parseInt(str);\n+ for(Map.Entry e : map.entrySet()){\n+ Integer v = map.get(n-e.getKey());\n+ ans += (v==null ? 0 : v) * e.getValue();\n+ }\n+ System.out.println(ans);\n }\n }\n catch(IOException ex){\n System.out.println(ex);\n }\n-\n- System.out.println(ans);\n }\n }\n```""" 94 | }, 95 | { 96 | "instruction": """Update the code to handle multiple datasets and calculate the number of combinations of four digits (0 ≤ a, b, c, d ≤ 9) that sum up to a given number n for each dataset. Initialize the map to store the counts of all possible sums of pairs of digits (a+b). Modify the code to read multiple inputs until EOF and for each input, calculate the number of valid combinations where the sum of two pairs equals n.""" 97 | }, 98 | ] 99 | -------------------------------------------------------------------------------- /gen/template/genjudge_template.py: -------------------------------------------------------------------------------- 1 | GENJUDGE_SYSTEM = """You are tasked with assisting a programmer by maintaining a record of the programming process, including potential future changes. Your role is to discern which changes the programmer desires you to propose proactively. These should align with their actual intentions and be helpful. To determine which changes align with a programmer's intentions, consider the following principles: 2 | 3 | 1. **Understand the Context**: Assess the overall goal of the programming project. Ensure that any proposed change aligns with the project's objectives and the programmer's current focus. 4 | 5 | 2. **Maintain Clear Communication**: Before proposing changes, ensure that your suggestions are clear and concise. This helps the programmer quickly understand the potential impact of each change. 6 | 7 | 3. **Prioritize Stability**: Avoid proposing changes that could introduce instability or significant complexity unless there is a clear benefit. Stability is often more valued than optimization in the early stages of development. 8 | 9 | 4. **Respect the Programmer's Preferences**: Pay attention to the programmer's coding style and preferences. Propose changes that enhance their style rather than contradict it. 10 | 11 | 5. **Incremental Improvements**: Suggest changes that offer incremental improvements rather than drastic overhauls, unless specifically requested. This approach is less disruptive and easier for the programmer to integrate. 12 | 13 | 6. **Consider Long-Term Maintenance**: Propose changes that improve code maintainability and readability. This includes refactoring for clarity, reducing redundancy, and enhancing documentation. 14 | 15 | 7. **Balance Proactivity and Reactivity**: Be proactive in suggesting improvements that are likely to be universally beneficial (e.g., bug fixes, performance enhancements). However, be reactive, not proactive, in areas where the programmer's specific intentions are unclear or where personal preference plays a significant role. 16 | 17 | For each potential change, return `True` if suggesting this change would be beneficial to the programmer, return `False` if the change does not align with the programmer's intentions or if they do not want you to predict this change. Give your decision after analyzing each change. Provide your response in the following format: 18 | 19 | ``` 20 | **Analysis of change 1:** 21 | 22 | Your analysis here. 23 | 24 | **Decision:** `True` or `False` 25 | 26 | **Analysis of change 2:** 27 | 28 | Your analysis here. 29 | 30 | **Decision:** `True` or `False` 31 | 32 | ... 33 | ```""" 34 | 35 | GENJUDGE_RECORD_TYPE = { 36 | "current": """Current code: 37 | {current} 38 | """, 39 | "history": """Revised code changes: 40 | {history} 41 | """ 42 | } 43 | 44 | GENJUDGE_PROMPT_INPUT = """{record} 45 | {change}""" 46 | 47 | GENJUDGE_PROMPT_INPUT_RECORD = """**record:** 48 | {record}""" 49 | 50 | GENJUDGE_PROMPT_INPUT_CHANGE = """**change {num}:** 51 | {change}""" 52 | 53 | GENJUDGE_PROMPT_OUTPUT = """{judgement}""" 54 | 55 | GENJUDGE_FEWSHOT = [ 56 | { 57 | "record": [ 58 | { 59 | "type": "history", 60 | "history": """```diff\n@@ -1 +1,5 @@\n+/**\n+ * magical invsqrt function from Quake III code\n+ * see: http://www.codemaestro.com/reviews/9\n+ */\n \n```""" 61 | }, 62 | { 63 | "type": "current", 64 | "current": """```c\n1 /**\n2 * magical invsqrt function from Quake III code\n3 * see: http://www.codemaestro.com/reviews/9\n4 */\n5 \n```""" 65 | } 66 | ], 67 | "change": [ 68 | """```diff\n@@ -3,3 +3,19 @@\n * see: http://www.codemaestro.com/reviews/9\n */\n \n+float InvSqrt(float x)\n+{\n+\tfloat xhalf = 0.5f*x;\n+\tint i = *(int*)&x;\n+\ti = 0x5f3759df - (i>>1);\n+\tx = *(float*)&i;\n+\tx = x*(1.5f-xhalf*x*x);\n+\treturn x;\n+}\n+\n+int main(void) {\n+\tint result = InvSqrt(0.00056);\n+\tprintf("Result: %d (should be 42)", result);\n+\treturn result != 42;\n+}\n+\n```""" 69 | ] 70 | }, 71 | { 72 | "judgement": """**Analysis of change 1:**\n\nThe change introduces the `InvSqrt` function implementation and a basic test case within the `main` function. This change is appropriate and beneficial. The primary goal is to implement and test the `InvSqrt` function, a well-known algorithm from the Quake III code. Adding the function and a basic test case aligns directly with this goal. The implementation is clear and concise, following common C coding practices. The function's logic is straightforward and well-documented, making it easy to understand. Overall, the proposed change is a logical next step in developing the `InvSqrt` function and ensuring it works correctly.\n\n**Decision:** `True`""" 73 | }, 74 | { 75 | "record": [ 76 | { 77 | "type": "history", 78 | "history": """```diff\n@@ -1 +1,21 @@\n \n+package com.google.gwtjsonrpc.client;\n+\n+public class VoidResult_JsonSerializer extends JsonSerializer {\n+ public static final VoidResult_JsonSerializer INSTANCE =\n+ new VoidResult_JsonSerializer();\n+\n+ private VoidResult_JsonSerializer() {\n+ }\n+\n+ @Override\n+ public void printJson(final StringBuffer sb, final VoidResult o) {\n+ sb.append("{}");\n+ }\n+\n+ @Override\n+ public VoidResult fromJson(final Object o) {\n+ return VoidResult.INSTANCE;\n+ }\n+}\n+\n```""" 79 | }, 80 | { 81 | "type": "current", 82 | "current": """```java\n1 \n2 \n3 package com.google.gwtjsonrpc.client;\n4 \n5 public class VoidResult_JsonSerializer extends JsonSerializer {\n6 public static final VoidResult_JsonSerializer INSTANCE =\n7 new VoidResult_JsonSerializer();\n8 \n9 private VoidResult_JsonSerializer() {\n10 }\n11 \n12 @Override\n13 public void printJson(final StringBuffer sb, final VoidResult o) {\n14 sb.append("{}");\n15 }\n16 \n17 @Override\n18 public VoidResult fromJson(final Object o) {\n19 return VoidResult.INSTANCE;\n20 }\n21 }\n22 \n```""" 83 | } 84 | ], 85 | "change": [ 86 | """```diff@@ -1,3 +1,16 @@\n+// Copyright 2008 Google Inc.\n+//\n+// Licensed under the Apache License, Version 2.0 (the "License");\n+// you may not use this file except in compliance with the License.\n+// You may obtain a copy of the License at\n+//\n+// http://www.apache.org/licenses/LICENSE-2.0\n+//\n+// Unless required by applicable law or agreed to in writing, software\n+// distributed under the License is distributed on an "AS IS" BASIS,\n+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+// See the License for the specific language governing permissions and\n+// limitations under the License.\n \n package com.google.gwtjsonrpc.client;\n \n```""" 87 | ] 88 | }, 89 | { 90 | "judgement": """**Analysis of change 1:**\n\nWithout enough information to confirm the use of the Apache License, predicting this change may not align with the programmer's intentions or the project's actual licensing requirements.\n\n**Decision:** `False`""" 91 | } 92 | ] 93 | -------------------------------------------------------------------------------- /gen/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def extract_code_blocks(markdown): 4 | """ 5 | Extracts code blocks from a given markdown string. 6 | 7 | Args: 8 | markdown (str): The markdown string to extract code blocks from. 9 | 10 | Returns: 11 | list: A list of code blocks extracted from the markdown string. 12 | """ 13 | # Define a regular expression to match code blocks. The (?s) allows '.' to match all characters including newlines. 14 | code_block_pattern = re.compile(r'```(.*?)\n(.*?)```', re.DOTALL) 15 | 16 | # Store the extracted code blocks 17 | code_blocks = [] 18 | 19 | # Use finditer to iterate over all matches of code blocks 20 | for block in code_block_pattern.finditer(markdown): 21 | # Get the content of the code block 22 | code_content = block.group(2) 23 | 24 | # Remove any additional indents (when the code block is in lists or other indented structures) 25 | lines = code_content.split('\n') 26 | code_blocks.append('\n'.join(line[len(lines[-1]):] for line in lines[:-1])) 27 | 28 | return code_blocks 29 | -------------------------------------------------------------------------------- /generic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/generic/__init__.py -------------------------------------------------------------------------------- /generic/special_tokens.py: -------------------------------------------------------------------------------- 1 | IM_START = "<|im_start|>" 2 | IM_END = "<|im_end|>" 3 | NEXT_START = "<|next_start|>" 4 | NEXT_END = "<|next_end|>" 5 | TARGET_START = "<|target_start|>" 6 | TARGET_END = "<|target_end|>" 7 | TARGET = "<|target|>" 8 | SEARCH_AND_REPLACE = "<|search_and_replace|>" 9 | SPECIAL_WORDS=[IM_START, IM_END, NEXT_START, NEXT_END, TARGET_START, TARGET_END, TARGET, SEARCH_AND_REPLACE] 10 | -------------------------------------------------------------------------------- /generic/utils.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | from .special_tokens import * 3 | 4 | def data_args(parser): 5 | """ 6 | Adds data arguments to the given argument parser. 7 | 8 | Args: 9 | parser (argparse.ArgumentParser): The argument parser to add the arguments to. 10 | 11 | Returns: 12 | argparse.ArgumentParser: The updated argument parser. 13 | """ 14 | parser.add_argument("--model_map", type=str, help="Model name, base and port") 15 | parser.add_argument("--input_path", type=str, help="Input Path") 16 | parser.add_argument("--output_path", type=str, help="Output Path") 17 | parser.add_argument("--num_proc", type=int, default=512, help="Number of processes") 18 | return parser 19 | 20 | def openai_args(parser): 21 | """ 22 | Adds OpenAI arguments to the given argument parser. 23 | 24 | Args: 25 | parser (argparse.ArgumentParser): The argument parser to add the arguments to. 26 | 27 | Returns: 28 | argparse.ArgumentParser: The updated argument parser. 29 | """ 30 | parser.add_argument("--max_tokens", type=int, default=3072, help="Max tokens") 31 | parser.add_argument("--temperature", type=float, default=0.2, help="Temperature") 32 | parser.add_argument("--top_p", type=float, default=0.95, help="Temperature") 33 | parser.add_argument("--frequency_penalty", type=float, default=0, help="Temperature") 34 | parser.add_argument("--presence_penalty", type=float, default=0, help="Temperature") 35 | return parser 36 | 37 | def get_openai_kwargs(args): 38 | """ 39 | Get OpenAI keyword arguments from the given arguments. 40 | 41 | Args: 42 | args (argparse.Namespace): The parsed arguments. 43 | 44 | Returns: 45 | dict: The OpenAI keyword arguments. 46 | """ 47 | openai_kwargs = { 48 | "max_tokens": args.max_tokens, 49 | "temperature": args.temperature, 50 | "top_p": args.top_p, 51 | "frequency_penalty": args.frequency_penalty, 52 | "presence_penalty": args.presence_penalty, 53 | } 54 | return openai_kwargs 55 | 56 | def decorate_code(code, lang="", use_line_num=False, start_line=None, end_line=None): 57 | """ 58 | Decorates the given code with markdown syntax. 59 | 60 | Args: 61 | code (str): The code to be decorated. 62 | lang (str, optional): The language identifier for syntax highlighting. Defaults to "". 63 | use_line_num (bool, optional): Whether to include line numbers. Defaults to False. 64 | start_line (int, optional): The starting line number. Defaults to None. 65 | end_line (int, optional): The ending line number. Defaults to None. 66 | 67 | Returns: 68 | str: The decorated code. 69 | 70 | """ 71 | decorate = "```" 72 | if lang: 73 | decorate += lang.lower() 74 | decorate += "\n" 75 | code_lines = code.split("\n") 76 | if start_line is None: 77 | start_line = 0 78 | if end_line is None: 79 | end_line = len(code_lines) 80 | if start_line != 0: 81 | decorate += "...\n" 82 | if use_line_num: 83 | decorate += "\n".join([f"{i + 1 + start_line} {line}" for i, line in enumerate(code_lines[start_line: end_line])]) 84 | else: 85 | decorate += "\n".join(code_lines[start_line: end_line]) 86 | if end_line != len(code_lines): 87 | decorate += "\n..." 88 | decorate += "\n```" 89 | return decorate 90 | 91 | def generate_locations_changes(code1, code2, lang, changes_lines): 92 | """ 93 | Generates a string representing the changes between two versions of code with specified line ranges. 94 | 95 | Args: 96 | code1 (str): The original version of the code. 97 | code2 (str): The modified version of the code. 98 | lang (str): The programming language of the code (used for syntax highlighting). 99 | changes_lines (list of tuples): A list of tuples where each tuple contains two pairs of integers. 100 | Each tuple represents the line ranges in the format ((old_start, old_end), (new_start, new_end)). 101 | 102 | Returns: 103 | str: A formatted string that shows the changes between the two code versions with syntax highlighting. 104 | """ 105 | code1_lines = code1.split('\n') 106 | code2_lines = code2.split('\n') 107 | locations_changes = [] 108 | for c in changes_lines: 109 | (old_start, old_end), (new_start, new_end) = c 110 | next_code = '\n'.join(code2_lines[new_start: new_end]) 111 | locations_changes.append(f"{old_start},{old_end}\n```{lang}\n{next_code}\n```") 112 | return "\n".join(locations_changes) 113 | 114 | def generate_search_and_replace(code1, code2, lang, changes_lines, sep_token=SEARCH_AND_REPLACE): 115 | """ 116 | Generates a search and replace string for code changes between two versions of code. 117 | 118 | Args: 119 | code1 (str): The original version of the code. 120 | code2 (str): The modified version of the code. 121 | lang (str): The programming language of the code (used for syntax highlighting). 122 | changes_lines (list of tuples): A list of tuples where each tuple contains two pairs of integers. 123 | Each pair represents the start and end line numbers of the changes in the original and modified code respectively. 124 | sep_token (str, optional): The token used to separate the search and replace sections in the output. Defaults to SEARCH_AND_REPLACE. 125 | 126 | Returns: 127 | str: A formatted string containing the search and replace sections for the specified changes, 128 | with syntax highlighting for the specified programming language. 129 | """ 130 | code1_lines = code1.split('\n') 131 | code2_lines = code2.split('\n') 132 | search_and_replace = [] 133 | for i in range(len(changes_lines)): 134 | (old_start, old_end), (new_start, new_end) = changes_lines[i] 135 | search = '\n'.join(code1_lines[old_start: old_end]) 136 | replace = '\n'.join(code2_lines[new_start: new_end]) 137 | search_and_replace.append(f"```{lang}\n{search}\n{sep_token}\n{replace}\n```") 138 | return "\n".join(search_and_replace) 139 | 140 | def extract_changes_lines(code1, code2, fromfile='', tofile='', unique=False, merge_changes=False, merge_threshold=2): 141 | """ 142 | Extracts the lines that have changed between two versions of code. 143 | Args: 144 | code1 (str): The original version of the code. 145 | code2 (str): The modified version of the code. 146 | fromfile (str, optional): The name of the original file. Defaults to ''. 147 | tofile (str, optional): The name of the modified file. Defaults to ''. 148 | unique (bool, optional): If True, ensures that the changes are unique. Defaults to False. 149 | merge_changes (bool, optional): If True, merges neighboring changes. Defaults to False. 150 | merge_threshold (int, optional): The threshold for merging neighboring changes. Defaults to 2. 151 | Returns: 152 | list: A list of tuples, where each tuple contains two tuples representing the start and end lines of changes 153 | in the original and modified code respectively. 154 | """ 155 | # Split the code into lines 156 | lines1 = code1.split('\n') 157 | lines2 = code2.split('\n') 158 | 159 | # Create a unified diff 160 | diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile, n=0, lineterm=''))[2:] 161 | 162 | # Process the diff to extract changes and their contexts from both files 163 | changes_lines = [] 164 | current_old_start_line = None 165 | current_new_start_line = None 166 | current_old_end_line = 0 167 | current_new_end_line = 0 168 | 169 | for line in diff: 170 | if line.startswith('@@'): 171 | # Start of a new change block, flush the previous block if it exists 172 | if current_old_start_line or current_new_start_line: 173 | if current_old_end_line > current_old_start_line: 174 | current_old_start_line -= 1 175 | current_old_end_line -= 1 176 | if current_new_end_line > current_new_start_line: 177 | current_new_start_line -= 1 178 | current_new_end_line -= 1 179 | changes_lines.append(((current_old_start_line, current_old_end_line), 180 | (current_new_start_line, current_new_end_line))) 181 | 182 | # Extract line numbers from the hunk information 183 | parts = line.split() 184 | old_info, new_info = parts[1], parts[2] 185 | current_old_start_line = int(old_info.split(',')[0][1:]) 186 | current_new_start_line = int(new_info.split(',')[0][1:]) 187 | current_old_end_line = current_old_start_line 188 | current_new_end_line = current_new_start_line 189 | elif line.startswith('-'): 190 | # Line removed from the old file 191 | current_old_end_line += 1 192 | elif line.startswith('+'): 193 | # Line added to the new file 194 | current_new_end_line += 1 195 | elif line.startswith(' '): 196 | assert False, "No lines should be unchanged in a unified diff with context lines number 0" 197 | 198 | # Append the last block if it exists 199 | if current_old_start_line or current_new_start_line: 200 | if current_old_end_line > current_old_start_line: 201 | current_old_start_line -= 1 202 | current_old_end_line -= 1 203 | if current_new_end_line > current_new_start_line: 204 | current_new_start_line -= 1 205 | current_new_end_line -= 1 206 | changes_lines.append(((current_old_start_line, current_old_end_line), 207 | (current_new_start_line, current_new_end_line))) 208 | changes_lines = filter_changes(changes_lines, lines1, lines2) 209 | if unique: 210 | changes_lines = unique_changes(changes_lines, lines1) 211 | if merge_changes: 212 | changes_lines = merge_neighbor_change(changes_lines, merge_threshold) 213 | return changes_lines 214 | 215 | def filter_changes(changes_lines, lines1, lines2): 216 | """ 217 | Filters out changes that are only due to differences in indentation or whitespace. 218 | 219 | Args: 220 | changes_lines (list of tuples): A list of tuples where each tuple contains two tuples. 221 | The first tuple represents the start and end indices of the change in `lines1`. 222 | The second tuple represents the start and end indices of the change in `lines2`. 223 | lines1 (list of str): The original list of lines. 224 | lines2 (list of str): The modified list of lines. 225 | 226 | Returns: 227 | list of tuples: A list of tuples containing the changes that are not purely due to differences in indentation or whitespace. 228 | """ 229 | filtered_changes_lines = [] 230 | if len(changes_lines) == 0: 231 | return filtered_changes_lines 232 | for old_change, new_change in changes_lines: 233 | old_start, old_end = old_change 234 | new_start, new_end = new_change 235 | indent1_max = max(len(line) - len(line.lstrip()) for line in lines1) 236 | indent2_max = max(len(line) - len(line.lstrip()) for line in lines2) 237 | before = " ".join(" ".join(lines1[old_start: old_end]).split()) 238 | after = " ".join(" ".join(lines2[new_start: new_end]).split()) 239 | if before != after or indent1_max != indent2_max: 240 | filtered_changes_lines.append((old_change, new_change)) 241 | return filtered_changes_lines 242 | 243 | def unique_changes(changes_lines, lines1): 244 | """ 245 | Identifies unique changes in a list of changes and expands them based on the context of the original lines. 246 | 247 | Args: 248 | changes_lines (list of tuples): A list of tuples where each tuple contains two sub-tuples. 249 | The first sub-tuple represents the old change with start and end indices, 250 | and the second sub-tuple represents the new change with start and end indices. 251 | lines1 (list): A list of lines from the original content. 252 | 253 | Returns: 254 | list of tuples: A list of tuples where each tuple contains two sub-tuples. 255 | The first sub-tuple represents the expanded old change with updated start and end indices, 256 | and the second sub-tuple represents the expanded new change with updated start and end indices. 257 | 258 | Raises: 259 | AssertionError: If the left or right expansion values are negative, indicating an invalid expansion. 260 | """ 261 | unique_changes_lines = [] 262 | if len(changes_lines) == 0: 263 | return unique_changes_lines 264 | for old_change, new_change in changes_lines: 265 | left_expansion, right_expansion = find_unique_sublist(lines1, old_change[0], old_change[1]) 266 | assert left_expansion >= 0 and right_expansion >= 0, "Invalid expansion" 267 | unique_changes_lines.append(((old_change[0] - left_expansion, old_change[1] + right_expansion), (new_change[0] - left_expansion, new_change[1] + right_expansion))) 268 | return unique_changes_lines 269 | 270 | # TODO: accelerate 271 | def find_unique_sublist(b, a1, a2): 272 | """ 273 | Finds the smallest extension of the sublist `b[a1:a2]` that is unique within the list `b`. 274 | 275 | Args: 276 | b (list): The list in which to find the unique sublist. 277 | a1 (int): The starting index of the sublist. 278 | a2 (int): The ending index of the sublist. 279 | 280 | Returns: 281 | tuple: A tuple (x, y) where `x` is the minimum number of elements to extend the sublist 282 | at the beginning, and `y` is the minimum number of elements to extend the sublist 283 | at the end to make it unique within `b`. If no unique sublist is found, returns (-1, -1). 284 | """ 285 | if not "\n".join(b).strip(): 286 | return 0, 0 287 | 288 | def is_unique_sublist(b, sublist): 289 | """ 290 | Check if a sublist appears exactly once in a list. 291 | 292 | Args: 293 | b (list): The main list in which to search for the sublist. 294 | sublist (list): The sublist to search for within the main list. 295 | 296 | Returns: 297 | bool: True if the sublist appears exactly once in the main list, False otherwise. 298 | 299 | Raises: 300 | AssertionError: If the sublist appears more than once in the main list. 301 | """ 302 | if not "\n".join(sublist).strip(): 303 | return False 304 | count = 0 305 | for i in range(len(b) - len(sublist) + 1): 306 | if b[i:i + len(sublist)] == sublist: 307 | count += 1 308 | if count > 1: 309 | return False 310 | assert count == 1, "Invalid count" 311 | return count == 1 312 | 313 | for extend_length in range(len(b) - a2 + a1 + 1): 314 | for window in range(extend_length + 1): 315 | current_sublist = b[a1 - extend_length + window:a2 + window] 316 | if is_unique_sublist(b, current_sublist): 317 | return min(a1, extend_length - window), min(len(b) - a2, window) 318 | return -1, -1 319 | 320 | def merge_neighbor_change(changes_lines, merge_threshold=2): 321 | """ 322 | Merges neighboring changes in a list of change line ranges based on a specified threshold. 323 | 324 | Args: 325 | changes_lines (list of tuples): A list of tuples where each tuple contains two tuples representing 326 | the old and new change line ranges respectively. 327 | Example: [((old_start, old_end), (new_start, new_end)), ...] 328 | merge_threshold (int, optional): The maximum allowed distance between neighboring changes to be merged. 329 | Defaults to 2. 330 | 331 | Returns: 332 | list of tuples: A list of merged change line ranges. 333 | """ 334 | merged_changes_lines = [] 335 | if len(changes_lines) == 0: 336 | return merged_changes_lines 337 | for i, (old_change, new_change) in enumerate(changes_lines): 338 | if i == 0: 339 | merged_changes_lines.append((old_change, new_change)) 340 | else: 341 | last_old_change, last_new_change = merged_changes_lines[-1] 342 | if old_change[0] - last_old_change[1] <= merge_threshold and new_change[0] - last_new_change[1] <= merge_threshold: 343 | merged_changes_lines[-1] = (min(last_old_change[0], old_change[0]), max(last_old_change[1], old_change[1])), (min(last_new_change[0], new_change[0]), max(last_new_change[1], new_change[1])) 344 | else: 345 | merged_changes_lines.append((old_change, new_change)) 346 | return merged_changes_lines 347 | -------------------------------------------------------------------------------- /model_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "TechxGenus/CursorCore-Yi-1.5B": { 3 | "base": "http://127.0.0.1:10086/v1", 4 | "api": "sk-xxx" 5 | } 6 | } -------------------------------------------------------------------------------- /pictures/APEval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/APEval.png -------------------------------------------------------------------------------- /pictures/CursorWeb.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/CursorWeb.gif -------------------------------------------------------------------------------- /pictures/EvalPlus_CanItEdit_OctoPack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/EvalPlus_CanItEdit_OctoPack.png -------------------------------------------------------------------------------- /pictures/conversation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/conversation.png -------------------------------------------------------------------------------- /pictures/cursorcore.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/cursorcore.png -------------------------------------------------------------------------------- /pictures/discord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/pictures/discord.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # data 2 | tqdm 3 | pandas 4 | datasets 5 | rouge_score 6 | 7 | # gen 8 | sglang[all] 9 | openai 10 | 11 | # src 12 | Levenshtein 13 | 14 | # train 15 | numba 16 | numpy 17 | ninja 18 | torch 19 | packaging 20 | transformers 21 | liger-kernel 22 | tensorboard 23 | accelerate 24 | deepspeed<=0.14.5 25 | 26 | # eval 27 | vllm 28 | evaluate 29 | evalplus 30 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | # Programming-Instruct 2 | 3 | This folder is the code for our data collection pipeline. 4 | 5 | ## Inference service for LLMs 6 | 7 | We generate data by prompting different LLMs. Currently, the OpenAI interface has been standardized, and almost all Inference frameworks support OpenAI compatible servers. Therefore, we uniformly use the OpenAI interface to generate. 8 | 9 | Example script to deploy `deepseek-coder-6.7b-instruct` using `sglang`: 10 | 11 | ```bash 12 | python -m sglang.launch_server --model-path deepseek-ai/deepseek-coder-6.7b-instruct --port 10086 13 | ``` 14 | 15 | Example script to deploy `deepseek-coder-6.7b-instruct` using `vllm`: 16 | 17 | ```bash 18 | python -m vllm.entrypoints.openai.api_server --port 10086 --model deepseek-ai/deepseek-coder-6.7b-instruct --enable-prefix-caching 19 | ``` 20 | 21 | We define the model inference service parameters in `model_map.json`. An example configuration is as follows: 22 | 23 | ```json 24 | { 25 | "deepseek-ai/deepseek-coder-6.7b-instruct": { 26 | "base": "http://127.0.0.1:10086/v1", 27 | "api": "sk-xxx" 28 | } 29 | } 30 | ``` 31 | 32 | ## AI programmer 33 | 34 | For each code snippet, we use LLMs to generate the corresponding coding history. Its input file is a list of code snippets. Examples are as follows: 35 | 36 | ```json 37 | [ 38 | "int i...", 39 | "import json...", 40 | "func main...", 41 | ... 42 | ] 43 | ``` 44 | 45 | The command to run the code is: 46 | 47 | ```bash 48 | python src/aiprogrammer.py --model_map model_map.json --input_path data/code_snippets.json --output_path data/aiprogrammer.json 49 | ``` 50 | 51 | ## Data collection 52 | 53 | Run the following scripts to generate data from various data sources: 54 | 55 | ```bash 56 | # AIprogrammer 57 | python src/data_collection.py --model_map model_map.json --data_type aiprogrammer --input_path data/aiprogrammer.json --output_path data/aiprogrammer_end.json 58 | 59 | # Git Commit 60 | python src/data_collection.py --model_map model_map.json --data_type commit --input_path data/commit.json --output_path data/commit_end.json --limit_one_block_prob 1.0 61 | 62 | # Online Submit 63 | python src/data_collection.py --model_map model_map.json --data_type submit --input_path data/submit_process.json --output_path data/submit_end.json 64 | ``` 65 | 66 | ## Synthetic target area 67 | 68 | Programmers often specify the parts requiring changes, typically in one of two ways: either by clicking with the cursor to indicate a general area or by selecting a specific text range with defined start and end points. 69 | 70 | We synthesize the target modification area with a random algorithm: 71 | 72 | ```bash 73 | python src/post_collection.py --input_path data/tmp.json --output_path data/tmp_area.json 74 | ``` 75 | -------------------------------------------------------------------------------- /src/aiprogrammer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | current_path = os.path.abspath(__file__) 4 | parent_directory = os.path.dirname(os.path.dirname(current_path)) 5 | sys.path.append(parent_directory) 6 | 7 | import json 8 | import argparse 9 | from gen import AIProgrammer 10 | from generic.utils import data_args, openai_args, get_openai_kwargs, decorate_code 11 | 12 | parser = argparse.ArgumentParser() 13 | parser = data_args(parser) 14 | parser = openai_args(parser) 15 | args = parser.parse_args() 16 | openai_kwargs = get_openai_kwargs(args) 17 | 18 | with open(args.model_map, 'r') as f: 19 | model_map = json.load(f) 20 | 21 | with open(args.input_path, 'r') as f: 22 | inputs = json.load(f) 23 | 24 | inputs = [{'code': decorate_code(i), "content": i} for i in inputs] 25 | 26 | Gen = AIProgrammer( 27 | model_map=model_map, num_proc=args.num_proc, **openai_kwargs 28 | ) 29 | 30 | results = Gen.gen(inputs) 31 | results = [result for result in results if result["output"]] 32 | 33 | with open(args.output_path, 'w') as f: 34 | json.dump(results, f, indent=4) 35 | -------------------------------------------------------------------------------- /src/data_collection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | current_path = os.path.abspath(__file__) 4 | parent_directory = os.path.dirname(os.path.dirname(current_path)) 5 | sys.path.append(parent_directory) 6 | 7 | import json 8 | import random 9 | import argparse 10 | import concurrent.futures 11 | from tqdm import tqdm 12 | from gen import GenJudgement, GenInstruction, GenChat 13 | from generic.utils import data_args, openai_args, get_openai_kwargs, decorate_code 14 | from utils import generate_diff, random_edit_series, generate_diff_blocks, apply_selected_blocks 15 | 16 | def generate_record_and_change(history, current, final): 17 | """ 18 | Generates a record of changes between historical code segments, the current code segment, 19 | and the final code segment. 20 | 21 | Args: 22 | history (list): A list of dictionaries representing historical code segments. Each dictionary 23 | should have a 'type' key with the value 'code' and a 'code' key with the code 24 | segment as its value. 25 | current (dict): A dictionary representing the current code segment. It should have a 'type' key 26 | with the value 'code', a 'code' key with the code segment, and a 'lang' key 27 | indicating the programming language. 28 | final (str): The final code segment as a string. 29 | 30 | Returns: 31 | tuple: A tuple containing: 32 | - blocks (list): A list of blocks representing the differences between the current code 33 | and the final code. 34 | - current_record (list): A list of dictionaries representing the historical changes and 35 | the current code segment. 36 | - future_change (list): A list of decorated code segments representing the future changes 37 | needed to reach the final code. 38 | """ 39 | current_record = [] 40 | future_change = [] 41 | 42 | history_current = history + [current] 43 | for previous, now in zip(history_current[:-1], history_current[1:]): 44 | if previous['type'] == 'code' and now['type'] == 'code': 45 | diff = generate_diff(previous['code'], now['code']) 46 | current_record.append({"type": "history", "history": decorate_code(diff, "diff")}) 47 | else: 48 | raise ValueError("History should only contain code segments now.") 49 | 50 | current_record.append({"type": "current", "current": decorate_code(current['code'], current['lang'], use_line_num=True)}) 51 | blocks = [] 52 | if current['code'] != final: 53 | blocks = generate_diff_blocks(current['code'], final) 54 | for i in range(len(blocks)): 55 | code = apply_selected_blocks(current['code'], blocks, [i]) 56 | future_change.append(decorate_code(generate_diff(current['code'], code), "diff")) 57 | return blocks, current_record, future_change 58 | 59 | parser = argparse.ArgumentParser() 60 | parser = data_args(parser) 61 | parser = openai_args(parser) 62 | parser.add_argument("--alpha", type=float, default=1.0, help="Sampling parameter for history") 63 | parser.add_argument("--data_type", type=str, help="Data Type (aiprogrammer, commit, submit)") 64 | parser.add_argument("--max_per_sample", type=int, default=1, help="Maximum number of samples per input") 65 | parser.add_argument("--use_history_prob", type=float, default=0.6, help="Probability of using history") 66 | parser.add_argument("--use_user_prob", type=float, default=0.5, help="Probability of using user") 67 | parser.add_argument("--limit_one_block_prob", type=float, default=0.5, help="Probability of limiting one block during History") 68 | args = parser.parse_args() 69 | openai_kwargs = get_openai_kwargs(args) 70 | 71 | with open(args.model_map, 'r') as f: 72 | model_map = json.load(f) 73 | 74 | with open(args.input_path, 'r') as f: 75 | inputs = json.load(f) 76 | 77 | random.shuffle(inputs) 78 | 79 | GenJudge = GenJudgement( 80 | model_map=model_map, num_proc=args.num_proc, **openai_kwargs 81 | ) 82 | 83 | GenInst = GenInstruction( 84 | model_map=model_map, num_proc=args.num_proc, **openai_kwargs 85 | ) 86 | 87 | GenCht = GenChat( 88 | model_map=model_map, num_proc=args.num_proc, **openai_kwargs 89 | ) 90 | 91 | def process(sample): 92 | """ 93 | Processes a given sample based on the specified data type and generates a series of code transformations. 94 | 95 | Args: 96 | sample (dict): A dictionary containing the sample data. The structure of the sample depends on the data type: 97 | - For 'aiprogrammer': Must contain 'output'. 98 | - For 'commit': Must contain 'code1' and 'code2'. 99 | - For 'submit': Must contain 'submissions', 'language', and 'problems_contexts'. 100 | 101 | Returns: 102 | list: A list of dictionaries, each representing a step in the code transformation process. Each dictionary contains: 103 | - 'history': The history of code transformations up to the current step. 104 | - 'current': The current code block being processed. 105 | - 'user': The user input or instruction for the current step. 106 | - 'chat': The chat or commentary generated for the current step. 107 | - 'next': The next code block after applying the transformation. 108 | 109 | Raises: 110 | ValueError: If the data type specified in args.data_type is not supported. 111 | """ 112 | if args.data_type == 'aiprogrammer': 113 | sample['lang'] = "" 114 | code_series = [] 115 | for code1, code2 in zip(sample['output'][:-1], sample['output'][1:]): 116 | if random.random() < args.limit_one_block_prob: 117 | code_series += random_edit_series(code1, code2)[:-1] 118 | else: 119 | code_series += [code1] 120 | final = sample['output'][-1] 121 | elif args.data_type == 'commit': 122 | if random.random() < args.limit_one_block_prob: 123 | code_series = random_edit_series(sample['code1'], sample['code2'])[:-1] 124 | else: 125 | code_series = [sample['code1']] 126 | final = sample['code2'] 127 | elif args.data_type == 'submit': 128 | sample['lang'] = sample['language'].lower() 129 | code_series = [] 130 | for code1, code2 in zip(sample['submissions'][:-1], sample['submissions'][1:]): 131 | if random.random() < args.limit_one_block_prob: 132 | code_series += random_edit_series(code1, code2)[:-1] 133 | else: 134 | code_series += [code1] 135 | final = sample['submissions'][-1] 136 | else: 137 | raise ValueError("Data Type {} is not supported.".format(args.data_type)) 138 | histories = [{"type": "code", 'lang': sample['lang'], "code": code} for code in code_series] 139 | candidates = list(range(len(histories))) 140 | result = [] 141 | candidate_num = 0 142 | while len(candidates) > 0 and candidate_num < args.max_per_sample: 143 | weights = [args.alpha ** c + 1e-6 for c in candidates] 144 | candidate = random.choices(candidates, weights=weights, k=1)[0] 145 | candidates = [c for c in candidates if c != candidate] 146 | current = histories[candidate] 147 | use_history = random.random() < args.use_history_prob 148 | use_user = random.random() < args.use_user_prob 149 | if candidate == 0 and current['code'].strip() == "" and not use_user: 150 | continue 151 | if use_history: 152 | history = histories[:candidate] 153 | if len(history) == 0: 154 | continue 155 | else: 156 | history = [] 157 | blocks, current_record, future_change = generate_record_and_change(history, current, final) 158 | if len(future_change) == 0: 159 | continue 160 | if use_user: 161 | if args.data_type == 'aiprogrammer': 162 | current_record_with_auxiliary = current_record 163 | elif args.data_type == 'commit': 164 | current_record_with_auxiliary = current_record + [{"type": "git", "git": sample['git']}] 165 | elif args.data_type == 'submit': 166 | current_record_with_auxiliary = current_record + [{"type": "problem", "problem": sample['problems_contexts']}] 167 | if candidate == 0 and args.data_type == 'commit': 168 | user = sample['git'] 169 | elif candidate == 0 and args.data_type == 'submit': 170 | user = "Please generate a correct {} program for the following problem:\n\n{}".format(sample['lang'], sample['problems_contexts']) 171 | else: 172 | inst = GenInst.gen([{"record": current_record_with_auxiliary, "change": decorate_code(generate_diff(current['code'], final), "diff")}]) 173 | if inst[0]["output"] is None: 174 | continue 175 | user = inst[0]["output"] 176 | next_ = {"type": "code", "lang": current['lang'], "code": final} 177 | current_record += [{"type": "user", "user": user}] 178 | else: 179 | judge = GenJudge.gen([{"record": current_record, "change": future_change}]) 180 | if judge[0]["output"] is None: 181 | continue 182 | selected_indices = [index for index, value in enumerate(judge[0]["output"]) if value] 183 | user = "" 184 | next_ = {"type": "code", "lang": current['lang'], "code": apply_selected_blocks(current['code'], blocks, selected_indices)} 185 | next_change = decorate_code(generate_diff(current['code'], next_["code"]), "diff") 186 | chat = GenCht.gen([{"record": current_record, "change": next_change}]) 187 | if chat[0]["output"] is None: 188 | chat = "" 189 | else: 190 | chat = chat[0]["output"] 191 | result.append({"history": history, "current": current, "user": user, "chat": chat, "next": next_}) 192 | candidate_num += 1 193 | return result 194 | 195 | with concurrent.futures.ThreadPoolExecutor(args.num_proc) as executor: 196 | futures = [executor.submit(process, item) for item in inputs] 197 | results = [] 198 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 199 | results.append(future.result()) 200 | 201 | results = [r for result in results for r in result] 202 | with open(args.output_path, 'w') as f: 203 | json.dump(list(results), f, indent=4) 204 | -------------------------------------------------------------------------------- /src/merge_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | def merge_json_files(output_file, *input_files): 5 | """ 6 | Merge multiple JSON files into a single file. 7 | 8 | Args: 9 | output_file (str): The path to the output file where the merged data will be written. 10 | *input_files (str): Variable number of input file paths to be merged. 11 | 12 | Raises: 13 | FileNotFoundError: If any of the input files are not found. 14 | JSONDecodeError: If any of the input files cannot be parsed as JSON. 15 | 16 | Returns: 17 | None 18 | """ 19 | merged_data = [] 20 | for file in input_files: 21 | try: 22 | with open(file, 'r') as f: 23 | data = json.load(f) 24 | if isinstance(data, list): 25 | merged_data.extend(data) 26 | else: 27 | print(f"Warning: {file} does not contain a list. Skipping.") 28 | except FileNotFoundError as e: 29 | print(f"Error reading {file}: {e}") 30 | except json.JSONDecodeError as e: 31 | print(f"Error parsing {file} as JSON: {e}") 32 | try: 33 | with open(output_file, 'w') as f: 34 | json.dump(merged_data, f, indent=4) 35 | print(f"Merged data written to {output_file}") 36 | except Exception as e: 37 | print(f"Error writing to {output_file}: {e}") 38 | 39 | parser = argparse.ArgumentParser(description='Merge multiple JSON files into one.') 40 | parser.add_argument('output', type=str, help='The output file to write the merged data to.') 41 | parser.add_argument('inputs', nargs='+', type=str, help='The input JSON files to merge.') 42 | 43 | args = parser.parse_args() 44 | merge_json_files(args.output, *args.inputs) 45 | -------------------------------------------------------------------------------- /src/post_collection.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import argparse 4 | import Levenshtein 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--input_path", type=str, help="Input Path") 8 | parser.add_argument("--output_path", type=str, help="Output Path") 9 | parser.add_argument("--use_loc_prob", type=float, default=0.75, help="Probability of using LOC") 10 | parser.add_argument("--loc_cursor_prob", type=float, default=0.5, help="Probability of using cursor, using range otherwise") 11 | parser.add_argument("--loc_noisy_prob", type=float, default=0.5, help="Probability of using noisy range") 12 | parser.add_argument("--loc_noisy_length", type=int, default=32, help="Length of noisy range") 13 | parser.add_argument("--loc_all_noisy_prob", type=float, default=0.05, help="Probability of using noisy range for all") 14 | args = parser.parse_args() 15 | 16 | def get_changes(s1, s2): 17 | """ 18 | Compute the changes needed to transform string s1 into string s2 using Levenshtein distance. 19 | 20 | Args: 21 | s1 (str): The original string. 22 | s2 (str): The target string. 23 | 24 | Returns: 25 | list of tuple: A list of tuples where each tuple contains the start and end indices of the changes in s1. 26 | Each tuple represents a range of indices in s1 that need to be changed to match s2. 27 | """ 28 | edit_ops = Levenshtein.opcodes(s1, s2) 29 | changes = [] 30 | for op, i1, i2, _, _ in edit_ops: 31 | if op != 'equal': 32 | if changes and changes[-1][1] == i1: 33 | changes[-1] = (changes[-1][0], i2) 34 | else: 35 | changes.append((i1, i2)) 36 | return changes 37 | 38 | def sample_from_boundaries(left, right, max_length, noisy_length=0): 39 | """ 40 | Samples a start and end point within specified boundaries, with optional noise. 41 | 42 | Args: 43 | left (int): The left boundary for sampling the start point. 44 | right (int): The right boundary for sampling the end point. 45 | max_length (int): The maximum allowable length for the end point. 46 | noisy_length (int, optional): The amount of noise to add to the boundaries. Defaults to 0. 47 | 48 | Returns: 49 | tuple: A tuple containing the sampled start and end points (start, end). 50 | """ 51 | start = random.choice(list(range(max(0, left - noisy_length), left + 1))) 52 | end = random.choice(list(range(right, min(right + noisy_length, max_length) + 1))) 53 | return start, end 54 | 55 | def sample_from_ranges(ranges): 56 | """ 57 | Selects a random number from a set of ranges. 58 | 59 | Args: 60 | ranges (list of tuple): A list of tuples where each tuple contains two integers (start, end) 61 | representing the inclusive range from start to end. 62 | 63 | Returns: 64 | int: A randomly selected number from the combined ranges. 65 | 66 | Raises: 67 | ValueError: If the ranges list is empty or if no numbers can be generated from the given ranges. 68 | """ 69 | numbers = set() 70 | for start, end in ranges: 71 | numbers.update(range(start, end + 1)) 72 | return random.choice(list(numbers)) 73 | 74 | def post_current(current, next_): 75 | """ 76 | Determines the cursor position or range based on the changes between the current and next states. 77 | 78 | Args: 79 | current (str): The current state. 80 | next_ (str): The next state. 81 | 82 | Returns: 83 | Union[int, Tuple[int, int], None]: 84 | - An integer representing a single cursor position. 85 | - A tuple of two integers representing a range (left, right). 86 | - None if no cursor position or range is determined. 87 | """ 88 | if random.random() < args.use_loc_prob: 89 | changes = get_changes(current, next_) 90 | if not changes: 91 | changes = [(0, len(current))] 92 | if random.random() < args.loc_cursor_prob: 93 | if random.random() < args.loc_all_noisy_prob: 94 | cursor = sample_from_ranges([(0, len(current))]) 95 | else: 96 | if random.random() < args.loc_noisy_prob: 97 | cursor = sample_from_ranges([random.choice(changes)]) 98 | else: 99 | cursor = random.choice(random.choice(changes)) 100 | return cursor 101 | else: 102 | if random.random() < args.loc_all_noisy_prob: 103 | left, right = sample_from_ranges([(0, len(current))]), sample_from_ranges([(0, len(current))]) 104 | if left > right: 105 | left, right = right, left 106 | else: 107 | if random.random() < args.loc_noisy_prob: 108 | noisy_length = args.loc_noisy_length 109 | else: 110 | noisy_length = 0 111 | left, right = sample_from_boundaries(changes[0][0], changes[-1][1], len(current), noisy_length) 112 | return left, right 113 | else: 114 | return None 115 | 116 | with open(args.input_path, "r") as f: 117 | data = json.load(f) 118 | 119 | for sample in data: 120 | loc = post_current(sample["current"]["code"], sample["next"]["code"]) 121 | # TODO: Unify variable names 122 | sample["loc"] = loc 123 | 124 | with open(args.output_path, "w") as f: 125 | json.dump(data, f, indent=4) 126 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import difflib 3 | import Levenshtein 4 | 5 | def generate_diff(code1, code2, fromfile='', tofile='', n=3): 6 | """ 7 | Generate a unified diff between two code strings. 8 | 9 | Args: 10 | code1 (str): The first code string. 11 | code2 (str): The second code string. 12 | fromfile (str, optional): The name of the first file. Defaults to ''. 13 | tofile (str, optional): The name of the second file. Defaults to ''. 14 | 15 | Returns: 16 | str: The unified diff as a string. 17 | """ 18 | # Split the code strings into lines 19 | lines1 = code1.split('\n') 20 | lines2 = code2.split('\n') 21 | 22 | # Generate unified diff 23 | diff = difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile, n=n, lineterm='') 24 | output = "" 25 | 26 | for line in diff: 27 | output += line + "\n" 28 | 29 | if fromfile == '' and tofile == '': 30 | output = "\n".join(output.split("\n")[2:]) 31 | 32 | return output[:-1] 33 | 34 | def if_continuous_modify(code1, code2, code3): 35 | """ 36 | Check if code3 is a continuous modification of code1 and code2. 37 | 38 | Args: 39 | code1 (str): The first code string. 40 | code2 (str): The second code string. 41 | code3 (str): The third code string. 42 | 43 | Returns: 44 | bool: True if code3 is a continuous modification of code1 and code2, False otherwise. 45 | """ 46 | # Calculate the Levenshtein distance between code1 and code2 47 | dist1 = Levenshtein.distance(code1, code2) 48 | # Calculate the Levenshtein distance between code2 and code3 49 | dist2 = Levenshtein.distance(code2, code3) 50 | # Calculate the Levenshtein distance between code1 and code3 51 | dist3 = Levenshtein.distance(code1, code3) 52 | 53 | # Check if code3 is a continuous modification of code1 and code2 54 | if dist3 == dist1 + dist2: 55 | return True 56 | else: 57 | return False 58 | 59 | def blockwise_if_continuous_modify(code1, code2, code3): 60 | """ 61 | Check if code3 is a continuous modification of code1 and code2. 62 | 63 | Args: 64 | code1 (str): The first code string. 65 | code2 (str): The second code string. 66 | code3 (str): The third code string. 67 | 68 | Returns: 69 | bool: True if code3 is a continuous modification of code1 and code2, False otherwise. 70 | """ 71 | # Calculate the Levenshtein distance between code1 and code2 72 | dist1 = Levenshtein.distance(code1, code2) 73 | # Calculate the Levenshtein distance between code2 and code3 74 | dist2 = Levenshtein.distance(code2, code3) 75 | # Calculate the Levenshtein distance between code1 and code3 76 | dist3 = Levenshtein.distance(code1, code3) 77 | 78 | past_diff_blocks = generate_diff_blocks(code1, code2) 79 | new_diff_blocks = generate_diff_blocks(code1, code3) 80 | 81 | # Check if code3 is a continuous modification of code1 and code2 82 | if dist3 == dist1 + dist2 and len(past_diff_blocks) == len(new_diff_blocks): 83 | return True 84 | else: 85 | return False 86 | 87 | def generate_diff_blocks(original, modified): 88 | """ 89 | Generate diff blocks between two strings. 90 | 91 | Args: 92 | original (str): The original string. 93 | modified (str): The modified string. 94 | 95 | Returns: 96 | list: A list of tuples, where each tuple contains a block of modified lines and the line number in the original string where the block starts. 97 | """ 98 | # Use difflib's ndiff to find differences 99 | differ = difflib.Differ() 100 | diff = list(differ.compare(original.split('\n'), modified.split('\n'))) 101 | 102 | # store all modified blocks 103 | blocks = [] 104 | current_block = [] 105 | 106 | # track the current line number 107 | orig_line_no = 0 108 | block_line_len = 0 109 | 110 | # Traverse the diff results into chunks 111 | for line in diff: 112 | if line.startswith(' '): 113 | # If the current block has content and an unmodified line is encountered, save the current block and reset 114 | if current_block: 115 | blocks.append((current_block, orig_line_no - block_line_len)) 116 | current_block = [] 117 | block_line_len = 0 118 | orig_line_no += 1 119 | elif line.startswith('- '): 120 | current_block.append(line) 121 | orig_line_no += 1 122 | block_line_len += 1 123 | else: 124 | current_block.append(line) 125 | 126 | # Make sure the last chunk is added 127 | if current_block: 128 | blocks.append((current_block, orig_line_no - block_line_len)) 129 | 130 | return blocks 131 | 132 | def apply_selected_blocks(original, blocks, selected_indices): 133 | """ 134 | Apply selected blocks to the original code. 135 | 136 | Args: 137 | original (str): The original code as a string. 138 | blocks (list): A list of code blocks. 139 | selected_indices (list): A list of indices representing the selected blocks. 140 | 141 | Returns: 142 | str: The modified code after applying the selected blocks. 143 | """ 144 | # Split the original code by lines 145 | original_lines = original.split('\n') 146 | 147 | # Adjust code based on selected block 148 | offset = 0 149 | for index in selected_indices: 150 | block, start_line = blocks[index] 151 | # Iterate through each row in the block 152 | delete_offset = 0 153 | for line in block: 154 | if line.startswith('- '): 155 | del original_lines[start_line + offset] 156 | delete_offset -= 1 157 | elif line.startswith('+ '): 158 | original_lines.insert(start_line + offset, line[2:]) 159 | offset += 1 160 | offset += delete_offset 161 | delete_offset = 0 162 | 163 | return '\n'.join(original_lines) 164 | 165 | # TODO: accelerate the process 166 | def random_edit_series(code1, code2): 167 | """ 168 | Generates a series of randomly edited code versions between two given code snippets. 169 | 170 | Args: 171 | code1 (str): The original code snippet. 172 | code2 (str): The modified code snippet. 173 | 174 | Returns: 175 | list: A list of code snippets representing the series of randomly edited versions 176 | between the original and modified code snippets. 177 | """ 178 | blocks = generate_diff_blocks(code1, code2) 179 | random_block_order = list(range(len(blocks))) 180 | # random.seed(42) 181 | random.shuffle(random_block_order) 182 | code_series = [] 183 | for i in range(len(blocks) + 1): 184 | code_series.append(apply_selected_blocks(code1, blocks, sorted(random_block_order[:i]))) 185 | return code_series 186 | 187 | def sequential_edit_series(code1, code2): 188 | """ 189 | Generate a series of code snippets by sequentially applying selected diff blocks. 190 | 191 | Args: 192 | code1 (str): The original code. 193 | code2 (str): The modified code. 194 | 195 | Returns: 196 | list: A list of code snippets, each representing the result of applying a selected diff block. 197 | """ 198 | blocks = generate_diff_blocks(code1, code2) 199 | code_series = [] 200 | for i in range(len(blocks)): 201 | code_series.append(apply_selected_blocks(code1, blocks, [i])) 202 | return code_series 203 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechxGenus/CursorCore/b1db04c746318f8d742bd29e0850e99e322498e7/tests/__init__.py -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | This directory contains programs related to training the CursorCore series. Some other training engines that support custom templates should also work. 4 | 5 | ## Prepare conversation 6 | 7 | Run the following program to inject the conversation template into the model: 8 | 9 | ```bash 10 | python train/prepare_conversation.py --model_name_or_path deepseek-ai/deepseek-coder-1.3b-base --save_path train/formatted-deepseek-coder-1.3b-base 11 | ``` 12 | 13 | ## Prepare data 14 | 15 | Run the following program to format training data: 16 | 17 | ```bash 18 | # WF Format 19 | python train/prepare_data.py --input_path data/data.json --output_path data/train_data.json 20 | 21 | # LC Format 22 | python train/prepare_data.py --input_path data/data.json --output_path data/train_data_lc.json --format_type lc 23 | 24 | # SR Format 25 | python train/prepare_data.py --input_path data/data.json --output_path data/train_data_sr.json --format_type sr 26 | ``` 27 | 28 | ## Training script 29 | 30 | Our script can be run with common distributed launchers such as deepspeed and accelerate. Here is an example script for training models using deepspeed: 31 | 32 | ```bash 33 | MODEL=$1 34 | DATA=$2 35 | OUTPUT=$3 36 | mkdir -p $OUTPUT 37 | 38 | deepspeed --master_port='10086' train/training.py \ 39 | --model_name_or_path $MODEL \ 40 | --data_path $DATA \ 41 | --model_max_length 16384 \ 42 | --batch_max_length 50000 \ 43 | --learning_rate 5e-5 \ 44 | --weight_decay 0 \ 45 | --num_train_epochs 2 \ 46 | --gradient_accumulation_steps 1 \ 47 | --gradient_checkpointing \ 48 | --lr_scheduler_type cosine \ 49 | --warmup_steps 15 \ 50 | --seed 10086 \ 51 | --output_dir $OUTPUT \ 52 | --bf16 True \ 53 | --evaluation_strategy "no" \ 54 | --save_strategy "epoch" \ 55 | --load_best_model_at_end False \ 56 | --save_total_limit 1000 \ 57 | --logging_steps 20 \ 58 | --tf32 True \ 59 | --optim adafactor \ 60 | --use_liger_kernel True \ 61 | --deepspeed train/ds_config.json 62 | ``` 63 | -------------------------------------------------------------------------------- /train/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "overlap_comm": true, 5 | "contiguous_gradients": true, 6 | "sub_group_size": 0, 7 | "reduce_bucket_size": "auto", 8 | "stage3_prefetch_bucket_size": "auto", 9 | "stage3_param_persistence_threshold": "auto", 10 | "stage3_max_live_parameters": 0, 11 | "stage3_max_reuse_distance": 0, 12 | "stage3_gather_16bit_weights_on_model_save": true 13 | }, 14 | "bf16": { 15 | "enabled": "auto" 16 | }, 17 | "gradient_accumulation_steps": "auto", 18 | "gradient_clipping": "auto", 19 | "train_batch_size": "auto", 20 | "train_micro_batch_size_per_gpu": "auto", 21 | "wall_clock_breakdown": false 22 | } 23 | -------------------------------------------------------------------------------- /train/prepare_conversation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | current_path = os.path.abspath(__file__) 4 | parent_directory = os.path.dirname(os.path.dirname(current_path)) 5 | sys.path.append(parent_directory) 6 | 7 | import torch 8 | import argparse 9 | import transformers 10 | from typing import Dict 11 | from transformers import AutoModelForCausalLM, AutoTokenizer 12 | from generic.special_tokens import * 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--model_name_or_path", type=str, help="The model checkpoint for the model to be merged.") 16 | parser.add_argument("--save_path", type=str, help="The path to save the merged model.") 17 | args = parser.parse_args() 18 | 19 | def smart_tokenizer_and_embedding_resize( 20 | special_tokens_dict: Dict, 21 | tokenizer: transformers.PreTrainedTokenizer, 22 | model: transformers.PreTrainedModel, 23 | ): 24 | """Resize tokenizer and embedding.""" 25 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 26 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) 27 | 28 | if num_new_tokens > 0: 29 | input_embeddings = model.get_input_embeddings().weight.data 30 | output_embeddings = model.get_output_embeddings().weight.data 31 | 32 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 33 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 34 | 35 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 36 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 37 | 38 | tokenizer = AutoTokenizer.from_pretrained( 39 | args.model_name_or_path, 40 | trust_remote_code=True, 41 | ) 42 | 43 | model = AutoModelForCausalLM.from_pretrained( 44 | args.model_name_or_path, 45 | trust_remote_code=True, 46 | torch_dtype=torch.bfloat16, 47 | ) 48 | 49 | tokenizer.pad_token = tokenizer.eos_token 50 | tokenizer.add_special_tokens({"eos_token": IM_END}) 51 | smart_tokenizer_and_embedding_resize( 52 | special_tokens_dict=dict(additional_special_tokens=SPECIAL_WORDS), 53 | tokenizer=tokenizer, 54 | model=model, 55 | ) 56 | 57 | tokenizer.chat_template = [ 58 | { 59 | "name": "default", 60 | "template": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful programming assistant.<|im_end|>\n' }}{% endif %}{% if not loop.first %}{{ '\n' }}{% endif %}{{ '<|im_start|>' + message['role'] }}{% if 'name' in message %}{{ ' name=' + message['name'] }}{% endif %}{{ '\n' + message['content'] + '<|im_end|>' }}{% if loop.last and add_generation_prompt %}{{ '\n<|im_start|>assistant\n' }}{% endif %}{% endfor %}" 61 | }, 62 | { 63 | "name": "assistant-conversation", 64 | "template": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful programming assistant.<|im_end|>\n' }}{% endif %}{% if not loop.first %}{{ '\n' }}{% endif %}{{ '<|im_start|>' + message['role'] }}{% if 'name' in message %}{{ ' name=' + message['name'] }}{% endif %}{{ '\n' + message['content'] + '<|im_end|>' }}{% if loop.last and add_generation_prompt %}{{ '\n<|im_start|>assistant\n<|next_start|>' }}{% endif %}{% endfor %}" 65 | }, 66 | { 67 | "name": "prefix_response", 68 | "template": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful programming assistant.<|im_end|>\n' }}{% endif %}{% if not loop.first %}{{ '\n' }}{% endif %}{{ '<|im_start|>' + message['role'] }}{% if 'name' in message %}{{ ' name=' + message['name'] }}{% endif %}{{ '\n' + message['content'] }}{% if not loop.last %}{{ '<|im_end|>' }}{% endif %}{% endfor %}" 69 | }, 70 | ] 71 | 72 | model.config.eos_token_id = tokenizer.eos_token_id 73 | model.generation_config.eos_token_id = tokenizer.eos_token_id 74 | 75 | tokenizer.save_pretrained(args.save_path) 76 | model.save_pretrained(args.save_path) 77 | -------------------------------------------------------------------------------- /train/prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | current_path = os.path.abspath(__file__) 4 | parent_directory = os.path.dirname(os.path.dirname(current_path)) 5 | sys.path.append(parent_directory) 6 | 7 | import json 8 | import argparse 9 | from generic.special_tokens import * 10 | from generic.utils import decorate_code, extract_changes_lines, generate_locations_changes, generate_search_and_replace 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--input_path", type=str, help="Input Path") 14 | parser.add_argument("--output_path", type=str, help="Output Path") 15 | parser.add_argument("--format_type", type=str, default="wf", help="Format Type for data") 16 | args = parser.parse_args() 17 | 18 | def format_data(data): 19 | """ 20 | Formats the given data based on the specified format type in args. 21 | 22 | Args: 23 | data (list): A list of dictionaries containing the data to be formatted. Each dictionary represents a sample and contains the following keys: 24 | - "history" (list): A list of dictionaries representing the history of messages. Each dictionary contains: 25 | - "code" (str): The code snippet. 26 | - "lang" (str): The programming language of the code snippet. 27 | - "current" (dict): A dictionary representing the current message with keys: 28 | - "code" (str): The current code snippet. 29 | - "lang" (str): The programming language of the current code snippet. 30 | - "loc" (int or tuple): The location(s) in the current code snippet where the target(s) should be inserted. 31 | - "user" (str): The user's message. 32 | - "next" (dict): A dictionary representing the next message with keys: 33 | - "code" (str): The next code snippet. 34 | - "lang" (str): The programming language of the next code snippet. 35 | - "chat" (str): Additional chat message from the assistant. 36 | 37 | Returns: 38 | list: A list of formatted conversations. Each conversation is a list of dictionaries with the following keys: 39 | - "role" (str): The role of the message ("history", "current", "user", "assistant"). 40 | - "content" (str): The content of the message, which may include decorated code or location changes. 41 | 42 | Raises: 43 | ValueError: If the format type specified in args is invalid. 44 | """ 45 | formatted_data = [] 46 | if args.format_type == "wf": 47 | for sample in data: 48 | conversation = [] 49 | if sample["history"]: 50 | for message in sample["history"]: 51 | conversation.append({"role": "history", "content": decorate_code(message["code"], message["lang"])}) 52 | current_loc = sample["current"]["code"] 53 | if sample["loc"]: 54 | if type(sample["loc"]) == int: 55 | current_loc = current_loc[:sample["loc"]] + TARGET + current_loc[sample["loc"]:] 56 | else: 57 | start, end = sample["loc"] 58 | current_loc = current_loc[:end] + TARGET_END + current_loc[end:] 59 | current_loc = current_loc[:start] + TARGET_START + current_loc[start:] 60 | conversation.append({"role": "current", "content": decorate_code(current_loc, sample["current"]["lang"])}) 61 | if sample["user"]: 62 | conversation.append({"role": "user", "content": sample["user"]}) 63 | assistant = "" 64 | assistant += NEXT_START + decorate_code(sample["next"]["code"], sample["next"]["lang"]) + NEXT_END 65 | if sample["chat"]: 66 | assistant += "\n" + sample["chat"] 67 | conversation.append({"role": "assistant", "content": assistant}) 68 | formatted_data.append(conversation) 69 | elif args.format_type == "lc": 70 | for sample in data: 71 | conversation = [] 72 | if sample["history"]: 73 | history_current = sample["history"] + [sample["current"]] 74 | for m1, m2 in zip(history_current[:-1], history_current[1:]): 75 | # In H, we record the modified position of the subsequent code snippet and the modified content of the previous code snippet 76 | # in A, we record the modified position of the previous code snippet and the modified content of the subsequent code snippet 77 | changes_lines = extract_changes_lines(m2["code"], m1["code"]) 78 | locations_changes = generate_locations_changes(m2["code"], m1["code"], m1["lang"], changes_lines) 79 | conversation.append({"role": "history", "content": locations_changes}) 80 | current_loc = sample["current"]["code"] 81 | if sample["loc"]: 82 | if type(sample["loc"]) == int: 83 | current_loc = current_loc[:sample["loc"]] + TARGET + current_loc[sample["loc"]:] 84 | else: 85 | start, end = sample["loc"] 86 | current_loc = current_loc[:end] + TARGET_END + current_loc[end:] 87 | current_loc = current_loc[:start] + TARGET_START + current_loc[start:] 88 | # we add line numbers for assistant to understand the location 89 | conversation.append({"role": "current", "content": decorate_code(current_loc, sample["current"]["lang"], use_line_num=True)}) 90 | if sample["user"]: 91 | conversation.append({"role": "user", "content": sample["user"]}) 92 | assistant = "" 93 | changes_lines = extract_changes_lines(sample["current"]["code"], sample["next"]["code"]) 94 | locations_changes = generate_locations_changes(sample["current"]["code"], sample["next"]["code"], sample["next"]["lang"], changes_lines) 95 | assistant += NEXT_START + locations_changes + NEXT_END 96 | if sample["chat"]: 97 | assistant += "\n" + sample["chat"] 98 | conversation.append({"role": "assistant", "content": assistant}) 99 | formatted_data.append(conversation) 100 | elif args.format_type == "sr": 101 | for sample in data: 102 | conversation = [] 103 | if sample["history"]: 104 | history_current = sample["history"] + [sample["current"]] 105 | for m1, m2 in zip(history_current[:-1], history_current[1:]): 106 | # We ensure that the searched content matches exactly and there are no duplicate paragraphs 107 | changes_lines = extract_changes_lines(m2["code"], m1["code"], unique=True, merge_changes=True) 108 | changes_lines = [(new, old) for old, new in changes_lines] 109 | search_and_replace = generate_search_and_replace(m1["code"], m2["code"], m1["lang"], changes_lines) 110 | conversation.append({"role": "history", "content": search_and_replace}) 111 | current_loc = sample["current"]["code"] 112 | if sample["loc"]: 113 | if type(sample["loc"]) == int: 114 | current_loc = current_loc[:sample["loc"]] + TARGET + current_loc[sample["loc"]:] 115 | else: 116 | start, end = sample["loc"] 117 | current_loc = current_loc[:end] + TARGET_END + current_loc[end:] 118 | current_loc = current_loc[:start] + TARGET_START + current_loc[start:] 119 | conversation.append({"role": "current", "content": decorate_code(current_loc, sample["current"]["lang"])}) 120 | if sample["user"]: 121 | conversation.append({"role": "user", "content": sample["user"]}) 122 | assistant = "" 123 | changes_lines = extract_changes_lines(sample["current"]["code"], sample["next"]["code"], unique=True, merge_changes=True) 124 | search_and_replace = generate_search_and_replace(sample["current"]["code"], sample["next"]["code"], sample["next"]["lang"], changes_lines) 125 | assistant += NEXT_START + search_and_replace + NEXT_END 126 | if sample["chat"]: 127 | assistant += "\n" + sample["chat"] 128 | conversation.append({"role": "assistant", "content": assistant}) 129 | formatted_data.append(conversation) 130 | else: 131 | raise ValueError(f"Invalid format type: {args.format_type}") 132 | return formatted_data 133 | 134 | with open(args.input_path, "r") as f: 135 | input_data = json.load(f) 136 | 137 | output_data = format_data(input_data) 138 | 139 | with open(args.output_path, "w") as f: 140 | json.dump(output_data, f, indent=4) 141 | -------------------------------------------------------------------------------- /train/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 3 | import copy 4 | import json 5 | import numba 6 | import torch 7 | import numpy as np 8 | import transformers 9 | from dataclasses import dataclass, field 10 | from typing import Optional, Dict, List 11 | from concurrent.futures import ThreadPoolExecutor 12 | from torch.utils.data import Dataset, DataLoader, Sampler 13 | from transformers import Trainer, AutoModelForCausalLM, AutoTokenizer 14 | from transformers import DefaultDataCollator, default_data_collator 15 | from transformers.trainer_utils import seed_worker 16 | from transformers.trainer_pt_utils import LabelSmoother 17 | IGNORE_INDEX = LabelSmoother.ignore_index 18 | 19 | @dataclass 20 | class ModelArguments: 21 | model_name_or_path: Optional[str] = field(default="deepseek-ai/DeepSeek-Coder-V2-Lite-Base") 22 | 23 | @dataclass 24 | class DataArguments: 25 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 26 | num_proc: int = field(default=8, metadata={"help": "Number of processes to use for data preprocessing."}) 27 | 28 | @dataclass 29 | class TrainingArguments(transformers.TrainingArguments): 30 | model_max_length: int = field(default=10000, metadata={"help": "Maximum sequence length."}) 31 | batch_max_length: int = field(default=25000, metadata={"help": "Maximum batch length."}) 32 | 33 | def preprocess( 34 | list_data_dict: List, 35 | tokenizer: transformers.PreTrainedTokenizer, 36 | num_proc: int, 37 | ) -> Dict: 38 | """Preprocess the data by tokenizing.""" 39 | examples = [tokenizer.apply_chat_template(l, tokenize=False) for l in list_data_dict] 40 | sources = [tokenizer.apply_chat_template(l[:-1], tokenize=False, add_generation_prompt=True) for l in list_data_dict] 41 | 42 | """Tokenize a list of strings.""" 43 | def tokenize_text(text): 44 | return tokenizer( 45 | text, 46 | return_tensors="pt", 47 | padding="longest", 48 | max_length=tokenizer.model_max_length, 49 | truncation=True, 50 | ) 51 | 52 | with ThreadPoolExecutor(max_workers=num_proc) as executor: 53 | examples_tokenized = list(executor.map(tokenize_text, examples)) 54 | sources_tokenized = list(executor.map(tokenize_text, sources)) 55 | 56 | input_ids = [tokenized.input_ids[0].tolist() for tokenized in examples_tokenized] 57 | labels = copy.deepcopy(input_ids) 58 | source_lens = [len(tokenized.input_ids[0]) for tokenized in sources_tokenized] 59 | for label, source_len in zip(labels, source_lens): 60 | label[:source_len] = [IGNORE_INDEX] * source_len 61 | return dict(input_ids=input_ids, labels=labels) 62 | 63 | @numba.njit 64 | def ffd(lengths, batch_max_length): 65 | """ 66 | First-Fit Decreasing (FFD) algorithm for bin packing. 67 | 68 | This function sorts the input lengths in decreasing order and then attempts to pack them into bins 69 | such that the sum of lengths in each bin does not exceed the specified batch_max_length. It returns 70 | the indices of the original lengths in each bin. 71 | 72 | Args: 73 | lengths (list or array-like): A list or array of lengths to be packed into bins. 74 | batch_max_length (int): The maximum allowable length for each bin. 75 | 76 | Returns: 77 | list of lists: A list where each sublist contains the indices of the original lengths that have 78 | been packed into the corresponding bin. 79 | """ 80 | lengths = np.array(lengths) 81 | indices = np.argsort(lengths)[::-1] 82 | lengths = lengths[indices] 83 | bins = [] 84 | bins_result = [] 85 | for lengths_id, size in enumerate(lengths): 86 | add_new = True 87 | for idx in range(len(bins)): 88 | if bins[idx] >= size: 89 | bins[idx] -= size 90 | bins_result[idx].append(indices[lengths_id]) 91 | add_new = False 92 | break 93 | if add_new: 94 | bins.append(batch_max_length - size) 95 | bins_result.append([indices[lengths_id]]) 96 | return bins_result 97 | 98 | class PackSampler(Sampler): 99 | def __init__(self, batch_max_length: int, lengths: List[int], seed: int = 0): 100 | batches = ffd(lengths, batch_max_length) 101 | indices = np.random.default_rng(seed=seed).permutation(len(batches)) 102 | self.batches = [batches[idx] for idx in indices] 103 | 104 | def __iter__(self): 105 | return iter(self.batches) 106 | 107 | def __len__(self): 108 | return len(self.batches) 109 | 110 | class SupervisedDataset(Dataset): 111 | """Dataset for supervised fine-tuning.""" 112 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, num_proc: int): 113 | super(SupervisedDataset, self).__init__() 114 | with open(data_path, "r") as json_file: 115 | list_data_dict = json.load(json_file) 116 | data_dict = preprocess(list_data_dict, tokenizer, num_proc=num_proc) 117 | self.input_ids = data_dict["input_ids"] 118 | self.labels = data_dict["labels"] 119 | 120 | def __len__(self): 121 | return len(self.input_ids) 122 | 123 | def __getitem__(self, index) -> Dict[str, List[int]]: 124 | return [dict(input_ids=self.input_ids[i], labels=self.labels[i]) for i in index] 125 | 126 | @dataclass 127 | class DataCollatorWithFlattening(DefaultDataCollator): 128 | """Collate examples for supervised fine-tuning.""" 129 | def __init__(self, *args, return_position_ids=True, **kwargs): 130 | super().__init__(*args, **kwargs) 131 | self.return_position_ids = return_position_ids 132 | 133 | def __call__(self, features, return_tensors=None): 134 | features = [item for feature in features for item in feature] 135 | if return_tensors is None: 136 | return_tensors = self.return_tensors 137 | is_labels_provided = "labels" in features[0] 138 | ret = {"input_ids": [], "labels": []} 139 | if self.return_position_ids: 140 | ret.update({"position_ids": []}) 141 | for idx in range(0, len(features)): 142 | ret["input_ids"] += features[idx]["input_ids"] 143 | if is_labels_provided: 144 | ret["labels"] += [IGNORE_INDEX] + features[idx]["labels"][1:] 145 | else: 146 | ret["labels"] += [IGNORE_INDEX] + features[idx]["input_ids"][1:] 147 | if self.return_position_ids: 148 | ret["position_ids"] += list(range(len(features[idx]["input_ids"]))) 149 | return default_data_collator([ret], return_tensors) 150 | 151 | class CustomTrainer(Trainer): 152 | def __init__(self, sampler, *args, **kwargs): 153 | super().__init__(*args, **kwargs) 154 | self.sampler = sampler 155 | 156 | def get_train_dataloader(self) -> DataLoader: 157 | """Returns the training [`~torch.utils.data.DataLoader`].""" 158 | dataloader_params = { 159 | "collate_fn": self.data_collator, 160 | "num_workers": self.args.dataloader_num_workers, 161 | "pin_memory": self.args.dataloader_pin_memory, 162 | "persistent_workers": self.args.dataloader_persistent_workers, 163 | "sampler": self.sampler, 164 | "drop_last": self.args.dataloader_drop_last, 165 | "worker_init_fn": seed_worker, 166 | "prefetch_factor": self.args.dataloader_prefetch_factor 167 | } 168 | return self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) 169 | 170 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: 171 | """Make dataset and collator for supervised fine-tuning.""" 172 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, num_proc=data_args.num_proc) 173 | data_collator = DataCollatorWithFlattening() 174 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 175 | 176 | def train(): 177 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 178 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 179 | tokenizer = AutoTokenizer.from_pretrained( 180 | model_args.model_name_or_path, 181 | trust_remote_code=True, 182 | model_max_length=training_args.model_max_length, 183 | use_fast=True, 184 | ) 185 | model = AutoModelForCausalLM.from_pretrained( 186 | model_args.model_name_or_path, 187 | trust_remote_code=True, 188 | torch_dtype=torch.bfloat16, 189 | attn_implementation="flash_attention_2", 190 | ) 191 | model.gradient_checkpointing_enable() 192 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 193 | lengths = [len(input_id) for input_id in data_module["train_dataset"].input_ids] 194 | sampler = PackSampler(batch_max_length=training_args.batch_max_length, lengths=lengths, seed=training_args.seed) 195 | trainer = CustomTrainer(sampler=sampler, model=model, tokenizer=tokenizer, args=training_args, **data_module) 196 | model.config.use_cache = False 197 | # model.model.forward = torch.compile(model.model.forward) 198 | trainer.train() 199 | trainer.save_model(training_args.output_dir) 200 | 201 | if __name__ == "__main__": 202 | train() 203 | --------------------------------------------------------------------------------