├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── collect_mbpp.py ├── collect_nl2bash.py ├── collect_spider.py ├── collect_zeroshot.py ├── collectors.py ├── data.py ├── evaluate.py ├── exec_spider_gold.py ├── execution.py ├── fewshot_reviewer.py ├── multi_exec.py ├── process_sql.py ├── pyminifier_canonicalize.py ├── requirements.txt ├── sample_selectors.py ├── utils.py ├── utils_sql.py └── zeroshot_reviewer.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset 2 | dataset.zip 3 | samples 4 | samples.zip 5 | result_db 6 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Open Source Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | Using welcoming and inclusive language 12 | Being respectful of differing viewpoints and experiences 13 | Gracefully accepting constructive criticism 14 | Focusing on what is best for the community 15 | Showing empathy towards other community members 16 | Examples of unacceptable behavior by participants include: 17 | 18 | The use of sexualized language or imagery and unwelcome sexual attention or advances 19 | Trolling, insulting/derogatory comments, and personal or political attacks 20 | Public or private harassment 21 | Publishing others’ private information, such as a physical or electronic address, without explicit permission 22 | Other conduct which could reasonably be considered inappropriate in a professional setting 23 | 24 | ## Our Responsibilities 25 | 26 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 27 | 28 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 29 | 30 | ## Scope 31 | 32 | This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 33 | 34 | ## Enforcement 35 | 36 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 37 | 38 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership. 39 | 40 | ## Attribution 41 | 42 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 43 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 44 | 45 | [homepage]: https://www.contributor-covenant.org -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 2 spaces for indentation rather than tabs 31 | * 80 character line length 32 | * ... 33 | 34 | ## License 35 | By contributing, you agree that your contributions will be licensed 36 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Attribution-NonCommercial 4.0 International 2 | 3 | Creative Commons Corporation ("Creative Commons") is not a law firm and 4 | does not provide legal services or legal advice. Distribution of 5 | Creative Commons public licenses does not create a lawyer-client or 6 | other relationship. Creative Commons makes its licenses and related 7 | information available on an "as-is" basis. Creative Commons gives no 8 | warranties regarding its licenses, any material licensed under their 9 | terms and conditions, or any related information. Creative Commons 10 | disclaims all liability for damages resulting from their use to the 11 | fullest extent possible. 12 | 13 | Using Creative Commons Public Licenses 14 | 15 | Creative Commons public licenses provide a standard set of terms and 16 | conditions that creators and other rights holders may use to share 17 | original works of authorship and other material subject to copyright and 18 | certain other rights specified in the public license below. The 19 | following considerations are for informational purposes only, are not 20 | exhaustive, and do not form part of our licenses. 21 | 22 | - Considerations for licensors: Our public licenses are intended for 23 | use by those authorized to give the public permission to use 24 | material in ways otherwise restricted by copyright and certain other 25 | rights. Our licenses are irrevocable. Licensors should read and 26 | understand the terms and conditions of the license they choose 27 | before applying it. Licensors should also secure all rights 28 | necessary before applying our licenses so that the public can reuse 29 | the material as expected. Licensors should clearly mark any material 30 | not subject to the license. This includes other CC-licensed 31 | material, or material used under an exception or limitation to 32 | copyright. More considerations for licensors : 33 | wiki.creativecommons.org/Considerations\_for\_licensors 34 | 35 | - Considerations for the public: By using one of our public licenses, 36 | a licensor grants the public permission to use the licensed material 37 | under specified terms and conditions. If the licensor's permission 38 | is not necessary for any reason–for example, because of any 39 | applicable exception or limitation to copyright–then that use is not 40 | regulated by the license. Our licenses grant only permissions under 41 | copyright and certain other rights that a licensor has authority to 42 | grant. Use of the licensed material may still be restricted for 43 | other reasons, including because others have copyright or other 44 | rights in the material. A licensor may make special requests, such 45 | as asking that all changes be marked or described. Although not 46 | required by our licenses, you are encouraged to respect those 47 | requests where reasonable. More considerations for the public : 48 | wiki.creativecommons.org/Considerations\_for\_licensees 49 | 50 | Creative Commons Attribution-NonCommercial 4.0 International Public 51 | License 52 | 53 | By exercising the Licensed Rights (defined below), You accept and agree 54 | to be bound by the terms and conditions of this Creative Commons 55 | Attribution-NonCommercial 4.0 International Public License ("Public 56 | License"). To the extent this Public License may be interpreted as a 57 | contract, You are granted the Licensed Rights in consideration of Your 58 | acceptance of these terms and conditions, and the Licensor grants You 59 | such rights in consideration of benefits the Licensor receives from 60 | making the Licensed Material available under these terms and conditions. 61 | 62 | - Section 1 – Definitions. 63 | 64 | - a. Adapted Material means material subject to Copyright and 65 | Similar Rights that is derived from or based upon the Licensed 66 | Material and in which the Licensed Material is translated, 67 | altered, arranged, transformed, or otherwise modified in a 68 | manner requiring permission under the Copyright and Similar 69 | Rights held by the Licensor. For purposes of this Public 70 | License, where the Licensed Material is a musical work, 71 | performance, or sound recording, Adapted Material is always 72 | produced where the Licensed Material is synched in timed 73 | relation with a moving image. 74 | - b. Adapter's License means the license You apply to Your 75 | Copyright and Similar Rights in Your contributions to Adapted 76 | Material in accordance with the terms and conditions of this 77 | Public License. 78 | - c. Copyright and Similar Rights means copyright and/or similar 79 | rights closely related to copyright including, without 80 | limitation, performance, broadcast, sound recording, and Sui 81 | Generis Database Rights, without regard to how the rights are 82 | labeled or categorized. For purposes of this Public License, the 83 | rights specified in Section 2(b)(1)-(2) are not Copyright and 84 | Similar Rights. 85 | - d. Effective Technological Measures means those measures that, 86 | in the absence of proper authority, may not be circumvented 87 | under laws fulfilling obligations under Article 11 of the WIPO 88 | Copyright Treaty adopted on December 20, 1996, and/or similar 89 | international agreements. 90 | - e. Exceptions and Limitations means fair use, fair dealing, 91 | and/or any other exception or limitation to Copyright and 92 | Similar Rights that applies to Your use of the Licensed 93 | Material. 94 | - f. Licensed Material means the artistic or literary work, 95 | database, or other material to which the Licensor applied this 96 | Public License. 97 | - g. Licensed Rights means the rights granted to You subject to 98 | the terms and conditions of this Public License, which are 99 | limited to all Copyright and Similar Rights that apply to Your 100 | use of the Licensed Material and that the Licensor has authority 101 | to license. 102 | - h. Licensor means the individual(s) or entity(ies) granting 103 | rights under this Public License. 104 | - i. NonCommercial means not primarily intended for or directed 105 | towards commercial advantage or monetary compensation. For 106 | purposes of this Public License, the exchange of the Licensed 107 | Material for other material subject to Copyright and Similar 108 | Rights by digital file-sharing or similar means is NonCommercial 109 | provided there is no payment of monetary compensation in 110 | connection with the exchange. 111 | - j. Share means to provide material to the public by any means or 112 | process that requires permission under the Licensed Rights, such 113 | as reproduction, public display, public performance, 114 | distribution, dissemination, communication, or importation, and 115 | to make material available to the public including in ways that 116 | members of the public may access the material from a place and 117 | at a time individually chosen by them. 118 | - k. Sui Generis Database Rights means rights other than copyright 119 | resulting from Directive 96/9/EC of the European Parliament and 120 | of the Council of 11 March 1996 on the legal protection of 121 | databases, as amended and/or succeeded, as well as other 122 | essentially equivalent rights anywhere in the world. 123 | - l. You means the individual or entity exercising the Licensed 124 | Rights under this Public License. Your has a corresponding 125 | meaning. 126 | 127 | - Section 2 – Scope. 128 | 129 | - a. License grant. 130 | - 1. Subject to the terms and conditions of this Public 131 | License, the Licensor hereby grants You a worldwide, 132 | royalty-free, non-sublicensable, non-exclusive, irrevocable 133 | license to exercise the Licensed Rights in the Licensed 134 | Material to: 135 | - A. reproduce and Share the Licensed Material, in whole 136 | or in part, for NonCommercial purposes only; and 137 | - B. produce, reproduce, and Share Adapted Material for 138 | NonCommercial purposes only. 139 | - 2. Exceptions and Limitations. For the avoidance of doubt, 140 | where Exceptions and Limitations apply to Your use, this 141 | Public License does not apply, and You do not need to comply 142 | with its terms and conditions. 143 | - 3. Term. The term of this Public License is specified in 144 | Section 6(a). 145 | - 4. Media and formats; technical modifications allowed. The 146 | Licensor authorizes You to exercise the Licensed Rights in 147 | all media and formats whether now known or hereafter 148 | created, and to make technical modifications necessary to do 149 | so. The Licensor waives and/or agrees not to assert any 150 | right or authority to forbid You from making technical 151 | modifications necessary to exercise the Licensed Rights, 152 | including technical modifications necessary to circumvent 153 | Effective Technological Measures. For purposes of this 154 | Public License, simply making modifications authorized by 155 | this Section 2(a)(4) never produces Adapted Material. 156 | - 5. Downstream recipients. 157 | - A. Offer from the Licensor – Licensed Material. Every 158 | recipient of the Licensed Material automatically 159 | receives an offer from the Licensor to exercise the 160 | Licensed Rights under the terms and conditions of this 161 | Public License. 162 | - B. No downstream restrictions. You may not offer or 163 | impose any additional or different terms or conditions 164 | on, or apply any Effective Technological Measures to, 165 | the Licensed Material if doing so restricts exercise of 166 | the Licensed Rights by any recipient of the Licensed 167 | Material. 168 | - 6. No endorsement. Nothing in this Public License 169 | constitutes or may be construed as permission to assert or 170 | imply that You are, or that Your use of the Licensed 171 | Material is, connected with, or sponsored, endorsed, or 172 | granted official status by, the Licensor or others 173 | designated to receive attribution as provided in Section 174 | 3(a)(1)(A)(i). 175 | - b. Other rights. 176 | - 1. Moral rights, such as the right of integrity, are not 177 | licensed under this Public License, nor are publicity, 178 | privacy, and/or other similar personality rights; however, 179 | to the extent possible, the Licensor waives and/or agrees 180 | not to assert any such rights held by the Licensor to the 181 | limited extent necessary to allow You to exercise the 182 | Licensed Rights, but not otherwise. 183 | - 2. Patent and trademark rights are not licensed under this 184 | Public License. 185 | - 3. To the extent possible, the Licensor waives any right to 186 | collect royalties from You for the exercise of the Licensed 187 | Rights, whether directly or through a collecting society 188 | under any voluntary or waivable statutory or compulsory 189 | licensing scheme. In all other cases the Licensor expressly 190 | reserves any right to collect such royalties, including when 191 | the Licensed Material is used other than for NonCommercial 192 | purposes. 193 | 194 | - Section 3 – License Conditions. 195 | 196 | Your exercise of the Licensed Rights is expressly made subject to 197 | the following conditions. 198 | 199 | - a. Attribution. 200 | - 1. If You Share the Licensed Material (including in modified 201 | form), You must: 202 | - A. retain the following if it is supplied by the 203 | Licensor with the Licensed Material: 204 | - i. identification of the creator(s) of the Licensed 205 | Material and any others designated to receive 206 | attribution, in any reasonable manner requested by 207 | the Licensor (including by pseudonym if designated); 208 | - ii. a copyright notice; 209 | - iii. a notice that refers to this Public License; 210 | - iv. a notice that refers to the disclaimer of 211 | warranties; 212 | - v. a URI or hyperlink to the Licensed Material to 213 | the extent reasonably practicable; 214 | - B. indicate if You modified the Licensed Material and 215 | retain an indication of any previous modifications; and 216 | - C. indicate the Licensed Material is licensed under this 217 | Public License, and include the text of, or the URI or 218 | hyperlink to, this Public License. 219 | - 2. You may satisfy the conditions in Section 3(a)(1) in any 220 | reasonable manner based on the medium, means, and context in 221 | which You Share the Licensed Material. For example, it may 222 | be reasonable to satisfy the conditions by providing a URI 223 | or hyperlink to a resource that includes the required 224 | information. 225 | - 3. If requested by the Licensor, You must remove any of the 226 | information required by Section 3(a)(1)(A) to the extent 227 | reasonably practicable. 228 | - 4. If You Share Adapted Material You produce, the Adapter's 229 | License You apply must not prevent recipients of the Adapted 230 | Material from complying with this Public License. 231 | 232 | - Section 4 – Sui Generis Database Rights. 233 | 234 | Where the Licensed Rights include Sui Generis Database Rights that 235 | apply to Your use of the Licensed Material: 236 | 237 | - a. for the avoidance of doubt, Section 2(a)(1) grants You the 238 | right to extract, reuse, reproduce, and Share all or a 239 | substantial portion of the contents of the database for 240 | NonCommercial purposes only; 241 | - b. if You include all or a substantial portion of the database 242 | contents in a database in which You have Sui Generis Database 243 | Rights, then the database in which You have Sui Generis Database 244 | Rights (but not its individual contents) is Adapted Material; 245 | and 246 | - c. You must comply with the conditions in Section 3(a) if You 247 | Share all or a substantial portion of the contents of the 248 | database. 249 | 250 | For the avoidance of doubt, this Section 4 supplements and does not 251 | replace Your obligations under this Public License where the 252 | Licensed Rights include other Copyright and Similar Rights. 253 | 254 | - Section 5 – Disclaimer of Warranties and Limitation of Liability. 255 | 256 | - a. Unless otherwise separately undertaken by the Licensor, to 257 | the extent possible, the Licensor offers the Licensed Material 258 | as-is and as-available, and makes no representations or 259 | warranties of any kind concerning the Licensed Material, whether 260 | express, implied, statutory, or other. This includes, without 261 | limitation, warranties of title, merchantability, fitness for a 262 | particular purpose, non-infringement, absence of latent or other 263 | defects, accuracy, or the presence or absence of errors, whether 264 | or not known or discoverable. Where disclaimers of warranties 265 | are not allowed in full or in part, this disclaimer may not 266 | apply to You. 267 | - b. To the extent possible, in no event will the Licensor be 268 | liable to You on any legal theory (including, without 269 | limitation, negligence) or otherwise for any direct, special, 270 | indirect, incidental, consequential, punitive, exemplary, or 271 | other losses, costs, expenses, or damages arising out of this 272 | Public License or use of the Licensed Material, even if the 273 | Licensor has been advised of the possibility of such losses, 274 | costs, expenses, or damages. Where a limitation of liability is 275 | not allowed in full or in part, this limitation may not apply to 276 | You. 277 | - c. The disclaimer of warranties and limitation of liability 278 | provided above shall be interpreted in a manner that, to the 279 | extent possible, most closely approximates an absolute 280 | disclaimer and waiver of all liability. 281 | 282 | - Section 6 – Term and Termination. 283 | 284 | - a. This Public License applies for the term of the Copyright and 285 | Similar Rights licensed here. However, if You fail to comply 286 | with this Public License, then Your rights under this Public 287 | License terminate automatically. 288 | - b. Where Your right to use the Licensed Material has terminated 289 | under Section 6(a), it reinstates: 290 | 291 | - 1. automatically as of the date the violation is cured, 292 | provided it is cured within 30 days of Your discovery of the 293 | violation; or 294 | - 2. upon express reinstatement by the Licensor. 295 | 296 | For the avoidance of doubt, this Section 6(b) does not affect 297 | any right the Licensor may have to seek remedies for Your 298 | violations of this Public License. 299 | 300 | - c. For the avoidance of doubt, the Licensor may also offer the 301 | Licensed Material under separate terms or conditions or stop 302 | distributing the Licensed Material at any time; however, doing 303 | so will not terminate this Public License. 304 | - d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 305 | License. 306 | 307 | - Section 7 – Other Terms and Conditions. 308 | 309 | - a. The Licensor shall not be bound by any additional or 310 | different terms or conditions communicated by You unless 311 | expressly agreed. 312 | - b. Any arrangements, understandings, or agreements regarding the 313 | Licensed Material not stated herein are separate from and 314 | independent of the terms and conditions of this Public License. 315 | 316 | - Section 8 – Interpretation. 317 | 318 | - a. For the avoidance of doubt, this Public License does not, and 319 | shall not be interpreted to, reduce, limit, restrict, or impose 320 | conditions on any use of the Licensed Material that could 321 | lawfully be made without permission under this Public License. 322 | - b. To the extent possible, if any provision of this Public 323 | License is deemed unenforceable, it shall be automatically 324 | reformed to the minimum extent necessary to make it enforceable. 325 | If the provision cannot be reformed, it shall be severed from 326 | this Public License without affecting the enforceability of the 327 | remaining terms and conditions. 328 | - c. No term or condition of this Public License will be waived 329 | and no failure to comply consented to unless expressly agreed to 330 | by the Licensor. 331 | - d. Nothing in this Public License constitutes or may be 332 | interpreted as a limitation upon, or waiver of, any privileges 333 | and immunities that apply to the Licensor or You, including from 334 | the legal processes of any jurisdiction or authority. 335 | 336 | Creative Commons is not a party to its public licenses. Notwithstanding, 337 | Creative Commons may elect to apply one of its public licenses to 338 | material it publishes and in those instances will be considered the 339 | "Licensor." The text of the Creative Commons public licenses is 340 | dedicated to the public domain under the CC0 Public Domain Dedication. 341 | Except for the limited purpose of indicating that material is shared 342 | under a Creative Commons public license or as otherwise permitted by the 343 | Creative Commons policies published at creativecommons.org/policies, 344 | Creative Commons does not authorize the use of the trademark "Creative 345 | Commons" or any other trademark or logo of Creative Commons without its 346 | prior written consent including, without limitation, in connection with 347 | any unauthorized modifications to any of its public licenses or any 348 | other arrangements, understandings, or agreements concerning use of 349 | licensed material. For the avoidance of doubt, this paragraph does not 350 | form part of the public licenses. 351 | 352 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Coder Reviewer Reranking for Code Generation 2 | [![made-with-python](https://img.shields.io/badge/Made%20with-Python-red.svg)](#python) 3 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 4 | 5 | Official code release for the paper [Coder Reviewer Reranking for Code Generation](https://arxiv.org/abs/2211.16490). 6 | 7 | 8 | ## Setup 9 | ### Downloading data and cached outputs 10 | 1. For convenience, we include data used for this project in [`dataset.zip`](https://dl.fbaipublicfiles.com/coder-reviewer/dataset.zip). You need to download and unzip this file before using this repo. 11 | These include 12 | - [HumanEval](https://github.com/openai/human-eval). We also include the prompt used in the [CodeT](https://github.com/microsoft/CodeT/tree/main/CodeT) paper 13 | - [MBPP](https://github.com/google-research/google-research/tree/master/mbpp), which includes both the sanitized version and the initial version. 14 | - [Spider](https://github.com/taoyds/spider) includes the evaluation script and the data. We also include the cached outputs from executing the groundtruth SQL queries. 15 | - [NL2BASH](https://github.com/TellinaTool/nl2bash/tree/master/data) 16 | 2. Samples and precomputed execution results can be found in [`samples.zip`](https://dl.fbaipublicfiles.com/coder-reviewer/samples.zip) 17 | 18 | ### Installing software environment 19 | 1. All experiments are run with `python==3.8.13`. 20 | 2. Install [pyminifier](https://github.com/liftoff/pyminifier/tree/master) from source. 21 | Installing `pyminifier` requires reverting setup tools to an older version (`pip install setuptools==57.5.0`). 22 | For other issues of installing `pyminifier`, checkout their [issues](https://github.com/liftoff/pyminifier/issues) for potential fixes. 23 | 3. Install `torch==1.12.1`. You should install a distribution that matches your hardware environment 24 | 4. Install the other packages by 25 | ```bash 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | ## Usage 30 | 31 | ### Running the selector with released outputs 32 | 1. We release samples obtained from the OpenAI codex API in `samples.zip`. Unzipping this file, you should see a folder with the below structure 33 | ```bash 34 | samples 35 | ├── codex-cushman 36 | │   ├── codet_humaneval 37 | │   └── mbpp_sanitized 38 | ├── codex001 39 | └── codex002 40 | ``` 41 | We will go over the code/commands you need to collect these samples in a later section. 42 | 2. Run the following script to compare different reranking methods. 43 | ```bash 44 | model="codex002" 45 | dataset="mbpp_sanitized" 46 | outdir="result_db" 47 | python sample_selectors.py --model ${model} \ 48 | --num_samples_end 25 \ 49 | --num_samples_gap 5 \ 50 | --data_path samples \ 51 | --out_dir ${outdir} \ 52 | --dataset ${dataset} \ 53 | --num_procs 10 \ 54 | --num_bootstraps 50 \ 55 | --temperature 0.4 \ 56 | --verbose\ 57 | ``` 58 | 3. We have included the execution results of all generated samples in the `samples.zip`. If you want to execute the generated programs yourself, you can run the following command. Typically, we leverage aggressive multiprocessing to speed up this process. You can change the number of processes by modifying `nprocs`. 59 | Modify the `model` and `dataset` arguments to execute other models and datasets. 60 | ```bash 61 | model="codex002" 62 | dataset="codet_humaneval" 63 | nprocs=25 64 | torchrun --nproc_per_node=${nprocs} multi_exec.py --temperature 0.4 --world_size 25 --dataset ${dataset} --in_data_path samples/${model} --batch_size 4 --num_seeds 25 --num_samples 5 --num_prompts 0 65 | ``` 66 | 67 | The outputs will look like and a dictionary object containing the result will be saved into `result_db` 68 | ``` 69 | sum_logprob 0.5587 0.01 70 | avg_logprob 0.5832 0.01 71 | avg_reverse_logprob 0.5626 0.01 72 | random 0.5562 0.01 73 | sumreverselogprob-ensemble#0.5 0.6152 0.01 74 | avgreverselogprob-ensemble#0.5 0.5963 0.01 75 | executability-sum_logprob 0.5976 0.01 76 | executability-avg_logprob 0.6049 0.01 77 | executability-avg_reverse_logprob 0.5952 0.01 78 | executability-random 0.5881 0.01 79 | executability-sumreverselogprob-ensemble#0.5 0.6440 0.01 80 | executability-avgreverselogprob-ensemble#0.5 0.6159 0.01 81 | mbr_exec 0.6389 0.01 82 | oracle 0.7891 0.01 83 | ``` 84 | 85 | ### Collecting Samples 86 | 1. the below example command collects 125 (5x25) samples for zeroshot humaneval with codex002. explore `collect*.py` for collecting samples on other datasets. These scripts collect programs given the language instructions, i.e., implementing the Coder model. 87 | ``` 88 | python collect_zeroshot.py --num_samples 5 --num_seeds 25 --dataset codet_humaneval collect --output-path samples/codex002 --engine-name codex002 --temperature 0.4 --split test --n-procs 1 --batch-size 20 --mode sample --n-prompts 0 89 | ``` 90 | 2. We collect the reviewer model p(instruction|generated program) by `fewshot_reviewer.py` and `zeroshot_reviewer.py`. Here's an example command for humaneval with codex002, 91 | ``` 92 | python zeroshot_reviewer.py --num_procs 1 --batch_size 20 --temperature 0.4 --num_samples 5 --split test --dataset codet_humaneval --model codex002 --data_path samples/codex002 --canonicalize --clean-print 93 | ``` 94 | This code will update the cached results with the reviewer model probability. Explore other arguments to run for different models and datasets. 95 | 96 | #### Authors 97 | - [Tianyi Zhang](https://tiiiger.github.io/) 98 | - [Tao Yu](https://taoyds.github.io/) 99 | - [Tatsunori Hashimoto](https://thashim.github.io/) 100 | - [Mike Lewis](https://research.facebook.com/people/lewis-mike/) 101 | - [Scott Wen-tau Yih](https://scottyih.org/) 102 | - [Daniel Fried](https://dpfried.github.io/) 103 | - [Sida I. Wang](http://www.sidaw.xyz/) 104 | 105 | #### Acknowledgement 106 | This codebase is largely adapted from [MBR-Exec](https://github.com/facebookresearch/mbr-exec). 107 | 108 | #### License 109 | This work is licensed under a 110 | [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc]. 111 | 112 | [![CC BY-NC 4.0][cc-by-nc-image]][cc-by-nc] 113 | 114 | [cc-by-nc]: http://creativecommons.org/licenses/by-nc/4.0/ 115 | [cc-by-nc-image]: https://licensebuttons.net/l/by-nc/4.0/88x31.png 116 | 117 | #### Citation 118 | If you find our work helpful, please cite as 119 | ``` 120 | @article{Zhang2022Coder, 121 | title={Coder Reviewer Reranking for Code Generation}, 122 | author={Tianyi Zhang and Tao Yu and Tatsunori B. Hashimoto and Mike Lewis and Wen-tau Yih and Daniel Fried and Sida I. Wang}, 123 | journal={ArXiv}, 124 | year={2022}, 125 | volume={abs/} 126 | } 127 | ``` 128 | -------------------------------------------------------------------------------- /collect_mbpp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import argparse 4 | 5 | import data 6 | from collectors import CollectorWithInfo 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--info-mode", 13 | type=str, 14 | default="assertion", 15 | choices=["function_name", "assertion"], 16 | ) 17 | parser.add_argument( 18 | "--dataset-type", 19 | type=str, 20 | default="MBPPGoogleDataset", 21 | choices=["MBPPDataset", "MBPPGoogleDataset"], 22 | ) 23 | parser.add_argument("--num_seeds", type=int, default=25) 24 | parser.add_argument("--num_samples", type=int, default=5) 25 | args = CollectorWithInfo.parse_args(parser) 26 | args.seed = list(range(args.num_seeds)) 27 | args.dataset = "mbpp" 28 | args.split = "test" 29 | dataset = getattr(data, args.dataset_type)(mode=args.info_mode) 30 | collector = CollectorWithInfo.from_args(args, dataset) 31 | for i in range(args.num_samples): 32 | collector(i, i, 5) 33 | -------------------------------------------------------------------------------- /collect_nl2bash.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from data import NL2BashDataset 4 | from collectors import CollectorWithInfo 5 | import argparse 6 | 7 | 8 | if __name__ == "__main__": 9 | dataset = NL2BashDataset() 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--num_seeds", type=int, default=25) 12 | parser.add_argument("--num_samples", type=int, default=5) 13 | args = CollectorWithInfo.parse_args(parser) 14 | args.dataset = "nl2bash" 15 | args.seed = list(range(args.num_seeds)) 16 | args.prompt_template = "{src}\n{trg}\n" 17 | args.example_template = "{src}\n" 18 | collector = CollectorWithInfo.from_args(args, dataset) 19 | for i in range(args.num_samples): 20 | collector(i, i, 5) 21 | -------------------------------------------------------------------------------- /collect_spider.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from data import SpiderDataset 4 | from collectors import CollectorWithInfo 5 | import argparse 6 | 7 | 8 | if __name__ == "__main__": 9 | dataset = SpiderDataset() 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--num_seeds", type=int, default=25) 12 | parser.add_argument("--num_samples", type=int, default=5) 13 | args = CollectorWithInfo.parse_args(parser) 14 | args.dataset = "spider" 15 | args.seed = list(range(args.num_seeds)) 16 | collector = CollectorWithInfo.from_args(args, dataset) 17 | for i in range(args.num_samples): 18 | collector(i, i, 5) 19 | -------------------------------------------------------------------------------- /collect_zeroshot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import argparse 4 | 5 | import data 6 | from collectors import CollectorWithInfo 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--info-mode", 13 | type=str, 14 | default="assertion", 15 | choices=["function_name", "assertion"], 16 | ) 17 | parser.add_argument( 18 | "--dataset", 19 | type=str, 20 | choices=["mbpp_sanitized", "codet_humaneval", "humaneval"], 21 | ) 22 | parser.add_argument("--num_seeds", type=int, default=25) 23 | parser.add_argument("--num_samples", type=int, default=5) 24 | args = CollectorWithInfo.parse_args(parser) 25 | args.output_path = args.output_path + "/" + str(args.dataset) 26 | if args.dataset == "codet_humaneval": 27 | data_file_path = "dataset/human_eval/dataset/CodeTHumanEval.jsonl" 28 | elif args.dataset == "mbpp_sanitized": 29 | data_file_path = "dataset/mbpp/mbpp_sanitized_for_code_generation.jsonl" 30 | args.end_template = ["\nclass", "\ndef", "\n#", "\nif"] 31 | dataset = getattr(data, "HumanEvalDataset")(path=data_file_path, mode="prompt_only") 32 | collector = CollectorWithInfo.from_args(args, dataset) 33 | if args.temperature > 0: 34 | args.seed = list(range(args.num_seeds)) 35 | for i in range(args.num_samples): 36 | collector(i, i, 5) 37 | else: 38 | args.seed = list(range(args.num_seeds)) 39 | collector(0, 0, 1) 40 | -------------------------------------------------------------------------------- /collectors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import argparse 4 | import copy 5 | import json 6 | import openai 7 | import os 8 | import pickle 9 | import random 10 | import signal 11 | import time 12 | from glob import glob 13 | from nltk.translate.bleu_score import sentence_bleu 14 | from tqdm import tqdm, trange 15 | import re 16 | 17 | codex_name_mapping = { 18 | "codex-cushman": "code-cushman-001", 19 | "codex002": "code-davinci-002", 20 | "codex001": "code-davinci-001", 21 | } 22 | 23 | 24 | def codex_greedy(configs, prompt, max_tokens=512): 25 | response = openai.Completion.create( 26 | engine=codex_name_mapping[configs.engine_name] 27 | if configs.engine_name is not None 28 | else "davinci-codex", 29 | prompt=prompt, 30 | temperature=0, 31 | max_tokens=max_tokens, 32 | top_p=1, 33 | frequency_penalty=0, 34 | presence_penalty=0, 35 | stop=configs.end_template, 36 | ) 37 | return response["choices"][0]["text"], None, None 38 | 39 | 40 | def codex_sample(configs, prompt, max_tokens=512): 41 | response = openai.Completion.create( 42 | engine=codex_name_mapping[configs.engine_name] 43 | if configs.engine_name is not None 44 | else "davinci-codex", 45 | prompt=prompt, 46 | temperature=configs.temperature, 47 | max_tokens=max_tokens, 48 | top_p=configs.top_p, 49 | frequency_penalty=0, 50 | presence_penalty=0, 51 | logprobs=1, 52 | stop=configs.end_template, 53 | ) 54 | return ( 55 | response["choices"][0]["text"], 56 | response["choices"][0]["logprobs"]["tokens"], 57 | response["choices"][0]["logprobs"]["token_logprobs"], 58 | ) 59 | 60 | 61 | def codex_batch_greedy(configs, batch_prompts, max_tokens=512): 62 | raise NotImplementedError 63 | 64 | 65 | def codex_batch_sample(configs, batch_prompts, max_tokens=512): 66 | response = openai.Completion.create( 67 | engine=codex_name_mapping[configs.engine_name] 68 | if configs.engine_name is not None 69 | else "davinci-codex", 70 | prompt=batch_prompts, 71 | temperature=configs.temperature, 72 | max_tokens=max_tokens, 73 | top_p=configs.top_p, 74 | frequency_penalty=0, 75 | presence_penalty=0, 76 | logprobs=1, 77 | stop=configs.end_template, 78 | ) 79 | return [ 80 | ( 81 | response["choices"][batch_i]["text"], 82 | response["choices"][batch_i]["logprobs"]["tokens"], 83 | response["choices"][batch_i]["logprobs"]["token_logprobs"], 84 | ) 85 | for batch_i in range(len(batch_prompts)) 86 | ] 87 | 88 | 89 | def process_batch_examples(args_with_idx): 90 | batch_i, batch_args = args_with_idx 91 | all_prompts = [] 92 | for args in batch_args: 93 | src, trg, info, prompt_prefix, configs = args 94 | if configs.dataset in ["mbpp_sanitized", "humaneval", "codet_humaneval"]: 95 | prompt = src 96 | else: 97 | prompt = prompt_prefix + configs.example_template.format(src=src, info=info) 98 | all_prompts.append(prompt) 99 | max_tokens = configs.max_tokens 100 | while True: 101 | if configs.engine_name == "codex002": 102 | openai.organization = os.getenv(f"OPENAI_ORG{(batch_i%3)+1}") 103 | else: 104 | openai.organization = os.getenv("OPENAI_ORG1") 105 | try: 106 | batch_results = ( 107 | codex_batch_greedy(configs, all_prompts, max_tokens) 108 | if configs.mode == "greedy" 109 | else codex_batch_sample(configs, all_prompts, max_tokens) 110 | ) 111 | break 112 | except openai.error.InvalidRequestError as e: 113 | print(f"Context len: halving gen tokens, curr: {max_tokens}", end="\r") 114 | max_tokens = max_tokens // 2 115 | if max_tokens < 32: 116 | raise ValueError("Prompt too long") 117 | except openai.error.RateLimitError as e: 118 | print(type(e), re.search("Current: .+ / min", str(e))[0], end="\r") 119 | time.sleep(30) 120 | except Exception as e: 121 | print(type(e), e) 122 | time.sleep(10) 123 | 124 | all_results = [] 125 | for args, prompt, (trg_prediction, tokens, logprobs) in zip( 126 | batch_args, all_prompts, batch_results 127 | ): 128 | src, trg, info, prompt_prefix, configs = args 129 | if "humaneval" in configs.dataset or configs.dataset == "mbpp_sanitized": 130 | if "\nprint" in trg_prediction: 131 | for i in range(0, len(tokens) - 1): 132 | if tokens[i : i + 2] == ["\n", "print"]: 133 | break 134 | tokens = tokens[:i] 135 | logprobs = logprobs[:i] 136 | trg_prediction = "".join(tokens) 137 | if i == len(tokens) - 1: 138 | raise ValueError("not matched") 139 | result = { 140 | "prompt": prompt, 141 | "src": src, 142 | "trg_prediction": trg_prediction, 143 | "reference": trg, 144 | "tokens": tokens, 145 | "logprobs": logprobs, 146 | } 147 | all_results.append(json.dumps(result)) 148 | return all_results 149 | 150 | 151 | def process_one_example(args): 152 | src, trg, info, prompt_prefix, configs = args 153 | if configs.dataset in ["mbpp_sanitized", "humaneval", "codet_humaneval"]: 154 | prompt = src 155 | else: 156 | prompt = prompt_prefix + configs.example_template.format(src=src, info=info) 157 | max_tokens = configs.max_tokens 158 | while True: 159 | try: 160 | trg_prediction, tokens, logprobs = ( 161 | codex_greedy(configs, prompt, max_tokens) 162 | if configs.mode == "greedy" 163 | else codex_sample(configs, prompt, max_tokens) 164 | ) 165 | break 166 | except openai.error.InvalidRequestError as e: 167 | print(f"Context len: halving gen tokens, curr: {max_tokens}", end="\r") 168 | max_tokens = max_tokens // 2 169 | if max_tokens < 32: 170 | raise ValueError("Prompt too long") 171 | except openai.error.RateLimitError as e: 172 | print(type(e), re.search("Current: .+ / min", str(e))[0], end="\r") 173 | time.sleep(30) 174 | except Exception as e: 175 | print(type(e), e) 176 | time.sleep(10) 177 | 178 | import warnings 179 | 180 | with warnings.catch_warnings(): 181 | warnings.simplefilter("ignore") 182 | try: 183 | bleu_score = sentence_bleu( 184 | [[ch for ch in trg]], [ch for ch in trg_prediction] 185 | ) 186 | except: 187 | bleu_score = 0 188 | if "humaneval" in configs.dataset or configs.dataset == "mbpp_sanitized": 189 | if "\nprint" in trg_prediction: 190 | for i in range(0, len(tokens) - 1): 191 | if tokens[i : i + 2] == ["\n", "print"]: 192 | break 193 | tokens = tokens[:i] 194 | logprobs = logprobs[:i] 195 | trg_prediction = "".join(tokens) 196 | if i == len(tokens) - 1: 197 | raise ValueError("not matched") 198 | return json.dumps( 199 | { 200 | "prompt": prompt, 201 | "src": src, 202 | "trg_prediction": trg_prediction, 203 | "reference": trg, 204 | "tokens": tokens, 205 | "logprobs": logprobs, 206 | "bleu": bleu_score, 207 | } 208 | ) 209 | 210 | 211 | def codex_with_info(configs, dataset, prefixes): 212 | # model 213 | openai.api_key = os.getenv("OPENAI_API_KEY") 214 | 215 | prompt_prefix = "".join( 216 | [ 217 | configs.prompt_template.format(src=x[0], trg=x[1], info=x[2]) 218 | for x in prefixes 219 | ] 220 | ) 221 | 222 | # save folder 223 | if configs.top_p == 1: 224 | save_dir = f"{configs.output_path}/seed-{configs.seed}/{configs.n_prompts}-shot/{configs.mode}-{configs.temperature}" 225 | else: 226 | save_dir = f"{configs.output_path}/seed-{configs.seed}/{configs.n_prompts}-shot/{configs.mode}-{configs.temperature}-p{configs.top_p}" 227 | if configs.max_tokens != 512: 228 | save_dir += f"-max{configs.max_tokens}" 229 | os.system(f"mkdir -p {save_dir}") 230 | # save configs and prefixes 231 | if configs.rank == 0: 232 | with open(f"{save_dir}/prefixes.json", "w") as fout: 233 | json.dump(prefixes, fout) 234 | fout.close() 235 | with open(f"{save_dir}/configs.pkl", "wb") as fout: 236 | pickle.dump(configs, fout) 237 | fout.close() 238 | ofname = f"{save_dir}/{configs.split}-{configs.rank}.jsonl" 239 | if os.path.exists(ofname): 240 | return 241 | 242 | from multiprocessing import Pool 243 | 244 | all_args = [] 245 | for (src, trg, info) in dataset: 246 | all_args.append((src, trg, info, prompt_prefix, configs)) 247 | all_jsons = [] 248 | if configs.n_procs > 1: 249 | all_args = [ 250 | all_args[chunk_start : chunk_start + configs.batch_size] 251 | for chunk_start in range(0, len(all_args), configs.batch_size) 252 | ] 253 | with Pool(processes=configs.n_procs) as pool: 254 | for result_json in tqdm( 255 | pool.imap(process_batch_examples, enumerate(all_args)), 256 | total=len(all_args), 257 | ): 258 | all_jsons.extend(result_json) 259 | else: 260 | for batch_i, batch_start in enumerate( 261 | trange(0, len(all_args), configs.batch_size) 262 | ): 263 | batch_args = all_args[batch_start : batch_start + configs.batch_size] 264 | all_jsons.extend(process_batch_examples((batch_i, batch_args))) 265 | 266 | with open(ofname, "w") as fout: 267 | for jsonf in all_jsons: 268 | fout.write(jsonf + "\n") 269 | 270 | 271 | """ example collector: """ 272 | 273 | 274 | class CollectorWithInfo(object): 275 | def __init__(self, configs, dataset): 276 | self.configs = configs 277 | self.dataset = dataset 278 | 279 | def __call__(self, rank, local_rank, world_size): 280 | configs = copy.deepcopy(self.configs) 281 | configs.rank = rank 282 | configs.gpu = local_rank 283 | configs.world_size = world_size 284 | args = [] 285 | for seed in self.configs.seed: 286 | for n_prompts in self.configs.n_prompts: 287 | args.append((seed, n_prompts, configs.temperature)) 288 | for seed, n_prompts, temperature in tqdm(args): 289 | configs.n_prompts = n_prompts 290 | configs.seed = seed 291 | configs.temperature = temperature 292 | random.seed(configs.seed) 293 | if configs.n_prompts == 0: 294 | prefixes = [] 295 | else: 296 | if configs.saved_prefixes_path_template is not None: 297 | prefix_pool = list() 298 | for path in glob( 299 | configs.saved_prefixes_path_template, recursive=True 300 | ): 301 | prefix_pool.extend(json.load(open(path))) 302 | prefix_pool = sorted(set([tuple(x) for x in prefix_pool])) 303 | prefixes = random.sample(prefix_pool, configs.n_prompts) 304 | else: 305 | prefixes = random.sample( 306 | self.dataset.data["train"], configs.n_prompts 307 | ) 308 | if configs.shuffle_prefix: 309 | original_prefixes = copy.deepcopy(prefixes) 310 | while original_prefixes == prefixes: 311 | random.shuffle(prefixes) 312 | codex_with_info(configs, self.dataset.data[configs.split], prefixes) 313 | 314 | @staticmethod 315 | def parse_args(main_parser=None): 316 | if main_parser is None: 317 | main_parser = argparse.ArgumentParser() 318 | subparsers = main_parser.add_subparsers(title="commands", dest="mode") 319 | # collect 320 | parser = subparsers.add_parser("collect", help="collecting stage") 321 | parser.add_argument("--output-path", type=str, required=True) 322 | parser.add_argument( 323 | "--split", type=str, default="dev", choices=["train", "dev", "test"] 324 | ) 325 | parser.add_argument("--seed", type=int, nargs="+", default=[0]) 326 | parser.add_argument("--n-procs", type=int, default=1) 327 | parser.add_argument( 328 | "--n-prompts", 329 | type=int, 330 | nargs="+", 331 | default=[3], 332 | help="number of few-shot prompt examples", 333 | ) 334 | parser.add_argument( 335 | "--mode", type=str, default="greedy", choices=["greedy", "sample"] 336 | ) 337 | parser.add_argument( 338 | "--batch-size", 339 | type=int, 340 | default=5, 341 | help="number of sampled examples under the sampling mode", 342 | ) 343 | parser.add_argument( 344 | "--max-tokens", 345 | type=int, 346 | default=512, 347 | help="number of sampled examples under the sampling mode", 348 | ) 349 | parser.add_argument( 350 | "--temperature", type=float, default=0.3, help="sample temperature" 351 | ) 352 | parser.add_argument( 353 | "--top_p", type=float, default=1.0, help="sample temperature" 354 | ) 355 | parser.add_argument( 356 | "--prompt-template", 357 | type=str, 358 | default="{info}\n{src}\n{trg}\n", 359 | ) 360 | parser.add_argument( 361 | "--example-template", 362 | type=str, 363 | default="{info}\n{src}\n", 364 | ) 365 | parser.add_argument("--end-template", type=str, default="") 366 | parser.add_argument("--shuffle-prefix", action="store_true", default=False) 367 | parser.add_argument("--saved-prefixes-path-template", type=str, default=None) 368 | parser.add_argument( 369 | "--engine-name", 370 | type=str, 371 | default="codex-cushman", 372 | choices=["codex-cushman", "codex001", "codex002"], 373 | ) 374 | # slurm arguments 375 | parser.add_argument("--slurm-ntasks", type=int, default=None) 376 | parser.add_argument("--slurm-ngpus", type=int, default=0) 377 | parser.add_argument("--slurm-nnodes", type=int, default=1) 378 | parser.add_argument("--slurm-partition", type=str, default="devlab") 379 | 380 | args = main_parser.parse_args() 381 | 382 | return args 383 | 384 | @classmethod 385 | def from_args(cls, args=None, dataset=None): 386 | if args is None: 387 | args = cls.parse_args() 388 | assert dataset is not None 389 | return cls(args, dataset) 390 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import collections 4 | import json 5 | import os 6 | import regex 7 | 8 | 9 | class NL2BashDataset(object): 10 | def __init__(self, path="dataset/nl2bash/data/bash"): 11 | self.data = collections.defaultdict() 12 | for split in ["train", "dev", "test"]: 13 | nls = [x.strip() for x in open(os.path.join(path, f"{split}.nl.filtered"))] 14 | cms = [x.strip() for x in open(os.path.join(path, f"{split}.cm.filtered"))] 15 | infos = ["" for x in open(os.path.join(path, f"{split}.cm.filtered"))] 16 | self.data[split] = list(zip(nls, cms, infos)) 17 | 18 | 19 | class SpiderDataset(object): 20 | def __init__(self, path="dataset/spider"): 21 | self.data = collections.defaultdict() 22 | self.dbs = json.load(open(f"{path}/tables.json")) 23 | self.id2db = {item["db_id"]: item for item in self.dbs} 24 | for split in ["train", "dev"]: 25 | split_fname = "train_spider" if split == "train" else split 26 | data = json.load(open(f"{path}/{split_fname}.json")) 27 | nls = [x["question"] for x in data] 28 | cms = [x["query"] for x in data] 29 | db_info = [self.extract_db_info(x["db_id"]) for x in data] 30 | self.data[split] = list(zip(nls, cms, db_info)) 31 | 32 | def extract_db_info(self, db_id): 33 | db = self.id2db[db_id] 34 | id2table = { 35 | i: table_name for i, table_name in enumerate(db["table_names_original"]) 36 | } 37 | info = f"{db_id} " 38 | used_table_id = set() 39 | for table_id, column_name in db["column_names_original"]: 40 | if table_id == -1: 41 | info += f"| {column_name} " 42 | elif table_id not in used_table_id: 43 | info += f"| {id2table[table_id]} : {column_name} " 44 | used_table_id.add(table_id) 45 | else: 46 | info += f", {column_name} " 47 | return info.strip() 48 | 49 | 50 | class MBPPGoogleDataset(object): 51 | def __init__(self, path="dataset/mbpp/mbpp.jsonl", mode="function_name"): 52 | raw_data = sorted( 53 | [json.loads(x) for x in open(path)], key=lambda x: x["task_id"] 54 | ) 55 | for i, data_item in enumerate(raw_data): 56 | assert data_item["task_id"] == i + 1 57 | self.raw_data = collections.defaultdict() 58 | self.mode = mode 59 | # 374 for training, 100 heldout, 500 test 60 | self.raw_data["train"] = raw_data[:10] + raw_data[510:] 61 | self.raw_data["test"] = raw_data[10:510] 62 | # data for codex collector, in input-output-info format 63 | self.data = collections.defaultdict() 64 | for split in self.raw_data: 65 | self.data[split] = self.extract_data(self.raw_data[split], mode) 66 | 67 | @staticmethod 68 | def extract_data(raw_data, mode): 69 | if mode == "function_name": 70 | get_function_name = lambda test_example: regex.match( 71 | "assert [\(]*([^\(]+)\(", test_example 72 | ).group(1) 73 | info = [get_function_name(x["test_list"][0]) for x in raw_data] 74 | elif mode == "assertion": 75 | info = [x["test_list"][0] for x in raw_data] 76 | elif mode == "assertion-full": 77 | info = [x["test_list"] for x in raw_data] 78 | else: 79 | raise Exception(f"Mode {mode} not supported.") 80 | nls = [x["text"] for x in raw_data] 81 | codes = [x["code"] for x in raw_data] 82 | return list(zip(nls, codes, info)) 83 | 84 | 85 | from dataset.human_eval.human_eval.data import read_problems 86 | 87 | 88 | class HumanEvalDataset(object): 89 | def __init__( 90 | self, 91 | path="dataset/human_eval/dataset/HumanEval.jsonl", 92 | assertion_path="", 93 | mode="assertion", 94 | ): 95 | self.path = path 96 | self.data = dict() 97 | self.raw_data = read_problems(path) 98 | self.mode = mode 99 | if assertion_path != "": 100 | self.assertion_data = read_problems(assertion_path) 101 | else: 102 | self.assertion_data = self.raw_data 103 | 104 | self.data["test"] = self.extract_data() 105 | 106 | def extract_data(self): 107 | nls = [] 108 | codes = [] 109 | info = [] 110 | for pid, prob in self.raw_data.items(): 111 | assert_prob = self.assertion_data[pid] 112 | nls.append(prob["prompt"]) 113 | docstring, func_header, func_context, doc_start = extract_docstring( 114 | assert_prob["prompt"] 115 | ) 116 | self.raw_data[pid]["func_header"] = func_header.strip() + "\n" 117 | self.raw_data[pid]["func_context"] = func_context 118 | codes.append(prob["canonical_solution"]) 119 | 120 | if self.mode != "prompt_only": 121 | assertions = extract_test(pid, prob["entry_point"], docstring) 122 | if self.mode == "assertion": 123 | self.raw_data[pid]["assertion"] = assertions[0] 124 | info.append(assertions[0]) 125 | else: 126 | self.raw_data[pid]["assertion"] = assertions 127 | info.append(assertions) 128 | else: 129 | info.append([]) 130 | return list(zip(nls, codes, info)) 131 | 132 | 133 | class MBPPSanDataset(HumanEvalDataset): 134 | def extract_data(self): 135 | nls = [] 136 | codes = [] 137 | info = [] 138 | for pid, prob in self.raw_data.items(): 139 | nls.append(prob["prompt"]) 140 | docstring, func_header, func_context, doc_start = extract_docstring( 141 | prob["prompt"] 142 | ) 143 | self.raw_data[pid]["func_header"] = func_header.strip() + "\n" 144 | self.raw_data[pid]["func_context"] = func_context 145 | codes.append(prob["canonical_solution"]) 146 | 147 | if self.mode != "prompt_only": 148 | assertions = [ 149 | l.strip() for l in prob["test"].split("\n")[1:] if l.strip() != "" 150 | ] 151 | if self.mode == "assertion": 152 | self.raw_data[pid]["assertion"] = assertions[0] 153 | info.append(assertions[0]) 154 | elif self.mode == "assertion-all": 155 | self.raw_data[pid]["assertion"] = assertions 156 | info.append(assertions) 157 | else: 158 | raise ValueError("invalid mode") 159 | else: 160 | info.append([]) 161 | return list(zip(nls, codes, info)) 162 | 163 | 164 | def rindex(lst, value): 165 | return len(lst) - lst[::-1].index(value) - 1 166 | 167 | 168 | def _check_test_case_validation(test_case): 169 | if len(test_case.strip()) < 1: 170 | return False 171 | if "assert" not in test_case: 172 | return False 173 | try: 174 | multi_line_test_case = test_case.replace("\n", "\n ") 175 | assert_in_a_block = f"try:\n {multi_line_test_case}\nexcept:\n pass\n" 176 | compile(assert_in_a_block, "", "exec") 177 | return True 178 | except Exception: 179 | return False 180 | 181 | 182 | def extract_generated_tests(content, entry_point): 183 | def _truncate(content): 184 | for identifier in ["\nclass", "\ndef", "\n#", "\nif", "\nprint"]: 185 | if identifier in content: 186 | content = content.split(identifier)[0] 187 | return content.strip() 188 | 189 | split_by_assert = [ 190 | f"assert {part}".strip() 191 | for part in f"assert {content}".split("assert ") 192 | if (entry_point.strip() in part) and len(part.strip()) > 0 193 | ] 194 | truncated_test_cases = [_truncate(i) for i in split_by_assert] 195 | checked_assertions = [ 196 | i for i in truncated_test_cases if _check_test_case_validation(i) 197 | ] 198 | return checked_assertions 199 | 200 | 201 | def extract_docstring(prompt): 202 | func_start = max(rindex(prompt, " fed") - 4, 0) 203 | clean_prompt = prompt[func_start:] 204 | if '"""' in prompt: 205 | doc_start = '"""' 206 | else: 207 | doc_start = "'''" 208 | docstring = clean_prompt[clean_prompt.strip().index(doc_start) :] 209 | func_header = clean_prompt[: clean_prompt.strip().index(doc_start)] 210 | func_context = prompt[:func_start] 211 | return docstring, func_header, func_context, doc_start 212 | 213 | 214 | def extract_test(pid, func_name, docstring): 215 | if pid in manual_extract: 216 | return manual_extract[pid] 217 | else: 218 | return _extract_tests(func_name, docstring) 219 | 220 | 221 | def _extract_tests(func_name, docstring): 222 | all_tests = [] 223 | doc_lines = docstring.strip().split("\n") 224 | 225 | test_start = False 226 | 227 | if ">>>" in docstring: 228 | for l in doc_lines: 229 | if not test_start: 230 | if ">>>" in l and func_name in l: 231 | test_start = True 232 | if test_start: 233 | if ">>>" in l and func_name in l: 234 | l = l.strip()[3:].strip() 235 | all_tests.append(l) 236 | elif l.strip() != "" and '"""' not in l: 237 | all_tests[-1] = "assert " + all_tests[-1] + f" == {l.strip()}" 238 | test_start = False 239 | elif any( 240 | ["==>" in docstring, "=>" in docstring, "->" in docstring, "➞" in docstring] 241 | ): 242 | for special_char in ["==>", "=>", "->", "➞", "==>"]: 243 | if special_char in docstring: 244 | break 245 | for l in doc_lines: 246 | if not test_start: 247 | if special_char in l and func_name in l: 248 | test_start = True 249 | if test_start and (special_char in l and func_name in l): 250 | l = l.strip().replace(special_char, "==") 251 | l = "assert " + l 252 | all_tests.append(l) 253 | elif any(["==" in docstring, "returns" in docstring]): 254 | for special_char in ["==", "returns"]: 255 | if special_char in docstring: 256 | break 257 | for l in doc_lines: 258 | if not test_start: 259 | if special_char in l and func_name + "(" in l: 260 | test_start = True 261 | if test_start and (special_char in l and func_name in l): 262 | l = "assert " + l.strip().replace(special_char, "==") 263 | all_tests.append(l) 264 | 265 | return all_tests 266 | 267 | 268 | manual_extract = { 269 | "HumanEval/12": [ 270 | "assert longest(['a', 'b', 'c']) == 'a'", 271 | "assert longest(['a', 'bb', 'ccc']) == 'ccc'", 272 | ], 273 | "HumanEval/38": ["assert True == True"], # empty assertion to handle no doc test 274 | "HumanEval/41": ["assert True == True"], # empty assertion to handle no doc test 275 | "HumanEval/50": ["assert True == True"], # empty assertion to handle no doc test 276 | "HumanEval/67": [ 277 | 'assert fruit_distribution("5 apples and 6 oranges", 19) == 8' 278 | 'assert fruit_distribution("0 apples and 1 oranges",3) == 2' 279 | 'assert fruit_distribution("2 apples and 3 oranges", 100) == 95' 280 | 'assert fruit_distribution("100 apples and 1 oranges",120) == 19' 281 | ], 282 | "HumanEval/68": [ 283 | "assert pluck([4,2,3]) == [2, 1]", 284 | "assert pluck([1,2,3]) == [2, 1]", 285 | "assert pluck([]) == []", 286 | "assert pluck([5, 0, 3, 0, 4, 2]) == [0, 1]", 287 | ], 288 | "HumanEval/78": [ 289 | "assert hex_key('AB') == 1", 290 | "assert hex_key('1077E') == 2", 291 | "assert hex_key('ABED1A33') == 4", 292 | "assert hex_key('123456789ABCDEF0') == 6", 293 | "assert hex_key('2020') == 2", 294 | ], 295 | "HumanEval/79": [ 296 | "assert decimal_to_binary(15) == 'db1111db'", 297 | "assert decimal_to_binary(32) == 'db100000db'", 298 | ], 299 | "HumanEval/81": [ 300 | "assert grade_equation([4.0, 3, 1.7, 2, 3.5]) ==> ['A+', 'B', 'C-', 'C', 'A-']" 301 | ], 302 | "HumanEval/83": ["assert True == True"], # empty assertion to handle no doc test 303 | "HumanEval/84": ["assert True == True"], # empty assertion to handle no doc test 304 | "HumanEval/86": [ 305 | "assert anti_shuffle('Hi') == 'Hi'", 306 | "assert anti_shuffle('hello') == 'ehllo'", 307 | "assert anti_shuffle('Hello World!!!') == 'Hello !!!Wdlor'", 308 | ], 309 | "HumanEval/88": [ 310 | "assert sort_array([]) == []", 311 | "assert sort_array([5]) == [5]", 312 | "assert sort_array([2, 4, 3, 0, 1, 5]) == [0, 1, 2, 3, 4, 5]", 313 | "assert sort_array([2, 4, 3, 0, 1, 5, 6]) == [6, 5, 4, 3, 2, 1, 0]", 314 | ], 315 | "HumanEval/94": [ 316 | "assert skjkasdkd([0,3,2,1,3,5,7,4,5,5,5,2,181,32,4,32,3,2,32,324,4,3]) == 10", 317 | "assert skjkasdkd([1,0,1,8,2,4597,2,1,3,40,1,2,1,2,4,2,5,1]) == 25", 318 | "assert skjkasdkd([1,3,1,32,5107,34,83278,109,163,23,2323,32,30,1,9,3]) == 13", 319 | "assert skjkasdkd([0,724,32,71,99,32,6,0,5,91,83,0,5,6]) == 11", 320 | "assert skjkasdkd([0,81,12,3,1,21]) == 3", 321 | "assert skjkasdkd([0,8,1,2,1,7]) == 7", 322 | ], 323 | "HumanEval/95": [ 324 | 'assert check_dict_case({"a":"apple", "b":"banana"}) == True.', 325 | 'assert check_dict_case({"a":"apple", "A":"banana", "B":"banana"}) == False.', 326 | 'assert check_dict_case({"a":"apple", 8:"banana", "a":"apple"}) == False.', 327 | 'assert check_dict_case({"Name":"John", "Age":"36", "City":"Houston"}) == False.', 328 | 'assert check_dict_case({"STATE":"NC", "ZIP":"12345" }) == True.', 329 | ], 330 | "HumanEval/97": [ 331 | "assert multiply(148, 412) == 16", 332 | "assert multiply(19, 28) == 72", 333 | "assert multiply(2020, 1851) == 0", 334 | "assert multiply(14,-15) == 20", 335 | ], 336 | "HumanEval/102": [ 337 | "assert choose_num(12, 15) == 14", 338 | "assert choose_num(13, 12) == -1", 339 | ], 340 | "HumanEval/105": ["assert True == True"], 341 | "HumanEval/107": [ 342 | "assert even_odd_palindrome(3) == (1, 3)", 343 | "assert even_odd_palindrome(12) == (4, 6)", 344 | ], 345 | "HumanEval/108": [ 346 | "assert count_nums([]) == 0", 347 | "assert count_nums([-1, 11, -11]) == 1", 348 | "assert count_nums([1, 1, 2]) == 3", 349 | ], 350 | "HumanEval/115": [ 351 | "assert max_fill([[0,0,1,0], [0,1,0,0], [1,1,1,1]]) == 1", 352 | "assert max_fill([[0,0,1,1], [0,0,0,0], [1,1,1,1], [0,1,1,1]]) == 2", 353 | "assert max_fill([[0,0,0], [0,0,0]]) == 0", 354 | ], 355 | "HumanEval/116": [ 356 | "assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]", 357 | "assert sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]", 358 | "assert sort_array([1, 0, 2, 3, 4]) == [0, 1, 2, 3, 4]", 359 | ], 360 | "HumanEval/112": [ 361 | "assert reverse_delete('abcde', 'ae') == ('bcd',False)", 362 | "assert reverse_delete('abcdef', 'b') == ('acdef',False)", 363 | "assert reverse_delete('abcdedcba', 'ab') == ('cdedc',True)", 364 | ], 365 | "HumanEval/120": [ 366 | "assert maximum([-3, -4, 5], 3) == [-4, -3, 5]", 367 | "assert maximum([4, -4, 4], 2) == [4, 4]", 368 | "assert maximum([-3, 2, 1, 2, -1, -2, 1], 1) == [2]", 369 | ], 370 | "HumanEval/122": [ 371 | "assert add_elements([111,21,3,4000,5,6,7,8,9]) == 24", 372 | ], 373 | "HumanEval/128": [ 374 | "assert prod_signs([1, 2, 2, -4]) == -9", 375 | "assert prod_signs([0, 1]) == 0", 376 | "assert prod_signs([]) == None", 377 | ], 378 | "HumanEval/129": [ 379 | "assert minPath([[1,2,3], [4,5,6], [7,8,9]], 3) == [1, 2, 1]", 380 | "assert minPath([[5,9,3], [4,1,6], [7,8,2]], 1) == [1]", 381 | ], 382 | "HumanEval/130": ["assert tri(3) == [1, 3, 2, 8]"], 383 | "HumanEval/133": [ 384 | "assert sum_squares([1,2,3]) == 14", 385 | "assert sum_squares([1,4,9]) == 98", 386 | "assert sum_squares([1,3,5,7]) == 84", 387 | "assert sum_squares([1.4,4.2,0]) == 29", 388 | "assert sum_squares([-2.4,1,1]) == 6", 389 | ], 390 | "HumanEval/135": [ 391 | "assert can_arrange([1,2,4,3,5]) == 3", 392 | "assert can_arrange([1,2,3]) == -1", 393 | ], 394 | "HumanEval/141": [ 395 | "assert file_name_check('example.txt') == 'Yes'", 396 | "assert file_name_check('1example.dll') == 'No'", 397 | ], 398 | "HumanEval/142": [ 399 | "assert sum_squares([1,2,3]) == 6", 400 | "assert sum_squares([]) == 0", 401 | "assert sum_squares([-1,-5,2,-1,-5]) == -126", 402 | ], 403 | "HumanEval/143": [ 404 | "assert words_in_sentence('This is a test') == 'is'", 405 | "assert words_in_sentence('lets go for swimming') == 'go for'", 406 | ], 407 | "HumanEval/144": [ 408 | 'assert simplify("1/5", "5/1") == True', 409 | 'assert simplify("1/6", "2/1") == False', 410 | 'assert simplify("7/10", "10/2") == False', 411 | ], 412 | "HumanEval/145": [ 413 | "assert order_by_points([1, 11, -1, -11, -12]) == [-1, -11, 1, -12, 11]", 414 | "assert order_by_points([]) == []", 415 | ], 416 | "HumanEval/156": [ 417 | "assert int_to_mini_roman(19) == 'xix'", 418 | "assert int_to_mini_roman(152) == 'clii'", 419 | "assert int_to_mini_roman(426) == 'cdxxvi'", 420 | ], 421 | "HumanEval/147": [ 422 | "assert get_max_triples(5) == 1", 423 | ], 424 | "HumanEval/149": [ 425 | 'assert list_sort(["aa", "a", "aaa"]) == ["aa"]', 426 | 'assert list_sort(["ab", "a", "aaa", "cd"]) == ["ab", "cd"]', 427 | ], 428 | "HumanEval/159": [ 429 | "assert eat(5, 6, 10) == [11, 4]", 430 | "assert eat(4, 8, 9) == [12, 1]", 431 | "assert eat(1, 10, 10) == [11, 0]", 432 | "assert eat(2, 11, 5) == [7, 0]", 433 | ], 434 | "HumanEval/160": [ 435 | "assert do_algebra([2, 3, 4, 5], ['+', '*', '-']) == 9", 436 | ], 437 | "HumanEval/161": [ 438 | 'assert solve("1234") == "4321"', 439 | 'assert solve("ab") == "AB"', 440 | 'assert solve("#a@C") == "#A@c"', 441 | ], 442 | "HumanEval/162": [ 443 | "assert string_to_md5('Hello world') == '3e25960a79dbc69b674cd4ec67a72c62'" 444 | ], 445 | } 446 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import os 4 | import tempfile 5 | from datasets import load_metric 6 | from tqdm import tqdm 7 | import pickle as pkl 8 | from data import MBPPGoogleDataset 9 | from execution import Command 10 | import sys 11 | from utils import time_limit 12 | 13 | 14 | """ dataset keys: src, trg_prediction, reference """ 15 | 16 | 17 | def evaluate_charbleu(dataset): 18 | bleu = load_metric("bleu") 19 | predictions = [[ch for ch in item["trg_prediction"]] for item in dataset] 20 | references = [[[ch for ch in item["reference"]]] for item in dataset] 21 | return bleu.compute(predictions=predictions, references=references) 22 | 23 | 24 | """ dataset keys: src, trg_prediction, reference (only trg_prediction useful) """ 25 | 26 | 27 | def evaluate_spider_with_cached_results(selected): 28 | all_pred_results = [item["execution_result"] for item in selected] 29 | all_gold_results = pkl.load( 30 | open( 31 | "./dataset/spider/cached_gold_results.pkl", 32 | "rb", 33 | ) 34 | ) 35 | 36 | total_correct = 0 37 | for p_res, g_res in tqdm( 38 | zip(all_pred_results, all_gold_results), 39 | total=len(all_gold_results), 40 | ): 41 | total_correct += int(p_res[1] == g_res) 42 | 43 | return total_correct / len(all_gold_results) 44 | 45 | 46 | def evaluate_one_mbpp(args, tempdir, dataset, timeout): 47 | i, item = args 48 | if "execution_result_full_pass" in dataset[i]: 49 | return int( 50 | all( 51 | isinstance(x[1], bool) and x[1] == True 52 | for x in dataset[i]["execution_result_full_pass"] 53 | ) 54 | ) 55 | else: 56 | test_cases = item["test_list"] 57 | test_setups = item["test_setup_code"] 58 | code = dataset[i]["trg_prediction"] 59 | # write code to file 60 | with open(f"{tempdir.name}/code-{i}.py", "w") as fout: 61 | print(code, file=fout) 62 | print(test_setups, file=fout) 63 | for case in test_cases: 64 | print(case, file=fout) 65 | fout.close() 66 | command = Command(f"python {tempdir.name}/code-{i}.py >/dev/null 2>&1") 67 | execution_result = command.run(timeout=timeout) == 0 68 | return execution_result 69 | 70 | 71 | from functools import partial 72 | from multiprocessing import Pool 73 | 74 | """ dataset keys: src, trg_prediction, reference (only trg_prediction useful) """ 75 | 76 | 77 | def evaluate_google_mbpp( 78 | dataset, 79 | reference_path, 80 | split="test", 81 | timeout=10, 82 | return_details=False, 83 | num_procs=1, 84 | verbose=False, 85 | ): 86 | references = MBPPGoogleDataset(reference_path) 87 | assert len(dataset) == len(references.raw_data[split]) 88 | tempdir = tempfile.TemporaryDirectory() 89 | passed_information = list() 90 | partial_evalutate_one = partial( 91 | evaluate_one_mbpp, tempdir=tempdir, dataset=dataset, timeout=timeout 92 | ) 93 | 94 | if num_procs > 1: 95 | with Pool(processes=num_procs) as pool: 96 | for result_json in tqdm( 97 | pool.imap( 98 | partial_evalutate_one, list(enumerate(references.raw_data[split])) 99 | ), 100 | total=len(references.raw_data[split]), 101 | leave=False, 102 | disable=not verbose, 103 | ): 104 | passed_information.append(result_json) 105 | else: 106 | for args in tqdm( 107 | list(enumerate(references.raw_data[split])), disable=not verbose 108 | ): 109 | passed_information.append(partial_evalutate_one(args)) 110 | tempdir.cleanup() 111 | if return_details: 112 | return passed_information 113 | else: 114 | return sum(passed_information) / len(passed_information) 115 | 116 | 117 | def evaluate_humaneval(dataset): 118 | all_passed = [d["execution_result_full_pass"] for d in dataset] 119 | return sum(all_passed) / len(all_passed) 120 | -------------------------------------------------------------------------------- /exec_spider_gold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from tqdm import tqdm 4 | import os 5 | import sqlite3 6 | import pickle as pkl 7 | 8 | # CONSTANT 9 | db_dir = "./dataset/spider/database/" 10 | # preloading spider data to reduce io 11 | from dataset.spider_official.evaluation import ( 12 | build_foreign_key_map_from_json, 13 | build_valid_col_units, 14 | rebuild_sql_val, 15 | rebuild_sql_col, 16 | ) 17 | from dataset.spider_official.process_sql import ( 18 | get_schema, 19 | Schema, 20 | get_sql, 21 | ) 22 | 23 | kmaps = build_foreign_key_map_from_json("./dataset/spider/tables.json") 24 | with open("dataset/spider/dev_gold.sql") as f: 25 | glist = [l.strip().split("\t") for l in f.readlines() if len(l.strip()) > 0] 26 | 27 | 28 | all_g_res = [] 29 | for gold_sql in tqdm(glist, total=len(glist)): 30 | g_str, db = gold_sql 31 | db_name = db 32 | db = os.path.join(db_dir, db, db + ".sqlite") 33 | schema = Schema(get_schema(db)) 34 | g_sql = get_sql(schema, g_str) 35 | kmap = kmaps[db_name] 36 | g_valid_col_units = build_valid_col_units(g_sql["from"]["table_units"], schema) 37 | g_sql = rebuild_sql_val(g_sql) 38 | g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) 39 | 40 | conn = sqlite3.connect(f"file:{db}?mode=ro", uri=True) 41 | cursor = conn.cursor() 42 | # there are potential utf-8 errors 43 | try: 44 | cursor.execute(g_str) 45 | g_res = cursor.fetchall() 46 | except: 47 | g_res = [] 48 | 49 | def res_map(res, val_units): 50 | rmap = {} 51 | for idx, val_unit in enumerate(val_units): 52 | key = ( 53 | tuple(val_unit[1]) 54 | if not val_unit[2] 55 | else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) 56 | ) 57 | rmap[key] = [r[idx] for r in res] 58 | return rmap 59 | 60 | g_val_units = [unit[1] for unit in g_sql["select"][1]] 61 | g_res = res_map(g_res, g_val_units) 62 | all_g_res.append(g_res) 63 | 64 | pkl.dump( 65 | all_g_res, 66 | open( 67 | "./dataset/spider/cached_gold_results.pkl", 68 | "wb", 69 | ), 70 | ) 71 | -------------------------------------------------------------------------------- /execution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import bashlex 4 | import json 5 | import os 6 | import pickle 7 | import regex 8 | import signal 9 | import subprocess 10 | import tempfile 11 | import threading 12 | from datasets import load_metric 13 | from glob import glob 14 | from nltk.translate.bleu_score import sentence_bleu 15 | from tqdm import tqdm 16 | from dataset.human_eval.human_eval.evaluation import evaluate_functional_correctness 17 | import numpy as np 18 | from collections import Counter 19 | 20 | 21 | from data import MBPPGoogleDataset, HumanEvalDataset, MBPPSanDataset 22 | from utils_sql import * 23 | from time import sleep 24 | 25 | 26 | class Command(object): 27 | def __init__(self, cmd): 28 | self.cmd = cmd 29 | self.process = None 30 | 31 | def run(self, timeout): 32 | def target(): 33 | self.process = subprocess.Popen(self.cmd, shell=True, preexec_fn=os.setsid) 34 | self.process.communicate() 35 | 36 | thread = threading.Thread(target=target) 37 | thread.start() 38 | 39 | thread.join(timeout) 40 | if thread.is_alive(): 41 | os.killpg(self.process.pid, signal.SIGTERM) 42 | thread.join() 43 | return self.process.returncode 44 | 45 | 46 | class PythonFunctionExecutor(object): 47 | def __init__(self, function_content, function_call, timeout=10): 48 | self.function_content = function_content 49 | self.function_content = self.function_content.replace("", "") 50 | self.function_call = function_call 51 | self.timeout = timeout 52 | 53 | def __call__(self, i, use_json=False): 54 | tempdir = tempfile.TemporaryDirectory() 55 | with open(f"{tempdir.name}/code-{i}.py", "w") as fout: 56 | print(self.function_content, file=fout) 57 | print(f"result = {self.function_call}", file=fout) 58 | print(f"import pickle", file=fout) 59 | print( 60 | f'pickle.dump(result, open("{tempdir.name}/execution_result-{i}.pkl", "wb"))', 61 | file=fout, 62 | ) 63 | command = Command(f"python {tempdir.name}/code-{i}.py >/dev/null 2>&1") 64 | execution_status = command.run(timeout=self.timeout) 65 | if execution_status == 0: 66 | try: 67 | execution_results = pickle.load( 68 | open(f"{tempdir.name}/execution_result-{i}.pkl", "rb") 69 | ) 70 | except: 71 | execution_results = None 72 | else: 73 | execution_results = None 74 | tempdir.cleanup() 75 | return execution_status, execution_results 76 | 77 | 78 | def mbpp_execute_one_assertion(args): 79 | data_item, code_item, i = args 80 | assertion = data_item[-1] 81 | command = regex.match(f"assert (.+)==.+", assertion).group(1) 82 | python_function = code_item["trg_prediction"] 83 | executor = PythonFunctionExecutor(python_function, command) 84 | execution_result = executor(i) 85 | return execution_result 86 | 87 | 88 | def mbpp_execute_multiple_assertion(args): 89 | data_item, code_item, i = args 90 | execution_result = list() 91 | python_function = code_item["trg_prediction"] 92 | for assertion_i, assertion in enumerate(data_item[-1]): 93 | command = regex.match(f"assert (.+)==.+", assertion).group(1) 94 | executor = PythonFunctionExecutor(python_function, command) 95 | execution_result.append(executor(f"{i}-{assertion_i}")) 96 | return execution_result 97 | 98 | 99 | def mbpp_execute_multiple_assertion_pass(args): 100 | data_item, code_item, i = args 101 | execution_result = list() 102 | python_function = code_item["trg_prediction"] 103 | for assertion_i, assertion in enumerate(data_item[-1]): 104 | command = regex.match(f"assert (.+==.+)", assertion).group(1) 105 | executor = PythonFunctionExecutor(python_function, f"({command})") 106 | execute_stats, execute_result = executor(f"{i}-{assertion_i}") 107 | # if isinstance(execute_result, tuple) and len(execute_result) == 2: 108 | # execute_result = execute_result[0] 109 | # assert execute_result is None or isinstance(execute_result, bool) 110 | execution_result.append((execute_stats, execute_result)) 111 | return execution_result 112 | 113 | 114 | from multiprocessing import Pool 115 | 116 | 117 | def execute_mbpp_google_folder(base_path, num_procs=10, verbose=False): 118 | # single assertion 119 | dataset = MBPPGoogleDataset(mode="assertion") 120 | for path in tqdm( 121 | glob(f"{base_path}/*jsonl"), leave=False, desc="exec one", disable=not verbose 122 | ): # execute first assertion call 123 | if "with-reverse" in path: 124 | continue 125 | if os.path.exists(path.replace("jsonl", "exec.pkl")): 126 | continue 127 | split = os.path.basename(path).split("-")[0] 128 | execution_results = list() 129 | all_args = [] 130 | for i, line in enumerate(open(path).readlines()): 131 | data_item = dataset.data[split][i] 132 | code_item = json.loads(line) 133 | all_args.append((data_item, code_item, i)) 134 | if num_procs > 1: 135 | with Pool(processes=num_procs) as pool: 136 | for execution_result in pool.imap(mbpp_execute_one_assertion, all_args): 137 | execution_results.append(execution_result) 138 | else: 139 | for execution_result in map(mbpp_execute_one_assertion, all_args): 140 | execution_results.append(execution_result) 141 | with open(path.replace("jsonl", "exec.pkl"), "wb") as fout: 142 | pickle.dump(execution_results, fout) 143 | # multiple assertions (cheating) 144 | dataset = MBPPGoogleDataset(mode="assertion-full") 145 | for path in tqdm( 146 | glob(f"{base_path}/*jsonl"), 147 | leave=False, 148 | desc="exec multiple", 149 | disable=not verbose, 150 | ): # execute all assertion calls 151 | if "with-reverse" in path: 152 | continue 153 | if os.path.exists(path.replace("jsonl", "execfull.pkl")): 154 | continue 155 | split = os.path.basename(path).split("-")[0] 156 | execution_results = list() 157 | all_args = [] 158 | for i, line in enumerate(open(path).readlines()): 159 | data_item = dataset.data[split][i] 160 | code_item = json.loads(line) 161 | import uuid 162 | 163 | all_args.append((data_item, code_item, str(uuid.uuid4()))) 164 | if num_procs > 1: 165 | with Pool(processes=num_procs) as pool: 166 | for execution_result in pool.imap( 167 | mbpp_execute_multiple_assertion, all_args 168 | ): 169 | execution_results.append(execution_result) 170 | else: 171 | for execution_result in map(mbpp_execute_multiple_assertion, all_args): 172 | execution_results.append(execution_result) 173 | with open(path.replace("jsonl", "execfull.pkl"), "wb") as fout: 174 | pickle.dump(execution_results, fout) 175 | # multiple assertions (pass or fail) 176 | for path in tqdm( 177 | glob(f"{base_path}/*jsonl"), 178 | leave=False, 179 | desc="exec-multiple-pass", 180 | disable=not verbose, 181 | ): 182 | if "with-reverse" in path: 183 | continue 184 | if os.path.exists(path.replace("jsonl", "execfullpass.pkl")): 185 | continue 186 | split = os.path.basename(path).split("-")[0] 187 | execution_results = list() 188 | all_args = [] 189 | for i, line in enumerate(open(path).readlines()): 190 | data_item = dataset.data[split][i] 191 | code_item = json.loads(line) 192 | all_args.append((data_item, code_item, i)) 193 | if num_procs > 1: 194 | with Pool(processes=num_procs) as pool: 195 | for execution_result in pool.imap( 196 | mbpp_execute_multiple_assertion_pass, all_args 197 | ): 198 | execution_results.append(execution_result) 199 | else: 200 | for execution_result in map(mbpp_execute_multiple_assertion_pass, all_args): 201 | execution_results.append(execution_result) 202 | # with open(path.replace('jsonl', 'execfullpass.pkl'), 'rb') as fout: 203 | # gt_execution_results = pickle.load(fout) 204 | # for i, (a, b) in enumerate(zip(execution_results, gt_execution_results)): 205 | # if a != b: 206 | # print(i, (a, b)) 207 | with open(path.replace("jsonl", "execfullpass.pkl"), "wb") as fout: 208 | pickle.dump(execution_results, fout) 209 | 210 | 211 | def execute_spider_folder( 212 | base_path, 213 | db_path="dataset/spider/database", 214 | gold_path="dataset/spider", 215 | table_path="dataset/spider/tables.json", 216 | timeout=10, 217 | ): 218 | kmaps = build_foreign_key_map_from_json(table_path) 219 | for path in glob(f"{base_path}/*jsonl"): 220 | if "with-reverse" in path: 221 | continue 222 | if os.path.exists(path.replace("jsonl", "exec.pkl")): 223 | continue 224 | execution_results = list() 225 | split = os.path.basename(path).split("-")[0] 226 | file_gold_path = f"{gold_path}/{split}_gold.sql" 227 | with open(file_gold_path) as f: 228 | glist = [l.strip().split("\t") for l in f if len(l.strip()) > 0] 229 | with open(path) as f: 230 | plist = [json.loads(l)["trg_prediction"] for l in f] 231 | for p_str, (_, db_name) in tqdm(list(zip(plist, glist))): 232 | db = os.path.join(db_path, db_name, db_name + ".sqlite") 233 | schema = Schema(get_schema(db)) 234 | try: 235 | p_sql = get_sql(schema, p_str) 236 | except: 237 | # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql 238 | p_sql = { 239 | "except": None, 240 | "from": {"conds": [], "table_units": []}, 241 | "groupBy": [], 242 | "having": [], 243 | "intersect": None, 244 | "limit": None, 245 | "orderBy": [], 246 | "select": [False, []], 247 | "union": None, 248 | "where": [], 249 | } 250 | # rebuild sql for value evaluation 251 | kmap = kmaps[db_name] 252 | p_valid_col_units = build_valid_col_units( 253 | p_sql["from"]["table_units"], schema 254 | ) 255 | p_sql = rebuild_sql_val(p_sql) 256 | p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) 257 | execution_result = execute(db, p_str, p_sql, timeout) 258 | execution_results.append(execution_result) 259 | with open(path.replace("jsonl", "exec.pkl"), "wb") as fout: 260 | pickle.dump(execution_results, fout) 261 | 262 | 263 | def simulate_bash_exec(command): 264 | return list(bashlex.split(command)) 265 | 266 | 267 | def execute_mbpp_google_folder_one(base_path, num_procs=5, verbose=False, tag=""): 268 | # single assertion 269 | path = str(base_path) 270 | dataset = MBPPGoogleDataset(mode="assertion") 271 | out_name = "exec.pkl" 272 | if not (os.path.exists(path.replace("jsonl", out_name))): 273 | split = os.path.basename(path).split("-")[0] 274 | execution_results = list() 275 | all_args = [] 276 | for i, line in enumerate(open(path).readlines()): 277 | data_item = dataset.data[split][i] 278 | code_item = json.loads(line) 279 | all_args.append((data_item, code_item, i)) 280 | if num_procs > 1: 281 | with Pool(processes=num_procs) as pool: 282 | for execution_result in tqdm( 283 | pool.imap(mbpp_execute_one_assertion, all_args), 284 | total=len(all_args), 285 | leave=False, 286 | disable=not verbose, 287 | desc="exec on", 288 | ): 289 | execution_results.append(execution_result) 290 | else: 291 | for execution_result in map( 292 | mbpp_execute_one_assertion, tqdm(all_args, disable=not verbose) 293 | ): 294 | execution_results.append(execution_result) 295 | with open(path.replace("jsonl", out_name), "wb") as fout: 296 | pickle.dump(execution_results, fout) 297 | # mltiple assertions (cheating) 298 | dataset = MBPPGoogleDataset(mode="assertion-full") 299 | path = str(base_path) 300 | out_name = "execfull.pkl" 301 | if not (os.path.exists(path.replace("jsonl", out_name))): 302 | split = os.path.basename(path).split("-")[0] 303 | execution_results = list() 304 | all_args = [] 305 | for i, line in enumerate(open(path).readlines()): 306 | data_item = dataset.data[split][i] 307 | code_item = json.loads(line) 308 | all_args.append((data_item, code_item, i)) 309 | if num_procs > 1: 310 | with Pool(processes=num_procs) as pool: 311 | for execution_result in tqdm( 312 | pool.imap(mbpp_execute_multiple_assertion, all_args), 313 | total=len(all_args), 314 | leave=False, 315 | disable=not verbose, 316 | desc="exec all", 317 | ): 318 | execution_results.append(execution_result) 319 | else: 320 | for execution_result in map( 321 | mbpp_execute_multiple_assertion, tqdm(all_args, disable=not verbose) 322 | ): 323 | execution_results.append(execution_result) 324 | with open(path.replace("jsonl", out_name), "wb") as fout: 325 | pickle.dump(execution_results, fout) 326 | # mltiple assertions (pass or fail) 327 | path = str(base_path) 328 | out_name = "execfullpass.pkl" 329 | if not (os.path.exists(path.replace("jsonl", out_name))): 330 | split = os.path.basename(path).split("-")[0] 331 | execution_results = list() 332 | all_args = [] 333 | for i, line in enumerate(open(path).readlines()): 334 | data_item = dataset.data[split][i] 335 | code_item = json.loads(line) 336 | all_args.append((data_item, code_item, i)) 337 | if num_procs > 1: 338 | with Pool(processes=num_procs) as pool: 339 | for execution_result in tqdm( 340 | pool.imap(mbpp_execute_multiple_assertion_pass, all_args), 341 | total=len(all_args), 342 | leave=False, 343 | disable=not verbose, 344 | desc="pass or fail", 345 | ): 346 | execution_results.append(execution_result) 347 | else: 348 | for execution_result in map( 349 | mbpp_execute_multiple_assertion_pass, 350 | tqdm(all_args, disable=not verbose), 351 | ): 352 | execution_results.append(execution_result) 353 | with open(path.replace("jsonl", out_name), "wb") as fout: 354 | pickle.dump(execution_results, fout) 355 | 356 | 357 | def execute_spider_folder_one( 358 | base_path, 359 | db_path="dataset/spider/database", 360 | gold_path="dataset/spider", 361 | table_path="dataset/spider/tables.json", 362 | timeout=10, 363 | verbose=False, 364 | tag="", 365 | ): 366 | kmaps = build_foreign_key_map_from_json(table_path) 367 | path = str(base_path) 368 | out_name = "exec.pkl" if tag == "" else f"exec.pkl" 369 | if not (os.path.exists(path.replace("jsonl", f"{out_name}"))): 370 | execution_results = list() 371 | split = os.path.basename(path).split("-")[0] 372 | file_gold_path = f"{gold_path}/{split}_gold.sql" 373 | with open(file_gold_path) as f: 374 | glist = [l.strip().split("\t") for l in f if len(l.strip()) > 0] 375 | with open(path) as f: 376 | plist = [json.loads(l)["trg_prediction"] for l in f] 377 | 378 | count = 0 379 | for p_str, (_, db_name) in tqdm( 380 | list(zip(plist, glist)), disable=not verbose, desc="SQL exec" 381 | ): 382 | db = os.path.join(db_path, db_name, db_name + ".sqlite") 383 | schema = Schema(get_schema(db)) 384 | try: 385 | p_sql = get_sql(schema, p_str) 386 | except: 387 | # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql 388 | p_sql = { 389 | "except": None, 390 | "from": {"conds": [], "table_units": []}, 391 | "groupBy": [], 392 | "having": [], 393 | "intersect": None, 394 | "limit": None, 395 | "orderBy": [], 396 | "select": [False, []], 397 | "union": None, 398 | "where": [], 399 | } 400 | # rebuild sql for value evaluation 401 | kmap = kmaps[db_name] 402 | p_valid_col_units = build_valid_col_units( 403 | p_sql["from"]["table_units"], schema 404 | ) 405 | p_sql = rebuild_sql_val(p_sql) 406 | p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) 407 | execution_result = execute(db, p_str, p_sql, timeout) 408 | execution_results.append(execution_result) 409 | count += 1 410 | with open(path.replace("jsonl", out_name), "wb") as fout: 411 | pickle.dump(execution_results, fout) 412 | 413 | 414 | def humaneval_postprocess( 415 | completion, 416 | ): 417 | keep_lines = [] 418 | for l in completion.split("\n"): 419 | if not l.startswith("print"): 420 | keep_lines.append(l) 421 | return "\n".join(keep_lines) 422 | 423 | 424 | def humaneval_execute_one_assertion(problem): 425 | assertion = problem["assertion"] 426 | try: 427 | command = regex.match(f"assert (.+)==.+", assertion).group(1) 428 | except: 429 | command = regex.match(f"assert (.+)", assertion).group(1) 430 | python_function = problem["prompt"] + problem["completion"] 431 | executor = PythonFunctionExecutor(python_function, command) 432 | execution_result = executor(problem["task_id"].split("/")[1]) 433 | return execution_result 434 | 435 | 436 | def humaneval_execute_multiple_assertion(problem): 437 | execution_result = list() 438 | python_function = problem["prompt"] + problem["completion"] 439 | task_id = problem["task_id"].split("/")[1] 440 | for assertion_i, assertion in enumerate(problem["assertion"]): 441 | try: 442 | try: 443 | command = regex.match(f"assert (.+)==.+", assertion).group(1) 444 | except: 445 | command = regex.match(f"assert (.+)", assertion).group(1) 446 | except: 447 | print(problem["assertion"]) 448 | print(problem["task_id"]) 449 | breakpoint() 450 | executor = PythonFunctionExecutor(python_function, command) 451 | execution_result.append(executor(f"{task_id}-{assertion_i}")) 452 | return execution_result 453 | 454 | 455 | def humaneval_execute_generated_assertion(problem): 456 | execution_result = list() 457 | python_function = problem["prompt"] + problem["completion"] 458 | task_id = problem["task_id"].split("/")[1] 459 | 460 | total_matched = 0 461 | for assertion_i, assertion in enumerate(problem["gen_assertion"]): 462 | matched = False 463 | for pattern in ["assert (.+)==.+", "assert (.+) is .+", "assert (.+)"]: 464 | try: 465 | command = regex.match(pattern, assertion).group(1) 466 | matched = True 467 | break 468 | except: 469 | pass 470 | 471 | if matched: 472 | executor = PythonFunctionExecutor(python_function, command) 473 | execution_result.append(executor(f"{task_id}-{assertion_i}")) 474 | total_matched += int(matched) 475 | 476 | if total_matched > 20: 477 | break 478 | return execution_result 479 | 480 | 481 | def execute_humaneval_folder_one( 482 | base_path, 483 | timeout=10, 484 | verbose=False, 485 | tag="", 486 | num_procs=1, 487 | dataset_choice="humaneval", 488 | ): 489 | path = str(base_path) 490 | if dataset_choice in ["humaneval", "codet_humaneval"]: 491 | dataset_cls = HumanEvalDataset 492 | if dataset_choice == "codet_humaneval": 493 | dataset_problem_file = "dataset/human_eval/dataset/CodeTHumanEval.jsonl" 494 | assertion_file = "dataset/human_eval/dataset/HumanEval.jsonl" 495 | else: 496 | dataset_problem_file = "dataset/human_eval/dataset/HumanEval.jsonl" 497 | assertion_file = "" 498 | elif dataset_choice == "mbpp_sanitized": 499 | dataset_problem_file = "dataset/mbpp/mbpp_sanitized_for_code_generation.jsonl" 500 | assertion_file = "" 501 | dataset_cls = MBPPSanDataset 502 | else: 503 | raise ValueError("Invalid data choice") 504 | 505 | dataset = dataset_cls( 506 | path=dataset_problem_file, assertion_path=assertion_file, mode="assertion" 507 | ) 508 | prompt_to_problem = {p["prompt"]: p for task_id, p in dataset.raw_data.items()} 509 | 510 | out_name = "exec.pkl" 511 | problem_with_completions = [] 512 | for line in open(path).readlines(): 513 | code_item = json.loads(line) 514 | problem = prompt_to_problem[code_item["prompt"]] 515 | problem["completion"] = humaneval_postprocess(code_item["trg_prediction"]) 516 | problem_with_completions.append(problem) 517 | 518 | if not (os.path.exists(path.replace("jsonl", out_name))): 519 | execution_results = [] 520 | if num_procs > 1: 521 | with Pool(processes=num_procs) as pool: 522 | for execution_result in pool.imap( 523 | humaneval_execute_one_assertion, problem_with_completions 524 | ): 525 | execution_results.append(execution_result) 526 | else: 527 | for execution_result in map( 528 | humaneval_execute_one_assertion, problem_with_completions 529 | ): 530 | execution_results.append(execution_result) 531 | with open(path.replace("jsonl", out_name), "wb") as fout: 532 | pickle.dump(execution_results, fout) 533 | 534 | dataset = dataset_cls( 535 | path=dataset_problem_file, assertion_path=assertion_file, mode="assertion-all" 536 | ) 537 | prompt_to_problem = {p["prompt"]: p for task_id, p in dataset.raw_data.items()} 538 | problem_with_completions = [] 539 | for line in open(path).readlines(): 540 | code_item = json.loads(line) 541 | problem = prompt_to_problem[code_item["prompt"]] 542 | problem["completion"] = humaneval_postprocess(code_item["trg_prediction"]) 543 | problem_with_completions.append(problem) 544 | 545 | out_name = "execfull.pkl" 546 | if not (os.path.exists(path.replace("jsonl", out_name))): 547 | execution_results = [] 548 | if num_procs > 1: 549 | with Pool(processes=num_procs) as pool: 550 | for execution_result in pool.imap( 551 | humaneval_execute_multiple_assertion, problem_with_completions 552 | ): 553 | execution_results.append(execution_result) 554 | else: 555 | for execution_result in map( 556 | humaneval_execute_multiple_assertion, problem_with_completions 557 | ): 558 | execution_results.append(execution_result) 559 | with open(path.replace("jsonl", out_name), "wb") as fout: 560 | pickle.dump(execution_results, fout) 561 | 562 | out_name = "execfullpass.pkl" 563 | if not (os.path.exists(path.replace("jsonl", out_name))): 564 | results, pass_at_k, extras = evaluate_functional_correctness( 565 | samples=problem_with_completions, 566 | sample_file=None, 567 | k=[1], 568 | problem_file=dataset_problem_file, 569 | suppress=True, 570 | timeout=timeout, 571 | ) 572 | all_passed = [] 573 | for result in results.values(): 574 | result.sort() 575 | passed = [r[1]["passed"] for r in result] 576 | assert len(passed) == 1 577 | all_passed.append(passed[0]) 578 | with open(path.replace("jsonl", out_name), "wb") as fout: 579 | pickle.dump(all_passed, fout) 580 | else: 581 | all_passed = pickle.load(open(path.replace("jsonl", out_name), "rb")) 582 | 583 | 584 | def execute_nl2bash_folder_one( 585 | base_path, 586 | ): 587 | bleu = load_metric("bleu") 588 | path = str(base_path) 589 | 590 | if all( 591 | ( 592 | os.path.exists(path.replace(".jsonl", ".exec.pkl")), 593 | os.path.exists(path.replace(".jsonl", ".exec.splitted.pkl")), 594 | os.path.exists(path.replace(".jsonl", ".exec.simulate.pkl")), 595 | os.path.exists(path.replace(".jsonl", ".exec.bleu.pkl")), 596 | ) 597 | ): 598 | # return 599 | pass 600 | 601 | all_exec_results = [] 602 | all_exec_splitted_results = [] 603 | all_simulate_exec = [] 604 | all_char_bleu = [] 605 | for line in tqdm(open(path).readlines()): 606 | code_item = json.loads(line) 607 | code_item["trg_prediction"] 608 | try: 609 | with time_limit(10): 610 | bashlex.parse(code_item["trg_prediction"]) 611 | all_exec_results.append(True) 612 | except: 613 | all_exec_results.append(False) 614 | 615 | try: 616 | with time_limit(10): 617 | splitted_trg_pred = simulate_bash_exec(code_item["trg_prediction"]) 618 | except: 619 | splitted_trg_pred = list() 620 | simulate_exec = Counter(splitted_trg_pred) 621 | all_exec_splitted_results.append(splitted_trg_pred) 622 | all_simulate_exec.append(simulate_exec) 623 | 624 | try: 625 | with time_limit(10): 626 | all_char_bleu.append( 627 | bleu.compute( 628 | predictions=[[ch for ch in code_item["reference"]]], 629 | references=[[[ch for ch in code_item["trg_prediction"]]]], 630 | )["bleu"] 631 | ) 632 | except: 633 | all_char_bleu.append(0) 634 | 635 | with open(path.replace(".jsonl", ".exec.pkl"), "wb") as fout: 636 | pickle.dump(all_exec_results, fout) 637 | with open(path.replace(".jsonl", ".exec.splitted.pkl"), "wb") as fout: 638 | pickle.dump(all_exec_splitted_results, fout) 639 | with open(path.replace(".jsonl", ".exec.simulate.pkl"), "wb") as fout: 640 | pickle.dump(all_simulate_exec, fout) 641 | with open(path.replace(".jsonl", ".exec.bleu.pkl"), "wb") as fout: 642 | pickle.dump(all_char_bleu, fout) 643 | -------------------------------------------------------------------------------- /fewshot_reviewer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from pathlib import Path 4 | import os 5 | from glob import glob 6 | from argparse import ArgumentParser 7 | import html 8 | import json 9 | from utils import * 10 | from tqdm import tqdm, trange 11 | from functools import partial 12 | from utils import write_jsonl, parse_prompt, make_new_context 13 | from pyminifier_canonicalize import remove_print, clean_comment 14 | 15 | parser = ArgumentParser() 16 | parser.add_argument("--model", type=str, default="codex") 17 | parser.add_argument( 18 | "--dataset", type=str, default="mbpp", choices=["mbpp", "spider", "nl2bash"] 19 | ) 20 | parser.add_argument("--tag", type=str, default="") 21 | parser.add_argument("--split", type=str, default="test") 22 | parser.add_argument("--batch_size", type=int, default=20) 23 | parser.add_argument("--max_tokens", type=int, default=512) 24 | parser.add_argument("--top_p", type=float, default=1.0) 25 | parser.add_argument("--num_samples", type=int, default=5) 26 | parser.add_argument("--num_procs", type=int, default=40) 27 | parser.add_argument("--canonicalize", action="store_true", default=False) 28 | parser.add_argument( 29 | "--data_path", 30 | type=str, 31 | default="/private/home/tianyizzz/projects/mbr-exec-data/mbr-exec-release/", 32 | ) 33 | parser.add_argument("--temperature", type=float, default=0.3) 34 | 35 | args = parser.parse_args() 36 | args.data_path = Path(args.data_path) 37 | out_dir = f"seed-*/**/*-{args.temperature}" 38 | if args.top_p != 1.0: 39 | out_dir += f"-p{args.top_p}" 40 | if args.max_tokens != 512: 41 | out_dir += f"-max{args.max_tokens}" 42 | args.data_path = args.data_path / args.dataset / out_dir 43 | paths = list(sorted(glob(str(args.data_path), recursive=True))) 44 | 45 | 46 | def find_start(tokens, dataset="mbpp"): 47 | if dataset == "mbpp": 48 | match_token = ["<", "info", ">"] 49 | else: 50 | match_token = ["<", "text", ">"] 51 | for i in range(len(tokens) - 3, 0, -1): 52 | if tokens[i : i + 3] == match_token: 53 | break 54 | return i 55 | 56 | 57 | def batch_query_reverse_logp(all_codex_data, args, verbose=False): 58 | for outer_i, batch_start in enumerate( 59 | trange(0, len(all_codex_data), args.batch_size, disable=not verbose) 60 | ): 61 | batch_data = all_codex_data[batch_start : batch_start + args.batch_size] 62 | batch_prompts = [] 63 | batch_prompts_without_ref = [] 64 | for codex_data in batch_data: 65 | prompt = codex_data["prompt"] 66 | prompt_parse = parse_prompt(prompt, dataset=args.dataset) 67 | code_sample = codex_data["trg_prediction"] 68 | prompt_parse[-1]["code"] = f"{code_sample}" 69 | if args.dataset == "mbpp" and args.canonicalize: 70 | try: 71 | code_sample = clean_comment(code_sample) 72 | except: 73 | code_sample = code_sample 74 | code_sample = remove_print(code_sample) 75 | with_ref_prompt, without_ref_prompt = make_new_context( 76 | prompt_parse, dataset=args.dataset 77 | ) 78 | batch_prompts.append(with_ref_prompt) 79 | batch_prompts_without_ref.append(without_ref_prompt) 80 | with_ref_reponse, _ = safe_codex_call( 81 | args, 82 | batch_prompts, 83 | temperature=1.0, 84 | echo=True, 85 | max_tokens=0, 86 | api_i=outer_i % 3, 87 | ) 88 | for batch_i, (codex_data, with_ref_prompt, without_ref_prompt) in enumerate( 89 | zip(batch_data, batch_prompts, batch_prompts_without_ref) 90 | ): 91 | num_api_tokens = find_start( 92 | with_ref_reponse["choices"][batch_i]["logprobs"]["tokens"], 93 | dataset=args.dataset, 94 | ) 95 | gt_prompt_logprob = with_ref_reponse["choices"][batch_i]["logprobs"][ 96 | "token_logprobs" 97 | ][num_api_tokens:] 98 | gt_prompt_tokens = with_ref_reponse["choices"][batch_i]["logprobs"][ 99 | "tokens" 100 | ][num_api_tokens:] 101 | codex_data["reverse_prompt_with_ref"] = with_ref_prompt 102 | codex_data["reverse_prompt_without_ref"] = without_ref_prompt 103 | codex_data["prompt_reverse_tokens"] = gt_prompt_tokens 104 | codex_data["prompt_reverse_logprobs"] = gt_prompt_logprob 105 | codex_data["prompt_reverse_full_tokens"] = with_ref_reponse["choices"][ 106 | batch_i 107 | ]["logprobs"]["tokens"] 108 | codex_data["prompt_reverse_full_logprobs"] = with_ref_reponse["choices"][ 109 | batch_i 110 | ]["logprobs"]["token_logprobs"] 111 | return all_codex_data 112 | 113 | 114 | paths = sorted(paths) 115 | print(paths) 116 | for path in tqdm(paths, desc="total seeds", disable=False): 117 | path = Path(path) 118 | for sample_i in trange(args.num_samples, leave=False): 119 | if len(args.tag) == 0: 120 | output_file_name = f"{args.split}-{sample_i}-with-reverse.jsonl" 121 | else: 122 | output_file_name = f"{args.split}-{sample_i}-with-reverse-{args.tag}.jsonl" 123 | 124 | try: 125 | all_codex_data = [] 126 | with open(path / f"{args.split}-{sample_i}.jsonl", "r") as f: 127 | for i, line in enumerate(f): 128 | codex_data = json.loads(line) 129 | codex_data = json.loads(line) 130 | all_codex_data.append(codex_data) 131 | except Exception as e: 132 | print(e) 133 | print(f"{path / output_file_name} not ready yet. skipping.") 134 | continue 135 | 136 | if (path / output_file_name).exists(): 137 | with open(path / output_file_name, "r") as f: 138 | line_num = len(f.readlines()) 139 | if line_num == len(all_codex_data): 140 | continue 141 | 142 | from multiprocessing import Pool 143 | 144 | if args.num_procs > 1: 145 | all_codex_data_with_reverse = [] 146 | chunk_size = len(all_codex_data) // args.num_procs + 1 147 | chunked_all_codex_data = [ 148 | all_codex_data[chunk_start : chunk_start + chunk_size] 149 | for chunk_start in range(0, len(all_codex_data), chunk_size) 150 | ] 151 | with Pool(processes=args.num_procs) as pool: 152 | for codex_data_with_reverse in tqdm( 153 | pool.imap( 154 | partial(batch_query_reverse_logp, args=args, verbose=True), 155 | chunked_all_codex_data, 156 | ), 157 | total=len(chunked_all_codex_data), 158 | ): 159 | all_codex_data_with_reverse.extend(codex_data_with_reverse) 160 | else: 161 | all_codex_data_with_reverse = batch_query_reverse_logp( 162 | all_codex_data, args, verbose=True 163 | ) 164 | 165 | with open(path / output_file_name, "w") as f: 166 | for codex_data_with_reverse in all_codex_data_with_reverse: 167 | codex_data_json = json.dumps(codex_data_with_reverse) 168 | f.write(codex_data_json + "\n") 169 | -------------------------------------------------------------------------------- /multi_exec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import shutil 4 | import torch 5 | from pathlib import Path 6 | import os 7 | from glob import glob 8 | from argparse import ArgumentParser 9 | from tqdm import tqdm, trange 10 | import torch.distributed as dist 11 | from execution import ( 12 | execute_humaneval_folder_one, 13 | execute_mbpp_google_folder_one, 14 | execute_spider_folder_one, 15 | execute_nl2bash_folder_one, 16 | ) 17 | from pathlib import Path 18 | 19 | 20 | parser = ArgumentParser() 21 | parser.add_argument("--batch_size", type=int, default=2) 22 | parser.add_argument("--dataset", type=str, default="mbpp") 23 | parser.add_argument("--tag", type=str, default="") 24 | parser.add_argument("--split", type=str, default="test") 25 | parser.add_argument("--num_seeds", type=int, default=5) 26 | parser.add_argument("--num_samples", type=int, default=5) 27 | parser.add_argument("--num_prompts", type=int, default=1) 28 | parser.add_argument( 29 | "--in_data_path", 30 | type=str, 31 | default="/private/home/tianyizzz/projects/mbr-exec-data/mbr-exec-codex001/", 32 | ) 33 | parser.add_argument("--temperature", type=float, default=0.3) 34 | parser.add_argument("--max_tokens", type=int, default=512) 35 | parser.add_argument("--top_p", type=float, default=1.0) 36 | parser.add_argument("--rank", type=int, default=0) 37 | parser.add_argument("--local_rank", type=int, default=0) 38 | parser.add_argument("--world_size", type=int, default=1) 39 | 40 | args = parser.parse_args() 41 | args.rank = int(os.environ.get("LOCAL_RANK", 0)) 42 | # if args.world_size > 1: 43 | # dist.init_process_group("gloo", rank=args.rank, world_size=args.world_size) 44 | 45 | paths = [] 46 | if args.temperature > 0: 47 | for seed in range(args.num_seeds): 48 | for i in range(args.num_samples): 49 | if (seed * args.num_samples + i) % args.world_size == args.rank: 50 | out_dir = f"sample-{args.temperature}" 51 | if args.top_p != 1.0: 52 | out_dir += f"-p{args.top_p}" 53 | if args.max_tokens != 512: 54 | out_dir += f"-max{args.max_tokens}" 55 | if args.tag == "": 56 | result_file = f"{args.split}-{i}.jsonl" 57 | else: 58 | result_file = f"{args.split}-{i}-{args.tag}.jsonl" 59 | path = ( 60 | Path(args.in_data_path) 61 | / args.dataset 62 | / f"seed-{seed}" 63 | / f"{args.num_prompts}-shot" 64 | / out_dir 65 | / result_file 66 | ) 67 | paths.append(path) 68 | else: 69 | for seed in range(args.num_seeds): 70 | i = 0 71 | if (seed * 5 + i) % args.world_size == args.rank: 72 | out_dir = f"sample-{args.temperature}" 73 | if args.max_tokens != 512: 74 | out_dir += f"-max{args.max_tokens}" 75 | if args.tag == "": 76 | result_file = f"{args.split}-{i}.jsonl" 77 | else: 78 | result_file = f"{args.split}-{i}-{args.tag}.jsonl" 79 | paths.append( 80 | Path(args.in_data_path) 81 | / args.dataset 82 | / f"seed-{seed}" 83 | / f"{args.num_prompts}-shot" 84 | / out_dir 85 | / result_file 86 | ) 87 | 88 | for path in tqdm(paths, disable=not args.rank == 0): 89 | if args.dataset == "mbpp": 90 | execute_mbpp_google_folder_one(path, verbose=args.rank == 0, tag=args.tag) 91 | elif args.dataset == "spider": 92 | execute_spider_folder_one(path, verbose=args.rank == 0, tag=args.tag) 93 | elif "humaneval" in args.dataset or args.dataset == "mbpp_sanitized": 94 | execute_humaneval_folder_one( 95 | path, verbose=args.rank == 0, tag=args.tag, dataset_choice=args.dataset 96 | ) 97 | elif args.dataset == "nl2bash": 98 | execute_nl2bash_folder_one(path) 99 | 100 | else: 101 | raise ValueError("invalid dataset") 102 | -------------------------------------------------------------------------------- /process_sql.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | ################################ 4 | # Assumptions: 5 | # 1. sql is correct 6 | # 2. only table name has alias 7 | # 3. only one intersect/union/except 8 | # 9 | # val: number(float)/string(str)/sql(dict) 10 | # col_unit: (agg_id, col_id, isDistinct(bool)) 11 | # val_unit: (unit_op, col_unit1, col_unit2) 12 | # table_unit: (table_type, col_unit/sql) 13 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 14 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 15 | # sql { 16 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 17 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 18 | # 'where': condition 19 | # 'groupBy': [col_unit1, col_unit2, ...] 20 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 21 | # 'having': condition 22 | # 'limit': None/limit value 23 | # 'intersect': None/sql 24 | # 'except': None/sql 25 | # 'union': None/sql 26 | # } 27 | ################################ 28 | 29 | import json 30 | import sqlite3 31 | from nltk import word_tokenize 32 | 33 | CLAUSE_KEYWORDS = ( 34 | "select", 35 | "from", 36 | "where", 37 | "group", 38 | "order", 39 | "limit", 40 | "intersect", 41 | "union", 42 | "except", 43 | ) 44 | JOIN_KEYWORDS = ("join", "on", "as") 45 | 46 | WHERE_OPS = ( 47 | "not", 48 | "between", 49 | "=", 50 | ">", 51 | "<", 52 | ">=", 53 | "<=", 54 | "!=", 55 | "in", 56 | "like", 57 | "is", 58 | "exists", 59 | ) 60 | UNIT_OPS = ("none", "-", "+", "*", "/") 61 | AGG_OPS = ("none", "max", "min", "count", "sum", "avg") 62 | TABLE_TYPE = { 63 | "sql": "sql", 64 | "table_unit": "table_unit", 65 | } 66 | 67 | COND_OPS = ("and", "or") 68 | SQL_OPS = ("intersect", "union", "except") 69 | ORDER_OPS = ("desc", "asc") 70 | 71 | 72 | class Schema: 73 | """ 74 | Simple schema which maps table&column to a unique identifier 75 | """ 76 | 77 | def __init__(self, schema): 78 | self._schema = schema 79 | self._idMap = self._map(self._schema) 80 | 81 | @property 82 | def schema(self): 83 | return self._schema 84 | 85 | @property 86 | def idMap(self): 87 | return self._idMap 88 | 89 | def _map(self, schema): 90 | idMap = {"*": "__all__"} 91 | id = 1 92 | for key, vals in schema.items(): 93 | for val in vals: 94 | idMap[key.lower() + "." + val.lower()] = ( 95 | "__" + key.lower() + "." + val.lower() + "__" 96 | ) 97 | id += 1 98 | 99 | for key in schema: 100 | idMap[key.lower()] = "__" + key.lower() + "__" 101 | id += 1 102 | 103 | return idMap 104 | 105 | 106 | def get_schema(db): 107 | """ 108 | Get database's schema, which is a dict with table name as key 109 | and list of column names as value 110 | :param db: database path 111 | :return: schema dict 112 | """ 113 | 114 | schema = {} 115 | conn = sqlite3.connect(db) 116 | cursor = conn.cursor() 117 | 118 | # fetch table names 119 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 120 | tables = [str(table[0].lower()) for table in cursor.fetchall()] 121 | 122 | # fetch table info 123 | for table in tables: 124 | cursor.execute("PRAGMA table_info({})".format(table)) 125 | schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] 126 | 127 | return schema 128 | 129 | 130 | def get_schema_from_json(fpath): 131 | with open(fpath) as f: 132 | data = json.load(f) 133 | 134 | schema = {} 135 | for entry in data: 136 | table = str(entry["table"].lower()) 137 | cols = [str(col["column_name"].lower()) for col in entry["col_data"]] 138 | schema[table] = cols 139 | 140 | return schema 141 | 142 | 143 | def tokenize(string): 144 | string = str(string) 145 | string = string.replace( 146 | "'", '"' 147 | ) # ensures all string values wrapped by "" problem?? 148 | quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] 149 | assert len(quote_idxs) % 2 == 0, "Unexpected quote" 150 | 151 | # keep string value as token 152 | vals = {} 153 | for i in range(len(quote_idxs) - 1, -1, -2): 154 | qidx1 = quote_idxs[i - 1] 155 | qidx2 = quote_idxs[i] 156 | val = string[qidx1 : qidx2 + 1] 157 | key = "__val_{}_{}__".format(qidx1, qidx2) 158 | string = string[:qidx1] + key + string[qidx2 + 1 :] 159 | vals[key] = val 160 | 161 | toks = [word.lower() for word in word_tokenize(string)] 162 | # replace with string value token 163 | for i in range(len(toks)): 164 | if toks[i] in vals: 165 | toks[i] = vals[toks[i]] 166 | 167 | # find if there exists !=, >=, <= 168 | eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] 169 | eq_idxs.reverse() 170 | prefix = ("!", ">", "<") 171 | for eq_idx in eq_idxs: 172 | pre_tok = toks[eq_idx - 1] 173 | if pre_tok in prefix: 174 | toks = toks[: eq_idx - 1] + [pre_tok + "="] + toks[eq_idx + 1 :] 175 | 176 | return toks 177 | 178 | 179 | def scan_alias(toks): 180 | """Scan the index of 'as' and build the map for all alias""" 181 | as_idxs = [idx for idx, tok in enumerate(toks) if tok == "as"] 182 | alias = {} 183 | for idx in as_idxs: 184 | alias[toks[idx + 1]] = toks[idx - 1] 185 | return alias 186 | 187 | 188 | def get_tables_with_alias(schema, toks): 189 | tables = scan_alias(toks) 190 | for key in schema: 191 | assert key not in tables, "Alias {} has the same name in table".format(key) 192 | tables[key] = key 193 | return tables 194 | 195 | 196 | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): 197 | """ 198 | :returns next idx, column id 199 | """ 200 | tok = toks[start_idx] 201 | if tok == "*": 202 | return start_idx + 1, schema.idMap[tok] 203 | 204 | if "." in tok: # if token is a composite 205 | alias, col = tok.split(".") 206 | key = tables_with_alias[alias] + "." + col 207 | return start_idx + 1, schema.idMap[key] 208 | 209 | assert ( 210 | default_tables is not None and len(default_tables) > 0 211 | ), "Default tables should not be None or empty" 212 | 213 | for alias in default_tables: 214 | table = tables_with_alias[alias] 215 | if tok in schema.schema[table]: 216 | key = table + "." + tok 217 | return start_idx + 1, schema.idMap[key] 218 | 219 | assert False, "Error col: {}".format(tok) 220 | 221 | 222 | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 223 | """ 224 | :returns next idx, (agg_op id, col_id) 225 | """ 226 | idx = start_idx 227 | len_ = len(toks) 228 | isBlock = False 229 | isDistinct = False 230 | if toks[idx] == "(": 231 | isBlock = True 232 | idx += 1 233 | 234 | if toks[idx] in AGG_OPS: 235 | agg_id = AGG_OPS.index(toks[idx]) 236 | idx += 1 237 | assert idx < len_ and toks[idx] == "(" 238 | idx += 1 239 | if toks[idx] == "distinct": 240 | idx += 1 241 | isDistinct = True 242 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 243 | assert idx < len_ and toks[idx] == ")" 244 | idx += 1 245 | return idx, (agg_id, col_id, isDistinct) 246 | 247 | if toks[idx] == "distinct": 248 | idx += 1 249 | isDistinct = True 250 | agg_id = AGG_OPS.index("none") 251 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 252 | 253 | if isBlock: 254 | assert toks[idx] == ")" 255 | idx += 1 # skip ')' 256 | 257 | return idx, (agg_id, col_id, isDistinct) 258 | 259 | 260 | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 261 | idx = start_idx 262 | len_ = len(toks) 263 | isBlock = False 264 | if toks[idx] == "(": 265 | isBlock = True 266 | idx += 1 267 | 268 | col_unit1 = None 269 | col_unit2 = None 270 | unit_op = UNIT_OPS.index("none") 271 | 272 | idx, col_unit1 = parse_col_unit( 273 | toks, idx, tables_with_alias, schema, default_tables 274 | ) 275 | if idx < len_ and toks[idx] in UNIT_OPS: 276 | unit_op = UNIT_OPS.index(toks[idx]) 277 | idx += 1 278 | idx, col_unit2 = parse_col_unit( 279 | toks, idx, tables_with_alias, schema, default_tables 280 | ) 281 | 282 | if isBlock: 283 | assert toks[idx] == ")" 284 | idx += 1 # skip ')' 285 | 286 | return idx, (unit_op, col_unit1, col_unit2) 287 | 288 | 289 | def parse_table_unit(toks, start_idx, tables_with_alias, schema): 290 | """ 291 | :returns next idx, table id, table name 292 | """ 293 | idx = start_idx 294 | len_ = len(toks) 295 | key = tables_with_alias[toks[idx]] 296 | 297 | if idx + 1 < len_ and toks[idx + 1] == "as": 298 | idx += 3 299 | else: 300 | idx += 1 301 | 302 | return idx, schema.idMap[key], key 303 | 304 | 305 | def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): 306 | idx = start_idx 307 | len_ = len(toks) 308 | 309 | isBlock = False 310 | if toks[idx] == "(": 311 | isBlock = True 312 | idx += 1 313 | 314 | if toks[idx] == "select": 315 | idx, val = parse_sql(toks, idx, tables_with_alias, schema) 316 | elif '"' in toks[idx]: # token is a string value 317 | val = toks[idx] 318 | idx += 1 319 | else: 320 | try: 321 | val = float(toks[idx]) 322 | idx += 1 323 | except: 324 | end_idx = idx 325 | while ( 326 | end_idx < len_ 327 | and toks[end_idx] != "," 328 | and toks[end_idx] != ")" 329 | and toks[end_idx] != "and" 330 | and toks[end_idx] not in CLAUSE_KEYWORDS 331 | and toks[end_idx] not in JOIN_KEYWORDS 332 | ): 333 | end_idx += 1 334 | 335 | idx, val = parse_col_unit( 336 | toks[start_idx:end_idx], 0, tables_with_alias, schema, default_tables 337 | ) 338 | idx = end_idx 339 | 340 | if isBlock: 341 | assert toks[idx] == ")" 342 | idx += 1 343 | 344 | return idx, val 345 | 346 | 347 | def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): 348 | idx = start_idx 349 | len_ = len(toks) 350 | conds = [] 351 | 352 | while idx < len_: 353 | idx, val_unit = parse_val_unit( 354 | toks, idx, tables_with_alias, schema, default_tables 355 | ) 356 | not_op = False 357 | if toks[idx] == "not": 358 | not_op = True 359 | idx += 1 360 | 361 | assert ( 362 | idx < len_ and toks[idx] in WHERE_OPS 363 | ), "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) 364 | op_id = WHERE_OPS.index(toks[idx]) 365 | idx += 1 366 | val1 = val2 = None 367 | if op_id == WHERE_OPS.index( 368 | "between" 369 | ): # between..and... special case: dual values 370 | idx, val1 = parse_value( 371 | toks, idx, tables_with_alias, schema, default_tables 372 | ) 373 | assert toks[idx] == "and" 374 | idx += 1 375 | idx, val2 = parse_value( 376 | toks, idx, tables_with_alias, schema, default_tables 377 | ) 378 | else: # normal case: single value 379 | idx, val1 = parse_value( 380 | toks, idx, tables_with_alias, schema, default_tables 381 | ) 382 | val2 = None 383 | 384 | conds.append((not_op, op_id, val_unit, val1, val2)) 385 | 386 | if idx < len_ and ( 387 | toks[idx] in CLAUSE_KEYWORDS 388 | or toks[idx] in (")", ";") 389 | or toks[idx] in JOIN_KEYWORDS 390 | ): 391 | break 392 | 393 | if idx < len_ and toks[idx] in COND_OPS: 394 | conds.append(toks[idx]) 395 | idx += 1 # skip and/or 396 | 397 | return idx, conds 398 | 399 | 400 | def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): 401 | idx = start_idx 402 | len_ = len(toks) 403 | 404 | assert toks[idx] == "select", "'select' not found" 405 | idx += 1 406 | isDistinct = False 407 | if idx < len_ and toks[idx] == "distinct": 408 | idx += 1 409 | isDistinct = True 410 | val_units = [] 411 | 412 | while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: 413 | agg_id = AGG_OPS.index("none") 414 | if toks[idx] in AGG_OPS: 415 | agg_id = AGG_OPS.index(toks[idx]) 416 | idx += 1 417 | idx, val_unit = parse_val_unit( 418 | toks, idx, tables_with_alias, schema, default_tables 419 | ) 420 | val_units.append((agg_id, val_unit)) 421 | if idx < len_ and toks[idx] == ",": 422 | idx += 1 # skip ',' 423 | 424 | return idx, (isDistinct, val_units) 425 | 426 | 427 | def parse_from(toks, start_idx, tables_with_alias, schema): 428 | """ 429 | Assume in the from clause, all table units are combined with join 430 | """ 431 | assert "from" in toks[start_idx:], "'from' not found" 432 | 433 | len_ = len(toks) 434 | idx = toks.index("from", start_idx) + 1 435 | default_tables = [] 436 | table_units = [] 437 | conds = [] 438 | 439 | while idx < len_: 440 | isBlock = False 441 | if toks[idx] == "(": 442 | isBlock = True 443 | idx += 1 444 | 445 | if toks[idx] == "select": 446 | idx, sql = parse_sql(toks, idx, tables_with_alias, schema) 447 | table_units.append((TABLE_TYPE["sql"], sql)) 448 | else: 449 | if idx < len_ and toks[idx] == "join": 450 | idx += 1 # skip join 451 | idx, table_unit, table_name = parse_table_unit( 452 | toks, idx, tables_with_alias, schema 453 | ) 454 | table_units.append((TABLE_TYPE["table_unit"], table_unit)) 455 | default_tables.append(table_name) 456 | if idx < len_ and toks[idx] == "on": 457 | idx += 1 # skip on 458 | idx, this_conds = parse_condition( 459 | toks, idx, tables_with_alias, schema, default_tables 460 | ) 461 | if len(conds) > 0: 462 | conds.append("and") 463 | conds.extend(this_conds) 464 | 465 | if isBlock: 466 | assert toks[idx] == ")" 467 | idx += 1 468 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 469 | break 470 | 471 | return idx, table_units, conds, default_tables 472 | 473 | 474 | def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): 475 | idx = start_idx 476 | len_ = len(toks) 477 | 478 | if idx >= len_ or toks[idx] != "where": 479 | return idx, [] 480 | 481 | idx += 1 482 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 483 | return idx, conds 484 | 485 | 486 | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): 487 | idx = start_idx 488 | len_ = len(toks) 489 | col_units = [] 490 | 491 | if idx >= len_ or toks[idx] != "group": 492 | return idx, col_units 493 | 494 | idx += 1 495 | assert toks[idx] == "by" 496 | idx += 1 497 | 498 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 499 | idx, col_unit = parse_col_unit( 500 | toks, idx, tables_with_alias, schema, default_tables 501 | ) 502 | col_units.append(col_unit) 503 | if idx < len_ and toks[idx] == ",": 504 | idx += 1 # skip ',' 505 | else: 506 | break 507 | 508 | return idx, col_units 509 | 510 | 511 | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): 512 | idx = start_idx 513 | len_ = len(toks) 514 | val_units = [] 515 | order_type = "asc" # default type is 'asc' 516 | 517 | if idx >= len_ or toks[idx] != "order": 518 | return idx, val_units 519 | 520 | idx += 1 521 | assert toks[idx] == "by" 522 | idx += 1 523 | 524 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 525 | idx, val_unit = parse_val_unit( 526 | toks, idx, tables_with_alias, schema, default_tables 527 | ) 528 | val_units.append(val_unit) 529 | if idx < len_ and toks[idx] in ORDER_OPS: 530 | order_type = toks[idx] 531 | idx += 1 532 | if idx < len_ and toks[idx] == ",": 533 | idx += 1 # skip ',' 534 | else: 535 | break 536 | 537 | return idx, (order_type, val_units) 538 | 539 | 540 | def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): 541 | idx = start_idx 542 | len_ = len(toks) 543 | 544 | if idx >= len_ or toks[idx] != "having": 545 | return idx, [] 546 | 547 | idx += 1 548 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 549 | return idx, conds 550 | 551 | 552 | def parse_limit(toks, start_idx): 553 | idx = start_idx 554 | len_ = len(toks) 555 | 556 | if idx < len_ and toks[idx] == "limit": 557 | idx += 2 558 | return idx, int(toks[idx - 1]) 559 | 560 | return idx, None 561 | 562 | 563 | def parse_sql(toks, start_idx, tables_with_alias, schema): 564 | isBlock = False # indicate whether this is a block of sql/sub-sql 565 | len_ = len(toks) 566 | idx = start_idx 567 | 568 | sql = {} 569 | if toks[idx] == "(": 570 | isBlock = True 571 | idx += 1 572 | 573 | # parse from clause in order to get default tables 574 | from_end_idx, table_units, conds, default_tables = parse_from( 575 | toks, start_idx, tables_with_alias, schema 576 | ) 577 | sql["from"] = {"table_units": table_units, "conds": conds} 578 | # select clause 579 | _, select_col_units = parse_select( 580 | toks, idx, tables_with_alias, schema, default_tables 581 | ) 582 | idx = from_end_idx 583 | sql["select"] = select_col_units 584 | # where clause 585 | idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) 586 | sql["where"] = where_conds 587 | # group by clause 588 | idx, group_col_units = parse_group_by( 589 | toks, idx, tables_with_alias, schema, default_tables 590 | ) 591 | sql["groupBy"] = group_col_units 592 | # having clause 593 | idx, having_conds = parse_having( 594 | toks, idx, tables_with_alias, schema, default_tables 595 | ) 596 | sql["having"] = having_conds 597 | # order by clause 598 | idx, order_col_units = parse_order_by( 599 | toks, idx, tables_with_alias, schema, default_tables 600 | ) 601 | sql["orderBy"] = order_col_units 602 | # limit clause 603 | idx, limit_val = parse_limit(toks, idx) 604 | sql["limit"] = limit_val 605 | 606 | idx = skip_semicolon(toks, idx) 607 | if isBlock: 608 | assert toks[idx] == ")" 609 | idx += 1 # skip ')' 610 | idx = skip_semicolon(toks, idx) 611 | 612 | # intersect/union/except clause 613 | for op in SQL_OPS: # initialize IUE 614 | sql[op] = None 615 | if idx < len_ and toks[idx] in SQL_OPS: 616 | sql_op = toks[idx] 617 | idx += 1 618 | idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) 619 | sql[sql_op] = IUE_sql 620 | return idx, sql 621 | 622 | 623 | def load_data(fpath): 624 | with open(fpath) as f: 625 | data = json.load(f) 626 | return data 627 | 628 | 629 | def get_sql(schema, query): 630 | toks = tokenize(query) 631 | tables_with_alias = get_tables_with_alias(schema.schema, toks) 632 | _, sql = parse_sql(toks, 0, tables_with_alias, schema) 633 | 634 | return sql 635 | 636 | 637 | def skip_semicolon(toks, start_idx): 638 | idx = start_idx 639 | while idx < len(toks) and toks[idx] == ";": 640 | idx += 1 641 | return idx 642 | -------------------------------------------------------------------------------- /pyminifier_canonicalize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import keyword, sys 4 | from pyminifier import analyze 5 | from pyminifier.minification import remove_comments_and_docstrings, remove_blank_lines 6 | import re 7 | 8 | RESERVED_WORDS = keyword.kwlist + analyze.builtins 9 | 10 | 11 | def clean_comment(code): 12 | code = remove_comments_and_docstrings(code) 13 | code = remove_blank_lines(code) 14 | return code 15 | 16 | 17 | def remove_print(code): 18 | code = re.sub("print(.+)", "print('')", code) 19 | code = re.sub("Error(.+)", "Error('')", code) 20 | code = re.sub("Exception(.+)", "Exception('')", code) 21 | code = re.sub("assert (.+), +['\"].+['\"]", "assert \\1", code) 22 | return code 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.13.1 2 | astor==0.8.1 3 | asttokens==2.0.5 4 | astunparse==1.6.3 5 | async-timeout==4.0.2 6 | attrs==21.4.0 7 | bashlex==0.16 8 | datasets==2.3.2 9 | filelock==3.7.1 10 | huggingface-hub==0.10.1 11 | matplotlib==3.5.2 12 | matplotlib-inline==0.1.3 13 | nltk==3.7 14 | openai==0.23.0 15 | pandas==1.4.2 16 | pandas-datareader==0.10.0 17 | pandas-stubs==1.2.0.61 18 | pandocfilters==1.5.0 19 | Pillow==9.2.0 20 | platformdirs==2.5.2 21 | pluggy==1.0.0 22 | regex==2022.6.2 23 | sacrebleu==2.1.0 24 | seaborn==0.11.2 25 | tokenizers==0.12.1 26 | tqdm==4.64.1 27 | transformers==4.23.1 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from time import sleep 4 | import os 5 | import random 6 | import openai 7 | import re 8 | import json 9 | 10 | 11 | def safe_codex_call( 12 | args, api_text, temperature=None, stop=None, echo=False, max_tokens=256, api_i=0 13 | ): 14 | temperature = temperature if temperature else args.temperature 15 | while True: 16 | try: 17 | if args.model == "codex002": 18 | openai.organization = os.getenv(f"OPENAI_ORG{api_i+1}") 19 | else: 20 | openai.organization = os.getenv("OPENAI_ORG1") 21 | codex_response = codex_greedy( 22 | api_text, 23 | temperature=temperature, 24 | codex_config=args.model, 25 | stop=stop, 26 | echo=echo, 27 | max_tokens=max_tokens, 28 | ) 29 | break 30 | except openai.error.InvalidRequestError as e: 31 | codex_response = None 32 | if isinstance(api_text, list): 33 | api_text = [t.replace("\n", "") for t in api_text] 34 | else: 35 | api_text = api_text.replace("\n", "") 36 | print("Invalid Request: Removing newlines") 37 | except openai.error.RateLimitError as e: 38 | print(type(e), f"API {api_i}:", e, end="\r") 39 | sleep(30) 40 | api_i = (api_i + 1) % 3 41 | except Exception as e: 42 | print(type(e), e) 43 | sleep(10) 44 | 45 | if codex_response is None: 46 | codex_text = "" 47 | else: 48 | codex_text = "".join(codex_response["choices"][0]["logprobs"]["tokens"]) 49 | return codex_response, codex_text 50 | 51 | 52 | def codex_greedy( 53 | prompt, temperature=0.3, codex_config="codex", stop=None, echo=False, max_tokens=256 54 | ): 55 | if stop is None: 56 | stop = ["#SOLUTION END", "# SOLUTION END", "SOLUTION END"] 57 | if codex_config == "codex001": 58 | codex_code = "code-davinci-001" 59 | elif codex_config == "codex002": 60 | codex_code = "code-davinci-002" 61 | elif codex_config == "codex-cushman": 62 | codex_code = "code-cushman-001" 63 | else: 64 | raise ValueError 65 | 66 | response = openai.Completion.create( 67 | engine=codex_code, 68 | prompt=prompt, 69 | temperature=temperature, 70 | stop=stop, 71 | max_tokens=max_tokens, 72 | top_p=0.95, 73 | logprobs=1, 74 | frequency_penalty=0, 75 | presence_penalty=0, 76 | echo=echo, 77 | ) 78 | return response 79 | 80 | 81 | def write_jsonl(data_list, file_path): 82 | with open(file_path, "w") as f: 83 | for d in data_list: 84 | f.write(json.dumps(d) + "\n") 85 | 86 | 87 | def parse_prompt(prompt, dataset="mbpp"): 88 | prompt_data = [] 89 | fewshot_examples = [ 90 | p.strip() + "" for p in prompt.split("") if len(p) > 1 91 | ] 92 | for example in fewshot_examples: 93 | example_data = dict() 94 | if dataset in ["mbpp", "spider"]: 95 | all_fields = ["info", "text", "code"] 96 | elif dataset == "nl2bash": 97 | all_fields = ["text", "code"] 98 | for field in all_fields: 99 | field_start = example.index(f"<{field}>") 100 | field_end = example.index(f"") 101 | example_data[field] = example[field_start : field_end + len(f"")] 102 | prompt_data.append(example_data) 103 | return prompt_data 104 | 105 | 106 | def make_new_context(prompt_parse, dataset="mbpp"): 107 | without_ref = "" 108 | with_ref = "" 109 | 110 | if dataset == "mbpp": 111 | full_prompt_fields = ["code", "info", "text"] 112 | elif dataset == "spider": 113 | full_prompt_fields = ["info", "code", "text"] 114 | else: 115 | full_prompt_fields = ["code", "text"] 116 | 117 | if dataset == "mbpp" or dataset == "nl2bash": 118 | partial_prompt_fields = ["code"] 119 | elif dataset == "spider": 120 | partial_prompt_fields = ["info", "code"] 121 | 122 | for i, example in enumerate(prompt_parse): 123 | for field in full_prompt_fields: 124 | with_ref += example[field] + "\n" 125 | if i < len(prompt_parse) - 1: 126 | for field in full_prompt_fields: 127 | without_ref += example[field] + "\n" 128 | else: 129 | for field in partial_prompt_fields: 130 | without_ref += example[field] + "\n" 131 | return with_ref.strip(), without_ref.strip() 132 | 133 | 134 | from contextlib import contextmanager 135 | import signal 136 | 137 | 138 | class TimeoutException(Exception): 139 | pass 140 | 141 | 142 | @contextmanager 143 | def time_limit(seconds: float): 144 | def signal_handler(signum, frame): 145 | raise TimeoutException("Timed out!") 146 | 147 | signal.setitimer(signal.ITIMER_REAL, seconds) 148 | signal.signal(signal.SIGALRM, signal_handler) 149 | try: 150 | yield 151 | finally: 152 | signal.setitimer(signal.ITIMER_REAL, 0) 153 | -------------------------------------------------------------------------------- /utils_sql.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | ################################ 4 | # val: number(float)/string(str)/sql(dict) 5 | # col_unit: (agg_id, col_id, isDistinct(bool)) 6 | # val_unit: (unit_op, col_unit1, col_unit2) 7 | # table_unit: (table_type, col_unit/sql) 8 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 9 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 10 | # sql { 11 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 12 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 13 | # 'where': condition 14 | # 'groupBy': [col_unit1, col_unit2, ...] 15 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 16 | # 'having': condition 17 | # 'limit': None/limit value 18 | # 'intersect': None/sql 19 | # 'except': None/sql 20 | # 'union': None/sql 21 | # } 22 | ################################ 23 | 24 | from __future__ import print_function 25 | import os 26 | import json 27 | import sqlite3 28 | import signal 29 | from contextlib import contextmanager 30 | import argparse 31 | from process_sql import get_schema, Schema, get_sql 32 | import sys 33 | from time import sleep 34 | 35 | # Flag to disable value evaluation 36 | DISABLE_VALUE = True 37 | # Flag to disable distinct in select evaluation 38 | DISABLE_DISTINCT = True 39 | 40 | 41 | CLAUSE_KEYWORDS = ( 42 | "select", 43 | "from", 44 | "where", 45 | "group", 46 | "order", 47 | "limit", 48 | "intersect", 49 | "union", 50 | "except", 51 | ) 52 | JOIN_KEYWORDS = ("join", "on", "as") 53 | 54 | WHERE_OPS = ( 55 | "not", 56 | "between", 57 | "=", 58 | ">", 59 | "<", 60 | ">=", 61 | "<=", 62 | "!=", 63 | "in", 64 | "like", 65 | "is", 66 | "exists", 67 | ) 68 | UNIT_OPS = ("none", "-", "+", "*", "/") 69 | AGG_OPS = ("none", "max", "min", "count", "sum", "avg") 70 | TABLE_TYPE = { 71 | "sql": "sql", 72 | "table_unit": "table_unit", 73 | } 74 | 75 | COND_OPS = ("and", "or") 76 | SQL_OPS = ("intersect", "union", "except") 77 | ORDER_OPS = ("desc", "asc") 78 | 79 | 80 | HARDNESS = { 81 | "component1": ("where", "group", "order", "limit", "join", "or", "like"), 82 | "component2": ("except", "union", "intersect"), 83 | } 84 | 85 | 86 | class TimeoutException(Exception): 87 | pass 88 | 89 | 90 | @contextmanager 91 | def time_limit(seconds: float): 92 | def signal_handler(signum, frame): 93 | raise TimeoutException("Timed out!") 94 | 95 | signal.setitimer(signal.ITIMER_REAL, seconds) 96 | signal.signal(signal.SIGALRM, signal_handler) 97 | try: 98 | yield 99 | finally: 100 | signal.setitimer(signal.ITIMER_REAL, 0) 101 | 102 | 103 | def condition_has_or(conds): 104 | return "or" in conds[1::2] 105 | 106 | 107 | def condition_has_like(conds): 108 | return WHERE_OPS.index("like") in [cond_unit[1] for cond_unit in conds[::2]] 109 | 110 | 111 | def condition_has_sql(conds): 112 | for cond_unit in conds[::2]: 113 | val1, val2 = cond_unit[3], cond_unit[4] 114 | if val1 is not None and type(val1) is dict: 115 | return True 116 | if val2 is not None and type(val2) is dict: 117 | return True 118 | return False 119 | 120 | 121 | def val_has_op(val_unit): 122 | return val_unit[0] != UNIT_OPS.index("none") 123 | 124 | 125 | def has_agg(unit): 126 | return unit[0] != AGG_OPS.index("none") 127 | 128 | 129 | def accuracy(count, total): 130 | if count == total: 131 | return 1 132 | return 0 133 | 134 | 135 | def recall(count, total): 136 | if count == total: 137 | return 1 138 | return 0 139 | 140 | 141 | def F1(acc, rec): 142 | if (acc + rec) == 0: 143 | return 0 144 | return (2.0 * acc * rec) / (acc + rec) 145 | 146 | 147 | def get_scores(count, pred_total, label_total): 148 | if pred_total != label_total: 149 | return 0, 0, 0 150 | elif count == pred_total: 151 | return 1, 1, 1 152 | return 0, 0, 0 153 | 154 | 155 | def eval_sel(pred, label): 156 | pred_sel = pred["select"][1] 157 | label_sel = label["select"][1] 158 | label_wo_agg = [unit[1] for unit in label_sel] 159 | pred_total = len(pred_sel) 160 | label_total = len(label_sel) 161 | cnt = 0 162 | cnt_wo_agg = 0 163 | 164 | for unit in pred_sel: 165 | if unit in label_sel: 166 | cnt += 1 167 | label_sel.remove(unit) 168 | if unit[1] in label_wo_agg: 169 | cnt_wo_agg += 1 170 | label_wo_agg.remove(unit[1]) 171 | 172 | return label_total, pred_total, cnt, cnt_wo_agg 173 | 174 | 175 | def eval_where(pred, label): 176 | pred_conds = [unit for unit in pred["where"][::2]] 177 | label_conds = [unit for unit in label["where"][::2]] 178 | label_wo_agg = [unit[2] for unit in label_conds] 179 | pred_total = len(pred_conds) 180 | label_total = len(label_conds) 181 | cnt = 0 182 | cnt_wo_agg = 0 183 | 184 | for unit in pred_conds: 185 | if unit in label_conds: 186 | cnt += 1 187 | label_conds.remove(unit) 188 | if unit[2] in label_wo_agg: 189 | cnt_wo_agg += 1 190 | label_wo_agg.remove(unit[2]) 191 | 192 | return label_total, pred_total, cnt, cnt_wo_agg 193 | 194 | 195 | def eval_group(pred, label): 196 | pred_cols = [unit[1] for unit in pred["groupBy"]] 197 | label_cols = [unit[1] for unit in label["groupBy"]] 198 | pred_total = len(pred_cols) 199 | label_total = len(label_cols) 200 | cnt = 0 201 | pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] 202 | label_cols = [ 203 | label.split(".")[1] if "." in label else label for label in label_cols 204 | ] 205 | for col in pred_cols: 206 | if col in label_cols: 207 | cnt += 1 208 | label_cols.remove(col) 209 | return label_total, pred_total, cnt 210 | 211 | 212 | def eval_having(pred, label): 213 | pred_total = label_total = cnt = 0 214 | if len(pred["groupBy"]) > 0: 215 | pred_total = 1 216 | if len(label["groupBy"]) > 0: 217 | label_total = 1 218 | 219 | pred_cols = [unit[1] for unit in pred["groupBy"]] 220 | label_cols = [unit[1] for unit in label["groupBy"]] 221 | if ( 222 | pred_total == label_total == 1 223 | and pred_cols == label_cols 224 | and pred["having"] == label["having"] 225 | ): 226 | cnt = 1 227 | 228 | return label_total, pred_total, cnt 229 | 230 | 231 | def eval_order(pred, label): 232 | pred_total = label_total = cnt = 0 233 | if len(pred["orderBy"]) > 0: 234 | pred_total = 1 235 | if len(label["orderBy"]) > 0: 236 | label_total = 1 237 | if ( 238 | len(label["orderBy"]) > 0 239 | and pred["orderBy"] == label["orderBy"] 240 | and ( 241 | (pred["limit"] is None and label["limit"] is None) 242 | or (pred["limit"] is not None and label["limit"] is not None) 243 | ) 244 | ): 245 | cnt = 1 246 | return label_total, pred_total, cnt 247 | 248 | 249 | def eval_and_or(pred, label): 250 | pred_ao = pred["where"][1::2] 251 | label_ao = label["where"][1::2] 252 | pred_ao = set(pred_ao) 253 | label_ao = set(label_ao) 254 | 255 | if pred_ao == label_ao: 256 | return 1, 1, 1 257 | return len(pred_ao), len(label_ao), 0 258 | 259 | 260 | def get_nestedSQL(sql): 261 | nested = [] 262 | for cond_unit in sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2]: 263 | if type(cond_unit[3]) is dict: 264 | nested.append(cond_unit[3]) 265 | if type(cond_unit[4]) is dict: 266 | nested.append(cond_unit[4]) 267 | if sql["intersect"] is not None: 268 | nested.append(sql["intersect"]) 269 | if sql["except"] is not None: 270 | nested.append(sql["except"]) 271 | if sql["union"] is not None: 272 | nested.append(sql["union"]) 273 | return nested 274 | 275 | 276 | def eval_nested(pred, label): 277 | label_total = 0 278 | pred_total = 0 279 | cnt = 0 280 | if pred is not None: 281 | pred_total += 1 282 | if label is not None: 283 | label_total += 1 284 | if pred is not None and label is not None: 285 | cnt += Evaluator().eval_exact_match(pred, label) 286 | return label_total, pred_total, cnt 287 | 288 | 289 | def eval_IUEN(pred, label): 290 | lt1, pt1, cnt1 = eval_nested(pred["intersect"], label["intersect"]) 291 | lt2, pt2, cnt2 = eval_nested(pred["except"], label["except"]) 292 | lt3, pt3, cnt3 = eval_nested(pred["union"], label["union"]) 293 | label_total = lt1 + lt2 + lt3 294 | pred_total = pt1 + pt2 + pt3 295 | cnt = cnt1 + cnt2 + cnt3 296 | return label_total, pred_total, cnt 297 | 298 | 299 | def get_keywords(sql): 300 | res = set() 301 | if len(sql["where"]) > 0: 302 | res.add("where") 303 | if len(sql["groupBy"]) > 0: 304 | res.add("group") 305 | if len(sql["having"]) > 0: 306 | res.add("having") 307 | if len(sql["orderBy"]) > 0: 308 | res.add(sql["orderBy"][0]) 309 | res.add("order") 310 | if sql["limit"] is not None: 311 | res.add("limit") 312 | if sql["except"] is not None: 313 | res.add("except") 314 | if sql["union"] is not None: 315 | res.add("union") 316 | if sql["intersect"] is not None: 317 | res.add("intersect") 318 | 319 | # or keyword 320 | ao = sql["from"]["conds"][1::2] + sql["where"][1::2] + sql["having"][1::2] 321 | if len([token for token in ao if token == "or"]) > 0: 322 | res.add("or") 323 | 324 | cond_units = sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2] 325 | # not keyword 326 | if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: 327 | res.add("not") 328 | 329 | # in keyword 330 | if ( 331 | len( 332 | [ 333 | cond_unit 334 | for cond_unit in cond_units 335 | if cond_unit[1] == WHERE_OPS.index("in") 336 | ] 337 | ) 338 | > 0 339 | ): 340 | res.add("in") 341 | 342 | # like keyword 343 | if ( 344 | len( 345 | [ 346 | cond_unit 347 | for cond_unit in cond_units 348 | if cond_unit[1] == WHERE_OPS.index("like") 349 | ] 350 | ) 351 | > 0 352 | ): 353 | res.add("like") 354 | 355 | return res 356 | 357 | 358 | def eval_keywords(pred, label): 359 | pred_keywords = get_keywords(pred) 360 | label_keywords = get_keywords(label) 361 | pred_total = len(pred_keywords) 362 | label_total = len(label_keywords) 363 | cnt = 0 364 | 365 | for k in pred_keywords: 366 | if k in label_keywords: 367 | cnt += 1 368 | return label_total, pred_total, cnt 369 | 370 | 371 | def count_agg(units): 372 | return len([unit for unit in units if has_agg(unit)]) 373 | 374 | 375 | def count_component1(sql): 376 | count = 0 377 | if len(sql["where"]) > 0: 378 | count += 1 379 | if len(sql["groupBy"]) > 0: 380 | count += 1 381 | if len(sql["orderBy"]) > 0: 382 | count += 1 383 | if sql["limit"] is not None: 384 | count += 1 385 | if len(sql["from"]["table_units"]) > 0: # JOIN 386 | count += len(sql["from"]["table_units"]) - 1 387 | 388 | ao = sql["from"]["conds"][1::2] + sql["where"][1::2] + sql["having"][1::2] 389 | count += len([token for token in ao if token == "or"]) 390 | cond_units = sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2] 391 | count += len( 392 | [ 393 | cond_unit 394 | for cond_unit in cond_units 395 | if cond_unit[1] == WHERE_OPS.index("like") 396 | ] 397 | ) 398 | 399 | return count 400 | 401 | 402 | def count_component2(sql): 403 | nested = get_nestedSQL(sql) 404 | return len(nested) 405 | 406 | 407 | def count_others(sql): 408 | count = 0 409 | # number of aggregation 410 | agg_count = count_agg(sql["select"][1]) 411 | agg_count += count_agg(sql["where"][::2]) 412 | agg_count += count_agg(sql["groupBy"]) 413 | if len(sql["orderBy"]) > 0: 414 | agg_count += count_agg( 415 | [unit[1] for unit in sql["orderBy"][1] if unit[1]] 416 | + [unit[2] for unit in sql["orderBy"][1] if unit[2]] 417 | ) 418 | agg_count += count_agg(sql["having"]) 419 | if agg_count > 1: 420 | count += 1 421 | 422 | # number of select columns 423 | if len(sql["select"][1]) > 1: 424 | count += 1 425 | 426 | # number of where conditions 427 | if len(sql["where"]) > 1: 428 | count += 1 429 | 430 | # number of group by clauses 431 | if len(sql["groupBy"]) > 1: 432 | count += 1 433 | 434 | return count 435 | 436 | 437 | class Evaluator: 438 | """A simple evaluator""" 439 | 440 | def __init__(self): 441 | self.partial_scores = None 442 | 443 | def eval_hardness(self, sql): 444 | count_comp1_ = count_component1(sql) 445 | count_comp2_ = count_component2(sql) 446 | count_others_ = count_others(sql) 447 | 448 | if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: 449 | return "easy" 450 | elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or ( 451 | count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0 452 | ): 453 | return "medium" 454 | elif ( 455 | (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) 456 | or (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) 457 | or (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1) 458 | ): 459 | return "hard" 460 | else: 461 | return "extra" 462 | 463 | def eval_exact_match(self, pred, label): 464 | partial_scores = self.eval_partial_match(pred, label) 465 | self.partial_scores = partial_scores 466 | 467 | for _, score in partial_scores.items(): 468 | if score["f1"] != 1: 469 | return 0 470 | if len(label["from"]["table_units"]) > 0: 471 | label_tables = sorted(label["from"]["table_units"]) 472 | pred_tables = sorted(pred["from"]["table_units"]) 473 | return label_tables == pred_tables 474 | return 1 475 | 476 | def eval_partial_match(self, pred, label): 477 | res = {} 478 | 479 | label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) 480 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 481 | res["select"] = { 482 | "acc": acc, 483 | "rec": rec, 484 | "f1": f1, 485 | "label_total": label_total, 486 | "pred_total": pred_total, 487 | } 488 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) 489 | res["select(no AGG)"] = { 490 | "acc": acc, 491 | "rec": rec, 492 | "f1": f1, 493 | "label_total": label_total, 494 | "pred_total": pred_total, 495 | } 496 | 497 | label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) 498 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 499 | res["where"] = { 500 | "acc": acc, 501 | "rec": rec, 502 | "f1": f1, 503 | "label_total": label_total, 504 | "pred_total": pred_total, 505 | } 506 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) 507 | res["where(no OP)"] = { 508 | "acc": acc, 509 | "rec": rec, 510 | "f1": f1, 511 | "label_total": label_total, 512 | "pred_total": pred_total, 513 | } 514 | 515 | label_total, pred_total, cnt = eval_group(pred, label) 516 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 517 | res["group(no Having)"] = { 518 | "acc": acc, 519 | "rec": rec, 520 | "f1": f1, 521 | "label_total": label_total, 522 | "pred_total": pred_total, 523 | } 524 | 525 | label_total, pred_total, cnt = eval_having(pred, label) 526 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 527 | res["group"] = { 528 | "acc": acc, 529 | "rec": rec, 530 | "f1": f1, 531 | "label_total": label_total, 532 | "pred_total": pred_total, 533 | } 534 | 535 | label_total, pred_total, cnt = eval_order(pred, label) 536 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 537 | res["order"] = { 538 | "acc": acc, 539 | "rec": rec, 540 | "f1": f1, 541 | "label_total": label_total, 542 | "pred_total": pred_total, 543 | } 544 | 545 | label_total, pred_total, cnt = eval_and_or(pred, label) 546 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 547 | res["and/or"] = { 548 | "acc": acc, 549 | "rec": rec, 550 | "f1": f1, 551 | "label_total": label_total, 552 | "pred_total": pred_total, 553 | } 554 | 555 | label_total, pred_total, cnt = eval_IUEN(pred, label) 556 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 557 | res["IUEN"] = { 558 | "acc": acc, 559 | "rec": rec, 560 | "f1": f1, 561 | "label_total": label_total, 562 | "pred_total": pred_total, 563 | } 564 | 565 | label_total, pred_total, cnt = eval_keywords(pred, label) 566 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 567 | res["keywords"] = { 568 | "acc": acc, 569 | "rec": rec, 570 | "f1": f1, 571 | "label_total": label_total, 572 | "pred_total": pred_total, 573 | } 574 | 575 | return res 576 | 577 | 578 | def isValidSQL(sql, db): 579 | conn = sqlite3.connect(db) 580 | cursor = conn.cursor() 581 | try: 582 | cursor.execute(sql) 583 | except: 584 | return False 585 | return True 586 | 587 | 588 | def print_scores(scores, etype): 589 | levels = ["easy", "medium", "hard", "extra", "all"] 590 | partial_types = [ 591 | "select", 592 | "select(no AGG)", 593 | "where", 594 | "where(no OP)", 595 | "group(no Having)", 596 | "group", 597 | "order", 598 | "and/or", 599 | "IUEN", 600 | "keywords", 601 | ] 602 | 603 | print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels)) 604 | counts = [scores[level]["count"] for level in levels] 605 | print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts)) 606 | 607 | if etype in ["all", "exec"]: 608 | print("===================== EXECUTION ACCURACY =====================") 609 | this_scores = [scores[level]["exec"] for level in levels] 610 | print( 611 | "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format( 612 | "execution", *this_scores 613 | ) 614 | ) 615 | 616 | if etype in ["all", "match"]: 617 | print("\n====================== EXACT MATCHING ACCURACY =====================") 618 | exact_scores = [scores[level]["exact"] for level in levels] 619 | print( 620 | "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format( 621 | "exact match", *exact_scores 622 | ) 623 | ) 624 | print("\n---------------------PARTIAL MATCHING ACCURACY----------------------") 625 | for type_ in partial_types: 626 | this_scores = [scores[level]["partial"][type_]["acc"] for level in levels] 627 | print( 628 | "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format( 629 | type_, *this_scores 630 | ) 631 | ) 632 | 633 | print("---------------------- PARTIAL MATCHING RECALL ----------------------") 634 | for type_ in partial_types: 635 | this_scores = [scores[level]["partial"][type_]["rec"] for level in levels] 636 | print( 637 | "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format( 638 | type_, *this_scores 639 | ) 640 | ) 641 | 642 | print("---------------------- PARTIAL MATCHING F1 --------------------------") 643 | for type_ in partial_types: 644 | this_scores = [scores[level]["partial"][type_]["f1"] for level in levels] 645 | print( 646 | "{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format( 647 | type_, *this_scores 648 | ) 649 | ) 650 | 651 | 652 | def evaluate(gold, predict, db_dir, etype, kmaps): 653 | with open(gold) as f: 654 | glist = [l.strip().split("\t") for l in f.readlines() if len(l.strip()) > 0] 655 | 656 | with open(predict) as f: 657 | plist = [l.strip().split("\t") for l in f.readlines() if len(l.strip()) > 0] 658 | # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")] 659 | # glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")] 660 | evaluator = Evaluator() 661 | 662 | levels = ["easy", "medium", "hard", "extra", "all"] 663 | partial_types = [ 664 | "select", 665 | "select(no AGG)", 666 | "where", 667 | "where(no OP)", 668 | "group(no Having)", 669 | "group", 670 | "order", 671 | "and/or", 672 | "IUEN", 673 | "keywords", 674 | ] 675 | entries = [] 676 | scores = {} 677 | 678 | for level in levels: 679 | scores[level] = {"count": 0, "partial": {}, "exact": 0.0} 680 | scores[level]["exec"] = 0 681 | for type_ in partial_types: 682 | scores[level]["partial"][type_] = { 683 | "acc": 0.0, 684 | "rec": 0.0, 685 | "f1": 0.0, 686 | "acc_count": 0, 687 | "rec_count": 0, 688 | } 689 | 690 | eval_err_num = 0 691 | for p, g in zip(plist, glist): 692 | p_str = p[0] 693 | g_str, db = g 694 | db_name = db 695 | db = os.path.join(db_dir, db, db + ".sqlite") 696 | schema = Schema(get_schema(db)) 697 | g_sql = get_sql(schema, g_str) 698 | hardness = evaluator.eval_hardness(g_sql) 699 | scores[hardness]["count"] += 1 700 | scores["all"]["count"] += 1 701 | 702 | try: 703 | p_sql = get_sql(schema, p_str) 704 | except: 705 | # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql 706 | p_sql = { 707 | "except": None, 708 | "from": {"conds": [], "table_units": []}, 709 | "groupBy": [], 710 | "having": [], 711 | "intersect": None, 712 | "limit": None, 713 | "orderBy": [], 714 | "select": [False, []], 715 | "union": None, 716 | "where": [], 717 | } 718 | eval_err_num += 1 719 | print("eval_err_num:{}".format(eval_err_num)) 720 | 721 | # rebuild sql for value evaluation 722 | kmap = kmaps[db_name] 723 | g_valid_col_units = build_valid_col_units(g_sql["from"]["table_units"], schema) 724 | g_sql = rebuild_sql_val(g_sql) 725 | g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) 726 | p_valid_col_units = build_valid_col_units(p_sql["from"]["table_units"], schema) 727 | p_sql = rebuild_sql_val(p_sql) 728 | p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) 729 | 730 | if etype in ["all", "exec"]: 731 | exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) 732 | if exec_score: 733 | scores[hardness]["exec"] += 1.0 734 | scores["all"]["exec"] += 1.0 735 | 736 | if etype in ["all", "match"]: 737 | exact_score = evaluator.eval_exact_match(p_sql, g_sql) 738 | partial_scores = evaluator.partial_scores 739 | if exact_score == 0: 740 | print("{} pred: {}".format(hardness, p_str)) 741 | print("{} gold: {}".format(hardness, g_str)) 742 | print("") 743 | scores[hardness]["exact"] += exact_score 744 | scores["all"]["exact"] += exact_score 745 | for type_ in partial_types: 746 | if partial_scores[type_]["pred_total"] > 0: 747 | scores[hardness]["partial"][type_]["acc"] += partial_scores[type_][ 748 | "acc" 749 | ] 750 | scores[hardness]["partial"][type_]["acc_count"] += 1 751 | if partial_scores[type_]["label_total"] > 0: 752 | scores[hardness]["partial"][type_]["rec"] += partial_scores[type_][ 753 | "rec" 754 | ] 755 | scores[hardness]["partial"][type_]["rec_count"] += 1 756 | scores[hardness]["partial"][type_]["f1"] += partial_scores[type_]["f1"] 757 | if partial_scores[type_]["pred_total"] > 0: 758 | scores["all"]["partial"][type_]["acc"] += partial_scores[type_][ 759 | "acc" 760 | ] 761 | scores["all"]["partial"][type_]["acc_count"] += 1 762 | if partial_scores[type_]["label_total"] > 0: 763 | scores["all"]["partial"][type_]["rec"] += partial_scores[type_][ 764 | "rec" 765 | ] 766 | scores["all"]["partial"][type_]["rec_count"] += 1 767 | scores["all"]["partial"][type_]["f1"] += partial_scores[type_]["f1"] 768 | 769 | entries.append( 770 | { 771 | "predictSQL": p_str, 772 | "goldSQL": g_str, 773 | "hardness": hardness, 774 | "exact": exact_score, 775 | "partial": partial_scores, 776 | } 777 | ) 778 | 779 | for level in levels: 780 | if scores[level]["count"] == 0: 781 | continue 782 | if etype in ["all", "exec"]: 783 | scores[level]["exec"] /= scores[level]["count"] 784 | 785 | if etype in ["all", "match"]: 786 | scores[level]["exact"] /= scores[level]["count"] 787 | for type_ in partial_types: 788 | if scores[level]["partial"][type_]["acc_count"] == 0: 789 | scores[level]["partial"][type_]["acc"] = 0 790 | else: 791 | scores[level]["partial"][type_]["acc"] = ( 792 | scores[level]["partial"][type_]["acc"] 793 | / scores[level]["partial"][type_]["acc_count"] 794 | * 1.0 795 | ) 796 | if scores[level]["partial"][type_]["rec_count"] == 0: 797 | scores[level]["partial"][type_]["rec"] = 0 798 | else: 799 | scores[level]["partial"][type_]["rec"] = ( 800 | scores[level]["partial"][type_]["rec"] 801 | / scores[level]["partial"][type_]["rec_count"] 802 | * 1.0 803 | ) 804 | if ( 805 | scores[level]["partial"][type_]["acc"] == 0 806 | and scores[level]["partial"][type_]["rec"] == 0 807 | ): 808 | scores[level]["partial"][type_]["f1"] = 1 809 | else: 810 | scores[level]["partial"][type_]["f1"] = ( 811 | 2.0 812 | * scores[level]["partial"][type_]["acc"] 813 | * scores[level]["partial"][type_]["rec"] 814 | / ( 815 | scores[level]["partial"][type_]["rec"] 816 | + scores[level]["partial"][type_]["acc"] 817 | ) 818 | ) 819 | 820 | print_scores(scores, etype) 821 | 822 | 823 | def eval_exec_match(db, p_str, g_str, pred, gold): 824 | """ 825 | return 1 if the values between prediction and gold are matching 826 | in the corresponding index. Currently not support multiple col_unit(pairs). 827 | """ 828 | conn = sqlite3.connect(db) 829 | cursor = conn.cursor() 830 | try: 831 | cursor.execute(p_str) 832 | p_res = cursor.fetchall() 833 | except: 834 | return False 835 | 836 | cursor.execute(g_str) 837 | q_res = cursor.fetchall() 838 | 839 | def res_map(res, val_units): 840 | rmap = {} 841 | for idx, val_unit in enumerate(val_units): 842 | key = ( 843 | tuple(val_unit[1]) 844 | if not val_unit[2] 845 | else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) 846 | ) 847 | rmap[key] = [r[idx] for r in res] 848 | return rmap 849 | 850 | p_val_units = [unit[1] for unit in pred["select"][1]] 851 | q_val_units = [unit[1] for unit in gold["select"][1]] 852 | return res_map(p_res, p_val_units) == res_map(q_res, q_val_units) 853 | 854 | 855 | from multiprocessing import Manager, Process 856 | 857 | 858 | def execute(db, p_str, pred, timeout): 859 | conn = sqlite3.connect(f"file:{db}?mode=ro", uri=True, timeout=30) 860 | cursor = conn.cursor() 861 | 862 | with Manager() as manager: 863 | result = manager.list() 864 | 865 | def unsafe_execute(result): 866 | # sys.stdout = open(os.devnull, "w") 867 | sys.stderr = open(os.devnull, "w") 868 | try: 869 | cursor.execute(p_str) 870 | p_res = cursor.fetchall() 871 | except sqlite3.OperationalError as e: 872 | if "locked" in str(e): 873 | print(e) 874 | raise ValueError("Invalid") 875 | result.append(p_res) 876 | 877 | p = Process(target=unsafe_execute, args=(result,)) 878 | p.start() 879 | p.join(timeout=timeout) 880 | 881 | if p.exitcode != 0: 882 | return False, None 883 | else: 884 | try: 885 | p_res = result[0] 886 | 887 | def res_map(res, val_units): 888 | rmap = {} 889 | for idx, val_unit in enumerate(val_units): 890 | key = ( 891 | tuple(val_unit[1]) 892 | if not val_unit[2] 893 | else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) 894 | ) 895 | rmap[key] = [r[idx] for r in res] 896 | return rmap 897 | 898 | p_val_units = [unit[1] for unit in pred["select"][1]] 899 | return True, res_map(p_res, p_val_units) 900 | except: 901 | return False, None 902 | 903 | 904 | # Rebuild SQL functions for value evaluation 905 | def rebuild_cond_unit_val(cond_unit): 906 | if cond_unit is None or not DISABLE_VALUE: 907 | return cond_unit 908 | 909 | not_op, op_id, val_unit, val1, val2 = cond_unit 910 | if type(val1) is not dict: 911 | val1 = None 912 | else: 913 | val1 = rebuild_sql_val(val1) 914 | if type(val2) is not dict: 915 | val2 = None 916 | else: 917 | val2 = rebuild_sql_val(val2) 918 | return not_op, op_id, val_unit, val1, val2 919 | 920 | 921 | def rebuild_condition_val(condition): 922 | if condition is None or not DISABLE_VALUE: 923 | return condition 924 | 925 | res = [] 926 | for idx, it in enumerate(condition): 927 | if idx % 2 == 0: 928 | res.append(rebuild_cond_unit_val(it)) 929 | else: 930 | res.append(it) 931 | return res 932 | 933 | 934 | def rebuild_sql_val(sql): 935 | if sql is None or not DISABLE_VALUE: 936 | return sql 937 | 938 | sql["from"]["conds"] = rebuild_condition_val(sql["from"]["conds"]) 939 | sql["having"] = rebuild_condition_val(sql["having"]) 940 | sql["where"] = rebuild_condition_val(sql["where"]) 941 | sql["intersect"] = rebuild_sql_val(sql["intersect"]) 942 | sql["except"] = rebuild_sql_val(sql["except"]) 943 | sql["union"] = rebuild_sql_val(sql["union"]) 944 | 945 | return sql 946 | 947 | 948 | # Rebuild SQL functions for foreign key evaluation 949 | def build_valid_col_units(table_units, schema): 950 | col_ids = [ 951 | table_unit[1] 952 | for table_unit in table_units 953 | if table_unit[0] == TABLE_TYPE["table_unit"] 954 | ] 955 | prefixs = [col_id[:-2] for col_id in col_ids] 956 | valid_col_units = [] 957 | for value in schema.idMap.values(): 958 | if "." in value and value[: value.index(".")] in prefixs: 959 | valid_col_units.append(value) 960 | return valid_col_units 961 | 962 | 963 | def rebuild_col_unit_col(valid_col_units, col_unit, kmap): 964 | if col_unit is None: 965 | return col_unit 966 | 967 | agg_id, col_id, distinct = col_unit 968 | if col_id in kmap and col_id in valid_col_units: 969 | col_id = kmap[col_id] 970 | if DISABLE_DISTINCT: 971 | distinct = None 972 | return agg_id, col_id, distinct 973 | 974 | 975 | def rebuild_val_unit_col(valid_col_units, val_unit, kmap): 976 | if val_unit is None: 977 | return val_unit 978 | 979 | unit_op, col_unit1, col_unit2 = val_unit 980 | col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) 981 | col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) 982 | return unit_op, col_unit1, col_unit2 983 | 984 | 985 | def rebuild_table_unit_col(valid_col_units, table_unit, kmap): 986 | if table_unit is None: 987 | return table_unit 988 | 989 | table_type, col_unit_or_sql = table_unit 990 | if isinstance(col_unit_or_sql, tuple): 991 | col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) 992 | return table_type, col_unit_or_sql 993 | 994 | 995 | def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): 996 | if cond_unit is None: 997 | return cond_unit 998 | 999 | not_op, op_id, val_unit, val1, val2 = cond_unit 1000 | val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) 1001 | return not_op, op_id, val_unit, val1, val2 1002 | 1003 | 1004 | def rebuild_condition_col(valid_col_units, condition, kmap): 1005 | for idx in range(len(condition)): 1006 | if idx % 2 == 0: 1007 | condition[idx] = rebuild_cond_unit_col( 1008 | valid_col_units, condition[idx], kmap 1009 | ) 1010 | return condition 1011 | 1012 | 1013 | def rebuild_select_col(valid_col_units, sel, kmap): 1014 | if sel is None: 1015 | return sel 1016 | distinct, _list = sel 1017 | new_list = [] 1018 | for it in _list: 1019 | agg_id, val_unit = it 1020 | new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) 1021 | if DISABLE_DISTINCT: 1022 | distinct = None 1023 | return distinct, new_list 1024 | 1025 | 1026 | def rebuild_from_col(valid_col_units, from_, kmap): 1027 | if from_ is None: 1028 | return from_ 1029 | 1030 | from_["table_units"] = [ 1031 | rebuild_table_unit_col(valid_col_units, table_unit, kmap) 1032 | for table_unit in from_["table_units"] 1033 | ] 1034 | from_["conds"] = rebuild_condition_col(valid_col_units, from_["conds"], kmap) 1035 | return from_ 1036 | 1037 | 1038 | def rebuild_group_by_col(valid_col_units, group_by, kmap): 1039 | if group_by is None: 1040 | return group_by 1041 | 1042 | return [ 1043 | rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by 1044 | ] 1045 | 1046 | 1047 | def rebuild_order_by_col(valid_col_units, order_by, kmap): 1048 | if order_by is None or len(order_by) == 0: 1049 | return order_by 1050 | 1051 | direction, val_units = order_by 1052 | new_val_units = [ 1053 | rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units 1054 | ] 1055 | return direction, new_val_units 1056 | 1057 | 1058 | def rebuild_sql_col(valid_col_units, sql, kmap): 1059 | if sql is None: 1060 | return sql 1061 | 1062 | sql["select"] = rebuild_select_col(valid_col_units, sql["select"], kmap) 1063 | sql["from"] = rebuild_from_col(valid_col_units, sql["from"], kmap) 1064 | sql["where"] = rebuild_condition_col(valid_col_units, sql["where"], kmap) 1065 | sql["groupBy"] = rebuild_group_by_col(valid_col_units, sql["groupBy"], kmap) 1066 | sql["orderBy"] = rebuild_order_by_col(valid_col_units, sql["orderBy"], kmap) 1067 | sql["having"] = rebuild_condition_col(valid_col_units, sql["having"], kmap) 1068 | sql["intersect"] = rebuild_sql_col(valid_col_units, sql["intersect"], kmap) 1069 | sql["except"] = rebuild_sql_col(valid_col_units, sql["except"], kmap) 1070 | sql["union"] = rebuild_sql_col(valid_col_units, sql["union"], kmap) 1071 | 1072 | return sql 1073 | 1074 | 1075 | def build_foreign_key_map(entry): 1076 | cols_orig = entry["column_names_original"] 1077 | tables_orig = entry["table_names_original"] 1078 | 1079 | # rebuild cols corresponding to idmap in Schema 1080 | cols = [] 1081 | for col_orig in cols_orig: 1082 | if col_orig[0] >= 0: 1083 | t = tables_orig[col_orig[0]] 1084 | c = col_orig[1] 1085 | cols.append("__" + t.lower() + "." + c.lower() + "__") 1086 | else: 1087 | cols.append("__all__") 1088 | 1089 | def keyset_in_list(k1, k2, k_list): 1090 | for k_set in k_list: 1091 | if k1 in k_set or k2 in k_set: 1092 | return k_set 1093 | new_k_set = set() 1094 | k_list.append(new_k_set) 1095 | return new_k_set 1096 | 1097 | foreign_key_list = [] 1098 | foreign_keys = entry["foreign_keys"] 1099 | for fkey in foreign_keys: 1100 | key1, key2 = fkey 1101 | key_set = keyset_in_list(key1, key2, foreign_key_list) 1102 | key_set.add(key1) 1103 | key_set.add(key2) 1104 | 1105 | foreign_key_map = {} 1106 | for key_set in foreign_key_list: 1107 | sorted_list = sorted(list(key_set)) 1108 | midx = sorted_list[0] 1109 | for idx in sorted_list: 1110 | foreign_key_map[cols[idx]] = cols[midx] 1111 | 1112 | return foreign_key_map 1113 | 1114 | 1115 | def build_foreign_key_map_from_json(table): 1116 | with open(table) as f: 1117 | data = json.load(f) 1118 | tables = {} 1119 | for entry in data: 1120 | tables[entry["db_id"]] = build_foreign_key_map(entry) 1121 | return tables 1122 | 1123 | 1124 | if __name__ == "__main__": 1125 | parser = argparse.ArgumentParser() 1126 | parser.add_argument("--gold", dest="gold", type=str) 1127 | parser.add_argument("--pred", dest="pred", type=str) 1128 | parser.add_argument("--db", dest="db", type=str) 1129 | parser.add_argument("--table", dest="table", type=str) 1130 | parser.add_argument("--etype", dest="etype", type=str) 1131 | args = parser.parse_args() 1132 | 1133 | gold = args.gold 1134 | pred = args.pred 1135 | db_dir = args.db 1136 | table = args.table 1137 | etype = args.etype 1138 | 1139 | assert etype in ["all", "exec", "match"], "Unknown evaluation method" 1140 | 1141 | kmaps = build_foreign_key_map_from_json(table) 1142 | 1143 | evaluate(gold, pred, db_dir, etype, kmaps) 1144 | -------------------------------------------------------------------------------- /zeroshot_reviewer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from pathlib import Path 4 | import os 5 | from glob import glob 6 | from argparse import ArgumentParser 7 | import html 8 | import json 9 | from utils import * 10 | from tqdm import tqdm, trange 11 | from data import HumanEvalDataset, rindex, extract_docstring 12 | from functools import partial 13 | from pyminifier_canonicalize import clean_comment, remove_print 14 | 15 | 16 | def postprocess_func_only(code, tokens): 17 | lines = [] 18 | for line in code.split("\n"): 19 | if len(line.strip()) > 0 and not line.startswith(" "): 20 | continue 21 | else: 22 | lines.append(line) 23 | 24 | code = "\n".join(lines) 25 | code = code.rstrip() 26 | 27 | curr = "" 28 | for i, tok in enumerate(tokens): 29 | curr += tok 30 | if len(curr) >= len(code): 31 | break 32 | 33 | return code, tokens[: i + 1] 34 | 35 | 36 | def make_new_context( 37 | codex_data, 38 | problem, 39 | canonicalize=False, 40 | clean_print=False, 41 | ): 42 | prompt = codex_data["prompt"] 43 | code_sample = codex_data["trg_prediction"] 44 | if canonicalize: 45 | try: 46 | code_sample = clean_comment(code_sample) 47 | except: 48 | # static error 49 | code_sample = code_sample 50 | if clean_print: 51 | code_sample = remove_print(code_sample) 52 | func_name = problem["entry_point"] 53 | docstring, func_header, func_context, doc_start = extract_docstring(prompt) 54 | if canonicalize: 55 | func_header = func_header.replace(f"{func_name}(", "f(") 56 | docstring = docstring.replace(f"{func_name}(", "f(") 57 | code_sample = code_sample.replace(f"{func_name}(", "f(") 58 | reverse_prompt = "\n\n# write the docstring for the above function\n" 59 | without_ref = ( 60 | func_context 61 | + "\n" 62 | + func_header.strip() 63 | + "\n" 64 | + code_sample 65 | + reverse_prompt 66 | + func_header.strip() 67 | + "\n" 68 | + f" {doc_start}" 69 | ) 70 | with_ref = without_ref + docstring.strip()[3:] 71 | return with_ref.rstrip(), without_ref 72 | 73 | 74 | def rindex(lst, value): 75 | return len(lst) - lst[::-1].index(value) - 1 76 | 77 | 78 | def find_start(tokens): 79 | tokens = tokens[:-2] # remove last docstring marker 80 | for marker in [' """', " '''", ' ""', "''"]: 81 | if marker in tokens: 82 | return rindex(tokens[:-1], marker) + 1 83 | raise ValueError("not found") 84 | 85 | 86 | def batch_query_reverse_logp(all_codex_data, args): 87 | for outer_i, batch_start in enumerate( 88 | range(0, len(all_codex_data), args.batch_size) 89 | ): 90 | batch_data = all_codex_data[batch_start : batch_start + args.batch_size] 91 | batch_prompts = [] 92 | batch_data_with_prompt = [] 93 | for codex_data, problem in batch_data: 94 | # TODO: postprocessing, should move else where 95 | codex_data["trg_prediction"], codex_data["tokens"] = postprocess_func_only( 96 | codex_data["trg_prediction"], codex_data["tokens"] 97 | ) 98 | codex_data["logprobs"] = codex_data["logprobs"][: len(codex_data["tokens"])] 99 | 100 | with_ref_prompt, without_ref_prompt = make_new_context( 101 | codex_data, 102 | problem, 103 | canonicalize=args.canonicalize, 104 | clean_print=args.clean_print, 105 | ) 106 | batch_prompts.append(with_ref_prompt) 107 | batch_data_with_prompt.append( 108 | (codex_data, problem, with_ref_prompt, without_ref_prompt) 109 | ) 110 | 111 | with_ref_reponse, _ = safe_codex_call( 112 | args, 113 | batch_prompts, 114 | temperature=1.0, 115 | echo=True, 116 | max_tokens=0, 117 | api_i=(outer_i % 3), 118 | ) 119 | for ( 120 | batch_i, 121 | (codex_data, problem, with_ref_prompt, without_ref_prompt), 122 | ) in enumerate(batch_data_with_prompt): 123 | num_api_tokens = find_start( 124 | with_ref_reponse["choices"][batch_i]["logprobs"]["tokens"] 125 | ) 126 | gt_prompt_logprob = with_ref_reponse["choices"][batch_i]["logprobs"][ 127 | "token_logprobs" 128 | ][num_api_tokens:] 129 | gt_prompt_tokens = with_ref_reponse["choices"][batch_i]["logprobs"][ 130 | "tokens" 131 | ][num_api_tokens:] 132 | codex_data["reverse_prompt_with_ref"] = with_ref_prompt 133 | codex_data["reverse_prompt_without_ref"] = without_ref_prompt 134 | codex_data["prompt_reverse_logprobs"] = gt_prompt_logprob 135 | codex_data["prompt_reverse_tokens"] = gt_prompt_tokens 136 | codex_data["prompt_reverse_full_tokens"] = with_ref_reponse["choices"][ 137 | batch_i 138 | ]["logprobs"]["tokens"] 139 | codex_data["prompt_reverse_full_logprobs"] = with_ref_reponse["choices"][ 140 | batch_i 141 | ]["logprobs"]["token_logprobs"] 142 | all_codex_data = [d[0] for d in all_codex_data] 143 | return all_codex_data 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = ArgumentParser() 148 | parser.add_argument("--model", type=str, default="codex001") 149 | parser.add_argument( 150 | "--dataset", 151 | type=str, 152 | default="humaneval", 153 | choices=["humaneval", "codet_humaneval", "mbpp_sanitized"], 154 | ) 155 | parser.add_argument("--tag", type=str, default="") 156 | parser.add_argument("--split", type=str, default="test") 157 | parser.add_argument("--num_samples", type=int, default=5) 158 | parser.add_argument("--num_procs", type=int, default=40) 159 | parser.add_argument( 160 | "--data_path", 161 | type=str, 162 | default="./samples/codex002", 163 | ) 164 | parser.add_argument("--temperature", type=float, default=0.3) 165 | parser.add_argument("--max_tokens", type=int, default=512) 166 | parser.add_argument("--batch_size", type=int, default=20) 167 | parser.add_argument("--top_p", type=float, default=1.0) 168 | parser.add_argument("--canonicalize", default=False, action="store_true") 169 | parser.add_argument("--clean-print", default=False, action="store_true") 170 | parser.add_argument("--overwrite-output-dir", default=False, action="store_true") 171 | 172 | args = parser.parse_args() 173 | args.data_path = Path(args.data_path) 174 | out_dir = f"seed-*/**/*-{args.temperature}" 175 | if args.top_p != 1.0: 176 | out_dir += f"-p{args.top_p}" 177 | if args.max_tokens != 512: 178 | out_dir += f"-max{args.max_tokens}" 179 | args.data_path = args.data_path / args.dataset / out_dir 180 | paths = list(sorted(glob(str(args.data_path), recursive=True))) 181 | 182 | if args.dataset == "codet_humaneval": 183 | dataset = HumanEvalDataset( 184 | "dataset/human_eval/dataset/CodeTHumanEval.jsonl", mode="prompt_only" 185 | ) 186 | else: 187 | dataset = HumanEvalDataset( 188 | path="dataset/mbpp/mbpp_sanitized_for_code_generation.jsonl", 189 | mode="prompt_only", 190 | ) 191 | prompt_to_data = {p["prompt"]: p for task_id, p in dataset.raw_data.items()} 192 | 193 | paths = sorted(paths) 194 | for path in tqdm(paths, desc="total seeds", disable=False): 195 | path = Path(path) 196 | for sample_i in trange(args.num_samples): 197 | if len(args.tag) == 0: 198 | output_file_name = f"{args.split}-{sample_i}.jsonl" 199 | else: 200 | output_file_name = f"{args.split}-{sample_i}-{args.tag}.jsonl" 201 | 202 | try: 203 | all_codex_data = [] 204 | with open(path / f"{args.split}-{sample_i}.jsonl", "r") as f: 205 | for i, line in enumerate(f): 206 | codex_data = json.loads(line) 207 | raw_data = prompt_to_data[codex_data["prompt"]] 208 | all_codex_data.append((codex_data, raw_data)) 209 | except Exception as e: 210 | print(e) 211 | print(f"{path / output_file_name} not ready yet. skipping.") 212 | continue 213 | 214 | if (path / output_file_name).exists() and not args.overwrite_output_dir: 215 | with open(path / output_file_name, "r") as f: 216 | line_num = len(f.readlines()) 217 | if line_num == len(all_codex_data): 218 | print(f"skipping {path / output_file_name}") 219 | continue 220 | 221 | from multiprocessing import Pool 222 | 223 | if args.num_procs > 1: 224 | all_codex_data_with_reverse = [] 225 | chunk_size = len(all_codex_data) // args.num_procs + 1 226 | chunked_all_codex_data = [ 227 | all_codex_data[chunk_start : chunk_start + chunk_size] 228 | for chunk_start in range(0, len(all_codex_data), chunk_size) 229 | ] 230 | with Pool(processes=args.num_procs) as pool: 231 | for codex_data_with_reverse in pool.imap( 232 | partial(batch_query_reverse_logp, args=args), 233 | chunked_all_codex_data, 234 | ): 235 | all_codex_data_with_reverse.extend(codex_data_with_reverse) 236 | else: 237 | all_codex_data_with_reverse = batch_query_reverse_logp( 238 | all_codex_data, args 239 | ) 240 | 241 | with open(path / output_file_name, "w") as f: 242 | for codex_data_with_reverse in all_codex_data_with_reverse: 243 | codex_data_json = json.dumps(codex_data_with_reverse) 244 | f.write(codex_data_json + "\n") 245 | --------------------------------------------------------------------------------