├── .bumpversion.cfg ├── .github └── workflows │ └── test.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── assets ├── programmer-demo-1080.gif └── programmer-ui.png ├── dev-notebooks └── BadPatchExperiment.ipynb ├── docs ├── design │ ├── file_line_tools.md │ ├── git_integration_rework.md │ ├── settings_management.md │ └── weave_query_ref_expansion.md └── dev │ └── devlog-shawn.md ├── programmer-ui └── ui.py ├── programmer ├── __init__.py ├── agent.py ├── agent_texteditor.py ├── agent_texteditor_o1.py ├── config.py ├── console.py ├── containerserver │ ├── README.md │ ├── checkserver.py │ └── cmserver.py ├── environment.py ├── evals │ └── eval_repeated_edits.py ├── evaluate.py ├── file_protocol.py ├── frog.jpg ├── git.py ├── io_context.py ├── programmer.py ├── settings_manager.py ├── swebench │ ├── README.md │ ├── __init__.py │ ├── data │ │ ├── ensembled_annotations_public.csv │ │ ├── samples_with_3_annotations_public.csv │ │ └── swebench-verified.parquet │ ├── evaluate.py │ ├── ingest │ │ ├── README.md │ │ ├── ingest_eval.py │ │ ├── make_dataset.py │ │ └── requirements.txt │ ├── run_instance.py │ ├── score.py │ ├── scripts │ │ ├── example_v_models.py │ │ └── verified_difficulty_labels.py │ └── swebench_model.py ├── tests │ ├── conftest.py │ ├── test_file_line_tools.py │ ├── test_git_integration.py │ ├── test_settings_manager.py │ ├── test_text_editor.py │ ├── test_tool_calling.py │ └── test_weave_query.py ├── text_editor.py ├── tool_calling.py ├── tools.py └── weave_next │ ├── api.py │ └── weave_query.py ├── pyproject.toml ├── pytest.ini ├── release.sh └── requirements-dev.txt /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.1.9 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:pyproject.toml] 7 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run Unit Tests and Pyright 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | 13 | strategy: 14 | matrix: 15 | python-version: ["3.10"] 16 | 17 | steps: 18 | - name: Check out code 19 | uses: actions/checkout@v2 20 | 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install -r requirements-dev.txt 30 | pip install -e . 31 | 32 | - name: Run tests 33 | run: pytest 34 | 35 | - name: Run Pyright 36 | run: pyright 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | weave.db 2 | __pycache__ 3 | *-checkpoint.ipynb 4 | dist/ 5 | programmer.egg-info 6 | 7 | # PyPI uploads 8 | .pypirc 9 | .programmer 10 | .python-version 11 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 0.1.9 4 | 5 | Full Changelog: [v0.1.8...v0.1.9](https://github.com/wandb/programmer/compare/v0.1.8...v0.1.9) 6 | 7 | ### Features 8 | 9 | - Support o1 based agents 10 | - llm playground in "programmer ui" 11 | - "TextEditor" tools for agents 12 | - lots of eval work: swebench, eval_repeated_edits micro eval 13 | 14 | ## 0.1.8 15 | 16 | Full Changelog: [v0.1.7...v0.1.8](https://github.com/wandb/programmer/compare/v0.1.7...v0.1.8) 17 | 18 | ### Chores 19 | 20 | - fix build for "programmer ui" 21 | 22 | ## 0.1.7 23 | 24 | Full Changelog: [v0.1.6...v0.1.7](https://github.com/wandb/programmer/compare/v0.1.6...v0.1.7) 25 | 26 | ### Chores 27 | 28 | - fix build to include sub-packages 29 | 30 | 31 | ## 0.1.6 32 | 33 | Full Changelog: [v0.1.5...v0.1.6](https://github.com/wandb/programmer/compare/v0.1.5...v0.1.6) 34 | 35 | ### Features 36 | 37 | - weave cloud logging 38 | - git state tracking 39 | - settings management 40 | - programmer ui 41 | 42 | ## 0.1.5 43 | 44 | Initial working release -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # programmer 2 | 3 | programmer is a command-line based AI programmer, that will help you get stuff done. 4 | 5 | **WARNING**: programmer has direct access to your machine, it can run commands, and read and write files, without safety checks. You have been warned! 6 | 7 | ![Demo](./assets/programmer-demo-1080.gif) 8 | 9 | ## Quickstart 10 | 11 | ``` 12 | pip install programmer 13 | programmer 14 | ``` 15 | 16 | ## Switching to o1-based agents 17 | 18 | Programmer works with OpenAI's just released [o1 models](https://openai.com/index/introducing-openai-o1-preview/). 19 | 20 | **GIANT WARNING**: Remember, programmer directly runs agent commands on your machine, without prompting you first. The o1 models are brand new and should not be trusted to do this! You have been **GIANT WARNING** warned. 21 | 22 | ``` 23 | programmer settings set agent o1-preview-2024-09-12_o1harness 24 | # OR 25 | programmer settings set agent o1-mini-2024-09-12_o1harness 26 | ``` 27 | 28 | The o1 agents currently don't work very well, and might do dangerous things. But they do work! 29 | 30 | ## Examples 31 | 32 | - "What processes are listening on port 4512?" ... "ok, please kill them" 33 | - "What's in frog.jpg?" 34 | - "Write a function to determine if a tic-tac-toe game is won in a file called tictactoe.py. also write unit tests, and iterate until they pass." 35 | - "Fix all the type errors in this project" 36 | 37 | 38 | ## Usage 39 | 40 | Just 41 | 42 | ``` 43 | programmer 44 | ``` 45 | 46 | Alternatively: 47 | ``` 48 | programmer prompt 49 | ``` 50 | 51 | To resume from an earlier state: 52 | ``` 53 | programmer --state 54 | ``` 55 | 56 | ## Tracking 57 | 58 | Programmer is designed to get better over time. For that we need to track trajectories, identify good and bad ones to add to Evaluations (like unit tests for AI), and then iterate on programmer's prompts and architecture to improve against the Evaluations. 59 | 60 | By default all trajectories are logged to `.programmer/weave.db`. You can turn on cloud logging with `programmer settings set weave_logging cloud`. Trajectories will be saved to Weave at wandb.ai 61 | 62 | You can turn on git tracking with `programmer settings set git_tracking on` to get programmer to track all of its work in "programmer-*" branches. Each git state will be associated with the Weave trajectories, and you can browse the diffs with `programmer ui` 63 | 64 | ## UI 65 | 66 | Run 67 | 68 | ``` 69 | programmer ui 70 | ``` 71 | 72 | to run the local streamlit UI. This should work with either weave_logging:cloud or weave_logging:local, but there are some bugs with local mode at the moment. 73 | 74 | ![Programmer UI screenshot](./assets/programmer-ui.png) 75 | 76 | # Weave UI 77 | 78 | When weave_logging is set to "cloud" you can use the Weave UI at wandb.ai to browse traces. 79 | 80 | ## Settings 81 | 82 | Settings are stored in .programmer/settings 83 | 84 | programmer settings set weave_logging 85 | - off: no logging 86 | - local: log to local sqlite db 87 | - cloud: log to weave cloud at wandb.ai 88 | 89 | programmer settings set git_tracking 90 | - off: no git tracking 91 | - on: programmer with make programmer-* branches and track changes 92 | 93 | ## Improving programmer 94 | 95 | programmer is designed to be improved using [weave](https://wandb.me/weave), our toolkit for AI application development. What does this mean? 96 | 97 | - you can browse traces and evals in the Weave UI at https://wandb.ai 98 | - programmer can resume from earlier states, with the --state argument 99 | - programmer will log all of your interactions to a local sqlite database, or the central Weave service. 100 | - This data can be used to improve programmer over time, by building Evaluations, fine-tuning, and other techniques. 101 | 102 | To run the evaluation: 103 | 104 | ``` 105 | python evaluate.py 106 | ``` 107 | 108 | ## roadmap 109 | 110 | - [x] weave server tracking 111 | - [x] git state tracking 112 | - [x] basic trajectory UI 113 | - [ ] user-annotation of good and bad behaviors 114 | - [ ] eval generation -------------------------------------------------------------------------------- /assets/programmer-demo-1080.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/programmer/9b25bc7d4d4e4f4d685737de66147c248d1159d8/assets/programmer-demo-1080.gif -------------------------------------------------------------------------------- /assets/programmer-ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/programmer/9b25bc7d4d4e4f4d685737de66147c248d1159d8/assets/programmer-ui.png -------------------------------------------------------------------------------- /docs/design/file_line_tools.md: -------------------------------------------------------------------------------- 1 | Author: Shawn 2 | 3 | We need to build some new tool functions for programmer to do finer grained edits to files. 4 | 5 | I want to add two functions to programmer/tools.py: 6 | 7 | Line numbers are 1-indexed 8 | 9 | *read_lines_from_file(file_path, start_line)* 10 | 11 | this should read up to 500 lines from the file, starting at the given line number. 12 | 13 | lines should be prefixed with ":" 14 | 15 | return value is a string 16 | 17 | *replace_lines_in_file(file_path, start_line, end_line, new_lines)* 18 | 19 | for the file at file_path, replace the lines from start_line (inclusive) to end_line (exclusive) with the new_lines. 20 | 21 | new_lines is a string, not a list of lines. 22 | 23 | returns the modified region in the same format as read_lines_from_file, with a 5-line buffer on either side of the modified region 24 | 25 | if the file does not exist, this can be used to create a new file by setting start_line to 1 and end_line to 1. 26 | 27 | invalid line ranges should cause an exception. 28 | 29 | 30 | ## additional requirements 31 | 32 | These functions need to have python typing types, doc strings in the correct format, and throw actionable Exception messages, as they will be used by an LLM. 33 | 34 | And we need comprehensive unit tests that check all the edge cases -------------------------------------------------------------------------------- /docs/design/git_integration_rework.md: -------------------------------------------------------------------------------- 1 | # Design Document: Git Integration Rework for Programmer 2 | 3 | ## Overview 4 | This document outlines the proposed changes to the Git integration feature of the `programmer` tool. The goal is to allow real-time edits in the user's working state while maintaining a separate commit history in a `programmer-` branch, without affecting the user's visible working state. 5 | 6 | ## Objectives 7 | - Allow users to see changes in real-time using tools like VSCode's Git view. 8 | - Maintain a separate commit history in the background in a `programmer-` branch. 9 | - Avoid using stash or switching the HEAD, ensuring the user's working state remains unchanged. 10 | 11 | ## Proposed Solution 12 | To achieve the above objectives, we propose the following solution: 13 | 14 | ### Session Branch Management 15 | 1. **Initialization**: At the start of a session, initialize a `programmer-` branch based on the current state of the user's branch. 16 | 17 | 2. **Change Tracking**: Monitor file changes in the working directory to reflect them in the session branch, keeping the user's current branch and working directory unchanged. 18 | 19 | 3. **Commit History**: Maintain a separate commit history in the `programmer-` branch to allow for session management and review without interfering with the user's workflow. 20 | 21 | ## Benefits 22 | - Users can continue using their preferred development tools and see live changes. 23 | - Maintains a clean separation of session history, enabling better session management and review. 24 | 25 | ## Challenges 26 | - Requires careful handling of Git's internal mechanisms to ensure seamless integration. 27 | 28 | ## Conclusion 29 | By leveraging a separate session branch, we can achieve seamless integration of the `programmer` tool with the user's workflow, providing real-time feedback and maintaining a comprehensive session history without disrupting the user's development environment. -------------------------------------------------------------------------------- /docs/design/settings_management.md: -------------------------------------------------------------------------------- 1 | # Settings Management Feature Design 2 | 3 | Author: programmer 4 | 5 | ## Introduction 6 | 7 | This document outlines the design for the settings management feature to be implemented in the 'programmer' project. The feature will allow users to manage settings related to weave logging and git tracking. These settings should persist across sessions and be stored in the user’s current directory. 8 | 9 | ## Feature Overview 10 | 11 | The settings management feature will provide the following functionalities: 12 | 13 | 1. **Weave Logging Control**: Users can control the state of weave logging with three options: 14 | - Off 15 | - Local 16 | - Cloud 17 | 18 | 2. **Git Tracking Control**: Users can control the state of git tracking with two options: 19 | - Off 20 | - On 21 | 22 | ## Requirements 23 | 24 | - The settings should persist across sessions. 25 | - The settings should be stored in a file located in the user’s current directory. 26 | - The feature should provide an easy interface for users to change settings. 27 | 28 | ## Design Details 29 | 30 | ### Settings Storage 31 | 32 | - The settings will be stored in a directory named `.programmer` in the user's current directory. 33 | - Within this directory, settings will be saved in a file named `settings`. 34 | - The file will use a simple key-value format for storing settings: 35 | 36 | ``` 37 | weave_logging=off 38 | git_tracking=on 39 | ``` 40 | 41 | ### Interface 42 | 43 | - A command-line interface will be provided to change settings. Users will be able to run commands such as: 44 | 45 | ``` 46 | programmer settings set weave_logging local 47 | programmer settings get weave_logging 48 | ``` 49 | 50 | ### Implementation Steps 51 | 52 | 1. **Create Settings Directory and File Structure**: Define the structure and location of the settings file within the `.programmer` directory. 53 | 2. **Implement CLI for Settings Management**: Develop commands to get and set the settings. 54 | 3. **Persist Settings Across Sessions**: Ensure that changes to settings are saved to the file and reloaded when the application starts. 55 | 56 | ## Conclusion 57 | 58 | This design document provides a comprehensive overview of the settings management feature for the 'programmer' project. By following this design, we aim to implement a robust settings management system that allows users to control weave logging and git tracking effectively. -------------------------------------------------------------------------------- /docs/design/weave_query_ref_expansion.md: -------------------------------------------------------------------------------- 1 | Author: Shawn 2 | 3 | We're working on programmer, an interactive command line tool that you can chat with to write programs 4 | 5 | programmer uses the weave library to trace trajectories. 6 | 7 | weave's core concept is the decorator `weave.op`, which saves Call records into a database. Call records include inputs, output, and other metadata. 8 | 9 | ``` 10 | @weave.op 11 | def add2(a, b): 12 | return a + b 13 | 14 | client = weave.init_trace_local() 15 | add2(2, 3) 16 | add2(5, 6) 17 | 18 | for call in client.calls(): 19 | ... 20 | ``` 21 | 22 | Weave also has a concept called objects. To use a weave Object, create a class that inherits from weave.Object, and add annotated type attributes. Objects inherit from pydantic.BaseModel. When a weave.Object descendent is encountered as an input or output of an op call, weave publishes the Object as a top-level record, and stores a ref uri (in the form of weave:///) to the Object record in the Call record. 23 | 24 | Here is an Object example: 25 | 26 | ``` 27 | class Point(weave.Object): 28 | x: float 29 | y: float 30 | 31 | @weave.op 32 | def add_points(p1: Point, p2: Point): 33 | return Point(x=p1.x + p2.x, y=p1.y + p2.y) 34 | 35 | ... 36 | ``` 37 | 38 | Objects may be nested inside each-other. 39 | 40 | In programmer we've built a new weave query interface in `programmer/weave_next/weave_query.py`. 41 | 42 | You can use weave_query.py's functions to resolve calls with Object refs, which may contain other refs, but its a bit cumbersome at the moment. 43 | 44 | The goal of this project is to improve the interface, into a single calls interface that allows us to fetch calls and expand refs and nested refs in one-shot. Something like this: 45 | 46 | ``` 47 | calls_query = calls(weave_client, 'my_op', expand_refs=['output', 'output.field_a']) 48 | calls_df = calls_query.to_pandas() 49 | ``` 50 | 51 | Please implement the feature including comprehensive unit tests. -------------------------------------------------------------------------------- /docs/dev/devlog-shawn.md: -------------------------------------------------------------------------------- 1 | 8/17/24 2 | ------- 3 | 4 | Git tracking is working in a branch now. Here's what it does: 5 | - If you're in a git repo, programmer automatically creates branches while it works. 6 | - The git state is stored in the trajectories that are auto-saved to Weave. 7 | - This means you can roll back programmer to any prior point, and both the conversation, and git repo state will be restored. 8 | 9 | Why do this? To improve an AI application like programmer, you need to experiment. 10 | 11 | Let's use an example. Suppose you're using programmer, you ask it to run some unit tests, and programmer say something like "OK here's a plan, I'll run the `pytest` command, would you like to proceed?" 12 | 13 | This is annoying, we just want it to run the command instead of stopping and asking the user. 14 | 15 | We can try to fix this with prompt engineering. We want to experiment with a bunch of different prompts, starting from the prior state of conversation and file system. 16 | 17 | 18 | OK, above is the beginning of a write up of how to talk about this feature... 19 | 20 | Now I want to do a few things: 21 | - think about if the git feature is ready. 22 | - write a new feature using programmer: programmer settings controls. 23 | 24 | Bug: 25 | - programmer fails to restore my original branch -------------------------------------------------------------------------------- /programmer-ui/ui.py: -------------------------------------------------------------------------------- 1 | # Streamlit UI for browsing programmer sessions 2 | 3 | import pandas as pd 4 | from typing import Optional, Union, Sequence, Dict, Callable, Any 5 | import json 6 | import streamlit as st 7 | import weave 8 | import os 9 | import openai 10 | import copy 11 | from weave.trace.weave_client import WeaveClient 12 | 13 | from programmer.weave_next.api import init_local_client 14 | from programmer.weave_next.weave_query import ( 15 | calls, 16 | expand_refs, 17 | get_call, 18 | expand_json_refs, 19 | ) 20 | from programmer.settings_manager import SettingsManager 21 | 22 | st.set_page_config(layout="wide") 23 | 24 | ST_HASH_FUNCS: Dict[Any, Callable] = {WeaveClient: lambda x: x._project_id()} 25 | 26 | 27 | @st.cache_resource 28 | def init_local_weave(db_path: str = "weave.db"): 29 | return init_local_client(db_path) 30 | 31 | 32 | @st.cache_resource 33 | def init_remote_weave(project: str): 34 | return weave.init(project) 35 | 36 | 37 | def init_from_settings() -> WeaveClient: 38 | SettingsManager.initialize_settings() 39 | weave_logging_setting = SettingsManager.get_setting("weave_logging") 40 | if weave_logging_setting == "off": 41 | st.error( 42 | "Weave logging is off. Please set weave_logging to 'on' in settings to use this feature." 43 | ) 44 | st.stop() 45 | raise Exception("Should never get here") 46 | elif weave_logging_setting == "local": 47 | return init_local_weave( 48 | os.path.join(SettingsManager.PROGRAMMER_DIR, "weave.db") 49 | ) 50 | elif weave_logging_setting == "cloud": 51 | curdir = os.path.basename(os.path.abspath(os.curdir)) 52 | return init_remote_weave(f"programmer-{curdir}") 53 | else: 54 | raise ValueError(f"Invalid weave_logging setting: {weave_logging_setting}") 55 | 56 | 57 | # Add sidebar for Weave project configuration 58 | with st.sidebar: 59 | st.header("Weave Project Configuration") 60 | 61 | # Initialize from settings 62 | initial_weave_logging = SettingsManager.get_setting("weave_logging") 63 | initial_project_type = "local" if initial_weave_logging == "local" else "cloud" 64 | initial_project_path = ( 65 | os.path.join(SettingsManager.PROGRAMMER_DIR, "weave.db") 66 | if initial_weave_logging == "local" 67 | else "" 68 | ) 69 | initial_project_name = ( 70 | f"programmer-{os.path.basename(os.path.abspath(os.curdir))}" 71 | if initial_weave_logging == "cloud" 72 | else "" 73 | ) 74 | 75 | project_type = st.radio( 76 | "Project Type", 77 | ["local", "cloud"], 78 | index=0 if initial_project_type == "local" else 1, 79 | ) 80 | 81 | if project_type == "local": 82 | project_path = st.text_input("Local DB Path", value=initial_project_path) 83 | # SettingsManager.set_setting("weave_logging", "local") 84 | # SettingsManager.set_setting("weave_db_path", project_path) 85 | client = init_local_weave(project_path) 86 | print("C2", client._project_id()) 87 | else: 88 | # SettingsManager.set_setting("weave_logging", "cloud") 89 | # SettingsManager.set_setting("weave_project_name", project_name) 90 | project_name = st.text_input("Cloud Project Name", value=initial_project_name) 91 | client = init_remote_weave(project_name) 92 | print("C3", client._project_id()) 93 | 94 | # Initialize client based on current settings 95 | # client = init_from_settings() 96 | print("CLIENT", client._project_id()) 97 | 98 | 99 | def set_focus_step_id(call_id): 100 | st.session_state["focus_step_id"] = call_id 101 | 102 | 103 | @st.cache_data(hash_funcs=ST_HASH_FUNCS) 104 | def cached_calls( 105 | wc: WeaveClient, 106 | op_names: Optional[Union[str, Sequence[str]]] = None, 107 | parent_ids: Optional[Union[str, Sequence[str]]] = None, 108 | limit: Optional[int] = None, 109 | expand_refs: Optional[list[str]] = None, 110 | ): 111 | return calls( 112 | wc, 113 | op_names=op_names, 114 | parent_ids=parent_ids, 115 | limit=limit, 116 | expand_refs=expand_refs, 117 | ).to_pandas() 118 | 119 | 120 | @st.cache_data(hash_funcs=ST_HASH_FUNCS) 121 | def cached_expand_refs(wc: WeaveClient, refs: Sequence[str]): 122 | return expand_refs(wc, refs).to_pandas() 123 | 124 | 125 | @st.cache_data(hash_funcs=ST_HASH_FUNCS) 126 | def cached_get_call(wc: WeaveClient, call_id: str): 127 | return get_call(wc, call_id) 128 | 129 | 130 | @st.cache_data(hash_funcs=ST_HASH_FUNCS) 131 | def cached_expand_json_refs(wc: WeaveClient, json: dict): 132 | return expand_json_refs(wc, json) 133 | 134 | 135 | def print_step_call(call): 136 | start_history = call["inputs.state.history"] 137 | end_history = call["output.history"] 138 | if isinstance(end_history, float): 139 | st.write("STEP WITH NO OUTPUT") 140 | return 141 | step_messages = list(end_history)[len(start_history) :] 142 | assistant_message = step_messages[0] 143 | tool_response_messages = step_messages[1:] 144 | 145 | if not assistant_message["role"] == "assistant": 146 | raise ValueError(f"Expected assistant message, got {assistant_message['role']}") 147 | 148 | with st.chat_message("assistant"): 149 | st.write(f"https://wandb.ai/shawn/programmer-sympy/weave/calls/{call.id}") 150 | st.write(f"State ref:", call["inputs.state._ref"]) 151 | if "content" in assistant_message: 152 | st.write(assistant_message["content"]) 153 | if "tool_calls" in assistant_message: 154 | for t in assistant_message["tool_calls"]: 155 | t_id = t["id"] 156 | f_name = t["function"]["name"] 157 | f_args = json.loads(t["function"]["arguments"]) 158 | arg0 = list(f_args.values())[0] 159 | for t_response in tool_response_messages: 160 | if t_response["tool_call_id"] == t_id: 161 | break 162 | else: 163 | raise ValueError(f"Tool call response not found for id {t_id}") 164 | with st.expander(f"{f_name}({arg0}, ...)"): 165 | if ( 166 | f_name == "replace_lines_in_file" 167 | or f_name == "read_lines_from_file" 168 | ): 169 | st.write(f_args) 170 | st.text(t_response["content"]) 171 | 172 | def set_focus_step_closure(): 173 | set_focus_step_id(call.id) 174 | 175 | try: 176 | start_snapshot_commit = call[ 177 | "inputs.state.env_snapshot_key.snapshot_info.commit" 178 | ] 179 | end_snapshot_commit = call["output.env_snapshot_key.snapshot_info.commit"] 180 | 181 | if start_snapshot_commit is not None and end_snapshot_commit is not None: 182 | if start_snapshot_commit != end_snapshot_commit: 183 | with st.expander( 184 | f"git diff {start_snapshot_commit} {end_snapshot_commit}" 185 | ): 186 | diff_output = os.popen( 187 | f"git diff {start_snapshot_commit} {end_snapshot_commit}" 188 | ).read() 189 | st.code(diff_output, language="diff") 190 | 191 | except KeyError: 192 | pass 193 | 194 | # st.button("focus", key=f"focus-{call.id}", on_click=set_focus_step_closure) 195 | 196 | 197 | def print_run_call( 198 | call, 199 | steps_df, 200 | ): 201 | st.write("RUN CALL", call.id) 202 | start_history = steps_df.iloc[0]["inputs.state.history"] 203 | user_input = start_history[-1]["content"] 204 | with st.chat_message("user"): 205 | st.write(user_input) 206 | for _, step in steps_df.iterrows(): 207 | print_step_call(step) 208 | 209 | 210 | def print_session_call(session_id): 211 | runs_df = cached_calls(client, "Agent.run", parent_ids=session_id) 212 | steps_df = cached_calls( 213 | client, 214 | "Agent.step", 215 | parent_ids=runs_df["id"].tolist(), 216 | expand_refs=[ 217 | "inputs.state", 218 | "inputs.state.env_snapshot_key", 219 | "output", 220 | "output.env_snapshot_key", 221 | ], 222 | ) 223 | 224 | for _, run_call_data in runs_df.iterrows(): 225 | run_steps_df = steps_df[steps_df["parent_id"] == run_call_data["id"]] 226 | 227 | print_run_call( 228 | run_call_data, 229 | run_steps_df, 230 | ) 231 | 232 | 233 | def sessions_page(): 234 | session_calls_df = cached_calls( 235 | client, "session", expand_refs=["inputs.agent_state"] 236 | ) 237 | if len(session_calls_df) == 0: 238 | st.error("No programmer sessions found.") 239 | st.stop() 240 | session_user_message_df = session_calls_df["inputs.agent_state.history"].apply( 241 | lambda v: v[-1]["content"] 242 | ) 243 | with st.sidebar: 244 | st.header("Session Selection") 245 | if st.button("Refresh"): 246 | st.cache_data.clear() 247 | st.rerun() 248 | message_ids = { 249 | f"{cid[-5:]}: {m}": cid 250 | for cid, m in reversed( 251 | list(zip(session_calls_df["id"], session_user_message_df)) 252 | ) 253 | } 254 | sel_message = st.radio("Session", options=message_ids.keys()) 255 | sel_id = None 256 | if sel_message: 257 | sel_id = message_ids.get(sel_message) 258 | if sel_id: 259 | st.header(f"Session: {sel_id}") 260 | print_session_call(sel_id) 261 | 262 | 263 | sessions_pg = st.Page(sessions_page, title="Sessions") 264 | 265 | 266 | # def write_chat_message(m, key): 267 | # with st.chat_message(m["role"]): 268 | # if "content" in m: 269 | # st.text_area( 270 | # "", value=str(m["content"]), label_visibility="collapsed", key=key 271 | # ) 272 | def write_chat_message(m, key, readonly=False): 273 | def on_change_content(): 274 | new_value = st.session_state[key] 275 | st.session_state.playground_state["editable_call"]["inputs"]["messages"][ 276 | m["original_index"] 277 | ]["content"] = new_value 278 | 279 | with st.chat_message(m["role"]): 280 | if m.get("content"): 281 | if readonly: 282 | st.code(m["content"]) 283 | else: 284 | st.text_area( 285 | "", 286 | value=m["content"], 287 | label_visibility="collapsed", 288 | key=key, 289 | on_change=on_change_content, 290 | ) 291 | if m.get("tool_calls"): 292 | for t in m["tool_calls"]: 293 | st.write(t["function"]["name"]) 294 | st.json( 295 | { 296 | "arguments": t["function"]["arguments"], 297 | "response": t.get("response", {}).get("content"), 298 | }, 299 | expanded=True, 300 | ) 301 | 302 | 303 | def attach_tool_call_responses(messages): 304 | new_messages = [] 305 | for i, m in enumerate(messages): 306 | new_m = copy.deepcopy(m) 307 | new_m["original_index"] = i 308 | if new_m["role"] == "assistant" and "tool_calls" in new_m: 309 | new_m["tool_call_responses"] = [] 310 | for t in new_m["tool_calls"]: 311 | t_id = t["id"] 312 | for j, t_response in enumerate(messages): 313 | if t_response.get("tool_call_id") == t_id: 314 | t["response"] = t_response 315 | t["response"]["original_index"] = j 316 | break 317 | if "tool_call_id" not in new_m: 318 | new_messages.append(new_m) 319 | return new_messages 320 | 321 | 322 | def playground_page(): 323 | with st.sidebar: 324 | if not st.session_state.get("playground_state"): 325 | st.session_state.playground_state = { 326 | "call_id": None, 327 | "call": None, 328 | "expanded_call": None, 329 | "editable_call": None, 330 | } 331 | playground_state = st.session_state.playground_state 332 | call_id = st.text_input("Call ID") 333 | if not call_id: 334 | st.error("Please set call ID") 335 | st.stop() 336 | 337 | # st.write(playground_state) 338 | if playground_state["expanded_call"] != playground_state["editable_call"]: 339 | st.warning("Call has been modified") 340 | if st.button("Restore original call"): 341 | st.session_state.playground_state["editable_call"] = copy.deepcopy( 342 | playground_state["expanded_call"] 343 | ) 344 | st.rerun() 345 | 346 | if call_id != st.session_state.playground_state["call_id"]: 347 | st.spinner("Loading call...") 348 | call = cached_get_call(client, call_id) 349 | editable_call = cached_expand_json_refs(client, call) 350 | st.session_state.playground_state = { 351 | "call_id": call_id, 352 | "call": call, 353 | "expanded_call": editable_call, 354 | "editable_call": copy.deepcopy(editable_call), 355 | } 356 | st.rerun() 357 | 358 | call = st.session_state.playground_state["call"] 359 | editable_call = st.session_state.playground_state["editable_call"] 360 | if call is None or editable_call is None: 361 | st.warning("call not yet loaded") 362 | st.stop() 363 | 364 | st.write(call["op_name"]) 365 | # st.json(call["inputs"]) 366 | # st.json(call["inputs"]["tools"]) 367 | 368 | def on_change_temperature(): 369 | st.session_state.playground_state["editable_call"]["inputs"][ 370 | "temperature" 371 | ] = st.session_state["temperature"] 372 | 373 | st.slider( 374 | "Temperature", 375 | min_value=0.0, 376 | max_value=1.0, 377 | value=editable_call["inputs"]["temperature"], 378 | key="temperature", 379 | on_change=on_change_temperature, 380 | ) 381 | 382 | tools = call["inputs"].get("tools", []) 383 | if tools: 384 | st.write("Tools") 385 | for tool_idx, t in enumerate(tools): 386 | with st.expander(t["function"]["name"]): 387 | 388 | def on_change_tool(): 389 | st.session_state.playground_state["editable_call"]["inputs"][ 390 | "tools" 391 | ][tool_idx] = json.loads(st.session_state[f"tool-{tool_idx}"]) 392 | st.rerun() 393 | 394 | st.text_area( 395 | "json", 396 | value=json.dumps(t, indent=2), 397 | height=300, 398 | key=f"tool-{tool_idx}", 399 | on_change=on_change_tool, 400 | ) 401 | 402 | def on_change_parallel_tool_calls(): 403 | st.session_state.playground_state["editable_call"]["inputs"][ 404 | "parallel_tool_calls" 405 | ] = st.session_state["parallel_tool_calls"] 406 | 407 | st.checkbox( 408 | "Parallel tool calls", 409 | value=editable_call["inputs"].get("parallel_tool_calls", True), 410 | key="parallel_tool_calls", 411 | on_change=on_change_parallel_tool_calls, 412 | ) 413 | 414 | inputs = editable_call["inputs"] 415 | all_input_messages = inputs["messages"] 416 | other_inputs = { 417 | k: v 418 | for k, v in inputs.items() 419 | if (k != "messages" and k != "self" and k != "stream") 420 | } 421 | 422 | tool_call_attached_messages = attach_tool_call_responses(all_input_messages) 423 | for i, m in enumerate(tool_call_attached_messages): 424 | write_chat_message(m, f"message-{i}") 425 | # output = editable_call["output"]["choices"][0]["message"] 426 | n_choices = st.number_input( 427 | "Number of choices", value=1, min_value=1, max_value=100 428 | ) 429 | if st.button("Generate"): 430 | chat_inputs = {**editable_call["inputs"]} 431 | # st.json(chat_inputs, expanded=False) 432 | if "stream" in chat_inputs: 433 | del chat_inputs["stream"] 434 | if "self" in chat_inputs: 435 | del chat_inputs["self"] 436 | chat_inputs["n"] = n_choices 437 | call_resp = openai.chat.completions.create(**chat_inputs).model_dump() 438 | 439 | editable_call["output"] = call_resp 440 | st.rerun() 441 | # st.json(response, expanded=False) 442 | # output = response["choices"][0]["message"] 443 | # st.json(output) 444 | response = editable_call["output"] 445 | st.write("full response") 446 | st.json(response, expanded=False) 447 | st.write("**system fingerprint**", response["system_fingerprint"]) 448 | st.write("**usage**", response["usage"]) 449 | for i, choice in enumerate(response["choices"]): 450 | output = choice["message"] 451 | st.write(f"Choice {i+1}") 452 | write_chat_message(output, f"output_message-{i}", readonly=True) 453 | 454 | # all_messages = [*all_input_messages, output] 455 | # st.json(st.session_state.playground_state, expanded=False) 456 | # st.json(all_messages, expanded=False) 457 | 458 | # st.write(expanded_call) 459 | 460 | 461 | playground_pg = st.Page(playground_page, title="Playground") 462 | 463 | 464 | pg = st.navigation([sessions_pg, playground_pg]) 465 | pg.run() 466 | -------------------------------------------------------------------------------- /programmer/__init__.py: -------------------------------------------------------------------------------- 1 | from .programmer import * -------------------------------------------------------------------------------- /programmer/agent.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union 2 | from pydantic import Field 3 | import litellm 4 | import time 5 | from openai.types.chat import ( 6 | ChatCompletionMessageParam, 7 | ) 8 | 9 | import weave 10 | from weave.trace.vals import WeaveList 11 | from weave.flow.chat_util import OpenAIStream 12 | 13 | from .console import Console 14 | from .tool_calling import chat_call_tool_params, perform_tool_calls 15 | from .environment import get_current_environment, EnvironmentSnapshotKey 16 | 17 | 18 | def get_commit_message(history: list[Any]) -> str: 19 | # Commit message is the most recent message with 'content' 20 | for i in range(len(history) - 1, -1, -1): 21 | if history[i].get("role") != "tool" and "content" in history[i]: 22 | return f'{history[i]["role"]}: {history[i]["content"]}' 23 | return "commit" 24 | 25 | 26 | # Weave bug workaround: adding two WeaveLists can create that cause 27 | # downstream crashes. 28 | # Can be removed after https://github.com/wandb/weave/pull/2165 is merged. 29 | def weavelist_add(self: Union[list, WeaveList], other: list) -> Union[list, WeaveList]: 30 | if isinstance(self, list): 31 | return self + other 32 | if not isinstance(other, list): 33 | return NotImplemented 34 | return WeaveList(list(self) + other, server=self.server) 35 | 36 | 37 | class AgentState(weave.Object): 38 | # The chat message history. 39 | history: list[Any] = Field(default_factory=list) 40 | env_snapshot_key: Optional[EnvironmentSnapshotKey] = None 41 | 42 | def with_history(self, history: list[Any]) -> "AgentState": 43 | environment = get_current_environment() 44 | msg = get_commit_message(history) 45 | snapshot_key = environment.make_snapshot(msg) 46 | return self.__class__(history=history, env_snapshot_key=snapshot_key) 47 | 48 | 49 | def unweavify(v: Any) -> Any: 50 | if isinstance(v, list): 51 | return [unweavify(m) for m in v] 52 | elif isinstance(v, dict): 53 | return {k: unweavify(v) for k, v in v.items()} 54 | else: 55 | return v 56 | 57 | 58 | class Agent(weave.Object): 59 | model_name: str 60 | temperature: float 61 | system_message: str 62 | tools: list[Any] = Field(default_factory=list) 63 | 64 | def initial_state(self, history: list[Any]) -> AgentState: 65 | return AgentState().with_history(history) 66 | 67 | @weave.op() 68 | def step(self, state: AgentState) -> AgentState: 69 | """Run a step of the agent. 70 | 71 | Args: 72 | state: The current state of the environment. 73 | action: The action to take. 74 | 75 | Returns: 76 | The new state of the environment. 77 | """ 78 | Console.step_start("agent", "green") 79 | # Printing this is ugly 80 | # ref = weave.obj_ref(state) 81 | # if ref: 82 | # print("state ref:", ref.uri()) 83 | 84 | messages: list[ChatCompletionMessageParam] = [ 85 | {"role": "system", "content": self.system_message}, 86 | ] 87 | messages += state.history 88 | 89 | # make type checkers happy by passing NotGiven instead of None 90 | tools = None 91 | if self.tools: 92 | tools = chat_call_tool_params(self.tools) 93 | 94 | Console.chat_response_start() 95 | 96 | # Workaround a weave bug, litellm tries to deepcopy messages which has 97 | # a TraceDict. TraceDict is not pickable, because it has a reference to 98 | # a weave server, which has a lock. 99 | messages = unweavify(messages) 100 | 101 | stream = litellm.completion( 102 | model=self.model_name, 103 | temperature=self.temperature, 104 | messages=messages, 105 | tools=tools, 106 | stream=True, 107 | timeout=60, 108 | ) 109 | wrapped_stream = OpenAIStream(stream) # type: ignore 110 | for chunk in wrapped_stream: 111 | if chunk.choices[0].delta.content: 112 | Console.chat_message_content_delta(chunk.choices[0].delta.content) 113 | 114 | response = wrapped_stream.final_response() 115 | response_message = response.choices[0].message 116 | if response_message.content: 117 | Console.chat_response_complete(response_message.content) 118 | 119 | new_messages = [] 120 | # we always store the dict representations of messages in agent state 121 | # instead of mixing in some pydantic objects. 122 | new_messages.append(response_message.model_dump(exclude_none=True)) 123 | if response_message.tool_calls: 124 | new_messages.extend( 125 | perform_tool_calls(self.tools, response_message.tool_calls) 126 | ) 127 | 128 | # new_history = state.history + new_messages 129 | new_history = weavelist_add(state.history, new_messages) 130 | 131 | return state.with_history(new_history) 132 | 133 | @weave.op() 134 | def run(self, state: AgentState, max_runtime_seconds: int = -1): 135 | start_time = time.time() 136 | while True: 137 | last_message = state.history[-1] 138 | if last_message["role"] == "assistant" and "tool_calls" not in last_message: 139 | return {"state": state, "stop_reason": "done"} 140 | state = self.step(state) 141 | if ( 142 | max_runtime_seconds > 0 143 | and time.time() - start_time > max_runtime_seconds 144 | ): 145 | return {"state": state, "stop_reason": "time_limit_exceeded"} 146 | -------------------------------------------------------------------------------- /programmer/agent_texteditor.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union 2 | from pydantic import Field 3 | import litellm 4 | from openai.types.chat import ( 5 | ChatCompletionMessageParam, 6 | ) 7 | 8 | import weave 9 | from weave.trace.vals import WeaveList 10 | from weave.flow.chat_util import OpenAIStream 11 | 12 | from .console import Console 13 | from .tool_calling import chat_call_tool_params, perform_tool_calls 14 | from .text_editor import ( 15 | TextEditor, 16 | TextEditorState, 17 | TextEditorStateful, 18 | open_file, 19 | close_file_range, 20 | replace_file_lines, 21 | text_editor, 22 | ) 23 | from .agent import AgentState, Agent 24 | 25 | 26 | # Weave bug workaround: adding two WeaveLists can create that cause 27 | # downstream crashes. 28 | # Can be removed after https://github.com/wandb/weave/pull/2165 is merged. 29 | def weavelist_add(self: Union[list, WeaveList], other: list) -> Union[list, WeaveList]: 30 | if isinstance(self, list): 31 | return self + other 32 | if not isinstance(other, list): 33 | return NotImplemented 34 | return WeaveList(list(self) + other, server=self.server) 35 | 36 | 37 | class AgentStateTextEditor(AgentState): 38 | text_editor_state: TextEditorState = Field(default_factory=TextEditorState) 39 | 40 | def with_history(self, history: list[Any]) -> "AgentStateTextEditor": 41 | next_state = super().with_history(history) 42 | return AgentStateTextEditor( 43 | history=next_state.history, 44 | env_snapshot_key=next_state.env_snapshot_key, 45 | text_editor_state=self.text_editor_state, 46 | ) 47 | 48 | def with_texteditor_state( 49 | self, text_editor_state: TextEditorState 50 | ) -> "AgentStateTextEditor": 51 | return AgentStateTextEditor( 52 | history=self.history, 53 | env_snapshot_key=self.env_snapshot_key, 54 | text_editor_state=text_editor_state, 55 | ) 56 | 57 | 58 | def unweavify(v: Any) -> Any: 59 | if isinstance(v, list): 60 | return [unweavify(m) for m in v] 61 | elif isinstance(v, dict): 62 | return {k: unweavify(v) for k, v in v.items()} 63 | else: 64 | return v 65 | 66 | 67 | class AgentTextEditor(Agent): 68 | parallel_tool_calls: bool = True 69 | text_editor: TextEditor 70 | 71 | def initial_state(self, history: list[Any]) -> AgentStateTextEditor: 72 | return AgentStateTextEditor(history=history) 73 | 74 | @weave.op() 75 | def step(self, state: AgentStateTextEditor) -> AgentStateTextEditor: 76 | """Run a step of the agent. 77 | 78 | Args: 79 | state: The current state of the environment. 80 | action: The action to take. 81 | 82 | Returns: 83 | The new state of the environment. 84 | """ 85 | Console.step_start("agent", "green") 86 | # Printing this is ugly 87 | # ref = weave.obj_ref(state) 88 | # if ref: 89 | # print("state ref:", ref.uri()) 90 | 91 | messages: list[ChatCompletionMessageParam] = [ 92 | {"role": "system", "content": self.system_message}, 93 | ] 94 | open_file_info = state.text_editor_state.get_open_file_info() 95 | 96 | messages.append( 97 | { 98 | "role": "system", 99 | "content": open_file_info.format_for_messages(), 100 | } 101 | ) 102 | 103 | messages += state.history 104 | 105 | messages.append( 106 | { 107 | "role": "system", 108 | "content": open_file_info.format_for_messages(), 109 | } 110 | ) 111 | 112 | self_tools = [*self.tools] or [] 113 | 114 | text_editor_stateful = TextEditorStateful( 115 | self.text_editor, state.text_editor_state 116 | ) 117 | 118 | # self_tools += [open_file, close_file_range, replace_file_lines] 119 | self_tools += [open_file, replace_file_lines] 120 | 121 | # make type checkers happy by passing NotGiven instead of None 122 | tools = None 123 | if self_tools: 124 | tools = chat_call_tool_params(self_tools) 125 | 126 | Console.chat_response_start() 127 | 128 | # Workaround a weave bug, litellm tries to deepcopy messages which has 129 | # a TraceDict. TraceDict is not pickable, because it has a reference to 130 | # a weave server, which has a lock. 131 | messages = unweavify(messages) 132 | 133 | stream = litellm.completion( 134 | model=self.model_name, 135 | temperature=self.temperature, 136 | messages=messages, 137 | tools=tools, 138 | stream=True, 139 | timeout=60, 140 | parallel_tool_calls=self.parallel_tool_calls, 141 | ) 142 | wrapped_stream = OpenAIStream(stream) # type: ignore 143 | for chunk in wrapped_stream: 144 | if chunk.choices[0].delta.content: 145 | Console.chat_message_content_delta(chunk.choices[0].delta.content) 146 | 147 | response = wrapped_stream.final_response() 148 | response_message = response.choices[0].message 149 | if response_message.content: 150 | Console.chat_response_complete(response_message.content) 151 | 152 | new_messages = [] 153 | # we always store the dict representations of messages in agent state 154 | # instead of mixing in some pydantic objects. 155 | new_messages.append(response_message.model_dump(exclude_none=True)) 156 | if response_message.tool_calls: 157 | with text_editor(text_editor_stateful): 158 | new_messages.extend( 159 | perform_tool_calls(self_tools, response_message.tool_calls) 160 | ) 161 | new_history = weavelist_add(state.history, new_messages) 162 | 163 | next_state = state.with_history(new_history) 164 | next_state = next_state.with_texteditor_state(text_editor_stateful.state) 165 | return next_state 166 | -------------------------------------------------------------------------------- /programmer/agent_texteditor_o1.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union 2 | from dataclasses import dataclass 3 | from pydantic import Field 4 | import openai 5 | from openai.types.chat import ChatCompletionMessageParam 6 | import json 7 | import re 8 | import time 9 | import uuid 10 | from openai.types.chat import ( 11 | ChatCompletionMessageToolCall, 12 | ) 13 | 14 | import weave 15 | from weave.trace.vals import WeaveList 16 | 17 | from .console import Console 18 | from .tool_calling import ( 19 | chat_call_tool_params, 20 | perform_tool_calls, 21 | generate_json_schema, 22 | ) 23 | from .text_editor import ( 24 | TextEditor, 25 | TextEditorState, 26 | TextEditorStateful, 27 | open_file, 28 | replace_file_lines, 29 | text_editor, 30 | ) 31 | from .agent import AgentState, Agent 32 | 33 | 34 | def weavelist_add(self: Union[list, WeaveList], other: list) -> Union[list, WeaveList]: 35 | if isinstance(self, list): 36 | return self + other 37 | if not isinstance(other, list): 38 | return NotImplemented 39 | return WeaveList(list(self) + other, server=self.server) 40 | 41 | 42 | @dataclass 43 | class ToolCallFunction: 44 | name: str 45 | arguments: str 46 | 47 | 48 | @dataclass 49 | class ToolCall: 50 | function: ToolCallFunction 51 | id: str 52 | 53 | 54 | class AgentStateTextEditor(AgentState): 55 | text_editor_state: TextEditorState = Field(default_factory=TextEditorState) 56 | 57 | def with_history(self, history: list[Any]) -> "AgentStateTextEditor": 58 | next_state = super().with_history(history) 59 | return AgentStateTextEditor( 60 | history=next_state.history, 61 | env_snapshot_key=next_state.env_snapshot_key, 62 | text_editor_state=self.text_editor_state, 63 | ) 64 | 65 | def with_texteditor_state( 66 | self, text_editor_state: TextEditorState 67 | ) -> "AgentStateTextEditor": 68 | return AgentStateTextEditor( 69 | history=self.history, 70 | env_snapshot_key=self.env_snapshot_key, 71 | text_editor_state=text_editor_state, 72 | ) 73 | 74 | 75 | class AgentTextEditorO1(Agent): 76 | parallel_tool_calls: bool = True 77 | text_editor: TextEditor 78 | 79 | def initial_state(self, history: list[Any]) -> AgentStateTextEditor: 80 | return AgentStateTextEditor(history=history) 81 | 82 | @weave.op() 83 | def step(self, state: AgentStateTextEditor) -> AgentStateTextEditor: 84 | """Run a step of the agent. 85 | 86 | Args: 87 | state: The current state of the environment. 88 | 89 | Returns: 90 | The new state of the environment. 91 | """ 92 | Console.step_start("agent", "green") 93 | 94 | # Prepare messages 95 | messages: list[ChatCompletionMessageParam] = [] 96 | 97 | # Combine system message and open_file_info into a user message 98 | open_file_info = state.text_editor_state.get_open_file_info() 99 | initial_content = ( 100 | f"{self.system_message}\n\n{open_file_info.format_for_messages()}" 101 | ) 102 | 103 | # Include descriptions of available tools 104 | self_tools = [*self.tools] or [] 105 | text_editor_stateful = TextEditorStateful( 106 | self.text_editor, state.text_editor_state 107 | ) 108 | 109 | self_tools += [open_file, replace_file_lines] 110 | 111 | # Generate tool descriptions 112 | tools_descriptions = "" 113 | for tool in self_tools: 114 | tool_schema = generate_json_schema(tool) 115 | tool_name = tool.__name__ 116 | tool_description = tool_schema.get("function", {}).get("description", "") 117 | tool_parameters = tool_schema.get("function", {}).get("parameters", {}) 118 | tools_descriptions += f"\n- {tool_name}: {tool_description}\nParameters: {json.dumps(tool_parameters)}\n" 119 | 120 | initial_content += f"\n\nAvailable tools:{tools_descriptions}\n" 121 | 122 | # Add instructions to the assistant about how to call tools 123 | initial_content += ( 124 | "When you want to use a tool, please output the tool call in the following format:\n" 125 | "()\n" 126 | 'For example: ({"file_name": "example.txt"})\n' 127 | "Please include the tool call in your response where appropriate." 128 | "If you have achieved your goal, our you're stuck, don't call a tool!" 129 | ) 130 | 131 | # Add the initial user message 132 | messages.append( 133 | { 134 | "role": "user", 135 | "content": f"{initial_content}", 136 | } 137 | ) 138 | 139 | # Add conversation history, ensuring only 'assistant' and 'user' roles 140 | messages += [ 141 | msg for msg in state.history if msg.get("role") in ["assistant", "user"] 142 | ] 143 | 144 | Console.chat_response_start() 145 | 146 | # Call the OpenAI API 147 | response = openai.chat.completions.create( 148 | model=self.model_name, 149 | temperature=self.temperature, 150 | messages=messages, 151 | timeout=600, 152 | ) 153 | 154 | # Get the assistant's response 155 | response_message = response.choices[0].message 156 | 157 | if response_message.content: 158 | print(response_message.content) 159 | Console.chat_response_complete(response_message.content) 160 | 161 | new_messages = [] 162 | # Store the assistant's response 163 | new_messages.append( 164 | { 165 | "role": response_message.role, 166 | "content": response_message.content, 167 | } 168 | ) 169 | 170 | # Parse any tool calls from the assistant's response 171 | tool_calls = self.parse_tool_calls(response_message.content or "") 172 | 173 | if tool_calls: 174 | with text_editor(text_editor_stateful): 175 | tool_messages = perform_tool_calls(self_tools, tool_calls) 176 | 177 | # Combine tool call responses into a single user message 178 | tool_responses = "\n" 179 | for msg in tool_messages: 180 | tool_responses += f"{msg['content']}\n" 181 | tool_responses += "" 182 | 183 | new_messages.append({"role": "user", "content": tool_responses}) 184 | 185 | new_history = weavelist_add(state.history, new_messages) 186 | 187 | next_state = state.with_history(new_history) 188 | next_state = next_state.with_texteditor_state(text_editor_stateful.state) 189 | return next_state 190 | 191 | def parse_tool_calls(self, content: str) -> list: 192 | tool_calls = [] 193 | pattern = r"<(.*?)>\((.*?)\)" 194 | matches = re.finditer(pattern, content, re.DOTALL) 195 | for match in matches: 196 | tool_id = match.group(1) 197 | tool_name = match.group(2) 198 | arguments = match.group(3) 199 | tool_call = ToolCall( 200 | function=ToolCallFunction( 201 | name=tool_name, 202 | arguments=arguments, 203 | ), 204 | id=tool_id, 205 | ) 206 | tool_calls.append(tool_call) 207 | return tool_calls 208 | 209 | @weave.op() 210 | def run(self, state: AgentState, max_runtime_seconds: int = -1): 211 | start_time = time.time() 212 | while True: 213 | last_message = state.history[-1] 214 | if last_message["role"] == "assistant": 215 | # Check if there are no tool calls in the content 216 | if not self.parse_tool_calls(last_message.get("content", "")): 217 | return {"state": state, "stop_reason": "done"} 218 | state = self.step(state) 219 | if ( 220 | max_runtime_seconds > 0 221 | and time.time() - start_time > max_runtime_seconds 222 | ): 223 | return {"state": state, "stop_reason": "time_limit_exceeded"} 224 | -------------------------------------------------------------------------------- /programmer/config.py: -------------------------------------------------------------------------------- 1 | SYSTEM_MESSAGE = """Assistant is a programming assistant named "programmer". 2 | programmer is autonomous, and does not stop to ask for user input until it is totally stuck. 3 | programmer always has access to a shell and local filesystem in perform tasks, via its tools. 4 | programmer writes code directly to files instead of to the terminal, unless it is showing snippets for discussion. 5 | """ 6 | 7 | from .tools import ( 8 | list_files, 9 | write_to_file, 10 | read_from_file, 11 | run_command, 12 | view_image, 13 | read_lines_from_file, 14 | replace_lines_in_file, 15 | splice_lines_in_file, 16 | ) 17 | from .agent import Agent 18 | from .agent_texteditor import AgentTextEditor 19 | from .text_editor import TextEditor 20 | from .agent_texteditor_o1 import AgentTextEditorO1 21 | from typing import Optional, Any 22 | 23 | agent_4o_basic = Agent( 24 | name="gpt-4o-2024-08-06_basic", 25 | model_name="gpt-4o-2024-08-06", 26 | temperature=0.7, 27 | system_message=SYSTEM_MESSAGE, 28 | tools=[list_files, write_to_file, read_from_file, run_command, view_image], 29 | ) 30 | 31 | agent_4omini_basic = Agent( 32 | name="gpt-4o-mini-2024-07-08_basic", 33 | model_name="gpt-4o-mini-2024-07-18", 34 | temperature=0.7, 35 | system_message=SYSTEM_MESSAGE, 36 | tools=[list_files, write_to_file, read_from_file, run_command, view_image], 37 | ) 38 | 39 | agent_claude_basic = Agent( 40 | name="claude-3-5-sonnet-basic", 41 | model_name="claude-3-5-sonnet-20240620", 42 | temperature=0.7, 43 | system_message=SYSTEM_MESSAGE, 44 | tools=[list_files, write_to_file, read_from_file, run_command, view_image], 45 | ) 46 | 47 | agent_4o_replace = Agent( 48 | name="gpt-4o-2024-08-06_replace", 49 | model_name="gpt-4o-2024-08-06", 50 | temperature=0.7, 51 | system_message=SYSTEM_MESSAGE, 52 | tools=[ 53 | list_files, 54 | run_command, 55 | view_image, 56 | read_lines_from_file, 57 | replace_lines_in_file, 58 | ], 59 | ) 60 | 61 | agent_claude_replace = Agent( 62 | name="claude-3-5-sonnet-20240620_replace", 63 | model_name="claude-3-5-sonnet-20240620", 64 | temperature=0.7, 65 | system_message=SYSTEM_MESSAGE, 66 | tools=[ 67 | list_files, 68 | run_command, 69 | view_image, 70 | read_lines_from_file, 71 | replace_lines_in_file, 72 | ], 73 | ) 74 | 75 | 76 | agent_4o_splice = Agent( 77 | name="gpt-4o-2024-08-06_splice", 78 | model_name="gpt-4o-2024-08-06", 79 | temperature=0.7, 80 | system_message=SYSTEM_MESSAGE, 81 | tools=[ 82 | list_files, 83 | run_command, 84 | view_image, 85 | read_lines_from_file, 86 | splice_lines_in_file, 87 | ], 88 | ) 89 | 90 | agent_claude_splice = Agent( 91 | name="claude-3-5-sonnet-20240620_splice", 92 | model_name="claude-3-5-sonnet-20240620", 93 | temperature=0.7, 94 | system_message=SYSTEM_MESSAGE, 95 | tools=[ 96 | list_files, 97 | run_command, 98 | view_image, 99 | read_lines_from_file, 100 | splice_lines_in_file, 101 | ], 102 | ) 103 | 104 | text_editor = TextEditor(max_open_size=15000, open_chunk_size=2000) 105 | agent_texteditor_4o_basic = AgentTextEditor( 106 | name="gpt-4o-2024-08-06_texteditor_basic", 107 | model_name="gpt-4o-2024-08-06", 108 | temperature=0.7, 109 | system_message=SYSTEM_MESSAGE, 110 | text_editor=text_editor, 111 | tools=[list_files, run_command, view_image], 112 | ) 113 | 114 | agent_texteditor_4o_basic_temp0 = AgentTextEditor( 115 | name="gpt-4o-2024-08-06_texteditor_basic_temp0", 116 | model_name="gpt-4o-2024-08-06", 117 | temperature=0.0, 118 | system_message=SYSTEM_MESSAGE, 119 | text_editor=text_editor, 120 | tools=[list_files, run_command, view_image], 121 | ) 122 | 123 | agent_texteditor_4o_basic_noparalleltc = AgentTextEditor( 124 | name="gpt-4o-2024-08-06_texteditor_basic_noparalleltc", 125 | model_name="gpt-4o-2024-08-06", 126 | temperature=0.7, 127 | system_message=SYSTEM_MESSAGE, 128 | text_editor=text_editor, 129 | tools=[list_files, run_command, view_image], 130 | parallel_tool_calls=False, 131 | ) 132 | 133 | agent_texteditor_o1_gpt4o = AgentTextEditorO1( 134 | name="gpt4o_o1harness", 135 | model_name="gpt-4o-2024-08-06", 136 | temperature=0.7, 137 | system_message=SYSTEM_MESSAGE, 138 | text_editor=text_editor, 139 | tools=[list_files, run_command, view_image], 140 | ) 141 | 142 | agent_texteditor_o1_o1preview = AgentTextEditorO1( 143 | name="o1-preview-2024-09-12_o1harness", 144 | model_name="o1-preview-2024-09-12", 145 | temperature=1, 146 | system_message=SYSTEM_MESSAGE, 147 | text_editor=text_editor, 148 | tools=[list_files, run_command, view_image], 149 | ) 150 | 151 | agent_texteditor_o1_o1mini = AgentTextEditorO1( 152 | name="o1-mini-2024-09-12_o1harness", 153 | model_name="o1-mini-2024-09-12", 154 | temperature=1, 155 | system_message=SYSTEM_MESSAGE, 156 | text_editor=text_editor, 157 | tools=[list_files, run_command, view_image], 158 | ) 159 | 160 | 161 | def get_config_by_name(name: str) -> Optional[Any]: 162 | """ 163 | Fetch a configuration object by its name. 164 | 165 | Args: 166 | name (str): The name of the configuration to fetch. 167 | 168 | Returns: 169 | Optional[Any]: The configuration object if found, None otherwise. 170 | """ 171 | # Get all variables defined in this module 172 | all_vars = globals() 173 | 174 | # Look for a variable that matches the given name 175 | for var_name, var_value in all_vars.items(): 176 | if isinstance(var_value, Agent): 177 | if var_value.name == name: 178 | return var_value 179 | 180 | # If no matching configuration is found, return None 181 | return None 182 | 183 | 184 | def get_all_config_names() -> list[str]: 185 | """ 186 | Get a list of all valid configuration names. 187 | 188 | Returns: 189 | list[str]: A list of all configuration names. 190 | """ 191 | all_vars = globals() 192 | config_names = [] 193 | 194 | for var_name, var_value in all_vars.items(): 195 | if isinstance(var_value, (Agent, AgentTextEditor, AgentTextEditorO1)): 196 | config_names.append(var_value.name) 197 | 198 | return sorted(config_names) 199 | -------------------------------------------------------------------------------- /programmer/console.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Optional 3 | from rich.console import Console as RichConsole 4 | from rich.padding import Padding 5 | 6 | 7 | console = RichConsole() 8 | 9 | 10 | class Console: 11 | @staticmethod 12 | def welcome(agent_name: Optional[str] = None) -> None: 13 | console.rule("[bold blue]Programmer") 14 | console.print("Welcome to programmer.") 15 | if agent_name: 16 | console.print(f"Using agent: [bold]{agent_name}[/bold]") 17 | console.print() 18 | 19 | @staticmethod 20 | def step_start(name: str, color: str) -> None: 21 | console.rule(f"[bold {color}]Begin {name} step") 22 | 23 | @staticmethod 24 | def chat_response_start() -> None: 25 | pass 26 | 27 | @staticmethod 28 | def chat_message_content_delta(message_content_delta: str) -> None: 29 | console.print(message_content_delta, end="") 30 | 31 | @staticmethod 32 | def chat_response_complete(agent_response: str) -> None: 33 | console.print("\n") 34 | 35 | @staticmethod 36 | def tool_call_start(tool_call: str) -> None: 37 | console.print(f"[bold yellow]Tool call: [/bold yellow]{tool_call}\n") 38 | 39 | @staticmethod 40 | def tool_call_complete(tool_response: str) -> None: 41 | lines = tool_response.split("\n") 42 | if len(lines) > 4: 43 | lines = lines[:4] 44 | lines.append("...") 45 | tool_response = "\n".join(lines) 46 | console.print( 47 | Padding.indent(f"{tool_response}\n", 4), 48 | no_wrap=True, 49 | overflow="ellipsis", 50 | ) 51 | 52 | @staticmethod 53 | def user_input_complete(user_input: str) -> None: 54 | console.print() 55 | -------------------------------------------------------------------------------- /programmer/containerserver/README.md: -------------------------------------------------------------------------------- 1 | # Container Manager Server 2 | 3 | ## Build images on server 4 | 5 | We use this for running swe-bench locally against containers on a remote server. See [swe-bench README](../swe-bench/README.md) for steps to build the SWE-bench images. 6 | 7 | ## Run and check server 8 | 9 | put cmserver.py on remote machine 10 | ``` 11 | gcloud compute scp --zone "us-west1-a" --project "weave-support-367421" cmserver.py programmer-benchmark2:~/ 12 | ``` 13 | 14 | on remote machine 15 | 16 | (just 1 worker for now, there's global state) 17 | ``` 18 | uvicorn cmserver:app --host 0.0.0.0 --port 8000 --workers 1 19 | ``` 20 | 21 | tunnel from local machine to remote 22 | ``` 23 | gcloud compute ssh --zone "us-west1-a" "programmer-benchmark" --project "weave-support-367421" -- -NL 8000:localhost:8000 24 | ``` 25 | 26 | local machine 27 | ``` 28 | python checkserver.py 29 | ``` 30 | 31 | result on remote machine should be there are no more running containers when done 32 | 33 | -------------------------------------------------------------------------------- /programmer/containerserver/checkserver.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import threading 3 | import argparse 4 | 5 | # Replace with the actual host and port if different 6 | BASE_URL = "http://127.0.0.1:8000" 7 | 8 | 9 | def start_container(image_id: str): 10 | response = requests.post(f"{BASE_URL}/container/start", json={"image_id": image_id}) 11 | if response.status_code == 200: 12 | return response.json().get("container_id") 13 | else: 14 | print(f"Failed to start container: {response.text}") 15 | return None 16 | 17 | 18 | def run_command(container_id: str, workdir: str, command: str): 19 | response = requests.post( 20 | f"{BASE_URL}/container/run", 21 | json={"container_id": container_id, "workdir": workdir, "command": command}, 22 | ) 23 | if response.status_code == 200: 24 | return response.json() 25 | else: 26 | print(f"Failed to run command: {response.text}") 27 | return None 28 | 29 | 30 | def write_file(container_id: str, file_path: str, file_content: str): 31 | response = requests.post( 32 | f"{BASE_URL}/container/write_file", 33 | json={ 34 | "container_id": container_id, 35 | "file_path": file_path, 36 | "file_content": file_content, 37 | }, 38 | ) 39 | if response.status_code == 200: 40 | return response.json().get("status") 41 | else: 42 | print(f"Failed to write file: {response.text}") 43 | return None 44 | 45 | 46 | def read_file(container_id: str, file_path: str): 47 | response = requests.post( 48 | f"{BASE_URL}/container/read_file", 49 | json={"container_id": container_id, "file_path": file_path}, 50 | ) 51 | if response.status_code == 200: 52 | return response.json().get("file_content") 53 | else: 54 | print(f"Failed to read file: {response.text}") 55 | return None 56 | 57 | 58 | def stop_container(container_id: str, delete: bool): 59 | response = requests.post( 60 | f"{BASE_URL}/container/stop", 61 | json={"container_id": container_id, "delete": delete}, 62 | ) 63 | if response.status_code == 200: 64 | return response.json().get("status") 65 | else: 66 | print(f"Failed to stop container: {response.text}") 67 | return None 68 | 69 | 70 | def manage_container(image_id: str, container_index: int): 71 | print(f"Starting container {container_index}...") 72 | container_id = start_container(image_id) 73 | if not container_id: 74 | print(f"Failed to start container {container_index}") 75 | return 76 | 77 | print(f"Started container {container_index} with ID: {container_id}") 78 | 79 | # Run a command inside the container 80 | output = run_command(container_id, "/", "ls") 81 | if output: 82 | print(f"Container {container_index} command output:\n{output}") 83 | 84 | # Write a file inside the container 85 | file_path = f"test_{container_index}.txt" 86 | file_content = f"Hello, this is a test for container {container_index}." 87 | write_status = write_file(container_id, file_path, file_content) 88 | if write_status: 89 | print(f"Container {container_index} write file status: {write_status}") 90 | 91 | # Read the file back from the container 92 | read_content = read_file(container_id, file_path) 93 | if read_content: 94 | print(f"Container {container_index} file content:\n{read_content}") 95 | 96 | # Stop the container (and delete it) 97 | stop_status = stop_container(container_id, delete=True) 98 | if stop_status: 99 | print(f"Container {container_index} stop status: {stop_status}") 100 | 101 | 102 | def run_parallel_tests(image_id: str, parallelism: int): 103 | threads = [] 104 | for i in range(parallelism): 105 | thread = threading.Thread(target=manage_container, args=(image_id, i)) 106 | threads.append(thread) 107 | thread.start() 108 | 109 | for thread in threads: 110 | thread.join() 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser(description="Run parallel container tests") 115 | parser.add_argument( 116 | "--parallelism", 117 | type=int, 118 | default=1, 119 | help="Number of parallel container operations (default: 1)", 120 | ) 121 | parser.add_argument( 122 | "--image-id", 123 | type=str, 124 | default="sweb.eval.x86_64.sympy__sympy-20590", 125 | help="Image ID to test", 126 | ) 127 | args = parser.parse_args() 128 | 129 | run_parallel_tests(args.image_id, args.parallelism) 130 | -------------------------------------------------------------------------------- /programmer/containerserver/cmserver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | from io import BytesIO 4 | from concurrent.futures import ThreadPoolExecutor 5 | import asyncio 6 | import docker 7 | from docker.errors import NotFound 8 | from fastapi import FastAPI, HTTPException 9 | from pydantic import BaseModel 10 | 11 | 12 | # DockerContainerManager class 13 | class DockerContainerManager: 14 | def __init__(self): 15 | self.client = docker.from_env() 16 | self.executor = ThreadPoolExecutor() 17 | 18 | async def start_container(self, image_id: str): 19 | loop = asyncio.get_event_loop() 20 | container = await loop.run_in_executor( 21 | self.executor, self._run_container, image_id 22 | ) 23 | return container.short_id 24 | 25 | def _run_container(self, image_id: str): 26 | return self.client.containers.run( 27 | image_id, detach=True, command="tail -f /dev/null" 28 | ) 29 | 30 | def _get_container(self, container_id: str): 31 | return self.client.containers.get(container_id) 32 | 33 | async def run_command(self, container_id: str, workdir: str, command: str): 34 | loop = asyncio.get_event_loop() 35 | exec_result = await loop.run_in_executor( 36 | self.executor, self._exec_run, container_id, command, workdir 37 | ) 38 | return { 39 | "exit_code": exec_result.exit_code, 40 | "output": exec_result.output.decode("utf-8"), 41 | } 42 | 43 | def _exec_run(self, container_id: str, command: str, workdir: str): 44 | container = self._get_container(container_id) 45 | return container.exec_run(command, workdir=workdir) 46 | 47 | async def write_file(self, container_id: str, file_path: str, file_content: str): 48 | file_path = os.path.join("/", file_path) 49 | container = self._get_container(container_id) 50 | tarstream = BytesIO() 51 | with tarfile.open(fileobj=tarstream, mode="w") as tar: 52 | tarinfo = tarfile.TarInfo(name=os.path.basename(file_path)) 53 | tarinfo.size = len(file_content) 54 | tar.addfile(tarinfo, BytesIO(file_content.encode("utf-8"))) 55 | tarstream.seek(0) 56 | 57 | loop = asyncio.get_event_loop() 58 | await loop.run_in_executor( 59 | self.executor, 60 | container.put_archive, 61 | os.path.dirname(file_path), 62 | tarstream, 63 | ) 64 | 65 | async def read_file(self, container_id: str, file_path: str): 66 | container = self._get_container(container_id) 67 | loop = asyncio.get_event_loop() 68 | bits, _ = await loop.run_in_executor( 69 | self.executor, container.get_archive, file_path 70 | ) 71 | file_content = BytesIO() 72 | for chunk in bits: 73 | file_content.write(chunk) 74 | file_content.seek(0) 75 | with tarfile.open(fileobj=file_content) as tar: 76 | member = tar.getmembers()[0] 77 | extract_result = tar.extractfile(member) 78 | if extract_result is None: 79 | raise Exception(f"Unexpected tar.extractfile result for: {file_path}") 80 | file_data = extract_result.read() 81 | return file_data.decode("utf-8") 82 | 83 | async def stop_container(self, container_id: str, delete: bool = False): 84 | container = self._get_container(container_id) 85 | loop = asyncio.get_event_loop() 86 | await loop.run_in_executor(self.executor, container.stop) 87 | if delete: 88 | await loop.run_in_executor(self.executor, container.remove) 89 | 90 | 91 | # FastAPI setup 92 | app = FastAPI() 93 | container_manager = DockerContainerManager() 94 | 95 | 96 | class StartContainerRequest(BaseModel): 97 | image_id: str 98 | 99 | 100 | class StopContainerRequest(BaseModel): 101 | container_id: str 102 | delete: bool 103 | 104 | 105 | class CommandRequest(BaseModel): 106 | container_id: str 107 | workdir: str 108 | command: str 109 | 110 | 111 | class FileRequest(BaseModel): 112 | container_id: str 113 | file_path: str 114 | file_content: str 115 | 116 | 117 | class FilePathRequest(BaseModel): 118 | container_id: str 119 | file_path: str 120 | 121 | 122 | @app.post("/container/start") 123 | async def start_container(request: StartContainerRequest): 124 | try: 125 | container_id = await container_manager.start_container(request.image_id) 126 | return {"container_id": container_id} 127 | except Exception as e: 128 | raise HTTPException(status_code=500, detail=str(e)) 129 | 130 | 131 | @app.post("/container/run") 132 | async def run_command(request: CommandRequest): 133 | try: 134 | result = await container_manager.run_command( 135 | request.container_id, request.workdir, request.command 136 | ) 137 | return {"exit_code": result["exit_code"], "output": result["output"]} 138 | except Exception as e: 139 | raise HTTPException(status_code=500, detail=str(e)) 140 | 141 | 142 | @app.post("/container/write_file") 143 | async def write_file(request: FileRequest): 144 | try: 145 | await container_manager.write_file( 146 | request.container_id, request.file_path, request.file_content 147 | ) 148 | return {"status": "file written"} 149 | except NotFound as e: 150 | raise HTTPException(status_code=404, detail=str(e)) 151 | except Exception as e: 152 | raise HTTPException(status_code=500, detail=str(e)) 153 | 154 | 155 | @app.post("/container/read_file") 156 | async def read_file(request: FilePathRequest): 157 | try: 158 | file_content = await container_manager.read_file( 159 | request.container_id, request.file_path 160 | ) 161 | return {"file_content": file_content} 162 | except NotFound as e: 163 | raise HTTPException(status_code=404, detail=str(e)) 164 | except Exception as e: 165 | raise HTTPException(status_code=500, detail=str(e)) 166 | 167 | 168 | @app.post("/container/stop") 169 | async def stop_container(request: StopContainerRequest): 170 | try: 171 | await container_manager.stop_container(request.container_id, request.delete) 172 | return {"status": "container stopped"} 173 | except Exception as e: 174 | raise HTTPException(status_code=500, detail=str(e)) 175 | 176 | 177 | # To run the server, use: 178 | # uvicorn your_file_name:app --host 0.0.0.0 --port 8000 --workers 4 179 | -------------------------------------------------------------------------------- /programmer/environment.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Protocol 3 | from contextvars import ContextVar 4 | from contextlib import contextmanager 5 | 6 | from .git import GitRepo 7 | 8 | @dataclass 9 | class EnvironmentSnapshotKey: 10 | env_id: str 11 | snapshot_info: dict 12 | 13 | 14 | class Environment(Protocol): 15 | def start_session(self, session_id: str): ... 16 | 17 | def finish_session(self): ... 18 | 19 | def make_snapshot(self, message: str) -> EnvironmentSnapshotKey: ... 20 | 21 | @classmethod 22 | def restore_from_snapshot_key(cls, ref: EnvironmentSnapshotKey): ... 23 | 24 | 25 | @contextmanager 26 | def environment_session(env: Environment, session_id: str | None): 27 | if session_id is not None: 28 | env.start_session(session_id) 29 | token = environment_context.set(env) 30 | try: 31 | yield env 32 | finally: 33 | env.finish_session() 34 | environment_context.reset(token) 35 | else: 36 | yield env 37 | 38 | 39 | def get_current_environment() -> Environment: 40 | return environment_context.get() 41 | 42 | 43 | class GitEnvironment(Environment): 44 | def __init__(self, repo: GitRepo): 45 | self.repo = repo 46 | self.original_git_ref = None 47 | self.programmer_branch = None 48 | 49 | def start_session(self, session_id: str): 50 | self.original_git_ref = self.repo.get_current_head() 51 | self.programmer_branch = f"programmer-{session_id}" 52 | print("programmer_branch:", self.programmer_branch) 53 | # Create the programmer branch based on the current state 54 | self.repo.create_branch(self.programmer_branch) 55 | 56 | def finish_session(self): 57 | if self.original_git_ref is None or self.programmer_branch is None: 58 | raise ValueError("Session not started") 59 | # No need to checkout back as we never changed the branch 60 | 61 | def make_snapshot(self, message: str) -> EnvironmentSnapshotKey: 62 | if self.programmer_branch is None: 63 | raise ValueError("Programmer branch is not set") 64 | # Commit directly to the programmer branch using new method 65 | commit_hash = self.repo.commit_directly_to_branch(self.programmer_branch, message) 66 | return EnvironmentSnapshotKey( 67 | "git", {"origin": self.repo.get_origin_url(), "commit": commit_hash} 68 | ) 69 | 70 | @classmethod 71 | def restore_from_snapshot_key(cls, ref: EnvironmentSnapshotKey): 72 | origin = ref.snapshot_info["origin"] 73 | commit = ref.snapshot_info["commit"] 74 | repo = GitRepo.from_current_dir() 75 | if not repo: 76 | raise ValueError("No git repo found") 77 | if origin != repo.get_origin_url(): 78 | raise ValueError("Origin URL mismatch") 79 | repo.checkout_existing(commit) 80 | print("Checked out commit", commit) 81 | 82 | 83 | class NoopEnvironment(Environment): 84 | def start_session(self, session_id: str): 85 | pass 86 | 87 | def finish_session(self): 88 | pass 89 | 90 | def make_snapshot(self, message: str) -> EnvironmentSnapshotKey: 91 | return EnvironmentSnapshotKey("noop", {}) 92 | 93 | @classmethod 94 | def restore_from_snapshot_key(cls, ref: EnvironmentSnapshotKey): 95 | pass 96 | 97 | 98 | def restore_environment(snapshot_key: EnvironmentSnapshotKey) -> Environment: 99 | if snapshot_key.env_id == "git": 100 | GitEnvironment.restore_from_snapshot_key(snapshot_key) 101 | return NoopEnvironment() 102 | 103 | 104 | environment_context: ContextVar[Environment] = ContextVar( 105 | "environment", default=NoopEnvironment() 106 | ) 107 | -------------------------------------------------------------------------------- /programmer/evals/eval_repeated_edits.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import os 3 | import time 4 | import concurrent.futures 5 | from typing import TypedDict, Callable 6 | from contextlib import contextmanager 7 | 8 | import weave 9 | from weave.trace import call_context 10 | 11 | from ..agent import AgentState, Agent 12 | from ..config import * 13 | from ..io_context import LocalIOContext, io_context, get_io_context 14 | 15 | # NOTES 16 | # - Try with other LLM and tool configs now that I have this test 17 | 18 | 19 | # @pytest.fixture 20 | @contextmanager 21 | def tempdir(): 22 | with tempfile.TemporaryDirectory() as dir_: 23 | with io_context(LocalIOContext(dir_)) as tc: 24 | yield tc 25 | 26 | 27 | def call_descendent_error_count(call): 28 | has_error = int(call.exception is not None) 29 | descendent_errors = sum(call_descendent_error_count(c) for c in call._children) 30 | return has_error + descendent_errors 31 | 32 | 33 | class EvalEditMemoryConfig(TypedDict): 34 | n_lines: int 35 | run_timeout_seconds: int 36 | 37 | 38 | @weave.op 39 | def eval_edit_memory( 40 | config: EvalEditMemoryConfig, agent: Agent, name: str, trial_idx: int 41 | ): 42 | call = weave.get_current_call() 43 | if call: 44 | call.set_display_name(f"{name}: Trial{trial_idx}") 45 | with tempdir() as ctx: 46 | expected_lines = [] 47 | n_alpha = 10 48 | num_per_alpha = config["n_lines"] // n_alpha 49 | for attempt in range(config["n_lines"]): 50 | expected_lines.append( 51 | f"{chr(65 + attempt // num_per_alpha)}{attempt % num_per_alpha}" 52 | ) 53 | 54 | with open(ctx.resolve_path("file.txt"), "w") as f: 55 | prev_file_contents = "\n".join(expected_lines) 56 | f.write(prev_file_contents) 57 | 58 | task_correct = False 59 | state = agent.initial_state(history=[]) 60 | 61 | def step6_insert_ampersands(lines): 62 | new_lines = [] 63 | for l in lines: 64 | if l.endswith("***"): 65 | new_lines.append("&&") 66 | new_lines.append(l) 67 | return new_lines 68 | 69 | results: dict = {} 70 | task_infos = [] 71 | for task_idx, (task_name, prompt, modify_expected_fn) in enumerate( 72 | [ 73 | ( 74 | "replace_range", 75 | "file.txt contains lines like 'C4', 'D8' etc. Replace lines 'A7' through 'B4' (inclusive) with 'X\\nY\\n'.", 76 | lambda lines: lines[: lines.index("A7")] 77 | + ["X", "Y"] 78 | + lines[lines.index("B4") + 1 :], 79 | ), 80 | ( 81 | "correct_range", 82 | "Actually that edit was wrong. Replace them with 'Z\nZZ\nZZZ\n' instead.", 83 | lambda lines: lines[: lines.index("X")] 84 | + ["Z", "ZZ", "ZZZ"] 85 | + lines[lines.index("Y") + 1 :], 86 | ), 87 | ( 88 | "insert_beginning", 89 | "Add a line '😊😊😊' to the start of the file.", 90 | lambda lines: ["😊😊😊"] + lines, 91 | ), 92 | ( 93 | "append_end", 94 | "Add a line '😔😔😔' to the end of the file.", 95 | lambda lines: lines + ["😔😔😔"], 96 | ), 97 | ( 98 | "replace_prior_range", 99 | "Replace the Z lines we added earlier with a single blank line.", 100 | lambda lines: lines[: lines.index("Z")] 101 | + [""] 102 | + lines[lines.index("ZZZ") + 1 :], 103 | ), 104 | ( 105 | "distribute_asterisks", 106 | "Append *** to the end of each line that ends with 7.", 107 | lambda lines: [l + "***" if l.endswith("7") else l for l in lines], 108 | ), 109 | ( 110 | "distribute_ampersand_prefix", 111 | "Insert a line containing '&&' prior to each of the '***' lines we just added.", 112 | step6_insert_ampersands, 113 | ), 114 | ] 115 | ): 116 | expected_lines = modify_expected_fn(expected_lines) 117 | run_task_result = run_task( 118 | config, 119 | agent, 120 | state, 121 | expected_lines, 122 | task_idx, 123 | task_name, 124 | prompt, 125 | ) 126 | state = run_task_result["state"] 127 | task_info = run_task_result["task_info"] 128 | task_infos.append(task_info) 129 | task_correct = task_info["correct"] 130 | if not task_correct: 131 | # Don't do further tasks. 132 | break 133 | 134 | results["success"] = task_correct 135 | results["completed_tasks"] = sum( 136 | task_info["correct"] for task_info in task_infos 137 | ) 138 | results["max_attempts"] = max( 139 | task_info["n_attempts"] for task_info in task_infos 140 | ) 141 | results["total_errors"] = sum(task_info["n_errors"] for task_info in task_infos) 142 | return results 143 | 144 | 145 | @weave.op 146 | def run_task( 147 | config: EvalEditMemoryConfig, 148 | agent: Agent, 149 | state: AgentState, 150 | expected_lines: list[str], 151 | task_idx: int, 152 | task_name: str, 153 | prompt: str, 154 | ): 155 | call = weave.get_current_call() 156 | if call: 157 | call.set_display_name(f"Task{task_idx}: {task_name}") 158 | print(f"*** TASK: {task_idx}, {prompt}") 159 | state = state.with_history( 160 | state.history 161 | + [ 162 | { 163 | "role": "user", 164 | "content": prompt, 165 | }, 166 | ], 167 | ) 168 | task_info = {"task_idx": task_idx} 169 | task_correct = False 170 | attempts = [] 171 | for attempt_idx in range(2): 172 | attempt_result = run_attempt(config, agent, state, expected_lines, attempt_idx) 173 | attempt_info = attempt_result["attempt_info"] 174 | state = attempt_result["state"] 175 | if attempt_info["correct"]: 176 | task_correct = True 177 | break 178 | 179 | attempts.append(attempt_info) 180 | 181 | print() 182 | print(f"*** FAILED ATTEMPT Task: {task_idx} Attempt: {attempt_idx}") 183 | print() 184 | state = state.with_history( 185 | state.history 186 | + [ 187 | { 188 | "role": "user", 189 | "content": "edit was incorrect, try again", 190 | }, 191 | ], 192 | ) 193 | task_info["correct"] = task_correct 194 | task_info["n_attempts"] = len(attempts) 195 | task_info["n_errors"] = sum(attempt_info["n_errors"] for attempt_info in attempts) 196 | task_info["n_messages"] = sum( 197 | attempt_info["n_messages"] for attempt_info in attempts 198 | ) 199 | 200 | return { 201 | "task_info": task_info, 202 | "state": state, 203 | } 204 | 205 | 206 | @weave.op 207 | def run_attempt( 208 | config: EvalEditMemoryConfig, 209 | agent: Agent, 210 | state: AgentState, 211 | expected_lines: list[str], 212 | attempt_idx: int, 213 | ): 214 | call = weave.get_current_call() 215 | if call: 216 | call.set_display_name(f"Attempt{attempt_idx}") 217 | ctx = get_io_context() 218 | attempt_info: dict = { 219 | "attempt_idx": attempt_idx, 220 | "correct": False, 221 | "n_errors": 0, 222 | "n_messages": 0, 223 | } 224 | with open(ctx.resolve_path("file.txt"), "r") as f: 225 | prev_file_contents = f.read().strip() 226 | 227 | run_result, call = agent.run.call( 228 | agent, state, max_runtime_seconds=config["run_timeout_seconds"] 229 | ) 230 | attempt_info["n_errors"] = call_descendent_error_count(call) 231 | if call.exception is not None: 232 | print("*** EXCEPTION ***") 233 | print(call.exception) 234 | attempt_info["stop_reason"] = "exception" 235 | return { 236 | "attempt_info": attempt_info, 237 | "state": state, 238 | } 239 | attempt_info["stop_reason"] = run_result["stop_reason"] 240 | stop_reason = run_result["stop_reason"] 241 | if stop_reason == "time_limit_exceeded": 242 | print("*** TIME LIMIT EXCEEDED ***") 243 | 244 | next_state = run_result["state"] 245 | 246 | attempt_info["n_messages"] = len(next_state.history) - len(state.history) 247 | state = next_state 248 | 249 | with open(ctx.resolve_path("file.txt"), "r") as f: 250 | file_contents = f.read().strip() 251 | attempt_info["made_edit"] = file_contents != prev_file_contents 252 | 253 | file_lines = file_contents.split("\n") 254 | attempt_correct = file_contents == "\n".join(expected_lines) 255 | attempt_info["correct"] = attempt_correct 256 | attempt_info["error_details"] = mismatch_details(expected_lines, file_lines) 257 | 258 | return { 259 | "attempt_info": attempt_info, 260 | "state": state, 261 | } 262 | 263 | 264 | def mismatch_details(lines, file_lines): 265 | error_details = [] 266 | error_details.append("Incorrect edit") 267 | error_details.append("file.txt\texpected") 268 | error_details.append(f"len={len(file_lines)}\tlen={len(lines)}") 269 | for i in range(len(max(lines, file_lines))): 270 | try: 271 | file_lines_i = file_lines[i] 272 | except IndexError: 273 | file_lines_i = None 274 | try: 275 | lines_i = lines[i] 276 | except IndexError: 277 | lines_i = None 278 | line_correct = file_lines_i == lines_i 279 | line_correct_str = "✅" if line_correct else "❌" 280 | error_details.append(f"{line_correct_str} {file_lines_i}\t{lines_i}") 281 | 282 | return "\n".join(error_details) 283 | 284 | 285 | @weave.op 286 | def run_trials( 287 | config: EvalEditMemoryConfig, 288 | agent: Agent, 289 | name: str, 290 | n_trials: int, 291 | max_workers: int = 16, 292 | ): 293 | call = weave.get_current_call() 294 | if call: 295 | call.set_display_name(name + f"_{n_trials}trials") 296 | current_call = call_context.get_current_call() 297 | if current_call is None: 298 | raise Exception("Should not happen, no current call") 299 | 300 | def run_single_trial(trial_idx: int): 301 | with call_context.current_call(current_call): 302 | start_time = time.time() 303 | result = eval_edit_memory(config, agent, name, trial_idx) 304 | duration = time.time() - start_time 305 | print(f"{name}: {result} {duration:.2f}s") 306 | return {**result, "duration": duration} 307 | 308 | results = [] 309 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 310 | futures = [ 311 | executor.submit(run_single_trial, trial_idx) 312 | for trial_idx in range(n_trials) 313 | ] 314 | results = [ 315 | future.result() for future in concurrent.futures.as_completed(futures) 316 | ] 317 | 318 | return { 319 | "success": sum(result["success"] for result in results) / n_trials, 320 | "avg_errors": sum(result["total_errors"] for result in results) / n_trials, 321 | "avg_completed_tasks": sum(result["completed_tasks"] for result in results) 322 | / n_trials, 323 | "max_attempts": max(result["max_attempts"] for result in results), 324 | } 325 | 326 | 327 | if __name__ == "__main__": 328 | weave.init("programmerdev-eval-edits1") 329 | agents = [ 330 | # agent_4omini_basic, 331 | # agent_4o_basic, 332 | # agent_claude_basic, 333 | # agent_4o_replace, 334 | # agent_claude_replace, 335 | # agent_4o_splice, 336 | # agent_claude_splice, 337 | # agent_texteditor_4o_basic, 338 | # agent_texteditor_4o_basic_temp0, 339 | # agent_texteditor_4o_basic_noparalleltc, 340 | # agent_texteditor_o1_o1preview, 341 | agent_texteditor_o1_o1mini, 342 | ] 343 | 344 | config = EvalEditMemoryConfig(n_lines=100, run_timeout_seconds=600) 345 | n_trials = 1 346 | config_s = f'{config["n_lines"]}lines_{config["run_timeout_seconds"]}timeout' 347 | results = {} 348 | for agent in agents: 349 | run_name = f"{agent.name}_{config_s}" 350 | results[agent.name] = run_trials( 351 | config, agent, run_name, n_trials, max_workers=5 352 | ) 353 | from rich import print 354 | 355 | print(results) 356 | -------------------------------------------------------------------------------- /programmer/evaluate.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import weave 4 | 5 | from agent import AgentState 6 | from config import agent_4o_basic 7 | 8 | # Need to serialize AgentState as json for now since we can't save weave Objects 9 | # in Dataset set. 10 | 11 | 12 | @weave.op() 13 | def rollout_len(model_output: str): 14 | final_state = AgentState(**json.loads(model_output)) 15 | return len(final_state.history) 16 | 17 | 18 | @weave.op() 19 | def final_answer_substr(expected_substr: str, model_output: str): 20 | final_state = AgentState(**json.loads(model_output)) 21 | final_message = final_state.history[-1] 22 | return expected_substr in final_message["content"] 23 | 24 | 25 | eval = weave.Evaluation( 26 | dataset=[ 27 | { 28 | "state": AgentState( 29 | history=[{"role": "user", "content": "what's in frog.jpeg"}] 30 | ).model_dump_json(), 31 | "expected_substr": "kitten", 32 | } 33 | ], 34 | scorers=[rollout_len, final_answer_substr], 35 | ) 36 | 37 | # Can't call a method with evaluation yet, so use this funky bridge function. 38 | # This also does our AgentState deserialization. 39 | 40 | 41 | @weave.op() 42 | def model_agent_bridge(state: str): 43 | return agent_4o_basic.run(AgentState(**json.loads(state))).model_dump_json() 44 | 45 | 46 | if __name__ == "__main__": 47 | weave.init_local_client() 48 | result = asyncio.run(eval.evaluate(model_agent_bridge)) 49 | -------------------------------------------------------------------------------- /programmer/file_protocol.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol 2 | 3 | 4 | class FileSystem(Protocol): 5 | def write_file(self, path: str, content: str) -> None: ... 6 | def read_file(self, path: str) -> str: ... 7 | -------------------------------------------------------------------------------- /programmer/frog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/programmer/9b25bc7d4d4e4f4d685737de66147c248d1159d8/programmer/frog.jpg -------------------------------------------------------------------------------- /programmer/git.py: -------------------------------------------------------------------------------- 1 | import os 2 | from git import Repo, InvalidGitRepositoryError, GitCommandError 3 | from typing import Optional 4 | import tempfile 5 | 6 | 7 | class GitRepo: 8 | def __init__(self, repo: Repo): 9 | self.repo = repo 10 | 11 | @classmethod 12 | def from_current_dir(cls) -> Optional["GitRepo"]: 13 | try: 14 | repo = Repo(os.getcwd(), search_parent_directories=True) 15 | return cls(repo) 16 | except InvalidGitRepositoryError: 17 | return None 18 | 19 | def get_origin_url(self) -> Optional[str]: 20 | try: 21 | remote_url = self.repo.remotes.origin.url 22 | return remote_url if remote_url else None 23 | except AttributeError: 24 | return None 25 | 26 | def create_branch(self, branch_name: str) -> None: 27 | if branch_name not in self.repo.heads: 28 | # Create a new branch from the current HEAD 29 | self.repo.git.branch(branch_name, self.repo.head.commit.hexsha) 30 | 31 | def checkout_existing(self, branch_name: str) -> None: 32 | self.repo.git.checkout(branch_name) 33 | 34 | def commit_directly_to_branch(self, branch_name: str, message: str) -> str: 35 | # Ensure the branch is initialized 36 | self.create_branch(branch_name) 37 | 38 | # Use a temporary index file to stage files without affecting the actual index. 39 | with tempfile.TemporaryDirectory() as temp_dir: 40 | temp_index_file = os.path.join(temp_dir, "index") 41 | env = os.environ.copy() 42 | env["GIT_INDEX_FILE"] = temp_index_file 43 | 44 | # Add all files from the working directory to the temporary index 45 | self.repo.git.add(A=True, env=env) 46 | 47 | # Determine the parent commit 48 | parent_commit = self.repo.commit(branch_name) 49 | 50 | # Check for changes between parent_commit and the temporary index 51 | diff_output = self.repo.git.diff(parent_commit.hexsha, "--cached", env=env) 52 | 53 | if not diff_output.strip(): 54 | # No changes to commit 55 | return parent_commit.hexsha 56 | 57 | # Write the tree from the temporary index 58 | tree = self.repo.git.write_tree(env=env) 59 | 60 | # print( 61 | # f"Committing to branch {branch_name}, parent commit: {parent_commit.hexsha}" 62 | # ) 63 | 64 | # Set author information using environment variables 65 | env["GIT_AUTHOR_NAME"] = "programmer" 66 | env["GIT_AUTHOR_EMAIL"] = "programmer-noreply@example.com" 67 | 68 | # Use the Repo's git command interface to create a commit-tree 69 | commit_hash = self.repo.git.commit_tree( 70 | tree, "-p", parent_commit.hexsha, "-m", message, env=env 71 | ) 72 | 73 | # Update the branch reference to point to the new commit 74 | self.repo.git.update_ref(f"refs/heads/{branch_name}", commit_hash) 75 | 76 | return commit_hash 77 | 78 | def get_current_head(self) -> str: 79 | if self.repo.head.is_detached: 80 | return str(self.repo.head.commit.hexsha) 81 | else: 82 | return str(self.repo.active_branch.name) 83 | -------------------------------------------------------------------------------- /programmer/io_context.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, TypedDict 2 | import os 3 | import subprocess 4 | import requests 5 | import shlex 6 | from contextlib import contextmanager 7 | from contextvars import ContextVar 8 | from typing import Optional, Union 9 | 10 | 11 | class RunCommandResult(TypedDict): 12 | exit_code: int 13 | output: str 14 | 15 | 16 | class IOContext(Protocol): 17 | def write_file(self, path: str, content: str) -> None: ... 18 | 19 | def read_file(self, path: str) -> str: ... 20 | 21 | def run_command(self, command: str) -> RunCommandResult: ... 22 | 23 | def resolve_path(self, path: str) -> str: ... 24 | 25 | 26 | class LocalIOContext(IOContext): 27 | def __init__(self, directory): 28 | self.directory = os.path.abspath(directory) 29 | 30 | def write_file(self, path: str, content: str) -> None: 31 | full_path = self.resolve_path(path) 32 | with open(full_path, "w") as f: 33 | f.write(content) 34 | 35 | def read_file(self, path: str) -> str: 36 | full_path = self.resolve_path(path) 37 | with open(full_path, "r") as f: 38 | return f.read() 39 | 40 | def run_command(self, command: str) -> RunCommandResult: 41 | completed_process = subprocess.run( 42 | command, 43 | stdout=subprocess.PIPE, 44 | stderr=subprocess.STDOUT, 45 | text=True, 46 | shell=True, 47 | cwd=self.directory, 48 | ) 49 | exit_code = completed_process.returncode 50 | output = completed_process.stdout.strip() 51 | 52 | return { 53 | "exit_code": exit_code, 54 | "output": output, 55 | } 56 | 57 | def resolve_path(self, path: str) -> str: 58 | return os.path.join(self.directory, path) 59 | 60 | 61 | class RemoteContainerIOContext(IOContext): 62 | def __init__(self, base_url: str, directory: str, command_prefix: str): 63 | self.base_url = base_url 64 | self.container_id = None 65 | self.directory = directory 66 | self.command_prefix = command_prefix 67 | 68 | @contextmanager 69 | def context(self, image_id: str): 70 | self.start_container(image_id) 71 | try: 72 | with io_context(self): 73 | yield 74 | finally: 75 | self.stop_container() 76 | 77 | def start_container(self, image_id): 78 | response = requests.post( 79 | f"{self.base_url}/container/start", json={"image_id": image_id} 80 | ) 81 | if response.status_code == 200: 82 | self.container_id = response.json().get("container_id") 83 | else: 84 | print(f"Failed to start container: {response.text}") 85 | 86 | def stop_container(self): 87 | response = requests.post( 88 | f"{self.base_url}/container/stop", 89 | json={"container_id": self.container_id, "delete": True}, 90 | ) 91 | if response.status_code == 200: 92 | self.container_id = None 93 | else: 94 | print(f"Failed to stop container: {response.text}") 95 | 96 | def write_file(self, path: str, content: str) -> None: 97 | full_path = os.path.join(self.directory, path) 98 | response = requests.post( 99 | f"{self.base_url}/container/write_file", 100 | json={ 101 | "container_id": self.container_id, 102 | "file_path": full_path, 103 | "file_content": content, 104 | }, 105 | ) 106 | if response.status_code != 200: 107 | raise Exception(f"Failed to write file: {response.text}") 108 | 109 | def read_file(self, path: str) -> str: 110 | full_path = os.path.join(self.directory, path) 111 | response = requests.post( 112 | f"{self.base_url}/container/read_file", 113 | json={"container_id": self.container_id, "file_path": full_path}, 114 | ) 115 | if response.status_code == 200: 116 | return response.json().get("file_content") 117 | else: 118 | raise Exception(f"Failed to read file: {response.text}") 119 | 120 | def run_command(self, command: str) -> RunCommandResult: 121 | command = self.command_prefix + command 122 | command = f"bash -c {shlex.quote(command)}" 123 | response = requests.post( 124 | f"{self.base_url}/container/run", 125 | json={ 126 | "container_id": self.container_id, 127 | "workdir": self.directory, 128 | "command": command, 129 | }, 130 | ) 131 | if response.status_code == 200: 132 | json = response.json() 133 | return { 134 | "exit_code": json["exit_code"], 135 | "output": json["output"], 136 | } 137 | else: 138 | raise Exception(f"Failed to run command: {response.text}") 139 | 140 | def resolve_path(self, path: str) -> str: 141 | return path # For remote containers, we assume paths are already resolved 142 | 143 | 144 | # Create a ContextVar to store the current ToolContext 145 | _io_context: ContextVar[Optional[Union[LocalIOContext, RemoteContainerIOContext]]] = ( 146 | ContextVar("_io_context", default=None) 147 | ) 148 | 149 | 150 | @contextmanager 151 | def io_context(context: Union[LocalIOContext, RemoteContainerIOContext]): 152 | token = _io_context.set(context) 153 | try: 154 | yield context 155 | finally: 156 | _io_context.reset(token) 157 | 158 | 159 | def get_io_context() -> Union[LocalIOContext, RemoteContainerIOContext]: 160 | context = _io_context.get() 161 | if context is None: 162 | return LocalIOContext(".") 163 | return context 164 | -------------------------------------------------------------------------------- /programmer/programmer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import subprocess 5 | 6 | from rich import print 7 | from rich.console import Console 8 | 9 | from typing import Any, Optional 10 | 11 | import weave 12 | 13 | from .agent import Agent, AgentState, get_commit_message 14 | from .console import Console, console 15 | from .config import ( 16 | # agent_4o_replace, 17 | # agent_texteditor_4o_basic, 18 | # agent_texteditor_o1_gpt4o, 19 | # agent_texteditor_o1_o1preview, 20 | get_config_by_name, 21 | ) 22 | from .environment import ( 23 | environment_session, 24 | restore_environment, 25 | get_current_environment, 26 | GitEnvironment, 27 | NoopEnvironment, 28 | ) 29 | from .weave_next.api import init_local_client 30 | from .settings_manager import SettingsManager 31 | 32 | from .git import GitRepo 33 | 34 | 35 | @weave.op 36 | def get_user_input(): 37 | return input("User input: ") 38 | 39 | 40 | @weave.op 41 | def user_input_step(state: AgentState) -> AgentState: 42 | Console.step_start("user_input", "purple") 43 | # Printing this is ugly 44 | # ref = weave.obj_ref(state) 45 | # if ref: 46 | # print("state ref:", ref.uri()) 47 | user_input = get_user_input() 48 | history = state.history + [ 49 | { 50 | "role": "user", 51 | "content": user_input, 52 | } 53 | ] 54 | return state.with_history(history) 55 | 56 | 57 | def settings_command(command_args): 58 | if len(command_args) < 2: 59 | console.print("[red]Invalid settings command[/red]") 60 | return 61 | action = command_args[0] 62 | key = command_args[1] 63 | if action == "get": 64 | value = SettingsManager.get_setting(key) 65 | if value is not None: 66 | console.print(f"{key} = {value}") 67 | else: 68 | console.print(f"[red]Setting '{key}' not found[/red]") 69 | elif action == "set" and len(command_args) == 3: 70 | value = command_args[2] 71 | SettingsManager.set_setting(key, value) 72 | console.print(f"[green]Setting '{key}' updated to '{value}'[/green]") 73 | else: 74 | console.print("[red]Invalid settings command[/red]") 75 | 76 | 77 | def make_environment(): 78 | git_repo = GitRepo.from_current_dir() 79 | git_tracking_enabled = SettingsManager.get_setting("git_tracking") == "on" 80 | if git_tracking_enabled and git_repo: 81 | env = GitEnvironment(git_repo) 82 | else: 83 | env = NoopEnvironment() 84 | return env 85 | 86 | 87 | @weave.op 88 | def session(agent: Agent, agent_state: AgentState): 89 | call = weave.get_current_call() 90 | 91 | session_id = None 92 | if call: 93 | session_id = call.id 94 | 95 | env = make_environment() 96 | 97 | with environment_session(env, session_id): 98 | agent_state = agent_state.with_history(agent_state.history) 99 | while True: 100 | result = agent.run(agent_state) 101 | agent_state = result["state"] 102 | agent_state = user_input_step(agent_state) 103 | 104 | 105 | def programmer(): 106 | parser = argparse.ArgumentParser(description="Programmer") 107 | subparsers = parser.add_subparsers(dest="command") 108 | 109 | # Subparser for the settings command 110 | settings_parser = subparsers.add_parser("settings", help="Manage settings") 111 | settings_parser.add_argument( 112 | "action", choices=["get", "set"], help="Action to perform" 113 | ) 114 | settings_parser.add_argument("key", help="The setting key") 115 | settings_parser.add_argument("value", nargs="?", help="The value to set") 116 | 117 | ui_parser = subparsers.add_parser("ui", help="Run the local UI") 118 | 119 | # Subparser for the prompt command 120 | prompt_parser = subparsers.add_parser( 121 | "prompt", help="Send initial prompt to the LLM" 122 | ) 123 | prompt_parser.add_argument( 124 | "prompt_args", nargs=argparse.REMAINDER, help="The prompt to send" 125 | ) 126 | 127 | parser.add_argument( 128 | "--state", type=str, help="weave ref of the state to begin from" 129 | ) 130 | 131 | # Initialize settings 132 | SettingsManager.initialize_settings() 133 | logging_mode = SettingsManager.get_setting("weave_logging") 134 | if logging_mode == "cloud": 135 | curdir = os.path.basename(os.path.abspath(os.curdir)) 136 | weave.init(f"programmer-{curdir}") 137 | elif logging_mode == "local": 138 | init_local_client(os.path.join(SettingsManager.PROGRAMMER_DIR, "weave.db")) 139 | 140 | args = parser.parse_args() 141 | 142 | if args.command == "settings": 143 | settings_command( 144 | [args.action, args.key, args.value] 145 | if args.value 146 | else [args.action, args.key] 147 | ) 148 | return 149 | elif args.command == "ui": 150 | module_path = os.path.abspath(__file__) 151 | module_dir = os.path.dirname(module_path) 152 | ui_path = os.path.join(module_dir, "..", "programmer-ui", "ui.py") 153 | subprocess.run(["streamlit", "run", ui_path]) 154 | return 155 | elif args.command == "prompt": 156 | # Handled later. 157 | pass 158 | 159 | # log to local sqlite db for now 160 | 161 | if args.state: 162 | state = weave.ref(args.state).get() 163 | if state.env_snapshot_key: 164 | environment = restore_environment(state.env_snapshot_key) 165 | 166 | agent_name = SettingsManager.get_setting("agent") 167 | if not agent_name: 168 | raise ValueError( 169 | "No agent name set. Please set the agent name in the settings." 170 | ) 171 | agent = get_config_by_name(agent_name) 172 | if not agent: 173 | raise ValueError( 174 | f"Agent {agent_name} not found. Please set a valid agent name in the settings." 175 | ) 176 | 177 | Console.welcome(agent_name=agent.name) 178 | 179 | if args.command == "prompt": 180 | initial_prompt = " ".join(args.prompt_args) 181 | print("Initial prompt:", initial_prompt) 182 | else: 183 | initial_prompt = input("Initial prompt: ") 184 | 185 | state = agent.initial_state( 186 | [ 187 | { 188 | "role": "user", 189 | "content": initial_prompt, 190 | }, 191 | ] 192 | ) 193 | 194 | session(agent, state) 195 | 196 | 197 | def main(): 198 | try: 199 | programmer() 200 | except KeyboardInterrupt: 201 | pass 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | -------------------------------------------------------------------------------- /programmer/settings_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .config import get_all_config_names 3 | 4 | 5 | class SettingsError(Exception): 6 | pass 7 | 8 | 9 | class SettingsManager: 10 | PROGRAMMER_DIR = ".programmer" 11 | SETTINGS_FILE = "settings" 12 | DEFAULT_SETTINGS = { 13 | "agent": "gpt-4o-2024-08-06_basic", 14 | "weave_logging": "local", 15 | "git_tracking": "off", 16 | } 17 | ALLOWED_VALUES = { 18 | "agent": get_all_config_names(), 19 | "weave_logging": ["off", "local", "cloud"], 20 | "git_tracking": ["off", "on"], 21 | } 22 | 23 | @classmethod 24 | def set_settings_dir(cls, dir_path): 25 | cls.PROGRAMMER_DIR = dir_path 26 | 27 | @staticmethod 28 | def initialize_settings(): 29 | """ 30 | Ensure that the settings directory and file exist, and populate missing settings with defaults. 31 | """ 32 | # Import GitRepo from git module 33 | from .git import GitRepo 34 | 35 | # Check if we're in a Git repository 36 | settings_dir = None 37 | git_repo = GitRepo.from_current_dir() 38 | if git_repo: 39 | # If in a Git repo, set the settings directory to the repo root 40 | repo_root = git_repo.repo.working_tree_dir 41 | if repo_root: 42 | settings_dir = os.path.join(repo_root, SettingsManager.PROGRAMMER_DIR) 43 | if not settings_dir: 44 | # use abs path 45 | settings_dir = os.path.abspath(SettingsManager.PROGRAMMER_DIR) 46 | 47 | SettingsManager.PROGRAMMER_DIR = settings_dir 48 | 49 | if not os.path.exists(SettingsManager.PROGRAMMER_DIR): 50 | os.makedirs(SettingsManager.PROGRAMMER_DIR) 51 | settings_path = os.path.join( 52 | SettingsManager.PROGRAMMER_DIR, SettingsManager.SETTINGS_FILE 53 | ) 54 | if not os.path.exists(settings_path): 55 | SettingsManager.write_default_settings() 56 | else: 57 | SettingsManager.validate_and_complete_settings() 58 | 59 | @staticmethod 60 | def validate_and_complete_settings(): 61 | """ 62 | Validate the settings file format and complete it with default values if necessary. 63 | """ 64 | settings_path = os.path.join( 65 | SettingsManager.PROGRAMMER_DIR, SettingsManager.SETTINGS_FILE 66 | ) 67 | with open(settings_path, "r") as f: 68 | lines = f.readlines() 69 | 70 | settings = {} 71 | for line in lines: 72 | if "=" not in line: 73 | raise SettingsError( 74 | f"Malformed settings line: '{line.strip()}'.\n" 75 | f"Please ensure each setting is in 'key=value' format.\n" 76 | f"Settings file location: {settings_path}" 77 | ) 78 | key, value = line.strip().split("=", 1) 79 | if ( 80 | key in SettingsManager.ALLOWED_VALUES 81 | and value not in SettingsManager.ALLOWED_VALUES[key] 82 | ): 83 | raise SettingsError( 84 | f"Invalid value '{value}' for setting '{key}'. Allowed values are: {SettingsManager.ALLOWED_VALUES[key]}\n" 85 | f"Settings file location: {settings_path}" 86 | ) 87 | settings[key] = value 88 | 89 | # Add missing default settings 90 | for key, default_value in SettingsManager.DEFAULT_SETTINGS.items(): 91 | if key not in settings: 92 | settings[key] = default_value 93 | 94 | # Rewrite the settings file with complete settings 95 | with open(settings_path, "w") as f: 96 | for key, value in settings.items(): 97 | f.write(f"{key}={value}\n") 98 | 99 | @staticmethod 100 | def write_default_settings(): 101 | """ 102 | Write the default settings to the settings file. 103 | """ 104 | settings_path = os.path.join( 105 | SettingsManager.PROGRAMMER_DIR, SettingsManager.SETTINGS_FILE 106 | ) 107 | with open(settings_path, "w") as f: 108 | for key, value in SettingsManager.DEFAULT_SETTINGS.items(): 109 | f.write(f"{key}={value}\n") 110 | 111 | @staticmethod 112 | def get_setting(key): 113 | """ 114 | Retrieve a setting's value by key. 115 | """ 116 | settings_path = os.path.join( 117 | SettingsManager.PROGRAMMER_DIR, SettingsManager.SETTINGS_FILE 118 | ) 119 | if not os.path.exists(settings_path): 120 | return None 121 | with open(settings_path, "r") as f: 122 | for line in f.readlines(): 123 | if line.startswith(key): 124 | return line.split("=")[1].strip() 125 | return None 126 | 127 | @staticmethod 128 | def set_setting(key, value): 129 | """ 130 | Set a setting's value by key, validating allowed values. 131 | """ 132 | settings_path = os.path.join( 133 | SettingsManager.PROGRAMMER_DIR, SettingsManager.SETTINGS_FILE 134 | ) 135 | if ( 136 | key in SettingsManager.ALLOWED_VALUES 137 | and value not in SettingsManager.ALLOWED_VALUES[key] 138 | ): 139 | raise SettingsError( 140 | f"Invalid value '{value}' for setting '{key}'. Allowed values are: {SettingsManager.ALLOWED_VALUES[key]}\n" 141 | f"Settings file location: {settings_path}" 142 | ) 143 | 144 | lines = [] 145 | found = False 146 | if os.path.exists(settings_path): 147 | with open(settings_path, "r") as f: 148 | lines = f.readlines() 149 | for i, line in enumerate(lines): 150 | if line.startswith(key): 151 | lines[i] = f"{key}={value}\n" 152 | found = True 153 | break 154 | if not found: 155 | lines.append(f"{key}={value}\n") 156 | with open(settings_path, "w") as f: 157 | f.writelines(lines) 158 | -------------------------------------------------------------------------------- /programmer/swebench/README.md: -------------------------------------------------------------------------------- 1 | # SWE Bench programmer evaluation 2 | 3 | This is a custom setup to run fast SWE-bench evals on programmer. The steps are: 4 | - serve swebench docker containers from a remote machine 5 | - setup an x86 machine (I use a gcp e2-standard-32) 6 | - build the swebench instance images. For SWE-bench_Verified this builds about 550 images. 7 | - run [containerserver](../containerserver/README.md) on the machine. containerserver serves an HTTP interface into the Docker containers. 8 | - on your local machine, run python -m programmer.swebench.run_instance or python -m programmer.swebench.evaluate 9 | 10 | ## Build SWE-bench images 11 | 12 | First do setup (below) then run this command to build all the images. --cache_level instance tells the script not to delete the instance images, which are what we want to use with container-manager. 13 | 14 | ``` 15 | python -m swebench.harness.run_evaluation \ 16 | --predictions_path gold \ 17 | --max_workers 24 \ 18 | --run_id validate-gold \ 19 | --dataset_name princeton-nlp/SWE-bench_Verified \ 20 | --cache_level instance 21 | ``` 22 | 23 | ## Run containerserver 24 | 25 | See [containerserver](../containerserver/README.md) for setup and running containerserver. 26 | 27 | 28 | ## remote machine setup instructions on gcp VM ubuntu 20.04 29 | 30 | ``` 31 | 32 | sudo snap install docker 33 | sudo groupadd docker 34 | sudo usermod -aG docker $USER 35 | sudo chown root:docker /var/run/docker.sock 36 | sudo chmod 660 /var/run/docker.sock 37 | 38 | sudo apt update 39 | sudo apt install -y \ 40 | build-essential \ 41 | libbz2-dev \ 42 | libreadline-dev \ 43 | libssl-dev \ 44 | zlib1g-dev \ 45 | libsqlite3-dev \ 46 | libffi-dev \ 47 | libncursesw5-dev \ 48 | libgdbm-dev \ 49 | liblzma-dev \ 50 | tk-dev \ 51 | libdb-dev \ 52 | libexpat1-dev \ 53 | libmpdec-dev \ 54 | libxml2-dev \ 55 | libxmlsec1-dev \ 56 | libffi-dev \ 57 | liblzma-dev 58 | 59 | # pyenv 60 | curl https://pyenv.run | bash 61 | echo 'export PYENV_ROOT="$HOME/.pyenv" 62 | [[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH" 63 | eval "$(pyenv init -)" 64 | eval "$(pyenv virtualenv-init -)"' >> ~/.bashrc 65 | ## exit and re-log in 66 | 67 | pyenv install 3.10.12 68 | pyenv virtualenv 3.10.12 swe-bench 69 | 70 | git clone https://github.com/princeton-nlp/SWE-bench.git 71 | cd SWE-bench 72 | pyenv local swe-bench 73 | pip install -e . 74 | ``` -------------------------------------------------------------------------------- /programmer/swebench/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/programmer/9b25bc7d4d4e4f4d685737de66147c248d1159d8/programmer/swebench/__init__.py -------------------------------------------------------------------------------- /programmer/swebench/data/swebench-verified.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wandb/programmer/9b25bc7d4d4e4f4d685737de66147c248d1159d8/programmer/swebench/data/swebench-verified.parquet -------------------------------------------------------------------------------- /programmer/swebench/evaluate.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pandas as pd 3 | from typing import Optional 4 | import random 5 | import weave 6 | 7 | from .swebench_model import SWEBenchProgrammerModel 8 | from .score import score_swebench 9 | from ..agent import Agent 10 | from ..config import agent_4o_basic 11 | 12 | 13 | def load_raw_dataset(name: str, split: str): 14 | return pd.read_parquet( 15 | f"hf://datasets/princeton-nlp/{name}/data/{split}-00000-of-00001.parquet" 16 | ) 17 | 18 | 19 | def load_weave_dataset( 20 | name: str, 21 | split: str, 22 | limit: Optional[int] = None, 23 | instance_ids: Optional[list[str]] = None, 24 | shuffle_seed: Optional[int] = None, 25 | ): 26 | df = load_raw_dataset(name, split) 27 | 28 | data_list = df.to_dict("records") 29 | if shuffle_seed is not None: 30 | random.seed(shuffle_seed) 31 | random.shuffle(data_list) 32 | data_list = [ 33 | r for r in data_list if instance_ids is None or r["instance_id"] in instance_ids 34 | ] 35 | data_list = data_list[:limit] if limit else data_list 36 | data_list = [{"instance": r} for r in data_list] 37 | 38 | return weave.Dataset(name=f"Verified-{limit}-{shuffle_seed}", rows=data_list) # type: ignore 39 | 40 | 41 | def main(): 42 | weave.init("weavedev-programmereval1") 43 | instance_ids = [ 44 | "django__django-16569", 45 | "django__django-11099", 46 | "scikit-learn__scikit-learn-12585", 47 | "django__django-13658", 48 | "django__django-9296", 49 | "astropy__astropy-14309", 50 | "django__django-12155", 51 | "django__django-16527", 52 | "sympy__sympy-24213", 53 | "django__django-11066", 54 | ] 55 | # ds = load_weave_dataset("SWE-bench_Verified", "test", instance_ids=instance_ids) 56 | ds = load_weave_dataset("SWE-bench_Verified", "test", limit=50, shuffle_seed=42) 57 | eval = weave.Evaluation( 58 | name="SWE-bench_Verified", dataset=ds, scorers=[score_swebench], trials=1 59 | ) 60 | 61 | model = SWEBenchProgrammerModel( 62 | agent=agent_4o_basic, 63 | max_runtime_seconds=180, 64 | ) 65 | res = asyncio.run(eval.evaluate(model)) 66 | print("RES", res) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /programmer/swebench/ingest/README.md: -------------------------------------------------------------------------------- 1 | # weave swe-bench eval ingestion 2 | 3 | these scripts slurp existing https://github.com/swe-bench/experiments results into weave evals 4 | -------------------------------------------------------------------------------- /programmer/swebench/ingest/ingest_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import asyncio 5 | import json 6 | import contextvars 7 | from rich import print 8 | 9 | import weave 10 | from make_dataset import load_weave_dataset 11 | 12 | 13 | context_var = contextvars.ContextVar("context", default={}) 14 | 15 | 16 | def load_instance_eval_file( 17 | experiments_repo_path, dataset_name, model_name, instance_id, file_name 18 | ): 19 | dataset_name_short = dataset_name.split("_")[1].lower() 20 | file_path = os.path.join( 21 | experiments_repo_path, 22 | "evaluation", 23 | dataset_name_short, 24 | model_name, 25 | "logs", 26 | instance_id, 27 | file_name, 28 | ) 29 | print(f"Loading file: {file_path}") 30 | 31 | if os.path.exists(file_path): 32 | with open(file_path, "r") as file: 33 | return file.read() 34 | else: 35 | return None 36 | 37 | 38 | def load_instance_eval_from_logs( 39 | experiments_repo_path, dataset_name, model_name, instance_id 40 | ): 41 | report_json_file = load_instance_eval_file( 42 | experiments_repo_path, 43 | dataset_name, 44 | model_name, 45 | instance_id, 46 | "report.json", 47 | ) 48 | report_json = None 49 | if report_json_file is not None: 50 | report_json = json.loads(report_json_file).get(instance_id) 51 | no_report = False 52 | if report_json is None: 53 | no_report = True 54 | 55 | return { 56 | "patch": load_instance_eval_file( 57 | experiments_repo_path, dataset_name, model_name, instance_id, "patch.diff" 58 | ), 59 | "report": report_json, 60 | "no_report": no_report, 61 | } 62 | 63 | 64 | def load_instance_eval_from_results( 65 | experiments_repo_path, dataset_name, model_name, instance_id 66 | ): 67 | dataset_name_short = dataset_name.split("_")[1].lower() 68 | file_path = os.path.join( 69 | experiments_repo_path, 70 | "evaluation", 71 | dataset_name_short, 72 | model_name, 73 | "results", 74 | "results.json", 75 | ) 76 | with open(file_path, "r") as file: 77 | results = json.loads(file.read()) 78 | summary = {} 79 | for k, instance_ids in results.items(): 80 | summary[k] = instance_id in instance_ids 81 | 82 | return summary 83 | 84 | 85 | class SWEBenchOfflineModel(weave.Model): 86 | @weave.op 87 | def predict(self, instance_id: str): 88 | context = context_var.get() 89 | experiments_repo_path = context.get("experiments_repo_path") 90 | dataset_name = context.get("dataset_name") 91 | return load_instance_eval_from_results( 92 | experiments_repo_path, dataset_name, self.name, instance_id 93 | ) 94 | 95 | 96 | @weave.op 97 | def score_from_logs(model_output: dict): 98 | result = {} 99 | if model_output.get("report"): 100 | result.update(model_output["report"]) 101 | result["no_report"] = model_output["no_report"] 102 | return result 103 | 104 | 105 | @weave.op 106 | def score(model_output: dict): 107 | return model_output 108 | 109 | 110 | def ingest_eval(experiments_repo_path, dataset_name, model_name): 111 | print(f"Ingesting evaluation logs for:") 112 | print(f"Dataset: {dataset_name}") 113 | print(f"Model: {model_name}") 114 | print(f"From repository: {experiments_repo_path}") 115 | 116 | dataset = load_weave_dataset(dataset_name, "test") 117 | eval = weave.Evaluation(name=dataset_name, dataset=dataset, scorers=[score]) 118 | 119 | context_var.set( 120 | { 121 | "experiments_repo_path": experiments_repo_path, 122 | "dataset_name": dataset_name, 123 | } 124 | ) 125 | 126 | model = SWEBenchOfflineModel(name=model_name) 127 | # result, call = asyncio.run(eval.evaluate.call(eval, model)) 128 | result = asyncio.run(eval.evaluate(model)) 129 | 130 | print(result) 131 | # call.set_display_name(model_name) 132 | 133 | 134 | def ingest_evals(experiments_repo_path, dataset_name): 135 | dataset_name_short = dataset_name.split("_")[1].lower() 136 | models_dir = os.path.join(experiments_repo_path, "evaluation", dataset_name_short) 137 | for model_name in os.listdir(models_dir): 138 | ingest_eval(experiments_repo_path, dataset_name, model_name) 139 | 140 | 141 | def main(): 142 | parser = argparse.ArgumentParser(description="Ingest evaluation logs into Weave.") 143 | parser.add_argument( 144 | "--experiments_repo_path", help="Path to the experiments repository" 145 | ) 146 | parser.add_argument( 147 | "--dataset_name", 148 | choices=["SWE-bench", "SWE-bench_Verified", "SWE-bench_Lite"], 149 | default="SWE-bench_Verified", 150 | help="Name of the dataset", 151 | ) 152 | parser.add_argument("--model_name", help="Name of the model") 153 | 154 | args = parser.parse_args() 155 | 156 | if not args.experiments_repo_path or not os.path.exists(args.experiments_repo_path): 157 | print( 158 | f"Error: Experiments repository path does not exist: {args.experiments_repo_path}" 159 | ) 160 | sys.exit(1) 161 | 162 | # Initialize Weave 163 | weave.init("weavedev-swebench5") 164 | 165 | if args.model_name: 166 | ingest_eval(args.experiments_repo_path, args.dataset_name, args.model_name) 167 | else: 168 | ingest_evals(args.experiments_repo_path, args.dataset_name) 169 | 170 | 171 | if __name__ == "__main__": 172 | main() 173 | # from rich import print 174 | 175 | # print( 176 | # load_instance_eval( 177 | # "/Users/shawnlewis/code/experiments", 178 | # "SWE-bench_Verified", 179 | # "20240620_sweagent_claude3.5sonnet", 180 | # "sympy__sympy-24661", 181 | # ) 182 | # ) 183 | -------------------------------------------------------------------------------- /programmer/swebench/ingest/make_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from typing import Optional 4 | import pandas as pd 5 | import weave 6 | 7 | 8 | splits = { 9 | "dev": "data/dev-00000-of-00001.parquet", 10 | "test": "data/test-00000-of-00001.parquet", 11 | "train": "data/train-00000-of-00001.parquet", 12 | } 13 | 14 | 15 | def load_raw_dataset(name: str, split: str): 16 | return pd.read_parquet( 17 | f"hf://datasets/princeton-nlp/{name}/data/{split}-00000-of-00001.parquet" 18 | ) 19 | 20 | 21 | def load_weave_dataset(name: str, split: str, limit: Optional[int] = None): 22 | df = load_raw_dataset(name, split) 23 | 24 | data_list = df.to_dict("records") 25 | data_list = data_list[:limit] if limit else data_list 26 | 27 | return weave.Dataset(name=f"Verified-{limit}", rows=data_list) # type: ignore 28 | 29 | 30 | def main(dataset_name="SWE-bench_Verified", split="test"): 31 | valid_datasets = ["SWE-bench", "SWE-bench_Verified", "SWE-bench_Lite"] 32 | valid_splits = ["dev", "test", "train"] 33 | 34 | if dataset_name not in valid_datasets: 35 | print(f"Error: Invalid dataset name. Choose from {', '.join(valid_datasets)}") 36 | sys.exit(1) 37 | 38 | if split not in valid_splits: 39 | print(f"Error: Invalid split. Choose from {', '.join(valid_splits)}") 40 | sys.exit(1) 41 | 42 | print(f"Creating dataset: {dataset_name}") 43 | print(f"Split: {split}") 44 | 45 | weave.init("weavedev-swebench1") 46 | 47 | df = load_raw_dataset(dataset_name, split) 48 | 49 | data_list = df.to_dict("records") 50 | 51 | dataset = weave.Dataset(rows=data_list) # type: ignore 52 | 53 | weave.publish(dataset, f"{dataset_name}_{split}") 54 | 55 | print(f"Dataset '{dataset_name}_{split}' created and saved successfully.") 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser( 60 | description="Create a dataset with specified name and split." 61 | ) 62 | parser.add_argument( 63 | "--dataset_name", 64 | choices=["SWE-bench", "SWE-bench_Verified", "SWE-bench_Lite"], 65 | default="SWE-bench_Verified", 66 | help="Name of the dataset to create", 67 | ) 68 | parser.add_argument( 69 | "--split", 70 | choices=["dev", "test", "train"], 71 | default="test", 72 | help="Split of the dataset to create", 73 | ) 74 | 75 | args = parser.parse_args() 76 | main(args.dataset_name, args.split) 77 | -------------------------------------------------------------------------------- /programmer/swebench/ingest/requirements.txt: -------------------------------------------------------------------------------- 1 | weave 2 | pandas 3 | fsspec 4 | huggingface_hub -------------------------------------------------------------------------------- /programmer/swebench/run_instance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | 5 | from rich import print 6 | 7 | import weave 8 | 9 | from ..weave_next.api import init_local_client 10 | from ..settings_manager import SettingsManager 11 | 12 | from ..swebench.swebench_model import SWEBenchProgrammerModel 13 | from ..swebench.score import score_swebench 14 | from ..config import agent_4o_basic 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser(description="Programmer") 19 | parser.add_argument( 20 | "--instance_id", type=str, help="The instance id to run", required=True 21 | ) 22 | 23 | # Initialize settings 24 | SettingsManager.initialize_settings() 25 | logging_mode = SettingsManager.get_setting("weave_logging") 26 | if logging_mode == "cloud": 27 | curdir = os.path.basename(os.path.abspath(os.curdir)) 28 | weave.init(f"programmer-{curdir}") 29 | elif logging_mode == "local": 30 | init_local_client(os.path.join(SettingsManager.PROGRAMMER_DIR, "weave.db")) 31 | 32 | args = parser.parse_args() 33 | 34 | df = pd.read_parquet("programmer/swebench/data/swebench-verified.parquet") 35 | 36 | instance_id = args.instance_id 37 | instance = df[df["instance_id"] == instance_id].iloc[0] 38 | problem_statement = instance["problem_statement"] 39 | 40 | print("PROBLEM STATEMENT\n", problem_statement) 41 | print() 42 | print("SOLUTION\n", instance["patch"]) 43 | print() 44 | 45 | model = SWEBenchProgrammerModel(agent=agent_4o_basic) 46 | model_output = model.predict(instance) 47 | score = score_swebench(instance, model_output["answer"]) 48 | print("SCORE\n", score) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /programmer/swebench/score.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from swebench.harness.test_spec import make_test_spec 3 | from swebench.harness.log_parsers import MAP_REPO_TO_PARSER 4 | from swebench.harness.grading import get_eval_tests_report, get_resolution_status 5 | from swebench.harness.constants import ( 6 | FAIL_TO_PASS, 7 | KEY_INSTANCE_ID, 8 | PASS_TO_PASS, 9 | ResolvedStatus, 10 | SWEbenchInstance, 11 | ) 12 | 13 | from ..io_context import RemoteContainerIOContext 14 | 15 | 16 | def score_swebench(instance: SWEbenchInstance, model_output): 17 | patch = model_output["answer"] 18 | tc = RemoteContainerIOContext( 19 | "http://localhost:8000", 20 | "/testbed", 21 | "source /opt/miniconda3/bin/activate && conda activate testbed && ", 22 | ) 23 | 24 | result: dict[str, Any] = {"patch_successfully_applied": False, "resolved": False} 25 | 26 | ts = make_test_spec(instance) 27 | container_id = f"sweb.eval.x86_64.{ts.instance_id}" 28 | with tc.context(container_id): 29 | print("EVAL SCRIPT\n", ts.eval_script) 30 | 31 | tc.write_file("/tmp/patch.diff", patch) 32 | patch_result = tc.run_command("git apply -v /tmp/patch.diff") 33 | if patch_result["exit_code"] == 0: 34 | result["patch_successfully_applied"] = True 35 | print("PATCH RESULT\n", patch_result) 36 | 37 | tc.write_file("/eval.sh", ts.eval_script) 38 | test_command_results = tc.run_command("chmod +x /eval.sh && /eval.sh") 39 | tc_output = test_command_results["output"] 40 | 41 | repo = "-".join( 42 | ts.instance_id.replace("__", "/").split("-")[:-1] 43 | ) # e.g. scikit-learn/scikit-learn 44 | log_parser = MAP_REPO_TO_PARSER[repo] 45 | test_name_to_passfail = log_parser(tc_output) 46 | 47 | eval_ref = { 48 | KEY_INSTANCE_ID: ts.instance_id, 49 | FAIL_TO_PASS: ts.FAIL_TO_PASS, 50 | PASS_TO_PASS: ts.PASS_TO_PASS, 51 | } 52 | 53 | report = get_eval_tests_report(test_name_to_passfail, eval_ref) 54 | resolved = get_resolution_status(report) == ResolvedStatus.FULL.value 55 | 56 | result.update({"resolved": resolved, "tests_status": report}) 57 | 58 | return result 59 | -------------------------------------------------------------------------------- /programmer/swebench/scripts/example_v_models.py: -------------------------------------------------------------------------------- 1 | # using existing swe-bench results logged to weave (see ingest dir), 2 | # produce a table with instance_id as rows, and models as columns. 3 | # useful for finding easy / hard examples 4 | 5 | import sys 6 | import pandas as pd 7 | 8 | import weave 9 | 10 | from ...weave_next.weave_query import calls 11 | 12 | 13 | def main(): 14 | if len(sys.argv) > 1: 15 | wc = weave.init("weavedev-swebench5") 16 | c = calls(wc, "Evaluation.predict_and_score", expand_refs=["inputs.example"]) 17 | df = c.to_pandas() 18 | 19 | df.to_parquet("verified.parquet", engine="pyarrow") 20 | else: 21 | df = pd.read_parquet("verified.parquet") 22 | # Pivot the dataframe 23 | pivot_df = df.pivot( 24 | index="inputs.example.instance_id", 25 | columns="inputs.model", 26 | values="output.model_output.resolved", 27 | ) 28 | 29 | # Extract model names from the column names 30 | pivot_df.columns = pivot_df.columns.str.extract(r"object/(.+):")[0] 31 | 32 | # Count models with resolved True for each instance 33 | pivot_df["models_resolved_true"] = pivot_df.apply(lambda row: row.sum(), axis=1) 34 | 35 | # Move the count column to the leftmost position 36 | cols = pivot_df.columns.tolist() 37 | cols = cols[-1:] + cols[:-1] 38 | pivot_df = pivot_df[cols] 39 | 40 | # Sort the pivot table by 'models_resolved_true' in descending order 41 | pivot_df = pivot_df.sort_values(by="models_resolved_true", ascending=False) # type: ignore 42 | 43 | # Sort columns by the model that got the most resolved 44 | model_success_count = pivot_df.sum().sort_values(ascending=False) 45 | sorted_columns = ["models_resolved_true"] + model_success_count.index.tolist() 46 | pivot_df = pivot_df[sorted_columns] 47 | 48 | # Display the first few rows of the resulting table 49 | print(pivot_df.head()) 50 | 51 | # Optionally, save the pivot table to a new file 52 | pivot_df.to_csv("pivot_table.csv") 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /programmer/swebench/scripts/verified_difficulty_labels.py: -------------------------------------------------------------------------------- 1 | # Quick script for viewing swebench examples against 2 | # annotated difficulties. 3 | # TODO: update for new file paths (in ../data) 4 | 5 | import pandas as pd 6 | import textwrap 7 | 8 | swebench_df = pd.read_parquet("swebench-verified.parquet") 9 | anno_df = pd.read_csv("ensembled_annotations_public.csv") 10 | df = swebench_df.merge(anno_df, on="instance_id", how="left") 11 | difficulty_counts = df.groupby(["repo", "difficulty"]).size().unstack(fill_value=0) 12 | 13 | # Display the difficulty value counts for each repo 14 | print("Difficulty value counts for each repo:") 15 | print(difficulty_counts) 16 | 17 | 18 | filtered_df = df[ 19 | ( 20 | df["repo"].isin(["sphinx-doc/sphinx", "sympy/sympy"]) 21 | & (df["difficulty"] == "<15 min fix") 22 | ) 23 | ] 24 | 25 | # Display the count of filtered examples 26 | print("\nNumber of '<15 min fix' examples from astropy and sympy:") 27 | print(filtered_df["repo"].value_counts()) # type: ignore 28 | 29 | print(filtered_df) 30 | example = filtered_df.loc[498] 31 | print(example) 32 | 33 | # with open("problem.txt", "w") as f: 34 | # f.write( 35 | # f""" 36 | # 37 | # {example["problem_statement"]} 38 | # 39 | # 40 | # {example["hints_text"]} 41 | # 42 | # """ 43 | # ) 44 | 45 | # Programmer with new tools passed on 497 46 | 47 | # do we need hint text? 48 | 49 | with open("problem.txt", "w") as f: 50 | f.write( 51 | f""" 52 | 53 | {example["problem_statement"]} 54 | 55 | """ 56 | ) 57 | print("FAIL_TO_PASS", example["FAIL_TO_PASS"]) 58 | print("PASS_TO_PASS", example["PASS_TO_PASS"]) 59 | 60 | print("PROBLEM\n", example["problem_statement"]) 61 | print("HINT\n", example["hints_text"]) 62 | 63 | print("PATCH\n", example["patch"]) 64 | print("TEST_PATCH\n", example["test_patch"]) 65 | 66 | with open("code.patch", "w") as f: 67 | f.write(example["patch"]) 68 | with open("test_code.patch", "w") as f: 69 | f.write(example["test_patch"]) 70 | 71 | 72 | # Display a few examples 73 | -------------------------------------------------------------------------------- /programmer/swebench/swebench_model.py: -------------------------------------------------------------------------------- 1 | import weave 2 | 3 | from ..agent import Agent, AgentState 4 | from ..io_context import RemoteContainerIOContext 5 | 6 | 7 | class SWEBenchProgrammerModel(weave.Model): 8 | agent: Agent 9 | max_runtime_seconds: int = 60 10 | 11 | def predict(self, instance): 12 | instance_id = instance["instance_id"] 13 | problem_statement = instance["problem_statement"] 14 | initial_prompt = f"""You are in a checkout of the a git repo. Please identify and fix the issue described in the problem statement. 15 | 16 | 17 | {problem_statement} 18 | """ 19 | state = AgentState( 20 | history=[ 21 | { 22 | "role": "user", 23 | "content": initial_prompt, 24 | }, 25 | ], 26 | ) 27 | 28 | tc = RemoteContainerIOContext( 29 | "http://localhost:8000", 30 | "/testbed", 31 | "source /opt/miniconda3/bin/activate && conda activate testbed && ", 32 | ) 33 | container_id = f"sweb.eval.x86_64.{instance_id}" 34 | with tc.context(container_id): 35 | result = self.agent.run(state, max_runtime_seconds=self.max_runtime_seconds) 36 | if result["stop_reason"] == "time_limit_exceeded": 37 | return {"errorcode": "runtime", "answer": ""} 38 | answer_result = tc.run_command("git diff") 39 | answer = answer_result["output"] 40 | return {"answer": answer} 41 | return {"answer": answer} 42 | -------------------------------------------------------------------------------- /programmer/tests/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | import warnings 3 | import pytest 4 | import logging 5 | 6 | from weave.trace.weave_client import WeaveClient 7 | from weave.trace_server.sqlite_trace_server import SqliteTraceServer 8 | from weave.trace.weave_init import InitializedClient 9 | from programmer.weave_next.api import make_external_sql_server 10 | 11 | # Set up logging 12 | logging.basicConfig(level=logging.DEBUG) 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | @pytest.fixture() 17 | def weave_client() -> Generator[WeaveClient, None, None]: 18 | project_id = "pytest-test-project" # Directly use a safe project id 19 | logger.debug(f"Using project_id: {project_id}") 20 | sqlite_server = SqliteTraceServer("file::memory:?cache=shared") 21 | sqlite_server.drop_tables() 22 | sqlite_server.setup_tables() 23 | sqlite_server = make_external_sql_server(sqlite_server) 24 | client = WeaveClient("pytest", project_id, sqlite_server) 25 | logger.debug(f"Initialized WeaveClient with project_id: {client._project_id()}") 26 | inited_client = InitializedClient(client) 27 | # weave fixture does autopatch.autopatch, do we want that here? 28 | try: 29 | yield inited_client.client 30 | finally: 31 | inited_client.reset() 32 | -------------------------------------------------------------------------------- /programmer/tests/test_file_line_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from tempfile import TemporaryDirectory 4 | from programmer.tools import ( 5 | read_lines_from_file, 6 | splice_lines_in_file, 7 | get_io_context, 8 | ) 9 | from programmer.io_context import LocalIOContext, io_context 10 | 11 | 12 | @pytest.fixture() 13 | def tempdir_tool_context(): 14 | with TemporaryDirectory() as tmpdir: 15 | with io_context(LocalIOContext(tmpdir)) as tc: 16 | yield tc 17 | 18 | 19 | @pytest.fixture() 20 | def test_file_path(tempdir_tool_context): 21 | file_path = "test_file.txt" 22 | tempdir_tool_context.write_file( 23 | file_path, 24 | "Line 1\nLine 2\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\n", 25 | ) 26 | yield file_path 27 | 28 | 29 | def test_read_lines_from_file(test_file_path): 30 | # Valid read 31 | result = read_lines_from_file(test_file_path, 1) 32 | assert result.startswith("1:Line 1\n") 33 | assert "10:Line 10\n" in result 34 | 35 | # Reading from a middle line 36 | result = read_lines_from_file(test_file_path, 5) 37 | assert result.startswith("5:Line 5\n") 38 | assert "10:Line 10\n" in result 39 | 40 | # Invalid start_line 41 | with pytest.raises(Exception, match="Invalid start_line number."): 42 | read_lines_from_file(test_file_path, 0) 43 | 44 | with pytest.raises(Exception, match="Invalid start_line number."): 45 | read_lines_from_file(test_file_path, 11) 46 | 47 | 48 | def test_replace_lines_in_file(test_file_path): 49 | # Valid replacement 50 | result = splice_lines_in_file( 51 | test_file_path, 52 | 2, 53 | 3, 54 | "Line 2\nLine 3\nLine 4\n", 55 | "New Line 2\nNew Line 3\nNew Line 4\n", 56 | ) 57 | assert "1:Line 1\n" in result 58 | assert "2:New Line 2\n" in result 59 | assert "4:New Line 4\n" in result 60 | assert "10:Line 10\n" in result 61 | 62 | # Replacement with a new file 63 | result = splice_lines_in_file( 64 | "new_test_file.txt", 1, 0, "", "First Line\nSecond Line\n" 65 | ) 66 | assert "1:First Line\n" in result 67 | assert "2:Second Line\n" in result 68 | 69 | splice_lines_in_file(test_file_path, 11, 0, "", "Out of range\n") 70 | 71 | 72 | # Test appending to the end of a file 73 | def test_append_to_file(tempdir_tool_context, test_file_path): 74 | # Read the original content 75 | original_content = tempdir_tool_context.read_file(test_file_path) 76 | 77 | # Append new lines 78 | new_lines = "New Line 11\nNew Line 12\n" 79 | result = splice_lines_in_file(test_file_path, 11, 0, "", new_lines) 80 | 81 | # Verify the file content 82 | updated_content = tempdir_tool_context.read_file(test_file_path) 83 | 84 | assert updated_content == original_content + new_lines 85 | 86 | # Verify that the original content is preserved 87 | assert original_content in updated_content 88 | 89 | # Check that we can still read all lines including the new ones 90 | all_lines = read_lines_from_file(test_file_path, 1) 91 | assert "1:Line 1\n" in all_lines 92 | assert "10:Line 10\n" in all_lines 93 | assert "11:New Line 11\n" in all_lines 94 | assert "12:New Line 12\n" in all_lines 95 | 96 | 97 | # Test inserting at the beginning of an existing file 98 | def test_insert_at_beginning(tempdir_tool_context, test_file_path): 99 | # Read the original content 100 | original_content = tempdir_tool_context.read_file(test_file_path) 101 | 102 | # Insert new lines at the beginning 103 | new_lines = "New First Line\nNew Second Line\n" 104 | result = splice_lines_in_file(test_file_path, 1, 0, "", new_lines) 105 | 106 | # Verify the result 107 | assert "1:New First Line\n" in result 108 | assert "2:New Second Line\n" in result 109 | assert "3:Line 1\n" in result 110 | 111 | # Verify the file content 112 | updated_content = tempdir_tool_context.read_file(test_file_path) 113 | 114 | assert updated_content == new_lines + original_content 115 | 116 | # Check that we can read all lines including the new ones 117 | all_lines = read_lines_from_file(test_file_path, 1) 118 | assert "1:New First Line\n" in all_lines 119 | assert "2:New Second Line\n" in all_lines 120 | assert "3:Line 1\n" in all_lines 121 | assert "12:Line 10\n" in all_lines # Original last line is now at position 12 122 | 123 | 124 | # Test reading, replacing, and reading again 125 | def test_read_replace_read(test_file_path): 126 | # Read the original content 127 | original_content = read_lines_from_file(test_file_path, 1) 128 | 129 | # Verify some original content 130 | assert "1:Line 1\n" in original_content 131 | assert "5:Line 5\n" in original_content 132 | assert "10:Line 10\n" in original_content 133 | 134 | # Replace lines 3-5 with new content 135 | new_lines = "Replaced Line 3\nReplaced Line 4\nReplaced Line 5\n" 136 | replace_result = splice_lines_in_file( 137 | test_file_path, 3, 3, "Line 3\nLine 4\nLine 5\n", new_lines 138 | ) 139 | 140 | # Verify the replace result 141 | assert "3:Replaced Line 3\n" in replace_result 142 | assert "4:Replaced Line 4\n" in replace_result 143 | assert "5:Replaced Line 5\n" in replace_result 144 | assert "6:Line 6\n" in replace_result # Original Line 6 is now at position 6 145 | 146 | # Read the updated content 147 | updated_content = read_lines_from_file(test_file_path, 1) 148 | 149 | # Verify the updated content 150 | assert "1:Line 1\n" in updated_content 151 | assert "2:Line 2\n" in updated_content 152 | assert "3:Replaced Line 3\n" in updated_content 153 | assert "4:Replaced Line 4\n" in updated_content 154 | assert "5:Replaced Line 5\n" in updated_content 155 | assert "6:Line 6\n" in updated_content 156 | assert ( 157 | "10:Line 10\n" in updated_content 158 | ) # Original last line is still at position 10 159 | -------------------------------------------------------------------------------- /programmer/tests/test_git_integration.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tempfile 3 | import shutil 4 | import os 5 | from git import Repo 6 | from programmer.git import GitRepo 7 | 8 | 9 | @pytest.fixture 10 | def setup_repo(): 11 | # Create a temporary directory for the repository 12 | test_dir = tempfile.mkdtemp() 13 | repo = Repo.init(test_dir) 14 | git_repo = GitRepo(repo) 15 | 16 | # Set up user config for the test repo 17 | with repo.config_writer() as config: 18 | config.set_value("user", "name", "Test User") 19 | config.set_value("user", "email", "test@example.com") 20 | 21 | # Create an initial commit so HEAD exists 22 | initial_file_path = os.path.join(test_dir, "initial.txt") 23 | with open(initial_file_path, "w") as f: 24 | f.write("Initial content") 25 | repo.index.add([initial_file_path]) 26 | repo.index.commit("Initial commit") 27 | 28 | # Create and checkout the main branch 29 | main_branch = repo.create_head('main') 30 | main_branch.checkout() 31 | 32 | yield repo, git_repo, test_dir 33 | 34 | # Remove the temporary directory after the test 35 | shutil.rmtree(test_dir) 36 | 37 | 38 | def test_commit_directly_to_branch(setup_repo): 39 | repo, git_repo, test_dir = setup_repo 40 | 41 | # Create and commit a file in the main branch 42 | file_path = os.path.join(test_dir, "test_file.py") 43 | with open(file_path, "w") as f: 44 | f.write("print('Hello, world!')\n") 45 | 46 | repo.index.add([file_path]) 47 | repo.index.commit("Initial commit on main") 48 | 49 | # Modify the file 50 | with open(file_path, "a") as f: 51 | f.write("print('Another line')\n") 52 | 53 | # Commit changes to the programmer- branch 54 | session_branch_name = "programmer-session" 55 | git_repo.create_branch(session_branch_name) 56 | commit_message = "Commit from programmer session" 57 | git_repo.commit_directly_to_branch(session_branch_name, commit_message) 58 | 59 | # Verify the commit in the programmer- branch 60 | session_branch_commit = repo.commit(session_branch_name) 61 | tree_files = session_branch_commit.tree.traverse() 62 | file_names = [item.path for item in tree_files] 63 | 64 | assert "test_file.py" in file_names 65 | 66 | # Verify the content of the file in the commit 67 | blob_data = ( 68 | session_branch_commit.tree["test_file.py"].data_stream.read().decode("utf-8") 69 | ) 70 | assert "print('Another line')" in blob_data 71 | 72 | # Verify that the main branch is unaffected 73 | main_branch_commit = repo.commit("main") 74 | assert main_branch_commit.hexsha != session_branch_commit.hexsha 75 | 76 | 77 | def test_no_empty_commit(setup_repo): 78 | repo, git_repo, _ = setup_repo 79 | 80 | # Create and checkout the programmer- branch 81 | session_branch_name = "programmer-session" 82 | git_repo.create_branch(session_branch_name) 83 | 84 | # Commit changes to the programmer- branch 85 | commit_message = "Commit from programmer session" 86 | initial_commit_sha = repo.commit(session_branch_name).hexsha 87 | commit_sha = git_repo.commit_directly_to_branch(session_branch_name, commit_message) 88 | 89 | # Verify that the SHA returned is equal to the initial commit SHA 90 | assert initial_commit_sha == commit_sha 91 | 92 | # Verify no new commit is created 93 | new_commit_sha = repo.commit(session_branch_name).hexsha 94 | assert initial_commit_sha == new_commit_sha 95 | 96 | 97 | def test_multiple_commits_with_empty(setup_repo): 98 | repo, git_repo, test_dir = setup_repo 99 | 100 | # Create and checkout the programmer- branch 101 | session_branch_name = "programmer-session" 102 | git_repo.create_branch(session_branch_name) 103 | initial_commit_sha = repo.commit(session_branch_name).hexsha 104 | 105 | # First commit (empty) 106 | commit_message = "First empty commit" 107 | commit_sha_1 = git_repo.commit_directly_to_branch(session_branch_name, commit_message) 108 | assert commit_sha_1 == initial_commit_sha 109 | 110 | # Second commit (non-empty) 111 | file_path = os.path.join(test_dir, "test_file.py") 112 | with open(file_path, "w") as f: 113 | f.write("print('Hello, world!')\n") 114 | 115 | git_repo.commit_directly_to_branch(session_branch_name, "Second commit with changes") 116 | commit_sha_2 = repo.commit(session_branch_name).hexsha 117 | assert commit_sha_2 != initial_commit_sha 118 | 119 | # Third commit (empty) 120 | commit_sha_3 = git_repo.commit_directly_to_branch(session_branch_name, "Third empty commit") 121 | assert commit_sha_3 == commit_sha_2 122 | assert commit_sha_3 != initial_commit_sha 123 | -------------------------------------------------------------------------------- /programmer/tests/test_settings_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import tempfile 4 | from programmer.settings_manager import SettingsManager, SettingsError 5 | 6 | 7 | @pytest.fixture(scope="function") 8 | def setup_and_teardown_settings(): 9 | """Fixture to set up and tear down a temporary settings directory for testing.""" 10 | with tempfile.TemporaryDirectory() as temp_dir: 11 | original_settings_dir = SettingsManager.PROGRAMMER_DIR 12 | SettingsManager.set_settings_dir(temp_dir) 13 | test_file = os.path.join(temp_dir, SettingsManager.SETTINGS_FILE) 14 | try: 15 | yield test_file 16 | finally: 17 | # Restore original settings directory 18 | SettingsManager.set_settings_dir(original_settings_dir) 19 | 20 | 21 | def test_initialize_settings_creates_file_with_defaults(setup_and_teardown_settings): 22 | test_file = setup_and_teardown_settings 23 | SettingsManager.initialize_settings() 24 | assert os.path.exists(test_file) 25 | with open(test_file, "r") as f: 26 | settings = f.read().strip() 27 | expected_settings = "\n".join( 28 | f"{key}={value}" for key, value in SettingsManager.DEFAULT_SETTINGS.items() 29 | ) 30 | assert settings == expected_settings 31 | 32 | 33 | def test_get_setting(setup_and_teardown_settings): 34 | SettingsManager.initialize_settings() 35 | for key, value in SettingsManager.DEFAULT_SETTINGS.items(): 36 | assert SettingsManager.get_setting(key) == value 37 | 38 | 39 | def test_set_setting_updates_existing(setup_and_teardown_settings): 40 | SettingsManager.initialize_settings() 41 | SettingsManager.set_setting("weave_logging", "cloud") 42 | assert SettingsManager.get_setting("weave_logging") == "cloud" 43 | 44 | 45 | def test_set_setting_adds_new(setup_and_teardown_settings): 46 | SettingsManager.initialize_settings() 47 | SettingsManager.set_setting("new_setting", "value") 48 | assert SettingsManager.get_setting("new_setting") == "value" 49 | 50 | 51 | def test_validate_and_complete_settings_raises_error_on_malformed_line( 52 | setup_and_teardown_settings, 53 | ): 54 | with open(setup_and_teardown_settings, "w") as f: 55 | f.write("malformed_line\n") 56 | with pytest.raises(SettingsError): 57 | SettingsManager.validate_and_complete_settings() 58 | 59 | 60 | def test_validate_and_complete_settings_adds_missing_defaults( 61 | setup_and_teardown_settings, 62 | ): 63 | with open(setup_and_teardown_settings, "w") as f: 64 | f.write("weave_logging=local\n") # Missing git_tracking 65 | SettingsManager.validate_and_complete_settings() 66 | assert SettingsManager.get_setting("git_tracking") == "off" 67 | 68 | 69 | def test_set_setting_raises_error_on_invalid_value(setup_and_teardown_settings): 70 | SettingsManager.initialize_settings() 71 | with pytest.raises(SettingsError): 72 | SettingsManager.set_setting("weave_logging", "invalid_value") 73 | 74 | with pytest.raises(SettingsError): 75 | SettingsManager.set_setting("git_tracking", "invalid_value") 76 | -------------------------------------------------------------------------------- /programmer/tests/test_text_editor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from tempfile import TemporaryDirectory 3 | from programmer.text_editor import ( 4 | TextEditor, 5 | TextEditorState, 6 | OpenFileState, 7 | LineRange, 8 | OpenFileResult, 9 | WriteFileResult, 10 | TextEditorMutationResult, 11 | ) 12 | from programmer.io_context import LocalIOContext, io_context 13 | 14 | 15 | @pytest.fixture() 16 | def tempdir_tool_context(): 17 | with TemporaryDirectory() as tmpdir: 18 | with io_context(LocalIOContext(tmpdir)) as tc: 19 | yield tc 20 | 21 | 22 | @pytest.fixture() 23 | def sample_file(tempdir_tool_context): 24 | file_path = "sample.txt" 25 | content = "\n".join(f"Line {i}" for i in range(1, 201)) # 200 lines 26 | tempdir_tool_context.write_file(file_path, content) 27 | return file_path 28 | 29 | 30 | @pytest.fixture() 31 | def text_editor(tempdir_tool_context): 32 | return TextEditor(max_open_size=150, open_chunk_size=50) 33 | 34 | 35 | @pytest.fixture() 36 | def initial_state(): 37 | return TextEditorState() 38 | 39 | 40 | def test_open_file(text_editor, sample_file, initial_state): 41 | result = text_editor.open_file(initial_state, sample_file, 1) 42 | assert isinstance(result, TextEditorMutationResult) 43 | assert isinstance(result.action_result, OpenFileResult) 44 | assert result.action_result.success 45 | assert sample_file in result.new_state.open_files 46 | assert ( 47 | result.new_state.open_files[sample_file].total_lines() == 50 48 | ) # OPEN_CHUNK_SIZE 49 | 50 | 51 | def test_open_file_exceed_max_size(tempdir_tool_context, sample_file): 52 | text_editor = TextEditor(max_open_size=75, open_chunk_size=50) 53 | initial_state = TextEditorState() 54 | 55 | # Open the file once (50 lines) 56 | result1 = text_editor.open_file(initial_state, sample_file, 1) 57 | assert result1.action_result.success 58 | 59 | # Try to open another chunk, which would exceed the max_open_size 60 | result2 = text_editor.open_file(result1.new_state, sample_file, 50) 61 | assert isinstance(result2.action_result, OpenFileResult) 62 | assert not result2.action_result.success 63 | assert "exceeding the maximum" in result2.action_result.error 64 | 65 | 66 | def test_open_file_at_boundary(tempdir_tool_context, sample_file): 67 | text_editor = TextEditor(max_open_size=100, open_chunk_size=50) 68 | initial_state = TextEditorState() 69 | 70 | # Open exactly MAX_OPEN_SIZE lines 71 | result1 = text_editor.open_file(initial_state, sample_file, 1) 72 | result2 = text_editor.open_file(result1.new_state, sample_file, 51) 73 | assert result1.action_result.success and result2.action_result.success 74 | assert result2.new_state.total_lines() == 100 # MAX_OPEN_SIZE 75 | 76 | # Try to open one more line, which should fail 77 | result3 = text_editor.open_file(result2.new_state, sample_file, 99) 78 | assert not result3.action_result.success 79 | assert "exceeding the maximum" in result3.action_result.error 80 | 81 | 82 | def test_replace_file_lines_at_boundary(text_editor, sample_file, initial_state): 83 | state1 = text_editor.open_file(initial_state, sample_file, 1).new_state 84 | state2 = text_editor.open_file(state1, sample_file, 51).new_state 85 | state3 = text_editor.open_file(state2, sample_file, 101).new_state 86 | 87 | # Replace 5 lines with 5 new lines (no net change) 88 | result = text_editor.replace_file_lines( 89 | state3, 90 | sample_file, 91 | [{"start_line": 1, "n_lines": 5, "lines": "New Line\n" * 5}], 92 | ) 93 | assert result.action_result.success 94 | 95 | # Try to replace 5 lines with 6 new lines (net increase of 1, should fail) 96 | result = text_editor.replace_file_lines( 97 | state3, 98 | sample_file, 99 | [{"start_line": 1, "n_lines": 5, "lines": "New Line\n" * 6}], 100 | ) 101 | assert not result.action_result.success 102 | assert "exceeding the maximum" in result.action_result.error 103 | 104 | 105 | def test_replace_file_lines_middle(text_editor, sample_file, initial_state): 106 | state1 = text_editor.open_file(initial_state, sample_file, 1).new_state 107 | 108 | # Replace 5 lines with 5 new lines (no net change) 109 | result = text_editor.replace_file_lines( 110 | state1, 111 | sample_file, 112 | [{"start_line": 5, "n_lines": 5, "lines": "A\nB\n"}], 113 | ) 114 | assert result.action_result.success 115 | 116 | # Try to replace 5 lines with 6 new lines (net increase of 1, should fail) 117 | file_info = result.new_state.get_open_file_info() 118 | assert file_info.open_file_buffers[sample_file].total_lines == 197 119 | assert len(file_info.open_file_buffers[sample_file].buffers) == 1 120 | buffer0 = file_info.open_file_buffers[sample_file].buffers[0] 121 | assert buffer0.line_range.start_line == 1 122 | assert buffer0.line_range.n_lines == 50 123 | 124 | 125 | def test_close_file_range(text_editor, sample_file, initial_state): 126 | state1 = text_editor.open_file(initial_state, sample_file, 1).new_state 127 | result = text_editor.close_file_range(state1, sample_file, 1, 25) 128 | assert result.new_state.open_files[sample_file].total_lines() == 25 129 | 130 | 131 | def test_get_open_file_info(text_editor, sample_file, initial_state): 132 | state1 = text_editor.open_file(initial_state, sample_file, 1).new_state 133 | info = state1.get_open_file_info() 134 | assert sample_file in info.open_file_buffers 135 | assert info.open_file_buffers[sample_file].total_lines == 200 136 | assert len(info.open_file_buffers[sample_file].buffers) == 1 137 | assert info.open_file_buffers[sample_file].buffers[0].line_range.start_line == 1 138 | assert info.open_file_buffers[sample_file].buffers[0].line_range.n_lines == 50 139 | 140 | 141 | def test_open_file_multiple_ranges(text_editor, sample_file, initial_state): 142 | state1 = text_editor.open_file(initial_state, sample_file, 1).new_state 143 | state2 = text_editor.open_file(state1, sample_file, 51).new_state 144 | assert len(state2.open_files[sample_file].ranges) == 1 145 | assert state2.open_files[sample_file].ranges[0].start_line == 1 146 | assert state2.open_files[sample_file].ranges[0].n_lines == 100 147 | 148 | 149 | def test_open_file_beyond_end(text_editor, sample_file, initial_state): 150 | result = text_editor.open_file(initial_state, sample_file, 201) 151 | assert isinstance(result.action_result, OpenFileResult) 152 | assert not result.action_result.success 153 | assert "beyond the end of the file" in result.action_result.error 154 | 155 | 156 | def test_open_file_at_end(text_editor, sample_file, initial_state): 157 | result = text_editor.open_file(initial_state, sample_file, 199) 158 | assert isinstance(result.action_result, OpenFileResult) 159 | assert result.action_result.success 160 | assert result.new_state.open_files[sample_file].total_lines() == 1 161 | 162 | 163 | def test_open_file_near_end(text_editor, sample_file, initial_state): 164 | result = text_editor.open_file(initial_state, sample_file, 190) 165 | assert isinstance(result.action_result, OpenFileResult) 166 | assert result.action_result.success 167 | assert result.new_state.open_files[sample_file].total_lines() == 10 168 | -------------------------------------------------------------------------------- /programmer/tests/test_tool_calling.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import TypedDict 3 | 4 | import weave 5 | 6 | from programmer.tool_calling import generate_json_schema 7 | 8 | 9 | class Range(TypedDict): 10 | start: int 11 | end: int 12 | 13 | 14 | @weave.op 15 | def merge_ranges(ranges: list[Range]) -> list[Range]: 16 | """Merge a list of ranges into a single range. 17 | 18 | Args: 19 | ranges: A list of ranges to merge. 20 | 21 | Returns: 22 | A list of merged ranges. 23 | """ 24 | return ranges 25 | 26 | 27 | def test_list_of_typeddict_schema(): 28 | schema = generate_json_schema(merge_ranges) 29 | assert schema == { 30 | "function": { 31 | "description": "Merge a list of ranges into a single range.", 32 | "name": "merge_ranges", 33 | "parameters": { 34 | "properties": { 35 | "ranges": { 36 | "description": "A list of ranges to merge.", 37 | "type": "array", 38 | "items": { 39 | "type": "object", 40 | "properties": { 41 | "start": {"type": "integer"}, 42 | "end": {"type": "integer"}, 43 | }, 44 | "required": ["start", "end"], 45 | }, 46 | } 47 | }, 48 | "required": ["ranges"], 49 | "type": "object", 50 | }, 51 | }, 52 | "type": "function", 53 | } 54 | 55 | 56 | class Color(Enum): 57 | RED = 1 58 | GREEN = 2 59 | BLUE = 3 60 | 61 | 62 | @weave.op 63 | def color_name(color: Color) -> str: 64 | """Get the name of a color. 65 | 66 | Args: 67 | color: The color to get the name of. 68 | 69 | Returns: 70 | The name of the color. 71 | """ 72 | return color.name 73 | 74 | 75 | def test_enum_schema(): 76 | schema = generate_json_schema(color_name) 77 | assert schema == { 78 | "function": { 79 | "description": "Get the name of a color.", 80 | "name": "color_name", 81 | "parameters": { 82 | "properties": { 83 | "color": { 84 | "description": "The color to get the name of.", 85 | "enum": [1, 2, 3], 86 | "type": "integer", 87 | } 88 | }, 89 | "required": ["color"], 90 | "type": "object", 91 | }, 92 | }, 93 | "type": "function", 94 | } 95 | -------------------------------------------------------------------------------- /programmer/tests/test_weave_query.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import weave 4 | from weave.trace.weave_client import WeaveClient 5 | 6 | from programmer.weave_next.weave_query import calls, expand_refs 7 | 8 | 9 | def test_weave_query_basic(weave_client: WeaveClient): 10 | @weave.op 11 | def add(a: int, b: int) -> int: 12 | return a + b 13 | 14 | add(1, 2) 15 | 16 | calls_query = calls(weave_client, op_names="add") 17 | calls_df = calls_query.to_pandas() 18 | assert len(calls_df) == 1 19 | 20 | 21 | class Point(weave.Object): 22 | x: float 23 | y: float 24 | 25 | 26 | class NestedObject(weave.Object): 27 | point: Point 28 | value: int 29 | 30 | 31 | @weave.op 32 | def create_point(x: float, y: float) -> Point: 33 | return Point(x=x, y=y) 34 | 35 | 36 | @weave.op 37 | def add_points(p1: Point, p2: Point) -> Point: 38 | return Point(x=p1.x + p2.x, y=p1.y + p2.y) 39 | 40 | 41 | @weave.op 42 | def create_nested(p: Point, v: int) -> NestedObject: 43 | return NestedObject(point=p, value=v) 44 | 45 | 46 | def test_calls_with_expanded_refs(weave_client: WeaveClient): 47 | p1 = create_point(1.0, 2.0) 48 | p2 = create_point(3.0, 4.0) 49 | result = add_points(p1, p2) 50 | nested = create_nested(result, 42) 51 | 52 | calls_query = calls( 53 | weave_client, op_names="create_nested", expand_refs=["output", "output.point"] 54 | ) 55 | calls_df = calls_query.to_pandas() 56 | 57 | assert len(calls_df) == 1 58 | assert calls_df.iloc[0]["output.value"] == 42 59 | assert calls_df.iloc[0]["output.point.x"] == 4.0 60 | assert calls_df.iloc[0]["output.point.y"] == 6.0 61 | -------------------------------------------------------------------------------- /programmer/text_editor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Generic, TypeVar 2 | from dataclasses import dataclass, field 3 | from contextlib import contextmanager 4 | from contextvars import ContextVar 5 | from typing import Optional, TypedDict 6 | 7 | import weave 8 | 9 | from .io_context import get_io_context 10 | 11 | 12 | @dataclass(frozen=True) 13 | class LineRange: 14 | start_line: int 15 | n_lines: int 16 | 17 | 18 | @dataclass(frozen=True) 19 | class OpenFileState: 20 | # Invariant: ranges must be non-overlapping and non-adjacent 21 | # and must be in sorted order 22 | ranges: tuple[LineRange, ...] = field(default_factory=tuple) 23 | 24 | def add_range(self, range: LineRange) -> "OpenFileState": 25 | # Create a new list of ranges 26 | new_ranges = list(self.ranges) 27 | 28 | # Find the correct position to insert the new range 29 | insert_index = 0 30 | for i, existing_range in enumerate(new_ranges): 31 | if range.start_line < existing_range.start_line: 32 | insert_index = i 33 | break 34 | insert_index = i + 1 35 | 36 | # Insert the new range 37 | new_ranges.insert(insert_index, range) 38 | 39 | # Merge overlapping or adjacent ranges 40 | i = 0 41 | while i < len(new_ranges) - 1: 42 | current_range = new_ranges[i] 43 | next_range = new_ranges[i + 1] 44 | 45 | if ( 46 | current_range.start_line + current_range.n_lines 47 | >= next_range.start_line 48 | ): 49 | # Merge the ranges 50 | merged_end = max( 51 | current_range.start_line + current_range.n_lines, 52 | next_range.start_line + next_range.n_lines, 53 | ) 54 | new_ranges[i] = LineRange( 55 | current_range.start_line, merged_end - current_range.start_line 56 | ) 57 | new_ranges.pop(i + 1) 58 | else: 59 | i += 1 60 | 61 | # Return a new OpenFileState with the updated ranges 62 | return OpenFileState(ranges=tuple(new_ranges)) 63 | 64 | def subtract_range(self, range: LineRange) -> "OpenFileState": 65 | new_ranges = [] 66 | for existing_range in self.ranges: 67 | if range.start_line >= existing_range.start_line + existing_range.n_lines: 68 | # The subtracted range is after this range, keep it as is 69 | new_ranges.append(existing_range) 70 | elif range.start_line + range.n_lines <= existing_range.start_line: 71 | # The subtracted range is before this range, keep it as is 72 | new_ranges.append(existing_range) 73 | else: 74 | # The ranges overlap, we need to split or adjust 75 | if range.start_line > existing_range.start_line: 76 | # Keep the part before the subtracted range 77 | new_ranges.append( 78 | LineRange( 79 | existing_range.start_line, 80 | range.start_line - existing_range.start_line, 81 | ) 82 | ) 83 | if ( 84 | range.start_line + range.n_lines 85 | < existing_range.start_line + existing_range.n_lines 86 | ): 87 | # Keep the part after the subtracted range 88 | new_ranges.append( 89 | LineRange( 90 | range.start_line + range.n_lines, 91 | (existing_range.start_line + existing_range.n_lines) 92 | - (range.start_line + range.n_lines), 93 | ) 94 | ) 95 | 96 | return OpenFileState(ranges=tuple(new_ranges)) 97 | 98 | def total_lines(self) -> int: 99 | return sum(r.n_lines for r in self.ranges) 100 | 101 | def is_range_open(self, start_line: int, n_lines: int) -> bool: 102 | end_line = start_line + n_lines 103 | for range in self.ranges: 104 | if ( 105 | range.start_line <= start_line 106 | and range.start_line + range.n_lines >= end_line 107 | ): 108 | return True 109 | return False 110 | 111 | 112 | @dataclass(frozen=True) 113 | class TextEditorState: 114 | open_files: dict[str, OpenFileState] = field(default_factory=dict) 115 | 116 | def total_lines(self) -> int: 117 | return sum(file.total_lines() for file in self.open_files.values()) 118 | 119 | def get_open_file_info(self) -> "OpenFileInfoResult": 120 | file_io_context = get_io_context() 121 | open_file_buffers = {} 122 | for path, open_file in self.open_files.items(): 123 | contents = file_io_context.read_file(path) 124 | lines = contents.split("\n") 125 | buffers = [] 126 | for range in open_file.ranges: 127 | buffer = Buffer( 128 | line_range=range, 129 | lines=lines[ 130 | range.start_line - 1 : range.start_line - 1 + range.n_lines 131 | ], 132 | ) 133 | buffers.append(buffer) 134 | open_file_info = OpenFileInfo( 135 | buffers=tuple(buffers), total_lines=len(lines) 136 | ) 137 | open_file_buffers[path] = open_file_info 138 | return OpenFileInfoResult(open_file_buffers=open_file_buffers) 139 | 140 | 141 | @dataclass(frozen=True) 142 | class Buffer: 143 | line_range: LineRange 144 | lines: list[str] 145 | 146 | 147 | @dataclass(frozen=True) 148 | class OpenFileInfo: 149 | buffers: tuple[Buffer, ...] = field(default_factory=tuple) 150 | total_lines: int = 0 151 | 152 | def n_lines(self) -> int: 153 | return sum(buffer.line_range.n_lines for buffer in self.buffers) 154 | 155 | 156 | @dataclass(frozen=True) 157 | class OpenFileInfoResult: 158 | open_file_buffers: dict[str, OpenFileInfo] = field(default_factory=dict) 159 | 160 | def format_for_messages(self) -> str: 161 | lines = [ 162 | "Visible file buffers. These are the latest states of any previously opened file ranges, and reflect the results of all prior edits." 163 | ] 164 | for path, open_file_info in self.open_file_buffers.items(): 165 | lines.append(f"") 166 | # lines.append(f"") 167 | for buffer in open_file_info.buffers: 168 | lines.append("") 169 | for i, line in enumerate(buffer.lines): 170 | lines.append(f"{buffer.line_range.start_line + i}: {line}") 171 | lines.append("") 172 | lines.append("") 173 | return "\n".join(lines) 174 | 175 | 176 | @dataclass(frozen=True) 177 | class ClosedFileRange: 178 | path: str 179 | start_line: int 180 | n_lines: int 181 | 182 | 183 | @dataclass(frozen=True) 184 | class OpenFileResult: 185 | success: bool 186 | error: str 187 | 188 | 189 | @dataclass(frozen=True) 190 | class WriteFileResult: 191 | success: bool 192 | error: str 193 | 194 | 195 | T = TypeVar("T") 196 | 197 | 198 | @dataclass(frozen=True) 199 | class TextEditorMutationResult(Generic[T]): 200 | new_state: TextEditorState 201 | action_result: T 202 | 203 | 204 | class LineRangeReplacement(TypedDict): 205 | start_line: int 206 | n_lines: int 207 | lines: str 208 | 209 | 210 | class TextEditor: 211 | def __init__( 212 | self, 213 | max_open_size: int = 1500, 214 | open_chunk_size: int = 500, 215 | ): 216 | self.MAX_OPEN_SIZE = max_open_size 217 | self.OPEN_CHUNK_SIZE = open_chunk_size 218 | 219 | def open_file( 220 | self, state: TextEditorState, path: str, start_line: int 221 | ) -> TextEditorMutationResult[OpenFileResult]: 222 | file_io_context = get_io_context() 223 | try: 224 | file_contents = file_io_context.read_file(path) 225 | except FileNotFoundError: 226 | return TextEditorMutationResult( 227 | new_state=state, 228 | action_result=OpenFileResult(success=False, error="File not found"), 229 | ) 230 | 231 | file_lines = file_contents.split("\n") 232 | file_lines_count = len(file_lines) 233 | 234 | if start_line < 1: 235 | return TextEditorMutationResult( 236 | new_state=state, 237 | action_result=OpenFileResult( 238 | success=False, 239 | error=f"Start line {start_line} is before the start of the file.", 240 | ), 241 | ) 242 | 243 | if start_line - 1 >= file_lines_count: 244 | return TextEditorMutationResult( 245 | new_state=state, 246 | action_result=OpenFileResult( 247 | success=False, 248 | error=f"Start line {start_line} is beyond the end of the file (which has {file_lines_count} lines).", 249 | ), 250 | ) 251 | 252 | orig_open_file_state = state.open_files.get(path, OpenFileState()) 253 | new_buffer = LineRange( 254 | start_line, min(self.OPEN_CHUNK_SIZE, file_lines_count - start_line) 255 | ) 256 | new_open_file_state = orig_open_file_state.add_range(new_buffer) 257 | added_lines = ( 258 | new_open_file_state.total_lines() - orig_open_file_state.total_lines() 259 | ) 260 | 261 | if state.total_lines() + added_lines > self.MAX_OPEN_SIZE: 262 | return TextEditorMutationResult( 263 | new_state=state, 264 | action_result=OpenFileResult( 265 | success=False, 266 | error=f"This request would result in {state.total_lines() + added_lines} open lines exceeding the maximum of {self.MAX_OPEN_SIZE} lines.", 267 | ), 268 | ) 269 | 270 | new_open_files = dict(state.open_files) 271 | new_open_files[path] = new_open_file_state 272 | new_state = TextEditorState(open_files=new_open_files) 273 | 274 | return TextEditorMutationResult( 275 | new_state=new_state, 276 | action_result=OpenFileResult(success=True, error=""), 277 | ) 278 | 279 | def close_file_range( 280 | self, state: TextEditorState, path: str, start_line: int, n_lines: int 281 | ) -> TextEditorMutationResult[None]: 282 | open_file_state = state.open_files[path] 283 | new_open_file_state = open_file_state.subtract_range( 284 | LineRange(start_line, n_lines) 285 | ) 286 | 287 | new_open_files = dict(state.open_files) 288 | if new_open_file_state.total_lines() == 0: 289 | del new_open_files[path] 290 | else: 291 | new_open_files[path] = new_open_file_state 292 | 293 | new_state = TextEditorState(open_files=new_open_files) 294 | return TextEditorMutationResult(new_state=new_state, action_result=None) 295 | 296 | def replace_file_lines( 297 | self, 298 | state: TextEditorState, 299 | path: str, 300 | replacements: list[LineRangeReplacement], 301 | ) -> TextEditorMutationResult[WriteFileResult]: 302 | file_io_context = get_io_context() 303 | 304 | # Check if the file is open 305 | open_file_state = state.open_files.get(path) 306 | if not open_file_state: 307 | return TextEditorMutationResult( 308 | new_state=state, 309 | action_result=WriteFileResult( 310 | success=False, 311 | error=f"The file {path} is not open.", 312 | ), 313 | ) 314 | 315 | # Check if all ranges are open 316 | missing_ranges = [] 317 | for replacement in replacements: 318 | if not open_file_state.is_range_open( 319 | replacement["start_line"], replacement["n_lines"] 320 | ): 321 | missing_ranges.append(replacement) 322 | if missing_ranges: 323 | return TextEditorMutationResult( 324 | new_state=state, 325 | action_result=WriteFileResult( 326 | success=False, 327 | error=f"The following ranges are not open: {missing_ranges}", 328 | ), 329 | ) 330 | 331 | # Sort replacements by start line 332 | replacements.sort(key=lambda x: x["start_line"]) 333 | 334 | # Ensure replacements are non-overlapping 335 | for i in range(len(replacements) - 1): 336 | if ( 337 | replacements[i]["start_line"] + replacements[i]["n_lines"] 338 | > replacements[i + 1]["start_line"] 339 | ): 340 | return TextEditorMutationResult( 341 | new_state=state, 342 | action_result=WriteFileResult( 343 | success=False, 344 | error=f"The following replacements are overlapping: {replacements[i]}, {replacements[i+1]}", 345 | ), 346 | ) 347 | 348 | all_new_lines = [l["lines"].rstrip("\n").split("\n") for l in replacements] 349 | 350 | net_change = sum(len(l) for l in all_new_lines) - sum( 351 | l["n_lines"] for l in replacements 352 | ) 353 | if state.total_lines() + net_change > self.MAX_OPEN_SIZE: 354 | return TextEditorMutationResult( 355 | new_state=state, 356 | action_result=WriteFileResult( 357 | success=False, 358 | error=f"This edit would result in {state.total_lines() + net_change} open lines exceeding the maximum of {self.MAX_OPEN_SIZE} lines.", 359 | ), 360 | ) 361 | 362 | file_io_context = get_io_context() 363 | try: 364 | file_contents = file_io_context.read_file(path) 365 | file_lines = file_contents.split("\n") 366 | except Exception as e: 367 | return TextEditorMutationResult( 368 | new_state=state, 369 | action_result=WriteFileResult( 370 | success=False, 371 | error=f"Failed to write to file: {str(e)}", 372 | ), 373 | ) 374 | 375 | # Apply replacements in reverse order to indexes don't change while iterating 376 | for i, replacement in reversed(list(enumerate(replacements))): 377 | start_line = replacement["start_line"] 378 | n_lines = replacement["n_lines"] 379 | file_lines[start_line - 1 : start_line - 1 + n_lines] = all_new_lines[i] 380 | 381 | new_contents = "\n".join(file_lines) 382 | 383 | file_io_context.write_file(path, new_contents) 384 | return TextEditorMutationResult( 385 | new_state=state, 386 | action_result=WriteFileResult(success=True, error=""), 387 | ) 388 | 389 | 390 | class TextEditorStateful: 391 | def __init__(self, text_editor: TextEditor, initial_state: TextEditorState): 392 | self.text_editor = text_editor 393 | self.state = initial_state 394 | 395 | def open_file(self, path: str, start_line: int) -> OpenFileResult: 396 | result = self.text_editor.open_file(self.state, path, start_line) 397 | self.state = result.new_state 398 | return result.action_result 399 | 400 | def close_file_range(self, path: str, start_line: int, n_lines: int) -> None: 401 | result = self.text_editor.close_file_range( 402 | self.state, path, start_line, n_lines 403 | ) 404 | self.state = result.new_state 405 | return result.action_result 406 | 407 | def replace_file_lines( 408 | self, 409 | path: str, 410 | replacements: list[LineRangeReplacement], 411 | ) -> WriteFileResult: 412 | result = self.text_editor.replace_file_lines(self.state, path, replacements) 413 | self.state = result.new_state 414 | return result.action_result 415 | 416 | 417 | _text_editor_context: ContextVar[Optional[TextEditorStateful]] = ContextVar( 418 | "_text_editor_context", default=None 419 | ) 420 | 421 | 422 | @contextmanager 423 | def text_editor(context: TextEditorStateful): 424 | token = _text_editor_context.set(context) 425 | try: 426 | yield context 427 | finally: 428 | _text_editor_context.reset(token) 429 | 430 | 431 | def require_text_editor() -> TextEditorStateful: 432 | context = _text_editor_context.get() 433 | assert context is not None 434 | return context 435 | 436 | 437 | @weave.op 438 | def open_file(path: str, start_line: int) -> str: 439 | """Open a buffer of lines from the given file. 440 | 441 | Args: 442 | path: The path to the file. 443 | start_line: The line number to start reading from (1-indexed). 444 | 445 | Returns: 446 | "success" if the file was opened successfully, 447 | "error: " if the file was not opened successfully. 448 | """ 449 | text_editor = require_text_editor() 450 | response = text_editor.open_file(path, start_line) 451 | if response.success: 452 | return "success" 453 | else: 454 | return f"error: {response.error}" 455 | 456 | 457 | @weave.op 458 | def close_file_range(path: str, start_line: int, n_lines: int) -> str: 459 | """Close a buffer of lines from the given file. 460 | 461 | Args: 462 | path: The path to the file. 463 | start_line: The line number to start reading from (1-indexed). 464 | n_lines: The number of lines to close. 465 | 466 | Returns: 467 | "success" if the file was closed successfully. 468 | """ 469 | text_editor = require_text_editor() 470 | response = text_editor.close_file_range(path, start_line, n_lines) 471 | return "success" 472 | 473 | 474 | class LineRangeReplacementStartEnd(TypedDict): 475 | start_line: int 476 | remove_up_to_line: int 477 | lines: str 478 | 479 | 480 | @weave.op 481 | def replace_file_lines( 482 | path: str, replacements: list[LineRangeReplacementStartEnd] 483 | ) -> str: 484 | """Replace ranges of lines within a file. Changes must be made to open ranges, and will be reflected immediately on the filesystem. First, existing lines are removed starting at start line, up to but not including replace_up_to_line. Then the new lines are added in that position. 485 | 486 | Args: 487 | path: The path to the file. 488 | replacements: A list of replacements to make. Each replacement is a dictionary with keys: start_line (1-indexed, inclusive), remove_up_to_line (1-indexed, exclusive), lines (a string of newline separated lines to insert) 489 | 490 | Returns: 491 | "success" if the file was replaced successfully, 492 | "error: " if the file was not replaced successfully. 493 | """ 494 | text_editor = require_text_editor() 495 | replacements_list = [ 496 | LineRangeReplacement( 497 | start_line=r["start_line"], 498 | n_lines=r["remove_up_to_line"] - r["start_line"], 499 | lines=r["lines"], 500 | ) 501 | for r in replacements 502 | ] 503 | response = text_editor.replace_file_lines(path, replacements_list) 504 | if response.success: 505 | return "success" 506 | else: 507 | return f"error: {response.error}" 508 | -------------------------------------------------------------------------------- /programmer/tool_calling.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import json 3 | import traceback 4 | import typing_extensions 5 | 6 | from typing import Any, Callable, get_type_hints, TypedDict 7 | 8 | from openai.types.chat import ChatCompletionMessageToolCall, ChatCompletionToolParam 9 | 10 | from .console import Console 11 | 12 | 13 | class TypedDictLike: 14 | __required_keys__: frozenset[str] 15 | 16 | 17 | def is_typed_dict_like(t: type) -> typing_extensions.TypeGuard[TypedDictLike]: 18 | return hasattr(t, "__required_keys__") 19 | 20 | 21 | def pytype_to_jsonschema(pytype: Any) -> dict: 22 | if pytype.__name__ == "str": 23 | return {"type": "string"} 24 | elif pytype.__name__ == "int": 25 | return {"type": "integer"} 26 | elif is_typed_dict_like(pytype): 27 | return { 28 | "type": "object", 29 | "properties": { 30 | k: pytype_to_jsonschema(v) for k, v in pytype.__annotations__.items() 31 | }, 32 | "required": list(pytype.__annotations__.keys()), 33 | } 34 | elif pytype.__name__ == "list": 35 | return {"type": "array", "items": pytype_to_jsonschema(pytype.__args__[0])} 36 | elif hasattr(pytype, "__members__"): 37 | member_types = [ 38 | pytype_to_jsonschema(type(v.value)) for v in pytype.__members__.values() 39 | ] 40 | t0 = member_types[0] 41 | for t in member_types[1:]: 42 | if t != t0: 43 | raise ValueError("All member types must be the same") 44 | mem_type = t0["type"] 45 | if mem_type != "string" and mem_type != "integer": 46 | raise ValueError(f"Enum member type {mem_type} is not supported") 47 | return {"type": mem_type, "enum": [e.value for e in pytype]} 48 | raise ValueError(f"Unsupported type: {pytype.__name__}") 49 | 50 | 51 | def generate_json_schema(func: Callable) -> dict: 52 | """Given a function, generate an OpenAI tool compatible JSON schema. 53 | 54 | WIP: This function is very basic and hacky. It will not work in many 55 | scenarios. 56 | """ 57 | # Extract function signature 58 | signature = inspect.signature(func) 59 | parameters = signature.parameters 60 | 61 | # Extract annotations 62 | type_hints = get_type_hints(func) 63 | 64 | # Initialize the schema structure 65 | schema = { 66 | "type": "function", 67 | "function": { 68 | "name": func.__name__, 69 | "description": func.__doc__.split("\n")[0] if func.__doc__ else "", 70 | "parameters": { 71 | "type": "object", 72 | "properties": {}, 73 | "required": [], 74 | }, 75 | }, 76 | } 77 | 78 | # Process each parameter 79 | for name, param in parameters.items(): 80 | # Determine if this parameter is required (no default value) 81 | is_required = param.default == inspect.Parameter.empty 82 | 83 | # Extract parameter type and description 84 | param_schema = pytype_to_jsonschema(type_hints[name]) 85 | 86 | # Attempt to extract description from docstring 87 | param_desc = "" 88 | if func.__doc__: 89 | doc_lines = func.__doc__.split("\n")[1:] 90 | for line in doc_lines: 91 | if name in line: 92 | param_desc = line.strip().split(":")[-1].strip() 93 | break 94 | if not param_desc: 95 | raise ValueError( 96 | f"Function {func.__name__} description for parameter {name} is missing" 97 | ) 98 | param_schema["description"] = param_desc 99 | 100 | schema["function"]["parameters"]["properties"][name] = param_schema # type: ignore 101 | 102 | if is_required: 103 | schema["function"]["parameters"]["required"].append(name) # type: ignore 104 | 105 | return schema 106 | 107 | 108 | def chat_call_tool_params(tools: list[Callable]) -> list[ChatCompletionToolParam]: 109 | chat_tools = [generate_json_schema(tool) for tool in tools] 110 | return [ChatCompletionToolParam(**tool) for tool in chat_tools] 111 | 112 | 113 | def get_tool(tools: list[Callable], name: str) -> Callable: 114 | for t in tools: 115 | if t.__name__ == name: 116 | return t 117 | raise KeyError(f"No tool with name {name} found") 118 | 119 | 120 | def perform_tool_calls( 121 | tools: list[Callable], tool_calls: list[ChatCompletionMessageToolCall] 122 | ) -> list[dict]: 123 | messages = [] 124 | for tool_call in tool_calls: 125 | function_name = tool_call.function.name 126 | tool = get_tool(tools, function_name) 127 | function_args = {} 128 | function_response = None 129 | tool_call_s = f"{function_name}({tool_call.function.arguments})" 130 | Console.tool_call_start(tool_call_s) 131 | try: 132 | function_args = json.loads(tool_call.function.arguments) 133 | except json.JSONDecodeError as e: 134 | print(f"Tool call {tool_call_s} failed to parse arguments: {e}") 135 | function_response = f"Argument parse error: {str(e)}" 136 | if not function_response: 137 | try: 138 | function_response = tool(**function_args) 139 | except Exception as e: 140 | print(f"Error occurred in tool {function_name}:") 141 | traceback.print_exc() 142 | function_response = f"Error: {str(e)}" 143 | 144 | additional_message = None 145 | if isinstance(function_response, tuple): 146 | additional_message = function_response[1] 147 | function_response = str(function_response[0]) 148 | else: 149 | function_response = str(function_response) 150 | 151 | Console.tool_call_complete(function_response) 152 | messages.append( 153 | { 154 | "tool_call_id": tool_call.id, 155 | "role": "tool", 156 | "content": function_response, 157 | } 158 | ) 159 | if additional_message: 160 | messages.append(additional_message) 161 | return messages 162 | -------------------------------------------------------------------------------- /programmer/tools.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import weave 4 | 5 | from .io_context import get_io_context 6 | 7 | LENGTH_LIMIT = 30000 8 | 9 | # TODO: 10 | # - get rid of resolve_path 11 | # - must return FileNotFoundError in read_file in Remote 12 | 13 | 14 | def read_image_as_base64(path: str): 15 | ext = os.path.splitext(path)[1] 16 | if ext not in [".jpg", ".jpeg", ".png"]: 17 | raise ValueError("Only .jpg, .jpeg, and .png files are supported.") 18 | if ext in [".jpg", ".jpeg"]: 19 | mime_type = "image/jpeg" 20 | else: 21 | mime_type = "image/png" 22 | # Read the image file in binary mode 23 | with open(path, "rb") as image_file: 24 | # Encode the image to base64 25 | base64_bytes = base64.b64encode(image_file.read()) 26 | # Convert the base64 bytes to string 27 | base64_string = base64_bytes.decode("utf-8") 28 | # Format the string as required 29 | formatted_base64_string = f"data:{mime_type};base64,{base64_string}" 30 | return formatted_base64_string 31 | 32 | 33 | @weave.op() 34 | def view_image(path: str): 35 | """View a png or jpg image file. 36 | 37 | Args: 38 | path: The path to the image file. 39 | 40 | Returns: 41 | A message indicating that the image was displayed successfully. 42 | """ 43 | context = get_io_context() 44 | full_path = context.resolve_path(path) 45 | base64_image = read_image_as_base64(full_path) 46 | 47 | return f"Image {full_path} displayed in next message.", { 48 | "role": "user", 49 | "content": [ 50 | { 51 | "type": "image_url", 52 | "image_url": {"url": base64_image, "detail": "high"}, 53 | }, 54 | ], 55 | } 56 | 57 | 58 | @weave.op() 59 | def list_files(directory: str) -> str: 60 | """List names of all files in a directory. 61 | 62 | Args: 63 | directory: The directory to list. 64 | 65 | Returns: 66 | The list of files in the directory. 67 | """ 68 | context = get_io_context() 69 | # full_path = context.resolve_path(directory) 70 | result = context.run_command(f"ls {directory}") 71 | exit_code = result["exit_code"] 72 | output = result["output"] 73 | if exit_code != 0: 74 | raise Exception(f"Failed to list files: {output}") 75 | if output == "": 76 | return "[No files found]" 77 | if len(output) > LENGTH_LIMIT: 78 | output = output[:LENGTH_LIMIT] 79 | output += "\n... (truncated)" 80 | return output 81 | 82 | 83 | @weave.op() 84 | def write_to_file(path: str, content: str) -> str: 85 | """Write text to a file at the given path. 86 | 87 | Args: 88 | path: The path to the file. 89 | content: The content to write to the file. 90 | 91 | Returns: 92 | A message indicating whether the file was written successfully. 93 | """ 94 | context = get_io_context() 95 | if len(content) > LENGTH_LIMIT: 96 | content = content[:LENGTH_LIMIT] 97 | content += "\n... (truncated)" 98 | context.write_file(path, content) 99 | return "File written successfully." 100 | 101 | 102 | @weave.op 103 | def read_from_file(path: str) -> str: 104 | """Read text from a file at the given path. 105 | 106 | Args: 107 | path: The path to the file. 108 | 109 | Returns: 110 | The content of the file. 111 | """ 112 | context = get_io_context() 113 | result = context.read_file(path) 114 | if len(result) > LENGTH_LIMIT: 115 | result = result[:LENGTH_LIMIT] 116 | result += "\n... (truncated)" 117 | return result 118 | 119 | 120 | @weave.op() 121 | def run_command(command: str) -> str: 122 | """Run a shell command and return its output. 123 | 124 | Args: 125 | command: The command to run. 126 | 127 | Returns: 128 | The output of the command. 129 | """ 130 | context = get_io_context() 131 | result = context.run_command(command) 132 | 133 | exit_code = result["exit_code"] 134 | output = result["output"] 135 | 136 | if len(output) > LENGTH_LIMIT: 137 | output = output[:LENGTH_LIMIT] 138 | output += "\n... (truncated)" 139 | 140 | result = f"Exit code: {exit_code}\n" 141 | if output: 142 | result += f"OUTPUT\n{output}\n" 143 | return result 144 | 145 | 146 | @weave.op 147 | def read_lines_from_file(file_path: str, start_line: int) -> str: 148 | """Read up to 500 lines from a file starting at a specific line number. 149 | 150 | Args: 151 | file_path: The path to the file. 152 | start_line: The line number to start reading from (1-indexed). 153 | 154 | Returns: 155 | A string with each line prefixed by its line number. 156 | 157 | Raises: 158 | Exception: If the file does not exist or start_line is invalid. 159 | """ 160 | context = get_io_context() 161 | full_path = context.resolve_path(file_path) 162 | content = context.read_file(full_path) 163 | lines = content.splitlines() 164 | 165 | if start_line < 1 or start_line > len(lines): 166 | raise Exception("Invalid start_line number.") 167 | 168 | end_line = min(start_line + 500, len(lines) + 1) 169 | result = "" 170 | 171 | for i in range(start_line - 1, end_line - 1): 172 | result += f"{i + 1}:{lines[i]}\n" 173 | 174 | return result 175 | 176 | 177 | @weave.op 178 | def replace_lines_in_file( 179 | file_path: str, 180 | start_line: int, 181 | remove_line_count: int, 182 | previous_lines: str, 183 | new_lines: str, 184 | ) -> str: 185 | """Replace lines in a file from start_line to end_line with new_lines. Changes are committed to the file. 186 | 187 | Args: 188 | file_path: The path to the file. 189 | start_line: The starting line number for replacement (1-indexed). 190 | remove_line_count: The number of lines to remove, starting with start_line. 191 | previous_lines: The previous lines to replace, as a single string. This must match the existing lines, or an exception is raised. 192 | new_lines: The new lines to insert, as a single string. 193 | 194 | Returns: 195 | Success message, otherwise raises an exception. 196 | 197 | Raises: 198 | Exception: If the line range is invalid or file cannot be accessed. 199 | """ 200 | context = get_io_context() 201 | full_path = context.resolve_path(file_path) 202 | try: 203 | content = context.read_file(full_path) 204 | except FileNotFoundError: 205 | content = "" 206 | lines = content.splitlines() 207 | 208 | end_line = start_line + remove_line_count 209 | 210 | if start_line < 1 or end_line < start_line or start_line > len(lines) + 1: 211 | raise Exception("Invalid line range.") 212 | 213 | prev_line_split = previous_lines.splitlines() 214 | if not lines[start_line - 1 : end_line - 1] == prev_line_split: 215 | raise Exception("Previous lines do not match.") 216 | 217 | # Adjust end_line if it exceeds the current number of lines 218 | end_line = min(end_line, len(lines) + 1) 219 | 220 | # Convert new_lines string into a list of lines 221 | new_lines_list = new_lines.splitlines() 222 | 223 | # Replace the specified line range 224 | lines[start_line - 1 : end_line - 1] = new_lines_list 225 | 226 | # Write the modified lines back to the file 227 | context.write_file(full_path, "\n".join(lines) + "\n") 228 | 229 | # Determine the range for the output with a 5-line buffer 230 | output_start = max(start_line - 6, 0) 231 | output_end = min(start_line - 1 + len(new_lines_list) + 6, len(lines)) 232 | result = "" 233 | 234 | for i in range(output_start, output_end): 235 | result += f"{i + 1}:{lines[i]}\n" 236 | 237 | return result 238 | 239 | 240 | @weave.op 241 | def splice_lines_in_file( 242 | file_path: str, 243 | start_line: int, 244 | remove_line_count: int, 245 | previous_lines: str, 246 | new_lines: str, 247 | ) -> str: 248 | """Remove remove_line_count lines, starting with start_line, then insert new_lines so that first line is inserted at index start_line. 249 | 250 | To append, use last line index + 1. 251 | 252 | Args: 253 | file_path: The path to the file. 254 | start_line: The starting line number for replacement (1-indexed). 255 | remove_line_count: The number of lines to remove, starting with start_line. 256 | previous_lines: The previous lines to replace, as a single string. This must match the existing lines, or an exception is raised. 257 | new_lines: The new lines to insert, as a single string. The first line inserted will be at index start_line. 258 | 259 | Returns: 260 | Success message, otherwise raises an exception. 261 | 262 | Raises: 263 | Exception: If the line range is invalid or file cannot be accessed. 264 | """ 265 | context = get_io_context() 266 | full_path = context.resolve_path(file_path) 267 | try: 268 | content = context.read_file(full_path) 269 | except FileNotFoundError: 270 | content = "" 271 | lines = content.splitlines() 272 | 273 | end_line = start_line + remove_line_count 274 | 275 | if start_line < 1 or end_line < start_line or start_line > len(lines) + 1: 276 | raise Exception("Invalid line range.") 277 | 278 | prev_line_split = previous_lines.splitlines() 279 | if not lines[start_line - 1 : end_line - 1] == prev_line_split: 280 | raise Exception("Previous lines do not match.") 281 | 282 | # Adjust end_line if it exceeds the current number of lines 283 | end_line = min(end_line, len(lines) + 1) 284 | 285 | # Convert new_lines string into a list of lines 286 | new_lines_list = new_lines.splitlines() 287 | 288 | # Replace the specified line range 289 | lines[start_line - 1 : end_line - 1] = new_lines_list 290 | 291 | # Write the modified lines back to the file 292 | context.write_file(full_path, "\n".join(lines) + "\n") 293 | 294 | # Determine the range for the output with a 5-line buffer 295 | output_start = max(start_line - 6, 0) 296 | output_end = min(start_line - 1 + len(new_lines_list) + 6, len(lines)) 297 | result = "" 298 | 299 | for i in range(output_start, output_end): 300 | result += f"{i + 1}:{lines[i]}\n" 301 | 302 | return result 303 | -------------------------------------------------------------------------------- /programmer/weave_next/api.py: -------------------------------------------------------------------------------- 1 | # weave doesn't setup the sqlite server correctly, need to wrap in an ID converter 2 | 3 | from typing import Optional 4 | import base64 5 | 6 | 7 | from weave.trace.weave_client import WeaveClient 8 | from weave.trace.weave_init import InitializedClient 9 | from weave.trace_server.external_to_internal_trace_server_adapter import ( 10 | IdConverter, 11 | ExternalTraceServer, 12 | ) 13 | from weave.trace_server.sqlite_trace_server import SqliteTraceServer 14 | 15 | 16 | def b64_encode(s: str) -> str: 17 | return base64.b64encode(s.encode("ascii")).decode("ascii") 18 | 19 | 20 | def b64_decode(s: str) -> str: 21 | return base64.b64decode(s.encode("ascii")).decode("ascii") 22 | 23 | 24 | class DummyIdConverter(IdConverter): 25 | def ext_to_int_project_id(self, project_id: str) -> str: 26 | return b64_encode(project_id) 27 | 28 | def int_to_ext_project_id(self, project_id: str) -> Optional[str]: 29 | return b64_decode(project_id) 30 | 31 | def ext_to_int_run_id(self, run_id: str) -> str: 32 | return run_id 33 | 34 | def int_to_ext_run_id(self, run_id: str) -> str: 35 | return run_id 36 | 37 | def ext_to_int_user_id(self, user_id: str) -> str: 38 | return user_id 39 | 40 | def int_to_ext_user_id(self, user_id: str) -> str: 41 | return user_id 42 | 43 | 44 | # Exposed for conftest.py 45 | def make_external_sql_server(internal_server: SqliteTraceServer) -> ExternalTraceServer: 46 | return ExternalTraceServer( 47 | internal_server, 48 | DummyIdConverter(), 49 | ) 50 | 51 | 52 | def init_local_client(db_path: str = "weave.db"): 53 | server = SqliteTraceServer(db_path) 54 | server.setup_tables() 55 | server = make_external_sql_server(server) 56 | client = WeaveClient("none", "none", server) 57 | return InitializedClient(client).client 58 | -------------------------------------------------------------------------------- /programmer/weave_next/weave_query.py: -------------------------------------------------------------------------------- 1 | # This is a batch Weave query API. It can move into the Weave library 2 | # if we find it useful. 3 | 4 | from typing import Optional, Union, Sequence, Any 5 | import pandas as pd 6 | from weave.trace.weave_client import WeaveClient 7 | from weave.trace_server.trace_server_interface import ( 8 | CallsQueryReq, 9 | CallsFilter, 10 | RefsReadBatchReq, 11 | ) 12 | 13 | 14 | def _construct_calls_filter( 15 | project_id: str, 16 | op_names: Optional[Union[str, Sequence[str]]] = None, 17 | parent_ids: Optional[Union[str, Sequence[str]]] = None, 18 | ): 19 | if op_names is None: 20 | op_names = [] 21 | elif isinstance(op_names, str): 22 | op_names = [op_names] 23 | op_ref_uris = [] 24 | for op_name in op_names: 25 | if op_name.startswith("weave:///"): 26 | op_ref_uris.append(op_name) 27 | else: 28 | if ":" not in op_name: 29 | op_name = op_name + ":*" 30 | op_ref_uris.append(f"weave:///{project_id}/op/{op_name}") 31 | 32 | if parent_ids is None: 33 | parent_ids = [] 34 | elif isinstance(parent_ids, str): 35 | parent_ids = [parent_ids] 36 | 37 | return CallsFilter(op_names=op_ref_uris, parent_ids=parent_ids) # type: ignore 38 | 39 | 40 | def _server_call_pages( 41 | wc: WeaveClient, 42 | filt: CallsFilter, 43 | limit: Optional[int] = None, 44 | ): 45 | page_index = 0 46 | page_size = 1000 47 | remaining = limit 48 | while True: 49 | response = wc.server.calls_query( 50 | CallsQueryReq( 51 | project_id=wc._project_id(), 52 | filter=filt, 53 | offset=page_index * page_size, 54 | limit=page_size, 55 | ) 56 | ) 57 | page_data = [] 58 | for v in response.calls: 59 | v = v.model_dump() 60 | page_data.append(v) 61 | if remaining is not None: 62 | page_data = page_data[:remaining] 63 | remaining -= len(page_data) 64 | yield page_data 65 | if len(page_data) < page_size: 66 | break 67 | page_index += 1 68 | 69 | 70 | def _server_refs(self, refs: Sequence[Union[str, Any]]): 71 | ref_uris = [] 72 | non_refs = [] 73 | ref_indices = {} 74 | for i, item in enumerate(refs): 75 | if isinstance(item, str) and item.startswith("weave://"): 76 | ref_uris.append(item) 77 | if item not in ref_indices: 78 | ref_indices[item] = [] 79 | ref_indices[item].append(i) 80 | else: 81 | non_refs.append((i, item)) 82 | 83 | results = [] 84 | for offset in range(0, len(ref_uris), 1000): 85 | batch = ref_uris[offset : offset + 1000] 86 | read_res = self.server.refs_read_batch(RefsReadBatchReq(refs=batch)) 87 | results.extend(read_res.vals) 88 | 89 | ref_to_result = dict(zip(ref_uris, results)) 90 | 91 | final_results: list[Any] = [None] * len(refs) 92 | for ref, result in ref_to_result.items(): 93 | for index in ref_indices[ref]: 94 | final_results[index] = result 95 | for index, item in non_refs: 96 | final_results[index] = item 97 | 98 | return final_results 99 | 100 | 101 | def _expand_refs_in_page(wc: WeaveClient, page: list[dict], expand_refs: list[str]): 102 | # To hack this implementation together, I flatten on each pass instead of dealing 103 | # with nested keys. This functionality will be available in the server soon, 104 | # so we can get rid of most of this code. 105 | flat_page = pd.json_normalize(page).to_dict(orient="records") 106 | for ref in expand_refs: 107 | ref_values = [call.get(ref) for call in flat_page] 108 | expanded_refs = _server_refs(wc, ref_values) 109 | for call, expanded_ref in zip(flat_page, expanded_refs): 110 | orig_val = call[ref] 111 | call[ref] = expanded_ref 112 | if ( 113 | isinstance(orig_val, str) 114 | and orig_val.startswith("weave://") 115 | and isinstance(expanded_ref, dict) 116 | ): 117 | expanded_ref["_ref"] = orig_val 118 | flat_page = pd.json_normalize(flat_page).to_dict(orient="records") 119 | return flat_page 120 | 121 | 122 | class Calls: 123 | def __init__( 124 | self, 125 | wc: WeaveClient, 126 | filt: CallsFilter, 127 | expand_refs: Optional[list[str]] = None, 128 | ): 129 | self._wc = wc 130 | self._filt = filt 131 | self._expand_refs = expand_refs or [] 132 | 133 | def to_pandas(self): 134 | vals = [] 135 | for page in _server_call_pages(self._wc, self._filt): 136 | if self._expand_refs: 137 | page = _expand_refs_in_page(self._wc, page, self._expand_refs) 138 | vals.extend(page) 139 | return pd.json_normalize(vals) 140 | 141 | 142 | def calls( 143 | wc: WeaveClient, 144 | op_names: Optional[Union[str, Sequence[str]]] = None, 145 | parent_ids: Optional[Union[str, Sequence[str]]] = None, 146 | limit: Optional[int] = None, 147 | expand_refs: Optional[list[str]] = None, 148 | ): 149 | return Calls( 150 | wc, _construct_calls_filter(wc._project_id(), op_names, parent_ids), expand_refs 151 | ) 152 | 153 | 154 | class Objs: 155 | def __init__(self, wc: WeaveClient, refs: Sequence[str]): 156 | self._wc = wc 157 | self._refs = refs 158 | 159 | def to_pandas(self): 160 | vals = _server_refs(self._wc, self._refs) 161 | df = pd.json_normalize(vals) 162 | df.index = pd.Index(self._refs) 163 | return df 164 | 165 | 166 | def expand_refs(wc: WeaveClient, refs: Sequence[str]): 167 | return Objs(wc, refs) 168 | 169 | 170 | def get_call(wc: WeaveClient, call_id: str): 171 | """Return a raw Weave call.""" 172 | response = wc.server.calls_query( 173 | CallsQueryReq( 174 | project_id=wc._project_id(), 175 | filter=CallsFilter(call_ids=[call_id]), 176 | ) 177 | ) 178 | return response.calls[0].model_dump() 179 | 180 | 181 | def expand_json_refs(wc: WeaveClient, json: dict): 182 | """Expand any nested refs in a compound python value""" 183 | 184 | def find_refs(obj): 185 | refs = [] 186 | if isinstance(obj, dict): 187 | for value in obj.values(): 188 | refs.extend(find_refs(value)) 189 | elif isinstance(obj, list): 190 | for item in obj: 191 | refs.extend(find_refs(item)) 192 | elif isinstance(obj, str) and obj.startswith("weave://"): 193 | refs.append(obj) 194 | return refs 195 | 196 | def replace_refs(obj, ref_values): 197 | if isinstance(obj, dict): 198 | return {k: replace_refs(v, ref_values) for k, v in obj.items()} 199 | elif isinstance(obj, list): 200 | return [replace_refs(item, ref_values) for item in obj] 201 | elif isinstance(obj, str) and obj.startswith("weave://"): 202 | return ref_values.get(obj, obj) 203 | return obj 204 | 205 | refs = find_refs(json) 206 | if not refs: 207 | return json 208 | 209 | ref_values = _server_refs(wc, refs) 210 | ref_dict = {ref: value for ref, value in zip(refs, ref_values)} 211 | 212 | return replace_refs(json, ref_dict) 213 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "programmer" 10 | version = "0.1.9" 11 | description = "A Python package for managing programming tasks." 12 | authors = [{name = "Shawn Lewis", email = "shawn@wandb.com"}] 13 | license = { text = "Apache-2.0" } 14 | readme = "README.md" 15 | requires-python = ">=3.10" 16 | dependencies = [ 17 | "weave>=0.51.1", "streamlit", "pandas", "litellm" 18 | ] 19 | 20 | [tool.setuptools] 21 | packages = { find = {} } 22 | 23 | [project.urls] 24 | Homepage = "https://github.com/wandb/programmer" 25 | 26 | [project.entry-points."console_scripts"] 27 | programmer = "programmer.programmer:main" 28 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore::DeprecationWarning -------------------------------------------------------------------------------- /release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Ensure the script exits on the first failure 4 | set -e 5 | 6 | # Bump the version and create a tag (patch, minor, major) 7 | # Adjust 'patch' to 'minor' or 'major' as needed 8 | bump2version patch 9 | 10 | # Push changes to Git 11 | git push 12 | 13 | # Push tags to Git 14 | git push --tags 15 | 16 | # Build the package 17 | python3 -m build 18 | 19 | # Upload to PyPI (ensure you have twine installed) 20 | twine upload dist/* -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | wheel 3 | pytest 4 | pyright 5 | fastapi 6 | docker 7 | swebench --------------------------------------------------------------------------------