├── .gitignore
├── LICENSE
├── README.md
├── lake-manifest.json
├── lakefile.lean
├── lean-toolchain
├── partII_dsp
├── README.md
├── dsp_utils.py
├── isabelle_setup.md
└── notebooks
│ ├── II_dsp__part1_intro.ipynb
│ ├── II_dsp__part2_dsp.ipynb
│ └── images
│ ├── .gitkeep
│ ├── dsp.png
│ ├── dsp_example.png
│ ├── dsp_plot.png
│ ├── dsp_search.png
│ ├── prove.png
│ ├── sketch.png
│ └── sledgehammer.png
└── partI_nextstep
├── README.md
├── notebooks
├── I_nextstep_lean__part0_intro.ipynb
├── I_nextstep_lean__part1_data.ipynb
├── I_nextstep_lean__part2_learn.ipynb
├── I_nextstep_lean__part3_proofsearch.ipynb
├── I_nextstep_lean__part4_evaluation.ipynb
├── I_nextstep_lean__part5_llmsuggest.ipynb
├── data
│ └── successes_mathlib4_200_wellecks_llmstep-mathlib4-pythia2.8b.json
└── images
│ ├── banach
│ ├── banach_1.png
│ ├── banach_2.png
│ ├── banach_3.png
│ ├── banach_4.png
│ └── banach_5.png
│ ├── leandojo_1.png
│ ├── llmsuggest
│ ├── llmstep_gif.gif
│ ├── llmsuggest.gif
│ └── llmsuggest_examples.png
│ ├── proof_state_1.png
│ ├── proof_state_2.png
│ └── proof_state_3.png
├── ntp_lean
├── ExtractSimple.lean
├── LLMsuggest.lean
└── examples
│ ├── example0.lean
│ └── example_demo.lean
├── ntp_python
├── __init__.py
├── data.py
├── llmsuggest
│ ├── server.py
│ └── suggest.py
├── postprocess_ast.py
├── proofsearch_dojo.py
├── proofsearch_pylean.py
└── tune.py
└── scripts
├── ds_config.json
└── tune_proofstep.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | *model/
2 | *.ast.json
3 | *.dep_paths
4 | lake-packages
5 | *.ckpt
6 | data
7 | generated/
8 |
9 | *.pyc
10 | *.tar.gz
11 | # JetBrains PyCharm IDE
12 | .idea/
13 |
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 | *$py.class
18 |
19 | # C extensions
20 | *.so
21 |
22 | # macOS dir files
23 | .DS_Store
24 |
25 | # Distribution / packaging
26 | .Python
27 | env/
28 | build/
29 | develop-eggs/
30 | dist/
31 | downloads/
32 | eggs/
33 | .eggs/
34 | lib/
35 | lib64/
36 | parts/
37 | sdist/
38 | var/
39 | wheels/
40 | *.egg-info/
41 | .installed.cfg
42 | *.egg
43 |
44 | # Checkpoints
45 | checkpoints
46 |
47 | # PyInstaller
48 | # Usually these files are written by a python script from a template
49 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
50 | *.manifest
51 | *.spec
52 |
53 | # Installer logs
54 | pip-log.txt
55 | pip-delete-this-directory.txt
56 |
57 | # Unit test / coverage reports
58 | htmlcov/
59 | .tox/
60 | .coverage
61 | .coverage.*
62 | .cache
63 | nosetests.xml
64 | coverage.xml
65 | *.cover
66 | .hypothesis/
67 |
68 | # Translations
69 | *.mo
70 | *.pot
71 |
72 | # Django stuff:
73 | *.log
74 | local_settings.py
75 |
76 | # Flask stuff:
77 | instance/
78 | .webassets-cache
79 |
80 | # Scrapy stuff:
81 | .scrapy
82 |
83 | # Sphinx documentation
84 | docs/_build/
85 |
86 | # PyBuilder
87 | target/
88 |
89 | # Jupyter Notebook
90 | .ipynb_checkpoints
91 |
92 | # pyenv
93 | .python-version
94 |
95 | # celery beat schedule file
96 | celerybeat-schedule
97 |
98 | # SageMath parsed files
99 | *.sage.py
100 |
101 | # dotenv
102 | .env
103 |
104 | # virtualenv
105 | .venv
106 | venv/
107 | ENV/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Sean Welleck
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A tutorial on neural theorem proving
2 |
3 | *Neural theorem proving* combines neural language models with formal proof assistants.\
4 | This tutorial introduces two research threads in neural theorem proving via interactive Jupyter notebooks.
5 |
6 |
7 | ## Part I : Next-step suggestion
8 |
9 | Builds a neural next-step suggestion tool, introducing concepts and past work in neural theorem proving along the way.
10 |
11 |
12 |
13 | #### Notebooks:
14 | | Topic | Notebook |
15 | |:-----------------------|-------:|
16 | | 0. Intro | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part0_intro.ipynb) |
17 | | 1. Data | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part1_data.ipynb) |
18 | | 2. Learning | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part2_learn.ipynb) |
19 | | 3. Proof Search | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part3_proofsearch.ipynb) |
20 | | 4. Evaluation | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part4_evaluation.ipynb) |
21 | | 5. `llmsuggest` | [notebook](./partI_nextstep/notebooks/I_nextstep_lean__part5_llmsuggest.ipynb) |
22 |
23 | All notebooks are in ([`partI_nextstep/notebooks`](./partI_nextstep/notebooks)). See [`partI_nextstep/ntp_python`](./partI_nextstep/ntp_python) and [`partI_nextstep/ntp_lean`](./partI_nextstep/ntp_lean) for the Python and Lean files covered in the notebooks.
24 |
25 | #### Setup:
26 | Please follow the setup instructions in [`partI_nextstep/README.md`](./partI_nextstep/README.md).
27 |
28 | ## Part II : Language cascades
29 | Chain together language models to guide formal proof search with informal proofs.
30 |
31 |
32 | #### Notebooks:
33 | | Topic | Notebook |
34 | |:-----------------------|-------:|
35 | | 1. Language model cascades | [notebook](./partII_dsp/notebooks/II_dsp__part1_intro.ipynb) |
36 | | 2. Draft, Sketch, Prove | [notebook](./partII_dsp/notebooks/II_dsp__part2_dsp.ipynb) |
37 |
38 | All notebooks are in ([`partII_dsp/notebooks`](./partII_dsp/notebooks)).
39 |
40 | #### Setup:
41 | Please follow the setup instructions in [`partII_dsp/README.md`](./partII_dsp/README.md).
42 |
43 |
44 | -------
45 | ### History
46 | These materials were originally developed as part of a IJCAI 2023 tutorial. \
47 | Slides for the 1 hour summary presentation given at IJCAI 2023 are [here](https://wellecks.com/data/welleck2023ntp_tutorial.pdf).
48 |
49 | #### Citation
50 |
51 | If you find this tutorial or repository useful in your work, please cite:
52 | ```
53 | @misc{ntptutorial,
54 | author = {Sean Welleck},
55 | title = {Neural theorem proving tutorial},
56 | year = {2023},
57 | publisher = {GitHub},
58 | journal = {GitHub repository},
59 | howpublished = {\url{https://github.com/wellecks/ntptutorial}},
60 | }
61 | ```
62 |
--------------------------------------------------------------------------------
/lake-manifest.json:
--------------------------------------------------------------------------------
1 | {"version": 4,
2 | "packagesDir": "lake-packages",
3 | "packages":
4 | [{"git":
5 | {"url": "https://github.com/EdAyers/ProofWidgets4",
6 | "subDir?": null,
7 | "rev": "c43db94a8f495dad37829e9d7ad65483d68c86b8",
8 | "name": "proofwidgets",
9 | "inputRev?": "v0.0.11"}},
10 | {"git":
11 | {"url": "https://github.com/leanprover-community/mathlib4.git",
12 | "subDir?": null,
13 | "rev": "3de751e5b518e96b9181328441aa1ab1677d8cf0",
14 | "name": "mathlib",
15 | "inputRev?": null}},
16 | {"git":
17 | {"url": "https://github.com/gebner/quote4",
18 | "subDir?": null,
19 | "rev": "c71f94e34c1cda52eef5c93dc9da409ab2727420",
20 | "name": "Qq",
21 | "inputRev?": "master"}},
22 | {"git":
23 | {"url": "https://github.com/JLimperg/aesop",
24 | "subDir?": null,
25 | "rev": "ca73109cc40837bc61df8024c9016da4b4f99d4c",
26 | "name": "aesop",
27 | "inputRev?": "master"}},
28 | {"git":
29 | {"url": "https://github.com/leanprover/std4",
30 | "subDir?": null,
31 | "rev": "d5471b83378e8ace4845f9a029af92f8b0cf10cb",
32 | "name": "std",
33 | "inputRev?": "main"}}]}
34 |
--------------------------------------------------------------------------------
/lakefile.lean:
--------------------------------------------------------------------------------
1 | import Lake
2 | open Lake DSL
3 |
4 | package «ntptutorial»
5 |
6 | require mathlib from git
7 | "https://github.com/leanprover-community/mathlib4.git"
8 |
--------------------------------------------------------------------------------
/lean-toolchain:
--------------------------------------------------------------------------------
1 | leanprover/lean4:nightly-2023-06-10
2 |
--------------------------------------------------------------------------------
/partII_dsp/README.md:
--------------------------------------------------------------------------------
1 | ## Part II : Language cascades
2 | Chain together language models to guide formal proof search with informal proofs.
3 |
4 |
5 | #### Notebooks:
6 | | Topic | Notebook |
7 | |:-----------------------|-------:|
8 | | 1. Language model cascades | [notebook](./notebooks/II_dsp__part1_intro.ipynb) |
9 | | 2. Draft, Sketch, Prove | [notebook](./notebooks/II_dsp__part2_dsp.ipynb) |
10 |
11 | All notebooks are in ([`partII_dsp/notebooks`](./notebooks)).
12 |
13 | #### Setup
14 | The Draft, Sketch, Prove notebook requires setting up an Isabelle proof checker for the "sketch" stage.
15 |
16 | Please follow this guide to set up the Isabelle Proof Checker: [Isabelle Proof Checker Setup](./isabelle_setup.md)
--------------------------------------------------------------------------------
/partII_dsp/dsp_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import openai
4 | import sys
5 |
6 |
7 | class LMFunction(object):
8 | def __init__(self, engine='gpt-3.5-turbo', max_tokens=512):
9 | self.engine = engine
10 | self.max_tokens = max_tokens
11 | self.openai = openai
12 | openai.api_key = os.environ['OPENAI_API_KEY']
13 |
14 | def _call_api(self, prompt, engine, max_tokens, max_retries=10, retry_wait=2):
15 | for i in range(max_retries):
16 | try:
17 | return self.openai.ChatCompletion.create(
18 | model=engine,
19 | messages=[
20 | {"role": "system", "content": "You are a helpful assistant."},
21 | {"role": "user", "content": prompt}
22 | ],
23 | max_tokens=max_tokens,
24 | temperature=1.0
25 | )
26 | except self.openai.error.OpenAIError as e:
27 | time.sleep(retry_wait)
28 | return {'choices': [{'message': {'content': ''}}]}
29 |
30 | def _parse_message(self, msg):
31 | try:
32 | content = msg['choices'][0]['message']['content']
33 | except (IndexError, KeyError):
34 | content = ''
35 | return content
36 |
37 | def f(self, prompt, x):
38 | msg = self._call_api(
39 | prompt=prompt+x,
40 | engine=self.engine,
41 | max_tokens=self.max_tokens
42 | )
43 | evaluation = self._parse_message(msg)
44 | return evaluation
45 |
46 |
47 | class Checker(object):
48 | """A modified version of the Draft, Sketch, Prove proof-checking client.
49 | (https://github.com/albertqjiang/draft_sketch_prove/blob/main/autoformalization/checker.py)
50 |
51 | This checker supports Isabelle2022 via the new version of PISA
52 | (https://albertqjiang.github.io/Portal-to-ISAbelle/).
53 |
54 | It supports checking a miniF2F-style proof via `check`.
55 |
56 | Finally, it replaces `sledgehammer` with a call to `normalhammer`.
57 | """
58 | def __init__(self, working_dir, isa_path, theory_file, port=9000):
59 | sys.path.append(os.environ['PISA_PATH'])
60 | try:
61 | from pisa_client import initialise_env
62 | self.initialise_env = initialise_env
63 | except:
64 | print("Set $PISA_PATH to /yourpath/to/Portal-to-ISAbelle/src/main/python")
65 |
66 | self.working_dir = working_dir
67 | self.isa_path = isa_path
68 | self.theory_file = theory_file
69 | self.port = port
70 |
71 | def _initialize(self):
72 | env = self.initialise_env(
73 | self.port,
74 | isa_path=self.isa_path,
75 | theory_file_path=self.theory_file,
76 | working_directory=self.working_dir
77 | )
78 | return env
79 |
80 | def _exit(self, env):
81 | try:
82 | env.post('exit')
83 | except:
84 | print("env.post('exit') timed out")
85 | pass
86 | os.system("ps aux | grep Isabelle | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")
87 | os.system("ps aux | grep poly | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")
88 |
89 | def _parse_output(self, obs):
90 | """Parse the sledgehammer output, otherwise return an empty string"""
91 | if '' in obs:
92 | output = obs.split('')[0]
93 | else:
94 | output = ''
95 | return output
96 |
97 | def _run_step(self, step, i, tls_name, env):
98 | obs, reward, done, metadata = env.step_to_top_level_state(
99 | action=step,
100 | tls_name=tls_name,
101 | new_name='default_%d' % i
102 | )
103 | error = None
104 | if 'error:' in obs or 'Step error' in obs or 'Unknown error' in obs:
105 | error = obs
106 | return obs, reward, done, metadata, error
107 |
108 | def _run_sledgehammer(self, step, i, tls_name, env):
109 | # First try heuristics
110 | for heuristic in ['by auto', 'by simp', 'by blast', 'by fastforce', 'by force', 'by eval', 'by presburger', 'by sos', 'by arith', 'by linarith', 'by (auto simp: field_simps)']:
111 | step_ = step.replace('normalhammer', heuristic)
112 | obs, reward, done, metadata, error = self._run_step(step_, i, tls_name, env)
113 | if error is None:
114 | obs = '%s %s' % (heuristic, obs)
115 | return obs, reward, done, metadata, error
116 | # Try sledgehammer
117 | out = self._run_step(step, i, tls_name, env)
118 | return out
119 |
120 | def check(self, statement_and_proof):
121 | # Initialize environment
122 | env = self._initialize()
123 | env.initialise()
124 |
125 | # Wrap and parse theorem
126 | theory = Checker.wrap_theorem(statement_and_proof)
127 | steps = Checker.get_parsed(env, theory)
128 |
129 | result = self._check(env, steps)
130 | return result
131 |
132 | def _check(self, env, steps):
133 | done = False
134 | reason = ''
135 | success = False
136 | step_results = []
137 | tls_name = 'default'
138 | for i, step in enumerate(steps):
139 | try:
140 | time0 = time.time()
141 | if 'normalhammer' in step:
142 | obs, reward, done, metadata, error = self._run_sledgehammer(step, i, tls_name, env)
143 | else:
144 | obs, reward, done, metadata, error = self._run_step(step, i, tls_name, env)
145 | step_time = time.time() - time0
146 | step_results.append(dict(index=i, step=step, output=self._parse_output(obs), step_time=step_time))
147 | if error is not None:
148 | reason = error
149 | success = False
150 | done = False
151 | break
152 | except:
153 | # Timeout - end the proof attempt
154 | success = False
155 | done = False
156 | reason = 'timeout (%d)' % len(step_results)
157 | step_results.append(dict(index=i, step=step, output=''))
158 | break
159 |
160 | # Change when successful
161 | tls_name = 'default_%d' % i
162 |
163 | if done and reward == 1.0:
164 | success = True
165 |
166 | result = {
167 | 'success': success,
168 | 'reason': reason,
169 | 'num_steps': len(steps),
170 | 'last_step': len(step_results),
171 | 'step_results': step_results,
172 | 'theorem_and_proof': self.reconstruct(step_results) if success else ''
173 | }
174 | # Exit environment
175 | self._exit(env)
176 | return result
177 |
178 | @staticmethod
179 | def reconstruct(step_results):
180 | steps = []
181 | for step_result in step_results[1:]:
182 | if step_result['output'] != '':
183 | steps.append(step_result['output'].strip())
184 | else:
185 | steps.append(step_result['step'].strip())
186 | theorem_and_proof = '\n'.join(steps)
187 | return theorem_and_proof
188 |
189 | @staticmethod
190 | def wrap_theorem(theorem):
191 | return 'theory Interactive imports HOL.HOL Complex_Main "HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" "Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" "HOL-Number_Theory.Number_Theory" \n begin\n%s' % theorem
192 |
193 | @staticmethod
194 | def get_parsed(env, theory, tls_name='default'):
195 | # HACK: the parsing doesn't work well with `normalhammer`, so we replace
196 | # all hammer calls with sorry, then replace sorry to normalhammer after parsing.
197 | theory = theory.replace('sledgehammer', 'sorry')
198 | theory = theory.replace('normalhammer', 'sorry')
199 |
200 | steps = env.post(f" ${theory}")
201 | steps = steps.split('')
202 | steps = [s for s in steps if s.strip() != '']
203 | # remove weird '$' step and whitespace steps
204 | steps = [s for s in steps if s != '$' and s.strip() != '']
205 | steps = [s.replace('sorry', 'normalhammer') for s in steps]
206 | return steps
207 |
--------------------------------------------------------------------------------
/partII_dsp/isabelle_setup.md:
--------------------------------------------------------------------------------
1 | # Setup: Isabelle Proof Checker
2 |
3 | Follow this guide to set up Isabelle proof checking. At the end, we will have a Python interface for checking a theorem and proof, e.g.
4 | ```python
5 | theorem_and_proof = """theorem ..."""
6 | result = checker.check(theorem_and_proof)
7 | ```
8 |
9 | ## Setup
10 |
11 | Proof checking is done via [PISA](https://github.com/albertqjiang/Portal-to-ISAbelle/tree/56def2c39f85d211e1f40cc5765581a567879106). We implement a client that interacts with PISA (`Checker` in [dsp_utils.py](./dsp_utils.py)).
12 |
13 | Here are setup steps for a non-dockerized environment. The setup is heavily based on the [PISA readme](https://github.com/albertqjiang/Portal-to-ISAbelle/tree/56def2c39f85d211e1f40cc5765581a567879106) and [Dockerfile](https://github.com/albertqjiang/Portal-to-ISAbelle/blob/main/docker/Dockerfile). You may need to refer to those if something goes wrong.
14 |
15 | ### Installation (PISA and Isabelle)
16 | First, we need to set up PISA and Isabelle.
17 | ```bash
18 | # -- PISA setup
19 | # Download Portal-to-ISAbelle (PISA)
20 | cd ~/
21 | git clone https://github.com/albertqjiang/Portal-to-ISAbelle.git
22 |
23 | # Scala installation
24 | sudo apt-get install zip
25 | curl -s "https://get.sdkman.io" | bash
26 | source "~/.sdkman/bin/sdkman-init.sh"
27 | sdk install java 11.0.11-open
28 | sdk install sbt
29 |
30 | # Compile PISA
31 | cd ~/Portal-to-ISAbelle
32 | sbt compile
33 | sbt assembly
34 |
35 | # -- Isabelle setup
36 | # Download Isabelle
37 | wget https://isabelle.in.tum.de/dist/Isabelle2022_linux.tar.gz && \
38 | tar -xzf Isabelle2022_linux.tar.gz
39 |
40 | # Install Isabelle (i.e., move to WORK_DIR, make an alias).
41 | export WORK_DIR=~/
42 | mv Isabelle2022 ${WORK_DIR}/
43 | echo 'alias isabelle=${WORK_DIR}/Isabelle2022/bin/isabelle' >> ~/.bashrc
44 | source ~/.bashrc
45 |
46 | # Build Isabelle HOL (creates heaps in ~/.isabelle)
47 | isabelle build -b -D ${WORK_DIR}/Isabelle2022/src/HOL/ -j 20
48 | ```
49 |
50 | At the end, here's what the setup looks like:
51 | - Portal-to-ISAbelle github repo in `~/Portal-to-ISAbelle`
52 | - Isabelle in `~/Isabelle2022`, e.g.
53 | ```
54 | ls ~/Isabelle2022
55 |
56 | => ANNOUNCE bin contrib ...
57 | ```
58 | - Isabelle heaps in `~/.isabelle`, e.g.
59 | ```
60 | ls ~/.isabelle/Isabelle2022/heaps/polyml-5.9_x86_64_32-linux/
61 |
62 | => Group-Ring-Module HOL-Corec_Examples HOL-Isar_Examples ...
63 | ```
64 | You can test out the installation so far by starting a PISA server:
65 | ```bash
66 | cd ~/Portal-to-ISAbelle
67 | sbt "runMain pisa.server.PisaOneStageServer9000"
68 | ```
69 |
70 | The next step is to specify a configuration that allows the Python client to talk to the Scala PISA server, as described below.
71 |
72 | ### Configuration
73 |
74 | At a high-level, we have three components:
75 | 1. The PISA Scala server
76 | 2. The PISA python library
77 | 3. Our python client, [Checker](./checker.py)
78 |
79 | We need to set environment variables and configuration so that all three can talk to each other.
80 |
81 | #### Set PISA_PATH
82 |
83 | First, set a `PISA_PATH` environment variable that points to PISA's python directory:
84 | ```bash
85 | export PISA_PATH=~/Portal-to-ISAbelle/src/main/python
86 | ```
87 | The variable is used to import PISA's python client (`Portal-to-Isabelle/src/main/python/pisa_client.py`) in Checker. \
88 | This links components 2 and 3.
89 |
90 |
91 | #### Setup a working directory and working file
92 | PISA is initialized by providing a particular working directory and file. \
93 | We will create a file called `Interactive.thy` and put it in the `HOL/Examples` directory:
94 |
95 | ```bash
96 | vim ~/Isabelle2022/src/HOL/Examples/Interactive.thy
97 | ```
98 | ```
99 | theory Interactive
100 | imports Complex_Main
101 | begin
102 |
103 | end
104 | ```
105 | We will use this working directory and file when initializing the checker.
106 |
107 | #### Initializing the checker (in Python, e.g. in the [Draft, Sketch, Prove notebook](./notebooks/II_dsp__part2_dsp.ipynb))
108 |
109 | To initialize the checker, we need to specify the path to Isabelle, the working directory, and the working file (theory file). \
110 | These are used to initialize a working Isabelle instance. This links components 1 and 2.
111 |
112 | Here is an example command found in the [Draft, Sketch, Prove notebook](./notebooks/II_dsp__part2_dsp.ipynb) based on the setup above (here, the home directory `~` is `/home/seanw`):
113 | ```python
114 | checker = dsp_utils.Checker(
115 | working_dir='/home/seanw/Isabelle2022/src/HOL/Examples',
116 | isa_path='/home/seanw/Isabelle2022',
117 | theory_file='/home/seanw/Isabelle2022/src/HOL/Examples/Interactive.thy',
118 | port=9000
119 | )
120 | ```
121 |
122 | #### Start the PISA server
123 | Finally, start a PISA server in a separate tmux window (similar to what was done above in Installation):
124 | ```bash
125 | cd ~/Portal-to-ISAbelle
126 | sbt "runMain pisa.server.PisaOneStageServer9000"
127 | ```
128 | The port specified in the config (here `"port": 9000`) should match the number that appears in the command (`PisaOneStageServer9000`).
129 |
130 | We *leave the server running while running the notebook* (hence, the separate tmux window).
131 |
132 | #### Run the proof checker!
133 | Now try running the proof checker (e.g. in the [Draft, Sketch, Prove notebook](./notebooks/II_dsp__part2_dsp.ipynb))! The notebook uses a call to `checker` that looks like:
134 | ```python
135 | theorem_and_sledgehammer_proof = """theorem gcd_lcm:
136 | assumes "gcd (n :: nat) 4 = 1"
137 | and "lcm (n :: nat) 4 = 28"
138 | shows "n = 7"
139 | proof -
140 | have c1: "1*28 = n*4" using assms
141 | sledgehammer
142 | then have c2: "n = 1*28/4"
143 | sledgehammer
144 | then show ?thesis
145 | sledgehammer
146 | qed"""
147 |
148 | result = checker.check(theorem_and_sledgehammer_proof)
149 |
150 | print("\n==== Success: %s" % result['success'])
151 | print("--- Complete proof:\n%s" % result['theorem_and_proof'])
152 | ```
153 |
--------------------------------------------------------------------------------
/partII_dsp/notebooks/II_dsp__part1_intro.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "#### Language cascades | part 1: introduction\n",
8 | "Tutorial on neural theorem proving\\\n",
9 | "Author: Sean Welleck\n",
10 | "\n",
11 | "----------------\n",
12 | "\n",
13 | "Tools such as [Chat-GPT]() show the flexibility of modern neural language generators.\\\n",
14 | "Namely, a single generation system can often perform a task by simply providing a suitable *prompt*:\n",
15 | "\n",
16 | "$\\quad y=f(p_\\theta(\\cdot|x;P)),$\n",
17 | "\n",
18 | "where $x$ is an input, $P$ is a prompt, and $f(\\cdot)$ is a decoding algorithm.\n",
19 | "\n",
20 | "\n",
21 | "Let's look at one of these functions:\n"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 1,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "import time\n",
31 | "import os\n",
32 | "import openai\n",
33 | "\n",
34 | "class LMFunction(object):\n",
35 | " def __init__(self, engine='gpt-3.5-turbo', max_tokens=512):\n",
36 | " self.engine = engine\n",
37 | " self.max_tokens = max_tokens\n",
38 | " self.openai = openai\n",
39 | " openai.api_key = os.environ['OPENAI_API_KEY']\n",
40 | "\n",
41 | " def _call_api(self, prompt, engine, max_tokens, max_retries=10, retry_wait=2):\n",
42 | " for i in range(max_retries):\n",
43 | " try:\n",
44 | " return self.openai.ChatCompletion.create(\n",
45 | " model=engine,\n",
46 | " messages=[\n",
47 | " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
48 | " {\"role\": \"user\", \"content\": prompt}\n",
49 | " ],\n",
50 | " max_tokens=max_tokens,\n",
51 | " temperature=1.0\n",
52 | " )\n",
53 | " except self.openai.error.OpenAIError as e:\n",
54 | " time.sleep(retry_wait)\n",
55 | " return {'choices': [{'message': {'content': ''}}]}\n",
56 | "\n",
57 | " def _parse_message(self, msg):\n",
58 | " try:\n",
59 | " content = msg['choices'][0]['message']['content']\n",
60 | " content = content.strip().split('\\n')[0]\n",
61 | " except (IndexError, KeyError):\n",
62 | " content = ''\n",
63 | " return content\n",
64 | "\n",
65 | " def f(self, prompt, x):\n",
66 | " msg = self._call_api(\n",
67 | " prompt=prompt+x,\n",
68 | " engine=self.engine,\n",
69 | " max_tokens=self.max_tokens\n",
70 | " )\n",
71 | " evaluation = self._parse_message(msg)\n",
72 | " return evaluation"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 2,
78 | "metadata": {},
79 | "outputs": [
80 | {
81 | "data": {
82 | "text/plain": [
83 | "[('1947256', False),\n",
84 | " ('1945256', False),\n",
85 | " ('1946256', False),\n",
86 | " ('1947176', True),\n",
87 | " ('1947176', True),\n",
88 | " ('1947056', False),\n",
89 | " ('1947456', False),\n",
90 | " ('1947256', False),\n",
91 | " ('1947256', False),\n",
92 | " ('1947256', False)]"
93 | ]
94 | },
95 | "execution_count": 2,
96 | "metadata": {},
97 | "output_type": "execute_result"
98 | }
99 | ],
100 | "source": [
101 | "prompt = \"\"\"Multiply two numbers. Here are some examples:\n",
102 | "432*342=147744\n",
103 | "98*19=1862\n",
104 | "\"\"\"\n",
105 | "\n",
106 | "p = LMFunction('gpt-4')\n",
107 | "\n",
108 | "outputs = [p.f(prompt, '872*2233=') for _ in range(10)]\n",
109 | "[(output, output==str(872*2233)) for output in outputs]"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 13,
115 | "metadata": {},
116 | "outputs": [
117 | {
118 | "data": {
119 | "text/plain": [
120 | "1947176"
121 | ]
122 | },
123 | "execution_count": 13,
124 | "metadata": {},
125 | "output_type": "execute_result"
126 | }
127 | ],
128 | "source": [
129 | "872*2233"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "The result above shows two interesting things:\n",
137 | "1. The function is stochastic; it can return different answers each time it is called.\n",
138 | "2. The function is capable of producing a correct answer; it gets it correct 2 times.\n",
139 | "\n",
140 | "Therefore, one way a stochastic function like this is useful is to pair it with a reliable verifier.\n",
141 | "\n",
142 | "\n",
143 | "#### Why is this useful?\n",
144 | "The main attraction of these functions is their flexibility. For instance,\n",
145 | "it is easy to implement a function that maps language input to a call to Sympy:\n",
146 | "\n",
147 | "```\n",
148 | " sympy_expression ~ f(\"872 times 2233 =\")\n",
149 | " answer = g(sympy_expression)\n",
150 | "```\n",
151 | "where $g(\\cdot)$ is Sympy evaluation.\n",
152 | "\n",
153 | "This yields functionality that is difficult to program manually:"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": 8,
159 | "metadata": {},
160 | "outputs": [
161 | {
162 | "name": "stdout",
163 | "output_type": "stream",
164 | "text": [
165 | "(800+72)*2233\n",
166 | "1947176\n"
167 | ]
168 | }
169 | ],
170 | "source": [
171 | "from sympy.parsing.sympy_parser import parse_expr\n",
172 | "\n",
173 | "prompt = \"\"\"Solve the multiplication problem by writing input to a sympy function.\n",
174 | "Do not add additional text. Here are some examples:\n",
175 | "\n",
176 | "432 multiplied by 342 is: 432*342\n",
177 | "98* 19 is how much? 98*19\n",
178 | "\"\"\"\n",
179 | "\n",
180 | "p = LMFunction('gpt-4')\n",
181 | "\n",
182 | "g = parse_expr\n",
183 | "\n",
184 | "sympy_expression = p.f(prompt, 'There are 800+72 apples in a barrel. How many apples in 2233 barrels?')\n",
185 | "answer = g(sympy_expression)\n",
186 | "print(sympy_expression)\n",
187 | "print(answer)"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 7,
193 | "metadata": {},
194 | "outputs": [
195 | {
196 | "data": {
197 | "text/plain": [
198 | "1947176"
199 | ]
200 | },
201 | "execution_count": 7,
202 | "metadata": {},
203 | "output_type": "execute_result"
204 | }
205 | ],
206 | "source": [
207 | "872*2233"
208 | ]
209 | },
210 | {
211 | "cell_type": "markdown",
212 | "metadata": {},
213 | "source": [
214 | "#### Language cascades\n",
215 | "\n",
216 | "A [language model cascade [Dohan et al 2022]](https://arxiv.org/abs/2207.10342) formalizes the idea of composing multiple functions, some of which are stochastic functions implemented by a language model.\n",
217 | "\n",
218 | "The result can be seen as a probabilistic program, whose samples \"execute the function\", e.g.:"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": 56,
224 | "metadata": {},
225 | "outputs": [
226 | {
227 | "data": {
228 | "text/latex": [
229 | "$\\displaystyle 4544$"
230 | ],
231 | "text/plain": [
232 | "4544"
233 | ]
234 | },
235 | "execution_count": 56,
236 | "metadata": {},
237 | "output_type": "execute_result"
238 | }
239 | ],
240 | "source": [
241 | "def multiply(x):\n",
242 | " y1 = p.f(prompt, x)\n",
243 | " y2 = g(y1)\n",
244 | " return y2\n",
245 | "\n",
246 | "\n",
247 | "multiply('I bought 32 cases of apples, with one hundred and 42 apples per case. How many total apples?')"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": 57,
253 | "metadata": {},
254 | "outputs": [
255 | {
256 | "data": {
257 | "text/plain": [
258 | "4544"
259 | ]
260 | },
261 | "execution_count": 57,
262 | "metadata": {},
263 | "output_type": "execute_result"
264 | }
265 | ],
266 | "source": [
267 | "32*142"
268 | ]
269 | },
270 | {
271 | "cell_type": "markdown",
272 | "metadata": {},
273 | "source": [
274 | "#### Cascades for neural theorem proving\n",
275 | "\n",
276 | "The two ideas mentioned above: composing multiple functions and using a verifier, make neural theorem proving a natural setting for language cascades.\n",
277 | "\n",
278 | "Namely, the goal will be to decompose theorem proving into different functions, then use the proof assistant to verify the final output.\n",
279 | "\n",
280 | "In the next notebook, we will see a cascade called [Draft, Sketch, Prove [Jiang et al ICLR 2023]](https://arxiv.org/abs/2210.12283) that does so with three components: \\\n",
281 | "**draft** an informal proof, **sketch** a formal proof, and **prove** the remaining gaps.\n",
282 | "\n",
283 | "The end result is a model and proof search procedure that is qualitatively much different than the next-step predictors we used in Part I."
284 | ]
285 | },
286 | {
287 | "cell_type": "markdown",
288 | "metadata": {},
289 | "source": []
290 | }
291 | ],
292 | "metadata": {
293 | "kernelspec": {
294 | "display_name": "Python 3 (ipykernel)",
295 | "language": "python",
296 | "name": "python3"
297 | },
298 | "language_info": {
299 | "codemirror_mode": {
300 | "name": "ipython",
301 | "version": 3
302 | },
303 | "file_extension": ".py",
304 | "mimetype": "text/x-python",
305 | "name": "python",
306 | "nbconvert_exporter": "python",
307 | "pygments_lexer": "ipython3",
308 | "version": "3.10.11"
309 | }
310 | },
311 | "nbformat": 4,
312 | "nbformat_minor": 4
313 | }
314 |
--------------------------------------------------------------------------------
/partII_dsp/notebooks/II_dsp__part2_dsp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "#### Language cascades | part 2: Draft, Sketch, Prove\n",
8 | "Tutorial on neural theorem proving\\\n",
9 | "Author: Sean Welleck\n",
10 | "\n",
11 | "----------------\n",
12 | "\n",
13 | "### High-level goal\n",
14 | "\n",
15 | "This notebook will implement a prototype version of [Draft, Sketch, Prove [Jiang et al ICLR 2023]](https://arxiv.org/pdf/2210.12283.pdf):\n",
16 | "\n",
17 | "
\n",
18 | "\n",
19 | "As pictured above, Draft, Sketch, Prove frames theorem proving as the following procedure. \\\n",
20 | "Given an informal (i.e., Latex) theorem statement $x_I$ and formal theorem statement $x_F$:\n",
21 | "\n",
22 | "1. Generate an *informal* proof $y_{I}\\sim p(\\cdot|x_I,P_{\\text{draft}})$, called a *draft*\n",
23 | "2. Generate a *formal sketch* $z_{F}\\sim p(\\cdot|y_{I}, x_I, x_F, P_{\\text{sketch}})$\n",
24 | "3. Prove the remaining conjectures in the sketch, $y_{F}=f(x_F,z_F)$.\n",
25 | "\n",
26 | "If step $3$ is successful, we will have a verified formal proof of $x_F$. Otherwise we try again. \\\n",
27 | "Conceptually, these steps can be viewed as the following program:"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 1,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "def dsp(xi, xf, f_draft, f_sketch, f_proof):\n",
37 | " yi = f_draft(xi)\n",
38 | " zf = f_sketch(yi, xi, xf)\n",
39 | " yf = f_proof(xf, zf)\n",
40 | " return yf"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {},
46 | "source": [
47 | "Next, we will discuss how to implement these three modules.\n",
48 | "\n",
49 | "We start by introducing the [Isabelle proof assistant](https://isabelle.in.tum.de/), since it is relevant to the implementation.\n",
50 | "\n",
51 | "\n",
52 | "### Isabelle proof assistant\n",
53 | "\n",
54 | "Draft, Sketch, Prove was originally proposed using a proof assistant called [**Isabelle**](https://isabelle.in.tum.de/).\n",
55 | "\n",
56 | "Isabelle has two relevant properties that are helpful to introduce.\n",
57 | "\n",
58 | "#### 1. Declarative proofs\n",
59 | "First, many Isabelle proofs are structured as a sequence of intermediate conjectures (referred to as a *declarative proof*).\n",
60 | "For example, consider the proof below:\n",
61 | "\n",
62 | "\n",
63 | "
"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {},
69 | "source": [
70 | "In **blue** are the intermediate steps, which include two conjectures:\n",
71 | "1. `c1`, which says that $1*28=n*4$\n",
72 | "2. `c2`, which says that $n=1*28/4$\n",
73 | "\n",
74 | "The final step `then show ?thesis` can be thought of as \"the result follows\".\\\n",
75 | "Intuitively, intermediate conjectures can resemble steps in an informal (latex) proof. \n",
76 | "\n",
77 | "\n",
78 | "Isabelle requires a proof of each step, shown in **green**. These often involve lower-level premises or calls to external automation.\\\n",
79 | "To obtain these proofs we can use Sledgehammer:"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {},
85 | "source": [
86 | "#### 2. Sledgehammer\n",
87 | "\n",
88 | "Isabelle has a powerful automation tool called **[Sledgehammer](https://isabelle.in.tum.de/website-Isabelle2009-1/sledgehammer.html)** for producing proofs similar to those shown in green.\n",
89 | "\n",
90 | "In practice, a user would write an intermediate conjecture, e.g. `have c1: \"...`, then call Sledgehammer to find a proof of the conjecture:\n",
91 | "\n",
92 | "
\n",
93 | "\n",
94 | "\n",
95 | "Under the covers, Sledgehammer calls out to classical provers that excel at short, low-level proofs. However, fully proving complex theorems with Sledgehammer is typically intractable due to the large search space (for instance, Sledgehammer wouldn't produce the `have c1` *statement*, even though it can prove the statement).\n",
96 | "\n",
97 | "Our use of Sledgehammer is implemented in the `Checker` class in `dsp_utils.py`. Please see [isabelle_setup.md](../isabelle_setup.md) to set up the proof checker, and modify the `working_dir`, `isa_path`, and `theory_file` paths below accordingly.\n",
98 | "\n",
99 | "Below, we initialize an Isabelle proof checker that can run Sledgehammer:"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 2,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "import sys\n",
109 | "import os\n",
110 | "sys.path.append('../')\n",
111 | "os.environ['PISA_PATH'] = '/home/seanw/Portal-to-ISAbelle/src/main/python'\n",
112 | "\n",
113 | "import dsp_utils\n",
114 | "\n",
115 | "checker = dsp_utils.Checker(\n",
116 | " working_dir='/home/seanw/Isabelle2022/src/HOL/Examples',\n",
117 | " isa_path='/home/seanw/Isabelle2022',\n",
118 | " theory_file='/home/seanw/Isabelle2022/src/HOL/Examples/Interactive.thy',\n",
119 | " port=9000\n",
120 | ")"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {},
126 | "source": [
127 | "Now we send the theorem and proof to the checker. If sledgehammer succeeds at a given step, the result of sledgehammer is added to the proof, and checking proceeds to the next step.\n",
128 | "\n",
129 | "At the end, we get a completed proof:"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 3,
135 | "metadata": {},
136 | "outputs": [
137 | {
138 | "name": "stdout",
139 | "output_type": "stream",
140 | "text": [
141 | "----------Path to Isabelle source----------\n",
142 | "/home/seanw/Isabelle2022\n",
143 | "----------Path to Isabelle working directory----------\n",
144 | "/home/seanw/Isabelle2022/src/HOL/Examples\n",
145 | "----------Path to Isabelle theory file----------\n",
146 | "/home/seanw/Isabelle2022/src/HOL/Examples/Interactive.thy\n",
147 | "\n",
148 | "==== Success: True\n",
149 | "--- Complete proof:\n",
150 | "theorem gcd_lcm:\n",
151 | " assumes \"gcd (n :: nat) 4 = 1\" \n",
152 | " and \"lcm (n :: nat) 4 = 28\"\n",
153 | " shows \"n = 7\"\n",
154 | "proof -\n",
155 | "have c1: \"1*28 = n*4\"\n",
156 | "using assms\n",
157 | "by (metis prod_gcd_lcm_nat)\n",
158 | "then\n",
159 | "have c2: \"n = 1*28/4\"\n",
160 | "by auto\n",
161 | "then\n",
162 | "show ?thesis\n",
163 | "by auto\n",
164 | "qed\n"
165 | ]
166 | }
167 | ],
168 | "source": [
169 | "theorem_and_sledgehammer_proof = \"\"\"theorem gcd_lcm:\n",
170 | " assumes \"gcd (n :: nat) 4 = 1\" \n",
171 | " and \"lcm (n :: nat) 4 = 28\"\n",
172 | " shows \"n = 7\"\n",
173 | "proof -\n",
174 | " have c1: \"1*28 = n*4\" using assms\n",
175 | " sledgehammer\n",
176 | " then have c2: \"n = 1*28/4\"\n",
177 | " sledgehammer\n",
178 | " then show ?thesis\n",
179 | " sledgehammer\n",
180 | "qed\"\"\"\n",
181 | "\n",
182 | "result = checker.check(theorem_and_sledgehammer_proof)\n",
183 | "\n",
184 | "print(\"\\n==== Success: %s\" % result['success'])\n",
185 | "print(\"--- Complete proof:\\n%s\" % result['theorem_and_proof'])"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "Notice that the `sledgehammer` steps are now filled in (e.g. `by (metis prod_gcd_lcm_nat)`)."
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "metadata": {},
198 | "source": [
199 | "## Cascade\n",
200 | "\n",
201 | "The declarative steps (in blue) and the lower-level proofs of those steps (in green) lead to a nice opportunity for a language model cascade.\n",
202 | "\n",
203 | "Namely, Draft, Sketch, Prove uses a neural language model to \"draft\" an informal proof, \"sketch\" the declarative steps based on the informal proof, then attempts to \"close the gaps\" with Sledgehammer (i.e., prove each step). If the gaps are closed, the steps together with their proofs constitute a verified proof of the original theorem (and of course, this is checked by Isabelle).\n",
204 | "\n",
205 | "\n",
206 | "## Draft and Sketch examples\n",
207 | "To implement the drafting and sketching steps, we provide a few examples in the prompt.\n",
208 | "\n",
209 | "We will derive these examples from ones used in the paper:"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 27,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "import sys\n",
219 | "sys.path.append('../')\n",
220 | "import dsp_utils\n",
221 | "\n",
222 | "import json\n",
223 | "examples = [\n",
224 | " {\"tag\": \"aimeI_2000_p7\", \"category\": \"algebra\", \"metadata\": {}, \"prompt\": \"Informal:\\n(*### Problem\\n\\nSuppose that $x,$ $y,$ and $z$ are three positive numbers that satisfy the equations $xyz = 1,$ $x + \\\\frac {1}{z} = 5,$ and $y + \\\\frac {1}{x} = 29.$ Then $z + \\\\frac {1}{y} = \\\\frac {m}{n},$ where $m$ and $n$ are [[relatively prime]] positive integers. Find $m + n$. Show that it is 5.\\n\\n\\nnote: this is the type of problem that makes you think symmetry, but actually can be solved easily with substitution, and other normal technniques\\n\\n### Solution\\n\\nWe can rewrite $xyz=1$ as $\\\\frac{1}{z}=xy$.\\n\\nSubstituting into one of the given equations, we have \\n$x+xy=5$\\n$x(1+y)=5$\\n$\\\\frac{1}{x}=\\\\frac{1+y}{5}.$\\n\\nWe can substitute back into $y+\\\\frac{1}{x}=29$ to obtain\\n$y+\\\\frac{1+y}{5}=29$\\n$5y+1+y=145$\\n$y=24.$\\n\\nWe can then substitute once again to get\\n$x=\\\\frac15$\\n$z=\\\\frac{5}{24}.$\\nThus, $z+\\\\frac1y=\\\\frac{5}{24}+\\\\frac{1}{24}=\\\\frac{1}{4}$, so $m+n=005$.*)\\n\\nFormal:\\ntheorem\\n fixes x y z :: real\\n and p :: rat\\n assumes \\\"0 < x \\\\ 0 < y \\\\ 0 < z\\\"\\n and \\\"x * y * z = 1\\\"\\n and \\\"x + 1 / z = 5\\\"\\n and \\\"y + 1 / x = 29\\\"\\n and \\\"z + 1 / y = p\\\"\\n and \\\"0 < p\\\" \\n shows \\\"let (m,n) = quotient_of p in m + n = 5\\\"\\nproof -\\n (* We can rewrite $xyz=1$ as $\\\\frac{1}{z}=xy$. *)\\n have c0: \\\"z = 1 / (x*y)\\\"\\n sledgehammer\\n (* Substituting into one of the given equations, we have \\n $x+xy=5$\\n $x(1+y)=5$\\n $\\\\frac{1}{x}=\\\\frac{1+y}{5}.$ *)\\n have c1: \\\"1 / x = (1+y) / 5\\\" \\n proof -\\n have \\\"x + x * y = 5\\\" using assms(3) unfolding c0\\n sledgehammer\\n then have \\\"x * (1 + y) = 5\\\"\\n sledgehammer\\n then have t1: \\\"x = 5 / (1+y)\\\"\\n sledgehammer\\n then show ?thesis\\n sledgehammer\\n qed\\n (* We can substitute back into $y+\\\\frac{1}{x}=29$ to obtain\\n $y+\\\\frac{1+y}{5}=29$\\n $5y+1+y=145$\\n $y=24.$ *)\\n have \\\"y + (1+y)/5 = 29\\\" using assms(4) unfolding c1 sledgehammer\\n then have \\\"5* (y + (1+y)/5) = 5 * 29\\\" sledgehammer\\n also have \\\"... = 145\\\" sledgehammer\\n finally have c2_1: \\\"5* (y + (1+y)/5) = 145\\\" sledgehammer\\n have \\\"5* (y + (1+y)/5) = 5*y + (1+y)\\\" sledgehammer\\n also have \\\"... = 6*y + 1\\\" sledgehammer\\n finally have c2_2: \\\"5* (y + (1+y)/5) = 6*y + 1\\\" sledgehammer\\n have \\\"6*y + 1 = 145\\\" using c2_1 c2_2 sledgehammer\\n then have c2: \\\"y = 24\\\" sledgehammer\\n (* We can then substitute once again to get\\n $x=\\\\frac15$\\n $z=\\\\frac{5}{24}.$ *)\\n have \\\"1/x = 5\\\" using c1 unfolding c2 sledgehammer\\n then have c3: \\\"x = 1/5\\\"\\n sledgehammer\\n then have c4: \\\"z = 5/24\\\"\\n sledgehammer\\n (* Thus, $z+\\\\frac1y=\\\\frac{5}{24}+\\\\frac{1}{24}=\\\\frac{1}{4}$, so $m+n=005$. *)\\n have \\\"p = z + 1/y\\\" using assms(5) sledgehammer\\n also have \\\"... = 5/24 + 1/24\\\" unfolding c2 c4 sledgehammer\\n also have \\\"... = 1/4\\\" sledgehammer\\n finally have c5: \\\"p = 1/4\\\"\\n sledgehammer\\n have \\\"quotient_of p = (1, 4)\\\" unfolding c5 sledgehammer\\n then show ?thesis sledgehammer\\nqed\"},\n",
225 | " {\"tag\": \"algebra_2rootsintpoly_am10tap11eqasqpam110\", \"category\": \"algebra\", \"metadata\": {}, \"prompt\": \"Informal:\\n(*### Problem\\n\\nShow that for any complex number a, $(a-10)(a+11) = a^2 + a - 110$.\\n\\n### Solution\\n\\nWe first expand all terms of the left hand side to get $a^2 - 10a + 11a - 10*11$.\\nThis equals $a^2 + a - 10*11 = a^2 + a - 110$.*)\\n\\nFormal:\\ntheorem\\n fixes a :: complex\\n shows \\\"(a-10) * (a+11) = a^2 + a -110\\\"\\nproof -\\n (* We first expand all terms of the left hand side to get $a^2 - 10a + 11a - 10*11$. *)\\n have \\\"(a-10) * (a+11) = a^2 - 10*a + 11*a - 10 *11\\\"\\n sledgehammer\\n (* This equals $a^2 + a - 10*11 = a^2 + a - 110$. *)\\n also have \\\"\\\\ = a^2 + a - 10 * 11\\\"\\n sledgehammer\\n also have \\\"\\\\ = a^2 + a - 110\\\"\\n sledgehammer\\n finally show ?thesis\\n sledgehammer\\nqed\"},\n",
226 | " {\"tag\": \"mathd_numbertheory_335\", \"category\": \"number_theory\", \"metadata\": {}, \"prompt\": \"Informal:\\n(*### Problem\\n\\nWhen Rachel divides her favorite number by 7, she gets a remainder of 5. What will the remainder be if she multiplies her favorite number by 5 and then divides by 7? Show that it is 4.\\n\\n### Solution\\n\\nLet $n$ be Rachel's favorite number. \\nThen $n \\\\equiv 5 \\\\pmod{7}$, so $5n \\\\equiv 5 \\\\cdot 5 \\\\equiv 25 \\\\equiv 4 \\\\pmod{7}$.\\n*)\\n\\nFormal:\\ntheorem\\n fixes n :: nat\\n assumes h0 : \\\"n mod 7 = 5\\\"\\n shows \\\"(5 * n) mod 7 = 4\\\"\\nproof -\\n (* Then $n \\\\equiv 5 \\\\pmod{7}$, so $5n \\\\equiv 5 \\\\cdot 5 \\\\equiv 25 \\\\equiv 4 \\\\pmod{7}$. *)\\n have c0:\\\"(5 * n) mod 7 = (5 * 5) mod 7\\\" using h0\\n sledgehammer\\n then have \\\"\\\\ = 4\\\" sledgehammer\\n then have \\\"(5 * n) mod 7 = 4\\\" using c0 sledgehammer\\n then show ?thesis sledgehammer\\nqed\"}\n",
227 | "]"
228 | ]
229 | },
230 | {
231 | "cell_type": "markdown",
232 | "metadata": {},
233 | "source": [
234 | "## Draft\n",
235 | "\n",
236 | "This function generates an *informal* proof $y_{I}\\sim p(\\cdot|x_I,P_{\\text{draft}})$, called a *draft*.\n",
237 | "\n",
238 | "Here, $P_{\\text{draft}}$ is a prompt containing examples of mapping the informal theorem statement $x_I$ to an informal proof $y_I$:\n"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 51,
244 | "metadata": {},
245 | "outputs": [
246 | {
247 | "name": "stdout",
248 | "output_type": "stream",
249 | "text": [
250 | "Draft an informal solution similar to below. \n",
251 | "The informal solution will be used to sketch a formal Isabelle proof.\n",
252 | "Here are some examples:\n",
253 | "Example:\n",
254 | "Informal:\n",
255 | "(*### Problem\n",
256 | "\n",
257 | "Suppose that $x,$ $y,$ and $z$ are three positive numbers that satisfy the equations $xyz = 1,$ $x + \\frac {1}{z} = 5,$ and $y + \\frac {1}{x} = 29.$ Then $z + \\frac {1}{y} = \\frac {m}{n},$ where $m$ and $n$ are [[relatively prime]] positive integers. Find $m + n$. Show that it is 5.\n",
258 | "\n",
259 | "\n",
260 | "note: this is the type of problem that makes you think symmetry, but actually can be solved easily with substitution, and other normal technniques\n",
261 | "\n",
262 | "### Solution\n",
263 | "\n",
264 | "We can rewrite $xyz=1$ as $\\frac{1}{z}=xy$.\n",
265 | "\n",
266 | "Substituting into one of the given equations, we have \n",
267 | "$x+xy=5$\n",
268 | "$x(1+y)=5$\n",
269 | "$\\frac{1}{x}=\\frac{1+y}{5}.$\n",
270 | "\n",
271 | "We can substitute back into $y+\\frac{1}{x}=29$ to obtain\n",
272 | "$y+\\frac{1+y}{5}=29$\n",
273 | "$5y+1+y=145$\n",
274 | "$y=24.$\n",
275 | "\n",
276 | "We can then substitute once again to get\n",
277 | "$x=\\frac15$\n",
278 | "$z=\\frac{5}{24}.$\n",
279 | "Thus, $z+\\frac1y=\\frac{5}{24}+\\frac{1}{24}=\\frac{1}{4}$, so $m+n=005$.*)\n",
280 | "\n",
281 | "\n",
282 | "\n",
283 | "Example:\n",
284 | "Informal:\n",
285 | "(*### Problem\n",
286 | "\n",
287 | "Show that for any complex number a, $(a-10)(a+11) = a^2 + a - 110$.\n",
288 | "\n",
289 | "### Solution\n",
290 | "\n",
291 | "We first expand all terms of the left hand side to get $a^2 - 10a + 11a - 10*11$.\n",
292 | "This equals $a^2 + a - 10*11 = a^2 + a - 110$.*)\n",
293 | "\n",
294 | "\n",
295 | "\n",
296 | "Example:\n",
297 | "Informal:\n",
298 | "(*### Problem\n",
299 | "\n",
300 | "When Rachel divides her favorite number by 7, she gets a remainder of 5. What will the remainder be if she multiplies her favorite number by 5 and then divides by 7? Show that it is 4.\n",
301 | "\n",
302 | "### Solution\n",
303 | "\n",
304 | "Let $n$ be Rachel's favorite number. \n",
305 | "Then $n \\equiv 5 \\pmod{7}$, so $5n \\equiv 5 \\cdot 5 \\equiv 25 \\equiv 4 \\pmod{7}$.\n",
306 | "*)\n",
307 | "\n",
308 | "\n",
309 | "\n",
310 | "Informal:\n",
311 | "(*### Problem\n",
312 | "\n",
313 | "\n"
314 | ]
315 | }
316 | ],
317 | "source": [
318 | "prompt = \"\"\"Draft an informal solution similar to below. \n",
319 | "The informal solution will be used to sketch a formal Isabelle proof.\n",
320 | "Here are some examples:\n",
321 | "\"\"\"\n",
322 | "for x in examples:\n",
323 | " prompt += (\"Example:\\n\" + x['prompt'][:x['prompt'].find('Formal:')] + \"\\n\\n\")\n",
324 | "prompt += \"\"\"Informal:\n",
325 | "(*### Problem\n",
326 | "\n",
327 | "\"\"\"\n",
328 | "\n",
329 | "print(prompt)"
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": 44,
335 | "metadata": {},
336 | "outputs": [
337 | {
338 | "name": "stdout",
339 | "output_type": "stream",
340 | "text": [
341 | "### Solution\n",
342 | "\n",
343 | "If x is even, then it can be represented as 2n where n is some integer.\n",
344 | "So x + 5 equals 2n + 5.\n",
345 | "We can rewrite 2n + 5 as 2(n + 2) + 1, which is in the form of 2k + 1 where k is an integer (in this case, n + 2). Thus, x + 5 is odd.\n"
346 | ]
347 | }
348 | ],
349 | "source": [
350 | "p = dsp_utils.LMFunction('gpt-4')\n",
351 | "xi = 'Show that if x is even, then x+5 is odd'\n",
352 | "yi = p.f(prompt, xi)\n",
353 | "print(yi)"
354 | ]
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "metadata": {},
359 | "source": [
360 | "## Sketch\n",
361 | "\n",
362 | "Generate a *formal sketch* $z_{F}\\sim p(\\cdot|y_{I}, x_I, x_F, P_{\\text{sketch}})$\n",
363 | "\n",
364 | "Here, $P_{\\text{sketch}}$ is a prompt containing the examples from the drafting step with an additional formal sketch."
365 | ]
366 | },
367 | {
368 | "cell_type": "code",
369 | "execution_count": 55,
370 | "metadata": {},
371 | "outputs": [
372 | {
373 | "name": "stdout",
374 | "output_type": "stream",
375 | "text": [
376 | "Formal:\n",
377 | "proof -\n",
378 | " (* If x is even, then it can be represented as 2n where n is some integer. *)\n",
379 | " obtain n where c1: \"x = 2*n\"\n",
380 | " using evenE assms\n",
381 | " sledgehammer\n",
382 | " (* So x + 5 equals 2n + 5. *)\n",
383 | " then have \"x + 5 = 2*n + 5\" \n",
384 | " sledgehammer\n",
385 | " (* We can rewrite 2n + 5 as 2(n + 2) + 1, which is in the form of 2k + 1 where k is an integer (in this case, n + 2). Thus, x + 5 is odd. *)\n",
386 | " also have \"\\ = 2*(n+2) + 1\"\n",
387 | " sledgehammer\n",
388 | " then have exI: \"\\k. x + 5 = 2*k+1\" \n",
389 | " sledgehammer\n",
390 | " then have \"odd (x+5)\" \n",
391 | " sledgehammer\n",
392 | " then show ?thesis \n",
393 | " sledgehammer\n",
394 | "qed\n"
395 | ]
396 | }
397 | ],
398 | "source": [
399 | "prompt = \"\"\"Translate the informal solution into a sketch of the\n",
400 | "formal Isabelle proof. Add `sledgehammer` in the sketch whenever\n",
401 | "possible. `sledgehammer` will be used to call the automated Sledgehammer prover. \n",
402 | "Here are some examples:\n",
403 | "\"\"\"\n",
404 | "for x in examples:\n",
405 | " prompt += (x['prompt'] + \"\\n\\n\")\n",
406 | "prompt += \"\"\"Informal:\n",
407 | "(*### Problem\n",
408 | "\n",
409 | "\"\"\"\n",
410 | "\n",
411 | "xf = \"\"\"theorem\n",
412 | "fixes x :: int\n",
413 | "assumes h0: \"even x\"\n",
414 | "shows \"odd (x+5)\" \"\"\"\n",
415 | "\n",
416 | "zi = p.f(prompt, xi + '\\n\\n' + yi + '\\n\\n' + xf)\n",
417 | "print(zi)"
418 | ]
419 | },
420 | {
421 | "cell_type": "markdown",
422 | "metadata": {},
423 | "source": [
424 | "## Proof\n",
425 | "\n",
426 | "Finally, we call [Sledgehammer](https://isabelle.in.tum.de/website-Isabelle2009-1/sledgehammer.html) to prove the remaining intermediate conjectures.\n",
427 | "\n",
428 | "You can see the completed proof printed in the cell output:\n"
429 | ]
430 | },
431 | {
432 | "cell_type": "code",
433 | "execution_count": 4,
434 | "metadata": {},
435 | "outputs": [
436 | {
437 | "name": "stdout",
438 | "output_type": "stream",
439 | "text": [
440 | "----------Path to Isabelle source----------\n",
441 | "/home/seanw/Isabelle2022\n",
442 | "----------Path to Isabelle working directory----------\n",
443 | "/home/seanw/Isabelle2022/src/HOL/Examples\n",
444 | "----------Path to Isabelle theory file----------\n",
445 | "/home/seanw/Isabelle2022/src/HOL/Examples/Interactive.thy\n",
446 | "\n",
447 | "==== Success: True\n",
448 | "--- Complete proof:\n",
449 | "theorem\n",
450 | "fixes x :: int\n",
451 | "assumes h0: \"even x\"\n",
452 | "shows \"odd (x+5)\"\n",
453 | "proof -\n",
454 | "(* If x is even, then it can be represented as 2n where n is some integer. *)\n",
455 | "obtain n where c1: \"x = 2*n\"\n",
456 | "using evenE assms\n",
457 | "by auto\n",
458 | "(* So x + 5 equals 2n + 5. *)\n",
459 | "then\n",
460 | "have \"x + 5 = 2*n + 5\"\n",
461 | "by auto\n",
462 | "(* We can rewrite 2n + 5 as 2(n + 2) + 1, which is in the form of 2k + 1 where k is an integer (in this case, n + 2). Thus, x + 5 is odd. *)\n",
463 | "also\n",
464 | "have \"\\ = 2*(n+2) + 1\"\n",
465 | "by auto\n",
466 | "then\n",
467 | "have exI: \"\\k. x + 5 = 2*k+1\"\n",
468 | "using c1 by blast\n",
469 | "then\n",
470 | "have \"odd (x+5)\"\n",
471 | "by presburger\n",
472 | "then\n",
473 | "show ?thesis\n",
474 | "by auto\n",
475 | "qed\n"
476 | ]
477 | }
478 | ],
479 | "source": [
480 | "theorem_with_proof = \"\"\"theorem\n",
481 | "fixes x :: int\n",
482 | "assumes h0: \"even x\"\n",
483 | "shows \"odd (x+5)\"\n",
484 | "proof -\n",
485 | " (* If x is even, then it can be represented as 2n where n is some integer. *)\n",
486 | " obtain n where c1: \"x = 2*n\"\n",
487 | " using evenE assms\n",
488 | " sledgehammer\n",
489 | " (* So x + 5 equals 2n + 5. *)\n",
490 | " then have \"x + 5 = 2*n + 5\" \n",
491 | " sledgehammer\n",
492 | " (* We can rewrite 2n + 5 as 2(n + 2) + 1, which is in the form of 2k + 1 where k is an integer (in this case, n + 2). Thus, x + 5 is odd. *)\n",
493 | " also have \"\\ = 2*(n+2) + 1\"\n",
494 | " sledgehammer\n",
495 | " then have exI: \"\\k. x + 5 = 2*k+1\" \n",
496 | " sledgehammer\n",
497 | " then have \"odd (x+5)\" \n",
498 | " sledgehammer\n",
499 | " then show ?thesis \n",
500 | " sledgehammer\n",
501 | "qed\"\"\"\n",
502 | "\n",
503 | "result = checker.check(theorem_with_proof)\n",
504 | "\n",
505 | "print(\"\\n==== Success: %s\" % result['success'])\n",
506 | "print(\"--- Complete proof:\\n%s\" % result['theorem_and_proof'])"
507 | ]
508 | },
509 | {
510 | "cell_type": "markdown",
511 | "metadata": {},
512 | "source": [
513 | "We now have a verified formal proof of the the claim \"If x is even, then x+5 is odd\", and the proof is annotated with informal proof steps as comments (the text inside of `(*....*)`. Pretty cool!"
514 | ]
515 | },
516 | {
517 | "cell_type": "markdown",
518 | "metadata": {},
519 | "source": [
520 | "## Proof search\n",
521 | "\n",
522 | "In the simple example above, the formal proof was successful on the first try.\n",
523 | "\n",
524 | "In more complex settings we need to try multiple times. \n",
525 | "\n",
526 | "Namely, we can sample multiple drafts, sample multiple formal sketches for each draft, then see if any of them can be successfully proved with Sledgehammer:\n",
527 | "\n",
528 | "
\n",
529 | "\n",
530 | "This proof search algorithm is different from the best-first search in next-step suggestion. Namely, it:\n",
531 | "1. narrows the search space to proofs that are \"similar to\" the informal proof\\*\n",
532 | "2. does not interact with the proof assistant after each step\n",
533 | "\n",
534 | "\\*Though ultimately it is up to the neural language model to decide how to use the informal proof (if at all)."
535 | ]
536 | },
537 | {
538 | "cell_type": "markdown",
539 | "metadata": {},
540 | "source": [
541 | "### Scaling up proof search\n",
542 | "\n",
543 | "The Draft, Sketch, Prove paper shows the effect of scaling up proof search; that is, sampling multiple drafts and/or multiple formal sketches and attempting to verify them.\n",
544 | "\n",
545 | "Namely, proof search with a single sampled sequence yielded less than 100 successful proofs, but scaling to 100 sampled drafts and/or sketches yielded almost 200 successful proofs:\n",
546 | "\n",
547 | "
"
548 | ]
549 | },
550 | {
551 | "cell_type": "markdown",
552 | "metadata": {},
553 | "source": [
554 | "Naturally, a better model would require a smaller search budget when measured in terms of number of calls to the model. For instance, imagine a very good model that proved the same ~200 theorems on its first try. This would require fewer calls to the model than above.\n",
555 | "\n",
556 | "Second, the search algorithm used was fairly naive (temperature sampling that only interacts with the proof assistant once). A better search algorithm could yield more successful proofs, and/or a smaller search budget to reach a given number of successful proofs."
557 | ]
558 | },
559 | {
560 | "cell_type": "markdown",
561 | "metadata": {},
562 | "source": [
563 | "#### Examples\n",
564 | "\n",
565 | "Here is an interesting example from the Draft, Sketch, Prove paper. It shows how the neural model can generate an informal proof that differs from the human-written informal proof. The resulting sketches, and formal proofs produced by the system are much different:\n",
566 | "\n",
567 | "
"
568 | ]
569 | },
570 | {
571 | "cell_type": "markdown",
572 | "metadata": {},
573 | "source": [
574 | "Feel free to play around with the gpt-4 based implementation above to see what it can do."
575 | ]
576 | },
577 | {
578 | "cell_type": "markdown",
579 | "metadata": {},
580 | "source": [
581 | "----------------------\n",
582 | "\n",
583 | "# Other cascades\n",
584 | "\n",
585 | "[Baldur [First et al 2023]](https://arxiv.org/pdf/2303.04910.pdf) develop a **refinement**, or *proof repair*, module to correct proofs using error messages $e$:\n",
586 | "\n",
587 | "- $y_F^1\\sim p(\\cdot|x_F)$\n",
588 | "- $y_F^{2}\\sim p(\\cdot|x_F,y_F^1,e)$\n",
589 | "\n",
590 | "\n",
591 | "[pySagredo [Azerbayev 2023]](https://github.com/zhangir-azerbayev/pySagredo) is an experimental tactic that uses refinement and GPT-4 to prove theorems in Lean4.\n",
592 | "\n",
593 | "\n",
594 | "More generally, integrating modern language models with proof assistants remains an active area of research."
595 | ]
596 | },
597 | {
598 | "cell_type": "markdown",
599 | "metadata": {},
600 | "source": []
601 | }
602 | ],
603 | "metadata": {
604 | "kernelspec": {
605 | "display_name": "Python 3 (ipykernel)",
606 | "language": "python",
607 | "name": "python3"
608 | },
609 | "language_info": {
610 | "codemirror_mode": {
611 | "name": "ipython",
612 | "version": 3
613 | },
614 | "file_extension": ".py",
615 | "mimetype": "text/x-python",
616 | "name": "python",
617 | "nbconvert_exporter": "python",
618 | "pygments_lexer": "ipython3",
619 | "version": "3.10.11"
620 | }
621 | },
622 | "nbformat": 4,
623 | "nbformat_minor": 4
624 | }
625 |
--------------------------------------------------------------------------------
/partII_dsp/notebooks/images/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/.gitkeep
--------------------------------------------------------------------------------
/partII_dsp/notebooks/images/dsp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/dsp.png
--------------------------------------------------------------------------------
/partII_dsp/notebooks/images/dsp_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/dsp_example.png
--------------------------------------------------------------------------------
/partII_dsp/notebooks/images/dsp_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/dsp_plot.png
--------------------------------------------------------------------------------
/partII_dsp/notebooks/images/dsp_search.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/dsp_search.png
--------------------------------------------------------------------------------
/partII_dsp/notebooks/images/prove.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/prove.png
--------------------------------------------------------------------------------
/partII_dsp/notebooks/images/sketch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/sketch.png
--------------------------------------------------------------------------------
/partII_dsp/notebooks/images/sledgehammer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partII_dsp/notebooks/images/sledgehammer.png
--------------------------------------------------------------------------------
/partI_nextstep/README.md:
--------------------------------------------------------------------------------
1 | ## Part I : Next-step suggestion
2 |
3 | Builds a neural next-step suggestion tool, introducing key concepts and overviewing past work in neural theorem proving along the way.
4 |
5 |
6 |
7 | #### Notebooks:
8 | | Topic | Notebook |
9 | |:-----------------------|-------:|
10 | | 0. Intro | [notebook](./notebooks/I_nextstep_lean__part0_intro.ipynb) |
11 | | 1. Data | [notebook](./notebooks/I_nextstep_lean__part1_data.ipynb) |
12 | | 2. Learning | [notebook](./notebooks/I_nextstep_lean__part2_learn.ipynb) |
13 | | 3. Proof Search | [notebook](./notebooks/I_nextstep_lean__part3_proofsearch.ipynb) |
14 | | 4. Evaluation | [notebook](./notebooks/I_nextstep_lean__part4_evaluation.ipynb) |
15 | | 5. `llmsuggest` | [notebook](./notebooks/I_nextstep_lean__part5_llmsuggest.ipynb) |
16 |
17 | All notebooks are in ([`partI_nextstep/notebooks`](./notebooks)). See [`partI_nextstep/ntp_python`](./ntp_python) and [`partI_nextstep/ntp_lean`](./ntp_lean) for the Python and Lean files covered in the notebooks.
18 |
19 | ## Setup
20 | The notebooks use several tools: Lean (in VSCode), `pylean`, and LeanDojo. It also uses Pytorch and Huggingface for language modeling. Below are setup steps:
21 |
22 | ### Setup Lean
23 |
24 | #### Setup Lean in VS Code
25 | To try the interactive tool, you will need VS Code and Lean 4.
26 |
27 | Please follow the [official instructions for installing Lean 4 in VS Code](https://leanprover-community.github.io/install/macos_details.html#installing-and-configuring-an-editor): [Installing and configuring an editor](https://leanprover-community.github.io/install/macos_details.html#installing-and-configuring-an-editor).
28 |
29 |
30 | #### Setup Lean on the command line
31 |
32 | Additionally, to run the notebooks you will need Lean 4 installed on your laptop/server.
33 |
34 | On Linux or in a Colab notebook, run this command:
35 | ```
36 | # from https://leanprover-community.github.io/install/linux.html
37 | curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh
38 | source $HOME/.elan/env
39 | lake
40 | ```
41 |
42 | For MacOS see:
43 | ```
44 | https://leanprover-community.github.io/install/macos.html
45 | ```
46 |
47 | ### Setup [`pylean`](https://github.com/zhangir-azerbayev/repl/tree/master)
48 | The `Proof Search` notebook uses [`pylean`](https://github.com/zhangir-azerbayev/repl/tree/master).
49 |
50 | ```bash
51 | git clone
52 | https://github.com/zhangir-azerbayev/repl
53 | cd repl
54 |
55 | git checkout bddf452deda0df2240b248e651bcc37fb8e59d01
56 |
57 | cd pylean
58 | python setup.py develop
59 | ```
60 |
61 | Then add the following to `repl/lakefile.lean`:
62 | ```
63 | require mathlib from git
64 | "https://github.com/leanprover-community/mathlib4.git" @ "38dbcd8285bc4b1391619c12f158a7409f3dfc12"
65 | ```
66 |
67 |
68 | ### Setup LeanDojo
69 | If you want to reproduce the evaluation discussed in the `Evaluation` notebook and implemented in `proofsearch_dojo.py`, you will need to install Lean Dojo:
70 | ```
71 | pip install lean-dojo==1.1.2
72 | export CONTAINER="native"
73 | ```
74 | The second line is needed to run LeanDojo outside of Docker.
75 |
76 | ### Setup language modeling tools
77 | The notebooks use `pytorch` and Huggingface `transformers` and `datasets`.
78 | Here is a Conda environment file for an environment that this ran on (excluding the additional libraries above):
79 | ```
80 | name: ntp
81 | channels:
82 | - defaults
83 | dependencies:
84 | - python==3.10.11
85 | - pip
86 | - pip:
87 | - accelerate==0.21.0
88 | - datasets==2.13.1
89 | - huggingface-hub==0.15.1
90 | - ndjson==0.3.1
91 | - nest-asyncio==1.5.6
92 | - networkx==3.1
93 | - nh3==0.2.14
94 | - ninja==1.11.1
95 | - numpy==1.25.0
96 | - torch==2.0.1
97 | - tqdm==4.65.0
98 | - transformers==4.31.0
99 | prefix: /home/seanw/.conda/envs/ntp
100 | ```
101 |
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/I_nextstep_lean__part0_intro.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "#### Neural next-step prediction \n",
8 | "Tutorial on neural theorem proving\\\n",
9 | "Author: Sean Welleck\n",
10 | "\n",
11 | "----------------"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "### High-level goal\n",
19 | "\n",
20 | "In *[Generative Language Modeling for Automated Theorem Proving](https://arxiv.org/abs/2009.03393)*, Stanislas Polu and Ilya Sutskever showed that a language model called *gptf* was capable of generating novel proofs that were adopted by a formal mathematics community. Their key insight was to frame *theorem proving as language generation*: first, use a neural language model $p_\\theta(y_t|x_t)$ to model the distribution over potential next-steps to take $y_t$ given the current state of a proof $x_t$, by simply treating $x_t$ and $y_t$ as discrete token sequences; then, use the language model within a search algorithm to generate a full proof.\n",
21 | "\n",
22 | "Since their work in 2020, this approach-which we refer to as *next-step prediction*-has formed the basis for research on combining language models with formal proof assistants, which we refer to as *neural theorem proving*. \n",
23 | "Neural theorem proving involves *interacting* with a proof assistant in order to generate a *verifiable* formal proof. This differs from traditional language generation tasks such as long-form question answering, which are often performed without interaction and are difficult to reliably evaluate in a scalable way.\n",
24 | "On the other hand, formal code is extremely scarce compared to the natural language found in everyday questions, or the Python code used to benchmark language-model-based code generators.\n",
25 | "As a result, neural theorem proving offers a unique, yet challenging, playground for creative algorithmic development with language models.\n",
26 | "\n",
27 | "\n",
28 | "\n",
29 | "\n",
30 | "Beyond algorithmic development, neural theorem proving offers an opportunity to enable new and useful tools.\n",
31 | "This tutorial walks through building an interactive tool for receiving next-step suggestions from a neural language model. Doing so involves [collecting data (part 1)](./I_nextstep_lean__part1_data.ipynb), [learning a model (part 2)](./I_nextstep_lean__part2_learn.ipynb), measuring performance with [proof search (part 3)](./I_nextstep_lean__part3_proofsearch.ipynb) and [evaluation sets (part 4)](./I_nextstep_lean__part4_evaluation.ipynb), and deploying the model as an [interactive tool (part 5)](./I_nextstep_lean__part5_llmsuggest.ipynb). \n",
32 | "By working through these notebooks, you will see core research ideas and get a practical introduction to neural theorem proving. We hope that doing so lays a foundation for doing research in this area, and that the end product shows a glimpse of \"human-machine collaboration\" that we expect will only expand further in the future."
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "
\n",
40 | "
"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {},
46 | "source": []
47 | }
48 | ],
49 | "metadata": {
50 | "kernelspec": {
51 | "display_name": "Python 3 (ipykernel)",
52 | "language": "python",
53 | "name": "python3"
54 | },
55 | "language_info": {
56 | "codemirror_mode": {
57 | "name": "ipython",
58 | "version": 3
59 | },
60 | "file_extension": ".py",
61 | "mimetype": "text/x-python",
62 | "name": "python",
63 | "nbconvert_exporter": "python",
64 | "pygments_lexer": "ipython3",
65 | "version": "3.10.11"
66 | }
67 | },
68 | "nbformat": 4,
69 | "nbformat_minor": 4
70 | }
71 |
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/I_nextstep_lean__part1_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "#### Neural next-step prediction | part 1: data\n",
9 | "Tutorial on neural theorem proving\\\n",
10 | "Author: Sean Welleck\n",
11 | "\n",
12 | "----------------"
13 | ]
14 | },
15 | {
16 | "attachments": {},
17 | "cell_type": "markdown",
18 | "metadata": {},
19 | "source": [
20 | "#### High-level goal\n",
21 | "\n",
22 | "Our goal is to train a neural next-step prediction model, $p(y_t|x_t)$. Here $x_t$ is a _proof state_, and $y_t$ is a next-step.\n",
23 | "\n",
24 | "To do so, we will create a dataset $\\mathcal{D}=\\{(x_t,y_t)\\}$ from human-written proofs. \n",
25 | "\n",
26 | "We can then train a neural next-step prediction model using a next-token prediction loss on the dataset."
27 | ]
28 | },
29 | {
30 | "attachments": {},
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "#### Simple example\n",
35 | "\n",
36 | "To see what proof states and next-steps look like, let's look at an example human-written theorem and proof:\n",
37 | "\n"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 1,
43 | "metadata": {},
44 | "outputs": [
45 | {
46 | "name": "stdout",
47 | "output_type": "stream",
48 | "text": [
49 | "import Mathlib.Data.Nat.Prime\n",
50 | "\n",
51 | "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n",
52 | " rw [Nat.coprime] at h \n",
53 | " exact h "
54 | ]
55 | }
56 | ],
57 | "source": [
58 | "!cat ../ntp_lean/examples/example0.lean"
59 | ]
60 | },
61 | {
62 | "attachments": {},
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "We would like to transform this theorem and proof into a sequence of (proof_state, next_step) examples.\n",
67 | "\n",
68 | "First, notice that the proof has two steps:\n",
69 | "\n",
70 | "1. $y_1=$ `rw [Nat.coprime] at h`\n",
71 | "2. $y_2=$ `exact h`"
72 | ]
73 | },
74 | {
75 | "attachments": {},
76 | "cell_type": "markdown",
77 | "metadata": {},
78 | "source": [
79 | "We can manually see the proof states by looking in VSCode. \n",
80 | "\n",
81 | "For example, placing the cursor before $y_1$ gives us the proof state $x_1$ (shown as \"Tactic state\"):"
82 | ]
83 | },
84 | {
85 | "attachments": {},
86 | "cell_type": "markdown",
87 | "metadata": {},
88 | "source": [
89 | ""
90 | ]
91 | },
92 | {
93 | "attachments": {},
94 | "cell_type": "markdown",
95 | "metadata": {},
96 | "source": [
97 | "That is, the image above corresponds to $(x_1,y_1)$ defined as:\n",
98 | "\n",
99 | " $x_1$: \n",
100 | " ```\n",
101 | " m n : ℕ\n",
102 | " h : Nat.coprime m n\n",
103 | " ⊢ Nat.gcd m n = 1\n",
104 | " ```\n",
105 | "\n",
106 | " $y_1$: `rw [Nat.coprime] at h`\n",
107 | "\n",
108 | "\n",
109 | "Similarly, we can get the proof state $x_2$ prior to the step $y_2$ (`exact h`):\n",
110 | "\n",
111 | "\n",
112 | "\n",
113 | "After step $y_2$, the proof is complete: the proof state $x_3$ says we have \"No goals\":\n",
114 | "\n",
115 | "\n",
116 | "\n",
117 | "In summary, it is possible to *manually* transform the theorem and proof into a sequence $[(x_1,y_1),(x_2,y_2),(x_3)]$."
118 | ]
119 | },
120 | {
121 | "attachments": {},
122 | "cell_type": "markdown",
123 | "metadata": {},
124 | "source": [
125 | "## Automatically extracting proof states and next-steps \n",
126 | "\n",
127 | "To scale up data collection, we need a way to *automatically* extract proof states and next-steps from human-written proofs.\n",
128 | "\n",
129 | "\n",
130 | "\n",
131 | "A new open-source library by Kaiyu Yang et al. called [LeanDojo](https://leandojo.org/) can automatically extract (proof state, next-step) pairs from Lean proofs. This idea originated in [Han et al ICLR 2022](https://github.com/jesse-michael-han/lean-step-public). We will look at a simplified version of what LeanDojo does.\n",
132 | "\n",
133 | "The core idea is to (1) transform a Lean file into abstract syntax trees using Lean, and (2) postprocess the abstract syntax tree into a dataset. Lean4's powerful metaprogramming functionality give us the tools to do this."
134 | ]
135 | },
136 | {
137 | "attachments": {},
138 | "cell_type": "markdown",
139 | "metadata": {},
140 | "source": [
141 | "#### 1. Transform a Lean file\n",
142 | "\n",
143 | "Conceptually, we want a script:\n",
144 | "\n",
145 | "$\\quad f_{\\text{extract}}(\\text{lean file})\\rightarrow \\text{ASTs}$,\n",
146 | "\n",
147 | "We run a simplified version of the script `ExtractData.lean` from LeanDojo:\n",
148 | ""
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 2,
154 | "metadata": {},
155 | "outputs": [
156 | {
157 | "name": "stdout",
158 | "output_type": "stream",
159 | "text": [
160 | "Input file: partI_nextstep/ntp_lean/examples/example0.lean\n",
161 | "AST: partI_nextstep/ntp_lean/examples/example0.ast.json\n"
162 | ]
163 | }
164 | ],
165 | "source": [
166 | "!cd ../../ && lake env lean --run partI_nextstep/ntp_lean/ExtractSimple.lean partI_nextstep/ntp_lean/examples/example0.lean"
167 | ]
168 | },
169 | {
170 | "attachments": {},
171 | "cell_type": "markdown",
172 | "metadata": {},
173 | "source": [
174 | "The output file `example.ast.json` includes proof states and abstract syntax trees for the commands in `example0.lean`.\n",
175 | "\n",
176 | "Here are the proof states for our example:"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 3,
182 | "metadata": {},
183 | "outputs": [
184 | {
185 | "data": {
186 | "text/plain": [
187 | "[{'stateBefore': 'm n : ℕ h : Nat.coprime m n ⊢ Nat.gcd m n = 1',\n",
188 | " 'stateAfter': 'm n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1',\n",
189 | " 'pos': 101,\n",
190 | " 'endPos': 122},\n",
191 | " {'stateBefore': 'm n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1',\n",
192 | " 'stateAfter': 'no goals',\n",
193 | " 'pos': 127,\n",
194 | " 'endPos': 134}]"
195 | ]
196 | },
197 | "execution_count": 3,
198 | "metadata": {},
199 | "output_type": "execute_result"
200 | }
201 | ],
202 | "source": [
203 | "import json\n",
204 | "ast = json.load(open('../../partI_nextstep/ntp_lean/examples/example0.ast.json'))\n",
205 | "ast['tactics']"
206 | ]
207 | },
208 | {
209 | "attachments": {},
210 | "cell_type": "markdown",
211 | "metadata": {},
212 | "source": [
213 | "Notice that the proof states are the ones we saw above in VSCode.\n",
214 | "\n",
215 | "Here is the theorem statement's abstract syntax tree:"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 4,
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "name": "stdout",
225 | "output_type": "stream",
226 | "text": [
227 | "{'args': [{'node': {'args': [...],\n",
228 | " 'info': 'none',\n",
229 | " 'kind': 'Lean.Parser.Command.declModifiers'}},\n",
230 | " {'node': {'args': [...],\n",
231 | " 'info': 'none',\n",
232 | " 'kind': 'Lean.Parser.Command.theorem'}}],\n",
233 | " 'info': 'none',\n",
234 | " 'kind': 'Lean.Parser.Command.declaration'}\n"
235 | ]
236 | }
237 | ],
238 | "source": [
239 | "import pprint\n",
240 | "pprint.pprint(ast['commandASTs'][1]['node'], depth=4)"
241 | ]
242 | },
243 | {
244 | "attachments": {},
245 | "cell_type": "markdown",
246 | "metadata": {},
247 | "source": [
248 | "#### Post-processing\n",
249 | "\n",
250 | "Next, we post-process the extracted data into a dataset:\n",
251 | "\n",
252 | "$\\quad f_{\\text{post-process}}(\\text{ASTs}, \\text{lean file})\\rightarrow \\{(x_t,y_t)\\}.$\n",
253 | "\n",
254 | "To do so, we use the collected proof states, traverse the AST, and recover the next-steps from the original Lean file.\\\n",
255 | "See `ntp_python.postprocess_ast` for an example (naive) traversal which extracts the theorem name.\n",
256 | "\n",
257 | "Postprocessing `example0.lean` in this way gives us two $(x_t,y_t)$ pairs:"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 5,
263 | "metadata": {},
264 | "outputs": [
265 | {
266 | "name": "stdout",
267 | "output_type": "stream",
268 | "text": [
269 | "Theorem: theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 ...\n",
270 | "--- x1 ---\n",
271 | "m n : ℕ h : Nat.coprime m n ⊢ Nat.gcd m n = 1\n",
272 | "--- y1 ---\n",
273 | "rw [Nat.coprime] at h\n",
274 | "\n",
275 | "--- x2 ---\n",
276 | "m n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1\n",
277 | "--- y2 ---\n",
278 | "exact h\n",
279 | "\n"
280 | ]
281 | }
282 | ],
283 | "source": [
284 | "import sys\n",
285 | "sys.path.append('../')\n",
286 | "from ntp_python.postprocess_ast import get_theorem\n",
287 | "from collections import defaultdict\n",
288 | "\n",
289 | "theorem2examples = defaultdict(list)\n",
290 | "\n",
291 | "lean_file = open('../../partI_nextstep/ntp_lean/examples/example0.lean').read()\n",
292 | "for item in ast['tactics']:\n",
293 | " theorem = get_theorem(item['pos'], ast)\n",
294 | " theorem2examples[theorem].append({\n",
295 | " 'x': item['stateBefore'],\n",
296 | " 'y': lean_file[item['pos']:item['endPos']],\n",
297 | " })\n",
298 | "\n",
299 | "for theorem, examples in theorem2examples.items():\n",
300 | " print(\"Theorem: \", theorem[:60], '...', sep=' ')\n",
301 | " for t, example in enumerate(examples):\n",
302 | " print(f\"--- x{t+1} ---\", example['x'], sep='\\n')\n",
303 | " print(f\"--- y{t+1} ---\", example['y'], sep='\\n')\n",
304 | " print()"
305 | ]
306 | },
307 | {
308 | "attachments": {},
309 | "cell_type": "markdown",
310 | "metadata": {},
311 | "source": [
312 | "The core extraction code in LeanDojo is in [ExtractData.lean](https://github.com/lean-dojo/LeanDojo/blob/main/src/lean_dojo/data_extraction/ExtractData.lean) if you are curious.\n",
313 | "\n",
314 | "## Scaling up data collection\n",
315 | "In general, Lean projects are more complex than the simple example above. For instance, projects may:\n",
316 | "1. have a large number of files\n",
317 | "2. have dependencies on other files or projects\n",
318 | "3. have complex file structure that our naive postprocessing doesn't handle\n",
319 | "\n",
320 | "An example is the [mathlib project](https://leanprover-community.github.io/mathlib-overview.html). Mathlib itself changes rapidly, and other Lean projects may depend on specific versions. [LeanDojo](https://leandojo.readthedocs.io/en/latest/index.html|) gives tools for handling this complexity."
321 | ]
322 | },
323 | {
324 | "attachments": {},
325 | "cell_type": "markdown",
326 | "metadata": {},
327 | "source": [
328 | "#### Extracting 90k+ theorems with LeanDojo\n",
329 | "\n",
330 | "The LeanDojo tool allows for extracting data from an *arbitrary Lean Github repository*. Conceptually,\n",
331 | "\n",
332 | "$\\quad f_{\\text{leandojo}}(\\text{lean repository})\\rightarrow \\mathcal{D}.$\n",
333 | "\n",
334 | "It supports parallelism, keeps track of versions and dependencies for extracted data, and its post-processing handles more complex scenarios."
335 | ]
336 | },
337 | {
338 | "attachments": {},
339 | "cell_type": "markdown",
340 | "metadata": {},
341 | "source": [
342 | "**Example**\\\n",
343 | "Here is what the interface would look like for [extracting a dataset from Mathlib4](https://github.com/lean-dojo/LeanDojo/blob/main/scripts/generate-benchmark-lean4.ipynb):\n",
344 | "\n",
345 | "```python\n",
346 | " URL = \"https://github.com/leanprover-community/mathlib4\"\n",
347 | " COMMIT = \"5a919533f110b7d76410134a237ee374f24eaaad\"\n",
348 | " repo = LeanGitRepo(URL, COMMIT)\n",
349 | " traced_repo = trace(repo)\n",
350 | "```\n",
351 | "\n",
352 | "To avoid possible dependency issues, we won't run LeanDojo directly here. However, the LeanDojo authors provide the extracted data online, so we will download it for this tutorial:"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": 6,
358 | "metadata": {},
359 | "outputs": [
360 | {
361 | "name": "stdout",
362 | "output_type": "stream",
363 | "text": [
364 | "Number of non-empty training proofs: 41944\n",
365 | "{'commit': '5a919533f110b7d76410134a237ee374f24eaaad',\n",
366 | " 'end': [308, 76],\n",
367 | " 'file_path': 'Mathlib/Analysis/BoxIntegral/Box/Basic.lean',\n",
368 | " 'full_name': 'BoxIntegral.Box.withBotCoe_inj',\n",
369 | " 'start': [307, 1],\n",
370 | " 'traced_tactics': [{'state_after': 'no goals',\n",
371 | " 'state_before': 'ι : Type u_1\\n'\n",
372 | " 'I✝ J✝ : Box ι\\n'\n",
373 | " 'x y : ι → ℝ\\n'\n",
374 | " 'I J : WithBot (Box ι)\\n'\n",
375 | " '⊢ ↑I = ↑J ↔ I = J',\n",
376 | " 'tactic': 'simp only [Subset.antisymm_iff, ← '\n",
377 | " 'le_antisymm_iff, withBotCoe_subset_iff]'}],\n",
378 | " 'url': 'https://github.com/leanprover-community/mathlib4'}\n"
379 | ]
380 | }
381 | ],
382 | "source": [
383 | "import json\n",
384 | "import sys\n",
385 | "import pprint\n",
386 | "sys.path.append('../')\n",
387 | "from ntp_python.data import _download_and_unpack\n",
388 | "\n",
389 | "_download_and_unpack(\n",
390 | " tarball_url='https://zenodo.org/record/8040110/files/leandojo_benchmark_4_v1.tar.gz',\n",
391 | " data_dir='../data',\n",
392 | " overwrite=False\n",
393 | ")\n",
394 | "\n",
395 | "train = json.load(open('../data/leandojo_benchmark_4/random/train.json'))\n",
396 | "train = [x for x in train if len(x['traced_tactics']) > 0]\n",
397 | "print(\"Number of non-empty training proofs: \", len(train), sep=' ')\n",
398 | "pprint.pprint(train[0])"
399 | ]
400 | },
401 | {
402 | "attachments": {},
403 | "cell_type": "markdown",
404 | "metadata": {},
405 | "source": [
406 | "#### Next steps\n",
407 | "In part 2, we'll train a neural next-step generation model on this mathlib4 dataset."
408 | ]
409 | }
410 | ],
411 | "metadata": {
412 | "kernelspec": {
413 | "display_name": "Python 3 (ipykernel)",
414 | "language": "python",
415 | "name": "python3"
416 | },
417 | "language_info": {
418 | "codemirror_mode": {
419 | "name": "ipython",
420 | "version": 3
421 | },
422 | "file_extension": ".py",
423 | "mimetype": "text/x-python",
424 | "name": "python",
425 | "nbconvert_exporter": "python",
426 | "pygments_lexer": "ipython3",
427 | "version": "3.10.11"
428 | }
429 | },
430 | "nbformat": 4,
431 | "nbformat_minor": 4
432 | }
433 |
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/I_nextstep_lean__part2_learn.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "#### Neural next-step prediction | part 2: learning\n",
9 | "Tutorial on neural theorem proving\\\n",
10 | "Author: Sean Welleck\n",
11 | "\n",
12 | "----------------"
13 | ]
14 | },
15 | {
16 | "attachments": {},
17 | "cell_type": "markdown",
18 | "metadata": {},
19 | "source": [
20 | "#### High-level goal\n",
21 | "\n",
22 | "Our goal is to train a neural next-step predictor $p_\\theta(y_t|x_t)$ on the dataset that we collected in the previous notebook.\n",
23 | "\n",
24 | "To do so, we will fine-tune a pretrained language model on the dataset $\\mathcal{D}=\\{(x_t,y_t)\\}$ using the standard supervised fine-tuning approach:\n",
25 | "\n",
26 | "$$\n",
27 | "\\max_\\theta \\sum_{(x_t,y_t)\\in \\mathcal{D}}-\\log p_\\theta(y_t|x_t).\n",
28 | "$$\n",
29 | "\n",
30 | "That is, we maximize the conditional likelihood of a next-step $y_t$ given the context $x_t$. \\\n",
31 | "This corresponds to minimizing a cross-entropy loss at each position of the next-step, $\\sum_{\\ell=1}^{{|y_t|}}-\\log p_\\theta(y_t^\\ell|y_t^{<\\ell})$."
32 | ]
33 | },
34 | {
35 | "attachments": {},
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "### Implementation\n",
40 | "\n",
41 | "The implementation consists of two steps:\n",
42 | "\n",
43 | "1. **Data formatting** ([data.py](../ntp_python/data.py)): formatting the examples.\n",
44 | "2. **Tuning** ([tune.py](../ntp_python/tune.py)): using a standard language model fine-tuning script.\n",
45 | "\n"
46 | ]
47 | },
48 | {
49 | "attachments": {},
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "#### 1. Data formatting\n",
54 | "\n",
55 | "We format each (tactic-state, next-step) pair $(x_t, y_t)$ as:\n",
56 | "\n",
57 | " [GOAL]tacticstate[PROOFSTEP]next-step<|endoftext|>\n",
58 | "\n",
59 | "Here, `[GOAL]...[PROOFSTEP]` is the input and `next-step<|endoftext|>` is the output.\n",
60 | "\n",
61 | "This format comes from [Han et al ICLR 2022]: \\\n",
62 | "[Proof Artifact Co-training for Theorem Proving with Language Models](https://arxiv.org/pdf/2102.06203.pdf).\n",
63 | "\n",
64 | "\n",
65 | "\n",
66 | ""
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 1,
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "name": "stdout",
76 | "output_type": "stream",
77 | "text": [
78 | "Saving split to disk...\n"
79 | ]
80 | },
81 | {
82 | "name": "stderr",
83 | "output_type": "stream",
84 | "text": [
85 | "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.87it/s]\n"
86 | ]
87 | },
88 | {
89 | "name": "stdout",
90 | "output_type": "stream",
91 | "text": [
92 | "train\t169530\n",
93 | "val\t4053\n",
94 | "test\t3606\n"
95 | ]
96 | }
97 | ],
98 | "source": [
99 | "import sys\n",
100 | "sys.path.append('../ntp_python')\n",
101 | "import data\n",
102 | "\n",
103 | "datasets = data.proofstep(\n",
104 | " data_dir='../data'\n",
105 | ")"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 2,
111 | "metadata": {},
112 | "outputs": [
113 | {
114 | "name": "stdout",
115 | "output_type": "stream",
116 | "text": [
117 | "Input:\n",
118 | "[GOAL]ι : Type u_1\n",
119 | "I✝ J✝ : Box ι\n",
120 | "x y : ι → ℝ\n",
121 | "I J : WithBot (Box ι)\n",
122 | "⊢ ↑I = ↑J ↔ I = J[PROOFSTEP]\n",
123 | "\n",
124 | "Output:\n",
125 | "simp only [Subset.antisymm_iff, ← le_antisymm_iff, withBotCoe_subset_iff]<|endoftext|>\n"
126 | ]
127 | }
128 | ],
129 | "source": [
130 | "example = datasets['train'][0]\n",
131 | "print(\"Input:\", example['input'], '', sep='\\n')\n",
132 | "print(\"Output:\", example['output'], sep='\\n')"
133 | ]
134 | },
135 | {
136 | "attachments": {},
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | "#### 4. Tuning\n",
141 | "\n",
142 | "We minimally adapt a standard language-model fine-tuning script from [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py). \n",
143 | "\n",
144 | "You can check out the full script at [partI_nextstep/ntp_python/tune.py](../ntp_python/tune.py). \\\n",
145 | "See [partI_nextstep/scripts/tune_proofstep.sh](../scripts/tune_proofstep.sh) for a command that trains on 8 GPUs with deepspeed."
146 | ]
147 | },
148 | {
149 | "attachments": {},
150 | "cell_type": "markdown",
151 | "metadata": {},
152 | "source": [
153 | "Here's an example command for training a 1.4b model on 1 GPU (and you can adjust the model size to be smaller to fit your compute constraints):"
154 | ]
155 | },
156 | {
157 | "attachments": {},
158 | "cell_type": "markdown",
159 | "metadata": {},
160 | "source": [
161 | "```bash\n",
162 | " REPO_DIR=\"..\"\n",
163 | " TRAIN_FILE=${REPO_DIR}/data/leandojo_benchmark_4/processed/proofstep-train.jsonl\n",
164 | " VALID_FILE=${REPO_DIR}/data/leandojo_benchmark_4/processed/proofstep-val.jsonl\n",
165 | " MODEL=EleutherAI/pythia-1.4b-deduped\n",
166 | "\n",
167 | " OUTDIR=${REPO_DIR}/model/${MODEL}\n",
168 | "\n",
169 | " python ../ntp_python/tune.py \\\n",
170 | " --model_name_or_path ${MODEL} \\\n",
171 | " --train_data_path ${TRAIN_FILE} \\\n",
172 | " --valid_data_path ${VALID_FILE} \\\n",
173 | " --fp16 \\\n",
174 | " --output_dir ${OUTDIR} \\\n",
175 | " --num_train_epochs 10 \\\n",
176 | " --per_device_train_batch_size 4 \\\n",
177 | " --per_device_eval_batch_size 4 \\\n",
178 | " --gradient_accumulation_steps 16 \\\n",
179 | " --evaluation_strategy \"steps\" \\\n",
180 | " --eval_steps 500 \\\n",
181 | " --save_strategy \"steps\" \\\n",
182 | " --save_steps 500 \\\n",
183 | " --save_total_limit 1 \\\n",
184 | " --learning_rate 1e-5 \\\n",
185 | " --load_best_model_at_end 1 \\\n",
186 | " --weight_decay 0. \\\n",
187 | " --warmup_ratio 0.03 \\\n",
188 | " --lr_scheduler_type \"cosine\" \\\n",
189 | " --logging_steps 10 \\\n",
190 | " --logging_dir \"$OUTDIR\" \\\n",
191 | " --report_to=\"tensorboard\"\n",
192 | "\n",
193 | "```"
194 | ]
195 | },
196 | {
197 | "attachments": {},
198 | "cell_type": "markdown",
199 | "metadata": {},
200 | "source": [
201 | "#### After training\n",
202 | "\n",
203 | "If everything went well, you should have a model in `../model/{MODEL_NAME}/checkpoint-{BEST_STEP}`.\n",
204 | "\n",
205 | "We have fine-tuned an `EleutherAI/pythia-2.8b-deduped` model that can be accessed through HuggingFace ([link](https://huggingface.co/wellecks/llmstep-mathlib4-pythia2.8b)):"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": null,
211 | "metadata": {},
212 | "outputs": [],
213 | "source": [
214 | "import transformers\n",
215 | "\n",
216 | "MODEL = 'wellecks/llmstep-mathlib4-pythia2.8b'\n",
217 | "model = transformers.GPTNeoXForCausalLM.from_pretrained(MODEL)\n",
218 | "tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(MODEL)"
219 | ]
220 | },
221 | {
222 | "cell_type": "markdown",
223 | "metadata": {},
224 | "source": [
225 | "You can use your own model by setting `MODEL = \"../model/{MODEL_NAME}/checkpoint-{BEST_STEP}\"` \\\n",
226 | "(e.g., `../model/EleutherAI/pythia-2.8b-deduped/checkpoint-5000`)."
227 | ]
228 | },
229 | {
230 | "attachments": {},
231 | "cell_type": "markdown",
232 | "metadata": {},
233 | "source": [
234 | "Let's generate a next-step suggestion for the proof state from our original example:\n",
235 | "\n",
236 | "```lean\n",
237 | " theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1\n",
238 | "```\n",
239 | "Recal from the previous notebook that the initial proof state $x_0$ is:\n",
240 | "\n",
241 | " m n : ℕ\n",
242 | " h : Nat.coprime m n\n",
243 | " ⊢ Nat.gcd m n = 1"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 4,
249 | "metadata": {},
250 | "outputs": [
251 | {
252 | "name": "stdout",
253 | "output_type": "stream",
254 | "text": [
255 | "rw [← h.gcd_eq_one]\n"
256 | ]
257 | }
258 | ],
259 | "source": [
260 | "prompt = \"\"\"[GOAL]m n : ℕ\n",
261 | " h : Nat.coprime m n\n",
262 | " ⊢ Nat.gcd m n = 1[PROOFSTEP]\"\"\"\n",
263 | "\n",
264 | "input_ids = tokenizer.encode(prompt, return_tensors='pt')\n",
265 | "out = model.generate(\n",
266 | " input_ids,\n",
267 | " max_new_tokens=256,\n",
268 | " pad_token_id=tokenizer.eos_token_id\n",
269 | ")\n",
270 | "text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)\n",
271 | "print(text)"
272 | ]
273 | },
274 | {
275 | "attachments": {},
276 | "cell_type": "markdown",
277 | "metadata": {},
278 | "source": [
279 | "### Next steps\n",
280 | "\n",
281 | "In the next notebook, we will prove theorems with the trained model by interacting with the Lean proof assistant.\n",
282 | "\n",
283 | "This will let us automatically check whether a generated proof (e.g., one containing the step above) is correct.\n",
284 | "\n",
285 | "Later on, we will build a VSCode plugin that returns next-step suggestions from the language model."
286 | ]
287 | },
288 | {
289 | "cell_type": "markdown",
290 | "metadata": {},
291 | "source": []
292 | }
293 | ],
294 | "metadata": {
295 | "kernelspec": {
296 | "display_name": "Python 3 (ipykernel)",
297 | "language": "python",
298 | "name": "python3"
299 | },
300 | "language_info": {
301 | "codemirror_mode": {
302 | "name": "ipython",
303 | "version": 3
304 | },
305 | "file_extension": ".py",
306 | "mimetype": "text/x-python",
307 | "name": "python",
308 | "nbconvert_exporter": "python",
309 | "pygments_lexer": "ipython3",
310 | "version": "3.10.11"
311 | }
312 | },
313 | "nbformat": 4,
314 | "nbformat_minor": 4
315 | }
316 |
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/I_nextstep_lean__part3_proofsearch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "#### Neural next-step prediction | part 3: proof search\n",
8 | "Tutorial on neural theorem proving\\\n",
9 | "Author: Sean Welleck\n",
10 | "\n",
11 | "----------------\n",
12 | "\n",
13 | "#### High-level goal\n",
14 | "\n",
15 | "Our next goal is to prove theorems with our neural next-step predictor, and check whether the theorems are correct.\n",
16 | "\n",
17 | "Proving and checking a theorem involves generating a next-step candidate with our model, giving it to Lean, and receiving a next state from Lean (or an error message). \\\n",
18 | "To do so, we will need two components:\n",
19 | "\n",
20 | "1. **Interacting** with Lean: an automated way to give a next-step to Lean and receive a next state (or an error).\n",
21 | "\n",
22 | "2. A **search strategy** that uses the next-step model and Lean to find a proof (e.g. generate one next-step, get the next state, repeat).\n",
23 | "\n",
24 | "\n",
25 | "Below, we'll walk through a simple example of each. "
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "-------------------\n",
33 | "\n",
34 | "### 1. Interaction\n",
35 | "\n",
36 | "To start, we'll walk through proving this theorem:"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "```lean4\n",
44 | "import Mathlib.Data.Nat.Prime\n",
45 | "\n",
46 | "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n",
47 | " rw [Nat.coprime] at h \n",
48 | " exact h \n",
49 | "```"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "#### Interaction with `pylean`\n",
57 | "\n",
58 | "The [`pylean`](https://github.com/zhangir-azerbayev/repl/tree/master) library gives us a lightweight interface to a lean REPL.\n",
59 | "\n",
60 | "We can pass `pylean` the import and theorem statement:"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 1,
66 | "metadata": {},
67 | "outputs": [
68 | {
69 | "name": "stdout",
70 | "output_type": "stream",
71 | "text": [
72 | "{'env': 0,\n",
73 | " 'messages': [{'data': 'unsolved goals\\n'\n",
74 | " 'm n : ℕ\\n'\n",
75 | " 'h : Nat.coprime m n\\n'\n",
76 | " '⊢ Nat.gcd m n = 1',\n",
77 | " 'endPos': {'column': 69, 'line': 4},\n",
78 | " 'pos': {'column': 68, 'line': 4},\n",
79 | " 'severity': 'error'}],\n",
80 | " 'sorries': []}\n"
81 | ]
82 | }
83 | ],
84 | "source": [
85 | "from pylean import LeanServer\n",
86 | "from pprint import pprint\n",
87 | "\n",
88 | "code = \"\"\"\n",
89 | "import Mathlib.Data.Nat.Prime\n",
90 | "\n",
91 | "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by {}\n",
92 | "\"\"\"\n",
93 | "\n",
94 | "lean = LeanServer()\n",
95 | "state = lean.run_code(code)\n",
96 | "lean.proc.close()\n",
97 | "pprint(state)"
98 | ]
99 | },
100 | {
101 | "cell_type": "markdown",
102 | "metadata": {},
103 | "source": [
104 | "We see that inside of `'data'`, `pylean` gives us the current proof state $x_t$; here's basic parsing code:"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 2,
110 | "metadata": {},
111 | "outputs": [
112 | {
113 | "name": "stdout",
114 | "output_type": "stream",
115 | "text": [
116 | "m n : ℕ\n",
117 | "h : Nat.coprime m n\n",
118 | "⊢ Nat.gcd m n = 1\n"
119 | ]
120 | }
121 | ],
122 | "source": [
123 | "def get_goal(state):\n",
124 | " goal = None\n",
125 | " for msg in state['messages']:\n",
126 | " if msg['data'].startswith('unsolved goals\\n'):\n",
127 | " goal = '\\n'.join(msg['data'].split('\\n')[1:])\n",
128 | " elif msg['severity'] == 'error':\n",
129 | " return None\n",
130 | " return goal\n",
131 | "\n",
132 | "print(get_goal(state))"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {},
138 | "source": [
139 | "We can use $x_t$ as input to our model $p_\\theta(y_t|x_t)$.\\\n",
140 | "Next, we load the trained model and generate a next step, $\\hat y_t\\sim q(p_\\theta(y_t|x_t))$.\n",
141 | "\n",
142 | "(Here $q(\\cdot)$ is a decoding algorithm such as greedy decoding or temperature sampling.)"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "# Load model and tokenizer\n",
152 | "import os\n",
153 | "import transformers\n",
154 | "model_name = 'wellecks/llmstep-mathlib4-pythia2.8b'\n",
155 | "model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name)\n",
156 | "tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name)\n",
157 | "os.environ['TOKENIZERS_PARALLELISM'] = 'false' # prevents an annoying warning\n",
158 | "\n",
159 | "\n",
160 | "def generate(prompt):\n",
161 | " input_ids = tokenizer.encode(prompt, return_tensors='pt')\n",
162 | " out = model.generate(\n",
163 | " input_ids,\n",
164 | " max_new_tokens=256,\n",
165 | " pad_token_id=tokenizer.eos_token_id\n",
166 | " )\n",
167 | " text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)\n",
168 | " return text"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 4,
174 | "metadata": {},
175 | "outputs": [
176 | {
177 | "name": "stdout",
178 | "output_type": "stream",
179 | "text": [
180 | "rw [← h.gcd_eq_one]\n"
181 | ]
182 | }
183 | ],
184 | "source": [
185 | "# Generate a next step\n",
186 | "prompt = f\"[GOAL]{get_goal(state)}[PROOFSTEP]\"\n",
187 | "\n",
188 | "next_step = generate(prompt)\n",
189 | "print(next_step)"
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "metadata": {},
195 | "source": [
196 | "Finally, we can give the generated next step to Lean and receive the next state."
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 5,
202 | "metadata": {},
203 | "outputs": [
204 | {
205 | "name": "stdout",
206 | "output_type": "stream",
207 | "text": [
208 | "{'env': 0, 'messages': [], 'sorries': []}\n"
209 | ]
210 | }
211 | ],
212 | "source": [
213 | "code = \"\"\"\n",
214 | "import Mathlib.Data.Nat.Prime\n",
215 | "\n",
216 | "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n",
217 | "\n",
218 | "\"\"\" + next_step\n",
219 | "\n",
220 | "lean = LeanServer()\n",
221 | "state = lean.run_code(code)\n",
222 | "lean.proc.close()\n",
223 | "\n",
224 | "pprint(state)"
225 | ]
226 | },
227 | {
228 | "cell_type": "markdown",
229 | "metadata": {},
230 | "source": [
231 | "There are no error messages, and no remaining goals - the proof is complete! If you want, paste this into VS Code to convince yourself that it's complete:\n",
232 | "\n",
233 | "```lean4\n",
234 | "import Mathlib.Data.Nat.Prime\n",
235 | "\n",
236 | "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by\n",
237 | " rw [← h.gcd_eq_one]\n",
238 | "```\n",
239 | "\n",
240 | "Also, notice that the machine-generated proof is different from the human written one shown at the starting of this section."
241 | ]
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "metadata": {},
246 | "source": [
247 | "-----------------\n",
248 | "\n",
249 | "### 2. Search strategy\n",
250 | "\n",
251 | "In the proof above, we simply generated one next step and the proof was complete.\n",
252 | "\n",
253 | "In general, proofs are multiple steps. Therefore we need an algorithm for generating a multiple step proof, which we refer to as a *search algorithm*.\n",
254 | "\n",
255 | "\n",
256 | "First, let's consider a naive algorithm that generates a next step, then continues to the next state. Upon receiving an error message\n",
257 | "the algorithm generates another next step."
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 6,
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "import sys\n",
267 | "sys.path.append('../ntp_python/')\n",
268 | "\n",
269 | "import proofsearch_pylean as proofsearch # some utilities for running code (as we did above) and parsing states/model outputs"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 7,
275 | "metadata": {},
276 | "outputs": [
277 | {
278 | "name": "stdout",
279 | "output_type": "stream",
280 | "text": [
281 | "== Current (0): \n",
282 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
283 | "\n",
284 | "-- Goal: \n",
285 | "a b c : ℕ\n",
286 | "⊢ a + b = c → a ≤ c\n",
287 | "\n",
288 | "== Current (1): \n",
289 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
290 | "rintro rfl\n",
291 | "-- Goal: \n",
292 | "a b : ℕ\n",
293 | "⊢ a ≤ a + b\n",
294 | "\n",
295 | "== Current (2): \n",
296 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
297 | "rintro rfl\n",
298 | "exact le_add_left _ _\n",
299 | "-- Error: backtracking\n",
300 | "-- Goal: \n",
301 | "a b : ℕ\n",
302 | "⊢ a ≤ a + b\n",
303 | "\n",
304 | "== Current (3): \n",
305 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
306 | "rintro rfl\n",
307 | "apply Nat.le_add_right sperr a\n",
308 | "-- Error: backtracking\n",
309 | "-- Goal: \n",
310 | "a b : ℕ\n",
311 | "⊢ a ≤ a + b\n",
312 | "\n",
313 | "== Current (4): \n",
314 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
315 | "rintro rfl\n",
316 | "apply Nat.le_add_right\n",
317 | "\n",
318 | "SUCCESS!\n",
319 | "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
320 | " rintro rfl\n",
321 | " apply Nat.le_add_right\n"
322 | ]
323 | }
324 | ],
325 | "source": [
326 | "transformers.set_seed(43)\n",
327 | "\n",
328 | "def prove_simple(model, tokenizer, header, theorem_statement, search_budget):\n",
329 | " success = False\n",
330 | "\n",
331 | " code = header + theorem_statement\n",
332 | " steps = []\n",
333 | " proof = ''\n",
334 | "\n",
335 | " for i in range(search_budget):\n",
336 | " print(\"== Current (%d): \" % i, theorem_statement[:-3] + '\\n' + proof, sep='\\n')\n",
337 | "\n",
338 | " # Run the code (header + proof-so-far)\n",
339 | " state = proofsearch.run_code(code)\n",
340 | " \n",
341 | " # Stop if the proof is complete.\n",
342 | " if proofsearch.is_done(state):\n",
343 | " success = True\n",
344 | " break\n",
345 | "\n",
346 | " # Get the new state.\n",
347 | " goal_candidate = proofsearch.get_goal(state)\n",
348 | " if goal_candidate is None:\n",
349 | " print(\"-- Error: backtracking\")\n",
350 | " steps = steps[:-1]\n",
351 | " else:\n",
352 | " goal = goal_candidate\n",
353 | "\n",
354 | " print(\"-- Goal: \", goal, sep='\\n')\n",
355 | "\n",
356 | " # Generate a next-step\n",
357 | " prompt = f\"[GOAL]{goal}[PROOFSTEP]\"\n",
358 | " texts, _= proofsearch.generate(prompt, model, tokenizer, temperatures=[0.5], num_samples=1)\n",
359 | " step = proofsearch.parse_step(texts[0])\n",
360 | "\n",
361 | " # Add the next-step to the proof-so-far\n",
362 | " steps.append(step)\n",
363 | " proof = '\\n'.join(steps)\n",
364 | " code = header + theorem_statement.replace(\" {}\", \"\") + '\\n' + proof\n",
365 | " print()\n",
366 | "\n",
367 | " if success:\n",
368 | " print(\"\\nSUCCESS!\")\n",
369 | " else:\n",
370 | " print(\"\\nFAILED\")\n",
371 | " \n",
372 | " print(theorem_statement.replace(\" {}\", \"\"))\n",
373 | " print (' ' + proof.replace('\\n', '\\n '))\n",
374 | " \n",
375 | " return {'theorem_statement': theorem_statement, 'proof': proof, 'success': success}\n",
376 | "\n",
377 | "\n",
378 | "header = \"\"\"\n",
379 | "import Mathlib.Data.Nat.Prime\n",
380 | "\n",
381 | "\"\"\"\n",
382 | "theorem_statement = \"\"\"theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}\"\"\"\n",
383 | "\n",
384 | "\n",
385 | "out = prove_simple(\n",
386 | " model, \n",
387 | " tokenizer,\n",
388 | " header, \n",
389 | " theorem_statement, \n",
390 | " search_budget=100\n",
391 | ")"
392 | ]
393 | },
394 | {
395 | "cell_type": "markdown",
396 | "metadata": {},
397 | "source": [
398 | "Above (setting `seed = 43` for reproducibility) the model generates `rintro rfl`. \\\n",
399 | "Next it generates `exact le_add_left _ _`, which receives an error, so the model tries again (backtracks). \\\n",
400 | "After backtracking one more time, the model generates `apply Nat.le_add_right` and the proof is complete."
401 | ]
402 | },
403 | {
404 | "cell_type": "markdown",
405 | "metadata": {},
406 | "source": [
407 | "### Best-first search\n",
408 | "\n",
409 | "Typically a less naive search procedure is used. These searches are usually variants of a tree search, in which nodes are states and edges are next-steps. \n",
410 | "\n",
411 | "The most common search in neural theorem proving is *best-first search*. This search:\n",
412 | "\n",
413 | "- generates multiple next-step suggestions to form (proof-so-far + next-step) *candidates*\n",
414 | "- scores all candidates so far\n",
415 | "- selects the highest scoring candidate\n",
416 | "\n",
417 | "A typical scoring function is the model's log probability, $\\log p_\\theta(y_t|x_t)$, summed across steps. Next-steps that lead to an error receive a score of $-\\infty$ (in practice, we discard these steps). In the literature, the scoring function is called a *value function* $v(y_{\\leq t}, x_t)$.\n",
418 | "\n",
419 | "#### Intuition\n",
420 | "\n",
421 | "A key idea is generating multiple suggestions at each step, ${y_t^{(1)},\\ldots,y_t^{(k)}}\\sim p_\\theta(\\cdot|x_t)$. Intuitively, the goal is to select a next-step that will lead to a correct proof. In general, we do not know whether a next-step will lead to a correct proof, so we use a heuristic value function for selecting a next-step.\n",
422 | "\n",
423 | "Here's what multiple suggestions and their (normalized) log-probabilities look like in our example:"
424 | ]
425 | },
426 | {
427 | "cell_type": "code",
428 | "execution_count": 8,
429 | "metadata": {},
430 | "outputs": [
431 | {
432 | "name": "stdout",
433 | "output_type": "stream",
434 | "text": [
435 | "-0.277\trw [Nat.coprime, gcd_comm] at h\n",
436 | "-0.279\trw [← h.gcd_eq_one]\n",
437 | "-0.335\tapply Nat.eq_one_of_dvd_dvd\n",
438 | "-0.349\trw [Nat.coprime] at h\n",
439 | "-0.350\trw [gcd_comm]\n"
440 | ]
441 | }
442 | ],
443 | "source": [
444 | "prompt = '[GOAL]m n : ℕ\\nh : Nat.coprime m n\\n⊢ Nat.gcd m n = 1[PROOFSTEP]'\n",
445 | "texts, scores = proofsearch.generate(prompt, model, tokenizer, temperatures=[0.0], num_samples=5)\n",
446 | "for text, score in zip(texts, scores):\n",
447 | " print('%.3f' % score, text, sep='\\t')"
448 | ]
449 | },
450 | {
451 | "cell_type": "markdown",
452 | "metadata": {},
453 | "source": [
454 | "### Implementation\n",
455 | "\n",
456 | "A minimal implementation of best first search is available in `proofsearch_pylean.py`.\\\n",
457 | "A version that uses LeanDojo for interaction is in `proofsearch_dojo.py`.\n",
458 | "\n",
459 | "We will use these in the next notebook to evaluate our model on a set of evaluation theorems.\\\n",
460 | "Below, we run best first search and print out the search trajectory:"
461 | ]
462 | },
463 | {
464 | "cell_type": "code",
465 | "execution_count": 9,
466 | "metadata": {},
467 | "outputs": [
468 | {
469 | "name": "stdout",
470 | "output_type": "stream",
471 | "text": [
472 | "--- current:\n",
473 | "\ttheorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n",
474 | "\t\n"
475 | ]
476 | },
477 | {
478 | "name": "stderr",
479 | "output_type": "stream",
480 | "text": [
481 | "100%|██████████| 4/4 [00:03<00:00, 1.10it/s]\n"
482 | ]
483 | },
484 | {
485 | "name": "stdout",
486 | "output_type": "stream",
487 | "text": [
488 | "--- type-checked candidates:\n",
489 | "\t(-0.066) rintro rfl\n",
490 | "\t(-0.307) rintro ⟨rfl, rfl⟩\n",
491 | "\t(-0.035) intro h\n",
492 | "\t(-0.230) rintro ⟨d, rfl⟩\n",
493 | "--- current:\n",
494 | "\ttheorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n",
495 | "\tintro h\n"
496 | ]
497 | },
498 | {
499 | "name": "stderr",
500 | "output_type": "stream",
501 | "text": [
502 | "100%|██████████| 4/4 [00:03<00:00, 1.11it/s]\n"
503 | ]
504 | },
505 | {
506 | "name": "stdout",
507 | "output_type": "stream",
508 | "text": [
509 | "--- type-checked candidates:\n",
510 | "\t(-0.172) apply le_of_add_le_add_right\n",
511 | "\t(-0.093) rw [← h]\n",
512 | "\t(-0.453) cases c\n",
513 | "--- current:\n",
514 | "\ttheorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n",
515 | "\trintro rfl\n"
516 | ]
517 | },
518 | {
519 | "name": "stderr",
520 | "output_type": "stream",
521 | "text": [
522 | "100%|██████████| 4/4 [00:03<00:00, 1.10it/s]"
523 | ]
524 | },
525 | {
526 | "name": "stdout",
527 | "output_type": "stream",
528 | "text": [
529 | "--- type-checked candidates:\n",
530 | "\t(-0.109) apply Nat.le_add_right\n",
531 | "\t(-0.173) exact Nat.le_add_right _ _\n"
532 | ]
533 | },
534 | {
535 | "name": "stderr",
536 | "output_type": "stream",
537 | "text": [
538 | "\n"
539 | ]
540 | },
541 | {
542 | "data": {
543 | "text/plain": [
544 | "{'theorem_statement': 'theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}',\n",
545 | " 'proof': ['rintro rfl', 'apply Nat.le_add_right'],\n",
546 | " 'state': {'sorries': [], 'messages': [], 'env': 0},\n",
547 | " 'score': 0.1747819110751152,\n",
548 | " 'success': True}"
549 | ]
550 | },
551 | "execution_count": 9,
552 | "metadata": {},
553 | "output_type": "execute_result"
554 | }
555 | ],
556 | "source": [
557 | "proofsearch.best_first_search(\n",
558 | " model, tokenizer, header, theorem_statement, \n",
559 | " max_iters=32,\n",
560 | " num_samples=4,\n",
561 | " temperatures=[0.0],\n",
562 | " verbose=True\n",
563 | ")"
564 | ]
565 | },
566 | {
567 | "cell_type": "markdown",
568 | "metadata": {},
569 | "source": [
570 | "The search selects a candidate trajectory, and generates 4 next-step suggestions.\\\n",
571 | "`intro h` is selected at the first step. The best expansion of `intro h` has score -0.093. \\\n",
572 | "This is less than the score of `rintro rfl` (-0.066), so `rintro rfl` is picked. This is backtracking, since `intro h` is no longer in the proof.\\\n",
573 | "Then `apply Nat.le_add_right` is suggested and the proof is complete.\n"
574 | ]
575 | },
576 | {
577 | "cell_type": "markdown",
578 | "metadata": {},
579 | "source": [
580 | "--------------------\n",
581 | "\n",
582 | "\n",
583 | "## Extensions\n",
584 | "\n",
585 | "Several works have proposed to improve the search strategy, either with a learned value function or a sophisticated search:\n",
586 | "\n",
587 | "- [Polu & Sutskever 2020](https://arxiv.org/pdf/2009.03393.pdf) propose to learn a value function $v(y_{\\leq t}, x_t)$ that estimates the probability of successfully proving the theorem with the model $p_\\theta$ starting at state $x_t$. To do so, they use proof search trajectories obtained by doing proof search with the model.\n",
588 | "\n",
589 | "- [Polu et al ICLR 2023](https://openreview.net/pdf?id=-P7G-8dmSh4) train the value function to predict the eventual length of the proof (or 0 if it is predicted to fail). The learned value function improves pass rate by ~10\\% on mathlib theorems compared to log-probability, with a ~1\\% improvement over the learned value function from [Polu & Sutskever 2020].\n",
590 | "\n",
591 | "- [Lample et al NeurIPS 2022](https://openreview.net/pdf?id=J4pX8Q8cxHH) propose a sophisticated MCTS-like search that explores multiple trajectories in parallel, collecting statistics on visited states in order to prioritize search trajectories.\n",
592 | "\n",
593 | "Reproducing, analyzing, and improving the search algorithm remains an open area for future work in neural theorem proving (for instance, these works were not open-sourced).\n",
594 | "\n",
595 | "Search algorithms are also an active area of research in LLMs, including methods like [tree-of-thought](https://arxiv.org/abs/2305.10601), [stepwise beam search](https://arxiv.org/pdf/2205.12910.pdf), [self-consistency](https://arxiv.org/pdf/2203.11171.pdf), and search with [learned stepwise verifiers](https://arxiv.org/pdf/2305.20050.pdf). In theorem proving, the final output is verifiable, but the quality of intermediate steps is difficult to evaluate."
596 | ]
597 | }
598 | ],
599 | "metadata": {
600 | "kernelspec": {
601 | "display_name": "Python 3 (ipykernel)",
602 | "language": "python",
603 | "name": "python3"
604 | },
605 | "language_info": {
606 | "codemirror_mode": {
607 | "name": "ipython",
608 | "version": 3
609 | },
610 | "file_extension": ".py",
611 | "mimetype": "text/x-python",
612 | "name": "python",
613 | "nbconvert_exporter": "python",
614 | "pygments_lexer": "ipython3",
615 | "version": "3.10.11"
616 | }
617 | },
618 | "nbformat": 4,
619 | "nbformat_minor": 4
620 | }
621 |
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/I_nextstep_lean__part5_llmsuggest.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "#### Neural next-step prediction | part 5: `llmsuggest` co-pilot \n",
9 | "Tutorial on neural theorem proving\\\n",
10 | "Author: Sean Welleck\n",
11 | "\n",
12 | "----------------"
13 | ]
14 | },
15 | {
16 | "attachments": {},
17 | "cell_type": "markdown",
18 | "metadata": {},
19 | "source": [
20 | "#### High-level goal\n",
21 | "\n",
22 | "Finally, we will see how our neural model can act as a helpful \"co-pilot\" when we are writing proofs. \\\n",
23 | "We will make an interactive tool that uses our neural next-step model to suggest next-steps in VSCode.\n",
24 | "\n",
25 | "Concretely, we'll create a `llmsuggest` tactic (in essence, a function) that displays generated suggestions in VSCode. `llmsuggest` is a minimal version of the [**llmstep** [Welleck & Saha 2023]](https://github.com/wellecks/llmstep) tactic, aimed at learning and building off of.\n",
26 | "\n",
27 | "Here is a preview of `llmsuggest`:"
28 | ]
29 | },
30 | {
31 | "attachments": {},
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "
\n"
36 | ]
37 | },
38 | {
39 | "attachments": {},
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "On the top, the user entered `llmsuggest`, then suggestions from our next-step prediction model appear in the Lean Infoview. Clicking a suggestion adds it to the proof."
44 | ]
45 | },
46 | {
47 | "attachments": {},
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "-----------------------\n",
52 | "\n",
53 | "### High-level approach\n",
54 | "\n",
55 | "Implementing `llmsuggest` involves three components:\n",
56 | "\n",
57 | "1. A Lean *tactic* that sends the current state to a Python script.\n",
58 | "2. A Python script that sends the current state to a server via a POST request.\n",
59 | "3. A Python server that runs our next-step suggestion model on the current state.\n",
60 | "\n",
61 | "The suggestions (3) are sent back to (2), and the tactic (1) displays the result in VSCode.\n",
62 | "\n",
63 | "\n",
64 | "### Implementing `llmsuggest`"
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "metadata": {},
70 | "source": [
71 | "#### 1. Tactic\n",
72 | "\n",
73 | "At a technical level, the proofs we have seen are actually sequences of *tactics*. \n",
74 | "For instance, `intro` is a tactic and `rw [...]` is a tactic. In general, a *tactic* is a Lean program that manipulates a state. \n",
75 | "\n",
76 | "To build a new tactic, we use *Lean metaprogramming*, which gives us tools to define new syntax, access the proof state, and more. \\\n",
77 | "`llmsuggest` only requires basic metaprogramming. To learn more, see the [Lean 4 metaprogramming book](https://github.com/leanprover-community/lean4-metaprogramming-book/tree/master).\n",
78 | "\n",
79 | "`llmsuggest` is implemented in `ntp_lean/LLMsuggest.lean`. The main definition specifies the syntax (i.e., `\"llmsuggest\"`), then defines the tactic. \\\n",
80 | "You can see below that the tactic gets the main goal (the \"tactic state\"), pretty-prints it, and converts it to a string.\n",
81 | "Then it runs a `runSuggest` function, and passes the output to an `addSuggestions` function:\n",
82 | "\n",
83 | "```lean\n",
84 | "-- `llmsuggest` tactic.\n",
85 | "syntax \"llmsuggest\" : tactic\n",
86 | "elab_rules : tactic\n",
87 | " | `(tactic | llmsuggest%$tac) =>\n",
88 | " Lean.Elab.Tactic.withMainContext do\n",
89 | " let goal ← Lean.Elab.Tactic.getMainGoal\n",
90 | " let ppgoal ← Lean.Meta.ppGoal goal\n",
91 | " let ppgoalstr := toString ppgoal\n",
92 | " let suggest ← runSuggest #[ppgoalstr]\n",
93 | " addSuggestions tac $ suggest.splitOn \"[SUGGESTION]\"\n",
94 | "```\n",
95 | "\n",
96 | "The `runSuggest` function calls a Python script (step 2 above), and the `addSuggestions` uses a Lean widget to display the results in VSCode. \\\n",
97 | "We won't look at these in detail, but please see `ntp_lean/LLMsuggest.lean` if you are curious. \\\n",
98 | "Hopefully with a small amount of effort, you can modify the tactic or make your own in the future.\n",
99 | "\n",
100 | "\n",
101 | "#### 2. Python script\n",
102 | "\n",
103 | "The `runSuggest` function in the tactic calls a Python script, `ntp_python/llmsuggest/suggest.py`. It passes the current tactic state as a command line argument.\\\n",
104 | "The script is simple: it sends a POST request containing the current tactic state to a server:\n",
105 | "\n",
106 | "```python\n",
107 | "def suggest(tactic_state):\n",
108 | " conn = http.client.HTTPConnection(\"localhost\", 5000)\n",
109 | " headers = {'Content-type': 'application/json'}\n",
110 | " body = json.dumps({\"tactic_state\": sys.argv[1]})\n",
111 | " conn.request(\"POST\", \"/\", body, headers)\n",
112 | " response = conn.getresponse()\n",
113 | " data = response.read()\n",
114 | " data_dict = json.loads(data)\n",
115 | " print('[SUGGESTION]'.join(data_dict['suggestions']))\n",
116 | " conn.close()\n",
117 | "\n",
118 | "if __name__ == \"__main__\":\n",
119 | " suggest(sys.argv[1])\n",
120 | "```\n",
121 | "\n",
122 | "After receiving suggestions, it prints the suggestions, and the printed suggestions will be received in the `runSuggest` function.\n",
123 | "\n",
124 | "#### 3. Server\n",
125 | "Finally, in `ntp_python/llmsuggest/server.py` we define a web server that handles the POST request, and hosts the language model.\n",
126 | "Specifically, the server initializes our language model, and uses the model to\n",
127 | "generate suggestions given a tactic state received in a POST request.\n",
128 | "```python\n",
129 | "model = transformers.GPTNeoXForCausalLM.from_pretrained('wellecks/llmstep-mathlib4-pythia2.8b')\n",
130 | "\n",
131 | "def generate ...\n",
132 | "\n",
133 | "app = Flask(__name__)\n",
134 | "\n",
135 | "@app.route('/', methods=['POST'])\n",
136 | "def process_request():\n",
137 | " data = request.get_json()\n",
138 | " tactic_state = data.get('tactic_state')\n",
139 | " prompt = \"\"\"[GOAL]%s[PROOFSTEP]\"\"\" % (tactic_state)\n",
140 | " texts = generate(prompt, args.num_samples)\n",
141 | " response = {\"suggestions\": texts}\n",
142 | " return jsonify(response)\n",
143 | "\n",
144 | "if __name__ == '__main__':\n",
145 | " app.run(host='0.0.0.0', port=args.port)\n",
146 | "```\n",
147 | "\n",
148 | "This server is minimal; one can imagine adding several features."
149 | ]
150 | },
151 | {
152 | "cell_type": "markdown",
153 | "metadata": {},
154 | "source": [
155 | "### Running `llmsuggest`\n",
156 | "\n",
157 | "To run `llmsuggest`, first start the server:\n",
158 | "```bash\n",
159 | "python python/server.py\n",
160 | "```\n",
161 | "\n",
162 | "Then open `ntp_lean/LLMsuggest.lean` in VS Code and try out `llmsuggest`. There are some example theorems and proofs at the bottom of the page:\n",
163 | "\n",
164 | "
"
165 | ]
166 | },
167 | {
168 | "attachments": {},
169 | "cell_type": "markdown",
170 | "metadata": {},
171 | "source": [
172 | "-----------------------\n",
173 | "\n",
174 | "### `llmstep`: [L]LM proofstep suggestions in Lean\n",
175 | "\n",
176 | "[`llmstep`](https://github.com/wellecks/llmstep) is an expanded version of the `llm_suggest` tactic: https://github.com/wellecks/llmstep\n",
177 | "\n",
178 | "`llmstep` includes features such as:\n",
179 | "1. **Type checking**: suggestions are checked by Lean and marked as completing a proof, valid, or invalid (but still possibly useful).\n",
180 | "2. **Prefixed generation**: e.g. `llmstep \"exact\"` returns suggestions that start with `\"exact\"`\n",
181 | "3. **Fast inference**: fast inference via [PagedAttention](https://vllm.ai/) for near real-time suggestions\n",
182 | "4. **Other models**: support for other models, e.g. `llmstep-llama2`\n",
183 | "\n",
184 | "Here's an example of using `llmstep`:\n",
185 | "\n",
186 | "
\n"
187 | ]
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "metadata": {},
192 | "source": [
193 | "The first invocation (`llmstep \"\"`) gives 5 suggestions, with `intro h n` and `intro h` outlined in blue since they type check.\n",
194 | "\n",
195 | "The second invocation (`llmstep \"exact\"`) gives suggestions that start with `exact`. The first three are outlined in green since they complete the proof."
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {},
201 | "source": [
202 | "----------\n",
203 | "\n",
204 | "## Next steps\n",
205 | "\n",
206 | "This concludes part 1 of the tutorial. We have seen how to build a neural next-step suggestion tool from scratch: collecting data, learning a model, measuring performance with proof search and evaluation sets, and deploying the model as an interactive tactic.\n",
207 | "\n",
208 | "In part 2, we will look at a generalization called language cascades, in which a language model implements a \"function\" that does more than predict the next step. We will see example cascades for drafting informal proofs, sketching the high-level structure of a proof, and refining proofs."
209 | ]
210 | },
211 | {
212 | "cell_type": "markdown",
213 | "metadata": {},
214 | "source": []
215 | }
216 | ],
217 | "metadata": {
218 | "kernelspec": {
219 | "display_name": "Python 3 (ipykernel)",
220 | "language": "python",
221 | "name": "python3"
222 | },
223 | "language_info": {
224 | "codemirror_mode": {
225 | "name": "ipython",
226 | "version": 3
227 | },
228 | "file_extension": ".py",
229 | "mimetype": "text/x-python",
230 | "name": "python",
231 | "nbconvert_exporter": "python",
232 | "pygments_lexer": "ipython3",
233 | "version": "3.10.11"
234 | }
235 | },
236 | "nbformat": 4,
237 | "nbformat_minor": 4
238 | }
239 |
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/data/successes_mathlib4_200_wellecks_llmstep-mathlib4-pythia2.8b.json:
--------------------------------------------------------------------------------
1 | {
2 | "results": [
3 | {
4 | "theorem": "NonemptyInterval.pure_injective",
5 | "proof": [
6 | "unfold pure",
7 | "simp [Injective]"
8 | ],
9 | "score": 0.06666778214275837,
10 | "success": true
11 | },
12 | {
13 | "theorem": "Finset.sym2_univ",
14 | "proof": [
15 | "rfl"
16 | ],
17 | "score": 0.057796161621809006,
18 | "success": true
19 | },
20 | {
21 | "theorem": "Subsemigroup.monotone_map",
22 | "proof": [
23 | "intro m n h",
24 | "rintro _ \u27e8x, hx, rfl\u27e9",
25 | "exact \u27e8x, h hx, rfl\u27e9"
26 | ],
27 | "score": 0.10295426659286022,
28 | "success": true
29 | },
30 | {
31 | "theorem": "MeasureTheory.Measurable.comp_nullMeasurable",
32 | "proof": [
33 | "exact hg.comp hf"
34 | ],
35 | "score": 0.03515021502971649,
36 | "success": true
37 | },
38 | {
39 | "theorem": "Finset.prod_pow_boole",
40 | "proof": [
41 | "split_ifs with h <;> simp [h]"
42 | ],
43 | "score": 0.022501567378640175,
44 | "success": true
45 | },
46 | {
47 | "theorem": "isOpenMap_div_right",
48 | "proof": [
49 | "simp only [div_eq_mul_inv, isOpenMap_mul_right]"
50 | ],
51 | "score": 0.06886392831802368,
52 | "success": true
53 | },
54 | {
55 | "theorem": "Fin.Ici_eq_finset_subtype",
56 | "proof": [
57 | "ext",
58 | "simp"
59 | ],
60 | "score": 0.04452210105955601,
61 | "success": true
62 | },
63 | {
64 | "theorem": "Polynomial.toFinsupp_zero",
65 | "proof": [
66 | "cases p",
67 | "rfl"
68 | ],
69 | "score": 0.0956510566174984,
70 | "success": true
71 | },
72 | {
73 | "theorem": "Continuous.sum_map",
74 | "proof": [
75 | "continuity"
76 | ],
77 | "score": 0.04079575836658478,
78 | "success": true
79 | },
80 | {
81 | "theorem": "Set.smul_set_inter\u2080",
82 | "proof": [
83 | "ext",
84 | "simp [mem_smul_set_iff_inv_smul_mem\u2080 ha]"
85 | ],
86 | "score": 0.07694211788475513,
87 | "success": true
88 | },
89 | {
90 | "theorem": "Subtype.coe_image_univ",
91 | "proof": [
92 | "simp [image_univ]"
93 | ],
94 | "score": 0.08483579009771347,
95 | "success": true
96 | },
97 | {
98 | "theorem": "AlgEquiv.toLinearEquiv_trans",
99 | "proof": [
100 | "ext <;> simp"
101 | ],
102 | "score": 0.02171783335506916,
103 | "success": true
104 | },
105 | {
106 | "theorem": "LieSubmodule.coe_add",
107 | "proof": [
108 | "simp"
109 | ],
110 | "score": 0.014644467271864414,
111 | "success": true
112 | },
113 | {
114 | "theorem": "IsAntichain.preimage_iso",
115 | "proof": [
116 | "intro b\u2081 b\u2082 hb\u2081 hb\u2082",
117 | "simpa using ht b\u2082 hb\u2082"
118 | ],
119 | "score": 0.06171885505318642,
120 | "success": true
121 | },
122 | {
123 | "theorem": "PrimeSpectrum.comap_asIdeal",
124 | "proof": [
125 | "rfl"
126 | ],
127 | "score": 0.026252638548612595,
128 | "success": true
129 | },
130 | {
131 | "theorem": "MonoidWithZeroHom.coe_mk",
132 | "proof": [
133 | "rfl"
134 | ],
135 | "score": 0.0029693180695176125,
136 | "success": true
137 | },
138 | {
139 | "theorem": "OreLocalization.one_def",
140 | "proof": [
141 | "rfl"
142 | ],
143 | "score": 0.05921659246087074,
144 | "success": true
145 | },
146 | {
147 | "theorem": "LinearPMap.mem_domain_of_mem_graph",
148 | "proof": [
149 | "rw [mem_graph_iff] at h",
150 | "rcases h with \u27e8x', rfl, h'\u27e9",
151 | "exact x'.2"
152 | ],
153 | "score": 0.04480455256998539,
154 | "success": true
155 | },
156 | {
157 | "theorem": "Vector.empty_toList_eq_ff",
158 | "proof": [
159 | "cases v",
160 | "simp [toList, List.isEmpty]",
161 | "split",
162 | "all_goals tauto"
163 | ],
164 | "score": 0.10822779312729836,
165 | "success": true
166 | },
167 | {
168 | "theorem": "Mathlib.Tactic.IntervalCases.of_not_lt_right",
169 | "proof": [
170 | "exact eq \u25b8 le_of_not_lt h"
171 | ],
172 | "score": 0.09517894685268402,
173 | "success": true
174 | },
175 | {
176 | "theorem": "List.Pairwise.and",
177 | "proof": [
178 | "induction hR",
179 | "case nil => simp",
180 | "simp_all"
181 | ],
182 | "score": 0.06705544702708721,
183 | "success": true
184 | },
185 | {
186 | "theorem": "Int.lcm_comm",
187 | "proof": [
188 | "rw [Int.lcm]",
189 | "apply Nat.lcm_comm"
190 | ],
191 | "score": 0.11606144160032272,
192 | "success": true
193 | },
194 | {
195 | "theorem": "Finsupp.ker_lsingle",
196 | "proof": [
197 | "ext",
198 | "simp"
199 | ],
200 | "score": 0.022008214611560106,
201 | "success": true
202 | },
203 | {
204 | "theorem": "Setoid.ext_iff",
205 | "proof": [
206 | "constructor",
207 | "rintro rfl _ _",
208 | "rfl",
209 | "exact fun h => Setoid.ext h"
210 | ],
211 | "score": 0.15640811529010534,
212 | "success": true
213 | },
214 | {
215 | "theorem": "Set.Finite.toFinset_offDiag",
216 | "proof": [
217 | "ext",
218 | "simp"
219 | ],
220 | "score": 0.010765764862298965,
221 | "success": true
222 | },
223 | {
224 | "theorem": "Finsupp.embDomain_inj",
225 | "proof": [
226 | "constructor",
227 | "apply embDomain_injective",
228 | "aesop"
229 | ],
230 | "score": 0.06850366480648518,
231 | "success": true
232 | },
233 | {
234 | "theorem": "PowerSeries.smul_eq_C_mul",
235 | "proof": [
236 | "simp [PowerSeries.ext_iff]"
237 | ],
238 | "score": 0.09802203625440598,
239 | "success": true
240 | },
241 | {
242 | "theorem": "OrderHom.orderHom_eq_id",
243 | "proof": [
244 | "apply Subsingleton.elim"
245 | ],
246 | "score": 0.032435908913612366,
247 | "success": true
248 | },
249 | {
250 | "theorem": "Behrend.sphere_subset_box",
251 | "proof": [
252 | "simp [sphere, box]"
253 | ],
254 | "score": 0.07471803575754166,
255 | "success": true
256 | },
257 | {
258 | "theorem": "Polynomial.aeval_algHom",
259 | "proof": [
260 | "ext",
261 | "simp"
262 | ],
263 | "score": 0.008154324954375625,
264 | "success": true
265 | },
266 | {
267 | "theorem": "OrderMonoidWithZeroHom.comp_mul",
268 | "proof": [
269 | "ext <;> simp"
270 | ],
271 | "score": 0.03107544593513012,
272 | "success": true
273 | },
274 | {
275 | "theorem": "Dfinsupp.mapRange_apply",
276 | "proof": [
277 | "rfl"
278 | ],
279 | "score": 0.021800542250275612,
280 | "success": true
281 | },
282 | {
283 | "theorem": "Submonoid.center_toSubsemigroup",
284 | "proof": [
285 | "rfl"
286 | ],
287 | "score": 0.1160610243678093,
288 | "success": true
289 | },
290 | {
291 | "theorem": "Subgroup.toSubmonoid_eq",
292 | "proof": [
293 | "constructor",
294 | "swap",
295 | "rintro rfl",
296 | "rfl",
297 | "apply toSubmonoid_injective"
298 | ],
299 | "score": 0.10547792294528335,
300 | "success": true
301 | },
302 | {
303 | "theorem": "MonoidHom.prod_comp_prodMap",
304 | "proof": [
305 | "ext <;> simp"
306 | ],
307 | "score": 0.010844556614756584,
308 | "success": true
309 | },
310 | {
311 | "theorem": "LocalEquiv.EqOnSource.source_eq",
312 | "proof": [
313 | "cases e",
314 | "cases e'",
315 | "exact h.1"
316 | ],
317 | "score": 0.039306161692366004,
318 | "success": true
319 | },
320 | {
321 | "theorem": "ENNReal.toNNReal_le_toNNReal",
322 | "proof": [
323 | "rw [\u2190 ENNReal.coe_le_coe, coe_toNNReal ha, coe_toNNReal hb]"
324 | ],
325 | "score": 0.05514577403664589,
326 | "success": true
327 | },
328 | {
329 | "theorem": "DirectSum.toAddMonoidAlgebra_toDirectSum",
330 | "proof": [
331 | "simp [DirectSum.toAddMonoidAlgebra]",
332 | "simp [AddMonoidAlgebra.toDirectSum]"
333 | ],
334 | "score": 0.07822933420538902,
335 | "success": true
336 | },
337 | {
338 | "theorem": "bddBelow_Ici",
339 | "proof": [
340 | "use a",
341 | "simp [lowerBounds]"
342 | ],
343 | "score": 0.09038608893752098,
344 | "success": true
345 | },
346 | {
347 | "theorem": "Int.lor_bit",
348 | "proof": [
349 | "dsimp [lor]",
350 | "cases a <;> cases b <;> simp",
351 | "all_goals cases m <;> cases n <;> simp"
352 | ],
353 | "score": 0.060909992549568415,
354 | "success": true
355 | },
356 | {
357 | "theorem": "IndexedPartition.some_index",
358 | "proof": [
359 | "cases hs",
360 | "simp",
361 | "aesop"
362 | ],
363 | "score": 0.10609651356935501,
364 | "success": true
365 | },
366 | {
367 | "theorem": "nndist_vadd_right",
368 | "proof": [
369 | "simp [nndist]",
370 | "ext",
371 | "simp",
372 | "rw [dist_comm, dist_vadd_left]"
373 | ],
374 | "score": 0.06943780835717916,
375 | "success": true
376 | },
377 | {
378 | "theorem": "Irrational.ne_nat",
379 | "proof": [
380 | "rw [Irrational] at h",
381 | "rintro rfl",
382 | "exact h (Set.mem_range_self _)"
383 | ],
384 | "score": 0.23603598587214947,
385 | "success": true
386 | },
387 | {
388 | "theorem": "StarSubalgebra.subset_adjoin",
389 | "proof": [
390 | "simp [subset_adjoin]",
391 | "tauto"
392 | ],
393 | "score": 0.05541197769343853,
394 | "success": true
395 | },
396 | {
397 | "theorem": "UniqueFactorizationMonoid.normalizedFactors_one",
398 | "proof": [
399 | "simp [normalizedFactors]"
400 | ],
401 | "score": 0.035825345665216446,
402 | "success": true
403 | },
404 | {
405 | "theorem": "Filter.bliminf_eq_liminf_subtype",
406 | "proof": [
407 | "simp [bliminf_eq, liminf_eq]",
408 | "congr!",
409 | "aesop"
410 | ],
411 | "score": 0.09382953867316246,
412 | "success": true
413 | },
414 | {
415 | "theorem": "Set.biUnion_singleton",
416 | "proof": [
417 | "simp"
418 | ],
419 | "score": 0.002082447987049818,
420 | "success": true
421 | },
422 | {
423 | "theorem": "RightOrdContinuous.iterate",
424 | "proof": [
425 | "induction' n with n ih",
426 | "case zero => exact RightOrdContinuous.id \u03b1",
427 | "exact ih.comp hf"
428 | ],
429 | "score": 0.08558873645961285,
430 | "success": true
431 | },
432 | {
433 | "theorem": "one_le_zpow",
434 | "proof": [
435 | "lift n to \u2115 using hn",
436 | "rw [zpow_ofNat]",
437 | "exact one_le_pow_of_one_le' H n"
438 | ],
439 | "score": 0.0882943207398057,
440 | "success": true
441 | },
442 | {
443 | "theorem": "not_lt_zero'",
444 | "proof": [
445 | "simp"
446 | ],
447 | "score": 0.03989856690168381,
448 | "success": true
449 | },
450 | {
451 | "theorem": "AlgEquiv.toLinearMap_injective",
452 | "proof": [
453 | "intro f g h",
454 | "ext",
455 | "rw [\u2190 toLinearMap_apply, \u2190 toLinearMap_apply, h]"
456 | ],
457 | "score": 0.04303271742537618,
458 | "success": true
459 | },
460 | {
461 | "theorem": "Algebra.algebraMap_ofSubsemiring_apply",
462 | "proof": [
463 | "rfl"
464 | ],
465 | "score": 0.017106426879763603,
466 | "success": true
467 | },
468 | {
469 | "theorem": "PiTensorProduct.lift_tprod",
470 | "proof": [
471 | "ext <;> simp"
472 | ],
473 | "score": 0.026965346187353134,
474 | "success": true
475 | },
476 | {
477 | "theorem": "div_eq_iff_mul_eq",
478 | "proof": [
479 | "rw [div_eq_iff hb]",
480 | "rw [eq_comm]"
481 | ],
482 | "score": 0.05084596015512943,
483 | "success": true
484 | },
485 | {
486 | "theorem": "ENNReal.one_lt_two",
487 | "proof": [
488 | "norm_num"
489 | ],
490 | "score": 0.004520223010331392,
491 | "success": true
492 | },
493 | {
494 | "theorem": "MeasureTheory.volume_preserving_finTwoArrow",
495 | "proof": [
496 | "haveI : Encodable \u03b9 := Fintype.toEncodable \u03b9",
497 | "apply measurePreserving_finTwoArrow"
498 | ],
499 | "score": 0.07828811183571815,
500 | "success": true
501 | },
502 | {
503 | "theorem": "Subsingleton.eq_univ_of_nonempty",
504 | "proof": [
505 | "rintro \u27e8x, hx\u27e9",
506 | "refine' eq_univ_iff_forall.mpr fun y => _",
507 | "rwa [Subsingleton.elim y x]"
508 | ],
509 | "score": 0.171775184571743,
510 | "success": true
511 | },
512 | {
513 | "theorem": "Complex.two_cos",
514 | "proof": [
515 | "rw [two_mul]",
516 | "simp [cos]"
517 | ],
518 | "score": 0.15968229621648788,
519 | "success": true
520 | },
521 | {
522 | "theorem": "IsLocalization.map_mk'",
523 | "proof": [
524 | "apply IsLocalization.lift_mk'"
525 | ],
526 | "score": 0.02295919507741928,
527 | "success": true
528 | },
529 | {
530 | "theorem": "Subsemigroup.coe_comap",
531 | "proof": [
532 | "rfl"
533 | ],
534 | "score": 0.03112742491066456,
535 | "success": true
536 | },
537 | {
538 | "theorem": "CategoryTheory.LaxMonoidalFunctor.prod'_\u03bc",
539 | "proof": [
540 | "dsimp [prod']",
541 | "simp"
542 | ],
543 | "score": 0.021082513965666294,
544 | "success": true
545 | },
546 | {
547 | "theorem": "Matrix.SpecialLinearGroup.coe_int_neg",
548 | "proof": [
549 | "ext",
550 | "simp"
551 | ],
552 | "score": 0.0304072555154562,
553 | "success": true
554 | },
555 | {
556 | "theorem": "ENNReal.toReal_nonneg",
557 | "proof": [
558 | "simp [ENNReal.toReal]"
559 | ],
560 | "score": 0.05232922360301018,
561 | "success": true
562 | },
563 | {
564 | "theorem": "Fin.le_coe_natAdd",
565 | "proof": [
566 | "simp"
567 | ],
568 | "score": 0.05793900787830353,
569 | "success": true
570 | },
571 | {
572 | "theorem": "SemiconjBy.inv_inv_symm",
573 | "proof": [
574 | "intro h",
575 | "simp [h]"
576 | ],
577 | "score": 0.12453563511371613,
578 | "success": true
579 | },
580 | {
581 | "theorem": "nhds_ofAdd",
582 | "proof": [
583 | "rfl"
584 | ],
585 | "score": 0.03388547897338867,
586 | "success": true
587 | },
588 | {
589 | "theorem": "MatrixEquivTensor.invFun_zero",
590 | "proof": [
591 | "simp [invFun]"
592 | ],
593 | "score": 0.018612481653690338,
594 | "success": true
595 | },
596 | {
597 | "theorem": "List.toFinset_reverse",
598 | "proof": [
599 | "simp [toFinset]"
600 | ],
601 | "score": 0.026263730600476265,
602 | "success": true
603 | },
604 | {
605 | "theorem": "CategoryTheory.Monoidal.leftUnitor_hom_app",
606 | "proof": [
607 | "dsimp [leftUnitor]"
608 | ],
609 | "score": 0.01946549117565155,
610 | "success": true
611 | },
612 | {
613 | "theorem": "List.sum_smul",
614 | "proof": [
615 | "induction' l with hd tl hl",
616 | "simp",
617 | "simp [add_smul, hl]"
618 | ],
619 | "score": 0.056858822237700224,
620 | "success": true
621 | },
622 | {
623 | "theorem": "ProbabilityTheory.kernel.integral_const",
624 | "proof": [
625 | "rw [kernel.const_apply]"
626 | ],
627 | "score": 0.0064316014759242535,
628 | "success": true
629 | },
630 | {
631 | "theorem": "CauSeq.coe_sup",
632 | "proof": [
633 | "rfl"
634 | ],
635 | "score": 0.0693473070859909,
636 | "success": true
637 | },
638 | {
639 | "theorem": "MeasureTheory.NullMeasurableSet.insert",
640 | "proof": [
641 | "unfold NullMeasurableSet",
642 | "rw [insert_eq]",
643 | "measurability"
644 | ],
645 | "score": 0.08922873809933662,
646 | "success": true
647 | },
648 | {
649 | "theorem": "Set.subset_iInter\u2082",
650 | "proof": [
651 | "intro x hx",
652 | "exact mem_iInter.2 fun i => mem_iInter.2 fun j => h i j hx"
653 | ],
654 | "score": 0.042423089034855366,
655 | "success": true
656 | },
657 | {
658 | "theorem": "ClopenUpperSet.coe_mk",
659 | "proof": [
660 | "rfl"
661 | ],
662 | "score": 0.01648038811981678,
663 | "success": true
664 | },
665 | {
666 | "theorem": "IsLprojection.compl_mul",
667 | "proof": [
668 | "simp [sub_mul]"
669 | ],
670 | "score": 0.061795979738235474,
671 | "success": true
672 | },
673 | {
674 | "theorem": "Equiv.mulLeft_symm_apply",
675 | "proof": [
676 | "simp [Equiv.mulLeft]"
677 | ],
678 | "score": 0.028932297602295876,
679 | "success": true
680 | },
681 | {
682 | "theorem": "Multiset.map_hcongr",
683 | "proof": [
684 | "subst h",
685 | "revert hf",
686 | "simp (config := { contextual := true }) [heq_iff_eq]"
687 | ],
688 | "score": 0.1003159754909575,
689 | "success": true
690 | },
691 | {
692 | "theorem": "EMetric.le_infEdist",
693 | "proof": [
694 | "simp [infEdist]"
695 | ],
696 | "score": 0.044103752821683884,
697 | "success": true
698 | },
699 | {
700 | "theorem": "DifferentiableOn.const_smul",
701 | "proof": [
702 | "intro y hy",
703 | "rcases h y hy with \u27e8f', hf'\u27e9",
704 | "exact hf'.differentiableWithinAt.const_smul _"
705 | ],
706 | "score": 0.043711879290640354,
707 | "success": true
708 | },
709 | {
710 | "theorem": "Set.diff_empty",
711 | "proof": [
712 | "simp [diff_eq]"
713 | ],
714 | "score": 0.028194859623908997,
715 | "success": true
716 | },
717 | {
718 | "theorem": "Equiv.symm_apply_apply",
719 | "proof": [
720 | "cases e",
721 | "aesop"
722 | ],
723 | "score": 0.09066382050514221,
724 | "success": true
725 | }
726 | ],
727 | "args": {
728 | "model_name": "wellecks/llmstep-mathlib4-pythia2.8b",
729 | "dataset_path": "./data/val.json",
730 | "max_iters": 50,
731 | "num_samples": 32,
732 | "num_examples": 200
733 | }
734 | }
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/banach/banach_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/banach/banach_1.png
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/banach/banach_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/banach/banach_2.png
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/banach/banach_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/banach/banach_3.png
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/banach/banach_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/banach/banach_4.png
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/banach/banach_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/banach/banach_5.png
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/leandojo_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/leandojo_1.png
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/llmsuggest/llmstep_gif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/llmsuggest/llmstep_gif.gif
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/llmsuggest/llmsuggest.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/llmsuggest/llmsuggest.gif
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/llmsuggest/llmsuggest_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/llmsuggest/llmsuggest_examples.png
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/proof_state_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/proof_state_1.png
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/proof_state_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/proof_state_2.png
--------------------------------------------------------------------------------
/partI_nextstep/notebooks/images/proof_state_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/notebooks/images/proof_state_3.png
--------------------------------------------------------------------------------
/partI_nextstep/ntp_lean/ExtractSimple.lean:
--------------------------------------------------------------------------------
1 | -- A minimal version of LeanDojo's ExtractData.lean for instructional purposes.
2 | -- Please see LeanDojo's ExtractData.lean for a full working script to use in practice.
3 | --
4 | -- Credits: original script is from LeanDojo https://github.com/lean-dojo/LeanDojo/
5 | -- @article{yang2023leandojo,
6 | -- title={{LeanDojo}: Theorem Proving with Retrieval-Augmented Language Models},
7 | -- author={Yang, Kaiyu and Swope, Aidan and Gu, Alex and Chalamala, Rahul and Song,
8 | -- Peiyang and Yu, Shixing and Godil, Saad and Prenger, Ryan and Anandkumar, Anima},
9 | -- journal={arXiv preprint arXiv:2306.15626},
10 | -- year={2023}
11 | -- }
12 | -- This script is essentially a slightly refactored subset of LeanDojo's script.
13 |
14 | import Lean
15 |
16 | open Lean Elab System
17 |
18 | instance : ToJson Substring where
19 | toJson s := toJson s.toString
20 |
21 | instance : ToJson String.Pos where
22 | toJson n := toJson n.1
23 |
24 | deriving instance ToJson for SourceInfo
25 | deriving instance ToJson for Syntax.Preresolved
26 | deriving instance ToJson for Syntax
27 |
28 | structure TacticTrace where
29 | stateBefore: String
30 | stateAfter: String
31 | pos: String.Pos
32 | endPos: String.Pos
33 | deriving ToJson
34 |
35 | structure Trace where
36 | commandASTs : Array Syntax
37 | tactics: Array TacticTrace
38 | deriving ToJson
39 |
40 | abbrev TraceM := StateT Trace IO
41 |
42 |
43 | def ppGoals (ctx : ContextInfo) (goals : List MVarId) : IO String :=
44 | if goals.isEmpty then
45 | return "no goals"
46 | else
47 | let fmt := ctx.runMetaM {} (return Std.Format.prefixJoin "\n\n" (← goals.mapM (Meta.ppGoal ·)))
48 | return (← fmt).pretty.trim
49 |
50 |
51 | private def visitTacticInfo (ctx : ContextInfo) (ti : TacticInfo) (parent : InfoTree) : TraceM Unit := do
52 | match parent with
53 | | .node (Info.ofTacticInfo i) _ =>
54 | match i.stx.getKind with
55 | | `Lean.Parser.Tactic.tacticSeq1Indented | `Lean.Parser.Tactic.tacticSeqBracketed =>
56 | let ctxBefore := { ctx with mctx := ti.mctxBefore }
57 | let ctxAfter := { ctx with mctx := ti.mctxAfter }
58 | let stateBefore ← ppGoals ctxBefore ti.goalsBefore
59 | let stateAfter ← ppGoals ctxAfter ti.goalsAfter
60 | let some posBefore := ti.stx.getPos? true | pure ()
61 | let some posAfter := ti.stx.getTailPos? true | pure ()
62 | match ti.stx with
63 | | .node _ _ _ =>
64 | modifyGet fun trace => ((),
65 | { trace with tactics := trace.tactics.push {
66 | stateBefore := stateBefore,
67 | stateAfter := stateAfter,
68 | pos := posBefore,
69 | endPos := posAfter } }
70 | )
71 | | _ => pure ()
72 | | _ => pure ()
73 | | _ => pure ()
74 |
75 |
76 | private def visitInfo (ctx : ContextInfo) (i : Info) (parent : InfoTree) : TraceM Unit := do
77 | match i with
78 | | .ofTacticInfo ti => visitTacticInfo ctx ti parent
79 | | _ => pure ()
80 |
81 |
82 | private partial def traverseTree (ctx: ContextInfo) (tree : InfoTree) (parent : InfoTree) : TraceM Unit := do
83 | match tree with
84 | | .context ctx' t => traverseTree ctx' t tree
85 | | .node i children =>
86 | visitInfo ctx i parent
87 | for x in children do
88 | traverseTree ctx x tree
89 | | _ => pure ()
90 |
91 |
92 | private def traverseTopLevelTree (tree : InfoTree) : TraceM Unit := do
93 | match tree with
94 | | .context ctx t => traverseTree ctx t tree
95 | | _ => throw $ IO.userError "Errors in traverseTopLevelTree; aborting"
96 |
97 |
98 | def traverseForest (trees : Array InfoTree) : TraceM Trace := do
99 | for t in trees do
100 | traverseTopLevelTree t
101 | get
102 |
103 |
104 | def relativeTo (path parent : FilePath) : Option FilePath :=
105 | let rec componentsRelativeTo (pathComps parentComps : List String) : Option FilePath :=
106 | match pathComps, parentComps with
107 | | _, [] => mkFilePath pathComps
108 | | [], _ => none
109 | | (h₁ :: t₁), (h₂ :: t₂) =>
110 | if h₁ == h₂ then
111 | componentsRelativeTo t₁ t₂
112 | else
113 | none
114 |
115 | componentsRelativeTo path.components parent.components
116 |
117 |
118 | def toAbsolute (path : FilePath) : IO FilePath := do
119 | if path.isAbsolute then
120 | pure path
121 | else
122 | let cwd ← IO.currentDir
123 | pure $ cwd / path
124 |
125 |
126 | unsafe def processFile (path : FilePath) : IO Unit := do
127 | let input ← IO.FS.readFile path
128 | let opts := Options.empty.setBool `trace.Elab.info true
129 | enableInitializersExecution
130 | let inputCtx := Parser.mkInputContext input path.toString
131 | let (header, parserState, messages) ← Parser.parseHeader inputCtx
132 | let (env, messages) ← processHeader header opts messages inputCtx
133 |
134 | if messages.hasErrors then
135 | for msg in messages.toList do
136 | if msg.severity == .error then
137 | println! "ERROR: {← msg.toString}"
138 | throw $ IO.userError "Errors during import; aborting"
139 |
140 | let some modName := path.fileStem | throw $ IO.userError s!"Invalid path: {path}"
141 | let env := env.setMainModule modName.toName
142 | let commandState := { Command.mkState env messages opts with infoState.enabled := true }
143 | let s ← IO.processCommands inputCtx parserState commandState
144 | let commands := s.commands.pop -- Remove EOI command.
145 | let trees := s.commandState.infoState.trees.toArray
146 | let trace ← (traverseForest trees).run' ⟨#[header] ++ commands, #[]⟩
147 |
148 | let cwd ← IO.currentDir
149 | let some relativePath := relativeTo path cwd | throw $ IO.userError s!"Invalid path: {path}"
150 | println! "Input file: {relativePath}"
151 | let json_path := (
152 | relativePath.withExtension "ast.json"
153 | )
154 | IO.FS.writeFile json_path (toJson trace).pretty
155 | println! "AST: {json_path}"
156 |
157 |
158 | unsafe def main (args : List String) : IO Unit := do
159 | match args with
160 | | path :: _ =>
161 | processFile (← toAbsolute ⟨path⟩)
162 | | [] =>
163 | println! "Please provide a .lean file (lake env lean --run ExtractData.lean FILENAME.lean)"
--------------------------------------------------------------------------------
/partI_nextstep/ntp_lean/LLMsuggest.lean:
--------------------------------------------------------------------------------
1 | /-
2 | `llmsuggest` tactic for LLM-based next-step suggestions in Lean4.
3 |
4 | This is a minimal version of `llmstep` built for tutorial purposes.
5 | `llmstep`: https://github.com/wellecks/llmstep
6 | -/
7 |
8 | import Mathlib.Tactic
9 |
10 |
11 | open Lean
12 |
13 | /- Calls a `suggest.py` python script with the given `args`. -/
14 | def runSuggest (args : Array String) : IO String := do
15 | let cwd ← IO.currentDir
16 | let path := cwd / "partI_nextstep" / "ntp_python" / "llmsuggest" / "suggest.py"
17 | unless ← path.pathExists do
18 | dbg_trace f!"{path}"
19 | throw <| IO.userError "could not find python script suggest.py"
20 | let s ← IO.Process.run { cmd := "python3", args := #[path.toString] ++ args }
21 | return s
22 |
23 | /- Display clickable suggestions in the VSCode Lean Infoview.
24 | When a suggestion is clicked, this widget replaces the `llmstep` call
25 | with the suggestion, and saves the call in an adjacent comment.
26 | Code based on `Std.Tactic.TryThis.tryThisWidget`. -/
27 | @[widget] def llmstepTryThisWidget : Widget.UserWidgetDefinition where
28 | name := "llmstep suggestions"
29 | javascript := "
30 | import * as React from 'react';
31 | import { EditorContext } from '@leanprover/infoview';
32 | const e = React.createElement;
33 | export default function(props) {
34 | const editorConnection = React.useContext(EditorContext)
35 | function onClick(suggestion) {
36 | editorConnection.api.applyEdit({
37 | changes: { [props.pos.uri]: [{ range:
38 | props.range,
39 | newText: suggestion + ' -- ' + props.tactic
40 | }] }
41 | })
42 | }
43 | return e('div',
44 | {className: 'ml1'},
45 | e('ul', {className: 'font-code pre-wrap'}, [
46 | 'Try this: ',
47 | ...(props.suggestions.map(suggestion =>
48 | e('li', {onClick: () => onClick(suggestion),
49 | className: 'link pointer dim', title: 'Apply suggestion'},
50 | suggestion
51 | )
52 | )),
53 | props.info
54 | ]))
55 | }"
56 |
57 |
58 | /- Adds multiple suggestions to the Lean InfoView.
59 | Code based on `Std.Tactic.addSuggestion`. -/
60 | def addSuggestions (tacRef : Syntax) (suggestions: List String)
61 | (origSpan? : Option Syntax := none)
62 | (extraMsg : String := "") : MetaM Unit := do
63 | if let some tacticRange := (origSpan?.getD tacRef).getRange? then
64 | let map ← getFileMap
65 | let start := findLineStart map.source tacticRange.start
66 | let body := map.source.findAux (· ≠ ' ') tacticRange.start start
67 | let texts := suggestions.map fun text => (
68 | Std.Format.prettyExtra text
69 | (indent := (body - start).1)
70 | (column := (tacticRange.start - start).1)
71 | )
72 | let start := (tacRef.getRange?.getD tacticRange).start
73 | let stop := (tacRef.getRange?.getD tacticRange).stop
74 | let stxRange :=
75 | { start := map.lineStart (map.toPosition start).line
76 | stop := map.lineStart ((map.toPosition stop).line + 1) }
77 | let tacticRange := map.utf8RangeToLspRange tacticRange
78 | let tactic := Std.Format.prettyExtra f!"{tacRef.prettyPrint}"
79 | let json := Json.mkObj [
80 | ("tactic", tactic),
81 | ("suggestions", toJson texts),
82 | ("range", toJson tacticRange),
83 | ("info", extraMsg)
84 | ]
85 | Widget.saveWidgetInfo ``llmstepTryThisWidget json (.ofRange stxRange)
86 |
87 |
88 | -- `llmsuggest` tactic.
89 | syntax "llmsuggest" : tactic
90 | elab_rules : tactic
91 | | `(tactic | llmsuggest%$tac) =>
92 | Lean.Elab.Tactic.withMainContext do
93 | let goal ← Lean.Elab.Tactic.getMainGoal
94 | let ppgoal ← Lean.Meta.ppGoal goal
95 | let ppgoalstr := toString ppgoal
96 | let suggest ← runSuggest #[ppgoalstr]
97 | addSuggestions tac $ suggest.splitOn "[SUGGESTION]"
98 |
99 |
100 | /- Examples -/
101 | example : 2 = 2 := by
102 | rfl -- llmsuggest
103 |
104 |
105 | example (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by
106 | intro h n -- llmsuggest
107 | exact h (Nat.le_succ _) -- llmsuggest
108 |
109 |
110 |
111 |
--------------------------------------------------------------------------------
/partI_nextstep/ntp_lean/examples/example0.lean:
--------------------------------------------------------------------------------
1 | import Mathlib.Data.Nat.Prime
2 |
3 | theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by
4 | rw [Nat.coprime] at h
5 | exact h
--------------------------------------------------------------------------------
/partI_nextstep/ntp_lean/examples/example_demo.lean:
--------------------------------------------------------------------------------
1 | import Mathlib.Data.Nat.Prime
2 |
3 | variable (α: Type) (R S T : Set α)
4 |
5 |
6 | example (h1: R ⊆ S) (h2: S ⊆ T) : (R ⊆ T) := by
7 | exact h1.trans h2
--------------------------------------------------------------------------------
/partI_nextstep/ntp_python/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wellecks/ntptutorial/80f56388c1e22004c8e63f6faa5c8d3b23b2e650/partI_nextstep/ntp_python/__init__.py
--------------------------------------------------------------------------------
/partI_nextstep/ntp_python/data.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Dict, Optional, Sequence
3 | from tqdm import tqdm
4 | from pathlib import Path
5 | import os
6 | import json
7 |
8 | TOKEN_MAP = {
9 | 'pad': '[PAD]',
10 | 'eos': '<|endoftext|>',
11 | }
12 |
13 | PHRASE_MAP = {
14 | 'goal': '[GOAL]',
15 | 'proofstep': '[PROOFSTEP]',
16 | }
17 |
18 |
19 | def _download_and_unpack(tarball_url, data_dir, overwrite):
20 | import subprocess
21 | if (not overwrite) and Path(data_dir).exists():
22 | return
23 | Path(data_dir).mkdir(parents=True, exist_ok=True)
24 | archive_path = os.path.join(data_dir, "archive.tar.gz")
25 | subprocess.call(['wget', '-O', archive_path, tarball_url])
26 | subprocess.call(['tar', '-xzf', archive_path, '-C', data_dir])
27 |
28 |
29 | def _load_ds(data_dir):
30 | ds = {}
31 | for split in ['train', 'val', 'test']:
32 | ds[split] = json.load(open(os.path.join(
33 | data_dir, 'leandojo_benchmark_4', 'random', f'{split}.json'), 'r')
34 | )
35 | return ds
36 |
37 |
38 | def _save_splits(splits, data_dir, tag):
39 | print("Saving split to disk...")
40 | out_dir = os.path.join(data_dir, 'processed')
41 | for split, examples in tqdm(splits.items(), total=len(splits)):
42 | Path(out_dir).mkdir(parents=True, exist_ok=True)
43 | out_file = os.path.join(
44 | out_dir, '%s-%s.jsonl' % (tag, split)
45 | )
46 | with open(out_file, 'w') as f:
47 | for example in examples:
48 | f.write(json.dumps(example))
49 | f.write('\n')
50 |
51 |
52 | def _print_stats(splits):
53 | for split, examples in splits.items():
54 | print("%s\t%d" % (split, len(examples)))
55 |
56 |
57 |
58 |
59 | # --- Proofstep
60 | def _fmt_proofstep(state_before, tactic):
61 | # [GOAL]{state_before}[PROOFSTEP]{tactic}<|endoftext|>
62 | inp = f"{PHRASE_MAP['goal']}{state_before}{PHRASE_MAP['proofstep']}"
63 | out = f"{tactic}{TOKEN_MAP['eos']}"
64 | return inp, out
65 |
66 |
67 | def fmt_proofstep(split):
68 | examples = []
69 | for traced_theorem in split:
70 | for tactic_example in traced_theorem['traced_tactics']:
71 | inp, out = _fmt_proofstep(tactic_example['state_before'], tactic_example['tactic'])
72 | examples.append({
73 | 'input': inp,
74 | 'output': out,
75 | })
76 | return examples
77 |
78 |
79 | def proofstep(data_dir):
80 | ds = _load_ds(data_dir)
81 | out_ds = {}
82 | for split in ds:
83 | out_ds[split] = fmt_proofstep(ds[split])
84 |
85 | _save_splits(
86 | splits=out_ds,
87 | data_dir=data_dir,
88 | tag='proofstep'
89 | )
90 | _print_stats(
91 | splits=out_ds
92 | )
93 | return out_ds
94 | # ---
95 |
96 |
97 | def main(args):
98 | proofstep(args.datadir)
99 |
100 |
101 | def setup(args):
102 | # Download data
103 | _download_and_unpack(
104 | tarball_url='https://zenodo.org/record/8040110/files/leandojo_benchmark_4_v1.tar.gz',
105 | data_dir=args.data_dir,
106 | overwrite=args.overwrite
107 | )
108 |
109 |
110 | if __name__ == '__main__':
111 | import argparse
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument('--overwrite', action='store_true')
114 | parser.add_argument('--data-dir', type=str, default='./data/leandojo_benchmark_4')
115 |
116 | args = parser.parse_args()
117 | setup(args)
118 | main(args)
119 |
--------------------------------------------------------------------------------
/partI_nextstep/ntp_python/llmsuggest/server.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, request, jsonify
2 |
3 | import argparse
4 | import transformers
5 | import torch
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--hf-model', type=str, default='wellecks/llmstep-mathlib4-pythia2.8b')
9 | parser.add_argument('--port', type=int, default=5000)
10 | parser.add_argument('--num-samples', type=int, default=5)
11 | args = parser.parse_args()
12 |
13 | print("Loading model...")
14 | model = transformers.GPTNeoXForCausalLM.from_pretrained(args.hf_model)
15 | if torch.cuda.is_available():
16 | model.cuda()
17 | model.eval()
18 |
19 | tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(args.hf_model)
20 | print("Done.")
21 |
22 |
23 | def generate(prompt, num_samples):
24 | print(prompt)
25 | input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
26 | out = model.generate(
27 | input_ids,
28 | max_new_tokens=50,
29 | pad_token_id=tokenizer.eos_token_id,
30 | num_return_sequences=num_samples,
31 | return_dict_in_generate=True,
32 | num_beams=num_samples,
33 | output_scores=True
34 | )
35 | texts = tokenizer.batch_decode(
36 | out.sequences[:,input_ids.shape[1]:],
37 | skip_special_tokens=True
38 | )
39 | texts = _unique_sorted(texts, out.sequences_scores.tolist())
40 | return texts
41 |
42 |
43 | def _unique_sorted(texts, scores):
44 | texts_, scores_ = [], []
45 | for t, s in sorted(zip(texts, scores), key=lambda x: -x[1]):
46 | if t not in texts_:
47 | texts_.append(t)
48 | scores_.append(s)
49 | return texts_
50 |
51 | app = Flask(__name__)
52 |
53 | @app.route('/', methods=['POST'])
54 | def process_request():
55 | data = request.get_json()
56 |
57 | tactic_state = data.get('tactic_state')
58 |
59 | prompt = """[GOAL]%s[PROOFSTEP]""" % (tactic_state)
60 | texts = generate(prompt, args.num_samples)
61 |
62 | response = {"suggestions": texts}
63 | return jsonify(response)
64 |
65 | if __name__ == '__main__':
66 | app.run(debug=True, host='0.0.0.0', port=args.port)
67 |
--------------------------------------------------------------------------------
/partI_nextstep/ntp_python/llmsuggest/suggest.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import http.client
3 | import json
4 | import sys
5 |
6 |
7 | def suggest(tactic_state):
8 | conn = http.client.HTTPConnection("localhost", 5000)
9 | headers = {'Content-type': 'application/json'}
10 | body = json.dumps({"tactic_state": sys.argv[1]})
11 | conn.request("POST", "/", body, headers)
12 | response = conn.getresponse()
13 | data = response.read()
14 | data_dict = json.loads(data)
15 | print('[SUGGESTION]'.join(data_dict['suggestions']))
16 | conn.close()
17 |
18 | if __name__ == "__main__":
19 | suggest(sys.argv[1])
--------------------------------------------------------------------------------
/partI_nextstep/ntp_python/postprocess_ast.py:
--------------------------------------------------------------------------------
1 | # A naive post-processing of the .ast.json file for instructional purposes.
2 | # Please see LeanDojo's `TracedFile` and `FileNode4` classes for a complete version.
3 |
4 | def reconstruct_theorem_proof(command_ast):
5 | pieces = []
6 | pos = []
7 | endPos = []
8 | def _process_arg(node_arg):
9 | if 'atom' in node_arg:
10 | pieces.append(node_arg['atom']['info']['original']['leading'])
11 | pieces.append(node_arg['atom']['val'])
12 | pieces.append(node_arg['atom']['info']['original']['trailing'])
13 | pos.append(node_arg['atom']['info']['original']['pos'])
14 | endPos.append(node_arg['atom']['info']['original']['endPos'])
15 | if 'ident' in node_arg:
16 | pieces.append(node_arg['ident']['info']['original']['leading'])
17 | pieces.append(node_arg['ident']['rawVal'])
18 | pieces.append(node_arg['ident']['info']['original']['trailing'])
19 | pos.append(node_arg['ident']['info']['original']['pos'])
20 | endPos.append(node_arg['ident']['info']['original']['endPos'])
21 | if 'node' in node_arg:
22 | _process_node(node_arg['node'])
23 |
24 | def _process_node(node):
25 | if 'args' in node:
26 | for arg in node['args']:
27 | _process_arg(arg)
28 |
29 | _process_node(command_ast['node'])
30 | out = ''.join(pieces)
31 | start = min(pos)
32 | end = max(endPos)
33 | return start, end, out
34 |
35 | def get_theorem(start, ast):
36 | for command in ast['commandASTs']:
37 | start_, end_, thm = reconstruct_theorem_proof(command)
38 | if start_ <= start <= end_:
39 | return thm
--------------------------------------------------------------------------------
/partI_nextstep/ntp_python/proofsearch_dojo.py:
--------------------------------------------------------------------------------
1 | # Utilities for interacting with Lean and proof search
2 | import os
3 | import transformers
4 | from lean_dojo import *
5 | import json
6 | import torch
7 | from datetime import datetime
8 | import heapq
9 | import transformers
10 | import random
11 | from typing import List, Tuple
12 | from tqdm import tqdm, trange
13 |
14 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
15 |
16 |
17 | def generate(prompt, model, tokenizer, num_samples):
18 | input_ids = tokenizer.encode(
19 | prompt, return_tensors='pt', truncation=True, max_length=1024
20 | ).to(model.device)
21 |
22 | texts, scores = [], []
23 | with torch.no_grad():
24 | out = model.generate(
25 | input_ids,
26 | max_new_tokens=128,
27 | do_sample=False,
28 | pad_token_id=tokenizer.eos_token_id,
29 | num_return_sequences=num_samples,
30 | return_dict_in_generate=True,
31 | output_scores=True,
32 | num_beams=num_samples
33 | )
34 | output_tokens = out.sequences[:, input_ids.shape[1]:]
35 | texts.extend(tokenizer.batch_decode(
36 | output_tokens,
37 | skip_special_tokens=True
38 | ))
39 | scores.extend(out.sequences_scores.view(-1).tolist())
40 |
41 | texts, scores = _unique_sorted(texts, scores)
42 | return texts, scores
43 |
44 |
45 | def _unique_sorted(texts, scores):
46 | texts_, scores_ = [], []
47 | for t, s in sorted(zip(texts, scores), key=lambda x: -x[1]):
48 | if t not in texts_:
49 | texts_.append(t)
50 | scores_.append(s)
51 | return texts_, scores_
52 |
53 |
54 | def _tactic_state(state):
55 | if isinstance(state, TacticState):
56 | ts = state.pp
57 | else:
58 | ts = state.unsolved_tactic_state
59 | return ts
60 |
61 |
62 | def _prompt(ts):
63 | prompt = f"[GOAL]{ts}[PROOFSTEP]"
64 | return prompt
65 |
66 |
67 | def best_first_search(theorem, model, tokenizer, max_iters, num_samples, timeout=600) -> dict:
68 | try:
69 | with Dojo(theorem, hard_timeout=60 + timeout) as (dojo, init_state):
70 | queue = [(0.0, [], init_state)]
71 | visited = set()
72 | for _ in trange(max_iters):
73 | if len(queue) == 0:
74 | break
75 |
76 | total_score, steps, state = heapq.heappop(queue)
77 | ts = _tactic_state(state)
78 | visited.add(ts)
79 |
80 | step_cands, step_scores = generate(
81 | _prompt(ts), model, tokenizer, num_samples
82 | )
83 |
84 | for step, score in zip(step_cands, step_scores):
85 | result = dojo.run_tac(state, step)
86 | if isinstance(result, ProofFinished):
87 | return {
88 | 'theorem': theorem.full_name,
89 | 'proof': steps + [step],
90 | 'score': total_score - score,
91 | 'success': True,
92 | 'failure_reason': ''
93 | }
94 | elif isinstance(result, TacticState):
95 | if _tactic_state(result) not in visited:
96 | # Score is negative log probability summed across steps
97 | new_score = (total_score - score)
98 | heapq.heappush(
99 | queue, (new_score, steps+[step], result)
100 | )
101 | except (DojoInitError, DojoHardTimeoutError, DojoCrashError) as e:
102 | return {'theorem': theorem.full_name, 'success': False, 'failure_reason': str(e)}
103 |
104 | return {'theorem': theorem.full_name, 'success': False, 'failure_reason': 'SearchEnded'}
105 |
106 |
107 | def _save(model_name, results, args_dict, dt):
108 | output_file = 'results__%s__%s.json' % (model_name.replace('/', '_'), dt)
109 | with open(output_file, 'w') as f:
110 | json.dump({
111 | 'results': results,
112 | 'args': args_dict
113 | } , f, indent=4)
114 | print(output_file)
115 |
116 |
117 | def load_model(model_name):
118 | model = transformers.GPTNeoXForCausalLM.from_pretrained(
119 | model_name, torch_dtype=torch.float16
120 | )
121 | tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name)
122 | model.eval()
123 | return model, tokenizer
124 |
125 |
126 | if __name__ == '__main__':
127 | import argparse
128 |
129 | parser = argparse.ArgumentParser()
130 | parser.add_argument(
131 | '--model-name',
132 | default='wellecks/llmstep-mathlib4-pythia2.8b',
133 | choices=['wellecks/llmstep-mathlib4-pythia2.8b']
134 | )
135 | parser.add_argument('--dataset-path', default='data/val.json')
136 | parser.add_argument('--max-iters', type=int, default=100)
137 | parser.add_argument('--num-samples', type=int, default=32)
138 | parser.add_argument('--num-examples', type=int, default=200)
139 | args = parser.parse_args()
140 |
141 | model, tokenizer = load_model(args.model_name, args.vllm)
142 |
143 | URL = "https://github.com/leanprover-community/mathlib4"
144 | COMMIT = "5a919533f110b7d76410134a237ee374f24eaaad"
145 | repo = LeanGitRepo(URL, COMMIT)
146 |
147 | dt = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
148 |
149 | with open(args.dataset_path) as f:
150 | data = json.load(f)
151 |
152 | random.seed(43)
153 | data = random.sample(data, args.num_examples)
154 | results = []
155 | for example in tqdm(data, total=len(data)):
156 | file_path = example['file_path']
157 | theorem_name = example['full_name']
158 | theorem = Theorem(repo, file_path, theorem_name)
159 | result = best_first_search(
160 | theorem, model, tokenizer,
161 | max_iters=args.max_iters,
162 | num_samples=args.num_samples
163 | )
164 | print(result)
165 | print('\n-----\n')
166 | results.append(result)
167 |
168 | _save(args.model_name, results, args.__dict__, dt)
169 | print(len([x for x in results if x['success']])/len(results))
--------------------------------------------------------------------------------
/partI_nextstep/ntp_python/proofsearch_pylean.py:
--------------------------------------------------------------------------------
1 | # Utilities for interacting with Lean and proof search
2 |
3 | from pylean import LeanServer
4 | import torch
5 | import heapq
6 | import concurrent
7 | import transformers
8 | import os
9 | from tqdm import tqdm
10 | from concurrent.futures import ThreadPoolExecutor
11 | from typing import List, Tuple
12 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
13 |
14 |
15 | def is_done(state):
16 | return state['sorries'] == [] and state['messages'] == []
17 |
18 |
19 | def get_goal(state):
20 | goal = None
21 | for msg in state['messages']:
22 | if msg['data'].startswith('unsolved goals\n'):
23 | goal = '\n'.join(msg['data'].split('\n')[1:])
24 | elif msg['severity'] == 'error':
25 | return None
26 | return goal
27 |
28 |
29 | def get_errors(state):
30 | return state['messages']
31 |
32 |
33 | def parse_step(step):
34 | step = step.replace('<|endoftext|>', '')
35 | return step
36 |
37 |
38 | def format_code(header, statement, steps_so_far, next_step):
39 | return header + (statement.replace(" {}", "") + '\n' + '\n'.join(steps_so_far + [next_step]))
40 |
41 |
42 | def run_code(code):
43 | lean = LeanServer()
44 | out = lean.run_code(code)
45 | lean.proc.close()
46 | del lean
47 | return out
48 |
49 |
50 | def sequence_scores(out, prompt_length, model, tokenizer):
51 | # Returns each output sequence's log probability normalized by the number of tokens.
52 | # An output sequence is defined as the tokens after the prompt up to and including eos.
53 | text = tokenizer.batch_decode(out.sequences)
54 | input_ids = tokenizer(
55 | text, return_tensors="pt", padding='longest', truncation=True
56 | ).to(model.device)
57 | with torch.no_grad():
58 | out = model(**input_ids)
59 | probs = torch.log_softmax(out.logits, dim=-1).detach()
60 | probs = probs[:, :-1, :]
61 | input_ids_shifted = input_ids.input_ids[:, 1:]
62 | log_probs = torch.gather(probs, 2, input_ids_shifted[:, :, None]).squeeze(-1)
63 | log_probs = log_probs[:, prompt_length:]
64 | up_to_eos_mask = (input_ids_shifted[:,prompt_length:].eq(
65 | tokenizer.eos_token_id).cumsum(1).cumsum(1) <= 1).type(log_probs.dtype)
66 | normalized_sequence_scores = (log_probs * up_to_eos_mask).sum(1) / up_to_eos_mask.sum(1)
67 | return normalized_sequence_scores
68 |
69 |
70 | def generate(prompt, model, tokenizer, temperatures, num_samples) -> Tuple[List[str], List[float]]:
71 | input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
72 | texts = []
73 | scores = []
74 | with torch.no_grad():
75 | # Does beam search at temp 0.0, otherwise temperature sampling.
76 | for temp in temperatures:
77 | decoding_params = dict(
78 | max_new_tokens=256,
79 | do_sample=temp > 0,
80 | temperature=temp,
81 | pad_token_id=tokenizer.eos_token_id,
82 | num_return_sequences=num_samples,
83 | return_dict_in_generate=True,
84 | output_scores=True,
85 | )
86 | if temp == 0.0:
87 | decoding_params['num_beams'] = num_samples
88 | out = model.generate(
89 | input_ids, **decoding_params
90 | )
91 |
92 | texts.extend(tokenizer.batch_decode(
93 | out.sequences[:,input_ids.shape[1]:],
94 | skip_special_tokens=True
95 | ))
96 | scores_ = sequence_scores(
97 | out=out,
98 | prompt_length=input_ids.shape[1],
99 | model=model,
100 | tokenizer=tokenizer
101 | )
102 | scores.extend(scores_.view(-1).tolist())
103 |
104 | texts, scores = _unique_sorted(texts, scores)
105 | return texts, scores
106 |
107 |
108 | def _unique_sorted(texts, scores):
109 | texts_, scores_ = [], []
110 | for t, s in sorted(zip(texts, scores), key=lambda x: -x[1]):
111 | if t not in texts_:
112 | texts_.append(t)
113 | scores_.append(s)
114 | return texts_, scores_
115 |
116 |
117 | def _print_type_checked_candidates(results):
118 | print('--- type-checked candidates:\n\t' + '\n\t'.join(
119 | '(%.3f) %s' % (step_score, step)
120 | for state, step, step_score in results if (
121 | get_goal(state) is not None or is_done(state))
122 | ))
123 |
124 |
125 | def _print_current(theorem_statement, steps):
126 | print('--- current:\n\t%s\n\t%s' % (
127 | theorem_statement.replace('{}', ''),
128 | '\n\t'.join(steps))
129 | )
130 |
131 |
132 | def best_first_search(model, tokenizer, header, statement, max_iters, temperatures, num_samples, verbose=False) -> dict:
133 | goal = get_goal(run_code(header + statement))
134 | if goal is None:
135 | return {
136 | 'theorem_statement': statement,
137 | 'success': False,
138 | 'msg': run_code(header + statement)
139 | }
140 |
141 | # Score, steps-so-far, goal state
142 | queue = [(0.0, [], goal)]
143 | visited = set()
144 | while len(queue) > 0 and max_iters > 0:
145 | # Dequeue the tuple with minimum score
146 | score, steps, goal = heapq.heappop(queue)
147 | visited.add(goal)
148 | if verbose:
149 | _print_current(statement, steps)
150 |
151 | # Generate next-step candidates
152 | prompt = f"[GOAL]{goal}[PROOFSTEP]"
153 | step_cands, step_scores = generate(
154 | prompt,
155 | model,
156 | tokenizer,
157 | temperatures=temperatures,
158 | num_samples=num_samples
159 | )
160 |
161 | # Run type checking in parallel via futures.
162 | with ThreadPoolExecutor(max_workers=16) as executor:
163 | # We need to save the step and score associated to each future.
164 | future2step = {}
165 | for step, step_score in zip(step_cands, step_scores):
166 | code = format_code(header, statement, steps, step)
167 | future = executor.submit(run_code, **dict(code=code))
168 | future2step[future] = (step, step_score)
169 |
170 | # Collect the type checking results as they complete.
171 | results = []
172 | for future in tqdm(concurrent.futures.as_completed(future2step.keys()), total=len(future2step)):
173 | result = future.result()
174 | results.append((result, *future2step[future]))
175 |
176 | if verbose:
177 | _print_type_checked_candidates(results)
178 | for state, step, step_score in results:
179 | # Stop if we have found a complete proof.
180 | if is_done(state):
181 | return {
182 | 'theorem_statement': statement,
183 | 'proof': steps + [step],
184 | 'state': state,
185 | 'score': score - step_score,
186 | 'success': True
187 | }
188 | goal_cand = get_goal(state)
189 | # Add new candidates to the queue.
190 | if goal_cand is not None and goal_cand not in visited:
191 | # Score is normalized negative log probability summed across steps
192 | new_score = (score - step_score)
193 | heapq.heappush(
194 | queue, (new_score, steps+[step], goal_cand)
195 | )
196 |
197 | max_iters -= 1
198 |
199 | return {'theorem_statement': statement, 'success': False}
200 |
201 |
202 | def _save(results):
203 | from datetime import datetime
204 | import json
205 | now = datetime.now()
206 | dt_string = now.strftime("%d-%m-%Y-%H-%M-%S")
207 | output_file = 'results__%s.json' % (dt_string)
208 | with open(output_file, 'w') as f:
209 | json.dump(results, f, indent=4)
210 | print(output_file)
211 |
212 |
213 | def load_model(model_name):
214 | model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name)
215 | tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name)
216 | model.eval()
217 | return model, tokenizer
218 |
219 |
220 | if __name__ == '__main__':
221 | model, tokenizer = load_model('wellecks/llmstep-mathlib4-pythia2.8b')
222 |
223 | evaluation_theorems = [
224 | """theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}""",
225 | """theorem thm2 (x y : ℝ) : x < y → 0 < y - x := by {}""",
226 | """theorem thm3 (n : Nat) : n ≥ 0 := by {}""",
227 | """theorem thm4 (x y z : ℝ) : x ≤ y → y ≤ z → x ≤ z := by {}""",
228 | """theorem thm5 (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by {}""",
229 | """theorem thm6: r ⊆ s → s ⊆ t → r ⊆ t := by {}""",
230 | """theorem thm7 (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by {}""",
231 | """theorem thm8 (c : ℝ) : Injective fun x => x + c := by {}""",
232 | """theorem thm9 (p q : Prop) : (p ∧ q) → ¬(¬p ∨ ¬q) := by {}""",
233 | """theorem thm10 (A B : Set ℕ) : A ⊆ B → ∀ n, n ∈ A → n ∈ B := by {}""",
234 | """theorem thm11 (injg : Injective g) (injf : Injective f) : Injective fun x => g (f x) := by {}""",
235 | """theorem thm12 (a b : ℕ) (h : a ≤ b) : a * (a + 1) ≤ b * (b + 1) := by {}""",
236 | """theorem thm13 (a b : ℕ) (h : a ≠ b) : a * 2 ≠ b * 2 := by {}""",
237 | ]
238 |
239 | # Shared header for the theorems above
240 | header = """import Mathlib.Data.Nat.Factorization.Basic
241 | import Mathlib.Data.Nat.Prime
242 | import Mathlib.Data.Real.Basic
243 |
244 | open BigOperators
245 | open Function
246 | variable {α : Type _} (r s t : Set α)
247 |
248 | """
249 |
250 | results = []
251 | for theorem in evaluation_theorems:
252 | result = best_first_search(
253 | model, tokenizer, header, theorem,
254 | max_iters=32,
255 | temperatures=[0.5],
256 | num_samples=16
257 | )
258 | print(result)
259 | print('\n-----\n')
260 | results.append(result)
261 |
262 | print(len([x for x in results if x['success']])/len(results))
263 | _save(results)
--------------------------------------------------------------------------------
/partI_nextstep/ntp_python/tune.py:
--------------------------------------------------------------------------------
1 | # Modified Alpaca finetuning script.
2 | # Original:
3 | # https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py
4 |
5 | import copy
6 | import logging
7 | import ndjson
8 | from dataclasses import dataclass, field
9 | from typing import Dict, Optional, Sequence
10 |
11 | import torch
12 | import transformers
13 | from torch.utils.data import Dataset
14 | from transformers import Trainer
15 | from data import TOKEN_MAP
16 |
17 | MASK_INPUT = True
18 | IGNORE_INDEX = -100
19 | DEFAULT_PAD_TOKEN = TOKEN_MAP['pad']
20 | DEFAULT_EOS_TOKEN = TOKEN_MAP['eos']
21 |
22 |
23 | @dataclass
24 | class ModelArguments:
25 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
26 |
27 |
28 | @dataclass
29 | class DataArguments:
30 | train_data_path: str = field(default=None, metadata={"help": "Path to the training data."})
31 | valid_data_path: str = field(default=None, metadata={"help": "Path to the validation data."})
32 |
33 |
34 | @dataclass
35 | class TrainingArguments(transformers.TrainingArguments):
36 | cache_dir: Optional[str] = field(default=None)
37 | optim: str = field(default="adamw_torch")
38 | model_max_length: int = field(
39 | default=1024,
40 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
41 | )
42 |
43 |
44 | def smart_tokenizer_and_embedding_resize(
45 | special_tokens_dict: Dict,
46 | tokenizer: transformers.PreTrainedTokenizer,
47 | model: transformers.PreTrainedModel,
48 | ):
49 | """Resize tokenizer and embedding.
50 |
51 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
52 | """
53 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
54 | model.resize_token_embeddings(len(tokenizer))
55 |
56 | if num_new_tokens > 0:
57 | input_embeddings = model.get_input_embeddings().weight.data
58 | output_embeddings = model.get_output_embeddings().weight.data
59 |
60 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
61 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
62 |
63 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
64 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
65 |
66 |
67 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
68 | """Tokenize a list of strings."""
69 | tokenized_list = [
70 | tokenizer(
71 | text,
72 | return_tensors="pt",
73 | padding="longest",
74 | max_length=tokenizer.model_max_length,
75 | truncation=True,
76 | )
77 | for text in strings
78 | ]
79 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
80 | input_ids_lens = labels_lens = [
81 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
82 | ]
83 | return dict(
84 | input_ids=input_ids,
85 | labels=labels,
86 | input_ids_lens=input_ids_lens,
87 | labels_lens=labels_lens,
88 | )
89 |
90 |
91 | def preprocess(
92 | sources: Sequence[str],
93 | targets: Sequence[str],
94 | tokenizer: transformers.PreTrainedTokenizer,
95 | ) -> Dict:
96 | """Preprocess the data by tokenizing."""
97 | examples = [s + t for s, t in zip(sources, targets)]
98 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
99 | input_ids = examples_tokenized["input_ids"]
100 | labels = copy.deepcopy(input_ids)
101 | if MASK_INPUT:
102 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
103 | label[:source_len] = IGNORE_INDEX
104 | return dict(input_ids=input_ids, labels=labels)
105 |
106 |
107 | class SupervisedDataset(Dataset):
108 | """Dataset for supervised fine-tuning."""
109 |
110 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
111 | super(SupervisedDataset, self).__init__()
112 | logging.warning("Loading data...")
113 | list_data_dict = ndjson.load(data_path)
114 |
115 | logging.warning("Formatting inputs...")
116 | sources = [example['input'] for example in list_data_dict]
117 | targets = [example['output'] for example in list_data_dict]
118 |
119 | logging.warning("Tokenizing inputs... This may take some time...")
120 | data_dict = preprocess(sources, targets, tokenizer)
121 |
122 | self.input_ids = data_dict["input_ids"]
123 | self.labels = data_dict["labels"]
124 |
125 | def __len__(self):
126 | return len(self.input_ids)
127 |
128 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
129 | return dict(input_ids=self.input_ids[i], labels=self.labels[i])
130 |
131 |
132 | @dataclass
133 | class DataCollatorForSupervisedDataset(object):
134 | """Collate examples for supervised fine-tuning."""
135 |
136 | tokenizer: transformers.PreTrainedTokenizer
137 |
138 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
139 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
140 | input_ids = torch.nn.utils.rnn.pad_sequence(
141 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
142 | )
143 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
144 | return dict(
145 | input_ids=input_ids,
146 | labels=labels,
147 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
148 | )
149 |
150 |
151 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
152 | """Make dataset and collator for supervised fine-tuning."""
153 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.train_data_path)
154 | eval_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.valid_data_path)
155 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
156 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)
157 |
158 |
159 | def train():
160 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
161 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
162 |
163 | model = transformers.AutoModelForCausalLM.from_pretrained(
164 | model_args.model_name_or_path,
165 | cache_dir=training_args.cache_dir,
166 | )
167 |
168 | tokenizer = transformers.AutoTokenizer.from_pretrained(
169 | model_args.model_name_or_path,
170 | cache_dir=training_args.cache_dir,
171 | model_max_length=training_args.model_max_length,
172 | padding_side="right",
173 | use_fast=False,
174 | )
175 | special_tokens_dict = dict()
176 | if tokenizer.pad_token is None:
177 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
178 | if tokenizer.eos_token is None:
179 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
180 |
181 | smart_tokenizer_and_embedding_resize(
182 | special_tokens_dict=special_tokens_dict,
183 | tokenizer=tokenizer,
184 | model=model,
185 | )
186 |
187 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
188 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
189 | trainer.train()
190 | trainer.save_state()
191 | trainer.save_model(output_dir=training_args.output_dir)
192 |
193 |
194 | if __name__ == "__main__":
195 | train()
--------------------------------------------------------------------------------
/partI_nextstep/scripts/ds_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 12,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "optimizer": {
11 | "type": "AdamW",
12 | "params": {
13 | "lr": "auto",
14 | "betas": "auto",
15 | "eps": "auto",
16 | "weight_decay": "auto"
17 | }
18 | },
19 | "scheduler": {
20 | "type": "WarmupLR",
21 | "params": {
22 | "warmup_min_lr": "auto",
23 | "warmup_max_lr": "auto",
24 | "warmup_num_steps": "auto"
25 | }
26 | },
27 | "zero_optimization": {
28 | "stage": 3,
29 | "offload_optimizer": {
30 | "device": "cpu",
31 | "pin_memory": false
32 | },
33 | "offload_param": {
34 | "device": "cpu",
35 | "pin_memory": false
36 | },
37 | "overlap_comm": true,
38 | "contiguous_gradients": true,
39 | "sub_group_size": 1e9,
40 | "reduce_bucket_size": "auto",
41 | "stage3_prefetch_bucket_size": "auto",
42 | "stage3_param_persistence_threshold": "auto",
43 | "stage3_max_live_parameters": 1e9,
44 | "stage3_max_reuse_distance": 1e9,
45 | "stage3_gather_fp16_weights_on_model_save": true
46 | },
47 | "gradient_accumulation_steps": "auto",
48 | "gradient_clipping": "auto",
49 | "steps_per_print": 2000,
50 | "train_batch_size": "auto",
51 | "train_micro_batch_size_per_gpu": "auto",
52 | "wall_clock_breakdown": false
53 | }
--------------------------------------------------------------------------------
/partI_nextstep/scripts/tune_proofstep.sh:
--------------------------------------------------------------------------------
1 | REPO_DIR=/path/to/ntptutorial
2 | TRAIN_FILE=${REPO_DIR}/data/leandojo_benchmark_4/processed/proofstep-train.jsonl
3 | VALID_FILE=${REPO_DIR}/data/leandojo_benchmark_4/processed/proofstep-val.jsonl
4 | MODEL=EleutherAI/pythia-2.8b-deduped
5 | CONFIG=${REPO_DIR}/scripts/ds_config.json
6 |
7 | OUTDIR=/path/to/output/ntptutorial/proofstep/${MODEL}
8 |
9 | deepspeed --include localhost:0,1,2,3,4,5,6,7 ${REPO_DIR}/ntp/tune.py \
10 | --deepspeed ${CONFIG} \
11 | --model_name_or_path ${MODEL} \
12 | --train_data_path ${TRAIN_FILE} \
13 | --valid_data_path ${VALID_FILE} \
14 | --fp16 \
15 | --output_dir ${OUTDIR} \
16 | --num_train_epochs 10 \
17 | --per_device_train_batch_size 4 \
18 | --per_device_eval_batch_size 4 \
19 | --gradient_accumulation_steps 2 \
20 | --evaluation_strategy "steps" \
21 | --eval_steps 500 \
22 | --save_strategy "steps" \
23 | --save_steps 500 \
24 | --save_total_limit 1 \
25 | --learning_rate 1e-5 \
26 | --load_best_model_at_end 1 \
27 | --weight_decay 0. \
28 | --warmup_ratio 0.03 \
29 | --lr_scheduler_type "cosine" \
30 | --logging_steps 10 \
31 | --logging_dir "$OUTDIR" \
32 | --report_to="tensorboard"
33 |
--------------------------------------------------------------------------------