├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs ├── curate_dataset.md └── dev_note.md ├── pyproject.toml ├── repoqa ├── __init__.py ├── compute_score.py ├── data.py ├── metric.py ├── provider │ ├── __init__.py │ ├── anthropic.py │ ├── base.py │ ├── google.py │ ├── hf.py │ ├── openai.py │ ├── request │ │ ├── __init__.py │ │ ├── anthropic.py │ │ ├── google.py │ │ └── openai.py │ └── vllm.py ├── search_needle_function.py └── utility.py ├── requirements.txt ├── results ├── .gitignore └── README.md ├── scripts ├── cherrypick │ ├── README.md │ └── lists.json ├── curate │ ├── dataset_ensemble_clone.py │ ├── dataset_ensemble_gh_api.py │ ├── dep_analysis │ │ ├── cpp.py │ │ ├── data │ │ │ └── .gitignore │ │ ├── go-analysis │ │ │ └── dependency_analysis.go │ │ ├── go.py │ │ ├── java-analysis │ │ │ ├── dependency-reduced-pom.xml │ │ │ ├── java-lib │ │ │ │ └── java-analysis-1.0-SNAPSHOT.jar │ │ │ ├── pom.xml │ │ │ └── src │ │ │ │ └── main │ │ │ │ └── java │ │ │ │ └── edu │ │ │ │ └── cs │ │ │ │ └── illinois │ │ │ │ └── repoqa │ │ │ │ └── DepAnalyze.java │ │ ├── java.py │ │ ├── python.py │ │ ├── rust.py │ │ └── typescript.py │ ├── function_analysis.py │ ├── github_fetch.py │ ├── merge_annotation.py │ ├── merge_dep.py │ ├── needle_annotation.py │ ├── needle_selection.py │ ├── obfuscate_nl.py │ ├── requirements.txt │ └── utility.py ├── demos │ └── model_request_oai.py ├── dev │ └── license-hdr.txt ├── eval │ └── recompute_all_scores.sh └── misc │ ├── estimate_max_char.py │ └── repo_token_size.py └── setup.cfg /.gitignore: -------------------------------------------------------------------------------- 1 | # Customized 2 | .vscode/ 3 | *.jsonl 4 | repoqa-*.json 5 | *.bak 6 | repoqa/_version.py 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | # nuclear option because steven uses PyCharm. 168 | .idea/ 169 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | repos: 6 | - repo: https://github.com/pycqa/isort 7 | rev: 5.12.0 8 | hooks: 9 | - id: isort 10 | name: isort (python) 11 | args: ["--profile", "black"] 12 | - repo: https://github.com/psf/black 13 | rev: 22.6.0 14 | hooks: 15 | - id: black 16 | - repo: https://github.com/pre-commit/pre-commit-hooks 17 | rev: v4.3.0 18 | hooks: 19 | - id: check-yaml 20 | - id: end-of-file-fixer 21 | - id: trailing-whitespace 22 | - id: check-added-large-files 23 | args: ["--maxkb=32"] 24 | - id: debug-statements 25 | - repo: https://github.com/Lucas-C/pre-commit-hooks 26 | rev: v1.5.4 27 | hooks: 28 | - id: forbid-tabs 29 | - id: remove-tabs 30 | - id: insert-license 31 | files: \.(sh|yaml|yml|py)$ 32 | args: ["--license-filepath=scripts/dev/license-hdr.txt", "--use-current-year"] 33 | exclude: (?x)^( 34 | repo/.* 35 | )$ 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | ------------------------------------------------------------------------------- 204 | The files under "evalplus/eval/" additionally complies with the MIT License for 205 | being built on OpenAI's HumanEval work. 206 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RepoQA: Evaluating Long-Context Code Understanding 2 | 3 | [![](https://img.shields.io/badge/arXiv-2406.06025-b31b1b.svg?style=for-the-badge)](https://arxiv.org/abs/2406.06025) 4 | [![](https://img.shields.io/pypi/v/repoqa?style=for-the-badge&labelColor=black)](https://pypi.org/project/repoqa/) 5 | 6 | 🏠 Homepage: https://evalplus.github.io/repoqa.html 7 | 8 | ## 🚀 Installation 9 | 10 | ```bash 11 | # without vLLM (can run openai, anthropic, and huggingface backends) 12 | pip install --upgrade repoqa 13 | # To enable vLLM 14 | pip install --upgrade "repoqa[vllm]" 15 | ``` 16 | 17 |
⏬ Install nightly version :: click to expand :: 18 |
19 | 20 | ```bash 21 | pip install --upgrade "git+https://github.com/evalplus/repoqa.git" # without vLLM 22 | pip install --upgrade "repoqa[vllm] @ git+https://github.com/evalplus/repoqa@main" # with vLLM 23 | ``` 24 | 25 |
26 |
27 | 28 |
⏬ Using RepoQA as a local repo? :: click to expand :: 29 |
30 | 31 | ```bash 32 | git clone https://github.com/evalplus/repoqa.git 33 | cd repoqa 34 | export PYTHONPATH=$PYTHONPATH:$(pwd) 35 | pip install -r requirements.txt 36 | ``` 37 | 38 |
39 |
40 | 41 | ## 🏁 Search Needle Function (SNF) 42 | 43 | Search Needle Function is the first and base RepoQA task which aims to practice LLMs' ability of **long-context code understanding and retrieval**. 44 | Its corresponding real-life scenario is to perform precise code search from function description. 45 | 46 |
🔎 More dataset details :: click to expand :: 47 |
48 | 49 | > [!Note] 50 | > 51 | > SNF includes 500 tests (5 programming languages x 10 repos x 10 needle functions) where an LLM is given: 52 | > 53 | > 1. A large code context sorted in file dependency 54 | > 2. A NL description of the needle function without revealing keywords like function names 55 | > 3. An instruction to retrieve the described function 56 | > 57 | > The evaluator passes a test if the searched function is syntactically closest to the ground-truth compared against 58 | > other functions (systematically parsed by `treesitter`) and the similarity is greater than a user defined threshold (by default 0.8). 59 | 60 |
61 |
62 | 63 | You can run the SNF evaluation using various backends: 64 | 65 | ### OpenAI Compatible Servers 66 | 67 | ```bash 68 | repoqa.search_needle_function --model "gpt4-turbo" --backend openai 69 | # 💡 If you use openai API compatible server such as vLLM servers: 70 | # repoqa.search_needle_function --base-url "http://localhost:[PORT]/v1" \ 71 | # --model "Qwen/CodeQwen1.5-7B-Chat" --backend openai 72 | ``` 73 | 74 | ### Anthropic Compatible Servers 75 | 76 | ```bash 77 | repoqa.search_needle_function --model "claude-3-haiku-20240307" --backend anthropic 78 | ``` 79 | 80 | ### vLLM 81 | 82 | ```bash 83 | repoqa.search_needle_function --model "Qwen/CodeQwen1.5-7B-Chat" --backend vllm 84 | ``` 85 | 86 |
🔎 Context extension for small-ctx models :: click to expand :: 87 |
88 | 89 | > There are two ways to unlock a model's context at inference time: 90 | > 91 | > 1. **Direct Extension**: Edit `max_positional_embedding` of the model's `config.json` (e.g., `hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/[hash]/config.json`) to something like `22528`. 92 | > 2. **[Dynamic RoPE Scaling](https://blog.eleuther.ai/yarn/#dynamic-scaling)**: 93 | > To extend `Meta-Llama-3-8B-Instruct` from 8k to 32k (4x), edit the `config.json`: 94 | > 95 | > `"rope_scaling": {"type": "dynamic", "factor": 4.0}` 96 | > 97 | > Note: This works for vLLM `<0.4.3` and HuggingFace transformers. RepoQA will automatically configure dynamic RoPE for vLLM `>= 0.4.3` 98 | 99 |
100 |
101 | 102 | > [!Note] 103 | > 104 | > Reference evaluation time: 105 | > 106 | > - Llama3-8B-Instruct: 45 minutes on 2xA6000 (PCIe NVLink) 107 | > - Llama3-70B-Instruct: 100 minutes on 4xA100 (PCIe NVLink) 108 | 109 | ### HuggingFace transformers 110 | 111 | ```bash 112 | repoqa.search_needle_function --model "Qwen/CodeQwen1.5-7B-Chat" --backend hf --trust-remote-code 113 | ``` 114 | 115 | > [!Tip] 116 | > 117 | > Installing [flash-attn](https://github.com/Dao-AILab/flash-attention) and 118 | > additionally set `--attn-implementation "flash_attention_2"` can largely 119 | > lower the memory requirement. 120 | 121 |
🔨 Having trouble installing `flash-attn`? :: click to expand :: 122 |
123 | 124 | > If you have trouble with `pip install flash-attn --no-build-isolation`, 125 | > you can try to directly use [pre-built wheels](https://github.com/Dao-AILab/flash-attention/releases): 126 | > 127 | > ```shell 128 | > export FLASH_ATTN_VER=2.5.8 # check latest version at https://github.com/Dao-AILab/flash-attention/releases 129 | > export CUDA_VER="cu122" # check available ones at https://github.com/Dao-AILab/flash-attention/releases 130 | > export TORCH_VER=$(python -c "import torch; print('.'.join(torch.__version__.split('.')[:2]))") 131 | > export PY_VER=$(python -c "import platform; print(''.join(platform.python_version().split('.')[:2]))") 132 | > export OS_ARCH=$(python -c "import platform; print(f'{platform.system().lower()}_{platform.machine()}')") 133 | > 134 | > export WHEEL=flash_attn-${FLASH_ATTN_VER}+${CUDA_VER}torch${TORCH_VER}cxx11abiFALSE-cp${PY_VER}-cp${PY_VER}-${OS_ARCH}.whl 135 | > wget https://github.com/Dao-AILab/flash-attention/releases/download/v${FLASH_ATTN_VER}/${WHEEL} 136 | > pip install ${WHEEL} 137 | > ``` 138 | 139 |
140 |
141 | 142 | ### Google Generative AI API (Gemini) 143 | 144 | ```bash 145 | repoqa.search_needle_function --model "gemini-1.5-pro-latest" --backend google 146 | ``` 147 | 148 | ### CLI Usage 149 | 150 | - **Input**: 151 | - `--model`: Hugging-Face model ID, such as `ise-uiuc/Magicoder-S-DS-6.7B` 152 | - `--backend`: `vllm` (default) or `openai` 153 | - `--base-url`: OpenAI API base URL 154 | - `--code-context-size` (default: 16384): #tokens (by DeepSeekCoder tokenizer) of repository context 155 | - `--caching` (default: True): accelerate subsequent runs by caching preprocessing; `--nocaching` to disable 156 | - `--max-new-tokens` (default: 1024): Maximum #new tokens to generate 157 | - `--system-message` (default: None): system message (note it's not supported by some models) 158 | - `--tensor-parallel-size`: #GPUS for doing tensor parallelism (only for vLLM) 159 | - `--languages` (default: None): List of languages to evaluate (None means all) 160 | - `--result-dir` (default: "results"): Directory to save the model outputs and evaluation results 161 | - `--clean-ctx-comments` (default: "none"): Clean context comments with padding ("positional_padding") or no padding ("no_padding") 162 | - `--eval-ignore-comments` (default: False): During evaluation, ignore groundtruth and model comments 163 | - `--trust-remote-code` (default: False): allow remote code (for HuggingFace transformers and vLLM) 164 | - `--attn-implementation` (default: None): Use "flash_attention_2" if your HF hits OOM 165 | - **Output**: 166 | - `results/ntoken_{code-context-size}/{model}.jsonl`: Model generated outputs 167 | - `results/ntoken_{code-context-size}/{model}-SCORE.json`: Evaluation results 168 | 169 | ### Compute Scores 170 | 171 | By default, the `repoqa.search_needle_function` command will evaluate model outputs and compute scores after text generation. 172 | However, you can also separately compute scores using the following command: 173 | 174 | ```shell 175 | repoqa.compute_score --model-output-path={model-output}.jsonl 176 | ``` 177 | 178 | > [!Tip] 179 | > 180 | > - **Input**: Path to the model generated outputs. 181 | > - **Output**: The evaluation scores would be stored in `{model-output}-SCORES.json` 182 | 183 | ## 📚 Read More 184 | 185 | - [RepoQA Homepage](https://evalplus.github.io/repoqa.html) 186 | - [RepoQA Dataset Curation](docs/curate_dataset.md) 187 | - [RepoQA Development Notes](docs/dev_note.md) 188 | 189 | ## Citation 190 | 191 | ```bibtex 192 | @article{repoqa, 193 | title = {RepoQA: Evaluating Long Context Code Understanding}, 194 | author = {Liu, Jiawei and Tian, Jia Le and Daita, Vijay and Wei, Yuxiang and Ding, Yifeng and Wang, Yuhan Katherine and Yang, Jun and Zhang, Lingming}, 195 | year = {2024}, 196 | journal = {arXiv preprint arXiv:2406.06025}, 197 | } 198 | ``` 199 | -------------------------------------------------------------------------------- /docs/curate_dataset.md: -------------------------------------------------------------------------------- 1 | # RepoQA Dataset Curation 2 | 3 | ## Search Needle Functions 4 | 5 | ### Step 1: Cherry-pick repositories 6 | 7 | See [scripts/cherrypick/README.md](cherrypick/README.md) for more information. 8 | 9 | 10 | > [!Tip] 11 | > 12 | > **Output**: Extend `scripts/cherrypick/lists.json` for a programming language. 13 | 14 | 15 | ### Step 2: Extract repo content 16 | 17 | ```shell 18 | python scripts/curate/dataset_ensemble_clone.py 19 | ``` 20 | 21 | > [!Tip] 22 | > 23 | > **Output**: `repoqa-{datetime}.json` by adding a `"content"` field (path to content) for each repo. 24 | 25 | 26 | ### Step 3: Dependency analysis 27 | 28 | Check [scripts/curate/dep_analysis](scripts/curate/dep_analysis) for more information. 29 | 30 | ```shell 31 | python scripts/curate/dep_analysis/{language}.py # python 32 | ``` 33 | 34 | > [!Tip] 35 | > 36 | > **Output**: `{language}.json` (e.g., `python.json`) with a list of items of `{"repo": ..., "commit_sha": ..., "dependency": ...}` field where the dependency is a map of path to imported paths. 37 | 38 | > [!Note] 39 | > 40 | > The `{language}.json` should be uploaded as a release. 41 | > 42 | > To fetch the release, go to `scripts/curate/dep_analysis/data` and run `gh release download dependency --pattern "*.json" --clobber`. 43 | 44 | 45 | ### Step 4: Merge step 2 and step 3 46 | 47 | ```shell 48 | python scripts/curate/merge_dep.py --dataset-path repoqa-{datetime}.json 49 | ``` 50 | 51 | > [!Tip] 52 | > 53 | > **Input**: Download dependency files in to `scripts/curate/dep_analysis/data`. 54 | > 55 | > **Output**: Update `repoqa-{datetime}.json` by adding a `"dependency"` field for each repository. 56 | 57 | 58 | ### Step 5: Function collection with TreeSitter 59 | 60 | ```shell 61 | # collect functions (in-place) 62 | python scripts/curate/function_analysis.py --dataset-path repoqa-{datetime}.json 63 | # select needles (in-place) 64 | python scripts/curate/needle_selection.py --dataset-path repoqa-{datetime}.json 65 | ``` 66 | 67 | > [!Tip] 68 | > 69 | > **Output**: `--dataset-path` (in-place) by adding a `"functions"` field (path to a list function information) for each repo. 70 | 71 | 72 | ### Step 6: Annotate each function with description to make a final dataset 73 | 74 | ```shell 75 | python scripts/curate/needle_annotation.py --dataset-path repoqa-{datetime}.json 76 | ``` 77 | 78 | > [!Tip] 79 | > 80 | > You need to set `OPENAI_API_KEY` in the environment variable to run GPT-4. But you can enable `--use-batch-api` to save some costs. 81 | > 82 | > **Output**: `--output-desc-path` is a seperate json file specifying the function annotations with its sources. 83 | 84 | 85 | ### Step 7: Merge needle description to the final dataset 86 | 87 | ```shell 88 | python scripts/curate/merge_annotation.py --dataset-path repoqa-{datetime}.json --annotation-path {output-desc-path}.jsonl 89 | ``` 90 | 91 | > [!Tip] 92 | > 93 | > **Output**: `--dataset-path` (in-place) by adding a `"description"` field for each needle function. 94 | -------------------------------------------------------------------------------- /docs/dev_note.md: -------------------------------------------------------------------------------- 1 | # RepoQA Development Notes 2 | 3 | ## DEV Structure 4 | 5 | - `repo`: entrypoint for working repositories 6 | - `repoqa`: source code for the RepoQA evaluation library 7 | - `scripts`: scripts for maintaining the repository and other utilities 8 | - `dev`: scripts for CI/CD and repository maintenance 9 | - `curate`: code for dataset curation 10 | - `dep_analysis`: dependency analysis for different programming languages 11 | - `cherrypick`: cherry-picked repositories for evaluation 12 | - `demos`: demos to quickly use some utility functions such as requesting LLMs 13 | 14 | ## Development Beginner Notice 15 | 16 | ### After clone 17 | 18 | ```shell 19 | pip install pre-commit 20 | pre-commit install 21 | pip install -r requirements.txt 22 | pip install -r scripts/curate/requirements.txt 23 | ``` 24 | 25 | 26 | ### Import errors? 27 | 28 | ```shell 29 | # Go to the root path of RepoQA 30 | export PYTHONPATH=$PYTHONPATH:$(pwd) 31 | ``` 32 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.setuptools_scm] 6 | write_to = "repoqa/_version.py" 7 | version_scheme = "release-branch-semver" 8 | local_scheme = "no-local-version" 9 | -------------------------------------------------------------------------------- /repoqa/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | try: 6 | from repoqa._version import __version__, __version_tuple__ 7 | except ImportError: 8 | __version__ = "local-dev" 9 | -------------------------------------------------------------------------------- /repoqa/compute_score.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import itertools 6 | import json 7 | import os 8 | import re 9 | from collections import defaultdict 10 | from datetime import datetime 11 | from enum import Enum 12 | from pathlib import Path 13 | from typing import Dict, List, Tuple, Union 14 | 15 | import numpy as np 16 | import tempdir 17 | from rich.console import Console 18 | from rich.table import Table 19 | from transformers import AutoConfig 20 | from tree_sitter_languages import get_language, get_parser 21 | 22 | from repoqa.data import get_repoqa_data 23 | from repoqa.metric import compute_function_similarity 24 | from repoqa.utility import COMMENT_QUERY, FUNCTION_QUERY, progress 25 | 26 | LANGUAGES = list(FUNCTION_QUERY.keys()) 27 | THRESHOLDS = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] 28 | 29 | 30 | class Result(Enum): 31 | BEST_MATCH = "best_match" 32 | FAIL_MATCH = "fail_match" 33 | 34 | 35 | # unbiased estimator from https://github.com/openai/human-eval 36 | def estimate_pass_at_k( 37 | num_samples: Union[int, List[int], np.ndarray], 38 | num_correct: Union[List[int], np.ndarray], 39 | k: int, 40 | ) -> np.ndarray: 41 | """ 42 | Estimates pass@k of each problem and returns them in an array. 43 | """ 44 | 45 | def estimator(n: int, c: int, k: int) -> float: 46 | """ 47 | Calculates 1 - comb(n - c, k) / comb(n, k). 48 | """ 49 | if n - c < k: 50 | return 1.0 51 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 52 | 53 | if isinstance(num_samples, int): 54 | num_samples_it = itertools.repeat(num_samples, len(num_correct)) 55 | else: 56 | assert len(num_samples) == len(num_correct) 57 | num_samples_it = iter(num_samples) 58 | 59 | return np.array( 60 | [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)] 61 | ) 62 | 63 | 64 | def remove_comments(source_code: str, lang: str) -> str: 65 | source_bytes = bytes(source_code, "utf8") 66 | parser = get_parser(lang) 67 | tree = parser.parse(source_bytes) 68 | root_node = tree.root_node 69 | 70 | # Remove comments from source code 71 | capture_list = [] 72 | for query_str in COMMENT_QUERY[lang]: 73 | comment_query = get_language(lang).query(query_str) 74 | capture_list += comment_query.captures(root_node) 75 | 76 | capture_list.sort(key=lambda cap: cap[0].start_byte, reverse=True) 77 | 78 | for node, _ in capture_list: 79 | source_bytes = source_bytes[: node.start_byte] + source_bytes[node.end_byte :] 80 | 81 | return source_bytes.decode("utf-8") 82 | 83 | 84 | def sanitize_output(model_output: str, lang: str) -> str: 85 | model_output = model_output.strip() 86 | search_pattern = r"^```(?:\w+)?\s*\n(.*?)(?=^```)```" 87 | code_blocks = re.findall(search_pattern, model_output, re.DOTALL | re.MULTILINE) 88 | 89 | parser = get_parser(lang) 90 | fn_query = get_language(lang).query(FUNCTION_QUERY[lang]) 91 | 92 | # If not code blocks found, simply return model output 93 | if not code_blocks: 94 | return model_output 95 | 96 | processed_blocks = [] 97 | for block in code_blocks: 98 | processed_blocks.append(block) 99 | 100 | # Try to use tree-sitter to parse if possible 101 | try: 102 | block_bytes = bytes(block, "utf8") 103 | tree = parser.parse(block_bytes) 104 | for capture in fn_query.captures(tree.root_node): 105 | node, _ = capture 106 | function_content = block_bytes[node.start_byte : node.end_byte] 107 | return function_content.decode("utf8") 108 | except: 109 | pass 110 | 111 | # no valid functions found by tree-sitter approach return first block 112 | return processed_blocks[0] 113 | 114 | 115 | def print_result_table(model_name, pass_results): 116 | # Printing scores in a table 117 | table = Table(title=f"Scores (%) of {model_name} at different thresholds") 118 | table.add_column("Threshold", justify="center", style="bold magenta") 119 | for threshold in THRESHOLDS: 120 | table.add_column(f"{threshold}", justify="center") 121 | 122 | # Prepare data to determine the maximum values for each threshold 123 | threshold_scores = {threshold: [] for threshold in THRESHOLDS} 124 | for lang_results in pass_results.values(): 125 | for thresh, value in lang_results.items(): 126 | threshold_scores[thresh].append(value["pass@1"]) 127 | 128 | # Calculate the maximum score for each threshold 129 | max_scores = { 130 | threshold: max(scores) for threshold, scores in threshold_scores.items() 131 | } 132 | min_scores = { 133 | threshold: min(scores) for threshold, scores in threshold_scores.items() 134 | } 135 | 136 | # Fill the table rows 137 | for language, lang_results in pass_results.items(): 138 | row = [("⭐" if language == "all" else "") + language] 139 | for threshold, value in lang_results.items(): 140 | score = value["pass@1"] 141 | formatted_score = f"{100 * score:.1f}" 142 | if max_scores[threshold] - score < 0.01: 143 | formatted_score = f"[bold green]{formatted_score}[/]" 144 | elif score - min_scores[threshold] < 0.01: 145 | formatted_score = f"[bold red]{formatted_score}[/]" 146 | row.append(formatted_score) 147 | if language == "all": 148 | row = [f"[bold yellow]{r}[/]" for r in row] 149 | table.add_row(*row) 150 | 151 | Console().print(table) 152 | 153 | 154 | def needle_evaluator( 155 | model_output: str, 156 | ground_truth: str, 157 | repo_info: Dict, 158 | lang: str, 159 | ignore_comments: bool, 160 | ) -> Tuple[Result, str, float]: 161 | contents = repo_info["content"] 162 | needles = repo_info["needles"] 163 | 164 | best_target = None 165 | best_similarity = 0 166 | sanitized_output = sanitize_output(model_output, lang) 167 | if ignore_comments: 168 | sanitized_output = remove_comments(sanitized_output, lang) 169 | for needle in needles: 170 | current_path = needle["path"] 171 | current_name = needle["name"] 172 | current_func = "\n".join( 173 | contents[current_path].split("\n")[ 174 | needle["start_line"] : needle["end_line"] 175 | ] 176 | ) 177 | if ignore_comments: 178 | current_func = remove_comments(current_func, lang) 179 | 180 | current_similarity = compute_function_similarity(sanitized_output, current_func) 181 | if current_similarity > best_similarity: 182 | best_similarity = current_similarity 183 | best_target = current_name 184 | 185 | if best_target == ground_truth: 186 | verdict = Result.BEST_MATCH 187 | else: 188 | verdict = Result.FAIL_MATCH 189 | return verdict, best_target, best_similarity 190 | 191 | 192 | def _get_repo(lang_data: Dict, repo_name: str) -> Dict: 193 | for repo in lang_data: 194 | if repo["repo"] == repo_name: 195 | return repo 196 | 197 | 198 | def compute_language_results(evaluation_result: Dict, all_results: Dict) -> None: 199 | for language, lang_results in evaluation_result.items(): 200 | current_result = {} 201 | total = np.array([1 for _ in lang_results]) 202 | 203 | for threshold in THRESHOLDS: 204 | correct_result = [] 205 | for res in lang_results: 206 | bc = 0 207 | if res["is_best_similar"] and res["best_similar_score"] >= threshold: 208 | bc = 1 209 | correct_result.append(bc) 210 | correct_result = np.array(correct_result) 211 | 212 | pass_at_k = { 213 | f"pass@{k}": estimate_pass_at_k(total, correct_result, k).mean() 214 | for k in [1, 10, 100] 215 | if total.min() >= k 216 | } 217 | current_result[threshold] = pass_at_k 218 | all_results[language] = current_result 219 | 220 | 221 | def fetch_hf_context(model_name: str) -> str: 222 | # Retrieved from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1073 223 | possible_keys = [ 224 | # OPT 225 | "max_position_embeddings", 226 | # GPT-2 227 | "n_positions", 228 | # MPT 229 | "max_seq_len", 230 | # ChatGLM2 231 | "seq_length", 232 | # Command-R 233 | "model_max_length", 234 | # Others 235 | "max_sequence_length", 236 | "max_seq_length", 237 | "seq_len", 238 | ] 239 | try: 240 | with tempdir.TempDir() as temp_dir: 241 | config = AutoConfig.from_pretrained( 242 | model_name, 243 | cache_dir=temp_dir, 244 | force_download=True, 245 | trust_remote_code=True, 246 | ).to_dict() 247 | longest_context = 0 248 | for key in possible_keys: 249 | if key in config: 250 | longest_context = max(config[key], longest_context) 251 | if not (longest_context): 252 | return "N/A" 253 | return str(int(longest_context / 1024)) + "k" 254 | except Exception as err: 255 | print(f"fetching failed... Reason:\n{err}") 256 | return "N/A" 257 | 258 | 259 | def compute_score( 260 | model_name: str, dataset: Dict, model_output: List[Dict], ignore_comments: bool 261 | ) -> Dict: 262 | evaluation_result = defaultdict(list) 263 | with progress(f"Scoring {model_name}") as pbar: 264 | for result in pbar.track(model_output): 265 | lang = result["language"] 266 | repo_name = result["repo"] 267 | model_outputs = result["output"] 268 | ground_truth = result["name"] 269 | repo_info = _get_repo(dataset[lang], repo_name) 270 | 271 | model_output = model_outputs[0] 272 | verdict, best_target, best_similarity = needle_evaluator( 273 | model_output, ground_truth, repo_info, lang, ignore_comments 274 | ) 275 | 276 | is_best_similar = False 277 | if verdict == Result.BEST_MATCH: 278 | is_best_similar = True 279 | 280 | current_task = { 281 | "repo": repo_name, 282 | "name": ground_truth, 283 | "needle_position": result["position_ratio"], 284 | "is_best_similar": is_best_similar, 285 | "best_similar_score": best_similarity, 286 | "best_target": best_target, 287 | "position": { 288 | "token_start": result["needle_token_start"], 289 | "token_end": result["needle_token_end"], 290 | }, 291 | } 292 | evaluation_result[lang].append(current_task) 293 | 294 | # Calculate pass@k 295 | pass_results = {} 296 | 297 | all_langs = [] 298 | for lang in evaluation_result: 299 | all_langs += evaluation_result[lang] 300 | total = np.array([1 for _ in all_langs]) 301 | 302 | pass_results["all"] = {} 303 | for threshold in THRESHOLDS: 304 | correct_result = [] 305 | for res in all_langs: 306 | bc = 0 307 | if res["is_best_similar"] and res["best_similar_score"] >= threshold: 308 | bc = 1 309 | correct_result.append(bc) 310 | correct_result = np.array(correct_result) 311 | pass_at_k = { 312 | f"pass@{k}": estimate_pass_at_k(total, correct_result, k).mean() 313 | for k in [1, 10, 100] 314 | if total.min() >= k 315 | } 316 | pass_results["all"][threshold] = pass_at_k 317 | 318 | compute_language_results(evaluation_result, pass_results) 319 | print_result_table(model_name, pass_results) 320 | 321 | output_json = {} 322 | model_json = {} 323 | model_json["eval_date"] = str(datetime.now()) 324 | 325 | # hardcode paid models 326 | if "/" in model_name: 327 | if model_name.startswith("bigcode/starcoder2"): 328 | train_context = "16k" 329 | else: 330 | train_context = fetch_hf_context(model_name) 331 | elif model_name.startswith("gpt-4-turbo") or model_name.startswith("gpt-4o-"): 332 | train_context = "128k" 333 | elif model_name.startswith("gpt-3.5-"): 334 | train_context = "16k" 335 | elif model_name.startswith("gemini-1.5-pro") or model_name.startswith( 336 | "gemini-1.5-flash" 337 | ): 338 | train_context = "1000k" 339 | elif model_name.startswith("gemini-1.0-pro"): 340 | train_context = "32k" 341 | elif model_name.startswith("claude-3-"): 342 | train_context = "200k" 343 | else: 344 | train_context = "N/A" 345 | model_json["train_size"] = train_context 346 | model_json["scores"] = pass_results 347 | model_json["results"] = evaluation_result 348 | 349 | output_json[model_name] = model_json 350 | 351 | return output_json 352 | 353 | 354 | def get_model_name(output_path: str) -> str: 355 | file_name = Path(output_path).stem 356 | segments = file_name.split("_") 357 | output_name = "" 358 | for segment in segments: 359 | if segment == "slash": 360 | output_name += "/" 361 | else: 362 | output_name += segment 363 | return output_name 364 | 365 | 366 | def save_json(output_json, result_path) -> None: 367 | if os.path.isfile(result_path): 368 | decision = "" 369 | while decision.lower() not in ["y", "n"]: 370 | print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...") 371 | decision = input() 372 | 373 | if decision.lower() == "y": 374 | # mv the file to a backup 375 | new_path = result_path + ".bak" 376 | while os.path.isfile(new_path): 377 | new_path += ".bak" 378 | os.rename(result_path, new_path) 379 | print(f"Backup {result_path} to {new_path}") 380 | 381 | if not os.path.isfile(result_path): 382 | with open(result_path, "w") as f: 383 | json.dump(output_json, f) 384 | 385 | 386 | def compute_main( 387 | model_output_path: str, ignore_comments: bool = False, dataset_path: str = None 388 | ): 389 | if dataset_path is None: 390 | dataset = get_repoqa_data() 391 | else: 392 | with open(dataset_path, "r") as dataset_f: 393 | dataset = json.load(dataset_f) 394 | 395 | model_outputs = [] 396 | with open(model_output_path, "r") as output_f: 397 | for line in output_f: 398 | model_outputs.append(json.loads(line)) 399 | 400 | file_base, _ = os.path.splitext(model_output_path) 401 | result_path = file_base + "-SCORES.json" 402 | model_name = get_model_name(model_output_path) 403 | output_json = compute_score(model_name, dataset, model_outputs, ignore_comments) 404 | save_json(output_json, result_path) 405 | 406 | 407 | def main(): 408 | from fire import Fire 409 | 410 | Fire(compute_main) 411 | 412 | 413 | if __name__ == "__main__": 414 | main() 415 | -------------------------------------------------------------------------------- /repoqa/data.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import gzip 6 | import json 7 | import os 8 | 9 | import tempdir 10 | import wget 11 | from appdirs import user_cache_dir 12 | 13 | CACHE_DIR = user_cache_dir("repoqa") 14 | 15 | REPOQA_DATA_OVERRIDE_PATH = os.getenv("REPOQA_DATA_OVERRIDE_PATH", None) 16 | REPOQA_DATA_VERSION = os.getenv("REPOQA_DATA_VERSION", "2024-06-23") 17 | 18 | 19 | def _get_repoqa_data_ready_path() -> str: 20 | if REPOQA_DATA_OVERRIDE_PATH: 21 | assert os.path.exists( 22 | REPOQA_DATA_OVERRIDE_PATH 23 | ), f"File not found: {REPOQA_DATA_OVERRIDE_PATH}" 24 | return REPOQA_DATA_OVERRIDE_PATH 25 | 26 | gzip_url = f"https://github.com/evalplus/repoqa_release/releases/download/{REPOQA_DATA_VERSION}/repoqa-{REPOQA_DATA_VERSION}.json.gz" 27 | cache_path = os.path.join(CACHE_DIR, f"repoqa-{REPOQA_DATA_VERSION}.json") 28 | # Check if human eval file exists in CACHE_DIR 29 | if not os.path.exists(cache_path): 30 | # Install HumanEval dataset and parse as json 31 | print(f"Downloading dataset from {gzip_url}") 32 | with tempdir.TempDir() as tmpdir: 33 | gzip_path = os.path.join(tmpdir, f"data.json.gz") 34 | wget.download(gzip_url, gzip_path) 35 | 36 | with gzip.open(gzip_path, "rb") as f: 37 | repoqa_data = f.read().decode("utf-8") 38 | 39 | # create CACHE_DIR if not exists 40 | os.makedirs(CACHE_DIR, exist_ok=True) 41 | # Write the original human eval file to CACHE_DIR 42 | with open(cache_path, "w") as f: 43 | f.write(repoqa_data) 44 | 45 | return cache_path 46 | 47 | 48 | def get_repoqa_data(): 49 | with open(_get_repoqa_data_ready_path(), "r") as f: 50 | return json.load(f) 51 | -------------------------------------------------------------------------------- /repoqa/metric.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import re 6 | 7 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 8 | 9 | 10 | def compute_function_similarity( 11 | candidate_function: str, reference_function: str 12 | ) -> float: 13 | candidate_tokens = [item for item in re.split("\s+", candidate_function.strip())] 14 | 15 | reference_tokens = [item for item in re.split("\s+", reference_function.strip())] 16 | 17 | chencherry = SmoothingFunction() 18 | 19 | return sentence_bleu( 20 | [reference_tokens], candidate_tokens, smoothing_function=chencherry.method4 21 | ) 22 | -------------------------------------------------------------------------------- /repoqa/provider/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from repoqa.provider.base import BaseProvider 6 | -------------------------------------------------------------------------------- /repoqa/provider/anthropic.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import os 6 | from typing import List 7 | 8 | from anthropic import Client 9 | 10 | from repoqa.provider.base import BaseProvider 11 | from repoqa.provider.request.anthropic import make_auto_request 12 | 13 | 14 | class AnthropicProvider(BaseProvider): 15 | def __init__(self, model): 16 | self.model = model 17 | self.client = Client(api_key=os.getenv("ANTHROPIC_KEY")) 18 | 19 | def generate_reply( 20 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None 21 | ) -> List[str]: 22 | assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" 23 | replies = [] 24 | for _ in range(n): 25 | reply = make_auto_request( 26 | self.client, 27 | message=question, 28 | model=self.model, 29 | temperature=temperature, 30 | max_tokens=max_tokens, 31 | system_msg=system_msg, 32 | ) 33 | replies.append(reply.content[0].text) 34 | 35 | return replies 36 | -------------------------------------------------------------------------------- /repoqa/provider/base.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import List 7 | 8 | 9 | class BaseProvider(ABC): 10 | @abstractmethod 11 | def generate_reply( 12 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None 13 | ) -> List[str]: 14 | ... 15 | -------------------------------------------------------------------------------- /repoqa/provider/google.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import os 6 | from typing import List 7 | 8 | import google.generativeai as genai 9 | 10 | from repoqa.provider.base import BaseProvider 11 | from repoqa.provider.request.google import make_auto_request 12 | 13 | 14 | class GoogleProvider(BaseProvider): 15 | def __init__(self, model): 16 | genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) 17 | self.model = model 18 | self.client = genai.GenerativeModel(model) 19 | 20 | def generate_reply( 21 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None 22 | ) -> List[str]: 23 | assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" 24 | replies = make_auto_request( 25 | self.client, 26 | question, 27 | self.model, 28 | n=n, 29 | max_tokens=max_tokens, 30 | temperature=temperature, 31 | system_msg=system_msg, 32 | ) 33 | 34 | if len(replies.candidates) != n: 35 | print(f"[WARNING] # replies = {len(replies.candidates)} != {n = }") 36 | 37 | ret_texts = [] 38 | for candidate in replies.candidates: 39 | parts = candidate.content.parts 40 | if parts: 41 | ret_texts.append(parts[0].text) 42 | else: 43 | print("Empty response!") 44 | ret_texts.append("") 45 | print(f"{candidate.safety_ratings = }") 46 | 47 | return ret_texts + [""] * (n - len(ret_texts)) 48 | -------------------------------------------------------------------------------- /repoqa/provider/hf.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import List 6 | 7 | import torch 8 | from transformers import AutoModelForCausalLM, AutoTokenizer 9 | 10 | from repoqa.provider.base import BaseProvider 11 | from repoqa.provider.request import construct_message_list, hacky_assistant_stop_seq 12 | 13 | 14 | class HfProvider(BaseProvider): 15 | def __init__(self, model, trust_remote_code=False, attn_implementation=None): 16 | self.tokenizer = AutoTokenizer.from_pretrained( 17 | model, trust_remote_code=trust_remote_code 18 | ) 19 | self.hf_model = AutoModelForCausalLM.from_pretrained( 20 | model, 21 | trust_remote_code=trust_remote_code, 22 | attn_implementation=attn_implementation, 23 | torch_dtype="auto", 24 | ).cuda() 25 | self.stop_seq = [] 26 | if self.tokenizer.chat_template: 27 | self.stop_seq.append(hacky_assistant_stop_seq(self.tokenizer)) 28 | 29 | @torch.inference_mode() 30 | def generate_reply( 31 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None 32 | ) -> List[str]: 33 | assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" 34 | 35 | prompt_tokens = self.tokenizer.apply_chat_template( 36 | construct_message_list(question, system_msg), 37 | return_tensors="pt", 38 | add_generation_prompt=True, 39 | ).cuda() 40 | input_length = prompt_tokens.size(-1) 41 | 42 | gen_args = {"do_sample": False} 43 | if temperature > 0: 44 | gen_args["do_sample"] = True 45 | gen_args["temperature"] = temperature 46 | 47 | output_text = self.hf_model.generate( 48 | input_ids=prompt_tokens, 49 | max_new_tokens=max_tokens, 50 | num_return_sequences=n, 51 | pad_token_id=self.tokenizer.eos_token_id, 52 | use_cache=True, 53 | stop_strings=self.stop_seq, 54 | tokenizer=self.tokenizer, 55 | **gen_args, 56 | ) 57 | 58 | gen_strs = [ 59 | self.tokenizer.decode( 60 | x[input_length:], 61 | skip_special_tokens=True, 62 | clean_up_tokenization_spaces=False, 63 | ) 64 | for x in output_text 65 | ] 66 | return gen_strs 67 | -------------------------------------------------------------------------------- /repoqa/provider/openai.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import os 6 | from typing import List 7 | 8 | from openai import Client 9 | from transformers import AutoTokenizer 10 | 11 | from repoqa.provider.base import BaseProvider 12 | from repoqa.provider.request import hacky_assistant_stop_seq 13 | from repoqa.provider.request.openai import make_auto_request 14 | 15 | 16 | class OpenAIProvider(BaseProvider): 17 | def __init__(self, model, base_url: str = None): 18 | self.model = model 19 | self.client = Client( 20 | api_key=os.getenv("OPENAI_API_KEY", "none"), base_url=base_url 21 | ) 22 | self.stop_seq = [] 23 | try: 24 | tokenizer = AutoTokenizer.from_pretrained(model) 25 | if tokenizer.chat_template: 26 | self.stop_seq.append(hacky_assistant_stop_seq(tokenizer)) 27 | print("Using stop sequence: ", self.stop_seq) 28 | except: 29 | print("Failed to automatically fetch stop tokens from HuggingFace.") 30 | 31 | def generate_reply( 32 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None 33 | ) -> List[str]: 34 | assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" 35 | replies = make_auto_request( 36 | self.client, 37 | message=question, 38 | model=self.model, 39 | temperature=temperature, 40 | n=n, 41 | max_tokens=max_tokens, 42 | system_msg=system_msg, 43 | stop=self.stop_seq, 44 | ) 45 | 46 | return [reply.message.content for reply in replies.choices] 47 | -------------------------------------------------------------------------------- /repoqa/provider/request/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | def construct_message_list(message, system_message=None): 7 | msglist = [{"role": "user", "content": message}] 8 | if system_message: 9 | msglist.insert(0, {"role": "system", "content": system_message}) 10 | return msglist 11 | 12 | 13 | def hacky_assistant_stop_seq(tokenizer) -> str: 14 | _magic_string_ = "&==NowOrNever==&Accelerate!!!==&" 15 | return tokenizer.apply_chat_template( 16 | [ 17 | {"role": "user", "content": ""}, 18 | {"role": "assistant", "content": _magic_string_}, 19 | ], 20 | tokenize=False, 21 | ).split(_magic_string_)[-1] 22 | -------------------------------------------------------------------------------- /repoqa/provider/request/anthropic.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import signal 6 | import time 7 | 8 | import anthropic 9 | from anthropic.types import Message 10 | 11 | from repoqa.provider.request import construct_message_list 12 | 13 | 14 | def make_request( 15 | client: anthropic.Client, 16 | message: str, 17 | model: str, 18 | max_tokens: int = 512, 19 | temperature: float = 1, 20 | system_msg="You are a helpful assistant good at coding.", 21 | **kwargs, 22 | ) -> Message: 23 | return client.messages.create( 24 | model=model, 25 | messages=construct_message_list(message, system_message=system_msg), 26 | max_tokens=max_tokens, 27 | temperature=temperature, 28 | **kwargs, 29 | ) 30 | 31 | 32 | def handler(signum, frame): 33 | # swallow signum and frame 34 | raise Exception("end of time") 35 | 36 | 37 | def make_auto_request(client: anthropic.Client, *args, **kwargs) -> Message: 38 | ret = None 39 | while ret is None: 40 | try: 41 | signal.signal(signal.SIGALRM, handler) 42 | signal.alarm(100) 43 | ret = make_request(client, *args, **kwargs) 44 | signal.alarm(0) 45 | except anthropic.RateLimitError: 46 | print("Rate limit exceeded. Waiting...") 47 | signal.alarm(0) 48 | time.sleep(10) 49 | except anthropic.APIConnectionError: 50 | print("API connection error. Waiting...") 51 | signal.alarm(0) 52 | time.sleep(5) 53 | except anthropic.InternalServerError: 54 | print("Internal server error. Waiting...") 55 | signal.alarm(0) 56 | time.sleep(5) 57 | except anthropic.APIError as e: 58 | print("Unknown API error") 59 | print(e) 60 | if ( 61 | e.body["error"]["message"] 62 | == "Output blocked by content filtering policy" 63 | ): 64 | raise Exception("Content filtering policy blocked output") 65 | signal.alarm(0) 66 | except Exception as e: 67 | print("Unknown error. Waiting...") 68 | print(e) 69 | signal.alarm(0) 70 | time.sleep(1) 71 | return ret 72 | -------------------------------------------------------------------------------- /repoqa/provider/request/google.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import signal 6 | import time 7 | 8 | import google.generativeai as genai 9 | from google.api_core.exceptions import GoogleAPICallError, ResourceExhausted 10 | 11 | from repoqa.provider.request import construct_message_list 12 | 13 | 14 | def make_request( 15 | client: genai.GenerativeModel, 16 | message: str, 17 | model: str, 18 | max_tokens: int = 512, 19 | temperature: float = 1, 20 | n: int = 1, 21 | system_msg="You are a helpful assistant good at coding.", 22 | **kwargs, 23 | ) -> genai.types.GenerateContentResponse: 24 | messages = [] 25 | if system_msg: 26 | messages.append({"role": "system", "parts": [system_msg]}) 27 | messages.append({"role": "user", "parts": [message]}) 28 | return client.generate_content( 29 | messages, 30 | generation_config=genai.types.GenerationConfig( 31 | candidate_count=n, max_output_tokens=max_tokens, temperature=temperature 32 | ), 33 | **kwargs, 34 | ) 35 | 36 | 37 | def handler(signum, frame): 38 | # swallow signum and frame 39 | raise Exception("end of time") 40 | 41 | 42 | def make_auto_request(*args, **kwargs) -> genai.types.GenerateContentResponse: 43 | ret = None 44 | while ret is None: 45 | try: 46 | signal.signal(signal.SIGALRM, handler) 47 | signal.alarm(100) 48 | ret = make_request(*args, **kwargs) 49 | signal.alarm(0) 50 | except ResourceExhausted as e: 51 | print("Rate limit exceeded. Waiting...", e.message) 52 | signal.alarm(0) 53 | time.sleep(10) 54 | except GoogleAPICallError as e: 55 | print(e.message) 56 | signal.alarm(0) 57 | time.sleep(1) 58 | except Exception as e: 59 | print("Unknown error. Waiting...") 60 | print(e) 61 | signal.alarm(0) 62 | time.sleep(1) 63 | return ret 64 | -------------------------------------------------------------------------------- /repoqa/provider/request/openai.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import signal 6 | import time 7 | 8 | import openai 9 | from openai.types.chat import ChatCompletion 10 | 11 | from repoqa.provider.request import construct_message_list 12 | 13 | 14 | def make_request( 15 | client: openai.Client, 16 | message: str, 17 | model: str, 18 | max_tokens: int = 512, 19 | temperature: float = 1, 20 | n: int = 1, 21 | system_msg="You are a helpful assistant good at coding.", 22 | **kwargs, 23 | ) -> ChatCompletion: 24 | return client.chat.completions.create( 25 | model=model, 26 | messages=construct_message_list(message, system_message=system_msg), 27 | max_tokens=max_tokens, 28 | temperature=temperature, 29 | n=n, 30 | **kwargs, 31 | ) 32 | 33 | 34 | def handler(signum, frame): 35 | # swallow signum and frame 36 | raise Exception("end of time") 37 | 38 | 39 | def make_auto_request(*args, **kwargs) -> ChatCompletion: 40 | ret = None 41 | while ret is None: 42 | try: 43 | signal.signal(signal.SIGALRM, handler) 44 | signal.alarm(100) 45 | ret = make_request(*args, **kwargs) 46 | signal.alarm(0) 47 | except openai.RateLimitError: 48 | print("Rate limit exceeded. Waiting...") 49 | signal.alarm(0) 50 | time.sleep(10) 51 | except openai.APIConnectionError: 52 | print("API connection error. Waiting...") 53 | signal.alarm(0) 54 | time.sleep(5) 55 | except openai.APIError as e: 56 | print(e) 57 | signal.alarm(0) 58 | except Exception as e: 59 | print("Unknown error. Waiting...") 60 | print(e) 61 | signal.alarm(0) 62 | time.sleep(1) 63 | return ret 64 | -------------------------------------------------------------------------------- /repoqa/provider/vllm.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from typing import List 6 | 7 | from transformers import AutoTokenizer 8 | from vllm import LLM, SamplingParams 9 | 10 | from repoqa.provider.base import BaseProvider 11 | from repoqa.provider.request import construct_message_list, hacky_assistant_stop_seq 12 | 13 | 14 | class VllmProvider(BaseProvider): 15 | def __init__( 16 | self, model, tensor_parallel_size, max_model_len=None, trust_remote_code=False 17 | ): 18 | self.tokenizer = AutoTokenizer.from_pretrained( 19 | model, trust_remote_code=trust_remote_code 20 | ) 21 | self.llm = LLM( 22 | model=model, 23 | tensor_parallel_size=tensor_parallel_size, 24 | max_model_len=max_model_len, 25 | trust_remote_code=trust_remote_code, 26 | ) 27 | self.stop_seq = [] 28 | if self.tokenizer.chat_template: 29 | self.stop_seq.append(hacky_assistant_stop_seq(self.tokenizer)) 30 | 31 | def generate_reply( 32 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None 33 | ) -> List[str]: 34 | assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" 35 | 36 | prompt = self.tokenizer.apply_chat_template( 37 | construct_message_list(question, system_msg), 38 | tokenize=False, 39 | add_generation_prompt=True, 40 | ) 41 | vllm_outputs = self.llm.generate( 42 | [prompt], 43 | SamplingParams( 44 | temperature=temperature, 45 | max_tokens=max_tokens, 46 | stop=self.stop_seq, 47 | ), 48 | use_tqdm=False, 49 | ) 50 | 51 | gen_strs = [x.outputs[0].text for x in vllm_outputs] 52 | return gen_strs 53 | -------------------------------------------------------------------------------- /repoqa/search_needle_function.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import json 5 | import os 6 | from enum import Enum 7 | from typing import List, Tuple 8 | 9 | from transformers import AutoTokenizer 10 | from tree_sitter_languages import get_language, get_parser 11 | 12 | from repoqa.compute_score import compute_score, save_json 13 | from repoqa.data import CACHE_DIR, get_repoqa_data 14 | from repoqa.utility import COMMENT_QUERY, progress, topological_sort 15 | 16 | COMMENT_PREFIX = { 17 | "python": "#", 18 | "java": "//", 19 | "typescript": "//", 20 | "rust": "//", 21 | "cpp": "//", 22 | "go": "//", 23 | } 24 | 25 | # Model context below: 26 | TEMPLATE = "instruction\ncode_context\ndescription\ninstruction" 27 | 28 | INSTRUCTION = ( 29 | "Based on the function description and code context," 30 | " please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:" 31 | ) 32 | 33 | # Mode to clean context comments 34 | class CleanComment(Enum): 35 | NoClean = "none" 36 | PositionalPadding = "positional_padding" 37 | NoPadding = "no_padding" 38 | 39 | 40 | def _backward_tokenizable_lines(lines, tokenizer, max_tokens): 41 | """Return the text and tokens from bottom to top""" 42 | text = "" 43 | ntokens = 0 44 | is_break = False 45 | for line in reversed(lines): 46 | new_ntokens = len(tokenizer.tokenize(line + "\n")) 47 | if ntokens + new_ntokens > max_tokens: 48 | is_break = True 49 | break 50 | text = line + "\n" + text 51 | ntokens += new_ntokens 52 | return text, ntokens, is_break 53 | 54 | 55 | def _forward_tokenizable_lines(lines, tokenizer, max_tokens): 56 | """Return the text and tokens from top to bottom""" 57 | text = "" 58 | ntokens = 0 59 | is_break = False 60 | for line in lines: 61 | new_ntokens = len(tokenizer.tokenize(line + "\n")) 62 | if ntokens + new_ntokens > max_tokens: 63 | is_break = True 64 | break 65 | text += line + "\n" 66 | ntokens += new_ntokens 67 | if is_break: 68 | text = text + "...\n" 69 | ntokens += len(tokenizer.tokenize("...\n")) 70 | return text, ntokens, is_break 71 | 72 | 73 | def filter_path_comments(capture, context_paths, source_bytes, comment_prefix): 74 | node, _ = capture 75 | text = source_bytes[node.start_byte : node.end_byte] 76 | for path in context_paths: 77 | if text.decode("utf8") == comment_prefix + " Path: " + path: 78 | return False 79 | return True 80 | 81 | 82 | def clean_segment_comments(language, segment, context_paths): 83 | source_bytes = bytes(segment, "utf8") 84 | parser = get_parser(language) 85 | tree = parser.parse(source_bytes) 86 | root_node = tree.root_node 87 | 88 | # Remove comments from source code 89 | capture_list = [] 90 | for query_str in COMMENT_QUERY[language]: 91 | comment_query = get_language(language).query(query_str) 92 | capture_list += comment_query.captures(root_node) 93 | 94 | # Filter out synethetic comments containing paths info 95 | filtered_capture = list( 96 | filter( 97 | lambda capture: filter_path_comments( 98 | capture, context_paths, source_bytes, COMMENT_PREFIX[language] 99 | ), 100 | capture_list, 101 | ) 102 | ) 103 | 104 | filtered_capture.sort(key=lambda cap: cap[0].start_byte, reverse=True) 105 | 106 | for node, _ in filtered_capture: 107 | source_bytes = source_bytes[: node.start_byte] + source_bytes[node.end_byte :] 108 | 109 | return source_bytes.decode("utf-8") 110 | 111 | 112 | # Clean partial context due to context construction 113 | def clean_partial_file(language, whole_file, partial_lines, path): 114 | path_comment = f"{COMMENT_PREFIX[language]} Path: {path}\n" 115 | source_bytes = bytes(whole_file, "utf8") 116 | parser = get_parser(language) 117 | tree = parser.parse(source_bytes) 118 | root_node = tree.root_node 119 | 120 | # Remove comments from source code 121 | capture_list = [] 122 | for query_str in COMMENT_QUERY[language]: 123 | comment_query = get_language(language).query(query_str) 124 | capture_list += comment_query.captures(root_node) 125 | 126 | capture_list.sort(key=lambda cap: cap[0].start_byte, reverse=True) 127 | 128 | for node, _ in capture_list: 129 | new_line_count = source_bytes[node.start_byte : node.end_byte].count(b"\n") 130 | source_bytes = ( 131 | source_bytes[: node.start_byte] 132 | + b"\n" * new_line_count 133 | + source_bytes[node.end_byte :] 134 | ) 135 | 136 | return ( 137 | path_comment 138 | + "\n".join(source_bytes.decode("utf-8").split("\n")[: partial_lines - 1]) 139 | + "...\n" 140 | ) 141 | 142 | 143 | def clean_context_comments( 144 | language: str, 145 | prefix: str, 146 | needle_code: str, 147 | suffix: str, 148 | tokenizer, 149 | context_paths: str, 150 | top_prefix_file: str, 151 | bot_suffix_file: str, 152 | position_ratio: float, 153 | add_padding: bool, 154 | ): 155 | prefix_orig_size = len(tokenizer.tokenize(prefix)) 156 | needle_orig_size = len(tokenizer.tokenize(needle_code)) 157 | suffix_orig_size = len(tokenizer.tokenize(suffix)) 158 | 159 | # If there is are prefix files, it might get chopped off preventing proper parsing 160 | # we fully parse the top prefix file to avoid errors 161 | if top_prefix_file: 162 | second_path = f"{COMMENT_PREFIX[language]} Path: {context_paths[1]}" 163 | prefix_lines = prefix.split("\n") 164 | top_file_lines = 0 165 | lines_after_target = [] 166 | target_found = False 167 | for line in prefix_lines: 168 | if target_found: 169 | lines_after_target.append(line) 170 | elif second_path in line: 171 | target_found = True 172 | lines_after_target.append(line) 173 | else: 174 | top_file_lines += 1 175 | top_file_cleaned = clean_partial_file( 176 | language, top_prefix_file, top_file_lines, context_paths[0] 177 | ) 178 | rest_files_cleaned = clean_segment_comments( 179 | language, "\n".join(lines_after_target), context_paths 180 | ) 181 | prefix_cleaned = top_file_cleaned + rest_files_cleaned 182 | else: 183 | prefix_cleaned = clean_segment_comments(language, prefix, context_paths) 184 | needle_cleaned = needle_code 185 | needle_cleaned = clean_segment_comments(language, needle_code, context_paths) 186 | 187 | # Same for suffix 188 | if bot_suffix_file: 189 | second_path = f"{COMMENT_PREFIX[language]} Path: {context_paths[-1]}" 190 | prefix_lines = prefix.split("\n") 191 | bot_file_lines = 0 192 | lines_before_target = [] 193 | target_found = False 194 | for line in prefix_lines: 195 | if target_found: 196 | lines_before_target.append(line) 197 | elif second_path in line: 198 | target_found = True 199 | lines_before_target.append(line) 200 | else: 201 | bot_file_lines += 1 202 | top_file_cleaned = clean_partial_file( 203 | language, bot_suffix_file, bot_file_lines, context_paths[-1] 204 | ) 205 | rest_files_cleaned = clean_segment_comments( 206 | language, "\n".join(lines_before_target), context_paths 207 | ) 208 | suffix_cleaned = rest_files_cleaned + top_file_cleaned 209 | else: 210 | suffix_cleaned = clean_segment_comments(language, suffix, context_paths) 211 | 212 | if not add_padding: 213 | return prefix_cleaned, needle_cleaned, suffix_cleaned 214 | 215 | # Calculate amount of padding to prefix and suffix to maintain position 216 | prefix_clean_size = len(tokenizer.tokenize(prefix_cleaned)) 217 | needle_clean_size = len(tokenizer.tokenize(needle_cleaned)) 218 | suffix_clean_size = len(tokenizer.tokenize(suffix_cleaned)) 219 | 220 | # Determine how much of needle padding go to prefix & suffix 221 | needle_tokens_removed = needle_orig_size - needle_clean_size 222 | needle_prefix_padding = int(needle_tokens_removed * position_ratio) 223 | needle_suffix_padding = needle_tokens_removed - needle_prefix_padding 224 | 225 | # Add more padding to compensate removal from prefix/suffix portions 226 | needle_prefix_padding = int( 227 | (needle_prefix_padding + prefix_orig_size - prefix_clean_size - 1) 228 | ) 229 | needle_suffix_padding = int( 230 | (needle_suffix_padding + suffix_orig_size - suffix_clean_size - 1) 231 | ) 232 | 233 | prefix_dummy = "" 234 | line = 0 235 | while needle_prefix_padding > 0: 236 | current = f"{COMMENT_PREFIX[language]} Line Number {line}\n" 237 | current_len = len(tokenizer.tokenize(current)) 238 | needle_prefix_padding -= current_len 239 | prefix_dummy += current 240 | line += 1 241 | prefix_cleaned = prefix_dummy + "\n" + prefix_cleaned 242 | 243 | suffix_dummy = "" 244 | while needle_suffix_padding > 0: 245 | current = f"{COMMENT_PREFIX[language]} Line Number {line}\n" 246 | current_len = len(tokenizer.tokenize(current)) 247 | needle_suffix_padding -= current_len 248 | line += 1 249 | suffix_dummy += current 250 | suffix_cleaned = suffix_cleaned + suffix_dummy + "\n" 251 | 252 | return prefix_cleaned, needle_cleaned, suffix_cleaned 253 | 254 | 255 | def make_code_context( 256 | needle, 257 | file_content_list: List[Tuple[str, str]], 258 | position_ratio: float, 259 | code_context_size: int, 260 | language: str, 261 | clean_comments: CleanComment = CleanComment.NoClean, 262 | ) -> str: 263 | """ 264 | Slice the file_content_list such that: 265 | 1. The slice contains code_context_size tokens 266 | 2. The positon of the needle is at position_ratio of the slice* 267 | *May not be achievable if the needle is too close to the beginning or end of the file_content_list 268 | *May not be accurate as we will also insert file names at the beginning of each file 269 | *Token sizes might not be 100 accurate but should be close enough 270 | """ 271 | tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf") 272 | 273 | needle_file_idx, needle_file_content = [ 274 | (i, content) 275 | for i, (f, content) in enumerate(file_content_list) 276 | if f == needle["path"] 277 | ][0] 278 | 279 | needle_code = needle_file_content[needle["start_byte"] : needle["end_byte"]] 280 | ntoken_needle = len(tokenizer.tokenize(needle_code)) 281 | 282 | # Used for if cleaning comments option is enabled (paths comments are skipped) 283 | context_paths = [needle["path"]] 284 | top_prefix_file = None 285 | bot_suffix_file = None 286 | 287 | prefix_size = int(code_context_size * position_ratio - ntoken_needle / 2) 288 | suffix_size = code_context_size - ntoken_needle - prefix_size 289 | 290 | # handling prefix of the needle file 291 | code_prefix, ntokens, is_break = _backward_tokenizable_lines( 292 | [COMMENT_PREFIX[language] + " Path: " + needle["path"]] 293 | + needle_file_content[: needle["start_byte"]].split("\n"), 294 | tokenizer, 295 | prefix_size, 296 | ) 297 | prefix_size -= ntokens 298 | 299 | # handling prefix of the previous files 300 | index = needle_file_idx - 1 301 | while not is_break and prefix_size > 0 and index >= 0: 302 | path, content = file_content_list[index] 303 | context_paths.insert(0, path) 304 | top_prefix_file = content 305 | prefix, ntokens, is_break = _forward_tokenizable_lines( 306 | [COMMENT_PREFIX[language] + " Path: " + path] + content.split("\n"), 307 | tokenizer, 308 | prefix_size, 309 | ) 310 | code_prefix = prefix + code_prefix 311 | prefix_size -= ntokens 312 | index -= 1 313 | 314 | # handling suffix of the needle file 315 | code_suffix, ntokens, is_break = _forward_tokenizable_lines( 316 | needle_file_content[needle["end_byte"] :].split("\n"), tokenizer, suffix_size 317 | ) 318 | suffix_size -= ntokens 319 | 320 | # handling suffix of the next files 321 | index = needle_file_idx + 1 322 | while not is_break and suffix_size > 0 and index < len(file_content_list): 323 | path, content = file_content_list[index] 324 | context_paths.append(path) 325 | bot_suffix_file = content 326 | suffix, ntokens, is_break = _forward_tokenizable_lines( 327 | [COMMENT_PREFIX[language] + " Path: " + path] + content.split("\n"), 328 | tokenizer, 329 | suffix_size, 330 | ) 331 | code_suffix += suffix 332 | suffix_size -= ntokens 333 | index += 1 334 | 335 | # Remove the comments in code_prefix, needle_code, code_suffix and 336 | # pad the code_prefix and code_suffix to maintain the position of the needle 337 | if clean_comments != CleanComment.NoClean: 338 | code_prefix, needle_code, code_suffix = clean_context_comments( 339 | language, 340 | code_prefix, 341 | needle_code, 342 | code_suffix, 343 | tokenizer, 344 | context_paths, 345 | top_prefix_file, 346 | bot_suffix_file, 347 | position_ratio, 348 | clean_comments == CleanComment.PositionalPadding, 349 | ) 350 | 351 | code_context = code_prefix + needle_code + code_suffix 352 | 353 | needle_token_start = len(tokenizer.tokenize(code_prefix)) 354 | needle_token_end = needle_token_start + len(tokenizer.tokenize(needle_code)) 355 | code_context_ntokens = needle_token_end + len(tokenizer.tokenize(code_suffix)) 356 | 357 | return { 358 | "code_context": code_context, 359 | "needle_token_start": needle_token_start, 360 | "needle_token_end": needle_token_end, 361 | "code_context_ntokens": code_context_ntokens, 362 | } 363 | 364 | 365 | def make_task_id(lang, repo, needle_name): 366 | return f"{lang}::{repo}::{needle_name}" 367 | 368 | 369 | def make_cache_id(lang, repo, needle_name, code_context_size, position_ratio): 370 | return f"{lang}::{repo}::{needle_name}::{code_context_size}::{position_ratio}" 371 | 372 | 373 | def evaluate_model( 374 | model: str, 375 | base_url: str = None, 376 | backend: str = None, 377 | tensor_parallel_size: int = 1, 378 | code_context_size: int = 16 * 1024, 379 | max_new_tokens: int = 1024, 380 | result_dir: str = "results", 381 | languages: List[str] = None, 382 | caching: bool = True, # if enabled, will cache the tasks which can be used to resume 383 | system_message: str = None, 384 | dataset_path: str = None, 385 | clean_ctx_comments: str = "none", 386 | eval_ignore_comments: bool = False, # ignore comments during score computation 387 | trust_remote_code: bool = False, 388 | attn_implementation=None, 389 | ): 390 | if backend is None: 391 | if base_url is not None: 392 | backend = "openai" 393 | else: 394 | backend = "vllm" 395 | print(f"Using {backend} as the backend") 396 | assert backend is not None, "Please specify the backend" 397 | 398 | if dataset_path is not None: 399 | with open(dataset_path) as f: 400 | dataset = json.load(f) 401 | else: 402 | dataset = get_repoqa_data() 403 | 404 | # makedir if not exists 405 | os.makedirs(result_dir, exist_ok=True) 406 | context_size_dir = os.path.join(result_dir, f"ntoken_{code_context_size}") 407 | os.makedirs(context_size_dir, exist_ok=True) 408 | model_output_path = os.path.join( 409 | context_size_dir, 410 | f"{model.replace('/', '_slash_')}.jsonl", 411 | ) 412 | 413 | # resume from model_output_file 414 | if os.path.exists(model_output_path): 415 | with open(model_output_path) as f: 416 | model_outputs = [json.loads(line) for line in f] 417 | else: 418 | model_outputs = [] 419 | 420 | if clean_ctx_comments == "positional_padding": 421 | clean_ctx_comments = CleanComment.PositionalPadding 422 | elif clean_ctx_comments == "no_padding": 423 | clean_ctx_comments = CleanComment.NoPadding 424 | else: 425 | clean_ctx_comments = CleanComment.NoClean 426 | 427 | # resume tasks from cache if any 428 | # schema: {"cache_id": .., **task} 429 | extra = "" 430 | if clean_ctx_comments != CleanComment.NoClean: 431 | extra += "_clean_cmt" 432 | cache_file = os.path.join( 433 | CACHE_DIR, f"cache{extra}_ntoken_{code_context_size}_v1.jsonl" 434 | ) 435 | os.makedirs(CACHE_DIR, exist_ok=True) 436 | 437 | cache = {} 438 | if caching: 439 | print("🔥 Caching enabled") 440 | if os.path.exists(cache_file): 441 | with open(cache_file) as f: 442 | cache = [json.loads(line) for line in f] 443 | cache = {c["cache_id"]: c for c in cache} 444 | # remove the cache_id field in c 445 | for c in cache.values(): 446 | c.pop("cache_id") 447 | print(f"Resuming from cache: {cache_file} with {len(cache)} tasks") 448 | 449 | resumed_task_ids = { 450 | make_task_id(r["language"], r["repo"], r["name"]) for r in model_outputs 451 | } 452 | 453 | # for each task we include 454 | # "repo", "name", "language", "path", 455 | # "template", "position_ratio", "description", "instruction", "code_context" 456 | # "needle_token_start", "needle_token_end", "code_context_ntokens" 457 | tasks = [] 458 | for lang, repos in dataset.items(): 459 | if languages is not None and lang not in languages: 460 | print(f"Skipping {lang} as it is not selected; selected: {languages}") 461 | continue 462 | 463 | print(f"🔥 Preparing code context for {lang}...") 464 | with progress(f"Processing {lang} context") as pbar: 465 | for repo in pbar.track(repos): 466 | # skip if the repo does not have needles 467 | if "needles" not in repo: 468 | pbar.console.print( 469 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `needles` -- do needle analysis first" 470 | ) 471 | continue 472 | 473 | ordered_paths = topological_sort(repo["dependency"]) 474 | file_content_list = [ 475 | (path, repo["content"][path]) for path in ordered_paths 476 | ] 477 | for i, needle in enumerate(repo["needles"]): 478 | task_id = make_task_id(lang, repo["repo"], needle["name"]) 479 | if task_id in resumed_task_ids: 480 | pbar.console.print( 481 | f"Skipping {task_id} as it is already in the results" 482 | ) 483 | continue 484 | 485 | position_ratio = (i + 0.5) / len(repo["needles"]) 486 | cache_id = make_cache_id( 487 | lang, 488 | repo["repo"], 489 | needle["name"], 490 | code_context_size, 491 | position_ratio, 492 | ) 493 | if cache_id in cache: 494 | tasks.append(cache[cache_id]) 495 | continue 496 | 497 | task = { 498 | "repo": repo["repo"], 499 | "name": needle["name"], 500 | "language": lang, 501 | "path": needle["path"], 502 | "position_ratio": position_ratio, 503 | "description": f"\nFunction Description:{needle['description']}\n", 504 | "instruction": INSTRUCTION, 505 | "template": TEMPLATE, 506 | } 507 | code_context_info = make_code_context( 508 | needle, 509 | file_content_list, 510 | position_ratio=position_ratio, 511 | code_context_size=code_context_size, 512 | language=lang, 513 | clean_comments=clean_ctx_comments, 514 | ) 515 | task.update(code_context_info) 516 | tasks.append(task) 517 | 518 | if caching: # cache 519 | with open(cache_file, "a") as f_out: 520 | f_out.write( 521 | json.dumps({"cache_id": cache_id, **task}) + "\n" 522 | ) 523 | # filter finished tasks again (in case a cache is used) 524 | tasks = [ 525 | task 526 | for task in tasks 527 | if make_task_id(task["language"], task["repo"], task["name"]) 528 | not in resumed_task_ids 529 | ] 530 | 531 | if len(tasks) == 0: 532 | print("No tasks to evaluate! Exiting...") 533 | return 534 | 535 | if backend == "openai": 536 | from repoqa.provider.openai import OpenAIProvider 537 | 538 | engine = OpenAIProvider(model, base_url=base_url) 539 | elif backend == "vllm": 540 | from repoqa.provider.vllm import VllmProvider 541 | 542 | engine = VllmProvider( 543 | model, 544 | tensor_parallel_size=tensor_parallel_size, 545 | max_model_len=int(code_context_size * 1.5), # Magic number 546 | trust_remote_code=trust_remote_code, 547 | ) 548 | elif backend == "anthropic": 549 | from repoqa.provider.anthropic import AnthropicProvider 550 | 551 | engine = AnthropicProvider(model) 552 | elif backend == "hf": 553 | from repoqa.provider.hf import HfProvider 554 | 555 | engine = HfProvider( 556 | model, 557 | trust_remote_code=trust_remote_code, 558 | attn_implementation=attn_implementation, 559 | ) 560 | elif backend == "google": 561 | from repoqa.provider.google import GoogleProvider 562 | 563 | engine = GoogleProvider(model) 564 | else: 565 | raise ValueError(f"Unknown backend: {backend}") 566 | 567 | if not system_message: 568 | print("🔥 System message is disabled") 569 | system_message = None 570 | 571 | with open(model_output_path, "a") as f_out: 572 | with progress(f"Running {model}") as pbar: 573 | for task in pbar.track(tasks): 574 | actual_position_ratio = ( 575 | task["needle_token_start"] / task["code_context_ntokens"] 576 | ) 577 | pbar.console.print( 578 | f"Searching {task['name']} in {task['repo']} ({task['language']}) -- " 579 | f"position ratio: actual={actual_position_ratio:.2f}, expected={task['position_ratio']}" 580 | ) 581 | prompt = "" 582 | for key in task["template"].split("\n"): 583 | prompt += task[key] 584 | 585 | replies = engine.generate_reply( 586 | prompt, n=1, max_tokens=max_new_tokens, system_msg=system_message 587 | ) 588 | result = {**task, "output": replies} 589 | f_out.write(json.dumps(result) + "\n") 590 | f_out.flush() 591 | model_outputs.append(result) 592 | 593 | file_base, _ = os.path.splitext(model_output_path) 594 | result_path = file_base + "-SCORES.json" 595 | output_json = compute_score( 596 | model, 597 | dataset, 598 | model_outputs, 599 | eval_ignore_comments or clean_ctx_comments != CleanComment.NoClean, 600 | ) 601 | save_json(output_json, result_path) 602 | 603 | 604 | def main(): 605 | from fire import Fire 606 | 607 | Fire(evaluate_model) 608 | 609 | 610 | if __name__ == "__main__": 611 | main() 612 | -------------------------------------------------------------------------------- /repoqa/utility.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from rich.progress import ( 6 | BarColumn, 7 | MofNCompleteColumn, 8 | Progress, 9 | TextColumn, 10 | TimeElapsedColumn, 11 | ) 12 | 13 | FUNCTION_QUERY = { 14 | "python": "(function_definition name: (_)) @fdef", 15 | "java": "(method_declaration name: (_)) @fdef", 16 | "typescript": "(function_declaration name: (_)) @fdef", 17 | "rust": "(function_item name: (_)) @fdef", 18 | "cpp": "(function_definition declarator: (function_declarator declarator: (identifier))) @fdef", 19 | "go": "(function_declaration name: (_)) @fdef", 20 | } 21 | 22 | COMMENT_QUERY = { 23 | "python": [ 24 | "(block (expression_statement (string) @docstring))", 25 | "(comment) @comment", 26 | ], 27 | "java": ["(line_comment) @comment", "(block_comment) @comment"], 28 | "cpp": ["(comment) @comment"], 29 | "rust": ["(line_comment) @comment", "(block_comment) @comment"], 30 | "typescript": ["(comment) @comment"], 31 | "go": ["(comment) @comment"], 32 | } 33 | 34 | FUNCTION_NAME_QUERY = { 35 | "python": """ 36 | ((function_definition 37 | name: (identifier) @function_name)) 38 | """, 39 | "java": """ 40 | (method_declaration 41 | name: (identifier) @method_name) 42 | """, 43 | "typescript": """ 44 | (function_declaration 45 | name: (identifier) @function_name) 46 | """, 47 | "rust": """ 48 | (function_item 49 | name: (identifier) @function_name) 50 | """, 51 | "cpp": """ 52 | (function_definition 53 | name: (identifier) @function_name) 54 | """, 55 | } 56 | 57 | 58 | def topological_sort(graph): 59 | # Stack to store the topological order 60 | stack = [] 61 | # Set to keep track of visited nodes 62 | visited = set() 63 | 64 | # Recursive function to process nodes 65 | def dfs(node): 66 | # Mark the current node as visited 67 | visited.add(node) 68 | # Recurse for all the vertices adjacent to this vertex 69 | for neighbour in graph.get(node, []): 70 | if neighbour not in visited: 71 | dfs(neighbour) 72 | # Push current vertex to stack which stores the result 73 | stack.append(node) 74 | 75 | # Call the recursive helper function to store the topological sort starting from all vertices one by one 76 | for node in graph: 77 | if node not in visited: 78 | dfs(node) 79 | 80 | return stack 81 | 82 | 83 | def progress(note: str = "processing"): 84 | return Progress( 85 | TextColumn(f"{note} •" + "[progress.percentage]{task.percentage:>3.0f}%"), 86 | BarColumn(), 87 | MofNCompleteColumn(), 88 | TextColumn("•"), 89 | TimeElapsedColumn(), 90 | ) 91 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tempdir 2 | appdirs 3 | wget 4 | fire 5 | nltk 6 | rich 7 | numpy 8 | tree_sitter<=0.21.3 9 | tree_sitter_languages 10 | transformers 11 | openai 12 | anthropic 13 | google-generativeai 14 | vllm 15 | -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | *.zip 2 | *.json 3 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | ## Install GitHub CLI 2 | 3 | Check [GitHub CLI](https://github.com/cli/cli) for installation. 4 | 5 | ## Run evaluations 6 | 7 | ```bash 8 | export PYTHONPATH=$(pwd) 9 | ``` 10 | 11 | ### vLLM 12 | 13 | Most OSS models can be evaluated with vLLM. 14 | The tricky part is to tune the `--tensor-parallel-size` (TP) according to the model size. 15 | 16 | * `--tensor-parallel-size` can be `1`, `2`, `4`, `8`. 17 | * If your `--tensor-parallel-size` is `2` or `4`, run `nvidia-smi topo -m` to check the GPU connectivity. 18 | * `NV4` means the GPUs are connected with NVLink (which is fast). 19 | * Set `CUDA_VISIBLE_DEVICES` to GPUs with good connectivity. 20 | * Models with < 10B may work with TP-1; 10-20B with TP-2; 20-40B with TP-4; > 40B with TP-8. 21 | 22 | ```bash 23 | python repoqa/search_needle_function.py --model "mistralai/Mistral-7B-Instruct-v0.2" --tensor-parallel-size 1 --backend vllm 24 | ``` 25 | 26 | ### OpenAI server 27 | 28 | ```bash 29 | export OPENAI_API_KEY="sk-..." # OAI or DeepSeekCoder key 30 | python repoqa/search_needle_function.py --model "gpt-4o-2024-05-13" --backend openai 31 | python repoqa/search_needle_function.py --model "deepseek-coder" --backend openai 32 | ``` 33 | 34 | ### Google Gemini 35 | 36 | ```bash 37 | export GOOGLE_API_KEY="..." 38 | python repoqa/search_needle_function.py --model "gemini-1.5-pro-latest" --backend google 39 | ``` 40 | 41 | ## Pull evaluated results 42 | 43 | ```bash 44 | cd results 45 | gh release download dev-results --pattern "*.zip" --clobber 46 | # unzip all zip files 47 | unzip "*.zip" 48 | ``` 49 | 50 | ## Update model outputs 51 | 52 | ```bash 53 | cd results 54 | # pull results first 55 | for item in "$(pwd)"/*; do 56 | # Check if the item is a directory 57 | if [ -d "$item" ]; then 58 | # Get the base name of the directory 59 | dir_name=$(basename "$item") 60 | zip -FSR "${dir_name}-output.zip" "$dir_name" "*.jsonl" 61 | zip -FSR "${dir_name}-scores.zip" "$dir_name" "*-SCORES.json" 62 | fi 63 | done 64 | gh release upload dev-results ./*.zip --clobber 65 | ``` 66 | -------------------------------------------------------------------------------- /scripts/cherrypick/README.md: -------------------------------------------------------------------------------- 1 | # GitHub Repository Curation 2 | 3 | ## Step 1 4 | 5 | ### PLs 6 | 7 | ``` 8 | python 9 | go 10 | c++ 11 | java 12 | typescript 13 | php 14 | rust 15 | ``` 16 | 17 | ### GitHub searching 18 | 19 | Search at: https://github.com/search? 20 | 21 | ``` 22 | language:{LANGUAGE} stars:>=100 license:mit license:apache-2.0 pushed:>=2024-01-01 size:128..32000 sort:stars 23 | ``` 24 | 25 | ### Find repositories by package rankings 26 | 27 | Can also look at package ranking websites such as https://pypistats.org/top 28 | 29 | ## Step 2 30 | 31 | Select 10 repositories based on the following quality criteria: 32 | 1. Uses unit tests and CI/CD 33 | 2. Is a package 34 | 3. Diversify over topics (developer tool, machine learning, etc) 35 | 4. Has enough code (hopefully 10+ files) 36 | 5. Better be pure in one language 37 | 38 | ## Step 3 39 | 40 | Follow the `lists.json` to fill the repositories. 41 | -------------------------------------------------------------------------------- /scripts/cherrypick/lists.json: -------------------------------------------------------------------------------- 1 | { 2 | "python": [ 3 | { 4 | "repo": "psf/black", 5 | "commit_sha": "f03ee113c9f3dfeb477f2d4247bfb7de2e5f465c", 6 | "entrypoint_path": "src/black", 7 | "topic": "formatter" 8 | }, 9 | { 10 | "repo": "python-poetry/poetry", 11 | "commit_sha": "21ffd99274d299bd6d2b7b0cdd7ba273fdd941be", 12 | "entrypoint_path": "src/poetry", 13 | "topic": "package-manager" 14 | }, 15 | { 16 | "repo": "locustio/locust", 17 | "commit_sha": "6241e9cf0a5c296da98abfaeb9d66d0ad1db2390", 18 | "entrypoint_path": "locust", 19 | "topic": "testing" 20 | }, 21 | { 22 | "repo": "pyg-team/pytorch_geometric", 23 | "commit_sha": "af0f5f44d4dfa8accd898df3e8523d06b67940b0", 24 | "entrypoint_path": "torch_geometric", 25 | "topic": "deep-learning" 26 | }, 27 | { 28 | "repo": "openai/openai-python", 29 | "commit_sha": "e41abf7b7dbc1e744d167f748e55d4dedfc0dca7", 30 | "entrypoint_path": "src/openai", 31 | "topic": "api" 32 | }, 33 | { 34 | "repo": "mlc-ai/mlc-llm", 35 | "commit_sha": "068d5ea9ca556f2f7a9603537b4f966da12b11f6", 36 | "entrypoint_path": "mlc_llm", 37 | "topic": "deep-learning-deployment" 38 | }, 39 | { 40 | "repo": "reactive-python/reactpy", 41 | "commit_sha": "4307a09dfa75e6c1b10a285e5ae4bdf0323cd018", 42 | "entrypoint_path": "src/py/reactpy/reactpy", 43 | "topic": "frontend" 44 | }, 45 | { 46 | "repo": "marshmallow-code/marshmallow", 47 | "commit_sha": "b8149cec77d16357d11b08f86de3b13e6fe02fa0", 48 | "entrypoint_path": "src/marshmallow", 49 | "topic": "schema" 50 | }, 51 | { 52 | "repo": "ethereum/web3.py", 53 | "commit_sha": "257f464801d11431bf09132fb30bbeeb5937542e", 54 | "entrypoint_path": "web3", 55 | "topic": "web3" 56 | }, 57 | { 58 | "repo": "Ciphey/Ciphey", 59 | "commit_sha": "5dfbe9330070d4aefc997e5d2648155a136a3757", 60 | "entrypoint_path": "ciphey", 61 | "topic": "cryptography" 62 | } 63 | ], 64 | "cpp": [ 65 | { 66 | "repo": "apache/logging-log4cxx", 67 | "commit_sha": "502f5711809e7f48c215164e374a5df62821ed52", 68 | "entrypoint_path": "src/main/cpp", 69 | "topic": "logging" 70 | }, 71 | { 72 | "repo": "skypjack/uvw", 73 | "commit_sha": "ba10b27646035594fc1dd9525dede8573b71a7d7", 74 | "entrypoint_path": "src", 75 | "topic": "network" 76 | }, 77 | { 78 | "repo": "ClickHouse/clickhouse-cpp", 79 | "commit_sha": "0fb483543b313a0979b4dbd130f834352a034ba8", 80 | "entrypoint_path": "clickhouse", 81 | "topic": "database" 82 | }, 83 | { 84 | "repo": "polybar/polybar", 85 | "commit_sha": "11b522c313f7b2b0a10063721ec8b0bf544de6f4", 86 | "entrypoint_path": "src", 87 | "topic": "developer-tool" 88 | }, 89 | { 90 | "repo": "drogonframework/drogon", 91 | "commit_sha": "294035beb96f6d474d501c210a1b403b0b0dedb6", 92 | "entrypoint_path": "lib/src", 93 | "topic": "web-framework" 94 | }, 95 | { 96 | "repo": "sass/node-sass", 97 | "commit_sha": "6081731aac89ce4612fe4839d4c6329539c0d8e1", 98 | "entrypoint_path": "src", 99 | "topic": "css-compiler" 100 | }, 101 | { 102 | "repo": "ml-explore/mlx", 103 | "commit_sha": "91eba8e4856ba9a629408e18a9d0c1e79a22dab0", 104 | "entrypoint_path": "mlx", 105 | "topic": "machine-learning" 106 | }, 107 | { 108 | "repo": "scylladb/seastar", 109 | "commit_sha": "b74a027c819ee5b9bfcee4f32edf8f37017bba12", 110 | "entrypoint_path": "src", 111 | "topic": "network" 112 | }, 113 | { 114 | "repo": "WasmEdge/WasmEdge", 115 | "commit_sha": "a5e6f2d011776a1d997586b6e9af0ec01efa3642", 116 | "entrypoint_path": "lib", 117 | "topic": "wasm" 118 | }, 119 | { 120 | "repo": "oatpp/oatpp", 121 | "commit_sha": "17ef2a7f6c8a932498799b2a5ae5aab2869975c7", 122 | "entrypoint_path": "src/oatpp", 123 | "topic": "web-framework" 124 | } 125 | ], 126 | "java": [ 127 | { 128 | "repo": "google/gson", 129 | "commit_sha": "ee61e3f020351f0498dd3ff8f8386b96454a4dfe", 130 | "entrypoint_path": "gson/src/main/java", 131 | "topic": "developer-tool" 132 | }, 133 | { 134 | "repo": "square/retrofit", 135 | "commit_sha": "10014c2bb7bd5fd24cf6b5b680e6917ab9a7767b", 136 | "entrypoint_path": "retrofit/src/main/java", 137 | "topic": "web-client" 138 | }, 139 | { 140 | "repo": "karatelabs/karate", 141 | "commit_sha": "a378ba4f9e5af11eca1794d0bd55c428c8491bd8", 142 | "entrypoint_path": "karate-core/src/main/java", 143 | "topic": "testing" 144 | }, 145 | { 146 | "repo": "Password4j/password4j", 147 | "commit_sha": "ded2c066d53df8dfaa8892a2868e04c7af655775", 148 | "entrypoint_path": "src/main/java", 149 | "topic": "cryptography" 150 | }, 151 | { 152 | "repo": "GoogleContainerTools/jib", 153 | "commit_sha": "2e9561664792f4b13f0db741ca94c996cbece949", 154 | "entrypoint_path": "jib-core/src/main/java", 155 | "topic": "build-tool" 156 | }, 157 | { 158 | "repo": "microsoft/gctoolkit", 159 | "commit_sha": "1476fc128465991f2c531645c780d3ed7c865f2c", 160 | "entrypoint_path": "parser/src/main/java", 161 | "topic": "log-parser" 162 | }, 163 | { 164 | "repo": "palantir/palantir-java-format", 165 | "commit_sha": "b3ef776d8fd7be6e5a53319f7f4a40b7ee0793a4", 166 | "entrypoint_path": "palantir-java-format/src/main/java", 167 | "topic": "formatter" 168 | }, 169 | { 170 | "repo": "microsoft/mssql-jdbc", 171 | "commit_sha": "eae6d7b33571c92029169b33378b2405e8df448a", 172 | "entrypoint_path": "src/main/java", 173 | "topic": "database" 174 | }, 175 | { 176 | "repo": "google/conscrypt", 177 | "commit_sha": "74c3c1ab7fecf9e50097a77dcac6ff2c7afd2ccf", 178 | "entrypoint_path": "common/src/main/java", 179 | "topic": "security" 180 | }, 181 | { 182 | "repo": "apache/flink-ml", 183 | "commit_sha": "f08f2756316c9fbd7fb7b7e39d140bfd3d959b2a", 184 | "entrypoint_path": "flink-ml-core/src/main/java", 185 | "topic": "machine-learning" 186 | } 187 | ], 188 | "typescript": [ 189 | { 190 | "repo": "xenova/transformers.js", 191 | "commit_sha": "642743136efa3a481b2b96d3c2c550085540d844", 192 | "entrypoint_path": "src", 193 | "topic": "machine-learning" 194 | }, 195 | { 196 | "repo": "expressjs/express", 197 | "commit_sha": "815f799310a5627c000d4a5156c1c958e4947b4c", 198 | "entrypoint_path": "lib", 199 | "topic": "web-server" 200 | }, 201 | { 202 | "repo": "axios/axios", 203 | "commit_sha": "751133eb9ed794c6f6634c52f4fe116e33bf5f09", 204 | "entrypoint_path": "lib", 205 | "topic": "web-client" 206 | }, 207 | { 208 | "repo": "date-fns/date-fns", 209 | "commit_sha": "5c1adb5369805ff552737bf8017dbe07f559b0c6", 210 | "entrypoint_path": "src", 211 | "topic": "time-library" 212 | }, 213 | { 214 | "repo": "langchain-ai/langchainjs", 215 | "commit_sha": "1f4a4498b66af5910d02dd4b9d68d99aa1350548", 216 | "entrypoint_path": "langchain/src", 217 | "topic": "machine-learning" 218 | }, 219 | { 220 | "repo": "cheeriojs/cheerio", 221 | "commit_sha": "d0b3c2f6b57cd1f835741175d463963266be0eef", 222 | "entrypoint_path": "src", 223 | "topic": "frontend" 224 | }, 225 | { 226 | "repo": "openai/openai-node", 227 | "commit_sha": "018ac718ccf6a96798ef8f91906b3b652aa50919", 228 | "entrypoint_path": "src", 229 | "topic": "api" 230 | }, 231 | { 232 | "repo": "umami-software/umami", 233 | "commit_sha": "be7f69fd5d3d711ba650a21013b73c7e17125f52", 234 | "entrypoint_path": "src", 235 | "topic": "analytics" 236 | }, 237 | { 238 | "repo": "mrdoob/three.js", 239 | "commit_sha": "df4fb14945c896fdf628981dde31b31ef7e2e0cc", 240 | "entrypoint_path": "src", 241 | "topic": "3d" 242 | }, 243 | { 244 | "repo": "firebase/firebase-admin-node", 245 | "commit_sha": "67151e620fbb7bdfc2a017e9a35d51eca035d824", 246 | "entrypoint_path": "src", 247 | "topic": "baas" 248 | } 249 | ], 250 | "rust": [ 251 | { 252 | "repo": "rust-bakery/nom", 253 | "commit_sha": "e87c7da9fa2ddb943306369f6c6d9e914256edbc", 254 | "entrypoint_path": "src", 255 | "topic": "parser" 256 | }, 257 | { 258 | "repo": "helix-editor/helix", 259 | "commit_sha": "e69292e5eb7b0f727fefa19cffec910718af31b3", 260 | "entrypoint_path": "helix-core/src", 261 | "topic": "code-editor" 262 | }, 263 | { 264 | "repo": "rust-lang/cargo", 265 | "commit_sha": "5da28587846d6cb2694e5bb1db1b5ca327285bf7", 266 | "entrypoint_path": "src", 267 | "topic": "package-manager" 268 | }, 269 | { 270 | "repo": "cloudflare/pingora", 271 | "commit_sha": "acee67f87020ef41267fe475bcc5cbed44782a06", 272 | "entrypoint_path": "pingora-core/src", 273 | "topic": "networking" 274 | }, 275 | { 276 | "repo": "huggingface/candle", 277 | "commit_sha": "33c9b6655459bd1086574cef9ba8f2e72a8804c8", 278 | "entrypoint_path": "candle-core/src", 279 | "topic": "machine-learning" 280 | }, 281 | { 282 | "repo": "seanmonstar/warp", 283 | "commit_sha": "ce8114b50cadedac089d496235fe9d18596f944e", 284 | "entrypoint_path": "src", 285 | "topic": "web-server" 286 | }, 287 | { 288 | "repo": "serde-rs/serde", 289 | "commit_sha": "6e38afff498d592af4ccac4cb669a86fc789207f", 290 | "entrypoint_path": "serde/src", 291 | "topic": "serialization" 292 | }, 293 | { 294 | "repo": "tokio-rs/tracing", 295 | "commit_sha": "908cc432a5994f6e17c8f36e13c217dc40085704", 296 | "entrypoint_path": "tracing-core/src", 297 | "topic": "application-analysis" 298 | }, 299 | { 300 | "repo": "AFLplusplus/LibAFL", 301 | "commit_sha": "527b892c1ddcc83207faaa71591e7f33a1a806a7", 302 | "entrypoint_path": "libafl/src", 303 | "topic": "fuzzing" 304 | }, 305 | { 306 | "repo": "alacritty/alacritty", 307 | "commit_sha": "d4f2f8577f763df059653dfab733dfe6ddc06913", 308 | "entrypoint_path": "alacritty/src", 309 | "topic": "terminal-emulator" 310 | } 311 | ], 312 | "go": [ 313 | { 314 | "repo": "junegunn/fzf", 315 | "commit_sha": "e352b6887849cb6c3c8ae1d98ed357f94273e90a", 316 | "entrypoint_path": "src", 317 | "topic": "command-line tool" 318 | }, 319 | { 320 | "repo": "fatedier/frp", 321 | "commit_sha": "eaae212d2d1b17360754afd9432c21640f15c832", 322 | "entrypoint_path": "pkg", 323 | "topic": "reverse proxy" 324 | }, 325 | { 326 | "repo": "caddyserver/caddy", 327 | "commit_sha": "4d6370bf92de163a53aec9081c5d5ae6614597a0", 328 | "entrypoint_path": "modules", 329 | "topic": "web server" 330 | }, 331 | { 332 | "repo": "nsqio/nsq", 333 | "commit_sha": "2127c0a1ce50a66d61d78496033b15ef5cb5e250", 334 | "entrypoint_path": "internal", 335 | "topic": "messaging platform" 336 | }, 337 | { 338 | "repo": "projectdiscovery/nuclei", 339 | "commit_sha": "3dfcec0a36dbf5ebc529ce0478076279cb975b71", 340 | "entrypoint_path": "pkg", 341 | "topic": "vulnerability scanner" 342 | }, 343 | { 344 | "repo": "jesseduffield/lazydocker", 345 | "commit_sha": "80af149b023f4d9be68a643412183d389d73368d", 346 | "entrypoint_path": "pkg", 347 | "topic": "docker management" 348 | }, 349 | { 350 | "repo": "zeromicro/go-zero", 351 | "commit_sha": "b337ae36e50092488c1899069f860007447fa17c", 352 | "entrypoint_path": "core", 353 | "topic": "cloud-native microservices" 354 | }, 355 | { 356 | "repo": "schollz/croc", 357 | "commit_sha": "d81116382fcb9dddb79b02ed4b0da99e7aecb2ab", 358 | "entrypoint_path": "src", 359 | "topic": "data transfer" 360 | }, 361 | { 362 | "repo": "iawia002/lux", 363 | "commit_sha": "59f517bb7138acd4fffcced4560197e83a592fc2", 364 | "entrypoint_path": "extractors", 365 | "topic": "video download" 366 | }, 367 | { 368 | "repo": "lima-vm/lima", 369 | "commit_sha": "5f6bfc951f5df6f8d21e3e2ccf9693146216ef4d", 370 | "entrypoint_path": "cmd", 371 | "topic": "Linux virtual machines" 372 | } 373 | ] 374 | } 375 | -------------------------------------------------------------------------------- /scripts/curate/dataset_ensemble_clone.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | from datetime import datetime 7 | 8 | import git 9 | import tempdir 10 | from fire import Fire 11 | from tqdm.auto import tqdm 12 | 13 | from scripts.curate.utility import lang2suffix 14 | 15 | 16 | def get_files_to_include(gh_repo, entrypoint, lang_suffix): 17 | files_to_include = [] 18 | for entry in gh_repo.commit().tree.traverse(): 19 | if entry.path.startswith(entrypoint) and any( 20 | [entry.path.endswith(suffix) for suffix in lang_suffix] 21 | ): 22 | files_to_include.append((entry.path, entry.abspath)) 23 | return files_to_include 24 | 25 | 26 | def main( 27 | target_path: str = f"repoqa-{datetime.now().isoformat()}.json", 28 | ): 29 | # read /scripts/cherrypick/lists.json 30 | with open("scripts/cherrypick/lists.json") as f: 31 | lists = json.load(f) 32 | 33 | for lang, repos in lists.items(): 34 | lang_suffix = lang2suffix[lang] 35 | for repo in tqdm(repos): 36 | repo_name = repo["repo"] 37 | commit_sha = repo["commit_sha"] 38 | entrypoint = repo["entrypoint_path"] 39 | 40 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}") 41 | 42 | if repo.get("content"): 43 | print(f"Skipping {repo_name} as it already has content.") 44 | continue 45 | 46 | with tempdir.TempDir() as temp_dir: 47 | gh_repo = git.Repo.clone_from( 48 | f"https://github.com/{repo_name}.git", 49 | temp_dir, 50 | ) 51 | gh_repo.git.checkout(commit_sha) 52 | 53 | files_to_include = get_files_to_include( 54 | gh_repo, entrypoint, lang_suffix 55 | ) 56 | 57 | repo["content"] = {} 58 | for path, abspath in files_to_include: 59 | with open(abspath, "r") as f: 60 | repo["content"][path] = f.read() 61 | with open(target_path, "w") as f_out: 62 | json.dump(lists, f_out) 63 | 64 | 65 | if __name__ == "__main__": 66 | Fire(main) 67 | -------------------------------------------------------------------------------- /scripts/curate/dataset_ensemble_gh_api.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | """ 6 | ! Note: not fully usable. You might encounter: 7 | github.GithubException.GithubException: 422 {"message": "Validation Failed", "errors": [{"message": "The listed users and repositories cannot be searched either because the resources do not exist o 8 | r you do not have permission to view them.", "resource": "Search", "field": "q", "code": "invalid"}], "documentation_url": "https://docs.github.com/v3/search/"} 9 | """ 10 | 11 | import json 12 | import os 13 | from datetime import datetime 14 | from typing import TypedDict 15 | 16 | from fire import Fire 17 | from github import Auth, Github 18 | from github.Repository import Repository 19 | from tqdm.auto import tqdm 20 | 21 | from scripts.curate.utility import lang2suffix 22 | 23 | 24 | class GitHubRepoMeta(TypedDict): 25 | repo_name: str 26 | repo_owner: str 27 | commit_sha: str 28 | repo_size: int 29 | 30 | 31 | class GitHubDocument(GitHubRepoMeta): 32 | timestamp: str 33 | path: str 34 | content: str 35 | 36 | 37 | def main( 38 | target_path: str = f"repoqa-{datetime.now().isoformat()}.json", 39 | ): 40 | token = os.getenv("GITHUB_TOKEN") 41 | assert token is not None, "Make a token at https://github.com/settings/tokens" 42 | auth = Auth.Token(token) 43 | 44 | # read /scripts/cherrypick/lists.json 45 | with open("scripts/cherrypick/lists.json") as f: 46 | lists = json.load(f) 47 | 48 | g = Github(auth=auth, per_page=1) 49 | for lang, repos in lists.items(): 50 | lang_suffix = lang2suffix[lang] 51 | for repo in tqdm(repos): 52 | if repo.get("content"): 53 | print(f"Skipping {repo['repo']} as it already has content.") 54 | continue 55 | 56 | repo_name = repo["repo"] 57 | commit_sha = repo["commit_sha"] 58 | entrypoint = repo["entrypoint_path"] 59 | query = f"repo:{repo_name}" 60 | 61 | gh_repos = g.search_repositories(query) 62 | gh_repo: Repository = gh_repos[0] 63 | contents = [ 64 | item 65 | for item in gh_repo.get_contents(entrypoint, ref=commit_sha) 66 | if any([item.path.endswith(suffix) for suffix in lang_suffix]) 67 | ] 68 | 69 | repo["content"] = {} 70 | for item in contents: 71 | if item.encoding != "base64": 72 | continue 73 | file_content = item.decoded_content.decode("utf-8") 74 | repo["content"][item.path] = file_content 75 | 76 | with open(target_path, "w") as f_out: 77 | json.dump(lists, f_out) 78 | 79 | 80 | if __name__ == "__main__": 81 | Fire(main) 82 | -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/cpp.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | import os 7 | from pathlib import Path 8 | from typing import List, Union 9 | 10 | import git 11 | import tempdir 12 | from fire import Fire 13 | from tqdm.auto import tqdm 14 | 15 | from scripts.curate.utility import lang2suffix 16 | 17 | POTENTIAL_INCLUDE_DIR_NAMES = [ 18 | "include", 19 | "src", 20 | "lib", 21 | "libs", 22 | "library", 23 | "libraries", 24 | "inc", 25 | ] 26 | 27 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | 29 | 30 | def get_potential_include_dirs(temp_dir: str) -> List[Path]: 31 | include_dir_list = [Path(temp_dir)] 32 | for potential_include_dir_name in POTENTIAL_INCLUDE_DIR_NAMES: 33 | for dir in Path(temp_dir).rglob(potential_include_dir_name): 34 | include_dir_list.append(dir) 35 | 36 | return include_dir_list 37 | 38 | 39 | def search_include( 40 | entrypoint: str, 41 | temp_dir: str, 42 | cpp_file: Path, 43 | include: str, 44 | additional_libs: List[Path] = [], 45 | ) -> Union[Path, None]: 46 | for possible_path in [cpp_file.parent] + additional_libs: 47 | if not (entrypoint in str(possible_path)): 48 | continue 49 | if (possible_path / include).exists(): 50 | return ( 51 | (possible_path / include) 52 | .resolve() 53 | .relative_to(Path(temp_dir)) 54 | .as_posix() 55 | ) 56 | 57 | 58 | def get_dependencies(temp_dir: str, entrypoint: str, cpp_file: Path): 59 | user_defined_includes = [] 60 | with open(cpp_file) as f: 61 | content = f.read() 62 | for line in content.split("\n"): 63 | line = line.strip() 64 | if line.startswith("#include"): 65 | include = line.split()[1] 66 | elif line.split()[:2] == ["#", "include"]: 67 | include = line.split()[2] 68 | else: 69 | continue 70 | 71 | if (include.startswith('"') or include.startswith("<")) and "." in include: 72 | include = include[1:-1] 73 | rela_include = search_include( 74 | entrypoint, 75 | temp_dir, 76 | cpp_file, 77 | include, 78 | additional_libs=get_potential_include_dirs(temp_dir), 79 | ) 80 | if rela_include: 81 | user_defined_includes.append(rela_include) 82 | 83 | return user_defined_includes 84 | 85 | 86 | # dataset_path is the dataset generated by dataset_ensemble_clone.py 87 | def main(): 88 | with open("scripts/cherrypick/lists.json") as f: 89 | lists = json.load(f) 90 | 91 | lang_suffix = lang2suffix["cpp"] 92 | repos = lists["cpp"] 93 | for repo in tqdm(repos): 94 | repo_name = repo["repo"] 95 | commit_sha = repo["commit_sha"] 96 | entrypoint = repo["entrypoint_path"] 97 | 98 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}") 99 | 100 | if repo.get("dependency"): 101 | print(f"Skipping {repo_name} as it already has dependency field.") 102 | continue 103 | 104 | with tempdir.TempDir() as temp_dir: 105 | gh_repo = git.Repo.clone_from( 106 | f"https://github.com/{repo_name}.git", 107 | temp_dir, 108 | ) 109 | gh_repo.git.checkout(commit_sha) 110 | 111 | output_dependency = {} 112 | abs_prefix = os.path.join(temp_dir, entrypoint) 113 | for cpp_ext in lang_suffix: 114 | for cpp_file in Path(abs_prefix).rglob("*" + cpp_ext): 115 | dependencies = get_dependencies(temp_dir, entrypoint, cpp_file) 116 | if len(dependencies) > 0: 117 | if "" in dependencies: 118 | dependencies.remove("") 119 | output_dependency[ 120 | cpp_file.relative_to(Path(temp_dir)).as_posix() 121 | ] = dependencies 122 | else: 123 | output_dependency[ 124 | cpp_file.relative_to(Path(temp_dir)).as_posix() 125 | ] = [] 126 | 127 | repo["dependency"] = output_dependency 128 | 129 | with open(os.path.join(CURRENT_DIR, "data", "cpp.json"), "w") as f_out: 130 | json.dump({"cpp": repos}, f_out) 131 | 132 | 133 | if __name__ == "__main__": 134 | Fire(main) 135 | -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/data/.gitignore: -------------------------------------------------------------------------------- 1 | *.json 2 | -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/go-analysis/dependency_analysis.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | 4 | import ( 5 | "encoding/json" 6 | "go/ast" 7 | "go/parser" 8 | "go/token" 9 | "io/fs" 10 | "os" 11 | "path/filepath" 12 | "strings" 13 | ) 14 | 15 | 16 | type DependencyMap struct { 17 | Name string `json:"name"` 18 | RepoName map[string][]string `json:"repoName"` 19 | } 20 | 21 | 22 | func main() { 23 | if len(os.Args) != 3 { 24 | os.Stderr.WriteString("Usage: go run main.go \n") 25 | os.Exit(1) 26 | } 27 | 28 | 29 | rootDir := os.Args[1] // This should be the 'src' directory 30 | mapName := os.Args[2] // Name passed in while calling from terminal 31 | repoMap := make(map[string][]string) 32 | fset := token.NewFileSet() 33 | 34 | 35 | // Walk through the directory and all its subdirectories 36 | filepath.WalkDir(rootDir, func(path string, d fs.DirEntry, err error) error { 37 | if err != nil { 38 | return err 39 | } 40 | 41 | 42 | if !d.IsDir() && strings.HasSuffix(d.Name(), ".go") { 43 | parsedFile, err := parser.ParseFile(fset, path, nil, parser.ParseComments) 44 | if err != nil { 45 | return err 46 | } 47 | 48 | 49 | // Initialize dependencies as an empty slice instead of nil 50 | dependencies := make(map[string]bool) 51 | 52 | 53 | // Analyze AST and find dependencies 54 | ast.Inspect(parsedFile, func(n ast.Node) bool { 55 | switch x := n.(type) { 56 | case *ast.SelectorExpr: 57 | if ident, ok := x.X.(*ast.Ident); ok { 58 | pkg := ident.Name 59 | 60 | 61 | // Check for local files that may correspond to this identifier 62 | filepath.WalkDir(rootDir, func(depPath string, depInfo fs.DirEntry, depErr error) error { 63 | if depErr != nil { 64 | return depErr 65 | } 66 | 67 | 68 | if !depInfo.IsDir() && strings.TrimSuffix(depInfo.Name(), ".go") == pkg { 69 | relPath, err := filepath.Rel(rootDir, strings.TrimPrefix(depPath, rootDir+string(os.PathSeparator))) 70 | if err == nil { 71 | dependencies[relPath] = true 72 | } 73 | } 74 | 75 | 76 | return nil 77 | }) 78 | } 79 | } 80 | 81 | 82 | return true 83 | }) 84 | 85 | 86 | // Convert map keys to a slice 87 | deps := make([]string, 0) // Initialize as an empty slice 88 | for dep := range dependencies { 89 | deps = append(deps, dep) 90 | } 91 | 92 | 93 | fileRelPath, _ := filepath.Rel(rootDir, strings.TrimPrefix(path, rootDir+string(os.PathSeparator))) 94 | repoMap[fileRelPath] = deps 95 | } 96 | 97 | 98 | return nil 99 | }) 100 | 101 | 102 | output := DependencyMap{Name: mapName, RepoName: repoMap} 103 | result, err := json.Marshal(output) // Change back to Marshal for single-line JSON output 104 | if err != nil { 105 | panic(err) 106 | } 107 | 108 | 109 | // Write the result to a file 110 | file, err := os.Create("output.json") 111 | if err != nil { 112 | panic(err) 113 | } 114 | defer file.Close() 115 | 116 | 117 | _, err = file.Write(result) 118 | if err != nil { 119 | panic(err) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/go.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | import os 7 | import shutil 8 | import subprocess 9 | from typing import Dict, List 10 | 11 | import git 12 | import tempdir 13 | from fire import Fire 14 | from tqdm.auto import tqdm 15 | 16 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | 18 | 19 | def remove_relative(path: str) -> str: 20 | parts = path.split(os.sep) 21 | filtered_parts = [part for part in parts if part not in (".", "..")] 22 | new_path = os.path.join(*filtered_parts) 23 | return new_path 24 | 25 | 26 | def sanitize_paths(data: Dict[str, List[str]], entrypoint: str) -> Dict[str, List[str]]: 27 | sanitized_data = {} 28 | for file, dependencies in data.items(): 29 | updated_file = os.path.join(entrypoint, remove_relative(file)) 30 | updated_dependencies = [] 31 | for dependency in dependencies: 32 | updated_dependency = os.path.join(entrypoint, remove_relative(dependency)) 33 | updated_dependencies.append(updated_dependency) 34 | sanitized_data[updated_file] = updated_dependencies 35 | return sanitized_data 36 | 37 | 38 | def run_dependency_analysis(config_file, go_file): 39 | # Load the JSON configuration 40 | with open(config_file, "r") as file: 41 | data = json.load(file) 42 | 43 | repos = data["go"] 44 | 45 | # Iterate over each repo entry in the JSON configuration 46 | for entry in tqdm(repos): 47 | repo_name = entry["repo"] 48 | commit_sha = entry["commit_sha"] 49 | entrypoint_path = entry["entrypoint_path"] 50 | 51 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}") 52 | 53 | with tempdir.TempDir() as temp_dir: 54 | gh_repo = git.Repo.clone_from( 55 | f"https://github.com/{repo_name}.git", 56 | temp_dir, 57 | ) 58 | gh_repo.git.checkout(commit_sha) 59 | shutil.copy(go_file, temp_dir) 60 | 61 | command_list = ( 62 | f"go build -o dependency_analysis dependency_analysis.go".split() 63 | ) 64 | subprocess.run(command_list, cwd=temp_dir) 65 | 66 | command_list = f"./dependency_analysis {entrypoint_path} {repo_name.split('/')[-1]}".split() 67 | 68 | subprocess.run(command_list, cwd=temp_dir) 69 | output_dir = os.path.join(temp_dir, "output.json") 70 | with open(output_dir, "r") as output_file: 71 | output_data = json.load(output_file) 72 | entry["dependency"] = sanitize_paths( 73 | output_data["repoName"], entrypoint_path 74 | ) 75 | 76 | # Write all output data to a file 77 | with open(os.path.join(CURRENT_DIR, "data", "go.json"), "w") as f_out: 78 | json.dump({"go": repos}, f_out) 79 | 80 | 81 | def main(): 82 | config_file = "scripts/cherrypick/lists.json" 83 | go_file = "scripts/curate/dep_analysis/go-analysis/dependency_analysis.go" 84 | 85 | run_dependency_analysis(config_file, go_file) 86 | 87 | 88 | if __name__ == "__main__": 89 | Fire(main) 90 | -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/java-analysis/dependency-reduced-pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | edu.cs.illinois.repoqa 5 | java-analysis 6 | 1.0-SNAPSHOT 7 | 8 | 9 | 10 | maven-shade-plugin 11 | 3.2.4 12 | 13 | 14 | package 15 | 16 | shade 17 | 18 | 19 | 20 | 21 | edu.cs.illinois.repoqa.DepAnalyze 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | maven-antrun-plugin 30 | 3.0.0 31 | 32 | 33 | package 34 | 35 | run 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 8 49 | 8 50 | 51 | 52 | -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/java-analysis/java-lib/java-analysis-1.0-SNAPSHOT.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evalplus/repoqa/ae876deb1365dbf5a15b0533723c8ed123eee586/scripts/curate/dep_analysis/java-analysis/java-lib/java-analysis-1.0-SNAPSHOT.jar -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/java-analysis/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | edu.cs.illinois.repoqa 8 | java-analysis 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 8 13 | 8 14 | 15 | 16 | 17 | 18 | com.github.javaparser 19 | javaparser-core 20 | 3.25.9 21 | 22 | 23 | com.github.javaparser 24 | javaparser-symbol-solver-core 25 | 3.25.9 26 | 27 | 28 | 29 | 30 | 31 | 32 | org.apache.maven.plugins 33 | maven-shade-plugin 34 | 3.2.4 35 | 36 | 37 | package 38 | 39 | shade 40 | 41 | 42 | 43 | 44 | edu.cs.illinois.repoqa.DepAnalyze 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | org.apache.maven.plugins 53 | maven-antrun-plugin 54 | 3.0.0 55 | 56 | 57 | package 58 | 59 | run 60 | 61 | 62 | 63 | 64 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/java-analysis/src/main/java/edu/cs/illinois/repoqa/DepAnalyze.java: -------------------------------------------------------------------------------- 1 | package edu.cs.illinois.repoqa; 2 | 3 | import java.io.File; 4 | import java.nio.file.Path; 5 | import java.nio.file.Paths; 6 | import java.util.ArrayList; 7 | import java.util.Collections; 8 | import java.util.HashSet; 9 | import java.util.List; 10 | 11 | import com.github.javaparser.ParserConfiguration; 12 | import com.github.javaparser.ast.CompilationUnit; 13 | import com.github.javaparser.ast.ImportDeclaration; 14 | import com.github.javaparser.ast.expr.NameExpr; 15 | import com.github.javaparser.symbolsolver.JavaSymbolSolver; 16 | import com.github.javaparser.symbolsolver.resolution.typesolvers.CombinedTypeSolver; 17 | import com.github.javaparser.symbolsolver.resolution.typesolvers.ReflectionTypeSolver; 18 | import com.github.javaparser.utils.SourceRoot; 19 | 20 | public class DepAnalyze { 21 | public static String getStartPackage(Path srcRootPath, Path filePath) { 22 | return srcRootPath.relativize(filePath.getParent()).toString().replace(File.separator, "."); 23 | } 24 | 25 | public static void analyze(String repoPath, String entryPoint, String filePath) { 26 | Path srcRootPath = Paths.get(repoPath, entryPoint).toAbsolutePath(); // src/main/java 27 | CombinedTypeSolver combinedTypeSolver = new CombinedTypeSolver(); 28 | combinedTypeSolver.add(new ReflectionTypeSolver()); 29 | JavaSymbolSolver symbolSolver = new JavaSymbolSolver(combinedTypeSolver); 30 | ParserConfiguration parserConfiguration = new ParserConfiguration().setSymbolResolver(symbolSolver).setLanguageLevel(ParserConfiguration.LanguageLevel.JAVA_14); 31 | SourceRoot sourceRoot = new SourceRoot(srcRootPath, parserConfiguration); 32 | CompilationUnit cu = sourceRoot.parse(getStartPackage(srcRootPath, Paths.get(filePath)), new File(filePath).getName()); 33 | 34 | List depPaths = new ArrayList<>(); 35 | depPaths.addAll(getImportDepPaths(cu, srcRootPath)); 36 | depPaths.addAll(getSamePackageDepPaths(cu, Paths.get(filePath).toAbsolutePath())); 37 | 38 | depPaths = new ArrayList<>(new HashSet<>(depPaths)); // remove duplicates 39 | Collections.sort(depPaths); 40 | depPaths.remove(Paths.get(repoPath)); // remove the current file from the dependencies 41 | depPaths.forEach(p -> System.out.println(Paths.get(repoPath).relativize(p))); 42 | } 43 | 44 | private static List getImportDepPaths(CompilationUnit cu, Path srcRootPath) { 45 | List depPaths = new ArrayList<>(); 46 | for (ImportDeclaration importDeclaration : cu.getImports()) { 47 | String importStr = importDeclaration.getNameAsString(); 48 | if (!importDeclaration.isAsterisk()) { 49 | Path depPath = srcRootPath.resolve(importStr.replace(".", File.separator) + ".java"); 50 | if (depPath.toFile().exists()) { 51 | depPaths.add(depPath); 52 | } else { 53 | Path possibleJavaPath = srcRootPath.resolve(importStr.substring(0, importStr.lastIndexOf(".")).replace(".", File.separator) + ".java"); 54 | if (possibleJavaPath.toFile().exists()) { 55 | // this indicates that the import is like "import static com.example.Main.main;" 56 | depPaths.add(possibleJavaPath); 57 | } 58 | } 59 | } 60 | else { 61 | Path depDirPath = srcRootPath.resolve(importStr.replace(".", File.separator)); 62 | Path possibleJavaPath = srcRootPath.resolve(importStr.substring(0, importStr.length()).replace(".", File.separator) + ".java"); 63 | if (possibleJavaPath.toFile().exists()) { 64 | // this indicates that the import is like "import com.example.Main.*;" 65 | depPaths.add(possibleJavaPath); 66 | continue; 67 | } 68 | if (depDirPath.toFile().exists()) { 69 | File[] files = depDirPath.toFile().listFiles(); 70 | for (File file : files) { 71 | if (file.isFile() && file.getName().endsWith(".java")) { 72 | depPaths.add(file.toPath()); 73 | } 74 | } 75 | } 76 | } 77 | } 78 | 79 | return depPaths; 80 | } 81 | 82 | private static List getSamePackageDepPaths(CompilationUnit cu, Path filePath) { 83 | List depPaths = new ArrayList<>(); 84 | List siblingClassSimpleNameList = new ArrayList<>(); 85 | for (File siblingFile : filePath.getParent().toFile().listFiles()) { 86 | if (siblingFile.getAbsolutePath().endsWith(".java") && !siblingFile.getAbsolutePath().equals(filePath.toFile().getAbsolutePath())) { 87 | siblingClassSimpleNameList.add(siblingFile.getName().substring(0, siblingFile.getName().length() - 5)); 88 | } 89 | } 90 | 91 | // check all identifiers in the current file 92 | for (NameExpr name : cu.findAll(NameExpr.class)) { 93 | String identifier = name.getNameAsString(); 94 | if (siblingClassSimpleNameList.contains(identifier)) { 95 | Path depPath = filePath.getParent().resolve(identifier + ".java"); 96 | if (depPath.toFile().exists()) { 97 | depPaths.add(depPath); 98 | } 99 | } 100 | } 101 | 102 | return depPaths; 103 | } 104 | 105 | public static void main(String[] args) { 106 | String repoPath = args[0]; 107 | String entryPoint = args[1]; // entry point is relative to the repoPath 108 | String filePath = args[2]; 109 | analyze(repoPath, entryPoint, filePath); 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/java.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | import os 7 | import subprocess 8 | from pathlib import Path 9 | 10 | import git 11 | import tempdir 12 | from fire import Fire 13 | from tqdm.auto import tqdm 14 | 15 | TOOL_JAR_PATH = ( 16 | Path(__file__).resolve().parent 17 | / "java-analysis" 18 | / "java-lib" 19 | / "java-analysis-1.0-SNAPSHOT.jar" 20 | ) 21 | 22 | 23 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 24 | 25 | # dataset_path is the dataset generated by dataset_ensemble_clone.py 26 | def main(): 27 | with open("scripts/cherrypick/lists.json") as f: 28 | lists = json.load(f) 29 | 30 | repos = lists["java"] 31 | for repo in tqdm(repos): 32 | repo_name = repo["repo"] 33 | commit_sha = repo["commit_sha"] 34 | entrypoint = repo["entrypoint_path"] 35 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}") 36 | 37 | with tempdir.TempDir() as temp_dir: 38 | gh_repo = git.Repo.clone_from( 39 | f"https://github.com/{repo_name}.git", 40 | temp_dir, 41 | ) 42 | gh_repo.git.checkout(commit_sha) 43 | 44 | output_dependency = {} 45 | abs_prefix = os.path.join(temp_dir, entrypoint) 46 | for java_file in Path(abs_prefix).rglob("*.java"): 47 | command_list = f"java -jar {TOOL_JAR_PATH} {temp_dir} {entrypoint} {java_file.absolute()}".split() 48 | # cd `temp_dir`` and capture the output 49 | output = subprocess.check_output(command_list, cwd=temp_dir) 50 | dependencies = output.decode("utf-8").strip().split("\n") 51 | 52 | if dependencies: 53 | if "" in dependencies: 54 | dependencies.remove("") 55 | output_dependency[ 56 | java_file.relative_to(Path(temp_dir)).as_posix() 57 | ] = dependencies 58 | else: 59 | output_dependency[ 60 | java_file.relative_to(Path(temp_dir)).as_posix() 61 | ] = [] 62 | 63 | repo["dependency"] = output_dependency 64 | 65 | with open(os.path.join(CURRENT_DIR, "data", "java.json"), "w") as f_out: 66 | json.dump({"java": repos}, f_out) 67 | 68 | 69 | if __name__ == "__main__": 70 | Fire(main) 71 | -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/python.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | import os 7 | import subprocess 8 | 9 | import git 10 | import tempdir 11 | from fire import Fire 12 | from tqdm.auto import tqdm 13 | 14 | from scripts.curate.dataset_ensemble_clone import get_files_to_include 15 | from scripts.curate.utility import lang2suffix 16 | 17 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | 19 | 20 | # dataset_path is the dataset generated by dataset_ensemble_clone.py 21 | def main(): 22 | with open("scripts/cherrypick/lists.json") as f: 23 | lists = json.load(f) 24 | 25 | lang_suffix = lang2suffix["python"] 26 | repos = lists["python"] 27 | for repo in tqdm(repos): 28 | repo_name = repo["repo"] 29 | commit_sha = repo["commit_sha"] 30 | entrypoint = repo["entrypoint_path"] 31 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}") 32 | 33 | with tempdir.TempDir() as temp_dir: 34 | gh_repo = git.Repo.clone_from( 35 | f"https://github.com/{repo_name}.git", 36 | temp_dir, 37 | ) 38 | gh_repo.git.checkout(commit_sha) 39 | command_list = f"pydeps {entrypoint} --show-deps --no-show".split() 40 | # cd `temp_dir`` and capture the output json 41 | output = subprocess.check_output(command_list, cwd=temp_dir) 42 | dependencies = json.loads(output) 43 | 44 | output_dependency = {} 45 | mod2path = {} 46 | 47 | abs_prefix = os.path.join(temp_dir, entrypoint) 48 | for mod, v in dependencies.items(): 49 | if v["path"] and v["path"].startswith(abs_prefix): 50 | mod2path[mod] = v["path"] 51 | 52 | for mod, v in dependencies.items(): 53 | if v["path"] and v["path"].startswith(abs_prefix): 54 | if "imports" in v: 55 | relative_path = os.path.relpath(v["path"], temp_dir) 56 | output_dependency[relative_path] = [ 57 | os.path.relpath(mod2path[imp], temp_dir) 58 | for imp in v["imports"] 59 | if imp in mod2path 60 | ] 61 | 62 | files_to_include = get_files_to_include(gh_repo, entrypoint, lang_suffix) 63 | for path, _ in files_to_include: 64 | if path not in output_dependency: 65 | output_dependency[path] = [] 66 | 67 | assert output_dependency, "Empty output_dependency" 68 | repo["dependency"] = output_dependency 69 | 70 | with open(os.path.join(CURRENT_DIR, "data", "python.json"), "w") as f_out: 71 | json.dump({"python": repos}, f_out) 72 | 73 | 74 | if __name__ == "__main__": 75 | Fire(main) 76 | -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/rust.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import json 5 | import os 6 | import subprocess 7 | from enum import Enum 8 | from pathlib import Path 9 | from typing import List, Tuple 10 | 11 | import git 12 | import pygraphviz as pgv 13 | import tempdir 14 | from fire import Fire 15 | from pygraphviz import Node 16 | from tqdm.auto import tqdm 17 | 18 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | 20 | 21 | class CrateType(Enum): 22 | LIB = ("lib",) 23 | BIN = "bin" 24 | 25 | 26 | # Given a namespace, find its location in the file directory 27 | def find_path(node: Node, temp_dir: str, entrypoint: str, type: CrateType) -> str: 28 | module_path = node.get_name() 29 | temp_dir = Path(temp_dir) 30 | entrypoint = Path(entrypoint) 31 | paths = module_path.split("::") 32 | 33 | current_path = entrypoint 34 | best_path = None 35 | 36 | for path in paths[1:]: 37 | current_path = current_path / path 38 | if (temp_dir / current_path / "mod.rs").exists(): 39 | best_path = current_path / "mod.rs" 40 | continue 41 | elif (temp_dir / current_path).with_suffix(".rs").exists(): 42 | best_path = current_path.with_suffix(".rs") 43 | break 44 | 45 | if not (best_path): 46 | if type == CrateType.LIB: 47 | return os.path.join(entrypoint, "lib.rs") 48 | elif type == CrateType.BIN: 49 | return os.path.join(entrypoint, "main.rs") 50 | return str(best_path) 51 | 52 | 53 | # Find all buildable packages within the entrypoint 54 | def find_packages(temp_dir: str, entrypoint: str) -> List[Tuple[str, str, CrateType]]: 55 | command_list = f"cargo metadata -q --no-deps --format-version=1".split() 56 | output = subprocess.check_output(command_list, cwd=temp_dir) 57 | decoded_output = output.decode("utf-8") 58 | 59 | result = json.loads(decoded_output) 60 | packages = result["packages"] 61 | results = [] 62 | for package in packages: 63 | for target in package["targets"]: 64 | if not (os.path.join(temp_dir, entrypoint) in target["src_path"]): 65 | continue 66 | if target["kind"][0] == "lib": 67 | results.append((target["name"], target["src_path"], CrateType.LIB)) 68 | elif target["kind"][0] == "bin": 69 | results.append((target["name"], target["src_path"], CrateType.BIN)) 70 | return results 71 | 72 | 73 | def is_cargo_modules_available() -> bool: 74 | try: 75 | command_list = ["cargo-modules", "--help"] 76 | subprocess.run(command_list, capture_output=True) 77 | return True 78 | except (subprocess.CalledProcessError, FileNotFoundError): 79 | return False 80 | 81 | 82 | # dataset_path is the dataset generated by dataset_ensemble_clone.py 83 | def main(): 84 | if not (is_cargo_modules_available()): 85 | print("cargo-modules tool not found, exiting...") 86 | return 87 | 88 | with open("scripts/cherrypick/lists.json") as f: 89 | lists = json.load(f) 90 | 91 | repos = lists["rust"] 92 | for repo in tqdm(repos): 93 | repo_name = repo["repo"] 94 | commit_sha = repo["commit_sha"] 95 | entrypoint = repo["entrypoint_path"] 96 | 97 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}") 98 | 99 | with tempdir.TempDir() as temp_dir: 100 | gh_repo = git.Repo.clone_from( 101 | f"https://github.com/{repo_name}.git", 102 | temp_dir, 103 | ) 104 | gh_repo.git.checkout(commit_sha) 105 | packages = find_packages(temp_dir, entrypoint) 106 | mapping = {} 107 | dependency = {} 108 | edges = [] 109 | for package_name, src_path, crate_type in packages: 110 | if "_" in package_name: 111 | package_name = "-".join(package_name.split("_")) 112 | analysis_param = "" 113 | 114 | # Run cargo-modules tool 115 | if crate_type == CrateType.BIN: 116 | analysis_param = f"--bin {package_name}" 117 | else: 118 | analysis_param = f"--lib" 119 | command_list = f"cargo modules dependencies --package {package_name} {analysis_param} --cfg-test --no-sysroot --no-traits --no-types".split() 120 | entrypoint = os.path.dirname(os.path.relpath(src_path, temp_dir)) 121 | 122 | # cd `temp_dir`` and capture the output json 123 | output = subprocess.check_output(command_list, cwd=temp_dir) 124 | decoded_output = output.decode("utf-8") 125 | 126 | # Parse graph and get mapping 127 | graph = pgv.AGraph(string=decoded_output) 128 | for node in graph.nodes(): 129 | node_name = node.get_name() 130 | mapping[node_name] = find_path( 131 | node, temp_dir, entrypoint, crate_type 132 | ) 133 | if mapping[node_name] in dependency: 134 | continue 135 | dependency[mapping[node_name]] = set() 136 | 137 | # Save edges for later 138 | for start_node, end_node in graph.edges(): 139 | edges.append((start_node.get_name(), end_node.get_name())) 140 | 141 | # Parse every edge for dependencies 142 | for start_name, end_name in edges: 143 | if ( 144 | not (start_name in mapping and end_name in mapping) 145 | or mapping[start_name] == mapping[end_name] 146 | ): 147 | continue 148 | dependency[mapping[start_name]].add(mapping[end_name]) 149 | 150 | for key in dependency: 151 | dependency[key] = list(dependency[key]) 152 | 153 | repo["dependency"] = dependency 154 | 155 | with open(os.path.join(CURRENT_DIR, "data", "rust.json"), "w") as f_out: 156 | json.dump({"rust": repos}, f_out) 157 | 158 | 159 | if __name__ == "__main__": 160 | Fire(main) 161 | -------------------------------------------------------------------------------- /scripts/curate/dep_analysis/typescript.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | import os 7 | import re 8 | import subprocess 9 | from pathlib import Path 10 | from typing import Dict, List 11 | 12 | import git 13 | import tempdir 14 | from fire import Fire 15 | from tqdm.auto import tqdm 16 | 17 | from scripts.curate.dataset_ensemble_clone import get_files_to_include 18 | from scripts.curate.utility import lang2suffix 19 | 20 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 21 | 22 | 23 | def traverse_dep( 24 | current_tree: Dict, current_file: str, entrypoint: str, dependencies: Dict 25 | ) -> None: 26 | if current_tree == None: 27 | dependencies[current_file] = set() 28 | return 29 | 30 | results = set( 31 | [ 32 | file 33 | for file in current_tree.keys() 34 | if entrypoint in file and Path(file).suffix in [".ts", ".js"] 35 | ] 36 | ) 37 | dependencies[current_file] = results 38 | 39 | for result in results: 40 | traverse_dep(current_tree[result], result, entrypoint, dependencies) 41 | 42 | 43 | def add_circular(circular_dependencies: Dict, dependencies: Dict): 44 | for cycle in circular_dependencies: 45 | for index, from_file in enumerate(cycle[:-1]): 46 | to_file = cycle[index + 1] 47 | dependencies[from_file].add(to_file) 48 | 49 | 50 | def get_dependencies( 51 | file_path: Path, temp_dir: str, entrypoint: str, dependencies: Dict 52 | ) -> None: 53 | command_list = f"dep-tree tree {file_path} --json".split() 54 | output_string = subprocess.check_output(command_list, cwd=temp_dir) 55 | output_string = output_string.decode("utf-8")[1:-2].split() 56 | json_string = "".join(chr(int(value)) for value in output_string) 57 | 58 | json_data = json.loads(json_string) 59 | tree_data = json_data["tree"] 60 | circular_dependencies = json_data["circularDependencies"] 61 | current_file = list(tree_data.keys())[0] 62 | 63 | traverse_dep(tree_data[current_file], current_file, entrypoint, dependencies) 64 | add_circular(circular_dependencies, dependencies) 65 | 66 | 67 | # dataset_path is the dataset generated by dataset_ensemble_clone.py 68 | def main(): 69 | with open("scripts/cherrypick/lists.json") as f: 70 | lists = json.load(f) 71 | 72 | lang_suffix = lang2suffix["typescript"] 73 | repos = lists["typescript"] 74 | for repo in tqdm(repos): 75 | repo_name = repo["repo"] 76 | commit_sha = repo["commit_sha"] 77 | entrypoint = repo["entrypoint_path"] 78 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}") 79 | 80 | dependencies = {} 81 | with tempdir.TempDir() as temp_dir: 82 | gh_repo = git.Repo.clone_from( 83 | f"https://github.com/{repo_name}.git", 84 | temp_dir, 85 | ) 86 | gh_repo.git.checkout(commit_sha) 87 | abs_prefix = Path(os.path.join(temp_dir, entrypoint)) 88 | if "/" in entrypoint: 89 | suffix_path = entrypoint.split("/")[-1] 90 | else: 91 | suffix_path = entrypoint 92 | for file_path in abs_prefix.rglob("*"): 93 | if file_path.is_file() and file_path.suffix in lang_suffix: 94 | current_name = os.path.relpath(str(file_path), temp_dir) 95 | if current_name in dependencies: 96 | continue 97 | get_dependencies(file_path, temp_dir, suffix_path, dependencies) 98 | 99 | for key, value in dependencies.items(): 100 | dependencies[key] = list(value) 101 | 102 | if "/" in entrypoint: 103 | updated_dependencies = {} 104 | append_prefix = "/".join(entrypoint.split("/")[:-1]) + "/" 105 | for key, values in dependencies.items(): 106 | new_key = append_prefix + key 107 | if not ((Path(temp_dir) / Path(new_key)).exists()): 108 | continue 109 | new_values = [] 110 | for value in values: 111 | new_value = append_prefix + value 112 | if (Path(temp_dir) / Path(new_value)).exists(): 113 | new_values.append(new_value) 114 | updated_dependencies[new_key] = new_values 115 | dependencies = updated_dependencies 116 | repo["dependency"] = dependencies 117 | with open(os.path.join(CURRENT_DIR, "data", "typescript.json"), "w") as f_out: 118 | json.dump({"typescript": repos}, f_out) 119 | 120 | 121 | if __name__ == "__main__": 122 | Fire(main) 123 | -------------------------------------------------------------------------------- /scripts/curate/function_analysis.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | 7 | from tqdm import tqdm 8 | from tree_sitter_languages import get_language, get_parser 9 | 10 | from repoqa.utility import COMMENT_QUERY, FUNCTION_QUERY, topological_sort 11 | 12 | _default_name_parser = lambda node: node.child_by_field_name("name").text.decode() 13 | _cpp_name_parser = ( 14 | lambda node: node.child_by_field_name("declarator") 15 | .child_by_field_name("declarator") 16 | .text.decode() 17 | ) 18 | 19 | 20 | def comment_analysis(code: bytes, language: str) -> float: 21 | query_texts = COMMENT_QUERY[language] 22 | parser = get_parser(language) 23 | tree = parser.parse(code) 24 | characters = 0 25 | for query_text in query_texts: 26 | comment_query = get_language(language).query(query_text) 27 | for node, _ in comment_query.captures(tree.root_node): 28 | comment_text = code[node.start_byte : node.end_byte] 29 | characters += len(comment_text) 30 | return characters / len(code) 31 | 32 | 33 | # Annotate an incomplete repoqa dataset with function and class information 34 | def main(dataset_path: str, overwrite_analysis: bool = False): 35 | assert dataset_path.endswith(".json"), "Dataset must be a JSON file, check README" 36 | with open(dataset_path, "r") as f: 37 | lists = json.load(f) 38 | 39 | for lang, repos in lists.items(): 40 | assert ( 41 | lang in FUNCTION_QUERY 42 | ), f"Unsupported language: {lang} -- supported: {FUNCTION_QUERY.keys()}" 43 | 44 | fn_query_text = FUNCTION_QUERY[lang] 45 | print(f"🔥 Querying {lang} functions with `{fn_query_text}`...") 46 | 47 | parser = get_parser(lang) 48 | fn_query = get_language(lang).query(fn_query_text) 49 | fn_name_parser = _cpp_name_parser if lang == "cpp" else _default_name_parser 50 | 51 | for repo in tqdm(repos): 52 | # skip if the repo already has function information 53 | if not overwrite_analysis and repo.get("functions"): 54 | continue 55 | 56 | if not repo.get("dependency"): 57 | print( 58 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `dependency` -- do dependency analysis first" 59 | ) 60 | continue 61 | 62 | ordered_paths = topological_sort(repo["dependency"]) 63 | global_byte_idx = 0 64 | global_line_idx = 0 65 | functions = {} # path to a list of functions 66 | for path in ordered_paths: 67 | code = repo["content"][path] 68 | code_bytes = bytes(code, "utf8") 69 | tree = parser.parse(code_bytes) 70 | extracted_functions = [] 71 | for capture in fn_query.captures(tree.root_node): 72 | node, _ = capture 73 | function_content = code_bytes[node.start_byte : node.end_byte] 74 | code_ratio = comment_analysis(function_content, lang) 75 | extracted_functions.append( 76 | { 77 | "name": fn_name_parser(node), 78 | "start_line": node.start_point[0], 79 | "end_line": node.end_point[0] + 1, 80 | "start_byte": node.start_byte, 81 | "end_byte": node.end_byte, 82 | "global_start_line": global_line_idx + node.start_point[0], 83 | "global_end_line": global_line_idx + node.end_point[0] + 1, 84 | "global_start_byte": global_byte_idx + node.start_byte, 85 | "global_end_byte": global_byte_idx + node.end_byte, 86 | "code_ratio": code_ratio, 87 | } 88 | ) 89 | functions[path] = extracted_functions 90 | global_byte_idx += len(code) 91 | global_line_idx += code.count("\n") + 1 92 | repo["functions"] = functions 93 | print( 94 | f"🎉 Found {sum(len(v) for v in functions.values())} functions in {repo['repo']} ({lang})" 95 | ) 96 | 97 | # update the dataset 98 | with open(dataset_path, "w") as f_out: 99 | json.dump(lists, f_out) 100 | 101 | 102 | if __name__ == "__main__": 103 | from fire import Fire 104 | 105 | Fire(main) 106 | -------------------------------------------------------------------------------- /scripts/curate/github_fetch.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | import os 7 | from datetime import datetime 8 | from typing import TypedDict 9 | from zoneinfo import ZoneInfo 10 | 11 | from fire import Fire 12 | from github import Auth, Github 13 | from tqdm.auto import tqdm 14 | 15 | from scripts.curate.utility import lang2suffix 16 | 17 | 18 | class GitHubRepoMeta(TypedDict): 19 | repo_name: str 20 | repo_owner: str 21 | commit_sha: str 22 | repo_size: int 23 | 24 | 25 | class GitHubDocument(GitHubRepoMeta): 26 | timestamp: str 27 | path: str 28 | content: str 29 | 30 | 31 | def main( 32 | language: str = "python", 33 | stars: int = 100, 34 | minimal_new_commits: int = 50, 35 | new_commit_since: str = "2024-01-01", 36 | minimal_lang_bytes: int = 1024 * 64, # 64k ideally 37 | ): 38 | token = os.getenv("GITHUB_TOKEN") 39 | assert token is not None, "Make a token at https://github.com/settings/tokens" 40 | auth = Auth.Token(token) 41 | 42 | # See repo selection criteria at https://github.com/evalplus/repoqa/issues/1 43 | query = [] 44 | query.append(f"language:{language}") 45 | query.append(f"stars:>={stars}") 46 | query.append("license:mit license:apache-2.0") 47 | query.append(f"pushed:>={new_commit_since}") 48 | # 128KB to 32MB 49 | query.append("size:128..32000") 50 | query.append("sort:stars") 51 | 52 | # compile query 53 | query = " ".join(query) 54 | print(f"{query=}") 55 | g = Github(auth=auth, per_page=100) 56 | 57 | lang_suffix = lang2suffix[language] 58 | 59 | with open(f"{language}-{datetime.now().isoformat()}.jsonl", "w") as f_out: 60 | repos = g.search_repositories(query) 61 | print("Found ", repos.totalCount, "repositories for", language) 62 | for repo in tqdm(repos, total=repos.totalCount): 63 | # filter at least 100 commits have been made since 2023 Q4 (>= 2023-09-01). 64 | commits = repo.get_commits() 65 | if ( 66 | count_ := sum( 67 | True 68 | for commit in commits 69 | if commit.last_modified_datetime 70 | >= datetime.strptime(new_commit_since, "%Y-%m-%d").replace( 71 | tzinfo=ZoneInfo("UTC") 72 | ) 73 | ) 74 | ) < minimal_new_commits: 75 | print( 76 | f"Skip {repo.html_url} for have less than {minimal_new_commits} after {new_commit_since} (only {count_} commits)" 77 | ) 78 | continue 79 | 80 | # filter repos that is large enough 81 | git_tree = repo.get_git_tree(repo.default_branch, recursive=True) 82 | 83 | tree_iter = list( 84 | filter( 85 | lambda item: item.type == "blob" 86 | and any([item.path.endswith(suffix) for suffix in lang_suffix]), 87 | tqdm(git_tree.tree, leave=False), 88 | ) 89 | ) 90 | 91 | code_file_size = int(sum(item.size for item in tree_iter)) 92 | if code_file_size < minimal_lang_bytes: 93 | print( 94 | f"Skip {repo.html_url} for have less than {minimal_lang_bytes} bytes source file after {new_commit_since} (only {code_file_size} bytes)" 95 | ) 96 | continue 97 | 98 | schema = dict( 99 | repo_name=repo.name, 100 | repo_size=code_file_size, 101 | repo_owner=repo.owner.login, 102 | repo_url=repo.html_url, 103 | commit_sha=git_tree.sha, 104 | last_modified_time=git_tree.last_modified_datetime, 105 | content={}, 106 | ) 107 | for item in tree_iter: 108 | # Fetch the content for each Python file 109 | content = repo.get_contents(item.path) 110 | assert not isinstance(content, list) 111 | if content.encoding != "base64": 112 | continue 113 | file_content = content.decoded_content.decode("utf-8") 114 | schema["content"][item.path] = file_content 115 | f_out.write(json.dumps(schema) + "\n") 116 | f_out.flush() 117 | 118 | 119 | if __name__ == "__main__": 120 | Fire(main) 121 | -------------------------------------------------------------------------------- /scripts/curate/merge_annotation.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | 7 | from fire import Fire 8 | 9 | 10 | def main(dataset_path: str, annotation_path: str): 11 | assert dataset_path.endswith(".json"), "Dataset must be a JSON file, check README" 12 | assert annotation_path.endswith( 13 | ".jsonl" 14 | ), "Annotation must be a JSONL file, check README" 15 | 16 | with open(dataset_path) as f: 17 | dataset = json.load(f) 18 | 19 | with open(annotation_path) as f: 20 | annotations = [json.loads(line) for line in f] 21 | 22 | def make_key(repo_name, func_name): 23 | return f"{repo_name}::{func_name}" 24 | 25 | key2annotation = {make_key(a["repo"], a["name"]): a for a in annotations} 26 | 27 | for lang, repos in dataset.items(): 28 | for repo in repos: 29 | if "needles" not in repo: 30 | print( 31 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `needles` -- do needle analysis first" 32 | ) 33 | continue 34 | for needle in repo["needles"]: 35 | needle_name = needle["name"] 36 | key = make_key(repo["repo"], needle_name) 37 | annotation = key2annotation.get(key, None) 38 | if annotation is None: 39 | print( 40 | f"⚠️ Missing annotation for {key} for lang {lang} -- skipping" 41 | ) 42 | continue 43 | needle["description"] = annotation["annotation"] 44 | 45 | with open(dataset_path, "w") as f_out: 46 | json.dump(dataset, f_out) 47 | 48 | 49 | if __name__ == "__main__": 50 | Fire(main) 51 | -------------------------------------------------------------------------------- /scripts/curate/merge_dep.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | import os 7 | 8 | from fire import Fire 9 | 10 | 11 | def main(dataset_path: str): 12 | # iterate json files under scripts/curate/dep_analysis/data 13 | repo2dep = {} 14 | for file in os.listdir(os.path.join("scripts/curate/dep_analysis/data")): 15 | if file.endswith(".json"): 16 | with open(os.path.join("scripts/curate/dep_analysis/data", file)) as f: 17 | data = json.load(f) 18 | 19 | repos = list(data.values())[0] 20 | for repo in repos: 21 | repo2dep[repo["repo"]] = repo["dependency"] 22 | 23 | with open(dataset_path) as f: 24 | dataset = json.load(f) 25 | 26 | for lang, repos in dataset.items(): 27 | for repo in repos: 28 | if repo["repo"] not in repo2dep: 29 | print(f"{lang} -- Repo {repo['repo']} not found in dep analysis data") 30 | continue 31 | repo["dependency"] = repo2dep[repo["repo"]] 32 | print(f"{lang} -- Repo {repo['repo']} has dependency added in the dataset") 33 | 34 | with open(dataset_path, "w") as f_out: 35 | json.dump(dataset, f_out) 36 | 37 | 38 | if __name__ == "__main__": 39 | Fire(main) 40 | -------------------------------------------------------------------------------- /scripts/curate/needle_annotation.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | import os 7 | 8 | import openai 9 | from tqdm import tqdm 10 | 11 | from repoqa.provider.request.openai import make_auto_request 12 | from repoqa.utility import topological_sort 13 | 14 | CAPTURE_HEAD = "" 15 | CAPTURE_TAIL = "" 16 | 17 | 18 | def make_prompt(fn_name: str, code: str): 19 | instruction = f'Can you **briefly** describe the purpose, input, output, and procedure of "{fn_name}"?' 20 | return f"""\ 21 | {instruction} 22 | 23 | ``` 24 | {code} 25 | ``` 26 | 27 | {instruction} 28 | 29 | Please follow format to complete the skeleton below: 30 | 31 | {CAPTURE_HEAD} 32 | 1. **Purpose**: ... 33 | 2. **Input**: ... 34 | 3. **Output**: ... 35 | 4. **Procedure**: ... 36 | {CAPTURE_TAIL} 37 | 38 | {instruction} 39 | 40 | Notes: 41 | 1. DO NOT reveal function names ({fn_name}) and variable names 42 | 2. Start with {CAPTURE_HEAD} and end with {CAPTURE_TAIL} 43 | 3. Customize the description to differentiate it from other functions 44 | """ 45 | 46 | 47 | # Annotate an incomplete repoqa dataset with function and class information 48 | def main( 49 | dataset_path: str, 50 | code_prefix_lines: int = 100, 51 | output_desc_path: str = "function_description.jsonl", 52 | use_batch_api: bool = False, 53 | verbose: bool = False, 54 | debug: bool = False, 55 | ): 56 | assert use_batch_api == False, "Batch API is not supported yet." 57 | 58 | assert dataset_path.endswith(".json"), "Dataset must be a JSON file, check README" 59 | with open(dataset_path, "r") as f: 60 | lists = json.load(f) 61 | 62 | # resume from output_desc_path 63 | if output_desc_path.endswith(".jsonl") and os.path.exists(output_desc_path): 64 | with open(output_desc_path, "r") as f: 65 | results = [json.loads(line) for line in f] 66 | else: 67 | # {repo, name, prompt, annotation} 68 | results = [] 69 | 70 | # a set of inference task to run; each item is a tuple of {repo, name, prompt} 71 | tasks = [] 72 | for lang, repos in lists.items(): 73 | print(f"🔥 Collecting unannotated needle functions for {lang}") 74 | for repo in tqdm(repos): 75 | if not repo.get("dependency"): 76 | print( 77 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `dependency` -- do dependency analysis first" 78 | ) 79 | continue 80 | ordered_paths = topological_sort(repo["dependency"]) 81 | repo_lines = [] 82 | for path in ordered_paths: 83 | repo_lines.extend(repo["content"][path].split("\n")) 84 | 85 | def get_code(global_start_line, global_end_line): 86 | return "\n".join( 87 | repo_lines[ 88 | max(0, global_start_line - code_prefix_lines) : global_end_line 89 | ] 90 | ) 91 | 92 | existing_needles = set( 93 | [item["name"] for item in results if item["repo"] == repo["repo"]] 94 | ) 95 | 96 | for needle in repo["needles"]: 97 | fn_name = needle["name"] 98 | if fn_name in existing_needles: 99 | continue 100 | code = get_code(needle["global_start_line"], needle["global_end_line"]) 101 | prompt = make_prompt(fn_name, code) 102 | if verbose: 103 | print(prompt) 104 | print("-" * 80) 105 | tasks.append( 106 | { 107 | "repo": repo["repo"], 108 | "name": fn_name, 109 | "prompt": prompt, 110 | } 111 | ) 112 | 113 | print(f"🔥 {len(tasks)} needle functions to be annotated in total") 114 | client = openai.Client() 115 | with open(output_desc_path, "+a") as f_out: 116 | for task in tqdm(tasks): 117 | print(f"🔥 Annotating {task['name']} in {task['repo']}") 118 | output = make_auto_request( 119 | client, 120 | task["prompt"], 121 | model="gpt-4-turbo", 122 | max_tokens=2048, 123 | temperature=0.2, 124 | n=1, 125 | ) 126 | annotation = output.choices[0].message.content 127 | result = { 128 | "repo": task["repo"], 129 | "name": task["name"], 130 | "prompt": task["prompt"], 131 | "raw_annotation": annotation, 132 | "annotation": annotation.split(CAPTURE_HEAD)[-1].split(CAPTURE_TAIL)[0], 133 | } 134 | json.dump(result, f_out) 135 | f_out.write("\n") 136 | f_out.flush() 137 | 138 | if debug: 139 | print("[PROMPT]", "-" * 80) 140 | print(task["prompt"]) 141 | print("[ANNOTATION]", "-" * 80) 142 | print(annotation) 143 | print("-" * 80) 144 | print("Enter to continue... or b to break:") 145 | if input() == "b": 146 | break 147 | 148 | 149 | if __name__ == "__main__": 150 | from fire import Fire 151 | 152 | Fire(main) 153 | -------------------------------------------------------------------------------- /scripts/curate/needle_selection.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | import json 6 | from collections import Counter 7 | from random import sample 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | # Annotate an incomplete repoqa dataset with function and class information 13 | def main( 14 | dataset_path: str, 15 | overwrite_analysis: bool = False, 16 | max_len: int = 2000, 17 | num_bins: int = 64, 18 | max_fn_per_repo: int = 10, 19 | ): 20 | assert ( 21 | num_bins >= max_fn_per_repo 22 | ), "Number of bins must be greater than max functions per repo" 23 | assert dataset_path.endswith(".json"), "Dataset must be a JSON file, check README" 24 | with open(dataset_path, "r") as f: 25 | lists = json.load(f) 26 | 27 | for lang, repos in lists.items(): 28 | print(f"🔥 Selecting needle functions for {lang}") 29 | # FIXME(@ganler): enable more dependency analysis! 30 | for repo in tqdm(repos): 31 | # skip if the repo already has function information 32 | if not overwrite_analysis and repo.get("needles"): 33 | continue 34 | 35 | if not repo.get("functions"): 36 | print( 37 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `functions` field -- do function analysis first" 38 | ) 39 | continue 40 | 41 | repo_size_bytes = sum(len(content) for content in repo["content"].values()) 42 | 43 | selected_bins = set() 44 | bin_size = repo_size_bytes // num_bins 45 | 46 | function_names = Counter() 47 | for funcs in repo["functions"].values(): 48 | for fn in funcs: 49 | function_names.update([fn["name"]]) 50 | # get function names that only appear once 51 | function_names = {k for k, v in function_names.items() if v == 1} 52 | 53 | needle_candidates = [] 54 | for path, funcs in repo["functions"].items(): 55 | for fn in funcs: 56 | # criteria 1: no repeated function names 57 | if fn["name"] not in function_names: 58 | continue 59 | 60 | # criteria 2: length <= max_len 61 | if fn["end_byte"] - fn["start_byte"] > max_len: 62 | continue 63 | 64 | # criteria 3: not in the same bin 65 | bin_idx = fn["global_start_byte"] // bin_size 66 | if bin_idx in selected_bins: 67 | continue 68 | 69 | # criteria 4: TODO -- select those with more code! 70 | selected_bins |= {bin_idx} 71 | needle_candidates.append((path, fn)) 72 | 73 | len_total_fn = sum(len(v) for v in repo["functions"].values()) 74 | print( 75 | f"🎉 Selected {len(needle_candidates)} needles from {len_total_fn} functions in {repo['repo']} ({lang})" 76 | ) 77 | needles = [] 78 | for path, fn in sample( 79 | needle_candidates, min(max_fn_per_repo, len(needle_candidates)) 80 | ): 81 | needles.append({**fn, "path": path}) 82 | repo["needles"] = needles 83 | 84 | # update the dataset 85 | with open(dataset_path, "w") as f_out: 86 | json.dump(lists, f_out) 87 | 88 | 89 | if __name__ == "__main__": 90 | from fire import Fire 91 | 92 | Fire(main) 93 | -------------------------------------------------------------------------------- /scripts/curate/obfuscate_nl.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # Run with ```python scripts/curate/obfuscate_nl.py repoqa-2024-04-20.json``` 6 | # Will save to repoqa-2024-04-20-obfuscated.json 7 | 8 | import json 9 | import os 10 | import re 11 | 12 | from fire import Fire 13 | from tqdm import tqdm 14 | from tree_sitter_languages import get_language, get_parser 15 | 16 | from repoqa.utility import COMMENT_QUERY, FUNCTION_NAME_QUERY 17 | 18 | 19 | def remove_comments(code, language): 20 | query_texts = COMMENT_QUERY[language] 21 | parser = get_parser(language) 22 | code_bytes = bytes(code, "utf8") 23 | tree = parser.parse(code_bytes) 24 | comment_chunks = [] 25 | for query_text in query_texts: 26 | comment_query = get_language(language).query(query_text) 27 | for node, _ in comment_query.captures(tree.root_node): 28 | comment_chunks.append(node.text.decode("utf-8")) 29 | comment_chunks.sort(key=len, reverse=True) 30 | for chunk in comment_chunks: 31 | chunk_lines = chunk.splitlines() 32 | chunk_lines_len = [len(bytes(line, "utf-8")) for line in chunk_lines] 33 | chunk_lines_empty = [ 34 | (bytes("", "utf-8").ljust(llen, b"\0")).decode("utf-8") 35 | for llen in chunk_lines_len 36 | ] 37 | chunk_empty = "\0".join(chunk_lines_empty) 38 | chunk_empty = chunk_empty[:-1] + "\n" 39 | code = code.replace(chunk, chunk_empty) 40 | return code 41 | 42 | 43 | def rename_functions(code, language, starting_index=0): 44 | func_name_query = get_language(language).query(FUNCTION_NAME_QUERY[language]) 45 | parser = get_parser(language) 46 | print(f"Running rename_functions: {code}, {language}") 47 | code_bytes = bytes(code, "utf8") 48 | tree = parser.parse(code_bytes) 49 | function_names = set() 50 | for capture in func_name_query.captures(tree.root_node): 51 | node, _ = capture 52 | function_names.add(node.text.decode("utf-8")) 53 | function_map = {} 54 | current_index = starting_index 55 | for name in function_names: 56 | function_map[name] = f"function_{starting_index}" 57 | code = code.replace(name, function_map[name]) 58 | current_index += 1 59 | return code, function_map 60 | 61 | 62 | def main(ds_filepath: str): 63 | dataset_file = open(ds_filepath, "r") 64 | dataset = dataset_file.read() 65 | dataset = json.loads(dataset) 66 | dataset_file.close() 67 | 68 | for lang in dataset.keys(): 69 | print(f"🔥 Processing language: {lang}") 70 | for repo_idx in tqdm(range(len(dataset[lang]))): 71 | for filepath in dataset[lang][repo_idx]["content"].keys(): 72 | prev_byte_len = len( 73 | bytes(dataset[lang][repo_idx]["content"][filepath], "utf-8") 74 | ) 75 | dataset[lang][repo_idx]["content"][filepath] = remove_comments( 76 | dataset[lang][repo_idx]["content"][filepath], lang 77 | ) 78 | new_byte_len = len( 79 | bytes(dataset[lang][repo_idx]["content"][filepath], "utf-8") 80 | ) 81 | assert prev_byte_len == new_byte_len 82 | 83 | dataset_dir = "/".join(ds_filepath.split("/")[:-1]) 84 | ds_filepath = ds_filepath.split("/")[-1] 85 | ds_fname = ".".join(ds_filepath.split(".")[:-1]) 86 | ds_ext = ds_filepath.split(".")[-1] 87 | 88 | obfs_ds_file = open( 89 | os.path.join(dataset_dir, f"{ds_fname}-obfuscated.{ds_ext}"), "w+" 90 | ) 91 | obfs_ds_file.write(json.dumps(dataset)) 92 | obfs_ds_file.close() 93 | 94 | 95 | if __name__ == "__main__": 96 | Fire(main) 97 | -------------------------------------------------------------------------------- /scripts/curate/requirements.txt: -------------------------------------------------------------------------------- 1 | PyGithub 2 | fire 3 | tqdm 4 | tempdir 5 | GitPython 6 | pydeps 7 | tree_sitter 8 | tree_sitter_languages 9 | pygraphviz 10 | matplotlib 11 | transformers 12 | python-dep-tree 13 | -------------------------------------------------------------------------------- /scripts/curate/utility.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | lang2suffix = { 6 | "python": [".py"], 7 | "go": [".go"], 8 | "cpp": [".cpp", ".hpp", ".cc", ".hh", ".cxx", ".hxx", ".c", ".h"], 9 | "java": [".java"], 10 | "typescript": [".ts", ".js"], 11 | "php": [".php"], 12 | "rust": [".rs"], 13 | } 14 | -------------------------------------------------------------------------------- /scripts/demos/model_request_oai.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | if __name__ == "__main__": 6 | import openai 7 | 8 | client = openai.OpenAI( # Note: if you need UIUC VPN or UIUC network to access the server! 9 | api_key="none", base_url="http://ise-dynamo.cs.illinois.edu:8888/v1" 10 | ) 11 | 12 | task_prefix = "def fibonacci(n):\n" 13 | prompt = f"""This is the fastest implementation for Fibonacci: 14 | ```python 15 | {task_prefix}""" 16 | 17 | # completion 18 | responses = client.completions.create( 19 | model="deepseek-ai/deepseek-coder-6.7b-instruct", 20 | prompt=prompt, 21 | max_tokens=256, 22 | n=3, 23 | stop=["\n```"], 24 | ) 25 | 26 | for c in responses.choices: 27 | print(task_prefix + c.text) 28 | print("=" * 8) 29 | -------------------------------------------------------------------------------- /scripts/dev/license-hdr.txt: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | 3 | SPDX-License-Identifier: Apache-2.0 4 | -------------------------------------------------------------------------------- /scripts/eval/recompute_all_scores.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 4 | # 5 | # SPDX-License-Identifier: Apache-2.0 6 | 7 | set -e 8 | 9 | export PYTHONPATH=$(pwd) 10 | 11 | for path in "$(pwd)"/results/**/*.jsonl; do 12 | # if the file size is greater than 10MB 13 | file_size_mb=$(du -m "$path" | cut -f1) 14 | echo "Size of $path: $file_size_mb MB" 15 | if [ $file_size_mb -lt 10 ]; then 16 | echo "File size is less than 10MB. Skipping..." 17 | continue 18 | fi 19 | yes | python repoqa/compute_score.py --model-output-path $path 20 | done 21 | -------------------------------------------------------------------------------- /scripts/misc/estimate_max_char.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | def main(model="deepseek-ai/deepseek-coder-6.7b-instruct", max_tokens=8 * 2**10): 7 | import time 8 | 9 | import openai 10 | 11 | client = openai.OpenAI( # Note: if you need UIUC VPN or UIUC network to access the server! 12 | api_key="none", base_url="http://ise-dynamo.cs.illinois.edu:8888/v1" 13 | ) 14 | 15 | prompt = "def " 16 | 17 | tstart = time.time() 18 | responses = client.completions.create( 19 | model=model, prompt=prompt, n=1, max_tokens=max_tokens, temperature=0 20 | ) 21 | print(f"Time taken: {time.time() - tstart:.1f}s") 22 | print("Finish reason:", responses.choices[0].finish_reason) 23 | print( 24 | "Estimated max context length in #chars is: ", 25 | len(prompt) + len(responses.choices[0].text), 26 | ) 27 | 28 | 29 | if __name__ == "__main__": 30 | from fire import Fire 31 | 32 | Fire(main) 33 | -------------------------------------------------------------------------------- /scripts/misc/repo_token_size.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # Estimate the maximum token size for a repository 6 | 7 | import json 8 | 9 | import matplotlib.pyplot as plt 10 | from transformers import AutoTokenizer 11 | 12 | from repoqa.utility import topological_sort 13 | 14 | COMMENT_PREFIX = { 15 | "python": "#", 16 | "java": "//", 17 | "typescript": "//", 18 | "rust": "//", 19 | "cpp": "//", 20 | } 21 | 22 | 23 | def get_max_token_size(dataset_path: str, model: str): 24 | with open(dataset_path, "r") as f: 25 | repo = json.load(f) 26 | 27 | tokenizer = AutoTokenizer.from_pretrained(model) 28 | 29 | min_token_size = 1e9 30 | min_token_size_repo = None 31 | token_sizes = [] 32 | 33 | for lang, repos in repo.items(): 34 | for repo in repos: 35 | if "dependency" not in repo: 36 | print( 37 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `dependency` -- do dependency analysis first" 38 | ) 39 | continue 40 | 41 | ordered_paths = topological_sort(repo["dependency"]) 42 | 43 | bigfile = "" 44 | for path in ordered_paths: 45 | bigfile += ( 46 | COMMENT_PREFIX[lang] 47 | + f" current path: {path}\n" 48 | + repo["content"][path] 49 | ) 50 | 51 | # estimate the maximum token size 52 | token_size = tokenizer(bigfile, return_tensors="pt")["input_ids"].shape[1] 53 | token_sizes.append(token_size) 54 | min_token_size = min(min_token_size, token_size) 55 | if min_token_size == token_size: 56 | min_token_size_repo = repo["repo"] 57 | print(f"[{lang}] {repo['repo']:<32}: {token_size:>20} tokens") 58 | 59 | print(f"Estimated minimum token size: {min_token_size} by {min_token_size_repo}") 60 | 61 | # visualize the distribution 62 | plt.figure(figsize=(8, 4)) 63 | plt.hist(token_sizes, bins=64) 64 | # xtick at every 20k 65 | unit = 100 * 1000 66 | plt.xticks(range(0, max(token_sizes) + 1, unit)) 67 | plt.xlim(0, max(token_sizes) + 1000) 68 | # xtick using "k" 69 | plt.gca().xaxis.set_major_formatter( 70 | plt.FuncFormatter(lambda x, _: f"{x / unit:.0f}") 71 | ) 72 | plt.xlabel("Token size (100k)") 73 | plt.ylabel("Frequency") 74 | plt.title("Token size distribution") 75 | # compact layout 76 | plt.tight_layout() 77 | plt.savefig("token_size_distribution.png", dpi=164, bbox_inches="tight") 78 | 79 | 80 | if __name__ == "__main__": 81 | from fire import Fire 82 | 83 | Fire(get_max_token_size) 84 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = repoqa 3 | description = "RepoQA for Evaluating Long-Context Code Understanding" 4 | long_description = file: README.md 5 | long_description_content_type = text/markdown 6 | url = https://github.com/evalplus/repoqa 7 | license = Apache-2.0 8 | license_files = LICENSE 9 | platform = any 10 | classifiers = 11 | Operating System :: OS Independent 12 | Programming Language :: Python :: 3 13 | License :: OSI Approved :: Apache Software License 14 | 15 | [options] 16 | packages = find: 17 | python_requires = >=3.8 18 | dependency_links = 19 | install_requires = 20 | tempdir>=0.7.1 21 | appdirs>=1.4.4 22 | wget>=3.2 23 | fire>=0.6.0 24 | nltk>=3.8.1 25 | rich>=13.5.2 26 | numpy>=1.25.2 27 | tree_sitter<=0.21.3 28 | tree_sitter_languages>=1.10.2 29 | transformers>=4.40.0 30 | openai>=1.23.2 31 | anthropic>=0.25.6 32 | google-generativeai>=0.5.2 33 | 34 | [options.entry_points] 35 | console_scripts = 36 | repoqa.search_needle_function = repoqa.search_needle_function:main 37 | repoqa.compute_score = repoqa.compute_score:main 38 | 39 | [options.extras_require] 40 | vllm = vllm>=0.3.3 41 | --------------------------------------------------------------------------------