├── .github └── workflows │ ├── codeql.yml │ └── dependency-review.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── example.png ├── execution ├── __init__.py ├── execution_evaluation.py ├── program_tracing.py └── safe_execution_util.py ├── lightning_modules ├── __init__.py ├── callbacks │ ├── __init__.py │ └── save_prediction_callback.py ├── datasets │ ├── __init__.py │ ├── mathqa_line_reader.py │ └── reader_utils.py ├── loggers │ ├── __init__.py │ └── patched_loggers.py └── models │ ├── __init__.py │ ├── gpt_seq2seq_model.py │ ├── gpt_stmt_mml_model.py │ ├── gpt_stmt_partial_mml_model.py │ ├── gpt_stmt_state_model.py │ └── gpt_util.py ├── preprocessing ├── __init__.py ├── mathqa_python_resplit_info.json ├── preprocess_gsm8k.py ├── preprocess_mathqa_python.py └── py-tree-sitter.so ├── requirements.txt ├── trainer.py └── training_configs ├── gpt_mle.yaml ├── gpt_self_sampling.yaml └── gpt_self_sampling_partial.yaml /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '15 7 * * 0' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | with: 74 | category: "/language:${{matrix.language}}" 75 | -------------------------------------------------------------------------------- /.github/workflows/dependency-review.yml: -------------------------------------------------------------------------------- 1 | # Dependency Review Action 2 | # 3 | # This Action will scan dependency manifest files that change as part of a Pull Request, surfacing known-vulnerable versions of the packages declared or updated in the PR. Once installed, if the workflow run is marked as required, PRs introducing known-vulnerable packages will be blocked from merging. 4 | # 5 | # Source repository: https://github.com/actions/dependency-review-action 6 | # Public documentation: https://docs.github.com/en/code-security/supply-chain-security/understanding-your-software-supply-chain/about-dependency-review#dependency-review-enforcement 7 | name: 'Dependency Review' 8 | on: [pull_request] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | dependency-review: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: 'Checkout Repository' 18 | uses: actions/checkout@v3 19 | - name: 'Dependency Review' 20 | uses: actions/dependency-review-action@v2 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # defined by Ansong 132 | data/ 133 | results/ 134 | wandb/ 135 | debug-tmp/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Math Reasoning from Self-Sampled Correct and Partially-Correct Solutions (ICLR'23) 2 |

3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 |

