├── .gitignore ├── LICENSE ├── README.md ├── VOCAB_22E4 ├── code_syntax_vocabulary.json ├── iodata_syntax_vocabulary.txt ├── the_great_wall.txt └── wild_content_vocabulary.txt ├── assets ├── teaser.png ├── vocab-dist-v3.pdf └── vocab-dist-v3.png ├── create_data_step1_tokenize_vocab_and_raw.py ├── create_data_step2_regularize.py ├── create_data_step3_raw_to_codeAug.py ├── create_data_step4_SG_to_pickle.py ├── dataloaders ├── .DS_Store ├── apps.py ├── check_exec_match.py ├── code_contests.py ├── data_augmentation.py ├── loader_utils.py └── sttd.py ├── dataset_examples ├── apps_test_codes_readable_nameReplaced_007810.py ├── apps_test_codes_readable_nameReplaced_007819.py ├── apps_test_codes_readable_raw_007810.py ├── apps_test_codes_readable_raw_007819.py ├── apps_test_description_007810.txt ├── apps_test_description_007819.txt ├── apps_test_iodatas_readable_007810.py ├── apps_test_iodatas_readable_007819.py ├── apps_train_codes_readable_nameReplaced_011174.py ├── apps_train_codes_readable_nameReplaced_011175.py ├── apps_train_description_011174.txt ├── apps_train_description_011175.txt ├── apps_train_iodatas_readable_011174.py └── apps_train_iodatas_readable_011175.py ├── evals.py ├── model ├── .DS_Store ├── __init__.py ├── embedders.py ├── model_wrapper.py ├── sklearn_wrapper.py ├── transformer.py └── utils_wrapper.py ├── parsers.py ├── quick_start_tokenizer.py ├── quick_start_tokenizer_output.txt ├── tokenizer ├── .DS_Store ├── astunparse │ ├── __init__.py │ ├── __main__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── printer.cpython-38.pyc │ │ └── unparser.cpython-38.pyc │ ├── printer.py │ └── unparser.py ├── python_syntax.py ├── tokenization_algorithm.py └── tokenizerAPI.py ├── train.py └── trainer ├── .DS_Store ├── logger.py ├── metrics.py ├── optim.py ├── slurm.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # customized ignorance 2 | 3 | .DS_Store 4 | data/ 5 | 6 | 7 | 8 | 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 VITA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🔖 Overview 2 | This repository is the official implementation of "[Outline, Then Details: Syntactically Guided Coarse-To-Fine Code Generation](https://icml.cc/virtual/2023/poster/25091)", accepted in ICML 2023. This work aims to improve the code generation for competitive programming problems. To better leverage the domain knowledge of programming language, we tokenize the code _not_ based on its plain string representation, but rather based on the Abstract Syntax Tree (AST). Our tokenizer decouples the code into a _syntax_-only subcomponent ($S_3$ in the paper) and a _content_-only subcomponent ($S_4$). We also use a transformer architecture that take the natural language inputs, as well as the program input/output data aligned across multiple samples according to the syntax roles. 3 | 4 | 5 | 6 | 7 | # ✅ Quick Start 8 | 9 | ## How to only play with the proposed syntax-aware tokenizer _without loading the model_ (how to generate $S_3$ / $S_4$ from code snippet): 10 | Check out the model-independent demos in `quick_start_tokenizer.py`. The results of this file are printed in `quick_start_tokenizer_output.txt`. 11 | 12 | If you did not get the same results as `quick_start_tokenizer_output.txt` that we provided, it is possibly due to path issue: the folder `VOCAB_22E4` are not found by python in your system. You may need to change the first 5 lines of `tokenizer/tokenization_algorithm.py`, change to absolute path so as to read the `VOCAB_SO_FILE` correctly in that file. 13 | 14 | In `dataset_examples` folder, we provided a few example problems in the dataset (already processed by `create_data_step2_regularize.py`, which make forward passes of the I/O data into the provided code and filter out those "unfaithful" samples, and clean the code with our tokenizer). Feel free to play the examples in `dataset_examples` folder with the `quick_start_tokenizer.py` script. Different samples in one file under `dataset_examples` folder are separated by `# 🟨 🟨 🟨 🟨 ` line. 15 | 16 | 17 | ## How to load pre-trained model and run inference: 18 | 1. Download our pre-trained weights at [here](https://github.com/VITA-Group/ChainCoder/files/22169605/ckpt.zip). 19 | 20 | 2. Download the processed CodeContest and APPS datasets [here](https://drive.google.com/file/d/1yOZXFmqTE_6ct2YLkHfRhVYAvAxZP88Y/view?usp=share_link). You can optionally generate them by running `create_data_step1_tokenize_vocab_and_raw.py`, `create_data_step2_regularize.py`, `create_data_step4_SG_to_pickle.py`. See details below. 21 | 22 | 2. Clone the repo then enter the folder by `cd ChainCoder`. 23 | 24 | 3. run the following command in this folder: 25 | `python evals.py --testing_load_ckpt_from=/your/model/weights/dirname` 26 | 27 | 28 | ## Steps to train new model from scratch 29 | In order to run training script, you'll need to have the training data of [CodeContest](https://github.com/deepmind/code_contests) and [APPS](https://github.com/hendrycks/apps) on your disk. You can optionally generate them by running `create_data_step1_tokenize_vocab_and_raw.py`, `create_data_step2_regularize.py`, `create_data_step4_SG_to_pickle.py`. 30 | 31 | 32 | The script to run the trianing code is `python train.py`. The training/evaluation parameters are configured in `parsers.py`. Important parameters to pay attention to: 33 | 34 | - `--pickle_data_root`: the dataset folder. There should be one or more folders named 'difficulty_introductory', 'difficulty_competition', etc., under this folder. Required for both training and testing. 35 | 36 | - `--training_ckpt_dump_path`: periodically save training ckpts. Required only for training. 37 | 38 | - `--run_on_cpu`: set to False to use cuda. 39 | 40 | 41 | 42 | 43 | # 💡 A quick glance of the algorithm 44 | 45 | In this work, we propose to improve the generalization and performance for competitive programing synthesis, i.e., generating code for code challenge/competition style problems. We focus on generating the python language based solution. These problems specify inputs and outputs, which can be integers, strings, lists, floats, or bool values. We propose a novel tokenizer, together with a transformer architecture to achieve this goal. Our method is featured by: 46 | 47 | (1) Our method Tokenize the python code not out of the plain string representation, but out of Abstract Syntax Tree (AST). In this way, we can decouple the code into _syntax_ component and _content_ component (mentioned as $S_3$, $S_4$ in the paper respectively). 48 | 49 | (2) Our method generate a shorter summary of $S_1$ and $S_2$ out of $S_3$ and $S_4$, to compose a hierarchical coarse-to-fine generation regime to facilitate the inference. 50 | 51 | (3) In the input side, the model uses a cross-sample alignment step to encode the I/O data, which aligns different samples according to their syntax roles. 52 | 53 | 54 | These tokenization related steps are implemented in `tokenizer/tokenization_algorithm.py`, and further wrapped in `tokenizer/tokenizerAPI.py`. The `tokenizer/python_syntax.py` and `tokenizer/astunparse` (modified from an [existing package](https://pypi.org/project/astunparse/)) provides helper functions to convert between python code and Abstract Syntax Tree (AST). The implementations in `tokenizer/python_syntax.py` are closely tied to AST, and the corresponding python version is 3.8 at the time of implementation. Different python versions have slight differences on how the AST understands the code. If you use a different python version than python 3.8, the AST will look slightly different, and you'll need to modify `tokenizer/python_syntax.py` for a perfect support, though such modifications are easy to implement. 55 | 56 | 57 | 58 | 59 | # 📋 Complete guidlines to the code 60 | 61 | By default, you should be able to succesfully run our code if your current dir is this folder (e.g., `/Users/xxx/.../ChainCoder`). However, our implementation requires two important paths to be configured correctly: the _vocabulary_ and the _dataset_. If you are having issue running our code, please check this section below. 62 | 63 | 64 | ## Why is "_vocabulary_" important when running our code? 65 | 66 | 67 | The ChainCoder tokenizer is syntax-aware, as it is built on a unique set of syntax vocabularies drawn from the Abstract Syntax Tree (AST) of the code. To make the token sequence shorter, we applied a grouping technique to combine several adjacent syntax roles into one single token. Therefore, theoretically, the combinatorial syntax roles can become infinite, though most combined syntax roles are quite rare. Check the figure 3 of the paper below. 68 | 69 | ![Figure 3. Token frequency distribution. Visibly, the majority of syntax combinatorial tokens are rarely met.](assets/vocab-dist-v3.png) 70 | 71 | This means that if you want a perfect support of all syntax roles of a new dataset, you need to first do a "sweep" for this dataset to collect all necessary syntax and content patterns. We have already sweeped across the APPS and code contest datasets, and provided their syntax vocabularies in `VOCAB_22E4` folder. These vocabularies will be imported in `tokenizer/tokenization_algorithm.py`. If you want to work with a new dataset and you do not sweep for that dataset, our tokenizer API might be able to tokenize most but not all of codes in that dataset. However, even for the code snippet that cannot be perfectly tokenized due to lack of syntax token support, there exists other ways to write the code in the correct logic but different syntax, which are supported by the existing vocabularies. Therefore, the model will still be able to solve the problem. 72 | 73 | ## How to make sure you load the "_vocabulary_" correctly? 74 | The we have sweeped the CodeContests and APPS datasets to collect the vocabulary in the folder `VOCAB_22E4`. The path to `VOCAB_22E4` is globally configured in one place: the beginning of `tokenizer/tokenization_algorithm.py`. Please make sure you can have `os.path.exists(VOCAB_SO_FILE)==True` within `tokenizer/tokenization_algorithm.py`. If not, try to use an absolute path for `VOCAB_ROOT` in `tokenizer/tokenization_algorithm.py`. 75 | 76 | 77 | 78 | ## How to reproduce results from scratch? 79 | The entire stages of training ChainCoder from scratch are: first collect vocabulary using `create_data_step1_tokenize_vocab_and_raw.py`, then run data augmentation using `create_data_step2_....py`, `create_data_step2_....py`, where step2 is optional. Lastly, train the model using `train.py`. Detailed steps: 80 | 81 | 82 | ### 1. Sweep across entire dataset to collect vocabulary. 83 | This step will create a raw data folder, containing the string format code, the readable format code, the string format I/O data and the readable format I/O data. 84 | #### 📋 checklist: 85 | - 📋 In `tokenizer/tokenization_algorithm.py`, set VOCAB_ROOT to non-existent dir on your machine: the four vocabulary files will be saved and updated within that dir. 86 | - 📋 Run: `create_data_step1_tokenize_vocab_and_raw.py --STTD_root_dir=some/new/path/where/raw/files/are/dumped/into --apps_data_root=where/you/store/APPS/data/downloaded`, where the args.STTD_root_dir to non-existent dir on your machine: this will be the location where all the processed files are stored. APPS download are available [here](https://github.com/hendrycks/apps). 87 | 88 | 89 | 2. Convert from raw data to the pickle file. 90 | 2.1. (Optional) Convert from raw data to the I/O augmented data 91 | 92 | 93 | 3. (Optional) Convert from raw file to code augmented data. 94 | 95 | 4. API provided for training. This API will read from the pickle folder. 96 | 97 | 98 | 99 | # 📑 Requirements 100 | Since our syntax tokenizer is deeply tied with the AST package (Abstract Syntax Tree), and the AST fundamental class might change name and/or arguments whenever there are python version updates, we have some loose requirements for the Python version. The python versions that have passed our tests include: 101 | 102 | python 3.8.13 103 | 104 | The python versions that has failed our tests include: 105 | 106 | python 2.x 107 | 108 | python 3.7 109 | 110 | If your python version is not already python 3.8, The easist way is to install a new environment with python 3.8. You can do this via `conda env create -f py38.yml`, or initialize an empty env and install torch. We require no special packages that is hard to deal, all packages except torch can simply be robustly installed via `pip install xxx`. 111 | 112 | If you would like to try to run the syntax tokenizer or the model under other python version, or if you need to develop syntax tokenizer to encode other programming language data, you can easily do so via changing the `python_syntax/python_syntax.py`. This file defines the syntax classes needed for the `python_syntax/tokenization_algorithm.py`. The class name and arguments in the `python_syntax/python_syntax.py` exactly match the core classes in the python AST package. 113 | 114 | 115 | 116 | # Citation 117 | 118 | ``` 119 | @article{zheng2023outline, 120 | title={Outline, then details: Syntactically guided coarse-to-fine code generation}, 121 | author={Zheng, Wenqing and Sharan, SP and Jaiswal, Ajay Kumar and Wang, Kevin and Xi, Yihan and Xu, Dejia and Wang, Zhangyang}, 122 | journal={arXiv preprint arXiv:2305.00909}, 123 | year={2023} 124 | } 125 | ``` 126 | 127 | -------------------------------------------------------------------------------- /VOCAB_22E4/iodata_syntax_vocabulary.txt: -------------------------------------------------------------------------------- 1 | {',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None))],ctx=Load())],ctx=Load())],ctx=Load()))],type_ignores=[])', 'Module(body=[Expr(value=List(elts=[List(elts=[List(elts=[],ctx=Load())],ctx=Load()),List(elts=[List(elts=[],ctx=Load())],ctx=Load())],ctx=Load()))],type_ignores=[])', ',ctx=Load()),Name(id=', ',kind=None))],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[UnaryOp(op=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None)],ctx=Load())],ctx=Load()),List(elts=[List(elts=[Constant(value=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[UnaryOp(op=', ',kind=None))],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None))],ctx=Load()),List(elts=[Constant(value=', 'Module(body=[Expr(value=List(elts=[List(elts=[List(elts=[UnaryOp(op=', ',ctx=Load())],ctx=Load()),List(elts=[Constant(value=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None)],ctx=Load())],ctx=Load()),List(elts=[List(elts=[UnaryOp(op=', ',ctx=Load())],ctx=Load())],ctx=Load())],ctx=Load()))],type_ignores=[])', ',ctx=Load()))],ctx=Load())],ctx=Load())],ctx=Load()))],type_ignores=[])', ',ctx=Load())],ctx=Load()),List(elts=[Name(id=', ',kind=None))],ctx=Load()),List(elts=[UnaryOp(op=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None))],ctx=Load()),List(elts=[Name(id=', ',operand=Name(id=', ',ctx=Load())],ctx=Load()),List(elts=[UnaryOp(op=', ',operand=Constant(value=', ',kind=None)],ctx=Load())],ctx=Load()),List(elts=[List(elts=[],ctx=Load())],ctx=Load())],ctx=Load()))],type_ignores=[])', ',kind=None)),UnaryOp(op=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None))],ctx=Load())],ctx=Load()),List(elts=[List(elts=[Constant(value=', ',kind=None)),Name(id=', ',kind=None)],ctx=Load())],ctx=Load())],ctx=Load()))],type_ignores=[])', ',ctx=Load()),Constant(value=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None))],ctx=Load())],ctx=Load()),List(elts=[List(elts=[Name(id=', ',kind=None)],ctx=Load())],ctx=Load()),List(elts=[List(elts=[Name(id=', ',kind=None)),Constant(value=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load())],ctx=Load()),List(elts=[List(elts=[UnaryOp(op=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None))],ctx=Load())],ctx=Load()),List(elts=[List(elts=[UnaryOp(op=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None),Constant(value=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None)],ctx=Load()),List(elts=[UnaryOp(op=', ',ctx=Load()),UnaryOp(op=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[UnaryOp(op=', ',kind=None)],ctx=Load()),List(elts=[Name(id=', ',kind=None),Name(id=', ',kind=None)],ctx=Load()),List(elts=[],ctx=Load()),List(elts=[Constant(value=', ',kind=None),UnaryOp(op=', 'Module(body=[Expr(value=List(elts=[List(elts=[List(elts=[Constant(value=', 'Module(body=[Expr(value=List(elts=[List(elts=[List(elts=[],ctx=Load())],ctx=Load()),List(elts=[List(elts=[Constant(value=', ',kind=None)],ctx=Load()),List(elts=[Constant(value='} 2 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/ChainCoder/a9ae00758fe9185c7f13f9b9dc591d05de0d4445/assets/teaser.png -------------------------------------------------------------------------------- /assets/vocab-dist-v3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/ChainCoder/a9ae00758fe9185c7f13f9b9dc591d05de0d4445/assets/vocab-dist-v3.pdf -------------------------------------------------------------------------------- /assets/vocab-dist-v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/ChainCoder/a9ae00758fe9185c7f13f9b9dc591d05de0d4445/assets/vocab-dist-v3.png -------------------------------------------------------------------------------- /create_data_step1_tokenize_vocab_and_raw.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | import time 4 | import sys 5 | import re 6 | 7 | from argparse import ArgumentParser 8 | 9 | from dataloaders.apps import get_apps_rawloader 10 | from dataloaders.code_contests import get_contest_rawloader 11 | from dataloaders.check_exec_match import check_io_match_one_sample_obj 12 | from dataloaders.loader_utils import timeout, save_raw 13 | from tokenizer.tokenizerAPI import ( 14 | vocabulary_defs, load_txt, 15 | tokenizerAPI_OR2T, 16 | tokenizerAPI_OT2R, 17 | tokenizerAPI_IT2R, 18 | tokenizerAPI_IR2T, 19 | ) 20 | 21 | 22 | 23 | 24 | def parse_args(): 25 | parser = ArgumentParser() 26 | 27 | parser.add_argument( 28 | '--apps_data_root', 29 | type = str, 30 | default = '/path/to/apps/APPS', 31 | help = 'Root of data downloaded from APPS (https://github.com/hendrycks/apps).' 32 | ) 33 | parser.add_argument( 34 | '--one_sample_tokenization_timelimit', 35 | type = int, 36 | default = 5, 37 | help = 'Time limit for tokenizing one sample. If exceed, discard this sample. Suggested: 2s is way more than enough, if one sample even exceeds 2s, it may take more than tens of seconds.' 38 | ) 39 | parser.add_argument( 40 | '--one_instance_vocab_collection_timelimit', 41 | type = int, 42 | default = 5000, 43 | help = 'If one instance exceed this number and did not finish all samples, it will be discarded and the results not saved.' 44 | ) 45 | 46 | parser.add_argument( 47 | '--STTD_output_root', 48 | type = str, 49 | default = 'some_new_dump_path', 50 | help = 'This will be the location where all the STTD files are dumped into. Set this to non-existent dir on your machine.' 51 | ) 52 | 53 | args = parser.parse_args() 54 | 55 | return args 56 | 57 | args = parse_args() 58 | 59 | 60 | def add_wrap(code): 61 | code = code.split('\n') 62 | code = ['def syntaxformer_added_top_wrap_func():'] + [' ' + l for l in code] + ['syntaxformer_added_top_wrap_func()\n'] 63 | code = '\n'.join(code) 64 | return code 65 | 66 | def tok_vocab_and_save(): 67 | vocabulary_defs.refuse_unseen_tokens = False 68 | 69 | apps_trainloader = get_apps_rawloader( 70 | mode="train", 71 | difficulties=["introductory", "interview", "competition"], 72 | apps_data_root=args.apps_data_root 73 | ) 74 | apps_testloader = get_apps_rawloader( 75 | mode="test", 76 | difficulties=["introductory", "interview", "competition"], 77 | apps_data_root=args.apps_data_root 78 | ) 79 | 80 | 81 | contest_train_loader = get_contest_rawloader('train') 82 | contest_test_loader = get_contest_rawloader('test') 83 | 84 | 85 | dataloaders = [ 86 | (apps_testloader, 'apps_test'), 87 | (apps_trainloader, 'apps_train'), 88 | (contest_train_loader, 'contest_train'), 89 | (contest_test_loader, 'contest_test'), 90 | ] 91 | 92 | 93 | instance_id = 0 94 | (the_wall_dir, vo_dir) = None, None 95 | 96 | for idataloader in range(len(dataloaders)): 97 | 98 | dataloader, which_loader = dataloaders[idataloader] 99 | print(f'\n\n\n Now using dataloader {which_loader} \n\n') 100 | 101 | for i, one_instance in enumerate(tqdm(dataloader)): 102 | if one_instance is None: continue 103 | 104 | try: 105 | status = ensureTokenization_and_save_raw(which_loader, instance_id, one_instance) 106 | except: 107 | print(f'instance error: {which_loader}, loader_output_id = {i}, error is: {sys.exc_info()[:-1]}') 108 | continue 109 | 110 | if status[0]=='👌': 111 | instance_id += 1 112 | (the_wall_dir, vo_dir) = status[1] 113 | 114 | if i%100==10: 115 | print(f'\n\n 🟨 🟨 checking: \n Now using {which_loader}, idataloader = {idataloader}, this-loader-id = {i} ; Total instances so far = {instance_id}; vocab dirs = {the_wall_dir, vo_dir}') 116 | 117 | print(f'\n\n 🟨 🟨 🟨 🟨 🟨 🟨 🟨 🟨 \n In step0 finished {which_loader} generated xx instances out of {i} ; Total instances so far = {instance_id}') 118 | time.sleep(1) 119 | return 120 | 121 | 122 | @timeout(args.one_instance_vocab_collection_timelimit) 123 | def ensureTokenization_and_save_raw(which_loader, instance_id, one_instance): 124 | global vocabulary_defs 125 | 126 | pcodes_raw, pxs_raw, pys_raw, pio_objs, pdescription, pdifficulty = one_instance['codes_raw'], one_instance['xs_raw'], one_instance['ys_raw'], one_instance["io_objs"], one_instance["description"], one_instance["difficulty"] 127 | 128 | vocab_SO_dic = dict(eval(load_txt(vocabulary_defs.VOCAB_SO_FILE))) 129 | vocab_SI_set = set(eval(load_txt(vocabulary_defs.VOCAB_SI_FILE))) 130 | vocab_CC_set = set(eval(load_txt(vocabulary_defs.VOCAB_CC_FILE))) 131 | 132 | def update_vocab_code(synSeq_code, contSeq_code): 133 | global vocabulary_defs 134 | nonlocal vocab_SO_dic, vocab_SI_set, vocab_CC_set 135 | for x in synSeq_code: 136 | if vocabulary_defs.is_unseen_SO(x): 137 | k = vocabulary_defs.toKey(x, 'O', 'syn', need_update=False) 138 | vocabulary_defs.update([k], ['O'], ['syn']) 139 | vocab_SO_dic[x] = 1 140 | else: 141 | if x in vocab_SO_dic: 142 | vocab_SO_dic[x] += 1 143 | for x in contSeq_code: 144 | if vocabulary_defs.is_unseen_CC(x): 145 | k = vocabulary_defs.toKey(x, 'O', 'cont', need_update=False) 146 | vocabulary_defs.update([k], ['O'], ['cont']) 147 | vocab_CC_set.update([x]) 148 | 149 | 150 | def update_vocab_iodata(synSeq_io, contSeq_io): 151 | global vocabulary_defs 152 | nonlocal vocab_SO_dic, vocab_SI_set, vocab_CC_set 153 | 154 | for x in synSeq_io: 155 | if vocabulary_defs.is_unseen_SI(x): 156 | vocab_SI_set.update([x]) 157 | k = vocabulary_defs.toKey(x, 'I', 'syn', need_update=False) 158 | vocabulary_defs.update([k], ['I'], ['syn']) 159 | for x in contSeq_io: 160 | if vocabulary_defs.is_unseen_CC(x): 161 | vocab_CC_set.update([x]) 162 | k = vocabulary_defs.toKey(x, 'I', 'cont', need_update=False) 163 | vocabulary_defs.update([k], ['I'], ['cont']) 164 | 165 | 166 | codes_nameReplaced = [] 167 | codes_raw = [] 168 | 169 | # 🟩 ensure tokenization for code 170 | for code in pcodes_raw: 171 | 172 | 173 | 174 | try: 175 | @timeout(args.one_sample_tokenization_timelimit) 176 | def run_code_tokenization(code): 177 | is_match, exec_out, prt_str = check_io_match_one_sample_obj(pio_objs[0], code, sanity_check_timeout=1) # only check if there's return issue 178 | 179 | 180 | errmsg = str(exec_out) 181 | return_err_locs = re.findall(r'SyntaxError(.*)return(.*)outside function', errmsg) 182 | if len(return_err_locs)!=0: 183 | code = add_wrap(code) 184 | 185 | 186 | synSeq_code, contSeq_code = tokenizerAPI_OR2T(code) 187 | return synSeq_code, contSeq_code 188 | synSeq_code, contSeq_code = run_code_tokenization(code) 189 | update_vocab_code(synSeq_code, contSeq_code) 190 | name_replaced_recov_code_str = tokenizerAPI_OT2R(synSeq_code, contSeq_code) 191 | except: 192 | print('CODE tokenization error, Discarded:\n', sys.exc_info()[:-1]) 193 | continue 194 | 195 | if synSeq_code==[]: 196 | print('in step0, tokenization fail (too many diy names), discarded this CODE, which is:') 197 | print(code) 198 | continue 199 | 200 | codes_nameReplaced.append(name_replaced_recov_code_str) 201 | codes_raw.append(code) 202 | 203 | xs_raw = [] 204 | ys_raw = [] 205 | io_objs = [] # shape: [sample, (input, output) tuple, token dim] 206 | for x_raw, y_raw, io_2t in zip(pxs_raw, pys_raw, pio_objs): 207 | 208 | 209 | try: 210 | @timeout(args.one_sample_tokenization_timelimit) 211 | def run_possibly_super_long_tokenization(io_2t): 212 | synSeq_io, contSeq_io = tokenizerAPI_IR2T(io_2t) 213 | update_vocab_iodata(synSeq_io, contSeq_io) 214 | rev = tokenizerAPI_IT2R(synSeq_io, contSeq_io) 215 | return rev, synSeq_io, contSeq_io 216 | rev, synSeq_io, contSeq_io = run_possibly_super_long_tokenization(io_2t) 217 | 218 | except: 219 | print('SAMPLE tokenization error, Discarded.') 220 | continue 221 | 222 | if rev!=io_2t: 223 | print('in step0, io tokenization fail, discarded this SAMPLE, which is:') 224 | print(io_2t) 225 | continue 226 | 227 | else: 228 | 229 | io_objs.append(io_2t) 230 | xs_raw.append(x_raw) 231 | ys_raw.append(y_raw) 232 | 233 | 234 | if len(io_objs)<=2 or len(codes_nameReplaced)==0: 235 | print('in step0, valid sample num too small, discarded this INSTANCE.') 236 | return '😭', (None, None) 237 | 238 | 239 | # 🟩 save raw files 240 | 241 | raw_dir = os.path.join(args.STTD_output_root, f'difficulty_{pdifficulty}') 242 | os.makedirs(raw_dir, exist_ok=True) 243 | 244 | cross_samp_join = '\n\n# 🟨 🟨 🟨 🟨 \n\n' 245 | codes_readable_raw = cross_samp_join.join(codes_raw) 246 | codes_readable_nameReplaced = cross_samp_join.join(codes_nameReplaced) 247 | iodatas_readable = cross_samp_join.join([repr(tuple(x)) for x in io_objs]) 248 | save_raw(raw_dir, which_loader, instance_id, 249 | codes_raw, codes_nameReplaced, codes_readable_raw, codes_readable_nameReplaced, 250 | xs_raw, ys_raw, io_objs, iodatas_readable, 251 | pdescription) 252 | 253 | # 🟩 save vocabs 254 | print(vocab_SO_dic, file=open(vocabulary_defs.VOCAB_SO_FILE, 'w')) 255 | print(vocab_SI_set, file=open(vocabulary_defs.VOCAB_SI_FILE, 'w')) 256 | print(vocab_CC_set, file=open(vocabulary_defs.VOCAB_CC_FILE, 'w')) 257 | the_wall_dir = vocabulary_defs.save_the_great_wall() 258 | vo_dir = vocabulary_defs.VOCAB_SO_FILE 259 | 260 | return '👌', (the_wall_dir, vo_dir) 261 | 262 | 263 | if __name__ == '__main__': 264 | 265 | tok_vocab_and_save() 266 | 267 | 268 | -------------------------------------------------------------------------------- /create_data_step2_regularize.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | import itertools 4 | import os 5 | import pickle 6 | import numpy as np 7 | from tqdm import tqdm 8 | from transformers import BertTokenizer, DistilBertTokenizer 9 | import time 10 | import sys 11 | from argparse import ArgumentParser 12 | import copy 13 | from glob import glob 14 | 15 | from tokenizer.tokenizerAPI import ( 16 | tokenizerAPI_IN2R, 17 | tokenizerAPI_IR2N, 18 | tokenizerAPI_ON2R, 19 | tokenizerAPI_OR2N, 20 | tokenizerAPI_OT2R, 21 | tokenizerAPI_OR2T, 22 | vocabulary_defs, 23 | ) 24 | 25 | from dataloaders.check_exec_match import check_io_match_one_sample_obj 26 | from dataloaders.data_augmentation import iodata_augmentor 27 | from dataloaders.loader_utils import evalio, load_all_instances, shuffled, save_raw, parse_loadername_from_filename, MyTimeoutError, timeout 28 | 29 | 30 | 31 | 32 | 33 | vocabulary_defs.refuse_unseen_tokens = True 34 | 35 | def parse_args(): 36 | 37 | parser = ArgumentParser() 38 | 39 | one_sample_program_run_timelimit = 10 40 | 41 | 42 | parser.add_argument( 43 | "raw_data_dir", 44 | help='This dir is read-only for this script: it might read raw files from this dir (or, might read from --code_augmented_dir, depending on how you set --SG_from), then convert to int and save to --pickle_dir.' 45 | ) 46 | parser.add_argument( 47 | "reg_dir", 48 | help='output regularized result dir' 49 | ) 50 | 51 | 52 | parser.add_argument("--verbose", type=int, default=1) 53 | 54 | 55 | # 🟩 Three important dirs below. 56 | parser.add_argument( 57 | "--SG_from", 58 | default='raw', 59 | choices=['raw', 'code_augmented'], 60 | help='Choose where to read raw file and convert to int; choices are ["raw", "code_augmented"].' 61 | ) 62 | 63 | 64 | parser.add_argument( 65 | "--code_augmented_dir", 66 | default='/path/to/your/code_augmented_dir', 67 | help='This dir is read-only for this script: it might read raw files from this dir (or, might read from --raw_data_dir, depending on how you set --SG_from), then convert to int and save to --pickle_dir.' 68 | ) 69 | 70 | 71 | parser.add_argument( 72 | "--one_sample_program_run_timelimit", 73 | type=int, 74 | default=one_sample_program_run_timelimit, 75 | ) 76 | 77 | parser.add_argument( 78 | "--only_do_subfolders", 79 | type=str, 80 | # default='0,1,2,3', 81 | default='all', 82 | help='Used to slice subfolder; values 0~3, seperate by comma, or "all".' 83 | ) 84 | parser.add_argument( 85 | "--only_do_ires", 86 | type=str, 87 | # default='?0,?1,?2,?3,?4,?5,?6,?7,?8,?9', 88 | default='all', 89 | help='Used to slice progress; usage: --only_do_ires="??", where "?" can be 0~9, seperate by comma; or, --only_do_ires="all".' 90 | ) 91 | 92 | 93 | args = parser.parse_args() 94 | if args.SG_from=='code_augmented': 95 | os.makedirs(args.code_augmented_dir, exist_ok=True) 96 | 97 | 98 | 99 | if args.only_do_subfolders=='all': 100 | args.only_do_subfolders = list(range(4)) 101 | else: 102 | args.only_do_subfolders = [int(x) for x in args.only_do_subfolders.split(',')] 103 | if args.only_do_ires=='all': 104 | args.only_do_ires = [f'?????{x}' for x in range(10)] 105 | else: 106 | tmp = [] 107 | for x in args.only_do_ires.split(','): 108 | x = '?'*(6-len(x)) + x 109 | tmp.append(x) 110 | args.only_do_ires = tmp 111 | 112 | if args.verbose: 113 | print('🙂', file=open('_log_ioaug_err.py', 'w')) 114 | 115 | 116 | return args 117 | 118 | 119 | args = parse_args() 120 | 121 | 122 | 123 | def check_match_loop(code_raw_st, io_s2t_orig): 124 | io2codes = defaultdict(list) 125 | totalnum = len(io_s2t_orig) 126 | print(f'🟧 Num samples = {totalnum}') 127 | for i_code in tqdm(range(len(code_raw_st))): 128 | code = code_raw_st[i_code] 129 | io_s2t = [] 130 | core_exec_time = 0 131 | validnum = 0 132 | for i, ioobj in enumerate(io_s2t_orig): 133 | is_match, exec_out, prt_str = check_io_match_one_sample_obj(ioobj, code, sanity_check_timeout=args.one_sample_program_run_timelimit) 134 | 135 | _ct = float(prt_str.split('core_exec_time:\n\t ')[1]) 136 | if _ct!=-1: # only add those passed time. 137 | core_exec_time += _ct 138 | 139 | 140 | io_s2t.append(copy.deepcopy(ioobj)) 141 | # y = process(x, code) 142 | 143 | if is_match: 144 | validnum += 1 145 | else: 146 | if not (type(exec_out) is RuntimeError): 147 | io_s2t[-1][1] = exec_out 148 | validnum += 1 149 | else: 150 | io_s2t.pop() 151 | 152 | io2codes[repr(io_s2t)].append([code, core_exec_time, validnum, totalnum]) 153 | 154 | return io2codes 155 | 156 | 157 | finished_f = 'finished_reg_run.txt' 158 | failed_f = 'failed_reg_run.txt' 159 | 160 | def main(): 161 | try: 162 | main_sub() 163 | except: 164 | print(('failed somewhere', args.only_do_subfolders, args.only_do_ires), file=open(failed_f, 'a')) 165 | return 166 | 167 | 168 | def main_sub(): 169 | if args.SG_from=='raw': 170 | args.SG_root_dir = args.raw_data_dir 171 | elif args.SG_from=='code_augmented': 172 | args.SG_root_dir = args.code_augmented_dir 173 | 174 | subfolders = [ 175 | 'difficulty_introductory', 176 | 'difficulty_interview', 177 | 'difficulty_competition', 178 | 'difficulty_dm_code_contest', 179 | ] 180 | 181 | subfolders = [subfolders[i] for i in args.only_do_subfolders] 182 | 183 | 184 | converted_files = 0 185 | total_faith_div_cnts = np.array([0,0,0]) 186 | 187 | for subfolder in subfolders: 188 | SG_subdir = os.path.join(args.SG_root_dir, subfolder) 189 | 190 | reg_dir_sub = os.path.join(args.reg_dir, subfolder) 191 | os.makedirs(reg_dir_sub, exist_ok=True) 192 | print(f'🙇 Ready to regularize to {reg_dir_sub}! 🙇') 193 | 194 | for ire in tqdm(args.only_do_ires): 195 | file_id_re = ire 196 | 197 | all_instances = load_all_instances(SG_subdir, file_id_re, shuffle=False) 198 | if len(all_instances[0])==0: 199 | continue 200 | 201 | 202 | allinst_cvt = list(zip(*all_instances)) 203 | 204 | 205 | for i_inst, (code_raw_st, code_nameRep_st, x_raw_st, y_raw_st, io_s2t_orig, description, filename) in enumerate(allinst_cvt): 206 | 207 | # 🟩 From here do whatever with these variables: they loop for the entire dataset per instance 208 | 209 | print(f'🟧 beginning check match: \n\t ire/subfolder = {ire, subfolder} \n\t inst/all ins = {i_inst} / {len(allinst_cvt)}\n\t code num = {len(code_raw_st)}') 210 | 211 | io2codes = check_match_loop(code_raw_st, io_s2t_orig) 212 | 213 | which_loader, inst_id_orig = parse_loadername_from_filename(filename) 214 | 215 | for io_r, codes in io2codes.items(): 216 | total_faith_div_cnts[0] += len(codes) 217 | 218 | io_objs = evalio(io_r) 219 | if io_objs==io_s2t_orig: 220 | pdescription = description 221 | _hash_code_behavior = 'orig' 222 | total_faith_div_cnts[1] += len(codes) 223 | else: 224 | pdescription = '' 225 | _hash_code_behavior = hash(repr(io_objs)) 226 | total_faith_div_cnts[2] += len(codes) 227 | 228 | codes_nameReplaced, ctimes, codes_raw = [], [], [] 229 | for code, ctime, validnum, totalnum in codes: 230 | try: 231 | namer = tokenizerAPI_OT2R(*tokenizerAPI_OR2T(code)) 232 | except: 233 | print(f'# ❓❓ an impossible error occured: tokenization error: \n\t\t# {sys.exc_info()[:-1]}\n# Code is:\n{code}', file=open('_log_for_reg.py', 'a')) 234 | continue 235 | codes_nameReplaced.append(namer) 236 | codes_raw.append(code) 237 | ctimes.append([ctime, validnum, totalnum]) 238 | 239 | 240 | raw_readable_with_time = list(map(lambda ab: f'{ab[0]}\n# ⏳ ⏳ Meta Info\n\t# time = {ab[1]}\n\t# io samples valid / all = {ab[2]} / {ab[3]}\n', codes)) 241 | 242 | instance_id = f'{inst_id_orig}ire{hash(ire)}hash{_hash_code_behavior}' 243 | cross_samp_join = '\n\n# 🟨 🟨 🟨 🟨 \n\n' 244 | codes_readable_raw = cross_samp_join.join(raw_readable_with_time) 245 | codes_readable_nameReplaced = cross_samp_join.join(codes_nameReplaced) 246 | iodatas_readable = cross_samp_join.join([repr(tuple(x)) for x in io_objs]) 247 | 248 | if validnum!=0: 249 | save_dir = reg_dir_sub 250 | else: 251 | save_dir = reg_dir_sub + '_bad_codes' # bad_codes means fail to finish all inputs. 252 | 253 | save_raw(save_dir, which_loader, instance_id, 254 | codes_raw, codes_nameReplaced, codes_readable_raw, codes_readable_nameReplaced, 255 | ctimes, [], io_objs, iodatas_readable, 256 | pdescription) 257 | 258 | 259 | print(f'🟧 👌 check match finish: \n\t subfolder = {subfolder} \n\t prog = (ire): {ire} / {args.only_do_ires} (inst): {i_inst} / {len(allinst_cvt)} \n\t total_faith_div_cnts = {total_faith_div_cnts}\n\t total_faith_div_cnts = {total_faith_div_cnts/10000} ') 260 | 261 | print((subfolder, file_id_re), file=open(finished_f, 'a')) 262 | 263 | print('🤘 All finish! 🤘') 264 | 265 | return 266 | 267 | 268 | 269 | def print_stats(iodata_is, code_is): 270 | num_inst = len(iodata_is) 271 | samples_1 = [len(x) for x in iodata_is] 272 | samples_2 = [len(x) for x in code_is] 273 | mv1 = [np.median(samples_1), np.std(samples_1)] 274 | mv2 = [np.median(samples_2), np.std(samples_2)] 275 | 276 | flatten1 = itertools.chain.from_iterable(iodata_is) 277 | lens1 = list(map(lambda x: len(x), flatten1)) 278 | lens1 = [np.median(lens1), np.std(lens1)] 279 | flatten2 = itertools.chain.from_iterable(code_is) 280 | lens2 = list(map(lambda x: len(x), flatten2)) 281 | lens2 = [np.median(lens2), np.std(lens2)] 282 | 283 | print(f'Sample Num Stats of {num_inst} I/O data insts:\t\tNum Samples = {mv1[0]} ± {mv1[1]:.3f}\t\tToken Len = {lens1[0]} ± {lens1[1]:.3f}') 284 | print(f'Sample Num Stats of {num_inst} code data insts:\t\tNum Samples = {mv2[0]} ± {mv2[1]:.3f}\t\tToken Len = {lens2[0]} ± {lens2[1]:.3f}') 285 | return 286 | 287 | 288 | 289 | if __name__ == "__main__": 290 | 291 | main() 292 | 293 | -------------------------------------------------------------------------------- /create_data_step3_raw_to_codeAug.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from argparse import ArgumentParser 5 | from difflib import SequenceMatcher 6 | from typing import List 7 | 8 | import numpy as np 9 | import openai 10 | from joblib import Parallel, delayed 11 | from tqdm import tqdm # noqa 12 | 13 | 14 | from dataloaders.loader_utils import save_raw, load_all_instances, shuffled 15 | 16 | 17 | 18 | all_keys = [ 19 | 'your_api_keys', 20 | ] 21 | 22 | TEMPERATURES = [0.5] 23 | 24 | 25 | def similar(a, b): 26 | return SequenceMatcher(None, a, b).ratio() 27 | 28 | 29 | def take_atleast_one_second(func): 30 | """ 31 | Custom decorator function which makes sure func takes a min of 1 second. 32 | :param func: Function to be decorated. 33 | :return: Wrapped fucntion. 34 | """ 35 | 36 | def wrapper(*args, **kwargs): 37 | tick = time.time() 38 | val = func(*args, **kwargs) 39 | while time.time() - tick < 1: 40 | continue 41 | return val 42 | 43 | return wrapper 44 | 45 | 46 | def summarizer(args, long_str: str) -> List[str]: 47 | """ 48 | Summarizes the description. 49 | :param long_str: What you want to summarize. 50 | :return: Different summarized versions of the same long str. 51 | """ 52 | max_tokens = max(128, int(len(long_str) * 0.1)) 53 | 54 | @take_atleast_one_second 55 | def prompt_gpt3(prompt, temp): 56 | openai.api_key = openai.api_key = all_keys[args.which_key_id] 57 | output = openai.Completion.create( 58 | model="text-davinci-002", 59 | prompt=prompt.strip() + "\n\nTl;dr: ", 60 | temperature=temp, 61 | max_tokens=max_tokens, 62 | top_p=1, 63 | frequency_penalty=0, 64 | presence_penalty=0, 65 | ) 66 | return output.choices[0].text 67 | 68 | responses = [] 69 | for each_temp in TEMPERATURES: 70 | responses.append(prompt_gpt3(long_str, each_temp)) 71 | 72 | return responses 73 | 74 | 75 | def predict_code_one_shot( 76 | args, demo_prompt: str, demo_output: str, target_prompt: str, target_starter_code: str 77 | ) -> List[str]: 78 | """ 79 | Given a prompt of demo input output, predict new code for a descrption. 80 | :param demo_prompt: Example input. 81 | :param demo_output: Example output code. 82 | :param target_prompt: Actual question. 83 | :param target_starter_code: Starting seed of answer. 84 | :return: A list of possible solutions. 85 | """ 86 | summarized_demo_prompts = summarizer(args, demo_prompt) 87 | summarized_target_prompts = summarizer(args, target_prompt) 88 | 89 | @take_atleast_one_second 90 | def prompt_gpt3(_demo_prompt, _demo_output, _target_prompt, _target_starter_code, _temp): 91 | prompt = ( 92 | "---Question---\n" 93 | + _demo_prompt.strip() 94 | + "\n---Python Code---\n" 95 | + _demo_output.strip() 96 | + "\n---Question---\n" 97 | + _target_prompt.strip() 98 | + "\n---Python Code---\n" 99 | + _target_starter_code.strip() 100 | + "\n" 101 | ) 102 | openai.api_key = all_keys[args.which_key_id] 103 | output = None 104 | while output is None: 105 | try: 106 | output = openai.Completion.create( 107 | model="text-davinci-002", 108 | prompt=prompt, 109 | temperature=_temp, 110 | max_tokens=256, 111 | stop=["---"], 112 | top_p=1, 113 | frequency_penalty=0, 114 | presence_penalty=0, 115 | ) 116 | except: 117 | time.sleep(10) 118 | 119 | return _target_starter_code.strip() + output.choices[0].text.strip() 120 | 121 | 122 | all_responses = [prompt_gpt3(demo_prompt, demo_output, target_prompt, target_starter_code, temp) 123 | for demo_prompt in summarized_demo_prompts 124 | for target_prompt in summarized_target_prompts 125 | for temp in TEMPERATURES 126 | ] 127 | 128 | 129 | return all_responses 130 | 131 | 132 | def parse_args(): 133 | parser = ArgumentParser() 134 | 135 | 136 | num_demo_prompts = 5 137 | num_random_starter_prompts = 5 138 | 139 | machine_name = 'some_name' 140 | 141 | parser.add_argument("pickle_dir") 142 | parser.add_argument("raw_data_dir") 143 | parser.add_argument("code_augmented_dir") 144 | 145 | 146 | parser.add_argument("--num_demo_prompts", default=num_demo_prompts, type=int) 147 | parser.add_argument("--num_random_starter_prompts", default=num_random_starter_prompts, type=int) 148 | parser.add_argument("--temperatures", default=["0.2", "0.7"], nargs="+", help="Provide space separated inputs") 149 | parser.add_argument("--machine_name", default=machine_name) 150 | parser.add_argument("--which_key_id", type=int, default=0) 151 | 152 | 153 | parser.add_argument("--verbose", type=int, default=1) 154 | 155 | 156 | args = parser.parse_args() 157 | os.makedirs(args.code_augmented_dir, exist_ok=True) 158 | os.makedirs(args.pickle_dir, exist_ok=True) 159 | args.machine_name += f'key_{str(args.which_key_id)}' 160 | 161 | 162 | 163 | assert len(os.listdir(args.raw_data_dir))==4 and 'difficulty_introductory' in os.listdir(args.raw_data_dir) 164 | 165 | return args 166 | 167 | 168 | def main(): 169 | args = parse_args() 170 | global TEMPERATURES 171 | TEMPERATURES = [float(i) for i in args.temperatures] 172 | 173 | 174 | 175 | subfolders = [ 176 | "difficulty_introductory", 177 | "difficulty_interview", 178 | 'difficulty_competition', 179 | "difficulty_dm_code_contest", 180 | ] 181 | 182 | for subfolder in shuffled(subfolders): 183 | subfolder_dir = os.path.join(args.raw_data_dir, subfolder) 184 | output_dir = os.path.join(args.code_augmented_dir, subfolder) 185 | os.makedirs(output_dir, exist_ok=True) 186 | print(f"Ready to generate to {output_dir}") 187 | 188 | 189 | codes_raw, codes_nameReplaced, xs_raw, ys_raw, iodatas_obj, descriptions, file_names = load_all_instances( 190 | subfolder_dir 191 | ) 192 | print(f"Found {len(codes_raw)} instances in {subfolder_dir}") 193 | 194 | # iterate over each code 195 | for code, code_replaced, x, y, iodata, desc, file_name in tqdm(shuffled(zip( 196 | codes_raw, codes_nameReplaced, xs_raw, ys_raw, iodatas_obj, descriptions, file_names 197 | ))): 198 | augmented_codes = [] 199 | 200 | # we use different demo prompt every time randomly 201 | for _ in range(args.num_demo_prompts): 202 | # index of demo prompt 203 | random_idx_1 = np.random.randint(0, len(codes_raw)) 204 | random_code, random_desc = codes_raw[random_idx_1], descriptions[random_idx_1] 205 | # # each demo prompt has multiple codes, so we choose one of them at random 206 | random_idx_2 = np.random.randint(0, len(random_code)) 207 | random_code = random_code[random_idx_2] 208 | 209 | # generate starter codes based on 5 random codes of target prompt 210 | seen_starter_codes = [] 211 | 212 | for each_code in np.random.choice(code, min(len(code), args.num_random_starter_prompts), replace=False): 213 | # all code till the last occurrence of "input" 214 | last_line_containing_input = max( 215 | i 216 | for i, x in enumerate(each_line.find("input") > -1 for each_line in each_code.split("\n")) 217 | if x 218 | ) 219 | last_line_containing_input = min( 220 | last_line_containing_input, int(len(each_code.split("\n")) * 0.3) 221 | ) # at max keep 30% of code as starter 222 | starter_code = "\n".join(each_code.split("\n")[: last_line_containing_input + 1]) 223 | 224 | # if generated starter code is very similar to existing starter codes, then we ignore it 225 | similarity_score = max( 226 | [similar(each_seen_starter_code, starter_code) for each_seen_starter_code in seen_starter_codes] 227 | + [0] 228 | ) 229 | seen_starter_codes.append(starter_code) 230 | if similarity_score > 0.6: 231 | continue 232 | 233 | # print('\n predict new codes using GPT3 !!') 234 | augmented_codes += predict_code_one_shot(args, random_desc, random_code, desc, starter_code) 235 | 236 | # save new codes to disk 237 | instance_id = int( 238 | re.findall(r"codes_raw_0+(\d+).py", file_name)[0] 239 | ) # some regex trick to parse out instance name 240 | save_raw(output_dir, args.machine_name, instance_id, 241 | augmented_codes, '', '', '', 242 | x, y, iodatas_obj, '', 243 | '') 244 | 245 | 246 | if __name__ == "__main__": 247 | main() 248 | -------------------------------------------------------------------------------- /dataloaders/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/ChainCoder/a9ae00758fe9185c7f13f9b9dc591d05de0d4445/dataloaders/.DS_Store -------------------------------------------------------------------------------- /dataloaders/apps.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from argparse import ArgumentParser 4 | from glob import glob 5 | from torch.utils.data import DataLoader, Dataset 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | VERBOSE = 1 10 | 11 | from dataloaders.check_exec_match import run_input_print_code 12 | 13 | 14 | 15 | 16 | def convert_raws_to_objs(list_of_raw_str): 17 | res = [] 18 | for i, each in enumerate(list_of_raw_str): 19 | raw_lst = each.strip().split("\n") if isinstance(each, str) else each 20 | if data_type_is_not_str(raw_lst): 21 | obj = None 22 | else: 23 | obj = convert_lines_to_obj(raw_lst) 24 | res.append(obj) 25 | return res 26 | 27 | def try_convert_number(n): 28 | # input is a string. 29 | def number_is_int(n): 30 | if n[0] in ['+', '-']: 31 | return n[1:].isdigit() 32 | else: 33 | return n.isdigit() 34 | 35 | is_number = True 36 | try: 37 | num = float(n) 38 | # check "nan" 39 | is_number = (num == num) # nan should return False 40 | except ValueError: 41 | is_number = False 42 | 43 | 44 | if is_number: 45 | if number_is_int(n): 46 | obj = int(n) 47 | else: 48 | obj = float(n) 49 | else: 50 | obj = n 51 | return obj 52 | 53 | def data_type_is_not_str(lst_1D): 54 | if type(lst_1D) is not list: 55 | return True 56 | for s_space in lst_1D: 57 | if type(s_space) is not str: 58 | return True 59 | return False 60 | 61 | 62 | def convert_lines_to_obj(lst_1D): 63 | lst_obj = [] 64 | if type(lst_1D) is not list: 65 | raise ValueError 66 | for s_space in lst_1D: 67 | if type(s_space) is str: 68 | s_lst = s_space.split() 69 | for i, s in enumerate(s_lst): 70 | 71 | sobj = try_convert_number(s) 72 | 73 | s_lst[i] = sobj 74 | lst_obj.append(s_lst) 75 | else: 76 | raise ValueError 77 | lst_obj.append(s_space) 78 | return lst_obj 79 | 80 | 81 | 82 | class APPS(Dataset): 83 | def __init__(self, mode, difficulties, apps_data_root): 84 | """ 85 | Args: 86 | modes: train, test 87 | difficulty: introductory interview competition 88 | """ 89 | all_instances = glob(os.path.join(apps_data_root, mode, '**')) 90 | self.instances = list( 91 | filter(lambda i: json.load(open(os.path.join(i, 'metadata.json')))["difficulty"] in difficulties, all_instances) 92 | ) 93 | 94 | 95 | def __len__(self): 96 | return len(self.instances) 97 | 98 | def __getitem__(self, idx): 99 | try: 100 | codes = json.load(open(os.path.join(self.instances[idx], 'solutions.json'))) 101 | iodata = json.load(open(os.path.join(self.instances[idx], 'input_output.json'))) 102 | description = open(os.path.join(self.instances[idx], 'question.txt')).read().split("-----Input-----")[0].strip() 103 | meta = json.load(open(os.path.join(self.instances[idx], "metadata.json"))) 104 | except FileNotFoundError as e: 105 | print(f'APPS file not found: {str(e)}.') 106 | return None 107 | 108 | check_io_match_now = False 109 | if check_io_match_now: 110 | for each_code in codes: 111 | try: 112 | list_str_out = run_input_print_code(each_code, x_raw[0]) 113 | except: 114 | print('in apps dataloader, exec bug, discarded this SAMPLE.') 115 | continue 116 | 117 | if list_str_out!=y_raw: 118 | print('in apps dataloader, encountered wrong dataset label, discarded this SAMPLE.') 119 | continue 120 | 121 | # sometimes inputs and outputs are having an extra first dimension (equivalent of unsqueeze(0)) 122 | # below is basically a squeeze(0) operation 123 | if len(iodata["inputs"]) == 1 and isinstance(iodata["inputs"][0], list) and len(iodata["inputs"][0]) > 0: 124 | iodata["inputs"] = iodata["inputs"][0] 125 | if len(iodata["outputs"]) == 1 and isinstance(iodata["outputs"][0], list) and len(iodata["outputs"][0]) > 0: 126 | iodata["outputs"] = iodata["outputs"][0] 127 | 128 | # sometimes xs_raw and ys_raw have string, and sometimes directly a list 129 | # so we have an if else in those lines below 130 | xs_raw = [[each.strip().split("\n") if isinstance(each, str) else each] for each in iodata["inputs"]] 131 | ys_raw = [[each.strip().split("\n") if isinstance(each, str) else each] for each in iodata["outputs"]] 132 | # if there's no group (x,y), below codes are equal to: x_objs = convert_raws_to_objs(iodata["outputs"]) then remove None jointly 133 | 134 | 135 | 136 | io_objs = [] 137 | for x_raw, y_raw in zip(xs_raw, ys_raw): 138 | assert len(x_raw)==1 139 | assert len(y_raw)==1 140 | 141 | x = x_raw[0] # 1-D list, elem is string, but contain space; space should be further splited. e.g.: ['2 3', 'abc 4', 'd'] 142 | y = y_raw[0] 143 | 144 | if data_type_is_not_str(x) or data_type_is_not_str(y): 145 | if VERBOSE: 146 | print(f'in APPS, non-standard iodata, dropped x, x, which is:\n{repr(x)}\n{repr(y)}') 147 | continue 148 | 149 | x_obj = convert_lines_to_obj(x) # supposed to be 2-D list, final shape: [each line, obj in line after split and eval] e.g.: [[2, 3], ['abc', 4], ['d']] 150 | y_obj = convert_lines_to_obj(y) 151 | 152 | io_objs.append([x_obj, y_obj]) 153 | 154 | pcodes = codes 155 | codes = [] 156 | for code in pcodes: 157 | if 'input(' in code: 158 | codes.append(code) 159 | 160 | 161 | if len(codes)==0 or len(io_objs)<=1: 162 | print('In APPS, valid regular data number too small, discard this problem.') 163 | return None 164 | 165 | 166 | samples = { 167 | "codes_raw": codes, 168 | "xs_raw": xs_raw, 169 | "ys_raw": ys_raw, 170 | "io_objs": io_objs, 171 | 172 | "description": description, 173 | "difficulty": meta['difficulty'], 174 | "info": {}, 175 | } 176 | 177 | try: 178 | assert len(samples['xs_raw'])==len(samples['ys_raw'])==len(samples["io_objs"]), (len(samples['xs_raw']), len(samples['ys_raw']), len(samples["io_objs"])) 179 | return samples 180 | except: 181 | print(f"in apps data, data len not equal: {len(samples['xs_raw']), len(samples['ys_raw']), len(samples['io_objs'])}") 182 | return None 183 | 184 | 185 | 186 | def get_apps_rawloader(mode, difficulties, apps_data_root): 187 | def process(batch): 188 | assert len(batch)==1 189 | return batch[0] 190 | 191 | dataset = APPS(mode, difficulties, apps_data_root) 192 | dataloader = DataLoader(dataset, batch_size=1, collate_fn=process) 193 | 194 | return dataloader 195 | 196 | 197 | 198 | 199 | def parse_args(): 200 | parser = ArgumentParser() 201 | parser.add_argument("--mode", default="test", choices=["train", "test"]) 202 | parser.add_argument("--difficulty", default="introductory interview competition average", nargs="+") 203 | args = parser.parse_args() 204 | 205 | return args 206 | 207 | 208 | def test_APPS_rawloader(): 209 | args = parse_args() 210 | 211 | args.difficulty = args.difficulty.split(" ") 212 | 213 | dataloader = get_apps_rawloader(args.mode, args.difficulty) 214 | print(len(dataloader)) 215 | for i in dataloader: 216 | print(i) 217 | 218 | total = 0 219 | for i in tqdm(dataloader): 220 | assert len(i["code_strings"]) == len(i["xs_raw"]) == len(i["ys_raw"]) 221 | total += len(i["code_strings"]) 222 | print(total) 223 | 224 | 225 | if __name__ == "__main__": 226 | test_APPS_rawloader() 227 | -------------------------------------------------------------------------------- /dataloaders/check_exec_match.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | import subprocess 4 | from typing import Union 5 | 6 | from time import time as timer 7 | 8 | from dataloaders.loader_utils import timeout, MyTimeoutError 9 | 10 | from tokenizer.tokenizerAPI import ( 11 | tokenizerAPI_IN2R, 12 | tokenizerAPI_IR2N, 13 | tokenizerAPI_ON2R, 14 | tokenizerAPI_OR2N, 15 | ) 16 | 17 | 18 | 19 | # 🟩 below raw conversion codes copied from apps.py 20 | def convert_raws_to_objs(list_of_raw_str): 21 | res = [] 22 | # bad_indices = [] 23 | for i, each in enumerate(list_of_raw_str): 24 | raw_lst = each.strip().split("\n") if isinstance(each, str) else each 25 | if data_type_is_not_str(raw_lst): 26 | # bad_indices.append(i) 27 | obj = None 28 | else: 29 | obj = convert_lines_to_obj(raw_lst) 30 | res.append(obj) 31 | return res 32 | 33 | 34 | def try_convert_number(n): 35 | # input is a string. 36 | def number_is_int(n): 37 | if n[0] in ['+', '-']: 38 | return n[1:].isdigit() 39 | else: 40 | return n.isdigit() 41 | 42 | is_number = True 43 | try: 44 | num = float(n) 45 | # check "nan" 46 | is_number = (num == num) # nan should return False 47 | except ValueError: 48 | is_number = False 49 | 50 | 51 | if is_number: 52 | if number_is_int(n): 53 | obj = int(n) 54 | else: 55 | obj = float(n) 56 | else: 57 | obj = n 58 | return obj 59 | 60 | def data_type_is_not_str(lst_1D): 61 | if type(lst_1D) is not list: 62 | return True 63 | # return [lst_1D] 64 | for s_space in lst_1D: 65 | if type(s_space) is not str: 66 | return True 67 | return False 68 | 69 | 70 | def convert_lines_to_obj(lst_1D): 71 | lst_obj = [] 72 | if type(lst_1D) is not list: 73 | raise ValueError 74 | # # return [lst_1D] 75 | for s_space in lst_1D: 76 | if type(s_space) is str: 77 | s_lst = s_space.split() 78 | for i, s in enumerate(s_lst): 79 | 80 | sobj = try_convert_number(s) 81 | 82 | s_lst[i] = sobj 83 | lst_obj.append(s_lst) 84 | else: 85 | raise ValueError 86 | lst_obj.append(s_space) 87 | return lst_obj 88 | 89 | 90 | 91 | convert_obj_back_to_raw = lambda lst2: '\n'.join([' '.join([str(x) for x in lst1]) for lst1 in lst2]) 92 | 93 | def convert_io_back_to_raw(io_obj): 94 | raw_x, raw_y = io_obj 95 | raw_x = convert_obj_back_to_raw(raw_x) 96 | raw_y = convert_obj_back_to_raw(raw_y) 97 | return raw_x, raw_y 98 | 99 | def check_io_match_one_sample_int(io_ns, code_ns, sanity_check_timeout): 100 | io_obj = tokenizerAPI_IN2R(io_ns) 101 | code = tokenizerAPI_ON2R(code_ns) 102 | is_match, exec_out, prt_str = check_io_match_one_sample_obj(io_obj, code, sanity_check_timeout) 103 | return is_match, exec_out, prt_str 104 | 105 | 106 | def check_io_match_one_sample_obj(io_obj, code, sanity_check_timeout, want_print = False): 107 | assert type(sanity_check_timeout) is int, 'must provide integer timelimit' 108 | 109 | try: 110 | # if 1: 111 | @timeout(sanity_check_timeout) 112 | def run_(): 113 | raw_x, raw_y = convert_io_back_to_raw(io_obj) 114 | exec_out, core_exec_time = run_input_print_code(code, raw_x) 115 | is_match = exec_out==raw_y 116 | return is_match, exec_out, core_exec_time 117 | is_match, exec_out, core_exec_time = run_() 118 | except MyTimeoutError: 119 | # else: 120 | is_match, exec_out, core_exec_time = False, RuntimeError(f'Timeout, limit is {sanity_check_timeout}s. NO error though.'), -1 121 | except: 122 | is_match, exec_out, core_exec_time = False, RuntimeError(f'Not timeout, other errors.'), -1 123 | 124 | 125 | 126 | prt_str = f'In check match: \nI/O:\n\t\t{repr(io_obj)}\nExec Result:\n\t\t{repr(exec_out)}\nIs Match:\n\t\t{repr(is_match)} \n core_exec_time:\n\t {core_exec_time}' 127 | if want_print: 128 | print(prt_str) 129 | return is_match, exec_out, prt_str 130 | 131 | 132 | def check_match_one_instance(instance_io, instance_code, sanity_check_timeout): 133 | import random 134 | io_inp = random.choice(instance_io) 135 | code_inp = random.choice(instance_code) 136 | is_match, exec_out, io_obj, code = check_io_match_one_sample_int(io_inp, code_inp, sanity_check_timeout) 137 | return 138 | 139 | def forward_run_code(raw_x, code, timelimit): 140 | try: 141 | @timeout(timelimit) 142 | def forward(): 143 | return run_input_print_code(code, raw_x) 144 | exec_out = forward() 145 | except: 146 | exec_out = RuntimeError('Program + idata execution Timeout.') 147 | return exec_out 148 | 149 | 150 | def run_input_print_code(code: str, idata: str, debug: bool = False) -> Union[str, int]: 151 | """ 152 | Runs code and returns output. 153 | If output is string, code executed properly. 154 | If output is -1 (int), code failed somewhere. 155 | :param code: String of code. 156 | :param idata: String of input to script. Use \n for new line, not list!!! 157 | :param debug: Bool flag to print error. 158 | :return: Output of code (or) -1 if failed. 159 | """ 160 | rnd = "".join(random.choices(string.ascii_letters + string.digits, k=16)) 161 | tmp_file = f"/tmp/input_print_{rnd}.py" 162 | open(tmp_file, "w").write(code) 163 | t0 = timer() 164 | result = subprocess.run(["python", tmp_file], input=idata, capture_output=True, text=True) 165 | core_exec_time = timer() - t0 166 | subprocess.run(["rm", tmp_file]) 167 | 168 | if len(result.stderr): 169 | if debug: 170 | print(result.stderr) 171 | # return RuntimeError('Program execution yield error; did NOT timeout.'), core_exec_time 172 | return RuntimeError(f'Program execution procedure error; did NOT timeout. Full error message: \n\n{result.stderr}'), core_exec_time 173 | else: 174 | return result.stdout.strip(), core_exec_time 175 | 176 | -------------------------------------------------------------------------------- /dataloaders/code_contests.py: -------------------------------------------------------------------------------- 1 | import os 2 | from joblib import Parallel, delayed 3 | 4 | from datasets import load_dataset 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | 8 | 9 | 10 | 11 | 12 | 13 | def try_convert_number(n): 14 | # input is a string. 15 | def number_is_int(n): 16 | if n[0] in ['+', '-']: 17 | return n[1:].isdigit() 18 | else: 19 | return n.isdigit() 20 | 21 | is_number = True 22 | try: 23 | num = float(n) 24 | # check "nan" 25 | is_number = (num == num) # nan should return False 26 | except ValueError: 27 | is_number = False 28 | 29 | 30 | if is_number: 31 | if number_is_int(n): 32 | obj = int(n) 33 | else: 34 | obj = float(n) 35 | else: 36 | obj = n 37 | return obj 38 | 39 | def data_type_is_not_str(lst_1D): 40 | if type(lst_1D) is not list: 41 | return True 42 | # return [lst_1D] 43 | for s_space in lst_1D: 44 | if type(s_space) is not str: 45 | return True 46 | return False 47 | 48 | 49 | def convert_lines_to_obj(lst_1D): 50 | lst_obj = [] 51 | if type(lst_1D) is not list: 52 | raise ValueError 53 | # # return [lst_1D] 54 | for s_space in lst_1D: 55 | if type(s_space) is str: 56 | s_lst = s_space.split() 57 | for i, s in enumerate(s_lst): 58 | 59 | sobj = try_convert_number(s) 60 | 61 | s_lst[i] = sobj 62 | lst_obj.append(s_lst) 63 | else: 64 | raise ValueError 65 | lst_obj.append(s_space) 66 | return lst_obj 67 | 68 | 69 | def get_contest_rawloader(split="train", datapath="~/.cache/huggingface/datasets"): 70 | dataset = load_dataset("deepmind/code_contests", split=split, cache_dir=datapath) 71 | 72 | def preprocess(batch): 73 | assert len(batch)==1 74 | each = batch[0] 75 | 76 | samples = dict() 77 | codes = [soln for lang, soln in zip(each["solutions"]["language"], each["solutions"]["solution"]) if lang == 3] 78 | 79 | xs_raw = [eachh.strip().split("\n") for eachh in each["public_tests"]["input"] + each["generated_tests"]["input"]] 80 | ys_raw = [eachh.strip().split("\n") for eachh in each["public_tests"]["output"] + each["generated_tests"]["output"]] 81 | description = each["description"] 82 | 83 | 84 | io_objs = [] 85 | for x_raw, y_raw in zip(xs_raw, ys_raw): 86 | x = x_raw # 1-D list, elem is string, but contain space; space should be further splited. e.g.: ['2 3', 'abc 4', 'd'] 87 | y = y_raw 88 | 89 | if data_type_is_not_str(x) or data_type_is_not_str(y): 90 | # print('in dm contest, non-standard iodata, dropped.') 91 | continue 92 | 93 | 94 | x_obj = convert_lines_to_obj(x) # supposed to be 2-D list, final shape: [each line, obj in line after split and eval] e.g.: [[2, 3], ['abc', 4], ['d']] 95 | y_obj = convert_lines_to_obj(y) 96 | 97 | 98 | io_objs.append([x_obj, y_obj]) 99 | 100 | pcodes = codes 101 | codes = [] 102 | for code in pcodes: 103 | if 'input(' in code: 104 | codes.append(code) 105 | 106 | 107 | if len(codes)==0 or len(io_objs)<=1: 108 | print('In dm contest, valid regular data number too small, discard this problem.') 109 | return None 110 | 111 | 112 | samples = { 113 | "codes_raw": codes, 114 | "xs_raw": xs_raw, 115 | "ys_raw": ys_raw, 116 | "io_objs": io_objs, 117 | 118 | "description": description, 119 | "difficulty": 'dm_code_contest', 120 | "info": {}, 121 | } 122 | 123 | try: 124 | assert len(samples['xs_raw'])==len(samples['ys_raw'])==len(samples["io_objs"]), (len(samples['xs_raw']), len(samples['ys_raw']), len(samples["io_objs"])) 125 | return samples 126 | except: 127 | print(f"in code contest loader, data len not equal: {len(samples['xs_raw']), len(samples['ys_raw']), len(samples['io_objs'])}") 128 | return None 129 | 130 | 131 | 132 | 133 | 134 | dataloader = DataLoader(dataset, batch_size=1, collate_fn=preprocess) 135 | return dataloader 136 | 137 | 138 | def save_codes_dump(): 139 | dataset = load_dataset("deepmind/code_contests", split="train") 140 | os.makedirs("codes_dump", exist_ok=True) 141 | for each_datapoint in dataset: 142 | python_codes = [soln for lang, soln in zip(each_datapoint["solutions"]["language"], each_datapoint["solutions"]["solution"]) if lang == 3] 143 | for each_code in python_codes: 144 | i = 0 145 | while os.path.exists(f"codes_dump/sample-{i:04d}.py"): 146 | i += 1 147 | if i > 9999: 148 | return 149 | open(f"codes_dump/sample-{i:04d}.py", "w").write(each_code) 150 | 151 | -------------------------------------------------------------------------------- /dataloaders/loader_utils.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import numpy as np 3 | import os 4 | 5 | import errno 6 | import signal 7 | from functools import wraps, partial 8 | import math 9 | import time 10 | 11 | 12 | def load_txt(fname, as_str=True): 13 | x = [] 14 | with open(fname, 'r') as f: 15 | for line in f.readlines(): 16 | if line[-1]=='\n': 17 | x.append(line[:-1]) 18 | else: 19 | x.append(line) 20 | if as_str: 21 | x = '\n'.join(x) 22 | return x 23 | 24 | 25 | def load_all_instances(subfolder, file_id_re='*', shuffle=True): 26 | name_to_int = lambda i: int(i.split("_")[-1].split('.')[0]) 27 | 28 | if len(glob(os.path.join(subfolder, f'*codes_nameReplaced*{file_id_re}.py')))==0: 29 | return [[] for _ in range(7)] 30 | 31 | def glob_sorted(restr): 32 | restr = os.path.join(subfolder, restr) 33 | files = glob(restr) 34 | indices = list(map(name_to_int, files)) 35 | 36 | files = sorted(zip(files, indices), key=lambda x:x[1]) 37 | return list(zip(*files)) 38 | codes_nameReplaced, ind_cre = glob_sorted(f'*codes_nameReplaced*{file_id_re}.py') 39 | codes_raw, ind_crw = glob_sorted(f'*codes_raw*{file_id_re}.py') 40 | xs_raw, ind_xr = glob_sorted(f'*xs_raw*{file_id_re}.py') 41 | ys_raw, ind_yr = glob_sorted(f'*ys_raw*{file_id_re}.py') 42 | iodatas_obj, ind_o = glob_sorted(f'*io_objs*{file_id_re}.py') 43 | descriptions, ind_d = glob_sorted(f'*description*{file_id_re}.txt') 44 | file_names, ind_fn = glob_sorted(f"*codes_raw*{file_id_re}.py") 45 | 46 | try: 47 | assert len(codes_raw)==len(codes_nameReplaced)==len(xs_raw)==len(ys_raw)==len(iodatas_obj)==len(descriptions)==len(file_names) 48 | except: 49 | print('file missing due to scp or generation! num of instances of each file types = ', (len(codes_raw),len(codes_nameReplaced),len(xs_raw),len(ys_raw),len(iodatas_obj),len(descriptions), len(file_names)), 'Now dropped the missing instance.') 50 | def drop_missing(): 51 | 52 | _codes_raw, _codes_nameReplaced, _xs_raw, _ys_raw, _iodatas_obj, _descriptions, _fnames = [[] for _ in range(7)] 53 | x2i_cre = {x:i for i, x in enumerate(ind_cre)} 54 | x2i_crw = {x:i for i, x in enumerate(ind_crw)} 55 | x2i_xr = {x:i for i, x in enumerate(ind_xr)} 56 | x2i_yr = {x:i for i, x in enumerate(ind_yr)} 57 | x2i_o = {x:i for i, x in enumerate(ind_o)} 58 | x2i_d = {x:i for i, x in enumerate(ind_d)} 59 | x2i_fn = {x:i for i, x in enumerate(ind_fn)} 60 | set_ind_cre, set_ind_crw, set_ind_xr, set_ind_yr, set_ind_o, set_ind_d, set_int_fn = set(ind_cre), set(ind_crw), set(ind_xr), set(ind_yr), set(ind_o), set(ind_d), set(ind_fn) 61 | inters = set_ind_cre.intersection(set_ind_crw).intersection(set_ind_xr).intersection(set_ind_yr).intersection(set_ind_o).intersection(set_ind_d).intersection(set_int_fn) 62 | 63 | for x in inters: 64 | _codes_raw.append(codes_raw[x2i_cre[x]]) 65 | _codes_nameReplaced.append(codes_nameReplaced[x2i_crw[x]]) 66 | _xs_raw.append(xs_raw[x2i_xr[x]]) 67 | _ys_raw.append(ys_raw[x2i_yr[x]]) 68 | _iodatas_obj.append(iodatas_obj[x2i_o[x]]) 69 | _descriptions.append(descriptions[x2i_d[x]]) 70 | _fnames.append(file_names[x2i_fn[x]]) 71 | 72 | return _codes_raw, _codes_nameReplaced, _xs_raw, _ys_raw, _iodatas_obj, _descriptions, _fnames 73 | codes_raw, codes_nameReplaced, xs_raw, ys_raw, iodatas_obj, descriptions, file_names = drop_missing() 74 | name_to_int_ls = lambda ls: [name_to_int(x) for x in ls] 75 | assert set(name_to_int_ls(codes_raw)) == set(name_to_int_ls(codes_nameReplaced)) == set(name_to_int_ls(xs_raw)) == set(name_to_int_ls(ys_raw)) == set(name_to_int_ls(iodatas_obj)) == set(name_to_int_ls(descriptions)) == set(name_to_int_ls(file_names)) 76 | 77 | 78 | 79 | if len(codes_raw)==0: 80 | print("empty samples") 81 | return [[] for _ in range(6)] 82 | 83 | 84 | # 🟩 load files 85 | batch_load = lambda flst: [load_txt(f) for f in flst] 86 | codes_raw, codes_nameReplaced, xs_raw, ys_raw, iodatas_obj, descriptions = batch_load(codes_raw), batch_load(codes_nameReplaced), batch_load(xs_raw), batch_load(ys_raw), batch_load(iodatas_obj), batch_load(descriptions) 87 | 88 | # 🟩 batch eval 89 | def replace_too_large_int_with_inf(io_strings): 90 | new_strings = [] 91 | max_digits = 4300 92 | for string in io_strings: 93 | res = '' 94 | prev_group = '' 95 | for s in string: 96 | if not s.isdigit(): 97 | if len(prev_group)>=max_digits: 98 | prev_group = 'inf' 99 | res += prev_group 100 | res += s 101 | prev_group = '' 102 | else: 103 | prev_group += s 104 | if len(prev_group)>=max_digits: 105 | res += 'inf' 106 | else: 107 | res += prev_group 108 | new_strings.append(res) 109 | return new_strings 110 | iodatas_obj = replace_too_large_int_with_inf(iodatas_obj) 111 | 112 | 113 | 114 | batch_eval = lambda objlst: [evalio(x) for x in objlst] 115 | codes_raw, codes_nameReplaced, xs_raw, ys_raw, iodatas_obj = batch_eval(codes_raw), batch_eval(codes_nameReplaced), batch_eval(xs_raw), batch_eval(ys_raw), batch_eval(iodatas_obj) 116 | 117 | 118 | 119 | 120 | # 🟩 shuffle jointly 121 | if shuffle: 122 | perm = np.random.permutation(len(codes_raw)) 123 | shuffle = lambda lst: [lst[i] for i in perm] 124 | codes_raw, codes_nameReplaced, xs_raw, ys_raw, iodatas_obj, descriptions, file_names = shuffle(codes_raw), shuffle(codes_nameReplaced), shuffle(xs_raw), shuffle(ys_raw), shuffle(iodatas_obj), shuffle(descriptions), shuffle(file_names) 125 | 126 | 127 | return codes_raw, codes_nameReplaced, xs_raw, ys_raw, iodatas_obj, descriptions, file_names 128 | 129 | def evalio(x): 130 | return eval(x, {'inf': float('inf'), 'nan': float('nan')}) 131 | 132 | 133 | def parse_loadername_from_filename(filename): 134 | basename = os.path.basename(filename) 135 | loadername = '_'.join(basename.split('_')[:2]) 136 | inst_id_orig = basename.split('_')[-1][:-3] 137 | return loadername, inst_id_orig 138 | 139 | def save_raw(_dir, which_loader, instance_id, 140 | codes_raw, codes_nameReplaced, codes_readable_raw, codes_readable_nameReplaced, 141 | xs_raw, ys_raw, io_objs, iodatas_readable, 142 | description): 143 | 144 | os.makedirs(_dir, exist_ok=True) 145 | if type(instance_id) is int: 146 | instance_id = f'{instance_id:06d}' 147 | print(codes_raw, file=open(os.path.join(_dir, f'{which_loader}_codes_raw_{instance_id}.py'), 'w')) 148 | print(codes_nameReplaced, file=open(os.path.join(_dir, f'{which_loader}_codes_nameReplaced_{instance_id}.py'), 'w')) 149 | print(xs_raw, file=open(os.path.join(_dir, f'{which_loader}_xs_raw_{instance_id}.py'), 'w')) 150 | print(ys_raw, file=open(os.path.join(_dir, f'{which_loader}_ys_raw_{instance_id}.py'), 'w')) 151 | print(io_objs, file=open(os.path.join(_dir, f'{which_loader}_io_objs_{instance_id}.py'), 'w')) 152 | print(description, file=open(os.path.join(_dir, f'{which_loader}_description_{instance_id}.txt'), 'w')) 153 | 154 | os.makedirs(os.path.join(_dir, 'readable'), exist_ok=True) 155 | print(codes_readable_raw, file=open(os.path.join(_dir, 'readable', f'{which_loader}_codes_readable_raw_{instance_id}.py'), 'w')) 156 | print(codes_readable_nameReplaced, file=open(os.path.join(_dir, 'readable', f'{which_loader}_codes_readable_nameReplaced_{instance_id}.py'), 'w')) 157 | print(iodatas_readable, file=open(os.path.join(_dir, 'readable', f'{which_loader}_iodatas_readable_{instance_id}.py'), 'w')) 158 | print(description, file=open(os.path.join(_dir, 'readable', f'{which_loader}_description_{instance_id}.txt'), 'w')) 159 | 160 | 161 | 162 | 163 | 164 | def shuffled(iterable): 165 | lst = list(iterable) 166 | np.random.shuffle(lst) 167 | return lst 168 | 169 | class MyTimeoutError(BaseException): 170 | pass 171 | def timeout(seconds=10, error_message=os.strerror(errno.ETIME)): 172 | def decorator(func): 173 | def _handle_timeout(repeat_id, signum, frame): 174 | signal.signal(signal.SIGALRM, partial(_handle_timeout, repeat_id + 1)) 175 | signal.alarm(seconds) 176 | raise MyTimeoutError(error_message) 177 | 178 | def wrapper(*args, **kwargs): 179 | old_signal = signal.signal(signal.SIGALRM, partial(_handle_timeout, 0)) 180 | old_time_left = signal.alarm(seconds) 181 | assert type(old_time_left) is int and old_time_left >= 0 182 | if 0 < old_time_left < seconds: # do not exceed previous timer 183 | signal.alarm(old_time_left) 184 | start_time = time.time() 185 | try: 186 | result = func(*args, **kwargs) 187 | finally: 188 | if old_time_left == 0: 189 | signal.alarm(0) 190 | else: 191 | sub = time.time() - start_time 192 | signal.signal(signal.SIGALRM, old_signal) 193 | signal.alarm(max(0, math.ceil(old_time_left - sub))) 194 | return result 195 | 196 | return wraps(func)(wrapper) 197 | 198 | return decorator 199 | -------------------------------------------------------------------------------- /dataloaders/sttd.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import pickle 4 | from glob import glob 5 | import numpy as np 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | VERBOSE = 0 9 | 10 | class STTTDDataset(Dataset): 11 | def __init__(self, params, pickle_root): 12 | if params.only_do_subfolder != '': 13 | subfolders = [params.only_do_subfolder] 14 | else: 15 | subfolders = ['difficulty_introductory', 'difficulty_interview', 'difficulty_dm_code_contest'] 16 | load_pkl_from = [os.path.join(pickle_root, x, '*.pkl') for x in subfolders] 17 | 18 | print(f'\n\nIn STTD dataloader, loading from: {load_pkl_from}') 19 | 20 | all_existing_files = list(itertools.chain.from_iterable([glob(x) for x in load_pkl_from])) 21 | if params.most_recent_pickle_num == 'all': 22 | self.files = all_existing_files 23 | print(f'When init dataloader, grabbing all {len(all_existing_files)} existing pickles.') 24 | else: 25 | most_recent_pickle_num = int(params.most_recent_pickle_num) 26 | self.files = sorted(all_existing_files, key=os.path.getmtime)[-most_recent_pickle_num:] 27 | print(f'When init dataloader, grabbing {most_recent_pickle_num} most recent pickles out of {len(all_existing_files)}.') 28 | assert len(self.files)>0, f'dataset empty: {load_pkl_from}' 29 | self.max_nlp_seq_len = 512 30 | 31 | self.max_io_seq_len, self.max_code_seq_len = params.max_io_seq_len, params.max_code_seq_len 32 | self.batch_size = params.batch_size 33 | self.samples_per_instance_io = params.samples_per_instance_io 34 | self.samples_per_instance_code = params.samples_per_instance_code 35 | self.samples_per_instance_io_hold = params.samples_per_instance_io_hold 36 | 37 | demo = pickle.load(open(self.files[0], "rb")) 38 | [valid_iodatas_int_is, valid_codes_int_is, valid_desc_distilbert, valid_desc_bert] = demo 39 | self.instances_per_file = len(valid_iodatas_int_is) 40 | 41 | self.inter_file_order = None 42 | self.intra_file_order = None 43 | self.inter_index = 0 44 | self.intra_index = 0 45 | self.inter_flag = 0 46 | self.inter_flag = 0 47 | self.file_in_memory = False 48 | self.batch_sampler() 49 | 50 | def __len__(self): 51 | return len(self.files) * (self.instances_per_file // self.batch_size) 52 | 53 | def batch_sampler(self): 54 | self.inter_file_order = np.random.permutation(np.arange(len(self.files))) 55 | assert self.instances_per_file / self.batch_size>0, (self.instances_per_file , self.batch_size) 56 | self.intra_file_order = [ 57 | np.random.permutation(np.arange(int(self.instances_per_file / self.batch_size))) 58 | for _ in range(len(self.files)) 59 | ] 60 | self.inter_index = 0 61 | self.intra_index = 0 62 | self.file_in_memory = False 63 | 64 | 65 | def __getitem__(self, item): 66 | if not self.file_in_memory: 67 | self.file_in_memory = pickle.load(open(self.files[self.inter_file_order[self.inter_index]], "rb")) 68 | self.file_in_memory[0], bad_insts1 = drop_token_len_exceeds(self.file_in_memory[0], self.max_io_seq_len) 69 | self.file_in_memory[1], bad_insts2 = drop_token_len_exceeds(self.file_in_memory[1], self.max_code_seq_len) 70 | bad_insts = set(bad_insts1+bad_insts2) 71 | if len(bad_insts)>0: 72 | for itp in range(4): # if bad, pad ... 73 | def drop_certain_indices(lst, indices): 74 | new = [] 75 | for i, x in enumerate(lst): 76 | if i not in indices: 77 | new.append(x) 78 | return new 79 | self.file_in_memory[itp] = drop_certain_indices(self.file_in_memory[itp], bad_insts) 80 | if VERBOSE: 81 | print(f'Dataloader Find {len(bad_insts)} bad_insts (tok len too long) out of {self.instances_per_file}... padded with duplication.') 82 | good_ists = np.random.choice(len(self.file_in_memory[0]), len(bad_insts)) 83 | for itp in range(4): # if bad, pad ... 84 | self.file_in_memory[itp].extend([self.file_in_memory[itp][ig] for ig in good_ists]) 85 | 86 | 87 | idx = self.intra_file_order[self.inter_file_order[self.inter_index]][self.intra_index] 88 | batch_content = [self.file_in_memory[type_dir][idx : idx + self.batch_size] for type_dir in range(4)] 89 | batch_content = list(zip(*batch_content)) 90 | 91 | data = dict() 92 | 93 | ioData_ist2_all = list(map(lambda i: list(map(lambda ii: ii, i[0])), batch_content)) 94 | codes_ist2_in_file = list(map(lambda i: i[1], batch_content)) 95 | 96 | def select_k_samples(ist2, release_hold): 97 | ist2_rel, ist2_hold = [], [] 98 | 99 | for i, st2 in enumerate(ist2): 100 | num_samp = len(st2) 101 | if num_samp >= sum(release_hold): 102 | replace = False 103 | else: 104 | if VERBOSE: 105 | print(f'In dataloader, requested too many samples for I/O: samples_per_instance_io + samples_per_instance_io_hold > total samples: sum({release_hold}) > {num_samp}. Returned samples now have duplications!') 106 | replace = True 107 | idsall = np.random.choice(num_samp, sum(release_hold), replace=replace) 108 | ids_rel, ids_hold = idsall[:release_hold[0]], idsall[release_hold[0]:]#idsall[-release_hold[1]:] 109 | ist2_rel.append([st2[j] for j in ids_rel]) 110 | ist2_hold.append([st2[j] for j in ids_hold]) 111 | return ist2_rel, ist2_hold 112 | 113 | data["ioData_ist2"], data["ioData_ist2_holdout"] = select_k_samples(ioData_ist2_all, [self.samples_per_instance_io, self.samples_per_instance_io_hold]) 114 | 115 | data["desc_distilbert_it"] = list(map(lambda i: i[2]['input_ids'], batch_content)) 116 | data["desc_bert_it"] = list(map(lambda i: i[3]['input_ids'], batch_content)) 117 | data["desc_distilbert_it"] = [x[:self.max_nlp_seq_len] for x in data["desc_distilbert_it"]] 118 | data["desc_bert_it"] = [x[:self.max_nlp_seq_len] for x in data["desc_bert_it"]] 119 | 120 | 121 | codes_ist2 = [] 122 | for inst_id in range(self.batch_size): 123 | codes_ist2.append(select_n_from_group(codes_ist2_in_file[inst_id], self.samples_per_instance_code, 'code-sample')) 124 | 125 | data["program_sit2"] = list(zip(*codes_ist2)) 126 | 127 | 128 | self.intra_index += 1 129 | if self.intra_index == int(self.instances_per_file / self.batch_size): 130 | self.file_in_memory = None 131 | self.intra_index = 0 132 | self.inter_index += 1 133 | if self.inter_index == len(self.files): 134 | self.batch_sampler() 135 | 136 | return data 137 | 138 | def drop_token_len_exceeds(ist2, max_len): 139 | new = [] 140 | for st2 in ist2: 141 | new.append(list(filter(lambda t2: len(t2)<=max_len, st2))) 142 | bad_insts = np.where([len(x)==0 for x in new])[0].tolist() 143 | return new, bad_insts 144 | 145 | 146 | def get_ChainCoder_dataloader(params, pickle_root, return_dataset=False): 147 | 148 | def collate_fn(batch): 149 | assert len(batch)==1 150 | newbatch = batch[0] 151 | if newbatch==None: 152 | return None 153 | coarse = lambda lst: [(lst[:1] if len(lst) >= 1 else lst) + [lst[i] for i in range(len(lst)) if (i+1) % 500 == 0]] + lst # coarse subsequence inserted to index-0 of original [[S3, S4], ...] list. 154 | newbatch['program_sit2'] = [list(map(coarse, inner_lst)) for inner_lst in newbatch['program_sit2']] 155 | return newbatch 156 | 157 | dataset = STTTDDataset(params, pickle_root) 158 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn) 159 | if return_dataset: 160 | return dataloader, dataset 161 | else: 162 | return dataloader 163 | 164 | def pretrain_loader(params): 165 | from datasets import load_dataset 166 | dataset = load_dataset("codeparrot/codeparrot-clean", split='train') 167 | from tokenizer.tokenizerAPI import tokenizerAPI_OR2N 168 | 169 | def collate_fn(batch): 170 | 171 | ioData_ist2 = [] 172 | desc_it = [] 173 | 174 | 175 | program_it2 = [] 176 | for data in batch: 177 | code_str = data['content'] 178 | print(data['path']) 179 | desc_it.append('') 180 | 181 | code_int_t2 = tokenizerAPI_OR2N(code_str) 182 | if code_int_t2==[]: 183 | return None 184 | program_it2.append(code_int_t2) 185 | 186 | program_sit2 = [program_it2] 187 | 188 | return { 189 | 'ioData_ist2': ioData_ist2, 190 | 'program_sit2': program_sit2, 191 | 'desc_bert_it': desc_it, 192 | 'desc_distilbert_it': desc_it, 193 | } 194 | 195 | 196 | dataloader = DataLoader(dataset, batch_size=params.batch_size, shuffle=True, collate_fn=collate_fn) 197 | return dataloader 198 | 199 | 200 | 201 | def select_n_from_group(group, n, msg=''): 202 | if len(group) >= n: 203 | replace = False 204 | else: 205 | if VERBOSE: 206 | print(f'In ASTer dataloader - {msg}, requested > existing: {n} > {len(group)}. Returned samples now have duplications!') 207 | replace = True 208 | indices = np.random.choice(len(group), n, replace=replace) 209 | selected = [group[i] for i in indices] 210 | return selected 211 | 212 | -------------------------------------------------------------------------------- /dataset_examples/apps_test_codes_readable_nameReplaced_007810.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def func_0(): 4 | return list(map(int, input().split())) 5 | 6 | def func_1(var_in_0, var_in_1=1): 7 | (var_0, var_1, var_2) = ([var_in_0], ([0] * var_3), (['+'] * var_4)) 8 | var_1[var_in_0] = 1 9 | while var_0: 10 | var_5 = var_0.pop() 11 | for (var_6, var_7) in var_8[var_5]: 12 | if (var_1[var_6] == 0): 13 | if ((var_7 * var_in_1) < 0): 14 | var_2[(abs(var_7) - 1)] = '-' 15 | elif ((var_7 * var_in_1) > 0): 16 | var_2[(abs(var_7) - 1)] = '+' 17 | if ((var_in_1 == 1) or (var_7 == 0)): 18 | var_0.append(var_6) 19 | var_1[var_6] = 1 20 | return (''.join(var_2), sum(var_1)) 21 | (var_3, var_4, var_0) = func_0() 22 | var_8 = [[] for var_5 in range(var_3)] 23 | var_7 = 1 24 | for var_9 in range(var_4): 25 | (var_10, var_in_0, var_11) = func_0() 26 | (var_in_0, var_11) = ((var_in_0 - 1), (var_11 - 1)) 27 | if (var_10 == 1): 28 | var_8[var_in_0].append((var_11, 0)) 29 | else: 30 | var_8[var_in_0].append((var_11, var_7)) 31 | var_8[var_11].append((var_in_0, (- var_7))) 32 | var_7 += 1 33 | var_4 = (var_7 - 1) 34 | (var_in_0, var_11) = func_1((var_0 - 1), 1) 35 | print(var_11) 36 | print(var_in_0) 37 | (var_in_0, var_11) = func_1((var_0 - 1), (- 1)) 38 | print(var_11) 39 | print(var_in_0) 40 | 41 | 42 | 43 | # 🟨 🟨 🟨 🟨 44 | 45 | 46 | import sys 47 | input = sys.stdin.readline 48 | 49 | def func_0(): 50 | return list(map(int, input().split())) 51 | 52 | def func_1(var_in_0, var_in_1=1): 53 | (var_0, var_1, var_2) = ([var_in_0], ([0] * var_3), (['+'] * var_4)) 54 | var_1[var_in_0] = 1 55 | while var_0: 56 | var_5 = var_0.pop() 57 | for (var_6, var_7) in var_8[var_5]: 58 | if (var_1[var_6] == 0): 59 | if ((var_7 * var_in_1) < 0): 60 | var_2[(abs(var_7) - 1)] = '-' 61 | elif ((var_7 * var_in_1) > 0): 62 | var_2[(abs(var_7) - 1)] = '+' 63 | if ((var_in_1 == 1) or (var_7 == 0)): 64 | var_0.append(var_6) 65 | var_1[var_6] = 1 66 | return (''.join(var_2), sum(var_1)) 67 | (var_3, var_4, var_0) = func_0() 68 | var_8 = [[] for var_5 in range(var_3)] 69 | var_7 = 1 70 | for var_9 in range(var_4): 71 | (var_10, var_in_0, var_11) = func_0() 72 | (var_in_0, var_11) = ((var_in_0 - 1), (var_11 - 1)) 73 | if (var_10 == 1): 74 | var_8[var_in_0].append((var_11, 0)) 75 | else: 76 | var_8[var_in_0].append((var_11, var_7)) 77 | var_8[var_11].append((var_in_0, (- var_7))) 78 | var_7 += 1 79 | var_4 = (var_7 - 1) 80 | (var_in_0, var_11) = func_1((var_0 - 1), 1) 81 | print(var_11) 82 | print(var_in_0) 83 | (var_in_0, var_11) = func_1((var_0 - 1), (- 1)) 84 | print(var_11) 85 | print(var_in_0) 86 | 87 | 88 | 89 | # 🟨 🟨 🟨 🟨 90 | 91 | 92 | import sys 93 | input = sys.stdin.readline 94 | 95 | def func_0(): 96 | return list(map(int, input().split())) 97 | 98 | def func_1(var_in_0): 99 | var_0 = [var_in_0] 100 | var_1 = ([0] * var_2) 101 | var_3 = (['+'] * var_4) 102 | var_1[var_in_0] = 1 103 | while var_0: 104 | var_5 = var_0.pop() 105 | for (var_6, var_7) in var_8[var_5]: 106 | if (var_1[var_6] == 0): 107 | if (var_7 < 0): 108 | var_3[((- var_7) - 1)] = '-' 109 | elif (var_7 > 0): 110 | var_3[(var_7 - 1)] = '+' 111 | var_0.append(var_6) 112 | var_1[var_6] = 1 113 | return (''.join(var_3), sum(var_1)) 114 | 115 | def func_2(var_in_0): 116 | var_0 = [var_in_0] 117 | var_1 = ([0] * var_2) 118 | var_3 = (['+'] * var_4) 119 | var_1[var_in_0] = 1 120 | while var_0: 121 | var_5 = var_0.pop() 122 | for (var_6, var_7) in var_8[var_5]: 123 | if (var_1[var_6] == 0): 124 | if (var_7 < 0): 125 | var_3[((- var_7) - 1)] = '+' 126 | elif (var_7 > 0): 127 | var_3[(var_7 - 1)] = '-' 128 | if (var_7 == 0): 129 | var_0.append(var_6) 130 | var_1[var_6] = 1 131 | return (''.join(var_3), sum(var_1)) 132 | (var_2, var_4, var_0) = func_0() 133 | var_8 = [[] for var_5 in range(var_2)] 134 | var_7 = 1 135 | for var_9 in range(var_4): 136 | (var_10, var_in_0, var_11) = func_0() 137 | (var_in_0, var_11) = ((var_in_0 - 1), (var_11 - 1)) 138 | if (var_10 == 1): 139 | var_8[var_in_0].append((var_11, 0)) 140 | else: 141 | var_8[var_in_0].append((var_11, var_7)) 142 | var_8[var_11].append((var_in_0, (- var_7))) 143 | var_7 += 1 144 | var_4 = (var_7 - 1) 145 | (var_in_0, var_11) = func_1((var_0 - 1)) 146 | print(var_11) 147 | print(var_in_0) 148 | (var_in_0, var_11) = func_2((var_0 - 1)) 149 | print(var_11) 150 | print(var_in_0) 151 | 152 | 153 | -------------------------------------------------------------------------------- /dataset_examples/apps_test_codes_readable_nameReplaced_007819.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | input = sys.stdin.readline 4 | var_0 = int(input()) 5 | var_1 = [list(map(int, input().split())) for var_2 in range(var_0)] 6 | var_3 = [var_1[var_2][0] for var_2 in range(var_0)] 7 | var_4 = [var_1[var_2][1] for var_2 in range(var_0)] 8 | var_3.sort() 9 | var_4.sort() 10 | var_5 = 0 11 | for var_2 in range(var_0): 12 | var_5 += max(var_3[var_2], var_4[var_2]) 13 | print((var_5 + var_0)) 14 | 15 | 16 | 17 | # 🟨 🟨 🟨 🟨 18 | 19 | 20 | 21 | def func_0(): 22 | n = int(input()) 23 | var_0 = input() 24 | var_1 = 0 25 | for var_2 in var_0: 26 | if (var_2 == '8'): 27 | var_1 += 1 28 | print(min(var_1, (n // 11))) 29 | 30 | def func_1(): 31 | 32 | def get(n): 33 | var_3 = 0 34 | while (n > 0): 35 | var_3 += (n % 10) 36 | n //= 10 37 | return var_3 38 | n = int(input()) 39 | if (n <= 9): 40 | print(n) 41 | return 42 | var_4 = 9 43 | while (n > var_4): 44 | var_4 = ((var_4 * 10) + 9) 45 | var_4 //= 10 46 | print((get(var_4) + get((n - var_4)))) 47 | 48 | def func_2(): 49 | (n, var_5, var_6, var_7) = (int(input()), [], [], 0) 50 | for var_2 in range(n): 51 | (var_8, var_9) = list(map(int, input().split())) 52 | var_5.append(var_8) 53 | var_6.append(var_9) 54 | var_5.sort() 55 | var_6.sort() 56 | for var_2 in range(n): 57 | var_7 += max(var_5[var_2], var_6[var_2]) 58 | print((var_7 + n)) 59 | func_2() 60 | 61 | 62 | 63 | # 🟨 🟨 🟨 🟨 64 | 65 | 66 | var_0 = int(input()) 67 | (var_1, var_2) = ([], []) 68 | for var_3 in range(var_0): 69 | (var_4, var_5) = list(map(int, input().split())) 70 | var_1.append(var_4) 71 | var_2.append(var_5) 72 | var_1.sort(reverse=True) 73 | var_2.sort(reverse=True) 74 | var_6 = var_0 75 | for var_7 in range(var_0): 76 | var_6 += max(var_1[var_7], var_2[var_7]) 77 | print(var_6) 78 | 79 | 80 | 81 | # 🟨 🟨 🟨 🟨 82 | 83 | 84 | var_0 = [[], []] 85 | var_1 = ([[var_0[var_2].append(int(var_3)) for (var_2, var_3) in enumerate(input().split())] for var_1 in range(int(input()))], [var_3.sort() for var_3 in var_0], print((len(var_0[0]) + sum((max(var_3, var_4) for (var_3, var_4) in zip(*var_0)))))) 86 | 87 | 88 | 89 | # 🟨 🟨 🟨 🟨 90 | 91 | 92 | var_0 = (lambda var_in_0: ([[var_in_0[var_1].append(int(var_2)) for (var_1, var_2) in enumerate(input().split())] for var_0 in range(int(input()))], [var_2.sort() for var_2 in var_in_0], print((len(var_in_0[0]) + sum((max(var_2, var_3) for (var_2, var_3) in zip(*var_in_0)))))))([[], []]) 93 | 94 | 95 | 96 | # 🟨 🟨 🟨 🟨 97 | 98 | 99 | print(sum([(max(*var_0) + 1) for var_0 in zip(*list(map(sorted, list(zip(*[list(map(int, input().split())) for var_1 in range(int(input()))])))))])) 100 | 101 | 102 | 103 | # 🟨 🟨 🟨 🟨 104 | 105 | 106 | var_0 = 0 107 | var_1 = int(input()) 108 | var_2 = [] 109 | var_3 = [] 110 | for var_4 in range(var_1): 111 | (var_5, var_6) = list(map(int, input().split())) 112 | var_2.append(var_5) 113 | var_3.append(var_6) 114 | var_2.sort() 115 | var_3.sort() 116 | for var_4 in range(0, var_1): 117 | var_0 += (max(var_2[var_4], var_3[var_4]) + 1) 118 | print(var_0) 119 | 120 | 121 | 122 | # 🟨 🟨 🟨 🟨 123 | 124 | 125 | var_0 = int(input()) 126 | var_1 = [] 127 | var_2 = [] 128 | for var_3 in range(var_0): 129 | (var_4, var_5) = list(map(int, input().split())) 130 | var_1.append(var_4) 131 | var_2.append(var_5) 132 | var_1.sort() 133 | var_2.sort() 134 | print((var_0 + sum(map(max, var_1, var_2)))) 135 | 136 | 137 | 138 | # 🟨 🟨 🟨 🟨 139 | 140 | 141 | import heapq 142 | var_0 = int(input()) 143 | var_1 = [var_2 for var_2 in range(var_0)] 144 | var_3 = [] 145 | var_4 = [] 146 | for var_2 in range(var_0): 147 | (var_5, var_6) = [int(var_7) for var_7 in input().split()] 148 | var_3.append((var_5, var_2)) 149 | var_4.append((var_6, var_2)) 150 | var_3.sort() 151 | var_4.sort() 152 | var_8 = var_0 153 | for var_2 in range(var_0): 154 | var_8 += max(var_3[var_2][0], var_4[var_2][0]) 155 | print(var_8) 156 | 157 | 158 | 159 | # 🟨 🟨 🟨 🟨 160 | 161 | 162 | var_0 = int(input()) 163 | var_1 = [] 164 | var_2 = [] 165 | for var_3 in range(var_0): 166 | (var_4, var_5) = [int(var_6) for var_6 in input().strip().split()] 167 | var_1.append(var_4) 168 | var_2.append(var_5) 169 | var_1.sort() 170 | var_2.sort() 171 | print((var_0 + sum([max(var_1[var_6], var_2[var_6]) for var_6 in range(var_0)]))) 172 | 173 | 174 | 175 | # 🟨 🟨 🟨 🟨 176 | 177 | 178 | var_0 = int(input()) 179 | var_1 = [] 180 | var_2 = [] 181 | for var_3 in range(var_0): 182 | var_4 = [int(var_5) for var_5 in input().split(' ')] 183 | var_1.append(var_4[0]) 184 | var_2.append(var_4[1]) 185 | var_1.sort() 186 | var_2.sort() 187 | var_6 = 0 188 | for var_3 in range(var_0): 189 | var_6 = ((var_6 + max(var_1[var_3], var_2[var_3])) + 1) 190 | print(var_6) 191 | 192 | 193 | 194 | # 🟨 🟨 🟨 🟨 195 | 196 | 197 | from sys import stdin 198 | var_0 = int(10000.0) 199 | var_1 = int((- 10000.0)) 200 | 201 | def func_0(): 202 | return int(stdin.readline()) 203 | 204 | def func_1(): 205 | return [int(var_2) for var_2 in stdin.readline().split()] 206 | 207 | def func_2(): 208 | return input() 209 | 210 | def func_3(): 211 | return [var_2 for var_2 in stdin.readline().split()] 212 | var_3 = func_0() 213 | (var_4, var_5) = ([], []) 214 | for var_6 in range(var_3): 215 | (var_7, var_8) = func_1() 216 | var_4.append(var_7) 217 | var_5.append(var_8) 218 | var_4 = sorted(var_4) 219 | var_5 = sorted(var_5) 220 | var_9 = var_3 221 | for var_10 in range(var_3): 222 | var_9 += max(var_4[var_10], var_5[var_10]) 223 | print(var_9) 224 | 225 | 226 | 227 | # 🟨 🟨 🟨 🟨 228 | 229 | 230 | import sys 231 | 232 | def func_0(): 233 | return map(int, sys.stdin.readline().split()) 234 | var_0 = int(input()) 235 | var_1 = ([0] * var_0) 236 | var_2 = ([0] * var_0) 237 | for var_3 in range(var_0): 238 | (var_1[var_3], var_2[var_3]) = func_0() 239 | var_1.sort() 240 | var_2.sort() 241 | var_4 = 0 242 | for var_3 in range(var_0): 243 | var_4 += max(var_1[var_3], var_2[var_3]) 244 | print((var_4 + var_0)) 245 | 246 | 247 | 248 | # 🟨 🟨 🟨 🟨 249 | 250 | 251 | var_0 = int(input()) 252 | var_1 = [] 253 | var_2 = [] 254 | for var_3 in range(var_0): 255 | (var_4, var_5) = map(int, input().split()) 256 | var_1.append(var_4) 257 | var_2.append(var_5) 258 | var_2 = sorted(var_2) 259 | var_1 = sorted(var_1) 260 | var_6 = var_0 261 | for var_3 in range(var_0): 262 | var_6 += max(var_1[var_3], var_2[var_3]) 263 | print(var_6) 264 | 265 | 266 | 267 | # 🟨 🟨 🟨 🟨 268 | 269 | 270 | var_0 = int(input()) 271 | var_1 = [0 for var_2 in range(var_0)] 272 | var_3 = [0 for var_2 in range(var_0)] 273 | for var_2 in range(var_0): 274 | [var_1[var_2], var_3[var_2]] = map(int, input().split()) 275 | var_1.sort() 276 | var_3.sort() 277 | var_4 = var_0 278 | for var_2 in range(var_0): 279 | var_4 += max(var_1[var_2], var_3[var_2]) 280 | print(var_4) 281 | 282 | 283 | 284 | # 🟨 🟨 🟨 🟨 285 | 286 | 287 | var_0 = int(input()) 288 | var_1 = [] 289 | var_2 = [] 290 | for var_3 in range(0, var_0): 291 | (var_4, var_5) = map(int, input().split()) 292 | var_1.append(var_4) 293 | var_2.append(var_5) 294 | var_1.sort() 295 | var_2.sort() 296 | var_6 = var_0 297 | for var_3 in range(0, var_0): 298 | var_6 += max(var_1[var_3], var_2[var_3]) 299 | print(var_6) 300 | 301 | 302 | 303 | # 🟨 🟨 🟨 🟨 304 | 305 | 306 | var_0 = int(input()) 307 | var_1 = [] 308 | var_2 = [] 309 | for var_3 in range(var_0): 310 | var_4 = [int(var_5) for var_5 in input().split()] 311 | (var_6, var_7) = var_4 312 | var_1.append(var_6) 313 | var_2.append(var_7) 314 | var_1.sort() 315 | var_2.sort() 316 | var_8 = [max(var_9, var_10) for (var_9, var_10) in zip(var_1, var_2)] 317 | print((var_0 + sum(var_8))) 318 | 319 | 320 | 321 | # 🟨 🟨 🟨 🟨 322 | 323 | 324 | var_0 = int(input()) 325 | var_1 = [] 326 | var_2 = [] 327 | for var_3 in range(var_0): 328 | (var_4, var_5) = list(map(int, input().split(' '))) 329 | var_1 += [var_4] 330 | var_2 += [var_5] 331 | var_1.sort() 332 | var_2.sort() 333 | var_6 = 0 334 | for var_3 in range(var_0): 335 | var_6 += 1 336 | var_6 += max(var_1[var_3], var_2[var_3]) 337 | print(var_6) 338 | 339 | 340 | 341 | # 🟨 🟨 🟨 🟨 342 | 343 | 344 | var_0 = int(input()) 345 | var_1 = [] 346 | var_2 = [] 347 | for var_3 in range(var_0): 348 | (var_4, var_5) = map(int, input().split(' ')) 349 | var_1 += [var_4] 350 | var_2 += [var_5] 351 | var_1.sort() 352 | var_2.sort() 353 | var_6 = 0 354 | for var_3 in range(var_0): 355 | var_6 += 1 356 | var_6 += max(var_1[var_3], var_2[var_3]) 357 | print(var_6) 358 | 359 | 360 | 361 | # 🟨 🟨 🟨 🟨 362 | 363 | 364 | var_0 = int(input()) 365 | var_1 = [] 366 | var_2 = [] 367 | for var_3 in range(var_0): 368 | (var_4, var_5) = map(int, input().split()) 369 | var_1.append(var_4) 370 | var_2.append(var_5) 371 | var_1.sort() 372 | var_2.sort() 373 | print((var_0 + sum(map(max, var_1, var_2)))) 374 | 375 | 376 | 377 | # 🟨 🟨 🟨 🟨 378 | 379 | 380 | var_0 = int(input()) 381 | var_1 = [] 382 | var_2 = [] 383 | for var_3 in range(var_0): 384 | (var_4, var_5) = map(int, input().split()) 385 | var_1.append(var_4) 386 | var_2.append(var_5) 387 | print((var_0 + sum(map(max, sorted(var_1), sorted(var_2))))) 388 | 389 | 390 | 391 | # 🟨 🟨 🟨 🟨 392 | 393 | 394 | var_0 = int(input()) 395 | var_1 = [] 396 | var_2 = [] 397 | for var_3 in range(var_0): 398 | (var_4, var_5) = input().split() 399 | var_1.append(int(var_4)) 400 | var_2.append(int(var_5)) 401 | var_1 = sorted(var_1) 402 | var_2 = sorted(var_2) 403 | var_6 = var_0 404 | for var_3 in range(var_0): 405 | var_6 += max(var_1[var_3], var_2[var_3]) 406 | print(var_6) 407 | 408 | 409 | 410 | # 🟨 🟨 🟨 🟨 411 | 412 | 413 | 414 | def func_0(): 415 | return list(map(int, input().split())) 416 | var_0 = int(input()) 417 | var_1 = [] 418 | var_2 = [] 419 | for var_3 in range(var_0): 420 | (var_4, var_5) = func_0() 421 | var_1.append(var_4) 422 | var_2.append(var_5) 423 | var_1.sort() 424 | var_2.sort() 425 | var_6 = var_0 426 | for var_7 in range(var_0): 427 | var_6 += max(var_1[var_7], var_2[var_7]) 428 | print(var_6) 429 | 430 | 431 | 432 | # 🟨 🟨 🟨 🟨 433 | 434 | 435 | var_0 = int(input()) 436 | var_1 = [] 437 | var_2 = [] 438 | for var_3 in range(var_0): 439 | (var_4, var_5) = map(int, input().split()) 440 | var_1.append(var_4) 441 | var_2.append(var_5) 442 | print((var_0 + sum(map(max, sorted(var_1), sorted(var_2))))) 443 | 444 | 445 | -------------------------------------------------------------------------------- /dataset_examples/apps_test_codes_readable_raw_007810.py: -------------------------------------------------------------------------------- 1 | def put(): 2 | return list(map(int, input().split())) 3 | 4 | def dfs(x,flag=1): 5 | s,vis,ans = [x],[0]*n,['+']*m 6 | vis[x]= 1 7 | while s: 8 | i = s.pop() 9 | for j,k in graph[i]: 10 | if vis[j]==0: 11 | if k*flag<0: 12 | ans[abs(k)-1]='-' 13 | elif k*flag>0: 14 | ans[abs(k)-1]='+' 15 | if flag==1 or k==0: 16 | s.append(j) 17 | vis[j]=1 18 | return ''.join(ans), sum(vis) 19 | 20 | n,m,s = put() 21 | graph = [[] for i in range(n)] 22 | k=1 23 | for _ in range(m): 24 | z,x,y = put() 25 | x,y = x-1,y-1 26 | if z==1: 27 | graph[x].append((y, 0)) 28 | else: 29 | graph[x].append((y, k)) 30 | graph[y].append((x,-k)) 31 | k+=1 32 | m = k-1 33 | x,y = dfs(s-1, 1) 34 | print(y) 35 | print(x) 36 | x,y = dfs(s-1,-1) 37 | print(y) 38 | print(x) 39 | 40 | 41 | # 🟨 🟨 🟨 🟨 42 | 43 | import sys 44 | input = sys.stdin.readline 45 | def put(): 46 | return list(map(int, input().split())) 47 | 48 | def dfs(x,flag=1): 49 | s,vis,ans = [x],[0]*n,['+']*m 50 | vis[x]= 1 51 | while s: 52 | i = s.pop() 53 | for j,k in graph[i]: 54 | if vis[j]==0: 55 | if k*flag<0: 56 | ans[abs(k)-1]='-' 57 | elif k*flag>0: 58 | ans[abs(k)-1]='+' 59 | if flag==1 or k==0: 60 | s.append(j) 61 | vis[j]=1 62 | return ''.join(ans), sum(vis) 63 | 64 | n,m,s = put() 65 | graph = [[] for i in range(n)] 66 | k=1 67 | for _ in range(m): 68 | z,x,y = put() 69 | x,y = x-1,y-1 70 | if z==1: 71 | graph[x].append((y, 0)) 72 | else: 73 | graph[x].append((y, k)) 74 | graph[y].append((x,-k)) 75 | k+=1 76 | m = k-1 77 | x,y = dfs(s-1, 1) 78 | print(y) 79 | print(x) 80 | x,y = dfs(s-1,-1) 81 | print(y) 82 | print(x) 83 | 84 | 85 | # 🟨 🟨 🟨 🟨 86 | 87 | import sys 88 | input = sys.stdin.readline 89 | 90 | def put(): 91 | return list(map(int, input().split())) 92 | 93 | 94 | 95 | 96 | def dfs0(x): 97 | s = [x] 98 | vis = [0] * n 99 | ans = ['+'] * m 100 | vis[x] = 1 101 | while s: 102 | i = s.pop() 103 | for j, k in graph[i]: 104 | if (vis[j] == 0): 105 | if (k < 0): 106 | ans[-k - 1] = '-' 107 | elif (k > 0): 108 | ans[k - 1] = '+' 109 | 110 | 111 | s.append(j) 112 | vis[j] = 1 113 | 114 | return ''.join(ans), sum(vis) 115 | 116 | def dfs1(x): 117 | s = [x] 118 | vis = [0] * n 119 | ans = ['+'] * m 120 | vis[x] = 1 121 | while s: 122 | i = s.pop() 123 | for j, k in graph[i]: 124 | if (vis[j] == 0): 125 | if (k < 0): 126 | ans[-k - 1] = '+' 127 | elif (k > 0): 128 | ans[k - 1] = '-' 129 | if (k == 0): 130 | s.append(j) 131 | vis[j] = 1 132 | 133 | return ''.join(ans), sum(vis) 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | n,m,s = put() 143 | graph = [[] for i in range(n)] 144 | 145 | k = 1 146 | 147 | for _ in range(m): 148 | z,x,y = put() 149 | x,y = x - 1, y - 1 150 | if (z == 1): 151 | graph[x].append((y, 0)) 152 | else: 153 | graph[x].append((y, k)) 154 | graph[y].append((x, -k)) 155 | k += 1 156 | 157 | m = k - 1 158 | x, y = dfs0(s - 1) 159 | print(y) 160 | print(x) 161 | x, y = dfs1(s - 1) 162 | print(y) 163 | print(x) 164 | 165 | 166 | -------------------------------------------------------------------------------- /dataset_examples/apps_test_codes_readable_raw_007819.py: -------------------------------------------------------------------------------- 1 | import sys 2 | input = sys.stdin.readline 3 | 4 | n=int(input()) 5 | lr=[list(map(int,input().split())) for i in range(n)] 6 | 7 | L=[lr[i][0] for i in range(n)] 8 | R=[lr[i][1] for i in range(n)] 9 | L.sort() 10 | R.sort() 11 | 12 | ANS=0 13 | 14 | for i in range(n): 15 | ANS+=max(L[i],R[i]) 16 | 17 | print(ANS+n) 18 | 19 | 20 | # 🟨 🟨 🟨 🟨 21 | 22 | def mainA(): 23 | n = int(input()) 24 | s = input() 25 | cnt = 0 26 | for i in s: 27 | if i == '8': 28 | cnt += 1 29 | print(min(cnt, n // 11)) 30 | 31 | def mainB(): 32 | def get(n): 33 | ret = 0 34 | while n > 0: 35 | ret += n % 10 36 | n //= 10 37 | return ret 38 | 39 | n = int(input()) 40 | if n <= 9: 41 | print(n) 42 | return 43 | t = 9 44 | while n > t: 45 | t = t * 10 + 9 46 | t //= 10 47 | print(get(t) + get(n - t)) 48 | 49 | 50 | def mainD(): 51 | n, A, B, ans = int(input()), [], [], 0 52 | for i in range(n): 53 | a, b = list(map(int, input().split())) 54 | A.append(a) 55 | B.append(b) 56 | A.sort() 57 | B.sort() 58 | for i in range(n): 59 | ans += max(A[i], B[i]) 60 | print(ans + n) 61 | 62 | mainD() 63 | 64 | 65 | # 🟨 🟨 🟨 🟨 66 | 67 | n = int(input()) 68 | ps,qs = [], [] 69 | for _ in range(n): 70 | p,q = list(map(int, input().split())) 71 | ps.append(p) 72 | qs.append(q) 73 | 74 | ps.sort(reverse=True) 75 | qs.sort(reverse=True) 76 | res = n 77 | for i in range(n): 78 | res += max(ps[i], qs[i]) 79 | print(res) 80 | 81 | 82 | # 🟨 🟨 🟨 🟨 83 | 84 | a = [[], []] 85 | _ = ([[a[i].append(int(x)) for i, x in enumerate(input().split())] for _ in range(int(input()))], [x.sort() for x in a], print(len(a[0]) + sum(max(x, y) for x, y in zip(*a)))) 86 | 87 | 88 | # 🟨 🟨 🟨 🟨 89 | 90 | _ = (lambda a : ([[a[i].append(int(x)) for i, x in enumerate(input().split())] for _ in range(int(input()))], [x.sort() for x in a], print(len(a[0]) + sum(max(x, y) for x, y in zip(*a)))))([[], []]) 91 | 92 | 93 | # 🟨 🟨 🟨 🟨 94 | 95 | print(sum([max(*x)+1 for x in zip(*list(map(sorted,list(zip(*[list(map(int,input().split())) for _ in range(int(input()))])))))])) 96 | 97 | 98 | # 🟨 🟨 🟨 🟨 99 | 100 | ans=0 101 | n=int(input()) 102 | a=[] 103 | b=[] 104 | for i in range(n): 105 | x,y=list(map(int,input().split())) 106 | a.append(x) 107 | b.append(y) 108 | a.sort() 109 | b.sort() 110 | #print(a) 111 | for i in range(0,n): 112 | ans+=max(a[i],b[i])+1 113 | print(ans) 114 | 115 | 116 | 117 | # 🟨 🟨 🟨 🟨 118 | 119 | n = int(input()) 120 | a = [] 121 | b = [] 122 | for _ in range(n) : 123 | x, y = list(map(int, input().split())) 124 | a.append(x) 125 | b.append(y) 126 | a.sort() 127 | b.sort() 128 | print(n + sum(map(max, a, b))) 129 | 130 | 131 | # 🟨 🟨 🟨 🟨 132 | 133 | import heapq 134 | n=int(input()) 135 | fa=[i for i in range(n)] 136 | ls=[] 137 | rs=[] 138 | for i in range(n): 139 | l,r=[int(x) for x in input().split()] 140 | ls.append((l,i)) 141 | rs.append((r,i)) 142 | ls.sort() 143 | rs.sort() 144 | ans=n 145 | for i in range(n): 146 | ans+=max(ls[i][0],rs[i][0]) 147 | # heapq.heapify(ls) 148 | # heapq.heapify(rs) 149 | # 150 | # ans=n 151 | # if n==1: 152 | # print(max(ls[0][0],rs[0][0])+1) 153 | # quit() 154 | # for i in range(n): 155 | # ll=heapq.heappop(ls) 156 | # if fa[rs[0][1]]!=fa[ll[1]]: 157 | # rr=heapq.heappop(rs) 158 | # fa[ll[1]]=rr[1] 159 | # else: 160 | # tem=heapq.heappop(rs) 161 | # rr=heapq.heappop(rs) 162 | # fa[ll[1]]=rr[1] 163 | # heapq.heappush(rs,tem) 164 | # ans+=max(ll[0],rr[0]) 165 | print(ans) 166 | 167 | # 🟨 🟨 🟨 🟨 168 | 169 | n = int(input()) 170 | a = [] 171 | b = [] 172 | for i in range(n): 173 | l, r = [int(_) for _ in input().strip().split()] 174 | a.append(l) 175 | b.append(r) 176 | 177 | a.sort() 178 | b.sort() 179 | 180 | print(n + sum([max(a[_],b[_]) for _ in range(n)])) 181 | 182 | 183 | # 🟨 🟨 🟨 🟨 184 | 185 | n=int(input()) 186 | 187 | a=[] 188 | b=[] 189 | for i in range(n): 190 | inp=[int(x) for x in input().split(" ")] 191 | a.append(inp[0]) 192 | b.append(inp[1]) 193 | 194 | a.sort() 195 | b.sort() 196 | 197 | ans=0 198 | 199 | for i in range(n): 200 | ans=ans+max(a[i],b[i])+1 201 | 202 | print(ans) 203 | 204 | # 🟨 🟨 🟨 🟨 205 | 206 | 207 | # -*- coding: utf-8 -*- 208 | # @Date : 2018-10-02 08:00:37 209 | # @Author : raj lath (oorja.halt@gmail.com) 210 | # @Link : link 211 | # @Version : 1.0.0 212 | 213 | from sys import stdin 214 | 215 | max_val=int(10e12) 216 | min_val=int(-10e12) 217 | 218 | def read_int() : return int(stdin.readline()) 219 | def read_ints() : return [int(x) for x in stdin.readline().split()] 220 | def read_str() : return input() 221 | def read_strs() : return [x for x in stdin.readline().split()] 222 | 223 | 224 | nb_guest = read_int() 225 | left, rite = [], [] 226 | for _ in range(nb_guest): 227 | a, b = read_ints() 228 | left.append(a) 229 | rite.append(b) 230 | left = sorted(left) 231 | rite = sorted(rite) 232 | answ = nb_guest 233 | for i in range(nb_guest): 234 | answ += max(left[i] , rite[i]) 235 | print(answ) 236 | 237 | 238 | 239 | 240 | # 🟨 🟨 🟨 🟨 241 | 242 | #!/usr/bin/env python3 243 | import sys 244 | 245 | def rint(): 246 | return map(int, sys.stdin.readline().split()) 247 | #lines = stdin.readlines() 248 | 249 | n = int(input()) 250 | r = [0]*n 251 | l = [0]*n 252 | 253 | for i in range(n): 254 | r[i], l[i] = rint() 255 | 256 | r.sort() 257 | l.sort() 258 | 259 | ans = 0 260 | 261 | for i in range(n): 262 | ans += max(r[i], l[i]) 263 | 264 | print(ans + n) 265 | 266 | # 🟨 🟨 🟨 🟨 267 | 268 | n=int(input()) 269 | l=[] 270 | r=[] 271 | for i in range(n): 272 | a,b=map(int,input().split()) 273 | l.append(a) 274 | r.append(b) 275 | r=sorted(r) 276 | l=sorted(l) 277 | ss=n 278 | for i in range(n): 279 | ss+=max(l[i],r[i]) 280 | print(ss) 281 | 282 | # 🟨 🟨 🟨 🟨 283 | 284 | n = int(input()) 285 | l = [0 for i in range(n)] 286 | r = [0 for i in range(n)] 287 | 288 | for i in range(n): 289 | [l[i], r[i]] = map(int, input().split()) 290 | 291 | l.sort() 292 | r.sort() 293 | 294 | res = n 295 | for i in range(n): 296 | res += max(l[i], r[i]) 297 | print(res) 298 | 299 | # 🟨 🟨 🟨 🟨 300 | 301 | n = int(input()) 302 | 303 | l = [] 304 | r = [] 305 | 306 | for i in range(0, n): 307 | x, y = map(int, input().split()) 308 | l.append(x) 309 | r.append(y) 310 | 311 | l.sort() 312 | r.sort() 313 | 314 | res = n 315 | 316 | for i in range(0, n): 317 | res += max(l[i], r[i]) 318 | 319 | print(res) 320 | 321 | # 🟨 🟨 🟨 🟨 322 | 323 | n = int(input()) 324 | 325 | l = [] 326 | r = [] 327 | 328 | for i in range(n): 329 | numbers_in_line = [int(num) for num in input().split()] 330 | l_new, r_new = numbers_in_line 331 | l.append(l_new) 332 | r.append(r_new) 333 | 334 | l.sort() 335 | r.sort() 336 | 337 | maxes = [max(lv, rv) for lv, rv in zip(l, r)] 338 | 339 | print(n + sum(maxes)) 340 | 341 | 342 | # 🟨 🟨 🟨 🟨 343 | 344 | n = int(input()) 345 | leftSpaces = [] 346 | rightSpaces = [] 347 | for i in range(n): 348 | left,right = list(map(int,input().split(" "))) 349 | leftSpaces += [left] 350 | rightSpaces += [right] 351 | leftSpaces.sort() 352 | rightSpaces.sort() 353 | chairs = 0 354 | for i in range(n): 355 | chairs += 1 356 | chairs += max(leftSpaces[i],rightSpaces[i]) 357 | print(chairs) 358 | 359 | 360 | 361 | 362 | # 🟨 🟨 🟨 🟨 363 | 364 | n = int(input()) 365 | leftSpaces = [] 366 | rightSpaces = [] 367 | for i in range(n): 368 | left,right = map(int,input().split(" ")) 369 | leftSpaces += [left] 370 | rightSpaces += [right] 371 | leftSpaces.sort() 372 | rightSpaces.sort() 373 | chairs = 0 374 | for i in range(n): 375 | chairs += 1 376 | chairs += max(leftSpaces[i],rightSpaces[i]) 377 | print(chairs) 378 | 379 | # 🟨 🟨 🟨 🟨 380 | 381 | n = int(input()) 382 | a = [] 383 | b = [] 384 | for _ in range(n) : 385 | x, y = map(int, input().split()) 386 | a.append(x) 387 | b.append(y) 388 | a.sort() 389 | b.sort() 390 | print(n + sum(map(max, a, b))) 391 | 392 | # 🟨 🟨 🟨 🟨 393 | 394 | n = int(input()) 395 | l = [] 396 | r = [] 397 | for _ in range(n): 398 | x, y = map(int, input().split()) 399 | l.append(x) 400 | r.append(y) 401 | print(n+sum(map(max, sorted(l), sorted(r)))) 402 | 403 | # 🟨 🟨 🟨 🟨 404 | 405 | n = int(input()) 406 | left = [] 407 | right = [] 408 | 409 | for i in range(n): 410 | l, r = input().split() 411 | left.append(int(l)) 412 | right.append(int(r)) 413 | 414 | left = sorted(left) 415 | right = sorted(right) 416 | 417 | res = n 418 | 419 | for i in range(n): 420 | res += max(left[i], right[i]) 421 | 422 | print(res) 423 | 424 | 425 | # 🟨 🟨 🟨 🟨 426 | 427 | def get_input_list(): 428 | return list(map(int, input().split())) 429 | n = int(input()) 430 | l = [] 431 | r = [] 432 | for _ in range(n): 433 | li, ri = get_input_list() 434 | l.append(li) 435 | r.append(ri) 436 | l.sort() 437 | r.sort() 438 | res = n 439 | for i in range(n): 440 | res += max(l[i],r[i]) 441 | print(res) 442 | 443 | # 🟨 🟨 🟨 🟨 444 | 445 | ## Problem @ http://codeforces.com/problemset/problem/1060/D 446 | ## #greedy #math 447 | n = int(input()) 448 | left = [] 449 | right = [] 450 | for i in range(n): 451 | a,b = map(int, input().split()) 452 | left.append(a) 453 | right.append(b) 454 | print(n + sum(map(max, sorted(left), sorted(right)))) 455 | -------------------------------------------------------------------------------- /dataset_examples/apps_test_description_007810.txt: -------------------------------------------------------------------------------- 1 | Vasya has a graph containing both directed (oriented) and undirected (non-oriented) edges. There can be multiple edges between a pair of vertices. 2 | 3 | Vasya has picked a vertex s from the graph. Now Vasya wants to create two separate plans: 4 | 5 | to orient each undirected edge in one of two possible directions to maximize number of vertices reachable from vertex s; to orient each undirected edge in one of two possible directions to minimize number of vertices reachable from vertex s. 6 | 7 | In each of two plans each undirected edge must become directed. For an edge chosen directions can differ in two plans. 8 | 9 | Help Vasya find the plans. 10 | -------------------------------------------------------------------------------- /dataset_examples/apps_test_description_007819.txt: -------------------------------------------------------------------------------- 1 | You invited $n$ guests to dinner! You plan to arrange one or more circles of chairs. Each chair is going to be either occupied by one guest, or be empty. You can make any number of circles. 2 | 3 | Your guests happen to be a little bit shy, so the $i$-th guest wants to have a least $l_i$ free chairs to the left of his chair, and at least $r_i$ free chairs to the right. The "left" and "right" directions are chosen assuming all guests are going to be seated towards the center of the circle. Note that when a guest is the only one in his circle, the $l_i$ chairs to his left and $r_i$ chairs to his right may overlap. 4 | 5 | What is smallest total number of chairs you have to use? 6 | -------------------------------------------------------------------------------- /dataset_examples/apps_test_iodatas_readable_007810.py: -------------------------------------------------------------------------------- 1 | ([[2, 2, 1], [1, 1, 2], [2, 2, 1]], [[2], ['-'], [2], ['+']]) 2 | 3 | # 🟨 🟨 🟨 🟨 4 | 5 | ([[6, 6, 3], [2, 2, 6], [1, 4, 5], [2, 3, 4], [1, 4, 1], [1, 3, 1], [2, 2, 3]], [[6], ['++-'], [2], ['+-+']]) 6 | 7 | # 🟨 🟨 🟨 🟨 8 | 9 | ([[5, 5, 5], [2, 5, 3], [1, 2, 3], [1, 4, 5], [2, 5, 2], [1, 2, 1]], [[4], ['++'], [1], ['--']]) 10 | 11 | # 🟨 🟨 🟨 🟨 12 | 13 | ([[13, 18, 9], [2, 3, 10], [1, 12, 10], [1, 11, 4], [2, 2, 8], [1, 5, 1], [1, 7, 12], [1, 5, 13], [1, 9, 7], [1, 10, 11], [2, 3, 12], [1, 9, 2], [1, 3, 9], [1, 8, 12], [2, 11, 3], [1, 3, 1], [1, 8, 4], [2, 9, 11], [1, 12, 13]], [[11], ['++-++'], [8], ['+-+-+']]) 14 | 15 | # 🟨 🟨 🟨 🟨 16 | 17 | ([[5, 10, 2], [2, 2, 4], [1, 1, 2], [2, 2, 3], [1, 3, 1], [1, 4, 1], [1, 5, 1], [1, 3, 4], [2, 5, 4], [1, 5, 2], [2, 5, 3]], [[5], ['++--'], [1], ['--++']]) 18 | 19 | # 🟨 🟨 🟨 🟨 20 | 21 | ([[5, 5, 1], [2, 5, 3], [2, 2, 5], [1, 2, 1], [2, 4, 2], [1, 1, 5]], [[5], ['+--'], [2], ['-++']]) 22 | 23 | # 🟨 🟨 🟨 🟨 24 | 25 | ([[5, 10, 3], [2, 5, 1], [2, 1, 3], [2, 3, 5], [2, 1, 4], [2, 5, 4], [2, 2, 5], [2, 3, 2], [2, 2, 1], [2, 4, 3], [2, 4, 2]], [[5], ['--+---+---'], [1], ['++-+++-+++']]) 26 | 27 | # 🟨 🟨 🟨 🟨 28 | 29 | ([[10, 10, 9], [2, 1, 6], [2, 7, 8], [1, 4, 1], [2, 5, 10], [1, 5, 2], [1, 6, 7], [1, 5, 1], [2, 9, 8], [2, 5, 3], [2, 3, 8]], [[9], ['+-++--'], [1], ['+++-++']]) 30 | 31 | # 🟨 🟨 🟨 🟨 32 | 33 | ([[10, 20, 5], [2, 3, 8], [2, 10, 2], [1, 8, 2], [1, 7, 3], [1, 1, 8], [1, 8, 5], [1, 2, 7], [1, 3, 9], [1, 6, 1], [2, 10, 8], [1, 4, 5], [1, 6, 8], [2, 3, 4], [1, 6, 5], [1, 2, 4], [1, 2, 3], [1, 5, 9], [2, 4, 9], [1, 4, 7], [1, 6, 2]], [[8], ['+----'], [2], ['+++++']]) 34 | 35 | # 🟨 🟨 🟨 🟨 36 | 37 | ([[10, 10, 6], [2, 1, 4], [1, 7, 8], [1, 6, 4], [1, 7, 2], [1, 6, 2], [1, 1, 3], [1, 9, 7], [1, 3, 10], [1, 9, 6], [1, 9, 1]], [[6], ['-'], [3], ['+']]) 38 | 39 | # 🟨 🟨 🟨 🟨 40 | 41 | ([[10, 20, 10], [2, 7, 3], [1, 7, 9], [1, 3, 6], [2, 8, 3], [2, 9, 2], [1, 5, 3], [2, 9, 8], [2, 9, 1], [1, 5, 9], [1, 10, 2], [1, 6, 7], [2, 3, 2], [2, 8, 1], [1, 6, 1], [2, 4, 6], [2, 10, 9], [2, 5, 7], [2, 10, 1], [1, 2, 7], [2, 3, 4]], [[10], ['---+----+-++'], [4], ['-++--+++++-+']]) 42 | 43 | # 🟨 🟨 🟨 🟨 44 | 45 | ([[14, 19, 14], [2, 5, 7], [1, 4, 1], [2, 9, 8], [1, 7, 3], [2, 14, 2], [2, 2, 8], [2, 6, 7], [2, 14, 7], [1, 7, 8], [2, 10, 8], [2, 11, 10], [1, 11, 7], [2, 3, 13], [1, 5, 4], [1, 14, 8], [2, 3, 1], [2, 6, 1], [2, 6, 10], [2, 8, 1]], [[13], ['--+--+--+---+'], [2], ['++-++-++++++-']]) 46 | 47 | # 🟨 🟨 🟨 🟨 48 | 49 | ([[300000, 1, 5345], [2, 5345, 23423]], [[2], ['+'], [1], ['-']]) 50 | 51 | # 🟨 🟨 🟨 🟨 52 | 53 | ([[2, 5, 1], [1, 1, 2], [1, 1, 2], [1, 1, 2], [2, 1, 2], [1, 1, 2]], [[2], ['+'], [2], ['+']]) 54 | 55 | # 🟨 🟨 🟨 🟨 56 | 57 | ([[2, 5, 2], [1, 1, 2], [1, 1, 2], [1, 1, 2], [2, 1, 2], [1, 1, 2]], [[2], ['-'], [1], ['+']]) 58 | 59 | # 🟨 🟨 🟨 🟨 60 | 61 | ([[2, 5, 2], [2, 1, 2], [2, 1, 2], [2, 1, 2], [2, 1, 2], [2, 1, 2]], [[2], ['-----'], [1], ['+++++']]) 62 | 63 | # 🟨 🟨 🟨 🟨 64 | 65 | ([[2, 5, 2], [1, 1, 2], [1, 1, 2], [1, 2, 1], [2, 1, 2], [1, 2, 1]], [[2], ['-'], [2], ['+']]) 66 | 67 | # 🟨 🟨 🟨 🟨 68 | 69 | ([[2, 5, 1], [1, 1, 2], [1, 1, 2], [1, 2, 1], [2, 1, 2], [1, 2, 1]], [[2], ['+'], [2], ['+']]) 70 | 71 | # 🟨 🟨 🟨 🟨 72 | 73 | ([[2, 2, 1], [2, 1, 2], [2, 2, 1]], [[2], ['+-'], [1], ['-+']]) 74 | 75 | # 🟨 🟨 🟨 🟨 76 | 77 | ([[2, 5, 1], [2, 1, 2], [2, 1, 2], [2, 1, 2], [2, 1, 2], [2, 1, 2]], [[2], ['+++++'], [1], ['-----']]) 78 | -------------------------------------------------------------------------------- /dataset_examples/apps_test_iodatas_readable_007819.py: -------------------------------------------------------------------------------- 1 | ([[3], [1, 1], [1, 1], [1, 1]], [[6]]) 2 | 3 | # 🟨 🟨 🟨 🟨 4 | 5 | ([[4], [1, 2], [2, 1], [3, 5], [5, 3]], [[15]]) 6 | 7 | # 🟨 🟨 🟨 🟨 8 | 9 | ([[1], [5, 6]], [[7]]) 10 | 11 | # 🟨 🟨 🟨 🟨 12 | 13 | ([[3], [2, 3], [2, 2], [1, 1]], [[9]]) 14 | 15 | # 🟨 🟨 🟨 🟨 16 | 17 | ([[4], [1, 4], [0, 3], [4, 3], [2, 4]], [[18]]) 18 | 19 | # 🟨 🟨 🟨 🟨 20 | 21 | ([[5], [5, 0], [4, 2], [2, 0], [5, 2], [3, 0]], [[24]]) 22 | 23 | # 🟨 🟨 🟨 🟨 24 | 25 | ([[10], [3, 3], [3, 5], [6, 9], [3, 1], [7, 3], [2, 10], [8, 2], [5, 1], [3, 2], [0, 2]], [[55]]) 26 | 27 | # 🟨 🟨 🟨 🟨 28 | 29 | ([[1], [901418150, 815121916]], [[901418151]]) 30 | 31 | # 🟨 🟨 🟨 🟨 32 | 33 | ([[1], [999999996, 999999988]], [[999999997]]) 34 | 35 | # 🟨 🟨 🟨 🟨 36 | 37 | ([[10], [805513144, 38998401], [16228409, 266085559], [293487744, 471510400], [138613792, 649258082], [904651590, 244678415], [443174087, 503924246], [579288498, 219903162], [179297759, 762760972], [92837851, 728185679], [983905980, 299473031]], [[4814008190]]) 38 | 39 | # 🟨 🟨 🟨 🟨 40 | 41 | ([[1], [0, 0]], [[1]]) 42 | 43 | # 🟨 🟨 🟨 🟨 44 | 45 | ([[1], [1000000000, 0]], [[1000000001]]) 46 | 47 | # 🟨 🟨 🟨 🟨 48 | 49 | ([[1], [1000000000, 999999999]], [[1000000001]]) 50 | 51 | # 🟨 🟨 🟨 🟨 52 | 53 | ([[2], [1000, 0], [0, 1000]], [[1002]]) 54 | 55 | # 🟨 🟨 🟨 🟨 56 | 57 | ([[10], [100, 0], [1234, 0], [1032134, 0], [1, 0], [2, 0], [0, 0], [5, 0], [7, 0], [11, 0], [239, 0]], [[1033743]]) 58 | 59 | # 🟨 🟨 🟨 🟨 60 | 61 | ([[8], [100, 0], [0, 1011], [432, 0], [0, 21], [123123, 0], [0, 123124321], [0, 0], [0, 102]], [[123125463]]) 62 | 63 | # 🟨 🟨 🟨 🟨 64 | 65 | ([[2], [6, 6], [3, 3]], [[11]]) 66 | 67 | # 🟨 🟨 🟨 🟨 68 | 69 | ([[3], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000]], [[3000000003]]) 70 | 71 | # 🟨 🟨 🟨 🟨 72 | 73 | ([[6], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000]], [[6000000006]]) 74 | 75 | # 🟨 🟨 🟨 🟨 76 | 77 | ([[10], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000]], [[10000000010]]) 78 | 79 | # 🟨 🟨 🟨 🟨 80 | 81 | ([[5], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000], [1000000000, 1000000000]], [[5000000005]]) 82 | 83 | # 🟨 🟨 🟨 🟨 84 | 85 | ([[3], [1, 1], [2, 2], [3, 3]], [[9]]) 86 | -------------------------------------------------------------------------------- /dataset_examples/apps_train_codes_readable_nameReplaced_011174.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | input = sys.stdin.readline 4 | var_0 = (lambda : list(map(int, input().split()))) 5 | for var_1 in range(int(input())): 6 | (var_2, var_3) = var_0() 7 | var_4 = 0 8 | var_5 = 0 9 | for var_6 in var_0(): 10 | if (var_3 == var_6): 11 | var_4 = 1 12 | var_5 = max(var_5, var_6) 13 | if var_4: 14 | print(1) 15 | else: 16 | print(max(2, (((var_3 - 1) // var_5) + 1))) 17 | 18 | 19 | 20 | # 🟨 🟨 🟨 🟨 21 | 22 | 23 | var_0 = int(input()) 24 | for var_1 in range(var_0): 25 | (var_2, var_3) = list(map(int, input().split())) 26 | var_4 = [int(var_5) for var_5 in input().split()] 27 | var_4.sort() 28 | if (var_4[(- 1)] > var_3): 29 | if (var_4.count(var_3) == 0): 30 | print(2) 31 | else: 32 | print(1) 33 | elif ((var_3 % var_4[(- 1)]) == 0): 34 | print((var_3 // var_4[(- 1)])) 35 | else: 36 | print(((var_3 // var_4[(- 1)]) + 1)) 37 | 38 | 39 | 40 | # 🟨 🟨 🟨 🟨 41 | 42 | 43 | for var_0 in range(int(input())): 44 | (var_1, var_2) = list(map(int, input().split())) 45 | var_3 = list(map(int, input().split())) 46 | var_4 = (((var_2 + max(var_3)) - 1) // max(var_3)) 47 | if ((var_4 == 1) and (var_2 not in var_3)): 48 | var_4 = 2 49 | print(var_4) 50 | 51 | 52 | 53 | # 🟨 🟨 🟨 🟨 54 | 55 | 56 | 57 | def func_0(): 58 | (var_0, var_1) = list(map(int, input().split())) 59 | var_2 = list(map(int, input().split())) 60 | var_2.sort() 61 | for var_3 in range(var_0): 62 | if (var_2[var_3] == var_1): 63 | print('1') 64 | return 65 | for var_3 in range(var_0): 66 | if (var_2[var_3] >= var_1): 67 | print('2') 68 | return 69 | print(int((((var_1 - 1) / var_2[(var_0 - 1)]) + 1))) 70 | for var_4 in range(int(input())): 71 | func_0() 72 | 73 | 74 | 75 | # 🟨 🟨 🟨 🟨 76 | 77 | 78 | var_0 = int(input()) 79 | for var_1 in range(var_0): 80 | (var_2, var_3) = map(int, input().split()) 81 | var_4 = list(map(int, input().split())) 82 | print(min((((((var_3 + var_5) - 1) // var_5) if (var_5 <= var_3) else 2) for var_5 in var_4))) 83 | 84 | 85 | 86 | # 🟨 🟨 🟨 🟨 87 | 88 | 89 | 90 | def func_0(): 91 | (var_0, var_1) = list(map(int, input().split())) 92 | var_2 = set(map(int, input().split())) 93 | var_3 = max(var_2) 94 | var_4 = (((var_1 + var_3) - 1) // var_3) 95 | if ((var_4 == 1) and (var_1 not in var_2)): 96 | var_4 = 2 97 | print(var_4) 98 | var_5 = int(input()) 99 | for var_6 in range(var_5): 100 | func_0() 101 | 102 | 103 | 104 | # 🟨 🟨 🟨 🟨 105 | 106 | 107 | var_0 = int(input()) 108 | for var_1 in range(var_0): 109 | (var_2, var_3) = map(int, input().split()) 110 | var_4 = list(map(int, input().split())) 111 | var_5 = 10000.0 112 | for var_6 in var_4: 113 | var_7 = 0 114 | if ((var_3 % var_6) != 0): 115 | var_7 += 1 116 | var_7 += (var_3 // var_6) 117 | if (((var_3 // var_6) == 0) and ((var_3 % var_6) != 0)): 118 | var_7 += 1 119 | var_5 = min(var_5, var_7) 120 | print(int(var_5)) 121 | 122 | 123 | -------------------------------------------------------------------------------- /dataset_examples/apps_train_description_011174.txt: -------------------------------------------------------------------------------- 1 | Bessie has way too many friends because she is everyone's favorite cow! Her new friend Rabbit is trying to hop over so they can play! 2 | 3 | More specifically, he wants to get from $(0,0)$ to $(x,0)$ by making multiple hops. He is only willing to hop from one point to another point on the 2D plane if the Euclidean distance between the endpoints of a hop is one of its $n$ favorite numbers: $a_1, a_2, \ldots, a_n$. What is the minimum number of hops Rabbit needs to get from $(0,0)$ to $(x,0)$? Rabbit may land on points with non-integer coordinates. It can be proved that Rabbit can always reach his destination. 4 | 5 | Recall that the Euclidean distance between points $(x_i, y_i)$ and $(x_j, y_j)$ is $\sqrt{(x_i-x_j)^2+(y_i-y_j)^2}$. 6 | 7 | For example, if Rabbit has favorite numbers $1$ and $3$ he could hop from $(0,0)$ to $(4,0)$ in two hops as shown below. Note that there also exists other valid ways to hop to $(4,0)$ in $2$ hops (e.g. $(0,0)$ $\rightarrow$ $(2,-\sqrt{5})$ $\rightarrow$ $(4,0)$). 8 | 9 | $1$ Here is a graphic for the first example. Both hops have distance $3$, one of Rabbit's favorite numbers. 10 | 11 | In other words, each time Rabbit chooses some number $a_i$ and hops with distance equal to $a_i$ in any direction he wants. The same number can be used multiple times. 12 | -------------------------------------------------------------------------------- /dataset_examples/apps_train_description_011175.txt: -------------------------------------------------------------------------------- 1 | Niwel is a little golden bear. As everyone knows, bears live in forests, but Niwel got tired of seeing all the trees so he decided to move to the city. 2 | 3 | In the city, Niwel took on a job managing bears to deliver goods. The city that he lives in can be represented as a directed graph with n nodes and m edges. Each edge has a weight capacity. A delivery consists of a bear carrying weights with their bear hands on a simple path from node 1 to node n. The total weight that travels across a particular edge must not exceed the weight capacity of that edge. 4 | 5 | Niwel has exactly x bears. In the interest of fairness, no bear can rest, and the weight that each bear carries must be exactly the same. However, each bear may take different paths if they like. 6 | 7 | Niwel would like to determine, what is the maximum amount of weight he can deliver (it's the sum of weights carried by bears). Find the maximum weight. 8 | -------------------------------------------------------------------------------- /dataset_examples/apps_train_iodatas_readable_011174.py: -------------------------------------------------------------------------------- 1 | ([[4], [2, 4], [1, 3], [3, 12], [3, 4, 5], [1, 5], [5], [2, 10], [15, 4]], [[2], [3], [1], [2]]) 2 | 3 | # 🟨 🟨 🟨 🟨 4 | 5 | ([[1], [10, 999999733], [25, 68, 91, 55, 36, 29, 96, 4, 63, 3]], [[10416664]]) 6 | 7 | # 🟨 🟨 🟨 🟨 8 | 9 | ([[1], [19, 1000000000], [15, 8, 22, 12, 10, 16, 2, 17, 14, 7, 20, 23, 9, 18, 3, 19, 21, 11, 1]], [[43478261]]) 10 | 11 | # 🟨 🟨 🟨 🟨 12 | 13 | ([[1], [1, 11], [5]], [[3]]) 14 | 15 | # 🟨 🟨 🟨 🟨 16 | 17 | ([[1], [1, 5], [2]], [[3]]) 18 | 19 | # 🟨 🟨 🟨 🟨 20 | 21 | ([[1], [2, 9], [2, 4]], [[3]]) 22 | -------------------------------------------------------------------------------- /dataset_examples/apps_train_iodatas_readable_011175.py: -------------------------------------------------------------------------------- 1 | ([[4, 4, 3], [1, 2, 2], [2, 4, 1], [1, 3, 1], [3, 4, 2]], [[1.5]]) 2 | 3 | # 🟨 🟨 🟨 🟨 4 | 5 | ([[3, 2, 100000], [1, 2, 1], [2, 3, 1]], [[1.0]]) 6 | 7 | # 🟨 🟨 🟨 🟨 8 | 9 | ([[3, 2, 100000], [1, 2, 1], [2, 3, 1000000]], [[1.0]]) 10 | 11 | # 🟨 🟨 🟨 🟨 12 | 13 | ([[2, 1, 100000], [1, 2, 1]], [[1.0]]) 14 | 15 | # 🟨 🟨 🟨 🟨 16 | 17 | ([[3, 2, 100000], [1, 2, 1], [2, 3, 100000]], [[1.0]]) 18 | -------------------------------------------------------------------------------- /evals.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import os 3 | import torch 4 | import numpy as np 5 | import pickle 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | 11 | 12 | from parsers import get_parser 13 | 14 | from trainer.slurm import init_signal_handler, init_distributed_mode 15 | from model import check_model_params, build_modules, load_modules 16 | from model.model_wrapper import ModelWrapper 17 | from model.embedders import get_model_tokenizer 18 | from trainer.trainer import Trainer 19 | 20 | 21 | from dataloaders.loader_utils import timeout 22 | from dataloaders.sttd import get_ChainCoder_dataloader 23 | from dataloaders.check_exec_match import check_io_match_one_sample_int, check_io_match_one_sample_obj 24 | from tokenizer.tokenizerAPI import ( 25 | vocabulary_defs, load_txt, 26 | tokenizerAPI_IN2R, 27 | tokenizerAPI_ON2R, 28 | tokenizerAPI_OR2T, 29 | ) 30 | 31 | 32 | def run_evals(params): 33 | 34 | params.multi_gpu=False 35 | params.is_slurm_job = False 36 | params.local_rank = -1 37 | params.master_port = -1 38 | params.num_workers = 1 39 | params.target_noise=0.0 40 | params.max_input_points=200 41 | os.environ['CUDA_VISIBLE_DEVICES'] = params.CUDA_VISIBLE_DEVICES 42 | 43 | init_distributed_mode(params) 44 | if params.is_slurm_job: 45 | init_signal_handler() 46 | 47 | # CPU / CUDA 48 | if not params.run_on_cpu: 49 | assert torch.cuda.is_available() 50 | params.eval_only=True 51 | 52 | # build environment / modules 53 | if params.batch_size_eval is None: 54 | params.batch_size_eval = int(1.5 * params.batch_size) 55 | 56 | env = vocabulary_defs 57 | modules = build_modules(env, params) 58 | load_modules(params.testing_load_ckpt_from, modules)#### 59 | trnr = Trainer(modules, vocabulary_defs, params) 60 | 61 | 62 | embedder = ( 63 | modules["embedder"].module 64 | if params.multi_gpu 65 | else modules["embedder"] 66 | ) 67 | 68 | encoder = ( 69 | modules["encoder"].module 70 | if params.multi_gpu 71 | else modules["encoder"] 72 | ) 73 | decoder = ( 74 | modules["decoder"].module 75 | if params.multi_gpu 76 | else modules["decoder"] 77 | ) 78 | embedder.eval() 79 | encoder.eval() 80 | decoder.eval() 81 | 82 | 83 | model = ModelWrapper( 84 | env=env, 85 | trnr=trnr, 86 | embedder=embedder, 87 | encoder=encoder, 88 | decoder=decoder, 89 | beam_length_penalty=params.beam_length_penalty, 90 | beam_size=params.beam_size, 91 | max_generated_output_len=params.max_generated_output_len, 92 | beam_early_stopping=params.beam_early_stopping, 93 | beam_temperature=params.beam_temperature, 94 | beam_type=params.beam_type, 95 | ) 96 | if not params.run_on_cpu: 97 | model = model.to('cuda') 98 | 99 | 100 | def control_evaluator(samples_per_instance_io): 101 | params.samples_per_instance_io = samples_per_instance_io 102 | params.samples_per_instance_io_hold = 4 103 | params.batch_size = 1 # at test time, always use batch-size = 1 104 | params.samples_per_instance_code = 2 105 | params.fine_fune_nlp = 0 106 | params.beam_size = 4 107 | return 108 | 109 | 110 | given_samples_list = [4] 111 | for i in tqdm(range(len(given_samples_list))): 112 | control_evaluator(given_samples_list[i]) 113 | testloader = get_ChainCoder_dataloader(params, params.test_pickle_dir) 114 | 115 | print(f'\n\n num samples feed is: {given_samples_list[i]} \n ') 116 | 117 | acc_syntax_error_free, acc_error_free, acc_demo_pass, acc_all_pass = evaluate_syntax_transformer(testloader, model, params) 118 | 119 | return 120 | 121 | 122 | 123 | def evaluate_syntax_transformer(testloader, model, params): 124 | acc_syntax_error_free = [] 125 | 126 | for i, samples in enumerate(tqdm(testloader)): 127 | if samples==None: 128 | continue 129 | 130 | programs_ia = model(samples) # 2D list, dims = [instance, answers] (num of instance always == 1 in test phase); output None means syntax error/etc so as to fail parsing code. 131 | 132 | assert len(programs_ia)==1 133 | answers = programs_ia[0] 134 | 135 | io_objs = [tokenizerAPI_IN2R(samples['ioData_ist2'][0][ioSamp_id]) for ioSamp_id in range(len(samples['ioData_ist2'][0]))] 136 | io_objs = list(map(lambda x: tuple(x), io_objs)) 137 | 138 | if len(answers)!=0: 139 | acc_syntax_error_free.append(1) 140 | else: 141 | acc_syntax_error_free.append(0) 142 | 143 | is_match = False 144 | is_all_bug_free = False 145 | for answer in answers: 146 | ioSamp_id = 0 147 | io_ns = samples['ioData_ist2'][0][ioSamp_id] 148 | io_obj = tokenizerAPI_IN2R(io_ns) 149 | 150 | is_match, exec_out, prt_str = check_io_match_one_sample_obj(io_obj, answer, params.program_forward_run_timeout) 151 | if type(exec_out) is not RuntimeError: 152 | is_all_bug_free = True 153 | if is_match: 154 | is_all_bug_free = True 155 | break 156 | else: 157 | print(prt_str) 158 | 159 | 160 | 161 | if __name__ == "__main__": 162 | 163 | params = get_parser() 164 | run_evals(params) 165 | -------------------------------------------------------------------------------- /model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/ChainCoder/a9ae00758fe9185c7f13f9b9dc591d05de0d4445/model/.DS_Store -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | import os 3 | import torch 4 | 5 | 6 | from .embedders import ChainCoderSampleEmbedder 7 | 8 | from .transformer import TransformerModel 9 | 10 | 11 | 12 | logger = getLogger() 13 | 14 | 15 | def check_model_params(params): 16 | """ 17 | Check models parameters. 18 | """ 19 | # model dimensions 20 | assert params.enc_emb_dim % params.n_enc_heads == 0 21 | assert params.dec_emb_dim % params.n_dec_heads == 0 22 | 23 | # reload a pretrained model 24 | if params.reload_model != "": 25 | print("Reloading model from ", params.reload_model) 26 | assert os.path.isfile(params.reload_model) 27 | 28 | 29 | def build_modules(vocabulary_defs, params): 30 | """ 31 | Build modules. 32 | """ 33 | modules = {} 34 | 35 | modules["embedder"] = ChainCoderSampleEmbedder(params, vocabulary_defs) 36 | # vocabulary_defs.get_length_after_batching = modules["embedder"].get_length_after_batching 37 | 38 | modules["encoder"] = TransformerModel( 39 | params, 40 | vocabulary_defs, 41 | is_encoder=True, 42 | with_output=False, 43 | use_prior_embeddings=True, 44 | positional_embeddings=params.enc_positional_embeddings 45 | ) 46 | 47 | 48 | 49 | 50 | modules["decoder"] = TransformerModel( 51 | params, 52 | vocabulary_defs, 53 | is_encoder=False, 54 | with_output=True, 55 | use_prior_embeddings=False, 56 | positional_embeddings=params.dec_positional_embeddings 57 | ) 58 | 59 | # log 60 | for k, v in modules.items(): 61 | logger.debug(f"{v}: {v}") 62 | for k, v in modules.items(): 63 | logger.info( 64 | f"Number of parameters ({k}): {sum([p.numel() for p in v.parameters() if p.requires_grad])}" 65 | ) 66 | 67 | # cuda 68 | if not params.run_on_cpu: 69 | for v in modules.values(): 70 | v.cuda() 71 | 72 | 73 | if params.torch_parallel: 74 | modules["embedder"] = torch.nn.DataParallel(modules["embedder"]) 75 | modules["encoder"] = torch.nn.DataParallel(modules["encoder"]) 76 | modules["decoder"] = torch.nn.DataParallel(modules["decoder"]) 77 | 78 | return modules 79 | 80 | 81 | 82 | def load_modules(reload_file, modules): 83 | # reload pretrained modules 84 | reloaded = torch.load(reload_file) 85 | for k, v in modules.items(): 86 | assert k in reloaded 87 | if all([k2.startswith("module.") for k2 in reloaded[k].keys()]): 88 | reloaded[k] = { 89 | k2[len("module.") :]: v2 for k2, v2 in reloaded[k].items() 90 | } 91 | v.load_state_dict(reloaded[k]) 92 | print(f"\n\n Reloading modules from {reload_file} SUCCEED! \n") 93 | -------------------------------------------------------------------------------- /model/embedders.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | from abc import ABC, abstractmethod 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import transformers 7 | import numpy as np 8 | 9 | 10 | from trainer.trainer import to_cuda 11 | from .transformer import TransformerModel 12 | 13 | 14 | MultiDimensionalFloat = List[float] 15 | XYPair = Tuple[MultiDimensionalFloat, MultiDimensionalFloat] 16 | Sequence = List[XYPair] 17 | 18 | 19 | 20 | class Embedder(ABC, nn.Module): 21 | """ 22 | Base class for embedders, transforms a sequence of pairs into a sequence of embeddings. 23 | """ 24 | 25 | def __init__(self): 26 | super().__init__() 27 | pass 28 | 29 | @abstractmethod 30 | def forward(self, sequences: List[Sequence]) -> Tuple[torch.Tensor, torch.Tensor]: 31 | pass 32 | 33 | def batch(self, seqs: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: 34 | raise NotImplementedError 35 | 36 | def embed(self, batch: torch.Tensor) -> torch.Tensor: 37 | raise NotImplementedError 38 | 39 | @abstractmethod 40 | def get_length_after_batching(self, sequences: List[Sequence]) -> List[int]: 41 | pass 42 | 43 | 44 | class ChainCoderSampleEmbedder(Embedder): 45 | def __init__(self, params, env): 46 | from .transformer import Embedding 47 | 48 | super().__init__() 49 | self.env = env 50 | self.params = params 51 | self.use_pretrained_NLP = params.use_pretrained_NLP 52 | 53 | self.embeddings = Embedding( 54 | len(self.env.i2t_inp), 55 | params.arch_encoder_dim, 56 | ) 57 | self.pretrain_pad_emb_idx = self.env.t2i_inp['<<||SPECIAL_RESERVED_TOKEN_POSSIBLY_NEVER_USED_9||>>'] 58 | self.activation_fn = F.relu 59 | 60 | self.distill_NLP_syntax_marking_ids, = to_cuda(torch.tensor(env.distill_NLP_syntax_marking_ids), use_cpu=self.params.run_on_cpu) 61 | 62 | 63 | self.sample_embedder = TransformerModel( 64 | params, 65 | env, 66 | is_encoder=True, 67 | with_output=False, 68 | use_prior_embeddings=False, 69 | positional_embeddings=params.enc_positional_embeddings, 70 | is_sample_embedder = True, 71 | ) 72 | 73 | if bool(self.use_pretrained_NLP): 74 | 75 | self.nlp_model, self.nlp_tokenizer, self.nlp_pad_token_id = get_model_tokenizer(params.nlp_arch, 'cpu' if self.params.run_on_cpu else 'cuda') 76 | nlp_out_dim = 768 if params.nlp_arch=='distilbert' else 1024 77 | self.nlp_glue_layer = getMLP([nlp_out_dim, int(nlp_out_dim*1.5), params.arch_encoder_dim]) 78 | 79 | 80 | def compress( 81 | self, tensor_iste: torch.Tensor, lens_i: torch.Tensor, 82 | ) -> Tuple[torch.Tensor, torch.Tensor]: 83 | """ 84 | Takes: (N_max * (d_in+d_out), B, d) tensors 85 | Returns: (N_max, B, d) 86 | """ 87 | _i,_s,_t,_2 = tensor_iste.shape # _2 == 2 88 | sb = tensor_iste.transpose(0,1).reshape(_s,-1) 89 | b_s = sb.transpose(0,1) 90 | lens_it2 = lens_i.reshape(-1,1,1).expand([_i, _t, _2]).reshape(-1) 91 | 92 | bse_out = self.sample_embedder("fwd", x=b_s, lengths=lens_it2, causal=False) 93 | 94 | be_out = bse_out[:,0,:] # pooling for samples: pick first 95 | 96 | ite = be_out.reshape(_i, _t, _2, -1).sum(dim=-2) # sum up content and syntax two subtokens 97 | 98 | return ite 99 | 100 | 101 | def forward(self, tensor_ist2, samp_lens_per_inst, token_lens_per_inst, desc_inp_tensor, desc_attn_mask) -> Tuple[torch.Tensor, torch.Tensor]: 102 | 103 | tensor_ite = self.compress(tensor_ist2, samp_lens_per_inst) # shape = torch.Size([158, 63, 512]) 104 | 105 | if self.use_pretrained_NLP: # use NLP tokenizer, mark distilled tokens with syntax subtokens, and concat distilled NLP tokens with sample_embedder output 106 | if self.params.fine_fune_nlp: # in some starting epochs, it is set to False: fix params of pretrained NLP model. 107 | description_embs_tie = self.nlp_model(desc_inp_tensor, desc_attn_mask) 108 | else: 109 | with torch.no_grad(): 110 | description_embs_tie = self.nlp_model(desc_inp_tensor, desc_attn_mask) 111 | 112 | tensor_tie = tensor_ite.transpose(0,1) 113 | nlp_reserved_len = description_embs_tie.shape[0] 114 | _t, _i, _e = tensor_tie.shape 115 | syntax_markings_tie = self.sample_embedder.embeddings(self.distill_NLP_syntax_marking_ids[:nlp_reserved_len].unsqueeze(-1).expand([nlp_reserved_len, _i])) 116 | description_embs_tie = self.nlp_glue_layer(description_embs_tie) 117 | description_embs_tie += syntax_markings_tie 118 | tensor_tie = torch.cat([description_embs_tie, tensor_tie], dim=0) 119 | token_lens_per_inst += len(self.distill_NLP_syntax_marking_ids) 120 | tensor_ite = tensor_tie.transpose(0,1) 121 | 122 | return tensor_ite, token_lens_per_inst 123 | 124 | 125 | 126 | def get_length_after_batching(self, seqs: List[Sequence]) -> torch.Tensor: 127 | lengths = torch.zeros(len(seqs), dtype=torch.long) 128 | for i, seq in enumerate(seqs): 129 | lengths[i] = len(seq) 130 | assert lengths.max() <= self.max_seq_len, "issue with lengths after batching" 131 | return lengths 132 | 133 | 134 | def getMLP(neurons, activation=nn.GELU, bias=True, dropout=0.2, last_dropout=False, normfun='layernorm'): 135 | def _init_weights(self): 136 | nn.init.xavier_uniform_(self.fc1.weight) 137 | nn.init.xavier_uniform_(self.fc2.weight) 138 | nn.init.normal_(self.fc1.bias, std=1e-6) 139 | nn.init.normal_(self.fc2.bias, std=1e-6) 140 | 141 | if len(neurons) in [0,1]: 142 | return nn.Identity() 143 | 144 | nn_list = [] 145 | n = len(neurons)-1 146 | for i in range(n-1): 147 | if normfun=='layernorm': 148 | norm = nn.LayerNorm(neurons[i+1]) 149 | elif normfun=='batchnorm': 150 | norm = nn.BatchNorm1d(neurons[i+1]) 151 | else: 152 | norm = nn.Identity() 153 | nn_list.extend([nn.Linear(neurons[i], neurons[i+1], bias=bias), norm, activation(), nn.Dropout(dropout)]) 154 | 155 | nn_list.extend([nn.Linear(neurons[n-1], neurons[n], bias=bias)]) 156 | if last_dropout: 157 | if normfun=='layernorm': 158 | norm = nn.LayerNorm(neurons[-1]) 159 | elif normfun=='batchnorm': 160 | norm = nn.BatchNorm1d(neurons[-1]) 161 | else: 162 | norm = nn.Identity() 163 | nn_list.extend([norm, activation(), nn.Dropout(dropout)]) 164 | 165 | return nn.Sequential(*nn_list) 166 | 167 | 168 | 169 | def get_model_tokenizer(which_model, device='cpu'): 170 | if which_model=='bert': 171 | from transformers import BertTokenizer, BertModel 172 | tokenizer = BertTokenizer.from_pretrained('bert-large-uncased') 173 | _model = BertModel.from_pretrained("bert-large-uncased") 174 | pad_token_id = tokenizer.encode(tokenizer.pad_token)[1] 175 | 176 | 177 | elif which_model=='distilbert': 178 | from transformers import DistilBertTokenizer, DistilBertModel 179 | tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") 180 | _model = DistilBertModel.from_pretrained("distilbert-base-uncased") 181 | pad_token_id = tokenizer.encode(tokenizer.pad_token)[1] 182 | 183 | 184 | _model = _model.to(device) 185 | 186 | 187 | def model(inp_tensor, attn_mask): 188 | 189 | 190 | out_BTE = _model(input_ids = inp_tensor, attention_mask = attn_mask).last_hidden_state 191 | first_pool = out_BTE[:,0] # BE 192 | max_pool = out_BTE.max(dim=1).values 193 | min_pool = out_BTE.min(dim=1).values 194 | last_pool = out_BTE[:,0] 195 | 196 | description_embs_tbe = torch.stack([first_pool, last_pool, max_pool, min_pool], dim=0) 197 | return description_embs_tbe 198 | 199 | 200 | return model, tokenizer, pad_token_id 201 | 202 | -------------------------------------------------------------------------------- /model/model_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from tokenizer.tokenizerAPI import tokenizerAPI_ON2R_flatten 6 | 7 | def chunks(lst, n): 8 | """Yield successive n-sized chunks from lst.""" 9 | for i in range(0, len(lst), n): 10 | yield lst[i : i + n] 11 | 12 | class ModelWrapper(nn.Module): 13 | """""" 14 | def __init__(self, 15 | env=None, 16 | trnr=None, 17 | embedder=None, 18 | encoder=None, 19 | decoder=None, 20 | beam_type="search", 21 | beam_length_penalty=1, 22 | beam_size=1, 23 | beam_early_stopping=True, 24 | max_generated_output_len=200, 25 | beam_temperature=1., 26 | ): 27 | super().__init__() 28 | 29 | self.env = env 30 | self.trnr = trnr 31 | self.embedder = embedder 32 | self.encoder = encoder 33 | self.decoder = decoder 34 | self.beam_type = beam_type 35 | self.beam_early_stopping = beam_early_stopping 36 | self.max_generated_output_len = max_generated_output_len 37 | self.beam_size = beam_size 38 | self.beam_length_penalty = beam_length_penalty 39 | self.beam_temperature = beam_temperature 40 | self.device=next(self.embedder.parameters()).device 41 | 42 | def set_args(self, args={}): 43 | for arg, val in args.items(): 44 | assert hasattr(self, arg), "{} arg does not exist".format(arg) 45 | setattr(self, arg, val) 46 | def decode_from_batch_output(self, generations_iat, is_wrapped_by_hyp=0): 47 | bs = len(generations_iat) 48 | if is_wrapped_by_hyp: 49 | generations_iat = [list(filter(lambda x: x is not None, [self.env.idx_to_infix(hyp.cpu().tolist()[1:], is_float=False, str_array=False) for (_, hyp) in generations_iat[i]])) for i in range(bs)] # [3(instance), 4(beam_size), tokens] 50 | else: 51 | generations_iat = [list(filter(lambda x: x is not None, [tokenizerAPI_ON2R_flatten(hyp[1:-1]) for hyp in generations_iat[i]])) for i in range(bs)] # [3,1,body] = [instance, answer, token] 52 | return generations_iat 53 | 54 | @torch.no_grad() 55 | def forward( 56 | self, 57 | samples 58 | ): 59 | 60 | """ 61 | x: bags of sequences (B, T) 62 | """ 63 | 64 | embedder, encoder, decoder = self.embedder, self.encoder, self.decoder 65 | 66 | # x1_TBE, x_len = embedder(samples) 67 | # encoded = encoder("fwd", x=x1_TBE, lengths=x_len, causal=False).transpose(0,1) 68 | 69 | encoded_BTE, x_len = self.trnr.io_embed_encode(samples) 70 | 71 | # encoded = encoded_BTE.transpose(0,1) 72 | encoded = encoded_BTE 73 | 74 | # x_len = torch.tensor([encoded_BTE.shape[0]]).to(encoded_BTE.device) 75 | 76 | outputs = [] 77 | 78 | bs = encoded.shape[0] 79 | 80 | ### Greedy solution. 81 | generations, gen_len = decoder.generate( # generations: torch.Size([68, 3]) = TB; gen_len = [B] 82 | encoded, # BTE 83 | x_len, # [B] 84 | sample_temperature=None, 85 | max_len=self.max_generated_output_len, 86 | ) 87 | 88 | generations = generations.unsqueeze(-1).view(generations.shape[0], bs, 1) # torch.Size([68, 3, 1]) 89 | generations = generations.transpose(0,1).transpose(1,2).cpu().tolist() # (3, 1, 68) 90 | 91 | 92 | generations = self.decode_from_batch_output(generations, is_wrapped_by_hyp=0) 93 | 94 | if self.beam_type == "search": 95 | decoded, tgt_len, search_generations = decoder.generate_beam( # decoded: torch [68, 3]; tgt_len: [3]; search_generations: [batch hyp, beam_num, score and tokens] 96 | encoded, 97 | x_len, 98 | beam_size=self.beam_size, 99 | length_penalty=self.beam_length_penalty, 100 | max_len=self.max_generated_output_len, 101 | early_stopping=self.beam_early_stopping, 102 | ) 103 | search_generations = [sorted([hyp for hyp in search_generations[i].hyp], key=lambda s: s[0], reverse=True) for i in range(bs)] # remove the hyp wrapper 104 | 105 | search_generations = self.decode_from_batch_output(search_generations, is_wrapped_by_hyp=1) 106 | 107 | for i in range(bs): 108 | generations[i].extend(search_generations[i]) # generations and search_generations both: IAT 109 | 110 | elif self.beam_type == "sampling": 111 | num_samples = self.beam_size 112 | encoded = (encoded.unsqueeze(1) 113 | .expand((bs, num_samples) + encoded.shape[1:]) 114 | .contiguous() 115 | .view((bs * num_samples,) + encoded.shape[1:]) 116 | ) 117 | x_len = x_len.unsqueeze(1).expand(bs, num_samples).contiguous().view(-1) 118 | sampling_generations, _ = decoder.generate( 119 | encoded, 120 | x_len, 121 | sample_temperature = self.beam_temperature, 122 | max_len=self.max_generated_output_len 123 | ) 124 | sampling_generations = sampling_generations.unsqueeze(-1).view(sampling_generations.shape[0], bs, num_samples) 125 | sampling_generations = sampling_generations.transpose(0, 1).transpose(1, 2).cpu().tolist() 126 | 127 | 128 | sampling_generations = self.decode_from_batch_output(sampling_generations, is_wrapped_by_hyp=0) 129 | 130 | for i in range(bs): 131 | generations[i].extend(sampling_generations[i]) 132 | else: 133 | raise NotImplementedError 134 | 135 | 136 | outputs.extend(generations) 137 | return outputs # shape= IAT 138 | 139 | 140 | -------------------------------------------------------------------------------- /model/sklearn_wrapper.py: -------------------------------------------------------------------------------- 1 | import math, time, copy 2 | import numpy as np 3 | import torch 4 | from collections import defaultdict 5 | from trainer.metrics import compute_metrics 6 | from sklearn.base import BaseEstimator 7 | import model.utils_wrapper as utils_wrapper 8 | import traceback 9 | 10 | class SyntaxTransformerRegressor(BaseEstimator): 11 | 12 | def __init__(self, 13 | model=None, 14 | max_input_points=10000, 15 | max_number_bags=-1, 16 | stop_refinement_after=1, 17 | n_trees_to_refine=1, 18 | rescale=True 19 | ): 20 | 21 | self.max_input_points = max_input_points 22 | self.max_number_bags = max_number_bags 23 | self.model = model 24 | self.stop_refinement_after = stop_refinement_after 25 | self.n_trees_to_refine = n_trees_to_refine 26 | self.rescale = rescale 27 | 28 | def set_args(self, args={}): 29 | for arg, val in args.items(): 30 | assert hasattr(self, arg), "{} arg does not exist".format(arg) 31 | setattr(self, arg, val) 32 | 33 | def fit( 34 | self, 35 | X, 36 | Y, 37 | verbose=False 38 | ): 39 | self.start_fit = time.time() 40 | 41 | if not isinstance(X, list): 42 | X = [X] 43 | Y = [Y] 44 | n_datasets = len(X) 45 | 46 | scaler = utils_wrapper.StandardScaler() if self.rescale else None 47 | scale_params = {} 48 | if scaler is not None: 49 | scaled_X = [] 50 | for i, x in enumerate(X): 51 | scaled_X.append(scaler.fit_transform(x)) 52 | scale_params[i]=scaler.get_params() 53 | else: 54 | scaled_X = X 55 | 56 | inputs, inputs_ids = [], [] 57 | for seq_id in range(len(scaled_X)): 58 | for seq_l in range(len(scaled_X[seq_id])): 59 | y_seq = Y[seq_id] 60 | if len(y_seq.shape)==1: 61 | y_seq = np.expand_dims(y_seq,-1) 62 | if seq_l%self.max_input_points == 0: 63 | inputs.append([]) 64 | inputs_ids.append(seq_id) 65 | inputs[-1].append([scaled_X[seq_id][seq_l], y_seq[seq_l]]) 66 | 67 | if self.max_number_bags>0: 68 | inputs = inputs[:self.max_number_bags] 69 | inputs_ids = inputs_ids[:self.max_number_bags] 70 | 71 | forward_time=time.time() 72 | outputs = self.model(inputs) ##Forward transformer: returns predicted functions 73 | if verbose: print("Finished forward in {} secs".format(time.time()-forward_time)) 74 | 75 | candidates = defaultdict(list) 76 | assert len(inputs) == len(outputs), "Problem with inputs and outputs" 77 | for i in range(len(inputs)): 78 | input_id = inputs_ids[i] 79 | candidate = outputs[i] 80 | candidates[input_id].extend(candidate) 81 | assert len(candidates.keys())==n_datasets 82 | 83 | self.tree = {} 84 | for input_id, candidates_id in candidates.items(): 85 | if len(candidates_id)==0: 86 | self.tree[input_id]=None 87 | continue 88 | 89 | refined_candidates = self.refine(scaled_X[input_id], Y[input_id], candidates_id, verbose=verbose) 90 | for i,candidate in enumerate(refined_candidates): 91 | if scaler is not None: 92 | refined_candidates[i]["predicted_tree"]=scaler.rescale_function(self.model.env, candidate["predicted_tree"], *scale_params[input_id]) 93 | else: 94 | refined_candidates[i]["predicted_tree"]=candidate["predicted_tree"] 95 | self.tree[input_id] = refined_candidates 96 | 97 | @torch.no_grad() 98 | def evaluate_tree(self, tree, X, y, metric): 99 | numexpr_fn = self.model.env.simplifier.tree_to_numexpr_fn(tree) 100 | y_tilde = numexpr_fn(X)[:,0] 101 | metrics = compute_metrics({"true": [y], "predicted": [y_tilde], "predicted_tree": [tree]}, metrics=metric) 102 | return metrics[metric][0] 103 | 104 | def order_candidates(self, X, y, candidates, metric="_mse", verbose=False): 105 | scores = [] 106 | for candidate in candidates: 107 | if metric not in candidate: 108 | score = self.evaluate_tree(candidate["predicted_tree"], X, y, metric) 109 | if math.isnan(score): 110 | score = np.infty if metric.startswith("_") else -np.infty 111 | else: 112 | score = candidates[metric] 113 | scores.append(score) 114 | ordered_idx = np.argsort(scores) 115 | if not metric.startswith("_"): ordered_idx=list(reversed(ordered_idx)) 116 | candidates = [candidates[i] for i in ordered_idx] 117 | return candidates 118 | 119 | def refine(self, X, y, candidates, verbose): 120 | refined_candidates = [] 121 | 122 | ## For skeleton model 123 | for i, candidate in enumerate(candidates): 124 | candidate_skeleton, candidate_constants = self.model.env.generator.function_to_skeleton(candidate, constants_with_idx=True) 125 | if "CONSTANT" in candidate_constants: 126 | candidates[i] = self.model.env.wrap_equation_floats(candidate_skeleton, np.random.randn(len(candidate_constants))) 127 | 128 | candidates = [{"refinement_type": "NoRef", "predicted_tree": candidate, "time": time.time()-self.start_fit} for candidate in candidates] 129 | candidates = self.order_candidates(X, y, candidates, metric="_mse", verbose=verbose) 130 | 131 | ## REMOVE SKELETON DUPLICATAS 132 | skeleton_candidates, candidates_to_remove = {}, [] 133 | for i, candidate in enumerate(candidates): 134 | skeleton_candidate, _ = self.model.env.generator.function_to_skeleton(candidate["predicted_tree"], constants_with_idx=False) 135 | if skeleton_candidate.infix() in skeleton_candidates: 136 | candidates_to_remove.append(i) 137 | else: 138 | skeleton_candidates[skeleton_candidate.infix()]=1 139 | if verbose: print("Removed {}/{} skeleton duplicata".format(len(candidates_to_remove), len(candidates))) 140 | 141 | candidates = [candidates[i] for i in range(len(candidates)) if i not in candidates_to_remove] 142 | if self.n_trees_to_refine>0: 143 | candidates_to_refine = candidates[:self.n_trees_to_refine] 144 | else: 145 | candidates_to_refine = copy.deepcopy(candidates) 146 | 147 | for candidate in candidates_to_refine: 148 | refinement_strategy = utils_wrapper.BFGSRefinement() 149 | candidate_skeleton, candidate_constants = self.model.env.generator.function_to_skeleton(candidate["predicted_tree"], constants_with_idx=True) 150 | try: 151 | refined_candidate = refinement_strategy.go(env=self.model.env, 152 | tree=candidate_skeleton, 153 | coeffs0=candidate_constants, 154 | X=X, 155 | y=y, 156 | downsample=1024, 157 | stop_after=self.stop_refinement_after) 158 | 159 | except Exception as e: 160 | if verbose: 161 | print(e) 162 | #traceback.format_exc() 163 | continue 164 | 165 | if refined_candidate is not None: 166 | refined_candidates.append({ 167 | "refinement_type": "BFGS", 168 | "predicted_tree": refined_candidate, 169 | }) 170 | candidates.extend(refined_candidates) 171 | candidates = self.order_candidates(X, y, candidates, metric="r2") 172 | 173 | for candidate in candidates: 174 | if "time" not in candidate: 175 | candidate["time"]=time.time()-self.start_fit 176 | return candidates 177 | 178 | def __str__(self): 179 | if hasattr(self, "tree"): 180 | for tree_idx in range(len(self.tree)): 181 | for gen in self.tree[tree_idx]: 182 | print(gen) 183 | return "Transformer" 184 | 185 | def retrieve_refinements_types(self): 186 | return ["BFGS", "NoRef"] 187 | 188 | def retrieve_tree(self, refinement_type=None, tree_idx=0, with_infos=False): 189 | if tree_idx == -1: idxs = [_ for _ in range(len(self.tree))] 190 | else: idxs = [tree_idx] 191 | best_trees = [] 192 | for idx in idxs: 193 | best_tree = copy.deepcopy(self.tree[idx]) 194 | if best_tree and refinement_type is not None: 195 | best_tree = list(filter(lambda gen: gen["refinement_type"]==refinement_type, best_tree)) 196 | if not best_tree: 197 | if with_infos: 198 | best_trees.append({"predicted_tree": None, "refinement_type": None, "time": None}) 199 | else: 200 | best_trees.append(None) 201 | else: 202 | if with_infos: 203 | best_trees.append(best_tree[0]) 204 | else: 205 | best_trees.append(best_tree[0]["predicted_tree"]) 206 | if tree_idx != -1: return best_trees[0] 207 | else: return best_trees 208 | 209 | 210 | def predict(self, X, refinement_type=None, tree_idx=0, batch=False): 211 | if not isinstance(X, list): 212 | X = [X] 213 | res = [] 214 | if batch: 215 | tree = self.retrieve_tree(refinement_type=refinement_type, tree_idx = -1) 216 | for tree_idx in range(len(tree)): 217 | X_idx = X[tree_idx] 218 | if tree[tree_idx] is None: 219 | res.append(None) 220 | else: 221 | numexpr_fn = self.model.env.simplifier.tree_to_numexpr_fn(tree[tree_idx]) 222 | y = numexpr_fn(X_idx)[:,0] 223 | res.append(y) 224 | return res 225 | else: 226 | X_idx = X[tree_idx] 227 | tree = self.retrieve_tree(refinement_type=refinement_type, tree_idx = tree_idx) 228 | if tree is not None: 229 | numexpr_fn = self.model.env.simplifier.tree_to_numexpr_fn(tree) 230 | y = numexpr_fn(X_idx)[:,0] 231 | return y 232 | else: 233 | return None 234 | -------------------------------------------------------------------------------- /model/utils_wrapper.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import sklearn 3 | from scipy.optimize import minimize 4 | import numpy as np 5 | import time 6 | import torch 7 | from functorch import grad 8 | from functools import partial 9 | import traceback 10 | 11 | class TimedFun: 12 | def __init__(self, fun, verbose=False, stop_after=3): 13 | self.fun_in = fun 14 | self.started = False 15 | self.stop_after = stop_after 16 | self.best_fun_value = np.infty 17 | self.best_x = None 18 | self.loss_history=[] 19 | self.verbose = verbose 20 | 21 | def fun(self, x, *args): 22 | if self.started is False: 23 | self.started = time.time() 24 | elif abs(time.time() - self.started) >= self.stop_after: 25 | self.loss_history.append(self.best_fun_value) 26 | raise ValueError("Time is over.") 27 | self.fun_value = self.fun_in(x, *args) 28 | self.loss_history.append(self.fun_value) 29 | if self.best_x is None: 30 | self.best_x=x 31 | elif self.fun_value < self.best_fun_value: 32 | self.best_fun_value=self.fun_value 33 | self.best_x=x 34 | self.x = x 35 | return self.fun_value 36 | 37 | class Scaler(ABC): 38 | """ 39 | Base class for scalers 40 | """ 41 | 42 | def __init__(self): 43 | pass 44 | 45 | @abstractmethod 46 | def fit(self, X): 47 | pass 48 | 49 | @abstractmethod 50 | def fit_transform(self, X): 51 | pass 52 | 53 | @abstractmethod 54 | def transform(self, X): 55 | pass 56 | 57 | @abstractmethod 58 | def get_params(self): 59 | pass 60 | 61 | def rescale_function(self, env, tree, a, b): 62 | prefix = tree.prefix().split(",") 63 | idx = 0 64 | while idx < len(prefix): 65 | if prefix[idx].startswith("x_"): 66 | k = int(prefix[idx][-1]) 67 | if k>=len(a): 68 | continue 69 | a_k, b_k = str(a[k]), str(b[k]) 70 | prefix_to_add = ["add", b_k, "mul", a_k, prefix[idx]] 71 | prefix = prefix[:idx] + prefix_to_add + prefix[min(idx + 1, len(prefix)):] 72 | idx += len(prefix_to_add) 73 | else: 74 | idx+=1 75 | continue 76 | rescaled_tree = env.word_to_infix(prefix, is_float=False, str_array=False) 77 | return rescaled_tree 78 | 79 | class StandardScaler(Scaler): 80 | def __init__(self): 81 | """ 82 | transformation is: 83 | x' = (x - mean)/std 84 | """ 85 | self.scaler = sklearn.preprocessing.StandardScaler() 86 | 87 | def fit(self, X): 88 | self.scaler.fit(X) 89 | 90 | def fit_transform(self, X): 91 | scaled_X = self.scaler.fit_transform(X) 92 | return scaled_X 93 | 94 | def transform(self, X): 95 | m, s = self.scaler.mean_, np.sqrt(self.scaler.var_) 96 | return (X-m)/s 97 | 98 | def get_params(self): 99 | m, s = self.scaler.mean_, np.sqrt(self.scaler.var_) 100 | a, b = 1/s, -m/s 101 | return (a, b) 102 | 103 | class MinMaxScaler(Scaler): 104 | def __init__(self): 105 | """ 106 | transformation is: 107 | x' = 2.*(x-xmin)/(xmax-xmin)-1. 108 | """ 109 | self.scaler = sklearn.preprocessing.MinMaxScaler(feature_range=(-1,1)) 110 | 111 | def fit(self, X): 112 | self.scaler.fit(X) 113 | 114 | def fit_transform(self, X): 115 | scaled_X = self.scaler.fit_transform(X) 116 | return scaled_X 117 | 118 | def transform(self, X): 119 | val_min, val_max = self.scaler.data_min_, self.scaler.data_max_ 120 | return 2*(X-val_min)/(val_max-val_min)-1. 121 | 122 | def get_params(self): 123 | val_min, val_max = self.scaler.data_min_, self.scaler.data_max_ 124 | a, b = 2./(val_max-val_min), -1.-2.*val_min/(val_max-val_min) 125 | return (a, b) 126 | 127 | class BFGSRefinement(): 128 | """ 129 | Wrapper around scipy's BFGS solver 130 | """ 131 | 132 | def __init__(self): 133 | """ 134 | Args: 135 | func: a PyTorch function that maps dependent variabels and 136 | parameters to function outputs for all data samples 137 | `func(x, coeffs) -> y` 138 | x, y: problem data as PyTorch tensors. Shape of x is (d, n) and 139 | shape of y is (n,) 140 | """ 141 | super().__init__() 142 | 143 | def go( 144 | self, env, tree, coeffs0, X, y, downsample=-1, stop_after=10 145 | ): 146 | 147 | func = env.simplifier.tree_to_torch_module(tree, dtype=torch.float64) 148 | self.X, self.y = X, y 149 | if downsample>0: 150 | self.X = self.X[:downsample] 151 | self.y = self.y[:downsample] 152 | self.X=torch.tensor(self.X, dtype=torch.float64, requires_grad=False) 153 | self.y=torch.tensor(self.y, dtype=torch.float64, requires_grad=False) 154 | self.func = partial(func, self.X) 155 | 156 | def objective_torch(coeffs): 157 | """ 158 | Compute the non-linear least-squares objective value 159 | objective(coeffs) = (1/2) sum((y - func(coeffs)) ** 2) 160 | Returns a PyTorch tensor. 161 | """ 162 | if not isinstance(coeffs, torch.Tensor): 163 | coeffs = torch.tensor(coeffs, dtype=torch.float64, requires_grad=True) 164 | y_tilde = self.func(coeffs) 165 | if y_tilde is None: return None 166 | mse = (self.y -y_tilde).pow(2).mean().div(2) 167 | return mse 168 | 169 | def objective_numpy(coeffs): 170 | """ 171 | Return the objective value as a float (for scipy). 172 | """ 173 | return objective_torch(coeffs).item() 174 | 175 | def gradient_numpy(coeffs): 176 | """ 177 | Compute the gradient of the objective at coeffs. 178 | Returns a numpy array (for scipy) 179 | """ 180 | if not isinstance(coeffs, torch.Tensor): 181 | coeffs = torch.tensor(coeffs, dtype=torch.float64, requires_grad=True) 182 | grad_obj = grad(objective_torch)(coeffs) 183 | return grad_obj.detach().numpy() 184 | 185 | objective_numpy_timed = TimedFun(objective_numpy, stop_after=stop_after) 186 | 187 | try: 188 | minimize( 189 | objective_numpy_timed.fun, 190 | coeffs0, 191 | method="BFGS", 192 | jac=gradient_numpy, 193 | options = {"disp": False} 194 | ) 195 | except ValueError as e: 196 | traceback.format_exc() 197 | best_constants = objective_numpy_timed.best_x 198 | return env.wrap_equation_floats(tree, best_constants) -------------------------------------------------------------------------------- /quick_start_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tokenizer.tokenizerAPI import ( 3 | tokenizerAPI_IR2T, 4 | tokenizerAPI_OR2T, 5 | tokenizerAPI_IT2R, 6 | tokenizerAPI_OT2R, 7 | tokenizerAPI_IT2N, 8 | tokenizerAPI_OT2N, 9 | tokenizerAPI_IN2T, 10 | tokenizerAPI_ON2T, 11 | tokenizerAPI_IN2R, 12 | tokenizerAPI_ON2R, 13 | tokenizerAPI_IR2N, 14 | tokenizerAPI_OR2N, 15 | ) 16 | 17 | 18 | def myprint(*x): 19 | ''' 20 | print both to terminal and an offline file 21 | ''' 22 | print(*x) 23 | print(*x, file=open('quick_start_tokenizer_output.txt', 'a')) 24 | 25 | sep = '\n\n________________________________\n' 26 | 27 | 28 | 29 | example_code_string_representation = ''' 30 | # Leetcode problem 7. Reverse Integer: 31 | # https://leetcode.com/problems/reverse-integer/description/ 32 | class Solution(object): 33 | def reverse(self, x): 34 | reverse = 0 35 | sign = -1 if x < 0 else 1 36 | x = abs(x) 37 | while x: 38 | digit = x % 10 39 | reverse = reverse * 10 + digit 40 | x /= 10 41 | result = sign * reverse 42 | if result > 2 ** 31 - 1 or result < -(2 ** 31): 43 | return 0 44 | return result 45 | ''' 46 | # example_io_data = [[1,4,'hello'], [True,False]] 47 | example_io_data = [1,4,'hello'] 48 | 49 | if os.path.exists('quick_start_tokenizer_output.txt'): 50 | os.remove('quick_start_tokenizer_output.txt') 51 | 52 | 53 | myprint(f'{sep}demo code:\n{example_code_string_representation}') 54 | 55 | # ---- Visualization for: python code string -> S3 and S4 subsequences 56 | syntax_token_S3, content_tokens_S4 = tokenizerAPI_OR2T(example_code_string_representation) 57 | myprint(f'{sep}syntax_token_S3:') 58 | for x in syntax_token_S3: 59 | myprint(x) 60 | myprint(f'{sep}content_tokens_S4:') 61 | for x in content_tokens_S4: 62 | myprint(x) 63 | 64 | 65 | # ---- Visualization for: S3/S4 -> integer sequence 66 | int_seq = tokenizerAPI_OT2N(syntax_token_S3, content_tokens_S4) 67 | myprint(f'{sep}integer sequence:\n{int_seq}') 68 | 69 | # ---- Visualization for combined one-step: python code string -> integer sequence 70 | int_seq = tokenizerAPI_OR2N(example_code_string_representation) 71 | myprint(f'{sep}integer sequence:\n{int_seq}') 72 | 73 | # ---- convert back to python string 74 | py_code_string = tokenizerAPI_ON2R(int_seq) 75 | myprint(f'{sep}Python code converted back:{py_code_string}') 76 | 77 | # ---- Visualization for: io_obj -> token represnetations 78 | myprint(f'{sep}I/O python object:\n{example_io_data}') 79 | syn_IO, cont_IO = tokenizerAPI_IR2T(example_io_data) 80 | myprint(f'{sep}IO data tokens:\n{syn_IO}\n{cont_IO}\n') 81 | 82 | # ---- Visualization for: IO data -> int sequence -> convert back to IO data 83 | # io_data_back = tokenizerAPI_IN2R(tokenizerAPI_IR2N(example_io_data)) 84 | # myprint(f'{sep}IO data converted back:\n{io_data_back}') 85 | 86 | 87 | os.system(f"open quick_start_tokenizer_output.txt") 88 | 89 | -------------------------------------------------------------------------------- /quick_start_tokenizer_output.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | ________________________________ 4 | demo code: 5 | 6 | # Leetcode problem 7. Reverse Integer: 7 | # https://leetcode.com/problems/reverse-integer/description/ 8 | class Solution(object): 9 | def reverse(self, x): 10 | reverse = 0 11 | sign = -1 if x < 0 else 1 12 | x = abs(x) 13 | while x: 14 | digit = x % 10 15 | reverse = reverse * 10 + digit 16 | x /= 10 17 | result = sign * reverse 18 | if result > 2 ** 31 - 1 or result < -(2 ** 31): 19 | return 0 20 | return result 21 | 22 | 23 | 24 | ________________________________ 25 | syntax_token_S3: 26 | Module(body=[ClassDef(name= 27 | ,bases=[Name(id= 28 | ,ctx=Load())],keywords=[],body=[FunctionDef(name= 29 | ,args=arguments(posonlyargs=[],args=[arg(arg= 30 | ,annotation=None,type_comment=None),arg(arg= 31 | ,annotation=None,type_comment=None)],vararg=None,kwonlyargs=[],kw_defaults=[],kwarg=None,defaults=[]),body=[Assign(targets=[Name(id= 32 | ,ctx=Store())],value=Constant(value= 33 | ,kind=None),type_comment=None),Assign(targets=[Name(id= 34 | ,ctx=Store())],value=IfExp(test=Compare(left=Name(id= 35 | ,ctx=Load()),ops=[ 36 | ],comparators=[Constant(value= 37 | ,kind=None)]),body=UnaryOp(op= 38 | ,operand=Constant(value= 39 | ,kind=None)),orelse=Constant(value= 40 | ,kind=None)),type_comment=None),Assign(targets=[Name(id= 41 | ,ctx=Store())],value=Call(func=Name(id= 42 | ,ctx=Load()),args=[Name(id= 43 | ,ctx=Load())],keywords=[]),type_comment=None),While(test=Name(id= 44 | ,ctx=Load()),body=[Assign(targets=[Name(id= 45 | ,ctx=Store())],value=BinOp(left=Name(id= 46 | ,ctx=Load()),op= 47 | ,right=Constant(value= 48 | ,kind=None)),type_comment=None),Assign(targets=[Name(id= 49 | ,ctx=Store())],value=BinOp(left=BinOp(left=Name(id= 50 | ,ctx=Load()),op= 51 | ,right=Constant(value= 52 | ,kind=None)),op= 53 | ,right=Name(id= 54 | ,ctx=Load())),type_comment=None),AugAssign(target=Name(id= 55 | ,ctx=Store()),op= 56 | ,value=Constant(value= 57 | ,kind=None))],orelse=[]),Assign(targets=[Name(id= 58 | ,ctx=Store())],value=BinOp(left=Name(id= 59 | ,ctx=Load()),op= 60 | ,right=Name(id= 61 | ,ctx=Load())),type_comment=None),If(test=BoolOp(op= 62 | ,values=[Compare(left=Name(id= 63 | ,ctx=Load()),ops=[ 64 | ],comparators=[BinOp(left=BinOp(left=Constant(value= 65 | ,kind=None),op= 66 | ,right=Constant(value= 67 | ,kind=None)),op= 68 | ,right=Constant(value= 69 | ,kind=None))]),Compare(left=Name(id= 70 | ,ctx=Load()),ops=[ 71 | ],comparators=[UnaryOp(op= 72 | ,operand=BinOp(left=Constant(value= 73 | ,kind=None),op= 74 | ,right=Constant(value= 75 | ,kind=None)))])]),body=[Return(value=Constant(value= 76 | ,kind=None))],orelse=[]),Return(value=Name(id= 77 | ,ctx=Load()))],decorator_list=[],returns=None,type_comment=None)],decorator_list=[])],type_ignores=[]) 78 | 79 | 80 | ________________________________ 81 | content_tokens_S4: 82 | Class_0 83 | object 84 | reverse 85 | self 86 | x 87 | reverse 88 | 0 89 | var_0 90 | x 91 | Lt() 92 | 0 93 | USub() 94 | 1 95 | 1 96 | x 97 | abs 98 | x 99 | x 100 | var_1 101 | x 102 | Mod() 103 | 10 104 | reverse 105 | reverse 106 | Mult() 107 | 10 108 | Add() 109 | var_1 110 | x 111 | Div() 112 | 10 113 | var_2 114 | var_0 115 | Mult() 116 | reverse 117 | Or() 118 | var_2 119 | Gt() 120 | 2 121 | Pow() 122 | 31 123 | Sub() 124 | 1 125 | var_2 126 | Lt() 127 | USub() 128 | 2 129 | Pow() 130 | 31 131 | 0 132 | var_2 133 | 134 | 135 | ________________________________ 136 | integer sequence: 137 | [[2397, 112], [1814, 510], [1853, 529], [1282, 1936], [1283, 672], [1284, 529], [1406, 881], [1407, 570], [2446, 672], [1321, 19], [1409, 881], [11219, 34], [1325, 983], [5441, 983], [1401, 672], [1285, 342], [1286, 672], [4210, 672], [1302, 571], [1392, 672], [1310, 22], [1311, 993], [1401, 529], [2267, 529], [1310, 23], [1311, 993], [1992, 0], [1316, 571], [1785, 672], [1355, 8], [1307, 993], [1529, 582], [1392, 570], [1310, 23], [1316, 529], [2308, 27], [1486, 582], [1321, 11], [1791, 1110], [1431, 29], [1311, 1141], [1992, 32], [1311, 983], [5312, 582], [1321, 19], [2145, 34], [4722, 1110], [1431, 29], [1311, 1141], [67442, 881], [1520, 582], [51212, 86]] 138 | 139 | 140 | ________________________________ 141 | integer sequence: 142 | [[2397, 112], [1814, 510], [1853, 529], [1282, 1936], [1283, 672], [1284, 529], [1406, 881], [1407, 570], [2446, 672], [1321, 19], [1409, 881], [11219, 34], [1325, 983], [5441, 983], [1401, 672], [1285, 342], [1286, 672], [4210, 672], [1302, 571], [1392, 672], [1310, 22], [1311, 993], [1401, 529], [2267, 529], [1310, 23], [1311, 993], [1992, 0], [1316, 571], [1785, 672], [1355, 8], [1307, 993], [1529, 582], [1392, 570], [1310, 23], [1316, 529], [2308, 27], [1486, 582], [1321, 11], [1791, 1110], [1431, 29], [1311, 1141], [1992, 32], [1311, 983], [5312, 582], [1321, 19], [2145, 34], [4722, 1110], [1431, 29], [1311, 1141], [67442, 881], [1520, 582], [51212, 86]] 143 | 144 | 145 | ________________________________ 146 | Python code converted back: 147 | 148 | class Class_0(object): 149 | 150 | def reverse(self, x): 151 | reverse = 0 152 | var_0 = ((- 1) if (x < 0) else 1) 153 | x = abs(x) 154 | while x: 155 | var_1 = (x % 10) 156 | reverse = ((reverse * 10) + var_1) 157 | x /= 10 158 | var_2 = (var_0 * reverse) 159 | if ((var_2 > ((2 ** 31) - 1)) or (var_2 < (- (2 ** 31)))): 160 | return 0 161 | return var_2 162 | 163 | 164 | 165 | 166 | ________________________________ 167 | I/O python object: 168 | [1, 4, 'hello'] 169 | 170 | 171 | ________________________________ 172 | IO data tokens: 173 | ['Module(body=[Expr(value=List(elts=[Constant(value=', ',kind=None),Constant(value=', ',kind=None),Constant(value=', '<<||syntax_on_hold_for_content_token||>>', '<<||syntax_on_hold_for_content_token||>>', '<<||syntax_on_hold_for_content_token||>>', '<<||syntax_on_hold_for_content_token||>>', ',kind=None)],ctx=Load()))],type_ignores=[])'] 174 | [1, 4, 'h', 'e', 'l', 'l', 'o'] 175 | 176 | -------------------------------------------------------------------------------- /tokenizer/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/ChainCoder/a9ae00758fe9185c7f13f9b9dc591d05de0d4445/tokenizer/.DS_Store -------------------------------------------------------------------------------- /tokenizer/astunparse/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import absolute_import 3 | from six.moves import cStringIO 4 | from .unparser import Unparser 5 | from .printer import Printer 6 | 7 | 8 | __version__ = '1.6.3' 9 | 10 | 11 | def unparse(tree): 12 | v = cStringIO() 13 | Unparser(tree, file=v) 14 | return v.getvalue() 15 | 16 | 17 | def dump(tree): 18 | v = cStringIO() 19 | Printer(file=v).visit(tree) 20 | return v.getvalue() 21 | -------------------------------------------------------------------------------- /tokenizer/astunparse/__main__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import argparse 5 | from .unparser import roundtrip 6 | from . import dump 7 | 8 | 9 | def roundtrip_recursive(target, dump_tree=False): 10 | if os.path.isfile(target): 11 | print(target) 12 | print("=" * len(target)) 13 | if dump_tree: 14 | dump(target) 15 | else: 16 | roundtrip(target) 17 | print() 18 | elif os.path.isdir(target): 19 | for item in os.listdir(target): 20 | if item.endswith(".py"): 21 | roundtrip_recursive(os.path.join(target, item), dump_tree) 22 | else: 23 | print( 24 | "WARNING: skipping '%s', not a file or directory" % target, 25 | file=sys.stderr 26 | ) 27 | 28 | 29 | def main(args): 30 | parser = argparse.ArgumentParser(prog="astunparse") 31 | parser.add_argument( 32 | 'target', 33 | nargs='+', 34 | help="Files or directories to show roundtripped source for" 35 | ) 36 | parser.add_argument( 37 | '--dump', 38 | type=bool, 39 | help="Show a pretty-printed AST instead of the source" 40 | ) 41 | 42 | arguments = parser.parse_args(args) 43 | for target in arguments.target: 44 | roundtrip_recursive(target, dump_tree=arguments.dump) 45 | 46 | 47 | if __name__ == "__main__": 48 | main(sys.argv[1:]) 49 | -------------------------------------------------------------------------------- /tokenizer/astunparse/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/ChainCoder/a9ae00758fe9185c7f13f9b9dc591d05de0d4445/tokenizer/astunparse/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tokenizer/astunparse/__pycache__/printer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/ChainCoder/a9ae00758fe9185c7f13f9b9dc591d05de0d4445/tokenizer/astunparse/__pycache__/printer.cpython-38.pyc -------------------------------------------------------------------------------- /tokenizer/astunparse/__pycache__/unparser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/ChainCoder/a9ae00758fe9185c7f13f9b9dc591d05de0d4445/tokenizer/astunparse/__pycache__/unparser.cpython-38.pyc -------------------------------------------------------------------------------- /tokenizer/astunparse/printer.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | import sys 3 | import ast 4 | import six 5 | 6 | 7 | class Printer(ast.NodeVisitor): 8 | 9 | def __init__(self, file=sys.stdout, indent=" "): 10 | self.indentation = 0 11 | self.indent_with = indent 12 | self.f = file 13 | 14 | # overridden to make the API obvious 15 | def visit(self, node): 16 | super(Printer, self).visit(node) 17 | 18 | def write(self, text): 19 | self.f.write(six.text_type(text)) 20 | 21 | def generic_visit(self, node): 22 | 23 | if isinstance(node, list): 24 | nodestart = "[" 25 | nodeend = "]" 26 | children = [("", child) for child in node] 27 | else: 28 | nodestart = type(node).__name__ + "(" 29 | nodeend = ")" 30 | children = [(name + "=", value) for name, value in ast.iter_fields(node)] 31 | 32 | if len(children) > 1: 33 | self.indentation += 1 34 | 35 | self.write(nodestart) 36 | for i, pair in enumerate(children): 37 | attr, child = pair 38 | if len(children) > 1: 39 | self.write("\n" + self.indent_with * self.indentation) 40 | if isinstance(child, (ast.AST, list)): 41 | self.write(attr) 42 | self.visit(child) 43 | else: 44 | self.write(attr + repr(child)) 45 | 46 | if i != len(children) - 1: 47 | self.write(",") 48 | self.write(nodeend) 49 | 50 | if len(children) > 1: 51 | self.indentation -= 1 52 | -------------------------------------------------------------------------------- /tokenizer/tokenizerAPI.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import numpy as np 3 | 4 | from tokenizer.tokenization_algorithm import ( 5 | vocabulary_defs, 6 | python_repr_tokenizer, 7 | pre_api_token_to_repr, 8 | pre_api_int2token, 9 | tonp, 10 | load_txt, 11 | closest, 12 | segment_then_put_in_template, 13 | robust_tokenization_after_fix_vocab 14 | ) 15 | 16 | VERBOSE = False 17 | 18 | 19 | # 🟩 🟩 🟩 🟩 🟩 🟩 🟩 🟩 20 | # 🟩 For external useage, only use the following 12 functions, to avoid complicated repr()/eval() conversion. 21 | # 🟩 I/O means: transformer input/output side, which are iodata/program 22 | # 🟩 R,T,N means: 23 | # 🟩 R: the human readable obj (not exactly repr); for program it is string repr, for io data it is actually python Obj, NOT string repr. 24 | # 🟩 T: tokens, len(syn_tokens) = len(cont_tokens) + 1 25 | # 🟩 N: integer encoded tokens, equal length. 26 | # 🟩 🟩 🟩 🟩 🟩 🟩 🟩 🟩 27 | def tokenizerAPI_IR2T(py_obj): 28 | py_repr = repr(py_obj) 29 | syn_tokens, cont_tokens = python_repr_tokenizer(py_repr, is_iodata=True) 30 | return syn_tokens, cont_tokens 31 | def tokenizerAPI_OR2T(py_repr, coarse_level=False): 32 | if not robust_tokenization_after_fix_vocab: 33 | syn_tokens, cont_tokens = python_repr_tokenizer(py_repr, is_iodata=False, coarse_level=coarse_level) 34 | return syn_tokens, cont_tokens 35 | else: 36 | try: 37 | syn_tokens, cont_tokens = python_repr_tokenizer(py_repr, is_iodata=False, coarse_level=coarse_level) 38 | return syn_tokens, cont_tokens 39 | except: 40 | return [], [] 41 | def tokenizerAPI_IT2R(syn_tokens, cont_tokens): 42 | recov_repr = pre_api_token_to_repr(syn_tokens, cont_tokens) 43 | recov_obj = eval(recov_repr, {'inf': float('inf'), 'nan': float('nan')}) 44 | return recov_obj 45 | def tokenizerAPI_OT2R(syn_tokens, cont_tokens): 46 | recov_repr = pre_api_token_to_repr(syn_tokens, cont_tokens) 47 | return recov_repr 48 | def tokenizerAPI_IT2N(syn_tokens, cont_tokens): 49 | cont_tokens += [vocabulary_defs.content_final_chasing_syntax_token] 50 | syn_intSeq = [vocabulary_defs.token2int_I(t, 'syn') for t in syn_tokens] 51 | cont_intSeq = [vocabulary_defs.token2int_I(t, 'cont') for t in cont_tokens] 52 | int_seq = np.array(list(zip(syn_intSeq, cont_intSeq))).tolist() 53 | return int_seq 54 | def tokenizerAPI_OT2N(syn_tokens, cont_tokens, coarse_level=False): 55 | cont_tokens += [vocabulary_defs.content_final_chasing_syntax_token] 56 | if coarse_level: 57 | syn_intSeq = [vocabulary_defs.token2int_O(t, 'syn') for t in syn_tokens[1:]] 58 | cont_intSeq = [vocabulary_defs.token2int_O(t, 'cont') for t in cont_tokens[1:]] 59 | int_seq = np.array(list(zip(syn_intSeq, cont_intSeq))).tolist() 60 | 61 | syn_intSeq = [vocabulary_defs.token2int_O(t, 'syn') for t in syn_tokens[0]] 62 | cont_intSeq = [vocabulary_defs.token2int_O(t, 'cont') for t in cont_tokens[0]] 63 | int_seq_c = np.array(list(zip(syn_intSeq, cont_intSeq))).tolist() 64 | int_seq.insert(0, int_seq_c) 65 | else: 66 | syn_intSeq = [vocabulary_defs.token2int_O(t, 'syn') for t in syn_tokens] 67 | cont_intSeq = [vocabulary_defs.token2int_O(t, 'cont') for t in cont_tokens] 68 | int_seq = np.array(list(zip(syn_intSeq, cont_intSeq))).tolist() 69 | return int_seq 70 | def tokenizerAPI_IN2T(int_seq, drop_cross_instance_pad_token=0, drop_cross_sample_pad_token=0): 71 | decoder = vocabulary_defs.int2token_I 72 | syn_tokens, cont_tokens = pre_api_int2token(int_seq, decoder, drop_cross_instance_pad_token, drop_cross_sample_pad_token) 73 | return syn_tokens, cont_tokens 74 | def tokenizerAPI_ON2T(int_seq, drop_cross_instance_pad_token=0, drop_cross_sample_pad_token=0): 75 | decoder = vocabulary_defs.int2token_O 76 | syn_tokens, cont_tokens = pre_api_int2token(int_seq, decoder, drop_cross_instance_pad_token, drop_cross_sample_pad_token) 77 | return syn_tokens, cont_tokens 78 | def tokenizerAPI_IN2R(int_seq): 79 | syn_tokens, cont_tokens = tokenizerAPI_IN2T(int_seq) 80 | recov_repr = tokenizerAPI_IT2R(syn_tokens, cont_tokens) 81 | return recov_repr 82 | def tokenizerAPI_ON2R(int_seq): 83 | syn_tokens, cont_tokens = tokenizerAPI_ON2T(int_seq) 84 | recov_repr = tokenizerAPI_OT2R(syn_tokens, cont_tokens) 85 | return recov_repr 86 | def tokenizerAPI_IR2N(py_repr): 87 | syn_tokens, cont_tokens = tokenizerAPI_IR2T(py_repr) 88 | int_seqs = tokenizerAPI_IT2N(syn_tokens, cont_tokens) 89 | return int_seqs 90 | def tokenizerAPI_OR2N(py_repr): 91 | syn_tokens, cont_tokens = tokenizerAPI_OR2T(py_repr) 92 | int_seqs = tokenizerAPI_OT2N(syn_tokens, cont_tokens) 93 | return int_seqs 94 | def tokenizerAPI_ON2R_flatten(int_seq, coarse_level=False): 95 | """ Used for transformer 1D sequence output. 96 | """ 97 | try: 98 | if coarse_level: 99 | int_seq = tonp(int_seq).reshape(-1).tolist() 100 | which_subseq = 0 101 | s3, s4 = [], [] 102 | for n in int_seq: 103 | if n==vocabulary_defs.s1234_sep_id: 104 | which_subseq += 1 105 | if which_subseq==3: 106 | s3.append(n) 107 | if which_subseq==3: 108 | s4.append(n) 109 | s3, s4 = s3[1:], s4[1:] 110 | transpose = lambda matrix: [[matrix[j][i] for j in range(len(matrix))] for i in range(len(matrix[0]))] 111 | int_seq = transpose([s3,s4]) 112 | recov_repr = tokenizerAPI_ON2R(int_seq) 113 | return recov_repr 114 | else: 115 | int_seq = tonp(int_seq).reshape(-1) 116 | l = len(int_seq) 117 | int_seq = int_seq.reshape(l//2,2).tolist() 118 | recov_repr = tokenizerAPI_ON2R(int_seq) 119 | return recov_repr 120 | except: 121 | return None 122 | def tokenizerAPI_D2N(txt_str): 123 | """ Convert from description texts to ints; paired with APPS pre-trained model. 124 | """ 125 | tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2') 126 | int_list = tokenizer.encode(txt_str, verbose=False) 127 | return int_list 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | # 🟩 🟩 🟩 🟩 🟩 🟩 🟩 🟩 138 | # 🟩 other APIs. 139 | def cross_sample_pad_aligning_syntax(iodata_ist2): 140 | """ padding the io data to align across different samples. The inputs/outputs of this function are all integer encoded sequence. For outputs, once the lengths are aligned, convert each sample into a numpy array - this is the best place to do so. 141 | The algorithm is: 142 | first convert back to obj, decouple back to i/o, then for each elem in i/o, again do tokenize, then, get the lengths of content tokens of each elem, pad with the longest one for every sample. 143 | Args: 144 | iodata_ist2: a nested list with int values coded tokens at the inner. Shape: [instance, sample, xy tokens, syn or cont subtokens] 145 | Returns: 146 | same shape list, but padded across sample (instance len are still different.) 147 | """ 148 | 149 | cont_pad_id = vocabulary_defs.input_cross_sample_pad_cont_id 150 | syn_pad_id = vocabulary_defs.input_cross_sample_pad_syn_id 151 | 152 | nums_that_samples_give_to_each_placeholder = [] # this is a irregular shaped 3D list with shape: [N_instance, N_sample, N_placeholder in this instance]. For each sample, this variable stores how many content token it gives to each placeholders as specified by this instance. 153 | sample_seqlens = [] 154 | for instance in iodata_ist2: 155 | nums_that_samples_give_to_each_placeholder.append([]) 156 | tmp = [] 157 | for sample_intSeq in instance: 158 | tmp.append(len(sample_intSeq)) 159 | nums_that_samples_give_to_each_placeholder[-1].append([]) 160 | inp, outp = tokenizerAPI_IN2R(sample_intSeq) # this is a list with shape N_sample x 2 (2 is the io) 161 | 162 | for body in list(inp) + list(outp): 163 | retok_syn, retok_cont = tokenizerAPI_IR2T(body) 164 | nums_that_samples_give_to_each_placeholder[-1][-1].append(len(retok_cont)) 165 | 166 | sample_seqlens.append(max(tmp)) 167 | 168 | def irregular_pad(data_st2, pad_target_len): 169 | res = [] 170 | for t2 in data_st2: 171 | res.append( t2 + [[syn_pad_id, cont_pad_id]] * (pad_target_len-len(t2)) ) 172 | return res 173 | 174 | res_ISt2 = [] 175 | for inst_id, instance in enumerate(iodata_ist2): 176 | res_ISt2.append([]) 177 | try: # normally it would be padded according to syntax roles here 178 | sample_devotions = np.asarray(nums_that_samples_give_to_each_placeholder[inst_id]) 179 | max_max = sample_seqlens[inst_id]-1 180 | max_devotions = sample_devotions.max(axis=0) # this is 1-D array with shape [num_placeholders] specified by this instance. It contain int values, meaning the padding target lens for each syntax role in this instance. 181 | if sum(max_devotions)0 202 | 203 | 204 | for i, instance in enumerate(res_ISt2): 205 | res_ISt2[i] = np.array(instance).astype(int) 206 | # assert len(res_ISt2[i])>0 207 | 208 | return res_ISt2 209 | 210 | def cross_instance_pad_io(inp_ist2): 211 | """ Pad the io data across instance, to make sure 1. the num sample are the same, 2. the sample length are the same across samples. The inputs/outputs of this function are all integer encoded sequences. 212 | Args: 213 | inp_ist2: a irregular list with np.array (i=list, st2=np.array), shape = [instance, sample, token, syn & cont] 214 | Returns: 215 | padded_ist2: np.array, shape = [instance, sample, token, syn & cont] 216 | """ 217 | def fit_small_into_large_3D(small_arr, large_arr): 218 | d0,d1,d2 = small_arr.shape 219 | large_arr[:d0, :d1, :d2] = small_arr 220 | return large_arr 221 | Ni = len(inp_ist2) 222 | cross_inst_pad_syn_id = vocabulary_defs.input_cross_instance_pad_syn_id 223 | cross_inst_pad_cont_id = vocabulary_defs.input_cross_instance_pad_cont_id 224 | 225 | # 🟩 pad for both sample and token len 226 | Ns = max([x.shape[0] for x in inp_ist2]) 227 | Nt = max([x.shape[1] for x in inp_ist2]) 228 | 229 | padded_ist2 = np.zeros([Ni,Ns,Nt,2]) 230 | padded_ist2[..., 0] = cross_inst_pad_syn_id 231 | padded_ist2[..., 1] = cross_inst_pad_cont_id 232 | token_lens_per_inst = [] 233 | samp_lens_per_inst = [] 234 | for inst_id in range(Ni): 235 | st2 = inp_ist2[inst_id] 236 | num_samp, max_tok_len, _ = st2.shape 237 | token_lens_per_inst.append(max_tok_len) 238 | samp_lens_per_inst.append(num_samp) 239 | padded_ist2[inst_id] = fit_small_into_large_3D(st2, padded_ist2[inst_id]) 240 | 241 | return padded_ist2, samp_lens_per_inst, token_lens_per_inst 242 | 243 | 244 | 245 | def cross_instance_pad_code_interleaved(inp_it2): 246 | """ Pad the code data across instance. 247 | Difference with `cross_instance_pad_io`: this one will add BOS/EOS 248 | """ 249 | 250 | Ni = len(inp_it2) 251 | pad_id = vocabulary_defs.output_pad_id 252 | bos = np.array([vocabulary_defs.bos_token_id]) 253 | eos = np.array([vocabulary_defs.eos_token_id]) 254 | 255 | # 🟩 pad for both sample and token len 256 | Nt = max([len(x) for x in inp_it2]) 257 | 258 | padded_it = np.zeros([Ni,Nt*2 + 2]) + pad_id # +2 is for BOS and EOS 259 | lens = [] 260 | 261 | for inst_id in range(Ni): 262 | t2 = np.concatenate([bos, tonp(inp_it2[inst_id]).reshape(-1), eos], axis=0) 263 | padded_it[inst_id][:len(t2)] = t2 264 | lens.append(len(t2)) 265 | 266 | return padded_it, lens 267 | 268 | 269 | def cross_instance_pad_code(inp_it2): 270 | """ Pad the code data across instance. 271 | As defined in get_ChainCoder_dataloader in sttd.py, inp_it2[instance_id] has the structure of: 272 | [ [[coarse_s0, coarse_c0], [coarse_s1, coarse_c1], ... ] , [s0,c0], [s1,c1], [s2,c2], ...] 273 | """ 274 | 275 | 276 | coarse_seq = [t2[0] for t2 in inp_it2] 277 | inp_it2 = [t2[1:] for t2 in inp_it2] 278 | # inp_it2 = inp_it2[1:] 279 | 280 | Ni = len(inp_it2) 281 | pad_id = vocabulary_defs.output_pad_id 282 | ssep = [vocabulary_defs.s1234_sep_id] 283 | 284 | bos = np.array([vocabulary_defs.bos_token_id]) 285 | eos = np.array([vocabulary_defs.eos_token_id]) 286 | 287 | # 🟩 pad for both sample and token len 288 | Nt = max([len(x)+len(y) for x,y in zip(coarse_seq, inp_it2)]) 289 | 290 | padded_it = np.zeros([Ni, Nt*2 + 2 + 3]) + pad_id # +2 is for BOS and EOS; +3 for s1234 sep 291 | lens = [] 292 | 293 | for inst_id in range(Ni): 294 | s1, s2 = tonp(coarse_seq[inst_id]).transpose(1,0).tolist() 295 | s3, s4 = tonp(inp_it2[inst_id]).transpose(1,0).tolist() 296 | 297 | t2 = np.concatenate([bos,s1,ssep,s2,ssep,s3,ssep,s4,eos], axis=0) 298 | padded_it[inst_id][:len(t2)] = t2 299 | lens.append(len(t2)) 300 | 301 | return padded_it, lens 302 | 303 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from tqdm import tqdm 5 | 6 | from trainer.slurm import init_distributed_mode 7 | from model import build_modules 8 | from trainer.trainer import Trainer 9 | from parsers import get_parser 10 | from tokenizer.tokenizerAPI import vocabulary_defs, closest 11 | from dataloaders.sttd import get_ChainCoder_dataloader, pretrain_loader 12 | 13 | 14 | def main(): 15 | 16 | params = get_parser() 17 | 18 | init_distributed_mode(params) 19 | # CPU / CUDA 20 | if not params.run_on_cpu: 21 | assert torch.cuda.is_available() 22 | 23 | os.makedirs(params.training_ckpt_dump_path, exist_ok=1) 24 | modules = build_modules(vocabulary_defs, params) 25 | trnr = Trainer(modules, vocabulary_defs, params) 26 | 27 | if params.training_resume_ckpt_from != '': 28 | trnr.reload_checkpoint(params.training_resume_ckpt_from) 29 | 30 | 31 | # ---- training 32 | for iepoch in range(params.max_epoch): 33 | 34 | def control_difficulty(params, iepoch): 35 | params.batch_size 36 | 37 | grid_io = [1] + list(range(4, params.training_difficulty_A_io, 4)) # [1, 4, 8, 12, ...] 38 | grid_code = list(range(1,params.training_difficulty_A_code)) 39 | def periodic(i, A, T): 40 | def period_decimal(x): 41 | decimal = x-np.floor(x) 42 | return min(decimal, 1-decimal) 43 | return 2* A * period_decimal(i/T+0.5) 44 | 45 | mixed_io = lambda i: closest(periodic(i, params.training_difficulty_A_io, params.training_difficulty_T_io), grid_io) # range(20) -> [28, 24, 16, 8, 4, 12, 20, 28, 28, 20, 8, 1, 8, 16, ...] 46 | mixed_code = lambda i: closest(periodic(i, params.training_difficulty_A_code, params.training_difficulty_T_code), grid_code) # range(20) -> [7, 7, 6, 5, 5, 4, 3, 2, 1, 1, 1, 1, 2, 3, ...] 47 | 48 | 49 | params.samples_per_instance_io = mixed_io(iepoch) 50 | params.samples_per_instance_code = mixed_code(iepoch) 51 | return 52 | 53 | control_difficulty(params, iepoch) 54 | 55 | if params.is_pretraining: 56 | trainloader = pretrain_loader(params) 57 | else: 58 | trainloader = get_ChainCoder_dataloader(params, params.pickle_data_root) 59 | 60 | print(f'\n🟨 🟨 🟨 🟨 Training Epoch {iepoch} Start 🟨 🟨 🟨 🟨\n') 61 | 62 | for samples in tqdm(trainloader): 63 | 64 | if samples is None: # for debug purpose, or robust_tokenization_after_fix_vocab, output may be None 65 | continue 66 | trnr.enc_dec_step(samples) 67 | 68 | if iepoch%20==0: 69 | trnr.save_checkpoint(f'epoch-{iepoch}.pth') 70 | 71 | 72 | if __name__ == "__main__": 73 | 74 | main() 75 | -------------------------------------------------------------------------------- /trainer/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/ChainCoder/a9ae00758fe9185c7f13f9b9dc591d05de0d4445/trainer/.DS_Store -------------------------------------------------------------------------------- /trainer/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from datetime import timedelta 4 | 5 | 6 | class LogFormatter: 7 | def __init__(self): 8 | self.start_time = time.time() 9 | 10 | def format(self, record): 11 | elapsed_seconds = round(record.created - self.start_time) 12 | 13 | prefix = "%s - %s - %s" % ( 14 | record.levelname, 15 | time.strftime("%x %X"), 16 | timedelta(seconds=elapsed_seconds), 17 | ) 18 | message = record.getMessage() 19 | message = message.replace("\n", "\n" + " " * (len(prefix) + 3)) 20 | return "%s - %s" % (prefix, message) if message else "" 21 | 22 | 23 | def create_logger(filepath, rank): 24 | """ 25 | Create a logger. 26 | Use a different log file for each process. 27 | """ 28 | # create log formatter 29 | log_formatter = LogFormatter() 30 | 31 | # create file handler and set level to debug 32 | if filepath is not None: 33 | if rank > 0: 34 | filepath = "%s-%i" % (filepath, rank) 35 | file_handler = logging.FileHandler(filepath, "a") 36 | file_handler.setLevel(logging.DEBUG) 37 | file_handler.setFormatter(log_formatter) 38 | 39 | # create console handler and set level to info 40 | console_handler = logging.StreamHandler() 41 | console_handler.setLevel(logging.INFO) 42 | console_handler.setFormatter(log_formatter) 43 | 44 | # create logger and set level to debug 45 | logger = logging.getLogger() 46 | logger.handlers = [] 47 | logger.setLevel(logging.DEBUG) 48 | logger.propagate = False 49 | if filepath is not None: 50 | logger.addHandler(file_handler) 51 | logger.addHandler(console_handler) 52 | 53 | # reset logger elapsed time 54 | def reset_time(): 55 | log_formatter.start_time = time.time() 56 | 57 | logger.reset_time = reset_time 58 | 59 | return logger 60 | -------------------------------------------------------------------------------- /trainer/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import r2_score, mean_squared_error 2 | from collections import defaultdict 3 | import numpy as np 4 | import scipy 5 | 6 | def compute_metrics(infos, metrics="r2"): 7 | results = defaultdict(list) 8 | if metrics == "": 9 | return {} 10 | 11 | if "true" in infos: 12 | true, predicted = infos["true"], infos["predicted"] 13 | assert len(true) == len(predicted), "issue with len, true: {}, predicted: {}".format(len(true), len(predicted)) 14 | for i in range(len(true)): 15 | if predicted[i] is None: continue 16 | if len(true[i].shape)==2: 17 | true[i]=true[i][:,0] 18 | if len(predicted[i].shape)==2: 19 | predicted[i]=predicted[i][:,0] 20 | assert true[i].shape == predicted[i].shape, "Problem with shapes: {}, {}".format(true[i].shape, predicted[i].shape) 21 | 22 | for metric in metrics.split(","): 23 | if metric == "r2": 24 | true, predicted = infos["true"], infos["predicted"] 25 | for i in range(len(true)): 26 | if predicted[i] is None or np.isnan(np.min(predicted[i])): 27 | #print(predicted[i]) 28 | results[metric].append(np.nan) 29 | else: 30 | try: 31 | results[metric].append(r2_score(true[i], predicted[i])) 32 | except Exception as e: 33 | #print(e, metric, true[i], predicted[i]) 34 | results[metric].append(np.nan) 35 | if metric == "r2_zero": 36 | true, predicted = infos["true"], infos["predicted"] 37 | for i in range(len(true)): 38 | if predicted[i] is None or np.isnan(np.min(predicted[i])): 39 | #print(predicted[i]) 40 | results[metric].append(np.nan) 41 | else: 42 | try: 43 | results[metric].append(max(0, r2_score(true[i], predicted[i]))) 44 | except Exception as e: 45 | #print(e, metric, true[i], predicted[i]) 46 | results[metric].append(np.nan) 47 | 48 | elif metric.startswith("accuracy_l1"): 49 | if metric == "accuracy_l1": 50 | atol, rtol = 0.0, 0.1 51 | tolerance_point = 0.95 52 | elif metric == "accuracy_l1_biggio": 53 | ## default is biggio et al. 54 | atol, rtol = 1e-3, 0.05 55 | tolerance_point = 0.95 56 | else: 57 | atol = 0 #float(metric.split("_")[-3]) 58 | rtol = float(metric.split("_")[-1]) 59 | tolerance_point = 0.95 #float(metric.split("_")[-1]) 60 | 61 | true, predicted = infos["true"], infos["predicted"] 62 | for i in range(len(true)): 63 | if predicted[i] is None or np.isnan(np.min(predicted[i])): 64 | results[metric].append(np.nan) 65 | else: 66 | try: 67 | is_close = np.isclose(predicted[i], true[i], atol=atol, rtol=rtol) 68 | results[metric].append(float(is_close.mean()>=tolerance_point)) 69 | except Exception as e: 70 | print(e, metric, true[i], predicted[i]) 71 | results[metric].append(np.nan) 72 | 73 | elif metric == "_mse": 74 | true, predicted = infos["true"], infos["predicted"] 75 | for i in range(len(true)): 76 | if predicted[i] is None or np.isnan(np.min(predicted[i])): 77 | results[metric].append(np.nan) 78 | else: 79 | try: 80 | results[metric].append(mean_squared_error(true[i], predicted[i])) 81 | except Exception as e: 82 | results[metric].append(np.nan) 83 | elif metric == "_nmse": 84 | true, predicted = infos["true"], infos["predicted"] 85 | for i in range(len(true)): 86 | if predicted[i] is None or np.isnan(np.min(predicted[i])): 87 | results[metric].append(np.nan) 88 | else: 89 | try: 90 | mean_y = np.mean(true[i]) 91 | NMSE = (np.mean(np.square(true[i]- predicted[i])))/mean_y 92 | results[metric].append(NMSE) 93 | except Exception as e: 94 | results[metric].append(np.nan) 95 | elif metric == "_rmse": 96 | true, predicted = infos["true"], infos["predicted"] 97 | for i in range(len(true)): 98 | if predicted[i] is None or np.isnan(np.min(predicted[i])): 99 | results[metric].append(np.nan) 100 | else: 101 | try: 102 | results[metric].append(mean_squared_error(true[i], predicted[i], squared=False)) 103 | except Exception as e: 104 | results[metric].append(np.nan) 105 | elif metric == "_complexity": 106 | if "predicted_tree" not in infos: 107 | results[metric].extend([np.nan for _ in range(len(infos["true"]))]) 108 | continue 109 | predicted_tree = infos["predicted_tree"] 110 | for i in range(len(predicted_tree)): 111 | if predicted_tree[i] is None: 112 | results[metric].append(np.nan) 113 | else: 114 | results[metric].append(len(predicted_tree[i].prefix().split(","))) 115 | 116 | elif metric == "_relative_complexity": 117 | if "tree" not in infos or "predicted_tree" not in infos: 118 | results[metric].extend([np.nan for _ in range(len(infos["true"]))]) 119 | continue 120 | tree = infos["tree"] 121 | predicted_tree = infos["predicted_tree"] 122 | for i in range(len(predicted_tree)): 123 | if predicted_tree[i] is None: 124 | results[metric].append(np.nan) 125 | else: 126 | results[metric].append(len(predicted_tree[i].prefix().split(",")) - len(tree[i].prefix().split(","))) 127 | 128 | elif metric == "is_symbolic_solution": 129 | 130 | true, predicted = infos["true"], infos["predicted"] 131 | for i in range(len(true)): 132 | if predicted[i] is None or np.isnan(np.min(predicted[i])): 133 | results[metric].append(np.nan) 134 | else: 135 | try: 136 | diff = true[i] - predicted[i] 137 | div = true[i] / (predicted[i] + 1e-100) 138 | std_diff = scipy.linalg.norm( 139 | np.abs(diff - diff.mean(0)) 140 | ) 141 | std_div = scipy.linalg.norm( 142 | np.abs(div - div.mean(0)) 143 | ) 144 | if std_diff<1e-10 and std_div<1e-10: results[metric].append(1.0) 145 | else: results[metric].append(0.0) 146 | except Exception as e: 147 | #print(e, metric, infos["predicted_tree"][i].infix()) 148 | results[metric].append(np.nan) 149 | 150 | elif metric == "_l1_error": 151 | true, predicted = infos["true"], infos["predicted"] 152 | for i in range(len(true)): 153 | if predicted[i] is None or np.isnan(np.min(predicted[i])): 154 | results[metric].append(np.nan) 155 | else: 156 | try: 157 | l1_error = np.mean(np.abs((true[i] - predicted[i]))) 158 | if np.isnan(l1_error): results[metric].append(np.infty) 159 | else: results[metric].append(l1_error) 160 | except Exception as e: 161 | 162 | results[metric].append(np.nan) 163 | return results 164 | -------------------------------------------------------------------------------- /trainer/slurm.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | import os 3 | import sys 4 | import torch 5 | import socket 6 | import signal 7 | import subprocess 8 | 9 | 10 | logger = getLogger() 11 | 12 | 13 | def sig_handler(signum, frame): 14 | logger.warning("Signal handler called with signal " + str(signum)) 15 | prod_id = int(os.environ["SLURM_PROCID"]) 16 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 17 | if prod_id == 0: 18 | logger.warning("Requeuing job " + os.environ["SLURM_JOB_ID"]) 19 | os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) 20 | else: 21 | logger.warning("Not the master process, no need to requeue.") 22 | sys.exit(-1) 23 | 24 | 25 | def term_handler(signum, frame): 26 | logger.warning("Signal handler called with signal " + str(signum)) 27 | logger.warning("Bypassing SIGTERM.") 28 | 29 | 30 | def init_signal_handler(): 31 | """ 32 | Handle signals sent by SLURM for time limit / pre-emption. 33 | """ 34 | signal.signal(signal.SIGUSR1, sig_handler) 35 | signal.signal(signal.SIGTERM, term_handler) 36 | logger.warning("Signal handler installed.") 37 | 38 | 39 | def init_distributed_mode(params): 40 | """ 41 | Handle single and multi-GPU / multi-node / SLURM jobs. 42 | Initialize the following variables: 43 | - n_nodes 44 | - node_id 45 | - local_rank 46 | - global_rank 47 | - world_size 48 | """ 49 | params.is_slurm_job = "SLURM_JOB_ID" in os.environ and not params.debug_slurm 50 | print("SLURM job: %s" % str(params.is_slurm_job)) 51 | 52 | # SLURM job 53 | if params.is_slurm_job: 54 | 55 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM 56 | 57 | SLURM_VARIABLES = [ 58 | "SLURM_JOB_ID", 59 | "SLURM_JOB_NODELIST", 60 | "SLURM_JOB_NUM_NODES", 61 | "SLURM_NTASKS", 62 | "SLURM_TASKS_PER_NODE", 63 | "SLURM_MEM_PER_NODE", 64 | "SLURM_MEM_PER_CPU", 65 | "SLURM_NODEID", 66 | "SLURM_PROCID", 67 | "SLURM_LOCALID", 68 | "SLURM_TASK_PID", 69 | ] 70 | 71 | PREFIX = "%i - " % int(os.environ["SLURM_PROCID"]) 72 | for name in SLURM_VARIABLES: 73 | value = os.environ.get(name, None) 74 | print(PREFIX + "%s: %s" % (name, str(value))) 75 | 76 | # # job ID 77 | # params.job_id = os.environ['SLURM_JOB_ID'] 78 | 79 | # number of nodes / node ID 80 | params.n_nodes = int(os.environ["SLURM_JOB_NUM_NODES"]) 81 | params.node_id = int(os.environ["SLURM_NODEID"]) 82 | 83 | # local rank on the current node / global rank 84 | params.local_rank = int(os.environ["SLURM_LOCALID"]) 85 | params.global_rank = int(os.environ["SLURM_PROCID"]) 86 | 87 | # number of processes / GPUs per node 88 | params.world_size = int(os.environ["SLURM_NTASKS"]) 89 | params.n_gpu_per_node = params.world_size // params.n_nodes 90 | 91 | # define master address and master port 92 | hostnames = subprocess.check_output( 93 | ["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]] 94 | ) 95 | params.master_addr = hostnames.split()[0].decode("utf-8") 96 | # assert 10001 <= params.master_port <= 20000 or params.world_size == 1 97 | print(PREFIX + "Master address: %s" % params.master_addr) 98 | print(PREFIX + "Master port : %i" % params.master_port) 99 | 100 | # set environment variables for 'env://' 101 | os.environ["MASTER_ADDR"] = params.master_addr 102 | os.environ["MASTER_PORT"] = str(params.master_port) 103 | os.environ["WORLD_SIZE"] = str(params.world_size) 104 | os.environ["RANK"] = str(params.global_rank) 105 | 106 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 107 | elif params.local_rank != -1: 108 | 109 | assert params.master_port == -1 110 | 111 | # read environment variables 112 | params.global_rank = int(os.environ["RANK"]) 113 | params.world_size = int(os.environ["WORLD_SIZE"]) 114 | params.n_gpu_per_node = int(os.environ["NGPU"]) 115 | 116 | # number of nodes / node ID 117 | params.n_nodes = params.world_size // params.n_gpu_per_node 118 | params.node_id = params.global_rank // params.n_gpu_per_node 119 | 120 | # local job (single GPU) 121 | else: 122 | assert params.local_rank == -1 123 | assert params.master_port == -1 124 | params.n_nodes = 1 125 | params.node_id = 0 126 | params.local_rank = 0 127 | params.global_rank = 0 128 | params.world_size = 1 129 | params.n_gpu_per_node = 1 130 | 131 | # sanity checks 132 | assert params.n_nodes >= 1 133 | assert 0 <= params.node_id < params.n_nodes 134 | assert 0 <= params.local_rank <= params.global_rank < params.world_size 135 | assert params.world_size == params.n_nodes * params.n_gpu_per_node 136 | 137 | # define whether this is the master process / if we are in distributed mode 138 | params.is_master = params.node_id == 0 and params.local_rank == 0 139 | params.multi_node = params.n_nodes > 1 140 | params.multi_gpu = params.world_size > 1 141 | 142 | # summary 143 | PREFIX = "%i - " % params.global_rank 144 | print(PREFIX + "Number of nodes: %i" % params.n_nodes) 145 | print(PREFIX + "Node ID : %i" % params.node_id) 146 | print(PREFIX + "Local rank : %i" % params.local_rank) 147 | print(PREFIX + "Global rank : %i" % params.global_rank) 148 | print(PREFIX + "World size : %i" % params.world_size) 149 | print(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node) 150 | print(PREFIX + "Master : %s" % str(params.is_master)) 151 | print(PREFIX + "Multi-node : %s" % str(params.multi_node)) 152 | print(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu)) 153 | print(PREFIX + "Hostname : %s" % socket.gethostname()) 154 | 155 | # set GPU device 156 | if not params.run_on_cpu: 157 | torch.cuda.set_device(params.local_rank) 158 | 159 | # initialize multi-GPU 160 | if params.multi_gpu: 161 | 162 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 163 | # 'env://' will read these environment variables: 164 | # MASTER_PORT - required; has to be a free port on machine with rank 0 165 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 166 | # WORLD_SIZE - required; can be set either here, or in a call to init function 167 | # RANK - required; can be set either here, or in a call to init function 168 | 169 | print("Initializing PyTorch distributed ...") 170 | torch.distributed.init_process_group( 171 | init_method="env://", 172 | backend="nccl", 173 | ) 174 | --------------------------------------------------------------------------------