├── .env ├── .gitattribute ├── .gitignore ├── .pre-commit-config.yaml ├── DevGuide.md ├── LICENSE ├── README.md ├── data ├── code │ ├── bad_code_1.py │ ├── bad_code_2.py │ ├── code_with_slash.py │ ├── dummy │ │ ├── __init__.py │ │ ├── dummy_1.py │ │ └── dummy_2.py │ ├── env_code_1.py │ ├── env_code_2.py │ └── good_code_1.py ├── coeditor_link.pkl ├── ex_repo │ ├── env_code_1.py │ └── env_code_2.py ├── repos_split.pkl └── useful_repos.pkl ├── install-all.bash ├── notebooks ├── analyze_data.ipynb ├── code_completion_format.ipynb ├── code_completion_inspect.ipynb ├── download_data.ipynb ├── multi_round_inspect.ipynb ├── profile_analysis.ipynb ├── profile_model.ipynb ├── run_api.ipynb └── run_open_ai_model.ipynb ├── poetry.lock ├── pyproject.toml ├── requirements.txt ├── scripts ├── code_completion_eval.py ├── install-deps-linux.sh ├── multi_round_eval.py ├── prepare_data.py ├── single_round_eval.py ├── start_server.py └── train_model.py ├── src └── coeditor │ ├── __init__.py │ ├── _utils.py │ ├── c3problem.py │ ├── change.py │ ├── common.py │ ├── dataset.py │ ├── encoding.py │ ├── experiments │ ├── code_completion.py │ ├── in_coder.py │ ├── openai_gpt.py │ ├── santa_coder.py │ └── star_coder.py │ ├── git.py │ ├── model.py │ ├── scoped_changes.py │ ├── service.py │ └── tk_array.py └── tests ├── __init__.py ├── test_analysis.py ├── test_edits.py ├── test_scoped_change.py └── testcases ├── defs.py ├── example.py └── usages.py /.env: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrVPlusOne/Coeditor/4d645e1293f6d5a9a93dd893b0b608dc54153a6a/.env -------------------------------------------------------------------------------- /.gitattribute: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | notebooks/scratch.ipynb 2 | config 3 | *.egg-info 4 | code_output 5 | temp 6 | wandb 7 | checkpoints 8 | caches 9 | lightning_logs/ 10 | output/ 11 | coeditor_logs/ 12 | *.coeditor_logs 13 | 14 | build/ 15 | __pycache__ 16 | *.py[cod] 17 | *~ 18 | /build 19 | /env*/ 20 | docs/build/ 21 | docs/source/_build 22 | mypyc/doc/_build 23 | *.iml 24 | /out/ 25 | .venv 26 | venv/ 27 | mypy_temp 28 | .mypy_cache/ 29 | .incremental_checker_cache.json 30 | .cache 31 | dmypy.json 32 | .dmypy.json 33 | 34 | # Packages 35 | *.egg 36 | *.egg-info 37 | *.eggs 38 | 39 | # IDEs 40 | .idea 41 | .vscode 42 | 43 | # vim temporary files 44 | .*.sw? 45 | *.sw? 46 | 47 | # Operating Systems 48 | .DS_Store 49 | 50 | # Coverage Files 51 | htmlcov 52 | .coverage* 53 | 54 | # pytest cache 55 | .pytest_cache/ 56 | 57 | # virtualenv 58 | .Python 59 | bin/ 60 | lib/ 61 | include/ 62 | .python-version 63 | pyvenv.cfg 64 | 65 | .tox 66 | pip-wheel-metadata 67 | 68 | 69 | test_capi 70 | *.o 71 | *.a 72 | test_capi 73 | /.mypyc-flake8-cache.json 74 | /mypyc/lib-rt/build/ 75 | /mypyc/lib-rt/*.so 76 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: 'tests/testcases' 2 | 3 | default_language_version: 4 | python: python3.11 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v3.2.0 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: end-of-file-fixer 12 | - id: check-yaml 13 | - id: check-added-large-files 14 | 15 | - repo: https://github.com/pycqa/isort 16 | rev: 5.11.4 17 | hooks: 18 | - id: isort 19 | args: ["--profile", "black", "--filter-files"] 20 | 21 | - repo: https://github.com/psf/black 22 | rev: 22.12.0 23 | hooks: 24 | - id: black 25 | -------------------------------------------------------------------------------- /DevGuide.md: -------------------------------------------------------------------------------- 1 | ## Tooling 2 | - Formatter: We use `black` for formatting with the default options. 3 | - Type Checker: We use Pylance for type checking. It's the built-in type checker shipped with the VSCode Python plugin and can be enabled by setting `Python > Anlaysis > Type Checking Mode` to `basic`. 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Coeditor: Leveraging Repo-level Diffs for Code Auto-editing 2 | 3 | Coeditor is a transformer model that auto-edits your code based on your recent code changes. This repo includes the server code for the [Coeditor VSCode extension](https://marketplace.visualstudio.com/items?itemName=JiayiWei.vscode-coeditor) and the scripts for data processing, model training, and evaluation. The ideas behind Coeditor are presented in the ICLR Spotlight paper, [Coeditor: Leveraging Repo-level Diffs for Code Auto-editing](https://openreview.net/forum?id=ALVwQjZRS8), by Jiayi Wei, Greg Durrett, and Isil Dillig. 4 | 5 | Watch the [Coeditor demo](https://www.youtube.com/watch?v=hjZE__jslzs) on Youtube. 6 | 7 | ## Installation 8 | 9 | ### Method 1: with Poetry (recommended) 10 | 11 | This project uses [poetry](https://python-poetry.org) to manage the package dependencies. Poetry records all dependencies in the `pyproject.toml` file and manages the (project-specific) virtual environment for you. 12 | 13 | You can install poetry via the following command: 14 | 15 | ```bash 16 | curl -sSL https://install.python-poetry.org | python3 - 17 | poetry completions bash >> ~/.bash_completion 18 | ``` 19 | 20 | To install all dependencies required by Coeditor, make sure you have python 3.11 installed, then, run the following at the project root: 21 | 22 | ```bash 23 | poetry install 24 | ``` 25 | 26 | You can then spawn a shell within the project's virtual environment via `poetry shell`. 27 | 28 | ### Method 2: using requirements.txt 29 | 30 | Alternatively, you can also install all dependencies using the exported [`requirements.txt`](requirements.txt) file. 31 | 32 | ```bash 33 | pip3 install -r requirements.txt 34 | ``` 35 | 36 | ## Usages 37 | 38 | **Note**: All scripts below should be run within the poetry shell (or the virtual environment in which you installed all the dependencies). 39 | 40 | ### Use the VSCode extension server✨ 41 | 42 | Run [`python scripts/start_server.py`](scripts/start_server.py) to start the Coeditor VSCode extension server. This will download the pre-trained Coeditor model from Huggingface (if not already) and start the extension service at port 5042. 43 | 44 | ### Run Coeditor inside a notebook 45 | - As an alternative to using the VSCode extension, you can directly run Coeditor inside [this notebook](notebooks/run_api.ipynb) by specifying a target file and line nubmer. 46 | 47 | ### Run unit tests 48 | 49 | You can run all unit tests via `poetry run pytest` (or just `pytest` if you run inside the poetry shell). 50 | 51 | ### Download the PyCommits dataset 52 | 53 | 1. (Optional) Configure the directories. Create the file `config/coeditor.json` and use the following template to specify where you want to store the dataset and the trained models: 54 | 55 | ```json 56 | { 57 | "datasets_root": "/path/to/datasets/directory", 58 | "models_root": "/path/to/models/direcotry" 59 | } 60 | ``` 61 | 62 | 2. Run the cells in [notebooks/download_data.ipynb](notebooks/download_data.ipynb) to clone the repos from GitHub. Note that we use the GitHub search API to search for repos with permissive licenses, so the results may change over time even though the query remains the same. 63 | 64 | 3. (Optional) Run [scripts/prepare_data.py](scripts/prepare_data.py) to preprocess the repos into the PyCommits format introduced in the paper. You can safely skip this step since it will automatically be run when you train a new model (and with the corresponding encoder parameters). 65 | 66 | ### Train a new model 67 | 68 | Use the [scripts/train_model.py](scripts/train_model.py) script to train a new model from scratch. By default, this script trains a model under our default settings, but you can uncomment the corresponding function calls at the bottom of the script to train a model following one of the ablation settings in the paper. 69 | 70 | **Note**: Only training with a single GPU is tested. You can set the GPU to use via the `CUDA_VISIBLE_DEVICES` environment variable. 71 | 72 | ### Evaluate pre-trained models 73 | 74 | - **Comparison with Code Completion Approaches**: Run [scripts/code_completion_eval.py](scripts/code_completion_eval.py) to obtain the results reported in section 4.1 of the paper. 75 | - **Multi-round editing**: Run [scripts/multi_round_eval.py](scripts/multi_round_eval.py) to obtain the results reported in section 4.2 of the paper. 76 | - **Ablation Studies**: Run [scripts/single_round_eval.py](scripts/single_round_eval.py) to obtain the results reported in section 4.3 of the paper. 77 | 78 | 79 | ## Citation 80 | Please cite our paper as: 81 | ``` 82 | @inproceedings{ 83 | wei2024coeditor, 84 | title={Coeditor: Leveraging Repo-level Diffs for Code Auto-editing}, 85 | author={Jiayi Wei and Greg Durrett and Isil Dillig}, 86 | booktitle={The Twelfth International Conference on Learning Representations}, 87 | year={2024}, 88 | url={https://openreview.net/forum?id=ALVwQjZRS8} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /data/code/bad_code_1.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | # A recursive fibonacci function 5 | def fib(n: str) -> list[int]: 6 | if n == 0: 7 | return 0 8 | elif n == 1: 9 | return 1 10 | else: 11 | return fib(n - 1) + fib(n - 2) 12 | 13 | 14 | def t_add(x: str, y: str) -> int: 15 | r = x + y 16 | return r 17 | 18 | 19 | x: int = fib(3) 20 | bad_y: str = 1 21 | -------------------------------------------------------------------------------- /data/code/bad_code_2.py: -------------------------------------------------------------------------------- 1 | from bad_code_1 import fib 2 | 3 | i: int = 4 4 | fib(i) 5 | -------------------------------------------------------------------------------- /data/code/code_with_slash.py: -------------------------------------------------------------------------------- 1 | class SlashClass: 2 | def __init__(self, check_interval: int, folder: Path, /) -> None: 3 | self._autolocked: Dict[Path, int] = {} 4 | self._lockers: Dict[Path, "DirectEdit"] = {} 5 | self._to_lock: Items = [] 6 | -------------------------------------------------------------------------------- /data/code/dummy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrVPlusOne/Coeditor/4d645e1293f6d5a9a93dd893b0b608dc54153a6a/data/code/dummy/__init__.py -------------------------------------------------------------------------------- /data/code/dummy/dummy_1.py: -------------------------------------------------------------------------------- 1 | def f_int(x: int) -> int: 2 | return x 3 | -------------------------------------------------------------------------------- /data/code/dummy/dummy_2.py: -------------------------------------------------------------------------------- 1 | from dummy.dummy_1 import f_int 2 | 3 | s: str = f_int(2) 4 | -------------------------------------------------------------------------------- /data/code/env_code_1.py: -------------------------------------------------------------------------------- 1 | # Env example 1: no existing annotations 2 | 3 | 4 | def fib(n): 5 | if n == 0: 6 | return 0 7 | elif n == 1: 8 | return 1 9 | else: 10 | return fib(n - 1) + fib(n - 2) 11 | 12 | 13 | def foo(bar): 14 | return fib(bar) 15 | 16 | 17 | def int_add(a, b): 18 | return a + b + "c" 19 | 20 | 21 | def int_tripple_add(a, b, c): 22 | return a + b + c 23 | -------------------------------------------------------------------------------- /data/code/env_code_2.py: -------------------------------------------------------------------------------- 1 | # Env example 2: some existing annotations 2 | 3 | from typing import * 4 | 5 | 6 | def fib(n: int): 7 | if n == 0: 8 | return 0 9 | elif n == 1: 10 | return 1 11 | else: 12 | return fib(n - 1) + fib(n - 2) 13 | 14 | 15 | def foo(bar: int): 16 | return fib(bar) 17 | 18 | 19 | class Bar: 20 | z: str = "hello" 21 | w: str 22 | 23 | def __init__(self, x: int): 24 | self.x: int = x 25 | self.y: Optional[int] = None 26 | self.reset(self.z) 27 | 28 | def reset(self, w0): 29 | self.w = w0 30 | 31 | def foo(self, z: str) -> int: 32 | return self.x + len(z) 33 | 34 | 35 | bar: Bar = Bar(3) 36 | -------------------------------------------------------------------------------- /data/code/good_code_1.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any # [added by SPOT] 3 | from typing import Optional 4 | 5 | print(math.sin(4)) 6 | 7 | x_str: str = "x" 8 | y: Any = 1 9 | z_str: str = x_str + y 10 | 11 | 12 | class Foo: 13 | def __init__(self, x: int): 14 | self.x: int = x 15 | self.y: Optional[int] = None 16 | self.z = "hello" 17 | 18 | def foo(self, z: str) -> int: 19 | return self.x + len(z) 20 | -------------------------------------------------------------------------------- /data/coeditor_link.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrVPlusOne/Coeditor/4d645e1293f6d5a9a93dd893b0b608dc54153a6a/data/coeditor_link.pkl -------------------------------------------------------------------------------- /data/ex_repo/env_code_1.py: -------------------------------------------------------------------------------- 1 | # Env example 1: no existing annotations 2 | 3 | good = 5 4 | 5 | 6 | def fib(n): 7 | if n == 0: 8 | return 0 9 | elif n == 1: 10 | return 1 11 | else: 12 | return fib(n - 1) + fib(n - 2) 13 | 14 | 15 | class Wrapper: 16 | x_elem: int 17 | 18 | @staticmethod 19 | def foo(bar): 20 | return fib(bar) 21 | 22 | def inc(self): 23 | self.x_elem += 1 24 | 25 | 26 | def int_add(a, b): 27 | return a + b + "c" 28 | 29 | 30 | def int_tripple_add(a, b, c): 31 | return a + b + c 32 | -------------------------------------------------------------------------------- /data/ex_repo/env_code_2.py: -------------------------------------------------------------------------------- 1 | # Env example 2: some existing annotations 2 | 3 | from typing import * 4 | 5 | 6 | def fib(n: int): 7 | if n == 0: 8 | return 0 9 | elif n == 1: 10 | return 1 11 | else: 12 | return fib(n - 1) + fib(n - 2) 13 | 14 | 15 | def foo(bar: int): 16 | return fib(bar) 17 | 18 | 19 | class Bar: 20 | z: str = "hello" 21 | w: str 22 | 23 | def __init__(self, x: int): 24 | self.x: int = x 25 | self.y: Optional[int] = None 26 | self.reset(self.z) 27 | 28 | def reset(self, w0): 29 | self.w = w0 30 | 31 | def foo(self, z: str) -> int: 32 | return self.x + len(z) 33 | 34 | 35 | bar: Bar = Bar(3) 36 | -------------------------------------------------------------------------------- /data/repos_split.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrVPlusOne/Coeditor/4d645e1293f6d5a9a93dd893b0b608dc54153a6a/data/repos_split.pkl -------------------------------------------------------------------------------- /data/useful_repos.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrVPlusOne/Coeditor/4d645e1293f6d5a9a93dd893b0b608dc54153a6a/data/useful_repos.pkl -------------------------------------------------------------------------------- /install-all.bash: -------------------------------------------------------------------------------- 1 | # Assuming on ubuntu, run commands to install all depedencies 2 | 3 | curl -sSL https://install.python-poetry.org | python3 - 4 | poetry completions bash >> ~/.bash_completion 5 | 6 | sudo add-apt-repository ppa:deadsnakes/ppa 7 | sudo apt install python3.11 8 | 9 | poetry install -------------------------------------------------------------------------------- /notebooks/download_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "from coeditor.common import *\n", 13 | "import os\n", 14 | "\n", 15 | "import requests\n", 16 | "import shutil\n", 17 | "import random\n", 18 | "\n", 19 | "os.chdir(proj_root())" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 25, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import requests\n", 29 | "import dateparser\n", 30 | "from coeditor.git import GitRepo\n", 31 | "import warnings\n", 32 | "import time\n", 33 | "\n", 34 | "\n", 35 | "def request_page(page: int, license: str, n_items: int = 10):\n", 36 | " if Path(\"config/github_token.txt\").exists():\n", 37 | " token = Path(\"config/github_token.txt\").read_text().strip()\n", 38 | " headers = {\n", 39 | " \"Authorization\": f\"Bearer {token}\"\n", 40 | " }\n", 41 | " else:\n", 42 | " headers = None\n", 43 | " return requests.get(\n", 44 | " f\"https://api.github.com/search/repositories?q=NOT+interview+NOT+reference+NOT+course+NOT+cheatsheet+created%3A>2018-01-01+stars%3A>100+size%3A<20000+language%3APython+license%3A{license}&sort=stars&order=desc&per_page={n_items}&page={page}\",\n", 45 | " headers=headers,\n", 46 | " ).json()\n", 47 | "\n", 48 | "\n", 49 | "def fetch_python_repos(license2counts: dict[str, int]):\n", 50 | " n_repos = sum(license2counts.values())\n", 51 | " repos = dict[str, GitRepo]()\n", 52 | " with tqdm(total=n_repos) as pbar:\n", 53 | " for license, n_repos in license2counts.items():\n", 54 | " for i in range(1, n_repos // 100 + 1):\n", 55 | " page = request_page(i, n_items=100, license=license)\n", 56 | " if (msg := page.get(\"message\", \"\")) and msg.startswith(\n", 57 | " \"API rate limit exceeded\"\n", 58 | " ):\n", 59 | " print(\"API rate limit exceeded, now wait for 1 min\")\n", 60 | " time.sleep(60)\n", 61 | " continue\n", 62 | " if not page.get(\"items\"):\n", 63 | " print(\"Fetching page failed:\")\n", 64 | " print(page)\n", 65 | " break\n", 66 | " for item in page[\"items\"]:\n", 67 | " r = GitRepo.from_github_item(item)\n", 68 | " if not r.archived:\n", 69 | " if r.authorname() in repos:\n", 70 | " print(f\"[warning] {r.authorname()} already in repos\")\n", 71 | " repos[r.authorname()] = r\n", 72 | " pbar.update(len(page[\"items\"]))\n", 73 | " return [repos[k] for k in list(repos)]\n" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 10, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/plain": [ 84 | "{'mit': 7386, 'apache-2.0': 2809, 'bsd-3-clause': 523, 'bsd-2-clause': 149}" 85 | ] 86 | }, 87 | "execution_count": 10, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "{\n", 94 | " l: int(request_page(0, l, n_items=1)[\"total_count\"])\n", 95 | " for l in [\"mit\", \"apache-2.0\", \"bsd-3-clause\", \"bsd-2-clause\"]\n", 96 | "}\n" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 23, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stderr", 106 | "output_type": "stream", 107 | "text": [ 108 | "100%|██████████| 2500/2500 [01:17<00:00, 32.10it/s]" 109 | ] 110 | }, 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "Repos: 2445\n" 116 | ] 117 | }, 118 | { 119 | "name": "stderr", 120 | "output_type": "stream", 121 | "text": [ 122 | "\n" 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "license2counts = {\n", 128 | " \"mit\": 1000,\n", 129 | " \"apache-2.0\": 1000,\n", 130 | " \"bsd-3-clause\": 500,\n", 131 | "}\n", 132 | "\n", 133 | "all_repos = fetch_python_repos(license2counts)\n", 134 | "print(\"Repos:\", len(all_repos))" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 26, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "name": "stderr", 144 | "output_type": "stream", 145 | "text": [ 146 | "downloading repos: 100%|██████████| 2445/2445 [22:13<00:00, 1.83it/s]" 147 | ] 148 | }, 149 | { 150 | "name": "stdout", 151 | "output_type": "stream", 152 | "text": [ 153 | "Successfully downloaded: 2444\n" 154 | ] 155 | }, 156 | { 157 | "name": "stderr", 158 | "output_type": "stream", 159 | "text": [ 160 | "\n" 161 | ] 162 | } 163 | ], 164 | "source": [ 165 | "dataset_name = \"perm2K\" # permissive licensed 2K repos\n", 166 | "repos_dir = get_dataset_dir(dataset_name)\n", 167 | "(repos_dir / \"downloading\").mkdir(exist_ok=True, parents=True)\n", 168 | "(repos_dir / \"downloaded\").mkdir(exist_ok=True, parents=True)\n", 169 | "\n", 170 | "downloaded = pmap(\n", 171 | " GitRepo.download,\n", 172 | " all_repos,\n", 173 | " key_args={\"repos_dir\": repos_dir, \"full_history\": True},\n", 174 | " desc=\"downloading repos\",\n", 175 | " max_workers=4,\n", 176 | " chunksize=1,\n", 177 | ")\n", 178 | "\n", 179 | "print(\"Successfully downloaded:\", sum(downloaded))\n", 180 | "downloaded_repos = [r for r, d in zip(all_repos, downloaded) if d]\n" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 27, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "Successfully downloaded: 2444\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "print(\"Successfully downloaded:\", len(downloaded_repos))" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 32, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "name": "stderr", 207 | "output_type": "stream", 208 | "text": [ 209 | "100%|██████████| 2444/2444 [00:31<00:00, 77.33it/s]" 210 | ] 211 | }, 212 | { 213 | "name": "stdout", 214 | "output_type": "stream", 215 | "text": [ 216 | "After filtering by commits: 1664\n" 217 | ] 218 | }, 219 | { 220 | "name": "stderr", 221 | "output_type": "stream", 222 | "text": [ 223 | "\n" 224 | ] 225 | } 226 | ], 227 | "source": [ 228 | "# now filter out repos with less than 50 commits\n", 229 | "filtered_repos = [r for r in tqdm(downloaded_repos) if r.count_commits(repos_dir) >= 50]\n", 230 | "print(\"After filtering by commits:\", len(filtered_repos))" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 38, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "name": "stdout", 240 | "output_type": "stream", 241 | "text": [ 242 | "Totoal duplicates: 15\n", 243 | "After filtering duplicates: 1650\n" 244 | ] 245 | } 246 | ], 247 | "source": [ 248 | "from coeditor.dataset import get_repo_signature\n", 249 | "\n", 250 | "repo_paths = [repos_dir / \"downloaded\" / r.authorname() for r in filtered_repos]\n", 251 | "sigs = pmap(get_repo_signature, repo_paths, desc=\"getting repo signatures\", chunksize=1)\n", 252 | "sig_groups = groupby(enumerate(sigs), lambda x: x[1])\n", 253 | "\n", 254 | "duplicates = set[str]()\n", 255 | "for sig, group in sig_groups.items():\n", 256 | " if len(group) > 1:\n", 257 | " print(f\"{len(group)} repos have the same signature {sig}:\")\n", 258 | " for i, _ in group:\n", 259 | " print(f\" {downloaded_repos[i].authorname()}\")\n", 260 | " for i, _ in group[1:]:\n", 261 | " duplicates.add(downloaded_repos[i].authorname())\n", 262 | "\n", 263 | "print(\"Totoal duplicates:\", len(duplicates))\n", 264 | "filtered_repos = [r for r in filtered_repos if r.authorname() not in duplicates]\n", 265 | "print(\"After filtering duplicates:\", len(filtered_repos))\n" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 35, 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "n_test=50, n_valid=50, n_train=1550\n" 278 | ] 279 | } 280 | ], 281 | "source": [ 282 | "n_test = 50\n", 283 | "n_valid = 50\n", 284 | "n_train = len(filtered_repos) - n_test - n_valid\n", 285 | "print(f\"n_test={n_test}, n_valid={n_valid}, n_train={n_train}\")\n", 286 | "\n", 287 | "random.seed(42)\n", 288 | "filtered_repos.sort(key=lambda r: r.authorname())\n", 289 | "random.shuffle(filtered_repos)\n", 290 | "\n", 291 | "split = {\n", 292 | " \"test\": filtered_repos[:n_test],\n", 293 | " \"valid\": filtered_repos[n_test : n_test + n_valid],\n", 294 | " \"train\": filtered_repos[n_test + n_valid :][:n_train],\n", 295 | "}\n", 296 | "\n", 297 | "pickle_dump(repos_dir / \"repos_split.pkl\", split)\n" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 36, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "name": "stderr", 307 | "output_type": "stream", 308 | "text": [ 309 | "moving test: 100%|██████████| 50/50 [00:00<00:00, 670.37it/s]\n", 310 | "moving valid: 100%|██████████| 50/50 [00:00<00:00, 716.50it/s]\n", 311 | "moving train: 100%|██████████| 1550/1550 [00:02<00:00, 686.52it/s]\n" 312 | ] 313 | } 314 | ], 315 | "source": [ 316 | "# move downloaded repos to their split group\n", 317 | "for group, rs in split.items():\n", 318 | " for repo in tqdm(rs, desc=f\"moving {group}\"):\n", 319 | " dest = repos_dir / \"repos\" / group\n", 320 | " dest.mkdir(exist_ok=True, parents=True)\n", 321 | " shutil.move(repos_dir / \"downloaded\" / repo.authorname(), dest)\n" 322 | ] 323 | } 324 | ], 325 | "metadata": { 326 | "kernelspec": { 327 | "display_name": "Python 3.10.4 ('.venv': pipenv)", 328 | "language": "python", 329 | "name": "python3" 330 | }, 331 | "language_info": { 332 | "codemirror_mode": { 333 | "name": "ipython", 334 | "version": 3 335 | }, 336 | "file_extension": ".py", 337 | "mimetype": "text/x-python", 338 | "name": "python", 339 | "nbconvert_exporter": "python", 340 | "pygments_lexer": "ipython3", 341 | "version": "3.11.0" 342 | }, 343 | "orig_nbformat": 4, 344 | "vscode": { 345 | "interpreter": { 346 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8" 347 | } 348 | } 349 | }, 350 | "nbformat": 4, 351 | "nbformat_minor": 2 352 | } 353 | -------------------------------------------------------------------------------- /notebooks/multi_round_inspect.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import os\n", 13 | "\n", 14 | "from coeditor.c3problem import C3ProblemGenerator, C3ProblemTokenizer\n", 15 | "from coeditor.common import *\n", 16 | "from coeditor.dataset import make_or_load_dataset\n", 17 | "from coeditor.model import (\n", 18 | " DecodingArgs,\n", 19 | " MultiRoundEvaluator,\n", 20 | " MultiRoundStrategy,\n", 21 | " RetrievalEditorModel,\n", 22 | ")\n", 23 | "\n", 24 | "os.chdir(proj_root())" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stderr", 34 | "output_type": "stream", 35 | "text": [ 36 | "pmap: _process_commits: 100%|██████████| 50/50 [00:05<00:00, 9.31repo/s]\n" 37 | ] 38 | }, 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "Dataset total size (n=1649): 5150.76 MB\n", 44 | "22516\n", 45 | "len(subset)=1000\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "dataset_name = \"perm2k\"\n", 51 | "model_name = get_coeditor_model_path()\n", 52 | "model_device = \"cuda\"\n", 53 | "N_test = 1000\n", 54 | "\n", 55 | "testset = make_or_load_dataset(\n", 56 | " dataset_name,\n", 57 | " C3ProblemGenerator(),\n", 58 | " splits=(\"test\",),\n", 59 | " time_limit_per_commit=40,\n", 60 | ")[\"test\"]\n", 61 | "\n", 62 | "print(f\"{len(testset)}\")\n", 63 | "subset = random_subset(testset, N_test, rng=42)\n", 64 | "print(f\"{len(subset)=}\")" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "import numpy as np\n", 74 | "\n", 75 | "target_file = (\n", 76 | " proj_root() / f\"output/multi_round_eval/{model_name}/most_uncertain-{N_test}.pkl\"\n", 77 | ")\n", 78 | "metric_stats = pickle_load(target_file)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "len(sample_ids)=33\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "sample_ids = [i for i, m in enumerate(metric_stats) if 50 <= m[\"keystrokes\"].total_edit_gain <= 100 and m[\"diff-lines\"].rounds == 1]\n", 96 | "print(f\"{len(sample_ids)=}\")" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "tokenizer = C3ProblemTokenizer.for_eval()\n", 106 | "dec_args = DecodingArgs(do_sample=False, num_beams=1)\n", 107 | "model = RetrievalEditorModel.load(model_name)\n", 108 | "model.to(model_device)\n", 109 | "None" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 6, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "path: db.basedb/BaseDB._insert\n", 122 | "n_references: 5\n", 123 | "total_reference_tks: 1125\n", 124 | "project: qiandao-today~qiandao\n", 125 | "commit: [8d0f0e04df Cursor操作结束后自动关闭; 统一DB连接操作]\n", 126 | "{'input_tks': 234, 'output_tks': 44, 'n_references': 5, 'changed_reference_tks': 957, 'unchanged_reference_tks': 168, 'total_reference_tks': 1125}\n", 127 | "--------------------------------------------------------------------------------\n", 128 | "round:\n", 129 | "1\n", 130 | "========Ground Truth========\n", 131 | "<13>: lastrowid = dbcur.lastrowid\n", 132 | " dbcur.close()\n", 133 | " return lastrowid\n", 134 | " return dbcur.lastrowid\n", 135 | "\n", 136 | "========Main Code========\n", 137 | " # module: db.basedb\n", 138 | " \n", 139 | " \n", 140 | " class BaseDB(object):\n", 141 | " \n", 142 | " def _insert(self, tablename=None, **values):\n", 143 | " <0> tablename = self.escape(tablename or self.__tablename__)\n", 144 | " <1> if values:\n", 145 | " <2> _keys = \", \".join((self.escape(k) for k in values.keys()))\n", 146 | " <3> _values = \", \".join([self.placeholder, ] * len(values))\n", 147 | " <4> sql_query = \"INSERT INTO %s (%s) VALUES (%s)\" % (tablename, _keys, _values)\n", 148 | " <5> else:\n", 149 | " <6> sql_query = \"INSERT INTO %s DEFAULT VALUES\" % tablename\n", 150 | " <7> logger.debug(\"\", sql_query)\n", 151 | " <8> \n", 152 | " <9> if values:\n", 153 | "<10> dbcur = self._execute(sql_query, list(values.values()))\n", 154 | "<11> else:\n", 155 | "<12> dbcur = self._execute(sql_query)\n", 156 | "<13> return dbcur.lastrowid\n", 157 | "<14> \n", 158 | " \n", 159 | "===========unchanged ref 0===========\n", 160 | " at: db.basedb\n", 161 | " logger = logging.getLogger('qiandao.basedb')\n", 162 | " \n", 163 | " at: db.basedb.BaseDB\n", 164 | " placeholder = \"%s\" # mysql\n", 165 | " \n", 166 | " escape(string)\n", 167 | " \n", 168 | " _execute(sql_query, values=[])\n", 169 | " _execute(self, sql_query, values=[])\n", 170 | " \n", 171 | " at: db.basedb.BaseDB._replace\n", 172 | " tablename = self.escape(tablename or self.__tablename__)\n", 173 | " \n", 174 | " at: logging.Logger\n", 175 | " debug(msg: Any, *args: Any, exc_info: _ExcInfoType=..., stack_info: bool=..., extra: Optional[Dict[str, Any]]=..., **kwargs: Any) -> None\n", 176 | " \n", 177 | " \n", 178 | "===========changed ref 0===========\n", 179 | " # module: db.basedb\n", 180 | " \n", 181 | " \n", 182 | " class BaseDB(object):\n", 183 | " + # placeholder = '?' # sqlite3\n", 184 | " + \n", 185 | " + def __init__(self, host=config.mysql.host, port=config.mysql.port,\n", 186 | " + database=config.mysql.database, user=config.mysql.user, passwd=config.mysql.passwd, auth_plugin=config.mysql.auth_plugin):\n", 187 | " + import mysql.connector\n", 188 | " + self.conn = mysql.connector.connect(user=user, password=passwd, host=host, port=port,\n", 189 | " + database=database, auth_plugin=auth_plugin, autocommit=True)\n", 190 | " + \n", 191 | "===========changed ref 1===========\n", 192 | " # module: db.basedb\n", 193 | " \n", 194 | " \n", 195 | " class BaseDB(object):\n", 196 | " \n", 197 | " def _replace(self, tablename=None, **values):\n", 198 | " tablename = self.escape(tablename or self.__tablename__)\n", 199 | " if values:\n", 200 | " _keys = \", \".join(self.escape(k) for k in values.keys())\n", 201 | " _values = \", \".join([self.placeholder, ] * len(values))\n", 202 | " sql_query = \"REPLACE INTO %s (%s) VALUES (%s)\" % (tablename, _keys, _values)\n", 203 | " else:\n", 204 | " sql_query = \"REPLACE INTO %s DEFAULT VALUES\" % tablename\n", 205 | " logger.debug(\"\", sql_query)\n", 206 | " \n", 207 | " if values:\n", 208 | " dbcur = self._execute(sql_query, list(values.values()))\n", 209 | " else:\n", 210 | " dbcur = self._execute(sql_query)\n", 211 | " + lastrowid = dbcur.lastrowid\n", 212 | " + dbcur.close()\n", 213 | " + return lastrowid\n", 214 | " - return dbcur.lastrowid\n", 215 | " \n", 216 | "===========changed ref 2===========\n", 217 | " # module: db.basedb\n", 218 | " \n", 219 | " \n", 220 | " class BaseDB(object):\n", 221 | " \n", 222 | " def _select(self, tablename=None, what=\"*\", where=\"\", where_values=[], offset=0, limit=None):\n", 223 | " tablename = self.escape(tablename or self.__tablename__)\n", 224 | " if isinstance(what, list) or isinstance(what, tuple) or what is None:\n", 225 | " what = ','.join(self.escape(f) for f in what) if what else '*'\n", 226 | " \n", 227 | " sql_query = \"SELECT %s FROM %s\" % (what, tablename)\n", 228 | " if where: sql_query += \" WHERE %s\" % where\n", 229 | " if limit: sql_query += \" LIMIT %d, %d\" % (offset, limit)\n", 230 | " logger.debug(\"\", sql_query)\n", 231 | " \n", 232 | " + dbcur = self._execute(sql_query, where_values)\n", 233 | " - for row in self._execute(sql_query, where_values):\n", 234 | " + for row in dbcur:\n", 235 | " yield [tostr(x) for x in row]\n", 236 | " + dbcur.close()\n", 237 | " \n", 238 | "===========changed ref 3===========\n", 239 | " # module: db.basedb\n", 240 | " \n", 241 | " \n", 242 | " class BaseDB(object):\n", 243 | " def _select2dic(self, tablename=None, what=\"*\", where=\"\", where_values=[], offset=0, limit=None):\n", 244 | " tablename = self.escape(tablename or self.__tablename__)\n", 245 | " if isinstance(what, list) or isinstance(what, tuple) or what is None:\n", 246 | " what = ','.join(self.escape(f) for f in what) if what else '*'\n", 247 | " \n", 248 | " sql_query = \"SELECT %s FROM %s\" % (what, tablename)\n", 249 | " if where: sql_query += \" WHERE %s\" % where\n", 250 | " if limit: sql_query += \" LIMIT %d, %d\" % (offset, limit)\n", 251 | " logger.debug(\"\", sql_query)\n", 252 | " \n", 253 | " dbcur = self._execute(sql_query, where_values)\n", 254 | " fields = [f[0] for f in dbcur.description]\n", 255 | " \n", 256 | " rtv = []\n", 257 | " for row in dbcur:\n", 258 | " rtv.append(dict(zip(fields, [tostr(x) for x in row])))\n", 259 | " #yield dict(zip(fields, [tostr(x) for x in row]))\n", 260 | " + \n", 261 | " + dbcur.close()\n", 262 | " return rtv\n", 263 | " \n", 264 | "========Predicted Changes========\n", 265 | "<13>: lastrowid = dbcur.lastrowid\n", 266 | " dbcur.close()\n", 267 | " return lastrowid\n", 268 | " return dbcur.lastrowid\n", 269 | "\n", 270 | "\u001b[32m========Accepted changes========\u001b[0m\n", 271 | "TkDelta(\n", 272 | " 13: (' lastrowid = dbcur.lastrowid', ' dbcur.close()', ' return lastrowid', '')\n", 273 | ")\n", 274 | "========Accepted gains========\n", 275 | "keystrokes: 50\n", 276 | "diff-lines: 4\n", 277 | "levenshtein: 47\n", 278 | "========Remaining changes========\n", 279 | "TkDelta(\n", 280 | "\n", 281 | ")\n" 282 | ] 283 | }, 284 | { 285 | "data": { 286 | "text/plain": [ 287 | "{'keystrokes': MultiRoundEditStats(label_edit_gain=69, first_edit_gain=69, total_edit_gain=69, rounds=1),\n", 288 | " 'diff-lines': MultiRoundEditStats(label_edit_gain=4, first_edit_gain=4, total_edit_gain=4, rounds=1),\n", 289 | " 'levenshtein': MultiRoundEditStats(label_edit_gain=55, first_edit_gain=55, total_edit_gain=55, rounds=1)}" 290 | ] 291 | }, 292 | "execution_count": 6, 293 | "metadata": {}, 294 | "output_type": "execute_result" 295 | } 296 | ], 297 | "source": [ 298 | "# target_hash=\"30eb3afbe173b75dc5ec44348cef49fe3eac2421\"\n", 299 | "# ex_id = [i for i in sample_ids if subset[i].src_info[\"commit\"].hash == target_hash][0]\n", 300 | "ex_id = sample_ids[5]\n", 301 | "ex = subset[ex_id]\n", 302 | "evaluator = MultiRoundEvaluator(model, tokenizer, dec_args, strategy=\"pick_first\")\n", 303 | "\n", 304 | "evaluator.multi_round_edit_gain(ex, print_steps=True, skip_ctx=False)" 305 | ] 306 | } 307 | ], 308 | "metadata": { 309 | "kernelspec": { 310 | "display_name": ".venv", 311 | "language": "python", 312 | "name": "python3" 313 | }, 314 | "language_info": { 315 | "codemirror_mode": { 316 | "name": "ipython", 317 | "version": 3 318 | }, 319 | "file_extension": ".py", 320 | "mimetype": "text/x-python", 321 | "name": "python", 322 | "nbconvert_exporter": "python", 323 | "pygments_lexer": "ipython3", 324 | "version": "3.11.0" 325 | }, 326 | "orig_nbformat": 4 327 | }, 328 | "nbformat": 4, 329 | "nbformat_minor": 2 330 | } 331 | -------------------------------------------------------------------------------- /notebooks/profile_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%load_ext snakeviz\n", 12 | "%load_ext line_profiler\n", 13 | "\n", 14 | "# turn off autoreload so that we can use the old model \n", 15 | "# when editing the current project\n", 16 | "\n", 17 | "from coeditor.common import *\n", 18 | "import os\n", 19 | "\n", 20 | "os.chdir(proj_root())" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "from coeditor.model import RetrievalEditorModel, AttentionMode, BatchArgs, DecodingArgs\n", 30 | "from coeditor.dataset import load_datasets\n", 31 | "import torch\n", 32 | "import copy" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 5, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "model = RetrievalEditorModel.from_code_t5(\"base\")\n", 42 | "model.to(\"cuda\")\n", 43 | "None" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 6, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stderr", 53 | "output_type": "stream", 54 | "text": [ 55 | "test run: 100%|██████████| 30/30 [00:04<00:00, 7.06it/s]\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "model.profile_run(repeats=30)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 5, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "model_path = get_model_dir(True) / \"coeditor-large-request-stub-v2\"\n", 70 | "model = RetrievalEditorModel.load(model_path)\n", 71 | "model.to(\"cuda\")\n", 72 | "model.attention_mode = AttentionMode.bidirectional\n", 73 | "\n", 74 | "batch_args = copy.deepcopy(BatchArgs())" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 7, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "\u001b[34mnum batches: 7,\u001b[0m \u001b[34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}\u001b[0m\n", 87 | "\u001b[34mnum batches: 7,\u001b[0m \u001b[34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}\u001b[0m\n" 88 | ] 89 | }, 90 | { 91 | "name": "stderr", 92 | "output_type": "stream", 93 | "text": [ 94 | "Training Epoch 0: 100%|██████████| 7/7 [00:10<00:00, 1.57s/it]\n" 95 | ] 96 | }, 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "\u001b[34mnum batches: 7,\u001b[0m \u001b[34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}\u001b[0m\n", 102 | "\u001b[34mnum batches: 7,\u001b[0m \u001b[34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}\u001b[0m\n" 103 | ] 104 | }, 105 | { 106 | "name": "stderr", 107 | "output_type": "stream", 108 | "text": [ 109 | "Training Epoch 0: 100%|██████████| 7/7 [00:10<00:00, 1.57s/it]" 110 | ] 111 | }, 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "11 s ± 3.62 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)\n" 117 | ] 118 | }, 119 | { 120 | "name": "stderr", 121 | "output_type": "stream", 122 | "text": [ 123 | "\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "# query_ref_layer not batched\n", 129 | "%timeit -n 1 -r 2 model.run_on_edits(test_edits, batch_args)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 7, 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "\u001b[34mnum batches: 7,\u001b[0m \u001b[34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}\u001b[0m\n", 142 | "\u001b[34mnum batches: 7,\u001b[0m \u001b[34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}\u001b[0m\n" 143 | ] 144 | }, 145 | { 146 | "name": "stderr", 147 | "output_type": "stream", 148 | "text": [ 149 | "Training Epoch 0: 100%|██████████| 7/7 [00:10<00:00, 1.56s/it]\n" 150 | ] 151 | }, 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "\u001b[34mnum batches: 7,\u001b[0m \u001b[34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}\u001b[0m\n", 157 | "\u001b[34mnum batches: 7,\u001b[0m \u001b[34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}\u001b[0m\n" 158 | ] 159 | }, 160 | { 161 | "name": "stderr", 162 | "output_type": "stream", 163 | "text": [ 164 | "Training Epoch 0: 100%|██████████| 7/7 [00:10<00:00, 1.55s/it]" 165 | ] 166 | }, 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "10.9 s ± 6.59 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)\n" 172 | ] 173 | }, 174 | { 175 | "name": "stderr", 176 | "output_type": "stream", 177 | "text": [ 178 | "\n" 179 | ] 180 | } 181 | ], 182 | "source": [ 183 | "%timeit -n 1 -r 2 model.run_on_edits(test_edits, batch_args)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 21, 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "\u001b[34mnum batches: 7,\u001b[0m \u001b[34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '9.0', 'max': '16.0'}\u001b[0m\n", 196 | "\u001b[34mnum batches: 7,\u001b[0m \u001b[34mbatch stats: {'mean': '14.3', 'median': '16.0', 'min': '4.0', 'max': '16.0'}\u001b[0m\n" 197 | ] 198 | }, 199 | { 200 | "name": "stderr", 201 | "output_type": "stream", 202 | "text": [ 203 | "Training Epoch 0: 100%|██████████| 7/7 [00:10<00:00, 1.56s/it]\n" 204 | ] 205 | }, 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | " \n", 211 | "*** Profile stats marshalled to file '/tmp/tmpjlbfo3f1'.\n", 212 | "Opening SnakeViz in a new tab...\n", 213 | "snakeviz web server started on 127.0.0.1:8080; enter Ctrl-C to exit\n", 214 | "http://127.0.0.1:8080/snakeviz/%2Ftmp%2Ftmpjlbfo3f1\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "%snakeviz -t model.run_on_edits(test_edits, batch_args)" 220 | ] 221 | } 222 | ], 223 | "metadata": { 224 | "kernelspec": { 225 | "display_name": ".venv", 226 | "language": "python", 227 | "name": "python3" 228 | }, 229 | "language_info": { 230 | "codemirror_mode": { 231 | "name": "ipython", 232 | "version": 3 233 | }, 234 | "file_extension": ".py", 235 | "mimetype": "text/x-python", 236 | "name": "python", 237 | "nbconvert_exporter": "python", 238 | "pygments_lexer": "ipython3", 239 | "version": "3.11.0" 240 | }, 241 | "orig_nbformat": 4, 242 | "vscode": { 243 | "interpreter": { 244 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8" 245 | } 246 | } 247 | }, 248 | "nbformat": 4, 249 | "nbformat_minor": 2 250 | } 251 | -------------------------------------------------------------------------------- /notebooks/run_api.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "You can use this notebook to run the edit suggestion service on any project \n", 9 | "by specifying a file and line number." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "\n", 21 | "from coeditor.common import *\n", 22 | "import os\n", 23 | "from coeditor.model import RetrievalEditorModel, AttentionMode, DecodingArgs, EditCostModel\n", 24 | "from coeditor.service import EditPredictionService, ChangeDetector\n", 25 | "from coeditor.c3problem import C3GeneratorCache, C3Problem\n", 26 | "\n", 27 | "os.chdir(proj_root())" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# NOTE: replace below with a target project for which you want to run the service.\n", 37 | "# Currently, we are using the Coeditor project itself as the target.\n", 38 | "target_dir = proj_root()\n", 39 | "\n", 40 | "model_path = get_coeditor_model_path()\n", 41 | "model = RetrievalEditorModel.load(model_path)\n", 42 | "model.to(\"cuda:0\")\n", 43 | "detector = ChangeDetector(target_dir)\n", 44 | "service = EditPredictionService(\n", 45 | " detector,\n", 46 | " model,\n", 47 | ")\n" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 6, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "Target module 'scripts.train_model' has not changed.\n", 60 | "Target span has not changed. Creating a trivial change.\n" 61 | ] 62 | }, 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "Analyzing train_model.py: 100%|██████████| 282/282 [00:00<00:00, 4317.00it/s]\n" 68 | ] 69 | }, 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "Writing logs to: .coeditor_logs\n", 75 | "Target file: scripts/train_model.py\n", 76 | "Edit range: (76, 0) - (106, 0)\n", 77 | "Target lines: 76--105\n", 78 | "\t--------------- Suggestion 0 (score: 1) ---------------\n", 79 | "\t encoder,\n", 80 | "\t remake_problems=recreate_data,\n", 81 | "\t )\n", 82 | " \n", 83 | "\t # limit the number of examples for faster testing\n", 84 | "\t datasets[\"valid\"] = random_subset(eval_probs[\"valid\"], 10000, rng=42)\n", 85 | "\t datasets[\"test\"] = random_subset(eval_probs[\"test\"], 10000, rng=42)\n", 86 | " \n", 87 | "\t config_dict: dict[str, Any] = {\n", 88 | "\t \"description\": description,\n", 89 | "\t \"edit_tokenizer\": encoder.edit_tokenizer.get_args(),\n", 90 | "\t \"batch_args\": batch_args,\n", 91 | "\t \"train_args\": train_args,\n", 92 | "\t \"dec_args\": dec_args,\n", 93 | "\t }\n", 94 | " \n", 95 | "\t project = \"Coeditor\" if not quicktest else \"Coeditor-quicktest\"\n", 96 | "\t if eval_only:\n", 97 | "\t project = \"eval-\" + project\n", 98 | "\t- wandb.init(dir=\"..\", project=project, name=model_name, config=config_dict)\n", 99 | "\t+ wandb.init(dir=\"..\", project=project, name=model_name, config=get_config_dict())\n", 100 | " \n", 101 | "\t if quicktest:\n", 102 | "\t print(\"Using fewer data for quick test.\")\n", 103 | "\t n_quick_exs = 20\n", 104 | "\t datasets = C3ProblemDataset(\n", 105 | "\t train=datasets[\"train\"][:n_quick_exs],\n", 106 | "\t valid=datasets[\"valid\"][:n_quick_exs],\n", 107 | "\t test=datasets[\"test\"][:n_quick_exs],\n", 108 | "\t )\n", 109 | " \n", 110 | " \n", 111 | "\n" 112 | ] 113 | }, 114 | { 115 | "data": { 116 | "text/html": [ 117 | "
\n", 118 | "\n", 131 | "\n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | "
namecountavg_timetotal_time
5run model11.5274781.527478
3model.generate11.5196971.519697
2decoder.forward640.0156661.002632
0get c3 problem10.2011540.201154
1tokenize c3 problem10.0330770.033077
4assemble changes10.0042260.004226
\n", 186 | "
" 187 | ], 188 | "text/plain": [ 189 | " name count avg_time total_time\n", 190 | "5 run model 1 1.527478 1.527478\n", 191 | "3 model.generate 1 1.519697 1.519697\n", 192 | "2 decoder.forward 64 0.015666 1.002632\n", 193 | "0 get c3 problem 1 0.201154 0.201154\n", 194 | "1 tokenize c3 problem 1 0.033077 0.033077\n", 195 | "4 assemble changes 1 0.004226 0.004226" 196 | ] 197 | }, 198 | "execution_count": 6, 199 | "metadata": {}, 200 | "output_type": "execute_result" 201 | } 202 | ], 203 | "source": [ 204 | "# NOTE: specify the target file (relative to target project) and line number below\n", 205 | "\n", 206 | "target_file = \"scripts/train_model.py\"\n", 207 | "target_line = 91\n", 208 | "\n", 209 | "service.tlogger.clear()\n", 210 | "response = service.suggest_edit(to_rel_path(target_file), target_line)\n", 211 | "print(response)\n", 212 | "service.tlogger.as_dataframe()" 213 | ] 214 | } 215 | ], 216 | "metadata": { 217 | "kernelspec": { 218 | "display_name": ".venv", 219 | "language": "python", 220 | "name": "python3" 221 | }, 222 | "language_info": { 223 | "codemirror_mode": { 224 | "name": "ipython", 225 | "version": 3 226 | }, 227 | "file_extension": ".py", 228 | "mimetype": "text/x-python", 229 | "name": "python", 230 | "nbconvert_exporter": "python", 231 | "pygments_lexer": "ipython3", 232 | "version": "3.11.0" 233 | }, 234 | "orig_nbformat": 4, 235 | "vscode": { 236 | "interpreter": { 237 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8" 238 | } 239 | } 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 2 243 | } 244 | -------------------------------------------------------------------------------- /notebooks/run_open_ai_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 5, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os\n", 20 | "from coeditor.common import proj_root\n", 21 | "os.environ[\"OPENAI_API_KEY\"] = (proj_root().parent / \"openai_api_key.txt\").read_text().strip()" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 6, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "You are a programming expert tasked to fill in a missing line for a given Python code \n", 34 | "snippet. The snippet may habe been truncated from both ends, and the missing line is \n", 35 | "indicated by a special token ``. \n", 36 | "You should output the missing line (along with any leading whitespaces) and \n", 37 | "nothing more. For example, if the input is\n", 38 | "```\n", 39 | "def fib(n):\n", 40 | " if n < 2:\n", 41 | "\n", 42 | " else:\n", 43 | " return fib(n-1) + fib(\n", 44 | "```\n", 45 | "Your output should be \" return 1\" (without the quotes) and nothing more.\n", 46 | "\n", 47 | "Now fill in the code snippet below:\n", 48 | "```\n", 49 | "def factorial(n):\n", 50 | " if n == 0: \n", 51 | "\n", 52 | " else:\n", 53 | " return n * factorial(n-1)\n", 54 | "\n", 55 | "\n", 56 | "```\n", 57 | "Your output:\n", 58 | "\n", 59 | "\n", 60 | "--------------------------------------------------------------------------------\n", 61 | "End of Prompt\n", 62 | "--------------------------------------------------------------------------------\n" 63 | ] 64 | }, 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "' return 1'" 69 | ] 70 | }, 71 | "execution_count": 6, 72 | "metadata": {}, 73 | "output_type": "execute_result" 74 | } 75 | ], 76 | "source": [ 77 | "from coeditor.experiments.openai_gpt import OpenAIGptWrapper\n", 78 | "\n", 79 | "prompt = \"\"\"\\\n", 80 | "def factorial(n):\n", 81 | " if n == 0: \n", 82 | "\n", 83 | " else:\n", 84 | " return n * factorial(n-1)\n", 85 | "\n", 86 | "\"\"\"\n", 87 | "prefix, suffix = prompt.split(\"\")\n", 88 | "gpt = OpenAIGptWrapper(use_fim=True, use_nl_prompt=True, print_prompt=True)\n", 89 | "gpt.infill(prefix, suffix, 100)\n" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 8, 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "def factorial(n):\n", 102 | " if n == 0: \n", 103 | "\n", 104 | "\n", 105 | "--------------------------------------------------------------------------------\n", 106 | "End of Prompt\n", 107 | "--------------------------------------------------------------------------------\n" 108 | ] 109 | }, 110 | { 111 | "data": { 112 | "text/plain": [ 113 | "' return 1'" 114 | ] 115 | }, 116 | "execution_count": 8, 117 | "metadata": {}, 118 | "output_type": "execute_result" 119 | } 120 | ], 121 | "source": [ 122 | "gpt = OpenAIGptWrapper(use_fim=False, print_prompt=True)\n", 123 | "prompt = \"\"\"\\\n", 124 | "def factorial(n):\n", 125 | " if n == 0: \n", 126 | "\n", 127 | "\"\"\"\n", 128 | "gpt.infill_lm(prompt, 100)" 129 | ] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": ".venv", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.11.0" 149 | }, 150 | "orig_nbformat": 4 151 | }, 152 | "nbformat": 4, 153 | "nbformat_minor": 2 154 | } 155 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "coeditor" 3 | version = "0.3.0" 4 | description = "Coeditor: AI assisted code editing" 5 | authors = ["Jiayi Wei "] 6 | license = "bsd-3-clause" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "3.11.*" 11 | tqdm = "4.65.*" 12 | dateparser = "1.1.*" 13 | pyrsistent = "0.19.*" 14 | pandas = "^1.4" 15 | torch = "2.*" 16 | datasets = "2.8.*" 17 | wandb = "0.13.*" 18 | colored = "1.4.*" 19 | termcolor = "1.0.*" 20 | prettytable = "3.4.*" 21 | nltk = "3.8.*" 22 | jsonrpcserver = "5.0.*" 23 | jedi = "~0.18.2" 24 | parso = "0.8.*" 25 | cachetools = "5.3.*" 26 | editdistance = "~0.6.2" 27 | transformers = "~4.31.0" 28 | openai = "^0.27.8" 29 | tiktoken = "^0.4.0" 30 | tenacity = "^8.2.2" 31 | 32 | 33 | [tool.poetry.group.dev.dependencies] 34 | pytest = "^7.2.2" 35 | black = "^23.3.0" 36 | snakeviz = "^2.1.1" 37 | line-profiler = "^4.0.3" 38 | matplotlib = "^3.7.1" 39 | ipykernel = "^6.22.0" 40 | 41 | [build-system] 42 | requires = ["poetry-core"] 43 | build-backend = "poetry.core.masonry.api" 44 | 45 | 46 | [tool.pylint] 47 | disable = "invalid-name, wildcard-import, unused-wildcard-import, unused-import, redefined-outer-name" 48 | 49 | 50 | [tool.pylint.'MESSAGES CONTROL'] 51 | # Torch has generated members that confuse pylint. 52 | # Disabling these messages for torch 53 | generated-members="torch.*" 54 | max-locals=100 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.4 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 2 | aiosignal==1.3.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 3 | appdirs==1.4.4 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 4 | appnope==0.1.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" and platform_system == "Darwin" or python_full_version >= "3.11.0" and python_full_version < "3.12.0" and sys_platform == "darwin" 5 | asttokens==2.2.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 6 | async-timeout==4.0.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 7 | attrs==22.2.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 8 | backcall==0.2.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 9 | black==23.3.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 10 | cachetools==5.3.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 11 | certifi==2022.12.7 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 12 | cffi==1.15.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" and implementation_name == "pypy" 13 | charset-normalizer==3.1.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 14 | click==8.1.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 15 | cmake==3.26.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 16 | colorama==0.4.6 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" and platform_system == "Windows" or python_full_version >= "3.11.0" and python_full_version < "3.12.0" and sys_platform == "win32" 17 | colored==1.4.4 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 18 | comm==0.1.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 19 | contourpy==1.0.7 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 20 | cycler==0.11.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 21 | datasets==2.8.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 22 | dateparser==1.1.8 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 23 | debugpy==1.6.7 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 24 | decorator==5.1.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 25 | dill==0.3.6 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 26 | docker-pycreds==0.4.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 27 | editdistance==0.6.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 28 | executing==1.2.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 29 | filelock==3.11.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 30 | fonttools==4.39.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 31 | frozenlist==1.3.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 32 | fsspec[http]==2023.3.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 33 | gitdb==4.0.10 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 34 | gitpython==3.1.31 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 35 | huggingface-hub==0.13.4 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 36 | idna==3.4 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 37 | iniconfig==2.0.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 38 | ipykernel==6.22.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 39 | ipython==8.12.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 40 | jedi==0.18.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 41 | jinja2==3.1.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 42 | joblib==1.2.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 43 | jsonrpcserver==5.0.9 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 44 | jsonschema==4.17.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 45 | jupyter-client==8.1.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 46 | jupyter-core==5.3.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 47 | kiwisolver==1.4.4 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 48 | line-profiler==4.0.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 49 | lit==16.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 50 | markupsafe==2.1.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 51 | matplotlib-inline==0.1.6 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 52 | matplotlib==3.7.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 53 | mpmath==1.3.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 54 | multidict==6.0.4 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 55 | multiprocess==0.70.14 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 56 | mypy-extensions==1.0.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 57 | nest-asyncio==1.5.6 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 58 | networkx==3.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 59 | nltk==3.8.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 60 | numpy==1.24.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 61 | nvidia-cublas-cu11==11.10.3.66 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 62 | nvidia-cuda-cupti-cu11==11.7.101 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 63 | nvidia-cuda-nvrtc-cu11==11.7.99 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 64 | nvidia-cuda-runtime-cu11==11.7.99 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 65 | nvidia-cudnn-cu11==8.5.0.96 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 66 | nvidia-cufft-cu11==10.9.0.58 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 67 | nvidia-curand-cu11==10.2.10.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 68 | nvidia-cusolver-cu11==11.4.0.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 69 | nvidia-cusparse-cu11==11.7.4.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 70 | nvidia-nccl-cu11==2.14.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 71 | nvidia-nvtx-cu11==11.7.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 72 | oslash==0.6.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 73 | packaging==23.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 74 | pandas==1.4.4 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 75 | parso==0.8.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 76 | pathspec==0.11.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 77 | pathtools==0.1.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 78 | pexpect==4.8.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" and sys_platform != "win32" 79 | pickleshare==0.7.5 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 80 | pillow==9.5.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 81 | platformdirs==3.2.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 82 | pluggy==1.0.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 83 | prettytable==3.4.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 84 | prompt-toolkit==3.0.38 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 85 | protobuf==4.22.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 86 | psutil==5.9.4 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 87 | ptyprocess==0.7.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" and sys_platform != "win32" 88 | pure-eval==0.2.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 89 | pyarrow==11.0.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 90 | pycparser==2.21 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" and implementation_name == "pypy" 91 | pygments==2.14.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 92 | pyparsing==3.0.9 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 93 | pyrsistent==0.19.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 94 | pytest==7.2.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 95 | python-dateutil==2.8.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 96 | pytz-deprecation-shim==0.1.0.post0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 97 | pytz==2023.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 98 | pywin32==306 ; sys_platform == "win32" and platform_python_implementation != "PyPy" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 99 | pyyaml==6.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 100 | pyzmq==25.0.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 101 | regex==2023.3.23 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 102 | requests==2.28.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 103 | responses==0.18.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 104 | sentry-sdk==1.19.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 105 | setproctitle==1.3.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 106 | setuptools==67.6.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 107 | six==1.16.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 108 | smmap==5.0.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 109 | snakeviz==2.1.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 110 | stack-data==0.6.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 111 | sympy==1.11.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 112 | termcolor==1.0.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 113 | tokenizers==0.12.1 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 114 | torch==2.0.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 115 | tornado==6.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 116 | tqdm==4.65.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 117 | traitlets==5.9.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 118 | transformers==4.27.4 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 119 | triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 120 | typing-extensions==4.5.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 121 | tzdata==2023.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 122 | tzlocal==4.3 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 123 | urllib3==1.26.15 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 124 | wandb==0.13.11 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 125 | wcwidth==0.2.6 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 126 | wheel==0.40.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.11.0" and python_full_version < "3.12.0" 127 | xxhash==3.2.0 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 128 | yarl==1.8.2 ; python_full_version >= "3.11.0" and python_full_version < "3.12.0" 129 | -------------------------------------------------------------------------------- /scripts/code_completion_eval.py: -------------------------------------------------------------------------------- 1 | """This script compares the performance of Coeditor against code completion models 2 | on FIM problems extracted from code changes.""" 3 | 4 | import os 5 | import shutil 6 | 7 | import torch 8 | from numpy import mean 9 | 10 | from coeditor.c3problem import ( 11 | C3ProblemGenerator, 12 | C3ProblemTokenizer, 13 | C3ToCodeCompletion, 14 | CompletionKind, 15 | ) 16 | from coeditor.common import * 17 | from coeditor.dataset import make_or_load_dataset 18 | from coeditor.encoding import inline_output_tokens, tokens_to_change 19 | from coeditor.experiments.code_completion import ( 20 | C3CompletionGenerator, 21 | FIMModel, 22 | infill_with_coeditor, 23 | ) 24 | from coeditor.experiments.in_coder import InCoderWrapper 25 | from coeditor.experiments.openai_gpt import OpenAIGptWrapper 26 | from coeditor.experiments.santa_coder import SantaCoderWrapper 27 | from coeditor.experiments.star_coder import StarCoderWrapper 28 | from coeditor.model import RetrievalEditorModel 29 | 30 | os.chdir(proj_root()) 31 | 32 | dataset_name = "perm2k" 33 | device = "cuda" 34 | N_test = 5000 35 | use_additions = True 36 | use_modifications = True 37 | 38 | # first, load the test data in FIM format 39 | fim_probs = make_or_load_dataset( 40 | dataset_name, 41 | C3CompletionGenerator( 42 | max_ctx_tks=1024 * 8, 43 | use_additions=use_additions, 44 | use_modifications=use_modifications, 45 | ), 46 | splits=("test",), 47 | time_limit_per_commit=20, 48 | remake_problems=False, 49 | )["test"] 50 | print(f"{len(fim_probs) = }") 51 | 52 | # and in C3 format 53 | c3_probs = make_or_load_dataset( 54 | dataset_name, 55 | C3ProblemGenerator(), 56 | splits=("test",), 57 | time_limit_per_commit=40, 58 | )["test"] 59 | transform = C3ToCodeCompletion( 60 | use_additions=use_additions, use_modifications=use_modifications 61 | ) 62 | c3_probs = join_list(transform.transform(p) for p in c3_probs) 63 | print(f"{len(c3_probs) = }") 64 | 65 | common_ids = set(p.uid() for p in fim_probs) & set(p.uid() for p in c3_probs) 66 | print(f"{len(common_ids) = }") 67 | fim_probs = [p for p in fim_probs if p.uid() in common_ids] 68 | fim_probs.sort(key=lambda p: p.uid()) 69 | c3_probs = [p for p in c3_probs if p.uid() in common_ids] 70 | c3_probs.sort(key=lambda p: p.uid()) 71 | 72 | # down-sample problems 73 | fim_probs = random_subset(fim_probs, N_test, rng=42) 74 | c3_probs = random_subset(c3_probs, N_test, rng=42) 75 | 76 | # pickle the problems 77 | fim_file = get_dataset_dir(dataset_name) / "code_completion_eval/fim_probs.pkl" 78 | c3_file = get_dataset_dir(dataset_name) / "code_completion_eval/c3_probs.pkl" 79 | pickle_dump(fim_file, fim_probs) 80 | pickle_dump(c3_file, c3_probs) 81 | 82 | sample_ids = set(random_subset(range(len(fim_probs)), 100, rng=73)) 83 | sample_dir = proj_root() / "output" / f"code_completion_eval(n={N_test})" 84 | if sample_dir.exists(): 85 | shutil.rmtree(sample_dir) 86 | 87 | ModelName = str 88 | accuracies = dict[ModelName, dict[str, float]]() 89 | 90 | 91 | def get_accs(results: dict[CompletionKind, list[bool]]) -> dict[str, float]: 92 | """Get the accuracy of the model on additions, modifications, and all problems.""" 93 | return { 94 | "add": float(mean(results["add"])), 95 | "mod": float(mean(results["mod"])), 96 | "all": float(mean(results["add"] + results["mod"])), 97 | } 98 | 99 | 100 | def eval_coeditor(): 101 | coeditor = RetrievalEditorModel.load(get_coeditor_model_path()) 102 | coeditor.half() 103 | coeditor.to("cuda") 104 | tknizer = C3ProblemTokenizer.for_eval() 105 | coeditor_results: dict[CompletionKind, list[bool]] = {"add": [], "mod": []} 106 | for i, prob in tqdm( 107 | list(enumerate(c3_probs)), smoothing=0, desc="Evaluating Coeditor" 108 | ): 109 | tk_prob = tknizer.tokenize_problem(prob) 110 | output = infill_with_coeditor(coeditor, tk_prob) 111 | pred_code = tokens_to_change( 112 | inline_output_tokens(tk_prob.main_tks, output) 113 | ).after 114 | label_code = tokens_to_change( 115 | inline_output_tokens(tk_prob.main_tks, tk_prob.output_tks) 116 | ).after 117 | correct = code_equal(pred_code, label_code) 118 | if "add" in prob.transformations: 119 | kind = "add" 120 | else: 121 | assert "mod" in prob.transformations 122 | kind = "mod" 123 | coeditor_results[kind].append(correct) 124 | 125 | if i in sample_ids: 126 | ex_dir = sample_dir / f"ex{i}" 127 | ex_dir.mkdir(parents=True, exist_ok=True) 128 | (ex_dir / "Coeditor-base.txt").write_text( 129 | tk_prob.show(output), encoding="utf-8" 130 | ) 131 | 132 | accuracies["Coeditor-base"] = get_accs(coeditor_results) 133 | print("Coeditor-base accuracy:") 134 | pretty_print_dict(accuracies["Coeditor-base"]) 135 | coeditor.to("cpu") 136 | 137 | 138 | def eval_fim_models(model_list: dict[str, Callable[[], FIMModel | OpenAIGptWrapper]]): 139 | for name, model_f in model_list.items(): 140 | with run_long_task(f"Evaluating {name}"): 141 | model = model_f() 142 | if isinstance(model, FIMModel): 143 | model.model.to(device) 144 | all_probs = list(enumerate(fim_probs)) 145 | if "gpt" in name: 146 | all_probs = all_probs[:1000] 147 | 148 | results: dict[CompletionKind, list[bool]] = {"add": [], "mod": []} 149 | for i, prob in tqdm(all_probs, smoothing=0, desc=f"Evaluating {name}"): 150 | left_ctx = "\n".join(prob.left_ctx) + "\n" 151 | right_ctx = "\n" + "\n".join(prob.right_ctx) 152 | with torch.no_grad(): 153 | try: 154 | pred = model.infill(left_ctx, right_ctx, max_output=128) 155 | except Exception as e: 156 | import traceback 157 | 158 | traceback.print_exc() 159 | print(f"Errored on problem {i}: {e}") 160 | print("Exiting the evaluation eariler with partial results.") 161 | break 162 | if pred: 163 | pred = pred.split("\n")[0] # only keep the first predicted line 164 | left_part = prob.left_ctx[-1] + "\n" if prob.left_ctx else "" 165 | right_part = "\n" + prob.right_ctx[0] if prob.right_ctx else "" 166 | pred_code = left_part + pred + right_part 167 | label_code = left_part + prob.middle + right_part 168 | correct = code_equal(pred_code, label_code) 169 | results[prob.kind].append(correct) 170 | if i in sample_ids: 171 | ex_dir = sample_dir / f"ex{i}" 172 | ex_dir.mkdir(parents=True, exist_ok=True) 173 | pred_str = ( 174 | f"prediction:\n{pred}\n{SEP}\nlabel:\n{prob.middle}\n" 175 | f"{SEP}\nleft context:\n{left_ctx}\n{SEP}\n" 176 | f"right context:\n{right_ctx}" 177 | ) 178 | (ex_dir / f"{name}.txt").write_text(pred_str, encoding="utf-8") 179 | 180 | accuracies[name] = acc = get_accs(results) 181 | print(f"{name} accuracy:") 182 | pretty_print_dict(acc) 183 | if isinstance(model, FIMModel): 184 | model.model.to("cpu") 185 | 186 | 187 | # %% Evaluate all models 188 | 189 | fim_model_list: dict[str, Callable[[], FIMModel | OpenAIGptWrapper]] = { 190 | "SantaCoder": SantaCoderWrapper.from_pretrained, 191 | "InCoder-1B": lambda: InCoderWrapper.from_pretrained( 192 | "facebook/incoder-1B", half_precision=True 193 | ), 194 | "InCoder-6B": lambda: InCoderWrapper.from_pretrained( 195 | "facebook/incoder-6B", half_precision=True 196 | ), 197 | "gpt-3.5-fim": lambda: OpenAIGptWrapper(use_fim=True, use_nl_prompt=False), 198 | "gpt-3.5-nl-prompt": lambda: OpenAIGptWrapper(use_fim=True, use_nl_prompt=True), 199 | "StarCoder-7B": lambda: StarCoderWrapper.from_pretrained(half_precision=True), 200 | } 201 | 202 | 203 | with run_long_task(f"Evaluating Coeditor"): 204 | eval_coeditor() 205 | 206 | eval_fim_models(fim_model_list) 207 | 208 | print(SEP) 209 | print(f"Summary ({use_additions=}, {use_modifications=}):") 210 | for model, acc in accuracies.items(): 211 | print("Model: " + model) 212 | for group, acc in acc.items(): 213 | print(f"\t{group}: {acc:.2%}") 214 | -------------------------------------------------------------------------------- /scripts/install-deps-linux.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # install poetry 3 | curl -sSL https://install.python-poetry.org | python3 - 4 | poetry completions bash >> ~/.bash_completion 5 | poetry install 6 | -------------------------------------------------------------------------------- /scripts/multi_round_eval.py: -------------------------------------------------------------------------------- 1 | """Evaluate Coeditor's performance in a multi-round editing setting.""" 2 | 3 | import os 4 | 5 | from coeditor.c3problem import C3ProblemGenerator, C3ProblemTokenizer 6 | from coeditor.common import * 7 | from coeditor.dataset import make_or_load_dataset 8 | from coeditor.model import ( 9 | DecodingArgs, 10 | MultiRoundEvaluator, 11 | MultiRoundStrategy, 12 | RetrievalEditorModel, 13 | ) 14 | 15 | os.chdir(proj_root()) 16 | 17 | dataset_name = "perm2k" 18 | N_test = 5000 # number of test examples to evaluate 19 | # NOTE: You can change the `model_name`` below to a `Path` to load a local model. 20 | model_name = get_coeditor_model_path() 21 | model_device = "cuda" 22 | 23 | # %% 24 | testset = make_or_load_dataset( 25 | dataset_name, 26 | C3ProblemGenerator(), 27 | splits=("test",), 28 | time_limit_per_commit=40, 29 | )["test"] 30 | 31 | print(f"{len(testset)}") 32 | subset = random_subset(testset, N_test, rng=42) 33 | print(f"{len(subset)=}") 34 | 35 | 36 | # %% 37 | tokenizer = C3ProblemTokenizer.for_eval() 38 | dec_args = DecodingArgs(do_sample=False, num_beams=1) 39 | model = RetrievalEditorModel.load(model_name) 40 | model.to(model_device) 41 | 42 | strategies: list[MultiRoundStrategy] = ["pick_first", "most_uncertain"] 43 | for strategy in strategies: 44 | evaluator = MultiRoundEvaluator(model, tokenizer, dec_args, strategy=strategy) 45 | metric_stats = [ 46 | evaluator.multi_round_edit_gain(ex, print_steps=False) 47 | for ex in tqdm(subset, smoothing=0.0) 48 | ] 49 | 50 | print("=" * 100) 51 | print("Prompting strategy:", strategy) 52 | target_file = ( 53 | proj_root() / f"output/multi_round_eval/{model_name}/{strategy}-{N_test}.pkl" 54 | ) 55 | pickle_dump(target_file, metric_stats) 56 | for cm in evaluator.cost_models: 57 | cm_name = cm.name 58 | print(SEP) 59 | print("Cost model:", cm_name) 60 | stats = [s[cm_name] for s in metric_stats] 61 | 62 | keys = ["label_edit_gain", "first_edit_gain", "total_edit_gain", "rounds"] 63 | mean_stats = {k: scalar_stats([getattr(s, k) for s in stats]) for k in keys} 64 | pretty_print_dict(mean_stats) 65 | 66 | print(f"For all edits (n={len(stats)}):") 67 | label_sum = sum(s.label_edit_gain for s in stats) 68 | single_sum = sum(s.first_edit_gain for s in stats) 69 | multi_sum = sum(s.total_edit_gain for s in stats) 70 | print(f"Single-round Gain ratio: {single_sum / label_sum:.2%}") 71 | print(f"Multi-round Gain ratio: {multi_sum / label_sum:.2%}") 72 | -------------------------------------------------------------------------------- /scripts/prepare_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script preprocesses the repos into the PyCommits format introduced in the paper. 3 | You can safely skip this step since it will automatically be run when you 4 | train a new model (and with the corresponding encoder parameters). 5 | 6 | The raw repos will be loaded from `get_dataset_dir(dataset_name) / "repos"`, and the 7 | processed results will be saved to `get_dataset_dir(dataset_name) / "processed"` 8 | and `get_dataset_dir(dataset_name) / "transformed"`. 9 | """ 10 | 11 | from coeditor._utils import run_long_task 12 | from coeditor.c3problem import C3ProblemChangeInlining, C3ProblemGenerator 13 | from coeditor.common import * 14 | from coeditor.dataset import * 15 | 16 | if __name__ == "__main__": 17 | os.chdir(proj_root()) 18 | 19 | dataset_name = "perm2k" 20 | encoder = C3CombinedEncoder( 21 | problem_tranform=C3ProblemChangeInlining( 22 | max_inline_ratio=0.6, allow_empty_problems=True 23 | ), 24 | ) 25 | with run_long_task( 26 | f"Preparing dataset {dataset_name} with encoder {encoder.change_processor}" 27 | ): 28 | problems = make_or_load_dataset( 29 | dataset_name, 30 | encoder.change_processor, 31 | ("valid", "test", "train"), 32 | remake_problems=False, 33 | ) 34 | 35 | transformed = make_or_load_transformed_dataset( 36 | dataset_name, 37 | problems, 38 | encoder, 39 | ) 40 | 41 | tokenizer = C3ProblemTokenizer() 42 | for name, probs in transformed.items(): 43 | probs = cast(Sequence[C3Problem], probs) 44 | print("=" * 40, name, "=" * 40) 45 | stats = tokenizer.compute_stats(probs) 46 | pretty_print_dict(stats) 47 | -------------------------------------------------------------------------------- /scripts/single_round_eval.py: -------------------------------------------------------------------------------- 1 | """Evaluate Coeditor's exact match performance in a single-round editing setting. 2 | 3 | This script generates the results for the ablation studies. 4 | """ 5 | 6 | import os 7 | 8 | import numpy as np 9 | 10 | from coeditor.c3problem import C3ProblemGenerator, C3ProblemTokenizer 11 | from coeditor.common import * 12 | from coeditor.dataset import make_or_load_dataset 13 | from coeditor.model import BatchArgs, C3DataLoader, DecodingArgs, RetrievalEditorModel 14 | 15 | os.chdir(proj_root()) 16 | 17 | dataset_name = "perm2k" 18 | model_device = "cuda" 19 | 20 | model_names = { 21 | "No Ablation": "coeditor-perm2k-c3-multi-v1.7.3", 22 | "No Diffs": "coeditor-perm2k-c3-multi-no_change-v1.7.3", 23 | "No Defs": "coeditor-perm2k-c3-multi-no_defs-v1.7.2", 24 | "Small Context": "coeditor-perm2k-c3-multi-2048-v1.7.2", 25 | } 26 | 27 | 28 | # we load the older dataset format since the models above were trained on it. 29 | testset = pickle_load( 30 | get_dataset_dir("perm2k") / "processed" / "valid-C3ProblemGenerator(VERSION=2.9)" 31 | ) 32 | 33 | # # uncomment below to load with the newest dataset format 34 | # testset = make_or_load_dataset( 35 | # dataset_name, 36 | # C3ProblemGenerator(), 37 | # splits=("valid",), 38 | # time_limit_per_commit=40, 39 | # )["valid"] 40 | 41 | # testset = random_subset(testset, 50, rng=42) 42 | print(f"{len(testset)=}") 43 | 44 | accs = dict[str, dict]() 45 | results = dict[str, list[bool]]() 46 | for name, full_name in model_names.items(): 47 | if "checkpoint" in full_name: 48 | model_path = get_model_dir(False) / full_name 49 | else: 50 | model_path = get_model_dir() / full_name 51 | model = RetrievalEditorModel.load(model_path) 52 | model.to(model_device) 53 | 54 | out_dir = get_model_dir() / full_name / "exact_match_samples" 55 | eval_tkn = C3ProblemTokenizer.for_eval() 56 | if name == "Small Context": 57 | eval_tkn.max_ref_tks_sum = 2048 58 | eval_batch_args = BatchArgs.eval_default() 59 | 60 | with timed_action(f"Evaluating {name}"): 61 | test_loader = C3DataLoader( 62 | testset, None, eval_tkn, eval_batch_args, shuffle=False, desc="evaluating" 63 | ) 64 | correctness = model.eval_on_data( 65 | testset, 66 | test_loader, 67 | DecodingArgs(), 68 | out_dir, 69 | probs_to_save=300, 70 | ) 71 | results[name] = correctness 72 | exact_acc = float(np.mean(correctness)) 73 | lb, ub = bootstrap_sample(list(map(float, correctness))) 74 | print("Exact-match accuracy:", exact_acc) 75 | print(f"95% CI: [{lb:.4f}, {ub:.4f}]") 76 | cprint("blue", "Exact-match samples saved to:", out_dir) 77 | accs[name] = {"mean": exact_acc, "lb": lb, "ub": ub} 78 | 79 | model.to("cpu") 80 | del model 81 | 82 | pretty_print_dict(accs) 83 | pickle_dump(Path("output/single_round_eval-accs.pkl"), accs) 84 | pickle_dump(Path("output/single_round_eval-results.pkl"), results) 85 | 86 | baseline_perf = results["No Ablation"] 87 | for name in ["No Diffs", "No Defs", "Small Context"]: 88 | this_perf = results[name] 89 | pvalue = bootstrap_compare(this_perf, baseline_perf) 90 | print(f"(vs. No Ablation) {name} p-value: {pvalue:.4f}") 91 | -------------------------------------------------------------------------------- /scripts/start_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run this script to start the Coeditor VSCode extension server. 3 | This will download the pre-trained Coeditor model from Huggingface (if not already) 4 | and start the extension service at port 5042. 5 | """ 6 | 7 | import traceback 8 | from functools import wraps 9 | 10 | from jsonrpcserver import Error, Success, method, serve 11 | 12 | from coeditor.common import * 13 | from coeditor.model import RetrievalEditorModel 14 | from coeditor.service import ( 15 | ChangeDetector, 16 | DecodingArgs, 17 | EditPredictionService, 18 | ServiceResponse, 19 | ) 20 | 21 | 22 | class LazyVal(Generic[T1]): 23 | def __init__(self, task: Callable[[], T1], tag: int): 24 | self._finished = False 25 | self._task = task 26 | self.id = tag 27 | 28 | def get(self) -> T1: 29 | if not self._finished: 30 | assert self._task is not None 31 | v = self._task() 32 | self._task = None 33 | self._finished = True 34 | self._result = v 35 | return self._result 36 | 37 | 38 | def start_server(device, port: int, print_stats: bool = True): 39 | model_path = get_coeditor_model_path() 40 | model = RetrievalEditorModel.load(model_path) 41 | model.to(device) 42 | print(f"Model '{model_path}' loaded on device:", device) 43 | dec_args = DecodingArgs(do_sample=False, num_beams=4) 44 | 45 | services = dict[Path, EditPredictionService]() 46 | tasks = dict[Path, LazyVal[ServiceResponse]]() 47 | 48 | def handle_error(f, *args, **kwargs): 49 | @wraps(f) 50 | def wrapper(*args, **kwargs): 51 | try: 52 | return f(*args, **kwargs) 53 | except Exception as e: 54 | traceback.print_exception(e) 55 | return Error(code=1, message=repr(e)) 56 | 57 | return wrapper 58 | 59 | @method 60 | @handle_error 61 | def initialize(project: str): 62 | target_dir = Path(project).resolve() 63 | 64 | if target_dir not in services: 65 | with timed_action(f"Create service for project: {target_dir}"): 66 | detector = ChangeDetector(target_dir) 67 | services[target_dir] = EditPredictionService( 68 | detector, 69 | model, 70 | dec_args=dec_args, 71 | ) 72 | 73 | return Success("OK") 74 | 75 | @method 76 | @handle_error 77 | def submit_problem( 78 | time: int, project: str, file: str, lines: Sequence[int] | int, writeLogs: bool 79 | ): 80 | initialize(project) 81 | target_dir = Path(project).resolve() 82 | service = services[target_dir] 83 | 84 | print(f"Suggesting edit for lines {lines} in {file}") 85 | path = Path(file) 86 | if Path.is_absolute(path): 87 | path = path.relative_to(target_dir) 88 | path = to_rel_path(path) 89 | 90 | service.tlogger.clear() 91 | log_dir = service.project / ".coeditor_logs" if writeLogs else None 92 | region, f = service._suggest_edit_two_steps(path, lines, log_dir) 93 | if target_dir in tasks and tasks[target_dir].id > time: 94 | return Success("Skipped") 95 | tasks[target_dir] = LazyVal(f, time) 96 | return Success(region.target_lines) 97 | 98 | @method 99 | @handle_error 100 | def get_result(time: int, project: str): 101 | target_dir = Path(project).resolve() 102 | cont = tasks[target_dir] 103 | if cont.id > time: 104 | return Success("Skipped") 105 | response = cont.get() 106 | service = services[target_dir] 107 | if print_stats: 108 | print("Runtime stats:") 109 | display(service.tlogger.as_dataframe()) 110 | 111 | return Success(response.to_json()) 112 | 113 | print(f"Starting suggestion server at localhost:{port}") 114 | serve("localhost", port) 115 | 116 | 117 | if __name__ == "__main__": 118 | start_server("cuda", port=5042) 119 | -------------------------------------------------------------------------------- /scripts/train_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | You can use this script to train a new model from scratch. 3 | By default, this script trains a model under our default settings, but you can 4 | uncomment the corresponding function calls at the bottom of the script to train 5 | a model following one of the ablation settings in the paper. 6 | """ 7 | 8 | import copy 9 | import multiprocessing 10 | import os 11 | import shutil 12 | import warnings 13 | 14 | import numpy as np 15 | import wandb 16 | 17 | from coeditor._utils import cprint, run_long_task 18 | from coeditor.c3problem import ( 19 | C3ProblemChangeInlining, 20 | C3ProblemGenerator, 21 | C3ProblemTokenizer, 22 | C3ToCodeCompletion, 23 | ) 24 | from coeditor.common import * 25 | from coeditor.dataset import ( 26 | C3CombinedEncoder, 27 | C3ProblemDataset, 28 | make_or_load_dataset, 29 | make_or_load_transformed_dataset, 30 | ) 31 | from coeditor.model import ( 32 | BatchArgs, 33 | C3DataLoader, 34 | DecodingArgs, 35 | RetrievalEditorModel, 36 | TrainingArgs, 37 | ) 38 | 39 | 40 | def train_model( 41 | model_name: str, 42 | dataset_name: str, 43 | description: str, 44 | encoder: C3CombinedEncoder = C3CombinedEncoder(), 45 | batch_args=BatchArgs.train_default(), 46 | eval_batch_args=BatchArgs.eval_default(), 47 | train_args=TrainingArgs(), 48 | fixed_ref_tks_sum: int | None = None, 49 | recreate_data: bool = False, 50 | multi_stage_training: bool = True, 51 | resumed_from: Path | None = None, 52 | model_size: Literal["small", "base", "large"] = "base", 53 | eval_only: bool = False, 54 | quicktest: bool = False, 55 | ): 56 | dec_args = DecodingArgs() 57 | if quicktest: 58 | model_name = "quicktest-" + model_name 59 | 60 | if not eval_only: 61 | check_save_dir(model_name) 62 | 63 | # problems will be transformed and saved for valid and test but not train. 64 | datasets = make_or_load_dataset( 65 | dataset_name, 66 | encoder.change_processor, 67 | remake_problems=recreate_data, 68 | splits=("valid", "test", "train"), 69 | ) 70 | 71 | with timed_action("Making or loading transformed C3 problems for eval"): 72 | # it's important to cache these due to randomness in the transformations 73 | eval_probs = make_or_load_transformed_dataset( 74 | dataset_name, 75 | datasets, 76 | encoder, 77 | remake_problems=recreate_data, 78 | ) 79 | 80 | # limit the number of examples for faster testing 81 | datasets["valid"] = random_subset(eval_probs["valid"], 10000, rng=42) 82 | datasets["test"] = random_subset(eval_probs["test"], 10000, rng=42) 83 | 84 | config_dict: dict[str, Any] = { 85 | "description": description, 86 | "edit_tokenizer": encoder.edit_tokenizer.get_args(), 87 | "batch_args": batch_args, 88 | "train_args": train_args, 89 | "dec_args": dec_args, 90 | } 91 | 92 | project = "Coeditor" if not quicktest else "Coeditor-quicktest" 93 | if eval_only: 94 | project = "eval-" + project 95 | wandb.init(dir="..", project=project, name=model_name, config=config_dict) 96 | 97 | if quicktest: 98 | print("Using fewer data for quick test.") 99 | n_quick_exs = 20 100 | datasets = C3ProblemDataset( 101 | train=datasets["train"][:n_quick_exs], 102 | valid=datasets["valid"][:n_quick_exs], 103 | test=datasets["test"][:n_quick_exs], 104 | ) 105 | 106 | if resumed_from is None: 107 | model = RetrievalEditorModel.from_code_t5(model_size) 108 | else: 109 | model = RetrievalEditorModel.load(resumed_from) 110 | 111 | if os.getenv("CUDA_VISIBLE_DEVICES") is None: 112 | warnings.warn( 113 | "CUDA_VISIBLE_DEVICES not set, using 0. Note that " 114 | "the Huggingface Trainer will use all visible GPUs for training." 115 | ) 116 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 117 | 118 | train_tkn = encoder.edit_tokenizer 119 | eval_tkn = copy.deepcopy(train_tkn) 120 | eval_tkn.max_query_tks = 1024 121 | eval_tkn.max_output_tks *= 2 122 | eval_tkn.max_ref_tks_sum *= 2 123 | if fixed_ref_tks_sum is not None: 124 | eval_tkn.max_ref_tks_sum = fixed_ref_tks_sum 125 | 126 | valid_loader = C3DataLoader( 127 | datasets["valid"], None, eval_tkn, eval_batch_args, shuffle=False, desc="eval" 128 | ) 129 | 130 | if not eval_only and multi_stage_training: 131 | # gradually increase the ctx size during training 132 | scales = [4, 2] 133 | for scale in scales: 134 | s_tkn = copy.copy(train_tkn) 135 | s_tkn.max_ref_tks_sum //= scale 136 | if fixed_ref_tks_sum is not None: 137 | s_tkn.max_ref_tks_sum = fixed_ref_tks_sum 138 | s_probs = [ 139 | x 140 | for x in datasets["train"] 141 | if sum(c.change_size() for c in x.relevant_changes) 142 | < s_tkn.max_ref_tks_sum 143 | ] 144 | # n_probs = max(1, scale * len(s_probs) // max(scales)) 145 | # s_probs = random_subset(s_probs, n_probs) 146 | desc = f"training (ctx={s_tkn.max_ref_tks_sum})" 147 | s_loader = C3DataLoader( 148 | s_probs, 149 | encoder.problem_tranform, 150 | s_tkn, 151 | batch_args, 152 | shuffle=True, 153 | desc=desc, 154 | ) 155 | 156 | with timed_action(desc): 157 | model.train_on_data(model_name, s_loader, valid_loader, train_args) 158 | 159 | elif not eval_only: 160 | desc = f"training (ctx={train_tkn.max_ref_tks_sum})" 161 | s_probs = [ 162 | x 163 | for x in datasets["train"] 164 | if sum(c.change_size() for c in x.relevant_changes) 165 | < C3ProblemTokenizer.max_ref_tks_sum 166 | ] 167 | s_loader = C3DataLoader( 168 | s_probs, 169 | encoder.problem_tranform, 170 | train_tkn, 171 | batch_args, 172 | shuffle=True, 173 | desc=desc, 174 | ) 175 | 176 | with timed_action(desc): 177 | model.train_on_data(model_name, s_loader, valid_loader, train_args) 178 | 179 | model.to("cuda") 180 | test_loader = C3DataLoader( 181 | datasets["test"], None, eval_tkn, eval_batch_args, shuffle=False, desc="test" 182 | ) 183 | print(f"{len(test_loader)}") 184 | print(f"{len(test_loader.all_probs)}") 185 | with timed_action("Loss Evaluation"): 186 | eval_result = model.eval_loss_on_loader(test_loader) 187 | eval_dict = {f"test/{k}": v.average() for k, v in eval_result.items()} 188 | wandb.log(eval_dict) 189 | 190 | with timed_action("Accuracy Evaluation"): 191 | out_dir = get_model_dir() / model_name / "exact_match_samples" 192 | correctness = model.eval_on_data( 193 | datasets["test"], 194 | test_loader, 195 | dec_args, 196 | out_dir, 197 | probs_to_save=300, 198 | ) 199 | exact_acc = float(np.mean(correctness)) 200 | print("Exact-match accuracy:", exact_acc) 201 | wandb.log({"test/exact-acc": exact_acc}) 202 | cprint("blue", "Exact-match samples saved to:", out_dir) 203 | 204 | return model 205 | 206 | 207 | def check_save_dir(model_name: str) -> None: 208 | "Prompt user to remove existing training directory or abort." 209 | training_dir = get_model_dir(False) / model_name 210 | trained_dir = get_model_dir(True) / model_name 211 | if training_dir.exists(): 212 | print(f"Training directory already exists:", training_dir) 213 | answer = input("Remove and retrain? (y/n):") 214 | if answer.lower().strip() == "y": 215 | shutil.rmtree(training_dir) 216 | return 217 | else: 218 | print("Training aborted.") 219 | exit(1) 220 | if trained_dir.exists(): 221 | print(f"Saved model already exists:", trained_dir) 222 | answer = input("Model will be overriden at the end. Continue? (y/n):") 223 | if answer.lower().strip() != "y": 224 | print("Training aborted.") 225 | exit(1) 226 | 227 | 228 | def eval_code_completion(): 229 | train_model( 230 | model_name="coeditor-xl-c3-completion-v1.6-resumed", 231 | dataset_name="tiny", 232 | description="", 233 | encoder=C3CombinedEncoder( 234 | problem_tranform=C3ToCodeCompletion(), 235 | ), 236 | resumed_from=(get_model_dir(True) / "coeditor-xl-c3-dropout-v1.6-resumed"), 237 | eval_only=True, 238 | ) 239 | 240 | 241 | def train_new_model(): 242 | train_model( 243 | model_name="coeditor-perm2k-base-v2.0", 244 | dataset_name="perm2k", 245 | description="Coeditor model trained with default settings.", 246 | encoder=C3CombinedEncoder( 247 | problem_tranform=C3ProblemChangeInlining( 248 | max_inline_ratio=0.6, allow_empty_problems=True 249 | ), 250 | ), 251 | ) 252 | 253 | 254 | def ablation_short_context(): 255 | train_model( 256 | model_name="coeditor-perm2k-2048ctx-v2.0", 257 | dataset_name="perm2k", 258 | description="Ablation: Use only 2048 max reference tokens.", 259 | encoder=C3CombinedEncoder( 260 | problem_tranform=C3ProblemChangeInlining( 261 | max_inline_ratio=0.6, allow_empty_problems=True 262 | ), 263 | ), 264 | fixed_ref_tks_sum=2048, 265 | ) 266 | 267 | 268 | def ablation_no_signatures(): 269 | train_model( 270 | model_name="coeditor-perm2k-no_sigs-v2.0", 271 | dataset_name="perm2k", 272 | description="Ablation: No signatures in context.", 273 | encoder=C3CombinedEncoder( 274 | problem_tranform=C3ProblemChangeInlining( 275 | max_inline_ratio=0.6, allow_empty_problems=True 276 | ), 277 | edit_tokenizer=C3ProblemTokenizer(disable_unchanged_refs=True), 278 | ), 279 | ) 280 | 281 | 282 | def ablation_no_changes(): 283 | train_model( 284 | model_name="coeditor-perm2k-no_changes-v2.0", 285 | dataset_name="perm2k", 286 | description="Ablation: No changes in context.", 287 | encoder=C3CombinedEncoder( 288 | problem_tranform=C3ProblemChangeInlining( 289 | max_inline_ratio=0.6, allow_empty_problems=True 290 | ), 291 | edit_tokenizer=C3ProblemTokenizer(current_code_only=True), 292 | ), 293 | ) 294 | 295 | 296 | if __name__ == "__main__": 297 | os.chdir(proj_root()) 298 | 299 | with run_long_task("train default model"): 300 | train_new_model() 301 | 302 | # with run_long_task("train ablation: short context"): 303 | # ablation_short_context() 304 | 305 | # with run_long_task("train ablation: no signatures"): 306 | # ablation_no_signatures() 307 | 308 | # with run_long_task("train ablation: no changes"): 309 | # ablation_no_changes() 310 | -------------------------------------------------------------------------------- /src/coeditor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrVPlusOne/Coeditor/4d645e1293f6d5a9a93dd893b0b608dc54153a6a/src/coeditor/__init__.py -------------------------------------------------------------------------------- /src/coeditor/change.py: -------------------------------------------------------------------------------- 1 | # utils for computing editing history from git commits 2 | 3 | from abc import abstractmethod 4 | from textwrap import indent 5 | 6 | from .common import * 7 | 8 | E1 = TypeVar("E1", covariant=True) 9 | 10 | 11 | class _ChangeBase(Generic[E1]): 12 | def show(self, name: str = "") -> str: 13 | return show_change(cast("Change", self), name=name) 14 | 15 | @property 16 | @abstractmethod 17 | def earlier(self) -> E1: 18 | ... 19 | 20 | @property 21 | @abstractmethod 22 | def later(self) -> E1: 23 | ... 24 | 25 | @property 26 | def changed(self) -> bool: 27 | return True 28 | 29 | 30 | @dataclass(frozen=True) 31 | class Added(_ChangeBase[E1]): 32 | after: E1 33 | 34 | def map(self, f: Callable[[E1], T2]) -> "Added[T2]": 35 | return Added(f(self.after)) 36 | 37 | def inverse(self) -> "Deleted[E1]": 38 | return Deleted(self.after) 39 | 40 | @property 41 | def earlier(self) -> E1: 42 | return self.after 43 | 44 | @property 45 | def later(self) -> E1: 46 | return self.after 47 | 48 | @staticmethod 49 | def new_value(v: T1) -> "Added[T1]": 50 | return Added(v) 51 | 52 | @staticmethod 53 | def as_char(): 54 | return "A" 55 | 56 | 57 | @dataclass(frozen=True) 58 | class Deleted(_ChangeBase[E1]): 59 | before: E1 60 | 61 | def map(self, f: Callable[[E1], T2]) -> "Deleted[T2]": 62 | return Deleted(f(self.before)) 63 | 64 | def inverse(self) -> "Added[E1]": 65 | return Added(self.before) 66 | 67 | @property 68 | def earlier(self) -> E1: 69 | return self.before 70 | 71 | @property 72 | def later(self) -> E1: 73 | return self.before 74 | 75 | @staticmethod 76 | def new_value(v: T1) -> "Deleted[T1]": 77 | return Deleted(v) 78 | 79 | @staticmethod 80 | def as_char(): 81 | return "D" 82 | 83 | 84 | @dataclass(frozen=True) 85 | class Modified(_ChangeBase[E1]): 86 | before: E1 87 | after: E1 88 | # Used for optimization. If False, `before`` may still equal to `after`. 89 | unchanged: bool = False 90 | 91 | def map(self, f: Callable[[E1], T2]) -> "Modified[T2]": 92 | if self.unchanged: 93 | return Modified.from_unchanged(f(self.before)) 94 | else: 95 | return Modified(f(self.before), f(self.after)) 96 | 97 | def inverse(self) -> "Modified[E1]": 98 | return Modified(self.after, self.before) 99 | 100 | @property 101 | def earlier(self) -> E1: 102 | return self.before 103 | 104 | @property 105 | def later(self) -> E1: 106 | return self.after 107 | 108 | @property 109 | def changed(self) -> bool: 110 | return not self.unchanged 111 | 112 | @staticmethod 113 | def as_char(): 114 | return "M" 115 | 116 | @staticmethod 117 | def from_unchanged(v: T1) -> "Modified[T1]": 118 | return Modified(v, v, unchanged=True) 119 | 120 | def __repr__(self): 121 | if self.before == self.after: 122 | return f"Modified(before=after={repr(self.before)})" 123 | else: 124 | return f"Modified(before={repr(self.before)}, after={repr(self.after)})" 125 | 126 | 127 | Change = Added[E1] | Deleted[E1] | Modified[E1] 128 | 129 | 130 | def default_show_diff( 131 | before: Any | None, after: Any | None, max_ctx: int | None = 6 132 | ) -> str: 133 | before = str(before) if before is not None else "" 134 | after = str(after) if after is not None else "" 135 | 136 | return show_string_diff(before, after, max_ctx=max_ctx) 137 | 138 | 139 | def show_change( 140 | change: Change[T1], 141 | name: str = "", 142 | show_diff: Callable[[T1 | None, T1 | None], str] = default_show_diff, 143 | ) -> str: 144 | tab = " " 145 | if isinstance(change, Added): 146 | return f"Added: {name}\n{indent(show_diff(None, change.after), tab)}" 147 | elif isinstance(change, Deleted): 148 | return f"Deleted: {name}\n{indent(show_diff(change.before, None), tab)}" 149 | elif isinstance(change, Modified): 150 | if change.before == change.after: 151 | return f"Unchanged: {name}" 152 | diff = show_diff(change.before, change.after) 153 | return f"Modified: {name}\n{indent(diff, tab)}" 154 | else: 155 | raise TypeError(f"Not a change type: {type(change)}") 156 | 157 | 158 | def get_named_changes( 159 | old_map: Mapping[T1, T2], new_map: Mapping[T1, T2] 160 | ) -> Mapping[T1, Change[T2]]: 161 | "Compute the changes between two maps of named elements." 162 | old_names = set(old_map) 163 | new_names = set(new_map) 164 | deleted_names = old_names - new_names 165 | added_names = new_names - old_names 166 | modified_names = old_names & new_names 167 | changes = {} 168 | for name in deleted_names: 169 | changes[name] = Deleted(old_map[name]) 170 | for name in added_names: 171 | changes[name] = Added(new_map[name]) 172 | for name in modified_names: 173 | changes[name] = Modified(old_map[name], new_map[name]) 174 | return changes 175 | 176 | 177 | # def _select_ast_calls( 178 | # node: ast.AST, path: ProjectPath 179 | # ) -> Generator[ast.Call, None, None]: 180 | # """Return all call nodes with the mathcing function name in the AST.""" 181 | # segs = split_dots(path.path) 182 | # if segs[-1] == "__init__": 183 | # f_name = segs[-2] 184 | # else: 185 | # f_name = segs[-1] 186 | # for n in ast.walk(node): 187 | # if isinstance(n, ast.Call) and isinstance(n.func, ast.Name): 188 | # if n.func.id == f_name: 189 | # yield n 190 | 191 | 192 | # def find_refactored_calls( 193 | # pedit: ProjectEdit, 194 | # pre_analysis: UsageAnalysis, 195 | # post_analysis: UsageAnalysis, 196 | # ) -> dict[ProjectPath, list[tuple[ProjectPath, Modified[cst.Call]]]]: 197 | # """Analyze project changes and return a mapping from each function `f` to 198 | # the refactored callsites within `f`.""" 199 | 200 | # changed_apis = set[ProjectPath]() 201 | # for c in pedit.all_elem_changes(): 202 | # match c: 203 | # case Modified(before=PythonFunction(), after=PythonFunction()): 204 | # if is_signature_changed(c): 205 | # changed_apis.add(c.before.path) 206 | 207 | # refactorings = dict[ProjectPath, list[tuple[ProjectPath, Modified[cst.Call]]]]() 208 | # for c in pedit.modified_functions(): 209 | # path = c.before.path 210 | # pre_usages = { 211 | # u.used: u.callsite 212 | # for u in pre_analysis.user2used.get(path, []) 213 | # if u.callsite 214 | # } 215 | # pos_usages = { 216 | # u.used: u.callsite 217 | # for u in post_analysis.user2used.get(path, []) 218 | # if u.callsite 219 | # } 220 | # call_changes = list[tuple[ProjectPath, Modified[cst.Call]]]() 221 | # for k in changed_apis & pre_usages.keys() & pos_usages.keys(): 222 | # call_before = normalize_code_by_ast(show_expr(pre_usages[k])) 223 | # call_after = normalize_code_by_ast(show_expr(pos_usages[k])) 224 | # if call_before != call_after: 225 | # call_changes.append((k, Modified(pre_usages[k], pos_usages[k]))) 226 | # refactorings[path] = call_changes 227 | # return refactorings 228 | -------------------------------------------------------------------------------- /src/coeditor/dataset.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import tempfile 3 | import traceback 4 | 5 | from coeditor import scoped_changes 6 | from coeditor._utils import pretty_print_dict, scalar_stats 7 | 8 | from .c3problem import ( 9 | C3Problem, 10 | C3ProblemGenerator, 11 | C3ProblemSimpleSplit, 12 | C3ProblemTokenizer, 13 | C3ProblemTransform, 14 | JediUsageAnalyzer, 15 | fix_jedi_cache, 16 | ) 17 | from .change import Added 18 | from .common import * 19 | from .encoding import TEdit 20 | from .git import CommitInfo, get_commit_history 21 | from .scoped_changes import ProjectChangeProcessor, TProb, edits_from_commit_history 22 | 23 | 24 | @dataclass 25 | class TokenizedEditDataset(Generic[TEdit]): 26 | _edits: list[TEdit] 27 | 28 | def __repr__(self) -> str: 29 | n_edits = len(self.all_edits()) 30 | return f"TokenizedEditDataset(n_edits={n_edits})" 31 | 32 | def subset_edits(self, n_edits: int) -> "TokenizedEditDataset": 33 | return TokenizedEditDataset.from_edits(self.all_edits()[:n_edits]) 34 | 35 | def overall_stats(self) -> dict: 36 | all_edits = self.all_edits() 37 | n_added = sum(isinstance(e.change_type, Added) for e in all_edits) 38 | basic_stats = { 39 | "n_edits": len(all_edits), 40 | "n_additions": n_added, 41 | } 42 | extra_stats = dict[str, list]() 43 | for e in all_edits: 44 | for k, v in e.stats().items(): 45 | if k in extra_stats: 46 | extra_stats[k].append(v) 47 | else: 48 | extra_stats[k] = [v] 49 | return basic_stats | {k: scalar_stats(v) for k, v in extra_stats.items()} 50 | 51 | def all_edits(self) -> list[TEdit]: 52 | return self._edits 53 | 54 | @staticmethod 55 | def from_edits(edits: Iterable[TEdit]) -> "TokenizedEditDataset[TEdit]": 56 | return TokenizedEditDataset(list(edits)) 57 | 58 | 59 | @dataclass 60 | class C3CombinedEncoder: 61 | change_processor: ProjectChangeProcessor[C3Problem] = field( 62 | default_factory=C3ProblemGenerator 63 | ) 64 | problem_tranform: C3ProblemTransform = field(default_factory=C3ProblemSimpleSplit) 65 | edit_tokenizer: C3ProblemTokenizer = field(default_factory=C3ProblemTokenizer) 66 | 67 | 68 | @dataclass 69 | class _ProcessingResult: 70 | edits: Sequence 71 | stats: dict[str, dict | Any] 72 | 73 | 74 | def _process_commits( 75 | root: Path, 76 | workdir: Path, 77 | is_training: bool, 78 | max_history_per_repo: int, 79 | change_processor: ProjectChangeProcessor[C3Problem], 80 | cache: PickleCache, 81 | time_limit_per_commit: float = 10.0, 82 | ) -> _ProcessingResult: 83 | # use process-specific parso cache 84 | fix_jedi_cache(workdir) 85 | scoped_changes._tlogger.clear() 86 | change_processor.clear_stats() 87 | change_processor.set_training(is_training) 88 | key = f"{root.name}({max_history_per_repo}, {is_training=})" 89 | commits = [] 90 | if not cache.contains(key): 91 | # keep the oldest commits 92 | commits = get_commit_history(root)[-max_history_per_repo:] 93 | try: 94 | # cannot return here since subprocess maybe be killed after returning 95 | edits = cache.cached( 96 | key, 97 | lambda: edits_from_commit_history( 98 | root, 99 | commits, 100 | tempdir=workdir / "code" / root.name, 101 | change_processor=change_processor, 102 | silent=True, 103 | time_limit=time_limit_per_commit * (len(commits) + 10), 104 | ), 105 | ) 106 | except Exception as e: 107 | if isinstance(e, KeyboardInterrupt): 108 | raise 109 | warnings.warn(f"Failed to process project: {root}\nError: {e}") 110 | traceback.print_exception(e, limit=-6) 111 | edits = [] 112 | stats = dict() 113 | change_processor.append_stats(stats) 114 | rec_add_dict_to(stats, {"tlogger": scoped_changes._tlogger.times}) 115 | return _ProcessingResult(edits, stats) 116 | 117 | 118 | def dataset_from_projects( 119 | cache: PickleCache, 120 | project_roots: Sequence[Path], 121 | change_processor: ProjectChangeProcessor[TProb], 122 | repo_training: Sequence[bool], 123 | max_history_per_repo: int, 124 | time_limit_per_commit: float, 125 | workers: int = DefaultWorkers, 126 | ) -> "Mapping[Path, Sequence[TProb]]": 127 | """ 128 | Create a TokenizedEditDataset from a list of project roots and a given encoder. 129 | Args: 130 | - max_history_per_repo (int, optional): When the repo history is longer than 131 | this value, only the oldest portion is going to be used. Defaults to 1000. 132 | """ 133 | # get the process id 134 | pid = os.getpid() 135 | workdir = Path(tempfile.gettempdir()) / "dataset_from_projects" / f"pid-{pid}" 136 | 137 | roots = project_roots 138 | workdirs = [workdir / f"repo-{i}" for i in range(len(roots))] 139 | try: 140 | presults = pmap( 141 | _process_commits, 142 | roots, 143 | workdirs, 144 | repo_training, 145 | key_args={ 146 | "max_history_per_repo": max_history_per_repo, 147 | "change_processor": change_processor, 148 | "time_limit_per_commit": time_limit_per_commit, 149 | "cache": cache, 150 | }, 151 | max_workers=workers, 152 | tqdm_args={"unit": "repo"}, 153 | ) 154 | finally: 155 | if workdir.exists(): 156 | print("Removing workdir:", workdir) 157 | shutil.rmtree(workdir) 158 | 159 | project2edits = dict[Path, list[TProb]]() 160 | 161 | try: 162 | stats = dict[str, Any]() 163 | for root, pr in zip(roots, presults): 164 | project2edits.setdefault(root, []).extend(pr.edits) 165 | rec_add_dict_to(stats, pr.stats) 166 | 167 | if "tlogger" in stats: 168 | df = TimeLogger.times_to_dataframe(stats.pop("tlogger")) 169 | if not df.empty: 170 | print("Time stats:") 171 | display(df) 172 | if "analyzer_errors" in list(stats.keys()): 173 | errors: dict = stats.pop("analyzer_errors") 174 | for k in list(errors.keys()): 175 | if JediUsageAnalyzer.is_known_error(k): 176 | errors.pop(k) 177 | if errors: 178 | print("Analyzer errors:") 179 | for k in sorted(errors.keys(), key=lambda k: errors[k], reverse=True): 180 | print(f"{k}:\t{errors[k]}") 181 | if stats: 182 | print("Other Stats:") 183 | pretty_print_dict(stats) 184 | except Exception as e: 185 | if not isinstance(e, KeyboardInterrupt): 186 | print("Error while printing stats:", e) 187 | 188 | return project2edits 189 | 190 | 191 | def datasets_from_repo_splits( 192 | cache: PickleCache, 193 | repos_root: Path, 194 | change_processor: ProjectChangeProcessor[TProb], 195 | splits: Sequence[str] = ("test", "valid", "train"), 196 | max_history_per_repo: int = 1000, 197 | time_limit_per_commit: float = 10.0, 198 | workers: int = DefaultWorkers, 199 | ) -> dict[str, Sequence[TProb]]: 200 | projects = dict[str, list[Path]]() 201 | split_is_training = dict[str, list[bool]]() 202 | for split in splits: 203 | if not (repos_root / split).exists(): 204 | warnings.warn(f"Split {split} not found at {repos_root / split}.") 205 | continue 206 | ps = [p for p in (repos_root / split).iterdir() if p.is_dir] 207 | projects[split] = ps 208 | training = split == "train" 209 | split_is_training[split] = [training] * len(ps) 210 | if not ps: 211 | warnings.warn(f"No projects found in {split} split") 212 | 213 | dataset = dataset_from_projects( 214 | cache, 215 | join_list(projects.values()), 216 | change_processor=change_processor, 217 | repo_training=join_list(split_is_training.values()), 218 | time_limit_per_commit=time_limit_per_commit, 219 | max_history_per_repo=max_history_per_repo, 220 | workers=workers, 221 | ) 222 | return {k: join_list(dataset[r] for r in repos) for k, repos in projects.items()} 223 | 224 | 225 | class C3ProblemDataset(TypedDict, Generic[TProb]): 226 | train: Sequence[TProb] 227 | valid: Sequence[TProb] 228 | test: Sequence[TProb] 229 | 230 | 231 | def make_or_load_dataset( 232 | dataset_name: str, 233 | change_processor: ProjectChangeProcessor[TProb], 234 | splits: Sequence[str], 235 | remake_problems: bool = False, 236 | time_limit_per_commit: float = 10.0, 237 | workers: int = DefaultWorkers, 238 | ) -> C3ProblemDataset[TProb]: 239 | prob_config = repr_modified_args(change_processor) 240 | processed_dir = get_dataset_dir(dataset_name) / "processed" 241 | cache_dir = processed_dir / prob_config 242 | cache = PickleCache(cache_dir) 243 | if remake_problems: 244 | cache.clear() 245 | results = datasets_from_repo_splits( 246 | cache, 247 | get_dataset_dir(dataset_name) / "repos", 248 | change_processor, 249 | workers=workers, 250 | splits=splits, 251 | time_limit_per_commit=time_limit_per_commit, 252 | ) 253 | size_mb = 0.0 254 | n = 0 255 | for f in cache_dir.iterdir(): 256 | n += 1 257 | size_mb += f.stat().st_size / (1024**2) 258 | print(f"Dataset total size ({n=}): {size_mb:.2f} MB") 259 | 260 | return C3ProblemDataset( 261 | train=results.get("train", []), 262 | valid=results.get("valid", []), 263 | test=results.get("test", []), 264 | ) 265 | 266 | 267 | def make_or_load_transformed_dataset( 268 | dataset_name: str, 269 | dataset: C3ProblemDataset | None, 270 | encoder: C3CombinedEncoder, 271 | remake_problems: bool = False, 272 | workers: int = DefaultWorkers, 273 | ) -> dict[str, Sequence[C3Problem]]: 274 | def transform_eval_problems( 275 | dataset: C3ProblemDataset, 276 | ) -> dict[str, Sequence[C3Problem]]: 277 | results = dict[str, Sequence[C3Problem]]() 278 | for split in ("valid", "test"): 279 | prob_lists = pmap( 280 | encoder.problem_tranform.transform, 281 | dataset[split], 282 | desc=f"transform({split})", 283 | chunksize=1000, 284 | max_workers=workers, 285 | ) 286 | results[split] = join_list(prob_lists) 287 | return results 288 | 289 | proc_config = repr_modified_args(encoder.change_processor) 290 | trans_config = repr_modified_args(encoder.problem_tranform) 291 | transformed_dir = get_dataset_dir(dataset_name) / "transformed" 292 | cache = PickleCache(transformed_dir) 293 | return cache.cached( 294 | f"eval-{proc_config}-{trans_config}", 295 | lambda: transform_eval_problems(not_none(dataset)), 296 | remake=remake_problems, 297 | ) 298 | 299 | 300 | def save_datasets(datasets: Mapping[str, Any], save_dir: Path) -> None: 301 | for name, dataset in datasets.items(): 302 | pickle_dump(save_dir / f"{name}.pkl", dataset) 303 | subprocess.run(["du", "-sh", save_dir]) 304 | 305 | 306 | def load_datasets(save_dir: Path, splits=("test", "valid", "train")) -> dict[str, Any]: 307 | return { 308 | name: pickle_load(path) 309 | for name in splits 310 | if (path := (save_dir / f"{name}.pkl")).exists() 311 | } 312 | 313 | 314 | def get_repo_signature(repo: Path, n_commits: int = 30) -> tuple[str, ...]: 315 | # use the first n commits as the signature 316 | commits = get_commit_history(repo)[-n_commits:] 317 | return tuple(c.msg for c in commits) 318 | -------------------------------------------------------------------------------- /src/coeditor/experiments/code_completion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import PreTrainedModel, PreTrainedTokenizerBase 3 | 4 | from coeditor.c3problem import ( 5 | C3ProblemGenerator, 6 | C3ToCodeCompletion, 7 | CompletionKind, 8 | SrcInfo, 9 | TkC3Problem, 10 | ) 11 | from coeditor.change import Change 12 | from coeditor.common import * 13 | from coeditor.encoding import ( 14 | Add_id, 15 | BOS_id, 16 | Del_id, 17 | EOS_id, 18 | Newline_id, 19 | TkDelta, 20 | TruncateAt, 21 | _Tokenizer, 22 | change_tks_to_original_delta, 23 | change_to_tokens, 24 | decode_tokens, 25 | get_extra_id, 26 | output_ids_as_seqs, 27 | tk_splitlines, 28 | truncate_sections, 29 | ) 30 | from coeditor.model import C3DataLoader, CodeT5Model, RetrievalEditorModel 31 | from coeditor.scoped_changes import ( 32 | ChangedSpan, 33 | ChangeScope, 34 | JProjectChange, 35 | ProjectChangeProcessor, 36 | ) 37 | 38 | CodeT5TKN = _Tokenizer 39 | 40 | 41 | @dataclass(frozen=True) 42 | class FIMProblem: 43 | "A fill-in-the-middle problem." 44 | left_ctx: Sequence[str] 45 | right_ctx: Sequence[str] 46 | middle: str 47 | src_info: SrcInfo 48 | max_ctx_tks: int 49 | kind: CompletionKind 50 | path: ProjectPath 51 | 52 | def uid(self) -> tuple[ProjectPath, str]: 53 | return self.path, not_none(self.src_info["commit"]).hash 54 | 55 | def get_contexts( 56 | self, 57 | tokenizer: PreTrainedTokenizerBase, 58 | tks_limit: int = 2040, 59 | max_output: int = 256, 60 | ) -> tuple[str, str]: 61 | """Get the left and right contexts, truncated to the given token limit. 62 | This is mostly for visualization as the FIM model handles context truncation 63 | internally (in a more efficient way).""" 64 | left_tks: TokenSeq = tokenizer.encode( 65 | "\n".join(self.left_ctx) + "\n", add_special_tokens=False 66 | ) 67 | right_tks: TokenSeq = tokenizer.encode( 68 | "\n" + "".join(self.right_ctx), add_special_tokens=False 69 | ) 70 | left_tks, right_tks = truncate_sections( 71 | tks_limit - max_output, 72 | (left_tks, TruncateAt.Left), 73 | (right_tks, TruncateAt.Right), 74 | add_bos=False, 75 | ) 76 | left = tokenizer.decode(left_tks, clean_up_tokenization_spaces=False) 77 | right = tokenizer.decode(right_tks, clean_up_tokenization_spaces=False) 78 | return left, right 79 | 80 | 81 | @dataclass 82 | class C3CompletionGenerator(ProjectChangeProcessor[FIMProblem]): 83 | """ 84 | Extract fiil-in-the-middle problems from code changes. 85 | 86 | ## Arguments 87 | - `addition_only`: whether to only extract from problems where the last change is 88 | a pure additon (rather than a replacement). `addition_only` problems are easier 89 | for code completion models since they don't see any code that get deleted. 90 | 91 | ## Change log 92 | - version 1.1: Limit context str length to `10 * max_ctx_tks`. 93 | - version 1.2: Add `path` attribute to `FIMProblem`. 94 | - version 1.2: Change context str length to `6 * max_ctx_tks`. 95 | """ 96 | 97 | VERSION = "1.2" 98 | max_ctx_tks: int = 2048 99 | min_target_size: int = C3ToCodeCompletion.min_target_size 100 | use_additions: bool = C3ToCodeCompletion.use_additions 101 | use_modifications: bool = C3ToCodeCompletion.use_modifications 102 | generator: C3ProblemGenerator = field(default_factory=C3ProblemGenerator) 103 | 104 | def __post_init__(self): 105 | self._sampler = C3ToCodeCompletion( 106 | self.min_target_size, 107 | use_additions=self.use_additions, 108 | use_modifications=self.use_modifications, 109 | ) 110 | 111 | def use_unchanged(self) -> bool: 112 | return True 113 | 114 | def post_edit_analysis(self, *args, **kwargs) -> list[ModuleName]: 115 | return self.generator.post_edit_analysis(*args, **kwargs) 116 | 117 | def process_change( 118 | self, 119 | pchange: JProjectChange, 120 | pre_analysis: None, 121 | post_analysis: Sequence[ModuleName], 122 | ) -> Sequence[FIMProblem]: 123 | probs = list[FIMProblem]() 124 | src_info: SrcInfo = { 125 | "project": pchange.project_name, 126 | "commit": pchange.commit_info, 127 | } 128 | for m in post_analysis: 129 | if (mchange := pchange.changed.get(m)) is None: 130 | continue 131 | all_spans = list(mchange.changed) 132 | all_spans.sort(key=lambda s: s.line_range) 133 | old_spans, new_spans = self._get_old_new_spans(all_spans) 134 | for i, span in enumerate(all_spans): 135 | if not self.should_mk_problem( 136 | span, 137 | func_only=not self.is_training, 138 | max_chars=self.generator.max_span_chars, 139 | max_lines=self.generator.max_span_lines, 140 | ) or code_equal(span.change.earlier, span.change.later): 141 | # only keep non-trivial modifications 142 | continue 143 | origin, delta = change_tks_to_original_delta( 144 | change_to_tokens(span.change) 145 | ) 146 | sampled = self._sampler.extract_completion(origin, delta) 147 | if sampled is None: 148 | continue 149 | new_origin, new_delta, kind = sampled 150 | left, middle, right = self._split_change(new_origin, new_delta) 151 | above_spans = [left] if left else [] 152 | # add previous spans until total size exceeds max_ctx_tks 153 | above_sum = len(left) 154 | for s in reversed(new_spans[: 2 * i + 1]): 155 | if above_sum + len(s) >= self.max_ctx_tks * 6: 156 | break 157 | above_sum += len(s) 158 | above_spans.append(s) 159 | above_spans.reverse() 160 | below_spans = [right] if right else [] 161 | below_sum = len(right) 162 | for s in old_spans[2 * i + 2 :]: 163 | # take until below sum exceeds max_ctx_tks 164 | if below_sum + len(s) >= self.max_ctx_tks * 6: 165 | break 166 | below_sum += len(s) 167 | below_spans.append(s) 168 | probs.append( 169 | FIMProblem( 170 | above_spans, 171 | below_spans, 172 | middle, 173 | src_info, 174 | self.max_ctx_tks, 175 | kind, 176 | path=span.scope.later.path, 177 | ) 178 | ) 179 | return probs 180 | 181 | def _get_old_new_spans( 182 | self, spans: Sequence[ChangedSpan] 183 | ) -> tuple[list[str], list[str]]: 184 | old_spans = list[str]() 185 | new_spans = list[str]() 186 | last_scope = [] 187 | for span in spans: 188 | scope_diff = list[Change[ChangeScope]]() 189 | for i, s in enumerate(span.parent_scopes): 190 | if i >= len(last_scope) or s.later.path != last_scope[i].later.path: 191 | scope_diff.append(s) 192 | old_header = "\n".join( 193 | s.earlier.header_code.strip("\n") for s in scope_diff 194 | ) 195 | new_header = "\n".join(s.later.header_code.strip("\n") for s in scope_diff) 196 | old_spans.append(old_header) 197 | old_spans.append(span.change.earlier) 198 | new_spans.append(new_header) 199 | new_spans.append(span.change.later) 200 | last_scope = span.parent_scopes 201 | return old_spans, new_spans 202 | 203 | def _split_change(self, origin: TokenSeq, delta: TkDelta) -> tuple[str, str, str]: 204 | assert_eq(delta.num_changes(), 1) 205 | lines = tk_splitlines(origin) 206 | key, action = list(delta.items())[0] 207 | assert_eq(action[0], Add_id, lambda: "delta must be a single addition") 208 | target_line = key[0] 209 | left = join_list( 210 | ( 211 | r 212 | for l in lines[:target_line] 213 | if (r := _change_line_to_result(l)) is not None 214 | ), 215 | Newline_id, 216 | ) 217 | right = join_list(lines[target_line:], Newline_id) 218 | middle = action[1:] 219 | return tuple(decode_tokens(x) for x in (left, middle, right)) 220 | 221 | 222 | def _change_line_to_result(line: TokenSeq) -> TokenSeq | None: 223 | if not line: 224 | return [] 225 | if line[0] == Add_id: 226 | return line[1:] 227 | elif line[0] == Del_id: 228 | return None 229 | else: 230 | return line 231 | 232 | 233 | class FIMModel(ABC): 234 | model: PreTrainedModel 235 | tokenizer: PreTrainedTokenizerBase 236 | 237 | @abstractmethod 238 | def infill(self, left: str, right: str, max_output: int) -> str: 239 | ... 240 | 241 | 242 | @dataclass 243 | class CodeT5Wrapper(FIMModel): 244 | model: CodeT5Model 245 | tks_limit: int = 2048 246 | tokenizer = CodeT5TKN 247 | 248 | def infill(self, left: str, right: str, max_output: int = 128) -> str: 249 | tkn = self.tokenizer 250 | device = self.model.device 251 | left_tks: TokenSeq = tkn.encode(left, add_special_tokens=False) 252 | right_tks: TokenSeq = tkn.encode(right, add_special_tokens=False) 253 | left_tks, right_tks = truncate_sections( 254 | self.tks_limit - max_output - 8, 255 | (left_tks, TruncateAt.Left), 256 | (right_tks, TruncateAt.Right), 257 | add_bos=False, 258 | ) 259 | input_ids = join_list( 260 | [[BOS_id], left_tks, [get_extra_id(0)], right_tks, [EOS_id]] 261 | ) 262 | input_ids = torch.LongTensor([input_ids]).to(device) 263 | output_ids = self.model.generate( 264 | input_ids=input_ids, 265 | do_sample=False, 266 | max_length=max_output, 267 | ) 268 | assert isinstance(output_ids, torch.Tensor) 269 | output_ids = output_ids[0].tolist() 270 | infill_ids = output_ids_as_seqs(output_ids)[get_extra_id(0)] 271 | return decode_tokens(infill_ids) 272 | 273 | @staticmethod 274 | def from_pretrained(model_name: str = "Salesforce/codet5-base"): 275 | model = CodeT5Model.from_pretrained(model_name) 276 | assert isinstance(model, CodeT5Model) 277 | return CodeT5Wrapper(model) 278 | 279 | 280 | _infill_prefix = _Tokenizer.encode("", add_special_tokens=False) 281 | 282 | 283 | def infill_with_coeditor( 284 | coeditor: RetrievalEditorModel, tk_prob: TkC3Problem, max_length: int = 128 285 | ) -> TokenSeq: 286 | """Run the Coeditor model on the (inifilling version) C3 Problem, return the 287 | model output.""" 288 | 289 | device = coeditor.device 290 | batch = C3DataLoader.pack_batch([tk_prob]) 291 | # the prefix is always an addition 292 | prefix_allowed_tokens_fn = RetrievalEditorModel._prefix_constraint([_infill_prefix]) 293 | input_ids = torch.LongTensor(batch["input_ids"]).to(device) 294 | output_ids = coeditor.generate( 295 | input_ids=input_ids, 296 | references=batch["references"], 297 | query_ref_list=batch["query_ref_list"], 298 | do_sample=False, 299 | max_length=max_length, 300 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 301 | ) 302 | assert isinstance(output_ids, torch.Tensor) 303 | output_ids = output_ids[0].tolist() 304 | return output_ids 305 | -------------------------------------------------------------------------------- /src/coeditor/experiments/in_coder.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import tokenizers 4 | import torch 5 | from torch import Tensor 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from transformers.models.xglm.modeling_xglm import XGLMForCausalLM 8 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 9 | 10 | from coeditor.common import * 11 | from coeditor.encoding import TruncateAt, truncate_sections 12 | from coeditor.experiments.code_completion import FIMModel 13 | 14 | InCoderModelType = XGLMForCausalLM 15 | InCoderTokenizerType = PreTrainedTokenizerFast 16 | 17 | # signals the start of a document 18 | BOS = "<|endoftext|>" 19 | # signals the end of a generated infill 20 | EOM = "<|endofmask|>" 21 | 22 | 23 | def make_sentinel(i) -> str: 24 | # signals (1) a location to insert an infill and (2) the start of the infill generation 25 | return f"<|mask:{i}|>" 26 | 27 | 28 | @dataclass 29 | class InCoderWrapper(FIMModel): 30 | model: InCoderModelType 31 | tokenizer: InCoderTokenizerType 32 | tks_limit: int = 2048 33 | 34 | def __post_init__(self): 35 | self.bos_ids = self.tokenizer.encode(BOS, add_special_tokens=False) 36 | self.mask0_ids = self.tokenizer.encode( 37 | make_sentinel(0), add_special_tokens=False 38 | ) 39 | self.mask1_ids = self.tokenizer.encode( 40 | make_sentinel(1), add_special_tokens=False 41 | ) 42 | 43 | def infill(self, left: str, right: str, max_output: int) -> str: 44 | tkn = self.tokenizer 45 | device = self.model.device 46 | left_tks: TokenSeq = tkn.encode(left, add_special_tokens=False) 47 | right_tks: TokenSeq = tkn.encode(right, add_special_tokens=False) 48 | left_tks, right_tks = truncate_sections( 49 | self.tks_limit - max_output - 8, 50 | (left_tks, TruncateAt.Left), 51 | (right_tks, TruncateAt.Right), 52 | add_bos=False, 53 | ) 54 | 55 | input_ids = join_list( 56 | [ 57 | self.bos_ids, 58 | left_tks, 59 | self.mask0_ids, 60 | right_tks, 61 | self.mask1_ids, 62 | self.mask0_ids, 63 | ] 64 | ) 65 | total_length = len(input_ids) + max_output 66 | if total_length > self.tks_limit: 67 | warnings.warn( 68 | f"Total length too large: {total_length=} (> {self.tks_limit})" 69 | ) 70 | input_ids = torch.LongTensor([input_ids]).to(device) 71 | output = self.model.generate( 72 | input_ids=input_ids, 73 | do_sample=False, 74 | max_length=total_length, 75 | ) 76 | assert isinstance(output, Tensor) 77 | output_ids = output[0].tolist() 78 | output_ids = output_ids[input_ids.size(1) :] 79 | completion: str = tkn.decode(output_ids, clean_up_tokenization_spaces=False) 80 | 81 | if EOM not in completion: 82 | completion += EOM 83 | completion = completion[: completion.index(EOM) + len(EOM)] 84 | infilled = completion[: -len(EOM)] 85 | return infilled 86 | 87 | def infill_multi( 88 | self, 89 | parts: Sequence[str], 90 | max_to_generate: int = 128, 91 | temperature: float = 0.2, 92 | extra_sentinel: bool = True, 93 | max_retries: int = 1, 94 | VERBOSE: bool = False, 95 | ): 96 | retries_attempted = 0 97 | done = False 98 | prompt = text = "" 99 | infills = [] 100 | 101 | while (not done) and (retries_attempted < max_retries): 102 | retries_attempted += 1 103 | 104 | if VERBOSE: 105 | print(f"retry {retries_attempted}") 106 | 107 | ## (1) build the prompt 108 | if len(parts) == 1: 109 | prompt = parts[0] 110 | else: 111 | prompt = "" 112 | # encode parts separated by sentinel 113 | for sentinel_ix, part in enumerate(parts): 114 | prompt += part 115 | if extra_sentinel or (sentinel_ix < len(parts) - 1): 116 | prompt += make_sentinel(sentinel_ix) 117 | 118 | infills = list[str]() 119 | 120 | done = True 121 | 122 | ## (2) generate infills 123 | for sentinel_ix, part in enumerate(parts[:-1]): 124 | prompt += make_sentinel(sentinel_ix) 125 | # TODO: this is inefficient as it requires re-encoding prefixes repeatedly 126 | completion = self.generate(prompt, max_to_generate, temperature) 127 | completion = completion[len(prompt) :] 128 | if EOM not in completion: 129 | if VERBOSE: 130 | print(f"warning: {EOM} not found") 131 | completion += EOM 132 | done = False 133 | completion = completion[: completion.index(EOM) + len(EOM)] 134 | infilled = completion[: -len(EOM)] 135 | infills.append(infilled) 136 | prompt += completion 137 | 138 | return infills 139 | 140 | def generate( 141 | self, input: str, max_to_generate: int = 128, temperature: float = 0.2 142 | ): 143 | """ 144 | Do standard left-to-right completion of the prefix `input` by sampling from the model 145 | """ 146 | tkn = self.tokenizer 147 | device = self.model.device 148 | input_ids: Tensor = tkn(input, return_tensors="pt").input_ids.to(device) 149 | max_length = max_to_generate + input_ids.flatten().size(0) 150 | if max_length > 2048: 151 | print( 152 | "warning: max_length {} is greater than the context window {}".format( 153 | max_length, 2048 154 | ) 155 | ) 156 | output = self.model.generate( 157 | input_ids=input_ids, 158 | do_sample=True, 159 | top_p=0.95, 160 | temperature=temperature, 161 | max_length=max_length, 162 | ) 163 | assert isinstance(output, Tensor) 164 | # pass clean_up_tokenization_spaces=False to avoid removing spaces before punctuation, e.g. "from ." -> "from." 165 | detok_hypo_str = tkn.decode( 166 | output.flatten(), clean_up_tokenization_spaces=False 167 | ) 168 | if detok_hypo_str.startswith(BOS): 169 | detok_hypo_str = detok_hypo_str[len(BOS) :] 170 | return detok_hypo_str 171 | 172 | @staticmethod 173 | def from_pretrained( 174 | model_name: str = "facebook/incoder-1B", half_precision: bool = True 175 | ): 176 | model = AutoModelForCausalLM.from_pretrained(model_name) 177 | tokenizer = AutoTokenizer.from_pretrained(model_name) 178 | assert isinstance(model, InCoderModelType) 179 | assert isinstance(tokenizer, InCoderTokenizerType) 180 | if half_precision: 181 | model = model.half() 182 | return InCoderWrapper(model, tokenizer) 183 | -------------------------------------------------------------------------------- /src/coeditor/experiments/openai_gpt.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | 3 | import openai 4 | import tenacity 5 | import tiktoken 6 | 7 | from coeditor.common import * 8 | from coeditor.encoding import TruncateAt, truncate_sections 9 | 10 | 11 | @dataclass 12 | class OpenAIGptWrapper: 13 | model: str = "gpt-3.5-turbo" 14 | tks_limit: int = 4096 15 | use_fim: bool = True 16 | use_nl_prompt: bool = False 17 | print_prompt: bool = False 18 | 19 | def __post_init__(self): 20 | api_key = openai.api_key = os.getenv("OPENAI_API_KEY") 21 | if api_key is None: 22 | raise RuntimeError("OPENAI_API_KEY env variable not set.") 23 | self.tokenizer = tiktoken.encoding_for_model(self.model) 24 | 25 | def infill(self, left: str, right: str, max_output: int) -> str: 26 | if self.use_fim: 27 | return self.infill_fim(left, right, max_output) 28 | else: 29 | return self.infill_lm(left, max_output) 30 | 31 | def infill_lm(self, left: str, max_output: int) -> str: 32 | """Infill code using left-to-right language modeling.""" 33 | left_tks = self.tokenizer.encode(left, disallowed_special=()) 34 | left_tks = truncate_sections( 35 | self.tks_limit - max_output - 400, 36 | (left_tks, TruncateAt.Left), 37 | add_bos=False, 38 | )[0] 39 | left_str = self.tokenizer.decode(left_tks) 40 | prompt = left_str 41 | self._print_prompt(prompt) 42 | return self._get_result(prompt, role="assistant", max_output=max_output) 43 | 44 | def infill_fim(self, left: str, right: str, max_output: int) -> str: 45 | """Infill code using FIM prompting.""" 46 | 47 | left_tks = self.tokenizer.encode(left, disallowed_special=()) 48 | right_tks = self.tokenizer.encode(right, disallowed_special=()) 49 | left_tks, right_tks = truncate_sections( 50 | self.tks_limit - max_output - 400, 51 | (left_tks, TruncateAt.Left), 52 | (right_tks, TruncateAt.Right), 53 | add_bos=False, 54 | ) 55 | left_str = self.tokenizer.decode(left_tks) 56 | right_str = self.tokenizer.decode(right_tks) 57 | 58 | if self.use_nl_prompt: 59 | prompt = f"""\ 60 | You are a programming expert tasked to fill in a missing line for a given Python code 61 | snippet. The snippet may have been truncated from both ends, and the missing line is 62 | indicated by a special token ``. 63 | You should output the missing line (along with any leading whitespaces) and 64 | nothing more. For example, if the input is 65 | ``` 66 | def fib(n): 67 | if n < 2: 68 | 69 | else: 70 | return fib(n-1) + fib( 71 | ``` 72 | Your output should be " return 1" (without the quotes) and nothing more. 73 | 74 | Now fill in the code snippet below: 75 | ``` 76 | {left_str}{right_str} 77 | ``` 78 | Your output: 79 | """ 80 | self._print_prompt(prompt) 81 | return self._get_result(prompt, max_output=max_output, role="user") 82 | 83 | else: 84 | prompt = f"""{right_str}\n--------\n{left_str}""" 85 | self._print_prompt(prompt) 86 | return self._get_result(prompt, max_output=max_output, role="assistant") 87 | 88 | @staticmethod 89 | def _sleep_notify(time): 90 | print(f"Waiting for {time:.2f} seconds") 91 | sleep(time) 92 | 93 | @tenacity.retry( 94 | sleep=_sleep_notify, 95 | wait=tenacity.wait_fixed(30), 96 | stop=tenacity.stop_after_attempt(6), 97 | ) 98 | def _get_result(self, prompt: str, max_output: int, role: str = "user") -> str: 99 | messages = [{"role": role, "content": prompt}] 100 | completion: Any = openai.ChatCompletion.create( 101 | model=self.model, 102 | messages=messages, 103 | temperature=0.0, 104 | max_tokens=max_output, 105 | stop=["\n", "\r\n"], 106 | ) 107 | result = completion.choices[0].message.content 108 | assert isinstance(result, str) 109 | return result 110 | 111 | def _print_prompt(self, prompt: str): 112 | if self.print_prompt: 113 | print(prompt) 114 | print(SEP) 115 | print("End of Prompt") 116 | print(SEP) 117 | 118 | 119 | my_password = "password123" 120 | -------------------------------------------------------------------------------- /src/coeditor/experiments/santa_coder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel 4 | from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast 5 | 6 | from coeditor.common import * 7 | from coeditor.encoding import TruncateAt, truncate_sections 8 | from coeditor.experiments.code_completion import FIMModel 9 | 10 | SantaCoderModelType = GPT2PreTrainedModel 11 | SantaCoderTokenizerType = GPT2TokenizerFast 12 | 13 | 14 | @dataclass 15 | class SantaCoderWrapper(FIMModel): 16 | model: SantaCoderModelType 17 | tokenizer: SantaCoderTokenizerType 18 | tks_limit: int = 2048 19 | 20 | def __post_init__(self): 21 | added = self.tokenizer.get_added_vocab() 22 | self.endoftext = self.tokenizer.encode("<|endoftext|>")[0] 23 | self.fim_prefix = added[""] 24 | self.fim_middle = added[""] 25 | self.fim_suffix = added[""] 26 | 27 | def infill(self, left: str, right: str, max_output: int) -> str: 28 | tkn = self.tokenizer 29 | device = self.model.device 30 | left_tks: TokenSeq = tkn.encode(left, add_special_tokens=False) 31 | right_tks: TokenSeq = tkn.encode(right, add_special_tokens=False) 32 | left_tks, right_tks = truncate_sections( 33 | self.tks_limit - max_output - 4, 34 | (left_tks, TruncateAt.Left), 35 | (right_tks, TruncateAt.Right), 36 | add_bos=False, 37 | ) 38 | 39 | input_ids = join_list( 40 | [ 41 | [self.fim_prefix], 42 | left_tks, 43 | [self.fim_suffix], 44 | right_tks, 45 | [self.fim_middle], 46 | ] 47 | ) 48 | total_length = len(input_ids) + max_output 49 | if total_length > self.tks_limit: 50 | warnings.warn( 51 | f"Total length {total_length} exceeds the limit of {self.tks_limit}." 52 | ) 53 | input_ids = torch.tensor([input_ids], device=device) 54 | output = self.model.generate( 55 | input_ids=input_ids, 56 | attention_mask=None, 57 | do_sample=False, 58 | max_length=total_length, 59 | eos_token_id=self.endoftext, 60 | pad_token_id=self.endoftext, 61 | ) 62 | assert isinstance(output, torch.Tensor) 63 | output_ids = output[0].tolist() 64 | output_ids = output_ids[input_ids.size(1) :] 65 | if output_ids[-1] == self.endoftext: 66 | output_ids = output_ids[:-1] 67 | completion: str = tkn.decode(output_ids, clean_up_tokenization_spaces=False) 68 | 69 | return completion 70 | 71 | @staticmethod 72 | def from_pretrained(model_name: str = "bigcode/santacoder"): 73 | model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) 74 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 75 | assert isinstance(model, SantaCoderModelType) 76 | assert isinstance(tokenizer, SantaCoderTokenizerType) 77 | return SantaCoderWrapper(model, tokenizer) 78 | -------------------------------------------------------------------------------- /src/coeditor/experiments/star_coder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast 4 | from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM 5 | 6 | from coeditor.common import * 7 | from coeditor.encoding import TruncateAt, truncate_sections 8 | from coeditor.experiments.code_completion import FIMModel 9 | 10 | SantaCoderModelType = GPTBigCodeForCausalLM 11 | SantaCoderTokenizerType = GPT2TokenizerFast 12 | 13 | 14 | @dataclass 15 | class StarCoderWrapper(FIMModel): 16 | model: SantaCoderModelType 17 | tokenizer: SantaCoderTokenizerType 18 | tks_limit: int = 1024 * 8 19 | 20 | def __post_init__(self): 21 | vocab = self.tokenizer.vocab 22 | self.endoftext = vocab["<|endoftext|>"] 23 | self.fim_prefix = vocab[""] 24 | self.fim_middle = vocab[""] 25 | self.fim_suffix = vocab[""] 26 | 27 | def infill(self, left: str, right: str, max_output: int) -> str: 28 | tkn = self.tokenizer 29 | device = self.model.device 30 | left_tks: TokenSeq = tkn.encode(left, add_special_tokens=False) 31 | right_tks: TokenSeq = tkn.encode(right, add_special_tokens=False) 32 | left_tks, right_tks = truncate_sections( 33 | self.tks_limit - max_output - 4, 34 | (left_tks, TruncateAt.Left), 35 | (right_tks, TruncateAt.Right), 36 | add_bos=False, 37 | ) 38 | 39 | input_ids = join_list( 40 | [ 41 | [self.fim_prefix], 42 | left_tks, 43 | [self.fim_suffix], 44 | right_tks, 45 | [self.fim_middle], 46 | ] 47 | ) 48 | total_length = len(input_ids) + max_output 49 | if total_length > self.tks_limit: 50 | warnings.warn( 51 | f"Total length {total_length} exceeds the limit of {self.tks_limit}." 52 | ) 53 | input_ids = torch.tensor([input_ids], device=device) 54 | output = self.model.generate( 55 | input_ids=input_ids, 56 | attention_mask=None, 57 | do_sample=False, 58 | max_length=total_length, 59 | eos_token_id=self.endoftext, 60 | pad_token_id=self.endoftext, 61 | ) 62 | assert isinstance(output, torch.Tensor) 63 | output_ids = output[0].tolist() 64 | output_ids = output_ids[input_ids.size(1) :] 65 | if output_ids[-1] == self.endoftext: 66 | output_ids = output_ids[:-1] 67 | completion: str = tkn.decode(output_ids, clean_up_tokenization_spaces=False) 68 | 69 | return completion 70 | 71 | @staticmethod 72 | def from_pretrained( 73 | model_name: str = "bigcode/starcoderbase-7b", half_precision: bool = True 74 | ): 75 | model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) 76 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 77 | assert isinstance(model, SantaCoderModelType) 78 | assert isinstance(tokenizer, SantaCoderTokenizerType) 79 | if half_precision: 80 | model = model.half() 81 | return StarCoderWrapper(model, tokenizer) 82 | -------------------------------------------------------------------------------- /src/coeditor/git.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from datetime import datetime 3 | 4 | import dateparser 5 | 6 | from coeditor.common import * 7 | 8 | 9 | @dataclass(frozen=True) 10 | class CommitInfo: 11 | hash: str 12 | parents: tuple[str, ...] 13 | msg: str 14 | 15 | def summary(self) -> str: 16 | return f"[{self.hash[:10]} {short_str(self.msg)}]" 17 | 18 | 19 | def get_commit_history( 20 | project_dir: Path, 21 | max_history: int | None = None, 22 | commit_id: str = "HEAD", 23 | ) -> list[CommitInfo]: 24 | """Get the commit history of the project, start from the given `commit_id`, 25 | going backward in time. 26 | When a merge commit is encountered, the second parent (the branch that's 27 | being merged in) is used as the history. 28 | """ 29 | commit_id = run_command( 30 | ["git", "rev-parse", commit_id], 31 | cwd=project_dir, 32 | ).strip() 33 | history = [] 34 | for _ in range(max_history if max_history else 100000): 35 | lines = run_command( 36 | ["git", "cat-file", "-p", commit_id], 37 | cwd=project_dir, 38 | ).splitlines() 39 | parents = [] 40 | for line in lines[1:]: 41 | if line.startswith("parent "): 42 | parents.append(line.split(" ")[1]) 43 | else: 44 | break 45 | commit_msg = run_command( 46 | ["git", "show", commit_id, "-s", "--format=%s"], 47 | cwd=project_dir, 48 | ).strip() 49 | history.append(CommitInfo(commit_id, tuple(parents), commit_msg)) 50 | if not parents: 51 | break 52 | commit_id = parents[-1] 53 | return history 54 | 55 | 56 | def file_content_from_commit( 57 | project_dir: Path, 58 | commit: str, 59 | path: str, 60 | ) -> str: 61 | return run_command( 62 | ["git", "show", f"{commit}:{path}"], 63 | cwd=project_dir, 64 | ) 65 | 66 | 67 | @dataclass 68 | class GitRepo: 69 | author: str 70 | name: str 71 | url: str 72 | stars: int 73 | forks: int 74 | description: str 75 | license: str 76 | archived: bool 77 | last_update: Optional[datetime] = None 78 | num_commits: Optional[int] = None 79 | 80 | def authorname(self): 81 | return f"{self.author}~{self.name}" 82 | 83 | def get_root(self, repos_dir: Path) -> Path: 84 | return repos_dir / "downloaded" / self.authorname() 85 | 86 | def download( 87 | self, repos_dir: Path, full_history: bool = True, timeout=None 88 | ) -> bool: 89 | depth = "--depth=1" if not full_history else "" 90 | subprocess.run( 91 | ["git", "clone", *depth, self.url, self.authorname()], 92 | cwd=(repos_dir / "downloading"), 93 | timeout=timeout, 94 | capture_output=True, 95 | ) 96 | if not (repos_dir / "downloading" / self.authorname()).is_dir(): 97 | # git clone failed. Possibly caused by invalid url? 98 | return False 99 | shutil.move( 100 | repos_dir / "downloading" / self.authorname(), (repos_dir / "downloaded") 101 | ) 102 | return True 103 | 104 | def read_last_update(self, repos_dir): 105 | d = self.get_root(repos_dir) 106 | s = subprocess.run( 107 | ["git", "log", "-1", "--format=%cd"], cwd=d, capture_output=True, text=True 108 | ).stdout 109 | lu = dateparser.parse(s.split("+")[0]) 110 | assert lu is not None 111 | self.last_update = lu.replace(tzinfo=None) 112 | return self.last_update 113 | 114 | def count_lines_of_code(self, repos_dir): 115 | n_lines = 0 116 | for src in self.get_root(repos_dir).glob("**/*.py"): 117 | with open(src, "r") as fp: 118 | n_lines += sum(1 for line in fp if line.rstrip()) 119 | self.lines_of_code = n_lines 120 | return n_lines 121 | 122 | def count_commits(self, repos_dir) -> int: 123 | result = run_command( 124 | ["git", "rev-list", "--count", "HEAD"], 125 | cwd=self.get_root(repos_dir), 126 | ) 127 | n = int(result) 128 | self.num_commits = n 129 | return n 130 | 131 | def revert_changes(self, repos_dir): 132 | rd = self.get_root(repos_dir) 133 | result = subprocess.run( 134 | ["git", "diff", "--name-only"], cwd=rd, capture_output=True, text=True 135 | ) 136 | if result.returncode == 0 and result.stdout.strip() != "": 137 | print("Reverting changes in", rd) 138 | subprocess.run( 139 | ["git", "checkout", "."], 140 | cwd=rd, 141 | ) 142 | 143 | @staticmethod 144 | def from_github_item(item: dict): 145 | return GitRepo( 146 | author=item["owner"]["login"], 147 | name=item["name"], 148 | url=item["html_url"], 149 | description=item["description"], 150 | license=item["license"]["key"], 151 | stars=item["stargazers_count"], 152 | forks=item["forks_count"], 153 | archived=item["archived"], 154 | last_update=not_none(dateparser.parse(item["pushed_at"])).replace( 155 | tzinfo=None 156 | ), 157 | ) 158 | -------------------------------------------------------------------------------- /src/coeditor/tk_array.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | 5 | from coeditor.common import * 6 | from coeditor.encoding import TruncateAt, decode_tokens, truncate_section 7 | 8 | 9 | class TkArray(ABC): 10 | @abstractmethod 11 | def __len__(self) -> int: 12 | ... 13 | 14 | @abstractmethod 15 | def tolist(self) -> TokenSeq: 16 | ... 17 | 18 | def truncate(self, dir: TruncateAt.Value, new_len: int) -> "TkArray": 19 | if new_len >= len(self): 20 | return self 21 | return _TruncatedTkArray(self, dir, new_len) 22 | 23 | @staticmethod 24 | def join(segs: Iterable["TkArray"], sep: int | None) -> "TkArray": 25 | return _JoinedTkArray(tuple(segs), sep, sum(len(seg) for seg in segs)) 26 | 27 | @staticmethod 28 | def new(tks: Sequence[int]) -> "TkArray": 29 | return _NumpyTkArray(np.array(tks, dtype=np.int32)) 30 | 31 | def _peek(self) -> str: 32 | tks = self.tolist() 33 | text = decode_tokens(tks) 34 | if len(text) > 100: 35 | text = text[:100] + "..." 36 | return text 37 | 38 | def __repr__(self) -> str: 39 | return f"TkArray(length={len(self)}, text={repr(self._peek())})" 40 | 41 | 42 | @dataclass(frozen=True) 43 | class _NumpyTkArray(TkArray): 44 | data: np.ndarray 45 | 46 | def __len__(self) -> int: 47 | return len(self.data) 48 | 49 | def tolist(self) -> TokenSeq: 50 | return self.data.tolist() 51 | 52 | 53 | @dataclass(frozen=True) 54 | class _JoinedTkArray(TkArray): 55 | "A chain-like data structure for concatenated `TkArray`s." 56 | 57 | segs: tuple[TkArray, ...] 58 | sep: int | None 59 | length: int 60 | 61 | def __len__(self) -> int: 62 | return self.length 63 | 64 | def tolist(self) -> TokenSeq: 65 | result = TokenSeq() 66 | for i, seg in enumerate(self.segs): 67 | if self.sep is not None and i > 0: 68 | result.append(self.sep) 69 | result.extend(seg.tolist()) 70 | return result 71 | 72 | 73 | @dataclass(frozen=True) 74 | class _TruncatedTkArray(TkArray): 75 | "A chain-like data structure for concatenated `TkArray`s." 76 | original: TkArray 77 | direction: TruncateAt.Value 78 | length: int 79 | 80 | def __len__(self) -> int: 81 | return self.length 82 | 83 | def tolist(self) -> TokenSeq: 84 | return truncate_section( 85 | self.original.tolist(), self.direction, self.length, inplace=True 86 | ) 87 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrVPlusOne/Coeditor/4d645e1293f6d5a9a93dd893b0b608dc54153a6a/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_analysis.py: -------------------------------------------------------------------------------- 1 | import jedi 2 | import pytest 3 | 4 | from coeditor._utils import proj_root 5 | from coeditor.c3problem import JediUsageAnalyzer, PyDefinition, PyFullName 6 | from coeditor.common import * 7 | 8 | testcase_root = Path(__file__).parent / "testcases" 9 | 10 | 11 | def assert_has_usages(defs: Collection[PyDefinition], *full_names: str): 12 | nameset = list(d.full_name for d in defs) 13 | for name in full_names: 14 | if PyFullName(name) not in nameset: 15 | raise AssertionError(f"{name} not in {nameset}") 16 | 17 | 18 | def assert_no_usages(defs: Collection[PyDefinition], *full_names: str): 19 | nameset = list(d.full_name for d in defs) 20 | for name in full_names: 21 | if PyFullName(name) in nameset: 22 | raise AssertionError(f"{name} should not be in {nameset}") 23 | 24 | 25 | def test_anlayzing_defs(): 26 | analyzer = JediUsageAnalyzer() 27 | project = jedi.Project(path=testcase_root, added_sys_path=[proj_root() / "src"]) 28 | script = jedi.Script(path=testcase_root / "defs.py", project=project) 29 | analysis = analyzer.get_line_usages(script, range(0, 46), silent=True) 30 | 31 | if analyzer.error_counts: 32 | raise RuntimeError(f"Errors found: {analyzer.error_counts}") 33 | 34 | assert_has_usages( 35 | analysis.line2usages[10], 36 | "defs.ScopeTree", 37 | "parso.python.tree.Function", 38 | "parso.python.tree.Class", 39 | "parso.python.tree.Module", 40 | ) 41 | 42 | assert_has_usages( 43 | analysis.line2usages[21], 44 | "defs.ChangeScope.path", 45 | "coeditor.common.ProjectPath", 46 | ) 47 | 48 | with pytest.raises(AssertionError): 49 | # wait for jedi to be fixed. 50 | assert_has_usages( 51 | analysis.line2usages[22], 52 | "defs.ChangeScope.tree", 53 | "defs.ChangeScope", # include parent usage as well 54 | "defs.ScopeTree", 55 | ) 56 | 57 | assert_has_usages( 58 | analysis.line2usages[23], 59 | "defs.ChangeScope.spans", 60 | "defs.ChangeScope", 61 | "typing.Sequence", 62 | ) 63 | 64 | assert_has_usages( 65 | analysis.line2usages[24], 66 | "typing.Mapping", 67 | "coeditor.common.ProjectPath", 68 | ) 69 | 70 | assert_has_usages( 71 | analysis.line2usages[28], 72 | "defs.ChangeScope.spans", 73 | ) 74 | 75 | assert_has_usages( 76 | analysis.line2usages[31], 77 | "coeditor.common.ProjectPath", 78 | "defs.ScopeTree", 79 | # "defs.ChangeScope", # couldn't handle string annotations for now 80 | ) 81 | 82 | assert_has_usages( 83 | analysis.line2usages[40], 84 | "parso.tree.BaseNode.__init__.children", 85 | ) 86 | 87 | assert_has_usages( 88 | analysis.line2usages[42], 89 | "parso.python.tree.PythonNode", 90 | "parso.python.tree.Scope.get_suite", 91 | # "parso.python.tree.BaseNode.children", 92 | ) 93 | 94 | 95 | # @pytest.mark.xfail(reason="Due to jedi bug") 96 | # def test_dataclass_signature(): 97 | # s = jedi.Script( 98 | # dedent( 99 | # """\ 100 | # from dataclasses import dataclass 101 | # @dataclass 102 | # class Foo: 103 | # bar: int 104 | # """ 105 | # ) 106 | # ) 107 | 108 | # defs = s.goto(3, 8) # go to Foo directly 109 | # assert len(defs) == 1 110 | # n = defs[0] 111 | # assert n._get_docstring_signature() == "Foo(bar: int)" 112 | 113 | # defs = s.goto(4, 6) # first go to bar 114 | # print(f"{len(defs)=}") 115 | # n = defs[0].parent() # then go to parent 116 | # assert n._get_docstring_signature() == "Foo(bar: int)" 117 | 118 | 119 | def test_anlayzing_usages(): 120 | analyzer = JediUsageAnalyzer() 121 | project = jedi.Project(path=testcase_root, added_sys_path=[proj_root() / "src"]) 122 | script = jedi.Script(path=testcase_root / "usages.py", project=project) 123 | analysis = analyzer.get_line_usages(script, range(0, 63), silent=True) 124 | 125 | if analyzer.error_counts: 126 | raise RuntimeError(f"Errors found: {analyzer.error_counts}") 127 | 128 | assert_has_usages( 129 | analysis.line2usages[11], 130 | "usages.JModule.tree", 131 | "parso.python.tree.Module", 132 | ) 133 | 134 | assert_has_usages( 135 | analysis.line2usages[13], 136 | "usages.JModule._to_scope", 137 | "defs.ChangeScope", 138 | ) 139 | 140 | assert_has_usages( 141 | analysis.line2usages[14], 142 | "usages.JModule.mname", 143 | "usages.JModule.tree", 144 | "defs.ChangeScope", 145 | "defs.ChangeScope.from_tree", 146 | "coeditor.common.ProjectPath", 147 | ) 148 | 149 | assert_has_usages( 150 | analysis.line2usages[19], 151 | "usages.JModule.iter_imports", 152 | ) 153 | 154 | assert_has_usages( 155 | analysis.line2usages[21], 156 | # "parso.python.tree.ImportFrom.get_from_names", 157 | ) 158 | 159 | assert_has_usages( 160 | analysis.line2usages[34], 161 | "coeditor._utils.as_any", 162 | ) 163 | -------------------------------------------------------------------------------- /tests/test_edits.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from coeditor.change import * 4 | from coeditor.encoding import * 5 | from coeditor.encoding import _BaseTokenizer, _Tokenizer 6 | 7 | 8 | def get_rng(): 9 | return random.Random(42) 10 | 11 | 12 | def get_before(change: Change[str]) -> str: 13 | if isinstance(change, Modified): 14 | return change.before 15 | elif isinstance(change, Added): 16 | return "" 17 | elif isinstance(change, Deleted): 18 | return change.before 19 | else: 20 | raise ValueError(f"Unknown change type: {change}") 21 | 22 | 23 | def get_after(change: Change[str]) -> str: 24 | if isinstance(change, Modified): 25 | return change.after 26 | elif isinstance(change, Added): 27 | return change.after 28 | elif isinstance(change, Deleted): 29 | return "" 30 | else: 31 | raise ValueError(f"Unknown change type: {change}") 32 | 33 | 34 | def assert_change_eq(actual: Change[str], expected: Change[str], name: str): 35 | assert_str_equal(get_before(actual), get_before(expected), name) 36 | assert_str_equal(get_after(actual), get_after(expected), name) 37 | 38 | 39 | def assert_tks_eq(actual: TokenSeq, expected: TokenSeq, name: str): 40 | actual_str = decode_tokens(actual) 41 | expected_str = decode_tokens(expected) 42 | assert_str_equal(actual_str, expected_str, name) 43 | 44 | 45 | def test_splitlines(): 46 | rng = get_rng() 47 | for n in range(60): 48 | for _ in range(10): 49 | rand_input = [rng.choice(["a", "b", "c", "\n"]) for _ in range(n)] 50 | input = "".join(rand_input).rstrip("\n") 51 | lines = splitlines(input) 52 | 53 | # basic identity 54 | assert "\n".join(lines) == input 55 | assert count_lines(input) == len(lines) 56 | 57 | # encode and decode 58 | enc = encode_lines_join(input) 59 | assert decode_tokens(enc) == input 60 | 61 | # split tokens 62 | tk_lines = tk_splitlines(enc) 63 | assert len(tk_lines) == len(lines) 64 | assert_tks_eq(join_list(tk_lines, Newline_id), enc, "join_list(tk_lines)") 65 | 66 | 67 | def test_get_lines(): 68 | def test_case(input: TokenSeq, a, b): 69 | expect = join_list(tk_splitlines(input)[a:b], Newline_id) 70 | assert_eq( 71 | tk_get_lines(input, a, b), 72 | expect, 73 | lambda: f"input={decode_tokens(input)}, {a=}, {b=}", 74 | ) 75 | 76 | test_case([], 0, 0) 77 | test_case([], 0, 1) 78 | test_case([3], 0, 1) 79 | test_case([3, Newline_id], 0, 1) 80 | test_case([3, Newline_id, 4], 0, 1) 81 | three_lines = [3, Newline_id, 4, Newline_id, 5] 82 | test_case(three_lines, 0, 2) 83 | test_case(three_lines, 0, 5) 84 | test_case(three_lines, 1, 5) 85 | test_case(three_lines, 2, 5) 86 | test_case(three_lines, 3, 5) 87 | test_case([Newline_id] * 5, 2, 3) 88 | test_case([Newline_id] * 5, 2, 4) 89 | 90 | 91 | class TestChangeIdentities: 92 | cases: dict[str, Change[str]] = { 93 | "empty": Modified("", ""), 94 | "generation": Modified("", "123"), 95 | "add a new line": Modified("", "\n"), 96 | "add a new line at end": Modified("a", "a\n"), 97 | "added": Added("a\nb\nc\n"), 98 | "deleted": Deleted("a\nb\nc\n"), 99 | "no change": Modified( 100 | dedent( 101 | """\ 102 | def f1(): 103 | x = 1 104 | """ 105 | ), 106 | dedent( 107 | """\ 108 | def f1(): 109 | x = 1 110 | """ 111 | ), 112 | ), 113 | "unchanged=True": Modified.from_unchanged( 114 | dedent( 115 | """\ 116 | def f1(): 117 | x = 1 118 | """ 119 | ), 120 | ), 121 | # this test case cannot pass for some reason. Tokenizer bug? 122 | # "leading_whitespace": Modified.from_unchanged(" ..."), 123 | "replace last": Modified( 124 | dedent( 125 | """\ 126 | def f1(): 127 | x = 1""" 128 | ), 129 | dedent( 130 | """\ 131 | def f1(): 132 | x = 2 133 | return x * 2""" 134 | ), 135 | ), 136 | "no special tokens": Modified( 137 | dedent( 138 | """\ 139 | def f1(): 140 | x = 1 141 | y = 2 142 | z = x + y 143 | return z 144 | 145 | def f2(): 146 | f1()""" 147 | ), 148 | dedent( 149 | """\ 150 | # new comment 151 | def f_new(): 152 | x = 1 153 | if x > 0: 154 | y = 2 * x 155 | y *= 2 156 | z = x + y 157 | return z 158 | 159 | def f2(): 160 | f1() 161 | return f_new() + a 162 | 163 | new_var = 0 164 | """ 165 | ), 166 | ), 167 | "with special tokens": Modified( 168 | dedent( 169 | """\ 170 | def f1(): 171 | x = "" 172 | y = "\tx" 173 | return x + y 174 | 175 | """ 176 | ), 177 | dedent( 178 | """\ 179 | # new comment 1 180 | # new comment 2 181 | def f1(): 182 | if newcond: 183 | x = "" 184 | new_var = 5 185 | y = "" 186 | return x + new_var + y 187 | """ 188 | ), 189 | ), 190 | "super long": Modified( 191 | "\n".join(f"x = {i}" for i in range(0, 200)), 192 | "\n".join(f"x = {2* (i // 2)}" for i in range(0, 200)), 193 | ), 194 | "strings with newlines": Modified( 195 | dedent( 196 | """\ 197 | If `True`, wraps the environments in an `AsyncVectorEnv` (which uses \n 198 | `multiprocessing` to run the environments in parallel) \n 199 | """ 200 | ), 201 | dedent( 202 | """\ 203 | If `True`, wraps the environments in an `AsyncVectorEnv` (which uses \n 204 | `multiprocessing` to run the environments in parallel) \n 205 | Added a line here. \n 206 | and here. 207 | """ 208 | ), 209 | ), 210 | } 211 | 212 | def test_str_encodings(self): 213 | for name, c in self.cases.items(): 214 | try: 215 | line_diffs = change_to_line_diffs(c) 216 | print("line_diffs\n------\n" + "\n".join(line_diffs)) 217 | before, delta = line_diffs_to_original_delta(line_diffs) 218 | print("before:") 219 | print(before) 220 | print("delta:", delta) 221 | assert_str_equal(before, get_before(c), name) 222 | after = delta.apply_to_input(before) 223 | assert_str_equal(after, get_after(c), name) 224 | except Exception: 225 | print_err(f"Failed for case: {name}") 226 | raise 227 | 228 | def test_tk_encodings(self): 229 | for name, c in self.cases.items(): 230 | print("=" * 40, name, "=" * 40) 231 | c_tokens = change_to_tokens(c) 232 | print_sections( 233 | ("c_tokens", decode_tokens(c_tokens)), 234 | ) 235 | c_rec = tokens_to_change(c_tokens) 236 | assert_change_eq( 237 | c_rec, c, "change_to_tokens |> tokens_to_change = identity: " + name 238 | ) 239 | 240 | in_seq, out_seq = change_to_input_output(c) 241 | print_sections( 242 | ("in_seq", decode_tokens(in_seq)), 243 | ("out_seq", decode_tokens(out_seq)), 244 | ) 245 | 246 | assert_tks_eq( 247 | in_seq, 248 | code_to_input(encode_lines_join(get_before(c))), 249 | "change_to_input_output mathese code_to_input: " + name, 250 | ) 251 | 252 | if len(splitlines(get_before(c))) < N_Extra_Ids: 253 | inlined = inline_output_tokens(in_seq, out_seq) 254 | assert_tks_eq( 255 | inlined, change_to_tokens(c), "inline_output_tokens: " + name 256 | ) 257 | c_rec2 = tokens_to_change(inlined) 258 | assert_change_eq(c_rec2, c, "tokens_to_change(inlined): " + name) 259 | 260 | def test_str_tk_conversion(self): 261 | for name, c in self.cases.items(): 262 | line_diffs = change_to_line_diffs(c) 263 | print("line_diffs\n------\n" + "\n".join(line_diffs)) 264 | before, delta = line_diffs_to_original_delta(line_diffs) 265 | print("delta:", delta) 266 | 267 | tk_delta = delta.to_tk_delta() 268 | tk_before = encode_lines_join(before) 269 | tk_after = tk_delta.apply_to_input(tk_before) 270 | if tk_after != encode_lines_join(get_after(c)): 271 | print("after diff:\n") 272 | print(show_string_diff(get_after(c), decode_tokens(tk_after))) 273 | 274 | c_tokens = tk_delta.apply_to_change(tk_before) 275 | if c_tokens != change_to_tokens(c): 276 | print("c_tokens diff:\n") 277 | print( 278 | show_string_diff( 279 | decode_tokens(c_tokens), decode_tokens(change_to_tokens(c)) 280 | ) 281 | ) 282 | 283 | origin1, tk_delta1 = change_tks_to_original_delta(c_tokens) 284 | if origin1 != tk_before: 285 | print("origin diff:\n") 286 | print( 287 | show_string_diff(decode_tokens(origin1), decode_tokens(tk_before)) 288 | ) 289 | 290 | assert tk_delta1.apply_to_input(origin1) == tk_after 291 | 292 | def test_apply_to_change(self): 293 | for name, c in self.cases.items(): 294 | before, delta = StrDelta.from_change(c) 295 | tk_delta = delta.to_tk_delta() 296 | tk_before = encode_lines_join(before) 297 | tk_change = tk_delta.apply_to_change(tk_before) 298 | expect = change_to_tokens(c) 299 | if tk_change != expect: 300 | print_sections( 301 | ("expect", decode_tokens(expect)), 302 | ("tk_change", decode_tokens(tk_change)), 303 | ) 304 | raise AssertionError(f"apply_to_change failed: {name}") 305 | 306 | def test_random_subset(self): 307 | rng = get_rng() 308 | 309 | def is_sorted(xs): 310 | return list(xs) == list(sorted(xs)) 311 | 312 | xs = range(50) 313 | assert is_sorted(xs) 314 | for _ in range(100): 315 | ys = random_subset(xs, 20, rng) 316 | assert is_sorted(ys) 317 | 318 | x_map = {i: i + 1 for i in range(50)} 319 | assert is_sorted(x_map) 320 | for _ in range(100): 321 | y_map = random_subset(x_map, 20, rng) 322 | assert is_sorted(y_map) 323 | 324 | def test_delta_decomposition(self): 325 | rng = get_rng() 326 | 327 | for name, c in self.cases.items(): 328 | original, delta = TkDelta.from_change_tks(change_to_tokens(c)) 329 | assert_tks_eq(original, encode_lines_join(get_before(c)), name) 330 | expect = delta.apply_to_input(original) 331 | assert_tks_eq(expect, encode_lines_join(get_after(c)), name) 332 | keys = tuple(delta.keys()) 333 | for _ in range(100): 334 | n_keys = int(len(keys) * rng.random()) 335 | sub_keys = random_subset(keys, n_keys) 336 | delta1, delta2 = delta.decompose_for_input(sub_keys) 337 | step1 = delta1.apply_to_input(original) 338 | step2 = delta2.apply_to_input(step1) 339 | try: 340 | assert_tks_eq(step2, expect, name) 341 | except: 342 | print_sections( 343 | ("change", decode_tokens(change_to_tokens(c))), 344 | ("delta", str(delta)), 345 | ("sub_keys", str(sub_keys)), 346 | ("original", decode_tokens(original)), 347 | ("delta1", str(delta1)), 348 | ("step1", decode_tokens(step1)), 349 | ("delta2", str(delta2)), 350 | ("step2", decode_tokens(step2)), 351 | ("expect", decode_tokens(expect)), 352 | ) 353 | raise 354 | 355 | def test_get_new_target_lines(self): 356 | rng = get_rng() 357 | 358 | for name, c in self.cases.items(): 359 | original, delta = TkDelta.from_change_tks(change_to_tokens(c)) 360 | n_origin_lines = len(tk_splitlines(original)) 361 | edit_lines = range(n_origin_lines + 1) 362 | keys = tuple(delta.keys()) 363 | for _ in range(100): 364 | n_keys = int(len(keys) * rng.random()) 365 | sub_keys = random_subset(keys, n_keys) 366 | sub_keys.sort() 367 | delta1, delta2 = delta.decompose_for_change(sub_keys) 368 | new_edit_lines = delta1.get_new_line_ids(edit_lines) 369 | new_edit_set = set(new_edit_lines) 370 | for l in delta2.changed_lines(): 371 | if l not in new_edit_set and l != n_origin_lines: 372 | print_err(f"{edit_lines=}") 373 | print_err("original", SEP) 374 | print_err(add_line_numbers(decode_tokens(original), start=0)) 375 | print_err(SEP) 376 | print_err(f"{delta=}") 377 | print_err(f"{sub_keys=}") 378 | print_err(f"{delta1=}") 379 | print_err("step1", SEP) 380 | step1 = delta1.apply_to_change(original) 381 | print_err(add_line_numbers(decode_tokens(step1), start=0)) 382 | print_err(SEP) 383 | print_err(f"{new_edit_lines=}") 384 | print_err(f"{delta2=}") 385 | raise AssertionError(f"{l=} not in {new_edit_lines=}") 386 | 387 | 388 | def test_edit_lines_transform(): 389 | ex_code = dedent( 390 | """\ 391 | a 392 | b 393 | c 394 | d 395 | e 396 | """ 397 | ) 398 | ex_delta = StrDelta( 399 | { 400 | 1: ("+1",), 401 | 2: ("+2",), 402 | 3: ("-",), 403 | 4: ("+d1", "+d2", "+d3"), 404 | } 405 | ) 406 | after_expect = dedent( 407 | """\ 408 | a 409 | +1 410 | b 411 | +2 412 | c 413 | -d 414 | +d1 415 | +d2 416 | +d3 417 | e 418 | """ 419 | ) 420 | 421 | tk_delta = ex_delta.to_tk_delta() 422 | all_lines = range(6) 423 | new_target_lines = tk_delta.get_new_line_ids(all_lines) 424 | expect = (0, 1, 2, 3, 4, 6, 7, 8, 9, 10) 425 | assert_eq(new_target_lines, expect) 426 | 427 | later_lines = range(3, 6) 428 | new_target_lines = tk_delta.get_new_line_ids(later_lines) 429 | # only the last 5 lines should be edited 430 | expect = (6, 7, 8, 9, 10) 431 | assert_eq(new_target_lines, expect) 432 | 433 | 434 | def test_code_normalization(): 435 | def check_code_equal(code1: str, code2: str): 436 | if not code_equal(code1, code2): 437 | e = AssertionError(f"code_equal failed.") 438 | diff = show_string_diff( 439 | normalize_code_by_ast(code1), normalize_code_by_ast(code2) 440 | ) 441 | e.add_note("Diff in normalized code:\n" + diff) 442 | raise e 443 | 444 | ex_code = dedent( 445 | """\ 446 | def f1(x, y): 447 | return f1(x + 1, y - 1) 448 | """ 449 | ) 450 | ex_code_compact = dedent( 451 | """\ 452 | def f1(x,y): 453 | return f1(x+1,y-1) 454 | """ 455 | ) 456 | check_code_equal(ex_code, ex_code_compact) 457 | ex_code_lose = dedent( 458 | """\ 459 | 460 | def f1(x,y): 461 | 462 | return f1( 463 | x+1, 464 | y-1 465 | ) 466 | """ 467 | ) 468 | check_code_equal(ex_code, ex_code_lose) 469 | 470 | ex_code_keyword1 = "f(x, y=y, z=z)" 471 | ex_code_keyword2 = "f(x, z=z, y=y)" 472 | check_code_equal(ex_code_keyword1, ex_code_keyword2) 473 | 474 | ex_code_keyword3 = "f(x, y=y, z=z, **kwargs)" 475 | 476 | with pytest.raises(AssertionError): 477 | check_code_equal(ex_code_keyword1, ex_code_keyword3) 478 | 479 | 480 | def test_extra_ids(): 481 | all_extra_ids = _Tokenizer.additional_special_tokens_ids 482 | 483 | for x in all_extra_ids: 484 | assert is_extra_id(x) 485 | n = extra_id_to_number(x) 486 | assert get_extra_id(n) == x 487 | 488 | 489 | def test_edit_distance(): 490 | jump_cost = 4 491 | cases = [ 492 | ("empty strings", ("", ""), 0), 493 | ("identical strings", ("abc", "abc"), 0), 494 | ("add to empty", ("", "abc"), 3 + jump_cost), 495 | ("delete all", ("abc", ""), 3 + jump_cost), 496 | ("add to end", ("abc", "abcd"), 1 + jump_cost), 497 | ("add in the middle", ("abc", "aabc"), 1 + jump_cost), 498 | ("replace in the middle", ("abc", "axc"), 2 + jump_cost), 499 | ("consective edits", ("abc", "axdf"), 2 * 2 + 1 + jump_cost), 500 | ("nonconsective inserts (close)", ("abc", "xaxbc"), 3 + jump_cost), 501 | ("nonconsective inserts (far)", ("abcdefg", "axbcdefxg"), 2 + jump_cost * 2), 502 | ("many inserts", ("abcdefg", "xaxbxcxdxefg"), 5 + 4 + jump_cost), 503 | ("many replaces (sep)", ("abcdefg", "xbxdxfx"), 4 * 2 + 3 + jump_cost), 504 | ("many replaces (continuous)", ("abcdefg", "axxxxfg"), 4 * 2 + jump_cost), 505 | ("delete single", ("abcde", "acde"), 1 + jump_cost), 506 | ("delete all", ("a" * 100, ""), 2 * jump_cost + 2), 507 | ( 508 | "delete middle", 509 | ("a" * 30 + "b" * 20 + "c" * 30, "a" * 30 + "c" * 30), 510 | 2 * jump_cost + 2, 511 | ), 512 | ] 513 | for name, (x, y), expect in cases: 514 | assert keystroke_cost(x, y, jump_cost) == expect, f"Failed for case: {name}" 515 | -------------------------------------------------------------------------------- /tests/test_scoped_change.py: -------------------------------------------------------------------------------- 1 | from textwrap import indent 2 | 3 | import pytest 4 | 5 | from coeditor.encoding import _BaseTokenizer 6 | from coeditor.scoped_changes import * 7 | 8 | 9 | def test_change_scope(): 10 | code1 = dedent( 11 | """\ 12 | import os 13 | 14 | x = 1 15 | y = x + 1 16 | 17 | def f1(): 18 | global x 19 | x *= 5 20 | return x 21 | 22 | if __name__ == "__main__": 23 | print(f1() + x) 24 | 25 | @annotated 26 | def f2(): 27 | return 1 28 | 29 | @dataclass 30 | class A: 31 | attr1: int 32 | 33 | @staticmethod 34 | def method1(): 35 | return 1 36 | 37 | class B: 38 | inner_attr1: int 39 | """ 40 | ) 41 | mod_tree = code_to_module(code1) 42 | scope = ChangeScope.from_tree(ProjectPath("code1", ""), mod_tree) 43 | global_spans = [ 44 | dedent( 45 | """\ 46 | x = 1 47 | y = x + 1 48 | """ 49 | ), 50 | dedent( 51 | """\ 52 | if __name__ == "__main__": 53 | print(f1() + x) 54 | """ 55 | ), 56 | ] 57 | try: 58 | for i, code in enumerate(global_spans): 59 | assert_str_equal(scope.spans[i].code, code) 60 | except Exception: 61 | print_err(f"{scope.spans=}") 62 | raise 63 | 64 | f1_expect = dedent( 65 | """\ 66 | global x 67 | x *= 5 68 | return x 69 | """ 70 | ) 71 | f1_code = scope.subscopes["f1"].spans_code 72 | assert_str_equal(f1_code, indent(f1_expect, " " * 4)) 73 | 74 | f2_expect = dedent( 75 | """\ 76 | @annotated 77 | def f2(): 78 | return 1 79 | """ 80 | ) 81 | f2_code = scope.subscopes["f2"].all_code 82 | assert_str_equal(f2_code, f2_expect) 83 | 84 | attr1_expect = dedent( 85 | """\ 86 | attr1: int 87 | """ 88 | ) 89 | attr1_code = scope.subscopes["A"].spans_code 90 | assert_str_equal(attr1_code, indent(attr1_expect, " " * 4)) 91 | 92 | method1_expect = dedent( 93 | """\ 94 | @staticmethod 95 | def method1(): 96 | return 1 97 | """ 98 | ) 99 | method1_code = scope.subscopes["A"].subscopes["method1"].all_code 100 | assert_str_equal(method1_code, indent(method1_expect, " " * 4)) 101 | 102 | inner_attr1_expect = dedent( 103 | """\ 104 | class B: 105 | inner_attr1: int 106 | """ 107 | ) 108 | inner_class_code = scope.subscopes["A"].subscopes["B"].all_code 109 | assert_str_equal(inner_class_code, indent(inner_attr1_expect, " " * 4)) 110 | 111 | 112 | class TestChangedSpan: 113 | code1 = dedent( 114 | """\ 115 | import os 116 | 117 | x = 1 118 | y = x + 1 119 | 120 | def f1(): 121 | global x 122 | x *= 5 123 | return x 124 | 125 | if __name__ == "__main__": 126 | print(f1() + x) 127 | """ 128 | ) 129 | scope1 = ChangeScope.from_tree(ProjectPath("code1", ""), code_to_module(code1)) 130 | 131 | @staticmethod 132 | def check_changed_spans( 133 | changed_spans: Sequence[ChangedSpan], *expects: tuple[type, int] 134 | ): 135 | print(f"{changed_spans=}\nchanges={[cs.change for cs in changed_spans]}") 136 | assert_eq( 137 | len(changed_spans), 138 | len(expects), 139 | ) 140 | for i, (change_type, n) in enumerate(expects): 141 | span = changed_spans[i] 142 | assert_eq(type(span.change), change_type) 143 | nl_change = span.change.map(count_lines) 144 | line_change = nl_change.later - nl_change.earlier 145 | assert_eq(line_change, n, lambda: f"{i=}, {span.change=}") 146 | 147 | def test_same_size_update(self): 148 | code2 = dedent( 149 | """\ 150 | import os 151 | 152 | x = 1 153 | y = x + 2 154 | 155 | def f1(): 156 | global x 157 | x *= 5 158 | return x + 1 159 | 160 | if __name__ == "__main__": 161 | print(f1() + x + 1) 162 | """ 163 | ) 164 | 165 | scope2 = ChangeScope.from_tree(ProjectPath("code1", ""), code_to_module(code2)) 166 | self.check_changed_spans( 167 | get_changed_spans(Modified(self.scope1, scope2)), 168 | (Modified, 0), 169 | (Modified, 0), 170 | (Modified, 0), 171 | ) 172 | 173 | def test_jmodule_change(self): 174 | code2 = dedent( 175 | """\ 176 | import os 177 | 178 | x = 1 179 | y = x + 2 180 | 181 | def f1(): 182 | global x 183 | x *= 5 184 | return x + 1 185 | 186 | if __name__ == "__main__": 187 | print(f1() + x + 1) 188 | """ 189 | ) 190 | 191 | mod1 = JModule("code1", code_to_module(self.code1)) 192 | mod2 = JModule("code1", code_to_module(code2)) 193 | mc = JModuleChange.from_modules(Modified(mod1, mod2)) 194 | assert len(mc.changed) == 3 195 | 196 | def test_diff_size_update(self): 197 | code2 = dedent( 198 | """\ 199 | import os 200 | 201 | x = 1 202 | y = x + 1 203 | z += 1 204 | 205 | def f1(): 206 | global x 207 | x *= 5 208 | return x 209 | 210 | print(f1() + x) 211 | """ 212 | ) 213 | scope2 = ChangeScope.from_tree(ProjectPath("code1", ""), code_to_module(code2)) 214 | self.check_changed_spans( 215 | get_changed_spans(Modified(self.scope1, scope2)), 216 | (Modified, 1), 217 | (Modified, -1), 218 | ) 219 | 220 | def test_fun_deletion(self): 221 | code2 = dedent( 222 | """\ 223 | import os 224 | 225 | x = 2 226 | 227 | if __doc__ == "__main__": 228 | print(f1() + x) 229 | print("doc") 230 | """ 231 | ) 232 | scope2 = ChangeScope.from_tree(ProjectPath("code1", ""), code_to_module(code2)) 233 | self.check_changed_spans( 234 | get_changed_spans(Modified(self.scope1, scope2)), 235 | (Modified, -1), 236 | (Deleted, 0), 237 | (Modified, 1), 238 | ) 239 | 240 | def test_fun_addition(self): 241 | code2 = dedent( 242 | """\ 243 | import os 244 | 245 | x = 1 246 | @wrapped 247 | def new_f(): 248 | pass 249 | y = x + 1 250 | 251 | def f1(): 252 | global x 253 | x *= 5 254 | return x 255 | 256 | if __name__ == "__main__": 257 | print(f1() + x) 258 | """ 259 | ) 260 | scope2 = ChangeScope.from_tree(ProjectPath("code1", ""), code_to_module(code2)) 261 | self.check_changed_spans( 262 | get_changed_spans(Modified(self.scope1, scope2)), 263 | (Added, 0), 264 | ) 265 | 266 | def test_class_addition(self): 267 | code1 = dedent( 268 | """\ 269 | import os 270 | 271 | x = 1 272 | y = x + 1 273 | 274 | if __name__ == "__main__": 275 | print(f1() + x) 276 | """ 277 | ) 278 | 279 | code2 = dedent( 280 | """\ 281 | import os 282 | 283 | x = 1 284 | y = x + 1 285 | 286 | @dataclass 287 | class Foo(): 288 | "new class" 289 | x: int = 1 290 | y: int = 2 291 | 292 | if __name__ == "__main__": 293 | print(f1() + x) 294 | """ 295 | ) 296 | scope1 = ChangeScope.from_tree(ProjectPath("code1", ""), code_to_module(code1)) 297 | scope2 = ChangeScope.from_tree(ProjectPath("code1", ""), code_to_module(code2)) 298 | self.check_changed_spans( 299 | get_changed_spans(Modified(scope1, scope2)), 300 | (Added, 0), 301 | ) 302 | 303 | def test_statement_move(self): 304 | code2 = dedent( 305 | """\ 306 | import os 307 | 308 | x = 1 309 | 310 | def f1(): 311 | global x 312 | x *= 5 313 | return x 314 | 315 | y = x + 1 316 | if __name__ == "__main__": 317 | print(f1() + x) 318 | """ 319 | ) 320 | scope2 = ChangeScope.from_tree(ProjectPath("code1", ""), code_to_module(code2)) 321 | self.check_changed_spans( 322 | get_changed_spans(Modified(self.scope1, scope2)), 323 | ) 324 | 325 | def test_comments_change(self): 326 | # have to update code as well for the changes to count 327 | code2 = dedent( 328 | """\ 329 | import os 330 | 331 | x = 1 332 | # belongs to f1 333 | 334 | def f1(): 335 | "added doc string" 336 | global x 337 | x *= 5 338 | return x + 1 339 | 340 | # belongs to main 341 | if __name__ == "__main__": 342 | print(f1() + x + 1) # belongs to print 343 | """ 344 | ) 345 | scope2 = ChangeScope.from_tree(ProjectPath("code1", ""), code_to_module(code2)) 346 | self.check_changed_spans( 347 | get_changed_spans(Modified(self.scope1, scope2)), 348 | (Modified, -1), 349 | (Modified, 1), 350 | (Modified, 1), 351 | ) 352 | -------------------------------------------------------------------------------- /tests/testcases/defs.py: -------------------------------------------------------------------------------- 1 | # these code are for testing only 2 | 3 | from functools import cached_property 4 | from coeditor.common import * 5 | from parso.python import tree as ptree 6 | from coeditor.change import Change 7 | 8 | # from coeditor. import ProjectPath 9 | 10 | ScopeTree = ptree.Function | ptree.Class | ptree.Module 11 | ChangedSpan = NewType("ChangedSpan", str) 12 | 13 | 14 | @dataclass 15 | class ChangeScope: 16 | """ 17 | A change scope is a python module, non-hidden function, or a non-hidden class, or a python module. 18 | - functions and classes that are inside a parent function are considered hidden. 19 | """ 20 | 21 | path: ProjectPath 22 | tree: ScopeTree 23 | spans: Sequence 24 | subscopes: Mapping[ProjectPath, Self] 25 | 26 | @cached_property 27 | def spans_code(self) -> str: 28 | return "\n".join(s.code for s in self.spans) 29 | 30 | @staticmethod 31 | def from_tree(path: ProjectPath, tree: ScopeTree) -> "ChangeScope": 32 | spans = [] 33 | subscopes = dict() 34 | scope = ChangeScope(path, tree, spans, subscopes) 35 | assert isinstance(tree, ScopeTree) 36 | is_func = isinstance(tree, ptree.Function) 37 | 38 | current_stmts = [] 39 | content = ( 40 | tree.children 41 | if isinstance(tree, ptree.Module) 42 | else cast(ptree.PythonNode, tree.get_suite()).children 43 | ) 44 | raise NotImplementedError 45 | -------------------------------------------------------------------------------- /tests/testcases/example.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | def foo(x, y): 5 | return x.bar().go(y.bar()) 6 | 7 | 8 | @dataclass 9 | class Foo: 10 | value: float 11 | 12 | 13 | def bar(self): 14 | return Weight(self.value) 15 | 16 | 17 | @dataclass 18 | class Weight: 19 | value: float 20 | 21 | def go(self, other): 22 | return self.value + other.value 23 | 24 | 25 | foo(Foo(1), Foo(2)) 26 | -------------------------------------------------------------------------------- /tests/testcases/usages.py: -------------------------------------------------------------------------------- 1 | from coeditor.change import Added, Deleted, Modified 2 | from coeditor.common import ModuleName 3 | from .defs import * 4 | from typing import * 5 | 6 | 7 | @dataclass 8 | class JModule: 9 | "A light wrapper around a jedi module." 10 | mname: ModuleName 11 | tree: ptree.Module 12 | 13 | def _to_scope(self) -> ChangeScope: 14 | return ChangeScope.from_tree(ProjectPath(self.mname, ""), self.tree) 15 | 16 | @cached_property 17 | def imported_names(self): 18 | names = set[ptree.Name]() 19 | for stmt in self.iter_imports(self.tree): 20 | if isinstance(stmt, ptree.ImportFrom): 21 | for n in stmt.get_from_names(): 22 | assert isinstance(n, ptree.Name) 23 | names.add(n) 24 | elif isinstance(stmt, ptree.ImportName): 25 | for n in stmt.get_defined_names(): 26 | assert isinstance(n, ptree.Name) 27 | names.add(n) 28 | return names 29 | 30 | def iter_imports(self, tree): 31 | raise NotImplementedError 32 | 33 | 34 | get_modified_spans = as_any(None) 35 | 36 | 37 | def get_named_changes(*args): 38 | raise NotImplementedError 39 | 40 | 41 | def recurse(scope_change: Change[ChangeScope], parent_changes) -> Iterable[ChangedSpan]: 42 | parent_changes = (*parent_changes, scope_change) 43 | match scope_change: 44 | case Modified(old_scope, new_scope): 45 | # compute statement differences 46 | yield from get_modified_spans(old_scope, new_scope, parent_changes) 47 | for sub_change in get_named_changes( 48 | old_scope.subscopes, new_scope.subscopes 49 | ).values(): 50 | yield from recurse(sub_change, parent_changes) 51 | case Added(scope) | Deleted(scope): 52 | for span in scope.spans: 53 | code_change = scope_change.new_value(span.code) 54 | yield ChangedSpan( 55 | code_change, 56 | parent_changes, 57 | span.line_range, 58 | ) 59 | for s in scope.subscopes.values(): 60 | s_change = scope_change.new_value(s) 61 | yield from recurse(s_change, parent_changes) 62 | --------------------------------------------------------------------------------