├── .github
└── workflows
│ ├── codeql.yml
│ └── dependency-review.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── example.png
├── execution
├── __init__.py
├── execution_evaluation.py
├── program_tracing.py
└── safe_execution_util.py
├── lightning_modules
├── __init__.py
├── callbacks
│ ├── __init__.py
│ └── save_prediction_callback.py
├── datasets
│ ├── __init__.py
│ ├── mathqa_line_reader.py
│ └── reader_utils.py
├── loggers
│ ├── __init__.py
│ └── patched_loggers.py
└── models
│ ├── __init__.py
│ ├── gpt_seq2seq_model.py
│ ├── gpt_stmt_mml_model.py
│ ├── gpt_stmt_partial_mml_model.py
│ ├── gpt_stmt_state_model.py
│ └── gpt_util.py
├── preprocessing
├── __init__.py
├── mathqa_python_resplit_info.json
├── preprocess_gsm8k.py
├── preprocess_mathqa_python.py
└── py-tree-sitter.so
├── requirements.txt
├── trainer.py
└── training_configs
├── gpt_mle.yaml
├── gpt_self_sampling.yaml
└── gpt_self_sampling_partial.yaml
/.github/workflows/codeql.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ "main" ]
17 | pull_request:
18 | # The branches below must be a subset of the branches above
19 | branches: [ "main" ]
20 | schedule:
21 | - cron: '15 7 * * 0'
22 |
23 | jobs:
24 | analyze:
25 | name: Analyze
26 | runs-on: ubuntu-latest
27 | permissions:
28 | actions: read
29 | contents: read
30 | security-events: write
31 |
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | language: [ 'python' ]
36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
38 |
39 | steps:
40 | - name: Checkout repository
41 | uses: actions/checkout@v3
42 |
43 | # Initializes the CodeQL tools for scanning.
44 | - name: Initialize CodeQL
45 | uses: github/codeql-action/init@v2
46 | with:
47 | languages: ${{ matrix.language }}
48 | # If you wish to specify custom queries, you can do so here or in a config file.
49 | # By default, queries listed here will override any specified in a config file.
50 | # Prefix the list here with "+" to use these queries and those in the config file.
51 |
52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
53 | # queries: security-extended,security-and-quality
54 |
55 |
56 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java).
57 | # If this step fails, then you should remove it and run the build manually (see below)
58 | - name: Autobuild
59 | uses: github/codeql-action/autobuild@v2
60 |
61 | # ℹ️ Command-line programs to run using the OS shell.
62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
63 |
64 | # If the Autobuild fails above, remove it and uncomment the following three lines.
65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
66 |
67 | # - run: |
68 | # echo "Run, Build Application using script"
69 | # ./location_of_script_within_repo/buildscript.sh
70 |
71 | - name: Perform CodeQL Analysis
72 | uses: github/codeql-action/analyze@v2
73 | with:
74 | category: "/language:${{matrix.language}}"
75 |
--------------------------------------------------------------------------------
/.github/workflows/dependency-review.yml:
--------------------------------------------------------------------------------
1 | # Dependency Review Action
2 | #
3 | # This Action will scan dependency manifest files that change as part of a Pull Request, surfacing known-vulnerable versions of the packages declared or updated in the PR. Once installed, if the workflow run is marked as required, PRs introducing known-vulnerable packages will be blocked from merging.
4 | #
5 | # Source repository: https://github.com/actions/dependency-review-action
6 | # Public documentation: https://docs.github.com/en/code-security/supply-chain-security/understanding-your-software-supply-chain/about-dependency-review#dependency-review-enforcement
7 | name: 'Dependency Review'
8 | on: [pull_request]
9 |
10 | permissions:
11 | contents: read
12 |
13 | jobs:
14 | dependency-review:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - name: 'Checkout Repository'
18 | uses: actions/checkout@v3
19 | - name: 'Dependency Review'
20 | uses: actions/dependency-review-action@v2
21 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # defined by Ansong
132 | data/
133 | results/
134 | wandb/
135 | debug-tmp/
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning Math Reasoning from Self-Sampled Correct and Partially-Correct Solutions (ICLR'23)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | Code for paper [Learning Math Reasoning from Self-Sampled Correct and Partially-Correct Solutions](https://arxiv.org/abs/2205.14318). In this work, we propose to let the model perform sampling during training and learn from those self-sampled correct or partially-correct solutions, which are automatically identified by comparing the final or intermediate execution states. An example is shown as below.
16 |
17 |
18 |
19 |
20 | ## Updates
21 | - **2023-03-08**: Initial code release
22 | - **2023-02-17**: Camera-ready version updated on [arxiv](https://arxiv.org/abs/2205.14318)
23 | - **2023-01-20**: Paper is accepted at ICLR 2023
24 |
25 | ## Environment Setup
26 | > **Note: all of the following has only been tested on Linux machines, you may need to build your own `tree-sitter` parsers if a different platform is used.**
27 |
28 | *(Recommended)* Create a new conda environment
29 | ```bash
30 | conda create -n trace-codegen python=3.8
31 | conda activate trace-codegen
32 | ```
33 | Clone the code and install the dependencies
34 | ```bash
35 | git clone git@github.com:microsoft/TraceCodegen
36 | cd TraceCodegen
37 | pip install -r requirements.txt
38 | ```
39 | *(Optional)* Set up `wandb` for experiment tracking. First following [wandb documentation](https://docs.wandb.ai/ref/cli/wandb-login) to login, then change the following lines in `trainer.logger+` fields of the `yaml` config file you would like to run:
40 | ```yaml
41 | entity:
42 | project:
43 | ```
44 | *(Optional)* At any point, if you met with the Python import problem (e.g., `ModuleNotFoundError`), try doing this in the main (`TraceCodegen`) directory:
45 | ```bash
46 | export PYTHONPATH=`pwd`
47 | ```
48 |
49 | ## Data and Preprocessing
50 | We conduct experiments on the [MathQA-Python](https://arxiv.org/abs/2108.07732) and [GSM8k](https://github.com/openai/grade-school-math) datasets. As they have different licenses and preprocessing pipelines, here we describe them separately. But first, let's make a `data` directory:
51 | ```bash
52 | mkdir data
53 | ```
54 | ### MathQA-Python
55 | First follow [this script](https://github.com/google/trax/blob/master/trax/examples/MathQA_Python_generation_notebook.ipynb) for generation the MathQA-Python dataset from the [original MathQA dataset](https://math-qa.github.io/). After that, make sure your data directory looks something like this:
56 | ```
57 | data
58 | |-- mathqa
59 | | |-- train-python.jsonl
60 | | |-- val-python.jsonl
61 | | |-- test-python.jsonl
62 | |---...
63 | ```
64 | We preprocess MathQA-Python by respliting the data with template-based deduplication (see detail in paper). To do this, run the preprocessing script with the following:
65 | ```bash
66 | python resplit_mathqa_python.py
67 | ```
68 | After this, your `data` directory should now look something like this:
69 | ```
70 | data
71 | |-- mathqa
72 | | |-- train-python.jsonl
73 | | |-- val-python.jsonl
74 | | |-- test-python.jsonl
75 | | |-- train_dedup.jsonl
76 | | |-- val_dedup.jsonl
77 | |---...
78 | ```
79 | Note that we only combine and resplit the orignal train and validation set, and the test set kept untouched.
80 |
81 | ### GSM8k
82 | As the solution to GSM8k questions are originally annotated as math formulas, we used a script to automatically extract
83 | MathQA-Python style programs as solutions. To replicate this, first download the data from the
84 | [original GSM8k repo](https://github.com/openai/grade-school-math/tree/master/grade_school_math/data). After that, your `data` directory should look like this:
85 | ```
86 | data
87 | |-- gsmath
88 | | |-- train.jsonl
89 | | |-- test.jsonl
90 | | |-- ...
91 | |---...
92 | ```
93 | Now run the preprocessing script for GSM8k:
94 | ```bash
95 | python preprocessing/preprocess_gsm8k.py
96 | ```
97 | After this, your `data` directory should look like this:
98 | ```
99 | data
100 | |-- gsmath
101 | | |-- train.jsonl
102 | | |-- test.jsonl
103 | | |-- gsmath_train.jsonl
104 | | |-- gsmath_val.jsonl
105 | | |-- gsmath_test.jsonl
106 | | |-- ...
107 | |---...
108 | ```
109 |
110 | ## Model Training
111 | Our training framework is built on top of [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) (version 1.5.10). More specifically, you would only need to change the `yaml` configuration files if you would like to adjust the hyperparameters (e.g., batch size, gpus, dataset file, etc).
112 |
113 | > **Note: To run model training, we recommend using GPUs that have at least 32GiB of memory, or decrease the training batch size accordingly. All our experiments are conducted on 8x V100-32GB GPUs.**
114 |
115 | ### Basic usage
116 | ```bash
117 | python trainer.py fit --config .yaml
118 | ```
119 | Existing `yaml` config files can be found in `training_configs`, you can also find all the hyperparameter settings (e.g., batch size) in the Appendix of the paper.
120 |
121 | ### Using different transformer models
122 | If you would like to switch between `GPT-Neo-125M` and `GPT-Neo-2.7B` models, be sure to change the following fields in the `yaml` config file:
123 | ```yaml
124 | trainer:
125 | ...
126 | strategy: deepspeed_stage_2_offload # for 2.7B, or "ddp_find_unused_parameters_false" for 125M
127 | ...
128 | model:
129 | class_path: ...
130 | init_args:
131 | transformer_model_name: &transformer EleutherAI/gpt-neo-2.7B # or EleutherAI/gpt-neo-125M
132 | ...
133 | data:
134 | class_path: ...
135 | init_args:
136 | ...
137 | batch_size: 2 # [Optional] change this according to the GPU memory
138 | val_batch_size: 4 # [Optional] change this according to the GPU memory
139 | ...
140 | ```
141 |
142 | ### Using different datasets
143 | Since the MathQA-Python and GSM8k datasets are in the same format after preprocessing, you just need to change the file paths in the following fields of the `yaml` config file:
144 | ```yaml
145 | data:
146 | ...
147 | train_file_path: data/mathqa/train_dedup.jsonl # or "data/gsmath/gsmath_train.jsonl" for gsm8k
148 | val_file_path: data/mathqa/val_dedup.jsonl # or "data/gsmath/gsmath_val.jsonl" for gsm8k
149 | ```
150 |
151 | ### Use fully- or partially-correct self-sampled solutions
152 | To this end, you just need to use different `yaml` config files in `training_configs`:
153 | ```bash
154 | training_configs/gpt_self_sampling.yaml # for using self-sampled fully-correct solutions only
155 | training_configs/gpt_self_sampling_partial.yaml # for also using self-sampled partially-correct solutions
156 | ```
157 |
158 | ### Using different learning objectives
159 | - For the MLE baseline, just run with the config file of `training_configs/gpt_mle.yaml`
160 | - For running MML, set the following in the `yaml` config file:
161 | ```yaml
162 | model:
163 | ...
164 | init_args:
165 | ...
166 | mle_lambda: 0.0
167 | mml_lambda: 1.0
168 | ...
169 | ```
170 | - For running $\beta$-MML, keep the above and set `beta_smoothing: `
171 | - For running MLE-Aug, set `mle_lambda: 1.0` and `mml_lambda: 0.0` in above.
172 |
173 | ### All other hyperparameters
174 | For all other hyperparameters, please read the rest of the fields in the `yaml` file and the corresponding `__init__` function in the corresponding class, or refer to the [pytorch-lightning documents](https://pytorch-lightning.readthedocs.io/en/1.5.10/).
175 |
176 | ## Model Inference
177 | For running model inference (e.g., on the test set), use the following command:
178 | ```bash
179 | python trainer.py validate --config --model.init_args.load_ckpt_file
180 | ```
181 |
182 | ## Citation
183 | If you use the code in this repository, consider cite:
184 | ```bibtex
185 | @inproceedings{ni2023selfsampling,
186 | title={Learning Math Reasoning from Self-Sampled Correct and Partially-Correct Solutions},
187 | author={Ni, Ansong and Inala, Jeevana Priya and Wang, Chenglong and Polozov, Alex and Meek, Christopher and Radev, Dragomir and Gao, Jianfeng},
188 | booktitle={The 2023 International Conference on Learning Representations}
189 | }
190 | ```
191 | For any questions, please open an issue. PRs are definitely welcomed, and please check the following section about contributing to this repo.
192 |
193 | ## Contributing
194 |
195 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
196 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
197 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
198 |
199 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
200 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
201 | provided by the bot. You will only need to do this once across all repos using our CLA.
202 |
203 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
204 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
205 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
206 |
207 | ## Trademarks
208 |
209 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
210 | trademarks or logos is subject to and must follow
211 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
212 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
213 | Any use of third-party trademarks or logos are subject to those third-party's policies.
214 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/example.png
--------------------------------------------------------------------------------
/execution/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/execution/__init__.py
--------------------------------------------------------------------------------
/execution/execution_evaluation.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import multiprocessing.pool
3 | import ast
4 | import math
5 |
6 | from typing import List, Dict, Tuple, Any
7 | from concurrent.futures import ProcessPoolExecutor as Pool
8 | from execution.safe_execution_util import execute
9 |
10 | ######################################################################################################
11 | # following are some dataset specific functions for getting the execution result
12 | ######################################################################################################
13 |
14 | def mathqa_answer_eq(prediction: Any, gold_answer: Any):
15 | try:
16 | # if the execution result is a numpy array, valueError will be raised
17 | if math.isclose(float(prediction), float(gold_answer), abs_tol=1e-4):
18 | return True
19 | else:
20 | return False
21 | except (ValueError, TypeError, OverflowError):
22 | return False
23 |
24 | def mathqa_execution(program: str) -> Any:
25 | """
26 | for mathqa-python, we should be getting the answers from the "answer" variable in the local() variables
27 | """
28 |
29 | result = execute(program)
30 |
31 | if result["result"] == "passed":
32 | if "answer" in result["locals"]:
33 | executed_answer = result["locals"]["answer"]
34 | else:
35 | # FIXME: this is so ad-hoc
36 | executed_answer = -10000
37 | else:
38 | executed_answer = None
39 |
40 | return executed_answer
41 |
42 |
43 | ######################################################################################################
44 | # following are different metrics for evaluating the execution result
45 | # FIXME: Right now we only consider single test cases
46 | ######################################################################################################
47 |
48 | def batch_exec_programs(programs: List[str], exec_func: callable, n_processes: int = 20) -> List[Any]:
49 | # build a dict to optimize for potential same programs
50 | program_dict = {}
51 | for program in programs:
52 | if program not in program_dict:
53 | program_dict[program] = None
54 | unique_programs = list(program_dict.keys())
55 |
56 | idx = 0
57 | parsable_unique_programs = []
58 | for program in unique_programs:
59 | try:
60 | ast.parse(program, mode="exec")
61 | parsable_unique_programs.append(program)
62 | program_dict[program] = idx
63 | idx += 1
64 | except SyntaxError:
65 | program_dict[program] = -1
66 | except MemoryError:
67 | print(f"MemoryError when parsing {program}")
68 | program_dict[program] = -1
69 | except ValueError:
70 | print(f"ValueError when parsing {program}")
71 | program_dict[program] = -1
72 |
73 | with Pool(n_processes) as p:
74 | unique_executed_answers = p.map(exec_func, parsable_unique_programs)
75 | unique_executed_answers = list(unique_executed_answers)
76 | unique_executed_answers.append(None) # all syntax error will be assigned to None
77 |
78 | # build the original programs answer list
79 | executed_answers = [unique_executed_answers[program_dict[program]] for program in programs]
80 |
81 | return executed_answers, len(unique_programs)
82 |
83 | def batch_execution_acc(programs: List[str], exec_func: callable, answers: List[str],
84 | n_examples: int, eval_at_k: int, n_processes: int = 20) -> List[Tuple[float, float]]:
85 | """
86 | This function evaluates execution accuracy for a batch of programs using multiprocessing.
87 |
88 | Returns: execution accuracy, execution rate
89 | """
90 | assert len(programs) == len(answers) * eval_at_k
91 | assert n_examples * eval_at_k == len(programs)
92 |
93 | executed_answers, n_unique_programs = batch_exec_programs(programs, exec_func, n_processes)
94 | print(f"Evaluating {len(programs)} generated programs for {n_examples} tasks, " + \
95 | f"but only {n_unique_programs} unique programs")
96 | pct_unique_programs = n_unique_programs / len(programs)
97 |
98 | # separate the results for each task
99 | grouped_executed_answers = [executed_answers[i*eval_at_k:(i+1)*eval_at_k] for i in range(0, n_examples)]
100 | grouped_execution_evals = []
101 | for predicted_answers, gold_answer in zip(grouped_executed_answers, answers):
102 | correct_count = 0.0
103 | for predicted_answer in predicted_answers:
104 | if mathqa_answer_eq(predicted_answer, gold_answer):
105 | correct_count += 1
106 |
107 | accuracy_at_k = correct_count / eval_at_k
108 | pass_at_k = correct_count > 0.0
109 |
110 | grouped_execution_evals.append((accuracy_at_k, pass_at_k))
111 |
112 | return grouped_execution_evals, pct_unique_programs
113 |
114 | def execution_acc(program: str, exec_func: callable, answer: str) -> Tuple[float, float]:
115 | """
116 | This function is used to evaluate the accuracy of the execution of the program.
117 |
118 | Returns: execution accuracy, execution rate
119 | """
120 | executed_answer = exec_func(program)
121 | if executed_answer is not None and mathqa_answer_eq(executed_answer, answer):
122 | return 1.0, 1.0
123 | elif executed_answer is not None:
124 | return 0.0, 1.0
125 | else:
126 | return 0.0, 0.0
127 |
128 | def execution_eval_at_k(programs: List[str], exec_func: callable, answer: str, k: int) -> Tuple[float, float]:
129 | """
130 | Assign 1.0 when at least one out of the k programs execute to the correct answer
131 |
132 | Returns: (accuracy_at_k, pass_at_k)
133 | """
134 | assert len(programs) >= k, "The number of programs should be larger than k"
135 |
136 | correct_count = 0.0
137 | with Pool(20) as p:
138 | executed_answers = p.map(exec_func, programs[:k])
139 | for executed_answer in executed_answers:
140 | if mathqa_answer_eq(executed_answer, answer):
141 | correct_count += 1
142 |
143 | accuracy_at_k = correct_count / k
144 | pass_at_k = correct_count > 0.0
145 |
146 | return accuracy_at_k, pass_at_k
147 |
--------------------------------------------------------------------------------
/execution/program_tracing.py:
--------------------------------------------------------------------------------
1 | import math
2 | import scipy
3 | import json
4 | import os
5 |
6 | from concurrent.futures import ProcessPoolExecutor as Pool
7 | from typing import List, Dict, Tuple, Any, Union, NamedTuple, Set
8 | from scipy import special
9 |
10 | from typing import List, Dict, Any
11 | from tqdm import tqdm
12 | from lightning_modules.datasets.reader_utils import get_statements_from_code, byte_idx_to_char_idx
13 | from execution.safe_execution_util import execute, canonicalize_var_dict
14 | from tree_sitter import Language, Parser
15 |
16 | ProgState = Dict[str, float]
17 | HashableProgState = Tuple[str]
18 | ProgTraceUnit = NamedTuple("ProgTraceUnit", [("code", str), ("type", str), ("state", ProgState)])
19 | ProgTrace = List[ProgTraceUnit]
20 | Program = NamedTuple("Program", [("code", str), ("code_lite", str), ("trace", ProgTrace)])
21 |
22 | # initialize the parser for the code
23 | language_build_path = os.path.join(os.path.dirname(__file__)+'/../preprocessing/', 'py-tree-sitter.so')
24 | PY_LANGUAGE = Language(language_build_path, 'python')
25 | parser = Parser()
26 | parser.set_language(PY_LANGUAGE)
27 |
28 | """
29 | Tracing the execution of a program:
30 | 1. It parses the program into a sequence of tracing units (currently stmts);
31 | 2. Make some markings of the tracing units;
32 | 3. Insert tracing code to the program, after every tracing unit;
33 | 4. Run the program with tracing;
34 | 5. Collect the variable tracing information.
35 | """
36 |
37 | from copy import deepcopy
38 | from types import ModuleType
39 |
40 | tracing_local_list = []
41 | def record_state(local_var_dict):
42 | copied_local_var_dict = canonicalize_var_dict(local_var_dict)
43 | tracing_local_list.append(copied_local_var_dict)
44 |
45 | def get_execution_states(program: str, debug=False) -> Union[ProgTrace, None]:
46 | # first parse the program with tree-sitter
47 | stmts = get_statements_from_code(program, parser)
48 |
49 | if stmts is None:
50 | if debug:
51 | print(f'skipping unparseable example')
52 | print(f"##########\n{program}\n##########")
53 | return None
54 |
55 | # extract the stmt strings
56 | idx = 0
57 | stmt_states = []
58 | for stmt in stmts:
59 | start_idx = byte_idx_to_char_idx(stmt['start_byte'], program)
60 | end_idx = byte_idx_to_char_idx(stmt['end_byte'], program)
61 |
62 | if start_idx != idx:
63 | # add the gap
64 | stmt_states.append({"code": program[idx:start_idx], "type": "gap"})
65 |
66 | # add the stmt
67 | stmt_states.append({"code": program[start_idx:end_idx], "type": "stmt"})
68 | idx = end_idx
69 |
70 |
71 | # NOTE: FIXME: This only works for straight-line programs since it does not consider indentation
72 | for stmt in stmt_states:
73 | if stmt["type"] == "stmt":
74 | stmt["code"] += "\nrecord_state(locals())"
75 |
76 | # now assemble the program back together
77 | traced_program = "".join([stmt["code"] for stmt in stmt_states])
78 |
79 | # execute the program with tracing code
80 | result = execute(traced_program, {},
81 | globals={"tracing_local_list": tracing_local_list, "deepcopy": deepcopy,
82 | "record_state": record_state, "ModuleType": ModuleType,
83 | "math": math, "scipy": scipy, "scipy.special": special}, use_tracing=True)
84 |
85 | if result["result"] == "passed":
86 | # add the *output* states for each statement and remove the tracing code to restore orginal program
87 | stmt_idx = 0
88 | for stmt in stmt_states:
89 | if stmt["type"] == "stmt":
90 | stmt["execution_state"] = result["tracing_local_list"][stmt_idx]
91 | stmt["code"] = stmt["code"].replace("\nrecord_state(locals())", "")
92 | stmt_idx += 1
93 | prog_trace = [ProgTraceUnit(stmt["code"], stmt["type"],
94 | stmt["execution_state"] if stmt["type"] == "stmt" else {}) for stmt in stmt_states]
95 | return prog_trace
96 | else:
97 | if debug:
98 | print(f'skipping example of error: {result["result"]}')
99 | print(f"##########\n{program}\n##########")
100 | return None
101 |
102 | def batch_program_tracing(programs: List[str], n_processes=20) -> List[Union[ProgTrace, None]]:
103 | with Pool(n_processes) as p:
104 | tracing_outputs = p.map(get_execution_states, programs)
105 | return list(tracing_outputs)
106 |
107 | def exec_stmt_in_context(stmt: str, context: Dict[str, Any]):
108 | # NOTE: FIXME: This only works for straight-line programs since it does not consider indentation
109 | traced_stmt = stmt + "\nrecord_state(locals())"
110 |
111 | # execute the program with tracing code
112 | if "math" in context:
113 | context["math"] = math
114 | if "scipy" in context:
115 | context["scipy"] = scipy
116 | context["scipy.special"] = special
117 |
118 | result = execute(traced_stmt, locals=context,
119 | globals={"tracing_local_list": tracing_local_list, "deepcopy": deepcopy,
120 | "record_state": record_state, "ModuleType": ModuleType}, use_tracing=True)
121 |
122 | if result["result"] == "passed":
123 | # return the execution states as the local var list
124 | assert len(result["tracing_local_list"]) == 1, f"tracing_local_list: {result['tracing_local_list']}"
125 | stmt_output_state = result["tracing_local_list"][0]
126 | return stmt_output_state
127 | else:
128 | return None
129 |
130 | def is_trivial_state(state_dict: Dict[str, Any], prev_stmt: str):
131 | if len(state_dict) == 0:
132 | return True
133 |
134 | assert prev_stmt is not None, "prev_stmt must be provided to determine trivial states unless the state is empty"
135 |
136 | if prev_stmt.split(" ")[0] in ["#", "import"]:
137 | return True
138 |
139 | assert len(state_dict) == 1, f"prev_stmt {prev_stmt}; original state dict {state_dict}"
140 |
141 | return f"{list(state_dict.keys())[0]} = {list(state_dict.values())[0]}" in prev_stmt
142 |
143 | def get_state_repr(state_dict: Dict[str, Any], prev_stmt: str = None, only_include_keys: List[str] = None,
144 | prev_state_dict: Dict[str, Any] = None, use_diff=False, skip_trivial_states: bool = False):
145 | if use_diff:
146 | raise NotImplementedError
147 |
148 | if only_include_keys is not None:
149 | state_dict = {k: v for k, v in state_dict.items() if k in only_include_keys}
150 |
151 | if skip_trivial_states and is_trivial_state(state_dict, prev_stmt):
152 | return ""
153 |
154 | repr = "# "
155 | for key, value in state_dict.items():
156 | repr += f"{key} = {value}; "
157 | repr += "\n"
158 |
159 | return repr
160 |
161 | if __name__ == "__main__":
162 | # load some sample programs
163 | with open('data/mathqa/val-python.jsonl', 'r') as f:
164 | lines = f.readlines()
165 |
166 | json_examples = [json.loads(line) for line in lines]
167 |
168 | with open('data/mathqa/val_python_with_states.jsonl', 'w+') as f:
169 | success_count = 0
170 | failed_count = 0
171 | for json_example in tqdm(json_examples):
172 |
173 | program = json_example["code"]
174 | stmt_states = get_execution_states(program)
175 |
176 | if stmt_states is not None:
177 | json_example["states"] = stmt_states
178 | f.write(json.dumps(json_example) + "\n")
179 | success_count += 1
180 | else:
181 | failed_count += 1
182 |
183 | print(f"Successfully traced {success_count}/{success_count+failed_count} programs")
--------------------------------------------------------------------------------
/execution/safe_execution_util.py:
--------------------------------------------------------------------------------
1 | ##### much of the code is from https://raw.githubusercontent.com/openai/human-eval/463c980b59e818ace59f6f9803cd92c749ceae61/human_eval/execution.py ####
2 | from tree_sitter import Language, Parser
3 | from typing import Optional, Callable, Dict, Any, List
4 | import ast
5 | import contextlib
6 | import faulthandler
7 | import io
8 | import os
9 | import multiprocessing
10 | import platform
11 | import signal
12 | import tempfile
13 |
14 | from copy import deepcopy
15 | from types import ModuleType
16 | def canonicalize_var_dict(var_dict):
17 | copied_var_dict = {}
18 | for key, value in var_dict.items():
19 | if key in ["canonicalize_var_dict", "tracing_local_list", "deepcopy", "record_state", "ModuleType"]:
20 | continue
21 |
22 | if isinstance(value, ModuleType):
23 | assert key in ["math", "scipy"]
24 | copied_var_dict[key] = "module: " + str(value)
25 | elif str(type(value)) == "":
26 | copied_var_dict[key] = int(value)
27 | else:
28 | copied_var_dict[key] = deepcopy(value)
29 | return copied_var_dict
30 |
31 | def execute(code: str,
32 | locals: Dict[str, Any] = {},
33 | globals: Dict[str, Any] = {},
34 | task_id: Optional[str] = "tmp",
35 | sol_id: Optional[str] = "tmp",
36 | timeout: Optional[int] = 2,
37 | stdi_str: Optional[str] = None,
38 | use_tracing: bool = False) -> None:
39 | manager = multiprocessing.Manager()
40 | result = manager.dict()
41 |
42 | p = multiprocessing.Process(target=unsafe_execute, args=(code, globals, locals, timeout, result, stdi_str, use_tracing))
43 | p.start()
44 | p.join(timeout=timeout + 1)
45 | if p.is_alive():
46 | p.kill()
47 |
48 | if len(result) == 0:
49 | result["result"] = "timed out"
50 |
51 | result.update({"task_id": task_id, "sol_id": sol_id})
52 | return dict(result)
53 |
54 | def unsafe_execute(check_program: str,
55 | globals_: Dict[str, Any],
56 | locals_: Dict[str, Any],
57 | timeout: float,
58 | result: Dict[str, Any],
59 | stdi_str: Optional[str],
60 | use_tracing: bool = False) -> None:
61 |
62 | with create_tempdir():
63 |
64 | # These system calls are needed when cleaning up tempdir.
65 | import os
66 | import shutil
67 | rmtree = shutil.rmtree
68 | rmdir = os.rmdir
69 | chdir = os.chdir
70 |
71 | # Disable functionalities that can make destructive changes to the test.
72 | reliability_guard()
73 |
74 | try:
75 | with redirect_io(stdi_str) as ostream:
76 | with time_limit(timeout):
77 | # WARNING
78 | # This program exists to execute untrusted model-generated code. Although
79 | # it is highly unlikely that model-generated code will do something overtly
80 | # malicious in response to this test suite, model-generated code may act
81 | # destructively due to a lack of model capability or alignment.
82 | # Users are strongly encouraged to sandbox this evaluation suite so that it
83 | # does not perform destructive actions on their host or network. For more
84 | # information on how OpenAI sandboxes its code, see the accompanying paper.
85 | # Once you have read this disclaimer and taken appropriate precautions,
86 | # uncomment the following line and proceed at your own risk:
87 | exec(check_program, globals_, locals_)
88 | # result["globals"] = globals_
89 | result["locals"] = canonicalize_var_dict(locals_)
90 | result["result"] = "passed"
91 | if use_tracing:
92 | result["tracing_local_list"] = globals_["tracing_local_list"]
93 | except TimeoutException:
94 | result["result"] = "timed out"
95 | except BaseException as e:
96 | result["result"] = f"failed: {e}"
97 | finally:
98 | result["ostream"] = ostream.getvalue()
99 |
100 | # Needed for cleaning up.
101 | shutil.rmtree = rmtree
102 | os.rmdir = rmdir
103 | os.chdir = chdir
104 |
105 | def check_correctness_human_eval(problem: Dict, completion: str, timeout: float,
106 | completion_id: Optional[int] = None) -> Dict:
107 | """
108 | Evaluates the functional correctness of a completion by running the test
109 | suite provided in the problem.
110 |
111 | :param completion_id: an optional completion ID so we can match
112 | the results later even if execution finishes asynchronously.
113 | """
114 |
115 | # Construct the check program and run it.
116 | check_program = (
117 | problem["prompt"] + completion + "\n" +
118 | problem["test"] + "\n" +
119 | f"check({problem['entry_point']})"
120 | )
121 |
122 | return check_correctness(check_program, timeout, problem['task_id'], completion_id)
123 |
124 | def check_correctness(code: str, timeout: float, task_id: Optional[int],
125 | completion_id: Optional[int]) -> Dict:
126 |
127 | manager = multiprocessing.Manager()
128 | result = manager.list()
129 |
130 | p = multiprocessing.Process(target=unsafe_execute, args=(code, timeout, result))
131 | p.start()
132 | p.join(timeout=timeout + 1)
133 | if p.is_alive():
134 | p.kill()
135 |
136 | if not result:
137 | result.append("timed out")
138 |
139 | return dict(
140 | task_id=task_id,
141 | passed=result[0] == "passed",
142 | result=result[0],
143 | completion_id=completion_id,
144 | )
145 |
146 |
147 | @contextlib.contextmanager
148 | def time_limit(seconds: float):
149 | def signal_handler(signum, frame):
150 | raise TimeoutException("Timed out!")
151 | signal.setitimer(signal.ITIMER_REAL, seconds)
152 | signal.signal(signal.SIGALRM, signal_handler)
153 | try:
154 | yield
155 | finally:
156 | signal.setitimer(signal.ITIMER_REAL, 0)
157 |
158 |
159 | @contextlib.contextmanager
160 | def swallow_io():
161 | stream = WriteOnlyStringIO()
162 | with contextlib.redirect_stdout(stream):
163 | with contextlib.redirect_stderr(stream):
164 | with redirect_stdin(stream):
165 | yield
166 |
167 | @contextlib.contextmanager
168 | def redirect_io(stdi_str: Optional[str]):
169 | stream = io.StringIO()
170 | istream = io.StringIO(stdi_str) if stdi_str else WriteOnlyStringIO()
171 | with contextlib.redirect_stdout(stream):
172 | with contextlib.redirect_stderr(stream):
173 | with redirect_stdin(istream):
174 | yield stream
175 |
176 | @contextlib.contextmanager
177 | def create_tempdir():
178 | with tempfile.TemporaryDirectory() as dirname:
179 | with chdir(dirname):
180 | yield dirname
181 |
182 |
183 | class TimeoutException(Exception):
184 | pass
185 |
186 |
187 | class WriteOnlyStringIO(io.StringIO):
188 | """ StringIO that throws an exception when it's read from """
189 |
190 | def read(self, *args, **kwargs):
191 | raise IOError
192 |
193 | def readline(self, *args, **kwargs):
194 | raise IOError
195 |
196 | def readlines(self, *args, **kwargs):
197 | raise IOError
198 |
199 | def readable(self, *args, **kwargs):
200 | """ Returns True if the IO object can be read. """
201 | return False
202 |
203 |
204 | class redirect_stdin(contextlib._RedirectStream): # type: ignore
205 | _stream = 'stdin'
206 |
207 |
208 | @contextlib.contextmanager
209 | def chdir(root):
210 | if root == ".":
211 | yield
212 | return
213 | cwd = os.getcwd()
214 | os.chdir(root)
215 | try:
216 | yield
217 | except BaseException as exc:
218 | raise exc
219 | finally:
220 | os.chdir(cwd)
221 |
222 |
223 | def reliability_guard(maximum_memory_bytes: Optional[int] = None):
224 | """
225 | This disables various destructive functions and prevents the generated code
226 | from interfering with the test (e.g. fork bomb, killing other processes,
227 | removing filesystem files, etc.)
228 |
229 | WARNING
230 | This function is NOT a security sandbox. Untrusted code, including, model-
231 | generated code, should not be blindly executed outside of one. See the
232 | Codex paper for more information about OpenAI's code sandbox, and proceed
233 | with caution.
234 | """
235 |
236 | if maximum_memory_bytes is not None:
237 | import resource
238 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
239 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
240 | if not platform.uname().system == 'Darwin':
241 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
242 |
243 | faulthandler.disable()
244 |
245 | import builtins
246 | builtins.exit = None
247 | builtins.quit = None
248 |
249 | import os
250 | os.environ['OMP_NUM_THREADS'] = '1'
251 |
252 | os.kill = None
253 | os.system = None
254 | os.putenv = None
255 | os.remove = None
256 | os.removedirs = None
257 | os.rmdir = None
258 | os.fchdir = None
259 | os.setuid = None
260 | os.fork = None
261 | os.forkpty = None
262 | os.killpg = None
263 | os.rename = None
264 | os.renames = None
265 | os.truncate = None
266 | os.replace = None
267 | os.unlink = None
268 | os.fchmod = None
269 | os.fchown = None
270 | os.chmod = None
271 | os.chown = None
272 | os.chroot = None
273 | os.fchdir = None
274 | os.lchflags = None
275 | os.lchmod = None
276 | os.lchown = None
277 | os.getcwd = None
278 | os.chdir = None
279 |
280 | import shutil
281 | shutil.rmtree = None
282 | shutil.move = None
283 | shutil.chown = None
284 |
285 | import subprocess
286 | subprocess.Popen = None # type: ignore
287 |
288 | __builtins__['help'] = None
289 |
290 | import sys
291 | sys.modules['ipdb'] = None
292 | sys.modules['joblib'] = None
293 | sys.modules['resource'] = None
294 | sys.modules['psutil'] = None
295 | sys.modules['tkinter'] = None
296 |
297 |
298 | if __name__ == '__main__':
299 | import sys
300 | istream = io.StringIO('hello\nworld\n')
301 |
302 | outputs = []
303 | with redirect_io(istream) as ostream:
304 | code = """
305 | a = input()
306 | print(f'a is {a}')
307 | b = input()
308 | print(f'b is {b}')
309 | """
310 |
311 | exec(code, {})
312 |
313 | outputs.append(ostream.getvalue())
314 |
315 | print(outputs)
--------------------------------------------------------------------------------
/lightning_modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/lightning_modules/__init__.py
--------------------------------------------------------------------------------
/lightning_modules/callbacks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/lightning_modules/callbacks/__init__.py
--------------------------------------------------------------------------------
/lightning_modules/callbacks/save_prediction_callback.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import pytorch_lightning as pl
5 |
6 | from typing import Any, Dict, Optional, List
7 | from pytorch_lightning.callbacks import Callback
8 | from pathlib import Path
9 |
10 | class SavePredictionCallback(Callback):
11 | def __init__(self):
12 | self.predictions = list()
13 | self.prediction_save_dir = None
14 | self.metrics_save_dir = None
15 |
16 | def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str]) -> None:
17 | if pl_module.global_rank == 0:
18 | # make dirs for predictions and metrics saving paths
19 | pred_save_dir = os.path.join(trainer.log_dir, 'predictions')
20 | metrics_save_dir = os.path.join(trainer.log_dir, 'metrics')
21 |
22 | Path(pred_save_dir).mkdir(parents=True, exist_ok=True)
23 | Path(metrics_save_dir).mkdir(parents=True, exist_ok=True)
24 |
25 | self.prediction_save_dir = pred_save_dir
26 | self.metrics_save_dir = metrics_save_dir
27 |
28 | def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
29 | outputs: List[Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
30 | self.predictions.extend(outputs)
31 |
32 | def on_validation_epoch_end(self, trainer, pl_module) -> None:
33 | # save the predictions
34 | save_pred_file_path = os.path.join(self.prediction_save_dir,
35 | f'predictions_step_{trainer.global_step}_rank_{trainer.global_rank}.jsonl')
36 | with open(save_pred_file_path, 'w+') as f:
37 | for prediction in self.predictions:
38 | f.write(json.dumps(prediction)+'\n')
39 | print(f"{len(self.predictions)} predictions saved to {save_pred_file_path}")
40 | self.predictions = []
41 |
42 |
43 | if pl_module.global_rank == 0:
44 | self.save_metrics(trainer, pl_module)
45 |
46 | pl_module._rouge_metric.reset()
47 | pl_module._bleu_metric.reset()
48 | pl_module._em_metric.reset()
49 | pl_module._stmt_length.reset()
50 | pl_module._cell_stmt_num.reset()
51 | pl_module._edit_distance.reset()
52 |
53 | def save_metrics(self, trainer, pl_module) -> None:
54 | metrics = {}
55 |
56 | rouge_dict = dict([(k, float(v)) for k, v in pl_module._rouge_metric.compute().items() if k.endswith('fmeasure')])
57 | metrics.update(rouge_dict)
58 | metrics["bleu"] = float(pl_module._bleu_metric.compute())
59 | metrics["cell_exact_match"] = float(pl_module._em_metric.compute())
60 | metrics["output_stmt_len"] = float(pl_module._stmt_length.compute())
61 | metrics["output_stmt_num"] = float(pl_module._cell_stmt_num.compute())
62 | metrics["cell_edit_dist"] = float(pl_module._edit_distance.compute())
63 |
64 | # save the evaluation metrics
65 | save_metrics_file_path = os.path.join(self.metrics_save_dir, f'metrics_step_{trainer.global_step}.json')
66 | with open(save_metrics_file_path, 'w+') as f:
67 | f.write(json.dumps(metrics, indent=4))
68 |
69 | print(f"Eval metrics saved to {save_metrics_file_path}")
--------------------------------------------------------------------------------
/lightning_modules/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/lightning_modules/datasets/__init__.py
--------------------------------------------------------------------------------
/lightning_modules/datasets/mathqa_line_reader.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import sys
4 | import os
5 | from overrides import overrides
6 | import torch
7 |
8 | from typing import Dict, Iterable, List, Any, Optional, Union
9 |
10 | from pytorch_lightning import LightningDataModule
11 | from torch.utils.data import Dataset
12 |
13 | from lightning_modules.models.gpt_util import get_gpt, left_pad_sequences
14 | from execution.program_tracing import get_state_repr, is_trivial_state
15 |
16 | from torch.utils.data import DataLoader
17 |
18 | # set environment variable to avoid deadlocks, see:
19 | # https://docs.allennlp.org/main/api/data/data_loaders/multiprocess_data_loader/#multiprocessdataloader.common_issues
20 | os.environ['TOKENIZERS_PARALLELISM']='0'
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 | FEW_SHOT_RESERVED = 10
25 |
26 |
27 | class MathQADataset(Dataset):
28 |
29 | def __init__(
30 | self,
31 | file_path: str,
32 | transformer_model_name: str,
33 | max_instances: int,
34 | few_shot_n: int = 0,
35 | mode: str = "train",
36 | multi_example_instance: bool = False,
37 | **kwargs):
38 | super().__init__(**kwargs)
39 |
40 | # mode is one of ["train", "test", "test_few_shot"]
41 | assert mode in ["train", "test", "test_few_shot"]
42 |
43 | _, self.tokenizer = get_gpt(transformer_model_name, tokenizer_only=True)
44 |
45 | self.max_instances = max_instances
46 | self.mode = mode
47 | self.multi_example_instance = multi_example_instance
48 |
49 | assert few_shot_n <= FEW_SHOT_RESERVED, f"few_shot_n should be smaller than {FEW_SHOT_RESERVED}"
50 | self.few_shot_n = few_shot_n
51 |
52 | self.instances = self.read(file_path)
53 |
54 | def get_train_instance(self, example: Dict[str, Any]) -> Dict[str, Any]:
55 | example_dict = {"metadata": example}
56 |
57 | tokenizer_outputs = self.tokenizer("\n".join([example["text"], example["code"]]))
58 |
59 |
60 | example_dict["input_ids"] = tokenizer_outputs["input_ids"] + [self.tokenizer.eos_token_id]
61 | example_dict["attention_mask"] = tokenizer_outputs["attention_mask"] + [1]
62 | example_dict["metadata"]["pad_token_id"] = self.tokenizer.pad_token_id
63 |
64 | return example_dict
65 |
66 | def get_test_instance(self, example: Dict[str, Any]) -> Dict[str, Any]:
67 | example_dict = {"metadata": example}
68 |
69 | tokenizer_outputs = self.tokenizer(example["text"] + "\n")
70 |
71 |
72 | example_dict["input_ids"] = tokenizer_outputs["input_ids"]
73 | example_dict["attention_mask"] = tokenizer_outputs["attention_mask"]
74 | example_dict["metadata"]["pad_token_id"] = self.tokenizer.pad_token_id
75 |
76 | return example_dict
77 |
78 | def get_test_few_shot_instance(self, example: Dict[str, Any],
79 | few_shot_text_list: List[str],
80 | few_shot_code_list: List[str]) -> Dict[str, Any]:
81 | raise NotImplementedError("get_test_few_shot_instance is deprecated.")
82 |
83 | def read(self, file_path: str) -> Iterable[Dict[str, Any]]:
84 | print("Reading dataset files at %s", file_path)
85 |
86 | all_yield_instances = []
87 |
88 | # load the mathqa dataset with states
89 | mathqa_json_examples = []
90 | with open(file_path, 'r') as f:
91 | if self.mode == "test_few_shot":
92 | lines = f.readlines()[:self.max_instances+FEW_SHOT_RESERVED]
93 | else:
94 | lines = f.readlines()[:self.max_instances]
95 | for line in lines:
96 | mathqa_json_examples.append(json.loads(line))
97 |
98 | if self.mode == "test_few_shot":
99 | # holdout for few-shot prompting
100 | few_shot_examples = mathqa_json_examples[:self.few_shot_n]
101 | few_shot_text_list = [example['text'] for example in few_shot_examples]
102 | few_shot_code_list = [example['code'] for example in few_shot_examples]
103 |
104 | mathqa_json_examples = mathqa_json_examples[FEW_SHOT_RESERVED:]
105 |
106 | for exp in mathqa_json_examples:
107 | if self.mode == "train":
108 | example_dict = self.get_train_instance(exp)
109 | elif self.mode == "test":
110 | example_dict = self.get_test_instance(exp)
111 | elif self.mode == "test_few_shot":
112 | example_dict = self.get_test_few_shot_instance(exp, few_shot_text_list, few_shot_code_list)
113 | else:
114 | raise ValueError(f"Unknown mode: {self.mode}")
115 |
116 | all_yield_instances.append(example_dict)
117 |
118 | logger.info(f"loaded {len(all_yield_instances)} instances")
119 |
120 | return all_yield_instances
121 |
122 | def __getitem__(self, idx: int):
123 | return self.instances[idx]
124 |
125 | def __len__(self):
126 | return len(self.instances)
127 |
128 | def truncate(self, max_instances):
129 | truncated_instances = self.instances[max_instances:]
130 | self.instances = self.instances[:max_instances]
131 | return truncated_instances
132 |
133 | def extend(self, instances):
134 | self.instances.extend(instances)
135 |
136 | class MathQAMmlDataset(MathQADataset):
137 |
138 | def __init__(
139 | self,
140 | file_path: str,
141 | transformer_model_name: str,
142 | max_instances: int,
143 | **kwargs):
144 | super().__init__(file_path=file_path, transformer_model_name=transformer_model_name,
145 | max_instances=max_instances, **kwargs)
146 |
147 | def get_train_instance(self, example: Dict[str, Any]) -> Dict[str, Any]:
148 | example_dict = {"metadata": example}
149 |
150 | tokenizer_outputs = self.tokenizer(example["text"] + "\n")
151 |
152 | example_dict["input_ids"] = tokenizer_outputs["input_ids"]
153 | example_dict["attention_mask"] = tokenizer_outputs["attention_mask"]
154 | example_dict["metadata"]["pad_token_id"] = self.tokenizer.pad_token_id
155 |
156 | return example_dict
157 |
158 |
159 | def customized_collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, Any]:
160 | result_dict = {}
161 |
162 | pad_token_id = examples[0]["metadata"]["pad_token_id"]
163 |
164 | for k in examples[0].keys():
165 | if k == "metadata":
166 | result_dict[k] = [ex[k] for ex in examples]
167 | elif k == "input_ids":
168 | result_dict[k] = left_pad_sequences([torch.tensor(ex[k]) for ex in examples],
169 | batch_first=True, padding_value=pad_token_id)
170 | elif k == "attention_mask":
171 | result_dict[k] = left_pad_sequences([torch.tensor(ex[k]) for ex in examples],
172 | batch_first=True, padding_value=0)
173 | elif k == "state_mask":
174 | result_dict[k] = left_pad_sequences([torch.tensor(ex[k]) for ex in examples],
175 | batch_first=True, padding_value=0)
176 | elif k == "labels":
177 | result_dict[k] = left_pad_sequences([torch.tensor(ex[k]) for ex in examples],
178 | batch_first=True, padding_value=pad_token_id)
179 | else:
180 | raise ValueError(f"Unknown key {k} in example instance")
181 |
182 | return result_dict
183 |
184 | class MathQADataModule(LightningDataModule):
185 | def __init__(self,
186 | transformer_model_name: str,
187 | batch_size: int = 1,
188 | val_batch_size: int = 1,
189 | few_shot_n: int = 0,
190 | train_file_path: str = None,
191 | val_file_path: str = None,
192 | test_file_path: str = None,
193 | train_max_instances: int = sys.maxsize,
194 | val_max_instances: int = sys.maxsize):
195 | super().__init__()
196 | self.transformer_model_name = transformer_model_name
197 | self.batch_size = batch_size
198 | self.val_batch_size = val_batch_size
199 | self.few_shot_n = few_shot_n
200 |
201 | self.train_file_path = train_file_path
202 | self.val_file_path = val_file_path
203 | self.test_file_path = test_file_path
204 |
205 | self.train_max_instances = train_max_instances
206 | self.val_max_instances = val_max_instances
207 |
208 | self.train_data = None
209 | self.val_data = None
210 |
211 | def assign_data(self):
212 | train_data = MathQADataset(file_path=self.train_file_path,
213 | transformer_model_name=self.transformer_model_name,
214 | max_instances=self.train_max_instances,
215 | mode="train", few_shot_n=self.few_shot_n)
216 | self.train_data = train_data
217 |
218 | val_data = MathQADataset(file_path=self.val_file_path,
219 | transformer_model_name=self.transformer_model_name,
220 | max_instances=self.val_max_instances,
221 | mode="test", few_shot_n=self.few_shot_n)
222 | self.val_data = val_data
223 | print(f"assigning data is called!")
224 |
225 | # OPTIONAL, called for every GPU/machine (assigning state is OK)
226 | def setup(self, stage: Optional[str] = None):
227 | assert stage in ["fit", "validate", "test"]
228 |
229 | self.assign_data()
230 |
231 | def train_dataloader(self):
232 | if self.train_data is None:
233 | self.assign_data()
234 |
235 | dtloader = DataLoader(self.train_data, batch_size=self.batch_size,
236 | shuffle=True, drop_last=True, collate_fn=customized_collate_fn)
237 | return dtloader
238 |
239 | def val_dataloader(self):
240 | if self.val_data is None:
241 | self.assign_data()
242 |
243 | dtloader = DataLoader(self.val_data, batch_size=self.val_batch_size,
244 | shuffle=False, drop_last=True, collate_fn=customized_collate_fn)
245 | return dtloader
246 |
247 | def test_dataloader(self):
248 | raise NotImplementedError
249 |
250 | class MathQAMmlDataModule(MathQADataModule):
251 | def __init__(self, transformer_model_name: str, **kwargs):
252 | super().__init__(transformer_model_name=transformer_model_name, **kwargs)
253 |
254 | @overrides
255 | def assign_data(self):
256 | train_data = MathQAMmlDataset(file_path=self.train_file_path,
257 | transformer_model_name=self.transformer_model_name,
258 | max_instances=self.train_max_instances,
259 | mode="train", few_shot_n=self.few_shot_n)
260 | self.train_data = train_data
261 |
262 | val_data = MathQAMmlDataset(file_path=self.val_file_path,
263 | transformer_model_name=self.transformer_model_name,
264 | max_instances=self.val_max_instances,
265 | mode="test", few_shot_n=self.few_shot_n)
266 | self.val_data = val_data
267 | print(f"assigning data in MML data module is called!")
268 |
--------------------------------------------------------------------------------
/lightning_modules/datasets/reader_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import logging
4 |
5 | from functools import reduce
6 | from math import gcd
7 | from typing import List, Any, Union, Dict, Tuple
8 |
9 | from tree_sitter import Parser, Language
10 |
11 | # TODO: add those tokens to the vocab of whatever transformer models
12 | INDENT_TOKEN = "@@INDENT@@"
13 | NEWLINE_TOKEN = "@@NEWLINE@@"
14 | END_OF_CELL_TOKEN = "<|endofcell|>"
15 | COM_STMTS = ['if_statement', 'for_statement', 'while_statement', 'try_statement', 'with_statement',
16 | 'function_definition', 'class_definition']
17 | PY_MODULES = ['module', 'block', 'decorated_definition']
18 |
19 | logger = logging.getLogger('reader')
20 |
21 | def get_indent_level(line: str) -> int:
22 | """return the indent level of the line, measuring the leading tabs or spaces"""
23 | code_start_idx = line.find(line.lstrip())
24 |
25 | if code_start_idx == 0:
26 | return 0
27 |
28 | indent_str = line[:code_start_idx]
29 | if ' ' in indent_str and '\t' in indent_str:
30 | logger.info('indentation string contains both spaces and tabs')
31 |
32 | if indent_str.replace(' ', '').replace('\t', '') != '':
33 | logger.info('indentation string is not all spaces or tabs')
34 |
35 | return len(indent_str)
36 |
37 |
38 | def preprocess_cell(cell: Dict[str, Any], line_offset: int, add_special_token_for_code: bool=False) -> Tuple[List[str], Dict]:
39 | """preprocess the cell"""
40 |
41 | cell_type = cell['type']
42 | cell_lines = cell['lines']
43 |
44 | if not add_special_token_for_code:
45 | # special NEWLINE token will be added, so no need to add it here
46 | cell_lines = [line+'\n' for line in cell_lines]
47 |
48 | if 'stmts' in cell:
49 | stmts = cell['stmts']
50 | else:
51 | stmts = []
52 |
53 | for stmt in stmts:
54 | start, end = stmt['start_point'], stmt['end_point']
55 | stmt['start_point'] = (start[0]+line_offset, start[1])
56 | stmt['end_point'] = (end[0]+line_offset, end[1])
57 |
58 | if cell_type == 'markdown':
59 | commented_lines = [f"# {line}" for line in cell_lines]
60 | return commented_lines, stmts
61 |
62 | if cell_type == 'code' and add_special_token_for_code:
63 | # first figure out the indent levels
64 | indent_levels = [get_indent_level(line) for line in cell_lines]
65 |
66 | # there are cases where the indentations are actually line breakers for long line,
67 | # use 20 as a threshold to distinguish between actual indent and line breakers
68 | indent_levels = [x if (x > 1 and x % 2 == 0 and x <= 20) else 0 for x in indent_levels]
69 |
70 | gcd_indent_level = reduce(gcd, indent_levels)
71 | if gcd_indent_level not in [0, 2, 4, 8]:
72 | # logger.info(f'indentation level is not a power of 2 but {gcd_indent_level}, setting all indentations to 0')
73 | indent_levels = [0] * len(indent_levels)
74 | if gcd_indent_level != 0:
75 | indent_levels = [i // gcd_indent_level for i in indent_levels]
76 |
77 | # add indentation and newline tokens
78 | lines = [' '.join([INDENT_TOKEN]*i + [line[i*gcd_indent_level:]] + [NEWLINE_TOKEN]) for i, line in zip(indent_levels, cell_lines)]
79 | return lines, stmts
80 |
81 | return cell_lines, stmts
82 |
83 | def construct_from_lines(all_lines: List[str], start: Tuple[int, int],
84 | end: Tuple[int, int], is_byte_idx: bool=False):
85 | if is_byte_idx:
86 | start = (start[0], byte_idx_to_char_idx(start[1], all_lines[start[0]]))
87 | end = (end[0], byte_idx_to_char_idx(end[1], all_lines[end[0]]))
88 |
89 | # construct back the statement string
90 | statement_str = ''
91 | for i in range(start[0], end[0] + 1):
92 | if i == start[0]:
93 | if i == end[0]: # same line statement
94 | statement_str += (all_lines[i][start[1]:end[1]])
95 | else:
96 | statement_str += (all_lines[i][start[1]:])
97 | elif i == end[0]:
98 | statement_str += (all_lines[i][:end[1]])
99 | else:
100 | statement_str += all_lines[i]
101 |
102 | return statement_str
103 |
104 | def byte_idx_to_char_idx(byte_idx: int, line: str) -> int:
105 | """convert byte index to char index"""
106 | return len(bytes(line, 'utf-8')[:byte_idx].decode('utf-8'))
107 |
108 | def last_char_idx_to_token_idx(char_idx: int, tokens: List[str]) -> int:
109 | """ find the token that ends with the given char index """
110 | # calculate the token end indices
111 | token_end_indices = []
112 | total_len = 0
113 | for token in tokens:
114 | total_len += len(token)
115 | token_end_indices.append(total_len)
116 |
117 | if char_idx+1 in token_end_indices:
118 | return token_end_indices.index(char_idx+1)
119 | else:
120 | # here is a special case, when the last char of a stmt is not the ending
121 | # char of a token. An example is `a = f(x);`, while ')' is the last char
122 | # of the stmt, BPE gives ');' as a token, thus we will have to add the whole token
123 | return token_end_indices.index(char_idx+2)
124 |
125 | def last_byte_idx_to_token_idx(byte_idx: int, tokens: List[str], tokenizer) -> int:
126 | """ find the token that ends with the given byte index """
127 | # calculate the token end indices
128 | token_end_indices = []
129 | total_len = 0
130 | for token in tokens:
131 | total_len += len(bytearray([tokenizer.byte_decoder[c] for c in token]))
132 | token_end_indices.append(total_len)
133 |
134 | if byte_idx+1 in token_end_indices:
135 | return token_end_indices.index(byte_idx+1)
136 | else:
137 | # here is a special case, when the last byte of a stmt is not the ending
138 | # char of a token. An example is `a = f(x);`, while ')' is the last char
139 | # of the stmt, BPE gives ');' as a token, thus we will have to add the whole token
140 | # NOTE: the example above is actually the only one we observe
141 | return token_end_indices.index(byte_idx+2)
142 |
143 | def get_statements_from_code(code: str, parser, tolerate_errors: bool=False) -> List[Dict[str, Any]]:
144 | parsed_tree = parser.parse(bytes(code, 'utf-8'))
145 |
146 | # do a dfs on the parsed tree to record all the simple statements
147 | target_stmts: List[Dict] = []
148 | node_stack = [parsed_tree.root_node]
149 | while len(node_stack) > 0:
150 | node = node_stack.pop()
151 |
152 | if (node.type.endswith('statement') or node.type in ['comment', 'decorator']) \
153 | and node.type not in COM_STMTS:
154 | # this is a simple statement or a comment, so we can add it to the list
155 | target_stmts.append({'type': node.type, 'start_point': node.start_point,
156 | 'end_point': node.end_point, 'start_byte': node.start_byte,
157 | 'end_byte': node.end_byte})
158 | elif node.type in COM_STMTS or node.type.endswith('clause'):
159 | # separate the header and the body by the ":" token
160 | children_types = [c.type for c in node.children]
161 | separator_idx = children_types.index(':')
162 | assert separator_idx != -1
163 |
164 | # start of the header is the starter of the complex stmt, end is the end of the ":" token
165 | target_stmts.append({'type': node.type+'_header', 'start_point': node.start_point,
166 | 'start_byte': node.children[separator_idx].start_byte,
167 | 'end_point': node.children[separator_idx].end_point,
168 | 'end_byte': node.children[separator_idx].end_byte})
169 | node_stack.extend(node.children[separator_idx+1:][::-1])
170 | elif node.type in PY_MODULES:
171 | node_stack.extend(node.children[::-1])
172 | elif node.type == 'ERROR':
173 | # err_code_line = code[:byte_idx_to_char_idx(node.end_byte, code)].split('\n')[-1]
174 | # print(f"failed to parse code: #########\n{err_code_line}\n#########")
175 | if tolerate_errors:
176 | continue
177 | else:
178 | # failed to parse tree, return None NOTE: previously returning [], but this will get
179 | # confused with blank cells
180 | return None
181 | else:
182 | # other types, not sure what it contains, but assume it doesn't contain more statements
183 | print(f'unexpected node type: {node.type}')
184 | assert 'statement' not in node.sexp()
185 |
186 | return target_stmts
187 |
188 |
189 | if __name__ == '__main__':
190 | language_build_path = os.path.join(os.path.dirname(__file__), '../../preprocessing/py-tree-sitter.so')
191 | PY_LANGUAGE = Language(language_build_path, 'python')
192 | parser = Parser()
193 | parser.set_language(PY_LANGUAGE)
194 |
195 | code = "#a simple test\nif a == 0:\n if b == 0:\n b=1\nelif a>1:\n a*=2\nelse:\n a/=2"
196 | print(code)
197 |
198 | stmts = get_statements_from_code(code, parser)
199 |
200 | start = 0
201 | stmt_str = []
202 | for stmt in stmts:
203 | stmt_str.append(code[byte_idx_to_char_idx(start, code): byte_idx_to_char_idx(stmt['end_byte'], code)])
204 | start = stmt['end_byte']
205 |
206 |
207 | tree = parser.parse(bytes(code, 'utf-8'))
208 |
209 | root = tree.root_node
210 |
211 | print("")
212 |
213 |
--------------------------------------------------------------------------------
/lightning_modules/loggers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/lightning_modules/loggers/__init__.py
--------------------------------------------------------------------------------
/lightning_modules/loggers/patched_loggers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import neptune
3 |
4 | from typing import Optional, Union, List
5 | from pytorch_lightning.loggers import NeptuneLogger, CSVLogger, TensorBoardLogger, WandbLogger
6 | from pytorch_lightning.utilities import rank_zero_only
7 |
8 | class PatchedWandbLogger(WandbLogger):
9 | def __init__(self, entity: str, project: str, name: str, log_model: bool, save_code: bool,
10 | tags: List[str] = None, *args, **kwargs):
11 |
12 | kwargs['entity'] = entity
13 | kwargs['save_code'] = save_code
14 |
15 | # try load exp name from env var (for using batch jobs)
16 | exp_name_var = os.getenv('EXP_NAME')
17 | if exp_name_var is not None:
18 | processed_name = exp_name_var
19 | else:
20 | # remove the preceeding folder name
21 | processed_name = name.split('/')[-1]
22 |
23 | # automatically get the tags from the name
24 | if tags is None:
25 | kwargs['tags'] = processed_name.split('-')
26 | else:
27 | kwargs['tags'] = tags
28 |
29 | super().__init__(name=processed_name, project=project, log_model=log_model, *args, **kwargs)
30 |
31 | @rank_zero_only
32 | def log_code(self):
33 | # log the yaml and py files
34 | root = "."
35 | print(f"saving all files in {os.path.abspath(root)}")
36 | result = self.experiment.log_code(root=root,
37 | include_fn=(lambda path: path.endswith(".py") or \
38 | path.endswith(".yaml") or \
39 | path.endswith(".sh")),
40 | exclude_fn=(lambda path: ".venv" in path or \
41 | "debug-tmp" in path))
42 | if result is not None:
43 | print("########################################")
44 | print("######## Logged code to wandb. #########")
45 | print("########################################")
46 | else:
47 | print("######## logger inited but not successfully saved #########")
48 |
49 | class PatchedNeptuneLogger(NeptuneLogger):
50 | def __init__(self, project_name: str, *args, **kwargs):
51 | api_key = os.getenv('NEPTUNE_API_KEY')
52 | if api_key is None:
53 | raise ValueError("Please provide an API key for the neptune logger in the env vars.")
54 | # exp_name = os.getenv('PL_LOG_DIR').split('/')[0]
55 | # exp_name = os.getenv('AMLT_JOB_NAME')
56 |
57 | kwargs['api_key'] = api_key
58 | # kwargs['experiment_id'] = exp_name
59 | kwargs['project'] = project_name
60 | kwargs['source_files'] = ['**/*.py', '**/*.yaml', '**/*.sh']
61 |
62 | super().__init__(*args, **kwargs)
63 |
64 | class PatchedCSVLogger(CSVLogger):
65 | def __init__(self,
66 | name: Optional[str] = "default",
67 | version: Optional[Union[int, str]] = None,
68 | prefix: str = "",
69 | ):
70 | save_dir = os.getenv('PL_LOG_DIR')
71 | super().__init__(save_dir, name, version, prefix)
72 |
73 | class PatchedTensorBoardLogger(TensorBoardLogger):
74 | def __init__(self,
75 | name: Optional[str] = "default",
76 | version: Optional[Union[int, str]] = None,
77 | log_graph: bool = False,
78 | default_hp_metric: bool = True,
79 | prefix: str = "",
80 | sub_dir: Optional[str] = None,
81 | ):
82 | save_dir = os.getenv('PL_LOG_DIR')
83 | super().__init__(save_dir, name, version, log_graph, default_hp_metric, prefix, sub_dir)
--------------------------------------------------------------------------------
/lightning_modules/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/lightning_modules/models/__init__.py
--------------------------------------------------------------------------------
/lightning_modules/models/gpt_seq2seq_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import os
4 | import math
5 | import torch.nn.functional as F
6 | import pytorch_lightning as pl
7 | import io, tokenize, re
8 | import ast, astunparse
9 | import numpy as np
10 |
11 | from types import ModuleType
12 | from typing import Optional, Dict, Any, Tuple, List
13 | from transformers.optimization import AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup
14 | from transformers.optimization import get_cosine_schedule_with_warmup
15 |
16 |
17 | from torchmetrics import Metric, MeanMetric, MetricCollection
18 | from pytorch_lightning import LightningModule
19 |
20 | from .gpt_util import get_gpt
21 | from execution.execution_evaluation import execution_acc, mathqa_execution
22 | from execution.execution_evaluation import execution_eval_at_k, batch_execution_acc
23 |
24 | def get_overlap_example_ids(set_name: str, min_allow_dist: int) -> List[int]:
25 | assert set_name in ['train', 'test', 'val']
26 |
27 | save_path = f"analysis/train_{set_name}_overlap.npy"
28 | sim_matrix = np.load(save_path)
29 |
30 | # exclude itself in computing the most similar example when using the same set
31 | if set_name == 'train':
32 | np.fill_diagonal(sim_matrix, np.inf)
33 |
34 | min_list = np.min(sim_matrix, axis=0)
35 |
36 | overlapping_ids = []
37 | for i, min_dist in enumerate(min_list):
38 | if min_dist > min_allow_dist:
39 | continue
40 | else:
41 | overlapping_ids.append(i)
42 |
43 | return overlapping_ids
44 |
45 | # from https://stackoverflow.com/questions/1769332/script-to-remove-python-comments-docstrings
46 | def remove_comments_and_docstrings(source):
47 | io_obj = io.StringIO(source)
48 | out = ""
49 | prev_toktype = tokenize.INDENT
50 | last_lineno = -1
51 | last_col = 0
52 | for tok in tokenize.generate_tokens(io_obj.readline):
53 | token_type = tok[0]
54 | token_string = tok[1]
55 | start_line, start_col = tok[2]
56 | end_line, end_col = tok[3]
57 | ltext = tok[4]
58 | if start_line > last_lineno:
59 | last_col = 0
60 | if start_col > last_col:
61 | out += (" " * (start_col - last_col))
62 | if token_type == tokenize.COMMENT:
63 | pass
64 | elif token_type == tokenize.STRING:
65 | if prev_toktype != tokenize.INDENT:
66 | if prev_toktype != tokenize.NEWLINE:
67 | if start_col > 0:
68 | out += token_string
69 | else:
70 | out += token_string
71 | prev_toktype = token_type
72 | last_col = end_col
73 | last_lineno = end_line
74 | out = '\n'.join(l for l in out.splitlines() if l.strip())
75 | return out
76 |
77 | def post_process_code(code, remove_comments=True, remove_extra_lines=False, ast_back_parse=True):
78 | """ a series of post-processing steps to clean up the code and avoid duplicated code """
79 |
80 | if remove_comments:
81 | code = remove_comments_and_docstrings(code)
82 |
83 | if ast_back_parse:
84 | code = astunparse.unparse(ast.parse(code))
85 |
86 | if remove_extra_lines:
87 | # remove the code after "answer" is generated
88 | result = []
89 | for line in code.split("\n"):
90 | result.append(line)
91 | if line.startswith("answer"):
92 | break
93 | code = "\n".join(result)
94 |
95 | code = code.strip()
96 |
97 | return code
98 |
99 | # From the Codex Paper
100 | def estimate_pass_at_k(n, c, k):
101 | """
102 | :param n: total number of samples
103 | :param c: number of correct samples
104 | :param k: k in pass@$k$
105 | """
106 | if n - c < k:
107 | return 1.0
108 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
109 |
110 | class GptSeq2SeqModel(LightningModule):
111 | def __init__(self,
112 | transformer_model_name: str,
113 | max_gen_len: int = 100,
114 | sampling_temp: float = 0.2,
115 | sampling_temp_at_k: float = 0.8,
116 | gradient_ckpt: bool = False,
117 | pass_at_k: int = 1,
118 | additional_pass_at_k: List[int] = [],
119 | eval_pass_at_k_every_n_epochs: int = 1,
120 | always_eval_pass_at_k_first_n_epochs: int = -1,
121 | max_generation_batches: int = 100,
122 | max_steps: int = -1,
123 | warmup_steps: int = 0,
124 | eval_greedy_search: bool = False,
125 | measure_dedup_metrics: bool = False,
126 | optimizer: Dict[str, Any] = None,
127 | lr_scheduler: Dict[str, Any] = None,
128 | load_ckpt_file: str = None) -> None:
129 | super().__init__()
130 |
131 | self.max_gen_len = max_gen_len
132 | self.sampling_temp = sampling_temp
133 | self.sampling_temp_at_k = sampling_temp_at_k
134 | self.max_steps = max_steps
135 | self.warmup_steps = warmup_steps
136 |
137 | self.pass_at_k = pass_at_k
138 | self.additional_pass_at_k = additional_pass_at_k
139 | self.eval_pass_at_k_every_n_epochs = eval_pass_at_k_every_n_epochs
140 | self.always_eval_pass_at_k_first_n_epochs = always_eval_pass_at_k_first_n_epochs
141 | self.max_generation_batches = max_generation_batches
142 | self.eval_greedy_search = eval_greedy_search
143 | self.measure_dedup_metrics = measure_dedup_metrics
144 |
145 | # We only instantiate this when we need it.
146 | self.gpt, self.tokenizer = get_gpt(transformer_model_name, gradient_ckpt=gradient_ckpt)
147 |
148 | # save the prediction results for every valiation epoch
149 | self.predictions: List[Dict[str, Any]] = []
150 |
151 | # the optimizer and lr scheduler settings
152 | self.opt_params = optimizer["init_args"]
153 | self.lrs_params = lr_scheduler
154 | assert self.lrs_params["name"] in ["linear", "cosine", "constant"], "lr_scheduler must be one of 'linear', 'cosine', 'constant'"
155 |
156 | # keep track of the number of validation epochs for pass at k
157 | self.eval_epoch = 0
158 |
159 | # load the state dict from the checkpoint file
160 | if load_ckpt_file is not None:
161 | checkpoint = torch.load(load_ckpt_file, map_location=torch.device("cpu"))
162 | self.load_state_dict(checkpoint["state_dict"], strict=False)
163 | print(f"loaded weights from {load_ckpt_file}")
164 |
165 | self.metrics_dict: Dict[str, Metric] = MetricCollection({})
166 |
167 | self.metrics_dict["exec_acc"] = MeanMetric()
168 | self.metrics_dict["exec_rate"] = MeanMetric()
169 | self.metrics_dict["program_len_diff"] = MeanMetric()
170 | self.metrics_dict["unique_pct_in_k"] = MeanMetric()
171 |
172 | assert len(self.additional_pass_at_k) == 0 or self.pass_at_k > max(self.additional_pass_at_k), \
173 | f"pass_at_k ({self.pass_at_k}) must be greater than all additional_pass_at_k ({self.additional_pass_at_k})"
174 | if self.pass_at_k > 1:
175 | self.metrics_dict[f"acc@{self.pass_at_k}"]= MeanMetric()
176 | self.metrics_dict[f"pass@{self.pass_at_k}"]= MeanMetric()
177 |
178 | for additional_k in self.additional_pass_at_k:
179 | self.metrics_dict[f"pass@{additional_k}"]= MeanMetric()
180 |
181 | if self.eval_greedy_search:
182 | self.metrics_dict["greedy_exec_acc"]= MeanMetric()
183 | self.metrics_dict["greedy_exec_rate"]= MeanMetric()
184 |
185 | if self.measure_dedup_metrics:
186 | # evaluation without the overlap
187 | self.val_overlap_ids: Dict[str, List[int]] = dict()
188 | self.val_overlap_ids["2"] = get_overlap_example_ids("val", 2)
189 | self.val_overlap_ids["4"] = get_overlap_example_ids("val", 4)
190 | self.val_overlap_ids["8"] = get_overlap_example_ids("val", 8)
191 |
192 | self.metrics_dict["dedup_exec_acc_2"]= MeanMetric()
193 | self.metrics_dict["dedup_exec_acc_4"]= MeanMetric()
194 | self.metrics_dict["dedup_exec_acc_8"]= MeanMetric()
195 |
196 |
197 | def forward( # type: ignore
198 | self,
199 | input_ids: torch.Tensor,
200 | attention_mask: torch.Tensor,
201 | metadata: Optional[List[Dict[str, Any]]] = None,
202 | ) -> List[Dict[str, Any]]:
203 | """
204 | The inference time behavior of the model.
205 |
206 | Args:
207 | input_ids [torch.Tensor]: Tokens from the context.
208 | metadata (Optional[List[Dict[str, Any]]], optional): All additional information, `List` for the batch. Defaults to None.
209 |
210 | Returns:
211 | Dict[str, Any]: results saved in a `Dict` object.
212 | """
213 |
214 | generated_token_ids = self.gpt.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=True,
215 | max_length=input_ids.shape[1]+self.max_gen_len,
216 | temperature=self.sampling_temp)
217 |
218 | generated_token_ids = generated_token_ids[:, input_ids.shape[1]:]
219 |
220 | generated_strs = self.tokenizer.batch_decode(generated_token_ids)
221 |
222 | # truncate after the first '#' to be consistent with the codex prompting experiments
223 | generated_programs = [s.split(self.tokenizer.eos_token)[0] for s in generated_strs]
224 |
225 | output_dicts = [{"generated_program": generated_programs[i], "metadata": metadata[i]} \
226 | for i in range(len(generated_programs))]
227 |
228 | if self.eval_greedy_search:
229 | generated_token_ids = self.gpt.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=False,
230 | max_length=input_ids.shape[1]+self.max_gen_len)
231 | generated_token_ids = generated_token_ids[:, input_ids.shape[1]:]
232 | generated_strs = self.tokenizer.batch_decode(generated_token_ids)
233 | # truncate after the first '#' to be consistent with the codex prompting experiments
234 | generated_programs = [s.split(self.tokenizer.eos_token)[0] for s in generated_strs]
235 |
236 | for i in range(len(metadata)):
237 | output_dicts[i]["greedy_generated_program"] = generated_programs[i]
238 |
239 | # evaluate pass at k FIXME: a lot of overlapping code here
240 | if (self.eval_epoch % self.eval_pass_at_k_every_n_epochs == 0 \
241 | or self.eval_epoch < self.always_eval_pass_at_k_first_n_epochs) and self.pass_at_k > 1:
242 | generated_strs_list = [[] for _ in range(len(metadata))]
243 | remaining_k = self.pass_at_k
244 | while remaining_k > 0:
245 | generate_batch_size = min(remaining_k, self.max_generation_batches)
246 | remaining_k -= generate_batch_size
247 | batch_generated_token_ids = self.gpt.generate(input_ids=input_ids, attention_mask=attention_mask,
248 | do_sample=True,
249 | max_length=input_ids.shape[1]+self.max_gen_len,
250 | temperature=self.sampling_temp_at_k,
251 | num_return_sequences=generate_batch_size)
252 |
253 | batch_generated_token_ids = batch_generated_token_ids[:, input_ids.shape[1]:]
254 | batch_generated_strs = self.tokenizer.batch_decode(batch_generated_token_ids)
255 | batch_generated_programs = [s.split(self.tokenizer.eos_token)[0] for s in batch_generated_strs]
256 |
257 | for i in range(len(metadata)):
258 | generated_strs_list[i].extend(batch_generated_programs[i*generate_batch_size:(i+1)*generate_batch_size])
259 |
260 | for i in range(len(metadata)):
261 | output_dicts[i]["generated_k_programs"] = generated_strs_list[i]
262 |
263 |
264 | return output_dicts
265 |
266 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]:
267 | input_ids = batch["input_ids"]
268 | attention_mask = batch["attention_mask"]
269 | labels = batch["labels"] if "labels" in batch else input_ids
270 |
271 | gpt_result = self.gpt(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
272 |
273 | self.log("loss", gpt_result.loss, on_step=True, on_epoch=True)
274 | return {"loss": gpt_result.loss}
275 |
276 | def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Dict[str, torch.Tensor]:
277 | # input_tokens, target_mask, context_tokens, target_tokens, metadata = batch
278 | return self.forward(batch["input_ids"], batch["attention_mask"], batch["metadata"])
279 |
280 | def validation_step_end(self, outputs: List[Dict[str, Any]]) -> None:
281 | # update the evaluation metrics
282 | for output_dict in outputs:
283 | exec_acc = execution_acc(output_dict["generated_program"], mathqa_execution, output_dict["metadata"]["answer"])
284 | program_len = len(list(filter(lambda x: not x.startswith("#") and not len(x.strip()) == 0,
285 | output_dict["generated_program"].split("\n"))))
286 | gold_program_len = len(list(filter(lambda x: not x.startswith("#") and not len(x.strip()) == 0,
287 | post_process_code(output_dict["metadata"]["code"]).split("\n"))))
288 |
289 | program_len_diff = program_len - gold_program_len
290 |
291 | self.metrics_dict["exec_acc"](exec_acc[0])
292 | self.metrics_dict["exec_rate"](exec_acc[1])
293 | self.metrics_dict["program_len_diff"](program_len_diff)
294 |
295 | # also save the results in the json output file
296 | output_dict["metrics"] = {"exec_acc": float(exec_acc[0]),
297 | "exec_rate": float(exec_acc[1]),
298 | "program_len_diff": float(program_len_diff)}
299 |
300 | # measuring conditional metrics if they are enabled
301 | if self.measure_dedup_metrics:
302 | task_id = int(output_dict["metadata"]["task_id"])
303 | for dedup_allow_k in ["2", "4", "8"]:
304 | if not task_id in self.val_overlap_ids[dedup_allow_k]:
305 | self.metrics_dict[f"dedup_exec_acc_{dedup_allow_k}"](exec_acc[0])
306 |
307 | if self.eval_greedy_search:
308 | exec_acc = execution_acc(output_dict["greedy_generated_program"], mathqa_execution,
309 | output_dict["metadata"]["answer"])
310 |
311 | self.metrics_dict["greedy_exec_acc"](exec_acc[0])
312 | self.metrics_dict["greedy_exec_rate"](exec_acc[1])
313 |
314 | output_dict["metrics"].update({"greedy_exec_acc": float(exec_acc[0]),
315 | "greedy_exec_rate": float(exec_acc[1])})
316 |
317 | # canonocalization of the states to avoid error on saving the modules
318 | if "generated_program_state_list" in output_dict:
319 | for state_dict in output_dict["generated_program_state_list"]:
320 | if state_dict is not None:
321 | for key, value in state_dict.items():
322 | if isinstance(value, ModuleType):
323 | state_dict[key] = str(value)
324 |
325 | # save the outputs to the model
326 | self.predictions.extend(outputs)
327 |
328 | def validation_epoch_end_extra(self, outputs: List[Dict[str, Any]]) -> None:
329 | # compute the eval_at_k metrics
330 | if (self.eval_epoch % self.eval_pass_at_k_every_n_epochs == 0 \
331 | or self.eval_epoch < self.always_eval_pass_at_k_first_n_epochs) and self.pass_at_k > 1:
332 | print("evaluating pass at k...")
333 |
334 | all_generated_k_programs = [p["generated_k_programs"] for p in self.predictions]
335 | all_generated_k_programs_faltten = [item for sublist in all_generated_k_programs for item in sublist]
336 | gold_answers = [p["metadata"]["answer"] for p in self.predictions]
337 |
338 | result_list, pct_unique_progs = batch_execution_acc(all_generated_k_programs_faltten,
339 | mathqa_execution, gold_answers, len(self.predictions), self.pass_at_k)
340 |
341 | self.metrics_dict["unique_pct_in_k"](pct_unique_progs)
342 | for acc_at_k, pass_at_k in result_list:
343 | self.metrics_dict[f"acc@{self.pass_at_k}"](acc_at_k)
344 | self.metrics_dict[f"pass@{self.pass_at_k}"](pass_at_k)
345 |
346 | if len(self.additional_pass_at_k) > 0:
347 | for additional_k in self.additional_pass_at_k:
348 | correct = int(self.pass_at_k * acc_at_k)
349 | estimated_pass_at_k = estimate_pass_at_k(self.pass_at_k, correct, additional_k)
350 | self.metrics_dict[f"pass@{additional_k}"](estimated_pass_at_k)
351 |
352 | self.eval_epoch += 1
353 |
354 | def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None:
355 | # extra steps for using the predictions
356 | self.validation_epoch_end_extra(outputs)
357 |
358 | # compute the metrics
359 | eval_metrics_dict = {}
360 | for k in self.metrics_dict.keys():
361 | if k.startswith("dedup_"):
362 | dedup_exec_acc = float(self.metrics_dict[k].compute())
363 | # for the dedup metrics, it's possible that a batch is all duplicates thus manually set nan to 0.0
364 | eval_metrics_dict[k] = dedup_exec_acc if not math.isnan(dedup_exec_acc) else 0.0
365 | else:
366 | eval_metrics_dict[k] = float(self.metrics_dict[k].compute())
367 |
368 | # log and save the evalution metrics
369 | print(f"validation result: {eval_metrics_dict}")
370 | self.log_dict(eval_metrics_dict)
371 |
372 | # reset all the metrics
373 | for k in self.metrics_dict.keys():
374 | self.metrics_dict[k].reset()
375 |
376 | # save the predictions
377 | save_pred_file_path = os.path.join(self.trainer.log_dir,
378 | f'predictions_step_{self.trainer.global_step}_rank_{self.trainer.global_rank}.jsonl')
379 | with open(save_pred_file_path, 'w+') as f:
380 | for prediction in self.predictions:
381 | f.write(json.dumps(prediction)+'\n')
382 | print(f"{len(self.predictions)} predictions saved to {save_pred_file_path}")
383 |
384 | # reset the predictions
385 | self.predictions = []
386 |
387 | # FIXME: debug setting only
388 | # self.sampling_temp += 0.1
389 | # self.sampling_temp_at_k += 0.2
390 | # print(f"sampling temp is now {self.sampling_temp}, sampling temp at k is now {self.sampling_temp_at_k}")
391 |
392 | def test_step(self, batch: torch.Tensor, batch_idx: int) -> Dict[str, torch.Tensor]:
393 | raise NotImplementedError
394 |
395 | def on_fit_start(self) -> None:
396 | # save the code using wandb
397 | if self.logger:
398 | # if logger is initialized, save the code
399 | self.logger[0].log_code()
400 | else:
401 | print("logger is not initialized, code will not be saved")
402 |
403 | return super().on_fit_start()
404 |
405 | def configure_optimizers(self):
406 | optimizer = AdamW(self.parameters(), **self.opt_params)
407 | if self.lrs_params["name"] == "cosine":
408 | lr_scheduler = get_cosine_schedule_with_warmup(optimizer, **self.lrs_params["init_args"])
409 | elif self.lrs_params["name"] == "linear":
410 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, **self.lrs_params["init_args"])
411 | elif self.lrs_params["name"] == "constant":
412 | lr_scheduler = get_constant_schedule_with_warmup(optimizer, **self.lrs_params["init_args"])
413 | else:
414 | raise ValueError(f"lr_scheduler {self.lrs_params} is not supported")
415 |
416 | return {"optimizer": optimizer,
417 | "lr_scheduler": {
418 | "scheduler": lr_scheduler,
419 | "interval": "step"
420 | }
421 | }
--------------------------------------------------------------------------------
/lightning_modules/models/gpt_stmt_mml_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import os
4 | import json
5 | import random
6 |
7 | from itertools import chain
8 | from concurrent.futures import ProcessPoolExecutor as Pool
9 |
10 | from typing import Optional, Dict, Any, Tuple, List, Set, Union
11 | from torch.nn import CrossEntropyLoss
12 |
13 | from torchmetrics import MeanMetric
14 | from pytorch_lightning import LightningModule
15 |
16 | from .gpt_stmt_state_model import GptStmtStateModel
17 | from .gpt_seq2seq_model import post_process_code
18 | from .gpt_util import left_pad_sequences
19 | from execution.execution_evaluation import execution_acc, mathqa_execution, batch_exec_programs, mathqa_answer_eq
20 |
21 |
22 | class GptStmtMmlModel(GptStmtStateModel):
23 | def __init__(self,
24 | transformer_model_name: str,
25 | load_gold_programs: bool = True,
26 | on_policy_sample_num: Union[int, float] = 5,
27 | on_policy_sample_temp: float = 0.8,
28 | max_sampling_len: int = 100,
29 | length_diff_tolerance: int = 100,
30 | marg_set_size: int = 100,
31 | max_buffer_size: int = 100,
32 | load_samples_file: str = None,
33 | exclude_context_loss: bool = False,
34 | beta_smoothing: float = 1.0,
35 | mle_lambda: float = 1.0,
36 | mml_lambda: float = 0.0,
37 | mle_aug_norm: bool = False,
38 | **kwargs) -> None:
39 | if "eval_with_states" in kwargs and kwargs["eval_with_states"]:
40 | raise ValueError("eval_with_states is not supported for GptStmtMmlModel")
41 |
42 | super().__init__(transformer_model_name, **kwargs)
43 |
44 | self.load_gold_programs = load_gold_programs
45 | self.on_policy_sample_num = on_policy_sample_num
46 | self.on_policy_sample_temp = on_policy_sample_temp
47 | self.max_sampling_len = max_sampling_len
48 | self.marg_set_size = marg_set_size
49 | self.max_buffer_size = max_buffer_size
50 | self.length_diff_tolerance = length_diff_tolerance
51 | self.exclude_context_loss = exclude_context_loss
52 | self.beta_smoothing = beta_smoothing
53 | assert self.beta_smoothing > 0.0 and self.beta_smoothing <= 1.0, \
54 | f"beta_smoothing must be in (0, 1], but got {self.beta_smoothing}"
55 | self.mle_lambda = mle_lambda
56 | self.mml_lambda = mml_lambda
57 | self.mle_aug_norm = mle_aug_norm
58 |
59 | # load the large sample of programs from the buffer
60 | self.loaded_samples: Dict[str, List[str]] = {}
61 | if load_samples_file is not None:
62 | self.load_samples_from_file(load_samples_file)
63 | print(f"loaded samples from {load_samples_file}")
64 |
65 | # the key being the task id and the value being the list of correct programs (in strs, *w/o* eos token)
66 | self.correct_program_buffer: Dict[str, Set[str]] = {}
67 |
68 | # define some debugging or eval metrics
69 | self.metrics_dict["pct_unique_programs"] = MeanMetric()
70 |
71 | def load_samples_from_file(self, file_path: str) -> None:
72 | with open(file_path, 'r') as f:
73 | for line in f:
74 | json_dict = json.loads(line)
75 | self.loaded_samples[json_dict["metadata"]["task_id"]] = json_dict["generated_k_programs"]
76 |
77 | def try_save_programs(self, generated_programs: List[str], task_ids: List[str],
78 | correct_answers: List[float], samples_per_task: int,
79 | log_unique_program_pct: bool = True, verbose: bool = False) -> None:
80 | # execute the programs first
81 | program_exec_results, n_unique_programs = batch_exec_programs(generated_programs, mathqa_execution)#, n_processes=5)
82 | if log_unique_program_pct:
83 | self.metrics_dict["pct_unique_programs"](n_unique_programs / (len(task_ids) * samples_per_task))
84 |
85 | # save the programs with correct results into the buffer if it's not already in there
86 | correct_count = 0
87 | saved_count = 0
88 | for i, exec_result in enumerate(program_exec_results):
89 | example_idx = i // samples_per_task
90 | if mathqa_answer_eq(exec_result, correct_answers[example_idx]):
91 | correct_count += 1
92 | task_id = task_ids[example_idx]
93 | generated_program = post_process_code(generated_programs[i], ast_back_parse=False)
94 | if generated_program not in self.correct_program_buffer[task_id]:
95 | # check whether satisfied the length difference tolerence
96 | min_prog_len = min([len(prog.split("\n")) for prog in self.correct_program_buffer[task_id]])
97 | if len(generated_program.split("\n")) - min_prog_len <= self.length_diff_tolerance:
98 | # save the program into the buffer
99 | self.correct_program_buffer[task_id].add(generated_program)
100 | saved_count += 1
101 | if verbose:
102 | print(f"{len(generated_programs)} in total, {n_unique_programs} are unique, " + \
103 | f"{correct_count} programs are correct, saved {saved_count} programs into the buffer")
104 |
105 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]:
106 | # use task ids as the identifier
107 | task_ids = [ex["task_id"] for ex in batch["metadata"]]
108 |
109 | # add the gold programs to the buffer
110 | for i, example in enumerate(batch["metadata"]):
111 | task_id = task_ids[i]
112 | if task_id in self.correct_program_buffer:
113 | # already have loaded the correct programs for this task
114 | continue
115 | elif self.load_gold_programs:
116 | gold_program = example["code"]
117 | # the program in the buffer is always post-processed
118 | self.correct_program_buffer[task_id] = set((post_process_code(gold_program, ast_back_parse=False),))
119 |
120 | # load the program from the samples, if available
121 | if task_id in self.loaded_samples:
122 | loaded_examples_for_task = self.loaded_samples.pop(task_id) # this also removes the key
123 | self.try_save_programs(loaded_examples_for_task, [task_id], [batch["metadata"][i]["answer"]],
124 | len(loaded_examples_for_task), log_unique_program_pct=False)
125 |
126 | else:
127 | self.correct_program_buffer[task_id] = set()
128 |
129 | # do on-policy sampling for the current tasks
130 | input_ids = batch["input_ids"]
131 | attention_mask = batch["attention_mask"]
132 |
133 | if self.on_policy_sample_num > 0:
134 | with torch.no_grad():
135 | if not all([len(self.correct_program_buffer[task_id]) >= self.max_buffer_size for task_id in task_ids]):
136 | # generate the programs and get their execution results
137 | max_context_len = input_ids.size(1)
138 | output_seqs = self.gpt.generate(input_ids=input_ids, attention_mask=attention_mask,
139 | do_sample=True, max_new_tokens=self.max_sampling_len,
140 | num_return_sequences=self.on_policy_sample_num,
141 | temperature=self.on_policy_sample_temp)
142 | generated_seqs = output_seqs[:, max_context_len:].cpu()
143 | generated_programs = self.tokenizer.batch_decode(generated_seqs, skip_special_tokens=True)
144 |
145 | # try to save the programs
146 | correct_answers = [x["answer"] for x in batch["metadata"]]
147 | self.try_save_programs(generated_programs, task_ids, correct_answers, self.on_policy_sample_num)
148 |
149 | def sample_from_buffer(task_id: str) -> List[str]:
150 | cached_programs = list(self.correct_program_buffer[task_id])
151 | if len(cached_programs) <= self.marg_set_size:
152 | return cached_programs
153 | else:
154 | return random.sample(cached_programs, self.marg_set_size)
155 |
156 | # remove the left paddings first and concat the context and cached programs
157 | context_seqs = [input_ids[i, -context_len:] for i, context_len in enumerate(attention_mask.sum(dim=1))]
158 | cached_program_seqs: List[List[torch.Tensor]] = [[self.tokenizer(prog_str, return_tensors="pt")['input_ids'][0]
159 | for prog_str in sample_from_buffer(task_id)]
160 | for task_id in task_ids]
161 | cached_program_nums = [len(cached_programs) for cached_programs in cached_program_seqs]
162 | flatten_input_ids = []
163 | for i, program_seqs in enumerate(cached_program_seqs):
164 | for program_seq in program_seqs:
165 | flatten_input_ids.append(torch.cat((context_seqs[i], program_seq.to(self.device),
166 | torch.tensor([self.tokenizer.eos_token_id], device=self.device)), dim=0))
167 | flatten_attention_mask = [torch.ones_like(flatten_ids) for flatten_ids in flatten_input_ids]
168 | flatten_input_ids = left_pad_sequences(flatten_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
169 | flatten_attention_mask = left_pad_sequences(flatten_attention_mask, batch_first=True, padding_value=0)
170 |
171 | # exclude the loss from context by setting the labels of them to be -100
172 | if self.exclude_context_loss:
173 | flatten_labels = []
174 | for i, program_seqs in enumerate(cached_program_seqs):
175 | for j, program_seq in enumerate(program_seqs):
176 | concat_labels = torch.cat((-100 * torch.ones_like(context_seqs[i]), program_seq.to(self.device),
177 | torch.tensor([self.tokenizer.eos_token_id], device=self.device)), dim=0)
178 | flatten_labels.append(concat_labels)
179 | flatten_labels = left_pad_sequences(flatten_labels, batch_first=True, padding_value=self.tokenizer.pad_token_id)
180 | else:
181 | flatten_labels = flatten_input_ids
182 | assert flatten_labels.shape == flatten_input_ids.shape == flatten_attention_mask.shape, \
183 | f"{flatten_labels.shape}, {flatten_input_ids.shape}, {flatten_attention_mask.shape}"
184 |
185 | # go through the gpt model as if it's a new batch
186 | gpt_result = self.gpt(input_ids=flatten_input_ids, attention_mask=flatten_attention_mask, labels=flatten_labels)
187 |
188 | # reorganize the logits to compute individual program log probs
189 | shift_logits = gpt_result.logits[..., :-1, :].contiguous()
190 | shift_labels = flatten_labels[..., 1:].contiguous()
191 | loss_fct = CrossEntropyLoss(reduction="none")
192 | flatten_unreduced_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
193 | unreduced_loss = flatten_unreduced_loss.view(shift_labels.shape) * flatten_attention_mask[..., 1:]
194 |
195 | # compute the marginal log prob
196 | loss = self.get_mixed_mle_mml_loss(unreduced_loss, cached_program_nums)
197 | self.log("loss", loss, on_step=True, on_epoch=True)
198 | self.log("fcp_buffer_size", sum(cached_program_nums) / len(cached_program_nums), on_step=False, on_epoch=True)
199 |
200 | return {"loss": loss}
201 |
202 | def get_mixed_mle_mml_loss(self, unreduced_loss: torch.Tensor, cached_program_nums: List[int],
203 | log_prob_dist: bool = True) -> torch.Tensor:
204 | """
205 | Compute the loss for the MML and MLE.
206 | """
207 | # compute the marginal log prob and the sum of the log probs
208 | grouped_example_log_probs = torch.split(-self.beta_smoothing * torch.sum(unreduced_loss, dim=1), cached_program_nums)
209 | marginal_log_probs = torch.stack([-1.0 * torch.logsumexp(log_probs, dim=0) / self.beta_smoothing for log_probs in grouped_example_log_probs])
210 | norm_func = (lambda x: 1.0 ) if not self.mle_aug_norm else (lambda x: 1.0 / len(x))
211 | sum_log_probs = torch.stack([-norm_func(log_probs) * torch.sum(log_probs, dim=0) for log_probs in grouped_example_log_probs])
212 | loss = torch.mean(self.mml_lambda * marginal_log_probs + self.mle_lambda * sum_log_probs)
213 |
214 | if log_prob_dist:
215 | # some additional metrics to evaluate the distribution of the programs
216 | max_prob = [sorted(torch.exp(log_probs), reverse=True)[0] for log_probs in grouped_example_log_probs]
217 | second_max_prob = [sorted(torch.exp(log_probs), reverse=True)[1]
218 | if len(log_probs) > 1 else None for log_probs in grouped_example_log_probs]
219 | second_max_prob = list(filter(lambda x: x is not None, second_max_prob))
220 |
221 | max_prob_avg = float(torch.pow(torch.stack(max_prob).mean(), 1.0 / self.beta_smoothing))
222 | second_max_prob_avg = float(torch.pow(torch.stack(second_max_prob).mean(), 1.0 / self.beta_smoothing)) \
223 | if len(second_max_prob) > 0 else 0.0
224 |
225 | self.log("max_prob", max_prob_avg, on_step=False, on_epoch=True)
226 | self.log("second_max_prob", second_max_prob_avg, on_step=False, on_epoch=True)
227 |
228 | return loss
229 |
230 | def forward( # type: ignore
231 | self,
232 | input_ids: torch.Tensor,
233 | attention_mask: torch.Tensor,
234 | metadata: Optional[List[Dict[str, Any]]] = None,
235 | ) -> List[Dict[str, Any]]:
236 | return super(GptStmtStateModel, self).forward(input_ids, attention_mask, metadata)
237 |
238 | def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None:
239 | # first do all the things that the base model do
240 | super().validation_epoch_end(outputs)
241 |
242 | # then save the buffer status
243 | save_buffer_file_path = os.path.join(self.trainer.log_dir,
244 | f'buffer_step_{self.trainer.global_step}_rank_{self.trainer.global_rank}.jsonl')
245 | with open(save_buffer_file_path, 'w+') as f:
246 | for task_id, program_seqs in self.correct_program_buffer.items():
247 | json_dict = {"task_id": task_id, "saved_programs": list(program_seqs)}
248 | f.write(json.dumps(json_dict) + '\n')
249 | print(f"buffer saved to {save_buffer_file_path}")
250 |
251 | def training_epoch_end(self, outputs) -> None:
252 | if not torch.distributed.is_initialized():
253 | print("training_epoch_end: not using distributed training")
254 | return
255 |
256 | # gather all the buffers from all processes
257 | world_size = torch.distributed.get_world_size()
258 | all_buffer_list = [{} for _ in range(world_size)]
259 | torch.distributed.all_gather_object(all_buffer_list, self.correct_program_buffer)
260 |
261 | # merge all the buffers
262 | prev_avg_buffer_size = sum(map(lambda x: len(x[1]), self.correct_program_buffer.items())) / len(self.correct_program_buffer)
263 | merged_buffer: Dict[str, Set[str]] = {}
264 | for buffer in all_buffer_list:
265 | for task_id, programs in buffer.items():
266 | if task_id not in merged_buffer:
267 | merged_buffer[task_id] = programs
268 | else:
269 | merged_buffer[task_id].update(programs)
270 |
271 | self.correct_program_buffer = merged_buffer
272 | after_avg_buffer_size = sum(map(lambda x: len(x[1]), self.correct_program_buffer.items())) / len(self.correct_program_buffer)
273 | print(f"buffer size increased from {prev_avg_buffer_size} to {after_avg_buffer_size}, by {after_avg_buffer_size - prev_avg_buffer_size}")
274 |
275 |
--------------------------------------------------------------------------------
/lightning_modules/models/gpt_stmt_partial_mml_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import os
4 | import json
5 | import random
6 |
7 | from itertools import chain
8 | from concurrent.futures import ProcessPoolExecutor as Pool
9 |
10 | from typing import Optional, Dict, Any, Tuple, List, Set, Union
11 | from torch.nn import CrossEntropyLoss
12 |
13 | from torchmetrics import MeanMetric
14 | from pytorch_lightning import LightningModule
15 |
16 | from .gpt_stmt_mml_model import GptStmtMmlModel
17 | from .gpt_stmt_state_model import GptStmtStateModel
18 | from .gpt_seq2seq_model import post_process_code
19 | from .gpt_util import left_pad_sequences
20 | from execution.execution_evaluation import execution_acc, mathqa_execution, batch_exec_programs, mathqa_answer_eq
21 | from execution.program_tracing import get_execution_states, batch_program_tracing
22 | from execution.program_tracing import ProgState, ProgTrace, ProgTraceUnit, HashableProgState, Program
23 |
24 | def get_hashable_state(state_dict: ProgState) -> HashableProgState:
25 | if len(state_dict) == 0:
26 | raise ValueError("state_dict is empty, check if the code is empty or being gap")
27 |
28 | # get the values and sort it to make sure the order is the same FIXME: This only works for mathqa
29 | str_vars = [str(value) for _, value in state_dict.items()]
30 |
31 | # make it immutable to be hashable so that we can use it as a key
32 | return tuple(sorted(str_vars))
33 |
34 | def mathqa_identify_fully_correct_state(state: ProgState, gold_answer: Any):
35 | if "type" in state and "code" in state:
36 | raise ValueError("mathqa_identify_fully_correct_state function only accepts states, not tracing units")
37 |
38 | if "answer" not in state:
39 | return False
40 | else:
41 | return mathqa_answer_eq(state["answer"], gold_answer)
42 |
43 | def mathqa_identify_output(state: ProgState):
44 | if "type" in state and "code" in state:
45 | raise ValueError("mathqa_identify_output function only accepts states, not tracing units")
46 |
47 | if "answer" in state:
48 | return True
49 | else:
50 | return False
51 |
52 | def construct_program_from_trace(trace: ProgTrace) -> Program:
53 | code = "".join([unit.code for unit in trace])
54 | code_lite = post_process_code(code)
55 | return Program(code, code_lite, trace)
56 |
57 | def get_program_str_set(programs: List[Program]) -> Set[str]:
58 | return set([program.code_lite for program in programs])
59 |
60 | def not_as_prefix_of_programs(program: Program, programs: List[Program]) -> bool:
61 | for prog in programs:
62 | if prog.code_lite.startswith(program.code_lite):
63 | return False
64 | return True
65 |
66 | def get_empty_program() -> Program:
67 | return Program("", "", [])
68 |
69 | def get_num_stmts(program: Program) -> int:
70 | return len(list(filter(lambda x: x.type == "stmt", program.trace)))
71 |
72 | class GptStmtPartialMmlModel(GptStmtMmlModel):
73 | def __init__(self,
74 | transformer_model_name: str,
75 | n_pcp_samples: int = 1,
76 | prioritize_fcp: bool = True,
77 | length_diff_tolerance: int = 100,
78 | sampling_from_states: bool = False,
79 | sampling_full_prog_only: bool = False,
80 | gold_program_only: bool = False,
81 | fcp_only: bool = False,
82 | norm_marg_by_len: bool = False,
83 | **kwargs) -> None:
84 |
85 | super().__init__(transformer_model_name, **kwargs)
86 |
87 | self.n_pcp_samples = n_pcp_samples
88 | assert n_pcp_samples == 1, "currently only support n_pcp_samples = 1"
89 | self.prioritize_fcp = prioritize_fcp
90 | self.length_diff_tolerance = length_diff_tolerance
91 | self.sampling_from_states = sampling_from_states
92 | self.sampling_full_prog_only = sampling_full_prog_only
93 | self.gold_program_only = gold_program_only
94 | self.fcp_only = fcp_only
95 | self.norm_marg_by_len = norm_marg_by_len
96 |
97 | # for each task id as the key, the value dict maps states to known sub-programs
98 | self.state_programs_dict: Dict[str, Dict[HashableProgState, List[Program]]] = {}
99 |
100 | # redefine it here to use new type hints
101 | self.correct_program_buffer: Dict[str, List[Program]] = {}
102 |
103 | # this needs to be differentiated from the fully correct programs
104 | # because we do not sample for the fully correct programs
105 | self.partially_correct_program_buffer: Dict[str, List[Program]] = {}
106 |
107 |
108 | def save_program_by_trace(self, task_id: str, tracing_states: Union[ProgTrace, None],
109 | is_fully_correct: bool) -> str:
110 | """ all the traces of the programs will have the final state being correct,
111 | so we try to save all of its subprograms by state
112 |
113 | return: the status of the saving attempt, either one of
114 | ["saved", "existing fcp", "existing pcp", "not valid"] """
115 |
116 | # nothing to save if the program executes to error
117 | if tracing_states is None:
118 | return "S3: not valid"
119 |
120 | # construct the program tuple
121 | trace_program = construct_program_from_trace(tracing_states)
122 |
123 | def check_buffer_and_remove(buffer: Dict[str, List[Program]], program: Program):
124 | # check any of the existing program is the prefix of this new program, if so, remove the shorter one
125 | programs_to_remove = []
126 | for i, saved_program in enumerate(buffer[task_id]):
127 | # empty program doesn't need to be removed
128 | if len(saved_program.code_lite) > 0 and program.code_lite.startswith(saved_program.code_lite):
129 | programs_to_remove.insert(0, i)
130 | for idx in programs_to_remove:
131 | buffer[task_id].pop(idx)
132 |
133 | # we only save the longest program
134 | if is_fully_correct:
135 | if trace_program.code_lite not in get_program_str_set(self.correct_program_buffer[task_id]):
136 | # check if any partial program is the prefix of this new fully correct program
137 | check_buffer_and_remove(self.partially_correct_program_buffer, trace_program)
138 | self.correct_program_buffer[task_id].append(trace_program)
139 | else:
140 | # there is nothing to save for the states dict since the full program is already in the buffer
141 | return "S4: existing fcp"
142 | elif not_as_prefix_of_programs(trace_program, self.partially_correct_program_buffer[task_id]) and \
143 | not_as_prefix_of_programs(trace_program, self.correct_program_buffer[task_id]):
144 | # check if any partial program is the prefix of this new partially correct program
145 | check_buffer_and_remove(self.partially_correct_program_buffer, trace_program)
146 | self.partially_correct_program_buffer[task_id].append(trace_program)
147 | else:
148 | return "S5: existing pcp"
149 |
150 | # we try to save all the sub-programs of the program by state, excluding the final correct state
151 | tracing_states_to_save = tracing_states[:-1] if is_fully_correct else tracing_states
152 | for i, trace_unit in enumerate(tracing_states_to_save):
153 | if trace_unit.type != "stmt":
154 | continue
155 |
156 | sub_program = construct_program_from_trace(tracing_states_to_save[:i+1])
157 |
158 | # check the state
159 | hashable_state = get_hashable_state(trace_unit.state)
160 | if hashable_state not in self.state_programs_dict[task_id]:
161 | self.state_programs_dict[task_id][hashable_state] = [sub_program]
162 | else:
163 | if sub_program.code_lite not in get_program_str_set(self.state_programs_dict[task_id][hashable_state]):
164 | self.state_programs_dict[task_id][hashable_state].append(sub_program)
165 |
166 | return "S6: saved"
167 |
168 | def check_and_save_partially_correct_program(self, task_id: str,
169 | program_trace: Union[ProgTrace, None],
170 | gold_answer: Any) -> str:
171 | """ check if there is a sub-program that worth saving """
172 |
173 | if program_trace is None:
174 | return "S0: not executable"
175 |
176 | # see if the any of the states matches the saved correct states
177 | saved_states: Dict[HashableProgState, List[Program]] = self.state_programs_dict.get(task_id)
178 |
179 | # identify the longest partially (or fully) correct program that fits the constraints
180 | furthest_correct_state_idx = -1
181 | stmt_num = 0
182 | for i, tracing_unit in enumerate(program_trace):
183 | if tracing_unit.type != "stmt":
184 | continue
185 | else:
186 | stmt_num += 1
187 |
188 | if len(tracing_unit.state) == 0:
189 | continue # most likely because a comment is generated as the first line
190 | else:
191 | hashable_state = get_hashable_state(tracing_unit.state)
192 |
193 | reach_full_correct_state = mathqa_identify_fully_correct_state(tracing_unit.state, gold_answer)
194 | reach_output = mathqa_identify_output(tracing_unit.state)
195 | if reach_full_correct_state:
196 | min_fcp_len = min(list(map(get_num_stmts, self.correct_program_buffer[task_id])))
197 | if stmt_num - min_fcp_len <= self.length_diff_tolerance:
198 | furthest_correct_state_idx = i
199 | break
200 | else:
201 | return "S1: fully correct but length exceeds tolerance"
202 | elif reach_output and not reach_full_correct_state:
203 | # if the program produce the incorrect output, we don't need to save it (but the prefix might still be useful)
204 | break
205 | elif hashable_state in saved_states:
206 | min_pcp_len = min(list(map(get_num_stmts, saved_states[hashable_state])))
207 | if stmt_num - min_pcp_len <= self.length_diff_tolerance:
208 | furthest_correct_state_idx = i
209 |
210 | # save the all the sub-programs of this program by state
211 | if furthest_correct_state_idx != -1:
212 | is_fully_correct = mathqa_identify_fully_correct_state(program_trace[furthest_correct_state_idx].state, gold_answer)
213 | return self.save_program_by_trace(task_id, program_trace[:furthest_correct_state_idx + 1], is_fully_correct)
214 | else:
215 | return "S2: not partiallly correct or partially correct but length exceeds tolerance"
216 |
217 | def concat_context_with_multiple_programs(self, context_input_ids: torch.Tensor,
218 | context_attention_mask: torch.Tensor,
219 | programs: List[List[Program]],
220 | is_fully_correct: List[List[bool]]) -> str:
221 | """ concatenate each context tensor with multiple programs """
222 |
223 | # remove the left paddings first and concat the context and cached programs
224 | context_seqs = [context_input_ids[i, -context_len:]
225 | for i, context_len in enumerate(context_attention_mask.sum(dim=1))]
226 | cached_program_seqs: List[List[torch.Tensor]] = [[self.tokenizer(prog.code, return_tensors="pt")['input_ids'][0]
227 | for prog in task_programs]
228 | for task_programs in programs]
229 | flatten_input_ids = []
230 | for i, program_seqs in enumerate(cached_program_seqs):
231 | for j, program_seq in enumerate(program_seqs):
232 | concat_context_program = torch.cat([context_seqs[i], program_seq.to(dtype=context_seqs[i].dtype, device=self.device)], dim=0)
233 | if is_fully_correct[i][j]:
234 | concat_context_program = torch.cat((concat_context_program, torch.tensor([self.tokenizer.eos_token_id], device=self.device)), dim=0)
235 | flatten_input_ids.append(concat_context_program)
236 | flatten_attention_mask = [torch.ones_like(flatten_ids) for flatten_ids in flatten_input_ids]
237 |
238 | flatten_input_ids = left_pad_sequences(flatten_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
239 | flatten_attention_mask = left_pad_sequences(flatten_attention_mask, batch_first=True, padding_value=0)
240 |
241 | # exclude the loss from context by setting the labels of them to be -100
242 | if self.exclude_context_loss:
243 | flatten_labels = []
244 | for i, program_seqs in enumerate(cached_program_seqs):
245 | for j, program_seq in enumerate(program_seqs):
246 | concat_labels = torch.cat([-100 * torch.ones_like(context_seqs[i]),
247 | program_seq.to(dtype=context_seqs[i].dtype, device=self.device)], dim=0)
248 | if is_fully_correct[i][j]:
249 | concat_labels = torch.cat((concat_labels, torch.tensor([self.tokenizer.eos_token_id], device=self.device)), dim=0)
250 | flatten_labels.append(concat_labels)
251 | flatten_labels = left_pad_sequences(flatten_labels, batch_first=True, padding_value=self.tokenizer.pad_token_id)
252 | else:
253 | flatten_labels = flatten_input_ids
254 |
255 | assert flatten_input_ids.shape[0] == len(flatten_attention_mask) == sum([len(progs) for progs in programs]) == flatten_labels.shape[0]
256 |
257 | return flatten_input_ids, flatten_attention_mask, flatten_labels
258 |
259 | def get_marg_program_set(self, task_id: str):
260 | """ get the programs to marginalize over for each of the task according to the specific settings """
261 |
262 | # TODO: (maybe with fancier policy) first fill in the fully correct programs,
263 | # if not enough, fill in the partially correct programs
264 | programs, is_fully_correct = [], []
265 | programs.extend(self.correct_program_buffer[task_id])
266 | is_fully_correct.extend([True] * len(self.correct_program_buffer[task_id]))
267 |
268 | if self.gold_program_only:
269 | return programs[:1], is_fully_correct[:1]
270 | if self.fcp_only:
271 | return programs[:self.marg_set_size], is_fully_correct[:self.marg_set_size]
272 |
273 | non_empty_pcp_buffer = list(filter(lambda x: len(x.code_lite) > 0, self.partially_correct_program_buffer[task_id]))
274 | programs.extend(non_empty_pcp_buffer)
275 | is_fully_correct.extend([False] * len(non_empty_pcp_buffer))
276 |
277 | if len(programs) > self.marg_set_size:
278 | if self.prioritize_fcp:
279 | return programs[:self.marg_set_size], is_fully_correct[:self.marg_set_size]
280 | else:
281 | # random sample the program indices, regardless of being pcp or fcp
282 | idx_sample = random.sample(range(len(programs)), self.marg_set_size)
283 | return [programs[i] for i in idx_sample], [is_fully_correct[i] for i in idx_sample]
284 | else:
285 | return programs, is_fully_correct
286 |
287 | def get_samples_for_completion(self, task_ids) -> List[List[Program]]:
288 | """ sample some unfinished programs from the buffer to do the completion sampling """
289 | if self.sampling_from_states:
290 | # first sample a state for which the programs reach
291 | states = [random.sample(self.state_programs_dict[task_id].keys(), 1)[0] for task_id in task_ids]
292 |
293 | # then we sample a program that reaches the state
294 | programs = [random.sample(self.state_programs_dict[task_id][state], self.n_pcp_samples)
295 | for task_id, state in zip(task_ids, states)]
296 | elif self.sampling_full_prog_only:
297 | # get the blank program to sample the full program
298 | programs = [[self.partially_correct_program_buffer[task_id][0]] for task_id in task_ids]
299 | else:
300 | # sample a program from the pcp buffera
301 | programs = [random.sample(self.partially_correct_program_buffer[task_id], self.n_pcp_samples)
302 | for task_id in task_ids]
303 |
304 | return programs
305 |
306 |
307 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]:
308 | # use task ids as the identifier
309 | task_ids = [ex["task_id"] for ex in batch["metadata"]]
310 |
311 | # add the gold programs to the correct program buffer and empty program to the partially correct program buffer
312 | for i, example in enumerate(batch["metadata"]):
313 | task_id = task_ids[i]
314 | # need to initialize the three data structures with this order
315 | if task_id not in self.state_programs_dict:
316 | # in order to be able to sample the empty program when sampling programs to be complete
317 | self.state_programs_dict[task_id] = {"NULL": [get_empty_program()]}
318 |
319 | if task_id not in self.partially_correct_program_buffer:
320 | # add empty program to the partially correct program buffer
321 | self.partially_correct_program_buffer[task_id] = [get_empty_program()]
322 |
323 | if task_id not in self.correct_program_buffer:
324 | # add the gold programs to the correct program buffer
325 | self.correct_program_buffer[task_id] = []
326 | tracing_states = get_execution_states(example["code"])
327 | return_msg = self.save_program_by_trace(task_id, tracing_states, is_fully_correct=True)
328 | assert return_msg == "S6: saved"
329 |
330 | # do on-policy sampling for the current tasks from the program prefixes
331 | context_input_ids = batch["input_ids"]
332 | context_attention_mask = batch["attention_mask"]
333 |
334 | if not self.gold_program_only:
335 | # TODO: maybe with fancier sampling methods, e.g., sampling from in-between the program
336 | pcp_samples: List[List[Program]] = self.get_samples_for_completion(task_ids)
337 | pcp_input_ids, pcp_attention_mask, _ = self.concat_context_with_multiple_programs(context_input_ids,
338 | context_attention_mask, pcp_samples,
339 | is_fully_correct=[[False]*len(progs) for progs in pcp_samples])
340 |
341 | with torch.no_grad():
342 | # generate the programs and get their execution results
343 | max_context_len = pcp_input_ids.size(1)
344 | output_seqs = self.gpt.generate(input_ids=pcp_input_ids, attention_mask=pcp_attention_mask,
345 | do_sample=True, max_new_tokens=self.max_sampling_len,
346 | num_return_sequences=self.on_policy_sample_num,
347 | temperature=self.on_policy_sample_temp)
348 | generated_seqs = output_seqs[:, max_context_len:].cpu()
349 | generated_program_completions = self.tokenizer.batch_decode(generated_seqs, skip_special_tokens=True)
350 | generated_programs = [pcp_samples[i // self.on_policy_sample_num][0].code + completion # FIXME: only works when self.n_pcp_samples == 1
351 | for i, completion in enumerate(generated_program_completions)]
352 | program_traces = batch_program_tracing(generated_programs)
353 |
354 | # save the programs with correct results into the buffer if it's not already in there
355 | results_count_dict = {f"S{i}": 0 for i in range(7)}
356 | for i, program_trace in enumerate(program_traces):
357 | example_idx = i // self.on_policy_sample_num
358 | gold_answer = batch["metadata"][example_idx]["answer"]
359 | task_id = task_ids[example_idx]
360 | result_msg = self.check_and_save_partially_correct_program(task_id, program_trace, gold_answer)
361 | results_count_dict[result_msg[:2]] += 1
362 |
363 | for k in results_count_dict.keys():
364 | self.log(f"save_msg_{k}", results_count_dict[k] / len(program_traces), on_step=False, on_epoch=True)
365 |
366 | # select the set of programs to marginalize over for each task and tokenize + concatenate with context
367 | marg_programs: List[List[Program]] = []
368 | marg_is_fully_correct: List[List[bool]] = []
369 | marg_program_nums: List[int] = []
370 | for task_id in task_ids:
371 | programs, is_fully_correct = self.get_marg_program_set(task_id)
372 | marg_programs.append(programs)
373 | marg_is_fully_correct.append(is_fully_correct)
374 | marg_program_nums.append(len(programs))
375 | flatten_input_ids, flatten_attention_mask, flatten_labels = self.concat_context_with_multiple_programs(context_input_ids,
376 | context_attention_mask, marg_programs, marg_is_fully_correct)
377 |
378 | # go through the gpt model as if it's a new batch
379 | gpt_result = self.gpt(input_ids=flatten_input_ids, attention_mask=flatten_attention_mask, labels=flatten_labels)
380 |
381 | # reorganize the logits to compute individual program log probs
382 | shift_logits = gpt_result.logits[..., :-1, :].contiguous()
383 | shift_labels = flatten_labels[..., 1:].contiguous()
384 | loss_fct = CrossEntropyLoss(reduction="none")
385 | flatten_unreduced_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
386 | unreduced_loss =flatten_unreduced_loss.view(shift_labels.shape) * flatten_attention_mask[..., 1:]
387 |
388 | if self.norm_marg_by_len:
389 | # normalizing by length so we can avoid favoring shorter (e.g., partially-correct) programs
390 | seq_lens = torch.sum(flatten_attention_mask, dim=-1, keepdim=True)
391 | grouped_seq_len = torch.split(seq_lens, marg_program_nums)
392 | max_group_len = torch.cat([torch.full_like(seq_len, torch.max(seq_len)) for seq_len in grouped_seq_len], dim=0)
393 | assert seq_lens.shape == max_group_len.shape
394 | unreduced_loss = unreduced_loss / seq_lens * max_group_len
395 |
396 | # compute the marginal log prob
397 | loss = self.get_mixed_mle_mml_loss(unreduced_loss, marg_program_nums)
398 | self.log("loss", loss, on_step=True, on_epoch=True)
399 |
400 | self.log("marg_size", sum(marg_program_nums) / len(marg_program_nums), on_step=False, on_epoch=True)
401 | self.log("fcp_buffer_size", sum([len(self.correct_program_buffer[task_id]) for task_id in task_ids]) \
402 | / len(task_ids), on_step=False, on_epoch=True)
403 | self.log("pcp_buffer_size", sum([len(self.partially_correct_program_buffer[task_id]) for task_id in task_ids]) \
404 | / len(task_ids), on_step=False, on_epoch=True)
405 | self.log("state_prog_dict_size", sum([len(self.state_programs_dict[task_id]) for task_id in task_ids]) \
406 | / len(task_ids), on_step=False, on_epoch=True)
407 |
408 | return {"loss": loss}
409 |
410 | def forward( # type: ignore
411 | self,
412 | input_ids: torch.Tensor,
413 | attention_mask: torch.Tensor,
414 | metadata: Optional[List[Dict[str, Any]]] = None,
415 | ) -> List[Dict[str, Any]]:
416 | return super(GptStmtStateModel, self).forward(input_ids, attention_mask, metadata)
417 |
418 | def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None:
419 | # first do all the things that the base model do
420 | super(GptStmtMmlModel, self).validation_epoch_end(outputs)
421 |
422 | # then save the buffer status for the fully correct programs and the partially correct programs
423 | save_buffer_file_path = os.path.join(self.trainer.log_dir,
424 | f'buffer_step_{self.trainer.global_step}_rank_{self.trainer.global_rank}.jsonl')
425 | with open(save_buffer_file_path, 'w+') as f:
426 | for task_id, programs in self.correct_program_buffer.items():
427 | json_dict = {"task_id": task_id,
428 | "saved_fcp_programs": list([prog.code for prog in programs]),
429 | "saved_pcp_programs": list([prog.code for prog in self.partially_correct_program_buffer[task_id]])}
430 | f.write(json.dumps(json_dict) + '\n')
431 | print(f"buffer saved to {save_buffer_file_path}")
432 |
433 | def merge_buffers(self, is_fcp_buffer: bool):
434 | # first identify which buffer to merge (full or partial)
435 | buffer_to_merge = self.correct_program_buffer if is_fcp_buffer else self.partially_correct_program_buffer
436 |
437 | world_size = torch.distributed.get_world_size()
438 | all_buffer_list: List[Dict[str, List[Program]]] = [{} for _ in range(world_size)]
439 | torch.distributed.all_gather_object(all_buffer_list, buffer_to_merge)
440 |
441 | # merge all the buffers
442 | prev_avg_buffer_size = sum(map(lambda x: len(x[1]), buffer_to_merge.items())) / len(buffer_to_merge)
443 |
444 | for buffer in all_buffer_list:
445 | for task_id, programs in buffer.items():
446 | if task_id not in buffer_to_merge:
447 | assert task_id not in self.correct_program_buffer and \
448 | task_id not in self.partially_correct_program_buffer and \
449 | task_id not in self.state_programs_dict, \
450 | f"task_id {task_id} should not be in any buffer"
451 |
452 | # init all three data structures
453 | # TODO: this can be optimized by simply assigning the three data structures since they are empty
454 | self.state_programs_dict[task_id] = {"NULL": [get_empty_program()]}
455 | self.partially_correct_program_buffer[task_id] = [get_empty_program()]
456 | self.correct_program_buffer[task_id] = []
457 |
458 | for program in programs:
459 | self.save_program_by_trace(task_id, program.trace, is_fully_correct=is_fcp_buffer)
460 |
461 | after_avg_buffer_size = sum(map(lambda x: len(x[1]), buffer_to_merge.items())) / len(buffer_to_merge)
462 | print(f"{'fcp' if is_fcp_buffer else 'pcp'} buffer size increased " \
463 | f"from {prev_avg_buffer_size} to {after_avg_buffer_size}, " \
464 | f"by {after_avg_buffer_size - prev_avg_buffer_size}")
465 |
466 |
467 | def training_epoch_end(self, outputs) -> None:
468 | if not torch.distributed.is_initialized():
469 | print("training_epoch_end: not using distributed training")
470 | return
471 |
472 | self.merge_buffers(is_fcp_buffer=True)
473 | self.merge_buffers(is_fcp_buffer=False)
474 |
475 |
--------------------------------------------------------------------------------
/lightning_modules/models/gpt_stmt_state_model.py:
--------------------------------------------------------------------------------
1 | from numpy import dtype
2 | import torch
3 | import json
4 | import os
5 | import torch.nn.functional as F
6 | import pytorch_lightning as pl
7 |
8 | from typing import Optional, Dict, Any, Tuple, List
9 | from transformers.optimization import AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup
10 | from torch.nn import CrossEntropyLoss
11 |
12 | from torchmetrics import MeanMetric
13 | from pytorch_lightning import LightningModule
14 |
15 | from .gpt_util import get_gpt, sanity_check, left_pad_sequences
16 | from .gpt_seq2seq_model import GptSeq2SeqModel
17 | from execution.execution_evaluation import execution_acc, mathqa_execution
18 | from execution.execution_evaluation import execution_eval_at_k, batch_execution_acc
19 | from execution.program_tracing import exec_stmt_in_context, get_state_repr, is_trivial_state
20 | from execution.safe_execution_util import canonicalize_var_dict
21 |
22 | class GptStmtStateModel(GptSeq2SeqModel):
23 | def __init__(self,
24 | transformer_model_name: str,
25 | max_stmt_len: int = 20,
26 | max_stmt_num: int = 20,
27 | max_context_len: int = 1024,
28 | eval_with_states: bool = False,
29 | skip_trivial_states: bool = False,
30 | **kwargs) -> None:
31 | super().__init__(transformer_model_name, **kwargs)
32 |
33 | self.max_stmt_len = max_stmt_len
34 | self.max_stmt_num = max_stmt_num
35 | self.max_context_len = max_context_len
36 | self.eval_with_states = eval_with_states
37 | self.skip_trivial_states = skip_trivial_states
38 |
39 | def forward( # type: ignore
40 | self,
41 | input_ids: torch.Tensor,
42 | attention_mask: torch.Tensor,
43 | metadata: Optional[List[Dict[str, Any]]] = None,
44 | ) -> List[Dict[str, Any]]:
45 | """
46 | The inference time behavior of the model.
47 |
48 | Args:
49 | input_ids [torch.Tensor]: Tokens from the context.
50 | metadata (Optional[List[Dict[str, Any]]], optional): All additional information, `List` for the batch. Defaults to None.
51 |
52 | Returns:
53 | Dict[str, Any]: results saved in a `Dict` object.
54 | """
55 | output_dicts = [{"metadata": metadata[i]} for i in range(len(metadata))]
56 |
57 | # iteratively generate the stmts until eos is generated
58 | batch_size = len(metadata)
59 | completed_programs = [[] for _ in range(batch_size)]
60 | program_states_list = [[{}] for _ in range(batch_size)]
61 | incomplete_program_indices = list(range(batch_size))
62 | for stmt_idx in range(self.max_stmt_num):
63 | # this is how many examples are left in the batch
64 | inner_batch_size = len(input_ids)
65 |
66 | max_context_len = input_ids.size(1)
67 | context_len_list = attention_mask.sum(dim=1)
68 |
69 | output_seqs = self.gpt.generate(input_ids=input_ids, attention_mask=attention_mask,
70 | do_sample=False, max_length=self.max_gen_len+max_context_len) # FIXME: this should be self.max_stmt_len
71 | # temperature=self.sampling_temp) # NOTE: now we assume only one seq is returned
72 |
73 | # remove the context and the tokens after the first newline token in the generated seq
74 | generated_seqs = [output_seqs[:, max_context_len:][i] for i in range(inner_batch_size)]
75 | for i, output_seq in enumerate(generated_seqs):
76 | nl_indices = (output_seq == self.tokenizer._convert_token_to_id(self.tokenizer.tokenize("\n")[0])).nonzero(as_tuple=True)[0]
77 | if len(nl_indices) > 0:
78 | first_nl_idx = int(nl_indices[0])
79 | generated_seqs[i] = output_seq[:first_nl_idx+1] # +1 because we need to include the newline token
80 | else:
81 | # this means that the generation hits the max_stmt_len before the first newline token or an early
82 | # eos token is generated. Either way, we need to stop the generation, so we snap an (possibly
83 | # additional) eos token to the end of the generated seq.
84 | generated_seqs[i] = torch.cat([output_seq, torch.tensor([self.tokenizer.eos_token_id], device=output_seq.device)])
85 |
86 | # concat the context with the generated seqs
87 | # full_seqs = []
88 | # for i in range(inner_batch_size):
89 | # full_seq = torch.cat([input_ids[i][:context_len_list[i]], generated_seqs[i]])
90 | # full_seqs.append(full_seq)
91 |
92 | # check if the end_of_sequence token is in the generated output
93 | incomplete_output_seqs = []
94 | incomplete_program_indices_new = []
95 | for i, output_seq in enumerate(generated_seqs):
96 | eos_indices = (output_seq == self.tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
97 | if len(eos_indices) > 0: # eos detected, end the stmt generation loop
98 | # cut off **at** the eos token and it's not added to the incomplete list (it's finished)
99 | first_eos_idx = int(eos_indices[0])
100 | output_seq = output_seq[:first_eos_idx]
101 | completed_programs[incomplete_program_indices[i]].extend(output_seq)
102 |
103 | # get the last state
104 | if self.eval_with_states:
105 | stmt_str = self.tokenizer.decode(output_seq, skip_special_tokens=True)
106 | last_state_dict = program_states_list[incomplete_program_indices[i]][-1]
107 | output_state = exec_stmt_in_context(stmt_str, last_state_dict)
108 | program_states_list[incomplete_program_indices[i]].append(output_state)
109 | else:
110 | if self.eval_with_states:
111 | # the generation is not finished yet, we need to augment the next steps with state information
112 | stmt_str = self.tokenizer.decode(output_seq, skip_special_tokens=True)
113 | last_state_dict = program_states_list[incomplete_program_indices[i]][-1]
114 | output_state = exec_stmt_in_context(stmt_str, last_state_dict)
115 | program_states_list[incomplete_program_indices[i]].append(output_state)
116 |
117 | if output_state == None:
118 | # the program will be not successfully executed, so we remove the program from the incomplete list
119 | completed_programs[incomplete_program_indices[i]].extend(output_seq)
120 | else:
121 | # incorporate the state into the context
122 | state_str = get_state_repr(output_state, only_include_keys=[stmt_str.split(" ")[0]],
123 | prev_stmt=stmt_str, skip_trivial_states=self.skip_trivial_states)
124 | state_tensor = self.tokenizer.encode(state_str, add_special_tokens=False, return_tensors="pt")[0] \
125 | .to(device=output_seq.device, dtype=output_seq.dtype)
126 | output_seq = torch.cat([output_seq, state_tensor])
127 |
128 | incomplete_output_seqs.append(torch.cat([input_ids[i][-context_len_list[i]:], output_seq]))
129 | completed_programs[incomplete_program_indices[i]].extend(output_seq)
130 | incomplete_program_indices_new.append(incomplete_program_indices[i])
131 | else:
132 | incomplete_output_seqs.append(torch.cat([input_ids[i][-context_len_list[i]:], output_seq]))
133 | completed_programs[incomplete_program_indices[i]].extend(output_seq)
134 | incomplete_program_indices_new.append(incomplete_program_indices[i])
135 |
136 | incomplete_program_indices = incomplete_program_indices_new
137 |
138 | if len(incomplete_output_seqs) == 0:
139 | # all seqs have been completed by generating the eos token
140 | break
141 | elif stmt_idx == self.max_stmt_num - 1:
142 | # reach the max stmt num, but still not all the seqs are completed
143 | # for i, incomplete_cell in enumerate(incomplete_output_seqs):
144 | # completed_programs[incomplete_program_indices[i]].extend(
145 | # torch.tensor([self.tokenizer.eos_token_id], device=output_seq.device))
146 | break
147 |
148 | # reformulate the input_ids and attention_mask with the newly generated output
149 | incomplete_output_seqs = [output_seq[-self.max_context_len:] for output_seq in incomplete_output_seqs]
150 | attention_mask_list = [torch.ones_like(incomplete_output_seq) for incomplete_output_seq in incomplete_output_seqs]
151 | # pad to the same length and stack to be the new input_ids
152 | input_ids = left_pad_sequences(incomplete_output_seqs, batch_first=True,
153 | padding_value=self.tokenizer.eos_token_id)
154 | attention_mask = left_pad_sequences(attention_mask_list, batch_first=True, padding_value=False)
155 |
156 | # completed_cells are accumlated stmts for each program, not including the original context; decode back to the strs
157 | generated_programs= self.tokenizer.batch_decode(completed_programs)
158 |
159 | for i in range(len(metadata)):
160 | output_dicts[i].update({"generated_program": generated_programs[i]})
161 | output_dicts[i].update({"generated_program_state_list": program_states_list[i]})
162 |
163 | return output_dicts
164 |
165 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]:
166 | input_ids = batch["input_ids"]
167 | attention_mask = batch["attention_mask"]
168 | labels = batch["labels"] if "labels" in batch else input_ids
169 |
170 | gpt_result = self.gpt(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
171 | self.log("loss", gpt_result.loss, on_step=True, on_epoch=True)
172 |
173 | if "state_mask" in batch:
174 | # log separately for loss on state tokens and non-state tokens
175 | state_mask = batch["state_mask"]
176 |
177 | # Shift so that tokens < n predict n
178 | shift_logits = gpt_result.logits[..., :-1, :].contiguous()
179 | shift_labels = labels[..., 1:].contiguous()
180 | shift_state_mask = state_mask[..., 1:].contiguous()
181 |
182 | # Flatten the tokens
183 | loss_fct = CrossEntropyLoss(reduction="none")
184 | unreduced_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
185 |
186 | code_token_loss = torch.sum(shift_state_mask.view(-1) * unreduced_loss) / torch.sum(shift_state_mask)
187 | state_token_loss = torch.sum((1 - shift_state_mask.view(-1)) * unreduced_loss) / torch.sum(1 - shift_state_mask)
188 |
189 | self.log("code_token_loss", code_token_loss, on_step=True, on_epoch=True)
190 | self.log("state_token_loss", state_token_loss, on_step=True, on_epoch=True)
191 |
192 | return {"loss": gpt_result.loss}
193 |
194 | def sanity_check_validation_step_end(self, outputs: List[Dict[str, Any]]) -> None:
195 | # update the evaluation metrics
196 | for output_dict in outputs:
197 | last_state_dict = output_dict["generated_program_state_list"][-1]
198 | if last_state_dict is not None and "answer" in last_state_dict:
199 | exec_rate = 1.0
200 | if last_state_dict["answer"] == output_dict["metadata"]["answer"]:
201 | exec_acc = 1.0
202 | else:
203 | exec_acc = 0.0
204 | else:
205 | exec_rate = 0.0
206 | exec_acc = 0.0
207 | output_dict.pop("generated_program_state_list")
208 |
209 | program_len = len(list(filter(lambda x: not x.startswith("#"),
210 | output_dict["generated_program"].split("\n"))))
211 | gold_program_len = len(list(filter(lambda x: not x.startswith("#"), output_dict["metadata"]["code"].split("\n"))))
212 | program_len_diff = program_len - gold_program_len
213 |
214 | self._num_metric_1(exec_acc)
215 | self._num_metric_2(exec_rate)
216 |
217 | self._num_metric_3(program_len_diff)
218 |
219 | output_dict["metrics"] = {"exec_acc": float(exec_acc),
220 | "exec_rate": float(exec_rate),
221 | "program_len_diff": float(program_len_diff)}
222 |
223 | # save the outputs to the model
224 | self.predictions.extend(outputs)
225 |
--------------------------------------------------------------------------------
/lightning_modules/models/gpt_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Tuple, Optional, List, Union
3 |
4 | from transformers import GPTNeoForCausalLM, GPT2Tokenizer
5 | from transformers import PreTrainedModel, PreTrainedTokenizer, GPT2LMHeadModel
6 | from transformers import GPT2Tokenizer, GPTJForCausalLM
7 |
8 | def get_gpt(model_name: str,
9 | tokenizer_only: bool = False,
10 | gradient_ckpt: bool = False,
11 | additional_special_tokens: Optional[List[str]] = None) \
12 | -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
13 | if additional_special_tokens is None:
14 | additional_special_tokens = []
15 |
16 | if not tokenizer_only:
17 | print(f"using pretrained model: {model_name}, gradient_ckpt: {gradient_ckpt}")
18 |
19 | if model_name == "microsoft/CodeGPT-small-py":
20 | tokenizer = GPT2Tokenizer.from_pretrained(model_name, additional_special_tokens=additional_special_tokens)
21 | if not tokenizer_only:
22 | model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)
23 | if len(additional_special_tokens) > 0:
24 | model.resize_token_embeddings(len(tokenizer))
25 | if model_name == "EleutherAI/gpt-j-6B":
26 | tokenizer = GPT2Tokenizer.from_pretrained(model_name)
27 | tokenizer.pad_token = tokenizer.eos_token
28 |
29 | if not tokenizer_only:
30 | model = GPTJForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
31 | gradient_checkpointing=gradient_ckpt, use_cache=not gradient_ckpt)
32 | if len(additional_special_tokens) > 0:
33 | model.resize_token_embeddings(len(tokenizer))
34 | elif model_name in ["EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-2.7B"]:
35 | tokenizer = GPT2Tokenizer.from_pretrained(model_name, additional_special_tokens=additional_special_tokens)
36 | tokenizer.pad_token = tokenizer.eos_token
37 |
38 | if not tokenizer_only:
39 | model = GPTNeoForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
40 | gradient_checkpointing=gradient_ckpt, use_cache=not gradient_ckpt)
41 | if len(additional_special_tokens) > 0:
42 | model.resize_token_embeddings(len(tokenizer))
43 | else:
44 | raise NotImplementedError
45 |
46 | if tokenizer_only:
47 | return None, tokenizer
48 | else:
49 | return model, tokenizer
50 |
51 | def left_pad_sequences(sequences: List[torch.Tensor], batch_first: bool = True, padding_value: Union[int, bool] = 0,
52 | max_len: int = -1, device: torch.device = None) -> torch.Tensor:
53 | assert all([len(seq.shape) == 1 for seq in sequences])
54 | max_len = max_len if max_len > 0 else max(len(s) for s in sequences)
55 | device = device if device is not None else sequences[0].device
56 |
57 | padded_seqs = []
58 | for seq in sequences:
59 | padded_seqs.append(torch.cat((torch.full((max_len - seq.shape[0],), padding_value, dtype=torch.long).to(device), seq)))
60 | return torch.stack(padded_seqs)
61 |
62 | def sanity_check(test_str: str, model, tokenizer):
63 | print(f"test str is: ###############{test_str}##############")
64 |
65 | input_ids = tokenizer.encode(test_str, add_special_tokens=False, return_tensors="pt").to(model.device)
66 | attention_mask = torch.where(input_ids == tokenizer.eos_token_id, torch.zeros_like(input_ids), torch.ones_like(input_ids)).to(model.device)
67 |
68 | output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=40, num_return_sequences=1)
69 |
70 | output_str = tokenizer.decode(output_ids[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)
71 | output_str_no_sp_tokens = tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
72 |
73 | print(f"output str is: ###############{output_str}##############")
74 |
75 | new_test_str = " ".join(output_str_no_sp_tokens.split("\n")[:-1])
76 |
77 | print(f"new test str is: ###############{new_test_str}###############")
78 |
79 | input_ids = tokenizer.encode(new_test_str, add_special_tokens=False, return_tensors="pt").to(model.device)
80 | attention_mask = torch.where(input_ids == tokenizer.eos_token_id, torch.zeros_like(input_ids), torch.ones_like(input_ids)).to(model.device)
81 |
82 | output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=40, num_return_sequences=1)
83 |
84 | output_str = tokenizer.decode(output_ids[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)
85 |
86 | print(f"new output str is: ###############{output_str}###############")
87 |
--------------------------------------------------------------------------------
/preprocessing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/preprocessing/__init__.py
--------------------------------------------------------------------------------
/preprocessing/preprocess_gsm8k.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import json
4 | import math
5 |
6 | from typing import List, Dict, Any
7 | from tqdm import tqdm
8 |
9 | from execution.execution_evaluation import mathqa_execution, execution_acc
10 | from tree_sitter import Language, Parser
11 |
12 | # initialize the parser for the code
13 | language_build_path = os.path.join(os.path.dirname(__file__), 'py-tree-sitter.so')
14 | PY_LANGUAGE = Language(language_build_path, 'python')
15 | parser = Parser()
16 | parser.set_language(PY_LANGUAGE)
17 |
18 | train_file = "./data/gsmath/train.jsonl"
19 | test_file = "./data/gsmath/test.jsonl"
20 |
21 | def get_answer_from_answer_str(answer_str: str) -> float:
22 | result_str = answer_str.split("\n")[-1].split(" ")[-1]
23 | result = float(result_str.replace(",", ""))
24 |
25 | return result
26 |
27 | def get_code_from_answer_str(answer_str: str, question_str: str) -> str:
28 | # reverse_var_dict only keeps the constants and the t_lines does not contain the constant inits
29 | reverse_var_dict: Dict[float, str] = {}
30 | reverse_temp_var_dict: Dict[float, str] = {}
31 | temp_var_num = 0
32 | t_lines = []
33 |
34 | for line in answer_str.split("\n")[:-1]:
35 | if not ("<<" in line and ">>" in line):
36 | continue
37 |
38 | # first extract the formula
39 | formula = line[line.index("<<") + 2: line.index(">>")]
40 |
41 | def get_var_name(var_str: str, allow_new: bool = True) -> str:
42 | num = float(var_str)
43 | if num in reverse_temp_var_dict:
44 | var_name = reverse_temp_var_dict[num]
45 | elif num in reverse_var_dict:
46 | var_name = reverse_var_dict[num]
47 | elif allow_new:
48 | # a new constant
49 | var_name = f"n{len(reverse_var_dict)}"
50 | reverse_var_dict[num] = var_name
51 | else:
52 | raise ValueError(f"{var_str} not found in var/temp dict")
53 |
54 | return var_name
55 |
56 | def get_node_text(node, text) -> str:
57 | return text[node.start_byte: node.end_byte]
58 |
59 | # make sure that the formula is valid
60 | expression, result = formula.split("=")
61 | if "/" in result:
62 | result = eval(result)
63 | if not eval(expression) == float(result):
64 | return "NULL"
65 |
66 | # interpret the formula with a parse tree
67 | assert expression.isascii, f"{expression} is not ascii"
68 | parsed_tree = parser.parse(bytes(expression, 'utf-8'))
69 |
70 | # do a dfs on the parsed tree to get the values replaced with names
71 | formula_bits = []
72 | node_stack = [parsed_tree.root_node.children[0].children[0]]
73 | while len(node_stack) > 0:
74 | node = node_stack.pop()
75 |
76 | if node.type in ["integer", "float"]:
77 | var_name = get_var_name(get_node_text(node, expression))
78 | formula_bits.append(var_name)
79 | elif node.type in ["+", "-", "*", "/", "**", "(", ")", "//"]:
80 | formula_bits.append(get_node_text(node, expression))
81 | elif node.type in ["binary_operator", "parenthesized_expression"]:
82 | node_stack.extend(node.children[::-1])
83 | elif node.type == "unary_operator":
84 | if node.children[0].type == "+":
85 | var_name = get_var_name(get_node_text(node, expression))
86 | formula_bits.append(var_name)
87 | elif node.children[0].type == "-":
88 | val = -float(get_node_text(node, expression))
89 | if val in reverse_temp_var_dict or val in reverse_var_dict:
90 | formula_bits.append(get_var_name(val, allow_new=False))
91 | elif -val in reverse_temp_var_dict or val in reverse_var_dict:
92 | formula_bits.append("-"+get_var_name(-val, allow_new=False))
93 | else:
94 | formula_bits.append(get_var_name(val, allow_new=True))
95 | else:
96 | raise ValueError(f"{expression} has unary operator {node.children[0].type}")
97 | else:
98 | raise ValueError(f"{expression} has {node.type}")
99 |
100 | right_formula = "".join(formula_bits)
101 |
102 | # add the temporary var
103 | # NOTE: we can't use the len(reverse_temp_var_dict) because we may have the same temp var in different lines
104 | temp_var_name = f"t{temp_var_num}"
105 | temp_var_num += 1
106 | reverse_temp_var_dict[float(result)] = temp_var_name
107 |
108 | # add the line
109 | t_lines.append(f"{temp_var_name}={right_formula}")
110 |
111 | # add the const var inits
112 | init_lines = []
113 | sorted_var_dict = sorted(reverse_var_dict.items(), key=lambda x: int(x[1][1:]))
114 | for var_val, var_name in sorted_var_dict:
115 | # if the float var is not directly used, and it can be casted as int, do cast as init
116 | if not str(var_val) in question_str and math.isclose(int(var_val), var_val, abs_tol=1e-4):
117 | init_lines.append(f"{var_name}={int(var_val)}")
118 | else:
119 | init_lines.append(f"{var_name}={var_val}")
120 |
121 |
122 | if len(t_lines) == 0:
123 | # no <> are given for this example, simply skip
124 | return "NULL"
125 |
126 | # replace the last line's temp var name with "answer"
127 | t_lines[-1] = "answer=" + t_lines[-1].split("=")[1]
128 |
129 | return "\n".join(init_lines + t_lines)
130 |
131 | def verify_code(code: str, gold_answer: str) -> bool:
132 | try:
133 | exec(code)
134 | if float(gold_answer) == float(eval("answer")):
135 | return True
136 | else:
137 | return False
138 | except Exception as e:
139 | return False
140 |
141 | def process_gsmath(instances: List[Dict[str, str]], set_name: str) -> List[Dict[str, Any]]:
142 | failed_code_extraction_indices = []
143 | for i, instance in tqdm(enumerate(instances)):
144 | # put it in the mathqa style: text, code, answer, task_id
145 | instance["text"] = instance["question"]
146 | instance.pop("question")
147 |
148 | instance["original_answer"] = instance["answer"]
149 | instance["task_id"] = f"{set_name}_{i}"
150 |
151 | instance["code"] = get_code_from_answer_str(instance["original_answer"], instance["text"])
152 | instance["answer"] = get_answer_from_answer_str(instance["original_answer"])
153 |
154 | if instance["code"] == "NULL":
155 | # failed to extract code, will skip this example in training, and only record for dev/test
156 | failed_code_extraction_indices.append(i)
157 |
158 | # verify the validity of the code
159 | failed_code_execution_indices = []
160 | for i, instance in enumerate(instances):
161 | if i in failed_code_extraction_indices:
162 | continue
163 |
164 | if not verify_code(instance["code"], instance["answer"]):
165 | failed_code_execution_indices.append(i)
166 | # print(f"{instance['task_id']} failed to verify, " \
167 | # f"original_answer: {instance['original_answer']}, " \
168 | # f"code: \n{instance['code']}\nanswer: {instance['answer']}")
169 |
170 | all_failed_indices = sorted(failed_code_extraction_indices + failed_code_execution_indices)
171 |
172 | print(f"{len(failed_code_extraction_indices)}/{len(instances)} failed to extract code")
173 | print(f"{len(failed_code_execution_indices)}/{len(instances)} failed to execute to the correct result")
174 | print(f"{len(all_failed_indices)}/{len(instances)} failed in total")
175 |
176 | # remove the failed examples if this is training set
177 | if set_name == "train":
178 | for i in all_failed_indices[::-1]:
179 | instances.pop(i)
180 |
181 | return instances
182 |
183 | def main():
184 | # load the train and test data
185 | with open(train_file, "r") as f:
186 | train_lines = f.readlines()
187 | train_data = [json.loads(line) for line in train_lines]
188 |
189 | with open(test_file, "r") as f:
190 | test_lines = f.readlines()
191 | test_data = [json.loads(line) for line in test_lines]
192 |
193 | # split the train data to train and dev
194 | train_data, dev_data = train_data[:int(len(train_data) * 0.8)], train_data[int(len(train_data) * 0.8):]
195 |
196 | # process all the data
197 | processed_train_data = process_gsmath(train_data, "train")
198 | processed_dev_data = process_gsmath(dev_data, "val")
199 | processed_test_data = process_gsmath(test_data, "test")
200 |
201 | # write the processed data to files
202 | with open("./data/gsmath/gsmath_train.jsonl", "w") as f:
203 | f.write("\n".join([json.dumps(data) for data in processed_train_data]))
204 | with open("./data/gsmath/gsmath_val.jsonl", "w") as f:
205 | f.write("\n".join([json.dumps(data) for data in processed_dev_data]))
206 | with open("./data/gsmath/gsmath_test.jsonl", "w") as f:
207 | f.write("\n".join([json.dumps(data) for data in processed_test_data]))
208 |
209 | def prune_gsmath(file_name: str) -> None:
210 | assert file_name.endswith(".jsonl")
211 |
212 | # load the data
213 | with open(file_name, "r") as f:
214 | train_lines = f.readlines()
215 | instances = [json.loads(line) for line in train_lines]
216 |
217 | failed_code_extraction_indices = []
218 | for i, instance in tqdm(enumerate(instances)):
219 | if instance["code"] == "NULL":
220 | # failed to extract code, will skip this example in training, and only record for dev/test
221 | failed_code_extraction_indices.append(i)
222 |
223 | # verify the validity of the code
224 | failed_code_execution_indices = []
225 | for i, instance in enumerate(instances):
226 | if i in failed_code_extraction_indices:
227 | continue
228 |
229 | if not verify_code(instance["code"], instance["answer"]):
230 | failed_code_execution_indices.append(i)
231 |
232 | all_failed_indices = sorted(failed_code_extraction_indices + failed_code_execution_indices)
233 |
234 | print(f"{len(failed_code_extraction_indices)}/{len(instances)} failed to extract code")
235 | print(f"{len(failed_code_execution_indices)}/{len(instances)} failed to execute to the correct result")
236 | print(f"{len(all_failed_indices)}/{len(instances)} failed in total")
237 |
238 | # remove the failed examples if this is training set
239 | for i in all_failed_indices[::-1]:
240 | instances.pop(i)
241 |
242 | with open(f"{file_name[:-6]}_pruned.jsonl", "w") as f:
243 | f.write("\n".join([json.dumps(ins) for ins in instances]))
244 |
245 | if __name__ == "__main__":
246 | # prune_gsmath("./data/gsmath/gsmath_val.jsonl")
247 | main()
--------------------------------------------------------------------------------
/preprocessing/preprocess_mathqa_python.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from typing import List, Dict, Tuple, Any, Union
4 |
5 | def read_jsonl_file(file_path: str) -> List[Dict[str, Any]]:
6 | with open(file_path, "r") as f:
7 | examples = [json.loads(line) for line in f]
8 |
9 | return examples
10 |
11 | def create_resplit_index():
12 | original_train_examples = read_jsonl_file("data/mathqa/train-python.jsonl")
13 | original_dev_examples = read_jsonl_file("data/mathqa/val-python.jsonl")
14 | original_all_examples = original_train_examples + original_dev_examples
15 |
16 | text_to_example_idx = {x["text"].strip(): i for i, x in enumerate(original_all_examples)}
17 |
18 | dedup_train_examples = read_jsonl_file("data/mathqa/train_dedup.jsonl")
19 | dedup_dev_examples = read_jsonl_file("data/mathqa/val_dedup.jsonl")
20 |
21 | dedup_train_idx = [text_to_example_idx[x["text"].strip()] for x in dedup_train_examples]
22 | dedup_dev_idx = [text_to_example_idx[x["text"].strip()] for x in dedup_dev_examples]
23 |
24 | with open("preprocessing/mathqa_python_resplit_info.json", "w+") as f:
25 | info = {
26 | "train_first_examples": dedup_train_examples[:5],
27 | "dev_first_examples": dedup_dev_examples[:5],
28 | "train_idx": dedup_train_idx,
29 | "dev_idx": dedup_dev_idx
30 | }
31 | json.dump(info, f)
32 |
33 | def recreate_split():
34 | # load the precomputed resplit info
35 | with open("preprocessing/mathqa_python_resplit_info.json", "r") as f:
36 | resplit_info = json.load(f)
37 |
38 | # load the original examples
39 | original_train_examples = read_jsonl_file("data/mathqa/train-python.jsonl")
40 | original_dev_examples = read_jsonl_file("data/mathqa/val-python.jsonl")
41 | original_all_examples = original_train_examples + original_dev_examples
42 |
43 | # recreate the split using the resplit info
44 | dedup_train_examples = [original_all_examples[i] for i in resplit_info["train_idx"]]
45 | dedup_dev_examples = [original_all_examples[i] for i in resplit_info["dev_idx"]]
46 |
47 | # rename the task ids
48 | for i, instance in enumerate(dedup_train_examples):
49 | instance['task_id'] = f"train_{i}"
50 | for i, instance in enumerate(dedup_dev_examples):
51 | instance['task_id'] = f"val_{i}"
52 |
53 | # verify that the split is correct
54 | assert all([x[0]["text"] == x[1]["text"] for x in zip(dedup_train_examples[:5], resplit_info["train_first_examples"])]), "train split is incorrect"
55 | assert all([x[0]["text"] == x[1]["text"] for x in zip(dedup_dev_examples[:5], resplit_info["dev_first_examples"])]), "dev split is incorrect"
56 |
57 | # write the recreated split to disk
58 | with open("data/mathqa/train_dedup.jsonl", "w+") as f:
59 | for example in dedup_train_examples:
60 | f.write(json.dumps(example) + "\n")
61 |
62 | with open("data/mathqa/val_dedup.jsonl", "w+") as f:
63 | for example in dedup_dev_examples:
64 | f.write(json.dumps(example) + "\n")
65 |
66 | def main():
67 | # create_resplit_index()
68 | recreate_split()
69 |
70 | if __name__ == "__main__":
71 | main()
--------------------------------------------------------------------------------
/preprocessing/py-tree-sitter.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/TraceCodegen/92bada8c9090de69cca037ea7c5449df420b40a5/preprocessing/py-tree-sitter.so
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # azureml-sdk == 1.26.0
2 | tensorboard ~= 2.4
3 | huggingface-hub == 0.1.2
4 | tree-sitter ~= 0.19.0
5 | torchmetrics == 0.6.0
6 |
7 | transformers == 4.16.2
8 | torch == 1.10.2
9 | pytorch-lightning == 1.5.10
10 | deepspeed ~= 0.5.10
11 |
12 | # to avoid getting an error of "fit.optimizer" conflict namespace error
13 | jsonargparse == 3.19.4
14 |
15 | nltk
16 | rouge-score
17 | jsonargparse[signatures]
18 | overrides
19 | neptune-client
20 | scipy
21 | editdistance
22 | bashplotlib
23 | wandb
24 | astunparse
25 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning import LightningModule, LightningDataModule
2 | from pytorch_lightning.utilities.cli import LightningCLI
3 |
4 | # see https://github.com/PyTorchLightning/pytorch-lightning/issues/10349
5 | import warnings
6 |
7 | warnings.filterwarnings(
8 | "ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*"
9 | )
10 |
11 | cli = LightningCLI(LightningModule, LightningDataModule,
12 | subclass_mode_model=True, subclass_mode_data=True,
13 | save_config_callback=None)
14 |
--------------------------------------------------------------------------------
/training_configs/gpt_mle.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 333
2 | trainer:
3 | gpus: 2
4 | gradient_clip_val: 1.0
5 | default_root_dir: debug-tmp
6 | # val_check_interval: 1.0
7 | max_steps: &max_steps 50000
8 | check_val_every_n_epoch: 2
9 | log_every_n_steps: 1
10 | num_sanity_val_steps: 0
11 | logger:
12 | - class_path: lightning_modules.loggers.patched_loggers.PatchedWandbLogger
13 | init_args:
14 | entity: niansong1996
15 | project: trace-codegen
16 | name: debug-tmp
17 | log_model: False
18 | save_code: True
19 | offline: False
20 | callbacks:
21 | - class_path: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
22 | init_args:
23 | monitor: exec_acc
24 | mode: max
25 | filename: '{step}-{exec_acc:.4f}-{exec_rate:.4f}'
26 | save_top_k: 5
27 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor
28 | init_args:
29 | logging_interval: step
30 | - class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar
31 | init_args:
32 | refresh_rate: 1
33 |
34 | accelerator: gpu
35 | # replace_sampler_ddp: False
36 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/8262
37 | strategy: deepspeed_stage_2
38 | # strategy: ddp_find_unused_parameters_false
39 | # precision: 16
40 | # accumulate_grad_batches: 4
41 |
42 | model:
43 | class_path: lightning_modules.models.gpt_seq2seq_model.GptSeq2SeqModel
44 | init_args:
45 | transformer_model_name: &transformer EleutherAI/gpt-neo-2.7B
46 | max_gen_len: 256
47 | sampling_temp: 0.2
48 | sampling_temp_at_k: 0.8
49 | # pass_at_k: 80
50 | # eval_pass_at_k_every_n_epochs: 4
51 | # additional_pass_at_k: [5, 10, 20, 50]
52 | # max_generation_batches: 50
53 | # always_eval_pass_at_k_first_n_epochs: 10
54 | gradient_ckpt: true
55 | measure_dedup_metrics: false
56 | # eval_greedy_search: true
57 | # load_ckpt_file: /home/v-ansongni/data/trace-codegen-data-wuphillyblob/data/mathqa/model_ckpts/step=6904-exec_acc=0.4480-exec_rate=0.6445.ckpt
58 | # load_ckpt_file: /home/v-ansongni/Code/trace-codegen/amlt/mathqa-finetune-gpt-neo-125M-pad-left/gpt-neo-mathqa-finetuning/lightning_logs/version_0/checkpoints/step=54044-exec_acc=0.7715-exec_rate=0.9893.ckpt
59 | optimizer:
60 | class_path: torch.optim.adamw.AdamW
61 | init_args:
62 | lr: 1.0e-4
63 | # lr: 0.0
64 | betas:
65 | - 0.9
66 | - 0.999
67 | eps: 1.0e-8
68 | weight_decay: 0.1
69 | lr_scheduler:
70 | name: linear
71 | init_args:
72 | num_warmup_steps: 100
73 | num_training_steps: *max_steps
74 |
75 | data:
76 | class_path: lightning_modules.datasets.mathqa_line_reader.MathQADataModule
77 | init_args:
78 | transformer_model_name: *transformer
79 | batch_size: 2
80 | val_batch_size: 4
81 | # train_file_path: data/mathqa/train_dedup.jsonl
82 | # val_file_path: data/mathqa/val_dedup.jsonl
83 | train_file_path: data/gsmath/gsmath_train_val.jsonl
84 | val_file_path: data/gsmath/gsmath_test.jsonl
85 | # train_max_instances: 40
86 | # val_max_instances: 20
87 | # few_shot_n: 4
--------------------------------------------------------------------------------
/training_configs/gpt_self_sampling.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 333
2 | trainer:
3 | gpus: 1
4 | gradient_clip_val: 1.0
5 | default_root_dir: debug-tmp
6 | # val_check_interval: 1.0
7 | max_steps: &max_steps 50000
8 | check_val_every_n_epoch: 2
9 | log_every_n_steps: 1
10 | num_sanity_val_steps: 0
11 | logger:
12 | - class_path: lightning_modules.loggers.patched_loggers.PatchedWandbLogger
13 | init_args:
14 | entity: niansong1996
15 | project: trace-codegen
16 | name: debug-tmp
17 | log_model: False
18 | save_code: True
19 | offline: False
20 | callbacks:
21 | - class_path: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
22 | init_args:
23 | monitor: exec_acc
24 | mode: max
25 | filename: '{step}-{exec_acc:.4f}-{exec_rate:.4f}'
26 | save_top_k: 5
27 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor
28 | init_args:
29 | logging_interval: step
30 | - class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar
31 | init_args:
32 | refresh_rate: 1
33 |
34 | accelerator: gpu
35 | # replace_sampler_ddp: False
36 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/8262
37 | strategy: deepspeed_stage_2_offload
38 | # strategy: ddp_find_unused_parameters_false
39 | # precision: 16
40 | # accumulate_grad_batches: 4
41 | model:
42 | class_path: lightning_modules.models.gpt_stmt_mml_model.GptStmtMmlModel
43 | init_args:
44 | transformer_model_name: &transformer EleutherAI/gpt-neo-2.7B
45 | max_gen_len: 256
46 | max_sampling_len: 100
47 | sampling_temp: 0.2
48 | on_policy_sample_num: 1
49 | on_policy_sample_temp: 0.8
50 | sampling_temp_at_k: 0.8
51 | # pass_at_k: 80
52 | # additional_pass_at_k: [5, 10, 20, 50]
53 | # eval_pass_at_k_every_n_epochs: 1
54 | # max_generation_batches: 50
55 | gradient_ckpt: true
56 | measure_dedup_metrics: false
57 | length_diff_tolerance: 0
58 | exclude_context_loss: false
59 | mle_lambda: 1.0
60 | mml_lambda: 0.0
61 | # beta_smoothing: 0.25
62 | # marg_set_size: 10
63 | # max_buffer_size: 10
64 | # eval_greedy_search: true
65 | # load_samples_file: /home/v-ansongni/Code/trace-codegen/mathqa_dedup_train_samples.jsonl
66 | # load_ckpt_file: /home/v-ansongni/Code/trace-codegen/amlt/mathqa-state-finetuning-125M-line-state-all-mask-ctx-len-1648/gpt-neo-mathqa-state-finetuning/lightning_logs/version_0/checkpoints/step=46175-exec_acc=0.6650-exec_rate=0.9956.ckpt
67 | # load_ckpt_file: /home/v-ansongni/Code/trace-codegen/amlt/mathqa-dedup-partial-mml-len_norm-low-temp/gpt-neo-mathqa-state-finetuning/lightning_logs/version_0/checkpoints/step=55755-exec_acc=0.1263-exec_rate=0.9727.ckpt
68 | # load_ckpt_file: /home/v-ansongni/Code/trace-codegen/amlt/mathqa-125M-state-skip-ts/gpt-neo-mathqa-state-finetuning/lightning_logs/version_0/checkpoints/step=49727-exec_acc=0.7559-exec_rate=0.9956.ckpt
69 | # load_ckpt_file: /home/v-ansongni/Code/trace-codegen/amlt/mathqa-finetune-gpt-neo-125M-pad-left/gpt-neo-mathqa-finetuning/lightning_logs/version_0/checkpoints/step=54044-exec_acc=0.7715-exec_rate=0.9893.ckpt
70 | optimizer:
71 | class_path: torch.optim.adamw.AdamW
72 | init_args:
73 | lr: 1.0e-4
74 | # lr: 0.0
75 | betas:
76 | - 0.9
77 | - 0.999
78 | eps: 1.0e-8
79 | weight_decay: 0.1
80 | lr_scheduler:
81 | name: linear
82 | init_args:
83 | num_warmup_steps: 100
84 | num_training_steps: *max_steps
85 |
86 | data:
87 | class_path: lightning_modules.datasets.mathqa_line_reader.MathQAMmlDataModule
88 | init_args:
89 | transformer_model_name: *transformer
90 | batch_size: 2
91 | val_batch_size: 4
92 | train_file_path: data/mathqa/train_dedup.jsonl
93 | val_file_path: data/mathqa/val_dedup.jsonl
94 | # train_max_instances: 40
95 | # val_max_instances: 20
96 | # few_shot_n: 4
--------------------------------------------------------------------------------
/training_configs/gpt_self_sampling_partial.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 333
2 | trainer:
3 | gpus: 1
4 | gradient_clip_val: 1.0
5 | default_root_dir: debug-tmp
6 | # val_check_interval: 1.0
7 | max_steps: &max_steps 50000
8 | check_val_every_n_epoch: 2
9 | log_every_n_steps: 1
10 | num_sanity_val_steps: 0
11 | logger:
12 | - class_path: lightning_modules.loggers.patched_loggers.PatchedWandbLogger
13 | init_args:
14 | entity: niansong1996
15 | project: trace-codegen
16 | name: debug-tmp
17 | log_model: False
18 | save_code: True
19 | offline: False
20 | callbacks:
21 | - class_path: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
22 | init_args:
23 | monitor: exec_acc
24 | mode: max
25 | filename: '{step}-{exec_acc:.4f}-{exec_rate:.4f}'
26 | save_top_k: 5
27 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor
28 | init_args:
29 | logging_interval: step
30 | - class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar
31 | init_args:
32 | refresh_rate: 1
33 |
34 | accelerator: gpu
35 | # replace_sampler_ddp: False
36 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/8262
37 | # strategy: deepspeed_stage_2
38 | strategy: ddp_find_unused_parameters_false
39 | # precision: 16
40 | # accumulate_grad_batches: 4
41 |
42 | model:
43 | class_path: lightning_modules.models.gpt_stmt_partial_mml_model.GptStmtPartialMmlModel
44 | init_args:
45 | transformer_model_name: &transformer EleutherAI/gpt-neo-2.7B
46 | max_gen_len: 256
47 | max_sampling_len: 100
48 | sampling_temp: 0.2
49 | on_policy_sample_num: 1
50 | on_policy_sample_temp: 0.8
51 | sampling_temp_at_k: 0.8
52 | # pass_at_k: 80
53 | # eval_pass_at_k_every_n_epochs: 1
54 | # max_generation_batches: 50
55 | # additional_pass_at_k: [5, 10, 20, 50]
56 | gradient_ckpt: true
57 | measure_dedup_metrics: false
58 | length_diff_tolerance: 0
59 | sampling_from_states: false
60 | mle_lambda: 1.0
61 | mml_lambda: 0.0
62 | # beta_smoothing: 0.25
63 | # containment_based_pc: true
64 | # sampling_full_prog_only: true
65 | # norm_marg_by_len: true
66 | # fcp_only: true
67 | # gold_program_only: true
68 | exclude_context_loss: false
69 | # prioritize_fcp: false
70 | # marg_set_size: 10
71 | # max_buffer_size: 10
72 | # eval_greedy_search: true
73 | optimizer:
74 | class_path: torch.optim.adamw.AdamW
75 | init_args:
76 | lr: 1.0e-4
77 | # lr: 0.0
78 | betas:
79 | - 0.9
80 | - 0.999
81 | eps: 1.0e-8
82 | weight_decay: 0.1
83 | lr_scheduler:
84 | name: linear
85 | init_args:
86 | num_warmup_steps: 100
87 | num_training_steps: *max_steps
88 |
89 | data:
90 | class_path: lightning_modules.datasets.mathqa_line_reader.MathQAMmlDataModule
91 | init_args:
92 | transformer_model_name: *transformer
93 | batch_size: 2
94 | val_batch_size: 4
95 | train_file_path: data/mathqa/train_dedup.jsonl
96 | val_file_path: data/mathqa/val_dedup.jsonl
97 | train_max_instances: 40
98 | val_max_instances: 20
99 | # few_shot_n: 4
100 |
--------------------------------------------------------------------------------