├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── alphageometry.py ├── alphageometry_test.py ├── ar.py ├── ar_test.py ├── beam_search.py ├── dd.py ├── dd_test.py ├── ddar.py ├── ddar_test.py ├── decoder_stack.py ├── defs.txt ├── download.sh ├── examples.txt ├── fig1.svg ├── geometry.py ├── geometry_150M_generate.gin ├── geometry_test.py ├── graph.py ├── graph_test.py ├── graph_utils.py ├── graph_utils_test.py ├── imo_ag_30.txt ├── jgex_ag_231.txt ├── lm_inference.py ├── lm_inference_test.py ├── models.py ├── numericals.py ├── numericals_test.py ├── pretty.py ├── problem.py ├── problem_test.py ├── requirements.in ├── requirements.txt ├── rules.txt ├── run.sh ├── run_tests.sh ├── trace_back.py ├── trace_back_test.py └── transformer_layer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Solving Olympiad Geometry without Human Demonstrations 3 | 4 | 5 | This repository contains the code necessary to 6 | reproduce DDAR and AlphaGeometry, 7 | the two geometry theorem provers 8 | introduced in the [Nature 2024](https://www.nature.com/articles/s41586-023-06747-5) paper: 9 | 10 | *
"Solving Olympiad Geometry without Human Demonstrations".
* 11 | 12 | 13 |
14 | 15 | 16 |
17 | fig1 18 |
19 | 20 | 21 | ## Dependencies 22 | 23 | For the instructions presented below, 24 | we use Python 3.10.9, and dependencies with their exact 25 | version numbers listed in `requirements.txt`. 26 | 27 | Our code depends on `meliad`, which is 28 | not a registered package with `pip`. See instructions below 29 | for how to manually install `meliad`. 30 | 31 | Note that one can still run the DDAR solver 32 | without the `meliad` and `sentencepiece` dependencies. 33 | 34 | ## Run the instructions 35 | 36 | All instructions in this `README.md` can be run in one go by: 37 | 38 | ``` 39 | bash run.sh 40 | ``` 41 | 42 | Below, we explain these instructions step-by-step. 43 | 44 | ## Install dependencies, download weights and vocabulary. 45 | 46 | Installation is done in a virtual environment: 47 | 48 | ``` 49 | virtualenv -p python3 . 50 | source ./bin/activate 51 | pip install --require-hashes -r requirements.txt 52 | ``` 53 | 54 | Download weights and vocabulary: 55 | 56 | ``` 57 | bash download.sh 58 | DATA=ag_ckpt_vocab 59 | ``` 60 | 61 | Finally, install `meliad` separately as it is not 62 | registered with `pip`: 63 | 64 | ``` 65 | MELIAD_PATH=meliad_lib/meliad 66 | mkdir -p $MELIAD_PATH 67 | git clone https://github.com/google-research/meliad $MELIAD_PATH 68 | export PYTHONPATH=$PYTHONPATH:$MELIAD_PATH 69 | ``` 70 | 71 | ## Set up common flags 72 | 73 | Before running the python scripts, 74 | let us first prepare some commonly used flags. 75 | The symbolic engine needs definitions and deduction rules to operate. 76 | These definitions and rules are provided in two text files 77 | `defs.txt` and `rules.txt`. 78 | 79 | ```shell 80 | DDAR_ARGS=( 81 | --defs_file=$(pwd)/defs.txt \ 82 | --rules_file=$(pwd)/rules.txt \ 83 | ); 84 | ``` 85 | 86 | Next, we define the flags relevant to the proof search. 87 | To reproduce the simple examples below, 88 | we use lightweight values for the proof search parameters: 89 | 90 | ```shell 91 | BATCH_SIZE=2 92 | BEAM_SIZE=2 93 | DEPTH=2 94 | 95 | SEARCH_ARGS=( 96 | --beam_size=$BEAM_SIZE 97 | --search_depth=$DEPTH 98 | ) 99 | ``` 100 | 101 | NOTE: The results in our paper can be obtained by setting 102 | `BATCH_SIZE=32`, `BEAM_SIZE=512`, `DEPTH=16` 103 | as described in section Methods. 104 | To stay under IMO time limits, 4 V100-GPUs and 250 CPU workers 105 | are needed as shown in Extended Data - Figure 1. 106 | Note that we also strip away other memory/speed optimizations 107 | due to internal dependencies and to promote code clarity. 108 | 109 | Assume the downloaded checkpoint and vocabulary is placed in `DATA`, 110 | and the installed `meliad` source code is at `MELIAD_PATH`. 111 | We make use of the `gin` library to manage model configurations, 112 | following `meliad` conventions. We now define the flags relevant to the 113 | language model: 114 | 115 | ```shell 116 | LM_ARGS=( 117 | --ckpt_path=$DATA \ 118 | --vocab_path=$DATA/geometry.757.model 119 | --gin_search_paths=$MELIAD_PATH/transformer/configs,$(pwd) \ 120 | --gin_file=base_htrans.gin \ 121 | --gin_file=size/medium_150M.gin \ 122 | --gin_file=options/positions_t5.gin \ 123 | --gin_file=options/lr_cosine_decay.gin \ 124 | --gin_file=options/seq_1024_nocache.gin \ 125 | --gin_file=geometry_150M_generate.gin \ 126 | --gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True \ 127 | --gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE \ 128 | --gin_param=TransformerTaskConfig.sequence_length=128 \ 129 | --gin_param=Trainer.restore_state_variables=False 130 | ); 131 | ``` 132 | 133 | TIP: Note that you can still run the DDAR solver 134 | without defining `SEARCH_ARGS` and `LM_ARGS`. 135 | In such case, simply disable the import of the `lm_inference` module 136 | inside `alphageometry.py`. 137 | 138 | ## Run DDAR 139 | 140 | The script loads a problem by reading a list of problems 141 | from a text file and solves the specific problem in the list according 142 | to its name. We pass these two pieces of information through the flags 143 | `--problems_file` and `--problem_name`. 144 | We use `--mode=ddar` to indicate that we want to use the DDAR solver. 145 | 146 | Below we showed this solver solving IMO 2000 P1: 147 | 148 | ```shell 149 | python -m alphageometry \ 150 | --alsologtostderr \ 151 | --problems_file=$(pwd)/imo_ag_30.txt \ 152 | --problem_name=translated_imo_2000_p1 \ 153 | --mode=ddar \ 154 | "${DDAR_ARGS[@]}" 155 | ``` 156 | 157 | Expect the following output 158 | 159 | ```shell 160 | graph.py:468] translated_imo_2000_p1 161 | graph.py:469] a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m = on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b, on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q 162 | ddar.py:41] Depth 1/1000 time = 1.7772269248962402 163 | ddar.py:41] Depth 2/1000 time = 5.63526177406311 164 | ddar.py:41] Depth 3/1000 time = 6.883412837982178 165 | ddar.py:41] Depth 4/1000 time = 10.275688409805298 166 | ddar.py:41] Depth 5/1000 time = 12.048273086547852 167 | alphageometry.py:190] 168 | ========================== 169 | * From theorem premises: 170 | A B G1 G2 M N C D E P Q : Points 171 | AG_1 ⟂ AB [00] 172 | BA ⟂ G_2B [01] 173 | G_2M = G_2B [02] 174 | G_1M = G_1A [03] 175 | 176 | ... 177 | [log omitted] 178 | ... 179 | 180 | 036. ∠QEB = ∠(QP-EA) [46] & ∠(BE-QP) = ∠AEP [55] ⇒ ∠EQP = ∠QPE [56] 181 | 037. ∠PQE = ∠EPQ [56] ⇒ EP = EQ 182 | 183 | ========================== 184 | ``` 185 | 186 | The output first includes a list of relevant premises that it uses, 187 | and then proof steps that gradually build up the proof. 188 | All predicates are numbered to track how they are derived 189 | from the premises, and to show that the proof is fully justified. 190 | 191 | TIP: Additionally passing the flag `--out_file=path/to/output/text/file.txt` 192 | will write the proof to a text file. 193 | 194 | Running on all problems in `imo_ag_30.txt` will yield solutions to 195 | 14 of them, as reported in Table 1 in our paper. 196 | 197 | ## Run AlphaGeometry: 198 | 199 | As a simple example, we load `--problem_name=orthocenter` 200 | from `--problem_file=examples.txt`. 201 | This time, we pass `--mode=alphageometry` to use the AlphaGeometry solver 202 | and pass the `SEARCH_ARGS` and `LM_ARGS` flags. 203 | 204 | ```shell 205 | python -m alphageometry \ 206 | --alsologtostderr \ 207 | --problems_file=$(pwd)/examples.txt \ 208 | --problem_name=orthocenter \ 209 | --mode=alphageometry \ 210 | "${DDAR_ARGS[@]}" \ 211 | "${SEARCH_ARGS[@]}" \ 212 | "${LM_ARGS[@]}" 213 | ``` 214 | 215 | Expect the following output: 216 | 217 | ```shell 218 | ... 219 | [log omitted] 220 | ... 221 | training_loop.py:725] Total parameters: 152072288 222 | training_loop.py:739] Total state size: 0 223 | training_loop.py:492] Training loop: creating task for mode beam_search 224 | 225 | graph.py:468] orthocenter 226 | graph.py:469] a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c 227 | ddar.py:41] Depth 1/1000 time = 0.009987592697143555 branch = 4 228 | ddar.py:41] Depth 2/1000 time = 0.00672602653503418 branch = 0 229 | alphageometry.py:221] DD+AR failed to solve the problem. 230 | alphageometry.py:457] Depth 0. There are 1 nodes to expand: 231 | alphageometry.py:460] {S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c {F1} x00 232 | alphageometry.py:465] Decoding from {S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c {F1} x00 233 | ... 234 | [log omitted] 235 | ... 236 | alphageometry.py:470] LM output (score=-1.102287): "e : C a c e 02 C b d e 03 ;" 237 | alphageometry.py:471] Translation: "e = on_line e a c, on_line e b d" 238 | 239 | alphageometry.py:480] Solving: "a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c" 240 | graph.py:468] 241 | graph.py:469] a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c 242 | ddar.py:41] Depth 1/1000 time = 0.021120786666870117 243 | ddar.py:41] Depth 2/1000 time = 0.033370018005371094 244 | ddar.py:41] Depth 3/1000 time = 0.04297471046447754 245 | alphageometry.py:140] 246 | ========================== 247 | * From theorem premises: 248 | A B C D : Points 249 | BD ⟂ AC [00] 250 | CD ⟂ AB [01] 251 | 252 | * Auxiliary Constructions: 253 | E : Points 254 | E,B,D are collinear [02] 255 | E,C,A are collinear [03] 256 | 257 | * Proof steps: 258 | 001. E,B,D are collinear [02] & E,C,A are collinear [03] & BD ⟂ AC [00] ⇒ ∠BEA = ∠CED [04] 259 | 002. E,B,D are collinear [02] & E,C,A are collinear [03] & BD ⟂ AC [00] ⇒ ∠BEC = ∠AED [05] 260 | 003. A,E,C are collinear [03] & E,B,D are collinear [02] & AC ⟂ BD [00] ⇒ EC ⟂ EB [06] 261 | 004. EC ⟂ EB [06] & CD ⟂ AB [01] ⇒ ∠(EC-BA) = ∠(EB-CD) [07] 262 | 005. E,C,A are collinear [03] & E,B,D are collinear [02] & ∠(EC-BA) = ∠(EB-CD) [07] ⇒ ∠BAE = ∠CDE [08] 263 | 006. ∠BEA = ∠CED [04] & ∠BAE = ∠CDE [08] (Similar Triangles)⇒ EB:EC = EA:ED [09] 264 | 007. EB:EC = EA:ED [09] & ∠BEC = ∠AED [05] (Similar Triangles)⇒ ∠BCE = ∠ADE [10] 265 | 008. EB:EC = EA:ED [09] & ∠BEC = ∠AED [05] (Similar Triangles)⇒ ∠EBC = ∠EAD [11] 266 | 009. ∠BCE = ∠ADE [10] & E,C,A are collinear [03] & E,B,D are collinear [02] & ∠EBC = ∠EAD [11] ⇒ AD ⟂ BC 267 | ========================== 268 | 269 | alphageometry.py:505] Solved. 270 | ``` 271 | 272 | NOTE: Point `H` is automatically renamed to `D`, 273 | as the LM is trained on synthetic problems 274 | where the points are named alphabetically, and so it expects 275 | the same during test time. 276 | 277 | NOTE: In this implementation of AlphaGeometry, 278 | we removed all optimizations that are dependent on 279 | internal infrastructure, e.g., 280 | parallelized model inference on multi GPUs, 281 | parallelized DDAR on multiple CPUs, 282 | parallel execution of LM and DDAR, 283 | shared pool of CPU workers across different problems, etc. 284 | We also removed some memory/speed optimizations and code 285 | abstractions in favor of code clarity. 286 | 287 | As can be seen in the output, initially DDAR failed to solve the problem. 288 | The LM proposes two auxiliary constructions (because `BATCH_SIZE=2`): 289 | 290 | * `e = eqdistance e c a b, eqdistance e b a c`, i.e., 291 | construct `E` as the intersection of circle (center=C, radius=AB) and 292 | circle (center=B, radius=AC). This construction has a score of `-1.186`. 293 | * `e = on_line e a c, on_line e b d`, i.e., 294 | `E` is the intersection of `AC` and `BD`. 295 | This construction has a higher score (`-1.102287`) than the previous. 296 | 297 | Since the second construction has a higher score, DDAR attempted the second 298 | construction first and found the solution right away. 299 | The proof search therefore terminates and there is no second iteration. 300 | 301 | ## Results 302 | 303 | Before attempting to reproduce the AlphaGeometry numbers in our paper, 304 | please make sure to pass all tests in the prepared test suite: 305 | 306 | ``` 307 | bash run_tests.sh 308 | ``` 309 | 310 | NOTE: [Issues#14](https://github.com/google-deepmind/alphageometry/issues/14) reports that although the top beam decodes are still the same, the LM is not giving the same score for different users. 311 | 312 | Then, pass the corresponding values for `--problem_file` (column) 313 | and `--mode` (row), and 314 | iterate on all problems to obtain the following results: 315 | 316 |
317 | 318 | Number of solved problems: 319 | 320 | | | `imo_ag_30.txt` | `jgex_ag_231.txt` | 321 | |----------|------------------|-------------------| 322 | | `ddar` | 14 | 198 | 323 | | `alphageometry` | 25 | 228 | 324 | 325 |
326 | 327 | ## Source code description 328 | 329 | Files in this repository include python modules/scripts to run the solvers and 330 | resource files necessary for the script to execute. We listed below 331 | each of them and their description. 332 | 333 | | File name | Description | 334 | |------------------------|------------------------------------------------------------------------------------| 335 | | `geometry.py` | Implements nodes (Point, Line, Circle, etc) in the proof state graph. | 336 | | `numericals.py` | Implements the numerical engine in the dynamic geometry environment. | 337 | | `graph_utils.py` | Implements utilities for the proof state graph. | 338 | | `graph.py` | Implements the proof state graph. | 339 | | `problem.py` | Implements the classes that represent the problem premises, conclusion, DAG nodes. | 340 | | `dd.py` | Implements DD and its traceback. | 341 | | `ar.py` | Implements AR and its traceback. | 342 | | `trace_back.py` | Implements the recursive traceback and dependency difference algorithm. | 343 | | `ddar.py` | Implements the combination DD+AR. | 344 | | `beam_search.py` | Implements beam decoding of a language model in JAX. | 345 | | `models.py` | Implements the transformer model. | 346 | | `transformer_layer.py` | Implements the transformer layer. | 347 | | `decoder_stack.py` | Implements the transformer decoder stack. | 348 | | `lm_inference.py` | Implements an interface to a trained LM to perform decoding. | 349 | | `alphageometry.py` | Main script that loads problems, calls DD+AR or AlphaGeometry solver, and prints solutions. | 350 | | `pretty.py` | Pretty formating the solutions output by solvers. | 351 | | `*_test.py` | Tests for the corresponding module. | 352 | | `download.sh` | Script to download model checkpoints and LM | 353 | | `run.sh` | Script to execute instructions in README. | 354 | | `run_tests.sh` | Script to execute the test suite. | 355 | 356 | 357 | Resource files: 358 | 359 | | Resource file name | Description | 360 | |------------------------|------------------------------------------------------------------------------------| 361 | | `defs.txt` | Definitions of different geometric construction actions. | 362 | | `rules.txt` | Deduction rules for DD. | 363 | | `geometry_150M_generate.gin`| Gin config of the LM implemented in meliad. | 364 | | `imo_ag_30.txt` | Problems in IMO-AG-30. | 365 | | `jgex_ag_231.txt` | Problems in JGEX-AG-231. | 366 | 367 | 368 | 369 | ## Citing this work 370 | 371 | ```bibtex 372 | @Article{AlphaGeometryTrinh2024, 373 | author = {Trinh, Trieu and Wu, Yuhuai and Le, Quoc and He, He and Luong, Thang}, 374 | journal = {Nature}, 375 | title = {Solving Olympiad Geometry without Human Demonstrations}, 376 | year = {2024}, 377 | doi = {10.1038/s41586-023-06747-5} 378 | } 379 | ``` 380 | 381 | ## Acknowledgements 382 | 383 | This research is a collaboration between the Google Brain team 384 | (now Google Deepmind) and 385 | the Computer Science Department of New York University. 386 | We thank Rif A. Saurous, Denny Zhou, Christian Szegedy, Delesley Hutchins, 387 | Thomas Kipf, Hieu Pham, Petar Veličković, Debidatta Dwibedi, 388 | Kyunghyun Cho, Lerrel Pinto, Alfredo Canziani, 389 | Thomas Wies, He He’s research group, 390 | Evan Chen (the USA’s IMO team coach), 391 | Mirek Olsak, Patrik Bak, 392 | and all three Nature's referees for their help and support. 393 | 394 | The code of AlphaGeometry communicates with and/or references the following 395 | separate libraries and packages: 396 | 397 | * [Abseil](https://github.com/abseil/abseil-py) 398 | * [JAX](https://github.com/google/jax/) 399 | * [matplotlib](https://matplotlib.org/) 400 | * [NumPy](https://numpy.org) 401 | * [SciPy](https://scipy.org) 402 | * [TensorFlow](https://github.com/tensorflow/tensorflow) 403 | * [Meliad](https://github.com/google-research/meliad) 404 | * [Flax](https://github.com/google/flax) 405 | * [Gin](https://github.com/google/gin-config) 406 | * [T5](https://github.com/google-research/text-to-text-transfer-transformer) 407 | * [SentencePiece](https://github.com/google/sentencepiece) 408 | 409 | 410 | 411 | We thank all their contributors and maintainers! 412 | 413 | 414 | ## Disclaimer 415 | 416 | This is not an officially supported Google product. 417 | 418 | This research code is provided "as-is" to the broader research community. 419 | Google does not promise to maintain or otherwise support this code in any way. 420 | 421 | ## Code License 422 | 423 | Copyright 2023 DeepMind Technologies Limited 424 | 425 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 426 | you may not use this file except in compliance with the Apache 2.0 license. 427 | You may obtain a copy of the Apache 2.0 license at: 428 | https://www.apache.org/licenses/LICENSE-2.0 429 | 430 | All other materials are licensed under the Creative Commons Attribution 4.0 431 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 432 | https://creativecommons.org/licenses/by/4.0/legalcode 433 | 434 | Unless required by applicable law or agreed to in writing, all software and 435 | materials distributed here under the Apache 2.0 or CC-BY licenses are 436 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 437 | either express or implied. See the licenses for the specific language governing 438 | permissions and limitations under those licenses. 439 | 440 | ## Model Parameters License 441 | 442 | The AlphaGeometry checkpoints and vocabulary are made available 443 | under the terms of the Creative Commons Attribution 4.0 444 | International (CC BY 4.0) license. 445 | You can find details at: 446 | https://creativecommons.org/licenses/by/4.0/legalcode 447 | 448 | -------------------------------------------------------------------------------- /alphageometry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Run DD+AR or AlphaGeometry solver. 17 | 18 | Please refer to README.md for detailed instructions. 19 | """ 20 | 21 | import traceback 22 | 23 | from absl import app 24 | from absl import flags 25 | from absl import logging 26 | import ddar 27 | import graph as gh 28 | import lm_inference as lm 29 | import pretty as pt 30 | import problem as pr 31 | 32 | 33 | _GIN_SEARCH_PATHS = flags.DEFINE_list( 34 | 'gin_search_paths', 35 | ['third_party/py/meliad/transformer/configs'], 36 | 'List of paths where the Gin config files are located.', 37 | ) 38 | _GIN_FILE = flags.DEFINE_multi_string( 39 | 'gin_file', ['base_htrans.gin'], 'List of Gin config files.' 40 | ) 41 | _GIN_PARAM = flags.DEFINE_multi_string( 42 | 'gin_param', None, 'Newline separated list of Gin parameter bindings.' 43 | ) 44 | 45 | _PROBLEMS_FILE = flags.DEFINE_string( 46 | 'problems_file', 47 | 'imo_ag_30.txt', 48 | 'text file contains the problem strings. See imo_ag_30.txt for example.', 49 | ) 50 | _PROBLEM_NAME = flags.DEFINE_string( 51 | 'problem_name', 52 | 'imo_2000_p1', 53 | 'name of the problem to solve, must be in the problem_file.', 54 | ) 55 | _MODE = flags.DEFINE_string( 56 | 'mode', 'ddar', 'either `ddar` (DD+AR) or `alphageometry`') 57 | _DEFS_FILE = flags.DEFINE_string( 58 | 'defs_file', 59 | 'defs.txt', 60 | 'definitions of available constructions to state a problem.', 61 | ) 62 | _RULES_FILE = flags.DEFINE_string( 63 | 'rules_file', 'rules.txt', 'list of deduction rules used by DD.' 64 | ) 65 | _CKPT_PATH = flags.DEFINE_string('ckpt_path', '', 'checkpoint of the LM model.') 66 | _VOCAB_PATH = flags.DEFINE_string( 67 | 'vocab_path', '', 'path to the LM vocab file.' 68 | ) 69 | _OUT_FILE = flags.DEFINE_string( 70 | 'out_file', '', 'path to the solution output file.' 71 | ) # pylint: disable=line-too-long 72 | _BEAM_SIZE = flags.DEFINE_integer( 73 | 'beam_size', 1, 'beam size of the proof search.' 74 | ) # pylint: disable=line-too-long 75 | _SEARCH_DEPTH = flags.DEFINE_integer( 76 | 'search_depth', 1, 'search depth of the proof search.' 77 | ) # pylint: disable=line-too-long 78 | 79 | DEFINITIONS = None # contains definitions of construction actions 80 | RULES = None # contains rules of deductions 81 | 82 | 83 | def natural_language_statement(logical_statement: pr.Dependency) -> str: 84 | """Convert logical_statement to natural language. 85 | 86 | Args: 87 | logical_statement: pr.Dependency with .name and .args 88 | 89 | Returns: 90 | a string of (pseudo) natural language of the predicate for human reader. 91 | """ 92 | names = [a.name.upper() for a in logical_statement.args] 93 | names = [(n[0] + '_' + n[1:]) if len(n) > 1 else n for n in names] 94 | return pt.pretty_nl(logical_statement.name, names) 95 | 96 | 97 | def proof_step_string( 98 | proof_step: pr.Dependency, refs: dict[tuple[str, ...], int], last_step: bool 99 | ) -> str: 100 | """Translate proof to natural language. 101 | 102 | Args: 103 | proof_step: pr.Dependency with .name and .args 104 | refs: dict(hash: int) to keep track of derived predicates 105 | last_step: boolean to keep track whether this is the last step. 106 | 107 | Returns: 108 | a string of (pseudo) natural language of the proof step for human reader. 109 | """ 110 | premises, [conclusion] = proof_step 111 | 112 | premises_nl = ' & '.join( 113 | [ 114 | natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()]) 115 | for p in premises 116 | ] 117 | ) 118 | 119 | if not premises: 120 | premises_nl = 'similarly' 121 | 122 | refs[conclusion.hashed()] = len(refs) 123 | 124 | conclusion_nl = natural_language_statement(conclusion) 125 | if not last_step: 126 | conclusion_nl += ' [{:02}]'.format(refs[conclusion.hashed()]) 127 | 128 | return f'{premises_nl} \u21d2 {conclusion_nl}' 129 | 130 | 131 | def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None: 132 | """Output the solution to out_file. 133 | 134 | Args: 135 | g: gh.Graph object, containing the proof state. 136 | p: pr.Problem object, containing the theorem. 137 | out_file: file to write to, empty string to skip writing to file. 138 | """ 139 | setup, aux, proof_steps, refs = ddar.get_proof_steps( 140 | g, p.goal, merge_trivials=False 141 | ) 142 | 143 | solution = '\n==========================' 144 | solution += '\n * From theorem premises:\n' 145 | premises_nl = [] 146 | for premises, [points] in setup: 147 | solution += ' '.join([p.name.upper() for p in points]) + ' ' 148 | if not premises: 149 | continue 150 | premises_nl += [ 151 | natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()]) 152 | for p in premises 153 | ] 154 | solution += ': Points\n' + '\n'.join(premises_nl) 155 | 156 | solution += '\n\n * Auxiliary Constructions:\n' 157 | aux_premises_nl = [] 158 | for premises, [points] in aux: 159 | solution += ' '.join([p.name.upper() for p in points]) + ' ' 160 | aux_premises_nl += [ 161 | natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()]) 162 | for p in premises 163 | ] 164 | solution += ': Points\n' + '\n'.join(aux_premises_nl) 165 | 166 | # some special case where the deduction rule has a well known name. 167 | r2name = { 168 | 'r32': '(SSS)', 169 | 'r33': '(SAS)', 170 | 'r34': '(Similar Triangles)', 171 | 'r35': '(Similar Triangles)', 172 | 'r36': '(ASA)', 173 | 'r37': '(ASA)', 174 | 'r38': '(Similar Triangles)', 175 | 'r39': '(Similar Triangles)', 176 | 'r40': '(Congruent Triangles)', 177 | 'a00': '(Distance chase)', 178 | 'a01': '(Ratio chase)', 179 | 'a02': '(Angle chase)', 180 | } 181 | 182 | solution += '\n\n * Proof steps:\n' 183 | for i, step in enumerate(proof_steps): 184 | _, [con] = step 185 | nl = proof_step_string(step, refs, last_step=i == len(proof_steps) - 1) 186 | rule_name = r2name.get(con.rule_name, '') 187 | nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ') 188 | solution += '{:03}. '.format(i + 1) + nl + '\n' 189 | 190 | solution += '==========================\n' 191 | logging.info(solution) 192 | if out_file: 193 | with open(out_file, 'w') as f: 194 | f.write(solution) 195 | logging.info('Solution written to %s.', out_file) 196 | 197 | 198 | def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference: 199 | lm.parse_gin_configuration( 200 | _GIN_FILE.value, _GIN_PARAM.value, gin_paths=_GIN_SEARCH_PATHS.value 201 | ) 202 | 203 | return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search') 204 | 205 | 206 | def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool: 207 | """Run DD+AR. 208 | 209 | Args: 210 | g: gh.Graph object, containing the proof state. 211 | p: pr.Problem object, containing the problem statement. 212 | out_file: path to output file if solution is found. 213 | 214 | Returns: 215 | Boolean, whether DD+AR finishes successfully. 216 | """ 217 | ddar.solve(g, RULES, p, max_level=1000) 218 | 219 | goal_args = g.names2nodes(p.goal.args) 220 | if not g.check(p.goal.name, goal_args): 221 | logging.info('DD+AR failed to solve the problem.') 222 | return False 223 | 224 | write_solution(g, p, out_file) 225 | 226 | gh.nm.draw( 227 | g.type2nodes[gh.Point], 228 | g.type2nodes[gh.Line], 229 | g.type2nodes[gh.Circle], 230 | g.type2nodes[gh.Segment]) 231 | return True 232 | 233 | 234 | def translate_constrained_to_constructive( 235 | point: str, name: str, args: list[str] 236 | ) -> tuple[str, list[str]]: 237 | """Translate a predicate from constraint-based to construction-based. 238 | 239 | Args: 240 | point: str: name of the new point 241 | name: str: name of the predicate, e.g., perp, para, etc. 242 | args: list[str]: list of predicate args. 243 | 244 | Returns: 245 | (name, args): translated to constructive predicate. 246 | """ 247 | if name in ['T', 'perp']: 248 | a, b, c, d = args 249 | if point in [c, d]: 250 | a, b, c, d = c, d, a, b 251 | if point == b: 252 | a, b = b, a 253 | if point == d: 254 | c, d = d, c 255 | if a == c and a == point: 256 | return 'on_dia', [a, b, d] 257 | return 'on_tline', [a, b, c, d] 258 | 259 | elif name in ['P', 'para']: 260 | a, b, c, d = args 261 | if point in [c, d]: 262 | a, b, c, d = c, d, a, b 263 | if point == b: 264 | a, b = b, a 265 | return 'on_pline', [a, b, c, d] 266 | 267 | elif name in ['D', 'cong']: 268 | a, b, c, d = args 269 | if point in [c, d]: 270 | a, b, c, d = c, d, a, b 271 | if point == b: 272 | a, b = b, a 273 | if point == d: 274 | c, d = d, c 275 | if a == c and a == point: 276 | return 'on_bline', [a, b, d] 277 | if b in [c, d]: 278 | if b == d: 279 | c, d = d, c # pylint: disable=unused-variable 280 | return 'on_circle', [a, b, d] 281 | return 'eqdistance', [a, b, c, d] 282 | 283 | elif name in ['C', 'coll']: 284 | a, b, c = args 285 | if point == b: 286 | a, b = b, a 287 | if point == c: 288 | a, b, c = c, a, b 289 | return 'on_line', [a, b, c] 290 | 291 | elif name in ['^', 'eqangle']: 292 | a, b, c, d, e, f = args 293 | 294 | if point in [d, e, f]: 295 | a, b, c, d, e, f = d, e, f, a, b, c 296 | 297 | x, b, y, c, d = b, c, e, d, f 298 | if point == b: 299 | a, b, c, d = b, a, d, c 300 | 301 | if point == d and x == y: # x p x b = x c x p 302 | return 'angle_bisector', [point, b, x, c] 303 | 304 | if point == x: 305 | return 'eqangle3', [x, a, b, y, c, d] 306 | 307 | return 'on_aline', [a, x, b, c, y, d] 308 | 309 | elif name in ['cyclic', 'O']: 310 | a, b, c = [x for x in args if x != point] 311 | return 'on_circum', [point, a, b, c] 312 | 313 | return name, args 314 | 315 | 316 | def check_valid_args(name: str, args: list[str]) -> bool: 317 | """Check whether a predicate is grammarically correct. 318 | 319 | Args: 320 | name: str: name of the predicate 321 | args: list[str]: args of the predicate 322 | 323 | Returns: 324 | bool: whether the predicate arg count is valid. 325 | """ 326 | if name == 'perp': 327 | if len(args) != 4: 328 | return False 329 | a, b, c, d = args 330 | if len({a, b}) < 2: 331 | return False 332 | if len({c, d}) < 2: 333 | return False 334 | elif name == 'para': 335 | if len(args) != 4: 336 | return False 337 | a, b, c, d = args 338 | if len({a, b, c, d}) < 4: 339 | return False 340 | elif name == 'cong': 341 | if len(args) != 4: 342 | return False 343 | a, b, c, d = args 344 | if len({a, b}) < 2: 345 | return False 346 | if len({c, d}) < 2: 347 | return False 348 | elif name == 'coll': 349 | if len(args) != 3: 350 | return False 351 | a, b, c = args 352 | if len({a, b, c}) < 3: 353 | return False 354 | elif name == 'cyclic': 355 | if len(args) != 4: 356 | return False 357 | a, b, c, d = args 358 | if len({a, b, c, d}) < 4: 359 | return False 360 | elif name == 'eqangle': 361 | if len(args) != 8: 362 | return False 363 | a, b, c, d, e, f, g, h = args 364 | if len({a, b, c, d}) < 3: 365 | return False 366 | if len({e, f, g, h}) < 3: 367 | return False 368 | return True 369 | 370 | 371 | def try_translate_constrained_to_construct(string: str, g: gh.Graph) -> str: 372 | """Whether a string of aux construction can be constructed. 373 | 374 | Args: 375 | string: str: the string describing aux construction. 376 | g: gh.Graph: the current proof state. 377 | 378 | Returns: 379 | str: whether this construction is valid. If not, starts with "ERROR:". 380 | """ 381 | if string[-1] != ';': 382 | return 'ERROR: must end with ;' 383 | 384 | head, prem_str = string.split(' : ') 385 | point = head.strip() 386 | 387 | if len(point) != 1 or point == ' ': 388 | return f'ERROR: invalid point name {point}' 389 | 390 | existing_points = [p.name for p in g.all_points()] 391 | if point in existing_points: 392 | return f'ERROR: point {point} already exists.' 393 | 394 | prem_toks = prem_str.split()[:-1] # remove the EOS ' ;' 395 | prems = [[]] 396 | 397 | for i, tok in enumerate(prem_toks): 398 | if tok.isdigit(): 399 | if i < len(prem_toks) - 1: 400 | prems.append([]) 401 | else: 402 | prems[-1].append(tok) 403 | 404 | if len(prems) > 2: 405 | return 'ERROR: there cannot be more than two predicates.' 406 | 407 | clause_txt = point + ' = ' 408 | constructions = [] 409 | 410 | for prem in prems: 411 | name, *args = prem 412 | 413 | if point not in args: 414 | return f'ERROR: {point} not found in predicate args.' 415 | 416 | if not check_valid_args(pt.map_symbol(name), args): 417 | return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args) 418 | 419 | for a in args: 420 | if a != point and a not in existing_points: 421 | return f'ERROR: point {a} does not exist.' 422 | 423 | try: 424 | name, args = translate_constrained_to_constructive(point, name, args) 425 | except: # pylint: disable=bare-except 426 | return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args) 427 | 428 | if name == 'on_aline': 429 | if args.count(point) > 1: 430 | return f'ERROR: on_aline involves twice {point}' 431 | 432 | constructions += [name + ' ' + ' '.join(args)] 433 | 434 | clause_txt += ', '.join(constructions) 435 | clause = pr.Clause.from_txt(clause_txt) 436 | 437 | try: 438 | g.copy().add_clause(clause, 0, DEFINITIONS) 439 | except: # pylint: disable=bare-except 440 | return 'ERROR: ' + traceback.format_exc() 441 | 442 | return clause_txt 443 | 444 | 445 | def insert_aux_to_premise(pstring: str, auxstring: str) -> str: 446 | """Insert auxiliary constructs from proof to premise. 447 | 448 | Args: 449 | pstring: str: describing the problem to solve. 450 | auxstring: str: describing the auxiliar construction. 451 | 452 | Returns: 453 | str: new pstring with auxstring inserted before the conclusion. 454 | """ 455 | setup, goal = pstring.split(' ? ') 456 | return setup + '; ' + auxstring + ' ? ' + goal 457 | 458 | 459 | class BeamQueue: 460 | """Keep only the top k objects according to their values.""" 461 | 462 | def __init__(self, max_size: int = 512): 463 | self.queue = [] 464 | self.max_size = max_size 465 | 466 | def add(self, node: object, val: float) -> None: 467 | """Add a new node to this queue.""" 468 | 469 | if len(self.queue) < self.max_size: 470 | self.queue.append((val, node)) 471 | return 472 | 473 | # Find the minimum node: 474 | min_idx, (min_val, _) = min(enumerate(self.queue), key=lambda x: x[1]) 475 | 476 | # replace it if the new node has higher value. 477 | if val > min_val: 478 | self.queue[min_idx] = (val, node) 479 | 480 | def __iter__(self): 481 | for val, node in self.queue: 482 | yield val, node 483 | 484 | def __len__(self) -> int: 485 | return len(self.queue) 486 | 487 | 488 | def run_alphageometry( 489 | model: lm.LanguageModelInference, 490 | p: pr.Problem, 491 | search_depth: int, 492 | beam_size: int, 493 | out_file: str, 494 | ) -> bool: 495 | """Simplified code to run AlphaGeometry proof search. 496 | 497 | We removed all optimizations that are infrastructure-dependent, e.g. 498 | parallelized model inference on multi GPUs, 499 | parallelized DD+AR on multiple CPUs, 500 | parallel execution of LM and DD+AR, 501 | shared pool of CPU workers across different problems, etc. 502 | 503 | Many other speed optimizations and abstractions are also removed to 504 | better present the core structure of the proof search. 505 | 506 | Args: 507 | model: Interface with inference-related endpoints to JAX's model. 508 | p: pr.Problem object describing the problem to solve. 509 | search_depth: max proof search depth. 510 | beam_size: beam size of the proof search. 511 | out_file: path to output file if solution is found. 512 | 513 | Returns: 514 | boolean of whether this is solved. 515 | """ 516 | # translate the problem to a string of grammar that the LM is trained on. 517 | string = p.setup_str_from_problem(DEFINITIONS) 518 | # special tokens prompting the LM to generate auxiliary points. 519 | string += ' {F1} x00' 520 | # the graph to represent the proof state. 521 | g, _ = gh.Graph.build_problem(p, DEFINITIONS) 522 | 523 | # First we run the symbolic engine DD+AR: 524 | if run_ddar(g, p, out_file): 525 | return True 526 | 527 | # beam search for the proof 528 | # each node in the search tree is a 3-tuple: 529 | # (, 530 | # , 531 | # ) 532 | beam_queue = BeamQueue(max_size=beam_size) 533 | # originally the beam search tree starts with a single node (a 3-tuple): 534 | beam_queue.add( 535 | node=(g, string, p.txt()), val=0.0 # value of the root node is simply 0. 536 | ) 537 | 538 | for depth in range(search_depth): 539 | logging.info( 540 | 'Depth %s. There are %i nodes to expand:', depth, len(beam_queue) 541 | ) 542 | for _, (_, string, _) in beam_queue: 543 | logging.info(string) 544 | 545 | new_queue = BeamQueue(max_size=beam_size) # to replace beam_queue. 546 | 547 | for prev_score, (g, string, pstring) in beam_queue: 548 | logging.info('Decoding from %s', string) 549 | outputs = model.beam_decode(string, eos_tokens=[';']) 550 | 551 | # translate lm output to the constructive language. 552 | # so that we can update the graph representing proof states: 553 | translations = [ 554 | try_translate_constrained_to_construct(o, g) 555 | for o in outputs['seqs_str'] 556 | ] 557 | 558 | # couple the lm outputs with its translations 559 | candidates = zip(outputs['seqs_str'], translations, outputs['scores']) 560 | 561 | # bring the highest scoring candidate first 562 | candidates = reversed(list(candidates)) 563 | 564 | for lm_out, translation, score in candidates: 565 | logging.info('LM output (score=%f): "%s"', score, lm_out) 566 | logging.info('Translation: "%s"\n', translation) 567 | 568 | if translation.startswith('ERROR:'): 569 | # the construction is invalid. 570 | continue 571 | 572 | # Update the constructive statement of the problem with the aux point: 573 | candidate_pstring = insert_aux_to_premise(pstring, translation) 574 | 575 | logging.info('Solving: "%s"', candidate_pstring) 576 | p_new = pr.Problem.from_txt(candidate_pstring) 577 | 578 | # This is the new proof state graph representation: 579 | g_new, _ = gh.Graph.build_problem(p_new, DEFINITIONS) 580 | if run_ddar(g_new, p_new, out_file): 581 | logging.info('Solved.') 582 | return True 583 | 584 | # Add the candidate to the beam queue. 585 | new_queue.add( 586 | # The string for the new node is old_string + lm output + 587 | # the special token asking for a new auxiliary point ' x00': 588 | node=(g_new, string + ' ' + lm_out + ' x00', candidate_pstring), 589 | # the score of each node is sum of score of all nodes 590 | # on the path to itself. For beam search, there is no need to 591 | # normalize according to path length because all nodes in beam 592 | # is of the same path length. 593 | val=prev_score + score, 594 | ) 595 | # Note that the queue only maintain at most beam_size nodes 596 | # so this new node might possibly be dropped depending on its value. 597 | 598 | # replace the old queue with new queue before the new proof search depth. 599 | beam_queue = new_queue 600 | 601 | return False 602 | 603 | 604 | def main(_): 605 | global DEFINITIONS 606 | global RULES 607 | 608 | # definitions of terms used in our domain-specific language. 609 | DEFINITIONS = pr.Definition.from_txt_file(_DEFS_FILE.value, to_dict=True) 610 | # load inference rules used in DD. 611 | RULES = pr.Theorem.from_txt_file(_RULES_FILE.value, to_dict=True) 612 | 613 | # when using the language model, 614 | # point names will be renamed to alphabetical a, b, c, d, e, ... 615 | # instead of staying with their original names, 616 | # in order to match the synthetic training data generation. 617 | need_rename = _MODE.value != 'ddar' 618 | 619 | # load problems from the problems_file, 620 | problems = pr.Problem.from_txt_file( 621 | _PROBLEMS_FILE.value, to_dict=True, translate=need_rename 622 | ) 623 | 624 | if _PROBLEM_NAME.value not in problems: 625 | raise ValueError( 626 | f'Problem name `{_PROBLEM_NAME.value}` ' 627 | + f'not found in `{_PROBLEMS_FILE.value}`' 628 | ) 629 | 630 | this_problem = problems[_PROBLEM_NAME.value] 631 | 632 | if _MODE.value == 'ddar': 633 | g, _ = gh.Graph.build_problem(this_problem, DEFINITIONS) 634 | run_ddar(g, this_problem, _OUT_FILE.value) 635 | 636 | elif _MODE.value == 'alphageometry': 637 | model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value) 638 | run_alphageometry( 639 | model, 640 | this_problem, 641 | _SEARCH_DEPTH.value, 642 | _BEAM_SIZE.value, 643 | _OUT_FILE.value, 644 | ) 645 | 646 | else: 647 | raise ValueError(f'Unknown FLAGS.mode: {_MODE.value}') 648 | 649 | 650 | if __name__ == '__main__': 651 | app.run(main) 652 | -------------------------------------------------------------------------------- /alphageometry_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for alphageometry.py.""" 17 | 18 | import unittest 19 | 20 | from absl.testing import absltest 21 | import alphageometry 22 | 23 | 24 | class AlphaGeometryTest(unittest.TestCase): 25 | 26 | def test_translate_constrained_to_constructive(self): 27 | self.assertEqual( 28 | alphageometry.translate_constrained_to_constructive( 29 | 'd', 'T', list('addb') 30 | ), 31 | ('on_dia', ['d', 'b', 'a']), 32 | ) 33 | self.assertEqual( 34 | alphageometry.translate_constrained_to_constructive( 35 | 'd', 'T', list('adbc') 36 | ), 37 | ('on_tline', ['d', 'a', 'b', 'c']), 38 | ) 39 | self.assertEqual( 40 | alphageometry.translate_constrained_to_constructive( 41 | 'd', 'P', list('bcda') 42 | ), 43 | ('on_pline', ['d', 'a', 'b', 'c']), 44 | ) 45 | self.assertEqual( 46 | alphageometry.translate_constrained_to_constructive( 47 | 'd', 'D', list('bdcd') 48 | ), 49 | ('on_bline', ['d', 'c', 'b']), 50 | ) 51 | self.assertEqual( 52 | alphageometry.translate_constrained_to_constructive( 53 | 'd', 'D', list('bdcb') 54 | ), 55 | ('on_circle', ['d', 'b', 'c']), 56 | ) 57 | self.assertEqual( 58 | alphageometry.translate_constrained_to_constructive( 59 | 'd', 'D', list('bacd') 60 | ), 61 | ('eqdistance', ['d', 'c', 'b', 'a']), 62 | ) 63 | self.assertEqual( 64 | alphageometry.translate_constrained_to_constructive( 65 | 'd', 'C', list('bad') 66 | ), 67 | ('on_line', ['d', 'b', 'a']), 68 | ) 69 | self.assertEqual( 70 | alphageometry.translate_constrained_to_constructive( 71 | 'd', 'C', list('bad') 72 | ), 73 | ('on_line', ['d', 'b', 'a']), 74 | ) 75 | self.assertEqual( 76 | alphageometry.translate_constrained_to_constructive( 77 | 'd', 'O', list('abcd') 78 | ), 79 | ('on_circum', ['d', 'a', 'b', 'c']), 80 | ) 81 | 82 | def test_insert_aux_to_premise(self): 83 | pstring = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c' # pylint: disable=line-too-long 84 | auxstring = 'e = on_line e a c, on_line e b d' 85 | 86 | target = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c' # pylint: disable=line-too-long 87 | self.assertEqual( 88 | alphageometry.insert_aux_to_premise(pstring, auxstring), target 89 | ) 90 | 91 | def test_beam_queue(self): 92 | beam_queue = alphageometry.BeamQueue(max_size=2) 93 | 94 | beam_queue.add('a', 1) 95 | beam_queue.add('b', 2) 96 | beam_queue.add('c', 3) 97 | 98 | beam_queue = list(beam_queue) 99 | self.assertEqual(beam_queue, [(3, 'c'), (2, 'b')]) 100 | 101 | 102 | if __name__ == '__main__': 103 | absltest.main() 104 | -------------------------------------------------------------------------------- /ar_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for ar.py.""" 17 | import unittest 18 | 19 | from absl.testing import absltest 20 | import ar 21 | import graph as gh 22 | import problem as pr 23 | 24 | 25 | class ARTest(unittest.TestCase): 26 | 27 | @classmethod 28 | def setUpClass(cls): 29 | super().setUpClass() 30 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) 31 | cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True) 32 | 33 | def test_update_groups(self): 34 | """Test for update_groups.""" 35 | groups1 = [{1, 2}, {3, 4, 5}, {6, 7}] 36 | groups2 = [{2, 3, 8}, {9, 10, 11}] 37 | 38 | _, links, history = ar.update_groups(groups1, groups2) 39 | self.assertEqual( 40 | history, 41 | [ 42 | [{1, 2, 3, 4, 5, 8}, {6, 7}], 43 | [{1, 2, 3, 4, 5, 8}, {6, 7}, {9, 10, 11}], 44 | ], 45 | ) 46 | self.assertEqual(links, [(2, 3), (3, 8), (9, 10), (10, 11)]) 47 | 48 | groups1 = [{1, 2}, {3, 4}, {5, 6}, {7, 8}] 49 | groups2 = [{2, 3, 8, 9, 10}, {3, 6, 11}] 50 | 51 | _, links, history = ar.update_groups(groups1, groups2) 52 | self.assertEqual( 53 | history, 54 | [ 55 | [{1, 2, 3, 4, 7, 8, 9, 10}, {5, 6}], 56 | [{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}], 57 | ], 58 | ) 59 | self.assertEqual(links, [(2, 3), (3, 8), (8, 9), (9, 10), (3, 6), (6, 11)]) 60 | 61 | groups1 = [] 62 | groups2 = [{1, 2}, {3, 4}, {5, 6}, {2, 3}] 63 | 64 | _, links, history = ar.update_groups(groups1, groups2) 65 | self.assertEqual( 66 | history, 67 | [ 68 | [{1, 2}], 69 | [{1, 2}, {3, 4}], 70 | [{1, 2}, {3, 4}, {5, 6}], 71 | [{1, 2, 3, 4}, {5, 6}], 72 | ], 73 | ) 74 | self.assertEqual(links, [(1, 2), (3, 4), (5, 6), (2, 3)]) 75 | 76 | def test_generic_table_simple(self): 77 | tb = ar.Table() 78 | 79 | # If a-b = b-c & d-a = c-d 80 | tb.add_eq4('a', 'b', 'b', 'c', 'fact1') 81 | tb.add_eq4('d', 'a', 'c', 'd', 'fact2') 82 | tb.add_eq4('x', 'y', 'z', 't', 'fact3') # distractor fact 83 | 84 | # Then b=d, because {fact1, fact2} but not fact3. 85 | result = list(tb.get_all_eqs_and_why()) 86 | self.assertIn(('b', 'd', ['fact1', 'fact2']), result) 87 | 88 | def test_angle_table_inbisector_exbisector(self): 89 | """Test that AR can figure out bisector & ex-bisector are perpendicular.""" 90 | # Load the scenario that we have cd is bisector of acb and 91 | # ce is the ex-bisector of acb. 92 | p = pr.Problem.from_txt( 93 | 'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?' 94 | ' perp d c c e' 95 | ) 96 | g, _ = gh.Graph.build_problem(p, ARTest.defs) 97 | 98 | # Create an external angle table: 99 | tb = ar.AngleTable('pi') 100 | 101 | # Add bisector & ex-bisector facts into the table: 102 | ca, cd, cb, ce = g.names2nodes(['d(ac)', 'd(cd)', 'd(bc)', 'd(ce)']) 103 | tb.add_eqangle(ca, cd, cd, cb, 'fact1') 104 | tb.add_eqangle(ce, ca, cb, ce, 'fact2') 105 | 106 | # Add a distractor fact to make sure traceback does not include this fact 107 | ab = g.names2nodes(['d(ab)'])[0] 108 | tb.add_eqangle(ab, cb, cb, ca, 'fact3') 109 | 110 | # Check for all new equalities 111 | result = list(tb.get_all_eqs_and_why()) 112 | 113 | # halfpi is represented as a tuple (1, 2) 114 | halfpi = (1, 2) 115 | 116 | # check that cd-ce == halfpi and this is because fact1 & fact2, not fact3 117 | self.assertCountEqual( 118 | result, 119 | [ 120 | (cd, ce, halfpi, ['fact1', 'fact2']), 121 | (ce, cd, halfpi, ['fact1', 'fact2']), 122 | ], 123 | ) 124 | 125 | def test_angle_table_equilateral_triangle(self): 126 | """Test that AR can figure out triangles with 3 equal angles => each is pi/3.""" 127 | # Load an equaliteral scenario 128 | p = pr.Problem.from_txt('a b c = ieq_triangle ? cong a b a c') 129 | g, _ = gh.Graph.build_problem(p, ARTest.defs) 130 | 131 | # Add two eqangles facts because ieq_triangle only add congruent sides 132 | a, b, c = g.names2nodes('abc') 133 | g.add_eqangle([a, b, b, c, b, c, c, a], pr.EmptyDependency(0, None)) 134 | g.add_eqangle([b, c, c, a, c, a, a, b], pr.EmptyDependency(0, None)) 135 | 136 | # Create an external angle table: 137 | tb = ar.AngleTable('pi') 138 | 139 | # Add the fact that there are three equal angles 140 | ab, bc, ca = g.names2nodes(['d(ab)', 'd(bc)', 'd(ac)']) 141 | tb.add_eqangle(ab, bc, bc, ca, 'fact1') 142 | tb.add_eqangle(bc, ca, ca, ab, 'fact2') 143 | 144 | # Now check for all new equalities 145 | result = list(tb.get_all_eqs_and_why()) 146 | result = [(x.name, y.name, z, t) for x, y, z, t in result] 147 | 148 | # 1/3 pi is represented as a tuple angle_60 149 | angle_60 = (1, 3) 150 | angle_120 = (2, 3) 151 | 152 | # check that angles constants are created and figured out: 153 | self.assertCountEqual( 154 | result, 155 | [ 156 | ('d(bc)', 'd(ac)', angle_120, ['fact1', 'fact2']), 157 | ('d(ab)', 'd(bc)', angle_120, ['fact1', 'fact2']), 158 | ('d(ac)', 'd(ab)', angle_120, ['fact1', 'fact2']), 159 | ('d(ac)', 'd(bc)', angle_60, ['fact1', 'fact2']), 160 | ('d(bc)', 'd(ab)', angle_60, ['fact1', 'fact2']), 161 | ('d(ab)', 'd(ac)', angle_60, ['fact1', 'fact2']), 162 | ], 163 | ) 164 | 165 | def test_incenter_excenter_touchpoints(self): 166 | """Test that AR can figure out incenter/excenter touchpoints are equidistant to midpoint.""" 167 | 168 | p = pr.Problem.from_txt( 169 | 'a b c = triangle a b c; d1 d2 d3 d = incenter2 a b c; e1 e2 e3 e =' 170 | ' excenter2 a b c ? perp d c c e', 171 | translate=False, 172 | ) 173 | g, _ = gh.Graph.build_problem(p, ARTest.defs) 174 | 175 | a, b, c, ab, bc, ca, d1, d2, d3, e1, e2, e3 = g.names2nodes( 176 | ['a', 'b', 'c', 'ab', 'bc', 'ac', 'd1', 'd2', 'd3', 'e1', 'e2', 'e3'] 177 | ) 178 | 179 | # Create an external distance table: 180 | tb = ar.DistanceTable() 181 | 182 | # DD can figure out the following facts, 183 | # we manually add them to AR. 184 | tb.add_cong(ab, ca, a, d3, a, d2, 'fact1') 185 | tb.add_cong(ab, ca, a, e3, a, e2, 'fact2') 186 | tb.add_cong(ca, bc, c, d2, c, d1, 'fact5') 187 | tb.add_cong(ca, bc, c, e2, c, e1, 'fact6') 188 | tb.add_cong(bc, ab, b, d1, b, d3, 'fact3') 189 | tb.add_cong(bc, ab, b, e1, b, e3, 'fact4') 190 | 191 | # Now we check whether tb has figured out that 192 | # distance(b, d1) == distance(e1, c) 193 | 194 | # linear comb exprssion of each variables: 195 | b = tb.v2e['bc:b'] 196 | c = tb.v2e['bc:c'] 197 | d1 = tb.v2e['bc:d1'] 198 | e1 = tb.v2e['bc:e1'] 199 | 200 | self.assertEqual(ar.minus(d1, b), ar.minus(c, e1)) 201 | 202 | 203 | if __name__ == '__main__': 204 | absltest.main() 205 | -------------------------------------------------------------------------------- /beam_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Fast decoding routines for inference from a trained model. 17 | 18 | Modified https://github.com/google/flax/blob/main/examples/wmt/decode.py 19 | to acommodate 20 | 21 | (a) continued decoding from a previous beam cache. 22 | (b) init with with a single beam and then expand into beam_size beams. 23 | """ 24 | 25 | from typing import Any 26 | 27 | import flax 28 | import jax 29 | from jax import lax 30 | import jax.numpy as jnp 31 | import numpy as np 32 | 33 | 34 | # Constants 35 | # "Effective negative infinity" constant for masking in beam search. 36 | NEG_INF = np.array(-1.0e7) 37 | 38 | # Beam search parameters 39 | BEAM_SEARCH_DEFAULT_ALPHA = 0.6 40 | MAX_DECODE_LEN = 32 41 | 42 | # Brevity penalty parameters 43 | BREVITY_LEN_BIAS_NUMERATOR = 5.0 44 | BREVITY_LEN_BIAS_DENOMINATOR = 6.0 45 | 46 | 47 | def brevity_penalty(alpha: float, length: int): 48 | """Brevity penalty function for beam search penalizing short sequences. 49 | 50 | Args: 51 | alpha: float: brevity-penalty scaling parameter. 52 | length: int: length of considered sequence. 53 | 54 | Returns: 55 | Brevity penalty score as jax scalar. 56 | """ 57 | return jnp.power( 58 | ((BREVITY_LEN_BIAS_NUMERATOR + length) / BREVITY_LEN_BIAS_DENOMINATOR), 59 | alpha, 60 | ) 61 | 62 | 63 | # Beam handling utility functions: 64 | 65 | 66 | def add_beam_dim(x: jnp.ndarray, beam_size: int) -> jnp.ndarray: 67 | """Creates new beam dimension in non-scalar array and tiles into it.""" 68 | if x.ndim == 0: # ignore scalars (e.g. cache index) 69 | return x 70 | x = jnp.expand_dims(x, axis=1) 71 | tile_dims = [1] * x.ndim 72 | tile_dims[1] = beam_size 73 | return jnp.tile(x, tile_dims) 74 | 75 | 76 | def add_beam_dim_cache( 77 | cache: tuple[dict[str, jnp.ndarray], ...], beam_size: int 78 | ) -> tuple[dict[str, jnp.ndarray], ...]: 79 | """Creates new beam dimension in non-scalar array and tiles into it.""" 80 | new_cache = [] 81 | 82 | for layer in cache: 83 | new_layer = {} 84 | for key, x in layer.items(): 85 | if key in ['keys', 'vals']: 86 | x = add_beam_dim(x, beam_size) 87 | new_layer[key] = x 88 | new_cache.append(new_layer) 89 | 90 | return tuple(new_cache) 91 | 92 | 93 | def flatten_beam_dim(x): 94 | """Flattens the first two dimensions of a non-scalar array.""" 95 | if x.ndim < 2: # ignore scalars (e.g. cache index) 96 | return x 97 | return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) 98 | 99 | 100 | def unflatten_beam_dim(x, batch_size, beam_size): 101 | """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" 102 | if x.ndim == 0: # ignore scalars (e.g. cache index) 103 | return x 104 | assert batch_size * beam_size == x.shape[0] 105 | return x.reshape((batch_size, beam_size) + x.shape[1:]) 106 | 107 | 108 | def flat_batch_beam_expand(x, beam_size): 109 | """Expands the each batch item by beam_size in batch_dimension.""" 110 | return flatten_beam_dim(add_beam_dim(x, beam_size)) 111 | 112 | 113 | def gather_beams(nested, beam_indices, batch_size, new_beam_size): 114 | """Gathers the beam slices indexed by beam_indices into new beam array. 115 | 116 | Args: 117 | nested: pytree of arrays or scalars (the latter ignored). 118 | beam_indices: array of beam_indices 119 | batch_size: int: size of batch. 120 | new_beam_size: int: size of _new_ beam dimension. 121 | 122 | Returns: 123 | New pytree with new beam arrays. 124 | [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] 125 | """ 126 | batch_indices = jnp.reshape( 127 | jnp.arange(batch_size * new_beam_size) // new_beam_size, 128 | (batch_size, new_beam_size), 129 | ) 130 | 131 | def gather_fn(x): 132 | if x.ndim == 0: # ignore scalars (e.g. cache index) 133 | return x 134 | else: 135 | return x[batch_indices, beam_indices] 136 | 137 | return jax.tree_util.tree_map(gather_fn, nested) 138 | 139 | 140 | def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size): 141 | """Gathers the top-k beam slices given by score_or_log_prob array. 142 | 143 | Args: 144 | nested: pytree of arrays or scalars (the latter ignored). 145 | score_or_log_prob: [batch_size, old_beam_size] array of values to sort by 146 | for top-k selection of beam slices. 147 | batch_size: int: size of batch. 148 | new_beam_size: int: size of _new_ top-k selected beam dimension 149 | 150 | Returns: 151 | New pytree with new beam arrays containing top k new_beam_size slices. 152 | [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] 153 | """ 154 | _, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size) 155 | topk_indices = jnp.flip(topk_indices, axis=1) 156 | return gather_beams(nested, topk_indices, batch_size, new_beam_size) 157 | 158 | 159 | def apply_on_cache(fn, cache, *args, **kwargs): 160 | """Apply fn(val) only when key is 'keys' or 'val'.""" 161 | new_cache = [] 162 | for layer in cache: 163 | new_layer = {} 164 | for key, val in layer.items(): 165 | if key in ['keys', 'values', 'current_index', 'relative_position_bias']: 166 | val = fn(val, *args, **kwargs) 167 | new_layer[key] = val 168 | new_cache.append(new_layer) 169 | return tuple(new_cache) 170 | 171 | 172 | # Beam search state: 173 | 174 | 175 | @flax.struct.dataclass 176 | class BeamState: 177 | """Holds beam search state data.""" 178 | 179 | # The position of the decoding loop in the length dimension. 180 | cur_index: jax.Array # scalar int32: current decoded length index 181 | # The active sequence log probabilities and finished sequence scores. 182 | live_logprobs: jax.Array # float32: [batch_size, beam_size] 183 | finished_scores: jax.Array # float32: [batch_size, beam_size] 184 | # The current active-beam-searching and finished sequences. 185 | live_seqs: jax.Array # int32: [batch_size, beam_size, max_decode_len] 186 | finished_seqs: jax.Array # int32: [batch_size, beam_size, 187 | # max_decode_len] 188 | # Records which of the 'finished_seqs' is occupied and not a filler slot. 189 | finished_flags: jax.Array # bool: [batch_size, beam_size] 190 | # The current state of the autoregressive decoding caches. 191 | cache: Any # Any pytree of arrays, e.g. flax attention Cache object 192 | 193 | 194 | def beam_init(seed_token, batch_size, beam_size, max_decode_len, cache): 195 | """Initializes the beam search state data structure.""" 196 | cur_index0 = jnp.array(0) 197 | live_logprobs0 = jnp.tile( 198 | jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1] 199 | ) 200 | finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF 201 | 202 | live_seqs0 = jnp.concatenate( 203 | [ 204 | jnp.reshape(seed_token, (batch_size, beam_size, 1)), 205 | jnp.zeros((batch_size, beam_size, max_decode_len - 1), jnp.int32), 206 | ], 207 | axis=-1, 208 | ) # (batch, beam, max_decode_len) 209 | 210 | finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) 211 | finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) 212 | beam_cache0 = apply_on_cache(lambda x: jnp.expand_dims(x, axis=0), cache) 213 | return BeamState( 214 | cur_index=cur_index0, 215 | live_logprobs=live_logprobs0, 216 | finished_scores=finished_scores0, 217 | live_seqs=live_seqs0, 218 | finished_seqs=finished_seqs0, 219 | finished_flags=finished_flags0, 220 | cache=beam_cache0, 221 | ) 222 | 223 | 224 | # Beam search routine: 225 | 226 | 227 | def beam_search_flat( 228 | seed_token, 229 | cache, 230 | tokens_to_logits, 231 | alpha=BEAM_SEARCH_DEFAULT_ALPHA, 232 | eos=None, 233 | max_decode_len=MAX_DECODE_LEN, 234 | mask=None, 235 | ): 236 | """Beam search for LM. 237 | 238 | inputs and cache is already flat! i.e. first dimention == batch*beam. 239 | 240 | Args: 241 | seed_token: array: [beam_size, 1] int32 sequence of tokens. 242 | cache: flax attention cache. 243 | tokens_to_logits: fast autoregressive decoder function taking single token 244 | slices and cache and returning next-token logits and updated cache. 245 | alpha: float: scaling factor for brevity penalty. 246 | eos: array: [vocab] 1 for end-of-sentence tokens, 0 for not. 247 | max_decode_len: int: maximum length of decoded translations. 248 | mask: array: [vocab] binary mask for vocab. 1 to keep the prob, 0 to set the 249 | prob := 0. 250 | 251 | Returns: 252 | Tuple of: 253 | [beam_size, max_decode_len] top-scoring sequences 254 | [beam_size] beam-search scores. 255 | """ 256 | # We liberally annotate shape information for clarity below. 257 | batch_size, beam_size = 1, seed_token.shape[0] 258 | mask = mask.reshape((1, 1, -1)) 259 | eos = eos.reshape((1, 1, -1)) 260 | mask_bias = (1 - mask) * NEG_INF 261 | 262 | # initialize beam search state 263 | beam_search_init_state = beam_init( 264 | seed_token, batch_size, beam_size, max_decode_len, cache 265 | ) 266 | 267 | def beam_search_loop_cond_fn(state): 268 | """Beam search loop termination condition.""" 269 | # Have we reached max decoding length? 270 | not_at_end = state.cur_index < max_decode_len - 1 271 | 272 | # Is no further progress in the beam search possible? 273 | # Get the best possible scores from alive sequences. 274 | min_brevity_penalty = brevity_penalty(alpha, max_decode_len) 275 | best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty 276 | # Get the worst scores from finished sequences. 277 | worst_finished_scores = jnp.min( 278 | state.finished_scores, axis=1, keepdims=True 279 | ) 280 | # Mask out scores from slots without any actual finished sequences. 281 | worst_finished_scores = jnp.where( 282 | state.finished_flags, worst_finished_scores, NEG_INF 283 | ) 284 | # If no best possible live score is better than current worst finished 285 | # scores, the search cannot improve the finished set further. 286 | search_terminated = jnp.all(worst_finished_scores > best_live_scores) 287 | 288 | # If we're not at the max decode length, and the search hasn't terminated, 289 | # continue looping. 290 | return not_at_end & (~search_terminated) 291 | 292 | def beam_search_loop_body_fn(state): 293 | """Beam search loop state update function.""" 294 | # Collect the current position slice along length to feed the fast 295 | # autoregressive decoder model. Flatten the beam dimension into batch 296 | # dimension for feeding into the model. 297 | # --> [batch * beam, 1] 298 | flat_ids = flatten_beam_dim( 299 | lax.dynamic_slice( 300 | state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1) 301 | ) 302 | ) 303 | # Flatten beam dimension into batch to be compatible with model. 304 | # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} 305 | flat_cache = apply_on_cache(flatten_beam_dim, state.cache) 306 | 307 | # Call fast-decoder model on current tokens to get next-position logits. 308 | # --> [batch * beam, vocab] 309 | flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache) 310 | 311 | # unflatten beam dimension 312 | # [batch * beam, vocab] --> [batch, beam, vocab] 313 | logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) 314 | 315 | # Unflatten beam dimension in attention cache arrays 316 | # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} 317 | new_cache = apply_on_cache( 318 | unflatten_beam_dim, new_flat_cache, batch_size, beam_size 319 | ) 320 | 321 | # Gather log probabilities from logits 322 | candidate_log_probs = jax.nn.log_softmax(logits) 323 | # Add new logprobs to existing prefix logprobs. 324 | # --> [batch, beam, vocab] 325 | log_probs = candidate_log_probs + jnp.expand_dims( 326 | state.live_logprobs, axis=2 327 | ) 328 | 329 | # We'll need the vocab size, gather it from the log probability dimension. 330 | vocab_size = log_probs.shape[2] 331 | 332 | # mask away some tokens. 333 | log_probs += mask_bias # [batch,beam,vocab]+[1,1,vocab] 334 | 335 | # Each item in batch has beam_size * vocab_size candidate sequences. 336 | # For each item, get the top 2*k candidates with the highest log- 337 | # probabilities. We gather the top 2*K beams here so that even if the best 338 | # K sequences reach EOS simultaneously, we have another K sequences 339 | # remaining to continue the live beam search. 340 | beams_to_keep = 2 * beam_size 341 | # Flatten beam and vocab dimensions. 342 | flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size)) 343 | # Gather the top 2*K scores from _all_ beams. 344 | # --> [batch, 2*beams], [batch, 2*beams] 345 | topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep) 346 | # Recover the beam index by floor division. 347 | topk_beam_indices = topk_indices // vocab_size 348 | # Gather 2*k top beams. 349 | # --> [batch, 2*beams, length] 350 | topk_seq = gather_beams( 351 | state.live_seqs, topk_beam_indices, batch_size, beams_to_keep 352 | ) 353 | 354 | # Append the most probable 2*K token IDs to the top 2*K sequences 355 | # Recover token id by modulo division and expand Id array for broadcasting. 356 | # --> [batch, 2*beams, 1] 357 | topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) 358 | # Update sequences for the 2*K top-k new sequences. 359 | # --> [batch, 2*beams, length] 360 | topk_seq = lax.dynamic_update_slice( 361 | topk_seq, topk_ids, (0, 0, state.cur_index + 1) 362 | ) 363 | 364 | # Update LIVE (in-progress) sequences: 365 | # Did any of these sequences reach an end marker? 366 | # --> [batch, 2*beams] 367 | last_token = topk_seq[:, :, state.cur_index + 1] 368 | last_token = jax.nn.one_hot(last_token, vocab_size, dtype=jnp.bfloat16) 369 | 370 | # any([batch, 2b, vocab] * [1, 1, vocab], axis=-1) == [batch, 2b] 371 | newly_finished = jnp.any(last_token * eos, axis=-1) 372 | 373 | # To prevent these newly finished sequences from being added to the LIVE 374 | # set of active beam search sequences, set their log probs to a very large 375 | # negative value. 376 | new_log_probs = topk_log_probs + newly_finished * NEG_INF 377 | # Determine the top k beam indices (from top 2*k beams) from log probs. 378 | # --> [batch, beams] 379 | _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size) 380 | new_topk_indices = jnp.flip(new_topk_indices, axis=1) 381 | # Gather the top k beams (from top 2*k beams). 382 | # --> [batch, beams, length], [batch, beams] 383 | top_alive_seq, top_alive_log_probs = gather_beams( 384 | [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size 385 | ) 386 | 387 | # Determine the top k beam indices from the original set of all beams. 388 | # --> [batch, beams] 389 | top_alive_indices = gather_beams( 390 | topk_beam_indices, new_topk_indices, batch_size, beam_size 391 | ) 392 | # With these, gather the top k beam-associated caches. 393 | # --> {[batch, beams, ...], ...} 394 | top_alive_cache = apply_on_cache( 395 | gather_beams, new_cache, top_alive_indices, batch_size, beam_size 396 | ) 397 | 398 | # Update FINISHED (reached end of sentence) sequences: 399 | # Calculate new seq scores from log probabilities. 400 | new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) 401 | # Mask out the still unfinished sequences by adding large negative value. 402 | # --> [batch, 2*beams] 403 | new_scores += (~newly_finished) * NEG_INF 404 | 405 | # Combine sequences, scores, and flags along the beam dimension and compare 406 | # new finished sequence scores to existing finished scores and select the 407 | # best from the new set of beams. 408 | finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] 409 | [state.finished_seqs, topk_seq], axis=1 410 | ) 411 | finished_scores = jnp.concatenate( # --> [batch, 3*beams] 412 | [state.finished_scores, new_scores], axis=1 413 | ) 414 | finished_flags = jnp.concatenate( # --> [batch, 3*beams] 415 | [state.finished_flags, newly_finished], axis=1 416 | ) 417 | # --> [batch, beams, length], [batch, beams], [batch, beams] 418 | top_finished_seq, top_finished_scores, top_finished_flags = ( 419 | gather_topk_beams( 420 | [finished_seqs, finished_scores, finished_flags], 421 | finished_scores, 422 | batch_size, 423 | beam_size, 424 | ) 425 | ) 426 | 427 | return BeamState( 428 | cur_index=state.cur_index + 1, 429 | live_logprobs=top_alive_log_probs, 430 | finished_scores=top_finished_scores, 431 | live_seqs=top_alive_seq, 432 | finished_seqs=top_finished_seq, 433 | finished_flags=top_finished_flags, 434 | cache=top_alive_cache, 435 | ) 436 | 437 | # Run while loop and get final beam search state. 438 | final_state = lax.while_loop( 439 | beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state 440 | ) 441 | 442 | # Account for the edge-case where there are no finished sequences for a 443 | # particular batch item. If so, return live sequences for that batch item. 444 | # --> [batch] 445 | none_finished = jnp.any(final_state.finished_flags, axis=1) 446 | # --> [batch, beams, length] 447 | finished_seqs = jnp.where( 448 | none_finished[:, None, None], 449 | final_state.finished_seqs, 450 | final_state.live_seqs, 451 | ) 452 | # --> [batch, beams] 453 | finished_scores = jnp.where( 454 | none_finished[:, None], 455 | final_state.finished_scores, 456 | final_state.live_logprobs, 457 | ) 458 | 459 | finished_seqs = jnp.reshape(finished_seqs, (beam_size, max_decode_len)) 460 | finished_scores = jnp.reshape(finished_scores, (beam_size,)) 461 | 462 | final_cache = apply_on_cache(flatten_beam_dim, final_state.cache) 463 | return finished_seqs, finished_scores, final_cache 464 | -------------------------------------------------------------------------------- /dd_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for dd.""" 17 | import unittest 18 | 19 | from absl.testing import absltest 20 | import dd 21 | import graph as gh 22 | import problem as pr 23 | 24 | 25 | MAX_LEVEL = 1000 26 | 27 | 28 | class DDTest(unittest.TestCase): 29 | 30 | @classmethod 31 | def setUpClass(cls): 32 | super().setUpClass() 33 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) 34 | cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True) 35 | 36 | def test_imo_2022_p4_should_succeed(self): 37 | p = pr.Problem.from_txt( 38 | 'a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m =' 39 | ' on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n' 40 | ' g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b,' 41 | ' on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a' 42 | ' n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q' 43 | ) 44 | g, _ = gh.Graph.build_problem(p, DDTest.defs) 45 | goal_args = g.names2nodes(p.goal.args) 46 | 47 | success = False 48 | for level in range(MAX_LEVEL): 49 | added, _, _, _ = dd.bfs_one_level(g, DDTest.rules, level, p) 50 | if g.check(p.goal.name, goal_args): 51 | success = True 52 | break 53 | if not added: # saturated 54 | break 55 | 56 | self.assertTrue(success) 57 | 58 | def test_incenter_excenter_should_fail(self): 59 | p = pr.Problem.from_txt( 60 | 'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?' 61 | ' perp d c c e' 62 | ) 63 | g, _ = gh.Graph.build_problem(p, DDTest.defs) 64 | goal_args = g.names2nodes(p.goal.args) 65 | 66 | success = False 67 | for level in range(MAX_LEVEL): 68 | added, _, _, _ = dd.bfs_one_level(g, DDTest.rules, level, p) 69 | if g.check(p.goal.name, goal_args): 70 | success = True 71 | break 72 | if not added: # saturated 73 | break 74 | 75 | self.assertFalse(success) 76 | 77 | 78 | if __name__ == '__main__': 79 | absltest.main() 80 | -------------------------------------------------------------------------------- /ddar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Implements the combination DD+AR.""" 17 | import time 18 | 19 | from absl import logging 20 | import dd 21 | import graph as gh 22 | import problem as pr 23 | from problem import Dependency # pylint: disable=g-importing-member 24 | import trace_back 25 | 26 | 27 | def saturate_or_goal( 28 | g: gh.Graph, 29 | theorems: list[pr.Theorem], 30 | level_times: list[float], 31 | p: pr.Problem, 32 | max_level: int = 100, 33 | timeout: int = 600, 34 | ) -> tuple[ 35 | list[dict[str, list[tuple[gh.Point, ...]]]], 36 | list[dict[str, list[tuple[gh.Point, ...]]]], 37 | list[int], 38 | list[pr.Dependency], 39 | ]: 40 | """Run DD until saturation or goal found.""" 41 | derives = [] 42 | eq4s = [] 43 | branching = [] 44 | all_added = [] 45 | 46 | while len(level_times) < max_level: 47 | level = len(level_times) + 1 48 | 49 | t = time.time() 50 | added, derv, eq4, n_branching = dd.bfs_one_level( 51 | g, theorems, level, p, verbose=False, nm_check=True, timeout=timeout 52 | ) 53 | all_added += added 54 | branching.append(n_branching) 55 | 56 | derives.append(derv) 57 | eq4s.append(eq4) 58 | level_time = time.time() - t 59 | 60 | logging.info(f'Depth {level}/{max_level} time = {level_time}') # pylint: disable=logging-fstring-interpolation 61 | level_times.append(level_time) 62 | 63 | if p.goal is not None: 64 | goal_args = list(map(lambda x: g.get(x, lambda: int(x)), p.goal.args)) 65 | if g.check(p.goal.name, goal_args): # found goal 66 | break 67 | 68 | if not added: # saturated 69 | break 70 | 71 | if level_time > timeout: 72 | break 73 | 74 | return derives, eq4s, branching, all_added 75 | 76 | 77 | def solve( 78 | g: gh.Graph, 79 | theorems: list[pr.Problem], 80 | controller: pr.Problem, 81 | max_level: int = 1000, 82 | timeout: int = 600, 83 | ) -> tuple[gh.Graph, list[float], str, list[int], list[pr.Dependency]]: 84 | """Alternate between DD and AR until goal is found.""" 85 | status = 'saturated' 86 | level_times = [] 87 | 88 | dervs, eq4 = g.derive_algebra(level=0, verbose=False) 89 | derives = [dervs] 90 | eq4s = [eq4] 91 | branches = [] 92 | all_added = [] 93 | 94 | while len(level_times) < max_level: 95 | dervs, eq4, next_branches, added = saturate_or_goal( 96 | g, theorems, level_times, controller, max_level, timeout=timeout 97 | ) 98 | all_added += added 99 | 100 | derives += dervs 101 | eq4s += eq4 102 | branches += next_branches 103 | 104 | # Now, it is either goal or saturated 105 | if controller.goal is not None: 106 | goal_args = g.names2points(controller.goal.args) 107 | if g.check(controller.goal.name, goal_args): # found goal 108 | status = 'solved' 109 | break 110 | 111 | if not derives: # officially saturated. 112 | break 113 | 114 | # Now we resort to algebra derivations. 115 | added = [] 116 | while derives and not added: 117 | added += dd.apply_derivations(g, derives.pop(0)) 118 | 119 | if added: 120 | continue 121 | 122 | # Final help from AR. 123 | while eq4s and not added: 124 | added += dd.apply_derivations(g, eq4s.pop(0)) 125 | 126 | all_added += added 127 | 128 | if not added: # Nothing left. saturated. 129 | break 130 | 131 | return g, level_times, status, branches, all_added 132 | 133 | 134 | def get_proof_steps( 135 | g: gh.Graph, goal: pr.Clause, merge_trivials: bool = False 136 | ) -> tuple[ 137 | list[pr.Dependency], 138 | list[pr.Dependency], 139 | list[tuple[list[pr.Dependency], list[pr.Dependency]]], 140 | dict[tuple[str, ...], int], 141 | ]: 142 | """Extract proof steps from the built DAG.""" 143 | goal_args = g.names2nodes(goal.args) 144 | query = Dependency(goal.name, goal_args, None, None) 145 | 146 | setup, aux, log, setup_points = trace_back.get_logs( 147 | query, g, merge_trivials=merge_trivials 148 | ) 149 | 150 | refs = {} 151 | setup = trace_back.point_log(setup, refs, set()) 152 | aux = trace_back.point_log(aux, refs, setup_points) 153 | 154 | setup = [(prems, [tuple(p)]) for p, prems in setup] 155 | aux = [(prems, [tuple(p)]) for p, prems in aux] 156 | 157 | return setup, aux, log, refs 158 | -------------------------------------------------------------------------------- /ddar_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for ddar.py.""" 17 | import unittest 18 | 19 | from absl.testing import absltest 20 | import ddar 21 | import graph as gh 22 | import problem as pr 23 | 24 | 25 | class DDARTest(unittest.TestCase): 26 | 27 | @classmethod 28 | def setUpClass(cls): 29 | super().setUpClass() 30 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) 31 | cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True) 32 | 33 | def test_orthocenter_should_fail(self): 34 | txt = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c' # pylint: disable=line-too-long 35 | p = pr.Problem.from_txt(txt) 36 | g, _ = gh.Graph.build_problem(p, DDARTest.defs) 37 | 38 | ddar.solve(g, DDARTest.rules, p, max_level=1000) 39 | goal_args = g.names2nodes(p.goal.args) 40 | self.assertFalse(g.check(p.goal.name, goal_args)) 41 | 42 | def test_orthocenter_aux_should_succeed(self): 43 | txt = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c' # pylint: disable=line-too-long 44 | p = pr.Problem.from_txt(txt) 45 | g, _ = gh.Graph.build_problem(p, DDARTest.defs) 46 | 47 | ddar.solve(g, DDARTest.rules, p, max_level=1000) 48 | goal_args = g.names2nodes(p.goal.args) 49 | self.assertTrue(g.check(p.goal.name, goal_args)) 50 | 51 | def test_incenter_excenter_should_succeed(self): 52 | # Note that this same problem should fail in dd_test.py 53 | p = pr.Problem.from_txt( 54 | 'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?' 55 | ' perp d c c e' 56 | ) # pylint: disable=line-too-long 57 | g, _ = gh.Graph.build_problem(p, DDARTest.defs) 58 | 59 | ddar.solve(g, DDARTest.rules, p, max_level=1000) 60 | goal_args = g.names2nodes(p.goal.args) 61 | self.assertTrue(g.check(p.goal.name, goal_args)) 62 | 63 | 64 | if __name__ == '__main__': 65 | absltest.main() 66 | -------------------------------------------------------------------------------- /decoder_stack.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """The decoder stack in inference mode.""" 17 | 18 | from typing import Any, Tuple 19 | 20 | import gin 21 | from transformer import decoder_stack 22 | import transformer_layer as tl 23 | 24 | 25 | struct = decoder_stack.struct 26 | nn_components = decoder_stack.nn_components 27 | position = decoder_stack.position 28 | jnp = decoder_stack.jnp 29 | attention = decoder_stack.attention 30 | 31 | DStackWindowState = decoder_stack.DStackWindowState 32 | 33 | Array = Any 34 | 35 | TransformerTaskConfig = decoder_stack.TransformerTaskConfig 36 | 37 | DStackDecoderState = Tuple[tl.DecoderState, ...] 38 | 39 | 40 | @gin.configurable 41 | class DecoderStackGenerate(decoder_stack.DecoderStack): 42 | """Stack of transformer decoder layers.""" 43 | 44 | layer_factory = tl.TransformerLayerGenerate 45 | 46 | def init_decoder_state_vanilla( 47 | self, sequence_length: int, start_of_sequence: Array 48 | ) -> DStackDecoderState: 49 | """Return initial state for autoregressive generation.""" 50 | return tuple( 51 | [ 52 | layer.init_decoder_state_vanilla(sequence_length, start_of_sequence) 53 | for layer in self.transformer_layers 54 | ] 55 | ) 56 | -------------------------------------------------------------------------------- /defs.txt: -------------------------------------------------------------------------------- 1 | angle_bisector x a b c 2 | x : a b c x 3 | a b c = ncoll a b c 4 | x : eqangle b a b x b x b c 5 | bisect a b c 6 | 7 | angle_mirror x a b c 8 | x : a b c x 9 | a b c = ncoll a b c 10 | x : eqangle b a b c b c b x 11 | amirror a b c 12 | 13 | circle x a b c 14 | x : a b c 15 | a b c = ncoll a b c 16 | x : cong x a x b, cong x b x c 17 | bline a b, bline a c 18 | 19 | circumcenter x a b c 20 | x : a b c 21 | a b c = ncoll a b c 22 | x : cong x a x b, cong x b x c 23 | bline a b, bline a c 24 | 25 | eq_quadrangle a b c d 26 | d : a b c d 27 | = 28 | a : ; b : ; c : ; d : cong d a b c 29 | eq_quadrangle 30 | 31 | eq_trapezoid a b c d 32 | d : a b c 33 | = 34 | a : ; b : ; c : ; d : para d c a b, cong d a b c 35 | eq_trapezoid 36 | 37 | eq_triangle x b c 38 | x : b c 39 | b c = diff b c 40 | x : cong x b b c, cong b c c x; eqangle b x b c c b c x, eqangle x c x b b x b c 41 | circle b b c, circle c b c 42 | 43 | eqangle2 x a b c 44 | x : a b c x 45 | a b c = ncoll a b c 46 | x : eqangle a b a x c x c b 47 | eqangle2 a b c 48 | 49 | eqdia_quadrangle a b c d 50 | d : a b c d 51 | = 52 | a : ; b : ; c : ; d : cong d b a c 53 | eqdia_quadrangle 54 | 55 | eqdistance x a b c 56 | x : a b c x 57 | a b c = diff b c 58 | x : cong x a b c 59 | circle a b c 60 | 61 | foot x a b c 62 | x : a b c 63 | a b c = ncoll a b c 64 | x : perp x a b c, coll x b c 65 | tline a b c, line b c 66 | 67 | free a 68 | a : a 69 | = 70 | a : 71 | free 72 | 73 | incenter x a b c 74 | x : a b c 75 | a b c = ncoll a b c 76 | x : eqangle a b a x a x a c, eqangle c a c x c x c b; eqangle b c b x b x b a 77 | bisect a b c, bisect b c a 78 | 79 | incenter2 x y z i a b c 80 | i : a b c, x : i b c, y : i c a, z : i a b 81 | a b c = ncoll a b c 82 | i : eqangle a b a i a i a c, eqangle c a c i c i c b; eqangle b c b i b i b a; x : coll x b c, perp i x b c; y : coll y c a, perp i y c a; z : coll z a b, perp i z a b; cong i x i y, cong i y i z 83 | incenter2 a b c 84 | 85 | excenter x a b c 86 | x : a b c 87 | a b c = ncoll a b c 88 | x : eqangle a b a x a x a c, eqangle c a c x c x c b; eqangle b c b x b x b a 89 | bisect b a c, exbisect b c a 90 | 91 | excenter2 x y z i a b c 92 | i : a b c, x : i b c, y : i c a, z : i a b 93 | a b c = ncoll a b c 94 | i : eqangle a b a i a i a c, eqangle c a c i c i c b; eqangle b c b i b i b a; x : coll x b c, perp i x b c; y : coll y c a, perp i y c a; z : coll z a b, perp i z a b; cong i x i y, cong i y i z 95 | excenter2 a b c 96 | 97 | centroid x y z i a b c 98 | x : b c, y : c a, z : a b, i : a x b y 99 | a b c = ncoll a b c 100 | x : coll x b c, cong x b x c; y : coll y c a, cong y c y a; z : coll z a b, cong z a z b; i : coll a x i, coll b y i; coll c z i 101 | centroid a b c 102 | 103 | ninepoints x y z i a b c 104 | x : b c, y : c a, z : a b, i : x y z 105 | a b c = ncoll a b c 106 | x : coll x b c, cong x b x c; y : coll y c a, cong y c y a; z : coll z a b, cong z a z b; i : cong i x i y, cong i y i z 107 | ninepoints a b c 108 | 109 | intersection_cc x o w a 110 | x : o w a 111 | o w a = ncoll o w a 112 | x : cong o a o x, cong w a w x 113 | circle o o a, circle w w a 114 | 115 | intersection_lc x a o b 116 | x : a o b 117 | a o b = diff a b, diff o b, nperp b o b a 118 | x : coll x a b, cong o b o x 119 | line b a, circle o o b 120 | 121 | intersection_ll x a b c d 122 | x : a b c d 123 | a b c d = npara a b c d, ncoll a b c d 124 | x : coll x a b, coll x c d 125 | line a b, line c d 126 | 127 | intersection_lp x a b c m n 128 | x : a b c m n 129 | a b c m n = npara m n a b, ncoll a b c, ncoll c m n 130 | x : coll x a b, para c x m n 131 | line a b, pline c m n 132 | 133 | intersection_lt x a b c d e 134 | x : a b c d e 135 | a b c d e = ncoll a b c, nperp a b d e 136 | x : coll x a b, perp x c d e 137 | line a b, tline c d e 138 | 139 | intersection_pp x a b c d e f 140 | x : a b c d e f 141 | a b c d e f = diff a d, npara b c e f 142 | x : para x a b c, para x d e f 143 | pline a b c, pline d e f 144 | 145 | intersection_tt x a b c d e f 146 | x : a b c d e f 147 | a b c d e f = diff a d, npara b c e f 148 | x : perp x a b c, perp x d e f 149 | tline a b c, tline d e f 150 | 151 | iso_triangle a b c 152 | c : a b c 153 | = 154 | a : ; b : ; c : eqangle b a b c c b c a, cong a b a c 155 | isos 156 | 157 | lc_tangent x a o 158 | x : x a o 159 | a o = diff a o 160 | x : perp a x a o 161 | tline a a o 162 | 163 | midpoint x a b 164 | x : a b 165 | a b = diff a b 166 | x : coll x a b, cong x a x b 167 | midp a b 168 | 169 | mirror x a b 170 | x : a b 171 | a b = diff a b 172 | x : coll x a b, cong b a b x 173 | pmirror a b 174 | 175 | nsquare x a b 176 | x : a b 177 | a b = diff a b 178 | x : cong x a a b, perp x a a b 179 | rotaten90 a b 180 | 181 | on_aline x a b c d e 182 | x : x a b c d e 183 | a b c d e = ncoll c d e 184 | x : eqangle a x a b d c d e 185 | aline e d c b a 186 | 187 | on_aline2 x a b c d e 188 | x : x a b c d e 189 | a b c d e = ncoll c d e 190 | x : eqangle x a x b d c d e 191 | aline2 e d c b a 192 | 193 | on_bline x a b 194 | x : x a b 195 | a b = diff a b 196 | x : cong x a x b, eqangle a x a b b a b x 197 | bline a b 198 | 199 | on_circle x o a 200 | x : x o a 201 | o a = diff o a 202 | x : cong o x o a 203 | circle o o a 204 | 205 | on_line x a b 206 | x : x a b 207 | a b = diff a b 208 | x : coll x a b 209 | line a b 210 | 211 | on_pline x a b c 212 | x : x a b c 213 | a b c = diff b c, ncoll a b c 214 | x : para x a b c 215 | pline a b c 216 | 217 | on_tline x a b c 218 | x : x a b c 219 | a b c = diff b c 220 | x : perp x a b c 221 | tline a b c 222 | 223 | orthocenter x a b c 224 | x : a b c 225 | a b c = ncoll a b c 226 | x : perp x a b c, perp x b c a; perp x c a b 227 | tline a b c, tline b c a 228 | 229 | parallelogram a b c x 230 | x : a b c 231 | a b c = ncoll a b c 232 | x : para a b c x, para a x b c; cong a b c x, cong a x b c 233 | pline a b c, pline c a b 234 | 235 | pentagon a b c d e 236 | 237 | = 238 | a : ; b : ; c : ; d : ; e : 239 | pentagon 240 | 241 | psquare x a b 242 | x : a b 243 | a b = diff a b 244 | x : cong x a a b, perp x a a b 245 | rotatep90 a b 246 | 247 | quadrangle a b c d 248 | 249 | = 250 | a : ; b : ; c : ; d : 251 | quadrangle 252 | 253 | r_trapezoid a b c d 254 | d : a b c 255 | = 256 | a : ; b : ; c : ; d : para a b c d, perp a b a d 257 | r_trapezoid 258 | 259 | r_triangle a b c 260 | c : a b c 261 | = 262 | a : ; b : ; c : perp a b a c 263 | r_triangle 264 | 265 | rectangle a b c d 266 | c : a b c , d : a b c 267 | = 268 | a : ; b : ; c : perp a b b c ; d : para a b c d, para a d b c; perp a b a d, cong a b c d, cong a d b c, cong a c b d 269 | rectangle 270 | 271 | reflect x a b c 272 | x : a b c 273 | a b c = diff b c, ncoll a b c 274 | x : cong b a b x, cong c a c x; perp b c a x 275 | reflect a b c 276 | 277 | risos a b c 278 | c : a b 279 | = 280 | a : ; b : ; c : perp a b a c, cong a b a c; eqangle b a b c c b c a 281 | risos 282 | 283 | s_angle a b x y 284 | x : a b x 285 | a b = diff a b 286 | x : s_angle a b x y 287 | s_angle a b y 288 | 289 | segment a b 290 | 291 | = 292 | a : ; b : 293 | segment 294 | 295 | shift x b c d 296 | x : b c d 297 | b c d = diff d b 298 | x : cong x b c d, cong x c b d 299 | shift d c b 300 | 301 | square a b x y 302 | x : a b, y : a b x 303 | a b = diff a b 304 | x : perp a b b x, cong a b b x; y : para a b x y, para a y b x; perp a y y x, cong b x x y, cong x y y a, perp a x b y, cong a x b y 305 | square a b 306 | 307 | isquare a b c d 308 | c : a b , d : a b c 309 | = 310 | a : ; b : ; c : perp a b b c, cong a b b c; d : para a b c d, para a d b c; perp a d d c, cong b c c d, cong c d d a, perp a c b d, cong a c b d 311 | isquare 312 | 313 | trapezoid a b c d 314 | d : a b c d 315 | = 316 | a : ; b : ; c : ; d : para a b c d 317 | trapezoid 318 | 319 | triangle a b c 320 | 321 | = 322 | a : ; b : ; c : 323 | triangle 324 | 325 | triangle12 a b c 326 | c : a b c 327 | = 328 | a : ; b : ; c : rconst a b a c 1 2 329 | triangle12 330 | 331 | 2l1c x y z i a b c o 332 | x : a b c o y z i, y : a b c o x z i, z : a b c o x y i, i : a b c o x y z 333 | a b c o = cong o a o b, ncoll a b c 334 | x y z i : coll x a c, coll y b c, cong o a o z, coll i o z, cong i x i y, cong i y i z, perp i x a c, perp i y b c 335 | 2l1c a b c o 336 | 337 | e5128 x y a b c d 338 | x : a b c d y, y : a b c d x 339 | a b c d = cong c b c d, perp b c b a 340 | x y : cong c b c x, coll y a b, coll x y d, eqangle a b a d x a x y 341 | e5128 a b c d 342 | 343 | 3peq x y z a b c 344 | z : b c z , x : a b c z y, y : a b c z x 345 | a b c = ncoll a b c 346 | z : coll z b c ; x y : coll x a b, coll y a c, coll x y z, cong z x z y 347 | 3peq a b c 348 | 349 | trisect x y a b c 350 | x : a b c y, y : a b c x 351 | a b c = ncoll a b c 352 | x y : coll x a c, coll y a c, eqangle b a b x b x b y, eqangle b x b y b y b c 353 | trisect a b c 354 | 355 | trisegment x y a b 356 | x : a b y, y : a b x 357 | a b = diff a b 358 | x y : coll x a b, coll y a b, cong x a x y, cong y x y b 359 | trisegment a b 360 | 361 | on_dia x a b 362 | x : x a b 363 | a b = diff a b 364 | x : perp x a x b 365 | dia a b 366 | 367 | ieq_triangle a b c 368 | c : a b 369 | = 370 | a : ; b : ; c : cong a b b c, cong b c c a; eqangle a b a c c a c b, eqangle c a c b b c b a 371 | ieq_triangle 372 | 373 | on_opline x a b 374 | x : x a b 375 | a b = diff a b 376 | x : coll x a b 377 | on_opline a b 378 | 379 | cc_tangent0 x y o a w b 380 | x : o a w b y, y : o a w b x 381 | o a w b = diff o a, diff w b, diff o w 382 | x y : cong o x o a, cong w y w b, perp x o x y, perp y w y x 383 | cc_tangent0 o a w b 384 | 385 | cc_tangent x y z i o a w b 386 | x : o a w b y, y : o a w b x, z : o a w b i, i : o a w b z 387 | o a w b = diff o a, diff w b, diff o w 388 | x y : cong o x o a, cong w y w b, perp x o x y, perp y w y x; z i : cong o z o a, cong w i w b, perp z o z i, perp i w i z 389 | cc_tangent o a w b 390 | 391 | eqangle3 x a b d e f 392 | x : x a b d e f 393 | a b d e f = ncoll d e f, diff a b, diff d e, diff e f 394 | x : eqangle x a x b d e d f 395 | eqangle3 a b d e f 396 | 397 | tangent x y a o b 398 | x y : o a b 399 | a o b = diff o a, diff o b, diff a b 400 | x : cong o x o b, perp a x o x; y : cong o y o b, perp a y o y 401 | tangent a o b 402 | 403 | on_circum x a b c 404 | x : a b c 405 | a b c = ncoll a b c 406 | x : cyclic a b c x 407 | cyclic a b c 408 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | gdown --folder https://bit.ly/alphageometry 17 | export DATA=ag_ckpt_vocab 18 | -------------------------------------------------------------------------------- /examples.txt: -------------------------------------------------------------------------------- 1 | orthocenter 2 | a b c = triangle; h = on_tline b a c, on_tline c a b ? perp a h b c 3 | orthocenter_aux 4 | a b c = triangle; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c 5 | incenter_excenter 6 | a b c = triangle a b c; d1 d2 d3 d = incenter2 a b c; e1 e2 e3 e = excenter2 a b c ? perp d c c e 7 | euler 8 | a b c = triangle a b c; h = orthocenter a b c; h1 = foot a b c; h2 = foot b c a; h3 = foot c a b; g1 g2 g3 g = centroid g1 g2 g3 g a b c; o = circle a b c ? coll h g o 9 | -------------------------------------------------------------------------------- /geometry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Implements geometric objects used in the graph representation.""" 17 | from __future__ import annotations 18 | from collections import defaultdict # pylint: disable=g-importing-member 19 | from typing import Any, Type 20 | 21 | # pylint: disable=protected-access 22 | 23 | 24 | class Node: 25 | r"""Node in the proof state graph. 26 | 27 | Can be Point, Line, Circle, etc. 28 | 29 | Each node maintains a merge history to 30 | other nodes if they are (found out to be) equivalent 31 | 32 | a -> b - 33 | \ 34 | c -> d -> e -> f -> g 35 | 36 | d.merged_to = e 37 | d.rep = g 38 | d.merged_from = {a, b, c, d} 39 | d.equivs = {a, b, c, d, e, f, g} 40 | """ 41 | 42 | def __init__(self, name: str = '', graph: Any = None): 43 | self.name = name or str(self) 44 | self.graph = graph 45 | 46 | self.edge_graph = {} 47 | # Edge graph: what other nodes is connected to this node. 48 | # edge graph = { 49 | # other1: {self1: deps, self2: deps}, 50 | # other2: {self2: deps, self3: deps} 51 | # } 52 | 53 | self.merge_graph = {} 54 | # Merge graph: history of merges with other nodes. 55 | # merge_graph = {self1: {self2: deps1, self3: deps2}} 56 | 57 | self.rep_by = None # represented by. 58 | self.members = {self} 59 | 60 | self._val = None 61 | self._obj = None 62 | 63 | self.deps = [] 64 | 65 | # numerical representation. 66 | self.num = None 67 | self.change = set() # what other nodes' num rely on this node? 68 | 69 | def set_rep(self, node: Node) -> None: 70 | if node == self: 71 | return 72 | self.rep_by = node 73 | node.merge_edge_graph(self.edge_graph) 74 | node.members.update(self.members) 75 | 76 | def rep(self) -> Node: 77 | x = self 78 | while x.rep_by: 79 | x = x.rep_by 80 | return x 81 | 82 | def why_rep(self) -> list[Any]: 83 | return self.why_equal([self.rep()], None) 84 | 85 | def rep_and_why(self) -> tuple[Node, list[Any]]: 86 | rep = self.rep() 87 | return rep, self.why_equal([rep], None) 88 | 89 | def neighbors( 90 | self, oftype: Type[Node], return_set: bool = False, do_rep: bool = True 91 | ) -> list[Node]: 92 | """Neighbors of this node in the proof state graph.""" 93 | if do_rep: 94 | rep = self.rep() 95 | else: 96 | rep = self 97 | result = set() 98 | 99 | for n in rep.edge_graph: 100 | if oftype is None or oftype and isinstance(n, oftype): 101 | if do_rep: 102 | result.add(n.rep()) 103 | else: 104 | result.add(n) 105 | 106 | if return_set: 107 | return result 108 | return list(result) 109 | 110 | def merge_edge_graph( 111 | self, new_edge_graph: dict[Node, dict[Node, list[Node]]] 112 | ) -> None: 113 | for x, xdict in new_edge_graph.items(): 114 | if x in self.edge_graph: 115 | self.edge_graph[x].update(dict(xdict)) 116 | else: 117 | self.edge_graph[x] = dict(xdict) 118 | 119 | def merge(self, nodes: list[Node], deps: list[Any]) -> None: 120 | for node in nodes: 121 | self.merge_one(node, deps) 122 | 123 | def merge_one(self, node: Node, deps: list[Any]) -> None: 124 | node.rep().set_rep(self.rep()) 125 | 126 | if node in self.merge_graph: 127 | return 128 | 129 | self.merge_graph[node] = deps 130 | node.merge_graph[self] = deps 131 | 132 | def is_val(self, node: Node) -> bool: 133 | return ( 134 | isinstance(self, Line) 135 | and isinstance(node, Direction) 136 | or isinstance(self, Segment) 137 | and isinstance(node, Length) 138 | or isinstance(self, Angle) 139 | and isinstance(node, Measure) 140 | or isinstance(self, Ratio) 141 | and isinstance(node, Value) 142 | ) 143 | 144 | def set_val(self, node: Node) -> None: 145 | self._val = node 146 | 147 | def set_obj(self, node: Node) -> None: 148 | self._obj = node 149 | 150 | @property 151 | def val(self) -> Node: 152 | if self._val is None: 153 | return None 154 | return self._val.rep() 155 | 156 | @property 157 | def obj(self) -> Node: 158 | if self._obj is None: 159 | return None 160 | return self._obj.rep() 161 | 162 | def equivs(self) -> set[Node]: 163 | return self.rep().members 164 | 165 | def connect_to(self, node: Node, deps: list[Any] = None) -> None: 166 | rep = self.rep() 167 | 168 | if node in rep.edge_graph: 169 | rep.edge_graph[node].update({self: deps}) 170 | else: 171 | rep.edge_graph[node] = {self: deps} 172 | 173 | if self.is_val(node): 174 | self.set_val(node) 175 | node.set_obj(self) 176 | 177 | def equivs_upto(self, level: int) -> dict[Node, Node]: 178 | """What are the equivalent nodes up to a certain level.""" 179 | parent = {self: None} 180 | visited = set() 181 | queue = [self] 182 | i = 0 183 | 184 | while i < len(queue): 185 | current = queue[i] 186 | i += 1 187 | visited.add(current) 188 | 189 | for neighbor in current.merge_graph: 190 | if ( 191 | level is not None 192 | and current.merge_graph[neighbor].level is not None 193 | and current.merge_graph[neighbor].level >= level 194 | ): 195 | continue 196 | if neighbor not in visited: 197 | queue.append(neighbor) 198 | parent[neighbor] = current 199 | 200 | return parent 201 | 202 | def why_equal(self, others: list[Node], level: int) -> list[Any]: 203 | """BFS why this node is equal to other nodes.""" 204 | others = set(others) 205 | found = 0 206 | 207 | parent = {} 208 | queue = [self] 209 | i = 0 210 | 211 | while i < len(queue): 212 | current = queue[i] 213 | if current in others: 214 | found += 1 215 | if found == len(others): 216 | break 217 | 218 | i += 1 219 | 220 | for neighbor in current.merge_graph: 221 | if ( 222 | level is not None 223 | and current.merge_graph[neighbor].level is not None 224 | and current.merge_graph[neighbor].level >= level 225 | ): 226 | continue 227 | if neighbor not in parent: 228 | queue.append(neighbor) 229 | parent[neighbor] = current 230 | 231 | return bfs_backtrack(self, others, parent) 232 | 233 | def why_equal_groups( 234 | self, groups: list[list[Node]], level: int 235 | ) -> tuple[list[Any], list[Node]]: 236 | """BFS for why self is equal to at least one member of each group.""" 237 | others = [None for _ in groups] 238 | found = 0 239 | 240 | parent = {} 241 | queue = [self] 242 | i = 0 243 | 244 | while i < len(queue): 245 | current = queue[i] 246 | 247 | for j, grp in enumerate(groups): 248 | if others[j] is None and current in grp: 249 | others[j] = current 250 | found += 1 251 | 252 | if found == len(others): 253 | break 254 | 255 | i += 1 256 | 257 | for neighbor in current.merge_graph: 258 | if ( 259 | level is not None 260 | and current.merge_graph[neighbor].level is not None 261 | and current.merge_graph[neighbor].level >= level 262 | ): 263 | continue 264 | if neighbor not in parent: 265 | queue.append(neighbor) 266 | parent[neighbor] = current 267 | 268 | return bfs_backtrack(self, others, parent), others 269 | 270 | def why_val(self, level: int) -> list[Any]: 271 | return self._val.why_equal([self.val], level) 272 | 273 | def why_connect(self, node: Node, level: int = None) -> list[Any]: 274 | rep = self.rep() 275 | equivs = list(rep.edge_graph[node].keys()) 276 | if not equivs: 277 | return None 278 | equiv = equivs[0] 279 | dep = rep.edge_graph[node][equiv] 280 | return [dep] + self.why_equal(equiv, level) 281 | 282 | 283 | def why_connect(*pairs: list[tuple[Node, Node]]) -> list[Any]: 284 | result = [] 285 | for node1, node2 in pairs: 286 | result += node1.why_connect(node2) 287 | return result 288 | 289 | 290 | def is_equiv(x: Node, y: Node, level: int = None) -> bool: 291 | level = level or float('inf') 292 | return x.why_equal([y], level) is not None 293 | 294 | 295 | def is_equal(x: Node, y: Node, level: int = None) -> bool: 296 | if x == y: 297 | return True 298 | if x._val is None or y._val is None: 299 | return False 300 | if x.val != y.val: 301 | return False 302 | return is_equiv(x._val, y._val, level) 303 | 304 | 305 | def bfs_backtrack( 306 | root: Node, leafs: list[Node], parent: dict[Node, Node] 307 | ) -> list[Any]: 308 | """Return the path given BFS trace of parent nodes.""" 309 | backtracked = {root} # no need to backtrack further when touching this set. 310 | deps = [] 311 | for node in leafs: 312 | if node is None: 313 | return None 314 | if node in backtracked: 315 | continue 316 | if node not in parent: 317 | return None 318 | while node not in backtracked: 319 | backtracked.add(node) 320 | deps.append(node.merge_graph[parent[node]]) 321 | node = parent[node] 322 | 323 | return deps 324 | 325 | 326 | class Point(Node): 327 | pass 328 | 329 | 330 | class Line(Node): 331 | """Node of type Line.""" 332 | 333 | def new_val(self) -> Direction: 334 | return Direction() 335 | 336 | def why_coll(self, points: list[Point], level: int = None) -> list[Any]: 337 | """Why points are connected to self.""" 338 | level = level or float('inf') 339 | 340 | groups = [] 341 | for p in points: 342 | group = [ 343 | l 344 | for l, d in self.edge_graph[p].items() 345 | if d is None or d.level < level 346 | ] 347 | if not group: 348 | return None 349 | groups.append(group) 350 | 351 | min_deps = None 352 | for line in groups[0]: 353 | deps, others = line.why_equal_groups(groups[1:], level) 354 | if deps is None: 355 | continue 356 | for p, o in zip(points, [line] + others): 357 | deps.append(self.edge_graph[p][o]) 358 | if min_deps is None or len(deps) < len(min_deps): 359 | min_deps = deps 360 | 361 | if min_deps is None: 362 | return None 363 | return [d for d in min_deps if d is not None] 364 | 365 | 366 | class Segment(Node): 367 | 368 | def new_val(self) -> Length: 369 | return Length() 370 | 371 | 372 | class Circle(Node): 373 | """Node of type Circle.""" 374 | 375 | def why_cyclic(self, points: list[Point], level: int = None) -> list[Any]: 376 | """Why points are connected to self.""" 377 | level = level or float('inf') 378 | 379 | groups = [] 380 | for p in points: 381 | group = [ 382 | c 383 | for c, d in self.edge_graph[p].items() 384 | if d is None or d.level < level 385 | ] 386 | if not group: 387 | return None 388 | groups.append(group) 389 | 390 | min_deps = None 391 | for circle in groups[0]: 392 | deps, others = circle.why_equal_groups(groups[1:], level) 393 | if deps is None: 394 | continue 395 | for p, o in zip(points, [circle] + others): 396 | deps.append(self.edge_graph[p][o]) 397 | 398 | if min_deps is None or len(deps) < len(min_deps): 399 | min_deps = deps 400 | 401 | if min_deps is None: 402 | return None 403 | return [d for d in min_deps if d is not None] 404 | 405 | 406 | def why_equal(x: Node, y: Node, level: int = None) -> list[Any]: 407 | if x == y: 408 | return [] 409 | if not x._val or not y._val: 410 | return None 411 | if x._val == y._val: 412 | return [] 413 | return x._val.why_equal([y._val], level) 414 | 415 | 416 | class Direction(Node): 417 | pass 418 | 419 | 420 | def get_lines_thru_all(*points: list[Point]) -> list[Line]: 421 | line2count = defaultdict(lambda: 0) 422 | points = set(points) 423 | for p in points: 424 | for l in p.neighbors(Line): 425 | line2count[l] += 1 426 | return [l for l, count in line2count.items() if count == len(points)] 427 | 428 | 429 | def line_of_and_why( 430 | points: list[Point], level: int = None 431 | ) -> tuple[Line, list[Any]]: 432 | """Why points are collinear.""" 433 | for l0 in get_lines_thru_all(*points): 434 | for l in l0.equivs(): 435 | if all([p in l.edge_graph for p in points]): 436 | x, y = l.points 437 | colls = list({x, y} | set(points)) 438 | # if len(colls) < 3: 439 | # return l, [] 440 | why = l.why_coll(colls, level) 441 | if why is not None: 442 | return l, why 443 | 444 | return None, None 445 | 446 | 447 | def get_circles_thru_all(*points: list[Point]) -> list[Circle]: 448 | circle2count = defaultdict(lambda: 0) 449 | points = set(points) 450 | for p in points: 451 | for c in p.neighbors(Circle): 452 | circle2count[c] += 1 453 | return [c for c, count in circle2count.items() if count == len(points)] 454 | 455 | 456 | def circle_of_and_why( 457 | points: list[Point], level: int = None 458 | ) -> tuple[Circle, list[Any]]: 459 | """Why points are concyclic.""" 460 | for c0 in get_circles_thru_all(*points): 461 | for c in c0.equivs(): 462 | if all([p in c.edge_graph for p in points]): 463 | cycls = list(set(points)) 464 | why = c.why_cyclic(cycls, level) 465 | if why is not None: 466 | return c, why 467 | 468 | return None, None 469 | 470 | 471 | def name_map(struct: Any) -> Any: 472 | if isinstance(struct, list): 473 | return [name_map(x) for x in struct] 474 | elif isinstance(struct, tuple): 475 | return tuple([name_map(x) for x in struct]) 476 | elif isinstance(struct, set): 477 | return set([name_map(x) for x in struct]) 478 | elif isinstance(struct, dict): 479 | return {name_map(x): name_map(y) for x, y in struct.items()} 480 | else: 481 | return getattr(struct, 'name', '') 482 | 483 | 484 | class Angle(Node): 485 | """Node of type Angle.""" 486 | 487 | def new_val(self) -> Measure: 488 | return Measure() 489 | 490 | def set_directions(self, d1: Direction, d2: Direction) -> None: 491 | self._d = d1, d2 492 | 493 | @property 494 | def directions(self) -> tuple[Direction, Direction]: 495 | d1, d2 = self._d 496 | if d1 is None or d2 is None: 497 | return d1, d2 498 | return d1.rep(), d2.rep() 499 | 500 | 501 | class Measure(Node): 502 | pass 503 | 504 | 505 | class Length(Node): 506 | pass 507 | 508 | 509 | class Ratio(Node): 510 | """Node of type Ratio.""" 511 | 512 | def new_val(self) -> Value: 513 | return Value() 514 | 515 | def set_lengths(self, l1: Length, l2: Length) -> None: 516 | self._l = l1, l2 517 | 518 | @property 519 | def lengths(self) -> tuple[Length, Length]: 520 | l1, l2 = self._l 521 | if l1 is None or l2 is None: 522 | return l1, l2 523 | return l1.rep(), l2.rep() 524 | 525 | 526 | class Value(Node): 527 | pass 528 | 529 | 530 | def all_angles( 531 | d1: Direction, d2: Direction, level: int = None 532 | ) -> tuple[Angle, list[Direction], list[Direction]]: 533 | level = level or float('inf') 534 | d1s = d1.equivs_upto(level) 535 | d2s = d2.equivs_upto(level) 536 | 537 | for ang in d1.rep().neighbors(Angle): 538 | d1_, d2_ = ang._d 539 | if d1_ in d1s and d2_ in d2s: 540 | yield ang, d1s, d2s 541 | 542 | 543 | def all_ratios( 544 | d1, d2, level=None 545 | ) -> tuple[Angle, list[Direction], list[Direction]]: 546 | level = level or float('inf') 547 | d1s = d1.equivs_upto(level) 548 | d2s = d2.equivs_upto(level) 549 | 550 | for ang in d1.rep().neighbors(Ratio): 551 | d1_, d2_ = ang._l 552 | if d1_ in d1s and d2_ in d2s: 553 | yield ang, d1s, d2s 554 | 555 | 556 | RANKING = { 557 | Point: 0, 558 | Line: 1, 559 | Segment: 2, 560 | Circle: 3, 561 | Direction: 4, 562 | Length: 5, 563 | Angle: 6, 564 | Ratio: 7, 565 | Measure: 8, 566 | Value: 9, 567 | } 568 | 569 | 570 | def val_type(x: Node) -> Type[Node]: 571 | if isinstance(x, Line): 572 | return Direction 573 | if isinstance(x, Segment): 574 | return Length 575 | if isinstance(x, Angle): 576 | return Measure 577 | if isinstance(x, Ratio): 578 | return Value 579 | -------------------------------------------------------------------------------- /geometry_150M_generate.gin: -------------------------------------------------------------------------------- 1 | NUM_EMBEDDINGS = 1024 2 | 3 | # Number of parameters = 152M 4 | NUM_LAYERS = 12 5 | EMBED_DIM = 1024 6 | NUM_HEADS = 8 7 | HEAD_DIM = 128 8 | MLP_DIM = 4096 9 | 10 | 11 | transformer_layer.TransformerLayerGenerate: 12 | num_heads = %NUM_HEADS 13 | head_size = %HEAD_DIM 14 | window_length = 1024 15 | use_long_xl_architecture = False 16 | max_unrolled_windows = -1 # Always unroll. 17 | relative_position_type = "t5" # Can be "fourier", "t5", or None. 18 | use_causal_mask = True 19 | attn_dropout_rate = %ATTN_DROPOUT_RATE # Attention matrix dropout. 20 | memory_num_neighbors = 0 21 | dtype = %DTYPE 22 | 23 | decoder_stack.DecoderStackGenerate: 24 | num_layers = %NUM_LAYERS 25 | embedding_size = %EMBED_DIM 26 | embedding_stddev = 1.0 27 | layer_factory = @transformer_layer.TransformerLayerGenerate 28 | dstack_window_length = 0 29 | use_absolute_positions = False 30 | use_final_layernorm = True # Final layernorm before token lookup. 31 | final_dropout_rate = %DROPOUT_RATE # Dropout before token lookup. 32 | final_mlp_factory = None # Final MLP to predict target tokens. 33 | recurrent_layer_indices = () 34 | memory_factory = None # e.g. @memory_factory.memory_on_tpu_factory 35 | memory_layer_indices = () 36 | dtype = %DTYPE 37 | 38 | 39 | models.DecoderOnlyLanguageModelGenerate: 40 | num_heads = %NUM_HEADS 41 | head_size = %HEAD_DIM 42 | task_config = @decoder_stack.TransformerTaskConfig() 43 | decoder_factory = @decoder_stack.DecoderStackGenerate 44 | 45 | 46 | training_loop.Trainer: 47 | model_definition = @models.DecoderOnlyLanguageModelGenerate 48 | -------------------------------------------------------------------------------- /geometry_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for geometry.py.""" 17 | import unittest 18 | 19 | from absl.testing import absltest 20 | import geometry as gm 21 | 22 | 23 | class GeometryTest(unittest.TestCase): 24 | 25 | def _setup_equality_example(self): 26 | # Create 4 nodes a, b, c, d 27 | # and their lengths 28 | a = gm.Segment('a') 29 | la = gm.Length('l(a)') 30 | a.connect_to(la) 31 | la.connect_to(a) 32 | 33 | b = gm.Segment('b') 34 | lb = gm.Length('l(b)') 35 | b.connect_to(lb) 36 | lb.connect_to(b) 37 | 38 | c = gm.Segment('c') 39 | lc = gm.Length('l(c)') 40 | c.connect_to(lc) 41 | lc.connect_to(c) 42 | 43 | d = gm.Segment('d') 44 | ld = gm.Length('l(d)') 45 | d.connect_to(ld) 46 | ld.connect_to(d) 47 | 48 | # Now let a=b, b=c, a=c, c=d 49 | la.merge([lb], 'fact1') 50 | lb.merge([lc], 'fact2') 51 | la.merge([lc], 'fact3') 52 | lc.merge([ld], 'fact4') 53 | return a, b, c, d, la, lb, lc, ld 54 | 55 | def test_merged_node_representative(self): 56 | _, _, _, _, la, lb, lc, ld = self._setup_equality_example() 57 | 58 | # all nodes are now represented by la. 59 | self.assertEqual(la.rep(), la) 60 | self.assertEqual(lb.rep(), la) 61 | self.assertEqual(lc.rep(), la) 62 | self.assertEqual(ld.rep(), la) 63 | 64 | def test_merged_node_equivalence(self): 65 | _, _, _, _, la, lb, lc, ld = self._setup_equality_example() 66 | # all la, lb, lc, ld are equivalent 67 | self.assertCountEqual(la.equivs(), [la, lb, lc, ld]) 68 | self.assertCountEqual(lb.equivs(), [la, lb, lc, ld]) 69 | self.assertCountEqual(lc.equivs(), [la, lb, lc, ld]) 70 | self.assertCountEqual(ld.equivs(), [la, lb, lc, ld]) 71 | 72 | def test_bfs_for_equality_transitivity(self): 73 | a, _, _, d, _, _, _, _ = self._setup_equality_example() 74 | 75 | # check that a==d because fact3 & fact4, not fact1 & fact2 76 | self.assertCountEqual(gm.why_equal(a, d), ['fact3', 'fact4']) 77 | 78 | 79 | if __name__ == '__main__': 80 | absltest.main() 81 | -------------------------------------------------------------------------------- /graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for graph.py.""" 17 | import unittest 18 | 19 | from absl.testing import absltest 20 | import graph as gh 21 | import numericals as nm 22 | import problem as pr 23 | 24 | 25 | MAX_LEVEL = 1000 26 | 27 | 28 | class GraphTest(unittest.TestCase): 29 | 30 | @classmethod 31 | def setUpClass(cls): 32 | super().setUpClass() 33 | 34 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) 35 | cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True) 36 | 37 | # load a complex setup: 38 | txt = 'a b c = triangle a b c; h = orthocenter a b c; h1 = foot a b c; h2 = foot b c a; h3 = foot c a b; g1 g2 g3 g = centroid g1 g2 g3 g a b c; o = circle a b c ? coll h g o' # pylint: disable=line-too-long 39 | p = pr.Problem.from_txt(txt, translate=False) 40 | cls.g, _ = gh.Graph.build_problem(p, GraphTest.defs) 41 | 42 | def test_build_graph_points(self): 43 | g = GraphTest.g 44 | 45 | all_points = g.all_points() 46 | all_names = [p.name for p in all_points] 47 | self.assertCountEqual( 48 | all_names, 49 | ['a', 'b', 'c', 'g', 'h', 'o', 'g1', 'g2', 'g3', 'h1', 'h2', 'h3'], 50 | ) 51 | 52 | def test_build_graph_predicates(self): 53 | gr = GraphTest.g 54 | 55 | a, b, c, g, h, o, g1, g2, g3, h1, h2, h3 = gr.names2points( 56 | ['a', 'b', 'c', 'g', 'h', 'o', 'g1', 'g2', 'g3', 'h1', 'h2', 'h3'] 57 | ) 58 | 59 | # Explicit statements: 60 | self.assertTrue(gr.check_cong([b, g1, g1, c])) 61 | self.assertTrue(gr.check_cong([c, g2, g2, a])) 62 | self.assertTrue(gr.check_cong([a, g3, g3, b])) 63 | self.assertTrue(gr.check_perp([a, h1, b, c])) 64 | self.assertTrue(gr.check_perp([b, h2, c, a])) 65 | self.assertTrue(gr.check_perp([c, h3, a, b])) 66 | self.assertTrue(gr.check_cong([o, a, o, b])) 67 | self.assertTrue(gr.check_cong([o, b, o, c])) 68 | self.assertTrue(gr.check_cong([o, a, o, c])) 69 | self.assertTrue(gr.check_coll([a, g, g1])) 70 | self.assertTrue(gr.check_coll([b, g, g2])) 71 | self.assertTrue(gr.check_coll([g1, b, c])) 72 | self.assertTrue(gr.check_coll([g2, c, a])) 73 | self.assertTrue(gr.check_coll([g3, a, b])) 74 | self.assertTrue(gr.check_perp([a, h, b, c])) 75 | self.assertTrue(gr.check_perp([b, h, c, a])) 76 | 77 | # These are NOT part of the premises: 78 | self.assertFalse(gr.check_perp([c, h, a, b])) 79 | self.assertFalse(gr.check_coll([c, g, g3])) 80 | 81 | # These are automatically inferred by the graph datastructure: 82 | self.assertTrue(gr.check_eqangle([a, h1, b, c, b, h2, c, a])) 83 | self.assertTrue(gr.check_eqangle([a, h1, b, h2, b, c, c, a])) 84 | self.assertTrue(gr.check_eqratio([b, g1, g1, c, c, g2, g2, a])) 85 | self.assertTrue(gr.check_eqratio([b, g1, g1, c, o, a, o, b])) 86 | self.assertTrue(gr.check_para([a, h, a, h1])) 87 | self.assertTrue(gr.check_para([b, h, b, h2])) 88 | self.assertTrue(gr.check_coll([a, h, h1])) 89 | self.assertTrue(gr.check_coll([b, h, h2])) 90 | 91 | def test_enumerate_colls(self): 92 | g = GraphTest.g 93 | 94 | for a, b, c in g.all_colls(): 95 | self.assertTrue(g.check_coll([a, b, c])) 96 | self.assertTrue(nm.check_coll([a.num, b.num, c.num])) 97 | 98 | def test_enumerate_paras(self): 99 | g = GraphTest.g 100 | 101 | for a, b, c, d in g.all_paras(): 102 | self.assertTrue(g.check_para([a, b, c, d])) 103 | self.assertTrue(nm.check_para([a.num, b.num, c.num, d.num])) 104 | 105 | def test_enumerate_perps(self): 106 | g = GraphTest.g 107 | 108 | for a, b, c, d in g.all_perps(): 109 | self.assertTrue(g.check_perp([a, b, c, d])) 110 | self.assertTrue(nm.check_perp([a.num, b.num, c.num, d.num])) 111 | 112 | def test_enumerate_congs(self): 113 | g = GraphTest.g 114 | 115 | for a, b, c, d in g.all_congs(): 116 | self.assertTrue(g.check_cong([a, b, c, d])) 117 | self.assertTrue(nm.check_cong([a.num, b.num, c.num, d.num])) 118 | 119 | def test_enumerate_eqangles(self): 120 | g = GraphTest.g 121 | 122 | for a, b, c, d, x, y, z, t in g.all_eqangles_8points(): 123 | self.assertTrue(g.check_eqangle([a, b, c, d, x, y, z, t])) 124 | self.assertTrue( 125 | nm.check_eqangle( 126 | [a.num, b.num, c.num, d.num, x.num, y.num, z.num, t.num] 127 | ) 128 | ) 129 | 130 | def test_enumerate_eqratios(self): 131 | g = GraphTest.g 132 | 133 | for a, b, c, d, x, y, z, t in g.all_eqratios_8points(): 134 | self.assertTrue(g.check_eqratio([a, b, c, d, x, y, z, t])) 135 | self.assertTrue( 136 | nm.check_eqratio( 137 | [a.num, b.num, c.num, d.num, x.num, y.num, z.num, t.num] 138 | ) 139 | ) 140 | 141 | def test_enumerate_cyclics(self): 142 | g = GraphTest.g 143 | 144 | for a, b, c, d, x, y, z, t in g.all_cyclics(): 145 | self.assertTrue(g.check_cyclic([a, b, c, d, x, y, z, t])) 146 | self.assertTrue(nm.check_cyclic([a.num, b.num, c.num, d.num])) 147 | 148 | def test_enumerate_midps(self): 149 | g = GraphTest.g 150 | 151 | for a, b, c in g.all_midps(): 152 | self.assertTrue(g.check_midp([a, b, c])) 153 | self.assertTrue(nm.check_midp([a.num, b.num, c.num])) 154 | 155 | def test_enumerate_circles(self): 156 | g = GraphTest.g 157 | 158 | for a, b, c, d in g.all_circles(): 159 | self.assertTrue(g.check_circle([a, b, c, d])) 160 | self.assertTrue(nm.check_circle([a.num, b.num, c.num, d.num])) 161 | 162 | 163 | if __name__ == '__main__': 164 | absltest.main() 165 | -------------------------------------------------------------------------------- /graph_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilizations for graph representation. 17 | 18 | Mainly for listing combinations and permutations of elements. 19 | """ 20 | 21 | from geometry import Point 22 | 23 | 24 | def _cross(elems1, elems2): 25 | for e1 in elems1: 26 | for e2 in elems2: 27 | yield e1, e2 28 | 29 | 30 | def cross(elems1, elems2): 31 | return list(_cross(elems1, elems2)) 32 | 33 | 34 | def _comb2(elems): 35 | if len(elems) < 2: 36 | return 37 | for i, e1 in enumerate(elems[:-1]): 38 | for e2 in elems[i + 1 :]: 39 | yield e1, e2 40 | 41 | 42 | def comb2(elems): 43 | return list(_comb2(elems)) 44 | 45 | 46 | def _comb3(elems): 47 | if len(elems) < 3: 48 | return 49 | for i, e1 in enumerate(elems[:-2]): 50 | for j, e2 in enumerate(elems[i + 1 : -1]): 51 | for e3 in elems[i + j + 2 :]: 52 | yield e1, e2, e3 53 | 54 | 55 | def comb3(elems): 56 | return list(_comb3(elems)) 57 | 58 | 59 | def _comb4(elems): 60 | if len(elems) < 4: 61 | return 62 | for i, e1 in enumerate(elems[:-3]): 63 | for j, e2 in enumerate(elems[i + 1 : -2]): 64 | for e3, e4 in _comb2(elems[i + j + 2 :]): 65 | yield e1, e2, e3, e4 66 | 67 | 68 | def comb4(elems): 69 | return list(_comb4(elems)) 70 | 71 | 72 | def _perm2(elems): 73 | for e1, e2 in comb2(elems): 74 | yield e1, e2 75 | yield e2, e1 76 | 77 | 78 | def perm2(elems): 79 | return list(_perm2(elems)) 80 | 81 | 82 | def _all_4points(l1, l2): 83 | p1s = l1.neighbors(Point) 84 | p2s = l2.neighbors(Point) 85 | for a, b in perm2(p1s): 86 | for c, d in perm2(p2s): 87 | yield a, b, c, d 88 | 89 | 90 | def all_4points(l1, l2): 91 | return list(_all_4points(l1, l2)) 92 | 93 | 94 | def _all_8points(l1, l2, l3, l4): 95 | for a, b, c, d in all_4points(l1, l2): 96 | for e, f, g, h in all_4points(l3, l4): 97 | yield (a, b, c, d, e, f, g, h) 98 | 99 | 100 | def all_8points(l1, l2, l3, l4): 101 | return list(_all_8points(l1, l2, l3, l4)) 102 | 103 | 104 | def _perm3(elems): 105 | for x in elems: 106 | for y in elems: 107 | if y == x: 108 | continue 109 | for z in elems: 110 | if z not in (x, y): 111 | yield x, y, z 112 | 113 | 114 | def perm3(elems): 115 | return list(_perm3(elems)) 116 | 117 | 118 | def _perm4(elems): 119 | for x in elems: 120 | for y in elems: 121 | if y == x: 122 | continue 123 | for z in elems: 124 | if z in (x, y): 125 | continue 126 | for t in elems: 127 | if t not in (x, y, z): 128 | yield x, y, z, t 129 | 130 | 131 | def perm4(elems): 132 | return list(_perm4(elems)) 133 | -------------------------------------------------------------------------------- /graph_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for graph_utils.py.""" 17 | import unittest 18 | 19 | from absl.testing import absltest 20 | import graph_utils as gu 21 | 22 | 23 | class GraphUtilsTest(unittest.TestCase): 24 | 25 | def test_cross(self): 26 | self.assertEqual(gu.cross([], [1]), []) 27 | self.assertEqual(gu.cross([1], []), []) 28 | self.assertEqual(gu.cross([1], [2]), [(1, 2)]) 29 | self.assertEqual(gu.cross([1], [2, 3]), [(1, 2), (1, 3)]) 30 | 31 | e1 = [1, 2, 3] 32 | e2 = [4, 5] 33 | target = [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)] 34 | self.assertEqual(gu.cross(e1, e2), target) 35 | 36 | def test_comb2(self): 37 | self.assertEqual(gu.comb2([]), []) 38 | self.assertEqual(gu.comb2([1]), []) 39 | self.assertEqual(gu.comb2([1, 2]), [(1, 2)]) 40 | self.assertEqual(gu.comb2([1, 2, 3]), [(1, 2), (1, 3), (2, 3)]) 41 | 42 | def test_comb3(self): 43 | self.assertEqual(gu.comb3([]), []) 44 | self.assertEqual(gu.comb3([1]), []) 45 | self.assertEqual(gu.comb3([1, 2]), []) 46 | self.assertEqual(gu.comb3([1, 2, 3]), [(1, 2, 3)]) 47 | self.assertEqual( 48 | gu.comb3([1, 2, 3, 4]), [(1, 2, 3), (1, 2, 4), (1, 3, 4), (2, 3, 4)] 49 | ) 50 | 51 | def test_comb4(self): 52 | self.assertEqual(gu.comb4([]), []) 53 | self.assertEqual(gu.comb4([1]), []) 54 | self.assertEqual(gu.comb4([1, 2]), []) 55 | self.assertEqual(gu.comb4([1, 2, 3]), []) 56 | self.assertEqual(gu.comb4([1, 2, 3, 4]), [(1, 2, 3, 4)]) 57 | self.assertEqual( 58 | gu.comb4([1, 2, 3, 4, 5]), 59 | [(1, 2, 3, 4), (1, 2, 3, 5), (1, 2, 4, 5), (1, 3, 4, 5), (2, 3, 4, 5)], 60 | ) 61 | 62 | def test_perm2(self): 63 | self.assertEqual(gu.perm2([]), []) 64 | self.assertEqual(gu.perm2([1]), []) 65 | self.assertEqual(gu.perm2([1, 2]), [(1, 2), (2, 1)]) 66 | self.assertEqual( 67 | gu.perm2([1, 2, 3]), [(1, 2), (2, 1), (1, 3), (3, 1), (2, 3), (3, 2)] 68 | ) 69 | 70 | def test_perm3(self): 71 | self.assertEqual(gu.perm3([]), []) 72 | self.assertEqual(gu.perm3([1]), []) 73 | self.assertEqual(gu.perm3([1, 2]), []) 74 | self.assertEqual( 75 | gu.perm3([1, 2, 3]), 76 | [(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)], 77 | ) 78 | self.assertEqual( 79 | gu.perm3([1, 2, 3, 4]), 80 | [ 81 | (1, 2, 3), 82 | (1, 2, 4), 83 | (1, 3, 2), 84 | (1, 3, 4), 85 | (1, 4, 2), 86 | (1, 4, 3), 87 | (2, 1, 3), 88 | (2, 1, 4), 89 | (2, 3, 1), 90 | (2, 3, 4), 91 | (2, 4, 1), 92 | (2, 4, 3), 93 | (3, 1, 2), 94 | (3, 1, 4), 95 | (3, 2, 1), 96 | (3, 2, 4), 97 | (3, 4, 1), 98 | (3, 4, 2), 99 | (4, 1, 2), 100 | (4, 1, 3), 101 | (4, 2, 1), 102 | (4, 2, 3), 103 | (4, 3, 1), 104 | (4, 3, 2), 105 | ], 106 | ) 107 | 108 | def test_perm4(self): 109 | self.assertEqual(gu.perm3([]), []) 110 | self.assertEqual(gu.perm3([1]), []) 111 | self.assertEqual(gu.perm3([1, 2]), []) 112 | self.assertEqual(gu.perm4([1, 2, 3]), []) 113 | self.assertEqual( 114 | gu.perm4([1, 2, 3, 4]), 115 | [ 116 | (1, 2, 3, 4), 117 | (1, 2, 4, 3), 118 | (1, 3, 2, 4), 119 | (1, 3, 4, 2), 120 | (1, 4, 2, 3), 121 | (1, 4, 3, 2), # pylint: disable=line-too-long 122 | (2, 1, 3, 4), 123 | (2, 1, 4, 3), 124 | (2, 3, 1, 4), 125 | (2, 3, 4, 1), 126 | (2, 4, 1, 3), 127 | (2, 4, 3, 1), # pylint: disable=line-too-long 128 | (3, 1, 2, 4), 129 | (3, 1, 4, 2), 130 | (3, 2, 1, 4), 131 | (3, 2, 4, 1), 132 | (3, 4, 1, 2), 133 | (3, 4, 2, 1), # pylint: disable=line-too-long 134 | (4, 1, 2, 3), 135 | (4, 1, 3, 2), 136 | (4, 2, 1, 3), 137 | (4, 2, 3, 1), 138 | (4, 3, 1, 2), 139 | (4, 3, 2, 1), 140 | ], # pylint: disable=line-too-long 141 | ) 142 | 143 | 144 | if __name__ == '__main__': 145 | absltest.main() 146 | -------------------------------------------------------------------------------- /imo_ag_30.txt: -------------------------------------------------------------------------------- 1 | translated_imo_2000_p1 2 | a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m = on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b, on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q 3 | translated_imo_2000_p6 4 | a b c = triangle a b c; h = orthocenter h a b c; t1 t2 t3 i = incenter2 t1 t2 t3 i a b c; h1 = foot h1 a b c; h2 = foot h2 b c a; h3 = foot h3 c a b; x1 = reflect x1 h1 t1 t2; x2 = reflect x2 h2 t1 t2; y2 = reflect y2 h2 t2 t3; y3 = reflect y3 h3 t2 t3; z = on_line z x1 x2, on_line z y2 y3 ? cong i z i t1 5 | translated_imo_2002_p2a 6 | b c = segment b c; o = midpoint o b c; a = on_circle a o b; d = on_circle d o b, on_bline d a b; e = on_bline e o a, on_circle e o b; f = on_bline f o a, on_circle f o b; j = on_pline j o a d, on_line j a c ? eqangle e c e j e j e f 7 | translated_imo_2002_p2b 8 | b c = segment b c; o = midpoint o b c; a = on_circle a o b; d = on_circle d o b, on_bline d a b; e = on_bline e o a, on_circle e o b; f = on_bline f o a, on_circle f o b; j = on_pline j o a d, on_line j a c ? eqangle c e c j c j c f 9 | translated_imo_2003_p4 10 | a b c = triangle a b c; o = circle o a b c; b1 = on_circle b1 o a, on_bline b1 c a; d1 = on_circle d1 o a, on_bline d1 c a; x = on_line x b b1, on_line x a c; d = on_line d d1 x, on_circle d o a; p = foot p d b c; q = foot q d c a; r = foot r d a b ? cong p q q r 11 | translated_imo_2004_p1 12 | a b c = triangle a b c; o = midpoint o b c; m = on_circle m o b, on_line m a b; n = on_circle n o b, on_line n a c; r = angle_bisector r b a c, angle_bisector r m o n; o1 = circle o1 b m r; o2 = circle o2 c n r; p = on_circle p o1 r, on_circle p o2 r ? coll p b c 13 | translated_imo_2004_p5 14 | a b c = triangle a b c; o = circle o a b c; d = on_circle d o a; p = on_aline p b c a b d, on_aline p d c a d b ? cong a p c p 15 | translated_imo_2005_p5 16 | a b c = triangle a b c; d = eqdistance d a b c; e = on_line e b c; f = on_line f a d, eqdistance f d e b; p = on_line p a c, on_line p b d; q = on_line q e f, on_line q b d; r = on_line r e f, on_line r a c; o1 = circle o1 a p d; o2 = circle o2 b p c; m = on_circle m o1 p, on_circle m o2 p ? cyclic p q r m 17 | translated_imo_2007_p4 18 | a b c = triangle a b c; o = circle o a b c; r = on_circle r o a, on_bline r a b; l = midpoint l c a; k = midpoint k c b; p = on_line p o k, on_line p c r; q = on_line q o l, on_line q c r; l1 = foot l1 l c r; k1 = foot k1 k c r ? eqratio k k1 l l1 r q r p 19 | translated_imo_2008_p1a 20 | a b c = triangle a b c; h = orthocenter h a b c; d = midpoint d b c; e = midpoint e a c; f = midpoint f a b; a1 = on_circle a1 d h, on_line a1 b c; a2 = on_circle a2 d h, on_line a2 b c; b1 = on_circle b1 e h, on_line b1 c a; b2 = on_circle b2 e h, on_line b2 c a; c1 = on_circle c1 f h, on_line c1 a b; c2 = on_circle c2 f h, on_line c2 a b ? cyclic c1 c2 b1 b2 21 | translated_imo_2008_p1b 22 | a b c = triangle a b c; h = orthocenter h a b c; d = midpoint d b c; e = midpoint e a c; f = midpoint f a b; a1 = on_circle a1 d h, on_line a1 b c; a2 = on_circle a2 d h, on_line a2 b c; b1 = on_circle b1 e h, on_line b1 c a; b2 = on_circle b2 e h, on_line b2 c a; c1 = on_circle c1 f h, on_line c1 a b; c2 = on_circle c2 f h, on_line c2 a b ? cyclic c1 c2 b1 a1 23 | translated_imo_2008_p6 24 | x@4.96_-0.13 y@-1.0068968328888160_-1.2534881080682770 z@-2.8402847238575120_-4.9117762734006830 = triangle x y z; o = circle o x y z; w@6.9090049230038776_-1.3884003936987552 = on_circle w o x; a = on_tline a z o z, on_tline a x o x; b = on_tline b z o z, on_tline b w o w; c = on_tline c y o y, on_tline c w o w; d = on_tline d x o x, on_tline d y o y; i1 = incenter i1 a b c; i2 = incenter i2 a c d; f1 = foot f1 i1 a c; f2 = foot f2 i2 a c; q t p s = cc_tangent q t p s i1 f1 i2 f2; k = on_line k q t, on_line k p s ? cong o k o x 25 | translated_imo_2009_p2 26 | m l k = triangle m l k; w = circle w m l k; q = on_tline q m w m; p = mirror p q m; b = mirror b p k; c = mirror c q l; a = on_line a b q, on_line a c p; o = circle o a b c ? cong o p o q 27 | translated_imo_2010_p2 28 | a b c = triangle a b c; o = circle o a b c; i = incenter i a b c; d = on_line d a i, on_circle d o a; f = on_line f b c; e = on_aline e a c b a f, on_circle e o a; g = midpoint g i f; k = on_line k d g, on_line k e i ? cong o a o k 29 | translated_imo_2010_p4 30 | s c p = iso_triangle s c p; o = on_tline o c s c; a = on_circle a o c; b = on_circle b o c, on_line b s a; m = on_line m c p, on_circle m o c; l = on_line l b p, on_circle l o c; k = on_line k a p, on_circle k o c ? cong m k m l 31 | translated_imo_2011_p6 32 | a b c = triangle a b c; o = circle o a b c; p = on_circle p o a; q = on_tline q p o p; pa = reflect pa p b c; pb = reflect pb p c a; pc = reflect pc p a b; qa = reflect qa q b c; qb = reflect qb q c a; qc = reflect qc q a b; a1 = on_line a1 pb qb, on_line a1 pc qc; b1 = on_line b1 pa qa, on_line b1 pc qc; c1 = on_line c1 pa qa, on_line c1 pb qb; o1 = circle o1 a1 b1 c1; x = on_circle x o a, on_circle x o1 a1 ? coll x o o1 33 | translated_imo_2012_p1 34 | a b c = triangle a b c; m l k j = excenter2 m l k j a b c; f = on_line f m l, on_line f b j; g = on_line g m k, on_line g c j; s = on_line s f a, on_line s b c; t = on_line t g a, on_line t c b ? cong m s m t 35 | translated_imo_2012_p5 36 | c a b = r_triangle c a b; d = foot d c a b; x = on_line x c d; k = on_line k a x, on_circle k b c; l = on_line l b x, on_circle l a c; m = on_line m a l, on_line m b k ? cong m k m l 37 | translated_imo_2013_p4 38 | a b c = triangle a b c; h = orthocenter h a b c; m = on_line m h b, on_line m a c; n = on_line n h c, on_line n a b; w = on_line w b c; o1 = circle o1 b n w; o2 = circle o2 c m w; x = on_line x o1 w, on_circle x o1 w; y = on_line y o2 w, on_circle y o2 w ? coll x h y 39 | translated_imo_2014_p4 40 | a b c = triangle a b c; p = on_line p b c, on_aline p a b b c a; q = on_line q b c, on_aline q a c c b a; m = mirror m a p; n = mirror n a q; x = on_line x b m, on_line x c n; o = circle o a b c ? cong o x o a 41 | translated_imo_2015_p3 42 | a b c = triangle a b c; h = orthocenter h a b c; f = on_line f h a, on_line f b c; m = midpoint m b c; o = circle o a b c; q = on_dia q a h, on_circle q o a; k = on_dia k h q, on_circle k o a; o1 = circle o1 k q h; o2 = circle o2 f k m ? coll o1 o2 k 43 | translated_imo_2015_p4 44 | a b c = triangle a b c; o = circle o a b c; d = on_line d b c; e = on_line e b c, on_circle e a d; f = on_circle f o a, on_circle f a d; g = on_circle g o a, on_circle g a d; o1 = circle o1 f b d; o2 = circle o2 g c e; k = on_circle k o1 b, on_line k a b; l = on_circle l o2 c, on_line l a c; x = on_line x f k, on_line x l g ? coll x o a 45 | translated_imo_2016_p1 46 | a b z = triangle a b z; f = angle_bisector f b a z, on_bline f a b; c = on_tline c b f b, on_line c a f; d = on_line d a z, on_bline d a c; e = angle_mirror e c a d, on_bline e a d; m = midpoint m c f; x = parallelogram e a m x; y = on_line y f x, on_line y e m ? coll y b d 47 | translated_imo_2017_p4 48 | r s = segment r s; t = mirror t r s; o = on_bline o r s; j = on_circle j o s; o1 = circle o1 j s t; a = on_tline a r o r, on_circle a o1 s; b = on_tline b r o r, on_circle b o1 s; k = on_line k j a, on_circle k o s ? perp k t o1 t 49 | translated_imo_2018_p1 50 | a b c = triangle a b c; o = circle o a b c; d = on_line d a b; e = on_line e a c, on_circle e a d; f = on_bline f b d, on_circle f o a; g = on_bline g e c, on_circle g o a ? para d e f g 51 | translated_imo_2019_p2 52 | a b c = triangle; a1 = on_line b c; b1 = on_line a c; p = on_line a a1; q = on_line b b1, on_pline p a b; p1 = on_line p b1, eqangle3 p c a b c; q1 = on_line q a1, eqangle3 c q b c a ? cyclic p q p1 q1 53 | translated_imo_2019_p6 54 | a b c = triangle a b c; d e f i = incenter2 d e f i a b c; r = on_tline r d e f, on_circle r i d; p = on_line p r a, on_circle p i d; o1 = circle o1 p c e; o2 = circle o2 p b f; q = on_circle q o1 p, on_circle q o2 p; t = on_line t p q, on_line t i d ? perp a t a i 55 | translated_imo_2020_p1 56 | p a b = triangle p a b; x = angle_bisector p b a; y = angle_bisector p a b; z = on_aline z a p a b x; t = on_aline t p a p a z; d = on_aline d p t p b a, on_line a z; u = on_aline u b p b a y; v = on_aline v p b p b u; c = on_aline c p v p a b, on_line b u; o = angle_bisector a d p, angle_bisector p c b ? cong o a o b 57 | translated_imo_2021_p3 58 | a b c = triangle; d = angle_bisector b a c; e = on_aline d a d c b, on_line a c; f = on_aline d a d b c, on_line a b; x = on_bline b c, on_line a c; o1 = circle a d c; o2 = circle e x d; y = on_line e f, on_line b c ? coll o1 o2 y 59 | translated_imo_2022_p4 60 | b c = segment; d = free; e = eqdistance d b c; t = on_bline b d, on_bline c e; a = eqangle2 b t e; p = on_line a b, on_line c d; q = on_line a b, on_line c t; r = on_line a e, on_line c d; s = on_line a e, on_line d t ? cyclic p q r s 61 | -------------------------------------------------------------------------------- /lm_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Wrapper for language modeling inference implemented in Meliad.""" 17 | from typing import Any, Dict 18 | 19 | import jax 20 | import models # pylint: disable=unused-import 21 | import t5.data 22 | from transformer import inference_utils 23 | 24 | 25 | np = jax.numpy 26 | 27 | 28 | Trainer = inference_utils.Trainer 29 | 30 | MetricsOutput = Dict[str, Any] # Metrics output by model. 31 | 32 | 33 | parse_gin_configuration = inference_utils.parse_gin_configuration 34 | 35 | 36 | class LanguageModelInference: 37 | """Meliad wrapper for LM inference.""" 38 | 39 | def __init__(self, vocab_path: str, load_dir: str, mode='beam_search'): 40 | self.vocab = t5.data.SentencePieceVocabulary(vocab_path) 41 | 42 | # This task won't be pulling from a dataset. 43 | def null_iter_fn() -> None: 44 | return None 45 | 46 | process_summaries_f = inference_utils.models.process_summaries_function( 47 | self.vocab 48 | ) 49 | 50 | trainer = inference_utils.training_loop.Trainer( 51 | get_training_dataset_iterator=null_iter_fn, 52 | get_test_dataset_iterator=None, 53 | pretty_print_input_function=None, 54 | process_summaries_function=process_summaries_f, 55 | load_dir=load_dir, 56 | workdir='', # Don't log or save checkpoints. 57 | replicate_mode=False, 58 | ) # Run on a single device at batch size 1. 59 | self.trainer = trainer 60 | 61 | # Create and initialize the model. 62 | (tstate, _, imodel, prngs) = trainer.initialize_model() 63 | self.imodel = imodel 64 | self.batch_size = imodel.task_config.batch_size 65 | 66 | self.n = imodel.num_heads 67 | self.h = imodel.head_size 68 | 69 | # Create an inference task. 70 | writers = {} 71 | self.task = trainer.create_training_task(mode, imodel, prngs, writers) # pylint: disable=too-many-function-args 72 | 73 | # Register any additional actions. 74 | # Actions are cleared first for use with colab. 75 | inference_utils.training_loop.clear_interstep_callbacks() 76 | inference_utils.training_loop.register_interstep_callbacks() 77 | self.tstate = tstate 78 | 79 | # some default parameters. 80 | eos = [0] * 1024 81 | for idx in self.encode_list(['.', ';']): 82 | eos[idx] = 1 83 | 84 | self.eos = np.array(eos, dtype=np.bfloat16) 85 | self.mask = jax.numpy.ones([1024], dtype=np.bfloat16) 86 | 87 | def decode(self, ids: list[int]) -> str: 88 | return self.vocab.decode(ids) 89 | 90 | def decode_list(self, tokens: list[int]) -> list[str]: 91 | return [self.decode([tok]) for tok in tokens] 92 | 93 | def encode(self, inputs_str: str) -> list[int]: 94 | return self.vocab.encode(inputs_str) 95 | 96 | def encode_list(self, inputs_strs: list[str]) -> list[int]: 97 | result = [self.vocab.encode(x) for x in inputs_strs] 98 | assert all([len(x) == 1 for x in result]), [ 99 | self.decode(x) for x in result if len(x) != 1 100 | ] 101 | return [x[0] for x in result] 102 | 103 | def call( 104 | self, 105 | inputs: np.ndarray, 106 | dstate: tuple[dict[str, np.ndarray], ...] = None, 107 | eos: np.ndarray = None, 108 | mask: np.ndarray = None, 109 | ) -> MetricsOutput: 110 | """Call the meliad model.""" 111 | batch_size, length = inputs.shape 112 | inputs = jax.numpy.pad(inputs, [(0, 0), (0, 1024 - length)]) 113 | 114 | if eos is None: 115 | eos = self.eos 116 | if mask is None: 117 | mask = self.mask 118 | 119 | x = {'targets': inputs, 'length': length, 'eos': eos, 'mask': mask} 120 | 121 | if dstate is not None: 122 | x['start_of_sequence'] = jax.numpy.array([False] * batch_size) 123 | else: 124 | dstate = tuple( 125 | [{ # this dummy value will never be used. 126 | 'current_index': np.array([0] * batch_size, dtype=np.int32), 127 | 'keys': np.zeros( 128 | (batch_size, 2048, self.n, self.h), dtype=np.bfloat16 129 | ), 130 | 'values': np.zeros( 131 | (batch_size, 2048, self.n, self.h), dtype=np.bfloat16 132 | ), 133 | 'recurrent_kvq': None, 134 | 'relative_position_bias': np.zeros( 135 | (batch_size, self.n, 1, 1024), dtype=np.bfloat16 136 | ), 137 | }] 138 | * 12 139 | ) 140 | x['start_of_sequence'] = jax.numpy.array([True] * batch_size) 141 | 142 | x['dstate'] = dstate 143 | _, metrics_np = self.task.run_step(self.tstate, x, 0) 144 | return metrics_np 145 | 146 | def beam_decode( 147 | self, 148 | inputs: str, 149 | eos_tokens: np.ndarray = None, 150 | mask_tokens: np.ndarray = None, 151 | dstate: dict[str, np.ndarray] = None, 152 | ) -> MetricsOutput: 153 | """Beam search.""" 154 | inputs = jax.numpy.array([self.vocab.encode(inputs)] * self.batch_size) 155 | 156 | eos = self.eos 157 | if eos_tokens is not None: 158 | eos_ids = self.encode_list(eos_tokens) 159 | eos = np.array( 160 | [1 if idx in eos_ids else 0 for idx in range(1024)], dtype=np.bfloat16 161 | ).reshape((1, 1, 1024)) 162 | 163 | mask = self.mask 164 | if mask_tokens is not None: 165 | mask_ids = self.encode_list(mask_tokens) 166 | mask = np.array( 167 | [0 if idx in mask_ids else 1 for idx in range(1024)], 168 | dtype=np.bfloat16, 169 | ).reshape((1, 1, 1024)) 170 | 171 | metrics_np = self.call(inputs, dstate=dstate, eos=eos, mask=mask) 172 | 173 | finished_seqs = metrics_np['finished_seqs'] 174 | finished_scores = metrics_np['finished_scores'] 175 | 176 | seqs = [] 177 | scores = [] 178 | for seq, score in zip(finished_seqs, finished_scores): 179 | seq = self.decode(seq[1:]) 180 | seqs.append(seq) 181 | scores.append(score) 182 | 183 | return { 184 | 'finished_seqs': finished_seqs, 185 | 'finished_scores': finished_scores, 186 | 'seqs_str': seqs, 187 | 'scores': scores, 188 | 'dstate': metrics_np['dstate'], 189 | } 190 | -------------------------------------------------------------------------------- /lm_inference_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for lm_inference.py.""" 17 | import os 18 | import unittest 19 | 20 | from absl import flags 21 | from absl.testing import absltest 22 | import lm_inference as lm 23 | 24 | 25 | _DATA_PATH = flags.DEFINE_string('data_path', '', 'path to ckpt and vocab.') 26 | _MELIAD_PATH = flags.DEFINE_string( 27 | 'meliad_path', '', 'path to meliad repository.' 28 | ) # pylint: disable=line-too-long 29 | 30 | 31 | class LmInferenceTest(unittest.TestCase): 32 | 33 | @classmethod 34 | def setUpClass(cls): 35 | super().setUpClass() 36 | gin_file = [ 37 | 'base_htrans.gin', 38 | 'size/medium_150M.gin', 39 | 'options/positions_t5.gin', 40 | 'options/lr_cosine_decay.gin', 41 | 'options/seq_1024_nocache.gin', 42 | 'geometry_150M_generate.gin', 43 | ] 44 | 45 | gin_param = [ 46 | 'DecoderOnlyLanguageModelGenerate.output_token_losses=True', 47 | 'TransformerTaskConfig.batch_size=2', 48 | 'TransformerTaskConfig.sequence_length=128', 49 | 'Trainer.restore_state_variables=False', 50 | ] 51 | 52 | gin_search_paths = [ 53 | os.path.join(_MELIAD_PATH.value, 'transformer/configs'), 54 | os.getcwd(), 55 | ] 56 | 57 | vocab_path = os.path.join(_DATA_PATH.value, 'geometry.757.model') 58 | 59 | lm.parse_gin_configuration(gin_file, gin_param, gin_paths=gin_search_paths) 60 | 61 | cls.loaded_lm = lm.LanguageModelInference( 62 | vocab_path, _DATA_PATH.value, mode='beam_search' 63 | ) 64 | 65 | def test_lm_decode(self): 66 | outputs = LmInferenceTest.loaded_lm.beam_decode( 67 | '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c' 68 | ' {F1} x00', 69 | eos_tokens=[';'], 70 | ) 71 | self.assertEqual( 72 | outputs['seqs_str'], 73 | ['e : D a b c e 02 D a c b e 03 ;', 'e : C a c e 02 C b d e 03 ;'], 74 | ) 75 | 76 | def test_lm_score_may_fail_numerically_for_external_meliad(self): 77 | outputs = LmInferenceTest.loaded_lm.beam_decode( 78 | '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c' 79 | ' {F1} x00', 80 | eos_tokens=[';'], 81 | ) 82 | self.assertEqual( 83 | outputs['scores'], 84 | [-1.18607294559478759765625, -1.10228693485260009765625], 85 | ) 86 | 87 | 88 | if __name__ == '__main__': 89 | absltest.main() 90 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Transformer language model generate mode.""" 17 | 18 | from typing import Any, Tuple 19 | import beam_search 20 | import decoder_stack 21 | import gin 22 | import jax 23 | import jax.numpy as jnp 24 | from transformer import models 25 | 26 | 27 | @gin.configurable 28 | class DecoderOnlyLanguageModelGenerate(models.DecoderOnlyLanguageModel): 29 | """Decoder only language modeling in inference mode.""" 30 | 31 | decoder_factory = decoder_stack.DecoderStackGenerate 32 | 33 | num_heads: int = gin.REQUIRED 34 | head_size: int = gin.REQUIRED 35 | 36 | def get_fake_input(self) -> dict[str, Any]: 37 | fake_input_dict = super().get_fake_input() 38 | b = self.task_config.batch_size 39 | n = self.num_heads 40 | h = self.head_size 41 | fake_input_dict.update({ 42 | 'dstate': tuple( 43 | [{ 44 | 'current_index': jnp.array([0] * b, dtype=jnp.int32), 45 | 'keys': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16), 46 | 'values': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16), 47 | 'recurrent_kvq': None, 48 | 'relative_position_bias': jnp.zeros( 49 | (b, n, 1, 1024), dtype=jnp.bfloat16 50 | ), 51 | }] 52 | * 12 53 | ), 54 | 'eos': jnp.zeros([1024], dtype=jnp.bfloat16), 55 | 'mask': jnp.ones([1024], dtype=jnp.bfloat16), 56 | 'length': 1, 57 | 'temperature': 1.0, 58 | }) 59 | return fake_input_dict 60 | 61 | def __call__(self, inputs: ...) -> tuple[Any, dict[str, Any]]: 62 | # Make sure this code is not used on untested cases. 63 | if self.mode not in ['init', 'beam_search']: 64 | raise ValueError(f'{type(self)} cannot do mode {self.mode}') 65 | if self.decoder.supports_generate(): 66 | raise ValueError(f'{type(self)}.decoder cannot supports_generate()') 67 | 68 | self.decoder( 69 | input_tokens=inputs['targets'][:, 0:1], 70 | target_tokens=None, 71 | start_of_sequence=inputs['start_of_sequence'], 72 | ) 73 | 74 | b = inputs['targets'].shape[0] 75 | no_start_of_seq = jnp.array([False] * b, dtype=jnp.bool_) 76 | 77 | # This fn is used in both beam_search or topk_sampling. 78 | def tokens_to_logits_fn( 79 | input_token: jnp.ndarray, dstate: tuple[dict[str, jnp.ndarray], ...] 80 | ) -> tuple[jnp.ndarray, tuple[dict[str, jnp.ndarray], ...]]: 81 | (logits, dstate, _) = self.decoder( 82 | input_tokens=input_token, 83 | target_tokens=None, 84 | start_of_sequence=no_start_of_seq, 85 | decoder_state=dstate, 86 | ) 87 | return logits[:, -1, :], dstate 88 | 89 | last_token = jax.lax.dynamic_slice_in_dim( 90 | inputs['targets'], inputs['length'] - 1, 1, axis=1 91 | ) 92 | 93 | # last token is used to seed beam_search 94 | inputs['targets'] = inputs['targets'][:, 0:-1] 95 | dstate = jax.lax.cond( 96 | inputs['start_of_sequence'][0], 97 | lambda: self.generate(inputs)[0], 98 | lambda: inputs['dstate'], 99 | ) 100 | 101 | # Then we run beam search, init with last_token & dstate. 102 | finished_seqs, finished_scores, dstate = beam_search.beam_search_flat( 103 | last_token, 104 | dstate, 105 | tokens_to_logits_fn, 106 | max_decode_len=512, 107 | eos=inputs['eos'].reshape((1, 1, -1)), 108 | mask=inputs['mask'].reshape((1, 1, -1)), 109 | ) 110 | 111 | return 0.0, { 112 | 'finished_seqs': finished_seqs, 113 | 'finished_scores': finished_scores, 114 | 'dstate': dstate, 115 | } 116 | 117 | def generate( 118 | self, inputs: ... 119 | ) -> tuple[tuple[dict[str, jnp.ndarray, ...], ...], jnp.ndarray]: 120 | """Generate an output sequence. 121 | 122 | Args: 123 | inputs: the same as argument to _call_. 124 | 125 | Returns: 126 | An array of generated tokens of shape (batch_size, sequence_length). 127 | """ 128 | input_tokens = inputs['targets'] # [b,seq_len] 129 | start_of_sequence = inputs['start_of_sequence'] # [b] 130 | target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)]) 131 | batch_size = target_tokens.shape[0] 132 | 133 | # Assuming all sequences start at the same time. 134 | start0 = inputs['start_of_sequence'][0] 135 | dstate = jax.lax.cond( 136 | start0, 137 | lambda: self.decoder.init_decoder_state_vanilla( # pylint: disable=g-long-lambda 138 | 1024, start_of_sequence 139 | ), 140 | lambda: inputs['dstate'], 141 | ) 142 | 143 | first_token = input_tokens[:, 0:1] 144 | no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_) 145 | temperature = 1 146 | if 'temperature' in inputs: 147 | temperature = inputs['temperature'] 148 | 149 | num_steps = inputs['length'] 150 | if self.mode == 'beam_search': 151 | num_steps -= 1 152 | 153 | def cond_fn(scan_state) -> jnp.bool_: 154 | _, _, i, _ = scan_state 155 | return i < num_steps 156 | 157 | def loop_fn(scan_state: Any) -> Tuple[Any, Any, Any, Any]: 158 | (dstate, input_token, i, _) = scan_state 159 | 160 | (logits, dstate, _) = self.decoder( 161 | input_tokens=input_token, 162 | target_tokens=None, 163 | start_of_sequence=no_start_of_seq, 164 | decoder_state=dstate, 165 | ) 166 | 167 | logits = logits / temperature 168 | output_token = jax.lax.dynamic_slice_in_dim(target_tokens, i, 1, axis=1) 169 | 170 | return (dstate, output_token, i + 1, logits) 171 | 172 | # Scan over the sequence length. 173 | dummy_logits = jnp.zeros((batch_size, 1, 1024)) 174 | initial_scan_state = (dstate, first_token, 0, dummy_logits) 175 | dstate, _, _, logits = jax.lax.while_loop( 176 | cond_fn, loop_fn, initial_scan_state 177 | ) 178 | return dstate, logits 179 | -------------------------------------------------------------------------------- /numericals_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit testing for the geometry numericals code.""" 17 | 18 | import unittest 19 | 20 | from absl.testing import absltest 21 | import numericals as nm 22 | 23 | np = nm.np 24 | 25 | unif = nm.unif 26 | Point = nm.Point 27 | Line = nm.Line 28 | Circle = nm.Circle 29 | HalfLine = nm.HalfLine 30 | 31 | line_circle_intersection = nm.line_circle_intersection 32 | line_line_intersection = nm.line_line_intersection 33 | 34 | check_coll = nm.check_coll 35 | check_eqangle = nm.check_eqangle 36 | 37 | random_points = nm.random_points 38 | ang_between = nm.ang_between 39 | head_from = nm.head_from 40 | 41 | 42 | class NumericalTest(unittest.TestCase): 43 | 44 | def test_sketch_ieq_triangle(self): 45 | a, b, c = nm.sketch_ieq_triangle([]) 46 | self.assertAlmostEqual(a.distance(b), b.distance(c)) 47 | self.assertAlmostEqual(c.distance(a), b.distance(c)) 48 | 49 | def test_sketch_2l1c(self): 50 | p = nm.Point(0.0, 0.0) 51 | pi = np.pi 52 | anga = unif(-0.4 * pi, 0.4 * pi) 53 | a = Point(np.cos(anga), np.sin(anga)) 54 | angb = unif(0.6 * pi, 1.4 * pi) 55 | b = Point(np.cos(angb), np.sin(angb)) 56 | 57 | angc = unif(anga + 0.05 * pi, angb - 0.05 * pi) 58 | c = Point(np.cos(angc), np.sin(angc)) * unif(0.2, 0.8) 59 | 60 | x, y, z, i = nm.sketch_2l1c([a, b, c, p]) 61 | self.assertTrue(check_coll([x, c, a])) 62 | self.assertTrue(check_coll([y, c, b])) 63 | self.assertAlmostEqual(z.distance(p), 1.0) 64 | self.assertTrue(check_coll([p, i, z])) 65 | self.assertTrue(Line(i, x).is_perp(Line(c, a))) 66 | self.assertTrue(Line(i, y).is_perp(Line(c, b))) 67 | self.assertAlmostEqual(i.distance(x), i.distance(y)) 68 | self.assertAlmostEqual(i.distance(x), i.distance(z)) 69 | 70 | def test_sketch_3peq(self): 71 | a, b, c = random_points(3) 72 | x, y, z = nm.sketch_3peq([a, b, c]) 73 | 74 | self.assertTrue(check_coll([a, b, x])) 75 | self.assertTrue(check_coll([a, c, y])) 76 | self.assertTrue(check_coll([b, c, z])) 77 | self.assertTrue(check_coll([x, y, z])) 78 | self.assertAlmostEqual(z.distance(x), z.distance(y)) 79 | 80 | def test_sketch_aline(self): 81 | a, b, c, d, e = random_points(5) 82 | ex = nm.sketch_aline([a, b, c, d, e]) 83 | self.assertIsInstance(ex, HalfLine) 84 | self.assertEqual(ex.tail, e) 85 | x = ex.head 86 | self.assertAlmostEqual(ang_between(b, a, c), ang_between(e, d, x)) 87 | 88 | def test_sketch_amirror(self): 89 | a, b, c = random_points(3) 90 | bx = nm.sketch_amirror([a, b, c]) 91 | self.assertIsInstance(bx, HalfLine) 92 | assert bx.tail == b 93 | x = bx.head 94 | 95 | ang1 = ang_between(b, a, c) 96 | ang2 = ang_between(b, c, x) 97 | self.assertAlmostEqual(ang1, ang2) 98 | 99 | def test_sketch_bisect(self): 100 | a, b, c = random_points(3) 101 | line = nm.sketch_bisect([a, b, c]) 102 | self.assertAlmostEqual(b.distance(line), 0.0) 103 | 104 | l = a.perpendicular_line(line) 105 | x = line_line_intersection(l, Line(b, c)) 106 | self.assertAlmostEqual(a.distance(line), x.distance(line)) 107 | 108 | d, _ = line_circle_intersection(line, Circle(b, radius=1)) 109 | ang1 = ang_between(b, a, d) 110 | ang2 = ang_between(b, d, c) 111 | self.assertAlmostEqual(ang1, ang2) 112 | 113 | def test_sketch_bline(self): 114 | a, b = random_points(2) 115 | l = nm.sketch_bline([a, b]) 116 | self.assertTrue(Line(a, b).is_perp(l)) 117 | self.assertAlmostEqual(a.distance(l), b.distance(l)) 118 | 119 | def test_sketch_cc_tangent(self): 120 | o = Point(0.0, 0.0) 121 | w = Point(1.0, 0.0) 122 | 123 | ra = unif(0.0, 0.6) 124 | rb = unif(0.4, 1.0) 125 | 126 | a = unif(0.0, np.pi) 127 | b = unif(0.0, np.pi) 128 | 129 | a = o + ra * Point(np.cos(a), np.sin(a)) 130 | b = w + rb * Point(np.sin(b), np.cos(b)) 131 | 132 | x, y, z, t = nm.sketch_cc_tangent([o, a, w, b]) 133 | xy = Line(x, y) 134 | zt = Line(z, t) 135 | self.assertAlmostEqual(o.distance(xy), o.distance(a)) 136 | self.assertAlmostEqual(o.distance(zt), o.distance(a)) 137 | self.assertAlmostEqual(w.distance(xy), w.distance(b)) 138 | self.assertAlmostEqual(w.distance(zt), w.distance(b)) 139 | 140 | def test_sketch_circle(self): 141 | a, b, c = random_points(3) 142 | circle = nm.sketch_circle([a, b, c]) 143 | self.assertAlmostEqual(circle.center.distance(a), 0.0) 144 | self.assertAlmostEqual(circle.radius, b.distance(c)) 145 | 146 | def test_sketch_e5128(self): 147 | b = Point(0.0, 0.0) 148 | c = Point(0.0, 1.0) 149 | ang = unif(-np.pi / 2, 3 * np.pi / 2) 150 | d = head_from(c, ang, 1.0) 151 | a = Point(unif(0.5, 2.0), 0.0) 152 | 153 | e, g = nm.sketch_e5128([a, b, c, d]) 154 | ang1 = ang_between(a, b, d) 155 | ang2 = ang_between(e, a, g) 156 | self.assertAlmostEqual(ang1, ang2) 157 | 158 | def test_sketch_eq_quadrangle(self): 159 | a, b, c, d = nm.sketch_eq_quadrangle([]) 160 | self.assertAlmostEqual(a.distance(d), c.distance(b)) 161 | ac = Line(a, c) 162 | assert ac.diff_side(b, d), (ac(b), ac(d)) 163 | bd = Line(b, d) 164 | assert bd.diff_side(a, c), (bd(a), bd(c)) 165 | 166 | def test_sketch_eq_trapezoid(self): 167 | a, b, c, d = nm.sketch_eq_trapezoid([]) 168 | assert Line(a, b).is_parallel(Line(c, d)) 169 | self.assertAlmostEqual(a.distance(d), b.distance(c)) 170 | 171 | def test_sketch_eqangle3(self): 172 | points = random_points(5) 173 | x = nm.sketch_eqangle3(points).sample_within(points)[0] 174 | a, b, d, e, f = points 175 | self.assertTrue(check_eqangle([x, a, x, b, d, e, d, f])) 176 | 177 | def test_sketch_eqangle2(self): 178 | a, b, c = random_points(3) 179 | x = nm.sketch_eqangle2([a, b, c]) 180 | ang1 = ang_between(a, b, x) 181 | ang2 = ang_between(c, x, b) 182 | self.assertAlmostEqual(ang1, ang2) 183 | 184 | def test_sketch_edia_quadrangle(self): 185 | a, b, c, d = nm.sketch_eqdia_quadrangle([]) 186 | assert Line(a, c).diff_side(b, d) 187 | assert Line(b, d).diff_side(a, c) 188 | self.assertAlmostEqual(a.distance(c), b.distance(d)) 189 | 190 | def test_sketch_isos(self): 191 | a, b, c = nm.sketch_isos([]) 192 | self.assertAlmostEqual(a.distance(b), a.distance(c)) 193 | self.assertAlmostEqual(ang_between(b, a, c), ang_between(c, b, a)) 194 | 195 | def test_sketch_quadrange(self): 196 | a, b, c, d = nm.sketch_quadrangle([]) 197 | self.assertTrue(Line(a, c).diff_side(b, d)) 198 | self.assertTrue(Line(b, d).diff_side(a, c)) 199 | 200 | def test_sketch_r_trapezoid(self): 201 | a, b, c, d = nm.sketch_r_trapezoid([]) 202 | self.assertTrue(Line(a, b).is_perp(Line(a, d))) 203 | self.assertTrue(Line(a, b).is_parallel(Line(c, d))) 204 | self.assertTrue(Line(a, c).diff_side(b, d)) 205 | self.assertTrue(Line(b, d).diff_side(a, c)) 206 | 207 | def test_sketch_r_triangle(self): 208 | a, b, c = nm.sketch_r_triangle([]) 209 | self.assertTrue(Line(a, b).is_perp(Line(a, c))) 210 | 211 | def test_sketch_rectangle(self): 212 | a, b, c, d = nm.sketch_rectangle([]) 213 | self.assertTrue(Line(a, b).is_perp(Line(b, c))) 214 | self.assertTrue(Line(b, c).is_perp(Line(c, d))) 215 | self.assertTrue(Line(c, d).is_perp(Line(d, a))) 216 | 217 | def test_sketch_reflect(self): 218 | a, b, c = random_points(3) 219 | x = nm.sketch_reflect([a, b, c]) 220 | self.assertTrue(Line(a, x).is_perp(Line(b, c))) 221 | self.assertAlmostEqual(x.distance(Line(b, c)), a.distance(Line(b, c))) 222 | 223 | def test_sketch_risos(self): 224 | a, b, c = nm.sketch_risos([]) 225 | self.assertAlmostEqual(a.distance(b), a.distance(c)) 226 | self.assertTrue(Line(a, b).is_perp(Line(a, c))) 227 | 228 | def test_sketch_rotaten90(self): 229 | a, b = random_points(2) 230 | x = nm.sketch_rotaten90([a, b]) 231 | self.assertAlmostEqual(a.distance(x), a.distance(b)) 232 | self.assertTrue(Line(a, x).is_perp(Line(a, b))) 233 | d = Point(0.0, 0.0) 234 | e = Point(0.0, 1.0) 235 | f = Point(1.0, 0.0) 236 | self.assertAlmostEqual(ang_between(d, e, f), ang_between(a, b, x)) 237 | 238 | def test_sketch_rotatep90(self): 239 | a, b = random_points(2) 240 | x = nm.sketch_rotatep90([a, b]) 241 | self.assertAlmostEqual(a.distance(x), a.distance(b)) 242 | self.assertTrue(Line(a, x).is_perp(Line(a, b))) 243 | d = Point(0.0, 0.0) 244 | e = Point(0.0, 1.0) 245 | f = Point(1.0, 0.0) 246 | self.assertAlmostEqual(ang_between(d, f, e), ang_between(a, b, x)) 247 | 248 | def test_sketch_s_angle(self): 249 | a, b = random_points(2) 250 | y = unif(0.0, np.pi) 251 | bx = nm.sketch_s_angle([a, b, y / np.pi * 180]) 252 | self.assertIsInstance(bx, HalfLine) 253 | self.assertEqual(bx.tail, b) 254 | x = bx.head 255 | 256 | d = Point(1.0, 0.0) 257 | e = Point(0.0, 0.0) 258 | f = Point(np.cos(y), np.sin(y)) 259 | self.assertAlmostEqual(ang_between(e, d, f), ang_between(b, a, x)) 260 | 261 | def test_sketch_shift(self): 262 | a, b, c = random_points(3) 263 | x = nm.sketch_shift([a, b, c]) 264 | self.assertTrue((b - a).close(x - c)) 265 | 266 | def test_sketch_square(self): 267 | a, b = random_points(2) 268 | c, d = nm.sketch_square([a, b]) 269 | self.assertTrue(Line(a, b).is_perp(Line(b, c))) 270 | self.assertTrue(Line(b, c).is_perp(Line(c, d))) 271 | self.assertTrue(Line(c, d).is_perp(Line(d, a))) 272 | self.assertAlmostEqual(a.distance(b), b.distance(c)) 273 | 274 | def test_sketch_isquare(self): 275 | a, b, c, d = nm.sketch_isquare([]) 276 | self.assertTrue(Line(a, b).is_perp(Line(b, c))) 277 | self.assertTrue(Line(b, c).is_perp(Line(c, d))) 278 | self.assertTrue(Line(c, d).is_perp(Line(d, a))) 279 | self.assertAlmostEqual(a.distance(b), b.distance(c)) 280 | 281 | def test_sketch_trapezoid(self): 282 | a, b, c, d = nm.sketch_trapezoid([]) 283 | self.assertTrue(Line(a, b).is_parallel(Line(c, d))) 284 | self.assertTrue(Line(a, c).diff_side(b, d)) 285 | self.assertTrue(Line(b, d).diff_side(a, c)) 286 | 287 | def test_sketch_triangle(self): 288 | a, b, c = nm.sketch_triangle([]) 289 | self.assertFalse(check_coll([a, b, c])) 290 | 291 | def test_sketch_triangle12(self): 292 | a, b, c = nm.sketch_triangle12([]) 293 | self.assertAlmostEqual(a.distance(b) * 2, a.distance(c)) 294 | 295 | def test_sketch_trisect(self): 296 | a, b, c = random_points(3) 297 | x, y = nm.sketch_trisect([a, b, c]) 298 | self.assertAlmostEqual(ang_between(b, a, x), ang_between(b, x, y)) 299 | self.assertAlmostEqual(ang_between(b, x, y), ang_between(b, y, c)) 300 | self.assertAlmostEqual(ang_between(b, a, x) * 3, ang_between(b, a, c)) 301 | 302 | def test_sketch_trisegment(self): 303 | a, b = random_points(2) 304 | x, y = nm.sketch_trisegment([a, b]) 305 | self.assertAlmostEqual( 306 | a.distance(x) + x.distance(y) + y.distance(b), a.distance(b) 307 | ) 308 | self.assertAlmostEqual(a.distance(x), x.distance(y)) 309 | self.assertAlmostEqual(x.distance(y), y.distance(b)) 310 | 311 | 312 | if __name__ == '__main__': 313 | absltest.main() 314 | -------------------------------------------------------------------------------- /pretty.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities for string manipulation in the DSL.""" 17 | 18 | MAP_SYMBOL = { 19 | 'T': 'perp', 20 | 'P': 'para', 21 | 'D': 'cong', 22 | 'S': 'simtri', 23 | 'I': 'circle', 24 | 'M': 'midp', 25 | 'O': 'cyclic', 26 | 'C': 'coll', 27 | '^': 'eqangle', 28 | '/': 'eqratio', 29 | '%': 'eqratio', 30 | '=': 'contri', 31 | 'X': 'collx', 32 | 'A': 'acompute', 33 | 'R': 'rcompute', 34 | 'Q': 'fixc', 35 | 'E': 'fixl', 36 | 'V': 'fixb', 37 | 'H': 'fixt', 38 | 'Z': 'fixp', 39 | 'Y': 'ind', 40 | } 41 | 42 | 43 | def map_symbol(c: str) -> str: 44 | return MAP_SYMBOL[c] 45 | 46 | 47 | def map_symbol_inv(c: str) -> str: 48 | return {v: k for k, v in MAP_SYMBOL.items()}[c] 49 | 50 | 51 | def _gcd(x: int, y: int) -> int: 52 | while y: 53 | x, y = y, x % y 54 | return x 55 | 56 | 57 | def simplify(n: int, d: int) -> tuple[int, int]: 58 | g = _gcd(n, d) 59 | return (n // g, d // g) 60 | 61 | 62 | def pretty2r(a: str, b: str, c: str, d: str) -> str: 63 | if b in (c, d): 64 | a, b = b, a 65 | 66 | if a == d: 67 | c, d = d, c 68 | 69 | return f'{a} {b} {c} {d}' 70 | 71 | 72 | def pretty2a(a: str, b: str, c: str, d: str) -> str: 73 | if b in (c, d): 74 | a, b = b, a 75 | 76 | if a == d: 77 | c, d = d, c 78 | 79 | return f'{a} {b} {c} {d}' 80 | 81 | 82 | def pretty_angle(a: str, b: str, c: str, d: str) -> str: 83 | if b in (c, d): 84 | a, b = b, a 85 | if a == d: 86 | c, d = d, c 87 | 88 | if a == c: 89 | return f'\u2220{b}{a}{d}' 90 | return f'\u2220({a}{b}-{c}{d})' 91 | 92 | 93 | def pretty_nl(name: str, args: list[str]) -> str: 94 | """Natural lang formatting a predicate.""" 95 | if name == 'aconst': 96 | a, b, c, d, y = args 97 | return f'{pretty_angle(a, b, c, d)} = {y}' 98 | if name == 'rconst': 99 | a, b, c, d, y = args 100 | return f'{a}{b}:{c}{d} = {y}' 101 | if name == 'acompute': 102 | a, b, c, d = args 103 | return f'{pretty_angle(a, b, c, d)}' 104 | if name in ['coll', 'C']: 105 | return '' + ','.join(args) + ' are collinear' 106 | if name == 'collx': 107 | return '' + ','.join(list(set(args))) + ' are collinear' 108 | if name in ['cyclic', 'O']: 109 | return '' + ','.join(args) + ' are concyclic' 110 | if name in ['midp', 'midpoint', 'M']: 111 | x, a, b = args 112 | return f'{x} is midpoint of {a}{b}' 113 | if name in ['eqangle', 'eqangle6', '^']: 114 | a, b, c, d, e, f, g, h = args 115 | return f'{pretty_angle(a, b, c, d)} = {pretty_angle(e, f, g, h)}' 116 | if name in ['eqratio', 'eqratio6', '/']: 117 | return '{}{}:{}{} = {}{}:{}{}'.format(*args) 118 | if name == 'eqratio3': 119 | a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name 120 | return f'S {o} {a} {b} {o} {c} {d}' 121 | if name in ['cong', 'D']: 122 | a, b, c, d = args 123 | return f'{a}{b} = {c}{d}' 124 | if name in ['perp', 'T']: 125 | if len(args) == 2: # this is algebraic derivation. 126 | ab, cd = args # ab = 'd( ... )' 127 | return f'{ab} \u27c2 {cd}' 128 | a, b, c, d = args 129 | return f'{a}{b} \u27c2 {c}{d}' 130 | if name in ['para', 'P']: 131 | if len(args) == 2: # this is algebraic derivation. 132 | ab, cd = args # ab = 'd( ... )' 133 | return f'{ab} \u2225 {cd}' 134 | a, b, c, d = args 135 | return f'{a}{b} \u2225 {c}{d}' 136 | if name in ['simtri2', 'simtri', 'simtri*']: 137 | a, b, c, x, y, z = args 138 | return f'\u0394{a}{b}{c} is similar to \u0394{x}{y}{z}' 139 | if name in ['contri2', 'contri', 'contri*']: 140 | a, b, c, x, y, z = args 141 | return f'\u0394{a}{b}{c} is congruent to \u0394{x}{y}{z}' 142 | if name in ['circle', 'I']: 143 | o, a, b, c = args 144 | return f'{o} is the circumcenter of \\Delta {a}{b}{c}' 145 | if name == 'foot': 146 | a, b, c, d = args 147 | return f'{a} is the foot of {b} on {c}{d}' 148 | 149 | 150 | def pretty(txt: str) -> str: 151 | """Pretty formating a predicate string.""" 152 | if isinstance(txt, str): 153 | txt = txt.split(' ') 154 | name, *args = txt 155 | if name == 'ind': 156 | return 'Y ' + ' '.join(args) 157 | if name in ['fixc', 'fixl', 'fixb', 'fixt', 'fixp']: 158 | return map_symbol_inv(name) + ' ' + ' '.join(args) 159 | if name == 'acompute': 160 | a, b, c, d = args 161 | return 'A ' + ' '.join(args) 162 | if name == 'rcompute': 163 | a, b, c, d = args 164 | return 'R ' + ' '.join(args) 165 | if name == 'aconst': 166 | a, b, c, d, y = args 167 | return f'^ {pretty2a(a, b, c, d)} {y}' 168 | if name == 'rconst': 169 | a, b, c, d, y = args 170 | return f'/ {pretty2r(a, b, c, d)} {y}' 171 | if name == 'coll': 172 | return 'C ' + ' '.join(args) 173 | if name == 'collx': 174 | return 'X ' + ' '.join(args) 175 | if name == 'cyclic': 176 | return 'O ' + ' '.join(args) 177 | if name in ['midp', 'midpoint']: 178 | x, a, b = args 179 | return f'M {x} {a} {b}' 180 | if name == 'eqangle': 181 | a, b, c, d, e, f, g, h = args 182 | return f'^ {pretty2a(a, b, c, d)} {pretty2a(e, f, g, h)}' 183 | if name == 'eqratio': 184 | a, b, c, d, e, f, g, h = args 185 | return f'/ {pretty2r(a, b, c, d)} {pretty2r(e, f, g, h)}' 186 | if name == 'eqratio3': 187 | a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name 188 | return f'S {o} {a} {b} {o} {c} {d}' 189 | if name == 'cong': 190 | a, b, c, d = args 191 | return f'D {a} {b} {c} {d}' 192 | if name == 'perp': 193 | if len(args) == 2: # this is algebraic derivation. 194 | ab, cd = args # ab = 'd( ... )' 195 | return f'T {ab} {cd}' 196 | a, b, c, d = args 197 | return f'T {a} {b} {c} {d}' 198 | if name == 'para': 199 | if len(args) == 2: # this is algebraic derivation. 200 | ab, cd = args # ab = 'd( ... )' 201 | return f'P {ab} {cd}' 202 | a, b, c, d = args 203 | return f'P {a} {b} {c} {d}' 204 | if name in ['simtri2', 'simtri', 'simtri*']: 205 | a, b, c, x, y, z = args 206 | return f'S {a} {b} {c} {x} {y} {z}' 207 | if name in ['contri2', 'contri', 'contri*']: 208 | a, b, c, x, y, z = args 209 | return f'= {a} {b} {c} {x} {y} {z}' 210 | if name == 'circle': 211 | o, a, b, c = args 212 | return f'I {o} {a} {b} {c}' 213 | if name == 'foot': 214 | a, b, c, d = args 215 | return f'F {a} {b} {c} {d}' 216 | return ' '.join(txt) 217 | -------------------------------------------------------------------------------- /problem_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for problem.py.""" 17 | import unittest 18 | 19 | from absl.testing import absltest 20 | import problem as pr 21 | 22 | 23 | class ProblemTest(unittest.TestCase): 24 | 25 | @classmethod 26 | def setUpClass(cls): 27 | super().setUpClass() 28 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) 29 | 30 | def test_orthocenter_no_translate(self): 31 | txt = 'a b c = triangle a b c; h = on_tline h b a c, on_tline h c a b ? perp a h b c' # pylint: disable=line-too-long 32 | 33 | # read the txt into pr.Problem object, do not change the name of points: 34 | p = pr.Problem.from_txt(txt, translate=False) 35 | 36 | # This is fed into the LM, translating from constructive to constrained: 37 | setup_str = p.setup_str_from_problem(ProblemTest.defs) 38 | 39 | self.assertEqual( 40 | setup_str, 41 | '{S} a : ; b : ; c : ; h : T a b c h 00 T a c b h 01 ? T a h b c', 42 | ) 43 | 44 | def test_orthocenter_translate(self): 45 | txt = 'a b c = triangle a b c; h = on_tline h b a c, on_tline h c a b ? perp a h b c' # pylint: disable=line-too-long 46 | 47 | # Read the txt into pr.Problem object, change h -> d to match 48 | # training data distribution. 49 | p = pr.Problem.from_txt(txt, translate=True) 50 | 51 | # This is fed into the LM, translating from constructive to constrained: 52 | setup_str = p.setup_str_from_problem(ProblemTest.defs) 53 | 54 | self.assertEqual( 55 | setup_str, 56 | '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c', 57 | ) 58 | 59 | 60 | if __name__ == '__main__': 61 | absltest.main() 62 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | tensorflow==2.13.0 2 | numpy==1.23.5 3 | scipy==1.10.0 4 | matplotlib==3.7.0 5 | gdown==4.7.1 6 | jax==0.4.6 7 | jaxlib==0.4.6 8 | flax==0.5.3 9 | gin-config==0.5.0 10 | gin==0.1.6 11 | t5==0.9.4 12 | sentencepiece==0.1.99 13 | absl-py==1.4.0 14 | clu==0.0.7 15 | optax==0.1.7 16 | seqio==0.0.18 17 | tensorflow-datasets==4.9.3 18 | -------------------------------------------------------------------------------- /rules.txt: -------------------------------------------------------------------------------- 1 | perp A B C D, perp C D E F, ncoll A B E => para A B E F 2 | cong O A O B, cong O B O C, cong O C O D => cyclic A B C D 3 | eqangle A B P Q C D P Q => para A B C D 4 | cyclic A B P Q => eqangle P A P B Q A Q B 5 | eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q 6 | cyclic A B C P Q R, eqangle C A C B R P R Q => cong A B P Q 7 | midp E A B, midp F A C => para E F B C 8 | para A B C D, coll O A C, coll O B D => eqratio3 A B C D O O 9 | perp A B C D, perp E F G H, npara A B E F => eqangle A B E F C D G H 10 | eqangle a b c d m n p q, eqangle c d e f p q r u => eqangle a b e f m n r u 11 | eqratio a b c d m n p q, eqratio c d e f p q r u => eqratio a b e f m n r u 12 | eqratio6 d b d c a b a c, coll d b c, ncoll a b c => eqangle6 a b a d a d a c 13 | eqangle6 a b a d a d a c, coll d b c, ncoll a b c => eqratio6 d b d c a b a c 14 | cong O A O B, ncoll O A B => eqangle O A A B A B O B 15 | eqangle6 A O A B B A B O, ncoll O A B => cong O A O B 16 | circle O A B C, perp O A A X => eqangle A X A B C A C B 17 | circle O A B C, eqangle A X A B C A C B => perp O A A X 18 | circle O A B C, midp M B C => eqangle A B A C O B O M 19 | circle O A B C, coll M B C, eqangle A B A C O B O M => midp M B C 20 | perp A B B C, midp M A C => cong A M B M 21 | circle O A B C, coll O A C => perp A B B C 22 | cyclic A B C D, para A B C D => eqangle A D C D C D C B 23 | midp M A B, perp O M A B => cong O A O B 24 | cong A P B P, cong A Q B Q => perp A B P Q 25 | cong A P B P, cong A Q B Q, cyclic A B P Q => perp P A A Q 26 | midp M A B, midp M C D => para A C B D 27 | midp M A B, para A C B D, para A D B C => midp M C D 28 | eqratio O A A C O B B D, coll O A C, coll O B D, ncoll A B C, sameside A O C B O D => para A B C D 29 | para A B A C => coll A B C 30 | midp M A B, midp N C D => eqratio M A A B N C C D 31 | eqangle A B P Q C D U V, perp P Q U V => perp A B C D 32 | eqratio A B P Q C D U V, cong P Q U V => cong A B C D 33 | cong A B P Q, cong B C Q R, cong C A R P, ncoll A B C => contri* A B C P Q R 34 | cong A B P Q, cong B C Q R, eqangle6 B A B C Q P Q R, ncoll A B C => contri* A B C P Q R 35 | eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C => simtri A B C P Q R 36 | eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C => simtri2 A B C P Q R 37 | eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri A B C P Q R 38 | eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C, cong A B P Q => contri2 A B C P Q R 39 | eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R 40 | eqratio6 B A B C Q P Q R, eqangle6 B A B C Q P Q R, ncoll A B C => simtri* A B C P Q R 41 | eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri* A B C P Q R 42 | para a b c d, coll m a d, coll n b c, eqratio6 m a m d n b n c, sameside m a d n b c => para m n a b 43 | para a b c d, coll m a d, coll n b c, para m n a b => eqratio6 m a m d n b n c 44 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # !/bin/bash 17 | set -e 18 | set -x 19 | 20 | virtualenv -p python3 . 21 | source ./bin/activate 22 | 23 | pip install --require-hashes -r requirements.txt 24 | 25 | gdown --folder https://bit.ly/alphageometry 26 | DATA=ag_ckpt_vocab 27 | 28 | MELIAD_PATH=meliad_lib/meliad 29 | mkdir -p $MELIAD_PATH 30 | git clone https://github.com/google-research/meliad $MELIAD_PATH 31 | export PYTHONPATH=$PYTHONPATH:$MELIAD_PATH 32 | 33 | DDAR_ARGS=( 34 | --defs_file=$(pwd)/defs.txt \ 35 | --rules_file=$(pwd)/rules.txt \ 36 | ); 37 | 38 | BATCH_SIZE=2 39 | BEAM_SIZE=2 40 | DEPTH=2 41 | 42 | SEARCH_ARGS=( 43 | --beam_size=$BEAM_SIZE 44 | --search_depth=$DEPTH 45 | ) 46 | 47 | LM_ARGS=( 48 | --ckpt_path=$DATA \ 49 | --vocab_path=$DATA/geometry.757.model \ 50 | --gin_search_paths=$MELIAD_PATH/transformer/configs \ 51 | --gin_file=base_htrans.gin \ 52 | --gin_file=size/medium_150M.gin \ 53 | --gin_file=options/positions_t5.gin \ 54 | --gin_file=options/lr_cosine_decay.gin \ 55 | --gin_file=options/seq_1024_nocache.gin \ 56 | --gin_file=geometry_150M_generate.gin \ 57 | --gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True \ 58 | --gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE \ 59 | --gin_param=TransformerTaskConfig.sequence_length=128 \ 60 | --gin_param=Trainer.restore_state_variables=False 61 | ); 62 | 63 | echo $PYTHONPATH 64 | 65 | python -m alphageometry \ 66 | --alsologtostderr \ 67 | --problems_file=$(pwd)/examples.txt \ 68 | --problem_name=orthocenter \ 69 | --mode=alphageometry \ 70 | "${DDAR_ARGS[@]}" \ 71 | "${SEARCH_ARGS[@]}" \ 72 | "${LM_ARGS[@]}" 73 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | DATA=ag_ckpt_vocab 17 | MELIAD_PATH=meliad_lib/meliad 18 | export PYTHONPATH=$PYTHONPATH:$MELIAD_PATH 19 | 20 | python problem_test.py 21 | python geometry_test.py 22 | python graph_utils_test.py 23 | python numericals_test.py 24 | python graph_test.py 25 | python dd_test.py 26 | python ar_test.py 27 | python ddar_test.py 28 | python trace_back_test.py 29 | python alphageometry_test.py 30 | python lm_inference_test.py --meliad_path=$MELIAD_PATH --data_path=$DATA 31 | -------------------------------------------------------------------------------- /trace_back.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Implements DAG-level traceback.""" 17 | 18 | from typing import Any 19 | 20 | import geometry as gm 21 | import pretty as pt 22 | import problem 23 | 24 | 25 | pretty = pt.pretty 26 | 27 | 28 | def point_levels( 29 | setup: list[problem.Dependency], existing_points: list[gm.Point] 30 | ) -> list[tuple[set[gm.Point], list[problem.Dependency]]]: 31 | """Reformat setup into levels of point constructions.""" 32 | levels = [] 33 | for con in setup: 34 | plevel = max([p.plevel for p in con.args if isinstance(p, gm.Point)]) 35 | 36 | while len(levels) - 1 < plevel: 37 | levels.append((set(), [])) 38 | 39 | for p in con.args: 40 | if not isinstance(p, gm.Point): 41 | continue 42 | if existing_points and p in existing_points: 43 | continue 44 | 45 | levels[p.plevel][0].add(p) 46 | 47 | cons = levels[plevel][1] 48 | cons.append(con) 49 | 50 | return [(p, c) for p, c in levels if p or c] 51 | 52 | 53 | def point_log( 54 | setup: list[problem.Dependency], 55 | ref_id: dict[tuple[str, ...], int], 56 | existing_points=list[gm.Point], 57 | ) -> list[tuple[list[gm.Point], list[problem.Dependency]]]: 58 | """Reformat setup into groups of point constructions.""" 59 | log = [] 60 | 61 | levels = point_levels(setup, existing_points) 62 | 63 | for points, cons in levels: 64 | for con in cons: 65 | if con.hashed() not in ref_id: 66 | ref_id[con.hashed()] = len(ref_id) 67 | 68 | log.append((points, cons)) 69 | 70 | return log 71 | 72 | 73 | def setup_to_levels( 74 | setup: list[problem.Dependency], 75 | ) -> list[list[problem.Dependency]]: 76 | """Reformat setup into levels of point constructions.""" 77 | levels = [] 78 | for d in setup: 79 | plevel = max([p.plevel for p in d.args if isinstance(p, gm.Point)]) 80 | while len(levels) - 1 < plevel: 81 | levels.append([]) 82 | 83 | levels[plevel].append(d) 84 | 85 | levels = [lvl for lvl in levels if lvl] 86 | return levels 87 | 88 | 89 | def separate_dependency_difference( 90 | query: problem.Dependency, 91 | log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], 92 | ) -> tuple[ 93 | list[tuple[list[problem.Dependency], list[problem.Dependency]]], 94 | list[problem.Dependency], 95 | list[problem.Dependency], 96 | set[gm.Point], 97 | set[gm.Point], 98 | ]: 99 | """Identify and separate the dependency difference.""" 100 | setup = [] 101 | log_, log = log, [] 102 | for prems, cons in log_: 103 | if not prems: 104 | setup.extend(cons) 105 | continue 106 | cons_ = [] 107 | for con in cons: 108 | if con.rule_name == 'c0': 109 | setup.append(con) 110 | else: 111 | cons_.append(con) 112 | if not cons_: 113 | continue 114 | 115 | prems = [p for p in prems if p.name != 'ind'] 116 | log.append((prems, cons_)) 117 | 118 | points = set(query.args) 119 | queue = list(query.args) 120 | i = 0 121 | while i < len(queue): 122 | q = queue[i] 123 | i += 1 124 | if not isinstance(q, gm.Point): 125 | continue 126 | for p in q.rely_on: 127 | if p not in points: 128 | points.add(p) 129 | queue.append(p) 130 | 131 | setup_, setup, aux_setup, aux_points = setup, [], [], set() 132 | for con in setup_: 133 | if con.name == 'ind': 134 | continue 135 | elif any([p not in points for p in con.args if isinstance(p, gm.Point)]): 136 | aux_setup.append(con) 137 | aux_points.update( 138 | [p for p in con.args if isinstance(p, gm.Point) and p not in points] 139 | ) 140 | else: 141 | setup.append(con) 142 | 143 | return log, setup, aux_setup, points, aux_points 144 | 145 | 146 | def recursive_traceback( 147 | query: problem.Dependency, 148 | ) -> list[tuple[list[problem.Dependency], list[problem.Dependency]]]: 149 | """Recursively traceback from the query, i.e. the conclusion.""" 150 | visited = set() 151 | log = [] 152 | stack = [] 153 | 154 | def read(q: problem.Dependency) -> None: 155 | q = q.remove_loop() 156 | hashed = q.hashed() 157 | if hashed in visited: 158 | return 159 | 160 | if hashed[0] in ['ncoll', 'npara', 'nperp', 'diff', 'sameside']: 161 | return 162 | 163 | nonlocal stack 164 | 165 | stack.append(hashed) 166 | prems = [] 167 | 168 | if q.rule_name != problem.CONSTRUCTION_RULE: 169 | all_deps = [] 170 | dep_names = set() 171 | for d in q.why: 172 | if d.hashed() in dep_names: 173 | continue 174 | dep_names.add(d.hashed()) 175 | all_deps.append(d) 176 | 177 | for d in all_deps: 178 | h = d.hashed() 179 | if h not in visited: 180 | read(d) 181 | if h in visited: 182 | prems.append(d) 183 | 184 | visited.add(hashed) 185 | hashs = sorted([d.hashed() for d in prems]) 186 | found = False 187 | for ps, qs in log: 188 | if sorted([d.hashed() for d in ps]) == hashs: 189 | qs += [q] 190 | found = True 191 | break 192 | if not found: 193 | log.append((prems, [q])) 194 | 195 | stack.pop(-1) 196 | 197 | read(query) 198 | 199 | # post process log: separate multi-conclusion lines 200 | log_, log = log, [] 201 | for ps, qs in log_: 202 | for q in qs: 203 | log.append((ps, [q])) 204 | 205 | return log 206 | 207 | 208 | def collx_to_coll_setup( 209 | setup: list[problem.Dependency], 210 | ) -> list[problem.Dependency]: 211 | """Convert collx to coll in setups.""" 212 | result = [] 213 | for level in setup_to_levels(setup): 214 | hashs = set() 215 | for dep in level: 216 | if dep.name == 'collx': 217 | dep.name = 'coll' 218 | dep.args = list(set(dep.args)) 219 | 220 | if dep.hashed() in hashs: 221 | continue 222 | hashs.add(dep.hashed()) 223 | result.append(dep) 224 | 225 | return result 226 | 227 | 228 | def collx_to_coll( 229 | setup: list[problem.Dependency], 230 | aux_setup: list[problem.Dependency], 231 | log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], 232 | ) -> tuple[ 233 | list[problem.Dependency], 234 | list[problem.Dependency], 235 | list[tuple[list[problem.Dependency], list[problem.Dependency]]], 236 | ]: 237 | """Convert collx to coll and dedup.""" 238 | setup = collx_to_coll_setup(setup) 239 | aux_setup = collx_to_coll_setup(aux_setup) 240 | 241 | con_set = set([p.hashed() for p in setup + aux_setup]) 242 | log_, log = log, [] 243 | for prems, cons in log_: 244 | prem_set = set() 245 | prems_, prems = prems, [] 246 | for p in prems_: 247 | if p.name == 'collx': 248 | p.name = 'coll' 249 | p.args = list(set(p.args)) 250 | if p.hashed() in prem_set: 251 | continue 252 | prem_set.add(p.hashed()) 253 | prems.append(p) 254 | 255 | cons_, cons = cons, [] 256 | for c in cons_: 257 | if c.name == 'collx': 258 | c.name = 'coll' 259 | c.args = list(set(c.args)) 260 | if c.hashed() in con_set: 261 | continue 262 | con_set.add(c.hashed()) 263 | cons.append(c) 264 | 265 | if not cons or not prems: 266 | continue 267 | 268 | log.append((prems, cons)) 269 | 270 | return setup, aux_setup, log 271 | 272 | 273 | def get_logs( 274 | query: problem.Dependency, g: Any, merge_trivials: bool = False 275 | ) -> tuple[ 276 | list[problem.Dependency], 277 | list[problem.Dependency], 278 | list[tuple[list[problem.Dependency], list[problem.Dependency]]], 279 | set[gm.Point], 280 | ]: 281 | """Given a DAG and conclusion N, return the premise, aux, proof.""" 282 | query = query.why_me_or_cache(g, query.level) 283 | log = recursive_traceback(query) 284 | log, setup, aux_setup, setup_points, _ = separate_dependency_difference( 285 | query, log 286 | ) 287 | 288 | setup, aux_setup, log = collx_to_coll(setup, aux_setup, log) 289 | 290 | setup, aux_setup, log = shorten_and_shave( 291 | setup, aux_setup, log, merge_trivials 292 | ) 293 | 294 | return setup, aux_setup, log, setup_points 295 | 296 | 297 | def shorten_and_shave( 298 | setup: list[problem.Dependency], 299 | aux_setup: list[problem.Dependency], 300 | log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], 301 | merge_trivials: bool = False, 302 | ) -> tuple[ 303 | list[problem.Dependency], 304 | list[problem.Dependency], 305 | list[tuple[list[problem.Dependency], list[problem.Dependency]]], 306 | ]: 307 | """Shorten the proof by removing unused predicates.""" 308 | log, _ = shorten_proof(log, merge_trivials=merge_trivials) 309 | 310 | all_prems = sum([list(prems) for prems, _ in log], []) 311 | all_prems = set([p.hashed() for p in all_prems]) 312 | setup = [d for d in setup if d.hashed() in all_prems] 313 | aux_setup = [d for d in aux_setup if d.hashed() in all_prems] 314 | return setup, aux_setup, log 315 | 316 | 317 | def join_prems( 318 | con: problem.Dependency, 319 | con2prems: dict[tuple[str, ...], list[problem.Dependency]], 320 | expanded: set[tuple[str, ...]], 321 | ) -> list[problem.Dependency]: 322 | """Join proof steps with the same premises.""" 323 | h = con.hashed() 324 | if h in expanded or h not in con2prems: 325 | return [con] 326 | 327 | result = [] 328 | for p in con2prems[h]: 329 | result += join_prems(p, con2prems, expanded) 330 | return result 331 | 332 | 333 | def shorten_proof( 334 | log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], 335 | merge_trivials: bool = False, 336 | ) -> tuple[ 337 | list[tuple[list[problem.Dependency], list[problem.Dependency]]], 338 | dict[tuple[str, ...], list[problem.Dependency]], 339 | ]: 340 | """Join multiple trivials proof steps into one.""" 341 | pops = set() 342 | con2prem = {} 343 | for prems, cons in log: 344 | assert len(cons) == 1 345 | con = cons[0] 346 | if con.rule_name == '': # pylint: disable=g-explicit-bool-comparison 347 | con2prem[con.hashed()] = prems 348 | elif not merge_trivials: 349 | # except for the ones that are premises to non-trivial steps. 350 | pops.update({p.hashed() for p in prems}) 351 | 352 | for p in pops: 353 | if p in con2prem: 354 | con2prem.pop(p) 355 | 356 | expanded = set() 357 | log2 = [] 358 | for i, (prems, cons) in enumerate(log): 359 | con = cons[0] 360 | if i < len(log) - 1 and con.hashed() in con2prem: 361 | continue 362 | 363 | hashs = set() 364 | new_prems = [] 365 | 366 | for p in sum([join_prems(p, con2prem, expanded) for p in prems], []): 367 | if p.hashed() not in hashs: 368 | new_prems.append(p) 369 | hashs.add(p.hashed()) 370 | 371 | log2 += [(new_prems, [con])] 372 | expanded.add(con.hashed()) 373 | 374 | return log2, con2prem 375 | -------------------------------------------------------------------------------- /trace_back_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit testing for the trace_back code.""" 17 | 18 | import unittest 19 | 20 | from absl.testing import absltest 21 | import ddar 22 | import graph as gh 23 | import problem as pr 24 | import trace_back as tb 25 | 26 | 27 | class TracebackTest(unittest.TestCase): 28 | 29 | @classmethod 30 | def setUpClass(cls): 31 | super().setUpClass() 32 | cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) 33 | cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True) 34 | 35 | def test_orthocenter_dependency_difference(self): 36 | txt = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c' # pylint: disable=line-too-long 37 | p = pr.Problem.from_txt(txt) 38 | g, _ = gh.Graph.build_problem(p, TracebackTest.defs) 39 | 40 | ddar.solve(g, TracebackTest.rules, p) 41 | 42 | goal_args = g.names2nodes(p.goal.args) 43 | query = pr.Dependency(p.goal.name, goal_args, None, None) 44 | 45 | setup, aux, _, _ = tb.get_logs(query, g, merge_trivials=False) 46 | 47 | # Convert each predicates to its hash string: 48 | setup = [p.hashed() for p in setup] 49 | aux = [p.hashed() for p in aux] 50 | 51 | self.assertCountEqual( 52 | setup, [('perp', 'a', 'c', 'b', 'd'), ('perp', 'a', 'b', 'c', 'd')] 53 | ) 54 | 55 | self.assertCountEqual( 56 | aux, [('coll', 'a', 'c', 'e'), ('coll', 'b', 'd', 'e')] 57 | ) 58 | 59 | 60 | if __name__ == '__main__': 61 | absltest.main() 62 | --------------------------------------------------------------------------------