├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── docs
├── curate_dataset.md
└── dev_note.md
├── pyproject.toml
├── repoqa
├── __init__.py
├── compute_score.py
├── data.py
├── metric.py
├── provider
│ ├── __init__.py
│ ├── anthropic.py
│ ├── base.py
│ ├── google.py
│ ├── hf.py
│ ├── openai.py
│ ├── request
│ │ ├── __init__.py
│ │ ├── anthropic.py
│ │ ├── google.py
│ │ └── openai.py
│ └── vllm.py
├── search_needle_function.py
└── utility.py
├── requirements.txt
├── results
├── .gitignore
└── README.md
├── scripts
├── cherrypick
│ ├── README.md
│ └── lists.json
├── curate
│ ├── dataset_ensemble_clone.py
│ ├── dataset_ensemble_gh_api.py
│ ├── dep_analysis
│ │ ├── cpp.py
│ │ ├── data
│ │ │ └── .gitignore
│ │ ├── go-analysis
│ │ │ └── dependency_analysis.go
│ │ ├── go.py
│ │ ├── java-analysis
│ │ │ ├── dependency-reduced-pom.xml
│ │ │ ├── java-lib
│ │ │ │ └── java-analysis-1.0-SNAPSHOT.jar
│ │ │ ├── pom.xml
│ │ │ └── src
│ │ │ │ └── main
│ │ │ │ └── java
│ │ │ │ └── edu
│ │ │ │ └── cs
│ │ │ │ └── illinois
│ │ │ │ └── repoqa
│ │ │ │ └── DepAnalyze.java
│ │ ├── java.py
│ │ ├── python.py
│ │ ├── rust.py
│ │ └── typescript.py
│ ├── function_analysis.py
│ ├── github_fetch.py
│ ├── merge_annotation.py
│ ├── merge_dep.py
│ ├── needle_annotation.py
│ ├── needle_selection.py
│ ├── obfuscate_nl.py
│ ├── requirements.txt
│ └── utility.py
├── demos
│ └── model_request_oai.py
├── dev
│ └── license-hdr.txt
├── eval
│ └── recompute_all_scores.sh
└── misc
│ ├── estimate_max_char.py
│ └── repo_token_size.py
└── setup.cfg
/.gitignore:
--------------------------------------------------------------------------------
1 | # Customized
2 | .vscode/
3 | *.jsonl
4 | repoqa-*.json
5 | *.bak
6 | repoqa/_version.py
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 | cover/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | .pybuilder/
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | # For a library or package, you might want to ignore these files since the code is
94 | # intended to run in multiple environments; otherwise, check them in:
95 | # .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # poetry
105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109 | #poetry.lock
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | #pdm.lock
114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115 | # in version control.
116 | # https://pdm.fming.dev/#use-with-ide
117 | .pdm.toml
118 |
119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120 | __pypackages__/
121 |
122 | # Celery stuff
123 | celerybeat-schedule
124 | celerybeat.pid
125 |
126 | # SageMath parsed files
127 | *.sage.py
128 |
129 | # Environments
130 | .env
131 | .venv
132 | env/
133 | venv/
134 | ENV/
135 | env.bak/
136 | venv.bak/
137 |
138 | # Spyder project settings
139 | .spyderproject
140 | .spyproject
141 |
142 | # Rope project settings
143 | .ropeproject
144 |
145 | # mkdocs documentation
146 | /site
147 |
148 | # mypy
149 | .mypy_cache/
150 | .dmypy.json
151 | dmypy.json
152 |
153 | # Pyre type checker
154 | .pyre/
155 |
156 | # pytype static type analyzer
157 | .pytype/
158 |
159 | # Cython debug symbols
160 | cython_debug/
161 |
162 | # PyCharm
163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165 | # and can be added to the global gitignore or merged into this file. For a more nuclear
166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167 | # nuclear option because steven uses PyCharm.
168 | .idea/
169 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | repos:
6 | - repo: https://github.com/pycqa/isort
7 | rev: 5.12.0
8 | hooks:
9 | - id: isort
10 | name: isort (python)
11 | args: ["--profile", "black"]
12 | - repo: https://github.com/psf/black
13 | rev: 22.6.0
14 | hooks:
15 | - id: black
16 | - repo: https://github.com/pre-commit/pre-commit-hooks
17 | rev: v4.3.0
18 | hooks:
19 | - id: check-yaml
20 | - id: end-of-file-fixer
21 | - id: trailing-whitespace
22 | - id: check-added-large-files
23 | args: ["--maxkb=32"]
24 | - id: debug-statements
25 | - repo: https://github.com/Lucas-C/pre-commit-hooks
26 | rev: v1.5.4
27 | hooks:
28 | - id: forbid-tabs
29 | - id: remove-tabs
30 | - id: insert-license
31 | files: \.(sh|yaml|yml|py)$
32 | args: ["--license-filepath=scripts/dev/license-hdr.txt", "--use-current-year"]
33 | exclude: (?x)^(
34 | repo/.*
35 | )$
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
203 | -------------------------------------------------------------------------------
204 | The files under "evalplus/eval/" additionally complies with the MIT License for
205 | being built on OpenAI's HumanEval work.
206 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RepoQA: Evaluating Long-Context Code Understanding
2 |
3 | [](https://arxiv.org/abs/2406.06025)
4 | [](https://pypi.org/project/repoqa/)
5 |
6 | 🏠 Homepage: https://evalplus.github.io/repoqa.html
7 |
8 | ## 🚀 Installation
9 |
10 | ```bash
11 | # without vLLM (can run openai, anthropic, and huggingface backends)
12 | pip install --upgrade repoqa
13 | # To enable vLLM
14 | pip install --upgrade "repoqa[vllm]"
15 | ```
16 |
17 | ⏬ Install nightly version :: click to expand ::
18 |
19 |
20 | ```bash
21 | pip install --upgrade "git+https://github.com/evalplus/repoqa.git" # without vLLM
22 | pip install --upgrade "repoqa[vllm] @ git+https://github.com/evalplus/repoqa@main" # with vLLM
23 | ```
24 |
25 |
26 |
27 |
28 | ⏬ Using RepoQA as a local repo? :: click to expand ::
29 |
30 |
31 | ```bash
32 | git clone https://github.com/evalplus/repoqa.git
33 | cd repoqa
34 | export PYTHONPATH=$PYTHONPATH:$(pwd)
35 | pip install -r requirements.txt
36 | ```
37 |
38 |
39 |
40 |
41 | ## 🏁 Search Needle Function (SNF)
42 |
43 | Search Needle Function is the first and base RepoQA task which aims to practice LLMs' ability of **long-context code understanding and retrieval**.
44 | Its corresponding real-life scenario is to perform precise code search from function description.
45 |
46 | 🔎 More dataset details :: click to expand ::
47 |
48 |
49 | > [!Note]
50 | >
51 | > SNF includes 500 tests (5 programming languages x 10 repos x 10 needle functions) where an LLM is given:
52 | >
53 | > 1. A large code context sorted in file dependency
54 | > 2. A NL description of the needle function without revealing keywords like function names
55 | > 3. An instruction to retrieve the described function
56 | >
57 | > The evaluator passes a test if the searched function is syntactically closest to the ground-truth compared against
58 | > other functions (systematically parsed by `treesitter`) and the similarity is greater than a user defined threshold (by default 0.8).
59 |
60 |
61 |
62 |
63 | You can run the SNF evaluation using various backends:
64 |
65 | ### OpenAI Compatible Servers
66 |
67 | ```bash
68 | repoqa.search_needle_function --model "gpt4-turbo" --backend openai
69 | # 💡 If you use openai API compatible server such as vLLM servers:
70 | # repoqa.search_needle_function --base-url "http://localhost:[PORT]/v1" \
71 | # --model "Qwen/CodeQwen1.5-7B-Chat" --backend openai
72 | ```
73 |
74 | ### Anthropic Compatible Servers
75 |
76 | ```bash
77 | repoqa.search_needle_function --model "claude-3-haiku-20240307" --backend anthropic
78 | ```
79 |
80 | ### vLLM
81 |
82 | ```bash
83 | repoqa.search_needle_function --model "Qwen/CodeQwen1.5-7B-Chat" --backend vllm
84 | ```
85 |
86 | 🔎 Context extension for small-ctx models :: click to expand ::
87 |
88 |
89 | > There are two ways to unlock a model's context at inference time:
90 | >
91 | > 1. **Direct Extension**: Edit `max_positional_embedding` of the model's `config.json` (e.g., `hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/[hash]/config.json`) to something like `22528`.
92 | > 2. **[Dynamic RoPE Scaling](https://blog.eleuther.ai/yarn/#dynamic-scaling)**:
93 | > To extend `Meta-Llama-3-8B-Instruct` from 8k to 32k (4x), edit the `config.json`:
94 | >
95 | > `"rope_scaling": {"type": "dynamic", "factor": 4.0}`
96 | >
97 | > Note: This works for vLLM `<0.4.3` and HuggingFace transformers. RepoQA will automatically configure dynamic RoPE for vLLM `>= 0.4.3`
98 |
99 |
100 |
101 |
102 | > [!Note]
103 | >
104 | > Reference evaluation time:
105 | >
106 | > - Llama3-8B-Instruct: 45 minutes on 2xA6000 (PCIe NVLink)
107 | > - Llama3-70B-Instruct: 100 minutes on 4xA100 (PCIe NVLink)
108 |
109 | ### HuggingFace transformers
110 |
111 | ```bash
112 | repoqa.search_needle_function --model "Qwen/CodeQwen1.5-7B-Chat" --backend hf --trust-remote-code
113 | ```
114 |
115 | > [!Tip]
116 | >
117 | > Installing [flash-attn](https://github.com/Dao-AILab/flash-attention) and
118 | > additionally set `--attn-implementation "flash_attention_2"` can largely
119 | > lower the memory requirement.
120 |
121 | 🔨 Having trouble installing `flash-attn`? :: click to expand ::
122 |
123 |
124 | > If you have trouble with `pip install flash-attn --no-build-isolation`,
125 | > you can try to directly use [pre-built wheels](https://github.com/Dao-AILab/flash-attention/releases):
126 | >
127 | > ```shell
128 | > export FLASH_ATTN_VER=2.5.8 # check latest version at https://github.com/Dao-AILab/flash-attention/releases
129 | > export CUDA_VER="cu122" # check available ones at https://github.com/Dao-AILab/flash-attention/releases
130 | > export TORCH_VER=$(python -c "import torch; print('.'.join(torch.__version__.split('.')[:2]))")
131 | > export PY_VER=$(python -c "import platform; print(''.join(platform.python_version().split('.')[:2]))")
132 | > export OS_ARCH=$(python -c "import platform; print(f'{platform.system().lower()}_{platform.machine()}')")
133 | >
134 | > export WHEEL=flash_attn-${FLASH_ATTN_VER}+${CUDA_VER}torch${TORCH_VER}cxx11abiFALSE-cp${PY_VER}-cp${PY_VER}-${OS_ARCH}.whl
135 | > wget https://github.com/Dao-AILab/flash-attention/releases/download/v${FLASH_ATTN_VER}/${WHEEL}
136 | > pip install ${WHEEL}
137 | > ```
138 |
139 |
140 |
141 |
142 | ### Google Generative AI API (Gemini)
143 |
144 | ```bash
145 | repoqa.search_needle_function --model "gemini-1.5-pro-latest" --backend google
146 | ```
147 |
148 | ### CLI Usage
149 |
150 | - **Input**:
151 | - `--model`: Hugging-Face model ID, such as `ise-uiuc/Magicoder-S-DS-6.7B`
152 | - `--backend`: `vllm` (default) or `openai`
153 | - `--base-url`: OpenAI API base URL
154 | - `--code-context-size` (default: 16384): #tokens (by DeepSeekCoder tokenizer) of repository context
155 | - `--caching` (default: True): accelerate subsequent runs by caching preprocessing; `--nocaching` to disable
156 | - `--max-new-tokens` (default: 1024): Maximum #new tokens to generate
157 | - `--system-message` (default: None): system message (note it's not supported by some models)
158 | - `--tensor-parallel-size`: #GPUS for doing tensor parallelism (only for vLLM)
159 | - `--languages` (default: None): List of languages to evaluate (None means all)
160 | - `--result-dir` (default: "results"): Directory to save the model outputs and evaluation results
161 | - `--clean-ctx-comments` (default: "none"): Clean context comments with padding ("positional_padding") or no padding ("no_padding")
162 | - `--eval-ignore-comments` (default: False): During evaluation, ignore groundtruth and model comments
163 | - `--trust-remote-code` (default: False): allow remote code (for HuggingFace transformers and vLLM)
164 | - `--attn-implementation` (default: None): Use "flash_attention_2" if your HF hits OOM
165 | - **Output**:
166 | - `results/ntoken_{code-context-size}/{model}.jsonl`: Model generated outputs
167 | - `results/ntoken_{code-context-size}/{model}-SCORE.json`: Evaluation results
168 |
169 | ### Compute Scores
170 |
171 | By default, the `repoqa.search_needle_function` command will evaluate model outputs and compute scores after text generation.
172 | However, you can also separately compute scores using the following command:
173 |
174 | ```shell
175 | repoqa.compute_score --model-output-path={model-output}.jsonl
176 | ```
177 |
178 | > [!Tip]
179 | >
180 | > - **Input**: Path to the model generated outputs.
181 | > - **Output**: The evaluation scores would be stored in `{model-output}-SCORES.json`
182 |
183 | ## 📚 Read More
184 |
185 | - [RepoQA Homepage](https://evalplus.github.io/repoqa.html)
186 | - [RepoQA Dataset Curation](docs/curate_dataset.md)
187 | - [RepoQA Development Notes](docs/dev_note.md)
188 |
189 | ## Citation
190 |
191 | ```bibtex
192 | @article{repoqa,
193 | title = {RepoQA: Evaluating Long Context Code Understanding},
194 | author = {Liu, Jiawei and Tian, Jia Le and Daita, Vijay and Wei, Yuxiang and Ding, Yifeng and Wang, Yuhan Katherine and Yang, Jun and Zhang, Lingming},
195 | year = {2024},
196 | journal = {arXiv preprint arXiv:2406.06025},
197 | }
198 | ```
199 |
--------------------------------------------------------------------------------
/docs/curate_dataset.md:
--------------------------------------------------------------------------------
1 | # RepoQA Dataset Curation
2 |
3 | ## Search Needle Functions
4 |
5 | ### Step 1: Cherry-pick repositories
6 |
7 | See [scripts/cherrypick/README.md](cherrypick/README.md) for more information.
8 |
9 |
10 | > [!Tip]
11 | >
12 | > **Output**: Extend `scripts/cherrypick/lists.json` for a programming language.
13 |
14 |
15 | ### Step 2: Extract repo content
16 |
17 | ```shell
18 | python scripts/curate/dataset_ensemble_clone.py
19 | ```
20 |
21 | > [!Tip]
22 | >
23 | > **Output**: `repoqa-{datetime}.json` by adding a `"content"` field (path to content) for each repo.
24 |
25 |
26 | ### Step 3: Dependency analysis
27 |
28 | Check [scripts/curate/dep_analysis](scripts/curate/dep_analysis) for more information.
29 |
30 | ```shell
31 | python scripts/curate/dep_analysis/{language}.py # python
32 | ```
33 |
34 | > [!Tip]
35 | >
36 | > **Output**: `{language}.json` (e.g., `python.json`) with a list of items of `{"repo": ..., "commit_sha": ..., "dependency": ...}` field where the dependency is a map of path to imported paths.
37 |
38 | > [!Note]
39 | >
40 | > The `{language}.json` should be uploaded as a release.
41 | >
42 | > To fetch the release, go to `scripts/curate/dep_analysis/data` and run `gh release download dependency --pattern "*.json" --clobber`.
43 |
44 |
45 | ### Step 4: Merge step 2 and step 3
46 |
47 | ```shell
48 | python scripts/curate/merge_dep.py --dataset-path repoqa-{datetime}.json
49 | ```
50 |
51 | > [!Tip]
52 | >
53 | > **Input**: Download dependency files in to `scripts/curate/dep_analysis/data`.
54 | >
55 | > **Output**: Update `repoqa-{datetime}.json` by adding a `"dependency"` field for each repository.
56 |
57 |
58 | ### Step 5: Function collection with TreeSitter
59 |
60 | ```shell
61 | # collect functions (in-place)
62 | python scripts/curate/function_analysis.py --dataset-path repoqa-{datetime}.json
63 | # select needles (in-place)
64 | python scripts/curate/needle_selection.py --dataset-path repoqa-{datetime}.json
65 | ```
66 |
67 | > [!Tip]
68 | >
69 | > **Output**: `--dataset-path` (in-place) by adding a `"functions"` field (path to a list function information) for each repo.
70 |
71 |
72 | ### Step 6: Annotate each function with description to make a final dataset
73 |
74 | ```shell
75 | python scripts/curate/needle_annotation.py --dataset-path repoqa-{datetime}.json
76 | ```
77 |
78 | > [!Tip]
79 | >
80 | > You need to set `OPENAI_API_KEY` in the environment variable to run GPT-4. But you can enable `--use-batch-api` to save some costs.
81 | >
82 | > **Output**: `--output-desc-path` is a seperate json file specifying the function annotations with its sources.
83 |
84 |
85 | ### Step 7: Merge needle description to the final dataset
86 |
87 | ```shell
88 | python scripts/curate/merge_annotation.py --dataset-path repoqa-{datetime}.json --annotation-path {output-desc-path}.jsonl
89 | ```
90 |
91 | > [!Tip]
92 | >
93 | > **Output**: `--dataset-path` (in-place) by adding a `"description"` field for each needle function.
94 |
--------------------------------------------------------------------------------
/docs/dev_note.md:
--------------------------------------------------------------------------------
1 | # RepoQA Development Notes
2 |
3 | ## DEV Structure
4 |
5 | - `repo`: entrypoint for working repositories
6 | - `repoqa`: source code for the RepoQA evaluation library
7 | - `scripts`: scripts for maintaining the repository and other utilities
8 | - `dev`: scripts for CI/CD and repository maintenance
9 | - `curate`: code for dataset curation
10 | - `dep_analysis`: dependency analysis for different programming languages
11 | - `cherrypick`: cherry-picked repositories for evaluation
12 | - `demos`: demos to quickly use some utility functions such as requesting LLMs
13 |
14 | ## Development Beginner Notice
15 |
16 | ### After clone
17 |
18 | ```shell
19 | pip install pre-commit
20 | pre-commit install
21 | pip install -r requirements.txt
22 | pip install -r scripts/curate/requirements.txt
23 | ```
24 |
25 |
26 | ### Import errors?
27 |
28 | ```shell
29 | # Go to the root path of RepoQA
30 | export PYTHONPATH=$PYTHONPATH:$(pwd)
31 | ```
32 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.setuptools_scm]
6 | write_to = "repoqa/_version.py"
7 | version_scheme = "release-branch-semver"
8 | local_scheme = "no-local-version"
9 |
--------------------------------------------------------------------------------
/repoqa/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | try:
6 | from repoqa._version import __version__, __version_tuple__
7 | except ImportError:
8 | __version__ = "local-dev"
9 |
--------------------------------------------------------------------------------
/repoqa/compute_score.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import itertools
6 | import json
7 | import os
8 | import re
9 | from collections import defaultdict
10 | from datetime import datetime
11 | from enum import Enum
12 | from pathlib import Path
13 | from typing import Dict, List, Tuple, Union
14 |
15 | import numpy as np
16 | import tempdir
17 | from rich.console import Console
18 | from rich.table import Table
19 | from transformers import AutoConfig
20 | from tree_sitter_languages import get_language, get_parser
21 |
22 | from repoqa.data import get_repoqa_data
23 | from repoqa.metric import compute_function_similarity
24 | from repoqa.utility import COMMENT_QUERY, FUNCTION_QUERY, progress
25 |
26 | LANGUAGES = list(FUNCTION_QUERY.keys())
27 | THRESHOLDS = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
28 |
29 |
30 | class Result(Enum):
31 | BEST_MATCH = "best_match"
32 | FAIL_MATCH = "fail_match"
33 |
34 |
35 | # unbiased estimator from https://github.com/openai/human-eval
36 | def estimate_pass_at_k(
37 | num_samples: Union[int, List[int], np.ndarray],
38 | num_correct: Union[List[int], np.ndarray],
39 | k: int,
40 | ) -> np.ndarray:
41 | """
42 | Estimates pass@k of each problem and returns them in an array.
43 | """
44 |
45 | def estimator(n: int, c: int, k: int) -> float:
46 | """
47 | Calculates 1 - comb(n - c, k) / comb(n, k).
48 | """
49 | if n - c < k:
50 | return 1.0
51 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
52 |
53 | if isinstance(num_samples, int):
54 | num_samples_it = itertools.repeat(num_samples, len(num_correct))
55 | else:
56 | assert len(num_samples) == len(num_correct)
57 | num_samples_it = iter(num_samples)
58 |
59 | return np.array(
60 | [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
61 | )
62 |
63 |
64 | def remove_comments(source_code: str, lang: str) -> str:
65 | source_bytes = bytes(source_code, "utf8")
66 | parser = get_parser(lang)
67 | tree = parser.parse(source_bytes)
68 | root_node = tree.root_node
69 |
70 | # Remove comments from source code
71 | capture_list = []
72 | for query_str in COMMENT_QUERY[lang]:
73 | comment_query = get_language(lang).query(query_str)
74 | capture_list += comment_query.captures(root_node)
75 |
76 | capture_list.sort(key=lambda cap: cap[0].start_byte, reverse=True)
77 |
78 | for node, _ in capture_list:
79 | source_bytes = source_bytes[: node.start_byte] + source_bytes[node.end_byte :]
80 |
81 | return source_bytes.decode("utf-8")
82 |
83 |
84 | def sanitize_output(model_output: str, lang: str) -> str:
85 | model_output = model_output.strip()
86 | search_pattern = r"^```(?:\w+)?\s*\n(.*?)(?=^```)```"
87 | code_blocks = re.findall(search_pattern, model_output, re.DOTALL | re.MULTILINE)
88 |
89 | parser = get_parser(lang)
90 | fn_query = get_language(lang).query(FUNCTION_QUERY[lang])
91 |
92 | # If not code blocks found, simply return model output
93 | if not code_blocks:
94 | return model_output
95 |
96 | processed_blocks = []
97 | for block in code_blocks:
98 | processed_blocks.append(block)
99 |
100 | # Try to use tree-sitter to parse if possible
101 | try:
102 | block_bytes = bytes(block, "utf8")
103 | tree = parser.parse(block_bytes)
104 | for capture in fn_query.captures(tree.root_node):
105 | node, _ = capture
106 | function_content = block_bytes[node.start_byte : node.end_byte]
107 | return function_content.decode("utf8")
108 | except:
109 | pass
110 |
111 | # no valid functions found by tree-sitter approach return first block
112 | return processed_blocks[0]
113 |
114 |
115 | def print_result_table(model_name, pass_results):
116 | # Printing scores in a table
117 | table = Table(title=f"Scores (%) of {model_name} at different thresholds")
118 | table.add_column("Threshold", justify="center", style="bold magenta")
119 | for threshold in THRESHOLDS:
120 | table.add_column(f"{threshold}", justify="center")
121 |
122 | # Prepare data to determine the maximum values for each threshold
123 | threshold_scores = {threshold: [] for threshold in THRESHOLDS}
124 | for lang_results in pass_results.values():
125 | for thresh, value in lang_results.items():
126 | threshold_scores[thresh].append(value["pass@1"])
127 |
128 | # Calculate the maximum score for each threshold
129 | max_scores = {
130 | threshold: max(scores) for threshold, scores in threshold_scores.items()
131 | }
132 | min_scores = {
133 | threshold: min(scores) for threshold, scores in threshold_scores.items()
134 | }
135 |
136 | # Fill the table rows
137 | for language, lang_results in pass_results.items():
138 | row = [("⭐" if language == "all" else "") + language]
139 | for threshold, value in lang_results.items():
140 | score = value["pass@1"]
141 | formatted_score = f"{100 * score:.1f}"
142 | if max_scores[threshold] - score < 0.01:
143 | formatted_score = f"[bold green]{formatted_score}[/]"
144 | elif score - min_scores[threshold] < 0.01:
145 | formatted_score = f"[bold red]{formatted_score}[/]"
146 | row.append(formatted_score)
147 | if language == "all":
148 | row = [f"[bold yellow]{r}[/]" for r in row]
149 | table.add_row(*row)
150 |
151 | Console().print(table)
152 |
153 |
154 | def needle_evaluator(
155 | model_output: str,
156 | ground_truth: str,
157 | repo_info: Dict,
158 | lang: str,
159 | ignore_comments: bool,
160 | ) -> Tuple[Result, str, float]:
161 | contents = repo_info["content"]
162 | needles = repo_info["needles"]
163 |
164 | best_target = None
165 | best_similarity = 0
166 | sanitized_output = sanitize_output(model_output, lang)
167 | if ignore_comments:
168 | sanitized_output = remove_comments(sanitized_output, lang)
169 | for needle in needles:
170 | current_path = needle["path"]
171 | current_name = needle["name"]
172 | current_func = "\n".join(
173 | contents[current_path].split("\n")[
174 | needle["start_line"] : needle["end_line"]
175 | ]
176 | )
177 | if ignore_comments:
178 | current_func = remove_comments(current_func, lang)
179 |
180 | current_similarity = compute_function_similarity(sanitized_output, current_func)
181 | if current_similarity > best_similarity:
182 | best_similarity = current_similarity
183 | best_target = current_name
184 |
185 | if best_target == ground_truth:
186 | verdict = Result.BEST_MATCH
187 | else:
188 | verdict = Result.FAIL_MATCH
189 | return verdict, best_target, best_similarity
190 |
191 |
192 | def _get_repo(lang_data: Dict, repo_name: str) -> Dict:
193 | for repo in lang_data:
194 | if repo["repo"] == repo_name:
195 | return repo
196 |
197 |
198 | def compute_language_results(evaluation_result: Dict, all_results: Dict) -> None:
199 | for language, lang_results in evaluation_result.items():
200 | current_result = {}
201 | total = np.array([1 for _ in lang_results])
202 |
203 | for threshold in THRESHOLDS:
204 | correct_result = []
205 | for res in lang_results:
206 | bc = 0
207 | if res["is_best_similar"] and res["best_similar_score"] >= threshold:
208 | bc = 1
209 | correct_result.append(bc)
210 | correct_result = np.array(correct_result)
211 |
212 | pass_at_k = {
213 | f"pass@{k}": estimate_pass_at_k(total, correct_result, k).mean()
214 | for k in [1, 10, 100]
215 | if total.min() >= k
216 | }
217 | current_result[threshold] = pass_at_k
218 | all_results[language] = current_result
219 |
220 |
221 | def fetch_hf_context(model_name: str) -> str:
222 | # Retrieved from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1073
223 | possible_keys = [
224 | # OPT
225 | "max_position_embeddings",
226 | # GPT-2
227 | "n_positions",
228 | # MPT
229 | "max_seq_len",
230 | # ChatGLM2
231 | "seq_length",
232 | # Command-R
233 | "model_max_length",
234 | # Others
235 | "max_sequence_length",
236 | "max_seq_length",
237 | "seq_len",
238 | ]
239 | try:
240 | with tempdir.TempDir() as temp_dir:
241 | config = AutoConfig.from_pretrained(
242 | model_name,
243 | cache_dir=temp_dir,
244 | force_download=True,
245 | trust_remote_code=True,
246 | ).to_dict()
247 | longest_context = 0
248 | for key in possible_keys:
249 | if key in config:
250 | longest_context = max(config[key], longest_context)
251 | if not (longest_context):
252 | return "N/A"
253 | return str(int(longest_context / 1024)) + "k"
254 | except Exception as err:
255 | print(f"fetching failed... Reason:\n{err}")
256 | return "N/A"
257 |
258 |
259 | def compute_score(
260 | model_name: str, dataset: Dict, model_output: List[Dict], ignore_comments: bool
261 | ) -> Dict:
262 | evaluation_result = defaultdict(list)
263 | with progress(f"Scoring {model_name}") as pbar:
264 | for result in pbar.track(model_output):
265 | lang = result["language"]
266 | repo_name = result["repo"]
267 | model_outputs = result["output"]
268 | ground_truth = result["name"]
269 | repo_info = _get_repo(dataset[lang], repo_name)
270 |
271 | model_output = model_outputs[0]
272 | verdict, best_target, best_similarity = needle_evaluator(
273 | model_output, ground_truth, repo_info, lang, ignore_comments
274 | )
275 |
276 | is_best_similar = False
277 | if verdict == Result.BEST_MATCH:
278 | is_best_similar = True
279 |
280 | current_task = {
281 | "repo": repo_name,
282 | "name": ground_truth,
283 | "needle_position": result["position_ratio"],
284 | "is_best_similar": is_best_similar,
285 | "best_similar_score": best_similarity,
286 | "best_target": best_target,
287 | "position": {
288 | "token_start": result["needle_token_start"],
289 | "token_end": result["needle_token_end"],
290 | },
291 | }
292 | evaluation_result[lang].append(current_task)
293 |
294 | # Calculate pass@k
295 | pass_results = {}
296 |
297 | all_langs = []
298 | for lang in evaluation_result:
299 | all_langs += evaluation_result[lang]
300 | total = np.array([1 for _ in all_langs])
301 |
302 | pass_results["all"] = {}
303 | for threshold in THRESHOLDS:
304 | correct_result = []
305 | for res in all_langs:
306 | bc = 0
307 | if res["is_best_similar"] and res["best_similar_score"] >= threshold:
308 | bc = 1
309 | correct_result.append(bc)
310 | correct_result = np.array(correct_result)
311 | pass_at_k = {
312 | f"pass@{k}": estimate_pass_at_k(total, correct_result, k).mean()
313 | for k in [1, 10, 100]
314 | if total.min() >= k
315 | }
316 | pass_results["all"][threshold] = pass_at_k
317 |
318 | compute_language_results(evaluation_result, pass_results)
319 | print_result_table(model_name, pass_results)
320 |
321 | output_json = {}
322 | model_json = {}
323 | model_json["eval_date"] = str(datetime.now())
324 |
325 | # hardcode paid models
326 | if "/" in model_name:
327 | if model_name.startswith("bigcode/starcoder2"):
328 | train_context = "16k"
329 | else:
330 | train_context = fetch_hf_context(model_name)
331 | elif model_name.startswith("gpt-4-turbo") or model_name.startswith("gpt-4o-"):
332 | train_context = "128k"
333 | elif model_name.startswith("gpt-3.5-"):
334 | train_context = "16k"
335 | elif model_name.startswith("gemini-1.5-pro") or model_name.startswith(
336 | "gemini-1.5-flash"
337 | ):
338 | train_context = "1000k"
339 | elif model_name.startswith("gemini-1.0-pro"):
340 | train_context = "32k"
341 | elif model_name.startswith("claude-3-"):
342 | train_context = "200k"
343 | else:
344 | train_context = "N/A"
345 | model_json["train_size"] = train_context
346 | model_json["scores"] = pass_results
347 | model_json["results"] = evaluation_result
348 |
349 | output_json[model_name] = model_json
350 |
351 | return output_json
352 |
353 |
354 | def get_model_name(output_path: str) -> str:
355 | file_name = Path(output_path).stem
356 | segments = file_name.split("_")
357 | output_name = ""
358 | for segment in segments:
359 | if segment == "slash":
360 | output_name += "/"
361 | else:
362 | output_name += segment
363 | return output_name
364 |
365 |
366 | def save_json(output_json, result_path) -> None:
367 | if os.path.isfile(result_path):
368 | decision = ""
369 | while decision.lower() not in ["y", "n"]:
370 | print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...")
371 | decision = input()
372 |
373 | if decision.lower() == "y":
374 | # mv the file to a backup
375 | new_path = result_path + ".bak"
376 | while os.path.isfile(new_path):
377 | new_path += ".bak"
378 | os.rename(result_path, new_path)
379 | print(f"Backup {result_path} to {new_path}")
380 |
381 | if not os.path.isfile(result_path):
382 | with open(result_path, "w") as f:
383 | json.dump(output_json, f)
384 |
385 |
386 | def compute_main(
387 | model_output_path: str, ignore_comments: bool = False, dataset_path: str = None
388 | ):
389 | if dataset_path is None:
390 | dataset = get_repoqa_data()
391 | else:
392 | with open(dataset_path, "r") as dataset_f:
393 | dataset = json.load(dataset_f)
394 |
395 | model_outputs = []
396 | with open(model_output_path, "r") as output_f:
397 | for line in output_f:
398 | model_outputs.append(json.loads(line))
399 |
400 | file_base, _ = os.path.splitext(model_output_path)
401 | result_path = file_base + "-SCORES.json"
402 | model_name = get_model_name(model_output_path)
403 | output_json = compute_score(model_name, dataset, model_outputs, ignore_comments)
404 | save_json(output_json, result_path)
405 |
406 |
407 | def main():
408 | from fire import Fire
409 |
410 | Fire(compute_main)
411 |
412 |
413 | if __name__ == "__main__":
414 | main()
415 |
--------------------------------------------------------------------------------
/repoqa/data.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import gzip
6 | import json
7 | import os
8 |
9 | import tempdir
10 | import wget
11 | from appdirs import user_cache_dir
12 |
13 | CACHE_DIR = user_cache_dir("repoqa")
14 |
15 | REPOQA_DATA_OVERRIDE_PATH = os.getenv("REPOQA_DATA_OVERRIDE_PATH", None)
16 | REPOQA_DATA_VERSION = os.getenv("REPOQA_DATA_VERSION", "2024-06-23")
17 |
18 |
19 | def _get_repoqa_data_ready_path() -> str:
20 | if REPOQA_DATA_OVERRIDE_PATH:
21 | assert os.path.exists(
22 | REPOQA_DATA_OVERRIDE_PATH
23 | ), f"File not found: {REPOQA_DATA_OVERRIDE_PATH}"
24 | return REPOQA_DATA_OVERRIDE_PATH
25 |
26 | gzip_url = f"https://github.com/evalplus/repoqa_release/releases/download/{REPOQA_DATA_VERSION}/repoqa-{REPOQA_DATA_VERSION}.json.gz"
27 | cache_path = os.path.join(CACHE_DIR, f"repoqa-{REPOQA_DATA_VERSION}.json")
28 | # Check if human eval file exists in CACHE_DIR
29 | if not os.path.exists(cache_path):
30 | # Install HumanEval dataset and parse as json
31 | print(f"Downloading dataset from {gzip_url}")
32 | with tempdir.TempDir() as tmpdir:
33 | gzip_path = os.path.join(tmpdir, f"data.json.gz")
34 | wget.download(gzip_url, gzip_path)
35 |
36 | with gzip.open(gzip_path, "rb") as f:
37 | repoqa_data = f.read().decode("utf-8")
38 |
39 | # create CACHE_DIR if not exists
40 | os.makedirs(CACHE_DIR, exist_ok=True)
41 | # Write the original human eval file to CACHE_DIR
42 | with open(cache_path, "w") as f:
43 | f.write(repoqa_data)
44 |
45 | return cache_path
46 |
47 |
48 | def get_repoqa_data():
49 | with open(_get_repoqa_data_ready_path(), "r") as f:
50 | return json.load(f)
51 |
--------------------------------------------------------------------------------
/repoqa/metric.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import re
6 |
7 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
8 |
9 |
10 | def compute_function_similarity(
11 | candidate_function: str, reference_function: str
12 | ) -> float:
13 | candidate_tokens = [item for item in re.split("\s+", candidate_function.strip())]
14 |
15 | reference_tokens = [item for item in re.split("\s+", reference_function.strip())]
16 |
17 | chencherry = SmoothingFunction()
18 |
19 | return sentence_bleu(
20 | [reference_tokens], candidate_tokens, smoothing_function=chencherry.method4
21 | )
22 |
--------------------------------------------------------------------------------
/repoqa/provider/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from repoqa.provider.base import BaseProvider
6 |
--------------------------------------------------------------------------------
/repoqa/provider/anthropic.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import os
6 | from typing import List
7 |
8 | from anthropic import Client
9 |
10 | from repoqa.provider.base import BaseProvider
11 | from repoqa.provider.request.anthropic import make_auto_request
12 |
13 |
14 | class AnthropicProvider(BaseProvider):
15 | def __init__(self, model):
16 | self.model = model
17 | self.client = Client(api_key=os.getenv("ANTHROPIC_KEY"))
18 |
19 | def generate_reply(
20 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
21 | ) -> List[str]:
22 | assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
23 | replies = []
24 | for _ in range(n):
25 | reply = make_auto_request(
26 | self.client,
27 | message=question,
28 | model=self.model,
29 | temperature=temperature,
30 | max_tokens=max_tokens,
31 | system_msg=system_msg,
32 | )
33 | replies.append(reply.content[0].text)
34 |
35 | return replies
36 |
--------------------------------------------------------------------------------
/repoqa/provider/base.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from abc import ABC, abstractmethod
6 | from typing import List
7 |
8 |
9 | class BaseProvider(ABC):
10 | @abstractmethod
11 | def generate_reply(
12 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
13 | ) -> List[str]:
14 | ...
15 |
--------------------------------------------------------------------------------
/repoqa/provider/google.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import os
6 | from typing import List
7 |
8 | import google.generativeai as genai
9 |
10 | from repoqa.provider.base import BaseProvider
11 | from repoqa.provider.request.google import make_auto_request
12 |
13 |
14 | class GoogleProvider(BaseProvider):
15 | def __init__(self, model):
16 | genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
17 | self.model = model
18 | self.client = genai.GenerativeModel(model)
19 |
20 | def generate_reply(
21 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
22 | ) -> List[str]:
23 | assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
24 | replies = make_auto_request(
25 | self.client,
26 | question,
27 | self.model,
28 | n=n,
29 | max_tokens=max_tokens,
30 | temperature=temperature,
31 | system_msg=system_msg,
32 | )
33 |
34 | if len(replies.candidates) != n:
35 | print(f"[WARNING] # replies = {len(replies.candidates)} != {n = }")
36 |
37 | ret_texts = []
38 | for candidate in replies.candidates:
39 | parts = candidate.content.parts
40 | if parts:
41 | ret_texts.append(parts[0].text)
42 | else:
43 | print("Empty response!")
44 | ret_texts.append("")
45 | print(f"{candidate.safety_ratings = }")
46 |
47 | return ret_texts + [""] * (n - len(ret_texts))
48 |
--------------------------------------------------------------------------------
/repoqa/provider/hf.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import List
6 |
7 | import torch
8 | from transformers import AutoModelForCausalLM, AutoTokenizer
9 |
10 | from repoqa.provider.base import BaseProvider
11 | from repoqa.provider.request import construct_message_list, hacky_assistant_stop_seq
12 |
13 |
14 | class HfProvider(BaseProvider):
15 | def __init__(self, model, trust_remote_code=False, attn_implementation=None):
16 | self.tokenizer = AutoTokenizer.from_pretrained(
17 | model, trust_remote_code=trust_remote_code
18 | )
19 | self.hf_model = AutoModelForCausalLM.from_pretrained(
20 | model,
21 | trust_remote_code=trust_remote_code,
22 | attn_implementation=attn_implementation,
23 | torch_dtype="auto",
24 | ).cuda()
25 | self.stop_seq = []
26 | if self.tokenizer.chat_template:
27 | self.stop_seq.append(hacky_assistant_stop_seq(self.tokenizer))
28 |
29 | @torch.inference_mode()
30 | def generate_reply(
31 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
32 | ) -> List[str]:
33 | assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
34 |
35 | prompt_tokens = self.tokenizer.apply_chat_template(
36 | construct_message_list(question, system_msg),
37 | return_tensors="pt",
38 | add_generation_prompt=True,
39 | ).cuda()
40 | input_length = prompt_tokens.size(-1)
41 |
42 | gen_args = {"do_sample": False}
43 | if temperature > 0:
44 | gen_args["do_sample"] = True
45 | gen_args["temperature"] = temperature
46 |
47 | output_text = self.hf_model.generate(
48 | input_ids=prompt_tokens,
49 | max_new_tokens=max_tokens,
50 | num_return_sequences=n,
51 | pad_token_id=self.tokenizer.eos_token_id,
52 | use_cache=True,
53 | stop_strings=self.stop_seq,
54 | tokenizer=self.tokenizer,
55 | **gen_args,
56 | )
57 |
58 | gen_strs = [
59 | self.tokenizer.decode(
60 | x[input_length:],
61 | skip_special_tokens=True,
62 | clean_up_tokenization_spaces=False,
63 | )
64 | for x in output_text
65 | ]
66 | return gen_strs
67 |
--------------------------------------------------------------------------------
/repoqa/provider/openai.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import os
6 | from typing import List
7 |
8 | from openai import Client
9 | from transformers import AutoTokenizer
10 |
11 | from repoqa.provider.base import BaseProvider
12 | from repoqa.provider.request import hacky_assistant_stop_seq
13 | from repoqa.provider.request.openai import make_auto_request
14 |
15 |
16 | class OpenAIProvider(BaseProvider):
17 | def __init__(self, model, base_url: str = None):
18 | self.model = model
19 | self.client = Client(
20 | api_key=os.getenv("OPENAI_API_KEY", "none"), base_url=base_url
21 | )
22 | self.stop_seq = []
23 | try:
24 | tokenizer = AutoTokenizer.from_pretrained(model)
25 | if tokenizer.chat_template:
26 | self.stop_seq.append(hacky_assistant_stop_seq(tokenizer))
27 | print("Using stop sequence: ", self.stop_seq)
28 | except:
29 | print("Failed to automatically fetch stop tokens from HuggingFace.")
30 |
31 | def generate_reply(
32 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
33 | ) -> List[str]:
34 | assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
35 | replies = make_auto_request(
36 | self.client,
37 | message=question,
38 | model=self.model,
39 | temperature=temperature,
40 | n=n,
41 | max_tokens=max_tokens,
42 | system_msg=system_msg,
43 | stop=self.stop_seq,
44 | )
45 |
46 | return [reply.message.content for reply in replies.choices]
47 |
--------------------------------------------------------------------------------
/repoqa/provider/request/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 |
6 | def construct_message_list(message, system_message=None):
7 | msglist = [{"role": "user", "content": message}]
8 | if system_message:
9 | msglist.insert(0, {"role": "system", "content": system_message})
10 | return msglist
11 |
12 |
13 | def hacky_assistant_stop_seq(tokenizer) -> str:
14 | _magic_string_ = "&==NowOrNever==&Accelerate!!!==&"
15 | return tokenizer.apply_chat_template(
16 | [
17 | {"role": "user", "content": ""},
18 | {"role": "assistant", "content": _magic_string_},
19 | ],
20 | tokenize=False,
21 | ).split(_magic_string_)[-1]
22 |
--------------------------------------------------------------------------------
/repoqa/provider/request/anthropic.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import signal
6 | import time
7 |
8 | import anthropic
9 | from anthropic.types import Message
10 |
11 | from repoqa.provider.request import construct_message_list
12 |
13 |
14 | def make_request(
15 | client: anthropic.Client,
16 | message: str,
17 | model: str,
18 | max_tokens: int = 512,
19 | temperature: float = 1,
20 | system_msg="You are a helpful assistant good at coding.",
21 | **kwargs,
22 | ) -> Message:
23 | return client.messages.create(
24 | model=model,
25 | messages=construct_message_list(message, system_message=system_msg),
26 | max_tokens=max_tokens,
27 | temperature=temperature,
28 | **kwargs,
29 | )
30 |
31 |
32 | def handler(signum, frame):
33 | # swallow signum and frame
34 | raise Exception("end of time")
35 |
36 |
37 | def make_auto_request(client: anthropic.Client, *args, **kwargs) -> Message:
38 | ret = None
39 | while ret is None:
40 | try:
41 | signal.signal(signal.SIGALRM, handler)
42 | signal.alarm(100)
43 | ret = make_request(client, *args, **kwargs)
44 | signal.alarm(0)
45 | except anthropic.RateLimitError:
46 | print("Rate limit exceeded. Waiting...")
47 | signal.alarm(0)
48 | time.sleep(10)
49 | except anthropic.APIConnectionError:
50 | print("API connection error. Waiting...")
51 | signal.alarm(0)
52 | time.sleep(5)
53 | except anthropic.InternalServerError:
54 | print("Internal server error. Waiting...")
55 | signal.alarm(0)
56 | time.sleep(5)
57 | except anthropic.APIError as e:
58 | print("Unknown API error")
59 | print(e)
60 | if (
61 | e.body["error"]["message"]
62 | == "Output blocked by content filtering policy"
63 | ):
64 | raise Exception("Content filtering policy blocked output")
65 | signal.alarm(0)
66 | except Exception as e:
67 | print("Unknown error. Waiting...")
68 | print(e)
69 | signal.alarm(0)
70 | time.sleep(1)
71 | return ret
72 |
--------------------------------------------------------------------------------
/repoqa/provider/request/google.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import signal
6 | import time
7 |
8 | import google.generativeai as genai
9 | from google.api_core.exceptions import GoogleAPICallError, ResourceExhausted
10 |
11 | from repoqa.provider.request import construct_message_list
12 |
13 |
14 | def make_request(
15 | client: genai.GenerativeModel,
16 | message: str,
17 | model: str,
18 | max_tokens: int = 512,
19 | temperature: float = 1,
20 | n: int = 1,
21 | system_msg="You are a helpful assistant good at coding.",
22 | **kwargs,
23 | ) -> genai.types.GenerateContentResponse:
24 | messages = []
25 | if system_msg:
26 | messages.append({"role": "system", "parts": [system_msg]})
27 | messages.append({"role": "user", "parts": [message]})
28 | return client.generate_content(
29 | messages,
30 | generation_config=genai.types.GenerationConfig(
31 | candidate_count=n, max_output_tokens=max_tokens, temperature=temperature
32 | ),
33 | **kwargs,
34 | )
35 |
36 |
37 | def handler(signum, frame):
38 | # swallow signum and frame
39 | raise Exception("end of time")
40 |
41 |
42 | def make_auto_request(*args, **kwargs) -> genai.types.GenerateContentResponse:
43 | ret = None
44 | while ret is None:
45 | try:
46 | signal.signal(signal.SIGALRM, handler)
47 | signal.alarm(100)
48 | ret = make_request(*args, **kwargs)
49 | signal.alarm(0)
50 | except ResourceExhausted as e:
51 | print("Rate limit exceeded. Waiting...", e.message)
52 | signal.alarm(0)
53 | time.sleep(10)
54 | except GoogleAPICallError as e:
55 | print(e.message)
56 | signal.alarm(0)
57 | time.sleep(1)
58 | except Exception as e:
59 | print("Unknown error. Waiting...")
60 | print(e)
61 | signal.alarm(0)
62 | time.sleep(1)
63 | return ret
64 |
--------------------------------------------------------------------------------
/repoqa/provider/request/openai.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import signal
6 | import time
7 |
8 | import openai
9 | from openai.types.chat import ChatCompletion
10 |
11 | from repoqa.provider.request import construct_message_list
12 |
13 |
14 | def make_request(
15 | client: openai.Client,
16 | message: str,
17 | model: str,
18 | max_tokens: int = 512,
19 | temperature: float = 1,
20 | n: int = 1,
21 | system_msg="You are a helpful assistant good at coding.",
22 | **kwargs,
23 | ) -> ChatCompletion:
24 | return client.chat.completions.create(
25 | model=model,
26 | messages=construct_message_list(message, system_message=system_msg),
27 | max_tokens=max_tokens,
28 | temperature=temperature,
29 | n=n,
30 | **kwargs,
31 | )
32 |
33 |
34 | def handler(signum, frame):
35 | # swallow signum and frame
36 | raise Exception("end of time")
37 |
38 |
39 | def make_auto_request(*args, **kwargs) -> ChatCompletion:
40 | ret = None
41 | while ret is None:
42 | try:
43 | signal.signal(signal.SIGALRM, handler)
44 | signal.alarm(100)
45 | ret = make_request(*args, **kwargs)
46 | signal.alarm(0)
47 | except openai.RateLimitError:
48 | print("Rate limit exceeded. Waiting...")
49 | signal.alarm(0)
50 | time.sleep(10)
51 | except openai.APIConnectionError:
52 | print("API connection error. Waiting...")
53 | signal.alarm(0)
54 | time.sleep(5)
55 | except openai.APIError as e:
56 | print(e)
57 | signal.alarm(0)
58 | except Exception as e:
59 | print("Unknown error. Waiting...")
60 | print(e)
61 | signal.alarm(0)
62 | time.sleep(1)
63 | return ret
64 |
--------------------------------------------------------------------------------
/repoqa/provider/vllm.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from typing import List
6 |
7 | from transformers import AutoTokenizer
8 | from vllm import LLM, SamplingParams
9 |
10 | from repoqa.provider.base import BaseProvider
11 | from repoqa.provider.request import construct_message_list, hacky_assistant_stop_seq
12 |
13 |
14 | class VllmProvider(BaseProvider):
15 | def __init__(
16 | self, model, tensor_parallel_size, max_model_len=None, trust_remote_code=False
17 | ):
18 | self.tokenizer = AutoTokenizer.from_pretrained(
19 | model, trust_remote_code=trust_remote_code
20 | )
21 | self.llm = LLM(
22 | model=model,
23 | tensor_parallel_size=tensor_parallel_size,
24 | max_model_len=max_model_len,
25 | trust_remote_code=trust_remote_code,
26 | )
27 | self.stop_seq = []
28 | if self.tokenizer.chat_template:
29 | self.stop_seq.append(hacky_assistant_stop_seq(self.tokenizer))
30 |
31 | def generate_reply(
32 | self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
33 | ) -> List[str]:
34 | assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
35 |
36 | prompt = self.tokenizer.apply_chat_template(
37 | construct_message_list(question, system_msg),
38 | tokenize=False,
39 | add_generation_prompt=True,
40 | )
41 | vllm_outputs = self.llm.generate(
42 | [prompt],
43 | SamplingParams(
44 | temperature=temperature,
45 | max_tokens=max_tokens,
46 | stop=self.stop_seq,
47 | ),
48 | use_tqdm=False,
49 | )
50 |
51 | gen_strs = [x.outputs[0].text for x in vllm_outputs]
52 | return gen_strs
53 |
--------------------------------------------------------------------------------
/repoqa/search_needle_function.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | import json
5 | import os
6 | from enum import Enum
7 | from typing import List, Tuple
8 |
9 | from transformers import AutoTokenizer
10 | from tree_sitter_languages import get_language, get_parser
11 |
12 | from repoqa.compute_score import compute_score, save_json
13 | from repoqa.data import CACHE_DIR, get_repoqa_data
14 | from repoqa.utility import COMMENT_QUERY, progress, topological_sort
15 |
16 | COMMENT_PREFIX = {
17 | "python": "#",
18 | "java": "//",
19 | "typescript": "//",
20 | "rust": "//",
21 | "cpp": "//",
22 | "go": "//",
23 | }
24 |
25 | # Model context below:
26 | TEMPLATE = "instruction\ncode_context\ndescription\ninstruction"
27 |
28 | INSTRUCTION = (
29 | "Based on the function description and code context,"
30 | " please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:"
31 | )
32 |
33 | # Mode to clean context comments
34 | class CleanComment(Enum):
35 | NoClean = "none"
36 | PositionalPadding = "positional_padding"
37 | NoPadding = "no_padding"
38 |
39 |
40 | def _backward_tokenizable_lines(lines, tokenizer, max_tokens):
41 | """Return the text and tokens from bottom to top"""
42 | text = ""
43 | ntokens = 0
44 | is_break = False
45 | for line in reversed(lines):
46 | new_ntokens = len(tokenizer.tokenize(line + "\n"))
47 | if ntokens + new_ntokens > max_tokens:
48 | is_break = True
49 | break
50 | text = line + "\n" + text
51 | ntokens += new_ntokens
52 | return text, ntokens, is_break
53 |
54 |
55 | def _forward_tokenizable_lines(lines, tokenizer, max_tokens):
56 | """Return the text and tokens from top to bottom"""
57 | text = ""
58 | ntokens = 0
59 | is_break = False
60 | for line in lines:
61 | new_ntokens = len(tokenizer.tokenize(line + "\n"))
62 | if ntokens + new_ntokens > max_tokens:
63 | is_break = True
64 | break
65 | text += line + "\n"
66 | ntokens += new_ntokens
67 | if is_break:
68 | text = text + "...\n"
69 | ntokens += len(tokenizer.tokenize("...\n"))
70 | return text, ntokens, is_break
71 |
72 |
73 | def filter_path_comments(capture, context_paths, source_bytes, comment_prefix):
74 | node, _ = capture
75 | text = source_bytes[node.start_byte : node.end_byte]
76 | for path in context_paths:
77 | if text.decode("utf8") == comment_prefix + " Path: " + path:
78 | return False
79 | return True
80 |
81 |
82 | def clean_segment_comments(language, segment, context_paths):
83 | source_bytes = bytes(segment, "utf8")
84 | parser = get_parser(language)
85 | tree = parser.parse(source_bytes)
86 | root_node = tree.root_node
87 |
88 | # Remove comments from source code
89 | capture_list = []
90 | for query_str in COMMENT_QUERY[language]:
91 | comment_query = get_language(language).query(query_str)
92 | capture_list += comment_query.captures(root_node)
93 |
94 | # Filter out synethetic comments containing paths info
95 | filtered_capture = list(
96 | filter(
97 | lambda capture: filter_path_comments(
98 | capture, context_paths, source_bytes, COMMENT_PREFIX[language]
99 | ),
100 | capture_list,
101 | )
102 | )
103 |
104 | filtered_capture.sort(key=lambda cap: cap[0].start_byte, reverse=True)
105 |
106 | for node, _ in filtered_capture:
107 | source_bytes = source_bytes[: node.start_byte] + source_bytes[node.end_byte :]
108 |
109 | return source_bytes.decode("utf-8")
110 |
111 |
112 | # Clean partial context due to context construction
113 | def clean_partial_file(language, whole_file, partial_lines, path):
114 | path_comment = f"{COMMENT_PREFIX[language]} Path: {path}\n"
115 | source_bytes = bytes(whole_file, "utf8")
116 | parser = get_parser(language)
117 | tree = parser.parse(source_bytes)
118 | root_node = tree.root_node
119 |
120 | # Remove comments from source code
121 | capture_list = []
122 | for query_str in COMMENT_QUERY[language]:
123 | comment_query = get_language(language).query(query_str)
124 | capture_list += comment_query.captures(root_node)
125 |
126 | capture_list.sort(key=lambda cap: cap[0].start_byte, reverse=True)
127 |
128 | for node, _ in capture_list:
129 | new_line_count = source_bytes[node.start_byte : node.end_byte].count(b"\n")
130 | source_bytes = (
131 | source_bytes[: node.start_byte]
132 | + b"\n" * new_line_count
133 | + source_bytes[node.end_byte :]
134 | )
135 |
136 | return (
137 | path_comment
138 | + "\n".join(source_bytes.decode("utf-8").split("\n")[: partial_lines - 1])
139 | + "...\n"
140 | )
141 |
142 |
143 | def clean_context_comments(
144 | language: str,
145 | prefix: str,
146 | needle_code: str,
147 | suffix: str,
148 | tokenizer,
149 | context_paths: str,
150 | top_prefix_file: str,
151 | bot_suffix_file: str,
152 | position_ratio: float,
153 | add_padding: bool,
154 | ):
155 | prefix_orig_size = len(tokenizer.tokenize(prefix))
156 | needle_orig_size = len(tokenizer.tokenize(needle_code))
157 | suffix_orig_size = len(tokenizer.tokenize(suffix))
158 |
159 | # If there is are prefix files, it might get chopped off preventing proper parsing
160 | # we fully parse the top prefix file to avoid errors
161 | if top_prefix_file:
162 | second_path = f"{COMMENT_PREFIX[language]} Path: {context_paths[1]}"
163 | prefix_lines = prefix.split("\n")
164 | top_file_lines = 0
165 | lines_after_target = []
166 | target_found = False
167 | for line in prefix_lines:
168 | if target_found:
169 | lines_after_target.append(line)
170 | elif second_path in line:
171 | target_found = True
172 | lines_after_target.append(line)
173 | else:
174 | top_file_lines += 1
175 | top_file_cleaned = clean_partial_file(
176 | language, top_prefix_file, top_file_lines, context_paths[0]
177 | )
178 | rest_files_cleaned = clean_segment_comments(
179 | language, "\n".join(lines_after_target), context_paths
180 | )
181 | prefix_cleaned = top_file_cleaned + rest_files_cleaned
182 | else:
183 | prefix_cleaned = clean_segment_comments(language, prefix, context_paths)
184 | needle_cleaned = needle_code
185 | needle_cleaned = clean_segment_comments(language, needle_code, context_paths)
186 |
187 | # Same for suffix
188 | if bot_suffix_file:
189 | second_path = f"{COMMENT_PREFIX[language]} Path: {context_paths[-1]}"
190 | prefix_lines = prefix.split("\n")
191 | bot_file_lines = 0
192 | lines_before_target = []
193 | target_found = False
194 | for line in prefix_lines:
195 | if target_found:
196 | lines_before_target.append(line)
197 | elif second_path in line:
198 | target_found = True
199 | lines_before_target.append(line)
200 | else:
201 | bot_file_lines += 1
202 | top_file_cleaned = clean_partial_file(
203 | language, bot_suffix_file, bot_file_lines, context_paths[-1]
204 | )
205 | rest_files_cleaned = clean_segment_comments(
206 | language, "\n".join(lines_before_target), context_paths
207 | )
208 | suffix_cleaned = rest_files_cleaned + top_file_cleaned
209 | else:
210 | suffix_cleaned = clean_segment_comments(language, suffix, context_paths)
211 |
212 | if not add_padding:
213 | return prefix_cleaned, needle_cleaned, suffix_cleaned
214 |
215 | # Calculate amount of padding to prefix and suffix to maintain position
216 | prefix_clean_size = len(tokenizer.tokenize(prefix_cleaned))
217 | needle_clean_size = len(tokenizer.tokenize(needle_cleaned))
218 | suffix_clean_size = len(tokenizer.tokenize(suffix_cleaned))
219 |
220 | # Determine how much of needle padding go to prefix & suffix
221 | needle_tokens_removed = needle_orig_size - needle_clean_size
222 | needle_prefix_padding = int(needle_tokens_removed * position_ratio)
223 | needle_suffix_padding = needle_tokens_removed - needle_prefix_padding
224 |
225 | # Add more padding to compensate removal from prefix/suffix portions
226 | needle_prefix_padding = int(
227 | (needle_prefix_padding + prefix_orig_size - prefix_clean_size - 1)
228 | )
229 | needle_suffix_padding = int(
230 | (needle_suffix_padding + suffix_orig_size - suffix_clean_size - 1)
231 | )
232 |
233 | prefix_dummy = ""
234 | line = 0
235 | while needle_prefix_padding > 0:
236 | current = f"{COMMENT_PREFIX[language]} Line Number {line}\n"
237 | current_len = len(tokenizer.tokenize(current))
238 | needle_prefix_padding -= current_len
239 | prefix_dummy += current
240 | line += 1
241 | prefix_cleaned = prefix_dummy + "\n" + prefix_cleaned
242 |
243 | suffix_dummy = ""
244 | while needle_suffix_padding > 0:
245 | current = f"{COMMENT_PREFIX[language]} Line Number {line}\n"
246 | current_len = len(tokenizer.tokenize(current))
247 | needle_suffix_padding -= current_len
248 | line += 1
249 | suffix_dummy += current
250 | suffix_cleaned = suffix_cleaned + suffix_dummy + "\n"
251 |
252 | return prefix_cleaned, needle_cleaned, suffix_cleaned
253 |
254 |
255 | def make_code_context(
256 | needle,
257 | file_content_list: List[Tuple[str, str]],
258 | position_ratio: float,
259 | code_context_size: int,
260 | language: str,
261 | clean_comments: CleanComment = CleanComment.NoClean,
262 | ) -> str:
263 | """
264 | Slice the file_content_list such that:
265 | 1. The slice contains code_context_size tokens
266 | 2. The positon of the needle is at position_ratio of the slice*
267 | *May not be achievable if the needle is too close to the beginning or end of the file_content_list
268 | *May not be accurate as we will also insert file names at the beginning of each file
269 | *Token sizes might not be 100 accurate but should be close enough
270 | """
271 | tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
272 |
273 | needle_file_idx, needle_file_content = [
274 | (i, content)
275 | for i, (f, content) in enumerate(file_content_list)
276 | if f == needle["path"]
277 | ][0]
278 |
279 | needle_code = needle_file_content[needle["start_byte"] : needle["end_byte"]]
280 | ntoken_needle = len(tokenizer.tokenize(needle_code))
281 |
282 | # Used for if cleaning comments option is enabled (paths comments are skipped)
283 | context_paths = [needle["path"]]
284 | top_prefix_file = None
285 | bot_suffix_file = None
286 |
287 | prefix_size = int(code_context_size * position_ratio - ntoken_needle / 2)
288 | suffix_size = code_context_size - ntoken_needle - prefix_size
289 |
290 | # handling prefix of the needle file
291 | code_prefix, ntokens, is_break = _backward_tokenizable_lines(
292 | [COMMENT_PREFIX[language] + " Path: " + needle["path"]]
293 | + needle_file_content[: needle["start_byte"]].split("\n"),
294 | tokenizer,
295 | prefix_size,
296 | )
297 | prefix_size -= ntokens
298 |
299 | # handling prefix of the previous files
300 | index = needle_file_idx - 1
301 | while not is_break and prefix_size > 0 and index >= 0:
302 | path, content = file_content_list[index]
303 | context_paths.insert(0, path)
304 | top_prefix_file = content
305 | prefix, ntokens, is_break = _forward_tokenizable_lines(
306 | [COMMENT_PREFIX[language] + " Path: " + path] + content.split("\n"),
307 | tokenizer,
308 | prefix_size,
309 | )
310 | code_prefix = prefix + code_prefix
311 | prefix_size -= ntokens
312 | index -= 1
313 |
314 | # handling suffix of the needle file
315 | code_suffix, ntokens, is_break = _forward_tokenizable_lines(
316 | needle_file_content[needle["end_byte"] :].split("\n"), tokenizer, suffix_size
317 | )
318 | suffix_size -= ntokens
319 |
320 | # handling suffix of the next files
321 | index = needle_file_idx + 1
322 | while not is_break and suffix_size > 0 and index < len(file_content_list):
323 | path, content = file_content_list[index]
324 | context_paths.append(path)
325 | bot_suffix_file = content
326 | suffix, ntokens, is_break = _forward_tokenizable_lines(
327 | [COMMENT_PREFIX[language] + " Path: " + path] + content.split("\n"),
328 | tokenizer,
329 | suffix_size,
330 | )
331 | code_suffix += suffix
332 | suffix_size -= ntokens
333 | index += 1
334 |
335 | # Remove the comments in code_prefix, needle_code, code_suffix and
336 | # pad the code_prefix and code_suffix to maintain the position of the needle
337 | if clean_comments != CleanComment.NoClean:
338 | code_prefix, needle_code, code_suffix = clean_context_comments(
339 | language,
340 | code_prefix,
341 | needle_code,
342 | code_suffix,
343 | tokenizer,
344 | context_paths,
345 | top_prefix_file,
346 | bot_suffix_file,
347 | position_ratio,
348 | clean_comments == CleanComment.PositionalPadding,
349 | )
350 |
351 | code_context = code_prefix + needle_code + code_suffix
352 |
353 | needle_token_start = len(tokenizer.tokenize(code_prefix))
354 | needle_token_end = needle_token_start + len(tokenizer.tokenize(needle_code))
355 | code_context_ntokens = needle_token_end + len(tokenizer.tokenize(code_suffix))
356 |
357 | return {
358 | "code_context": code_context,
359 | "needle_token_start": needle_token_start,
360 | "needle_token_end": needle_token_end,
361 | "code_context_ntokens": code_context_ntokens,
362 | }
363 |
364 |
365 | def make_task_id(lang, repo, needle_name):
366 | return f"{lang}::{repo}::{needle_name}"
367 |
368 |
369 | def make_cache_id(lang, repo, needle_name, code_context_size, position_ratio):
370 | return f"{lang}::{repo}::{needle_name}::{code_context_size}::{position_ratio}"
371 |
372 |
373 | def evaluate_model(
374 | model: str,
375 | base_url: str = None,
376 | backend: str = None,
377 | tensor_parallel_size: int = 1,
378 | code_context_size: int = 16 * 1024,
379 | max_new_tokens: int = 1024,
380 | result_dir: str = "results",
381 | languages: List[str] = None,
382 | caching: bool = True, # if enabled, will cache the tasks which can be used to resume
383 | system_message: str = None,
384 | dataset_path: str = None,
385 | clean_ctx_comments: str = "none",
386 | eval_ignore_comments: bool = False, # ignore comments during score computation
387 | trust_remote_code: bool = False,
388 | attn_implementation=None,
389 | ):
390 | if backend is None:
391 | if base_url is not None:
392 | backend = "openai"
393 | else:
394 | backend = "vllm"
395 | print(f"Using {backend} as the backend")
396 | assert backend is not None, "Please specify the backend"
397 |
398 | if dataset_path is not None:
399 | with open(dataset_path) as f:
400 | dataset = json.load(f)
401 | else:
402 | dataset = get_repoqa_data()
403 |
404 | # makedir if not exists
405 | os.makedirs(result_dir, exist_ok=True)
406 | context_size_dir = os.path.join(result_dir, f"ntoken_{code_context_size}")
407 | os.makedirs(context_size_dir, exist_ok=True)
408 | model_output_path = os.path.join(
409 | context_size_dir,
410 | f"{model.replace('/', '_slash_')}.jsonl",
411 | )
412 |
413 | # resume from model_output_file
414 | if os.path.exists(model_output_path):
415 | with open(model_output_path) as f:
416 | model_outputs = [json.loads(line) for line in f]
417 | else:
418 | model_outputs = []
419 |
420 | if clean_ctx_comments == "positional_padding":
421 | clean_ctx_comments = CleanComment.PositionalPadding
422 | elif clean_ctx_comments == "no_padding":
423 | clean_ctx_comments = CleanComment.NoPadding
424 | else:
425 | clean_ctx_comments = CleanComment.NoClean
426 |
427 | # resume tasks from cache if any
428 | # schema: {"cache_id": .., **task}
429 | extra = ""
430 | if clean_ctx_comments != CleanComment.NoClean:
431 | extra += "_clean_cmt"
432 | cache_file = os.path.join(
433 | CACHE_DIR, f"cache{extra}_ntoken_{code_context_size}_v1.jsonl"
434 | )
435 | os.makedirs(CACHE_DIR, exist_ok=True)
436 |
437 | cache = {}
438 | if caching:
439 | print("🔥 Caching enabled")
440 | if os.path.exists(cache_file):
441 | with open(cache_file) as f:
442 | cache = [json.loads(line) for line in f]
443 | cache = {c["cache_id"]: c for c in cache}
444 | # remove the cache_id field in c
445 | for c in cache.values():
446 | c.pop("cache_id")
447 | print(f"Resuming from cache: {cache_file} with {len(cache)} tasks")
448 |
449 | resumed_task_ids = {
450 | make_task_id(r["language"], r["repo"], r["name"]) for r in model_outputs
451 | }
452 |
453 | # for each task we include
454 | # "repo", "name", "language", "path",
455 | # "template", "position_ratio", "description", "instruction", "code_context"
456 | # "needle_token_start", "needle_token_end", "code_context_ntokens"
457 | tasks = []
458 | for lang, repos in dataset.items():
459 | if languages is not None and lang not in languages:
460 | print(f"Skipping {lang} as it is not selected; selected: {languages}")
461 | continue
462 |
463 | print(f"🔥 Preparing code context for {lang}...")
464 | with progress(f"Processing {lang} context") as pbar:
465 | for repo in pbar.track(repos):
466 | # skip if the repo does not have needles
467 | if "needles" not in repo:
468 | pbar.console.print(
469 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `needles` -- do needle analysis first"
470 | )
471 | continue
472 |
473 | ordered_paths = topological_sort(repo["dependency"])
474 | file_content_list = [
475 | (path, repo["content"][path]) for path in ordered_paths
476 | ]
477 | for i, needle in enumerate(repo["needles"]):
478 | task_id = make_task_id(lang, repo["repo"], needle["name"])
479 | if task_id in resumed_task_ids:
480 | pbar.console.print(
481 | f"Skipping {task_id} as it is already in the results"
482 | )
483 | continue
484 |
485 | position_ratio = (i + 0.5) / len(repo["needles"])
486 | cache_id = make_cache_id(
487 | lang,
488 | repo["repo"],
489 | needle["name"],
490 | code_context_size,
491 | position_ratio,
492 | )
493 | if cache_id in cache:
494 | tasks.append(cache[cache_id])
495 | continue
496 |
497 | task = {
498 | "repo": repo["repo"],
499 | "name": needle["name"],
500 | "language": lang,
501 | "path": needle["path"],
502 | "position_ratio": position_ratio,
503 | "description": f"\nFunction Description:{needle['description']}\n",
504 | "instruction": INSTRUCTION,
505 | "template": TEMPLATE,
506 | }
507 | code_context_info = make_code_context(
508 | needle,
509 | file_content_list,
510 | position_ratio=position_ratio,
511 | code_context_size=code_context_size,
512 | language=lang,
513 | clean_comments=clean_ctx_comments,
514 | )
515 | task.update(code_context_info)
516 | tasks.append(task)
517 |
518 | if caching: # cache
519 | with open(cache_file, "a") as f_out:
520 | f_out.write(
521 | json.dumps({"cache_id": cache_id, **task}) + "\n"
522 | )
523 | # filter finished tasks again (in case a cache is used)
524 | tasks = [
525 | task
526 | for task in tasks
527 | if make_task_id(task["language"], task["repo"], task["name"])
528 | not in resumed_task_ids
529 | ]
530 |
531 | if len(tasks) == 0:
532 | print("No tasks to evaluate! Exiting...")
533 | return
534 |
535 | if backend == "openai":
536 | from repoqa.provider.openai import OpenAIProvider
537 |
538 | engine = OpenAIProvider(model, base_url=base_url)
539 | elif backend == "vllm":
540 | from repoqa.provider.vllm import VllmProvider
541 |
542 | engine = VllmProvider(
543 | model,
544 | tensor_parallel_size=tensor_parallel_size,
545 | max_model_len=int(code_context_size * 1.5), # Magic number
546 | trust_remote_code=trust_remote_code,
547 | )
548 | elif backend == "anthropic":
549 | from repoqa.provider.anthropic import AnthropicProvider
550 |
551 | engine = AnthropicProvider(model)
552 | elif backend == "hf":
553 | from repoqa.provider.hf import HfProvider
554 |
555 | engine = HfProvider(
556 | model,
557 | trust_remote_code=trust_remote_code,
558 | attn_implementation=attn_implementation,
559 | )
560 | elif backend == "google":
561 | from repoqa.provider.google import GoogleProvider
562 |
563 | engine = GoogleProvider(model)
564 | else:
565 | raise ValueError(f"Unknown backend: {backend}")
566 |
567 | if not system_message:
568 | print("🔥 System message is disabled")
569 | system_message = None
570 |
571 | with open(model_output_path, "a") as f_out:
572 | with progress(f"Running {model}") as pbar:
573 | for task in pbar.track(tasks):
574 | actual_position_ratio = (
575 | task["needle_token_start"] / task["code_context_ntokens"]
576 | )
577 | pbar.console.print(
578 | f"Searching {task['name']} in {task['repo']} ({task['language']}) -- "
579 | f"position ratio: actual={actual_position_ratio:.2f}, expected={task['position_ratio']}"
580 | )
581 | prompt = ""
582 | for key in task["template"].split("\n"):
583 | prompt += task[key]
584 |
585 | replies = engine.generate_reply(
586 | prompt, n=1, max_tokens=max_new_tokens, system_msg=system_message
587 | )
588 | result = {**task, "output": replies}
589 | f_out.write(json.dumps(result) + "\n")
590 | f_out.flush()
591 | model_outputs.append(result)
592 |
593 | file_base, _ = os.path.splitext(model_output_path)
594 | result_path = file_base + "-SCORES.json"
595 | output_json = compute_score(
596 | model,
597 | dataset,
598 | model_outputs,
599 | eval_ignore_comments or clean_ctx_comments != CleanComment.NoClean,
600 | )
601 | save_json(output_json, result_path)
602 |
603 |
604 | def main():
605 | from fire import Fire
606 |
607 | Fire(evaluate_model)
608 |
609 |
610 | if __name__ == "__main__":
611 | main()
612 |
--------------------------------------------------------------------------------
/repoqa/utility.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from rich.progress import (
6 | BarColumn,
7 | MofNCompleteColumn,
8 | Progress,
9 | TextColumn,
10 | TimeElapsedColumn,
11 | )
12 |
13 | FUNCTION_QUERY = {
14 | "python": "(function_definition name: (_)) @fdef",
15 | "java": "(method_declaration name: (_)) @fdef",
16 | "typescript": "(function_declaration name: (_)) @fdef",
17 | "rust": "(function_item name: (_)) @fdef",
18 | "cpp": "(function_definition declarator: (function_declarator declarator: (identifier))) @fdef",
19 | "go": "(function_declaration name: (_)) @fdef",
20 | }
21 |
22 | COMMENT_QUERY = {
23 | "python": [
24 | "(block (expression_statement (string) @docstring))",
25 | "(comment) @comment",
26 | ],
27 | "java": ["(line_comment) @comment", "(block_comment) @comment"],
28 | "cpp": ["(comment) @comment"],
29 | "rust": ["(line_comment) @comment", "(block_comment) @comment"],
30 | "typescript": ["(comment) @comment"],
31 | "go": ["(comment) @comment"],
32 | }
33 |
34 | FUNCTION_NAME_QUERY = {
35 | "python": """
36 | ((function_definition
37 | name: (identifier) @function_name))
38 | """,
39 | "java": """
40 | (method_declaration
41 | name: (identifier) @method_name)
42 | """,
43 | "typescript": """
44 | (function_declaration
45 | name: (identifier) @function_name)
46 | """,
47 | "rust": """
48 | (function_item
49 | name: (identifier) @function_name)
50 | """,
51 | "cpp": """
52 | (function_definition
53 | name: (identifier) @function_name)
54 | """,
55 | }
56 |
57 |
58 | def topological_sort(graph):
59 | # Stack to store the topological order
60 | stack = []
61 | # Set to keep track of visited nodes
62 | visited = set()
63 |
64 | # Recursive function to process nodes
65 | def dfs(node):
66 | # Mark the current node as visited
67 | visited.add(node)
68 | # Recurse for all the vertices adjacent to this vertex
69 | for neighbour in graph.get(node, []):
70 | if neighbour not in visited:
71 | dfs(neighbour)
72 | # Push current vertex to stack which stores the result
73 | stack.append(node)
74 |
75 | # Call the recursive helper function to store the topological sort starting from all vertices one by one
76 | for node in graph:
77 | if node not in visited:
78 | dfs(node)
79 |
80 | return stack
81 |
82 |
83 | def progress(note: str = "processing"):
84 | return Progress(
85 | TextColumn(f"{note} •" + "[progress.percentage]{task.percentage:>3.0f}%"),
86 | BarColumn(),
87 | MofNCompleteColumn(),
88 | TextColumn("•"),
89 | TimeElapsedColumn(),
90 | )
91 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tempdir
2 | appdirs
3 | wget
4 | fire
5 | nltk
6 | rich
7 | numpy
8 | tree_sitter<=0.21.3
9 | tree_sitter_languages
10 | transformers
11 | openai
12 | anthropic
13 | google-generativeai
14 | vllm
15 |
--------------------------------------------------------------------------------
/results/.gitignore:
--------------------------------------------------------------------------------
1 | *.zip
2 | *.json
3 |
--------------------------------------------------------------------------------
/results/README.md:
--------------------------------------------------------------------------------
1 | ## Install GitHub CLI
2 |
3 | Check [GitHub CLI](https://github.com/cli/cli) for installation.
4 |
5 | ## Run evaluations
6 |
7 | ```bash
8 | export PYTHONPATH=$(pwd)
9 | ```
10 |
11 | ### vLLM
12 |
13 | Most OSS models can be evaluated with vLLM.
14 | The tricky part is to tune the `--tensor-parallel-size` (TP) according to the model size.
15 |
16 | * `--tensor-parallel-size` can be `1`, `2`, `4`, `8`.
17 | * If your `--tensor-parallel-size` is `2` or `4`, run `nvidia-smi topo -m` to check the GPU connectivity.
18 | * `NV4` means the GPUs are connected with NVLink (which is fast).
19 | * Set `CUDA_VISIBLE_DEVICES` to GPUs with good connectivity.
20 | * Models with < 10B may work with TP-1; 10-20B with TP-2; 20-40B with TP-4; > 40B with TP-8.
21 |
22 | ```bash
23 | python repoqa/search_needle_function.py --model "mistralai/Mistral-7B-Instruct-v0.2" --tensor-parallel-size 1 --backend vllm
24 | ```
25 |
26 | ### OpenAI server
27 |
28 | ```bash
29 | export OPENAI_API_KEY="sk-..." # OAI or DeepSeekCoder key
30 | python repoqa/search_needle_function.py --model "gpt-4o-2024-05-13" --backend openai
31 | python repoqa/search_needle_function.py --model "deepseek-coder" --backend openai
32 | ```
33 |
34 | ### Google Gemini
35 |
36 | ```bash
37 | export GOOGLE_API_KEY="..."
38 | python repoqa/search_needle_function.py --model "gemini-1.5-pro-latest" --backend google
39 | ```
40 |
41 | ## Pull evaluated results
42 |
43 | ```bash
44 | cd results
45 | gh release download dev-results --pattern "*.zip" --clobber
46 | # unzip all zip files
47 | unzip "*.zip"
48 | ```
49 |
50 | ## Update model outputs
51 |
52 | ```bash
53 | cd results
54 | # pull results first
55 | for item in "$(pwd)"/*; do
56 | # Check if the item is a directory
57 | if [ -d "$item" ]; then
58 | # Get the base name of the directory
59 | dir_name=$(basename "$item")
60 | zip -FSR "${dir_name}-output.zip" "$dir_name" "*.jsonl"
61 | zip -FSR "${dir_name}-scores.zip" "$dir_name" "*-SCORES.json"
62 | fi
63 | done
64 | gh release upload dev-results ./*.zip --clobber
65 | ```
66 |
--------------------------------------------------------------------------------
/scripts/cherrypick/README.md:
--------------------------------------------------------------------------------
1 | # GitHub Repository Curation
2 |
3 | ## Step 1
4 |
5 | ### PLs
6 |
7 | ```
8 | python
9 | go
10 | c++
11 | java
12 | typescript
13 | php
14 | rust
15 | ```
16 |
17 | ### GitHub searching
18 |
19 | Search at: https://github.com/search?
20 |
21 | ```
22 | language:{LANGUAGE} stars:>=100 license:mit license:apache-2.0 pushed:>=2024-01-01 size:128..32000 sort:stars
23 | ```
24 |
25 | ### Find repositories by package rankings
26 |
27 | Can also look at package ranking websites such as https://pypistats.org/top
28 |
29 | ## Step 2
30 |
31 | Select 10 repositories based on the following quality criteria:
32 | 1. Uses unit tests and CI/CD
33 | 2. Is a package
34 | 3. Diversify over topics (developer tool, machine learning, etc)
35 | 4. Has enough code (hopefully 10+ files)
36 | 5. Better be pure in one language
37 |
38 | ## Step 3
39 |
40 | Follow the `lists.json` to fill the repositories.
41 |
--------------------------------------------------------------------------------
/scripts/cherrypick/lists.json:
--------------------------------------------------------------------------------
1 | {
2 | "python": [
3 | {
4 | "repo": "psf/black",
5 | "commit_sha": "f03ee113c9f3dfeb477f2d4247bfb7de2e5f465c",
6 | "entrypoint_path": "src/black",
7 | "topic": "formatter"
8 | },
9 | {
10 | "repo": "python-poetry/poetry",
11 | "commit_sha": "21ffd99274d299bd6d2b7b0cdd7ba273fdd941be",
12 | "entrypoint_path": "src/poetry",
13 | "topic": "package-manager"
14 | },
15 | {
16 | "repo": "locustio/locust",
17 | "commit_sha": "6241e9cf0a5c296da98abfaeb9d66d0ad1db2390",
18 | "entrypoint_path": "locust",
19 | "topic": "testing"
20 | },
21 | {
22 | "repo": "pyg-team/pytorch_geometric",
23 | "commit_sha": "af0f5f44d4dfa8accd898df3e8523d06b67940b0",
24 | "entrypoint_path": "torch_geometric",
25 | "topic": "deep-learning"
26 | },
27 | {
28 | "repo": "openai/openai-python",
29 | "commit_sha": "e41abf7b7dbc1e744d167f748e55d4dedfc0dca7",
30 | "entrypoint_path": "src/openai",
31 | "topic": "api"
32 | },
33 | {
34 | "repo": "mlc-ai/mlc-llm",
35 | "commit_sha": "068d5ea9ca556f2f7a9603537b4f966da12b11f6",
36 | "entrypoint_path": "mlc_llm",
37 | "topic": "deep-learning-deployment"
38 | },
39 | {
40 | "repo": "reactive-python/reactpy",
41 | "commit_sha": "4307a09dfa75e6c1b10a285e5ae4bdf0323cd018",
42 | "entrypoint_path": "src/py/reactpy/reactpy",
43 | "topic": "frontend"
44 | },
45 | {
46 | "repo": "marshmallow-code/marshmallow",
47 | "commit_sha": "b8149cec77d16357d11b08f86de3b13e6fe02fa0",
48 | "entrypoint_path": "src/marshmallow",
49 | "topic": "schema"
50 | },
51 | {
52 | "repo": "ethereum/web3.py",
53 | "commit_sha": "257f464801d11431bf09132fb30bbeeb5937542e",
54 | "entrypoint_path": "web3",
55 | "topic": "web3"
56 | },
57 | {
58 | "repo": "Ciphey/Ciphey",
59 | "commit_sha": "5dfbe9330070d4aefc997e5d2648155a136a3757",
60 | "entrypoint_path": "ciphey",
61 | "topic": "cryptography"
62 | }
63 | ],
64 | "cpp": [
65 | {
66 | "repo": "apache/logging-log4cxx",
67 | "commit_sha": "502f5711809e7f48c215164e374a5df62821ed52",
68 | "entrypoint_path": "src/main/cpp",
69 | "topic": "logging"
70 | },
71 | {
72 | "repo": "skypjack/uvw",
73 | "commit_sha": "ba10b27646035594fc1dd9525dede8573b71a7d7",
74 | "entrypoint_path": "src",
75 | "topic": "network"
76 | },
77 | {
78 | "repo": "ClickHouse/clickhouse-cpp",
79 | "commit_sha": "0fb483543b313a0979b4dbd130f834352a034ba8",
80 | "entrypoint_path": "clickhouse",
81 | "topic": "database"
82 | },
83 | {
84 | "repo": "polybar/polybar",
85 | "commit_sha": "11b522c313f7b2b0a10063721ec8b0bf544de6f4",
86 | "entrypoint_path": "src",
87 | "topic": "developer-tool"
88 | },
89 | {
90 | "repo": "drogonframework/drogon",
91 | "commit_sha": "294035beb96f6d474d501c210a1b403b0b0dedb6",
92 | "entrypoint_path": "lib/src",
93 | "topic": "web-framework"
94 | },
95 | {
96 | "repo": "sass/node-sass",
97 | "commit_sha": "6081731aac89ce4612fe4839d4c6329539c0d8e1",
98 | "entrypoint_path": "src",
99 | "topic": "css-compiler"
100 | },
101 | {
102 | "repo": "ml-explore/mlx",
103 | "commit_sha": "91eba8e4856ba9a629408e18a9d0c1e79a22dab0",
104 | "entrypoint_path": "mlx",
105 | "topic": "machine-learning"
106 | },
107 | {
108 | "repo": "scylladb/seastar",
109 | "commit_sha": "b74a027c819ee5b9bfcee4f32edf8f37017bba12",
110 | "entrypoint_path": "src",
111 | "topic": "network"
112 | },
113 | {
114 | "repo": "WasmEdge/WasmEdge",
115 | "commit_sha": "a5e6f2d011776a1d997586b6e9af0ec01efa3642",
116 | "entrypoint_path": "lib",
117 | "topic": "wasm"
118 | },
119 | {
120 | "repo": "oatpp/oatpp",
121 | "commit_sha": "17ef2a7f6c8a932498799b2a5ae5aab2869975c7",
122 | "entrypoint_path": "src/oatpp",
123 | "topic": "web-framework"
124 | }
125 | ],
126 | "java": [
127 | {
128 | "repo": "google/gson",
129 | "commit_sha": "ee61e3f020351f0498dd3ff8f8386b96454a4dfe",
130 | "entrypoint_path": "gson/src/main/java",
131 | "topic": "developer-tool"
132 | },
133 | {
134 | "repo": "square/retrofit",
135 | "commit_sha": "10014c2bb7bd5fd24cf6b5b680e6917ab9a7767b",
136 | "entrypoint_path": "retrofit/src/main/java",
137 | "topic": "web-client"
138 | },
139 | {
140 | "repo": "karatelabs/karate",
141 | "commit_sha": "a378ba4f9e5af11eca1794d0bd55c428c8491bd8",
142 | "entrypoint_path": "karate-core/src/main/java",
143 | "topic": "testing"
144 | },
145 | {
146 | "repo": "Password4j/password4j",
147 | "commit_sha": "ded2c066d53df8dfaa8892a2868e04c7af655775",
148 | "entrypoint_path": "src/main/java",
149 | "topic": "cryptography"
150 | },
151 | {
152 | "repo": "GoogleContainerTools/jib",
153 | "commit_sha": "2e9561664792f4b13f0db741ca94c996cbece949",
154 | "entrypoint_path": "jib-core/src/main/java",
155 | "topic": "build-tool"
156 | },
157 | {
158 | "repo": "microsoft/gctoolkit",
159 | "commit_sha": "1476fc128465991f2c531645c780d3ed7c865f2c",
160 | "entrypoint_path": "parser/src/main/java",
161 | "topic": "log-parser"
162 | },
163 | {
164 | "repo": "palantir/palantir-java-format",
165 | "commit_sha": "b3ef776d8fd7be6e5a53319f7f4a40b7ee0793a4",
166 | "entrypoint_path": "palantir-java-format/src/main/java",
167 | "topic": "formatter"
168 | },
169 | {
170 | "repo": "microsoft/mssql-jdbc",
171 | "commit_sha": "eae6d7b33571c92029169b33378b2405e8df448a",
172 | "entrypoint_path": "src/main/java",
173 | "topic": "database"
174 | },
175 | {
176 | "repo": "google/conscrypt",
177 | "commit_sha": "74c3c1ab7fecf9e50097a77dcac6ff2c7afd2ccf",
178 | "entrypoint_path": "common/src/main/java",
179 | "topic": "security"
180 | },
181 | {
182 | "repo": "apache/flink-ml",
183 | "commit_sha": "f08f2756316c9fbd7fb7b7e39d140bfd3d959b2a",
184 | "entrypoint_path": "flink-ml-core/src/main/java",
185 | "topic": "machine-learning"
186 | }
187 | ],
188 | "typescript": [
189 | {
190 | "repo": "xenova/transformers.js",
191 | "commit_sha": "642743136efa3a481b2b96d3c2c550085540d844",
192 | "entrypoint_path": "src",
193 | "topic": "machine-learning"
194 | },
195 | {
196 | "repo": "expressjs/express",
197 | "commit_sha": "815f799310a5627c000d4a5156c1c958e4947b4c",
198 | "entrypoint_path": "lib",
199 | "topic": "web-server"
200 | },
201 | {
202 | "repo": "axios/axios",
203 | "commit_sha": "751133eb9ed794c6f6634c52f4fe116e33bf5f09",
204 | "entrypoint_path": "lib",
205 | "topic": "web-client"
206 | },
207 | {
208 | "repo": "date-fns/date-fns",
209 | "commit_sha": "5c1adb5369805ff552737bf8017dbe07f559b0c6",
210 | "entrypoint_path": "src",
211 | "topic": "time-library"
212 | },
213 | {
214 | "repo": "langchain-ai/langchainjs",
215 | "commit_sha": "1f4a4498b66af5910d02dd4b9d68d99aa1350548",
216 | "entrypoint_path": "langchain/src",
217 | "topic": "machine-learning"
218 | },
219 | {
220 | "repo": "cheeriojs/cheerio",
221 | "commit_sha": "d0b3c2f6b57cd1f835741175d463963266be0eef",
222 | "entrypoint_path": "src",
223 | "topic": "frontend"
224 | },
225 | {
226 | "repo": "openai/openai-node",
227 | "commit_sha": "018ac718ccf6a96798ef8f91906b3b652aa50919",
228 | "entrypoint_path": "src",
229 | "topic": "api"
230 | },
231 | {
232 | "repo": "umami-software/umami",
233 | "commit_sha": "be7f69fd5d3d711ba650a21013b73c7e17125f52",
234 | "entrypoint_path": "src",
235 | "topic": "analytics"
236 | },
237 | {
238 | "repo": "mrdoob/three.js",
239 | "commit_sha": "df4fb14945c896fdf628981dde31b31ef7e2e0cc",
240 | "entrypoint_path": "src",
241 | "topic": "3d"
242 | },
243 | {
244 | "repo": "firebase/firebase-admin-node",
245 | "commit_sha": "67151e620fbb7bdfc2a017e9a35d51eca035d824",
246 | "entrypoint_path": "src",
247 | "topic": "baas"
248 | }
249 | ],
250 | "rust": [
251 | {
252 | "repo": "rust-bakery/nom",
253 | "commit_sha": "e87c7da9fa2ddb943306369f6c6d9e914256edbc",
254 | "entrypoint_path": "src",
255 | "topic": "parser"
256 | },
257 | {
258 | "repo": "helix-editor/helix",
259 | "commit_sha": "e69292e5eb7b0f727fefa19cffec910718af31b3",
260 | "entrypoint_path": "helix-core/src",
261 | "topic": "code-editor"
262 | },
263 | {
264 | "repo": "rust-lang/cargo",
265 | "commit_sha": "5da28587846d6cb2694e5bb1db1b5ca327285bf7",
266 | "entrypoint_path": "src",
267 | "topic": "package-manager"
268 | },
269 | {
270 | "repo": "cloudflare/pingora",
271 | "commit_sha": "acee67f87020ef41267fe475bcc5cbed44782a06",
272 | "entrypoint_path": "pingora-core/src",
273 | "topic": "networking"
274 | },
275 | {
276 | "repo": "huggingface/candle",
277 | "commit_sha": "33c9b6655459bd1086574cef9ba8f2e72a8804c8",
278 | "entrypoint_path": "candle-core/src",
279 | "topic": "machine-learning"
280 | },
281 | {
282 | "repo": "seanmonstar/warp",
283 | "commit_sha": "ce8114b50cadedac089d496235fe9d18596f944e",
284 | "entrypoint_path": "src",
285 | "topic": "web-server"
286 | },
287 | {
288 | "repo": "serde-rs/serde",
289 | "commit_sha": "6e38afff498d592af4ccac4cb669a86fc789207f",
290 | "entrypoint_path": "serde/src",
291 | "topic": "serialization"
292 | },
293 | {
294 | "repo": "tokio-rs/tracing",
295 | "commit_sha": "908cc432a5994f6e17c8f36e13c217dc40085704",
296 | "entrypoint_path": "tracing-core/src",
297 | "topic": "application-analysis"
298 | },
299 | {
300 | "repo": "AFLplusplus/LibAFL",
301 | "commit_sha": "527b892c1ddcc83207faaa71591e7f33a1a806a7",
302 | "entrypoint_path": "libafl/src",
303 | "topic": "fuzzing"
304 | },
305 | {
306 | "repo": "alacritty/alacritty",
307 | "commit_sha": "d4f2f8577f763df059653dfab733dfe6ddc06913",
308 | "entrypoint_path": "alacritty/src",
309 | "topic": "terminal-emulator"
310 | }
311 | ],
312 | "go": [
313 | {
314 | "repo": "junegunn/fzf",
315 | "commit_sha": "e352b6887849cb6c3c8ae1d98ed357f94273e90a",
316 | "entrypoint_path": "src",
317 | "topic": "command-line tool"
318 | },
319 | {
320 | "repo": "fatedier/frp",
321 | "commit_sha": "eaae212d2d1b17360754afd9432c21640f15c832",
322 | "entrypoint_path": "pkg",
323 | "topic": "reverse proxy"
324 | },
325 | {
326 | "repo": "caddyserver/caddy",
327 | "commit_sha": "4d6370bf92de163a53aec9081c5d5ae6614597a0",
328 | "entrypoint_path": "modules",
329 | "topic": "web server"
330 | },
331 | {
332 | "repo": "nsqio/nsq",
333 | "commit_sha": "2127c0a1ce50a66d61d78496033b15ef5cb5e250",
334 | "entrypoint_path": "internal",
335 | "topic": "messaging platform"
336 | },
337 | {
338 | "repo": "projectdiscovery/nuclei",
339 | "commit_sha": "3dfcec0a36dbf5ebc529ce0478076279cb975b71",
340 | "entrypoint_path": "pkg",
341 | "topic": "vulnerability scanner"
342 | },
343 | {
344 | "repo": "jesseduffield/lazydocker",
345 | "commit_sha": "80af149b023f4d9be68a643412183d389d73368d",
346 | "entrypoint_path": "pkg",
347 | "topic": "docker management"
348 | },
349 | {
350 | "repo": "zeromicro/go-zero",
351 | "commit_sha": "b337ae36e50092488c1899069f860007447fa17c",
352 | "entrypoint_path": "core",
353 | "topic": "cloud-native microservices"
354 | },
355 | {
356 | "repo": "schollz/croc",
357 | "commit_sha": "d81116382fcb9dddb79b02ed4b0da99e7aecb2ab",
358 | "entrypoint_path": "src",
359 | "topic": "data transfer"
360 | },
361 | {
362 | "repo": "iawia002/lux",
363 | "commit_sha": "59f517bb7138acd4fffcced4560197e83a592fc2",
364 | "entrypoint_path": "extractors",
365 | "topic": "video download"
366 | },
367 | {
368 | "repo": "lima-vm/lima",
369 | "commit_sha": "5f6bfc951f5df6f8d21e3e2ccf9693146216ef4d",
370 | "entrypoint_path": "cmd",
371 | "topic": "Linux virtual machines"
372 | }
373 | ]
374 | }
375 |
--------------------------------------------------------------------------------
/scripts/curate/dataset_ensemble_clone.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 | from datetime import datetime
7 |
8 | import git
9 | import tempdir
10 | from fire import Fire
11 | from tqdm.auto import tqdm
12 |
13 | from scripts.curate.utility import lang2suffix
14 |
15 |
16 | def get_files_to_include(gh_repo, entrypoint, lang_suffix):
17 | files_to_include = []
18 | for entry in gh_repo.commit().tree.traverse():
19 | if entry.path.startswith(entrypoint) and any(
20 | [entry.path.endswith(suffix) for suffix in lang_suffix]
21 | ):
22 | files_to_include.append((entry.path, entry.abspath))
23 | return files_to_include
24 |
25 |
26 | def main(
27 | target_path: str = f"repoqa-{datetime.now().isoformat()}.json",
28 | ):
29 | # read /scripts/cherrypick/lists.json
30 | with open("scripts/cherrypick/lists.json") as f:
31 | lists = json.load(f)
32 |
33 | for lang, repos in lists.items():
34 | lang_suffix = lang2suffix[lang]
35 | for repo in tqdm(repos):
36 | repo_name = repo["repo"]
37 | commit_sha = repo["commit_sha"]
38 | entrypoint = repo["entrypoint_path"]
39 |
40 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}")
41 |
42 | if repo.get("content"):
43 | print(f"Skipping {repo_name} as it already has content.")
44 | continue
45 |
46 | with tempdir.TempDir() as temp_dir:
47 | gh_repo = git.Repo.clone_from(
48 | f"https://github.com/{repo_name}.git",
49 | temp_dir,
50 | )
51 | gh_repo.git.checkout(commit_sha)
52 |
53 | files_to_include = get_files_to_include(
54 | gh_repo, entrypoint, lang_suffix
55 | )
56 |
57 | repo["content"] = {}
58 | for path, abspath in files_to_include:
59 | with open(abspath, "r") as f:
60 | repo["content"][path] = f.read()
61 | with open(target_path, "w") as f_out:
62 | json.dump(lists, f_out)
63 |
64 |
65 | if __name__ == "__main__":
66 | Fire(main)
67 |
--------------------------------------------------------------------------------
/scripts/curate/dataset_ensemble_gh_api.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | """
6 | ! Note: not fully usable. You might encounter:
7 | github.GithubException.GithubException: 422 {"message": "Validation Failed", "errors": [{"message": "The listed users and repositories cannot be searched either because the resources do not exist o
8 | r you do not have permission to view them.", "resource": "Search", "field": "q", "code": "invalid"}], "documentation_url": "https://docs.github.com/v3/search/"}
9 | """
10 |
11 | import json
12 | import os
13 | from datetime import datetime
14 | from typing import TypedDict
15 |
16 | from fire import Fire
17 | from github import Auth, Github
18 | from github.Repository import Repository
19 | from tqdm.auto import tqdm
20 |
21 | from scripts.curate.utility import lang2suffix
22 |
23 |
24 | class GitHubRepoMeta(TypedDict):
25 | repo_name: str
26 | repo_owner: str
27 | commit_sha: str
28 | repo_size: int
29 |
30 |
31 | class GitHubDocument(GitHubRepoMeta):
32 | timestamp: str
33 | path: str
34 | content: str
35 |
36 |
37 | def main(
38 | target_path: str = f"repoqa-{datetime.now().isoformat()}.json",
39 | ):
40 | token = os.getenv("GITHUB_TOKEN")
41 | assert token is not None, "Make a token at https://github.com/settings/tokens"
42 | auth = Auth.Token(token)
43 |
44 | # read /scripts/cherrypick/lists.json
45 | with open("scripts/cherrypick/lists.json") as f:
46 | lists = json.load(f)
47 |
48 | g = Github(auth=auth, per_page=1)
49 | for lang, repos in lists.items():
50 | lang_suffix = lang2suffix[lang]
51 | for repo in tqdm(repos):
52 | if repo.get("content"):
53 | print(f"Skipping {repo['repo']} as it already has content.")
54 | continue
55 |
56 | repo_name = repo["repo"]
57 | commit_sha = repo["commit_sha"]
58 | entrypoint = repo["entrypoint_path"]
59 | query = f"repo:{repo_name}"
60 |
61 | gh_repos = g.search_repositories(query)
62 | gh_repo: Repository = gh_repos[0]
63 | contents = [
64 | item
65 | for item in gh_repo.get_contents(entrypoint, ref=commit_sha)
66 | if any([item.path.endswith(suffix) for suffix in lang_suffix])
67 | ]
68 |
69 | repo["content"] = {}
70 | for item in contents:
71 | if item.encoding != "base64":
72 | continue
73 | file_content = item.decoded_content.decode("utf-8")
74 | repo["content"][item.path] = file_content
75 |
76 | with open(target_path, "w") as f_out:
77 | json.dump(lists, f_out)
78 |
79 |
80 | if __name__ == "__main__":
81 | Fire(main)
82 |
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/cpp.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 | import os
7 | from pathlib import Path
8 | from typing import List, Union
9 |
10 | import git
11 | import tempdir
12 | from fire import Fire
13 | from tqdm.auto import tqdm
14 |
15 | from scripts.curate.utility import lang2suffix
16 |
17 | POTENTIAL_INCLUDE_DIR_NAMES = [
18 | "include",
19 | "src",
20 | "lib",
21 | "libs",
22 | "library",
23 | "libraries",
24 | "inc",
25 | ]
26 |
27 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
28 |
29 |
30 | def get_potential_include_dirs(temp_dir: str) -> List[Path]:
31 | include_dir_list = [Path(temp_dir)]
32 | for potential_include_dir_name in POTENTIAL_INCLUDE_DIR_NAMES:
33 | for dir in Path(temp_dir).rglob(potential_include_dir_name):
34 | include_dir_list.append(dir)
35 |
36 | return include_dir_list
37 |
38 |
39 | def search_include(
40 | entrypoint: str,
41 | temp_dir: str,
42 | cpp_file: Path,
43 | include: str,
44 | additional_libs: List[Path] = [],
45 | ) -> Union[Path, None]:
46 | for possible_path in [cpp_file.parent] + additional_libs:
47 | if not (entrypoint in str(possible_path)):
48 | continue
49 | if (possible_path / include).exists():
50 | return (
51 | (possible_path / include)
52 | .resolve()
53 | .relative_to(Path(temp_dir))
54 | .as_posix()
55 | )
56 |
57 |
58 | def get_dependencies(temp_dir: str, entrypoint: str, cpp_file: Path):
59 | user_defined_includes = []
60 | with open(cpp_file) as f:
61 | content = f.read()
62 | for line in content.split("\n"):
63 | line = line.strip()
64 | if line.startswith("#include"):
65 | include = line.split()[1]
66 | elif line.split()[:2] == ["#", "include"]:
67 | include = line.split()[2]
68 | else:
69 | continue
70 |
71 | if (include.startswith('"') or include.startswith("<")) and "." in include:
72 | include = include[1:-1]
73 | rela_include = search_include(
74 | entrypoint,
75 | temp_dir,
76 | cpp_file,
77 | include,
78 | additional_libs=get_potential_include_dirs(temp_dir),
79 | )
80 | if rela_include:
81 | user_defined_includes.append(rela_include)
82 |
83 | return user_defined_includes
84 |
85 |
86 | # dataset_path is the dataset generated by dataset_ensemble_clone.py
87 | def main():
88 | with open("scripts/cherrypick/lists.json") as f:
89 | lists = json.load(f)
90 |
91 | lang_suffix = lang2suffix["cpp"]
92 | repos = lists["cpp"]
93 | for repo in tqdm(repos):
94 | repo_name = repo["repo"]
95 | commit_sha = repo["commit_sha"]
96 | entrypoint = repo["entrypoint_path"]
97 |
98 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}")
99 |
100 | if repo.get("dependency"):
101 | print(f"Skipping {repo_name} as it already has dependency field.")
102 | continue
103 |
104 | with tempdir.TempDir() as temp_dir:
105 | gh_repo = git.Repo.clone_from(
106 | f"https://github.com/{repo_name}.git",
107 | temp_dir,
108 | )
109 | gh_repo.git.checkout(commit_sha)
110 |
111 | output_dependency = {}
112 | abs_prefix = os.path.join(temp_dir, entrypoint)
113 | for cpp_ext in lang_suffix:
114 | for cpp_file in Path(abs_prefix).rglob("*" + cpp_ext):
115 | dependencies = get_dependencies(temp_dir, entrypoint, cpp_file)
116 | if len(dependencies) > 0:
117 | if "" in dependencies:
118 | dependencies.remove("")
119 | output_dependency[
120 | cpp_file.relative_to(Path(temp_dir)).as_posix()
121 | ] = dependencies
122 | else:
123 | output_dependency[
124 | cpp_file.relative_to(Path(temp_dir)).as_posix()
125 | ] = []
126 |
127 | repo["dependency"] = output_dependency
128 |
129 | with open(os.path.join(CURRENT_DIR, "data", "cpp.json"), "w") as f_out:
130 | json.dump({"cpp": repos}, f_out)
131 |
132 |
133 | if __name__ == "__main__":
134 | Fire(main)
135 |
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/data/.gitignore:
--------------------------------------------------------------------------------
1 | *.json
2 |
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/go-analysis/dependency_analysis.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 |
4 | import (
5 | "encoding/json"
6 | "go/ast"
7 | "go/parser"
8 | "go/token"
9 | "io/fs"
10 | "os"
11 | "path/filepath"
12 | "strings"
13 | )
14 |
15 |
16 | type DependencyMap struct {
17 | Name string `json:"name"`
18 | RepoName map[string][]string `json:"repoName"`
19 | }
20 |
21 |
22 | func main() {
23 | if len(os.Args) != 3 {
24 | os.Stderr.WriteString("Usage: go run main.go \n")
25 | os.Exit(1)
26 | }
27 |
28 |
29 | rootDir := os.Args[1] // This should be the 'src' directory
30 | mapName := os.Args[2] // Name passed in while calling from terminal
31 | repoMap := make(map[string][]string)
32 | fset := token.NewFileSet()
33 |
34 |
35 | // Walk through the directory and all its subdirectories
36 | filepath.WalkDir(rootDir, func(path string, d fs.DirEntry, err error) error {
37 | if err != nil {
38 | return err
39 | }
40 |
41 |
42 | if !d.IsDir() && strings.HasSuffix(d.Name(), ".go") {
43 | parsedFile, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
44 | if err != nil {
45 | return err
46 | }
47 |
48 |
49 | // Initialize dependencies as an empty slice instead of nil
50 | dependencies := make(map[string]bool)
51 |
52 |
53 | // Analyze AST and find dependencies
54 | ast.Inspect(parsedFile, func(n ast.Node) bool {
55 | switch x := n.(type) {
56 | case *ast.SelectorExpr:
57 | if ident, ok := x.X.(*ast.Ident); ok {
58 | pkg := ident.Name
59 |
60 |
61 | // Check for local files that may correspond to this identifier
62 | filepath.WalkDir(rootDir, func(depPath string, depInfo fs.DirEntry, depErr error) error {
63 | if depErr != nil {
64 | return depErr
65 | }
66 |
67 |
68 | if !depInfo.IsDir() && strings.TrimSuffix(depInfo.Name(), ".go") == pkg {
69 | relPath, err := filepath.Rel(rootDir, strings.TrimPrefix(depPath, rootDir+string(os.PathSeparator)))
70 | if err == nil {
71 | dependencies[relPath] = true
72 | }
73 | }
74 |
75 |
76 | return nil
77 | })
78 | }
79 | }
80 |
81 |
82 | return true
83 | })
84 |
85 |
86 | // Convert map keys to a slice
87 | deps := make([]string, 0) // Initialize as an empty slice
88 | for dep := range dependencies {
89 | deps = append(deps, dep)
90 | }
91 |
92 |
93 | fileRelPath, _ := filepath.Rel(rootDir, strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
94 | repoMap[fileRelPath] = deps
95 | }
96 |
97 |
98 | return nil
99 | })
100 |
101 |
102 | output := DependencyMap{Name: mapName, RepoName: repoMap}
103 | result, err := json.Marshal(output) // Change back to Marshal for single-line JSON output
104 | if err != nil {
105 | panic(err)
106 | }
107 |
108 |
109 | // Write the result to a file
110 | file, err := os.Create("output.json")
111 | if err != nil {
112 | panic(err)
113 | }
114 | defer file.Close()
115 |
116 |
117 | _, err = file.Write(result)
118 | if err != nil {
119 | panic(err)
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/go.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 | import os
7 | import shutil
8 | import subprocess
9 | from typing import Dict, List
10 |
11 | import git
12 | import tempdir
13 | from fire import Fire
14 | from tqdm.auto import tqdm
15 |
16 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
17 |
18 |
19 | def remove_relative(path: str) -> str:
20 | parts = path.split(os.sep)
21 | filtered_parts = [part for part in parts if part not in (".", "..")]
22 | new_path = os.path.join(*filtered_parts)
23 | return new_path
24 |
25 |
26 | def sanitize_paths(data: Dict[str, List[str]], entrypoint: str) -> Dict[str, List[str]]:
27 | sanitized_data = {}
28 | for file, dependencies in data.items():
29 | updated_file = os.path.join(entrypoint, remove_relative(file))
30 | updated_dependencies = []
31 | for dependency in dependencies:
32 | updated_dependency = os.path.join(entrypoint, remove_relative(dependency))
33 | updated_dependencies.append(updated_dependency)
34 | sanitized_data[updated_file] = updated_dependencies
35 | return sanitized_data
36 |
37 |
38 | def run_dependency_analysis(config_file, go_file):
39 | # Load the JSON configuration
40 | with open(config_file, "r") as file:
41 | data = json.load(file)
42 |
43 | repos = data["go"]
44 |
45 | # Iterate over each repo entry in the JSON configuration
46 | for entry in tqdm(repos):
47 | repo_name = entry["repo"]
48 | commit_sha = entry["commit_sha"]
49 | entrypoint_path = entry["entrypoint_path"]
50 |
51 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}")
52 |
53 | with tempdir.TempDir() as temp_dir:
54 | gh_repo = git.Repo.clone_from(
55 | f"https://github.com/{repo_name}.git",
56 | temp_dir,
57 | )
58 | gh_repo.git.checkout(commit_sha)
59 | shutil.copy(go_file, temp_dir)
60 |
61 | command_list = (
62 | f"go build -o dependency_analysis dependency_analysis.go".split()
63 | )
64 | subprocess.run(command_list, cwd=temp_dir)
65 |
66 | command_list = f"./dependency_analysis {entrypoint_path} {repo_name.split('/')[-1]}".split()
67 |
68 | subprocess.run(command_list, cwd=temp_dir)
69 | output_dir = os.path.join(temp_dir, "output.json")
70 | with open(output_dir, "r") as output_file:
71 | output_data = json.load(output_file)
72 | entry["dependency"] = sanitize_paths(
73 | output_data["repoName"], entrypoint_path
74 | )
75 |
76 | # Write all output data to a file
77 | with open(os.path.join(CURRENT_DIR, "data", "go.json"), "w") as f_out:
78 | json.dump({"go": repos}, f_out)
79 |
80 |
81 | def main():
82 | config_file = "scripts/cherrypick/lists.json"
83 | go_file = "scripts/curate/dep_analysis/go-analysis/dependency_analysis.go"
84 |
85 | run_dependency_analysis(config_file, go_file)
86 |
87 |
88 | if __name__ == "__main__":
89 | Fire(main)
90 |
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/java-analysis/dependency-reduced-pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | 4.0.0
4 | edu.cs.illinois.repoqa
5 | java-analysis
6 | 1.0-SNAPSHOT
7 |
8 |
9 |
10 | maven-shade-plugin
11 | 3.2.4
12 |
13 |
14 | package
15 |
16 | shade
17 |
18 |
19 |
20 |
21 | edu.cs.illinois.repoqa.DepAnalyze
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | maven-antrun-plugin
30 | 3.0.0
31 |
32 |
33 | package
34 |
35 | run
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 | 8
49 | 8
50 |
51 |
52 |
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/java-analysis/java-lib/java-analysis-1.0-SNAPSHOT.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/evalplus/repoqa/ae876deb1365dbf5a15b0533723c8ed123eee586/scripts/curate/dep_analysis/java-analysis/java-lib/java-analysis-1.0-SNAPSHOT.jar
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/java-analysis/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | edu.cs.illinois.repoqa
8 | java-analysis
9 | 1.0-SNAPSHOT
10 |
11 |
12 | 8
13 | 8
14 |
15 |
16 |
17 |
18 | com.github.javaparser
19 | javaparser-core
20 | 3.25.9
21 |
22 |
23 | com.github.javaparser
24 | javaparser-symbol-solver-core
25 | 3.25.9
26 |
27 |
28 |
29 |
30 |
31 |
32 | org.apache.maven.plugins
33 | maven-shade-plugin
34 | 3.2.4
35 |
36 |
37 | package
38 |
39 | shade
40 |
41 |
42 |
43 |
44 | edu.cs.illinois.repoqa.DepAnalyze
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 | org.apache.maven.plugins
53 | maven-antrun-plugin
54 | 3.0.0
55 |
56 |
57 | package
58 |
59 | run
60 |
61 |
62 |
63 |
64 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/java-analysis/src/main/java/edu/cs/illinois/repoqa/DepAnalyze.java:
--------------------------------------------------------------------------------
1 | package edu.cs.illinois.repoqa;
2 |
3 | import java.io.File;
4 | import java.nio.file.Path;
5 | import java.nio.file.Paths;
6 | import java.util.ArrayList;
7 | import java.util.Collections;
8 | import java.util.HashSet;
9 | import java.util.List;
10 |
11 | import com.github.javaparser.ParserConfiguration;
12 | import com.github.javaparser.ast.CompilationUnit;
13 | import com.github.javaparser.ast.ImportDeclaration;
14 | import com.github.javaparser.ast.expr.NameExpr;
15 | import com.github.javaparser.symbolsolver.JavaSymbolSolver;
16 | import com.github.javaparser.symbolsolver.resolution.typesolvers.CombinedTypeSolver;
17 | import com.github.javaparser.symbolsolver.resolution.typesolvers.ReflectionTypeSolver;
18 | import com.github.javaparser.utils.SourceRoot;
19 |
20 | public class DepAnalyze {
21 | public static String getStartPackage(Path srcRootPath, Path filePath) {
22 | return srcRootPath.relativize(filePath.getParent()).toString().replace(File.separator, ".");
23 | }
24 |
25 | public static void analyze(String repoPath, String entryPoint, String filePath) {
26 | Path srcRootPath = Paths.get(repoPath, entryPoint).toAbsolutePath(); // src/main/java
27 | CombinedTypeSolver combinedTypeSolver = new CombinedTypeSolver();
28 | combinedTypeSolver.add(new ReflectionTypeSolver());
29 | JavaSymbolSolver symbolSolver = new JavaSymbolSolver(combinedTypeSolver);
30 | ParserConfiguration parserConfiguration = new ParserConfiguration().setSymbolResolver(symbolSolver).setLanguageLevel(ParserConfiguration.LanguageLevel.JAVA_14);
31 | SourceRoot sourceRoot = new SourceRoot(srcRootPath, parserConfiguration);
32 | CompilationUnit cu = sourceRoot.parse(getStartPackage(srcRootPath, Paths.get(filePath)), new File(filePath).getName());
33 |
34 | List depPaths = new ArrayList<>();
35 | depPaths.addAll(getImportDepPaths(cu, srcRootPath));
36 | depPaths.addAll(getSamePackageDepPaths(cu, Paths.get(filePath).toAbsolutePath()));
37 |
38 | depPaths = new ArrayList<>(new HashSet<>(depPaths)); // remove duplicates
39 | Collections.sort(depPaths);
40 | depPaths.remove(Paths.get(repoPath)); // remove the current file from the dependencies
41 | depPaths.forEach(p -> System.out.println(Paths.get(repoPath).relativize(p)));
42 | }
43 |
44 | private static List getImportDepPaths(CompilationUnit cu, Path srcRootPath) {
45 | List depPaths = new ArrayList<>();
46 | for (ImportDeclaration importDeclaration : cu.getImports()) {
47 | String importStr = importDeclaration.getNameAsString();
48 | if (!importDeclaration.isAsterisk()) {
49 | Path depPath = srcRootPath.resolve(importStr.replace(".", File.separator) + ".java");
50 | if (depPath.toFile().exists()) {
51 | depPaths.add(depPath);
52 | } else {
53 | Path possibleJavaPath = srcRootPath.resolve(importStr.substring(0, importStr.lastIndexOf(".")).replace(".", File.separator) + ".java");
54 | if (possibleJavaPath.toFile().exists()) {
55 | // this indicates that the import is like "import static com.example.Main.main;"
56 | depPaths.add(possibleJavaPath);
57 | }
58 | }
59 | }
60 | else {
61 | Path depDirPath = srcRootPath.resolve(importStr.replace(".", File.separator));
62 | Path possibleJavaPath = srcRootPath.resolve(importStr.substring(0, importStr.length()).replace(".", File.separator) + ".java");
63 | if (possibleJavaPath.toFile().exists()) {
64 | // this indicates that the import is like "import com.example.Main.*;"
65 | depPaths.add(possibleJavaPath);
66 | continue;
67 | }
68 | if (depDirPath.toFile().exists()) {
69 | File[] files = depDirPath.toFile().listFiles();
70 | for (File file : files) {
71 | if (file.isFile() && file.getName().endsWith(".java")) {
72 | depPaths.add(file.toPath());
73 | }
74 | }
75 | }
76 | }
77 | }
78 |
79 | return depPaths;
80 | }
81 |
82 | private static List getSamePackageDepPaths(CompilationUnit cu, Path filePath) {
83 | List depPaths = new ArrayList<>();
84 | List siblingClassSimpleNameList = new ArrayList<>();
85 | for (File siblingFile : filePath.getParent().toFile().listFiles()) {
86 | if (siblingFile.getAbsolutePath().endsWith(".java") && !siblingFile.getAbsolutePath().equals(filePath.toFile().getAbsolutePath())) {
87 | siblingClassSimpleNameList.add(siblingFile.getName().substring(0, siblingFile.getName().length() - 5));
88 | }
89 | }
90 |
91 | // check all identifiers in the current file
92 | for (NameExpr name : cu.findAll(NameExpr.class)) {
93 | String identifier = name.getNameAsString();
94 | if (siblingClassSimpleNameList.contains(identifier)) {
95 | Path depPath = filePath.getParent().resolve(identifier + ".java");
96 | if (depPath.toFile().exists()) {
97 | depPaths.add(depPath);
98 | }
99 | }
100 | }
101 |
102 | return depPaths;
103 | }
104 |
105 | public static void main(String[] args) {
106 | String repoPath = args[0];
107 | String entryPoint = args[1]; // entry point is relative to the repoPath
108 | String filePath = args[2];
109 | analyze(repoPath, entryPoint, filePath);
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/java.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 | import os
7 | import subprocess
8 | from pathlib import Path
9 |
10 | import git
11 | import tempdir
12 | from fire import Fire
13 | from tqdm.auto import tqdm
14 |
15 | TOOL_JAR_PATH = (
16 | Path(__file__).resolve().parent
17 | / "java-analysis"
18 | / "java-lib"
19 | / "java-analysis-1.0-SNAPSHOT.jar"
20 | )
21 |
22 |
23 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
24 |
25 | # dataset_path is the dataset generated by dataset_ensemble_clone.py
26 | def main():
27 | with open("scripts/cherrypick/lists.json") as f:
28 | lists = json.load(f)
29 |
30 | repos = lists["java"]
31 | for repo in tqdm(repos):
32 | repo_name = repo["repo"]
33 | commit_sha = repo["commit_sha"]
34 | entrypoint = repo["entrypoint_path"]
35 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}")
36 |
37 | with tempdir.TempDir() as temp_dir:
38 | gh_repo = git.Repo.clone_from(
39 | f"https://github.com/{repo_name}.git",
40 | temp_dir,
41 | )
42 | gh_repo.git.checkout(commit_sha)
43 |
44 | output_dependency = {}
45 | abs_prefix = os.path.join(temp_dir, entrypoint)
46 | for java_file in Path(abs_prefix).rglob("*.java"):
47 | command_list = f"java -jar {TOOL_JAR_PATH} {temp_dir} {entrypoint} {java_file.absolute()}".split()
48 | # cd `temp_dir`` and capture the output
49 | output = subprocess.check_output(command_list, cwd=temp_dir)
50 | dependencies = output.decode("utf-8").strip().split("\n")
51 |
52 | if dependencies:
53 | if "" in dependencies:
54 | dependencies.remove("")
55 | output_dependency[
56 | java_file.relative_to(Path(temp_dir)).as_posix()
57 | ] = dependencies
58 | else:
59 | output_dependency[
60 | java_file.relative_to(Path(temp_dir)).as_posix()
61 | ] = []
62 |
63 | repo["dependency"] = output_dependency
64 |
65 | with open(os.path.join(CURRENT_DIR, "data", "java.json"), "w") as f_out:
66 | json.dump({"java": repos}, f_out)
67 |
68 |
69 | if __name__ == "__main__":
70 | Fire(main)
71 |
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/python.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 | import os
7 | import subprocess
8 |
9 | import git
10 | import tempdir
11 | from fire import Fire
12 | from tqdm.auto import tqdm
13 |
14 | from scripts.curate.dataset_ensemble_clone import get_files_to_include
15 | from scripts.curate.utility import lang2suffix
16 |
17 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
18 |
19 |
20 | # dataset_path is the dataset generated by dataset_ensemble_clone.py
21 | def main():
22 | with open("scripts/cherrypick/lists.json") as f:
23 | lists = json.load(f)
24 |
25 | lang_suffix = lang2suffix["python"]
26 | repos = lists["python"]
27 | for repo in tqdm(repos):
28 | repo_name = repo["repo"]
29 | commit_sha = repo["commit_sha"]
30 | entrypoint = repo["entrypoint_path"]
31 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}")
32 |
33 | with tempdir.TempDir() as temp_dir:
34 | gh_repo = git.Repo.clone_from(
35 | f"https://github.com/{repo_name}.git",
36 | temp_dir,
37 | )
38 | gh_repo.git.checkout(commit_sha)
39 | command_list = f"pydeps {entrypoint} --show-deps --no-show".split()
40 | # cd `temp_dir`` and capture the output json
41 | output = subprocess.check_output(command_list, cwd=temp_dir)
42 | dependencies = json.loads(output)
43 |
44 | output_dependency = {}
45 | mod2path = {}
46 |
47 | abs_prefix = os.path.join(temp_dir, entrypoint)
48 | for mod, v in dependencies.items():
49 | if v["path"] and v["path"].startswith(abs_prefix):
50 | mod2path[mod] = v["path"]
51 |
52 | for mod, v in dependencies.items():
53 | if v["path"] and v["path"].startswith(abs_prefix):
54 | if "imports" in v:
55 | relative_path = os.path.relpath(v["path"], temp_dir)
56 | output_dependency[relative_path] = [
57 | os.path.relpath(mod2path[imp], temp_dir)
58 | for imp in v["imports"]
59 | if imp in mod2path
60 | ]
61 |
62 | files_to_include = get_files_to_include(gh_repo, entrypoint, lang_suffix)
63 | for path, _ in files_to_include:
64 | if path not in output_dependency:
65 | output_dependency[path] = []
66 |
67 | assert output_dependency, "Empty output_dependency"
68 | repo["dependency"] = output_dependency
69 |
70 | with open(os.path.join(CURRENT_DIR, "data", "python.json"), "w") as f_out:
71 | json.dump({"python": repos}, f_out)
72 |
73 |
74 | if __name__ == "__main__":
75 | Fire(main)
76 |
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/rust.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | import json
5 | import os
6 | import subprocess
7 | from enum import Enum
8 | from pathlib import Path
9 | from typing import List, Tuple
10 |
11 | import git
12 | import pygraphviz as pgv
13 | import tempdir
14 | from fire import Fire
15 | from pygraphviz import Node
16 | from tqdm.auto import tqdm
17 |
18 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
19 |
20 |
21 | class CrateType(Enum):
22 | LIB = ("lib",)
23 | BIN = "bin"
24 |
25 |
26 | # Given a namespace, find its location in the file directory
27 | def find_path(node: Node, temp_dir: str, entrypoint: str, type: CrateType) -> str:
28 | module_path = node.get_name()
29 | temp_dir = Path(temp_dir)
30 | entrypoint = Path(entrypoint)
31 | paths = module_path.split("::")
32 |
33 | current_path = entrypoint
34 | best_path = None
35 |
36 | for path in paths[1:]:
37 | current_path = current_path / path
38 | if (temp_dir / current_path / "mod.rs").exists():
39 | best_path = current_path / "mod.rs"
40 | continue
41 | elif (temp_dir / current_path).with_suffix(".rs").exists():
42 | best_path = current_path.with_suffix(".rs")
43 | break
44 |
45 | if not (best_path):
46 | if type == CrateType.LIB:
47 | return os.path.join(entrypoint, "lib.rs")
48 | elif type == CrateType.BIN:
49 | return os.path.join(entrypoint, "main.rs")
50 | return str(best_path)
51 |
52 |
53 | # Find all buildable packages within the entrypoint
54 | def find_packages(temp_dir: str, entrypoint: str) -> List[Tuple[str, str, CrateType]]:
55 | command_list = f"cargo metadata -q --no-deps --format-version=1".split()
56 | output = subprocess.check_output(command_list, cwd=temp_dir)
57 | decoded_output = output.decode("utf-8")
58 |
59 | result = json.loads(decoded_output)
60 | packages = result["packages"]
61 | results = []
62 | for package in packages:
63 | for target in package["targets"]:
64 | if not (os.path.join(temp_dir, entrypoint) in target["src_path"]):
65 | continue
66 | if target["kind"][0] == "lib":
67 | results.append((target["name"], target["src_path"], CrateType.LIB))
68 | elif target["kind"][0] == "bin":
69 | results.append((target["name"], target["src_path"], CrateType.BIN))
70 | return results
71 |
72 |
73 | def is_cargo_modules_available() -> bool:
74 | try:
75 | command_list = ["cargo-modules", "--help"]
76 | subprocess.run(command_list, capture_output=True)
77 | return True
78 | except (subprocess.CalledProcessError, FileNotFoundError):
79 | return False
80 |
81 |
82 | # dataset_path is the dataset generated by dataset_ensemble_clone.py
83 | def main():
84 | if not (is_cargo_modules_available()):
85 | print("cargo-modules tool not found, exiting...")
86 | return
87 |
88 | with open("scripts/cherrypick/lists.json") as f:
89 | lists = json.load(f)
90 |
91 | repos = lists["rust"]
92 | for repo in tqdm(repos):
93 | repo_name = repo["repo"]
94 | commit_sha = repo["commit_sha"]
95 | entrypoint = repo["entrypoint_path"]
96 |
97 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}")
98 |
99 | with tempdir.TempDir() as temp_dir:
100 | gh_repo = git.Repo.clone_from(
101 | f"https://github.com/{repo_name}.git",
102 | temp_dir,
103 | )
104 | gh_repo.git.checkout(commit_sha)
105 | packages = find_packages(temp_dir, entrypoint)
106 | mapping = {}
107 | dependency = {}
108 | edges = []
109 | for package_name, src_path, crate_type in packages:
110 | if "_" in package_name:
111 | package_name = "-".join(package_name.split("_"))
112 | analysis_param = ""
113 |
114 | # Run cargo-modules tool
115 | if crate_type == CrateType.BIN:
116 | analysis_param = f"--bin {package_name}"
117 | else:
118 | analysis_param = f"--lib"
119 | command_list = f"cargo modules dependencies --package {package_name} {analysis_param} --cfg-test --no-sysroot --no-traits --no-types".split()
120 | entrypoint = os.path.dirname(os.path.relpath(src_path, temp_dir))
121 |
122 | # cd `temp_dir`` and capture the output json
123 | output = subprocess.check_output(command_list, cwd=temp_dir)
124 | decoded_output = output.decode("utf-8")
125 |
126 | # Parse graph and get mapping
127 | graph = pgv.AGraph(string=decoded_output)
128 | for node in graph.nodes():
129 | node_name = node.get_name()
130 | mapping[node_name] = find_path(
131 | node, temp_dir, entrypoint, crate_type
132 | )
133 | if mapping[node_name] in dependency:
134 | continue
135 | dependency[mapping[node_name]] = set()
136 |
137 | # Save edges for later
138 | for start_node, end_node in graph.edges():
139 | edges.append((start_node.get_name(), end_node.get_name()))
140 |
141 | # Parse every edge for dependencies
142 | for start_name, end_name in edges:
143 | if (
144 | not (start_name in mapping and end_name in mapping)
145 | or mapping[start_name] == mapping[end_name]
146 | ):
147 | continue
148 | dependency[mapping[start_name]].add(mapping[end_name])
149 |
150 | for key in dependency:
151 | dependency[key] = list(dependency[key])
152 |
153 | repo["dependency"] = dependency
154 |
155 | with open(os.path.join(CURRENT_DIR, "data", "rust.json"), "w") as f_out:
156 | json.dump({"rust": repos}, f_out)
157 |
158 |
159 | if __name__ == "__main__":
160 | Fire(main)
161 |
--------------------------------------------------------------------------------
/scripts/curate/dep_analysis/typescript.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 | import os
7 | import re
8 | import subprocess
9 | from pathlib import Path
10 | from typing import Dict, List
11 |
12 | import git
13 | import tempdir
14 | from fire import Fire
15 | from tqdm.auto import tqdm
16 |
17 | from scripts.curate.dataset_ensemble_clone import get_files_to_include
18 | from scripts.curate.utility import lang2suffix
19 |
20 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
21 |
22 |
23 | def traverse_dep(
24 | current_tree: Dict, current_file: str, entrypoint: str, dependencies: Dict
25 | ) -> None:
26 | if current_tree == None:
27 | dependencies[current_file] = set()
28 | return
29 |
30 | results = set(
31 | [
32 | file
33 | for file in current_tree.keys()
34 | if entrypoint in file and Path(file).suffix in [".ts", ".js"]
35 | ]
36 | )
37 | dependencies[current_file] = results
38 |
39 | for result in results:
40 | traverse_dep(current_tree[result], result, entrypoint, dependencies)
41 |
42 |
43 | def add_circular(circular_dependencies: Dict, dependencies: Dict):
44 | for cycle in circular_dependencies:
45 | for index, from_file in enumerate(cycle[:-1]):
46 | to_file = cycle[index + 1]
47 | dependencies[from_file].add(to_file)
48 |
49 |
50 | def get_dependencies(
51 | file_path: Path, temp_dir: str, entrypoint: str, dependencies: Dict
52 | ) -> None:
53 | command_list = f"dep-tree tree {file_path} --json".split()
54 | output_string = subprocess.check_output(command_list, cwd=temp_dir)
55 | output_string = output_string.decode("utf-8")[1:-2].split()
56 | json_string = "".join(chr(int(value)) for value in output_string)
57 |
58 | json_data = json.loads(json_string)
59 | tree_data = json_data["tree"]
60 | circular_dependencies = json_data["circularDependencies"]
61 | current_file = list(tree_data.keys())[0]
62 |
63 | traverse_dep(tree_data[current_file], current_file, entrypoint, dependencies)
64 | add_circular(circular_dependencies, dependencies)
65 |
66 |
67 | # dataset_path is the dataset generated by dataset_ensemble_clone.py
68 | def main():
69 | with open("scripts/cherrypick/lists.json") as f:
70 | lists = json.load(f)
71 |
72 | lang_suffix = lang2suffix["typescript"]
73 | repos = lists["typescript"]
74 | for repo in tqdm(repos):
75 | repo_name = repo["repo"]
76 | commit_sha = repo["commit_sha"]
77 | entrypoint = repo["entrypoint_path"]
78 | print(f"Visiting https://github.com/{repo_name}/tree/{commit_sha}")
79 |
80 | dependencies = {}
81 | with tempdir.TempDir() as temp_dir:
82 | gh_repo = git.Repo.clone_from(
83 | f"https://github.com/{repo_name}.git",
84 | temp_dir,
85 | )
86 | gh_repo.git.checkout(commit_sha)
87 | abs_prefix = Path(os.path.join(temp_dir, entrypoint))
88 | if "/" in entrypoint:
89 | suffix_path = entrypoint.split("/")[-1]
90 | else:
91 | suffix_path = entrypoint
92 | for file_path in abs_prefix.rglob("*"):
93 | if file_path.is_file() and file_path.suffix in lang_suffix:
94 | current_name = os.path.relpath(str(file_path), temp_dir)
95 | if current_name in dependencies:
96 | continue
97 | get_dependencies(file_path, temp_dir, suffix_path, dependencies)
98 |
99 | for key, value in dependencies.items():
100 | dependencies[key] = list(value)
101 |
102 | if "/" in entrypoint:
103 | updated_dependencies = {}
104 | append_prefix = "/".join(entrypoint.split("/")[:-1]) + "/"
105 | for key, values in dependencies.items():
106 | new_key = append_prefix + key
107 | if not ((Path(temp_dir) / Path(new_key)).exists()):
108 | continue
109 | new_values = []
110 | for value in values:
111 | new_value = append_prefix + value
112 | if (Path(temp_dir) / Path(new_value)).exists():
113 | new_values.append(new_value)
114 | updated_dependencies[new_key] = new_values
115 | dependencies = updated_dependencies
116 | repo["dependency"] = dependencies
117 | with open(os.path.join(CURRENT_DIR, "data", "typescript.json"), "w") as f_out:
118 | json.dump({"typescript": repos}, f_out)
119 |
120 |
121 | if __name__ == "__main__":
122 | Fire(main)
123 |
--------------------------------------------------------------------------------
/scripts/curate/function_analysis.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 |
7 | from tqdm import tqdm
8 | from tree_sitter_languages import get_language, get_parser
9 |
10 | from repoqa.utility import COMMENT_QUERY, FUNCTION_QUERY, topological_sort
11 |
12 | _default_name_parser = lambda node: node.child_by_field_name("name").text.decode()
13 | _cpp_name_parser = (
14 | lambda node: node.child_by_field_name("declarator")
15 | .child_by_field_name("declarator")
16 | .text.decode()
17 | )
18 |
19 |
20 | def comment_analysis(code: bytes, language: str) -> float:
21 | query_texts = COMMENT_QUERY[language]
22 | parser = get_parser(language)
23 | tree = parser.parse(code)
24 | characters = 0
25 | for query_text in query_texts:
26 | comment_query = get_language(language).query(query_text)
27 | for node, _ in comment_query.captures(tree.root_node):
28 | comment_text = code[node.start_byte : node.end_byte]
29 | characters += len(comment_text)
30 | return characters / len(code)
31 |
32 |
33 | # Annotate an incomplete repoqa dataset with function and class information
34 | def main(dataset_path: str, overwrite_analysis: bool = False):
35 | assert dataset_path.endswith(".json"), "Dataset must be a JSON file, check README"
36 | with open(dataset_path, "r") as f:
37 | lists = json.load(f)
38 |
39 | for lang, repos in lists.items():
40 | assert (
41 | lang in FUNCTION_QUERY
42 | ), f"Unsupported language: {lang} -- supported: {FUNCTION_QUERY.keys()}"
43 |
44 | fn_query_text = FUNCTION_QUERY[lang]
45 | print(f"🔥 Querying {lang} functions with `{fn_query_text}`...")
46 |
47 | parser = get_parser(lang)
48 | fn_query = get_language(lang).query(fn_query_text)
49 | fn_name_parser = _cpp_name_parser if lang == "cpp" else _default_name_parser
50 |
51 | for repo in tqdm(repos):
52 | # skip if the repo already has function information
53 | if not overwrite_analysis and repo.get("functions"):
54 | continue
55 |
56 | if not repo.get("dependency"):
57 | print(
58 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `dependency` -- do dependency analysis first"
59 | )
60 | continue
61 |
62 | ordered_paths = topological_sort(repo["dependency"])
63 | global_byte_idx = 0
64 | global_line_idx = 0
65 | functions = {} # path to a list of functions
66 | for path in ordered_paths:
67 | code = repo["content"][path]
68 | code_bytes = bytes(code, "utf8")
69 | tree = parser.parse(code_bytes)
70 | extracted_functions = []
71 | for capture in fn_query.captures(tree.root_node):
72 | node, _ = capture
73 | function_content = code_bytes[node.start_byte : node.end_byte]
74 | code_ratio = comment_analysis(function_content, lang)
75 | extracted_functions.append(
76 | {
77 | "name": fn_name_parser(node),
78 | "start_line": node.start_point[0],
79 | "end_line": node.end_point[0] + 1,
80 | "start_byte": node.start_byte,
81 | "end_byte": node.end_byte,
82 | "global_start_line": global_line_idx + node.start_point[0],
83 | "global_end_line": global_line_idx + node.end_point[0] + 1,
84 | "global_start_byte": global_byte_idx + node.start_byte,
85 | "global_end_byte": global_byte_idx + node.end_byte,
86 | "code_ratio": code_ratio,
87 | }
88 | )
89 | functions[path] = extracted_functions
90 | global_byte_idx += len(code)
91 | global_line_idx += code.count("\n") + 1
92 | repo["functions"] = functions
93 | print(
94 | f"🎉 Found {sum(len(v) for v in functions.values())} functions in {repo['repo']} ({lang})"
95 | )
96 |
97 | # update the dataset
98 | with open(dataset_path, "w") as f_out:
99 | json.dump(lists, f_out)
100 |
101 |
102 | if __name__ == "__main__":
103 | from fire import Fire
104 |
105 | Fire(main)
106 |
--------------------------------------------------------------------------------
/scripts/curate/github_fetch.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 | import os
7 | from datetime import datetime
8 | from typing import TypedDict
9 | from zoneinfo import ZoneInfo
10 |
11 | from fire import Fire
12 | from github import Auth, Github
13 | from tqdm.auto import tqdm
14 |
15 | from scripts.curate.utility import lang2suffix
16 |
17 |
18 | class GitHubRepoMeta(TypedDict):
19 | repo_name: str
20 | repo_owner: str
21 | commit_sha: str
22 | repo_size: int
23 |
24 |
25 | class GitHubDocument(GitHubRepoMeta):
26 | timestamp: str
27 | path: str
28 | content: str
29 |
30 |
31 | def main(
32 | language: str = "python",
33 | stars: int = 100,
34 | minimal_new_commits: int = 50,
35 | new_commit_since: str = "2024-01-01",
36 | minimal_lang_bytes: int = 1024 * 64, # 64k ideally
37 | ):
38 | token = os.getenv("GITHUB_TOKEN")
39 | assert token is not None, "Make a token at https://github.com/settings/tokens"
40 | auth = Auth.Token(token)
41 |
42 | # See repo selection criteria at https://github.com/evalplus/repoqa/issues/1
43 | query = []
44 | query.append(f"language:{language}")
45 | query.append(f"stars:>={stars}")
46 | query.append("license:mit license:apache-2.0")
47 | query.append(f"pushed:>={new_commit_since}")
48 | # 128KB to 32MB
49 | query.append("size:128..32000")
50 | query.append("sort:stars")
51 |
52 | # compile query
53 | query = " ".join(query)
54 | print(f"{query=}")
55 | g = Github(auth=auth, per_page=100)
56 |
57 | lang_suffix = lang2suffix[language]
58 |
59 | with open(f"{language}-{datetime.now().isoformat()}.jsonl", "w") as f_out:
60 | repos = g.search_repositories(query)
61 | print("Found ", repos.totalCount, "repositories for", language)
62 | for repo in tqdm(repos, total=repos.totalCount):
63 | # filter at least 100 commits have been made since 2023 Q4 (>= 2023-09-01).
64 | commits = repo.get_commits()
65 | if (
66 | count_ := sum(
67 | True
68 | for commit in commits
69 | if commit.last_modified_datetime
70 | >= datetime.strptime(new_commit_since, "%Y-%m-%d").replace(
71 | tzinfo=ZoneInfo("UTC")
72 | )
73 | )
74 | ) < minimal_new_commits:
75 | print(
76 | f"Skip {repo.html_url} for have less than {minimal_new_commits} after {new_commit_since} (only {count_} commits)"
77 | )
78 | continue
79 |
80 | # filter repos that is large enough
81 | git_tree = repo.get_git_tree(repo.default_branch, recursive=True)
82 |
83 | tree_iter = list(
84 | filter(
85 | lambda item: item.type == "blob"
86 | and any([item.path.endswith(suffix) for suffix in lang_suffix]),
87 | tqdm(git_tree.tree, leave=False),
88 | )
89 | )
90 |
91 | code_file_size = int(sum(item.size for item in tree_iter))
92 | if code_file_size < minimal_lang_bytes:
93 | print(
94 | f"Skip {repo.html_url} for have less than {minimal_lang_bytes} bytes source file after {new_commit_since} (only {code_file_size} bytes)"
95 | )
96 | continue
97 |
98 | schema = dict(
99 | repo_name=repo.name,
100 | repo_size=code_file_size,
101 | repo_owner=repo.owner.login,
102 | repo_url=repo.html_url,
103 | commit_sha=git_tree.sha,
104 | last_modified_time=git_tree.last_modified_datetime,
105 | content={},
106 | )
107 | for item in tree_iter:
108 | # Fetch the content for each Python file
109 | content = repo.get_contents(item.path)
110 | assert not isinstance(content, list)
111 | if content.encoding != "base64":
112 | continue
113 | file_content = content.decoded_content.decode("utf-8")
114 | schema["content"][item.path] = file_content
115 | f_out.write(json.dumps(schema) + "\n")
116 | f_out.flush()
117 |
118 |
119 | if __name__ == "__main__":
120 | Fire(main)
121 |
--------------------------------------------------------------------------------
/scripts/curate/merge_annotation.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 |
7 | from fire import Fire
8 |
9 |
10 | def main(dataset_path: str, annotation_path: str):
11 | assert dataset_path.endswith(".json"), "Dataset must be a JSON file, check README"
12 | assert annotation_path.endswith(
13 | ".jsonl"
14 | ), "Annotation must be a JSONL file, check README"
15 |
16 | with open(dataset_path) as f:
17 | dataset = json.load(f)
18 |
19 | with open(annotation_path) as f:
20 | annotations = [json.loads(line) for line in f]
21 |
22 | def make_key(repo_name, func_name):
23 | return f"{repo_name}::{func_name}"
24 |
25 | key2annotation = {make_key(a["repo"], a["name"]): a for a in annotations}
26 |
27 | for lang, repos in dataset.items():
28 | for repo in repos:
29 | if "needles" not in repo:
30 | print(
31 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `needles` -- do needle analysis first"
32 | )
33 | continue
34 | for needle in repo["needles"]:
35 | needle_name = needle["name"]
36 | key = make_key(repo["repo"], needle_name)
37 | annotation = key2annotation.get(key, None)
38 | if annotation is None:
39 | print(
40 | f"⚠️ Missing annotation for {key} for lang {lang} -- skipping"
41 | )
42 | continue
43 | needle["description"] = annotation["annotation"]
44 |
45 | with open(dataset_path, "w") as f_out:
46 | json.dump(dataset, f_out)
47 |
48 |
49 | if __name__ == "__main__":
50 | Fire(main)
51 |
--------------------------------------------------------------------------------
/scripts/curate/merge_dep.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 | import os
7 |
8 | from fire import Fire
9 |
10 |
11 | def main(dataset_path: str):
12 | # iterate json files under scripts/curate/dep_analysis/data
13 | repo2dep = {}
14 | for file in os.listdir(os.path.join("scripts/curate/dep_analysis/data")):
15 | if file.endswith(".json"):
16 | with open(os.path.join("scripts/curate/dep_analysis/data", file)) as f:
17 | data = json.load(f)
18 |
19 | repos = list(data.values())[0]
20 | for repo in repos:
21 | repo2dep[repo["repo"]] = repo["dependency"]
22 |
23 | with open(dataset_path) as f:
24 | dataset = json.load(f)
25 |
26 | for lang, repos in dataset.items():
27 | for repo in repos:
28 | if repo["repo"] not in repo2dep:
29 | print(f"{lang} -- Repo {repo['repo']} not found in dep analysis data")
30 | continue
31 | repo["dependency"] = repo2dep[repo["repo"]]
32 | print(f"{lang} -- Repo {repo['repo']} has dependency added in the dataset")
33 |
34 | with open(dataset_path, "w") as f_out:
35 | json.dump(dataset, f_out)
36 |
37 |
38 | if __name__ == "__main__":
39 | Fire(main)
40 |
--------------------------------------------------------------------------------
/scripts/curate/needle_annotation.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 | import os
7 |
8 | import openai
9 | from tqdm import tqdm
10 |
11 | from repoqa.provider.request.openai import make_auto_request
12 | from repoqa.utility import topological_sort
13 |
14 | CAPTURE_HEAD = ""
15 | CAPTURE_TAIL = ""
16 |
17 |
18 | def make_prompt(fn_name: str, code: str):
19 | instruction = f'Can you **briefly** describe the purpose, input, output, and procedure of "{fn_name}"?'
20 | return f"""\
21 | {instruction}
22 |
23 | ```
24 | {code}
25 | ```
26 |
27 | {instruction}
28 |
29 | Please follow format to complete the skeleton below:
30 |
31 | {CAPTURE_HEAD}
32 | 1. **Purpose**: ...
33 | 2. **Input**: ...
34 | 3. **Output**: ...
35 | 4. **Procedure**: ...
36 | {CAPTURE_TAIL}
37 |
38 | {instruction}
39 |
40 | Notes:
41 | 1. DO NOT reveal function names ({fn_name}) and variable names
42 | 2. Start with {CAPTURE_HEAD} and end with {CAPTURE_TAIL}
43 | 3. Customize the description to differentiate it from other functions
44 | """
45 |
46 |
47 | # Annotate an incomplete repoqa dataset with function and class information
48 | def main(
49 | dataset_path: str,
50 | code_prefix_lines: int = 100,
51 | output_desc_path: str = "function_description.jsonl",
52 | use_batch_api: bool = False,
53 | verbose: bool = False,
54 | debug: bool = False,
55 | ):
56 | assert use_batch_api == False, "Batch API is not supported yet."
57 |
58 | assert dataset_path.endswith(".json"), "Dataset must be a JSON file, check README"
59 | with open(dataset_path, "r") as f:
60 | lists = json.load(f)
61 |
62 | # resume from output_desc_path
63 | if output_desc_path.endswith(".jsonl") and os.path.exists(output_desc_path):
64 | with open(output_desc_path, "r") as f:
65 | results = [json.loads(line) for line in f]
66 | else:
67 | # {repo, name, prompt, annotation}
68 | results = []
69 |
70 | # a set of inference task to run; each item is a tuple of {repo, name, prompt}
71 | tasks = []
72 | for lang, repos in lists.items():
73 | print(f"🔥 Collecting unannotated needle functions for {lang}")
74 | for repo in tqdm(repos):
75 | if not repo.get("dependency"):
76 | print(
77 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `dependency` -- do dependency analysis first"
78 | )
79 | continue
80 | ordered_paths = topological_sort(repo["dependency"])
81 | repo_lines = []
82 | for path in ordered_paths:
83 | repo_lines.extend(repo["content"][path].split("\n"))
84 |
85 | def get_code(global_start_line, global_end_line):
86 | return "\n".join(
87 | repo_lines[
88 | max(0, global_start_line - code_prefix_lines) : global_end_line
89 | ]
90 | )
91 |
92 | existing_needles = set(
93 | [item["name"] for item in results if item["repo"] == repo["repo"]]
94 | )
95 |
96 | for needle in repo["needles"]:
97 | fn_name = needle["name"]
98 | if fn_name in existing_needles:
99 | continue
100 | code = get_code(needle["global_start_line"], needle["global_end_line"])
101 | prompt = make_prompt(fn_name, code)
102 | if verbose:
103 | print(prompt)
104 | print("-" * 80)
105 | tasks.append(
106 | {
107 | "repo": repo["repo"],
108 | "name": fn_name,
109 | "prompt": prompt,
110 | }
111 | )
112 |
113 | print(f"🔥 {len(tasks)} needle functions to be annotated in total")
114 | client = openai.Client()
115 | with open(output_desc_path, "+a") as f_out:
116 | for task in tqdm(tasks):
117 | print(f"🔥 Annotating {task['name']} in {task['repo']}")
118 | output = make_auto_request(
119 | client,
120 | task["prompt"],
121 | model="gpt-4-turbo",
122 | max_tokens=2048,
123 | temperature=0.2,
124 | n=1,
125 | )
126 | annotation = output.choices[0].message.content
127 | result = {
128 | "repo": task["repo"],
129 | "name": task["name"],
130 | "prompt": task["prompt"],
131 | "raw_annotation": annotation,
132 | "annotation": annotation.split(CAPTURE_HEAD)[-1].split(CAPTURE_TAIL)[0],
133 | }
134 | json.dump(result, f_out)
135 | f_out.write("\n")
136 | f_out.flush()
137 |
138 | if debug:
139 | print("[PROMPT]", "-" * 80)
140 | print(task["prompt"])
141 | print("[ANNOTATION]", "-" * 80)
142 | print(annotation)
143 | print("-" * 80)
144 | print("Enter to continue... or b to break:")
145 | if input() == "b":
146 | break
147 |
148 |
149 | if __name__ == "__main__":
150 | from fire import Fire
151 |
152 | Fire(main)
153 |
--------------------------------------------------------------------------------
/scripts/curate/needle_selection.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | import json
6 | from collections import Counter
7 | from random import sample
8 |
9 | from tqdm import tqdm
10 |
11 |
12 | # Annotate an incomplete repoqa dataset with function and class information
13 | def main(
14 | dataset_path: str,
15 | overwrite_analysis: bool = False,
16 | max_len: int = 2000,
17 | num_bins: int = 64,
18 | max_fn_per_repo: int = 10,
19 | ):
20 | assert (
21 | num_bins >= max_fn_per_repo
22 | ), "Number of bins must be greater than max functions per repo"
23 | assert dataset_path.endswith(".json"), "Dataset must be a JSON file, check README"
24 | with open(dataset_path, "r") as f:
25 | lists = json.load(f)
26 |
27 | for lang, repos in lists.items():
28 | print(f"🔥 Selecting needle functions for {lang}")
29 | # FIXME(@ganler): enable more dependency analysis!
30 | for repo in tqdm(repos):
31 | # skip if the repo already has function information
32 | if not overwrite_analysis and repo.get("needles"):
33 | continue
34 |
35 | if not repo.get("functions"):
36 | print(
37 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `functions` field -- do function analysis first"
38 | )
39 | continue
40 |
41 | repo_size_bytes = sum(len(content) for content in repo["content"].values())
42 |
43 | selected_bins = set()
44 | bin_size = repo_size_bytes // num_bins
45 |
46 | function_names = Counter()
47 | for funcs in repo["functions"].values():
48 | for fn in funcs:
49 | function_names.update([fn["name"]])
50 | # get function names that only appear once
51 | function_names = {k for k, v in function_names.items() if v == 1}
52 |
53 | needle_candidates = []
54 | for path, funcs in repo["functions"].items():
55 | for fn in funcs:
56 | # criteria 1: no repeated function names
57 | if fn["name"] not in function_names:
58 | continue
59 |
60 | # criteria 2: length <= max_len
61 | if fn["end_byte"] - fn["start_byte"] > max_len:
62 | continue
63 |
64 | # criteria 3: not in the same bin
65 | bin_idx = fn["global_start_byte"] // bin_size
66 | if bin_idx in selected_bins:
67 | continue
68 |
69 | # criteria 4: TODO -- select those with more code!
70 | selected_bins |= {bin_idx}
71 | needle_candidates.append((path, fn))
72 |
73 | len_total_fn = sum(len(v) for v in repo["functions"].values())
74 | print(
75 | f"🎉 Selected {len(needle_candidates)} needles from {len_total_fn} functions in {repo['repo']} ({lang})"
76 | )
77 | needles = []
78 | for path, fn in sample(
79 | needle_candidates, min(max_fn_per_repo, len(needle_candidates))
80 | ):
81 | needles.append({**fn, "path": path})
82 | repo["needles"] = needles
83 |
84 | # update the dataset
85 | with open(dataset_path, "w") as f_out:
86 | json.dump(lists, f_out)
87 |
88 |
89 | if __name__ == "__main__":
90 | from fire import Fire
91 |
92 | Fire(main)
93 |
--------------------------------------------------------------------------------
/scripts/curate/obfuscate_nl.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | # Run with ```python scripts/curate/obfuscate_nl.py repoqa-2024-04-20.json```
6 | # Will save to repoqa-2024-04-20-obfuscated.json
7 |
8 | import json
9 | import os
10 | import re
11 |
12 | from fire import Fire
13 | from tqdm import tqdm
14 | from tree_sitter_languages import get_language, get_parser
15 |
16 | from repoqa.utility import COMMENT_QUERY, FUNCTION_NAME_QUERY
17 |
18 |
19 | def remove_comments(code, language):
20 | query_texts = COMMENT_QUERY[language]
21 | parser = get_parser(language)
22 | code_bytes = bytes(code, "utf8")
23 | tree = parser.parse(code_bytes)
24 | comment_chunks = []
25 | for query_text in query_texts:
26 | comment_query = get_language(language).query(query_text)
27 | for node, _ in comment_query.captures(tree.root_node):
28 | comment_chunks.append(node.text.decode("utf-8"))
29 | comment_chunks.sort(key=len, reverse=True)
30 | for chunk in comment_chunks:
31 | chunk_lines = chunk.splitlines()
32 | chunk_lines_len = [len(bytes(line, "utf-8")) for line in chunk_lines]
33 | chunk_lines_empty = [
34 | (bytes("", "utf-8").ljust(llen, b"\0")).decode("utf-8")
35 | for llen in chunk_lines_len
36 | ]
37 | chunk_empty = "\0".join(chunk_lines_empty)
38 | chunk_empty = chunk_empty[:-1] + "\n"
39 | code = code.replace(chunk, chunk_empty)
40 | return code
41 |
42 |
43 | def rename_functions(code, language, starting_index=0):
44 | func_name_query = get_language(language).query(FUNCTION_NAME_QUERY[language])
45 | parser = get_parser(language)
46 | print(f"Running rename_functions: {code}, {language}")
47 | code_bytes = bytes(code, "utf8")
48 | tree = parser.parse(code_bytes)
49 | function_names = set()
50 | for capture in func_name_query.captures(tree.root_node):
51 | node, _ = capture
52 | function_names.add(node.text.decode("utf-8"))
53 | function_map = {}
54 | current_index = starting_index
55 | for name in function_names:
56 | function_map[name] = f"function_{starting_index}"
57 | code = code.replace(name, function_map[name])
58 | current_index += 1
59 | return code, function_map
60 |
61 |
62 | def main(ds_filepath: str):
63 | dataset_file = open(ds_filepath, "r")
64 | dataset = dataset_file.read()
65 | dataset = json.loads(dataset)
66 | dataset_file.close()
67 |
68 | for lang in dataset.keys():
69 | print(f"🔥 Processing language: {lang}")
70 | for repo_idx in tqdm(range(len(dataset[lang]))):
71 | for filepath in dataset[lang][repo_idx]["content"].keys():
72 | prev_byte_len = len(
73 | bytes(dataset[lang][repo_idx]["content"][filepath], "utf-8")
74 | )
75 | dataset[lang][repo_idx]["content"][filepath] = remove_comments(
76 | dataset[lang][repo_idx]["content"][filepath], lang
77 | )
78 | new_byte_len = len(
79 | bytes(dataset[lang][repo_idx]["content"][filepath], "utf-8")
80 | )
81 | assert prev_byte_len == new_byte_len
82 |
83 | dataset_dir = "/".join(ds_filepath.split("/")[:-1])
84 | ds_filepath = ds_filepath.split("/")[-1]
85 | ds_fname = ".".join(ds_filepath.split(".")[:-1])
86 | ds_ext = ds_filepath.split(".")[-1]
87 |
88 | obfs_ds_file = open(
89 | os.path.join(dataset_dir, f"{ds_fname}-obfuscated.{ds_ext}"), "w+"
90 | )
91 | obfs_ds_file.write(json.dumps(dataset))
92 | obfs_ds_file.close()
93 |
94 |
95 | if __name__ == "__main__":
96 | Fire(main)
97 |
--------------------------------------------------------------------------------
/scripts/curate/requirements.txt:
--------------------------------------------------------------------------------
1 | PyGithub
2 | fire
3 | tqdm
4 | tempdir
5 | GitPython
6 | pydeps
7 | tree_sitter
8 | tree_sitter_languages
9 | pygraphviz
10 | matplotlib
11 | transformers
12 | python-dep-tree
13 |
--------------------------------------------------------------------------------
/scripts/curate/utility.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | lang2suffix = {
6 | "python": [".py"],
7 | "go": [".go"],
8 | "cpp": [".cpp", ".hpp", ".cc", ".hh", ".cxx", ".hxx", ".c", ".h"],
9 | "java": [".java"],
10 | "typescript": [".ts", ".js"],
11 | "php": [".php"],
12 | "rust": [".rs"],
13 | }
14 |
--------------------------------------------------------------------------------
/scripts/demos/model_request_oai.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | if __name__ == "__main__":
6 | import openai
7 |
8 | client = openai.OpenAI( # Note: if you need UIUC VPN or UIUC network to access the server!
9 | api_key="none", base_url="http://ise-dynamo.cs.illinois.edu:8888/v1"
10 | )
11 |
12 | task_prefix = "def fibonacci(n):\n"
13 | prompt = f"""This is the fastest implementation for Fibonacci:
14 | ```python
15 | {task_prefix}"""
16 |
17 | # completion
18 | responses = client.completions.create(
19 | model="deepseek-ai/deepseek-coder-6.7b-instruct",
20 | prompt=prompt,
21 | max_tokens=256,
22 | n=3,
23 | stop=["\n```"],
24 | )
25 |
26 | for c in responses.choices:
27 | print(task_prefix + c.text)
28 | print("=" * 8)
29 |
--------------------------------------------------------------------------------
/scripts/dev/license-hdr.txt:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 |
3 | SPDX-License-Identifier: Apache-2.0
4 |
--------------------------------------------------------------------------------
/scripts/eval/recompute_all_scores.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
4 | #
5 | # SPDX-License-Identifier: Apache-2.0
6 |
7 | set -e
8 |
9 | export PYTHONPATH=$(pwd)
10 |
11 | for path in "$(pwd)"/results/**/*.jsonl; do
12 | # if the file size is greater than 10MB
13 | file_size_mb=$(du -m "$path" | cut -f1)
14 | echo "Size of $path: $file_size_mb MB"
15 | if [ $file_size_mb -lt 10 ]; then
16 | echo "File size is less than 10MB. Skipping..."
17 | continue
18 | fi
19 | yes | python repoqa/compute_score.py --model-output-path $path
20 | done
21 |
--------------------------------------------------------------------------------
/scripts/misc/estimate_max_char.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 |
6 | def main(model="deepseek-ai/deepseek-coder-6.7b-instruct", max_tokens=8 * 2**10):
7 | import time
8 |
9 | import openai
10 |
11 | client = openai.OpenAI( # Note: if you need UIUC VPN or UIUC network to access the server!
12 | api_key="none", base_url="http://ise-dynamo.cs.illinois.edu:8888/v1"
13 | )
14 |
15 | prompt = "def "
16 |
17 | tstart = time.time()
18 | responses = client.completions.create(
19 | model=model, prompt=prompt, n=1, max_tokens=max_tokens, temperature=0
20 | )
21 | print(f"Time taken: {time.time() - tstart:.1f}s")
22 | print("Finish reason:", responses.choices[0].finish_reason)
23 | print(
24 | "Estimated max context length in #chars is: ",
25 | len(prompt) + len(responses.choices[0].text),
26 | )
27 |
28 |
29 | if __name__ == "__main__":
30 | from fire import Fire
31 |
32 | Fire(main)
33 |
--------------------------------------------------------------------------------
/scripts/misc/repo_token_size.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | # Estimate the maximum token size for a repository
6 |
7 | import json
8 |
9 | import matplotlib.pyplot as plt
10 | from transformers import AutoTokenizer
11 |
12 | from repoqa.utility import topological_sort
13 |
14 | COMMENT_PREFIX = {
15 | "python": "#",
16 | "java": "//",
17 | "typescript": "//",
18 | "rust": "//",
19 | "cpp": "//",
20 | }
21 |
22 |
23 | def get_max_token_size(dataset_path: str, model: str):
24 | with open(dataset_path, "r") as f:
25 | repo = json.load(f)
26 |
27 | tokenizer = AutoTokenizer.from_pretrained(model)
28 |
29 | min_token_size = 1e9
30 | min_token_size_repo = None
31 | token_sizes = []
32 |
33 | for lang, repos in repo.items():
34 | for repo in repos:
35 | if "dependency" not in repo:
36 | print(
37 | f"⚠️ Skipping {repo['repo']} ({lang}) as it does not have `dependency` -- do dependency analysis first"
38 | )
39 | continue
40 |
41 | ordered_paths = topological_sort(repo["dependency"])
42 |
43 | bigfile = ""
44 | for path in ordered_paths:
45 | bigfile += (
46 | COMMENT_PREFIX[lang]
47 | + f" current path: {path}\n"
48 | + repo["content"][path]
49 | )
50 |
51 | # estimate the maximum token size
52 | token_size = tokenizer(bigfile, return_tensors="pt")["input_ids"].shape[1]
53 | token_sizes.append(token_size)
54 | min_token_size = min(min_token_size, token_size)
55 | if min_token_size == token_size:
56 | min_token_size_repo = repo["repo"]
57 | print(f"[{lang}] {repo['repo']:<32}: {token_size:>20} tokens")
58 |
59 | print(f"Estimated minimum token size: {min_token_size} by {min_token_size_repo}")
60 |
61 | # visualize the distribution
62 | plt.figure(figsize=(8, 4))
63 | plt.hist(token_sizes, bins=64)
64 | # xtick at every 20k
65 | unit = 100 * 1000
66 | plt.xticks(range(0, max(token_sizes) + 1, unit))
67 | plt.xlim(0, max(token_sizes) + 1000)
68 | # xtick using "k"
69 | plt.gca().xaxis.set_major_formatter(
70 | plt.FuncFormatter(lambda x, _: f"{x / unit:.0f}")
71 | )
72 | plt.xlabel("Token size (100k)")
73 | plt.ylabel("Frequency")
74 | plt.title("Token size distribution")
75 | # compact layout
76 | plt.tight_layout()
77 | plt.savefig("token_size_distribution.png", dpi=164, bbox_inches="tight")
78 |
79 |
80 | if __name__ == "__main__":
81 | from fire import Fire
82 |
83 | Fire(get_max_token_size)
84 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = repoqa
3 | description = "RepoQA for Evaluating Long-Context Code Understanding"
4 | long_description = file: README.md
5 | long_description_content_type = text/markdown
6 | url = https://github.com/evalplus/repoqa
7 | license = Apache-2.0
8 | license_files = LICENSE
9 | platform = any
10 | classifiers =
11 | Operating System :: OS Independent
12 | Programming Language :: Python :: 3
13 | License :: OSI Approved :: Apache Software License
14 |
15 | [options]
16 | packages = find:
17 | python_requires = >=3.8
18 | dependency_links =
19 | install_requires =
20 | tempdir>=0.7.1
21 | appdirs>=1.4.4
22 | wget>=3.2
23 | fire>=0.6.0
24 | nltk>=3.8.1
25 | rich>=13.5.2
26 | numpy>=1.25.2
27 | tree_sitter<=0.21.3
28 | tree_sitter_languages>=1.10.2
29 | transformers>=4.40.0
30 | openai>=1.23.2
31 | anthropic>=0.25.6
32 | google-generativeai>=0.5.2
33 |
34 | [options.entry_points]
35 | console_scripts =
36 | repoqa.search_needle_function = repoqa.search_needle_function:main
37 | repoqa.compute_score = repoqa.compute_score:main
38 |
39 | [options.extras_require]
40 | vllm = vllm>=0.3.3
41 |
--------------------------------------------------------------------------------