├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── alphageometry.py
├── alphageometry_test.py
├── ar.py
├── ar_test.py
├── beam_search.py
├── dd.py
├── dd_test.py
├── ddar.py
├── ddar_test.py
├── decoder_stack.py
├── defs.txt
├── download.sh
├── examples.txt
├── fig1.svg
├── geometry.py
├── geometry_150M_generate.gin
├── geometry_test.py
├── graph.py
├── graph_test.py
├── graph_utils.py
├── graph_utils_test.py
├── imo_ag_30.txt
├── jgex_ag_231.txt
├── lm_inference.py
├── lm_inference_test.py
├── models.py
├── numericals.py
├── numericals_test.py
├── pretty.py
├── problem.py
├── problem_test.py
├── requirements.in
├── requirements.txt
├── rules.txt
├── run.sh
├── run_tests.sh
├── trace_back.py
├── trace_back_test.py
└── transformer_layer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Solving Olympiad Geometry without Human Demonstrations
3 |
4 |
5 | This repository contains the code necessary to
6 | reproduce DDAR and AlphaGeometry,
7 | the two geometry theorem provers
8 | introduced in the [Nature 2024](https://www.nature.com/articles/s41586-023-06747-5) paper:
9 |
10 | *
"Solving Olympiad Geometry without Human Demonstrations".*
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | ## Dependencies
22 |
23 | For the instructions presented below,
24 | we use Python 3.10.9, and dependencies with their exact
25 | version numbers listed in `requirements.txt`.
26 |
27 | Our code depends on `meliad`, which is
28 | not a registered package with `pip`. See instructions below
29 | for how to manually install `meliad`.
30 |
31 | Note that one can still run the DDAR solver
32 | without the `meliad` and `sentencepiece` dependencies.
33 |
34 | ## Run the instructions
35 |
36 | All instructions in this `README.md` can be run in one go by:
37 |
38 | ```
39 | bash run.sh
40 | ```
41 |
42 | Below, we explain these instructions step-by-step.
43 |
44 | ## Install dependencies, download weights and vocabulary.
45 |
46 | Installation is done in a virtual environment:
47 |
48 | ```
49 | virtualenv -p python3 .
50 | source ./bin/activate
51 | pip install --require-hashes -r requirements.txt
52 | ```
53 |
54 | Download weights and vocabulary:
55 |
56 | ```
57 | bash download.sh
58 | DATA=ag_ckpt_vocab
59 | ```
60 |
61 | Finally, install `meliad` separately as it is not
62 | registered with `pip`:
63 |
64 | ```
65 | MELIAD_PATH=meliad_lib/meliad
66 | mkdir -p $MELIAD_PATH
67 | git clone https://github.com/google-research/meliad $MELIAD_PATH
68 | export PYTHONPATH=$PYTHONPATH:$MELIAD_PATH
69 | ```
70 |
71 | ## Set up common flags
72 |
73 | Before running the python scripts,
74 | let us first prepare some commonly used flags.
75 | The symbolic engine needs definitions and deduction rules to operate.
76 | These definitions and rules are provided in two text files
77 | `defs.txt` and `rules.txt`.
78 |
79 | ```shell
80 | DDAR_ARGS=(
81 | --defs_file=$(pwd)/defs.txt \
82 | --rules_file=$(pwd)/rules.txt \
83 | );
84 | ```
85 |
86 | Next, we define the flags relevant to the proof search.
87 | To reproduce the simple examples below,
88 | we use lightweight values for the proof search parameters:
89 |
90 | ```shell
91 | BATCH_SIZE=2
92 | BEAM_SIZE=2
93 | DEPTH=2
94 |
95 | SEARCH_ARGS=(
96 | --beam_size=$BEAM_SIZE
97 | --search_depth=$DEPTH
98 | )
99 | ```
100 |
101 | NOTE: The results in our paper can be obtained by setting
102 | `BATCH_SIZE=32`, `BEAM_SIZE=512`, `DEPTH=16`
103 | as described in section Methods.
104 | To stay under IMO time limits, 4 V100-GPUs and 250 CPU workers
105 | are needed as shown in Extended Data - Figure 1.
106 | Note that we also strip away other memory/speed optimizations
107 | due to internal dependencies and to promote code clarity.
108 |
109 | Assume the downloaded checkpoint and vocabulary is placed in `DATA`,
110 | and the installed `meliad` source code is at `MELIAD_PATH`.
111 | We make use of the `gin` library to manage model configurations,
112 | following `meliad` conventions. We now define the flags relevant to the
113 | language model:
114 |
115 | ```shell
116 | LM_ARGS=(
117 | --ckpt_path=$DATA \
118 | --vocab_path=$DATA/geometry.757.model
119 | --gin_search_paths=$MELIAD_PATH/transformer/configs,$(pwd) \
120 | --gin_file=base_htrans.gin \
121 | --gin_file=size/medium_150M.gin \
122 | --gin_file=options/positions_t5.gin \
123 | --gin_file=options/lr_cosine_decay.gin \
124 | --gin_file=options/seq_1024_nocache.gin \
125 | --gin_file=geometry_150M_generate.gin \
126 | --gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True \
127 | --gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE \
128 | --gin_param=TransformerTaskConfig.sequence_length=128 \
129 | --gin_param=Trainer.restore_state_variables=False
130 | );
131 | ```
132 |
133 | TIP: Note that you can still run the DDAR solver
134 | without defining `SEARCH_ARGS` and `LM_ARGS`.
135 | In such case, simply disable the import of the `lm_inference` module
136 | inside `alphageometry.py`.
137 |
138 | ## Run DDAR
139 |
140 | The script loads a problem by reading a list of problems
141 | from a text file and solves the specific problem in the list according
142 | to its name. We pass these two pieces of information through the flags
143 | `--problems_file` and `--problem_name`.
144 | We use `--mode=ddar` to indicate that we want to use the DDAR solver.
145 |
146 | Below we showed this solver solving IMO 2000 P1:
147 |
148 | ```shell
149 | python -m alphageometry \
150 | --alsologtostderr \
151 | --problems_file=$(pwd)/imo_ag_30.txt \
152 | --problem_name=translated_imo_2000_p1 \
153 | --mode=ddar \
154 | "${DDAR_ARGS[@]}"
155 | ```
156 |
157 | Expect the following output
158 |
159 | ```shell
160 | graph.py:468] translated_imo_2000_p1
161 | graph.py:469] a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m = on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b, on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q
162 | ddar.py:41] Depth 1/1000 time = 1.7772269248962402
163 | ddar.py:41] Depth 2/1000 time = 5.63526177406311
164 | ddar.py:41] Depth 3/1000 time = 6.883412837982178
165 | ddar.py:41] Depth 4/1000 time = 10.275688409805298
166 | ddar.py:41] Depth 5/1000 time = 12.048273086547852
167 | alphageometry.py:190]
168 | ==========================
169 | * From theorem premises:
170 | A B G1 G2 M N C D E P Q : Points
171 | AG_1 ⟂ AB [00]
172 | BA ⟂ G_2B [01]
173 | G_2M = G_2B [02]
174 | G_1M = G_1A [03]
175 |
176 | ...
177 | [log omitted]
178 | ...
179 |
180 | 036. ∠QEB = ∠(QP-EA) [46] & ∠(BE-QP) = ∠AEP [55] ⇒ ∠EQP = ∠QPE [56]
181 | 037. ∠PQE = ∠EPQ [56] ⇒ EP = EQ
182 |
183 | ==========================
184 | ```
185 |
186 | The output first includes a list of relevant premises that it uses,
187 | and then proof steps that gradually build up the proof.
188 | All predicates are numbered to track how they are derived
189 | from the premises, and to show that the proof is fully justified.
190 |
191 | TIP: Additionally passing the flag `--out_file=path/to/output/text/file.txt`
192 | will write the proof to a text file.
193 |
194 | Running on all problems in `imo_ag_30.txt` will yield solutions to
195 | 14 of them, as reported in Table 1 in our paper.
196 |
197 | ## Run AlphaGeometry:
198 |
199 | As a simple example, we load `--problem_name=orthocenter`
200 | from `--problem_file=examples.txt`.
201 | This time, we pass `--mode=alphageometry` to use the AlphaGeometry solver
202 | and pass the `SEARCH_ARGS` and `LM_ARGS` flags.
203 |
204 | ```shell
205 | python -m alphageometry \
206 | --alsologtostderr \
207 | --problems_file=$(pwd)/examples.txt \
208 | --problem_name=orthocenter \
209 | --mode=alphageometry \
210 | "${DDAR_ARGS[@]}" \
211 | "${SEARCH_ARGS[@]}" \
212 | "${LM_ARGS[@]}"
213 | ```
214 |
215 | Expect the following output:
216 |
217 | ```shell
218 | ...
219 | [log omitted]
220 | ...
221 | training_loop.py:725] Total parameters: 152072288
222 | training_loop.py:739] Total state size: 0
223 | training_loop.py:492] Training loop: creating task for mode beam_search
224 |
225 | graph.py:468] orthocenter
226 | graph.py:469] a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c
227 | ddar.py:41] Depth 1/1000 time = 0.009987592697143555 branch = 4
228 | ddar.py:41] Depth 2/1000 time = 0.00672602653503418 branch = 0
229 | alphageometry.py:221] DD+AR failed to solve the problem.
230 | alphageometry.py:457] Depth 0. There are 1 nodes to expand:
231 | alphageometry.py:460] {S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c {F1} x00
232 | alphageometry.py:465] Decoding from {S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c {F1} x00
233 | ...
234 | [log omitted]
235 | ...
236 | alphageometry.py:470] LM output (score=-1.102287): "e : C a c e 02 C b d e 03 ;"
237 | alphageometry.py:471] Translation: "e = on_line e a c, on_line e b d"
238 |
239 | alphageometry.py:480] Solving: "a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c"
240 | graph.py:468]
241 | graph.py:469] a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c
242 | ddar.py:41] Depth 1/1000 time = 0.021120786666870117
243 | ddar.py:41] Depth 2/1000 time = 0.033370018005371094
244 | ddar.py:41] Depth 3/1000 time = 0.04297471046447754
245 | alphageometry.py:140]
246 | ==========================
247 | * From theorem premises:
248 | A B C D : Points
249 | BD ⟂ AC [00]
250 | CD ⟂ AB [01]
251 |
252 | * Auxiliary Constructions:
253 | E : Points
254 | E,B,D are collinear [02]
255 | E,C,A are collinear [03]
256 |
257 | * Proof steps:
258 | 001. E,B,D are collinear [02] & E,C,A are collinear [03] & BD ⟂ AC [00] ⇒ ∠BEA = ∠CED [04]
259 | 002. E,B,D are collinear [02] & E,C,A are collinear [03] & BD ⟂ AC [00] ⇒ ∠BEC = ∠AED [05]
260 | 003. A,E,C are collinear [03] & E,B,D are collinear [02] & AC ⟂ BD [00] ⇒ EC ⟂ EB [06]
261 | 004. EC ⟂ EB [06] & CD ⟂ AB [01] ⇒ ∠(EC-BA) = ∠(EB-CD) [07]
262 | 005. E,C,A are collinear [03] & E,B,D are collinear [02] & ∠(EC-BA) = ∠(EB-CD) [07] ⇒ ∠BAE = ∠CDE [08]
263 | 006. ∠BEA = ∠CED [04] & ∠BAE = ∠CDE [08] (Similar Triangles)⇒ EB:EC = EA:ED [09]
264 | 007. EB:EC = EA:ED [09] & ∠BEC = ∠AED [05] (Similar Triangles)⇒ ∠BCE = ∠ADE [10]
265 | 008. EB:EC = EA:ED [09] & ∠BEC = ∠AED [05] (Similar Triangles)⇒ ∠EBC = ∠EAD [11]
266 | 009. ∠BCE = ∠ADE [10] & E,C,A are collinear [03] & E,B,D are collinear [02] & ∠EBC = ∠EAD [11] ⇒ AD ⟂ BC
267 | ==========================
268 |
269 | alphageometry.py:505] Solved.
270 | ```
271 |
272 | NOTE: Point `H` is automatically renamed to `D`,
273 | as the LM is trained on synthetic problems
274 | where the points are named alphabetically, and so it expects
275 | the same during test time.
276 |
277 | NOTE: In this implementation of AlphaGeometry,
278 | we removed all optimizations that are dependent on
279 | internal infrastructure, e.g.,
280 | parallelized model inference on multi GPUs,
281 | parallelized DDAR on multiple CPUs,
282 | parallel execution of LM and DDAR,
283 | shared pool of CPU workers across different problems, etc.
284 | We also removed some memory/speed optimizations and code
285 | abstractions in favor of code clarity.
286 |
287 | As can be seen in the output, initially DDAR failed to solve the problem.
288 | The LM proposes two auxiliary constructions (because `BATCH_SIZE=2`):
289 |
290 | * `e = eqdistance e c a b, eqdistance e b a c`, i.e.,
291 | construct `E` as the intersection of circle (center=C, radius=AB) and
292 | circle (center=B, radius=AC). This construction has a score of `-1.186`.
293 | * `e = on_line e a c, on_line e b d`, i.e.,
294 | `E` is the intersection of `AC` and `BD`.
295 | This construction has a higher score (`-1.102287`) than the previous.
296 |
297 | Since the second construction has a higher score, DDAR attempted the second
298 | construction first and found the solution right away.
299 | The proof search therefore terminates and there is no second iteration.
300 |
301 | ## Results
302 |
303 | Before attempting to reproduce the AlphaGeometry numbers in our paper,
304 | please make sure to pass all tests in the prepared test suite:
305 |
306 | ```
307 | bash run_tests.sh
308 | ```
309 |
310 | NOTE: [Issues#14](https://github.com/google-deepmind/alphageometry/issues/14) reports that although the top beam decodes are still the same, the LM is not giving the same score for different users.
311 |
312 | Then, pass the corresponding values for `--problem_file` (column)
313 | and `--mode` (row), and
314 | iterate on all problems to obtain the following results:
315 |
316 |
317 |
318 | Number of solved problems:
319 |
320 | | | `imo_ag_30.txt` | `jgex_ag_231.txt` |
321 | |----------|------------------|-------------------|
322 | | `ddar` | 14 | 198 |
323 | | `alphageometry` | 25 | 228 |
324 |
325 |
326 |
327 | ## Source code description
328 |
329 | Files in this repository include python modules/scripts to run the solvers and
330 | resource files necessary for the script to execute. We listed below
331 | each of them and their description.
332 |
333 | | File name | Description |
334 | |------------------------|------------------------------------------------------------------------------------|
335 | | `geometry.py` | Implements nodes (Point, Line, Circle, etc) in the proof state graph. |
336 | | `numericals.py` | Implements the numerical engine in the dynamic geometry environment. |
337 | | `graph_utils.py` | Implements utilities for the proof state graph. |
338 | | `graph.py` | Implements the proof state graph. |
339 | | `problem.py` | Implements the classes that represent the problem premises, conclusion, DAG nodes. |
340 | | `dd.py` | Implements DD and its traceback. |
341 | | `ar.py` | Implements AR and its traceback. |
342 | | `trace_back.py` | Implements the recursive traceback and dependency difference algorithm. |
343 | | `ddar.py` | Implements the combination DD+AR. |
344 | | `beam_search.py` | Implements beam decoding of a language model in JAX. |
345 | | `models.py` | Implements the transformer model. |
346 | | `transformer_layer.py` | Implements the transformer layer. |
347 | | `decoder_stack.py` | Implements the transformer decoder stack. |
348 | | `lm_inference.py` | Implements an interface to a trained LM to perform decoding. |
349 | | `alphageometry.py` | Main script that loads problems, calls DD+AR or AlphaGeometry solver, and prints solutions. |
350 | | `pretty.py` | Pretty formating the solutions output by solvers. |
351 | | `*_test.py` | Tests for the corresponding module. |
352 | | `download.sh` | Script to download model checkpoints and LM |
353 | | `run.sh` | Script to execute instructions in README. |
354 | | `run_tests.sh` | Script to execute the test suite. |
355 |
356 |
357 | Resource files:
358 |
359 | | Resource file name | Description |
360 | |------------------------|------------------------------------------------------------------------------------|
361 | | `defs.txt` | Definitions of different geometric construction actions. |
362 | | `rules.txt` | Deduction rules for DD. |
363 | | `geometry_150M_generate.gin`| Gin config of the LM implemented in meliad. |
364 | | `imo_ag_30.txt` | Problems in IMO-AG-30. |
365 | | `jgex_ag_231.txt` | Problems in JGEX-AG-231. |
366 |
367 |
368 |
369 | ## Citing this work
370 |
371 | ```bibtex
372 | @Article{AlphaGeometryTrinh2024,
373 | author = {Trinh, Trieu and Wu, Yuhuai and Le, Quoc and He, He and Luong, Thang},
374 | journal = {Nature},
375 | title = {Solving Olympiad Geometry without Human Demonstrations},
376 | year = {2024},
377 | doi = {10.1038/s41586-023-06747-5}
378 | }
379 | ```
380 |
381 | ## Acknowledgements
382 |
383 | This research is a collaboration between the Google Brain team
384 | (now Google Deepmind) and
385 | the Computer Science Department of New York University.
386 | We thank Rif A. Saurous, Denny Zhou, Christian Szegedy, Delesley Hutchins,
387 | Thomas Kipf, Hieu Pham, Petar Veličković, Debidatta Dwibedi,
388 | Kyunghyun Cho, Lerrel Pinto, Alfredo Canziani,
389 | Thomas Wies, He He’s research group,
390 | Evan Chen (the USA’s IMO team coach),
391 | Mirek Olsak, Patrik Bak,
392 | and all three Nature's referees for their help and support.
393 |
394 | The code of AlphaGeometry communicates with and/or references the following
395 | separate libraries and packages:
396 |
397 | * [Abseil](https://github.com/abseil/abseil-py)
398 | * [JAX](https://github.com/google/jax/)
399 | * [matplotlib](https://matplotlib.org/)
400 | * [NumPy](https://numpy.org)
401 | * [SciPy](https://scipy.org)
402 | * [TensorFlow](https://github.com/tensorflow/tensorflow)
403 | * [Meliad](https://github.com/google-research/meliad)
404 | * [Flax](https://github.com/google/flax)
405 | * [Gin](https://github.com/google/gin-config)
406 | * [T5](https://github.com/google-research/text-to-text-transfer-transformer)
407 | * [SentencePiece](https://github.com/google/sentencepiece)
408 |
409 |
410 |
411 | We thank all their contributors and maintainers!
412 |
413 |
414 | ## Disclaimer
415 |
416 | This is not an officially supported Google product.
417 |
418 | This research code is provided "as-is" to the broader research community.
419 | Google does not promise to maintain or otherwise support this code in any way.
420 |
421 | ## Code License
422 |
423 | Copyright 2023 DeepMind Technologies Limited
424 |
425 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
426 | you may not use this file except in compliance with the Apache 2.0 license.
427 | You may obtain a copy of the Apache 2.0 license at:
428 | https://www.apache.org/licenses/LICENSE-2.0
429 |
430 | All other materials are licensed under the Creative Commons Attribution 4.0
431 | International License (CC-BY). You may obtain a copy of the CC-BY license at:
432 | https://creativecommons.org/licenses/by/4.0/legalcode
433 |
434 | Unless required by applicable law or agreed to in writing, all software and
435 | materials distributed here under the Apache 2.0 or CC-BY licenses are
436 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
437 | either express or implied. See the licenses for the specific language governing
438 | permissions and limitations under those licenses.
439 |
440 | ## Model Parameters License
441 |
442 | The AlphaGeometry checkpoints and vocabulary are made available
443 | under the terms of the Creative Commons Attribution 4.0
444 | International (CC BY 4.0) license.
445 | You can find details at:
446 | https://creativecommons.org/licenses/by/4.0/legalcode
447 |
448 |
--------------------------------------------------------------------------------
/alphageometry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Run DD+AR or AlphaGeometry solver.
17 |
18 | Please refer to README.md for detailed instructions.
19 | """
20 |
21 | import traceback
22 |
23 | from absl import app
24 | from absl import flags
25 | from absl import logging
26 | import ddar
27 | import graph as gh
28 | import lm_inference as lm
29 | import pretty as pt
30 | import problem as pr
31 |
32 |
33 | _GIN_SEARCH_PATHS = flags.DEFINE_list(
34 | 'gin_search_paths',
35 | ['third_party/py/meliad/transformer/configs'],
36 | 'List of paths where the Gin config files are located.',
37 | )
38 | _GIN_FILE = flags.DEFINE_multi_string(
39 | 'gin_file', ['base_htrans.gin'], 'List of Gin config files.'
40 | )
41 | _GIN_PARAM = flags.DEFINE_multi_string(
42 | 'gin_param', None, 'Newline separated list of Gin parameter bindings.'
43 | )
44 |
45 | _PROBLEMS_FILE = flags.DEFINE_string(
46 | 'problems_file',
47 | 'imo_ag_30.txt',
48 | 'text file contains the problem strings. See imo_ag_30.txt for example.',
49 | )
50 | _PROBLEM_NAME = flags.DEFINE_string(
51 | 'problem_name',
52 | 'imo_2000_p1',
53 | 'name of the problem to solve, must be in the problem_file.',
54 | )
55 | _MODE = flags.DEFINE_string(
56 | 'mode', 'ddar', 'either `ddar` (DD+AR) or `alphageometry`')
57 | _DEFS_FILE = flags.DEFINE_string(
58 | 'defs_file',
59 | 'defs.txt',
60 | 'definitions of available constructions to state a problem.',
61 | )
62 | _RULES_FILE = flags.DEFINE_string(
63 | 'rules_file', 'rules.txt', 'list of deduction rules used by DD.'
64 | )
65 | _CKPT_PATH = flags.DEFINE_string('ckpt_path', '', 'checkpoint of the LM model.')
66 | _VOCAB_PATH = flags.DEFINE_string(
67 | 'vocab_path', '', 'path to the LM vocab file.'
68 | )
69 | _OUT_FILE = flags.DEFINE_string(
70 | 'out_file', '', 'path to the solution output file.'
71 | ) # pylint: disable=line-too-long
72 | _BEAM_SIZE = flags.DEFINE_integer(
73 | 'beam_size', 1, 'beam size of the proof search.'
74 | ) # pylint: disable=line-too-long
75 | _SEARCH_DEPTH = flags.DEFINE_integer(
76 | 'search_depth', 1, 'search depth of the proof search.'
77 | ) # pylint: disable=line-too-long
78 |
79 | DEFINITIONS = None # contains definitions of construction actions
80 | RULES = None # contains rules of deductions
81 |
82 |
83 | def natural_language_statement(logical_statement: pr.Dependency) -> str:
84 | """Convert logical_statement to natural language.
85 |
86 | Args:
87 | logical_statement: pr.Dependency with .name and .args
88 |
89 | Returns:
90 | a string of (pseudo) natural language of the predicate for human reader.
91 | """
92 | names = [a.name.upper() for a in logical_statement.args]
93 | names = [(n[0] + '_' + n[1:]) if len(n) > 1 else n for n in names]
94 | return pt.pretty_nl(logical_statement.name, names)
95 |
96 |
97 | def proof_step_string(
98 | proof_step: pr.Dependency, refs: dict[tuple[str, ...], int], last_step: bool
99 | ) -> str:
100 | """Translate proof to natural language.
101 |
102 | Args:
103 | proof_step: pr.Dependency with .name and .args
104 | refs: dict(hash: int) to keep track of derived predicates
105 | last_step: boolean to keep track whether this is the last step.
106 |
107 | Returns:
108 | a string of (pseudo) natural language of the proof step for human reader.
109 | """
110 | premises, [conclusion] = proof_step
111 |
112 | premises_nl = ' & '.join(
113 | [
114 | natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
115 | for p in premises
116 | ]
117 | )
118 |
119 | if not premises:
120 | premises_nl = 'similarly'
121 |
122 | refs[conclusion.hashed()] = len(refs)
123 |
124 | conclusion_nl = natural_language_statement(conclusion)
125 | if not last_step:
126 | conclusion_nl += ' [{:02}]'.format(refs[conclusion.hashed()])
127 |
128 | return f'{premises_nl} \u21d2 {conclusion_nl}'
129 |
130 |
131 | def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
132 | """Output the solution to out_file.
133 |
134 | Args:
135 | g: gh.Graph object, containing the proof state.
136 | p: pr.Problem object, containing the theorem.
137 | out_file: file to write to, empty string to skip writing to file.
138 | """
139 | setup, aux, proof_steps, refs = ddar.get_proof_steps(
140 | g, p.goal, merge_trivials=False
141 | )
142 |
143 | solution = '\n=========================='
144 | solution += '\n * From theorem premises:\n'
145 | premises_nl = []
146 | for premises, [points] in setup:
147 | solution += ' '.join([p.name.upper() for p in points]) + ' '
148 | if not premises:
149 | continue
150 | premises_nl += [
151 | natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
152 | for p in premises
153 | ]
154 | solution += ': Points\n' + '\n'.join(premises_nl)
155 |
156 | solution += '\n\n * Auxiliary Constructions:\n'
157 | aux_premises_nl = []
158 | for premises, [points] in aux:
159 | solution += ' '.join([p.name.upper() for p in points]) + ' '
160 | aux_premises_nl += [
161 | natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
162 | for p in premises
163 | ]
164 | solution += ': Points\n' + '\n'.join(aux_premises_nl)
165 |
166 | # some special case where the deduction rule has a well known name.
167 | r2name = {
168 | 'r32': '(SSS)',
169 | 'r33': '(SAS)',
170 | 'r34': '(Similar Triangles)',
171 | 'r35': '(Similar Triangles)',
172 | 'r36': '(ASA)',
173 | 'r37': '(ASA)',
174 | 'r38': '(Similar Triangles)',
175 | 'r39': '(Similar Triangles)',
176 | 'r40': '(Congruent Triangles)',
177 | 'a00': '(Distance chase)',
178 | 'a01': '(Ratio chase)',
179 | 'a02': '(Angle chase)',
180 | }
181 |
182 | solution += '\n\n * Proof steps:\n'
183 | for i, step in enumerate(proof_steps):
184 | _, [con] = step
185 | nl = proof_step_string(step, refs, last_step=i == len(proof_steps) - 1)
186 | rule_name = r2name.get(con.rule_name, '')
187 | nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
188 | solution += '{:03}. '.format(i + 1) + nl + '\n'
189 |
190 | solution += '==========================\n'
191 | logging.info(solution)
192 | if out_file:
193 | with open(out_file, 'w') as f:
194 | f.write(solution)
195 | logging.info('Solution written to %s.', out_file)
196 |
197 |
198 | def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
199 | lm.parse_gin_configuration(
200 | _GIN_FILE.value, _GIN_PARAM.value, gin_paths=_GIN_SEARCH_PATHS.value
201 | )
202 |
203 | return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search')
204 |
205 |
206 | def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool:
207 | """Run DD+AR.
208 |
209 | Args:
210 | g: gh.Graph object, containing the proof state.
211 | p: pr.Problem object, containing the problem statement.
212 | out_file: path to output file if solution is found.
213 |
214 | Returns:
215 | Boolean, whether DD+AR finishes successfully.
216 | """
217 | ddar.solve(g, RULES, p, max_level=1000)
218 |
219 | goal_args = g.names2nodes(p.goal.args)
220 | if not g.check(p.goal.name, goal_args):
221 | logging.info('DD+AR failed to solve the problem.')
222 | return False
223 |
224 | write_solution(g, p, out_file)
225 |
226 | gh.nm.draw(
227 | g.type2nodes[gh.Point],
228 | g.type2nodes[gh.Line],
229 | g.type2nodes[gh.Circle],
230 | g.type2nodes[gh.Segment])
231 | return True
232 |
233 |
234 | def translate_constrained_to_constructive(
235 | point: str, name: str, args: list[str]
236 | ) -> tuple[str, list[str]]:
237 | """Translate a predicate from constraint-based to construction-based.
238 |
239 | Args:
240 | point: str: name of the new point
241 | name: str: name of the predicate, e.g., perp, para, etc.
242 | args: list[str]: list of predicate args.
243 |
244 | Returns:
245 | (name, args): translated to constructive predicate.
246 | """
247 | if name in ['T', 'perp']:
248 | a, b, c, d = args
249 | if point in [c, d]:
250 | a, b, c, d = c, d, a, b
251 | if point == b:
252 | a, b = b, a
253 | if point == d:
254 | c, d = d, c
255 | if a == c and a == point:
256 | return 'on_dia', [a, b, d]
257 | return 'on_tline', [a, b, c, d]
258 |
259 | elif name in ['P', 'para']:
260 | a, b, c, d = args
261 | if point in [c, d]:
262 | a, b, c, d = c, d, a, b
263 | if point == b:
264 | a, b = b, a
265 | return 'on_pline', [a, b, c, d]
266 |
267 | elif name in ['D', 'cong']:
268 | a, b, c, d = args
269 | if point in [c, d]:
270 | a, b, c, d = c, d, a, b
271 | if point == b:
272 | a, b = b, a
273 | if point == d:
274 | c, d = d, c
275 | if a == c and a == point:
276 | return 'on_bline', [a, b, d]
277 | if b in [c, d]:
278 | if b == d:
279 | c, d = d, c # pylint: disable=unused-variable
280 | return 'on_circle', [a, b, d]
281 | return 'eqdistance', [a, b, c, d]
282 |
283 | elif name in ['C', 'coll']:
284 | a, b, c = args
285 | if point == b:
286 | a, b = b, a
287 | if point == c:
288 | a, b, c = c, a, b
289 | return 'on_line', [a, b, c]
290 |
291 | elif name in ['^', 'eqangle']:
292 | a, b, c, d, e, f = args
293 |
294 | if point in [d, e, f]:
295 | a, b, c, d, e, f = d, e, f, a, b, c
296 |
297 | x, b, y, c, d = b, c, e, d, f
298 | if point == b:
299 | a, b, c, d = b, a, d, c
300 |
301 | if point == d and x == y: # x p x b = x c x p
302 | return 'angle_bisector', [point, b, x, c]
303 |
304 | if point == x:
305 | return 'eqangle3', [x, a, b, y, c, d]
306 |
307 | return 'on_aline', [a, x, b, c, y, d]
308 |
309 | elif name in ['cyclic', 'O']:
310 | a, b, c = [x for x in args if x != point]
311 | return 'on_circum', [point, a, b, c]
312 |
313 | return name, args
314 |
315 |
316 | def check_valid_args(name: str, args: list[str]) -> bool:
317 | """Check whether a predicate is grammarically correct.
318 |
319 | Args:
320 | name: str: name of the predicate
321 | args: list[str]: args of the predicate
322 |
323 | Returns:
324 | bool: whether the predicate arg count is valid.
325 | """
326 | if name == 'perp':
327 | if len(args) != 4:
328 | return False
329 | a, b, c, d = args
330 | if len({a, b}) < 2:
331 | return False
332 | if len({c, d}) < 2:
333 | return False
334 | elif name == 'para':
335 | if len(args) != 4:
336 | return False
337 | a, b, c, d = args
338 | if len({a, b, c, d}) < 4:
339 | return False
340 | elif name == 'cong':
341 | if len(args) != 4:
342 | return False
343 | a, b, c, d = args
344 | if len({a, b}) < 2:
345 | return False
346 | if len({c, d}) < 2:
347 | return False
348 | elif name == 'coll':
349 | if len(args) != 3:
350 | return False
351 | a, b, c = args
352 | if len({a, b, c}) < 3:
353 | return False
354 | elif name == 'cyclic':
355 | if len(args) != 4:
356 | return False
357 | a, b, c, d = args
358 | if len({a, b, c, d}) < 4:
359 | return False
360 | elif name == 'eqangle':
361 | if len(args) != 8:
362 | return False
363 | a, b, c, d, e, f, g, h = args
364 | if len({a, b, c, d}) < 3:
365 | return False
366 | if len({e, f, g, h}) < 3:
367 | return False
368 | return True
369 |
370 |
371 | def try_translate_constrained_to_construct(string: str, g: gh.Graph) -> str:
372 | """Whether a string of aux construction can be constructed.
373 |
374 | Args:
375 | string: str: the string describing aux construction.
376 | g: gh.Graph: the current proof state.
377 |
378 | Returns:
379 | str: whether this construction is valid. If not, starts with "ERROR:".
380 | """
381 | if string[-1] != ';':
382 | return 'ERROR: must end with ;'
383 |
384 | head, prem_str = string.split(' : ')
385 | point = head.strip()
386 |
387 | if len(point) != 1 or point == ' ':
388 | return f'ERROR: invalid point name {point}'
389 |
390 | existing_points = [p.name for p in g.all_points()]
391 | if point in existing_points:
392 | return f'ERROR: point {point} already exists.'
393 |
394 | prem_toks = prem_str.split()[:-1] # remove the EOS ' ;'
395 | prems = [[]]
396 |
397 | for i, tok in enumerate(prem_toks):
398 | if tok.isdigit():
399 | if i < len(prem_toks) - 1:
400 | prems.append([])
401 | else:
402 | prems[-1].append(tok)
403 |
404 | if len(prems) > 2:
405 | return 'ERROR: there cannot be more than two predicates.'
406 |
407 | clause_txt = point + ' = '
408 | constructions = []
409 |
410 | for prem in prems:
411 | name, *args = prem
412 |
413 | if point not in args:
414 | return f'ERROR: {point} not found in predicate args.'
415 |
416 | if not check_valid_args(pt.map_symbol(name), args):
417 | return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)
418 |
419 | for a in args:
420 | if a != point and a not in existing_points:
421 | return f'ERROR: point {a} does not exist.'
422 |
423 | try:
424 | name, args = translate_constrained_to_constructive(point, name, args)
425 | except: # pylint: disable=bare-except
426 | return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)
427 |
428 | if name == 'on_aline':
429 | if args.count(point) > 1:
430 | return f'ERROR: on_aline involves twice {point}'
431 |
432 | constructions += [name + ' ' + ' '.join(args)]
433 |
434 | clause_txt += ', '.join(constructions)
435 | clause = pr.Clause.from_txt(clause_txt)
436 |
437 | try:
438 | g.copy().add_clause(clause, 0, DEFINITIONS)
439 | except: # pylint: disable=bare-except
440 | return 'ERROR: ' + traceback.format_exc()
441 |
442 | return clause_txt
443 |
444 |
445 | def insert_aux_to_premise(pstring: str, auxstring: str) -> str:
446 | """Insert auxiliary constructs from proof to premise.
447 |
448 | Args:
449 | pstring: str: describing the problem to solve.
450 | auxstring: str: describing the auxiliar construction.
451 |
452 | Returns:
453 | str: new pstring with auxstring inserted before the conclusion.
454 | """
455 | setup, goal = pstring.split(' ? ')
456 | return setup + '; ' + auxstring + ' ? ' + goal
457 |
458 |
459 | class BeamQueue:
460 | """Keep only the top k objects according to their values."""
461 |
462 | def __init__(self, max_size: int = 512):
463 | self.queue = []
464 | self.max_size = max_size
465 |
466 | def add(self, node: object, val: float) -> None:
467 | """Add a new node to this queue."""
468 |
469 | if len(self.queue) < self.max_size:
470 | self.queue.append((val, node))
471 | return
472 |
473 | # Find the minimum node:
474 | min_idx, (min_val, _) = min(enumerate(self.queue), key=lambda x: x[1])
475 |
476 | # replace it if the new node has higher value.
477 | if val > min_val:
478 | self.queue[min_idx] = (val, node)
479 |
480 | def __iter__(self):
481 | for val, node in self.queue:
482 | yield val, node
483 |
484 | def __len__(self) -> int:
485 | return len(self.queue)
486 |
487 |
488 | def run_alphageometry(
489 | model: lm.LanguageModelInference,
490 | p: pr.Problem,
491 | search_depth: int,
492 | beam_size: int,
493 | out_file: str,
494 | ) -> bool:
495 | """Simplified code to run AlphaGeometry proof search.
496 |
497 | We removed all optimizations that are infrastructure-dependent, e.g.
498 | parallelized model inference on multi GPUs,
499 | parallelized DD+AR on multiple CPUs,
500 | parallel execution of LM and DD+AR,
501 | shared pool of CPU workers across different problems, etc.
502 |
503 | Many other speed optimizations and abstractions are also removed to
504 | better present the core structure of the proof search.
505 |
506 | Args:
507 | model: Interface with inference-related endpoints to JAX's model.
508 | p: pr.Problem object describing the problem to solve.
509 | search_depth: max proof search depth.
510 | beam_size: beam size of the proof search.
511 | out_file: path to output file if solution is found.
512 |
513 | Returns:
514 | boolean of whether this is solved.
515 | """
516 | # translate the problem to a string of grammar that the LM is trained on.
517 | string = p.setup_str_from_problem(DEFINITIONS)
518 | # special tokens prompting the LM to generate auxiliary points.
519 | string += ' {F1} x00'
520 | # the graph to represent the proof state.
521 | g, _ = gh.Graph.build_problem(p, DEFINITIONS)
522 |
523 | # First we run the symbolic engine DD+AR:
524 | if run_ddar(g, p, out_file):
525 | return True
526 |
527 | # beam search for the proof
528 | # each node in the search tree is a 3-tuple:
529 | # (,
530 | # ,
531 | # )
532 | beam_queue = BeamQueue(max_size=beam_size)
533 | # originally the beam search tree starts with a single node (a 3-tuple):
534 | beam_queue.add(
535 | node=(g, string, p.txt()), val=0.0 # value of the root node is simply 0.
536 | )
537 |
538 | for depth in range(search_depth):
539 | logging.info(
540 | 'Depth %s. There are %i nodes to expand:', depth, len(beam_queue)
541 | )
542 | for _, (_, string, _) in beam_queue:
543 | logging.info(string)
544 |
545 | new_queue = BeamQueue(max_size=beam_size) # to replace beam_queue.
546 |
547 | for prev_score, (g, string, pstring) in beam_queue:
548 | logging.info('Decoding from %s', string)
549 | outputs = model.beam_decode(string, eos_tokens=[';'])
550 |
551 | # translate lm output to the constructive language.
552 | # so that we can update the graph representing proof states:
553 | translations = [
554 | try_translate_constrained_to_construct(o, g)
555 | for o in outputs['seqs_str']
556 | ]
557 |
558 | # couple the lm outputs with its translations
559 | candidates = zip(outputs['seqs_str'], translations, outputs['scores'])
560 |
561 | # bring the highest scoring candidate first
562 | candidates = reversed(list(candidates))
563 |
564 | for lm_out, translation, score in candidates:
565 | logging.info('LM output (score=%f): "%s"', score, lm_out)
566 | logging.info('Translation: "%s"\n', translation)
567 |
568 | if translation.startswith('ERROR:'):
569 | # the construction is invalid.
570 | continue
571 |
572 | # Update the constructive statement of the problem with the aux point:
573 | candidate_pstring = insert_aux_to_premise(pstring, translation)
574 |
575 | logging.info('Solving: "%s"', candidate_pstring)
576 | p_new = pr.Problem.from_txt(candidate_pstring)
577 |
578 | # This is the new proof state graph representation:
579 | g_new, _ = gh.Graph.build_problem(p_new, DEFINITIONS)
580 | if run_ddar(g_new, p_new, out_file):
581 | logging.info('Solved.')
582 | return True
583 |
584 | # Add the candidate to the beam queue.
585 | new_queue.add(
586 | # The string for the new node is old_string + lm output +
587 | # the special token asking for a new auxiliary point ' x00':
588 | node=(g_new, string + ' ' + lm_out + ' x00', candidate_pstring),
589 | # the score of each node is sum of score of all nodes
590 | # on the path to itself. For beam search, there is no need to
591 | # normalize according to path length because all nodes in beam
592 | # is of the same path length.
593 | val=prev_score + score,
594 | )
595 | # Note that the queue only maintain at most beam_size nodes
596 | # so this new node might possibly be dropped depending on its value.
597 |
598 | # replace the old queue with new queue before the new proof search depth.
599 | beam_queue = new_queue
600 |
601 | return False
602 |
603 |
604 | def main(_):
605 | global DEFINITIONS
606 | global RULES
607 |
608 | # definitions of terms used in our domain-specific language.
609 | DEFINITIONS = pr.Definition.from_txt_file(_DEFS_FILE.value, to_dict=True)
610 | # load inference rules used in DD.
611 | RULES = pr.Theorem.from_txt_file(_RULES_FILE.value, to_dict=True)
612 |
613 | # when using the language model,
614 | # point names will be renamed to alphabetical a, b, c, d, e, ...
615 | # instead of staying with their original names,
616 | # in order to match the synthetic training data generation.
617 | need_rename = _MODE.value != 'ddar'
618 |
619 | # load problems from the problems_file,
620 | problems = pr.Problem.from_txt_file(
621 | _PROBLEMS_FILE.value, to_dict=True, translate=need_rename
622 | )
623 |
624 | if _PROBLEM_NAME.value not in problems:
625 | raise ValueError(
626 | f'Problem name `{_PROBLEM_NAME.value}` '
627 | + f'not found in `{_PROBLEMS_FILE.value}`'
628 | )
629 |
630 | this_problem = problems[_PROBLEM_NAME.value]
631 |
632 | if _MODE.value == 'ddar':
633 | g, _ = gh.Graph.build_problem(this_problem, DEFINITIONS)
634 | run_ddar(g, this_problem, _OUT_FILE.value)
635 |
636 | elif _MODE.value == 'alphageometry':
637 | model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
638 | run_alphageometry(
639 | model,
640 | this_problem,
641 | _SEARCH_DEPTH.value,
642 | _BEAM_SIZE.value,
643 | _OUT_FILE.value,
644 | )
645 |
646 | else:
647 | raise ValueError(f'Unknown FLAGS.mode: {_MODE.value}')
648 |
649 |
650 | if __name__ == '__main__':
651 | app.run(main)
652 |
--------------------------------------------------------------------------------
/alphageometry_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Unit tests for alphageometry.py."""
17 |
18 | import unittest
19 |
20 | from absl.testing import absltest
21 | import alphageometry
22 |
23 |
24 | class AlphaGeometryTest(unittest.TestCase):
25 |
26 | def test_translate_constrained_to_constructive(self):
27 | self.assertEqual(
28 | alphageometry.translate_constrained_to_constructive(
29 | 'd', 'T', list('addb')
30 | ),
31 | ('on_dia', ['d', 'b', 'a']),
32 | )
33 | self.assertEqual(
34 | alphageometry.translate_constrained_to_constructive(
35 | 'd', 'T', list('adbc')
36 | ),
37 | ('on_tline', ['d', 'a', 'b', 'c']),
38 | )
39 | self.assertEqual(
40 | alphageometry.translate_constrained_to_constructive(
41 | 'd', 'P', list('bcda')
42 | ),
43 | ('on_pline', ['d', 'a', 'b', 'c']),
44 | )
45 | self.assertEqual(
46 | alphageometry.translate_constrained_to_constructive(
47 | 'd', 'D', list('bdcd')
48 | ),
49 | ('on_bline', ['d', 'c', 'b']),
50 | )
51 | self.assertEqual(
52 | alphageometry.translate_constrained_to_constructive(
53 | 'd', 'D', list('bdcb')
54 | ),
55 | ('on_circle', ['d', 'b', 'c']),
56 | )
57 | self.assertEqual(
58 | alphageometry.translate_constrained_to_constructive(
59 | 'd', 'D', list('bacd')
60 | ),
61 | ('eqdistance', ['d', 'c', 'b', 'a']),
62 | )
63 | self.assertEqual(
64 | alphageometry.translate_constrained_to_constructive(
65 | 'd', 'C', list('bad')
66 | ),
67 | ('on_line', ['d', 'b', 'a']),
68 | )
69 | self.assertEqual(
70 | alphageometry.translate_constrained_to_constructive(
71 | 'd', 'C', list('bad')
72 | ),
73 | ('on_line', ['d', 'b', 'a']),
74 | )
75 | self.assertEqual(
76 | alphageometry.translate_constrained_to_constructive(
77 | 'd', 'O', list('abcd')
78 | ),
79 | ('on_circum', ['d', 'a', 'b', 'c']),
80 | )
81 |
82 | def test_insert_aux_to_premise(self):
83 | pstring = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c' # pylint: disable=line-too-long
84 | auxstring = 'e = on_line e a c, on_line e b d'
85 |
86 | target = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c' # pylint: disable=line-too-long
87 | self.assertEqual(
88 | alphageometry.insert_aux_to_premise(pstring, auxstring), target
89 | )
90 |
91 | def test_beam_queue(self):
92 | beam_queue = alphageometry.BeamQueue(max_size=2)
93 |
94 | beam_queue.add('a', 1)
95 | beam_queue.add('b', 2)
96 | beam_queue.add('c', 3)
97 |
98 | beam_queue = list(beam_queue)
99 | self.assertEqual(beam_queue, [(3, 'c'), (2, 'b')])
100 |
101 |
102 | if __name__ == '__main__':
103 | absltest.main()
104 |
--------------------------------------------------------------------------------
/ar_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Unit tests for ar.py."""
17 | import unittest
18 |
19 | from absl.testing import absltest
20 | import ar
21 | import graph as gh
22 | import problem as pr
23 |
24 |
25 | class ARTest(unittest.TestCase):
26 |
27 | @classmethod
28 | def setUpClass(cls):
29 | super().setUpClass()
30 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
31 | cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True)
32 |
33 | def test_update_groups(self):
34 | """Test for update_groups."""
35 | groups1 = [{1, 2}, {3, 4, 5}, {6, 7}]
36 | groups2 = [{2, 3, 8}, {9, 10, 11}]
37 |
38 | _, links, history = ar.update_groups(groups1, groups2)
39 | self.assertEqual(
40 | history,
41 | [
42 | [{1, 2, 3, 4, 5, 8}, {6, 7}],
43 | [{1, 2, 3, 4, 5, 8}, {6, 7}, {9, 10, 11}],
44 | ],
45 | )
46 | self.assertEqual(links, [(2, 3), (3, 8), (9, 10), (10, 11)])
47 |
48 | groups1 = [{1, 2}, {3, 4}, {5, 6}, {7, 8}]
49 | groups2 = [{2, 3, 8, 9, 10}, {3, 6, 11}]
50 |
51 | _, links, history = ar.update_groups(groups1, groups2)
52 | self.assertEqual(
53 | history,
54 | [
55 | [{1, 2, 3, 4, 7, 8, 9, 10}, {5, 6}],
56 | [{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}],
57 | ],
58 | )
59 | self.assertEqual(links, [(2, 3), (3, 8), (8, 9), (9, 10), (3, 6), (6, 11)])
60 |
61 | groups1 = []
62 | groups2 = [{1, 2}, {3, 4}, {5, 6}, {2, 3}]
63 |
64 | _, links, history = ar.update_groups(groups1, groups2)
65 | self.assertEqual(
66 | history,
67 | [
68 | [{1, 2}],
69 | [{1, 2}, {3, 4}],
70 | [{1, 2}, {3, 4}, {5, 6}],
71 | [{1, 2, 3, 4}, {5, 6}],
72 | ],
73 | )
74 | self.assertEqual(links, [(1, 2), (3, 4), (5, 6), (2, 3)])
75 |
76 | def test_generic_table_simple(self):
77 | tb = ar.Table()
78 |
79 | # If a-b = b-c & d-a = c-d
80 | tb.add_eq4('a', 'b', 'b', 'c', 'fact1')
81 | tb.add_eq4('d', 'a', 'c', 'd', 'fact2')
82 | tb.add_eq4('x', 'y', 'z', 't', 'fact3') # distractor fact
83 |
84 | # Then b=d, because {fact1, fact2} but not fact3.
85 | result = list(tb.get_all_eqs_and_why())
86 | self.assertIn(('b', 'd', ['fact1', 'fact2']), result)
87 |
88 | def test_angle_table_inbisector_exbisector(self):
89 | """Test that AR can figure out bisector & ex-bisector are perpendicular."""
90 | # Load the scenario that we have cd is bisector of acb and
91 | # ce is the ex-bisector of acb.
92 | p = pr.Problem.from_txt(
93 | 'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?'
94 | ' perp d c c e'
95 | )
96 | g, _ = gh.Graph.build_problem(p, ARTest.defs)
97 |
98 | # Create an external angle table:
99 | tb = ar.AngleTable('pi')
100 |
101 | # Add bisector & ex-bisector facts into the table:
102 | ca, cd, cb, ce = g.names2nodes(['d(ac)', 'd(cd)', 'd(bc)', 'd(ce)'])
103 | tb.add_eqangle(ca, cd, cd, cb, 'fact1')
104 | tb.add_eqangle(ce, ca, cb, ce, 'fact2')
105 |
106 | # Add a distractor fact to make sure traceback does not include this fact
107 | ab = g.names2nodes(['d(ab)'])[0]
108 | tb.add_eqangle(ab, cb, cb, ca, 'fact3')
109 |
110 | # Check for all new equalities
111 | result = list(tb.get_all_eqs_and_why())
112 |
113 | # halfpi is represented as a tuple (1, 2)
114 | halfpi = (1, 2)
115 |
116 | # check that cd-ce == halfpi and this is because fact1 & fact2, not fact3
117 | self.assertCountEqual(
118 | result,
119 | [
120 | (cd, ce, halfpi, ['fact1', 'fact2']),
121 | (ce, cd, halfpi, ['fact1', 'fact2']),
122 | ],
123 | )
124 |
125 | def test_angle_table_equilateral_triangle(self):
126 | """Test that AR can figure out triangles with 3 equal angles => each is pi/3."""
127 | # Load an equaliteral scenario
128 | p = pr.Problem.from_txt('a b c = ieq_triangle ? cong a b a c')
129 | g, _ = gh.Graph.build_problem(p, ARTest.defs)
130 |
131 | # Add two eqangles facts because ieq_triangle only add congruent sides
132 | a, b, c = g.names2nodes('abc')
133 | g.add_eqangle([a, b, b, c, b, c, c, a], pr.EmptyDependency(0, None))
134 | g.add_eqangle([b, c, c, a, c, a, a, b], pr.EmptyDependency(0, None))
135 |
136 | # Create an external angle table:
137 | tb = ar.AngleTable('pi')
138 |
139 | # Add the fact that there are three equal angles
140 | ab, bc, ca = g.names2nodes(['d(ab)', 'd(bc)', 'd(ac)'])
141 | tb.add_eqangle(ab, bc, bc, ca, 'fact1')
142 | tb.add_eqangle(bc, ca, ca, ab, 'fact2')
143 |
144 | # Now check for all new equalities
145 | result = list(tb.get_all_eqs_and_why())
146 | result = [(x.name, y.name, z, t) for x, y, z, t in result]
147 |
148 | # 1/3 pi is represented as a tuple angle_60
149 | angle_60 = (1, 3)
150 | angle_120 = (2, 3)
151 |
152 | # check that angles constants are created and figured out:
153 | self.assertCountEqual(
154 | result,
155 | [
156 | ('d(bc)', 'd(ac)', angle_120, ['fact1', 'fact2']),
157 | ('d(ab)', 'd(bc)', angle_120, ['fact1', 'fact2']),
158 | ('d(ac)', 'd(ab)', angle_120, ['fact1', 'fact2']),
159 | ('d(ac)', 'd(bc)', angle_60, ['fact1', 'fact2']),
160 | ('d(bc)', 'd(ab)', angle_60, ['fact1', 'fact2']),
161 | ('d(ab)', 'd(ac)', angle_60, ['fact1', 'fact2']),
162 | ],
163 | )
164 |
165 | def test_incenter_excenter_touchpoints(self):
166 | """Test that AR can figure out incenter/excenter touchpoints are equidistant to midpoint."""
167 |
168 | p = pr.Problem.from_txt(
169 | 'a b c = triangle a b c; d1 d2 d3 d = incenter2 a b c; e1 e2 e3 e ='
170 | ' excenter2 a b c ? perp d c c e',
171 | translate=False,
172 | )
173 | g, _ = gh.Graph.build_problem(p, ARTest.defs)
174 |
175 | a, b, c, ab, bc, ca, d1, d2, d3, e1, e2, e3 = g.names2nodes(
176 | ['a', 'b', 'c', 'ab', 'bc', 'ac', 'd1', 'd2', 'd3', 'e1', 'e2', 'e3']
177 | )
178 |
179 | # Create an external distance table:
180 | tb = ar.DistanceTable()
181 |
182 | # DD can figure out the following facts,
183 | # we manually add them to AR.
184 | tb.add_cong(ab, ca, a, d3, a, d2, 'fact1')
185 | tb.add_cong(ab, ca, a, e3, a, e2, 'fact2')
186 | tb.add_cong(ca, bc, c, d2, c, d1, 'fact5')
187 | tb.add_cong(ca, bc, c, e2, c, e1, 'fact6')
188 | tb.add_cong(bc, ab, b, d1, b, d3, 'fact3')
189 | tb.add_cong(bc, ab, b, e1, b, e3, 'fact4')
190 |
191 | # Now we check whether tb has figured out that
192 | # distance(b, d1) == distance(e1, c)
193 |
194 | # linear comb exprssion of each variables:
195 | b = tb.v2e['bc:b']
196 | c = tb.v2e['bc:c']
197 | d1 = tb.v2e['bc:d1']
198 | e1 = tb.v2e['bc:e1']
199 |
200 | self.assertEqual(ar.minus(d1, b), ar.minus(c, e1))
201 |
202 |
203 | if __name__ == '__main__':
204 | absltest.main()
205 |
--------------------------------------------------------------------------------
/beam_search.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Fast decoding routines for inference from a trained model.
17 |
18 | Modified https://github.com/google/flax/blob/main/examples/wmt/decode.py
19 | to acommodate
20 |
21 | (a) continued decoding from a previous beam cache.
22 | (b) init with with a single beam and then expand into beam_size beams.
23 | """
24 |
25 | from typing import Any
26 |
27 | import flax
28 | import jax
29 | from jax import lax
30 | import jax.numpy as jnp
31 | import numpy as np
32 |
33 |
34 | # Constants
35 | # "Effective negative infinity" constant for masking in beam search.
36 | NEG_INF = np.array(-1.0e7)
37 |
38 | # Beam search parameters
39 | BEAM_SEARCH_DEFAULT_ALPHA = 0.6
40 | MAX_DECODE_LEN = 32
41 |
42 | # Brevity penalty parameters
43 | BREVITY_LEN_BIAS_NUMERATOR = 5.0
44 | BREVITY_LEN_BIAS_DENOMINATOR = 6.0
45 |
46 |
47 | def brevity_penalty(alpha: float, length: int):
48 | """Brevity penalty function for beam search penalizing short sequences.
49 |
50 | Args:
51 | alpha: float: brevity-penalty scaling parameter.
52 | length: int: length of considered sequence.
53 |
54 | Returns:
55 | Brevity penalty score as jax scalar.
56 | """
57 | return jnp.power(
58 | ((BREVITY_LEN_BIAS_NUMERATOR + length) / BREVITY_LEN_BIAS_DENOMINATOR),
59 | alpha,
60 | )
61 |
62 |
63 | # Beam handling utility functions:
64 |
65 |
66 | def add_beam_dim(x: jnp.ndarray, beam_size: int) -> jnp.ndarray:
67 | """Creates new beam dimension in non-scalar array and tiles into it."""
68 | if x.ndim == 0: # ignore scalars (e.g. cache index)
69 | return x
70 | x = jnp.expand_dims(x, axis=1)
71 | tile_dims = [1] * x.ndim
72 | tile_dims[1] = beam_size
73 | return jnp.tile(x, tile_dims)
74 |
75 |
76 | def add_beam_dim_cache(
77 | cache: tuple[dict[str, jnp.ndarray], ...], beam_size: int
78 | ) -> tuple[dict[str, jnp.ndarray], ...]:
79 | """Creates new beam dimension in non-scalar array and tiles into it."""
80 | new_cache = []
81 |
82 | for layer in cache:
83 | new_layer = {}
84 | for key, x in layer.items():
85 | if key in ['keys', 'vals']:
86 | x = add_beam_dim(x, beam_size)
87 | new_layer[key] = x
88 | new_cache.append(new_layer)
89 |
90 | return tuple(new_cache)
91 |
92 |
93 | def flatten_beam_dim(x):
94 | """Flattens the first two dimensions of a non-scalar array."""
95 | if x.ndim < 2: # ignore scalars (e.g. cache index)
96 | return x
97 | return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
98 |
99 |
100 | def unflatten_beam_dim(x, batch_size, beam_size):
101 | """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
102 | if x.ndim == 0: # ignore scalars (e.g. cache index)
103 | return x
104 | assert batch_size * beam_size == x.shape[0]
105 | return x.reshape((batch_size, beam_size) + x.shape[1:])
106 |
107 |
108 | def flat_batch_beam_expand(x, beam_size):
109 | """Expands the each batch item by beam_size in batch_dimension."""
110 | return flatten_beam_dim(add_beam_dim(x, beam_size))
111 |
112 |
113 | def gather_beams(nested, beam_indices, batch_size, new_beam_size):
114 | """Gathers the beam slices indexed by beam_indices into new beam array.
115 |
116 | Args:
117 | nested: pytree of arrays or scalars (the latter ignored).
118 | beam_indices: array of beam_indices
119 | batch_size: int: size of batch.
120 | new_beam_size: int: size of _new_ beam dimension.
121 |
122 | Returns:
123 | New pytree with new beam arrays.
124 | [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
125 | """
126 | batch_indices = jnp.reshape(
127 | jnp.arange(batch_size * new_beam_size) // new_beam_size,
128 | (batch_size, new_beam_size),
129 | )
130 |
131 | def gather_fn(x):
132 | if x.ndim == 0: # ignore scalars (e.g. cache index)
133 | return x
134 | else:
135 | return x[batch_indices, beam_indices]
136 |
137 | return jax.tree_util.tree_map(gather_fn, nested)
138 |
139 |
140 | def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size):
141 | """Gathers the top-k beam slices given by score_or_log_prob array.
142 |
143 | Args:
144 | nested: pytree of arrays or scalars (the latter ignored).
145 | score_or_log_prob: [batch_size, old_beam_size] array of values to sort by
146 | for top-k selection of beam slices.
147 | batch_size: int: size of batch.
148 | new_beam_size: int: size of _new_ top-k selected beam dimension
149 |
150 | Returns:
151 | New pytree with new beam arrays containing top k new_beam_size slices.
152 | [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
153 | """
154 | _, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size)
155 | topk_indices = jnp.flip(topk_indices, axis=1)
156 | return gather_beams(nested, topk_indices, batch_size, new_beam_size)
157 |
158 |
159 | def apply_on_cache(fn, cache, *args, **kwargs):
160 | """Apply fn(val) only when key is 'keys' or 'val'."""
161 | new_cache = []
162 | for layer in cache:
163 | new_layer = {}
164 | for key, val in layer.items():
165 | if key in ['keys', 'values', 'current_index', 'relative_position_bias']:
166 | val = fn(val, *args, **kwargs)
167 | new_layer[key] = val
168 | new_cache.append(new_layer)
169 | return tuple(new_cache)
170 |
171 |
172 | # Beam search state:
173 |
174 |
175 | @flax.struct.dataclass
176 | class BeamState:
177 | """Holds beam search state data."""
178 |
179 | # The position of the decoding loop in the length dimension.
180 | cur_index: jax.Array # scalar int32: current decoded length index
181 | # The active sequence log probabilities and finished sequence scores.
182 | live_logprobs: jax.Array # float32: [batch_size, beam_size]
183 | finished_scores: jax.Array # float32: [batch_size, beam_size]
184 | # The current active-beam-searching and finished sequences.
185 | live_seqs: jax.Array # int32: [batch_size, beam_size, max_decode_len]
186 | finished_seqs: jax.Array # int32: [batch_size, beam_size,
187 | # max_decode_len]
188 | # Records which of the 'finished_seqs' is occupied and not a filler slot.
189 | finished_flags: jax.Array # bool: [batch_size, beam_size]
190 | # The current state of the autoregressive decoding caches.
191 | cache: Any # Any pytree of arrays, e.g. flax attention Cache object
192 |
193 |
194 | def beam_init(seed_token, batch_size, beam_size, max_decode_len, cache):
195 | """Initializes the beam search state data structure."""
196 | cur_index0 = jnp.array(0)
197 | live_logprobs0 = jnp.tile(
198 | jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]
199 | )
200 | finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
201 |
202 | live_seqs0 = jnp.concatenate(
203 | [
204 | jnp.reshape(seed_token, (batch_size, beam_size, 1)),
205 | jnp.zeros((batch_size, beam_size, max_decode_len - 1), jnp.int32),
206 | ],
207 | axis=-1,
208 | ) # (batch, beam, max_decode_len)
209 |
210 | finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
211 | finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
212 | beam_cache0 = apply_on_cache(lambda x: jnp.expand_dims(x, axis=0), cache)
213 | return BeamState(
214 | cur_index=cur_index0,
215 | live_logprobs=live_logprobs0,
216 | finished_scores=finished_scores0,
217 | live_seqs=live_seqs0,
218 | finished_seqs=finished_seqs0,
219 | finished_flags=finished_flags0,
220 | cache=beam_cache0,
221 | )
222 |
223 |
224 | # Beam search routine:
225 |
226 |
227 | def beam_search_flat(
228 | seed_token,
229 | cache,
230 | tokens_to_logits,
231 | alpha=BEAM_SEARCH_DEFAULT_ALPHA,
232 | eos=None,
233 | max_decode_len=MAX_DECODE_LEN,
234 | mask=None,
235 | ):
236 | """Beam search for LM.
237 |
238 | inputs and cache is already flat! i.e. first dimention == batch*beam.
239 |
240 | Args:
241 | seed_token: array: [beam_size, 1] int32 sequence of tokens.
242 | cache: flax attention cache.
243 | tokens_to_logits: fast autoregressive decoder function taking single token
244 | slices and cache and returning next-token logits and updated cache.
245 | alpha: float: scaling factor for brevity penalty.
246 | eos: array: [vocab] 1 for end-of-sentence tokens, 0 for not.
247 | max_decode_len: int: maximum length of decoded translations.
248 | mask: array: [vocab] binary mask for vocab. 1 to keep the prob, 0 to set the
249 | prob := 0.
250 |
251 | Returns:
252 | Tuple of:
253 | [beam_size, max_decode_len] top-scoring sequences
254 | [beam_size] beam-search scores.
255 | """
256 | # We liberally annotate shape information for clarity below.
257 | batch_size, beam_size = 1, seed_token.shape[0]
258 | mask = mask.reshape((1, 1, -1))
259 | eos = eos.reshape((1, 1, -1))
260 | mask_bias = (1 - mask) * NEG_INF
261 |
262 | # initialize beam search state
263 | beam_search_init_state = beam_init(
264 | seed_token, batch_size, beam_size, max_decode_len, cache
265 | )
266 |
267 | def beam_search_loop_cond_fn(state):
268 | """Beam search loop termination condition."""
269 | # Have we reached max decoding length?
270 | not_at_end = state.cur_index < max_decode_len - 1
271 |
272 | # Is no further progress in the beam search possible?
273 | # Get the best possible scores from alive sequences.
274 | min_brevity_penalty = brevity_penalty(alpha, max_decode_len)
275 | best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
276 | # Get the worst scores from finished sequences.
277 | worst_finished_scores = jnp.min(
278 | state.finished_scores, axis=1, keepdims=True
279 | )
280 | # Mask out scores from slots without any actual finished sequences.
281 | worst_finished_scores = jnp.where(
282 | state.finished_flags, worst_finished_scores, NEG_INF
283 | )
284 | # If no best possible live score is better than current worst finished
285 | # scores, the search cannot improve the finished set further.
286 | search_terminated = jnp.all(worst_finished_scores > best_live_scores)
287 |
288 | # If we're not at the max decode length, and the search hasn't terminated,
289 | # continue looping.
290 | return not_at_end & (~search_terminated)
291 |
292 | def beam_search_loop_body_fn(state):
293 | """Beam search loop state update function."""
294 | # Collect the current position slice along length to feed the fast
295 | # autoregressive decoder model. Flatten the beam dimension into batch
296 | # dimension for feeding into the model.
297 | # --> [batch * beam, 1]
298 | flat_ids = flatten_beam_dim(
299 | lax.dynamic_slice(
300 | state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1)
301 | )
302 | )
303 | # Flatten beam dimension into batch to be compatible with model.
304 | # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
305 | flat_cache = apply_on_cache(flatten_beam_dim, state.cache)
306 |
307 | # Call fast-decoder model on current tokens to get next-position logits.
308 | # --> [batch * beam, vocab]
309 | flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)
310 |
311 | # unflatten beam dimension
312 | # [batch * beam, vocab] --> [batch, beam, vocab]
313 | logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
314 |
315 | # Unflatten beam dimension in attention cache arrays
316 | # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
317 | new_cache = apply_on_cache(
318 | unflatten_beam_dim, new_flat_cache, batch_size, beam_size
319 | )
320 |
321 | # Gather log probabilities from logits
322 | candidate_log_probs = jax.nn.log_softmax(logits)
323 | # Add new logprobs to existing prefix logprobs.
324 | # --> [batch, beam, vocab]
325 | log_probs = candidate_log_probs + jnp.expand_dims(
326 | state.live_logprobs, axis=2
327 | )
328 |
329 | # We'll need the vocab size, gather it from the log probability dimension.
330 | vocab_size = log_probs.shape[2]
331 |
332 | # mask away some tokens.
333 | log_probs += mask_bias # [batch,beam,vocab]+[1,1,vocab]
334 |
335 | # Each item in batch has beam_size * vocab_size candidate sequences.
336 | # For each item, get the top 2*k candidates with the highest log-
337 | # probabilities. We gather the top 2*K beams here so that even if the best
338 | # K sequences reach EOS simultaneously, we have another K sequences
339 | # remaining to continue the live beam search.
340 | beams_to_keep = 2 * beam_size
341 | # Flatten beam and vocab dimensions.
342 | flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size))
343 | # Gather the top 2*K scores from _all_ beams.
344 | # --> [batch, 2*beams], [batch, 2*beams]
345 | topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep)
346 | # Recover the beam index by floor division.
347 | topk_beam_indices = topk_indices // vocab_size
348 | # Gather 2*k top beams.
349 | # --> [batch, 2*beams, length]
350 | topk_seq = gather_beams(
351 | state.live_seqs, topk_beam_indices, batch_size, beams_to_keep
352 | )
353 |
354 | # Append the most probable 2*K token IDs to the top 2*K sequences
355 | # Recover token id by modulo division and expand Id array for broadcasting.
356 | # --> [batch, 2*beams, 1]
357 | topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
358 | # Update sequences for the 2*K top-k new sequences.
359 | # --> [batch, 2*beams, length]
360 | topk_seq = lax.dynamic_update_slice(
361 | topk_seq, topk_ids, (0, 0, state.cur_index + 1)
362 | )
363 |
364 | # Update LIVE (in-progress) sequences:
365 | # Did any of these sequences reach an end marker?
366 | # --> [batch, 2*beams]
367 | last_token = topk_seq[:, :, state.cur_index + 1]
368 | last_token = jax.nn.one_hot(last_token, vocab_size, dtype=jnp.bfloat16)
369 |
370 | # any([batch, 2b, vocab] * [1, 1, vocab], axis=-1) == [batch, 2b]
371 | newly_finished = jnp.any(last_token * eos, axis=-1)
372 |
373 | # To prevent these newly finished sequences from being added to the LIVE
374 | # set of active beam search sequences, set their log probs to a very large
375 | # negative value.
376 | new_log_probs = topk_log_probs + newly_finished * NEG_INF
377 | # Determine the top k beam indices (from top 2*k beams) from log probs.
378 | # --> [batch, beams]
379 | _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
380 | new_topk_indices = jnp.flip(new_topk_indices, axis=1)
381 | # Gather the top k beams (from top 2*k beams).
382 | # --> [batch, beams, length], [batch, beams]
383 | top_alive_seq, top_alive_log_probs = gather_beams(
384 | [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size
385 | )
386 |
387 | # Determine the top k beam indices from the original set of all beams.
388 | # --> [batch, beams]
389 | top_alive_indices = gather_beams(
390 | topk_beam_indices, new_topk_indices, batch_size, beam_size
391 | )
392 | # With these, gather the top k beam-associated caches.
393 | # --> {[batch, beams, ...], ...}
394 | top_alive_cache = apply_on_cache(
395 | gather_beams, new_cache, top_alive_indices, batch_size, beam_size
396 | )
397 |
398 | # Update FINISHED (reached end of sentence) sequences:
399 | # Calculate new seq scores from log probabilities.
400 | new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1)
401 | # Mask out the still unfinished sequences by adding large negative value.
402 | # --> [batch, 2*beams]
403 | new_scores += (~newly_finished) * NEG_INF
404 |
405 | # Combine sequences, scores, and flags along the beam dimension and compare
406 | # new finished sequence scores to existing finished scores and select the
407 | # best from the new set of beams.
408 | finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length]
409 | [state.finished_seqs, topk_seq], axis=1
410 | )
411 | finished_scores = jnp.concatenate( # --> [batch, 3*beams]
412 | [state.finished_scores, new_scores], axis=1
413 | )
414 | finished_flags = jnp.concatenate( # --> [batch, 3*beams]
415 | [state.finished_flags, newly_finished], axis=1
416 | )
417 | # --> [batch, beams, length], [batch, beams], [batch, beams]
418 | top_finished_seq, top_finished_scores, top_finished_flags = (
419 | gather_topk_beams(
420 | [finished_seqs, finished_scores, finished_flags],
421 | finished_scores,
422 | batch_size,
423 | beam_size,
424 | )
425 | )
426 |
427 | return BeamState(
428 | cur_index=state.cur_index + 1,
429 | live_logprobs=top_alive_log_probs,
430 | finished_scores=top_finished_scores,
431 | live_seqs=top_alive_seq,
432 | finished_seqs=top_finished_seq,
433 | finished_flags=top_finished_flags,
434 | cache=top_alive_cache,
435 | )
436 |
437 | # Run while loop and get final beam search state.
438 | final_state = lax.while_loop(
439 | beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state
440 | )
441 |
442 | # Account for the edge-case where there are no finished sequences for a
443 | # particular batch item. If so, return live sequences for that batch item.
444 | # --> [batch]
445 | none_finished = jnp.any(final_state.finished_flags, axis=1)
446 | # --> [batch, beams, length]
447 | finished_seqs = jnp.where(
448 | none_finished[:, None, None],
449 | final_state.finished_seqs,
450 | final_state.live_seqs,
451 | )
452 | # --> [batch, beams]
453 | finished_scores = jnp.where(
454 | none_finished[:, None],
455 | final_state.finished_scores,
456 | final_state.live_logprobs,
457 | )
458 |
459 | finished_seqs = jnp.reshape(finished_seqs, (beam_size, max_decode_len))
460 | finished_scores = jnp.reshape(finished_scores, (beam_size,))
461 |
462 | final_cache = apply_on_cache(flatten_beam_dim, final_state.cache)
463 | return finished_seqs, finished_scores, final_cache
464 |
--------------------------------------------------------------------------------
/dd_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Unit tests for dd."""
17 | import unittest
18 |
19 | from absl.testing import absltest
20 | import dd
21 | import graph as gh
22 | import problem as pr
23 |
24 |
25 | MAX_LEVEL = 1000
26 |
27 |
28 | class DDTest(unittest.TestCase):
29 |
30 | @classmethod
31 | def setUpClass(cls):
32 | super().setUpClass()
33 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
34 | cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True)
35 |
36 | def test_imo_2022_p4_should_succeed(self):
37 | p = pr.Problem.from_txt(
38 | 'a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m ='
39 | ' on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n'
40 | ' g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b,'
41 | ' on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a'
42 | ' n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q'
43 | )
44 | g, _ = gh.Graph.build_problem(p, DDTest.defs)
45 | goal_args = g.names2nodes(p.goal.args)
46 |
47 | success = False
48 | for level in range(MAX_LEVEL):
49 | added, _, _, _ = dd.bfs_one_level(g, DDTest.rules, level, p)
50 | if g.check(p.goal.name, goal_args):
51 | success = True
52 | break
53 | if not added: # saturated
54 | break
55 |
56 | self.assertTrue(success)
57 |
58 | def test_incenter_excenter_should_fail(self):
59 | p = pr.Problem.from_txt(
60 | 'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?'
61 | ' perp d c c e'
62 | )
63 | g, _ = gh.Graph.build_problem(p, DDTest.defs)
64 | goal_args = g.names2nodes(p.goal.args)
65 |
66 | success = False
67 | for level in range(MAX_LEVEL):
68 | added, _, _, _ = dd.bfs_one_level(g, DDTest.rules, level, p)
69 | if g.check(p.goal.name, goal_args):
70 | success = True
71 | break
72 | if not added: # saturated
73 | break
74 |
75 | self.assertFalse(success)
76 |
77 |
78 | if __name__ == '__main__':
79 | absltest.main()
80 |
--------------------------------------------------------------------------------
/ddar.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Implements the combination DD+AR."""
17 | import time
18 |
19 | from absl import logging
20 | import dd
21 | import graph as gh
22 | import problem as pr
23 | from problem import Dependency # pylint: disable=g-importing-member
24 | import trace_back
25 |
26 |
27 | def saturate_or_goal(
28 | g: gh.Graph,
29 | theorems: list[pr.Theorem],
30 | level_times: list[float],
31 | p: pr.Problem,
32 | max_level: int = 100,
33 | timeout: int = 600,
34 | ) -> tuple[
35 | list[dict[str, list[tuple[gh.Point, ...]]]],
36 | list[dict[str, list[tuple[gh.Point, ...]]]],
37 | list[int],
38 | list[pr.Dependency],
39 | ]:
40 | """Run DD until saturation or goal found."""
41 | derives = []
42 | eq4s = []
43 | branching = []
44 | all_added = []
45 |
46 | while len(level_times) < max_level:
47 | level = len(level_times) + 1
48 |
49 | t = time.time()
50 | added, derv, eq4, n_branching = dd.bfs_one_level(
51 | g, theorems, level, p, verbose=False, nm_check=True, timeout=timeout
52 | )
53 | all_added += added
54 | branching.append(n_branching)
55 |
56 | derives.append(derv)
57 | eq4s.append(eq4)
58 | level_time = time.time() - t
59 |
60 | logging.info(f'Depth {level}/{max_level} time = {level_time}') # pylint: disable=logging-fstring-interpolation
61 | level_times.append(level_time)
62 |
63 | if p.goal is not None:
64 | goal_args = list(map(lambda x: g.get(x, lambda: int(x)), p.goal.args))
65 | if g.check(p.goal.name, goal_args): # found goal
66 | break
67 |
68 | if not added: # saturated
69 | break
70 |
71 | if level_time > timeout:
72 | break
73 |
74 | return derives, eq4s, branching, all_added
75 |
76 |
77 | def solve(
78 | g: gh.Graph,
79 | theorems: list[pr.Problem],
80 | controller: pr.Problem,
81 | max_level: int = 1000,
82 | timeout: int = 600,
83 | ) -> tuple[gh.Graph, list[float], str, list[int], list[pr.Dependency]]:
84 | """Alternate between DD and AR until goal is found."""
85 | status = 'saturated'
86 | level_times = []
87 |
88 | dervs, eq4 = g.derive_algebra(level=0, verbose=False)
89 | derives = [dervs]
90 | eq4s = [eq4]
91 | branches = []
92 | all_added = []
93 |
94 | while len(level_times) < max_level:
95 | dervs, eq4, next_branches, added = saturate_or_goal(
96 | g, theorems, level_times, controller, max_level, timeout=timeout
97 | )
98 | all_added += added
99 |
100 | derives += dervs
101 | eq4s += eq4
102 | branches += next_branches
103 |
104 | # Now, it is either goal or saturated
105 | if controller.goal is not None:
106 | goal_args = g.names2points(controller.goal.args)
107 | if g.check(controller.goal.name, goal_args): # found goal
108 | status = 'solved'
109 | break
110 |
111 | if not derives: # officially saturated.
112 | break
113 |
114 | # Now we resort to algebra derivations.
115 | added = []
116 | while derives and not added:
117 | added += dd.apply_derivations(g, derives.pop(0))
118 |
119 | if added:
120 | continue
121 |
122 | # Final help from AR.
123 | while eq4s and not added:
124 | added += dd.apply_derivations(g, eq4s.pop(0))
125 |
126 | all_added += added
127 |
128 | if not added: # Nothing left. saturated.
129 | break
130 |
131 | return g, level_times, status, branches, all_added
132 |
133 |
134 | def get_proof_steps(
135 | g: gh.Graph, goal: pr.Clause, merge_trivials: bool = False
136 | ) -> tuple[
137 | list[pr.Dependency],
138 | list[pr.Dependency],
139 | list[tuple[list[pr.Dependency], list[pr.Dependency]]],
140 | dict[tuple[str, ...], int],
141 | ]:
142 | """Extract proof steps from the built DAG."""
143 | goal_args = g.names2nodes(goal.args)
144 | query = Dependency(goal.name, goal_args, None, None)
145 |
146 | setup, aux, log, setup_points = trace_back.get_logs(
147 | query, g, merge_trivials=merge_trivials
148 | )
149 |
150 | refs = {}
151 | setup = trace_back.point_log(setup, refs, set())
152 | aux = trace_back.point_log(aux, refs, setup_points)
153 |
154 | setup = [(prems, [tuple(p)]) for p, prems in setup]
155 | aux = [(prems, [tuple(p)]) for p, prems in aux]
156 |
157 | return setup, aux, log, refs
158 |
--------------------------------------------------------------------------------
/ddar_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Unit tests for ddar.py."""
17 | import unittest
18 |
19 | from absl.testing import absltest
20 | import ddar
21 | import graph as gh
22 | import problem as pr
23 |
24 |
25 | class DDARTest(unittest.TestCase):
26 |
27 | @classmethod
28 | def setUpClass(cls):
29 | super().setUpClass()
30 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
31 | cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True)
32 |
33 | def test_orthocenter_should_fail(self):
34 | txt = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c' # pylint: disable=line-too-long
35 | p = pr.Problem.from_txt(txt)
36 | g, _ = gh.Graph.build_problem(p, DDARTest.defs)
37 |
38 | ddar.solve(g, DDARTest.rules, p, max_level=1000)
39 | goal_args = g.names2nodes(p.goal.args)
40 | self.assertFalse(g.check(p.goal.name, goal_args))
41 |
42 | def test_orthocenter_aux_should_succeed(self):
43 | txt = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c' # pylint: disable=line-too-long
44 | p = pr.Problem.from_txt(txt)
45 | g, _ = gh.Graph.build_problem(p, DDARTest.defs)
46 |
47 | ddar.solve(g, DDARTest.rules, p, max_level=1000)
48 | goal_args = g.names2nodes(p.goal.args)
49 | self.assertTrue(g.check(p.goal.name, goal_args))
50 |
51 | def test_incenter_excenter_should_succeed(self):
52 | # Note that this same problem should fail in dd_test.py
53 | p = pr.Problem.from_txt(
54 | 'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?'
55 | ' perp d c c e'
56 | ) # pylint: disable=line-too-long
57 | g, _ = gh.Graph.build_problem(p, DDARTest.defs)
58 |
59 | ddar.solve(g, DDARTest.rules, p, max_level=1000)
60 | goal_args = g.names2nodes(p.goal.args)
61 | self.assertTrue(g.check(p.goal.name, goal_args))
62 |
63 |
64 | if __name__ == '__main__':
65 | absltest.main()
66 |
--------------------------------------------------------------------------------
/decoder_stack.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """The decoder stack in inference mode."""
17 |
18 | from typing import Any, Tuple
19 |
20 | import gin
21 | from transformer import decoder_stack
22 | import transformer_layer as tl
23 |
24 |
25 | struct = decoder_stack.struct
26 | nn_components = decoder_stack.nn_components
27 | position = decoder_stack.position
28 | jnp = decoder_stack.jnp
29 | attention = decoder_stack.attention
30 |
31 | DStackWindowState = decoder_stack.DStackWindowState
32 |
33 | Array = Any
34 |
35 | TransformerTaskConfig = decoder_stack.TransformerTaskConfig
36 |
37 | DStackDecoderState = Tuple[tl.DecoderState, ...]
38 |
39 |
40 | @gin.configurable
41 | class DecoderStackGenerate(decoder_stack.DecoderStack):
42 | """Stack of transformer decoder layers."""
43 |
44 | layer_factory = tl.TransformerLayerGenerate
45 |
46 | def init_decoder_state_vanilla(
47 | self, sequence_length: int, start_of_sequence: Array
48 | ) -> DStackDecoderState:
49 | """Return initial state for autoregressive generation."""
50 | return tuple(
51 | [
52 | layer.init_decoder_state_vanilla(sequence_length, start_of_sequence)
53 | for layer in self.transformer_layers
54 | ]
55 | )
56 |
--------------------------------------------------------------------------------
/defs.txt:
--------------------------------------------------------------------------------
1 | angle_bisector x a b c
2 | x : a b c x
3 | a b c = ncoll a b c
4 | x : eqangle b a b x b x b c
5 | bisect a b c
6 |
7 | angle_mirror x a b c
8 | x : a b c x
9 | a b c = ncoll a b c
10 | x : eqangle b a b c b c b x
11 | amirror a b c
12 |
13 | circle x a b c
14 | x : a b c
15 | a b c = ncoll a b c
16 | x : cong x a x b, cong x b x c
17 | bline a b, bline a c
18 |
19 | circumcenter x a b c
20 | x : a b c
21 | a b c = ncoll a b c
22 | x : cong x a x b, cong x b x c
23 | bline a b, bline a c
24 |
25 | eq_quadrangle a b c d
26 | d : a b c d
27 | =
28 | a : ; b : ; c : ; d : cong d a b c
29 | eq_quadrangle
30 |
31 | eq_trapezoid a b c d
32 | d : a b c
33 | =
34 | a : ; b : ; c : ; d : para d c a b, cong d a b c
35 | eq_trapezoid
36 |
37 | eq_triangle x b c
38 | x : b c
39 | b c = diff b c
40 | x : cong x b b c, cong b c c x; eqangle b x b c c b c x, eqangle x c x b b x b c
41 | circle b b c, circle c b c
42 |
43 | eqangle2 x a b c
44 | x : a b c x
45 | a b c = ncoll a b c
46 | x : eqangle a b a x c x c b
47 | eqangle2 a b c
48 |
49 | eqdia_quadrangle a b c d
50 | d : a b c d
51 | =
52 | a : ; b : ; c : ; d : cong d b a c
53 | eqdia_quadrangle
54 |
55 | eqdistance x a b c
56 | x : a b c x
57 | a b c = diff b c
58 | x : cong x a b c
59 | circle a b c
60 |
61 | foot x a b c
62 | x : a b c
63 | a b c = ncoll a b c
64 | x : perp x a b c, coll x b c
65 | tline a b c, line b c
66 |
67 | free a
68 | a : a
69 | =
70 | a :
71 | free
72 |
73 | incenter x a b c
74 | x : a b c
75 | a b c = ncoll a b c
76 | x : eqangle a b a x a x a c, eqangle c a c x c x c b; eqangle b c b x b x b a
77 | bisect a b c, bisect b c a
78 |
79 | incenter2 x y z i a b c
80 | i : a b c, x : i b c, y : i c a, z : i a b
81 | a b c = ncoll a b c
82 | i : eqangle a b a i a i a c, eqangle c a c i c i c b; eqangle b c b i b i b a; x : coll x b c, perp i x b c; y : coll y c a, perp i y c a; z : coll z a b, perp i z a b; cong i x i y, cong i y i z
83 | incenter2 a b c
84 |
85 | excenter x a b c
86 | x : a b c
87 | a b c = ncoll a b c
88 | x : eqangle a b a x a x a c, eqangle c a c x c x c b; eqangle b c b x b x b a
89 | bisect b a c, exbisect b c a
90 |
91 | excenter2 x y z i a b c
92 | i : a b c, x : i b c, y : i c a, z : i a b
93 | a b c = ncoll a b c
94 | i : eqangle a b a i a i a c, eqangle c a c i c i c b; eqangle b c b i b i b a; x : coll x b c, perp i x b c; y : coll y c a, perp i y c a; z : coll z a b, perp i z a b; cong i x i y, cong i y i z
95 | excenter2 a b c
96 |
97 | centroid x y z i a b c
98 | x : b c, y : c a, z : a b, i : a x b y
99 | a b c = ncoll a b c
100 | x : coll x b c, cong x b x c; y : coll y c a, cong y c y a; z : coll z a b, cong z a z b; i : coll a x i, coll b y i; coll c z i
101 | centroid a b c
102 |
103 | ninepoints x y z i a b c
104 | x : b c, y : c a, z : a b, i : x y z
105 | a b c = ncoll a b c
106 | x : coll x b c, cong x b x c; y : coll y c a, cong y c y a; z : coll z a b, cong z a z b; i : cong i x i y, cong i y i z
107 | ninepoints a b c
108 |
109 | intersection_cc x o w a
110 | x : o w a
111 | o w a = ncoll o w a
112 | x : cong o a o x, cong w a w x
113 | circle o o a, circle w w a
114 |
115 | intersection_lc x a o b
116 | x : a o b
117 | a o b = diff a b, diff o b, nperp b o b a
118 | x : coll x a b, cong o b o x
119 | line b a, circle o o b
120 |
121 | intersection_ll x a b c d
122 | x : a b c d
123 | a b c d = npara a b c d, ncoll a b c d
124 | x : coll x a b, coll x c d
125 | line a b, line c d
126 |
127 | intersection_lp x a b c m n
128 | x : a b c m n
129 | a b c m n = npara m n a b, ncoll a b c, ncoll c m n
130 | x : coll x a b, para c x m n
131 | line a b, pline c m n
132 |
133 | intersection_lt x a b c d e
134 | x : a b c d e
135 | a b c d e = ncoll a b c, nperp a b d e
136 | x : coll x a b, perp x c d e
137 | line a b, tline c d e
138 |
139 | intersection_pp x a b c d e f
140 | x : a b c d e f
141 | a b c d e f = diff a d, npara b c e f
142 | x : para x a b c, para x d e f
143 | pline a b c, pline d e f
144 |
145 | intersection_tt x a b c d e f
146 | x : a b c d e f
147 | a b c d e f = diff a d, npara b c e f
148 | x : perp x a b c, perp x d e f
149 | tline a b c, tline d e f
150 |
151 | iso_triangle a b c
152 | c : a b c
153 | =
154 | a : ; b : ; c : eqangle b a b c c b c a, cong a b a c
155 | isos
156 |
157 | lc_tangent x a o
158 | x : x a o
159 | a o = diff a o
160 | x : perp a x a o
161 | tline a a o
162 |
163 | midpoint x a b
164 | x : a b
165 | a b = diff a b
166 | x : coll x a b, cong x a x b
167 | midp a b
168 |
169 | mirror x a b
170 | x : a b
171 | a b = diff a b
172 | x : coll x a b, cong b a b x
173 | pmirror a b
174 |
175 | nsquare x a b
176 | x : a b
177 | a b = diff a b
178 | x : cong x a a b, perp x a a b
179 | rotaten90 a b
180 |
181 | on_aline x a b c d e
182 | x : x a b c d e
183 | a b c d e = ncoll c d e
184 | x : eqangle a x a b d c d e
185 | aline e d c b a
186 |
187 | on_aline2 x a b c d e
188 | x : x a b c d e
189 | a b c d e = ncoll c d e
190 | x : eqangle x a x b d c d e
191 | aline2 e d c b a
192 |
193 | on_bline x a b
194 | x : x a b
195 | a b = diff a b
196 | x : cong x a x b, eqangle a x a b b a b x
197 | bline a b
198 |
199 | on_circle x o a
200 | x : x o a
201 | o a = diff o a
202 | x : cong o x o a
203 | circle o o a
204 |
205 | on_line x a b
206 | x : x a b
207 | a b = diff a b
208 | x : coll x a b
209 | line a b
210 |
211 | on_pline x a b c
212 | x : x a b c
213 | a b c = diff b c, ncoll a b c
214 | x : para x a b c
215 | pline a b c
216 |
217 | on_tline x a b c
218 | x : x a b c
219 | a b c = diff b c
220 | x : perp x a b c
221 | tline a b c
222 |
223 | orthocenter x a b c
224 | x : a b c
225 | a b c = ncoll a b c
226 | x : perp x a b c, perp x b c a; perp x c a b
227 | tline a b c, tline b c a
228 |
229 | parallelogram a b c x
230 | x : a b c
231 | a b c = ncoll a b c
232 | x : para a b c x, para a x b c; cong a b c x, cong a x b c
233 | pline a b c, pline c a b
234 |
235 | pentagon a b c d e
236 |
237 | =
238 | a : ; b : ; c : ; d : ; e :
239 | pentagon
240 |
241 | psquare x a b
242 | x : a b
243 | a b = diff a b
244 | x : cong x a a b, perp x a a b
245 | rotatep90 a b
246 |
247 | quadrangle a b c d
248 |
249 | =
250 | a : ; b : ; c : ; d :
251 | quadrangle
252 |
253 | r_trapezoid a b c d
254 | d : a b c
255 | =
256 | a : ; b : ; c : ; d : para a b c d, perp a b a d
257 | r_trapezoid
258 |
259 | r_triangle a b c
260 | c : a b c
261 | =
262 | a : ; b : ; c : perp a b a c
263 | r_triangle
264 |
265 | rectangle a b c d
266 | c : a b c , d : a b c
267 | =
268 | a : ; b : ; c : perp a b b c ; d : para a b c d, para a d b c; perp a b a d, cong a b c d, cong a d b c, cong a c b d
269 | rectangle
270 |
271 | reflect x a b c
272 | x : a b c
273 | a b c = diff b c, ncoll a b c
274 | x : cong b a b x, cong c a c x; perp b c a x
275 | reflect a b c
276 |
277 | risos a b c
278 | c : a b
279 | =
280 | a : ; b : ; c : perp a b a c, cong a b a c; eqangle b a b c c b c a
281 | risos
282 |
283 | s_angle a b x y
284 | x : a b x
285 | a b = diff a b
286 | x : s_angle a b x y
287 | s_angle a b y
288 |
289 | segment a b
290 |
291 | =
292 | a : ; b :
293 | segment
294 |
295 | shift x b c d
296 | x : b c d
297 | b c d = diff d b
298 | x : cong x b c d, cong x c b d
299 | shift d c b
300 |
301 | square a b x y
302 | x : a b, y : a b x
303 | a b = diff a b
304 | x : perp a b b x, cong a b b x; y : para a b x y, para a y b x; perp a y y x, cong b x x y, cong x y y a, perp a x b y, cong a x b y
305 | square a b
306 |
307 | isquare a b c d
308 | c : a b , d : a b c
309 | =
310 | a : ; b : ; c : perp a b b c, cong a b b c; d : para a b c d, para a d b c; perp a d d c, cong b c c d, cong c d d a, perp a c b d, cong a c b d
311 | isquare
312 |
313 | trapezoid a b c d
314 | d : a b c d
315 | =
316 | a : ; b : ; c : ; d : para a b c d
317 | trapezoid
318 |
319 | triangle a b c
320 |
321 | =
322 | a : ; b : ; c :
323 | triangle
324 |
325 | triangle12 a b c
326 | c : a b c
327 | =
328 | a : ; b : ; c : rconst a b a c 1 2
329 | triangle12
330 |
331 | 2l1c x y z i a b c o
332 | x : a b c o y z i, y : a b c o x z i, z : a b c o x y i, i : a b c o x y z
333 | a b c o = cong o a o b, ncoll a b c
334 | x y z i : coll x a c, coll y b c, cong o a o z, coll i o z, cong i x i y, cong i y i z, perp i x a c, perp i y b c
335 | 2l1c a b c o
336 |
337 | e5128 x y a b c d
338 | x : a b c d y, y : a b c d x
339 | a b c d = cong c b c d, perp b c b a
340 | x y : cong c b c x, coll y a b, coll x y d, eqangle a b a d x a x y
341 | e5128 a b c d
342 |
343 | 3peq x y z a b c
344 | z : b c z , x : a b c z y, y : a b c z x
345 | a b c = ncoll a b c
346 | z : coll z b c ; x y : coll x a b, coll y a c, coll x y z, cong z x z y
347 | 3peq a b c
348 |
349 | trisect x y a b c
350 | x : a b c y, y : a b c x
351 | a b c = ncoll a b c
352 | x y : coll x a c, coll y a c, eqangle b a b x b x b y, eqangle b x b y b y b c
353 | trisect a b c
354 |
355 | trisegment x y a b
356 | x : a b y, y : a b x
357 | a b = diff a b
358 | x y : coll x a b, coll y a b, cong x a x y, cong y x y b
359 | trisegment a b
360 |
361 | on_dia x a b
362 | x : x a b
363 | a b = diff a b
364 | x : perp x a x b
365 | dia a b
366 |
367 | ieq_triangle a b c
368 | c : a b
369 | =
370 | a : ; b : ; c : cong a b b c, cong b c c a; eqangle a b a c c a c b, eqangle c a c b b c b a
371 | ieq_triangle
372 |
373 | on_opline x a b
374 | x : x a b
375 | a b = diff a b
376 | x : coll x a b
377 | on_opline a b
378 |
379 | cc_tangent0 x y o a w b
380 | x : o a w b y, y : o a w b x
381 | o a w b = diff o a, diff w b, diff o w
382 | x y : cong o x o a, cong w y w b, perp x o x y, perp y w y x
383 | cc_tangent0 o a w b
384 |
385 | cc_tangent x y z i o a w b
386 | x : o a w b y, y : o a w b x, z : o a w b i, i : o a w b z
387 | o a w b = diff o a, diff w b, diff o w
388 | x y : cong o x o a, cong w y w b, perp x o x y, perp y w y x; z i : cong o z o a, cong w i w b, perp z o z i, perp i w i z
389 | cc_tangent o a w b
390 |
391 | eqangle3 x a b d e f
392 | x : x a b d e f
393 | a b d e f = ncoll d e f, diff a b, diff d e, diff e f
394 | x : eqangle x a x b d e d f
395 | eqangle3 a b d e f
396 |
397 | tangent x y a o b
398 | x y : o a b
399 | a o b = diff o a, diff o b, diff a b
400 | x : cong o x o b, perp a x o x; y : cong o y o b, perp a y o y
401 | tangent a o b
402 |
403 | on_circum x a b c
404 | x : a b c
405 | a b c = ncoll a b c
406 | x : cyclic a b c x
407 | cyclic a b c
408 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | gdown --folder https://bit.ly/alphageometry
17 | export DATA=ag_ckpt_vocab
18 |
--------------------------------------------------------------------------------
/examples.txt:
--------------------------------------------------------------------------------
1 | orthocenter
2 | a b c = triangle; h = on_tline b a c, on_tline c a b ? perp a h b c
3 | orthocenter_aux
4 | a b c = triangle; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c
5 | incenter_excenter
6 | a b c = triangle a b c; d1 d2 d3 d = incenter2 a b c; e1 e2 e3 e = excenter2 a b c ? perp d c c e
7 | euler
8 | a b c = triangle a b c; h = orthocenter a b c; h1 = foot a b c; h2 = foot b c a; h3 = foot c a b; g1 g2 g3 g = centroid g1 g2 g3 g a b c; o = circle a b c ? coll h g o
9 |
--------------------------------------------------------------------------------
/geometry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Implements geometric objects used in the graph representation."""
17 | from __future__ import annotations
18 | from collections import defaultdict # pylint: disable=g-importing-member
19 | from typing import Any, Type
20 |
21 | # pylint: disable=protected-access
22 |
23 |
24 | class Node:
25 | r"""Node in the proof state graph.
26 |
27 | Can be Point, Line, Circle, etc.
28 |
29 | Each node maintains a merge history to
30 | other nodes if they are (found out to be) equivalent
31 |
32 | a -> b -
33 | \
34 | c -> d -> e -> f -> g
35 |
36 | d.merged_to = e
37 | d.rep = g
38 | d.merged_from = {a, b, c, d}
39 | d.equivs = {a, b, c, d, e, f, g}
40 | """
41 |
42 | def __init__(self, name: str = '', graph: Any = None):
43 | self.name = name or str(self)
44 | self.graph = graph
45 |
46 | self.edge_graph = {}
47 | # Edge graph: what other nodes is connected to this node.
48 | # edge graph = {
49 | # other1: {self1: deps, self2: deps},
50 | # other2: {self2: deps, self3: deps}
51 | # }
52 |
53 | self.merge_graph = {}
54 | # Merge graph: history of merges with other nodes.
55 | # merge_graph = {self1: {self2: deps1, self3: deps2}}
56 |
57 | self.rep_by = None # represented by.
58 | self.members = {self}
59 |
60 | self._val = None
61 | self._obj = None
62 |
63 | self.deps = []
64 |
65 | # numerical representation.
66 | self.num = None
67 | self.change = set() # what other nodes' num rely on this node?
68 |
69 | def set_rep(self, node: Node) -> None:
70 | if node == self:
71 | return
72 | self.rep_by = node
73 | node.merge_edge_graph(self.edge_graph)
74 | node.members.update(self.members)
75 |
76 | def rep(self) -> Node:
77 | x = self
78 | while x.rep_by:
79 | x = x.rep_by
80 | return x
81 |
82 | def why_rep(self) -> list[Any]:
83 | return self.why_equal([self.rep()], None)
84 |
85 | def rep_and_why(self) -> tuple[Node, list[Any]]:
86 | rep = self.rep()
87 | return rep, self.why_equal([rep], None)
88 |
89 | def neighbors(
90 | self, oftype: Type[Node], return_set: bool = False, do_rep: bool = True
91 | ) -> list[Node]:
92 | """Neighbors of this node in the proof state graph."""
93 | if do_rep:
94 | rep = self.rep()
95 | else:
96 | rep = self
97 | result = set()
98 |
99 | for n in rep.edge_graph:
100 | if oftype is None or oftype and isinstance(n, oftype):
101 | if do_rep:
102 | result.add(n.rep())
103 | else:
104 | result.add(n)
105 |
106 | if return_set:
107 | return result
108 | return list(result)
109 |
110 | def merge_edge_graph(
111 | self, new_edge_graph: dict[Node, dict[Node, list[Node]]]
112 | ) -> None:
113 | for x, xdict in new_edge_graph.items():
114 | if x in self.edge_graph:
115 | self.edge_graph[x].update(dict(xdict))
116 | else:
117 | self.edge_graph[x] = dict(xdict)
118 |
119 | def merge(self, nodes: list[Node], deps: list[Any]) -> None:
120 | for node in nodes:
121 | self.merge_one(node, deps)
122 |
123 | def merge_one(self, node: Node, deps: list[Any]) -> None:
124 | node.rep().set_rep(self.rep())
125 |
126 | if node in self.merge_graph:
127 | return
128 |
129 | self.merge_graph[node] = deps
130 | node.merge_graph[self] = deps
131 |
132 | def is_val(self, node: Node) -> bool:
133 | return (
134 | isinstance(self, Line)
135 | and isinstance(node, Direction)
136 | or isinstance(self, Segment)
137 | and isinstance(node, Length)
138 | or isinstance(self, Angle)
139 | and isinstance(node, Measure)
140 | or isinstance(self, Ratio)
141 | and isinstance(node, Value)
142 | )
143 |
144 | def set_val(self, node: Node) -> None:
145 | self._val = node
146 |
147 | def set_obj(self, node: Node) -> None:
148 | self._obj = node
149 |
150 | @property
151 | def val(self) -> Node:
152 | if self._val is None:
153 | return None
154 | return self._val.rep()
155 |
156 | @property
157 | def obj(self) -> Node:
158 | if self._obj is None:
159 | return None
160 | return self._obj.rep()
161 |
162 | def equivs(self) -> set[Node]:
163 | return self.rep().members
164 |
165 | def connect_to(self, node: Node, deps: list[Any] = None) -> None:
166 | rep = self.rep()
167 |
168 | if node in rep.edge_graph:
169 | rep.edge_graph[node].update({self: deps})
170 | else:
171 | rep.edge_graph[node] = {self: deps}
172 |
173 | if self.is_val(node):
174 | self.set_val(node)
175 | node.set_obj(self)
176 |
177 | def equivs_upto(self, level: int) -> dict[Node, Node]:
178 | """What are the equivalent nodes up to a certain level."""
179 | parent = {self: None}
180 | visited = set()
181 | queue = [self]
182 | i = 0
183 |
184 | while i < len(queue):
185 | current = queue[i]
186 | i += 1
187 | visited.add(current)
188 |
189 | for neighbor in current.merge_graph:
190 | if (
191 | level is not None
192 | and current.merge_graph[neighbor].level is not None
193 | and current.merge_graph[neighbor].level >= level
194 | ):
195 | continue
196 | if neighbor not in visited:
197 | queue.append(neighbor)
198 | parent[neighbor] = current
199 |
200 | return parent
201 |
202 | def why_equal(self, others: list[Node], level: int) -> list[Any]:
203 | """BFS why this node is equal to other nodes."""
204 | others = set(others)
205 | found = 0
206 |
207 | parent = {}
208 | queue = [self]
209 | i = 0
210 |
211 | while i < len(queue):
212 | current = queue[i]
213 | if current in others:
214 | found += 1
215 | if found == len(others):
216 | break
217 |
218 | i += 1
219 |
220 | for neighbor in current.merge_graph:
221 | if (
222 | level is not None
223 | and current.merge_graph[neighbor].level is not None
224 | and current.merge_graph[neighbor].level >= level
225 | ):
226 | continue
227 | if neighbor not in parent:
228 | queue.append(neighbor)
229 | parent[neighbor] = current
230 |
231 | return bfs_backtrack(self, others, parent)
232 |
233 | def why_equal_groups(
234 | self, groups: list[list[Node]], level: int
235 | ) -> tuple[list[Any], list[Node]]:
236 | """BFS for why self is equal to at least one member of each group."""
237 | others = [None for _ in groups]
238 | found = 0
239 |
240 | parent = {}
241 | queue = [self]
242 | i = 0
243 |
244 | while i < len(queue):
245 | current = queue[i]
246 |
247 | for j, grp in enumerate(groups):
248 | if others[j] is None and current in grp:
249 | others[j] = current
250 | found += 1
251 |
252 | if found == len(others):
253 | break
254 |
255 | i += 1
256 |
257 | for neighbor in current.merge_graph:
258 | if (
259 | level is not None
260 | and current.merge_graph[neighbor].level is not None
261 | and current.merge_graph[neighbor].level >= level
262 | ):
263 | continue
264 | if neighbor not in parent:
265 | queue.append(neighbor)
266 | parent[neighbor] = current
267 |
268 | return bfs_backtrack(self, others, parent), others
269 |
270 | def why_val(self, level: int) -> list[Any]:
271 | return self._val.why_equal([self.val], level)
272 |
273 | def why_connect(self, node: Node, level: int = None) -> list[Any]:
274 | rep = self.rep()
275 | equivs = list(rep.edge_graph[node].keys())
276 | if not equivs:
277 | return None
278 | equiv = equivs[0]
279 | dep = rep.edge_graph[node][equiv]
280 | return [dep] + self.why_equal(equiv, level)
281 |
282 |
283 | def why_connect(*pairs: list[tuple[Node, Node]]) -> list[Any]:
284 | result = []
285 | for node1, node2 in pairs:
286 | result += node1.why_connect(node2)
287 | return result
288 |
289 |
290 | def is_equiv(x: Node, y: Node, level: int = None) -> bool:
291 | level = level or float('inf')
292 | return x.why_equal([y], level) is not None
293 |
294 |
295 | def is_equal(x: Node, y: Node, level: int = None) -> bool:
296 | if x == y:
297 | return True
298 | if x._val is None or y._val is None:
299 | return False
300 | if x.val != y.val:
301 | return False
302 | return is_equiv(x._val, y._val, level)
303 |
304 |
305 | def bfs_backtrack(
306 | root: Node, leafs: list[Node], parent: dict[Node, Node]
307 | ) -> list[Any]:
308 | """Return the path given BFS trace of parent nodes."""
309 | backtracked = {root} # no need to backtrack further when touching this set.
310 | deps = []
311 | for node in leafs:
312 | if node is None:
313 | return None
314 | if node in backtracked:
315 | continue
316 | if node not in parent:
317 | return None
318 | while node not in backtracked:
319 | backtracked.add(node)
320 | deps.append(node.merge_graph[parent[node]])
321 | node = parent[node]
322 |
323 | return deps
324 |
325 |
326 | class Point(Node):
327 | pass
328 |
329 |
330 | class Line(Node):
331 | """Node of type Line."""
332 |
333 | def new_val(self) -> Direction:
334 | return Direction()
335 |
336 | def why_coll(self, points: list[Point], level: int = None) -> list[Any]:
337 | """Why points are connected to self."""
338 | level = level or float('inf')
339 |
340 | groups = []
341 | for p in points:
342 | group = [
343 | l
344 | for l, d in self.edge_graph[p].items()
345 | if d is None or d.level < level
346 | ]
347 | if not group:
348 | return None
349 | groups.append(group)
350 |
351 | min_deps = None
352 | for line in groups[0]:
353 | deps, others = line.why_equal_groups(groups[1:], level)
354 | if deps is None:
355 | continue
356 | for p, o in zip(points, [line] + others):
357 | deps.append(self.edge_graph[p][o])
358 | if min_deps is None or len(deps) < len(min_deps):
359 | min_deps = deps
360 |
361 | if min_deps is None:
362 | return None
363 | return [d for d in min_deps if d is not None]
364 |
365 |
366 | class Segment(Node):
367 |
368 | def new_val(self) -> Length:
369 | return Length()
370 |
371 |
372 | class Circle(Node):
373 | """Node of type Circle."""
374 |
375 | def why_cyclic(self, points: list[Point], level: int = None) -> list[Any]:
376 | """Why points are connected to self."""
377 | level = level or float('inf')
378 |
379 | groups = []
380 | for p in points:
381 | group = [
382 | c
383 | for c, d in self.edge_graph[p].items()
384 | if d is None or d.level < level
385 | ]
386 | if not group:
387 | return None
388 | groups.append(group)
389 |
390 | min_deps = None
391 | for circle in groups[0]:
392 | deps, others = circle.why_equal_groups(groups[1:], level)
393 | if deps is None:
394 | continue
395 | for p, o in zip(points, [circle] + others):
396 | deps.append(self.edge_graph[p][o])
397 |
398 | if min_deps is None or len(deps) < len(min_deps):
399 | min_deps = deps
400 |
401 | if min_deps is None:
402 | return None
403 | return [d for d in min_deps if d is not None]
404 |
405 |
406 | def why_equal(x: Node, y: Node, level: int = None) -> list[Any]:
407 | if x == y:
408 | return []
409 | if not x._val or not y._val:
410 | return None
411 | if x._val == y._val:
412 | return []
413 | return x._val.why_equal([y._val], level)
414 |
415 |
416 | class Direction(Node):
417 | pass
418 |
419 |
420 | def get_lines_thru_all(*points: list[Point]) -> list[Line]:
421 | line2count = defaultdict(lambda: 0)
422 | points = set(points)
423 | for p in points:
424 | for l in p.neighbors(Line):
425 | line2count[l] += 1
426 | return [l for l, count in line2count.items() if count == len(points)]
427 |
428 |
429 | def line_of_and_why(
430 | points: list[Point], level: int = None
431 | ) -> tuple[Line, list[Any]]:
432 | """Why points are collinear."""
433 | for l0 in get_lines_thru_all(*points):
434 | for l in l0.equivs():
435 | if all([p in l.edge_graph for p in points]):
436 | x, y = l.points
437 | colls = list({x, y} | set(points))
438 | # if len(colls) < 3:
439 | # return l, []
440 | why = l.why_coll(colls, level)
441 | if why is not None:
442 | return l, why
443 |
444 | return None, None
445 |
446 |
447 | def get_circles_thru_all(*points: list[Point]) -> list[Circle]:
448 | circle2count = defaultdict(lambda: 0)
449 | points = set(points)
450 | for p in points:
451 | for c in p.neighbors(Circle):
452 | circle2count[c] += 1
453 | return [c for c, count in circle2count.items() if count == len(points)]
454 |
455 |
456 | def circle_of_and_why(
457 | points: list[Point], level: int = None
458 | ) -> tuple[Circle, list[Any]]:
459 | """Why points are concyclic."""
460 | for c0 in get_circles_thru_all(*points):
461 | for c in c0.equivs():
462 | if all([p in c.edge_graph for p in points]):
463 | cycls = list(set(points))
464 | why = c.why_cyclic(cycls, level)
465 | if why is not None:
466 | return c, why
467 |
468 | return None, None
469 |
470 |
471 | def name_map(struct: Any) -> Any:
472 | if isinstance(struct, list):
473 | return [name_map(x) for x in struct]
474 | elif isinstance(struct, tuple):
475 | return tuple([name_map(x) for x in struct])
476 | elif isinstance(struct, set):
477 | return set([name_map(x) for x in struct])
478 | elif isinstance(struct, dict):
479 | return {name_map(x): name_map(y) for x, y in struct.items()}
480 | else:
481 | return getattr(struct, 'name', '')
482 |
483 |
484 | class Angle(Node):
485 | """Node of type Angle."""
486 |
487 | def new_val(self) -> Measure:
488 | return Measure()
489 |
490 | def set_directions(self, d1: Direction, d2: Direction) -> None:
491 | self._d = d1, d2
492 |
493 | @property
494 | def directions(self) -> tuple[Direction, Direction]:
495 | d1, d2 = self._d
496 | if d1 is None or d2 is None:
497 | return d1, d2
498 | return d1.rep(), d2.rep()
499 |
500 |
501 | class Measure(Node):
502 | pass
503 |
504 |
505 | class Length(Node):
506 | pass
507 |
508 |
509 | class Ratio(Node):
510 | """Node of type Ratio."""
511 |
512 | def new_val(self) -> Value:
513 | return Value()
514 |
515 | def set_lengths(self, l1: Length, l2: Length) -> None:
516 | self._l = l1, l2
517 |
518 | @property
519 | def lengths(self) -> tuple[Length, Length]:
520 | l1, l2 = self._l
521 | if l1 is None or l2 is None:
522 | return l1, l2
523 | return l1.rep(), l2.rep()
524 |
525 |
526 | class Value(Node):
527 | pass
528 |
529 |
530 | def all_angles(
531 | d1: Direction, d2: Direction, level: int = None
532 | ) -> tuple[Angle, list[Direction], list[Direction]]:
533 | level = level or float('inf')
534 | d1s = d1.equivs_upto(level)
535 | d2s = d2.equivs_upto(level)
536 |
537 | for ang in d1.rep().neighbors(Angle):
538 | d1_, d2_ = ang._d
539 | if d1_ in d1s and d2_ in d2s:
540 | yield ang, d1s, d2s
541 |
542 |
543 | def all_ratios(
544 | d1, d2, level=None
545 | ) -> tuple[Angle, list[Direction], list[Direction]]:
546 | level = level or float('inf')
547 | d1s = d1.equivs_upto(level)
548 | d2s = d2.equivs_upto(level)
549 |
550 | for ang in d1.rep().neighbors(Ratio):
551 | d1_, d2_ = ang._l
552 | if d1_ in d1s and d2_ in d2s:
553 | yield ang, d1s, d2s
554 |
555 |
556 | RANKING = {
557 | Point: 0,
558 | Line: 1,
559 | Segment: 2,
560 | Circle: 3,
561 | Direction: 4,
562 | Length: 5,
563 | Angle: 6,
564 | Ratio: 7,
565 | Measure: 8,
566 | Value: 9,
567 | }
568 |
569 |
570 | def val_type(x: Node) -> Type[Node]:
571 | if isinstance(x, Line):
572 | return Direction
573 | if isinstance(x, Segment):
574 | return Length
575 | if isinstance(x, Angle):
576 | return Measure
577 | if isinstance(x, Ratio):
578 | return Value
579 |
--------------------------------------------------------------------------------
/geometry_150M_generate.gin:
--------------------------------------------------------------------------------
1 | NUM_EMBEDDINGS = 1024
2 |
3 | # Number of parameters = 152M
4 | NUM_LAYERS = 12
5 | EMBED_DIM = 1024
6 | NUM_HEADS = 8
7 | HEAD_DIM = 128
8 | MLP_DIM = 4096
9 |
10 |
11 | transformer_layer.TransformerLayerGenerate:
12 | num_heads = %NUM_HEADS
13 | head_size = %HEAD_DIM
14 | window_length = 1024
15 | use_long_xl_architecture = False
16 | max_unrolled_windows = -1 # Always unroll.
17 | relative_position_type = "t5" # Can be "fourier", "t5", or None.
18 | use_causal_mask = True
19 | attn_dropout_rate = %ATTN_DROPOUT_RATE # Attention matrix dropout.
20 | memory_num_neighbors = 0
21 | dtype = %DTYPE
22 |
23 | decoder_stack.DecoderStackGenerate:
24 | num_layers = %NUM_LAYERS
25 | embedding_size = %EMBED_DIM
26 | embedding_stddev = 1.0
27 | layer_factory = @transformer_layer.TransformerLayerGenerate
28 | dstack_window_length = 0
29 | use_absolute_positions = False
30 | use_final_layernorm = True # Final layernorm before token lookup.
31 | final_dropout_rate = %DROPOUT_RATE # Dropout before token lookup.
32 | final_mlp_factory = None # Final MLP to predict target tokens.
33 | recurrent_layer_indices = ()
34 | memory_factory = None # e.g. @memory_factory.memory_on_tpu_factory
35 | memory_layer_indices = ()
36 | dtype = %DTYPE
37 |
38 |
39 | models.DecoderOnlyLanguageModelGenerate:
40 | num_heads = %NUM_HEADS
41 | head_size = %HEAD_DIM
42 | task_config = @decoder_stack.TransformerTaskConfig()
43 | decoder_factory = @decoder_stack.DecoderStackGenerate
44 |
45 |
46 | training_loop.Trainer:
47 | model_definition = @models.DecoderOnlyLanguageModelGenerate
48 |
--------------------------------------------------------------------------------
/geometry_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Unit tests for geometry.py."""
17 | import unittest
18 |
19 | from absl.testing import absltest
20 | import geometry as gm
21 |
22 |
23 | class GeometryTest(unittest.TestCase):
24 |
25 | def _setup_equality_example(self):
26 | # Create 4 nodes a, b, c, d
27 | # and their lengths
28 | a = gm.Segment('a')
29 | la = gm.Length('l(a)')
30 | a.connect_to(la)
31 | la.connect_to(a)
32 |
33 | b = gm.Segment('b')
34 | lb = gm.Length('l(b)')
35 | b.connect_to(lb)
36 | lb.connect_to(b)
37 |
38 | c = gm.Segment('c')
39 | lc = gm.Length('l(c)')
40 | c.connect_to(lc)
41 | lc.connect_to(c)
42 |
43 | d = gm.Segment('d')
44 | ld = gm.Length('l(d)')
45 | d.connect_to(ld)
46 | ld.connect_to(d)
47 |
48 | # Now let a=b, b=c, a=c, c=d
49 | la.merge([lb], 'fact1')
50 | lb.merge([lc], 'fact2')
51 | la.merge([lc], 'fact3')
52 | lc.merge([ld], 'fact4')
53 | return a, b, c, d, la, lb, lc, ld
54 |
55 | def test_merged_node_representative(self):
56 | _, _, _, _, la, lb, lc, ld = self._setup_equality_example()
57 |
58 | # all nodes are now represented by la.
59 | self.assertEqual(la.rep(), la)
60 | self.assertEqual(lb.rep(), la)
61 | self.assertEqual(lc.rep(), la)
62 | self.assertEqual(ld.rep(), la)
63 |
64 | def test_merged_node_equivalence(self):
65 | _, _, _, _, la, lb, lc, ld = self._setup_equality_example()
66 | # all la, lb, lc, ld are equivalent
67 | self.assertCountEqual(la.equivs(), [la, lb, lc, ld])
68 | self.assertCountEqual(lb.equivs(), [la, lb, lc, ld])
69 | self.assertCountEqual(lc.equivs(), [la, lb, lc, ld])
70 | self.assertCountEqual(ld.equivs(), [la, lb, lc, ld])
71 |
72 | def test_bfs_for_equality_transitivity(self):
73 | a, _, _, d, _, _, _, _ = self._setup_equality_example()
74 |
75 | # check that a==d because fact3 & fact4, not fact1 & fact2
76 | self.assertCountEqual(gm.why_equal(a, d), ['fact3', 'fact4'])
77 |
78 |
79 | if __name__ == '__main__':
80 | absltest.main()
81 |
--------------------------------------------------------------------------------
/graph_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Unit tests for graph.py."""
17 | import unittest
18 |
19 | from absl.testing import absltest
20 | import graph as gh
21 | import numericals as nm
22 | import problem as pr
23 |
24 |
25 | MAX_LEVEL = 1000
26 |
27 |
28 | class GraphTest(unittest.TestCase):
29 |
30 | @classmethod
31 | def setUpClass(cls):
32 | super().setUpClass()
33 |
34 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
35 | cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True)
36 |
37 | # load a complex setup:
38 | txt = 'a b c = triangle a b c; h = orthocenter a b c; h1 = foot a b c; h2 = foot b c a; h3 = foot c a b; g1 g2 g3 g = centroid g1 g2 g3 g a b c; o = circle a b c ? coll h g o' # pylint: disable=line-too-long
39 | p = pr.Problem.from_txt(txt, translate=False)
40 | cls.g, _ = gh.Graph.build_problem(p, GraphTest.defs)
41 |
42 | def test_build_graph_points(self):
43 | g = GraphTest.g
44 |
45 | all_points = g.all_points()
46 | all_names = [p.name for p in all_points]
47 | self.assertCountEqual(
48 | all_names,
49 | ['a', 'b', 'c', 'g', 'h', 'o', 'g1', 'g2', 'g3', 'h1', 'h2', 'h3'],
50 | )
51 |
52 | def test_build_graph_predicates(self):
53 | gr = GraphTest.g
54 |
55 | a, b, c, g, h, o, g1, g2, g3, h1, h2, h3 = gr.names2points(
56 | ['a', 'b', 'c', 'g', 'h', 'o', 'g1', 'g2', 'g3', 'h1', 'h2', 'h3']
57 | )
58 |
59 | # Explicit statements:
60 | self.assertTrue(gr.check_cong([b, g1, g1, c]))
61 | self.assertTrue(gr.check_cong([c, g2, g2, a]))
62 | self.assertTrue(gr.check_cong([a, g3, g3, b]))
63 | self.assertTrue(gr.check_perp([a, h1, b, c]))
64 | self.assertTrue(gr.check_perp([b, h2, c, a]))
65 | self.assertTrue(gr.check_perp([c, h3, a, b]))
66 | self.assertTrue(gr.check_cong([o, a, o, b]))
67 | self.assertTrue(gr.check_cong([o, b, o, c]))
68 | self.assertTrue(gr.check_cong([o, a, o, c]))
69 | self.assertTrue(gr.check_coll([a, g, g1]))
70 | self.assertTrue(gr.check_coll([b, g, g2]))
71 | self.assertTrue(gr.check_coll([g1, b, c]))
72 | self.assertTrue(gr.check_coll([g2, c, a]))
73 | self.assertTrue(gr.check_coll([g3, a, b]))
74 | self.assertTrue(gr.check_perp([a, h, b, c]))
75 | self.assertTrue(gr.check_perp([b, h, c, a]))
76 |
77 | # These are NOT part of the premises:
78 | self.assertFalse(gr.check_perp([c, h, a, b]))
79 | self.assertFalse(gr.check_coll([c, g, g3]))
80 |
81 | # These are automatically inferred by the graph datastructure:
82 | self.assertTrue(gr.check_eqangle([a, h1, b, c, b, h2, c, a]))
83 | self.assertTrue(gr.check_eqangle([a, h1, b, h2, b, c, c, a]))
84 | self.assertTrue(gr.check_eqratio([b, g1, g1, c, c, g2, g2, a]))
85 | self.assertTrue(gr.check_eqratio([b, g1, g1, c, o, a, o, b]))
86 | self.assertTrue(gr.check_para([a, h, a, h1]))
87 | self.assertTrue(gr.check_para([b, h, b, h2]))
88 | self.assertTrue(gr.check_coll([a, h, h1]))
89 | self.assertTrue(gr.check_coll([b, h, h2]))
90 |
91 | def test_enumerate_colls(self):
92 | g = GraphTest.g
93 |
94 | for a, b, c in g.all_colls():
95 | self.assertTrue(g.check_coll([a, b, c]))
96 | self.assertTrue(nm.check_coll([a.num, b.num, c.num]))
97 |
98 | def test_enumerate_paras(self):
99 | g = GraphTest.g
100 |
101 | for a, b, c, d in g.all_paras():
102 | self.assertTrue(g.check_para([a, b, c, d]))
103 | self.assertTrue(nm.check_para([a.num, b.num, c.num, d.num]))
104 |
105 | def test_enumerate_perps(self):
106 | g = GraphTest.g
107 |
108 | for a, b, c, d in g.all_perps():
109 | self.assertTrue(g.check_perp([a, b, c, d]))
110 | self.assertTrue(nm.check_perp([a.num, b.num, c.num, d.num]))
111 |
112 | def test_enumerate_congs(self):
113 | g = GraphTest.g
114 |
115 | for a, b, c, d in g.all_congs():
116 | self.assertTrue(g.check_cong([a, b, c, d]))
117 | self.assertTrue(nm.check_cong([a.num, b.num, c.num, d.num]))
118 |
119 | def test_enumerate_eqangles(self):
120 | g = GraphTest.g
121 |
122 | for a, b, c, d, x, y, z, t in g.all_eqangles_8points():
123 | self.assertTrue(g.check_eqangle([a, b, c, d, x, y, z, t]))
124 | self.assertTrue(
125 | nm.check_eqangle(
126 | [a.num, b.num, c.num, d.num, x.num, y.num, z.num, t.num]
127 | )
128 | )
129 |
130 | def test_enumerate_eqratios(self):
131 | g = GraphTest.g
132 |
133 | for a, b, c, d, x, y, z, t in g.all_eqratios_8points():
134 | self.assertTrue(g.check_eqratio([a, b, c, d, x, y, z, t]))
135 | self.assertTrue(
136 | nm.check_eqratio(
137 | [a.num, b.num, c.num, d.num, x.num, y.num, z.num, t.num]
138 | )
139 | )
140 |
141 | def test_enumerate_cyclics(self):
142 | g = GraphTest.g
143 |
144 | for a, b, c, d, x, y, z, t in g.all_cyclics():
145 | self.assertTrue(g.check_cyclic([a, b, c, d, x, y, z, t]))
146 | self.assertTrue(nm.check_cyclic([a.num, b.num, c.num, d.num]))
147 |
148 | def test_enumerate_midps(self):
149 | g = GraphTest.g
150 |
151 | for a, b, c in g.all_midps():
152 | self.assertTrue(g.check_midp([a, b, c]))
153 | self.assertTrue(nm.check_midp([a.num, b.num, c.num]))
154 |
155 | def test_enumerate_circles(self):
156 | g = GraphTest.g
157 |
158 | for a, b, c, d in g.all_circles():
159 | self.assertTrue(g.check_circle([a, b, c, d]))
160 | self.assertTrue(nm.check_circle([a.num, b.num, c.num, d.num]))
161 |
162 |
163 | if __name__ == '__main__':
164 | absltest.main()
165 |
--------------------------------------------------------------------------------
/graph_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utilizations for graph representation.
17 |
18 | Mainly for listing combinations and permutations of elements.
19 | """
20 |
21 | from geometry import Point
22 |
23 |
24 | def _cross(elems1, elems2):
25 | for e1 in elems1:
26 | for e2 in elems2:
27 | yield e1, e2
28 |
29 |
30 | def cross(elems1, elems2):
31 | return list(_cross(elems1, elems2))
32 |
33 |
34 | def _comb2(elems):
35 | if len(elems) < 2:
36 | return
37 | for i, e1 in enumerate(elems[:-1]):
38 | for e2 in elems[i + 1 :]:
39 | yield e1, e2
40 |
41 |
42 | def comb2(elems):
43 | return list(_comb2(elems))
44 |
45 |
46 | def _comb3(elems):
47 | if len(elems) < 3:
48 | return
49 | for i, e1 in enumerate(elems[:-2]):
50 | for j, e2 in enumerate(elems[i + 1 : -1]):
51 | for e3 in elems[i + j + 2 :]:
52 | yield e1, e2, e3
53 |
54 |
55 | def comb3(elems):
56 | return list(_comb3(elems))
57 |
58 |
59 | def _comb4(elems):
60 | if len(elems) < 4:
61 | return
62 | for i, e1 in enumerate(elems[:-3]):
63 | for j, e2 in enumerate(elems[i + 1 : -2]):
64 | for e3, e4 in _comb2(elems[i + j + 2 :]):
65 | yield e1, e2, e3, e4
66 |
67 |
68 | def comb4(elems):
69 | return list(_comb4(elems))
70 |
71 |
72 | def _perm2(elems):
73 | for e1, e2 in comb2(elems):
74 | yield e1, e2
75 | yield e2, e1
76 |
77 |
78 | def perm2(elems):
79 | return list(_perm2(elems))
80 |
81 |
82 | def _all_4points(l1, l2):
83 | p1s = l1.neighbors(Point)
84 | p2s = l2.neighbors(Point)
85 | for a, b in perm2(p1s):
86 | for c, d in perm2(p2s):
87 | yield a, b, c, d
88 |
89 |
90 | def all_4points(l1, l2):
91 | return list(_all_4points(l1, l2))
92 |
93 |
94 | def _all_8points(l1, l2, l3, l4):
95 | for a, b, c, d in all_4points(l1, l2):
96 | for e, f, g, h in all_4points(l3, l4):
97 | yield (a, b, c, d, e, f, g, h)
98 |
99 |
100 | def all_8points(l1, l2, l3, l4):
101 | return list(_all_8points(l1, l2, l3, l4))
102 |
103 |
104 | def _perm3(elems):
105 | for x in elems:
106 | for y in elems:
107 | if y == x:
108 | continue
109 | for z in elems:
110 | if z not in (x, y):
111 | yield x, y, z
112 |
113 |
114 | def perm3(elems):
115 | return list(_perm3(elems))
116 |
117 |
118 | def _perm4(elems):
119 | for x in elems:
120 | for y in elems:
121 | if y == x:
122 | continue
123 | for z in elems:
124 | if z in (x, y):
125 | continue
126 | for t in elems:
127 | if t not in (x, y, z):
128 | yield x, y, z, t
129 |
130 |
131 | def perm4(elems):
132 | return list(_perm4(elems))
133 |
--------------------------------------------------------------------------------
/graph_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Unit tests for graph_utils.py."""
17 | import unittest
18 |
19 | from absl.testing import absltest
20 | import graph_utils as gu
21 |
22 |
23 | class GraphUtilsTest(unittest.TestCase):
24 |
25 | def test_cross(self):
26 | self.assertEqual(gu.cross([], [1]), [])
27 | self.assertEqual(gu.cross([1], []), [])
28 | self.assertEqual(gu.cross([1], [2]), [(1, 2)])
29 | self.assertEqual(gu.cross([1], [2, 3]), [(1, 2), (1, 3)])
30 |
31 | e1 = [1, 2, 3]
32 | e2 = [4, 5]
33 | target = [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
34 | self.assertEqual(gu.cross(e1, e2), target)
35 |
36 | def test_comb2(self):
37 | self.assertEqual(gu.comb2([]), [])
38 | self.assertEqual(gu.comb2([1]), [])
39 | self.assertEqual(gu.comb2([1, 2]), [(1, 2)])
40 | self.assertEqual(gu.comb2([1, 2, 3]), [(1, 2), (1, 3), (2, 3)])
41 |
42 | def test_comb3(self):
43 | self.assertEqual(gu.comb3([]), [])
44 | self.assertEqual(gu.comb3([1]), [])
45 | self.assertEqual(gu.comb3([1, 2]), [])
46 | self.assertEqual(gu.comb3([1, 2, 3]), [(1, 2, 3)])
47 | self.assertEqual(
48 | gu.comb3([1, 2, 3, 4]), [(1, 2, 3), (1, 2, 4), (1, 3, 4), (2, 3, 4)]
49 | )
50 |
51 | def test_comb4(self):
52 | self.assertEqual(gu.comb4([]), [])
53 | self.assertEqual(gu.comb4([1]), [])
54 | self.assertEqual(gu.comb4([1, 2]), [])
55 | self.assertEqual(gu.comb4([1, 2, 3]), [])
56 | self.assertEqual(gu.comb4([1, 2, 3, 4]), [(1, 2, 3, 4)])
57 | self.assertEqual(
58 | gu.comb4([1, 2, 3, 4, 5]),
59 | [(1, 2, 3, 4), (1, 2, 3, 5), (1, 2, 4, 5), (1, 3, 4, 5), (2, 3, 4, 5)],
60 | )
61 |
62 | def test_perm2(self):
63 | self.assertEqual(gu.perm2([]), [])
64 | self.assertEqual(gu.perm2([1]), [])
65 | self.assertEqual(gu.perm2([1, 2]), [(1, 2), (2, 1)])
66 | self.assertEqual(
67 | gu.perm2([1, 2, 3]), [(1, 2), (2, 1), (1, 3), (3, 1), (2, 3), (3, 2)]
68 | )
69 |
70 | def test_perm3(self):
71 | self.assertEqual(gu.perm3([]), [])
72 | self.assertEqual(gu.perm3([1]), [])
73 | self.assertEqual(gu.perm3([1, 2]), [])
74 | self.assertEqual(
75 | gu.perm3([1, 2, 3]),
76 | [(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)],
77 | )
78 | self.assertEqual(
79 | gu.perm3([1, 2, 3, 4]),
80 | [
81 | (1, 2, 3),
82 | (1, 2, 4),
83 | (1, 3, 2),
84 | (1, 3, 4),
85 | (1, 4, 2),
86 | (1, 4, 3),
87 | (2, 1, 3),
88 | (2, 1, 4),
89 | (2, 3, 1),
90 | (2, 3, 4),
91 | (2, 4, 1),
92 | (2, 4, 3),
93 | (3, 1, 2),
94 | (3, 1, 4),
95 | (3, 2, 1),
96 | (3, 2, 4),
97 | (3, 4, 1),
98 | (3, 4, 2),
99 | (4, 1, 2),
100 | (4, 1, 3),
101 | (4, 2, 1),
102 | (4, 2, 3),
103 | (4, 3, 1),
104 | (4, 3, 2),
105 | ],
106 | )
107 |
108 | def test_perm4(self):
109 | self.assertEqual(gu.perm3([]), [])
110 | self.assertEqual(gu.perm3([1]), [])
111 | self.assertEqual(gu.perm3([1, 2]), [])
112 | self.assertEqual(gu.perm4([1, 2, 3]), [])
113 | self.assertEqual(
114 | gu.perm4([1, 2, 3, 4]),
115 | [
116 | (1, 2, 3, 4),
117 | (1, 2, 4, 3),
118 | (1, 3, 2, 4),
119 | (1, 3, 4, 2),
120 | (1, 4, 2, 3),
121 | (1, 4, 3, 2), # pylint: disable=line-too-long
122 | (2, 1, 3, 4),
123 | (2, 1, 4, 3),
124 | (2, 3, 1, 4),
125 | (2, 3, 4, 1),
126 | (2, 4, 1, 3),
127 | (2, 4, 3, 1), # pylint: disable=line-too-long
128 | (3, 1, 2, 4),
129 | (3, 1, 4, 2),
130 | (3, 2, 1, 4),
131 | (3, 2, 4, 1),
132 | (3, 4, 1, 2),
133 | (3, 4, 2, 1), # pylint: disable=line-too-long
134 | (4, 1, 2, 3),
135 | (4, 1, 3, 2),
136 | (4, 2, 1, 3),
137 | (4, 2, 3, 1),
138 | (4, 3, 1, 2),
139 | (4, 3, 2, 1),
140 | ], # pylint: disable=line-too-long
141 | )
142 |
143 |
144 | if __name__ == '__main__':
145 | absltest.main()
146 |
--------------------------------------------------------------------------------
/imo_ag_30.txt:
--------------------------------------------------------------------------------
1 | translated_imo_2000_p1
2 | a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m = on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b, on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q
3 | translated_imo_2000_p6
4 | a b c = triangle a b c; h = orthocenter h a b c; t1 t2 t3 i = incenter2 t1 t2 t3 i a b c; h1 = foot h1 a b c; h2 = foot h2 b c a; h3 = foot h3 c a b; x1 = reflect x1 h1 t1 t2; x2 = reflect x2 h2 t1 t2; y2 = reflect y2 h2 t2 t3; y3 = reflect y3 h3 t2 t3; z = on_line z x1 x2, on_line z y2 y3 ? cong i z i t1
5 | translated_imo_2002_p2a
6 | b c = segment b c; o = midpoint o b c; a = on_circle a o b; d = on_circle d o b, on_bline d a b; e = on_bline e o a, on_circle e o b; f = on_bline f o a, on_circle f o b; j = on_pline j o a d, on_line j a c ? eqangle e c e j e j e f
7 | translated_imo_2002_p2b
8 | b c = segment b c; o = midpoint o b c; a = on_circle a o b; d = on_circle d o b, on_bline d a b; e = on_bline e o a, on_circle e o b; f = on_bline f o a, on_circle f o b; j = on_pline j o a d, on_line j a c ? eqangle c e c j c j c f
9 | translated_imo_2003_p4
10 | a b c = triangle a b c; o = circle o a b c; b1 = on_circle b1 o a, on_bline b1 c a; d1 = on_circle d1 o a, on_bline d1 c a; x = on_line x b b1, on_line x a c; d = on_line d d1 x, on_circle d o a; p = foot p d b c; q = foot q d c a; r = foot r d a b ? cong p q q r
11 | translated_imo_2004_p1
12 | a b c = triangle a b c; o = midpoint o b c; m = on_circle m o b, on_line m a b; n = on_circle n o b, on_line n a c; r = angle_bisector r b a c, angle_bisector r m o n; o1 = circle o1 b m r; o2 = circle o2 c n r; p = on_circle p o1 r, on_circle p o2 r ? coll p b c
13 | translated_imo_2004_p5
14 | a b c = triangle a b c; o = circle o a b c; d = on_circle d o a; p = on_aline p b c a b d, on_aline p d c a d b ? cong a p c p
15 | translated_imo_2005_p5
16 | a b c = triangle a b c; d = eqdistance d a b c; e = on_line e b c; f = on_line f a d, eqdistance f d e b; p = on_line p a c, on_line p b d; q = on_line q e f, on_line q b d; r = on_line r e f, on_line r a c; o1 = circle o1 a p d; o2 = circle o2 b p c; m = on_circle m o1 p, on_circle m o2 p ? cyclic p q r m
17 | translated_imo_2007_p4
18 | a b c = triangle a b c; o = circle o a b c; r = on_circle r o a, on_bline r a b; l = midpoint l c a; k = midpoint k c b; p = on_line p o k, on_line p c r; q = on_line q o l, on_line q c r; l1 = foot l1 l c r; k1 = foot k1 k c r ? eqratio k k1 l l1 r q r p
19 | translated_imo_2008_p1a
20 | a b c = triangle a b c; h = orthocenter h a b c; d = midpoint d b c; e = midpoint e a c; f = midpoint f a b; a1 = on_circle a1 d h, on_line a1 b c; a2 = on_circle a2 d h, on_line a2 b c; b1 = on_circle b1 e h, on_line b1 c a; b2 = on_circle b2 e h, on_line b2 c a; c1 = on_circle c1 f h, on_line c1 a b; c2 = on_circle c2 f h, on_line c2 a b ? cyclic c1 c2 b1 b2
21 | translated_imo_2008_p1b
22 | a b c = triangle a b c; h = orthocenter h a b c; d = midpoint d b c; e = midpoint e a c; f = midpoint f a b; a1 = on_circle a1 d h, on_line a1 b c; a2 = on_circle a2 d h, on_line a2 b c; b1 = on_circle b1 e h, on_line b1 c a; b2 = on_circle b2 e h, on_line b2 c a; c1 = on_circle c1 f h, on_line c1 a b; c2 = on_circle c2 f h, on_line c2 a b ? cyclic c1 c2 b1 a1
23 | translated_imo_2008_p6
24 | x@4.96_-0.13 y@-1.0068968328888160_-1.2534881080682770 z@-2.8402847238575120_-4.9117762734006830 = triangle x y z; o = circle o x y z; w@6.9090049230038776_-1.3884003936987552 = on_circle w o x; a = on_tline a z o z, on_tline a x o x; b = on_tline b z o z, on_tline b w o w; c = on_tline c y o y, on_tline c w o w; d = on_tline d x o x, on_tline d y o y; i1 = incenter i1 a b c; i2 = incenter i2 a c d; f1 = foot f1 i1 a c; f2 = foot f2 i2 a c; q t p s = cc_tangent q t p s i1 f1 i2 f2; k = on_line k q t, on_line k p s ? cong o k o x
25 | translated_imo_2009_p2
26 | m l k = triangle m l k; w = circle w m l k; q = on_tline q m w m; p = mirror p q m; b = mirror b p k; c = mirror c q l; a = on_line a b q, on_line a c p; o = circle o a b c ? cong o p o q
27 | translated_imo_2010_p2
28 | a b c = triangle a b c; o = circle o a b c; i = incenter i a b c; d = on_line d a i, on_circle d o a; f = on_line f b c; e = on_aline e a c b a f, on_circle e o a; g = midpoint g i f; k = on_line k d g, on_line k e i ? cong o a o k
29 | translated_imo_2010_p4
30 | s c p = iso_triangle s c p; o = on_tline o c s c; a = on_circle a o c; b = on_circle b o c, on_line b s a; m = on_line m c p, on_circle m o c; l = on_line l b p, on_circle l o c; k = on_line k a p, on_circle k o c ? cong m k m l
31 | translated_imo_2011_p6
32 | a b c = triangle a b c; o = circle o a b c; p = on_circle p o a; q = on_tline q p o p; pa = reflect pa p b c; pb = reflect pb p c a; pc = reflect pc p a b; qa = reflect qa q b c; qb = reflect qb q c a; qc = reflect qc q a b; a1 = on_line a1 pb qb, on_line a1 pc qc; b1 = on_line b1 pa qa, on_line b1 pc qc; c1 = on_line c1 pa qa, on_line c1 pb qb; o1 = circle o1 a1 b1 c1; x = on_circle x o a, on_circle x o1 a1 ? coll x o o1
33 | translated_imo_2012_p1
34 | a b c = triangle a b c; m l k j = excenter2 m l k j a b c; f = on_line f m l, on_line f b j; g = on_line g m k, on_line g c j; s = on_line s f a, on_line s b c; t = on_line t g a, on_line t c b ? cong m s m t
35 | translated_imo_2012_p5
36 | c a b = r_triangle c a b; d = foot d c a b; x = on_line x c d; k = on_line k a x, on_circle k b c; l = on_line l b x, on_circle l a c; m = on_line m a l, on_line m b k ? cong m k m l
37 | translated_imo_2013_p4
38 | a b c = triangle a b c; h = orthocenter h a b c; m = on_line m h b, on_line m a c; n = on_line n h c, on_line n a b; w = on_line w b c; o1 = circle o1 b n w; o2 = circle o2 c m w; x = on_line x o1 w, on_circle x o1 w; y = on_line y o2 w, on_circle y o2 w ? coll x h y
39 | translated_imo_2014_p4
40 | a b c = triangle a b c; p = on_line p b c, on_aline p a b b c a; q = on_line q b c, on_aline q a c c b a; m = mirror m a p; n = mirror n a q; x = on_line x b m, on_line x c n; o = circle o a b c ? cong o x o a
41 | translated_imo_2015_p3
42 | a b c = triangle a b c; h = orthocenter h a b c; f = on_line f h a, on_line f b c; m = midpoint m b c; o = circle o a b c; q = on_dia q a h, on_circle q o a; k = on_dia k h q, on_circle k o a; o1 = circle o1 k q h; o2 = circle o2 f k m ? coll o1 o2 k
43 | translated_imo_2015_p4
44 | a b c = triangle a b c; o = circle o a b c; d = on_line d b c; e = on_line e b c, on_circle e a d; f = on_circle f o a, on_circle f a d; g = on_circle g o a, on_circle g a d; o1 = circle o1 f b d; o2 = circle o2 g c e; k = on_circle k o1 b, on_line k a b; l = on_circle l o2 c, on_line l a c; x = on_line x f k, on_line x l g ? coll x o a
45 | translated_imo_2016_p1
46 | a b z = triangle a b z; f = angle_bisector f b a z, on_bline f a b; c = on_tline c b f b, on_line c a f; d = on_line d a z, on_bline d a c; e = angle_mirror e c a d, on_bline e a d; m = midpoint m c f; x = parallelogram e a m x; y = on_line y f x, on_line y e m ? coll y b d
47 | translated_imo_2017_p4
48 | r s = segment r s; t = mirror t r s; o = on_bline o r s; j = on_circle j o s; o1 = circle o1 j s t; a = on_tline a r o r, on_circle a o1 s; b = on_tline b r o r, on_circle b o1 s; k = on_line k j a, on_circle k o s ? perp k t o1 t
49 | translated_imo_2018_p1
50 | a b c = triangle a b c; o = circle o a b c; d = on_line d a b; e = on_line e a c, on_circle e a d; f = on_bline f b d, on_circle f o a; g = on_bline g e c, on_circle g o a ? para d e f g
51 | translated_imo_2019_p2
52 | a b c = triangle; a1 = on_line b c; b1 = on_line a c; p = on_line a a1; q = on_line b b1, on_pline p a b; p1 = on_line p b1, eqangle3 p c a b c; q1 = on_line q a1, eqangle3 c q b c a ? cyclic p q p1 q1
53 | translated_imo_2019_p6
54 | a b c = triangle a b c; d e f i = incenter2 d e f i a b c; r = on_tline r d e f, on_circle r i d; p = on_line p r a, on_circle p i d; o1 = circle o1 p c e; o2 = circle o2 p b f; q = on_circle q o1 p, on_circle q o2 p; t = on_line t p q, on_line t i d ? perp a t a i
55 | translated_imo_2020_p1
56 | p a b = triangle p a b; x = angle_bisector p b a; y = angle_bisector p a b; z = on_aline z a p a b x; t = on_aline t p a p a z; d = on_aline d p t p b a, on_line a z; u = on_aline u b p b a y; v = on_aline v p b p b u; c = on_aline c p v p a b, on_line b u; o = angle_bisector a d p, angle_bisector p c b ? cong o a o b
57 | translated_imo_2021_p3
58 | a b c = triangle; d = angle_bisector b a c; e = on_aline d a d c b, on_line a c; f = on_aline d a d b c, on_line a b; x = on_bline b c, on_line a c; o1 = circle a d c; o2 = circle e x d; y = on_line e f, on_line b c ? coll o1 o2 y
59 | translated_imo_2022_p4
60 | b c = segment; d = free; e = eqdistance d b c; t = on_bline b d, on_bline c e; a = eqangle2 b t e; p = on_line a b, on_line c d; q = on_line a b, on_line c t; r = on_line a e, on_line c d; s = on_line a e, on_line d t ? cyclic p q r s
61 |
--------------------------------------------------------------------------------
/lm_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Wrapper for language modeling inference implemented in Meliad."""
17 | from typing import Any, Dict
18 |
19 | import jax
20 | import models # pylint: disable=unused-import
21 | import t5.data
22 | from transformer import inference_utils
23 |
24 |
25 | np = jax.numpy
26 |
27 |
28 | Trainer = inference_utils.Trainer
29 |
30 | MetricsOutput = Dict[str, Any] # Metrics output by model.
31 |
32 |
33 | parse_gin_configuration = inference_utils.parse_gin_configuration
34 |
35 |
36 | class LanguageModelInference:
37 | """Meliad wrapper for LM inference."""
38 |
39 | def __init__(self, vocab_path: str, load_dir: str, mode='beam_search'):
40 | self.vocab = t5.data.SentencePieceVocabulary(vocab_path)
41 |
42 | # This task won't be pulling from a dataset.
43 | def null_iter_fn() -> None:
44 | return None
45 |
46 | process_summaries_f = inference_utils.models.process_summaries_function(
47 | self.vocab
48 | )
49 |
50 | trainer = inference_utils.training_loop.Trainer(
51 | get_training_dataset_iterator=null_iter_fn,
52 | get_test_dataset_iterator=None,
53 | pretty_print_input_function=None,
54 | process_summaries_function=process_summaries_f,
55 | load_dir=load_dir,
56 | workdir='', # Don't log or save checkpoints.
57 | replicate_mode=False,
58 | ) # Run on a single device at batch size 1.
59 | self.trainer = trainer
60 |
61 | # Create and initialize the model.
62 | (tstate, _, imodel, prngs) = trainer.initialize_model()
63 | self.imodel = imodel
64 | self.batch_size = imodel.task_config.batch_size
65 |
66 | self.n = imodel.num_heads
67 | self.h = imodel.head_size
68 |
69 | # Create an inference task.
70 | writers = {}
71 | self.task = trainer.create_training_task(mode, imodel, prngs, writers) # pylint: disable=too-many-function-args
72 |
73 | # Register any additional actions.
74 | # Actions are cleared first for use with colab.
75 | inference_utils.training_loop.clear_interstep_callbacks()
76 | inference_utils.training_loop.register_interstep_callbacks()
77 | self.tstate = tstate
78 |
79 | # some default parameters.
80 | eos = [0] * 1024
81 | for idx in self.encode_list(['.', ';']):
82 | eos[idx] = 1
83 |
84 | self.eos = np.array(eos, dtype=np.bfloat16)
85 | self.mask = jax.numpy.ones([1024], dtype=np.bfloat16)
86 |
87 | def decode(self, ids: list[int]) -> str:
88 | return self.vocab.decode(ids)
89 |
90 | def decode_list(self, tokens: list[int]) -> list[str]:
91 | return [self.decode([tok]) for tok in tokens]
92 |
93 | def encode(self, inputs_str: str) -> list[int]:
94 | return self.vocab.encode(inputs_str)
95 |
96 | def encode_list(self, inputs_strs: list[str]) -> list[int]:
97 | result = [self.vocab.encode(x) for x in inputs_strs]
98 | assert all([len(x) == 1 for x in result]), [
99 | self.decode(x) for x in result if len(x) != 1
100 | ]
101 | return [x[0] for x in result]
102 |
103 | def call(
104 | self,
105 | inputs: np.ndarray,
106 | dstate: tuple[dict[str, np.ndarray], ...] = None,
107 | eos: np.ndarray = None,
108 | mask: np.ndarray = None,
109 | ) -> MetricsOutput:
110 | """Call the meliad model."""
111 | batch_size, length = inputs.shape
112 | inputs = jax.numpy.pad(inputs, [(0, 0), (0, 1024 - length)])
113 |
114 | if eos is None:
115 | eos = self.eos
116 | if mask is None:
117 | mask = self.mask
118 |
119 | x = {'targets': inputs, 'length': length, 'eos': eos, 'mask': mask}
120 |
121 | if dstate is not None:
122 | x['start_of_sequence'] = jax.numpy.array([False] * batch_size)
123 | else:
124 | dstate = tuple(
125 | [{ # this dummy value will never be used.
126 | 'current_index': np.array([0] * batch_size, dtype=np.int32),
127 | 'keys': np.zeros(
128 | (batch_size, 2048, self.n, self.h), dtype=np.bfloat16
129 | ),
130 | 'values': np.zeros(
131 | (batch_size, 2048, self.n, self.h), dtype=np.bfloat16
132 | ),
133 | 'recurrent_kvq': None,
134 | 'relative_position_bias': np.zeros(
135 | (batch_size, self.n, 1, 1024), dtype=np.bfloat16
136 | ),
137 | }]
138 | * 12
139 | )
140 | x['start_of_sequence'] = jax.numpy.array([True] * batch_size)
141 |
142 | x['dstate'] = dstate
143 | _, metrics_np = self.task.run_step(self.tstate, x, 0)
144 | return metrics_np
145 |
146 | def beam_decode(
147 | self,
148 | inputs: str,
149 | eos_tokens: np.ndarray = None,
150 | mask_tokens: np.ndarray = None,
151 | dstate: dict[str, np.ndarray] = None,
152 | ) -> MetricsOutput:
153 | """Beam search."""
154 | inputs = jax.numpy.array([self.vocab.encode(inputs)] * self.batch_size)
155 |
156 | eos = self.eos
157 | if eos_tokens is not None:
158 | eos_ids = self.encode_list(eos_tokens)
159 | eos = np.array(
160 | [1 if idx in eos_ids else 0 for idx in range(1024)], dtype=np.bfloat16
161 | ).reshape((1, 1, 1024))
162 |
163 | mask = self.mask
164 | if mask_tokens is not None:
165 | mask_ids = self.encode_list(mask_tokens)
166 | mask = np.array(
167 | [0 if idx in mask_ids else 1 for idx in range(1024)],
168 | dtype=np.bfloat16,
169 | ).reshape((1, 1, 1024))
170 |
171 | metrics_np = self.call(inputs, dstate=dstate, eos=eos, mask=mask)
172 |
173 | finished_seqs = metrics_np['finished_seqs']
174 | finished_scores = metrics_np['finished_scores']
175 |
176 | seqs = []
177 | scores = []
178 | for seq, score in zip(finished_seqs, finished_scores):
179 | seq = self.decode(seq[1:])
180 | seqs.append(seq)
181 | scores.append(score)
182 |
183 | return {
184 | 'finished_seqs': finished_seqs,
185 | 'finished_scores': finished_scores,
186 | 'seqs_str': seqs,
187 | 'scores': scores,
188 | 'dstate': metrics_np['dstate'],
189 | }
190 |
--------------------------------------------------------------------------------
/lm_inference_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Unit tests for lm_inference.py."""
17 | import os
18 | import unittest
19 |
20 | from absl import flags
21 | from absl.testing import absltest
22 | import lm_inference as lm
23 |
24 |
25 | _DATA_PATH = flags.DEFINE_string('data_path', '', 'path to ckpt and vocab.')
26 | _MELIAD_PATH = flags.DEFINE_string(
27 | 'meliad_path', '', 'path to meliad repository.'
28 | ) # pylint: disable=line-too-long
29 |
30 |
31 | class LmInferenceTest(unittest.TestCase):
32 |
33 | @classmethod
34 | def setUpClass(cls):
35 | super().setUpClass()
36 | gin_file = [
37 | 'base_htrans.gin',
38 | 'size/medium_150M.gin',
39 | 'options/positions_t5.gin',
40 | 'options/lr_cosine_decay.gin',
41 | 'options/seq_1024_nocache.gin',
42 | 'geometry_150M_generate.gin',
43 | ]
44 |
45 | gin_param = [
46 | 'DecoderOnlyLanguageModelGenerate.output_token_losses=True',
47 | 'TransformerTaskConfig.batch_size=2',
48 | 'TransformerTaskConfig.sequence_length=128',
49 | 'Trainer.restore_state_variables=False',
50 | ]
51 |
52 | gin_search_paths = [
53 | os.path.join(_MELIAD_PATH.value, 'transformer/configs'),
54 | os.getcwd(),
55 | ]
56 |
57 | vocab_path = os.path.join(_DATA_PATH.value, 'geometry.757.model')
58 |
59 | lm.parse_gin_configuration(gin_file, gin_param, gin_paths=gin_search_paths)
60 |
61 | cls.loaded_lm = lm.LanguageModelInference(
62 | vocab_path, _DATA_PATH.value, mode='beam_search'
63 | )
64 |
65 | def test_lm_decode(self):
66 | outputs = LmInferenceTest.loaded_lm.beam_decode(
67 | '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c'
68 | ' {F1} x00',
69 | eos_tokens=[';'],
70 | )
71 | self.assertEqual(
72 | outputs['seqs_str'],
73 | ['e : D a b c e 02 D a c b e 03 ;', 'e : C a c e 02 C b d e 03 ;'],
74 | )
75 |
76 | def test_lm_score_may_fail_numerically_for_external_meliad(self):
77 | outputs = LmInferenceTest.loaded_lm.beam_decode(
78 | '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c'
79 | ' {F1} x00',
80 | eos_tokens=[';'],
81 | )
82 | self.assertEqual(
83 | outputs['scores'],
84 | [-1.18607294559478759765625, -1.10228693485260009765625],
85 | )
86 |
87 |
88 | if __name__ == '__main__':
89 | absltest.main()
90 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Transformer language model generate mode."""
17 |
18 | from typing import Any, Tuple
19 | import beam_search
20 | import decoder_stack
21 | import gin
22 | import jax
23 | import jax.numpy as jnp
24 | from transformer import models
25 |
26 |
27 | @gin.configurable
28 | class DecoderOnlyLanguageModelGenerate(models.DecoderOnlyLanguageModel):
29 | """Decoder only language modeling in inference mode."""
30 |
31 | decoder_factory = decoder_stack.DecoderStackGenerate
32 |
33 | num_heads: int = gin.REQUIRED
34 | head_size: int = gin.REQUIRED
35 |
36 | def get_fake_input(self) -> dict[str, Any]:
37 | fake_input_dict = super().get_fake_input()
38 | b = self.task_config.batch_size
39 | n = self.num_heads
40 | h = self.head_size
41 | fake_input_dict.update({
42 | 'dstate': tuple(
43 | [{
44 | 'current_index': jnp.array([0] * b, dtype=jnp.int32),
45 | 'keys': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
46 | 'values': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
47 | 'recurrent_kvq': None,
48 | 'relative_position_bias': jnp.zeros(
49 | (b, n, 1, 1024), dtype=jnp.bfloat16
50 | ),
51 | }]
52 | * 12
53 | ),
54 | 'eos': jnp.zeros([1024], dtype=jnp.bfloat16),
55 | 'mask': jnp.ones([1024], dtype=jnp.bfloat16),
56 | 'length': 1,
57 | 'temperature': 1.0,
58 | })
59 | return fake_input_dict
60 |
61 | def __call__(self, inputs: ...) -> tuple[Any, dict[str, Any]]:
62 | # Make sure this code is not used on untested cases.
63 | if self.mode not in ['init', 'beam_search']:
64 | raise ValueError(f'{type(self)} cannot do mode {self.mode}')
65 | if self.decoder.supports_generate():
66 | raise ValueError(f'{type(self)}.decoder cannot supports_generate()')
67 |
68 | self.decoder(
69 | input_tokens=inputs['targets'][:, 0:1],
70 | target_tokens=None,
71 | start_of_sequence=inputs['start_of_sequence'],
72 | )
73 |
74 | b = inputs['targets'].shape[0]
75 | no_start_of_seq = jnp.array([False] * b, dtype=jnp.bool_)
76 |
77 | # This fn is used in both beam_search or topk_sampling.
78 | def tokens_to_logits_fn(
79 | input_token: jnp.ndarray, dstate: tuple[dict[str, jnp.ndarray], ...]
80 | ) -> tuple[jnp.ndarray, tuple[dict[str, jnp.ndarray], ...]]:
81 | (logits, dstate, _) = self.decoder(
82 | input_tokens=input_token,
83 | target_tokens=None,
84 | start_of_sequence=no_start_of_seq,
85 | decoder_state=dstate,
86 | )
87 | return logits[:, -1, :], dstate
88 |
89 | last_token = jax.lax.dynamic_slice_in_dim(
90 | inputs['targets'], inputs['length'] - 1, 1, axis=1
91 | )
92 |
93 | # last token is used to seed beam_search
94 | inputs['targets'] = inputs['targets'][:, 0:-1]
95 | dstate = jax.lax.cond(
96 | inputs['start_of_sequence'][0],
97 | lambda: self.generate(inputs)[0],
98 | lambda: inputs['dstate'],
99 | )
100 |
101 | # Then we run beam search, init with last_token & dstate.
102 | finished_seqs, finished_scores, dstate = beam_search.beam_search_flat(
103 | last_token,
104 | dstate,
105 | tokens_to_logits_fn,
106 | max_decode_len=512,
107 | eos=inputs['eos'].reshape((1, 1, -1)),
108 | mask=inputs['mask'].reshape((1, 1, -1)),
109 | )
110 |
111 | return 0.0, {
112 | 'finished_seqs': finished_seqs,
113 | 'finished_scores': finished_scores,
114 | 'dstate': dstate,
115 | }
116 |
117 | def generate(
118 | self, inputs: ...
119 | ) -> tuple[tuple[dict[str, jnp.ndarray, ...], ...], jnp.ndarray]:
120 | """Generate an output sequence.
121 |
122 | Args:
123 | inputs: the same as argument to _call_.
124 |
125 | Returns:
126 | An array of generated tokens of shape (batch_size, sequence_length).
127 | """
128 | input_tokens = inputs['targets'] # [b,seq_len]
129 | start_of_sequence = inputs['start_of_sequence'] # [b]
130 | target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
131 | batch_size = target_tokens.shape[0]
132 |
133 | # Assuming all sequences start at the same time.
134 | start0 = inputs['start_of_sequence'][0]
135 | dstate = jax.lax.cond(
136 | start0,
137 | lambda: self.decoder.init_decoder_state_vanilla( # pylint: disable=g-long-lambda
138 | 1024, start_of_sequence
139 | ),
140 | lambda: inputs['dstate'],
141 | )
142 |
143 | first_token = input_tokens[:, 0:1]
144 | no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
145 | temperature = 1
146 | if 'temperature' in inputs:
147 | temperature = inputs['temperature']
148 |
149 | num_steps = inputs['length']
150 | if self.mode == 'beam_search':
151 | num_steps -= 1
152 |
153 | def cond_fn(scan_state) -> jnp.bool_:
154 | _, _, i, _ = scan_state
155 | return i < num_steps
156 |
157 | def loop_fn(scan_state: Any) -> Tuple[Any, Any, Any, Any]:
158 | (dstate, input_token, i, _) = scan_state
159 |
160 | (logits, dstate, _) = self.decoder(
161 | input_tokens=input_token,
162 | target_tokens=None,
163 | start_of_sequence=no_start_of_seq,
164 | decoder_state=dstate,
165 | )
166 |
167 | logits = logits / temperature
168 | output_token = jax.lax.dynamic_slice_in_dim(target_tokens, i, 1, axis=1)
169 |
170 | return (dstate, output_token, i + 1, logits)
171 |
172 | # Scan over the sequence length.
173 | dummy_logits = jnp.zeros((batch_size, 1, 1024))
174 | initial_scan_state = (dstate, first_token, 0, dummy_logits)
175 | dstate, _, _, logits = jax.lax.while_loop(
176 | cond_fn, loop_fn, initial_scan_state
177 | )
178 | return dstate, logits
179 |
--------------------------------------------------------------------------------
/numericals_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Unit testing for the geometry numericals code."""
17 |
18 | import unittest
19 |
20 | from absl.testing import absltest
21 | import numericals as nm
22 |
23 | np = nm.np
24 |
25 | unif = nm.unif
26 | Point = nm.Point
27 | Line = nm.Line
28 | Circle = nm.Circle
29 | HalfLine = nm.HalfLine
30 |
31 | line_circle_intersection = nm.line_circle_intersection
32 | line_line_intersection = nm.line_line_intersection
33 |
34 | check_coll = nm.check_coll
35 | check_eqangle = nm.check_eqangle
36 |
37 | random_points = nm.random_points
38 | ang_between = nm.ang_between
39 | head_from = nm.head_from
40 |
41 |
42 | class NumericalTest(unittest.TestCase):
43 |
44 | def test_sketch_ieq_triangle(self):
45 | a, b, c = nm.sketch_ieq_triangle([])
46 | self.assertAlmostEqual(a.distance(b), b.distance(c))
47 | self.assertAlmostEqual(c.distance(a), b.distance(c))
48 |
49 | def test_sketch_2l1c(self):
50 | p = nm.Point(0.0, 0.0)
51 | pi = np.pi
52 | anga = unif(-0.4 * pi, 0.4 * pi)
53 | a = Point(np.cos(anga), np.sin(anga))
54 | angb = unif(0.6 * pi, 1.4 * pi)
55 | b = Point(np.cos(angb), np.sin(angb))
56 |
57 | angc = unif(anga + 0.05 * pi, angb - 0.05 * pi)
58 | c = Point(np.cos(angc), np.sin(angc)) * unif(0.2, 0.8)
59 |
60 | x, y, z, i = nm.sketch_2l1c([a, b, c, p])
61 | self.assertTrue(check_coll([x, c, a]))
62 | self.assertTrue(check_coll([y, c, b]))
63 | self.assertAlmostEqual(z.distance(p), 1.0)
64 | self.assertTrue(check_coll([p, i, z]))
65 | self.assertTrue(Line(i, x).is_perp(Line(c, a)))
66 | self.assertTrue(Line(i, y).is_perp(Line(c, b)))
67 | self.assertAlmostEqual(i.distance(x), i.distance(y))
68 | self.assertAlmostEqual(i.distance(x), i.distance(z))
69 |
70 | def test_sketch_3peq(self):
71 | a, b, c = random_points(3)
72 | x, y, z = nm.sketch_3peq([a, b, c])
73 |
74 | self.assertTrue(check_coll([a, b, x]))
75 | self.assertTrue(check_coll([a, c, y]))
76 | self.assertTrue(check_coll([b, c, z]))
77 | self.assertTrue(check_coll([x, y, z]))
78 | self.assertAlmostEqual(z.distance(x), z.distance(y))
79 |
80 | def test_sketch_aline(self):
81 | a, b, c, d, e = random_points(5)
82 | ex = nm.sketch_aline([a, b, c, d, e])
83 | self.assertIsInstance(ex, HalfLine)
84 | self.assertEqual(ex.tail, e)
85 | x = ex.head
86 | self.assertAlmostEqual(ang_between(b, a, c), ang_between(e, d, x))
87 |
88 | def test_sketch_amirror(self):
89 | a, b, c = random_points(3)
90 | bx = nm.sketch_amirror([a, b, c])
91 | self.assertIsInstance(bx, HalfLine)
92 | assert bx.tail == b
93 | x = bx.head
94 |
95 | ang1 = ang_between(b, a, c)
96 | ang2 = ang_between(b, c, x)
97 | self.assertAlmostEqual(ang1, ang2)
98 |
99 | def test_sketch_bisect(self):
100 | a, b, c = random_points(3)
101 | line = nm.sketch_bisect([a, b, c])
102 | self.assertAlmostEqual(b.distance(line), 0.0)
103 |
104 | l = a.perpendicular_line(line)
105 | x = line_line_intersection(l, Line(b, c))
106 | self.assertAlmostEqual(a.distance(line), x.distance(line))
107 |
108 | d, _ = line_circle_intersection(line, Circle(b, radius=1))
109 | ang1 = ang_between(b, a, d)
110 | ang2 = ang_between(b, d, c)
111 | self.assertAlmostEqual(ang1, ang2)
112 |
113 | def test_sketch_bline(self):
114 | a, b = random_points(2)
115 | l = nm.sketch_bline([a, b])
116 | self.assertTrue(Line(a, b).is_perp(l))
117 | self.assertAlmostEqual(a.distance(l), b.distance(l))
118 |
119 | def test_sketch_cc_tangent(self):
120 | o = Point(0.0, 0.0)
121 | w = Point(1.0, 0.0)
122 |
123 | ra = unif(0.0, 0.6)
124 | rb = unif(0.4, 1.0)
125 |
126 | a = unif(0.0, np.pi)
127 | b = unif(0.0, np.pi)
128 |
129 | a = o + ra * Point(np.cos(a), np.sin(a))
130 | b = w + rb * Point(np.sin(b), np.cos(b))
131 |
132 | x, y, z, t = nm.sketch_cc_tangent([o, a, w, b])
133 | xy = Line(x, y)
134 | zt = Line(z, t)
135 | self.assertAlmostEqual(o.distance(xy), o.distance(a))
136 | self.assertAlmostEqual(o.distance(zt), o.distance(a))
137 | self.assertAlmostEqual(w.distance(xy), w.distance(b))
138 | self.assertAlmostEqual(w.distance(zt), w.distance(b))
139 |
140 | def test_sketch_circle(self):
141 | a, b, c = random_points(3)
142 | circle = nm.sketch_circle([a, b, c])
143 | self.assertAlmostEqual(circle.center.distance(a), 0.0)
144 | self.assertAlmostEqual(circle.radius, b.distance(c))
145 |
146 | def test_sketch_e5128(self):
147 | b = Point(0.0, 0.0)
148 | c = Point(0.0, 1.0)
149 | ang = unif(-np.pi / 2, 3 * np.pi / 2)
150 | d = head_from(c, ang, 1.0)
151 | a = Point(unif(0.5, 2.0), 0.0)
152 |
153 | e, g = nm.sketch_e5128([a, b, c, d])
154 | ang1 = ang_between(a, b, d)
155 | ang2 = ang_between(e, a, g)
156 | self.assertAlmostEqual(ang1, ang2)
157 |
158 | def test_sketch_eq_quadrangle(self):
159 | a, b, c, d = nm.sketch_eq_quadrangle([])
160 | self.assertAlmostEqual(a.distance(d), c.distance(b))
161 | ac = Line(a, c)
162 | assert ac.diff_side(b, d), (ac(b), ac(d))
163 | bd = Line(b, d)
164 | assert bd.diff_side(a, c), (bd(a), bd(c))
165 |
166 | def test_sketch_eq_trapezoid(self):
167 | a, b, c, d = nm.sketch_eq_trapezoid([])
168 | assert Line(a, b).is_parallel(Line(c, d))
169 | self.assertAlmostEqual(a.distance(d), b.distance(c))
170 |
171 | def test_sketch_eqangle3(self):
172 | points = random_points(5)
173 | x = nm.sketch_eqangle3(points).sample_within(points)[0]
174 | a, b, d, e, f = points
175 | self.assertTrue(check_eqangle([x, a, x, b, d, e, d, f]))
176 |
177 | def test_sketch_eqangle2(self):
178 | a, b, c = random_points(3)
179 | x = nm.sketch_eqangle2([a, b, c])
180 | ang1 = ang_between(a, b, x)
181 | ang2 = ang_between(c, x, b)
182 | self.assertAlmostEqual(ang1, ang2)
183 |
184 | def test_sketch_edia_quadrangle(self):
185 | a, b, c, d = nm.sketch_eqdia_quadrangle([])
186 | assert Line(a, c).diff_side(b, d)
187 | assert Line(b, d).diff_side(a, c)
188 | self.assertAlmostEqual(a.distance(c), b.distance(d))
189 |
190 | def test_sketch_isos(self):
191 | a, b, c = nm.sketch_isos([])
192 | self.assertAlmostEqual(a.distance(b), a.distance(c))
193 | self.assertAlmostEqual(ang_between(b, a, c), ang_between(c, b, a))
194 |
195 | def test_sketch_quadrange(self):
196 | a, b, c, d = nm.sketch_quadrangle([])
197 | self.assertTrue(Line(a, c).diff_side(b, d))
198 | self.assertTrue(Line(b, d).diff_side(a, c))
199 |
200 | def test_sketch_r_trapezoid(self):
201 | a, b, c, d = nm.sketch_r_trapezoid([])
202 | self.assertTrue(Line(a, b).is_perp(Line(a, d)))
203 | self.assertTrue(Line(a, b).is_parallel(Line(c, d)))
204 | self.assertTrue(Line(a, c).diff_side(b, d))
205 | self.assertTrue(Line(b, d).diff_side(a, c))
206 |
207 | def test_sketch_r_triangle(self):
208 | a, b, c = nm.sketch_r_triangle([])
209 | self.assertTrue(Line(a, b).is_perp(Line(a, c)))
210 |
211 | def test_sketch_rectangle(self):
212 | a, b, c, d = nm.sketch_rectangle([])
213 | self.assertTrue(Line(a, b).is_perp(Line(b, c)))
214 | self.assertTrue(Line(b, c).is_perp(Line(c, d)))
215 | self.assertTrue(Line(c, d).is_perp(Line(d, a)))
216 |
217 | def test_sketch_reflect(self):
218 | a, b, c = random_points(3)
219 | x = nm.sketch_reflect([a, b, c])
220 | self.assertTrue(Line(a, x).is_perp(Line(b, c)))
221 | self.assertAlmostEqual(x.distance(Line(b, c)), a.distance(Line(b, c)))
222 |
223 | def test_sketch_risos(self):
224 | a, b, c = nm.sketch_risos([])
225 | self.assertAlmostEqual(a.distance(b), a.distance(c))
226 | self.assertTrue(Line(a, b).is_perp(Line(a, c)))
227 |
228 | def test_sketch_rotaten90(self):
229 | a, b = random_points(2)
230 | x = nm.sketch_rotaten90([a, b])
231 | self.assertAlmostEqual(a.distance(x), a.distance(b))
232 | self.assertTrue(Line(a, x).is_perp(Line(a, b)))
233 | d = Point(0.0, 0.0)
234 | e = Point(0.0, 1.0)
235 | f = Point(1.0, 0.0)
236 | self.assertAlmostEqual(ang_between(d, e, f), ang_between(a, b, x))
237 |
238 | def test_sketch_rotatep90(self):
239 | a, b = random_points(2)
240 | x = nm.sketch_rotatep90([a, b])
241 | self.assertAlmostEqual(a.distance(x), a.distance(b))
242 | self.assertTrue(Line(a, x).is_perp(Line(a, b)))
243 | d = Point(0.0, 0.0)
244 | e = Point(0.0, 1.0)
245 | f = Point(1.0, 0.0)
246 | self.assertAlmostEqual(ang_between(d, f, e), ang_between(a, b, x))
247 |
248 | def test_sketch_s_angle(self):
249 | a, b = random_points(2)
250 | y = unif(0.0, np.pi)
251 | bx = nm.sketch_s_angle([a, b, y / np.pi * 180])
252 | self.assertIsInstance(bx, HalfLine)
253 | self.assertEqual(bx.tail, b)
254 | x = bx.head
255 |
256 | d = Point(1.0, 0.0)
257 | e = Point(0.0, 0.0)
258 | f = Point(np.cos(y), np.sin(y))
259 | self.assertAlmostEqual(ang_between(e, d, f), ang_between(b, a, x))
260 |
261 | def test_sketch_shift(self):
262 | a, b, c = random_points(3)
263 | x = nm.sketch_shift([a, b, c])
264 | self.assertTrue((b - a).close(x - c))
265 |
266 | def test_sketch_square(self):
267 | a, b = random_points(2)
268 | c, d = nm.sketch_square([a, b])
269 | self.assertTrue(Line(a, b).is_perp(Line(b, c)))
270 | self.assertTrue(Line(b, c).is_perp(Line(c, d)))
271 | self.assertTrue(Line(c, d).is_perp(Line(d, a)))
272 | self.assertAlmostEqual(a.distance(b), b.distance(c))
273 |
274 | def test_sketch_isquare(self):
275 | a, b, c, d = nm.sketch_isquare([])
276 | self.assertTrue(Line(a, b).is_perp(Line(b, c)))
277 | self.assertTrue(Line(b, c).is_perp(Line(c, d)))
278 | self.assertTrue(Line(c, d).is_perp(Line(d, a)))
279 | self.assertAlmostEqual(a.distance(b), b.distance(c))
280 |
281 | def test_sketch_trapezoid(self):
282 | a, b, c, d = nm.sketch_trapezoid([])
283 | self.assertTrue(Line(a, b).is_parallel(Line(c, d)))
284 | self.assertTrue(Line(a, c).diff_side(b, d))
285 | self.assertTrue(Line(b, d).diff_side(a, c))
286 |
287 | def test_sketch_triangle(self):
288 | a, b, c = nm.sketch_triangle([])
289 | self.assertFalse(check_coll([a, b, c]))
290 |
291 | def test_sketch_triangle12(self):
292 | a, b, c = nm.sketch_triangle12([])
293 | self.assertAlmostEqual(a.distance(b) * 2, a.distance(c))
294 |
295 | def test_sketch_trisect(self):
296 | a, b, c = random_points(3)
297 | x, y = nm.sketch_trisect([a, b, c])
298 | self.assertAlmostEqual(ang_between(b, a, x), ang_between(b, x, y))
299 | self.assertAlmostEqual(ang_between(b, x, y), ang_between(b, y, c))
300 | self.assertAlmostEqual(ang_between(b, a, x) * 3, ang_between(b, a, c))
301 |
302 | def test_sketch_trisegment(self):
303 | a, b = random_points(2)
304 | x, y = nm.sketch_trisegment([a, b])
305 | self.assertAlmostEqual(
306 | a.distance(x) + x.distance(y) + y.distance(b), a.distance(b)
307 | )
308 | self.assertAlmostEqual(a.distance(x), x.distance(y))
309 | self.assertAlmostEqual(x.distance(y), y.distance(b))
310 |
311 |
312 | if __name__ == '__main__':
313 | absltest.main()
314 |
--------------------------------------------------------------------------------
/pretty.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utilities for string manipulation in the DSL."""
17 |
18 | MAP_SYMBOL = {
19 | 'T': 'perp',
20 | 'P': 'para',
21 | 'D': 'cong',
22 | 'S': 'simtri',
23 | 'I': 'circle',
24 | 'M': 'midp',
25 | 'O': 'cyclic',
26 | 'C': 'coll',
27 | '^': 'eqangle',
28 | '/': 'eqratio',
29 | '%': 'eqratio',
30 | '=': 'contri',
31 | 'X': 'collx',
32 | 'A': 'acompute',
33 | 'R': 'rcompute',
34 | 'Q': 'fixc',
35 | 'E': 'fixl',
36 | 'V': 'fixb',
37 | 'H': 'fixt',
38 | 'Z': 'fixp',
39 | 'Y': 'ind',
40 | }
41 |
42 |
43 | def map_symbol(c: str) -> str:
44 | return MAP_SYMBOL[c]
45 |
46 |
47 | def map_symbol_inv(c: str) -> str:
48 | return {v: k for k, v in MAP_SYMBOL.items()}[c]
49 |
50 |
51 | def _gcd(x: int, y: int) -> int:
52 | while y:
53 | x, y = y, x % y
54 | return x
55 |
56 |
57 | def simplify(n: int, d: int) -> tuple[int, int]:
58 | g = _gcd(n, d)
59 | return (n // g, d // g)
60 |
61 |
62 | def pretty2r(a: str, b: str, c: str, d: str) -> str:
63 | if b in (c, d):
64 | a, b = b, a
65 |
66 | if a == d:
67 | c, d = d, c
68 |
69 | return f'{a} {b} {c} {d}'
70 |
71 |
72 | def pretty2a(a: str, b: str, c: str, d: str) -> str:
73 | if b in (c, d):
74 | a, b = b, a
75 |
76 | if a == d:
77 | c, d = d, c
78 |
79 | return f'{a} {b} {c} {d}'
80 |
81 |
82 | def pretty_angle(a: str, b: str, c: str, d: str) -> str:
83 | if b in (c, d):
84 | a, b = b, a
85 | if a == d:
86 | c, d = d, c
87 |
88 | if a == c:
89 | return f'\u2220{b}{a}{d}'
90 | return f'\u2220({a}{b}-{c}{d})'
91 |
92 |
93 | def pretty_nl(name: str, args: list[str]) -> str:
94 | """Natural lang formatting a predicate."""
95 | if name == 'aconst':
96 | a, b, c, d, y = args
97 | return f'{pretty_angle(a, b, c, d)} = {y}'
98 | if name == 'rconst':
99 | a, b, c, d, y = args
100 | return f'{a}{b}:{c}{d} = {y}'
101 | if name == 'acompute':
102 | a, b, c, d = args
103 | return f'{pretty_angle(a, b, c, d)}'
104 | if name in ['coll', 'C']:
105 | return '' + ','.join(args) + ' are collinear'
106 | if name == 'collx':
107 | return '' + ','.join(list(set(args))) + ' are collinear'
108 | if name in ['cyclic', 'O']:
109 | return '' + ','.join(args) + ' are concyclic'
110 | if name in ['midp', 'midpoint', 'M']:
111 | x, a, b = args
112 | return f'{x} is midpoint of {a}{b}'
113 | if name in ['eqangle', 'eqangle6', '^']:
114 | a, b, c, d, e, f, g, h = args
115 | return f'{pretty_angle(a, b, c, d)} = {pretty_angle(e, f, g, h)}'
116 | if name in ['eqratio', 'eqratio6', '/']:
117 | return '{}{}:{}{} = {}{}:{}{}'.format(*args)
118 | if name == 'eqratio3':
119 | a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
120 | return f'S {o} {a} {b} {o} {c} {d}'
121 | if name in ['cong', 'D']:
122 | a, b, c, d = args
123 | return f'{a}{b} = {c}{d}'
124 | if name in ['perp', 'T']:
125 | if len(args) == 2: # this is algebraic derivation.
126 | ab, cd = args # ab = 'd( ... )'
127 | return f'{ab} \u27c2 {cd}'
128 | a, b, c, d = args
129 | return f'{a}{b} \u27c2 {c}{d}'
130 | if name in ['para', 'P']:
131 | if len(args) == 2: # this is algebraic derivation.
132 | ab, cd = args # ab = 'd( ... )'
133 | return f'{ab} \u2225 {cd}'
134 | a, b, c, d = args
135 | return f'{a}{b} \u2225 {c}{d}'
136 | if name in ['simtri2', 'simtri', 'simtri*']:
137 | a, b, c, x, y, z = args
138 | return f'\u0394{a}{b}{c} is similar to \u0394{x}{y}{z}'
139 | if name in ['contri2', 'contri', 'contri*']:
140 | a, b, c, x, y, z = args
141 | return f'\u0394{a}{b}{c} is congruent to \u0394{x}{y}{z}'
142 | if name in ['circle', 'I']:
143 | o, a, b, c = args
144 | return f'{o} is the circumcenter of \\Delta {a}{b}{c}'
145 | if name == 'foot':
146 | a, b, c, d = args
147 | return f'{a} is the foot of {b} on {c}{d}'
148 |
149 |
150 | def pretty(txt: str) -> str:
151 | """Pretty formating a predicate string."""
152 | if isinstance(txt, str):
153 | txt = txt.split(' ')
154 | name, *args = txt
155 | if name == 'ind':
156 | return 'Y ' + ' '.join(args)
157 | if name in ['fixc', 'fixl', 'fixb', 'fixt', 'fixp']:
158 | return map_symbol_inv(name) + ' ' + ' '.join(args)
159 | if name == 'acompute':
160 | a, b, c, d = args
161 | return 'A ' + ' '.join(args)
162 | if name == 'rcompute':
163 | a, b, c, d = args
164 | return 'R ' + ' '.join(args)
165 | if name == 'aconst':
166 | a, b, c, d, y = args
167 | return f'^ {pretty2a(a, b, c, d)} {y}'
168 | if name == 'rconst':
169 | a, b, c, d, y = args
170 | return f'/ {pretty2r(a, b, c, d)} {y}'
171 | if name == 'coll':
172 | return 'C ' + ' '.join(args)
173 | if name == 'collx':
174 | return 'X ' + ' '.join(args)
175 | if name == 'cyclic':
176 | return 'O ' + ' '.join(args)
177 | if name in ['midp', 'midpoint']:
178 | x, a, b = args
179 | return f'M {x} {a} {b}'
180 | if name == 'eqangle':
181 | a, b, c, d, e, f, g, h = args
182 | return f'^ {pretty2a(a, b, c, d)} {pretty2a(e, f, g, h)}'
183 | if name == 'eqratio':
184 | a, b, c, d, e, f, g, h = args
185 | return f'/ {pretty2r(a, b, c, d)} {pretty2r(e, f, g, h)}'
186 | if name == 'eqratio3':
187 | a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
188 | return f'S {o} {a} {b} {o} {c} {d}'
189 | if name == 'cong':
190 | a, b, c, d = args
191 | return f'D {a} {b} {c} {d}'
192 | if name == 'perp':
193 | if len(args) == 2: # this is algebraic derivation.
194 | ab, cd = args # ab = 'd( ... )'
195 | return f'T {ab} {cd}'
196 | a, b, c, d = args
197 | return f'T {a} {b} {c} {d}'
198 | if name == 'para':
199 | if len(args) == 2: # this is algebraic derivation.
200 | ab, cd = args # ab = 'd( ... )'
201 | return f'P {ab} {cd}'
202 | a, b, c, d = args
203 | return f'P {a} {b} {c} {d}'
204 | if name in ['simtri2', 'simtri', 'simtri*']:
205 | a, b, c, x, y, z = args
206 | return f'S {a} {b} {c} {x} {y} {z}'
207 | if name in ['contri2', 'contri', 'contri*']:
208 | a, b, c, x, y, z = args
209 | return f'= {a} {b} {c} {x} {y} {z}'
210 | if name == 'circle':
211 | o, a, b, c = args
212 | return f'I {o} {a} {b} {c}'
213 | if name == 'foot':
214 | a, b, c, d = args
215 | return f'F {a} {b} {c} {d}'
216 | return ' '.join(txt)
217 |
--------------------------------------------------------------------------------
/problem_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Unit tests for problem.py."""
17 | import unittest
18 |
19 | from absl.testing import absltest
20 | import problem as pr
21 |
22 |
23 | class ProblemTest(unittest.TestCase):
24 |
25 | @classmethod
26 | def setUpClass(cls):
27 | super().setUpClass()
28 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
29 |
30 | def test_orthocenter_no_translate(self):
31 | txt = 'a b c = triangle a b c; h = on_tline h b a c, on_tline h c a b ? perp a h b c' # pylint: disable=line-too-long
32 |
33 | # read the txt into pr.Problem object, do not change the name of points:
34 | p = pr.Problem.from_txt(txt, translate=False)
35 |
36 | # This is fed into the LM, translating from constructive to constrained:
37 | setup_str = p.setup_str_from_problem(ProblemTest.defs)
38 |
39 | self.assertEqual(
40 | setup_str,
41 | '{S} a : ; b : ; c : ; h : T a b c h 00 T a c b h 01 ? T a h b c',
42 | )
43 |
44 | def test_orthocenter_translate(self):
45 | txt = 'a b c = triangle a b c; h = on_tline h b a c, on_tline h c a b ? perp a h b c' # pylint: disable=line-too-long
46 |
47 | # Read the txt into pr.Problem object, change h -> d to match
48 | # training data distribution.
49 | p = pr.Problem.from_txt(txt, translate=True)
50 |
51 | # This is fed into the LM, translating from constructive to constrained:
52 | setup_str = p.setup_str_from_problem(ProblemTest.defs)
53 |
54 | self.assertEqual(
55 | setup_str,
56 | '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c',
57 | )
58 |
59 |
60 | if __name__ == '__main__':
61 | absltest.main()
62 |
--------------------------------------------------------------------------------
/requirements.in:
--------------------------------------------------------------------------------
1 | tensorflow==2.13.0
2 | numpy==1.23.5
3 | scipy==1.10.0
4 | matplotlib==3.7.0
5 | gdown==4.7.1
6 | jax==0.4.6
7 | jaxlib==0.4.6
8 | flax==0.5.3
9 | gin-config==0.5.0
10 | gin==0.1.6
11 | t5==0.9.4
12 | sentencepiece==0.1.99
13 | absl-py==1.4.0
14 | clu==0.0.7
15 | optax==0.1.7
16 | seqio==0.0.18
17 | tensorflow-datasets==4.9.3
18 |
--------------------------------------------------------------------------------
/rules.txt:
--------------------------------------------------------------------------------
1 | perp A B C D, perp C D E F, ncoll A B E => para A B E F
2 | cong O A O B, cong O B O C, cong O C O D => cyclic A B C D
3 | eqangle A B P Q C D P Q => para A B C D
4 | cyclic A B P Q => eqangle P A P B Q A Q B
5 | eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q
6 | cyclic A B C P Q R, eqangle C A C B R P R Q => cong A B P Q
7 | midp E A B, midp F A C => para E F B C
8 | para A B C D, coll O A C, coll O B D => eqratio3 A B C D O O
9 | perp A B C D, perp E F G H, npara A B E F => eqangle A B E F C D G H
10 | eqangle a b c d m n p q, eqangle c d e f p q r u => eqangle a b e f m n r u
11 | eqratio a b c d m n p q, eqratio c d e f p q r u => eqratio a b e f m n r u
12 | eqratio6 d b d c a b a c, coll d b c, ncoll a b c => eqangle6 a b a d a d a c
13 | eqangle6 a b a d a d a c, coll d b c, ncoll a b c => eqratio6 d b d c a b a c
14 | cong O A O B, ncoll O A B => eqangle O A A B A B O B
15 | eqangle6 A O A B B A B O, ncoll O A B => cong O A O B
16 | circle O A B C, perp O A A X => eqangle A X A B C A C B
17 | circle O A B C, eqangle A X A B C A C B => perp O A A X
18 | circle O A B C, midp M B C => eqangle A B A C O B O M
19 | circle O A B C, coll M B C, eqangle A B A C O B O M => midp M B C
20 | perp A B B C, midp M A C => cong A M B M
21 | circle O A B C, coll O A C => perp A B B C
22 | cyclic A B C D, para A B C D => eqangle A D C D C D C B
23 | midp M A B, perp O M A B => cong O A O B
24 | cong A P B P, cong A Q B Q => perp A B P Q
25 | cong A P B P, cong A Q B Q, cyclic A B P Q => perp P A A Q
26 | midp M A B, midp M C D => para A C B D
27 | midp M A B, para A C B D, para A D B C => midp M C D
28 | eqratio O A A C O B B D, coll O A C, coll O B D, ncoll A B C, sameside A O C B O D => para A B C D
29 | para A B A C => coll A B C
30 | midp M A B, midp N C D => eqratio M A A B N C C D
31 | eqangle A B P Q C D U V, perp P Q U V => perp A B C D
32 | eqratio A B P Q C D U V, cong P Q U V => cong A B C D
33 | cong A B P Q, cong B C Q R, cong C A R P, ncoll A B C => contri* A B C P Q R
34 | cong A B P Q, cong B C Q R, eqangle6 B A B C Q P Q R, ncoll A B C => contri* A B C P Q R
35 | eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C => simtri A B C P Q R
36 | eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C => simtri2 A B C P Q R
37 | eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri A B C P Q R
38 | eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C, cong A B P Q => contri2 A B C P Q R
39 | eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R
40 | eqratio6 B A B C Q P Q R, eqangle6 B A B C Q P Q R, ncoll A B C => simtri* A B C P Q R
41 | eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri* A B C P Q R
42 | para a b c d, coll m a d, coll n b c, eqratio6 m a m d n b n c, sameside m a d n b c => para m n a b
43 | para a b c d, coll m a d, coll n b c, para m n a b => eqratio6 m a m d n b n c
44 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | # !/bin/bash
17 | set -e
18 | set -x
19 |
20 | virtualenv -p python3 .
21 | source ./bin/activate
22 |
23 | pip install --require-hashes -r requirements.txt
24 |
25 | gdown --folder https://bit.ly/alphageometry
26 | DATA=ag_ckpt_vocab
27 |
28 | MELIAD_PATH=meliad_lib/meliad
29 | mkdir -p $MELIAD_PATH
30 | git clone https://github.com/google-research/meliad $MELIAD_PATH
31 | export PYTHONPATH=$PYTHONPATH:$MELIAD_PATH
32 |
33 | DDAR_ARGS=(
34 | --defs_file=$(pwd)/defs.txt \
35 | --rules_file=$(pwd)/rules.txt \
36 | );
37 |
38 | BATCH_SIZE=2
39 | BEAM_SIZE=2
40 | DEPTH=2
41 |
42 | SEARCH_ARGS=(
43 | --beam_size=$BEAM_SIZE
44 | --search_depth=$DEPTH
45 | )
46 |
47 | LM_ARGS=(
48 | --ckpt_path=$DATA \
49 | --vocab_path=$DATA/geometry.757.model \
50 | --gin_search_paths=$MELIAD_PATH/transformer/configs \
51 | --gin_file=base_htrans.gin \
52 | --gin_file=size/medium_150M.gin \
53 | --gin_file=options/positions_t5.gin \
54 | --gin_file=options/lr_cosine_decay.gin \
55 | --gin_file=options/seq_1024_nocache.gin \
56 | --gin_file=geometry_150M_generate.gin \
57 | --gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True \
58 | --gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE \
59 | --gin_param=TransformerTaskConfig.sequence_length=128 \
60 | --gin_param=Trainer.restore_state_variables=False
61 | );
62 |
63 | echo $PYTHONPATH
64 |
65 | python -m alphageometry \
66 | --alsologtostderr \
67 | --problems_file=$(pwd)/examples.txt \
68 | --problem_name=orthocenter \
69 | --mode=alphageometry \
70 | "${DDAR_ARGS[@]}" \
71 | "${SEARCH_ARGS[@]}" \
72 | "${LM_ARGS[@]}"
73 |
--------------------------------------------------------------------------------
/run_tests.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | DATA=ag_ckpt_vocab
17 | MELIAD_PATH=meliad_lib/meliad
18 | export PYTHONPATH=$PYTHONPATH:$MELIAD_PATH
19 |
20 | python problem_test.py
21 | python geometry_test.py
22 | python graph_utils_test.py
23 | python numericals_test.py
24 | python graph_test.py
25 | python dd_test.py
26 | python ar_test.py
27 | python ddar_test.py
28 | python trace_back_test.py
29 | python alphageometry_test.py
30 | python lm_inference_test.py --meliad_path=$MELIAD_PATH --data_path=$DATA
31 |
--------------------------------------------------------------------------------
/trace_back.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Implements DAG-level traceback."""
17 |
18 | from typing import Any
19 |
20 | import geometry as gm
21 | import pretty as pt
22 | import problem
23 |
24 |
25 | pretty = pt.pretty
26 |
27 |
28 | def point_levels(
29 | setup: list[problem.Dependency], existing_points: list[gm.Point]
30 | ) -> list[tuple[set[gm.Point], list[problem.Dependency]]]:
31 | """Reformat setup into levels of point constructions."""
32 | levels = []
33 | for con in setup:
34 | plevel = max([p.plevel for p in con.args if isinstance(p, gm.Point)])
35 |
36 | while len(levels) - 1 < plevel:
37 | levels.append((set(), []))
38 |
39 | for p in con.args:
40 | if not isinstance(p, gm.Point):
41 | continue
42 | if existing_points and p in existing_points:
43 | continue
44 |
45 | levels[p.plevel][0].add(p)
46 |
47 | cons = levels[plevel][1]
48 | cons.append(con)
49 |
50 | return [(p, c) for p, c in levels if p or c]
51 |
52 |
53 | def point_log(
54 | setup: list[problem.Dependency],
55 | ref_id: dict[tuple[str, ...], int],
56 | existing_points=list[gm.Point],
57 | ) -> list[tuple[list[gm.Point], list[problem.Dependency]]]:
58 | """Reformat setup into groups of point constructions."""
59 | log = []
60 |
61 | levels = point_levels(setup, existing_points)
62 |
63 | for points, cons in levels:
64 | for con in cons:
65 | if con.hashed() not in ref_id:
66 | ref_id[con.hashed()] = len(ref_id)
67 |
68 | log.append((points, cons))
69 |
70 | return log
71 |
72 |
73 | def setup_to_levels(
74 | setup: list[problem.Dependency],
75 | ) -> list[list[problem.Dependency]]:
76 | """Reformat setup into levels of point constructions."""
77 | levels = []
78 | for d in setup:
79 | plevel = max([p.plevel for p in d.args if isinstance(p, gm.Point)])
80 | while len(levels) - 1 < plevel:
81 | levels.append([])
82 |
83 | levels[plevel].append(d)
84 |
85 | levels = [lvl for lvl in levels if lvl]
86 | return levels
87 |
88 |
89 | def separate_dependency_difference(
90 | query: problem.Dependency,
91 | log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
92 | ) -> tuple[
93 | list[tuple[list[problem.Dependency], list[problem.Dependency]]],
94 | list[problem.Dependency],
95 | list[problem.Dependency],
96 | set[gm.Point],
97 | set[gm.Point],
98 | ]:
99 | """Identify and separate the dependency difference."""
100 | setup = []
101 | log_, log = log, []
102 | for prems, cons in log_:
103 | if not prems:
104 | setup.extend(cons)
105 | continue
106 | cons_ = []
107 | for con in cons:
108 | if con.rule_name == 'c0':
109 | setup.append(con)
110 | else:
111 | cons_.append(con)
112 | if not cons_:
113 | continue
114 |
115 | prems = [p for p in prems if p.name != 'ind']
116 | log.append((prems, cons_))
117 |
118 | points = set(query.args)
119 | queue = list(query.args)
120 | i = 0
121 | while i < len(queue):
122 | q = queue[i]
123 | i += 1
124 | if not isinstance(q, gm.Point):
125 | continue
126 | for p in q.rely_on:
127 | if p not in points:
128 | points.add(p)
129 | queue.append(p)
130 |
131 | setup_, setup, aux_setup, aux_points = setup, [], [], set()
132 | for con in setup_:
133 | if con.name == 'ind':
134 | continue
135 | elif any([p not in points for p in con.args if isinstance(p, gm.Point)]):
136 | aux_setup.append(con)
137 | aux_points.update(
138 | [p for p in con.args if isinstance(p, gm.Point) and p not in points]
139 | )
140 | else:
141 | setup.append(con)
142 |
143 | return log, setup, aux_setup, points, aux_points
144 |
145 |
146 | def recursive_traceback(
147 | query: problem.Dependency,
148 | ) -> list[tuple[list[problem.Dependency], list[problem.Dependency]]]:
149 | """Recursively traceback from the query, i.e. the conclusion."""
150 | visited = set()
151 | log = []
152 | stack = []
153 |
154 | def read(q: problem.Dependency) -> None:
155 | q = q.remove_loop()
156 | hashed = q.hashed()
157 | if hashed in visited:
158 | return
159 |
160 | if hashed[0] in ['ncoll', 'npara', 'nperp', 'diff', 'sameside']:
161 | return
162 |
163 | nonlocal stack
164 |
165 | stack.append(hashed)
166 | prems = []
167 |
168 | if q.rule_name != problem.CONSTRUCTION_RULE:
169 | all_deps = []
170 | dep_names = set()
171 | for d in q.why:
172 | if d.hashed() in dep_names:
173 | continue
174 | dep_names.add(d.hashed())
175 | all_deps.append(d)
176 |
177 | for d in all_deps:
178 | h = d.hashed()
179 | if h not in visited:
180 | read(d)
181 | if h in visited:
182 | prems.append(d)
183 |
184 | visited.add(hashed)
185 | hashs = sorted([d.hashed() for d in prems])
186 | found = False
187 | for ps, qs in log:
188 | if sorted([d.hashed() for d in ps]) == hashs:
189 | qs += [q]
190 | found = True
191 | break
192 | if not found:
193 | log.append((prems, [q]))
194 |
195 | stack.pop(-1)
196 |
197 | read(query)
198 |
199 | # post process log: separate multi-conclusion lines
200 | log_, log = log, []
201 | for ps, qs in log_:
202 | for q in qs:
203 | log.append((ps, [q]))
204 |
205 | return log
206 |
207 |
208 | def collx_to_coll_setup(
209 | setup: list[problem.Dependency],
210 | ) -> list[problem.Dependency]:
211 | """Convert collx to coll in setups."""
212 | result = []
213 | for level in setup_to_levels(setup):
214 | hashs = set()
215 | for dep in level:
216 | if dep.name == 'collx':
217 | dep.name = 'coll'
218 | dep.args = list(set(dep.args))
219 |
220 | if dep.hashed() in hashs:
221 | continue
222 | hashs.add(dep.hashed())
223 | result.append(dep)
224 |
225 | return result
226 |
227 |
228 | def collx_to_coll(
229 | setup: list[problem.Dependency],
230 | aux_setup: list[problem.Dependency],
231 | log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
232 | ) -> tuple[
233 | list[problem.Dependency],
234 | list[problem.Dependency],
235 | list[tuple[list[problem.Dependency], list[problem.Dependency]]],
236 | ]:
237 | """Convert collx to coll and dedup."""
238 | setup = collx_to_coll_setup(setup)
239 | aux_setup = collx_to_coll_setup(aux_setup)
240 |
241 | con_set = set([p.hashed() for p in setup + aux_setup])
242 | log_, log = log, []
243 | for prems, cons in log_:
244 | prem_set = set()
245 | prems_, prems = prems, []
246 | for p in prems_:
247 | if p.name == 'collx':
248 | p.name = 'coll'
249 | p.args = list(set(p.args))
250 | if p.hashed() in prem_set:
251 | continue
252 | prem_set.add(p.hashed())
253 | prems.append(p)
254 |
255 | cons_, cons = cons, []
256 | for c in cons_:
257 | if c.name == 'collx':
258 | c.name = 'coll'
259 | c.args = list(set(c.args))
260 | if c.hashed() in con_set:
261 | continue
262 | con_set.add(c.hashed())
263 | cons.append(c)
264 |
265 | if not cons or not prems:
266 | continue
267 |
268 | log.append((prems, cons))
269 |
270 | return setup, aux_setup, log
271 |
272 |
273 | def get_logs(
274 | query: problem.Dependency, g: Any, merge_trivials: bool = False
275 | ) -> tuple[
276 | list[problem.Dependency],
277 | list[problem.Dependency],
278 | list[tuple[list[problem.Dependency], list[problem.Dependency]]],
279 | set[gm.Point],
280 | ]:
281 | """Given a DAG and conclusion N, return the premise, aux, proof."""
282 | query = query.why_me_or_cache(g, query.level)
283 | log = recursive_traceback(query)
284 | log, setup, aux_setup, setup_points, _ = separate_dependency_difference(
285 | query, log
286 | )
287 |
288 | setup, aux_setup, log = collx_to_coll(setup, aux_setup, log)
289 |
290 | setup, aux_setup, log = shorten_and_shave(
291 | setup, aux_setup, log, merge_trivials
292 | )
293 |
294 | return setup, aux_setup, log, setup_points
295 |
296 |
297 | def shorten_and_shave(
298 | setup: list[problem.Dependency],
299 | aux_setup: list[problem.Dependency],
300 | log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
301 | merge_trivials: bool = False,
302 | ) -> tuple[
303 | list[problem.Dependency],
304 | list[problem.Dependency],
305 | list[tuple[list[problem.Dependency], list[problem.Dependency]]],
306 | ]:
307 | """Shorten the proof by removing unused predicates."""
308 | log, _ = shorten_proof(log, merge_trivials=merge_trivials)
309 |
310 | all_prems = sum([list(prems) for prems, _ in log], [])
311 | all_prems = set([p.hashed() for p in all_prems])
312 | setup = [d for d in setup if d.hashed() in all_prems]
313 | aux_setup = [d for d in aux_setup if d.hashed() in all_prems]
314 | return setup, aux_setup, log
315 |
316 |
317 | def join_prems(
318 | con: problem.Dependency,
319 | con2prems: dict[tuple[str, ...], list[problem.Dependency]],
320 | expanded: set[tuple[str, ...]],
321 | ) -> list[problem.Dependency]:
322 | """Join proof steps with the same premises."""
323 | h = con.hashed()
324 | if h in expanded or h not in con2prems:
325 | return [con]
326 |
327 | result = []
328 | for p in con2prems[h]:
329 | result += join_prems(p, con2prems, expanded)
330 | return result
331 |
332 |
333 | def shorten_proof(
334 | log: list[tuple[list[problem.Dependency], list[problem.Dependency]]],
335 | merge_trivials: bool = False,
336 | ) -> tuple[
337 | list[tuple[list[problem.Dependency], list[problem.Dependency]]],
338 | dict[tuple[str, ...], list[problem.Dependency]],
339 | ]:
340 | """Join multiple trivials proof steps into one."""
341 | pops = set()
342 | con2prem = {}
343 | for prems, cons in log:
344 | assert len(cons) == 1
345 | con = cons[0]
346 | if con.rule_name == '': # pylint: disable=g-explicit-bool-comparison
347 | con2prem[con.hashed()] = prems
348 | elif not merge_trivials:
349 | # except for the ones that are premises to non-trivial steps.
350 | pops.update({p.hashed() for p in prems})
351 |
352 | for p in pops:
353 | if p in con2prem:
354 | con2prem.pop(p)
355 |
356 | expanded = set()
357 | log2 = []
358 | for i, (prems, cons) in enumerate(log):
359 | con = cons[0]
360 | if i < len(log) - 1 and con.hashed() in con2prem:
361 | continue
362 |
363 | hashs = set()
364 | new_prems = []
365 |
366 | for p in sum([join_prems(p, con2prem, expanded) for p in prems], []):
367 | if p.hashed() not in hashs:
368 | new_prems.append(p)
369 | hashs.add(p.hashed())
370 |
371 | log2 += [(new_prems, [con])]
372 | expanded.add(con.hashed())
373 |
374 | return log2, con2prem
375 |
--------------------------------------------------------------------------------
/trace_back_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Unit testing for the trace_back code."""
17 |
18 | import unittest
19 |
20 | from absl.testing import absltest
21 | import ddar
22 | import graph as gh
23 | import problem as pr
24 | import trace_back as tb
25 |
26 |
27 | class TracebackTest(unittest.TestCase):
28 |
29 | @classmethod
30 | def setUpClass(cls):
31 | super().setUpClass()
32 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
33 | cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True)
34 |
35 | def test_orthocenter_dependency_difference(self):
36 | txt = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c' # pylint: disable=line-too-long
37 | p = pr.Problem.from_txt(txt)
38 | g, _ = gh.Graph.build_problem(p, TracebackTest.defs)
39 |
40 | ddar.solve(g, TracebackTest.rules, p)
41 |
42 | goal_args = g.names2nodes(p.goal.args)
43 | query = pr.Dependency(p.goal.name, goal_args, None, None)
44 |
45 | setup, aux, _, _ = tb.get_logs(query, g, merge_trivials=False)
46 |
47 | # Convert each predicates to its hash string:
48 | setup = [p.hashed() for p in setup]
49 | aux = [p.hashed() for p in aux]
50 |
51 | self.assertCountEqual(
52 | setup, [('perp', 'a', 'c', 'b', 'd'), ('perp', 'a', 'b', 'c', 'd')]
53 | )
54 |
55 | self.assertCountEqual(
56 | aux, [('coll', 'a', 'c', 'e'), ('coll', 'b', 'd', 'e')]
57 | )
58 |
59 |
60 | if __name__ == '__main__':
61 | absltest.main()
62 |
--------------------------------------------------------------------------------