├── .VERSION
├── .github
└── workflows
│ ├── ci.yml
│ └── publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── codenav
├── __init__.py
├── agents
│ ├── __init__.py
│ ├── agent.py
│ ├── cohere
│ │ ├── __init__.py
│ │ └── agent.py
│ ├── gpt4
│ │ ├── __init__.py
│ │ └── agent.py
│ ├── interaction_formatters.py
│ └── llm_chat_agent.py
├── codenav_run.py
├── constants.py
├── default_eval_spec.py
├── environments
│ ├── __init__.py
│ ├── abstractions.py
│ ├── code_env.py
│ ├── code_summary_env.py
│ ├── done_env.py
│ └── retrieval_env.py
├── interaction
│ ├── __init__.py
│ ├── episode.py
│ └── messages.py
├── prompts
│ ├── __init__.py
│ ├── codenav
│ │ └── repo_description.txt
│ ├── default
│ │ ├── action__code.txt
│ │ ├── action__done.txt
│ │ ├── action__es_search.txt
│ │ ├── action__guidelines.txt
│ │ ├── action__preamble.txt
│ │ ├── overview.txt
│ │ ├── repo_description.txt
│ │ ├── response__code.txt
│ │ ├── response__done.txt
│ │ ├── response__es_search.txt
│ │ ├── response__preamble.txt
│ │ └── workflow.txt
│ ├── query_prompt.py
│ └── restart_prompt.py
├── retrieval
│ ├── __init__.py
│ ├── code_blocks.py
│ ├── code_summarizer.py
│ └── elasticsearch
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── create_index.py
│ │ ├── debug_add_item_to_index.py
│ │ ├── elasticsearch.yml
│ │ ├── elasticsearch_constants.py
│ │ ├── elasticsearch_retriever.py
│ │ ├── index_codebase.py
│ │ └── install_elasticsearch.py
└── utils
│ ├── __init__.py
│ ├── config_params.py
│ ├── eval_types.py
│ ├── evaluator.py
│ ├── hashing_utils.py
│ ├── linting_and_type_checking_utils.py
│ ├── llm_utils.py
│ ├── logging_utils.py
│ ├── omegaconf_utils.py
│ ├── parsing_utils.py
│ ├── prompt_utils.py
│ └── string_utils.py
├── codenav_examples
├── create_code_env.py
├── create_episode.py
├── create_index.py
├── create_prompt.py
└── parallel_evaluation.py
├── playground
└── .gitignore
├── pyproject.toml
├── requirements.txt
├── scripts
└── release.py
└── setup.py
/.VERSION:
--------------------------------------------------------------------------------
1 | 0.0.1
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | concurrency:
4 | group: ${{ github.workflow }}-${{ github.ref }}
5 | cancel-in-progress: true
6 |
7 | on:
8 | pull_request:
9 | branches:
10 | - main
11 | push:
12 | branches:
13 | - main
14 |
15 | env:
16 | # Change this to invalidate existing cache.
17 | CACHE_PREFIX: v0
18 | PYTHON_PATH: ./
19 |
20 | jobs:
21 | checks:
22 | name: python ${{ matrix.python }} - ${{ matrix.task.name }}
23 | runs-on: [ubuntu-latest]
24 | timeout-minutes: 30
25 | strategy:
26 | fail-fast: false
27 | matrix:
28 | python: [3.9]
29 | task:
30 | - name: Style
31 | run: |
32 | black --check .
33 |
34 | # TODO: Add testing back in after fixing the tests for use with GH actions.
35 | # - name: Test
36 | # run: |
37 | # pytest -v --color=yes tests/
38 |
39 | steps:
40 | - uses: actions/checkout@v3
41 |
42 | - name: Setup Python
43 | uses: actions/setup-python@v4
44 | with:
45 | python-version: ${{ matrix.python }}
46 |
47 | - name: Install prerequisites
48 | run: |
49 | pip install --upgrade pip setuptools wheel virtualenv
50 |
51 | - name: Set build variables
52 | shell: bash
53 | run: |
54 | # Get the exact Python version to use in the cache key.
55 | echo "PYTHON_VERSION=$(python --version)" >> $GITHUB_ENV
56 | echo "RUNNER_ARCH=$(uname -m)" >> $GITHUB_ENV
57 | # Use week number in cache key so we can refresh the cache weekly.
58 | echo "WEEK_NUMBER=$(date +%V)" >> $GITHUB_ENV
59 |
60 | - uses: actions/cache@v3
61 | id: virtualenv-cache
62 | with:
63 | path: .venv
64 | key: ${{ env.CACHE_PREFIX }}-${{ env.WEEK_NUMBER }}-${{ runner.os }}-${{ env.RUNNER_ARCH }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('requirements.txt') }}
65 | restore-keys: |
66 | ${{ env.CACHE_PREFIX }}-${{ env.WEEK_NUMBER }}-${{ runner.os }}-${{ env.RUNNER_ARCH }}-${{ env.PYTHON_VERSION }}-
67 |
68 | - name: Setup virtual environment (no cache hit)
69 | if: steps.virtualenv-cache.outputs.cache-hit != 'true'
70 | run: |
71 | test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv
72 | . .venv/bin/activate
73 | pip install -e .[dev]
74 |
75 | - name: Setup virtual environment (cache hit)
76 | if: steps.virtualenv-cache.outputs.cache-hit == 'true'
77 | run: |
78 | . .venv/bin/activate
79 | pip install --no-deps -e .[dev]
80 |
81 | - name: Show environment info
82 | run: |
83 | . .venv/bin/activate
84 | which python
85 | python --version
86 | pip freeze
87 |
88 | - name: ${{ matrix.task.name }}
89 | run: |
90 | . .venv/bin/activate
91 | ${{ matrix.task.run }}
92 |
93 | - name: Clean up
94 | if: always()
95 | run: |
96 | . .venv/bin/activate
97 | pip uninstall -y codenav
98 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload the codenav package using Twine (after manually triggering it)
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Publish PYPI Packages
5 |
6 | on:
7 | workflow_dispatch:
8 |
9 | jobs:
10 | deploy:
11 |
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v2
16 | - name: Set up Python
17 | uses: actions/setup-python@v2
18 | with:
19 | python-version: '3.9'
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install setuptools twine
24 | - name: Build and publish
25 | env:
26 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
27 | run: |
28 | python scripts/release.py
29 | twine upload -u __token__ dist/*
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | rag_cache
163 | __tmp*
164 | wandb
165 | tmp
166 | external_src
167 | codenav/external_src
168 | file_lock
169 | tmp.py
170 | *.jpg
171 | *.png
172 | rsync-to-*
173 | results
174 | *_out
175 | .vscode
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | Copyright 2024 Allen Institute for AI
179 |
180 | Licensed under the Apache License, Version 2.0 (the "License");
181 | you may not use this file except in compliance with the License.
182 | You may obtain a copy of the License at
183 |
184 | http://www.apache.org/licenses/LICENSE-2.0
185 |
186 | Unless required by applicable law or agreed to in writing, software
187 | distributed under the License is distributed on an "AS IS" BASIS,
188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
189 | See the License for the specific language governing permissions and
190 | limitations under the License.
191 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [CodeNav: Beyond tool-use to using real-world codebases with LLM agents 🚀](https://codenav.allenai.org/)
2 |
3 |
4 |

5 |
6 |
7 | [](https://arxiv.org/abs/2406.12276)
8 | [](https://codenav.allenai.org/)
9 | [](https://codenav.streamlit.app/)
10 | [](https://opensource.org/licenses/Apache-2.0)
11 | [](https://www.python.org/downloads/release/python-390/)
12 | [](https://github.com/psf/black)
13 |
14 | CodeNav is an LLM agent that navigates and leverages previously unseen code repositories to solve user queries. In contrast to tool-use LLM agents that require "registration" of all relevant tools via manual descriptions within the LLM context, CodeNav automatically indexes and searches over code blocks in the target codebase, finds relevant code snippets, imports them, and uses them to iteratively generate a solution with execution feedback.
15 |
16 | ## Getting Started 🛠️
17 |
18 | You can use CodeNav as a command line tool or programmatically as a Python module. In either case, you'll first
19 | want to install CodeNav:
20 | ```bash
21 | pip install git+https://github.com/allenai/codenav
22 | ```
23 |
24 | ### CodeNav as a command line tool
25 |
26 | After installing `codenav`, you can use it as a command line tool by running :
27 | ```bash
28 | codenav init # Downloads/starts the Elasticsearch search index CodeNav depends to search for code snippets
29 | ```
30 | and then
31 | ```bash
32 | codenav query \
33 | --code_dir /PATH/TO/CODEBASE/YOU/WANT/CODENAV/TO/USE \
34 | --playground_dir /WORKING/DIRECTORY/FOR/CODENAV/AGENT \
35 | --query "Query you want CodeNav to answer using the above codebase"
36 | ```
37 | You can find other command line options by running `codenav --help`. For example, you might run something like
38 | ```bash
39 | codenav query \
40 | --code_dir /PATH/TO/THIS/REPO/codenav \
41 | --playground_dir /PATH/TO/THIS/REPO/playground \
42 | --query "Write a google-style documentation string for the DoneEnv class and save it to DoneEnv.py"
43 | ```
44 | Running the above results in the CodeNav agent saving a file `DoneEnv.py` with contents:
45 |
46 | Click to see DoneEnv.py contents
47 |
48 | class DoneEnv(CodeNavEnv):
49 | """
50 | DoneEnv is an environment class that handles the 'done' action in the CodeNav framework.
51 |
52 | Methods:
53 | check_action_validity(action: CodeNavAction) -> Tuple[bool, str]:
54 | Checks if the given action is valid for the 'done' action.
55 |
56 | step(action: CodeNavAction) -> None:
57 | Executes the 'done' action.
58 | """
59 | def check_action_validity(self, action: CodeNavAction) -> Tuple[bool, str]:
60 | """
61 | Checks if the given action is valid for the 'done' action.
62 |
63 | Args:
64 | action (CodeNavAction): The action to be validated.
65 |
66 | Returns:
67 | Tuple[bool, str]: A tuple containing a boolean indicating validity and an error message if invalid.
68 | """
69 | assert action.content is not None
70 |
71 | if action.content.strip().lower() in ["true", "false"]:
72 | return True, ""
73 | else:
74 | return (
75 | False,
76 | "When executing the done action, the content must be either 'True' or 'False'",
77 | )
78 |
79 | def step(self, action: CodeNavAction) -> None:
80 | """
81 | Executes the 'done' action.
82 |
83 | Args:
84 | action (CodeNavAction): The action to be executed.
85 | """
86 | return None
87 |
88 |
89 | Note: the `codenav` command line tool is simply an alias for running the [codenav_run.py](codenav%2Fcodenav_run.py) so
90 | you can replace `codenav ...` with `python -m codenav.codenav_run ...`
91 | or `python /path/to/codenav/codenav_run.py ...` and obtain the same results.
92 |
93 | Here's a more detailed description of the arguments you can pass to `codenav query` or `python -m codenav.codenav_run query`:
94 | | Argument | Type | Description |
95 | | --- | --- | --- |
96 | | `--code_dir` | str | The path to the codebase you want CodeNav to use. By default all files in this directory will get indexed with relative file paths. For instance, if you set `--code_dir /Users/tanmay/codebase` which contains a `computer_vision/tool.py` file then this file will be indexed with relative path `computer_vision/tools.py` |
97 | | `--force_subdir` | str | If you wish to only index a subdirectory within the code_dir then set this to the name of the sub directory |
98 | | `--module` | str | If you have a module installed e.g. via `pip install transformers` and you want CodeNav to use this module, you can simply set `--module transformers` instead of providing `--code_dir` |
99 | | `--repo_description_path` | str | If you have a README file or a file with a description of the codebase you are using, you can provide the path to this file here. You may use this file to point out to CodeNav the high-level purpose and structure of the codebase (e.g. highlight important directories, files, classes or functions) |
100 | | `--force_reindex` | bool | Set this flag if you want to force CodeNav to reindex the codebase. Otherwise, CodeNav will reuse an existing index if it exists or create one if it doesn't |
101 | | `--playground_dir` | str | The path specified here will work as the current directory for CodeNav's execution environment |
102 | | `--query` | str | The query you want CodeNav to solve using the codebase |
103 | | `--query_file` | str | If your query is long, you may want to save it to a txt file and provide the path to the text file here |
104 | | `--max_steps` | int | The maximum number of interactions to allow between CodeNav agent and environments |
105 |
106 |
107 | ### CodeNav as a library
108 |
109 | If you'd like to use CodeNav programmatically, you can do so by importing the `codenav` module and using the various
110 | functions/classes we provide. To get a sense of how this is done, we provide a number of example scripts
111 | under the [codenav_examples](codenav_examples) directory:
112 | - [create_index.py](codenav_examples%2Fcreate_index.py): Creates an Elasticsearch index for this codebase and then uses the `RetrievalEnv` environment to search for a code snippet.
113 | - [create_episode.py](codenav_examples%2Fcreate_episode.py): Creates an `OpenAICodeNavAgent` agent and then uses it to generate a solution for the query `"Find the DoneEnv and instantiate it"` **on this codebase** (i.e. executes a CodeNav agent on the CodeNav codebase). Be sure to run the `create_index.py` script above to generate the index before running this script.
114 | - [create_code_env.py](codenav_examples%2Fcreate_code_env.py)): Creates a `PythonCodeEnv` object and then executes a given code string in this environemnt
115 | - [create_prompt.py](codenav_examples%2Fcreate_prompt.py): Creates a custom prompt and instantiates and CodeNav agent with that prompt.
116 | - [parallel_evaluation.py](codenav_examples%2Fparallel_evaluation.py): Demonstrates how to run multiple CodeNav agents in parallel. This is useful for evaluating on a dataset of queries using multiple processes. The EvalSpec abstraction also helps you organize the code a little better!
117 |
118 | **Note** - You will still need to launch ElasticSearch server before running any of the above. To do so run
119 | ```
120 | python -m codenav.codenav_run init
121 | ```
122 |
123 | ## Elasticsearch & Indexing Gotchas 🤔
124 |
125 | When running CodeNav you must start an Elasticsearch index on your machine (e.g. by running `codenav init`)
126 | and once you run a query on a given codebase, CodeNav will index that codebase exactly once.
127 | This process means there are two things you should keep in mind:
128 | 1. You must manually shut off the Elasticsearch index once you are done with it. You can do this by running `codenav stop`.
129 | 2. If you modify/update the codebase you are asking CodeNav to use the Elasticsearch index will not automatically update and thus CodeNav will be writing code using stale information. In this case, you should add the `--force_reindex` flag when running `codenav query`, this will force CodeNav to reindex the codebase.
130 | 3. If you run CodeNav and find that it is unable to search for a file, you may want to make sure the file was indexed correctly. You can inspect all indexed files using Elasticsearch's Kibana interface at `http://localhost:5601/`. To view all the indices index by CodeNav, go to `http://localhost:5601/app/management/data/index_management`. Then click on the index you want to inspect and the click on "Discover Index" on the top-right side of the page. This will show you all the code blocks stored in this index. You can now use the UI to run queries against this index and see if the file your are looking for is present in the index and if it has the correct file path.
131 |
132 | ## Warning ⚠️
133 |
134 | CodeNav is a research project and may make errors. As CodeNav can potentially execute ANY code
135 | it wants, it is not suitable for security sensitive applications. We strongly recommend
136 | that you run CodeNav in a sandboxed environment where data loss or security breaches are not a concern.
137 |
138 | ## Authors ✍️
139 | - [Tanmay Gupta](https://tanmaygupta.info/)
140 | - [Luca Weihs](https://lucaweihs.github.io/)
141 | - [Aniruddha Kembhavi](https://anikem.github.io/)
142 |
143 | ## License 📄
144 | This project is licensed under the Apache 2.0 License.
145 |
146 | ## Citation
147 | ```bibtex
148 | @misc{gupta2024codenavtooluseusingrealworld,
149 | title={CodeNav: Beyond tool-use to using real-world codebases with LLM agents},
150 | author={Tanmay Gupta and Luca Weihs and Aniruddha Kembhavi},
151 | year={2024},
152 | eprint={2406.12276},
153 | archivePrefix={arXiv},
154 | primaryClass={cs.AI},
155 | url={https://arxiv.org/abs/2406.12276},
156 | }
157 | ```
158 |
159 | CodeNav builds along the research direction we started exploring with VisProg (CVPR 2023 Best Paper). For more context please visit [https://github.com/allenai/visprog/blob/main/README.md](https://github.com/allenai/visprog/blob/main/README.md).
160 |
--------------------------------------------------------------------------------
/codenav/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/__init__.py
--------------------------------------------------------------------------------
/codenav/agents/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/agents/__init__.py
--------------------------------------------------------------------------------
/codenav/agents/agent.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Optional, Sequence, List, Dict, Any
3 |
4 | import codenav.interaction.messages as msg
5 | from codenav.agents.interaction_formatters import InteractionFormatter
6 |
7 |
8 | class CodeNavAgent(abc.ABC):
9 | max_tokens: int
10 | interaction_formatter: InteractionFormatter
11 | model: str
12 |
13 | def __init__(self, allowed_action_types: Sequence[msg.ACTION_TYPES]):
14 | self.allowed_action_types = allowed_action_types
15 | # self.reset()
16 |
17 | def reset(self):
18 | """Reset episode state. Must be called before starting a new episode."""
19 | self.episode_state = self.init_episode_state()
20 |
21 | @abc.abstractmethod
22 | def init_episode_state(self) -> msg.EpisodeState:
23 | """Build system prompt and initialize the episode state with it"""
24 | raise NotImplementedError
25 |
26 | def update_state(self, interaction: msg.Interaction):
27 | self.episode_state.update(interaction)
28 |
29 | @property
30 | def system_prompt_str(self) -> str:
31 | return self.episode_state.system_prompt.content
32 |
33 | @property
34 | def user_query_prompt_str(self) -> str:
35 | queries = [
36 | i.response
37 | for i in self.episode_state.interactions
38 | if isinstance(i.response, msg.UserQueryToAgent)
39 | ]
40 | assert len(queries) == 1
41 | return queries[0].message
42 |
43 | @abc.abstractmethod
44 | def get_action(self) -> msg.CodeNavAction:
45 | raise NotImplementedError
46 |
47 | @abc.abstractmethod
48 | def summarize_episode_state_for_restart(
49 | self, episode_state: msg.EpisodeState, max_tokens: Optional[int]
50 | ) -> str:
51 | raise NotImplementedError
52 |
53 | @property
54 | @abc.abstractmethod
55 | def all_queries_and_responses(self) -> List[Dict[str, Any]]:
56 | raise NotImplementedError
57 |
--------------------------------------------------------------------------------
/codenav/agents/cohere/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/agents/cohere/__init__.py
--------------------------------------------------------------------------------
/codenav/agents/cohere/agent.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import time
3 | from typing import List, Literal, Optional, Sequence
4 |
5 | from cohere import ChatMessage, InternalServerError, TooManyRequestsError
6 |
7 | import codenav.interaction.messages as msg
8 | from codenav.agents.interaction_formatters import InteractionFormatter
9 | from codenav.agents.llm_chat_agent import LLMChatCodeNavAgent, LlmChatMessage
10 | from codenav.constants import COHERE_CLIENT
11 |
12 |
13 | class CohereCodeNavAgent(LLMChatCodeNavAgent):
14 | def __init__(
15 | self,
16 | prompt: str,
17 | model: Literal["command-r", "command-r-plus"] = "command-r",
18 | max_tokens: int = 50000,
19 | allowed_action_types: Sequence[msg.ACTION_TYPES] = (
20 | "code",
21 | "done",
22 | "search",
23 | "reset",
24 | ),
25 | prompt_set: str = "default",
26 | interaction_formatter: Optional[InteractionFormatter] = None,
27 | ):
28 | super().__init__(
29 | model=model,
30 | prompt=prompt,
31 | max_tokens=max_tokens,
32 | allowed_action_types=allowed_action_types,
33 | interaction_formatter=interaction_formatter,
34 | )
35 |
36 | @property
37 | def client(self):
38 | return COHERE_CLIENT
39 |
40 | def query_llm(
41 | self,
42 | messages: List[LlmChatMessage],
43 | model: Optional[str] = None,
44 | max_tokens: Optional[int] = None,
45 | ) -> str:
46 | if model is None:
47 | model = self.model
48 |
49 | if max_tokens is None:
50 | max_tokens = self.max_tokens
51 |
52 | output = None
53 | nretries = 50
54 | for retry in range(nretries):
55 | try:
56 | output = self.client.chat(
57 | message=messages[-1]["message"],
58 | model=model,
59 | chat_history=messages[:-1],
60 | prompt_truncation="OFF",
61 | temperature=0.0,
62 | max_input_tokens=max_tokens,
63 | max_tokens=3000,
64 | )
65 | break
66 | except (TooManyRequestsError, InternalServerError):
67 | pass
68 |
69 | if retry >= nretries - 1:
70 | raise RuntimeError(f"Hit max retries ({nretries})")
71 |
72 | time.sleep(5)
73 |
74 | self._all_queries_and_responses.append(
75 | {
76 | "input": copy.deepcopy(messages),
77 | "output": output.text,
78 | "input_tokens": output.meta.billed_units.input_tokens,
79 | "output_tokens": output.meta.billed_units.output_tokens,
80 | }
81 | )
82 | return output.text
83 |
84 | def create_message_from_text(self, text: str, role: str) -> ChatMessage:
85 | role = {
86 | "assistant": "CHATBOT",
87 | "user": "USER",
88 | "system": "SYSTEM",
89 | }[role]
90 |
91 | return dict(role=role, message=text)
92 |
--------------------------------------------------------------------------------
/codenav/agents/gpt4/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/agents/gpt4/__init__.py
--------------------------------------------------------------------------------
/codenav/agents/gpt4/agent.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence
2 |
3 | import codenav.interaction.messages as msg
4 | from codenav.agents.interaction_formatters import InteractionFormatter
5 | from codenav.agents.llm_chat_agent import LLMChatCodeNavAgent, LlmChatMessage
6 | from codenav.constants import DEFAULT_OPENAI_MODEL, OPENAI_CLIENT
7 | from codenav.utils.llm_utils import create_openai_message
8 |
9 |
10 | class OpenAICodeNavAgent(LLMChatCodeNavAgent):
11 | def __init__(
12 | self,
13 | prompt: str,
14 | model: str = DEFAULT_OPENAI_MODEL,
15 | max_tokens: int = 50000,
16 | allowed_action_types: Sequence[msg.ACTION_TYPES] = (
17 | "code",
18 | "done",
19 | "search",
20 | "reset",
21 | ),
22 | interaction_formatter: Optional[InteractionFormatter] = None,
23 | ):
24 | super().__init__(
25 | model=model,
26 | prompt=prompt,
27 | max_tokens=max_tokens,
28 | allowed_action_types=allowed_action_types,
29 | interaction_formatter=interaction_formatter,
30 | )
31 |
32 | @property
33 | def client(self):
34 | return OPENAI_CLIENT
35 |
36 | def create_message_from_text(self, text: str, role: str) -> LlmChatMessage:
37 | return create_openai_message(text=text, role=role)
38 |
--------------------------------------------------------------------------------
/codenav/agents/interaction_formatters.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from types import MappingProxyType
3 | from typing import Mapping, Literal
4 |
5 | from codenav.interaction.messages import (
6 | CodeNavAction,
7 | RESPONSE_TYPES,
8 | Interaction,
9 | MultiRetrievalResult,
10 | RetrievalResult,
11 | )
12 | from codenav.retrieval.elasticsearch.create_index import es_doc_to_string
13 |
14 | DEFAULT_ACTION_FORMAT_KWARGS = MappingProxyType(
15 | dict(include_header=True),
16 | )
17 |
18 | DEFAULT_RESPONSE_FORMAT_KWARGS = MappingProxyType(
19 | dict(
20 | include_code=False,
21 | display_updated_vars=True,
22 | include_query=True,
23 | include_header=True,
24 | )
25 | )
26 |
27 |
28 | class InteractionFormatter(abc.ABC):
29 | def format_action(self, action: CodeNavAction):
30 | raise NotImplementedError
31 |
32 | def format_response(self, response: RESPONSE_TYPES):
33 | raise NotImplementedError
34 |
35 |
36 | class DefaultInteractionFormatter(InteractionFormatter):
37 | def __init__(
38 | self,
39 | action_format_kwargs: Mapping[str, bool] = DEFAULT_ACTION_FORMAT_KWARGS,
40 | response_format_kwargs: Mapping[str, bool] = DEFAULT_RESPONSE_FORMAT_KWARGS,
41 | ):
42 | self.action_format_kwargs = action_format_kwargs
43 | self.response_format_kwargs = response_format_kwargs
44 |
45 | def format_action(self, action: CodeNavAction):
46 | return Interaction.format_action(
47 | action,
48 | **self.action_format_kwargs,
49 | )
50 |
51 | def format_response(self, response: RESPONSE_TYPES):
52 | return Interaction.format_response(
53 | response,
54 | **self.response_format_kwargs,
55 | )
56 |
57 |
58 | class CustomRetrievalInteractionFormatter(DefaultInteractionFormatter):
59 | def __init__(
60 | self, use_summary: Literal["ifshorter", "always", "never", "prototype"]
61 | ):
62 | super().__init__()
63 | self.use_summary = use_summary
64 |
65 | def format_retrieval_result(self, rr: RetrievalResult, include_query=True):
66 | res_str = ""
67 | if include_query:
68 | res_str += f"QUERY:\n{rr.query}\n\n"
69 |
70 | res_str += "CODE BLOCKS:\n"
71 |
72 | if len(rr.es_docs) == 0:
73 | if rr.failure_reason is not None:
74 | res_str += f"Failed to retrieve code blocks: {rr.failure_reason}\n"
75 | else:
76 | res_str += "No code blocks found.\n"
77 |
78 | return res_str
79 |
80 | for doc in rr.es_docs[: rr.max_expanded]:
81 | if self.use_summary == "prototype":
82 | doc = {**doc}
83 | doc["text"] = doc["prototype"] or doc["text"]
84 |
85 | use_summary = "never"
86 | else:
87 | use_summary = self.use_summary
88 |
89 | res_str += f"---\n{es_doc_to_string(doc, use_summary=use_summary)}\n"
90 |
91 | res_str += "---\n"
92 |
93 | unexpanded_docs = rr.es_docs[rr.max_expanded :]
94 | if len(unexpanded_docs) <= rr.max_expanded:
95 | res_str += "(All code blocks matching the query were returned.)\n"
96 | else:
97 | res_str += (
98 | f"({len(unexpanded_docs)} additional code blocks not shown."
99 | f" Search again with the same query to see additional results.)\n\n"
100 | )
101 |
102 | if rr.max_prototype > 0:
103 | prototypes_docs = [
104 | doc
105 | for doc in unexpanded_docs
106 | if doc["type"] in {"CLASS", "FUNCTION"}
107 | ]
108 | num_prototype_docs_shown = min(len(prototypes_docs), rr.max_prototype)
109 | res_str += (
110 | f"Prototypes for the next {num_prototype_docs_shown} out of"
111 | f" {len(prototypes_docs)} classes/functions found in unexpanded results"
112 | f" (search again with the same query to see details):\n"
113 | )
114 | for doc in prototypes_docs[:num_prototype_docs_shown]:
115 | res_str += f"{es_doc_to_string(doc, prototype=True)}\n"
116 |
117 | return res_str
118 |
119 | def format_response(self, response: RESPONSE_TYPES):
120 | if not isinstance(response, MultiRetrievalResult):
121 | return super(CustomRetrievalInteractionFormatter, self).format_response(
122 | response
123 | )
124 | else:
125 | res_str = ""
126 | for res in response.retrieval_results:
127 | res_str += self.format_retrieval_result(res, include_query=True)
128 | res_str += "\n"
129 |
130 | return res_str
131 |
--------------------------------------------------------------------------------
/codenav/agents/llm_chat_agent.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import traceback
3 | from typing import Any, Dict, List, Optional, Sequence, Union, cast
4 |
5 | from openai.types.chat import (
6 | ChatCompletionAssistantMessageParam,
7 | ChatCompletionSystemMessageParam,
8 | ChatCompletionUserMessageParam,
9 | )
10 |
11 | import codenav.interaction.messages as msg
12 | from codenav.agents.agent import CodeNavAgent
13 | from codenav.agents.interaction_formatters import (
14 | DefaultInteractionFormatter,
15 | InteractionFormatter,
16 | )
17 | from codenav.constants import (
18 | TOGETHER_CLIENT,
19 | )
20 | from codenav.prompts.restart_prompt import RESTART_PROMPT
21 | from codenav.utils.llm_utils import MaxTokensExceededError, query_gpt
22 |
23 | LlmChatMessage = Union[
24 | ChatCompletionSystemMessageParam,
25 | ChatCompletionUserMessageParam,
26 | ChatCompletionAssistantMessageParam,
27 | ]
28 |
29 |
30 | class LLMChatCodeNavAgent(CodeNavAgent):
31 | def __init__(
32 | self,
33 | model: str,
34 | prompt: str,
35 | max_tokens: int = 50000,
36 | allowed_action_types: Sequence[msg.ACTION_TYPES] = (
37 | "code",
38 | "done",
39 | "search",
40 | "reset",
41 | ),
42 | interaction_formatter: Optional[InteractionFormatter] = None,
43 | ):
44 | super().__init__(allowed_action_types=allowed_action_types)
45 | self.model = model
46 | self.prompt = prompt
47 | self.max_tokens = max_tokens
48 |
49 | if interaction_formatter is None:
50 | self.interaction_formatter = DefaultInteractionFormatter()
51 | else:
52 | self.interaction_formatter = interaction_formatter
53 |
54 | self._all_queries_and_responses: List[Dict] = []
55 |
56 | @property
57 | def client(self):
58 | return TOGETHER_CLIENT
59 |
60 | def query_llm(
61 | self,
62 | messages: List[LlmChatMessage],
63 | model: Optional[str] = None,
64 | max_tokens: Optional[int] = None,
65 | ) -> str:
66 | if model is None:
67 | model = self.model
68 |
69 | if max_tokens is None:
70 | max_tokens = self.max_tokens
71 |
72 | response_dict = query_gpt(
73 | messages=messages,
74 | model=model,
75 | max_tokens=max_tokens,
76 | client=self.client,
77 | return_input_output_tokens=True,
78 | )
79 | self._all_queries_and_responses.append(
80 | {"input": copy.deepcopy(messages), **response_dict}
81 | )
82 | return response_dict["output"]
83 |
84 | @property
85 | def all_queries_and_responses(self) -> List[Dict[str, Any]]:
86 | return copy.deepcopy(self._all_queries_and_responses)
87 |
88 | def init_episode_state(self) -> msg.EpisodeState:
89 | return msg.EpisodeState(system_prompt=msg.SystemPrompt(content=self.prompt))
90 |
91 | def create_message_from_text(self, text: str, role: str) -> LlmChatMessage:
92 | return cast(LlmChatMessage, {"role": role, "content": text})
93 |
94 | def build_chat_context(
95 | self,
96 | episode_state: msg.EpisodeState,
97 | ) -> List[LlmChatMessage]:
98 | chat_messages: List[LlmChatMessage] = [
99 | self.create_message_from_text(
100 | text=episode_state.system_prompt.content, role="system"
101 | ),
102 | ]
103 | for interaction in episode_state.interactions:
104 | if interaction.hidden:
105 | continue
106 |
107 | if interaction.action is not None:
108 | chat_messages.append(
109 | self.create_message_from_text(
110 | text=self.interaction_formatter.format_action(
111 | interaction.action
112 | ),
113 | role="assistant",
114 | )
115 | )
116 |
117 | if interaction.response is not None:
118 | chat_messages.append(
119 | self.create_message_from_text(
120 | text=self.interaction_formatter.format_response(
121 | interaction.response,
122 | ),
123 | role="user",
124 | )
125 | )
126 |
127 | return chat_messages
128 |
129 | def summarize_episode_state_for_restart(
130 | self, episode_state: msg.EpisodeState, max_tokens: Optional[int]
131 | ) -> str:
132 | if max_tokens is None:
133 | max_tokens = self.max_tokens
134 |
135 | chat_messages = self.build_chat_context(episode_state)
136 |
137 | current_env_var_names = {
138 | var_name
139 | for i in episode_state.interactions
140 | if isinstance(i.response, msg.ExecutionResult)
141 | for var_name in (i.response.updated_vars or {}).keys()
142 | }
143 |
144 | chat_messages.append(
145 | self.create_message_from_text(
146 | text=RESTART_PROMPT.format(
147 | current_env_var_names=", ".join(sorted(list(current_env_var_names)))
148 | ),
149 | role="user",
150 | )
151 | )
152 |
153 | return self.query_llm(messages=chat_messages, max_tokens=max_tokens)
154 |
155 | def get_action(self) -> msg.CodeNavAction:
156 | chat_messages = self.build_chat_context(self.episode_state)
157 | try:
158 | output = self.query_llm(messages=chat_messages)
159 | return msg.CodeNavAction.from_text(output)
160 | except MaxTokensExceededError:
161 | if "reset" in self.allowed_action_types:
162 | summary = self.summarize_episode_state_for_restart(
163 | episode_state=self.episode_state, max_tokens=2 * self.max_tokens
164 | )
165 | for interaction in self.episode_state.interactions:
166 | if not isinstance(interaction.response, msg.UserQueryToAgent):
167 | if not interaction.hidden:
168 | interaction.hidden_at_index = len(
169 | self.episode_state.interactions
170 | )
171 | interaction.hidden = True
172 | action = msg.CodeNavAction.from_text(summary)
173 | action.type = "reset"
174 | return action
175 |
176 | return msg.CodeNavAction(
177 | thought=f"Max tokens exceeded:\n{traceback.format_exc()}",
178 | type="done",
179 | content="False",
180 | )
181 |
--------------------------------------------------------------------------------
/codenav/codenav_run.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import re
4 | import subprocess
5 | import time
6 | from argparse import ArgumentParser
7 | from typing import Any, Dict, Optional, Sequence
8 |
9 | import attrs
10 | from elasticsearch import Elasticsearch
11 |
12 | from codenav.default_eval_spec import run_codenav_on_query
13 | from codenav.interaction.episode import Episode
14 | from codenav.retrieval.elasticsearch.index_codebase import (
15 | DEFAULT_ES_HOST,
16 | DEFAULT_ES_PORT,
17 | DEFAULT_KIBANA_PORT,
18 | build_index,
19 | )
20 | from codenav.retrieval.elasticsearch.install_elasticsearch import (
21 | ES_PATH,
22 | KIBANA_PATH,
23 | install_elasticsearch,
24 | is_es_installed,
25 | )
26 |
27 |
28 | def is_port_in_use(port: int) -> bool:
29 | import socket
30 |
31 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
32 | return s.connect_ex(("localhost", port)) == 0
33 |
34 |
35 | def is_es_running():
36 | es = Elasticsearch(DEFAULT_ES_HOST)
37 | return es.ping()
38 |
39 |
40 | def run_init():
41 | es = Elasticsearch(DEFAULT_ES_HOST)
42 | if es.ping():
43 | print(
44 | "Initialization complete, Elasticsearch is already running at http://localhost:9200."
45 | )
46 | return
47 |
48 | if not is_es_installed():
49 | print("Elasticsearch installation not found, downloading...")
50 | install_elasticsearch()
51 |
52 | if not is_es_installed():
53 | raise ValueError("Elasticsearch installation failed")
54 |
55 | if is_port_in_use(DEFAULT_ES_PORT) or is_port_in_use(DEFAULT_KIBANA_PORT):
56 | raise ValueError(
57 | f"The ports {DEFAULT_ES_PORT} and {DEFAULT_KIBANA_PORT} are already in use,"
58 | f" to start elasticsearch we require that these ports are free."
59 | )
60 |
61 | cmd = os.path.join(ES_PATH, "bin", "elasticsearch")
62 | print(f"Starting Elasticsearch server with command: {cmd}")
63 | es_process = subprocess.Popen(
64 | [cmd], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
65 | )
66 |
67 | cmd = os.path.join(KIBANA_PATH, "bin", "kibana")
68 | print(f"Starting Kibana server with command: {cmd}")
69 | kibana_process = subprocess.Popen(
70 | [cmd], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
71 | )
72 |
73 | es_started = False
74 | kibana_started = False
75 | try:
76 | for _ in range(10):
77 | es_started = es.ping()
78 | kibana_started = is_port_in_use(DEFAULT_KIBANA_PORT)
79 |
80 | if not es_started:
81 | print("Elasticsearch server not started yet...")
82 |
83 | if not kibana_started:
84 | print("Kibana server not started yet...")
85 |
86 | if es_started and kibana_started:
87 | break
88 |
89 | print("Waiting 10 seconds...")
90 | time.sleep(10)
91 |
92 | if not (es_started and kibana_started):
93 | raise RuntimeError("Elasticsearch failed to start")
94 |
95 | finally:
96 | if not (es_started and kibana_started):
97 | es_process.kill()
98 | kibana_process.kill()
99 |
100 | # noinspection PyUnreachableCode
101 | print(
102 | f"Initialization complete. "
103 | f" Elasticsearch server started successfully (PID {es_process.pid}) and can be accessed at {DEFAULT_ES_PORT}."
104 | f" You can also access the Kibana dashboard (PID {kibana_process.pid}) at {DEFAULT_KIBANA_PORT}."
105 | f" You will need to manually stop these processes when you are done with them."
106 | )
107 |
108 |
109 | def main():
110 | parser = ArgumentParser()
111 | parser.add_argument(
112 | "command",
113 | help="command to be executed",
114 | choices=["init", "stop", "query"],
115 | )
116 | parser.add_argument(
117 | "--code_dir",
118 | type=str,
119 | default=None,
120 | help="Path to the codebase to use. Only one of `code_dir` or `module` should be provided.",
121 | )
122 | parser.add_argument(
123 | "--module",
124 | type=str,
125 | default=None,
126 | help="Module to use for the codebase. Only one of `code_dir` or `module` should be provided.",
127 | )
128 |
129 | parser.add_argument(
130 | "--playground_dir",
131 | type=str,
132 | default=None,
133 | help="The working directory for the agent.",
134 | )
135 | parser.add_argument(
136 | "--max_steps",
137 | type=int,
138 | default=20,
139 | help="Maximum number of rounds of interaction.",
140 | )
141 | parser.add_argument(
142 | "-q",
143 | "--q",
144 | "--query",
145 | type=str,
146 | help="A description of the problem you want the the agent to solve (using `code_dir`).",
147 | )
148 | parser.add_argument(
149 | "-f",
150 | "--query_file",
151 | type=str,
152 | default=None,
153 | help="A path to a file containing your query (useful for long/detailed queries that are hard to enter on the commandline).",
154 | )
155 | parser.add_argument(
156 | "--force_reindex",
157 | action="store_true",
158 | help="Will delete the existing index (if any) and refresh it.",
159 | )
160 |
161 | parser.add_argument(
162 | "--force_subdir",
163 | type=str,
164 | default=None,
165 | help="Index only a subdirectory of the code_dir",
166 | )
167 |
168 | parser.add_argument(
169 | "--repo_description_path",
170 | type=str,
171 | default=None,
172 | help="Path to a file containing a description of the codebase.",
173 | )
174 |
175 | args = parser.parse_args()
176 |
177 | if args.command == "init":
178 | run_init()
179 | elif args.command == "stop":
180 | # Find all processes that start with ES_PATH and KIBANA_PATH and kill them
181 | for path in [ES_PATH, KIBANA_PATH]:
182 | cmd = f"ps aux | grep {path} | grep -v grep | awk '{{print $2}}' | xargs kill "
183 | subprocess.run(cmd, shell=True)
184 | elif args.command == "query":
185 | if not is_es_running():
186 | raise ValueError(
187 | "Elasticsearch not running, please run `codenav init` first."
188 | )
189 |
190 | assert (args.q is None) != (
191 | args.query_file is None
192 | ), "Exactly one of `q` or `query_file` should be provided"
193 |
194 | if args.query_file is not None:
195 | print(args.query_file)
196 | with open(args.query_file, "r") as f:
197 | args.q = f.read()
198 |
199 | if args.q is None:
200 | raise ValueError("No query provided")
201 |
202 | if args.code_dir is None == args.module is None:
203 | raise ValueError("Exactly one of `code_dir` or `module` should be provided")
204 |
205 | if args.code_dir is None and args.module is None:
206 | raise ValueError("No code_dir or module provided")
207 |
208 | if args.playground_dir is None:
209 | raise ValueError("No playground_dir provided")
210 |
211 | if args.code_dir is None:
212 | path_to_module = os.path.abspath(
213 | os.path.dirname(importlib.import_module(args.module).__file__)
214 | )
215 | args.code_dir = os.path.dirname(path_to_module)
216 | code_name = os.path.basename(path_to_module)
217 | force_subdir = code_name
218 | sys_path = os.path.dirname(path_to_module)
219 | else:
220 | force_subdir = args.force_subdir
221 | args.code_dir = os.path.abspath(args.code_dir)
222 | sys_path = args.code_dir
223 | code_name = os.path.basename(args.code_dir)
224 |
225 | args.playground_dir = os.path.abspath(args.playground_dir)
226 |
227 | if args.force_reindex or not Elasticsearch(DEFAULT_ES_HOST).indices.exists(
228 | index=code_name
229 | ):
230 | print(f"Index {code_name} not found, creating index...")
231 | build_index(
232 | code_dir=args.code_dir,
233 | index_uid=code_name,
234 | delete_index=args.force_reindex,
235 | force_subdir=force_subdir,
236 | )
237 |
238 | run_codenav_on_query(
239 | exp_name=re.sub("[^A-Za-z0–9 ]", "", args.q).replace(" ", "_")[:30],
240 | out_dir=args.playground_dir,
241 | query=args.q,
242 | code_dir=args.code_dir if args.module is None else args.module,
243 | sys_paths=[sys_path],
244 | index_name=code_name,
245 | working_dir=args.playground_dir,
246 | max_steps=args.max_steps,
247 | repo_description_path=args.repo_description_path,
248 | )
249 |
250 | else:
251 | raise ValueError(f"Unrecognized command: {args.command}")
252 |
253 |
254 | if __name__ == "__main__":
255 | main()
256 |
--------------------------------------------------------------------------------
/codenav/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | from pathlib import Path
4 | from typing import Optional
5 |
6 | ABS_PATH_OF_CODENAV_DIR = os.path.abspath(os.path.dirname(Path(__file__)))
7 | PROMPTS_DIR = os.path.join(ABS_PATH_OF_CODENAV_DIR, "prompts")
8 |
9 | DEFAULT_OPENAI_MODEL = "gpt-4o-2024-05-13"
10 | DEFAULT_RETRIEVAL_PER_QUERY = 3
11 |
12 |
13 | def get_env_var(name: str) -> Optional[str]:
14 | if name in os.environ:
15 | return os.environ[name]
16 | else:
17 | return None
18 |
19 |
20 | OPENAI_API_KEY = get_env_var("OPENAI_API_KEY")
21 | OPENAI_ORG = get_env_var("OPENAI_ORG")
22 | OPENAI_CLIENT = None
23 | try:
24 | from openai import OpenAI
25 |
26 | if OPENAI_API_KEY is not None and OPENAI_ORG is not None:
27 | OPENAI_CLIENT = OpenAI(
28 | api_key=OPENAI_API_KEY,
29 | organization=OPENAI_ORG,
30 | )
31 | else:
32 | warnings.warn(
33 | "OpenAI_API_KEY and OPENAI_ORG not set. OpenAI API will not work."
34 | )
35 | except ImportError:
36 | warnings.warn("openai package not found. OpenAI API will not work.")
37 |
38 |
39 | TOGETHER_API_KEY = get_env_var("TOGETHER_API_KEY")
40 | TOGETHER_CLIENT = None
41 | try:
42 | from together import Together
43 |
44 | if TOGETHER_API_KEY is not None:
45 | TOGETHER_CLIENT = Together(api_key=TOGETHER_API_KEY)
46 | else:
47 | warnings.warn("TOGETHER_API_KEY not set. Together API will not work.")
48 | except ImportError:
49 | warnings.warn("together package not found. Together API will not work.")
50 |
51 |
52 | COHERE_API_KEY = get_env_var("COHERE_API_KEY")
53 | COHERE_CLIENT = None
54 | try:
55 | import cohere
56 |
57 | if COHERE_API_KEY is not None:
58 | COHERE_CLIENT = cohere.Client(api_key=COHERE_API_KEY)
59 | else:
60 | warnings.warn("COHERE_API_KEY not set. Cohere API will not work.")
61 | except ImportError:
62 | warnings.warn("cohere package not found. Cohere API will not work.")
63 |
--------------------------------------------------------------------------------
/codenav/default_eval_spec.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, List, Optional
3 |
4 | from codenav.agents.gpt4.agent import OpenAICodeNavAgent
5 | from codenav.constants import ABS_PATH_OF_CODENAV_DIR, DEFAULT_OPENAI_MODEL
6 | from codenav.environments.code_env import PythonCodeEnv
7 | from codenav.environments.done_env import DoneEnv
8 | from codenav.environments.retrieval_env import EsCodeRetriever, RetrievalEnv
9 | from codenav.interaction.episode import Episode
10 | from codenav.retrieval.elasticsearch.elasticsearch_constants import RESERVED_CHARACTERS
11 | from codenav.retrieval.elasticsearch.index_codebase import DEFAULT_ES_HOST
12 | from codenav.utils.eval_types import EvalInput, EvalSpec, Str2AnyDict
13 | from codenav.utils.evaluator import CodenavEvaluator
14 | from codenav.utils.prompt_utils import PROMPTS_DIR, PromptBuilder
15 |
16 |
17 | class DefaultEvalSpec(EvalSpec):
18 | def __init__(
19 | self,
20 | episode_kwargs: Str2AnyDict,
21 | interaction_kwargs: Str2AnyDict,
22 | logging_kwargs: Str2AnyDict,
23 | ):
24 | super().__init__(episode_kwargs, interaction_kwargs, logging_kwargs)
25 |
26 | @staticmethod
27 | def build_episode(
28 | eval_input: EvalInput,
29 | episode_kwargs: Optional[Str2AnyDict] = None,
30 | ) -> Episode:
31 | assert episode_kwargs is not None
32 |
33 | prompt_builder = PromptBuilder(
34 | prompt_dirs=episode_kwargs["prompt_dirs"],
35 | repo_description=episode_kwargs["repo_description"],
36 | )
37 | prompt = prompt_builder.build(
38 | dict(
39 | AVAILABLE_ACTIONS=episode_kwargs["allowed_actions"],
40 | RESERVED_CHARACTERS=RESERVED_CHARACTERS,
41 | RETRIEVALS_PER_KEYWORD=episode_kwargs["retrievals_per_keyword"],
42 | )
43 | )
44 |
45 | return Episode(
46 | agent=OpenAICodeNavAgent(
47 | prompt=prompt,
48 | model=episode_kwargs["llm"],
49 | allowed_action_types=episode_kwargs["allowed_actions"],
50 | ),
51 | action_type_to_env=dict(
52 | code=PythonCodeEnv(
53 | code_dir=episode_kwargs["code_dir"],
54 | sys_paths=episode_kwargs["sys_paths"],
55 | working_dir=episode_kwargs["working_dir"],
56 | ),
57 | search=RetrievalEnv(
58 | code_retriever=EsCodeRetriever(
59 | index_name=episode_kwargs["index_name"],
60 | host=episode_kwargs["host"],
61 | ),
62 | expansions_per_query=episode_kwargs["retrievals_per_keyword"],
63 | prototypes_per_query=episode_kwargs["prototypes_per_keyword"],
64 | summarize_code=False,
65 | ),
66 | done=DoneEnv(),
67 | ),
68 | user_query_str=eval_input.query,
69 | )
70 |
71 | @staticmethod
72 | def run_interaction(
73 | episode: Episode,
74 | interaction_kwargs: Optional[Str2AnyDict] = None,
75 | ) -> Str2AnyDict:
76 | assert interaction_kwargs is not None
77 | episode.step_until_max_steps_or_success(
78 | max_steps=interaction_kwargs["max_steps"],
79 | verbose=interaction_kwargs["verbose"],
80 | )
81 | ipynb_str = episode.to_notebook(cur_dir=episode.code_env.working_dir)
82 | return dict(ipynb_str=ipynb_str)
83 |
84 | @staticmethod
85 | def log_output(
86 | interaction_output: Str2AnyDict,
87 | eval_input: EvalInput,
88 | logging_kwargs: Optional[Str2AnyDict] = None,
89 | ) -> Any:
90 | assert logging_kwargs is not None
91 |
92 | outfile = os.path.join(logging_kwargs["out_dir"], f"{eval_input.uid}.ipynb")
93 | with open(outfile, "w") as f:
94 | f.write(interaction_output["ipynb_str"])
95 |
96 | return outfile
97 |
98 |
99 | def run_codenav_on_query(
100 | exp_name: str,
101 | out_dir: str,
102 | query: str,
103 | code_dir: str,
104 | index_name: str,
105 | working_dir: str = os.path.join(
106 | os.path.dirname(ABS_PATH_OF_CODENAV_DIR), "playground"
107 | ),
108 | sys_paths: Optional[List[str]] = None,
109 | repo_description_path: Optional[str] = None,
110 | es_host: str = DEFAULT_ES_HOST,
111 | max_steps: int = 20,
112 | ):
113 | prompt_dirs = [PROMPTS_DIR]
114 | repo_description = "default/repo_description.txt"
115 | if repo_description_path is not None:
116 | prompt_dir, repo_description = os.path.split(repo_description_path)
117 | conflict_path = os.path.join(PROMPTS_DIR, repo_description)
118 | if os.path.exists(conflict_path):
119 | raise ValueError(
120 | f"Prompt conflict detected: {repo_description} already exists in {PROMPTS_DIR}. "
121 | f"Please rename the {repo_description_path} file to resolve this conflict."
122 | )
123 | # append to front
124 | prompt_dirs.append(prompt_dir)
125 |
126 | episode_kwargs = dict(
127 | allowed_actions=["done", "code", "search"],
128 | repo_description=repo_description,
129 | retrievals_per_keyword=3,
130 | prototypes_per_keyword=7,
131 | llm=DEFAULT_OPENAI_MODEL,
132 | code_dir=code_dir,
133 | sys_paths=[] if sys_paths is None else sys_paths,
134 | working_dir=working_dir,
135 | index_name=index_name,
136 | host=es_host,
137 | prompt_dirs=prompt_dirs,
138 | )
139 | interaction_kwargs = dict(max_steps=max_steps, verbose=True)
140 | logging_kwargs = dict(out_dir=out_dir)
141 |
142 | # Run CodeNav on the query
143 | outfile = CodenavEvaluator.evaluate_input(
144 | eval_input=EvalInput(uid=exp_name, query=query),
145 | eval_spec=DefaultEvalSpec(
146 | episode_kwargs=episode_kwargs,
147 | interaction_kwargs=interaction_kwargs,
148 | logging_kwargs=logging_kwargs,
149 | ),
150 | )
151 |
152 | print("Output saved to ", outfile)
153 |
--------------------------------------------------------------------------------
/codenav/environments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/environments/__init__.py
--------------------------------------------------------------------------------
/codenav/environments/abstractions.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Tuple, Optional, Union
3 |
4 | from codenav.interaction.messages import (
5 | CodeNavAction,
6 | InvalidAction,
7 | MultiRetrievalResult,
8 | ExecutionResult,
9 | UserMessageToAgent,
10 | )
11 |
12 |
13 | class CodeNavEnv(abc.ABC):
14 | @abc.abstractmethod
15 | def check_action_validity(self, action: CodeNavAction) -> Tuple[bool, str]:
16 | raise NotImplementedError
17 |
18 | @abc.abstractmethod
19 | def step(
20 | self, action: CodeNavAction
21 | ) -> Optional[
22 | Union[InvalidAction, MultiRetrievalResult, ExecutionResult, UserMessageToAgent]
23 | ]:
24 | raise NotImplementedError
25 |
--------------------------------------------------------------------------------
/codenav/environments/code_summary_env.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | from codenav.environments.abstractions import CodeNavEnv
4 | from codenav.interaction.messages import CodeNavAction, UserMessageToAgent
5 |
6 |
7 | class CodeSummaryEnv(CodeNavEnv):
8 | def __init__(self) -> None:
9 | self.summary: Optional[str] = None
10 |
11 | def check_action_validity(self, action: CodeNavAction) -> Tuple[bool, str]:
12 | if action.content is None or action.content.strip() == "":
13 | return False, "No summary found in the action content."
14 |
15 | return True, ""
16 |
17 | def step(self, action: CodeNavAction) -> UserMessageToAgent:
18 | assert action.content is not None
19 | self.summary = action.content.strip()
20 | return UserMessageToAgent(message="Summary received and stored.")
21 |
--------------------------------------------------------------------------------
/codenav/environments/done_env.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from codenav.environments.abstractions import CodeNavEnv
4 | from codenav.interaction.messages import CodeNavAction
5 |
6 |
7 | class DoneEnv(CodeNavEnv):
8 | def check_action_validity(self, action: CodeNavAction) -> Tuple[bool, str]:
9 | assert action.content is not None
10 |
11 | if action.content.strip().lower() in ["true", "false"]:
12 | return True, ""
13 | else:
14 | return (
15 | False,
16 | "When executing the done action, the content must be either 'True' or 'False'",
17 | )
18 |
19 | def step(self, action: CodeNavAction) -> None:
20 | return None
21 |
--------------------------------------------------------------------------------
/codenav/environments/retrieval_env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import traceback
3 | from typing import Dict, List, Sequence, Tuple, Union
4 |
5 | import codenav.interaction.messages as msg
6 | from codenav.environments.abstractions import CodeNavEnv
7 | from codenav.interaction.messages import CodeNavAction
8 | from codenav.retrieval.code_blocks import CodeBlockType
9 | from codenav.retrieval.elasticsearch.create_index import EsDocument, es_doc_to_hash
10 | from codenav.retrieval.elasticsearch.elasticsearch_retriever import (
11 | EsCodeRetriever,
12 | parallel_add_summary_to_es_docs,
13 | )
14 |
15 |
16 | def reorder_es_docs(es_docs: List[EsDocument]) -> List[EsDocument]:
17 | scores = []
18 | for es_doc in es_docs:
19 | t = es_doc["type"]
20 | proto = es_doc.get("prototype") or ""
21 |
22 | if t in [
23 | CodeBlockType.FUNCTION.name,
24 | CodeBlockType.CLASS.name,
25 | CodeBlockType.DOCUMENTATION.name,
26 | ]:
27 | score = 0.0
28 | elif t in [CodeBlockType.IMPORT.name, CodeBlockType.ASSIGNMENT.name]:
29 | score = 1.0
30 | else:
31 | score = 2.0
32 |
33 | if (
34 | os.path.basename(es_doc["file_path"]).lower().startswith("test_")
35 | or os.path.basename(es_doc["file_path"]).lower().endswith("_test.py")
36 | or (proto is not None and "test_" in proto.lower() or "Test" in proto)
37 | ):
38 | score += 2.0
39 |
40 | scores.append(score)
41 |
42 | return [
43 | es_doc for _, es_doc in sorted(list(zip(scores, es_docs)), key=lambda x: x[0])
44 | ]
45 |
46 |
47 | class RetrievalEnv(CodeNavEnv):
48 | def __init__(
49 | self,
50 | code_retriever: EsCodeRetriever,
51 | expansions_per_query: int,
52 | prototypes_per_query: int,
53 | max_per_query: int = 100,
54 | summarize_code: bool = True,
55 | overwrite_existing_summary: bool = False,
56 | ):
57 | self.code_retriever = code_retriever
58 | self.expansions_per_query = expansions_per_query
59 | self.prototypes_per_query = prototypes_per_query
60 | self.max_per_query = max_per_query
61 | self.summarize_code = summarize_code
62 | self.overwrite_existing_summary = overwrite_existing_summary
63 |
64 | self.retrieved_es_docs: Dict[str, EsDocument] = {}
65 |
66 | def reset(self):
67 | self.retrieved_es_docs = dict()
68 |
69 | def check_action_validity(self, action: CodeNavAction) -> Tuple[bool, str]:
70 | return True, ""
71 |
72 | def _get_retrieval_result(self, query: str):
73 | query = query.strip()
74 |
75 | try:
76 | es_docs = self.code_retriever.search(
77 | query=query, default_n=self.max_per_query
78 | )
79 | except:
80 | error_msg = traceback.format_exc()
81 | print(error_msg)
82 | if "Failed to parse query" in error_msg:
83 | error_msg = (
84 | f"Failed to parse search query: {query}\n"
85 | f"Please check the syntax and try again (be careful to escape any reserved characters)."
86 | )
87 |
88 | return msg.RetrievalResult(
89 | query=query,
90 | es_docs=[],
91 | failure_reason=error_msg,
92 | )
93 |
94 | filtered_es_docs = [
95 | es_doc
96 | for es_doc in es_docs
97 | if es_doc_to_hash(es_doc) not in self.retrieved_es_docs
98 | ]
99 |
100 | failure_reason = None
101 | if len(filtered_es_docs) == 0 and len(es_docs) > 0:
102 | failure_reason = (
103 | f"All code blocks matching the query have already been returned."
104 | )
105 |
106 | filtered_es_docs = reorder_es_docs(filtered_es_docs)
107 |
108 | # todo: add summaries to filtered_es_docs if jit_summarize is True
109 | if self.summarize_code:
110 | parallel_add_summary_to_es_docs(
111 | es_docs=filtered_es_docs[: self.expansions_per_query],
112 | es=self.code_retriever.es,
113 | index_name=self.code_retriever.index_name,
114 | code_summarizer=self.code_retriever.code_summarizer,
115 | overwrite_existing=self.overwrite_existing_summary,
116 | )
117 |
118 | for es_doc in filtered_es_docs[: self.expansions_per_query]:
119 | self.retrieved_es_docs[es_doc_to_hash(es_doc)] = es_doc
120 |
121 | return msg.RetrievalResult(
122 | query=query,
123 | es_docs=filtered_es_docs,
124 | failure_reason=failure_reason,
125 | max_expanded=self.expansions_per_query,
126 | max_prototype=self.prototypes_per_query,
127 | )
128 |
129 | def step(
130 | self, action: Union[CodeNavAction, str, Sequence[str]]
131 | ) -> msg.MultiRetrievalResult:
132 | if isinstance(action, str):
133 | queries = [q.strip() for q in action.split("\n") if q.strip() != ""]
134 | elif isinstance(action, CodeNavAction):
135 | if action.content is None:
136 | queries = []
137 | else:
138 | queries = [
139 | q.strip() for q in action.content.split("\n") if q.strip() != ""
140 | ]
141 | else:
142 | queries = [q.strip() for q in action if q.strip() != ""]
143 |
144 | retrieval_results: List[msg.RetrievalResult] = [
145 | self._get_retrieval_result(query=query) for query in queries
146 | ]
147 |
148 | return msg.MultiRetrievalResult(retrieval_results=retrieval_results)
149 |
--------------------------------------------------------------------------------
/codenav/interaction/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/interaction/__init__.py
--------------------------------------------------------------------------------
/codenav/interaction/episode.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple
2 |
3 | import pandas as pd
4 |
5 | import codenav.interaction.messages as msg
6 | from codenav.agents.agent import CodeNavAgent
7 | from codenav.agents.interaction_formatters import (
8 | DefaultInteractionFormatter,
9 | InteractionFormatter,
10 | )
11 | from codenav.environments.abstractions import CodeNavEnv
12 | from codenav.environments.code_env import PythonCodeEnv
13 | from codenav.environments.retrieval_env import RetrievalEnv
14 | from codenav.prompts.query_prompt import create_user_query_message
15 | from codenav.prompts.restart_prompt import PICKUP_PROMPT
16 |
17 |
18 | class Episode:
19 | def __init__(
20 | self,
21 | agent: CodeNavAgent,
22 | action_type_to_env: Dict[msg.ACTION_TYPES, CodeNavEnv],
23 | user_query_str: str,
24 | ):
25 | self.agent = agent
26 | self.action_type_to_env = action_type_to_env
27 |
28 | self.user_query_str = user_query_str
29 | self.agent.reset()
30 | # self.agent.reset(action_type_to_env=self.action_type_to_env)
31 |
32 | assert all(k in action_type_to_env for k in ["done", "code"])
33 |
34 | @property
35 | def code_env(self) -> PythonCodeEnv:
36 | code_envs = [
37 | env
38 | for env in self.action_type_to_env.values()
39 | if isinstance(env, PythonCodeEnv)
40 | ]
41 | assert len(code_envs) == 1
42 | return code_envs[0]
43 |
44 | def check_action_validity(self, action: msg.CodeNavAction) -> Tuple[bool, str]:
45 | is_valid = True
46 | error_msg = ""
47 |
48 | if action.type == "reset":
49 | return True, ""
50 |
51 | if action.thought is None:
52 | is_valid = False
53 | error_msg += "Action should always contain thought.\n"
54 |
55 | if action.type not in self.action_type_to_env:
56 | is_valid = False
57 | error_msg += f"Action type {action.type} is not supported.\n"
58 | return is_valid, error_msg
59 |
60 | assert action.type is not None
61 | env = self.action_type_to_env[action.type]
62 |
63 | content_valid, error_msg_content = env.check_action_validity(action)
64 |
65 | is_valid = is_valid and content_valid
66 | error_msg += error_msg_content
67 |
68 | return is_valid, error_msg
69 |
70 | def step(self) -> msg.Interaction:
71 | if len(self.agent.episode_state.interactions) == 0:
72 | assert self.code_env.code_dir is not None
73 |
74 | # Start of episode, add user query
75 | self.agent.update_state(
76 | msg.Interaction(
77 | action=None,
78 | response=create_user_query_message(
79 | user_query_str=self.user_query_str,
80 | code_dir=self.code_env.code_dir,
81 | working_dir=self.code_env.working_dir,
82 | added_paths=self.code_env.sys_paths,
83 | ),
84 | )
85 | )
86 |
87 | action = self.agent.get_action()
88 |
89 | action_is_valid, action_error_msg = self.check_action_validity(action)
90 |
91 | response: Optional[msg.RESPONSE_TYPES]
92 | if action_is_valid and action.type == "reset":
93 | response = msg.UserMessageToAgent(message=PICKUP_PROMPT)
94 | for env in self.action_type_to_env.values():
95 | if isinstance(env, RetrievalEnv):
96 | env.reset()
97 |
98 | elif action_is_valid:
99 | assert action.type is not None
100 | try:
101 | response = self.action_type_to_env[action.type].step(action)
102 | except KeyboardInterrupt:
103 | print(
104 | f"Keyboard interrupt occurred while attempting to execute code:{{\n{action.content}\n}}\n"
105 | f"String of notebook before interrupt: {self.to_notebook(cur_dir=self.code_env.working_dir)}\n",
106 | flush=True,
107 | )
108 | raise
109 | else:
110 | response = msg.InvalidAction(reason=action_error_msg)
111 |
112 | interaction = msg.Interaction(action=action, response=response)
113 | self.agent.update_state(interaction=interaction)
114 |
115 | return interaction
116 |
117 | def step_until_max_steps_or_success(self, max_steps: int, verbose: bool = True):
118 | for i in range(max_steps):
119 | interaction = self.step()
120 | if verbose:
121 | print("*" * 80)
122 | print(f"Step {i+1}")
123 | print("*" * 80)
124 | print("")
125 | print(
126 | Episode.format_interaction(
127 | interaction, self.agent.interaction_formatter
128 | )
129 | )
130 | if (
131 | interaction.action is not None
132 | and interaction.action.type == "done"
133 | and not isinstance(interaction.response, msg.InvalidAction)
134 | ):
135 | break
136 |
137 | @staticmethod
138 | def get_record(
139 | interaction: msg.Interaction, formatter: InteractionFormatter
140 | ) -> dict[str, Any]:
141 | if formatter is None:
142 | formatter = DefaultInteractionFormatter()
143 |
144 | if interaction.response is None:
145 | response_text = None
146 | else:
147 | response_text = formatter.format_response(
148 | interaction.response,
149 | )
150 |
151 | action = (
152 | msg.CodeNavAction() if interaction.action is None else interaction.action
153 | )
154 | return {
155 | "action/thought": action.thought,
156 | "action/type": str(action.type),
157 | "action/content": action.content,
158 | "response": response_text,
159 | "hidden": interaction.hidden,
160 | }
161 |
162 | def tabulate_interactions(self) -> pd.DataFrame:
163 | records = []
164 | for interaction in self.agent.episode_state.interactions:
165 | records.append(
166 | Episode.get_record(interaction, self.agent.interaction_formatter)
167 | )
168 |
169 | return pd.DataFrame.from_records(records)
170 |
171 | def tabulate_exec_trace(self) -> pd.DataFrame:
172 | return self.code_env.tabulate()
173 |
174 | def tabulate_prompts(self) -> pd.DataFrame:
175 | return pd.DataFrame.from_dict(
176 | dict(
177 | prompt_type=["QUERY", "SYSTEM_PROMPT"],
178 | prompt_content=[
179 | self.agent.user_query_prompt_str,
180 | self.agent.system_prompt_str,
181 | ],
182 | )
183 | )
184 |
185 | def to_notebook(self, cur_dir: str) -> str:
186 | import nbformat
187 | from nbformat import v4 as nbf
188 |
189 | nb = nbf.new_notebook()
190 | for i, row in self.tabulate_interactions().iterrows():
191 | thought = row["action/thought"]
192 | action_type = row["action/type"]
193 | content = row["action/content"]
194 | response = row["response"]
195 | thought, action_type, content, response = (
196 | x if x is not None else "None"
197 | for x in (thought, action_type, content, response)
198 | )
199 |
200 | if i == 0:
201 | nb.cells.append(
202 | nbf.new_markdown_cell(
203 | "\n\n".join(
204 | [
205 | "# Instruction to CodeNav",
206 | self.agent.user_query_prompt_str,
207 | "# Interactions",
208 | ]
209 | )
210 | )
211 | )
212 | continue
213 |
214 | nb.cells.append(
215 | nbf.new_markdown_cell(
216 | "\n\n".join(
217 | [
218 | f"## Step {i}: {action_type}",
219 | thought,
220 | ]
221 | )
222 | )
223 | )
224 |
225 | if action_type == "done":
226 | output = nbf.new_output(
227 | output_type="stream",
228 | text="Ending episode since the agent has issued a 'done' action.",
229 | )
230 | elif action_type == "code":
231 | output = nbf.new_output(
232 | output_type="execute_result",
233 | data={"text/plain": response},
234 | )
235 | else:
236 | output = nbf.new_output(
237 | output_type="stream",
238 | text=response,
239 | )
240 |
241 | nb.cells.append(nbf.new_code_cell(content, outputs=[output]))
242 | nb.cells.append(nbf.new_markdown_cell("---"))
243 |
244 | return nbformat.writes(nb)
245 |
246 | @staticmethod
247 | def format_interaction(
248 | interaction: msg.Interaction, interaction_formatter: InteractionFormatter
249 | ) -> str:
250 | record = Episode.get_record(interaction, formatter=interaction_formatter)
251 | return Episode.format_record(record)
252 |
253 | @staticmethod
254 | def format_record(record: dict[str, Any]) -> str:
255 | inter_str = "------Action------"
256 |
257 | inter_str += f"\nTHOUGHT:\n{record['action/thought']}"
258 |
259 | inter_str += f"\nACTION TYPE:\n{record.get('action/type')}"
260 |
261 | inter_str += f"\nACTION CONTENT:\n{record.get('action/content')}"
262 |
263 | inter_str += "\n\n-----Response-----"
264 | inter_str += f"\n{record['response']}"
265 | return inter_str
266 |
--------------------------------------------------------------------------------
/codenav/interaction/messages.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Any, Dict, List, Literal, Optional, Sequence, Union, get_args
3 |
4 | import attrs
5 |
6 | from codenav.retrieval.elasticsearch.create_index import EsDocument, es_doc_to_string
7 | from codenav.utils.linting_and_type_checking_utils import (
8 | CodeAnalysisError,
9 | LintingError,
10 | TypeCheckingError,
11 | )
12 | from codenav.utils.string_utils import get_tag_content_from_text
13 |
14 |
15 | class CodeNavMessage:
16 | @abc.abstractmethod
17 | def format(self, *args, **kwargs) -> str:
18 | raise NotImplementedError
19 |
20 |
21 | @attrs.define
22 | class SystemPrompt(CodeNavMessage):
23 | content: str
24 |
25 | def format(self) -> str:
26 | return self.content
27 |
28 |
29 | ACTION_TYPES = Literal[
30 | "done", "code", "search", "reset", "code_summary", "request_user_message"
31 | ]
32 |
33 |
34 | @attrs.define
35 | class CodeNavAction(CodeNavMessage):
36 | thought: Optional[str] = None
37 | type: Optional[ACTION_TYPES] = None
38 | content: Optional[str] = None
39 |
40 | @staticmethod
41 | def get_tag_content_from_text(
42 | text: str,
43 | tag: Literal[
44 | "thought",
45 | "type",
46 | "content",
47 | "reset",
48 | "code",
49 | ],
50 | ) -> Optional[str]:
51 | return get_tag_content_from_text(text=text, tag=tag)
52 |
53 | @staticmethod
54 | def from_text(text: str) -> "CodeNavAction":
55 | thought = CodeNavAction.get_tag_content_from_text(text, "thought")
56 | type = CodeNavAction.get_tag_content_from_text(text, "type") # type: ignore
57 | content = CodeNavAction.get_tag_content_from_text(text, "content")
58 |
59 | assert type is None or type in get_args(
60 | ACTION_TYPES
61 | ), f"Invalid action type: {type} (valid types are {get_args(ACTION_TYPES)})"
62 |
63 | return CodeNavAction(
64 | thought=thought,
65 | type=type, # type: ignore
66 | content=content,
67 | )
68 |
69 | def to_tagged_text(self) -> str:
70 | return (
71 | f"\n{self.thought}\n"
72 | f"\n\n{self.type}\n"
73 | f"\n\n{self.content}\n"
74 | )
75 |
76 | def format(self) -> str:
77 | return self.to_tagged_text()
78 |
79 |
80 | @attrs.define
81 | class InvalidAction(CodeNavMessage):
82 | reason: str
83 |
84 | def format(self) -> str:
85 | return str(self)
86 |
87 |
88 | @attrs.define
89 | class ExecutionResult(CodeNavMessage):
90 | code_str: str
91 | stdout: str
92 | updated_vars: Optional[Dict[str, Any]] = None
93 | exec_error: Optional[str] = None
94 | linting_errors: Optional[List[LintingError]] = None
95 | type_checking_errors: Optional[List[TypeCheckingError]] = None
96 |
97 | @staticmethod
98 | def format_vars_with_max_len(vars: Dict[str, Any], max_len: int) -> str:
99 | """Format local variables with a maximum length per string representation."""
100 |
101 | l = []
102 | for k, v in vars.items():
103 | str_v = str(v)
104 | if len(str_v) > max_len:
105 | str_v = str_v[:max_len] + "..."
106 | l.append(f'"{k}": {str_v}')
107 |
108 | return "{" + ", ".join(l) + "}"
109 |
110 | def format(
111 | self,
112 | include_code: bool,
113 | display_updated_vars: bool,
114 | max_local_var_len: int = 500,
115 | ) -> str:
116 | res_str = ""
117 |
118 | if include_code:
119 | res_str = f"```\n{self.code_str}\n```\n"
120 |
121 | if self.stdout is not None and len(self.stdout) > 0:
122 | if len(self.stdout) > 2000:
123 | stdout_start = self.stdout[:1000]
124 | stdout_end = self.stdout[-1000:]
125 | msg = "STDOUT was too long. Showing only the start and end separated by ellipsis."
126 | res_str += f"STDOUT ({msg}):\n{stdout_start}\n\n...\n\n{stdout_end}\n"
127 | else:
128 | res_str += f"STDOUT:\n{self.stdout}\n"
129 |
130 | if self.exec_error is not None:
131 | res_str += f"EXECUTION ERROR:\n{self.exec_error}\n"
132 | elif self.stdout is None or len(self.stdout) == 0:
133 | # If there was no error or stdout, want to print something to tell the agent
134 | # that the code was executed
135 | res_str += "CODE EXECUTED WITHOUT ERROR, STDOUT WAS EMPTY\n"
136 |
137 | if (
138 | display_updated_vars
139 | and self.updated_vars is not None
140 | and len(self.updated_vars) > 0
141 | ):
142 | res_str += (
143 | f"RELEVANT VARIABLES (only shown if string rep. has changed after code exec):"
144 | f"\n{ExecutionResult.format_vars_with_max_len(self.updated_vars, max_len=max_local_var_len)}\n"
145 | )
146 |
147 | analysis_errors: List[CodeAnalysisError] = []
148 |
149 | if self.linting_errors is not None:
150 | analysis_errors.extend(self.linting_errors)
151 |
152 | if self.type_checking_errors is not None:
153 | analysis_errors.extend(self.type_checking_errors)
154 |
155 | if len(analysis_errors) > 0:
156 | res_str += "STATIC ANALYSIS ERRORS:\n"
157 | for err in analysis_errors:
158 | res_str += f"{err}\n"
159 |
160 | return res_str
161 |
162 |
163 | @attrs.define
164 | class RetrievalResult:
165 | query: str
166 | es_docs: Sequence[EsDocument]
167 | failure_reason: Optional[str] = None
168 | max_expanded: int = 3
169 | max_prototype: int = 10
170 |
171 | def format(
172 | self,
173 | include_query: bool = True,
174 | ) -> str:
175 | res_str = ""
176 | if include_query:
177 | res_str += f"QUERY:\n{self.query}\n\n"
178 |
179 | res_str += "CODE BLOCKS:\n"
180 |
181 | if len(self.es_docs) == 0:
182 | if self.failure_reason is not None:
183 | res_str += f"Failed to retrieve code blocks: {self.failure_reason}\n"
184 | else:
185 | res_str += "No code blocks found.\n"
186 |
187 | return res_str
188 |
189 | for doc in self.es_docs[: self.max_expanded]:
190 | res_str += "---\n{}\n".format(es_doc_to_string(doc, prototype=False))
191 |
192 | res_str += "---\n"
193 |
194 | unexpanded_docs = self.es_docs[self.max_expanded :]
195 | if len(unexpanded_docs) <= 0:
196 | res_str += "(All code blocks matching the query were returned.)\n"
197 | else:
198 | res_str += (
199 | f"({len(unexpanded_docs)} additional code blocks not shown."
200 | f" Search again with the same query to see additional results.)\n\n"
201 | )
202 |
203 | if self.max_prototype > 0:
204 | prototypes_docs = [
205 | doc
206 | for doc in unexpanded_docs
207 | if doc["type"] in {"CLASS", "FUNCTION"}
208 | ]
209 | num_prototype_docs_shown = min(len(prototypes_docs), self.max_prototype)
210 | res_str += (
211 | f"Prototypes for the next {num_prototype_docs_shown} out of"
212 | f" {len(prototypes_docs)} classes/functions found in unexpanded results"
213 | f" (search again with the same query to see details):\n"
214 | )
215 | for doc in prototypes_docs[:num_prototype_docs_shown]:
216 | res_str += "{}\n".format(es_doc_to_string(doc, prototype=True))
217 |
218 | return res_str
219 |
220 |
221 | @attrs.define
222 | class MultiRetrievalResult(CodeNavMessage):
223 | retrieval_results: Sequence[RetrievalResult]
224 |
225 | def format(
226 | self,
227 | include_query: bool = True,
228 | ) -> str:
229 | res_str = ""
230 | for res in self.retrieval_results:
231 | res_str += res.format(include_query)
232 | res_str += "\n"
233 |
234 | return res_str
235 |
236 |
237 | @attrs.define
238 | class UserMessageToAgent(CodeNavMessage):
239 | message: str
240 |
241 | def format(self) -> str:
242 | return self.message
243 |
244 |
245 | class UserQueryToAgent(UserMessageToAgent):
246 | pass
247 |
248 |
249 | RESPONSE_TYPES = Union[
250 | InvalidAction, MultiRetrievalResult, ExecutionResult, UserMessageToAgent
251 | ]
252 |
253 |
254 | @attrs.define
255 | class Interaction:
256 | action: Optional[CodeNavAction]
257 | response: Optional[RESPONSE_TYPES] = None
258 | hidden: bool = False
259 | hidden_at_index: bool = -1
260 |
261 | @staticmethod
262 | def format_action(action: CodeNavAction, include_header=True) -> str:
263 | if include_header:
264 | return "ACTION:\n" + action.format()
265 |
266 | return action.format()
267 |
268 | @staticmethod
269 | def format_response(
270 | response: RESPONSE_TYPES,
271 | include_code,
272 | display_updated_vars,
273 | include_query,
274 | include_header,
275 | ) -> str:
276 | if isinstance(response, ExecutionResult):
277 | response_type = "Execution Result"
278 | response_text = response.format(
279 | include_code=include_code,
280 | display_updated_vars=display_updated_vars,
281 | )
282 | elif isinstance(response, InvalidAction):
283 | response_type = "Invalid Action"
284 | response_text = response.format()
285 | elif isinstance(response, MultiRetrievalResult):
286 | response_type = "Retrieval Result"
287 | response_text = response.format(include_query=include_query)
288 | elif isinstance(response, UserMessageToAgent):
289 | response_type = "User Message"
290 | response_text = response.format()
291 | else:
292 | raise NotImplementedError()
293 |
294 | if include_header:
295 | return f"RESPONSE ({response_type}):\n" + response_text
296 |
297 | return response_text
298 |
299 |
300 | class EpisodeState:
301 | def __init__(self, system_prompt: SystemPrompt):
302 | self.system_prompt = system_prompt
303 | self.interactions: List[Interaction] = []
304 |
305 | def update(self, interaction: Interaction):
306 | self.interactions.append(interaction)
307 |
308 |
309 | def format_reponse(
310 | response: RESPONSE_TYPES,
311 | include_code=False,
312 | display_updated_vars=True,
313 | include_query=True,
314 | ) -> str:
315 | if isinstance(response, ExecutionResult):
316 | response_text = response.format(
317 | include_code=include_code,
318 | display_updated_vars=display_updated_vars,
319 | )
320 | elif isinstance(response, InvalidAction):
321 | response_text = response.format()
322 | elif isinstance(response, MultiRetrievalResult):
323 | response_text = response.format(include_query=include_query)
324 | elif isinstance(response, UserMessageToAgent):
325 | response_text = response.format()
326 | else:
327 | raise NotImplementedError()
328 |
329 | return "ACTION RESPONSE:\n" + response_text
330 |
--------------------------------------------------------------------------------
/codenav/prompts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/prompts/__init__.py
--------------------------------------------------------------------------------
/codenav/prompts/codenav/repo_description.txt:
--------------------------------------------------------------------------------
1 | ## Repo description and user requirements
2 |
3 | The codebase you will use is called `codenav`. It is a library for creating LLM agents that can interact with one of the avilable environments to solve the user queries that require using an external codebase. For example, an agent can interact with "PythonCodeEnv" for code execution, and with "RetrievalEnv" for retrieving code snippets from an ElasticSearch index. It provides an "Episode" class for running this interaction with a specific agent and a set of environments. Here's the directory structure:
4 |
5 | codenav/agents/ - contains subdirectories that store implementations of LLM agents implemented with different LLMs
6 | codenav/environments/ - contains environments that the agent can interact with
7 | codenav/interaction/ - contains Episode and messages implementations (messages is how agent interacts with environments)
8 | codenav/prompts/ - stores various system prompts for the LLM agent
9 | codenav/retrieval/ - various files related to creating elastic search index and retrieving items from the index
10 | codenav/utils/ - contains various utility python files
11 | codenav/constants.py - contains important constants
--------------------------------------------------------------------------------
/codenav/prompts/default/action__code.txt:
--------------------------------------------------------------------------------
1 | ## ACTION TYPE: code
2 |
3 | In this case ... should include your Python code to execute. Do not enclose this code in ```.
--------------------------------------------------------------------------------
/codenav/prompts/default/action__done.txt:
--------------------------------------------------------------------------------
1 | ## ACTION TYPE: done
2 |
3 | In this case ... should contain ONLY the string True or False indicating whether you have been successful or not.
--------------------------------------------------------------------------------
/codenav/prompts/default/action__es_search.txt:
--------------------------------------------------------------------------------
1 | ## ACTION TYPE: search
2 |
3 | In this case ... should include one elasticsearch "query string" per line, these queries will be used to search an elasticsearch index of the codebase. Elasticsearch's reserved characters are {RESERVED_CHARACTERS} these should be escaped with a \ if you are trying to search for them explicitly (< and > can never be escaped). The index has fields
4 |
5 | - file_path # The path of the file relative the codebase's root (not absolute)
6 | - type # The of the text block (FUNCTION, CLASS, ASSIGNMENT, IMPORT, or DOCUMENTATION)
7 | - lines # The line numbers of the block in the file
8 | - text # The code as a string
9 |
10 | You may search with ANY valid query string. Example queries:
11 |
12 | text: (classification OR "detection model") # Finds blocks containing "classification" or "detection model". The quotes ensure "detection model" is treated as a single phrase.
13 |
14 | (text: classification) AND NOT (text: linear) # Finds blocks containing "classification" but not containing "linear"
15 |
16 | ((type: FUNCTION) OR (type: CLASS)) AND (file_path: *rel\/path\/to\/code.py) # Finds all functions or classes in the file at *rel/path/to/code.py
17 |
18 | 1. Rather than searching for `def function_name` or `class ClassName` you should use prefer to use the `type` field (i.e. search for `(type: FUNCTION) AND (text: function_name)` or `(type: CLASS) AND (text: ClassName)`).
19 | 2. When searching for a file_path, forward slashes MUST BE ESCAPED (ie all / should be replaced with \/).
20 | 3. If you are searching with a (file_path: rel\/path\/to\/code.py) and not retrieving relevant results, try broadening the search using a *, ie (file_path: *path\/to\/code.py)
21 | 3. Start with simple queries and only add extra constraints if the returned results are too broad.
22 |
--------------------------------------------------------------------------------
/codenav/prompts/default/action__guidelines.txt:
--------------------------------------------------------------------------------
1 | ## ACTION GUIDELINES
2 |
3 | Five guidelines that you should always follow are:
4 |
5 | 1. Do not ask the user to verify your results: do your best to verify your own results by coding or searching.
6 | 2. Before using a function or class from the code base, search for it to read its implementation.
7 | 3. Do not produce all code at once. Break down the problem into small steps and write code step by step. Before writing the next step, wait for user to execute the code and return the results. Iterate.
8 | 4. If the user query asks for you to do things a certain way, do not make assumptions or change the requirements to make it easier for you. It is better to fail following the requirements than to succeed solving a problem that should not have been solved.
9 |
--------------------------------------------------------------------------------
/codenav/prompts/default/action__preamble.txt:
--------------------------------------------------------------------------------
1 | # AGENT ACTIONS
2 |
3 | Always produce your output in the format:
4 | ```
5 |
6 | Rationale for your action describing what you've done, where the answer is, or why you have failed. Must ALWAYS be present, be concise.
7 |
8 |
9 | The type of action you're taking, this should be one of: {AVAILABLE_ACTIONS}
10 |
11 |
12 | The content of your action corresponding to the type. How this content must be formatted for each action type is described below.
13 |
14 | ```
15 | Do not include any text outside of the above tags.
16 |
--------------------------------------------------------------------------------
/codenav/prompts/default/overview.txt:
--------------------------------------------------------------------------------
1 | # OVERVIEW
2 |
3 | You are an expert Python Programmer. You will solve a user's query by writing Python code. In particular, the user will direct you to use a specific code base. We will now describe the workflow of your interaction with the user (Section "WORKFLOW"), the kinds of outputs you are allowed to produce (Section "AGENT ACTIONS"), and the responses you might expect from the user (Section "USER RESPONSES")
4 |
--------------------------------------------------------------------------------
/codenav/prompts/default/repo_description.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/prompts/default/repo_description.txt
--------------------------------------------------------------------------------
/codenav/prompts/default/response__code.txt:
--------------------------------------------------------------------------------
1 | ## USER RESPONSE TO code
2 |
3 | If you output code, the user will execute the code and return an ExecutionResult object. This object contains the result of executing the code (stdout and new variables), python execution errors, linting errors (from flake8), and type checking errors (from mypy). Pay attention to the errors - you might need to fix them in subsequent steps. Execution errors are the most important and you must fix them or try a different approach. Linting and type checking errors can be ignored if benign.
--------------------------------------------------------------------------------
/codenav/prompts/default/response__done.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/prompts/default/response__done.txt
--------------------------------------------------------------------------------
/codenav/prompts/default/response__es_search.txt:
--------------------------------------------------------------------------------
1 | ## USER RESPONSE TO search
2 |
3 | If you output search queries, the user will search the code base using your queries and return results in the format:
4 |
5 | ---
6 | file_path=,
7 | lines=[, ],
8 | type=,
9 | content={{
10 |
11 | }}
12 | ---
13 | ... # Up to {RETRIEVALS_PER_KEYWORD} code blocks per query.
14 |
15 | Use these text blocks as reference for generating code in subsequent steps. By default the user will return {RETRIEVALS_PER_KEYWORD} code blocks per query. If you would like to see more code blocks for the same query then simply output the same query again. The user will return the next {RETRIEVALS_PER_KEYWORD} text blocks.
16 |
17 | If a code block has already been retrieved before, it will not show up in the search results again. So if you are search is not returning the desired result, it is possible the code block has already been returned in the past.
--------------------------------------------------------------------------------
/codenav/prompts/default/response__preamble.txt:
--------------------------------------------------------------------------------
1 | # USER RESPONSES
2 |
3 | The user will respond to your actions with one of the following responses:
--------------------------------------------------------------------------------
/codenav/prompts/default/workflow.txt:
--------------------------------------------------------------------------------
1 | # WORKFLOW
2 |
3 | To solve the user's query, you might need multiple rounds of interaction with the user:
4 |
5 | First, the user will give you a query and information about the code repository they want you to use. The user will tell you the absolute path to the code repo as well as your current directory (the directory in which your code would be executed). All necessary dependencies should already be installed, do not try to install missing python dependencies. If you are missing a critical dependency, write "MISSING DEPENDENCY" in your thought and take the done action with content being "False".
6 |
7 | Given this initial task information, you will interact with the user to solve the task. The interaction may consist of multiple rounds. Each round consists of an action (ie output) from you and a response to that action from the user.
8 |
--------------------------------------------------------------------------------
/codenav/prompts/query_prompt.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from codenav.interaction.messages import UserQueryToAgent
4 |
5 | PATHS = """\
6 | Use the code base located at `{CODE_DIR}` to solve this query. Your current directory is `{WORKING_DIR}`.
7 | """
8 |
9 | ADDED_TO_PATH = """\
10 | The code base path has either been installed via pip or has been already been added to the system path via
11 | ```
12 | import sys
13 | sys.path.extend({ADDED_PATH})
14 | ```
15 | """
16 |
17 | # Import instructions might be different in case of an installed library rather than a local code repo. For now, we assume that the code repo is local and not installed.
18 | IMPORT_INSTRUCTIONS = """\
19 | If the import path in retrieved code block says `testing/dir_name/file_name.py` and you want to import variable, function or class called `obj` from this file, then import using `from testing.dir_name.file_name import obj`.
20 | """
21 |
22 |
23 | def create_user_query_message(
24 | user_query_str: str,
25 | code_dir: str,
26 | working_dir: str,
27 | added_paths: Sequence[str],
28 | ):
29 | return UserQueryToAgent(
30 | message=f"USER QUERY:\n\n{user_query_str}\n\nPATHS AND IMPORT INSTRUCTIONS:"
31 | + (f"\n" f"\n{PATHS}" f"\n{IMPORT_INSTRUCTIONS}" f"\n{ADDED_TO_PATH}").format(
32 | CODE_DIR=code_dir, WORKING_DIR=working_dir, ADDED_PATH=list(added_paths)
33 | ),
34 | )
35 |
--------------------------------------------------------------------------------
/codenav/prompts/restart_prompt.py:
--------------------------------------------------------------------------------
1 | RESTART_PROMPT = """\
2 | This task is about to be handed off to another AI agent. The agent will be given the same prompt you were and will be expected to continue your work in the same Python environment. The other agent will only be shown your single next response and will not have access to your previous thoughts, keywords, code, execution results, or retrieved code blocks. With this in mind, please respond in the format:
3 |
4 |
5 | A summary of everything you have done and learned so far relevant to the task. You should be concise but thorough.
6 |
7 |
8 | Code that, if run, would result in the Python environment being in its current state. This should include all necessary imports, definitions, and any other necessary setup. Ideally, this code should not produce any errors. Add comments when necessary explaining the code and describing the state of the environment. In this code do not define any variables/functions or import any libraries that are not available in the current environment state. For reference, the current variables defined in the environment are:\n{current_env_var_names}
9 |
10 | """
11 |
12 | PICKUP_PROMPT = """\
13 | You are picking up where another AI agent left off. The previous agent has provided a summary of their work and a code snippet that, if run, would result in the Python environment being in its current state.
14 | """
15 |
--------------------------------------------------------------------------------
/codenav/retrieval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/retrieval/__init__.py
--------------------------------------------------------------------------------
/codenav/retrieval/code_blocks.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | from enum import Enum
4 | from typing import Dict, Iterable, List, Optional, Set
5 |
6 |
7 | class CodeBlockType(Enum):
8 | ASSIGNMENT = 1
9 | FUNCTION = 2
10 | CLASS = 3
11 | FILE = 4
12 | IMPORT = 5
13 | CONDITIONAL = 6
14 | DOCUMENTATION = 7
15 | OTHER = 8
16 |
17 |
18 | AST_TYPE_TO_CODE_BLOCK_TYPE = {
19 | ast.Module: CodeBlockType.FILE,
20 | ast.FunctionDef: CodeBlockType.FUNCTION,
21 | ast.ClassDef: CodeBlockType.CLASS,
22 | ast.Assign: CodeBlockType.ASSIGNMENT,
23 | ast.Import: CodeBlockType.IMPORT,
24 | ast.ImportFrom: CodeBlockType.IMPORT,
25 | ast.If: CodeBlockType.CONDITIONAL,
26 | ast.For: CodeBlockType.CONDITIONAL,
27 | }
28 |
29 |
30 | CODE_BLOCK_TEMPLATE = """file_path={file_path},
31 | lines=[{start_lineno}, {end_lineno}],
32 | type={type},
33 | content={{
34 | {code}
35 | }}"""
36 |
37 |
38 | DEFAULT_CODEBLOCK_TYPES = {
39 | CodeBlockType.FUNCTION,
40 | CodeBlockType.CLASS,
41 | CodeBlockType.ASSIGNMENT,
42 | CodeBlockType.IMPORT,
43 | CodeBlockType.CONDITIONAL,
44 | CodeBlockType.DOCUMENTATION,
45 | }
46 |
47 |
48 | class FilePath:
49 | def __init__(self, path: str, base_dir: str = "/"):
50 | self.path = os.path.abspath(path)
51 | self.base_dir = base_dir
52 |
53 | @property
54 | def file_name(self) -> str:
55 | """root/base_dir/path/to/file.py -> file.py"""
56 | return os.path.basename(self.path)
57 |
58 | @property
59 | def ext(self) -> str:
60 | """root/base_dir/path/to/file.py -> .py"""
61 | return os.path.splitext(self.path)[1]
62 |
63 | @property
64 | def rel_path(self) -> str:
65 | """root/base_dir/path/to/file.py -> path/to/file.py"""
66 | return os.path.relpath(self.path, self.base_dir)
67 |
68 | @property
69 | def import_path(self) -> str:
70 | """root/base_dir/path/to/file.py -> base_dir/path/to/file.py"""
71 | return os.path.join(os.path.basename(self.base_dir), self.rel_path)
72 |
73 | @property
74 | def abs_path(self) -> str:
75 | """root/base_dir/path/to/file.py -> /root/base_dir/path/to/file.py"""
76 | return self.path
77 |
78 | def __repr__(self) -> str:
79 | """FilePath(base_dir=root/base_dir, rel_path=path/to/file.py)"""
80 | return f"FilePath(base_dir={self.base_dir}, rel_path={self.rel_path})"
81 |
82 | def __eq__(self, other: object) -> bool:
83 | if not isinstance(other, FilePath):
84 | return NotImplemented
85 | return self.path == other.path and self.base_dir == other.base_dir
86 |
87 | def __hash__(self):
88 | return hash(str(self))
89 |
90 |
91 | def get_file_list(dir_path: str) -> list[FilePath]:
92 | file_list = []
93 | for root, dirs, files in os.walk(dir_path, followlinks=True):
94 | for file in files:
95 | file_list.append(FilePath(path=os.path.join(root, file), base_dir=dir_path))
96 | return file_list
97 |
98 |
99 | def filter_by_extension(
100 | file_list: list[FilePath], valid_extensions: Iterable[str] = (".py",)
101 | ) -> list[FilePath]:
102 | return [file_path for file_path in file_list if file_path.ext in valid_extensions]
103 |
104 |
105 | class CodeBlockASTNode:
106 | def __init__(
107 | self,
108 | ast_node: ast.AST,
109 | parent: Optional["CodeBlockASTNode"] = None,
110 | tree: Optional["CodeBlockAST"] = None,
111 | ):
112 | self.ast_node = ast_node
113 | self.parent = parent
114 | self.tree = tree
115 | self.code_summary: Optional[str] = None
116 |
117 | @staticmethod
118 | def format_code_block(
119 | file_path: str,
120 | start_lineno: Optional[int],
121 | end_lineno: Optional[int],
122 | type: str,
123 | code: str,
124 | ) -> str:
125 | return CODE_BLOCK_TEMPLATE.format(
126 | file_path=file_path,
127 | start_lineno=start_lineno,
128 | end_lineno=end_lineno,
129 | type=type,
130 | code=code,
131 | )
132 |
133 | def __repr__(self) -> str:
134 | if self.tree is None or self.tree.file_path is None:
135 | file_path = "None"
136 | else:
137 | file_path = self.tree.file_path.rel_path
138 |
139 | return self.format_code_block(
140 | file_path=file_path,
141 | start_lineno=self.ast_node.lineno - 1,
142 | end_lineno=self.ast_node.end_lineno,
143 | type=self.block_type.name,
144 | code=self.code if self.code_summary is None else self.code_summary,
145 | )
146 |
147 | @property
148 | def block_type(self) -> CodeBlockType:
149 | return AST_TYPE_TO_CODE_BLOCK_TYPE.get(type(self.ast_node), CodeBlockType.OTHER)
150 |
151 | @property
152 | def code(self):
153 | return "\n".join(
154 | self.tree.code_lines[self.ast_node.lineno - 1 : self.ast_node.end_lineno]
155 | )
156 |
157 | def children(
158 | self, included_types: Optional[Set[CodeBlockType]] = None
159 | ) -> List["CodeBlockASTNode"]:
160 | if self.tree is None:
161 | raise RuntimeError(
162 | "Cannot call children on a node that is not part of a tree."
163 | )
164 |
165 | return [
166 | self.tree.ast_node_to_node[child]
167 | for child in ast.iter_child_nodes(self.ast_node)
168 | if included_types is None
169 | or self.tree.ast_node_to_node[child].block_type in included_types
170 | ]
171 |
172 |
173 | class CodeBlockAST:
174 | def __init__(self, code: str, file_path: Optional[FilePath]):
175 | self.code = code
176 | self.code_lines = code.splitlines()
177 | self.file_path = file_path
178 | self.ast_root = ast.parse(self.code)
179 | self.ast_node_to_node: Dict[ast.AST, "CodeBlockASTNode"] = {}
180 | self._build_tree(
181 | CodeBlockASTNode(
182 | ast_node=self.ast_root,
183 | parent=None,
184 | tree=self,
185 | )
186 | )
187 |
188 | @property
189 | def root(self) -> CodeBlockASTNode:
190 | return self.ast_node_to_node[self.ast_root]
191 |
192 | @staticmethod
193 | def from_file_path(file_path: FilePath) -> "CodeBlockAST":
194 | with open(file_path.path, "r") as f:
195 | code = f.read()
196 |
197 | return CodeBlockAST(
198 | code=code,
199 | file_path=file_path,
200 | )
201 |
202 | def _build_tree(self, node: CodeBlockASTNode):
203 | self.ast_node_to_node[node.ast_node] = node
204 | for child in ast.iter_child_nodes(node.ast_node):
205 | child_node = CodeBlockASTNode(ast_node=child, parent=node, tree=self)
206 | self._build_tree(child_node)
207 |
208 | def __repr__(self) -> str:
209 | return ast.dump(self.ast_root)
210 |
--------------------------------------------------------------------------------
/codenav/retrieval/code_summarizer.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import re
3 |
4 | from codenav.constants import DEFAULT_OPENAI_MODEL, OPENAI_CLIENT
5 | from codenav.utils.llm_utils import create_openai_message, query_gpt
6 | from codenav.utils.string_utils import get_tag_content_from_text
7 |
8 | SUMMARIZATION_PROMPT = """\
9 | You are an expert python programmer. Given a user provided code snippet that represents a function or a class definition, you will document that code. Your documentation should be concise as possible while still being sufficient to fully understand how the code behaves. Never mention that the code is being documented by an AI model or that it is in any particular style.
10 |
11 | If you are given a class definition which, for example, defines functions func1, func2, ..., funcN, you response should be formatted as:
12 | ```
13 |
14 | Doc string of the class, should include Attributes (if any).
15 |
16 |
17 |
18 | Google-style doc string of func1, should include Args, Returns, Yields, Raises, and Examples if applicable.
19 |
20 |
21 | (Doc strings for func2, ..., funcN, each in their own block)
22 | ```
23 |
24 | If you are given a function called func, then format similarly as:
25 | ```
26 |
27 | Doc string of func, should include Args, Returns, Yields, Raises, and Examples if applicable.
28 |
29 | ```
30 |
31 | Do not include starting or ending ``` in your response. Ensure you document the __init__ function if defined. Your output should start with including the < and > characters.
32 | """
33 |
34 |
35 | class DocstringTransformer(ast.NodeTransformer):
36 | def __init__(self, docstring_map):
37 | self.docstring_map = docstring_map
38 |
39 | def indent_docstring(self, docstring: str, offset: int = 0):
40 | indentation = " " * offset
41 | if "\n" in docstring:
42 | lines = docstring.split("\n") + [""]
43 | return "\n" + "\n".join([f"{indentation}{line}" for line in lines])
44 |
45 | def visit_ClassDef(self, node):
46 | # Update class docstring
47 | if node.name in self.docstring_map:
48 | docstring = self.indent_docstring(
49 | self.docstring_map[node.name], offset=node.col_offset + 4
50 | )
51 | if (
52 | node.body
53 | and isinstance(node.body[0], ast.Expr)
54 | and isinstance(node.body[0].value, (ast.Str, ast.Constant))
55 | ):
56 | # Replace the existing docstring
57 | node.body[0] = ast.Expr(value=ast.Str(s=docstring))
58 | else:
59 | # Insert new docstring if none exists
60 | node.body.insert(0, ast.Expr(value=ast.Str(s=docstring)))
61 | # node.body.insert(0, ast.Expr(value=ast.Str(s=docstring)))
62 | self.generic_visit(node) # Process methods
63 | return node
64 |
65 | def visit_FunctionDef(self, node):
66 | # Update method docstring and replace body
67 | if node.name in self.docstring_map:
68 | docstring = self.indent_docstring(
69 | self.docstring_map[node.name], offset=node.col_offset + 4
70 | )
71 | node.body = [
72 | ast.Expr(value=ast.Constant(value=docstring)),
73 | ast.parse("...").body,
74 | ]
75 | else:
76 | node.body = [ast.parse("...").body]
77 | return node
78 |
79 |
80 | def rewrite_docstring(code, docstring_map):
81 | tree = ast.parse(code)
82 | transformer = DocstringTransformer(docstring_map)
83 | transformed_tree = transformer.visit(tree)
84 | return ast.unparse(transformed_tree)
85 |
86 |
87 | class CodeSummarizer:
88 | def __init__(
89 | self,
90 | model: str = DEFAULT_OPENAI_MODEL,
91 | max_tokens: int = 50000,
92 | ):
93 | self.model = model
94 | self.max_tokens = max_tokens
95 |
96 | def generate_tagged_summary(self, code: str) -> str:
97 | # Use OpenAI to summarize the code
98 | messages = [
99 | create_openai_message(text=SUMMARIZATION_PROMPT, role="system"),
100 | create_openai_message(text=code, role="user"),
101 | ]
102 | return query_gpt(
103 | messages=messages,
104 | model=self.model,
105 | max_tokens=self.max_tokens,
106 | client=OPENAI_CLIENT,
107 | )
108 |
109 | def summarize(self, code: str) -> str:
110 | tagged_summary = self.generate_tagged_summary(code)
111 | tags = [t.strip() for t in re.findall(r"<([^/]*?)>", tagged_summary)]
112 | docstring_map = dict()
113 | for tag in tags:
114 | docstring_map[tag] = get_tag_content_from_text(tagged_summary, tag=tag)
115 |
116 | return rewrite_docstring(code, docstring_map)
117 |
118 |
119 | if __name__ == "__main__":
120 | import inspect
121 | import sys
122 | import importlib
123 |
124 | func_or_class = sys.argv[1]
125 |
126 | m = importlib.import_module(".".join(func_or_class.split(".")[:-1]))
127 | to_doc = inspect.getsource(getattr(m, func_or_class.split(".")[-1]))
128 |
129 | summarizer = CodeSummarizer()
130 | print(summarizer.summarize(to_doc))
131 |
--------------------------------------------------------------------------------
/codenav/retrieval/elasticsearch/README.md:
--------------------------------------------------------------------------------
1 | # Elasticsearch
2 |
3 | ## Download
4 |
5 | You can download Elasticsearch by running
6 | ```bash
7 | python -m codenav.retrieval.elasticsearch.install_elasticsearch
8 | ```
9 | This will save Elastic search to `codenav/external_src/elasticsearch-8.12.0`.
10 |
11 | ## Start
12 |
13 | Once downloaded, you can start the Elasticsearch server by running:
14 | ```bash
15 | bash codenav/external_src/elasticsearch-8.12.0/bin/elasticsearch
16 | ```
17 |
18 | ## Graphical interface
19 |
20 | It can be useful to use Kibana, an GUI for Elasticsearch, so you can do things like
21 | deleting an index without running commands via Python. Kibana will automatically be downloaded when you run the above
22 | install script. You can start Kibana by running:
23 | ```bash
24 | bash codenav/external_src/kibana-8.12.0/bin/kibana
25 | ```
26 | and you can access the web interface by navigating to [http://localhost:5601](http://localhost:5601) in your browser.
--------------------------------------------------------------------------------
/codenav/retrieval/elasticsearch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/retrieval/elasticsearch/__init__.py
--------------------------------------------------------------------------------
/codenav/retrieval/elasticsearch/create_index.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, TypedDict, Union, Literal
2 |
3 | import tiktoken
4 | from elasticsearch import Elasticsearch
5 |
6 | from codenav.utils.hashing_utils import md5_hash_str
7 |
8 |
9 | class EsDocument(TypedDict):
10 | text: str
11 | file_path: str
12 | lines: dict
13 | type: str # Will be string corresponding to one of the CodeBlockType types.
14 | prototype: str
15 | text_vector: Optional[list[float]]
16 | text_summary: Optional[str]
17 |
18 |
19 | def es_doc_to_string(
20 | doc: EsDocument,
21 | prototype: bool = False,
22 | use_summary: Union[bool, Literal["ifshorter", "always", "never"]] = "ifshorter",
23 | ) -> str:
24 | if prototype:
25 | return doc["prototype"] + " # " + doc["file_path"]
26 |
27 | code: Optional[str] = doc["text"]
28 | summary = doc.get("text_summary", code)
29 |
30 | if isinstance(use_summary, str):
31 | if use_summary == "ifshorter":
32 | encode_count = lambda m: len(
33 | tiktoken.get_encoding("cl100k_base").encode(m, disallowed_special=[])
34 | )
35 |
36 | use_summary = encode_count(code) >= encode_count(summary)
37 | elif use_summary == "always":
38 | use_summary = True
39 | elif use_summary == "never":
40 | use_summary = False
41 | else:
42 | raise ValueError("Invalid value for use_summary")
43 |
44 | doc_str = [
45 | f"file_path={doc['file_path']}",
46 | f"lines=[{doc['lines']['gte']}, {doc['lines']['lt']}]",
47 | f"type={doc['type']}",
48 | f"content={{\n{summary if use_summary else code}\n}}",
49 | ]
50 |
51 | return "\n".join(doc_str)
52 |
53 |
54 | def es_doc_to_hash(doc: EsDocument) -> str:
55 | return md5_hash_str(es_doc_to_string(doc, prototype=False, use_summary=False))
56 |
57 |
58 | class EsHit(TypedDict):
59 | _index: str
60 | _type: str
61 | _id: str
62 | _score: float
63 | _source: EsDocument
64 |
65 |
66 | CODENAV_INDEX_SETTINGS = {
67 | "settings": {
68 | "number_of_shards": 1,
69 | "number_of_replicas": 0,
70 | "analysis": {
71 | "tokenizer": {
72 | "code_ngram_3_4_tokenizer": {
73 | "type": "ngram",
74 | "min_gram": 3,
75 | "max_gram": 4,
76 | "token_chars": ["letter", "digit", "punctuation", "symbol"],
77 | }
78 | },
79 | "analyzer": {
80 | "code_ngram_3_4_analyzer": {
81 | "type": "custom",
82 | "tokenizer": "code_ngram_3_4_tokenizer",
83 | "filter": ["lowercase"], # You can include other filters as needed
84 | }
85 | },
86 | },
87 | },
88 | "mappings": {
89 | "properties": {
90 | "file_path": {"type": "keyword"},
91 | "type": {"type": "keyword"},
92 | "lines": {"type": "integer_range"}, # Using integer_range for line numbers
93 | "text": {
94 | "type": "text",
95 | "analyzer": "code_ngram_3_4_analyzer",
96 | },
97 | "prototype": {
98 | "type": "text",
99 | "analyzer": "code_ngram_3_4_analyzer",
100 | },
101 | "text_vector": {
102 | "type": "dense_vector",
103 | "dims": 1536, # Adjust the dimension according to your vector
104 | "index": "true",
105 | "similarity": "cosine",
106 | },
107 | "text_summary": {
108 | "type": "text",
109 | "analyzer": "code_ngram_3_4_analyzer",
110 | },
111 | }
112 | },
113 | }
114 |
115 |
116 | def create_empty_index(
117 | es: Elasticsearch,
118 | index_name: str,
119 | embedding_dim: int,
120 | ):
121 | assert embedding_dim == 1536, "Only 1536-dimensional vectors are supported"
122 | # Create the index
123 | es.indices.create(
124 | index=index_name, body=CODENAV_INDEX_SETTINGS, request_timeout=120
125 | )
126 |
127 |
128 | def create_file_path_hash_index(es: Elasticsearch, index_name: str):
129 | if es.indices.exists(index=index_name):
130 | return
131 |
132 | # Define the index settings and mappings
133 | index_body = {
134 | "settings": {"number_of_shards": 1, "number_of_replicas": 0},
135 | "mappings": {
136 | "properties": {
137 | "md5hash": {"type": "keyword"},
138 | "code_uuid": {"type": "keyword"},
139 | "file_path": {"type": "keyword"},
140 | "doc_count": {"type": "integer"},
141 | }
142 | },
143 | }
144 |
145 | # Create the index
146 | es.indices.create(index=index_name, body=index_body)
147 |
148 |
149 | if __name__ == "__main__":
150 | # Connect to the local Elasticsearch instance
151 | es = Elasticsearch(hosts="http://localhost:9200/")
152 |
153 | create_empty_index(es=es, index_name="code_repository", embedding_dim=1536)
154 |
155 | # Example document
156 | doc = {
157 | "text": "def example_function():\n pass",
158 | "file_path": "/path/to/file.py",
159 | "lines": {"gte": 5, "lte": 30},
160 | "type": "function",
161 | "text_vector": [0.1] * 768, # Example vector, replace with actual data
162 | }
163 |
164 | # Index the document
165 | es.index(index="code_repository", document=doc)
166 |
167 | query = {"query": {"match": {"type": "function"}}}
168 |
169 | res = es.search(index="code_repository", body=query)
170 | print(res["hits"]["hits"])
171 |
--------------------------------------------------------------------------------
/codenav/retrieval/elasticsearch/debug_add_item_to_index.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/retrieval/elasticsearch/debug_add_item_to_index.py
--------------------------------------------------------------------------------
/codenav/retrieval/elasticsearch/elasticsearch.yml:
--------------------------------------------------------------------------------
1 | # ======================== Elasticsearch Configuration =========================
2 | #
3 | # NOTE: Elasticsearch comes with reasonable defaults for most settings.
4 | # Before you set out to tweak and tune the configuration, make sure you
5 | # understand what are you trying to accomplish and the consequences.
6 | #
7 | # The primary way of configuring a node is via this file. This template lists
8 | # the most important settings you may want to configure for a production cluster.
9 | #
10 | # Please consult the documentation for further information on configuration options:
11 | # https://www.elastic.co/guide/en/elasticsearch/reference/index.html
12 | #
13 | # ---------------------------------- Cluster -----------------------------------
14 | #
15 | # Use a descriptive name for your cluster:
16 | #
17 | #cluster.name: my-application
18 | #
19 | # ------------------------------------ Node ------------------------------------
20 | #
21 | # Use a descriptive name for the node:
22 | #
23 | #node.name: node-1
24 | #
25 | # Add custom attributes to the node:
26 | #
27 | #node.attr.rack: r1
28 | #
29 | # ----------------------------------- Paths ------------------------------------
30 | #
31 | # Path to directory where to store the data (separate multiple locations by comma):
32 | #
33 | #path.data: /path/to/data
34 | #
35 | # Path to log files:
36 | #
37 | #path.logs: /path/to/logs
38 | #
39 | # ----------------------------------- Memory -----------------------------------
40 | #
41 | # Lock the memory on startup:
42 | #
43 | #bootstrap.memory_lock: true
44 | #
45 | # Make sure that the heap size is set to about half the memory available
46 | # on the system and that the owner of the process is allowed to use this
47 | # limit.
48 | #
49 | # Elasticsearch performs poorly when the system is swapping the memory.
50 | #
51 | # ---------------------------------- Network -----------------------------------
52 | #
53 | # By default Elasticsearch is only accessible on localhost. Set a different
54 | # address here to expose this node on the network:
55 | #
56 | #network.host: 192.168.0.1
57 | #
58 | # By default Elasticsearch listens for HTTP traffic on the first free port it
59 | # finds starting at 9200. Set a specific HTTP port here:
60 | #
61 | #http.port: 9200
62 | #
63 | # For more information, consult the network module documentation.
64 | #
65 | # --------------------------------- Discovery ----------------------------------
66 | #
67 | # Pass an initial list of hosts to perform discovery when this node is started:
68 | # The default list of hosts is ["127.0.0.1", "[::1]"]
69 | #
70 | #discovery.seed_hosts: ["host1", "host2"]
71 | #
72 | # Bootstrap the cluster using an initial set of master-eligible nodes:
73 | #
74 | #cluster.initial_master_nodes: ["node-1", "node-2"]
75 | #
76 | # For more information, consult the discovery and cluster formation module documentation.
77 | #
78 | # ---------------------------------- Various -----------------------------------
79 | #
80 | # Allow wildcard deletion of indices:
81 | #
82 | #action.destructive_requires_name: false
83 |
84 | #----------------------- BEGIN SECURITY AUTO CONFIGURATION -----------------------
85 | #
86 | # The following settings, TLS certificates, and keys have been automatically
87 | # generated to configure Elasticsearch security features on 01-02-2024 00:31:46
88 | #
89 | # --------------------------------------------------------------------------------
90 |
91 | # Enable security features
92 | xpack.security.enabled: false
93 |
94 | xpack.security.enrollment.enabled: false
95 |
96 | # Enable encryption for HTTP API client connections, such as Kibana, Logstash, and Agents
97 | xpack.security.http.ssl:
98 | enabled: false
99 | # keystore.path: certs/http.p12
100 |
101 | # Enable encryption and mutual authentication between cluster nodes
102 | xpack.security.transport.ssl:
103 | enabled: false
104 | # verification_mode: certificate
105 | # keystore.path: certs/transport.p12
106 | # truststore.path: certs/transport.p12
107 |
108 | # Create a new cluster with the current node only
109 | # Additional nodes can still join the cluster later
110 | #cluster.initial_master_nodes: ["ip-172-16-20-86.us-west-2.compute.internal"]
111 |
112 | # cluster.routing.allocation.disk.watermark.high: "95%"
113 | cluster.routing.allocation.disk.threshold_enabled: false # Getting rid of high disk utilization errors
114 |
115 | # Allow HTTP API connections from anywhere
116 | # Connections are encrypted and require user authentication
117 | http.host: 0.0.0.0
118 |
119 | # Allow other nodes to join the cluster from anywhere
120 | # Connections are encrypted and mutually authenticated
121 | #transport.host: 0.0.0.0
122 |
123 | #----------------------- END SECURITY AUTO CONFIGURATION -------------------------
124 |
--------------------------------------------------------------------------------
/codenav/retrieval/elasticsearch/elasticsearch_constants.py:
--------------------------------------------------------------------------------
1 | RESERVED_CHARACTERS = """+ - = && || > < ! ( ) { } [ ] ^ " ~ * ? : \ /"""
2 |
--------------------------------------------------------------------------------
/codenav/retrieval/elasticsearch/elasticsearch_retriever.py:
--------------------------------------------------------------------------------
1 | from concurrent.futures.thread import ThreadPoolExecutor
2 | from typing import List
3 |
4 | from elasticsearch import Elasticsearch
5 |
6 | from codenav.retrieval.code_blocks import CodeBlockType
7 | from codenav.retrieval.code_summarizer import CodeSummarizer
8 | from codenav.retrieval.elasticsearch.create_index import EsDocument, es_doc_to_hash
9 |
10 |
11 | class EsCodeRetriever:
12 | def __init__(self, index_name: str, host: str):
13 | self.index_name = index_name
14 | assert (
15 | index_name is not None and index_name != ""
16 | ), "Index name cannot be empty."
17 |
18 | self.host = host
19 | self.es = Elasticsearch(hosts=host)
20 |
21 | if not self.es.ping():
22 | raise ValueError(
23 | f"Elasticsearch is not running or could not be reached at {host}."
24 | )
25 |
26 | self.code_summarizer = CodeSummarizer()
27 |
28 | def search(self, query: str, default_n: int = 10) -> List[EsDocument]:
29 | body = {"query": {"query_string": {"query": query}}, "size": default_n}
30 | hits = self.es.search(index=self.index_name, body=body)["hits"]["hits"]
31 | return [hit["_source"] for hit in hits]
32 |
33 |
34 | def add_summary_to_es_doc(
35 | es_doc: EsDocument,
36 | es: Elasticsearch,
37 | index_name: str,
38 | code_summarizer: CodeSummarizer,
39 | overwrite_existing: bool = False,
40 | ):
41 | if es_doc["type"] == CodeBlockType.DOCUMENTATION.name:
42 | return
43 |
44 | if "text_summary" in es_doc and es_doc["text_summary"] is not None:
45 | if not overwrite_existing:
46 | return
47 |
48 | summary = code_summarizer.summarize(es_doc["text"])
49 | es_doc["text_summary"] = summary
50 | es.update(
51 | index=index_name,
52 | id=es_doc_to_hash(es_doc),
53 | body={"doc": {"text_summary": summary}},
54 | )
55 |
56 |
57 | def parallel_add_summary_to_es_docs(
58 | es_docs: List[EsDocument],
59 | es: Elasticsearch,
60 | index_name: str,
61 | code_summarizer: CodeSummarizer,
62 | overwrite_existing: bool = False,
63 | ):
64 | n = len(es_docs)
65 | with ThreadPoolExecutor() as executor:
66 | executor.map(
67 | add_summary_to_es_doc,
68 | es_docs,
69 | [es] * n,
70 | [index_name] * n,
71 | [code_summarizer] * n,
72 | [overwrite_existing] * n,
73 | )
74 |
--------------------------------------------------------------------------------
/codenav/retrieval/elasticsearch/index_codebase.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path
3 | import time
4 | from typing import List, Optional, Tuple
5 |
6 | from elasticsearch import Elasticsearch
7 | from elasticsearch.helpers import bulk
8 |
9 | from codenav.retrieval.code_blocks import (
10 | CodeBlockAST,
11 | CodeBlockType,
12 | filter_by_extension,
13 | get_file_list,
14 | )
15 | from codenav.retrieval.elasticsearch.create_index import (
16 | EsDocument,
17 | create_empty_index,
18 | es_doc_to_hash,
19 | )
20 | from codenav.utils.llm_utils import num_tokens_from_string
21 | from codenav.utils.parsing_utils import get_class_or_function_prototype
22 |
23 | DEFAULT_ES_PORT = 9200
24 | DEFAULT_KIBANA_PORT = 5601
25 | DEFAULT_ES_HOST = f"http://localhost:{DEFAULT_ES_PORT}"
26 | DEFAULT_KIBANA_HOST = f"http://localhost:{DEFAULT_KIBANA_PORT}"
27 |
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser(description="Index codebase")
31 | parser.add_argument(
32 | "--code_dir",
33 | type=str,
34 | required=True,
35 | help="Path to the codebase to index",
36 | )
37 | parser.add_argument(
38 | "--index_uid",
39 | type=str,
40 | required=True,
41 | help="Unique identifier for the index",
42 | )
43 | parser.add_argument(
44 | "--delete_index",
45 | action="store_true",
46 | help="Delete the index if it already exists",
47 | )
48 | parser.add_argument(
49 | "--host",
50 | type=str,
51 | default=DEFAULT_ES_HOST,
52 | help="Elasticsearch host",
53 | )
54 | parser.add_argument(
55 | "--force_subdir",
56 | type=str,
57 | default=None,
58 | help="If provided, only index files in this subdirectory of code_dir. Their path in the index will still be the"
59 | " path relative to code_dir.",
60 | )
61 | return parser.parse_args()
62 |
63 |
64 | def split_markdown_documentation_into_parts(
65 | docstring: str, min_tokens: int = 100, split_after_ntokens: int = 1000
66 | ) -> Tuple[List[str], List[Tuple[int, int]]]:
67 | parts: List[str] = []
68 | line_nums: List[Tuple[int, int]] = []
69 | current_lines: List[str] = []
70 | current_line_nums: List[int] = []
71 |
72 | lines = docstring.split("\n")
73 | for cur_line_idx, line in enumerate(lines):
74 | line = line.rstrip()
75 |
76 | if line == "" and len(current_lines) == 0:
77 | # Skip leading empty lines
78 | continue
79 |
80 | if len(line) == 0:
81 | current_lines.append("")
82 | current_line_nums.append(cur_line_idx)
83 | continue
84 |
85 | if (
86 | line.lstrip().startswith("#")
87 | and (len(line) == 1 or line[1] != "#")
88 | and (num_tokens_from_string("\n".join(current_lines)) > min_tokens)
89 | ):
90 | # We're at a top-level header, and we have enough tokens to split
91 | parts.append("\n".join(current_lines))
92 | line_nums.append((current_line_nums[0], current_line_nums[-1]))
93 | current_lines = []
94 | current_line_nums = []
95 |
96 | current_lines.append(line)
97 | current_line_nums.append(cur_line_idx)
98 |
99 | # We should split if we're at the end of the document or if we've reached the token limit
100 | if (
101 | len(lines) - 1 == cur_line_idx
102 | or num_tokens_from_string("\n".join(current_lines)) > split_after_ntokens
103 | ):
104 | parts.append("\n".join(current_lines))
105 | line_nums.append((current_line_nums[0], current_line_nums[-1]))
106 | current_lines = []
107 | current_line_nums = []
108 |
109 | return parts, line_nums
110 |
111 |
112 | def _should_skip_python_file(file_path: str):
113 | with open(file_path, "r") as f:
114 | first_line = f.readline().strip("\n #").lower()
115 |
116 | if first_line.startswith("index:"):
117 | return first_line.split(":")[1].strip().lower() == "false"
118 |
119 | return False
120 |
121 |
122 | def get_es_docs(code_dir: str, force_subdir: Optional[str]):
123 | all_files = get_file_list(code_dir)
124 | docs = []
125 |
126 | if force_subdir is not None:
127 | force_subdir = os.path.abspath(os.path.join(code_dir, force_subdir))
128 | if force_subdir[-1] != "/":
129 | force_subdir += "/"
130 |
131 | all_files = [
132 | file_path
133 | for file_path in all_files
134 | if file_path.abs_path.startswith(force_subdir)
135 | ]
136 |
137 | python_files = filter_by_extension(all_files, valid_extensions=[".py"])
138 | for file_path in python_files:
139 | if _should_skip_python_file(file_path.abs_path):
140 | print(f"Skipping {file_path} as it has `index: false` in the first line.")
141 | continue
142 |
143 | try:
144 | code_block_ast = CodeBlockAST.from_file_path(file_path)
145 | except SyntaxError:
146 | print(f"Syntax error in {file_path}, skipping...")
147 | continue
148 |
149 | for block in code_block_ast.root.children():
150 | # noinspection PyTypeChecker
151 | docs.append(
152 | EsDocument(
153 | file_path=file_path.rel_path,
154 | type=block.block_type.name,
155 | lines=dict(
156 | gte=block.ast_node.lineno - 1, lt=block.ast_node.end_lineno
157 | ),
158 | text=block.code,
159 | prototype=get_class_or_function_prototype(
160 | block.ast_node, include_init=False
161 | ),
162 | )
163 | )
164 |
165 | doc_files = filter_by_extension(all_files, valid_extensions=[".md"])
166 | for file_path in doc_files:
167 | with open(file_path.abs_path, "r") as f:
168 | for part, line_nums in zip(
169 | *split_markdown_documentation_into_parts(f.read())
170 | ):
171 | docs.append(
172 | EsDocument(
173 | file_path=file_path.rel_path,
174 | type=CodeBlockType.DOCUMENTATION.name,
175 | lines=dict(gte=line_nums[0], lt=line_nums[1] + 1),
176 | text=part,
177 | )
178 | )
179 |
180 | return docs
181 |
182 |
183 | def build_index(
184 | code_dir: str,
185 | index_uid: str,
186 | delete_index: bool,
187 | host: str = DEFAULT_ES_HOST,
188 | force_subdir: Optional[str] = None,
189 | ):
190 | code_dir = os.path.abspath(code_dir)
191 | print(f"Indexing codebase at {code_dir} with index_uid {index_uid}")
192 |
193 | assert os.path.exists(code_dir), f"{code_dir} does not exist"
194 |
195 | docs = get_es_docs(code_dir, force_subdir=force_subdir)
196 | bulk_insert = [
197 | {
198 | "_op_type": "index",
199 | "_index": index_uid,
200 | "_id": es_doc_to_hash(doc),
201 | "_source": doc,
202 | }
203 | for doc in docs
204 | ]
205 |
206 | if delete_index:
207 | es = Elasticsearch(host)
208 | if es.indices.exists(index=index_uid):
209 | print("Deleting existing index...")
210 | es.indices.delete(index=index_uid)
211 |
212 | assert len(bulk_insert) > 0, f"No documents to index in {code_dir}."
213 |
214 | print(f"Indexing {len(bulk_insert)} documents...")
215 | es = Elasticsearch(host)
216 | create_empty_index(es=es, index_name=index_uid, embedding_dim=1536)
217 | bulk(es, bulk_insert)
218 |
219 | time.sleep(2) # to allow for indexing to finish
220 |
221 |
222 | if __name__ == "__main__":
223 | args = parse_args()
224 | build_index(
225 | code_dir=args.code_dir,
226 | index_uid=args.index_uid,
227 | delete_index=args.delete_index,
228 | host=args.host,
229 | force_subdir=args.force_subdir,
230 | )
231 |
--------------------------------------------------------------------------------
/codenav/retrieval/elasticsearch/install_elasticsearch.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import shutil
4 | import sys
5 | import urllib
6 | import urllib.request
7 | from pathlib import Path
8 |
9 | ABS_PATH_OF_ES_DIR = os.path.abspath(os.path.dirname(Path(__file__)))
10 |
11 | DOWNLOAD_DIR = os.path.join(os.path.expanduser("~/.cache/codenav/elasticsearch"))
12 |
13 | ES_VERSION = "8.12.1"
14 | KIBANA_VERSION = ES_VERSION
15 |
16 | ES_PATH = os.path.join(DOWNLOAD_DIR, f"elasticsearch-{ES_VERSION}")
17 | KIBANA_PATH = os.path.join(DOWNLOAD_DIR, f"kibana-{ES_VERSION}")
18 |
19 | PLATFORM_TO_ES_URL = {
20 | "linux-x86_64": f"https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-{ES_VERSION}-linux-x86_64.tar.gz",
21 | "darwin-aarch64": f"https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-{ES_VERSION}-darwin-aarch64.tar.gz",
22 | "darwin-x86_64": f"https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-{ES_VERSION}-darwin-x86_64.tar.gz",
23 | }
24 |
25 | PLATFORM_TO_KIBANA_URL = {
26 | "linux-x86_64": f"https://artifacts.elastic.co/downloads/kibana/kibana-{KIBANA_VERSION}-linux-x86_64.tar.gz",
27 | "darwin-aarch64": f"https://artifacts.elastic.co/downloads/kibana/kibana-{KIBANA_VERSION}-darwin-aarch64.tar.gz",
28 | "darwin-x86_64": f"https://artifacts.elastic.co/downloads/kibana/kibana-{KIBANA_VERSION}-darwin-x86_64.tar.gz",
29 | }
30 |
31 |
32 | def compute_sha512(file_path: str):
33 | sha512_hash = hashlib.sha512()
34 | with open(file_path, "rb") as f:
35 | # Read and update hash string value in blocks of 4K
36 | for byte_block in iter(lambda: f.read(4096), b""):
37 | sha512_hash.update(byte_block)
38 | return sha512_hash.hexdigest()
39 |
40 |
41 | def install_from_url(url: str) -> None:
42 | print(f"Downloading {url}")
43 |
44 | os.makedirs(DOWNLOAD_DIR, exist_ok=True)
45 |
46 | name = "-".join(url.split("/")[-1].split("-")[:2])
47 |
48 | if os.path.exists(os.path.join(DOWNLOAD_DIR, name)):
49 | print(f"{name} already exists. Skipping download...")
50 | return
51 |
52 | tar_path = os.path.join(DOWNLOAD_DIR, f"{name}.tar.gz")
53 | tar_hash_path = tar_path + ".sha512"
54 |
55 | urllib.request.urlretrieve(url, tar_path)
56 | urllib.request.urlretrieve(
57 | url + ".sha512",
58 | tar_hash_path,
59 | )
60 | print("Download complete")
61 |
62 | # Checking SHA512
63 | print("Checking SHA512")
64 | sha512 = compute_sha512(tar_path)
65 | with open(tar_hash_path, "r") as f:
66 | expected_sha512 = f.read().strip().split(" ")[0]
67 |
68 | if sha512 != expected_sha512:
69 | raise ValueError(f"SHA512 mismatch. Expected {expected_sha512}, got {sha512}")
70 |
71 | print(f"Extracting {tar_path} to {DOWNLOAD_DIR}")
72 | os.system(f"tar -xzf {tar_path} -C {DOWNLOAD_DIR}")
73 |
74 | assert os.path.join(ES_PATH)
75 |
76 | os.remove(tar_path)
77 | os.remove(tar_hash_path)
78 |
79 | print(f"{name} installation complete")
80 |
81 |
82 | def install_elasticsearch():
83 | if sys.platform == "darwin":
84 | if os.uname().machine == "arm64":
85 | platform = "darwin-aarch64"
86 | else:
87 | platform = "darwin-x86_64"
88 | else:
89 | assert sys.platform == "linux"
90 | platform = "linux-x86_64"
91 |
92 | install_from_url(PLATFORM_TO_ES_URL[platform])
93 |
94 | # Copy the elasticsearch config file to the elasticsearch directory
95 | es_config_file = os.path.join(ABS_PATH_OF_ES_DIR, "elasticsearch.yml")
96 | shutil.copy(es_config_file, os.path.join(ES_PATH, "config"))
97 |
98 | install_from_url(PLATFORM_TO_KIBANA_URL[platform])
99 |
100 |
101 | def is_es_installed():
102 | return os.path.exists(ES_PATH) and os.path.exists(KIBANA_PATH)
103 |
104 |
105 | if __name__ == "__main__":
106 | if not is_es_installed():
107 | install_elasticsearch()
108 | else:
109 | print(f"Elasticsearch already installed at {ES_PATH}")
110 |
--------------------------------------------------------------------------------
/codenav/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/utils/__init__.py
--------------------------------------------------------------------------------
/codenav/utils/config_params.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/codenav/2f200d9de2ee15cc82e232b8414765fe5f2fb617/codenav/utils/config_params.py
--------------------------------------------------------------------------------
/codenav/utils/eval_types.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Dict, Optional, Union
3 |
4 | import attrs
5 |
6 | from codenav.interaction.episode import Episode
7 |
8 |
9 | @attrs.define
10 | class EvalInput:
11 | uid: Union[str, int]
12 | query: str
13 | metadata: Optional[Any] = None
14 |
15 |
16 | Str2AnyDict = Dict[str, Any]
17 |
18 |
19 | class EvalSpec(ABC):
20 | def __init__(
21 | self,
22 | episode_kwargs: Str2AnyDict,
23 | interaction_kwargs: Str2AnyDict,
24 | logging_kwargs: Str2AnyDict,
25 | ):
26 | self.episode_kwargs = episode_kwargs
27 | self.interaction_kwargs = interaction_kwargs
28 | self.logging_kwargs = logging_kwargs
29 |
30 | @staticmethod
31 | @abstractmethod
32 | def build_episode(
33 | eval_input: EvalInput,
34 | episode_kwargs: Optional[Str2AnyDict] = None,
35 | ) -> Episode:
36 | pass
37 |
38 | @staticmethod
39 | @abstractmethod
40 | def run_interaction(
41 | episode: Episode,
42 | interaction_kwargs: Optional[Str2AnyDict] = None,
43 | ) -> Optional[Str2AnyDict]:
44 | pass
45 |
46 | @staticmethod
47 | @abstractmethod
48 | def log_output(
49 | interaction_output: Str2AnyDict,
50 | eval_input: EvalInput,
51 | logging_kwargs: Optional[Str2AnyDict] = None,
52 | ) -> Any:
53 | pass
54 |
--------------------------------------------------------------------------------
/codenav/utils/evaluator.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import multiprocessing as mp
3 | import sys
4 | import time
5 | import traceback
6 | from queue import Empty
7 | from typing import Any, Iterator, Sequence
8 |
9 | from codenav.utils.eval_types import EvalInput, EvalSpec
10 |
11 | if sys.platform.lower() == "darwin":
12 | mp = mp.get_context("spawn")
13 | else:
14 | mp = mp.get_context("forkserver")
15 |
16 |
17 | def _parallel_worker(task_queue: mp.Queue, result_queue: mp.Queue, eval_spec: EvalSpec):
18 | while True:
19 | try:
20 | eval_input: EvalInput = task_queue.get(timeout=1)
21 | print(f"Starting task: {eval_input.uid}")
22 | try:
23 | result = CodenavEvaluator.evaluate_input(
24 | eval_input=eval_input, eval_spec=copy.deepcopy(eval_spec)
25 | )
26 | except:
27 | result = ("failure", eval_input, traceback.format_exc())
28 | result_queue.put(result)
29 | raise
30 | result_queue.put(result)
31 | except Empty:
32 | break
33 |
34 |
35 | class CodenavEvaluator:
36 | def __init__(self, eval_spec: EvalSpec):
37 | self.eval_spec = eval_spec
38 |
39 | @staticmethod
40 | def evaluate_input(
41 | eval_input: EvalInput,
42 | eval_spec: EvalSpec,
43 | ) -> Any:
44 | episode = eval_spec.build_episode(
45 | eval_input=eval_input, episode_kwargs=eval_spec.episode_kwargs
46 | )
47 | interaction_output = eval_spec.run_interaction(
48 | episode=episode, interaction_kwargs=eval_spec.interaction_kwargs
49 | )
50 | assert interaction_output is not None
51 | return eval_spec.log_output(
52 | interaction_output=interaction_output,
53 | eval_input=eval_input,
54 | logging_kwargs=eval_spec.logging_kwargs,
55 | )
56 |
57 | def evaluate_in_sequence(self, inputs: Sequence[EvalInput]) -> Iterator[Any]:
58 | for input in inputs:
59 | yield CodenavEvaluator.evaluate_input(input, self.eval_spec)
60 |
61 | def evaluate_in_parallel(
62 | self, inputs: Sequence[EvalInput], n_procs: int
63 | ) -> Iterator[Any]:
64 | task_queue: mp.Queue[EvalInput] = mp.Queue()
65 | result_queue: mp.Queue[Any] = mp.Queue()
66 |
67 | for input in inputs:
68 | task_queue.put(input)
69 |
70 | procs = []
71 | for proc_idx in range(n_procs):
72 | p = mp.Process(
73 | target=_parallel_worker, args=(task_queue, result_queue, self.eval_spec)
74 | )
75 | p.start()
76 | procs.append(p)
77 |
78 | for _ in range(len(inputs)):
79 | yield result_queue.get()
80 |
81 | for proc in procs:
82 | proc.join(1)
83 |
84 | def evaluate(self, samples: Sequence[EvalInput], n_procs: int = 1) -> Iterator[Any]:
85 | start_time = time.time()
86 |
87 | if n_procs > 1:
88 | yield from self.evaluate_in_parallel(samples, n_procs)
89 | else:
90 | yield from self.evaluate_in_sequence(samples)
91 |
92 | print(f"Time taken in evaluation: {time.time() - start_time}")
93 |
--------------------------------------------------------------------------------
/codenav/utils/hashing_utils.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 |
3 |
4 | def md5_hash_str(to_hash: str) -> str:
5 | return hashlib.md5(to_hash.encode()).hexdigest()
6 |
7 |
8 | def md5_hash_file(to_hash_path: str):
9 | with open(to_hash_path, "r") as f:
10 | return md5_hash_str(f.read())
11 |
--------------------------------------------------------------------------------
/codenav/utils/linting_and_type_checking_utils.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import tempfile
3 | from typing import List, Optional, Tuple
4 |
5 | import black
6 | import mypy.api
7 |
8 |
9 | def black_format_code(code: str) -> Tuple[str, Optional[str]]:
10 | """Format code using black."""
11 | error_str: Optional[str] = None
12 | formatted_code: str = ""
13 | try:
14 | formatted_code = black.format_str(code, mode=black.FileMode())
15 | except Exception as error:
16 | formatted_code = code
17 | error_str = str(error)
18 | return formatted_code, error_str
19 |
20 |
21 | def run_mypy(code: str) -> Tuple[str, str, int]:
22 | """
23 | Run mypy (static type checker) on code.
24 |
25 | Returns:
26 | stdout (str): output of mypy, errors found etc
27 | stderr (str): errors generated by mypy itself
28 | exit_code (int): 0 means no type errors, non-zero implies type errors
29 | """
30 | return mypy.api.run(
31 | [
32 | "--ignore-missing-imports",
33 | "--no-namespace-packages",
34 | "--follow_imports",
35 | "silent",
36 | "-c",
37 | code,
38 | ]
39 | )
40 |
41 |
42 | def run_flake8(code: str) -> Tuple[str, str]:
43 | """
44 | Run flake8 (linter) on code.
45 |
46 | Returns:
47 | stdout (str): output of flake8, errors found etc
48 | stderr (str): errors generated by flake8 itself
49 | """
50 | with tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=True) as temp_file:
51 | temp_file.write(code)
52 | temp_file.flush() # Make sure all data is flushed to disk
53 | result = subprocess.run(
54 | ["flake8", temp_file.name], capture_output=True, text=True
55 | )
56 |
57 | return result.stdout, result.stderr
58 |
59 |
60 | class CodeAnalysisError:
61 | """Represents a code analysis error."""
62 |
63 | def __init__(
64 | self, code_ref: str, line_num: int, error: str, prefix: str = "CodeAnalysis"
65 | ):
66 | self.code_ref = code_ref
67 | self.line_num = line_num
68 | self.error = error
69 | self.prefix = prefix
70 |
71 | def __repr__(self) -> str:
72 | return f'{self.prefix}Error(code_ref="{self.code_ref}", line_num={self.line_num}, error="{self.error}")'
73 |
74 |
75 | class TypeCheckingError(CodeAnalysisError):
76 | """Represents a type checking error."""
77 |
78 | def __init__(self, code_ref: str, line_num: int, error: str):
79 | super().__init__(code_ref, line_num, error, prefix="TypeChecking")
80 |
81 |
82 | class LintingError(CodeAnalysisError):
83 | """Represents a linting error."""
84 |
85 | def __init__(self, code_ref: str, line_num: int, error: str):
86 | super().__init__(code_ref, line_num, error, prefix="Linting")
87 |
88 |
89 | def parse_mypy_output(mypy_output: str, code: str) -> List[TypeCheckingError]:
90 | """Parse mypy output."""
91 | if mypy_output == "":
92 | return []
93 |
94 | code_lines = code.split("\n")
95 | parse = []
96 | # skip last two items after split because - the second last is a count of error &
97 | # the last is an empty string
98 | lines = mypy_output.split("\n")
99 | assert lines[-2].startswith("Found ") or lines[-2].startswith("Success")
100 | assert lines[-1] == ""
101 | for line in lines[:-2]:
102 | path, line_num, error_type, error = line.split(":", maxsplit=3)
103 | parse.append(
104 | TypeCheckingError(
105 | code_ref=code_lines[int(line_num) - 1],
106 | line_num=int(line_num),
107 | error=error.strip(),
108 | )
109 | )
110 |
111 | return parse
112 |
113 |
114 | def parse_flake8_output(flake8_output: str, code: str) -> List[LintingError]:
115 | """Parse flake8 output."""
116 | if parse_flake8_output == "":
117 | return []
118 |
119 | code_lines = code.split("\n")
120 | parse = []
121 | # skip last item after split because it's an empty string
122 | lines = flake8_output.split("\n")
123 | assert lines[-1] == ""
124 | for line in lines[:-1]:
125 | path, line_num, col, error = line.split(":", maxsplit=3)
126 | parse.append(
127 | LintingError(
128 | code_ref=code_lines[int(line_num) - 1],
129 | line_num=int(line_num),
130 | error=error.strip(),
131 | )
132 | )
133 |
134 | return parse
135 |
136 |
137 | def get_linting_errors(
138 | code: str, skip_codes: List[str] = ["E402", "E501"]
139 | ) -> List[LintingError]:
140 | """Get linting errors."""
141 | lint_errs = parse_flake8_output(run_flake8(code)[0], code)
142 | to_return: List[LintingError] = []
143 | for err in lint_errs:
144 | if not any([code in err.error for code in skip_codes]):
145 | to_return.append(err)
146 |
147 | return to_return
148 |
149 |
150 | def get_type_checking_errors(code: str) -> List[TypeCheckingError]:
151 | """Get type checking errors."""
152 | type_errs = parse_mypy_output(run_mypy(code)[0], code)
153 | to_return: List[TypeCheckingError] = []
154 | for err in type_errs:
155 | if "no-redef" not in err.error:
156 | to_return.append(err)
157 |
158 | return to_return
159 |
--------------------------------------------------------------------------------
/codenav/utils/llm_utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import math
4 | import time
5 | import traceback
6 | from typing import List, Literal, Optional, Sequence, Union, cast, Any, Dict
7 |
8 | import numpy as np
9 | import tiktoken
10 | import tqdm
11 | from PIL import Image
12 | from openai import InternalServerError, OpenAI, RateLimitError
13 | from openai.resources import Chat
14 | from openai.types.chat import (
15 | ChatCompletionAssistantMessageParam,
16 | ChatCompletionMessageParam,
17 | ChatCompletionSystemMessageParam,
18 | ChatCompletionUserMessageParam,
19 | )
20 |
21 | MODEL_STR_TO_PRICE_PER_1M_INPUT_TOKENS = {
22 | # Together models, input tokens cost same as output
23 | "mistralai/Mistral-7B-Instruct-v0.2": 0.2,
24 | "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.6,
25 | "mistralai/Mixtral-8x22B-Instruct-v0.1": 1.2,
26 | "Qwen/Qwen1.5-72B-Chat": 1.0,
27 | "Qwen/Qwen1.5-110B-Chat": 1.8,
28 | "meta-llama/Llama-3-70b-chat-hf": 0.9,
29 | # OpenAI models
30 | "gpt-3.5-turbo-0301": 1.5,
31 | "gpt-3.5-turbo-0125": 1.5,
32 | "gpt-4-1106-preview": 10.0,
33 | "gpt-4o-2024-05-13": 5.0,
34 | # Cohere
35 | "command-r": 0.5,
36 | "command-r-plus": 3.0,
37 | }
38 |
39 | MODEL_STR_TO_PRICE_PER_1M_OUTPUT_TOKENS = {
40 | # Together models, input tokens cost same as output
41 | "mistralai/Mistral-7B-Instruct-v0.2": 0.2,
42 | "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.6,
43 | "mistralai/Mixtral-8x22B-Instruct-v0.1": 1.2,
44 | "Qwen/Qwen1.5-72B-Chat": 1.0,
45 | "Qwen/Qwen1.5-110B-Chat": 1.8,
46 | "meta-llama/Llama-3-70b-chat-hf": 0.9,
47 | # OpenAI models
48 | "gpt-3.5-turbo-0301": 2.0,
49 | "gpt-3.5-turbo-0125": 2.0,
50 | "gpt-4o-2024-05-13": 15.0,
51 | "gpt-4-1106-preview": 30.0,
52 | # Cohere
53 | "command-r": 1.5,
54 | "command-r-plus": 15.0,
55 | }
56 |
57 |
58 | class MaxTokensExceededError(Exception):
59 | pass
60 |
61 |
62 | def num_tokens_from_messages(
63 | messages,
64 | skip_images: bool,
65 | model="gpt-3.5-turbo-0301",
66 | ):
67 | """Returns the number of tokens used by a list of messages."""
68 | assert skip_images, "skip_images=False is not presently supported"
69 |
70 | try:
71 | encoding = tiktoken.encoding_for_model(model)
72 | except KeyError:
73 | encoding = tiktoken.get_encoding("cl100k_base")
74 |
75 | if model in [
76 | "gpt-3.5-turbo-0301",
77 | "gpt-3.5-turbo-0125",
78 | "gpt-4-1106-preview",
79 | "gpt-4-0125-preview",
80 | ]: # note: future models may deviate from this
81 | num_tokens = 0
82 | for message in messages:
83 | num_tokens += (
84 | 4 # every message follows {role/name}\n{content}\n
85 | )
86 | for key, value in message.items():
87 | if key == "content":
88 | if isinstance(value, str):
89 | num_tokens += len(encoding.encode(value, disallowed_special=[]))
90 | elif isinstance(value, List):
91 | for piece in value:
92 | if piece["type"] == "text":
93 | num_tokens += len(
94 | encoding.encode(
95 | piece["text"], disallowed_special=[]
96 | )
97 | )
98 | else:
99 | assert skip_images
100 | else:
101 | raise NotImplementedError
102 | elif key == "name": # if there's a name, the role is omitted
103 | num_tokens += -1 # role is always required and always 1 token
104 | else:
105 | num_tokens += len(encoding.encode(value, disallowed_special=[]))
106 | num_tokens += 2 # every reply is primed with assistant
107 | return num_tokens
108 | else:
109 | raise NotImplementedError(
110 | f"""num_tokens_from_messages() is not presently implemented for model {model}.
111 | See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
112 | )
113 |
114 |
115 | def image_height_width_from_path(image_path: str):
116 | with Image.open(image_path) as img:
117 | # Load only image metadata (not pixel data)
118 | img.load()
119 |
120 | # Get dimensions
121 | width, height = img.size
122 |
123 | return height, width
124 |
125 |
126 | def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
127 | """Returns the number of tokens in a text string."""
128 | import tiktoken
129 |
130 | encoding = tiktoken.get_encoding(encoding_name)
131 | num_tokens = len(encoding.encode(string, disallowed_special=[]))
132 | return num_tokens
133 |
134 |
135 | def compute_token_count_for_image(image: Union[str, np.ndarray]):
136 | # From https://platform.openai.com/docs/guides/vision
137 | if isinstance(image, str):
138 | h, w = image_height_width_from_path(image)
139 | else:
140 | h, w, _ = image.shape
141 |
142 | # First rescaled to be within 2048x2048
143 | scale = 2048 / max([2048, h, w])
144 | h = h * scale
145 | w = w * scale
146 |
147 | # Then rescaled so shortest edge is 768
148 | h, w = 768 * h / min(h, w), 768 * w / min(h, w)
149 |
150 | return math.ceil(h / 512) * math.ceil(w / 512) * 170 + 85
151 |
152 |
153 | def partition_sequence(seq: Sequence, parts: int) -> List:
154 | assert 0 < parts, f"parts [{parts}] must be greater > 0"
155 | assert parts <= len(seq), f"parts [{parts}] > len(seq) [{len(seq)}]"
156 | n = len(seq)
157 |
158 | quotient = n // parts
159 | remainder = n % parts
160 | counts = [quotient + (i < remainder) for i in range(parts)]
161 | inds = np.cumsum([0] + counts)
162 | return [seq[ind0:ind1] for ind0, ind1 in zip(inds[:-1], inds[1:])]
163 |
164 |
165 | def encode_image(image_path: str):
166 | with open(image_path, "rb") as image_file:
167 | return base64.b64encode(image_file.read()).decode("utf-8")
168 |
169 |
170 | class NumpyJSONEncoder(json.JSONEncoder):
171 | """JSON encoder for numpy objects.
172 |
173 | Based off the stackoverflow answer by Jie Yang here: https://stackoverflow.com/a/57915246.
174 | The license for this code is [BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/).
175 | """
176 |
177 | def default(self, obj):
178 | if isinstance(obj, np.void):
179 | return None
180 | elif isinstance(obj, np.bool_):
181 | return bool(obj)
182 | elif isinstance(obj, np.integer):
183 | return int(obj)
184 | elif isinstance(obj, np.floating):
185 | return float(obj)
186 | elif isinstance(obj, np.ndarray):
187 | return obj.tolist()
188 | else:
189 | return super(NumpyJSONEncoder, self).default(obj)
190 |
191 |
192 | def compute_llm_cost(input_tokens: int, output_tokens: int, model: str):
193 | assert (
194 | model in MODEL_STR_TO_PRICE_PER_1M_INPUT_TOKENS
195 | and model in MODEL_STR_TO_PRICE_PER_1M_OUTPUT_TOKENS
196 | ), f"model [{model}] must be in both MODEL_STR_TO_PRICE_PER_1M_INPUT_TOKENS and MODEL_STR_TO_PRICE_PER_1M_OUTPUT_TOKENS"
197 |
198 | input_token_cost_per_1m = MODEL_STR_TO_PRICE_PER_1M_INPUT_TOKENS[model]
199 | output_token_cost_per_1m = MODEL_STR_TO_PRICE_PER_1M_OUTPUT_TOKENS[model]
200 |
201 | return (
202 | input_tokens * input_token_cost_per_1m
203 | + output_tokens * output_token_cost_per_1m
204 | ) / 1e6
205 |
206 |
207 | def compute_cost_for_queries_and_responses(
208 | queries_and_responses: Sequence[Dict[str, Union[str, List[Dict[str, Any]]]]],
209 | model: str,
210 | encoding_name: Optional[str] = "cl100k_base",
211 | ):
212 | input_tokens = 0
213 | output_tokens = 0
214 | for query_and_response in queries_and_responses:
215 | if "input_tokens" in query_and_response:
216 | input_tokens += query_and_response["input_tokens"]
217 | else:
218 | input_tokens += num_tokens_from_messages(
219 | query_and_response["input"], model=model, skip_images=True
220 | )
221 |
222 | if "output_tokens" in query_and_response:
223 | output_tokens += query_and_response["output_tokens"]
224 | else:
225 | output_str = query_and_response["output"]
226 | output_tokens += len(
227 | tiktoken.get_encoding(encoding_name).encode(
228 | output_str, disallowed_special=[]
229 | )
230 | )
231 |
232 | return (
233 | compute_llm_cost(input_tokens, output_tokens, model=model),
234 | input_tokens,
235 | output_tokens,
236 | )
237 |
238 |
239 | def create_openai_message(
240 | text: str, role: Literal["user", "system", "assistant"] = "user"
241 | ) -> Union[
242 | ChatCompletionSystemMessageParam,
243 | ChatCompletionUserMessageParam,
244 | ChatCompletionAssistantMessageParam,
245 | ]:
246 | return { # type: ignore
247 | "role": role,
248 | "content": [
249 | {
250 | "type": "text",
251 | "text": text,
252 | },
253 | ],
254 | }
255 |
256 |
257 | def query_gpt(
258 | messages: Sequence[ChatCompletionMessageParam],
259 | model: str,
260 | client: OpenAI,
261 | pbar: Optional[tqdm.tqdm] = None,
262 | sec_wait_between_retries: float = 10,
263 | max_tokens: int = 3000,
264 | return_input_output_tokens: bool = False,
265 | ) -> Optional[Union[str, Dict[str, Union[str, int]]]]:
266 | """Query the OpenAI API with the given messages."""
267 | num_tokens = num_tokens_from_messages(messages, skip_images=True)
268 | if num_tokens > max_tokens:
269 | raise MaxTokensExceededError(
270 | f"num_tokens [{num_tokens}] > max_tokens [{max_tokens}]"
271 | )
272 |
273 | if pbar:
274 | pbar.write(f"Num tokens: {num_tokens}")
275 |
276 | response = None
277 | for retry in range(10):
278 | try:
279 | response = cast(Chat, client.chat).completions.create(
280 | model=model,
281 | messages=cast(List[ChatCompletionMessageParam], messages),
282 | max_tokens=3000, # Max number of output tokens
283 | temperature=0.0,
284 | )
285 | break
286 | except RateLimitError:
287 | if pbar:
288 | pbar.write(
289 | f"Rate limit error, waiting {sec_wait_between_retries} seconds..."
290 | )
291 | except InternalServerError:
292 | if pbar:
293 | pbar.write(
294 | f"Internal server error, waiting {sec_wait_between_retries} seconds..."
295 | )
296 | except:
297 | m = traceback.format_exc().lower()
298 | if "ratelimit" in m or "rate limit" in m:
299 | if pbar:
300 | pbar.write(
301 | f"Rate limit error, waiting {sec_wait_between_retries} seconds..."
302 | )
303 | else:
304 | raise
305 |
306 | if response is not None:
307 | break
308 |
309 | if retry >= 9:
310 | if pbar:
311 | pbar.write(f"Hit max retries, raising exception.")
312 | raise RuntimeError("Hit max retries")
313 |
314 | if pbar:
315 | pbar.write(
316 | f"Retry {retry} failed, sleeping for {sec_wait_between_retries} seconds"
317 | )
318 |
319 | time.sleep(sec_wait_between_retries)
320 |
321 | if response is None:
322 | return None
323 |
324 | output = response.choices[0].message.content
325 |
326 | if return_input_output_tokens:
327 | return {
328 | "output": output,
329 | "input_tokens": response.usage.prompt_tokens,
330 | "output_tokens": response.usage.completion_tokens,
331 | }
332 | else:
333 | return output
334 |
--------------------------------------------------------------------------------
/codenav/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import multiprocessing as mp
4 | import os
5 | import signal
6 | import tempfile
7 | from queue import Empty
8 | from typing import Dict, Any, Set
9 |
10 | import attrs
11 | import wandb
12 |
13 |
14 | @attrs.define
15 | class StringAsFileArtifact:
16 | string: str
17 | file_name: str
18 | artifact_info: Dict[str, Any]
19 |
20 |
21 | class WandbClient:
22 | def __init__(self, queue: mp.Queue, client_index: int):
23 | self.queue = queue
24 | self.client_index = client_index
25 |
26 | self.Table = wandb.Table
27 | self.Image = wandb.Image
28 | self.Video = wandb.Video
29 | self.Audio = wandb.Audio
30 | self.Html = wandb.Html
31 |
32 | self._started = False
33 | self._closed = False
34 |
35 | def _check_is_start_and_notclosed(self):
36 | if not self._started:
37 | raise ValueError("WandbClient not been started yet. Call .init() first.")
38 |
39 | if self._closed:
40 | raise ValueError("WandbClient is closed.")
41 |
42 | def init(self):
43 | if self._started:
44 | raise ValueError("WandbClient is already started, cannot init again.")
45 |
46 | if self._closed:
47 | raise ValueError("WandbClient is closed, cannot init.")
48 |
49 | self._started = True
50 |
51 | def log(self, data: Dict[str, Any]):
52 | self._check_is_start_and_notclosed()
53 | self.queue.put((self.client_index, False, data))
54 |
55 | def log_artifact(self, artifact):
56 | self._check_is_start_and_notclosed()
57 | self.queue.put((self.client_index, False, artifact))
58 |
59 | def log_string_as_file_artifact(self, saa: StringAsFileArtifact):
60 | self._check_is_start_and_notclosed()
61 | self.queue.put((self.client_index, False, saa))
62 |
63 | def close(self):
64 | if not self._closed:
65 | self.queue.put((self.client_index, True, None))
66 | self._closed = True
67 |
68 | def __del__(self):
69 | if self._started and not self._closed:
70 | self.close()
71 |
72 |
73 | class WandbServer:
74 | def __init__(self, queue: mp.Queue, **wandb_kwargs):
75 | self.queue = queue
76 | self._num_clients_created = 0
77 | self._open_clients: Set[int] = set()
78 |
79 | wandb.init(**wandb_kwargs)
80 |
81 | def finish(self):
82 | assert not self.any_open_clients()
83 | wandb.finish()
84 |
85 | def create_client(self):
86 | wc = WandbClient(self.queue, client_index=self._num_clients_created)
87 | self._num_clients_created += 1
88 | self._open_clients.add(wc.client_index)
89 | return wc
90 |
91 | def any_open_clients(self):
92 | return len(self._open_clients) > 0
93 |
94 | @property
95 | def num_closed_clients(self):
96 | return self._num_clients_created - len(self._open_clients)
97 |
98 | def log(self, timeout: int, verbose: bool = False):
99 | logged_data = []
100 | while True:
101 | try:
102 | client_ind, closing, data = self.queue.get(timeout=timeout)
103 |
104 | assert client_ind in self._open_clients
105 |
106 | if closing:
107 | if verbose:
108 | print(f"Closing client {client_ind}")
109 |
110 | self._open_clients.remove(client_ind)
111 | continue
112 |
113 | if verbose:
114 | print(f"Logging [from client {client_ind}]: {data}")
115 |
116 | if isinstance(data, wandb.Artifact):
117 | wandb.log_artifact(data)
118 | elif isinstance(data, StringAsFileArtifact):
119 | td = tempfile.TemporaryDirectory()
120 | with td as temp_dir:
121 | with open(
122 | os.path.join(temp_dir, data.file_name), "w"
123 | ) as temp_file:
124 | temp_file.write(data.string)
125 |
126 | artifact = wandb.Artifact(**data.artifact_info)
127 | artifact.add_file(os.path.join(temp_dir, data.file_name))
128 |
129 | wandb.log_artifact(artifact)
130 | else:
131 | try:
132 | wandb.log(data)
133 | except TypeError:
134 | if isinstance(data, dict):
135 | new_data = {}
136 | for k, v in data.items():
137 | try:
138 | json.dumps(v)
139 | new_data[k] = v
140 | except:
141 | new_data[k] = str(v)
142 |
143 | wandb.log(new_data)
144 | else:
145 | raise
146 |
147 | logged_data.append(data)
148 | except Empty:
149 | break
150 |
151 | return logged_data
152 |
153 |
154 | class DelayedKeyboardInterrupt:
155 | def __enter__(self):
156 | self.signal_received = False
157 | self.old_handler = signal.signal(signal.SIGINT, self.handler)
158 |
159 | def handler(self, sig, frame):
160 | self.signal_received = (sig, frame)
161 | logging.debug(
162 | "SIGINT received. Delaying KeyboardInterrupt as critical code is running."
163 | )
164 |
165 | def __exit__(self, type, value, traceback):
166 | signal.signal(signal.SIGINT, self.old_handler)
167 | if self.signal_received:
168 | self.old_handler(*self.signal_received)
169 |
--------------------------------------------------------------------------------
/codenav/utils/omegaconf_utils.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import copy
3 | import json
4 | import os.path
5 | import sys
6 | from typing import Optional, Sequence
7 |
8 | from omegaconf import OmegaConf, DictConfig, MissingMandatoryValue
9 | from omegaconf.errors import InterpolationKeyError
10 |
11 |
12 | def recursive_merge(base: DictConfig, other: DictConfig) -> DictConfig:
13 | base = copy.deepcopy(base)
14 |
15 | if all(key not in base for key in other):
16 | return OmegaConf.merge(base, other)
17 |
18 | for key in other:
19 | try:
20 | value = other[key]
21 | except InterpolationKeyError:
22 | value = other._get_node(key)._value()
23 | except MissingMandatoryValue:
24 | value = other._get_node(key)._value()
25 |
26 | if (
27 | key in base
28 | and isinstance(base[key], DictConfig)
29 | and isinstance(value, DictConfig)
30 | ):
31 | base[key] = recursive_merge(base[key], value)
32 | else:
33 | base[key] = value
34 |
35 | return base
36 |
37 |
38 | def load_config_from_file(config_path: str, at_key: Optional[str] = None) -> DictConfig:
39 | config_path = os.path.abspath(config_path)
40 | if not os.path.exists(config_path):
41 | raise FileNotFoundError(f"Config file not found: {config_path}")
42 |
43 | base_dir = os.path.dirname(os.path.abspath(config_path))
44 |
45 | main_conf = OmegaConf.load(config_path)
46 |
47 | defaults_paths = main_conf.get("_defaults_", [])
48 | if "_self_" not in defaults_paths:
49 | defaults_paths.insert(0, "_self_")
50 |
51 | main_conf.pop("_defaults_", None)
52 |
53 | config = OmegaConf.create()
54 | for p in defaults_paths:
55 | if p == "_self_":
56 | config = recursive_merge(config, main_conf)
57 | else:
58 | key = None
59 | if "@" in p:
60 | p, key = p.split("@")
61 |
62 | if not os.path.isabs(p):
63 | p = os.path.join(base_dir, p)
64 |
65 | other_conf = load_config_from_file(p, at_key=key)
66 |
67 | config = recursive_merge(config, other_conf)
68 |
69 | if at_key is not None:
70 | top = OmegaConf.create()
71 | cur = top
72 | sub_keys = at_key.split(".")
73 | for sub_key in sub_keys[:-1]:
74 | cur[sub_key] = OmegaConf.create()
75 | cur = cur[sub_key]
76 | cur[sub_keys[-1]] = config
77 | config = top
78 |
79 | return config
80 |
81 |
82 | def parse_omegaconf(
83 | base_config_path: Optional[str],
84 | args: Optional[Sequence[str]] = None,
85 | ) -> DictConfig:
86 | """
87 | Parses and merges configuration files and command line arguments using OmegaConf.
88 |
89 | This function loads a base configuration file (if provided), processes command line
90 | arguments to override configuration values or load additional files, and returns
91 | the final merged configuration.
92 |
93 | Basics:
94 | OmegaConf is a flexible and powerful configuration management library for Python.
95 | It supports hierarchical configurations, interpolation, and merging of configuration files.
96 | Key features of OmegaConf include:
97 | - Hierarchical Configuration: Allows configurations to be organized in nested structures.
98 | - Interpolation: Supports references to other values within the configuration.
99 | - Merging: Combines multiple configurations, with later configurations overriding earlier ones.
100 |
101 | YAML file overriding with `_defaults_`:
102 | You can use the `_defaults_` field in a YAML file to specify a list
103 | of other YAML files that should be loaded and merged. The `_defaults_` field allows
104 | you to control the order in which configurations are applied, with later files in
105 | the list overriding earlier ones.
106 |
107 | Example YAML files:
108 |
109 | ```yaml
110 | # logging.yaml
111 | log_level: info
112 | ```
113 |
114 | ```yaml
115 | # database.yaml
116 | database:
117 | host: localhost
118 | port: 5432
119 | ```
120 |
121 | ```yaml
122 | # base.yaml
123 | _defaults_:
124 | - logging.yaml@logging
125 | - database.yaml
126 |
127 | database:
128 | port: 3306
129 | ```
130 |
131 | Running `parse_omegaconf("base.yaml")` will result in the following configuration:
132 | ```yaml
133 | logging:
134 | log_level: info
135 | database:
136 | host: localhost
137 | port: 5432
138 | ```
139 |
140 | How overriding works:
141 | 1. When loading the main configuration file, `base.yaml`, the configurations listed in `_defaults_` are loaded
142 | in the order specified.
143 | 2. The main configuration file (the one containing `_defaults_`) IS LOADED FIRST BY DEFAULT.
144 | 3. If there are conflicting keys, the values in the later files override the earlier ones.
145 | 4. If a default yaml file is post-fixed with `@key`, the configuration will be placed at the specified key. Notice
146 | that logging.yaml@logging results in log_level: info being placed under the logging key in the final configuration.
147 | 5. By default, the fields in the main configuration file ARE OVERWRITTEN by the fields in the files listed in `_defaults_`
148 | as they are loaded first. If you'd prefer a different order, you can add _self_ to the list of defaults. E.g.
149 | ```yaml
150 | _defaults_:
151 | - logging.yaml
152 | - _self_
153 | - database.yaml
154 | ```
155 | will result in database.yaml overriding fields in base.yaml resulting in the merged configuration:
156 | ```yaml
157 | logging:
158 | log_level: info
159 | database:
160 | host: localhost
161 | port: 5432
162 | ```
163 |
164 | Command line overrides:
165 | Command line arguments can be used to override configuration values. Overrides can
166 | be specified directly or by using the `_file_` keyword to load additional configuration files.
167 |
168 | Direct overrides:
169 | ```bash
170 | python script.py database.port=1234
171 | ```
172 |
173 | `_file_` override:
174 | The `_file_` override allows you to specify an additional YAML file to be merged
175 | into the existing configuration:
176 | ```bash
177 | python script.py _file_=extra_config.yaml
178 | ```
179 |
180 | `_file_` with key specification:
181 | To load a file and merge it at a specific key:
182 | ```bash
183 | python script.py _file_=extra_config.yaml@some.key
184 | ```
185 |
186 | Comparing `parse_omegaconf` with the `hydra` library:
187 | Hydra is a popular configuration management library built on top of OmegaConf. Which provides a large
188 | collection of additional features for managing complex configurations. Hydra is designed to be a
189 | comprehensive solution for configuration management, including support for running and managing
190 | multiple job executions. `parse_omegaconf` is a much simpler function that focuses solely on configuration
191 | loading and merging, without the additional features provided by Hydra. We wrote this function as
192 | a lightweight alternative which is less opinionated about how configurations should be structured, loaded,
193 | and used.
194 |
195 | Args:
196 | base_config_path (Optional[str]): The path to the base configuration file.
197 | If None, an empty configuration will be created.
198 | args (Optional[Sequence[str]]): A list of command line arguments to override
199 | configuration values. If None, `sys.argv[1:]` will be used.
200 |
201 | Returns:
202 | omegaconf.DictConfig: The final merged configuration.
203 |
204 | Raises:
205 | FileNotFoundError: If the specified base configuration file does not exist.
206 | ValueError: If a command line argument is not formatted as `key=value`.
207 | KeyError: If attempting to override a non-existent key without using the `+` prefix.
208 | """
209 | if base_config_path is not None:
210 | config = load_config_from_file(base_config_path)
211 | else:
212 | config = OmegaConf.create()
213 |
214 | if args is None:
215 | args = sys.argv[1:]
216 |
217 | while len(args) > 0:
218 | arg = args.pop(0)
219 |
220 | try:
221 | key, value = arg.split("=", maxsplit=1)
222 | except ValueError:
223 | raise ValueError(f"Invalid argument: {arg}. Must be formatted as key=value")
224 |
225 | if arg.startswith("_file_="):
226 | sub_key = None
227 | if "@" in value:
228 | value, sub_key = value.split("@")
229 | config = recursive_merge(
230 | config, load_config_from_file(value, at_key=sub_key)
231 | )
232 | else:
233 | try:
234 | value = ast.literal_eval(value)
235 | except (SyntaxError, ValueError):
236 | try:
237 | value = json.loads(value)
238 | except json.JSONDecodeError:
239 | pass
240 |
241 | if key.startswith("+"):
242 | key = key[1:]
243 | else:
244 | key_parts = key.split(".")
245 | sub_config = config
246 | err_msg = (
247 | f"Cannot override value for key {key} in config to {value} using a command line argument"
248 | f" as the key {key} does not already exist in the config. If you'd like to to add a"
249 | f" key that does not already exist, please use the format +key=value rather than just key=value."
250 | )
251 | for key_part in key_parts[:-1]:
252 | if key_part not in sub_config:
253 | raise KeyError(err_msg)
254 | sub_config = sub_config[key_part]
255 |
256 | if sub_config._get_node(key_parts[-1]) is None:
257 | raise KeyError(err_msg)
258 |
259 | OmegaConf.update(config, key, value)
260 |
261 | return config
262 |
--------------------------------------------------------------------------------
/codenav/utils/parsing_utils.py:
--------------------------------------------------------------------------------
1 | import ast
2 | from _ast import AST
3 | from typing import Union, cast
4 |
5 |
6 | def get_class_or_function_prototype(
7 | code: Union[str, ast.ClassDef, ast.FunctionDef],
8 | include_init: bool = True,
9 | ) -> str:
10 | """
11 | Summarizes the given Python class or function definition code.
12 |
13 | For classes by, this will occur by keeping the class name and its __init__ method signature,
14 | replacing the body of __init__ with an ellipsis (...).
15 |
16 | For functions, this will occur by keeping the function name and its signature, replacing the body with an ellipsis (...).
17 |
18 | Args:
19 | - code: A string containing the Python class definition.
20 |
21 | Returns:
22 | - A summary string of the form "class ClassName(BaseClass): def __init__(self, arg1: Type, arg2: Type, ...): ..."
23 | for a class definition, or "def function_name(arg1: Type, arg2: Type, ...) -> Type: ..." for a function definition.
24 | """
25 | if isinstance(code, str):
26 | tree = ast.parse(code)
27 | func_or_class = [
28 | node
29 | for node in ast.iter_child_nodes(tree)
30 | if isinstance(node, (ast.ClassDef, ast.FunctionDef))
31 | ]
32 | assert (
33 | len(func_or_class) == 1
34 | ), "The given code should contain exactly one class or function definition."
35 | node = func_or_class[0]
36 | else:
37 | node = code
38 |
39 | summary = ""
40 | if isinstance(node, ast.ClassDef):
41 | # Format class definition
42 | base_classes = [ast.unparse(base) for base in node.bases]
43 | class_header = f"class {node.name}({', '.join(base_classes)}):"
44 | summary += class_header
45 |
46 | if include_init:
47 | for item in node.body:
48 | if isinstance(item, ast.FunctionDef) and item.name == "__init__":
49 | # Format __init__ method signature
50 | init_signature = ast.unparse(item.args)
51 | summary += f" def __init__({init_signature}): ..."
52 | break
53 | else:
54 | summary += " ..."
55 | elif isinstance(node, ast.FunctionDef):
56 | # Format function definition
57 | function_signature = ast.unparse(node.args)
58 |
59 | try:
60 | return_suff = f" -> {ast.unparse(cast(AST, node.returns))}"
61 | except:
62 | return_suff = ""
63 |
64 | summary += f"def {node.name}({function_signature}){return_suff}: ..."
65 |
66 | return summary
67 |
--------------------------------------------------------------------------------
/codenav/utils/prompt_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from collections import defaultdict
4 | from typing import Any, Dict, List, Sequence, Tuple
5 |
6 | import attrs
7 |
8 | from codenav.constants import PROMPTS_DIR
9 |
10 |
11 | def extract_placeholders(s):
12 | pattern = r"(? str:
48 | found_path = None
49 | for prompt_dir in self.prompt_dirs:
50 | prompt_path = os.path.join(prompt_dir, rel_path)
51 | if os.path.exists(prompt_path):
52 | found_path = prompt_path
53 | break
54 |
55 | if found_path is None:
56 | raise FileNotFoundError(
57 | f"Prompt file with relative path: '{rel_path}' not found "
58 | f"in any of the prompt directories: {self.prompt_dirs}"
59 | )
60 |
61 | with open(found_path, "r") as f:
62 | prompt_str = f.read()
63 | placeholders = extract_placeholders(prompt_str)
64 | for pl in placeholders:
65 | self.placeholder_to_paths[pl].append(found_path)
66 |
67 | return prompt_str
68 |
69 | def get_action_prompts(self) -> Sequence[str]:
70 | return [
71 | self.get_prompt(getattr(self.action_prompts, action))
72 | for action in self.actions_to_enable
73 | ]
74 |
75 | def get_response_prompts(self) -> Sequence[str]:
76 | return [
77 | self.get_prompt(getattr(self.response_prompts, action))
78 | for action in self.actions_to_enable
79 | ]
80 |
81 | def build_template(self) -> Tuple[str, Dict[str, List[str]]]:
82 | self.placeholder_to_paths = defaultdict(list)
83 | prompts = [
84 | self.get_prompt(self.overview),
85 | self.get_prompt(self.workflow),
86 | self.get_prompt(self.action_prompts.preamble),
87 | *self.get_action_prompts(),
88 | self.get_prompt(self.action_prompts.guidelines),
89 | self.get_prompt(self.response_prompts.preamble),
90 | *self.get_response_prompts(),
91 | self.get_prompt(self.repo_description),
92 | ]
93 |
94 | return "\n\n".join(prompts), dict(self.placeholder_to_paths)
95 |
96 | def build(self, placeholder_values: Dict[str, Any]) -> str:
97 | return self.build_template()[0].format(**placeholder_values)
98 |
--------------------------------------------------------------------------------
/codenav/utils/string_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Optional
3 |
4 |
5 | def get_tag_content_from_text(
6 | text: str,
7 | tag: str,
8 | ) -> Optional[str]:
9 | pattern = f"<{tag}>" + r"\s*(.*?)\s*" + f"{tag}>"
10 | match = re.search(pattern, text, re.DOTALL)
11 | content = match.group(1) if match else None
12 | if content == "":
13 | return None
14 | return content
15 |
16 |
17 | def str2bool(s: str):
18 | s = s.lower().strip()
19 | if s in ["yes", "true", "t", "y", "1"]:
20 | return True
21 | elif s in ["no", "false", "f", "n", "0"]:
22 | return False
23 | else:
24 | raise NotImplementedError
25 |
--------------------------------------------------------------------------------
/codenav_examples/create_code_env.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from codenav.environments.code_env import PythonCodeEnv
4 |
5 | project_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
6 | print(f"Project directory: {project_dir}")
7 | env = PythonCodeEnv(
8 | sys_paths=[project_dir],
9 | working_dir=os.path.join(project_dir, "playground"),
10 | enable_type_checking=True,
11 | )
12 |
13 | exec_output1 = env.step("from codenav.interaction.messages import ACTION_TYPES")
14 | print(exec_output1.format(include_code=True, display_updated_vars=True))
15 |
16 | exec_output2 = env.step("print(ACTION_TYPES)")
17 | print(exec_output2.format(include_code=True, display_updated_vars=True))
18 |
--------------------------------------------------------------------------------
/codenav_examples/create_episode.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from codenav.agents.gpt4.agent import OpenAICodeNavAgent
4 | from codenav.constants import ABS_PATH_OF_CODENAV_DIR, DEFAULT_OPENAI_MODEL
5 | from codenav.environments.code_env import PythonCodeEnv
6 | from codenav.environments.done_env import DoneEnv
7 | from codenav.environments.retrieval_env import EsCodeRetriever, RetrievalEnv
8 | from codenav.interaction.episode import Episode
9 | from codenav.retrieval.elasticsearch.elasticsearch_constants import RESERVED_CHARACTERS
10 | from codenav.retrieval.elasticsearch.index_codebase import DEFAULT_ES_HOST
11 | from codenav.utils.prompt_utils import PromptBuilder
12 |
13 | ALLOWED_ACTIONS = ["done", "code", "search"]
14 | CODE_DIR = ABS_PATH_OF_CODENAV_DIR
15 | PARENT_DIR = os.path.dirname(CODE_DIR)
16 |
17 | # create prompt
18 | prompt_builder = PromptBuilder(repo_description="codenav/repo_description.txt")
19 | prompt = prompt_builder.build(
20 | dict(
21 | AVAILABLE_ACTIONS=ALLOWED_ACTIONS,
22 | RESERVED_CHARACTERS=RESERVED_CHARACTERS,
23 | RETRIEVALS_PER_KEYWORD=3,
24 | )
25 | )
26 |
27 | # create environments
28 | code_env = PythonCodeEnv(
29 | code_dir=CODE_DIR,
30 | sys_paths=[PARENT_DIR],
31 | working_dir=os.path.join(PARENT_DIR, "playground"),
32 | )
33 |
34 | retrieval_env = RetrievalEnv(
35 | code_retriever=EsCodeRetriever(
36 | index_name="codenav",
37 | host=DEFAULT_ES_HOST,
38 | ),
39 | expansions_per_query=3,
40 | prototypes_per_query=7,
41 | )
42 |
43 | done_env = DoneEnv()
44 |
45 |
46 | # create agent using prompt
47 | agent = OpenAICodeNavAgent(
48 | prompt=prompt,
49 | model=DEFAULT_OPENAI_MODEL,
50 | allowed_action_types=ALLOWED_ACTIONS,
51 | )
52 |
53 | # create environments:
54 | episode = Episode(
55 | agent,
56 | action_type_to_env=dict(
57 | code=code_env,
58 | search=retrieval_env,
59 | done=done_env,
60 | ),
61 | user_query_str="Find the DoneEnv and instantiate it",
62 | )
63 |
64 | episode.step_until_max_steps_or_success(max_steps=5)
65 |
--------------------------------------------------------------------------------
/codenav_examples/create_index.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import pandas as pd
5 | from elasticsearch import Elasticsearch
6 |
7 | from codenav.environments.retrieval_env import RetrievalEnv
8 | from codenav.retrieval.elasticsearch.elasticsearch_retriever import EsCodeRetriever
9 | from codenav.retrieval.elasticsearch.index_codebase import DEFAULT_ES_HOST, build_index
10 |
11 | print(f"Looking for a running Elasticsearch server at {DEFAULT_ES_HOST}...")
12 | es = Elasticsearch(DEFAULT_ES_HOST)
13 | if es.ping():
14 | print(f"Elasticsearch server is running at {DEFAULT_ES_HOST}")
15 | else:
16 | print(
17 | f"Elasticsearch server not found at {DEFAULT_ES_HOST}\n"
18 | "\tStart the server before running this script\n"
19 | "\tTo start the Elasticsearch server, run `condenav init` or `python -m condenav.codenav_run init`"
20 | )
21 | sys.exit(1)
22 |
23 |
24 | CODE_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
25 | SUBDIR = "codenav"
26 | print(f"Building index from subdir='{SUBDIR}' of code_dir='{CODE_DIR}' ...")
27 | build_index(
28 | code_dir=CODE_DIR,
29 | force_subdir=SUBDIR,
30 | delete_index=True,
31 | index_uid="codenav",
32 | host=DEFAULT_ES_HOST,
33 | )
34 |
35 | print("Creating EsCodeRetriever which can search the index...")
36 | es_code_retriever = EsCodeRetriever(index_name="codenav", host=DEFAULT_ES_HOST)
37 |
38 | print("Searching the ES index using `prototype: CodeEnv`...")
39 | search_query = "prototype: CodeSummaryEnv"
40 | raw_search_res = es_code_retriever.search(search_query)
41 |
42 | raw_search_res_df = pd.DataFrame.from_records(raw_search_res)
43 | print(raw_search_res_df)
44 |
45 | print("Creating retrieval environment that adds state logic...")
46 | env = RetrievalEnv(
47 | code_retriever=es_code_retriever,
48 | expansions_per_query=3,
49 | prototypes_per_query=5,
50 | summarize_code=False,
51 | )
52 | response = env.step(search_query)
53 | print(response.format())
54 |
--------------------------------------------------------------------------------
/codenav_examples/create_prompt.py:
--------------------------------------------------------------------------------
1 | from codenav.agents.gpt4.agent import DEFAULT_OPENAI_MODEL, OpenAICodeNavAgent
2 | from codenav.retrieval.elasticsearch.elasticsearch_constants import RESERVED_CHARACTERS
3 | from codenav.utils.prompt_utils import PromptBuilder
4 |
5 | # Prompt builder puts together a prompt template from text files
6 | # The template may contain placeholders for values
7 | prompt_builder = PromptBuilder(repo_description="codenav/repo_description.txt")
8 | prompt_template, placeholder_to_paths = prompt_builder.build_template()
9 |
10 | # see placeholders and the file paths they appear in
11 | print("Placeholders in template:\n", placeholder_to_paths)
12 |
13 | # provide values for these placeholders
14 | ALLOWED_ACTIONS = ["done", "code", "search"]
15 | placeholder_values = dict(
16 | AVAILABLE_ACTIONS=ALLOWED_ACTIONS,
17 | RESERVED_CHARACTERS=RESERVED_CHARACTERS,
18 | RETRIEVALS_PER_KEYWORD=3,
19 | )
20 | print("Provided values:\n", placeholder_values)
21 |
22 | # build prompt using values
23 | # the following is equivalent to prompt_template.format(**placeholder_values)
24 | prompt = prompt_builder.build(placeholder_values)
25 |
26 | # create agent using prompt
27 | agent = OpenAICodeNavAgent(
28 | prompt=prompt,
29 | model=DEFAULT_OPENAI_MODEL,
30 | allowed_action_types=ALLOWED_ACTIONS,
31 | )
32 |
--------------------------------------------------------------------------------
/codenav_examples/parallel_evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, List, Optional
3 |
4 | from codenav.agents.gpt4.agent import OpenAICodeNavAgent
5 | from codenav.constants import ABS_PATH_OF_CODENAV_DIR, DEFAULT_OPENAI_MODEL
6 | from codenav.environments.code_env import PythonCodeEnv
7 | from codenav.environments.done_env import DoneEnv
8 | from codenav.environments.retrieval_env import EsCodeRetriever, RetrievalEnv
9 | from codenav.interaction.episode import Episode
10 | from codenav.retrieval.elasticsearch.elasticsearch_constants import RESERVED_CHARACTERS
11 | from codenav.retrieval.elasticsearch.index_codebase import DEFAULT_ES_HOST
12 | from codenav.utils.eval_types import EvalInput, EvalSpec, Str2AnyDict
13 | from codenav.utils.evaluator import CodenavEvaluator
14 | from codenav.utils.prompt_utils import PromptBuilder
15 |
16 |
17 | # EvalSpec defines the components for running an evaluation
18 | # EvalSpec is then used by the CodenavEvaluator to run CodeNav on inputs using 1 or more processes
19 | # EvalSpec requires defining 3 methods: build_episode, run_interaction, log_output
20 | class CodenavEvalSpec(EvalSpec):
21 | def __init__(
22 | self,
23 | episode_kwargs: Str2AnyDict,
24 | interaction_kwargs: Str2AnyDict,
25 | logging_kwargs: Str2AnyDict,
26 | ):
27 | super().__init__(episode_kwargs, interaction_kwargs, logging_kwargs)
28 |
29 | @staticmethod
30 | def build_episode(
31 | eval_input: EvalInput,
32 | episode_kwargs: Optional[Str2AnyDict] = None,
33 | ) -> Episode:
34 | assert episode_kwargs is not None
35 | prompt_builder = PromptBuilder(
36 | repo_description=episode_kwargs["repo_description"]
37 | )
38 | prompt = prompt_builder.build(
39 | dict(
40 | AVAILABLE_ACTIONS=episode_kwargs["allowed_actions"],
41 | RESERVED_CHARACTERS=RESERVED_CHARACTERS,
42 | RETRIEVALS_PER_KEYWORD=episode_kwargs["retrievals_per_keyword"],
43 | )
44 | )
45 |
46 | return Episode(
47 | agent=OpenAICodeNavAgent(
48 | prompt=prompt,
49 | model=episode_kwargs["llm"],
50 | allowed_action_types=episode_kwargs["allowed_actions"],
51 | ),
52 | action_type_to_env=dict(
53 | code=PythonCodeEnv(
54 | code_dir=episode_kwargs["code_dir"],
55 | sys_paths=episode_kwargs["sys_paths"],
56 | working_dir=episode_kwargs["working_dir"],
57 | ),
58 | search=RetrievalEnv(
59 | code_retriever=EsCodeRetriever(
60 | index_name=episode_kwargs["index_name"],
61 | host=episode_kwargs["host"],
62 | ),
63 | expansions_per_query=episode_kwargs["retrievals_per_keyword"],
64 | prototypes_per_query=episode_kwargs["prototypes_per_keyword"],
65 | ),
66 | done=DoneEnv(),
67 | ),
68 | user_query_str=eval_input.query,
69 | )
70 |
71 | @staticmethod
72 | def run_interaction(
73 | episode: Episode,
74 | interaction_kwargs: Optional[Str2AnyDict] = None,
75 | ) -> Str2AnyDict:
76 | assert interaction_kwargs is not None
77 | episode.step_until_max_steps_or_success(
78 | max_steps=interaction_kwargs["max_steps"],
79 | verbose=interaction_kwargs["verbose"],
80 | )
81 | ipynb_str = episode.to_notebook(cur_dir=episode.code_env.working_dir)
82 | return dict(ipynb_str=ipynb_str)
83 |
84 | @staticmethod
85 | def log_output(
86 | interaction_output: Str2AnyDict,
87 | eval_input: EvalInput,
88 | logging_kwargs: Optional[Str2AnyDict] = None,
89 | ) -> Any:
90 | assert logging_kwargs is not None
91 |
92 | outfile = os.path.join(logging_kwargs["out_dir"], f"{eval_input.uid}.ipynb")
93 | with open(outfile, "w") as f:
94 | f.write(interaction_output["ipynb_str"])
95 |
96 | return outfile
97 |
98 |
99 | def run_parallel_evaluation(
100 | eval_inputs: List[EvalInput],
101 | episode_kwargs: Str2AnyDict,
102 | interaction_kwargs: Str2AnyDict,
103 | logging_kwargs: Str2AnyDict,
104 | num_processes: int = 2,
105 | ):
106 | # create an instance of the CodenavEvaluator using the eval spec
107 | evaluator = CodenavEvaluator(
108 | eval_spec=CodenavEvalSpec(
109 | episode_kwargs=episode_kwargs,
110 | interaction_kwargs=interaction_kwargs,
111 | logging_kwargs=logging_kwargs,
112 | )
113 | )
114 |
115 | # Get outputs from the output queue
116 | num_inputs = len(eval_inputs)
117 | for i, output in enumerate(evaluator.evaluate(eval_inputs, n_procs=2)):
118 | print(
119 | f"Evaluated {i+1}/{num_inputs} | Input uid: {eval_inputs[i].uid} | Output saved to ",
120 | output,
121 | )
122 |
123 |
124 | if __name__ == "__main__":
125 | episode_kwargs = dict(
126 | allowed_actions=["done", "code", "search"],
127 | repo_description="codenav/repo_description.txt",
128 | retrievals_per_keyword=3,
129 | prototypes_per_keyword=7,
130 | llm=DEFAULT_OPENAI_MODEL,
131 | code_dir=ABS_PATH_OF_CODENAV_DIR,
132 | sys_paths=[os.path.dirname(ABS_PATH_OF_CODENAV_DIR)],
133 | working_dir=os.path.join(
134 | os.path.dirname(ABS_PATH_OF_CODENAV_DIR), "playground"
135 | ),
136 | index_name="codenav",
137 | host=DEFAULT_ES_HOST,
138 | )
139 | interaction_kwargs = dict(max_steps=10, verbose=True)
140 | logging_kwargs = dict(out_dir="/Users/tanmayg/Code/codenav_test/outputs")
141 |
142 | # Define the inputs to evaluate using EvalInput
143 | # Each EvalInput instance consists of a unique id (uid), a query, and optionally any metadata
144 | eval_inputs = [
145 | EvalInput(uid=1, query="Find the DoneEnv and instantiate it"),
146 | EvalInput(
147 | uid=2,
148 | query="Build the prompt template using PromptBuilder and print all the placeholders",
149 | ),
150 | ]
151 | run_parallel_evaluation(
152 | eval_inputs,
153 | episode_kwargs,
154 | interaction_kwargs,
155 | logging_kwargs,
156 | num_processes=2,
157 | )
158 |
--------------------------------------------------------------------------------
/playground/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | extend-exclude = '''
3 | (
4 | eval_codebases/hf_transformers
5 | |eval_codebases/DAMO-ConvAI
6 | |.venv
7 | |external_src
8 | )
9 | '''
10 |
11 | [tool.mypy]
12 | exclude = [
13 | "^eval_codebases/llama2-webui.*",
14 | "^eval_codebases/hf_transformers.*",
15 | "^eval_codebases/cvxpy.*",
16 | "^eval_codebases/flask.*",
17 | "^eval_codebases/mnm.*",
18 | "^eval_codebases/sympy.*",
19 | "^eval_codebases/scikit-learn.*",
20 | "^eval_codebases/src.*",
21 | "^external_src.*",
22 | "tmp.py"
23 | ]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | openai
2 | together
3 | cohere
4 | tiktoken
5 | black==23.12.1
6 | mypy
7 | flake8
8 | pillow
9 | numpy<2.0.0
10 | pandas
11 | wandb
12 | matplotlib
13 | nbformat
14 | notebook
15 | jupyterlab
16 | elasticsearch
17 | compress_json
18 | omegaconf
19 | hydra-core
20 | attrs
21 | tqdm
22 |
--------------------------------------------------------------------------------
/scripts/release.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from pathlib import Path
4 | from subprocess import getoutput
5 |
6 | PKG_NAME = "codenav"
7 |
8 |
9 | def make_package(verbose=False):
10 | """Prepares sdist for codenav."""
11 |
12 | orig_dir = os.getcwd()
13 |
14 | base_dir = os.path.dirname(os.path.abspath(os.path.dirname(Path(__file__))))
15 | os.chdir(base_dir)
16 |
17 | with open(".VERSION", "r") as f:
18 | __version__ = f.readline().strip()
19 |
20 | # generate sdist via setuptools
21 | output = getoutput(f"{sys.executable} setup.py sdist")
22 | if verbose:
23 | print(output)
24 |
25 | os.chdir(os.path.join(base_dir, "dist"))
26 |
27 | # uncompress the tar.gz sdist
28 | output = getoutput(f"tar zxvf {PKG_NAME}-{__version__}.tar.gz")
29 | if verbose:
30 | print(output)
31 |
32 | # create new source file with version
33 | getoutput(
34 | f"printf '__version__ = \"{__version__}\"\n' >> {PKG_NAME}-{__version__}/{PKG_NAME}/_version.py"
35 | )
36 | # include it in sources
37 | getoutput(
38 | f'printf "\ncodenav/_version.py" >> {PKG_NAME}-{__version__}/{PKG_NAME}.egg-info/SOURCES.txt'
39 | )
40 |
41 | # recompress tar.gz
42 | output = getoutput(
43 | f"tar zcvf {PKG_NAME}-{__version__}.tar.gz {PKG_NAME}-{__version__}/"
44 | )
45 | if verbose:
46 | print(output)
47 |
48 | # remove temporary directory
49 | output = getoutput(f"rm -r {PKG_NAME}-{__version__}")
50 | if verbose:
51 | print(output)
52 |
53 | os.chdir(orig_dir)
54 |
55 |
56 | if __name__ == "__main__":
57 | make_package(verbose=False)
58 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | from setuptools import find_packages, setup
5 |
6 |
7 | def parse_req_file(fname, initial=None):
8 | """Reads requires.txt file generated by setuptools and outputs a
9 | new/updated dict of extras as keys and corresponding lists of dependencies
10 | as values.
11 |
12 | The input file's contents are similar to a `ConfigParser` file, e.g.
13 | pkg_1
14 | pkg_2
15 | pkg_3
16 |
17 | [extras1]
18 | pkg_4
19 | pkg_5
20 |
21 | [extras2]
22 | pkg_6
23 | pkg_7
24 | """
25 | reqs = {} if initial is None else initial
26 | cline = None
27 | with open(fname, "r") as f:
28 | for line in f.readlines():
29 | line = line[:-1].strip()
30 | if len(line) == 0:
31 | continue
32 | if line[0] == "[":
33 | # Add new key for current extras (if missing in dict)
34 | cline = line[1:-1].strip()
35 | if cline not in reqs:
36 | reqs[cline] = []
37 | else:
38 | # Only keep dependencies from extras
39 | if cline is not None:
40 | reqs[cline].append(line)
41 | return reqs
42 |
43 |
44 | def get_version(fname):
45 | """Reads PKG-INFO file generated by setuptools and extracts the Version
46 | number."""
47 | res = "UNK"
48 | with open(fname, "r") as f:
49 | for line in f.readlines():
50 | line = line[:-1]
51 | if line.startswith("Version:"):
52 | res = line.replace("Version:", "").strip()
53 | break
54 | if res in ["UNK", ""]:
55 | raise ValueError(f"Missing Version number in {fname}")
56 | return res
57 |
58 |
59 | def read_requirements(filename: str):
60 | with open(filename) as requirements_file:
61 | import re
62 |
63 | def fix_url_dependencies(req: str) -> str:
64 | """Pip and setuptools disagree about how URL dependencies should be handled."""
65 | m = re.match(
66 | r"^(git\+)?(https|ssh)://(git@)?github\.com/([\w-]+)/(?P[\w-]+)\.git",
67 | req,
68 | )
69 | if m is None:
70 | return req
71 | else:
72 | return f"{m.group('name')} @ {req}"
73 |
74 | requirements = []
75 | for line in requirements_file:
76 | line = line.strip()
77 | if line.startswith("#") or line.startswith("-e") or len(line) <= 0:
78 | continue
79 | requirements.append(fix_url_dependencies(line))
80 | return requirements
81 |
82 |
83 | def _do_setup():
84 | base_dir = os.path.abspath(os.path.dirname(Path(__file__)))
85 |
86 | if not os.path.exists(
87 | os.path.join(base_dir, "codenav.egg-info/dependency_links.txt")
88 | ):
89 | # Build mode for sdist
90 | os.chdir(base_dir)
91 |
92 | version_path = os.path.abspath(".VERSION")
93 | print(version_path)
94 | with open(version_path, "r") as f:
95 | __version__ = f.readline().strip()
96 |
97 | # Extra dependencies for development (actually unnecessary)
98 | extras = {}
99 | if os.path.exists("dev_requirements.txt"):
100 | extras = {
101 | "dev": [
102 | l.strip()
103 | for l in open("dev_requirements.txt", "r").readlines()
104 | if l.strip() != ""
105 | ]
106 | }
107 | else:
108 | # Install mode from sdist
109 | __version__ = get_version(os.path.join(base_dir, "codenav.egg-info/PKG-INFO"))
110 | extras = parse_req_file(os.path.join(base_dir, "codenav.egg-info/requires.txt"))
111 |
112 | setup(
113 | name="codenav",
114 | version=__version__,
115 | description=(
116 | "CodeNav is a LLM-powered agent that can answer queries about code."
117 | ),
118 | long_description=open("README.md").read(),
119 | long_description_content_type="text/markdown",
120 | classifiers=[
121 | "Intended Audience :: Science/Research",
122 | "Development Status :: 3 - Alpha",
123 | "License :: OSI Approved :: Apache Software License",
124 | "Programming Language :: Python :: 3",
125 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
126 | ],
127 | keywords=["code understanding", "LLM", "large language models"],
128 | url="https://github.com/allenai/codenav",
129 | author="Allen Institute for Artificial Intelligence",
130 | author_email="lucaw@allenai.org",
131 | packages=find_packages(
132 | include=["codenav", "codenav.*"],
133 | exclude=["*.tests", "*.tests.*", "tests.*", "tests"],
134 | ),
135 | license="Apache 2.0",
136 | package_data={
137 | "codenav": ["prompts/default/*"],
138 | },
139 | install_requires=read_requirements("requirements.txt"),
140 | python_requires=">=3.9",
141 | entry_points={"console_scripts": ["codenav=codenav.codenav_run:main"]},
142 | # scripts=["codenav/codenav_run.py"],
143 | # setup_requires=["pytest-runner"],
144 | # tests_require=["pytest", "pytest-cov"],
145 | extras_require=extras,
146 | )
147 |
148 |
149 | if __name__ == "__main__":
150 | _do_setup()
151 |
--------------------------------------------------------------------------------