├── .gitignore ├── LICENSE ├── README.md ├── lake-manifest.json ├── lakefile.lean ├── lean-toolchain ├── partII_dsp ├── README.md ├── dsp_utils.py ├── isabelle_setup.md └── notebooks │ ├── II_dsp__part1_intro.ipynb │ ├── II_dsp__part2_dsp.ipynb │ └── images │ ├── .gitkeep │ ├── dsp.png │ ├── dsp_example.png │ ├── dsp_plot.png │ ├── dsp_search.png │ ├── prove.png │ ├── sketch.png │ └── sledgehammer.png └── partI_nextstep ├── README.md ├── notebooks ├── I_nextstep_lean__part0_intro.ipynb ├── I_nextstep_lean__part1_data.ipynb ├── I_nextstep_lean__part2_learn.ipynb ├── I_nextstep_lean__part3_proofsearch.ipynb ├── I_nextstep_lean__part4_evaluation.ipynb ├── I_nextstep_lean__part5_llmsuggest.ipynb ├── data │ └── successes_mathlib4_200_wellecks_llmstep-mathlib4-pythia2.8b.json └── images │ ├── banach │ ├── banach_1.png │ ├── banach_2.png │ ├── banach_3.png │ ├── banach_4.png │ └── banach_5.png │ ├── leandojo_1.png │ ├── llmsuggest │ ├── llmstep_gif.gif │ ├── llmsuggest.gif │ └── llmsuggest_examples.png │ ├── proof_state_1.png │ ├── proof_state_2.png │ └── proof_state_3.png ├── ntp_lean ├── ExtractSimple.lean ├── LLMsuggest.lean └── examples │ ├── example0.lean │ └── example_demo.lean ├── ntp_python ├── __init__.py ├── data.py ├── llmsuggest │ ├── server.py │ └── suggest.py ├── postprocess_ast.py ├── proofsearch_dojo.py ├── proofsearch_pylean.py └── tune.py └── scripts ├── ds_config.json └── tune_proofstep.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *model/ 2 | *.ast.json 3 | *.dep_paths 4 | lake-packages 5 | *.ckpt 6 | data 7 | generated/ 8 | 9 | *.pyc 10 | *.tar.gz 11 | # JetBrains PyCharm IDE 12 | .idea/ 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # macOS dir files 23 | .DS_Store 24 | 25 | # Distribution / packaging 26 | .Python 27 | env/ 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | 44 | # Checkpoints 45 | checkpoints 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | .hypothesis/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # dotenv 102 | .env 103 | 104 | # virtualenv 105 | .venv 106 | venv/ 107 | ENV/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sean Welleck 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A tutorial on neural theorem proving 2 | 3 | *Neural theorem proving* combines neural language models with formal proof assistants.\ 4 | This tutorial introduces two research threads in neural theorem proving via interactive Jupyter notebooks. 5 | 6 | 7 | ## Part I : Next-step suggestion 8 | 9 | Builds a neural next-step suggestion tool, introducing concepts and past work in neural theorem proving along the way. 10 | 11 | 12 | 13 | #### Notebooks: 14 | | Topic | Notebook | 15 | |:-----------------------|-------:| 16 | | 0. Intro | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part0_intro.ipynb) | 17 | | 1. Data | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part1_data.ipynb) | 18 | | 2. Learning | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part2_learn.ipynb) | 19 | | 3. Proof Search | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part3_proofsearch.ipynb) | 20 | | 4. Evaluation | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part4_evaluation.ipynb) | 21 | | 5. `llmsuggest` | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part5_llmsuggest.ipynb) | 22 | 23 | All notebooks are in ([`partI_nextstep/notebooks`](./partI_nextstep/notebooks)). See [`partI_nextstep/ntp_python`](./partI_nextstep/ntp_python) and [`partI_nextstep/ntp_lean`](./partI_nextstep/ntp_lean) for the Python and Lean files covered in the notebooks. 24 | 25 | #### Setup: 26 | Please follow the setup instructions in [`partI_nextstep/README.md`](./partI_nextstep/README.md). 27 | 28 | ## Part II : Language cascades 29 | Chain together language models to guide formal proof search with informal proofs. 30 | 31 | 32 | #### Notebooks: 33 | | Topic | Notebook | 34 | |:-----------------------|-------:| 35 | | 1. Language model cascades | [notebook](./partII_dsp/notebooks/II_dsp__part1_intro.ipynb) | 36 | | 2. Draft, Sketch, Prove | [notebook](./partII_dsp/notebooks/II_dsp__part2_dsp.ipynb) | 37 | 38 | All notebooks are in ([`partII_dsp/notebooks`](./partII_dsp/notebooks)). 39 | 40 | #### Setup: 41 | Please follow the setup instructions in [`partII_dsp/README.md`](./partII_dsp/README.md). 42 | 43 | 44 | ------- 45 | ### History 46 | These materials were originally developed as part of a IJCAI 2023 tutorial. \ 47 | Slides for the 1 hour summary presentation given at IJCAI 2023 are [here](https://wellecks.com/data/welleck2023ntp_tutorial.pdf). 48 | 49 | #### Citation 50 | 51 | If you find this tutorial or repository useful in your work, please cite: 52 | ``` 53 | @misc{ntptutorial, 54 | author = {Sean Welleck}, 55 | title = {Neural theorem proving tutorial}, 56 | year = {2023}, 57 | publisher = {GitHub}, 58 | journal = {GitHub repository}, 59 | howpublished = {\url{https://github.com/wellecks/ntptutorial}}, 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /lake-manifest.json: -------------------------------------------------------------------------------- 1 | {"version": 4, 2 | "packagesDir": "lake-packages", 3 | "packages": 4 | [{"git": 5 | {"url": "https://github.com/EdAyers/ProofWidgets4", 6 | "subDir?": null, 7 | "rev": "c43db94a8f495dad37829e9d7ad65483d68c86b8", 8 | "name": "proofwidgets", 9 | "inputRev?": "v0.0.11"}}, 10 | {"git": 11 | {"url": "https://github.com/leanprover-community/mathlib4.git", 12 | "subDir?": null, 13 | "rev": "3de751e5b518e96b9181328441aa1ab1677d8cf0", 14 | "name": "mathlib", 15 | "inputRev?": null}}, 16 | {"git": 17 | {"url": "https://github.com/gebner/quote4", 18 | "subDir?": null, 19 | "rev": "c71f94e34c1cda52eef5c93dc9da409ab2727420", 20 | "name": "Qq", 21 | "inputRev?": "master"}}, 22 | {"git": 23 | {"url": "https://github.com/JLimperg/aesop", 24 | "subDir?": null, 25 | "rev": "ca73109cc40837bc61df8024c9016da4b4f99d4c", 26 | "name": "aesop", 27 | "inputRev?": "master"}}, 28 | {"git": 29 | {"url": "https://github.com/leanprover/std4", 30 | "subDir?": null, 31 | "rev": "d5471b83378e8ace4845f9a029af92f8b0cf10cb", 32 | "name": "std", 33 | "inputRev?": "main"}}]} 34 | -------------------------------------------------------------------------------- /lakefile.lean: -------------------------------------------------------------------------------- 1 | import Lake 2 | open Lake DSL 3 | 4 | package «ntptutorial» 5 | 6 | require mathlib from git 7 | "https://github.com/leanprover-community/mathlib4.git" 8 | -------------------------------------------------------------------------------- /lean-toolchain: -------------------------------------------------------------------------------- 1 | leanprover/lean4:nightly-2023-06-10 2 | -------------------------------------------------------------------------------- /partII_dsp/README.md: -------------------------------------------------------------------------------- 1 | ## Part II : Language cascades 2 | Chain together language models to guide formal proof search with informal proofs. 3 | 4 | 5 | #### Notebooks: 6 | | Topic | Notebook | 7 | |:-----------------------|-------:| 8 | | 1. Language model cascades | [notebook](./notebooks/II_dsp__part1_intro.ipynb) | 9 | | 2. Draft, Sketch, Prove | [notebook](./notebooks/II_dsp__part2_dsp.ipynb) | 10 | 11 | All notebooks are in ([`partII_dsp/notebooks`](./notebooks)). 12 | 13 | #### Setup 14 | The Draft, Sketch, Prove notebook requires setting up an Isabelle proof checker for the "sketch" stage. 15 | 16 | Please follow this guide to set up the Isabelle Proof Checker: [Isabelle Proof Checker Setup](./isabelle_setup.md) -------------------------------------------------------------------------------- /partII_dsp/dsp_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import openai 4 | import sys 5 | 6 | 7 | class LMFunction(object): 8 | def __init__(self, engine='gpt-3.5-turbo', max_tokens=512): 9 | self.engine = engine 10 | self.max_tokens = max_tokens 11 | self.openai = openai 12 | openai.api_key = os.environ['OPENAI_API_KEY'] 13 | 14 | def _call_api(self, prompt, engine, max_tokens, max_retries=10, retry_wait=2): 15 | for i in range(max_retries): 16 | try: 17 | return self.openai.ChatCompletion.create( 18 | model=engine, 19 | messages=[ 20 | {"role": "system", "content": "You are a helpful assistant."}, 21 | {"role": "user", "content": prompt} 22 | ], 23 | max_tokens=max_tokens, 24 | temperature=1.0 25 | ) 26 | except self.openai.error.OpenAIError as e: 27 | time.sleep(retry_wait) 28 | return {'choices': [{'message': {'content': ''}}]} 29 | 30 | def _parse_message(self, msg): 31 | try: 32 | content = msg['choices'][0]['message']['content'] 33 | except (IndexError, KeyError): 34 | content = '' 35 | return content 36 | 37 | def f(self, prompt, x): 38 | msg = self._call_api( 39 | prompt=prompt+x, 40 | engine=self.engine, 41 | max_tokens=self.max_tokens 42 | ) 43 | evaluation = self._parse_message(msg) 44 | return evaluation 45 | 46 | 47 | class Checker(object): 48 | """A modified version of the Draft, Sketch, Prove proof-checking client. 49 | (https://github.com/albertqjiang/draft_sketch_prove/blob/main/autoformalization/checker.py) 50 | 51 | This checker supports Isabelle2022 via the new version of PISA 52 | (https://albertqjiang.github.io/Portal-to-ISAbelle/). 53 | 54 | It supports checking a miniF2F-style proof via `check`. 55 | 56 | Finally, it replaces `sledgehammer` with a call to `normalhammer`. 57 | """ 58 | def __init__(self, working_dir, isa_path, theory_file, port=9000): 59 | sys.path.append(os.environ['PISA_PATH']) 60 | try: 61 | from pisa_client import initialise_env 62 | self.initialise_env = initialise_env 63 | except: 64 | print("Set $PISA_PATH to /yourpath/to/Portal-to-ISAbelle/src/main/python") 65 | 66 | self.working_dir = working_dir 67 | self.isa_path = isa_path 68 | self.theory_file = theory_file 69 | self.port = port 70 | 71 | def _initialize(self): 72 | env = self.initialise_env( 73 | self.port, 74 | isa_path=self.isa_path, 75 | theory_file_path=self.theory_file, 76 | working_directory=self.working_dir 77 | ) 78 | return env 79 | 80 | def _exit(self, env): 81 | try: 82 | env.post('exit') 83 | except: 84 | print("env.post('exit') timed out") 85 | pass 86 | os.system("ps aux | grep Isabelle | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1") 87 | os.system("ps aux | grep poly | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1") 88 | 89 | def _parse_output(self, obs): 90 | """Parse the sledgehammer output, otherwise return an empty string""" 91 | if '' in obs: 92 | output = obs.split('')[0] 93 | else: 94 | output = '' 95 | return output 96 | 97 | def _run_step(self, step, i, tls_name, env): 98 | obs, reward, done, metadata = env.step_to_top_level_state( 99 | action=step, 100 | tls_name=tls_name, 101 | new_name='default_%d' % i 102 | ) 103 | error = None 104 | if 'error:' in obs or 'Step error' in obs or 'Unknown error' in obs: 105 | error = obs 106 | return obs, reward, done, metadata, error 107 | 108 | def _run_sledgehammer(self, step, i, tls_name, env): 109 | # First try heuristics 110 | for heuristic in ['by auto', 'by simp', 'by blast', 'by fastforce', 'by force', 'by eval', 'by presburger', 'by sos', 'by arith', 'by linarith', 'by (auto simp: field_simps)']: 111 | step_ = step.replace('normalhammer', heuristic) 112 | obs, reward, done, metadata, error = self._run_step(step_, i, tls_name, env) 113 | if error is None: 114 | obs = '%s %s' % (heuristic, obs) 115 | return obs, reward, done, metadata, error 116 | # Try sledgehammer 117 | out = self._run_step(step, i, tls_name, env) 118 | return out 119 | 120 | def check(self, statement_and_proof): 121 | # Initialize environment 122 | env = self._initialize() 123 | env.initialise() 124 | 125 | # Wrap and parse theorem 126 | theory = Checker.wrap_theorem(statement_and_proof) 127 | steps = Checker.get_parsed(env, theory) 128 | 129 | result = self._check(env, steps) 130 | return result 131 | 132 | def _check(self, env, steps): 133 | done = False 134 | reason = '' 135 | success = False 136 | step_results = [] 137 | tls_name = 'default' 138 | for i, step in enumerate(steps): 139 | try: 140 | time0 = time.time() 141 | if 'normalhammer' in step: 142 | obs, reward, done, metadata, error = self._run_sledgehammer(step, i, tls_name, env) 143 | else: 144 | obs, reward, done, metadata, error = self._run_step(step, i, tls_name, env) 145 | step_time = time.time() - time0 146 | step_results.append(dict(index=i, step=step, output=self._parse_output(obs), step_time=step_time)) 147 | if error is not None: 148 | reason = error 149 | success = False 150 | done = False 151 | break 152 | except: 153 | # Timeout - end the proof attempt 154 | success = False 155 | done = False 156 | reason = 'timeout (%d)' % len(step_results) 157 | step_results.append(dict(index=i, step=step, output='')) 158 | break 159 | 160 | # Change when successful 161 | tls_name = 'default_%d' % i 162 | 163 | if done and reward == 1.0: 164 | success = True 165 | 166 | result = { 167 | 'success': success, 168 | 'reason': reason, 169 | 'num_steps': len(steps), 170 | 'last_step': len(step_results), 171 | 'step_results': step_results, 172 | 'theorem_and_proof': self.reconstruct(step_results) if success else '' 173 | } 174 | # Exit environment 175 | self._exit(env) 176 | return result 177 | 178 | @staticmethod 179 | def reconstruct(step_results): 180 | steps = [] 181 | for step_result in step_results[1:]: 182 | if step_result['output'] != '': 183 | steps.append(step_result['output'].strip()) 184 | else: 185 | steps.append(step_result['step'].strip()) 186 | theorem_and_proof = '\n'.join(steps) 187 | return theorem_and_proof 188 | 189 | @staticmethod 190 | def wrap_theorem(theorem): 191 | return 'theory Interactive imports HOL.HOL Complex_Main "HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" "Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" "HOL-Number_Theory.Number_Theory" \n begin\n%s' % theorem 192 | 193 | @staticmethod 194 | def get_parsed(env, theory, tls_name='default'): 195 | # HACK: the parsing doesn't work well with `normalhammer`, so we replace 196 | # all hammer calls with sorry, then replace sorry to normalhammer after parsing. 197 | theory = theory.replace('sledgehammer', 'sorry') 198 | theory = theory.replace('normalhammer', 'sorry') 199 | 200 | steps = env.post(f" ${theory}") 201 | steps = steps.split('') 202 | steps = [s for s in steps if s.strip() != ''] 203 | # remove weird '$' step and whitespace steps 204 | steps = [s for s in steps if s != '$' and s.strip() != ''] 205 | steps = [s.replace('sorry', 'normalhammer') for s in steps] 206 | return steps 207 | -------------------------------------------------------------------------------- /partII_dsp/isabelle_setup.md: -------------------------------------------------------------------------------- 1 | # Setup: Isabelle Proof Checker 2 | 3 | Follow this guide to set up Isabelle proof checking. At the end, we will have a Python interface for checking a theorem and proof, e.g. 4 | ```python 5 | theorem_and_proof = """theorem ...""" 6 | result = checker.check(theorem_and_proof) 7 | ``` 8 | 9 | ## Setup 10 | 11 | Proof checking is done via [PISA](https://github.com/albertqjiang/Portal-to-ISAbelle/tree/56def2c39f85d211e1f40cc5765581a567879106). We implement a client that interacts with PISA (`Checker` in [dsp_utils.py](./dsp_utils.py)). 12 | 13 | Here are setup steps for a non-dockerized environment. The setup is heavily based on the [PISA readme](https://github.com/albertqjiang/Portal-to-ISAbelle/tree/56def2c39f85d211e1f40cc5765581a567879106) and [Dockerfile](https://github.com/albertqjiang/Portal-to-ISAbelle/blob/main/docker/Dockerfile). You may need to refer to those if something goes wrong. 14 | 15 | ### Installation (PISA and Isabelle) 16 | First, we need to set up PISA and Isabelle. 17 | ```bash 18 | # -- PISA setup 19 | # Download Portal-to-ISAbelle (PISA) 20 | cd ~/ 21 | git clone https://github.com/albertqjiang/Portal-to-ISAbelle.git 22 | 23 | # Scala installation 24 | sudo apt-get install zip 25 | curl -s "https://get.sdkman.io" | bash 26 | source "~/.sdkman/bin/sdkman-init.sh" 27 | sdk install java 11.0.11-open 28 | sdk install sbt 29 | 30 | # Compile PISA 31 | cd ~/Portal-to-ISAbelle 32 | sbt compile 33 | sbt assembly 34 | 35 | # -- Isabelle setup 36 | # Download Isabelle 37 | wget https://isabelle.in.tum.de/dist/Isabelle2022_linux.tar.gz && \ 38 | tar -xzf Isabelle2022_linux.tar.gz 39 | 40 | # Install Isabelle (i.e., move to WORK_DIR, make an alias). 41 | export WORK_DIR=~/ 42 | mv Isabelle2022 ${WORK_DIR}/ 43 | echo 'alias isabelle=${WORK_DIR}/Isabelle2022/bin/isabelle' >> ~/.bashrc 44 | source ~/.bashrc 45 | 46 | # Build Isabelle HOL (creates heaps in ~/.isabelle) 47 | isabelle build -b -D ${WORK_DIR}/Isabelle2022/src/HOL/ -j 20 48 | ``` 49 | 50 | At the end, here's what the setup looks like: 51 | - Portal-to-ISAbelle github repo in `~/Portal-to-ISAbelle` 52 | - Isabelle in `~/Isabelle2022`, e.g. 53 | ``` 54 | ls ~/Isabelle2022 55 | 56 | => ANNOUNCE bin contrib ... 57 | ``` 58 | - Isabelle heaps in `~/.isabelle`, e.g. 59 | ``` 60 | ls ~/.isabelle/Isabelle2022/heaps/polyml-5.9_x86_64_32-linux/ 61 | 62 | => Group-Ring-Module HOL-Corec_Examples HOL-Isar_Examples ... 63 | ``` 64 | You can test out the installation so far by starting a PISA server: 65 | ```bash 66 | cd ~/Portal-to-ISAbelle 67 | sbt "runMain pisa.server.PisaOneStageServer9000" 68 | ``` 69 | 70 | The next step is to specify a configuration that allows the Python client to talk to the Scala PISA server, as described below. 71 | 72 | ### Configuration 73 | 74 | At a high-level, we have three components: 75 | 1. The PISA Scala server 76 | 2. The PISA python library 77 | 3. Our python client, [Checker](./checker.py) 78 | 79 | We need to set environment variables and configuration so that all three can talk to each other. 80 | 81 | #### Set PISA_PATH 82 | 83 | First, set a `PISA_PATH` environment variable that points to PISA's python directory: 84 | ```bash 85 | export PISA_PATH=~/Portal-to-ISAbelle/src/main/python 86 | ``` 87 | The variable is used to import PISA's python client (`Portal-to-Isabelle/src/main/python/pisa_client.py`) in Checker. \ 88 | This links components 2 and 3. 89 | 90 | 91 | #### Setup a working directory and working file 92 | PISA is initialized by providing a particular working directory and file. \ 93 | We will create a file called `Interactive.thy` and put it in the `HOL/Examples` directory: 94 | 95 | ```bash 96 | vim ~/Isabelle2022/src/HOL/Examples/Interactive.thy 97 | ``` 98 | ``` 99 | theory Interactive 100 | imports Complex_Main 101 | begin 102 | 103 | end 104 | ``` 105 | We will use this working directory and file when initializing the checker. 106 | 107 | #### Initializing the checker (in Python, e.g. in the [Draft, Sketch, Prove notebook](./notebooks/II_dsp__part2_dsp.ipynb)) 108 | 109 | To initialize the checker, we need to specify the path to Isabelle, the working directory, and the working file (theory file). \ 110 | These are used to initialize a working Isabelle instance. This links components 1 and 2. 111 | 112 | Here is an example command found in the [Draft, Sketch, Prove notebook](./notebooks/II_dsp__part2_dsp.ipynb) based on the setup above (here, the home directory `~` is `/home/seanw`): 113 | ```python 114 | checker = dsp_utils.Checker( 115 | working_dir='/home/seanw/Isabelle2022/src/HOL/Examples', 116 | isa_path='/home/seanw/Isabelle2022', 117 | theory_file='/home/seanw/Isabelle2022/src/HOL/Examples/Interactive.thy', 118 | port=9000 119 | ) 120 | ``` 121 | 122 | #### Start the PISA server 123 | Finally, start a PISA server in a separate tmux window (similar to what was done above in Installation): 124 | ```bash 125 | cd ~/Portal-to-ISAbelle 126 | sbt "runMain pisa.server.PisaOneStageServer9000" 127 | ``` 128 | The port specified in the config (here `"port": 9000`) should match the number that appears in the command (`PisaOneStageServer9000`). 129 | 130 | We *leave the server running while running the notebook* (hence, the separate tmux window). 131 | 132 | #### Run the proof checker! 133 | Now try running the proof checker (e.g. in the [Draft, Sketch, Prove notebook](./notebooks/II_dsp__part2_dsp.ipynb))! The notebook uses a call to `checker` that looks like: 134 | ```python 135 | theorem_and_sledgehammer_proof = """theorem gcd_lcm: 136 | assumes "gcd (n :: nat) 4 = 1" 137 | and "lcm (n :: nat) 4 = 28" 138 | shows "n = 7" 139 | proof - 140 | have c1: "1*28 = n*4" using assms 141 | sledgehammer 142 | then have c2: "n = 1*28/4" 143 | sledgehammer 144 | then show ?thesis 145 | sledgehammer 146 | qed""" 147 | 148 | result = checker.check(theorem_and_sledgehammer_proof) 149 | 150 | print("\n==== Success: %s" % result['success']) 151 | print("--- Complete proof:\n%s" % result['theorem_and_proof']) 152 | ``` 153 | -------------------------------------------------------------------------------- /partII_dsp/notebooks/II_dsp__part1_intro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "#### Language cascades | part 1: introduction\n", 8 | "Tutorial on neural theorem proving\\\n", 9 | "Author: Sean Welleck\n", 10 | "\n", 11 | "----------------\n", 12 | "\n", 13 | "Tools such as [Chat-GPT]() show the flexibility of modern neural language generators.\\\n", 14 | "Namely, a single generation system can often perform a task by simply providing a suitable *prompt*:\n", 15 | "\n", 16 | "$\\quad y=f(p_\\theta(\\cdot|x;P)),$\n", 17 | "\n", 18 | "where $x$ is an input, $P$ is a prompt, and $f(\\cdot)$ is a decoding algorithm.\n", 19 | "\n", 20 | "\n", 21 | "Let's look at one of these functions:\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import time\n", 31 | "import os\n", 32 | "import openai\n", 33 | "\n", 34 | "class LMFunction(object):\n", 35 | " def __init__(self, engine='gpt-3.5-turbo', max_tokens=512):\n", 36 | " self.engine = engine\n", 37 | " self.max_tokens = max_tokens\n", 38 | " self.openai = openai\n", 39 | " openai.api_key = os.environ['OPENAI_API_KEY']\n", 40 | "\n", 41 | " def _call_api(self, prompt, engine, max_tokens, max_retries=10, retry_wait=2):\n", 42 | " for i in range(max_retries):\n", 43 | " try:\n", 44 | " return self.openai.ChatCompletion.create(\n", 45 | " model=engine,\n", 46 | " messages=[\n", 47 | " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", 48 | " {\"role\": \"user\", \"content\": prompt}\n", 49 | " ],\n", 50 | " max_tokens=max_tokens,\n", 51 | " temperature=1.0\n", 52 | " )\n", 53 | " except self.openai.error.OpenAIError as e:\n", 54 | " time.sleep(retry_wait)\n", 55 | " return {'choices': [{'message': {'content': ''}}]}\n", 56 | "\n", 57 | " def _parse_message(self, msg):\n", 58 | " try:\n", 59 | " content = msg['choices'][0]['message']['content']\n", 60 | " content = content.strip().split('\\n')[0]\n", 61 | " except (IndexError, KeyError):\n", 62 | " content = ''\n", 63 | " return content\n", 64 | "\n", 65 | " def f(self, prompt, x):\n", 66 | " msg = self._call_api(\n", 67 | " prompt=prompt+x,\n", 68 | " engine=self.engine,\n", 69 | " max_tokens=self.max_tokens\n", 70 | " )\n", 71 | " evaluation = self._parse_message(msg)\n", 72 | " return evaluation" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 2, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "data": { 82 | "text/plain": [ 83 | "[('1947256', False),\n", 84 | " ('1945256', False),\n", 85 | " ('1946256', False),\n", 86 | " ('1947176', True),\n", 87 | " ('1947176', True),\n", 88 | " ('1947056', False),\n", 89 | " ('1947456', False),\n", 90 | " ('1947256', False),\n", 91 | " ('1947256', False),\n", 92 | " ('1947256', False)]" 93 | ] 94 | }, 95 | "execution_count": 2, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "prompt = \"\"\"Multiply two numbers. Here are some examples:\n", 102 | "432*342=147744\n", 103 | "98*19=1862\n", 104 | "\"\"\"\n", 105 | "\n", 106 | "p = LMFunction('gpt-4')\n", 107 | "\n", 108 | "outputs = [p.f(prompt, '872*2233=') for _ in range(10)]\n", 109 | "[(output, output==str(872*2233)) for output in outputs]" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 13, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "data": { 119 | "text/plain": [ 120 | "1947176" 121 | ] 122 | }, 123 | "execution_count": 13, 124 | "metadata": {}, 125 | "output_type": "execute_result" 126 | } 127 | ], 128 | "source": [ 129 | "872*2233" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "The result above shows two interesting things:\n", 137 | "1. The function is stochastic; it can return different answers each time it is called.\n", 138 | "2. The function is capable of producing a correct answer; it gets it correct 2 times.\n", 139 | "\n", 140 | "Therefore, one way a stochastic function like this is useful is to pair it with a reliable verifier.\n", 141 | "\n", 142 | "\n", 143 | "#### Why is this useful?\n", 144 | "The main attraction of these functions is their flexibility. For instance,\n", 145 | "it is easy to implement a function that maps language input to a call to Sympy:\n", 146 | "\n", 147 | "```\n", 148 | " sympy_expression ~ f(\"872 times 2233 =\")\n", 149 | " answer = g(sympy_expression)\n", 150 | "```\n", 151 | "where $g(\\cdot)$ is Sympy evaluation.\n", 152 | "\n", 153 | "This yields functionality that is difficult to program manually:" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 8, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "(800+72)*2233\n", 166 | "1947176\n" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "from sympy.parsing.sympy_parser import parse_expr\n", 172 | "\n", 173 | "prompt = \"\"\"Solve the multiplication problem by writing input to a sympy function.\n", 174 | "Do not add additional text. Here are some examples:\n", 175 | "\n", 176 | "432 multiplied by 342 is: 432*342\n", 177 | "98* 19 is how much? 98*19\n", 178 | "\"\"\"\n", 179 | "\n", 180 | "p = LMFunction('gpt-4')\n", 181 | "\n", 182 | "g = parse_expr\n", 183 | "\n", 184 | "sympy_expression = p.f(prompt, 'There are 800+72 apples in a barrel. How many apples in 2233 barrels?')\n", 185 | "answer = g(sympy_expression)\n", 186 | "print(sympy_expression)\n", 187 | "print(answer)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 7, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "text/plain": [ 198 | "1947176" 199 | ] 200 | }, 201 | "execution_count": 7, 202 | "metadata": {}, 203 | "output_type": "execute_result" 204 | } 205 | ], 206 | "source": [ 207 | "872*2233" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "#### Language cascades\n", 215 | "\n", 216 | "A [language model cascade [Dohan et al 2022]](https://arxiv.org/abs/2207.10342) formalizes the idea of composing multiple functions, some of which are stochastic functions implemented by a language model.\n", 217 | "\n", 218 | "The result can be seen as a probabilistic program, whose samples \"execute the function\", e.g.:" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 56, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "data": { 228 | "text/latex": [ 229 | "$\\displaystyle 4544$" 230 | ], 231 | "text/plain": [ 232 | "4544" 233 | ] 234 | }, 235 | "execution_count": 56, 236 | "metadata": {}, 237 | "output_type": "execute_result" 238 | } 239 | ], 240 | "source": [ 241 | "def multiply(x):\n", 242 | " y1 = p.f(prompt, x)\n", 243 | " y2 = g(y1)\n", 244 | " return y2\n", 245 | "\n", 246 | "\n", 247 | "multiply('I bought 32 cases of apples, with one hundred and 42 apples per case. How many total apples?')" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 57, 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "data": { 257 | "text/plain": [ 258 | "4544" 259 | ] 260 | }, 261 | "execution_count": 57, 262 | "metadata": {}, 263 | "output_type": "execute_result" 264 | } 265 | ], 266 | "source": [ 267 | "32*142" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "#### Cascades for neural theorem proving\n", 275 | "\n", 276 | "The two ideas mentioned above: composing multiple functions and using a verifier, make neural theorem proving a natural setting for language cascades.\n", 277 | "\n", 278 | "Namely, the goal will be to decompose theorem proving into different functions, then use the proof assistant to verify the final output.\n", 279 | "\n", 280 | "In the next notebook, we will see a cascade called [Draft, Sketch, Prove [Jiang et al ICLR 2023]](https://arxiv.org/abs/2210.12283) that does so with three components: \\\n", 281 | "**draft** an informal proof, **sketch** a formal proof, and **prove** the remaining gaps.\n", 282 | "\n", 283 | "The end result is a model and proof search procedure that is qualitatively much different than the next-step predictors we used in Part I." 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": {}, 289 | "source": [] 290 | } 291 | ], 292 | "metadata": { 293 | "kernelspec": { 294 | "display_name": "Python 3 (ipykernel)", 295 | "language": "python", 296 | "name": "python3" 297 | }, 298 | "language_info": { 299 | "codemirror_mode": { 300 | "name": "ipython", 301 | "version": 3 302 | }, 303 | "file_extension": ".py", 304 | "mimetype": "text/x-python", 305 | "name": "python", 306 | "nbconvert_exporter": "python", 307 | "pygments_lexer": "ipython3", 308 | "version": "3.10.11" 309 | } 310 | }, 311 | "nbformat": 4, 312 | "nbformat_minor": 4 313 | } 314 | -------------------------------------------------------------------------------- /partII_dsp/notebooks/II_dsp__part2_dsp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "#### Language cascades | part 2: Draft, Sketch, Prove\n", 8 | "Tutorial on neural theorem proving\\\n", 9 | "Author: Sean Welleck\n", 10 | "\n", 11 | "----------------\n", 12 | "\n", 13 | "### High-level goal\n", 14 | "\n", 15 | "This notebook will implement a prototype version of [Draft, Sketch, Prove [Jiang et al ICLR 2023]](https://arxiv.org/pdf/2210.12283.pdf):\n", 16 | "\n", 17 | "\n", 18 | "\n", 19 | "As pictured above, Draft, Sketch, Prove frames theorem proving as the following procedure. \\\n", 20 | "Given an informal (i.e., Latex) theorem statement $x_I$ and formal theorem statement $x_F$:\n", 21 | "\n", 22 | "1. Generate an *informal* proof $y_{I}\\sim p(\\cdot|x_I,P_{\\text{draft}})$, called a *draft*\n", 23 | "2. Generate a *formal sketch* $z_{F}\\sim p(\\cdot|y_{I}, x_I, x_F, P_{\\text{sketch}})$\n", 24 | "3. Prove the remaining conjectures in the sketch, $y_{F}=f(x_F,z_F)$.\n", 25 | "\n", 26 | "If step $3$ is successful, we will have a verified formal proof of $x_F$. Otherwise we try again. \\\n", 27 | "Conceptually, these steps can be viewed as the following program:" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 1, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "def dsp(xi, xf, f_draft, f_sketch, f_proof):\n", 37 | " yi = f_draft(xi)\n", 38 | " zf = f_sketch(yi, xi, xf)\n", 39 | " yf = f_proof(xf, zf)\n", 40 | " return yf" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "Next, we will discuss how to implement these three modules.\n", 48 | "\n", 49 | "We start by introducing the [Isabelle proof assistant](https://isabelle.in.tum.de/), since it is relevant to the implementation.\n", 50 | "\n", 51 | "\n", 52 | "### Isabelle proof assistant\n", 53 | "\n", 54 | "Draft, Sketch, Prove was originally proposed using a proof assistant called [**Isabelle**](https://isabelle.in.tum.de/).\n", 55 | "\n", 56 | "Isabelle has two relevant properties that are helpful to introduce.\n", 57 | "\n", 58 | "#### 1. Declarative proofs\n", 59 | "First, many Isabelle proofs are structured as a sequence of intermediate conjectures (referred to as a *declarative proof*).\n", 60 | "For example, consider the proof below:\n", 61 | "\n", 62 | "\n", 63 | "" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "In **blue** are the intermediate steps, which include two conjectures:\n", 71 | "1. `c1`, which says that $1*28=n*4$\n", 72 | "2. `c2`, which says that $n=1*28/4$\n", 73 | "\n", 74 | "The final step `then show ?thesis` can be thought of as \"the result follows\".\\\n", 75 | "Intuitively, intermediate conjectures can resemble steps in an informal (latex) proof. \n", 76 | "\n", 77 | "\n", 78 | "Isabelle requires a proof of each step, shown in **green**. These often involve lower-level premises or calls to external automation.\\\n", 79 | "To obtain these proofs we can use Sledgehammer:" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "#### 2. Sledgehammer\n", 87 | "\n", 88 | "Isabelle has a powerful automation tool called **[Sledgehammer](https://isabelle.in.tum.de/website-Isabelle2009-1/sledgehammer.html)** for producing proofs similar to those shown in green.\n", 89 | "\n", 90 | "In practice, a user would write an intermediate conjecture, e.g. `have c1: \"...`, then call Sledgehammer to find a proof of the conjecture:\n", 91 | "\n", 92 | "\n", 93 | "\n", 94 | "\n", 95 | "Under the covers, Sledgehammer calls out to classical provers that excel at short, low-level proofs. However, fully proving complex theorems with Sledgehammer is typically intractable due to the large search space (for instance, Sledgehammer wouldn't produce the `have c1` *statement*, even though it can prove the statement).\n", 96 | "\n", 97 | "Our use of Sledgehammer is implemented in the `Checker` class in `dsp_utils.py`. Please see [isabelle_setup.md](../isabelle_setup.md) to set up the proof checker, and modify the `working_dir`, `isa_path`, and `theory_file` paths below accordingly.\n", 98 | "\n", 99 | "Below, we initialize an Isabelle proof checker that can run Sledgehammer:" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 2, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "import sys\n", 109 | "import os\n", 110 | "sys.path.append('../')\n", 111 | "os.environ['PISA_PATH'] = '/home/seanw/Portal-to-ISAbelle/src/main/python'\n", 112 | "\n", 113 | "import dsp_utils\n", 114 | "\n", 115 | "checker = dsp_utils.Checker(\n", 116 | " working_dir='/home/seanw/Isabelle2022/src/HOL/Examples',\n", 117 | " isa_path='/home/seanw/Isabelle2022',\n", 118 | " theory_file='/home/seanw/Isabelle2022/src/HOL/Examples/Interactive.thy',\n", 119 | " port=9000\n", 120 | ")" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "Now we send the theorem and proof to the checker. If sledgehammer succeeds at a given step, the result of sledgehammer is added to the proof, and checking proceeds to the next step.\n", 128 | "\n", 129 | "At the end, we get a completed proof:" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 3, 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "----------Path to Isabelle source----------\n", 142 | "/home/seanw/Isabelle2022\n", 143 | "----------Path to Isabelle working directory----------\n", 144 | "/home/seanw/Isabelle2022/src/HOL/Examples\n", 145 | "----------Path to Isabelle theory file----------\n", 146 | "/home/seanw/Isabelle2022/src/HOL/Examples/Interactive.thy\n", 147 | "\n", 148 | "==== Success: True\n", 149 | "--- Complete proof:\n", 150 | "theorem gcd_lcm:\n", 151 | " assumes \"gcd (n :: nat) 4 = 1\" \n", 152 | " and \"lcm (n :: nat) 4 = 28\"\n", 153 | " shows \"n = 7\"\n", 154 | "proof -\n", 155 | "have c1: \"1*28 = n*4\"\n", 156 | "using assms\n", 157 | "by (metis prod_gcd_lcm_nat)\n", 158 | "then\n", 159 | "have c2: \"n = 1*28/4\"\n", 160 | "by auto\n", 161 | "then\n", 162 | "show ?thesis\n", 163 | "by auto\n", 164 | "qed\n" 165 | ] 166 | } 167 | ], 168 | "source": [ 169 | "theorem_and_sledgehammer_proof = \"\"\"theorem gcd_lcm:\n", 170 | " assumes \"gcd (n :: nat) 4 = 1\" \n", 171 | " and \"lcm (n :: nat) 4 = 28\"\n", 172 | " shows \"n = 7\"\n", 173 | "proof -\n", 174 | " have c1: \"1*28 = n*4\" using assms\n", 175 | " sledgehammer\n", 176 | " then have c2: \"n = 1*28/4\"\n", 177 | " sledgehammer\n", 178 | " then show ?thesis\n", 179 | " sledgehammer\n", 180 | "qed\"\"\"\n", 181 | "\n", 182 | "result = checker.check(theorem_and_sledgehammer_proof)\n", 183 | "\n", 184 | "print(\"\\n==== Success: %s\" % result['success'])\n", 185 | "print(\"--- Complete proof:\\n%s\" % result['theorem_and_proof'])" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "Notice that the `sledgehammer` steps are now filled in (e.g. `by (metis prod_gcd_lcm_nat)`)." 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "## Cascade\n", 200 | "\n", 201 | "The declarative steps (in blue) and the lower-level proofs of those steps (in green) lead to a nice opportunity for a language model cascade.\n", 202 | "\n", 203 | "Namely, Draft, Sketch, Prove uses a neural language model to \"draft\" an informal proof, \"sketch\" the declarative steps based on the informal proof, then attempts to \"close the gaps\" with Sledgehammer (i.e., prove each step). If the gaps are closed, the steps together with their proofs constitute a verified proof of the original theorem (and of course, this is checked by Isabelle).\n", 204 | "\n", 205 | "\n", 206 | "## Draft and Sketch examples\n", 207 | "To implement the drafting and sketching steps, we provide a few examples in the prompt.\n", 208 | "\n", 209 | "We will derive these examples from ones used in the paper:" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 27, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "import sys\n", 219 | "sys.path.append('../')\n", 220 | "import dsp_utils\n", 221 | "\n", 222 | "import json\n", 223 | "examples = [\n", 224 | " {\"tag\": \"aimeI_2000_p7\", \"category\": \"algebra\", \"metadata\": {}, \"prompt\": \"Informal:\\n(*### Problem\\n\\nSuppose that $x,$ $y,$ and $z$ are three positive numbers that satisfy the equations $xyz = 1,$ $x + \\\\frac {1}{z} = 5,$ and $y + \\\\frac {1}{x} = 29.$ Then $z + \\\\frac {1}{y} = \\\\frac {m}{n},$ where $m$ and $n$ are [[relatively prime]] positive integers. Find $m + n$. Show that it is 5.\\n\\n\\nnote: this is the type of problem that makes you think symmetry, but actually can be solved easily with substitution, and other normal technniques\\n\\n### Solution\\n\\nWe can rewrite $xyz=1$ as $\\\\frac{1}{z}=xy$.\\n\\nSubstituting into one of the given equations, we have \\n$x+xy=5$\\n$x(1+y)=5$\\n$\\\\frac{1}{x}=\\\\frac{1+y}{5}.$\\n\\nWe can substitute back into $y+\\\\frac{1}{x}=29$ to obtain\\n$y+\\\\frac{1+y}{5}=29$\\n$5y+1+y=145$\\n$y=24.$\\n\\nWe can then substitute once again to get\\n$x=\\\\frac15$\\n$z=\\\\frac{5}{24}.$\\nThus, $z+\\\\frac1y=\\\\frac{5}{24}+\\\\frac{1}{24}=\\\\frac{1}{4}$, so $m+n=005$.*)\\n\\nFormal:\\ntheorem\\n fixes x y z :: real\\n and p :: rat\\n assumes \\\"0 < x \\\\ 0 < y \\\\ 0 < z\\\"\\n and \\\"x * y * z = 1\\\"\\n and \\\"x + 1 / z = 5\\\"\\n and \\\"y + 1 / x = 29\\\"\\n and \\\"z + 1 / y = p\\\"\\n and \\\"0 < p\\\" \\n shows \\\"let (m,n) = quotient_of p in m + n = 5\\\"\\nproof -\\n (* We can rewrite $xyz=1$ as $\\\\frac{1}{z}=xy$. *)\\n have c0: \\\"z = 1 / (x*y)\\\"\\n sledgehammer\\n (* Substituting into one of the given equations, we have \\n $x+xy=5$\\n $x(1+y)=5$\\n $\\\\frac{1}{x}=\\\\frac{1+y}{5}.$ *)\\n have c1: \\\"1 / x = (1+y) / 5\\\" \\n proof -\\n have \\\"x + x * y = 5\\\" using assms(3) unfolding c0\\n sledgehammer\\n then have \\\"x * (1 + y) = 5\\\"\\n sledgehammer\\n then have t1: \\\"x = 5 / (1+y)\\\"\\n sledgehammer\\n then show ?thesis\\n sledgehammer\\n qed\\n (* We can substitute back into $y+\\\\frac{1}{x}=29$ to obtain\\n $y+\\\\frac{1+y}{5}=29$\\n $5y+1+y=145$\\n $y=24.$ *)\\n have \\\"y + (1+y)/5 = 29\\\" using assms(4) unfolding c1 sledgehammer\\n then have \\\"5* (y + (1+y)/5) = 5 * 29\\\" sledgehammer\\n also have \\\"... = 145\\\" sledgehammer\\n finally have c2_1: \\\"5* (y + (1+y)/5) = 145\\\" sledgehammer\\n have \\\"5* (y + (1+y)/5) = 5*y + (1+y)\\\" sledgehammer\\n also have \\\"... = 6*y + 1\\\" sledgehammer\\n finally have c2_2: \\\"5* (y + (1+y)/5) = 6*y + 1\\\" sledgehammer\\n have \\\"6*y + 1 = 145\\\" using c2_1 c2_2 sledgehammer\\n then have c2: \\\"y = 24\\\" sledgehammer\\n (* We can then substitute once again to get\\n $x=\\\\frac15$\\n $z=\\\\frac{5}{24}.$ *)\\n have \\\"1/x = 5\\\" using c1 unfolding c2 sledgehammer\\n then have c3: \\\"x = 1/5\\\"\\n sledgehammer\\n then have c4: \\\"z = 5/24\\\"\\n sledgehammer\\n (* Thus, $z+\\\\frac1y=\\\\frac{5}{24}+\\\\frac{1}{24}=\\\\frac{1}{4}$, so $m+n=005$. *)\\n have \\\"p = z + 1/y\\\" using assms(5) sledgehammer\\n also have \\\"... = 5/24 + 1/24\\\" unfolding c2 c4 sledgehammer\\n also have \\\"... = 1/4\\\" sledgehammer\\n finally have c5: \\\"p = 1/4\\\"\\n sledgehammer\\n have \\\"quotient_of p = (1, 4)\\\" unfolding c5 sledgehammer\\n then show ?thesis sledgehammer\\nqed\"},\n", 225 | " {\"tag\": \"algebra_2rootsintpoly_am10tap11eqasqpam110\", \"category\": \"algebra\", \"metadata\": {}, \"prompt\": \"Informal:\\n(*### Problem\\n\\nShow that for any complex number a, $(a-10)(a+11) = a^2 + a - 110$.\\n\\n### Solution\\n\\nWe first expand all terms of the left hand side to get $a^2 - 10a + 11a - 10*11$.\\nThis equals $a^2 + a - 10*11 = a^2 + a - 110$.*)\\n\\nFormal:\\ntheorem\\n fixes a :: complex\\n shows \\\"(a-10) * (a+11) = a^2 + a -110\\\"\\nproof -\\n (* We first expand all terms of the left hand side to get $a^2 - 10a + 11a - 10*11$. *)\\n have \\\"(a-10) * (a+11) = a^2 - 10*a + 11*a - 10 *11\\\"\\n sledgehammer\\n (* This equals $a^2 + a - 10*11 = a^2 + a - 110$. *)\\n also have \\\"\\\\ = a^2 + a - 10 * 11\\\"\\n sledgehammer\\n also have \\\"\\\\ = a^2 + a - 110\\\"\\n sledgehammer\\n finally show ?thesis\\n sledgehammer\\nqed\"},\n", 226 | " {\"tag\": \"mathd_numbertheory_335\", \"category\": \"number_theory\", \"metadata\": {}, \"prompt\": \"Informal:\\n(*### Problem\\n\\nWhen Rachel divides her favorite number by 7, she gets a remainder of 5. What will the remainder be if she multiplies her favorite number by 5 and then divides by 7? Show that it is 4.\\n\\n### Solution\\n\\nLet $n$ be Rachel's favorite number. \\nThen $n \\\\equiv 5 \\\\pmod{7}$, so $5n \\\\equiv 5 \\\\cdot 5 \\\\equiv 25 \\\\equiv 4 \\\\pmod{7}$.\\n*)\\n\\nFormal:\\ntheorem\\n fixes n :: nat\\n assumes h0 : \\\"n mod 7 = 5\\\"\\n shows \\\"(5 * n) mod 7 = 4\\\"\\nproof -\\n (* Then $n \\\\equiv 5 \\\\pmod{7}$, so $5n \\\\equiv 5 \\\\cdot 5 \\\\equiv 25 \\\\equiv 4 \\\\pmod{7}$. *)\\n have c0:\\\"(5 * n) mod 7 = (5 * 5) mod 7\\\" using h0\\n sledgehammer\\n then have \\\"\\\\ = 4\\\" sledgehammer\\n then have \\\"(5 * n) mod 7 = 4\\\" using c0 sledgehammer\\n then show ?thesis sledgehammer\\nqed\"}\n", 227 | "]" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": {}, 233 | "source": [ 234 | "## Draft\n", 235 | "\n", 236 | "This function generates an *informal* proof $y_{I}\\sim p(\\cdot|x_I,P_{\\text{draft}})$, called a *draft*.\n", 237 | "\n", 238 | "Here, $P_{\\text{draft}}$ is a prompt containing examples of mapping the informal theorem statement $x_I$ to an informal proof $y_I$:\n" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 51, 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "Draft an informal solution similar to below. \n", 251 | "The informal solution will be used to sketch a formal Isabelle proof.\n", 252 | "Here are some examples:\n", 253 | "Example:\n", 254 | "Informal:\n", 255 | "(*### Problem\n", 256 | "\n", 257 | "Suppose that $x,$ $y,$ and $z$ are three positive numbers that satisfy the equations $xyz = 1,$ $x + \\frac {1}{z} = 5,$ and $y + \\frac {1}{x} = 29.$ Then $z + \\frac {1}{y} = \\frac {m}{n},$ where $m$ and $n$ are [[relatively prime]] positive integers. Find $m + n$. Show that it is 5.\n", 258 | "\n", 259 | "\n", 260 | "note: this is the type of problem that makes you think symmetry, but actually can be solved easily with substitution, and other normal technniques\n", 261 | "\n", 262 | "### Solution\n", 263 | "\n", 264 | "We can rewrite $xyz=1$ as $\\frac{1}{z}=xy$.\n", 265 | "\n", 266 | "Substituting into one of the given equations, we have \n", 267 | "$x+xy=5$\n", 268 | "$x(1+y)=5$\n", 269 | "$\\frac{1}{x}=\\frac{1+y}{5}.$\n", 270 | "\n", 271 | "We can substitute back into $y+\\frac{1}{x}=29$ to obtain\n", 272 | "$y+\\frac{1+y}{5}=29$\n", 273 | "$5y+1+y=145$\n", 274 | "$y=24.$\n", 275 | "\n", 276 | "We can then substitute once again to get\n", 277 | "$x=\\frac15$\n", 278 | "$z=\\frac{5}{24}.$\n", 279 | "Thus, $z+\\frac1y=\\frac{5}{24}+\\frac{1}{24}=\\frac{1}{4}$, so $m+n=005$.*)\n", 280 | "\n", 281 | "\n", 282 | "\n", 283 | "Example:\n", 284 | "Informal:\n", 285 | "(*### Problem\n", 286 | "\n", 287 | "Show that for any complex number a, $(a-10)(a+11) = a^2 + a - 110$.\n", 288 | "\n", 289 | "### Solution\n", 290 | "\n", 291 | "We first expand all terms of the left hand side to get $a^2 - 10a + 11a - 10*11$.\n", 292 | "This equals $a^2 + a - 10*11 = a^2 + a - 110$.*)\n", 293 | "\n", 294 | "\n", 295 | "\n", 296 | "Example:\n", 297 | "Informal:\n", 298 | "(*### Problem\n", 299 | "\n", 300 | "When Rachel divides her favorite number by 7, she gets a remainder of 5. What will the remainder be if she multiplies her favorite number by 5 and then divides by 7? Show that it is 4.\n", 301 | "\n", 302 | "### Solution\n", 303 | "\n", 304 | "Let $n$ be Rachel's favorite number. \n", 305 | "Then $n \\equiv 5 \\pmod{7}$, so $5n \\equiv 5 \\cdot 5 \\equiv 25 \\equiv 4 \\pmod{7}$.\n", 306 | "*)\n", 307 | "\n", 308 | "\n", 309 | "\n", 310 | "Informal:\n", 311 | "(*### Problem\n", 312 | "\n", 313 | "\n" 314 | ] 315 | } 316 | ], 317 | "source": [ 318 | "prompt = \"\"\"Draft an informal solution similar to below. \n", 319 | "The informal solution will be used to sketch a formal Isabelle proof.\n", 320 | "Here are some examples:\n", 321 | "\"\"\"\n", 322 | "for x in examples:\n", 323 | " prompt += (\"Example:\\n\" + x['prompt'][:x['prompt'].find('Formal:')] + \"\\n\\n\")\n", 324 | "prompt += \"\"\"Informal:\n", 325 | "(*### Problem\n", 326 | "\n", 327 | "\"\"\"\n", 328 | "\n", 329 | "print(prompt)" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 44, 335 | "metadata": {}, 336 | "outputs": [ 337 | { 338 | "name": "stdout", 339 | "output_type": "stream", 340 | "text": [ 341 | "### Solution\n", 342 | "\n", 343 | "If x is even, then it can be represented as 2n where n is some integer.\n", 344 | "So x + 5 equals 2n + 5.\n", 345 | "We can rewrite 2n + 5 as 2(n + 2) + 1, which is in the form of 2k + 1 where k is an integer (in this case, n + 2). Thus, x + 5 is odd.\n" 346 | ] 347 | } 348 | ], 349 | "source": [ 350 | "p = dsp_utils.LMFunction('gpt-4')\n", 351 | "xi = 'Show that if x is even, then x+5 is odd'\n", 352 | "yi = p.f(prompt, xi)\n", 353 | "print(yi)" 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": {}, 359 | "source": [ 360 | "## Sketch\n", 361 | "\n", 362 | "Generate a *formal sketch* $z_{F}\\sim p(\\cdot|y_{I}, x_I, x_F, P_{\\text{sketch}})$\n", 363 | "\n", 364 | "Here, $P_{\\text{sketch}}$ is a prompt containing the examples from the drafting step with an additional formal sketch." 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 55, 370 | "metadata": {}, 371 | "outputs": [ 372 | { 373 | "name": "stdout", 374 | "output_type": "stream", 375 | "text": [ 376 | "Formal:\n", 377 | "proof -\n", 378 | " (* If x is even, then it can be represented as 2n where n is some integer. *)\n", 379 | " obtain n where c1: \"x = 2*n\"\n", 380 | " using evenE assms\n", 381 | " sledgehammer\n", 382 | " (* So x + 5 equals 2n + 5. *)\n", 383 | " then have \"x + 5 = 2*n + 5\" \n", 384 | " sledgehammer\n", 385 | " (* We can rewrite 2n + 5 as 2(n + 2) + 1, which is in the form of 2k + 1 where k is an integer (in this case, n + 2). Thus, x + 5 is odd. *)\n", 386 | " also have \"\\ = 2*(n+2) + 1\"\n", 387 | " sledgehammer\n", 388 | " then have exI: \"\\k. x + 5 = 2*k+1\" \n", 389 | " sledgehammer\n", 390 | " then have \"odd (x+5)\" \n", 391 | " sledgehammer\n", 392 | " then show ?thesis \n", 393 | " sledgehammer\n", 394 | "qed\n" 395 | ] 396 | } 397 | ], 398 | "source": [ 399 | "prompt = \"\"\"Translate the informal solution into a sketch of the\n", 400 | "formal Isabelle proof. Add `sledgehammer` in the sketch whenever\n", 401 | "possible. `sledgehammer` will be used to call the automated Sledgehammer prover. \n", 402 | "Here are some examples:\n", 403 | "\"\"\"\n", 404 | "for x in examples:\n", 405 | " prompt += (x['prompt'] + \"\\n\\n\")\n", 406 | "prompt += \"\"\"Informal:\n", 407 | "(*### Problem\n", 408 | "\n", 409 | "\"\"\"\n", 410 | "\n", 411 | "xf = \"\"\"theorem\n", 412 | "fixes x :: int\n", 413 | "assumes h0: \"even x\"\n", 414 | "shows \"odd (x+5)\" \"\"\"\n", 415 | "\n", 416 | "zi = p.f(prompt, xi + '\\n\\n' + yi + '\\n\\n' + xf)\n", 417 | "print(zi)" 418 | ] 419 | }, 420 | { 421 | "cell_type": "markdown", 422 | "metadata": {}, 423 | "source": [ 424 | "## Proof\n", 425 | "\n", 426 | "Finally, we call [Sledgehammer](https://isabelle.in.tum.de/website-Isabelle2009-1/sledgehammer.html) to prove the remaining intermediate conjectures.\n", 427 | "\n", 428 | "You can see the completed proof printed in the cell output:\n" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": 4, 434 | "metadata": {}, 435 | "outputs": [ 436 | { 437 | "name": "stdout", 438 | "output_type": "stream", 439 | "text": [ 440 | "----------Path to Isabelle source----------\n", 441 | "/home/seanw/Isabelle2022\n", 442 | "----------Path to Isabelle working directory----------\n", 443 | "/home/seanw/Isabelle2022/src/HOL/Examples\n", 444 | "----------Path to Isabelle theory file----------\n", 445 | "/home/seanw/Isabelle2022/src/HOL/Examples/Interactive.thy\n", 446 | "\n", 447 | "==== Success: True\n", 448 | "--- Complete proof:\n", 449 | "theorem\n", 450 | "fixes x :: int\n", 451 | "assumes h0: \"even x\"\n", 452 | "shows \"odd (x+5)\"\n", 453 | "proof -\n", 454 | "(* If x is even, then it can be represented as 2n where n is some integer. *)\n", 455 | "obtain n where c1: \"x = 2*n\"\n", 456 | "using evenE assms\n", 457 | "by auto\n", 458 | "(* So x + 5 equals 2n + 5. *)\n", 459 | "then\n", 460 | "have \"x + 5 = 2*n + 5\"\n", 461 | "by auto\n", 462 | "(* We can rewrite 2n + 5 as 2(n + 2) + 1, which is in the form of 2k + 1 where k is an integer (in this case, n + 2). Thus, x + 5 is odd. *)\n", 463 | "also\n", 464 | "have \"\\ = 2*(n+2) + 1\"\n", 465 | "by auto\n", 466 | "then\n", 467 | "have exI: \"\\k. x + 5 = 2*k+1\"\n", 468 | "using c1 by blast\n", 469 | "then\n", 470 | "have \"odd (x+5)\"\n", 471 | "by presburger\n", 472 | "then\n", 473 | "show ?thesis\n", 474 | "by auto\n", 475 | "qed\n" 476 | ] 477 | } 478 | ], 479 | "source": [ 480 | "theorem_with_proof = \"\"\"theorem\n", 481 | "fixes x :: int\n", 482 | "assumes h0: \"even x\"\n", 483 | "shows \"odd (x+5)\"\n", 484 | "proof -\n", 485 | " (* If x is even, then it can be represented as 2n where n is some integer. *)\n", 486 | " obtain n where c1: \"x = 2*n\"\n", 487 | " using evenE assms\n", 488 | " sledgehammer\n", 489 | " (* So x + 5 equals 2n + 5. *)\n", 490 | " then have \"x + 5 = 2*n + 5\" \n", 491 | " sledgehammer\n", 492 | " (* We can rewrite 2n + 5 as 2(n + 2) + 1, which is in the form of 2k + 1 where k is an integer (in this case, n + 2). Thus, x + 5 is odd. *)\n", 493 | " also have \"\\ = 2*(n+2) + 1\"\n", 494 | " sledgehammer\n", 495 | " then have exI: \"\\k. x + 5 = 2*k+1\" \n", 496 | " sledgehammer\n", 497 | " then have \"odd (x+5)\" \n", 498 | " sledgehammer\n", 499 | " then show ?thesis \n", 500 | " sledgehammer\n", 501 | "qed\"\"\"\n", 502 | "\n", 503 | "result = checker.check(theorem_with_proof)\n", 504 | "\n", 505 | "print(\"\\n==== Success: %s\" % result['success'])\n", 506 | "print(\"--- Complete proof:\\n%s\" % result['theorem_and_proof'])" 507 | ] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "metadata": {}, 512 | "source": [ 513 | "We now have a verified formal proof of the the claim \"If x is even, then x+5 is odd\", and the proof is annotated with informal proof steps as comments (the text inside of `(*....*)`. Pretty cool!" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "## Proof search\n", 521 | "\n", 522 | "In the simple example above, the formal proof was successful on the first try.\n", 523 | "\n", 524 | "In more complex settings we need to try multiple times. \n", 525 | "\n", 526 | "Namely, we can sample multiple drafts, sample multiple formal sketches for each draft, then see if any of them can be successfully proved with Sledgehammer:\n", 527 | "\n", 528 | "\n", 529 | "\n", 530 | "This proof search algorithm is different from the best-first search in next-step suggestion. Namely, it:\n", 531 | "1. narrows the search space to proofs that are \"similar to\" the informal proof\\*\n", 532 | "2. does not interact with the proof assistant after each step\n", 533 | "\n", 534 | "\\*Though ultimately it is up to the neural language model to decide how to use the informal proof (if at all)." 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": {}, 540 | "source": [ 541 | "### Scaling up proof search\n", 542 | "\n", 543 | "The Draft, Sketch, Prove paper shows the effect of scaling up proof search; that is, sampling multiple drafts and/or multiple formal sketches and attempting to verify them.\n", 544 | "\n", 545 | "Namely, proof search with a single sampled sequence yielded less than 100 successful proofs, but scaling to 100 sampled drafts and/or sketches yielded almost 200 successful proofs:\n", 546 | "\n", 547 | "" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": {}, 553 | "source": [ 554 | "Naturally, a better model would require a smaller search budget when measured in terms of number of calls to the model. For instance, imagine a very good model that proved the same ~200 theorems on its first try. This would require fewer calls to the model than above.\n", 555 | "\n", 556 | "Second, the search algorithm used was fairly naive (temperature sampling that only interacts with the proof assistant once). A better search algorithm could yield more successful proofs, and/or a smaller search budget to reach a given number of successful proofs." 557 | ] 558 | }, 559 | { 560 | "cell_type": "markdown", 561 | "metadata": {}, 562 | "source": [ 563 | "#### Examples\n", 564 | "\n", 565 | "Here is an interesting example from the Draft, Sketch, Prove paper. It shows how the neural model can generate an informal proof that differs from the human-written informal proof. The resulting sketches, and formal proofs produced by the system are much different:\n", 566 | "\n", 567 | "" 568 | ] 569 | }, 570 | { 571 | "cell_type": "markdown", 572 | "metadata": {}, 573 | "source": [ 574 | "Feel free to play around with the gpt-4 based implementation above to see what it can do." 575 | ] 576 | }, 577 | { 578 | "cell_type": "markdown", 579 | "metadata": {}, 580 | "source": [ 581 | "----------------------\n", 582 | "\n", 583 | "# Other cascades\n", 584 | "\n", 585 | "[Baldur [First et al 2023]](https://arxiv.org/pdf/2303.04910.pdf) develop a **refinement**, or *proof repair*, module to correct proofs using error messages $e$:\n", 586 | "\n", 587 | "- $y_F^1\\sim p(\\cdot|x_F)$\n", 588 | "- $y_F^{2}\\sim p(\\cdot|x_F,y_F^1,e)$\n", 589 | "\n", 590 | "\n", 591 | "[pySagredo [Azerbayev 2023]](https://github.com/zhangir-azerbayev/pySagredo) is an experimental tactic that uses refinement and GPT-4 to prove theorems in Lean4.\n", 592 | "\n", 593 | "\n", 594 | "More generally, integrating modern language models with proof assistants remains an active area of research." 595 | ] 596 | }, 597 | { 598 | "cell_type": "markdown", 599 | "metadata": {}, 600 | "source": [] 601 | } 602 | ], 603 | "metadata": { 604 | "kernelspec": { 605 | "display_name": "Python 3 (ipykernel)", 606 | "language": "python", 607 | "name": "python3" 608 | }, 609 | "language_info": { 610 | "codemirror_mode": { 611 | "name": "ipython", 612 | "version": 3 613 | }, 614 | "file_extension": ".py", 615 | "mimetype": "text/x-python", 616 | "name": "python", 617 | "nbconvert_exporter": "python", 618 | "pygments_lexer": "ipython3", 619 | "version": "3.10.11" 620 | } 621 | }, 622 | "nbformat": 4, 623 | "nbformat_minor": 4 624 | } 625 | -------------------------------------------------------------------------------- /partII_dsp/notebooks/images/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/.gitkeep -------------------------------------------------------------------------------- /partII_dsp/notebooks/images/dsp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/dsp.png -------------------------------------------------------------------------------- /partII_dsp/notebooks/images/dsp_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/dsp_example.png -------------------------------------------------------------------------------- /partII_dsp/notebooks/images/dsp_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/dsp_plot.png -------------------------------------------------------------------------------- /partII_dsp/notebooks/images/dsp_search.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/dsp_search.png -------------------------------------------------------------------------------- /partII_dsp/notebooks/images/prove.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/prove.png -------------------------------------------------------------------------------- /partII_dsp/notebooks/images/sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/sketch.png -------------------------------------------------------------------------------- /partII_dsp/notebooks/images/sledgehammer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/sledgehammer.png -------------------------------------------------------------------------------- /partI_nextstep/README.md: -------------------------------------------------------------------------------- 1 | ## Part I : Next-step suggestion 2 | 3 | Builds a neural next-step suggestion tool, introducing key concepts and overviewing past work in neural theorem proving along the way. 4 | 5 | 6 | 7 | #### Notebooks: 8 | | Topic | Notebook | 9 | |:-----------------------|-------:| 10 | | 0. Intro | [notebook](./notebooks/I_nextstep_lean__part0_intro.ipynb) | 11 | | 1. Data | [notebook](./notebooks/I_nextstep_lean__part1_data.ipynb) | 12 | | 2. Learning | [notebook](./notebooks/I_nextstep_lean__part2_learn.ipynb) | 13 | | 3. Proof Search | [notebook](./notebooks/I_nextstep_lean__part3_proofsearch.ipynb) | 14 | | 4. Evaluation | [notebook](./notebooks/I_nextstep_lean__part4_evaluation.ipynb) | 15 | | 5. `llmsuggest` | [notebook](./notebooks/I_nextstep_lean__part5_llmsuggest.ipynb) | 16 | 17 | All notebooks are in ([`partI_nextstep/notebooks`](./notebooks)). See [`partI_nextstep/ntp_python`](./ntp_python) and [`partI_nextstep/ntp_lean`](./ntp_lean) for the Python and Lean files covered in the notebooks. 18 | 19 | ## Setup 20 | The notebooks use several tools: Lean (in VSCode), `pylean`, and LeanDojo. It also uses Pytorch and Huggingface for language modeling. Below are setup steps: 21 | 22 | ### Setup Lean 23 | 24 | #### Setup Lean in VS Code 25 | To try the interactive tool, you will need VS Code and Lean 4. 26 | 27 | Please follow the [official instructions for installing Lean 4 in VS Code](https://leanprover-community.github.io/install/macos_details.html#installing-and-configuring-an-editor): [Installing and configuring an editor](https://leanprover-community.github.io/install/macos_details.html#installing-and-configuring-an-editor). 28 | 29 | 30 | #### Setup Lean on the command line 31 | 32 | Additionally, to run the notebooks you will need Lean 4 installed on your laptop/server. 33 | 34 | On Linux or in a Colab notebook, run this command: 35 | ``` 36 | # from https://leanprover-community.github.io/install/linux.html 37 | curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh 38 | source $HOME/.elan/env 39 | lake 40 | ``` 41 | 42 | For MacOS see: 43 | ``` 44 | https://leanprover-community.github.io/install/macos.html 45 | ``` 46 | 47 | ### Setup [`pylean`](https://github.com/zhangir-azerbayev/repl/tree/master) 48 | The `Proof Search` notebook uses [`pylean`](https://github.com/zhangir-azerbayev/repl/tree/master). 49 | 50 | ```bash 51 | git clone 52 | https://github.com/zhangir-azerbayev/repl 53 | cd repl 54 | 55 | git checkout bddf452deda0df2240b248e651bcc37fb8e59d01 56 | 57 | cd pylean 58 | python setup.py develop 59 | ``` 60 | 61 | Then add the following to `repl/lakefile.lean`: 62 | ``` 63 | require mathlib from git 64 | "https://github.com/leanprover-community/mathlib4.git" @ "38dbcd8285bc4b1391619c12f158a7409f3dfc12" 65 | ``` 66 | 67 | 68 | ### Setup LeanDojo 69 | If you want to reproduce the evaluation discussed in the `Evaluation` notebook and implemented in `proofsearch_dojo.py`, you will need to install Lean Dojo: 70 | ``` 71 | pip install lean-dojo==1.1.2 72 | export CONTAINER="native" 73 | ``` 74 | The second line is needed to run LeanDojo outside of Docker. 75 | 76 | ### Setup language modeling tools 77 | The notebooks use `pytorch` and Huggingface `transformers` and `datasets`. 78 | Here is a Conda environment file for an environment that this ran on (excluding the additional libraries above): 79 | ``` 80 | name: ntp 81 | channels: 82 | - defaults 83 | dependencies: 84 | - python==3.10.11 85 | - pip 86 | - pip: 87 | - accelerate==0.21.0 88 | - datasets==2.13.1 89 | - huggingface-hub==0.15.1 90 | - ndjson==0.3.1 91 | - nest-asyncio==1.5.6 92 | - networkx==3.1 93 | - nh3==0.2.14 94 | - ninja==1.11.1 95 | - numpy==1.25.0 96 | - torch==2.0.1 97 | - tqdm==4.65.0 98 | - transformers==4.31.0 99 | prefix: /home/seanw/.conda/envs/ntp 100 | ``` 101 | -------------------------------------------------------------------------------- /partI_nextstep/notebooks/I_nextstep_lean__part0_intro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "#### Neural next-step prediction \n", 8 | "Tutorial on neural theorem proving\\\n", 9 | "Author: Sean Welleck\n", 10 | "\n", 11 | "----------------" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "### High-level goal\n", 19 | "\n", 20 | "In *[Generative Language Modeling for Automated Theorem Proving](https://arxiv.org/abs/2009.03393)*, Stanislas Polu and Ilya Sutskever showed that a language model called *gptf* was capable of generating novel proofs that were adopted by a formal mathematics community. Their key insight was to frame *theorem proving as language generation*: first, use a neural language model $p_\\theta(y_t|x_t)$ to model the distribution over potential next-steps to take $y_t$ given the current state of a proof $x_t$, by simply treating $x_t$ and $y_t$ as discrete token sequences; then, use the language model within a search algorithm to generate a full proof.\n", 21 | "\n", 22 | "Since their work in 2020, this approach-which we refer to as *next-step prediction*-has formed the basis for research on combining language models with formal proof assistants, which we refer to as *neural theorem proving*. \n", 23 | "Neural theorem proving involves *interacting* with a proof assistant in order to generate a *verifiable* formal proof. This differs from traditional language generation tasks such as long-form question answering, which are often performed without interaction and are difficult to reliably evaluate in a scalable way.\n", 24 | "On the other hand, formal code is extremely scarce compared to the natural language found in everyday questions, or the Python code used to benchmark language-model-based code generators.\n", 25 | "As a result, neural theorem proving offers a unique, yet challenging, playground for creative algorithmic development with language models.\n", 26 | "\n", 27 | "\n", 28 | "\n", 29 | "\n", 30 | "Beyond algorithmic development, neural theorem proving offers an opportunity to enable new and useful tools.\n", 31 | "This tutorial walks through building an interactive tool for receiving next-step suggestions from a neural language model. Doing so involves [collecting data (part 1)](./I_nextstep_lean__part1_data.ipynb), [learning a model (part 2)](./I_nextstep_lean__part2_learn.ipynb), measuring performance with [proof search (part 3)](./I_nextstep_lean__part3_proofsearch.ipynb) and [evaluation sets (part 4)](./I_nextstep_lean__part4_evaluation.ipynb), and deploying the model as an [interactive tool (part 5)](./I_nextstep_lean__part5_llmsuggest.ipynb). \n", 32 | "By working through these notebooks, you will see core research ideas and get a practical introduction to neural theorem proving. We hope that doing so lays a foundation for doing research in this area, and that the end product shows a glimpse of \"human-machine collaboration\" that we expect will only expand further in the future." 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "
\n", 40 | "" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [] 47 | } 48 | ], 49 | "metadata": { 50 | "kernelspec": { 51 | "display_name": "Python 3 (ipykernel)", 52 | "language": "python", 53 | "name": "python3" 54 | }, 55 | "language_info": { 56 | "codemirror_mode": { 57 | "name": "ipython", 58 | "version": 3 59 | }, 60 | "file_extension": ".py", 61 | "mimetype": "text/x-python", 62 | "name": "python", 63 | "nbconvert_exporter": "python", 64 | "pygments_lexer": "ipython3", 65 | "version": "3.10.11" 66 | } 67 | }, 68 | "nbformat": 4, 69 | "nbformat_minor": 4 70 | } 71 | -------------------------------------------------------------------------------- /partI_nextstep/notebooks/I_nextstep_lean__part1_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "#### Neural next-step prediction | part 1: data\n", 9 | "Tutorial on neural theorem proving\\\n", 10 | "Author: Sean Welleck\n", 11 | "\n", 12 | "----------------" 13 | ] 14 | }, 15 | { 16 | "attachments": {}, 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "#### High-level goal\n", 21 | "\n", 22 | "Our goal is to train a neural next-step prediction model, $p(y_t|x_t)$. Here $x_t$ is a _proof state_, and $y_t$ is a next-step.\n", 23 | "\n", 24 | "To do so, we will create a dataset $\\mathcal{D}=\\{(x_t,y_t)\\}$ from human-written proofs. \n", 25 | "\n", 26 | "We can then train a neural next-step prediction model using a next-token prediction loss on the dataset." 27 | ] 28 | }, 29 | { 30 | "attachments": {}, 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "#### Simple example\n", 35 | "\n", 36 | "To see what proof states and next-steps look like, let's look at an example human-written theorem and proof:\n", 37 | "\n" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 1, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "import Mathlib.Data.Nat.Prime\n", 50 | "\n", 51 | "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n", 52 | " rw [Nat.coprime] at h \n", 53 | " exact h " 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "!cat ../ntp_lean/examples/example0.lean" 59 | ] 60 | }, 61 | { 62 | "attachments": {}, 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "We would like to transform this theorem and proof into a sequence of (proof_state, next_step) examples.\n", 67 | "\n", 68 | "First, notice that the proof has two steps:\n", 69 | "\n", 70 | "1. $y_1=$ `rw [Nat.coprime] at h`\n", 71 | "2. $y_2=$ `exact h`" 72 | ] 73 | }, 74 | { 75 | "attachments": {}, 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "We can manually see the proof states by looking in VSCode. \n", 80 | "\n", 81 | "For example, placing the cursor before $y_1$ gives us the proof state $x_1$ (shown as \"Tactic state\"):" 82 | ] 83 | }, 84 | { 85 | "attachments": {}, 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "![title](images/proof_state_1.png)" 90 | ] 91 | }, 92 | { 93 | "attachments": {}, 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "That is, the image above corresponds to $(x_1,y_1)$ defined as:\n", 98 | "\n", 99 | " $x_1$: \n", 100 | " ```\n", 101 | " m n : ℕ\n", 102 | " h : Nat.coprime m n\n", 103 | " ⊢ Nat.gcd m n = 1\n", 104 | " ```\n", 105 | "\n", 106 | " $y_1$: `rw [Nat.coprime] at h`\n", 107 | "\n", 108 | "\n", 109 | "Similarly, we can get the proof state $x_2$ prior to the step $y_2$ (`exact h`):\n", 110 | "\n", 111 | "![title](images/proof_state_2.png)\n", 112 | "\n", 113 | "After step $y_2$, the proof is complete: the proof state $x_3$ says we have \"No goals\":\n", 114 | "\n", 115 | "![title](images/proof_state_3.png)\n", 116 | "\n", 117 | "In summary, it is possible to *manually* transform the theorem and proof into a sequence $[(x_1,y_1),(x_2,y_2),(x_3)]$." 118 | ] 119 | }, 120 | { 121 | "attachments": {}, 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Automatically extracting proof states and next-steps \n", 126 | "\n", 127 | "To scale up data collection, we need a way to *automatically* extract proof states and next-steps from human-written proofs.\n", 128 | "\n", 129 | "\n", 130 | "\n", 131 | "A new open-source library by Kaiyu Yang et al. called [LeanDojo](https://leandojo.org/) can automatically extract (proof state, next-step) pairs from Lean proofs. This idea originated in [Han et al ICLR 2022](https://github.com/jesse-michael-han/lean-step-public). We will look at a simplified version of what LeanDojo does.\n", 132 | "\n", 133 | "The core idea is to (1) transform a Lean file into abstract syntax trees using Lean, and (2) postprocess the abstract syntax tree into a dataset. Lean4's powerful metaprogramming functionality give us the tools to do this." 134 | ] 135 | }, 136 | { 137 | "attachments": {}, 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "#### 1. Transform a Lean file\n", 142 | "\n", 143 | "Conceptually, we want a script:\n", 144 | "\n", 145 | "$\\quad f_{\\text{extract}}(\\text{lean file})\\rightarrow \\text{ASTs}$,\n", 146 | "\n", 147 | "We run a simplified version of the script `ExtractData.lean` from LeanDojo:\n", 148 | "" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 2, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | "Input file: partI_nextstep/ntp_lean/examples/example0.lean\n", 161 | "AST: partI_nextstep/ntp_lean/examples/example0.ast.json\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "!cd ../../ && lake env lean --run partI_nextstep/ntp_lean/ExtractSimple.lean partI_nextstep/ntp_lean/examples/example0.lean" 167 | ] 168 | }, 169 | { 170 | "attachments": {}, 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "The output file `example.ast.json` includes proof states and abstract syntax trees for the commands in `example0.lean`.\n", 175 | "\n", 176 | "Here are the proof states for our example:" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 3, 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "data": { 186 | "text/plain": [ 187 | "[{'stateBefore': 'm n : ℕ h : Nat.coprime m n ⊢ Nat.gcd m n = 1',\n", 188 | " 'stateAfter': 'm n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1',\n", 189 | " 'pos': 101,\n", 190 | " 'endPos': 122},\n", 191 | " {'stateBefore': 'm n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1',\n", 192 | " 'stateAfter': 'no goals',\n", 193 | " 'pos': 127,\n", 194 | " 'endPos': 134}]" 195 | ] 196 | }, 197 | "execution_count": 3, 198 | "metadata": {}, 199 | "output_type": "execute_result" 200 | } 201 | ], 202 | "source": [ 203 | "import json\n", 204 | "ast = json.load(open('../../partI_nextstep/ntp_lean/examples/example0.ast.json'))\n", 205 | "ast['tactics']" 206 | ] 207 | }, 208 | { 209 | "attachments": {}, 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "Notice that the proof states are the ones we saw above in VSCode.\n", 214 | "\n", 215 | "Here is the theorem statement's abstract syntax tree:" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 4, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "{'args': [{'node': {'args': [...],\n", 228 | " 'info': 'none',\n", 229 | " 'kind': 'Lean.Parser.Command.declModifiers'}},\n", 230 | " {'node': {'args': [...],\n", 231 | " 'info': 'none',\n", 232 | " 'kind': 'Lean.Parser.Command.theorem'}}],\n", 233 | " 'info': 'none',\n", 234 | " 'kind': 'Lean.Parser.Command.declaration'}\n" 235 | ] 236 | } 237 | ], 238 | "source": [ 239 | "import pprint\n", 240 | "pprint.pprint(ast['commandASTs'][1]['node'], depth=4)" 241 | ] 242 | }, 243 | { 244 | "attachments": {}, 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "#### Post-processing\n", 249 | "\n", 250 | "Next, we post-process the extracted data into a dataset:\n", 251 | "\n", 252 | "$\\quad f_{\\text{post-process}}(\\text{ASTs}, \\text{lean file})\\rightarrow \\{(x_t,y_t)\\}.$\n", 253 | "\n", 254 | "To do so, we use the collected proof states, traverse the AST, and recover the next-steps from the original Lean file.\\\n", 255 | "See `ntp_python.postprocess_ast` for an example (naive) traversal which extracts the theorem name.\n", 256 | "\n", 257 | "Postprocessing `example0.lean` in this way gives us two $(x_t,y_t)$ pairs:" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 5, 263 | "metadata": {}, 264 | "outputs": [ 265 | { 266 | "name": "stdout", 267 | "output_type": "stream", 268 | "text": [ 269 | "Theorem: theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 ...\n", 270 | "--- x1 ---\n", 271 | "m n : ℕ h : Nat.coprime m n ⊢ Nat.gcd m n = 1\n", 272 | "--- y1 ---\n", 273 | "rw [Nat.coprime] at h\n", 274 | "\n", 275 | "--- x2 ---\n", 276 | "m n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1\n", 277 | "--- y2 ---\n", 278 | "exact h\n", 279 | "\n" 280 | ] 281 | } 282 | ], 283 | "source": [ 284 | "import sys\n", 285 | "sys.path.append('../')\n", 286 | "from ntp_python.postprocess_ast import get_theorem\n", 287 | "from collections import defaultdict\n", 288 | "\n", 289 | "theorem2examples = defaultdict(list)\n", 290 | "\n", 291 | "lean_file = open('../../partI_nextstep/ntp_lean/examples/example0.lean').read()\n", 292 | "for item in ast['tactics']:\n", 293 | " theorem = get_theorem(item['pos'], ast)\n", 294 | " theorem2examples[theorem].append({\n", 295 | " 'x': item['stateBefore'],\n", 296 | " 'y': lean_file[item['pos']:item['endPos']],\n", 297 | " })\n", 298 | "\n", 299 | "for theorem, examples in theorem2examples.items():\n", 300 | " print(\"Theorem: \", theorem[:60], '...', sep=' ')\n", 301 | " for t, example in enumerate(examples):\n", 302 | " print(f\"--- x{t+1} ---\", example['x'], sep='\\n')\n", 303 | " print(f\"--- y{t+1} ---\", example['y'], sep='\\n')\n", 304 | " print()" 305 | ] 306 | }, 307 | { 308 | "attachments": {}, 309 | "cell_type": "markdown", 310 | "metadata": {}, 311 | "source": [ 312 | "The core extraction code in LeanDojo is in [ExtractData.lean](https://github.com/lean-dojo/LeanDojo/blob/main/src/lean_dojo/data_extraction/ExtractData.lean) if you are curious.\n", 313 | "\n", 314 | "## Scaling up data collection\n", 315 | "In general, Lean projects are more complex than the simple example above. For instance, projects may:\n", 316 | "1. have a large number of files\n", 317 | "2. have dependencies on other files or projects\n", 318 | "3. have complex file structure that our naive postprocessing doesn't handle\n", 319 | "\n", 320 | "An example is the [mathlib project](https://leanprover-community.github.io/mathlib-overview.html). Mathlib itself changes rapidly, and other Lean projects may depend on specific versions. [LeanDojo](https://leandojo.readthedocs.io/en/latest/index.html|) gives tools for handling this complexity." 321 | ] 322 | }, 323 | { 324 | "attachments": {}, 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "#### Extracting 90k+ theorems with LeanDojo\n", 329 | "\n", 330 | "The LeanDojo tool allows for extracting data from an *arbitrary Lean Github repository*. Conceptually,\n", 331 | "\n", 332 | "$\\quad f_{\\text{leandojo}}(\\text{lean repository})\\rightarrow \\mathcal{D}.$\n", 333 | "\n", 334 | "It supports parallelism, keeps track of versions and dependencies for extracted data, and its post-processing handles more complex scenarios." 335 | ] 336 | }, 337 | { 338 | "attachments": {}, 339 | "cell_type": "markdown", 340 | "metadata": {}, 341 | "source": [ 342 | "**Example**\\\n", 343 | "Here is what the interface would look like for [extracting a dataset from Mathlib4](https://github.com/lean-dojo/LeanDojo/blob/main/scripts/generate-benchmark-lean4.ipynb):\n", 344 | "\n", 345 | "```python\n", 346 | " URL = \"https://github.com/leanprover-community/mathlib4\"\n", 347 | " COMMIT = \"5a919533f110b7d76410134a237ee374f24eaaad\"\n", 348 | " repo = LeanGitRepo(URL, COMMIT)\n", 349 | " traced_repo = trace(repo)\n", 350 | "```\n", 351 | "\n", 352 | "To avoid possible dependency issues, we won't run LeanDojo directly here. However, the LeanDojo authors provide the extracted data online, so we will download it for this tutorial:" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 6, 358 | "metadata": {}, 359 | "outputs": [ 360 | { 361 | "name": "stdout", 362 | "output_type": "stream", 363 | "text": [ 364 | "Number of non-empty training proofs: 41944\n", 365 | "{'commit': '5a919533f110b7d76410134a237ee374f24eaaad',\n", 366 | " 'end': [308, 76],\n", 367 | " 'file_path': 'Mathlib/Analysis/BoxIntegral/Box/Basic.lean',\n", 368 | " 'full_name': 'BoxIntegral.Box.withBotCoe_inj',\n", 369 | " 'start': [307, 1],\n", 370 | " 'traced_tactics': [{'state_after': 'no goals',\n", 371 | " 'state_before': 'ι : Type u_1\\n'\n", 372 | " 'I✝ J✝ : Box ι\\n'\n", 373 | " 'x y : ι → ℝ\\n'\n", 374 | " 'I J : WithBot (Box ι)\\n'\n", 375 | " '⊢ ↑I = ↑J ↔ I = J',\n", 376 | " 'tactic': 'simp only [Subset.antisymm_iff, ← '\n", 377 | " 'le_antisymm_iff, withBotCoe_subset_iff]'}],\n", 378 | " 'url': 'https://github.com/leanprover-community/mathlib4'}\n" 379 | ] 380 | } 381 | ], 382 | "source": [ 383 | "import json\n", 384 | "import sys\n", 385 | "import pprint\n", 386 | "sys.path.append('../')\n", 387 | "from ntp_python.data import _download_and_unpack\n", 388 | "\n", 389 | "_download_and_unpack(\n", 390 | " tarball_url='https://zenodo.org/record/8040110/files/leandojo_benchmark_4_v1.tar.gz',\n", 391 | " data_dir='../data',\n", 392 | " overwrite=False\n", 393 | ")\n", 394 | "\n", 395 | "train = json.load(open('../data/leandojo_benchmark_4/random/train.json'))\n", 396 | "train = [x for x in train if len(x['traced_tactics']) > 0]\n", 397 | "print(\"Number of non-empty training proofs: \", len(train), sep=' ')\n", 398 | "pprint.pprint(train[0])" 399 | ] 400 | }, 401 | { 402 | "attachments": {}, 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "#### Next steps\n", 407 | "In part 2, we'll train a neural next-step generation model on this mathlib4 dataset." 408 | ] 409 | } 410 | ], 411 | "metadata": { 412 | "kernelspec": { 413 | "display_name": "Python 3 (ipykernel)", 414 | "language": "python", 415 | "name": "python3" 416 | }, 417 | "language_info": { 418 | "codemirror_mode": { 419 | "name": "ipython", 420 | "version": 3 421 | }, 422 | "file_extension": ".py", 423 | "mimetype": "text/x-python", 424 | "name": "python", 425 | "nbconvert_exporter": "python", 426 | "pygments_lexer": "ipython3", 427 | "version": "3.10.11" 428 | } 429 | }, 430 | "nbformat": 4, 431 | "nbformat_minor": 4 432 | } 433 | -------------------------------------------------------------------------------- /partI_nextstep/notebooks/I_nextstep_lean__part2_learn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "#### Neural next-step prediction | part 2: learning\n", 9 | "Tutorial on neural theorem proving\\\n", 10 | "Author: Sean Welleck\n", 11 | "\n", 12 | "----------------" 13 | ] 14 | }, 15 | { 16 | "attachments": {}, 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "#### High-level goal\n", 21 | "\n", 22 | "Our goal is to train a neural next-step predictor $p_\\theta(y_t|x_t)$ on the dataset that we collected in the previous notebook.\n", 23 | "\n", 24 | "To do so, we will fine-tune a pretrained language model on the dataset $\\mathcal{D}=\\{(x_t,y_t)\\}$ using the standard supervised fine-tuning approach:\n", 25 | "\n", 26 | "$$\n", 27 | "\\max_\\theta \\sum_{(x_t,y_t)\\in \\mathcal{D}}-\\log p_\\theta(y_t|x_t).\n", 28 | "$$\n", 29 | "\n", 30 | "That is, we maximize the conditional likelihood of a next-step $y_t$ given the context $x_t$. \\\n", 31 | "This corresponds to minimizing a cross-entropy loss at each position of the next-step, $\\sum_{\\ell=1}^{{|y_t|}}-\\log p_\\theta(y_t^\\ell|y_t^{<\\ell})$." 32 | ] 33 | }, 34 | { 35 | "attachments": {}, 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "### Implementation\n", 40 | "\n", 41 | "The implementation consists of two steps:\n", 42 | "\n", 43 | "1. **Data formatting** ([data.py](../ntp_python/data.py)): formatting the examples.\n", 44 | "2. **Tuning** ([tune.py](../ntp_python/tune.py)): using a standard language model fine-tuning script.\n", 45 | "\n" 46 | ] 47 | }, 48 | { 49 | "attachments": {}, 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "#### 1. Data formatting\n", 54 | "\n", 55 | "We format each (tactic-state, next-step) pair $(x_t, y_t)$ as:\n", 56 | "\n", 57 | " [GOAL]tacticstate[PROOFSTEP]next-step<|endoftext|>\n", 58 | "\n", 59 | "Here, `[GOAL]...[PROOFSTEP]` is the input and `next-step<|endoftext|>` is the output.\n", 60 | "\n", 61 | "This format comes from [Han et al ICLR 2022]: \\\n", 62 | "[Proof Artifact Co-training for Theorem Proving with Language Models](https://arxiv.org/pdf/2102.06203.pdf).\n", 63 | "\n", 64 | "\n", 65 | "\n", 66 | "" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 1, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "Saving split to disk...\n" 79 | ] 80 | }, 81 | { 82 | "name": "stderr", 83 | "output_type": "stream", 84 | "text": [ 85 | "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.87it/s]\n" 86 | ] 87 | }, 88 | { 89 | "name": "stdout", 90 | "output_type": "stream", 91 | "text": [ 92 | "train\t169530\n", 93 | "val\t4053\n", 94 | "test\t3606\n" 95 | ] 96 | } 97 | ], 98 | "source": [ 99 | "import sys\n", 100 | "sys.path.append('../ntp_python')\n", 101 | "import data\n", 102 | "\n", 103 | "datasets = data.proofstep(\n", 104 | " data_dir='../data'\n", 105 | ")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 2, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "Input:\n", 118 | "[GOAL]ι : Type u_1\n", 119 | "I✝ J✝ : Box ι\n", 120 | "x y : ι → ℝ\n", 121 | "I J : WithBot (Box ι)\n", 122 | "⊢ ↑I = ↑J ↔ I = J[PROOFSTEP]\n", 123 | "\n", 124 | "Output:\n", 125 | "simp only [Subset.antisymm_iff, ← le_antisymm_iff, withBotCoe_subset_iff]<|endoftext|>\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "example = datasets['train'][0]\n", 131 | "print(\"Input:\", example['input'], '', sep='\\n')\n", 132 | "print(\"Output:\", example['output'], sep='\\n')" 133 | ] 134 | }, 135 | { 136 | "attachments": {}, 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "#### 4. Tuning\n", 141 | "\n", 142 | "We minimally adapt a standard language-model fine-tuning script from [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py). \n", 143 | "\n", 144 | "You can check out the full script at [partI_nextstep/ntp_python/tune.py](../ntp_python/tune.py). \\\n", 145 | "See [partI_nextstep/scripts/tune_proofstep.sh](../scripts/tune_proofstep.sh) for a command that trains on 8 GPUs with deepspeed." 146 | ] 147 | }, 148 | { 149 | "attachments": {}, 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "Here's an example command for training a 1.4b model on 1 GPU (and you can adjust the model size to be smaller to fit your compute constraints):" 154 | ] 155 | }, 156 | { 157 | "attachments": {}, 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "```bash\n", 162 | " REPO_DIR=\"..\"\n", 163 | " TRAIN_FILE=${REPO_DIR}/data/leandojo_benchmark_4/processed/proofstep-train.jsonl\n", 164 | " VALID_FILE=${REPO_DIR}/data/leandojo_benchmark_4/processed/proofstep-val.jsonl\n", 165 | " MODEL=EleutherAI/pythia-1.4b-deduped\n", 166 | "\n", 167 | " OUTDIR=${REPO_DIR}/model/${MODEL}\n", 168 | "\n", 169 | " python ../ntp_python/tune.py \\\n", 170 | " --model_name_or_path ${MODEL} \\\n", 171 | " --train_data_path ${TRAIN_FILE} \\\n", 172 | " --valid_data_path ${VALID_FILE} \\\n", 173 | " --fp16 \\\n", 174 | " --output_dir ${OUTDIR} \\\n", 175 | " --num_train_epochs 10 \\\n", 176 | " --per_device_train_batch_size 4 \\\n", 177 | " --per_device_eval_batch_size 4 \\\n", 178 | " --gradient_accumulation_steps 16 \\\n", 179 | " --evaluation_strategy \"steps\" \\\n", 180 | " --eval_steps 500 \\\n", 181 | " --save_strategy \"steps\" \\\n", 182 | " --save_steps 500 \\\n", 183 | " --save_total_limit 1 \\\n", 184 | " --learning_rate 1e-5 \\\n", 185 | " --load_best_model_at_end 1 \\\n", 186 | " --weight_decay 0. \\\n", 187 | " --warmup_ratio 0.03 \\\n", 188 | " --lr_scheduler_type \"cosine\" \\\n", 189 | " --logging_steps 10 \\\n", 190 | " --logging_dir \"$OUTDIR\" \\\n", 191 | " --report_to=\"tensorboard\"\n", 192 | "\n", 193 | "```" 194 | ] 195 | }, 196 | { 197 | "attachments": {}, 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "#### After training\n", 202 | "\n", 203 | "If everything went well, you should have a model in `../model/{MODEL_NAME}/checkpoint-{BEST_STEP}`.\n", 204 | "\n", 205 | "We have fine-tuned an `EleutherAI/pythia-2.8b-deduped` model that can be accessed through HuggingFace ([link](https://huggingface.co/wellecks/llmstep-mathlib4-pythia2.8b)):" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "import transformers\n", 215 | "\n", 216 | "MODEL = 'wellecks/llmstep-mathlib4-pythia2.8b'\n", 217 | "model = transformers.GPTNeoXForCausalLM.from_pretrained(MODEL)\n", 218 | "tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(MODEL)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "You can use your own model by setting `MODEL = \"../model/{MODEL_NAME}/checkpoint-{BEST_STEP}\"` \\\n", 226 | "(e.g., `../model/EleutherAI/pythia-2.8b-deduped/checkpoint-5000`)." 227 | ] 228 | }, 229 | { 230 | "attachments": {}, 231 | "cell_type": "markdown", 232 | "metadata": {}, 233 | "source": [ 234 | "Let's generate a next-step suggestion for the proof state from our original example:\n", 235 | "\n", 236 | "```lean\n", 237 | " theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1\n", 238 | "```\n", 239 | "Recal from the previous notebook that the initial proof state $x_0$ is:\n", 240 | "\n", 241 | " m n : ℕ\n", 242 | " h : Nat.coprime m n\n", 243 | " ⊢ Nat.gcd m n = 1" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 4, 249 | "metadata": {}, 250 | "outputs": [ 251 | { 252 | "name": "stdout", 253 | "output_type": "stream", 254 | "text": [ 255 | "rw [← h.gcd_eq_one]\n" 256 | ] 257 | } 258 | ], 259 | "source": [ 260 | "prompt = \"\"\"[GOAL]m n : ℕ\n", 261 | " h : Nat.coprime m n\n", 262 | " ⊢ Nat.gcd m n = 1[PROOFSTEP]\"\"\"\n", 263 | "\n", 264 | "input_ids = tokenizer.encode(prompt, return_tensors='pt')\n", 265 | "out = model.generate(\n", 266 | " input_ids,\n", 267 | " max_new_tokens=256,\n", 268 | " pad_token_id=tokenizer.eos_token_id\n", 269 | ")\n", 270 | "text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)\n", 271 | "print(text)" 272 | ] 273 | }, 274 | { 275 | "attachments": {}, 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "### Next steps\n", 280 | "\n", 281 | "In the next notebook, we will prove theorems with the trained model by interacting with the Lean proof assistant.\n", 282 | "\n", 283 | "This will let us automatically check whether a generated proof (e.g., one containing the step above) is correct.\n", 284 | "\n", 285 | "Later on, we will build a VSCode plugin that returns next-step suggestions from the language model." 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [] 292 | } 293 | ], 294 | "metadata": { 295 | "kernelspec": { 296 | "display_name": "Python 3 (ipykernel)", 297 | "language": "python", 298 | "name": "python3" 299 | }, 300 | "language_info": { 301 | "codemirror_mode": { 302 | "name": "ipython", 303 | "version": 3 304 | }, 305 | "file_extension": ".py", 306 | "mimetype": "text/x-python", 307 | "name": "python", 308 | "nbconvert_exporter": "python", 309 | "pygments_lexer": "ipython3", 310 | "version": "3.10.11" 311 | } 312 | }, 313 | "nbformat": 4, 314 | "nbformat_minor": 4 315 | } 316 | -------------------------------------------------------------------------------- /partI_nextstep/notebooks/I_nextstep_lean__part3_proofsearch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "#### Neural next-step prediction | part 3: proof search\n", 8 | "Tutorial on neural theorem proving\\\n", 9 | "Author: Sean Welleck\n", 10 | "\n", 11 | "----------------\n", 12 | "\n", 13 | "#### High-level goal\n", 14 | "\n", 15 | "Our next goal is to prove theorems with our neural next-step predictor, and check whether the theorems are correct.\n", 16 | "\n", 17 | "Proving and checking a theorem involves generating a next-step candidate with our model, giving it to Lean, and receiving a next state from Lean (or an error message). \\\n", 18 | "To do so, we will need two components:\n", 19 | "\n", 20 | "1. **Interacting** with Lean: an automated way to give a next-step to Lean and receive a next state (or an error).\n", 21 | "\n", 22 | "2. A **search strategy** that uses the next-step model and Lean to find a proof (e.g. generate one next-step, get the next state, repeat).\n", 23 | "\n", 24 | "\n", 25 | "Below, we'll walk through a simple example of each. " 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "-------------------\n", 33 | "\n", 34 | "### 1. Interaction\n", 35 | "\n", 36 | "To start, we'll walk through proving this theorem:" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "```lean4\n", 44 | "import Mathlib.Data.Nat.Prime\n", 45 | "\n", 46 | "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n", 47 | " rw [Nat.coprime] at h \n", 48 | " exact h \n", 49 | "```" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "#### Interaction with `pylean`\n", 57 | "\n", 58 | "The [`pylean`](https://github.com/zhangir-azerbayev/repl/tree/master) library gives us a lightweight interface to a lean REPL.\n", 59 | "\n", 60 | "We can pass `pylean` the import and theorem statement:" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 1, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "{'env': 0,\n", 73 | " 'messages': [{'data': 'unsolved goals\\n'\n", 74 | " 'm n : ℕ\\n'\n", 75 | " 'h : Nat.coprime m n\\n'\n", 76 | " '⊢ Nat.gcd m n = 1',\n", 77 | " 'endPos': {'column': 69, 'line': 4},\n", 78 | " 'pos': {'column': 68, 'line': 4},\n", 79 | " 'severity': 'error'}],\n", 80 | " 'sorries': []}\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "from pylean import LeanServer\n", 86 | "from pprint import pprint\n", 87 | "\n", 88 | "code = \"\"\"\n", 89 | "import Mathlib.Data.Nat.Prime\n", 90 | "\n", 91 | "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by {}\n", 92 | "\"\"\"\n", 93 | "\n", 94 | "lean = LeanServer()\n", 95 | "state = lean.run_code(code)\n", 96 | "lean.proc.close()\n", 97 | "pprint(state)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "We see that inside of `'data'`, `pylean` gives us the current proof state $x_t$; here's basic parsing code:" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 2, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "m n : ℕ\n", 117 | "h : Nat.coprime m n\n", 118 | "⊢ Nat.gcd m n = 1\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "def get_goal(state):\n", 124 | " goal = None\n", 125 | " for msg in state['messages']:\n", 126 | " if msg['data'].startswith('unsolved goals\\n'):\n", 127 | " goal = '\\n'.join(msg['data'].split('\\n')[1:])\n", 128 | " elif msg['severity'] == 'error':\n", 129 | " return None\n", 130 | " return goal\n", 131 | "\n", 132 | "print(get_goal(state))" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "We can use $x_t$ as input to our model $p_\\theta(y_t|x_t)$.\\\n", 140 | "Next, we load the trained model and generate a next step, $\\hat y_t\\sim q(p_\\theta(y_t|x_t))$.\n", 141 | "\n", 142 | "(Here $q(\\cdot)$ is a decoding algorithm such as greedy decoding or temperature sampling.)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "# Load model and tokenizer\n", 152 | "import os\n", 153 | "import transformers\n", 154 | "model_name = 'wellecks/llmstep-mathlib4-pythia2.8b'\n", 155 | "model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name)\n", 156 | "tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name)\n", 157 | "os.environ['TOKENIZERS_PARALLELISM'] = 'false' # prevents an annoying warning\n", 158 | "\n", 159 | "\n", 160 | "def generate(prompt):\n", 161 | " input_ids = tokenizer.encode(prompt, return_tensors='pt')\n", 162 | " out = model.generate(\n", 163 | " input_ids,\n", 164 | " max_new_tokens=256,\n", 165 | " pad_token_id=tokenizer.eos_token_id\n", 166 | " )\n", 167 | " text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)\n", 168 | " return text" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 4, 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "rw [← h.gcd_eq_one]\n" 181 | ] 182 | } 183 | ], 184 | "source": [ 185 | "# Generate a next step\n", 186 | "prompt = f\"[GOAL]{get_goal(state)}[PROOFSTEP]\"\n", 187 | "\n", 188 | "next_step = generate(prompt)\n", 189 | "print(next_step)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "Finally, we can give the generated next step to Lean and receive the next state." 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 5, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "{'env': 0, 'messages': [], 'sorries': []}\n" 209 | ] 210 | } 211 | ], 212 | "source": [ 213 | "code = \"\"\"\n", 214 | "import Mathlib.Data.Nat.Prime\n", 215 | "\n", 216 | "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n", 217 | "\n", 218 | "\"\"\" + next_step\n", 219 | "\n", 220 | "lean = LeanServer()\n", 221 | "state = lean.run_code(code)\n", 222 | "lean.proc.close()\n", 223 | "\n", 224 | "pprint(state)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "There are no error messages, and no remaining goals - the proof is complete! If you want, paste this into VS Code to convince yourself that it's complete:\n", 232 | "\n", 233 | "```lean4\n", 234 | "import Mathlib.Data.Nat.Prime\n", 235 | "\n", 236 | "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by\n", 237 | " rw [← h.gcd_eq_one]\n", 238 | "```\n", 239 | "\n", 240 | "Also, notice that the machine-generated proof is different from the human written one shown at the starting of this section." 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": {}, 246 | "source": [ 247 | "-----------------\n", 248 | "\n", 249 | "### 2. Search strategy\n", 250 | "\n", 251 | "In the proof above, we simply generated one next step and the proof was complete.\n", 252 | "\n", 253 | "In general, proofs are multiple steps. Therefore we need an algorithm for generating a multiple step proof, which we refer to as a *search algorithm*.\n", 254 | "\n", 255 | "\n", 256 | "First, let's consider a naive algorithm that generates a next step, then continues to the next state. Upon receiving an error message\n", 257 | "the algorithm generates another next step." 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 6, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "import sys\n", 267 | "sys.path.append('../ntp_python/')\n", 268 | "\n", 269 | "import proofsearch_pylean as proofsearch # some utilities for running code (as we did above) and parsing states/model outputs" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 7, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | "== Current (0): \n", 282 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n", 283 | "\n", 284 | "-- Goal: \n", 285 | "a b c : ℕ\n", 286 | "⊢ a + b = c → a ≤ c\n", 287 | "\n", 288 | "== Current (1): \n", 289 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n", 290 | "rintro rfl\n", 291 | "-- Goal: \n", 292 | "a b : ℕ\n", 293 | "⊢ a ≤ a + b\n", 294 | "\n", 295 | "== Current (2): \n", 296 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n", 297 | "rintro rfl\n", 298 | "exact le_add_left _ _\n", 299 | "-- Error: backtracking\n", 300 | "-- Goal: \n", 301 | "a b : ℕ\n", 302 | "⊢ a ≤ a + b\n", 303 | "\n", 304 | "== Current (3): \n", 305 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n", 306 | "rintro rfl\n", 307 | "apply Nat.le_add_right sperr a\n", 308 | "-- Error: backtracking\n", 309 | "-- Goal: \n", 310 | "a b : ℕ\n", 311 | "⊢ a ≤ a + b\n", 312 | "\n", 313 | "== Current (4): \n", 314 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n", 315 | "rintro rfl\n", 316 | "apply Nat.le_add_right\n", 317 | "\n", 318 | "SUCCESS!\n", 319 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n", 320 | " rintro rfl\n", 321 | " apply Nat.le_add_right\n" 322 | ] 323 | } 324 | ], 325 | "source": [ 326 | "transformers.set_seed(43)\n", 327 | "\n", 328 | "def prove_simple(model, tokenizer, header, theorem_statement, search_budget):\n", 329 | " success = False\n", 330 | "\n", 331 | " code = header + theorem_statement\n", 332 | " steps = []\n", 333 | " proof = ''\n", 334 | "\n", 335 | " for i in range(search_budget):\n", 336 | " print(\"== Current (%d): \" % i, theorem_statement[:-3] + '\\n' + proof, sep='\\n')\n", 337 | "\n", 338 | " # Run the code (header + proof-so-far)\n", 339 | " state = proofsearch.run_code(code)\n", 340 | " \n", 341 | " # Stop if the proof is complete.\n", 342 | " if proofsearch.is_done(state):\n", 343 | " success = True\n", 344 | " break\n", 345 | "\n", 346 | " # Get the new state.\n", 347 | " goal_candidate = proofsearch.get_goal(state)\n", 348 | " if goal_candidate is None:\n", 349 | " print(\"-- Error: backtracking\")\n", 350 | " steps = steps[:-1]\n", 351 | " else:\n", 352 | " goal = goal_candidate\n", 353 | "\n", 354 | " print(\"-- Goal: \", goal, sep='\\n')\n", 355 | "\n", 356 | " # Generate a next-step\n", 357 | " prompt = f\"[GOAL]{goal}[PROOFSTEP]\"\n", 358 | " texts, _= proofsearch.generate(prompt, model, tokenizer, temperatures=[0.5], num_samples=1)\n", 359 | " step = proofsearch.parse_step(texts[0])\n", 360 | "\n", 361 | " # Add the next-step to the proof-so-far\n", 362 | " steps.append(step)\n", 363 | " proof = '\\n'.join(steps)\n", 364 | " code = header + theorem_statement.replace(\" {}\", \"\") + '\\n' + proof\n", 365 | " print()\n", 366 | "\n", 367 | " if success:\n", 368 | " print(\"\\nSUCCESS!\")\n", 369 | " else:\n", 370 | " print(\"\\nFAILED\")\n", 371 | " \n", 372 | " print(theorem_statement.replace(\" {}\", \"\"))\n", 373 | " print (' ' + proof.replace('\\n', '\\n '))\n", 374 | " \n", 375 | " return {'theorem_statement': theorem_statement, 'proof': proof, 'success': success}\n", 376 | "\n", 377 | "\n", 378 | "header = \"\"\"\n", 379 | "import Mathlib.Data.Nat.Prime\n", 380 | "\n", 381 | "\"\"\"\n", 382 | "theorem_statement = \"\"\"theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}\"\"\"\n", 383 | "\n", 384 | "\n", 385 | "out = prove_simple(\n", 386 | " model, \n", 387 | " tokenizer,\n", 388 | " header, \n", 389 | " theorem_statement, \n", 390 | " search_budget=100\n", 391 | ")" 392 | ] 393 | }, 394 | { 395 | "cell_type": "markdown", 396 | "metadata": {}, 397 | "source": [ 398 | "Above (setting `seed = 43` for reproducibility) the model generates `rintro rfl`. \\\n", 399 | "Next it generates `exact le_add_left _ _`, which receives an error, so the model tries again (backtracks). \\\n", 400 | "After backtracking one more time, the model generates `apply Nat.le_add_right` and the proof is complete." 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "metadata": {}, 406 | "source": [ 407 | "### Best-first search\n", 408 | "\n", 409 | "Typically a less naive search procedure is used. These searches are usually variants of a tree search, in which nodes are states and edges are next-steps. \n", 410 | "\n", 411 | "The most common search in neural theorem proving is *best-first search*. This search:\n", 412 | "\n", 413 | "- generates multiple next-step suggestions to form (proof-so-far + next-step) *candidates*\n", 414 | "- scores all candidates so far\n", 415 | "- selects the highest scoring candidate\n", 416 | "\n", 417 | "A typical scoring function is the model's log probability, $\\log p_\\theta(y_t|x_t)$, summed across steps. Next-steps that lead to an error receive a score of $-\\infty$ (in practice, we discard these steps). In the literature, the scoring function is called a *value function* $v(y_{\\leq t}, x_t)$.\n", 418 | "\n", 419 | "#### Intuition\n", 420 | "\n", 421 | "A key idea is generating multiple suggestions at each step, ${y_t^{(1)},\\ldots,y_t^{(k)}}\\sim p_\\theta(\\cdot|x_t)$. Intuitively, the goal is to select a next-step that will lead to a correct proof. In general, we do not know whether a next-step will lead to a correct proof, so we use a heuristic value function for selecting a next-step.\n", 422 | "\n", 423 | "Here's what multiple suggestions and their (normalized) log-probabilities look like in our example:" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": 8, 429 | "metadata": {}, 430 | "outputs": [ 431 | { 432 | "name": "stdout", 433 | "output_type": "stream", 434 | "text": [ 435 | "-0.277\trw [Nat.coprime, gcd_comm] at h\n", 436 | "-0.279\trw [← h.gcd_eq_one]\n", 437 | "-0.335\tapply Nat.eq_one_of_dvd_dvd\n", 438 | "-0.349\trw [Nat.coprime] at h\n", 439 | "-0.350\trw [gcd_comm]\n" 440 | ] 441 | } 442 | ], 443 | "source": [ 444 | "prompt = '[GOAL]m n : ℕ\\nh : Nat.coprime m n\\n⊢ Nat.gcd m n = 1[PROOFSTEP]'\n", 445 | "texts, scores = proofsearch.generate(prompt, model, tokenizer, temperatures=[0.0], num_samples=5)\n", 446 | "for text, score in zip(texts, scores):\n", 447 | " print('%.3f' % score, text, sep='\\t')" 448 | ] 449 | }, 450 | { 451 | "cell_type": "markdown", 452 | "metadata": {}, 453 | "source": [ 454 | "### Implementation\n", 455 | "\n", 456 | "A minimal implementation of best first search is available in `proofsearch_pylean.py`.\\\n", 457 | "A version that uses LeanDojo for interaction is in `proofsearch_dojo.py`.\n", 458 | "\n", 459 | "We will use these in the next notebook to evaluate our model on a set of evaluation theorems.\\\n", 460 | "Below, we run best first search and print out the search trajectory:" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 9, 466 | "metadata": {}, 467 | "outputs": [ 468 | { 469 | "name": "stdout", 470 | "output_type": "stream", 471 | "text": [ 472 | "--- current:\n", 473 | "\ttheorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n", 474 | "\t\n" 475 | ] 476 | }, 477 | { 478 | "name": "stderr", 479 | "output_type": "stream", 480 | "text": [ 481 | "100%|██████████| 4/4 [00:03<00:00, 1.10it/s]\n" 482 | ] 483 | }, 484 | { 485 | "name": "stdout", 486 | "output_type": "stream", 487 | "text": [ 488 | "--- type-checked candidates:\n", 489 | "\t(-0.066) rintro rfl\n", 490 | "\t(-0.307) rintro ⟨rfl, rfl⟩\n", 491 | "\t(-0.035) intro h\n", 492 | "\t(-0.230) rintro ⟨d, rfl⟩\n", 493 | "--- current:\n", 494 | "\ttheorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n", 495 | "\tintro h\n" 496 | ] 497 | }, 498 | { 499 | "name": "stderr", 500 | "output_type": "stream", 501 | "text": [ 502 | "100%|██████████| 4/4 [00:03<00:00, 1.11it/s]\n" 503 | ] 504 | }, 505 | { 506 | "name": "stdout", 507 | "output_type": "stream", 508 | "text": [ 509 | "--- type-checked candidates:\n", 510 | "\t(-0.172) apply le_of_add_le_add_right\n", 511 | "\t(-0.093) rw [← h]\n", 512 | "\t(-0.453) cases c\n", 513 | "--- current:\n", 514 | "\ttheorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n", 515 | "\trintro rfl\n" 516 | ] 517 | }, 518 | { 519 | "name": "stderr", 520 | "output_type": "stream", 521 | "text": [ 522 | "100%|██████████| 4/4 [00:03<00:00, 1.10it/s]" 523 | ] 524 | }, 525 | { 526 | "name": "stdout", 527 | "output_type": "stream", 528 | "text": [ 529 | "--- type-checked candidates:\n", 530 | "\t(-0.109) apply Nat.le_add_right\n", 531 | "\t(-0.173) exact Nat.le_add_right _ _\n" 532 | ] 533 | }, 534 | { 535 | "name": "stderr", 536 | "output_type": "stream", 537 | "text": [ 538 | "\n" 539 | ] 540 | }, 541 | { 542 | "data": { 543 | "text/plain": [ 544 | "{'theorem_statement': 'theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}',\n", 545 | " 'proof': ['rintro rfl', 'apply Nat.le_add_right'],\n", 546 | " 'state': {'sorries': [], 'messages': [], 'env': 0},\n", 547 | " 'score': 0.1747819110751152,\n", 548 | " 'success': True}" 549 | ] 550 | }, 551 | "execution_count": 9, 552 | "metadata": {}, 553 | "output_type": "execute_result" 554 | } 555 | ], 556 | "source": [ 557 | "proofsearch.best_first_search(\n", 558 | " model, tokenizer, header, theorem_statement, \n", 559 | " max_iters=32,\n", 560 | " num_samples=4,\n", 561 | " temperatures=[0.0],\n", 562 | " verbose=True\n", 563 | ")" 564 | ] 565 | }, 566 | { 567 | "cell_type": "markdown", 568 | "metadata": {}, 569 | "source": [ 570 | "The search selects a candidate trajectory, and generates 4 next-step suggestions.\\\n", 571 | "`intro h` is selected at the first step. The best expansion of `intro h` has score -0.093. \\\n", 572 | "This is less than the score of `rintro rfl` (-0.066), so `rintro rfl` is picked. This is backtracking, since `intro h` is no longer in the proof.\\\n", 573 | "Then `apply Nat.le_add_right` is suggested and the proof is complete.\n" 574 | ] 575 | }, 576 | { 577 | "cell_type": "markdown", 578 | "metadata": {}, 579 | "source": [ 580 | "--------------------\n", 581 | "\n", 582 | "\n", 583 | "## Extensions\n", 584 | "\n", 585 | "Several works have proposed to improve the search strategy, either with a learned value function or a sophisticated search:\n", 586 | "\n", 587 | "- [Polu & Sutskever 2020](https://arxiv.org/pdf/2009.03393.pdf) propose to learn a value function $v(y_{\\leq t}, x_t)$ that estimates the probability of successfully proving the theorem with the model $p_\\theta$ starting at state $x_t$. To do so, they use proof search trajectories obtained by doing proof search with the model.\n", 588 | "\n", 589 | "- [Polu et al ICLR 2023](https://openreview.net/pdf?id=-P7G-8dmSh4) train the value function to predict the eventual length of the proof (or 0 if it is predicted to fail). The learned value function improves pass rate by ~10\\% on mathlib theorems compared to log-probability, with a ~1\\% improvement over the learned value function from [Polu & Sutskever 2020].\n", 590 | "\n", 591 | "- [Lample et al NeurIPS 2022](https://openreview.net/pdf?id=J4pX8Q8cxHH) propose a sophisticated MCTS-like search that explores multiple trajectories in parallel, collecting statistics on visited states in order to prioritize search trajectories.\n", 592 | "\n", 593 | "Reproducing, analyzing, and improving the search algorithm remains an open area for future work in neural theorem proving (for instance, these works were not open-sourced).\n", 594 | "\n", 595 | "Search algorithms are also an active area of research in LLMs, including methods like [tree-of-thought](https://arxiv.org/abs/2305.10601), [stepwise beam search](https://arxiv.org/pdf/2205.12910.pdf), [self-consistency](https://arxiv.org/pdf/2203.11171.pdf), and search with [learned stepwise verifiers](https://arxiv.org/pdf/2305.20050.pdf). In theorem proving, the final output is verifiable, but the quality of intermediate steps is difficult to evaluate." 596 | ] 597 | } 598 | ], 599 | "metadata": { 600 | "kernelspec": { 601 | "display_name": "Python 3 (ipykernel)", 602 | "language": "python", 603 | "name": "python3" 604 | }, 605 | "language_info": { 606 | "codemirror_mode": { 607 | "name": "ipython", 608 | "version": 3 609 | }, 610 | "file_extension": ".py", 611 | "mimetype": "text/x-python", 612 | "name": "python", 613 | "nbconvert_exporter": "python", 614 | "pygments_lexer": "ipython3", 615 | "version": "3.10.11" 616 | } 617 | }, 618 | "nbformat": 4, 619 | "nbformat_minor": 4 620 | } 621 | -------------------------------------------------------------------------------- /partI_nextstep/notebooks/I_nextstep_lean__part5_llmsuggest.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "#### Neural next-step prediction | part 5: `llmsuggest` co-pilot \n", 9 | "Tutorial on neural theorem proving\\\n", 10 | "Author: Sean Welleck\n", 11 | "\n", 12 | "----------------" 13 | ] 14 | }, 15 | { 16 | "attachments": {}, 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "#### High-level goal\n", 21 | "\n", 22 | "Finally, we will see how our neural model can act as a helpful \"co-pilot\" when we are writing proofs. \\\n", 23 | "We will make an interactive tool that uses our neural next-step model to suggest next-steps in VSCode.\n", 24 | "\n", 25 | "Concretely, we'll create a `llmsuggest` tactic (in essence, a function) that displays generated suggestions in VSCode. `llmsuggest` is a minimal version of the [**llmstep** [Welleck & Saha 2023]](https://github.com/wellecks/llmstep) tactic, aimed at learning and building off of.\n", 26 | "\n", 27 | "Here is a preview of `llmsuggest`:" 28 | ] 29 | }, 30 | { 31 | "attachments": {}, 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "\n" 36 | ] 37 | }, 38 | { 39 | "attachments": {}, 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "On the top, the user entered `llmsuggest`, then suggestions from our next-step prediction model appear in the Lean Infoview. Clicking a suggestion adds it to the proof." 44 | ] 45 | }, 46 | { 47 | "attachments": {}, 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "-----------------------\n", 52 | "\n", 53 | "### High-level approach\n", 54 | "\n", 55 | "Implementing `llmsuggest` involves three components:\n", 56 | "\n", 57 | "1. A Lean *tactic* that sends the current state to a Python script.\n", 58 | "2. A Python script that sends the current state to a server via a POST request.\n", 59 | "3. A Python server that runs our next-step suggestion model on the current state.\n", 60 | "\n", 61 | "The suggestions (3) are sent back to (2), and the tactic (1) displays the result in VSCode.\n", 62 | "\n", 63 | "\n", 64 | "### Implementing `llmsuggest`" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "#### 1. Tactic\n", 72 | "\n", 73 | "At a technical level, the proofs we have seen are actually sequences of *tactics*. \n", 74 | "For instance, `intro` is a tactic and `rw [...]` is a tactic. In general, a *tactic* is a Lean program that manipulates a state. \n", 75 | "\n", 76 | "To build a new tactic, we use *Lean metaprogramming*, which gives us tools to define new syntax, access the proof state, and more. \\\n", 77 | "`llmsuggest` only requires basic metaprogramming. To learn more, see the [Lean 4 metaprogramming book](https://github.com/leanprover-community/lean4-metaprogramming-book/tree/master).\n", 78 | "\n", 79 | "`llmsuggest` is implemented in `ntp_lean/LLMsuggest.lean`. The main definition specifies the syntax (i.e., `\"llmsuggest\"`), then defines the tactic. \\\n", 80 | "You can see below that the tactic gets the main goal (the \"tactic state\"), pretty-prints it, and converts it to a string.\n", 81 | "Then it runs a `runSuggest` function, and passes the output to an `addSuggestions` function:\n", 82 | "\n", 83 | "```lean\n", 84 | "-- `llmsuggest` tactic.\n", 85 | "syntax \"llmsuggest\" : tactic\n", 86 | "elab_rules : tactic\n", 87 | " | `(tactic | llmsuggest%$tac) =>\n", 88 | " Lean.Elab.Tactic.withMainContext do\n", 89 | " let goal ← Lean.Elab.Tactic.getMainGoal\n", 90 | " let ppgoal ← Lean.Meta.ppGoal goal\n", 91 | " let ppgoalstr := toString ppgoal\n", 92 | " let suggest ← runSuggest #[ppgoalstr]\n", 93 | " addSuggestions tac $ suggest.splitOn \"[SUGGESTION]\"\n", 94 | "```\n", 95 | "\n", 96 | "The `runSuggest` function calls a Python script (step 2 above), and the `addSuggestions` uses a Lean widget to display the results in VSCode. \\\n", 97 | "We won't look at these in detail, but please see `ntp_lean/LLMsuggest.lean` if you are curious. \\\n", 98 | "Hopefully with a small amount of effort, you can modify the tactic or make your own in the future.\n", 99 | "\n", 100 | "\n", 101 | "#### 2. Python script\n", 102 | "\n", 103 | "The `runSuggest` function in the tactic calls a Python script, `ntp_python/llmsuggest/suggest.py`. It passes the current tactic state as a command line argument.\\\n", 104 | "The script is simple: it sends a POST request containing the current tactic state to a server:\n", 105 | "\n", 106 | "```python\n", 107 | "def suggest(tactic_state):\n", 108 | " conn = http.client.HTTPConnection(\"localhost\", 5000)\n", 109 | " headers = {'Content-type': 'application/json'}\n", 110 | " body = json.dumps({\"tactic_state\": sys.argv[1]})\n", 111 | " conn.request(\"POST\", \"/\", body, headers)\n", 112 | " response = conn.getresponse()\n", 113 | " data = response.read()\n", 114 | " data_dict = json.loads(data)\n", 115 | " print('[SUGGESTION]'.join(data_dict['suggestions']))\n", 116 | " conn.close()\n", 117 | "\n", 118 | "if __name__ == \"__main__\":\n", 119 | " suggest(sys.argv[1])\n", 120 | "```\n", 121 | "\n", 122 | "After receiving suggestions, it prints the suggestions, and the printed suggestions will be received in the `runSuggest` function.\n", 123 | "\n", 124 | "#### 3. Server\n", 125 | "Finally, in `ntp_python/llmsuggest/server.py` we define a web server that handles the POST request, and hosts the language model.\n", 126 | "Specifically, the server initializes our language model, and uses the model to\n", 127 | "generate suggestions given a tactic state received in a POST request.\n", 128 | "```python\n", 129 | "model = transformers.GPTNeoXForCausalLM.from_pretrained('wellecks/llmstep-mathlib4-pythia2.8b')\n", 130 | "\n", 131 | "def generate ...\n", 132 | "\n", 133 | "app = Flask(__name__)\n", 134 | "\n", 135 | "@app.route('/', methods=['POST'])\n", 136 | "def process_request():\n", 137 | " data = request.get_json()\n", 138 | " tactic_state = data.get('tactic_state')\n", 139 | " prompt = \"\"\"[GOAL]%s[PROOFSTEP]\"\"\" % (tactic_state)\n", 140 | " texts = generate(prompt, args.num_samples)\n", 141 | " response = {\"suggestions\": texts}\n", 142 | " return jsonify(response)\n", 143 | "\n", 144 | "if __name__ == '__main__':\n", 145 | " app.run(host='0.0.0.0', port=args.port)\n", 146 | "```\n", 147 | "\n", 148 | "This server is minimal; one can imagine adding several features." 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "### Running `llmsuggest`\n", 156 | "\n", 157 | "To run `llmsuggest`, first start the server:\n", 158 | "```bash\n", 159 | "python python/server.py\n", 160 | "```\n", 161 | "\n", 162 | "Then open `ntp_lean/LLMsuggest.lean` in VS Code and try out `llmsuggest`. There are some example theorems and proofs at the bottom of the page:\n", 163 | "\n", 164 | "\"\"" 165 | ] 166 | }, 167 | { 168 | "attachments": {}, 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "-----------------------\n", 173 | "\n", 174 | "### `llmstep`: [L]LM proofstep suggestions in Lean\n", 175 | "\n", 176 | "[`llmstep`](https://github.com/wellecks/llmstep) is an expanded version of the `llm_suggest` tactic: https://github.com/wellecks/llmstep\n", 177 | "\n", 178 | "`llmstep` includes features such as:\n", 179 | "1. **Type checking**: suggestions are checked by Lean and marked as completing a proof, valid, or invalid (but still possibly useful).\n", 180 | "2. **Prefixed generation**: e.g. `llmstep \"exact\"` returns suggestions that start with `\"exact\"`\n", 181 | "3. **Fast inference**: fast inference via [PagedAttention](https://vllm.ai/) for near real-time suggestions\n", 182 | "4. **Other models**: support for other models, e.g. `llmstep-llama2`\n", 183 | "\n", 184 | "Here's an example of using `llmstep`:\n", 185 | "\n", 186 | "\"\"\n" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "The first invocation (`llmstep \"\"`) gives 5 suggestions, with `intro h n` and `intro h` outlined in blue since they type check.\n", 194 | "\n", 195 | "The second invocation (`llmstep \"exact\"`) gives suggestions that start with `exact`. The first three are outlined in green since they complete the proof." 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "----------\n", 203 | "\n", 204 | "## Next steps\n", 205 | "\n", 206 | "This concludes part 1 of the tutorial. We have seen how to build a neural next-step suggestion tool from scratch: collecting data, learning a model, measuring performance with proof search and evaluation sets, and deploying the model as an interactive tactic.\n", 207 | "\n", 208 | "In part 2, we will look at a generalization called language cascades, in which a language model implements a \"function\" that does more than predict the next step. We will see example cascades for drafting informal proofs, sketching the high-level structure of a proof, and refining proofs." 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [] 215 | } 216 | ], 217 | "metadata": { 218 | "kernelspec": { 219 | "display_name": "Python 3 (ipykernel)", 220 | "language": "python", 221 | "name": "python3" 222 | }, 223 | "language_info": { 224 | "codemirror_mode": { 225 | "name": "ipython", 226 | "version": 3 227 | }, 228 | "file_extension": ".py", 229 | "mimetype": "text/x-python", 230 | "name": "python", 231 | "nbconvert_exporter": "python", 232 | "pygments_lexer": "ipython3", 233 | "version": "3.10.11" 234 | } 235 | }, 236 | "nbformat": 4, 237 | "nbformat_minor": 4 238 | } 239 | -------------------------------------------------------------------------------- /partI_nextstep/notebooks/data/successes_mathlib4_200_wellecks_llmstep-mathlib4-pythia2.8b.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": [ 3 | { 4 | "theorem": "NonemptyInterval.pure_injective", 5 | "proof": [ 6 | "unfold pure", 7 | "simp [Injective]" 8 | ], 9 | "score": 0.06666778214275837, 10 | "success": true 11 | }, 12 | { 13 | "theorem": "Finset.sym2_univ", 14 | "proof": [ 15 | "rfl" 16 | ], 17 | "score": 0.057796161621809006, 18 | "success": true 19 | }, 20 | { 21 | "theorem": "Subsemigroup.monotone_map", 22 | "proof": [ 23 | "intro m n h", 24 | "rintro _ \u27e8x, hx, rfl\u27e9", 25 | "exact \u27e8x, h hx, rfl\u27e9" 26 | ], 27 | "score": 0.10295426659286022, 28 | "success": true 29 | }, 30 | { 31 | "theorem": "MeasureTheory.Measurable.comp_nullMeasurable", 32 | "proof": [ 33 | "exact hg.comp hf" 34 | ], 35 | "score": 0.03515021502971649, 36 | "success": true 37 | }, 38 | { 39 | "theorem": "Finset.prod_pow_boole", 40 | "proof": [ 41 | "split_ifs with h <;> simp [h]" 42 | ], 43 | "score": 0.022501567378640175, 44 | "success": true 45 | }, 46 | { 47 | "theorem": "isOpenMap_div_right", 48 | "proof": [ 49 | "simp only [div_eq_mul_inv, isOpenMap_mul_right]" 50 | ], 51 | "score": 0.06886392831802368, 52 | "success": true 53 | }, 54 | { 55 | "theorem": "Fin.Ici_eq_finset_subtype", 56 | "proof": [ 57 | "ext", 58 | "simp" 59 | ], 60 | "score": 0.04452210105955601, 61 | "success": true 62 | }, 63 | { 64 | "theorem": "Polynomial.toFinsupp_zero", 65 | "proof": [ 66 | "cases p", 67 | "rfl" 68 | ], 69 | "score": 0.0956510566174984, 70 | "success": true 71 | }, 72 | { 73 | "theorem": "Continuous.sum_map", 74 | "proof": [ 75 | "continuity" 76 | ], 77 | "score": 0.04079575836658478, 78 | "success": true 79 | }, 80 | { 81 | "theorem": "Set.smul_set_inter\u2080", 82 | "proof": [ 83 | "ext", 84 | "simp [mem_smul_set_iff_inv_smul_mem\u2080 ha]" 85 | ], 86 | "score": 0.07694211788475513, 87 | "success": true 88 | }, 89 | { 90 | "theorem": "Subtype.coe_image_univ", 91 | "proof": [ 92 | "simp [image_univ]" 93 | ], 94 | "score": 0.08483579009771347, 95 | "success": true 96 | }, 97 | { 98 | "theorem": "AlgEquiv.toLinearEquiv_trans", 99 | "proof": [ 100 | "ext <;> simp" 101 | ], 102 | "score": 0.02171783335506916, 103 | "success": true 104 | }, 105 | { 106 | "theorem": "LieSubmodule.coe_add", 107 | "proof": [ 108 | "simp" 109 | ], 110 | "score": 0.014644467271864414, 111 | "success": true 112 | }, 113 | { 114 | "theorem": "IsAntichain.preimage_iso", 115 | "proof": [ 116 | "intro b\u2081 b\u2082 hb\u2081 hb\u2082", 117 | "simpa using ht b\u2082 hb\u2082" 118 | ], 119 | "score": 0.06171885505318642, 120 | "success": true 121 | }, 122 | { 123 | "theorem": "PrimeSpectrum.comap_asIdeal", 124 | "proof": [ 125 | "rfl" 126 | ], 127 | "score": 0.026252638548612595, 128 | "success": true 129 | }, 130 | { 131 | "theorem": "MonoidWithZeroHom.coe_mk", 132 | "proof": [ 133 | "rfl" 134 | ], 135 | "score": 0.0029693180695176125, 136 | "success": true 137 | }, 138 | { 139 | "theorem": "OreLocalization.one_def", 140 | "proof": [ 141 | "rfl" 142 | ], 143 | "score": 0.05921659246087074, 144 | "success": true 145 | }, 146 | { 147 | "theorem": "LinearPMap.mem_domain_of_mem_graph", 148 | "proof": [ 149 | "rw [mem_graph_iff] at h", 150 | "rcases h with \u27e8x', rfl, h'\u27e9", 151 | "exact x'.2" 152 | ], 153 | "score": 0.04480455256998539, 154 | "success": true 155 | }, 156 | { 157 | "theorem": "Vector.empty_toList_eq_ff", 158 | "proof": [ 159 | "cases v", 160 | "simp [toList, List.isEmpty]", 161 | "split", 162 | "all_goals tauto" 163 | ], 164 | "score": 0.10822779312729836, 165 | "success": true 166 | }, 167 | { 168 | "theorem": "Mathlib.Tactic.IntervalCases.of_not_lt_right", 169 | "proof": [ 170 | "exact eq \u25b8 le_of_not_lt h" 171 | ], 172 | "score": 0.09517894685268402, 173 | "success": true 174 | }, 175 | { 176 | "theorem": "List.Pairwise.and", 177 | "proof": [ 178 | "induction hR", 179 | "case nil => simp", 180 | "simp_all" 181 | ], 182 | "score": 0.06705544702708721, 183 | "success": true 184 | }, 185 | { 186 | "theorem": "Int.lcm_comm", 187 | "proof": [ 188 | "rw [Int.lcm]", 189 | "apply Nat.lcm_comm" 190 | ], 191 | "score": 0.11606144160032272, 192 | "success": true 193 | }, 194 | { 195 | "theorem": "Finsupp.ker_lsingle", 196 | "proof": [ 197 | "ext", 198 | "simp" 199 | ], 200 | "score": 0.022008214611560106, 201 | "success": true 202 | }, 203 | { 204 | "theorem": "Setoid.ext_iff", 205 | "proof": [ 206 | "constructor", 207 | "rintro rfl _ _", 208 | "rfl", 209 | "exact fun h => Setoid.ext h" 210 | ], 211 | "score": 0.15640811529010534, 212 | "success": true 213 | }, 214 | { 215 | "theorem": "Set.Finite.toFinset_offDiag", 216 | "proof": [ 217 | "ext", 218 | "simp" 219 | ], 220 | "score": 0.010765764862298965, 221 | "success": true 222 | }, 223 | { 224 | "theorem": "Finsupp.embDomain_inj", 225 | "proof": [ 226 | "constructor", 227 | "apply embDomain_injective", 228 | "aesop" 229 | ], 230 | "score": 0.06850366480648518, 231 | "success": true 232 | }, 233 | { 234 | "theorem": "PowerSeries.smul_eq_C_mul", 235 | "proof": [ 236 | "simp [PowerSeries.ext_iff]" 237 | ], 238 | "score": 0.09802203625440598, 239 | "success": true 240 | }, 241 | { 242 | "theorem": "OrderHom.orderHom_eq_id", 243 | "proof": [ 244 | "apply Subsingleton.elim" 245 | ], 246 | "score": 0.032435908913612366, 247 | "success": true 248 | }, 249 | { 250 | "theorem": "Behrend.sphere_subset_box", 251 | "proof": [ 252 | "simp [sphere, box]" 253 | ], 254 | "score": 0.07471803575754166, 255 | "success": true 256 | }, 257 | { 258 | "theorem": "Polynomial.aeval_algHom", 259 | "proof": [ 260 | "ext", 261 | "simp" 262 | ], 263 | "score": 0.008154324954375625, 264 | "success": true 265 | }, 266 | { 267 | "theorem": "OrderMonoidWithZeroHom.comp_mul", 268 | "proof": [ 269 | "ext <;> simp" 270 | ], 271 | "score": 0.03107544593513012, 272 | "success": true 273 | }, 274 | { 275 | "theorem": "Dfinsupp.mapRange_apply", 276 | "proof": [ 277 | "rfl" 278 | ], 279 | "score": 0.021800542250275612, 280 | "success": true 281 | }, 282 | { 283 | "theorem": "Submonoid.center_toSubsemigroup", 284 | "proof": [ 285 | "rfl" 286 | ], 287 | "score": 0.1160610243678093, 288 | "success": true 289 | }, 290 | { 291 | "theorem": "Subgroup.toSubmonoid_eq", 292 | "proof": [ 293 | "constructor", 294 | "swap", 295 | "rintro rfl", 296 | "rfl", 297 | "apply toSubmonoid_injective" 298 | ], 299 | "score": 0.10547792294528335, 300 | "success": true 301 | }, 302 | { 303 | "theorem": "MonoidHom.prod_comp_prodMap", 304 | "proof": [ 305 | "ext <;> simp" 306 | ], 307 | "score": 0.010844556614756584, 308 | "success": true 309 | }, 310 | { 311 | "theorem": "LocalEquiv.EqOnSource.source_eq", 312 | "proof": [ 313 | "cases e", 314 | "cases e'", 315 | "exact h.1" 316 | ], 317 | "score": 0.039306161692366004, 318 | "success": true 319 | }, 320 | { 321 | "theorem": "ENNReal.toNNReal_le_toNNReal", 322 | "proof": [ 323 | "rw [\u2190 ENNReal.coe_le_coe, coe_toNNReal ha, coe_toNNReal hb]" 324 | ], 325 | "score": 0.05514577403664589, 326 | "success": true 327 | }, 328 | { 329 | "theorem": "DirectSum.toAddMonoidAlgebra_toDirectSum", 330 | "proof": [ 331 | "simp [DirectSum.toAddMonoidAlgebra]", 332 | "simp [AddMonoidAlgebra.toDirectSum]" 333 | ], 334 | "score": 0.07822933420538902, 335 | "success": true 336 | }, 337 | { 338 | "theorem": "bddBelow_Ici", 339 | "proof": [ 340 | "use a", 341 | "simp [lowerBounds]" 342 | ], 343 | "score": 0.09038608893752098, 344 | "success": true 345 | }, 346 | { 347 | "theorem": "Int.lor_bit", 348 | "proof": [ 349 | "dsimp [lor]", 350 | "cases a <;> cases b <;> simp", 351 | "all_goals cases m <;> cases n <;> simp" 352 | ], 353 | "score": 0.060909992549568415, 354 | "success": true 355 | }, 356 | { 357 | "theorem": "IndexedPartition.some_index", 358 | "proof": [ 359 | "cases hs", 360 | "simp", 361 | "aesop" 362 | ], 363 | "score": 0.10609651356935501, 364 | "success": true 365 | }, 366 | { 367 | "theorem": "nndist_vadd_right", 368 | "proof": [ 369 | "simp [nndist]", 370 | "ext", 371 | "simp", 372 | "rw [dist_comm, dist_vadd_left]" 373 | ], 374 | "score": 0.06943780835717916, 375 | "success": true 376 | }, 377 | { 378 | "theorem": "Irrational.ne_nat", 379 | "proof": [ 380 | "rw [Irrational] at h", 381 | "rintro rfl", 382 | "exact h (Set.mem_range_self _)" 383 | ], 384 | "score": 0.23603598587214947, 385 | "success": true 386 | }, 387 | { 388 | "theorem": "StarSubalgebra.subset_adjoin", 389 | "proof": [ 390 | "simp [subset_adjoin]", 391 | "tauto" 392 | ], 393 | "score": 0.05541197769343853, 394 | "success": true 395 | }, 396 | { 397 | "theorem": "UniqueFactorizationMonoid.normalizedFactors_one", 398 | "proof": [ 399 | "simp [normalizedFactors]" 400 | ], 401 | "score": 0.035825345665216446, 402 | "success": true 403 | }, 404 | { 405 | "theorem": "Filter.bliminf_eq_liminf_subtype", 406 | "proof": [ 407 | "simp [bliminf_eq, liminf_eq]", 408 | "congr!", 409 | "aesop" 410 | ], 411 | "score": 0.09382953867316246, 412 | "success": true 413 | }, 414 | { 415 | "theorem": "Set.biUnion_singleton", 416 | "proof": [ 417 | "simp" 418 | ], 419 | "score": 0.002082447987049818, 420 | "success": true 421 | }, 422 | { 423 | "theorem": "RightOrdContinuous.iterate", 424 | "proof": [ 425 | "induction' n with n ih", 426 | "case zero => exact RightOrdContinuous.id \u03b1", 427 | "exact ih.comp hf" 428 | ], 429 | "score": 0.08558873645961285, 430 | "success": true 431 | }, 432 | { 433 | "theorem": "one_le_zpow", 434 | "proof": [ 435 | "lift n to \u2115 using hn", 436 | "rw [zpow_ofNat]", 437 | "exact one_le_pow_of_one_le' H n" 438 | ], 439 | "score": 0.0882943207398057, 440 | "success": true 441 | }, 442 | { 443 | "theorem": "not_lt_zero'", 444 | "proof": [ 445 | "simp" 446 | ], 447 | "score": 0.03989856690168381, 448 | "success": true 449 | }, 450 | { 451 | "theorem": "AlgEquiv.toLinearMap_injective", 452 | "proof": [ 453 | "intro f g h", 454 | "ext", 455 | "rw [\u2190 toLinearMap_apply, \u2190 toLinearMap_apply, h]" 456 | ], 457 | "score": 0.04303271742537618, 458 | "success": true 459 | }, 460 | { 461 | "theorem": "Algebra.algebraMap_ofSubsemiring_apply", 462 | "proof": [ 463 | "rfl" 464 | ], 465 | "score": 0.017106426879763603, 466 | "success": true 467 | }, 468 | { 469 | "theorem": "PiTensorProduct.lift_tprod", 470 | "proof": [ 471 | "ext <;> simp" 472 | ], 473 | "score": 0.026965346187353134, 474 | "success": true 475 | }, 476 | { 477 | "theorem": "div_eq_iff_mul_eq", 478 | "proof": [ 479 | "rw [div_eq_iff hb]", 480 | "rw [eq_comm]" 481 | ], 482 | "score": 0.05084596015512943, 483 | "success": true 484 | }, 485 | { 486 | "theorem": "ENNReal.one_lt_two", 487 | "proof": [ 488 | "norm_num" 489 | ], 490 | "score": 0.004520223010331392, 491 | "success": true 492 | }, 493 | { 494 | "theorem": "MeasureTheory.volume_preserving_finTwoArrow", 495 | "proof": [ 496 | "haveI : Encodable \u03b9 := Fintype.toEncodable \u03b9", 497 | "apply measurePreserving_finTwoArrow" 498 | ], 499 | "score": 0.07828811183571815, 500 | "success": true 501 | }, 502 | { 503 | "theorem": "Subsingleton.eq_univ_of_nonempty", 504 | "proof": [ 505 | "rintro \u27e8x, hx\u27e9", 506 | "refine' eq_univ_iff_forall.mpr fun y => _", 507 | "rwa [Subsingleton.elim y x]" 508 | ], 509 | "score": 0.171775184571743, 510 | "success": true 511 | }, 512 | { 513 | "theorem": "Complex.two_cos", 514 | "proof": [ 515 | "rw [two_mul]", 516 | "simp [cos]" 517 | ], 518 | "score": 0.15968229621648788, 519 | "success": true 520 | }, 521 | { 522 | "theorem": "IsLocalization.map_mk'", 523 | "proof": [ 524 | "apply IsLocalization.lift_mk'" 525 | ], 526 | "score": 0.02295919507741928, 527 | "success": true 528 | }, 529 | { 530 | "theorem": "Subsemigroup.coe_comap", 531 | "proof": [ 532 | "rfl" 533 | ], 534 | "score": 0.03112742491066456, 535 | "success": true 536 | }, 537 | { 538 | "theorem": "CategoryTheory.LaxMonoidalFunctor.prod'_\u03bc", 539 | "proof": [ 540 | "dsimp [prod']", 541 | "simp" 542 | ], 543 | "score": 0.021082513965666294, 544 | "success": true 545 | }, 546 | { 547 | "theorem": "Matrix.SpecialLinearGroup.coe_int_neg", 548 | "proof": [ 549 | "ext", 550 | "simp" 551 | ], 552 | "score": 0.0304072555154562, 553 | "success": true 554 | }, 555 | { 556 | "theorem": "ENNReal.toReal_nonneg", 557 | "proof": [ 558 | "simp [ENNReal.toReal]" 559 | ], 560 | "score": 0.05232922360301018, 561 | "success": true 562 | }, 563 | { 564 | "theorem": "Fin.le_coe_natAdd", 565 | "proof": [ 566 | "simp" 567 | ], 568 | "score": 0.05793900787830353, 569 | "success": true 570 | }, 571 | { 572 | "theorem": "SemiconjBy.inv_inv_symm", 573 | "proof": [ 574 | "intro h", 575 | "simp [h]" 576 | ], 577 | "score": 0.12453563511371613, 578 | "success": true 579 | }, 580 | { 581 | "theorem": "nhds_ofAdd", 582 | "proof": [ 583 | "rfl" 584 | ], 585 | "score": 0.03388547897338867, 586 | "success": true 587 | }, 588 | { 589 | "theorem": "MatrixEquivTensor.invFun_zero", 590 | "proof": [ 591 | "simp [invFun]" 592 | ], 593 | "score": 0.018612481653690338, 594 | "success": true 595 | }, 596 | { 597 | "theorem": "List.toFinset_reverse", 598 | "proof": [ 599 | "simp [toFinset]" 600 | ], 601 | "score": 0.026263730600476265, 602 | "success": true 603 | }, 604 | { 605 | "theorem": "CategoryTheory.Monoidal.leftUnitor_hom_app", 606 | "proof": [ 607 | "dsimp [leftUnitor]" 608 | ], 609 | "score": 0.01946549117565155, 610 | "success": true 611 | }, 612 | { 613 | "theorem": "List.sum_smul", 614 | "proof": [ 615 | "induction' l with hd tl hl", 616 | "simp", 617 | "simp [add_smul, hl]" 618 | ], 619 | "score": 0.056858822237700224, 620 | "success": true 621 | }, 622 | { 623 | "theorem": "ProbabilityTheory.kernel.integral_const", 624 | "proof": [ 625 | "rw [kernel.const_apply]" 626 | ], 627 | "score": 0.0064316014759242535, 628 | "success": true 629 | }, 630 | { 631 | "theorem": "CauSeq.coe_sup", 632 | "proof": [ 633 | "rfl" 634 | ], 635 | "score": 0.0693473070859909, 636 | "success": true 637 | }, 638 | { 639 | "theorem": "MeasureTheory.NullMeasurableSet.insert", 640 | "proof": [ 641 | "unfold NullMeasurableSet", 642 | "rw [insert_eq]", 643 | "measurability" 644 | ], 645 | "score": 0.08922873809933662, 646 | "success": true 647 | }, 648 | { 649 | "theorem": "Set.subset_iInter\u2082", 650 | "proof": [ 651 | "intro x hx", 652 | "exact mem_iInter.2 fun i => mem_iInter.2 fun j => h i j hx" 653 | ], 654 | "score": 0.042423089034855366, 655 | "success": true 656 | }, 657 | { 658 | "theorem": "ClopenUpperSet.coe_mk", 659 | "proof": [ 660 | "rfl" 661 | ], 662 | "score": 0.01648038811981678, 663 | "success": true 664 | }, 665 | { 666 | "theorem": "IsLprojection.compl_mul", 667 | "proof": [ 668 | "simp [sub_mul]" 669 | ], 670 | "score": 0.061795979738235474, 671 | "success": true 672 | }, 673 | { 674 | "theorem": "Equiv.mulLeft_symm_apply", 675 | "proof": [ 676 | "simp [Equiv.mulLeft]" 677 | ], 678 | "score": 0.028932297602295876, 679 | "success": true 680 | }, 681 | { 682 | "theorem": "Multiset.map_hcongr", 683 | "proof": [ 684 | "subst h", 685 | "revert hf", 686 | "simp (config := { contextual := true }) [heq_iff_eq]" 687 | ], 688 | "score": 0.1003159754909575, 689 | "success": true 690 | }, 691 | { 692 | "theorem": "EMetric.le_infEdist", 693 | "proof": [ 694 | "simp [infEdist]" 695 | ], 696 | "score": 0.044103752821683884, 697 | "success": true 698 | }, 699 | { 700 | "theorem": "DifferentiableOn.const_smul", 701 | "proof": [ 702 | "intro y hy", 703 | "rcases h y hy with \u27e8f', hf'\u27e9", 704 | "exact hf'.differentiableWithinAt.const_smul _" 705 | ], 706 | "score": 0.043711879290640354, 707 | "success": true 708 | }, 709 | { 710 | "theorem": "Set.diff_empty", 711 | "proof": [ 712 | "simp [diff_eq]" 713 | ], 714 | "score": 0.028194859623908997, 715 | "success": true 716 | }, 717 | { 718 | "theorem": "Equiv.symm_apply_apply", 719 | "proof": [ 720 | "cases e", 721 | "aesop" 722 | ], 723 | "score": 0.09066382050514221, 724 | "success": true 725 | } 726 | ], 727 | "args": { 728 | "model_name": "wellecks/llmstep-mathlib4-pythia2.8b", 729 | "dataset_path": "./data/val.json", 730 | "max_iters": 50, 731 | "num_samples": 32, 732 | "num_examples": 200 733 | } 734 | } -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/banach/banach_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/banach/banach_1.png -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/banach/banach_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/banach/banach_2.png -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/banach/banach_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/banach/banach_3.png -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/banach/banach_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/banach/banach_4.png -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/banach/banach_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/banach/banach_5.png -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/leandojo_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/leandojo_1.png -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/llmsuggest/llmstep_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/llmsuggest/llmstep_gif.gif -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/llmsuggest/llmsuggest.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/llmsuggest/llmsuggest.gif -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/llmsuggest/llmsuggest_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/llmsuggest/llmsuggest_examples.png -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/proof_state_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/proof_state_1.png -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/proof_state_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/proof_state_2.png -------------------------------------------------------------------------------- /partI_nextstep/notebooks/images/proof_state_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/proof_state_3.png -------------------------------------------------------------------------------- /partI_nextstep/ntp_lean/ExtractSimple.lean: -------------------------------------------------------------------------------- 1 | -- A minimal version of LeanDojo's ExtractData.lean for instructional purposes. 2 | -- Please see LeanDojo's ExtractData.lean for a full working script to use in practice. 3 | -- 4 | -- Credits: original script is from LeanDojo https://github.com/lean-dojo/LeanDojo/ 5 | -- @article{yang2023leandojo, 6 | -- title={{LeanDojo}: Theorem Proving with Retrieval-Augmented Language Models}, 7 | -- author={Yang, Kaiyu and Swope, Aidan and Gu, Alex and Chalamala, Rahul and Song, 8 | -- Peiyang and Yu, Shixing and Godil, Saad and Prenger, Ryan and Anandkumar, Anima}, 9 | -- journal={arXiv preprint arXiv:2306.15626}, 10 | -- year={2023} 11 | -- } 12 | -- This script is essentially a slightly refactored subset of LeanDojo's script. 13 | 14 | import Lean 15 | 16 | open Lean Elab System 17 | 18 | instance : ToJson Substring where 19 | toJson s := toJson s.toString 20 | 21 | instance : ToJson String.Pos where 22 | toJson n := toJson n.1 23 | 24 | deriving instance ToJson for SourceInfo 25 | deriving instance ToJson for Syntax.Preresolved 26 | deriving instance ToJson for Syntax 27 | 28 | structure TacticTrace where 29 | stateBefore: String 30 | stateAfter: String 31 | pos: String.Pos 32 | endPos: String.Pos 33 | deriving ToJson 34 | 35 | structure Trace where 36 | commandASTs : Array Syntax 37 | tactics: Array TacticTrace 38 | deriving ToJson 39 | 40 | abbrev TraceM := StateT Trace IO 41 | 42 | 43 | def ppGoals (ctx : ContextInfo) (goals : List MVarId) : IO String := 44 | if goals.isEmpty then 45 | return "no goals" 46 | else 47 | let fmt := ctx.runMetaM {} (return Std.Format.prefixJoin "\n\n" (← goals.mapM (Meta.ppGoal ·))) 48 | return (← fmt).pretty.trim 49 | 50 | 51 | private def visitTacticInfo (ctx : ContextInfo) (ti : TacticInfo) (parent : InfoTree) : TraceM Unit := do 52 | match parent with 53 | | .node (Info.ofTacticInfo i) _ => 54 | match i.stx.getKind with 55 | | `Lean.Parser.Tactic.tacticSeq1Indented | `Lean.Parser.Tactic.tacticSeqBracketed => 56 | let ctxBefore := { ctx with mctx := ti.mctxBefore } 57 | let ctxAfter := { ctx with mctx := ti.mctxAfter } 58 | let stateBefore ← ppGoals ctxBefore ti.goalsBefore 59 | let stateAfter ← ppGoals ctxAfter ti.goalsAfter 60 | let some posBefore := ti.stx.getPos? true | pure () 61 | let some posAfter := ti.stx.getTailPos? true | pure () 62 | match ti.stx with 63 | | .node _ _ _ => 64 | modifyGet fun trace => ((), 65 | { trace with tactics := trace.tactics.push { 66 | stateBefore := stateBefore, 67 | stateAfter := stateAfter, 68 | pos := posBefore, 69 | endPos := posAfter } } 70 | ) 71 | | _ => pure () 72 | | _ => pure () 73 | | _ => pure () 74 | 75 | 76 | private def visitInfo (ctx : ContextInfo) (i : Info) (parent : InfoTree) : TraceM Unit := do 77 | match i with 78 | | .ofTacticInfo ti => visitTacticInfo ctx ti parent 79 | | _ => pure () 80 | 81 | 82 | private partial def traverseTree (ctx: ContextInfo) (tree : InfoTree) (parent : InfoTree) : TraceM Unit := do 83 | match tree with 84 | | .context ctx' t => traverseTree ctx' t tree 85 | | .node i children => 86 | visitInfo ctx i parent 87 | for x in children do 88 | traverseTree ctx x tree 89 | | _ => pure () 90 | 91 | 92 | private def traverseTopLevelTree (tree : InfoTree) : TraceM Unit := do 93 | match tree with 94 | | .context ctx t => traverseTree ctx t tree 95 | | _ => throw $ IO.userError "Errors in traverseTopLevelTree; aborting" 96 | 97 | 98 | def traverseForest (trees : Array InfoTree) : TraceM Trace := do 99 | for t in trees do 100 | traverseTopLevelTree t 101 | get 102 | 103 | 104 | def relativeTo (path parent : FilePath) : Option FilePath := 105 | let rec componentsRelativeTo (pathComps parentComps : List String) : Option FilePath := 106 | match pathComps, parentComps with 107 | | _, [] => mkFilePath pathComps 108 | | [], _ => none 109 | | (h₁ :: t₁), (h₂ :: t₂) => 110 | if h₁ == h₂ then 111 | componentsRelativeTo t₁ t₂ 112 | else 113 | none 114 | 115 | componentsRelativeTo path.components parent.components 116 | 117 | 118 | def toAbsolute (path : FilePath) : IO FilePath := do 119 | if path.isAbsolute then 120 | pure path 121 | else 122 | let cwd ← IO.currentDir 123 | pure $ cwd / path 124 | 125 | 126 | unsafe def processFile (path : FilePath) : IO Unit := do 127 | let input ← IO.FS.readFile path 128 | let opts := Options.empty.setBool `trace.Elab.info true 129 | enableInitializersExecution 130 | let inputCtx := Parser.mkInputContext input path.toString 131 | let (header, parserState, messages) ← Parser.parseHeader inputCtx 132 | let (env, messages) ← processHeader header opts messages inputCtx 133 | 134 | if messages.hasErrors then 135 | for msg in messages.toList do 136 | if msg.severity == .error then 137 | println! "ERROR: {← msg.toString}" 138 | throw $ IO.userError "Errors during import; aborting" 139 | 140 | let some modName := path.fileStem | throw $ IO.userError s!"Invalid path: {path}" 141 | let env := env.setMainModule modName.toName 142 | let commandState := { Command.mkState env messages opts with infoState.enabled := true } 143 | let s ← IO.processCommands inputCtx parserState commandState 144 | let commands := s.commands.pop -- Remove EOI command. 145 | let trees := s.commandState.infoState.trees.toArray 146 | let trace ← (traverseForest trees).run' ⟨#[header] ++ commands, #[]⟩ 147 | 148 | let cwd ← IO.currentDir 149 | let some relativePath := relativeTo path cwd | throw $ IO.userError s!"Invalid path: {path}" 150 | println! "Input file: {relativePath}" 151 | let json_path := ( 152 | relativePath.withExtension "ast.json" 153 | ) 154 | IO.FS.writeFile json_path (toJson trace).pretty 155 | println! "AST: {json_path}" 156 | 157 | 158 | unsafe def main (args : List String) : IO Unit := do 159 | match args with 160 | | path :: _ => 161 | processFile (← toAbsolute ⟨path⟩) 162 | | [] => 163 | println! "Please provide a .lean file (lake env lean --run ExtractData.lean FILENAME.lean)" -------------------------------------------------------------------------------- /partI_nextstep/ntp_lean/LLMsuggest.lean: -------------------------------------------------------------------------------- 1 | /- 2 | `llmsuggest` tactic for LLM-based next-step suggestions in Lean4. 3 | 4 | This is a minimal version of `llmstep` built for tutorial purposes. 5 | `llmstep`: https://github.com/wellecks/llmstep 6 | -/ 7 | 8 | import Mathlib.Tactic 9 | 10 | 11 | open Lean 12 | 13 | /- Calls a `suggest.py` python script with the given `args`. -/ 14 | def runSuggest (args : Array String) : IO String := do 15 | let cwd ← IO.currentDir 16 | let path := cwd / "partI_nextstep" / "ntp_python" / "llmsuggest" / "suggest.py" 17 | unless ← path.pathExists do 18 | dbg_trace f!"{path}" 19 | throw <| IO.userError "could not find python script suggest.py" 20 | let s ← IO.Process.run { cmd := "python3", args := #[path.toString] ++ args } 21 | return s 22 | 23 | /- Display clickable suggestions in the VSCode Lean Infoview. 24 | When a suggestion is clicked, this widget replaces the `llmstep` call 25 | with the suggestion, and saves the call in an adjacent comment. 26 | Code based on `Std.Tactic.TryThis.tryThisWidget`. -/ 27 | @[widget] def llmstepTryThisWidget : Widget.UserWidgetDefinition where 28 | name := "llmstep suggestions" 29 | javascript := " 30 | import * as React from 'react'; 31 | import { EditorContext } from '@leanprover/infoview'; 32 | const e = React.createElement; 33 | export default function(props) { 34 | const editorConnection = React.useContext(EditorContext) 35 | function onClick(suggestion) { 36 | editorConnection.api.applyEdit({ 37 | changes: { [props.pos.uri]: [{ range: 38 | props.range, 39 | newText: suggestion + ' -- ' + props.tactic 40 | }] } 41 | }) 42 | } 43 | return e('div', 44 | {className: 'ml1'}, 45 | e('ul', {className: 'font-code pre-wrap'}, [ 46 | 'Try this: ', 47 | ...(props.suggestions.map(suggestion => 48 | e('li', {onClick: () => onClick(suggestion), 49 | className: 'link pointer dim', title: 'Apply suggestion'}, 50 | suggestion 51 | ) 52 | )), 53 | props.info 54 | ])) 55 | }" 56 | 57 | 58 | /- Adds multiple suggestions to the Lean InfoView. 59 | Code based on `Std.Tactic.addSuggestion`. -/ 60 | def addSuggestions (tacRef : Syntax) (suggestions: List String) 61 | (origSpan? : Option Syntax := none) 62 | (extraMsg : String := "") : MetaM Unit := do 63 | if let some tacticRange := (origSpan?.getD tacRef).getRange? then 64 | let map ← getFileMap 65 | let start := findLineStart map.source tacticRange.start 66 | let body := map.source.findAux (· ≠ ' ') tacticRange.start start 67 | let texts := suggestions.map fun text => ( 68 | Std.Format.prettyExtra text 69 | (indent := (body - start).1) 70 | (column := (tacticRange.start - start).1) 71 | ) 72 | let start := (tacRef.getRange?.getD tacticRange).start 73 | let stop := (tacRef.getRange?.getD tacticRange).stop 74 | let stxRange := 75 | { start := map.lineStart (map.toPosition start).line 76 | stop := map.lineStart ((map.toPosition stop).line + 1) } 77 | let tacticRange := map.utf8RangeToLspRange tacticRange 78 | let tactic := Std.Format.prettyExtra f!"{tacRef.prettyPrint}" 79 | let json := Json.mkObj [ 80 | ("tactic", tactic), 81 | ("suggestions", toJson texts), 82 | ("range", toJson tacticRange), 83 | ("info", extraMsg) 84 | ] 85 | Widget.saveWidgetInfo ``llmstepTryThisWidget json (.ofRange stxRange) 86 | 87 | 88 | -- `llmsuggest` tactic. 89 | syntax "llmsuggest" : tactic 90 | elab_rules : tactic 91 | | `(tactic | llmsuggest%$tac) => 92 | Lean.Elab.Tactic.withMainContext do 93 | let goal ← Lean.Elab.Tactic.getMainGoal 94 | let ppgoal ← Lean.Meta.ppGoal goal 95 | let ppgoalstr := toString ppgoal 96 | let suggest ← runSuggest #[ppgoalstr] 97 | addSuggestions tac $ suggest.splitOn "[SUGGESTION]" 98 | 99 | 100 | /- Examples -/ 101 | example : 2 = 2 := by 102 | rfl -- llmsuggest 103 | 104 | 105 | example (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by 106 | intro h n -- llmsuggest 107 | exact h (Nat.le_succ _) -- llmsuggest 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /partI_nextstep/ntp_lean/examples/example0.lean: -------------------------------------------------------------------------------- 1 | import Mathlib.Data.Nat.Prime 2 | 3 | theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by 4 | rw [Nat.coprime] at h 5 | exact h -------------------------------------------------------------------------------- /partI_nextstep/ntp_lean/examples/example_demo.lean: -------------------------------------------------------------------------------- 1 | import Mathlib.Data.Nat.Prime 2 | 3 | variable (α: Type) (R S T : Set α) 4 | 5 | 6 | example (h1: R ⊆ S) (h2: S ⊆ T) : (R ⊆ T) := by 7 | exact h1.trans h2 -------------------------------------------------------------------------------- /partI_nextstep/ntp_python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/ntp_python/__init__.py -------------------------------------------------------------------------------- /partI_nextstep/ntp_python/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, Optional, Sequence 3 | from tqdm import tqdm 4 | from pathlib import Path 5 | import os 6 | import json 7 | 8 | TOKEN_MAP = { 9 | 'pad': '[PAD]', 10 | 'eos': '<|endoftext|>', 11 | } 12 | 13 | PHRASE_MAP = { 14 | 'goal': '[GOAL]', 15 | 'proofstep': '[PROOFSTEP]', 16 | } 17 | 18 | 19 | def _download_and_unpack(tarball_url, data_dir, overwrite): 20 | import subprocess 21 | if (not overwrite) and Path(data_dir).exists(): 22 | return 23 | Path(data_dir).mkdir(parents=True, exist_ok=True) 24 | archive_path = os.path.join(data_dir, "archive.tar.gz") 25 | subprocess.call(['wget', '-O', archive_path, tarball_url]) 26 | subprocess.call(['tar', '-xzf', archive_path, '-C', data_dir]) 27 | 28 | 29 | def _load_ds(data_dir): 30 | ds = {} 31 | for split in ['train', 'val', 'test']: 32 | ds[split] = json.load(open(os.path.join( 33 | data_dir, 'leandojo_benchmark_4', 'random', f'{split}.json'), 'r') 34 | ) 35 | return ds 36 | 37 | 38 | def _save_splits(splits, data_dir, tag): 39 | print("Saving split to disk...") 40 | out_dir = os.path.join(data_dir, 'processed') 41 | for split, examples in tqdm(splits.items(), total=len(splits)): 42 | Path(out_dir).mkdir(parents=True, exist_ok=True) 43 | out_file = os.path.join( 44 | out_dir, '%s-%s.jsonl' % (tag, split) 45 | ) 46 | with open(out_file, 'w') as f: 47 | for example in examples: 48 | f.write(json.dumps(example)) 49 | f.write('\n') 50 | 51 | 52 | def _print_stats(splits): 53 | for split, examples in splits.items(): 54 | print("%s\t%d" % (split, len(examples))) 55 | 56 | 57 | 58 | 59 | # --- Proofstep 60 | def _fmt_proofstep(state_before, tactic): 61 | # [GOAL]{state_before}[PROOFSTEP]{tactic}<|endoftext|> 62 | inp = f"{PHRASE_MAP['goal']}{state_before}{PHRASE_MAP['proofstep']}" 63 | out = f"{tactic}{TOKEN_MAP['eos']}" 64 | return inp, out 65 | 66 | 67 | def fmt_proofstep(split): 68 | examples = [] 69 | for traced_theorem in split: 70 | for tactic_example in traced_theorem['traced_tactics']: 71 | inp, out = _fmt_proofstep(tactic_example['state_before'], tactic_example['tactic']) 72 | examples.append({ 73 | 'input': inp, 74 | 'output': out, 75 | }) 76 | return examples 77 | 78 | 79 | def proofstep(data_dir): 80 | ds = _load_ds(data_dir) 81 | out_ds = {} 82 | for split in ds: 83 | out_ds[split] = fmt_proofstep(ds[split]) 84 | 85 | _save_splits( 86 | splits=out_ds, 87 | data_dir=data_dir, 88 | tag='proofstep' 89 | ) 90 | _print_stats( 91 | splits=out_ds 92 | ) 93 | return out_ds 94 | # --- 95 | 96 | 97 | def main(args): 98 | proofstep(args.datadir) 99 | 100 | 101 | def setup(args): 102 | # Download data 103 | _download_and_unpack( 104 | tarball_url='https://zenodo.org/record/8040110/files/leandojo_benchmark_4_v1.tar.gz', 105 | data_dir=args.data_dir, 106 | overwrite=args.overwrite 107 | ) 108 | 109 | 110 | if __name__ == '__main__': 111 | import argparse 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument('--overwrite', action='store_true') 114 | parser.add_argument('--data-dir', type=str, default='./data/leandojo_benchmark_4') 115 | 116 | args = parser.parse_args() 117 | setup(args) 118 | main(args) 119 | -------------------------------------------------------------------------------- /partI_nextstep/ntp_python/llmsuggest/server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, jsonify 2 | 3 | import argparse 4 | import transformers 5 | import torch 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--hf-model', type=str, default='wellecks/llmstep-mathlib4-pythia2.8b') 9 | parser.add_argument('--port', type=int, default=5000) 10 | parser.add_argument('--num-samples', type=int, default=5) 11 | args = parser.parse_args() 12 | 13 | print("Loading model...") 14 | model = transformers.GPTNeoXForCausalLM.from_pretrained(args.hf_model) 15 | if torch.cuda.is_available(): 16 | model.cuda() 17 | model.eval() 18 | 19 | tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(args.hf_model) 20 | print("Done.") 21 | 22 | 23 | def generate(prompt, num_samples): 24 | print(prompt) 25 | input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device) 26 | out = model.generate( 27 | input_ids, 28 | max_new_tokens=50, 29 | pad_token_id=tokenizer.eos_token_id, 30 | num_return_sequences=num_samples, 31 | return_dict_in_generate=True, 32 | num_beams=num_samples, 33 | output_scores=True 34 | ) 35 | texts = tokenizer.batch_decode( 36 | out.sequences[:,input_ids.shape[1]:], 37 | skip_special_tokens=True 38 | ) 39 | texts = _unique_sorted(texts, out.sequences_scores.tolist()) 40 | return texts 41 | 42 | 43 | def _unique_sorted(texts, scores): 44 | texts_, scores_ = [], [] 45 | for t, s in sorted(zip(texts, scores), key=lambda x: -x[1]): 46 | if t not in texts_: 47 | texts_.append(t) 48 | scores_.append(s) 49 | return texts_ 50 | 51 | app = Flask(__name__) 52 | 53 | @app.route('/', methods=['POST']) 54 | def process_request(): 55 | data = request.get_json() 56 | 57 | tactic_state = data.get('tactic_state') 58 | 59 | prompt = """[GOAL]%s[PROOFSTEP]""" % (tactic_state) 60 | texts = generate(prompt, args.num_samples) 61 | 62 | response = {"suggestions": texts} 63 | return jsonify(response) 64 | 65 | if __name__ == '__main__': 66 | app.run(debug=True, host='0.0.0.0', port=args.port) 67 | -------------------------------------------------------------------------------- /partI_nextstep/ntp_python/llmsuggest/suggest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import http.client 3 | import json 4 | import sys 5 | 6 | 7 | def suggest(tactic_state): 8 | conn = http.client.HTTPConnection("localhost", 5000) 9 | headers = {'Content-type': 'application/json'} 10 | body = json.dumps({"tactic_state": sys.argv[1]}) 11 | conn.request("POST", "/", body, headers) 12 | response = conn.getresponse() 13 | data = response.read() 14 | data_dict = json.loads(data) 15 | print('[SUGGESTION]'.join(data_dict['suggestions'])) 16 | conn.close() 17 | 18 | if __name__ == "__main__": 19 | suggest(sys.argv[1]) -------------------------------------------------------------------------------- /partI_nextstep/ntp_python/postprocess_ast.py: -------------------------------------------------------------------------------- 1 | # A naive post-processing of the .ast.json file for instructional purposes. 2 | # Please see LeanDojo's `TracedFile` and `FileNode4` classes for a complete version. 3 | 4 | def reconstruct_theorem_proof(command_ast): 5 | pieces = [] 6 | pos = [] 7 | endPos = [] 8 | def _process_arg(node_arg): 9 | if 'atom' in node_arg: 10 | pieces.append(node_arg['atom']['info']['original']['leading']) 11 | pieces.append(node_arg['atom']['val']) 12 | pieces.append(node_arg['atom']['info']['original']['trailing']) 13 | pos.append(node_arg['atom']['info']['original']['pos']) 14 | endPos.append(node_arg['atom']['info']['original']['endPos']) 15 | if 'ident' in node_arg: 16 | pieces.append(node_arg['ident']['info']['original']['leading']) 17 | pieces.append(node_arg['ident']['rawVal']) 18 | pieces.append(node_arg['ident']['info']['original']['trailing']) 19 | pos.append(node_arg['ident']['info']['original']['pos']) 20 | endPos.append(node_arg['ident']['info']['original']['endPos']) 21 | if 'node' in node_arg: 22 | _process_node(node_arg['node']) 23 | 24 | def _process_node(node): 25 | if 'args' in node: 26 | for arg in node['args']: 27 | _process_arg(arg) 28 | 29 | _process_node(command_ast['node']) 30 | out = ''.join(pieces) 31 | start = min(pos) 32 | end = max(endPos) 33 | return start, end, out 34 | 35 | def get_theorem(start, ast): 36 | for command in ast['commandASTs']: 37 | start_, end_, thm = reconstruct_theorem_proof(command) 38 | if start_ <= start <= end_: 39 | return thm -------------------------------------------------------------------------------- /partI_nextstep/ntp_python/proofsearch_dojo.py: -------------------------------------------------------------------------------- 1 | # Utilities for interacting with Lean and proof search 2 | import os 3 | import transformers 4 | from lean_dojo import * 5 | import json 6 | import torch 7 | from datetime import datetime 8 | import heapq 9 | import transformers 10 | import random 11 | from typing import List, Tuple 12 | from tqdm import tqdm, trange 13 | 14 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 15 | 16 | 17 | def generate(prompt, model, tokenizer, num_samples): 18 | input_ids = tokenizer.encode( 19 | prompt, return_tensors='pt', truncation=True, max_length=1024 20 | ).to(model.device) 21 | 22 | texts, scores = [], [] 23 | with torch.no_grad(): 24 | out = model.generate( 25 | input_ids, 26 | max_new_tokens=128, 27 | do_sample=False, 28 | pad_token_id=tokenizer.eos_token_id, 29 | num_return_sequences=num_samples, 30 | return_dict_in_generate=True, 31 | output_scores=True, 32 | num_beams=num_samples 33 | ) 34 | output_tokens = out.sequences[:, input_ids.shape[1]:] 35 | texts.extend(tokenizer.batch_decode( 36 | output_tokens, 37 | skip_special_tokens=True 38 | )) 39 | scores.extend(out.sequences_scores.view(-1).tolist()) 40 | 41 | texts, scores = _unique_sorted(texts, scores) 42 | return texts, scores 43 | 44 | 45 | def _unique_sorted(texts, scores): 46 | texts_, scores_ = [], [] 47 | for t, s in sorted(zip(texts, scores), key=lambda x: -x[1]): 48 | if t not in texts_: 49 | texts_.append(t) 50 | scores_.append(s) 51 | return texts_, scores_ 52 | 53 | 54 | def _tactic_state(state): 55 | if isinstance(state, TacticState): 56 | ts = state.pp 57 | else: 58 | ts = state.unsolved_tactic_state 59 | return ts 60 | 61 | 62 | def _prompt(ts): 63 | prompt = f"[GOAL]{ts}[PROOFSTEP]" 64 | return prompt 65 | 66 | 67 | def best_first_search(theorem, model, tokenizer, max_iters, num_samples, timeout=600) -> dict: 68 | try: 69 | with Dojo(theorem, hard_timeout=60 + timeout) as (dojo, init_state): 70 | queue = [(0.0, [], init_state)] 71 | visited = set() 72 | for _ in trange(max_iters): 73 | if len(queue) == 0: 74 | break 75 | 76 | total_score, steps, state = heapq.heappop(queue) 77 | ts = _tactic_state(state) 78 | visited.add(ts) 79 | 80 | step_cands, step_scores = generate( 81 | _prompt(ts), model, tokenizer, num_samples 82 | ) 83 | 84 | for step, score in zip(step_cands, step_scores): 85 | result = dojo.run_tac(state, step) 86 | if isinstance(result, ProofFinished): 87 | return { 88 | 'theorem': theorem.full_name, 89 | 'proof': steps + [step], 90 | 'score': total_score - score, 91 | 'success': True, 92 | 'failure_reason': '' 93 | } 94 | elif isinstance(result, TacticState): 95 | if _tactic_state(result) not in visited: 96 | # Score is negative log probability summed across steps 97 | new_score = (total_score - score) 98 | heapq.heappush( 99 | queue, (new_score, steps+[step], result) 100 | ) 101 | except (DojoInitError, DojoHardTimeoutError, DojoCrashError) as e: 102 | return {'theorem': theorem.full_name, 'success': False, 'failure_reason': str(e)} 103 | 104 | return {'theorem': theorem.full_name, 'success': False, 'failure_reason': 'SearchEnded'} 105 | 106 | 107 | def _save(model_name, results, args_dict, dt): 108 | output_file = 'results__%s__%s.json' % (model_name.replace('/', '_'), dt) 109 | with open(output_file, 'w') as f: 110 | json.dump({ 111 | 'results': results, 112 | 'args': args_dict 113 | } , f, indent=4) 114 | print(output_file) 115 | 116 | 117 | def load_model(model_name): 118 | model = transformers.GPTNeoXForCausalLM.from_pretrained( 119 | model_name, torch_dtype=torch.float16 120 | ) 121 | tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name) 122 | model.eval() 123 | return model, tokenizer 124 | 125 | 126 | if __name__ == '__main__': 127 | import argparse 128 | 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument( 131 | '--model-name', 132 | default='wellecks/llmstep-mathlib4-pythia2.8b', 133 | choices=['wellecks/llmstep-mathlib4-pythia2.8b'] 134 | ) 135 | parser.add_argument('--dataset-path', default='data/val.json') 136 | parser.add_argument('--max-iters', type=int, default=100) 137 | parser.add_argument('--num-samples', type=int, default=32) 138 | parser.add_argument('--num-examples', type=int, default=200) 139 | args = parser.parse_args() 140 | 141 | model, tokenizer = load_model(args.model_name, args.vllm) 142 | 143 | URL = "https://github.com/leanprover-community/mathlib4" 144 | COMMIT = "5a919533f110b7d76410134a237ee374f24eaaad" 145 | repo = LeanGitRepo(URL, COMMIT) 146 | 147 | dt = datetime.now().strftime("%d-%m-%Y-%H-%M-%S") 148 | 149 | with open(args.dataset_path) as f: 150 | data = json.load(f) 151 | 152 | random.seed(43) 153 | data = random.sample(data, args.num_examples) 154 | results = [] 155 | for example in tqdm(data, total=len(data)): 156 | file_path = example['file_path'] 157 | theorem_name = example['full_name'] 158 | theorem = Theorem(repo, file_path, theorem_name) 159 | result = best_first_search( 160 | theorem, model, tokenizer, 161 | max_iters=args.max_iters, 162 | num_samples=args.num_samples 163 | ) 164 | print(result) 165 | print('\n-----\n') 166 | results.append(result) 167 | 168 | _save(args.model_name, results, args.__dict__, dt) 169 | print(len([x for x in results if x['success']])/len(results)) -------------------------------------------------------------------------------- /partI_nextstep/ntp_python/proofsearch_pylean.py: -------------------------------------------------------------------------------- 1 | # Utilities for interacting with Lean and proof search 2 | 3 | from pylean import LeanServer 4 | import torch 5 | import heapq 6 | import concurrent 7 | import transformers 8 | import os 9 | from tqdm import tqdm 10 | from concurrent.futures import ThreadPoolExecutor 11 | from typing import List, Tuple 12 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 13 | 14 | 15 | def is_done(state): 16 | return state['sorries'] == [] and state['messages'] == [] 17 | 18 | 19 | def get_goal(state): 20 | goal = None 21 | for msg in state['messages']: 22 | if msg['data'].startswith('unsolved goals\n'): 23 | goal = '\n'.join(msg['data'].split('\n')[1:]) 24 | elif msg['severity'] == 'error': 25 | return None 26 | return goal 27 | 28 | 29 | def get_errors(state): 30 | return state['messages'] 31 | 32 | 33 | def parse_step(step): 34 | step = step.replace('<|endoftext|>', '') 35 | return step 36 | 37 | 38 | def format_code(header, statement, steps_so_far, next_step): 39 | return header + (statement.replace(" {}", "") + '\n' + '\n'.join(steps_so_far + [next_step])) 40 | 41 | 42 | def run_code(code): 43 | lean = LeanServer() 44 | out = lean.run_code(code) 45 | lean.proc.close() 46 | del lean 47 | return out 48 | 49 | 50 | def sequence_scores(out, prompt_length, model, tokenizer): 51 | # Returns each output sequence's log probability normalized by the number of tokens. 52 | # An output sequence is defined as the tokens after the prompt up to and including eos. 53 | text = tokenizer.batch_decode(out.sequences) 54 | input_ids = tokenizer( 55 | text, return_tensors="pt", padding='longest', truncation=True 56 | ).to(model.device) 57 | with torch.no_grad(): 58 | out = model(**input_ids) 59 | probs = torch.log_softmax(out.logits, dim=-1).detach() 60 | probs = probs[:, :-1, :] 61 | input_ids_shifted = input_ids.input_ids[:, 1:] 62 | log_probs = torch.gather(probs, 2, input_ids_shifted[:, :, None]).squeeze(-1) 63 | log_probs = log_probs[:, prompt_length:] 64 | up_to_eos_mask = (input_ids_shifted[:,prompt_length:].eq( 65 | tokenizer.eos_token_id).cumsum(1).cumsum(1) <= 1).type(log_probs.dtype) 66 | normalized_sequence_scores = (log_probs * up_to_eos_mask).sum(1) / up_to_eos_mask.sum(1) 67 | return normalized_sequence_scores 68 | 69 | 70 | def generate(prompt, model, tokenizer, temperatures, num_samples) -> Tuple[List[str], List[float]]: 71 | input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device) 72 | texts = [] 73 | scores = [] 74 | with torch.no_grad(): 75 | # Does beam search at temp 0.0, otherwise temperature sampling. 76 | for temp in temperatures: 77 | decoding_params = dict( 78 | max_new_tokens=256, 79 | do_sample=temp > 0, 80 | temperature=temp, 81 | pad_token_id=tokenizer.eos_token_id, 82 | num_return_sequences=num_samples, 83 | return_dict_in_generate=True, 84 | output_scores=True, 85 | ) 86 | if temp == 0.0: 87 | decoding_params['num_beams'] = num_samples 88 | out = model.generate( 89 | input_ids, **decoding_params 90 | ) 91 | 92 | texts.extend(tokenizer.batch_decode( 93 | out.sequences[:,input_ids.shape[1]:], 94 | skip_special_tokens=True 95 | )) 96 | scores_ = sequence_scores( 97 | out=out, 98 | prompt_length=input_ids.shape[1], 99 | model=model, 100 | tokenizer=tokenizer 101 | ) 102 | scores.extend(scores_.view(-1).tolist()) 103 | 104 | texts, scores = _unique_sorted(texts, scores) 105 | return texts, scores 106 | 107 | 108 | def _unique_sorted(texts, scores): 109 | texts_, scores_ = [], [] 110 | for t, s in sorted(zip(texts, scores), key=lambda x: -x[1]): 111 | if t not in texts_: 112 | texts_.append(t) 113 | scores_.append(s) 114 | return texts_, scores_ 115 | 116 | 117 | def _print_type_checked_candidates(results): 118 | print('--- type-checked candidates:\n\t' + '\n\t'.join( 119 | '(%.3f) %s' % (step_score, step) 120 | for state, step, step_score in results if ( 121 | get_goal(state) is not None or is_done(state)) 122 | )) 123 | 124 | 125 | def _print_current(theorem_statement, steps): 126 | print('--- current:\n\t%s\n\t%s' % ( 127 | theorem_statement.replace('{}', ''), 128 | '\n\t'.join(steps)) 129 | ) 130 | 131 | 132 | def best_first_search(model, tokenizer, header, statement, max_iters, temperatures, num_samples, verbose=False) -> dict: 133 | goal = get_goal(run_code(header + statement)) 134 | if goal is None: 135 | return { 136 | 'theorem_statement': statement, 137 | 'success': False, 138 | 'msg': run_code(header + statement) 139 | } 140 | 141 | # Score, steps-so-far, goal state 142 | queue = [(0.0, [], goal)] 143 | visited = set() 144 | while len(queue) > 0 and max_iters > 0: 145 | # Dequeue the tuple with minimum score 146 | score, steps, goal = heapq.heappop(queue) 147 | visited.add(goal) 148 | if verbose: 149 | _print_current(statement, steps) 150 | 151 | # Generate next-step candidates 152 | prompt = f"[GOAL]{goal}[PROOFSTEP]" 153 | step_cands, step_scores = generate( 154 | prompt, 155 | model, 156 | tokenizer, 157 | temperatures=temperatures, 158 | num_samples=num_samples 159 | ) 160 | 161 | # Run type checking in parallel via futures. 162 | with ThreadPoolExecutor(max_workers=16) as executor: 163 | # We need to save the step and score associated to each future. 164 | future2step = {} 165 | for step, step_score in zip(step_cands, step_scores): 166 | code = format_code(header, statement, steps, step) 167 | future = executor.submit(run_code, **dict(code=code)) 168 | future2step[future] = (step, step_score) 169 | 170 | # Collect the type checking results as they complete. 171 | results = [] 172 | for future in tqdm(concurrent.futures.as_completed(future2step.keys()), total=len(future2step)): 173 | result = future.result() 174 | results.append((result, *future2step[future])) 175 | 176 | if verbose: 177 | _print_type_checked_candidates(results) 178 | for state, step, step_score in results: 179 | # Stop if we have found a complete proof. 180 | if is_done(state): 181 | return { 182 | 'theorem_statement': statement, 183 | 'proof': steps + [step], 184 | 'state': state, 185 | 'score': score - step_score, 186 | 'success': True 187 | } 188 | goal_cand = get_goal(state) 189 | # Add new candidates to the queue. 190 | if goal_cand is not None and goal_cand not in visited: 191 | # Score is normalized negative log probability summed across steps 192 | new_score = (score - step_score) 193 | heapq.heappush( 194 | queue, (new_score, steps+[step], goal_cand) 195 | ) 196 | 197 | max_iters -= 1 198 | 199 | return {'theorem_statement': statement, 'success': False} 200 | 201 | 202 | def _save(results): 203 | from datetime import datetime 204 | import json 205 | now = datetime.now() 206 | dt_string = now.strftime("%d-%m-%Y-%H-%M-%S") 207 | output_file = 'results__%s.json' % (dt_string) 208 | with open(output_file, 'w') as f: 209 | json.dump(results, f, indent=4) 210 | print(output_file) 211 | 212 | 213 | def load_model(model_name): 214 | model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name) 215 | tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name) 216 | model.eval() 217 | return model, tokenizer 218 | 219 | 220 | if __name__ == '__main__': 221 | model, tokenizer = load_model('wellecks/llmstep-mathlib4-pythia2.8b') 222 | 223 | evaluation_theorems = [ 224 | """theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}""", 225 | """theorem thm2 (x y : ℝ) : x < y → 0 < y - x := by {}""", 226 | """theorem thm3 (n : Nat) : n ≥ 0 := by {}""", 227 | """theorem thm4 (x y z : ℝ) : x ≤ y → y ≤ z → x ≤ z := by {}""", 228 | """theorem thm5 (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by {}""", 229 | """theorem thm6: r ⊆ s → s ⊆ t → r ⊆ t := by {}""", 230 | """theorem thm7 (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by {}""", 231 | """theorem thm8 (c : ℝ) : Injective fun x => x + c := by {}""", 232 | """theorem thm9 (p q : Prop) : (p ∧ q) → ¬(¬p ∨ ¬q) := by {}""", 233 | """theorem thm10 (A B : Set ℕ) : A ⊆ B → ∀ n, n ∈ A → n ∈ B := by {}""", 234 | """theorem thm11 (injg : Injective g) (injf : Injective f) : Injective fun x => g (f x) := by {}""", 235 | """theorem thm12 (a b : ℕ) (h : a ≤ b) : a * (a + 1) ≤ b * (b + 1) := by {}""", 236 | """theorem thm13 (a b : ℕ) (h : a ≠ b) : a * 2 ≠ b * 2 := by {}""", 237 | ] 238 | 239 | # Shared header for the theorems above 240 | header = """import Mathlib.Data.Nat.Factorization.Basic 241 | import Mathlib.Data.Nat.Prime 242 | import Mathlib.Data.Real.Basic 243 | 244 | open BigOperators 245 | open Function 246 | variable {α : Type _} (r s t : Set α) 247 | 248 | """ 249 | 250 | results = [] 251 | for theorem in evaluation_theorems: 252 | result = best_first_search( 253 | model, tokenizer, header, theorem, 254 | max_iters=32, 255 | temperatures=[0.5], 256 | num_samples=16 257 | ) 258 | print(result) 259 | print('\n-----\n') 260 | results.append(result) 261 | 262 | print(len([x for x in results if x['success']])/len(results)) 263 | _save(results) -------------------------------------------------------------------------------- /partI_nextstep/ntp_python/tune.py: -------------------------------------------------------------------------------- 1 | # Modified Alpaca finetuning script. 2 | # Original: 3 | # https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py 4 | 5 | import copy 6 | import logging 7 | import ndjson 8 | from dataclasses import dataclass, field 9 | from typing import Dict, Optional, Sequence 10 | 11 | import torch 12 | import transformers 13 | from torch.utils.data import Dataset 14 | from transformers import Trainer 15 | from data import TOKEN_MAP 16 | 17 | MASK_INPUT = True 18 | IGNORE_INDEX = -100 19 | DEFAULT_PAD_TOKEN = TOKEN_MAP['pad'] 20 | DEFAULT_EOS_TOKEN = TOKEN_MAP['eos'] 21 | 22 | 23 | @dataclass 24 | class ModelArguments: 25 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped") 26 | 27 | 28 | @dataclass 29 | class DataArguments: 30 | train_data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 31 | valid_data_path: str = field(default=None, metadata={"help": "Path to the validation data."}) 32 | 33 | 34 | @dataclass 35 | class TrainingArguments(transformers.TrainingArguments): 36 | cache_dir: Optional[str] = field(default=None) 37 | optim: str = field(default="adamw_torch") 38 | model_max_length: int = field( 39 | default=1024, 40 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 41 | ) 42 | 43 | 44 | def smart_tokenizer_and_embedding_resize( 45 | special_tokens_dict: Dict, 46 | tokenizer: transformers.PreTrainedTokenizer, 47 | model: transformers.PreTrainedModel, 48 | ): 49 | """Resize tokenizer and embedding. 50 | 51 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 52 | """ 53 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 54 | model.resize_token_embeddings(len(tokenizer)) 55 | 56 | if num_new_tokens > 0: 57 | input_embeddings = model.get_input_embeddings().weight.data 58 | output_embeddings = model.get_output_embeddings().weight.data 59 | 60 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 61 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 62 | 63 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 64 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 65 | 66 | 67 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 68 | """Tokenize a list of strings.""" 69 | tokenized_list = [ 70 | tokenizer( 71 | text, 72 | return_tensors="pt", 73 | padding="longest", 74 | max_length=tokenizer.model_max_length, 75 | truncation=True, 76 | ) 77 | for text in strings 78 | ] 79 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 80 | input_ids_lens = labels_lens = [ 81 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 82 | ] 83 | return dict( 84 | input_ids=input_ids, 85 | labels=labels, 86 | input_ids_lens=input_ids_lens, 87 | labels_lens=labels_lens, 88 | ) 89 | 90 | 91 | def preprocess( 92 | sources: Sequence[str], 93 | targets: Sequence[str], 94 | tokenizer: transformers.PreTrainedTokenizer, 95 | ) -> Dict: 96 | """Preprocess the data by tokenizing.""" 97 | examples = [s + t for s, t in zip(sources, targets)] 98 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 99 | input_ids = examples_tokenized["input_ids"] 100 | labels = copy.deepcopy(input_ids) 101 | if MASK_INPUT: 102 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 103 | label[:source_len] = IGNORE_INDEX 104 | return dict(input_ids=input_ids, labels=labels) 105 | 106 | 107 | class SupervisedDataset(Dataset): 108 | """Dataset for supervised fine-tuning.""" 109 | 110 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): 111 | super(SupervisedDataset, self).__init__() 112 | logging.warning("Loading data...") 113 | list_data_dict = ndjson.load(data_path) 114 | 115 | logging.warning("Formatting inputs...") 116 | sources = [example['input'] for example in list_data_dict] 117 | targets = [example['output'] for example in list_data_dict] 118 | 119 | logging.warning("Tokenizing inputs... This may take some time...") 120 | data_dict = preprocess(sources, targets, tokenizer) 121 | 122 | self.input_ids = data_dict["input_ids"] 123 | self.labels = data_dict["labels"] 124 | 125 | def __len__(self): 126 | return len(self.input_ids) 127 | 128 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 129 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 130 | 131 | 132 | @dataclass 133 | class DataCollatorForSupervisedDataset(object): 134 | """Collate examples for supervised fine-tuning.""" 135 | 136 | tokenizer: transformers.PreTrainedTokenizer 137 | 138 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 139 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 140 | input_ids = torch.nn.utils.rnn.pad_sequence( 141 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 142 | ) 143 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 144 | return dict( 145 | input_ids=input_ids, 146 | labels=labels, 147 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 148 | ) 149 | 150 | 151 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: 152 | """Make dataset and collator for supervised fine-tuning.""" 153 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.train_data_path) 154 | eval_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.valid_data_path) 155 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 156 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator) 157 | 158 | 159 | def train(): 160 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 161 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 162 | 163 | model = transformers.AutoModelForCausalLM.from_pretrained( 164 | model_args.model_name_or_path, 165 | cache_dir=training_args.cache_dir, 166 | ) 167 | 168 | tokenizer = transformers.AutoTokenizer.from_pretrained( 169 | model_args.model_name_or_path, 170 | cache_dir=training_args.cache_dir, 171 | model_max_length=training_args.model_max_length, 172 | padding_side="right", 173 | use_fast=False, 174 | ) 175 | special_tokens_dict = dict() 176 | if tokenizer.pad_token is None: 177 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 178 | if tokenizer.eos_token is None: 179 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 180 | 181 | smart_tokenizer_and_embedding_resize( 182 | special_tokens_dict=special_tokens_dict, 183 | tokenizer=tokenizer, 184 | model=model, 185 | ) 186 | 187 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 188 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 189 | trainer.train() 190 | trainer.save_state() 191 | trainer.save_model(output_dir=training_args.output_dir) 192 | 193 | 194 | if __name__ == "__main__": 195 | train() -------------------------------------------------------------------------------- /partI_nextstep/scripts/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 12, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "optimizer": { 11 | "type": "AdamW", 12 | "params": { 13 | "lr": "auto", 14 | "betas": "auto", 15 | "eps": "auto", 16 | "weight_decay": "auto" 17 | } 18 | }, 19 | "scheduler": { 20 | "type": "WarmupLR", 21 | "params": { 22 | "warmup_min_lr": "auto", 23 | "warmup_max_lr": "auto", 24 | "warmup_num_steps": "auto" 25 | } 26 | }, 27 | "zero_optimization": { 28 | "stage": 3, 29 | "offload_optimizer": { 30 | "device": "cpu", 31 | "pin_memory": false 32 | }, 33 | "offload_param": { 34 | "device": "cpu", 35 | "pin_memory": false 36 | }, 37 | "overlap_comm": true, 38 | "contiguous_gradients": true, 39 | "sub_group_size": 1e9, 40 | "reduce_bucket_size": "auto", 41 | "stage3_prefetch_bucket_size": "auto", 42 | "stage3_param_persistence_threshold": "auto", 43 | "stage3_max_live_parameters": 1e9, 44 | "stage3_max_reuse_distance": 1e9, 45 | "stage3_gather_fp16_weights_on_model_save": true 46 | }, 47 | "gradient_accumulation_steps": "auto", 48 | "gradient_clipping": "auto", 49 | "steps_per_print": 2000, 50 | "train_batch_size": "auto", 51 | "train_micro_batch_size_per_gpu": "auto", 52 | "wall_clock_breakdown": false 53 | } -------------------------------------------------------------------------------- /partI_nextstep/scripts/tune_proofstep.sh: -------------------------------------------------------------------------------- 1 | REPO_DIR=/path/to/ntptutorial 2 | TRAIN_FILE=${REPO_DIR}/data/leandojo_benchmark_4/processed/proofstep-train.jsonl 3 | VALID_FILE=${REPO_DIR}/data/leandojo_benchmark_4/processed/proofstep-val.jsonl 4 | MODEL=EleutherAI/pythia-2.8b-deduped 5 | CONFIG=${REPO_DIR}/scripts/ds_config.json 6 | 7 | OUTDIR=/path/to/output/ntptutorial/proofstep/${MODEL} 8 | 9 | deepspeed --include localhost:0,1,2,3,4,5,6,7 ${REPO_DIR}/ntp/tune.py \ 10 | --deepspeed ${CONFIG} \ 11 | --model_name_or_path ${MODEL} \ 12 | --train_data_path ${TRAIN_FILE} \ 13 | --valid_data_path ${VALID_FILE} \ 14 | --fp16 \ 15 | --output_dir ${OUTDIR} \ 16 | --num_train_epochs 10 \ 17 | --per_device_train_batch_size 4 \ 18 | --per_device_eval_batch_size 4 \ 19 | --gradient_accumulation_steps 2 \ 20 | --evaluation_strategy "steps" \ 21 | --eval_steps 500 \ 22 | --save_strategy "steps" \ 23 | --save_steps 500 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 1e-5 \ 26 | --load_best_model_at_end 1 \ 27 | --weight_decay 0. \ 28 | --warmup_ratio 0.03 \ 29 | --lr_scheduler_type "cosine" \ 30 | --logging_steps 10 \ 31 | --logging_dir "$OUTDIR" \ 32 | --report_to="tensorboard" 33 | --------------------------------------------------------------------------------