├── .github
├── ISSUE_TEMPLATE
│ └── new-puzzle.md
└── workflows
│ └── codeql-analysis.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── DATASHEET.md
├── ICLR2023
├── README.md
├── data
│ ├── 125M_PAPER_1M_iter_1.txt.gz
│ ├── 13B_PAPER_1M_iter_1.txt.gz
│ ├── 27B_PAPER_1M_iter_1.txt.gz
│ ├── 350M_PAPER_1M_iter_0.txt.gz
│ └── Codex_PAPER_1M_iter_0.txt.gz
├── requirements.txt
└── src
│ ├── babysit.sh
│ ├── ds_config_gptneo.json
│ ├── fine_tune.py
│ ├── fine_tune.sh
│ ├── fine_tune1.sh
│ ├── gen.py
│ ├── gen.sh
│ ├── judge.py
│ ├── neo_train.py
│ ├── preprocess.py
│ ├── requirements.txt
│ ├── solve.py
│ └── utils.py
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── generators
├── ICPC.py
├── IMO.py
├── __init__.py
├── algebra.py
├── basic.py
├── chess.py
├── classic_puzzles.py
├── codeforces.py
├── compression.py
├── conways_game_of_life.py
├── games.py
├── graphs.py
├── human_eval.py
├── lattices.py
├── number_theory.py
├── probability.py
├── problem_constants
│ └── rsa_challenges.json
├── study.py
├── trivial_inverse.py
└── tutorial.py
├── make_dataset.py
├── notebooks
├── Demo.ipynb
├── Hackathon_puzzles.ipynb
├── Intro.ipynb
├── demo
│ ├── demo.py
│ └── study_puzzles.py
└── fireworks.gif
├── puzzle_generator.py
├── puzzles
├── README.md
├── puzzles.json
└── split.json
├── solvers
├── README.md
├── codex
│ ├── 138puzzles.json
│ ├── 30puzzles.json
│ ├── 397puzzles.json
│ ├── README.md
│ ├── ezlog.py
│ ├── lm_solve
│ │ ├── __init__.py
│ │ ├── gpt_lib.py
│ │ ├── judge.py
│ │ ├── run.py
│ │ └── scratch.py
│ ├── requirements.txt
│ ├── results
│ │ └── results_397_cushman_codex_1k_full.json.gz
│ ├── run_codex_experiments.py
│ └── utils.py
├── enumerative
│ ├── README.md
│ ├── challenges
│ │ ├── __init__.py
│ │ ├── challenge.py
│ │ └── solutions.py
│ ├── download_pretrained_roberta.sh
│ ├── filter_outputs.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── ml_bow_bigram.py
│ │ ├── model.py
│ │ ├── transformers
│ │ │ ├── __init__.py
│ │ │ ├── dataset_processor.py
│ │ │ ├── finetune_transformer.py
│ │ │ ├── generate_rule_embeddings.py
│ │ │ ├── learn_tokenizer.py
│ │ │ ├── neural_classifier.py
│ │ │ └── preprocess_pretraining_data.py
│ │ └── uniform.py
│ ├── requirements.txt
│ ├── run_bigram.sh
│ ├── run_transformer.sh
│ ├── run_uniform.sh
│ ├── solve_challenges.py
│ ├── top_down.py
│ ├── tython
│ │ ├── __init__.py
│ │ ├── nonterminals.py
│ │ ├── parse.py
│ │ ├── program.py
│ │ └── rules.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── str_utils.py
│ │ └── time_utils.py
└── gpt3
│ ├── README.md
│ ├── create_simple_prompts.py
│ ├── ezlog.py
│ ├── lm_solve
│ ├── __init__.py
│ ├── gpt3_lib.py
│ ├── judge.py
│ └── run.py
│ ├── puzzles_with_descriptions.json
│ ├── puzzles_with_prompts.json
│ ├── requirements.txt
│ ├── run_gpt3_experiments.py
│ └── utils
│ ├── __init__.py
│ ├── str_utils.py
│ └── time_utils.py
└── utils.py
/.github/ISSUE_TEMPLATE/new-puzzle.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: New puzzle
3 | about: Create your own puzzle by defining a function
4 | title: New puzzle
5 | labels: New-puzzle
6 | assignees: ''
7 |
8 | ---
9 |
10 | ```python
11 | def sat(x: str):
12 | """optional problem description"""
13 | return "Hello " + x == "Hello world" # change this to your puzzle
14 | ```
15 |
16 | Solvers, post your solutions in the comments using the following formatting:
17 | ````
18 | Reveal solution
19 |
20 | ```python
21 | def sol():
22 | return "world" # replace with your solution
23 | ```
24 |
25 | ````
26 |
--------------------------------------------------------------------------------
/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ main ]
17 | pull_request:
18 | # The branches below must be a subset of the branches above
19 | branches: [ main ]
20 | schedule:
21 | - cron: '41 10 * * 5'
22 |
23 | jobs:
24 | analyze:
25 | name: Analyze
26 | runs-on: ubuntu-latest
27 | permissions:
28 | actions: read
29 | contents: read
30 | security-events: write
31 |
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | language: [ 'python' ]
36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
37 | # Learn more:
38 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed
39 |
40 | steps:
41 | - name: Checkout repository
42 | uses: actions/checkout@v2
43 |
44 | # Initializes the CodeQL tools for scanning.
45 | - name: Initialize CodeQL
46 | uses: github/codeql-action/init@v1
47 | with:
48 | languages: ${{ matrix.language }}
49 | # If you wish to specify custom queries, you can do so here or in a config file.
50 | # By default, queries listed here will override any specified in a config file.
51 | # Prefix the list here with "+" to use these queries and those in the config file.
52 | # queries: ./path/to/local/query, your-org/your-repo/queries@main
53 |
54 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
55 | # If this step fails, then you should remove it and run the build manually (see below)
56 | - name: Autobuild
57 | uses: github/codeql-action/autobuild@v1
58 |
59 | # ℹ️ Command-line programs to run using the OS shell.
60 | # 📚 https://git.io/JvXDl
61 |
62 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
63 | # and modify them (or add more) to build your code if your project
64 | # uses a compiled language
65 |
66 | #- run: |
67 | # make bootstrap
68 | # make release
69 |
70 | - name: Perform CodeQL Analysis
71 | uses: github/codeql-action/analyze@v1
72 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | .idea
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/ICLR2023/README.md:
--------------------------------------------------------------------------------
1 | This is the code and data for the paper: Language Models can teach themselves to code better
2 | https://arxiv.org/abs/2207.14502
3 |
4 | LICENSE
5 | MIT License - as already specified in the ../LICENSE file of PythonProgrammingPuzzles repo
6 |
7 | GPU USAGE
8 | GPU usage was large , especially for the 2.7B sized model which is ~20X the 125M.
9 | Data generation takes the most GPU usage and took about 2500 GPU hours for 2.7B (on v100)
10 | Finetuning on the 1M generated data took about 40 GPU hours for 2.7B (on v100) per epoch of finetuning - 10 epochs = 400 GPU hours
11 | Solving the 228 problem testset with 100 attempts using the finetuned 2.7B model took about 4 hours (on v100)
12 | We mostly used v100, but we used whatever was available, so T4 and A100 sometimes if they were free.
13 | Tried everything at 125M first - debug there and make it work perfect - then roll out the 1.3 and 2.7 jobs
14 |
15 | DATASETS
16 | In data directory are the datasets used. We feel the most interesting dataset is data/Codex_PAPER_1M_iter_0.txt
17 | which is generated by Codex and gave the best results when finetuned on. All the datasets are part of our public release.
18 |
19 | SETUP
20 | src/requirements.txt is what we install on our cluster machines - the cluster comes with NVidia drivers and matching pytorch
21 | ./requirements.txt is what I personally have installed on my local machine and tested this runs - but it has lots of stuff you don't need
22 | So try src/requirements.txt only - and if that doesn't work - then /requirements.txt has all versions of everything installed on my machine
23 | Getting a deepspeed 0.6.1 matching a pytorch matching a nvidia driver install was tricky for me on some machines, torch 1.10 and 1.11 both work
24 |
25 | GENERATING/FINETUNING -> run "cd src, ./babysit.sh GPU_INDEX_TO_USE" -> GPU_INDEX_TO_USE=0 typically
26 | In src/babysit.sh is the script that generates data, and finetunes on that data in a loop, finetuning the GPT-Neo 125M/1.3B/2.7B models
27 | In src/babysit.sh TEST_LOCAL=1 controls running locally on machine's GPUs which is great for fast testing, or =0 is launching on the cluster which is slow but has lots of GPUs
28 | Realistically you have to train on a cluster - data generation takes a long time so having lots of machines all generating data is the feasible approach.
29 | But given enough time - this will run locally on 1 GPU. 1 year for 2.7B, or 2 weeks for 125M.
30 | We found generating 75k samples after deduping worked for iteration_0 - finetune on that data.
31 | Then using that fine_tuned model in iter_1 generating data happens more quickly - the finetuned model solves many more problems
32 | Repeating that process works well.
33 | On 125M we looked at just training on only 125M generated data from iter_0 versus iter_1 versus iter_2 - generating 600K for each iteration.
34 | It seemed finetuning on iter_2 data was best on the testset 26.9/228 solved vs iter_1=26.1/228 vs iter_0=22.2/228
35 | With 1M samples from 125M generated data sampled across all the iterations 0,1,2 we got 26.75/228
36 | We understand why it's faster to generate iter_2 data on a finetuned model - it solves more problems.
37 | But why are the generated puzzles&solutions better for training the model on?
38 | We will explore that more in the future - and try iterating a lot farther than 3 iterations - although our preliminary experiments on 125M show it tops out at 3 iterations
39 |
40 | FINETUNING ONLY -> run "cd src, ./fine_tune1.sh GPU_INDEX_TO_USE" -> GPU_INDEX_TO_USE=0 typically
41 | # ./fine_tune1.sh GPU MODEL_TO_TRAIN EXPERIMENT_NAME_DIRECTORY TRAIN_DATA EPOCHS
42 | This allows the repeated finetuning on a specific dataset.
43 | Use this to do a temperature grid search, or try different variations of parameters on a specific dataset.
44 |
45 | Detailed instructions for reproducing experiments:
46 | # Generating Codex data
47 | python gen.py -n=32 -max_tokens=4096 -model_path=openai/code-davinci-002 -model_path_solve=openai/code-cushman-001 -out=../data/codex/iter_0 -seed=2022
48 |
49 | # Measuring codex accuracy via API calls
50 | ./solve2.sh
51 | python solve.py -prefix=../data/train_prefix.txt -attempts=1 -model_path=openai/code-cushman-001 -gpu=0 -fixed_temp=0.8 -out=../data/codex -puzzles=../data/test_228.json -seed=2022 -batch_size=64
52 |
53 | # Producing verified Codex_PAPER_1M_iter_0.txt from the puzzle/solution old style data generated by Codex
54 | python preprocess.py -path=../data/codex/old_verified -f_name=Codex_PAPER_1M_iter_0.txt -max_sols_per_puzzle=8 -old_style_json=True -max_examples=1000000 -include_failures=False -seed=2022
55 | cp ../data/codex/old/Codex_PAPER_1M_iter_0.txt ../data/Codex_PAPER_1M_iter_0.txt
56 |
57 | # Producing unverified Codex_unverified_PAPER_1M_iter_0.txt from the puzzle/solution old style data generated by Codex
58 | python preprocess.py -path=../data/codex/old_unverified -f_name=Codex_unverified_PAPER_1M_iter_0.txt -max_sols_per_puzzle=8 -old_style_json=True -max_examples=1000000 -include_failures=True -seed=2022
59 | cp ../data/codex/old_unverified/Codex_unverified_PAPER_1M_iter_0.txt ../data/Codex_unverified_PAPER_1M_iter_0.txt
60 |
61 | # Producing 125M_PAPER_25K_iter_0.txt from the puzzle/solution new style data
62 | python preprocess.py ../data/125M_PAPER/iter_0 125M_PAPER_25K_iter_0.txt 8 False 25000 False -seed=2022
63 | cp ../data/125M_PAPER/iter_0/125M_PAPER_25K_iter_0.txt ../data/125M_PAPER_25K_iter_0.txt
64 |
65 | # Producing 125M_PAPER_1M_iter_1.txt from the puzzle/solution new style data
66 | python preprocess.py ../data/125M_PAPER/iter_1 125M_PAPER_1M_iter_1.txt 8 False 1000000 False -seed=2022
67 | cp ../data/125M_PAPER/iter_1/125M_PAPER_1M_iter_1.txt ../data/125M_PAPER_1M_iter_1.txt
68 |
69 | # Producing 125M_PAPER_1M_iter_2.txt from the puzzle/solution new style data13B
70 | python preprocess.py ../data/125M_PAPER/iter_2 125M_PAPER_1M_iter_2.txt 8 False 1000000 False -seed=2022
71 | cp ../data/125M_PAPER/iter_2/125M_PAPER_1M_iter_2.txt ../data/125M_PAPER_1M_iter_2.txt
72 |
73 | # Producing 13B_PAPER_25K_iter_0.txt from the puzzle/solution new style data
74 | python preprocess.py ../data/13B_PAPER/iter_0 13B_PAPER_25K_iter_0.txt 8 False 25000 False -seed=2022
75 | cp ../data/13B_PAPER/iter_0/13B_PAPER_25K_iter_0.txt ../data/13B_PAPER_25K_iter_0.txt
76 |
77 | # Producing 13B_PAPER_1M_iter_1.txt from the puzzle/solution new style data
78 | python preprocess.py ../data/13B_PAPER/iter_1 13B_PAPER_1M_iter_1.txt 8 False 1000000 False -seed=2022
79 | cp ../data/13B_PAPER/iter_1/13B_PAPER_1M_iter_1.txt ../data/13B_PAPER_1M_iter_1.txt
80 |
81 | # Producing 13B_PAPER_1M_iter_2.txt from the puzzle/solution new style data
82 | python preprocess.py ../data/13B_PAPER/iter_2 13B_PAPER_1M_iter_2.txt 8 False 1000000 False -seed=2022
83 | cp ../data/13B_PAPER/iter_2/13B_PAPER_1M_iter_2.txt ../data/13B_PAPER_1M_iter_2.txt
84 |
85 | # Producing 27B_PAPER_25K_iter_0.txt from the puzzle/solution new style data
86 | python preprocess.py ../data/27B_PAPER/iter_0 27B_PAPER_25K_iter_0.txt 8 False 25000 False -seed=2022
87 | cp ../data/27B_PAPER/iter_0/27B_PAPER_25K_iter_0.txt ../data/27B_PAPER_25K_iter_0.txt
88 |
89 | # Producing 27B_PAPER_1M_iter_1.txt from the puzzle/solution new style data
90 | python preprocess.py ../data/27B_PAPER/iter_1 27B_PAPER_1M_iter_1.txt 8 False 1000000 False -seed=2022
91 | cp ../data/27B_PAPER/iter_1/27B_PAPER_1M_iter_1.txt ../data/27B_PAPER_1M_iter_1.txt
92 |
93 | # Producing 27B_PAPER_1M_iter_2.txt from the puzzle/solution new style data
94 | python preprocess.py ../data/27B_PAPER/iter_2 27B_PAPER_1M_iter_2.txt 8 False 1000000 False -seed=2022
95 | cp ../data/27B_PAPER/iter_2/27B_PAPER_1M_iter_2.txt ../data/27B_PAPER_1M_iter_2.txt
96 |
97 | # Data files produced by babysit.sh - generating data from gpt-neo-* and Codex
98 | # At the time of experiments running, Codex wasn't finetunable, so only iteration 0 data was available
99 | Codex_PAPER_1M_iter_0.txt
100 | 125M_PAPER_25K_iter_0.txt
101 | 13B_PAPER_25K_iter_0.txt
102 | 27B_PAPER_25K_iter_0.txt
103 | 125M_PAPER_1M_iter_1.txt
104 | 13B_PAPER_1M_iter_1.txt
105 | 27B_PAPER_1M_iter_1.txt
106 | 125M_PAPER_1M_iter_2.txt
107 | 13B_PAPER_1M_iter_2.txt
108 | 27B_PAPER_1M_iter_2.txt
109 |
110 | # Figure 5 - 3 diagrams - showing the 3 GPT models trained on verified codex vs unverified codex vs baseline
111 | # 5a GPT-NEO 125M
112 | ./fine_tune1.sh 0 125M ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt
113 | ./fine_tune1.sh 0 125M ft1_Codex_unverified_PAPER_1M_iter_0 Codex_unverified_PAPER_1M_iter_0.txt
114 | ./solve1.sh 0 125M 10 228
115 | # 5b GPT-NEO 13B
116 | ./fine_tune1.sh 0 13B ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt
117 | ./fine_tune1.sh 0 13B ft1_Codex_unverified_PAPER_1M_iter_0 Codex_unverified_PAPER_1M_iter_0.txt
118 | ./solve1.sh 0 13B 10 228 5
119 | # 5c GPT-NEO 27B
120 | ./fine_tune1.sh 0 27B ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt
121 | ./fine_tune1.sh 0 27B ft1_Codex_unverified_PAPER_1M_iter_0 Codex_unverified_PAPER_1M_iter_0.txt
122 | ./solve1.sh 0 13B 10 228 5
123 |
124 | # Figure 6 - 3 diagrams - showing test228 Pass@ for the 3 GPT models trained on data from 4 generators (codex and 3 GPT-Neo) and baseline
125 | # 6a - GPT-NEO 125M trained on 4 different datasets and baseline
126 | # ./fine_tune1.sh 0 125M ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt (dupe of 5a)
127 | ./fine_tune1.sh 0 125M ft1_125M_PAPER_1M_iter_2 125M_PAPER_1M_iter_2.txt
128 | ./fine_tune1.sh 0 125M ft1_13B_PAPER_1M_iter_2 13B_PAPER_1M_iter_2.txt
129 | ./fine_tune1.sh 0 125M ft1_27B_PAPER_1M_iter_2 27B_PAPER_1M_iter_2.txt
130 |
131 | # 6b - GPT-NEO 13B trained on 4 different datasets and baseline
132 | # ./fine_tune1.sh 0 13B ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt (dupe of 5b)
133 | ./fine_tune1.sh 0 13B ft1_125M_PAPER_1M_iter_2 125M_PAPER_1M_iter_2.txt
134 | ./fine_tune1.sh 0 13B ft1_13B_PAPER_1M_iter_2 13B_PAPER_1M_iter_2.txt
135 | ./fine_tune1.sh 0 13B ft1_27B_PAPER_1M_iter_2 27B_PAPER_1M_iter_2.txt
136 |
137 | # 6c - GPT-NEO 27B trained on 4 different datasets and baseline
138 | # ./fine_tune1.sh 0 27B ft1_Codex_PAPER_1M_iter_0 Codex_PAPER_1M_iter_0.txt (dupe of 5c)
139 | ./fine_tune1.sh 0 27B ft1_125M_PAPER_1M_iter_2 125M_PAPER_1M_iter_2.txt
140 | ./fine_tune1.sh 0 27B ft1_13B_PAPER_1M_iter_2 13B_PAPER_1M_iter_2.txt
141 | ./fine_tune1.sh 0 27B ft1_27B_PAPER_1M_iter_2 27B_PAPER_1M_iter_2.txt
142 |
143 | # Launch on torch2020 - edit solve.yaml for correct parameters of model and epoch
144 | ./tst_human_eval_base.sh 0 125M 1024
145 | ./tst_human_eval_ft1.sh 0 125M 1024
146 | ./tst_human_eval_ft5.sh 0 125M 1024
147 | ./tst_human_eval_ft10.sh 0 125M 1024
148 |
--------------------------------------------------------------------------------
/ICLR2023/data/125M_PAPER_1M_iter_1.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/ICLR2023/data/125M_PAPER_1M_iter_1.txt.gz
--------------------------------------------------------------------------------
/ICLR2023/data/13B_PAPER_1M_iter_1.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/ICLR2023/data/13B_PAPER_1M_iter_1.txt.gz
--------------------------------------------------------------------------------
/ICLR2023/data/27B_PAPER_1M_iter_1.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/ICLR2023/data/27B_PAPER_1M_iter_1.txt.gz
--------------------------------------------------------------------------------
/ICLR2023/data/350M_PAPER_1M_iter_0.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/ICLR2023/data/350M_PAPER_1M_iter_0.txt.gz
--------------------------------------------------------------------------------
/ICLR2023/data/Codex_PAPER_1M_iter_0.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/ICLR2023/data/Codex_PAPER_1M_iter_0.txt.gz
--------------------------------------------------------------------------------
/ICLR2023/requirements.txt:
--------------------------------------------------------------------------------
1 | adal==1.2.7
2 | aiohttp==3.8.5
3 | aiosignal==1.2.0
4 | amlt==8.0.9
5 | applicationinsights==0.11.10
6 | asn1crypto==0.24.0
7 | astor==0.8.1
8 | async-timeout==4.0.1
9 | attrs==17.4.0
10 | Automat==0.6.0
11 | azure-common==1.1.27
12 | azure-core==1.17.0
13 | azure-data-tables==12.0.0b6
14 | azure-graphrbac==0.61.1
15 | azure-identity==1.4.1
16 | azure-mgmt-authorization==0.61.0
17 | azure-mgmt-containerregistry==2.8.0
18 | azure-mgmt-keyvault==2.2.0
19 | azure-mgmt-resource==13.0.0
20 | azure-mgmt-storage==11.2.0
21 | azure-storage-blob==2.1.0
22 | azure-storage-common==2.1.0
23 | azure-storage-file==2.1.0
24 | azureml-automl-core==1.26.0
25 | azureml-contrib-k8s==0.1.16
26 | azureml-contrib-pipeline-steps==1.26.0
27 | azureml-core==1.26.0
28 | azureml-dataprep==2.13.2
29 | azureml-dataprep-native==32.0.0
30 | azureml-dataprep-rslex==1.11.2
31 | azureml-dataset-runtime==1.26.0
32 | azureml-k8s-mt==1.0.4
33 | azureml-pipeline-core==1.26.0
34 | azureml-pipeline-steps==1.26.0
35 | azureml-telemetry==1.26.0
36 | azureml-train-automl-client==1.26.0
37 | azureml-train-core==1.26.0
38 | azureml-train-restclients-hyperdrive==1.26.0
39 | backcall==0.2.0
40 | backports.tempfile==1.0
41 | backports.weakref==1.0.post1
42 | beautifulsoup4==4.9.3
43 | bitstring==3.1.9
44 | black==21.8b0
45 | blinker==1.4
46 | blis==0.7.4
47 | blobxfer==1.10.0
48 | cachetools==4.2.2
49 | catalogue==2.0.6
50 | certifi==2023.7.22
51 | cffi==1.14.6
52 | chardet==3.0.4
53 | charset-normalizer==2.0.7
54 | click==7.1.2
55 | click-completion @ git+https://github.com/temporaer/click-completion.git@41b21868cac0781d25b37da624bae2fd1f36be88
56 | click-option-group==0.5.3
57 | click-plugins==1.1.1
58 | cloud-init==20.2
59 | cloudpickle==1.6.0
60 | colorama==0.3.7
61 | colorlog==6.4.1
62 | command-not-found==0.3
63 | configobj==5.0.6
64 | configparser==5.0.2
65 | constantly==15.1.0
66 | contextlib2==21.6.0
67 | cryptography==41.0.4
68 | cycler==0.10.0
69 | cymem==2.0.5
70 | datasets==1.15.1
71 | debugpy==1.4.3
72 | decorator==5.0.9
73 | deepspeed==0.5.1
74 | dill==0.3.4
75 | distro==1.6.0
76 | distro-info===0.18ubuntu0.18.04.1
77 | docker==5.0.1
78 | docker-pycreds==0.4.0
79 | dotnetcore2==2.1.21
80 | ecdsa==0.17.0
81 | entrypoints==0.3
82 | et-xmlfile==1.1.0
83 | fail2ban==0.10.2
84 | fastai==2.5.2
85 | fastcore==1.3.26
86 | fastdownload==0.0.5
87 | fastprogress==1.0.0
88 | filelock==3.0.12
89 | Flask==2.3.2
90 | Flask-Cors==3.0.10
91 | Flask-Executor==0.9.4
92 | Flask-FontAwesome==0.1.5
93 | frozenlist==1.2.0
94 | fsspec==2021.11.0
95 | gitdb==4.0.7
96 | GitPython==3.1.35
97 | httplib2==0.19.0
98 | huggingface-hub==0.1.2
99 | humanize==3.11.0
100 | hyperlink==17.3.1
101 | idna==2.6
102 | incremental==16.10.1
103 | ipdb==0.13.9
104 | ipykernel==6.4.1
105 | ipython==8.10.0
106 | ipython-genutils==0.2.0
107 | isodate==0.6.0
108 | itsdangerous==2.0.1
109 | jedi==0.18.0
110 | Jinja2==3.0.1
111 | jmespath==0.10.0
112 | joblib==1.2.0
113 | jsonpatch==1.16
114 | jsonpickle==2.0.0
115 | jsonpointer==1.10
116 | jsonschema==2.6.0
117 | jupyter-client==7.0.5
118 | jupyter-core==4.11.2
119 | keyring==10.6.0
120 | keyrings.alt==3.0
121 | kiwisolver==1.3.2
122 | language-selector==0.1
123 | libtmux==0.10.1
124 | Mako==1.2.2
125 | MarkupSafe==2.0.1
126 | marshmallow==3.10.0
127 | matplotlib==3.4.3
128 | matplotlib-inline==0.1.3
129 | mlb-core==0.0.4
130 | msal==1.14.0
131 | msal-extensions==0.2.2
132 | msrest==0.6.19
133 | msrestazure==0.6.4
134 | multidict==5.2.0
135 | multiprocess==0.70.12.2
136 | murmurhash==1.0.5
137 | mypy-extensions==0.4.3
138 | ndg-httpsclient==0.5.1
139 | nest-asyncio==1.5.1
140 | netifaces==0.10.4
141 | ninja==1.10.2
142 | ntlm-auth==1.5.0
143 | numpy==1.22.0
144 | oauthlib==3.2.2
145 | openai==0.13.0
146 | openpyxl==3.0.9
147 | orderedset==2.0.3
148 | packaging==21.0
149 | PAM==0.4.2
150 | pandas==1.3.2
151 | pandas-stubs==1.2.0.45
152 | parso==0.8.2
153 | passpy==1.0.2
154 | pathspec==0.9.0
155 | pathtools==0.1.2
156 | pathy==0.6.0
157 | Pebble==4.6.3
158 | petname==2.6
159 | pexpect==4.8.0
160 | pickleshare==0.7.5
161 | Pillow==10.0.1
162 | platformdirs==2.3.0
163 | portalocker==1.7.1
164 | preshed==3.0.5
165 | promise==2.3
166 | prompt-toolkit==3.0.20
167 | protobuf==3.18.3
168 | psb2==1.0.0
169 | psutil==5.8.0
170 | ptyprocess==0.7.0
171 | pyarrow==1.0.1
172 | pyasn1==0.4.2
173 | pyasn1-modules==0.2.1
174 | pycparser==2.20
175 | pycrypto==2.6.1
176 | pydantic==1.8.2
177 | Pygments==2.15.0
178 | PyGObject==3.26.1
179 | PyJWT==2.4.0
180 | pyOpenSSL==17.5.0
181 | pyparsing==2.4.7
182 | pyperclip==1.8.2
183 | pyserial==3.4
184 | python-apt==1.6.5+ubuntu0.3
185 | python-dateutil==2.8.2
186 | python-debian==0.1.32
187 | python-gnupg==0.4.7
188 | pytz==2021.1
189 | pyxdg==0.26
190 | PyYAML==5.4.1
191 | pyzmq==22.3.0
192 | regex==2021.8.28
193 | requests==2.31.0
194 | requests-ntlm==1.1.0
195 | requests-oauthlib==1.3.0
196 | requests-unixsocket==0.1.5
197 | ruamel.yaml==0.17.16
198 | ruamel.yaml.clib==0.2.6
199 | sacremoses==0.0.45
200 | scikit-learn==0.24.2
201 | scipy==1.10.0
202 | SecretStorage==2.3.1
203 | sentry-sdk==1.14.0
204 | service-identity==16.0.0
205 | shellingham==1.4.0
206 | shortuuid==1.0.1
207 | six==1.16.0
208 | sklearn==0.0
209 | smart-open==5.2.1
210 | smmap==4.0.0
211 | soupsieve==2.2.1
212 | spacy==3.1.2
213 | spacy-legacy==3.0.8
214 | srsly==2.4.1
215 | ssh-import-id==5.7
216 | sshpubkeys==3.3.1
217 | strictfire==0.4.1
218 | subprocess32==3.5.4
219 | systemd-python==234
220 | tabulate==0.8.9
221 | tensorboardX==1.8
222 | termcolor==1.1.0
223 | thinc==8.0.10
224 | threadpoolctl==2.2.0
225 | tokenizers==0.10.3
226 | toml==0.10.2
227 | tomli==1.2.1
228 | torch==1.13.1
229 | torchvision==0.10.0
230 | tornado==6.3.3
231 | tqdm==4.62.2
232 | traitlets==5.1.0
233 | transformers==4.30.0
234 | Twisted==22.10.0
235 | typer==0.3.2
236 | typing-extensions==3.10.0.2
237 | ufw==0.36
238 | unattended-upgrades==0.1
239 | urllib3==1.26.17
240 | virtualenv==15.1.0
241 | WALinuxAgent==2.2.45
242 | wasabi==0.8.2
243 | wcwidth==0.2.5
244 | websocket-client==1.2.1
245 | Werkzeug==2.2.3
246 | xdg==5.1.1
247 | xxhash==2.0.2
248 | yarl==1.7.2
249 | zope.interface==4.3.2
250 |
--------------------------------------------------------------------------------
/ICLR2023/src/babysit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # All Experiment Settings - constant through the experiment run - passed to gen.sh and fine_tune.sh as needed
3 | GPU=0 # which GPU to use
4 | MODEL="125M" # MODEL is the size of the model: 125M, 13B, 27B
5 | EXPERIMENT=$MODEL"_PAPER" # Name of Experiment directory under data/* and models/base-model/* to store results
6 | TEST_LOCAL=1 # 0 means run gen/fine_tune on cluster remotely, 1 means run gen/fine_tune locally
7 | TARGET_NUM_FILES=1 # How many files to generate in each iteration before starting fine_tuning. Count of unique examples would have been better.
8 | ITER_START=0 # inclusive index to start processing at - creates iter_# under data&models at each iteration. Can continue prev runs by start at prev ITER_MAX
9 | ITER_MAX=5 # exclusive index to stop processing iterations at
10 | EPOCHS_START=1 # inclusive index of epochs to start processing at - could continue prev run by starting at prev EPOCHS_MAX+1 - 0th epoch is the default model so epoch starts at 1
11 | EPOCHS_MAX=4 # inclusive index of epochs to stop processing at
12 | EPOCHS_PER_STEP=1 # How many EPOCHS through the data to do in each step
13 | TRAIN_INCREMENTAL=0 # Only train on data from the latest iteration, and start finetuning on the last finetuned model - otherwise start from scratch and use all the data generated
14 | TRAIN_BOOST=0 # Initial generation of data from default model is slow - 1 means looks in 125M_RL_ALL to use previous generated initial data to bootstrap.
15 | PASS_AT_K=100 # PASS_AT_K says do K trials to solve to compute Pass@K
16 | LINE_LOG_K=11 # LINE_LOG_K is how many lines of results from solve have results for saving
17 |
18 | echo babysit args: $# $0 $1 $2 $3 $4
19 |
20 | if (( $# \!= 1 ))
21 | then
22 | echo babysit.sh only takes 1 argument, unless called by another script to initialize configuration variables
23 | return
24 | fi
25 |
26 | if (( $# \>= 1 ))
27 | then
28 | GPU=$1
29 | fi
30 |
31 | echo babysit GPU $GPU
32 |
33 | for (( iteration=$ITER_START; iteration<$ITER_MAX; iteration++ ))
34 | do
35 | FULLNAME="${EXPERIMENT}---${iteration}"
36 | echo FULLNAME $FULLNAME
37 | export FULLNAME # Needed to pass variable off to yaml job
38 | DATAPATH=data/${EXPERIMENT}/iter_$iteration
39 | echo DATAPATH $DATAPATH
40 |
41 | if (( $TEST_LOCAL \> 0 ))
42 | then
43 | count=`ls -lt ../${DATAPATH} | grep json | wc -l`
44 | else
45 | count=`amlt sto list ${DATAPATH} | grep json | wc -l`
46 | fi
47 | echo count $count
48 |
49 | # Instead of file count we might want to check if the amount of data from preprocess is sufficient
50 | # If not we call to generate more
51 |
52 | if (( $count \> 0 ))
53 | then
54 | echo "$FULLNAME has already been started"
55 | echo "You are resuming at iteration $iteration"
56 | echo "You already have $count files of data this iteration"
57 | else
58 | echo "$FULLNAME is starting generation for iteration $iteration"
59 | fi
60 |
61 | if (( $count \< $TARGET_NUM_FILES ))
62 | then
63 | if (( $TEST_LOCAL \> 0 ))
64 | then
65 | # ./gen.sh $GPU 2560 100 $FULLNAME -1
66 | # 2.7B 384 100 runs ~10 hours
67 | # 2.7B 160 100 runs ~4.5 hours
68 | ./gen.sh $GPU 256000 100 $FULLNAME -1
69 | else
70 | amlt run hyper_gen_octows.yaml $FULLNAME -d "$FULLNAME"
71 | exit
72 | fi
73 | fi
74 |
75 | # Running local you are done, but launching on the cloud, you have to wait
76 | for (( poll=0; poll<500; poll++ ))
77 | do
78 | if (( $TEST_LOCAL \> 0 ))
79 | then
80 | count=`ls -lt ../${DATAPATH} | grep json | wc -l`
81 | else
82 | count=`amlt sto list ${DATAPATH} | grep json | wc -l`
83 | fi
84 |
85 | echo "gen wait - Iteration: $iteration, Poll: $poll, Count: $count"
86 |
87 | if (( $count \>= $TARGET_NUM_FILES ))
88 | then
89 | echo "Finished generation iteration $iteration after $poll polls"
90 | break
91 | fi
92 | sleep 3m
93 | done
94 |
95 | # Start a finetune job
96 | if (( $TEST_LOCAL \> 0 ))
97 | then
98 | ./fine_tune.sh $GPU $FULLNAME
99 | else
100 | # Pass enviroment variable FULLNAME to amlt.yaml
101 | amlt run amlt_octo.yaml $FULLNAME -d "$FULLNAME"
102 | exit
103 | fi
104 |
105 | # On cluster we need to wait for finetune job to finish, run locally it's done
106 | # Check the log files for starting the running of solve have been created for the last epoch of training
107 |
108 | MODELPATH=models/gpt-neo-$MODEL/${EXPERIMENT}/iter_$iteration
109 | SOLVE_PATH=$MODELPATH/"epoch_"$EPOCHS_MAX/"solve_"$PASS_AT_K
110 | echo babysit.sh SOLVE_PATH $SOLVE_PATH
111 |
112 | for (( poll=0; poll<500; poll++ ))
113 | do
114 | if (( $TEST_LOCAL \> 0 ))
115 | then
116 | count=`ls -lt ../$SOLVE_PATH | grep json | wc -l`
117 | else
118 | count=`amlt sto list $SOLVE_PATH | grep json | wc -l`
119 | fi
120 |
121 | echo "fine_tune wait - Iteration: $iteration, Poll: $poll, Count: $count"
122 |
123 | if (( $count \>= 1 ))
124 | then
125 | echo "Finished fine_tune iteration $iteration after $poll polls"
126 | break
127 | fi
128 | sleep 3m
129 | done
130 |
131 | done
132 |
133 | # Pull all the results into 1 log file to look at more easily
134 |
135 | if [[ -z "${AMLT_DATA_DIR}" ]];
136 | then
137 | # running locally on torch2020 so we don't have AMLT enviroment variables defined, so need to set them up
138 | AMLT_DATA_DIR=../data
139 | else
140 | # On remote we don't have access to the log files - maybe could do amlt sto download to do this summary below ?
141 | exit
142 | fi
143 |
144 | BASE_MODEL_PATH=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL
145 | LOG_FILE=$BASE_MODEL_PATH/$EXPERIMENT/"solve_"$PASS_AT_K".txt"
146 | echo solve LOG_FILE for babysit.sh is $LOG_FILE
147 | rm $LOG_FILE
148 |
149 | for (( iteration=$ITER_START; iteration<$ITER_MAX; iteration++ ))
150 | do
151 | for (( epochs=$EPOCHS_START; epochs<=$EPOCHS_MAX; epochs++ ))
152 | do
153 | EPOCH_NAME="epoch_"$epochs
154 | STEP_PATH=$BASE_MODEL_PATH/$EXPERIMENT/iter_$iteration/$EPOCH_NAME
155 | MODEL_PATH=$STEP_PATH/finetuned
156 | echo iteration $iteration epoch $epochs >> $LOG_FILE
157 | head -$LINE_LOG_K $STEP_PATH/"solve_"$PASS_AT_K/results.json >> $LOG_FILE
158 | done
159 | done
160 |
161 | cat $LOG_FILE
162 |
--------------------------------------------------------------------------------
/ICLR2023/src/ds_config_gptneo.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "optimizer": {
11 | "type": "AdamW",
12 | "params": {
13 | "lr": "auto",
14 | "betas": "auto",
15 | "eps": "auto",
16 | "weight_decay": "auto"
17 | }
18 | },
19 | "scheduler": {
20 | "type": "WarmupLR",
21 | "params": {
22 | "warmup_min_lr": "auto",
23 | "warmup_max_lr": "auto",
24 | "warmup_num_steps": "auto"
25 | }
26 | },
27 | "zero_optimization": {
28 | "stage": 2,
29 | "allgather_partitions": true,
30 | "allgather_bucket_size": 2e8,
31 | "overlap_comm": true,
32 | "reduce_scatter": true,
33 | "reduce_bucket_size": 2e8,
34 | "contiguous_gradients": true,
35 | "cpu_offload": true
36 | },
37 | "gradient_accumulation_steps": "auto",
38 | "gradient_clipping": "auto",
39 | "steps_per_print": 2000,
40 | "train_batch_size": "auto",
41 | "train_micro_batch_size_per_gpu": "auto",
42 | "wall_clock_breakdown": false
43 | }
44 |
--------------------------------------------------------------------------------
/ICLR2023/src/fine_tune.py:
--------------------------------------------------------------------------------
1 | from strictfire import StrictFire as Fire # aborts early on invalid arguments
2 | import os
3 | import csv
4 | import subprocess
5 | import shlex
6 | import random
7 | import numpy as np
8 | import torch
9 | import utils
10 |
11 | def fine_tune(
12 | train_txt="../data/generated_sol_100.txt",
13 | output_dir = "../outputs/",
14 | subdir="out",
15 | model_path="EleutherAI/gpt-neo-2.7B",
16 | gpu=0,
17 | num_gpus=1,
18 | epochs=4,
19 | seed=0,
20 | ):
21 | """
22 | Fine tune the model on the puzzles in train_txt file and save the results to OUTPUT_DIR/output_subdir
23 |
24 | train_txt: the (possibly gzipped) file containing the text to fine tune on (default: ../data/generated_sol_100.txt)
25 | subdir: the subdirectory to save the results to (default "out")
26 | model_path: the path to the model to fine tune (default "EleutherAI/gpt-neo-2.7B")
27 | gpu: which GPU(s) to use, e.g.: 0,1 (default 0)
28 | epochs: how many epochs to train for (default 4)
29 | seed: the random seed to use, not sure if this affects fine tuning (default 0)
30 | """
31 | random.seed(seed)
32 | np.random.seed(seed)
33 | torch.manual_seed(seed)
34 |
35 | # create output dir if necessary
36 | output_path = os.path.join(output_dir, subdir)
37 | if not os.path.exists(output_path):
38 | os.makedirs(output_path)
39 |
40 |
41 | text = utils.load_text_file(train_txt) # decompresses if ends in .gz
42 | tokenizer = utils.load_tokenizer(model_path)
43 | num_toks = utils.num_tokens(text, tokenizer, verbose=True)
44 | assert num_toks > 1024, "Not enough tokens in text to fine tune"
45 |
46 | # create csv
47 | train_file = os.path.join(output_path, "train.csv")
48 | with open(train_file, mode="w", encoding="utf-8") as csv_file:
49 | fieldnames = ["text"]
50 | writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
51 | writer.writeheader()
52 | writer.writerow({"text": text})
53 |
54 | output_path_finetuned = os.path.join(output_path, "finetuned")
55 |
56 | # keep gradient_accumulation_steps at 1 bc setting it to 2 effectively doubles the batch
57 | # size which gets tricky when batch sizes are small (ft_tokens will no longer be accurate)
58 | gradient_accumulation_steps = 1
59 | per_device_train_batch_size = 4
60 |
61 | cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
62 | if len(cuda_visible_devices):
63 | print("os.environ(CUDA_VISIBLE_DEVICES)", cuda_visible_devices)
64 | del os.environ["CUDA_VISIBLE_DEVICES"]
65 | print("os.environ(CUDA_VISIBLE_DEVICES)", os.environ.get("CUDA_VISIBLE_DEVICES", ""))
66 |
67 | master_port = 29600 # During training deepspeed uses a port to syncronize. 2 jobs need to set different ports to run in parallel
68 | if type(gpu) in [list, tuple]:
69 | master_port += gpu[0]
70 | gpu = ",".join([str(g) for g in gpu])
71 | else:
72 | master_port += gpu
73 |
74 | gpu_string = f'--include=localhost:{gpu}'
75 |
76 | if num_gpus > 1:
77 | gpu_string = f"--num_nodes=1 --num_gpus={num_gpus}",
78 | # If gpu is passed in as negative - it's the count of gpu to use - a bit of a hack
79 | if gpu < 0:
80 | num_gpus = abs(gpu)
81 | gpu_string = f"--num_nodes=1 --num_gpus={num_gpus}"
82 |
83 | print("gpu_string", gpu_string)
84 |
85 | cmd = " ".join(
86 | [
87 | "deepspeed",
88 | f"--master_port={master_port}",
89 | gpu_string,
90 | # f'--include=localhost:{gpu}',
91 | # "--num_nodes=1",
92 | # f"--num_gpus={num_gpus}",
93 | "neo_train.py",
94 | f"--model_name_or_path={model_path}",
95 | f"--train_file={train_file}",
96 | f"--output_dir={output_path_finetuned}",
97 | "--overwrite_output_dir",
98 | "--ignore_data_skip",
99 | "--deepspeed",
100 | "ds_config_gptneo.json",
101 | f"--save_strategy=no", # ATK remove checkpointing for large datasets
102 | # pretty sure this is just dataset cache
103 | "--overwrite_cache",
104 | # logging frequency
105 | "--logging_steps=5",
106 | "--do_train",
107 | "--report_to none", # turns off report_to WANDB for instance
108 | "--fp16",
109 | f"--num_train_epochs={epochs}",
110 | # overrides num_train_epochs if set to a positive value. This is the number of gradient steps that happen total.
111 | f"--per_device_train_batch_size={per_device_train_batch_size}",
112 | "--use_fast_tokenizer=False",
113 | f"--gradient_accumulation_steps={gradient_accumulation_steps}",
114 | "--learning_rate=5e-06",
115 | # linear increase from this up to learning rate, then LR schedule happens (which itself is linear decreasing until max_steps)
116 | "--warmup_steps=10",
117 | ]
118 | )
119 |
120 | utils.info(f"running command: {cmd}")
121 | print(f"Command to run:{cmd}") # Why is this different than what utils.info prints out, utils.info truncates it
122 | # exit()
123 | res = subprocess.run(shlex.split(cmd), check=True)
124 | utils.info(str(res))
125 |
126 |
127 | if __name__ == "__main__":
128 | Fire(fine_tune)
129 |
--------------------------------------------------------------------------------
/ICLR2023/src/fine_tune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo fine_tune.sh args: $# $0 $1 $2 $3 $4
3 | # Grab the configuration variables
4 | . babysit.sh
5 |
6 | # On AMLT machines we don't specify which GPU to use
7 | # GPU="-1"
8 | if [[ -z "${AMLT_DATA_DIR}" ]]; then
9 | # running locally on torch2020 so we don't have AMLT enviroment variables defined, so need to set them up
10 | AMLT_DATA_DIR=../data
11 | # On torch2020 we do specify which GPU to use
12 | # GPU="0"
13 | fi
14 |
15 | # assert that there are at least 2 argument
16 | if (( $# \< 2 ))
17 | then
18 | echo "Usage: $0 "
19 | exit
20 | fi
21 |
22 | GPU=$1
23 | FULLNAME=$2
24 |
25 | # split by ; fullname string into experiment name and iteration
26 | # e.g. "125M_RL---0" --> "125M_RL;0"
27 | SPLIT=(${FULLNAME//---/ })
28 | EXPERIMENT=${SPLIT[0]}
29 | ITERATION=${SPLIT[1]}
30 | OUTPATH=$AMLT_DATA_DIR/$EXPERIMENT/iter_$ITERATION
31 |
32 | echo GPU $GPU
33 | echo EXPERIMENT $EXPERIMENT
34 | echo ITERAION $ITERATION
35 | echo OUTPATH $OUTPATH
36 |
37 | # GPU_SOLVE is the GPU we want solve to use. Solve currently only uses 1 GPU - it would be great to make it use more when they are available.
38 | # if GPU is negative - that tells fine_tune how many GPU to use on cluster - and we need to set GPU for solve to 0 on cluster
39 | # if GPU is positive - we are running locally on torch2020 - and we need to leave the GPU set properly
40 | GPU_SOLVE=$GPU
41 | if (( $GPU \< 0 ))
42 | then
43 | GPU_SOLVE=0
44 | fi
45 | echo GPU_SOLVE $GPU_SOLVE
46 |
47 | python preprocess.py $OUTPATH
48 |
49 | TRN_FILE=$OUTPATH/gen_ps_filtered.txt
50 | echo TRN_FILE $TRN_FILE
51 | TST_FILE=$AMLT_DATA_DIR/test_228.json
52 |
53 | BASE_MODEL_PATH=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL
54 | # 125M is copied locally to start
55 | MODEL_PATH=$BASE_MODEL_PATH
56 | MODEL_PATH=EleutherAI/gpt-neo-125M # !!! Just for paper release
57 | # 13B is off in the cloud to start
58 | if [[ "$MODEL" == "13B" ]]; then
59 | MODEL_PATH=EleutherAI/gpt-neo-1.3B
60 | fi
61 | # 27B is off in the cloud tto start
62 | if [[ "$MODEL" == "27B" ]]; then
63 | MODEL_PATH=EleutherAI/gpt-neo-2.7B
64 | fi
65 |
66 | echo MODEL MODEL_PATH $MODEL $MODEL_PATH
67 |
68 | # Training incremental means use the previous iterations trained model, and just the additional iteration's new data to fine_tune on.
69 | # Otherwise use the base model - and retrain from scratch on all the data from all previous iterations.
70 | # They are sort of equivalent - except from scratch picks up any extra data that was generated - and mixes all the iterations data together - but slower.
71 | if (( $TRAIN_INCREMENTAL \> 0 ))
72 | then
73 | PREV_ITERATION=$((ITERATION-1))
74 | echo $PREV_ITERATION
75 | TEST=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL/$EXPERIMENT/iter_$PREV_ITERATION/epoch_$EPOCHS_MAX/finetuned
76 |
77 | if [ -a $TEST ] # exists
78 | then
79 | MODEL_PATH=$TEST
80 | echo fine_tune.sh using previous iteration model
81 | fi
82 | fi
83 |
84 | echo "fine_tune.sh starting from NEO model at: ${MODEL_PATH}"
85 |
86 | # Pull all the results into 1 log file to look at more easily
87 | LOG_FILE=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION/"solve.txt"
88 | echo solve LOG_FILE for fine_tune.sh is $LOG_FILE
89 | rm $LOG_FILE
90 |
91 | for (( epochs=$EPOCHS_START; epochs<=$EPOCHS_MAX; epochs++ ))
92 | do
93 | EPOCH_NAME="epoch_"$epochs
94 | EPOCHS_STEP=$(($EPOCHS_PER_STEP * $epochs))
95 | python fine_tune.py -train_txt=$TRN_FILE -gpu=$GPU -output_dir=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION -subdir=$EPOCH_NAME -model_path=$MODEL_PATH -epochs=$EPOCHS_STEP
96 | # measure the finetuned model's accuracy
97 | STEP_PATH=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION/$EPOCH_NAME
98 | MODEL_PATH=$STEP_PATH/finetuned
99 | python solve.py -prefix=$AMLT_DATA_DIR/train_prefix.txt -attempts=$PASS_AT_K -model_path=$MODEL_PATH -gpu=$GPU_SOLVE -fixed_temp=0.8 -out=$STEP_PATH/"solve_"$PASS_AT_K"/" -puzzles=$TST_FILE
100 | head -$LINE_LOG_K $STEP_PATH/"solve_"$PASS_AT_K/results.json >> $LOG_FILE
101 | done
102 |
103 | cat $LOG_FILE
--------------------------------------------------------------------------------
/ICLR2023/src/fine_tune1.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This is for finetuning a model on 1 dataset only
3 | echo fine_tune1.sh args: $# $0 $1 $2 $3 $4
4 |
5 | # All Experiment Settings - constant through the experiment run
6 | GPU=0 # which GPU to use
7 | MODEL="125M" # MODEL is the size of the model: 125M, 13B, 27B
8 | EXPERIMENT=$MODEL"_PAPER1" # Name of Experiment directory under data/* and models/base-model/* to store results
9 | ITERATION=0 # Random seed for finetuning
10 | EPOCHS_START=1 # inclusive index of epochs to start processing at - could continue prev run by starting at prev EPOCHS_MAX+1 - 0th epoch is the default model so epoch starts at 1
11 | EPOCHS_MAX=10 # inclusive index of epochs to stop processing at
12 | EPOCHS_PER_STEP=1 # How many EPOCHS through the data to do in each step
13 | PASS_AT_K=100 # PASS_AT_K says do K trials to solve to compute Pass@K
14 | LINE_LOG_K=11 # LINE_LOG_K is how many lines of results from solve have results for saving
15 |
16 | # On AMLT machines we don't specify which GPU to use
17 | if [[ -z "${AMLT_DATA_DIR}" ]]; then
18 | # running locally on torch2020 so we don't have AMLT enviroment variables defined, so need to set them up
19 | AMLT_DATA_DIR=../data
20 | fi
21 |
22 | if (( $# \>= 1 ))
23 | then
24 | GPU=$1
25 | fi
26 |
27 | echo GPU $GPU
28 | echo EXPERIMENT $EXPERIMENT
29 | echo ITERAION $ITERATION
30 |
31 | TRN_FILE=$AMLT_DATA_DIR/generated_sol_950k.txt
32 | echo TRN_FILE $TRN_FILE
33 | TST_FILE=$AMLT_DATA_DIR/test_228.json
34 |
35 | # GPU_SOLVE is the GPU we want solve to use. Solve currently only uses 1 GPU - it would be great to make it use more when they are available.
36 | # if GPU is negative - that tells fine_tune how many GPU to use on cluster - and we need to set GPU for solve to 0 on cluster
37 | # if GPU is positive - we are running locally on torch2020 - and we need to leave the GPU set properly
38 | GPU_SOLVE=$GPU
39 | if (( $GPU \< 0 ))
40 | then
41 | GPU_SOLVE=0
42 | fi
43 | echo GPU_SOLVE $GPU_SOLVE
44 |
45 | # measure the base model's accuracy - don't really need to do this very often - it doesn't change
46 |
47 | BASE_MODEL_PATH=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL
48 | # 125M is copied locally to start
49 | MODEL_PATH=$BASE_MODEL_PATH
50 | MODEL_PATH=EleutherAI/gpt-neo-125M # !!! Just for paper release
51 | # 13B is off in the cloud to start
52 | if [[ "$MODEL" == "13B" ]]; then
53 | MODEL_PATH=EleutherAI/gpt-neo-1.3B
54 | fi
55 | # 27B is off in the cloud tto start
56 | if [[ "$MODEL" == "27B" ]]; then
57 | MODEL_PATH=EleutherAI/gpt-neo-2.7B
58 | fi
59 |
60 | echo MODEL MODEL_PATH $MODEL $MODEL_PATH
61 |
62 | # Training incremental means use the previous epochs model to start
63 | # Otherwise use the base model to retrain from scratch
64 | if (( $EPOCHS_START \> 1 ))
65 | then
66 | PREV_EPOCH=$((EPOCHS_START-1))
67 | echo $PREV_EPOCH
68 | TEST=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL/$EXPERIMENT/iter_$ITERATION/epoch_$PREV_EPOCH/finetuned
69 |
70 | if [ -a $TEST ] # exists
71 | then
72 | MODEL_PATH=$TEST
73 | echo fine_tune.sh using previous iteration model
74 | fi
75 | fi
76 |
77 | echo "fine_tune.sh starting from NEO model at: ${MODEL_PATH}"
78 |
79 | # Pull all the results into 1 log file to look at more easily
80 | LOG_FILE=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION/"solve.txt"
81 | echo solve LOG_FILE for fine_tune.sh is $LOG_FILE
82 | rm $LOG_FILE
83 |
84 | for (( epochs=$EPOCHS_START; epochs<=$EPOCHS_MAX; epochs++ ))
85 | do
86 | EPOCH_NAME="epoch_"$epochs
87 | EPOCHS_STEP=$(($EPOCHS_PER_STEP * $epochs))
88 | python fine_tune.py -train_txt=$TRN_FILE -gpu=$GPU -output_dir=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION -subdir=$EPOCH_NAME -model_path=$MODEL_PATH -epochs=$EPOCHS_STEP -seed=$ITERATION
89 | # measure the finetuned model's accuracy
90 | STEP_PATH=$BASE_MODEL_PATH/$EXPERIMENT/iter_$ITERATION/$EPOCH_NAME
91 | MODEL_PATH=$STEP_PATH/finetuned
92 | python solve.py -prefix=$AMLT_DATA_DIR/train_prefix.txt -attempts=$PASS_AT_K -model_path=$MODEL_PATH -gpu=$GPU_SOLVE -fixed_temp=0.8 -out=$STEP_PATH/"solve_"$PASS_AT_K"/" -puzzles=$TST_FILE -seed=$ITERATION -batch_size=256
93 | head -$LINE_LOG_K $STEP_PATH/"solve_"$PASS_AT_K/results.json >> $LOG_FILE
94 | done
95 |
96 | cat $LOG_FILE
--------------------------------------------------------------------------------
/ICLR2023/src/gen.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Grab the configuration variables
3 | . babysit.sh
4 |
5 | if [[ -z "${AMLT_DATA_DIR}" ]]; then
6 | # running locally on torch2020 we don't have AMLT enviroment variables defined, set them up
7 | AMLT_DATA_DIR="../data"
8 | fi
9 |
10 | echo RANK is:
11 | echo $RANK
12 |
13 | if [[ -z "${RANK}" ]]; then
14 | # running locally on torch2020 we don't have AMLT enviroment variables defined, set them up
15 | RANK=0
16 | fi
17 |
18 | GPU=$RANK
19 | PUZZLE_CNT=32
20 | SOLUTION_CNT=32
21 | FULLNAME="125M_RL_TEST---0"
22 |
23 | echo $# $0 $1 $2 $3 $4
24 | if (( $# \>= 1 ))
25 | then
26 | GPU=$1
27 | fi
28 |
29 | echo $RANK
30 | echo $GPU
31 |
32 | if (( $# \>= 2 ))
33 | then
34 | PUZZLE_CNT=$2
35 | fi
36 |
37 | if (( $# \>= 3 ))
38 | then
39 | SOLUTION_CNT=$3
40 | fi
41 |
42 | if (( $# \>= 4 ))
43 | then
44 | FULLNAME=$4
45 |
46 | fi
47 |
48 | RANDOM_SEED=-1
49 |
50 | if (( $# \>= 5 ))
51 | then
52 | RANDOM_SEED=$5
53 | echo "Random seed is $RANDOM_SEED"
54 | fi
55 |
56 | SPLIT=(${FULLNAME//---/ })
57 | EXPERIMENT=${SPLIT[0]}
58 | ITERATION=${SPLIT[1]}
59 | OUTPATH=$AMLT_DATA_DIR/$EXPERIMENT/iter_$ITERATION
60 |
61 | echo GPU $GPU
62 | echo EXPERIMENT $EXPERIMENT
63 | echo ITERAION $ITERATION
64 | echo OUTPATH $OUTPATH
65 |
66 | BASE_MODEL_PATH=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL
67 | # 125M is copied locally to start
68 | MODEL_PATH=$BASE_MODEL_PATH
69 | MODEL_PATH=EleutherAI/gpt-neo-125M # !!! Just for paper release
70 | # 13B is off in the cloud to start
71 | if [[ "$MODEL" == "13B" ]]; then
72 | MODEL_PATH=EleutherAI/gpt-neo-1.3B
73 | fi
74 | # 27B is off in the cloud tto start
75 | if [[ "$MODEL" == "27B" ]]; then
76 | MODEL_PATH=EleutherAI/gpt-neo-2.7B
77 | fi
78 |
79 | echo MODEL MODEL_PATH $MODEL $MODEL_PATH
80 |
81 | PREV_ITERATION=$((ITERATION-1))
82 | echo $PREV_ITERATION
83 | TEST=$AMLT_DATA_DIR/../models/gpt-neo-$MODEL/$EXPERIMENT/iter_$PREV_ITERATION/epoch_$EPOCHS_MAX/finetuned
84 |
85 | if [ -a $TEST ] # exists
86 | then
87 | MODEL_PATH=$TEST
88 | echo fine_tune.sh using previous iteration model
89 | fi
90 |
91 | echo "gen.sh using NEO model at: ${MODEL_PATH}"
92 |
93 | python gen.py -out="$OUTPATH" -n=$PUZZLE_CNT -seed=$RANDOM_SEED -gpu=$GPU -train=$AMLT_DATA_DIR/155_train.json -prefix=$AMLT_DATA_DIR/train_prefix.txt -model_path=$MODEL_PATH -attempts=$SOLUTION_CNT -only_good=True
94 |
--------------------------------------------------------------------------------
/ICLR2023/src/judge.py:
--------------------------------------------------------------------------------
1 | from utils import load_json
2 | from pebble import ProcessPool
3 | import multiprocessing as mp
4 | from concurrent.futures import TimeoutError
5 | from typing import List, Set, Tuple, Dict
6 |
7 | import utils
8 | import sys
9 | import re
10 | from copy import deepcopy
11 |
12 | sys.setrecursionlimit(5000)
13 |
14 |
15 | def no_print(*_args, **_kwargs):
16 | pass
17 |
18 |
19 | def run_judge(judge, f, tests):
20 | answer_type = list(judge.__annotations__.values())[0]
21 | for x in tests:
22 | y = f(**deepcopy(x)) # so f cannot cheat and change the input x
23 | if not utils.type_check(y, answer_type):
24 | raise TypeError
25 | assert judge(y, **x) is True, f"{f} failed on test {x}"
26 |
27 |
28 | _ENV = dict(
29 | List=List,
30 | Set=Set,
31 | Tuple=Tuple,
32 | Dict=Dict,
33 | type_check=utils.type_check,
34 | run_judge=run_judge,
35 | test_puzzle=utils.test_puzzle,
36 | os=None,
37 | sys=None,
38 | input=None,
39 | open=None,
40 | print=no_print,
41 | compile=None,
42 | copyright=None,
43 | )
44 |
45 | _UNSAFE = ["builtin", "__class", "open("]
46 | _SAFE_IMPORTS = {"collections", "copy", "hashlib", "math", "random", "re", "string", "typing"}
47 |
48 | MAX_WORKERS = mp.cpu_count() // 2
49 |
50 |
51 |
52 | def unsafe_imports(code):
53 | """Check if code imports any unsafe modules.
54 |
55 | Args:
56 | code (str): The code to check.
57 |
58 | Returns:
59 | bool: True if code imports unsafe modules.
60 | """
61 | if "import" not in code:
62 | return False
63 | for line in code.split("\n"):
64 | if "import" in line:
65 | match = re.search(r"^\s*from\s+([\w\.]+)\s+import\s", line)
66 | if match:
67 | modules = [match.group(1)]
68 | else:
69 | match = re.search(r"^\s*import\s+(.+)", line)
70 | if match:
71 | modules = match.group(1).split(",")
72 | else:
73 | return True
74 | if any(m.strip() not in _SAFE_IMPORTS for m in modules):
75 | return True
76 | return False
77 |
78 |
79 | def _judge(code_env):
80 | code, env = code_env
81 | if unsafe_imports(code) or any(u in code for u in _UNSAFE):
82 | return False, Exception(f"unsafe code"), code
83 | try:
84 | exec(code, env.copy())
85 | return True, None, code
86 | except Exception as e:
87 | return False, e, code
88 |
89 |
90 | def judge_parallel(src_codes, timeout, max_workers=MAX_WORKERS, env=_ENV):
91 | codes = utils.dedup(src_codes)
92 | utils.info(
93 | f"Judging {len(src_codes):,} codes ({len(src_codes)-len(codes):,} duplicates) with {max_workers} workers"
94 | )
95 | successes = set()
96 |
97 | # print("writing to file for debugging before judging")
98 | # from train import save_json
99 | #
100 | # save_json(new_codes, "results/tmp/new_codes.json")
101 | utils.silence_std_err(True)
102 | with ProcessPool(max_workers=max_workers) as pool:
103 | future = pool.map(_judge, [(code, env) for code in codes], timeout=timeout)
104 |
105 | results = future.result()
106 | i = 0
107 | while True:
108 | try:
109 | success, exc, code = next(results)
110 | if success:
111 | successes.add(codes[i])
112 | except StopIteration:
113 | break
114 | except (TimeoutError, Exception) as error:
115 | pass
116 | assert i < len(codes)
117 | i += 1
118 | assert i == len(codes)
119 | utils.silence_std_err(False)
120 | return [code in successes for code in src_codes]
121 |
122 |
123 |
124 |
125 | def test():
126 | import time
127 |
128 | tests = [
129 | ("def sol(a: int=10000200001):\n return (list(range(3 * a))[str(a)])\nx = sol()", False),
130 | ("print(12)", True),
131 | ("while True: pass", False),
132 | ("def sol(): sol()\nsol()", False),
133 | ("2+2", True),
134 | ("""1+1""", True),
135 | ("""assert False,'cats'""", False),
136 | ("""assert False""", False),
137 | ("""1[2]""", False),
138 | ("""1/0""", False),
139 | (
140 | """while True:
141 | pass""",
142 | False,
143 | ),
144 | (
145 | """for i in range(10**4):
146 | pass""",
147 | True,
148 | ),
149 | ("print('hello')", True),
150 | ]
151 |
152 | scores = {}
153 | tests2 = tests
154 | pad = " "
155 | for _ in range(6):
156 | print(f"n={len(tests2)} timing test" + "*" * 20)
157 | times = []
158 | for max_workers in [4, 16, 32, 64, 128]:
159 | time0 = time.perf_counter()
160 | res = judge_parallel([test for test, r in tests2], timeout=1, max_workers=max_workers)
161 | for (test, expected), r in zip(tests2, res):
162 | assert expected == r, f"Failed expected {expected}, got {r} for {test}"
163 |
164 | times.append((max_workers, time.perf_counter() - time0))
165 |
166 | scores[len(tests2)] = times
167 | tests2 = tests2 + [(t + pad, r) for (t, r) in tests2]
168 | pad = pad * 2
169 | print("mp.cpu_count() =", mp.cpu_count())
170 |
171 | for n, times in scores.items():
172 | print(n, "tests, [(max_workers, time)] =", times)
173 |
174 |
175 | if __name__ == "__main__":
176 | test()
177 |
--------------------------------------------------------------------------------
/ICLR2023/src/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import utils
4 | import glob
5 | import json
6 | from typing import List
7 | from strictfire import StrictFire as Fire # aborts early on invalid arguments
8 |
9 | class WeirdInputsException(Exception):
10 | pass
11 |
12 | def get_inputs(sat: str):
13 | """Extacts arguments past the first from a function string
14 | def f(a: Dict[int, str], b=12):
15 | test
16 |
17 | should give 'b=12'
18 | """
19 | sat = sat.replace(" -> bool", "")
20 | first_line = sat.split("\n")[0].strip()
21 | if not first_line.endswith("):") and "#" in first_line:
22 | if "):" in first_line:
23 | n = first_line.index("):")
24 | if "#" in first_line[n:]:
25 | first_line = first_line[:n + first_line[n:].index("#")].strip()
26 | else:
27 | first_line = "" # raises exception below
28 | if not (first_line.endswith("):") and first_line.startswith("def")):
29 | raise WeirdInputsException("Weird puzzle, cannot extract inputs", json.dumps(sat))
30 | arg_str = first_line[first_line.index("("):-len("):")]
31 | depth = 0
32 | for i, c in enumerate(arg_str):
33 | if c == "," and depth == 0:
34 | return arg_str[i + 1:].strip()
35 | elif c == "[":
36 | depth += 1
37 | elif c == "]":
38 | depth -= 1
39 | return ""
40 |
41 | def main(
42 | path,
43 | filtered_name="gen_ps_filtered.txt",
44 | unfiltered_name=None, # "gen_ps_unfiltered.txt",
45 | max_sols_per_puzzle=8,
46 | seed=0):
47 | """
48 | Merge the puzzles from the given json input files. Examples:
49 | python preprocess.py -unfiltered_name=gen_ps_unfiltered.txt -- ~/aicoder/data/gen_125M_RL/*.json
50 |
51 | path: path to search for json files
52 | filtered_name: path to write puzzles, unfiltered (default: gen_ps_filtered.txt)
53 | unfiltered_name: path to write filtered puzzles (optional)
54 | max_sols_per_puzzle: maximum number of solutions per puzzle (default 8)
55 | seed: random seed (default 0) for reproducibility
56 | infiles: list of files to read puzzles from (like /path/*.json)
57 | """
58 |
59 | # Make the path so enumeration off that path works, even if it doesn't exist yet
60 | filtered_path = os.path.join(path, filtered_name)
61 | os.makedirs(os.path.dirname(filtered_path), exist_ok=True)
62 |
63 | codes = []
64 | all_codes = []
65 |
66 | # grab all the iter_* data for just this experiment
67 | gen_paths = [os.path.join(path, "../*/*.json")]
68 |
69 | # grab just the data for this iter_# for this experiment
70 | # gen_paths = [os.path.join(path, "*.json")]
71 |
72 | for gen_path in gen_paths:
73 | for filename in sorted(glob.glob(gen_path)):
74 | print("preprocess filename:", filename)
75 | js = utils.load_json(filename)
76 | for f, successes, failures in js:
77 | for body in sorted(utils.dedup(successes), key=len)[:max_sols_per_puzzle]:
78 |
79 | try:
80 | g = f"def g({get_inputs(f)}):{body}".strip("\\").strip()
81 | codes.append(f + "\n\n" + g + "\n\n" + "assert f(g())\n\n")
82 | except WeirdInputsException:
83 | print("failed to create g")
84 | pass
85 | print(f"{len(codes):,}/{len(all_codes):,} puzzles of preprocessing {filename}")
86 |
87 | print("len(codes)", len(codes))
88 | codes = utils.dedup(codes)
89 | print("len(codes) after dedup", len(codes))
90 |
91 | random.shuffle(codes)
92 | random.shuffle(all_codes)
93 |
94 | # Make it the same number of examples as we got from codex
95 | codes = codes[:950511]
96 | print("len(codes) after truncation", len(codes))
97 |
98 | code = "".join(codes)
99 |
100 | utils.save_text_file(code, filtered_path)
101 | print(f"Wrote filtered results to {filtered_path}")
102 |
103 | assert unfiltered_name is None, "Not supported now, go back to earlier version"
104 | if unfiltered_name:
105 | unfiltered_path = os.path.join(path, filtered_name)
106 | utils.save_text_file("".join(all_codes), unfiltered_path)
107 | print(f"Wrote unfiltered results to {unfiltered_path}")
108 |
109 |
110 | if __name__ == "__main__":
111 | Fire(main)
112 |
--------------------------------------------------------------------------------
/ICLR2023/src/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | orderedset
3 | numpy
4 | astor
5 | sklearn
6 | fire
7 | strictfire
8 | pebble
9 | deepspeed == 0.6.1
10 | transformers == 4.30.0
--------------------------------------------------------------------------------
/ICLR2023/src/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import inspect
4 | import io
5 | import os
6 | import sys
7 | import time
8 | from transformers import AutoTokenizer
9 |
10 |
11 | os.environ["WANDB_DISABLED"] = "true"
12 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
13 | my_path = os.path.dirname(__file__)
14 |
15 |
16 | def load_tokenizer(model_path):
17 | tokenizer = AutoTokenizer.from_pretrained(model_path)
18 | tokenizer.padding_side = "left"
19 | tokenizer.pad_token = tokenizer.eos_token
20 | return tokenizer
21 |
22 |
23 | def num_tokens(s: str, tokenizer, verbose=False):
24 |
25 | start_time = time.time()
26 | if verbose:
27 | info(f"Tokenizing {pretty_int(len(s))} chars ({pretty_int(len(s.splitlines()))} lines)")
28 | # ans = _tokenizer(s, return_tensors="pt").input_ids.shape[1] # produces annoying warnings
29 | ans = tokenizer(s, return_tensors="pt", max_length=10 + len(s), truncation=True).input_ids.shape[1]
30 |
31 | duration_mins = (time.time() - start_time)/60
32 | if verbose:
33 | info(f"Num tokens: {ans:,} in {duration_mins:.2f} mins")
34 | return ans
35 |
36 |
37 | def create_experiment_outpath(out: str, bSaveCommand=True):
38 | """
39 | Create the output directory and return its name. Also stores the command line in command.sh
40 | Date format is like Jan-1-2020
41 | """
42 | output_path = str(out).replace("", time.strftime("%b%d-%H-%M-%S"))
43 | os.makedirs(output_path, exist_ok=True) # ran into error due to non-atomic check
44 | if bSaveCommand:
45 | save_text_file(' '.join([sys.executable] + sys.argv) + "\n", f"{output_path}/command.sh")
46 | # make command.sh executable:
47 | os.chmod(f"{output_path}/command.sh", 0o755)
48 | return output_path
49 |
50 | def pretty_int(n: int) -> str:
51 | """Converts an integer to a string with commas, with M for millions and B for billions"""
52 | if n > 1_000_000_000:
53 | return f"{n/1_000_000_000:.1f}B"
54 | if n > 1_000_000:
55 | return f"{n/1_000_000:.1f}M"
56 | return f"{n:,}"
57 |
58 |
59 |
60 | def test_puzzle(f, x):
61 | """Checks if x is of the correct type and makes f return True (literally True, not an integer or whatever)
62 |
63 | :param f: Puzzle
64 | :param x: candidate answer
65 | :return:
66 | """
67 | answer_type = list(f.__annotations__.values())[0]
68 | if not type_check(x, answer_type):
69 | raise TypeError
70 | return f(x) is True
71 |
72 |
73 |
74 | def type_check(obj, typ):
75 | """
76 | check if obj is of type `typ` where `typ` is a `typing` module type annotation, eg List[int]
77 | The way we do this to be compatible across versions is we first convert the type to a string.
78 | """
79 |
80 | type_str = str(typ).replace("typing.", "")
81 | if type_str.startswith("
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/generators/__init__.py:
--------------------------------------------------------------------------------
1 | # problems are identified automatically by looking for subclasses (and sub-subclasses) of the Problem class
2 |
3 | from . import study
4 | from . import classic_puzzles
5 | from . import human_eval
6 | from . import codeforces
7 | from . import algebra
8 | from . import basic
9 | from . import chess
10 | from . import compression
11 | from . import conways_game_of_life
12 | from . import games
13 | from . import graphs
14 | from . import ICPC
15 | from . import IMO
16 | from . import lattices
17 | from . import number_theory
18 | from . import probability
19 | from . import trivial_inverse
20 | from . import tutorial
--------------------------------------------------------------------------------
/generators/algebra.py:
--------------------------------------------------------------------------------
1 | """Roots of polynomials"""
2 |
3 | from puzzle_generator import PuzzleGenerator, Tags
4 | from typing import List
5 |
6 |
7 | # See https://github.com/microsoft/PythonProgrammingPuzzles/wiki/How-to-add-a-puzzle to learn about adding puzzles
8 |
9 |
10 | class QuadraticRoot(PuzzleGenerator):
11 | """See [quadratic equations](https://en.wikipedia.org/wiki/Quadratic_formula)"""
12 |
13 | tags = [Tags.math, Tags.famous]
14 |
15 | @staticmethod
16 | def sat(x: float, coeffs=[2.5, 1.3, -0.5]):
17 | """
18 | Find any (real) solution to: a x^2 + b x + c where coeffs = [a, b, c].
19 | For example, since x^2 - 3x + 2 has a root at 1, sat(x = 1., coeffs = [1., -3., 2.]) is True.
20 | """
21 | a, b, c = coeffs
22 | return abs(a * x ** 2 + b * x + c) < 1e-6
23 |
24 | @staticmethod
25 | def sol(coeffs):
26 | a, b, c = coeffs
27 | if a == 0:
28 | ans = -c / b if b != 0 else 0.0
29 | else:
30 | ans = ((-b + (b ** 2 - 4 * a * c) ** 0.5) / (2 * a))
31 | return ans
32 |
33 | @staticmethod
34 | def sol2(coeffs):
35 | a, b, c = coeffs
36 | if a == 0:
37 | ans = -c / b if b != 0 else 0.0
38 | else:
39 | ans = (-b - (b ** 2 - 4 * a * c) ** 0.5) / (2 * a)
40 | return ans
41 |
42 | def gen_random(self):
43 | x, a, b = [self.random.heavy_tail_float() for _ in range(3)]
44 | c = -(a * x ** 2 + b * x) # make sure it has a real-valued solution
45 | coeffs = [a, b, c]
46 | self.add(dict(coeffs=coeffs))
47 |
48 |
49 | class AllQuadraticRoots(PuzzleGenerator):
50 | """See [quadratic equations](https://en.wikipedia.org/wiki/Quadratic_formula)."""
51 |
52 | tags = [Tags.math, Tags.famous]
53 |
54 | @staticmethod
55 | def sat(roots: List[float], coeffs=[1.3, -0.5]):
56 | """Find all (real) solutions to: x^2 + b x + c (i.e., factor into roots), here coeffs = [b, c]"""
57 | b, c = coeffs
58 | r1, r2 = roots
59 | return abs(r1 + r2 + b) + abs(r1 * r2 - c) < 1e-6
60 |
61 | @staticmethod
62 | def sol(coeffs):
63 | b, c = coeffs
64 | delta = (b ** 2 - 4 * c) ** 0.5
65 | return [(-b + delta) / 2, (-b - delta) / 2]
66 |
67 | def gen_random(self):
68 | x, b = [self.random.heavy_tail_float() for _ in range(2)]
69 | c = -(x ** 2 + b * x) # make sure it has a real-valued solution
70 | coeffs = [b, c]
71 | self.add(dict(coeffs=coeffs))
72 |
73 |
74 | class CubicRoot(PuzzleGenerator):
75 | """See [cubic equation](https://en.wikipedia.org/wiki/Cubic_formula)."""
76 |
77 | tags = [Tags.math, Tags.famous]
78 |
79 | @staticmethod
80 | def sat(x: float, coeffs=[2.0, 1.0, 0.0, 8.0]):
81 | """
82 | Find any (real) solution to: a x^3 + b x^2 + c x + d where coeffs = [a, b, c, d]
83 | For example, since (x-1)(x-2)(x-3) = x^3 - 6x^2 + 11x - 6, sat(x = 1., coeffs = [-6., 11., -6.]) is True.
84 | """
85 | return abs(sum(c * x ** (3 - i) for i, c in enumerate(coeffs))) < 1e-6
86 |
87 | @staticmethod
88 | def sol(coeffs):
89 | a2, a1, a0 = [c / coeffs[0] for c in coeffs[1:]]
90 | p = (3 * a1 - a2 ** 2) / 3
91 | q = (9 * a1 * a2 - 27 * a0 - 2 * a2 ** 3) / 27
92 | delta = (q ** 2 + 4 * p ** 3 / 27) ** 0.5
93 | omega = (-(-1) ** (1 / 3))
94 | for cube in [(q + delta) / 2, (q - delta) / 2]:
95 | c = cube ** (1 / 3)
96 | for w in [c, c * omega, c * omega.conjugate()]:
97 | if w != 0:
98 | x = complex(w - p / (3 * w) - a2 / 3).real
99 | if abs(sum(c * x ** (3 - i) for i, c in enumerate(coeffs))) < 1e-6:
100 | return x
101 |
102 | def gen_random(self):
103 | x, a, b, c = [self.random.heavy_tail_float() for _ in range(4)]
104 | d = -(a * x ** 3 + b * x ** 2 + c * x) # make sure it has a real-valued solution
105 | coeffs = [a, b, c, d]
106 | if self.sol(coeffs) is not None:
107 | self.add(dict(coeffs=coeffs))
108 |
109 |
110 | class AllCubicRoots(PuzzleGenerator):
111 | """See [cubic equation](https://en.wikipedia.org/wiki/Cubic_formula)."""
112 |
113 | tags = [Tags.math, Tags.famous]
114 |
115 | @staticmethod
116 | def sat(roots: List[float], coeffs=[1.0, -2.0, -1.0]):
117 | """Find all 3 distinct real roots of x^3 + a x^2 + b x + c, i.e., factor into (x-r1)(x-r2)(x-r3).
118 | coeffs = [a, b, c]. For example, since (x-1)(x-2)(x-3) = x^3 - 6x^2 + 11x - 6,
119 | sat(roots = [1., 2., 3.], coeffs = [-6., 11., -6.]) is True.
120 | """
121 | r1, r2, r3 = roots
122 | a, b, c = coeffs
123 | return abs(r1 + r2 + r3 + a) + abs(r1 * r2 + r1 * r3 + r2 * r3 - b) + abs(r1 * r2 * r3 + c) < 1e-6
124 |
125 | @staticmethod
126 | def sol(coeffs):
127 | a, b, c = coeffs
128 | p = (3 * b - a ** 2) / 3
129 | q = (9 * b * a - 27 * c - 2 * a ** 3) / 27
130 | delta = (q ** 2 + 4 * p ** 3 / 27) ** 0.5
131 | omega = (-(-1) ** (1 / 3))
132 | ans = []
133 | for cube in [(q + delta) / 2, (q - delta) / 2]:
134 | v = cube ** (1 / 3)
135 | for w in [v, v * omega, v * omega.conjugate()]:
136 | if w != 0.0:
137 | x = complex(w - p / (3 * w) - a / 3).real
138 | if abs(x ** 3 + a * x ** 2 + b * x + c) < 1e-4:
139 | if not ans or min(abs(z - x) for z in ans) > 1e-6:
140 | ans.append(x)
141 | if len(ans) == 3:
142 | return ans
143 |
144 | def gen_random(self):
145 | r1, r2, r3 = [self.random.heavy_tail_float() for _ in range(3)]
146 | coeffs = [-r1 - r2 - r3, r1 * r2 + r1 * r3 + r2 * r3, -r1 * r2 * r3] # to ensure solvability
147 | if self.sol(coeffs) is not None:
148 | self.add(dict(coeffs=coeffs)) # won't add duplicates
149 |
150 |
151 | if __name__ == "__main__":
152 | PuzzleGenerator.debug_problems()
153 |
--------------------------------------------------------------------------------
/generators/compression.py:
--------------------------------------------------------------------------------
1 | """Puzzles relating to de/compression."""
2 |
3 | from puzzle_generator import PuzzleGenerator, Tags
4 | from typing import List
5 |
6 |
7 | # See https://github.com/microsoft/PythonProgrammingPuzzles/wiki/How-to-add-a-puzzle to learn about adding puzzles
8 |
9 |
10 | def _compress_LZW(text): # for development
11 | index = {chr(i): i for i in range(256)}
12 |
13 | seq = []
14 | buffer = ""
15 | for c in text:
16 | if buffer + c in index:
17 | buffer += c
18 | continue
19 | seq.append(index[buffer])
20 | index[buffer + c] = len(index) + 1
21 | buffer = c
22 |
23 | if text != "":
24 | seq.append(index[buffer])
25 |
26 | return seq
27 |
28 |
29 | def _decompress_LZW(seq: List[int]): # for development
30 | index = [chr(i) for i in range(256)]
31 | pieces = [""]
32 | for i in seq:
33 | pieces.append(pieces[-1] + pieces[-1][0] if i == len(index) else index[i])
34 | index.append(pieces[-2] + pieces[-1][0])
35 | return "".join(pieces)
36 |
37 |
38 | class LZW(PuzzleGenerator):
39 | """
40 | We have provided a simple version of the *decompression* algorithm of
41 | [Lempel-Ziv-Welch](https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv%E2%80%93Welch)
42 | so the solution is the *compression* algorithm.
43 | """
44 |
45 | tags = [Tags.strings, Tags.famous]
46 |
47 |
48 | # _compress_LZW("Hellooooooooooooooooooooo world!") is length-17
49 |
50 | @staticmethod
51 | def sat(seq: List[int], compressed_len=17, text="Hellooooooooooooooooooooo world!"):
52 | """
53 | Find a (short) compression that decompresses to the given string for the provided implementation of the
54 | Lempel-Ziv decompression algorithm from https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv%E2%80%93Welch
55 | """
56 | index = [chr(i) for i in range(256)]
57 | pieces = [""]
58 | for i in seq:
59 | pieces.append((pieces[-1] + pieces[-1][0]) if i == len(index) else index[i])
60 | index.append(pieces[-2] + pieces[-1][0])
61 | return "".join(pieces) == text and len(seq) <= compressed_len
62 |
63 | @staticmethod
64 | def sol(compressed_len, text):
65 | # compressed_len is ignored
66 | index = {chr(i): i for i in range(256)}
67 | seq = []
68 | buffer = ""
69 | for c in text:
70 | if buffer + c in index:
71 | buffer += c
72 | continue
73 | seq.append(index[buffer])
74 | index[buffer + c] = len(index) + 1
75 | buffer = c
76 |
77 | if text != "":
78 | seq.append(index[buffer])
79 |
80 | return seq
81 |
82 | def gen(self, _target_num_instances):
83 | self.add({"text": "", "compressed_len": 0})
84 | self.add({"text": "c" * 1000, "compressed_len": len(_compress_LZW("c" * 1000))})
85 |
86 | def gen_random(self):
87 | max_len = self.random.choice([10, 100, 1000])
88 | text = self.random.pseudo_word(0, max_len)
89 | self.add({"text": text, "compressed_len": len(_compress_LZW(text))})
90 |
91 | # Removed this puzzle because the puzzle statement would give away the solution to LZW, haha!
92 | # class LZW_decompress(PuzzleGenerator):
93 | # """We have provided a simple version of the
94 | # [Lempel-Ziv-Welch](https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv%E2%80%93Welch)
95 | # and the solution is the *decompression* algorithm.
96 | # """
97 | #
98 | # @staticmethod
99 | # def sat(text: str, seq=[72, 101, 108, 108, 111, 32, 119, 111, 114, 100, 262, 264, 266, 263, 265, 33]):
100 | # """
101 | # Find a string that compresses to the target sequence for the provided implementation of the
102 | # Lempel-Ziv algorithm from https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv%E2%80%93Welch
103 | # """
104 | # index = {chr(i): i for i in range(256)}
105 | # seq2 = []
106 | # buffer = ""
107 | # for c in text:
108 | # if buffer + c in index:
109 | # buffer += c
110 | # continue
111 | # seq2.append(index[buffer])
112 | # index[buffer + c] = len(index) + 1
113 | # buffer = c
114 | #
115 | # if text != "":
116 | # seq2.append(index[buffer])
117 | #
118 | # return seq2 == seq
119 | #
120 | # @staticmethod
121 | # def sol(seq):
122 | # index = [chr(i) for i in range(256)]
123 | # pieces = [""]
124 | # for i in seq:
125 | # pieces.append(pieces[-1] + pieces[-1][0] if i == len(index) else index[i])
126 | # index.append(pieces[-2] + pieces[-1][0])
127 | # return "".join(pieces)
128 | #
129 | # def gen(self, _target_num_instances):
130 | # for s in ['', 'a', 'b' * 1000, 'ab' * 1000 + '!']:
131 | # self.add({"seq": _compress_LZW(s)})
132 | #
133 | # def gen_random(self):
134 | # max_len = self.random.choice([10, 100, 1000])
135 | # text = self.random.pseudo_word(0, max_len)
136 | # self.add({"seq": _compress_LZW(text)})
137 |
138 |
139 | class PackingHam(PuzzleGenerator):
140 | """
141 | This packing problem a [classic problem](https://en.wikipedia.org/wiki/Sphere_packing#Other_spaces)
142 | in coding theory.
143 | """
144 |
145 | tags = [Tags.strings, Tags.famous]
146 |
147 | @staticmethod
148 | def sat(words: List[str], num=100, bits=100, dist=34):
149 | """Pack a certain number of binary strings so that they have a minimum hamming distance between each other."""
150 | assert len(words) == num and all(len(word) == bits and set(word) <= {"0", "1"} for word in words)
151 | return all(sum([a != b for a, b in zip(words[i], words[j])]) >= dist for i in range(num) for j in range(i))
152 |
153 | @staticmethod
154 | def sol(num, bits, dist):
155 | import random # key insight, use randomness!
156 | r = random.Random(0)
157 | while True:
158 | seqs = [r.getrandbits(bits) for _ in range(num)]
159 | if all(bin(seqs[i] ^ seqs[j]).count("1") >= dist for i in range(num) for j in range(i)):
160 | return [bin(s)[2:].rjust(bits, '0') for s in seqs]
161 |
162 | def gen_random(self):
163 | bits = self.random.randrange(1, self.random.choice([10, 100]))
164 | num = self.random.randrange(2, self.random.choice([10, 100]))
165 |
166 | def score(seqs):
167 | return min(bin(seqs[i] ^ seqs[j]).count("1") for i in range(num) for j in range(i))
168 |
169 | # best of 5
170 | seqs = min([[self.random.getrandbits(bits) for _ in range(num)] for _ in range(5)], key=score)
171 | dist = score(seqs)
172 | if dist > 0:
173 | self.add(dict(num=num, bits=bits, dist=dist))
174 |
175 |
176 | if __name__ == "__main__":
177 | PuzzleGenerator.debug_problems()
178 |
--------------------------------------------------------------------------------
/generators/lattices.py:
--------------------------------------------------------------------------------
1 | """Lattice problems with and without noise"""
2 |
3 | from puzzle_generator import PuzzleGenerator, Tags
4 | from typing import List
5 |
6 |
7 | # See https://github.com/microsoft/PythonProgrammingPuzzles/wiki/How-to-add-a-puzzle to learn about adding puzzles
8 |
9 |
10 | class LearnParity(PuzzleGenerator):
11 | """Parity learning (Gaussian elimination)
12 |
13 | The canonical solution to this
14 | [Parity learning problem](https://en.wikipedia.org/w/index.php?title=Parity_learning)
15 | is to use
16 | [Gaussian Elimination](https://en.wikipedia.org/w/index.php?title=Gaussian_elimination).
17 |
18 | The vectors are encoded as binary integers for succinctness.
19 | """
20 |
21 | @staticmethod
22 | def sat(inds: List[int], vecs=[169, 203, 409, 50, 37, 479, 370, 133, 53, 159, 161, 367, 474, 107, 82, 447, 385]):
23 | """
24 | Parity learning: Given binary vectors in a subspace, find the secret set S of indices such that:
25 | $\\sum_{i \in S} x_i = 1 (mod 2)$
26 | """
27 | return all(sum((v >> i) & 1 for i in inds) % 2 == 1 for v in vecs)
28 |
29 | @staticmethod
30 | def sol(vecs):
31 | # Gaussian elimination
32 | d = 0 # decode vectors into arrays
33 | m = max(vecs)
34 | while m:
35 | m >>= 1
36 | d += 1
37 | vecs = [[(n >> i) & 1 for i in range(d)] for n in vecs]
38 | ans = []
39 | pool = [[0] * (d + 1) for _ in range(d)] + [v + [1] for v in vecs]
40 | for i in range(d):
41 | pool[i][i] = 1
42 |
43 | for i in range(d): # zero out bit i
44 | for v in pool[d:]:
45 | if v[i] == 1:
46 | break
47 | if v[i] == 0:
48 | v = pool[i]
49 | assert v[i] == 1 # found a vector with v[i] = 1, subtract it off from those with a 1 in the ith coordinate
50 | w = v[:]
51 | for v in pool:
52 | if v[i] == 1:
53 | for j in range(d + 1):
54 | v[j] ^= w[j]
55 |
56 | return [i for i in range(d) if pool[i][-1]]
57 |
58 | @staticmethod
59 | def rand_parity_problem(rand, d=63):
60 | secret = rand.sample(range(d), d // 2)
61 | num_vecs = d + 9
62 | vecs = [[rand.randrange(2) for _ in range(d)] for i in range(num_vecs)]
63 | for v in vecs:
64 | v[secret[0]] = (1 + sum([v[i] for i in secret[1:]])) % 2
65 | vecs = [sum(1 << i for i, b in enumerate(v) if b) for v in vecs] # encode into ints
66 | return vecs
67 |
68 | def gen(self, target_num_instances):
69 | vecs = self.rand_parity_problem(self.random, d=63)
70 | self.add(dict(vecs=vecs), multiplier=10)
71 |
72 | def gen_random(self):
73 | d = self.random.randrange(2, self.random.choice([5, 10, 20, 100]))
74 | vecs = self.rand_parity_problem(
75 | self.random,
76 | d=d,
77 | )
78 | self.add(dict(vecs=vecs), multiplier=10 if d > 9 else 1)
79 |
80 |
81 | class LearnParityWithNoise(PuzzleGenerator):
82 | """Learn parity with noise (*unsolved*)
83 |
84 | The fastest known algorithm to this
85 | [Parity learning problem](https://en.wikipedia.org/w/index.php?title=Parity_learning)
86 | runs in time $2^(d/(log d))$
87 |
88 | The example puzzle has small dimension so is easily solvable, but other instances are much harder.
89 | """
90 |
91 | @staticmethod
92 | def sat(inds: List[int], vecs=[26, 5, 32, 3, 15, 18, 31, 13, 24, 25, 34, 5, 15, 24, 16, 13, 0, 27, 37]):
93 | """
94 | Learning parity with noise: Given binary vectors, find the secret set $S$ of indices such that, for at least
95 | 3/4 of the vectors, $$sum_{i \in S} x_i = 1 (mod 2)$$
96 | """
97 | return sum(sum((v >> i) & 1 for i in inds) % 2 for v in vecs) >= len(vecs) * 3 / 4
98 |
99 | @staticmethod
100 | def sol(vecs):
101 | # brute force
102 | d = 0 # decode vectors into arrays
103 | m = max(vecs)
104 | while m:
105 | m >>= 1
106 | d += 1
107 | vecs = [[(n >> i) & 1 for i in range(d)] for n in vecs]
108 |
109 | import random
110 | rand = random.Random(0)
111 | target = (len(vecs) * 3) // 4
112 | max_attempts = 10 ** 5
113 | for _ in range(max_attempts):
114 | ans = [i for i in range(d) if rand.randrange(2)]
115 | if sum(sum(v[i] for i in ans) % 2 for v in vecs) >= len(vecs) * 3 / 4:
116 | return ans
117 |
118 | @staticmethod
119 | def rand_parity_problem(rand, d=63):
120 | secret = rand.sample(range(d), d // 2)
121 | num_vecs = 2 * d + 5
122 | vecs = [[rand.randrange(2) for _ in range(d)] for i in range(num_vecs)]
123 | for v in vecs:
124 | v[secret[0]] = (1 + sum([v[i] for i in secret[1:]])) % 2
125 | mistakes = rand.sample(vecs, int(len(vecs) * rand.random() * 1 / 4))
126 | for v in mistakes:
127 | v[secret[0]] ^= 1 # flip bit in mistakes
128 | vecs = [sum(1 << i for i, b in enumerate(v) if b) for v in vecs] # encode into ints
129 | return vecs
130 |
131 | def gen(self, target_num_instances):
132 | vecs = self.rand_parity_problem(self.random, d=63)
133 | self.add(dict(vecs=vecs), test=False, multiplier=1000)
134 |
135 | def gen_random(self):
136 | d = self.random.randrange(2, self.random.choice([11, 100])) # number of dimensions
137 | vecs = self.rand_parity_problem(
138 | self.random,
139 | d=d
140 | )
141 | self.add(dict(vecs=vecs), test=d < 19, multiplier=1000 if d > 40 else 30 if d >= 19 else 1)
142 |
143 | if __name__ == "__main__":
144 | PuzzleGenerator.debug_problems()
145 |
--------------------------------------------------------------------------------
/generators/probability.py:
--------------------------------------------------------------------------------
1 | """Probability problems"""
2 |
3 | from puzzle_generator import PuzzleGenerator, Tags
4 | from typing import List
5 |
6 |
7 | # See https://github.com/microsoft/PythonProgrammingPuzzles/wiki/How-to-add-a-puzzle to learn about adding puzzles
8 |
9 |
10 | class BirthdayParadox(PuzzleGenerator):
11 | """
12 | Adaptation of the classic
13 | [Birthday Problem](https://en.wikipedia.org/wiki/Birthday_problem (Mathematical Problems category)).
14 |
15 | The year length is year_len (365 is earth, while Neptune year is 60,182).
16 | """
17 |
18 | @staticmethod
19 | def sat(n: int, year_len=365):
20 | """Find n such that the probability of two people having the same birthday in a group of n is near 1/2."""
21 | prob = 1.0
22 | for i in range(n):
23 | prob *= (year_len - i) / year_len
24 | return (prob - 0.5) ** 2 <= 1/year_len
25 |
26 | @staticmethod
27 | def sol(year_len):
28 | n = 1
29 | distinct_prob = 1.0
30 | best = (0.5, 1) # (difference between probability and 1/2, n)
31 | while distinct_prob > 0.5:
32 | distinct_prob *= (year_len - n) / year_len
33 | n += 1
34 | best = min(best, (abs(0.5 - distinct_prob), n))
35 |
36 | return best[1]
37 |
38 | def safe_add(self, **inputs):
39 | if self.sat(self.sol(**inputs), **inputs):
40 | self.add(inputs)
41 |
42 | def gen(self, target_num_instances):
43 | self.safe_add(year_len=60182) # Neptune year!
44 | for year_len in range(2, target_num_instances):
45 | self.safe_add(year_len=year_len)
46 |
47 |
48 |
49 | class BirthdayParadoxMonteCarlo(BirthdayParadox):
50 | """A slower, Monte Carlo version of the above Birthday Paradox problem."""
51 |
52 | @staticmethod
53 | def sat(n: int, year_len=365):
54 | """Find n such that the probability of two people having the same birthday in a group of n is near 1/2."""
55 | import random
56 | random.seed(0)
57 | K = 1000 # number of samples
58 | prob = sum(len({random.randrange(year_len) for i in range(n)}) < n for j in range(K)) / K
59 | return (prob - 0.5) ** 2 <= year_len
60 |
61 |
62 |
63 |
64 |
65 | class BallotProblem(PuzzleGenerator):
66 | """
67 | See the [Wikipedia article](https://en.wikipedia.org/wiki/Bertrand%27s_ballot_theorem) or
68 | or [Addario-Berry L., Reed B.A. (2008) Ballot Theorems, Old and New. In: Gyori E., Katona G.O.H., Lovász L.,
69 | Sági G. (eds) Horizons of Combinatorics. Bolyai Society Mathematical Studies, vol 17.
70 | Springer, Berlin, Heidelberg.](https://doi.org/10.1007/978-3-540-77200-2_1)
71 | """
72 |
73 | @staticmethod
74 | def sat(counts: List[int], target_prob=0.5):
75 | """
76 | Suppose a list of m 1's and n -1's are permuted at random.
77 | What is the probability that all of the cumulative sums are positive?
78 | The goal is to find counts = [m, n] that make the probability of the ballot problem close to target_prob.
79 | """
80 | m, n = counts # m = num 1's, n = num -1's
81 | probs = [1.0] + [0.0] * n # probs[n] is probability for current m, starting with m = 1
82 | for i in range(2, m + 1): # compute probs using dynamic programming for m = i
83 | old_probs = probs
84 | probs = [1.0] + [0.0] * n
85 | for j in range(1, min(n + 1, i)):
86 | probs[j] = (
87 | j / (i + j) * probs[j - 1] # last element is a -1 so use probs
88 | +
89 | i / (i + j) * old_probs[j] # last element is a 1 so use old_probs, m = i - 1
90 | )
91 | return abs(probs[n] - target_prob) < 1e-6
92 |
93 | @staticmethod
94 | def sol(target_prob):
95 | for m in range(1, 10000):
96 | n = round(m * (1 - target_prob) / (1 + target_prob))
97 | if abs(target_prob - (m - n) / (m + n)) < 1e-6:
98 | return [m, n]
99 |
100 | def gen_random(self):
101 | m = self.random.randrange(1, self.random.choice([10, 100, 200, 300, 400, 500, 1000]))
102 | n = self.random.randrange(1, m + 1)
103 | target_prob = (m - n) / (m + n)
104 | self.add(dict(target_prob=target_prob))
105 |
106 |
107 | class BinomialProbabilities(PuzzleGenerator):
108 | """See [Binomial distribution](https://en.wikipedia.org/wiki/Binomial_distribution)"""
109 |
110 | @staticmethod
111 | def sat(counts: List[int], p=0.5, target_prob=1 / 16.0):
112 | """Find counts = [a, b] so that the probability of a H's and b T's among a + b coin flips is ~ target_prob."""
113 | from itertools import product
114 | a, b = counts
115 | n = a + b
116 | prob = (p ** a) * ((1-p) ** b)
117 | tot = sum([prob for sample in product([0, 1], repeat=n) if sum(sample) == a])
118 | return abs(tot - target_prob) < 1e-6
119 |
120 | @staticmethod
121 | def sol(p, target_prob):
122 | probs = [1.0]
123 | q = 1 - p
124 | while len(probs) < 20:
125 | probs = [(p * a + q * b) for a, b in zip([0] + probs, probs + [0])]
126 | answers = [i for i, p in enumerate(probs) if abs(p - target_prob) < 1e-6]
127 | if answers:
128 | return [answers[0], len(probs) - 1 - answers[0]]
129 |
130 | def gen_random(self):
131 | probs = [1.0]
132 | p = self.random.random()
133 | q = 1 - p
134 | for n in range(self.random.randrange(1, 11)):
135 | probs = [(p * a + q * b) for a, b in zip([0] + probs, probs + [0])]
136 | target_prob = self.random.choice(probs)
137 | self.add(dict(p=p, target_prob=target_prob))
138 |
139 |
140 |
141 | class ExponentialProbability(PuzzleGenerator):
142 | """See [Exponential distribution](https://en.wikipedia.org/wiki/Exponential_distribution)"""
143 |
144 | @staticmethod
145 | def sat(p_stop: float, steps=10, target_prob=0.5):
146 | """
147 | Find p_stop so that the probability of stopping in steps or fewer time steps is the given target_prob if you
148 | stop each step with probability p_stop
149 | """
150 | prob = sum(p_stop*(1-p_stop)**t for t in range(steps))
151 | return abs(prob - target_prob) < 1e-6
152 |
153 | @staticmethod
154 | def sol(steps, target_prob):
155 | return 1 - (1 - target_prob) ** (1.0/steps)
156 |
157 | def gen_random(self):
158 | steps = self.random.randrange(1, 100)
159 | target_prob = self.random.random()
160 | self.add(dict(steps=steps, target_prob=target_prob))
161 |
162 |
163 |
164 | if __name__ == "__main__":
165 | PuzzleGenerator.debug_problems()
166 |
--------------------------------------------------------------------------------
/generators/tutorial.py:
--------------------------------------------------------------------------------
1 | """
2 | A few example puzzles that were presented with solutions to participants of the study.
3 | """
4 |
5 | from puzzle_generator import PuzzleGenerator, Tags
6 | from typing import List
7 |
8 |
9 | # See https://github.com/microsoft/PythonProgrammingPuzzles/wiki/How-to-add-a-puzzle to learn about adding puzzles
10 |
11 |
12 | class Tutorial1(PuzzleGenerator):
13 | @staticmethod
14 | def sat(s: str):
15 | """Find a string that when concatenated onto 'Hello ' gives 'Hello world'."""
16 | return "Hello " + s == "Hello world"
17 |
18 | @staticmethod
19 | def sol():
20 | return "world"
21 |
22 |
23 | class Tutorial2(PuzzleGenerator):
24 | @staticmethod
25 | def sat(s: str):
26 | """Find a string that when reversed and concatenated onto 'Hello ' gives 'Hello world'."""
27 | return "Hello " + s[::-1] == "Hello world"
28 |
29 | @staticmethod
30 | def sol():
31 | return "world"[::-1]
32 |
33 |
34 | class Tutorial3(PuzzleGenerator):
35 | @staticmethod
36 | def sat(x: List[int]):
37 | """Find a list of two integers whose sum is 3."""
38 | return len(x) == 2 and sum(x) == 3
39 |
40 | @staticmethod
41 | def sol():
42 | return [1, 2]
43 |
44 |
45 | class Tutorial4(PuzzleGenerator):
46 | @staticmethod
47 | def sat(s: List[str]):
48 | """Find a list of 1000 distinct strings which each have more 'a's than 'b's and at least one 'b'."""
49 | return len(set(s)) == 1000 and all((x.count("a") > x.count("b")) and ('b' in x) for x in s)
50 |
51 | @staticmethod
52 | def sol():
53 | return ["a" * (i + 2) + "b" for i in range(1000)]
54 |
55 |
56 | class Tutorial5(PuzzleGenerator):
57 | @staticmethod
58 | def sat(n: int):
59 | """Find an integer whose perfect square begins with 123456789 in its decimal representation."""
60 | return str(n * n).startswith("123456789")
61 |
62 | @staticmethod
63 | def sol():
64 | return int(int("123456789" + "0" * 9) ** 0.5) + 1
65 |
66 |
67 | # Not clear this last one is necessary/helpful
68 | # # class Tutorial6(Problem):
69 | # """Find a string corresponding to a decimal number whose negation is 1337."""
70 | #
71 | # @staticmethod
72 | # def sat(s: str):
73 | # return -1 * int(s) == 1337
74 | #
75 | # @staticmethod
76 | # def sol():
77 | # return str(-1337)
78 |
79 | if __name__ == "__main__":
80 | PuzzleGenerator.debug_problems()
81 |
--------------------------------------------------------------------------------
/notebooks/Demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "deletable": false
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# Please run this cell first.\n",
12 | "from demo.demo import next_puzzle, cur_puzzle, puzzle, give_up"
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {},
18 | "source": [
19 | "# Instructions\n",
20 | "\n",
21 | "* This [Jupyter](https://jupyter.org/) notebook is a demo simulating the study used for the [**Programming Puzzles**](https://arxiv.org/abs/2106.05784) paper (with some modifications to run locally).\n",
22 | "* You can reuse the same notebook cells to solve all puzzles (run cell with Ctrl+Enter), or open new cells as you advance (Shift+Enter).\n",
23 | "\n",
24 | "\n",
25 | "* **Visit our repository** to explore the full dataset and **contribute your own new puzzles**: https://github.com/microsoft/PythonProgrammingPuzzles.\n",
26 | "* Share this demo with friends by sending this link: https://aka.ms/python_puzzles_study\n",
27 | "* For a shorter intro, and to see which puzzles the AI baselines solved, visit: https://aka.ms/python_puzzles\n",
28 | "\n",
29 | "## Overview\n",
30 | "\n",
31 | "\n",
32 | "* The first 3 problems are \"practice\" and the time you take will not be counted. This is a good chance to see how the system works.\n",
33 | "* Puzzles are defined by `def puzzle(...)`. For each puzzle, you will try to find an input `x` which makes `puzzle(x)` return `True`.\n",
34 | "* Type `next_puzzle()` when you are ready to start the first problem or to advance to the next problem.\n",
35 | "* There is **no option to revisit a puzzle** and once you call `next_puzzle()` the clock starts ticking (you have up to 6 minutes per puzzle).\n",
36 | "* **If you get lost, call `cur_puzzle()`** to see the current puzzle you are on (and time bar).\n",
37 | "\n",
38 | "## Timing\n",
39 | "\n",
40 | "* You have up to 6 minutes for each puzzle.\n",
41 | "* If you do not solve a problem in **6 minutes**, move to the next puzzle by typing `next_puzzle()`.\n",
42 | "* If you would like to give up and skip to the next puzzle without waiting, you can call `give_up()`.\n",
43 | "\n",
44 | "\n",
45 | "## Possible issues\n",
46 | "\n",
47 | "* Remember to use `cur_puzzle()` if you lost the code of the current puzzle.\n",
48 | "* For any feedback, please visit our [GitHub repository](https://github.com/microsoft/PythonProgrammingPuzzles) and reach out to us by email.\n",
49 | "\n",
50 | "\n",
51 | "## Summary of functions\n",
52 | "\n",
53 | "| Function \t| Description \t|\n",
54 | "|:----------------------\t|:------------------------------------------------------------------------------------------\t|\n",
55 | "|`next_puzzle()` \t| Start the next puzzle (call only when you are ready to start! no revisiting) \t|\n",
56 | "|`cur_puzzle()` \t| Present the current puzzle (useful if you got lost or accidentally overridden `puzzle()`) \t|\n",
57 | "|`puzzle(...)` \t| Submit a solution to the current puzzle \t|\n",
58 | "|`give_up()` \t| Give up and skip to the next puzzle *before* 6 minutes have passed. Please avoid this option if possible. |"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "# The first 3 puzzles are warmup. Begin the warmup part by running this cell.\n",
68 | "\n",
69 | "next_puzzle()"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "# Solve the first puzzle by running this cell.\n",
79 | "\n",
80 | "puzzle(\"world\")"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "# when you are ready to continue, run this cell.\n",
90 | "\n",
91 | "next_puzzle()"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "# You're on your own. Have fun.\n"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": []
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": []
116 | }
117 | ],
118 | "metadata": {
119 | "kernelspec": {
120 | "display_name": "Python 3",
121 | "language": "python",
122 | "name": "python3"
123 | },
124 | "language_info": {
125 | "codemirror_mode": {
126 | "name": "ipython",
127 | "version": 3
128 | },
129 | "file_extension": ".py",
130 | "mimetype": "text/x-python",
131 | "name": "python",
132 | "nbconvert_exporter": "python",
133 | "pygments_lexer": "ipython3",
134 | "version": "3.8.8"
135 | }
136 | },
137 | "nbformat": 4,
138 | "nbformat_minor": 4
139 | }
--------------------------------------------------------------------------------
/notebooks/demo/study_puzzles.py:
--------------------------------------------------------------------------------
1 | # provides get_puzzles
2 |
3 | # parts separated by ======== (exactly 8)
4 | # puzzles by ----'s (exactly 4)
5 |
6 |
7 | def get_puzzles():
8 | """
9 | Creates the puzzles for the study
10 |
11 | :return: list of {"src": str, "name": str, "part": str}
12 | """
13 | part_names = ["WARM UP", "PART 1/3", "PART 2/3", "PART 3/3"]
14 |
15 | raw_puzzles = '''
16 |
17 | def puzzle(s: str):
18 | """
19 | Warmup problem.
20 | """
21 | return "Hello " + s == 'Hello world'
22 |
23 | ----
24 |
25 | def puzzle(n: int):
26 | """
27 | Hint: puzzle(111111111) works.
28 | """
29 | return str(n * n).startswith("123456789")
30 |
31 | ----
32 |
33 | def puzzle(x: str):
34 | """
35 | Hint: note that the input should be a string.
36 | """
37 | return -1 * int(x) == 1337
38 |
39 |
40 | ========
41 |
42 | def puzzle(s: str):
43 | return s.count('o') == 1000 and s.count('oo') == 0
44 |
45 |
46 | ----
47 |
48 |
49 | def puzzle(s: str):
50 | return s.count('o') == 1000 and s.count('oo') == 100 and s.count('ho') == 801
51 |
52 | ----
53 |
54 | def puzzle(x: List[int]):
55 | return sorted(x) == list(range(999)) and all(x[i] != i for i in range(len(x)))
56 |
57 | ----
58 |
59 | def puzzle(x: List[int]):
60 | return len(x) == 10 and x.count(x[3]) == 2
61 |
62 | ----
63 |
64 |
65 | def puzzle(x: List[int]):
66 | return all([x.count(i) == i for i in range(10)])
67 |
68 | ----
69 |
70 | def puzzle(n: int):
71 | return n % 123 == 4 and n > 10**10
72 |
73 |
74 | ----
75 |
76 | def puzzle(s: str):
77 | return str(8**2888).count(s) > 8 and len(s) == 3
78 |
79 | ----
80 |
81 | def puzzle(s: List[str]):
82 | return s[1234] in s[1235] and s[1234] != s[1235]
83 |
84 | ----
85 |
86 | def puzzle(x: List[int]):
87 | return ["The quick brown fox jumps over the lazy dog"[i] for i in x] == list("The five boxing wizards jump quickly")
88 |
89 | ----
90 |
91 | def puzzle(s: str):
92 | return s in str(8**1818) and s==s[::-1] and len(s)>11
93 |
94 | ========
95 |
96 | def puzzle(x: List[str]):
97 | return min(x) == max(x) == str(len(x))
98 |
99 |
100 | ----
101 |
102 | def puzzle(x: List[int]):
103 | return all(a + b == 9 for a, b in zip([4] + x, x)) and len(x) == 1000
104 |
105 |
106 | ----
107 |
108 |
109 | def puzzle(x: float):
110 | return str(x - 3.1415).startswith("123.456")
111 |
112 | ----
113 |
114 | def puzzle(x: List[int]):
115 | return all([sum(x[:i]) == i for i in range(20)])
116 |
117 | ----
118 |
119 | def puzzle(x: List[int]):
120 | return all(sum(x[:i]) == 2 ** i - 1 for i in range(20))
121 |
122 | ----
123 |
124 | def puzzle(x: str):
125 | return float(x) + len(x) == 4.5
126 |
127 | ----
128 |
129 | def puzzle(n: int):
130 | return len(str(n + 1000)) > len(str(n + 1001))
131 |
132 | ----
133 |
134 | def puzzle(x: List[str]):
135 | return [s + t for s in x for t in x if s!=t] == 'berlin berger linber linger gerber gerlin'.split()
136 |
137 | ----
138 |
139 | def puzzle(x: Set[int]):
140 | return {i + j for i in x for j in x} == {0, 1, 2, 3, 4, 5, 6, 17, 18, 19, 20, 34}
141 |
142 | ----
143 |
144 |
145 | def puzzle(x: List[int]):
146 | return all(b in {a-1, a+1, 3*a} for a, b in zip([0] + x, x + [128]))
147 |
148 |
149 | ========
150 |
151 | def puzzle(x: List[int]):
152 | return all([x[i] != x[i + 1] for i in range(10)]) and len(set(x)) == 3
153 |
154 | ----
155 |
156 | def puzzle(x: str):
157 | return x[::2] in x and len(set(x)) == 5
158 |
159 | ----
160 |
161 | def puzzle(x: List[str]):
162 | return tuple(x) in zip('dee', 'doo', 'dah!')
163 |
164 | ----
165 |
166 | def puzzle(x: List[int]):
167 | return x.count(17) == 3 and x.count(3) >= 2
168 |
169 |
170 | ----
171 |
172 | def puzzle(s: str):
173 | return sorted(s)==sorted('Permute me true') and s==s[::-1]
174 |
175 |
176 | ----
177 | def puzzle(x: List[str]):
178 | return "".join(x) == str(8**88) and all(len(s)==8 for s in x)
179 |
180 |
181 | ----
182 |
183 | def puzzle(x: List[int]):
184 | return x[x[0]] != x[x[1]] and x[x[x[0]]] == x[x[x[1]]]
185 |
186 | ----
187 |
188 | def puzzle(x: Set[int]):
189 | return all(i in range(1000) and abs(i-j) >= 10 for i in x for j in x if i != j) and len(x)==100
190 |
191 | ----
192 |
193 |
194 | def puzzle(x: Set[int]):
195 | return all(i in range(1000) and abs(i*i - j*j) >= 10 for i in x for j in x if i != j) and len(x) > 995
196 |
197 | ----
198 |
199 | def puzzle(x: List[int]):
200 | return all([123*x[i] % 1000 < 123*x[i+1] % 1000 and x[i] in range(1000) for i in range(20)])
201 |
202 |
203 |
204 |
205 | '''
206 |
207 | parts = [[src.strip() for src in part.split("----")] for part in raw_puzzles.split("========")]
208 | assert len(part_names) == len(parts)
209 | ans = []
210 | for part_name, part in zip(part_names, parts):
211 | per_part_num = 1
212 | for src in part:
213 | ans.append({"src": src,
214 | "part": part_name,
215 | "name": f"PUZZLE {per_part_num}/{len(part)}"})
216 | per_part_num += 1
217 | return ans
218 |
219 |
220 |
--------------------------------------------------------------------------------
/notebooks/fireworks.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/notebooks/fireworks.gif
--------------------------------------------------------------------------------
/puzzles/split.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": [
3 | "Abbreviate",
4 | "AllQuadraticRoots",
5 | "AnyPath",
6 | "AnyTriangle",
7 | "ArithmeticSequence",
8 | "BallotProblem",
9 | "BasicStrCounts",
10 | "BillSums",
11 | "BirthdayParadoxMonteCarlo",
12 | "BoxVolume",
13 | "CapitalizeFirstLetter",
14 | "CenteredString",
15 | "CheckersPosition",
16 | "ClockAngle",
17 | "CollatzCycleUnsolved",
18 | "CollatzDelay",
19 | "CollatzGeneralizedUnsolved",
20 | "CombinationLock",
21 | "CombinationLockObfuscated",
22 | "CommonCase",
23 | "CompareInAnyCase",
24 | "CompleteParens",
25 | "ConcatStrings",
26 | "Conway99",
27 | "Count47",
28 | "CumulativeSum",
29 | "Dada",
30 | "DecreasingCountComparison",
31 | "DistinctDigits",
32 | "DistinctOddSum",
33 | "DominoTile",
34 | "Easy63",
35 | "EasySum",
36 | "EasyTwos",
37 | "EightQueensOrFewer",
38 | "EvenPath",
39 | "FermatsLastTheorem",
40 | "FindProductiveList",
41 | "FindRepeats",
42 | "FivePowers",
43 | "FloatNegSquareRoot",
44 | "FloatSquareRoot",
45 | "FloatWithDecimalValue",
46 | "FourSquares",
47 | "GCD",
48 | "GCD_multi",
49 | "GimmeChars",
50 | "GraphIsomorphism",
51 | "HalfPairs",
52 | "Harder63",
53 | "HelloWorld",
54 | "IfCases",
55 | "IfProblemWithOr",
56 | "IncDec",
57 | "IntDiv",
58 | "IntDiv2",
59 | "IntMul",
60 | "IntNeg",
61 | "IntSquareRoot",
62 | "IntSub",
63 | "IntSub2",
64 | "IntSum",
65 | "InvertIndices",
66 | "InvertPermutation",
67 | "Kirkman",
68 | "KnightsTour",
69 | "LCM",
70 | "LCM_multi",
71 | "LearnParity",
72 | "LearnParityWithNoise",
73 | "Lehmer",
74 | "LineIntersection",
75 | "ListAt",
76 | "ListDistinctSum",
77 | "ListIn",
78 | "ListMul",
79 | "ListNegAt",
80 | "ListPosSum",
81 | "ListSetLen",
82 | "ListSlice",
83 | "LongestSubsetString",
84 | "MatchingMarkers",
85 | "MaxConsecutiveProduct",
86 | "MaxConsecutiveSum",
87 | "MaxDelta",
88 | "MaybeReversed",
89 | "MinConsecutiveSum",
90 | "MinRotations",
91 | "MonkeyAndCoconuts",
92 | "Nash",
93 | "NecklaceSplit",
94 | "Nim",
95 | "No3Colinear",
96 | "NoRelativePrimes",
97 | "OddPath",
98 | "OnesAndTwos",
99 | "PandigitalSquare",
100 | "PenultimateRevString",
101 | "PenultimateString",
102 | "PostageStamp",
103 | "QuadraticRoot",
104 | "RepeatDec",
105 | "ReverseCat",
106 | "ReverseLifeStep",
107 | "RockPaperScissors",
108 | "SameDifferent",
109 | "ShortestPath",
110 | "Spaceship",
111 | "SquareTiles",
112 | "SquaringTheSquare",
113 | "Sstriiinggssuubb",
114 | "StrAdd",
115 | "StrCount",
116 | "StrIn",
117 | "StrIn2",
118 | "StrIndex",
119 | "StrIndex2",
120 | "StrJoiner",
121 | "StrLen",
122 | "StrParts",
123 | "StrSlice",
124 | "Study_1",
125 | "Study_11",
126 | "Study_13",
127 | "Study_15",
128 | "Study_17",
129 | "Study_19",
130 | "Study_21",
131 | "Study_23",
132 | "Study_25",
133 | "Study_27",
134 | "Study_29",
135 | "Study_3",
136 | "Study_5",
137 | "Study_7",
138 | "Study_9",
139 | "SublistSum",
140 | "SubstrCount",
141 | "SumOfDigits",
142 | "TotalDifference",
143 | "Triple0",
144 | "TripleDouble",
145 | "Tutorial1",
146 | "Tutorial2",
147 | "Tutorial3",
148 | "Tutorial4",
149 | "Tutorial5",
150 | "UNSOLVED_UncrossedKnightsPath",
151 | "UncrossedKnightsPath",
152 | "UnweightedShortestPath",
153 | "VerbalArithmetic",
154 | "VowelDrop",
155 | "WaterPouring",
156 | "ZipStr",
157 | "Znam"
158 | ],
159 | "test": [
160 | "AllCubicRoots",
161 | "AllPandigitalSquares",
162 | "AllPrefixes",
163 | "AlternatingFactorials",
164 | "AntiShuffle",
165 | "AnyEdge",
166 | "ArrayDiff",
167 | "BackWorlds",
168 | "BackwardsDigits",
169 | "BiPermutations",
170 | "BigOdds",
171 | "BiggestEven",
172 | "BiggestK",
173 | "Binarize",
174 | "BinaryAverage",
175 | "BinarySort",
176 | "BinaryStrXOR",
177 | "BinomialProbabilities",
178 | "BirthdayParadox",
179 | "BitSum",
180 | "BooleanPythagoreanTriples",
181 | "CardGame24",
182 | "CatStrings",
183 | "CeilingSquares",
184 | "CertifiedGCD",
185 | "ChangeBase",
186 | "CharCounts",
187 | "CharSum",
188 | "ClosestInteger",
189 | "ClosestPalindrome",
190 | "CommonNumbers",
191 | "CompleteSplit",
192 | "ConsonantFilter",
193 | "CubeRoot",
194 | "CubicRoot",
195 | "CumulativeSums",
196 | "DateDiff",
197 | "Dedup",
198 | "DeepestParens",
199 | "DelPalindrome",
200 | "Derivative",
201 | "DiffChars",
202 | "DiscreteLog",
203 | "DistinctChars",
204 | "Drops",
205 | "EngineerNumbers",
206 | "EvaluateOperators",
207 | "Even4Sum",
208 | "EvenBetween",
209 | "EvenOddDigits",
210 | "EvenOddSum",
211 | "EvenPalindromeNumbers",
212 | "EvenWords",
213 | "ExpandSpaces",
214 | "ExponentialCoinMoves",
215 | "ExponentialProbability",
216 | "Factor47",
217 | "FactorString",
218 | "Factoring",
219 | "FermatComposites",
220 | "Fib3",
221 | "Fib4",
222 | "Fibonacci",
223 | "FilenameOK",
224 | "FilterInts",
225 | "FindBored",
226 | "FindCloseElements",
227 | "FindClosePair",
228 | "FindContainers",
229 | "FindExtensions",
230 | "FindHomogeneousSubstring",
231 | "FindPositives",
232 | "FindVowels",
233 | "FirstNegCumulative",
234 | "FlipCase",
235 | "Frac",
236 | "GCD17",
237 | "GeometricSequence",
238 | "Grader",
239 | "GreatestHIndex",
240 | "HalfSorted",
241 | "HalfTag",
242 | "HeronTriangle",
243 | "HexPrimes",
244 | "IdentifyZeroTrips",
245 | "IfProblem",
246 | "IfProblemWithAnd",
247 | "IncreasingViolation",
248 | "IntNegSquareRoot",
249 | "IntegerLog",
250 | "Intersperse",
251 | "InverseSuperFactorial",
252 | "InvestigateCrash",
253 | "IsEven",
254 | "LZW",
255 | "LargestDivisor",
256 | "LargestNegSmallestPos",
257 | "LargestPrimeDigitSum",
258 | "LargestPrimeFactor",
259 | "LargestStringNum",
260 | "LastLetters",
261 | "LexPath",
262 | "ListInc",
263 | "ListIndex",
264 | "ListIndex2",
265 | "ListLen",
266 | "LittleFermat",
267 | "LongEarlySum",
268 | "LongestMonotonicSubstring",
269 | "LongestMonotonicSubstringTricky",
270 | "LongestStr",
271 | "Mastermind",
272 | "MaxInt",
273 | "Median",
274 | "MinBigger",
275 | "MinSquaredDeviation",
276 | "MissingBananas",
277 | "Monotonic",
278 | "MoreQueens",
279 | "MostUnique",
280 | "Moving0s",
281 | "NarrowerList",
282 | "NearbyDuplicates",
283 | "NumPasses",
284 | "OddCase",
285 | "OddCollatz",
286 | "OddDegreePolynomialRoot",
287 | "OddProduct",
288 | "OneEnded",
289 | "OptimalBridges",
290 | "Oscillators",
291 | "OverlappingCount",
292 | "PackingHam",
293 | "PairZeroSum",
294 | "Palindrome",
295 | "PalindromeContaining",
296 | "ParenDepth",
297 | "ParenthesesPermutation",
298 | "ParityExchange",
299 | "ParseMusic",
300 | "PickNearNeighbors",
301 | "PlanetRange",
302 | "PlantedClique",
303 | "PositiveDigitSums",
304 | "PrimeFactorization",
305 | "PrimeFib",
306 | "PrimeIntervalIntersection",
307 | "PrimeSel",
308 | "PrimeWords",
309 | "PrimesUpTo",
310 | "ProductSigns",
311 | "PythagoreanTriples",
312 | "Quine",
313 | "Rescale",
314 | "RevQuine",
315 | "RollingMax",
316 | "RomanNumerals",
317 | "RotateSort",
318 | "RotateString",
319 | "SecondSmallestUnique",
320 | "SeparateParenGroups",
321 | "SevenElevenThirteen",
322 | "ShiftChars",
323 | "ShortIntegerPath",
324 | "ShortestDecDelta",
325 | "SimplifyProductFraction",
326 | "SlidingOne",
327 | "SlidingPuzzle",
328 | "SmallExponentBigSolution",
329 | "SmallestEven",
330 | "SortByDigitSum",
331 | "SortNumbers",
332 | "SortPlusPlus",
333 | "SortedOdds",
334 | "SpaceyRange",
335 | "Sssuubbstriiingg",
336 | "StonePiles",
337 | "StrAt",
338 | "StrLength",
339 | "StrMul",
340 | "StrMul2",
341 | "StrNegAt",
342 | "StrSetLen",
343 | "StrSplit",
344 | "StrSplitter",
345 | "StrangeSplit",
346 | "Study_10",
347 | "Study_12",
348 | "Study_14",
349 | "Study_16",
350 | "Study_18",
351 | "Study_2",
352 | "Study_20",
353 | "Study_22",
354 | "Study_24",
355 | "Study_26",
356 | "Study_28",
357 | "Study_30",
358 | "Study_4",
359 | "Study_6",
360 | "Study_8",
361 | "SubstitutionCypher",
362 | "Sudoku",
363 | "SumProduct",
364 | "ThreeCubes",
365 | "ThreeCycle",
366 | "ThreePrimes",
367 | "Threeples",
368 | "TicTacToeO",
369 | "TicTacToeX",
370 | "TowersOfHanoi",
371 | "TowersOfHanoiArbitrary",
372 | "TriangleArea",
373 | "Tribonacci",
374 | "TripleZeroSum",
375 | "TwoThirdsSorted",
376 | "UnevenFind",
377 | "UniqueSorted",
378 | "UnitsProduct",
379 | "UpDownSort",
380 | "UppercaseEven",
381 | "ValidBracketSubsequence",
382 | "VowelSandwich",
383 | "WeirdDecodeVowels",
384 | "WildSort",
385 | "Zarankiewicz",
386 | "ZeroSum",
387 | "ZobristCollision"
388 | ]
389 | }
--------------------------------------------------------------------------------
/solvers/README.md:
--------------------------------------------------------------------------------
1 | # Solvers
2 |
3 | This folder contains two subfolders for recreating the benchmarks in the [paper](https://arxiv.org/abs/2106.05784).
4 | * [gpt3](/solvers/gpt3) The GPT-3 experiments.
5 | * [enumerative](/solvers/enumerative) The enumerative top-down search solvers.
6 |
7 | Each folder has a separate README explaining how to run the experiments.
--------------------------------------------------------------------------------
/solvers/codex/README.md:
--------------------------------------------------------------------------------
1 | # Running GPT-3 experiments
2 |
3 | These are instructions for re-running the Codex experiments. The results will be slightly different than those in
4 | the paper because the API is non-deterministic.
5 |
6 | The requirements can be installed with `pip3 install -r requirements.txt`.
7 |
8 | `run_codex_experiments.py` runs the Codex experiments and prints the results to stdout. Change the
9 | parameters in that file to run it on the 397puzzles (v0.2) or 138puzzles.json (v0.1 used in
10 | first experiment) or 30puzzles.json (study)
11 | or to use the davinci-codex engine vs cushman-codex.
12 |
13 | ## Installation and execution.
14 | You will need an open-ai Codex API access key which can be signed up for [here](https://openai.com/join/).
15 | You will then need to set it as the `OPENAI_API_KEY` environmnet variable. If you want an extension added
16 | to the engines such as "-msft", set the environment variable `export OPEN_AI_ENGINE_SUFFIX=-msft`.
17 | We also recommend that you set the environment variable `export PYTHONHASHSEED=0` for determinism.
18 |
19 | The requirements can be installed with `pip3 install -r requirements.txt`.
20 |
21 | It was run with Python 3.6.9, sys.version = '3.6.9 (default, Jan 26 2021, 15:33:00) \n[GCC 8.4.0]', but should
22 | be compatible with later versions as well.
23 |
24 | Then you simply run
25 | `python run_codex_experiments.py`. It uses cacheing mechanisms with the first run
26 | being quite slow and verbose, querying the API. However you can subsequently run it again and it will be
27 | much faster and just output the results. The cacheing makes it deterministic so it should give the same
28 | exact results when re-run.
29 |
30 |
--------------------------------------------------------------------------------
/solvers/codex/ezlog.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import inspect
4 | import io
5 |
6 | my_path = os.path.dirname(__file__)
7 |
8 |
9 | def color_str(obj, code="\033[0;36m"):
10 | return code + str(obj) + '\033[0m'
11 |
12 |
13 | _configured = False
14 |
15 |
16 | def configure_logging(stdio_level=logging.INFO,
17 | file_level=logging.DEBUG,
18 | filename=".easy.log",
19 | filepath=os.path.join(my_path, "logs")):
20 | os.makedirs(filepath, exist_ok=True)
21 | filename = os.path.join(filepath, filename)
22 | global _configured
23 | if _configured:
24 | warning("Re-configuring logging")
25 | stdio_handler = logging.StreamHandler()
26 | stdio_handler.setLevel(stdio_level)
27 | file_hanlder = logging.FileHandler(filename)
28 | file_hanlder.setLevel(file_level)
29 |
30 | logging.basicConfig(
31 | format="%(asctime)s - %(levelname)s - %(name)s - %(message).200s",
32 | datefmt="%m/%d/%Y %H:%M:%S",
33 | level=min(stdio_level, file_level),
34 | handlers=[stdio_handler, file_hanlder]
35 | )
36 |
37 | _configured = True
38 | _get_or_create_logger().debug("Configured logging")
39 |
40 |
41 | _loggers = {}
42 |
43 |
44 | def _get_or_create_logger():
45 | global _configured, _loggers
46 | if not _configured:
47 | configure_logging()
48 | try:
49 | for frame in inspect.stack():
50 | name = inspect.getmodule(frame[0]).__name__
51 | if name != __name__:
52 | break
53 | except:
54 | name = "_"
55 | if name not in _loggers:
56 | _loggers[name] = logging.getLogger(name)
57 | return _loggers[name]
58 |
59 |
60 | def print_to_string(*args, end="", **kwargs):
61 | with io.StringIO() as buf:
62 | print(*args, file=buf, end=end, **kwargs)
63 | return buf.getvalue()
64 |
65 |
66 | def debug(*args, **kwargs):
67 | _get_or_create_logger().debug(print_to_string(*args, **kwargs))
68 |
69 |
70 | def info(*args, **kwargs):
71 | _get_or_create_logger().info(print_to_string(*args, **kwargs))
72 |
73 |
74 | log = info
75 |
76 |
77 | def warning(*args, **kwargs):
78 | _get_or_create_logger().warning(print_to_string(*args, **kwargs))
79 |
80 |
81 | warn = warning
82 |
83 |
84 | def error(*args, **kwargs):
85 | _get_or_create_logger().error(print_to_string(*args, **kwargs))
86 |
--------------------------------------------------------------------------------
/solvers/codex/lm_solve/__init__.py:
--------------------------------------------------------------------------------
1 | from lm_solve.run import *
2 |
--------------------------------------------------------------------------------
/solvers/codex/lm_solve/gpt_lib.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import openai
4 | import ezlog
5 | import time
6 | import datetime
7 |
8 | assert 'OPENAI_API_KEY' in os.environ, "Need to set environment variable `OPENAI_API_KEY`"
9 | openai.api_key = os.environ['OPENAI_API_KEY']
10 | OPEN_AI_ENGINE_SUFFIX = os.environ.get('OPEN_AI_ENGINE_SUFFIX', '') # add extension such as -msft to engine names
11 |
12 | _CACHE_PATH = os.path.join(os.path.dirname(__file__), "../.cache")
13 | _CACHE_ENCODING = "utf-8"
14 |
15 |
16 | # the cache file is just a list of (query params dictionary encoded as a string but without n, result list)
17 | # multiple queries with the same params (except for n) are merged into a single big list
18 | class Cache:
19 | def __init__(self, filename):
20 | self.filename = filename
21 | self._cache = None
22 |
23 | def _load_cache(self):
24 | """for lazy loading"""
25 | assert self._cache is None, "gpt cache already loaded"
26 |
27 | if not os.path.exists(_CACHE_PATH):
28 | ezlog.warn("Creating cache path")
29 | os.makedirs(_CACHE_PATH)
30 |
31 | self._cache = {}
32 |
33 | if os.path.exists(self.filename):
34 | time0 = time.perf_counter()
35 | with open(self.filename, "r", encoding=_CACHE_ENCODING) as f:
36 | for k, v in [eval(line) for line in f.readlines()]:
37 | if k not in self._cache:
38 | self._cache[k] = v
39 | else:
40 | self._cache[k].extend(v)
41 | ezlog.info(f"Loaded cache `{self.filename}` in {time.perf_counter() - time0:.1f}s")
42 | else:
43 | ezlog.warn("No gpt cache yet")
44 |
45 | def defrag(self):
46 | if self._cache is None:
47 | self._load_cache()
48 |
49 | # def helper(k): # remove max_batch
50 | # k2 = eval(k)
51 | # del k2["max_batch"]
52 | # return str(k2)
53 |
54 | if self._cache:
55 | with open(self.filename, "w", encoding=_CACHE_ENCODING) as f:
56 | # f.write("\n".join([str((helper(k), v)) for k, v in self._cache.items()]+[""]))
57 | f.write("\n".join([str((k, v)) for k, v in self._cache.items()]+[""]))
58 | ezlog.info("Defragged cache")
59 | else:
60 | ezlog.warn("No cache to defrag")
61 |
62 |
63 | def get(self, item):
64 | if self._cache is None:
65 | self._load_cache()
66 |
67 | return self._cache.get(item, []).copy() # no monkey business changing cache
68 |
69 | def extend(self, key, values):
70 | if self._cache is None:
71 | self._load_cache()
72 |
73 | v = self._cache.setdefault(key, [])
74 | v.extend(values)
75 |
76 | with open(self.filename, "a", encoding=_CACHE_ENCODING) as f:
77 | f.write(str((key, values)) + "\n")
78 |
79 | return v.copy() # no monkey business changing cache
80 |
81 |
82 | BATCH_SIZES = {
83 | "davinci": 32,
84 | "davinci-codex": 128,
85 | "cushman-codex": 128
86 | }
87 |
88 | CACHES = {cache: Cache(os.path.join(_CACHE_PATH, cache + ".cache")) for cache in BATCH_SIZES}
89 |
90 | def query(prompt, n=10, max_tokens=150, temp=1.0, stop=None, notes=None, cache_only=False, verbose=True,
91 | max_retries=10, engine="cushman-codex"):
92 | """Query gpt
93 |
94 | :param prompt: Up to 2048 tokens (about 3-4k chars)
95 | :param n: number of answers, None returns all cached answers
96 | :param max_tokens:
97 | :param temp: 0.9 seems to work well
98 | :param stop: string to stop at or '' if not to stop
99 | :param notes: notes you want to save or change in case you want to run the same query more than once!
100 | :return: list of answers and then the response items
101 | """
102 | global BATCH_SIZES
103 | global CACHES
104 | cur_cache = CACHES[engine]
105 | max_batch = BATCH_SIZES[engine]
106 | engine += OPEN_AI_ENGINE_SUFFIX # add tail to engine name
107 |
108 | if temp == 0 and n > 1:
109 | ezlog.debug("Temp 0: no point in running more than one query")
110 | n = 1
111 |
112 | key = str(dict(prompt=prompt, max_tokens=max_tokens, temp=temp, stop=stop, rep=notes))
113 |
114 | cached = cur_cache.get(key)
115 |
116 | if n is None:
117 | return cached
118 |
119 | if len(cached) >= n:
120 | return cached[:n]
121 |
122 | assert not cache_only, f'Entry not found in cache with prompt "{json.dumps(prompt)}"'
123 | if verbose:
124 | print("/" * 100)
125 | print(f"Querying GPT {engine} with prompt:")
126 | print(prompt)
127 | s = stop and stop.replace('\n', '\\n')
128 | print(f"/// n={n} ({n - len(cached)} new) max_tokens={max_tokens} temp={temp} max_batch={max_batch} stop={s}")
129 | print("/" * 100)
130 |
131 | time0 = time.perf_counter()
132 |
133 | new = []
134 | n -= len(cached)
135 |
136 | while n > 0:
137 | m = min(n, max_batch)
138 |
139 | try_number = 0
140 | while True:
141 | try:
142 | res = openai.Completion.create(
143 | engine=engine,
144 | prompt=prompt,
145 | max_tokens=max_tokens,
146 | temperature=temp,
147 | n=m,
148 | stop=stop or None
149 | )
150 | break
151 | except (openai.error.RateLimitError, openai.error.APIError):
152 | if try_number == max_retries:
153 | print("Rate limit error: Giving up!")
154 | raise
155 | sleep_secs = 10 * (2 ** try_number)
156 | try_number += 1
157 | print(f"Rate limit error #{try_number}: Sleeping for {sleep_secs} seconds...")
158 | time.sleep(sleep_secs)
159 |
160 | new += [c["text"] for c in res["choices"]]
161 | n -= m
162 |
163 | return cur_cache.extend(key, new)
164 |
--------------------------------------------------------------------------------
/solvers/codex/lm_solve/scratch.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 |
3 | def chars(filename, max_ord=300):
4 | with open(filename, "r", encoding="utf8") as f:
5 | counts = Counter(ord(c) for c in f.read())
6 | print(counts.most_common())
7 | print("max", max(counts))
8 | missing = [i for i in range(max_ord) if i not in counts]
9 | print("missing", missing)
10 |
11 | chars(".cache/davinci-codex.cache")
12 |
13 | import time
14 | time0 = time.perf_counter()
15 | with open(".cache/davinci-codex.cache", "r", encoding="utf8") as f:
16 | lines = f.readlines()
17 | time1 = time.perf_counter()
18 | print("duration", time1 - time0)
19 |
20 | time0 = time.perf_counter()
21 | with open(".cache/davinci-codex.cache", "r", encoding="utf8") as f:
22 | lines = f.readlines()
23 | time1 = time.perf_counter()
24 | print("duration", time1 - time0)
25 |
26 | import json
27 | time0 = time.perf_counter()
28 | elines = [json.loads(l) for l in lines]
29 | time1 = time.perf_counter()
30 | print("duration", time1 - time0)
31 |
32 | len(lines), len(set(lines))
33 | len(elines[0][0])
34 | list(eval(elines[0][0]))
35 |
--------------------------------------------------------------------------------
/solvers/codex/requirements.txt:
--------------------------------------------------------------------------------
1 | astor==0.8.1
2 | numpy==1.22.0
3 | openai==0.6.3
4 | tqdm==4.60.0
5 | transformers==4.30.0
6 | Pebble==4.6.1
7 |
8 | # we ran with Python version sys.version = '3.6.9 (default, Jan 26 2021, 15:33:00) \n[GCC 8.4.0]'
9 | # distro-info===0.18ubuntu0.18.04.1
10 |
--------------------------------------------------------------------------------
/solvers/codex/results/results_397_cushman_codex_1k_full.json.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/solvers/codex/results/results_397_cushman_codex_1k_full.json.gz
--------------------------------------------------------------------------------
/solvers/codex/run_codex_experiments.py:
--------------------------------------------------------------------------------
1 | """
2 | This script runs the codex experiments.
3 | For GPT-3 experiments see run_gpt3_experiments.py in https://github.com/microsoft/PythonProgrammingPuzzles/tree/v0.1
4 | It uses cacheing mechanisms so that if run twice with the same parameters, it will give exactly the same
5 | results and will not query the API again and will not judge the resulting solutions again. Hence, the first
6 | time you run it, it will be slow, but you can subsequently run it again and it will be fast. It will run the
7 | experiment three times, with different seeds to get different results.
8 | """
9 |
10 | import lm_solve
11 | import utils
12 | import math
13 | import numpy as np
14 |
15 | OUTPUT_FILENAME = "results_30_cushman_codex_32.json"
16 | SEEDS = 1 # number of times to run it
17 | PARAMS = dict(
18 | temp=0.9,
19 | timeout=1.0, # seconds to judge
20 | n=32, # number of attempts per puzzle, usually 1000, or set small for a fast run
21 | filename="30puzzles.json", # set to 397puzzles.json for a run on full v0.2 dataset
22 | cache_only=False, # change this to True if you want to run a 2nd time without risking hitting API
23 | engine="cushman-codex", # FAST-CODEX: "cushman-codex" CODEX: "davinci-codex" GPT3: "davinci"
24 | )
25 |
26 | BOOTSTRAP_PARAMS = dict(
27 | temp=PARAMS["temp"],
28 | n=PARAMS["n"],
29 | timeout=PARAMS["timeout"],
30 | filename=PARAMS["filename"],
31 | cache_only=PARAMS["cache_only"],
32 | ppi=32, # puzzles per iteration
33 | engine=PARAMS["engine"],
34 | prefix="from typing import List\n\n",
35 | )
36 |
37 | PREFIX = """from typing import List
38 |
39 | def f1(s: str):
40 | return "Hello " + s == "Hello world"
41 |
42 | def g1():
43 | return "world"
44 |
45 | assert f1(g1())
46 |
47 | def f2(s: str):
48 | return "Hello " + s[::-1] == "Hello world"
49 |
50 | def g2():
51 | return "world"[::-1]
52 |
53 | assert f2(g2())
54 |
55 | def f3(x: List[int]):
56 | return len(x) == 2 and sum(x) == 3
57 |
58 | def g3():
59 | return [1, 2]
60 |
61 | assert f3(g3())
62 |
63 | def f4(s: List[str]):
64 | return len(set(s)) == 1000 and all((x.count("a") > x.count("b")) and ('b' in x) for x in s)
65 |
66 | def g4():
67 | return ["a"*(i+2)+"b" for i in range(1000)]
68 |
69 | assert f4(g4())
70 |
71 | def f5(n: int):
72 | return str(n * n).startswith("123456789")
73 |
74 | def g5():
75 | return int(int("123456789" + "0"*9) ** 0.5) + 1
76 |
77 | assert f5(g5())
78 |
79 | """ # trailing newlines important
80 |
81 |
82 | PREFIX_DOCSTR = '''from typing import List
83 |
84 | def f1(s: str):
85 | return "Hello " + s == "Hello world"
86 |
87 | def g1():
88 | """Find a string that when concatenated onto 'Hello ' gives 'Hello world'."""
89 | return "world"
90 |
91 | assert f1(g1())
92 |
93 | def f2(s: str):
94 | return "Hello " + s[::-1] == "Hello world"
95 |
96 | def g2():
97 | """Find a string that when reversed and concatenated onto 'Hello ' gives 'Hello world'."""
98 | return "world"[::-1]
99 |
100 | assert f2(g2())
101 |
102 | def f3(x: List[int]):
103 | return len(x) == 2 and sum(x) == 3
104 |
105 | def g3():
106 | """Find a list of two integers whose sum is 3."""
107 | return [1, 2]
108 |
109 | assert f3(g3())
110 |
111 | def f4(s: List[str]):
112 | return len(set(s)) == 1000 and all((x.count("a") > x.count("b")) and ('b' in x) for x in s)
113 |
114 | def g4():
115 | """Find a list of 1000 distinct strings which each have more 'a's than 'b's and at least one 'b'."""
116 | return ["a"*(i+2)+"b" for i in range(1000)]
117 |
118 | assert f4(g4())
119 |
120 | def f5(n: int):
121 | return str(n * n).startswith("123456789")
122 |
123 | def g5():
124 | """Find an integer whose perfect square begins with 123456789 in its decimal representation."""
125 | return int(int("123456789" + "0"*9) ** 0.5) + 1
126 |
127 | assert f5(g5())
128 |
129 | ''' # trailing newlines important
130 |
131 |
132 | def pass_at_k(k: int, successes: int, attempts: int):
133 | fail_prob = 1.0
134 | for i in range(k):
135 | fail_prob *= (attempts - successes)/attempts # gets right answer of 0 when attempts == successes
136 | attempts -= 1
137 | return 1.0 - fail_prob
138 |
139 |
140 |
141 | def run(seed=0):
142 | PARAMS_0 = {**PARAMS, "n": 0}
143 | BOOTSTRAP_PARAMS_0 = {**BOOTSTRAP_PARAMS, "n": 0}
144 | sols = [lm_solve.prompt_experiment(**PARAMS, experiment="short", prefix="", seed=seed),
145 | lm_solve.prompt_experiment(**PARAMS, experiment="med", prefix=PREFIX, remove_docstring=True, seed=seed),
146 | lm_solve.prompt_experiment(**PARAMS, experiment="long", prefix=PREFIX_DOCSTR, remove_docstring=False, seed=seed),
147 | ]
148 | num_puzzles = len(sols[0]["sat_sols"])
149 | assert all(len(s["sat_sols"]) == num_puzzles for s in sols)
150 |
151 | n = PARAMS["n"]
152 | ks = [1]
153 | while ks[-1] < n:
154 | ks += [ks[-1] * i for i in [10]] # for i in [2, 5, 10]]
155 | ks = [k for k in ks if k <= n]
156 | if ks[-1] != n:
157 | ks.append(n)
158 | for s in sols:
159 | s["pass@k"] = [
160 | (
161 | k,
162 | np.mean([pass_at_k(k, s_s["n_sols"], n) for s_s in s["sat_sols"]])
163 | )
164 | for k in ks]
165 |
166 | bootstrap = lm_solve.bootstrap(**BOOTSTRAP_PARAMS, seed=seed)
167 | bootstrap["pass@k"] = [(k, np.mean([s_s["failures"] < k for s_s in bootstrap["sat_sols"]])) for k in ks]
168 | sols.append(bootstrap)
169 |
170 | print(f"run={seed} ALL DONE!\n\n")
171 | print(f"run={seed} RESULTS " + "=" * 50)
172 | print()
173 |
174 | for s in sols:
175 | print(s["experiment"], "prefix:", s["prefix"].replace("\n", "\\n")[:250])
176 | print(" ", s["tot_solved"], "solved, pass@k", " ".join(f'{k} {p:.5f}' for k, p in s["pass@k"]))
177 |
178 | print(f"Pass at k [(k, {', '.join(s['experiment'] for s in sols)}) ...]")
179 | print(list(zip([k for k, _p in sols[0]["pass@k"]], *[[p for _k, p in s["pass@k"]] for s in sols])))
180 |
181 | return sols
182 |
183 | def main():
184 | res = [s for seed in range(SEEDS) for s in run(seed)]
185 |
186 | if OUTPUT_FILENAME:
187 | FULL_FILENAME = OUTPUT_FILENAME.replace(".json", "_full.json.gz")
188 | utils.save_json(res, FULL_FILENAME)
189 | for s in res:
190 | if "sat_sols" in s:
191 | for t in s["sat_sols"] :
192 | if "sol_counts" in t:
193 | if t["sol_counts"]:
194 | t["shortest_sol"] = min([s for s, c in t["sol_counts"]], key=len)
195 | t["longest_sol"] = max([s for s, c in t["sol_counts"]], key=len)
196 | t["common_sol"] = max(t["sol_counts"], key=lambda z: z[1])[0]
197 | del t["sol_counts"]
198 | utils.save_json(res, OUTPUT_FILENAME)
199 | print(f"saved results to {OUTPUT_FILENAME} and {FULL_FILENAME}")
200 |
201 |
202 | if __name__ == "__main__":
203 | main()
204 |
205 |
--------------------------------------------------------------------------------
/solvers/codex/utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import operator
3 | import os
4 | import re
5 | import json
6 | import time
7 | import logging
8 |
9 | def get_lambda_arg_name(lam):
10 | assert lam.startswith("lambda ")
11 | return lam[len("lambda "):lam.index(":")].strip()
12 |
13 |
14 | def stringify(const):
15 | if type(const) is str:
16 | return json.dumps(const)
17 | return str(const)
18 |
19 |
20 | def color_str(obj, code="\033[0;36m"):
21 | return code + str(obj) + '\033[0m'
22 |
23 |
24 | def prod(iterable): # like sum but product
25 | return functools.reduce(operator.mul, iterable, 1)
26 |
27 |
28 | def flatten(it):
29 | return (e for a in it for e in (flatten(a) if isinstance(a, (tuple, list)) else (a,)))
30 |
31 | def save_json(obj, filename, make_dirs_if_necessary=False, indent=2, **kwargs):
32 | """Saves compressed file if filename ends with '.gz'"""
33 | import json
34 | if make_dirs_if_necessary:
35 | os.makedirs(os.path.dirname(filename), exist_ok=True)
36 | if filename.endswith(".gz"):
37 | import gzip
38 | with gzip.open(filename, "wt") as f:
39 | return json.dump(obj, f, indent=indent, **kwargs)
40 | with open(filename, "w", encoding="utf8") as f:
41 | return json.dump(obj, f, indent=indent, **kwargs)
42 |
43 | def load_json(filename):
44 | """Loads compressed file if filename ends with '.gz'"""
45 | import json
46 | if filename.endswith(".gz"):
47 | import gzip
48 | with gzip.open(filename, "rt") as f:
49 | return json.load(f)
50 | with open(filename, "r", encoding="utf8") as f:
51 | return json.load(f)
52 |
53 |
54 | def viz_py(py):
55 | import astor, ast
56 | print(astor.dump_tree(ast.parse(py)))
57 |
58 |
59 | def dedup(li):
60 | seen = set()
61 | return [x for x in li if x not in seen and not seen.add(x)]
62 |
63 |
64 | def test_puzzle(f, x):
65 | """Checks if x is of the correct type and makes f return True (literally True, not an integer or whatever)
66 |
67 | :param f: Puzzle
68 | :param x: candidate answer
69 | :return:
70 | """
71 | answer_type = list(f.__annotations__.values())[0]
72 | if not type_check(x, answer_type):
73 | raise TypeError
74 | return f(x) is True
75 |
76 |
77 |
78 | def type_check(obj, typ):
79 | """
80 | check if obj is of type `typ` where `typ` is a `typing` module type annotation, eg List[int]
81 | The way we do this to be compatible across versions is we first convert the type to a string.
82 | """
83 |
84 | type_str = str(typ).replace("typing.", "")
85 | if type_str.startswith("= 2:
113 | a = src.index('"""')
114 | b = src.index('"""', a+1) + 3
115 | if count == 1:
116 | h = helper(src[:a])
117 | if h != src[:a]:
118 | return h + src[a:]
119 | return helper(src[:a]) + src[a:b] + helper(src[b:])
120 |
121 | return helper(src)
122 |
123 | logger = None
124 |
125 | def timeit(method):
126 | global logger
127 | if logger is None:
128 | logger = logging.getLogger(__name__)
129 | def timed(*args, **kw):
130 | tick = time.time()
131 | result = method(*args, **kw)
132 | tock = time.time()
133 | logger.debug(f'{method.__name__}: {tock - tick:.3f}s')
134 |
135 | return result
136 | return timed
137 |
138 |
--------------------------------------------------------------------------------
/solvers/enumerative/README.md:
--------------------------------------------------------------------------------
1 | # Enumerative puzzle solvers
2 |
3 | This folder contains the code for the enumerative models used in our Programming Puzzles paper.
4 | We used python 3.8.0 and the libraries in the `requirements.txt` file.
5 |
6 | In a linux machine with python3.8.0 installed, the following commands will set up the environment:
7 | ```
8 | virtualenv -p /usr/bin/python3.8 env_solvers
9 | source env_solvers/bin/activate
10 | pip install -r requirements.txt
11 | ```
12 |
13 | ## Uniform solver
14 | ```
15 | bash run_uniform.sh
16 | ```
17 | This will run the uniform solver for a maximum of 10k trials per puzzle. This is required before training the other parameterized solvers.
18 |
19 | To run the uniform with 1M trials per puzzle, simply change the `max_n_progs` argument in the bash script.
20 |
21 | ## Bigram random forest solver
22 | ```
23 | bash run_bigram.sh
24 | ```
25 | This will first train a parameterized model with self-bootsrapping (first iteration is based on the unifrom solutions). The last command will train a model without self-bootsrapping.
26 |
27 | ## Transformers solver
28 | ```
29 | bash download_pretrained_roberta.sh
30 | bash run_transformer.sh
31 | ```
32 | The first script will download the RoBERTa-Base model that we trained on Python code.
33 |
34 | The second script will first train a parameterized model with self-bootsrapping (first iteration is based on the unifrom solutions). The last command will train a model without self-bootsrapping.
35 |
--------------------------------------------------------------------------------
/solvers/enumerative/challenges/__init__.py:
--------------------------------------------------------------------------------
1 | from challenges.challenge import *
2 | from challenges.solutions import *
3 |
4 | def contains_node(root, x_node):
5 | return root is x_node or (hasattr(root, "children") and any(contains_node(k, x_node) for k in root.children))
6 |
--------------------------------------------------------------------------------
/solvers/enumerative/challenges/challenge.py:
--------------------------------------------------------------------------------
1 | import utils
2 | import tython
3 | import logging
4 | from typing import List, Dict, Callable, Tuple, Generator, Set, Sequence
5 | from tython import Program, nt
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def extract_constants(prog) -> Dict:
11 | '''
12 | Extract all constants from program. Does not (yet) allow copying of comprehensions, e.g., '[i*i for i in range(10)]'
13 | '''
14 |
15 | from collections import defaultdict
16 | consts = defaultdict(list)
17 |
18 | def handle_args(args_node):
19 |
20 | if args_node.rule.name == 'cast:ARGS':
21 | handle_args(args_node.children[0])
22 | else:
23 | if len(args_node.children) >= 3 and args_node.children[1].nt == nt.TYPE:
24 | annotation_node = args_node.children[1]
25 | t = nt.type2nt(eval(annotation_node.src()))
26 | consts[t].append(args_node.children[0])
27 | if args_node.children and args_node.children[-1].nt in {nt.ARGS, nt.DEFAULT_ARGS}:
28 | handle_args(args_node.children[-1])
29 |
30 | def helper(node):
31 | if node.rule.name == 'def': # it's a function
32 | name_node, args_node, body_node = node.children
33 | if name_node.src() == 'sat':
34 | handle_args(args_node.children[-1]) # skip first arg for `def sat`
35 | else:
36 | handle_args(args_node)
37 | helper(body_node)
38 | return False
39 | elif node.nt in {nt.NAME}:
40 | return False
41 | elif node.nt in {nt.STMT}:
42 | for c in node.children:
43 | helper(c)
44 | return False
45 | if node.rule.name not in {"int-const", "str-const"} and not all([helper(c) for c in node.children]):
46 | return False
47 | if node.nt.isa(nt.LIST, nt.SET, nt.DICT, nt.TUPLE, nt.RANGE,
48 | nt.INT, nt.FLOAT, nt.BOOL, nt.STR):
49 | consts[node.nt].append(node)
50 | return True
51 |
52 | if prog is not None:
53 | helper(prog.tree)
54 |
55 | return dict(consts)
56 |
57 | #
58 | # q = Program("""
59 | # def sat(i: List[str], a=5):
60 | # return i==['5']
61 | # """)
62 | #
63 | # extract_constants(q)
64 | #
65 | #
66 | # %%
67 | class Solution():
68 | def __init__(self, string=None, prog=None, likelihood=None, time=None, count=None):
69 | self.string = string
70 | self.prog = prog
71 | self.likelihood = likelihood
72 | self.time = time
73 | self.count = count
74 |
75 |
76 | class SolverSolution(Solution):
77 | def __init__(self, string=None, prog=None, likelihood=None, time=None, count=None):
78 | super().__init__(string=string, prog=prog, likelihood=likelihood)
79 | self.time = time
80 | self.count = count
81 |
82 |
83 | def get_arg_type_str(sat_str):
84 | assert sat_str.startswith("def sat(") and ":" in sat_str
85 | depth = 0
86 | for i, c in enumerate(sat_str):
87 | if c == '[':
88 | depth += 1
89 | elif c == ']':
90 | depth -= 1
91 | elif c in ")," and depth == 0:
92 | return sat_str[sat_str.index(":") + 1:i].lstrip()
93 | assert False
94 |
95 |
96 | class Challenge():
97 | def __init__(self, challenge_config, max_ticks=100000000):
98 | self.name = challenge_config["name"]
99 | self.f_str = challenge_config["sat"]
100 | self.type_str = get_arg_type_str(challenge_config["sat"])
101 | self.type = eval(self.type_str)
102 | self.gold_solutions = []
103 | self.solver_solutions = []
104 | for sol in challenge_config["sols"]:
105 | self.gold_solutions.append(Solution(string=sol))
106 | if "sol_tries" in challenge_config:
107 | for i, x in enumerate(challenge_config["sol_tries"]):
108 | self.gold_solutions[i].count = x
109 |
110 | if "sol_time" in challenge_config:
111 | for i, x in enumerate(challenge_config["sol_time"]):
112 | self.gold_solutions[i].time = x
113 |
114 | self.solution_strs = challenge_config["sols"]
115 | self.max_ticks = max_ticks
116 |
117 | self._parse_challenge()
118 |
119 | def _parse_challenge(self):
120 | '''
121 | Converts the challenge string to a tython program.
122 | '''
123 | self.sol_kind = tython.nt.type2nt(self.type)
124 | self.prog = None
125 | self.f = None
126 | try:
127 | self.prog = tython.Program(
128 | self.f_str)
129 | self.f = self.prog.run(max_ticks=self.max_ticks)
130 | except Program.EvalException as e:
131 | logger.warning(f"Exception evaluating {self.name} '{self.f_str}': {e}")
132 | except Exception as e:
133 | logger.warning(f"Exception parsing {self.name} '{self.f_str}': {e}")
134 |
--------------------------------------------------------------------------------
/solvers/enumerative/challenges/solutions.py:
--------------------------------------------------------------------------------
1 | from typing import List, Set, Dict, Callable, Tuple
2 | import logging
3 | from challenges import extract_constants
4 |
5 | from tython import Program, TastNode, _RULES_BY_KIND, RULES, Rule, str2name
6 | from tython.rules import DEF_RULE
7 |
8 | logger = logging.getLogger(__name__)
9 | logger.setLevel(logging.INFO)
10 |
11 |
12 | def generatable_answer(q: Program, a: Program):
13 | def get_src(node):
14 | return Program(node).src(safe=False, simplify=False)
15 |
16 | consts = extract_constants(q)
17 | const_srcs = {k: [get_src(c) for c in consts[k]] for k in consts}
18 |
19 | def helper(anode):
20 | if anode.rule.name == "COPY":
21 | return get_src(anode.children[0]) in const_srcs[anode.rule.nt]
22 | return anode.rule.name != "literal" and all(helper(n) for n in anode.children)
23 |
24 | return helper(a.tree)
25 |
26 |
27 | def verify_solutions(challenges):
28 | '''
29 | Verify all provided solutions to the given challenges and store the parsed solution
30 | program in the challenge object.
31 | '''
32 | successes = 0
33 | all_correct = True
34 | for ch in challenges:
35 | verified_sols = []
36 | f = ch.prog.run(max_ticks=ch.max_ticks)['sat']
37 | args_node = ch.prog.tree.children[0].children[1].children[-1]
38 | for sol_p in ch.gold_solutions:
39 | s = sol_p.string
40 | if s == '':
41 | continue
42 | # Verify solution is correct.
43 | if True:
44 | assert s.startswith('def ')
45 | try:
46 | sol_prog = Program(s)
47 | except Exception as e:
48 | logger.error(f"Exception parsing solution for {ch.name} '{s}': {e}")
49 | continue
50 |
51 | # Inject the value assignments for the variables to the call of the sol func.
52 | p_body = sol_prog.tree.children[0].children[2]
53 | sol_prog.tree.children[0].children[1].children = args_node.children
54 | sol_prog = Program(TastNode(DEF_RULE, [str2name("sol"), args_node, p_body]))
55 |
56 | a_safe = sol_prog.src(simplify=False)
57 | x = sol_prog.run(max_ticks=ch.max_ticks)["sol"]()
58 | ch.prog.reset_clock()
59 |
60 | v = f(x)
61 | assert isinstance(v, bool)
62 |
63 | if not generatable_answer(ch.prog, sol_prog):
64 | logger.error(f'Challenge "{ch.name}" cannot be used to automatically generate solution "{s}"')
65 |
66 | # TODO
67 | # if type(y) != ch.type:
68 | # print(f'Challenge "{ch.name}" has wrong solution type: "{type(y)}"')
69 | # all_correct = False
70 | if v is not True: # checks both False and None
71 | logger.error(f'Challenge "{ch.name}" not satisfied by solution "{s}"')
72 | else:
73 | sol_p.prog = sol_prog
74 | verified_sols.append(sol_p)
75 | successes += 1
76 |
77 | ch.gold_solutions = verified_sols
78 |
79 | logger.info(
80 | f"Tython confirmed {successes:,} solutions to {len(challenges)} challenges."
81 | )
82 | return
83 |
--------------------------------------------------------------------------------
/solvers/enumerative/download_pretrained_roberta.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | # Linux commands to download our Roberta model pretrained on Python code.
4 | # Newer vesrions of huggingface transformers don't require this but we need to adjust the rest of the code for them.
5 |
6 | set -ex
7 |
8 | mkdir tals
9 | mkdir tals/roberta_python
10 |
11 | cd tals/roberta_python
12 |
13 | wget https://huggingface.co/tals/roberta_python/resolve/main/config.json
14 | wget https://huggingface.co/tals/roberta_python/resolve/main/merges.txt
15 | wget https://huggingface.co/tals/roberta_python/resolve/main/pytorch_model.bin
16 | wget https://huggingface.co/tals/roberta_python/resolve/main/special_tokens_map.json
17 | wget https://huggingface.co/tals/roberta_python/resolve/main/tokenizer_config.json
18 | wget https://huggingface.co/tals/roberta_python/resolve/main/training_args.bin
19 | wget https://huggingface.co/tals/roberta_python/resolve/main/vocab.json
20 |
21 | cd ../..
22 |
--------------------------------------------------------------------------------
/solvers/enumerative/filter_outputs.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import os
4 |
5 |
6 | inp_file = sys.argv[1]
7 | if len(sys.argv) > 2:
8 | shift = int(sys.argv[2])
9 | else:
10 | shift = 0
11 | out_file = os.path.splitext(inp_file)[0]
12 |
13 |
14 | with open(inp_file, 'r') as f:
15 | data = json.load(f)
16 |
17 | thresholds = [100, 1000, 10000, 100000, 1000000]
18 | for t in thresholds:
19 | t = t
20 | out = []
21 | suc = 0
22 | for p in data:
23 | #if not p["name"].startswith("Study"):
24 | # continue
25 | if p["sols"][-1] != "" and p["sol_tries"][-1] + shift <= t:
26 | out.append(p)
27 | suc += 1
28 | else:
29 | out.append(dict(name=p["name"], sat=p["sat"], sols=[]))
30 |
31 | print(f"t={t}: solutions: {suc}/ {len(out)}")
32 |
33 | with open(out_file + f"_{t}.json", "w") as fw:
34 | json.dump(out, fw, indent=4)
35 |
--------------------------------------------------------------------------------
/solvers/enumerative/models/__init__.py:
--------------------------------------------------------------------------------
1 | MODEL_REGISTRY = {}
2 | def RegisterModel(model_name):
3 | def decorator(m):
4 | MODEL_REGISTRY[model_name] = m
5 | return m
6 |
7 | return decorator
8 |
9 | from models.uniform import *
10 | #from models.bigram import *
11 | #from models.ml_bow_unigram import *
12 | from models.ml_bow_bigram import *
13 |
--------------------------------------------------------------------------------
/solvers/enumerative/models/ml_bow_bigram.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Callable, Tuple, Generator, Set, Sequence
2 | from collections import defaultdict
3 | import numpy as np
4 | from copy import deepcopy
5 | import sklearn
6 | import sklearn.linear_model
7 | import sklearn.ensemble
8 | import sklearn.neighbors
9 | import sklearn.tree
10 | from scipy import sparse
11 | from tqdm import tqdm
12 | import time
13 | import logging
14 |
15 | from tython import Program, TastNode, _RULES_BY_KIND, Rule, RULES, nt
16 | from models.model import CandidateGenerator, reachable_rules_by_kind
17 | from models import RegisterModel
18 | from challenges import extract_constants
19 |
20 | logger = logging.getLogger(__name__)
21 | logger.setLevel(logging.INFO)
22 |
23 |
24 | @RegisterModel("ml_bow_bigram")
25 | class BOWCondParentLRModel(CandidateGenerator):
26 | '''
27 | learn rule weights with logistic regression.
28 | Features are: BOW for question and parent rule + child ind
29 | '''
30 |
31 | def __init__(self, ml_model='lr') -> None:
32 | '''
33 | Learn model per kind.
34 | featurization: feature extraction method.
35 | '''
36 | super().__init__()
37 | self.vocabs = _RULES_BY_KIND
38 | self.ml_model = ml_model
39 |
40 | if ml_model == 'lr':
41 | self.sk_model = sklearn.linear_model.LogisticRegression(penalty='l2')
42 | elif ml_model == 'rf':
43 | self.sk_model = sklearn.ensemble.RandomForestClassifier()
44 | elif ml_model == 'knn':
45 | self.sk_model = sklearn.neighbors.KNeighborsClassifier()
46 | elif ml_model == 'dt':
47 | self.sk_model = sklearn.tree.DecisionTreeClassifier()
48 | else:
49 | raise Exception("Unknown ml model %s", ml_model)
50 |
51 | def featurize_question(self, prog: Program):
52 | '''
53 | Convert program to features with bow (bag of words/ rules)
54 | '''
55 |
56 | qf = np.zeros(len(RULES), dtype=int)
57 |
58 | # Get all rules in prog.
59 |
60 | queue = [prog.tree]
61 | while queue:
62 | node = queue.pop()
63 | qf[node.rule.index] += 1
64 | for child in node.children:
65 | queue.append(child)
66 |
67 | return sparse.csr_matrix(qf)
68 |
69 | def add_parent_features(self, qf: np.array, parent_rule: int, child_num: int):
70 | # Assume max 10 children
71 | assert child_num < 10
72 | parent_feature = np.zeros(len(RULES) + 1 + 10, dtype=int)
73 | parent_feature[parent_rule] = 1
74 | parent_feature[child_num + len(RULES) + 1] = 1
75 |
76 | # features = np.concatenate((qf, parent_feature))
77 | parent_feature = sparse.csr_matrix(parent_feature)
78 | features = sparse.hstack([qf, parent_feature])
79 | return features
80 |
81 | def traverse_ans_tree(self, prog: Program):
82 | '''
83 | Collect all rules with their features.
84 | '''
85 | # Initiate with root, parent_ind=len(RULES)
86 | rules = [[prog.tree.rule.index, len(RULES), 0]]
87 | queue = [prog.tree]
88 | while queue:
89 | node = queue.pop()
90 | parent_ind = node.rule.index
91 | for num_child, child in enumerate(node.children):
92 | rules.append([child.rule.index, parent_ind, num_child])
93 | queue.append(child)
94 |
95 | return rules
96 |
97 | def learn(self, QAs: List[Tuple[Program, Program]]) -> None:
98 | '''
99 | Optimize the ML models with the given examples.
100 | QAs: list of QAs for building vocab.
101 | '''
102 | # xs = None
103 | xs = []
104 | targets = []
105 | logger.info("Creating features from puzzle solution pairs")
106 | for q, a in tqdm(QAs):
107 | qf = self.featurize_question(q)
108 | for (target_rule, parent_rule, child_num) in self.traverse_ans_tree(a):
109 | input_features = self.add_parent_features(qf, parent_rule, child_num)
110 |
111 | # if xs is None:
112 | # xs = sparse.csr_matrix(input_features)
113 | # else:
114 | # xs = sparse.vstack([xs, input_features])
115 | xs.append(sparse.csr_matrix(input_features))
116 | targets.append(target_rule)
117 |
118 | logger.info(f"Collected {len(targets)} rules")
119 | xs = sparse.vstack(xs)
120 | self.sk_model.fit(xs, targets)
121 |
122 | def get_candidates(self, q: Program) -> Dict:
123 | '''
124 | Predict candidates for each question
125 | '''
126 | st_time = time.time()
127 | qf = self.featurize_question(q)
128 |
129 | consts = extract_constants(q)
130 |
131 | # Our question is always a function that returns a bool.
132 | sol_type_annotation = q.tree.children[0].children[1].children[1]
133 | sol_kind = nt.type2nt(eval(sol_type_annotation.src()))
134 |
135 | rules_by_kind = _RULES_BY_KIND # zzz reachable_rules_by_kind(sol_kind, consts)
136 | rules_by_kind_sets = {k: {r.index for r in rules} for k, rules in rules_by_kind.items()}
137 |
138 | ans = {}
139 |
140 | times1 = []
141 | times2 = []
142 | for parent_kind in rules_by_kind:
143 | # None for root.
144 | for r in rules_by_kind[parent_kind] + ([None] if parent_kind == sol_kind else []):
145 | # Don't include parents that we won't reach anyway (a.k.a massive pruning).
146 | if r is not None and r.index not in self.sk_model.classes_:
147 | continue
148 | if r and r.name == "COPY":
149 | assert len(r.kids) == 1
150 | if r.nt in consts:
151 | p = 1/len(consts[r.nt])
152 | relevant_consts = [(p, n) for n in consts[r.nt]]
153 | else:
154 | relevant_consts = []
155 | ans[r] = [(relevant_consts, [])]
156 | continue
157 |
158 | ans[r] = []
159 | for child_num, kind in enumerate([sol_kind] if r is None else r.kids):
160 | tik2 = time.time()
161 | class_mask = np.array([c in rules_by_kind_sets[kind] for c in self.sk_model.classes_])
162 | assert child_num < 9
163 | # child_num = min(child_num, 9)
164 | if r is None:
165 | parent_rule = len(RULES)
166 | else:
167 | parent_rule = r.index
168 |
169 | input_features = self.add_parent_features(qf, parent_rule, child_num)
170 | tik1 = time.time()
171 | # TODO: create batches to reduce run time.
172 | rule_probs = self.sk_model.predict_proba(input_features.reshape(1, -1))[0]
173 | tok1 = time.time() - tik1
174 | times1.append(tok1)
175 |
176 | # Mask out predictions to rules of different kind
177 | sum_probs = np.sum(rule_probs * class_mask)
178 | if sum_probs > 0:
179 | rule_probs = rule_probs * class_mask / sum_probs
180 | rule_probs = [(rule_probs[i], RULES[self.sk_model.classes_[i]]) for i in
181 | rule_probs.nonzero()[0]]
182 | assert all(r.nt == kind for _, r in rule_probs)
183 | else:
184 | rule_probs = []
185 |
186 |
187 |
188 | ans[r].append(([], rule_probs))
189 | tok2 = time.time() - tik2
190 | times2.append(tok2)
191 |
192 | # print(np.mean(times1))
193 | # print(np.mean(times2))
194 | end_time = time.time() - st_time
195 | logger.debug("Get candidates took {:.2f}s".format(end_time))
196 |
197 | return ans
198 |
199 | def get_likelihood(self, q: TastNode, ans: Program) -> float:
200 | '''
201 | Get the probaility of the given answer program to the question.
202 | '''
203 | ans_rules = self.traverse_ans_tree(ans)
204 | qf = self.featurize_question(q)
205 | rule_probs = []
206 | for r in ans_rules:
207 | target_rule_ind = r[0]
208 | parent_rule, child_num = r[1:]
209 | input_features = self.add_parent_features(qf, parent_rule, child_num)
210 | out_probs = self.sk_model.predict_proba(input_features.reshape(1, -1))[0]
211 | target_prob = out_probs[np.where(self.sk_model.classes_ == target_rule_ind)].item() \
212 | if target_rule_ind in self.sk_model.classes_ else 0
213 | rule_probs.append(target_prob)
214 |
215 | return np.prod(rule_probs)
216 |
--------------------------------------------------------------------------------
/solvers/enumerative/models/model.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Callable, Tuple, Generator, Set, Sequence
2 |
3 | from tython import Program, TastNode, Rule, _RULES_BY_KIND
4 |
5 | def reachable_rules_by_kind(sol_kind, terminal_kinds) -> Set:
6 | '''
7 | Finds all kinds that could be in a tree with sol_kind as the root
8 | and terminal kinds as possible kinds for constants.
9 | '''
10 | # some kinds are unusuable because there is no way to generate a complete tree
11 | # starting with them.
12 | completable = set(terminal_kinds) # force copy
13 | signatures = {k: {r.kids for r in _RULES_BY_KIND[k]} for k in _RULES_BY_KIND}
14 | uncompletable = set(signatures) - completable
15 | while True:
16 | completable.update({k for k in uncompletable
17 | if any(all(k in completable for k in s) for s in signatures[k])})
18 | if not completable.intersection(uncompletable):
19 | break
20 | uncompletable -= completable
21 | for k in signatures:
22 | signatures[k] = {s for s in signatures[k] if all(k2 in completable for k2 in s)}
23 | signatures = {k: s for k, s in signatures.items() if s}
24 | if sol_kind not in completable:
25 | return {}
26 |
27 | good_kinds = set()
28 | kind_queue = {sol_kind}
29 |
30 | while kind_queue:
31 | parent_kind = kind_queue.pop()
32 | good_kinds.add(parent_kind)
33 | kind_queue.update({k for sig in signatures[parent_kind]
34 | for k in sig if k not in good_kinds})
35 |
36 | good_rules_by_kind = {k: [r for r in _RULES_BY_KIND[k] if all(k2 in good_kinds for k2 in r.kids)]
37 | for k in _RULES_BY_KIND if k in good_kinds}
38 |
39 | return {k: v for k, v in good_rules_by_kind.items() if v}
40 |
41 | class CandidateGenerator:
42 | '''
43 | A model that generate candidate solutions for challenges.
44 | '''
45 |
46 | def __init__(self) -> None:
47 | pass
48 |
49 | def learn(self, QAs):
50 | return
51 |
52 | def get_candidates(self, q: TastNode) -> Dict[Rule,
53 | List[Tuple[List[Tuple[float, TastNode]], List[Tuple[float, Rule]]]]]:
54 | '''
55 | Get solution candidates for a question q.
56 | TODO: program instead of TastNode?
57 | '''
58 | pass
59 |
60 |
--------------------------------------------------------------------------------
/solvers/enumerative/models/transformers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/PythonProgrammingPuzzles/017f488b8f6320b8220cda38e78f2f72a36eab20/solvers/enumerative/models/transformers/__init__.py
--------------------------------------------------------------------------------
/solvers/enumerative/models/transformers/generate_rule_embeddings.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import logging
4 | import csv
5 | import os
6 | import random
7 | import json
8 | import string
9 |
10 | import numpy as np
11 | import torch
12 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
13 | from tqdm import tqdm, trange
14 |
15 | from transformers import (
16 | WEIGHTS_NAME,
17 | AutoConfig,
18 | AutoModel,
19 | AutoTokenizer,
20 | )
21 |
22 | from dataset_processor import PuzzleProcessor, InputExample, InputFeatures, convert_examples_to_features
23 |
24 | from tython import Program, TastNode, _RULES_BY_KIND, Rule, RULES
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 | def set_seed(args):
29 | random.seed(args.seed)
30 | np.random.seed(args.seed)
31 | torch.manual_seed(args.seed)
32 | if args.n_gpu > 0:
33 | torch.cuda.manual_seed_all(args.seed)
34 |
35 | def save_embeds(file_path, embeds, vocab_file):
36 | n_tokens, emb_dim = embeds.shape
37 | with open(file_path, 'w') as f:
38 | writer = csv.writer(f, delimiter='\t')
39 | for i in range(n_tokens):
40 | writer.writerow(list(embeds[i,:]))
41 |
42 | with open(vocab_file, 'w') as f:
43 | writer = csv.writer(f, delimiter='\t')
44 | for i in range(n_tokens):
45 | desc = generate_rule_string(RULES[i])
46 | writer.writerow([desc])
47 |
48 | return
49 |
50 |
51 | # First attempt to get random values for rules.
52 |
53 | #thing_list = ['x', 'y', 'z', 'a', 'b', 'i', 'j', 'k']
54 | #bool_list = ['True', 'False']
55 | #
56 | #
57 | #kind_gens = {
58 | # 'int': lambda: random.choice([str(random.randint(-1000, 1000)), 'i', 'j']) ,
59 | # 'thing': lambda: random.choice(thing_list),
60 | # 'float' : lambda: round(random.random() * random.randint(-100,100), 3),
61 | # 'bool': lambda: random.choice(bool_list),
62 | # 'range': lambda: random.choice(bool_list),
63 | # 'str': lambda: '"{}"'.format(''.join([random.choice(string.ascii_uppercase + string.ascii_lowercase +string.digits) for _ in range(random.randint(1,10))])),
64 | # 'range': lambda: random.choice(['range({})'.format(random.randint(-100,100)),
65 | # 'range({},{})'.format(random.randint(-100,100), random.randint(-100,100)),
66 | # 'range({},{},{})'.format(random.randint(-100,100), random.randint(-100,100), random.randint(-10,10))]),
67 | # 'none': lambda: 'None',
68 | #}
69 | #
70 | #
71 | #special_gens = {
72 | # 'Callable': lambda: 'func({})',
73 | # 'Tuple': lambda: '({})',
74 | # 'Set': lambda: '{{}}',
75 | #}
76 | #
77 | #
78 | #def gen_random_kind(kind):
79 | # '''
80 | # Given a string or list of kind, returns a string with an instance of that kind
81 | # '''
82 | # if isinstance(kind, (tuple, list)):
83 | # var_instances = []
84 | # for k in kind[1:]:
85 | # var_instances.append(gen_random_kind(k))
86 | #
87 | # base_str = special_gens[kind[0]]()
88 | # combine_str = base_str.format(', '.join([str(i) for i in var_instances]))
89 | # return combine_str
90 | # else:
91 | # return kind_gens[kind]()
92 | #
93 | #
94 | #def open_kind(kind):
95 | # '''
96 | # return list of the names of kinds that create it.
97 | # '''
98 | # if isinstance(kind, (tuple, list)):
99 | # return [open_kind(k) for k in kind]
100 | # else:
101 | # return kind.__name__ if hasattr(kind, "__name__") else kind._name if hasattr(kind, "_name") else str(kind)
102 | #
103 | #
104 | #def generate_rule_strings(rule, max_num = 100):
105 | # children_kinds = []
106 | # for kind in rule.children:
107 | # children_kinds.append(open_kind(kind))
108 | #
109 | # rule_strings = []
110 | # for _ in range(max_num):
111 | # children_vals = []
112 | # for k in children_kinds:
113 | # children_vals.append(gen_random_kind(k))
114 | #
115 | # desc = rule.to_python(*children_vals)
116 | # if desc not in rule_strings:
117 | # rule_strings.append(desc)
118 | #
119 | # if 'Callable' in str(rule):
120 | # return rule_strings
121 |
122 | def generate_rule_string(rule):
123 | return rule.var_repr()
124 |
125 | def main():
126 | parser = argparse.ArgumentParser()
127 |
128 | # Required parameters
129 | parser.add_argument(
130 | "--model_name_or_path",
131 | default=None,
132 | type=str,
133 | required=True,
134 | help="Path to pretrained model or model identifier from huggingface.co/models",
135 | )
136 | parser.add_argument(
137 | "--output_dir",
138 | default=None,
139 | type=str,
140 | required=True,
141 | help="The output directory where the model predictions and checkpoints will be written.",
142 | )
143 |
144 | # Other parameters
145 | parser.add_argument(
146 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
147 | )
148 | parser.add_argument(
149 | "--tokenizer_name",
150 | default="",
151 | type=str,
152 | help="Pretrained tokenizer name or path if not the same as model_name",
153 | )
154 | parser.add_argument(
155 | "--cache_dir",
156 | default=None,
157 | type=str,
158 | help="Where do you want to store the pre-trained models downloaded from s3",
159 | )
160 | parser.add_argument(
161 | "--max_seq_length",
162 | default=128,
163 | type=int,
164 | help="The maximum total input sequence length after tokenization. Sequences longer "
165 | "than this will be truncated, sequences shorter will be padded.",
166 | )
167 | parser.add_argument(
168 | "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
169 | )
170 |
171 | parser.add_argument(
172 | "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
173 | )
174 |
175 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
176 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
177 |
178 | args = parser.parse_args()
179 |
180 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
181 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
182 | args.device = device
183 |
184 | # Setup logging
185 | logging.basicConfig(
186 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
187 | datefmt="%m/%d/%Y %H:%M:%S",
188 | level=logging.INFO,
189 | )
190 | logger.warning(
191 | "Process device: %s, n_gpu: %s",
192 | device,
193 | args.n_gpu,
194 | )
195 |
196 | # Set seed
197 | set_seed(args)
198 |
199 | # Load pretrained model and tokenizer
200 | config = AutoConfig.from_pretrained(
201 | args.config_name if args.config_name else args.model_name_or_path,
202 | cache_dir=args.cache_dir,
203 | )
204 | args.model_type = config.model_type
205 | tokenizer = AutoTokenizer.from_pretrained(
206 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
207 | do_lower_case=args.do_lower_case,
208 | cache_dir=args.cache_dir,
209 | )
210 | model = AutoModel.from_pretrained(
211 | args.model_name_or_path,
212 | from_tf=bool(".ckpt" in args.model_name_or_path),
213 | config=config,
214 | cache_dir=args.cache_dir,
215 | )
216 |
217 | model.to(args.device)
218 | logger.info("Generating embeddings for %s rules", len(RULES))
219 |
220 | examples = []
221 | for i, rule in enumerate(RULES):
222 | desc = generate_rule_string(rule)
223 | examples.append(
224 | InputExample(guid=i,
225 | puzzle_str=desc,
226 | parent_ind=None,
227 | child_num=None,
228 | label=None))
229 |
230 | features = convert_examples_to_features(examples, tokenizer, max_length=args.max_seq_length)
231 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
232 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.int)
233 | dataset = TensorDataset(all_input_ids, all_attention_mask)
234 |
235 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
236 | eval_sampler = SequentialSampler(dataset)
237 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
238 |
239 | hidden_size = model.config.hidden_size
240 | # multi-gpu eval
241 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
242 | model = torch.nn.DataParallel(model)
243 |
244 | logger.info("***** Running evaluation *****")
245 | logger.info(" Num examples = %d", len(dataset))
246 | logger.info(" Batch size = %d", args.eval_batch_size)
247 | rule_embeddings = np.zeros((len(RULES), hidden_size))
248 | for b, batch in tqdm(enumerate(eval_dataloader), desc="Evaluating", total=len(dataset)/args.eval_batch_size):
249 | model.eval()
250 | batch = tuple(t.to(args.device) for t in batch)
251 |
252 | with torch.no_grad():
253 | mask = batch[1]
254 | inputs = {"input_ids": batch[0], "attention_mask": mask}
255 | outputs, _ = model(**inputs)
256 |
257 | # Get Average of token representations
258 | mask_ex = mask.unsqueeze(-1).expand_as(outputs)
259 | y = (outputs * mask_ex).sum(1)
260 | y = y / mask_ex.sum(1)
261 |
262 | y = y.detach().cpu().numpy()
263 | for i, emb in enumerate(y):
264 | rule_embeddings[b * args.eval_batch_size + i, :] = emb
265 |
266 | os.makedirs(args.output_dir, exist_ok=True)
267 | out_file = os.path.join(args.output_dir, "rule_embeds.tsv")
268 | vocab_file = os.path.join(args.output_dir, "rule_embeds_vocab.tsv")
269 | save_embeds(out_file, rule_embeddings, vocab_file)
270 | logger.info("Saved embeddings to %s", out_file)
271 |
272 | if __name__ == "__main__":
273 | main()
274 |
--------------------------------------------------------------------------------
/solvers/enumerative/models/transformers/learn_tokenizer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | from os.path import join
4 |
5 | from tokenizers import ByteLevelBPETokenizer
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument(
9 | "--files",
10 | metavar="path",
11 | type=str,
12 | default="python_data/train.txt",
13 | help="The files to use as training; accept '**/*.txt' type of patterns \
14 | if enclosed in quotes",
15 | )
16 | parser.add_argument(
17 | "--out",
18 | default="trained_models/roberta_python_tokenizer",
19 | type=str,
20 | help="Path to the output directory, where the files will be saved",
21 | )
22 | parser.add_argument(
23 | "--name", default="bpe-bytelevel", type=str, help="The name of the output vocab files"
24 | )
25 | args = parser.parse_args()
26 |
27 | files = glob.glob(args.files)
28 | if not files:
29 | print(f"File does not exist: {args.files}")
30 | exit(1)
31 |
32 |
33 | # Initialize an empty tokenizer
34 | tokenizer = ByteLevelBPETokenizer(add_prefix_space=True)
35 |
36 | # And then train
37 | tokenizer.train(
38 | files,
39 | vocab_size=50265,
40 | min_frequency=2,
41 | show_progress=True,
42 | special_tokens=["", "", "", "", ""],
43 | )
44 |
45 | # Save the files
46 | os.makedirs(args.out, exist_ok=True)
47 | tokenizer.save_model(args.out, args.name)
48 |
49 | # Restoring model from learned vocab/merges
50 | tokenizer = ByteLevelBPETokenizer(
51 | join(args.out, "{}-vocab.json".format(args.name)),
52 | join(args.out, "{}-merges.txt".format(args.name)),
53 | add_prefix_space=True,
54 | )
55 |
56 | # Test encoding
57 | print(tokenizer.encode("Training ByteLevel BPE is very easy").tokens)
58 |
--------------------------------------------------------------------------------
/solvers/enumerative/models/transformers/preprocess_pretraining_data.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import gzip
3 | import json
4 | import random
5 | from tqdm import tqdm
6 |
7 | for split in ['train', 'valid', 'test']:
8 | data_path = 'python_data/{}/*.jsonl.gz'.format(split)
9 | #output_file = 'python_data/{}.jsonl'.format(split)
10 | output_file = 'python_data/{}.txt'.format(split)
11 |
12 | if False:
13 | of = open(output_file, 'w')
14 |
15 | for file_name in glob.glob(data_path):
16 | with gzip.open(file_name, 'rb') as f:
17 | for line in tqdm(f, desc=file_name):
18 | example = json.loads(line)
19 | code = example['code']
20 | doc = example['docstring']
21 | assert doc in code
22 | new_c = code[:code.index(doc)] + code[code.index(doc) + len(doc):]
23 | #new_c = new_c.replace('\n', '\\n')
24 | #doc = doc.replace('\n', '\\n')
25 | if random.random() < 0.5:
26 | inp = new_c
27 | out = doc
28 | else:
29 | inp = doc
30 | out = new_c
31 |
32 | of.write(json.dumps(dict(inp=inp, out=out)) + '\n')
33 |
34 | of.close()
35 |
36 | if True:
37 | of = open(output_file, 'w')
38 |
39 | for file_name in glob.glob(data_path):
40 | with gzip.open(file_name, 'rb') as f:
41 | for line in tqdm(f, desc=file_name):
42 | example = json.loads(line)
43 | of.write(example['original_string'].replace('\n', '\\n'))
44 | of.write('\n')
45 |
46 | of.close()
47 |
48 | if False:
49 | out_splt = split if split != 'valid' else 'val'
50 | src = 'python_data/{}.source'.format(out_splt)
51 | tgt = 'python_data/{}.target'.format(out_splt)
52 | of_src = open(src, 'w')
53 | of_tgt = open(tgt, 'w')
54 | for file_name in glob.glob(data_path):
55 | with gzip.open(file_name, 'rb') as f:
56 | for line in tqdm(f, desc=file_name):
57 | example = json.loads(line)
58 | code = example['code']
59 | doc = example['docstring']
60 | assert doc in code
61 | new_c = code[:code.index(doc)] + code[code.index(doc) + len(doc):]
62 | new_c = new_c.replace('\n', '\\n')
63 | doc = doc.replace('\n', '\\n')
64 | if random.random() < 0.5:
65 | of_src.write(new_c)
66 | of_src.write('\n')
67 | of_tgt.write(doc)
68 | of_tgt.write('\n')
69 | else:
70 | of_src.write(doc)
71 | of_src.write('\n')
72 | of_tgt.write(new_c)
73 | of_tgt.write('\n')
74 | of_src.close()
75 | of_tgt.close()
76 |
--------------------------------------------------------------------------------
/solvers/enumerative/models/uniform.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Callable, Tuple, Generator, Set, Sequence
2 | import functools
3 | import operator
4 | from collections import defaultdict
5 | import random
6 |
7 | from tython import Program, TastNode, _RULES_BY_KIND, Rule, nt
8 | from models.model import CandidateGenerator, reachable_rules_by_kind
9 | from models import RegisterModel
10 | from challenges import extract_constants
11 |
12 |
13 | def prod(iterable): # like sum but product
14 | return functools.reduce(operator.mul, iterable, 1)
15 |
16 |
17 | @RegisterModel("uniform")
18 | class UniformModel(CandidateGenerator):
19 | '''
20 | Uniformly samples from all rules.
21 | '''
22 |
23 | def __init__(self, copy_prob=0.5) -> None:
24 | super().__init__()
25 | self.copy_prob = copy_prob
26 | # self.intended_depth = intended_depth
27 | #self._random_cache = defaultdict(lambda: defaultdict(dict))
28 | self._random_cache = {} # defaultdict is not pickable (for multiproc)
29 |
30 | def random(self, kind, max_depth=5, nodes_by_kind=None):
31 | if nodes_by_kind is None:
32 | nodes_by_kind = {}
33 | key = sum(hash(n) for n in nodes_by_kind)
34 | if key not in self._random_cache:
35 | self._random_cache[key] = {}
36 | cache = self._random_cache[key]
37 |
38 | # cache[kind][depth] is a list of available rules
39 |
40 | def available_rules(kind, depth):
41 | if depth <= 0:
42 | return []
43 | if kind in cache and depth in cache[kind]:
44 | return cache[kind][depth]
45 | rules = [r for r in _RULES_BY_KIND[kind] if
46 | all([nodes_by_kind.get(k) or available_rules(k, depth - 1) for k in r.kids])]
47 | if kind not in cache:
48 | cache[kind] = {}
49 | cache[kind][depth] = rules
50 | return rules
51 |
52 | def helper(kind, depth):
53 | assert depth >= 0
54 | rules = available_rules(kind, depth)
55 | assert rules or nodes_by_kind.get(kind), f"Cannot generate random {kind} of depth <= {depth}"
56 |
57 | if nodes_by_kind.get(kind) and (not rules or random.random() < self.copy_prob):
58 | return random.choice(nodes_by_kind[kind])
59 |
60 | rule = random.choice(rules)
61 |
62 | return TastNode(rule, [helper(k, depth - 1) for k in rule.kids])
63 |
64 | return Program(helper(kind, max_depth))
65 |
66 | def get_candidates_by_nodes(self, kind, nodes_by_kind):
67 | rules_by_kind = _RULES_BY_KIND # zzz reachable_rules_by_kind(kind, nodes_by_kind)
68 |
69 | if kind not in rules_by_kind:
70 | return {}
71 |
72 | # p_kind_rules = {k: (1 - self.copy_prob if nodes_by_kind.get(k) else 1) / max(1, len(rules_by_kind[k]))
73 | # for k in rules_by_kind}
74 |
75 | by_kind = {}
76 |
77 | for parent_kind in rules_by_kind:
78 | rules = rules_by_kind[parent_kind]
79 | has_copy = sum(r.name == "COPY" for r in rules)
80 | if has_copy:
81 | assert has_copy == 1
82 | by_kind[parent_kind] = [], [
83 | (self.copy_prob if r.name == "COPY" else (1 - self.copy_prob) / len(rules), r)
84 | for r in rules]
85 | else:
86 | by_kind[parent_kind] = [], [(1 / len(rules), r) for r in rules]
87 |
88 | ans = {r: [by_kind[k] for k in r.kids]
89 | for rules in rules_by_kind.values() for r in rules if r.name != "COPY"}
90 | ans.update({r: [([(1. / len(nodes_by_kind[r.nt]), n) for n in nodes_by_kind.get(r.nt, [])], [])]
91 | for rules in rules_by_kind.values() for r in rules if r.name == "COPY"})
92 | ans[None] = [by_kind[kind]]
93 |
94 | return ans
95 |
96 | def get_candidates(self, q: Program) -> Dict[Rule,
97 | List[Tuple[List[Tuple[float, TastNode]], List[Tuple[float, Rule]]]]]:
98 |
99 | consts = extract_constants(q)
100 |
101 | sol_type_annotation = q.tree.children[0].children[1].children[1]
102 | sol_kind = nt.type2nt(eval(sol_type_annotation.src()))
103 |
104 | return self.get_candidates_by_nodes(sol_kind, consts)
105 |
106 |
107 | if __name__ == "__main__":
108 | u = UniformModel()
109 | random.seed(0)
110 | for _ in range(100):
111 | p = u.random((List, int), 5)
112 | print(p.src(safe=False))
113 | try:
114 | p.val(max_ticks=1000)
115 | except Program.EvalException:
116 | pass
117 |
--------------------------------------------------------------------------------
/solvers/enumerative/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | orderedset
3 | numpy
4 | astor
5 | sklearn
6 | torch==1.13.1
7 | transformers==4.30.0
8 | tensorboardX
9 |
--------------------------------------------------------------------------------
/solvers/enumerative/run_bigram.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | set -ex
4 |
5 | # Bootsrapping process. Starting with Uniform
6 | python solve_challenges.py \
7 | --challenges_path "results/uniform/out.json" \
8 | --eval_challenges_path "results/uniform/out.json" \
9 | -m ml_bow_bigram \
10 | --ml_model rf \
11 | --copy_prob 0.5 \
12 | --max_n_progs 10000 \
13 | --timeout_secs 3600 \
14 | --threads 40 \
15 | --learn_from_train_sols \
16 | --logging_dir results/bootstrap/ml_bigram_0 \
17 | --out_file results/bootstrap/ml_bigram_0/out.json
18 |
19 |
20 | for i in {1..5}; do
21 | python solve_challenges.py \
22 | --challenges_path "results/bootstrap/ml_bigram_$(($i-1))/out.json" \
23 | --eval_challenges_path "results/bootstrap/ml_bigram_$(($i-1))/out.json" \
24 | -m ml_bow_bigram \
25 | --ml_model rf \
26 | --copy_prob 0.5 \
27 | --max_n_progs 10000 \
28 | --timeout_secs 3600 \
29 | --threads 40 \
30 | --learn_from_train_sols \
31 | --logging_dir results/bootstrap/ml_bigram_$i \
32 | --out_file results/bootstrap/ml_bigram_$i/out.json
33 | done
34 |
35 | # Last run until 1M.
36 | python solve_challenges.py \
37 | --challenges_path "results/bootstrap/ml_bigram_5/out.json" \
38 | --eval_challenges_path "results/bootstrap/ml_bigram_5/out.json" \
39 | -m ml_bow_bigram \
40 | --ml_model rf \
41 | --copy_prob 0.5 \
42 | --max_n_progs 1000000 \
43 | --timeout_secs 3600 \
44 | --threads 20 \
45 | --learn_from_train_sols \
46 | --logging_dir results/bootstrap/ml_bigram_6 \
47 | --out_file results/bootstrap/ml_bigram_6/out.json
48 |
49 |
50 | # Run without self-bootrapping (only over unifrom) until 1M.
51 | python solve_challenges.py \
52 | --challenges_path "results/uniform/out.json" \
53 | --eval_challenges_path "results/uniform/out.json" \
54 | -m ml_bow_bigram \
55 | --ml_model rf \
56 | --copy_prob 0.5 \
57 | --max_n_progs 1000000 \
58 | --timeout_secs 3600 \
59 | --threads 40 \
60 | --learn_from_train_sols \
61 | --logging_dir results/ml_bigram \
62 | --out_file results/ml_bigram/out.json
63 |
64 |
--------------------------------------------------------------------------------
/solvers/enumerative/run_transformer.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | set -ex
4 |
5 | # First, extract rule embeddings.
6 | PYTHONPATH=./ python models/transformers/generate_rule_embeddings.py \
7 | --model_name_or_path tals/roberta_python \
8 | --output_dir results/roberta_rule_embeddings
9 |
10 | # Bootsrapping process. Starting with Uniform
11 | PYTHONPATH=./ python models/transformers/finetune_transformer.py \
12 | --challenges_path results/uniform/out.json \
13 | --eval_challenges_path results/uniform/out.json \
14 | --model_name_or_path tals/roberta_python \
15 | --output_dir results/bootstrap/roberta_0 \
16 | --num_train_epochs 20 \
17 | --do_train \
18 | --do_infer \
19 | --max_n_progs 10000 \
20 | --timeout_secs 3600 \
21 | --threads 40 \
22 | --per_gpu_eval_batch_size 128 \
23 | --per_gpu_train_batch_size 16 \
24 | --rule_emb_dir results/roberta_rule_embeddings \
25 | --overwrite_cache \
26 | --max_ticks 10000
27 |
28 | for i in {1..5}; do
29 | PYTHONPATH=./ python models/transformers/finetune_transformer.py \
30 | --challenges_path "results/bootstrap/roberta_$(($i-1))/solutions.json" \
31 | --eval_challenges_path "results/bootstrap/roberta_$(($i-1))/solutions.json" \
32 | --model_name_or_path tals/roberta_python \
33 | --output_dir results/bootstrap/roberta_$i \
34 | --num_train_epochs 20 \
35 | --do_train \
36 | --do_infer \
37 | --max_n_progs 10000 \
38 | --timeout_secs 3600 \
39 | --threads 40 \
40 | --per_gpu_eval_batch_size 128 \
41 | --per_gpu_train_batch_size 16 \
42 | --rule_emb_dir results/roberta_rule_embeddings \
43 | --overwrite_cache \
44 | --max_ticks 10000
45 | done
46 |
47 | # Last run until 1M.
48 | PYTHONPATH=./ python models/transformers/finetune_transformer.py \
49 | --challenges_path "results/bootstrap/roberta_5/solutions.json" \
50 | --eval_challenges_path "results/bootstrap/roberta_5/solutions.json" \
51 | --model_name_or_path tals/roberta_python \
52 | --output_dir results/bootstrap/roberta_6 \
53 | --num_train_epochs 20 \
54 | --do_train \
55 | --do_infer \
56 | --max_n_progs 1000000 \
57 | --timeout_secs 3600 \
58 | --threads 40 \
59 | --per_gpu_eval_batch_size 128 \
60 | --per_gpu_train_batch_size 16 \
61 | --rule_emb_dir results/roberta_rule_embeddings \
62 | --overwrite_cache \
63 | --max_ticks 10000
64 |
65 |
66 | # Run without self-bootrapping (only over unifrom) until 1M.
67 | PYTHONPATH=./ python models/transformers/finetune_transformer.py \
68 | --challenges_path results/uniform/out.json \
69 | --eval_challenges_path results/uniform/out.json \
70 | --model_name_or_path tals/roberta_python \
71 | --output_dir results/roberta \
72 | --num_train_epochs 20 \
73 | --do_train \
74 | --do_infer \
75 | --max_n_progs 1000000 \
76 | --timeout_secs 3600 \
77 | --threads 40 \
78 | --per_gpu_eval_batch_size 128 \
79 | --per_gpu_train_batch_size 16 \
80 | --rule_emb_dir results/roberta_rule_embeddings \
81 | --overwrite_cache \
82 | --max_ticks 10000
83 |
--------------------------------------------------------------------------------
/solvers/enumerative/run_uniform.sh:
--------------------------------------------------------------------------------
1 | python solve_challenges.py \
2 | -p "../../problems/*.json" \
3 | -m uniform \
4 | --solve_uniform \
5 | --copy_prob 0.5 \
6 | --max_n_progs 10000 \
7 | --timeout_secs 3600 \
8 | --threads 40 \
9 | --logging_dir results/uniform \
10 | --out_file results/uniform/out.json
11 |
--------------------------------------------------------------------------------
/solvers/enumerative/tython/__init__.py:
--------------------------------------------------------------------------------
1 | from .program import *
2 | from . import nonterminals as nt
3 | from .rules import Rule, _rules_by_nt as _RULES_BY_KIND, rules as RULES
4 | from .parse import str2name
5 |
--------------------------------------------------------------------------------
/solvers/enumerative/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from utils.str_utils import *
2 | from utils.time_utils import *
3 |
4 | import functools
5 | import operator
6 |
7 |
8 | def prod(iterable): # like sum but product
9 | return functools.reduce(operator.mul, iterable, 1)
10 |
11 |
12 | def flatten(it):
13 | return (e for a in it for e in (flatten(a) if isinstance(a, (tuple, list)) else (a,)))
14 |
15 |
16 | def load_json(filename):
17 | import json
18 | with open(filename, "r") as f:
19 | return json.load(f)
20 |
21 |
22 | def viz_py(py):
23 | import astor, ast
24 | print(astor.dump_tree(ast.parse(py)))
25 |
26 |
27 | def dedup(li):
28 | seen = set()
29 | return [x for x in li if x not in seen and not seen.add(x)]
30 |
--------------------------------------------------------------------------------
/solvers/enumerative/utils/str_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | def get_lambda_arg_name(lam):
4 | assert lam.startswith("lambda ")
5 | return lam[len("lambda "):lam.index(":")].strip()
6 |
7 |
8 | def stringify(const):
9 | if type(const) is str:
10 | return json.dumps(const)
11 | return str(const)
12 |
13 |
14 | def color_str(obj, code="\033[0;36m"):
15 | return code + str(obj) + '\033[0m'
16 |
17 |
--------------------------------------------------------------------------------
/solvers/enumerative/utils/time_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 |
4 | logger = logging.getLogger(__name__)
5 |
6 | def timeit(method):
7 | def timed(*args, **kw):
8 | tick = time.time()
9 | result = method(*args, **kw)
10 | tock = time.time()
11 | logger.debug(f'{method.__name__}: {tock - tick:.3f}s')
12 |
13 | return result
14 | return timed
15 |
--------------------------------------------------------------------------------
/solvers/gpt3/README.md:
--------------------------------------------------------------------------------
1 | # Running GPT-3 experiments
2 |
3 | These are instructions for re-running the GPT-3 experiments. The results will be slightly different than those in
4 | the paper because the API is non-deterministic.
5 |
6 | The requirements can be installed with `pip3 install -r requirements.txt`.
7 |
8 | `run_gpt3_experiments.py` runs the GPT-3 experiments and prints the results to stdout.
9 |
10 | ## Installation and execution.
11 | You will need an open-ai GPT-3 access key which can be signed up for [here](https://openai.com/join/).
12 | You will then need to set it as the `OPENAI_API_KEY` environmnet variable.
13 |
14 | The requirements can be installed with `pip3 install -r requirements.txt`.
15 |
16 | It was run with Python 3.6.9, sys.version = '3.6.9 (default, Jan 26 2021, 15:33:00) \n[GCC 8.4.0]', but should
17 | be compatible with later versions as well.
18 |
19 | Then you simply run
20 | `python run_gpt3_experiments.py` and the results are written to stdout. It uses cacheing mechanisms with the first run
21 | being quite slow and verbose, querying the API. However you can subsequently run it again and it will be
22 | much faster and just output the results. The cacheing makes it deterministic so it should give the same
23 | exact results when re-run.
24 |
--------------------------------------------------------------------------------
/solvers/gpt3/create_simple_prompts.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 |
4 | from run_gpt3_experiments import PARAMS, PREFIX, PREFIX_DOCSTR
5 | import lm_solve
6 |
7 | OUTFILENAME = "puzzles_with_prompts.json"
8 |
9 |
10 | def run(outfilename=OUTFILENAME, filename=PARAMS["filename"]):
11 | with open(filename, "r") as f:
12 | entries = json.load(f)
13 |
14 | for entry in entries:
15 | entry["prompts"] = {}
16 |
17 | for mode in ["short", "medium", "long"]:
18 | prefix = {
19 | "short": "",
20 | "medium": PREFIX,
21 | "long": PREFIX_DOCSTR
22 | }[mode]
23 | prefix = re.sub(r" +$", "", (prefix or "").lstrip(),
24 | flags=re.M) # delete leading/trailing whitespace on each line
25 | puzzles = lm_solve.load_puzzles(filename, mode == "long")
26 | prompts = lm_solve.get_prompts(prefix, [f for ft, f in puzzles])
27 | assert len(puzzles) == len(prompts) == len(entries)
28 | for entry, prompt in zip(entries, prompts):
29 | entry["prompts"][mode] = prompt
30 |
31 | with open(outfilename, "w") as f:
32 | json.dump(entries, f, indent=4)
33 |
34 |
35 | if __name__ == "__main__":
36 | run()
37 |
--------------------------------------------------------------------------------
/solvers/gpt3/ezlog.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import inspect
4 | import io
5 |
6 | my_path = os.path.dirname(__file__)
7 |
8 |
9 | def color_str(obj, code="\033[0;36m"):
10 | return code + str(obj) + '\033[0m'
11 |
12 |
13 | _configured = False
14 |
15 |
16 | def configure_logging(stdio_level=logging.INFO,
17 | file_level=logging.DEBUG,
18 | filename=".easy.log",
19 | filepath=os.path.join(my_path, "logs")):
20 | os.makedirs(filepath, exist_ok=True)
21 | filename = os.path.join(filepath, filename)
22 | global _configured
23 | if _configured:
24 | warning("Re-configuring logging")
25 | stdio_handler = logging.StreamHandler()
26 | stdio_handler.setLevel(stdio_level)
27 | file_hanlder = logging.FileHandler(filename)
28 | file_hanlder.setLevel(file_level)
29 |
30 | logging.basicConfig(
31 | format="%(asctime)s - %(levelname)s - %(name)s - %(message).200s",
32 | datefmt="%m/%d/%Y %H:%M:%S",
33 | level=min(stdio_level, file_level),
34 | handlers=[stdio_handler, file_hanlder]
35 | )
36 |
37 | _configured = True
38 | _get_or_create_logger().debug("Configured logging")
39 |
40 |
41 | _loggers = {}
42 |
43 |
44 | def _get_or_create_logger():
45 | global _configured, _loggers
46 | if not _configured:
47 | configure_logging()
48 | try:
49 | for frame in inspect.stack():
50 | name = inspect.getmodule(frame[0]).__name__
51 | if name != __name__:
52 | break
53 | except:
54 | name = "_"
55 | if name not in _loggers:
56 | _loggers[name] = logging.getLogger(name)
57 | return _loggers[name]
58 |
59 |
60 | def print_to_string(*args, end="", **kwargs):
61 | with io.StringIO() as buf:
62 | print(*args, file=buf, end=end, **kwargs)
63 | return buf.getvalue()
64 |
65 |
66 | def debug(*args, **kwargs):
67 | _get_or_create_logger().debug(print_to_string(*args, **kwargs))
68 |
69 |
70 | def info(*args, **kwargs):
71 | _get_or_create_logger().info(print_to_string(*args, **kwargs))
72 |
73 |
74 | log = info
75 |
76 |
77 | def warning(*args, **kwargs):
78 | _get_or_create_logger().warning(print_to_string(*args, **kwargs))
79 |
80 |
81 | warn = warning
82 |
83 |
84 | def error(*args, **kwargs):
85 | _get_or_create_logger().error(print_to_string(*args, **kwargs))
86 |
--------------------------------------------------------------------------------
/solvers/gpt3/lm_solve/__init__.py:
--------------------------------------------------------------------------------
1 | from lm_solve.run import *
2 |
--------------------------------------------------------------------------------
/solvers/gpt3/lm_solve/gpt3_lib.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import openai
4 | import ezlog
5 | import time
6 | import datetime
7 |
8 | assert 'OPENAI_API_KEY' in os.environ, "Need to set environment variable `OPENAI_API_KEY`"
9 | openai.api_key = os.environ['OPENAI_API_KEY']
10 |
11 |
12 | _CACHE_PATH = os.path.join(os.path.dirname(__file__), "../.cache")
13 | _CACHE_FILENAME = os.path.join(_CACHE_PATH, "gpt3.cache")
14 | _ENCODING = "utf-8"
15 |
16 | _cache = None
17 |
18 |
19 | # the cache file is just a list of (query params dictionary encoded as a string but without n, result list)
20 | # multiple queries with the same params (except for n) are merged into a single big list
21 | def _save_line(item, comment=None):
22 | global _cache
23 | assert _cache is not None
24 | with open(_CACHE_FILENAME, "a", encoding=_ENCODING) as f:
25 | f.write(str(item)+ ((" # " + comment + "\n") if comment else "\n"))
26 |
27 | def _load_cache():
28 | global _cache
29 |
30 | assert _cache is None, "gpt3 cache already loaded"
31 |
32 | if not os.path.exists(_CACHE_PATH):
33 | ezlog.warn("Creating cache path")
34 | os.makedirs(_CACHE_PATH)
35 |
36 | _cache = {}
37 |
38 | if os.path.exists(_CACHE_FILENAME):
39 | time0 = time.perf_counter()
40 | with open(_CACHE_FILENAME, "r", encoding=_ENCODING) as f:
41 | for k, v in [eval(line) for line in f.readlines()]:
42 | if k not in _cache:
43 | _cache[k] = v
44 | else:
45 | _cache[k].extend(v)
46 | ezlog.info(f"Loaded gpt3 cache in {time.perf_counter()-time0:.1f}s")
47 | else:
48 | ezlog.warn("No gpt3 cache yet")
49 |
50 |
51 |
52 | def query(prompt, n=10, max_tokens=150, temp=1.0, max_batch=32, stop=None, notes=None, cache_only=False, verbose=True):
53 | """Query gpt3
54 |
55 | :param prompt: Up to 2048 tokens (about 3-4k chars)
56 | :param n: number of answers, None returns all cached answers
57 | :param max_tokens:
58 | :param temp: 0.9 seems to work well
59 | :param max_batch: max to query at once
60 | :param stop: string to stop at or '' if not to stop
61 | :param notes: notes you want to save or change in case you want to run the same query more than once!
62 | :return: list of answers and then the response items
63 | """
64 | global _cache
65 | if _cache is None:
66 | _load_cache()
67 |
68 | if temp == 0 and n > 1:
69 | ezlog.debug("Temp 0: no point in running more than one query")
70 | n = 1
71 |
72 | key = str(dict(prompt=prompt, max_tokens=max_tokens, temp=temp, max_batch=max_batch, stop=stop, rep=notes))
73 | cached = _cache.get(key, [])
74 | if n is None:
75 | return cached[:]
76 |
77 | if len(cached) >= n:
78 | return cached[:n]
79 |
80 | if cache_only:
81 | pass
82 | 1/0
83 | assert not cache_only, "Entry not found in cache"
84 | if verbose:
85 | print("/"*100)
86 | print("Querying GPT3 with prompt:")
87 | print(prompt)
88 | s = stop and stop.replace('\n', '\\n')
89 | print(f"/// n={n} ({n-len(cached)} new) max_tokens={max_tokens} temp={temp} max_batch={max_batch} stop={s}")
90 | print("/"*100)
91 |
92 | time0 = time.perf_counter()
93 |
94 | new = []
95 | n -= len(cached)
96 |
97 | while n > 0:
98 | m = min(n, max_batch)
99 |
100 | res = openai.Completion.create(
101 | engine="davinci-msft",
102 | prompt=prompt,
103 | max_tokens=max_tokens,
104 | temperature=temp,
105 | n=m,
106 | stop=stop or None
107 | )
108 |
109 | new += [c["text"] for c in res["choices"]]
110 | n -= m
111 |
112 | _save_line((key, new), f"{time.perf_counter() - time0:.1f}s {datetime.datetime.now()}")
113 | ans = _cache[key] = cached + new
114 | return ans[:]
115 |
116 | # old code
117 | # # to persist calls to the API...
118 | # _disk_cache = joblib.Memory(os.path.join(os.path.dirname(__file__), ".cache"), verbose=1).cache
119 | #
120 | #
121 | # @_disk_cache
122 | # def query(prompt, n=10, max_tokens=150, temperature=1.0, max_batch=32):
123 | # """Query gpt3
124 | #
125 | # :param prompt: Up to 2048 tokens (about 3-4k chars)
126 | # :param n: number of answers
127 | # :param max_tokens:
128 | # :param temperature:
129 | # :param max_batch: max to query at once
130 | # :return: list of answers and then the response items
131 | # """
132 | # if temperature == 0 and n > 1:
133 | # ezlog.debug("Temp 0: no point in running more than one query")
134 | # n = 1
135 | #
136 | # responses = []
137 | # while n > 0:
138 | # m = min(n, max_batch)
139 | # prompt_summary = prompt if len(prompt) < 80 else f"{prompt[:40]}...{prompt[-40:]}"
140 | # ezlog.warn(f"**** Running GPT3 query: temp {temperature}, n={m}, prompt={prompt_summary}")
141 | # time0 = time.perf_counter()
142 | # responses.append(openai.Completion.create(
143 | # engine="davinci-msft",
144 | # prompt=prompt,
145 | # max_tokens=max_tokens,
146 | # temperature=temperature,
147 | # n=m
148 | # ))
149 | # ezlog.info(f"**** Got response in {time.perf_counter()-time0}s...")
150 | # n -= m
151 | #
152 | # return [c["text"] for r in responses for c in r["choices"]], responses
153 |
154 |
155 |
--------------------------------------------------------------------------------
/solvers/gpt3/lm_solve/judge.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import time
4 | from pebble import ProcessPool
5 | import multiprocessing
6 | from concurrent.futures import TimeoutError
7 | from typing import List, Set, Tuple, Dict, Union, Any
8 |
9 | import utils
10 | import ezlog
11 |
12 | def _COPY(x):
13 | return x
14 |
15 | _ENV = dict(List=List, Set=Set, Tuple=Tuple, Dict=Dict, COPY=_COPY,
16 | os=None, sys=None, input=None, open=None, print=None, compile=None, copyright=None)
17 | _UNSAFE = ["import", "builtin", "__class"]
18 |
19 | MAX_WORKERS = multiprocessing.cpu_count() // 2
20 |
21 | _CACHE_PATH = os.path.join(os.path.dirname(__file__), "../.cache") # ".cache"
22 |
23 |
24 | class SetCache:
25 | """Simple cache that stores a set of keys. Cannot remove values. Haven't yet implemented iteration."""
26 | BYTE_LEN = 256 // 8 # for sha256
27 |
28 | def __init__(self, name, path=_CACHE_PATH):
29 | self._path = path
30 | self.name = name
31 | self._filename = os.path.join(path, f"{name}_set.cache")
32 | self._set = None # the main set, loaded lazily
33 |
34 | def update(self, keys):
35 | self._load()
36 | hashes = [self._hash(k) for k in keys]
37 | additions = {h for h in hashes if h not in self._set}
38 |
39 | if additions:
40 | with open(self._filename, "ab") as f: # append to binary file
41 | for h in additions:
42 | f.write(h)
43 | self._set.update(additions)
44 |
45 | def add(self, key):
46 | self.update([key])
47 |
48 | def _hash(self, key):
49 | return hashlib.sha256(bytes(str(key), encoding='utf8')).digest()
50 |
51 | def __contains__(self, key: str):
52 | self._load()
53 | h = self._hash(key)
54 | return h in self._set
55 |
56 | def __delitem__(self, key):
57 | raise NotImplementedError
58 |
59 | def __len__(self):
60 | self._load()
61 | return len(self._set)
62 |
63 | def _load(self):
64 | if self._set is None: # only load if not already loaded
65 | if not os.path.exists(self._path):
66 | ezlog.warn(f"Creating path for `{self.name}` cache")
67 | os.makedirs(self._path)
68 |
69 | time0 = time.perf_counter()
70 | if os.path.exists(self._filename):
71 | with open(self._filename, "rb") as f: # read binary file
72 | data = f.read()
73 | self._set = {data[j:j + self.BYTE_LEN] for j in range(0, len(data), self.BYTE_LEN)}
74 | else:
75 | self._set = set()
76 | dur = time.perf_counter() - time0
77 | ezlog.info(f"Loaded `{self.name}` cache of {len(self):,} items in {dur:.1f}s")
78 |
79 |
80 | def _judge(code_env):
81 | code, env = code_env
82 | for u in _UNSAFE:
83 | if u in code:
84 | return False
85 |
86 | try:
87 | exec(code, env.copy()) # not sure if copy() is necessary
88 | return True
89 | except Exception as e:
90 | return False
91 |
92 |
93 | # Cache judge results (which are nondeterministic due to timeout) for reproducibility
94 | _judge_success = SetCache('judge_success')
95 | _judged_batches = SetCache('judged_batches')
96 |
97 | def judge_parallel(src_codes, timeout, max_workers=MAX_WORKERS, env=_ENV, force_compute=False):
98 | global _judge_success, _judged_batches
99 |
100 | if force_compute or src_codes not in _judged_batches:
101 | new_codes = utils.dedup(code for code in src_codes if code not in _judge_success)
102 | if new_codes:
103 | ezlog.info(f"Judging {len(new_codes)}/{len(src_codes)} new codes (removing duplicates/things in cache)")
104 | successes = []
105 |
106 |
107 | with ProcessPool(max_workers=max_workers) as pool:
108 | future = pool.map(_judge, [(code, env) for code in new_codes], timeout=timeout)
109 |
110 | results = future.result()
111 |
112 | i = 0
113 | while True:
114 | try:
115 | if next(results):
116 | successes.append(new_codes[i])
117 | except StopIteration:
118 | _judge_success.update(successes)
119 | break
120 | except (TimeoutError, Exception) as error:
121 | pass
122 | assert i < len(new_codes)
123 | i += 1
124 | assert i == len(new_codes)
125 | _judged_batches.add(src_codes)
126 |
127 | return [code in _judge_success for code in src_codes]
128 |
129 |
130 | if __name__ == "__main__":
131 | res = judge_parallel([
132 | """1+1
133 | """,
134 | """assert False,'cats'""",
135 | """assert False""",
136 | """1[2]""",
137 | """1/0""",
138 | """while True:
139 | pass""",
140 | """for i in range(10**5):
141 | pass"""
142 | ], timeout=1.0)
143 | print(res)
144 |
--------------------------------------------------------------------------------
/solvers/gpt3/requirements.txt:
--------------------------------------------------------------------------------
1 | astor==0.8.1
2 | numpy==1.22.0
3 | openai==0.6.3
4 | tqdm==4.60.0
5 | transformers==4.30.0
6 | Pebble==4.6.1
7 |
8 | # we ran with Python version sys.version = '3.6.9 (default, Jan 26 2021, 15:33:00) \n[GCC 8.4.0]'
9 | # distro-info===0.18ubuntu0.18.04.1
10 |
--------------------------------------------------------------------------------
/solvers/gpt3/run_gpt3_experiments.py:
--------------------------------------------------------------------------------
1 | """
2 | This script runs the GPT_3 experiments and prints the results to stdout.
3 | It uses cacheing mechanisms so that if run twice with the same parameters, it will give exactly the same
4 | results and will not query the GPT3 API again and will not judge the resulting solutions again. Hence, the first
5 | time you run it, it will be slow, but you can subsequently run it again and it will be fast. It will run the
6 | experiment three times, with different seeds to get different results.
7 |
8 | """
9 |
10 | import lm_solve
11 | import numpy as np
12 |
13 | PARAMS = dict(
14 | temp=0.9,
15 | timeout=1.0, # seconds to judge
16 | n=10 * 1000,
17 | filename="puzzles_with_descriptions.json",
18 | stop="\n",
19 | cache_only=False, # change this to True if you want to run a 2nd time without risking hitting API
20 | )
21 |
22 | BOOTSTRAP_PARAMS = dict(
23 | temp=PARAMS["temp"],
24 | timeout=PARAMS["timeout"],
25 | filename=PARAMS["filename"],
26 | stop=PARAMS["stop"],
27 | cache_only=PARAMS["cache_only"],
28 | ppi=32, # puzzles per iteration
29 | iterations=(PARAMS["n"] + 31) // 32,
30 | )
31 |
32 | STUDY = range(107, 137) # the range of puzzles used in the study
33 |
34 | PREFIX = """
35 | def f1(s: str):
36 | return "Hello " + s == "Hello world"
37 |
38 | assert True == f1("world")
39 |
40 | ---
41 |
42 | def f2(s: str):
43 | return "Hello " + s[::-1] == "Hello world"
44 |
45 | assert True == f2("world"[::-1])
46 |
47 | ---
48 |
49 | def f3(x: List[int]):
50 | return len(x) == 2 and sum(x) == 3
51 |
52 | assert True == f3([1, 2])
53 |
54 | ---
55 |
56 | def f4(s: List[str]):
57 | return len(set(s)) == 1000 and all((x.count("a") > x.count("b")) and ('b' in x) for x in s)
58 |
59 | assert True == f4(["a"*(i+2)+"b" for i in range(1000)])
60 |
61 | ---
62 |
63 | def f5(n: int):
64 | return str(n * n).startswith("123456789")
65 |
66 | assert True == f5(int(int("123456789" + "0"*9) ** 0.5) + 1)
67 |
68 | ---
69 |
70 | """ # trailing newlines important
71 |
72 | PREFIX_DOCSTR = '''
73 | def f1(s: str):
74 | """Find a string that when concatenated onto 'Hello ' gives 'Hello world'."""
75 | return "Hello " + s == "Hello world"
76 |
77 | assert True == f1("world")
78 |
79 | ---
80 |
81 | def f2(s: str):
82 | """Find a string that when reversed and concatenated onto 'Hello ' gives 'Hello world'."""
83 | return "Hello " + s[::-1] == "Hello world"
84 |
85 | assert True == f2("world"[::-1])
86 |
87 | ---
88 |
89 | def f3(x: List[int]):
90 | """Find a list of two integers whose sum is 3."""
91 | return len(x) == 2 and sum(x) == 3
92 |
93 | assert True == f3([1, 2])
94 |
95 | ---
96 |
97 | def f4(s: List[str]):
98 | """Find a list of 1000 distinct strings which each have more 'a's than 'b's and at least one 'b'."""
99 | return len(set(s)) == 1000 and all((x.count("a") > x.count("b")) and ('b' in x) for x in s)
100 |
101 | assert True == f4(["a"*(i+2)+"b" for i in range(1000)])
102 |
103 | ---
104 |
105 | def f5(n: int):
106 | """Find an integer whose perfect square begins with 123456789 in its decimal representation."""
107 | return str(n * n).startswith("123456789")
108 |
109 | assert True == f5(int(int("123456789" + "0"*9) ** 0.5) + 1)
110 |
111 | ---
112 |
113 | ''' # trailing newlines important
114 |
115 |
116 | def run(seed=0):
117 | sols = [lm_solve.prompt_experiment(**PARAMS, prefix="", seed=seed),
118 | lm_solve.prompt_experiment(**PARAMS, prefix=PREFIX, seed=seed),
119 | lm_solve.prompt_experiment(**PARAMS, prefix=PREFIX_DOCSTR, add_docstring=True, seed=seed)]
120 | problems_solved = [sorted([i for i, (f, gs) in enumerate(s) if gs]) for s in sols]
121 | bootstrap = lm_solve.bootstrap(**BOOTSTRAP_PARAMS, seed=seed)
122 | print(f"run={seed} ALL DONE!\n\n")
123 | print(f"run={seed} RESULTS " + "=" * 50)
124 | print()
125 |
126 | # Instead of running until first success, and outputting number of attempts, we do something more accurate.
127 | # We run for N tries for each problem and do not stop on first success. Then, we use the number of successes
128 | # for a better estimate of the average number of attempts required for first success. If we have s successes
129 | # out of N attempts, then the expected number of attempts is (N - s) / (1 + s). This is the expectation of the
130 | # random variable that is: when you permute the attempts uniformly at random, how many attempts before the
131 | # first success. If s=N, it's 0, if s=1, it's (N-1)/2, etc.
132 | counts = [[(PARAMS["n"] - len(gs)) / (1 + len(gs)) for f, gs in s if gs] for s in sols]
133 | counts.append([m for m, _i, _f, _a in bootstrap])
134 | counts = [[1 + z for z in c] for c in counts] # add 1 to make it 1-based
135 | for c in counts:
136 | c.sort()
137 | print(f"run={seed} (Expected) number of attempts before a problem is solved [short, med, long, bootstrap]:")
138 | print(counts)
139 | problems_solved.append([i for _m, i, _f, _a in bootstrap])
140 | print()
141 | print(f"run={seed} Which problems were solved [short, med, long, bootstrap]:")
142 | print(problems_solved)
143 | print()
144 | print(f"run={seed} Number of problems solved [short, med, long, bootstrap]:")
145 | print([len(c) for c in counts])
146 | print()
147 | print(f"run={seed} Number of 30 study problems solved [short, med, long, bootstrap]:")
148 | print([len([i for i in s if i in STUDY]) for s in problems_solved])
149 | print()
150 | difficulties = [1.0 for _ in range(len(sols[0]))]
151 |
152 | k = 1
153 | for m, i, f, a in bootstrap:
154 | difficulties[i] = np.log(m + 1) / np.log(PARAMS["n"])
155 |
156 | # These commented lines print the problems that bootstrap solved
157 | # print()
158 | # print(f"# Bootstrap solved after {m + 1} tries:")
159 | # print(f.replace("def f", "def f" + str(k)))
160 | # import json
161 | #
162 | # print(f"SOL:", json.dumps(a))
163 | k += 1
164 | print(f"run={seed} Bootstrap difficulties for study puzzles:")
165 | print([difficulties[i] for i in STUDY])
166 |
167 |
168 | if __name__ == "__main__":
169 | for seed in range(3):
170 | run(seed)
171 |
--------------------------------------------------------------------------------
/solvers/gpt3/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from utils.str_utils import *
2 | from utils.time_utils import *
3 |
4 | import functools
5 | import operator
6 |
7 |
8 | def prod(iterable): # like sum but product
9 | return functools.reduce(operator.mul, iterable, 1)
10 |
11 |
12 | def flatten(it):
13 | return (e for a in it for e in (flatten(a) if isinstance(a, (tuple, list)) else (a,)))
14 |
15 |
16 | def load_json(filename):
17 | import json
18 | with open(filename, "r") as f:
19 | return json.load(f)
20 |
21 |
22 | def viz_py(py):
23 | import astor, ast
24 | print(astor.dump_tree(ast.parse(py)))
25 |
26 |
27 | def dedup(li):
28 | seen = set()
29 | return [x for x in li if x not in seen and not seen.add(x)]
30 |
--------------------------------------------------------------------------------
/solvers/gpt3/utils/str_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | def get_lambda_arg_name(lam):
4 | assert lam.startswith("lambda ")
5 | return lam[len("lambda "):lam.index(":")].strip()
6 |
7 |
8 | def stringify(const):
9 | if type(const) is str:
10 | return json.dumps(const)
11 | return str(const)
12 |
13 |
14 | def color_str(obj, code="\033[0;36m"):
15 | return code + str(obj) + '\033[0m'
16 |
17 |
--------------------------------------------------------------------------------
/solvers/gpt3/utils/time_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 |
4 | logger = logging.getLogger(__name__)
5 |
6 | def timeit(method):
7 | def timed(*args, **kw):
8 | tick = time.time()
9 | result = method(*args, **kw)
10 | tock = time.time()
11 | logger.debug(f'{method.__name__}: {tock - tick:.3f}s')
12 |
13 | return result
14 | return timed
15 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import inspect
4 | import io
5 | import os
6 |
7 | my_path = os.path.dirname(__file__)
8 |
9 | def check_hashseed(desired_seed = 0):
10 | if os.environ.get('PYTHONHASHSEED') != desired_seed:
11 | info(f"Ideally set PYTHONHASHSEED={desired_seed} for perfect reproducibility")
12 | return False
13 | return True
14 |
15 |
16 | def inv_dict(d):
17 | ans = {}
18 | for k, v in d.items():
19 | if v not in ans:
20 | ans[v] = []
21 | ans[v].append(k)
22 | return ans
23 |
24 |
25 | def remove_docstring(f):
26 | """Remove docstring"""
27 | assert '\n """' in f, f"No triple quote docstring (after four spaces) in: \n{f}"
28 | i = f.index('\n """')
29 | j = f.index('"""', i + 8)
30 | return f[:i + 1] + f[j + 4:]
31 |
32 |
33 | def get_docstring(f):
34 | assert '\n """' in f, f"No triple quote docstring (after four spaces) in: \n{f}"
35 | i = f.index('\n """')
36 | j = f.index('"""', i + 8)
37 | docstring = f[i + 1:j + 3]
38 | if not docstring.strip(' "'):
39 | warn(f"Empty docstring in:\n{f}")
40 | return docstring
41 |
42 |
43 | def flatten(it):
44 | return (e for a in it for e in (flatten(a) if isinstance(a, (tuple, list)) else (a,)))
45 |
46 | def save_json(obj, filename, make_dirs_if_necessary=False, indent=2, **kwargs):
47 | """Saves compressed file if filename ends with '.gz'"""
48 | import json
49 | if make_dirs_if_necessary:
50 | os.makedirs(os.path.dirname(filename), exist_ok=True)
51 | if filename.endswith(".gz"):
52 | import gzip
53 | with gzip.open(filename, "wt") as f:
54 | return json.dump(obj, f, indent=indent, **kwargs)
55 | with open(filename, "w", encoding="utf8") as f:
56 | return json.dump(obj, f, indent=indent, **kwargs)
57 |
58 | def load_json(filename):
59 | """Loads compressed file if filename ends with '.gz'"""
60 | import json
61 | if filename.endswith(".gz"):
62 | import gzip
63 | with gzip.open(filename, "rt") as f:
64 | return json.load(f)
65 | with open(filename, "r", encoding="utf8") as f:
66 | return json.load(f)
67 |
68 |
69 | def stringify(const):
70 | if type(const) is str:
71 | return json.dumps(const)
72 | return str(const)
73 |
74 |
75 | def dedup(stuff):
76 | seen = set()
77 | return [a for a in stuff if a not in seen and not seen.add(a)]
78 |
79 |
80 | def color_str(obj, code="\033[0;36m"):
81 | return code + str(obj) + '\033[0m'
82 |
83 |
84 | _configured = False
85 |
86 |
87 | def configure_logging(stdio_level=logging.INFO,
88 | file_level=logging.DEBUG,
89 | filename=os.path.join(my_path, ".problems.log")):
90 | global _configured
91 | if _configured:
92 | warning("Re-configuring logging")
93 | stdio_handler = logging.StreamHandler()
94 | stdio_handler.setLevel(stdio_level)
95 | file_hanlder = logging.FileHandler(filename)
96 | file_hanlder.setLevel(file_level)
97 |
98 | logging.basicConfig(
99 | format="%(asctime)s - %(levelname)s - %(name)s - %(message).200s",
100 | datefmt="%m/%d/%Y %H:%M:%S",
101 | level=min(stdio_level, file_level),
102 | handlers=[stdio_handler, file_hanlder]
103 | )
104 |
105 | _configured = True
106 | _get_or_create_logger().debug("Configured logging")
107 |
108 |
109 | _loggers = {}
110 |
111 |
112 | def _get_or_create_logger():
113 | global _configured, _loggers
114 | if not _configured:
115 | configure_logging()
116 | name = "_"
117 | for frame in inspect.stack():
118 | name = inspect.getmodule(frame[0]).__name__
119 | if name != __name__:
120 | break
121 | if name not in _loggers:
122 | _loggers[name] = logging.getLogger(name)
123 | return _loggers[name]
124 |
125 |
126 | def print_to_string(*args, end="", **kwargs):
127 | with io.StringIO() as buf:
128 | print(*args, file=buf, end=end, **kwargs)
129 | return buf.getvalue()
130 |
131 |
132 | def debug(*args, **kwargs):
133 | _get_or_create_logger().debug(print_to_string(*args, **kwargs))
134 |
135 |
136 | def info(*args, **kwargs):
137 | _get_or_create_logger().info(print_to_string(*args, **kwargs))
138 |
139 |
140 | log = info
141 |
142 |
143 | def warning(*args, **kwargs):
144 | _get_or_create_logger().warning(print_to_string(*args, **kwargs))
145 |
146 |
147 | warn = warning
148 |
149 |
150 | def error(*args, **kwargs):
151 | _get_or_create_logger().error(print_to_string(*args, **kwargs))
152 |
--------------------------------------------------------------------------------