├── .gitignore ├── .gitlab-ci.yml ├── LICENSE ├── doc ├── dune └── index.mld ├── dune-project ├── ego.opam ├── lib ├── basic.ml ├── basic.mli ├── dune ├── ego.ml ├── ego.mli ├── equivalence.ml ├── equivalence.mli ├── generic.ml ├── generic.mli ├── id.ml ├── id.mli ├── language.ml ├── ordered_set.ml ├── ordered_set.mli ├── query.ml ├── scheduler.ml ├── symbol.ml ├── symbol.mli ├── term.ml └── types.ml ├── macros ├── dune └── ppx_sexp.ml ├── readme.md └── test ├── dune ├── test_basic.ml ├── test_generic.ml ├── test_math.ml └── test_prop.ml /.gitignore: -------------------------------------------------------------------------------- 1 | _build/ -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | image: ruby:2.7 2 | 3 | pages: 4 | script: 5 | - echo 'Nothing to do...' 6 | artifacts: 7 | paths: 8 | - public/ 9 | only: 10 | - pages 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /doc/dune: -------------------------------------------------------------------------------- 1 | (documentation 2 | (package ego)) 3 | 4 | -------------------------------------------------------------------------------- /dune-project: -------------------------------------------------------------------------------- 1 | (lang dune 2.9) 2 | (package 3 | (name ego) 4 | (synopsis "Ego (EGraphs OCaml) is extensible EGraph library for OCaml.") 5 | (description 6 | "Ego is an exensible egraph library for OCaml loosely based on the egg library in Rust.") 7 | (depends 8 | (ocaml (>= 4.0.8)) 9 | (containers (>= 3.3)) 10 | (containers-data (>= 3.3)) 11 | (iter (>= 1.2.1)) 12 | (ppx_deriving (>= 4.4)) 13 | (ocamldot (>= 1.1)) 14 | (sexplib (>= v0.14.0)))) 15 | (version 0.0.6) 16 | (name ego) 17 | (generate_opam_files true) 18 | (license GPL-3.0+) 19 | (source (uri git+https://gitlab.com/gopiandcode/ego.git)) 20 | (bug_reports https://gitlab.com/gopiandcode/ego/issues) 21 | (homepage https://gitlab.com/gopiandcode/ego) 22 | (authors "Kiran Gopinathan") 23 | (maintainers "kirang@comp.nus.edu.sg") 24 | -------------------------------------------------------------------------------- /ego.opam: -------------------------------------------------------------------------------- 1 | # This file is generated by dune, edit dune-project instead 2 | opam-version: "2.0" 3 | version: "0.0.6" 4 | synopsis: "Ego (EGraphs OCaml) is extensible EGraph library for OCaml." 5 | description: 6 | "Ego is an exensible egraph library for OCaml loosely based on the egg library in Rust." 7 | maintainer: ["kirang@comp.nus.edu.sg"] 8 | authors: ["Kiran Gopinathan"] 9 | license: "GPL-3.0+" 10 | homepage: "https://gitlab.com/gopiandcode/ego" 11 | bug-reports: "https://gitlab.com/gopiandcode/ego/issues" 12 | depends: [ 13 | "dune" {>= "2.9"} 14 | "ocaml" {>= "4.0.8"} 15 | "containers" {>= "3.3"} 16 | "containers-data" {>= "3.3"} 17 | "iter" {>= "1.2.1"} 18 | "ppx_deriving" {>= "4.4"} 19 | "ocamldot" {>= "1.1"} 20 | "sexplib" {>= "v0.14.0"} 21 | "odoc" {with-doc} 22 | "ppx_inline_alcotest" {with-test} 23 | ] 24 | build: [ 25 | ["dune" "subst"] {dev} 26 | [ 27 | "dune" 28 | "build" 29 | "-p" 30 | name 31 | "-j" 32 | jobs 33 | "--promote-install-files=false" 34 | "@install" 35 | "@runtest" {with-test} 36 | "@doc" {with-doc} 37 | ] 38 | ["dune" "install" "-p" name "--create-install-files" name] 39 | ] 40 | dev-repo: "git+https://gitlab.com/gopiandcode/ego.git" 41 | -------------------------------------------------------------------------------- /lib/basic.ml: -------------------------------------------------------------------------------- 1 | open [@warning "-33"] Containers 2 | open Language 3 | open Types 4 | 5 | module StringMap = Map.Make(String) 6 | let str p v = Format.to_string p v 7 | let lappend_pair a (b,c) = (a,b,c) 8 | 9 | module Symbol = Symbol 10 | 11 | module Query = struct 12 | module Query = Query 13 | 14 | type 'a query = 'a Query.t = V of string | Q of 'a * 'a query list 15 | type t = Symbol.t query 16 | 17 | let of_sexp = Query.of_sexp Symbol.intern 18 | let to_sexp = Query.to_sexp (Format.to_string Symbol.pp) 19 | 20 | let pp = Query.pp Symbol.pp 21 | let show = Format.to_string pp 22 | 23 | let variables = Query.variables 24 | 25 | end 26 | 27 | module Term = Term 28 | 29 | module Rule : sig 30 | type 'sym rule = Query.t * Query.t 31 | type t = Symbol.t rule [@@deriving show] 32 | val make: from:Query.t -> into:Query.t -> t option 33 | end = struct 34 | type 'sym rule = Query.t * Query.t 35 | type t = Symbol.t rule 36 | 37 | let make ~from ~into = 38 | let pattern_vars = Query.variables from in 39 | let rewrite_vars = Query.variables into in 40 | if StringSet.subset rewrite_vars pattern_vars 41 | then Some (from, into) 42 | else None 43 | 44 | let pp fmt (lhs,rhs) = 45 | Format.pp_open_hbox fmt (); 46 | Query.pp fmt lhs; 47 | Format.pp_print_string fmt " -> "; 48 | Query.pp fmt rhs; 49 | Format.pp_close_box fmt () 50 | 51 | let show = str pp 52 | 53 | let%test "rules are printed as expected" = 54 | Alcotest.(check string) 55 | "prints as expected" 56 | "(<< ?a 1) -> (* ?a 2)" (str pp ((Q (Symbol.intern "<<", [V "a"; Q (Symbol.intern "1", [])])), 57 | (Q (Symbol.intern "*", [V "a"; Q (Symbol.intern "2", [])])))) 58 | 59 | end 60 | 61 | 62 | 63 | (* * Egraphs *) 64 | (* ** Types *) 65 | type egraph = { 66 | mutable version: int; 67 | 68 | uf: Id.store; (* tracks equivalence classes of 69 | class ids *) 70 | class_members: (enode * Id.t) Vector.vector Id.Map.t; (* maps classes to the canonical nodes 71 | they contain, and any classes that are 72 | children of these nodes *) 73 | hash_cons: (enode, Id.t) Hashtbl.t; (* maps cannonical nodes to their 74 | equivalence classes *) 75 | worklist: Id.t Vector.vector; (* List of equivalence classes for which 76 | nodes are out of date - i.e 77 | cannoncial(node) != node *) 78 | } 79 | 80 | 81 | 82 | 83 | 84 | 85 | (* ** Graph *) 86 | module EGraph = struct 87 | 88 | type t = egraph 89 | 90 | (* *** Pretty printing *) 91 | let pp ?(pp_id=EClassId.pp) fmt self = 92 | let open Format in 93 | pp_print_string fmt "(egraph"; 94 | pp_open_hovbox fmt 1; 95 | pp_print_space fmt (); 96 | pp_print_string fmt "(eclasses "; 97 | pp_open_hvbox fmt 1; 98 | Id.Map.to_seq self.class_members 99 | |> Seq.to_list 100 | |> pp_print_list ~pp_sep:pp_print_space 101 | (fun fmt (cls, elts) -> 102 | pp_print_string fmt "("; 103 | pp_open_hvbox fmt 1; 104 | pp_id fmt cls; 105 | if not @@ Vector.is_empty elts then 106 | pp_print_space fmt (); 107 | Vector.pp ~pp_sep:pp_print_space 108 | (fun fmt (node, id) -> 109 | pp_print_string fmt "("; 110 | pp_open_hbox fmt (); 111 | pp_id fmt id; 112 | pp_print_space fmt (); 113 | ENode.pp ~pp_id fmt node; 114 | pp_close_box fmt (); 115 | pp_print_string fmt ")"; 116 | ) fmt elts; 117 | pp_close_box fmt (); 118 | pp_print_string fmt ")"; 119 | ) fmt; 120 | pp_close_box fmt (); 121 | pp_print_string fmt ")"; 122 | pp_print_space fmt (); 123 | pp_print_string fmt "(enodes "; 124 | pp_open_hvbox fmt 1; 125 | Hashtbl.to_seq self.hash_cons 126 | |> Seq.to_list 127 | |> pp_print_list ~pp_sep:pp_print_space 128 | (fun fmt (node, cls) -> 129 | pp_print_string fmt "("; 130 | pp_open_hvbox fmt 1; 131 | pp_id fmt cls; 132 | pp_print_space fmt (); 133 | ENode.pp ~pp_id fmt node; 134 | pp_close_box fmt (); 135 | pp_print_string fmt ")"; 136 | ) fmt; 137 | pp_close_box fmt (); 138 | pp_print_string fmt ")"; 139 | pp_close_box fmt (); 140 | pp_print_string fmt ")" 141 | 142 | let (.@[]) self fn = fn self [@@inline always] 143 | 144 | 145 | (* *** Initialization *) 146 | let init () = { 147 | version=0; 148 | uf=Id.create_store (); 149 | class_members=Id.Map.create 10; 150 | hash_cons=Hashtbl.create 10; 151 | worklist=Vector.create (); 152 | } 153 | 154 | (* *** Eclasses *) 155 | let new_class self = 156 | let id = Id.make self.uf () in 157 | id 158 | 159 | let get_class_members self id = 160 | match Id.Map.find_opt self.class_members id with 161 | | Some classes -> classes 162 | | None -> 163 | let cls = Vector.create () in 164 | Id.Map.add self.class_members id cls; 165 | cls 166 | 167 | (* Adds a node into the egraph, assuming that the cannonical version 168 | of the node is up to date in the hash cons or 169 | *) 170 | let add_enode self node = 171 | let node = ENode.canonicalise self.uf node in 172 | let id = match Hashtbl.find_opt self.hash_cons node with 173 | | None -> 174 | self.version <- self.version + 1; 175 | (* There are no nodes congruent to this node in the graph *) 176 | let id = self.@[new_class] in 177 | let cls = self.@[get_class_members] id in 178 | Vector.append_list cls @@ List.map (fun child -> 179 | (node, child) 180 | ) (ENode.children node); 181 | Hashtbl.replace self.hash_cons node id; 182 | id 183 | | Some id -> id in 184 | Id.find self.uf id 185 | 186 | let rec subst self pat env = 187 | match pat with 188 | | Query.V id -> StringMap.find id env 189 | | Q (sym, args) -> 190 | let enode = (sym, List.map (fun arg -> self.@[subst] arg env) args) in 191 | self.@[add_enode] enode 192 | 193 | let add_node self ((sym, children) : Term.t) = 194 | add_enode self (Symbol.intern sym, children) 195 | 196 | let add_sexp self sexp = add_node self @@ Term.of_sexp (add_node self) sexp 197 | 198 | let find self vl = Id.find self.uf vl 199 | 200 | let append_to_worklist self vl = 201 | Vector.push self.worklist vl 202 | 203 | let merge self a b = 204 | let (+=) va vb = Vector.append va vb in 205 | let a = Id.find self.uf a in 206 | let b = Id.find self.uf b in 207 | if Id.eq_id a b then () 208 | else begin 209 | self.version <- self.version + 1; 210 | assert (Id.eq_id a (Id.union self.uf a b)); 211 | assert (Id.eq_id a (Id.find self.uf a)); 212 | assert (Id.eq_id a (Id.find self.uf b)); 213 | self.@[get_class_members] b += self.@[get_class_members] a; 214 | Vector.clear (self.@[get_class_members] a); 215 | self.@[append_to_worklist] b; 216 | end 217 | 218 | let repair self ecls_id = 219 | let (+=) va vb = Vector.append_iter va vb in 220 | let uses = self.@[get_class_members] ecls_id in 221 | let uses = 222 | let res = Vector.copy uses in 223 | Vector.clear uses; 224 | res in 225 | (* update canonical uses in hashcons *) 226 | Vector.to_iter uses (fun (p_node, p_eclass) -> 227 | Hashtbl.remove self.hash_cons p_node; 228 | let p_node = self.uf.@[ENode.canonicalise] p_node in 229 | Hashtbl.replace self.hash_cons p_node (self.@[find] p_eclass) 230 | ); 231 | let new_uses = Hashtbl.create 10 in 232 | Vector.to_iter uses (fun (p_node, p_eclass) -> 233 | let p_node = self.uf.@[ENode.canonicalise] p_node in 234 | begin match Hashtbl.find_opt new_uses p_node with 235 | | None -> () 236 | | Some nd -> self.@[merge] p_eclass nd 237 | end; 238 | Hashtbl.replace new_uses p_node (self.@[find] p_eclass) 239 | ); 240 | (self.@[get_class_members] (self.@[find] ecls_id)) += (Hashtbl.to_iter new_uses) 241 | 242 | let rebuild self = 243 | while not @@ Vector.is_empty self.worklist do 244 | let worklist = Id.Set.of_iter (Vector.to_iter self.worklist |> Iter.map (self.@[find])) in 245 | Vector.clear self.worklist; 246 | Id.Set.to_iter worklist (fun ecls_id -> 247 | self.@[repair] ecls_id 248 | ) 249 | done 250 | 251 | (* *** Exports *) 252 | (* **** Export eclasses *) 253 | let eclasses self = 254 | let r = Id.Map.create 10 in 255 | Hashtbl.iter (fun node eid -> 256 | let eid = Id.find self.uf eid in 257 | match Id.Map.find_opt r eid with 258 | | None -> let ls = Vector.of_list [node] in Id.Map.add r eid ls 259 | | Some ls -> Vector.push ls node 260 | ) self.hash_cons; 261 | r 262 | 263 | (* **** Export as dot *) 264 | let to_dot self = 265 | let eclasses = eclasses self in 266 | let stmt_list = 267 | let rev_map = 268 | Hashtbl.to_seq self.hash_cons 269 | |> Seq.map Pair.swap 270 | |> Id.Map.of_seq in 271 | let to_label id = 272 | let rec to_str id = 273 | match Id.Map.find_opt rev_map id with 274 | | None -> Format.to_string EClassId.pp id 275 | | Some (sym, []) -> Format.to_string Symbol.pp sym 276 | | Some (sym, children) -> 277 | Printf.sprintf "(%s %s)" 278 | (Format.to_string Symbol.pp sym) 279 | (List.to_string ~sep:" " to_str children) in 280 | to_str id in 281 | let to_label_node (sym,children) = 282 | match children with 283 | | [] -> Format.to_string Symbol.pp sym 284 | | children -> 285 | Printf.sprintf "(%s %s)" 286 | (Format.to_string Symbol.pp sym) 287 | (List.to_string ~sep:" " to_label children) in 288 | let to_id id = 289 | Odot.Double_quoted_id (to_label id) in 290 | let to_node_id node = 291 | Odot.Double_quoted_id (to_label_node node) in 292 | let to_subgraph_id id = 293 | Odot.Simple_id (Printf.sprintf "cluster_%d" (Id.repr id)) in 294 | let sub_graphs = 295 | (fun f -> Fun.flip Id.Map.iter eclasses (Fun.curry f)) 296 | |> Iter.map (fun (eclass, enodes) -> 297 | let nodes = 298 | Vector.to_iter enodes 299 | |> Iter.map (fun (node: enode) -> 300 | let node_id = to_node_id node in 301 | let attrs = Odot.[Simple_id "label", 302 | Some (Double_quoted_id 303 | (Format.to_string Symbol.pp (fst node)))] in 304 | Odot.Stmt_node ((node_id, None), attrs)) 305 | |> Iter.to_list in 306 | Odot.(Stmt_subgraph { 307 | sub_id= Some (to_subgraph_id eclass); 308 | sub_stmt_list= 309 | Stmt_attr ( 310 | Attr_graph [ 311 | (Simple_id "label", Some (Simple_id (Format.to_string EClassId.pp eclass))) 312 | ]) :: nodes; 313 | }) 314 | ) 315 | |> Iter.to_list in 316 | let edges = 317 | (fun f -> Fun.flip Id.Map.iter eclasses (Fun.curry f)) 318 | |> Iter.flat_map (fun (_eclass, enodes) -> 319 | Vector.to_iter enodes 320 | |> Iter.flat_map (fun node -> 321 | let label = to_node_id node in 322 | Iter.of_list (ENode.children node) 323 | |> Iter.map (fun child -> 324 | let child_label = to_id child in 325 | Odot.(Stmt_edge ( 326 | Edge_node_id (label, None), 327 | [Edge_node_id (child_label, None)], 328 | [] 329 | )) 330 | ) 331 | ) 332 | ) 333 | |> Iter.to_list in 334 | (List.append sub_graphs edges) in 335 | Odot.{ 336 | strict=true; 337 | kind=Digraph; 338 | id=None; 339 | stmt_list; 340 | } 341 | 342 | (* **** Print as dot *) 343 | let pp_dot fmt st = 344 | Format.pp_print_string fmt (Odot.string_of_graph (to_dot st)) 345 | 346 | let extract cost eg = 347 | let eclasses = eg.@[eclasses] in 348 | let cost_map = Id.Map.create 10 in 349 | let node_total_cost node = 350 | let has_cost id = Id.Map.mem cost_map (eg.@[find] id) in 351 | if List.for_all has_cost (Term.children node) 352 | then let cost_f id = fst @@ Id.Map.find cost_map (eg.@[find] id) in Some (cost cost_f node) 353 | else None in 354 | let make_pass enodes = 355 | let cost, node = 356 | Vector.to_iter enodes 357 | |> Iter.map (fun n -> (node_total_cost n, n)) 358 | |> Iter.min_exn ~lt:(fun (c1, _) (c2, _) -> 359 | (match c1, c2 with 360 | | None, None -> 0 361 | | Some _, None -> -1 362 | | None, Some _ -> 1 363 | | Some c1, Some c2 -> Float.compare c1 c2) = -1) in 364 | Option.map (fun cost -> (cost, node)) cost in 365 | let find_costs () = 366 | let any_changes = ref true in 367 | while !any_changes do 368 | any_changes := false; 369 | Fun.flip Id.Map.iter eclasses (fun eclass enodes -> 370 | let pass = make_pass enodes in 371 | match Id.Map.find_opt cost_map eclass, pass with 372 | | None, Some nw -> Id.Map.replace cost_map eclass nw; 373 | any_changes := true 374 | | Some ((cold, _)), Some ((cnew, _) as nw) 375 | when Float.compare cnew cold = -1 -> 376 | Id.Map.replace cost_map eclass nw; 377 | any_changes := true 378 | | _ -> () 379 | ) 380 | done in 381 | let rec extract eid = 382 | let eid = find eg eid in 383 | let enode = Id.Map.find cost_map eid |> snd in 384 | let head = Atom (Format.to_string Symbol.pp @@ fst enode) in 385 | match ENode.children enode with 386 | | [] -> head 387 | | children -> List (head :: List.map extract children) in 388 | find_costs (); 389 | fun result -> extract result 390 | 391 | 392 | (* ** Matching *) 393 | let ematch eg classes pattern = 394 | let concat_map f l = Iter.concat (Iter.map f l) in 395 | let rec enode_matches p enode env = 396 | match[@warning "-8"] p,enode with 397 | | Query.(Q (f, _), (f', _)) when not @@ (Equal.map Symbol.repr Equal.int) f f' -> 398 | Iter.empty 399 | | (Q (_, args), (_, args')) -> 400 | (fun f -> List.iter2 (Fun.curry f) args args') 401 | |> Iter.fold (fun envs (qvar, trm) -> 402 | concat_map (fun env' -> match_in qvar trm env') envs) (Iter.singleton env) 403 | and match_in p eid env = 404 | let eid = find eg eid in 405 | match p with 406 | | V id -> begin 407 | match StringMap.find_opt id env with 408 | | None -> Iter.singleton (StringMap.add id eid env) 409 | | Some eid' when Id.eq_id eid eid' -> Iter.singleton env 410 | | _ -> Iter.empty 411 | end 412 | | p -> 413 | match Id.Map.find_opt classes eid with 414 | | Some v -> Vector.to_iter v |> concat_map (fun enode -> enode_matches p enode env) 415 | | None -> Iter.empty 416 | in 417 | (fun f -> Id.Map.iter (Fun.curry f) classes) 418 | |> concat_map (fun (eid, _) -> 419 | Iter.map (fun s -> (eid, s)) (match_in pattern eid StringMap.empty)) 420 | 421 | (* ** Rewriting System *) 422 | let apply_rules eg rules = 423 | let eclasses = eclasses eg in 424 | let find_matches (from_rule, to_rule) = 425 | ematch eg eclasses from_rule |> Iter.map (lappend_pair to_rule) in 426 | let for_each_match = Iter.of_list rules |> Iter.flat_map find_matches in 427 | for_each_match begin fun (to_rule, eid, env) -> 428 | let new_eid = subst eg to_rule env in 429 | merge eg eid new_eid 430 | end; 431 | rebuild eg 432 | 433 | let run_until_saturation ?fuel eg rules = 434 | match fuel with 435 | | None -> 436 | let rec loop last_version = 437 | apply_rules eg rules; 438 | if not @@ Int.equal eg.version last_version 439 | then loop eg.version 440 | else () in 441 | loop eg.version; true 442 | | Some fuel -> 443 | let rec loop fuel last_version = 444 | apply_rules eg rules; 445 | if not @@ Int.equal eg.version last_version 446 | then if fuel > 0 447 | then loop (fuel - 1) eg.version 448 | else false 449 | else true in 450 | loop fuel eg.version 451 | 452 | 453 | end 454 | 455 | let%test "test egraph matching" = 456 | let g = EGraph.init () in 457 | let g1 = EGraph.add_sexp g [%s g 1] in 458 | let g2 = EGraph.add_sexp g [%s g 2] in 459 | EGraph.merge g g1 g2; 460 | EGraph.rebuild g; 461 | let query = Query.of_sexp [%s g "?a"] in 462 | let matches = EGraph.ematch g (EGraph.eclasses g) query |> Iter.to_list in 463 | (* Should have two matches: (g 1) and (g 2) *) 464 | Alcotest.(check int) "(g ?a) has 2 matches" 465 | 2 (List.length matches) 466 | 467 | let%test "test egraph matching" = 468 | let g = EGraph.init () in 469 | let g1 = EGraph.add_sexp g [%s g 1] in 470 | let g2 = EGraph.add_sexp g [%s g 2] in 471 | let g3 = EGraph.add_sexp g [%s g 3] in 472 | let f1 = EGraph.add_sexp g [%s (f 1 (g 2))] in 473 | let f2 = EGraph.add_sexp g [%s (f 2 (g 3))] in 474 | let f3 = EGraph.add_sexp g [%s (f 3 (g 1))] in 475 | EGraph.merge g g1 g2; 476 | EGraph.merge g g2 g3; 477 | EGraph.merge g f1 f2; 478 | EGraph.merge g f2 f3; 479 | EGraph.rebuild g; 480 | let query = Query.of_sexp [%s f "?a" (g "?a")] in 481 | let matches = EGraph.ematch g (EGraph.eclasses g) query |> Iter.to_list in 482 | Alcotest.(check int) "has 3 matches" 483 | 3 (List.length matches) 484 | -------------------------------------------------------------------------------- /lib/basic.mli: -------------------------------------------------------------------------------- 1 | type egraph 2 | 3 | module Symbol : sig 4 | type t = private int 5 | val intern : string -> t 6 | val to_string : t -> string 7 | end 8 | 9 | module Query : sig 10 | type t [@@deriving show] 11 | val of_sexp : Sexplib0.Sexp.t -> t 12 | val to_sexp : t -> Sexplib0.Sexp.t 13 | end 14 | 15 | module Rule : sig 16 | 17 | type t [@@deriving show] 18 | 19 | val make: from:Query.t -> into:Query.t -> t option 20 | 21 | end 22 | 23 | module EGraph : sig 24 | type t = egraph 25 | 26 | val pp : ?pp_id:(Format.formatter -> Id.t -> unit) -> Format.formatter -> t -> unit 27 | val pp_dot : Format.formatter -> t -> unit 28 | 29 | val init : unit -> t 30 | 31 | val add_sexp: t -> Sexplib.Sexp.t -> Id.t 32 | 33 | val to_dot : t -> Odot.graph 34 | 35 | val merge : t -> Id.t -> Id.t -> unit 36 | 37 | val rebuild : t -> unit 38 | 39 | val extract: ((Id.t -> float) -> (Symbol.t * Id.t list) -> float) -> t -> Id.t -> Sexplib0.Sexp.t 40 | 41 | val apply_rules : t -> Rule.t list -> unit 42 | 43 | val run_until_saturation: ?fuel:int -> t -> Rule.t list -> bool 44 | 45 | end 46 | 47 | -------------------------------------------------------------------------------- /lib/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name ego) 3 | (public_name ego) 4 | (libraries containers iter dot containers-data sexplib) 5 | (inline_tests) 6 | (preprocess (pps ppx_sexp ppx_inline_alcotest ppx_deriving.std))) 7 | 8 | (env 9 | (dev 10 | (flags (:standard -w -58)))) 11 | -------------------------------------------------------------------------------- /lib/ego.ml: -------------------------------------------------------------------------------- 1 | module Id = Id 2 | module Basic = Basic 3 | module Generic = struct 4 | module Query = Query 5 | module Scheduler = Scheduler 6 | include Language 7 | include Generic 8 | end 9 | -------------------------------------------------------------------------------- /lib/ego.mli: -------------------------------------------------------------------------------- 1 | (** Ego is an extensible egraph library for OCaml. The interface to 2 | Ego is loosely based on the Rust's egg library and reimplements 3 | their EClass analysis in pure OCaml. 4 | 5 | {{:#top}Ego} provides two interfaces to its equality saturation 6 | engine: 7 | 8 | 1. {!Ego.Basic} - an out-of-the-box interface to pure equality 9 | saturation (i.e supporting only syntactic rewrites). 10 | 11 | 2. {!Ego.Generic} - a higher order interface to equality saturation, 12 | parameterised over custom-user defined analyses. 13 | 14 | You may want to check out the {{:../index.html} quick start guide}. 15 | *) 16 | 17 | module Id : sig 18 | (** This module provides an implementation of an {i efficient} 19 | {b union-find} data-structure. It's main exported type, 20 | {!t}, is used to represent equivalence classes in the EGraph 21 | data-structures provided by {!Ego}. *) 22 | 23 | 24 | type t = private int 25 | (** An abstract datatype used to represent equivalence classes in 26 | {!Ego}. *) 27 | 28 | end 29 | 30 | 31 | 32 | module Basic: sig 33 | 34 | (** This module implements a {i fairly efficient} 35 | "syntactic-rewrite-only" EGraph-based equality saturation engine 36 | that operates over Sexps. 37 | 38 | The main interface to EGraph is under the module {!EGraph}. 39 | 40 | Note: This module is not safe for serialization as it uses 41 | {!Symbol.t} internally to represent strings, and so will be 42 | dependent on the execution context. If you wish to persist 43 | EGraphs across executions, check out the EGraphs defined in 44 | {!Ego.Generic} *) 45 | 46 | module Symbol : sig 47 | (** Implements an efficient encoding of strings 48 | 49 | Note: Datatypes using this module are not safe for 50 | serialization as tag associated with each string dependent on 51 | the execution context. 52 | 53 | If you wish to persist EGraphs across executions, check out the 54 | EGraphs defined in {!Ego.Generic} *) 55 | 56 | type t = private int 57 | (** Abstract type providing an efficient encoding of some string value. *) 58 | 59 | val intern : string -> t 60 | (** [intern s] returns a symbol representing the string [s]. *) 61 | 62 | val to_string : t -> string 63 | (** [to_string t] returns the string associated with symbol [t]. *) 64 | end 65 | 66 | module Query : sig 67 | (** This module encodes patterns (for both matching and 68 | transformation) over Sexprs and is part of {!Ego.Basic}'s API 69 | for expressing syntactic rewrites. *) 70 | 71 | type t 72 | (** Encodes a pattern over S-expressions. *) 73 | 74 | val pp: Format.formatter -> t -> unit 75 | (** [pp fmt s] pretty prints the query [s]. *) 76 | 77 | val show: t -> string 78 | (** [show s] converts the query [s] to a string *) 79 | 80 | val of_sexp : Sexplib0.Sexp.t -> t 81 | (** [of_sexp s] builds a pattern from a s-expression 82 | 83 | Note: Any atom prefixed with "?" will be treated as a pattern 84 | variable. 85 | 86 | For example, the following pattern will match any multiplication expressions: 87 | {[ 88 | List [Atom "*"; Atom "?a"; Atom "?b"] 89 | ]} 90 | *) 91 | 92 | val to_sexp : t -> Sexplib0.Sexp.t 93 | (** [to_sexp s] converts a pattern back into an s-expression. This is idempotent with {!of_sexp}. *) 94 | 95 | end 96 | 97 | module Rule : sig 98 | (** This module encodes syntactic rewrite rules over Sexprs and is part of {!Ego.Basic}'s API 99 | for expressing syntactic rewrites. *) 100 | 101 | type t 102 | (** Encodes a rewrite rule over S-expressions. *) 103 | 104 | val pp: Format.formatter -> t -> unit 105 | (** [pp fmt r] pretty prints the rewrite rule [r]. *) 106 | 107 | val show: t -> string 108 | (** [show r] converts the rewrite rule [r] to a string *) 109 | 110 | 111 | val make: from:Query.t -> into:Query.t -> t option 112 | (** [make ~from ~into] builds a syntactic rewrite rule from a 113 | matching pattern [from] and a result pattern [into]. 114 | 115 | Iff [into] contains variables that are not bound in [from], 116 | then the rule is invalid, and the function will return [None]. *) 117 | 118 | end 119 | 120 | module EGraph : sig 121 | (** This module defines the main interface to the EGraph provided 122 | by {!Ego.Basic}. *) 123 | 124 | type t 125 | (** Represents a syntactic-rewrite-only EGraph that operates over 126 | Sexps. *) 127 | 128 | val pp : ?pp_id:(Format.formatter -> Id.t -> unit) -> Format.formatter -> t -> unit 129 | (** [pp ?pp_id fmt graph] prints an internal representation of the 130 | [graph]. 131 | 132 | {b Note}: This is primarily intended for debugging, and the 133 | output format is not guaranteed to remain consistent over 134 | versions. *) 135 | 136 | val pp_dot : Format.formatter -> t -> unit 137 | (** [pp_dot fmt graph] pretty prints [graph] in a Graphviz format. *) 138 | 139 | val init : unit -> t 140 | (** [init ()] creates a new EGraph. *) 141 | 142 | val add_sexp : t -> Sexplib0.Sexp.t -> Id.t 143 | (** [add_sexp graph sexp] adds [sexp] to [graph] and returns the 144 | equivalence class associated with term. *) 145 | 146 | val to_dot : t -> Odot.graph 147 | (** [to_dot graph] converts [graph] into a Graphviz representation. *) 148 | 149 | val merge : t -> Id.t -> Id.t -> unit 150 | (** [merge graph id1 id2] merges the equivalence classes 151 | associated with [id1] and [id2]. 152 | 153 | {b Note}: If you call {!merge} manually, you must call 154 | {!rebuild} before running any queries or extraction. *) 155 | 156 | val rebuild : t -> unit 157 | (** [rebuild graph] restores the internal invariants of the EGraph 158 | [graph]. 159 | 160 | {b Note}: If you call {!merge} manually, you must call 161 | {!rebuild} before running any queries or extraction. *) 162 | 163 | val extract: ((Id.t -> float) -> (Symbol.t * Id.t list) -> float) -> t -> Id.t -> Sexplib0.Sexp.t 164 | (** [extract cost_fn graph] computes an extraction function [Id.t 165 | -> Sexplib0.Sexp.t] to extract terms (specified by [Id.t]) from 166 | the EGraph. 167 | 168 | [cost_fn f (sym,children)] should assign costs to the node 169 | with tag [sym] and children [children] - it can use [f] to 170 | determine the cost of a child. *) 171 | 172 | val apply_rules : t -> Rule.t list -> unit 173 | (** [apply_rules graph rules] runs each of the rewrites in [rules] 174 | exactly once over the egraph [graph] and then returns. *) 175 | 176 | val run_until_saturation: ?fuel:int -> t -> Rule.t list -> bool 177 | (** [run_until_saturation ?fuel graph rules] repeatedly each one 178 | of the rewrites in [rules] until no further changes occur ({i 179 | i.e equality saturation }), or until it runs out of [fuel]. 180 | 181 | It returns a boolean indicating whether it reached equality 182 | saturation or had to terminate early. *) 183 | end 184 | 185 | end 186 | 187 | module Generic : sig 188 | 189 | (** This module implements a generic EGraph-based equality 190 | saturation engine that operates over arbitrary user-defined 191 | languages and provides support for extensible custom user-defined 192 | EClass analyses. 193 | 194 | The main interface to EGraph is provided by the functor {!Make} 195 | which constructs an EGraph given a {!LANGUAGE} and {!ANALYSIS}, 196 | {!ANALYSIS_OPS}. 197 | 198 | You may want to check out the {{:../../index.html} quick start 199 | guide}. *) 200 | 201 | 202 | type ('node, 'analysis, 'data, 'permission) egraph 203 | (** A generic representation of an EGraph, parameterised over the 204 | language term types ['node], analysis state ['analysis] and data 205 | ['data] and read permissions ['permission]. *) 206 | 207 | module StringMap : Map.S with type key = string 208 | 209 | (** The module {!Query} encodes generic patterns (for both matching 210 | and transformation) over expressions and is part of 211 | {!Ego.Generic}'s API for expressing rewrites. *) 212 | module Query : sig 213 | 214 | type 'sym t 215 | (** Represents a query over expressions in a language with 216 | operators of type ['sym]. *) 217 | 218 | val of_sexp : (string -> 'a) -> Sexplib0.Sexp.t -> 'a t 219 | (** [of_sexp f s] constructs a query from a sexpression [s] using 220 | [f] to convert operator tags. *) 221 | 222 | val to_sexp : ('a -> string) -> 'a t -> Sexplib0.Sexp.t 223 | (** [to_sexp f q] converts a query [q] to a sexpression using [f] 224 | to convert operators in the query to strings. *) 225 | 226 | val pp : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a t -> unit 227 | (** [pp f fmt q] pretty prints a query [q] using [f] to print the 228 | operators within the query. *) 229 | 230 | val show : (Format.formatter -> 'a -> unit) -> 'a t -> string 231 | (** [show f q] prints a query [q] to string using [f] to print 232 | operators within the query. *) 233 | 234 | end 235 | 236 | (** The module {!Scheduler} provides implementations of some generic 237 | schedulers for Ego's equality saturation engine. 238 | 239 | See {!Make.BuildRunner} on how to compose a schedule with an 240 | EGraph definition. *) 241 | module Scheduler : sig 242 | 243 | (** The module {!Backoff} implements an exponential backoff 244 | scheduler. The scheduler works by tracking a maximum match 245 | limit, and (BEB) banning rules which exceed their limit. *) 246 | module Backoff : sig 247 | 248 | type t 249 | (** Represents the persistent state of the scheduler - it really 250 | just tracks the match limit and ban_length parameters chosen 251 | by for this particular instantiation. *) 252 | 253 | type data 254 | (** Represents the metadata about rules tracked by the 255 | scheduler. *) 256 | 257 | val with_params : match_limit:int -> ban_length:int -> t 258 | (** [with_params ~match_limit ~ban_length] creates a new backoff 259 | scheduler with the threshold for banning rules set to 260 | [match_limit] and the length for which rules are banned set 261 | to [ban_length]. *) 262 | 263 | val default : unit -> t 264 | (** [default ()] returns a default instance of the backoff 265 | scheduler with the threshold for banning rules set to 1_000 266 | and the initial ban_length set to 5. *) 267 | 268 | (* */ *) 269 | 270 | val create_rule_metadata : t -> 'a -> data 271 | 272 | val should_stop : t -> int -> data Iter.t -> bool 273 | 274 | val guard_rule_usage : 275 | ('node, 'analysis, 'data, 'permission) egraph -> 276 | t -> data -> int -> 277 | (unit -> (Id.t * Id.t StringMap.t) Iter.t) -> 278 | (Id.t * Id.t StringMap.t) Iter.t 279 | 280 | end 281 | 282 | 283 | (** The module {!Simple} implements a scheduler that runs every 284 | rule each time - i.e applies no scheduling at all. This works 285 | fine for rewrite systems with a finite number of EClasses but 286 | can become a problem if the number of EClasses is too large or 287 | unbounded. *) 288 | module Simple : sig 289 | type t 290 | type data 291 | val init : unit -> data 292 | val create_rule_metadata : t -> 'b -> data 293 | val should_stop : t -> int -> data -> bool 294 | val guard_rule_usage : 295 | ('node, 'analysis, 'data, 'permission) egraph -> 296 | t -> 297 | data -> 298 | int -> 299 | (data -> (Id.t * Id.t StringMap.t) Iter.t) -> 300 | (Id.t * Id.t StringMap.t) Iter.t 301 | end 302 | end 303 | 304 | (** {1:permissions Read/Write permissions} 305 | 306 | For convenience, the operations over the EGraph are split into 307 | those which {b read and write} to the graph [rw t] and those that 308 | are {b read-only} [ro t]. When defining the analysis operations, 309 | certain operations assume that the graph is not modified, so 310 | these anotations will help users to avoid violating the internal 311 | invariants of the data structure. 312 | *) 313 | 314 | type rw 315 | (** Encodes a read/write permission for a graph. *) 316 | 317 | type ro 318 | (** Encodes a read-only permission for a graph. *) 319 | 320 | (** {1:interfaces Interfaces} *) 321 | 322 | (** The {!LANGUAGE} module type represents the definition of an 323 | arbitrary language for use with an EGraph. *) 324 | module type LANGUAGE = sig 325 | 326 | type 'a shape 327 | (** Encodes the "shape" of an expression in the language over 328 | sub-expressions of type ['a]. *) 329 | 330 | type op 331 | (** Represents the tags that discriminate the expression 332 | constructors of the language. *) 333 | 334 | (** Represents concrete terms of the language by "tying-the-knot". *) 335 | type t = Mk of t shape [@@unboxed] 336 | 337 | val equal_op : op -> op -> bool 338 | (** [equal_op op1 op2] returns true if the operators [op1], [op2] are equal. *) 339 | 340 | val pp_shape : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a shape -> unit 341 | (** [pp_shape f fmt s] pretty prints expressions of the language. *) 342 | 343 | val compare_shape : ('a -> 'a -> int) -> 'a shape -> 'a shape -> int 344 | (** [compare cmp a b] compares expressions [a] and [b] using [cmp] 345 | to compare subexpressions. *) 346 | 347 | val op : 'a shape -> op 348 | (** [op expr] retrieves the tag that discriminates the shape of 349 | the expression [expr]. *) 350 | 351 | val children : 'a shape -> 'a list 352 | (** [children exp] returns the subexpressions of expression [exp]. *) 353 | 354 | val map_children : 'a shape -> ('a -> 'b) -> 'b shape 355 | (** [map_children exp f] maps the function [f] over the 356 | sub-expressions of the expression [exp] *) 357 | 358 | val make : op -> 'a list -> 'a shape 359 | (** [make op ls] constructs an expression from the tag [op] and 360 | children [ls]. 361 | 362 | {b Note}: If called with invalid arity of arguments for the 363 | operator [op] the function may throw an error. *) 364 | 365 | end 366 | 367 | (** The module type {!ANALYSIS} encodes the data-types for an 368 | abstract EClass analysis over EGraphs. *) 369 | module type ANALYSIS = sig 370 | 371 | type t 372 | (** Represents any persistent state that an analysis may need to 373 | track separately from each EClasses. 374 | 375 | {b Note}: Terms of this type must be mutated imperatively as 376 | the EGraph API doesn't provide any functions to functionally 377 | update the persisted state. *) 378 | 379 | type data 380 | (** Represents the additional analysis information that we will be 381 | attached to each EClass. *) 382 | 383 | val pp_data : Format.formatter -> data -> unit 384 | (** [pp_data fmt data] pretty prints [data] using the formatter [fmt]. *) 385 | 386 | val show_data : data -> string 387 | (** [show_data data] converts [data] into a string. *) 388 | 389 | val equal_data : data -> data -> bool 390 | (** [equal_data d1 d2] returns true iff [d1], [d2] are equal. *) 391 | 392 | val default: data 393 | (** Represents a default abstract value for new nodes. *) 394 | 395 | end 396 | 397 | (** The module type {!ANALYSIS_OPS} defines the main operations for 398 | an EClass analysis over an EGraph. *) 399 | module type ANALYSIS_OPS = sig 400 | 401 | type 'a t 402 | (** Represents the EGraph over which the analysis operates. *) 403 | 404 | type analysis 405 | (** Represents the persistent state of the analysis. *) 406 | 407 | type node 408 | (** Represents expressions of the language over which the analysis 409 | operates. *) 410 | 411 | type data 412 | (** Represents the additional analysis information that we will be 413 | attached to each EClass. *) 414 | 415 | val make : ro t -> node -> data 416 | (** [make graph node] returns the analysis data for [node]. 417 | 418 | This function is called whenever a new node is added and 419 | should generate a new abstract value for the node, usually 420 | using the abstract values of its children. 421 | 422 | {b Note}: In terms of abstract interpretation, this function 423 | can be thought of the "abstraction" function, mapping concrete 424 | terms to their corresponding values in the abstract domain. *) 425 | 426 | val merge : analysis -> data -> data -> data * (bool * bool) 427 | (** [merge st d1 d2] returns the analysis data that represents the 428 | combination of [d1] and [d2] and a tuple indicating whether the 429 | result differs from [d1] and or [d2]. 430 | 431 | This function is called whenever two equivalance classes are 432 | merged and should produce a new abstract value that represents 433 | the merge of their corresponding abstract values. 434 | 435 | {b Note}: In terms of abstract interpretation, this function 436 | can be thought of the least upper bound (lub), exposing the 437 | semi-lattice structure of the abstract domain. *) 438 | 439 | val modify : rw t -> Id.t -> unit 440 | (** [modify graph class] is used to introduce new children of an 441 | equivalence class whenever new information about its elements 442 | is found by the analysis. 443 | 444 | This function is called whenever the children or abstract 445 | values of an eclass are modified and may use the abstract value 446 | of its to modify the egraph. 447 | 448 | {b Note}: In terms of abstract interpretation, this function 449 | can be thought of the "abstraction" function, mapping concrete 450 | terms to their corresponding values in the abstract domain. *) 451 | 452 | end 453 | 454 | (** The module type {!COST} represents the definition of some 455 | arbitrary cost system for ranking expressions over some language. 456 | *) 457 | module type COST = sig 458 | 459 | type t 460 | (** Represents the type of a cost of a node. *) 461 | 462 | type node 463 | (** Represents terms of the language *) 464 | 465 | val compare : t -> t -> int 466 | (** [compare c1 c2] compares the costs [t1] and [t2] *) 467 | 468 | val cost : (Id.t -> t) -> node -> t 469 | (** [cost f node] should assign costs to the node [node]. It can 470 | use the provided function [f] to determine the cost of a 471 | child. *) 472 | 473 | end 474 | 475 | (** The module type {!SCHEDULER} represents the definition of some 476 | scheduling system for ranking rule applications during equality 477 | saturation. 478 | 479 | See {!Scheduler} for some generic schedulers. 480 | *) 481 | module type SCHEDULER = sig 482 | 483 | type 'p egraph 484 | (** Represents an EGraph with read/write permissions 485 | ['p]. *) 486 | 487 | type t 488 | (** Represents any persistent state of the scheduler that must be 489 | maintained separately to its rules. *) 490 | 491 | type data 492 | (** Represents metadata about a rule that the scheduler keeps 493 | track of in order to schedule rules. *) 494 | 495 | type rule 496 | (** Represents the type of rules over which this scheduler operates *) 497 | 498 | val default : unit -> t 499 | (** Create a default instance of the scheduler. *) 500 | 501 | val should_stop: t -> int -> data Iter.t -> bool 502 | (** [should_stop scheduler iteration data] is called whenever the 503 | EGraph reaches saturation (with the rules that have been 504 | scheduled), and should return whether further iterations should 505 | be run (i.e we will be trying a different schedule) or whether 506 | we have actually truly reached saturation. *) 507 | 508 | val create_rule_metadata: t -> rule -> data 509 | (** [create_rule_metadata scheduler rule] returns the initial 510 | metadata for a rule [rule]. *) 511 | 512 | val guard_rule_usage: 513 | rw egraph -> t -> data -> int -> 514 | (unit -> (Id.t * Id.t StringMap.t) Iter.t) -> (Id.t * Id.t StringMap.t) Iter.t 515 | (** [guard_rule_usage graph scheduler data iteration 516 | gen_matches] is called before the execution of a particular 517 | rule (represented by the callback [gen_matches]), and should 518 | return a filtered set of matches according to the scheduling 519 | of the rule. *) 520 | 521 | end 522 | 523 | (** This module {!GRAPH_API} represents the interface through which 524 | EClass analyses can interact with an EGraph. *) 525 | module type GRAPH_API = sig 526 | 527 | type 'p t 528 | (** Represents an EGraph with read permissions ['p]. *) 529 | 530 | type data 531 | (** Represents the additional analysis information that we will be 532 | attached to each EClass. *) 533 | 534 | type analysis 535 | (** Represents the persistent state of the analysis. *) 536 | 537 | type 'a shape 538 | (** Represents the shape of expressions in the language. *) 539 | 540 | type node 541 | (** Represents concrete terms of expressions in the language *) 542 | 543 | val freeze : rw t -> ro t 544 | (** [freeze graph] returns a read-only reference to the EGraph. 545 | 546 | {b Note}: it is safe to modify [graph] after passing it to 547 | freeze, this method is mainly intended to allow using the 548 | read-only APIs of the EGraph when you have a RW instance of 549 | the EGraph. *) 550 | 551 | val class_equal : ro t -> Id.t -> Id.t -> bool 552 | (** [class_equal graph cls1 cls2] returns true if and only if 553 | [cls1] and [cls2] are congruent in the EGraph [graph]. *) 554 | 555 | val iter_children : ro t -> Id.t -> Id.t shape Iter.t 556 | (** [iter_children graph cls] returns an iterator over the 557 | children of the current EClass. *) 558 | 559 | val set_data : rw t -> Id.t -> data -> unit 560 | (** [set_data graph cls data] sets the analysis data for EClass 561 | [cls] in EGraph [graph] to be [data]. *) 562 | 563 | val get_data : ro t -> Id.t -> data 564 | (** [get_data graph cls] returns the analysis data for EClass 565 | [cls] in EGraph [graph]. *) 566 | 567 | val get_analysis: rw t -> analysis 568 | (** [get_analysis graph] returns the persistent analysis sate 569 | for an EGraph. *) 570 | 571 | val add_node : rw t -> node -> Id.t 572 | (** [add_node graph term] adds the term [term] into the EGraph 573 | [graph] and returns the corresponding equivalence class. *) 574 | 575 | val merge : rw t -> Id.t -> Id.t -> unit 576 | (** [merge graph cls1 cls2] merges the two equivalence classes 577 | [cls1] and [cls2]. *) 578 | 579 | end 580 | 581 | (** This module type {!RULE} defines the rewrite interface for an 582 | EGraph, allowing users to express relatively complex 583 | transformations of expressions over some language. *) 584 | module type RULE = sig 585 | 586 | type t 587 | (** Represents rewrite rules over the language of the EGraph. *) 588 | 589 | type query 590 | (** Represents a pattern over the language of the EGraph - it can 591 | either be used to {i match} and {i bind} a particular 592 | subpattern in an expression, or can be used to express the 593 | output schema for a rewrite. *) 594 | 595 | type 'p egraph 596 | (** Represents an EGraph with read/write permissions 597 | ['p]. *) 598 | 599 | val make_constant : from:query -> into:query -> t 600 | (** [make_constant ~from ~into] creates a rewrite rule from a 601 | pattern [from] into a schema [into] that applies a purely 602 | syntactic transformation. *) 603 | 604 | val make_conditional : 605 | from:query -> 606 | into:query -> 607 | cond:(rw egraph -> Id.t -> Id.t StringMap.t -> bool) -> 608 | t 609 | (** [make_conditional ~from ~into ~cond] creates a syntactic 610 | rewrite rule from [from] to [into] that is conditionally 611 | applied based on some property [cond] of the EGraph, the root 612 | eclass of the sub-expression being transformed and the eclasses 613 | of all bound variables. *) 614 | 615 | val make_dynamic : 616 | from:query -> 617 | generator:(rw egraph -> Id.t -> Id.t StringMap.t -> query option) -> t 618 | (** [make_dynamic ~from ~generator] creates a dynamic rewrite 619 | rule from a pattern [from] into a schema that is 620 | conditionally generated based on properties of the EGraph, 621 | the root eclass of the sub-expression being transformed and 622 | the eclasses of all bound variables *) 623 | 624 | end 625 | 626 | (** {1:constructors EGraph Constructors} *) 627 | 628 | (** This functor {!MakePrinter} allows users to construct EGraph 629 | printing utilities for a given {!LANGUAGE} and {!ANALYSIS}. *) 630 | module MakePrinter : functor (L : LANGUAGE) (A : ANALYSIS) -> sig 631 | 632 | (* val pp : Format.formatter -> (Id.t L.shape, A.t, A.data, 'b) egraph -> unit 633 | * (\** [pp fmt graph] pretty prints an internal representation of 634 | * the graph. 635 | * 636 | * {b Note}: This is primarily intended for debugging, and the 637 | * output format is not guaranteed to remain consistent over 638 | * versions. *\) *) 639 | 640 | val to_dot : (Id.t L.shape, A.t, A.data, 'b) egraph -> Odot.graph 641 | (** [to_dot graph] converts an EGraph into a Graphviz 642 | representation for debugging. *) 643 | 644 | end 645 | 646 | (** This functor {!MakeExtractor} allows users to construct an 647 | EGraph extraction procedure for a given {!LANGUAGE} and {!COST} 648 | system. *) 649 | module MakeExtractor : functor 650 | (L : LANGUAGE) 651 | (E : COST with type node := Id.t L.shape) -> sig 652 | 653 | val extract : (Id.t L.shape, 'a, 'b, rw) egraph -> Id.t -> L.t 654 | (** [extract graph] computes an extraction function [Id.t -> 655 | Sexplib0.Sexp.t] to extract concrete terms of the language {!L} 656 | from their respective EClasses (specified by [Id.t]) from the 657 | EGraph according to the cost system {!E}. *) 658 | 659 | end 660 | 661 | 662 | (** This functor {!Make} serves as the main interface to Ego's 663 | generic EGraphs, and constructs an EGraph given a {!LANGUAGE}, an 664 | {!ANALYSIS} and it's {!ANALYSIS_OPS}. *) 665 | module Make : 666 | functor 667 | (L : LANGUAGE) 668 | (A : ANALYSIS) 669 | (MakeAnalysisOps : functor 670 | (S : GRAPH_API with type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph 671 | and type analysis := A.t 672 | and type data := A.data 673 | and type 'a shape := 'a L.shape 674 | and type node := L.t) -> 675 | ANALYSIS_OPS with type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph 676 | and type analysis := A.t 677 | and type data := A.data 678 | and type node := Id.t L.shape) -> 679 | sig 680 | 681 | type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph 682 | (** This type represents an EGraph parameterised over a 683 | particular language {!L} and analysis {!A}. *) 684 | 685 | (** This module {!Rule} defines the rewrite interface for the 686 | EGraph, allowing users to express relatively complex 687 | transformations of expressions of the Language {!L}. *) 688 | module Rule: 689 | RULE with type query := L.op Query.t 690 | and type 'a egraph := (Id.t L.shape, A.t, A.data, 'a) egraph 691 | 692 | 693 | val freeze : rw t -> ro t 694 | (** [freeze graph] returns a read-only reference to the EGraph. 695 | 696 | {b Note}: it is safe to modify [graph] after passing it to 697 | freeze, this method is mainly intended to allow using the 698 | read-only APIs of the EGraph when you have a RW instance of 699 | the EGraph. *) 700 | 701 | val init : A.t -> 'p t 702 | (** [init analysis] creates a new EGraph with an initial 703 | persistent analysis state of [analysis]. *) 704 | 705 | val class_equal: ro t -> Id.t -> Id.t -> bool 706 | (** [class_equal graph cls1 cls2] returns true if and only if 707 | [cls1] and [cls2] are congruent in the EGraph [graph]. *) 708 | 709 | val set_data : rw t -> Id.t -> A.data -> unit 710 | (** [set_data graph cls data] sets the analysis data for EClass 711 | [cls] in EGraph [graph] to be [data]. *) 712 | 713 | val get_data : _ t -> Id.t -> A.data 714 | (** [get_data graph cls] returns the analysis data for EClass 715 | [cls] in EGraph [graph]. *) 716 | 717 | val get_analysis: rw t -> A.t 718 | (** [get_analysis graph] returns the persistent analysis sate 719 | for an EGraph. *) 720 | 721 | val iter_children : ro t -> Id.t -> Id.t L.shape Iter.t 722 | (** [iter_children graph cls] returns an iterator over the 723 | elements of an eclass [cls]. *) 724 | 725 | (* val pp : Format.formatter -> (Id.t L.shape, 'a, A.data, _) egraph -> unit 726 | * (\** [pp fmt graph] pretty prints an internal representation of 727 | * the graph. 728 | * 729 | * {b Note}: This is primarily intended for debugging, and the 730 | * output format is not guaranteed to remain consistent over 731 | * versions. *\) *) 732 | 733 | val to_dot : (Id.t L.shape, A.t, A.data, _) egraph -> Odot.graph 734 | (** [to_dot graph] converts an EGraph into a Graphviz 735 | representation for debugging. *) 736 | 737 | val add_node : rw t -> L.t -> Id.t 738 | (** [add_node graph term] adds the term [term] into the EGraph 739 | [graph] and returns the corresponding equivalence class. *) 740 | 741 | val merge : rw t -> Id.t -> Id.t -> unit 742 | (** [merge graph cls1 cls2] merges the two equivalence classes 743 | [cls1] and [cls2]. *) 744 | 745 | val rebuild : rw t -> unit 746 | (** [rebuild graph] restores the internal invariants of the 747 | graph. 748 | 749 | {b Note}: If you call {!merge} manually (i.e outside of 750 | analysis functions), you must call {!rebuild} before running 751 | any queries or extraction. *) 752 | 753 | val find_matches : ro t -> L.op Query.t -> (Id.t * Id.t StringMap.t) Iter.t 754 | (** [find_matches graph query] returns an iterator over each 755 | match of the query [query] in the EGraph. *) 756 | 757 | val apply_rules : (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> unit 758 | (** [apply_rules graph rules] runs each of the rewrites in [rules] 759 | exactly once over the egraph [graph] and then returns. *) 760 | 761 | val run_until_saturation: 762 | ?scheduler:Scheduler.Backoff.t -> 763 | ?node_limit:[`Bounded of int | `Unbounded] -> 764 | ?fuel:[`Bounded of int | `Unbounded] -> 765 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) -> 766 | (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool 767 | (** [run_until_saturation ?scheduler ?node_limit ?fuel ?until 768 | graph rules] repeatedly each one of the rewrites in [rules] 769 | according to the scheduler [scheduler] until no further 770 | changes occur ({i i.e equality saturation }), or until it 771 | runs out of [fuel] (defaults to 30) or reaches a [node_limit] 772 | if supplied (defaults to 10_000) or some predicate [until] is 773 | satisfied. 774 | 775 | It returns a boolean indicating whether it reached equality 776 | saturation or had to terminate early. *) 777 | 778 | (** The module {!BuildRunner} allows users to supply their own 779 | custom domain-specific scheduling strategies for equality 780 | saturation by supplying a corresponding Scheduling module 781 | satisfying {!SCHEDULER} *) 782 | module BuildRunner (S : SCHEDULER 783 | with type 'a egraph := (Id.t L.shape, A.t, A.data, rw) egraph 784 | and type rule := Rule.t) : 785 | sig 786 | 787 | val run_until_saturation : 788 | ?scheduler:S.t -> 789 | ?node_limit:[`Bounded of int | `Unbounded] -> 790 | ?fuel:[`Bounded of int | `Unbounded] -> 791 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) -> 792 | (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool 793 | (** [run_until_saturation ?scheduler ?node_limit ?fuel 794 | ?until graph rules] repeatedly each one of the rewrites 795 | in [rules] according to the scheduler [scheduler] until 796 | no further changes occur ({i i.e equality saturation }), 797 | or until it runs out of [fuel] (defaults to 30) or 798 | reaches some [node_limit] (defaults to 10_000) or some 799 | predicate [until] is satisfied. 800 | 801 | It returns a boolean indicating whether it reached 802 | equality saturation or had to terminate early. *) 803 | 804 | end 805 | 806 | end 807 | 808 | end 809 | -------------------------------------------------------------------------------- /lib/equivalence.ml: -------------------------------------------------------------------------------- 1 | open Containers 2 | (* module IntMap = Map.Make (Int) *) 3 | module IntMap = Hashtbl.Make (Int) 4 | module IntSet = CCHashSet.Make (Int) 5 | 6 | module Make = functor () -> struct 7 | 8 | type elem = 9 | | Root of int 10 | | Link of elem_ref 11 | and elem_ref = int 12 | and store = { 13 | mutable limit: int; 14 | content: elem IntMap.t 15 | } 16 | 17 | type t = elem_ref 18 | 19 | let repr v = v 20 | 21 | let (.@[]) store rf = IntMap.find store.content rf 22 | let (.@[]<-) store rf vl = IntMap.replace store.content rf vl 23 | 24 | let create_store () = {limit=0; content=IntMap.create 100} 25 | 26 | let hash = Int.hash 27 | 28 | let rref (store: store) vl = 29 | let x = store.limit in 30 | store.limit <- x + 1; 31 | IntMap.replace store.content x vl; 32 | x 33 | 34 | let make_raw = 35 | let id = ref 0 in 36 | fun () -> incr id; (Root !id) 37 | 38 | let make store () = 39 | rref store @@ make_raw () 40 | 41 | let rec find store x = 42 | match store.@[x] with 43 | | Root _ -> x 44 | | Link y -> 45 | let z = find store y in 46 | if not @@ Equal.physical z y then 47 | store.@[x] <- Link z; 48 | z 49 | let equal store t1 t2 = 50 | let t1 = find store t1 in 51 | let t2 = find store t2 in 52 | Equal.physical t1 t2 53 | 54 | let link store x y = 55 | if Equal.physical x y then x 56 | else match[@warning "-8"] store.@[x], store.@[y] with 57 | | Root _, Root _ -> store.@[y] <- Link x; x 58 | (* if vx < vy then (store.@[x] <- Link y; y) 59 | * else if vy > vx then (store.@[y] <- Link x; x) 60 | * else (store.@[y] <- Link x; 61 | * store.@[x] <- make_raw (); 62 | * x) *) 63 | 64 | let union store x y = 65 | let x = find store x in 66 | let y = find store y in 67 | link store x y 68 | 69 | module Map = IntMap 70 | 71 | module Set = IntSet 72 | 73 | end 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /lib/equivalence.mli: -------------------------------------------------------------------------------- 1 | module Make : functor () -> sig 2 | type store 3 | type t = private int 4 | 5 | val repr : t -> int 6 | val create_store : unit -> store 7 | val make : store -> unit -> t 8 | val find : store -> t -> t 9 | 10 | val equal : store -> t -> t -> bool 11 | val union : store -> t -> t -> t 12 | val hash: t -> int 13 | 14 | module Map : Hashtbl.S with type key = t 15 | 16 | module Set : CCHashSet.S with type elt = t 17 | end 18 | -------------------------------------------------------------------------------- /lib/generic.ml: -------------------------------------------------------------------------------- 1 | open [@warning "-33"] Containers 2 | open Language 3 | open Types 4 | module Id = Id 5 | 6 | let dedup cmp vec = 7 | let prev = ref None in 8 | Vector.filter_in_place (fun elt -> 9 | match !prev with 10 | | None -> prev := Some elt; true 11 | | Some last_value -> 12 | if Int.equal (cmp last_value elt) 0 13 | then false 14 | else begin 15 | prev := Some elt; 16 | true 17 | end 18 | ) vec 19 | 20 | 21 | (* let lappend_pair a (b,c) = (a,b,c) *) 22 | type 'a query = 'a Query.t 23 | 24 | type ('node, 'data) eclass = { 25 | mutable id: Id.t; 26 | nodes: 'node Vector.vector; 27 | mutable data: 'data; 28 | parents: ('node * Id.t) Vector.vector; 29 | } 30 | 31 | type ('node, 'analysis, 'data, 'permission) egraph = { 32 | mutable version: int; 33 | analysis: 'analysis; 34 | 35 | uf: Id.store; (* tracks equivalence classes of 36 | class ids *) 37 | class_data: 38 | ('node, 'data) eclass Id.Map.t; (* maps classes to the canonical nodes 39 | they contain, and any classes that are 40 | children of these nodes *) 41 | hash_cons: ('node, Id.t) Hashtbl.t; (* maps cannonical nodes to their 42 | equivalence classes *) 43 | pending: ('node * Id.t) Vector.vector; 44 | 45 | pending_analysis: ('node * Id.t) Vector.vector; 46 | } 47 | 48 | 49 | module MakeInt (L: LANGUAGE) (* (A: ANALYSIS) *) = struct 50 | 51 | let (.@[]) self fn = fn self [@@inline always] 52 | 53 | (* *** Initialization *) 54 | let init analysis = { 55 | version=0; 56 | analysis; 57 | uf=Id.create_store (); 58 | class_data=Id.Map.create 10; 59 | hash_cons=Hashtbl.create 10; 60 | pending=Vector.create (); 61 | pending_analysis=Vector.create (); 62 | } 63 | 64 | (* *** Eclasses *) 65 | let get_analysis self = self.analysis 66 | 67 | let get_class_data self id = 68 | match Id.Map.find_opt self.class_data id with 69 | | Some data -> data 70 | | None -> failwith @@ Printf.sprintf "attempted to set the data of an unbound class %s " (EClassId.show id) 71 | 72 | let remove_class_data self id = 73 | match Id.Map.find_opt self.class_data id with 74 | | Some classes -> Id.Map.remove self.class_data id; Some classes 75 | | None -> None 76 | 77 | let set_data self id data = 78 | match Id.Map.find_opt self.class_data id with 79 | | None -> failwith @@ Printf.sprintf "attempted to set the data of an unbound class %s " (EClassId.show id) 80 | | Some class_data -> class_data.data <- data 81 | 82 | let get_data self id = 83 | match Id.Map.find_opt self.class_data (Id.find self.uf id) with 84 | | None -> failwith @@ Printf.sprintf "attempted to get the data of an unbound class %s " (EClassId.show id) 85 | | Some class_data -> class_data.data 86 | 87 | let canonicalise self node = L.map_children node (Id.find self.uf) 88 | 89 | let find self vl = Id.find self.uf vl 90 | 91 | (* *** Exports *) 92 | (* **** Export eclasses *) 93 | let eclasses self = 94 | let r = Id.Map.create 10 in 95 | Hashtbl.iter (fun node eid -> 96 | let eid = Id.find self.uf eid in 97 | match Id.Map.find_opt r eid with 98 | | None -> let ls = Vector.of_list [node] in Id.Map.add r eid ls 99 | | Some ls -> Vector.push ls node 100 | ) self.hash_cons; 101 | r 102 | 103 | let class_equal self cls1 cls2 = 104 | Id.equal self.uf cls1 cls2 105 | 106 | end 107 | 108 | module MakePrinter (L: LANGUAGE) (A: ANALYSIS) = struct 109 | 110 | open (MakeInt(L)) 111 | 112 | (* **** Export as dot *) 113 | let to_dot self = 114 | let eclasses = eclasses self in 115 | 116 | let pp_node_by_id fmt id = 117 | let pp_node_by_id fmt id = 118 | let id = self.@[find] id in 119 | begin 120 | let vls = Id.Map.find_opt eclasses id |> Option.get_lazy Vector.create in 121 | let open Format in 122 | pp_print_string fmt "{"; 123 | pp_open_hovbox fmt 1; 124 | Vector.pp 125 | ~pp_sep:(fun fmt () -> pp_print_string fmt ","; pp_print_space fmt ()) 126 | (L.pp_shape EClassId.pp) fmt vls; 127 | pp_close_box fmt (); 128 | pp_print_string fmt "}" 129 | end in 130 | pp_node_by_id fmt id in 131 | let stmt_list = 132 | let rev_map = 133 | Hashtbl.to_seq self.hash_cons 134 | |> Seq.map Pair.swap 135 | |> Id.Map.of_seq in 136 | let to_label id = 137 | let to_str id = 138 | match Id.Map.find_opt rev_map id with 139 | | None -> Format.to_string EClassId.pp id 140 | | Some node -> Format.to_string (L.pp_shape pp_node_by_id) node in 141 | to_str id in 142 | let to_id id = 143 | Odot.Double_quoted_id (to_label id) in 144 | let to_node_id node = 145 | Odot.Double_quoted_id (Format.to_string (L.pp_shape pp_node_by_id) node) in 146 | let to_subgraph_id id = 147 | Odot.Simple_id (Printf.sprintf "cluster_%d" (Id.repr id)) in 148 | let eclass_label eclass = 149 | let eclass_txt = Format.to_string EClassId.pp eclass in 150 | let data = get_data self eclass |> A.show_data in 151 | eclass_txt ^ " = " ^ data in 152 | let sub_graphs = 153 | (fun f -> Fun.flip Id.Map.iter eclasses (Fun.curry f)) 154 | |> Iter.map (fun (eclass, (enodes: (Id.t L.shape, _) Vector.t)) -> 155 | let nodes = 156 | Vector.to_iter enodes 157 | |> Iter.map (fun (node: Id.t L.shape) -> 158 | let node_id = to_node_id node in 159 | let attrs = Odot.[Simple_id "label", 160 | Some (Double_quoted_id 161 | (Format.to_string (L.pp_shape pp_node_by_id) node))] in 162 | Odot.Stmt_node ((node_id, None), attrs)) 163 | |> Iter.to_list in 164 | Odot.(Stmt_subgraph { 165 | sub_id= Some (to_subgraph_id eclass); 166 | sub_stmt_list= 167 | Stmt_attr ( 168 | Attr_graph [ 169 | (Simple_id "label", Some (Double_quoted_id (eclass_label eclass))) 170 | ]) :: nodes; 171 | }) 172 | ) 173 | |> Iter.to_list in 174 | let edges = 175 | (fun f -> Fun.flip Id.Map.iter eclasses (Fun.curry f)) 176 | |> Iter.flat_map (fun (_eclass, enodes) -> 177 | Vector.to_iter enodes 178 | |> Iter.flat_map (fun node -> 179 | let label = to_node_id node in 180 | Iter.of_list (L.children node) 181 | |> Iter.map (fun child -> 182 | let child_label = to_id child in 183 | Odot.(Stmt_edge ( 184 | Edge_node_id (label, None), 185 | [Edge_node_id (child_label, None)], 186 | [] 187 | )) 188 | ) 189 | ) 190 | ) 191 | |> Iter.to_list in 192 | (List.append sub_graphs edges) in 193 | Odot.{ 194 | strict=true; 195 | kind=Digraph; 196 | id=None; 197 | stmt_list; 198 | } 199 | 200 | (* **** Print as dot *) 201 | let pp_dot fmt st = 202 | Format.pp_print_string fmt (Odot.string_of_graph (to_dot st)) 203 | 204 | end 205 | 206 | module MakeExtractor (L: LANGUAGE) (E: COST with type node := Id.t L.shape) = struct 207 | 208 | open (MakeInt(L)) 209 | 210 | let extract eg = 211 | let eclasses = eg.@[eclasses] in 212 | let cost_map = Id.Map.create 10 in 213 | let node_total_cost node = 214 | let has_cost id = Id.Map.mem cost_map (eg.@[find] id) in 215 | if List.for_all has_cost (L.children node) 216 | then let cost_f id = fst @@ Id.Map.find cost_map (eg.@[find] id) in Some (E.cost cost_f node) 217 | else None in 218 | let make_pass enodes = 219 | let cost, node = 220 | Vector.to_iter enodes 221 | |> Iter.map (fun n -> (node_total_cost n, n)) 222 | |> Iter.min_exn ~lt:(fun (c1, _) (c2, _) -> 223 | (match c1, c2 with 224 | | None, None -> 0 225 | | Some _, None -> -1 226 | | None, Some _ -> 1 227 | | Some c1, Some c2 -> E.compare c1 c2) = -1) in 228 | Option.map (fun cost -> (cost, node)) cost in 229 | let find_costs () = 230 | let any_changes = ref true in 231 | while !any_changes do 232 | any_changes := false; 233 | Fun.flip Id.Map.iter eclasses (fun eclass enodes -> 234 | let pass = make_pass enodes in 235 | match Id.Map.find_opt cost_map eclass, pass with 236 | | None, Some nw -> Id.Map.replace cost_map eclass nw; any_changes := true 237 | | Some ((cold, _)), Some ((cnew, _) as nw) 238 | when E.compare cnew cold = -1 -> 239 | Id.Map.replace cost_map eclass nw; any_changes := true 240 | | _ -> () 241 | ) 242 | done in 243 | let rec extract eid = 244 | let eid = eg.@[find] eid in 245 | let enode = Id.Map.find cost_map eid |> snd in 246 | let head = L.op enode in 247 | let children = L.children enode in 248 | L.Mk (L.make head @@ List.map extract children) in 249 | find_costs (); 250 | fun result -> extract result 251 | 252 | end 253 | 254 | (* ** Graph *) 255 | module MakeOps 256 | (L: LANGUAGE) 257 | (A: ANALYSIS) 258 | (AM: sig 259 | val make: (Id.t L.shape, A.t, A.data, ro) egraph -> Id.t L.shape -> A.data 260 | val merge: A.t -> A.data -> A.data -> A.data * (bool * bool) 261 | val modify: (Id.t L.shape, A.t, A.data, rw) egraph -> Id.t -> unit 262 | end) = 263 | struct 264 | 265 | open (MakeInt (L)) 266 | 267 | module Rule = struct 268 | 269 | type rule_output = 270 | | Constant of L.op Query.t 271 | | Conditional of 272 | L.op Query.t * 273 | ((Id.t L.shape, A.t, A.data, rw) egraph -> eclass_id -> eclass_id StringMap.t -> bool) 274 | | Dynamic of 275 | ((Id.t L.shape, A.t, A.data, rw) egraph -> eclass_id -> eclass_id StringMap.t -> L.op Query.t option) 276 | 277 | type t = L.op Query.t * rule_output 278 | 279 | let make_constant ~from ~into = (from, Constant into) 280 | let make_conditional ~from ~into ~cond = (from, Conditional (into, cond)) 281 | let make_dynamic ~from ~generator = (from, Dynamic generator) 282 | 283 | end 284 | 285 | let new_class self = 286 | let id = Id.make self.uf () in 287 | Id.Map.add self.class_data id {id; nodes=Vector.create (); data=A.default; parents=Vector.create ()}; 288 | id 289 | 290 | let freeze (graph: (_, _, _, rw) egraph) = (graph:> (_, _, _, ro) egraph) 291 | 292 | (* Adds a node into the egraph, assuming that the cannonical version 293 | of the node is up to date in the hash cons or 294 | *) 295 | let add_enode self (node: Id.t L.shape) = 296 | let node = self.@[canonicalise] node in 297 | let id = match Hashtbl.find_opt self.hash_cons node with 298 | | None -> 299 | self.version <- self.version + 1; 300 | let id = Id.make self.uf () in 301 | let cls = { 302 | id; 303 | nodes=Vector.of_list [node]; 304 | data = AM.make (freeze self) node; 305 | parents=Vector.create () 306 | } in 307 | 308 | List.iter (fun child -> 309 | let tup = (node, id) in 310 | Vector.push ((self.@[get_class_data] child).parents) tup 311 | ) (L.children node); 312 | 313 | Vector.push self.pending (node,id); 314 | 315 | Id.Map.add self.class_data id cls; 316 | 317 | Hashtbl.add self.hash_cons node id; 318 | 319 | AM.modify self id; 320 | id 321 | | Some id -> self.@[find] id in 322 | Id.find self.uf id 323 | 324 | let rec add_node self (L.Mk op: L.t) : Id.t = 325 | add_enode self @@ L.map_children op (add_node self) 326 | 327 | let rec subst self pat env = 328 | match pat with 329 | | Query.V id -> StringMap.find id env 330 | | Q (sym, args) -> 331 | let enode = L.make sym (List.map (fun arg -> self.@[subst] arg env) args) in 332 | self.@[add_enode] enode 333 | 334 | let merge self id1 id2 = 335 | let (+=) va vb = Vector.append va vb in 336 | let id1 = Id.find self.uf id1 in 337 | let id2 = Id.find self.uf id2 in 338 | if Id.eq_id id1 id2 then () 339 | else begin 340 | self.version <- self.version + 1; 341 | (* cls2 has fewer children *) 342 | let id1, id2 = 343 | if Vector.length (self.@[get_class_data] id1).parents < Vector.length (self.@[get_class_data] id2).parents 344 | then (id2, id1) 345 | else (id1, id2) in 346 | 347 | (* make cls1 the new root *) 348 | assert (Id.eq_id id1 (Id.union self.uf id1 id2)); 349 | 350 | let cls2 = self.@[remove_class_data] id2 351 | |> Option.get_exn_or "Invariant violation" in 352 | let cls1 = self.@[get_class_data] id1 in 353 | assert (Id.eq_id id1 cls1.id); 354 | 355 | self.pending += cls2.parents; 356 | 357 | let (did_update_cls1, did_update_cls2) = 358 | let data, res = (AM.merge self.analysis cls1.data cls2.data) in 359 | cls1.data <- data; 360 | res in 361 | 362 | if did_update_cls1 then self.pending_analysis += cls1.parents; 363 | if did_update_cls2 then self.pending_analysis += cls2.parents; 364 | 365 | cls1.nodes += cls2.nodes; 366 | cls1.parents += cls2.parents; 367 | AM.modify self id1 368 | end 369 | 370 | let rebuild_classes self = 371 | Id.Map.to_seq_values self.class_data |> Seq.iter (fun cls -> 372 | Vector.map_in_place (fun node -> self.@[canonicalise] node) cls.nodes; 373 | Vector.sort' (L.compare_shape EClassId.compare) cls.nodes; 374 | dedup (L.compare_shape EClassId.compare) cls.nodes 375 | ) 376 | 377 | let process_unions self = 378 | (* let init_size = Hashtbl.length self.hash_cons in *) 379 | while not @@ Vector.is_empty self.pending do 380 | 381 | let rec update_hash_cons () = 382 | match Vector.pop self.pending with 383 | | None -> () 384 | | Some (node,cls) -> 385 | let old_node = node in 386 | let node = self.@[canonicalise] node in 387 | if not @@ ((L.compare_shape EClassId.compare old_node node) = 0) then 388 | Hashtbl.remove self.hash_cons old_node; 389 | begin match (Hashtbl.find_opt self.hash_cons node) with 390 | | None -> Hashtbl.add self.hash_cons node cls 391 | | Some memo_cls -> self.@[merge] memo_cls cls 392 | end; 393 | update_hash_cons () in 394 | update_hash_cons (); 395 | 396 | let rec update_analysis () = 397 | match Vector.pop self.pending_analysis with 398 | | None -> () 399 | | Some (node, class_id) -> 400 | let class_id = self.@[find] class_id in 401 | let node_data = AM.make (freeze self) node in 402 | let cls = self.@[get_class_data] class_id in 403 | assert (Id.eq_id cls.id class_id); 404 | let (did_update_left, _did_update_right) = 405 | let data,res = AM.merge self.analysis cls.data node_data in 406 | cls.data <- data; 407 | res in 408 | if did_update_left then begin 409 | Vector.append self.pending_analysis cls.parents; 410 | AM.modify self class_id 411 | end; 412 | update_analysis () in 413 | update_analysis () 414 | done 415 | (* let _final_size = Hashtbl.length self.hash_cons in 416 | * print_endline @@ Printf.sprintf "after rebuilding size of nodes is %d => %d" init_size final_size *) 417 | 418 | let rebuild (self: (Id.t L.shape, 'b, 'c, rw) egraph) = 419 | process_unions self; 420 | rebuild_classes self 421 | 422 | (* ** Matching *) 423 | let ematch eg (classes: (Id.t L.shape, 'a) Vector.t Id.Map.t) pattern = 424 | let concat_map f l = Iter.concat (Iter.map f l) in 425 | let rec enode_matches p enode env = 426 | match[@warning "-8"] p with 427 | | Query.Q (f, _) when not @@ L.equal_op f (L.op enode) -> 428 | Iter.empty 429 | | Q (_, args) -> 430 | (fun f -> List.iter2 (Fun.curry f) args (L.children enode)) 431 | |> Iter.fold (fun envs (qvar, trm) -> 432 | concat_map (fun env' -> match_in qvar trm env') envs) (Iter.singleton env) 433 | and match_in p eid env = 434 | let eid = find eg eid in 435 | match p with 436 | | V id -> begin 437 | match StringMap.find_opt id env with 438 | | None -> Iter.singleton (StringMap.add id eid env) 439 | | Some eid' when Id.eq_id eid eid' -> Iter.singleton env 440 | | _ -> Iter.empty 441 | end 442 | | p -> 443 | match Id.Map.find_opt classes eid with 444 | | Some v -> Vector.to_iter v |> concat_map (fun enode -> enode_matches p enode env) 445 | | None -> Iter.empty 446 | in 447 | (fun f -> Id.Map.iter (Fun.curry f) classes) 448 | |> concat_map (fun (eid, _) -> 449 | Iter.map (fun s -> (eid, s)) (match_in pattern eid StringMap.empty)) 450 | 451 | let find_matches eg = 452 | let eclasses = eclasses eg in 453 | fun rule -> ematch eg eclasses rule 454 | 455 | let iter_children self cls = 456 | (* let old_cls = cls in *) 457 | let cls = (self.@[find] cls) in 458 | Id.Map.find_opt (eclasses self) cls |> Option.map Vector.to_iter |> Option.get_or ~default:Iter.empty 459 | 460 | module BuildRunner (S : SCHEDULER with type 'a egraph := (Id.t L.shape, A.t, A.data, rw) egraph 461 | and type rule := Rule.t) = struct 462 | 463 | (* ** Rewriting System *) 464 | let apply_rules scheduler iteration (eg: (Id.t L.shape, _, _, _) egraph) (rules : (Rule.t * S.data) array) = 465 | let find_matches = find_matches eg in 466 | let for_each_match = 467 | Iter.of_array rules 468 | |> Iter.flat_map (fun ((from_rule, to_rule), meta_data) -> 469 | S.guard_rule_usage eg scheduler meta_data iteration (fun () -> find_matches from_rule) 470 | |> Iter.map (fun (eid,env) -> (to_rule, eid, env)) 471 | ) in 472 | for_each_match begin fun (to_rule, eid, env) -> 473 | match to_rule with 474 | | Rule.Constant to_rule -> 475 | let new_eid = subst eg to_rule env in 476 | merge eg eid new_eid 477 | | Conditional (to_rule, cond) -> 478 | if cond eg eid env then 479 | let new_eid = subst eg to_rule env in 480 | merge eg eid new_eid 481 | else () 482 | | Dynamic cond -> 483 | match cond eg eid env with 484 | | None -> () 485 | | Some to_rule -> 486 | let new_eid = subst eg to_rule env in 487 | merge eg eid new_eid 488 | end; 489 | rebuild eg 490 | 491 | let run_until_saturation ?scheduler ?(node_limit=`Bounded 10_000) ?(fuel=`Bounded 30) ?until eg rules = 492 | let scheduler = match scheduler with None -> S.default () | Some scheduler -> scheduler in 493 | let rules = Iter.of_list rules 494 | |> Iter.map (fun rule -> (rule, S.create_rule_metadata scheduler rule)) 495 | |> Iter.to_array in 496 | let rule_data () = Array.to_iter rules |> Iter.map snd in 497 | match fuel, node_limit, until with 498 | | `Unbounded, `Unbounded, None -> 499 | let rec loop last_version ind = 500 | apply_rules scheduler ind eg rules; 501 | if not @@ Int.equal eg.version last_version 502 | then loop eg.version (ind + 1) 503 | else if S.should_stop scheduler ind (rule_data ()) then () else loop eg.version (ind + 1) in 504 | loop eg.version 0; true 505 | | `Unbounded, `Unbounded, Some pred -> 506 | let rec loop last_version ind = 507 | apply_rules scheduler ind eg rules; 508 | if not @@ Int.equal eg.version last_version 509 | then if pred eg then false else loop eg.version (ind + 1) 510 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in 511 | loop eg.version 0 512 | | `Unbounded, `Bounded node_limit, None -> 513 | let rec loop last_version ind = 514 | apply_rules scheduler ind eg rules; 515 | if not @@ Int.equal eg.version last_version 516 | then if Hashtbl.length eg.hash_cons < node_limit 517 | then loop eg.version (ind + 1) 518 | else false 519 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in 520 | loop eg.version 0 521 | | `Unbounded, `Bounded node_limit, Some pred -> 522 | let rec loop last_version ind = 523 | apply_rules scheduler ind eg rules; 524 | if not @@ Int.equal eg.version last_version 525 | then if Hashtbl.length eg.hash_cons < node_limit 526 | then if pred eg then false else loop eg.version (ind + 1) 527 | else false 528 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in 529 | loop eg.version 0 530 | | `Bounded fuel, `Unbounded, None -> 531 | let rec loop last_version ind = 532 | apply_rules scheduler ind eg rules; 533 | if not @@ Int.equal eg.version last_version 534 | then if fuel > ind 535 | then loop eg.version (ind + 1) 536 | else false 537 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in 538 | loop eg.version 0 539 | | `Bounded fuel, `Unbounded, Some pred -> 540 | let rec loop last_version ind = 541 | apply_rules scheduler ind eg rules; 542 | if not @@ Int.equal eg.version last_version 543 | then if fuel > ind 544 | then if pred eg then false else loop eg.version (ind + 1) 545 | else false 546 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in 547 | loop eg.version 0 548 | | `Bounded fuel, `Bounded node_limit, None -> 549 | let rec loop last_version ind = 550 | apply_rules scheduler ind eg rules; 551 | if not @@ Int.equal eg.version last_version 552 | then if fuel > ind && Hashtbl.length eg.hash_cons < node_limit 553 | then loop eg.version (ind + 1) 554 | else false 555 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in 556 | loop eg.version 0 557 | | `Bounded fuel, `Bounded node_limit, Some pred -> 558 | let rec loop last_version ind = 559 | apply_rules scheduler ind eg rules; 560 | if not @@ Int.equal eg.version last_version 561 | then if fuel > ind && Hashtbl.length eg.hash_cons < node_limit 562 | then if pred eg then false else loop eg.version (ind + 1) 563 | else false 564 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in 565 | loop eg.version 0 566 | 567 | end 568 | 569 | include (BuildRunner (Scheduler.Backoff)) 570 | 571 | let apply_rules (eg: (Id.t L.shape, _, _, _) egraph) (rules : Rule.t list) = 572 | let find_matches = find_matches eg in 573 | let for_each_match = 574 | Iter.of_list rules 575 | |> Iter.flat_map 576 | (fun (from_rule, to_rule) -> 577 | find_matches from_rule 578 | |> Iter.map (fun (eid,env) -> (to_rule, eid, env)) 579 | ) in 580 | for_each_match begin fun (to_rule, eid, env) -> 581 | match to_rule with 582 | | Rule.Constant to_rule -> 583 | let new_eid = subst eg to_rule env in 584 | merge eg eid new_eid 585 | | Conditional (to_rule, cond) -> 586 | if cond eg eid env then 587 | let new_eid = subst eg to_rule env in 588 | merge eg eid new_eid 589 | else () 590 | | Dynamic cond -> 591 | match cond eg eid env with 592 | | None -> () 593 | | Some to_rule -> 594 | let new_eid = subst eg to_rule env in 595 | merge eg eid new_eid 596 | end; 597 | rebuild eg 598 | end 599 | 600 | 601 | 602 | 603 | module Make 604 | (L: LANGUAGE) 605 | (A: ANALYSIS) 606 | (MakeAnalysisOps: functor 607 | (S: GRAPH_API 608 | with type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph 609 | and type analysis := A.t 610 | and type data := A.data 611 | and type 'a shape := 'a L.shape 612 | and type node := L.t) -> sig 613 | val make: (Id.t L.shape, A.t, A.data, ro) egraph -> Id.t L.shape -> A.data 614 | val merge: A.t -> A.data -> A.data -> A.data * (bool * bool) 615 | val modify: (Id.t L.shape, A.t, A.data, rw) egraph -> Id.t -> unit 616 | end) 617 | = struct 618 | 619 | 620 | module rec EGraph : sig 621 | type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph 622 | 623 | module Rule: sig 624 | type t 625 | val make_constant : from:L.op query -> into:L.op query -> t 626 | val make_conditional : 627 | from:L.op query -> 628 | into:L.op query -> 629 | cond:((Id.t L.shape, A.t, A.data, rw) egraph -> eclass_id -> eclass_id StringMap.t -> bool) -> 630 | t 631 | 632 | val make_dynamic : 633 | from:L.op query -> 634 | generator:((Id.t L.shape, A.t, A.data, rw) egraph -> 635 | eclass_id -> eclass_id StringMap.t -> L.op query option) -> 636 | t 637 | 638 | end 639 | 640 | val freeze : rw t -> ro t 641 | val init : A.t -> 'p t 642 | val class_equal: ro t -> eclass_id -> eclass_id -> bool 643 | val new_class : rw t -> eclass_id 644 | val set_data : rw t -> eclass_id -> A.data -> unit 645 | val get_data : _ t -> eclass_id -> A.data 646 | val get_analysis : rw t -> A.t 647 | val canonicalise : rw t -> Id.t L.shape -> Id.t L.shape 648 | val find : ro t -> eclass_id -> eclass_id 649 | (* val append_to_worklist : rw t -> eclass_id -> unit *) 650 | val eclasses: rw t -> (Id.t L.shape, Vector.rw) Vector.t Id.Map.t 651 | (* val pp : Format.formatter -> (Id.t L.shape, 'a, A.data, _) egraph -> unit *) 652 | val to_dot : (Id.t L.shape, A.t, A.data, _) egraph -> Odot.graph 653 | val pp_dot : Format.formatter -> (Id.t L.shape, A.t, A.data, _) egraph -> unit 654 | val add_node : rw t -> L.t -> eclass_id 655 | val merge : rw t -> eclass_id -> eclass_id -> unit 656 | val iter_children : ro t -> eclass_id -> Id.t L.shape Iter.t 657 | val rebuild : rw t -> unit 658 | 659 | val find_matches : ro t -> L.op query -> (eclass_id * eclass_id StringMap.t) Iter.t 660 | val apply_rules : (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> unit 661 | val run_until_saturation: 662 | ?scheduler:Scheduler.Backoff.t -> 663 | ?node_limit:[`Bounded of int | `Unbounded] -> 664 | ?fuel:[`Bounded of int | `Unbounded] -> 665 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) -> (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool 666 | 667 | module BuildRunner (S : SCHEDULER 668 | with type 'a egraph := (Id.t L.shape, A.t, A.data, rw) egraph 669 | and type rule := Rule.t) : 670 | sig 671 | val apply_rules : 672 | S.t -> 673 | int -> 674 | (Id.t L.shape, A.t, A.data, rw) egraph -> 675 | (Rule.t * S.data) array -> unit 676 | val run_until_saturation : 677 | ?scheduler:S.t -> 678 | ?node_limit:[`Bounded of int | `Unbounded] -> 679 | ?fuel:[`Bounded of int | `Unbounded] -> 680 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) -> 681 | (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool 682 | end 683 | end 684 | = struct 685 | let _unsafe = 10 686 | type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph 687 | include (MakeInt (L)) 688 | include (MakePrinter (L) (A)) 689 | include (MakeOps (L) (A) (Analysis)) 690 | end 691 | and Analysis : sig 692 | val make: (Id.t L.shape, A.t, A.data, ro) egraph -> Id.t L.shape -> A.data 693 | val merge: A.t -> A.data -> A.data -> A.data * (bool * bool) 694 | val modify: (Id.t L.shape, A.t, A.data, rw) egraph -> Id.t -> unit 695 | end = MakeAnalysisOps (EGraph) 696 | 697 | include EGraph 698 | 699 | end 700 | -------------------------------------------------------------------------------- /lib/generic.mli: -------------------------------------------------------------------------------- 1 | open Language 2 | 3 | type ('node, 'analysis, 'data, 'permission) egraph 4 | 5 | module MakePrinter : functor (L : LANGUAGE) (A : ANALYSIS) -> sig 6 | (* val pp : Format.formatter -> (Id.t L.shape, A.t, A.data, 'b) egraph -> unit *) 7 | val to_dot : (Id.t L.shape, A.t, A.data, 'b) egraph -> Odot.graph 8 | val pp_dot : Format.formatter -> (Id.t L.shape, A.t, A.data, 'b) egraph -> unit 9 | end 10 | 11 | module MakeExtractor : functor 12 | (L : LANGUAGE) 13 | (E : COST with type node := Id.t L.shape) -> sig 14 | val extract : (Id.t L.shape, 'a, 'b, rw) egraph -> Id.t -> L.t 15 | end 16 | 17 | module Make : 18 | functor 19 | (L : LANGUAGE) 20 | (A : ANALYSIS) 21 | (MakeAnalysisOps : functor 22 | (S : GRAPH_API with type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph 23 | and type analysis := A.t 24 | and type 'a shape := 'a L.shape 25 | and type data := A.data 26 | and type node := L.t) -> 27 | ANALYSIS_OPS with type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph 28 | and type analysis := A.t 29 | and type data := A.data 30 | and type node := Id.t L.shape) -> 31 | sig 32 | type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph 33 | module Rule: 34 | RULE with type query := L.op Query.t 35 | and type 'a egraph := (Id.t L.shape, A.t, A.data, 'a) egraph 36 | 37 | val freeze : rw t -> ro t 38 | val init : A.t -> 'p t 39 | val class_equal: ro t -> Id.t -> Id.t -> bool 40 | val new_class : rw t -> Id.t 41 | val set_data : rw t -> Id.t -> A.data -> unit 42 | val get_data : _ t -> Id.t -> A.data 43 | val get_analysis: rw t -> A.t 44 | val canonicalise : rw t -> Id.t L.shape -> Id.t L.shape 45 | val find : ro t -> Id.t -> Id.t 46 | val eclasses: rw t -> (Id.t L.shape, Containers.Vector.rw) Containers.Vector.t Id.Map.t 47 | val iter_children : ro t -> Id.t -> Id.t L.shape Iter.t 48 | val to_dot : (Id.t L.shape, A.t, A.data, _) egraph -> Odot.graph 49 | val pp_dot : Format.formatter -> (Id.t L.shape, A.t, A.data, _) egraph -> unit 50 | val add_node : rw t -> L.t -> Id.t 51 | val merge : rw t -> Id.t -> Id.t -> unit 52 | val rebuild : rw t -> unit 53 | val find_matches : ro t -> L.op Query.t -> (Id.t * Id.t StringMap.t) Iter.t 54 | val apply_rules : (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> unit 55 | val run_until_saturation: 56 | ?scheduler:Scheduler.Backoff.t -> 57 | ?node_limit:[`Bounded of int | `Unbounded] -> 58 | ?fuel:[`Bounded of int | `Unbounded] -> 59 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) -> (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool 60 | 61 | module BuildRunner (S : SCHEDULER 62 | with type 'a egraph := (Id.t L.shape, A.t, A.data, rw) egraph 63 | and type rule := Rule.t) : 64 | sig 65 | val apply_rules : 66 | S.t -> 67 | int -> 68 | (Id.t L.shape, A.t, A.data, rw) egraph -> 69 | (Rule.t * S.data) array -> unit 70 | val run_until_saturation : 71 | ?scheduler:S.t -> 72 | ?node_limit:[`Bounded of int | `Unbounded] -> 73 | ?fuel:[`Bounded of int | `Unbounded] -> 74 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) -> 75 | (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool 76 | end 77 | 78 | end 79 | -------------------------------------------------------------------------------- /lib/id.ml: -------------------------------------------------------------------------------- 1 | open Containers 2 | 3 | include (Equivalence.Make ()) 4 | 5 | let eq_id = (Equal.map repr Equal.int) 6 | 7 | module OrderedSet = Ordered_set.Make (struct 8 | type nonrec t = t 9 | let equal = eq_id 10 | let hash = hash 11 | end) 12 | 13 | -------------------------------------------------------------------------------- /lib/id.mli: -------------------------------------------------------------------------------- 1 | type t = private int 2 | type store 3 | val eq_id : t -> t -> bool 4 | val repr : t -> int 5 | val hash : t -> int 6 | 7 | val create_store : unit -> store 8 | val make : store -> unit -> t 9 | val find : store -> t -> t 10 | val equal : store -> t -> t -> bool 11 | val union : store -> t -> t -> t 12 | 13 | module Map : Hashtbl.S with type key = t 14 | module Set : CCHashSet.S with type elt = t 15 | module OrderedSet : Ordered_set.S with type elt = t 16 | -------------------------------------------------------------------------------- /lib/language.ml: -------------------------------------------------------------------------------- 1 | open Containers 2 | module StringSet = Set.Make(String) 3 | module StringMap = Stdlib.Map.Make(String) 4 | 5 | let str p v = Format.to_string p v 6 | 7 | type sexp = Sexplib.Sexp.t = Atom of string | List of sexp list 8 | 9 | type rw = [`RW] 10 | type ro = [`RO] 11 | 12 | 13 | module type LANGUAGE = sig 14 | type 'a shape 15 | type op 16 | 17 | type t = Mk of t shape [@@unboxed] 18 | 19 | val equal_op: op -> op -> bool 20 | 21 | val pp_shape: (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a shape -> unit 22 | val compare_shape: ('a -> 'a -> int) -> 'a shape -> 'a shape -> int 23 | val op: 'a shape -> op 24 | val children: 'a shape -> 'a list 25 | val map_children: 'a shape -> ('a -> 'b) -> 'b shape 26 | val make : op -> 'a list -> 'a shape 27 | end 28 | 29 | module type ANALYSIS = sig 30 | type t 31 | type data [@@deriving show, eq] 32 | val default: data 33 | end 34 | 35 | module type ANALYSIS_OPS = sig 36 | type 'a t 37 | type analysis 38 | type node 39 | type data 40 | val make : ro t -> node -> data 41 | val merge : analysis -> data -> data -> data * (bool * bool) 42 | val modify : rw t -> Id.t -> unit 43 | end 44 | 45 | 46 | module type COST = sig 47 | type t 48 | type node 49 | val compare : t -> t -> int 50 | val cost : (Id.t -> t) -> node -> t 51 | end 52 | 53 | module type GRAPH_API = sig 54 | type 'p t 55 | 56 | type analysis 57 | type data 58 | type node 59 | type 'a shape 60 | 61 | val freeze : rw t -> ro t 62 | val class_equal : ro t -> Id.t -> Id.t -> bool 63 | val iter_children : ro t -> Id.t -> Id.t shape Iter.t 64 | val set_data : rw t -> Id.t -> data -> unit 65 | val get_data : ro t -> Id.t -> data 66 | val get_analysis : rw t -> analysis 67 | val add_node : rw t -> node -> Id.t 68 | val merge : rw t -> Id.t -> Id.t -> unit 69 | end 70 | 71 | module type RULE = sig 72 | type t 73 | type query 74 | type 'a egraph 75 | 76 | val make_constant : from:query -> into:query -> t 77 | val make_conditional : 78 | from:query -> 79 | into:query -> 80 | cond:(rw egraph -> Id.t -> Id.t StringMap.t -> bool) -> 81 | t 82 | 83 | val make_dynamic : 84 | from:query -> 85 | generator:(rw egraph -> Id.t -> Id.t StringMap.t -> query option) -> t 86 | 87 | end 88 | 89 | module type SCHEDULER = sig 90 | 91 | type 'a egraph 92 | 93 | type t 94 | 95 | type data 96 | 97 | type rule 98 | 99 | val default : unit -> t 100 | 101 | val should_stop: t -> int -> data Iter.t -> bool 102 | 103 | val create_rule_metadata: t -> rule -> data 104 | 105 | val guard_rule_usage: 106 | rw egraph -> t -> data -> int -> 107 | (unit -> (Id.t * Id.t StringMap.t) Iter.t) -> (Id.t * Id.t StringMap.t) Iter.t 108 | 109 | end 110 | -------------------------------------------------------------------------------- /lib/ordered_set.ml: -------------------------------------------------------------------------------- 1 | open Containers 2 | 3 | module type S = sig 4 | type t 5 | type elt 6 | val create : unit -> t 7 | val push : elt -> t -> unit 8 | val pop : t -> elt 9 | val pop_opt : t -> elt option 10 | val append : t -> elt list -> unit 11 | val clear : t -> unit 12 | val copy : t -> t 13 | val is_empty : t -> bool 14 | val length : t -> int 15 | val iter : (elt -> unit) -> t -> unit 16 | val fold : ('a -> elt -> 'a) -> 'a -> t -> 'a 17 | end 18 | 19 | module Make (Elt: Hashtbl.HashedType) : S with type elt = Elt.t = struct 20 | 21 | module Set = CCHashSet.Make(Elt) 22 | 23 | type t = { 24 | elts: Elt.t Queue.t; 25 | cache: Set.t 26 | } 27 | 28 | type elt = Elt.t 29 | 30 | let create () = {elts=Queue.create (); cache=Set.create 1} 31 | let push vl st = 32 | if Set.mem st.cache vl 33 | then () 34 | else (Queue.push vl st.elts; Set.insert st.cache vl) 35 | 36 | let pop st = 37 | let hd = Queue.pop st.elts in 38 | Set.remove st.cache hd; 39 | hd 40 | 41 | let pop_opt st = 42 | match Queue.peek_opt st.elts with 43 | | None -> None 44 | | Some hd -> 45 | ignore @@ Queue.pop st.elts; 46 | Set.remove st.cache hd; 47 | Some hd 48 | 49 | let append st elts = List.iter (fun elt -> push elt st) elts 50 | 51 | let clear st = Queue.clear st.elts; Set.clear st.cache 52 | 53 | let copy st = {elts=Queue.copy st.elts; cache=Set.copy st.cache} 54 | 55 | let is_empty st = Queue.is_empty st.elts 56 | 57 | let length st = Queue.length st.elts 58 | 59 | let iter f st = Queue.iter f st.elts 60 | 61 | let fold f acc st = Queue.fold f acc st.elts 62 | 63 | end 64 | -------------------------------------------------------------------------------- /lib/ordered_set.mli: -------------------------------------------------------------------------------- 1 | module type S = sig 2 | type t 3 | type elt 4 | val create : unit -> t 5 | val push : elt -> t -> unit 6 | val pop : t -> elt 7 | val pop_opt : t -> elt option 8 | val append : t -> elt list -> unit 9 | val clear : t -> unit 10 | val copy : t -> t 11 | val is_empty : t -> bool 12 | val length : t -> int 13 | val iter : (elt -> unit) -> t -> unit 14 | val fold : ('a -> elt -> 'a) -> 'a -> t -> 'a 15 | end 16 | 17 | module Make: 18 | functor (Elt : Containers.Hashtbl.HashedType) -> S with type elt = Elt.t 19 | -------------------------------------------------------------------------------- /lib/query.ml: -------------------------------------------------------------------------------- 1 | open Containers 2 | open Language 3 | 4 | type 'sym t = 5 | | V of string 6 | | Q of 'sym * 'sym t list 7 | 8 | let rec of_sexp intern : sexp -> _ t = function 9 | | Atom str when String.prefix ~pre:"?" str -> V (String.drop 1 str) 10 | | Atom sym -> Q (intern sym, []) 11 | | List (Atom sym :: children) -> 12 | Q (intern sym, List.map (of_sexp intern) children) 13 | | _ -> invalid_arg "Query sexp not of the expected form" 14 | 15 | let rec to_sexp to_string : _ t -> sexp = function 16 | | V str -> Atom ("?" ^ str) 17 | | Q (head, children) -> List (Atom (to_string head) :: List.map (to_sexp to_string) children) 18 | 19 | let rec pp symbol_pp fmt = function 20 | | V sym -> Format.pp_print_string fmt ("?" ^ sym) 21 | | Q (sym, []) -> symbol_pp fmt sym 22 | | Q (sym, children) -> 23 | let open Format in 24 | pp_print_string fmt "("; 25 | pp_open_hvbox fmt 1; 26 | symbol_pp fmt sym; 27 | pp_print_space fmt (); 28 | pp_print_list ~pp_sep:pp_print_space (pp symbol_pp) fmt children; 29 | pp_close_box fmt (); 30 | pp_print_string fmt ")" 31 | 32 | let show symbol_pp = str (pp symbol_pp) 33 | 34 | let%test "terms are printed as expected" = 35 | Alcotest.(check string) 36 | "prints as expected" 37 | "(+ 1 ?a)" (str (pp Symbol.pp) (Q (Symbol.intern "+", [Q (Symbol.intern "1", []); V "a"]))) 38 | 39 | let variables query = 40 | let rec loop acc = 41 | function 42 | V sym -> StringSet.add sym acc 43 | | Q (_, children) -> 44 | List.fold_left loop acc children in 45 | loop StringSet.empty query 46 | 47 | -------------------------------------------------------------------------------- /lib/scheduler.ml: -------------------------------------------------------------------------------- 1 | open Containers 2 | open Language 3 | 4 | module Backoff = struct 5 | 6 | type t = {match_limit: int; ban_length: int} 7 | 8 | type data = { 9 | mutable times_applied: int; 10 | mutable banned_until: int; 11 | mutable times_banned: int; 12 | mutable match_limit: int; 13 | mutable ban_length: int; 14 | } 15 | 16 | let with_params ~match_limit ~ban_length = {match_limit; ban_length} 17 | 18 | let default () : t = { 19 | match_limit = 1_000; 20 | ban_length = 5; 21 | } 22 | 23 | let create_rule_metadata ({match_limit; ban_length}: t) _ = { 24 | times_applied = 0; 25 | banned_until = 0; 26 | times_banned = 0; 27 | match_limit; 28 | ban_length; 29 | } 30 | 31 | let should_stop _ iteration stats = 32 | let banned = stats 33 | |> Iter.filter (fun data -> data.banned_until > iteration) 34 | |> Iter.to_array in 35 | 36 | if Array.length banned = 0 37 | then true 38 | else begin 39 | let min_ban = 40 | Iter.of_array banned 41 | |> Iter.map (fun data -> data.banned_until) 42 | |> Iter.min_exn ~lt:Int.(<) in 43 | let delta = min_ban - iteration in 44 | 45 | Iter.of_array banned 46 | |> Iter.iter (fun data -> data.banned_until <- data.banned_until - delta) ; 47 | 48 | false 49 | end 50 | 51 | 52 | let guard_rule_usage _ (_ : t) (data: data) iteration 53 | (gen_matches: (unit -> (Id.t * Id.t StringMap.t) Iter.t)) : 54 | (Id.t * Id.t StringMap.t) Iter.t = 55 | if iteration < data.banned_until 56 | then Iter.empty 57 | else begin 58 | let elts = Iter.to_array (gen_matches ()) in 59 | let total_len = Array.length elts in 60 | let threshold = data.match_limit lsl data.times_banned in 61 | if total_len > threshold 62 | then begin 63 | let ban_length = data.ban_length lsl data.times_banned in 64 | data.times_banned <- data.times_banned + 1; 65 | data.banned_until <- iteration + ban_length; 66 | Iter.empty 67 | end 68 | else begin 69 | data.times_applied <- data.times_applied + 1; 70 | Iter.of_array elts 71 | end 72 | 73 | end 74 | 75 | end 76 | 77 | module Simple = struct 78 | 79 | 80 | type t = unit 81 | 82 | type data = unit 83 | 84 | let init () : t = () 85 | 86 | let create_rule_metadata _ _ = () 87 | 88 | let should_stop _ _iteration _stats = true 89 | 90 | let guard_rule_usage _ (_ : t) ((): data) _iteration 91 | (gen_matches: (unit -> (Id.t * Id.t StringMap.t) Iter.t)) : (Id.t * Id.t StringMap.t) Iter.t = 92 | gen_matches () 93 | 94 | end 95 | -------------------------------------------------------------------------------- /lib/symbol.ml: -------------------------------------------------------------------------------- 1 | open Containers 2 | 3 | type t = int 4 | 5 | module SymbolMap = Map.Make(Int) 6 | module StrMap = Hashtbl.Make(String) 7 | 8 | let repr v = v 9 | let tbl = StrMap.create 10 10 | let strs = CCVector.create () 11 | 12 | let intern str = 13 | match StrMap.find_opt tbl str with 14 | | Some id -> id 15 | | None -> 16 | let id = Vector.length strs in 17 | Vector.push strs str; 18 | StrMap.add tbl str id; 19 | id 20 | 21 | let pp fmt s = 22 | Format.pp_print_string fmt (Vector.get strs s) 23 | 24 | let to_string s = Vector.get strs s 25 | -------------------------------------------------------------------------------- /lib/symbol.mli: -------------------------------------------------------------------------------- 1 | open Containers 2 | type t = private int 3 | val repr : t -> int 4 | val intern : string -> t 5 | val pp : Format.formatter -> t -> unit 6 | val to_string: t -> string 7 | 8 | module SymbolMap : Map.S with type key = t 9 | -------------------------------------------------------------------------------- /lib/term.ml: -------------------------------------------------------------------------------- 1 | open Containers 2 | 3 | type t = string * Id.t list 4 | 5 | type op = string 6 | let equal_op = String.equal 7 | let pp_op = String.pp 8 | 9 | let pp pp_children fmt = function 10 | | (sym, []) -> Format.pp_print_string fmt sym 11 | | (sym, children) -> 12 | let open Format in 13 | pp_print_string fmt "("; 14 | pp_open_hvbox fmt 1; 15 | pp_print_string fmt sym; 16 | pp_print_space fmt (); 17 | pp_print_list ~pp_sep:pp_print_space pp_children fmt children; 18 | pp_close_box fmt (); 19 | pp_print_string fmt ")" 20 | 21 | let compare = 22 | Pair.compare String.compare 23 | (List.compare (fun id1 id2 -> 24 | Fun.uncurry Int.compare @@ Pair.map_same Id.repr (id1,id2))) 25 | 26 | let op = fst 27 | let children = snd 28 | let map_children t f = Pair.map_snd (List.map f) t 29 | let make = Pair.make 30 | 31 | let show pp_children = Format.to_string (pp pp_children) 32 | let rec of_sexp f : Sexplib0.Sexp.t -> t = 33 | function 34 | | Atom str -> (str, []) 35 | | List (Atom head :: tail) -> 36 | (head, List.to_iter tail |> Iter.map (of_sexp f) |> Iter.map f |> Iter.to_list) 37 | | _ -> failwith "invalid sexp structure" 38 | 39 | -------------------------------------------------------------------------------- /lib/types.ml: -------------------------------------------------------------------------------- 1 | open Containers 2 | 3 | type eclass_id = Id.t 4 | type enode = Symbol.t * eclass_id list 5 | 6 | let str p v = Format.to_string p v 7 | 8 | (* ** ID *) 9 | module EClassId = struct 10 | type t = eclass_id 11 | let pp fmt id = 12 | Format.pp_print_string fmt @@ Printf.sprintf "e%d" (Id.repr id) 13 | let show = str pp 14 | 15 | let compare (a:t) (b: t) = 16 | Int.compare (a :> int) (b :> int) 17 | 18 | let%test "IDs print correctly" = 19 | let store = Id.create_store () in 20 | Alcotest.(check string) 21 | "should pretty print as e0" 22 | "e0" (str pp (Id.make store ())) 23 | 24 | end 25 | 26 | 27 | (* ** Node *) 28 | module ENode = struct 29 | 30 | type t = enode 31 | 32 | let children (_, children) = children 33 | 34 | let canonicalise uf (sym, children) = 35 | (sym, List.map (Id.find uf) children) 36 | 37 | let hash : enode Hash.t = Hash.(pair poly (list Id.hash)) 38 | 39 | let%test "node hashes correctly" = 40 | let store = Id.create_store () in 41 | let i1 = Id.make store () in 42 | Alcotest.(check int) 43 | "hash values should match" 44 | (hash (Symbol.intern "example", [i1])) 45 | (hash (Symbol.intern "example", [i1])) 46 | 47 | let%test "node hashes correctly after union" = 48 | let store = Id.create_store () in 49 | let i1 = Id.make store () in 50 | let i2 = Id.make store () in 51 | let hash_1 = hash (Symbol.intern "example", [i1]) in 52 | ignore @@ Id.union store i1 i2; 53 | let hash_2 = hash (Symbol.intern "example", [i1]) in 54 | Alcotest.(check int) 55 | "hash values should match" 56 | hash_1 57 | hash_2 58 | 59 | let equal : enode Equal.t = Equal.(pair poly (list Id.eq_id)) 60 | 61 | let pp ?(pp_id=EClassId.pp) fmt (sym, children) = 62 | match children with 63 | | [] -> Symbol.pp fmt sym 64 | | children -> 65 | let open Format in 66 | pp_print_string fmt "("; 67 | pp_open_hvbox fmt 1; 68 | Symbol.pp fmt sym; 69 | pp_print_space fmt (); 70 | pp_print_list ~pp_sep:(pp_print_space) pp_id fmt children; 71 | pp_close_box fmt (); 72 | pp_print_string fmt ")" 73 | 74 | let%test "leaf nodes prints correctly" = 75 | Alcotest.(check string) 76 | "should pretty print as sexp" 77 | "example" 78 | (str (pp ~pp_id:EClassId.pp) 79 | (Symbol.intern "example", [])) 80 | 81 | let%test "node prints correctly" = 82 | let store = Id.create_store () in 83 | Alcotest.(check string) 84 | "should pretty print as sexp" 85 | "(example e0 e1 e2)" 86 | (str (pp ~pp_id:EClassId.pp) 87 | (Symbol.intern "example", 88 | List.init 3 (fun _ -> Id.make store ())) 89 | ) 90 | module Set = Set.Make (struct 91 | type t = enode 92 | let compare n1 n2 = Int.compare (hash n1) (hash n2) 93 | end) 94 | 95 | end 96 | -------------------------------------------------------------------------------- /macros/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name ppx_sexp) 3 | (modules ppx_sexp) 4 | (kind ppx_rewriter) 5 | (libraries ppxlib sexplib) 6 | (preprocess (pps ppxlib.metaquot ppx_deriving.std))) 7 | -------------------------------------------------------------------------------- /macros/ppx_sexp.ml: -------------------------------------------------------------------------------- 1 | open Ppxlib 2 | 3 | let name = "s" 4 | 5 | let build_atom ~loc v = 6 | let str = Ast_helper.(Exp.constant ~loc (Pconst_string (v, loc, None))) in 7 | [%expr Sexplib0.Sexp.Atom [%e str ]] 8 | 9 | let build_list ~loc ls = 10 | let rec build_ls = function 11 | | [] -> [%expr []] 12 | | h :: t -> [%expr [%e h] :: [%e build_ls t]] in 13 | [%expr Sexplib0.Sexp.List [%e (build_ls ls)]] 14 | 15 | let rec convert ~loc expr = 16 | match expr with 17 | | { pexp_desc=Pexp_ident { txt=Lident "()"; _ }; pexp_loc=loc; _ } -> 18 | build_list ~loc [] 19 | | { pexp_desc=Pexp_ident { txt=Lident txt; _ }; pexp_loc=loc; _ } 20 | when txt.[0] = '(' && txt.[String.length txt - 1] = ')' -> 21 | build_atom ~loc (String.sub txt 1 (String.length txt - 2)) 22 | | { pexp_desc=Pexp_ident { txt=Lident txt; _ }; pexp_loc=loc; _ } -> 23 | build_atom ~loc txt 24 | | { pexp_desc=Pexp_constant const; pexp_loc=loc; _ } -> 25 | let const = match const with 26 | | Pconst_integer (txt, _) -> txt 27 | | Pconst_char cr -> String.make 1 cr 28 | | Pconst_string (txt, _, _) -> txt 29 | | Pconst_float (txt, _) -> txt in 30 | build_atom ~loc const 31 | | { pexp_desc=Pexp_apply (expr, args); pexp_loc=loc; _ } 32 | when List.for_all (function (Nolabel, _) -> true | _ -> false) args -> 33 | let h = convert ~loc:expr.pexp_loc expr in 34 | let t = List.map (fun (_, expr) -> convert ~loc:expr.pexp_loc expr) args in 35 | build_list ~loc (h :: t) 36 | | [%expr [%e? x] x ] -> x 37 | | e -> 38 | let exp = Pprintast.expression Format.str_formatter e; Format.flush_str_formatter () in 39 | Location.raise_errorf ~loc "use of unsupported syntactic construct %s" exp 40 | 41 | let expand ~loc ~path:_ expr = convert ~loc expr 42 | 43 | let ext = 44 | Extension.declare name Extension.Context.expression 45 | Ast_pattern.(single_expr_payload __) 46 | expand 47 | 48 | let () = Driver.register_transformation name ~extensions:[ext] 49 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Ego - EGraphs in OCaml 2 | 3 | Ego (EGraphs OCaml) is an OCaml library that provides generic equality 4 | saturation using EGraphs. 5 | 6 | The design of Ego loosely follows the design of Rust's egg library, 7 | providing a flexible interface to run equality saturation extended 8 | with custom user-defined analyses. 9 | 10 | ```ocaml 11 | (* create an egraph *) 12 | let graph = EGraph.init () 13 | (* add expressions *) 14 | let expr1 = EGraph.add_sexp graph [%s ((a << 1) / 2)] 15 | (* Convert to graphviz *) 16 | let g : Odot.graph = EGraph.to_dot graph 17 | ``` 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /test/dune: -------------------------------------------------------------------------------- 1 | (tests 2 | (names test_basic test_generic test_math test_prop) 3 | (preprocess (pps ppx_sexp ppx_deriving.std)) 4 | (libraries ego alcotest)) 5 | -------------------------------------------------------------------------------- /test/test_basic.ml: -------------------------------------------------------------------------------- 1 | open Ego.Basic 2 | 3 | let sexp = 4 | (module struct 5 | type t = Sexplib.Sexp.t 6 | let pp = Sexplib.Sexp.pp_hum 7 | let equal = Sexplib.Sexp.equal 8 | end : Alcotest.TESTABLE with type t = Sexplib.Sexp.t) 9 | 10 | let documentation_example () = 11 | let graph = EGraph.init () in 12 | let expr_id = EGraph.add_sexp graph [%s ((a << 1) / 2)] in 13 | let from = Query.of_sexp [%s ("?a" << 1)] 14 | and into = Query.of_sexp [%s ("?a" * 2)] in 15 | let rule = Rule.make ~from ~into |> function 16 | | Some rule -> rule 17 | | None -> Alcotest.fail "could not build rule" in 18 | Alcotest.(check bool) 19 | "should reach equality saturation" 20 | true (EGraph.run_until_saturation graph [rule]); 21 | let cost_function score (sym, children) = 22 | let node_score = 23 | match Symbol.to_string sym with 24 | | "*" -> 1. 25 | | "/" -> 1. 26 | | "<<" -> 2. 27 | | _ -> 0. in 28 | node_score +. List.fold_left (fun acc vl -> acc +. score vl) 0. children in 29 | let result = EGraph.extract cost_function graph expr_id in 30 | Alcotest.(check sexp) 31 | "extracted expression has been simplified" 32 | [%s ((a * 2) / 2)] 33 | result 34 | 35 | (* 36 | We start off with two exprs, (g 1) and (g 2), and merge these two. 37 | Then we add a rule (g ?a) -> (h ?a), creating (h 1) and (h 2) which are also equal to (g 1) and (g 2). 38 | We extract the cheapest term using a cost function constructed such that (h 2) is lowest cost term, with cost 11. 39 | Previously (h 1), which has cost 12, was extracted instead. 40 | 41 | (h 1): 12 42 | (h 2): 11 43 | (g 1): inf 44 | (g 2): inf 45 | *) 46 | let test_match () = 47 | let graph = EGraph.init () in 48 | let expr_id1 = EGraph.add_sexp graph [%s (g 1)] in 49 | let _ = EGraph.add_sexp graph [%s (g 2)] in 50 | let from = Query.of_sexp [%s (g 1)] 51 | and into = Query.of_sexp [%s (g 2)] in 52 | let rule1 = Rule.make ~from ~into |> function 53 | | Some rule -> rule 54 | | None -> Alcotest.fail "could not build rule" in 55 | let from = Query.of_sexp [%s (g "?a")] 56 | and into = Query.of_sexp [%s (h "?a")] in 57 | let rule2 = Rule.make ~from ~into |> function 58 | | Some rule -> rule 59 | | None -> Alcotest.fail "could not build rule" in 60 | Alcotest.(check bool) 61 | "should reach equality saturation" 62 | true (EGraph.run_until_saturation graph [rule1; rule2]); 63 | let cost_function score (sym, children) = 64 | let node_score = 65 | match Symbol.to_string sym with 66 | | "g" -> 9999999. 67 | | "h" -> 10. 68 | | "1" -> 2. 69 | | "2" -> 1. 70 | | _ -> 9999999. in 71 | node_score +. List.fold_left (fun acc vl -> acc +. score vl) 0. children in 72 | let result = EGraph.extract cost_function graph expr_id1 in 73 | Alcotest.(check sexp) 74 | "cheapest expression is (h 2)" 75 | [%s (h 2)] 76 | result 77 | 78 | 79 | let () = 80 | Alcotest.run "basic" [ 81 | ("documentation", ["example given in documentation works as written", `Quick, documentation_example]); 82 | ("test matching", ["test matching", `Quick, test_match]) 83 | ] 84 | -------------------------------------------------------------------------------- /test/test_generic.ml: -------------------------------------------------------------------------------- 1 | open Ego.Generic 2 | 3 | let sexp = 4 | (module struct 5 | type t = Sexplib.Sexp.t 6 | let pp = Sexplib.Sexp.pp_hum 7 | let equal = Sexplib.Sexp.equal 8 | end : Alcotest.TESTABLE with type t = Sexplib.Sexp.t) 9 | 10 | module L = struct 11 | 12 | type 'a shape = Add of 'a * 'a | Sub of 'a * 'a | Mul of 'a * 'a 13 | | Div of 'a * 'a | Var of string | Const of int [@@deriving ord, show] 14 | 15 | type op = AddOp | SubOp | MulOp | DivOp | VarOp of string | ConstOp of int [@@deriving eq] 16 | 17 | type t = Mk of t shape [@@unboxed] 18 | 19 | let rec of_sexp = function [@warning "-8"] 20 | | Sexplib0.Sexp.Atom s -> 21 | begin match int_of_string_opt s with 22 | | Some n -> Mk (Const n) 23 | | None -> Mk (Var s) 24 | end 25 | | List [Atom ("*" | " * "); l; r] -> Mk (Mul (of_sexp l, of_sexp r)) 26 | | List [Atom "-"; l; r] -> Mk (Sub (of_sexp l, of_sexp r)) 27 | | List [Atom "+"; l; r] -> Mk (Add (of_sexp l, of_sexp r)) 28 | | List [Atom "/"; l; r] -> Mk (Div (of_sexp l, of_sexp r)) 29 | 30 | let rec to_sexp = function 31 | | Mk (Add (l, r)) -> Sexplib0.Sexp.List [Atom "+"; to_sexp l; to_sexp r] 32 | | Mk (Sub (l, r)) -> List [Atom "-"; to_sexp l; to_sexp r] 33 | | Mk (Mul (l, r)) -> List [Atom "*"; to_sexp l; to_sexp r] 34 | | Mk (Div (l, r)) -> List [Atom "/"; to_sexp l; to_sexp r] 35 | | Mk (Var s) -> Atom s 36 | | Mk (Const n) -> Atom (Int.to_string n) 37 | 38 | let op = function 39 | | Add _ -> AddOp 40 | | Sub _ -> SubOp 41 | | Mul _ -> MulOp 42 | | Div _ -> DivOp 43 | | Var s -> VarOp s 44 | | Const i -> ConstOp i 45 | 46 | let op_of_string = function 47 | | "+" -> AddOp 48 | | "-" -> SubOp 49 | | ("*" | " * ") -> MulOp 50 | | "/" -> DivOp 51 | | s -> match int_of_string_opt s with 52 | | None -> VarOp s 53 | | Some n -> ConstOp n 54 | 55 | let children = function 56 | | Add (l,r) | Sub (l,r) | Mul (l,r) | Div (l,r) -> [l;r] 57 | | Var _ | Const _ -> [] 58 | 59 | let map_children term f = match term with 60 | | Add (l,r) -> Add (f l, f r) 61 | | Sub (l,r) -> Sub (f l, f r) 62 | | Mul (l,r) -> Mul (f l, f r) 63 | | Div (l,r) -> Div (f l, f r) 64 | | Var s -> Var s | Const i -> Const i 65 | 66 | let make op ls = 67 | match[@warning "-8"] op,ls with 68 | | AddOp, [l;r] -> Add (l,r) 69 | | SubOp, [l;r] -> Sub (l,r) 70 | | MulOp, [l;r] -> Mul (l,r) 71 | | DivOp, [l;r] -> Div (l,r) 72 | | VarOp s, [] -> Var s 73 | | ConstOp i, [] -> Const i 74 | 75 | end 76 | 77 | module C = struct 78 | type t = float [@@deriving ord] 79 | let cost f : Ego.Id.t L.shape -> t = function 80 | | L.Add (l, r) -> f l +. f r +. 1.0 81 | | L.Sub (l, r) -> f l +. f r +. 1.5 82 | | L.Mul (l, r) -> f l +. f r +. 2.0 83 | | L.Div (l, r) -> f l +. f r +. 2.0 84 | | L.Var _ -> 1.0 85 | | L.Const _ -> 1.0 86 | end 87 | 88 | module A = struct type t = unit type data = int option [@@deriving eq, show] let default = None end 89 | module MA (S : GRAPH_API 90 | with type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph 91 | and type 'a shape := 'a L.shape 92 | and type analysis := A.t 93 | and type data := A.data 94 | and type node := L.t) = struct 95 | type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph 96 | 97 | let eval : A.data L.shape -> A.data = 98 | function 99 | | L.Add (Some l, Some r) -> Some (l + r) 100 | | L.Sub (Some l, Some r) -> Some (l - r) 101 | | L.Mul (Some l, Some r) -> Some (l * r) 102 | | L.Div (Some l, Some r) -> if r <> 0 then Some (l / r) else None 103 | | L.Const n -> Some n 104 | | _ -> None 105 | 106 | let make : ro t -> Ego.Id.t L.shape -> A.data = 107 | fun graph term -> 108 | eval (L.map_children term (S.get_data graph)) 109 | 110 | let merge : A.t -> A.data -> A.data -> A.data * (bool * bool) = 111 | fun () l r -> match l,r with 112 | | Some l, Some r -> assert (l = r); Some l, (false, false) 113 | | Some l, None -> Some l, (false, true) 114 | | None, Some r -> Some r, (true, false) 115 | | _ -> None, (false, false) 116 | 117 | let modify : 'a t -> Ego.Id.t -> unit = 118 | fun graph cls -> 119 | match S.get_data (S.freeze graph) cls with 120 | | None -> () 121 | | Some n -> 122 | let nw_cls = S.add_node graph (L.Mk (Const n)) in 123 | S.merge graph nw_cls cls 124 | 125 | end 126 | 127 | module EGraph = Make (L) (A) (MA) 128 | module Extractor = MakeExtractor (L) (C) 129 | 130 | 131 | let documentation_example () = 132 | let graph = EGraph.init () in 133 | let expr = EGraph.add_node graph (L.of_sexp [%s (2 * 2)]) in 134 | let result = Extractor.extract graph expr in 135 | Alcotest.(check sexp) 136 | "extracted expression has been simplified" 137 | [%s 4] 138 | (L.to_sexp result) 139 | 140 | let simple_constant_folding () = 141 | let graph = EGraph.init () in 142 | let expr = EGraph.add_node graph (L.of_sexp [%s (2 + (1 + (3 - 2)))]) in 143 | let result = Extractor.extract graph expr in 144 | Alcotest.(check sexp) 145 | "extracted expression has been simplified" 146 | [%s 4] 147 | (L.to_sexp result) 148 | 149 | 150 | let multiple_terms_constant_folding () = 151 | let graph = EGraph.init () in 152 | let expr1 = EGraph.add_node graph (L.of_sexp [%s (3 * (2 - (10 / 5)))]) in 153 | let expr2 = EGraph.add_node graph (L.of_sexp [%s (3 - 3)]) in 154 | Alcotest.(check sexp) 155 | "first extracted expression has been simplified" 156 | [%s 0] 157 | (L.to_sexp (Extractor.extract graph expr1)); 158 | Alcotest.(check sexp) 159 | "second extracted expression has been simplified" 160 | [%s 0] 161 | (L.to_sexp (Extractor.extract graph expr2)) 162 | 163 | let multiple_terms_variable_constant_folding () = 164 | let graph = EGraph.init () in 165 | let expr1 = EGraph.add_node graph (L.of_sexp [%s ((2 * x) + (3 * (2 - (10 / 5))))]) in 166 | let expr2 = EGraph.add_node graph (L.of_sexp [%s ((3 - 3) + (2 * x))]) in 167 | Alcotest.(check sexp) 168 | "first extracted expression has been simplified" 169 | [%s ("+" (( * ) 2 x) 0)] 170 | (L.to_sexp (Extractor.extract graph expr1)); 171 | Alcotest.(check sexp) 172 | "second extracted expression has been simplified" 173 | [%s ("+" 0 (( * ) 2 x))] 174 | (L.to_sexp (Extractor.extract graph expr2)) 175 | 176 | let syntactic_rewrite () = 177 | let graph = EGraph.init () in 178 | let rewrite = 179 | EGraph.Rule.make_constant 180 | ~from:(Query.of_sexp L.op_of_string [%s (2 * "?x")]) 181 | ~into:(Query.of_sexp L.op_of_string [%s ("?x" + "?x")]) in 182 | let expr = EGraph.add_node graph (L.of_sexp [%s (1 + (3 / (2 * a)))]) in 183 | Alcotest.(check bool) 184 | "rewrites reached saturation" true 185 | @@ EGraph.run_until_saturation graph [rewrite]; 186 | Alcotest.(check sexp) 187 | "first extracted expression has been simplified" 188 | [%s ("+" 1 ((/) 3 ("+" a a)))] 189 | (L.to_sexp (Extractor.extract graph expr)) 190 | 191 | let conditional_rewrite () = 192 | let graph = EGraph.init () in 193 | let rewrite = 194 | EGraph.Rule.make_conditional 195 | ~from:(Query.of_sexp L.op_of_string [%s ("?x" / "?x")]) 196 | ~into:(Query.of_sexp L.op_of_string [%s 1]) 197 | ~cond:(fun graph _root env -> 198 | let x = StringMap.find "x" env in 199 | match EGraph.get_data graph x with 200 | | None | Some 0 -> false 201 | | _ -> true (* only safe to do this rewrite if x isn't 0 *)) in 202 | let expr_valid = EGraph.add_node graph (L.of_sexp [%s (10 / 10)]) in 203 | let expr_invalid = EGraph.add_node graph (L.of_sexp [%s (0 / 0)]) in 204 | let expr_invalid_compl = EGraph.add_node graph (L.of_sexp [%s ((4 * 3 - 6 * 2) / (2 * 3 - 3 * 2))]) in 205 | let expr_valid_compl = EGraph.add_node graph (L.of_sexp [%s (((4 * 3 - 6 * 2) + 1) / ((2 * 3 - 3 * 2) + 1))]) in 206 | let expr_x_x = EGraph.add_node graph (L.of_sexp [%s (x / x)]) in 207 | Alcotest.(check bool) 208 | "rewrites reached saturation" true 209 | @@ EGraph.run_until_saturation graph [rewrite]; 210 | Alcotest.(check sexp) 211 | "basic expression has been simplified" 212 | [%s 1] 213 | (L.to_sexp (Extractor.extract graph expr_valid)); 214 | Alcotest.(check sexp) 215 | "invalid expression has not been simplified" 216 | [%s ("/" 0 0)] 217 | (L.to_sexp (Extractor.extract graph expr_invalid)); 218 | Alcotest.(check sexp) 219 | "complex invalid expression has not been simplified beyond minimal" 220 | [%s ("/" 0 0)] 221 | (L.to_sexp (Extractor.extract graph expr_invalid_compl)); 222 | Alcotest.(check sexp) 223 | "complex valid expression has been simplified" 224 | [%s 1] 225 | (L.to_sexp (Extractor.extract graph expr_valid_compl)); 226 | Alcotest.(check sexp) 227 | "expression of variables has not been simplified" 228 | [%s ("/" x x)] 229 | (L.to_sexp (Extractor.extract graph expr_x_x)) 230 | 231 | 232 | let () = 233 | Alcotest.run "generic" [ 234 | ("documentation", [ 235 | "example given in documentation works as written", `Quick, documentation_example; 236 | "simple constant folding", `Quick, simple_constant_folding; 237 | "multiple terms constant folding", `Quick, multiple_terms_constant_folding; 238 | "multiple terms with variable constant folding", `Quick, multiple_terms_variable_constant_folding; 239 | "syntactic rewriting", `Quick, syntactic_rewrite; 240 | "conditional rewriting", `Quick, conditional_rewrite; 241 | ]) 242 | ] 243 | -------------------------------------------------------------------------------- /test/test_math.ml: -------------------------------------------------------------------------------- 1 | open Ego.Generic 2 | let sexp = 3 | (module struct 4 | type t = Sexplib.Sexp.t 5 | let pp = Sexplib.Sexp.pp_hum 6 | let equal = Sexplib.Sexp.equal 7 | end : Alcotest.TESTABLE with type t = Sexplib.Sexp.t) 8 | 9 | module Symbol : sig 10 | type t 11 | val compare : t -> t -> int 12 | val equal: t -> t -> bool 13 | val pp : Format.formatter -> t -> unit 14 | val intern: string -> t 15 | val to_string: t -> string 16 | end = struct 17 | type t = int 18 | let equal = Int.equal 19 | let compare = Int.compare 20 | let intern, to_string = 21 | let tbl = ref @@ StringMap.empty in 22 | let buf = Array.make 100 "" in 23 | let limit = ref 0 in 24 | let intern s = 25 | match (StringMap.find_opt s !tbl) with 26 | Some n -> n | None -> let ind = !limit in buf.(ind) <- s; incr limit; 27 | tbl := (StringMap.add s ind !tbl); ind in 28 | let to_string n = buf.(n) in 29 | intern, to_string 30 | 31 | let pp fmt s = Format.pp_print_string fmt (to_string s) 32 | end 33 | 34 | module L = struct 35 | type 'a shape = 36 | | Diff of 'a * 'a 37 | | Integral of 'a * 'a 38 | | Add of 'a * 'a 39 | | Sub of 'a * 'a 40 | | Mul of 'a * 'a 41 | | Div of 'a * 'a 42 | | Pow of 'a * 'a 43 | | Ln of 'a 44 | | Sqrt of 'a 45 | | Sin of 'a 46 | | Cos of 'a 47 | | Constant of float 48 | | Symbol of Symbol.t 49 | [@@deriving ord, show] 50 | 51 | (* let float_equal f1 f2 = 52 | * print_endline @@ Printf.sprintf "comparing %f eq with %f = %b" f1 f2 (Float.equal f1 f2); 53 | * Float.equal f1 f2 *) 54 | 55 | type op = 56 | | DiffOp | IntegralOp | AddOp | SubOp | MulOp | DivOp | PowOp | LnOp | SqrtOp 57 | | SinOp | CosOp | ConstantOp of float | SymbolOp of Symbol.t [@@deriving eq] 58 | 59 | type t = Mk of t shape [@@unboxed] 60 | 61 | let rec of_sexp : Sexplib0.Sexp.t -> _ = function [@warning "-8"] 62 | | Atom s -> 63 | begin match float_of_string_opt s with 64 | | Some n -> Mk (Constant n) 65 | | None -> 66 | match int_of_string_opt s with 67 | | Some n -> Mk (Constant (Float.of_int n)) 68 | | None -> Mk (Symbol (Symbol.intern s)) 69 | end 70 | | List [Atom "d"; l; r] -> Mk (Diff (of_sexp l, of_sexp r)) 71 | | List [Atom "i"; l; r] -> Mk (Integral (of_sexp l, of_sexp r)) 72 | | List [Atom ("*" | " * "); l; r] -> Mk (Mul (of_sexp l, of_sexp r)) 73 | | List [Atom "-"; l; r] -> Mk (Sub (of_sexp l, of_sexp r)) 74 | | List [Atom "+"; l; r] -> Mk (Add (of_sexp l, of_sexp r)) 75 | | List [Atom "/"; l; r] -> Mk (Div (of_sexp l, of_sexp r)) 76 | | List [Atom "pow"; l; r] -> Mk (Pow (of_sexp l, of_sexp r)) 77 | | List [Atom "ln"; l] -> Mk (Ln (of_sexp l)) 78 | | List [Atom "sqrt"; l] -> Mk (Sqrt (of_sexp l)) 79 | | List [Atom "sin"; l] -> Mk (Sin (of_sexp l)) 80 | | List [Atom "cos"; l] -> Mk (Cos (of_sexp l)) 81 | 82 | let rec to_sexp : t -> Sexplib0.Sexp.t = function 83 | | Mk (Diff (l, r)) -> List [Atom "d"; to_sexp l; to_sexp r] 84 | | Mk (Integral (l, r)) -> List [Atom "i"; to_sexp l; to_sexp r] 85 | | Mk (Add (l, r)) -> List [Atom "+"; to_sexp l; to_sexp r] 86 | | Mk (Sub (l, r)) -> List [Atom "-"; to_sexp l; to_sexp r] 87 | | Mk (Mul (l, r)) -> List [Atom "*"; to_sexp l; to_sexp r] 88 | | Mk (Div (l, r)) -> List [Atom "/"; to_sexp l; to_sexp r] 89 | | Mk (Pow (l, r)) -> List [Atom "pow"; to_sexp l; to_sexp r] 90 | | Mk (Ln l) -> List [Atom "ln"; to_sexp l] 91 | | Mk (Sqrt l) -> List [Atom "sqrt"; to_sexp l] 92 | | Mk (Sin l) -> List [Atom "sin"; to_sexp l] 93 | | Mk (Cos l) -> List [Atom "cos"; to_sexp l] 94 | | Mk (Constant l) -> Atom (Float.to_string l) 95 | | Mk (Symbol l) -> Atom (Symbol.to_string l) 96 | 97 | let op = function 98 | | Diff (_, _) -> DiffOp | Integral (_, _) -> IntegralOp 99 | | Add (_, _) -> AddOp | Sub (_, _) -> SubOp 100 | | Mul (_, _) -> MulOp | Div (_, _) -> DivOp 101 | | Pow (_, _) -> PowOp | Ln _ -> LnOp 102 | | Sqrt _ -> SqrtOp | Sin _ -> SinOp | Cos _ -> CosOp 103 | | Constant c -> ConstantOp c | Symbol s -> SymbolOp s 104 | 105 | let op_of_string : string -> op = function [@warning "-8"] 106 | | "d" -> DiffOp | "i" -> IntegralOp | ("*" | " * ") -> MulOp 107 | | "-" -> SubOp | "+" -> AddOp | "/" -> DivOp 108 | | "pow" -> PowOp | "ln" -> LnOp | "sqrt" -> SqrtOp 109 | | "sin" -> SinOp | "cos" -> CosOp | s -> 110 | begin match float_of_string_opt s with 111 | | Some n -> (ConstantOp n) 112 | | None -> 113 | match int_of_string_opt s with 114 | | Some n -> ConstantOp (Float.of_int n) 115 | | None -> SymbolOp (Symbol.intern s) 116 | end 117 | 118 | let children = function 119 | | Diff (l, r) | Integral (l, r) | Add (l, r) | Sub (l, r) | Mul (l, r) 120 | | Div (l, r) | Pow (l, r) -> [l;r] 121 | | Ln l | Sqrt l | Sin l | Cos l -> [l] 122 | | Constant _ | Symbol _ -> [] 123 | 124 | let map_children term f = match term with 125 | | Diff (l, r) -> Diff (f l, f r) | Integral (l, r) -> Integral (f l, f r) 126 | | Add (l, r) -> Add (f l, f r) | Sub (l, r) -> Sub (f l, f r) 127 | | Mul (l, r) -> Mul (f l, f r) | Div (l, r) -> Div (f l, f r) 128 | | Pow (l, r) -> Pow (f l, f r) | Ln l -> Ln (f l) 129 | | Sqrt l -> Sqrt (f l) | Sin l -> Sin (f l) | Cos l -> Cos (f l) 130 | | Constant c -> Constant c | Symbol s -> Symbol s 131 | 132 | let make op ls = 133 | match[@warning "-8"] op,ls with 134 | | DiffOp, [l;r] -> Diff (l, r) | IntegralOp, [l;r] -> Integral (l, r) 135 | | AddOp, [l;r] -> Add (l, r) | SubOp, [l;r] -> Sub (l, r) 136 | | MulOp, [l;r] -> Mul (l, r) | DivOp, [l;r] -> Div (l, r) 137 | | PowOp, [l;r] -> Pow (l, r) | LnOp, [l] -> Ln l 138 | | SqrtOp, [l] -> Sqrt l | SinOp, [l] -> Sin l | CosOp, [l] -> Cos l 139 | | ConstantOp c, [] -> Constant c | SymbolOp s, [] -> Symbol s 140 | 141 | end 142 | 143 | module C = struct 144 | type t = int [@@deriving ord] 145 | let cost f : Ego.Id.t L.shape -> t = 146 | fun term -> 147 | let base_cost = match term with Diff _ | Integral _ -> 100 | Sub _ -> 20 | _ -> 1 in 148 | L.children term |> List.fold_left (fun acc vl -> acc + f vl) base_cost 149 | end 150 | 151 | module A = struct type t = unit type data = float option [@@deriving eq, show] let default = None end 152 | module MA (S : GRAPH_API 153 | with type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph 154 | and type 'a shape := 'a L.shape 155 | and type analysis := A.t 156 | and type data := A.data 157 | and type node := L.t) = struct 158 | type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph 159 | 160 | let eval : A.data L.shape -> A.data = 161 | function 162 | | L.Add (Some l, Some r) -> Some (l +. r) 163 | | L.Sub (Some l, Some r) -> Some (l -. r) 164 | | L.Mul (Some l, Some r) -> Some (l *. r) 165 | | L.Div (Some l, Some r) -> if Containers.Float.equal_precision ~epsilon:0.01 r 0. then Some (l /. r) else None 166 | | L.Constant n -> Some n 167 | | _ -> None 168 | 169 | let make : ro t -> Ego.Id.t L.shape -> A.data = 170 | fun graph term -> 171 | eval (L.map_children term (S.get_data graph)) 172 | 173 | let merge : A.t -> A.data -> A.data -> A.data * (bool * bool)= 174 | fun () l r -> match l,r with 175 | | Some l, Some r -> 176 | if Float.equal l r 177 | then Some l, (false,false) 178 | else failwith @@ Printf.sprintf "merge failed: float values %f <> %f " l r 179 | | Some l, _ -> Some l, (false, true) 180 | | _, Some r -> Some r, (true, false) 181 | | _ -> None, (false,false) 182 | 183 | let modify : 'a t -> Ego.Id.t -> unit = 184 | fun graph cls -> 185 | match S.get_data (S.freeze graph) cls with 186 | | None -> () 187 | | Some n -> 188 | let nw_cls = S.add_node graph (L.Mk (Constant n)) in 189 | S.merge graph nw_cls cls 190 | 191 | end 192 | 193 | module EGraph = Make (L) (A) (MA) 194 | module Extractor = MakeExtractor (L) (C) 195 | 196 | let is_const_or_distinct_var v w = 197 | fun graph _root_id env -> 198 | let v = StringMap.find v env in 199 | let w = StringMap.find w env in 200 | (not @@ EGraph.class_equal (EGraph.freeze graph) v w) 201 | && ((EGraph.get_data graph v |> Option.is_some) || 202 | EGraph.iter_children (EGraph.freeze graph) v |> Iter.exists (function L.Symbol _ -> true | _ -> false)) 203 | 204 | let is_const v = 205 | fun graph _root_id env -> 206 | let v = StringMap.find v env in 207 | EGraph.get_data graph v |> Option.is_some 208 | 209 | let is_sym v = 210 | fun graph _root_id env -> 211 | let v = StringMap.find v env in 212 | EGraph.iter_children (EGraph.freeze graph) v |> Iter.exists (function L.Symbol _ -> true | _ -> false) 213 | 214 | let is_not_zero v = 215 | fun graph _root_id env -> 216 | let v = StringMap.find v env in 217 | EGraph.get_data graph v |> function Some 0.0 -> false | _ -> true 218 | 219 | let qf = Query.of_sexp L.op_of_string 220 | let (@->) from into = EGraph.Rule.make_constant ~from:(qf from) ~into:(qf into) 221 | let rewrite from into ~if_ = EGraph.Rule.make_conditional ~from:(qf from) ~into:(qf into) ~cond:if_ 222 | 223 | let rules = 224 | let[@warning "-26"] (&&) f1 f2 = fun graph root_id env -> (f1 graph root_id env) && (f2 graph root_id env) in 225 | [ 226 | 227 | [%s ("?a" + "?b")] @-> [%s ("?b" + "?a")]; (* comm-add *) 228 | [%s ("?a" * "?b")] @-> [%s ("?b" * "?a")]; (* comm-mul *) 229 | [%s ("?a" + ("?b" + "?c"))] @-> [%s (("?a" + "?b") + "?c")]; (* assoc add *) 230 | [%s ("?a" * ("?b" * "?c"))] @-> [%s (("?a" * "?b") * "?c")]; (* assoc mul *) 231 | 232 | [%s (("?a" - "?c") + "?b")] @-> [%s ("?a" + ("?b" - "?c"))]; 233 | 234 | [%s ("?a" - "?b")] @-> [%s ("?a" + ("-1." * "?b"))]; (* sub canon *) 235 | rewrite [%s ("?a" / "?b")] [%s ("?a" * (pow "?b" "-1.0"))] ~if_:(is_not_zero "b"); (* div canon *) 236 | 237 | [%s ("?a" + "0.")] @-> [%s "?a"]; (* zero-add *) 238 | [%s ("?a" * "0.")] @-> [%s "0."]; (* zero-mul *) 239 | [%s ("?a" * "1.")] @-> [%s "?a"]; (* one-mul *) 240 | 241 | [%s "?a"] @-> [%s ("?a" + "0.")]; (* add-zero *) 242 | [%s "?a"] @-> [%s ("?a" * "1.")]; (* mul-one *) 243 | 244 | [%s ("?a" - "?a")] @-> [%s "0."]; (* cancel sub *) 245 | rewrite [%s ("?a" / "?a")] [%s "1."] ~if_:(is_not_zero "a"); (* cancel div *) 246 | 247 | [%s("?a" * ("?b" + "?c"))] @-> [%s (("?a" * "?b") + ("?a" * "?c"))]; (* distribute *) 248 | [%s (("?a" * "?b") + ("?a" * "?c"))] @-> [%s ("?a" * ("?b" + "?c"))]; (* factor *) 249 | 250 | [%s ((pow "?a" "?b") * (pow "?a" "?c"))] @-> [%s (pow "?a" ("?b" + "?c"))]; (* pow-mul *) 251 | rewrite [%s (pow "?x" "0.")] [%s "1."] ~if_:(is_not_zero "x"); (* po0 *) 252 | 253 | [%s (pow "?x" "1.")] @-> [%s "?x"]; (* pow1 *) 254 | 255 | [%s (pow "?x" "2.")] @-> [%s ("?x" * "?x")]; (* po2 *) 256 | 257 | rewrite [%s (pow "?x" "-1.")] [%s("1." / "?x")] ~if_:(is_not_zero "x"); (* pow-recip *) 258 | 259 | rewrite [%s ("?x" * ("1." / "?x"))] [%s "1."] ~if_:(is_not_zero "x"); (* recip mul div *) 260 | 261 | rewrite [%s (d "?x" "?x")] [%s "1."] ~if_:(is_sym "x"); (* d variable *) 262 | 263 | rewrite [%s (d "?x" "?c")] [%s"0."] ~if_:(is_sym "x" && is_const_or_distinct_var "c" "x"); 264 | (* d constant *) 265 | 266 | [%s (d "?x" ("?a" + "?b"))] @-> [%s ((d "?x" "?a") + (d "?x" "?b"))]; (* d-add *) 267 | [%s (d "?x" ("?a" * "?b"))] @-> [%s (("?a" * (d "?x" "?b")) + ("?b" * (d "?x" "?a")))]; (* d-mul *) 268 | 269 | [%s (d "?x" (sin "?x"))] @-> [%s (cos "?x")]; (* d-sin *) 270 | 271 | [%s (d "?x" (cos "?x"))] @-> [%s ("-1." * (sin "?x"))]; (* d-cos *) 272 | 273 | rewrite [%s (d "?x" (ln "?x"))] [%s (1 / "?x")] ~if_:(is_not_zero "x"); (* d-ln *) 274 | 275 | rewrite [%s (d "?x" (pow "?f" "?g"))] 276 | [%s ((pow "?f" "?g") * (((d "?x" "?f") * ("?g" / "?f")) + ((d "?x" "?g") * (ln "?f"))))] 277 | ~if_:(is_not_zero "f" && is_not_zero "g"); 278 | [%s (i "1." "?x")] @-> [%s "?x"]; 279 | rewrite 280 | [%s (i (pow "?x" "?c") "?x")] 281 | [%s ((pow "?x" ("?c" + "1.")) / ("?c" + "1."))] 282 | ~if_:(is_const "c"); 283 | [%s (i (cos "?x") "?x")] @-> [%s (sin "?x")]; 284 | [%s (i (sin "?x") "?x")] @-> [%s ("-1." * (cos "?x"))]; 285 | [%s (i ("?f" + "?g") "?x")] @-> [%s ((i "?f" "?x") + (i "?g" "?x"))]; 286 | [%s (i ("?f" - "?g") "?x")] @-> [%s ((i "?f" "?x") - (i "?g" "?x"))]; 287 | [%s (i ("?a" * "?b") "?x")] @-> [%s (("?a" * (i "?b" "?x")) - (i ((d "?x" "?a") * (i "?b" "?x")) "?x"))]; 288 | ] 289 | 290 | let run_and_check1 ?node_limit ?fuel rules s1 f () = 291 | let graph = EGraph.init () in 292 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in 293 | let reached_saturation = EGraph.run_until_saturation ?node_limit ?fuel graph rules in 294 | begin 295 | match fuel, node_limit with 296 | _ , Some _ | Some _, _ -> () 297 | | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation 298 | end; 299 | f graph term_1 300 | 301 | let run_and_check2 ?node_limit ?fuel rules s1 s2 f () = 302 | let graph = EGraph.init () in 303 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in 304 | let term_2 = EGraph.add_node graph (L.of_sexp s2) in 305 | let reached_saturation = EGraph.run_until_saturation ?node_limit ?fuel graph rules in 306 | begin 307 | match fuel, node_limit with 308 | _ , Some _ | Some _, _ -> () 309 | | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation 310 | end; 311 | f graph term_1 term_2 312 | 313 | let check_proves_equal ?node_limit ?fuel rules s1 s2 () = 314 | let graph = EGraph.init () in 315 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in 316 | let term_2 = EGraph.add_node graph (L.of_sexp s2) in 317 | let terms_are_equal graph = EGraph.class_equal (EGraph.freeze graph) term_1 term_2 in 318 | let _reached_saturation = EGraph.run_until_saturation ~until:terms_are_equal ?node_limit ?fuel graph rules in 319 | (* begin 320 | * match fuel, node_limit with 321 | * _ , Some _ | Some _, _ -> () 322 | * | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation 323 | * end; *) 324 | Alcotest.(check bool) "proves terms are equal modulo rewriting" 325 | true 326 | (terms_are_equal graph) 327 | 328 | let check_cannot_prove_equal ?node_limit ?fuel rules s1 s2 = 329 | run_and_check2 ?node_limit ?fuel rules s1 s2 (fun graph term_1 term_2 -> 330 | Alcotest.(check bool) "must not prove terms are equal modulo rewriting" 331 | false 332 | (EGraph.class_equal (EGraph.freeze graph) term_1 term_2) 333 | ) 334 | 335 | let check_extract ?node_limit ?fuel rules s1 s2 = 336 | run_and_check1 ?node_limit ?fuel rules s1 (fun graph term_1 -> 337 | Alcotest.(check sexp) 338 | "extracted expression matches" 339 | s2 340 | (L.to_sexp (Extractor.extract graph term_1)) 341 | ) 342 | 343 | let () = 344 | Alcotest.run "math" 345 | [("proving with addition", 346 | let rules = [ 347 | (* add comm *) [%s ("?a" + "?b")] @-> [%s ("?b" + "?a")]; 348 | (* add assoc *) [%s ("?a" + ("?b" + "?c"))] @-> [%s (("?a" + "?b") + "?c")]; 349 | ] in [ 350 | "constants are simplified", `Quick, check_proves_equal rules 351 | [%s (1 + (2 + (3 + (4 + (5 + (6 + 7))))))] 352 | [%s (7 + (6 + (5 + (4 + (3 + (2 + 1))))))]; 353 | "constants are evaluated", `Quick, check_proves_equal rules 354 | [%s (1 + (2 + (3 + (4 + (5 + (6 + 7))))))] 355 | [%s 28]; 356 | "symbols can be rearranged", `Quick, check_proves_equal rules 357 | [%s (1 + (x + (2 + (3 + (4 + (5 + (6 + 7)))))))] 358 | [%s (x + 28)]; 359 | ]); 360 | "proving arithmetic with full rule set", [ 361 | "subtraction works with symbols", `Quick, 362 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules 363 | [%s (x - x)] [%s 0.]; 364 | "subtraction works with non obvious equalities", `Quick, 365 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules 366 | [%s (x - (x + 0))] [%s 0.]; 367 | "subtraction works with complex expressions", `Quick, 368 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules 369 | [%s ((sqrt 5.) - (sqrt 5.))] [%s 0.]; 370 | "subtraction works with complex expressions and addition", `Quick, 371 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules 372 | [%s ((1 + (sqrt 5.)) - ((sqrt 5.) + 1))] [%s 0.]; 373 | "multiplication is rewritten", `Quick, 374 | check_proves_equal ~node_limit:(`Bounded 75_000) ~fuel:(`Bounded 30) rules 375 | [%s (1 - x)] [%s (1 + ("-1." * x))]; 376 | "multiplication is propagated", `Quick, 377 | check_proves_equal ~node_limit:(`Bounded 75_000) ~fuel:(`Bounded 30) rules 378 | [%s ((1 - x) + x)] [%s (1 + (x + ("-1." * x)))]; 379 | "1subtraction can be reverted", `Quick, 380 | check_proves_equal ~node_limit:(`Bounded 75_000) ~fuel:(`Bounded 30) rules 381 | [%s (1 + ("-1." * x))] [%s (1 - x)]; 382 | "subtraction can be cancelled", `Quick, 383 | check_proves_equal ~node_limit:(`Bounded 75_000) ~fuel:(`Bounded 30) rules 384 | [%s (x + ("-1." * x))] [%s 0]; 385 | "subtraction can be propagated and cancelled", `Quick, 386 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules 387 | [%s (1 + (x - x))] [%s ((1 - x) + x) ]; 388 | "complex subtraction can be propagated and cancelled", `Quick, 389 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules 390 | [%s (1 + (x - x))] [%s ((1 - x) + x) ]; 391 | "plus minus one can be propagated and cancelled", `Quick, 392 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules 393 | [%s ((1 - x) + (1 + x)) ] [%s 2]; 394 | "division can be simplified", `Quick, 395 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules 396 | [%s (2 / 2) ] [%s 1]; 397 | "division with numerator 0 can be simplified", `Quick, 398 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules 399 | [%s ((x - x) / 2) ] [%s 0]; 400 | "multiplication with 0 is simplified", `Quick, 401 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules 402 | [%s (x * 0.) ] [%s 0]; 403 | ]; 404 | "does not prove invalid equalities", [ 405 | "multiplication and addition are not equal", `Quick, 406 | check_cannot_prove_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules [%s (x + y)] [%s (x / y)] 407 | ]; 408 | "reasoning about derivatives", [ 409 | "dx/dy of x is 1", `Quick, 410 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules 411 | [%s (d x x)] [%s 1 ]; 412 | "dx/dy of y is 0", `Quick, 413 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules 414 | [%s (d x y)] [%s 0 ]; 415 | "dx/dy of 1 + 2x is 2", `Quick, 416 | check_proves_equal ~node_limit:(`Bounded 100_000) ~fuel:(`Bounded 35) rules 417 | [%s (d x (1 + (2. * x)))] [%s 2. ]; 418 | "dx/dy of xy + 1 is y", `Quick, 419 | check_extract ~node_limit:(`Unbounded) ~fuel:(`Bounded 15) rules 420 | [%s (d x (1. + (y * x)))] [%s y ]; 421 | "dx/dy of ln x is 1 / x", `Quick, 422 | check_proves_equal ~node_limit:(`Bounded 100_000) ~fuel:(`Bounded 35) rules 423 | [%s (d x (ln x))] [%s 1 / x ]; 424 | ]; 425 | ] 426 | -------------------------------------------------------------------------------- /test/test_prop.ml: -------------------------------------------------------------------------------- 1 | open Ego.Generic 2 | 3 | let sexp = 4 | (module struct 5 | type t = Sexplib.Sexp.t 6 | let pp = Sexplib.Sexp.pp_hum 7 | let equal = Sexplib.Sexp.equal 8 | end : Alcotest.TESTABLE with type t = Sexplib.Sexp.t) 9 | 10 | 11 | module L = struct 12 | 13 | type 'a shape = 14 | | And of 'a * 'a 15 | | Or of 'a * 'a 16 | | Not of 'a 17 | | Impl of 'a * 'a 18 | | Bool of bool 19 | | Symbol of string [@@deriving ord, show] 20 | 21 | type t = Mk of t shape [@@unboxed] 22 | 23 | type op = 24 | | AndOp 25 | | OrOp 26 | | NotOp 27 | | ImplOp 28 | | BoolOp of bool 29 | | SymbolOp of string [@@deriving eq, ord] 30 | 31 | let rec of_sexp = function[@warning "-8"] 32 | | Sexplib0.Sexp.Atom "true" -> Mk (Bool true) 33 | | Sexplib0.Sexp.Atom "false" -> Mk (Bool false) 34 | | Sexplib0.Sexp.Atom s -> Mk (Symbol s) 35 | | List [Atom "&&"; l; r] -> Mk (And (of_sexp l, of_sexp r)) 36 | | List [Atom "||"; l; r] -> Mk (Or (of_sexp l, of_sexp r)) 37 | | List [Atom "not"; l] -> Mk (Not (of_sexp l)) 38 | | List [Atom "=>"; l; r] -> Mk (Impl (of_sexp l, of_sexp r)) 39 | 40 | let op_of_string = function[@warning "-8"] 41 | | "true" -> (BoolOp true) 42 | | "false" -> (BoolOp false) 43 | | "&&" -> AndOp 44 | | "||" -> OrOp 45 | | "not" -> NotOp 46 | | "=>" -> ImplOp 47 | | s -> (SymbolOp s) 48 | 49 | 50 | let rec to_sexp = function[@warning "-8"] 51 | | Mk (Bool true) ->Sexplib0.Sexp.Atom "true" 52 | | Mk (Bool false) ->Sexplib0.Sexp.Atom "false" 53 | | Mk (Symbol s) ->Sexplib0.Sexp.Atom s 54 | | Mk (And (l, r)) -> List [Atom "&&"; to_sexp l; to_sexp r] 55 | | Mk (Or (l, r)) -> List [Atom "||"; to_sexp l; to_sexp r] 56 | | Mk (Not l) -> List [Atom "not"; to_sexp l] 57 | | Mk (Impl (l, r)) -> List [Atom "=>"; to_sexp l; to_sexp r] 58 | 59 | let op = function 60 | | And _ -> AndOp 61 | | Or _ -> OrOp 62 | | Not _ -> NotOp 63 | | Impl _ -> ImplOp 64 | | Bool b -> BoolOp b 65 | | Symbol s -> SymbolOp s 66 | 67 | let children = function 68 | | And (l,r) -> [l;r] 69 | | Or (l,r) -> [l;r] 70 | | Not l -> [l] 71 | | Impl (l,r) -> [l;r] 72 | | Bool _ | Symbol _ -> [] 73 | 74 | let map_children term f = match term with 75 | | And (l,r) -> And (f l, f r) 76 | | Or (l,r) -> Or (f l, f r) 77 | | Not l -> Not (f l) 78 | | Impl (l,r) -> Impl (f l, f r) 79 | | Bool b -> Bool b 80 | | Symbol s -> Symbol s 81 | 82 | let make op children = match[@warning "-8"] op,children with 83 | | AndOp, [l;r] -> And (l,r) 84 | | OrOp, [l;r] -> Or (l,r) 85 | | NotOp, [l] -> Not l 86 | | ImplOp, [l;r] -> Impl (l,r) 87 | | BoolOp b, [] -> Bool b 88 | | SymbolOp s, [] -> Symbol s 89 | 90 | end 91 | 92 | module C = struct 93 | type t = float [@@deriving ord] 94 | let cost f : Ego.Id.t L.shape -> t = function 95 | | L.And (l, r) -> f l +. f r +. 3. 96 | | L.Or (l, r) -> f l +. f r +. 2.0 97 | | L.Impl (l, r) -> f l +. f r +. 1.0 98 | | L.Not l -> f l +. 3.0 99 | | L.Symbol _ -> 1.0 100 | | L.Bool _ -> 1.0 101 | end 102 | 103 | module A = struct type t = unit type data = bool option[@@deriving eq,show] let default = None end 104 | module MA (S : GRAPH_API 105 | with type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph 106 | and type 'a shape := 'a L.shape 107 | and type analysis := A.t 108 | and type data := A.data 109 | and type node := L.t) = struct 110 | type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph 111 | 112 | let eval : A.data L.shape -> A.data = 113 | function 114 | | L.Bool c -> Some c 115 | | L.Not (Some b) -> Some (not b) 116 | | L.And (Some l, Some r) -> Some (l && r) 117 | | L.Or (Some l, Some r) -> Some (l || r) 118 | | L.Impl (Some l, Some r) -> Some ((not l) || r) 119 | | _ -> None 120 | 121 | let make : ro t -> Ego.Id.t L.shape -> A.data = 122 | fun graph term -> eval (L.map_children term (S.get_data graph)) 123 | 124 | let merge : A.t -> A.data -> A.data -> A.data * (bool * bool) = 125 | fun () l r -> match l,r with 126 | | Some l, Some r -> assert (l = r); Some l, (false, false) 127 | | Some l, None -> Some l, (false, true) 128 | | None, Some r -> Some r, (true, false) 129 | | _ -> None, (false, false) 130 | 131 | let modify : 'a t -> Ego.Id.t -> unit = 132 | fun graph cls -> 133 | match S.get_data (S.freeze graph) cls with 134 | | None -> () 135 | | Some n -> 136 | let nw_cls = S.add_node graph (L.Mk (Bool n)) in 137 | S.merge graph nw_cls cls 138 | 139 | end 140 | 141 | module EGraph = Make (L) (A) (MA) 142 | module Extractor = MakeExtractor (L) (C) 143 | 144 | 145 | let run_and_check1 ?node_limit ?fuel rules s1 f () = 146 | let graph = EGraph.init () in 147 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in 148 | let reached_saturation = EGraph.run_until_saturation ?node_limit ?fuel graph rules in 149 | begin 150 | match fuel, node_limit with 151 | _ , Some _ | Some _, _ -> () 152 | | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation 153 | end; 154 | f graph term_1 155 | 156 | let run_and_check2 ?node_limit ?fuel rules s1 s2 f () = 157 | let graph = EGraph.init () in 158 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in 159 | let term_2 = EGraph.add_node graph (L.of_sexp s2) in 160 | let reached_saturation = EGraph.run_until_saturation ?node_limit ?fuel graph rules in 161 | begin 162 | match fuel, node_limit with 163 | _ , Some _ | Some _, _ -> () 164 | | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation 165 | end; 166 | f graph term_1 term_2 167 | 168 | let check_proves_equal ?node_limit ?fuel rules s1 s2 () = 169 | let graph = EGraph.init () in 170 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in 171 | let term_2 = EGraph.add_node graph (L.of_sexp s2) in 172 | let terms_are_equal graph = EGraph.class_equal (EGraph.freeze graph) term_1 term_2 in 173 | let _reached_saturation = EGraph.run_until_saturation ~until:terms_are_equal ?node_limit ?fuel graph rules in 174 | (* begin 175 | * match fuel, node_limit with 176 | * _ , Some _ | Some _, _ -> () 177 | * | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation 178 | * end; *) 179 | Alcotest.(check bool) "proves terms are equal modulo rewriting" 180 | true 181 | (terms_are_equal graph) 182 | 183 | let check_cannot_prove_equal ?node_limit ?fuel rules s1 s2 = 184 | run_and_check2 ?node_limit ?fuel rules s1 s2 (fun graph term_1 term_2 -> 185 | Alcotest.(check bool) "must not prove terms are equal modulo rewriting" 186 | false 187 | (EGraph.class_equal (EGraph.freeze graph) term_1 term_2) 188 | ) 189 | 190 | let check_extract ?node_limit ?fuel rules s1 s2 = 191 | run_and_check1 ?node_limit ?fuel rules s1 (fun graph term_1 -> 192 | Alcotest.(check sexp) 193 | "extracted expression matches" 194 | s2 195 | (L.to_sexp (Extractor.extract graph term_1)) 196 | ) 197 | 198 | 199 | 200 | let qf = Query.of_sexp L.op_of_string 201 | let (@->) from into = EGraph.Rule.make_constant ~from:(qf from) ~into:(qf into) 202 | let rewrite from into ~if_ = EGraph.Rule.make_conditional ~from:(qf from) ~into:(qf into) ~cond:if_ 203 | 204 | let rules = [ 205 | (* def_imply *) [%s ("?a" => "?b")] @-> [%s ((not "?a") || "?b")]; 206 | (* double_neg *) [%s (not (not "?a"))] @-> [%s "?a"]; 207 | (* assoc_or *) [%s ( "?a" || ("?b" || "?c"))] @-> [%s (("?a" || "?b") || "?c")]; 208 | (* dist_and_or *) [%s ("?a" && ("?b" || "?c"))] @-> [%s (("?a" && "?b") || ("?a" && "?c"))]; 209 | (* dist_or_and *) [%s ("?a" || ("?b" || "?c"))] @-> [%s (("?a" || "?b") && ("?a" || "?c"))]; 210 | (* comm_or *) [%s ("?a" || "?b")] @-> [%s ("?b" || "?a")]; 211 | (* comm_and *) [%s ("?a" && "?b")] @-> [%s ("?b" && "?a")]; 212 | (* lem *) [%s ("?a" || (not "?a"))] @-> [%s"true"]; 213 | (* or_true *) [%s ("?a" || "true")] @-> [%s "true"]; 214 | (* and_true *) [%s ("?a" && "true")] @-> [%s"?a"]; 215 | (* contrapositive *) [%s ("?a" => "?b")] @-> [%s ((not "?b") => (not "?a"))]; 216 | (* lem_imply *) [%s (("?a" => "?b") && ((not "?a") => "?c")) ] @-> [%s ("?b" || "?c") ]; 217 | ] 218 | 219 | let proves ?(match_limit=1_000) ?(ban_length=5) ?node_limit ?fuel start goals () = 220 | let graph = EGraph.init () in 221 | let start = EGraph.add_node graph (L.of_sexp start) in 222 | let scheduler = Ego.Generic.Scheduler.Backoff.with_params ~match_limit ~ban_length in 223 | ignore @@ EGraph.run_until_saturation ~scheduler ?fuel ?node_limit graph rules; 224 | List.iter (fun goal -> 225 | let goal = EGraph.add_node graph (L.of_sexp goal) in 226 | Alcotest.(check bool) 227 | "goal can be proved from start" 228 | true 229 | (EGraph.class_equal (EGraph.freeze graph) start goal) 230 | ) goals 231 | 232 | let proves_cached ?(match_limit=1_000) ?(ban_length=5) ?node_limit ?fuel start goals () = 233 | let graph = EGraph.init () in 234 | let start = EGraph.add_node graph (L.of_sexp start) in 235 | let goals = List.map (fun goal -> EGraph.add_node graph (L.of_sexp goal)) goals in 236 | let last = 237 | let rec last acc ls = match ls with 238 | | [] -> acc 239 | | h :: t -> last h t in 240 | last start goals in 241 | let scheduler = Ego.Generic.Scheduler.Backoff.with_params ~match_limit ~ban_length in 242 | ignore @@ EGraph.run_until_saturation ~scheduler ?fuel ?node_limit ~until:(fun graph -> 243 | EGraph.class_equal (EGraph.freeze graph) start last 244 | ) graph rules; 245 | List.iter (fun goal -> 246 | Alcotest.(check bool) 247 | "goal can be proved from start" 248 | true 249 | (EGraph.class_equal (EGraph.freeze graph) start goal) 250 | ) goals 251 | 252 | let () = 253 | Alcotest.run "prop" [ 254 | "ematch tests", [ 255 | "check matches after merging", `Quick, 256 | (fun () -> let graph = EGraph.init () in 257 | let n1 = EGraph.add_node graph (L.of_sexp [%s (x && z)]) in 258 | let n2 = EGraph.add_node graph (L.of_sexp [%s (y && z)]) in 259 | EGraph.merge graph n1 n2; 260 | EGraph.rebuild graph; 261 | let query = qf [%s "?a" && z] in 262 | let matches = EGraph.find_matches (EGraph.freeze graph) query |> Iter.length in 263 | Alcotest.(check int) "2 matches" 2 matches); 264 | 265 | "check matches after saturating", `Quick, 266 | fun () -> let graph = EGraph.init () in 267 | let scheduler = Ego.Generic.Scheduler.Backoff.with_params ~match_limit:1000 ~ban_length:5 in 268 | let _ = EGraph.add_node graph (L.of_sexp [%s (x && y)]) in 269 | let query = [%s "?a" && "?b"] @-> [%s "?b" && "?a"] in 270 | ignore @@ EGraph.run_until_saturation ~scheduler graph [query]; 271 | let q = qf [%s "?a" && "?b"] in 272 | let matches = EGraph.find_matches (EGraph.freeze graph) q |> Iter.length in 273 | Alcotest.(check int) "2 matches" 2 matches 274 | ]; 275 | "proving contrapositive", [ 276 | "proves idempotent", `Quick, proves [%s (x => y)] [[%s (x => y)]]; 277 | "proves negation", `Quick, proves [%s (x => y)] [[%s (x => y)]; 278 | [%s ((not x) || y)]]; 279 | "proves double negation", `Quick, proves [%s (x => y)] [[%s (x => y)]; 280 | [%s ((not x) || y)]; 281 | [%s ((not x) || (not (not y)))]]; 282 | "proves commutativity", `Quick, proves [%s (x => y)] [[%s (x => y)]; 283 | [%s ((not x) || y)]; 284 | [%s ((not x) || (not (not y)))]; 285 | [%s ((not (not y)) || (not x))]; 286 | ]; 287 | "proves contrapositive", `Quick, proves [%s (x => y)] [[%s (x => y)]; 288 | [%s ((not x) || y)]; 289 | [%s ((not x) || (not (not y)))]; 290 | [%s ((not (not y)) || (not x))]; 291 | [%s ((not y) => (not x))]; 292 | ]]; 293 | "proving chain", [ 294 | "proves idempotent", `Quick, proves [%s ((x => y) && (y => z))] [[%s ((x => y) && (y => z))]]; 295 | "proves contrapositive", `Quick, proves [%s ((x => y) && (y => z))] 296 | [[%s ((x => y) && (y => z))]; 297 | [%s (((not y) => (not x)) && (y => z))]]; 298 | "proves commutativity", `Quick, proves [%s ((x => y) && (y => z))] 299 | [[%s ((x => y) && (y => z))]; 300 | [%s (((not y) => (not x)) && (y => z))]; 301 | [%s ((y => z) && ((not y) => (not x)))]]; 302 | "proves negation", `Quick, proves [%s ((x => y) && (y => z))] 303 | [[%s ((x => y) && (y => z))]; 304 | [%s (((not y) => (not x)) && (y => z))]; 305 | [%s ((y => z) && ((not y) => (not x)))]; 306 | [%s (z || (not x))] 307 | ]; 308 | "proves commutativity", `Quick, proves 309 | ~node_limit:(`Bounded 10_000) 310 | ~fuel:(`Bounded 60) 311 | [%s ((x => y) && (y => z))] 312 | [[%s ((x => y) && (y => z))]; 313 | [%s (((not y) => (not x)) && (y => z))]; 314 | [%s ((y => z) && ((not y) => (not x)))]; 315 | [%s (z || (not x))]; 316 | [%s ((not x) || z)]; ]; 317 | "proves chain", `Quick, proves_cached 318 | ~match_limit:(10_000) ~ban_length:5 319 | ~node_limit:(`Bounded 600_000) 320 | ~fuel:(`Bounded 50) 321 | [%s ((x => y) && (y => z))] 322 | [[%s ((x => y) && (y => z))]; 323 | [%s (((not y) => (not x)) && (y => z))]; 324 | [%s ((y => z) && ((not y) => (not x)))]; 325 | [%s (z || (not x))]; 326 | [%s ((not x) || z)]; 327 | [%s (x => z)]; 328 | ] 329 | ]] 330 | 331 | --------------------------------------------------------------------------------