├── .github ├── ISSUE_TEMPLATE │ └── new-puzzle.md └── workflows │ └── codeql-analysis.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── DATASHEET.md ├── ICLR2023 ├── README.md ├── data │ ├── 125M_PAPER_1M_iter_1.txt.gz │ ├── 13B_PAPER_1M_iter_1.txt.gz │ ├── 27B_PAPER_1M_iter_1.txt.gz │ ├── 350M_PAPER_1M_iter_0.txt.gz │ └── Codex_PAPER_1M_iter_0.txt.gz ├── requirements.txt └── src │ ├── babysit.sh │ ├── ds_config_gptneo.json │ ├── fine_tune.py │ ├── fine_tune.sh │ ├── fine_tune1.sh │ ├── gen.py │ ├── gen.sh │ ├── judge.py │ ├── neo_train.py │ ├── preprocess.py │ ├── requirements.txt │ ├── solve.py │ └── utils.py ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── generators ├── ICPC.py ├── IMO.py ├── __init__.py ├── algebra.py ├── basic.py ├── chess.py ├── classic_puzzles.py ├── codeforces.py ├── compression.py ├── conways_game_of_life.py ├── games.py ├── graphs.py ├── human_eval.py ├── lattices.py ├── number_theory.py ├── probability.py ├── problem_constants │ └── rsa_challenges.json ├── study.py ├── trivial_inverse.py └── tutorial.py ├── make_dataset.py ├── notebooks ├── Demo.ipynb ├── Hackathon_puzzles.ipynb ├── Intro.ipynb ├── demo │ ├── demo.py │ └── study_puzzles.py └── fireworks.gif ├── puzzle_generator.py ├── puzzles ├── README.md ├── puzzles.json └── split.json ├── solvers ├── README.md ├── codex │ ├── 138puzzles.json │ ├── 30puzzles.json │ ├── 397puzzles.json │ ├── README.md │ ├── ezlog.py │ ├── lm_solve │ │ ├── __init__.py │ │ ├── gpt_lib.py │ │ ├── judge.py │ │ ├── run.py │ │ └── scratch.py │ ├── requirements.txt │ ├── results │ │ └── results_397_cushman_codex_1k_full.json.gz │ ├── run_codex_experiments.py │ └── utils.py ├── enumerative │ ├── README.md │ ├── challenges │ │ ├── __init__.py │ │ ├── challenge.py │ │ └── solutions.py │ ├── download_pretrained_roberta.sh │ ├── filter_outputs.py │ ├── models │ │ ├── __init__.py │ │ ├── ml_bow_bigram.py │ │ ├── model.py │ │ ├── transformers │ │ │ ├── __init__.py │ │ │ ├── dataset_processor.py │ │ │ ├── finetune_transformer.py │ │ │ ├── generate_rule_embeddings.py │ │ │ ├── learn_tokenizer.py │ │ │ ├── neural_classifier.py │ │ │ └── preprocess_pretraining_data.py │ │ └── uniform.py │ ├── requirements.txt │ ├── run_bigram.sh │ ├── run_transformer.sh │ ├── run_uniform.sh │ ├── solve_challenges.py │ ├── top_down.py │ ├── tython │ │ ├── __init__.py │ │ ├── nonterminals.py │ │ ├── parse.py │ │ ├── program.py │ │ └── rules.py │ └── utils │ │ ├── __init__.py │ │ ├── str_utils.py │ │ └── time_utils.py └── gpt3 │ ├── README.md │ ├── create_simple_prompts.py │ ├── ezlog.py │ ├── lm_solve │ ├── __init__.py │ ├── gpt3_lib.py │ ├── judge.py │ └── run.py │ ├── puzzles_with_descriptions.json │ ├── puzzles_with_prompts.json │ ├── requirements.txt │ ├── run_gpt3_experiments.py │ └── utils │ ├── __init__.py │ ├── str_utils.py │ └── time_utils.py └── utils.py /.github/ISSUE_TEMPLATE/new-puzzle.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: New puzzle 3 | about: Create your own puzzle by defining a function 4 | title: New puzzle 5 | labels: New-puzzle 6 | assignees: '' 7 | 8 | --- 9 | 10 | ```python 11 | def sat(x: str): 12 | """optional problem description""" 13 | return "Hello " + x == "Hello world" # change this to your puzzle 14 | ``` 15 | 16 | Solvers, post your solutions in the comments using the following formatting: 17 | ```` 18 |
Reveal solution 19 | 20 | ```python 21 | def sol(): 22 | return "world" # replace with your solution 23 | ``` 24 |
25 | ```` 26 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ main ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ main ] 20 | schedule: 21 | - cron: '41 10 * * 5' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] 37 | # Learn more: 38 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed 39 | 40 | steps: 41 | - name: Checkout repository 42 | uses: actions/checkout@v2 43 | 44 | # Initializes the CodeQL tools for scanning. 45 | - name: Initialize CodeQL 46 | uses: github/codeql-action/init@v1 47 | with: 48 | languages: ${{ matrix.language }} 49 | # If you wish to specify custom queries, you can do so here or in a config file. 50 | # By default, queries listed here will override any specified in a config file. 51 | # Prefix the list here with "+" to use these queries and those in the config file. 52 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 53 | 54 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 55 | # If this step fails, then you should remove it and run the build manually (see below) 56 | - name: Autobuild 57 | uses: github/codeql-action/autobuild@v1 58 | 59 | # ℹ️ Command-line programs to run using the OS shell. 60 | # 📚 https://git.io/JvXDl 61 | 62 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 63 | # and modify them (or add more) to build your code if your project 64 | # uses a compiled language 65 | 66 | #- run: | 67 | # make bootstrap 68 | # make release 69 | 70 | - name: Perform CodeQL Analysis 71 | uses: github/codeql-action/analyze@v1 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | .idea 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /ICLR2023/README.md: -------------------------------------------------------------------------------- 1 | This is the code and data for the paper: Language Models can teach themselves to code better 2 | https://arxiv.org/abs/2207.14502 3 | 4 | LICENSE 5 | MIT License - as already specified in the ../LICENSE file of PythonProgrammingPuzzles repo 6 | 7 | GPU USAGE 8 | GPU usage was large , especially for the 2.7B sized model which is ~20X the 125M. 9 | Data generation takes the most GPU usage and took about 2500 GPU hours for 2.7B (on v100) 10 | Finetuning on the 1M generated data took about 40 GPU hours for 2.7B (on v100) per epoch of finetuning - 10 epochs = 400 GPU hours 11 | Solving the 228 problem testset with 100 attempts using the finetuned 2.7B model took about 4 hours (on v100) 12 | We mostly used v100, but we used whatever was available, so T4 and A100 sometimes if they were free. 13 | Tried everything at 125M first - debug there and make it work perfect - then roll out the 1.3 and 2.7 jobs 14 | 15 | DATASETS 16 | In data directory are the datasets used. We feel the most interesting dataset is data/Codex_PAPER_1M_iter_0.txt 17 | which is generated by Codex and gave the best results when finetuned on. All the datasets are part of our public release. 18 | 19 | SETUP 20 | src/requirements.txt is what we install on our cluster machines - the cluster comes with NVidia drivers and matching pytorch 21 | ./requirements.txt is what I personally have installed on my local machine and tested this runs - but it has lots of stuff you don't need 22 | So try src/requirements.txt only - and if that doesn't work - then /requirements.txt has all versions of everything installed on my machine 23 | Getting a deepspeed 0.6.1 matching a pytorch matching a nvidia driver install was tricky for me on some machines, torch 1.10 and 1.11 both work 24 | 25 | GENERATING/FINETUNING -> run "cd src, ./babysit.sh GPU_INDEX_TO_USE" -> GPU_INDEX_TO_USE=0 typically 26 | In src/babysit.sh is the script that generates data, and finetunes on that data in a loop, finetuning the GPT-Neo 125M/1.3B/2.7B models 27 | In src/babysit.sh TEST_LOCAL=1 controls running locally on machine's GPUs which is great for fast testing, or =0 is launching on the cluster which is slow but has lots of GPUs 28 | Realistically you have to train on a cluster - data generation takes a long time so having lots of machines all generating data is the feasible approach. 29 | But given enough time - this will run locally on 1 GPU. 1 year for 2.7B, or 2 weeks for 125M. 30 | We found generating 75k samples after deduping worked for iteration_0 - finetune on that data. 31 | Then using that fine_tuned model in iter_1 generating data happens more quickly - the finetuned model solves many more problems 32 | Repeating that process works well. 33 | On 125M we looked at just training on only 125M generated data from iter_0 versus iter_1 versus iter_2 - generating 600K for each iteration. 34 | It seemed finetuning on iter_2 data was best on the testset 26.9/228 solved vs iter_1=26.1/228 vs iter_0=22.2/228 35 | With 1M samples from 125M generated data sampled across all the iterations 0,1,2 we got 26.75/228 36 | We understand why it's faster to generate iter_2 data on a finetuned model - it solves more problems. 37 | But why are the generated puzzles&solutions better for training the model on? 38 | We will explore that more in the future - and try iterating a lot farther than 3 iterations - although our preliminary experiments on 125M show it tops out at 3 iterations 39 | 40 | FINETUNING ONLY -> run "cd src, ./fine_tune1.sh GPU_INDEX_TO_USE" -> GPU_INDEX_TO_USE=0 typically 41 | # ./fine_tune1.sh GPU MODEL_TO_TRAIN EXPERIMENT_NAME_DIRECTORY TRAIN_DATA EPOCHS 42 | This allows the repeated finetuning on a specific dataset. 43 | Use this to do a temperature grid search, or try different variations of parameters on a specific dataset. 44 | 45 | Detailed instructions for reproducing experiments: 46 | # Generating Codex data 47 | python gen.py -n=32 -max_tokens=4096 -model_path=openai/code-davinci-002 -model_path_solve=openai/code-cushman-001 -out=../data/codex/iter_0 -seed=2022 48 | 49 | # Measuring codex accuracy via API calls 50 | ./solve2.sh 51 | python solve.py -prefix=../data/train_prefix.txt -attempts=1 -model_path=openai/code-cushman-001 -gpu=0 -fixed_temp=0.8 -out=../data/codex -puzzles=../data/test_228.json -seed=2022 -batch_size=64 52 | 53 | # Producing verified Codex_PAPER_1M_iter_0.txt from the puzzle/solution old style data generated by Codex 54 | python preprocess.py -path=../data/codex/old_verified -f_name=Codex_PAPER_1M_iter_0.txt -max_sols_per_puzzle=8 -old_style_json=True -max_examples=1000000 -include_failures=False -seed=2022 55 | cp ../data/codex/old/Codex_PAPER_1M_iter_0.txt ../data/Codex_PAPER_1M_iter_0.txt 56 | 57 | # Producing unverified Codex_unverified_PAPER_1M_iter_0.txt from the puzzle/solution old style data generated by Codex 58 | python preprocess.py -path=../data/codex/old_unverified -f_name=Codex_unverified_PAPER_1M_iter_0.txt -max_sols_per_puzzle=8 -old_style_json=True -max_examples=1000000 -include_failures=True -seed=2022 59 | cp ../data/codex/old_unverified/Codex_unverified_PAPER_1M_iter_0.txt ../data/Codex_unverified_PAPER_1M_iter_0.txt 60 | 61 | # Producing 125M_PAPER_25K_iter_0.txt from the puzzle/solution new style data 62 | python preprocess.py ../data/125M_PAPER/iter_0 125M_PAPER_25K_iter_0.txt 8 False 25000 False -seed=2022 63 | cp ../data/125M_PAPER/iter_0/125M_PAPER_25K_iter_0.txt ../data/125M_PAPER_25K_iter_0.txt 64 | 65 | # Producing 125M_PAPER_1M_iter_1.txt from the puzzle/solution new style data 66 | python preprocess.py ../data/125M_PAPER/iter_1 125M_PAPER_1M_iter_1.txt 8 False 1000000 False -seed=2022 67 | cp ../data/125M_PAPER/iter_1/125M_PAPER_1M_iter_1.txt ../data/125M_PAPER_1M_iter_1.txt 68 | 69 | # Producing 125M_PAPER_1M_iter_2.txt from the puzzle/solution new style data13B 70 | python preprocess.py ../data/125M_PAPER/iter_2 125M_PAPER_1M_iter_2.txt 8 False 1000000 False -seed=2022 71 | cp ../data/125M_PAPER/iter_2/125M_PAPER_1M_iter_2.txt ../data/125M_PAPER_1M_iter_2.txt 72 | 73 | # Producing 13B_PAPER_25K_iter_0.txt from the puzzle/solution new style data 74 | python preprocess.py ../data/13B_PAPER/iter_0 13B_PAPER_25K_iter_0.txt 8 False 25000 False -seed=2022 75 | cp ../data/13B_PAPER/iter_0/13B_PAPER_25K_iter_0.txt ../data/13B_PAPER_25K_iter_0.txt 76 | 77 | # Producing 13B_PAPER_1M_iter_1.txt from the puzzle/solution new style data 78 | python preprocess.py ../data/13B_PAPER/iter_1 13B_PAPER_1M_iter_1.txt 8 False 1000000 False -seed=2022 79 | cp ../data/13B_PAPER/iter_1/13B_PAPER_1M_iter_1.txt ../data/13B_PAPER_1M_iter_1.txt 80 | 81 | # Producing 13B_PAPER_1M_iter_2.txt from the puzzle/solution new style data 82 | python preprocess.py ../data/13B_PAPER/iter_2 13B_PAPER_1M_iter_2.txt 8 False 1000000 False -seed=2022 83 | cp ../data/13B_PAPER/iter_2/13B_PAPER_1M_iter_2.txt ../data/13B_PAPER_1M_iter_2.txt 84 | 85 | # Producing 27B_PAPER_25K_iter_0.txt from the puzzle/solution new style data 86 | python preprocess.py ../data/27B_PAPER/iter_0 27B_PAPER_25K_iter_0.txt 8 False 25000 False -seed=2022 87 | cp ../data/27B_PAPER/iter_0/27B_PAPER_25K_iter_0.txt ../data/27B_PAPER_25K_iter_0.txt 88 | 89 | # Producing 27B_PAPER_1M_iter_1.txt from the puzzle/solution new style data 90 | python preprocess.py ../data/27B_PAPER/iter_1 27B_PAPER_1M_iter_1.txt 8 False 1000000 False -seed=2022 91 | cp ../data/27B_PAPER/iter_1/27B_PAPER_1M_iter_1.txt ../data/27B_PAPER_1M_iter_1.txt 92 | 93 | # Producing 27B_PAPER_1M_iter_2.txt from the puzzle/solution new style data 94 | python preprocess.py ../data/27B_PAPER/iter_2 27B_PAPER_1M_iter_2.txt 8 False 1000000 False -seed=2022 95 | cp ../data/27B_PAPER/iter_2/27B_PAPER_1M_iter_2.txt ../data/27B_PAPER_1M_iter_2.txt 96 | 97 | # Data files produced by babysit.sh - generating data from gpt-neo-* and Codex 98 | # At the time of experiments running, Codex wasn't finetunable, so only iteration 0 data was available 99 | Codex_PAPER_1M_iter_0.txt 100 | 125M_PAPER_25K_iter_0.txt 101 | 13B_PAPER_25K_iter_0.txt 102 | 27B_PAPER_25K_iter_0.txt 103 | 125M_PAPER_1M_iter_1.txt 104 | 13B_PAPER_1M_iter_1.txt 105 | 27B_PAPER_1M_iter_1.txt 106 | 125M_PAPER_1M_iter_2.txt 107 | 13B_PAPER_1M_iter_2.txt 108 | 27B_PAPER_1M_iter_2.txt 109 | 110 | # Figure 5 - 3 diagrams - showing the 3 GPT models trained on verified codex vs unverified codex vs baseline 111 | # 5a GPT-NEO 125M 112 | ./fine_tune1.sh 0 125M ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt 113 | ./fine_tune1.sh 0 125M ft1_Codex_unverified_PAPER_1M_iter_0 Codex_unverified_PAPER_1M_iter_0.txt 114 | ./solve1.sh 0 125M 10 228 115 | # 5b GPT-NEO 13B 116 | ./fine_tune1.sh 0 13B ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt 117 | ./fine_tune1.sh 0 13B ft1_Codex_unverified_PAPER_1M_iter_0 Codex_unverified_PAPER_1M_iter_0.txt 118 | ./solve1.sh 0 13B 10 228 5 119 | # 5c GPT-NEO 27B 120 | ./fine_tune1.sh 0 27B ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt 121 | ./fine_tune1.sh 0 27B ft1_Codex_unverified_PAPER_1M_iter_0 Codex_unverified_PAPER_1M_iter_0.txt 122 | ./solve1.sh 0 13B 10 228 5 123 | 124 | # Figure 6 - 3 diagrams - showing test228 Pass@ for the 3 GPT models trained on data from 4 generators (codex and 3 GPT-Neo) and baseline 125 | # 6a - GPT-NEO 125M trained on 4 different datasets and baseline 126 | # ./fine_tune1.sh 0 125M ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt (dupe of 5a) 127 | ./fine_tune1.sh 0 125M ft1_125M_PAPER_1M_iter_2 125M_PAPER_1M_iter_2.txt 128 | ./fine_tune1.sh 0 125M ft1_13B_PAPER_1M_iter_2 13B_PAPER_1M_iter_2.txt 129 | ./fine_tune1.sh 0 125M ft1_27B_PAPER_1M_iter_2 27B_PAPER_1M_iter_2.txt 130 | 131 | # 6b - GPT-NEO 13B trained on 4 different datasets and baseline 132 | # ./fine_tune1.sh 0 13B ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt (dupe of 5b) 133 | ./fine_tune1.sh 0 13B ft1_125M_PAPER_1M_iter_2 125M_PAPER_1M_iter_2.txt 134 | ./fine_tune1.sh 0 13B ft1_13B_PAPER_1M_iter_2 13B_PAPER_1M_iter_2.txt 135 | ./fine_tune1.sh 0 13B ft1_27B_PAPER_1M_iter_2 27B_PAPER_1M_iter_2.txt 136 | 137 | # 6c - GPT-NEO 27B trained on 4 different datasets and baseline 138 | # ./fine_tune1.sh 0 27B ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt (dupe of 5c) 139 | ./fine_tune1.sh 0 27B ft1_125M_PAPER_1M_iter_2 125M_PAPER_1M_iter_2.txt 140 | ./fine_tune1.sh 0 27B ft1_13B_PAPER_1M_iter_2 13B_PAPER_1M_iter_2.txt 141 | ./fine_tune1.sh 0 27B ft1_27B_PAPER_1M_iter_2 27B_PAPER_1M_iter_2.txt 142 | 143 | # Launch on torch2020 - edit solve.yaml for correct parameters of model and epoch 144 | ./tst_human_eval_base.sh 0 125M 1024 145 | ./tst_human_eval_ft1.sh 0 125M 1024 146 | ./tst_human_eval_ft5.sh 0 125M 1024 147 | ./tst_human_eval_ft10.sh 0 125M 1024 148 | -------------------------------------------------------------------------------- /ICLR2023/data/125M_PAPER_1M_iter_1.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/ICLR2023/data/125M_PAPER_1M_iter_1.txt.gz -------------------------------------------------------------------------------- /ICLR2023/data/13B_PAPER_1M_iter_1.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/ICLR2023/data/13B_PAPER_1M_iter_1.txt.gz -------------------------------------------------------------------------------- /ICLR2023/data/27B_PAPER_1M_iter_1.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/ICLR2023/data/27B_PAPER_1M_iter_1.txt.gz -------------------------------------------------------------------------------- /ICLR2023/data/350M_PAPER_1M_iter_0.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/ICLR2023/data/350M_PAPER_1M_iter_0.txt.gz -------------------------------------------------------------------------------- /ICLR2023/data/Codex_PAPER_1M_iter_0.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/ICLR2023/data/Codex_PAPER_1M_iter_0.txt.gz -------------------------------------------------------------------------------- /ICLR2023/requirements.txt: -------------------------------------------------------------------------------- 1 | adal==1.2.7 2 | aiohttp==3.8.5 3 | aiosignal==1.2.0 4 | amlt==8.0.9 5 | applicationinsights==0.11.10 6 | asn1crypto==0.24.0 7 | astor==0.8.1 8 | async-timeout==4.0.1 9 | attrs==17.4.0 10 | Automat==0.6.0 11 | azure-common==1.1.27 12 | azure-core==1.17.0 13 | azure-data-tables==12.0.0b6 14 | azure-graphrbac==0.61.1 15 | azure-identity==1.4.1 16 | azure-mgmt-authorization==0.61.0 17 | azure-mgmt-containerregistry==2.8.0 18 | azure-mgmt-keyvault==2.2.0 19 | azure-mgmt-resource==13.0.0 20 | azure-mgmt-storage==11.2.0 21 | azure-storage-blob==2.1.0 22 | azure-storage-common==2.1.0 23 | azure-storage-file==2.1.0 24 | azureml-automl-core==1.26.0 25 | azureml-contrib-k8s==0.1.16 26 | azureml-contrib-pipeline-steps==1.26.0 27 | azureml-core==1.26.0 28 | azureml-dataprep==2.13.2 29 | azureml-dataprep-native==32.0.0 30 | azureml-dataprep-rslex==1.11.2 31 | azureml-dataset-runtime==1.26.0 32 | azureml-k8s-mt==1.0.4 33 | azureml-pipeline-core==1.26.0 34 | azureml-pipeline-steps==1.26.0 35 | azureml-telemetry==1.26.0 36 | azureml-train-automl-client==1.26.0 37 | azureml-train-core==1.26.0 38 | azureml-train-restclients-hyperdrive==1.26.0 39 | backcall==0.2.0 40 | backports.tempfile==1.0 41 | backports.weakref==1.0.post1 42 | beautifulsoup4==4.9.3 43 | bitstring==3.1.9 44 | black==21.8b0 45 | blinker==1.4 46 | blis==0.7.4 47 | blobxfer==1.10.0 48 | cachetools==4.2.2 49 | catalogue==2.0.6 50 | certifi==2023.7.22 51 | cffi==1.14.6 52 | chardet==3.0.4 53 | charset-normalizer==2.0.7 54 | click==7.1.2 55 | click-completion @ git+https://github.com/temporaer/click-completion.git@41b21868cac0781d25b37da624bae2fd1f36be88 56 | click-option-group==0.5.3 57 | click-plugins==1.1.1 58 | cloud-init==20.2 59 | cloudpickle==1.6.0 60 | colorama==0.3.7 61 | colorlog==6.4.1 62 | command-not-found==0.3 63 | configobj==5.0.6 64 | configparser==5.0.2 65 | constantly==15.1.0 66 | contextlib2==21.6.0 67 | cryptography==41.0.4 68 | cycler==0.10.0 69 | cymem==2.0.5 70 | datasets==1.15.1 71 | debugpy==1.4.3 72 | decorator==5.0.9 73 | deepspeed==0.5.1 74 | dill==0.3.4 75 | distro==1.6.0 76 | distro-info===0.18ubuntu0.18.04.1 77 | docker==5.0.1 78 | docker-pycreds==0.4.0 79 | dotnetcore2==2.1.21 80 | ecdsa==0.17.0 81 | entrypoints==0.3 82 | et-xmlfile==1.1.0 83 | fail2ban==0.10.2 84 | fastai==2.5.2 85 | fastcore==1.3.26 86 | fastdownload==0.0.5 87 | fastprogress==1.0.0 88 | filelock==3.0.12 89 | Flask==2.3.2 90 | Flask-Cors==3.0.10 91 | Flask-Executor==0.9.4 92 | Flask-FontAwesome==0.1.5 93 | frozenlist==1.2.0 94 | fsspec==2021.11.0 95 | gitdb==4.0.7 96 | GitPython==3.1.35 97 | httplib2==0.19.0 98 | huggingface-hub==0.1.2 99 | humanize==3.11.0 100 | hyperlink==17.3.1 101 | idna==2.6 102 | incremental==16.10.1 103 | ipdb==0.13.9 104 | ipykernel==6.4.1 105 | ipython==8.10.0 106 | ipython-genutils==0.2.0 107 | isodate==0.6.0 108 | itsdangerous==2.0.1 109 | jedi==0.18.0 110 | Jinja2==3.0.1 111 | jmespath==0.10.0 112 | joblib==1.2.0 113 | jsonpatch==1.16 114 | jsonpickle==2.0.0 115 | jsonpointer==1.10 116 | jsonschema==2.6.0 117 | jupyter-client==7.0.5 118 | jupyter-core==4.11.2 119 | keyring==10.6.0 120 | keyrings.alt==3.0 121 | kiwisolver==1.3.2 122 | language-selector==0.1 123 | libtmux==0.10.1 124 | Mako==1.2.2 125 | MarkupSafe==2.0.1 126 | marshmallow==3.10.0 127 | matplotlib==3.4.3 128 | matplotlib-inline==0.1.3 129 | mlb-core==0.0.4 130 | msal==1.14.0 131 | msal-extensions==0.2.2 132 | msrest==0.6.19 133 | msrestazure==0.6.4 134 | multidict==5.2.0 135 | multiprocess==0.70.12.2 136 | murmurhash==1.0.5 137 | mypy-extensions==0.4.3 138 | ndg-httpsclient==0.5.1 139 | nest-asyncio==1.5.1 140 | netifaces==0.10.4 141 | ninja==1.10.2 142 | ntlm-auth==1.5.0 143 | numpy==1.22.0 144 | oauthlib==3.2.2 145 | openai==0.13.0 146 | openpyxl==3.0.9 147 | orderedset==2.0.3 148 | packaging==21.0 149 | PAM==0.4.2 150 | pandas==1.3.2 151 | pandas-stubs==1.2.0.45 152 | parso==0.8.2 153 | passpy==1.0.2 154 | pathspec==0.9.0 155 | pathtools==0.1.2 156 | pathy==0.6.0 157 | Pebble==4.6.3 158 | petname==2.6 159 | pexpect==4.8.0 160 | pickleshare==0.7.5 161 | Pillow==10.0.1 162 | platformdirs==2.3.0 163 | portalocker==1.7.1 164 | preshed==3.0.5 165 | promise==2.3 166 | prompt-toolkit==3.0.20 167 | protobuf==3.18.3 168 | psb2==1.0.0 169 | psutil==5.8.0 170 | ptyprocess==0.7.0 171 | pyarrow==1.0.1 172 | pyasn1==0.4.2 173 | pyasn1-modules==0.2.1 174 | pycparser==2.20 175 | pycrypto==2.6.1 176 | pydantic==1.8.2 177 | Pygments==2.15.0 178 | PyGObject==3.26.1 179 | PyJWT==2.4.0 180 | pyOpenSSL==17.5.0 181 | pyparsing==2.4.7 182 | pyperclip==1.8.2 183 | pyserial==3.4 184 | python-apt==1.6.5+ubuntu0.3 185 | python-dateutil==2.8.2 186 | python-debian==0.1.32 187 | python-gnupg==0.4.7 188 | pytz==2021.1 189 | pyxdg==0.26 190 | PyYAML==5.4.1 191 | pyzmq==22.3.0 192 | regex==2021.8.28 193 | requests==2.31.0 194 | requests-ntlm==1.1.0 195 | requests-oauthlib==1.3.0 196 | requests-unixsocket==0.1.5 197 | ruamel.yaml==0.17.16 198 | ruamel.yaml.clib==0.2.6 199 | sacremoses==0.0.45 200 | scikit-learn==0.24.2 201 | scipy==1.10.0 202 | SecretStorage==2.3.1 203 | sentry-sdk==1.14.0 204 | service-identity==16.0.0 205 | shellingham==1.4.0 206 | shortuuid==1.0.1 207 | six==1.16.0 208 | sklearn==0.0 209 | smart-open==5.2.1 210 | smmap==4.0.0 211 | soupsieve==2.2.1 212 | spacy==3.1.2 213 | spacy-legacy==3.0.8 214 | srsly==2.4.1 215 | ssh-import-id==5.7 216 | sshpubkeys==3.3.1 217 | strictfire==0.4.1 218 | subprocess32==3.5.4 219 | systemd-python==234 220 | tabulate==0.8.9 221 | tensorboardX==1.8 222 | termcolor==1.1.0 223 | thinc==8.0.10 224 | threadpoolctl==2.2.0 225 | tokenizers==0.10.3 226 | toml==0.10.2 227 | tomli==1.2.1 228 | torch==1.13.1 229 | torchvision==0.10.0 230 | tornado==6.3.3 231 | tqdm==4.62.2 232 | traitlets==5.1.0 233 | transformers==4.30.0 234 | Twisted==22.10.0 235 | typer==0.3.2 236 | typing-extensions==3.10.0.2 237 | ufw==0.36 238 | unattended-upgrades==0.1 239 | urllib3==1.26.17 240 | virtualenv==15.1.0 241 | WALinuxAgent==2.2.45 242 | wasabi==0.8.2 243 | wcwidth==0.2.5 244 | websocket-client==1.2.1 245 | Werkzeug==2.2.3 246 | xdg==5.1.1 247 | xxhash==2.0.2 248 | yarl==1.7.2 249 | zope.interface==4.3.2 250 | -------------------------------------------------------------------------------- /ICLR2023/src/babysit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # All Experiment Settings - constant through the experiment run - passed to gen.sh and fine_tune.sh as needed 3 | GPU=0 # which GPU to use 4 | MODEL="125M" # MODEL is the size of the model: 125M, 13B, 27B 5 | EXPERIMENT=$MODEL"_PAPER" # Name of Experiment directory under data/* and models/base-model/* to store results 6 | TEST_LOCAL=1 # 0 means run gen/fine_tune on cluster remotely, 1 means run gen/fine_tune locally 7 | TARGET_NUM_FILES=1 # How many files to generate in each iteration before starting fine_tuning. Count of unique examples would have been better. 8 | ITER_START=0 # inclusive index to start processing at - creates iter_# under data&models at each iteration. Can continue prev runs by start at prev ITER_MAX 9 | ITER_MAX=5 # exclusive index to stop processing iterations at 10 | EPOCHS_START=1 # inclusive index of epochs to start processing at - could continue prev run by starting at prev EPOCHS_MAX+1 - 0th epoch is the default model so epoch starts at 1 11 | EPOCHS_MAX=4 # inclusive index of epochs to stop processing at 12 | EPOCHS_PER_STEP=1 # How many EPOCHS through the data to do in each step 13 | TRAIN_INCREMENTAL=0 # Only train on data from the latest iteration, and start finetuning on the last finetuned model - otherwise start from scratch and use all the data generated 14 | TRAIN_BOOST=0 # Initial generation of data from default model is slow - 1 means looks in 125M_RL_ALL to use previous generated initial data to bootstrap. 15 | PASS_AT_K=100 # PASS_AT_K says do K trials to solve to compute Pass@K 16 | LINE_LOG_K=11 # LINE_LOG_K is how many lines of results from solve have results for saving 17 | 18 | echo babysit args: $# $0 $1 $2 $3 $4 19 | 20 | if (( $# \!= 1 )) 21 | then 22 | echo babysit.sh only takes 1 argument, unless called by another script to initialize configuration variables 23 | return 24 | fi 25 | 26 | if (( $# \>= 1 )) 27 | then 28 | GPU=$1 29 | fi 30 | 31 | echo babysit GPU $GPU 32 | 33 | for (( iteration=$ITER_START; iteration<$ITER_MAX; iteration++ )) 34 | do 35 | FULLNAME="${EXPERIMENT}---${iteration}" 36 | echo FULLNAME $FULLNAME 37 | export FULLNAME # Needed to pass variable off to yaml job 38 | DATAPATH=data/${EXPERIMENT}/iter_$iteration 39 | echo DATAPATH $DATAPATH 40 | 41 | if (( $TEST_LOCAL \> 0 )) 42 | then 43 | count=`ls -lt ../${DATAPATH} | grep json | wc -l` 44 | else 45 | count=`amlt sto list ${DATAPATH} | grep json | wc -l` 46 | fi 47 | echo count $count 48 | 49 | # Instead of file count we might want to check if the amount of data from preprocess is sufficient 50 | # If not we call to generate more 51 | 52 | if (( $count \> 0 )) 53 | then 54 | echo "$FULLNAME has already been started" 55 | echo "You are resuming at iteration $iteration" 56 | echo "You already have $count files of data this iteration" 57 | else 58 | echo "$FULLNAME is starting generation for iteration $iteration" 59 | fi 60 | 61 | if (( $count \< $TARGET_NUM_FILES )) 62 | then 63 | if (( $TEST_LOCAL \> 0 )) 64 | then 65 | # ./gen.sh $GPU 2560 100 $FULLNAME -1 66 | # 2.7B 384 100 runs ~10 hours 67 | # 2.7B 160 100 runs ~4.5 hours 68 | ./gen.sh $GPU 256000 100 $FULLNAME -1 69 | else 70 | amlt run hyper_gen_octows.yaml $FULLNAME -d "$FULLNAME" 71 | exit 72 | fi 73 | fi 74 | 75 | # Running local you are done, but launching on the cloud, you have to wait 76 | for (( poll=0; poll<500; poll++ )) 77 | do 78 | if (( $TEST_LOCAL \> 0 )) 79 | then 80 | count=`ls -lt ../${DATAPATH} | grep json | wc -l` 81 | else 82 | count=`amlt sto list ${DATAPATH} | grep json | wc -l` 83 | fi 84 | 85 | echo "gen wait - Iteration: $iteration, Poll: $poll, Count: $count" 86 | 87 | if (( $count \>= $TARGET_NUM_FILES )) 88 | then 89 | echo "Finished generation iteration $iteration after $poll polls" 90 | break 91 | fi 92 | sleep 3m 93 | done 94 | 95 | # Start a finetune job 96 | if (( $TEST_LOCAL \> 0 )) 97 | then 98 | ./fine_tune.sh $GPU $FULLNAME 99 | else 100 | # Pass enviroment variable FULLNAME to amlt.yaml 101 | amlt run amlt_octo.yaml $FULLNAME -d "$FULLNAME" 102 | exit 103 | fi 104 | 105 | # On cluster we need to wait for finetune job to finish, run locally it's done 106 | # Check the log files for starting the running of solve have been created for the last epoch of training 107 | 108 | MODELPATH=models/gpt-neo-$MODEL/${EXPERIMENT}/iter_$iteration 109 | SOLVE_PATH=$MODELPATH/"epoch_"$EPOCHS_MAX/"solve_"$PASS_AT_K 110 | echo babysit.sh SOLVE_PATH $SOLVE_PATH 111 | 112 | for (( poll=0; poll<500; poll++ )) 113 | do 114 | if (( $TEST_LOCAL \> 0 )) 115 | then 116 | count=`ls -lt ../$SOLVE_PATH | grep json | wc -l` 117 | else 118 | count=`amlt sto list $SOLVE_PATH | grep json | wc -l` 119 | fi 120 | 121 | echo "fine_tune wait - Iteration: $iteration, Poll: $poll, Count: $count" 122 | 123 | if (( $count \>= 1 )) 124 | then 125 | echo "Finished fine_tune iteration $iteration after $poll polls" 126 | break 127 | fi 128 | sleep 3m 129 | done 130 | 131 | done 132 | 133 | # Pull all the results into 1 log file to look at more easily 134 | 135 | if [[ -z "${AMLT_DATA_DIR}" ]]; 136 | then 137 | # running locally on torch2020 so we don't have AMLT enviroment variables defined, so need to set them up 138 | AMLT_DATA_DIR=../data 139 | else 140 | # On remote we don't have access to the log files - maybe could do amlt sto download to do this summary below ? 141 | exit 142 | fi 143 | 144 | BASE_MODEL_PATH=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL 145 | LOG_FILE=$BASE_MODEL_PATH/$EXPERIMENT/"solve_"$PASS_AT_K".txt" 146 | echo solve LOG_FILE for babysit.sh is $LOG_FILE 147 | rm $LOG_FILE 148 | 149 | for (( iteration=$ITER_START; iteration<$ITER_MAX; iteration++ )) 150 | do 151 | for (( epochs=$EPOCHS_START; epochs<=$EPOCHS_MAX; epochs++ )) 152 | do 153 | EPOCH_NAME="epoch_"$epochs 154 | STEP_PATH=$BASE_MODEL_PATH/$EXPERIMENT/iter_$iteration/$EPOCH_NAME 155 | MODEL_PATH=$STEP_PATH/finetuned 156 | echo iteration $iteration epoch $epochs >> $LOG_FILE 157 | head -$LINE_LOG_K $STEP_PATH/"solve_"$PASS_AT_K/results.json >> $LOG_FILE 158 | done 159 | done 160 | 161 | cat $LOG_FILE 162 | -------------------------------------------------------------------------------- /ICLR2023/src/ds_config_gptneo.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "optimizer": { 11 | "type": "AdamW", 12 | "params": { 13 | "lr": "auto", 14 | "betas": "auto", 15 | "eps": "auto", 16 | "weight_decay": "auto" 17 | } 18 | }, 19 | "scheduler": { 20 | "type": "WarmupLR", 21 | "params": { 22 | "warmup_min_lr": "auto", 23 | "warmup_max_lr": "auto", 24 | "warmup_num_steps": "auto" 25 | } 26 | }, 27 | "zero_optimization": { 28 | "stage": 2, 29 | "allgather_partitions": true, 30 | "allgather_bucket_size": 2e8, 31 | "overlap_comm": true, 32 | "reduce_scatter": true, 33 | "reduce_bucket_size": 2e8, 34 | "contiguous_gradients": true, 35 | "cpu_offload": true 36 | }, 37 | "gradient_accumulation_steps": "auto", 38 | "gradient_clipping": "auto", 39 | "steps_per_print": 2000, 40 | "train_batch_size": "auto", 41 | "train_micro_batch_size_per_gpu": "auto", 42 | "wall_clock_breakdown": false 43 | } 44 | -------------------------------------------------------------------------------- /ICLR2023/src/fine_tune.py: -------------------------------------------------------------------------------- 1 | from strictfire import StrictFire as Fire # aborts early on invalid arguments 2 | import os 3 | import csv 4 | import subprocess 5 | import shlex 6 | import random 7 | import numpy as np 8 | import torch 9 | import utils 10 | 11 | def fine_tune( 12 | train_txt="../data/generated_sol_100.txt", 13 | output_dir = "../outputs/", 14 | subdir="out", 15 | model_path="EleutherAI/gpt-neo-2.7B", 16 | gpu=0, 17 | num_gpus=1, 18 | epochs=4, 19 | seed=0, 20 | ): 21 | """ 22 | Fine tune the model on the puzzles in train_txt file and save the results to OUTPUT_DIR/output_subdir 23 | 24 | train_txt: the (possibly gzipped) file containing the text to fine tune on (default: ../data/generated_sol_100.txt) 25 | subdir: the subdirectory to save the results to (default "out") 26 | model_path: the path to the model to fine tune (default "EleutherAI/gpt-neo-2.7B") 27 | gpu: which GPU(s) to use, e.g.: 0,1 (default 0) 28 | epochs: how many epochs to train for (default 4) 29 | seed: the random seed to use, not sure if this affects fine tuning (default 0) 30 | """ 31 | random.seed(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | 35 | # create output dir if necessary 36 | output_path = os.path.join(output_dir, subdir) 37 | if not os.path.exists(output_path): 38 | os.makedirs(output_path) 39 | 40 | 41 | text = utils.load_text_file(train_txt) # decompresses if ends in .gz 42 | tokenizer = utils.load_tokenizer(model_path) 43 | num_toks = utils.num_tokens(text, tokenizer, verbose=True) 44 | assert num_toks > 1024, "Not enough tokens in text to fine tune" 45 | 46 | # create csv 47 | train_file = os.path.join(output_path, "train.csv") 48 | with open(train_file, mode="w", encoding="utf-8") as csv_file: 49 | fieldnames = ["text"] 50 | writer = csv.DictWriter(csv_file, fieldnames=fieldnames) 51 | writer.writeheader() 52 | writer.writerow({"text": text}) 53 | 54 | output_path_finetuned = os.path.join(output_path, "finetuned") 55 | 56 | # keep gradient_accumulation_steps at 1 bc setting it to 2 effectively doubles the batch 57 | # size which gets tricky when batch sizes are small (ft_tokens will no longer be accurate) 58 | gradient_accumulation_steps = 1 59 | per_device_train_batch_size = 4 60 | 61 | cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") 62 | if len(cuda_visible_devices): 63 | print("os.environ(CUDA_VISIBLE_DEVICES)", cuda_visible_devices) 64 | del os.environ["CUDA_VISIBLE_DEVICES"] 65 | print("os.environ(CUDA_VISIBLE_DEVICES)", os.environ.get("CUDA_VISIBLE_DEVICES", "")) 66 | 67 | master_port = 29600 # During training deepspeed uses a port to syncronize. 2 jobs need to set different ports to run in parallel 68 | if type(gpu) in [list, tuple]: 69 | master_port += gpu[0] 70 | gpu = ",".join([str(g) for g in gpu]) 71 | else: 72 | master_port += gpu 73 | 74 | gpu_string = f'--include=localhost:{gpu}' 75 | 76 | if num_gpus > 1: 77 | gpu_string = f"--num_nodes=1 --num_gpus={num_gpus}", 78 | # If gpu is passed in as negative - it's the count of gpu to use - a bit of a hack 79 | if gpu < 0: 80 | num_gpus = abs(gpu) 81 | gpu_string = f"--num_nodes=1 --num_gpus={num_gpus}" 82 | 83 | print("gpu_string", gpu_string) 84 | 85 | cmd = " ".join( 86 | [ 87 | "deepspeed", 88 | f"--master_port={master_port}", 89 | gpu_string, 90 | # f'--include=localhost:{gpu}', 91 | # "--num_nodes=1", 92 | # f"--num_gpus={num_gpus}", 93 | "neo_train.py", 94 | f"--model_name_or_path={model_path}", 95 | f"--train_file={train_file}", 96 | f"--output_dir={output_path_finetuned}", 97 | "--overwrite_output_dir", 98 | "--ignore_data_skip", 99 | "--deepspeed", 100 | "ds_config_gptneo.json", 101 | f"--save_strategy=no", # ATK remove checkpointing for large datasets 102 | # pretty sure this is just dataset cache 103 | "--overwrite_cache", 104 | # logging frequency 105 | "--logging_steps=5", 106 | "--do_train", 107 | "--report_to none", # turns off report_to WANDB for instance 108 | "--fp16", 109 | f"--num_train_epochs={epochs}", 110 | # overrides num_train_epochs if set to a positive value. This is the number of gradient steps that happen total. 111 | f"--per_device_train_batch_size={per_device_train_batch_size}", 112 | "--use_fast_tokenizer=False", 113 | f"--gradient_accumulation_steps={gradient_accumulation_steps}", 114 | "--learning_rate=5e-06", 115 | # linear increase from this up to learning rate, then LR schedule happens (which itself is linear decreasing until max_steps) 116 | "--warmup_steps=10", 117 | ] 118 | ) 119 | 120 | utils.info(f"running command: {cmd}") 121 | print(f"Command to run:{cmd}") # Why is this different than what utils.info prints out, utils.info truncates it 122 | # exit() 123 | res = subprocess.run(shlex.split(cmd), check=True) 124 | utils.info(str(res)) 125 | 126 | 127 | if __name__ == "__main__": 128 | Fire(fine_tune) 129 | -------------------------------------------------------------------------------- /ICLR2023/src/fine_tune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo fine_tune.sh args: $# $0 $1 $2 $3 $4 3 | # Grab the configuration variables 4 | . babysit.sh 5 | 6 | # On AMLT machines we don't specify which GPU to use 7 | # GPU="-1" 8 | if [[ -z "${AMLT_DATA_DIR}" ]]; then 9 | # running locally on torch2020 so we don't have AMLT enviroment variables defined, so need to set them up 10 | AMLT_DATA_DIR=../data 11 | # On torch2020 we do specify which GPU to use 12 | # GPU="0" 13 | fi 14 | 15 | # assert that there are at least 2 argument 16 | if (( $# \< 2 )) 17 | then 18 | echo "Usage: $0 " 19 | exit 20 | fi 21 | 22 | GPU=$1 23 | FULLNAME=$2 24 | 25 | # split by ; fullname string into experiment name and iteration 26 | # e.g. "125M_RL---0" --> "125M_RL;0" 27 | SPLIT=(${FULLNAME//---/ }) 28 | EXPERIMENT=${SPLIT[0]} 29 | ITERATION=${SPLIT[1]} 30 | OUTPATH=$AMLT_DATA_DIR/$EXPERIMENT/iter_$ITERATION 31 | 32 | echo GPU $GPU 33 | echo EXPERIMENT $EXPERIMENT 34 | echo ITERAION $ITERATION 35 | echo OUTPATH $OUTPATH 36 | 37 | # GPU_SOLVE is the GPU we want solve to use. Solve currently only uses 1 GPU - it would be great to make it use more when they are available. 38 | # if GPU is negative - that tells fine_tune how many GPU to use on cluster - and we need to set GPU for solve to 0 on cluster 39 | # if GPU is positive - we are running locally on torch2020 - and we need to leave the GPU set properly 40 | GPU_SOLVE=$GPU 41 | if (( $GPU \< 0 )) 42 | then 43 | GPU_SOLVE=0 44 | fi 45 | echo GPU_SOLVE $GPU_SOLVE 46 | 47 | python preprocess.py $OUTPATH 48 | 49 | TRN_FILE=$OUTPATH/gen_ps_filtered.txt 50 | echo TRN_FILE $TRN_FILE 51 | TST_FILE=$AMLT_DATA_DIR/test_228.json 52 | 53 | BASE_MODEL_PATH=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL 54 | # 125M is copied locally to start 55 | MODEL_PATH=$BASE_MODEL_PATH 56 | MODEL_PATH=EleutherAI/gpt-neo-125M # !!! Just for paper release 57 | # 13B is off in the cloud to start 58 | if [[ "$MODEL" == "13B" ]]; then 59 | MODEL_PATH=EleutherAI/gpt-neo-1.3B 60 | fi 61 | # 27B is off in the cloud tto start 62 | if [[ "$MODEL" == "27B" ]]; then 63 | MODEL_PATH=EleutherAI/gpt-neo-2.7B 64 | fi 65 | 66 | echo MODEL MODEL_PATH $MODEL $MODEL_PATH 67 | 68 | # Training incremental means use the previous iterations trained model, and just the additional iteration's new data to fine_tune on. 69 | # Otherwise use the base model - and retrain from scratch on all the data from all previous iterations. 70 | # They are sort of equivalent - except from scratch picks up any extra data that was generated - and mixes all the iterations data together - but slower. 71 | if (( $TRAIN_INCREMENTAL \> 0 )) 72 | then 73 | PREV_ITERATION=$((ITERATION-1)) 74 | echo $PREV_ITERATION 75 | TEST=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL/$EXPERIMENT/iter_$PREV_ITERATION/epoch_$EPOCHS_MAX/finetuned 76 | 77 | if [ -a $TEST ] # exists 78 | then 79 | MODEL_PATH=$TEST 80 | echo fine_tune.sh using previous iteration model 81 | fi 82 | fi 83 | 84 | echo "fine_tune.sh starting from NEO model at: ${MODEL_PATH}" 85 | 86 | # Pull all the results into 1 log file to look at more easily 87 | LOG_FILE=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION/"solve.txt" 88 | echo solve LOG_FILE for fine_tune.sh is $LOG_FILE 89 | rm $LOG_FILE 90 | 91 | for (( epochs=$EPOCHS_START; epochs<=$EPOCHS_MAX; epochs++ )) 92 | do 93 | EPOCH_NAME="epoch_"$epochs 94 | EPOCHS_STEP=$(($EPOCHS_PER_STEP * $epochs)) 95 | python fine_tune.py -train_txt=$TRN_FILE -gpu=$GPU -output_dir=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION -subdir=$EPOCH_NAME -model_path=$MODEL_PATH -epochs=$EPOCHS_STEP 96 | # measure the finetuned model's accuracy 97 | STEP_PATH=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION/$EPOCH_NAME 98 | MODEL_PATH=$STEP_PATH/finetuned 99 | python solve.py -prefix=$AMLT_DATA_DIR/train_prefix.txt -attempts=$PASS_AT_K -model_path=$MODEL_PATH -gpu=$GPU_SOLVE -fixed_temp=0.8 -out=$STEP_PATH/"solve_"$PASS_AT_K"/" -puzzles=$TST_FILE 100 | head -$LINE_LOG_K $STEP_PATH/"solve_"$PASS_AT_K/results.json >> $LOG_FILE 101 | done 102 | 103 | cat $LOG_FILE -------------------------------------------------------------------------------- /ICLR2023/src/fine_tune1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This is for finetuning a model on 1 dataset only 3 | echo fine_tune1.sh args: $# $0 $1 $2 $3 $4 4 | 5 | # All Experiment Settings - constant through the experiment run 6 | GPU=0 # which GPU to use 7 | MODEL="125M" # MODEL is the size of the model: 125M, 13B, 27B 8 | EXPERIMENT=$MODEL"_PAPER1" # Name of Experiment directory under data/* and models/base-model/* to store results 9 | ITERATION=0 # Random seed for finetuning 10 | EPOCHS_START=1 # inclusive index of epochs to start processing at - could continue prev run by starting at prev EPOCHS_MAX+1 - 0th epoch is the default model so epoch starts at 1 11 | EPOCHS_MAX=10 # inclusive index of epochs to stop processing at 12 | EPOCHS_PER_STEP=1 # How many EPOCHS through the data to do in each step 13 | PASS_AT_K=100 # PASS_AT_K says do K trials to solve to compute Pass@K 14 | LINE_LOG_K=11 # LINE_LOG_K is how many lines of results from solve have results for saving 15 | 16 | # On AMLT machines we don't specify which GPU to use 17 | if [[ -z "${AMLT_DATA_DIR}" ]]; then 18 | # running locally on torch2020 so we don't have AMLT enviroment variables defined, so need to set them up 19 | AMLT_DATA_DIR=../data 20 | fi 21 | 22 | if (( $# \>= 1 )) 23 | then 24 | GPU=$1 25 | fi 26 | 27 | echo GPU $GPU 28 | echo EXPERIMENT $EXPERIMENT 29 | echo ITERAION $ITERATION 30 | 31 | TRN_FILE=$AMLT_DATA_DIR/generated_sol_950k.txt 32 | echo TRN_FILE $TRN_FILE 33 | TST_FILE=$AMLT_DATA_DIR/test_228.json 34 | 35 | # GPU_SOLVE is the GPU we want solve to use. Solve currently only uses 1 GPU - it would be great to make it use more when they are available. 36 | # if GPU is negative - that tells fine_tune how many GPU to use on cluster - and we need to set GPU for solve to 0 on cluster 37 | # if GPU is positive - we are running locally on torch2020 - and we need to leave the GPU set properly 38 | GPU_SOLVE=$GPU 39 | if (( $GPU \< 0 )) 40 | then 41 | GPU_SOLVE=0 42 | fi 43 | echo GPU_SOLVE $GPU_SOLVE 44 | 45 | # measure the base model's accuracy - don't really need to do this very often - it doesn't change 46 | 47 | BASE_MODEL_PATH=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL 48 | # 125M is copied locally to start 49 | MODEL_PATH=$BASE_MODEL_PATH 50 | MODEL_PATH=EleutherAI/gpt-neo-125M # !!! Just for paper release 51 | # 13B is off in the cloud to start 52 | if [[ "$MODEL" == "13B" ]]; then 53 | MODEL_PATH=EleutherAI/gpt-neo-1.3B 54 | fi 55 | # 27B is off in the cloud tto start 56 | if [[ "$MODEL" == "27B" ]]; then 57 | MODEL_PATH=EleutherAI/gpt-neo-2.7B 58 | fi 59 | 60 | echo MODEL MODEL_PATH $MODEL $MODEL_PATH 61 | 62 | # Training incremental means use the previous epochs model to start 63 | # Otherwise use the base model to retrain from scratch 64 | if (( $EPOCHS_START \> 1 )) 65 | then 66 | PREV_EPOCH=$((EPOCHS_START-1)) 67 | echo $PREV_EPOCH 68 | TEST=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL/$EXPERIMENT/iter_$ITERATION/epoch_$PREV_EPOCH/finetuned 69 | 70 | if [ -a $TEST ] # exists 71 | then 72 | MODEL_PATH=$TEST 73 | echo fine_tune.sh using previous iteration model 74 | fi 75 | fi 76 | 77 | echo "fine_tune.sh starting from NEO model at: ${MODEL_PATH}" 78 | 79 | # Pull all the results into 1 log file to look at more easily 80 | LOG_FILE=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION/"solve.txt" 81 | echo solve LOG_FILE for fine_tune.sh is $LOG_FILE 82 | rm $LOG_FILE 83 | 84 | for (( epochs=$EPOCHS_START; epochs<=$EPOCHS_MAX; epochs++ )) 85 | do 86 | EPOCH_NAME="epoch_"$epochs 87 | EPOCHS_STEP=$(($EPOCHS_PER_STEP * $epochs)) 88 | python fine_tune.py -train_txt=$TRN_FILE -gpu=$GPU -output_dir=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION -subdir=$EPOCH_NAME -model_path=$MODEL_PATH -epochs=$EPOCHS_STEP -seed=$ITERATION 89 | # measure the finetuned model's accuracy 90 | STEP_PATH=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION/$EPOCH_NAME 91 | MODEL_PATH=$STEP_PATH/finetuned 92 | python solve.py -prefix=$AMLT_DATA_DIR/train_prefix.txt -attempts=$PASS_AT_K -model_path=$MODEL_PATH -gpu=$GPU_SOLVE -fixed_temp=0.8 -out=$STEP_PATH/"solve_"$PASS_AT_K"/" -puzzles=$TST_FILE -seed=$ITERATION -batch_size=256 93 | head -$LINE_LOG_K $STEP_PATH/"solve_"$PASS_AT_K/results.json >> $LOG_FILE 94 | done 95 | 96 | cat $LOG_FILE -------------------------------------------------------------------------------- /ICLR2023/src/gen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Grab the configuration variables 3 | . babysit.sh 4 | 5 | if [[ -z "${AMLT_DATA_DIR}" ]]; then 6 | # running locally on torch2020 we don't have AMLT enviroment variables defined, set them up 7 | AMLT_DATA_DIR="../data" 8 | fi 9 | 10 | echo RANK is: 11 | echo $RANK 12 | 13 | if [[ -z "${RANK}" ]]; then 14 | # running locally on torch2020 we don't have AMLT enviroment variables defined, set them up 15 | RANK=0 16 | fi 17 | 18 | GPU=$RANK 19 | PUZZLE_CNT=32 20 | SOLUTION_CNT=32 21 | FULLNAME="125M_RL_TEST---0" 22 | 23 | echo $# $0 $1 $2 $3 $4 24 | if (( $# \>= 1 )) 25 | then 26 | GPU=$1 27 | fi 28 | 29 | echo $RANK 30 | echo $GPU 31 | 32 | if (( $# \>= 2 )) 33 | then 34 | PUZZLE_CNT=$2 35 | fi 36 | 37 | if (( $# \>= 3 )) 38 | then 39 | SOLUTION_CNT=$3 40 | fi 41 | 42 | if (( $# \>= 4 )) 43 | then 44 | FULLNAME=$4 45 | 46 | fi 47 | 48 | RANDOM_SEED=-1 49 | 50 | if (( $# \>= 5 )) 51 | then 52 | RANDOM_SEED=$5 53 | echo "Random seed is $RANDOM_SEED" 54 | fi 55 | 56 | SPLIT=(${FULLNAME//---/ }) 57 | EXPERIMENT=${SPLIT[0]} 58 | ITERATION=${SPLIT[1]} 59 | OUTPATH=$AMLT_DATA_DIR/$EXPERIMENT/iter_$ITERATION 60 | 61 | echo GPU $GPU 62 | echo EXPERIMENT $EXPERIMENT 63 | echo ITERAION $ITERATION 64 | echo OUTPATH $OUTPATH 65 | 66 | BASE_MODEL_PATH=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL 67 | # 125M is copied locally to start 68 | MODEL_PATH=$BASE_MODEL_PATH 69 | MODEL_PATH=EleutherAI/gpt-neo-125M # !!! Just for paper release 70 | # 13B is off in the cloud to start 71 | if [[ "$MODEL" == "13B" ]]; then 72 | MODEL_PATH=EleutherAI/gpt-neo-1.3B 73 | fi 74 | # 27B is off in the cloud tto start 75 | if [[ "$MODEL" == "27B" ]]; then 76 | MODEL_PATH=EleutherAI/gpt-neo-2.7B 77 | fi 78 | 79 | echo MODEL MODEL_PATH $MODEL $MODEL_PATH 80 | 81 | PREV_ITERATION=$((ITERATION-1)) 82 | echo $PREV_ITERATION 83 | TEST=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL/$EXPERIMENT/iter_$PREV_ITERATION/epoch_$EPOCHS_MAX/finetuned 84 | 85 | if [ -a $TEST ] # exists 86 | then 87 | MODEL_PATH=$TEST 88 | echo fine_tune.sh using previous iteration model 89 | fi 90 | 91 | echo "gen.sh using NEO model at: ${MODEL_PATH}" 92 | 93 | python gen.py -out="$OUTPATH" -n=$PUZZLE_CNT -seed=$RANDOM_SEED -gpu=$GPU -train=$AMLT_DATA_DIR/155_train.json -prefix=$AMLT_DATA_DIR/train_prefix.txt -model_path=$MODEL_PATH -attempts=$SOLUTION_CNT -only_good=True 94 | -------------------------------------------------------------------------------- /ICLR2023/src/judge.py: -------------------------------------------------------------------------------- 1 | from utils import load_json 2 | from pebble import ProcessPool 3 | import multiprocessing as mp 4 | from concurrent.futures import TimeoutError 5 | from typing import List, Set, Tuple, Dict 6 | 7 | import utils 8 | import sys 9 | import re 10 | from copy import deepcopy 11 | 12 | sys.setrecursionlimit(5000) 13 | 14 | 15 | def no_print(*_args, **_kwargs): 16 | pass 17 | 18 | 19 | def run_judge(judge, f, tests): 20 | answer_type = list(judge.__annotations__.values())[0] 21 | for x in tests: 22 | y = f(**deepcopy(x)) # so f cannot cheat and change the input x 23 | if not utils.type_check(y, answer_type): 24 | raise TypeError 25 | assert judge(y, **x) is True, f"{f} failed on test {x}" 26 | 27 | 28 | _ENV = dict( 29 | List=List, 30 | Set=Set, 31 | Tuple=Tuple, 32 | Dict=Dict, 33 | type_check=utils.type_check, 34 | run_judge=run_judge, 35 | test_puzzle=utils.test_puzzle, 36 | os=None, 37 | sys=None, 38 | input=None, 39 | open=None, 40 | print=no_print, 41 | compile=None, 42 | copyright=None, 43 | ) 44 | 45 | _UNSAFE = ["builtin", "__class", "open("] 46 | _SAFE_IMPORTS = {"collections", "copy", "hashlib", "math", "random", "re", "string", "typing"} 47 | 48 | MAX_WORKERS = mp.cpu_count() // 2 49 | 50 | 51 | 52 | def unsafe_imports(code): 53 | """Check if code imports any unsafe modules. 54 | 55 | Args: 56 | code (str): The code to check. 57 | 58 | Returns: 59 | bool: True if code imports unsafe modules. 60 | """ 61 | if "import" not in code: 62 | return False 63 | for line in code.split("\n"): 64 | if "import" in line: 65 | match = re.search(r"^\s*from\s+([\w\.]+)\s+import\s", line) 66 | if match: 67 | modules = [match.group(1)] 68 | else: 69 | match = re.search(r"^\s*import\s+(.+)", line) 70 | if match: 71 | modules = match.group(1).split(",") 72 | else: 73 | return True 74 | if any(m.strip() not in _SAFE_IMPORTS for m in modules): 75 | return True 76 | return False 77 | 78 | 79 | def _judge(code_env): 80 | code, env = code_env 81 | if unsafe_imports(code) or any(u in code for u in _UNSAFE): 82 | return False, Exception(f"unsafe code"), code 83 | try: 84 | exec(code, env.copy()) 85 | return True, None, code 86 | except Exception as e: 87 | return False, e, code 88 | 89 | 90 | def judge_parallel(src_codes, timeout, max_workers=MAX_WORKERS, env=_ENV): 91 | codes = utils.dedup(src_codes) 92 | utils.info( 93 | f"Judging {len(src_codes):,} codes ({len(src_codes)-len(codes):,} duplicates) with {max_workers} workers" 94 | ) 95 | successes = set() 96 | 97 | # print("writing to file for debugging before judging") 98 | # from train import save_json 99 | # 100 | # save_json(new_codes, "results/tmp/new_codes.json") 101 | utils.silence_std_err(True) 102 | with ProcessPool(max_workers=max_workers) as pool: 103 | future = pool.map(_judge, [(code, env) for code in codes], timeout=timeout) 104 | 105 | results = future.result() 106 | i = 0 107 | while True: 108 | try: 109 | success, exc, code = next(results) 110 | if success: 111 | successes.add(codes[i]) 112 | except StopIteration: 113 | break 114 | except (TimeoutError, Exception) as error: 115 | pass 116 | assert i < len(codes) 117 | i += 1 118 | assert i == len(codes) 119 | utils.silence_std_err(False) 120 | return [code in successes for code in src_codes] 121 | 122 | 123 | 124 | 125 | def test(): 126 | import time 127 | 128 | tests = [ 129 | ("def sol(a: int=10000200001):\n return (list(range(3 * a))[str(a)])\nx = sol()", False), 130 | ("print(12)", True), 131 | ("while True: pass", False), 132 | ("def sol(): sol()\nsol()", False), 133 | ("2+2", True), 134 | ("""1+1""", True), 135 | ("""assert False,'cats'""", False), 136 | ("""assert False""", False), 137 | ("""1[2]""", False), 138 | ("""1/0""", False), 139 | ( 140 | """while True: 141 | pass""", 142 | False, 143 | ), 144 | ( 145 | """for i in range(10**4): 146 | pass""", 147 | True, 148 | ), 149 | ("print('hello')", True), 150 | ] 151 | 152 | scores = {} 153 | tests2 = tests 154 | pad = " " 155 | for _ in range(6): 156 | print(f"n={len(tests2)} timing test" + "*" * 20) 157 | times = [] 158 | for max_workers in [4, 16, 32, 64, 128]: 159 | time0 = time.perf_counter() 160 | res = judge_parallel([test for test, r in tests2], timeout=1, max_workers=max_workers) 161 | for (test, expected), r in zip(tests2, res): 162 | assert expected == r, f"Failed expected {expected}, got {r} for {test}" 163 | 164 | times.append((max_workers, time.perf_counter() - time0)) 165 | 166 | scores[len(tests2)] = times 167 | tests2 = tests2 + [(t + pad, r) for (t, r) in tests2] 168 | pad = pad * 2 169 | print("mp.cpu_count() =", mp.cpu_count()) 170 | 171 | for n, times in scores.items(): 172 | print(n, "tests, [(max_workers, time)] =", times) 173 | 174 | 175 | if __name__ == "__main__": 176 | test() 177 | -------------------------------------------------------------------------------- /ICLR2023/src/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import utils 4 | import glob 5 | import json 6 | from typing import List 7 | from strictfire import StrictFire as Fire # aborts early on invalid arguments 8 | 9 | class WeirdInputsException(Exception): 10 | pass 11 | 12 | def get_inputs(sat: str): 13 | """Extacts arguments past the first from a function string 14 | def f(a: Dict[int, str], b=12): 15 | test 16 | 17 | should give 'b=12' 18 | """ 19 | sat = sat.replace(" -> bool", "") 20 | first_line = sat.split("\n")[0].strip() 21 | if not first_line.endswith("):") and "#" in first_line: 22 | if "):" in first_line: 23 | n = first_line.index("):") 24 | if "#" in first_line[n:]: 25 | first_line = first_line[:n + first_line[n:].index("#")].strip() 26 | else: 27 | first_line = "" # raises exception below 28 | if not (first_line.endswith("):") and first_line.startswith("def")): 29 | raise WeirdInputsException("Weird puzzle, cannot extract inputs", json.dumps(sat)) 30 | arg_str = first_line[first_line.index("("):-len("):")] 31 | depth = 0 32 | for i, c in enumerate(arg_str): 33 | if c == "," and depth == 0: 34 | return arg_str[i + 1:].strip() 35 | elif c == "[": 36 | depth += 1 37 | elif c == "]": 38 | depth -= 1 39 | return "" 40 | 41 | def main( 42 | path, 43 | filtered_name="gen_ps_filtered.txt", 44 | unfiltered_name=None, # "gen_ps_unfiltered.txt", 45 | max_sols_per_puzzle=8, 46 | seed=0): 47 | """ 48 | Merge the puzzles from the given json input files. Examples: 49 | python preprocess.py -unfiltered_name=gen_ps_unfiltered.txt -- ~/aicoder/data/gen_125M_RL/*.json 50 | 51 | path: path to search for json files 52 | filtered_name: path to write puzzles, unfiltered (default: gen_ps_filtered.txt) 53 | unfiltered_name: path to write filtered puzzles (optional) 54 | max_sols_per_puzzle: maximum number of solutions per puzzle (default 8) 55 | seed: random seed (default 0) for reproducibility 56 | infiles: list of files to read puzzles from (like /path/*.json) 57 | """ 58 | 59 | # Make the path so enumeration off that path works, even if it doesn't exist yet 60 | filtered_path = os.path.join(path, filtered_name) 61 | os.makedirs(os.path.dirname(filtered_path), exist_ok=True) 62 | 63 | codes = [] 64 | all_codes = [] 65 | 66 | # grab all the iter_* data for just this experiment 67 | gen_paths = [os.path.join(path, "../*/*.json")] 68 | 69 | # grab just the data for this iter_# for this experiment 70 | # gen_paths = [os.path.join(path, "*.json")] 71 | 72 | for gen_path in gen_paths: 73 | for filename in sorted(glob.glob(gen_path)): 74 | print("preprocess filename:", filename) 75 | js = utils.load_json(filename) 76 | for f, successes, failures in js: 77 | for body in sorted(utils.dedup(successes), key=len)[:max_sols_per_puzzle]: 78 | 79 | try: 80 | g = f"def g({get_inputs(f)}):{body}".strip("\\").strip() 81 | codes.append(f + "\n\n" + g + "\n\n" + "assert f(g())\n\n") 82 | except WeirdInputsException: 83 | print("failed to create g") 84 | pass 85 | print(f"{len(codes):,}/{len(all_codes):,} puzzles of preprocessing {filename}") 86 | 87 | print("len(codes)", len(codes)) 88 | codes = utils.dedup(codes) 89 | print("len(codes) after dedup", len(codes)) 90 | 91 | random.shuffle(codes) 92 | random.shuffle(all_codes) 93 | 94 | # Make it the same number of examples as we got from codex 95 | codes = codes[:950511] 96 | print("len(codes) after truncation", len(codes)) 97 | 98 | code = "".join(codes) 99 | 100 | utils.save_text_file(code, filtered_path) 101 | print(f"Wrote filtered results to {filtered_path}") 102 | 103 | assert unfiltered_name is None, "Not supported now, go back to earlier version" 104 | if unfiltered_name: 105 | unfiltered_path = os.path.join(path, filtered_name) 106 | utils.save_text_file("".join(all_codes), unfiltered_path) 107 | print(f"Wrote unfiltered results to {unfiltered_path}") 108 | 109 | 110 | if __name__ == "__main__": 111 | Fire(main) 112 | -------------------------------------------------------------------------------- /ICLR2023/src/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | orderedset 3 | numpy 4 | astor 5 | sklearn 6 | fire 7 | strictfire 8 | pebble 9 | deepspeed == 0.6.1 10 | transformers == 4.30.0 -------------------------------------------------------------------------------- /ICLR2023/src/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import inspect 4 | import io 5 | import os 6 | import sys 7 | import time 8 | from transformers import AutoTokenizer 9 | 10 | 11 | os.environ["WANDB_DISABLED"] = "true" 12 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 13 | my_path = os.path.dirname(__file__) 14 | 15 | 16 | def load_tokenizer(model_path): 17 | tokenizer = AutoTokenizer.from_pretrained(model_path) 18 | tokenizer.padding_side = "left" 19 | tokenizer.pad_token = tokenizer.eos_token 20 | return tokenizer 21 | 22 | 23 | def num_tokens(s: str, tokenizer, verbose=False): 24 | 25 | start_time = time.time() 26 | if verbose: 27 | info(f"Tokenizing {pretty_int(len(s))} chars ({pretty_int(len(s.splitlines()))} lines)") 28 | # ans = _tokenizer(s, return_tensors="pt").input_ids.shape[1] # produces annoying warnings 29 | ans = tokenizer(s, return_tensors="pt", max_length=10 + len(s), truncation=True).input_ids.shape[1] 30 | 31 | duration_mins = (time.time() - start_time)/60 32 | if verbose: 33 | info(f"Num tokens: {ans:,} in {duration_mins:.2f} mins") 34 | return ans 35 | 36 | 37 | def create_experiment_outpath(out: str, bSaveCommand=True): 38 | """ 39 | Create the output directory and return its name. Also stores the command line in command.sh 40 | Date format is like Jan-1-2020 41 | """ 42 | output_path = str(out).replace("", time.strftime("%b%d-%H-%M-%S")) 43 | os.makedirs(output_path, exist_ok=True) # ran into error due to non-atomic check 44 | if bSaveCommand: 45 | save_text_file(' '.join([sys.executable] + sys.argv) + "\n", f"{output_path}/command.sh") 46 | # make command.sh executable: 47 | os.chmod(f"{output_path}/command.sh", 0o755) 48 | return output_path 49 | 50 | def pretty_int(n: int) -> str: 51 | """Converts an integer to a string with commas, with M for millions and B for billions""" 52 | if n > 1_000_000_000: 53 | return f"{n/1_000_000_000:.1f}B" 54 | if n > 1_000_000: 55 | return f"{n/1_000_000:.1f}M" 56 | return f"{n:,}" 57 | 58 | 59 | 60 | def test_puzzle(f, x): 61 | """Checks if x is of the correct type and makes f return True (literally True, not an integer or whatever) 62 | 63 | :param f: Puzzle 64 | :param x: candidate answer 65 | :return: 66 | """ 67 | answer_type = list(f.__annotations__.values())[0] 68 | if not type_check(x, answer_type): 69 | raise TypeError 70 | return f(x) is True 71 | 72 | 73 | 74 | def type_check(obj, typ): 75 | """ 76 | check if obj is of type `typ` where `typ` is a `typing` module type annotation, eg List[int] 77 | The way we do this to be compatible across versions is we first convert the type to a string. 78 | """ 79 | 80 | type_str = str(typ).replace("typing.", "") 81 | if type_str.startswith(" 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /generators/__init__.py: -------------------------------------------------------------------------------- 1 | # problems are identified automatically by looking for subclasses (and sub-subclasses) of the Problem class 2 | 3 | from . import study 4 | from . import classic_puzzles 5 | from . import human_eval 6 | from . import codeforces 7 | from . import algebra 8 | from . import basic 9 | from . import chess 10 | from . import compression 11 | from . import conways_game_of_life 12 | from . import games 13 | from . import graphs 14 | from . import ICPC 15 | from . import IMO 16 | from . import lattices 17 | from . import number_theory 18 | from . import probability 19 | from . import trivial_inverse 20 | from . import tutorial -------------------------------------------------------------------------------- /generators/algebra.py: -------------------------------------------------------------------------------- 1 | """Roots of polynomials""" 2 | 3 | from puzzle_generator import PuzzleGenerator, Tags 4 | from typing import List 5 | 6 | 7 | # See https://github.com/microsoft/PythonProgrammingPuzzles/wiki/How-to-add-a-puzzle to learn about adding puzzles 8 | 9 | 10 | class QuadraticRoot(PuzzleGenerator): 11 | """See [quadratic equations](https://en.wikipedia.org/wiki/Quadratic_formula)""" 12 | 13 | tags = [Tags.math, Tags.famous] 14 | 15 | @staticmethod 16 | def sat(x: float, coeffs=[2.5, 1.3, -0.5]): 17 | """ 18 | Find any (real) solution to: a x^2 + b x + c where coeffs = [a, b, c]. 19 | For example, since x^2 - 3x + 2 has a root at 1, sat(x = 1., coeffs = [1., -3., 2.]) is True. 20 | """ 21 | a, b, c = coeffs 22 | return abs(a * x ** 2 + b * x + c) < 1e-6 23 | 24 | @staticmethod 25 | def sol(coeffs): 26 | a, b, c = coeffs 27 | if a == 0: 28 | ans = -c / b if b != 0 else 0.0 29 | else: 30 | ans = ((-b + (b ** 2 - 4 * a * c) ** 0.5) / (2 * a)) 31 | return ans 32 | 33 | @staticmethod 34 | def sol2(coeffs): 35 | a, b, c = coeffs 36 | if a == 0: 37 | ans = -c / b if b != 0 else 0.0 38 | else: 39 | ans = (-b - (b ** 2 - 4 * a * c) ** 0.5) / (2 * a) 40 | return ans 41 | 42 | def gen_random(self): 43 | x, a, b = [self.random.heavy_tail_float() for _ in range(3)] 44 | c = -(a * x ** 2 + b * x) # make sure it has a real-valued solution 45 | coeffs = [a, b, c] 46 | self.add(dict(coeffs=coeffs)) 47 | 48 | 49 | class AllQuadraticRoots(PuzzleGenerator): 50 | """See [quadratic equations](https://en.wikipedia.org/wiki/Quadratic_formula).""" 51 | 52 | tags = [Tags.math, Tags.famous] 53 | 54 | @staticmethod 55 | def sat(roots: List[float], coeffs=[1.3, -0.5]): 56 | """Find all (real) solutions to: x^2 + b x + c (i.e., factor into roots), here coeffs = [b, c]""" 57 | b, c = coeffs 58 | r1, r2 = roots 59 | return abs(r1 + r2 + b) + abs(r1 * r2 - c) < 1e-6 60 | 61 | @staticmethod 62 | def sol(coeffs): 63 | b, c = coeffs 64 | delta = (b ** 2 - 4 * c) ** 0.5 65 | return [(-b + delta) / 2, (-b - delta) / 2] 66 | 67 | def gen_random(self): 68 | x, b = [self.random.heavy_tail_float() for _ in range(2)] 69 | c = -(x ** 2 + b * x) # make sure it has a real-valued solution 70 | coeffs = [b, c] 71 | self.add(dict(coeffs=coeffs)) 72 | 73 | 74 | class CubicRoot(PuzzleGenerator): 75 | """See [cubic equation](https://en.wikipedia.org/wiki/Cubic_formula).""" 76 | 77 | tags = [Tags.math, Tags.famous] 78 | 79 | @staticmethod 80 | def sat(x: float, coeffs=[2.0, 1.0, 0.0, 8.0]): 81 | """ 82 | Find any (real) solution to: a x^3 + b x^2 + c x + d where coeffs = [a, b, c, d] 83 | For example, since (x-1)(x-2)(x-3) = x^3 - 6x^2 + 11x - 6, sat(x = 1., coeffs = [-6., 11., -6.]) is True. 84 | """ 85 | return abs(sum(c * x ** (3 - i) for i, c in enumerate(coeffs))) < 1e-6 86 | 87 | @staticmethod 88 | def sol(coeffs): 89 | a2, a1, a0 = [c / coeffs[0] for c in coeffs[1:]] 90 | p = (3 * a1 - a2 ** 2) / 3 91 | q = (9 * a1 * a2 - 27 * a0 - 2 * a2 ** 3) / 27 92 | delta = (q ** 2 + 4 * p ** 3 / 27) ** 0.5 93 | omega = (-(-1) ** (1 / 3)) 94 | for cube in [(q + delta) / 2, (q - delta) / 2]: 95 | c = cube ** (1 / 3) 96 | for w in [c, c * omega, c * omega.conjugate()]: 97 | if w != 0: 98 | x = complex(w - p / (3 * w) - a2 / 3).real 99 | if abs(sum(c * x ** (3 - i) for i, c in enumerate(coeffs))) < 1e-6: 100 | return x 101 | 102 | def gen_random(self): 103 | x, a, b, c = [self.random.heavy_tail_float() for _ in range(4)] 104 | d = -(a * x ** 3 + b * x ** 2 + c * x) # make sure it has a real-valued solution 105 | coeffs = [a, b, c, d] 106 | if self.sol(coeffs) is not None: 107 | self.add(dict(coeffs=coeffs)) 108 | 109 | 110 | class AllCubicRoots(PuzzleGenerator): 111 | """See [cubic equation](https://en.wikipedia.org/wiki/Cubic_formula).""" 112 | 113 | tags = [Tags.math, Tags.famous] 114 | 115 | @staticmethod 116 | def sat(roots: List[float], coeffs=[1.0, -2.0, -1.0]): 117 | """Find all 3 distinct real roots of x^3 + a x^2 + b x + c, i.e., factor into (x-r1)(x-r2)(x-r3). 118 | coeffs = [a, b, c]. For example, since (x-1)(x-2)(x-3) = x^3 - 6x^2 + 11x - 6, 119 | sat(roots = [1., 2., 3.], coeffs = [-6., 11., -6.]) is True. 120 | """ 121 | r1, r2, r3 = roots 122 | a, b, c = coeffs 123 | return abs(r1 + r2 + r3 + a) + abs(r1 * r2 + r1 * r3 + r2 * r3 - b) + abs(r1 * r2 * r3 + c) < 1e-6 124 | 125 | @staticmethod 126 | def sol(coeffs): 127 | a, b, c = coeffs 128 | p = (3 * b - a ** 2) / 3 129 | q = (9 * b * a - 27 * c - 2 * a ** 3) / 27 130 | delta = (q ** 2 + 4 * p ** 3 / 27) ** 0.5 131 | omega = (-(-1) ** (1 / 3)) 132 | ans = [] 133 | for cube in [(q + delta) / 2, (q - delta) / 2]: 134 | v = cube ** (1 / 3) 135 | for w in [v, v * omega, v * omega.conjugate()]: 136 | if w != 0.0: 137 | x = complex(w - p / (3 * w) - a / 3).real 138 | if abs(x ** 3 + a * x ** 2 + b * x + c) < 1e-4: 139 | if not ans or min(abs(z - x) for z in ans) > 1e-6: 140 | ans.append(x) 141 | if len(ans) == 3: 142 | return ans 143 | 144 | def gen_random(self): 145 | r1, r2, r3 = [self.random.heavy_tail_float() for _ in range(3)] 146 | coeffs = [-r1 - r2 - r3, r1 * r2 + r1 * r3 + r2 * r3, -r1 * r2 * r3] # to ensure solvability 147 | if self.sol(coeffs) is not None: 148 | self.add(dict(coeffs=coeffs)) # won't add duplicates 149 | 150 | 151 | if __name__ == "__main__": 152 | PuzzleGenerator.debug_problems() 153 | -------------------------------------------------------------------------------- /generators/compression.py: -------------------------------------------------------------------------------- 1 | """Puzzles relating to de/compression.""" 2 | 3 | from puzzle_generator import PuzzleGenerator, Tags 4 | from typing import List 5 | 6 | 7 | # See https://github.com/microsoft/PythonProgrammingPuzzles/wiki/How-to-add-a-puzzle to learn about adding puzzles 8 | 9 | 10 | def _compress_LZW(text): # for development 11 | index = {chr(i): i for i in range(256)} 12 | 13 | seq = [] 14 | buffer = "" 15 | for c in text: 16 | if buffer + c in index: 17 | buffer += c 18 | continue 19 | seq.append(index[buffer]) 20 | index[buffer + c] = len(index) + 1 21 | buffer = c 22 | 23 | if text != "": 24 | seq.append(index[buffer]) 25 | 26 | return seq 27 | 28 | 29 | def _decompress_LZW(seq: List[int]): # for development 30 | index = [chr(i) for i in range(256)] 31 | pieces = [""] 32 | for i in seq: 33 | pieces.append(pieces[-1] + pieces[-1][0] if i == len(index) else index[i]) 34 | index.append(pieces[-2] + pieces[-1][0]) 35 | return "".join(pieces) 36 | 37 | 38 | class LZW(PuzzleGenerator): 39 | """ 40 | We have provided a simple version of the *decompression* algorithm of 41 | [Lempel-Ziv-Welch](https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv%E2%80%93Welch) 42 | so the solution is the *compression* algorithm. 43 | """ 44 | 45 | tags = [Tags.strings, Tags.famous] 46 | 47 | 48 | # _compress_LZW("Hellooooooooooooooooooooo world!") is length-17 49 | 50 | @staticmethod 51 | def sat(seq: List[int], compressed_len=17, text="Hellooooooooooooooooooooo world!"): 52 | """ 53 | Find a (short) compression that decompresses to the given string for the provided implementation of the 54 | Lempel-Ziv decompression algorithm from https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv%E2%80%93Welch 55 | """ 56 | index = [chr(i) for i in range(256)] 57 | pieces = [""] 58 | for i in seq: 59 | pieces.append((pieces[-1] + pieces[-1][0]) if i == len(index) else index[i]) 60 | index.append(pieces[-2] + pieces[-1][0]) 61 | return "".join(pieces) == text and len(seq) <= compressed_len 62 | 63 | @staticmethod 64 | def sol(compressed_len, text): 65 | # compressed_len is ignored 66 | index = {chr(i): i for i in range(256)} 67 | seq = [] 68 | buffer = "" 69 | for c in text: 70 | if buffer + c in index: 71 | buffer += c 72 | continue 73 | seq.append(index[buffer]) 74 | index[buffer + c] = len(index) + 1 75 | buffer = c 76 | 77 | if text != "": 78 | seq.append(index[buffer]) 79 | 80 | return seq 81 | 82 | def gen(self, _target_num_instances): 83 | self.add({"text": "", "compressed_len": 0}) 84 | self.add({"text": "c" * 1000, "compressed_len": len(_compress_LZW("c" * 1000))}) 85 | 86 | def gen_random(self): 87 | max_len = self.random.choice([10, 100, 1000]) 88 | text = self.random.pseudo_word(0, max_len) 89 | self.add({"text": text, "compressed_len": len(_compress_LZW(text))}) 90 | 91 | # Removed this puzzle because the puzzle statement would give away the solution to LZW, haha! 92 | # class LZW_decompress(PuzzleGenerator): 93 | # """We have provided a simple version of the 94 | # [Lempel-Ziv-Welch](https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv%E2%80%93Welch) 95 | # and the solution is the *decompression* algorithm. 96 | # """ 97 | # 98 | # @staticmethod 99 | # def sat(text: str, seq=[72, 101, 108, 108, 111, 32, 119, 111, 114, 100, 262, 264, 266, 263, 265, 33]): 100 | # """ 101 | # Find a string that compresses to the target sequence for the provided implementation of the 102 | # Lempel-Ziv algorithm from https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv%E2%80%93Welch 103 | # """ 104 | # index = {chr(i): i for i in range(256)} 105 | # seq2 = [] 106 | # buffer = "" 107 | # for c in text: 108 | # if buffer + c in index: 109 | # buffer += c 110 | # continue 111 | # seq2.append(index[buffer]) 112 | # index[buffer + c] = len(index) + 1 113 | # buffer = c 114 | # 115 | # if text != "": 116 | # seq2.append(index[buffer]) 117 | # 118 | # return seq2 == seq 119 | # 120 | # @staticmethod 121 | # def sol(seq): 122 | # index = [chr(i) for i in range(256)] 123 | # pieces = [""] 124 | # for i in seq: 125 | # pieces.append(pieces[-1] + pieces[-1][0] if i == len(index) else index[i]) 126 | # index.append(pieces[-2] + pieces[-1][0]) 127 | # return "".join(pieces) 128 | # 129 | # def gen(self, _target_num_instances): 130 | # for s in ['', 'a', 'b' * 1000, 'ab' * 1000 + '!']: 131 | # self.add({"seq": _compress_LZW(s)}) 132 | # 133 | # def gen_random(self): 134 | # max_len = self.random.choice([10, 100, 1000]) 135 | # text = self.random.pseudo_word(0, max_len) 136 | # self.add({"seq": _compress_LZW(text)}) 137 | 138 | 139 | class PackingHam(PuzzleGenerator): 140 | """ 141 | This packing problem a [classic problem](https://en.wikipedia.org/wiki/Sphere_packing#Other_spaces) 142 | in coding theory. 143 | """ 144 | 145 | tags = [Tags.strings, Tags.famous] 146 | 147 | @staticmethod 148 | def sat(words: List[str], num=100, bits=100, dist=34): 149 | """Pack a certain number of binary strings so that they have a minimum hamming distance between each other.""" 150 | assert len(words) == num and all(len(word) == bits and set(word) <= {"0", "1"} for word in words) 151 | return all(sum([a != b for a, b in zip(words[i], words[j])]) >= dist for i in range(num) for j in range(i)) 152 | 153 | @staticmethod 154 | def sol(num, bits, dist): 155 | import random # key insight, use randomness! 156 | r = random.Random(0) 157 | while True: 158 | seqs = [r.getrandbits(bits) for _ in range(num)] 159 | if all(bin(seqs[i] ^ seqs[j]).count("1") >= dist for i in range(num) for j in range(i)): 160 | return [bin(s)[2:].rjust(bits, '0') for s in seqs] 161 | 162 | def gen_random(self): 163 | bits = self.random.randrange(1, self.random.choice([10, 100])) 164 | num = self.random.randrange(2, self.random.choice([10, 100])) 165 | 166 | def score(seqs): 167 | return min(bin(seqs[i] ^ seqs[j]).count("1") for i in range(num) for j in range(i)) 168 | 169 | # best of 5 170 | seqs = min([[self.random.getrandbits(bits) for _ in range(num)] for _ in range(5)], key=score) 171 | dist = score(seqs) 172 | if dist > 0: 173 | self.add(dict(num=num, bits=bits, dist=dist)) 174 | 175 | 176 | if __name__ == "__main__": 177 | PuzzleGenerator.debug_problems() 178 | -------------------------------------------------------------------------------- /generators/lattices.py: -------------------------------------------------------------------------------- 1 | """Lattice problems with and without noise""" 2 | 3 | from puzzle_generator import PuzzleGenerator, Tags 4 | from typing import List 5 | 6 | 7 | # See https://github.com/microsoft/PythonProgrammingPuzzles/wiki/How-to-add-a-puzzle to learn about adding puzzles 8 | 9 | 10 | class LearnParity(PuzzleGenerator): 11 | """Parity learning (Gaussian elimination) 12 | 13 | The canonical solution to this 14 | [Parity learning problem](https://en.wikipedia.org/w/index.php?title=Parity_learning) 15 | is to use 16 | [Gaussian Elimination](https://en.wikipedia.org/w/index.php?title=Gaussian_elimination). 17 | 18 | The vectors are encoded as binary integers for succinctness. 19 | """ 20 | 21 | @staticmethod 22 | def sat(inds: List[int], vecs=[169, 203, 409, 50, 37, 479, 370, 133, 53, 159, 161, 367, 474, 107, 82, 447, 385]): 23 | """ 24 | Parity learning: Given binary vectors in a subspace, find the secret set S of indices such that: 25 | $\\sum_{i \in S} x_i = 1 (mod 2)$ 26 | """ 27 | return all(sum((v >> i) & 1 for i in inds) % 2 == 1 for v in vecs) 28 | 29 | @staticmethod 30 | def sol(vecs): 31 | # Gaussian elimination 32 | d = 0 # decode vectors into arrays 33 | m = max(vecs) 34 | while m: 35 | m >>= 1 36 | d += 1 37 | vecs = [[(n >> i) & 1 for i in range(d)] for n in vecs] 38 | ans = [] 39 | pool = [[0] * (d + 1) for _ in range(d)] + [v + [1] for v in vecs] 40 | for i in range(d): 41 | pool[i][i] = 1 42 | 43 | for i in range(d): # zero out bit i 44 | for v in pool[d:]: 45 | if v[i] == 1: 46 | break 47 | if v[i] == 0: 48 | v = pool[i] 49 | assert v[i] == 1 # found a vector with v[i] = 1, subtract it off from those with a 1 in the ith coordinate 50 | w = v[:] 51 | for v in pool: 52 | if v[i] == 1: 53 | for j in range(d + 1): 54 | v[j] ^= w[j] 55 | 56 | return [i for i in range(d) if pool[i][-1]] 57 | 58 | @staticmethod 59 | def rand_parity_problem(rand, d=63): 60 | secret = rand.sample(range(d), d // 2) 61 | num_vecs = d + 9 62 | vecs = [[rand.randrange(2) for _ in range(d)] for i in range(num_vecs)] 63 | for v in vecs: 64 | v[secret[0]] = (1 + sum([v[i] for i in secret[1:]])) % 2 65 | vecs = [sum(1 << i for i, b in enumerate(v) if b) for v in vecs] # encode into ints 66 | return vecs 67 | 68 | def gen(self, target_num_instances): 69 | vecs = self.rand_parity_problem(self.random, d=63) 70 | self.add(dict(vecs=vecs), multiplier=10) 71 | 72 | def gen_random(self): 73 | d = self.random.randrange(2, self.random.choice([5, 10, 20, 100])) 74 | vecs = self.rand_parity_problem( 75 | self.random, 76 | d=d, 77 | ) 78 | self.add(dict(vecs=vecs), multiplier=10 if d > 9 else 1) 79 | 80 | 81 | class LearnParityWithNoise(PuzzleGenerator): 82 | """Learn parity with noise (*unsolved*) 83 | 84 | The fastest known algorithm to this 85 | [Parity learning problem](https://en.wikipedia.org/w/index.php?title=Parity_learning) 86 | runs in time $2^(d/(log d))$ 87 | 88 | The example puzzle has small dimension so is easily solvable, but other instances are much harder. 89 | """ 90 | 91 | @staticmethod 92 | def sat(inds: List[int], vecs=[26, 5, 32, 3, 15, 18, 31, 13, 24, 25, 34, 5, 15, 24, 16, 13, 0, 27, 37]): 93 | """ 94 | Learning parity with noise: Given binary vectors, find the secret set $S$ of indices such that, for at least 95 | 3/4 of the vectors, $$sum_{i \in S} x_i = 1 (mod 2)$$ 96 | """ 97 | return sum(sum((v >> i) & 1 for i in inds) % 2 for v in vecs) >= len(vecs) * 3 / 4 98 | 99 | @staticmethod 100 | def sol(vecs): 101 | # brute force 102 | d = 0 # decode vectors into arrays 103 | m = max(vecs) 104 | while m: 105 | m >>= 1 106 | d += 1 107 | vecs = [[(n >> i) & 1 for i in range(d)] for n in vecs] 108 | 109 | import random 110 | rand = random.Random(0) 111 | target = (len(vecs) * 3) // 4 112 | max_attempts = 10 ** 5 113 | for _ in range(max_attempts): 114 | ans = [i for i in range(d) if rand.randrange(2)] 115 | if sum(sum(v[i] for i in ans) % 2 for v in vecs) >= len(vecs) * 3 / 4: 116 | return ans 117 | 118 | @staticmethod 119 | def rand_parity_problem(rand, d=63): 120 | secret = rand.sample(range(d), d // 2) 121 | num_vecs = 2 * d + 5 122 | vecs = [[rand.randrange(2) for _ in range(d)] for i in range(num_vecs)] 123 | for v in vecs: 124 | v[secret[0]] = (1 + sum([v[i] for i in secret[1:]])) % 2 125 | mistakes = rand.sample(vecs, int(len(vecs) * rand.random() * 1 / 4)) 126 | for v in mistakes: 127 | v[secret[0]] ^= 1 # flip bit in mistakes 128 | vecs = [sum(1 << i for i, b in enumerate(v) if b) for v in vecs] # encode into ints 129 | return vecs 130 | 131 | def gen(self, target_num_instances): 132 | vecs = self.rand_parity_problem(self.random, d=63) 133 | self.add(dict(vecs=vecs), test=False, multiplier=1000) 134 | 135 | def gen_random(self): 136 | d = self.random.randrange(2, self.random.choice([11, 100])) # number of dimensions 137 | vecs = self.rand_parity_problem( 138 | self.random, 139 | d=d 140 | ) 141 | self.add(dict(vecs=vecs), test=d < 19, multiplier=1000 if d > 40 else 30 if d >= 19 else 1) 142 | 143 | if __name__ == "__main__": 144 | PuzzleGenerator.debug_problems() 145 | -------------------------------------------------------------------------------- /generators/probability.py: -------------------------------------------------------------------------------- 1 | """Probability problems""" 2 | 3 | from puzzle_generator import PuzzleGenerator, Tags 4 | from typing import List 5 | 6 | 7 | # See https://github.com/microsoft/PythonProgrammingPuzzles/wiki/How-to-add-a-puzzle to learn about adding puzzles 8 | 9 | 10 | class BirthdayParadox(PuzzleGenerator): 11 | """ 12 | Adaptation of the classic 13 | [Birthday Problem](https://en.wikipedia.org/wiki/Birthday_problem (Mathematical Problems category)). 14 | 15 | The year length is year_len (365 is earth, while Neptune year is 60,182). 16 | """ 17 | 18 | @staticmethod 19 | def sat(n: int, year_len=365): 20 | """Find n such that the probability of two people having the same birthday in a group of n is near 1/2.""" 21 | prob = 1.0 22 | for i in range(n): 23 | prob *= (year_len - i) / year_len 24 | return (prob - 0.5) ** 2 <= 1/year_len 25 | 26 | @staticmethod 27 | def sol(year_len): 28 | n = 1 29 | distinct_prob = 1.0 30 | best = (0.5, 1) # (difference between probability and 1/2, n) 31 | while distinct_prob > 0.5: 32 | distinct_prob *= (year_len - n) / year_len 33 | n += 1 34 | best = min(best, (abs(0.5 - distinct_prob), n)) 35 | 36 | return best[1] 37 | 38 | def safe_add(self, **inputs): 39 | if self.sat(self.sol(**inputs), **inputs): 40 | self.add(inputs) 41 | 42 | def gen(self, target_num_instances): 43 | self.safe_add(year_len=60182) # Neptune year! 44 | for year_len in range(2, target_num_instances): 45 | self.safe_add(year_len=year_len) 46 | 47 | 48 | 49 | class BirthdayParadoxMonteCarlo(BirthdayParadox): 50 | """A slower, Monte Carlo version of the above Birthday Paradox problem.""" 51 | 52 | @staticmethod 53 | def sat(n: int, year_len=365): 54 | """Find n such that the probability of two people having the same birthday in a group of n is near 1/2.""" 55 | import random 56 | random.seed(0) 57 | K = 1000 # number of samples 58 | prob = sum(len({random.randrange(year_len) for i in range(n)}) < n for j in range(K)) / K 59 | return (prob - 0.5) ** 2 <= year_len 60 | 61 | 62 | 63 | 64 | 65 | class BallotProblem(PuzzleGenerator): 66 | """ 67 | See the [Wikipedia article](https://en.wikipedia.org/wiki/Bertrand%27s_ballot_theorem) or 68 | or [Addario-Berry L., Reed B.A. (2008) Ballot Theorems, Old and New. In: Gyori E., Katona G.O.H., Lovász L., 69 | Sági G. (eds) Horizons of Combinatorics. Bolyai Society Mathematical Studies, vol 17. 70 | Springer, Berlin, Heidelberg.](https://doi.org/10.1007/978-3-540-77200-2_1) 71 | """ 72 | 73 | @staticmethod 74 | def sat(counts: List[int], target_prob=0.5): 75 | """ 76 | Suppose a list of m 1's and n -1's are permuted at random. 77 | What is the probability that all of the cumulative sums are positive? 78 | The goal is to find counts = [m, n] that make the probability of the ballot problem close to target_prob. 79 | """ 80 | m, n = counts # m = num 1's, n = num -1's 81 | probs = [1.0] + [0.0] * n # probs[n] is probability for current m, starting with m = 1 82 | for i in range(2, m + 1): # compute probs using dynamic programming for m = i 83 | old_probs = probs 84 | probs = [1.0] + [0.0] * n 85 | for j in range(1, min(n + 1, i)): 86 | probs[j] = ( 87 | j / (i + j) * probs[j - 1] # last element is a -1 so use probs 88 | + 89 | i / (i + j) * old_probs[j] # last element is a 1 so use old_probs, m = i - 1 90 | ) 91 | return abs(probs[n] - target_prob) < 1e-6 92 | 93 | @staticmethod 94 | def sol(target_prob): 95 | for m in range(1, 10000): 96 | n = round(m * (1 - target_prob) / (1 + target_prob)) 97 | if abs(target_prob - (m - n) / (m + n)) < 1e-6: 98 | return [m, n] 99 | 100 | def gen_random(self): 101 | m = self.random.randrange(1, self.random.choice([10, 100, 200, 300, 400, 500, 1000])) 102 | n = self.random.randrange(1, m + 1) 103 | target_prob = (m - n) / (m + n) 104 | self.add(dict(target_prob=target_prob)) 105 | 106 | 107 | class BinomialProbabilities(PuzzleGenerator): 108 | """See [Binomial distribution](https://en.wikipedia.org/wiki/Binomial_distribution)""" 109 | 110 | @staticmethod 111 | def sat(counts: List[int], p=0.5, target_prob=1 / 16.0): 112 | """Find counts = [a, b] so that the probability of a H's and b T's among a + b coin flips is ~ target_prob.""" 113 | from itertools import product 114 | a, b = counts 115 | n = a + b 116 | prob = (p ** a) * ((1-p) ** b) 117 | tot = sum([prob for sample in product([0, 1], repeat=n) if sum(sample) == a]) 118 | return abs(tot - target_prob) < 1e-6 119 | 120 | @staticmethod 121 | def sol(p, target_prob): 122 | probs = [1.0] 123 | q = 1 - p 124 | while len(probs) < 20: 125 | probs = [(p * a + q * b) for a, b in zip([0] + probs, probs + [0])] 126 | answers = [i for i, p in enumerate(probs) if abs(p - target_prob) < 1e-6] 127 | if answers: 128 | return [answers[0], len(probs) - 1 - answers[0]] 129 | 130 | def gen_random(self): 131 | probs = [1.0] 132 | p = self.random.random() 133 | q = 1 - p 134 | for n in range(self.random.randrange(1, 11)): 135 | probs = [(p * a + q * b) for a, b in zip([0] + probs, probs + [0])] 136 | target_prob = self.random.choice(probs) 137 | self.add(dict(p=p, target_prob=target_prob)) 138 | 139 | 140 | 141 | class ExponentialProbability(PuzzleGenerator): 142 | """See [Exponential distribution](https://en.wikipedia.org/wiki/Exponential_distribution)""" 143 | 144 | @staticmethod 145 | def sat(p_stop: float, steps=10, target_prob=0.5): 146 | """ 147 | Find p_stop so that the probability of stopping in steps or fewer time steps is the given target_prob if you 148 | stop each step with probability p_stop 149 | """ 150 | prob = sum(p_stop*(1-p_stop)**t for t in range(steps)) 151 | return abs(prob - target_prob) < 1e-6 152 | 153 | @staticmethod 154 | def sol(steps, target_prob): 155 | return 1 - (1 - target_prob) ** (1.0/steps) 156 | 157 | def gen_random(self): 158 | steps = self.random.randrange(1, 100) 159 | target_prob = self.random.random() 160 | self.add(dict(steps=steps, target_prob=target_prob)) 161 | 162 | 163 | 164 | if __name__ == "__main__": 165 | PuzzleGenerator.debug_problems() 166 | -------------------------------------------------------------------------------- /generators/tutorial.py: -------------------------------------------------------------------------------- 1 | """ 2 | A few example puzzles that were presented with solutions to participants of the study. 3 | """ 4 | 5 | from puzzle_generator import PuzzleGenerator, Tags 6 | from typing import List 7 | 8 | 9 | # See https://github.com/microsoft/PythonProgrammingPuzzles/wiki/How-to-add-a-puzzle to learn about adding puzzles 10 | 11 | 12 | class Tutorial1(PuzzleGenerator): 13 | @staticmethod 14 | def sat(s: str): 15 | """Find a string that when concatenated onto 'Hello ' gives 'Hello world'.""" 16 | return "Hello " + s == "Hello world" 17 | 18 | @staticmethod 19 | def sol(): 20 | return "world" 21 | 22 | 23 | class Tutorial2(PuzzleGenerator): 24 | @staticmethod 25 | def sat(s: str): 26 | """Find a string that when reversed and concatenated onto 'Hello ' gives 'Hello world'.""" 27 | return "Hello " + s[::-1] == "Hello world" 28 | 29 | @staticmethod 30 | def sol(): 31 | return "world"[::-1] 32 | 33 | 34 | class Tutorial3(PuzzleGenerator): 35 | @staticmethod 36 | def sat(x: List[int]): 37 | """Find a list of two integers whose sum is 3.""" 38 | return len(x) == 2 and sum(x) == 3 39 | 40 | @staticmethod 41 | def sol(): 42 | return [1, 2] 43 | 44 | 45 | class Tutorial4(PuzzleGenerator): 46 | @staticmethod 47 | def sat(s: List[str]): 48 | """Find a list of 1000 distinct strings which each have more 'a's than 'b's and at least one 'b'.""" 49 | return len(set(s)) == 1000 and all((x.count("a") > x.count("b")) and ('b' in x) for x in s) 50 | 51 | @staticmethod 52 | def sol(): 53 | return ["a" * (i + 2) + "b" for i in range(1000)] 54 | 55 | 56 | class Tutorial5(PuzzleGenerator): 57 | @staticmethod 58 | def sat(n: int): 59 | """Find an integer whose perfect square begins with 123456789 in its decimal representation.""" 60 | return str(n * n).startswith("123456789") 61 | 62 | @staticmethod 63 | def sol(): 64 | return int(int("123456789" + "0" * 9) ** 0.5) + 1 65 | 66 | 67 | # Not clear this last one is necessary/helpful 68 | # # class Tutorial6(Problem): 69 | # """Find a string corresponding to a decimal number whose negation is 1337.""" 70 | # 71 | # @staticmethod 72 | # def sat(s: str): 73 | # return -1 * int(s) == 1337 74 | # 75 | # @staticmethod 76 | # def sol(): 77 | # return str(-1337) 78 | 79 | if __name__ == "__main__": 80 | PuzzleGenerator.debug_problems() 81 | -------------------------------------------------------------------------------- /notebooks/Demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "deletable": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# Please run this cell first.\n", 12 | "from demo.demo import next_puzzle, cur_puzzle, puzzle, give_up" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "# Instructions\n", 20 | "\n", 21 | "* This [Jupyter](https://jupyter.org/) notebook is a demo simulating the study used for the [**Programming Puzzles**](https://arxiv.org/abs/2106.05784) paper (with some modifications to run locally).\n", 22 | "* You can reuse the same notebook cells to solve all puzzles (run cell with Ctrl+Enter), or open new cells as you advance (Shift+Enter).\n", 23 | "\n", 24 | "\n", 25 | "* **Visit our repository** to explore the full dataset and **contribute your own new puzzles**: https://github.com/microsoft/PythonProgrammingPuzzles.\n", 26 | "* Share this demo with friends by sending this link: https://aka.ms/python_puzzles_study\n", 27 | "* For a shorter intro, and to see which puzzles the AI baselines solved, visit: https://aka.ms/python_puzzles\n", 28 | "\n", 29 | "## Overview\n", 30 | "\n", 31 | "\n", 32 | "* The first 3 problems are \"practice\" and the time you take will not be counted. This is a good chance to see how the system works.\n", 33 | "* Puzzles are defined by `def puzzle(...)`. For each puzzle, you will try to find an input `x` which makes `puzzle(x)` return `True`.\n", 34 | "* Type `next_puzzle()` when you are ready to start the first problem or to advance to the next problem.\n", 35 | "* There is **no option to revisit a puzzle** and once you call `next_puzzle()` the clock starts ticking (you have up to 6 minutes per puzzle).\n", 36 | "* **If you get lost, call `cur_puzzle()`** to see the current puzzle you are on (and time bar).\n", 37 | "\n", 38 | "## Timing\n", 39 | "\n", 40 | "* You have up to 6 minutes for each puzzle.\n", 41 | "* If you do not solve a problem in **6 minutes**, move to the next puzzle by typing `next_puzzle()`.\n", 42 | "* If you would like to give up and skip to the next puzzle without waiting, you can call `give_up()`.\n", 43 | "\n", 44 | "\n", 45 | "## Possible issues\n", 46 | "\n", 47 | "* Remember to use `cur_puzzle()` if you lost the code of the current puzzle.\n", 48 | "* For any feedback, please visit our [GitHub repository](https://github.com/microsoft/PythonProgrammingPuzzles) and reach out to us by email.\n", 49 | "\n", 50 | "\n", 51 | "## Summary of functions\n", 52 | "\n", 53 | "| Function \t| Description \t|\n", 54 | "|:----------------------\t|:------------------------------------------------------------------------------------------\t|\n", 55 | "|`next_puzzle()` \t| Start the next puzzle (call only when you are ready to start! no revisiting) \t|\n", 56 | "|`cur_puzzle()` \t| Present the current puzzle (useful if you got lost or accidentally overridden `puzzle()`) \t|\n", 57 | "|`puzzle(...)` \t| Submit a solution to the current puzzle \t|\n", 58 | "|`give_up()` \t| Give up and skip to the next puzzle *before* 6 minutes have passed. Please avoid this option if possible. |" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# The first 3 puzzles are warmup. Begin the warmup part by running this cell.\n", 68 | "\n", 69 | "next_puzzle()" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# Solve the first puzzle by running this cell.\n", 79 | "\n", 80 | "puzzle(\"world\")" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "# when you are ready to continue, run this cell.\n", 90 | "\n", 91 | "next_puzzle()" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# You're on your own. Have fun.\n" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [] 116 | } 117 | ], 118 | "metadata": { 119 | "kernelspec": { 120 | "display_name": "Python 3", 121 | "language": "python", 122 | "name": "python3" 123 | }, 124 | "language_info": { 125 | "codemirror_mode": { 126 | "name": "ipython", 127 | "version": 3 128 | }, 129 | "file_extension": ".py", 130 | "mimetype": "text/x-python", 131 | "name": "python", 132 | "nbconvert_exporter": "python", 133 | "pygments_lexer": "ipython3", 134 | "version": "3.8.8" 135 | } 136 | }, 137 | "nbformat": 4, 138 | "nbformat_minor": 4 139 | } -------------------------------------------------------------------------------- /notebooks/demo/study_puzzles.py: -------------------------------------------------------------------------------- 1 | # provides get_puzzles 2 | 3 | # parts separated by ======== (exactly 8) 4 | # puzzles by ----'s (exactly 4) 5 | 6 | 7 | def get_puzzles(): 8 | """ 9 | Creates the puzzles for the study 10 | 11 | :return: list of {"src": str, "name": str, "part": str} 12 | """ 13 | part_names = ["WARM UP", "PART 1/3", "PART 2/3", "PART 3/3"] 14 | 15 | raw_puzzles = ''' 16 | 17 | def puzzle(s: str): 18 | """ 19 | Warmup problem. 20 | """ 21 | return "Hello " + s == 'Hello world' 22 | 23 | ---- 24 | 25 | def puzzle(n: int): 26 | """ 27 | Hint: puzzle(111111111) works. 28 | """ 29 | return str(n * n).startswith("123456789") 30 | 31 | ---- 32 | 33 | def puzzle(x: str): 34 | """ 35 | Hint: note that the input should be a string. 36 | """ 37 | return -1 * int(x) == 1337 38 | 39 | 40 | ======== 41 | 42 | def puzzle(s: str): 43 | return s.count('o') == 1000 and s.count('oo') == 0 44 | 45 | 46 | ---- 47 | 48 | 49 | def puzzle(s: str): 50 | return s.count('o') == 1000 and s.count('oo') == 100 and s.count('ho') == 801 51 | 52 | ---- 53 | 54 | def puzzle(x: List[int]): 55 | return sorted(x) == list(range(999)) and all(x[i] != i for i in range(len(x))) 56 | 57 | ---- 58 | 59 | def puzzle(x: List[int]): 60 | return len(x) == 10 and x.count(x[3]) == 2 61 | 62 | ---- 63 | 64 | 65 | def puzzle(x: List[int]): 66 | return all([x.count(i) == i for i in range(10)]) 67 | 68 | ---- 69 | 70 | def puzzle(n: int): 71 | return n % 123 == 4 and n > 10**10 72 | 73 | 74 | ---- 75 | 76 | def puzzle(s: str): 77 | return str(8**2888).count(s) > 8 and len(s) == 3 78 | 79 | ---- 80 | 81 | def puzzle(s: List[str]): 82 | return s[1234] in s[1235] and s[1234] != s[1235] 83 | 84 | ---- 85 | 86 | def puzzle(x: List[int]): 87 | return ["The quick brown fox jumps over the lazy dog"[i] for i in x] == list("The five boxing wizards jump quickly") 88 | 89 | ---- 90 | 91 | def puzzle(s: str): 92 | return s in str(8**1818) and s==s[::-1] and len(s)>11 93 | 94 | ======== 95 | 96 | def puzzle(x: List[str]): 97 | return min(x) == max(x) == str(len(x)) 98 | 99 | 100 | ---- 101 | 102 | def puzzle(x: List[int]): 103 | return all(a + b == 9 for a, b in zip([4] + x, x)) and len(x) == 1000 104 | 105 | 106 | ---- 107 | 108 | 109 | def puzzle(x: float): 110 | return str(x - 3.1415).startswith("123.456") 111 | 112 | ---- 113 | 114 | def puzzle(x: List[int]): 115 | return all([sum(x[:i]) == i for i in range(20)]) 116 | 117 | ---- 118 | 119 | def puzzle(x: List[int]): 120 | return all(sum(x[:i]) == 2 ** i - 1 for i in range(20)) 121 | 122 | ---- 123 | 124 | def puzzle(x: str): 125 | return float(x) + len(x) == 4.5 126 | 127 | ---- 128 | 129 | def puzzle(n: int): 130 | return len(str(n + 1000)) > len(str(n + 1001)) 131 | 132 | ---- 133 | 134 | def puzzle(x: List[str]): 135 | return [s + t for s in x for t in x if s!=t] == 'berlin berger linber linger gerber gerlin'.split() 136 | 137 | ---- 138 | 139 | def puzzle(x: Set[int]): 140 | return {i + j for i in x for j in x} == {0, 1, 2, 3, 4, 5, 6, 17, 18, 19, 20, 34} 141 | 142 | ---- 143 | 144 | 145 | def puzzle(x: List[int]): 146 | return all(b in {a-1, a+1, 3*a} for a, b in zip([0] + x, x + [128])) 147 | 148 | 149 | ======== 150 | 151 | def puzzle(x: List[int]): 152 | return all([x[i] != x[i + 1] for i in range(10)]) and len(set(x)) == 3 153 | 154 | ---- 155 | 156 | def puzzle(x: str): 157 | return x[::2] in x and len(set(x)) == 5 158 | 159 | ---- 160 | 161 | def puzzle(x: List[str]): 162 | return tuple(x) in zip('dee', 'doo', 'dah!') 163 | 164 | ---- 165 | 166 | def puzzle(x: List[int]): 167 | return x.count(17) == 3 and x.count(3) >= 2 168 | 169 | 170 | ---- 171 | 172 | def puzzle(s: str): 173 | return sorted(s)==sorted('Permute me true') and s==s[::-1] 174 | 175 | 176 | ---- 177 | def puzzle(x: List[str]): 178 | return "".join(x) == str(8**88) and all(len(s)==8 for s in x) 179 | 180 | 181 | ---- 182 | 183 | def puzzle(x: List[int]): 184 | return x[x[0]] != x[x[1]] and x[x[x[0]]] == x[x[x[1]]] 185 | 186 | ---- 187 | 188 | def puzzle(x: Set[int]): 189 | return all(i in range(1000) and abs(i-j) >= 10 for i in x for j in x if i != j) and len(x)==100 190 | 191 | ---- 192 | 193 | 194 | def puzzle(x: Set[int]): 195 | return all(i in range(1000) and abs(i*i - j*j) >= 10 for i in x for j in x if i != j) and len(x) > 995 196 | 197 | ---- 198 | 199 | def puzzle(x: List[int]): 200 | return all([123*x[i] % 1000 < 123*x[i+1] % 1000 and x[i] in range(1000) for i in range(20)]) 201 | 202 | 203 | 204 | 205 | ''' 206 | 207 | parts = [[src.strip() for src in part.split("----")] for part in raw_puzzles.split("========")] 208 | assert len(part_names) == len(parts) 209 | ans = [] 210 | for part_name, part in zip(part_names, parts): 211 | per_part_num = 1 212 | for src in part: 213 | ans.append({"src": src, 214 | "part": part_name, 215 | "name": f"PUZZLE {per_part_num}/{len(part)}"}) 216 | per_part_num += 1 217 | return ans 218 | 219 | 220 | -------------------------------------------------------------------------------- /notebooks/fireworks.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/notebooks/fireworks.gif -------------------------------------------------------------------------------- /puzzles/split.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | "Abbreviate", 4 | "AllQuadraticRoots", 5 | "AnyPath", 6 | "AnyTriangle", 7 | "ArithmeticSequence", 8 | "BallotProblem", 9 | "BasicStrCounts", 10 | "BillSums", 11 | "BirthdayParadoxMonteCarlo", 12 | "BoxVolume", 13 | "CapitalizeFirstLetter", 14 | "CenteredString", 15 | "CheckersPosition", 16 | "ClockAngle", 17 | "CollatzCycleUnsolved", 18 | "CollatzDelay", 19 | "CollatzGeneralizedUnsolved", 20 | "CombinationLock", 21 | "CombinationLockObfuscated", 22 | "CommonCase", 23 | "CompareInAnyCase", 24 | "CompleteParens", 25 | "ConcatStrings", 26 | "Conway99", 27 | "Count47", 28 | "CumulativeSum", 29 | "Dada", 30 | "DecreasingCountComparison", 31 | "DistinctDigits", 32 | "DistinctOddSum", 33 | "DominoTile", 34 | "Easy63", 35 | "EasySum", 36 | "EasyTwos", 37 | "EightQueensOrFewer", 38 | "EvenPath", 39 | "FermatsLastTheorem", 40 | "FindProductiveList", 41 | "FindRepeats", 42 | "FivePowers", 43 | "FloatNegSquareRoot", 44 | "FloatSquareRoot", 45 | "FloatWithDecimalValue", 46 | "FourSquares", 47 | "GCD", 48 | "GCD_multi", 49 | "GimmeChars", 50 | "GraphIsomorphism", 51 | "HalfPairs", 52 | "Harder63", 53 | "HelloWorld", 54 | "IfCases", 55 | "IfProblemWithOr", 56 | "IncDec", 57 | "IntDiv", 58 | "IntDiv2", 59 | "IntMul", 60 | "IntNeg", 61 | "IntSquareRoot", 62 | "IntSub", 63 | "IntSub2", 64 | "IntSum", 65 | "InvertIndices", 66 | "InvertPermutation", 67 | "Kirkman", 68 | "KnightsTour", 69 | "LCM", 70 | "LCM_multi", 71 | "LearnParity", 72 | "LearnParityWithNoise", 73 | "Lehmer", 74 | "LineIntersection", 75 | "ListAt", 76 | "ListDistinctSum", 77 | "ListIn", 78 | "ListMul", 79 | "ListNegAt", 80 | "ListPosSum", 81 | "ListSetLen", 82 | "ListSlice", 83 | "LongestSubsetString", 84 | "MatchingMarkers", 85 | "MaxConsecutiveProduct", 86 | "MaxConsecutiveSum", 87 | "MaxDelta", 88 | "MaybeReversed", 89 | "MinConsecutiveSum", 90 | "MinRotations", 91 | "MonkeyAndCoconuts", 92 | "Nash", 93 | "NecklaceSplit", 94 | "Nim", 95 | "No3Colinear", 96 | "NoRelativePrimes", 97 | "OddPath", 98 | "OnesAndTwos", 99 | "PandigitalSquare", 100 | "PenultimateRevString", 101 | "PenultimateString", 102 | "PostageStamp", 103 | "QuadraticRoot", 104 | "RepeatDec", 105 | "ReverseCat", 106 | "ReverseLifeStep", 107 | "RockPaperScissors", 108 | "SameDifferent", 109 | "ShortestPath", 110 | "Spaceship", 111 | "SquareTiles", 112 | "SquaringTheSquare", 113 | "Sstriiinggssuubb", 114 | "StrAdd", 115 | "StrCount", 116 | "StrIn", 117 | "StrIn2", 118 | "StrIndex", 119 | "StrIndex2", 120 | "StrJoiner", 121 | "StrLen", 122 | "StrParts", 123 | "StrSlice", 124 | "Study_1", 125 | "Study_11", 126 | "Study_13", 127 | "Study_15", 128 | "Study_17", 129 | "Study_19", 130 | "Study_21", 131 | "Study_23", 132 | "Study_25", 133 | "Study_27", 134 | "Study_29", 135 | "Study_3", 136 | "Study_5", 137 | "Study_7", 138 | "Study_9", 139 | "SublistSum", 140 | "SubstrCount", 141 | "SumOfDigits", 142 | "TotalDifference", 143 | "Triple0", 144 | "TripleDouble", 145 | "Tutorial1", 146 | "Tutorial2", 147 | "Tutorial3", 148 | "Tutorial4", 149 | "Tutorial5", 150 | "UNSOLVED_UncrossedKnightsPath", 151 | "UncrossedKnightsPath", 152 | "UnweightedShortestPath", 153 | "VerbalArithmetic", 154 | "VowelDrop", 155 | "WaterPouring", 156 | "ZipStr", 157 | "Znam" 158 | ], 159 | "test": [ 160 | "AllCubicRoots", 161 | "AllPandigitalSquares", 162 | "AllPrefixes", 163 | "AlternatingFactorials", 164 | "AntiShuffle", 165 | "AnyEdge", 166 | "ArrayDiff", 167 | "BackWorlds", 168 | "BackwardsDigits", 169 | "BiPermutations", 170 | "BigOdds", 171 | "BiggestEven", 172 | "BiggestK", 173 | "Binarize", 174 | "BinaryAverage", 175 | "BinarySort", 176 | "BinaryStrXOR", 177 | "BinomialProbabilities", 178 | "BirthdayParadox", 179 | "BitSum", 180 | "BooleanPythagoreanTriples", 181 | "CardGame24", 182 | "CatStrings", 183 | "CeilingSquares", 184 | "CertifiedGCD", 185 | "ChangeBase", 186 | "CharCounts", 187 | "CharSum", 188 | "ClosestInteger", 189 | "ClosestPalindrome", 190 | "CommonNumbers", 191 | "CompleteSplit", 192 | "ConsonantFilter", 193 | "CubeRoot", 194 | "CubicRoot", 195 | "CumulativeSums", 196 | "DateDiff", 197 | "Dedup", 198 | "DeepestParens", 199 | "DelPalindrome", 200 | "Derivative", 201 | "DiffChars", 202 | "DiscreteLog", 203 | "DistinctChars", 204 | "Drops", 205 | "EngineerNumbers", 206 | "EvaluateOperators", 207 | "Even4Sum", 208 | "EvenBetween", 209 | "EvenOddDigits", 210 | "EvenOddSum", 211 | "EvenPalindromeNumbers", 212 | "EvenWords", 213 | "ExpandSpaces", 214 | "ExponentialCoinMoves", 215 | "ExponentialProbability", 216 | "Factor47", 217 | "FactorString", 218 | "Factoring", 219 | "FermatComposites", 220 | "Fib3", 221 | "Fib4", 222 | "Fibonacci", 223 | "FilenameOK", 224 | "FilterInts", 225 | "FindBored", 226 | "FindCloseElements", 227 | "FindClosePair", 228 | "FindContainers", 229 | "FindExtensions", 230 | "FindHomogeneousSubstring", 231 | "FindPositives", 232 | "FindVowels", 233 | "FirstNegCumulative", 234 | "FlipCase", 235 | "Frac", 236 | "GCD17", 237 | "GeometricSequence", 238 | "Grader", 239 | "GreatestHIndex", 240 | "HalfSorted", 241 | "HalfTag", 242 | "HeronTriangle", 243 | "HexPrimes", 244 | "IdentifyZeroTrips", 245 | "IfProblem", 246 | "IfProblemWithAnd", 247 | "IncreasingViolation", 248 | "IntNegSquareRoot", 249 | "IntegerLog", 250 | "Intersperse", 251 | "InverseSuperFactorial", 252 | "InvestigateCrash", 253 | "IsEven", 254 | "LZW", 255 | "LargestDivisor", 256 | "LargestNegSmallestPos", 257 | "LargestPrimeDigitSum", 258 | "LargestPrimeFactor", 259 | "LargestStringNum", 260 | "LastLetters", 261 | "LexPath", 262 | "ListInc", 263 | "ListIndex", 264 | "ListIndex2", 265 | "ListLen", 266 | "LittleFermat", 267 | "LongEarlySum", 268 | "LongestMonotonicSubstring", 269 | "LongestMonotonicSubstringTricky", 270 | "LongestStr", 271 | "Mastermind", 272 | "MaxInt", 273 | "Median", 274 | "MinBigger", 275 | "MinSquaredDeviation", 276 | "MissingBananas", 277 | "Monotonic", 278 | "MoreQueens", 279 | "MostUnique", 280 | "Moving0s", 281 | "NarrowerList", 282 | "NearbyDuplicates", 283 | "NumPasses", 284 | "OddCase", 285 | "OddCollatz", 286 | "OddDegreePolynomialRoot", 287 | "OddProduct", 288 | "OneEnded", 289 | "OptimalBridges", 290 | "Oscillators", 291 | "OverlappingCount", 292 | "PackingHam", 293 | "PairZeroSum", 294 | "Palindrome", 295 | "PalindromeContaining", 296 | "ParenDepth", 297 | "ParenthesesPermutation", 298 | "ParityExchange", 299 | "ParseMusic", 300 | "PickNearNeighbors", 301 | "PlanetRange", 302 | "PlantedClique", 303 | "PositiveDigitSums", 304 | "PrimeFactorization", 305 | "PrimeFib", 306 | "PrimeIntervalIntersection", 307 | "PrimeSel", 308 | "PrimeWords", 309 | "PrimesUpTo", 310 | "ProductSigns", 311 | "PythagoreanTriples", 312 | "Quine", 313 | "Rescale", 314 | "RevQuine", 315 | "RollingMax", 316 | "RomanNumerals", 317 | "RotateSort", 318 | "RotateString", 319 | "SecondSmallestUnique", 320 | "SeparateParenGroups", 321 | "SevenElevenThirteen", 322 | "ShiftChars", 323 | "ShortIntegerPath", 324 | "ShortestDecDelta", 325 | "SimplifyProductFraction", 326 | "SlidingOne", 327 | "SlidingPuzzle", 328 | "SmallExponentBigSolution", 329 | "SmallestEven", 330 | "SortByDigitSum", 331 | "SortNumbers", 332 | "SortPlusPlus", 333 | "SortedOdds", 334 | "SpaceyRange", 335 | "Sssuubbstriiingg", 336 | "StonePiles", 337 | "StrAt", 338 | "StrLength", 339 | "StrMul", 340 | "StrMul2", 341 | "StrNegAt", 342 | "StrSetLen", 343 | "StrSplit", 344 | "StrSplitter", 345 | "StrangeSplit", 346 | "Study_10", 347 | "Study_12", 348 | "Study_14", 349 | "Study_16", 350 | "Study_18", 351 | "Study_2", 352 | "Study_20", 353 | "Study_22", 354 | "Study_24", 355 | "Study_26", 356 | "Study_28", 357 | "Study_30", 358 | "Study_4", 359 | "Study_6", 360 | "Study_8", 361 | "SubstitutionCypher", 362 | "Sudoku", 363 | "SumProduct", 364 | "ThreeCubes", 365 | "ThreeCycle", 366 | "ThreePrimes", 367 | "Threeples", 368 | "TicTacToeO", 369 | "TicTacToeX", 370 | "TowersOfHanoi", 371 | "TowersOfHanoiArbitrary", 372 | "TriangleArea", 373 | "Tribonacci", 374 | "TripleZeroSum", 375 | "TwoThirdsSorted", 376 | "UnevenFind", 377 | "UniqueSorted", 378 | "UnitsProduct", 379 | "UpDownSort", 380 | "UppercaseEven", 381 | "ValidBracketSubsequence", 382 | "VowelSandwich", 383 | "WeirdDecodeVowels", 384 | "WildSort", 385 | "Zarankiewicz", 386 | "ZeroSum", 387 | "ZobristCollision" 388 | ] 389 | } -------------------------------------------------------------------------------- /solvers/README.md: -------------------------------------------------------------------------------- 1 | # Solvers 2 | 3 | This folder contains two subfolders for recreating the benchmarks in the [paper](https://arxiv.org/abs/2106.05784). 4 | * [gpt3](/solvers/gpt3) The GPT-3 experiments. 5 | * [enumerative](/solvers/enumerative) The enumerative top-down search solvers. 6 | 7 | Each folder has a separate README explaining how to run the experiments. -------------------------------------------------------------------------------- /solvers/codex/README.md: -------------------------------------------------------------------------------- 1 | # Running GPT-3 experiments 2 | 3 | These are instructions for re-running the Codex experiments. The results will be slightly different than those in 4 | the paper because the API is non-deterministic. 5 | 6 | The requirements can be installed with `pip3 install -r requirements.txt`. 7 | 8 | `run_codex_experiments.py` runs the Codex experiments and prints the results to stdout. Change the 9 | parameters in that file to run it on the 397puzzles (v0.2) or 138puzzles.json (v0.1 used in 10 | first experiment) or 30puzzles.json (study) 11 | or to use the davinci-codex engine vs cushman-codex. 12 | 13 | ## Installation and execution. 14 | You will need an open-ai Codex API access key which can be signed up for [here](https://openai.com/join/). 15 | You will then need to set it as the `OPENAI_API_KEY` environmnet variable. If you want an extension added 16 | to the engines such as "-msft", set the environment variable `export OPEN_AI_ENGINE_SUFFIX=-msft`. 17 | We also recommend that you set the environment variable `export PYTHONHASHSEED=0` for determinism. 18 | 19 | The requirements can be installed with `pip3 install -r requirements.txt`. 20 | 21 | It was run with Python 3.6.9, sys.version = '3.6.9 (default, Jan 26 2021, 15:33:00) \n[GCC 8.4.0]', but should 22 | be compatible with later versions as well. 23 | 24 | Then you simply run 25 | `python run_codex_experiments.py`. It uses cacheing mechanisms with the first run 26 | being quite slow and verbose, querying the API. However you can subsequently run it again and it will be 27 | much faster and just output the results. The cacheing makes it deterministic so it should give the same 28 | exact results when re-run. 29 | 30 | -------------------------------------------------------------------------------- /solvers/codex/ezlog.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import inspect 4 | import io 5 | 6 | my_path = os.path.dirname(__file__) 7 | 8 | 9 | def color_str(obj, code="\033[0;36m"): 10 | return code + str(obj) + '\033[0m' 11 | 12 | 13 | _configured = False 14 | 15 | 16 | def configure_logging(stdio_level=logging.INFO, 17 | file_level=logging.DEBUG, 18 | filename=".easy.log", 19 | filepath=os.path.join(my_path, "logs")): 20 | os.makedirs(filepath, exist_ok=True) 21 | filename = os.path.join(filepath, filename) 22 | global _configured 23 | if _configured: 24 | warning("Re-configuring logging") 25 | stdio_handler = logging.StreamHandler() 26 | stdio_handler.setLevel(stdio_level) 27 | file_hanlder = logging.FileHandler(filename) 28 | file_hanlder.setLevel(file_level) 29 | 30 | logging.basicConfig( 31 | format="%(asctime)s - %(levelname)s - %(name)s - %(message).200s", 32 | datefmt="%m/%d/%Y %H:%M:%S", 33 | level=min(stdio_level, file_level), 34 | handlers=[stdio_handler, file_hanlder] 35 | ) 36 | 37 | _configured = True 38 | _get_or_create_logger().debug("Configured logging") 39 | 40 | 41 | _loggers = {} 42 | 43 | 44 | def _get_or_create_logger(): 45 | global _configured, _loggers 46 | if not _configured: 47 | configure_logging() 48 | try: 49 | for frame in inspect.stack(): 50 | name = inspect.getmodule(frame[0]).__name__ 51 | if name != __name__: 52 | break 53 | except: 54 | name = "_" 55 | if name not in _loggers: 56 | _loggers[name] = logging.getLogger(name) 57 | return _loggers[name] 58 | 59 | 60 | def print_to_string(*args, end="", **kwargs): 61 | with io.StringIO() as buf: 62 | print(*args, file=buf, end=end, **kwargs) 63 | return buf.getvalue() 64 | 65 | 66 | def debug(*args, **kwargs): 67 | _get_or_create_logger().debug(print_to_string(*args, **kwargs)) 68 | 69 | 70 | def info(*args, **kwargs): 71 | _get_or_create_logger().info(print_to_string(*args, **kwargs)) 72 | 73 | 74 | log = info 75 | 76 | 77 | def warning(*args, **kwargs): 78 | _get_or_create_logger().warning(print_to_string(*args, **kwargs)) 79 | 80 | 81 | warn = warning 82 | 83 | 84 | def error(*args, **kwargs): 85 | _get_or_create_logger().error(print_to_string(*args, **kwargs)) 86 | -------------------------------------------------------------------------------- /solvers/codex/lm_solve/__init__.py: -------------------------------------------------------------------------------- 1 | from lm_solve.run import * 2 | -------------------------------------------------------------------------------- /solvers/codex/lm_solve/gpt_lib.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import openai 4 | import ezlog 5 | import time 6 | import datetime 7 | 8 | assert 'OPENAI_API_KEY' in os.environ, "Need to set environment variable `OPENAI_API_KEY`" 9 | openai.api_key = os.environ['OPENAI_API_KEY'] 10 | OPEN_AI_ENGINE_SUFFIX = os.environ.get('OPEN_AI_ENGINE_SUFFIX', '') # add extension such as -msft to engine names 11 | 12 | _CACHE_PATH = os.path.join(os.path.dirname(__file__), "../.cache") 13 | _CACHE_ENCODING = "utf-8" 14 | 15 | 16 | # the cache file is just a list of (query params dictionary encoded as a string but without n, result list) 17 | # multiple queries with the same params (except for n) are merged into a single big list 18 | class Cache: 19 | def __init__(self, filename): 20 | self.filename = filename 21 | self._cache = None 22 | 23 | def _load_cache(self): 24 | """for lazy loading""" 25 | assert self._cache is None, "gpt cache already loaded" 26 | 27 | if not os.path.exists(_CACHE_PATH): 28 | ezlog.warn("Creating cache path") 29 | os.makedirs(_CACHE_PATH) 30 | 31 | self._cache = {} 32 | 33 | if os.path.exists(self.filename): 34 | time0 = time.perf_counter() 35 | with open(self.filename, "r", encoding=_CACHE_ENCODING) as f: 36 | for k, v in [eval(line) for line in f.readlines()]: 37 | if k not in self._cache: 38 | self._cache[k] = v 39 | else: 40 | self._cache[k].extend(v) 41 | ezlog.info(f"Loaded cache `{self.filename}` in {time.perf_counter() - time0:.1f}s") 42 | else: 43 | ezlog.warn("No gpt cache yet") 44 | 45 | def defrag(self): 46 | if self._cache is None: 47 | self._load_cache() 48 | 49 | # def helper(k): # remove max_batch 50 | # k2 = eval(k) 51 | # del k2["max_batch"] 52 | # return str(k2) 53 | 54 | if self._cache: 55 | with open(self.filename, "w", encoding=_CACHE_ENCODING) as f: 56 | # f.write("\n".join([str((helper(k), v)) for k, v in self._cache.items()]+[""])) 57 | f.write("\n".join([str((k, v)) for k, v in self._cache.items()]+[""])) 58 | ezlog.info("Defragged cache") 59 | else: 60 | ezlog.warn("No cache to defrag") 61 | 62 | 63 | def get(self, item): 64 | if self._cache is None: 65 | self._load_cache() 66 | 67 | return self._cache.get(item, []).copy() # no monkey business changing cache 68 | 69 | def extend(self, key, values): 70 | if self._cache is None: 71 | self._load_cache() 72 | 73 | v = self._cache.setdefault(key, []) 74 | v.extend(values) 75 | 76 | with open(self.filename, "a", encoding=_CACHE_ENCODING) as f: 77 | f.write(str((key, values)) + "\n") 78 | 79 | return v.copy() # no monkey business changing cache 80 | 81 | 82 | BATCH_SIZES = { 83 | "davinci": 32, 84 | "davinci-codex": 128, 85 | "cushman-codex": 128 86 | } 87 | 88 | CACHES = {cache: Cache(os.path.join(_CACHE_PATH, cache + ".cache")) for cache in BATCH_SIZES} 89 | 90 | def query(prompt, n=10, max_tokens=150, temp=1.0, stop=None, notes=None, cache_only=False, verbose=True, 91 | max_retries=10, engine="cushman-codex"): 92 | """Query gpt 93 | 94 | :param prompt: Up to 2048 tokens (about 3-4k chars) 95 | :param n: number of answers, None returns all cached answers 96 | :param max_tokens: 97 | :param temp: 0.9 seems to work well 98 | :param stop: string to stop at or '' if not to stop 99 | :param notes: notes you want to save or change in case you want to run the same query more than once! 100 | :return: list of answers and then the response items 101 | """ 102 | global BATCH_SIZES 103 | global CACHES 104 | cur_cache = CACHES[engine] 105 | max_batch = BATCH_SIZES[engine] 106 | engine += OPEN_AI_ENGINE_SUFFIX # add tail to engine name 107 | 108 | if temp == 0 and n > 1: 109 | ezlog.debug("Temp 0: no point in running more than one query") 110 | n = 1 111 | 112 | key = str(dict(prompt=prompt, max_tokens=max_tokens, temp=temp, stop=stop, rep=notes)) 113 | 114 | cached = cur_cache.get(key) 115 | 116 | if n is None: 117 | return cached 118 | 119 | if len(cached) >= n: 120 | return cached[:n] 121 | 122 | assert not cache_only, f'Entry not found in cache with prompt "{json.dumps(prompt)}"' 123 | if verbose: 124 | print("/" * 100) 125 | print(f"Querying GPT {engine} with prompt:") 126 | print(prompt) 127 | s = stop and stop.replace('\n', '\\n') 128 | print(f"/// n={n} ({n - len(cached)} new) max_tokens={max_tokens} temp={temp} max_batch={max_batch} stop={s}") 129 | print("/" * 100) 130 | 131 | time0 = time.perf_counter() 132 | 133 | new = [] 134 | n -= len(cached) 135 | 136 | while n > 0: 137 | m = min(n, max_batch) 138 | 139 | try_number = 0 140 | while True: 141 | try: 142 | res = openai.Completion.create( 143 | engine=engine, 144 | prompt=prompt, 145 | max_tokens=max_tokens, 146 | temperature=temp, 147 | n=m, 148 | stop=stop or None 149 | ) 150 | break 151 | except (openai.error.RateLimitError, openai.error.APIError): 152 | if try_number == max_retries: 153 | print("Rate limit error: Giving up!") 154 | raise 155 | sleep_secs = 10 * (2 ** try_number) 156 | try_number += 1 157 | print(f"Rate limit error #{try_number}: Sleeping for {sleep_secs} seconds...") 158 | time.sleep(sleep_secs) 159 | 160 | new += [c["text"] for c in res["choices"]] 161 | n -= m 162 | 163 | return cur_cache.extend(key, new) 164 | -------------------------------------------------------------------------------- /solvers/codex/lm_solve/scratch.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | def chars(filename, max_ord=300): 4 | with open(filename, "r", encoding="utf8") as f: 5 | counts = Counter(ord(c) for c in f.read()) 6 | print(counts.most_common()) 7 | print("max", max(counts)) 8 | missing = [i for i in range(max_ord) if i not in counts] 9 | print("missing", missing) 10 | 11 | chars(".cache/davinci-codex.cache") 12 | 13 | import time 14 | time0 = time.perf_counter() 15 | with open(".cache/davinci-codex.cache", "r", encoding="utf8") as f: 16 | lines = f.readlines() 17 | time1 = time.perf_counter() 18 | print("duration", time1 - time0) 19 | 20 | time0 = time.perf_counter() 21 | with open(".cache/davinci-codex.cache", "r", encoding="utf8") as f: 22 | lines = f.readlines() 23 | time1 = time.perf_counter() 24 | print("duration", time1 - time0) 25 | 26 | import json 27 | time0 = time.perf_counter() 28 | elines = [json.loads(l) for l in lines] 29 | time1 = time.perf_counter() 30 | print("duration", time1 - time0) 31 | 32 | len(lines), len(set(lines)) 33 | len(elines[0][0]) 34 | list(eval(elines[0][0])) 35 | -------------------------------------------------------------------------------- /solvers/codex/requirements.txt: -------------------------------------------------------------------------------- 1 | astor==0.8.1 2 | numpy==1.22.0 3 | openai==0.6.3 4 | tqdm==4.60.0 5 | transformers==4.30.0 6 | Pebble==4.6.1 7 | 8 | # we ran with Python version sys.version = '3.6.9 (default, Jan 26 2021, 15:33:00) \n[GCC 8.4.0]' 9 | # distro-info===0.18ubuntu0.18.04.1 10 | -------------------------------------------------------------------------------- /solvers/codex/results/results_397_cushman_codex_1k_full.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/solvers/codex/results/results_397_cushman_codex_1k_full.json.gz -------------------------------------------------------------------------------- /solvers/codex/run_codex_experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script runs the codex experiments. 3 | For GPT-3 experiments see run_gpt3_experiments.py in https://github.com/microsoft/PythonProgrammingPuzzles/tree/v0.1 4 | It uses cacheing mechanisms so that if run twice with the same parameters, it will give exactly the same 5 | results and will not query the API again and will not judge the resulting solutions again. Hence, the first 6 | time you run it, it will be slow, but you can subsequently run it again and it will be fast. It will run the 7 | experiment three times, with different seeds to get different results. 8 | """ 9 | 10 | import lm_solve 11 | import utils 12 | import math 13 | import numpy as np 14 | 15 | OUTPUT_FILENAME = "results_30_cushman_codex_32.json" 16 | SEEDS = 1 # number of times to run it 17 | PARAMS = dict( 18 | temp=0.9, 19 | timeout=1.0, # seconds to judge 20 | n=32, # number of attempts per puzzle, usually 1000, or set small for a fast run 21 | filename="30puzzles.json", # set to 397puzzles.json for a run on full v0.2 dataset 22 | cache_only=False, # change this to True if you want to run a 2nd time without risking hitting API 23 | engine="cushman-codex", # FAST-CODEX: "cushman-codex" CODEX: "davinci-codex" GPT3: "davinci" 24 | ) 25 | 26 | BOOTSTRAP_PARAMS = dict( 27 | temp=PARAMS["temp"], 28 | n=PARAMS["n"], 29 | timeout=PARAMS["timeout"], 30 | filename=PARAMS["filename"], 31 | cache_only=PARAMS["cache_only"], 32 | ppi=32, # puzzles per iteration 33 | engine=PARAMS["engine"], 34 | prefix="from typing import List\n\n", 35 | ) 36 | 37 | PREFIX = """from typing import List 38 | 39 | def f1(s: str): 40 | return "Hello " + s == "Hello world" 41 | 42 | def g1(): 43 | return "world" 44 | 45 | assert f1(g1()) 46 | 47 | def f2(s: str): 48 | return "Hello " + s[::-1] == "Hello world" 49 | 50 | def g2(): 51 | return "world"[::-1] 52 | 53 | assert f2(g2()) 54 | 55 | def f3(x: List[int]): 56 | return len(x) == 2 and sum(x) == 3 57 | 58 | def g3(): 59 | return [1, 2] 60 | 61 | assert f3(g3()) 62 | 63 | def f4(s: List[str]): 64 | return len(set(s)) == 1000 and all((x.count("a") > x.count("b")) and ('b' in x) for x in s) 65 | 66 | def g4(): 67 | return ["a"*(i+2)+"b" for i in range(1000)] 68 | 69 | assert f4(g4()) 70 | 71 | def f5(n: int): 72 | return str(n * n).startswith("123456789") 73 | 74 | def g5(): 75 | return int(int("123456789" + "0"*9) ** 0.5) + 1 76 | 77 | assert f5(g5()) 78 | 79 | """ # trailing newlines important 80 | 81 | 82 | PREFIX_DOCSTR = '''from typing import List 83 | 84 | def f1(s: str): 85 | return "Hello " + s == "Hello world" 86 | 87 | def g1(): 88 | """Find a string that when concatenated onto 'Hello ' gives 'Hello world'.""" 89 | return "world" 90 | 91 | assert f1(g1()) 92 | 93 | def f2(s: str): 94 | return "Hello " + s[::-1] == "Hello world" 95 | 96 | def g2(): 97 | """Find a string that when reversed and concatenated onto 'Hello ' gives 'Hello world'.""" 98 | return "world"[::-1] 99 | 100 | assert f2(g2()) 101 | 102 | def f3(x: List[int]): 103 | return len(x) == 2 and sum(x) == 3 104 | 105 | def g3(): 106 | """Find a list of two integers whose sum is 3.""" 107 | return [1, 2] 108 | 109 | assert f3(g3()) 110 | 111 | def f4(s: List[str]): 112 | return len(set(s)) == 1000 and all((x.count("a") > x.count("b")) and ('b' in x) for x in s) 113 | 114 | def g4(): 115 | """Find a list of 1000 distinct strings which each have more 'a's than 'b's and at least one 'b'.""" 116 | return ["a"*(i+2)+"b" for i in range(1000)] 117 | 118 | assert f4(g4()) 119 | 120 | def f5(n: int): 121 | return str(n * n).startswith("123456789") 122 | 123 | def g5(): 124 | """Find an integer whose perfect square begins with 123456789 in its decimal representation.""" 125 | return int(int("123456789" + "0"*9) ** 0.5) + 1 126 | 127 | assert f5(g5()) 128 | 129 | ''' # trailing newlines important 130 | 131 | 132 | def pass_at_k(k: int, successes: int, attempts: int): 133 | fail_prob = 1.0 134 | for i in range(k): 135 | fail_prob *= (attempts - successes)/attempts # gets right answer of 0 when attempts == successes 136 | attempts -= 1 137 | return 1.0 - fail_prob 138 | 139 | 140 | 141 | def run(seed=0): 142 | PARAMS_0 = {**PARAMS, "n": 0} 143 | BOOTSTRAP_PARAMS_0 = {**BOOTSTRAP_PARAMS, "n": 0} 144 | sols = [lm_solve.prompt_experiment(**PARAMS, experiment="short", prefix="", seed=seed), 145 | lm_solve.prompt_experiment(**PARAMS, experiment="med", prefix=PREFIX, remove_docstring=True, seed=seed), 146 | lm_solve.prompt_experiment(**PARAMS, experiment="long", prefix=PREFIX_DOCSTR, remove_docstring=False, seed=seed), 147 | ] 148 | num_puzzles = len(sols[0]["sat_sols"]) 149 | assert all(len(s["sat_sols"]) == num_puzzles for s in sols) 150 | 151 | n = PARAMS["n"] 152 | ks = [1] 153 | while ks[-1] < n: 154 | ks += [ks[-1] * i for i in [10]] # for i in [2, 5, 10]] 155 | ks = [k for k in ks if k <= n] 156 | if ks[-1] != n: 157 | ks.append(n) 158 | for s in sols: 159 | s["pass@k"] = [ 160 | ( 161 | k, 162 | np.mean([pass_at_k(k, s_s["n_sols"], n) for s_s in s["sat_sols"]]) 163 | ) 164 | for k in ks] 165 | 166 | bootstrap = lm_solve.bootstrap(**BOOTSTRAP_PARAMS, seed=seed) 167 | bootstrap["pass@k"] = [(k, np.mean([s_s["failures"] < k for s_s in bootstrap["sat_sols"]])) for k in ks] 168 | sols.append(bootstrap) 169 | 170 | print(f"run={seed} ALL DONE!\n\n") 171 | print(f"run={seed} RESULTS " + "=" * 50) 172 | print() 173 | 174 | for s in sols: 175 | print(s["experiment"], "prefix:", s["prefix"].replace("\n", "\\n")[:250]) 176 | print(" ", s["tot_solved"], "solved, pass@k", " ".join(f'{k} {p:.5f}' for k, p in s["pass@k"])) 177 | 178 | print(f"Pass at k [(k, {', '.join(s['experiment'] for s in sols)}) ...]") 179 | print(list(zip([k for k, _p in sols[0]["pass@k"]], *[[p for _k, p in s["pass@k"]] for s in sols]))) 180 | 181 | return sols 182 | 183 | def main(): 184 | res = [s for seed in range(SEEDS) for s in run(seed)] 185 | 186 | if OUTPUT_FILENAME: 187 | FULL_FILENAME = OUTPUT_FILENAME.replace(".json", "_full.json.gz") 188 | utils.save_json(res, FULL_FILENAME) 189 | for s in res: 190 | if "sat_sols" in s: 191 | for t in s["sat_sols"] : 192 | if "sol_counts" in t: 193 | if t["sol_counts"]: 194 | t["shortest_sol"] = min([s for s, c in t["sol_counts"]], key=len) 195 | t["longest_sol"] = max([s for s, c in t["sol_counts"]], key=len) 196 | t["common_sol"] = max(t["sol_counts"], key=lambda z: z[1])[0] 197 | del t["sol_counts"] 198 | utils.save_json(res, OUTPUT_FILENAME) 199 | print(f"saved results to {OUTPUT_FILENAME} and {FULL_FILENAME}") 200 | 201 | 202 | if __name__ == "__main__": 203 | main() 204 | 205 | -------------------------------------------------------------------------------- /solvers/codex/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import operator 3 | import os 4 | import re 5 | import json 6 | import time 7 | import logging 8 | 9 | def get_lambda_arg_name(lam): 10 | assert lam.startswith("lambda ") 11 | return lam[len("lambda "):lam.index(":")].strip() 12 | 13 | 14 | def stringify(const): 15 | if type(const) is str: 16 | return json.dumps(const) 17 | return str(const) 18 | 19 | 20 | def color_str(obj, code="\033[0;36m"): 21 | return code + str(obj) + '\033[0m' 22 | 23 | 24 | def prod(iterable): # like sum but product 25 | return functools.reduce(operator.mul, iterable, 1) 26 | 27 | 28 | def flatten(it): 29 | return (e for a in it for e in (flatten(a) if isinstance(a, (tuple, list)) else (a,))) 30 | 31 | def save_json(obj, filename, make_dirs_if_necessary=False, indent=2, **kwargs): 32 | """Saves compressed file if filename ends with '.gz'""" 33 | import json 34 | if make_dirs_if_necessary: 35 | os.makedirs(os.path.dirname(filename), exist_ok=True) 36 | if filename.endswith(".gz"): 37 | import gzip 38 | with gzip.open(filename, "wt") as f: 39 | return json.dump(obj, f, indent=indent, **kwargs) 40 | with open(filename, "w", encoding="utf8") as f: 41 | return json.dump(obj, f, indent=indent, **kwargs) 42 | 43 | def load_json(filename): 44 | """Loads compressed file if filename ends with '.gz'""" 45 | import json 46 | if filename.endswith(".gz"): 47 | import gzip 48 | with gzip.open(filename, "rt") as f: 49 | return json.load(f) 50 | with open(filename, "r", encoding="utf8") as f: 51 | return json.load(f) 52 | 53 | 54 | def viz_py(py): 55 | import astor, ast 56 | print(astor.dump_tree(ast.parse(py))) 57 | 58 | 59 | def dedup(li): 60 | seen = set() 61 | return [x for x in li if x not in seen and not seen.add(x)] 62 | 63 | 64 | def test_puzzle(f, x): 65 | """Checks if x is of the correct type and makes f return True (literally True, not an integer or whatever) 66 | 67 | :param f: Puzzle 68 | :param x: candidate answer 69 | :return: 70 | """ 71 | answer_type = list(f.__annotations__.values())[0] 72 | if not type_check(x, answer_type): 73 | raise TypeError 74 | return f(x) is True 75 | 76 | 77 | 78 | def type_check(obj, typ): 79 | """ 80 | check if obj is of type `typ` where `typ` is a `typing` module type annotation, eg List[int] 81 | The way we do this to be compatible across versions is we first convert the type to a string. 82 | """ 83 | 84 | type_str = str(typ).replace("typing.", "") 85 | if type_str.startswith("= 2: 113 | a = src.index('"""') 114 | b = src.index('"""', a+1) + 3 115 | if count == 1: 116 | h = helper(src[:a]) 117 | if h != src[:a]: 118 | return h + src[a:] 119 | return helper(src[:a]) + src[a:b] + helper(src[b:]) 120 | 121 | return helper(src) 122 | 123 | logger = None 124 | 125 | def timeit(method): 126 | global logger 127 | if logger is None: 128 | logger = logging.getLogger(__name__) 129 | def timed(*args, **kw): 130 | tick = time.time() 131 | result = method(*args, **kw) 132 | tock = time.time() 133 | logger.debug(f'{method.__name__}: {tock - tick:.3f}s') 134 | 135 | return result 136 | return timed 137 | 138 | -------------------------------------------------------------------------------- /solvers/enumerative/README.md: -------------------------------------------------------------------------------- 1 | # Enumerative puzzle solvers 2 | 3 | This folder contains the code for the enumerative models used in our Programming Puzzles paper. 4 | We used python 3.8.0 and the libraries in the `requirements.txt` file. 5 | 6 | In a linux machine with python3.8.0 installed, the following commands will set up the environment: 7 | ``` 8 | virtualenv -p /usr/bin/python3.8 env_solvers 9 | source env_solvers/bin/activate 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Uniform solver 14 | ``` 15 | bash run_uniform.sh 16 | ``` 17 | This will run the uniform solver for a maximum of 10k trials per puzzle. This is required before training the other parameterized solvers. 18 | 19 | To run the uniform with 1M trials per puzzle, simply change the `max_n_progs` argument in the bash script. 20 | 21 | ## Bigram random forest solver 22 | ``` 23 | bash run_bigram.sh 24 | ``` 25 | This will first train a parameterized model with self-bootsrapping (first iteration is based on the unifrom solutions). The last command will train a model without self-bootsrapping. 26 | 27 | ## Transformers solver 28 | ``` 29 | bash download_pretrained_roberta.sh 30 | bash run_transformer.sh 31 | ``` 32 | The first script will download the RoBERTa-Base model that we trained on Python code. 33 | 34 | The second script will first train a parameterized model with self-bootsrapping (first iteration is based on the unifrom solutions). The last command will train a model without self-bootsrapping. 35 | -------------------------------------------------------------------------------- /solvers/enumerative/challenges/__init__.py: -------------------------------------------------------------------------------- 1 | from challenges.challenge import * 2 | from challenges.solutions import * 3 | 4 | def contains_node(root, x_node): 5 | return root is x_node or (hasattr(root, "children") and any(contains_node(k, x_node) for k in root.children)) 6 | -------------------------------------------------------------------------------- /solvers/enumerative/challenges/challenge.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import tython 3 | import logging 4 | from typing import List, Dict, Callable, Tuple, Generator, Set, Sequence 5 | from tython import Program, nt 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def extract_constants(prog) -> Dict: 11 | ''' 12 | Extract all constants from program. Does not (yet) allow copying of comprehensions, e.g., '[i*i for i in range(10)]' 13 | ''' 14 | 15 | from collections import defaultdict 16 | consts = defaultdict(list) 17 | 18 | def handle_args(args_node): 19 | 20 | if args_node.rule.name == 'cast:ARGS': 21 | handle_args(args_node.children[0]) 22 | else: 23 | if len(args_node.children) >= 3 and args_node.children[1].nt == nt.TYPE: 24 | annotation_node = args_node.children[1] 25 | t = nt.type2nt(eval(annotation_node.src())) 26 | consts[t].append(args_node.children[0]) 27 | if args_node.children and args_node.children[-1].nt in {nt.ARGS, nt.DEFAULT_ARGS}: 28 | handle_args(args_node.children[-1]) 29 | 30 | def helper(node): 31 | if node.rule.name == 'def': # it's a function 32 | name_node, args_node, body_node = node.children 33 | if name_node.src() == 'sat': 34 | handle_args(args_node.children[-1]) # skip first arg for `def sat` 35 | else: 36 | handle_args(args_node) 37 | helper(body_node) 38 | return False 39 | elif node.nt in {nt.NAME}: 40 | return False 41 | elif node.nt in {nt.STMT}: 42 | for c in node.children: 43 | helper(c) 44 | return False 45 | if node.rule.name not in {"int-const", "str-const"} and not all([helper(c) for c in node.children]): 46 | return False 47 | if node.nt.isa(nt.LIST, nt.SET, nt.DICT, nt.TUPLE, nt.RANGE, 48 | nt.INT, nt.FLOAT, nt.BOOL, nt.STR): 49 | consts[node.nt].append(node) 50 | return True 51 | 52 | if prog is not None: 53 | helper(prog.tree) 54 | 55 | return dict(consts) 56 | 57 | # 58 | # q = Program(""" 59 | # def sat(i: List[str], a=5): 60 | # return i==['5'] 61 | # """) 62 | # 63 | # extract_constants(q) 64 | # 65 | # 66 | # %% 67 | class Solution(): 68 | def __init__(self, string=None, prog=None, likelihood=None, time=None, count=None): 69 | self.string = string 70 | self.prog = prog 71 | self.likelihood = likelihood 72 | self.time = time 73 | self.count = count 74 | 75 | 76 | class SolverSolution(Solution): 77 | def __init__(self, string=None, prog=None, likelihood=None, time=None, count=None): 78 | super().__init__(string=string, prog=prog, likelihood=likelihood) 79 | self.time = time 80 | self.count = count 81 | 82 | 83 | def get_arg_type_str(sat_str): 84 | assert sat_str.startswith("def sat(") and ":" in sat_str 85 | depth = 0 86 | for i, c in enumerate(sat_str): 87 | if c == '[': 88 | depth += 1 89 | elif c == ']': 90 | depth -= 1 91 | elif c in ")," and depth == 0: 92 | return sat_str[sat_str.index(":") + 1:i].lstrip() 93 | assert False 94 | 95 | 96 | class Challenge(): 97 | def __init__(self, challenge_config, max_ticks=100000000): 98 | self.name = challenge_config["name"] 99 | self.f_str = challenge_config["sat"] 100 | self.type_str = get_arg_type_str(challenge_config["sat"]) 101 | self.type = eval(self.type_str) 102 | self.gold_solutions = [] 103 | self.solver_solutions = [] 104 | for sol in challenge_config["sols"]: 105 | self.gold_solutions.append(Solution(string=sol)) 106 | if "sol_tries" in challenge_config: 107 | for i, x in enumerate(challenge_config["sol_tries"]): 108 | self.gold_solutions[i].count = x 109 | 110 | if "sol_time" in challenge_config: 111 | for i, x in enumerate(challenge_config["sol_time"]): 112 | self.gold_solutions[i].time = x 113 | 114 | self.solution_strs = challenge_config["sols"] 115 | self.max_ticks = max_ticks 116 | 117 | self._parse_challenge() 118 | 119 | def _parse_challenge(self): 120 | ''' 121 | Converts the challenge string to a tython program. 122 | ''' 123 | self.sol_kind = tython.nt.type2nt(self.type) 124 | self.prog = None 125 | self.f = None 126 | try: 127 | self.prog = tython.Program( 128 | self.f_str) 129 | self.f = self.prog.run(max_ticks=self.max_ticks) 130 | except Program.EvalException as e: 131 | logger.warning(f"Exception evaluating {self.name} '{self.f_str}': {e}") 132 | except Exception as e: 133 | logger.warning(f"Exception parsing {self.name} '{self.f_str}': {e}") 134 | -------------------------------------------------------------------------------- /solvers/enumerative/challenges/solutions.py: -------------------------------------------------------------------------------- 1 | from typing import List, Set, Dict, Callable, Tuple 2 | import logging 3 | from challenges import extract_constants 4 | 5 | from tython import Program, TastNode, _RULES_BY_KIND, RULES, Rule, str2name 6 | from tython.rules import DEF_RULE 7 | 8 | logger = logging.getLogger(__name__) 9 | logger.setLevel(logging.INFO) 10 | 11 | 12 | def generatable_answer(q: Program, a: Program): 13 | def get_src(node): 14 | return Program(node).src(safe=False, simplify=False) 15 | 16 | consts = extract_constants(q) 17 | const_srcs = {k: [get_src(c) for c in consts[k]] for k in consts} 18 | 19 | def helper(anode): 20 | if anode.rule.name == "COPY": 21 | return get_src(anode.children[0]) in const_srcs[anode.rule.nt] 22 | return anode.rule.name != "literal" and all(helper(n) for n in anode.children) 23 | 24 | return helper(a.tree) 25 | 26 | 27 | def verify_solutions(challenges): 28 | ''' 29 | Verify all provided solutions to the given challenges and store the parsed solution 30 | program in the challenge object. 31 | ''' 32 | successes = 0 33 | all_correct = True 34 | for ch in challenges: 35 | verified_sols = [] 36 | f = ch.prog.run(max_ticks=ch.max_ticks)['sat'] 37 | args_node = ch.prog.tree.children[0].children[1].children[-1] 38 | for sol_p in ch.gold_solutions: 39 | s = sol_p.string 40 | if s == '': 41 | continue 42 | # Verify solution is correct. 43 | if True: 44 | assert s.startswith('def ') 45 | try: 46 | sol_prog = Program(s) 47 | except Exception as e: 48 | logger.error(f"Exception parsing solution for {ch.name} '{s}': {e}") 49 | continue 50 | 51 | # Inject the value assignments for the variables to the call of the sol func. 52 | p_body = sol_prog.tree.children[0].children[2] 53 | sol_prog.tree.children[0].children[1].children = args_node.children 54 | sol_prog = Program(TastNode(DEF_RULE, [str2name("sol"), args_node, p_body])) 55 | 56 | a_safe = sol_prog.src(simplify=False) 57 | x = sol_prog.run(max_ticks=ch.max_ticks)["sol"]() 58 | ch.prog.reset_clock() 59 | 60 | v = f(x) 61 | assert isinstance(v, bool) 62 | 63 | if not generatable_answer(ch.prog, sol_prog): 64 | logger.error(f'Challenge "{ch.name}" cannot be used to automatically generate solution "{s}"') 65 | 66 | # TODO 67 | # if type(y) != ch.type: 68 | # print(f'Challenge "{ch.name}" has wrong solution type: "{type(y)}"') 69 | # all_correct = False 70 | if v is not True: # checks both False and None 71 | logger.error(f'Challenge "{ch.name}" not satisfied by solution "{s}"') 72 | else: 73 | sol_p.prog = sol_prog 74 | verified_sols.append(sol_p) 75 | successes += 1 76 | 77 | ch.gold_solutions = verified_sols 78 | 79 | logger.info( 80 | f"Tython confirmed {successes:,} solutions to {len(challenges)} challenges." 81 | ) 82 | return 83 | -------------------------------------------------------------------------------- /solvers/enumerative/download_pretrained_roberta.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Linux commands to download our Roberta model pretrained on Python code. 4 | # Newer vesrions of huggingface transformers don't require this but we need to adjust the rest of the code for them. 5 | 6 | set -ex 7 | 8 | mkdir tals 9 | mkdir tals/roberta_python 10 | 11 | cd tals/roberta_python 12 | 13 | wget https://huggingface.co/tals/roberta_python/resolve/main/config.json 14 | wget https://huggingface.co/tals/roberta_python/resolve/main/merges.txt 15 | wget https://huggingface.co/tals/roberta_python/resolve/main/pytorch_model.bin 16 | wget https://huggingface.co/tals/roberta_python/resolve/main/special_tokens_map.json 17 | wget https://huggingface.co/tals/roberta_python/resolve/main/tokenizer_config.json 18 | wget https://huggingface.co/tals/roberta_python/resolve/main/training_args.bin 19 | wget https://huggingface.co/tals/roberta_python/resolve/main/vocab.json 20 | 21 | cd ../.. 22 | -------------------------------------------------------------------------------- /solvers/enumerative/filter_outputs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os 4 | 5 | 6 | inp_file = sys.argv[1] 7 | if len(sys.argv) > 2: 8 | shift = int(sys.argv[2]) 9 | else: 10 | shift = 0 11 | out_file = os.path.splitext(inp_file)[0] 12 | 13 | 14 | with open(inp_file, 'r') as f: 15 | data = json.load(f) 16 | 17 | thresholds = [100, 1000, 10000, 100000, 1000000] 18 | for t in thresholds: 19 | t = t 20 | out = [] 21 | suc = 0 22 | for p in data: 23 | #if not p["name"].startswith("Study"): 24 | # continue 25 | if p["sols"][-1] != "" and p["sol_tries"][-1] + shift <= t: 26 | out.append(p) 27 | suc += 1 28 | else: 29 | out.append(dict(name=p["name"], sat=p["sat"], sols=[])) 30 | 31 | print(f"t={t}: solutions: {suc}/ {len(out)}") 32 | 33 | with open(out_file + f"_{t}.json", "w") as fw: 34 | json.dump(out, fw, indent=4) 35 | -------------------------------------------------------------------------------- /solvers/enumerative/models/__init__.py: -------------------------------------------------------------------------------- 1 | MODEL_REGISTRY = {} 2 | def RegisterModel(model_name): 3 | def decorator(m): 4 | MODEL_REGISTRY[model_name] = m 5 | return m 6 | 7 | return decorator 8 | 9 | from models.uniform import * 10 | #from models.bigram import * 11 | #from models.ml_bow_unigram import * 12 | from models.ml_bow_bigram import * 13 | -------------------------------------------------------------------------------- /solvers/enumerative/models/ml_bow_bigram.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Callable, Tuple, Generator, Set, Sequence 2 | from collections import defaultdict 3 | import numpy as np 4 | from copy import deepcopy 5 | import sklearn 6 | import sklearn.linear_model 7 | import sklearn.ensemble 8 | import sklearn.neighbors 9 | import sklearn.tree 10 | from scipy import sparse 11 | from tqdm import tqdm 12 | import time 13 | import logging 14 | 15 | from tython import Program, TastNode, _RULES_BY_KIND, Rule, RULES, nt 16 | from models.model import CandidateGenerator, reachable_rules_by_kind 17 | from models import RegisterModel 18 | from challenges import extract_constants 19 | 20 | logger = logging.getLogger(__name__) 21 | logger.setLevel(logging.INFO) 22 | 23 | 24 | @RegisterModel("ml_bow_bigram") 25 | class BOWCondParentLRModel(CandidateGenerator): 26 | ''' 27 | learn rule weights with logistic regression. 28 | Features are: BOW for question and parent rule + child ind 29 | ''' 30 | 31 | def __init__(self, ml_model='lr') -> None: 32 | ''' 33 | Learn model per kind. 34 | featurization: feature extraction method. 35 | ''' 36 | super().__init__() 37 | self.vocabs = _RULES_BY_KIND 38 | self.ml_model = ml_model 39 | 40 | if ml_model == 'lr': 41 | self.sk_model = sklearn.linear_model.LogisticRegression(penalty='l2') 42 | elif ml_model == 'rf': 43 | self.sk_model = sklearn.ensemble.RandomForestClassifier() 44 | elif ml_model == 'knn': 45 | self.sk_model = sklearn.neighbors.KNeighborsClassifier() 46 | elif ml_model == 'dt': 47 | self.sk_model = sklearn.tree.DecisionTreeClassifier() 48 | else: 49 | raise Exception("Unknown ml model %s", ml_model) 50 | 51 | def featurize_question(self, prog: Program): 52 | ''' 53 | Convert program to features with bow (bag of words/ rules) 54 | ''' 55 | 56 | qf = np.zeros(len(RULES), dtype=int) 57 | 58 | # Get all rules in prog. 59 | 60 | queue = [prog.tree] 61 | while queue: 62 | node = queue.pop() 63 | qf[node.rule.index] += 1 64 | for child in node.children: 65 | queue.append(child) 66 | 67 | return sparse.csr_matrix(qf) 68 | 69 | def add_parent_features(self, qf: np.array, parent_rule: int, child_num: int): 70 | # Assume max 10 children 71 | assert child_num < 10 72 | parent_feature = np.zeros(len(RULES) + 1 + 10, dtype=int) 73 | parent_feature[parent_rule] = 1 74 | parent_feature[child_num + len(RULES) + 1] = 1 75 | 76 | # features = np.concatenate((qf, parent_feature)) 77 | parent_feature = sparse.csr_matrix(parent_feature) 78 | features = sparse.hstack([qf, parent_feature]) 79 | return features 80 | 81 | def traverse_ans_tree(self, prog: Program): 82 | ''' 83 | Collect all rules with their features. 84 | ''' 85 | # Initiate with root, parent_ind=len(RULES) 86 | rules = [[prog.tree.rule.index, len(RULES), 0]] 87 | queue = [prog.tree] 88 | while queue: 89 | node = queue.pop() 90 | parent_ind = node.rule.index 91 | for num_child, child in enumerate(node.children): 92 | rules.append([child.rule.index, parent_ind, num_child]) 93 | queue.append(child) 94 | 95 | return rules 96 | 97 | def learn(self, QAs: List[Tuple[Program, Program]]) -> None: 98 | ''' 99 | Optimize the ML models with the given examples. 100 | QAs: list of QAs for building vocab. 101 | ''' 102 | # xs = None 103 | xs = [] 104 | targets = [] 105 | logger.info("Creating features from puzzle solution pairs") 106 | for q, a in tqdm(QAs): 107 | qf = self.featurize_question(q) 108 | for (target_rule, parent_rule, child_num) in self.traverse_ans_tree(a): 109 | input_features = self.add_parent_features(qf, parent_rule, child_num) 110 | 111 | # if xs is None: 112 | # xs = sparse.csr_matrix(input_features) 113 | # else: 114 | # xs = sparse.vstack([xs, input_features]) 115 | xs.append(sparse.csr_matrix(input_features)) 116 | targets.append(target_rule) 117 | 118 | logger.info(f"Collected {len(targets)} rules") 119 | xs = sparse.vstack(xs) 120 | self.sk_model.fit(xs, targets) 121 | 122 | def get_candidates(self, q: Program) -> Dict: 123 | ''' 124 | Predict candidates for each question 125 | ''' 126 | st_time = time.time() 127 | qf = self.featurize_question(q) 128 | 129 | consts = extract_constants(q) 130 | 131 | # Our question is always a function that returns a bool. 132 | sol_type_annotation = q.tree.children[0].children[1].children[1] 133 | sol_kind = nt.type2nt(eval(sol_type_annotation.src())) 134 | 135 | rules_by_kind = _RULES_BY_KIND # zzz reachable_rules_by_kind(sol_kind, consts) 136 | rules_by_kind_sets = {k: {r.index for r in rules} for k, rules in rules_by_kind.items()} 137 | 138 | ans = {} 139 | 140 | times1 = [] 141 | times2 = [] 142 | for parent_kind in rules_by_kind: 143 | # None for root. 144 | for r in rules_by_kind[parent_kind] + ([None] if parent_kind == sol_kind else []): 145 | # Don't include parents that we won't reach anyway (a.k.a massive pruning). 146 | if r is not None and r.index not in self.sk_model.classes_: 147 | continue 148 | if r and r.name == "COPY": 149 | assert len(r.kids) == 1 150 | if r.nt in consts: 151 | p = 1/len(consts[r.nt]) 152 | relevant_consts = [(p, n) for n in consts[r.nt]] 153 | else: 154 | relevant_consts = [] 155 | ans[r] = [(relevant_consts, [])] 156 | continue 157 | 158 | ans[r] = [] 159 | for child_num, kind in enumerate([sol_kind] if r is None else r.kids): 160 | tik2 = time.time() 161 | class_mask = np.array([c in rules_by_kind_sets[kind] for c in self.sk_model.classes_]) 162 | assert child_num < 9 163 | # child_num = min(child_num, 9) 164 | if r is None: 165 | parent_rule = len(RULES) 166 | else: 167 | parent_rule = r.index 168 | 169 | input_features = self.add_parent_features(qf, parent_rule, child_num) 170 | tik1 = time.time() 171 | # TODO: create batches to reduce run time. 172 | rule_probs = self.sk_model.predict_proba(input_features.reshape(1, -1))[0] 173 | tok1 = time.time() - tik1 174 | times1.append(tok1) 175 | 176 | # Mask out predictions to rules of different kind 177 | sum_probs = np.sum(rule_probs * class_mask) 178 | if sum_probs > 0: 179 | rule_probs = rule_probs * class_mask / sum_probs 180 | rule_probs = [(rule_probs[i], RULES[self.sk_model.classes_[i]]) for i in 181 | rule_probs.nonzero()[0]] 182 | assert all(r.nt == kind for _, r in rule_probs) 183 | else: 184 | rule_probs = [] 185 | 186 | 187 | 188 | ans[r].append(([], rule_probs)) 189 | tok2 = time.time() - tik2 190 | times2.append(tok2) 191 | 192 | # print(np.mean(times1)) 193 | # print(np.mean(times2)) 194 | end_time = time.time() - st_time 195 | logger.debug("Get candidates took {:.2f}s".format(end_time)) 196 | 197 | return ans 198 | 199 | def get_likelihood(self, q: TastNode, ans: Program) -> float: 200 | ''' 201 | Get the probaility of the given answer program to the question. 202 | ''' 203 | ans_rules = self.traverse_ans_tree(ans) 204 | qf = self.featurize_question(q) 205 | rule_probs = [] 206 | for r in ans_rules: 207 | target_rule_ind = r[0] 208 | parent_rule, child_num = r[1:] 209 | input_features = self.add_parent_features(qf, parent_rule, child_num) 210 | out_probs = self.sk_model.predict_proba(input_features.reshape(1, -1))[0] 211 | target_prob = out_probs[np.where(self.sk_model.classes_ == target_rule_ind)].item() \ 212 | if target_rule_ind in self.sk_model.classes_ else 0 213 | rule_probs.append(target_prob) 214 | 215 | return np.prod(rule_probs) 216 | -------------------------------------------------------------------------------- /solvers/enumerative/models/model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Callable, Tuple, Generator, Set, Sequence 2 | 3 | from tython import Program, TastNode, Rule, _RULES_BY_KIND 4 | 5 | def reachable_rules_by_kind(sol_kind, terminal_kinds) -> Set: 6 | ''' 7 | Finds all kinds that could be in a tree with sol_kind as the root 8 | and terminal kinds as possible kinds for constants. 9 | ''' 10 | # some kinds are unusuable because there is no way to generate a complete tree 11 | # starting with them. 12 | completable = set(terminal_kinds) # force copy 13 | signatures = {k: {r.kids for r in _RULES_BY_KIND[k]} for k in _RULES_BY_KIND} 14 | uncompletable = set(signatures) - completable 15 | while True: 16 | completable.update({k for k in uncompletable 17 | if any(all(k in completable for k in s) for s in signatures[k])}) 18 | if not completable.intersection(uncompletable): 19 | break 20 | uncompletable -= completable 21 | for k in signatures: 22 | signatures[k] = {s for s in signatures[k] if all(k2 in completable for k2 in s)} 23 | signatures = {k: s for k, s in signatures.items() if s} 24 | if sol_kind not in completable: 25 | return {} 26 | 27 | good_kinds = set() 28 | kind_queue = {sol_kind} 29 | 30 | while kind_queue: 31 | parent_kind = kind_queue.pop() 32 | good_kinds.add(parent_kind) 33 | kind_queue.update({k for sig in signatures[parent_kind] 34 | for k in sig if k not in good_kinds}) 35 | 36 | good_rules_by_kind = {k: [r for r in _RULES_BY_KIND[k] if all(k2 in good_kinds for k2 in r.kids)] 37 | for k in _RULES_BY_KIND if k in good_kinds} 38 | 39 | return {k: v for k, v in good_rules_by_kind.items() if v} 40 | 41 | class CandidateGenerator: 42 | ''' 43 | A model that generate candidate solutions for challenges. 44 | ''' 45 | 46 | def __init__(self) -> None: 47 | pass 48 | 49 | def learn(self, QAs): 50 | return 51 | 52 | def get_candidates(self, q: TastNode) -> Dict[Rule, 53 | List[Tuple[List[Tuple[float, TastNode]], List[Tuple[float, Rule]]]]]: 54 | ''' 55 | Get solution candidates for a question q. 56 | TODO: program instead of TastNode? 57 | ''' 58 | pass 59 | 60 | -------------------------------------------------------------------------------- /solvers/enumerative/models/transformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/solvers/enumerative/models/transformers/__init__.py -------------------------------------------------------------------------------- /solvers/enumerative/models/transformers/generate_rule_embeddings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import csv 5 | import os 6 | import random 7 | import json 8 | import string 9 | 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 13 | from tqdm import tqdm, trange 14 | 15 | from transformers import ( 16 | WEIGHTS_NAME, 17 | AutoConfig, 18 | AutoModel, 19 | AutoTokenizer, 20 | ) 21 | 22 | from dataset_processor import PuzzleProcessor, InputExample, InputFeatures, convert_examples_to_features 23 | 24 | from tython import Program, TastNode, _RULES_BY_KIND, Rule, RULES 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | def set_seed(args): 29 | random.seed(args.seed) 30 | np.random.seed(args.seed) 31 | torch.manual_seed(args.seed) 32 | if args.n_gpu > 0: 33 | torch.cuda.manual_seed_all(args.seed) 34 | 35 | def save_embeds(file_path, embeds, vocab_file): 36 | n_tokens, emb_dim = embeds.shape 37 | with open(file_path, 'w') as f: 38 | writer = csv.writer(f, delimiter='\t') 39 | for i in range(n_tokens): 40 | writer.writerow(list(embeds[i,:])) 41 | 42 | with open(vocab_file, 'w') as f: 43 | writer = csv.writer(f, delimiter='\t') 44 | for i in range(n_tokens): 45 | desc = generate_rule_string(RULES[i]) 46 | writer.writerow([desc]) 47 | 48 | return 49 | 50 | 51 | # First attempt to get random values for rules. 52 | 53 | #thing_list = ['x', 'y', 'z', 'a', 'b', 'i', 'j', 'k'] 54 | #bool_list = ['True', 'False'] 55 | # 56 | # 57 | #kind_gens = { 58 | # 'int': lambda: random.choice([str(random.randint(-1000, 1000)), 'i', 'j']) , 59 | # 'thing': lambda: random.choice(thing_list), 60 | # 'float' : lambda: round(random.random() * random.randint(-100,100), 3), 61 | # 'bool': lambda: random.choice(bool_list), 62 | # 'range': lambda: random.choice(bool_list), 63 | # 'str': lambda: '"{}"'.format(''.join([random.choice(string.ascii_uppercase + string.ascii_lowercase +string.digits) for _ in range(random.randint(1,10))])), 64 | # 'range': lambda: random.choice(['range({})'.format(random.randint(-100,100)), 65 | # 'range({},{})'.format(random.randint(-100,100), random.randint(-100,100)), 66 | # 'range({},{},{})'.format(random.randint(-100,100), random.randint(-100,100), random.randint(-10,10))]), 67 | # 'none': lambda: 'None', 68 | #} 69 | # 70 | # 71 | #special_gens = { 72 | # 'Callable': lambda: 'func({})', 73 | # 'Tuple': lambda: '({})', 74 | # 'Set': lambda: '{{}}', 75 | #} 76 | # 77 | # 78 | #def gen_random_kind(kind): 79 | # ''' 80 | # Given a string or list of kind, returns a string with an instance of that kind 81 | # ''' 82 | # if isinstance(kind, (tuple, list)): 83 | # var_instances = [] 84 | # for k in kind[1:]: 85 | # var_instances.append(gen_random_kind(k)) 86 | # 87 | # base_str = special_gens[kind[0]]() 88 | # combine_str = base_str.format(', '.join([str(i) for i in var_instances])) 89 | # return combine_str 90 | # else: 91 | # return kind_gens[kind]() 92 | # 93 | # 94 | #def open_kind(kind): 95 | # ''' 96 | # return list of the names of kinds that create it. 97 | # ''' 98 | # if isinstance(kind, (tuple, list)): 99 | # return [open_kind(k) for k in kind] 100 | # else: 101 | # return kind.__name__ if hasattr(kind, "__name__") else kind._name if hasattr(kind, "_name") else str(kind) 102 | # 103 | # 104 | #def generate_rule_strings(rule, max_num = 100): 105 | # children_kinds = [] 106 | # for kind in rule.children: 107 | # children_kinds.append(open_kind(kind)) 108 | # 109 | # rule_strings = [] 110 | # for _ in range(max_num): 111 | # children_vals = [] 112 | # for k in children_kinds: 113 | # children_vals.append(gen_random_kind(k)) 114 | # 115 | # desc = rule.to_python(*children_vals) 116 | # if desc not in rule_strings: 117 | # rule_strings.append(desc) 118 | # 119 | # if 'Callable' in str(rule): 120 | # return rule_strings 121 | 122 | def generate_rule_string(rule): 123 | return rule.var_repr() 124 | 125 | def main(): 126 | parser = argparse.ArgumentParser() 127 | 128 | # Required parameters 129 | parser.add_argument( 130 | "--model_name_or_path", 131 | default=None, 132 | type=str, 133 | required=True, 134 | help="Path to pretrained model or model identifier from huggingface.co/models", 135 | ) 136 | parser.add_argument( 137 | "--output_dir", 138 | default=None, 139 | type=str, 140 | required=True, 141 | help="The output directory where the model predictions and checkpoints will be written.", 142 | ) 143 | 144 | # Other parameters 145 | parser.add_argument( 146 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" 147 | ) 148 | parser.add_argument( 149 | "--tokenizer_name", 150 | default="", 151 | type=str, 152 | help="Pretrained tokenizer name or path if not the same as model_name", 153 | ) 154 | parser.add_argument( 155 | "--cache_dir", 156 | default=None, 157 | type=str, 158 | help="Where do you want to store the pre-trained models downloaded from s3", 159 | ) 160 | parser.add_argument( 161 | "--max_seq_length", 162 | default=128, 163 | type=int, 164 | help="The maximum total input sequence length after tokenization. Sequences longer " 165 | "than this will be truncated, sequences shorter will be padded.", 166 | ) 167 | parser.add_argument( 168 | "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." 169 | ) 170 | 171 | parser.add_argument( 172 | "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation." 173 | ) 174 | 175 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 176 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 177 | 178 | args = parser.parse_args() 179 | 180 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 181 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() 182 | args.device = device 183 | 184 | # Setup logging 185 | logging.basicConfig( 186 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 187 | datefmt="%m/%d/%Y %H:%M:%S", 188 | level=logging.INFO, 189 | ) 190 | logger.warning( 191 | "Process device: %s, n_gpu: %s", 192 | device, 193 | args.n_gpu, 194 | ) 195 | 196 | # Set seed 197 | set_seed(args) 198 | 199 | # Load pretrained model and tokenizer 200 | config = AutoConfig.from_pretrained( 201 | args.config_name if args.config_name else args.model_name_or_path, 202 | cache_dir=args.cache_dir, 203 | ) 204 | args.model_type = config.model_type 205 | tokenizer = AutoTokenizer.from_pretrained( 206 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 207 | do_lower_case=args.do_lower_case, 208 | cache_dir=args.cache_dir, 209 | ) 210 | model = AutoModel.from_pretrained( 211 | args.model_name_or_path, 212 | from_tf=bool(".ckpt" in args.model_name_or_path), 213 | config=config, 214 | cache_dir=args.cache_dir, 215 | ) 216 | 217 | model.to(args.device) 218 | logger.info("Generating embeddings for %s rules", len(RULES)) 219 | 220 | examples = [] 221 | for i, rule in enumerate(RULES): 222 | desc = generate_rule_string(rule) 223 | examples.append( 224 | InputExample(guid=i, 225 | puzzle_str=desc, 226 | parent_ind=None, 227 | child_num=None, 228 | label=None)) 229 | 230 | features = convert_examples_to_features(examples, tokenizer, max_length=args.max_seq_length) 231 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 232 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.int) 233 | dataset = TensorDataset(all_input_ids, all_attention_mask) 234 | 235 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 236 | eval_sampler = SequentialSampler(dataset) 237 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 238 | 239 | hidden_size = model.config.hidden_size 240 | # multi-gpu eval 241 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 242 | model = torch.nn.DataParallel(model) 243 | 244 | logger.info("***** Running evaluation *****") 245 | logger.info(" Num examples = %d", len(dataset)) 246 | logger.info(" Batch size = %d", args.eval_batch_size) 247 | rule_embeddings = np.zeros((len(RULES), hidden_size)) 248 | for b, batch in tqdm(enumerate(eval_dataloader), desc="Evaluating", total=len(dataset)/args.eval_batch_size): 249 | model.eval() 250 | batch = tuple(t.to(args.device) for t in batch) 251 | 252 | with torch.no_grad(): 253 | mask = batch[1] 254 | inputs = {"input_ids": batch[0], "attention_mask": mask} 255 | outputs, _ = model(**inputs) 256 | 257 | # Get Average of token representations 258 | mask_ex = mask.unsqueeze(-1).expand_as(outputs) 259 | y = (outputs * mask_ex).sum(1) 260 | y = y / mask_ex.sum(1) 261 | 262 | y = y.detach().cpu().numpy() 263 | for i, emb in enumerate(y): 264 | rule_embeddings[b * args.eval_batch_size + i, :] = emb 265 | 266 | os.makedirs(args.output_dir, exist_ok=True) 267 | out_file = os.path.join(args.output_dir, "rule_embeds.tsv") 268 | vocab_file = os.path.join(args.output_dir, "rule_embeds_vocab.tsv") 269 | save_embeds(out_file, rule_embeddings, vocab_file) 270 | logger.info("Saved embeddings to %s", out_file) 271 | 272 | if __name__ == "__main__": 273 | main() 274 | -------------------------------------------------------------------------------- /solvers/enumerative/models/transformers/learn_tokenizer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | from os.path import join 4 | 5 | from tokenizers import ByteLevelBPETokenizer 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--files", 10 | metavar="path", 11 | type=str, 12 | default="python_data/train.txt", 13 | help="The files to use as training; accept '**/*.txt' type of patterns \ 14 | if enclosed in quotes", 15 | ) 16 | parser.add_argument( 17 | "--out", 18 | default="trained_models/roberta_python_tokenizer", 19 | type=str, 20 | help="Path to the output directory, where the files will be saved", 21 | ) 22 | parser.add_argument( 23 | "--name", default="bpe-bytelevel", type=str, help="The name of the output vocab files" 24 | ) 25 | args = parser.parse_args() 26 | 27 | files = glob.glob(args.files) 28 | if not files: 29 | print(f"File does not exist: {args.files}") 30 | exit(1) 31 | 32 | 33 | # Initialize an empty tokenizer 34 | tokenizer = ByteLevelBPETokenizer(add_prefix_space=True) 35 | 36 | # And then train 37 | tokenizer.train( 38 | files, 39 | vocab_size=50265, 40 | min_frequency=2, 41 | show_progress=True, 42 | special_tokens=["", "", "", "", ""], 43 | ) 44 | 45 | # Save the files 46 | os.makedirs(args.out, exist_ok=True) 47 | tokenizer.save_model(args.out, args.name) 48 | 49 | # Restoring model from learned vocab/merges 50 | tokenizer = ByteLevelBPETokenizer( 51 | join(args.out, "{}-vocab.json".format(args.name)), 52 | join(args.out, "{}-merges.txt".format(args.name)), 53 | add_prefix_space=True, 54 | ) 55 | 56 | # Test encoding 57 | print(tokenizer.encode("Training ByteLevel BPE is very easy").tokens) 58 | -------------------------------------------------------------------------------- /solvers/enumerative/models/transformers/preprocess_pretraining_data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import gzip 3 | import json 4 | import random 5 | from tqdm import tqdm 6 | 7 | for split in ['train', 'valid', 'test']: 8 | data_path = 'python_data/{}/*.jsonl.gz'.format(split) 9 | #output_file = 'python_data/{}.jsonl'.format(split) 10 | output_file = 'python_data/{}.txt'.format(split) 11 | 12 | if False: 13 | of = open(output_file, 'w') 14 | 15 | for file_name in glob.glob(data_path): 16 | with gzip.open(file_name, 'rb') as f: 17 | for line in tqdm(f, desc=file_name): 18 | example = json.loads(line) 19 | code = example['code'] 20 | doc = example['docstring'] 21 | assert doc in code 22 | new_c = code[:code.index(doc)] + code[code.index(doc) + len(doc):] 23 | #new_c = new_c.replace('\n', '\\n') 24 | #doc = doc.replace('\n', '\\n') 25 | if random.random() < 0.5: 26 | inp = new_c 27 | out = doc 28 | else: 29 | inp = doc 30 | out = new_c 31 | 32 | of.write(json.dumps(dict(inp=inp, out=out)) + '\n') 33 | 34 | of.close() 35 | 36 | if True: 37 | of = open(output_file, 'w') 38 | 39 | for file_name in glob.glob(data_path): 40 | with gzip.open(file_name, 'rb') as f: 41 | for line in tqdm(f, desc=file_name): 42 | example = json.loads(line) 43 | of.write(example['original_string'].replace('\n', '\\n')) 44 | of.write('\n') 45 | 46 | of.close() 47 | 48 | if False: 49 | out_splt = split if split != 'valid' else 'val' 50 | src = 'python_data/{}.source'.format(out_splt) 51 | tgt = 'python_data/{}.target'.format(out_splt) 52 | of_src = open(src, 'w') 53 | of_tgt = open(tgt, 'w') 54 | for file_name in glob.glob(data_path): 55 | with gzip.open(file_name, 'rb') as f: 56 | for line in tqdm(f, desc=file_name): 57 | example = json.loads(line) 58 | code = example['code'] 59 | doc = example['docstring'] 60 | assert doc in code 61 | new_c = code[:code.index(doc)] + code[code.index(doc) + len(doc):] 62 | new_c = new_c.replace('\n', '\\n') 63 | doc = doc.replace('\n', '\\n') 64 | if random.random() < 0.5: 65 | of_src.write(new_c) 66 | of_src.write('\n') 67 | of_tgt.write(doc) 68 | of_tgt.write('\n') 69 | else: 70 | of_src.write(doc) 71 | of_src.write('\n') 72 | of_tgt.write(new_c) 73 | of_tgt.write('\n') 74 | of_src.close() 75 | of_tgt.close() 76 | -------------------------------------------------------------------------------- /solvers/enumerative/models/uniform.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Callable, Tuple, Generator, Set, Sequence 2 | import functools 3 | import operator 4 | from collections import defaultdict 5 | import random 6 | 7 | from tython import Program, TastNode, _RULES_BY_KIND, Rule, nt 8 | from models.model import CandidateGenerator, reachable_rules_by_kind 9 | from models import RegisterModel 10 | from challenges import extract_constants 11 | 12 | 13 | def prod(iterable): # like sum but product 14 | return functools.reduce(operator.mul, iterable, 1) 15 | 16 | 17 | @RegisterModel("uniform") 18 | class UniformModel(CandidateGenerator): 19 | ''' 20 | Uniformly samples from all rules. 21 | ''' 22 | 23 | def __init__(self, copy_prob=0.5) -> None: 24 | super().__init__() 25 | self.copy_prob = copy_prob 26 | # self.intended_depth = intended_depth 27 | #self._random_cache = defaultdict(lambda: defaultdict(dict)) 28 | self._random_cache = {} # defaultdict is not pickable (for multiproc) 29 | 30 | def random(self, kind, max_depth=5, nodes_by_kind=None): 31 | if nodes_by_kind is None: 32 | nodes_by_kind = {} 33 | key = sum(hash(n) for n in nodes_by_kind) 34 | if key not in self._random_cache: 35 | self._random_cache[key] = {} 36 | cache = self._random_cache[key] 37 | 38 | # cache[kind][depth] is a list of available rules 39 | 40 | def available_rules(kind, depth): 41 | if depth <= 0: 42 | return [] 43 | if kind in cache and depth in cache[kind]: 44 | return cache[kind][depth] 45 | rules = [r for r in _RULES_BY_KIND[kind] if 46 | all([nodes_by_kind.get(k) or available_rules(k, depth - 1) for k in r.kids])] 47 | if kind not in cache: 48 | cache[kind] = {} 49 | cache[kind][depth] = rules 50 | return rules 51 | 52 | def helper(kind, depth): 53 | assert depth >= 0 54 | rules = available_rules(kind, depth) 55 | assert rules or nodes_by_kind.get(kind), f"Cannot generate random {kind} of depth <= {depth}" 56 | 57 | if nodes_by_kind.get(kind) and (not rules or random.random() < self.copy_prob): 58 | return random.choice(nodes_by_kind[kind]) 59 | 60 | rule = random.choice(rules) 61 | 62 | return TastNode(rule, [helper(k, depth - 1) for k in rule.kids]) 63 | 64 | return Program(helper(kind, max_depth)) 65 | 66 | def get_candidates_by_nodes(self, kind, nodes_by_kind): 67 | rules_by_kind = _RULES_BY_KIND # zzz reachable_rules_by_kind(kind, nodes_by_kind) 68 | 69 | if kind not in rules_by_kind: 70 | return {} 71 | 72 | # p_kind_rules = {k: (1 - self.copy_prob if nodes_by_kind.get(k) else 1) / max(1, len(rules_by_kind[k])) 73 | # for k in rules_by_kind} 74 | 75 | by_kind = {} 76 | 77 | for parent_kind in rules_by_kind: 78 | rules = rules_by_kind[parent_kind] 79 | has_copy = sum(r.name == "COPY" for r in rules) 80 | if has_copy: 81 | assert has_copy == 1 82 | by_kind[parent_kind] = [], [ 83 | (self.copy_prob if r.name == "COPY" else (1 - self.copy_prob) / len(rules), r) 84 | for r in rules] 85 | else: 86 | by_kind[parent_kind] = [], [(1 / len(rules), r) for r in rules] 87 | 88 | ans = {r: [by_kind[k] for k in r.kids] 89 | for rules in rules_by_kind.values() for r in rules if r.name != "COPY"} 90 | ans.update({r: [([(1. / len(nodes_by_kind[r.nt]), n) for n in nodes_by_kind.get(r.nt, [])], [])] 91 | for rules in rules_by_kind.values() for r in rules if r.name == "COPY"}) 92 | ans[None] = [by_kind[kind]] 93 | 94 | return ans 95 | 96 | def get_candidates(self, q: Program) -> Dict[Rule, 97 | List[Tuple[List[Tuple[float, TastNode]], List[Tuple[float, Rule]]]]]: 98 | 99 | consts = extract_constants(q) 100 | 101 | sol_type_annotation = q.tree.children[0].children[1].children[1] 102 | sol_kind = nt.type2nt(eval(sol_type_annotation.src())) 103 | 104 | return self.get_candidates_by_nodes(sol_kind, consts) 105 | 106 | 107 | if __name__ == "__main__": 108 | u = UniformModel() 109 | random.seed(0) 110 | for _ in range(100): 111 | p = u.random((List, int), 5) 112 | print(p.src(safe=False)) 113 | try: 114 | p.val(max_ticks=1000) 115 | except Program.EvalException: 116 | pass 117 | -------------------------------------------------------------------------------- /solvers/enumerative/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | orderedset 3 | numpy 4 | astor 5 | sklearn 6 | torch==1.13.1 7 | transformers==4.30.0 8 | tensorboardX 9 | -------------------------------------------------------------------------------- /solvers/enumerative/run_bigram.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | set -ex 4 | 5 | # Bootsrapping process. Starting with Uniform 6 | python solve_challenges.py \ 7 | --challenges_path "results/uniform/out.json" \ 8 | --eval_challenges_path "results/uniform/out.json" \ 9 | -m ml_bow_bigram \ 10 | --ml_model rf \ 11 | --copy_prob 0.5 \ 12 | --max_n_progs 10000 \ 13 | --timeout_secs 3600 \ 14 | --threads 40 \ 15 | --learn_from_train_sols \ 16 | --logging_dir results/bootstrap/ml_bigram_0 \ 17 | --out_file results/bootstrap/ml_bigram_0/out.json 18 | 19 | 20 | for i in {1..5}; do 21 | python solve_challenges.py \ 22 | --challenges_path "results/bootstrap/ml_bigram_$(($i-1))/out.json" \ 23 | --eval_challenges_path "results/bootstrap/ml_bigram_$(($i-1))/out.json" \ 24 | -m ml_bow_bigram \ 25 | --ml_model rf \ 26 | --copy_prob 0.5 \ 27 | --max_n_progs 10000 \ 28 | --timeout_secs 3600 \ 29 | --threads 40 \ 30 | --learn_from_train_sols \ 31 | --logging_dir results/bootstrap/ml_bigram_$i \ 32 | --out_file results/bootstrap/ml_bigram_$i/out.json 33 | done 34 | 35 | # Last run until 1M. 36 | python solve_challenges.py \ 37 | --challenges_path "results/bootstrap/ml_bigram_5/out.json" \ 38 | --eval_challenges_path "results/bootstrap/ml_bigram_5/out.json" \ 39 | -m ml_bow_bigram \ 40 | --ml_model rf \ 41 | --copy_prob 0.5 \ 42 | --max_n_progs 1000000 \ 43 | --timeout_secs 3600 \ 44 | --threads 20 \ 45 | --learn_from_train_sols \ 46 | --logging_dir results/bootstrap/ml_bigram_6 \ 47 | --out_file results/bootstrap/ml_bigram_6/out.json 48 | 49 | 50 | # Run without self-bootrapping (only over unifrom) until 1M. 51 | python solve_challenges.py \ 52 | --challenges_path "results/uniform/out.json" \ 53 | --eval_challenges_path "results/uniform/out.json" \ 54 | -m ml_bow_bigram \ 55 | --ml_model rf \ 56 | --copy_prob 0.5 \ 57 | --max_n_progs 1000000 \ 58 | --timeout_secs 3600 \ 59 | --threads 40 \ 60 | --learn_from_train_sols \ 61 | --logging_dir results/ml_bigram \ 62 | --out_file results/ml_bigram/out.json 63 | 64 | -------------------------------------------------------------------------------- /solvers/enumerative/run_transformer.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | set -ex 4 | 5 | # First, extract rule embeddings. 6 | PYTHONPATH=./ python models/transformers/generate_rule_embeddings.py \ 7 | --model_name_or_path tals/roberta_python \ 8 | --output_dir results/roberta_rule_embeddings 9 | 10 | # Bootsrapping process. Starting with Uniform 11 | PYTHONPATH=./ python models/transformers/finetune_transformer.py \ 12 | --challenges_path results/uniform/out.json \ 13 | --eval_challenges_path results/uniform/out.json \ 14 | --model_name_or_path tals/roberta_python \ 15 | --output_dir results/bootstrap/roberta_0 \ 16 | --num_train_epochs 20 \ 17 | --do_train \ 18 | --do_infer \ 19 | --max_n_progs 10000 \ 20 | --timeout_secs 3600 \ 21 | --threads 40 \ 22 | --per_gpu_eval_batch_size 128 \ 23 | --per_gpu_train_batch_size 16 \ 24 | --rule_emb_dir results/roberta_rule_embeddings \ 25 | --overwrite_cache \ 26 | --max_ticks 10000 27 | 28 | for i in {1..5}; do 29 | PYTHONPATH=./ python models/transformers/finetune_transformer.py \ 30 | --challenges_path "results/bootstrap/roberta_$(($i-1))/solutions.json" \ 31 | --eval_challenges_path "results/bootstrap/roberta_$(($i-1))/solutions.json" \ 32 | --model_name_or_path tals/roberta_python \ 33 | --output_dir results/bootstrap/roberta_$i \ 34 | --num_train_epochs 20 \ 35 | --do_train \ 36 | --do_infer \ 37 | --max_n_progs 10000 \ 38 | --timeout_secs 3600 \ 39 | --threads 40 \ 40 | --per_gpu_eval_batch_size 128 \ 41 | --per_gpu_train_batch_size 16 \ 42 | --rule_emb_dir results/roberta_rule_embeddings \ 43 | --overwrite_cache \ 44 | --max_ticks 10000 45 | done 46 | 47 | # Last run until 1M. 48 | PYTHONPATH=./ python models/transformers/finetune_transformer.py \ 49 | --challenges_path "results/bootstrap/roberta_5/solutions.json" \ 50 | --eval_challenges_path "results/bootstrap/roberta_5/solutions.json" \ 51 | --model_name_or_path tals/roberta_python \ 52 | --output_dir results/bootstrap/roberta_6 \ 53 | --num_train_epochs 20 \ 54 | --do_train \ 55 | --do_infer \ 56 | --max_n_progs 1000000 \ 57 | --timeout_secs 3600 \ 58 | --threads 40 \ 59 | --per_gpu_eval_batch_size 128 \ 60 | --per_gpu_train_batch_size 16 \ 61 | --rule_emb_dir results/roberta_rule_embeddings \ 62 | --overwrite_cache \ 63 | --max_ticks 10000 64 | 65 | 66 | # Run without self-bootrapping (only over unifrom) until 1M. 67 | PYTHONPATH=./ python models/transformers/finetune_transformer.py \ 68 | --challenges_path results/uniform/out.json \ 69 | --eval_challenges_path results/uniform/out.json \ 70 | --model_name_or_path tals/roberta_python \ 71 | --output_dir results/roberta \ 72 | --num_train_epochs 20 \ 73 | --do_train \ 74 | --do_infer \ 75 | --max_n_progs 1000000 \ 76 | --timeout_secs 3600 \ 77 | --threads 40 \ 78 | --per_gpu_eval_batch_size 128 \ 79 | --per_gpu_train_batch_size 16 \ 80 | --rule_emb_dir results/roberta_rule_embeddings \ 81 | --overwrite_cache \ 82 | --max_ticks 10000 83 | -------------------------------------------------------------------------------- /solvers/enumerative/run_uniform.sh: -------------------------------------------------------------------------------- 1 | python solve_challenges.py \ 2 | -p "../../problems/*.json" \ 3 | -m uniform \ 4 | --solve_uniform \ 5 | --copy_prob 0.5 \ 6 | --max_n_progs 10000 \ 7 | --timeout_secs 3600 \ 8 | --threads 40 \ 9 | --logging_dir results/uniform \ 10 | --out_file results/uniform/out.json 11 | -------------------------------------------------------------------------------- /solvers/enumerative/tython/__init__.py: -------------------------------------------------------------------------------- 1 | from .program import * 2 | from . import nonterminals as nt 3 | from .rules import Rule, _rules_by_nt as _RULES_BY_KIND, rules as RULES 4 | from .parse import str2name 5 | -------------------------------------------------------------------------------- /solvers/enumerative/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.str_utils import * 2 | from utils.time_utils import * 3 | 4 | import functools 5 | import operator 6 | 7 | 8 | def prod(iterable): # like sum but product 9 | return functools.reduce(operator.mul, iterable, 1) 10 | 11 | 12 | def flatten(it): 13 | return (e for a in it for e in (flatten(a) if isinstance(a, (tuple, list)) else (a,))) 14 | 15 | 16 | def load_json(filename): 17 | import json 18 | with open(filename, "r") as f: 19 | return json.load(f) 20 | 21 | 22 | def viz_py(py): 23 | import astor, ast 24 | print(astor.dump_tree(ast.parse(py))) 25 | 26 | 27 | def dedup(li): 28 | seen = set() 29 | return [x for x in li if x not in seen and not seen.add(x)] 30 | -------------------------------------------------------------------------------- /solvers/enumerative/utils/str_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def get_lambda_arg_name(lam): 4 | assert lam.startswith("lambda ") 5 | return lam[len("lambda "):lam.index(":")].strip() 6 | 7 | 8 | def stringify(const): 9 | if type(const) is str: 10 | return json.dumps(const) 11 | return str(const) 12 | 13 | 14 | def color_str(obj, code="\033[0;36m"): 15 | return code + str(obj) + '\033[0m' 16 | 17 | -------------------------------------------------------------------------------- /solvers/enumerative/utils/time_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | def timeit(method): 7 | def timed(*args, **kw): 8 | tick = time.time() 9 | result = method(*args, **kw) 10 | tock = time.time() 11 | logger.debug(f'{method.__name__}: {tock - tick:.3f}s') 12 | 13 | return result 14 | return timed 15 | -------------------------------------------------------------------------------- /solvers/gpt3/README.md: -------------------------------------------------------------------------------- 1 | # Running GPT-3 experiments 2 | 3 | These are instructions for re-running the GPT-3 experiments. The results will be slightly different than those in 4 | the paper because the API is non-deterministic. 5 | 6 | The requirements can be installed with `pip3 install -r requirements.txt`. 7 | 8 | `run_gpt3_experiments.py` runs the GPT-3 experiments and prints the results to stdout. 9 | 10 | ## Installation and execution. 11 | You will need an open-ai GPT-3 access key which can be signed up for [here](https://openai.com/join/). 12 | You will then need to set it as the `OPENAI_API_KEY` environmnet variable. 13 | 14 | The requirements can be installed with `pip3 install -r requirements.txt`. 15 | 16 | It was run with Python 3.6.9, sys.version = '3.6.9 (default, Jan 26 2021, 15:33:00) \n[GCC 8.4.0]', but should 17 | be compatible with later versions as well. 18 | 19 | Then you simply run 20 | `python run_gpt3_experiments.py` and the results are written to stdout. It uses cacheing mechanisms with the first run 21 | being quite slow and verbose, querying the API. However you can subsequently run it again and it will be 22 | much faster and just output the results. The cacheing makes it deterministic so it should give the same 23 | exact results when re-run. 24 | -------------------------------------------------------------------------------- /solvers/gpt3/create_simple_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | 4 | from run_gpt3_experiments import PARAMS, PREFIX, PREFIX_DOCSTR 5 | import lm_solve 6 | 7 | OUTFILENAME = "puzzles_with_prompts.json" 8 | 9 | 10 | def run(outfilename=OUTFILENAME, filename=PARAMS["filename"]): 11 | with open(filename, "r") as f: 12 | entries = json.load(f) 13 | 14 | for entry in entries: 15 | entry["prompts"] = {} 16 | 17 | for mode in ["short", "medium", "long"]: 18 | prefix = { 19 | "short": "", 20 | "medium": PREFIX, 21 | "long": PREFIX_DOCSTR 22 | }[mode] 23 | prefix = re.sub(r" +$", "", (prefix or "").lstrip(), 24 | flags=re.M) # delete leading/trailing whitespace on each line 25 | puzzles = lm_solve.load_puzzles(filename, mode == "long") 26 | prompts = lm_solve.get_prompts(prefix, [f for ft, f in puzzles]) 27 | assert len(puzzles) == len(prompts) == len(entries) 28 | for entry, prompt in zip(entries, prompts): 29 | entry["prompts"][mode] = prompt 30 | 31 | with open(outfilename, "w") as f: 32 | json.dump(entries, f, indent=4) 33 | 34 | 35 | if __name__ == "__main__": 36 | run() 37 | -------------------------------------------------------------------------------- /solvers/gpt3/ezlog.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import inspect 4 | import io 5 | 6 | my_path = os.path.dirname(__file__) 7 | 8 | 9 | def color_str(obj, code="\033[0;36m"): 10 | return code + str(obj) + '\033[0m' 11 | 12 | 13 | _configured = False 14 | 15 | 16 | def configure_logging(stdio_level=logging.INFO, 17 | file_level=logging.DEBUG, 18 | filename=".easy.log", 19 | filepath=os.path.join(my_path, "logs")): 20 | os.makedirs(filepath, exist_ok=True) 21 | filename = os.path.join(filepath, filename) 22 | global _configured 23 | if _configured: 24 | warning("Re-configuring logging") 25 | stdio_handler = logging.StreamHandler() 26 | stdio_handler.setLevel(stdio_level) 27 | file_hanlder = logging.FileHandler(filename) 28 | file_hanlder.setLevel(file_level) 29 | 30 | logging.basicConfig( 31 | format="%(asctime)s - %(levelname)s - %(name)s - %(message).200s", 32 | datefmt="%m/%d/%Y %H:%M:%S", 33 | level=min(stdio_level, file_level), 34 | handlers=[stdio_handler, file_hanlder] 35 | ) 36 | 37 | _configured = True 38 | _get_or_create_logger().debug("Configured logging") 39 | 40 | 41 | _loggers = {} 42 | 43 | 44 | def _get_or_create_logger(): 45 | global _configured, _loggers 46 | if not _configured: 47 | configure_logging() 48 | try: 49 | for frame in inspect.stack(): 50 | name = inspect.getmodule(frame[0]).__name__ 51 | if name != __name__: 52 | break 53 | except: 54 | name = "_" 55 | if name not in _loggers: 56 | _loggers[name] = logging.getLogger(name) 57 | return _loggers[name] 58 | 59 | 60 | def print_to_string(*args, end="", **kwargs): 61 | with io.StringIO() as buf: 62 | print(*args, file=buf, end=end, **kwargs) 63 | return buf.getvalue() 64 | 65 | 66 | def debug(*args, **kwargs): 67 | _get_or_create_logger().debug(print_to_string(*args, **kwargs)) 68 | 69 | 70 | def info(*args, **kwargs): 71 | _get_or_create_logger().info(print_to_string(*args, **kwargs)) 72 | 73 | 74 | log = info 75 | 76 | 77 | def warning(*args, **kwargs): 78 | _get_or_create_logger().warning(print_to_string(*args, **kwargs)) 79 | 80 | 81 | warn = warning 82 | 83 | 84 | def error(*args, **kwargs): 85 | _get_or_create_logger().error(print_to_string(*args, **kwargs)) 86 | -------------------------------------------------------------------------------- /solvers/gpt3/lm_solve/__init__.py: -------------------------------------------------------------------------------- 1 | from lm_solve.run import * 2 | -------------------------------------------------------------------------------- /solvers/gpt3/lm_solve/gpt3_lib.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import openai 4 | import ezlog 5 | import time 6 | import datetime 7 | 8 | assert 'OPENAI_API_KEY' in os.environ, "Need to set environment variable `OPENAI_API_KEY`" 9 | openai.api_key = os.environ['OPENAI_API_KEY'] 10 | 11 | 12 | _CACHE_PATH = os.path.join(os.path.dirname(__file__), "../.cache") 13 | _CACHE_FILENAME = os.path.join(_CACHE_PATH, "gpt3.cache") 14 | _ENCODING = "utf-8" 15 | 16 | _cache = None 17 | 18 | 19 | # the cache file is just a list of (query params dictionary encoded as a string but without n, result list) 20 | # multiple queries with the same params (except for n) are merged into a single big list 21 | def _save_line(item, comment=None): 22 | global _cache 23 | assert _cache is not None 24 | with open(_CACHE_FILENAME, "a", encoding=_ENCODING) as f: 25 | f.write(str(item)+ ((" # " + comment + "\n") if comment else "\n")) 26 | 27 | def _load_cache(): 28 | global _cache 29 | 30 | assert _cache is None, "gpt3 cache already loaded" 31 | 32 | if not os.path.exists(_CACHE_PATH): 33 | ezlog.warn("Creating cache path") 34 | os.makedirs(_CACHE_PATH) 35 | 36 | _cache = {} 37 | 38 | if os.path.exists(_CACHE_FILENAME): 39 | time0 = time.perf_counter() 40 | with open(_CACHE_FILENAME, "r", encoding=_ENCODING) as f: 41 | for k, v in [eval(line) for line in f.readlines()]: 42 | if k not in _cache: 43 | _cache[k] = v 44 | else: 45 | _cache[k].extend(v) 46 | ezlog.info(f"Loaded gpt3 cache in {time.perf_counter()-time0:.1f}s") 47 | else: 48 | ezlog.warn("No gpt3 cache yet") 49 | 50 | 51 | 52 | def query(prompt, n=10, max_tokens=150, temp=1.0, max_batch=32, stop=None, notes=None, cache_only=False, verbose=True): 53 | """Query gpt3 54 | 55 | :param prompt: Up to 2048 tokens (about 3-4k chars) 56 | :param n: number of answers, None returns all cached answers 57 | :param max_tokens: 58 | :param temp: 0.9 seems to work well 59 | :param max_batch: max to query at once 60 | :param stop: string to stop at or '' if not to stop 61 | :param notes: notes you want to save or change in case you want to run the same query more than once! 62 | :return: list of answers and then the response items 63 | """ 64 | global _cache 65 | if _cache is None: 66 | _load_cache() 67 | 68 | if temp == 0 and n > 1: 69 | ezlog.debug("Temp 0: no point in running more than one query") 70 | n = 1 71 | 72 | key = str(dict(prompt=prompt, max_tokens=max_tokens, temp=temp, max_batch=max_batch, stop=stop, rep=notes)) 73 | cached = _cache.get(key, []) 74 | if n is None: 75 | return cached[:] 76 | 77 | if len(cached) >= n: 78 | return cached[:n] 79 | 80 | if cache_only: 81 | pass 82 | 1/0 83 | assert not cache_only, "Entry not found in cache" 84 | if verbose: 85 | print("/"*100) 86 | print("Querying GPT3 with prompt:") 87 | print(prompt) 88 | s = stop and stop.replace('\n', '\\n') 89 | print(f"/// n={n} ({n-len(cached)} new) max_tokens={max_tokens} temp={temp} max_batch={max_batch} stop={s}") 90 | print("/"*100) 91 | 92 | time0 = time.perf_counter() 93 | 94 | new = [] 95 | n -= len(cached) 96 | 97 | while n > 0: 98 | m = min(n, max_batch) 99 | 100 | res = openai.Completion.create( 101 | engine="davinci-msft", 102 | prompt=prompt, 103 | max_tokens=max_tokens, 104 | temperature=temp, 105 | n=m, 106 | stop=stop or None 107 | ) 108 | 109 | new += [c["text"] for c in res["choices"]] 110 | n -= m 111 | 112 | _save_line((key, new), f"{time.perf_counter() - time0:.1f}s {datetime.datetime.now()}") 113 | ans = _cache[key] = cached + new 114 | return ans[:] 115 | 116 | # old code 117 | # # to persist calls to the API... 118 | # _disk_cache = joblib.Memory(os.path.join(os.path.dirname(__file__), ".cache"), verbose=1).cache 119 | # 120 | # 121 | # @_disk_cache 122 | # def query(prompt, n=10, max_tokens=150, temperature=1.0, max_batch=32): 123 | # """Query gpt3 124 | # 125 | # :param prompt: Up to 2048 tokens (about 3-4k chars) 126 | # :param n: number of answers 127 | # :param max_tokens: 128 | # :param temperature: 129 | # :param max_batch: max to query at once 130 | # :return: list of answers and then the response items 131 | # """ 132 | # if temperature == 0 and n > 1: 133 | # ezlog.debug("Temp 0: no point in running more than one query") 134 | # n = 1 135 | # 136 | # responses = [] 137 | # while n > 0: 138 | # m = min(n, max_batch) 139 | # prompt_summary = prompt if len(prompt) < 80 else f"{prompt[:40]}...{prompt[-40:]}" 140 | # ezlog.warn(f"**** Running GPT3 query: temp {temperature}, n={m}, prompt={prompt_summary}") 141 | # time0 = time.perf_counter() 142 | # responses.append(openai.Completion.create( 143 | # engine="davinci-msft", 144 | # prompt=prompt, 145 | # max_tokens=max_tokens, 146 | # temperature=temperature, 147 | # n=m 148 | # )) 149 | # ezlog.info(f"**** Got response in {time.perf_counter()-time0}s...") 150 | # n -= m 151 | # 152 | # return [c["text"] for r in responses for c in r["choices"]], responses 153 | 154 | 155 | -------------------------------------------------------------------------------- /solvers/gpt3/lm_solve/judge.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import time 4 | from pebble import ProcessPool 5 | import multiprocessing 6 | from concurrent.futures import TimeoutError 7 | from typing import List, Set, Tuple, Dict, Union, Any 8 | 9 | import utils 10 | import ezlog 11 | 12 | def _COPY(x): 13 | return x 14 | 15 | _ENV = dict(List=List, Set=Set, Tuple=Tuple, Dict=Dict, COPY=_COPY, 16 | os=None, sys=None, input=None, open=None, print=None, compile=None, copyright=None) 17 | _UNSAFE = ["import", "builtin", "__class"] 18 | 19 | MAX_WORKERS = multiprocessing.cpu_count() // 2 20 | 21 | _CACHE_PATH = os.path.join(os.path.dirname(__file__), "../.cache") # ".cache" 22 | 23 | 24 | class SetCache: 25 | """Simple cache that stores a set of keys. Cannot remove values. Haven't yet implemented iteration.""" 26 | BYTE_LEN = 256 // 8 # for sha256 27 | 28 | def __init__(self, name, path=_CACHE_PATH): 29 | self._path = path 30 | self.name = name 31 | self._filename = os.path.join(path, f"{name}_set.cache") 32 | self._set = None # the main set, loaded lazily 33 | 34 | def update(self, keys): 35 | self._load() 36 | hashes = [self._hash(k) for k in keys] 37 | additions = {h for h in hashes if h not in self._set} 38 | 39 | if additions: 40 | with open(self._filename, "ab") as f: # append to binary file 41 | for h in additions: 42 | f.write(h) 43 | self._set.update(additions) 44 | 45 | def add(self, key): 46 | self.update([key]) 47 | 48 | def _hash(self, key): 49 | return hashlib.sha256(bytes(str(key), encoding='utf8')).digest() 50 | 51 | def __contains__(self, key: str): 52 | self._load() 53 | h = self._hash(key) 54 | return h in self._set 55 | 56 | def __delitem__(self, key): 57 | raise NotImplementedError 58 | 59 | def __len__(self): 60 | self._load() 61 | return len(self._set) 62 | 63 | def _load(self): 64 | if self._set is None: # only load if not already loaded 65 | if not os.path.exists(self._path): 66 | ezlog.warn(f"Creating path for `{self.name}` cache") 67 | os.makedirs(self._path) 68 | 69 | time0 = time.perf_counter() 70 | if os.path.exists(self._filename): 71 | with open(self._filename, "rb") as f: # read binary file 72 | data = f.read() 73 | self._set = {data[j:j + self.BYTE_LEN] for j in range(0, len(data), self.BYTE_LEN)} 74 | else: 75 | self._set = set() 76 | dur = time.perf_counter() - time0 77 | ezlog.info(f"Loaded `{self.name}` cache of {len(self):,} items in {dur:.1f}s") 78 | 79 | 80 | def _judge(code_env): 81 | code, env = code_env 82 | for u in _UNSAFE: 83 | if u in code: 84 | return False 85 | 86 | try: 87 | exec(code, env.copy()) # not sure if copy() is necessary 88 | return True 89 | except Exception as e: 90 | return False 91 | 92 | 93 | # Cache judge results (which are nondeterministic due to timeout) for reproducibility 94 | _judge_success = SetCache('judge_success') 95 | _judged_batches = SetCache('judged_batches') 96 | 97 | def judge_parallel(src_codes, timeout, max_workers=MAX_WORKERS, env=_ENV, force_compute=False): 98 | global _judge_success, _judged_batches 99 | 100 | if force_compute or src_codes not in _judged_batches: 101 | new_codes = utils.dedup(code for code in src_codes if code not in _judge_success) 102 | if new_codes: 103 | ezlog.info(f"Judging {len(new_codes)}/{len(src_codes)} new codes (removing duplicates/things in cache)") 104 | successes = [] 105 | 106 | 107 | with ProcessPool(max_workers=max_workers) as pool: 108 | future = pool.map(_judge, [(code, env) for code in new_codes], timeout=timeout) 109 | 110 | results = future.result() 111 | 112 | i = 0 113 | while True: 114 | try: 115 | if next(results): 116 | successes.append(new_codes[i]) 117 | except StopIteration: 118 | _judge_success.update(successes) 119 | break 120 | except (TimeoutError, Exception) as error: 121 | pass 122 | assert i < len(new_codes) 123 | i += 1 124 | assert i == len(new_codes) 125 | _judged_batches.add(src_codes) 126 | 127 | return [code in _judge_success for code in src_codes] 128 | 129 | 130 | if __name__ == "__main__": 131 | res = judge_parallel([ 132 | """1+1 133 | """, 134 | """assert False,'cats'""", 135 | """assert False""", 136 | """1[2]""", 137 | """1/0""", 138 | """while True: 139 | pass""", 140 | """for i in range(10**5): 141 | pass""" 142 | ], timeout=1.0) 143 | print(res) 144 | -------------------------------------------------------------------------------- /solvers/gpt3/requirements.txt: -------------------------------------------------------------------------------- 1 | astor==0.8.1 2 | numpy==1.22.0 3 | openai==0.6.3 4 | tqdm==4.60.0 5 | transformers==4.30.0 6 | Pebble==4.6.1 7 | 8 | # we ran with Python version sys.version = '3.6.9 (default, Jan 26 2021, 15:33:00) \n[GCC 8.4.0]' 9 | # distro-info===0.18ubuntu0.18.04.1 10 | -------------------------------------------------------------------------------- /solvers/gpt3/run_gpt3_experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script runs the GPT_3 experiments and prints the results to stdout. 3 | It uses cacheing mechanisms so that if run twice with the same parameters, it will give exactly the same 4 | results and will not query the GPT3 API again and will not judge the resulting solutions again. Hence, the first 5 | time you run it, it will be slow, but you can subsequently run it again and it will be fast. It will run the 6 | experiment three times, with different seeds to get different results. 7 | 8 | """ 9 | 10 | import lm_solve 11 | import numpy as np 12 | 13 | PARAMS = dict( 14 | temp=0.9, 15 | timeout=1.0, # seconds to judge 16 | n=10 * 1000, 17 | filename="puzzles_with_descriptions.json", 18 | stop="\n", 19 | cache_only=False, # change this to True if you want to run a 2nd time without risking hitting API 20 | ) 21 | 22 | BOOTSTRAP_PARAMS = dict( 23 | temp=PARAMS["temp"], 24 | timeout=PARAMS["timeout"], 25 | filename=PARAMS["filename"], 26 | stop=PARAMS["stop"], 27 | cache_only=PARAMS["cache_only"], 28 | ppi=32, # puzzles per iteration 29 | iterations=(PARAMS["n"] + 31) // 32, 30 | ) 31 | 32 | STUDY = range(107, 137) # the range of puzzles used in the study 33 | 34 | PREFIX = """ 35 | def f1(s: str): 36 | return "Hello " + s == "Hello world" 37 | 38 | assert True == f1("world") 39 | 40 | --- 41 | 42 | def f2(s: str): 43 | return "Hello " + s[::-1] == "Hello world" 44 | 45 | assert True == f2("world"[::-1]) 46 | 47 | --- 48 | 49 | def f3(x: List[int]): 50 | return len(x) == 2 and sum(x) == 3 51 | 52 | assert True == f3([1, 2]) 53 | 54 | --- 55 | 56 | def f4(s: List[str]): 57 | return len(set(s)) == 1000 and all((x.count("a") > x.count("b")) and ('b' in x) for x in s) 58 | 59 | assert True == f4(["a"*(i+2)+"b" for i in range(1000)]) 60 | 61 | --- 62 | 63 | def f5(n: int): 64 | return str(n * n).startswith("123456789") 65 | 66 | assert True == f5(int(int("123456789" + "0"*9) ** 0.5) + 1) 67 | 68 | --- 69 | 70 | """ # trailing newlines important 71 | 72 | PREFIX_DOCSTR = ''' 73 | def f1(s: str): 74 | """Find a string that when concatenated onto 'Hello ' gives 'Hello world'.""" 75 | return "Hello " + s == "Hello world" 76 | 77 | assert True == f1("world") 78 | 79 | --- 80 | 81 | def f2(s: str): 82 | """Find a string that when reversed and concatenated onto 'Hello ' gives 'Hello world'.""" 83 | return "Hello " + s[::-1] == "Hello world" 84 | 85 | assert True == f2("world"[::-1]) 86 | 87 | --- 88 | 89 | def f3(x: List[int]): 90 | """Find a list of two integers whose sum is 3.""" 91 | return len(x) == 2 and sum(x) == 3 92 | 93 | assert True == f3([1, 2]) 94 | 95 | --- 96 | 97 | def f4(s: List[str]): 98 | """Find a list of 1000 distinct strings which each have more 'a's than 'b's and at least one 'b'.""" 99 | return len(set(s)) == 1000 and all((x.count("a") > x.count("b")) and ('b' in x) for x in s) 100 | 101 | assert True == f4(["a"*(i+2)+"b" for i in range(1000)]) 102 | 103 | --- 104 | 105 | def f5(n: int): 106 | """Find an integer whose perfect square begins with 123456789 in its decimal representation.""" 107 | return str(n * n).startswith("123456789") 108 | 109 | assert True == f5(int(int("123456789" + "0"*9) ** 0.5) + 1) 110 | 111 | --- 112 | 113 | ''' # trailing newlines important 114 | 115 | 116 | def run(seed=0): 117 | sols = [lm_solve.prompt_experiment(**PARAMS, prefix="", seed=seed), 118 | lm_solve.prompt_experiment(**PARAMS, prefix=PREFIX, seed=seed), 119 | lm_solve.prompt_experiment(**PARAMS, prefix=PREFIX_DOCSTR, add_docstring=True, seed=seed)] 120 | problems_solved = [sorted([i for i, (f, gs) in enumerate(s) if gs]) for s in sols] 121 | bootstrap = lm_solve.bootstrap(**BOOTSTRAP_PARAMS, seed=seed) 122 | print(f"run={seed} ALL DONE!\n\n") 123 | print(f"run={seed} RESULTS " + "=" * 50) 124 | print() 125 | 126 | # Instead of running until first success, and outputting number of attempts, we do something more accurate. 127 | # We run for N tries for each problem and do not stop on first success. Then, we use the number of successes 128 | # for a better estimate of the average number of attempts required for first success. If we have s successes 129 | # out of N attempts, then the expected number of attempts is (N - s) / (1 + s). This is the expectation of the 130 | # random variable that is: when you permute the attempts uniformly at random, how many attempts before the 131 | # first success. If s=N, it's 0, if s=1, it's (N-1)/2, etc. 132 | counts = [[(PARAMS["n"] - len(gs)) / (1 + len(gs)) for f, gs in s if gs] for s in sols] 133 | counts.append([m for m, _i, _f, _a in bootstrap]) 134 | counts = [[1 + z for z in c] for c in counts] # add 1 to make it 1-based 135 | for c in counts: 136 | c.sort() 137 | print(f"run={seed} (Expected) number of attempts before a problem is solved [short, med, long, bootstrap]:") 138 | print(counts) 139 | problems_solved.append([i for _m, i, _f, _a in bootstrap]) 140 | print() 141 | print(f"run={seed} Which problems were solved [short, med, long, bootstrap]:") 142 | print(problems_solved) 143 | print() 144 | print(f"run={seed} Number of problems solved [short, med, long, bootstrap]:") 145 | print([len(c) for c in counts]) 146 | print() 147 | print(f"run={seed} Number of 30 study problems solved [short, med, long, bootstrap]:") 148 | print([len([i for i in s if i in STUDY]) for s in problems_solved]) 149 | print() 150 | difficulties = [1.0 for _ in range(len(sols[0]))] 151 | 152 | k = 1 153 | for m, i, f, a in bootstrap: 154 | difficulties[i] = np.log(m + 1) / np.log(PARAMS["n"]) 155 | 156 | # These commented lines print the problems that bootstrap solved 157 | # print() 158 | # print(f"# Bootstrap solved after {m + 1} tries:") 159 | # print(f.replace("def f", "def f" + str(k))) 160 | # import json 161 | # 162 | # print(f"SOL:", json.dumps(a)) 163 | k += 1 164 | print(f"run={seed} Bootstrap difficulties for study puzzles:") 165 | print([difficulties[i] for i in STUDY]) 166 | 167 | 168 | if __name__ == "__main__": 169 | for seed in range(3): 170 | run(seed) 171 | -------------------------------------------------------------------------------- /solvers/gpt3/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.str_utils import * 2 | from utils.time_utils import * 3 | 4 | import functools 5 | import operator 6 | 7 | 8 | def prod(iterable): # like sum but product 9 | return functools.reduce(operator.mul, iterable, 1) 10 | 11 | 12 | def flatten(it): 13 | return (e for a in it for e in (flatten(a) if isinstance(a, (tuple, list)) else (a,))) 14 | 15 | 16 | def load_json(filename): 17 | import json 18 | with open(filename, "r") as f: 19 | return json.load(f) 20 | 21 | 22 | def viz_py(py): 23 | import astor, ast 24 | print(astor.dump_tree(ast.parse(py))) 25 | 26 | 27 | def dedup(li): 28 | seen = set() 29 | return [x for x in li if x not in seen and not seen.add(x)] 30 | -------------------------------------------------------------------------------- /solvers/gpt3/utils/str_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def get_lambda_arg_name(lam): 4 | assert lam.startswith("lambda ") 5 | return lam[len("lambda "):lam.index(":")].strip() 6 | 7 | 8 | def stringify(const): 9 | if type(const) is str: 10 | return json.dumps(const) 11 | return str(const) 12 | 13 | 14 | def color_str(obj, code="\033[0;36m"): 15 | return code + str(obj) + '\033[0m' 16 | 17 | -------------------------------------------------------------------------------- /solvers/gpt3/utils/time_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | def timeit(method): 7 | def timed(*args, **kw): 8 | tick = time.time() 9 | result = method(*args, **kw) 10 | tock = time.time() 11 | logger.debug(f'{method.__name__}: {tock - tick:.3f}s') 12 | 13 | return result 14 | return timed 15 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import inspect 4 | import io 5 | import os 6 | 7 | my_path = os.path.dirname(__file__) 8 | 9 | def check_hashseed(desired_seed = 0): 10 | if os.environ.get('PYTHONHASHSEED') != desired_seed: 11 | info(f"Ideally set PYTHONHASHSEED={desired_seed} for perfect reproducibility") 12 | return False 13 | return True 14 | 15 | 16 | def inv_dict(d): 17 | ans = {} 18 | for k, v in d.items(): 19 | if v not in ans: 20 | ans[v] = [] 21 | ans[v].append(k) 22 | return ans 23 | 24 | 25 | def remove_docstring(f): 26 | """Remove docstring""" 27 | assert '\n """' in f, f"No triple quote docstring (after four spaces) in: \n{f}" 28 | i = f.index('\n """') 29 | j = f.index('"""', i + 8) 30 | return f[:i + 1] + f[j + 4:] 31 | 32 | 33 | def get_docstring(f): 34 | assert '\n """' in f, f"No triple quote docstring (after four spaces) in: \n{f}" 35 | i = f.index('\n """') 36 | j = f.index('"""', i + 8) 37 | docstring = f[i + 1:j + 3] 38 | if not docstring.strip(' "'): 39 | warn(f"Empty docstring in:\n{f}") 40 | return docstring 41 | 42 | 43 | def flatten(it): 44 | return (e for a in it for e in (flatten(a) if isinstance(a, (tuple, list)) else (a,))) 45 | 46 | def save_json(obj, filename, make_dirs_if_necessary=False, indent=2, **kwargs): 47 | """Saves compressed file if filename ends with '.gz'""" 48 | import json 49 | if make_dirs_if_necessary: 50 | os.makedirs(os.path.dirname(filename), exist_ok=True) 51 | if filename.endswith(".gz"): 52 | import gzip 53 | with gzip.open(filename, "wt") as f: 54 | return json.dump(obj, f, indent=indent, **kwargs) 55 | with open(filename, "w", encoding="utf8") as f: 56 | return json.dump(obj, f, indent=indent, **kwargs) 57 | 58 | def load_json(filename): 59 | """Loads compressed file if filename ends with '.gz'""" 60 | import json 61 | if filename.endswith(".gz"): 62 | import gzip 63 | with gzip.open(filename, "rt") as f: 64 | return json.load(f) 65 | with open(filename, "r", encoding="utf8") as f: 66 | return json.load(f) 67 | 68 | 69 | def stringify(const): 70 | if type(const) is str: 71 | return json.dumps(const) 72 | return str(const) 73 | 74 | 75 | def dedup(stuff): 76 | seen = set() 77 | return [a for a in stuff if a not in seen and not seen.add(a)] 78 | 79 | 80 | def color_str(obj, code="\033[0;36m"): 81 | return code + str(obj) + '\033[0m' 82 | 83 | 84 | _configured = False 85 | 86 | 87 | def configure_logging(stdio_level=logging.INFO, 88 | file_level=logging.DEBUG, 89 | filename=os.path.join(my_path, ".problems.log")): 90 | global _configured 91 | if _configured: 92 | warning("Re-configuring logging") 93 | stdio_handler = logging.StreamHandler() 94 | stdio_handler.setLevel(stdio_level) 95 | file_hanlder = logging.FileHandler(filename) 96 | file_hanlder.setLevel(file_level) 97 | 98 | logging.basicConfig( 99 | format="%(asctime)s - %(levelname)s - %(name)s - %(message).200s", 100 | datefmt="%m/%d/%Y %H:%M:%S", 101 | level=min(stdio_level, file_level), 102 | handlers=[stdio_handler, file_hanlder] 103 | ) 104 | 105 | _configured = True 106 | _get_or_create_logger().debug("Configured logging") 107 | 108 | 109 | _loggers = {} 110 | 111 | 112 | def _get_or_create_logger(): 113 | global _configured, _loggers 114 | if not _configured: 115 | configure_logging() 116 | name = "_" 117 | for frame in inspect.stack(): 118 | name = inspect.getmodule(frame[0]).__name__ 119 | if name != __name__: 120 | break 121 | if name not in _loggers: 122 | _loggers[name] = logging.getLogger(name) 123 | return _loggers[name] 124 | 125 | 126 | def print_to_string(*args, end="", **kwargs): 127 | with io.StringIO() as buf: 128 | print(*args, file=buf, end=end, **kwargs) 129 | return buf.getvalue() 130 | 131 | 132 | def debug(*args, **kwargs): 133 | _get_or_create_logger().debug(print_to_string(*args, **kwargs)) 134 | 135 | 136 | def info(*args, **kwargs): 137 | _get_or_create_logger().info(print_to_string(*args, **kwargs)) 138 | 139 | 140 | log = info 141 | 142 | 143 | def warning(*args, **kwargs): 144 | _get_or_create_logger().warning(print_to_string(*args, **kwargs)) 145 | 146 | 147 | warn = warning 148 | 149 | 150 | def error(*args, **kwargs): 151 | _get_or_create_logger().error(print_to_string(*args, **kwargs)) 152 | --------------------------------------------------------------------------------