14 | 15 | Code for paper [Learning Math Reasoning from Self-Sampled Correct and Partially-Correct Solutions](https://arxiv.org/abs/2205.14318). In this work, we propose to let the model perform sampling during training and learn from those self-sampled correct or partially-correct solutions, which are automatically identified by comparing the final or intermediate execution states. An example is shown as below. 16 |

17 | 18 |

19 | 20 | ## Updates 21 | - **2023-03-08**: Initial code release 22 | - **2023-02-17**: Camera-ready version updated on [arxiv](https://arxiv.org/abs/2205.14318) 23 | - **2023-01-20**: Paper is accepted at ICLR 2023 24 | 25 | ## Environment Setup 26 | > **Note: all of the following has only been tested on Linux machines, you may need to build your own `tree-sitter` parsers if a different platform is used.** 27 | 28 | *(Recommended)* Create a new conda environment 29 | ```bash 30 | conda create -n trace-codegen python=3.8 31 | conda activate trace-codegen 32 | ``` 33 | Clone the code and install the dependencies 34 | ```bash 35 | git clone git@github.com:microsoft/TraceCodegen 36 | cd TraceCodegen 37 | pip install -r requirements.txt 38 | ``` 39 | *(Optional)* Set up `wandb` for experiment tracking. First following [wandb documentation](https://docs.wandb.ai/ref/cli/wandb-login) to login, then change the following lines in `trainer.logger+` fields of the `yaml` config file you would like to run: 40 | ```yaml 41 | entity: 42 | project: 43 | ``` 44 | *(Optional)* At any point, if you met with the Python import problem (e.g., `ModuleNotFoundError`), try doing this in the main (`TraceCodegen`) directory: 45 | ```bash 46 | export PYTHONPATH=`pwd` 47 | ``` 48 | 49 | ## Data and Preprocessing 50 | We conduct experiments on the [MathQA-Python](https://arxiv.org/abs/2108.07732) and [GSM8k](https://github.com/openai/grade-school-math) datasets. As they have different licenses and preprocessing pipelines, here we describe them separately. But first, let's make a `data` directory: 51 | ```bash 52 | mkdir data 53 | ``` 54 | ### MathQA-Python 55 | First follow [this script](https://github.com/google/trax/blob/master/trax/examples/MathQA_Python_generation_notebook.ipynb) for generation the MathQA-Python dataset from the [original MathQA dataset](https://math-qa.github.io/). After that, make sure your data directory looks something like this: 56 | ``` 57 | data 58 | |-- mathqa 59 | | |-- train-python.jsonl 60 | | |-- val-python.jsonl 61 | | |-- test-python.jsonl 62 | |---... 63 | ``` 64 | We preprocess MathQA-Python by respliting the data with template-based deduplication (see detail in paper). To do this, run the preprocessing script with the following: 65 | ```bash 66 | python resplit_mathqa_python.py 67 | ``` 68 | After this, your `data` directory should now look something like this: 69 | ``` 70 | data 71 | |-- mathqa 72 | | |-- train-python.jsonl 73 | | |-- val-python.jsonl 74 | | |-- test-python.jsonl 75 | | |-- train_dedup.jsonl 76 | | |-- val_dedup.jsonl 77 | |---... 78 | ``` 79 | Note that we only combine and resplit the orignal train and validation set, and the test set kept untouched. 80 | 81 | ### GSM8k 82 | As the solution to GSM8k questions are originally annotated as math formulas, we used a script to automatically extract 83 | MathQA-Python style programs as solutions. To replicate this, first download the data from the 84 | [original GSM8k repo](https://github.com/openai/grade-school-math/tree/master/grade_school_math/data). After that, your `data` directory should look like this: 85 | ``` 86 | data 87 | |-- gsmath 88 | | |-- train.jsonl 89 | | |-- test.jsonl 90 | | |-- ... 91 | |---... 92 | ``` 93 | Now run the preprocessing script for GSM8k: 94 | ```bash 95 | python preprocessing/preprocess_gsm8k.py 96 | ``` 97 | After this, your `data` directory should look like this: 98 | ``` 99 | data 100 | |-- gsmath 101 | | |-- train.jsonl 102 | | |-- test.jsonl 103 | | |-- gsmath_train.jsonl 104 | | |-- gsmath_val.jsonl 105 | | |-- gsmath_test.jsonl 106 | | |-- ... 107 | |---... 108 | ``` 109 | 110 | ## Model Training 111 | Our training framework is built on top of [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) (version 1.5.10). More specifically, you would only need to change the `yaml` configuration files if you would like to adjust the hyperparameters (e.g., batch size, gpus, dataset file, etc). 112 | 113 | > **Note: To run model training, we recommend using GPUs that have at least 32GiB of memory, or decrease the training batch size accordingly. All our experiments are conducted on 8x V100-32GB GPUs.** 114 | 115 | ### Basic usage 116 | ```bash 117 | python trainer.py fit --config .yaml 118 | ``` 119 | Existing `yaml` config files can be found in `training_configs`, you can also find all the hyperparameter settings (e.g., batch size) in the Appendix of the paper. 120 | 121 | ### Using different transformer models 122 | If you would like to switch between `GPT-Neo-125M` and `GPT-Neo-2.7B` models, be sure to change the following fields in the `yaml` config file: 123 | ```yaml 124 | trainer: 125 | ... 126 | strategy: deepspeed_stage_2_offload # for 2.7B, or "ddp_find_unused_parameters_false" for 125M 127 | ... 128 | model: 129 | class_path: ... 130 | init_args: 131 | transformer_model_name: &transformer EleutherAI/gpt-neo-2.7B # or EleutherAI/gpt-neo-125M 132 | ... 133 | data: 134 | class_path: ... 135 | init_args: 136 | ... 137 | batch_size: 2 # [Optional] change this according to the GPU memory 138 | val_batch_size: 4 # [Optional] change this according to the GPU memory 139 | ... 140 | ``` 141 | 142 | ### Using different datasets 143 | Since the MathQA-Python and GSM8k datasets are in the same format after preprocessing, you just need to change the file paths in the following fields of the `yaml` config file: 144 | ```yaml 145 | data: 146 | ... 147 | train_file_path: data/mathqa/train_dedup.jsonl # or "data/gsmath/gsmath_train.jsonl" for gsm8k 148 | val_file_path: data/mathqa/val_dedup.jsonl # or "data/gsmath/gsmath_val.jsonl" for gsm8k 149 | ``` 150 | 151 | ### Use fully- or partially-correct self-sampled solutions 152 | To this end, you just need to use different `yaml` config files in `training_configs`: 153 | ```bash 154 | training_configs/gpt_self_sampling.yaml # for using self-sampled fully-correct solutions only 155 | training_configs/gpt_self_sampling_partial.yaml # for also using self-sampled partially-correct solutions 156 | ``` 157 | 158 | ### Using different learning objectives 159 | - For the MLE baseline, just run with the config file of `training_configs/gpt_mle.yaml` 160 | - For running MML, set the following in the `yaml` config file: 161 | ```yaml 162 | model: 163 | ... 164 | init_args: 165 | ... 166 | mle_lambda: 0.0 167 | mml_lambda: 1.0 168 | ... 169 | ``` 170 | - For running $\beta$-MML, keep the above and set `beta_smoothing: ` 171 | - For running MLE-Aug, set `mle_lambda: 1.0` and `mml_lambda: 0.0` in above. 172 | 173 | ### All other hyperparameters 174 | For all other hyperparameters, please read the rest of the fields in the `yaml` file and the corresponding `__init__` function in the corresponding class, or refer to the [pytorch-lightning documents](https://pytorch-lightning.readthedocs.io/en/1.5.10/). 175 | 176 | ## Model Inference 177 | For running model inference (e.g., on the test set), use the following command: 178 | ```bash 179 | python trainer.py validate --config --model.init_args.load_ckpt_file 180 | ``` 181 | 182 | ## Citation 183 | If you use the code in this repository, consider cite: 184 | ```bibtex 185 | @inproceedings{ni2023selfsampling, 186 | title={Learning Math Reasoning from Self-Sampled Correct and Partially-Correct Solutions}, 187 | author={Ni, Ansong and Inala, Jeevana Priya and Wang, Chenglong and Polozov, Alex and Meek, Christopher and Radev, Dragomir and Gao, Jianfeng}, 188 | booktitle={The 2023 International Conference on Learning Representations} 189 | } 190 | ``` 191 | For any questions, please open an issue. PRs are definitely welcomed, and please check the following section about contributing to this repo. 192 | 193 | ## Contributing 194 | 195 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 196 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 197 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 198 | 199 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 200 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 201 | provided by the bot. You will only need to do this once across all repos using our CLA. 202 | 203 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 204 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 205 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 206 | 207 | ## Trademarks 208 | 209 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 210 | trademarks or logos is subject to and must follow 211 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 212 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 213 | Any use of third-party trademarks or logos are subject to those third-party's policies. 214 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/example.png -------------------------------------------------------------------------------- /execution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/execution/__init__.py -------------------------------------------------------------------------------- /execution/execution_evaluation.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import multiprocessing.pool 3 | import ast 4 | import math 5 | 6 | from typing import List, Dict, Tuple, Any 7 | from concurrent.futures import ProcessPoolExecutor as Pool 8 | from execution.safe_execution_util import execute 9 | 10 | ###################################################################################################### 11 | # following are some dataset specific functions for getting the execution result 12 | ###################################################################################################### 13 | 14 | def mathqa_answer_eq(prediction: Any, gold_answer: Any): 15 | try: 16 | # if the execution result is a numpy array, valueError will be raised 17 | if math.isclose(float(prediction), float(gold_answer), abs_tol=1e-4): 18 | return True 19 | else: 20 | return False 21 | except (ValueError, TypeError, OverflowError): 22 | return False 23 | 24 | def mathqa_execution(program: str) -> Any: 25 | """ 26 | for mathqa-python, we should be getting the answers from the "answer" variable in the local() variables 27 | """ 28 | 29 | result = execute(program) 30 | 31 | if result["result"] == "passed": 32 | if "answer" in result["locals"]: 33 | executed_answer = result["locals"]["answer"] 34 | else: 35 | # FIXME: this is so ad-hoc 36 | executed_answer = -10000 37 | else: 38 | executed_answer = None 39 | 40 | return executed_answer 41 | 42 | 43 | ###################################################################################################### 44 | # following are different metrics for evaluating the execution result 45 | # FIXME: Right now we only consider single test cases 46 | ###################################################################################################### 47 | 48 | def batch_exec_programs(programs: List[str], exec_func: callable, n_processes: int = 20) -> List[Any]: 49 | # build a dict to optimize for potential same programs 50 | program_dict = {} 51 | for program in programs: 52 | if program not in program_dict: 53 | program_dict[program] = None 54 | unique_programs = list(program_dict.keys()) 55 | 56 | idx = 0 57 | parsable_unique_programs = [] 58 | for program in unique_programs: 59 | try: 60 | ast.parse(program, mode="exec") 61 | parsable_unique_programs.append(program) 62 | program_dict[program] = idx 63 | idx += 1 64 | except SyntaxError: 65 | program_dict[program] = -1 66 | except MemoryError: 67 | print(f"MemoryError when parsing {program}") 68 | program_dict[program] = -1 69 | except ValueError: 70 | print(f"ValueError when parsing {program}") 71 | program_dict[program] = -1 72 | 73 | with Pool(n_processes) as p: 74 | unique_executed_answers = p.map(exec_func, parsable_unique_programs) 75 | unique_executed_answers = list(unique_executed_answers) 76 | unique_executed_answers.append(None) # all syntax error will be assigned to None 77 | 78 | # build the original programs answer list 79 | executed_answers = [unique_executed_answers[program_dict[program]] for program in programs] 80 | 81 | return executed_answers, len(unique_programs) 82 | 83 | def batch_execution_acc(programs: List[str], exec_func: callable, answers: List[str], 84 | n_examples: int, eval_at_k: int, n_processes: int = 20) -> List[Tuple[float, float]]: 85 | """ 86 | This function evaluates execution accuracy for a batch of programs using multiprocessing. 87 | 88 | Returns: execution accuracy, execution rate 89 | """ 90 | assert len(programs) == len(answers) * eval_at_k 91 | assert n_examples * eval_at_k == len(programs) 92 | 93 | executed_answers, n_unique_programs = batch_exec_programs(programs, exec_func, n_processes) 94 | print(f"Evaluating {len(programs)} generated programs for {n_examples} tasks, " + \ 95 | f"but only {n_unique_programs} unique programs") 96 | pct_unique_programs = n_unique_programs / len(programs) 97 | 98 | # separate the results for each task 99 | grouped_executed_answers = [executed_answers[i*eval_at_k:(i+1)*eval_at_k] for i in range(0, n_examples)] 100 | grouped_execution_evals = [] 101 | for predicted_answers, gold_answer in zip(grouped_executed_answers, answers): 102 | correct_count = 0.0 103 | for predicted_answer in predicted_answers: 104 | if mathqa_answer_eq(predicted_answer, gold_answer): 105 | correct_count += 1 106 | 107 | accuracy_at_k = correct_count / eval_at_k 108 | pass_at_k = correct_count > 0.0 109 | 110 | grouped_execution_evals.append((accuracy_at_k, pass_at_k)) 111 | 112 | return grouped_execution_evals, pct_unique_programs 113 | 114 | def execution_acc(program: str, exec_func: callable, answer: str) -> Tuple[float, float]: 115 | """ 116 | This function is used to evaluate the accuracy of the execution of the program. 117 | 118 | Returns: execution accuracy, execution rate 119 | """ 120 | executed_answer = exec_func(program) 121 | if executed_answer is not None and mathqa_answer_eq(executed_answer, answer): 122 | return 1.0, 1.0 123 | elif executed_answer is not None: 124 | return 0.0, 1.0 125 | else: 126 | return 0.0, 0.0 127 | 128 | def execution_eval_at_k(programs: List[str], exec_func: callable, answer: str, k: int) -> Tuple[float, float]: 129 | """ 130 | Assign 1.0 when at least one out of the k programs execute to the correct answer 131 | 132 | Returns: (accuracy_at_k, pass_at_k) 133 | """ 134 | assert len(programs) >= k, "The number of programs should be larger than k" 135 | 136 | correct_count = 0.0 137 | with Pool(20) as p: 138 | executed_answers = p.map(exec_func, programs[:k]) 139 | for executed_answer in executed_answers: 140 | if mathqa_answer_eq(executed_answer, answer): 141 | correct_count += 1 142 | 143 | accuracy_at_k = correct_count / k 144 | pass_at_k = correct_count > 0.0 145 | 146 | return accuracy_at_k, pass_at_k 147 | -------------------------------------------------------------------------------- /execution/program_tracing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import scipy 3 | import json 4 | import os 5 | 6 | from concurrent.futures import ProcessPoolExecutor as Pool 7 | from typing import List, Dict, Tuple, Any, Union, NamedTuple, Set 8 | from scipy import special 9 | 10 | from typing import List, Dict, Any 11 | from tqdm import tqdm 12 | from lightning_modules.datasets.reader_utils import get_statements_from_code, byte_idx_to_char_idx 13 | from execution.safe_execution_util import execute, canonicalize_var_dict 14 | from tree_sitter import Language, Parser 15 | 16 | ProgState = Dict[str, float] 17 | HashableProgState = Tuple[str] 18 | ProgTraceUnit = NamedTuple("ProgTraceUnit", [("code", str), ("type", str), ("state", ProgState)]) 19 | ProgTrace = List[ProgTraceUnit] 20 | Program = NamedTuple("Program", [("code", str), ("code_lite", str), ("trace", ProgTrace)]) 21 | 22 | # initialize the parser for the code 23 | language_build_path = os.path.join(os.path.dirname(__file__)+'/../preprocessing/', 'py-tree-sitter.so') 24 | PY_LANGUAGE = Language(language_build_path, 'python') 25 | parser = Parser() 26 | parser.set_language(PY_LANGUAGE) 27 | 28 | """ 29 | Tracing the execution of a program: 30 | 1. It parses the program into a sequence of tracing units (currently stmts); 31 | 2. Make some markings of the tracing units; 32 | 3. Insert tracing code to the program, after every tracing unit; 33 | 4. Run the program with tracing; 34 | 5. Collect the variable tracing information. 35 | """ 36 | 37 | from copy import deepcopy 38 | from types import ModuleType 39 | 40 | tracing_local_list = [] 41 | def record_state(local_var_dict): 42 | copied_local_var_dict = canonicalize_var_dict(local_var_dict) 43 | tracing_local_list.append(copied_local_var_dict) 44 | 45 | def get_execution_states(program: str, debug=False) -> Union[ProgTrace, None]: 46 | # first parse the program with tree-sitter 47 | stmts = get_statements_from_code(program, parser) 48 | 49 | if stmts is None: 50 | if debug: 51 | print(f'skipping unparseable example') 52 | print(f"##########\n{program}\n##########") 53 | return None 54 | 55 | # extract the stmt strings 56 | idx = 0 57 | stmt_states = [] 58 | for stmt in stmts: 59 | start_idx = byte_idx_to_char_idx(stmt['start_byte'], program) 60 | end_idx = byte_idx_to_char_idx(stmt['end_byte'], program) 61 | 62 | if start_idx != idx: 63 | # add the gap 64 | stmt_states.append({"code": program[idx:start_idx], "type": "gap"}) 65 | 66 | # add the stmt 67 | stmt_states.append({"code": program[start_idx:end_idx], "type": "stmt"}) 68 | idx = end_idx 69 | 70 | 71 | # NOTE: FIXME: This only works for straight-line programs since it does not consider indentation 72 | for stmt in stmt_states: 73 | if stmt["type"] == "stmt": 74 | stmt["code"] += "\nrecord_state(locals())" 75 | 76 | # now assemble the program back together 77 | traced_program = "".join([stmt["code"] for stmt in stmt_states]) 78 | 79 | # execute the program with tracing code 80 | result = execute(traced_program, {}, 81 | globals={"tracing_local_list": tracing_local_list, "deepcopy": deepcopy, 82 | "record_state": record_state, "ModuleType": ModuleType, 83 | "math": math, "scipy": scipy, "scipy.special": special}, use_tracing=True) 84 | 85 | if result["result"] == "passed": 86 | # add the *output* states for each statement and remove the tracing code to restore orginal program 87 | stmt_idx = 0 88 | for stmt in stmt_states: 89 | if stmt["type"] == "stmt": 90 | stmt["execution_state"] = result["tracing_local_list"][stmt_idx] 91 | stmt["code"] = stmt["code"].replace("\nrecord_state(locals())", "") 92 | stmt_idx += 1 93 | prog_trace = [ProgTraceUnit(stmt["code"], stmt["type"], 94 | stmt["execution_state"] if stmt["type"] == "stmt" else {}) for stmt in stmt_states] 95 | return prog_trace 96 | else: 97 | if debug: 98 | print(f'skipping example of error: {result["result"]}') 99 | print(f"##########\n{program}\n##########") 100 | return None 101 | 102 | def batch_program_tracing(programs: List[str], n_processes=20) -> List[Union[ProgTrace, None]]: 103 | with Pool(n_processes) as p: 104 | tracing_outputs = p.map(get_execution_states, programs) 105 | return list(tracing_outputs) 106 | 107 | def exec_stmt_in_context(stmt: str, context: Dict[str, Any]): 108 | # NOTE: FIXME: This only works for straight-line programs since it does not consider indentation 109 | traced_stmt = stmt + "\nrecord_state(locals())" 110 | 111 | # execute the program with tracing code 112 | if "math" in context: 113 | context["math"] = math 114 | if "scipy" in context: 115 | context["scipy"] = scipy 116 | context["scipy.special"] = special 117 | 118 | result = execute(traced_stmt, locals=context, 119 | globals={"tracing_local_list": tracing_local_list, "deepcopy": deepcopy, 120 | "record_state": record_state, "ModuleType": ModuleType}, use_tracing=True) 121 | 122 | if result["result"] == "passed": 123 | # return the execution states as the local var list 124 | assert len(result["tracing_local_list"]) == 1, f"tracing_local_list: {result['tracing_local_list']}" 125 | stmt_output_state = result["tracing_local_list"][0] 126 | return stmt_output_state 127 | else: 128 | return None 129 | 130 | def is_trivial_state(state_dict: Dict[str, Any], prev_stmt: str): 131 | if len(state_dict) == 0: 132 | return True 133 | 134 | assert prev_stmt is not None, "prev_stmt must be provided to determine trivial states unless the state is empty" 135 | 136 | if prev_stmt.split(" ")[0] in ["#", "import"]: 137 | return True 138 | 139 | assert len(state_dict) == 1, f"prev_stmt {prev_stmt}; original state dict {state_dict}" 140 | 141 | return f"{list(state_dict.keys())[0]} = {list(state_dict.values())[0]}" in prev_stmt 142 | 143 | def get_state_repr(state_dict: Dict[str, Any], prev_stmt: str = None, only_include_keys: List[str] = None, 144 | prev_state_dict: Dict[str, Any] = None, use_diff=False, skip_trivial_states: bool = False): 145 | if use_diff: 146 | raise NotImplementedError 147 | 148 | if only_include_keys is not None: 149 | state_dict = {k: v for k, v in state_dict.items() if k in only_include_keys} 150 | 151 | if skip_trivial_states and is_trivial_state(state_dict, prev_stmt): 152 | return "" 153 | 154 | repr = "# " 155 | for key, value in state_dict.items(): 156 | repr += f"{key} = {value}; " 157 | repr += "\n" 158 | 159 | return repr 160 | 161 | if __name__ == "__main__": 162 | # load some sample programs 163 | with open('data/mathqa/val-python.jsonl', 'r') as f: 164 | lines = f.readlines() 165 | 166 | json_examples = [json.loads(line) for line in lines] 167 | 168 | with open('data/mathqa/val_python_with_states.jsonl', 'w+') as f: 169 | success_count = 0 170 | failed_count = 0 171 | for json_example in tqdm(json_examples): 172 | 173 | program = json_example["code"] 174 | stmt_states = get_execution_states(program) 175 | 176 | if stmt_states is not None: 177 | json_example["states"] = stmt_states 178 | f.write(json.dumps(json_example) + "\n") 179 | success_count += 1 180 | else: 181 | failed_count += 1 182 | 183 | print(f"Successfully traced {success_count}/{success_count+failed_count} programs") -------------------------------------------------------------------------------- /execution/safe_execution_util.py: -------------------------------------------------------------------------------- 1 | ##### much of the code is from https://raw.githubusercontent.com/openai/human-eval/463c980b59e818ace59f6f9803cd92c749ceae61/human_eval/execution.py #### 2 | from tree_sitter import Language, Parser 3 | from typing import Optional, Callable, Dict, Any, List 4 | import ast 5 | import contextlib 6 | import faulthandler 7 | import io 8 | import os 9 | import multiprocessing 10 | import platform 11 | import signal 12 | import tempfile 13 | 14 | from copy import deepcopy 15 | from types import ModuleType 16 | def canonicalize_var_dict(var_dict): 17 | copied_var_dict = {} 18 | for key, value in var_dict.items(): 19 | if key in ["canonicalize_var_dict", "tracing_local_list", "deepcopy", "record_state", "ModuleType"]: 20 | continue 21 | 22 | if isinstance(value, ModuleType): 23 | assert key in ["math", "scipy"] 24 | copied_var_dict[key] = "module: " + str(value) 25 | elif str(type(value)) == "": 26 | copied_var_dict[key] = int(value) 27 | else: 28 | copied_var_dict[key] = deepcopy(value) 29 | return copied_var_dict 30 | 31 | def execute(code: str, 32 | locals: Dict[str, Any] = {}, 33 | globals: Dict[str, Any] = {}, 34 | task_id: Optional[str] = "tmp", 35 | sol_id: Optional[str] = "tmp", 36 | timeout: Optional[int] = 2, 37 | stdi_str: Optional[str] = None, 38 | use_tracing: bool = False) -> None: 39 | manager = multiprocessing.Manager() 40 | result = manager.dict() 41 | 42 | p = multiprocessing.Process(target=unsafe_execute, args=(code, globals, locals, timeout, result, stdi_str, use_tracing)) 43 | p.start() 44 | p.join(timeout=timeout + 1) 45 | if p.is_alive(): 46 | p.kill() 47 | 48 | if len(result) == 0: 49 | result["result"] = "timed out" 50 | 51 | result.update({"task_id": task_id, "sol_id": sol_id}) 52 | return dict(result) 53 | 54 | def unsafe_execute(check_program: str, 55 | globals_: Dict[str, Any], 56 | locals_: Dict[str, Any], 57 | timeout: float, 58 | result: Dict[str, Any], 59 | stdi_str: Optional[str], 60 | use_tracing: bool = False) -> None: 61 | 62 | with create_tempdir(): 63 | 64 | # These system calls are needed when cleaning up tempdir. 65 | import os 66 | import shutil 67 | rmtree = shutil.rmtree 68 | rmdir = os.rmdir 69 | chdir = os.chdir 70 | 71 | # Disable functionalities that can make destructive changes to the test. 72 | reliability_guard() 73 | 74 | try: 75 | with redirect_io(stdi_str) as ostream: 76 | with time_limit(timeout): 77 | # WARNING 78 | # This program exists to execute untrusted model-generated code. Although 79 | # it is highly unlikely that model-generated code will do something overtly 80 | # malicious in response to this test suite, model-generated code may act 81 | # destructively due to a lack of model capability or alignment. 82 | # Users are strongly encouraged to sandbox this evaluation suite so that it 83 | # does not perform destructive actions on their host or network. For more 84 | # information on how OpenAI sandboxes its code, see the accompanying paper. 85 | # Once you have read this disclaimer and taken appropriate precautions, 86 | # uncomment the following line and proceed at your own risk: 87 | exec(check_program, globals_, locals_) 88 | # result["globals"] = globals_ 89 | result["locals"] = canonicalize_var_dict(locals_) 90 | result["result"] = "passed" 91 | if use_tracing: 92 | result["tracing_local_list"] = globals_["tracing_local_list"] 93 | except TimeoutException: 94 | result["result"] = "timed out" 95 | except BaseException as e: 96 | result["result"] = f"failed: {e}" 97 | finally: 98 | result["ostream"] = ostream.getvalue() 99 | 100 | # Needed for cleaning up. 101 | shutil.rmtree = rmtree 102 | os.rmdir = rmdir 103 | os.chdir = chdir 104 | 105 | def check_correctness_human_eval(problem: Dict, completion: str, timeout: float, 106 | completion_id: Optional[int] = None) -> Dict: 107 | """ 108 | Evaluates the functional correctness of a completion by running the test 109 | suite provided in the problem. 110 | 111 | :param completion_id: an optional completion ID so we can match 112 | the results later even if execution finishes asynchronously. 113 | """ 114 | 115 | # Construct the check program and run it. 116 | check_program = ( 117 | problem["prompt"] + completion + "\n" + 118 | problem["test"] + "\n" + 119 | f"check({problem['entry_point']})" 120 | ) 121 | 122 | return check_correctness(check_program, timeout, problem['task_id'], completion_id) 123 | 124 | def check_correctness(code: str, timeout: float, task_id: Optional[int], 125 | completion_id: Optional[int]) -> Dict: 126 | 127 | manager = multiprocessing.Manager() 128 | result = manager.list() 129 | 130 | p = multiprocessing.Process(target=unsafe_execute, args=(code, timeout, result)) 131 | p.start() 132 | p.join(timeout=timeout + 1) 133 | if p.is_alive(): 134 | p.kill() 135 | 136 | if not result: 137 | result.append("timed out") 138 | 139 | return dict( 140 | task_id=task_id, 141 | passed=result[0] == "passed", 142 | result=result[0], 143 | completion_id=completion_id, 144 | ) 145 | 146 | 147 | @contextlib.contextmanager 148 | def time_limit(seconds: float): 149 | def signal_handler(signum, frame): 150 | raise TimeoutException("Timed out!") 151 | signal.setitimer(signal.ITIMER_REAL, seconds) 152 | signal.signal(signal.SIGALRM, signal_handler) 153 | try: 154 | yield 155 | finally: 156 | signal.setitimer(signal.ITIMER_REAL, 0) 157 | 158 | 159 | @contextlib.contextmanager 160 | def swallow_io(): 161 | stream = WriteOnlyStringIO() 162 | with contextlib.redirect_stdout(stream): 163 | with contextlib.redirect_stderr(stream): 164 | with redirect_stdin(stream): 165 | yield 166 | 167 | @contextlib.contextmanager 168 | def redirect_io(stdi_str: Optional[str]): 169 | stream = io.StringIO() 170 | istream = io.StringIO(stdi_str) if stdi_str else WriteOnlyStringIO() 171 | with contextlib.redirect_stdout(stream): 172 | with contextlib.redirect_stderr(stream): 173 | with redirect_stdin(istream): 174 | yield stream 175 | 176 | @contextlib.contextmanager 177 | def create_tempdir(): 178 | with tempfile.TemporaryDirectory() as dirname: 179 | with chdir(dirname): 180 | yield dirname 181 | 182 | 183 | class TimeoutException(Exception): 184 | pass 185 | 186 | 187 | class WriteOnlyStringIO(io.StringIO): 188 | """ StringIO that throws an exception when it's read from """ 189 | 190 | def read(self, *args, **kwargs): 191 | raise IOError 192 | 193 | def readline(self, *args, **kwargs): 194 | raise IOError 195 | 196 | def readlines(self, *args, **kwargs): 197 | raise IOError 198 | 199 | def readable(self, *args, **kwargs): 200 | """ Returns True if the IO object can be read. """ 201 | return False 202 | 203 | 204 | class redirect_stdin(contextlib._RedirectStream): # type: ignore 205 | _stream = 'stdin' 206 | 207 | 208 | @contextlib.contextmanager 209 | def chdir(root): 210 | if root == ".": 211 | yield 212 | return 213 | cwd = os.getcwd() 214 | os.chdir(root) 215 | try: 216 | yield 217 | except BaseException as exc: 218 | raise exc 219 | finally: 220 | os.chdir(cwd) 221 | 222 | 223 | def reliability_guard(maximum_memory_bytes: Optional[int] = None): 224 | """ 225 | This disables various destructive functions and prevents the generated code 226 | from interfering with the test (e.g. fork bomb, killing other processes, 227 | removing filesystem files, etc.) 228 | 229 | WARNING 230 | This function is NOT a security sandbox. Untrusted code, including, model- 231 | generated code, should not be blindly executed outside of one. See the 232 | Codex paper for more information about OpenAI's code sandbox, and proceed 233 | with caution. 234 | """ 235 | 236 | if maximum_memory_bytes is not None: 237 | import resource 238 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) 239 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) 240 | if not platform.uname().system == 'Darwin': 241 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) 242 | 243 | faulthandler.disable() 244 | 245 | import builtins 246 | builtins.exit = None 247 | builtins.quit = None 248 | 249 | import os 250 | os.environ['OMP_NUM_THREADS'] = '1' 251 | 252 | os.kill = None 253 | os.system = None 254 | os.putenv = None 255 | os.remove = None 256 | os.removedirs = None 257 | os.rmdir = None 258 | os.fchdir = None 259 | os.setuid = None 260 | os.fork = None 261 | os.forkpty = None 262 | os.killpg = None 263 | os.rename = None 264 | os.renames = None 265 | os.truncate = None 266 | os.replace = None 267 | os.unlink = None 268 | os.fchmod = None 269 | os.fchown = None 270 | os.chmod = None 271 | os.chown = None 272 | os.chroot = None 273 | os.fchdir = None 274 | os.lchflags = None 275 | os.lchmod = None 276 | os.lchown = None 277 | os.getcwd = None 278 | os.chdir = None 279 | 280 | import shutil 281 | shutil.rmtree = None 282 | shutil.move = None 283 | shutil.chown = None 284 | 285 | import subprocess 286 | subprocess.Popen = None # type: ignore 287 | 288 | __builtins__['help'] = None 289 | 290 | import sys 291 | sys.modules['ipdb'] = None 292 | sys.modules['joblib'] = None 293 | sys.modules['resource'] = None 294 | sys.modules['psutil'] = None 295 | sys.modules['tkinter'] = None 296 | 297 | 298 | if __name__ == '__main__': 299 | import sys 300 | istream = io.StringIO('hello\nworld\n') 301 | 302 | outputs = [] 303 | with redirect_io(istream) as ostream: 304 | code = """ 305 | a = input() 306 | print(f'a is {a}') 307 | b = input() 308 | print(f'b is {b}') 309 | """ 310 | 311 | exec(code, {}) 312 | 313 | outputs.append(ostream.getvalue()) 314 | 315 | print(outputs) -------------------------------------------------------------------------------- /lightning_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/lightning_modules/__init__.py -------------------------------------------------------------------------------- /lightning_modules/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/lightning_modules/callbacks/__init__.py -------------------------------------------------------------------------------- /lightning_modules/callbacks/save_prediction_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import pytorch_lightning as pl 5 | 6 | from typing import Any, Dict, Optional, List 7 | from pytorch_lightning.callbacks import Callback 8 | from pathlib import Path 9 | 10 | class SavePredictionCallback(Callback): 11 | def __init__(self): 12 | self.predictions = list() 13 | self.prediction_save_dir = None 14 | self.metrics_save_dir = None 15 | 16 | def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str]) -> None: 17 | if pl_module.global_rank == 0: 18 | # make dirs for predictions and metrics saving paths 19 | pred_save_dir = os.path.join(trainer.log_dir, 'predictions') 20 | metrics_save_dir = os.path.join(trainer.log_dir, 'metrics') 21 | 22 | Path(pred_save_dir).mkdir(parents=True, exist_ok=True) 23 | Path(metrics_save_dir).mkdir(parents=True, exist_ok=True) 24 | 25 | self.prediction_save_dir = pred_save_dir 26 | self.metrics_save_dir = metrics_save_dir 27 | 28 | def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", 29 | outputs: List[Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int) -> None: 30 | self.predictions.extend(outputs) 31 | 32 | def on_validation_epoch_end(self, trainer, pl_module) -> None: 33 | # save the predictions 34 | save_pred_file_path = os.path.join(self.prediction_save_dir, 35 | f'predictions_step_{trainer.global_step}_rank_{trainer.global_rank}.jsonl') 36 | with open(save_pred_file_path, 'w+') as f: 37 | for prediction in self.predictions: 38 | f.write(json.dumps(prediction)+'\n') 39 | print(f"{len(self.predictions)} predictions saved to {save_pred_file_path}") 40 | self.predictions = [] 41 | 42 | 43 | if pl_module.global_rank == 0: 44 | self.save_metrics(trainer, pl_module) 45 | 46 | pl_module._rouge_metric.reset() 47 | pl_module._bleu_metric.reset() 48 | pl_module._em_metric.reset() 49 | pl_module._stmt_length.reset() 50 | pl_module._cell_stmt_num.reset() 51 | pl_module._edit_distance.reset() 52 | 53 | def save_metrics(self, trainer, pl_module) -> None: 54 | metrics = {} 55 | 56 | rouge_dict = dict([(k, float(v)) for k, v in pl_module._rouge_metric.compute().items() if k.endswith('fmeasure')]) 57 | metrics.update(rouge_dict) 58 | metrics["bleu"] = float(pl_module._bleu_metric.compute()) 59 | metrics["cell_exact_match"] = float(pl_module._em_metric.compute()) 60 | metrics["output_stmt_len"] = float(pl_module._stmt_length.compute()) 61 | metrics["output_stmt_num"] = float(pl_module._cell_stmt_num.compute()) 62 | metrics["cell_edit_dist"] = float(pl_module._edit_distance.compute()) 63 | 64 | # save the evaluation metrics 65 | save_metrics_file_path = os.path.join(self.metrics_save_dir, f'metrics_step_{trainer.global_step}.json') 66 | with open(save_metrics_file_path, 'w+') as f: 67 | f.write(json.dumps(metrics, indent=4)) 68 | 69 | print(f"Eval metrics saved to {save_metrics_file_path}") -------------------------------------------------------------------------------- /lightning_modules/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/lightning_modules/datasets/__init__.py -------------------------------------------------------------------------------- /lightning_modules/datasets/mathqa_line_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | import os 5 | from overrides import overrides 6 | import torch 7 | 8 | from typing import Dict, Iterable, List, Any, Optional, Union 9 | 10 | from pytorch_lightning import LightningDataModule 11 | from torch.utils.data import Dataset 12 | 13 | from lightning_modules.models.gpt_util import get_gpt, left_pad_sequences 14 | from execution.program_tracing import get_state_repr, is_trivial_state 15 | 16 | from torch.utils.data import DataLoader 17 | 18 | # set environment variable to avoid deadlocks, see: 19 | # https://docs.allennlp.org/main/api/data/data_loaders/multiprocess_data_loader/#multiprocessdataloader.common_issues 20 | os.environ['TOKENIZERS_PARALLELISM']='0' 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | FEW_SHOT_RESERVED = 10 25 | 26 | 27 | class MathQADataset(Dataset): 28 | 29 | def __init__( 30 | self, 31 | file_path: str, 32 | transformer_model_name: str, 33 | max_instances: int, 34 | few_shot_n: int = 0, 35 | mode: str = "train", 36 | multi_example_instance: bool = False, 37 | **kwargs): 38 | super().__init__(**kwargs) 39 | 40 | # mode is one of ["train", "test", "test_few_shot"] 41 | assert mode in ["train", "test", "test_few_shot"] 42 | 43 | _, self.tokenizer = get_gpt(transformer_model_name, tokenizer_only=True) 44 | 45 | self.max_instances = max_instances 46 | self.mode = mode 47 | self.multi_example_instance = multi_example_instance 48 | 49 | assert few_shot_n <= FEW_SHOT_RESERVED, f"few_shot_n should be smaller than {FEW_SHOT_RESERVED}" 50 | self.few_shot_n = few_shot_n 51 | 52 | self.instances = self.read(file_path) 53 | 54 | def get_train_instance(self, example: Dict[str, Any]) -> Dict[str, Any]: 55 | example_dict = {"metadata": example} 56 | 57 | tokenizer_outputs = self.tokenizer("\n".join([example["text"], example["code"]])) 58 | 59 | 60 | example_dict["input_ids"] = tokenizer_outputs["input_ids"] + [self.tokenizer.eos_token_id] 61 | example_dict["attention_mask"] = tokenizer_outputs["attention_mask"] + [1] 62 | example_dict["metadata"]["pad_token_id"] = self.tokenizer.pad_token_id 63 | 64 | return example_dict 65 | 66 | def get_test_instance(self, example: Dict[str, Any]) -> Dict[str, Any]: 67 | example_dict = {"metadata": example} 68 | 69 | tokenizer_outputs = self.tokenizer(example["text"] + "\n") 70 | 71 | 72 | example_dict["input_ids"] = tokenizer_outputs["input_ids"] 73 | example_dict["attention_mask"] = tokenizer_outputs["attention_mask"] 74 | example_dict["metadata"]["pad_token_id"] = self.tokenizer.pad_token_id 75 | 76 | return example_dict 77 | 78 | def get_test_few_shot_instance(self, example: Dict[str, Any], 79 | few_shot_text_list: List[str], 80 | few_shot_code_list: List[str]) -> Dict[str, Any]: 81 | raise NotImplementedError("get_test_few_shot_instance is deprecated.") 82 | 83 | def read(self, file_path: str) -> Iterable[Dict[str, Any]]: 84 | print("Reading dataset files at %s", file_path) 85 | 86 | all_yield_instances = [] 87 | 88 | # load the mathqa dataset with states 89 | mathqa_json_examples = [] 90 | with open(file_path, 'r') as f: 91 | if self.mode == "test_few_shot": 92 | lines = f.readlines()[:self.max_instances+FEW_SHOT_RESERVED] 93 | else: 94 | lines = f.readlines()[:self.max_instances] 95 | for line in lines: 96 | mathqa_json_examples.append(json.loads(line)) 97 | 98 | if self.mode == "test_few_shot": 99 | # holdout for few-shot prompting 100 | few_shot_examples = mathqa_json_examples[:self.few_shot_n] 101 | few_shot_text_list = [example['text'] for example in few_shot_examples] 102 | few_shot_code_list = [example['code'] for example in few_shot_examples] 103 | 104 | mathqa_json_examples = mathqa_json_examples[FEW_SHOT_RESERVED:] 105 | 106 | for exp in mathqa_json_examples: 107 | if self.mode == "train": 108 | example_dict = self.get_train_instance(exp) 109 | elif self.mode == "test": 110 | example_dict = self.get_test_instance(exp) 111 | elif self.mode == "test_few_shot": 112 | example_dict = self.get_test_few_shot_instance(exp, few_shot_text_list, few_shot_code_list) 113 | else: 114 | raise ValueError(f"Unknown mode: {self.mode}") 115 | 116 | all_yield_instances.append(example_dict) 117 | 118 | logger.info(f"loaded {len(all_yield_instances)} instances") 119 | 120 | return all_yield_instances 121 | 122 | def __getitem__(self, idx: int): 123 | return self.instances[idx] 124 | 125 | def __len__(self): 126 | return len(self.instances) 127 | 128 | def truncate(self, max_instances): 129 | truncated_instances = self.instances[max_instances:] 130 | self.instances = self.instances[:max_instances] 131 | return truncated_instances 132 | 133 | def extend(self, instances): 134 | self.instances.extend(instances) 135 | 136 | class MathQAMmlDataset(MathQADataset): 137 | 138 | def __init__( 139 | self, 140 | file_path: str, 141 | transformer_model_name: str, 142 | max_instances: int, 143 | **kwargs): 144 | super().__init__(file_path=file_path, transformer_model_name=transformer_model_name, 145 | max_instances=max_instances, **kwargs) 146 | 147 | def get_train_instance(self, example: Dict[str, Any]) -> Dict[str, Any]: 148 | example_dict = {"metadata": example} 149 | 150 | tokenizer_outputs = self.tokenizer(example["text"] + "\n") 151 | 152 | example_dict["input_ids"] = tokenizer_outputs["input_ids"] 153 | example_dict["attention_mask"] = tokenizer_outputs["attention_mask"] 154 | example_dict["metadata"]["pad_token_id"] = self.tokenizer.pad_token_id 155 | 156 | return example_dict 157 | 158 | 159 | def customized_collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, Any]: 160 | result_dict = {} 161 | 162 | pad_token_id = examples[0]["metadata"]["pad_token_id"] 163 | 164 | for k in examples[0].keys(): 165 | if k == "metadata": 166 | result_dict[k] = [ex[k] for ex in examples] 167 | elif k == "input_ids": 168 | result_dict[k] = left_pad_sequences([torch.tensor(ex[k]) for ex in examples], 169 | batch_first=True, padding_value=pad_token_id) 170 | elif k == "attention_mask": 171 | result_dict[k] = left_pad_sequences([torch.tensor(ex[k]) for ex in examples], 172 | batch_first=True, padding_value=0) 173 | elif k == "state_mask": 174 | result_dict[k] = left_pad_sequences([torch.tensor(ex[k]) for ex in examples], 175 | batch_first=True, padding_value=0) 176 | elif k == "labels": 177 | result_dict[k] = left_pad_sequences([torch.tensor(ex[k]) for ex in examples], 178 | batch_first=True, padding_value=pad_token_id) 179 | else: 180 | raise ValueError(f"Unknown key {k} in example instance") 181 | 182 | return result_dict 183 | 184 | class MathQADataModule(LightningDataModule): 185 | def __init__(self, 186 | transformer_model_name: str, 187 | batch_size: int = 1, 188 | val_batch_size: int = 1, 189 | few_shot_n: int = 0, 190 | train_file_path: str = None, 191 | val_file_path: str = None, 192 | test_file_path: str = None, 193 | train_max_instances: int = sys.maxsize, 194 | val_max_instances: int = sys.maxsize): 195 | super().__init__() 196 | self.transformer_model_name = transformer_model_name 197 | self.batch_size = batch_size 198 | self.val_batch_size = val_batch_size 199 | self.few_shot_n = few_shot_n 200 | 201 | self.train_file_path = train_file_path 202 | self.val_file_path = val_file_path 203 | self.test_file_path = test_file_path 204 | 205 | self.train_max_instances = train_max_instances 206 | self.val_max_instances = val_max_instances 207 | 208 | self.train_data = None 209 | self.val_data = None 210 | 211 | def assign_data(self): 212 | train_data = MathQADataset(file_path=self.train_file_path, 213 | transformer_model_name=self.transformer_model_name, 214 | max_instances=self.train_max_instances, 215 | mode="train", few_shot_n=self.few_shot_n) 216 | self.train_data = train_data 217 | 218 | val_data = MathQADataset(file_path=self.val_file_path, 219 | transformer_model_name=self.transformer_model_name, 220 | max_instances=self.val_max_instances, 221 | mode="test", few_shot_n=self.few_shot_n) 222 | self.val_data = val_data 223 | print(f"assigning data is called!") 224 | 225 | # OPTIONAL, called for every GPU/machine (assigning state is OK) 226 | def setup(self, stage: Optional[str] = None): 227 | assert stage in ["fit", "validate", "test"] 228 | 229 | self.assign_data() 230 | 231 | def train_dataloader(self): 232 | if self.train_data is None: 233 | self.assign_data() 234 | 235 | dtloader = DataLoader(self.train_data, batch_size=self.batch_size, 236 | shuffle=True, drop_last=True, collate_fn=customized_collate_fn) 237 | return dtloader 238 | 239 | def val_dataloader(self): 240 | if self.val_data is None: 241 | self.assign_data() 242 | 243 | dtloader = DataLoader(self.val_data, batch_size=self.val_batch_size, 244 | shuffle=False, drop_last=True, collate_fn=customized_collate_fn) 245 | return dtloader 246 | 247 | def test_dataloader(self): 248 | raise NotImplementedError 249 | 250 | class MathQAMmlDataModule(MathQADataModule): 251 | def __init__(self, transformer_model_name: str, **kwargs): 252 | super().__init__(transformer_model_name=transformer_model_name, **kwargs) 253 | 254 | @overrides 255 | def assign_data(self): 256 | train_data = MathQAMmlDataset(file_path=self.train_file_path, 257 | transformer_model_name=self.transformer_model_name, 258 | max_instances=self.train_max_instances, 259 | mode="train", few_shot_n=self.few_shot_n) 260 | self.train_data = train_data 261 | 262 | val_data = MathQAMmlDataset(file_path=self.val_file_path, 263 | transformer_model_name=self.transformer_model_name, 264 | max_instances=self.val_max_instances, 265 | mode="test", few_shot_n=self.few_shot_n) 266 | self.val_data = val_data 267 | print(f"assigning data in MML data module is called!") 268 | -------------------------------------------------------------------------------- /lightning_modules/datasets/reader_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | 5 | from functools import reduce 6 | from math import gcd 7 | from typing import List, Any, Union, Dict, Tuple 8 | 9 | from tree_sitter import Parser, Language 10 | 11 | # TODO: add those tokens to the vocab of whatever transformer models 12 | INDENT_TOKEN = "@@INDENT@@" 13 | NEWLINE_TOKEN = "@@NEWLINE@@" 14 | END_OF_CELL_TOKEN = "<|endofcell|>" 15 | COM_STMTS = ['if_statement', 'for_statement', 'while_statement', 'try_statement', 'with_statement', 16 | 'function_definition', 'class_definition'] 17 | PY_MODULES = ['module', 'block', 'decorated_definition'] 18 | 19 | logger = logging.getLogger('reader') 20 | 21 | def get_indent_level(line: str) -> int: 22 | """return the indent level of the line, measuring the leading tabs or spaces""" 23 | code_start_idx = line.find(line.lstrip()) 24 | 25 | if code_start_idx == 0: 26 | return 0 27 | 28 | indent_str = line[:code_start_idx] 29 | if ' ' in indent_str and '\t' in indent_str: 30 | logger.info('indentation string contains both spaces and tabs') 31 | 32 | if indent_str.replace(' ', '').replace('\t', '') != '': 33 | logger.info('indentation string is not all spaces or tabs') 34 | 35 | return len(indent_str) 36 | 37 | 38 | def preprocess_cell(cell: Dict[str, Any], line_offset: int, add_special_token_for_code: bool=False) -> Tuple[List[str], Dict]: 39 | """preprocess the cell""" 40 | 41 | cell_type = cell['type'] 42 | cell_lines = cell['lines'] 43 | 44 | if not add_special_token_for_code: 45 | # special NEWLINE token will be added, so no need to add it here 46 | cell_lines = [line+'\n' for line in cell_lines] 47 | 48 | if 'stmts' in cell: 49 | stmts = cell['stmts'] 50 | else: 51 | stmts = [] 52 | 53 | for stmt in stmts: 54 | start, end = stmt['start_point'], stmt['end_point'] 55 | stmt['start_point'] = (start[0]+line_offset, start[1]) 56 | stmt['end_point'] = (end[0]+line_offset, end[1]) 57 | 58 | if cell_type == 'markdown': 59 | commented_lines = [f"# {line}" for line in cell_lines] 60 | return commented_lines, stmts 61 | 62 | if cell_type == 'code' and add_special_token_for_code: 63 | # first figure out the indent levels 64 | indent_levels = [get_indent_level(line) for line in cell_lines] 65 | 66 | # there are cases where the indentations are actually line breakers for long line, 67 | # use 20 as a threshold to distinguish between actual indent and line breakers 68 | indent_levels = [x if (x > 1 and x % 2 == 0 and x <= 20) else 0 for x in indent_levels] 69 | 70 | gcd_indent_level = reduce(gcd, indent_levels) 71 | if gcd_indent_level not in [0, 2, 4, 8]: 72 | # logger.info(f'indentation level is not a power of 2 but {gcd_indent_level}, setting all indentations to 0') 73 | indent_levels = [0] * len(indent_levels) 74 | if gcd_indent_level != 0: 75 | indent_levels = [i // gcd_indent_level for i in indent_levels] 76 | 77 | # add indentation and newline tokens 78 | lines = [' '.join([INDENT_TOKEN]*i + [line[i*gcd_indent_level:]] + [NEWLINE_TOKEN]) for i, line in zip(indent_levels, cell_lines)] 79 | return lines, stmts 80 | 81 | return cell_lines, stmts 82 | 83 | def construct_from_lines(all_lines: List[str], start: Tuple[int, int], 84 | end: Tuple[int, int], is_byte_idx: bool=False): 85 | if is_byte_idx: 86 | start = (start[0], byte_idx_to_char_idx(start[1], all_lines[start[0]])) 87 | end = (end[0], byte_idx_to_char_idx(end[1], all_lines[end[0]])) 88 | 89 | # construct back the statement string 90 | statement_str = '' 91 | for i in range(start[0], end[0] + 1): 92 | if i == start[0]: 93 | if i == end[0]: # same line statement 94 | statement_str += (all_lines[i][start[1]:end[1]]) 95 | else: 96 | statement_str += (all_lines[i][start[1]:]) 97 | elif i == end[0]: 98 | statement_str += (all_lines[i][:end[1]]) 99 | else: 100 | statement_str += all_lines[i] 101 | 102 | return statement_str 103 | 104 | def byte_idx_to_char_idx(byte_idx: int, line: str) -> int: 105 | """convert byte index to char index""" 106 | return len(bytes(line, 'utf-8')[:byte_idx].decode('utf-8')) 107 | 108 | def last_char_idx_to_token_idx(char_idx: int, tokens: List[str]) -> int: 109 | """ find the token that ends with the given char index """ 110 | # calculate the token end indices 111 | token_end_indices = [] 112 | total_len = 0 113 | for token in tokens: 114 | total_len += len(token) 115 | token_end_indices.append(total_len) 116 | 117 | if char_idx+1 in token_end_indices: 118 | return token_end_indices.index(char_idx+1) 119 | else: 120 | # here is a special case, when the last char of a stmt is not the ending 121 | # char of a token. An example is `a = f(x);`, while ')' is the last char 122 | # of the stmt, BPE gives ');' as a token, thus we will have to add the whole token 123 | return token_end_indices.index(char_idx+2) 124 | 125 | def last_byte_idx_to_token_idx(byte_idx: int, tokens: List[str], tokenizer) -> int: 126 | """ find the token that ends with the given byte index """ 127 | # calculate the token end indices 128 | token_end_indices = [] 129 | total_len = 0 130 | for token in tokens: 131 | total_len += len(bytearray([tokenizer.byte_decoder[c] for c in token])) 132 | token_end_indices.append(total_len) 133 | 134 | if byte_idx+1 in token_end_indices: 135 | return token_end_indices.index(byte_idx+1) 136 | else: 137 | # here is a special case, when the last byte of a stmt is not the ending 138 | # char of a token. An example is `a = f(x);`, while ')' is the last char 139 | # of the stmt, BPE gives ');' as a token, thus we will have to add the whole token 140 | # NOTE: the example above is actually the only one we observe 141 | return token_end_indices.index(byte_idx+2) 142 | 143 | def get_statements_from_code(code: str, parser, tolerate_errors: bool=False) -> List[Dict[str, Any]]: 144 | parsed_tree = parser.parse(bytes(code, 'utf-8')) 145 | 146 | # do a dfs on the parsed tree to record all the simple statements 147 | target_stmts: List[Dict] = [] 148 | node_stack = [parsed_tree.root_node] 149 | while len(node_stack) > 0: 150 | node = node_stack.pop() 151 | 152 | if (node.type.endswith('statement') or node.type in ['comment', 'decorator']) \ 153 | and node.type not in COM_STMTS: 154 | # this is a simple statement or a comment, so we can add it to the list 155 | target_stmts.append({'type': node.type, 'start_point': node.start_point, 156 | 'end_point': node.end_point, 'start_byte': node.start_byte, 157 | 'end_byte': node.end_byte}) 158 | elif node.type in COM_STMTS or node.type.endswith('clause'): 159 | # separate the header and the body by the ":" token 160 | children_types = [c.type for c in node.children] 161 | separator_idx = children_types.index(':') 162 | assert separator_idx != -1 163 | 164 | # start of the header is the starter of the complex stmt, end is the end of the ":" token 165 | target_stmts.append({'type': node.type+'_header', 'start_point': node.start_point, 166 | 'start_byte': node.children[separator_idx].start_byte, 167 | 'end_point': node.children[separator_idx].end_point, 168 | 'end_byte': node.children[separator_idx].end_byte}) 169 | node_stack.extend(node.children[separator_idx+1:][::-1]) 170 | elif node.type in PY_MODULES: 171 | node_stack.extend(node.children[::-1]) 172 | elif node.type == 'ERROR': 173 | # err_code_line = code[:byte_idx_to_char_idx(node.end_byte, code)].split('\n')[-1] 174 | # print(f"failed to parse code: #########\n{err_code_line}\n#########") 175 | if tolerate_errors: 176 | continue 177 | else: 178 | # failed to parse tree, return None NOTE: previously returning [], but this will get 179 | # confused with blank cells 180 | return None 181 | else: 182 | # other types, not sure what it contains, but assume it doesn't contain more statements 183 | print(f'unexpected node type: {node.type}') 184 | assert 'statement' not in node.sexp() 185 | 186 | return target_stmts 187 | 188 | 189 | if __name__ == '__main__': 190 | language_build_path = os.path.join(os.path.dirname(__file__), '../../preprocessing/py-tree-sitter.so') 191 | PY_LANGUAGE = Language(language_build_path, 'python') 192 | parser = Parser() 193 | parser.set_language(PY_LANGUAGE) 194 | 195 | code = "#a simple test\nif a == 0:\n if b == 0:\n b=1\nelif a>1:\n a*=2\nelse:\n a/=2" 196 | print(code) 197 | 198 | stmts = get_statements_from_code(code, parser) 199 | 200 | start = 0 201 | stmt_str = [] 202 | for stmt in stmts: 203 | stmt_str.append(code[byte_idx_to_char_idx(start, code): byte_idx_to_char_idx(stmt['end_byte'], code)]) 204 | start = stmt['end_byte'] 205 | 206 | 207 | tree = parser.parse(bytes(code, 'utf-8')) 208 | 209 | root = tree.root_node 210 | 211 | print("") 212 | 213 | -------------------------------------------------------------------------------- /lightning_modules/loggers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/lightning_modules/loggers/__init__.py -------------------------------------------------------------------------------- /lightning_modules/loggers/patched_loggers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import neptune 3 | 4 | from typing import Optional, Union, List 5 | from pytorch_lightning.loggers import NeptuneLogger, CSVLogger, TensorBoardLogger, WandbLogger 6 | from pytorch_lightning.utilities import rank_zero_only 7 | 8 | class PatchedWandbLogger(WandbLogger): 9 | def __init__(self, entity: str, project: str, name: str, log_model: bool, save_code: bool, 10 | tags: List[str] = None, *args, **kwargs): 11 | 12 | kwargs['entity'] = entity 13 | kwargs['save_code'] = save_code 14 | 15 | # try load exp name from env var (for using batch jobs) 16 | exp_name_var = os.getenv('EXP_NAME') 17 | if exp_name_var is not None: 18 | processed_name = exp_name_var 19 | else: 20 | # remove the preceeding folder name 21 | processed_name = name.split('/')[-1] 22 | 23 | # automatically get the tags from the name 24 | if tags is None: 25 | kwargs['tags'] = processed_name.split('-') 26 | else: 27 | kwargs['tags'] = tags 28 | 29 | super().__init__(name=processed_name, project=project, log_model=log_model, *args, **kwargs) 30 | 31 | @rank_zero_only 32 | def log_code(self): 33 | # log the yaml and py files 34 | root = "." 35 | print(f"saving all files in {os.path.abspath(root)}") 36 | result = self.experiment.log_code(root=root, 37 | include_fn=(lambda path: path.endswith(".py") or \ 38 | path.endswith(".yaml") or \ 39 | path.endswith(".sh")), 40 | exclude_fn=(lambda path: ".venv" in path or \ 41 | "debug-tmp" in path)) 42 | if result is not None: 43 | print("########################################") 44 | print("######## Logged code to wandb. #########") 45 | print("########################################") 46 | else: 47 | print("######## logger inited but not successfully saved #########") 48 | 49 | class PatchedNeptuneLogger(NeptuneLogger): 50 | def __init__(self, project_name: str, *args, **kwargs): 51 | api_key = os.getenv('NEPTUNE_API_KEY') 52 | if api_key is None: 53 | raise ValueError("Please provide an API key for the neptune logger in the env vars.") 54 | # exp_name = os.getenv('PL_LOG_DIR').split('/')[0] 55 | # exp_name = os.getenv('AMLT_JOB_NAME') 56 | 57 | kwargs['api_key'] = api_key 58 | # kwargs['experiment_id'] = exp_name 59 | kwargs['project'] = project_name 60 | kwargs['source_files'] = ['**/*.py', '**/*.yaml', '**/*.sh'] 61 | 62 | super().__init__(*args, **kwargs) 63 | 64 | class PatchedCSVLogger(CSVLogger): 65 | def __init__(self, 66 | name: Optional[str] = "default", 67 | version: Optional[Union[int, str]] = None, 68 | prefix: str = "", 69 | ): 70 | save_dir = os.getenv('PL_LOG_DIR') 71 | super().__init__(save_dir, name, version, prefix) 72 | 73 | class PatchedTensorBoardLogger(TensorBoardLogger): 74 | def __init__(self, 75 | name: Optional[str] = "default", 76 | version: Optional[Union[int, str]] = None, 77 | log_graph: bool = False, 78 | default_hp_metric: bool = True, 79 | prefix: str = "", 80 | sub_dir: Optional[str] = None, 81 | ): 82 | save_dir = os.getenv('PL_LOG_DIR') 83 | super().__init__(save_dir, name, version, log_graph, default_hp_metric, prefix, sub_dir) -------------------------------------------------------------------------------- /lightning_modules/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/lightning_modules/models/__init__.py -------------------------------------------------------------------------------- /lightning_modules/models/gpt_seq2seq_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | import math 5 | import torch.nn.functional as F 6 | import pytorch_lightning as pl 7 | import io, tokenize, re 8 | import ast, astunparse 9 | import numpy as np 10 | 11 | from types import ModuleType 12 | from typing import Optional, Dict, Any, Tuple, List 13 | from transformers.optimization import AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup 14 | from transformers.optimization import get_cosine_schedule_with_warmup 15 | 16 | 17 | from torchmetrics import Metric, MeanMetric, MetricCollection 18 | from pytorch_lightning import LightningModule 19 | 20 | from .gpt_util import get_gpt 21 | from execution.execution_evaluation import execution_acc, mathqa_execution 22 | from execution.execution_evaluation import execution_eval_at_k, batch_execution_acc 23 | 24 | def get_overlap_example_ids(set_name: str, min_allow_dist: int) -> List[int]: 25 | assert set_name in ['train', 'test', 'val'] 26 | 27 | save_path = f"analysis/train_{set_name}_overlap.npy" 28 | sim_matrix = np.load(save_path) 29 | 30 | # exclude itself in computing the most similar example when using the same set 31 | if set_name == 'train': 32 | np.fill_diagonal(sim_matrix, np.inf) 33 | 34 | min_list = np.min(sim_matrix, axis=0) 35 | 36 | overlapping_ids = [] 37 | for i, min_dist in enumerate(min_list): 38 | if min_dist > min_allow_dist: 39 | continue 40 | else: 41 | overlapping_ids.append(i) 42 | 43 | return overlapping_ids 44 | 45 | # from https://stackoverflow.com/questions/1769332/script-to-remove-python-comments-docstrings 46 | def remove_comments_and_docstrings(source): 47 | io_obj = io.StringIO(source) 48 | out = "" 49 | prev_toktype = tokenize.INDENT 50 | last_lineno = -1 51 | last_col = 0 52 | for tok in tokenize.generate_tokens(io_obj.readline): 53 | token_type = tok[0] 54 | token_string = tok[1] 55 | start_line, start_col = tok[2] 56 | end_line, end_col = tok[3] 57 | ltext = tok[4] 58 | if start_line > last_lineno: 59 | last_col = 0 60 | if start_col > last_col: 61 | out += (" " * (start_col - last_col)) 62 | if token_type == tokenize.COMMENT: 63 | pass 64 | elif token_type == tokenize.STRING: 65 | if prev_toktype != tokenize.INDENT: 66 | if prev_toktype != tokenize.NEWLINE: 67 | if start_col > 0: 68 | out += token_string 69 | else: 70 | out += token_string 71 | prev_toktype = token_type 72 | last_col = end_col 73 | last_lineno = end_line 74 | out = '\n'.join(l for l in out.splitlines() if l.strip()) 75 | return out 76 | 77 | def post_process_code(code, remove_comments=True, remove_extra_lines=False, ast_back_parse=True): 78 | """ a series of post-processing steps to clean up the code and avoid duplicated code """ 79 | 80 | if remove_comments: 81 | code = remove_comments_and_docstrings(code) 82 | 83 | if ast_back_parse: 84 | code = astunparse.unparse(ast.parse(code)) 85 | 86 | if remove_extra_lines: 87 | # remove the code after "answer" is generated 88 | result = [] 89 | for line in code.split("\n"): 90 | result.append(line) 91 | if line.startswith("answer"): 92 | break 93 | code = "\n".join(result) 94 | 95 | code = code.strip() 96 | 97 | return code 98 | 99 | # From the Codex Paper 100 | def estimate_pass_at_k(n, c, k): 101 | """ 102 | :param n: total number of samples 103 | :param c: number of correct samples 104 | :param k: k in pass@$k$ 105 | """ 106 | if n - c < k: 107 | return 1.0 108 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 109 | 110 | class GptSeq2SeqModel(LightningModule): 111 | def __init__(self, 112 | transformer_model_name: str, 113 | max_gen_len: int = 100, 114 | sampling_temp: float = 0.2, 115 | sampling_temp_at_k: float = 0.8, 116 | gradient_ckpt: bool = False, 117 | pass_at_k: int = 1, 118 | additional_pass_at_k: List[int] = [], 119 | eval_pass_at_k_every_n_epochs: int = 1, 120 | always_eval_pass_at_k_first_n_epochs: int = -1, 121 | max_generation_batches: int = 100, 122 | max_steps: int = -1, 123 | warmup_steps: int = 0, 124 | eval_greedy_search: bool = False, 125 | measure_dedup_metrics: bool = False, 126 | optimizer: Dict[str, Any] = None, 127 | lr_scheduler: Dict[str, Any] = None, 128 | load_ckpt_file: str = None) -> None: 129 | super().__init__() 130 | 131 | self.max_gen_len = max_gen_len 132 | self.sampling_temp = sampling_temp 133 | self.sampling_temp_at_k = sampling_temp_at_k 134 | self.max_steps = max_steps 135 | self.warmup_steps = warmup_steps 136 | 137 | self.pass_at_k = pass_at_k 138 | self.additional_pass_at_k = additional_pass_at_k 139 | self.eval_pass_at_k_every_n_epochs = eval_pass_at_k_every_n_epochs 140 | self.always_eval_pass_at_k_first_n_epochs = always_eval_pass_at_k_first_n_epochs 141 | self.max_generation_batches = max_generation_batches 142 | self.eval_greedy_search = eval_greedy_search 143 | self.measure_dedup_metrics = measure_dedup_metrics 144 | 145 | # We only instantiate this when we need it. 146 | self.gpt, self.tokenizer = get_gpt(transformer_model_name, gradient_ckpt=gradient_ckpt) 147 | 148 | # save the prediction results for every valiation epoch 149 | self.predictions: List[Dict[str, Any]] = [] 150 | 151 | # the optimizer and lr scheduler settings 152 | self.opt_params = optimizer["init_args"] 153 | self.lrs_params = lr_scheduler 154 | assert self.lrs_params["name"] in ["linear", "cosine", "constant"], "lr_scheduler must be one of 'linear', 'cosine', 'constant'" 155 | 156 | # keep track of the number of validation epochs for pass at k 157 | self.eval_epoch = 0 158 | 159 | # load the state dict from the checkpoint file 160 | if load_ckpt_file is not None: 161 | checkpoint = torch.load(load_ckpt_file, map_location=torch.device("cpu")) 162 | self.load_state_dict(checkpoint["state_dict"], strict=False) 163 | print(f"loaded weights from {load_ckpt_file}") 164 | 165 | self.metrics_dict: Dict[str, Metric] = MetricCollection({}) 166 | 167 | self.metrics_dict["exec_acc"] = MeanMetric() 168 | self.metrics_dict["exec_rate"] = MeanMetric() 169 | self.metrics_dict["program_len_diff"] = MeanMetric() 170 | self.metrics_dict["unique_pct_in_k"] = MeanMetric() 171 | 172 | assert len(self.additional_pass_at_k) == 0 or self.pass_at_k > max(self.additional_pass_at_k), \ 173 | f"pass_at_k ({self.pass_at_k}) must be greater than all additional_pass_at_k ({self.additional_pass_at_k})" 174 | if self.pass_at_k > 1: 175 | self.metrics_dict[f"acc@{self.pass_at_k}"]= MeanMetric() 176 | self.metrics_dict[f"pass@{self.pass_at_k}"]= MeanMetric() 177 | 178 | for additional_k in self.additional_pass_at_k: 179 | self.metrics_dict[f"pass@{additional_k}"]= MeanMetric() 180 | 181 | if self.eval_greedy_search: 182 | self.metrics_dict["greedy_exec_acc"]= MeanMetric() 183 | self.metrics_dict["greedy_exec_rate"]= MeanMetric() 184 | 185 | if self.measure_dedup_metrics: 186 | # evaluation without the overlap 187 | self.val_overlap_ids: Dict[str, List[int]] = dict() 188 | self.val_overlap_ids["2"] = get_overlap_example_ids("val", 2) 189 | self.val_overlap_ids["4"] = get_overlap_example_ids("val", 4) 190 | self.val_overlap_ids["8"] = get_overlap_example_ids("val", 8) 191 | 192 | self.metrics_dict["dedup_exec_acc_2"]= MeanMetric() 193 | self.metrics_dict["dedup_exec_acc_4"]= MeanMetric() 194 | self.metrics_dict["dedup_exec_acc_8"]= MeanMetric() 195 | 196 | 197 | def forward( # type: ignore 198 | self, 199 | input_ids: torch.Tensor, 200 | attention_mask: torch.Tensor, 201 | metadata: Optional[List[Dict[str, Any]]] = None, 202 | ) -> List[Dict[str, Any]]: 203 | """ 204 | The inference time behavior of the model. 205 | 206 | Args: 207 | input_ids [torch.Tensor]: Tokens from the context. 208 | metadata (Optional[List[Dict[str, Any]]], optional): All additional information, `List` for the batch. Defaults to None. 209 | 210 | Returns: 211 | Dict[str, Any]: results saved in a `Dict` object. 212 | """ 213 | 214 | generated_token_ids = self.gpt.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=True, 215 | max_length=input_ids.shape[1]+self.max_gen_len, 216 | temperature=self.sampling_temp) 217 | 218 | generated_token_ids = generated_token_ids[:, input_ids.shape[1]:] 219 | 220 | generated_strs = self.tokenizer.batch_decode(generated_token_ids) 221 | 222 | # truncate after the first '#' to be consistent with the codex prompting experiments 223 | generated_programs = [s.split(self.tokenizer.eos_token)[0] for s in generated_strs] 224 | 225 | output_dicts = [{"generated_program": generated_programs[i], "metadata": metadata[i]} \ 226 | for i in range(len(generated_programs))] 227 | 228 | if self.eval_greedy_search: 229 | generated_token_ids = self.gpt.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=False, 230 | max_length=input_ids.shape[1]+self.max_gen_len) 231 | generated_token_ids = generated_token_ids[:, input_ids.shape[1]:] 232 | generated_strs = self.tokenizer.batch_decode(generated_token_ids) 233 | # truncate after the first '#' to be consistent with the codex prompting experiments 234 | generated_programs = [s.split(self.tokenizer.eos_token)[0] for s in generated_strs] 235 | 236 | for i in range(len(metadata)): 237 | output_dicts[i]["greedy_generated_program"] = generated_programs[i] 238 | 239 | # evaluate pass at k FIXME: a lot of overlapping code here 240 | if (self.eval_epoch % self.eval_pass_at_k_every_n_epochs == 0 \ 241 | or self.eval_epoch < self.always_eval_pass_at_k_first_n_epochs) and self.pass_at_k > 1: 242 | generated_strs_list = [[] for _ in range(len(metadata))] 243 | remaining_k = self.pass_at_k 244 | while remaining_k > 0: 245 | generate_batch_size = min(remaining_k, self.max_generation_batches) 246 | remaining_k -= generate_batch_size 247 | batch_generated_token_ids = self.gpt.generate(input_ids=input_ids, attention_mask=attention_mask, 248 | do_sample=True, 249 | max_length=input_ids.shape[1]+self.max_gen_len, 250 | temperature=self.sampling_temp_at_k, 251 | num_return_sequences=generate_batch_size) 252 | 253 | batch_generated_token_ids = batch_generated_token_ids[:, input_ids.shape[1]:] 254 | batch_generated_strs = self.tokenizer.batch_decode(batch_generated_token_ids) 255 | batch_generated_programs = [s.split(self.tokenizer.eos_token)[0] for s in batch_generated_strs] 256 | 257 | for i in range(len(metadata)): 258 | generated_strs_list[i].extend(batch_generated_programs[i*generate_batch_size:(i+1)*generate_batch_size]) 259 | 260 | for i in range(len(metadata)): 261 | output_dicts[i]["generated_k_programs"] = generated_strs_list[i] 262 | 263 | 264 | return output_dicts 265 | 266 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: 267 | input_ids = batch["input_ids"] 268 | attention_mask = batch["attention_mask"] 269 | labels = batch["labels"] if "labels" in batch else input_ids 270 | 271 | gpt_result = self.gpt(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 272 | 273 | self.log("loss", gpt_result.loss, on_step=True, on_epoch=True) 274 | return {"loss": gpt_result.loss} 275 | 276 | def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Dict[str, torch.Tensor]: 277 | # input_tokens, target_mask, context_tokens, target_tokens, metadata = batch 278 | return self.forward(batch["input_ids"], batch["attention_mask"], batch["metadata"]) 279 | 280 | def validation_step_end(self, outputs: List[Dict[str, Any]]) -> None: 281 | # update the evaluation metrics 282 | for output_dict in outputs: 283 | exec_acc = execution_acc(output_dict["generated_program"], mathqa_execution, output_dict["metadata"]["answer"]) 284 | program_len = len(list(filter(lambda x: not x.startswith("#") and not len(x.strip()) == 0, 285 | output_dict["generated_program"].split("\n")))) 286 | gold_program_len = len(list(filter(lambda x: not x.startswith("#") and not len(x.strip()) == 0, 287 | post_process_code(output_dict["metadata"]["code"]).split("\n")))) 288 | 289 | program_len_diff = program_len - gold_program_len 290 | 291 | self.metrics_dict["exec_acc"](exec_acc[0]) 292 | self.metrics_dict["exec_rate"](exec_acc[1]) 293 | self.metrics_dict["program_len_diff"](program_len_diff) 294 | 295 | # also save the results in the json output file 296 | output_dict["metrics"] = {"exec_acc": float(exec_acc[0]), 297 | "exec_rate": float(exec_acc[1]), 298 | "program_len_diff": float(program_len_diff)} 299 | 300 | # measuring conditional metrics if they are enabled 301 | if self.measure_dedup_metrics: 302 | task_id = int(output_dict["metadata"]["task_id"]) 303 | for dedup_allow_k in ["2", "4", "8"]: 304 | if not task_id in self.val_overlap_ids[dedup_allow_k]: 305 | self.metrics_dict[f"dedup_exec_acc_{dedup_allow_k}"](exec_acc[0]) 306 | 307 | if self.eval_greedy_search: 308 | exec_acc = execution_acc(output_dict["greedy_generated_program"], mathqa_execution, 309 | output_dict["metadata"]["answer"]) 310 | 311 | self.metrics_dict["greedy_exec_acc"](exec_acc[0]) 312 | self.metrics_dict["greedy_exec_rate"](exec_acc[1]) 313 | 314 | output_dict["metrics"].update({"greedy_exec_acc": float(exec_acc[0]), 315 | "greedy_exec_rate": float(exec_acc[1])}) 316 | 317 | # canonocalization of the states to avoid error on saving the modules 318 | if "generated_program_state_list" in output_dict: 319 | for state_dict in output_dict["generated_program_state_list"]: 320 | if state_dict is not None: 321 | for key, value in state_dict.items(): 322 | if isinstance(value, ModuleType): 323 | state_dict[key] = str(value) 324 | 325 | # save the outputs to the model 326 | self.predictions.extend(outputs) 327 | 328 | def validation_epoch_end_extra(self, outputs: List[Dict[str, Any]]) -> None: 329 | # compute the eval_at_k metrics 330 | if (self.eval_epoch % self.eval_pass_at_k_every_n_epochs == 0 \ 331 | or self.eval_epoch < self.always_eval_pass_at_k_first_n_epochs) and self.pass_at_k > 1: 332 | print("evaluating pass at k...") 333 | 334 | all_generated_k_programs = [p["generated_k_programs"] for p in self.predictions] 335 | all_generated_k_programs_faltten = [item for sublist in all_generated_k_programs for item in sublist] 336 | gold_answers = [p["metadata"]["answer"] for p in self.predictions] 337 | 338 | result_list, pct_unique_progs = batch_execution_acc(all_generated_k_programs_faltten, 339 | mathqa_execution, gold_answers, len(self.predictions), self.pass_at_k) 340 | 341 | self.metrics_dict["unique_pct_in_k"](pct_unique_progs) 342 | for acc_at_k, pass_at_k in result_list: 343 | self.metrics_dict[f"acc@{self.pass_at_k}"](acc_at_k) 344 | self.metrics_dict[f"pass@{self.pass_at_k}"](pass_at_k) 345 | 346 | if len(self.additional_pass_at_k) > 0: 347 | for additional_k in self.additional_pass_at_k: 348 | correct = int(self.pass_at_k * acc_at_k) 349 | estimated_pass_at_k = estimate_pass_at_k(self.pass_at_k, correct, additional_k) 350 | self.metrics_dict[f"pass@{additional_k}"](estimated_pass_at_k) 351 | 352 | self.eval_epoch += 1 353 | 354 | def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: 355 | # extra steps for using the predictions 356 | self.validation_epoch_end_extra(outputs) 357 | 358 | # compute the metrics 359 | eval_metrics_dict = {} 360 | for k in self.metrics_dict.keys(): 361 | if k.startswith("dedup_"): 362 | dedup_exec_acc = float(self.metrics_dict[k].compute()) 363 | # for the dedup metrics, it's possible that a batch is all duplicates thus manually set nan to 0.0 364 | eval_metrics_dict[k] = dedup_exec_acc if not math.isnan(dedup_exec_acc) else 0.0 365 | else: 366 | eval_metrics_dict[k] = float(self.metrics_dict[k].compute()) 367 | 368 | # log and save the evalution metrics 369 | print(f"validation result: {eval_metrics_dict}") 370 | self.log_dict(eval_metrics_dict) 371 | 372 | # reset all the metrics 373 | for k in self.metrics_dict.keys(): 374 | self.metrics_dict[k].reset() 375 | 376 | # save the predictions 377 | save_pred_file_path = os.path.join(self.trainer.log_dir, 378 | f'predictions_step_{self.trainer.global_step}_rank_{self.trainer.global_rank}.jsonl') 379 | with open(save_pred_file_path, 'w+') as f: 380 | for prediction in self.predictions: 381 | f.write(json.dumps(prediction)+'\n') 382 | print(f"{len(self.predictions)} predictions saved to {save_pred_file_path}") 383 | 384 | # reset the predictions 385 | self.predictions = [] 386 | 387 | # FIXME: debug setting only 388 | # self.sampling_temp += 0.1 389 | # self.sampling_temp_at_k += 0.2 390 | # print(f"sampling temp is now {self.sampling_temp}, sampling temp at k is now {self.sampling_temp_at_k}") 391 | 392 | def test_step(self, batch: torch.Tensor, batch_idx: int) -> Dict[str, torch.Tensor]: 393 | raise NotImplementedError 394 | 395 | def on_fit_start(self) -> None: 396 | # save the code using wandb 397 | if self.logger: 398 | # if logger is initialized, save the code 399 | self.logger[0].log_code() 400 | else: 401 | print("logger is not initialized, code will not be saved") 402 | 403 | return super().on_fit_start() 404 | 405 | def configure_optimizers(self): 406 | optimizer = AdamW(self.parameters(), **self.opt_params) 407 | if self.lrs_params["name"] == "cosine": 408 | lr_scheduler = get_cosine_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 409 | elif self.lrs_params["name"] == "linear": 410 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 411 | elif self.lrs_params["name"] == "constant": 412 | lr_scheduler = get_constant_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 413 | else: 414 | raise ValueError(f"lr_scheduler {self.lrs_params} is not supported") 415 | 416 | return {"optimizer": optimizer, 417 | "lr_scheduler": { 418 | "scheduler": lr_scheduler, 419 | "interval": "step" 420 | } 421 | } -------------------------------------------------------------------------------- /lightning_modules/models/gpt_stmt_mml_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import os 4 | import json 5 | import random 6 | 7 | from itertools import chain 8 | from concurrent.futures import ProcessPoolExecutor as Pool 9 | 10 | from typing import Optional, Dict, Any, Tuple, List, Set, Union 11 | from torch.nn import CrossEntropyLoss 12 | 13 | from torchmetrics import MeanMetric 14 | from pytorch_lightning import LightningModule 15 | 16 | from .gpt_stmt_state_model import GptStmtStateModel 17 | from .gpt_seq2seq_model import post_process_code 18 | from .gpt_util import left_pad_sequences 19 | from execution.execution_evaluation import execution_acc, mathqa_execution, batch_exec_programs, mathqa_answer_eq 20 | 21 | 22 | class GptStmtMmlModel(GptStmtStateModel): 23 | def __init__(self, 24 | transformer_model_name: str, 25 | load_gold_programs: bool = True, 26 | on_policy_sample_num: Union[int, float] = 5, 27 | on_policy_sample_temp: float = 0.8, 28 | max_sampling_len: int = 100, 29 | length_diff_tolerance: int = 100, 30 | marg_set_size: int = 100, 31 | max_buffer_size: int = 100, 32 | load_samples_file: str = None, 33 | exclude_context_loss: bool = False, 34 | beta_smoothing: float = 1.0, 35 | mle_lambda: float = 1.0, 36 | mml_lambda: float = 0.0, 37 | mle_aug_norm: bool = False, 38 | **kwargs) -> None: 39 | if "eval_with_states" in kwargs and kwargs["eval_with_states"]: 40 | raise ValueError("eval_with_states is not supported for GptStmtMmlModel") 41 | 42 | super().__init__(transformer_model_name, **kwargs) 43 | 44 | self.load_gold_programs = load_gold_programs 45 | self.on_policy_sample_num = on_policy_sample_num 46 | self.on_policy_sample_temp = on_policy_sample_temp 47 | self.max_sampling_len = max_sampling_len 48 | self.marg_set_size = marg_set_size 49 | self.max_buffer_size = max_buffer_size 50 | self.length_diff_tolerance = length_diff_tolerance 51 | self.exclude_context_loss = exclude_context_loss 52 | self.beta_smoothing = beta_smoothing 53 | assert self.beta_smoothing > 0.0 and self.beta_smoothing <= 1.0, \ 54 | f"beta_smoothing must be in (0, 1], but got {self.beta_smoothing}" 55 | self.mle_lambda = mle_lambda 56 | self.mml_lambda = mml_lambda 57 | self.mle_aug_norm = mle_aug_norm 58 | 59 | # load the large sample of programs from the buffer 60 | self.loaded_samples: Dict[str, List[str]] = {} 61 | if load_samples_file is not None: 62 | self.load_samples_from_file(load_samples_file) 63 | print(f"loaded samples from {load_samples_file}") 64 | 65 | # the key being the task id and the value being the list of correct programs (in strs, *w/o* eos token) 66 | self.correct_program_buffer: Dict[str, Set[str]] = {} 67 | 68 | # define some debugging or eval metrics 69 | self.metrics_dict["pct_unique_programs"] = MeanMetric() 70 | 71 | def load_samples_from_file(self, file_path: str) -> None: 72 | with open(file_path, 'r') as f: 73 | for line in f: 74 | json_dict = json.loads(line) 75 | self.loaded_samples[json_dict["metadata"]["task_id"]] = json_dict["generated_k_programs"] 76 | 77 | def try_save_programs(self, generated_programs: List[str], task_ids: List[str], 78 | correct_answers: List[float], samples_per_task: int, 79 | log_unique_program_pct: bool = True, verbose: bool = False) -> None: 80 | # execute the programs first 81 | program_exec_results, n_unique_programs = batch_exec_programs(generated_programs, mathqa_execution)#, n_processes=5) 82 | if log_unique_program_pct: 83 | self.metrics_dict["pct_unique_programs"](n_unique_programs / (len(task_ids) * samples_per_task)) 84 | 85 | # save the programs with correct results into the buffer if it's not already in there 86 | correct_count = 0 87 | saved_count = 0 88 | for i, exec_result in enumerate(program_exec_results): 89 | example_idx = i // samples_per_task 90 | if mathqa_answer_eq(exec_result, correct_answers[example_idx]): 91 | correct_count += 1 92 | task_id = task_ids[example_idx] 93 | generated_program = post_process_code(generated_programs[i], ast_back_parse=False) 94 | if generated_program not in self.correct_program_buffer[task_id]: 95 | # check whether satisfied the length difference tolerence 96 | min_prog_len = min([len(prog.split("\n")) for prog in self.correct_program_buffer[task_id]]) 97 | if len(generated_program.split("\n")) - min_prog_len <= self.length_diff_tolerance: 98 | # save the program into the buffer 99 | self.correct_program_buffer[task_id].add(generated_program) 100 | saved_count += 1 101 | if verbose: 102 | print(f"{len(generated_programs)} in total, {n_unique_programs} are unique, " + \ 103 | f"{correct_count} programs are correct, saved {saved_count} programs into the buffer") 104 | 105 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: 106 | # use task ids as the identifier 107 | task_ids = [ex["task_id"] for ex in batch["metadata"]] 108 | 109 | # add the gold programs to the buffer 110 | for i, example in enumerate(batch["metadata"]): 111 | task_id = task_ids[i] 112 | if task_id in self.correct_program_buffer: 113 | # already have loaded the correct programs for this task 114 | continue 115 | elif self.load_gold_programs: 116 | gold_program = example["code"] 117 | # the program in the buffer is always post-processed 118 | self.correct_program_buffer[task_id] = set((post_process_code(gold_program, ast_back_parse=False),)) 119 | 120 | # load the program from the samples, if available 121 | if task_id in self.loaded_samples: 122 | loaded_examples_for_task = self.loaded_samples.pop(task_id) # this also removes the key 123 | self.try_save_programs(loaded_examples_for_task, [task_id], [batch["metadata"][i]["answer"]], 124 | len(loaded_examples_for_task), log_unique_program_pct=False) 125 | 126 | else: 127 | self.correct_program_buffer[task_id] = set() 128 | 129 | # do on-policy sampling for the current tasks 130 | input_ids = batch["input_ids"] 131 | attention_mask = batch["attention_mask"] 132 | 133 | if self.on_policy_sample_num > 0: 134 | with torch.no_grad(): 135 | if not all([len(self.correct_program_buffer[task_id]) >= self.max_buffer_size for task_id in task_ids]): 136 | # generate the programs and get their execution results 137 | max_context_len = input_ids.size(1) 138 | output_seqs = self.gpt.generate(input_ids=input_ids, attention_mask=attention_mask, 139 | do_sample=True, max_new_tokens=self.max_sampling_len, 140 | num_return_sequences=self.on_policy_sample_num, 141 | temperature=self.on_policy_sample_temp) 142 | generated_seqs = output_seqs[:, max_context_len:].cpu() 143 | generated_programs = self.tokenizer.batch_decode(generated_seqs, skip_special_tokens=True) 144 | 145 | # try to save the programs 146 | correct_answers = [x["answer"] for x in batch["metadata"]] 147 | self.try_save_programs(generated_programs, task_ids, correct_answers, self.on_policy_sample_num) 148 | 149 | def sample_from_buffer(task_id: str) -> List[str]: 150 | cached_programs = list(self.correct_program_buffer[task_id]) 151 | if len(cached_programs) <= self.marg_set_size: 152 | return cached_programs 153 | else: 154 | return random.sample(cached_programs, self.marg_set_size) 155 | 156 | # remove the left paddings first and concat the context and cached programs 157 | context_seqs = [input_ids[i, -context_len:] for i, context_len in enumerate(attention_mask.sum(dim=1))] 158 | cached_program_seqs: List[List[torch.Tensor]] = [[self.tokenizer(prog_str, return_tensors="pt")['input_ids'][0] 159 | for prog_str in sample_from_buffer(task_id)] 160 | for task_id in task_ids] 161 | cached_program_nums = [len(cached_programs) for cached_programs in cached_program_seqs] 162 | flatten_input_ids = [] 163 | for i, program_seqs in enumerate(cached_program_seqs): 164 | for program_seq in program_seqs: 165 | flatten_input_ids.append(torch.cat((context_seqs[i], program_seq.to(self.device), 166 | torch.tensor([self.tokenizer.eos_token_id], device=self.device)), dim=0)) 167 | flatten_attention_mask = [torch.ones_like(flatten_ids) for flatten_ids in flatten_input_ids] 168 | flatten_input_ids = left_pad_sequences(flatten_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 169 | flatten_attention_mask = left_pad_sequences(flatten_attention_mask, batch_first=True, padding_value=0) 170 | 171 | # exclude the loss from context by setting the labels of them to be -100 172 | if self.exclude_context_loss: 173 | flatten_labels = [] 174 | for i, program_seqs in enumerate(cached_program_seqs): 175 | for j, program_seq in enumerate(program_seqs): 176 | concat_labels = torch.cat((-100 * torch.ones_like(context_seqs[i]), program_seq.to(self.device), 177 | torch.tensor([self.tokenizer.eos_token_id], device=self.device)), dim=0) 178 | flatten_labels.append(concat_labels) 179 | flatten_labels = left_pad_sequences(flatten_labels, batch_first=True, padding_value=self.tokenizer.pad_token_id) 180 | else: 181 | flatten_labels = flatten_input_ids 182 | assert flatten_labels.shape == flatten_input_ids.shape == flatten_attention_mask.shape, \ 183 | f"{flatten_labels.shape}, {flatten_input_ids.shape}, {flatten_attention_mask.shape}" 184 | 185 | # go through the gpt model as if it's a new batch 186 | gpt_result = self.gpt(input_ids=flatten_input_ids, attention_mask=flatten_attention_mask, labels=flatten_labels) 187 | 188 | # reorganize the logits to compute individual program log probs 189 | shift_logits = gpt_result.logits[..., :-1, :].contiguous() 190 | shift_labels = flatten_labels[..., 1:].contiguous() 191 | loss_fct = CrossEntropyLoss(reduction="none") 192 | flatten_unreduced_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 193 | unreduced_loss = flatten_unreduced_loss.view(shift_labels.shape) * flatten_attention_mask[..., 1:] 194 | 195 | # compute the marginal log prob 196 | loss = self.get_mixed_mle_mml_loss(unreduced_loss, cached_program_nums) 197 | self.log("loss", loss, on_step=True, on_epoch=True) 198 | self.log("fcp_buffer_size", sum(cached_program_nums) / len(cached_program_nums), on_step=False, on_epoch=True) 199 | 200 | return {"loss": loss} 201 | 202 | def get_mixed_mle_mml_loss(self, unreduced_loss: torch.Tensor, cached_program_nums: List[int], 203 | log_prob_dist: bool = True) -> torch.Tensor: 204 | """ 205 | Compute the loss for the MML and MLE. 206 | """ 207 | # compute the marginal log prob and the sum of the log probs 208 | grouped_example_log_probs = torch.split(-self.beta_smoothing * torch.sum(unreduced_loss, dim=1), cached_program_nums) 209 | marginal_log_probs = torch.stack([-1.0 * torch.logsumexp(log_probs, dim=0) / self.beta_smoothing for log_probs in grouped_example_log_probs]) 210 | norm_func = (lambda x: 1.0 ) if not self.mle_aug_norm else (lambda x: 1.0 / len(x)) 211 | sum_log_probs = torch.stack([-norm_func(log_probs) * torch.sum(log_probs, dim=0) for log_probs in grouped_example_log_probs]) 212 | loss = torch.mean(self.mml_lambda * marginal_log_probs + self.mle_lambda * sum_log_probs) 213 | 214 | if log_prob_dist: 215 | # some additional metrics to evaluate the distribution of the programs 216 | max_prob = [sorted(torch.exp(log_probs), reverse=True)[0] for log_probs in grouped_example_log_probs] 217 | second_max_prob = [sorted(torch.exp(log_probs), reverse=True)[1] 218 | if len(log_probs) > 1 else None for log_probs in grouped_example_log_probs] 219 | second_max_prob = list(filter(lambda x: x is not None, second_max_prob)) 220 | 221 | max_prob_avg = float(torch.pow(torch.stack(max_prob).mean(), 1.0 / self.beta_smoothing)) 222 | second_max_prob_avg = float(torch.pow(torch.stack(second_max_prob).mean(), 1.0 / self.beta_smoothing)) \ 223 | if len(second_max_prob) > 0 else 0.0 224 | 225 | self.log("max_prob", max_prob_avg, on_step=False, on_epoch=True) 226 | self.log("second_max_prob", second_max_prob_avg, on_step=False, on_epoch=True) 227 | 228 | return loss 229 | 230 | def forward( # type: ignore 231 | self, 232 | input_ids: torch.Tensor, 233 | attention_mask: torch.Tensor, 234 | metadata: Optional[List[Dict[str, Any]]] = None, 235 | ) -> List[Dict[str, Any]]: 236 | return super(GptStmtStateModel, self).forward(input_ids, attention_mask, metadata) 237 | 238 | def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: 239 | # first do all the things that the base model do 240 | super().validation_epoch_end(outputs) 241 | 242 | # then save the buffer status 243 | save_buffer_file_path = os.path.join(self.trainer.log_dir, 244 | f'buffer_step_{self.trainer.global_step}_rank_{self.trainer.global_rank}.jsonl') 245 | with open(save_buffer_file_path, 'w+') as f: 246 | for task_id, program_seqs in self.correct_program_buffer.items(): 247 | json_dict = {"task_id": task_id, "saved_programs": list(program_seqs)} 248 | f.write(json.dumps(json_dict) + '\n') 249 | print(f"buffer saved to {save_buffer_file_path}") 250 | 251 | def training_epoch_end(self, outputs) -> None: 252 | if not torch.distributed.is_initialized(): 253 | print("training_epoch_end: not using distributed training") 254 | return 255 | 256 | # gather all the buffers from all processes 257 | world_size = torch.distributed.get_world_size() 258 | all_buffer_list = [{} for _ in range(world_size)] 259 | torch.distributed.all_gather_object(all_buffer_list, self.correct_program_buffer) 260 | 261 | # merge all the buffers 262 | prev_avg_buffer_size = sum(map(lambda x: len(x[1]), self.correct_program_buffer.items())) / len(self.correct_program_buffer) 263 | merged_buffer: Dict[str, Set[str]] = {} 264 | for buffer in all_buffer_list: 265 | for task_id, programs in buffer.items(): 266 | if task_id not in merged_buffer: 267 | merged_buffer[task_id] = programs 268 | else: 269 | merged_buffer[task_id].update(programs) 270 | 271 | self.correct_program_buffer = merged_buffer 272 | after_avg_buffer_size = sum(map(lambda x: len(x[1]), self.correct_program_buffer.items())) / len(self.correct_program_buffer) 273 | print(f"buffer size increased from {prev_avg_buffer_size} to {after_avg_buffer_size}, by {after_avg_buffer_size - prev_avg_buffer_size}") 274 | 275 | -------------------------------------------------------------------------------- /lightning_modules/models/gpt_stmt_partial_mml_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import os 4 | import json 5 | import random 6 | 7 | from itertools import chain 8 | from concurrent.futures import ProcessPoolExecutor as Pool 9 | 10 | from typing import Optional, Dict, Any, Tuple, List, Set, Union 11 | from torch.nn import CrossEntropyLoss 12 | 13 | from torchmetrics import MeanMetric 14 | from pytorch_lightning import LightningModule 15 | 16 | from .gpt_stmt_mml_model import GptStmtMmlModel 17 | from .gpt_stmt_state_model import GptStmtStateModel 18 | from .gpt_seq2seq_model import post_process_code 19 | from .gpt_util import left_pad_sequences 20 | from execution.execution_evaluation import execution_acc, mathqa_execution, batch_exec_programs, mathqa_answer_eq 21 | from execution.program_tracing import get_execution_states, batch_program_tracing 22 | from execution.program_tracing import ProgState, ProgTrace, ProgTraceUnit, HashableProgState, Program 23 | 24 | def get_hashable_state(state_dict: ProgState) -> HashableProgState: 25 | if len(state_dict) == 0: 26 | raise ValueError("state_dict is empty, check if the code is empty or being gap") 27 | 28 | # get the values and sort it to make sure the order is the same FIXME: This only works for mathqa 29 | str_vars = [str(value) for _, value in state_dict.items()] 30 | 31 | # make it immutable to be hashable so that we can use it as a key 32 | return tuple(sorted(str_vars)) 33 | 34 | def mathqa_identify_fully_correct_state(state: ProgState, gold_answer: Any): 35 | if "type" in state and "code" in state: 36 | raise ValueError("mathqa_identify_fully_correct_state function only accepts states, not tracing units") 37 | 38 | if "answer" not in state: 39 | return False 40 | else: 41 | return mathqa_answer_eq(state["answer"], gold_answer) 42 | 43 | def mathqa_identify_output(state: ProgState): 44 | if "type" in state and "code" in state: 45 | raise ValueError("mathqa_identify_output function only accepts states, not tracing units") 46 | 47 | if "answer" in state: 48 | return True 49 | else: 50 | return False 51 | 52 | def construct_program_from_trace(trace: ProgTrace) -> Program: 53 | code = "".join([unit.code for unit in trace]) 54 | code_lite = post_process_code(code) 55 | return Program(code, code_lite, trace) 56 | 57 | def get_program_str_set(programs: List[Program]) -> Set[str]: 58 | return set([program.code_lite for program in programs]) 59 | 60 | def not_as_prefix_of_programs(program: Program, programs: List[Program]) -> bool: 61 | for prog in programs: 62 | if prog.code_lite.startswith(program.code_lite): 63 | return False 64 | return True 65 | 66 | def get_empty_program() -> Program: 67 | return Program("", "", []) 68 | 69 | def get_num_stmts(program: Program) -> int: 70 | return len(list(filter(lambda x: x.type == "stmt", program.trace))) 71 | 72 | class GptStmtPartialMmlModel(GptStmtMmlModel): 73 | def __init__(self, 74 | transformer_model_name: str, 75 | n_pcp_samples: int = 1, 76 | prioritize_fcp: bool = True, 77 | length_diff_tolerance: int = 100, 78 | sampling_from_states: bool = False, 79 | sampling_full_prog_only: bool = False, 80 | gold_program_only: bool = False, 81 | fcp_only: bool = False, 82 | norm_marg_by_len: bool = False, 83 | **kwargs) -> None: 84 | 85 | super().__init__(transformer_model_name, **kwargs) 86 | 87 | self.n_pcp_samples = n_pcp_samples 88 | assert n_pcp_samples == 1, "currently only support n_pcp_samples = 1" 89 | self.prioritize_fcp = prioritize_fcp 90 | self.length_diff_tolerance = length_diff_tolerance 91 | self.sampling_from_states = sampling_from_states 92 | self.sampling_full_prog_only = sampling_full_prog_only 93 | self.gold_program_only = gold_program_only 94 | self.fcp_only = fcp_only 95 | self.norm_marg_by_len = norm_marg_by_len 96 | 97 | # for each task id as the key, the value dict maps states to known sub-programs 98 | self.state_programs_dict: Dict[str, Dict[HashableProgState, List[Program]]] = {} 99 | 100 | # redefine it here to use new type hints 101 | self.correct_program_buffer: Dict[str, List[Program]] = {} 102 | 103 | # this needs to be differentiated from the fully correct programs 104 | # because we do not sample for the fully correct programs 105 | self.partially_correct_program_buffer: Dict[str, List[Program]] = {} 106 | 107 | 108 | def save_program_by_trace(self, task_id: str, tracing_states: Union[ProgTrace, None], 109 | is_fully_correct: bool) -> str: 110 | """ all the traces of the programs will have the final state being correct, 111 | so we try to save all of its subprograms by state 112 | 113 | return: the status of the saving attempt, either one of 114 | ["saved", "existing fcp", "existing pcp", "not valid"] """ 115 | 116 | # nothing to save if the program executes to error 117 | if tracing_states is None: 118 | return "S3: not valid" 119 | 120 | # construct the program tuple 121 | trace_program = construct_program_from_trace(tracing_states) 122 | 123 | def check_buffer_and_remove(buffer: Dict[str, List[Program]], program: Program): 124 | # check any of the existing program is the prefix of this new program, if so, remove the shorter one 125 | programs_to_remove = [] 126 | for i, saved_program in enumerate(buffer[task_id]): 127 | # empty program doesn't need to be removed 128 | if len(saved_program.code_lite) > 0 and program.code_lite.startswith(saved_program.code_lite): 129 | programs_to_remove.insert(0, i) 130 | for idx in programs_to_remove: 131 | buffer[task_id].pop(idx) 132 | 133 | # we only save the longest program 134 | if is_fully_correct: 135 | if trace_program.code_lite not in get_program_str_set(self.correct_program_buffer[task_id]): 136 | # check if any partial program is the prefix of this new fully correct program 137 | check_buffer_and_remove(self.partially_correct_program_buffer, trace_program) 138 | self.correct_program_buffer[task_id].append(trace_program) 139 | else: 140 | # there is nothing to save for the states dict since the full program is already in the buffer 141 | return "S4: existing fcp" 142 | elif not_as_prefix_of_programs(trace_program, self.partially_correct_program_buffer[task_id]) and \ 143 | not_as_prefix_of_programs(trace_program, self.correct_program_buffer[task_id]): 144 | # check if any partial program is the prefix of this new partially correct program 145 | check_buffer_and_remove(self.partially_correct_program_buffer, trace_program) 146 | self.partially_correct_program_buffer[task_id].append(trace_program) 147 | else: 148 | return "S5: existing pcp" 149 | 150 | # we try to save all the sub-programs of the program by state, excluding the final correct state 151 | tracing_states_to_save = tracing_states[:-1] if is_fully_correct else tracing_states 152 | for i, trace_unit in enumerate(tracing_states_to_save): 153 | if trace_unit.type != "stmt": 154 | continue 155 | 156 | sub_program = construct_program_from_trace(tracing_states_to_save[:i+1]) 157 | 158 | # check the state 159 | hashable_state = get_hashable_state(trace_unit.state) 160 | if hashable_state not in self.state_programs_dict[task_id]: 161 | self.state_programs_dict[task_id][hashable_state] = [sub_program] 162 | else: 163 | if sub_program.code_lite not in get_program_str_set(self.state_programs_dict[task_id][hashable_state]): 164 | self.state_programs_dict[task_id][hashable_state].append(sub_program) 165 | 166 | return "S6: saved" 167 | 168 | def check_and_save_partially_correct_program(self, task_id: str, 169 | program_trace: Union[ProgTrace, None], 170 | gold_answer: Any) -> str: 171 | """ check if there is a sub-program that worth saving """ 172 | 173 | if program_trace is None: 174 | return "S0: not executable" 175 | 176 | # see if the any of the states matches the saved correct states 177 | saved_states: Dict[HashableProgState, List[Program]] = self.state_programs_dict.get(task_id) 178 | 179 | # identify the longest partially (or fully) correct program that fits the constraints 180 | furthest_correct_state_idx = -1 181 | stmt_num = 0 182 | for i, tracing_unit in enumerate(program_trace): 183 | if tracing_unit.type != "stmt": 184 | continue 185 | else: 186 | stmt_num += 1 187 | 188 | if len(tracing_unit.state) == 0: 189 | continue # most likely because a comment is generated as the first line 190 | else: 191 | hashable_state = get_hashable_state(tracing_unit.state) 192 | 193 | reach_full_correct_state = mathqa_identify_fully_correct_state(tracing_unit.state, gold_answer) 194 | reach_output = mathqa_identify_output(tracing_unit.state) 195 | if reach_full_correct_state: 196 | min_fcp_len = min(list(map(get_num_stmts, self.correct_program_buffer[task_id]))) 197 | if stmt_num - min_fcp_len <= self.length_diff_tolerance: 198 | furthest_correct_state_idx = i 199 | break 200 | else: 201 | return "S1: fully correct but length exceeds tolerance" 202 | elif reach_output and not reach_full_correct_state: 203 | # if the program produce the incorrect output, we don't need to save it (but the prefix might still be useful) 204 | break 205 | elif hashable_state in saved_states: 206 | min_pcp_len = min(list(map(get_num_stmts, saved_states[hashable_state]))) 207 | if stmt_num - min_pcp_len <= self.length_diff_tolerance: 208 | furthest_correct_state_idx = i 209 | 210 | # save the all the sub-programs of this program by state 211 | if furthest_correct_state_idx != -1: 212 | is_fully_correct = mathqa_identify_fully_correct_state(program_trace[furthest_correct_state_idx].state, gold_answer) 213 | return self.save_program_by_trace(task_id, program_trace[:furthest_correct_state_idx + 1], is_fully_correct) 214 | else: 215 | return "S2: not partiallly correct or partially correct but length exceeds tolerance" 216 | 217 | def concat_context_with_multiple_programs(self, context_input_ids: torch.Tensor, 218 | context_attention_mask: torch.Tensor, 219 | programs: List[List[Program]], 220 | is_fully_correct: List[List[bool]]) -> str: 221 | """ concatenate each context tensor with multiple programs """ 222 | 223 | # remove the left paddings first and concat the context and cached programs 224 | context_seqs = [context_input_ids[i, -context_len:] 225 | for i, context_len in enumerate(context_attention_mask.sum(dim=1))] 226 | cached_program_seqs: List[List[torch.Tensor]] = [[self.tokenizer(prog.code, return_tensors="pt")['input_ids'][0] 227 | for prog in task_programs] 228 | for task_programs in programs] 229 | flatten_input_ids = [] 230 | for i, program_seqs in enumerate(cached_program_seqs): 231 | for j, program_seq in enumerate(program_seqs): 232 | concat_context_program = torch.cat([context_seqs[i], program_seq.to(dtype=context_seqs[i].dtype, device=self.device)], dim=0) 233 | if is_fully_correct[i][j]: 234 | concat_context_program = torch.cat((concat_context_program, torch.tensor([self.tokenizer.eos_token_id], device=self.device)), dim=0) 235 | flatten_input_ids.append(concat_context_program) 236 | flatten_attention_mask = [torch.ones_like(flatten_ids) for flatten_ids in flatten_input_ids] 237 | 238 | flatten_input_ids = left_pad_sequences(flatten_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 239 | flatten_attention_mask = left_pad_sequences(flatten_attention_mask, batch_first=True, padding_value=0) 240 | 241 | # exclude the loss from context by setting the labels of them to be -100 242 | if self.exclude_context_loss: 243 | flatten_labels = [] 244 | for i, program_seqs in enumerate(cached_program_seqs): 245 | for j, program_seq in enumerate(program_seqs): 246 | concat_labels = torch.cat([-100 * torch.ones_like(context_seqs[i]), 247 | program_seq.to(dtype=context_seqs[i].dtype, device=self.device)], dim=0) 248 | if is_fully_correct[i][j]: 249 | concat_labels = torch.cat((concat_labels, torch.tensor([self.tokenizer.eos_token_id], device=self.device)), dim=0) 250 | flatten_labels.append(concat_labels) 251 | flatten_labels = left_pad_sequences(flatten_labels, batch_first=True, padding_value=self.tokenizer.pad_token_id) 252 | else: 253 | flatten_labels = flatten_input_ids 254 | 255 | assert flatten_input_ids.shape[0] == len(flatten_attention_mask) == sum([len(progs) for progs in programs]) == flatten_labels.shape[0] 256 | 257 | return flatten_input_ids, flatten_attention_mask, flatten_labels 258 | 259 | def get_marg_program_set(self, task_id: str): 260 | """ get the programs to marginalize over for each of the task according to the specific settings """ 261 | 262 | # TODO: (maybe with fancier policy) first fill in the fully correct programs, 263 | # if not enough, fill in the partially correct programs 264 | programs, is_fully_correct = [], [] 265 | programs.extend(self.correct_program_buffer[task_id]) 266 | is_fully_correct.extend([True] * len(self.correct_program_buffer[task_id])) 267 | 268 | if self.gold_program_only: 269 | return programs[:1], is_fully_correct[:1] 270 | if self.fcp_only: 271 | return programs[:self.marg_set_size], is_fully_correct[:self.marg_set_size] 272 | 273 | non_empty_pcp_buffer = list(filter(lambda x: len(x.code_lite) > 0, self.partially_correct_program_buffer[task_id])) 274 | programs.extend(non_empty_pcp_buffer) 275 | is_fully_correct.extend([False] * len(non_empty_pcp_buffer)) 276 | 277 | if len(programs) > self.marg_set_size: 278 | if self.prioritize_fcp: 279 | return programs[:self.marg_set_size], is_fully_correct[:self.marg_set_size] 280 | else: 281 | # random sample the program indices, regardless of being pcp or fcp 282 | idx_sample = random.sample(range(len(programs)), self.marg_set_size) 283 | return [programs[i] for i in idx_sample], [is_fully_correct[i] for i in idx_sample] 284 | else: 285 | return programs, is_fully_correct 286 | 287 | def get_samples_for_completion(self, task_ids) -> List[List[Program]]: 288 | """ sample some unfinished programs from the buffer to do the completion sampling """ 289 | if self.sampling_from_states: 290 | # first sample a state for which the programs reach 291 | states = [random.sample(self.state_programs_dict[task_id].keys(), 1)[0] for task_id in task_ids] 292 | 293 | # then we sample a program that reaches the state 294 | programs = [random.sample(self.state_programs_dict[task_id][state], self.n_pcp_samples) 295 | for task_id, state in zip(task_ids, states)] 296 | elif self.sampling_full_prog_only: 297 | # get the blank program to sample the full program 298 | programs = [[self.partially_correct_program_buffer[task_id][0]] for task_id in task_ids] 299 | else: 300 | # sample a program from the pcp buffera 301 | programs = [random.sample(self.partially_correct_program_buffer[task_id], self.n_pcp_samples) 302 | for task_id in task_ids] 303 | 304 | return programs 305 | 306 | 307 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: 308 | # use task ids as the identifier 309 | task_ids = [ex["task_id"] for ex in batch["metadata"]] 310 | 311 | # add the gold programs to the correct program buffer and empty program to the partially correct program buffer 312 | for i, example in enumerate(batch["metadata"]): 313 | task_id = task_ids[i] 314 | # need to initialize the three data structures with this order 315 | if task_id not in self.state_programs_dict: 316 | # in order to be able to sample the empty program when sampling programs to be complete 317 | self.state_programs_dict[task_id] = {"NULL": [get_empty_program()]} 318 | 319 | if task_id not in self.partially_correct_program_buffer: 320 | # add empty program to the partially correct program buffer 321 | self.partially_correct_program_buffer[task_id] = [get_empty_program()] 322 | 323 | if task_id not in self.correct_program_buffer: 324 | # add the gold programs to the correct program buffer 325 | self.correct_program_buffer[task_id] = [] 326 | tracing_states = get_execution_states(example["code"]) 327 | return_msg = self.save_program_by_trace(task_id, tracing_states, is_fully_correct=True) 328 | assert return_msg == "S6: saved" 329 | 330 | # do on-policy sampling for the current tasks from the program prefixes 331 | context_input_ids = batch["input_ids"] 332 | context_attention_mask = batch["attention_mask"] 333 | 334 | if not self.gold_program_only: 335 | # TODO: maybe with fancier sampling methods, e.g., sampling from in-between the program 336 | pcp_samples: List[List[Program]] = self.get_samples_for_completion(task_ids) 337 | pcp_input_ids, pcp_attention_mask, _ = self.concat_context_with_multiple_programs(context_input_ids, 338 | context_attention_mask, pcp_samples, 339 | is_fully_correct=[[False]*len(progs) for progs in pcp_samples]) 340 | 341 | with torch.no_grad(): 342 | # generate the programs and get their execution results 343 | max_context_len = pcp_input_ids.size(1) 344 | output_seqs = self.gpt.generate(input_ids=pcp_input_ids, attention_mask=pcp_attention_mask, 345 | do_sample=True, max_new_tokens=self.max_sampling_len, 346 | num_return_sequences=self.on_policy_sample_num, 347 | temperature=self.on_policy_sample_temp) 348 | generated_seqs = output_seqs[:, max_context_len:].cpu() 349 | generated_program_completions = self.tokenizer.batch_decode(generated_seqs, skip_special_tokens=True) 350 | generated_programs = [pcp_samples[i // self.on_policy_sample_num][0].code + completion # FIXME: only works when self.n_pcp_samples == 1 351 | for i, completion in enumerate(generated_program_completions)] 352 | program_traces = batch_program_tracing(generated_programs) 353 | 354 | # save the programs with correct results into the buffer if it's not already in there 355 | results_count_dict = {f"S{i}": 0 for i in range(7)} 356 | for i, program_trace in enumerate(program_traces): 357 | example_idx = i // self.on_policy_sample_num 358 | gold_answer = batch["metadata"][example_idx]["answer"] 359 | task_id = task_ids[example_idx] 360 | result_msg = self.check_and_save_partially_correct_program(task_id, program_trace, gold_answer) 361 | results_count_dict[result_msg[:2]] += 1 362 | 363 | for k in results_count_dict.keys(): 364 | self.log(f"save_msg_{k}", results_count_dict[k] / len(program_traces), on_step=False, on_epoch=True) 365 | 366 | # select the set of programs to marginalize over for each task and tokenize + concatenate with context 367 | marg_programs: List[List[Program]] = [] 368 | marg_is_fully_correct: List[List[bool]] = [] 369 | marg_program_nums: List[int] = [] 370 | for task_id in task_ids: 371 | programs, is_fully_correct = self.get_marg_program_set(task_id) 372 | marg_programs.append(programs) 373 | marg_is_fully_correct.append(is_fully_correct) 374 | marg_program_nums.append(len(programs)) 375 | flatten_input_ids, flatten_attention_mask, flatten_labels = self.concat_context_with_multiple_programs(context_input_ids, 376 | context_attention_mask, marg_programs, marg_is_fully_correct) 377 | 378 | # go through the gpt model as if it's a new batch 379 | gpt_result = self.gpt(input_ids=flatten_input_ids, attention_mask=flatten_attention_mask, labels=flatten_labels) 380 | 381 | # reorganize the logits to compute individual program log probs 382 | shift_logits = gpt_result.logits[..., :-1, :].contiguous() 383 | shift_labels = flatten_labels[..., 1:].contiguous() 384 | loss_fct = CrossEntropyLoss(reduction="none") 385 | flatten_unreduced_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 386 | unreduced_loss =flatten_unreduced_loss.view(shift_labels.shape) * flatten_attention_mask[..., 1:] 387 | 388 | if self.norm_marg_by_len: 389 | # normalizing by length so we can avoid favoring shorter (e.g., partially-correct) programs 390 | seq_lens = torch.sum(flatten_attention_mask, dim=-1, keepdim=True) 391 | grouped_seq_len = torch.split(seq_lens, marg_program_nums) 392 | max_group_len = torch.cat([torch.full_like(seq_len, torch.max(seq_len)) for seq_len in grouped_seq_len], dim=0) 393 | assert seq_lens.shape == max_group_len.shape 394 | unreduced_loss = unreduced_loss / seq_lens * max_group_len 395 | 396 | # compute the marginal log prob 397 | loss = self.get_mixed_mle_mml_loss(unreduced_loss, marg_program_nums) 398 | self.log("loss", loss, on_step=True, on_epoch=True) 399 | 400 | self.log("marg_size", sum(marg_program_nums) / len(marg_program_nums), on_step=False, on_epoch=True) 401 | self.log("fcp_buffer_size", sum([len(self.correct_program_buffer[task_id]) for task_id in task_ids]) \ 402 | / len(task_ids), on_step=False, on_epoch=True) 403 | self.log("pcp_buffer_size", sum([len(self.partially_correct_program_buffer[task_id]) for task_id in task_ids]) \ 404 | / len(task_ids), on_step=False, on_epoch=True) 405 | self.log("state_prog_dict_size", sum([len(self.state_programs_dict[task_id]) for task_id in task_ids]) \ 406 | / len(task_ids), on_step=False, on_epoch=True) 407 | 408 | return {"loss": loss} 409 | 410 | def forward( # type: ignore 411 | self, 412 | input_ids: torch.Tensor, 413 | attention_mask: torch.Tensor, 414 | metadata: Optional[List[Dict[str, Any]]] = None, 415 | ) -> List[Dict[str, Any]]: 416 | return super(GptStmtStateModel, self).forward(input_ids, attention_mask, metadata) 417 | 418 | def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: 419 | # first do all the things that the base model do 420 | super(GptStmtMmlModel, self).validation_epoch_end(outputs) 421 | 422 | # then save the buffer status for the fully correct programs and the partially correct programs 423 | save_buffer_file_path = os.path.join(self.trainer.log_dir, 424 | f'buffer_step_{self.trainer.global_step}_rank_{self.trainer.global_rank}.jsonl') 425 | with open(save_buffer_file_path, 'w+') as f: 426 | for task_id, programs in self.correct_program_buffer.items(): 427 | json_dict = {"task_id": task_id, 428 | "saved_fcp_programs": list([prog.code for prog in programs]), 429 | "saved_pcp_programs": list([prog.code for prog in self.partially_correct_program_buffer[task_id]])} 430 | f.write(json.dumps(json_dict) + '\n') 431 | print(f"buffer saved to {save_buffer_file_path}") 432 | 433 | def merge_buffers(self, is_fcp_buffer: bool): 434 | # first identify which buffer to merge (full or partial) 435 | buffer_to_merge = self.correct_program_buffer if is_fcp_buffer else self.partially_correct_program_buffer 436 | 437 | world_size = torch.distributed.get_world_size() 438 | all_buffer_list: List[Dict[str, List[Program]]] = [{} for _ in range(world_size)] 439 | torch.distributed.all_gather_object(all_buffer_list, buffer_to_merge) 440 | 441 | # merge all the buffers 442 | prev_avg_buffer_size = sum(map(lambda x: len(x[1]), buffer_to_merge.items())) / len(buffer_to_merge) 443 | 444 | for buffer in all_buffer_list: 445 | for task_id, programs in buffer.items(): 446 | if task_id not in buffer_to_merge: 447 | assert task_id not in self.correct_program_buffer and \ 448 | task_id not in self.partially_correct_program_buffer and \ 449 | task_id not in self.state_programs_dict, \ 450 | f"task_id {task_id} should not be in any buffer" 451 | 452 | # init all three data structures 453 | # TODO: this can be optimized by simply assigning the three data structures since they are empty 454 | self.state_programs_dict[task_id] = {"NULL": [get_empty_program()]} 455 | self.partially_correct_program_buffer[task_id] = [get_empty_program()] 456 | self.correct_program_buffer[task_id] = [] 457 | 458 | for program in programs: 459 | self.save_program_by_trace(task_id, program.trace, is_fully_correct=is_fcp_buffer) 460 | 461 | after_avg_buffer_size = sum(map(lambda x: len(x[1]), buffer_to_merge.items())) / len(buffer_to_merge) 462 | print(f"{'fcp' if is_fcp_buffer else 'pcp'} buffer size increased " \ 463 | f"from {prev_avg_buffer_size} to {after_avg_buffer_size}, " \ 464 | f"by {after_avg_buffer_size - prev_avg_buffer_size}") 465 | 466 | 467 | def training_epoch_end(self, outputs) -> None: 468 | if not torch.distributed.is_initialized(): 469 | print("training_epoch_end: not using distributed training") 470 | return 471 | 472 | self.merge_buffers(is_fcp_buffer=True) 473 | self.merge_buffers(is_fcp_buffer=False) 474 | 475 | -------------------------------------------------------------------------------- /lightning_modules/models/gpt_stmt_state_model.py: -------------------------------------------------------------------------------- 1 | from numpy import dtype 2 | import torch 3 | import json 4 | import os 5 | import torch.nn.functional as F 6 | import pytorch_lightning as pl 7 | 8 | from typing import Optional, Dict, Any, Tuple, List 9 | from transformers.optimization import AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup 10 | from torch.nn import CrossEntropyLoss 11 | 12 | from torchmetrics import MeanMetric 13 | from pytorch_lightning import LightningModule 14 | 15 | from .gpt_util import get_gpt, sanity_check, left_pad_sequences 16 | from .gpt_seq2seq_model import GptSeq2SeqModel 17 | from execution.execution_evaluation import execution_acc, mathqa_execution 18 | from execution.execution_evaluation import execution_eval_at_k, batch_execution_acc 19 | from execution.program_tracing import exec_stmt_in_context, get_state_repr, is_trivial_state 20 | from execution.safe_execution_util import canonicalize_var_dict 21 | 22 | class GptStmtStateModel(GptSeq2SeqModel): 23 | def __init__(self, 24 | transformer_model_name: str, 25 | max_stmt_len: int = 20, 26 | max_stmt_num: int = 20, 27 | max_context_len: int = 1024, 28 | eval_with_states: bool = False, 29 | skip_trivial_states: bool = False, 30 | **kwargs) -> None: 31 | super().__init__(transformer_model_name, **kwargs) 32 | 33 | self.max_stmt_len = max_stmt_len 34 | self.max_stmt_num = max_stmt_num 35 | self.max_context_len = max_context_len 36 | self.eval_with_states = eval_with_states 37 | self.skip_trivial_states = skip_trivial_states 38 | 39 | def forward( # type: ignore 40 | self, 41 | input_ids: torch.Tensor, 42 | attention_mask: torch.Tensor, 43 | metadata: Optional[List[Dict[str, Any]]] = None, 44 | ) -> List[Dict[str, Any]]: 45 | """ 46 | The inference time behavior of the model. 47 | 48 | Args: 49 | input_ids [torch.Tensor]: Tokens from the context. 50 | metadata (Optional[List[Dict[str, Any]]], optional): All additional information, `List` for the batch. Defaults to None. 51 | 52 | Returns: 53 | Dict[str, Any]: results saved in a `Dict` object. 54 | """ 55 | output_dicts = [{"metadata": metadata[i]} for i in range(len(metadata))] 56 | 57 | # iteratively generate the stmts until eos is generated 58 | batch_size = len(metadata) 59 | completed_programs = [[] for _ in range(batch_size)] 60 | program_states_list = [[{}] for _ in range(batch_size)] 61 | incomplete_program_indices = list(range(batch_size)) 62 | for stmt_idx in range(self.max_stmt_num): 63 | # this is how many examples are left in the batch 64 | inner_batch_size = len(input_ids) 65 | 66 | max_context_len = input_ids.size(1) 67 | context_len_list = attention_mask.sum(dim=1) 68 | 69 | output_seqs = self.gpt.generate(input_ids=input_ids, attention_mask=attention_mask, 70 | do_sample=False, max_length=self.max_gen_len+max_context_len) # FIXME: this should be self.max_stmt_len 71 | # temperature=self.sampling_temp) # NOTE: now we assume only one seq is returned 72 | 73 | # remove the context and the tokens after the first newline token in the generated seq 74 | generated_seqs = [output_seqs[:, max_context_len:][i] for i in range(inner_batch_size)] 75 | for i, output_seq in enumerate(generated_seqs): 76 | nl_indices = (output_seq == self.tokenizer._convert_token_to_id(self.tokenizer.tokenize("\n")[0])).nonzero(as_tuple=True)[0] 77 | if len(nl_indices) > 0: 78 | first_nl_idx = int(nl_indices[0]) 79 | generated_seqs[i] = output_seq[:first_nl_idx+1] # +1 because we need to include the newline token 80 | else: 81 | # this means that the generation hits the max_stmt_len before the first newline token or an early 82 | # eos token is generated. Either way, we need to stop the generation, so we snap an (possibly 83 | # additional) eos token to the end of the generated seq. 84 | generated_seqs[i] = torch.cat([output_seq, torch.tensor([self.tokenizer.eos_token_id], device=output_seq.device)]) 85 | 86 | # concat the context with the generated seqs 87 | # full_seqs = [] 88 | # for i in range(inner_batch_size): 89 | # full_seq = torch.cat([input_ids[i][:context_len_list[i]], generated_seqs[i]]) 90 | # full_seqs.append(full_seq) 91 | 92 | # check if the end_of_sequence token is in the generated output 93 | incomplete_output_seqs = [] 94 | incomplete_program_indices_new = [] 95 | for i, output_seq in enumerate(generated_seqs): 96 | eos_indices = (output_seq == self.tokenizer.eos_token_id).nonzero(as_tuple=True)[0] 97 | if len(eos_indices) > 0: # eos detected, end the stmt generation loop 98 | # cut off **at** the eos token and it's not added to the incomplete list (it's finished) 99 | first_eos_idx = int(eos_indices[0]) 100 | output_seq = output_seq[:first_eos_idx] 101 | completed_programs[incomplete_program_indices[i]].extend(output_seq) 102 | 103 | # get the last state 104 | if self.eval_with_states: 105 | stmt_str = self.tokenizer.decode(output_seq, skip_special_tokens=True) 106 | last_state_dict = program_states_list[incomplete_program_indices[i]][-1] 107 | output_state = exec_stmt_in_context(stmt_str, last_state_dict) 108 | program_states_list[incomplete_program_indices[i]].append(output_state) 109 | else: 110 | if self.eval_with_states: 111 | # the generation is not finished yet, we need to augment the next steps with state information 112 | stmt_str = self.tokenizer.decode(output_seq, skip_special_tokens=True) 113 | last_state_dict = program_states_list[incomplete_program_indices[i]][-1] 114 | output_state = exec_stmt_in_context(stmt_str, last_state_dict) 115 | program_states_list[incomplete_program_indices[i]].append(output_state) 116 | 117 | if output_state == None: 118 | # the program will be not successfully executed, so we remove the program from the incomplete list 119 | completed_programs[incomplete_program_indices[i]].extend(output_seq) 120 | else: 121 | # incorporate the state into the context 122 | state_str = get_state_repr(output_state, only_include_keys=[stmt_str.split(" ")[0]], 123 | prev_stmt=stmt_str, skip_trivial_states=self.skip_trivial_states) 124 | state_tensor = self.tokenizer.encode(state_str, add_special_tokens=False, return_tensors="pt")[0] \ 125 | .to(device=output_seq.device, dtype=output_seq.dtype) 126 | output_seq = torch.cat([output_seq, state_tensor]) 127 | 128 | incomplete_output_seqs.append(torch.cat([input_ids[i][-context_len_list[i]:], output_seq])) 129 | completed_programs[incomplete_program_indices[i]].extend(output_seq) 130 | incomplete_program_indices_new.append(incomplete_program_indices[i]) 131 | else: 132 | incomplete_output_seqs.append(torch.cat([input_ids[i][-context_len_list[i]:], output_seq])) 133 | completed_programs[incomplete_program_indices[i]].extend(output_seq) 134 | incomplete_program_indices_new.append(incomplete_program_indices[i]) 135 | 136 | incomplete_program_indices = incomplete_program_indices_new 137 | 138 | if len(incomplete_output_seqs) == 0: 139 | # all seqs have been completed by generating the eos token 140 | break 141 | elif stmt_idx == self.max_stmt_num - 1: 142 | # reach the max stmt num, but still not all the seqs are completed 143 | # for i, incomplete_cell in enumerate(incomplete_output_seqs): 144 | # completed_programs[incomplete_program_indices[i]].extend( 145 | # torch.tensor([self.tokenizer.eos_token_id], device=output_seq.device)) 146 | break 147 | 148 | # reformulate the input_ids and attention_mask with the newly generated output 149 | incomplete_output_seqs = [output_seq[-self.max_context_len:] for output_seq in incomplete_output_seqs] 150 | attention_mask_list = [torch.ones_like(incomplete_output_seq) for incomplete_output_seq in incomplete_output_seqs] 151 | # pad to the same length and stack to be the new input_ids 152 | input_ids = left_pad_sequences(incomplete_output_seqs, batch_first=True, 153 | padding_value=self.tokenizer.eos_token_id) 154 | attention_mask = left_pad_sequences(attention_mask_list, batch_first=True, padding_value=False) 155 | 156 | # completed_cells are accumlated stmts for each program, not including the original context; decode back to the strs 157 | generated_programs= self.tokenizer.batch_decode(completed_programs) 158 | 159 | for i in range(len(metadata)): 160 | output_dicts[i].update({"generated_program": generated_programs[i]}) 161 | output_dicts[i].update({"generated_program_state_list": program_states_list[i]}) 162 | 163 | return output_dicts 164 | 165 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: 166 | input_ids = batch["input_ids"] 167 | attention_mask = batch["attention_mask"] 168 | labels = batch["labels"] if "labels" in batch else input_ids 169 | 170 | gpt_result = self.gpt(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 171 | self.log("loss", gpt_result.loss, on_step=True, on_epoch=True) 172 | 173 | if "state_mask" in batch: 174 | # log separately for loss on state tokens and non-state tokens 175 | state_mask = batch["state_mask"] 176 | 177 | # Shift so that tokens < n predict n 178 | shift_logits = gpt_result.logits[..., :-1, :].contiguous() 179 | shift_labels = labels[..., 1:].contiguous() 180 | shift_state_mask = state_mask[..., 1:].contiguous() 181 | 182 | # Flatten the tokens 183 | loss_fct = CrossEntropyLoss(reduction="none") 184 | unreduced_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 185 | 186 | code_token_loss = torch.sum(shift_state_mask.view(-1) * unreduced_loss) / torch.sum(shift_state_mask) 187 | state_token_loss = torch.sum((1 - shift_state_mask.view(-1)) * unreduced_loss) / torch.sum(1 - shift_state_mask) 188 | 189 | self.log("code_token_loss", code_token_loss, on_step=True, on_epoch=True) 190 | self.log("state_token_loss", state_token_loss, on_step=True, on_epoch=True) 191 | 192 | return {"loss": gpt_result.loss} 193 | 194 | def sanity_check_validation_step_end(self, outputs: List[Dict[str, Any]]) -> None: 195 | # update the evaluation metrics 196 | for output_dict in outputs: 197 | last_state_dict = output_dict["generated_program_state_list"][-1] 198 | if last_state_dict is not None and "answer" in last_state_dict: 199 | exec_rate = 1.0 200 | if last_state_dict["answer"] == output_dict["metadata"]["answer"]: 201 | exec_acc = 1.0 202 | else: 203 | exec_acc = 0.0 204 | else: 205 | exec_rate = 0.0 206 | exec_acc = 0.0 207 | output_dict.pop("generated_program_state_list") 208 | 209 | program_len = len(list(filter(lambda x: not x.startswith("#"), 210 | output_dict["generated_program"].split("\n")))) 211 | gold_program_len = len(list(filter(lambda x: not x.startswith("#"), output_dict["metadata"]["code"].split("\n")))) 212 | program_len_diff = program_len - gold_program_len 213 | 214 | self._num_metric_1(exec_acc) 215 | self._num_metric_2(exec_rate) 216 | 217 | self._num_metric_3(program_len_diff) 218 | 219 | output_dict["metrics"] = {"exec_acc": float(exec_acc), 220 | "exec_rate": float(exec_rate), 221 | "program_len_diff": float(program_len_diff)} 222 | 223 | # save the outputs to the model 224 | self.predictions.extend(outputs) 225 | -------------------------------------------------------------------------------- /lightning_modules/models/gpt_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple, Optional, List, Union 3 | 4 | from transformers import GPTNeoForCausalLM, GPT2Tokenizer 5 | from transformers import PreTrainedModel, PreTrainedTokenizer, GPT2LMHeadModel 6 | from transformers import GPT2Tokenizer, GPTJForCausalLM 7 | 8 | def get_gpt(model_name: str, 9 | tokenizer_only: bool = False, 10 | gradient_ckpt: bool = False, 11 | additional_special_tokens: Optional[List[str]] = None) \ 12 | -> Tuple[PreTrainedModel, PreTrainedTokenizer]: 13 | if additional_special_tokens is None: 14 | additional_special_tokens = [] 15 | 16 | if not tokenizer_only: 17 | print(f"using pretrained model: {model_name}, gradient_ckpt: {gradient_ckpt}") 18 | 19 | if model_name == "microsoft/CodeGPT-small-py": 20 | tokenizer = GPT2Tokenizer.from_pretrained(model_name, additional_special_tokens=additional_special_tokens) 21 | if not tokenizer_only: 22 | model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id) 23 | if len(additional_special_tokens) > 0: 24 | model.resize_token_embeddings(len(tokenizer)) 25 | if model_name == "EleutherAI/gpt-j-6B": 26 | tokenizer = GPT2Tokenizer.from_pretrained(model_name) 27 | tokenizer.pad_token = tokenizer.eos_token 28 | 29 | if not tokenizer_only: 30 | model = GPTJForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, 31 | gradient_checkpointing=gradient_ckpt, use_cache=not gradient_ckpt) 32 | if len(additional_special_tokens) > 0: 33 | model.resize_token_embeddings(len(tokenizer)) 34 | elif model_name in ["EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-2.7B"]: 35 | tokenizer = GPT2Tokenizer.from_pretrained(model_name, additional_special_tokens=additional_special_tokens) 36 | tokenizer.pad_token = tokenizer.eos_token 37 | 38 | if not tokenizer_only: 39 | model = GPTNeoForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, 40 | gradient_checkpointing=gradient_ckpt, use_cache=not gradient_ckpt) 41 | if len(additional_special_tokens) > 0: 42 | model.resize_token_embeddings(len(tokenizer)) 43 | else: 44 | raise NotImplementedError 45 | 46 | if tokenizer_only: 47 | return None, tokenizer 48 | else: 49 | return model, tokenizer 50 | 51 | def left_pad_sequences(sequences: List[torch.Tensor], batch_first: bool = True, padding_value: Union[int, bool] = 0, 52 | max_len: int = -1, device: torch.device = None) -> torch.Tensor: 53 | assert all([len(seq.shape) == 1 for seq in sequences]) 54 | max_len = max_len if max_len > 0 else max(len(s) for s in sequences) 55 | device = device if device is not None else sequences[0].device 56 | 57 | padded_seqs = [] 58 | for seq in sequences: 59 | padded_seqs.append(torch.cat((torch.full((max_len - seq.shape[0],), padding_value, dtype=torch.long).to(device), seq))) 60 | return torch.stack(padded_seqs) 61 | 62 | def sanity_check(test_str: str, model, tokenizer): 63 | print(f"test str is: ###############{test_str}##############") 64 | 65 | input_ids = tokenizer.encode(test_str, add_special_tokens=False, return_tensors="pt").to(model.device) 66 | attention_mask = torch.where(input_ids == tokenizer.eos_token_id, torch.zeros_like(input_ids), torch.ones_like(input_ids)).to(model.device) 67 | 68 | output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=40, num_return_sequences=1) 69 | 70 | output_str = tokenizer.decode(output_ids[0], skip_special_tokens=False, clean_up_tokenization_spaces=False) 71 | output_str_no_sp_tokens = tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) 72 | 73 | print(f"output str is: ###############{output_str}##############") 74 | 75 | new_test_str = " ".join(output_str_no_sp_tokens.split("\n")[:-1]) 76 | 77 | print(f"new test str is: ###############{new_test_str}###############") 78 | 79 | input_ids = tokenizer.encode(new_test_str, add_special_tokens=False, return_tensors="pt").to(model.device) 80 | attention_mask = torch.where(input_ids == tokenizer.eos_token_id, torch.zeros_like(input_ids), torch.ones_like(input_ids)).to(model.device) 81 | 82 | output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=40, num_return_sequences=1) 83 | 84 | output_str = tokenizer.decode(output_ids[0], skip_special_tokens=False, clean_up_tokenization_spaces=False) 85 | 86 | print(f"new output str is: ###############{output_str}###############") 87 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/preprocessing/__init__.py -------------------------------------------------------------------------------- /preprocessing/preprocess_gsm8k.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import json 4 | import math 5 | 6 | from typing import List, Dict, Any 7 | from tqdm import tqdm 8 | 9 | from execution.execution_evaluation import mathqa_execution, execution_acc 10 | from tree_sitter import Language, Parser 11 | 12 | # initialize the parser for the code 13 | language_build_path = os.path.join(os.path.dirname(__file__), 'py-tree-sitter.so') 14 | PY_LANGUAGE = Language(language_build_path, 'python') 15 | parser = Parser() 16 | parser.set_language(PY_LANGUAGE) 17 | 18 | train_file = "./data/gsmath/train.jsonl" 19 | test_file = "./data/gsmath/test.jsonl" 20 | 21 | def get_answer_from_answer_str(answer_str: str) -> float: 22 | result_str = answer_str.split("\n")[-1].split(" ")[-1] 23 | result = float(result_str.replace(",", "")) 24 | 25 | return result 26 | 27 | def get_code_from_answer_str(answer_str: str, question_str: str) -> str: 28 | # reverse_var_dict only keeps the constants and the t_lines does not contain the constant inits 29 | reverse_var_dict: Dict[float, str] = {} 30 | reverse_temp_var_dict: Dict[float, str] = {} 31 | temp_var_num = 0 32 | t_lines = [] 33 | 34 | for line in answer_str.split("\n")[:-1]: 35 | if not ("<<" in line and ">>" in line): 36 | continue 37 | 38 | # first extract the formula 39 | formula = line[line.index("<<") + 2: line.index(">>")] 40 | 41 | def get_var_name(var_str: str, allow_new: bool = True) -> str: 42 | num = float(var_str) 43 | if num in reverse_temp_var_dict: 44 | var_name = reverse_temp_var_dict[num] 45 | elif num in reverse_var_dict: 46 | var_name = reverse_var_dict[num] 47 | elif allow_new: 48 | # a new constant 49 | var_name = f"n{len(reverse_var_dict)}" 50 | reverse_var_dict[num] = var_name 51 | else: 52 | raise ValueError(f"{var_str} not found in var/temp dict") 53 | 54 | return var_name 55 | 56 | def get_node_text(node, text) -> str: 57 | return text[node.start_byte: node.end_byte] 58 | 59 | # make sure that the formula is valid 60 | expression, result = formula.split("=") 61 | if "/" in result: 62 | result = eval(result) 63 | if not eval(expression) == float(result): 64 | return "NULL" 65 | 66 | # interpret the formula with a parse tree 67 | assert expression.isascii, f"{expression} is not ascii" 68 | parsed_tree = parser.parse(bytes(expression, 'utf-8')) 69 | 70 | # do a dfs on the parsed tree to get the values replaced with names 71 | formula_bits = [] 72 | node_stack = [parsed_tree.root_node.children[0].children[0]] 73 | while len(node_stack) > 0: 74 | node = node_stack.pop() 75 | 76 | if node.type in ["integer", "float"]: 77 | var_name = get_var_name(get_node_text(node, expression)) 78 | formula_bits.append(var_name) 79 | elif node.type in ["+", "-", "*", "/", "**", "(", ")", "//"]: 80 | formula_bits.append(get_node_text(node, expression)) 81 | elif node.type in ["binary_operator", "parenthesized_expression"]: 82 | node_stack.extend(node.children[::-1]) 83 | elif node.type == "unary_operator": 84 | if node.children[0].type == "+": 85 | var_name = get_var_name(get_node_text(node, expression)) 86 | formula_bits.append(var_name) 87 | elif node.children[0].type == "-": 88 | val = -float(get_node_text(node, expression)) 89 | if val in reverse_temp_var_dict or val in reverse_var_dict: 90 | formula_bits.append(get_var_name(val, allow_new=False)) 91 | elif -val in reverse_temp_var_dict or val in reverse_var_dict: 92 | formula_bits.append("-"+get_var_name(-val, allow_new=False)) 93 | else: 94 | formula_bits.append(get_var_name(val, allow_new=True)) 95 | else: 96 | raise ValueError(f"{expression} has unary operator {node.children[0].type}") 97 | else: 98 | raise ValueError(f"{expression} has {node.type}") 99 | 100 | right_formula = "".join(formula_bits) 101 | 102 | # add the temporary var 103 | # NOTE: we can't use the len(reverse_temp_var_dict) because we may have the same temp var in different lines 104 | temp_var_name = f"t{temp_var_num}" 105 | temp_var_num += 1 106 | reverse_temp_var_dict[float(result)] = temp_var_name 107 | 108 | # add the line 109 | t_lines.append(f"{temp_var_name}={right_formula}") 110 | 111 | # add the const var inits 112 | init_lines = [] 113 | sorted_var_dict = sorted(reverse_var_dict.items(), key=lambda x: int(x[1][1:])) 114 | for var_val, var_name in sorted_var_dict: 115 | # if the float var is not directly used, and it can be casted as int, do cast as init 116 | if not str(var_val) in question_str and math.isclose(int(var_val), var_val, abs_tol=1e-4): 117 | init_lines.append(f"{var_name}={int(var_val)}") 118 | else: 119 | init_lines.append(f"{var_name}={var_val}") 120 | 121 | 122 | if len(t_lines) == 0: 123 | # no <> are given for this example, simply skip 124 | return "NULL" 125 | 126 | # replace the last line's temp var name with "answer" 127 | t_lines[-1] = "answer=" + t_lines[-1].split("=")[1] 128 | 129 | return "\n".join(init_lines + t_lines) 130 | 131 | def verify_code(code: str, gold_answer: str) -> bool: 132 | try: 133 | exec(code) 134 | if float(gold_answer) == float(eval("answer")): 135 | return True 136 | else: 137 | return False 138 | except Exception as e: 139 | return False 140 | 141 | def process_gsmath(instances: List[Dict[str, str]], set_name: str) -> List[Dict[str, Any]]: 142 | failed_code_extraction_indices = [] 143 | for i, instance in tqdm(enumerate(instances)): 144 | # put it in the mathqa style: text, code, answer, task_id 145 | instance["text"] = instance["question"] 146 | instance.pop("question") 147 | 148 | instance["original_answer"] = instance["answer"] 149 | instance["task_id"] = f"{set_name}_{i}" 150 | 151 | instance["code"] = get_code_from_answer_str(instance["original_answer"], instance["text"]) 152 | instance["answer"] = get_answer_from_answer_str(instance["original_answer"]) 153 | 154 | if instance["code"] == "NULL": 155 | # failed to extract code, will skip this example in training, and only record for dev/test 156 | failed_code_extraction_indices.append(i) 157 | 158 | # verify the validity of the code 159 | failed_code_execution_indices = [] 160 | for i, instance in enumerate(instances): 161 | if i in failed_code_extraction_indices: 162 | continue 163 | 164 | if not verify_code(instance["code"], instance["answer"]): 165 | failed_code_execution_indices.append(i) 166 | # print(f"{instance['task_id']} failed to verify, " \ 167 | # f"original_answer: {instance['original_answer']}, " \ 168 | # f"code: \n{instance['code']}\nanswer: {instance['answer']}") 169 | 170 | all_failed_indices = sorted(failed_code_extraction_indices + failed_code_execution_indices) 171 | 172 | print(f"{len(failed_code_extraction_indices)}/{len(instances)} failed to extract code") 173 | print(f"{len(failed_code_execution_indices)}/{len(instances)} failed to execute to the correct result") 174 | print(f"{len(all_failed_indices)}/{len(instances)} failed in total") 175 | 176 | # remove the failed examples if this is training set 177 | if set_name == "train": 178 | for i in all_failed_indices[::-1]: 179 | instances.pop(i) 180 | 181 | return instances 182 | 183 | def main(): 184 | # load the train and test data 185 | with open(train_file, "r") as f: 186 | train_lines = f.readlines() 187 | train_data = [json.loads(line) for line in train_lines] 188 | 189 | with open(test_file, "r") as f: 190 | test_lines = f.readlines() 191 | test_data = [json.loads(line) for line in test_lines] 192 | 193 | # split the train data to train and dev 194 | train_data, dev_data = train_data[:int(len(train_data) * 0.8)], train_data[int(len(train_data) * 0.8):] 195 | 196 | # process all the data 197 | processed_train_data = process_gsmath(train_data, "train") 198 | processed_dev_data = process_gsmath(dev_data, "val") 199 | processed_test_data = process_gsmath(test_data, "test") 200 | 201 | # write the processed data to files 202 | with open("./data/gsmath/gsmath_train.jsonl", "w") as f: 203 | f.write("\n".join([json.dumps(data) for data in processed_train_data])) 204 | with open("./data/gsmath/gsmath_val.jsonl", "w") as f: 205 | f.write("\n".join([json.dumps(data) for data in processed_dev_data])) 206 | with open("./data/gsmath/gsmath_test.jsonl", "w") as f: 207 | f.write("\n".join([json.dumps(data) for data in processed_test_data])) 208 | 209 | def prune_gsmath(file_name: str) -> None: 210 | assert file_name.endswith(".jsonl") 211 | 212 | # load the data 213 | with open(file_name, "r") as f: 214 | train_lines = f.readlines() 215 | instances = [json.loads(line) for line in train_lines] 216 | 217 | failed_code_extraction_indices = [] 218 | for i, instance in tqdm(enumerate(instances)): 219 | if instance["code"] == "NULL": 220 | # failed to extract code, will skip this example in training, and only record for dev/test 221 | failed_code_extraction_indices.append(i) 222 | 223 | # verify the validity of the code 224 | failed_code_execution_indices = [] 225 | for i, instance in enumerate(instances): 226 | if i in failed_code_extraction_indices: 227 | continue 228 | 229 | if not verify_code(instance["code"], instance["answer"]): 230 | failed_code_execution_indices.append(i) 231 | 232 | all_failed_indices = sorted(failed_code_extraction_indices + failed_code_execution_indices) 233 | 234 | print(f"{len(failed_code_extraction_indices)}/{len(instances)} failed to extract code") 235 | print(f"{len(failed_code_execution_indices)}/{len(instances)} failed to execute to the correct result") 236 | print(f"{len(all_failed_indices)}/{len(instances)} failed in total") 237 | 238 | # remove the failed examples if this is training set 239 | for i in all_failed_indices[::-1]: 240 | instances.pop(i) 241 | 242 | with open(f"{file_name[:-6]}_pruned.jsonl", "w") as f: 243 | f.write("\n".join([json.dumps(ins) for ins in instances])) 244 | 245 | if __name__ == "__main__": 246 | # prune_gsmath("./data/gsmath/gsmath_val.jsonl") 247 | main() -------------------------------------------------------------------------------- /preprocessing/preprocess_mathqa_python.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from typing import List, Dict, Tuple, Any, Union 4 | 5 | def read_jsonl_file(file_path: str) -> List[Dict[str, Any]]: 6 | with open(file_path, "r") as f: 7 | examples = [json.loads(line) for line in f] 8 | 9 | return examples 10 | 11 | def create_resplit_index(): 12 | original_train_examples = read_jsonl_file("data/mathqa/train-python.jsonl") 13 | original_dev_examples = read_jsonl_file("data/mathqa/val-python.jsonl") 14 | original_all_examples = original_train_examples + original_dev_examples 15 | 16 | text_to_example_idx = {x["text"].strip(): i for i, x in enumerate(original_all_examples)} 17 | 18 | dedup_train_examples = read_jsonl_file("data/mathqa/train_dedup.jsonl") 19 | dedup_dev_examples = read_jsonl_file("data/mathqa/val_dedup.jsonl") 20 | 21 | dedup_train_idx = [text_to_example_idx[x["text"].strip()] for x in dedup_train_examples] 22 | dedup_dev_idx = [text_to_example_idx[x["text"].strip()] for x in dedup_dev_examples] 23 | 24 | with open("preprocessing/mathqa_python_resplit_info.json", "w+") as f: 25 | info = { 26 | "train_first_examples": dedup_train_examples[:5], 27 | "dev_first_examples": dedup_dev_examples[:5], 28 | "train_idx": dedup_train_idx, 29 | "dev_idx": dedup_dev_idx 30 | } 31 | json.dump(info, f) 32 | 33 | def recreate_split(): 34 | # load the precomputed resplit info 35 | with open("preprocessing/mathqa_python_resplit_info.json", "r") as f: 36 | resplit_info = json.load(f) 37 | 38 | # load the original examples 39 | original_train_examples = read_jsonl_file("data/mathqa/train-python.jsonl") 40 | original_dev_examples = read_jsonl_file("data/mathqa/val-python.jsonl") 41 | original_all_examples = original_train_examples + original_dev_examples 42 | 43 | # recreate the split using the resplit info 44 | dedup_train_examples = [original_all_examples[i] for i in resplit_info["train_idx"]] 45 | dedup_dev_examples = [original_all_examples[i] for i in resplit_info["dev_idx"]] 46 | 47 | # rename the task ids 48 | for i, instance in enumerate(dedup_train_examples): 49 | instance['task_id'] = f"train_{i}" 50 | for i, instance in enumerate(dedup_dev_examples): 51 | instance['task_id'] = f"val_{i}" 52 | 53 | # verify that the split is correct 54 | assert all([x[0]["text"] == x[1]["text"] for x in zip(dedup_train_examples[:5], resplit_info["train_first_examples"])]), "train split is incorrect" 55 | assert all([x[0]["text"] == x[1]["text"] for x in zip(dedup_dev_examples[:5], resplit_info["dev_first_examples"])]), "dev split is incorrect" 56 | 57 | # write the recreated split to disk 58 | with open("data/mathqa/train_dedup.jsonl", "w+") as f: 59 | for example in dedup_train_examples: 60 | f.write(json.dumps(example) + "\n") 61 | 62 | with open("data/mathqa/val_dedup.jsonl", "w+") as f: 63 | for example in dedup_dev_examples: 64 | f.write(json.dumps(example) + "\n") 65 | 66 | def main(): 67 | # create_resplit_index() 68 | recreate_split() 69 | 70 | if __name__ == "__main__": 71 | main() -------------------------------------------------------------------------------- /preprocessing/py-tree-sitter.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/preprocessing/py-tree-sitter.so -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # azureml-sdk == 1.26.0 2 | tensorboard ~= 2.4 3 | huggingface-hub == 0.1.2 4 | tree-sitter ~= 0.19.0 5 | torchmetrics == 0.6.0 6 | 7 | transformers == 4.16.2 8 | torch == 1.10.2 9 | pytorch-lightning == 1.5.10 10 | deepspeed ~= 0.5.10 11 | 12 | # to avoid getting an error of "fit.optimizer" conflict namespace error 13 | jsonargparse == 3.19.4 14 | 15 | nltk 16 | rouge-score 17 | jsonargparse[signatures] 18 | overrides 19 | neptune-client 20 | scipy 21 | editdistance 22 | bashplotlib 23 | wandb 24 | astunparse 25 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningModule, LightningDataModule 2 | from pytorch_lightning.utilities.cli import LightningCLI 3 | 4 | # see https://github.com/PyTorchLightning/pytorch-lightning/issues/10349 5 | import warnings 6 | 7 | warnings.filterwarnings( 8 | "ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*" 9 | ) 10 | 11 | cli = LightningCLI(LightningModule, LightningDataModule, 12 | subclass_mode_model=True, subclass_mode_data=True, 13 | save_config_callback=None) 14 | -------------------------------------------------------------------------------- /training_configs/gpt_mle.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 333 2 | trainer: 3 | gpus: 2 4 | gradient_clip_val: 1.0 5 | default_root_dir: debug-tmp 6 | # val_check_interval: 1.0 7 | max_steps: &max_steps 50000 8 | check_val_every_n_epoch: 2 9 | log_every_n_steps: 1 10 | num_sanity_val_steps: 0 11 | logger: 12 | - class_path: lightning_modules.loggers.patched_loggers.PatchedWandbLogger 13 | init_args: 14 | entity: niansong1996 15 | project: trace-codegen 16 | name: debug-tmp 17 | log_model: False 18 | save_code: True 19 | offline: False 20 | callbacks: 21 | - class_path: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 22 | init_args: 23 | monitor: exec_acc 24 | mode: max 25 | filename: '{step}-{exec_acc:.4f}-{exec_rate:.4f}' 26 | save_top_k: 5 27 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 28 | init_args: 29 | logging_interval: step 30 | - class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar 31 | init_args: 32 | refresh_rate: 1 33 | 34 | accelerator: gpu 35 | # replace_sampler_ddp: False 36 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/8262 37 | strategy: deepspeed_stage_2 38 | # strategy: ddp_find_unused_parameters_false 39 | # precision: 16 40 | # accumulate_grad_batches: 4 41 | 42 | model: 43 | class_path: lightning_modules.models.gpt_seq2seq_model.GptSeq2SeqModel 44 | init_args: 45 | transformer_model_name: &transformer EleutherAI/gpt-neo-2.7B 46 | max_gen_len: 256 47 | sampling_temp: 0.2 48 | sampling_temp_at_k: 0.8 49 | # pass_at_k: 80 50 | # eval_pass_at_k_every_n_epochs: 4 51 | # additional_pass_at_k: [5, 10, 20, 50] 52 | # max_generation_batches: 50 53 | # always_eval_pass_at_k_first_n_epochs: 10 54 | gradient_ckpt: true 55 | measure_dedup_metrics: false 56 | # eval_greedy_search: true 57 | # load_ckpt_file: /home/v-ansongni/data/trace-codegen-data-wuphillyblob/data/mathqa/model_ckpts/step=6904-exec_acc=0.4480-exec_rate=0.6445.ckpt 58 | # load_ckpt_file: /home/v-ansongni/Code/trace-codegen/amlt/mathqa-finetune-gpt-neo-125M-pad-left/gpt-neo-mathqa-finetuning/lightning_logs/version_0/checkpoints/step=54044-exec_acc=0.7715-exec_rate=0.9893.ckpt 59 | optimizer: 60 | class_path: torch.optim.adamw.AdamW 61 | init_args: 62 | lr: 1.0e-4 63 | # lr: 0.0 64 | betas: 65 | - 0.9 66 | - 0.999 67 | eps: 1.0e-8 68 | weight_decay: 0.1 69 | lr_scheduler: 70 | name: linear 71 | init_args: 72 | num_warmup_steps: 100 73 | num_training_steps: *max_steps 74 | 75 | data: 76 | class_path: lightning_modules.datasets.mathqa_line_reader.MathQADataModule 77 | init_args: 78 | transformer_model_name: *transformer 79 | batch_size: 2 80 | val_batch_size: 4 81 | # train_file_path: data/mathqa/train_dedup.jsonl 82 | # val_file_path: data/mathqa/val_dedup.jsonl 83 | train_file_path: data/gsmath/gsmath_train_val.jsonl 84 | val_file_path: data/gsmath/gsmath_test.jsonl 85 | # train_max_instances: 40 86 | # val_max_instances: 20 87 | # few_shot_n: 4 -------------------------------------------------------------------------------- /training_configs/gpt_self_sampling.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 333 2 | trainer: 3 | gpus: 1 4 | gradient_clip_val: 1.0 5 | default_root_dir: debug-tmp 6 | # val_check_interval: 1.0 7 | max_steps: &max_steps 50000 8 | check_val_every_n_epoch: 2 9 | log_every_n_steps: 1 10 | num_sanity_val_steps: 0 11 | logger: 12 | - class_path: lightning_modules.loggers.patched_loggers.PatchedWandbLogger 13 | init_args: 14 | entity: niansong1996 15 | project: trace-codegen 16 | name: debug-tmp 17 | log_model: False 18 | save_code: True 19 | offline: False 20 | callbacks: 21 | - class_path: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 22 | init_args: 23 | monitor: exec_acc 24 | mode: max 25 | filename: '{step}-{exec_acc:.4f}-{exec_rate:.4f}' 26 | save_top_k: 5 27 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 28 | init_args: 29 | logging_interval: step 30 | - class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar 31 | init_args: 32 | refresh_rate: 1 33 | 34 | accelerator: gpu 35 | # replace_sampler_ddp: False 36 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/8262 37 | strategy: deepspeed_stage_2_offload 38 | # strategy: ddp_find_unused_parameters_false 39 | # precision: 16 40 | # accumulate_grad_batches: 4 41 | model: 42 | class_path: lightning_modules.models.gpt_stmt_mml_model.GptStmtMmlModel 43 | init_args: 44 | transformer_model_name: &transformer EleutherAI/gpt-neo-2.7B 45 | max_gen_len: 256 46 | max_sampling_len: 100 47 | sampling_temp: 0.2 48 | on_policy_sample_num: 1 49 | on_policy_sample_temp: 0.8 50 | sampling_temp_at_k: 0.8 51 | # pass_at_k: 80 52 | # additional_pass_at_k: [5, 10, 20, 50] 53 | # eval_pass_at_k_every_n_epochs: 1 54 | # max_generation_batches: 50 55 | gradient_ckpt: true 56 | measure_dedup_metrics: false 57 | length_diff_tolerance: 0 58 | exclude_context_loss: false 59 | mle_lambda: 1.0 60 | mml_lambda: 0.0 61 | # beta_smoothing: 0.25 62 | # marg_set_size: 10 63 | # max_buffer_size: 10 64 | # eval_greedy_search: true 65 | # load_samples_file: /home/v-ansongni/Code/trace-codegen/mathqa_dedup_train_samples.jsonl 66 | # load_ckpt_file: /home/v-ansongni/Code/trace-codegen/amlt/mathqa-state-finetuning-125M-line-state-all-mask-ctx-len-1648/gpt-neo-mathqa-state-finetuning/lightning_logs/version_0/checkpoints/step=46175-exec_acc=0.6650-exec_rate=0.9956.ckpt 67 | # load_ckpt_file: /home/v-ansongni/Code/trace-codegen/amlt/mathqa-dedup-partial-mml-len_norm-low-temp/gpt-neo-mathqa-state-finetuning/lightning_logs/version_0/checkpoints/step=55755-exec_acc=0.1263-exec_rate=0.9727.ckpt 68 | # load_ckpt_file: /home/v-ansongni/Code/trace-codegen/amlt/mathqa-125M-state-skip-ts/gpt-neo-mathqa-state-finetuning/lightning_logs/version_0/checkpoints/step=49727-exec_acc=0.7559-exec_rate=0.9956.ckpt 69 | # load_ckpt_file: /home/v-ansongni/Code/trace-codegen/amlt/mathqa-finetune-gpt-neo-125M-pad-left/gpt-neo-mathqa-finetuning/lightning_logs/version_0/checkpoints/step=54044-exec_acc=0.7715-exec_rate=0.9893.ckpt 70 | optimizer: 71 | class_path: torch.optim.adamw.AdamW 72 | init_args: 73 | lr: 1.0e-4 74 | # lr: 0.0 75 | betas: 76 | - 0.9 77 | - 0.999 78 | eps: 1.0e-8 79 | weight_decay: 0.1 80 | lr_scheduler: 81 | name: linear 82 | init_args: 83 | num_warmup_steps: 100 84 | num_training_steps: *max_steps 85 | 86 | data: 87 | class_path: lightning_modules.datasets.mathqa_line_reader.MathQAMmlDataModule 88 | init_args: 89 | transformer_model_name: *transformer 90 | batch_size: 2 91 | val_batch_size: 4 92 | train_file_path: data/mathqa/train_dedup.jsonl 93 | val_file_path: data/mathqa/val_dedup.jsonl 94 | # train_max_instances: 40 95 | # val_max_instances: 20 96 | # few_shot_n: 4 -------------------------------------------------------------------------------- /training_configs/gpt_self_sampling_partial.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 333 2 | trainer: 3 | gpus: 1 4 | gradient_clip_val: 1.0 5 | default_root_dir: debug-tmp 6 | # val_check_interval: 1.0 7 | max_steps: &max_steps 50000 8 | check_val_every_n_epoch: 2 9 | log_every_n_steps: 1 10 | num_sanity_val_steps: 0 11 | logger: 12 | - class_path: lightning_modules.loggers.patched_loggers.PatchedWandbLogger 13 | init_args: 14 | entity: niansong1996 15 | project: trace-codegen 16 | name: debug-tmp 17 | log_model: False 18 | save_code: True 19 | offline: False 20 | callbacks: 21 | - class_path: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 22 | init_args: 23 | monitor: exec_acc 24 | mode: max 25 | filename: '{step}-{exec_acc:.4f}-{exec_rate:.4f}' 26 | save_top_k: 5 27 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 28 | init_args: 29 | logging_interval: step 30 | - class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar 31 | init_args: 32 | refresh_rate: 1 33 | 34 | accelerator: gpu 35 | # replace_sampler_ddp: False 36 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/8262 37 | # strategy: deepspeed_stage_2 38 | strategy: ddp_find_unused_parameters_false 39 | # precision: 16 40 | # accumulate_grad_batches: 4 41 | 42 | model: 43 | class_path: lightning_modules.models.gpt_stmt_partial_mml_model.GptStmtPartialMmlModel 44 | init_args: 45 | transformer_model_name: &transformer EleutherAI/gpt-neo-2.7B 46 | max_gen_len: 256 47 | max_sampling_len: 100 48 | sampling_temp: 0.2 49 | on_policy_sample_num: 1 50 | on_policy_sample_temp: 0.8 51 | sampling_temp_at_k: 0.8 52 | # pass_at_k: 80 53 | # eval_pass_at_k_every_n_epochs: 1 54 | # max_generation_batches: 50 55 | # additional_pass_at_k: [5, 10, 20, 50] 56 | gradient_ckpt: true 57 | measure_dedup_metrics: false 58 | length_diff_tolerance: 0 59 | sampling_from_states: false 60 | mle_lambda: 1.0 61 | mml_lambda: 0.0 62 | # beta_smoothing: 0.25 63 | # containment_based_pc: true 64 | # sampling_full_prog_only: true 65 | # norm_marg_by_len: true 66 | # fcp_only: true 67 | # gold_program_only: true 68 | exclude_context_loss: false 69 | # prioritize_fcp: false 70 | # marg_set_size: 10 71 | # max_buffer_size: 10 72 | # eval_greedy_search: true 73 | optimizer: 74 | class_path: torch.optim.adamw.AdamW 75 | init_args: 76 | lr: 1.0e-4 77 | # lr: 0.0 78 | betas: 79 | - 0.9 80 | - 0.999 81 | eps: 1.0e-8 82 | weight_decay: 0.1 83 | lr_scheduler: 84 | name: linear 85 | init_args: 86 | num_warmup_steps: 100 87 | num_training_steps: *max_steps 88 | 89 | data: 90 | class_path: lightning_modules.datasets.mathqa_line_reader.MathQAMmlDataModule 91 | init_args: 92 | transformer_model_name: *transformer 93 | batch_size: 2 94 | val_batch_size: 4 95 | train_file_path: data/mathqa/train_dedup.jsonl 96 | val_file_path: data/mathqa/val_dedup.jsonl 97 | train_max_instances: 40 98 | val_max_instances: 20 99 | # few_shot_n: 4 100 | --------------------------------------------------------------------------------