├── AI_ETHICS.md ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── SECURITY.md ├── __init__.py ├── agent_humaneval.sh ├── bfs.py ├── common.py ├── config.py ├── data ├── apps_selected.json ├── code_contests_test.json ├── humaneval.jsonl └── mbpp.jsonl ├── dfs_real.py ├── executors ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── executor_types.cpython-310.pyc │ ├── executor_types.cpython-311.pyc │ ├── executor_utils.cpython-310.pyc │ ├── executor_utils.cpython-311.pyc │ ├── factory.cpython-310.pyc │ ├── factory.cpython-311.pyc │ ├── py_executor.cpython-310.pyc │ └── py_executor.cpython-311.pyc ├── executor_types.py ├── executor_utils.py ├── factory.py └── py_executor.py ├── generators ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── factory.cpython-310.pyc │ ├── factory.cpython-311.pyc │ ├── generator_types.cpython-310.pyc │ ├── generator_types.cpython-311.pyc │ ├── generator_utils.cpython-310.pyc │ ├── generator_utils.cpython-311.pyc │ ├── model.cpython-310.pyc │ ├── model.cpython-311.pyc │ ├── parse.cpython-310.pyc │ ├── parse.cpython-311.pyc │ ├── py_generate.cpython-310.pyc │ └── py_generate.cpython-311.pyc ├── factory.py ├── generator_types.py ├── generator_utils.py ├── model.py ├── parse.py └── py_generate.py ├── license_info.md ├── llm_agent_guide.py ├── main.py ├── reflexion.py ├── reflexion_codecontests.sh ├── requirements.txt ├── resample_baseline.py ├── root ├── check_test.py └── get_acc.py ├── strategy.py └── utils.py /AI_ETHICS.md: -------------------------------------------------------------------------------- 1 | ## Ethics disclaimer for Salesforce AI models, data, code 2 | 3 | This release is for research purposes only in support of an academic 4 | paper. Our models, datasets, and code are not specifically designed or 5 | evaluated for all downstream purposes. We strongly recommend users 6 | evaluate and address potential concerns related to accuracy, safety, and 7 | fairness before deploying this model. We encourage users to consider the 8 | common limitations of AI, comply with applicable laws, and leverage best 9 | practices when selecting use cases, particularly for high-risk scenarios 10 | where errors or misuse could significantly impact people’s lives, rights, 11 | or safety. For further guidance on use cases, refer to our standard 12 | [AUP](https://www.salesforce.com/content/dam/web/en_us/www/documents/legal/Agreements/policies/ExternalFacing_Services_Policy.pdf) 13 | and [AI AUP](https://www.salesforce.com/content/dam/web/en_us/www/documents/legal/Agreements/policies/ai-acceptable-use-policy.pdf). 14 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related other information. Please be careful while editing. 2 | #ECCN:Open Source 3 | #GUSINFO:Open Source,Open Source Workflow -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guide For CodeTree 2 | 3 | This page lists the operational governance model of this project, as well as the recommendations and requirements for how to best contribute to CodeTree. We strive to obey these as best as possible. As always, thanks for contributing – we hope these guidelines make it easier and shed some light on our approach and processes. 4 | 5 | # Governance Model 6 | 7 | ## Community Based 8 | 9 | The intent and goal of open sourcing this project is to increase the contributor and user base. The governance model is one where new project leads (`admins`) will be added to the project based on their contributions and efforts, a so-called "do-acracy" or "meritocracy" similar to that used by all Apache Software Foundation projects. 10 | 11 | 12 | # Getting started 13 | 14 | # Issues, requests & ideas 15 | 16 | Use GitHub Issues page to submit issues, enhancement requests and discuss ideas. 17 | 18 | ### Bug Reports and Fixes 19 | - If you find a bug, please search for it in the [Issues](https://github.com/SalesforceAIResearch/CodeTree/issues), and if it isn't already tracked, 20 | [create a new issue](https://github.com/SalesforceAIResearch/CodeTree/issues/new). Fill out the "Bug Report" section of the issue template. Even if an Issue is closed, feel free to comment and add details, it will still 21 | be reviewed. 22 | - Issues that have already been identified as a bug (note: able to reproduce) will be labelled `bug`. 23 | - If you'd like to submit a fix for a bug, [send a Pull Request](#creating_a_pull_request) and mention the Issue number. 24 | - Include tests that isolate the bug and verifies that it was fixed. 25 | 26 | ### New Features 27 | - If you'd like to add new functionality to this project, describe the problem you want to solve in a [new Issue](https://github.com/SalesforceAIResearch/CodeTree/issues/new). 28 | - Issues that have been identified as a feature request will be labelled `enhancement`. 29 | - If you'd like to implement the new feature, please wait for feedback from the project 30 | maintainers before spending too much time writing the code. In some cases, `enhancement`s may 31 | not align well with the project objectives at the time. 32 | 33 | ### Tests, Documentation, Miscellaneous 34 | - If you'd like to improve the tests, you want to make the documentation clearer, you have an 35 | alternative implementation of something that may have advantages over the way its currently 36 | done, or you have any other change, we would be happy to hear about it! 37 | - If its a trivial change, go ahead and [send a Pull Request](#creating_a_pull_request) with the changes you have in mind. 38 | - If not, [open an Issue](https://github.com/SalesforceAIResearch/CodeTree/issues/new) to discuss the idea first. 39 | 40 | If you're new to our project and looking for some way to make your first contribution, look for 41 | Issues labelled `good first contribution`. 42 | 43 | # Contribution Checklist 44 | 45 | - [x] Clean, simple, well styled code 46 | - [x] Commits should be atomic and messages must be descriptive. Related issues should be mentioned by Issue number. 47 | - [x] Comments 48 | - Module-level & function-level comments. 49 | - Comments on complex blocks of code or algorithms (include references to sources). 50 | - [x] Tests 51 | - The test suite, if provided, must be complete and pass 52 | - Increase code coverage, not versa. 53 | - Use any of our testkits that contains a bunch of testing facilities you would need. For example: `import com.salesforce.op.test._` and borrow inspiration from existing tests. 54 | - [x] Dependencies 55 | - Minimize number of dependencies. 56 | - Prefer Apache 2.0, BSD3, MIT, ISC and MPL licenses. 57 | - [x] Reviews 58 | - Changes must be approved via peer code review 59 | 60 | # Creating a Pull Request 61 | 62 | 1. **Ensure the bug/feature was not already reported** by searching on GitHub under Issues. If none exists, create a new issue so that other contributors can keep track of what you are trying to add/fix and offer suggestions (or let you know if there is already an effort in progress). 63 | 3. **Clone** the forked repo to your machine. 64 | 4. **Create** a new branch to contain your work (e.g. `git br fix-issue-11`) 65 | 4. **Commit** changes to your own branch. 66 | 5. **Push** your work back up to your fork. (e.g. `git push fix-issue-11`) 67 | 6. **Submit** a Pull Request against the `main` branch and refer to the issue(s) you are fixing. Try not to pollute your pull request with unintended changes. Keep it simple and small. 68 | 7. **Sign** the Salesforce CLA (you will be prompted to do so when submitting the Pull Request) 69 | 70 | > **NOTE**: Be sure to [sync your fork](https://help.github.com/articles/syncing-a-fork/) before making a pull request. 71 | 72 | 73 | # Code of Conduct 74 | Please follow our [Code of Conduct](CODE_OF_CONDUCT.md). 75 | 76 | # License 77 | By contributing your code, you agree to license your contribution under the terms of our project [LICENSE](LICENSE.txt) and to sign the [Salesforce CLA](https://cla.salesforce.com/sign-cla) 78 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License Version 2.0 2 | 3 | Copyright (c) 2024 Salesforce, Inc. 4 | All rights reserved. 5 | 6 | Apache License 7 | Version 2.0, January 2004 8 | http://www.apache.org/licenses/ 9 | 10 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 11 | 12 | 1. Definitions. 13 | 14 | "License" shall mean the terms and conditions for use, reproduction, 15 | and distribution as defined by Sections 1 through 9 of this document. 16 | 17 | "Licensor" shall mean the copyright owner or entity authorized by 18 | the copyright owner that is granting the License. 19 | 20 | "Legal Entity" shall mean the union of the acting entity and all 21 | other entities that control, are controlled by, or are under common 22 | control with that entity. For the purposes of this definition, 23 | "control" means (i) the power, direct or indirect, to cause the 24 | direction or management of such entity, whether by contract or 25 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 26 | outstanding shares, or (iii) beneficial ownership of such entity. 27 | 28 | "You" (or "Your") shall mean an individual or Legal Entity 29 | exercising permissions granted by this License. 30 | 31 | "Source" form shall mean the preferred form for making modifications, 32 | including but not limited to software source code, documentation 33 | source, and configuration files. 34 | 35 | "Object" form shall mean any form resulting from mechanical 36 | transformation or translation of a Source form, including but 37 | not limited to compiled object code, generated documentation, 38 | and conversions to other media types. 39 | 40 | "Work" shall mean the work of authorship, whether in Source or 41 | Object form, made available under the License, as indicated by a 42 | copyright notice that is included in or attached to the work 43 | (an example is provided in the Appendix below). 44 | 45 | "Derivative Works" shall mean any work, whether in Source or Object 46 | form, that is based on (or derived from) the Work and for which the 47 | editorial revisions, annotations, elaborations, or other modifications 48 | represent, as a whole, an original work of authorship. For the purposes 49 | of this License, Derivative Works shall not include works that remain 50 | separable from, or merely link (or bind by name) to the interfaces of, 51 | the Work and Derivative Works thereof. 52 | 53 | "Contribution" shall mean any work of authorship, including 54 | the original version of the Work and any modifications or additions 55 | to that Work or Derivative Works thereof, that is intentionally 56 | submitted to Licensor for inclusion in the Work by the copyright owner 57 | or by an individual or Legal Entity authorized to submit on behalf of 58 | the copyright owner. For the purposes of this definition, "submitted" 59 | means any form of electronic, verbal, or written communication sent 60 | to the Licensor or its representatives, including but not limited to 61 | communication on electronic mailing lists, source code control systems, 62 | and issue tracking systems that are managed by, or on behalf of, the 63 | Licensor for the purpose of discussing and improving the Work, but 64 | excluding communication that is conspicuously marked or otherwise 65 | designated in writing by the copyright owner as "Not a Contribution." 66 | 67 | "Contributor" shall mean Licensor and any individual or Legal Entity 68 | on behalf of whom a Contribution has been received by Licensor and 69 | subsequently incorporated within the Work. 70 | 71 | 2. Grant of Copyright License. Subject to the terms and conditions of 72 | this License, each Contributor hereby grants to You a perpetual, 73 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 74 | copyright license to reproduce, prepare Derivative Works of, 75 | publicly display, publicly perform, sublicense, and distribute the 76 | Work and such Derivative Works in Source or Object form. 77 | 78 | 3. Grant of Patent License. Subject to the terms and conditions of 79 | this License, each Contributor hereby grants to You a perpetual, 80 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 81 | (except as stated in this section) patent license to make, have made, 82 | use, offer to sell, sell, import, and otherwise transfer the Work, 83 | where such license applies only to those patent claims licensable 84 | by such Contributor that are necessarily infringed by their 85 | Contribution(s) alone or by combination of their Contribution(s) 86 | with the Work to which such Contribution(s) was submitted. If You 87 | institute patent litigation against any entity (including a 88 | cross-claim or counterclaim in a lawsuit) alleging that the Work 89 | or a Contribution incorporated within the Work constitutes direct 90 | or contributory patent infringement, then any patent licenses 91 | granted to You under this License for that Work shall terminate 92 | as of the date such litigation is filed. 93 | 94 | 4. Redistribution. You may reproduce and distribute copies of the 95 | Work or Derivative Works thereof in any medium, with or without 96 | modifications, and in Source or Object form, provided that You 97 | meet the following conditions: 98 | 99 | (a) You must give any other recipients of the Work or 100 | Derivative Works a copy of this License; and 101 | 102 | (b) You must cause any modified files to carry prominent notices 103 | stating that You changed the files; and 104 | 105 | (c) You must retain, in the Source form of any Derivative Works 106 | that You distribute, all copyright, patent, trademark, and 107 | attribution notices from the Source form of the Work, 108 | excluding those notices that do not pertain to any part of 109 | the Derivative Works; and 110 | 111 | (d) If the Work includes a "NOTICE" text file as part of its 112 | distribution, then any Derivative Works that You distribute must 113 | include a readable copy of the attribution notices contained 114 | within such NOTICE file, excluding those notices that do not 115 | pertain to any part of the Derivative Works, in at least one 116 | of the following places: within a NOTICE text file distributed 117 | as part of the Derivative Works; within the Source form or 118 | documentation, if provided along with the Derivative Works; or, 119 | within a display generated by the Derivative Works, if and 120 | wherever such third-party notices normally appear. The contents 121 | of the NOTICE file are for informational purposes only and 122 | do not modify the License. You may add Your own attribution 123 | notices within Derivative Works that You distribute, alongside 124 | or as an addendum to the NOTICE text from the Work, provided 125 | that such additional attribution notices cannot be construed 126 | as modifying the License. 127 | 128 | You may add Your own copyright statement to Your modifications and 129 | may provide additional or different license terms and conditions 130 | for use, reproduction, or distribution of Your modifications, or 131 | for any such Derivative Works as a whole, provided Your use, 132 | reproduction, and distribution of the Work otherwise complies with 133 | the conditions stated in this License. 134 | 135 | 5. Submission of Contributions. Unless You explicitly state otherwise, 136 | any Contribution intentionally submitted for inclusion in the Work 137 | by You to the Licensor shall be under the terms and conditions of 138 | this License, without any additional terms or conditions. 139 | Notwithstanding the above, nothing herein shall supersede or modify 140 | the terms of any separate license agreement you may have executed 141 | with Licensor regarding such Contributions. 142 | 143 | 6. Trademarks. This License does not grant permission to use the trade 144 | names, trademarks, service marks, or product names of the Licensor, 145 | except as required for reasonable and customary use in describing the 146 | origin of the Work and reproducing the content of the NOTICE file. 147 | 148 | 7. Disclaimer of Warranty. Unless required by applicable law or 149 | agreed to in writing, Licensor provides the Work (and each 150 | Contributor provides its Contributions) on an "AS IS" BASIS, 151 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 152 | implied, including, without limitation, any warranties or conditions 153 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 154 | PARTICULAR PURPOSE. You are solely responsible for determining the 155 | appropriateness of using or redistributing the Work and assume any 156 | risks associated with Your exercise of permissions under this License. 157 | 158 | 8. Limitation of Liability. In no event and under no legal theory, 159 | whether in tort (including negligence), contract, or otherwise, 160 | unless required by applicable law (such as deliberate and grossly 161 | negligent acts) or agreed to in writing, shall any Contributor be 162 | liable to You for damages, including any direct, indirect, special, 163 | incidental, or consequential damages of any character arising as a 164 | result of this License or out of the use or inability to use the 165 | Work (including but not limited to damages for loss of goodwill, 166 | work stoppage, computer failure or malfunction, or any and all 167 | other commercial damages or losses), even if such Contributor 168 | has been advised of the possibility of such damages. 169 | 170 | 9. Accepting Warranty or Additional Liability. While redistributing 171 | the Work or Derivative Works thereof, You may choose to offer, 172 | and charge a fee for, acceptance of support, warranty, indemnity, 173 | or other liability obligations and/or rights consistent with this 174 | License. However, in accepting such obligations, You may act only 175 | on Your own behalf and on Your sole responsibility, not on behalf 176 | of any other Contributor, and only if You agree to indemnify, 177 | defend, and hold each Contributor harmless for any liability 178 | incurred by, or claims asserted against, such Contributor by reason 179 | of your accepting any such warranty or additional liability. 180 | 181 | END OF TERMS AND CONDITIONS 182 | 183 | APPENDIX: How to apply the Apache License to your work. 184 | 185 | To apply the Apache License to your work, attach the following 186 | boilerplate notice, with the fields enclosed by brackets "{}" 187 | replaced with your own identifying information. (Don't include 188 | the brackets!) The text should be enclosed in the appropriate 189 | comment syntax for the file format. We also recommend that a 190 | file or class name and description of purpose be included on the 191 | same "printed page" as the copyright notice for easier 192 | identification within third-party archives. 193 | 194 | Copyright {yyyy} {name of copyright owner} 195 | 196 | Licensed under the Apache License, Version 2.0 (the "License"); 197 | you may not use this file except in compliance with the License. 198 | You may obtain a copy of the License at 199 | 200 | http://www.apache.org/licenses/LICENSE-2.0 201 | 202 | Unless required by applicable law or agreed to in writing, software 203 | distributed under the License is distributed on an "AS IS" BASIS, 204 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 205 | See the License for the specific language governing permissions and 206 | limitations under the License. 207 | 208 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Repo of CodeTree: Agent-guided Tree Search for Code Generation with Large Language Models 2 | 3 | This is the repo of the paper: [CodeTree: Agent-guided Tree Search for Code Generation with Large Language Models](https://arxiv.org/abs/2411.04329) 4 | 5 | ## Run Scripts 6 | 7 | To run the repo, you first need to install requirements. 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | Set `OPENAI_API_KEY` environment variable to your OpenAI API key if you want to use openai methods: 14 | 15 | ``` 16 | export OPENAI_API_KEY= 17 | ``` 18 | 19 | Then run the scripts for the full method: 20 | 21 | ``` 22 | bash `agent_humaneval.sh` 23 | ``` 24 | 25 | ## Details 26 | 27 | We currently support the following options as the `strategy` argument ( corresponding to the paper): 28 | 29 | * `agent`: Full CodeTree method 30 | * `bfs`: CodeTree-BFS 31 | * `dfs`: CodeTree-DFS 32 | * `reflexion`: Reflexion 33 | * `resample`: Resample 34 | 35 | We currently support the following models as the `model` argument ( corresponding to the paper): 36 | 37 | * `GPT-4o-mini`: gpt-4o-mini-2024-07-18 38 | * `GPT-4o`: gpt-4o-2024-08-06 39 | * `GPT-3.5-turbo`: GPT-3.5-turbo (outdated and not recommended) 40 | * `GPT-4`: GPT-4 (outdated and not recommended) 41 | * `Llama-3.1-8B-Instruct`: meta-llama/Llama-3.1-8B-Instruct 42 | 43 | Datasets are in `CodeTree/data/` 44 | 45 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | 22 | -------------------------------------------------------------------------------- /agent_humaneval.sh: -------------------------------------------------------------------------------- 1 | # export OPENAI_API_KEY="" # input your openai key if not already 2 | python main.py \ 3 | --run_name "code_4o-mini-agent" \ 4 | --root_dir "root" \ 5 | --dataset_path data/humaneval.jsonl \ 6 | --strategy "agent" \ 7 | --language "py" \ 8 | --model "gpt-4o-mini" \ 9 | --pass_at_k "1" \ 10 | --max_iters 20 \ 11 | --function \ 12 | --verbose 13 | -------------------------------------------------------------------------------- /bfs.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import openai 22 | from utils import enumerate_resume, make_printv, write_jsonl, resume_success_count 23 | from executors import executor_factory 24 | from generators import generator_factory, model_factory 25 | from typing import List, Dict, Tuple, Any 26 | import math 27 | import re 28 | import sys 29 | from collections import Counter 30 | from common import gen_test_eval 31 | # bfs real 32 | sys.set_int_max_str_digits(100000) # Increase the limit to 10000 digits 33 | # Many are passed but not solved, so if one has passed, use the agreement function to select one 34 | 35 | ADD_HINT = "To solve the problem, you can refer the hint given by an expert, and complete the details by analyzing it first.\nHint:" 36 | 37 | # TODO: From sample to list 38 | class Node: 39 | def __init__(self, solution: str, parent=None, context="", depth=0): 40 | self.solution = solution 41 | self.parent = parent 42 | self.children = [] 43 | self.value = 0 44 | self.visits = 0 45 | self.context = "" 46 | self.depth = depth 47 | self.reflection = "" 48 | self.test_feedback = "" 49 | self.strategy="" 50 | 51 | def uct(self, exploration_weight=1.0): 52 | if self.visits == 0: 53 | # return float('inf') 54 | return self.value 55 | return (self.value / self.visits) + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits) 56 | 57 | def best_child(self): 58 | if not self.children: # Check if children list is empty 59 | return None 60 | return max(self.children, key=lambda child: child.uct()) 61 | 62 | def best_child_value(self): 63 | if not self.children: # Check if children list is empty 64 | return None 65 | return max(self.children, key=lambda child: child.value) 66 | 67 | def sort_children_by_value(self): 68 | self.children.sort(key=lambda x: x.value, reverse=True) 69 | 70 | def update(self, reward: float): 71 | self.visits += 1 72 | self.value += reward 73 | 74 | 75 | def rerank_list_of_nodes(list_of_nodes): 76 | return sorted(list_of_nodes, key=lambda x:x.value, reverse=True) # small value in the front 77 | 78 | def run_bfs( 79 | dataset: List[dict], 80 | model_name: str, 81 | language: str, 82 | log_path: str, 83 | verbose: bool, 84 | max_iters: int, 85 | is_leetcode: bool = False, 86 | max_depth: int = 3, 87 | search_width: int = 3, 88 | Codecontests: bool = False 89 | ) -> None: 90 | if Codecontests: exe = executor_factory("code_contests") 91 | else: exe = executor_factory(language, is_leet=is_leetcode) 92 | 93 | gen = generator_factory(language) 94 | model = model_factory(model_name) 95 | print_v = make_printv(verbose) 96 | count, sad_case, debug_thoroughed_case, enter_debug_case = 0, 0, 0, 0 97 | num_items = len(dataset) 98 | num_success, weak_success = 0, 0 # Counter for successful solutions 99 | passed_at_sample, solve_or_not = [], [] 100 | debug_case, skip = 0, 0 101 | pass_problem_subset = [] 102 | for idx, item in enumerate(dataset): 103 | tests_i = item["given_tests"] 104 | if Codecontests: 105 | item["entry_point"] = "" 106 | else: 107 | tests_i = [test for test in tests_i if item['entry_point'] in test and 'assert False' not in test] 108 | 109 | hints = gen.strategy(item["prompt"], model, num_strategy=search_width, task="strategy", temperature=0) 110 | if len(hints) > search_width: hints = hints[:search_width] 111 | stack, memory_stack = [], [] 112 | is_solved, is_weaker_solved = False, False 113 | num_try = 0 114 | if len(hints) < search_width: 115 | count += 1 116 | for hint in hints: 117 | cur_func_impl = gen.func_impl(item["prompt"] + f"{ADD_HINT} {hint}\n", model, "simple", 118 | temperature=0) 119 | new_node = Node(cur_func_impl) 120 | num_try += 1 121 | is_passing, feedback, reward = gen_test_eval(exe, cur_func_impl, tests_i, prev=item["prev"]) 122 | if is_passing: 123 | is_solved = exe.evaluate( 124 | item["entry_point"], cur_func_impl, item["test"], timeout=1, prev=item["prev"]) # early exit 125 | if "weaker_test" in item.keys(): 126 | is_weaker_solved = exe.evaluate( 127 | item["entry_point"], cur_func_impl, item["weaker_test"], timeout=1, prev=item["prev"]) 128 | break 129 | new_node.test_feedback = feedback 130 | new_node.update(reward) 131 | new_node.strategy = hint 132 | stack.append(new_node) 133 | # Exit when passed public test cases. 134 | 135 | if is_passing: 136 | if is_solved: 137 | num_success += int(is_solved) 138 | passed_at_sample.append(num_try) 139 | if "difficulty" in item.keys(): pass_problem_subset.append(item["difficulty"]) 140 | else: 141 | print("SAD, passed but not solved.") 142 | sad_case += 1 143 | if is_weaker_solved: 144 | weak_success += int(is_weaker_solved) 145 | item["weak_acc"] = round(weak_success / (idx + 1), 3) 146 | item["acc"] = round(num_success / (idx + 1), 3) 147 | write_jsonl(log_path, [item], append=True) 148 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={item["weak_acc"]}, pass no solve: {sad_case}, enter debug: {debug_case}') 149 | continue # early stop on this case if passsed 150 | 151 | print("Entering Debugging Stage") 152 | debuged = True; debug_case += 1 153 | stack = rerank_list_of_nodes(stack) # out of all stack 154 | print("Stack after sorting: ", [a.value for a in stack]) 155 | while stack and num_try < max_iters and not is_passing: 156 | this_node = stack.pop(0) 157 | if this_node.depth >= max_depth: continue 158 | this_node.visits += 1 159 | reflections = gen.strategy(item["prompt"], 160 | model, task="reflection", 161 | num_strategy=search_width, 162 | prev_func_impl=this_node.solution, 163 | feedback=this_node.test_feedback, 164 | temperature=0, 165 | given_strategy=this_node.strategy) 166 | if len(reflections) < 2: "print not enough reflections!" 167 | for reflection in reflections: 168 | if num_try >= max_iters: break 169 | new_solution, while_cnt = None, 0 170 | while new_solution is None and while_cnt < 3: 171 | while_cnt += 1 172 | new_solution = gen.func_impl( 173 | func_sig=item["prompt"], 174 | model=model, 175 | strategy="reflexion", 176 | prev_func_impl=this_node.solution, 177 | feedback=this_node.test_feedback, 178 | self_reflection=reflection, 179 | temperature=0 180 | ) 181 | is_passing, feedback, reward = gen_test_eval(exe, new_solution, tests_i, prev=item["prev"]) 182 | num_try += 1 183 | if is_passing: 184 | is_solved = exe.evaluate( 185 | item["entry_point"], new_solution, item["test"], timeout=1, prev=item["prev"]) 186 | if "weaker_test" in item.keys(): 187 | is_weaker_solved = exe.evaluate( 188 | item["entry_point"], new_solution, item["weaker_test"], timeout=1, prev=item["prev"]) 189 | break 190 | new_node = Node(new_solution, depth=this_node.depth + 1) 191 | new_node.test_feedback = feedback 192 | new_node.update(reward) 193 | new_node.strategy = this_node.strategy 194 | this_node.children.append(new_node) 195 | if is_passing: break 196 | this_node.sort_children_by_value() 197 | stack.extend(this_node.children) 198 | print("Children after sorting: ", [a.value for a in stack]) 199 | if num_try >= max_iters: debug_thoroughed_case += 1 200 | if is_passing: 201 | if debuged: enter_debug_case += 1 202 | if is_solved: 203 | num_success += int(is_solved) 204 | passed_at_sample.append(num_try) 205 | if "difficulty" in item.keys(): pass_problem_subset.append(item["difficulty"]) 206 | else: 207 | sad_case += 1 208 | print("Sad, pass but not solve") 209 | if is_weaker_solved: 210 | weak_success += int(is_weaker_solved) 211 | item["weak_acc"] = round(weak_success / (idx + 1), 3) 212 | item["acc"] = round(num_success / (idx + 1), 3) 213 | write_jsonl(log_path, [item], append=True) 214 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={item["weak_acc"]}, pass no solve: {sad_case}, enter debug: {debug_case}') 215 | continue 216 | 217 | print("_______________________________") 218 | print(passed_at_sample) 219 | print(sorted(passed_at_sample)) 220 | print(len(passed_at_sample)) 221 | print(Counter(passed_at_sample)) 222 | print("Passed but not solved case", sad_case) 223 | print("not sample 2 even when asked: ", count) 224 | print("20 tries used still not solve:", debug_thoroughed_case) 225 | print("Pass not solve after debugging", enter_debug_case) 226 | print(Counter(pass_problem_subset)) 227 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={round(weak_success / (idx + 1), 3)}') 228 | 229 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import ast 22 | import textwrap 23 | from copy import copy 24 | 25 | def add_prev_funcs(prompts, cur_function, entry_point): 26 | all_functions = find_functions_with_implementation(prompts) 27 | if len(all_functions) == 1: 28 | return cur_function 29 | else: 30 | assert all_functions[-1]["func_name"] == entry_point # last one is to be implemented 31 | return "\n".join(ele["func_code"] for ele in all_functions[:-1]) + f"\n{cur_function}" 32 | 33 | 34 | def wrap_mbpp_data(dataset): 35 | list_of_data = [] 36 | for i in range(len(dataset)): 37 | code = dataset[i]["code"] 38 | test_case_string = "\n".join(dataset[i]["test_list"]) 39 | docstring = f"\"\"\"\n{dataset[i]['prompt']}\n\nExamples for reference:\n{test_case_string}\n\"\"\"" 40 | lines = code.strip().splitlines() 41 | for line in reversed(lines): # from last function definition 42 | if line[:3] == "def": 43 | header = line 44 | break 45 | # how_many_functions += 1 46 | 47 | comment = textwrap.indent(docstring, code.strip().splitlines()[-1].split("return")[0]) 48 | new_point = dict(copy(dataset[i])) 49 | new_point["prompt"] = f"{header}\n{comment}" 50 | new_point["given_tests"] = copy(dataset[i]["test_list"]) 51 | new_point["weaker_test"] = "\n".join(new_point["given_tests"]) # add lines to the end of program 52 | new_point["entry_point"] = header 53 | new_point["prev"] = "" 54 | list_of_data.append(new_point) 55 | print("mbpp question example:", list_of_data[0]["prompt"], sep="\n") 56 | return list_of_data 57 | 58 | def gen_test_eval(exe, solution, test_cases, prev=""): 59 | is_passing, feedback, _ = exe.execute(solution, test_cases, timeout=1, prev=prev) 60 | if is_passing: reward = 1 61 | else: 62 | reward = _.count(True)/len(_) 63 | return is_passing, feedback, reward 64 | def wrap_human_eval_data(dataset_loaded, dataset_evalplus): 65 | dataset_dict = {entry['task_id']: entry for entry in dataset_loaded} 66 | list_of_data = [] 67 | for i in range(len(dataset_evalplus)): 68 | # print("processing") 69 | task_id = dataset_evalplus[i]['task_id'] 70 | entry_point = dataset_evalplus[i]['entry_point'] 71 | new_point = dict(copy(dataset_evalplus[i])) 72 | new_point["weaker_test"] = dataset_dict[task_id]['test'] + f"\ncheck({entry_point})\n" 73 | new_point["test"] = dataset_evalplus[i]["test"] + f"\ncheck({entry_point})\n" 74 | new_point["given_tests"] = copy(dataset_dict[task_id]['given_tests']) 75 | temp = extract_implemented_functions( 76 | dataset_evalplus[i]["prompt"]) # find_functions_with_implementation(dataset_evalplus[i]["prompt"]) 77 | if temp: 78 | new_point["prev"] = f"\n{temp}\n" 79 | print("Multiple implementation!") 80 | print(temp) 81 | else: 82 | new_point["prev"] = "" 83 | list_of_data.append(new_point) 84 | 85 | # print(dataset_evalplus[i]["given_tests"]) 86 | print(list_of_data[0]["weaker_test"]) 87 | print(list_of_data[3]["prompt"]) 88 | # print(list_of_data[0]["test"]) 89 | return list_of_data # dataset_evalplus 90 | def has_docstring(func_node): 91 | if func_node.body and isinstance(func_node.body[0], ast.Expr): 92 | expr = func_node.body[0].value 93 | if isinstance(expr, ast.Constant) and isinstance(expr.value, str): 94 | return True 95 | return False 96 | 97 | def contains_return_with_value(func_node): 98 | for node in ast.walk(func_node): 99 | if isinstance(node, ast.Return): 100 | if node.value is not None: 101 | # Optionally, ensure the return value is not None 102 | if not (isinstance(node.value, ast.Constant) and node.value.value is None): 103 | return True 104 | return False 105 | 106 | def is_effective_function(func_node): 107 | if not func_node.body: 108 | return False 109 | 110 | # Exclude functions that only contain 'pass' or 'Ellipsis' 111 | for stmt in func_node.body: 112 | if isinstance(stmt, ast.Pass): 113 | return False 114 | if isinstance(stmt, ast.Expr): 115 | expr = stmt.value 116 | if isinstance(expr, ast.Constant) and expr.value == Ellipsis: 117 | return False 118 | 119 | # Check for return statements with values 120 | if contains_return_with_value(func_node): 121 | return True 122 | 123 | return False 124 | 125 | def find_functions_with_implementation(source_code): 126 | node = ast.parse(source_code) 127 | functions_info = [] 128 | 129 | for n in ast.walk(node): 130 | if isinstance(n, ast.FunctionDef): 131 | functions_info.append({ 132 | "func_name": n.name, 133 | "func_code": ast.get_source_segment(source_code, n), 134 | "implemented": is_effective_function(n), 135 | "has_docstring": has_docstring(n) 136 | }) 137 | 138 | return functions_info 139 | 140 | def extract_implemented_functions(source_code): 141 | functions_info = find_functions_with_implementation(source_code) 142 | try: 143 | implemented_funcs = [func['func_code'] for func in functions_info if func['implemented']] 144 | except: 145 | print(functions_info) 146 | return "\n\n".join(implemented_funcs) 147 | 148 | def cal_metrics(decisions): 149 | # Initialize counts for TP, TN, FP, and FN 150 | TP = TN = FP = FN = 0 151 | 152 | # Iterate through decisions and count each case 153 | for predict, label in decisions: 154 | if predict == 1 and label == 1: 155 | TP += 1 156 | elif predict == 0 and label == 0: 157 | TN += 1 158 | elif predict == 1 and label == 0: 159 | FP += 1 160 | elif predict == 0 and label == 1: 161 | FN += 1 162 | 163 | # Calculate accuracy 164 | total = TP + TN + FP + FN 165 | accuracy = (TP + TN) / total if total > 0 else 0 166 | 167 | # Return the metrics as a dictionary 168 | return { 169 | 'TP': TP, 170 | 'TN': TN, 171 | 'FP': FP, 172 | 'FN': FN, 173 | 'accuracy': accuracy 174 | } 175 | 176 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import argparse 22 | _args=None 23 | def get_parsed_args(): 24 | global _args 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--run_name", type=str, help="The name of the run") 27 | parser.add_argument("--root_dir", type=str, 28 | help="The root logging directory", default="root") 29 | parser.add_argument("--dataset_path", type=str, 30 | help="The path to the benchmark dataset", default="root") 31 | parser.add_argument("--language", type=str, help=" `py` only") 32 | parser.add_argument("--model", type=str, help="GPT models, and LLaMA 3.1 models") 33 | parser.add_argument("--pass_at_k", type=int, 34 | help="Pass@k metric, only implemented pass@1", default=1) 35 | parser.add_argument("--max_iters", type=int, 36 | help="The maximum number of total tries in code implementation(budget)", default=10) 37 | parser.add_argument("--max_depth", type=int) 38 | parser.add_argument("--strategy", type=str, help="run methods, [reflexion, dfs, bfs, agent, resample, strategy]") 39 | parser.add_argument("--search_width", type=int) 40 | parser.add_argument("--verbose", action='store_true', help="To print live logs") 41 | parser.add_argument("--function", action='store_true', 42 | help="if it's function implementation task or a program implementation task, codecontests=False, mbpp/humaneval=True") 43 | _args = parser.parse_args() 44 | return _args -------------------------------------------------------------------------------- /dfs_real.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import openai 22 | from utils import make_printv, write_jsonl 23 | from executors import executor_factory 24 | from generators import generator_factory, model_factory 25 | from typing import List, Dict, Tuple, Any 26 | import math 27 | 28 | import sys 29 | from collections import Counter 30 | from common import wrap_mbpp_data, wrap_human_eval_data, gen_test_eval 31 | 32 | 33 | ADD_HINT = "To solve the problem, you can refer the hint given by an expert: " 34 | 35 | # TODO: ADD feature to combine nodes 36 | class Node: 37 | def __init__(self, solution: str, parent=None, context="", depth=0): 38 | self.solution = solution 39 | self.parent = parent 40 | self.children = [] 41 | self.value = 0 42 | self.visits = 0 43 | self.context = "" 44 | self.depth = depth 45 | self.reflection = "" 46 | self.test_feedback = "" 47 | self.strategy = "" 48 | 49 | def uct(self, exploration_weight=1.0): 50 | if self.visits == 0: 51 | return self.value 52 | return (self.value / self.visits) + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits) 53 | 54 | def best_child(self): 55 | if not self.children: # Check if children list is empty 56 | return None 57 | return max(self.children, key=lambda child: child.uct()) 58 | 59 | def best_child_value(self): 60 | if not self.children: # Check if children list is empty 61 | return None 62 | return max(self.children, key=lambda child: child.value) 63 | 64 | def sort_children_by_value(self): 65 | self.children.sort(key=lambda x: x.value) 66 | 67 | def update(self, reward: float): 68 | self.visits += 1 69 | self.value += reward 70 | 71 | 72 | def rerank_list_of_nodes(list_of_nodes): 73 | return sorted(list_of_nodes, key=lambda x:x.value) # small value in the front 74 | 75 | def run_dfs( 76 | dataset: List[dict], 77 | model_name: str, 78 | language: str, 79 | max_iters: int, 80 | log_path: str, 81 | verbose: bool, 82 | is_leetcode: bool = False, 83 | max_depth: int = 3, 84 | search_width: int = 3, 85 | Codecontests: bool = False 86 | 87 | ) -> None: 88 | print("max_depth",max_depth, "search_width",search_width) 89 | from datasets import load_dataset 90 | pass_problem_subset = [] 91 | if Codecontests: 92 | exe = executor_factory("code_contests") 93 | else: exe = executor_factory(language, is_leet=is_leetcode) 94 | gen = generator_factory(language) 95 | model = model_factory(model_name) 96 | print_v = make_printv(verbose) 97 | count, sad_case, debug_thoroughed_case, enter_debug_case = 0, 0, 0, 0 98 | num_items = len(dataset) 99 | num_success, weak_success = 0, 0 # Counter for successful solutions 100 | passed_at_sample, solve_or_not = [], [] 101 | debug_case = 0 102 | skip = 0 103 | for idx, item in enumerate(dataset): 104 | print("STARTING EXAMPLE", idx) 105 | tests_i = item["given_tests"] 106 | if Codecontests: 107 | item["entry_point"] = "" 108 | else: 109 | tests_i = [test for test in tests_i if item['entry_point'] in test and 'assert False' not in test] 110 | hints = gen.strategy(item["prompt"], model, num_strategy=search_width, task="strategy", temperature=0) 111 | 112 | if len(hints) > search_width: hints = hints[:search_width] 113 | stack = [] 114 | is_solved = False 115 | is_weaker_solved=False 116 | num_try = 0 117 | for hint in hints: 118 | cur_func_impl = gen.func_impl(item["prompt"] + f"{ADD_HINT}{hint}\n", model, "simple", 119 | temperature=0) 120 | new_node = Node(cur_func_impl) 121 | num_try += 1 122 | is_passing, feedback, reward = gen_test_eval(exe, cur_func_impl, tests_i, prev=item["prev"]) 123 | if is_passing: 124 | is_solved = exe.evaluate( 125 | item["entry_point"], cur_func_impl, item["test"], timeout=1, prev=item["prev"]) # early exit 126 | if "weaker_test" in item.keys(): 127 | is_weaker_solved = exe.evaluate( 128 | item["entry_point"], cur_func_impl, item["weaker_test"], timeout=1, prev=item["prev"]) 129 | break 130 | new_node.test_feedback = feedback 131 | new_node.update(reward) 132 | new_node.strategy = hint 133 | stack.append(new_node) 134 | if is_passing: 135 | if is_solved: 136 | num_success += int(is_solved) 137 | passed_at_sample.append(num_try) 138 | if "difficulty" in item.keys(): pass_problem_subset.append(item["difficulty"]) 139 | else: 140 | print("SAD, passed but not solved.") 141 | sad_case += 1 142 | if is_weaker_solved: 143 | weak_success += int(is_weaker_solved) 144 | item["acc"] = round(num_success / (idx + 1), 3) 145 | item["weak_acc"] = round(weak_success / (idx + 1), 3) 146 | write_jsonl(log_path, [item], append=True) 147 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={item["weak_acc"]}, pass no solve: {sad_case}, enter debug: {debug_case}, "exhaust": {debug_thoroughed_case}', "skip", skip) 148 | continue # early stop on this case if passsed 149 | 150 | print("Entering Debugging Stage") 151 | debuged = True; debug_case += 1 152 | stack = list(reversed(stack)) 153 | stack = rerank_list_of_nodes(stack) # out of all stack 154 | print("Stack after sorting: ", [a.value for a in stack]) 155 | while stack and num_try < max_iters and not is_passing: 156 | if len(stack) == 0: break 157 | this_node = stack.pop() 158 | if this_node.depth >= max_depth: continue 159 | this_node.visits += 1 160 | reflections = gen.strategy(item["prompt"], 161 | model, task="reflection", 162 | num_strategy=search_width, 163 | prev_func_impl=this_node.solution, 164 | feedback=this_node.test_feedback, 165 | temperature=0, 166 | given_strategy=this_node.strategy) 167 | if len(reflections) < 2: print("not enough reflections!") 168 | for reflection in reflections: 169 | new_solution, while_cnt = None, 0 170 | while new_solution is None and while_cnt < 3: 171 | while_cnt += 1 172 | new_solution = gen.func_impl( 173 | func_sig=item["prompt"], 174 | model=model, 175 | strategy="reflexion", 176 | prev_func_impl=this_node.solution, 177 | feedback=this_node.test_feedback, 178 | self_reflection=reflection, 179 | temperature=0 180 | ) 181 | is_passing, feedback, reward = gen_test_eval(exe, new_solution, tests_i,prev=item["prev"]) 182 | num_try += 1 183 | if is_passing: 184 | is_solved = exe.evaluate( 185 | item["entry_point"], new_solution, item["test"], timeout=1, prev=item["prev"]) 186 | if "weaker_test" in item.keys(): 187 | is_weaker_solved = exe.evaluate( 188 | item["entry_point"], new_solution, item["weaker_test"], timeout=1, prev=item["prev"]) 189 | break 190 | new_node = Node(new_solution, depth=this_node.depth + 1) 191 | new_node.test_feedback = feedback 192 | new_node.update(reward) 193 | new_node.strategy = this_node.strategy 194 | this_node.children.append(new_node) 195 | if is_passing: break 196 | this_node.children.reverse() 197 | this_node.sort_children_by_value() 198 | stack.extend(this_node.children) 199 | print("Children after sorting: ", [a.value for a in stack]) 200 | if num_try >= max_iters: debug_thoroughed_case += 1 201 | if is_passing: 202 | if is_solved: 203 | num_success += int(is_solved) 204 | passed_at_sample.append(num_try) 205 | if "difficulty" in item.keys(): pass_problem_subset.append(item["difficulty"]) 206 | else: 207 | sad_case += 1 208 | print("Sad, pass but not solve") 209 | if debuged: enter_debug_case += 1 210 | if is_weaker_solved: 211 | weak_success += int(is_weaker_solved) 212 | item["acc"] = round(num_success / (idx + 1), 3) 213 | item["weak_acc"] = round(weak_success / (idx + 1), 3) 214 | write_jsonl(log_path, [item], append=True) 215 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={item["weak_acc"]}, pass no solve: {sad_case}, enter debug: {debug_case}, exhaust: {debug_thoroughed_case}',"skip", skip) 216 | continue # early stop on this case if passsed 217 | 218 | print("_______________________________") 219 | print(passed_at_sample) 220 | print(sorted(passed_at_sample)) 221 | print(len(passed_at_sample)) 222 | print(Counter(passed_at_sample)) 223 | print("Passed but not solved case", sad_case) 224 | print("not sample 2 even when asked: ", count) 225 | print("20 tries used still not solve:", debug_thoroughed_case) 226 | print("Pass not solve after debugging", enter_debug_case) 227 | print(Counter(pass_problem_subset)) 228 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={round(weak_success / (idx + 1), 3)}') 229 | -------------------------------------------------------------------------------- /executors/__init__.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | from .py_executor import PyExecutor 22 | from .factory import executor_factory -------------------------------------------------------------------------------- /executors/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/executors/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /executors/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/executors/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /executors/__pycache__/executor_types.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/executors/__pycache__/executor_types.cpython-310.pyc -------------------------------------------------------------------------------- /executors/__pycache__/executor_types.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/executors/__pycache__/executor_types.cpython-311.pyc -------------------------------------------------------------------------------- /executors/__pycache__/executor_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/executors/__pycache__/executor_utils.cpython-310.pyc -------------------------------------------------------------------------------- /executors/__pycache__/executor_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/executors/__pycache__/executor_utils.cpython-311.pyc -------------------------------------------------------------------------------- /executors/__pycache__/factory.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/executors/__pycache__/factory.cpython-310.pyc -------------------------------------------------------------------------------- /executors/__pycache__/factory.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/executors/__pycache__/factory.cpython-311.pyc -------------------------------------------------------------------------------- /executors/__pycache__/py_executor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/executors/__pycache__/py_executor.cpython-310.pyc -------------------------------------------------------------------------------- /executors/__pycache__/py_executor.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/executors/__pycache__/py_executor.cpython-311.pyc -------------------------------------------------------------------------------- /executors/executor_types.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | from typing import NamedTuple, List, Tuple 22 | from abc import ABC, abstractmethod 23 | 24 | class ExecuteResult(NamedTuple): 25 | is_passing: bool 26 | feedback: str 27 | state: Tuple[bool] 28 | 29 | class Executor(ABC): 30 | @abstractmethod 31 | def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult: 32 | ... 33 | 34 | @abstractmethod 35 | def evaluate(self, name: str, func: str, test: str, timeout: int = 5) -> bool: 36 | ... 37 | 38 | # class Executor: 39 | # def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult: 40 | # raise NotImplementedError 41 | 42 | # def evaluate(self, name: str, func: str, test: str, timeout: int = 5) -> bool: 43 | # raise NotImplementedError 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /executors/executor_utils.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | def timeout_handler(_, __): 22 | raise TimeoutError() 23 | import os, json 24 | 25 | 26 | def to_jsonl(dict_data, file_path): 27 | with open(file_path, 'a') as file: 28 | json_line = json.dumps(dict_data) 29 | file.write(json_line + os.linesep) 30 | 31 | 32 | from threading import Thread, Event 33 | import threading 34 | 35 | class PropagatingThread(Thread): 36 | def __init__(self, *args, **kwargs): 37 | super().__init__(*args, **kwargs) 38 | self.stop_event = Event() 39 | self.exc = None 40 | def run(self): 41 | # get the start time 42 | self.exc = None 43 | try: 44 | if hasattr(self, '_Thread__target'): 45 | # Thread uses name mangling prior to Python 3. 46 | self.ret = self._Thread__target(*self._Thread__args, **self._Thread__kwargs) 47 | else: 48 | self.ret = self._target(*self._args, **self._kwargs) 49 | except BaseException as e: 50 | self.exc = e 51 | 52 | def join(self, timeout=None): 53 | super(PropagatingThread, self).join(timeout=timeout) 54 | if self.exc: 55 | raise self.exc 56 | # if not hasattr(self, 'res'): raise SystemExit() 57 | return self.ret if hasattr(self, 'ret') else None 58 | 59 | def stop(self): 60 | self.stop_event.set() 61 | 62 | def should_stop(self): 63 | return self.stop_event.is_set() 64 | 65 | 66 | def function_with_timeout(func, args, timeout): 67 | # print("func",func, end="!!!\n") 68 | result_container = [] 69 | 70 | def wrapper(): 71 | result_container.append(func(*args)) 72 | 73 | thread = PropagatingThread(target=wrapper) 74 | thread.start() 75 | 76 | 77 | thread.join(timeout=timeout) 78 | 79 | if thread.is_alive(): 80 | print("Still Alive") 81 | thread.stop_event.set() 82 | thread.join(0.01) 83 | raise TimeoutError() 84 | else: 85 | return result_container[0] 86 | 87 | # Py tests 88 | 89 | # if __name__ == "__main__": 90 | # formatter = PySubmissionFormatter() 91 | # leetcode_1 = 'class Solution:\n def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n ' 92 | # humaneval_1 = 'def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n' 93 | 94 | # assert leetcode_1 == formatter.to_leetcode(humaneval_1) 95 | # assert humaneval_1 == formatter.to_humaneval(leetcode_1) -------------------------------------------------------------------------------- /executors/factory.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | from .py_executor import PyExecutor, PyExecutorPro 22 | from .executor_types import Executor 23 | 24 | def executor_factory(lang: str, is_leet: bool = False) -> Executor: 25 | if lang == "code_contests": return PyExecutorPro() 26 | if lang == "py" or lang == "python": 27 | return PyExecutor() 28 | else: 29 | raise ValueError(f"Invalid language for executor: {lang}") 30 | -------------------------------------------------------------------------------- /executors/py_executor.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import ast 22 | import signal 23 | import astunparse 24 | import subprocess 25 | from .executor_utils import function_with_timeout 26 | import sys 27 | import random 28 | import os 29 | from typing import List 30 | from .executor_types import ExecuteResult, Executor 31 | def check_syntax_error(file_path): 32 | with open(file_path, 'r') as file: 33 | source_code = file.read() 34 | try: 35 | # Try to parse the source code. 36 | ast.parse(source_code) 37 | return False, None # without error 38 | except SyntaxError as e: 39 | # If a syntax error is found, report the error. 40 | print(f"SyntaxError in code: {e}") 41 | return True, e # with error 42 | def check_syntax_error_source(source_code): 43 | try: 44 | # Try to parse the source code. 45 | ast.parse(source_code) 46 | return False, None # without error 47 | except SyntaxError as e: 48 | # If a syntax error is found, report the error. 49 | print(f"SyntaxError in file code: {e}") 50 | return True, e # with error 51 | 52 | class PyExecutor(Executor): 53 | def execute(self, func: str, tests: List[str], timeout: int = 5, prev="") -> ExecuteResult: 54 | # Combine function code and assert statement 55 | error, syn = check_syntax_error_source(func) 56 | if error: return ExecuteResult(False, str(error), tuple([False]*len(tests))) 57 | imports = 'from typing import *\nfrom collections import *\nimport math\nimport sys\nimport random\nimport re\nimport itertools, functools, bisect, heapq, string\n' 58 | func_test_list = [f'{imports}\n{func}\n{test}\n' for test in tests] 59 | if prev: imports += f"{prev}\n" 60 | # Run the tests and collect the results 61 | success_tests = [] 62 | failed_tests = [] 63 | is_passing = True 64 | num_tests = len(func_test_list) 65 | for i in range(num_tests): 66 | try: 67 | pid = os.getpid() 68 | path = f"temp_files/temp_test{pid}.py" 69 | with open(path, "w") as temp_file: 70 | temp_file.write(func_test_list[i]) 71 | 72 | result = subprocess.run( 73 | ['python', f'{path}'], 74 | capture_output=True, 75 | text=True, 76 | timeout=timeout 77 | ) 78 | if result.returncode == 0: 79 | success_tests.append(tests[i]) 80 | continue 81 | else: 82 | new_test_line = f'{imports}\n{func}\nprint({get_call_str(tests[i])})' 83 | 84 | path = f"temp_files/temp_test_output{pid}.py" 85 | with open(path, "w") as temp_file: 86 | temp_file.write(new_test_line) 87 | 88 | result = subprocess.run( 89 | ['python', f'{path}'], 90 | capture_output=True, 91 | text=True, 92 | timeout=timeout 93 | ) 94 | output = result.stdout 95 | error = result.stderr 96 | if error: output = error + output 97 | 98 | failed_tests += [f"{tests[i]} # incorrect program's output: {output}"] 99 | is_passing = False 100 | except subprocess.TimeoutExpired: 101 | failed_tests.append(f"{tests[i]} # incorrect program's output: Program Timed Out after {timeout} seconds.") 102 | is_passing = False 103 | except Exception as e: 104 | 105 | failed_tests += [f"{tests[i]} # incorrect program's output: {e}"] 106 | is_passing = False 107 | 108 | state = [] 109 | for test in tests: 110 | if test in success_tests: 111 | state += [True] 112 | else: 113 | state += [False] 114 | 115 | state = tuple(state) 116 | feedback = "" 117 | if success_tests: feedback += "Tested passed:" 118 | for test in success_tests: 119 | feedback += f"\n{test}" 120 | 121 | if failed_tests: feedback += "\nTests failed:" 122 | for test in failed_tests: 123 | feedback += f"\n{test}" 124 | if len(feedback) > 2500: feedback = feedback[:2500] # feedback length cut 125 | return ExecuteResult(is_passing, feedback, state) 126 | 127 | def evaluate(self, name: str, func: str, test: str, timeout: int = 5, prev="") -> bool: 128 | """ 129 | Evaluates the implementation on Human-Eval Python. 130 | 131 | probably should be written in a dataset-agnostic way but not now 132 | """ 133 | code = f"""from typing import * 134 | from collections import * 135 | import math, itertools, functools, bisect, heapq, string 136 | import sys 137 | import re 138 | {prev} 139 | {func} 140 | 141 | {test} 142 | """ 143 | try: 144 | pid = os.getpid() 145 | path = f"temp_files/temp_test{pid}.py" 146 | 147 | with open(path, "w") as temp_file: 148 | temp_file.write(code) 149 | result = subprocess.run( 150 | ['python', f'{path}'], 151 | capture_output=True, 152 | text=True, 153 | timeout=timeout 154 | ) 155 | if result.returncode == 0: 156 | return True 157 | 158 | return False 159 | except Exception as e: 160 | # print(code) 161 | # print("error during handling", e) 162 | return False 163 | 164 | def get_call_str(assert_statement: str) -> str: 165 | ast_parsed = ast.parse(assert_statement) 166 | try: 167 | call_str = ast_parsed.body[0].test.left # type: ignore 168 | except: 169 | call_str = ast_parsed.body[0].test # type: ignore 170 | 171 | return astunparse.unparse(call_str).strip() 172 | 173 | class PyExecutorPro(): 174 | def execute(self, func: str, tests: List[dict], timeout: int = 5, prev="") -> ExecuteResult: 175 | num_tests = len(tests["input"]) 176 | assert len(tests["output"]) == num_tests 177 | pid = os.getpid() 178 | path = f"temp_files/temp_test{pid}.py" 179 | with open(path, "w") as temp_file: 180 | temp_file.write(func) 181 | error, syn = check_syntax_error(path) 182 | if error: return ExecuteResult(False, str(error), tuple([False]*num_tests)) 183 | success_tests = [] 184 | failed_tests = [] 185 | is_passing = True 186 | 187 | for i in range(num_tests): 188 | input_data = tests["input"][i] 189 | expected_output = tests["output"][i] 190 | try: 191 | result = subprocess.run([sys.executable, path], input=input_data, text=True, capture_output=True, 192 | check=True, timeout=timeout) 193 | code_output = result.stdout 194 | error_output = result.stderr 195 | if result.returncode == 0: 196 | if code_output.strip() == expected_output.strip(): 197 | success_tests.append(f"Input:\n{input_data.strip()}\nOutput:\n{expected_output.strip()}") 198 | else: 199 | failed_tests.append(f"Input:\n{input_data.strip()}\nExpected Output:\n{expected_output.strip()}\nProgram's Output:\n{code_output.strip()}\n-----------------------------") 200 | is_passing = False 201 | else: 202 | is_passing = False 203 | failed_tests.append( 204 | f"Input:\n{input_data.strip()}\nExpected Output:\n{expected_output.strip()}\nProgram's Output:\n{code_output.strip()}\n-----------------------------") 205 | except subprocess.CalledProcessError as e: 206 | is_passing = False 207 | failed_tests.append(f"{input_data.strip()}\nProgram's Output: {e.stderr}") 208 | except subprocess.SubprocessError as e: 209 | failed_tests.append(f"{input_data.strip()}\nProgram's Output: {e}") 210 | is_passing = False 211 | except subprocess.TimeoutExpired: 212 | failed_tests.append(f"{input_data.strip()}\nProgram's Output: Program Timed Out after {timeout} seconds, could be read in format problem where program waiting on input.") 213 | is_passing = False 214 | except Exception as e: 215 | failed_tests.append(f"{input_data.strip()}\nProgram's Output: {e}") 216 | is_passing = False 217 | state = [] 218 | for test in tests: 219 | if test in success_tests: 220 | state += [True] 221 | else: 222 | state += [False] 223 | 224 | state = tuple(state) 225 | feedback = "Tests succeeded:" 226 | if success_tests: 227 | for test in success_tests: 228 | feedback += f"\n{test}" 229 | else: feedback += "\nNone" 230 | feedback += "\nTests failed:" 231 | if failed_tests: 232 | for test in failed_tests: 233 | feedback += f"\n{test}" 234 | else: feedback += "\nNone" 235 | if len(feedback) > 2500: feedback = feedback[:2500] 236 | return ExecuteResult(is_passing, feedback, state) 237 | 238 | def evaluate(self, name: str, func: str, test: List[dict], timeout: int = 5, prev="") -> bool: 239 | num_tests = len(test["input"]) 240 | assert len(test["output"]) == num_tests 241 | pid = os.getpid() 242 | path = f"temp_files/temp_test{pid}.py" 243 | # path = "temp_files/temp_test.py" 244 | with open(path, "w") as temp_file: 245 | temp_file.write(func) 246 | error, syn = check_syntax_error(path) 247 | if error: return ExecuteResult(False, str(error), tuple([False]*num_tests)) 248 | success_tests = [] 249 | failed_tests = [] 250 | is_passing = True 251 | 252 | for i in range(num_tests): 253 | input_data = test["input"][i] 254 | expected_output = test["output"][i] 255 | try: 256 | result = subprocess.run([sys.executable, path], input=input_data, text=True, capture_output=True, 257 | check=True, timeout=timeout) 258 | code_output = result.stdout 259 | if result.returncode == 0: 260 | if code_output.strip() == expected_output.strip(): 261 | success_tests.append(f"Input:\n{input_data.strip()}\nOutput:\n{expected_output.strip()}") 262 | else: 263 | failed_tests.append(f"Input:\n{input_data.strip()}\nExpected Output:\n{expected_output.strip()}\nProgram's Output:\n{code_output.strip()}") 264 | is_passing = False 265 | else: raise ValueError("Doesn't return output correctly") 266 | except Exception as e: 267 | failed_tests.append(f"{input_data.strip()}\nProgram's Output: {e}") 268 | is_passing = False 269 | except subprocess.TimeoutExpired: 270 | failed_tests.append(f"{input_data.strip()}\nProgram's Output: Program Timed Out after {timeout} seconds.") 271 | is_passing = False 272 | if not is_passing: break # already found failed case 273 | return is_passing 274 | 275 | 276 | 277 | def get_output(func: str, assert_statement: str, timeout: int = 5) -> str: 278 | try: 279 | exec(f"from typing import *\n{func}", globals()) 280 | func_call = get_call_str(assert_statement) 281 | output = function_with_timeout(eval, (func_call, globals()), timeout=timeout) 282 | return output 283 | except TimeoutError: 284 | return "TIMEOUT" 285 | except Exception as e: 286 | return str(e) 287 | 288 | if __name__ == "__main__": 289 | func1 = """ 290 | def add(a, b): 291 | return a + b 292 | """ 293 | func = "def add(a, b):\n while True:\n x = 1\n return a + b" 294 | tests = ["assert add(1, 2) == 3", "assert add(2, 2) == 3"] 295 | print(func) 296 | print(PyExecutor().execute(func, tests, timeout=1)) 297 | -------------------------------------------------------------------------------- /generators/__init__.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import sys, os 22 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) 23 | from .py_generate import PyGenerator 24 | from .factory import generator_factory, model_factory 25 | from .model import ModelBase, GPT4, GPT35 26 | -------------------------------------------------------------------------------- /generators/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /generators/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /generators/__pycache__/factory.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/factory.cpython-310.pyc -------------------------------------------------------------------------------- /generators/__pycache__/factory.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/factory.cpython-311.pyc -------------------------------------------------------------------------------- /generators/__pycache__/generator_types.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/generator_types.cpython-310.pyc -------------------------------------------------------------------------------- /generators/__pycache__/generator_types.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/generator_types.cpython-311.pyc -------------------------------------------------------------------------------- /generators/__pycache__/generator_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/generator_utils.cpython-310.pyc -------------------------------------------------------------------------------- /generators/__pycache__/generator_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/generator_utils.cpython-311.pyc -------------------------------------------------------------------------------- /generators/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /generators/__pycache__/model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/model.cpython-311.pyc -------------------------------------------------------------------------------- /generators/__pycache__/parse.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/parse.cpython-310.pyc -------------------------------------------------------------------------------- /generators/__pycache__/parse.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/parse.cpython-311.pyc -------------------------------------------------------------------------------- /generators/__pycache__/py_generate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/py_generate.cpython-310.pyc -------------------------------------------------------------------------------- /generators/__pycache__/py_generate.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/CodeTree/72c9c874f831e1b4f8b4c449ac8ffeb141dc8fbd/generators/__pycache__/py_generate.cpython-311.pyc -------------------------------------------------------------------------------- /generators/factory.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | #from ..main import Codecontests 22 | import sys, os 23 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) 24 | 25 | from .py_generate import PyGenerator # TODO: change source back 26 | from .generator_types import Generator 27 | from .model import myLLM, CodeLlama, ModelBase, GPT4, GPT35, StarChat, GPTDavinci, GPT4Omini, GPT4O 28 | 29 | 30 | def generator_factory(lang: str) -> Generator: 31 | if lang == "py" or lang == "python": 32 | return PyGenerator() 33 | else: 34 | raise ValueError(f"Invalid language for generator: {lang}") 35 | 36 | 37 | def model_factory(model_name: str) -> ModelBase: 38 | if model_name == "gpt-4": 39 | return GPT4() 40 | elif "gpt-3.5-turbo" in model_name: 41 | return GPT35() 42 | elif "gpt-4o-mini" in model_name: 43 | return GPT4Omini() 44 | elif "gpt-4o" in model_name: return GPT4O() 45 | elif "meta-llama" in model_name: return myLLM(model_name) 46 | elif model_name == "starchat": 47 | return StarChat() 48 | elif model_name.startswith("codellama"): 49 | # if it has `-` in the name, version was specified 50 | kwargs = {} 51 | if "-" in model_name: 52 | kwargs["version"] = model_name.split("-")[1] 53 | return CodeLlama(**kwargs) 54 | elif model_name.startswith("text-davinci"): 55 | return GPTDavinci(model_name) 56 | else: 57 | return myLLM(model_name) 58 | #raise ValueError(f"Invalid model name: {model_name}") 59 | -------------------------------------------------------------------------------- /generators/generator_types.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | from typing import List, Optional, Union 22 | from abc import abstractmethod, ABC 23 | 24 | from generators.model import ModelBase 25 | 26 | 27 | class Generator: 28 | @abstractmethod 29 | def self_reflection(self, func: str, feedback: str, model: ModelBase) -> str: 30 | ... 31 | 32 | @abstractmethod 33 | def func_impl( 34 | self, 35 | func_sig: str, 36 | model: ModelBase, 37 | strategy: str, 38 | prev_func_impl: Optional[str] = None, 39 | feedback: Optional[str] = None, 40 | self_reflection: Optional[str] = None, 41 | num_comps: int = 1, 42 | temperature: float = 0.0, 43 | ) -> Union[str, List[str]]: 44 | ... 45 | 46 | @abstractmethod 47 | def internal_tests( 48 | self, 49 | func_sig: str, 50 | model: ModelBase, 51 | max_num_tests: int = 5 52 | ) -> List[str]: 53 | 54 | ... 55 | @abstractmethod 56 | def strategy(self, 57 | func_sig: str, 58 | model: ModelBase, 59 | num_strategy: str, 60 | temperature: float = 0.0, 61 | prev_func_impl: Optional[str] = None, 62 | feedback: Optional[str] = None, 63 | given_strategy: Optional[str] = None, 64 | task: str="strategy") -> List[str]: 65 | ... 66 | 67 | def agent_eval(self, 68 | func_sig: str, 69 | model: ModelBase, 70 | temperature: float = 0.0, 71 | prev_func_impl: Optional[str] = None, 72 | feedback: Optional[str] = None, 73 | given_strategy: Optional[str] = None, 74 | task: str="stop") -> List[str]: 75 | ... -------------------------------------------------------------------------------- /generators/generator_utils.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | from .model import ModelBase, Message, messages_to_str 22 | import random 23 | #from ..main import Codecontests 24 | import sys, os 25 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) 26 | from config import get_parsed_args 27 | args = get_parsed_args() 28 | Codecontests = False if args.function else True 29 | from typing import Union, List, Optional, Callable 30 | PY_STRATEGY = ("You are an AI assistant that provides strategy for Python programmers to code. You will be given a function signature and its docstring by the user. " 31 | "Your goal is to think of {py_strategy_k} " 32 | "strategies in English(Not Code) on how to approach this problem and solve it. Describe each strategy with a FEW sentences in a SINGLE Line. List and Number your strategies line by line using \"1. \"; \"2. \"; \"3. \" and so on.") 33 | Prompt_flexible = "The number of alternatives(either one or multiple) should be determined given this specific case." 34 | PY_IMPLEMENT = """You are an AI assistant who helps the user write code. The user will give you a function signature and its docstring and also suggest a strategy. You should instruct (in English) the user to implement their strategy, adding details not provided in the strategy. You must give {py_strategy_k} alternatives on how to implement the strategy exactly. Each alternative should be FEW sentences in a SINGLE line. List and number your {py_strategy_k} implementation alternatives using \"1. \", \"2. \".""" 35 | PY_REFELCTION = ("You are an AI assistant who can reflect on problem-solving solution program. You will be given a task, an incorrect function implementation and feedbacks from executing the code. Your goal is to describe the existing issue(s) and suggest methods on how to improve the code. Rules:\n" 36 | "1. From the algorithm and implementation level, there could be multiple methods to fix the error, you should provide {py_strategy_k} alternative reflections using various strategies. If the bug is clouded and ambigious, you can use alternatives as different interpretations, too.\n" 37 | "2. Each reflection should briefly describe the issues and bugs, what kind of improvement is needed, then describe how to implement the correction. You are allowed to restate the bug for each reflection if needed. Each reflection should start be complete and self-contained. In other words, if there are more than one bugs, they should be presented in one reflection rather than separately.\n" 38 | "3. Answer format: List and number your alternatives line by line, starting with \"1. \", \"2. \" and so on. Each reflection alternative is in a single line within a few sentences.\n") 39 | 40 | 41 | system_stop_simplified = """The user will provide a programming task along with a solution that passes all visible test cases. Your task is to further review the solution before it is judged against hidden test cases. Determine whether the solution is robust and general enough to pass unseen, valid test cases. Guideline: 42 | 1. Generalization Check: Verify that the solution uses general methods, avoiding hardcoding specific values or cases unless explicitly required. Confirm that the approach logically extends to unseen cases without special assumptions. 43 | 2. Boundary Check: Ensure all boundaries are correctly handled, including list indexing, loop start and end points, if-else conditions, and recursion exits. Look for potential off-by-one errors or boundary misses that could cause functional errors. 44 | 3. Edge Case Check: Confirm that the solution correctly handles valid edge/corner cases, such as zero, negative, empty, boundary values, or other special problem-specific situations. Note: All unseen test cases are guaranteed to follow stated data types, formats, conditions, and other constraints in the problem, no need to handle unallowed inputs. Do NOT apply redundant handling for cases that the current solution inherently manages, such as empty lists in sorting algorithms (`sorted([]) → []`), unless they explicitly fail (e.g., `max([]) → error`). 45 | 4. Major Efficiency Check: Check if the solution is within polynomial time/space complexity, if NOT, fail this check. 46 | 47 | **Response Format**: 48 | Firstly, within several sentences, follow the guideline and briefly analyze. 49 | On a new line, respond with “True” if the solution is ACCEPTABLE as-is, or “False” if NECESSARY modifications are required to handle unseen valid test cases. 50 | 51 | The following is one example of how to review: 52 | : 53 | ```python 54 | def find_first_unique(nums: list[int]) -> int: 55 | \"\"\" 56 | Find the first unique integer in a list of integers. 57 | Args: nums (list[int]): A list of integers to search through. 58 | Returns: int: The first unique integer in the list, or -1 if no unique integer is found. 59 | Examples: 60 | >>> find_first_unique([4, 5, 1, 2, 0, 4]) ==> 5 61 | >>> find_first_unique([7, 3]) ==> 7 62 | \"\"\" 63 | for i, num in enumerate(nums): 64 | if num not in nums[i:]: return num 65 | return -1 66 | ``` 67 | : 68 | 1. Generalization Check: `num not in nums[i:]` won’t handle cases where the number appears previous positions, `find_first_unique([7, 7]) ==> 7` instead of -1. Other checks are omitted for now since the solution logic is wrong. 69 | False 70 | """ 71 | if Codecontests: 72 | PY_STRATEGY = PY_STRATEGY.replace("a function signature and its docstring", "a programming problem and its required input/output formats") 73 | PY_IMPLEMENT = PY_IMPLEMENT.replace("a function signature and its docstring", "a programming problem and its required input/output formats") 74 | PY_REFELCTION = PY_REFELCTION.replace("function implementation", "solution program to a problem") 75 | 76 | def generic_generate_func_impl( 77 | func_sig: str, 78 | model: ModelBase, 79 | strategy: str, 80 | prev_func_impl, 81 | feedback, 82 | self_reflection, 83 | num_comps, 84 | temperature, 85 | reflection_chat_instruction: str, 86 | reflection_few_shot: str, 87 | simple_chat_instruction: str, 88 | reflection_completion_instruction: str, 89 | simple_completion_instruction: str, 90 | code_block_instruction: str, 91 | parse_code_block: Callable[[str], str], 92 | add_code_block: Callable[[str], str] 93 | ) -> Union[str, List[str]]: 94 | if strategy != "reflexion" and strategy != "simple" and strategy != "self-repair": 95 | raise ValueError( 96 | f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`") 97 | if strategy == "reflexion" and (prev_func_impl is None or feedback is None or self_reflection is None): 98 | raise ValueError( 99 | f"Invalid arguments: given `strategy=reflexion` but `prev_func_impl`, `feedback`, or `self_reflection` is None") 100 | if model.is_chat: 101 | func_bodies = None 102 | if strategy == "reflexion": 103 | # message = f"{reflection_few_shot}\n[previous impl]:\n{add_code_block(prev_func_impl)}\n\n[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:\n{self_reflection}\n\n[improved impl]:\n{func_sig}" 104 | prompt = f"{reflection_chat_instruction}\n{code_block_instruction}" 105 | # func_bodies is a really bad name, as it can also be just 1 string 106 | #print_messages(prompt, message) 107 | messages = [ 108 | Message( 109 | role="system", 110 | content=prompt, 111 | ), 112 | Message( 113 | role="user", # TODO: check this 114 | content=f"Here's the challenge for you:\n{func_sig}\n[implement]:\n", 115 | ), 116 | Message( 117 | role="assistant", 118 | content=f"{add_code_block(prev_func_impl)}" 119 | ), 120 | Message( 121 | role="user", 122 | content=f"[unit test results from previous implement]:\n{feedback}\n\n[reflection on previous implement]:\n", 123 | ), 124 | Message( 125 | role="assistant", 126 | content=self_reflection+"\n", 127 | ), 128 | Message( 129 | role="user", 130 | content=f"[improved implement]:\n", 131 | ), 132 | ] 133 | print(messages_to_str(messages)) 134 | func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature, max_tokens=4096) 135 | else: # Simple 136 | messages = [ 137 | Message( 138 | role="system", 139 | content=f"{simple_chat_instruction}\n{code_block_instruction}", 140 | ), 141 | Message( 142 | role="user", 143 | content=func_sig, 144 | ), 145 | ] 146 | func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature, max_tokens=4096) 147 | 148 | else: 149 | if strategy == "reflexion": 150 | prompt = f"{reflection_completion_instruction}\n{add_code_block(prev_func_impl)}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}" 151 | func_bodies = model.generate( 152 | prompt, num_comps=num_comps, temperature=temperature) 153 | else: 154 | prompt = f"{simple_completion_instruction}\n{func_sig}\n{code_block_instruction}" 155 | func_bodies = model.generate( 156 | prompt, num_comps=num_comps, temperature=temperature) 157 | 158 | if num_comps == 1: 159 | assert isinstance(func_bodies, str) 160 | #print("model responses!", func_bodies) 161 | func_body_str = parse_code_block(func_bodies) 162 | print_generated_func_body(func_body_str) 163 | return func_body_str 164 | 165 | else: 166 | try: 167 | func_bodies = [parse_code_block(func_body) for func_body in func_bodies] 168 | print_generated_func_body("\n\n".join(func_bodies)) 169 | except: 170 | print(func_bodies) 171 | 172 | return func_bodies 173 | 174 | 175 | def generate_with_accumulated_context( 176 | func_sig: str, 177 | model: ModelBase, 178 | strategy: str, 179 | prev_func_impl, 180 | accumulated_feedback, 181 | accumulated_reflection, 182 | num_comps, 183 | temperature, 184 | reflection_chat_instruction: str, 185 | reflection_few_shot: str, 186 | simple_chat_instruction: str, 187 | reflection_completion_instruction: str, 188 | simple_completion_instruction: str, 189 | code_block_instruction: str, 190 | parse_code_block: Callable[[str], str], 191 | add_code_block: Callable[[str], str] 192 | ) -> Union[str, List[str]]: 193 | # Ensure that the strategy is valid 194 | if strategy != "reflexion" and strategy != "simple": 195 | raise ValueError( 196 | f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`") 197 | if strategy == "reflexion" and (prev_func_impl is None or accumulated_feedback is None or accumulated_reflection is None): 198 | raise ValueError( 199 | f"Invalid arguments: given `strategy=reflexion` but `prev_func_impl`, `feedback`, or `self_reflection` is None") 200 | 201 | # Build the accumulated context from the provided feedback and reflections 202 | accumulated_context = "\n\n".join( 203 | [f"[previous impl {i+1}]:\n{add_code_block(impl)}\n[unit test results from previous impl {i+1}]:\n{feedback}\n[reflection on previous impl {i+1}]:\n{reflection}" 204 | for i, (impl, feedback, reflection) in enumerate(zip(prev_func_impl, accumulated_feedback, accumulated_reflection))] 205 | ) 206 | 207 | if model.is_chat: 208 | if strategy == "reflexion": 209 | # Constructing the message using a loop for accumulated context 210 | messages = [ 211 | Message(role="system", content=f"{reflection_chat_instruction}\n{code_block_instruction}"), 212 | Message(role="user", content=reflection_few_shot) 213 | ] 214 | 215 | for impl, feedback, reflection in zip(prev_func_impl, accumulated_feedback, accumulated_reflection): 216 | messages.append(Message(role="assistant", content=add_code_block(impl))) 217 | messages.append(Message(role="user", content=f"[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:\n{reflection}")) 218 | 219 | messages.append(Message(role="user", content=f"[improved impl]:\n{func_sig}")) 220 | prompt = "\n".join([message.content for message in messages]) 221 | message = (f"{reflection_few_shot}\n{accumulated_context}\n\n[improved impl]:\n{func_sig}") 222 | print_messages(prompt, message) 223 | 224 | func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature) 225 | else: 226 | system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}" 227 | print_messages(system_prompt, func_sig) 228 | messages = [ 229 | Message(role="system", content=f"{simple_chat_instruction}\n{code_block_instruction}"), 230 | Message(role="user", content=func_sig) 231 | ] 232 | func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature) 233 | else: 234 | if strategy == "reflexion": 235 | prompt = f"{reflection_completion_instruction}\n{accumulated_context}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}" 236 | func_bodies = model.generate(prompt, num_comps=num_comps, temperature=temperature) 237 | print_messages(prompt, "") 238 | else: 239 | prompt = f"{simple_completion_instruction}\n{func_sig}\n{code_block_instruction}" 240 | func_bodies = model.generate(prompt, num_comps=num_comps, temperature=temperature) 241 | print_messages(prompt, "") 242 | 243 | if num_comps == 1: 244 | assert isinstance(func_bodies, str) 245 | func_body_str = parse_code_block(func_bodies) 246 | print_generated_func_body(func_body_str) 247 | return func_body_str 248 | 249 | else: 250 | func_bodies = [parse_code_block(func_body) for func_body in func_bodies] 251 | print_generated_func_body("\n\n".join(func_bodies)) 252 | return func_bodies 253 | 254 | 255 | def generic_generate_internal_tests( 256 | func_sig: str, 257 | model: ModelBase, 258 | max_num_tests: int, 259 | test_generation_few_shot: str, 260 | test_generation_chat_instruction: str, 261 | test_generation_completion_instruction: str, 262 | parse_tests: Callable[[str], List[str]], 263 | is_syntax_valid: Callable[[str], bool], 264 | is_react: bool = False 265 | ) -> List[str]: 266 | """Generates tests for a function.""" 267 | if model.is_chat: 268 | if is_react: 269 | messages = [ 270 | Message( 271 | role="system", 272 | content=test_generation_chat_instruction, 273 | ), 274 | Message( 275 | role="user", 276 | content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[think]:" 277 | ) 278 | ] 279 | output = model.generate_chat(messages=messages, max_tokens=2048) 280 | print(f'React test generation output: {output}') 281 | else: 282 | messages = [ 283 | Message( 284 | role="system", 285 | content=f"{test_generation_chat_instruction}\n\n{test_generation_few_shot}", 286 | ), 287 | Message( 288 | role="user", 289 | content=f"[func signature]:\n{func_sig}\n\n[unit tests]:", 290 | ) 291 | ] 292 | output = model.generate_chat(messages=messages, max_tokens=2048) 293 | else: 294 | prompt = f'{test_generation_completion_instruction}\n\nfunc signature:\n{func_sig}\nunit tests:' 295 | output = model.generate(prompt, max_tokens=2048) 296 | all_tests = parse_tests(output) # type: ignore 297 | valid_tests = [test for test in all_tests if is_syntax_valid(test)] 298 | 299 | return sample_n_random(valid_tests, max_num_tests) 300 | 301 | 302 | def generic_generate_self_reflection( 303 | func: str, 304 | feedback: str, 305 | model: ModelBase, 306 | self_reflection_chat_instruction: str, 307 | self_reflection_completion_instruction: str, 308 | add_code_block: Callable[[str], str], 309 | self_reflection_few_shot: Optional[str] = None, 310 | task = "evaluation" 311 | ) -> str: 312 | if model.is_chat: 313 | if task == "evaluation": 314 | messages = [ 315 | Message( 316 | role="system", 317 | content=self_reflection_chat_instruction, 318 | ), 319 | Message( 320 | role="user", 321 | content=f'[function impl]:\n{add_code_block(func)}\n\n[unit test results]:\n{feedback}\n\nThis function passed visible tests, please further evaluate the code. Your options are "1. Correct implementation of desired function", "2. Mostly correct implementation, didn\'t consider edge/corner cases", "Only fits some situations, not the desired functionality."', 322 | ) 323 | ] 324 | if self_reflection_few_shot is not None: 325 | messages = [ 326 | Message( 327 | role="system", 328 | content=self_reflection_chat_instruction, 329 | ), 330 | Message( 331 | role="user", 332 | content=f'{self_reflection_few_shot}\n\n[function impl]:\n{add_code_block(func)}\n\n[unit test results]:\n{feedback}\n\n[self-reflection]:', 333 | ) 334 | ] 335 | reflection = model.generate_chat(messages=messages) 336 | print(f'Self reflection output: {reflection}') 337 | else: 338 | messages = [ 339 | Message( 340 | role="system", 341 | content=self_reflection_chat_instruction, 342 | ), 343 | Message( 344 | role="user", 345 | content=f'[function impl]:\n{add_code_block(func)}\n\n[unit test results]:\n{feedback}\n\n[self-reflection]:', 346 | ) 347 | ] 348 | reflection = model.generate_chat(messages=messages) 349 | else: 350 | reflection = model.generate( 351 | f'{self_reflection_completion_instruction}\n{add_code_block(func)}\n\n{feedback}\n\nExplanation:') 352 | return reflection # type: ignore 353 | def generic_evaluate(func_sig: str, 354 | model: ModelBase, 355 | parse_response: Callable, 356 | task = "stop", 357 | code = "", 358 | temperature=0.0, 359 | lang="python", 360 | exe_feedback = "", 361 | code_impr=None 362 | ): 363 | """ 364 | 1. Which strategy to explore first 365 | 2. When pass_public_test, whether the current solution is acceptable, or keep exploring 366 | 3. Whether rollback to before-fix; keeping the summary of this_fix, mark as fail 367 | """ 368 | if task == "eval": 369 | messages = [ 370 | Message(role="system", content="Your task is to evaluate a strategy and corresponding implementation for solving a programming problem. You should score from 1 to 5 separately on the following aspects.\n" 371 | "Correctness: How well can the solution solve the task?\n" 372 | "Simpleness: How straightforward is the implementation given the difficulty of the problem?\n" 373 | "Generalizability: How well can this solution cover all cases, even ones not mentioned in examples?\n" 374 | "Insightfulness: Even when the solution is incorrect, how well does it point out a good direction to solve the problem?\n" 375 | "Your scores should use the follwing standards. 1: bad, 2: not too bad, 3: fair, 4: good, 5: excellent"), 376 | Message(role="user", content=f"Task Description:\n{func_sig}\n\nCode to Evaluate:\n```{lang}\n{code}\n```\nFeedback from executing the code on visible test cases:\n\n{exe_feedback}") 377 | ] 378 | elif task == "stop": 379 | messages = [ 380 | Message(role="system", content=system_stop_simplified), 381 | Message(role="user", content=f"Task Description:\n{func_sig}\n\nCode to Evaluate:\n```{lang}\n{code}\n```\nFeedback from executing the code on visible test cases:\n```\n{exe_feedback}\n```") 382 | ] 383 | print(messages_to_str(messages)) 384 | elif task == "tests": 385 | messages = [ 386 | Message(role="system", 387 | content="Your task is to evaluate the execution outputs of a code implementation. The statement and code is given by the user, and the output/expected output on a set of test cases." 388 | "Your should analyze the expected outputs and execution outputs. From a 0 to 5 range, you should give a score on how far the execution outputs are from the expected ones. Standards are below:\n" 389 | "\n0: Errors or time out when executing.\n" 390 | "\n1: No pattern found when comparing pairs of , errors are hard to interpret.\n" 391 | "\n2: Results abnormal for a part of cases(e.g., cannot handle negative elements; only half of it sorted).\n" 392 | "\n3: Result pairs have clear patterns(e.g., all elements offset by 1; all elements + 1; corp by value; reverse all elements...)\n" 393 | "\n4: Lack consideration of edge condition/corner cases(e.g., error only when elements are equal), otherwise correct.\n" 394 | "\n5: Results matched.\n" 395 | "\nGive your brief analysis first. Afterwards, start a new line with A SINGLE INTEGER NUMBER as your final score(0 to 5)."), 396 | Message(role="user", 397 | content=f"Task Description:\n{func_sig}\n\nCode to Evaluate:\n```{lang}\n{code}\n```\nFeedback from executing the code on visible test cases:\n\n{exe_feedback}") 398 | ] 399 | # input: test; goal: score the test output vs. expected output; for CodeContests 400 | elif task == "compare": 401 | assert code_impr is not None 402 | messages = [ 403 | Message(role="system", 404 | content="Your task is to compare a pair of solutions. The SECOND solution is a bug-fixing attempt to the FIRST solution, which fails to fix the bug. You should evaluate the attempt on whether it should be rollbacked. You should first analyze, and answer 'Rollback.' or 'Keep.'as the last word of your response."), 405 | Message(role="user", 406 | content=f"Task Description:\n{func_sig}\n\nCode to Evaluate:\n```{lang}\n{code}\n```\n\n```{lang}\n{code_impr}\n```\nFeedback from executing the code on visible test cases:\n\n{exe_feedback}") 407 | ] 408 | else: raise ValueError("task not in one of eval, tests, stop, compare") 409 | response = model.generate_chat(messages=messages, max_tokens=2048, temperature=temperature) 410 | result = parse_response(response) 411 | return result 412 | 413 | def sample_n_random(items: List[str], n: int) -> List[str]: 414 | """Sample min(n, len(items)) random items from a list""" 415 | assert n >= 0 416 | if n >= len(items): 417 | return items 418 | return random.sample(items, n) 419 | def generic_gen_strategy( 420 | func_sig: str, 421 | model: ModelBase, 422 | parse_strategy: Callable[[str], List[str]], 423 | code_combine: Callable, 424 | task = "strategy", 425 | given_strategy="", 426 | incorrect_code="", 427 | test_feedback="", 428 | temperature=0.0, 429 | lang="python", 430 | num_list = "3", 431 | ) -> List[str]: 432 | """Generates tests for a function.""" 433 | if model.is_chat: 434 | if task == "strategy": 435 | system_prompt = PY_STRATEGY.replace("{py_strategy_k}", str(num_list)) 436 | if "multiple" in str(num_list): system_prompt += Prompt_flexible 437 | # print_messages(system_prompt, func_sig) 438 | messages = [ 439 | Message(role="system", content=system_prompt), 440 | Message(role="user", content=func_sig) 441 | ] 442 | elif task == "implementation": 443 | system_prompt = PY_IMPLEMENT.replace("{py_strategy_k}", str(num_list)) #+ "\n" 444 | messages = [ 445 | Message(role="system", content=system_prompt), 446 | Message(role="user", content=f"```{lang}\n{func_sig}\n```\nHigh Level Strategy: {given_strategy}") 447 | ] 448 | elif task == "reflection": 449 | system_prompt = PY_REFELCTION.replace("{py_strategy_k}", str(num_list)) #+ "\n" 450 | if "multiple" in str(num_list): system_prompt += Prompt_flexible 451 | if given_strategy is None: given_strategy = "" 452 | messages = [ 453 | Message(role="system", content=system_prompt), 454 | Message(role="user", content= f"[problem] {func_sig}\n\n[proposed strategy]{given_strategy}\n```{lang}\n{incorrect_code}\n```\n"), 455 | Message(role="user", content=f"[unit test results]:\n{test_feedback}") 456 | ] 457 | else: raise ValueError("Must be in one of strategy/reflection/implementation") 458 | print(messages_to_str(messages)) 459 | func_bodies = model.generate_chat(messages=messages, max_tokens=2048, temperature=temperature) 460 | assert isinstance(func_bodies, str) 461 | func_body_str = parse_strategy(func_bodies) # how many strategies are given 462 | print_generated_func_body(func_body_str) 463 | return func_body_str 464 | else: 465 | raise ValueError("For chat models only.") 466 | def print_messages(system_message_text: str, user_message_text: str) -> None: 467 | print(f"""----------------------- SYSTEM MESSAGE -----------------------) 468 | {system_message_text} 469 | ---------------------------------------------- 470 | ----------------------- USER MESSAGE ----------------------- 471 | {user_message_text} 472 | ---------------------------------------------- 473 | """, flush=True) 474 | 475 | def print_generated_func_body(func_body_str: str) -> None: 476 | print(f"""--------------------- GENERATED FUNC BODY --------------------- 477 | {func_body_str} 478 | ------------------------------------------""") 479 | -------------------------------------------------------------------------------- /generators/model.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | from typing import List, Union, Optional, Literal 22 | import dataclasses 23 | import httpx 24 | from vllm import LLM, SamplingParams 25 | from transformers import AutoTokenizer 26 | import random 27 | import numpy as np 28 | 29 | timeout = httpx.Timeout(100.0) 30 | 31 | from tenacity import ( 32 | retry, 33 | stop_after_attempt, # type: ignore 34 | wait_random_exponential, # type: ignore 35 | ) 36 | import openai 37 | 38 | 39 | MessageRole = Literal["system", "user", "assistant"] 40 | client = openai.OpenAI(timeout=60) 41 | 42 | @dataclasses.dataclass() 43 | class Message(): 44 | role: MessageRole 45 | content: str 46 | 47 | 48 | def message_to_str(message: Message) -> str: 49 | return f"{message.role}: {message.content}" 50 | 51 | 52 | def messages_to_str(messages: List[Message]) -> str: 53 | return "\n".join([message_to_str(message) for message in messages]) 54 | 55 | 56 | @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) 57 | def gpt_completion( 58 | model: str, 59 | prompt: str, 60 | max_tokens: int = 1024, 61 | stop_strs: Optional[List[str]] = None, 62 | temperature: float = 0.0, 63 | num_comps=1, 64 | ) -> Union[List[str], str]: 65 | response = openai.Completion.create( 66 | model=model, 67 | prompt=prompt, 68 | temperature=temperature, 69 | max_tokens=max_tokens, 70 | top_p=1, 71 | frequency_penalty=0.0, 72 | presence_penalty=0.0, 73 | stop=stop_strs, 74 | n=num_comps, 75 | request_timeout=120, 76 | ) 77 | if num_comps == 1: 78 | return response.choices[0].text # type: ignore 79 | 80 | return [choice.text for choice in response.choices] # type: ignore 81 | 82 | 83 | # @retry(wait=wait_random_exponential(min=1, max=180), stop=stop_after_attempt(6)) 84 | def gpt_chat( 85 | model: str, 86 | messages: List[Message], 87 | max_tokens: int = 1024, 88 | temperature: float = 0.0, 89 | num_comps=1, 90 | ) -> Union[List[str], str]: 91 | messages=[dataclasses.asdict(message) for message in messages] 92 | messages[0]["content"] = "\nWrap any code snippet with a pair of ``` code fences in your response.\n" + messages[0]["content"] 93 | response = client.chat.completions.create( 94 | model=model, 95 | messages=messages, 96 | max_tokens=max_tokens, 97 | temperature=temperature, 98 | top_p=1, 99 | frequency_penalty=0.0, 100 | presence_penalty=0.0, 101 | n=num_comps, 102 | # request_timeout=120, 103 | ) 104 | if num_comps == 1: 105 | return response.choices[0].message.content # type: ignore 106 | print("temp", temperature) 107 | return [choice.message.content for choice in response.choices] # type: ignore 108 | 109 | 110 | class ModelBase(): 111 | def __init__(self, name: str): 112 | self.name = name 113 | self.is_chat = False 114 | 115 | def __repr__(self) -> str: 116 | return f'{self.name}' 117 | 118 | def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: 119 | raise NotImplementedError 120 | 121 | def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1) -> Union[List[str], str]: 122 | raise NotImplementedError 123 | 124 | 125 | class GPTChat(ModelBase): 126 | def __init__(self, model_name: str): 127 | self.name = model_name 128 | self.is_chat = True 129 | 130 | def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: 131 | #messages[0]["content"] = "\nWrap any code snippet with a pair of ``` code fences in your response.\n" + messages[0]["content"] 132 | return gpt_chat(self.name, messages, max_tokens, temperature, num_comps) 133 | 134 | 135 | class GPT4(GPTChat): 136 | def __init__(self): 137 | super().__init__("gpt-4") 138 | 139 | class GPT4Omini(GPTChat): 140 | def __init__(self): 141 | super().__init__("gpt-4o-mini-2024-07-18") 142 | 143 | class GPT4O(GPTChat): 144 | def __init__(self): 145 | super().__init__("gpt-4o-2024-08-06") 146 | 147 | class GPT35(GPTChat): 148 | def __init__(self): 149 | super().__init__("gpt-3.5-turbo") 150 | 151 | 152 | class GPTDavinci(ModelBase): 153 | def __init__(self, model_name: str): 154 | self.name = model_name 155 | 156 | def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0, num_comps=1) -> Union[List[str], str]: 157 | return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps) 158 | 159 | 160 | class HFModelBase(ModelBase): 161 | """ 162 | Base for huggingface chat models 163 | """ 164 | 165 | def __init__(self, model_name: str, model, tokenizer, eos_token_id=None): 166 | self.name = model_name 167 | self.model = model 168 | self.tokenizer = tokenizer 169 | self.eos_token_id = eos_token_id if eos_token_id is not None else self.tokenizer.eos_token_id 170 | self.is_chat = True 171 | 172 | def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: 173 | # NOTE: HF does not like temp of 0.0. 174 | if temperature < 0.0001: 175 | temperature = 0.0001 176 | 177 | prompt = self.prepare_prompt(messages) 178 | 179 | outputs = self.model.generate( 180 | prompt, 181 | max_new_tokens=min( 182 | max_tokens, self.model.config.max_position_embeddings), 183 | use_cache=True, 184 | do_sample=True, 185 | temperature=temperature, 186 | top_p=0.95, 187 | eos_token_id=self.eos_token_id, 188 | num_return_sequences=num_comps, 189 | request_timeout=120, 190 | ) 191 | 192 | outs = self.tokenizer.batch_decode(outputs, skip_special_tokens=False) 193 | assert isinstance(outs, list) 194 | for i, out in enumerate(outs): 195 | assert isinstance(out, str) 196 | outs[i] = self.extract_output(out) 197 | 198 | if len(outs) == 1: 199 | return outs[0] # type: ignore 200 | else: 201 | return outs # type: ignore 202 | 203 | def prepare_prompt(self, messages: List[Message]): 204 | raise NotImplementedError 205 | 206 | def extract_output(self, output: str) -> str: 207 | raise NotImplementedError 208 | 209 | 210 | class StarChat(HFModelBase): 211 | def __init__(self): 212 | import torch 213 | from transformers import AutoModelForCausalLM, AutoTokenizer 214 | model = AutoModelForCausalLM.from_pretrained( 215 | "HuggingFaceH4/starchat-beta", 216 | torch_dtype=torch.bfloat16, 217 | device_map="auto", 218 | ) 219 | tokenizer = AutoTokenizer.from_pretrained( 220 | "HuggingFaceH4/starchat-beta", 221 | ) 222 | super().__init__("starchat", model, tokenizer, eos_token_id=49155) 223 | 224 | def prepare_prompt(self, messages: List[Message]): 225 | prompt = "" 226 | for i, message in enumerate(messages): 227 | prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n" 228 | if i == len(messages) - 1: 229 | prompt += "<|assistant|>\n" 230 | 231 | return self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) 232 | 233 | def extract_output(self, output: str) -> str: 234 | out = output.split("<|assistant|>")[1] 235 | if out.endswith("<|end|>"): 236 | out = out[:-len("<|end|>")] 237 | 238 | return out 239 | 240 | class myLLM(ModelBase): 241 | def __init__(self, model_name: str): 242 | self.name = model_name 243 | self.is_chat = True 244 | self.model = LLM(model_name, max_model_len=8192, gpu_memory_utilization=1.0) 245 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 246 | 247 | def generate_chat(self, messages: List[Message], max_tokens: int = 4096, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: 248 | # tokenizer = AutoTokenizer.from_pretrained(model_name) 249 | # print(temperature) 250 | sampling_params = SamplingParams(temperature=temperature, top_p=0.95, max_tokens=max_tokens) 251 | messages = [dataclasses.asdict(message) for message in messages] 252 | messages[0]["content"] = "Wrap any code snippet with a pair of ``` code fences in your response.\n" + messages[0]["content"] 253 | # formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 254 | output = self.model.chat(messages, sampling_params=sampling_params) 255 | random.seed() 256 | np.random.seed() 257 | def print_outputs(outputs): 258 | res = [] 259 | for output in outputs: 260 | prompt = output.prompt 261 | generated_text = output.outputs[0].text 262 | res.append(generated_text) 263 | return "\n".join(res) 264 | return print_outputs(output) 265 | 266 | def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1) -> Union[List[str], str]: 267 | raise NotImplementedError 268 | 269 | 270 | class CodeLlama(HFModelBase): 271 | B_INST, E_INST = "[INST]", "[/INST]" 272 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 273 | 274 | DEFAULT_SYSTEM_PROMPT = """\ 275 | You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. 276 | 277 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" 278 | 279 | def __init__(self, version: Literal["34b", "13b", "7b"] = "34b"): 280 | import torch 281 | from transformers import AutoModelForCausalLM, AutoTokenizer 282 | tokenizer = AutoTokenizer.from_pretrained( 283 | f"codellama/CodeLlama-{version}-Instruct-hf", 284 | add_eos_token=True, 285 | add_bos_token=True, 286 | padding_side='left' 287 | ) 288 | model = AutoModelForCausalLM.from_pretrained( 289 | f"codellama/CodeLlama-{version}-Instruct-hf", 290 | torch_dtype=torch.bfloat16, 291 | device_map="auto", 292 | ) 293 | super().__init__("codellama", model, tokenizer) 294 | 295 | def prepare_prompt(self, messages: List[Message]): 296 | if messages[0].role != "system": 297 | messages = [ 298 | Message(role="system", content=self.DEFAULT_SYSTEM_PROMPT) 299 | ] + messages 300 | messages = [ 301 | Message(role=messages[1].role, content=self.B_SYS + 302 | messages[0].content + self.E_SYS + messages[1].content) 303 | ] + messages[2:] 304 | assert all([msg.role == "user" for msg in messages[::2]]) and all( 305 | [msg.role == "assistant" for msg in messages[1::2]] 306 | ), ( 307 | "model only supports 'system', 'user' and 'assistant' roles, " 308 | "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" 309 | ) 310 | messages_tokens: List[int] = sum( 311 | [ 312 | self.tokenizer.encode( 313 | f"{self.B_INST} {(prompt.content).strip()} {self.E_INST} {(answer.content).strip()} ", 314 | ) 315 | for prompt, answer in zip( 316 | messages[::2], 317 | messages[1::2], 318 | ) 319 | ], 320 | [], 321 | ) 322 | assert messages[-1].role == "user", f"Last message must be from user, got {messages[-1].role}" 323 | messages_tokens += self.tokenizer.encode( 324 | f"{self.B_INST} {(messages[-1].content).strip()} {self.E_INST}", 325 | ) 326 | # remove eos token from last message 327 | messages_tokens = messages_tokens[:-1] 328 | import torch 329 | return torch.tensor([messages_tokens]).to(self.model.device) 330 | 331 | def extract_output(self, output: str) -> str: 332 | out = output.split("[/INST]")[-1].split("")[0].strip() 333 | return out 334 | -------------------------------------------------------------------------------- /generators/parse.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import re 22 | from typing import Optional 23 | 24 | 25 | def parse_code_block(string: str, lang: str) -> Optional[str]: 26 | code = None 27 | code_pattern = fr"```{lang}\n(.*?)\n```" 28 | match = re.search(code_pattern, string, re.DOTALL) 29 | if match: 30 | code = match.group(1) 31 | if code is None: 32 | generic_code_pattern = r"```\n(.*?)\n```" 33 | match = re.search(generic_code_pattern, string, re.DOTALL) 34 | if match: 35 | code = match.group(1) 36 | if code is None: return parse_first_func(string, lang) 37 | # Remove unit tests written by itself 38 | return code 39 | # def_appear = False 40 | # new_code_lines = [] 41 | # for code_line in code.splitlines(): 42 | # if code_line.startswith("\t") or code_line.startswith(" ") or "import" in code_line: 43 | # new_code_lines.append(code_line) 44 | # elif "def" in code_line and ":" in code_line: 45 | # new_code_lines.append(code_line) 46 | # def_appear = True 47 | # elif def_appear is False: # before fisrt function appears, you can write some constant definitions. 48 | # new_code_lines.append(code_line) 49 | # return "\n".join(new_code_lines) 50 | 51 | 52 | 53 | def combine_function(docstring: str, implementation: str) -> str: 54 | impl_lines = implementation.strip().split("\n") 55 | if docstring.count("def") < 1: return None 56 | if docstring.count("def") > 1 or implementation.count("def") > 1: 57 | print("Error, many functions found.") 58 | return None 59 | # Find the function definition line in the implementation 60 | func_def_line = None 61 | for i, line in enumerate(impl_lines): 62 | if line.strip().startswith("def "): 63 | func_def_line = i 64 | break 65 | if func_def_line is None: 66 | return None 67 | impl_lines = docstring + "\n".join(impl_lines[func_def_line+1:]) 68 | return impl_lines 69 | 70 | def parse_multiple_code_block(string: str, lang: str) -> Optional[str]: 71 | list_of_code = [] 72 | for code_pattern in [fr"```{lang}\n(.*?)\n```", r"```\n(.*?)\n```"]: 73 | if re.search(code_pattern, string, re.DOTALL): 74 | matches = re.finditer(code_pattern, string, re.DOTALL) 75 | for match in matches: 76 | list_of_code.append(match.group(1)) 77 | return list_of_code 78 | return parse_first_func(string, lang) 79 | 80 | 81 | def parse_first_func(code: str, lang: str) -> Optional[str]: 82 | assert lang == "python", "Only python is supported for now. TODO: Rust" 83 | code_lines = code.split("\n") 84 | def_i = -1 85 | last_i = 0 86 | got_return = False 87 | for i, line in enumerate(code_lines): 88 | if line.startswith("def "): 89 | if def_i == -1: 90 | def_i = i 91 | else: 92 | break 93 | elif "return" in line and def_i != -1: 94 | got_return = True 95 | if line == "" and def_i != -1 and got_return: 96 | last_i = i 97 | break 98 | 99 | if last_i == 0: 100 | last_i = len(code_lines) - 1 101 | 102 | if def_i == -1: 103 | return "" 104 | 105 | return "\n".join(code_lines[def_i:last_i+1]).rstrip("[/PYTHON]") 106 | 107 | 108 | def add_code_block(string: str, lang: str) -> str: 109 | return f"```{lang}\n{string}\n```" 110 | 111 | 112 | def parse_functions_and_imports(code: str, lang: str) -> Optional[str]: 113 | assert lang == "python", "Only python is supported for now. TODO: Rust" 114 | code_lines = code.split("\n") 115 | filtered_lines = [] 116 | inside_function_or_class = False 117 | inside_triple_quotes = False 118 | 119 | for line in code_lines: 120 | stripped_line = line.strip() 121 | if stripped_line.startswith("import ") or stripped_line.startswith("from "): 122 | filtered_lines.append(line) 123 | elif stripped_line.startswith("def ") or stripped_line.startswith("class "): 124 | inside_function_or_class = True 125 | filtered_lines.append(line) 126 | elif inside_function_or_class: 127 | filtered_lines.append(line) 128 | if stripped_line.startswith('"""') or stripped_line.startswith("'''"): 129 | inside_triple_quotes = not inside_triple_quotes 130 | if not inside_triple_quotes and not stripped_line.startswith(" "): 131 | inside_function_or_class = False 132 | elif inside_triple_quotes: 133 | filtered_lines.append(line) 134 | if stripped_line.endswith('"""') or stripped_line.endswith("'''"): 135 | inside_triple_quotes = False 136 | 137 | # Remove lines that should be excluded 138 | result = [line for line in filtered_lines if 139 | line.strip() and not line.strip().startswith("[") and not line.strip().startswith("my_wonderful_func()")] 140 | 141 | return "\n".join(result) 142 | 143 | 144 | if __name__ == "__main__": 145 | CODE = """ 146 | import collections 147 | a = 1 148 | b = 2 149 | sub_parser = parser.add_subparsers().add_parser("frf 150 | a") 151 | 152 | def my_wonderful_func(): 153 | def useless_helper(): 154 | return 1 155 | if 1: 156 | return 1 157 | else: 158 | return ( 159 | 1, 160 | 2, 161 | ) 162 | [1,2,3,4,5] 163 | 164 | def bleh(): 165 | return aaa 166 | """ 167 | #print(parse_code_block(CODE, "python")) 168 | print(parse_functions_and_imports(CODE, "python")) 169 | CODE = """def total_match(lst1: List[str], lst2: List[str]) -> List[str]: 170 | \"\"\" 171 | Write a function that accepts two lists of strings and returns the list that has 172 | total number of chars in the all strings of the list less than the other list. 173 | 174 | if the two lists have the same number of chars, return the first list. 175 | 176 | Examples 177 | >>> total_match([], []) 178 | [] 179 | >>> total_match(['hi', 'admin'], ['hI', 'Hi']) 180 | ['hI', 'Hi'] 181 | >>> total_match(['hi', 'admin'], ['hi', 'hi', 'admin', 'project']) 182 | ['hi', 'admin'] 183 | >>> total_match(['hi', 'admin'], ['hI', 'hi', 'hi']) 184 | ['hI', 'hi', 'hi'] 185 | >>> total_match(['4'], ['1', '2', '3', '4', '5']) 186 | ['4'] 187 | \"\"\" 188 | total_chars_lst1 = sum(len(word) for word in lst1) 189 | total_chars_lst2 = sum(len(word) for word in lst2) 190 | 191 | if total_chars_lst1 < total_chars_lst2: 192 | return lst1 193 | elif total_chars_lst1 > total_chars_lst2: 194 | return lst2 195 | else: 196 | return lst1 197 | """ 198 | #print(parse_code_block(CODE, "python")) 199 | print(parse_functions_and_imports(CODE, "python")) 200 | -------------------------------------------------------------------------------- /generators/py_generate.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | from .model import ModelBase, message_to_str 22 | from .generator_types import Generator 23 | from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection, generate_with_accumulated_context, generic_gen_strategy, generic_evaluate 24 | import sys, os 25 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) 26 | from typing import Optional, List, Union 27 | import ast 28 | import re 29 | from .parse import parse_code_block, add_code_block, parse_multiple_code_block, combine_function 30 | from config import get_parsed_args 31 | args = get_parsed_args() 32 | Codecontests = False if args.function else True 33 | 34 | PY_SIMPLE_COMPLETION_INSTRUCTION = "# Write the body of this function only." 35 | PY_REFLEXION_COMPLETION_INSTRUCTION = "You are a Python writing assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature).\n\n-----" 36 | PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = "You are a Python writing assistant. You will be given a function implementation and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation.\n\n-----" 37 | USE_PYTHON_CODEBLOCK_INSTRUCTION = "Use a Python code block to write your response. For example:\n```python\nprint('Hello world!')\n```" 38 | PY_SIMPLE_CHAT_INSTRUCTION = "You are an AI that only responds with python code, NOT ENGLISH. You will be given a function signature and its docstring by the user. Write your full implementation (restate the function signature). Don't include test cases or printing statements in the code block." 39 | PY_SIMPLE_CHAT_INSTRUCTION_V2 = "You are an AI that only responds with only python code. You will be given a function signature and its docstring by the user. Write your full implementation (restate the function signature)." 40 | PY_REFLEXION_CHAT_INSTRUCTION = "You are an AI Python assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature)." 41 | PY_REFLEXION_CHAT_INSTRUCTION_V2 = "You are an AI Python assistant. You will be given your previous implementation of a function, a series of unit tests results, and your self-reflection on your previous implementation. Write your full implementation (restate the function signature)." 42 | PY_SELF_REFLECTION_CHAT_INSTRUCTION = "You are a Python programming assistant. You will be given a function implementation and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation." 43 | PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = "You are a Python programming assistant. You will be given a function implementation and a series of unit test results. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as guidance when you try again later. Only provide the few sentence description in your answer, not the implementation. You will be given a few examples by the user." 44 | PY_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring.""" 45 | PY_TEST_GENERATION_CHAT_INSTRUCTION = """You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring.""" 46 | 47 | if Codecontests: 48 | PY_REFLEXION_COMPLETION_INSTRUCTION = "You are a Python writing assistant. You will be given your past solution to a problem, a series of unit tests, and a hint to improve the solution appropriately. Write your full program(include read input/print output).\n\n-----" 49 | PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = "You are a Python writing assistant. You will be given a solution to a problem and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation.\n\n-----" 50 | USE_PYTHON_CODEBLOCK_INSTRUCTION = "Use a Python code block to write your response. For example:\n```python\nprint('Hello world!')\n```" 51 | PY_SIMPLE_CHAT_INSTRUCTION = "You are an AI that only responds with python code, NOT ENGLISH. You will be given a programming problem and its required input/output formats. Write your full implementation (include read input/print output; exclude test cases) in a code block." 52 | PY_REFLEXION_CHAT_INSTRUCTION = "You are an AI Python assistant. You will be given your past solution to a problem, a series of unit tests, and a hint to change the implementation appropriately. Write your full program(include read input/print output; exclude test cases).\n\n-----" 53 | 54 | 55 | class PyGenerator(Generator): 56 | def self_reflection(self, func: str, feedback: str, model: ModelBase) -> str: 57 | return generic_generate_self_reflection( 58 | func=func, 59 | feedback=feedback, 60 | model=model, 61 | self_reflection_chat_instruction=PY_SELF_REFLECTION_CHAT_INSTRUCTION, 62 | self_reflection_completion_instruction=PY_SELF_REFLECTION_COMPLETION_INSTRUCTION, 63 | add_code_block=lambda x: add_code_block(x, "python"), 64 | self_reflection_few_shot=PY_SELF_REFLECTION_CHAT_INSTRUCTION 65 | ) 66 | 67 | def func_impl( 68 | self, 69 | func_sig: str, 70 | model: ModelBase, 71 | strategy: str, 72 | prev_func_impl: Optional[str] = None, 73 | feedback: Optional[str] = None, 74 | self_reflection: Optional[str] = None, 75 | num_comps: int = 1, 76 | temperature: float = 0.8, 77 | acc_feedback: Optional[str] = None, 78 | acc_reflection: Optional[str] = None, 79 | ) -> Union[str, List[str]]: 80 | if strategy == "mcts": 81 | return generate_with_accumulated_context( 82 | func_sig=func_sig, 83 | model=model, 84 | strategy="reflexion", 85 | prev_func_impl=prev_func_impl, 86 | accumulated_feedback=acc_feedback, 87 | accumulated_reflection=acc_reflection, 88 | num_comps=num_comps, 89 | temperature=temperature, 90 | reflection_chat_instruction=PY_REFLEXION_CHAT_INSTRUCTION, 91 | reflection_few_shot=PY_REFLEXION_CHAT_INSTRUCTION, 92 | simple_chat_instruction=PY_SIMPLE_CHAT_INSTRUCTION, 93 | reflection_completion_instruction=PY_REFLEXION_COMPLETION_INSTRUCTION, 94 | simple_completion_instruction=PY_SIMPLE_COMPLETION_INSTRUCTION, 95 | code_block_instruction=USE_PYTHON_CODEBLOCK_INSTRUCTION, 96 | parse_code_block=lambda x: parse_code_block(x, "python"), 97 | add_code_block=lambda x: add_code_block(x, "python"), 98 | ) 99 | else: 100 | return generic_generate_func_impl( 101 | func_sig=func_sig, 102 | model=model, 103 | strategy=strategy, 104 | prev_func_impl=prev_func_impl, 105 | feedback=feedback, 106 | self_reflection=self_reflection, 107 | num_comps=num_comps, 108 | temperature=temperature, 109 | reflection_chat_instruction=PY_REFLEXION_CHAT_INSTRUCTION, 110 | reflection_few_shot=PY_REFLEXION_CHAT_INSTRUCTION, 111 | simple_chat_instruction=PY_SIMPLE_CHAT_INSTRUCTION, 112 | reflection_completion_instruction=PY_REFLEXION_COMPLETION_INSTRUCTION, 113 | simple_completion_instruction=PY_SIMPLE_COMPLETION_INSTRUCTION, 114 | code_block_instruction=USE_PYTHON_CODEBLOCK_INSTRUCTION, 115 | parse_code_block=lambda x: parse_code_block(x, "python"), 116 | add_code_block=lambda x: add_code_block(x, "python"), 117 | ) 118 | 119 | def internal_tests(self, func_sig: str, model: ModelBase, max_num_tests: int = 12) -> List[str]: 120 | def parse_tests(tests: str) -> List[str]: 121 | return [test.strip() for test in tests.splitlines() if "assert" in test] 122 | """ 123 | Generates tests for a function. 124 | """ 125 | return generic_generate_internal_tests( 126 | func_sig=func_sig, 127 | model=model, 128 | max_num_tests=max_num_tests, 129 | test_generation_few_shot=PY_TEST_GENERATION_CHAT_INSTRUCTION, 130 | test_generation_chat_instruction=PY_TEST_GENERATION_CHAT_INSTRUCTION, 131 | test_generation_completion_instruction=PY_TEST_GENERATION_COMPLETION_INSTRUCTION, 132 | parse_tests=parse_tests, 133 | is_syntax_valid=py_is_syntax_valid, 134 | ) 135 | 136 | def strategy(self, 137 | func_sig: str, 138 | model: ModelBase, 139 | num_strategy: int=3, 140 | temperature: float = 0.0, 141 | prev_func_impl: Optional[str] = None, 142 | feedback: Optional[str] = None, 143 | given_strategy: Optional[str] = None, 144 | task: str="strategy") -> List[str]: 145 | def parse_strategy(strategies: str) -> List[str]: 146 | pattern = r"^\s*<\d+>(.*)" 147 | pattern2 = r"\d+\.(.*)" 148 | new_strategies = [] 149 | lines = strategies.splitlines() 150 | lines = [ele for ele in lines if ele.strip() != ''] 151 | for line in lines: 152 | if len(line) < 5: continue 153 | a = re.search(pattern, line.strip()) 154 | if a: new_strategies.append(a.groups()[0]) 155 | else: 156 | a = re.search(pattern2, line.strip()) 157 | if a: new_strategies.append(a.groups()[0]) 158 | # else: new_strategies.append(line) 159 | return new_strategies 160 | return generic_gen_strategy( 161 | func_sig=func_sig, 162 | model=model, 163 | parse_strategy=parse_strategy, 164 | code_combine=combine_function, 165 | task=task, 166 | incorrect_code=prev_func_impl, 167 | test_feedback=feedback, 168 | given_strategy=given_strategy, 169 | num_list=num_strategy 170 | ) 171 | 172 | def agent_eval(self, 173 | func_sig: str, 174 | model: ModelBase, 175 | temperature: float = 0.0, 176 | prev_func_impl: Optional[str] = None, 177 | feedback: Optional[str] = None, 178 | given_strategy: Optional[str] = None, 179 | task: str="stop") -> List[str]: 180 | 181 | def binary_stop_parser(response): 182 | lines_of_response = response.splitlines() 183 | lines_of_response = [ele for ele in lines_of_response if ele.strip() != ''] 184 | judge = True 185 | if "false" in lines_of_response[-1].lower(): judge=False 186 | elif "true" in lines_of_response[-1].lower(): judge=True 187 | else: 188 | print("Sorry, this parse of judgement doesn't seem to work.") 189 | print("Response:", response) 190 | if len(lines_of_response) > 2: return judge, "\n".join(lines_of_response[:-1]) 191 | return judge, lines_of_response[0] 192 | 193 | 194 | def test_eval_parser(response): 195 | response_last_line = response.splitlines()[-1] 196 | score = 0 197 | for ele in ["0","1","2", "3", "4", "5"]: 198 | if ele in response_last_line: score = int(ele) 199 | return score, "\n".join(response.splitlines()[:-1]) 200 | 201 | return generic_evaluate( 202 | func_sig=func_sig, 203 | model=model, 204 | parse_response=binary_stop_parser if task=="stop" else test_eval_parser, 205 | task=task, 206 | code=prev_func_impl, 207 | exe_feedback=feedback, 208 | lang="python", 209 | code_impr=None 210 | ) 211 | 212 | 213 | 214 | 215 | DUMMY_FUNC_SIG = "def func():" 216 | DUMMY_FUNC_CALL = "func()" 217 | 218 | 219 | def handle_first_line_indent(func_body: str) -> str: 220 | if func_body.startswith(" "): 221 | return func_body 222 | split = func_body.splitlines() 223 | return f" {split[0]}\n" + "\n".join(split[1:]) 224 | 225 | 226 | def handle_entire_body_indent(func_body: str) -> str: 227 | split = func_body.splitlines() 228 | res = "\n".join([" " + line for line in split]) 229 | return res 230 | 231 | 232 | def fix_turbo_response(func_body: str) -> str: 233 | return fix_markdown(remove_unindented_signatures(func_body)) 234 | 235 | 236 | def fix_markdown(func_body: str) -> str: 237 | return re.sub("`{3}", "", func_body) 238 | 239 | 240 | def remove_unindented_signatures(code: str) -> str: 241 | regex = r"^def\s+\w+\s*\(" 242 | 243 | before_signature = [] 244 | after_signature = [] 245 | signature_found = False 246 | 247 | for line in code.split("\n"): 248 | if re.match(regex, line): 249 | signature_found = True 250 | continue 251 | 252 | if signature_found: 253 | after_signature.append(line) 254 | else: 255 | if not line.startswith(" ") and line.strip(): 256 | line = " " + line 257 | before_signature.append(line) 258 | 259 | return "\n".join(before_signature + after_signature) 260 | 261 | 262 | def py_fix_indentation(func_body: str) -> str: 263 | func_body = fix_turbo_response(func_body) 264 | """ 265 | 3 cases: 266 | 1. good syntax 267 | 2. first line not good 268 | 3. entire body not good 269 | """ 270 | def parse_indent_rec(f_body: str, cur_state: int) -> str: 271 | f_body = fix_markdown(f_body) 272 | if cur_state > 1: 273 | return f_body 274 | code = f'{DUMMY_FUNC_SIG}\n{f_body}\n{DUMMY_FUNC_CALL}' 275 | try: 276 | exec(code) 277 | return f_body 278 | except (IndentationError, SyntaxError): 279 | p_func = handle_first_line_indent if cur_state == 0 else handle_entire_body_indent 280 | return parse_indent_rec(p_func(func_body), cur_state + 1) 281 | except Exception: 282 | return f_body 283 | return parse_indent_rec(func_body, 0) 284 | 285 | 286 | def py_is_syntax_valid(code: str) -> bool: 287 | try: 288 | ast.parse(code) 289 | return True 290 | except Exception: 291 | return False 292 | -------------------------------------------------------------------------------- /license_info.md: -------------------------------------------------------------------------------- 1 | License Info 2 | ------------ 3 | 4 | Most projects we open source should use the [Apache License v2](https://opensource.org/license/apache-2-0/) license. Samples, demos, and blog / doc code examples should instead use [CC-0](https://creativecommons.org/publicdomain/zero/1.0/). If you strongly feel your project should perhaps use a different license clause, please engage with legal team. 5 | 6 | For the ALv2 license, create a `LICENSE.txt` file (or use the one in this template repo) in the root of your repo containing: 7 | ``` 8 | Apache License Version 2.0 9 | 10 | Copyright (c) 2023 Salesforce, Inc. 11 | All rights reserved. 12 | 13 | Apache License 14 | Version 2.0, January 2004 15 | http://www.apache.org/licenses/ 16 | 17 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 18 | 19 | 1. Definitions. 20 | 21 | "License" shall mean the terms and conditions for use, reproduction, 22 | and distribution as defined by Sections 1 through 9 of this document. 23 | 24 | "Licensor" shall mean the copyright owner or entity authorized by 25 | the copyright owner that is granting the License. 26 | 27 | "Legal Entity" shall mean the union of the acting entity and all 28 | other entities that control, are controlled by, or are under common 29 | control with that entity. For the purposes of this definition, 30 | "control" means (i) the power, direct or indirect, to cause the 31 | direction or management of such entity, whether by contract or 32 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 33 | outstanding shares, or (iii) beneficial ownership of such entity. 34 | 35 | "You" (or "Your") shall mean an individual or Legal Entity 36 | exercising permissions granted by this License. 37 | 38 | "Source" form shall mean the preferred form for making modifications, 39 | including but not limited to software source code, documentation 40 | source, and configuration files. 41 | 42 | "Object" form shall mean any form resulting from mechanical 43 | transformation or translation of a Source form, including but 44 | not limited to compiled object code, generated documentation, 45 | and conversions to other media types. 46 | 47 | "Work" shall mean the work of authorship, whether in Source or 48 | Object form, made available under the License, as indicated by a 49 | copyright notice that is included in or attached to the work 50 | (an example is provided in the Appendix below). 51 | 52 | "Derivative Works" shall mean any work, whether in Source or Object 53 | form, that is based on (or derived from) the Work and for which the 54 | editorial revisions, annotations, elaborations, or other modifications 55 | represent, as a whole, an original work of authorship. For the purposes 56 | of this License, Derivative Works shall not include works that remain 57 | separable from, or merely link (or bind by name) to the interfaces of, 58 | the Work and Derivative Works thereof. 59 | 60 | "Contribution" shall mean any work of authorship, including 61 | the original version of the Work and any modifications or additions 62 | to that Work or Derivative Works thereof, that is intentionally 63 | submitted to Licensor for inclusion in the Work by the copyright owner 64 | or by an individual or Legal Entity authorized to submit on behalf of 65 | the copyright owner. For the purposes of this definition, "submitted" 66 | means any form of electronic, verbal, or written communication sent 67 | to the Licensor or its representatives, including but not limited to 68 | communication on electronic mailing lists, source code control systems, 69 | and issue tracking systems that are managed by, or on behalf of, the 70 | Licensor for the purpose of discussing and improving the Work, but 71 | excluding communication that is conspicuously marked or otherwise 72 | designated in writing by the copyright owner as "Not a Contribution." 73 | 74 | "Contributor" shall mean Licensor and any individual or Legal Entity 75 | on behalf of whom a Contribution has been received by Licensor and 76 | subsequently incorporated within the Work. 77 | 78 | 2. Grant of Copyright License. Subject to the terms and conditions of 79 | this License, each Contributor hereby grants to You a perpetual, 80 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 81 | copyright license to reproduce, prepare Derivative Works of, 82 | publicly display, publicly perform, sublicense, and distribute the 83 | Work and such Derivative Works in Source or Object form. 84 | 85 | 3. Grant of Patent License. Subject to the terms and conditions of 86 | this License, each Contributor hereby grants to You a perpetual, 87 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 88 | (except as stated in this section) patent license to make, have made, 89 | use, offer to sell, sell, import, and otherwise transfer the Work, 90 | where such license applies only to those patent claims licensable 91 | by such Contributor that are necessarily infringed by their 92 | Contribution(s) alone or by combination of their Contribution(s) 93 | with the Work to which such Contribution(s) was submitted. If You 94 | institute patent litigation against any entity (including a 95 | cross-claim or counterclaim in a lawsuit) alleging that the Work 96 | or a Contribution incorporated within the Work constitutes direct 97 | or contributory patent infringement, then any patent licenses 98 | granted to You under this License for that Work shall terminate 99 | as of the date such litigation is filed. 100 | 101 | 4. Redistribution. You may reproduce and distribute copies of the 102 | Work or Derivative Works thereof in any medium, with or without 103 | modifications, and in Source or Object form, provided that You 104 | meet the following conditions: 105 | 106 | (a) You must give any other recipients of the Work or 107 | Derivative Works a copy of this License; and 108 | 109 | (b) You must cause any modified files to carry prominent notices 110 | stating that You changed the files; and 111 | 112 | (c) You must retain, in the Source form of any Derivative Works 113 | that You distribute, all copyright, patent, trademark, and 114 | attribution notices from the Source form of the Work, 115 | excluding those notices that do not pertain to any part of 116 | the Derivative Works; and 117 | 118 | (d) If the Work includes a "NOTICE" text file as part of its 119 | distribution, then any Derivative Works that You distribute must 120 | include a readable copy of the attribution notices contained 121 | within such NOTICE file, excluding those notices that do not 122 | pertain to any part of the Derivative Works, in at least one 123 | of the following places: within a NOTICE text file distributed 124 | as part of the Derivative Works; within the Source form or 125 | documentation, if provided along with the Derivative Works; or, 126 | within a display generated by the Derivative Works, if and 127 | wherever such third-party notices normally appear. The contents 128 | of the NOTICE file are for informational purposes only and 129 | do not modify the License. You may add Your own attribution 130 | notices within Derivative Works that You distribute, alongside 131 | or as an addendum to the NOTICE text from the Work, provided 132 | that such additional attribution notices cannot be construed 133 | as modifying the License. 134 | 135 | You may add Your own copyright statement to Your modifications and 136 | may provide additional or different license terms and conditions 137 | for use, reproduction, or distribution of Your modifications, or 138 | for any such Derivative Works as a whole, provided Your use, 139 | reproduction, and distribution of the Work otherwise complies with 140 | the conditions stated in this License. 141 | 142 | 5. Submission of Contributions. Unless You explicitly state otherwise, 143 | any Contribution intentionally submitted for inclusion in the Work 144 | by You to the Licensor shall be under the terms and conditions of 145 | this License, without any additional terms or conditions. 146 | Notwithstanding the above, nothing herein shall supersede or modify 147 | the terms of any separate license agreement you may have executed 148 | with Licensor regarding such Contributions. 149 | 150 | 6. Trademarks. This License does not grant permission to use the trade 151 | names, trademarks, service marks, or product names of the Licensor, 152 | except as required for reasonable and customary use in describing the 153 | origin of the Work and reproducing the content of the NOTICE file. 154 | 155 | 7. Disclaimer of Warranty. Unless required by applicable law or 156 | agreed to in writing, Licensor provides the Work (and each 157 | Contributor provides its Contributions) on an "AS IS" BASIS, 158 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 159 | implied, including, without limitation, any warranties or conditions 160 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 161 | PARTICULAR PURPOSE. You are solely responsible for determining the 162 | appropriateness of using or redistributing the Work and assume any 163 | risks associated with Your exercise of permissions under this License. 164 | 165 | 8. Limitation of Liability. In no event and under no legal theory, 166 | whether in tort (including negligence), contract, or otherwise, 167 | unless required by applicable law (such as deliberate and grossly 168 | negligent acts) or agreed to in writing, shall any Contributor be 169 | liable to You for damages, including any direct, indirect, special, 170 | incidental, or consequential damages of any character arising as a 171 | result of this License or out of the use or inability to use the 172 | Work (including but not limited to damages for loss of goodwill, 173 | work stoppage, computer failure or malfunction, or any and all 174 | other commercial damages or losses), even if such Contributor 175 | has been advised of the possibility of such damages. 176 | 177 | 9. Accepting Warranty or Additional Liability. While redistributing 178 | the Work or Derivative Works thereof, You may choose to offer, 179 | and charge a fee for, acceptance of support, warranty, indemnity, 180 | or other liability obligations and/or rights consistent with this 181 | License. However, in accepting such obligations, You may act only 182 | on Your own behalf and on Your sole responsibility, not on behalf 183 | of any other Contributor, and only if You agree to indemnify, 184 | defend, and hold each Contributor harmless for any liability 185 | incurred by, or claims asserted against, such Contributor by reason 186 | of your accepting any such warranty or additional liability. 187 | 188 | END OF TERMS AND CONDITIONS 189 | ``` 190 | 191 | The shorter version of license text should be added as a comment to all Salesforce-authored source code and configuration files that support comments. This include file formats like HTML, CSS, JavaScript, XML, etc. which aren't directly code, but are still critical to your project code. Like: 192 | ``` 193 | /* 194 | * Copyright (c) 2023, Salesforce, Inc. 195 | * SPDX-License-Identifier: Apache-2 196 | * 197 | * Licensed under the Apache License, Version 2.0 (the "License"); 198 | * you may not use this file except in compliance with the License. 199 | * You may obtain a copy of the License at 200 | * 201 | * http://www.apache.org/licenses/LICENSE-2.0 202 | * 203 | * Unless required by applicable law or agreed to in writing, software 204 | * distributed under the License is distributed on an "AS IS" BASIS, 205 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 206 | * See the License for the specific language governing permissions and 207 | * limitations under the License. 208 | */ 209 | ``` 210 | 211 | Note that there are many tools that exist to do this sort of thing in an automated fashion, without having to manually edit every single file in your project. It is highly recommended that you research some of these tools for your particular language / build system. 212 | 213 | For sample, demo, and example code, we recommend the [Unlicense](https://opensource.org/license/unlicense/) license. Create a `LICENSE.txt` file containing: 214 | ``` 215 | This is free and unencumbered software released into the public domain. 216 | 217 | Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. 218 | 219 | In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law. 220 | 221 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 222 | ``` 223 | 224 | No license header is required for samples, demos, and example code. 225 | -------------------------------------------------------------------------------- /llm_agent_guide.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import openai 22 | from utils import make_printv, write_jsonl 23 | from executors import executor_factory 24 | from generators import generator_factory, model_factory 25 | from typing import List 26 | import sys 27 | from common import gen_test_eval 28 | from collections import Counter 29 | 30 | sys.set_int_max_str_digits(100000) # Increase the limit to 10000 digits 31 | 32 | ADD_HINT = "To solve the problem, you can refer the hint given by an expert, and complete the details by analyzing it first.\nHint:" 33 | 34 | # TODO: From sample to list 35 | class Node: 36 | def __init__(self, solution: str, parent=None, strategy="", reflection="", depth=0): 37 | self.solution = solution 38 | self.parent = parent 39 | self.children = [] 40 | self.value = 0 41 | self.visits = 0 42 | self.context = "" 43 | self.depth = depth 44 | self.reflection = reflection 45 | self.test_feedback = "" 46 | self.strategy=strategy 47 | self.pass_visible = False 48 | 49 | def best_child(self): 50 | if not self.children: # Check if children list is empty 51 | return None 52 | return max(self.children, key=lambda child: child.uct()) 53 | 54 | def best_child_value(self): 55 | if not self.children: # Check if children list is empty 56 | return None 57 | return max(self.children, key=lambda child: child.value) 58 | 59 | def sort_children_by_value(self): 60 | self.children.sort(key=lambda x: x.value, reverse=True) 61 | 62 | def update(self, reward: float): 63 | self.visits += 1 64 | self.value += reward 65 | 66 | # only keeps the most recent blocks 67 | 68 | def eval_node(prompt, node:Node, gen, model, max_depth=3): 69 | """ 70 | The evaluation shows not fully correct, decide whether to go on. 71 | """ 72 | if node.parent is None: parent_value=0 73 | else: parent_value = node.parent.value 74 | if node.depth >= max_depth: return False, node.value, "" 75 | agent_reward, analysis = gen.agent_eval(prompt, model, prev_func_impl=node.solution, task="tests", 76 | feedback=node.test_feedback.split("[additional review]:")[0].strip()) 77 | node.value += float(agent_reward) / 15 78 | if node.value < parent_value: return False, node.value, "" 79 | elif node.value == parent_value and node.depth > agent_reward: return False, node.value, "" 80 | return True, node.value, analysis 81 | 82 | def step_verify(gen, exe, item, solution, model): 83 | """ 84 | if pass all public test cases, run one agent review step 85 | """ 86 | is_passing, feedback, reward = gen_test_eval(exe, solution, item["given_tests"], prev=item["prev"]) 87 | if not is_passing: 88 | return False, feedback, reward 89 | else: 90 | reward = 1 91 | option, analysis = gen.agent_eval(item["prompt"], model, prev_func_impl=solution, 92 | task="stop", feedback=feedback, temperature=0) 93 | if option: return True, feedback, reward 94 | else: 95 | return False, f"{feedback}\n\n[additional review]:\n\n{analysis}", reward 96 | 97 | 98 | def rerank_list_of_nodes(list_of_nodes): 99 | return sorted(list_of_nodes, key=lambda x:x.value, reverse=True) # small value in the front 100 | 101 | def agent_guide( 102 | dataset: List[dict], 103 | model_name: str, 104 | language: str, 105 | log_path: str, 106 | verbose: bool, 107 | max_depth: int = 3, 108 | search_width: int = 10, 109 | max_iters: int=20, 110 | Codecontests: bool = False 111 | ) -> None: 112 | print("max_depth", max_depth, "search_width", search_width) 113 | pass_problem_subset = [] 114 | if Codecontests: 115 | exe = executor_factory("code_contests") 116 | else: 117 | exe = executor_factory(language, is_leet=False) 118 | 119 | print("Len(dataset)", len(dataset), dataset[0].keys()) 120 | gen = generator_factory(language) 121 | model = model_factory(model_name) 122 | print_v = make_printv(verbose) 123 | count, sad_case, debug_thoroughed_case, enter_debug_case = 0, 0, 0, 0 124 | num_items = len(dataset) 125 | num_success, weak_success = 0, 0 # Counter for successful solutions 126 | passed_at_sample, solve_or_not = [], [] 127 | for idx, item in enumerate(dataset): 128 | print("STARTING EXAMPLE", idx) 129 | if Codecontests: 130 | item["entry_point"] = "" 131 | else: item["given_tests"] = [test for test in item["given_tests"] if 'assert False' not in test] 132 | # Thinker Agent Preparation 133 | hints = gen.strategy(item["prompt"], model, num_strategy="multiple", task="strategy", temperature=0) 134 | if len(hints) > search_width: hints = hints[:search_width] # width cut 135 | stack = [] 136 | is_passing, is_solved, is_weaker_solved = False, False, False 137 | num_try = 0 138 | for hint in reversed(hints): 139 | new_node = Node("", strategy=hint, depth=1) 140 | stack.append(new_node) # initial placeholders for new nodes 141 | 142 | # Tree Search Start 143 | found_solution = None 144 | candidate_solution = None 145 | while stack and num_try < max_iters and not is_passing: 146 | if len(stack) == 0: break 147 | this_node = stack.pop() 148 | if this_node.depth > max_depth: continue 149 | # Solver Agent 150 | if not this_node.solution: 151 | cur_solution = gen.func_impl(item["prompt"] + f"{ADD_HINT}{this_node.strategy}\n", 152 | model, "simple", temperature=0) 153 | if not candidate_solution: candidate_solution = cur_solution 154 | 155 | # Debugger Agent 156 | else: 157 | cur_solution = gen.func_impl( 158 | func_sig=item["prompt"], 159 | model=model, 160 | strategy="reflexion", 161 | prev_func_impl=this_node.solution, 162 | feedback=this_node.test_feedback.split("[additional review]:")[0].strip(), 163 | self_reflection=this_node.reflection, 164 | temperature=0 165 | ) 166 | num_try += 1 167 | 168 | # Execute and get Feedback 169 | is_passing, feedback, reward = step_verify(gen, exe, item, cur_solution, model) 170 | print("cur solution judge", is_passing) 171 | 172 | # Update node information as parent 173 | this_node.solution = cur_solution # update the solution to real solution 174 | this_node.test_feedback = feedback # With additional critic feedback 175 | this_node.value = reward 176 | 177 | if reward > 0.99: 178 | this_node.pass_visible = True 179 | if this_node.parent and this_node.parent.pass_visible and this_node.depth==max_depth: 180 | is_passing = True 181 | this_node.value += 5/15 # reward for passing all visible 182 | if is_passing: 183 | found_solution = cur_solution 184 | break 185 | 186 | elif reward <= 0.99: # didn't pass, need debugging 187 | candidate_solution = cur_solution 188 | go_on, values, analysis = eval_node(prompt=item["prompt"],node=this_node,gen=gen,model=model,max_depth=max_depth) 189 | this_node.value = values 190 | 191 | # Continue on this node 192 | else: go_on = this_node.pass_visible 193 | 194 | if go_on: 195 | # Thinker Agent, init startegies for potential agents 196 | reflections = gen.strategy(item["prompt"], 197 | model, task="reflection", 198 | num_strategy="one or multiple (if there is)", 199 | prev_func_impl=this_node.solution, 200 | feedback=this_node.test_feedback, 201 | temperature=0, 202 | given_strategy=this_node.strategy) 203 | if len(reflections) > search_width: reflections = reflections[:search_width] 204 | for reflection in reversed(reflections): 205 | if not reflection: continue 206 | new_node = Node(cur_solution, reflection=reflection, parent=this_node, strategy=this_node.strategy, depth=this_node.depth + 1) # init with previous code 207 | new_node.test_feedback = this_node.test_feedback 208 | this_node.children.append(new_node) # children in a reverse order 209 | stack.extend(this_node.children) 210 | if num_try >= max_iters: debug_thoroughed_case += 1 211 | # Verify that values are actually fair for all nodes. 212 | 213 | if found_solution is None: found_solution = candidate_solution 214 | 215 | is_solved = exe.evaluate( 216 | item["entry_point"], found_solution, item["test"], timeout=10, prev=item["prev"]) # early exit 217 | if "weaker_test" in item.keys(): 218 | is_weaker_solved = exe.evaluate( 219 | item["entry_point"], found_solution, item["weaker_test"], timeout=10, prev=item["prev"]) 220 | if is_solved: 221 | num_success += int(is_solved) 222 | passed_at_sample.append(num_try) 223 | if "difficulty" in item.keys(): pass_problem_subset.append(item["difficulty"]) 224 | else: 225 | sad_case += 1 226 | print("Sad, Pass but not solve") 227 | 228 | if is_weaker_solved: 229 | weak_success += int(is_weaker_solved) 230 | item["acc"] = round(num_success / (idx + 1), 3) 231 | item["weak_acc"] = round(weak_success / (idx + 1), 3) 232 | write_jsonl(log_path, [item], append=True) 233 | print_v( 234 | f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={item["weak_acc"]}, pass no solve: {sad_case}, exhaust: {debug_thoroughed_case}') 235 | 236 | print("_______________________________") 237 | print(passed_at_sample) 238 | print(sorted(passed_at_sample)) 239 | print(len(passed_at_sample)) 240 | print(Counter(passed_at_sample)) 241 | print("Passed but not solved case", sad_case) 242 | print(f"{max_iters} tries used still not solve:", debug_thoroughed_case) 243 | print(Counter(pass_problem_subset)) 244 | print_v( 245 | f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={round(weak_success / (idx + 1), 3)}') 246 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import os 22 | import argparse 23 | from reflexion import run_reflexion 24 | from utils import read_jsonl, read_jsonl_gz, read_json 25 | from bfs import run_bfs 26 | from dfs_real import run_dfs as my_dfs 27 | from strategy import strategy_guide 28 | from llm_agent_guide import agent_guide 29 | from resample_baseline import resample 30 | from common import wrap_mbpp_data, wrap_human_eval_data 31 | from datasets import load_dataset 32 | from config import get_parsed_args 33 | 34 | def strategy_factory(strategy: str): 35 | def kwargs_wrapper_gen(func, delete_keys=[]): 36 | def kwargs_wrapper(**kwargs): 37 | for key in delete_keys: 38 | del kwargs[key] 39 | return func(**kwargs) 40 | return kwargs_wrapper 41 | 42 | if strategy == "reflexion": 43 | return kwargs_wrapper_gen(run_reflexion) 44 | elif strategy == "dfs": 45 | return kwargs_wrapper_gen(my_dfs) 46 | elif strategy == "bfs": 47 | return kwargs_wrapper_gen(run_bfs) 48 | elif strategy == "strategy": 49 | return kwargs_wrapper_gen(strategy_guide) 50 | elif strategy == "agent": 51 | return kwargs_wrapper_gen(agent_guide) 52 | elif strategy == "resample": 53 | return kwargs_wrapper_gen(resample) 54 | else: 55 | raise ValueError(f"Strategy `{strategy}` is not supported") 56 | 57 | 58 | def main(args): 59 | # check if the root dir exists and create it if not 60 | if not os.path.exists(args.root_dir): 61 | os.makedirs(args.root_dir) 62 | dataset_name = os.path.basename(args.dataset_path).replace("jsonl", "") 63 | log_dir = os.path.join(args.root_dir, args.run_name) 64 | log_path = os.path.join( 65 | log_dir, f"{dataset_name}_{args.strategy}_{args.max_iters}_{args.model}_pass_at_k_{args.pass_at_k}_{args.language}.jsonl") 66 | if not os.path.exists(log_dir): 67 | os.makedirs(log_dir) 68 | 69 | # check if the strategy is valid 70 | run_strategy = strategy_factory(args.strategy) 71 | 72 | # print starting message 73 | if args.verbose: 74 | print(f""" 75 | Starting run with the following parameters: 76 | strategy: {args.strategy} 77 | pass@k: {args.pass_at_k} 78 | """) 79 | else: 80 | print(f"Logs will be saved in `{log_dir}`") 81 | 82 | # load the dataset 83 | print(f'Loading the dataset...') 84 | if args.dataset_path.endswith(".json"): 85 | dataset = read_json(args.dataset_path) 86 | elif args.dataset_path.endswith(".jsonl"): 87 | dataset = read_jsonl(args.dataset_path) 88 | elif args.dataset_path.endswith(".jsonl.gz"): 89 | dataset = read_jsonl_gz(args.dataset_path) 90 | else: 91 | raise ValueError( 92 | f"Dataset path `{args.dataset_path}` is not supported") 93 | print(f"Loaded {len(dataset)} examples") 94 | if "mbpp" in args.dataset_path: 95 | dataset = load_dataset("evalplus/mbppplus") 96 | dataset = wrap_mbpp_data(dataset["test"]) # half-half 97 | if "humaneval" in args.dataset_path: 98 | new_dataset = load_dataset("evalplus/humanevalplus") 99 | dataset = wrap_human_eval_data(dataset, new_dataset["test"]) 100 | Codecontests = False if args.function else True 101 | run_strategy( 102 | dataset=dataset, 103 | model_name=args.model, 104 | language=args.language, 105 | max_iters=args.max_iters, 106 | log_path=log_path, 107 | verbose=args.verbose, 108 | Codecontests=Codecontests 109 | ) 110 | 111 | print(f"Done! Check out the logs in `{log_path}`") 112 | 113 | if __name__ == "__main__": 114 | args = get_parsed_args() 115 | main(args) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /reflexion.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | from utils import enumerate_resume, make_printv, write_jsonl, resume_success_count 22 | from executors import executor_factory 23 | from generators import generator_factory, model_factory 24 | from typing import List, Dict, Tuple, Any 25 | 26 | import sys 27 | from collections import Counter 28 | sys.set_int_max_str_digits(100000) # Increase the limit to 10000 digits 29 | from common import gen_test_eval 30 | def run_reflexion( 31 | dataset: List[dict], 32 | model_name: str, 33 | max_iters: int, 34 | language: str, 35 | log_path: str, 36 | verbose: bool, 37 | is_leetcode: bool = False, 38 | Codecontests: bool = False 39 | ) -> None: 40 | if Codecontests: 41 | exe = executor_factory("code_contests") 42 | else: exe = executor_factory(language, is_leet=is_leetcode) 43 | gen = generator_factory(language) 44 | model = model_factory(model_name) 45 | print_v = make_printv(verbose) 46 | 47 | num_items = len(dataset) 48 | num_success, skip, weak_success = 0, 0, 0 # Counter for successful solutions 49 | passed_at_sample, solve_or_not = [], [] 50 | pass_problem_subset = [] 51 | 52 | for idx, item in enumerate(dataset): 53 | tests_i = item["given_tests"] 54 | if Codecontests: 55 | item["entry_point"] = "" 56 | else: 57 | tests_i = [test for test in tests_i if item['entry_point'] in test and 'assert False' not in test] 58 | is_solved, is_weaker_solved, is_passing = False, False, False 59 | num_try = 0 60 | cur_func_impl = gen.func_impl(item["prompt"], model, "simple") 61 | is_passing, feedback, reward = gen_test_eval(exe, cur_func_impl, tests_i, prev=item["prev"]) 62 | num_try += 1 63 | cur_feedback = feedback 64 | for i in range(max_iters-1): 65 | if is_passing: break 66 | reflection = gen.self_reflection( 67 | cur_func_impl, cur_feedback, model) 68 | cur_func_impl = gen.func_impl( 69 | func_sig=item["prompt"], 70 | model=model, 71 | strategy="reflexion", 72 | prev_func_impl=cur_func_impl, 73 | feedback=cur_feedback, 74 | self_reflection=reflection, 75 | ) 76 | is_passing, cur_feedback, reward = gen_test_eval(exe, cur_func_impl, tests_i, prev=item["prev"]) 77 | num_try += 1 78 | 79 | # Exit when passed public test cases. 80 | if is_passing: 81 | is_solved = exe.evaluate( 82 | item["entry_point"], cur_func_impl, item["test"], timeout=1, prev=item["prev"]) # early exit 83 | if "weaker_test" in item.keys(): 84 | is_weaker_solved = exe.evaluate( 85 | item["entry_point"], cur_func_impl, item["weaker_test"], timeout=1, prev=item["prev"]) 86 | if is_solved: 87 | num_success += int(is_solved) 88 | passed_at_sample.append(num_try) 89 | if "difficulty" in item.keys(): pass_problem_subset.append(item["difficulty"]) 90 | if is_weaker_solved: 91 | weak_success += int(is_weaker_solved) 92 | item["weak_acc"] = round(weak_success / (idx + 1), 3) 93 | item["acc"] = round(num_success / (idx + 1), 3) 94 | write_jsonl(log_path, [item], append=True) 95 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={item["weak_acc"]}') 96 | continue # early stop on this case if passsed 97 | print("_______________________________") 98 | print(passed_at_sample) 99 | print(sorted(passed_at_sample)) 100 | print(len(passed_at_sample)) 101 | print(Counter(passed_at_sample)) 102 | print(Counter(pass_problem_subset)) 103 | 104 | 105 | # write_jsonl(log_path, [item], append=True) 106 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={round(weak_success / (idx + 1), 3)}') 107 | -------------------------------------------------------------------------------- /reflexion_codecontests.sh: -------------------------------------------------------------------------------- 1 | # export OPENAI_API_KEY="" # input your openai key if not already 2 | python main.py \ 3 | --run_name "code_4o-mini-reflexion" \ 4 | --root_dir "root" \ 5 | --dataset_path data/code_contests_test.json \ 6 | --strategy "reflexion" \ 7 | --language "py" \ 8 | --model "gpt-4o-mini" \ 9 | --pass_at_k "1" \ 10 | --max_iters 20 \ 11 | --verbose -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | astunparse==1.6.3 2 | datasets==2.18.0 3 | httpx==0.23.0 4 | jsonlines==2.0.0 5 | numpy==2.1.3 6 | openai==0.28.0 7 | tenacity==8.2.3 8 | torch==2.5.1 9 | transformers==4.32.1 10 | vllm==0.6.3 11 | -------------------------------------------------------------------------------- /resample_baseline.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | from utils import enumerate_resume, make_printv, write_jsonl, resume_success_count 22 | from executors import executor_factory 23 | import sys 24 | from common import gen_test_eval 25 | from generators import generator_factory, model_factory 26 | from typing import List, Dict, Tuple, Any 27 | import math 28 | import sys 29 | from collections import Counter 30 | sys.set_int_max_str_digits(100000) # Increase the limit to 10000 digits 31 | 32 | class Node: 33 | def __init__(self, solution: str, parent=None, context="", depth=0): 34 | self.solution = solution 35 | self.parent = parent 36 | self.children = [] 37 | self.value = 0 38 | self.visits = 0 39 | self.context = "" 40 | self.depth = depth 41 | self.reflection = "" 42 | self.test_feedback = "" 43 | 44 | def uct(self, exploration_weight=1.0): 45 | if self.visits == 0: 46 | # return float('inf') 47 | return self.value 48 | return (self.value / self.visits) + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits) 49 | 50 | def best_child(self): 51 | if not self.children: # Check if children list is empty 52 | return None 53 | return max(self.children, key=lambda child: child.uct()) 54 | 55 | def best_child_value(self): 56 | if not self.children: # Check if children list is empty 57 | return None 58 | return max(self.children, key=lambda child: child.value) 59 | 60 | def sort_children_by_value(self): 61 | self.children.sort(key=lambda x: x.value) 62 | 63 | def update(self, reward: float): 64 | self.visits += 1 65 | self.value += reward 66 | 67 | def resample( 68 | dataset: List[dict], 69 | model_name: str, 70 | language: str, 71 | max_iters: int, 72 | log_path: str, 73 | verbose: bool, 74 | is_leetcode: bool = False, 75 | Codecontests: bool = False 76 | ) -> None: 77 | if Codecontests: 78 | exe = executor_factory("code_contests") 79 | else: exe = executor_factory(language, is_leet=is_leetcode) 80 | 81 | pass_problem_subset = [] 82 | gen = generator_factory(language) 83 | model = model_factory(model_name) 84 | print_v = make_printv(verbose) 85 | 86 | num_items = len(dataset) 87 | num_success, weak_success = 0, 0 # Counter for successful solutions 88 | passed_at_sample, solve_or_not = [], [] 89 | 90 | for idx, item in enumerate(dataset): 91 | print("STARTING EXAMPLE", idx) 92 | tests_i = item["given_tests"] 93 | if Codecontests: 94 | item["entry_point"] = "" 95 | else: 96 | tests_i = [test for test in tests_i if item['entry_point'] in test and 'assert False' not in test] 97 | root = Node("") 98 | stack = [root] # implementations 99 | is_solved, is_weaker_solved = False, False 100 | num_try = 0 101 | for i in range(max_iters): 102 | cur_func_impl = None 103 | while cur_func_impl is None: 104 | cur_func_impl = gen.func_impl(item["prompt"], model, "simple", temperature=1.0) 105 | stack.append(Node(cur_func_impl)) 106 | stack[0].children.append(stack[-1]) 107 | is_passing, feedback, reward = gen_test_eval(exe, cur_func_impl, tests_i, prev=item["prev"]) 108 | num_try += 1 109 | stack[-1].update(reward) 110 | stack[-1].test_feedback = feedback 111 | if is_passing: 112 | is_solved = exe.evaluate( 113 | item["entry_point"], cur_func_impl, item["test"], timeout=1, prev=item["prev"]) # early exit 114 | if "weaker_test" in item.keys(): 115 | is_weaker_solved = exe.evaluate( 116 | item["entry_point"], cur_func_impl, item["weaker_test"], timeout=1, prev=item["prev"]) 117 | break 118 | # Exit when passed public test cases. 119 | if is_passing: 120 | if is_solved: 121 | num_success += int(is_solved) 122 | passed_at_sample.append(num_try) 123 | if "difficulty" in item.keys(): pass_problem_subset.append(item["difficulty"]) 124 | else: print("sad, passed but no solve.") 125 | if is_weaker_solved: 126 | weak_success += int(is_weaker_solved) 127 | item["acc"] = round(num_success / (idx + 1), 3) 128 | item["weak_acc"] = round(weak_success / (idx + 1), 3) 129 | 130 | write_jsonl(log_path, [item], append=True) 131 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={round(weak_success / (idx + 1), 3)}') 132 | continue # early stop on this case if passsed 133 | print("_______________________________") 134 | print(passed_at_sample) 135 | print(sorted(passed_at_sample)) 136 | print(len(passed_at_sample)) 137 | print(Counter(passed_at_sample)) 138 | print(Counter(pass_problem_subset)) 139 | 140 | 141 | # write_jsonl(log_path, [item], append=True) 142 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={round(weak_success / (idx + 1), 3)}') 143 | -------------------------------------------------------------------------------- /root/check_test.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import sys 22 | 23 | 24 | def combine_function(docstring: str, implementation: str) -> str: 25 | impl_lines = implementation.strip().split("\n") 26 | # Find the function definition line in the implementation 27 | func_def_line = None 28 | for i, line in enumerate(impl_lines): 29 | if line.strip().startswith("def "): 30 | func_def_line = i 31 | break 32 | if func_def_line is None: 33 | raise ValueError("Function definition not found in the implementation") 34 | impl_lines = docstring + "\n".join(impl_lines[func_def_line+1:]) 35 | return impl_lines 36 | 37 | if __name__ == "__main__": 38 | # Example usage 39 | docstring = ''' 40 | def has_close_elements(numbers: List[float], threshold: float) -> bool: 41 | """ Check if in given list of numbers, are any two numbers closer to each other than 42 | given threshold. 43 | >>> has_close_elements([1.0, 2.0, 3.0], 0.5) 44 | False 45 | >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3) 46 | True 47 | """ 48 | ''' 49 | implementation = """def has_close_elements(numbers: List[float], threshold: float) -> bool: 50 | 51 | for idx, elem in enumerate(numbers): 52 | for idx2, elem2 in enumerate(numbers): 53 | if idx != idx2: 54 | distance = abs(elem - elem2) 55 | if distance < threshold: 56 | return True 57 | 58 | return False 59 | """ 60 | 61 | combined_function = combine_function(docstring, implementation) 62 | print(combined_function) 63 | PY_IMPLEMENT = ("You are an AI assistant that help the user to write code. You will be given a function signature and its docstring by the user. The user would also suggest a strategy." 64 | "You should instruct the user how to implement it(in English not with Code), add details not provided in the strategy. You should give {py_strategy_k} distinct alternatives that are potentially correct on how to ground this idea." 65 | "e.g., if the idea was linear greedy search, it could be done through a forward or backward scan. Each alternative should be several sentences in one line. List and Number your implementation {py_strategy_k} alternatives line by line using \"1. \"; \"2. \" \"3. \" and so on.") 66 | PY_REFELCTION = ("You are an AI assistant that provides reflection. You will be given a function implementation and a series of unit tests. Your goal is to explain why the implementation is wrong as indicated by the tests, " 67 | "then point out a direction to fix the bug. You must provide 2 alternatives for the different possible bugs/fixes. List and number your alternatives line by line using \"1. \" and \"2. \". For each line, use a few sentences to analyze the issue from an angle, guess a possible bug, and suggest and describe how to fix it. Do not use new lines or list steps within each alternative.") 68 | sys.stdout.write(PY_REFELCTION.replace("{py_strategy_k}","2")) 69 | 70 | # USER: 71 | # function signature: 72 | # def say_hi() -> str: 73 | # """ 74 | # Greet as a computer. 75 | # """ 76 | # 77 | # strategy: output "hello world" is a good way to greet. 78 | # 79 | # MODEL: 80 | # 1. use the bulit-in `print(string)` function in python 81 | # 2. use `sys.stdout.write(string)` to directly write to standard output. 82 | # 83 | # But the model uses new lines and list step-by-step for a single implementation. 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /root/get_acc.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import json 22 | 23 | def calculate_overall_accuracy(filename): 24 | overall_count = 0 # total number of instances 25 | overall_correct = 0 # total number of correct instances 26 | running_avg = 0 27 | prev_acc = 0 28 | 29 | with open(filename, 'r') as f: 30 | count = 0 # number of instances for the current run 31 | for line in f: 32 | data = json.loads(line) 33 | acc = data['acc'] 34 | 35 | # Check for reset 36 | if acc == 1.0 or acc == 0.0 and prev_acc != 1.0 and prev_acc != 0.0: 37 | # Use the last running average to find the number of correct instances for this run 38 | correct = int(running_avg * count) 39 | 40 | # Update overall counters 41 | overall_count += count 42 | overall_correct += correct 43 | 44 | # Reset for the next run 45 | count = 0 46 | 47 | # Update count for the current run 48 | count += 1 49 | 50 | # Keep track of the current running average 51 | running_avg = acc 52 | prev_acc = acc 53 | 54 | # Don't forget the last run 55 | if count > 0: 56 | correct = int(running_avg * count) 57 | overall_count += count 58 | overall_correct += correct 59 | 60 | # Calculate overall accuracy 61 | if overall_count == 0: 62 | return 0, count 63 | else: 64 | return overall_correct / overall_count, overall_count 65 | 66 | filename = "/Users/andyzhou/Documents/Research/LLMPlanning/programming/root/test_mcts_hard_acc_full_4tst_temp_gpt4/humaneval-py._mcts_8_gpt-4_pass_at_k_1_py.jsonl" 67 | res = calculate_overall_accuracy(filename) 68 | overall_avg = res[0] 69 | count = res[1] 70 | print(f"Overall average accuracy: {overall_avg}") 71 | print(f"Count: {count}") 72 | 73 | -------------------------------------------------------------------------------- /strategy.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | from utils import make_printv, write_jsonl 22 | from executors import executor_factory 23 | from generators import generator_factory, model_factory 24 | from typing import List, Dict, Tuple, Any 25 | import math 26 | import sys 27 | from collections import Counter 28 | from common import gen_test_eval 29 | sys.set_int_max_str_digits(100000) # Increase the limit to 10000 digits 30 | from config import get_parsed_args 31 | args = get_parsed_args() 32 | Codecontests = False if args.function else True 33 | ADD_HINT = "\nTo solve the problem, You can follow the hint given by an expert: " 34 | 35 | 36 | class Node: 37 | def __init__(self, solution: str, parent=None, context="", depth=0): 38 | self.solution = solution 39 | self.parent = parent 40 | self.children = [] 41 | self.value = 0 42 | self.visits = 0 43 | self.context = "" 44 | self.depth = depth 45 | self.reflection = "" 46 | self.test_feedback = "" 47 | 48 | def uct(self, exploration_weight=1.0): 49 | if self.visits == 0: 50 | return self.value 51 | return (self.value / self.visits) + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits) 52 | 53 | def best_child(self): 54 | if not self.children: # Check if children list is empty 55 | return None 56 | return max(self.children, key=lambda child: child.uct()) 57 | 58 | def best_child_value(self): 59 | if not self.children: # Check if children list is empty 60 | return None 61 | return max(self.children, key=lambda child: child.value) 62 | 63 | def sort_children_by_value(self): 64 | self.children.sort(key=lambda x: x.value) 65 | 66 | def update(self, reward: float): 67 | self.visits += 1 68 | self.value += reward 69 | 70 | def strategy_guide( 71 | dataset: List[dict], 72 | model_name: str, 73 | language: str, 74 | log_path: str, 75 | verbose: bool, 76 | max_iters: int, 77 | Codecontests: bool, 78 | is_leetcode: bool = False 79 | ) -> None: 80 | if Codecontests: 81 | exe = executor_factory("code_contests") 82 | else: exe = executor_factory(language, is_leet=is_leetcode) 83 | 84 | gen = generator_factory(language) 85 | model = model_factory(model_name) 86 | print_v = make_printv(verbose) 87 | 88 | num_items = len(dataset) 89 | num_success, skip, weak_success = 0, 0, 0 # Counter for successful solutions 90 | passed_at_sample, solve_or_not = [], [] 91 | is_weaker_solved = False 92 | pass_problem_subset = [] 93 | 94 | for idx, item in enumerate(dataset): 95 | 96 | tests_i = item["given_tests"] 97 | if Codecontests: 98 | item["entry_point"] = "" 99 | else: 100 | tests_i = [test for test in tests_i if item['entry_point'] in test and 'assert False' not in test] 101 | 102 | hints = gen.strategy(item["prompt"], model, num_strategy=20) 103 | 104 | root = Node("") 105 | stack = [root] # implementations 106 | is_solved = False 107 | num_try = 0 108 | if len(hints) > max_iters: hints = hints[:max_iters] 109 | elif len(hints) < max_iters: hints = hints + gen.strategy(item["prompt"]+"\nThink Carefully.", model, num_strategy=20 - len(hints)) 110 | for hint in hints: 111 | cur_func_impl = None 112 | while cur_func_impl is None: 113 | cur_func_impl = gen.func_impl(item["prompt"] + f"{ADD_HINT}{hint}\n", model, "simple") 114 | stack.append(Node(cur_func_impl)) 115 | stack[0].children.append(stack[-1]) # adding children to root 116 | is_passing, feedback, reward = gen_test_eval(exe, cur_func_impl, tests_i, prev=item["prev"]) 117 | num_try += 1 118 | stack[-1].update(reward) 119 | stack[-1].test_feedback = feedback 120 | if is_passing: 121 | is_solved = exe.evaluate( 122 | item["entry_point"], cur_func_impl, item["test"], timeout=1,prev=item["prev"]) # early exit 123 | if "weaker_test" in item.keys(): 124 | is_weaker_solved = exe.evaluate( 125 | item["entry_point"], cur_func_impl, item["weaker_test"], timeout=1, prev=item["prev"]) 126 | break 127 | # Exit when passed public test cases. 128 | if is_passing: 129 | if is_solved: 130 | num_success += int(is_solved) 131 | passed_at_sample.append(num_try) 132 | if "difficulty" in item.keys(): pass_problem_subset.append(item["difficulty"]) 133 | if is_weaker_solved: 134 | weak_success += int(is_weaker_solved) 135 | item["weak_acc"] = round(weak_success / (idx + 1), 3) 136 | item["acc"] = round(num_success / (idx + 1), 3) 137 | write_jsonl(log_path, [item], append=True) 138 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={item["weak_acc"]}') 139 | continue # early stop on this case if passsed 140 | print("_______________________________") 141 | print(passed_at_sample) 142 | print(sorted(passed_at_sample)) 143 | print(len(passed_at_sample)) 144 | print(Counter(passed_at_sample)) 145 | print(Counter(pass_problem_subset)) 146 | # write_jsonl(log_path, [item], append=True) 147 | print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={round(weak_success / (idx + 1), 3)}') 148 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # LICENSE HEADER MANAGED BY add-license-header 2 | # 3 | # /* 4 | # * Copyright (c) 2023, Salesforce, Inc. 5 | # * SPDX-License-Identifier: Apache-2 6 | # * 7 | # * Licensed under the Apache License, Version 2.0 (the "License"); 8 | # * you may not use this file except in compliance with the License. 9 | # * You may obtain a copy of the License at 10 | # * 11 | # * http://www.apache.org/licenses/LICENSE-2.0 12 | # * 13 | # * Unless required by applicable law or agreed to in writing, software 14 | # * distributed under the License is distributed on an "AS IS" BASIS, 15 | # * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # * See the License for the specific language governing permissions and 17 | # * limitations under the License. 18 | # */ 19 | # 20 | 21 | import os 22 | import gzip 23 | import json 24 | import openai 25 | import jsonlines 26 | 27 | from typing import List 28 | 29 | openai.api_key = os.getenv("OPENAI_API_KEY") 30 | 31 | 32 | def make_printv(verbose: bool): 33 | def print_v(*args, **kwargs): 34 | if verbose: 35 | kwargs["flush"] = True 36 | print(*args, **kwargs) 37 | else: 38 | pass 39 | return print_v 40 | 41 | def merge_test_subsets(test_subsets): 42 | # Initialize the lists for full_tests inputs and outputs 43 | full_inputs = [] 44 | full_outputs = [] 45 | 46 | # Iterate through each subset and extend the full_tests lists 47 | for subset in ['public_tests', 'private_tests']: #, 'generated_tests']: 48 | # Check if the subset exists in the test_subsets to avoid KeyError 49 | if subset in test_subsets: 50 | full_inputs.extend(test_subsets[subset]['input']) 51 | full_outputs.extend(test_subsets[subset]['output']) 52 | 53 | # Construct the full_tests dictionary 54 | full_tests = { 55 | 'input': full_inputs, 56 | 'output': full_outputs 57 | } 58 | return full_tests 59 | 60 | def read_json(path:str): 61 | with open(path) as f: 62 | data = json.load(f) 63 | for i in range(len(data)): 64 | data[i]["prompt"] = data[i]["description"].replace("\n\n\n","\n\n").replace("\n\n\n","\n\n") 65 | data[i]["given_tests"] = data[i]["public_tests"] 66 | data[i]["test"] =data[i]["private_tests"] # merge_test_subsets(data[i]) 67 | data[i]["prev"] = "" 68 | if "cf_rating" not in data[i].keys(): data[i]["cf_rating"] = 0 69 | return data 70 | 71 | def read_jsonl(path: str) -> List[dict]: 72 | if not os.path.exists(path): 73 | raise FileNotFoundError(f"File `{path}` does not exist.") 74 | elif not path.endswith(".jsonl"): 75 | raise ValueError(f"File `{path}` is not a jsonl file.") 76 | items = [] 77 | with jsonlines.open(path) as reader: 78 | for item in reader: 79 | items += [item] 80 | return items 81 | 82 | 83 | def write_jsonl(path: str, data: List[dict], append: bool = False): 84 | os.makedirs(os.path.dirname(path), exist_ok=True) 85 | with jsonlines.open(path, mode='a' if append else 'w') as writer: 86 | for item in data: 87 | writer.write(item) 88 | 89 | 90 | def read_jsonl_gz(path: str) -> List[dict]: 91 | if not path.endswith(".jsonl.gz"): 92 | raise ValueError(f"File `{path}` is not a jsonl.gz file.") 93 | with gzip.open(path, "rt") as f: 94 | data = [json.loads(line) for line in f] 95 | return data 96 | 97 | 98 | # generator that returns the item and the index in the dataset. 99 | # if the results_path exists, it will skip all items that have been processed 100 | # before. 101 | def enumerate_resume(dataset, results_path): 102 | if not os.path.exists(results_path): 103 | for i, item in enumerate(dataset): 104 | yield i, item 105 | else: 106 | count = 0 107 | with jsonlines.open(results_path) as reader: 108 | for item in reader: 109 | count += 1 110 | 111 | for i, item in enumerate(dataset): 112 | # skip items that have been processed before 113 | if i < count: 114 | continue 115 | yield i, item 116 | 117 | 118 | def resume_success_count(dataset) -> int: 119 | count = 0 120 | for item in dataset: 121 | if "is_solved" in item and item["is_solved"]: 122 | count += 1 123 | return count 124 | 125 | --------------------------------------------------------------------------------