├── .gitignore
├── .gitlab-ci.yml
├── LICENSE
├── doc
├── dune
└── index.mld
├── dune-project
├── ego.opam
├── lib
├── basic.ml
├── basic.mli
├── dune
├── ego.ml
├── ego.mli
├── equivalence.ml
├── equivalence.mli
├── generic.ml
├── generic.mli
├── id.ml
├── id.mli
├── language.ml
├── ordered_set.ml
├── ordered_set.mli
├── query.ml
├── scheduler.ml
├── symbol.ml
├── symbol.mli
├── term.ml
└── types.ml
├── macros
├── dune
└── ppx_sexp.ml
├── readme.md
└── test
├── dune
├── test_basic.ml
├── test_generic.ml
├── test_math.ml
└── test_prop.ml
/.gitignore:
--------------------------------------------------------------------------------
1 | _build/
--------------------------------------------------------------------------------
/.gitlab-ci.yml:
--------------------------------------------------------------------------------
1 | image: ruby:2.7
2 |
3 | pages:
4 | script:
5 | - echo 'Nothing to do...'
6 | artifacts:
7 | paths:
8 | - public/
9 | only:
10 | - pages
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/doc/dune:
--------------------------------------------------------------------------------
1 | (documentation
2 | (package ego))
3 |
4 |
--------------------------------------------------------------------------------
/dune-project:
--------------------------------------------------------------------------------
1 | (lang dune 2.9)
2 | (package
3 | (name ego)
4 | (synopsis "Ego (EGraphs OCaml) is extensible EGraph library for OCaml.")
5 | (description
6 | "Ego is an exensible egraph library for OCaml loosely based on the egg library in Rust.")
7 | (depends
8 | (ocaml (>= 4.0.8))
9 | (containers (>= 3.3))
10 | (containers-data (>= 3.3))
11 | (iter (>= 1.2.1))
12 | (ppx_deriving (>= 4.4))
13 | (ocamldot (>= 1.1))
14 | (sexplib (>= v0.14.0))))
15 | (version 0.0.6)
16 | (name ego)
17 | (generate_opam_files true)
18 | (license GPL-3.0+)
19 | (source (uri git+https://gitlab.com/gopiandcode/ego.git))
20 | (bug_reports https://gitlab.com/gopiandcode/ego/issues)
21 | (homepage https://gitlab.com/gopiandcode/ego)
22 | (authors "Kiran Gopinathan")
23 | (maintainers "kirang@comp.nus.edu.sg")
24 |
--------------------------------------------------------------------------------
/ego.opam:
--------------------------------------------------------------------------------
1 | # This file is generated by dune, edit dune-project instead
2 | opam-version: "2.0"
3 | version: "0.0.6"
4 | synopsis: "Ego (EGraphs OCaml) is extensible EGraph library for OCaml."
5 | description:
6 | "Ego is an exensible egraph library for OCaml loosely based on the egg library in Rust."
7 | maintainer: ["kirang@comp.nus.edu.sg"]
8 | authors: ["Kiran Gopinathan"]
9 | license: "GPL-3.0+"
10 | homepage: "https://gitlab.com/gopiandcode/ego"
11 | bug-reports: "https://gitlab.com/gopiandcode/ego/issues"
12 | depends: [
13 | "dune" {>= "2.9"}
14 | "ocaml" {>= "4.0.8"}
15 | "containers" {>= "3.3"}
16 | "containers-data" {>= "3.3"}
17 | "iter" {>= "1.2.1"}
18 | "ppx_deriving" {>= "4.4"}
19 | "ocamldot" {>= "1.1"}
20 | "sexplib" {>= "v0.14.0"}
21 | "odoc" {with-doc}
22 | "ppx_inline_alcotest" {with-test}
23 | ]
24 | build: [
25 | ["dune" "subst"] {dev}
26 | [
27 | "dune"
28 | "build"
29 | "-p"
30 | name
31 | "-j"
32 | jobs
33 | "--promote-install-files=false"
34 | "@install"
35 | "@runtest" {with-test}
36 | "@doc" {with-doc}
37 | ]
38 | ["dune" "install" "-p" name "--create-install-files" name]
39 | ]
40 | dev-repo: "git+https://gitlab.com/gopiandcode/ego.git"
41 |
--------------------------------------------------------------------------------
/lib/basic.ml:
--------------------------------------------------------------------------------
1 | open [@warning "-33"] Containers
2 | open Language
3 | open Types
4 |
5 | module StringMap = Map.Make(String)
6 | let str p v = Format.to_string p v
7 | let lappend_pair a (b,c) = (a,b,c)
8 |
9 | module Symbol = Symbol
10 |
11 | module Query = struct
12 | module Query = Query
13 |
14 | type 'a query = 'a Query.t = V of string | Q of 'a * 'a query list
15 | type t = Symbol.t query
16 |
17 | let of_sexp = Query.of_sexp Symbol.intern
18 | let to_sexp = Query.to_sexp (Format.to_string Symbol.pp)
19 |
20 | let pp = Query.pp Symbol.pp
21 | let show = Format.to_string pp
22 |
23 | let variables = Query.variables
24 |
25 | end
26 |
27 | module Term = Term
28 |
29 | module Rule : sig
30 | type 'sym rule = Query.t * Query.t
31 | type t = Symbol.t rule [@@deriving show]
32 | val make: from:Query.t -> into:Query.t -> t option
33 | end = struct
34 | type 'sym rule = Query.t * Query.t
35 | type t = Symbol.t rule
36 |
37 | let make ~from ~into =
38 | let pattern_vars = Query.variables from in
39 | let rewrite_vars = Query.variables into in
40 | if StringSet.subset rewrite_vars pattern_vars
41 | then Some (from, into)
42 | else None
43 |
44 | let pp fmt (lhs,rhs) =
45 | Format.pp_open_hbox fmt ();
46 | Query.pp fmt lhs;
47 | Format.pp_print_string fmt " -> ";
48 | Query.pp fmt rhs;
49 | Format.pp_close_box fmt ()
50 |
51 | let show = str pp
52 |
53 | let%test "rules are printed as expected" =
54 | Alcotest.(check string)
55 | "prints as expected"
56 | "(<< ?a 1) -> (* ?a 2)" (str pp ((Q (Symbol.intern "<<", [V "a"; Q (Symbol.intern "1", [])])),
57 | (Q (Symbol.intern "*", [V "a"; Q (Symbol.intern "2", [])]))))
58 |
59 | end
60 |
61 |
62 |
63 | (* * Egraphs *)
64 | (* ** Types *)
65 | type egraph = {
66 | mutable version: int;
67 |
68 | uf: Id.store; (* tracks equivalence classes of
69 | class ids *)
70 | class_members: (enode * Id.t) Vector.vector Id.Map.t; (* maps classes to the canonical nodes
71 | they contain, and any classes that are
72 | children of these nodes *)
73 | hash_cons: (enode, Id.t) Hashtbl.t; (* maps cannonical nodes to their
74 | equivalence classes *)
75 | worklist: Id.t Vector.vector; (* List of equivalence classes for which
76 | nodes are out of date - i.e
77 | cannoncial(node) != node *)
78 | }
79 |
80 |
81 |
82 |
83 |
84 |
85 | (* ** Graph *)
86 | module EGraph = struct
87 |
88 | type t = egraph
89 |
90 | (* *** Pretty printing *)
91 | let pp ?(pp_id=EClassId.pp) fmt self =
92 | let open Format in
93 | pp_print_string fmt "(egraph";
94 | pp_open_hovbox fmt 1;
95 | pp_print_space fmt ();
96 | pp_print_string fmt "(eclasses ";
97 | pp_open_hvbox fmt 1;
98 | Id.Map.to_seq self.class_members
99 | |> Seq.to_list
100 | |> pp_print_list ~pp_sep:pp_print_space
101 | (fun fmt (cls, elts) ->
102 | pp_print_string fmt "(";
103 | pp_open_hvbox fmt 1;
104 | pp_id fmt cls;
105 | if not @@ Vector.is_empty elts then
106 | pp_print_space fmt ();
107 | Vector.pp ~pp_sep:pp_print_space
108 | (fun fmt (node, id) ->
109 | pp_print_string fmt "(";
110 | pp_open_hbox fmt ();
111 | pp_id fmt id;
112 | pp_print_space fmt ();
113 | ENode.pp ~pp_id fmt node;
114 | pp_close_box fmt ();
115 | pp_print_string fmt ")";
116 | ) fmt elts;
117 | pp_close_box fmt ();
118 | pp_print_string fmt ")";
119 | ) fmt;
120 | pp_close_box fmt ();
121 | pp_print_string fmt ")";
122 | pp_print_space fmt ();
123 | pp_print_string fmt "(enodes ";
124 | pp_open_hvbox fmt 1;
125 | Hashtbl.to_seq self.hash_cons
126 | |> Seq.to_list
127 | |> pp_print_list ~pp_sep:pp_print_space
128 | (fun fmt (node, cls) ->
129 | pp_print_string fmt "(";
130 | pp_open_hvbox fmt 1;
131 | pp_id fmt cls;
132 | pp_print_space fmt ();
133 | ENode.pp ~pp_id fmt node;
134 | pp_close_box fmt ();
135 | pp_print_string fmt ")";
136 | ) fmt;
137 | pp_close_box fmt ();
138 | pp_print_string fmt ")";
139 | pp_close_box fmt ();
140 | pp_print_string fmt ")"
141 |
142 | let (.@[]) self fn = fn self [@@inline always]
143 |
144 |
145 | (* *** Initialization *)
146 | let init () = {
147 | version=0;
148 | uf=Id.create_store ();
149 | class_members=Id.Map.create 10;
150 | hash_cons=Hashtbl.create 10;
151 | worklist=Vector.create ();
152 | }
153 |
154 | (* *** Eclasses *)
155 | let new_class self =
156 | let id = Id.make self.uf () in
157 | id
158 |
159 | let get_class_members self id =
160 | match Id.Map.find_opt self.class_members id with
161 | | Some classes -> classes
162 | | None ->
163 | let cls = Vector.create () in
164 | Id.Map.add self.class_members id cls;
165 | cls
166 |
167 | (* Adds a node into the egraph, assuming that the cannonical version
168 | of the node is up to date in the hash cons or
169 | *)
170 | let add_enode self node =
171 | let node = ENode.canonicalise self.uf node in
172 | let id = match Hashtbl.find_opt self.hash_cons node with
173 | | None ->
174 | self.version <- self.version + 1;
175 | (* There are no nodes congruent to this node in the graph *)
176 | let id = self.@[new_class] in
177 | let cls = self.@[get_class_members] id in
178 | Vector.append_list cls @@ List.map (fun child ->
179 | (node, child)
180 | ) (ENode.children node);
181 | Hashtbl.replace self.hash_cons node id;
182 | id
183 | | Some id -> id in
184 | Id.find self.uf id
185 |
186 | let rec subst self pat env =
187 | match pat with
188 | | Query.V id -> StringMap.find id env
189 | | Q (sym, args) ->
190 | let enode = (sym, List.map (fun arg -> self.@[subst] arg env) args) in
191 | self.@[add_enode] enode
192 |
193 | let add_node self ((sym, children) : Term.t) =
194 | add_enode self (Symbol.intern sym, children)
195 |
196 | let add_sexp self sexp = add_node self @@ Term.of_sexp (add_node self) sexp
197 |
198 | let find self vl = Id.find self.uf vl
199 |
200 | let append_to_worklist self vl =
201 | Vector.push self.worklist vl
202 |
203 | let merge self a b =
204 | let (+=) va vb = Vector.append va vb in
205 | let a = Id.find self.uf a in
206 | let b = Id.find self.uf b in
207 | if Id.eq_id a b then ()
208 | else begin
209 | self.version <- self.version + 1;
210 | assert (Id.eq_id a (Id.union self.uf a b));
211 | assert (Id.eq_id a (Id.find self.uf a));
212 | assert (Id.eq_id a (Id.find self.uf b));
213 | self.@[get_class_members] b += self.@[get_class_members] a;
214 | Vector.clear (self.@[get_class_members] a);
215 | self.@[append_to_worklist] b;
216 | end
217 |
218 | let repair self ecls_id =
219 | let (+=) va vb = Vector.append_iter va vb in
220 | let uses = self.@[get_class_members] ecls_id in
221 | let uses =
222 | let res = Vector.copy uses in
223 | Vector.clear uses;
224 | res in
225 | (* update canonical uses in hashcons *)
226 | Vector.to_iter uses (fun (p_node, p_eclass) ->
227 | Hashtbl.remove self.hash_cons p_node;
228 | let p_node = self.uf.@[ENode.canonicalise] p_node in
229 | Hashtbl.replace self.hash_cons p_node (self.@[find] p_eclass)
230 | );
231 | let new_uses = Hashtbl.create 10 in
232 | Vector.to_iter uses (fun (p_node, p_eclass) ->
233 | let p_node = self.uf.@[ENode.canonicalise] p_node in
234 | begin match Hashtbl.find_opt new_uses p_node with
235 | | None -> ()
236 | | Some nd -> self.@[merge] p_eclass nd
237 | end;
238 | Hashtbl.replace new_uses p_node (self.@[find] p_eclass)
239 | );
240 | (self.@[get_class_members] (self.@[find] ecls_id)) += (Hashtbl.to_iter new_uses)
241 |
242 | let rebuild self =
243 | while not @@ Vector.is_empty self.worklist do
244 | let worklist = Id.Set.of_iter (Vector.to_iter self.worklist |> Iter.map (self.@[find])) in
245 | Vector.clear self.worklist;
246 | Id.Set.to_iter worklist (fun ecls_id ->
247 | self.@[repair] ecls_id
248 | )
249 | done
250 |
251 | (* *** Exports *)
252 | (* **** Export eclasses *)
253 | let eclasses self =
254 | let r = Id.Map.create 10 in
255 | Hashtbl.iter (fun node eid ->
256 | let eid = Id.find self.uf eid in
257 | match Id.Map.find_opt r eid with
258 | | None -> let ls = Vector.of_list [node] in Id.Map.add r eid ls
259 | | Some ls -> Vector.push ls node
260 | ) self.hash_cons;
261 | r
262 |
263 | (* **** Export as dot *)
264 | let to_dot self =
265 | let eclasses = eclasses self in
266 | let stmt_list =
267 | let rev_map =
268 | Hashtbl.to_seq self.hash_cons
269 | |> Seq.map Pair.swap
270 | |> Id.Map.of_seq in
271 | let to_label id =
272 | let rec to_str id =
273 | match Id.Map.find_opt rev_map id with
274 | | None -> Format.to_string EClassId.pp id
275 | | Some (sym, []) -> Format.to_string Symbol.pp sym
276 | | Some (sym, children) ->
277 | Printf.sprintf "(%s %s)"
278 | (Format.to_string Symbol.pp sym)
279 | (List.to_string ~sep:" " to_str children) in
280 | to_str id in
281 | let to_label_node (sym,children) =
282 | match children with
283 | | [] -> Format.to_string Symbol.pp sym
284 | | children ->
285 | Printf.sprintf "(%s %s)"
286 | (Format.to_string Symbol.pp sym)
287 | (List.to_string ~sep:" " to_label children) in
288 | let to_id id =
289 | Odot.Double_quoted_id (to_label id) in
290 | let to_node_id node =
291 | Odot.Double_quoted_id (to_label_node node) in
292 | let to_subgraph_id id =
293 | Odot.Simple_id (Printf.sprintf "cluster_%d" (Id.repr id)) in
294 | let sub_graphs =
295 | (fun f -> Fun.flip Id.Map.iter eclasses (Fun.curry f))
296 | |> Iter.map (fun (eclass, enodes) ->
297 | let nodes =
298 | Vector.to_iter enodes
299 | |> Iter.map (fun (node: enode) ->
300 | let node_id = to_node_id node in
301 | let attrs = Odot.[Simple_id "label",
302 | Some (Double_quoted_id
303 | (Format.to_string Symbol.pp (fst node)))] in
304 | Odot.Stmt_node ((node_id, None), attrs))
305 | |> Iter.to_list in
306 | Odot.(Stmt_subgraph {
307 | sub_id= Some (to_subgraph_id eclass);
308 | sub_stmt_list=
309 | Stmt_attr (
310 | Attr_graph [
311 | (Simple_id "label", Some (Simple_id (Format.to_string EClassId.pp eclass)))
312 | ]) :: nodes;
313 | })
314 | )
315 | |> Iter.to_list in
316 | let edges =
317 | (fun f -> Fun.flip Id.Map.iter eclasses (Fun.curry f))
318 | |> Iter.flat_map (fun (_eclass, enodes) ->
319 | Vector.to_iter enodes
320 | |> Iter.flat_map (fun node ->
321 | let label = to_node_id node in
322 | Iter.of_list (ENode.children node)
323 | |> Iter.map (fun child ->
324 | let child_label = to_id child in
325 | Odot.(Stmt_edge (
326 | Edge_node_id (label, None),
327 | [Edge_node_id (child_label, None)],
328 | []
329 | ))
330 | )
331 | )
332 | )
333 | |> Iter.to_list in
334 | (List.append sub_graphs edges) in
335 | Odot.{
336 | strict=true;
337 | kind=Digraph;
338 | id=None;
339 | stmt_list;
340 | }
341 |
342 | (* **** Print as dot *)
343 | let pp_dot fmt st =
344 | Format.pp_print_string fmt (Odot.string_of_graph (to_dot st))
345 |
346 | let extract cost eg =
347 | let eclasses = eg.@[eclasses] in
348 | let cost_map = Id.Map.create 10 in
349 | let node_total_cost node =
350 | let has_cost id = Id.Map.mem cost_map (eg.@[find] id) in
351 | if List.for_all has_cost (Term.children node)
352 | then let cost_f id = fst @@ Id.Map.find cost_map (eg.@[find] id) in Some (cost cost_f node)
353 | else None in
354 | let make_pass enodes =
355 | let cost, node =
356 | Vector.to_iter enodes
357 | |> Iter.map (fun n -> (node_total_cost n, n))
358 | |> Iter.min_exn ~lt:(fun (c1, _) (c2, _) ->
359 | (match c1, c2 with
360 | | None, None -> 0
361 | | Some _, None -> -1
362 | | None, Some _ -> 1
363 | | Some c1, Some c2 -> Float.compare c1 c2) = -1) in
364 | Option.map (fun cost -> (cost, node)) cost in
365 | let find_costs () =
366 | let any_changes = ref true in
367 | while !any_changes do
368 | any_changes := false;
369 | Fun.flip Id.Map.iter eclasses (fun eclass enodes ->
370 | let pass = make_pass enodes in
371 | match Id.Map.find_opt cost_map eclass, pass with
372 | | None, Some nw -> Id.Map.replace cost_map eclass nw;
373 | any_changes := true
374 | | Some ((cold, _)), Some ((cnew, _) as nw)
375 | when Float.compare cnew cold = -1 ->
376 | Id.Map.replace cost_map eclass nw;
377 | any_changes := true
378 | | _ -> ()
379 | )
380 | done in
381 | let rec extract eid =
382 | let eid = find eg eid in
383 | let enode = Id.Map.find cost_map eid |> snd in
384 | let head = Atom (Format.to_string Symbol.pp @@ fst enode) in
385 | match ENode.children enode with
386 | | [] -> head
387 | | children -> List (head :: List.map extract children) in
388 | find_costs ();
389 | fun result -> extract result
390 |
391 |
392 | (* ** Matching *)
393 | let ematch eg classes pattern =
394 | let concat_map f l = Iter.concat (Iter.map f l) in
395 | let rec enode_matches p enode env =
396 | match[@warning "-8"] p,enode with
397 | | Query.(Q (f, _), (f', _)) when not @@ (Equal.map Symbol.repr Equal.int) f f' ->
398 | Iter.empty
399 | | (Q (_, args), (_, args')) ->
400 | (fun f -> List.iter2 (Fun.curry f) args args')
401 | |> Iter.fold (fun envs (qvar, trm) ->
402 | concat_map (fun env' -> match_in qvar trm env') envs) (Iter.singleton env)
403 | and match_in p eid env =
404 | let eid = find eg eid in
405 | match p with
406 | | V id -> begin
407 | match StringMap.find_opt id env with
408 | | None -> Iter.singleton (StringMap.add id eid env)
409 | | Some eid' when Id.eq_id eid eid' -> Iter.singleton env
410 | | _ -> Iter.empty
411 | end
412 | | p ->
413 | match Id.Map.find_opt classes eid with
414 | | Some v -> Vector.to_iter v |> concat_map (fun enode -> enode_matches p enode env)
415 | | None -> Iter.empty
416 | in
417 | (fun f -> Id.Map.iter (Fun.curry f) classes)
418 | |> concat_map (fun (eid, _) ->
419 | Iter.map (fun s -> (eid, s)) (match_in pattern eid StringMap.empty))
420 |
421 | (* ** Rewriting System *)
422 | let apply_rules eg rules =
423 | let eclasses = eclasses eg in
424 | let find_matches (from_rule, to_rule) =
425 | ematch eg eclasses from_rule |> Iter.map (lappend_pair to_rule) in
426 | let for_each_match = Iter.of_list rules |> Iter.flat_map find_matches in
427 | for_each_match begin fun (to_rule, eid, env) ->
428 | let new_eid = subst eg to_rule env in
429 | merge eg eid new_eid
430 | end;
431 | rebuild eg
432 |
433 | let run_until_saturation ?fuel eg rules =
434 | match fuel with
435 | | None ->
436 | let rec loop last_version =
437 | apply_rules eg rules;
438 | if not @@ Int.equal eg.version last_version
439 | then loop eg.version
440 | else () in
441 | loop eg.version; true
442 | | Some fuel ->
443 | let rec loop fuel last_version =
444 | apply_rules eg rules;
445 | if not @@ Int.equal eg.version last_version
446 | then if fuel > 0
447 | then loop (fuel - 1) eg.version
448 | else false
449 | else true in
450 | loop fuel eg.version
451 |
452 |
453 | end
454 |
455 | let%test "test egraph matching" =
456 | let g = EGraph.init () in
457 | let g1 = EGraph.add_sexp g [%s g 1] in
458 | let g2 = EGraph.add_sexp g [%s g 2] in
459 | EGraph.merge g g1 g2;
460 | EGraph.rebuild g;
461 | let query = Query.of_sexp [%s g "?a"] in
462 | let matches = EGraph.ematch g (EGraph.eclasses g) query |> Iter.to_list in
463 | (* Should have two matches: (g 1) and (g 2) *)
464 | Alcotest.(check int) "(g ?a) has 2 matches"
465 | 2 (List.length matches)
466 |
467 | let%test "test egraph matching" =
468 | let g = EGraph.init () in
469 | let g1 = EGraph.add_sexp g [%s g 1] in
470 | let g2 = EGraph.add_sexp g [%s g 2] in
471 | let g3 = EGraph.add_sexp g [%s g 3] in
472 | let f1 = EGraph.add_sexp g [%s (f 1 (g 2))] in
473 | let f2 = EGraph.add_sexp g [%s (f 2 (g 3))] in
474 | let f3 = EGraph.add_sexp g [%s (f 3 (g 1))] in
475 | EGraph.merge g g1 g2;
476 | EGraph.merge g g2 g3;
477 | EGraph.merge g f1 f2;
478 | EGraph.merge g f2 f3;
479 | EGraph.rebuild g;
480 | let query = Query.of_sexp [%s f "?a" (g "?a")] in
481 | let matches = EGraph.ematch g (EGraph.eclasses g) query |> Iter.to_list in
482 | Alcotest.(check int) "has 3 matches"
483 | 3 (List.length matches)
484 |
--------------------------------------------------------------------------------
/lib/basic.mli:
--------------------------------------------------------------------------------
1 | type egraph
2 |
3 | module Symbol : sig
4 | type t = private int
5 | val intern : string -> t
6 | val to_string : t -> string
7 | end
8 |
9 | module Query : sig
10 | type t [@@deriving show]
11 | val of_sexp : Sexplib0.Sexp.t -> t
12 | val to_sexp : t -> Sexplib0.Sexp.t
13 | end
14 |
15 | module Rule : sig
16 |
17 | type t [@@deriving show]
18 |
19 | val make: from:Query.t -> into:Query.t -> t option
20 |
21 | end
22 |
23 | module EGraph : sig
24 | type t = egraph
25 |
26 | val pp : ?pp_id:(Format.formatter -> Id.t -> unit) -> Format.formatter -> t -> unit
27 | val pp_dot : Format.formatter -> t -> unit
28 |
29 | val init : unit -> t
30 |
31 | val add_sexp: t -> Sexplib.Sexp.t -> Id.t
32 |
33 | val to_dot : t -> Odot.graph
34 |
35 | val merge : t -> Id.t -> Id.t -> unit
36 |
37 | val rebuild : t -> unit
38 |
39 | val extract: ((Id.t -> float) -> (Symbol.t * Id.t list) -> float) -> t -> Id.t -> Sexplib0.Sexp.t
40 |
41 | val apply_rules : t -> Rule.t list -> unit
42 |
43 | val run_until_saturation: ?fuel:int -> t -> Rule.t list -> bool
44 |
45 | end
46 |
47 |
--------------------------------------------------------------------------------
/lib/dune:
--------------------------------------------------------------------------------
1 | (library
2 | (name ego)
3 | (public_name ego)
4 | (libraries containers iter dot containers-data sexplib)
5 | (inline_tests)
6 | (preprocess (pps ppx_sexp ppx_inline_alcotest ppx_deriving.std)))
7 |
8 | (env
9 | (dev
10 | (flags (:standard -w -58))))
11 |
--------------------------------------------------------------------------------
/lib/ego.ml:
--------------------------------------------------------------------------------
1 | module Id = Id
2 | module Basic = Basic
3 | module Generic = struct
4 | module Query = Query
5 | module Scheduler = Scheduler
6 | include Language
7 | include Generic
8 | end
9 |
--------------------------------------------------------------------------------
/lib/ego.mli:
--------------------------------------------------------------------------------
1 | (** Ego is an extensible egraph library for OCaml. The interface to
2 | Ego is loosely based on the Rust's egg library and reimplements
3 | their EClass analysis in pure OCaml.
4 |
5 | {{:#top}Ego} provides two interfaces to its equality saturation
6 | engine:
7 |
8 | 1. {!Ego.Basic} - an out-of-the-box interface to pure equality
9 | saturation (i.e supporting only syntactic rewrites).
10 |
11 | 2. {!Ego.Generic} - a higher order interface to equality saturation,
12 | parameterised over custom-user defined analyses.
13 |
14 | You may want to check out the {{:../index.html} quick start guide}.
15 | *)
16 |
17 | module Id : sig
18 | (** This module provides an implementation of an {i efficient}
19 | {b union-find} data-structure. It's main exported type,
20 | {!t}, is used to represent equivalence classes in the EGraph
21 | data-structures provided by {!Ego}. *)
22 |
23 |
24 | type t = private int
25 | (** An abstract datatype used to represent equivalence classes in
26 | {!Ego}. *)
27 |
28 | end
29 |
30 |
31 |
32 | module Basic: sig
33 |
34 | (** This module implements a {i fairly efficient}
35 | "syntactic-rewrite-only" EGraph-based equality saturation engine
36 | that operates over Sexps.
37 |
38 | The main interface to EGraph is under the module {!EGraph}.
39 |
40 | Note: This module is not safe for serialization as it uses
41 | {!Symbol.t} internally to represent strings, and so will be
42 | dependent on the execution context. If you wish to persist
43 | EGraphs across executions, check out the EGraphs defined in
44 | {!Ego.Generic} *)
45 |
46 | module Symbol : sig
47 | (** Implements an efficient encoding of strings
48 |
49 | Note: Datatypes using this module are not safe for
50 | serialization as tag associated with each string dependent on
51 | the execution context.
52 |
53 | If you wish to persist EGraphs across executions, check out the
54 | EGraphs defined in {!Ego.Generic} *)
55 |
56 | type t = private int
57 | (** Abstract type providing an efficient encoding of some string value. *)
58 |
59 | val intern : string -> t
60 | (** [intern s] returns a symbol representing the string [s]. *)
61 |
62 | val to_string : t -> string
63 | (** [to_string t] returns the string associated with symbol [t]. *)
64 | end
65 |
66 | module Query : sig
67 | (** This module encodes patterns (for both matching and
68 | transformation) over Sexprs and is part of {!Ego.Basic}'s API
69 | for expressing syntactic rewrites. *)
70 |
71 | type t
72 | (** Encodes a pattern over S-expressions. *)
73 |
74 | val pp: Format.formatter -> t -> unit
75 | (** [pp fmt s] pretty prints the query [s]. *)
76 |
77 | val show: t -> string
78 | (** [show s] converts the query [s] to a string *)
79 |
80 | val of_sexp : Sexplib0.Sexp.t -> t
81 | (** [of_sexp s] builds a pattern from a s-expression
82 |
83 | Note: Any atom prefixed with "?" will be treated as a pattern
84 | variable.
85 |
86 | For example, the following pattern will match any multiplication expressions:
87 | {[
88 | List [Atom "*"; Atom "?a"; Atom "?b"]
89 | ]}
90 | *)
91 |
92 | val to_sexp : t -> Sexplib0.Sexp.t
93 | (** [to_sexp s] converts a pattern back into an s-expression. This is idempotent with {!of_sexp}. *)
94 |
95 | end
96 |
97 | module Rule : sig
98 | (** This module encodes syntactic rewrite rules over Sexprs and is part of {!Ego.Basic}'s API
99 | for expressing syntactic rewrites. *)
100 |
101 | type t
102 | (** Encodes a rewrite rule over S-expressions. *)
103 |
104 | val pp: Format.formatter -> t -> unit
105 | (** [pp fmt r] pretty prints the rewrite rule [r]. *)
106 |
107 | val show: t -> string
108 | (** [show r] converts the rewrite rule [r] to a string *)
109 |
110 |
111 | val make: from:Query.t -> into:Query.t -> t option
112 | (** [make ~from ~into] builds a syntactic rewrite rule from a
113 | matching pattern [from] and a result pattern [into].
114 |
115 | Iff [into] contains variables that are not bound in [from],
116 | then the rule is invalid, and the function will return [None]. *)
117 |
118 | end
119 |
120 | module EGraph : sig
121 | (** This module defines the main interface to the EGraph provided
122 | by {!Ego.Basic}. *)
123 |
124 | type t
125 | (** Represents a syntactic-rewrite-only EGraph that operates over
126 | Sexps. *)
127 |
128 | val pp : ?pp_id:(Format.formatter -> Id.t -> unit) -> Format.formatter -> t -> unit
129 | (** [pp ?pp_id fmt graph] prints an internal representation of the
130 | [graph].
131 |
132 | {b Note}: This is primarily intended for debugging, and the
133 | output format is not guaranteed to remain consistent over
134 | versions. *)
135 |
136 | val pp_dot : Format.formatter -> t -> unit
137 | (** [pp_dot fmt graph] pretty prints [graph] in a Graphviz format. *)
138 |
139 | val init : unit -> t
140 | (** [init ()] creates a new EGraph. *)
141 |
142 | val add_sexp : t -> Sexplib0.Sexp.t -> Id.t
143 | (** [add_sexp graph sexp] adds [sexp] to [graph] and returns the
144 | equivalence class associated with term. *)
145 |
146 | val to_dot : t -> Odot.graph
147 | (** [to_dot graph] converts [graph] into a Graphviz representation. *)
148 |
149 | val merge : t -> Id.t -> Id.t -> unit
150 | (** [merge graph id1 id2] merges the equivalence classes
151 | associated with [id1] and [id2].
152 |
153 | {b Note}: If you call {!merge} manually, you must call
154 | {!rebuild} before running any queries or extraction. *)
155 |
156 | val rebuild : t -> unit
157 | (** [rebuild graph] restores the internal invariants of the EGraph
158 | [graph].
159 |
160 | {b Note}: If you call {!merge} manually, you must call
161 | {!rebuild} before running any queries or extraction. *)
162 |
163 | val extract: ((Id.t -> float) -> (Symbol.t * Id.t list) -> float) -> t -> Id.t -> Sexplib0.Sexp.t
164 | (** [extract cost_fn graph] computes an extraction function [Id.t
165 | -> Sexplib0.Sexp.t] to extract terms (specified by [Id.t]) from
166 | the EGraph.
167 |
168 | [cost_fn f (sym,children)] should assign costs to the node
169 | with tag [sym] and children [children] - it can use [f] to
170 | determine the cost of a child. *)
171 |
172 | val apply_rules : t -> Rule.t list -> unit
173 | (** [apply_rules graph rules] runs each of the rewrites in [rules]
174 | exactly once over the egraph [graph] and then returns. *)
175 |
176 | val run_until_saturation: ?fuel:int -> t -> Rule.t list -> bool
177 | (** [run_until_saturation ?fuel graph rules] repeatedly each one
178 | of the rewrites in [rules] until no further changes occur ({i
179 | i.e equality saturation }), or until it runs out of [fuel].
180 |
181 | It returns a boolean indicating whether it reached equality
182 | saturation or had to terminate early. *)
183 | end
184 |
185 | end
186 |
187 | module Generic : sig
188 |
189 | (** This module implements a generic EGraph-based equality
190 | saturation engine that operates over arbitrary user-defined
191 | languages and provides support for extensible custom user-defined
192 | EClass analyses.
193 |
194 | The main interface to EGraph is provided by the functor {!Make}
195 | which constructs an EGraph given a {!LANGUAGE} and {!ANALYSIS},
196 | {!ANALYSIS_OPS}.
197 |
198 | You may want to check out the {{:../../index.html} quick start
199 | guide}. *)
200 |
201 |
202 | type ('node, 'analysis, 'data, 'permission) egraph
203 | (** A generic representation of an EGraph, parameterised over the
204 | language term types ['node], analysis state ['analysis] and data
205 | ['data] and read permissions ['permission]. *)
206 |
207 | module StringMap : Map.S with type key = string
208 |
209 | (** The module {!Query} encodes generic patterns (for both matching
210 | and transformation) over expressions and is part of
211 | {!Ego.Generic}'s API for expressing rewrites. *)
212 | module Query : sig
213 |
214 | type 'sym t
215 | (** Represents a query over expressions in a language with
216 | operators of type ['sym]. *)
217 |
218 | val of_sexp : (string -> 'a) -> Sexplib0.Sexp.t -> 'a t
219 | (** [of_sexp f s] constructs a query from a sexpression [s] using
220 | [f] to convert operator tags. *)
221 |
222 | val to_sexp : ('a -> string) -> 'a t -> Sexplib0.Sexp.t
223 | (** [to_sexp f q] converts a query [q] to a sexpression using [f]
224 | to convert operators in the query to strings. *)
225 |
226 | val pp : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a t -> unit
227 | (** [pp f fmt q] pretty prints a query [q] using [f] to print the
228 | operators within the query. *)
229 |
230 | val show : (Format.formatter -> 'a -> unit) -> 'a t -> string
231 | (** [show f q] prints a query [q] to string using [f] to print
232 | operators within the query. *)
233 |
234 | end
235 |
236 | (** The module {!Scheduler} provides implementations of some generic
237 | schedulers for Ego's equality saturation engine.
238 |
239 | See {!Make.BuildRunner} on how to compose a schedule with an
240 | EGraph definition. *)
241 | module Scheduler : sig
242 |
243 | (** The module {!Backoff} implements an exponential backoff
244 | scheduler. The scheduler works by tracking a maximum match
245 | limit, and (BEB) banning rules which exceed their limit. *)
246 | module Backoff : sig
247 |
248 | type t
249 | (** Represents the persistent state of the scheduler - it really
250 | just tracks the match limit and ban_length parameters chosen
251 | by for this particular instantiation. *)
252 |
253 | type data
254 | (** Represents the metadata about rules tracked by the
255 | scheduler. *)
256 |
257 | val with_params : match_limit:int -> ban_length:int -> t
258 | (** [with_params ~match_limit ~ban_length] creates a new backoff
259 | scheduler with the threshold for banning rules set to
260 | [match_limit] and the length for which rules are banned set
261 | to [ban_length]. *)
262 |
263 | val default : unit -> t
264 | (** [default ()] returns a default instance of the backoff
265 | scheduler with the threshold for banning rules set to 1_000
266 | and the initial ban_length set to 5. *)
267 |
268 | (* */ *)
269 |
270 | val create_rule_metadata : t -> 'a -> data
271 |
272 | val should_stop : t -> int -> data Iter.t -> bool
273 |
274 | val guard_rule_usage :
275 | ('node, 'analysis, 'data, 'permission) egraph ->
276 | t -> data -> int ->
277 | (unit -> (Id.t * Id.t StringMap.t) Iter.t) ->
278 | (Id.t * Id.t StringMap.t) Iter.t
279 |
280 | end
281 |
282 |
283 | (** The module {!Simple} implements a scheduler that runs every
284 | rule each time - i.e applies no scheduling at all. This works
285 | fine for rewrite systems with a finite number of EClasses but
286 | can become a problem if the number of EClasses is too large or
287 | unbounded. *)
288 | module Simple : sig
289 | type t
290 | type data
291 | val init : unit -> data
292 | val create_rule_metadata : t -> 'b -> data
293 | val should_stop : t -> int -> data -> bool
294 | val guard_rule_usage :
295 | ('node, 'analysis, 'data, 'permission) egraph ->
296 | t ->
297 | data ->
298 | int ->
299 | (data -> (Id.t * Id.t StringMap.t) Iter.t) ->
300 | (Id.t * Id.t StringMap.t) Iter.t
301 | end
302 | end
303 |
304 | (** {1:permissions Read/Write permissions}
305 |
306 | For convenience, the operations over the EGraph are split into
307 | those which {b read and write} to the graph [rw t] and those that
308 | are {b read-only} [ro t]. When defining the analysis operations,
309 | certain operations assume that the graph is not modified, so
310 | these anotations will help users to avoid violating the internal
311 | invariants of the data structure.
312 | *)
313 |
314 | type rw
315 | (** Encodes a read/write permission for a graph. *)
316 |
317 | type ro
318 | (** Encodes a read-only permission for a graph. *)
319 |
320 | (** {1:interfaces Interfaces} *)
321 |
322 | (** The {!LANGUAGE} module type represents the definition of an
323 | arbitrary language for use with an EGraph. *)
324 | module type LANGUAGE = sig
325 |
326 | type 'a shape
327 | (** Encodes the "shape" of an expression in the language over
328 | sub-expressions of type ['a]. *)
329 |
330 | type op
331 | (** Represents the tags that discriminate the expression
332 | constructors of the language. *)
333 |
334 | (** Represents concrete terms of the language by "tying-the-knot". *)
335 | type t = Mk of t shape [@@unboxed]
336 |
337 | val equal_op : op -> op -> bool
338 | (** [equal_op op1 op2] returns true if the operators [op1], [op2] are equal. *)
339 |
340 | val pp_shape : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a shape -> unit
341 | (** [pp_shape f fmt s] pretty prints expressions of the language. *)
342 |
343 | val compare_shape : ('a -> 'a -> int) -> 'a shape -> 'a shape -> int
344 | (** [compare cmp a b] compares expressions [a] and [b] using [cmp]
345 | to compare subexpressions. *)
346 |
347 | val op : 'a shape -> op
348 | (** [op expr] retrieves the tag that discriminates the shape of
349 | the expression [expr]. *)
350 |
351 | val children : 'a shape -> 'a list
352 | (** [children exp] returns the subexpressions of expression [exp]. *)
353 |
354 | val map_children : 'a shape -> ('a -> 'b) -> 'b shape
355 | (** [map_children exp f] maps the function [f] over the
356 | sub-expressions of the expression [exp] *)
357 |
358 | val make : op -> 'a list -> 'a shape
359 | (** [make op ls] constructs an expression from the tag [op] and
360 | children [ls].
361 |
362 | {b Note}: If called with invalid arity of arguments for the
363 | operator [op] the function may throw an error. *)
364 |
365 | end
366 |
367 | (** The module type {!ANALYSIS} encodes the data-types for an
368 | abstract EClass analysis over EGraphs. *)
369 | module type ANALYSIS = sig
370 |
371 | type t
372 | (** Represents any persistent state that an analysis may need to
373 | track separately from each EClasses.
374 |
375 | {b Note}: Terms of this type must be mutated imperatively as
376 | the EGraph API doesn't provide any functions to functionally
377 | update the persisted state. *)
378 |
379 | type data
380 | (** Represents the additional analysis information that we will be
381 | attached to each EClass. *)
382 |
383 | val pp_data : Format.formatter -> data -> unit
384 | (** [pp_data fmt data] pretty prints [data] using the formatter [fmt]. *)
385 |
386 | val show_data : data -> string
387 | (** [show_data data] converts [data] into a string. *)
388 |
389 | val equal_data : data -> data -> bool
390 | (** [equal_data d1 d2] returns true iff [d1], [d2] are equal. *)
391 |
392 | val default: data
393 | (** Represents a default abstract value for new nodes. *)
394 |
395 | end
396 |
397 | (** The module type {!ANALYSIS_OPS} defines the main operations for
398 | an EClass analysis over an EGraph. *)
399 | module type ANALYSIS_OPS = sig
400 |
401 | type 'a t
402 | (** Represents the EGraph over which the analysis operates. *)
403 |
404 | type analysis
405 | (** Represents the persistent state of the analysis. *)
406 |
407 | type node
408 | (** Represents expressions of the language over which the analysis
409 | operates. *)
410 |
411 | type data
412 | (** Represents the additional analysis information that we will be
413 | attached to each EClass. *)
414 |
415 | val make : ro t -> node -> data
416 | (** [make graph node] returns the analysis data for [node].
417 |
418 | This function is called whenever a new node is added and
419 | should generate a new abstract value for the node, usually
420 | using the abstract values of its children.
421 |
422 | {b Note}: In terms of abstract interpretation, this function
423 | can be thought of the "abstraction" function, mapping concrete
424 | terms to their corresponding values in the abstract domain. *)
425 |
426 | val merge : analysis -> data -> data -> data * (bool * bool)
427 | (** [merge st d1 d2] returns the analysis data that represents the
428 | combination of [d1] and [d2] and a tuple indicating whether the
429 | result differs from [d1] and or [d2].
430 |
431 | This function is called whenever two equivalance classes are
432 | merged and should produce a new abstract value that represents
433 | the merge of their corresponding abstract values.
434 |
435 | {b Note}: In terms of abstract interpretation, this function
436 | can be thought of the least upper bound (lub), exposing the
437 | semi-lattice structure of the abstract domain. *)
438 |
439 | val modify : rw t -> Id.t -> unit
440 | (** [modify graph class] is used to introduce new children of an
441 | equivalence class whenever new information about its elements
442 | is found by the analysis.
443 |
444 | This function is called whenever the children or abstract
445 | values of an eclass are modified and may use the abstract value
446 | of its to modify the egraph.
447 |
448 | {b Note}: In terms of abstract interpretation, this function
449 | can be thought of the "abstraction" function, mapping concrete
450 | terms to their corresponding values in the abstract domain. *)
451 |
452 | end
453 |
454 | (** The module type {!COST} represents the definition of some
455 | arbitrary cost system for ranking expressions over some language.
456 | *)
457 | module type COST = sig
458 |
459 | type t
460 | (** Represents the type of a cost of a node. *)
461 |
462 | type node
463 | (** Represents terms of the language *)
464 |
465 | val compare : t -> t -> int
466 | (** [compare c1 c2] compares the costs [t1] and [t2] *)
467 |
468 | val cost : (Id.t -> t) -> node -> t
469 | (** [cost f node] should assign costs to the node [node]. It can
470 | use the provided function [f] to determine the cost of a
471 | child. *)
472 |
473 | end
474 |
475 | (** The module type {!SCHEDULER} represents the definition of some
476 | scheduling system for ranking rule applications during equality
477 | saturation.
478 |
479 | See {!Scheduler} for some generic schedulers.
480 | *)
481 | module type SCHEDULER = sig
482 |
483 | type 'p egraph
484 | (** Represents an EGraph with read/write permissions
485 | ['p]. *)
486 |
487 | type t
488 | (** Represents any persistent state of the scheduler that must be
489 | maintained separately to its rules. *)
490 |
491 | type data
492 | (** Represents metadata about a rule that the scheduler keeps
493 | track of in order to schedule rules. *)
494 |
495 | type rule
496 | (** Represents the type of rules over which this scheduler operates *)
497 |
498 | val default : unit -> t
499 | (** Create a default instance of the scheduler. *)
500 |
501 | val should_stop: t -> int -> data Iter.t -> bool
502 | (** [should_stop scheduler iteration data] is called whenever the
503 | EGraph reaches saturation (with the rules that have been
504 | scheduled), and should return whether further iterations should
505 | be run (i.e we will be trying a different schedule) or whether
506 | we have actually truly reached saturation. *)
507 |
508 | val create_rule_metadata: t -> rule -> data
509 | (** [create_rule_metadata scheduler rule] returns the initial
510 | metadata for a rule [rule]. *)
511 |
512 | val guard_rule_usage:
513 | rw egraph -> t -> data -> int ->
514 | (unit -> (Id.t * Id.t StringMap.t) Iter.t) -> (Id.t * Id.t StringMap.t) Iter.t
515 | (** [guard_rule_usage graph scheduler data iteration
516 | gen_matches] is called before the execution of a particular
517 | rule (represented by the callback [gen_matches]), and should
518 | return a filtered set of matches according to the scheduling
519 | of the rule. *)
520 |
521 | end
522 |
523 | (** This module {!GRAPH_API} represents the interface through which
524 | EClass analyses can interact with an EGraph. *)
525 | module type GRAPH_API = sig
526 |
527 | type 'p t
528 | (** Represents an EGraph with read permissions ['p]. *)
529 |
530 | type data
531 | (** Represents the additional analysis information that we will be
532 | attached to each EClass. *)
533 |
534 | type analysis
535 | (** Represents the persistent state of the analysis. *)
536 |
537 | type 'a shape
538 | (** Represents the shape of expressions in the language. *)
539 |
540 | type node
541 | (** Represents concrete terms of expressions in the language *)
542 |
543 | val freeze : rw t -> ro t
544 | (** [freeze graph] returns a read-only reference to the EGraph.
545 |
546 | {b Note}: it is safe to modify [graph] after passing it to
547 | freeze, this method is mainly intended to allow using the
548 | read-only APIs of the EGraph when you have a RW instance of
549 | the EGraph. *)
550 |
551 | val class_equal : ro t -> Id.t -> Id.t -> bool
552 | (** [class_equal graph cls1 cls2] returns true if and only if
553 | [cls1] and [cls2] are congruent in the EGraph [graph]. *)
554 |
555 | val iter_children : ro t -> Id.t -> Id.t shape Iter.t
556 | (** [iter_children graph cls] returns an iterator over the
557 | children of the current EClass. *)
558 |
559 | val set_data : rw t -> Id.t -> data -> unit
560 | (** [set_data graph cls data] sets the analysis data for EClass
561 | [cls] in EGraph [graph] to be [data]. *)
562 |
563 | val get_data : ro t -> Id.t -> data
564 | (** [get_data graph cls] returns the analysis data for EClass
565 | [cls] in EGraph [graph]. *)
566 |
567 | val get_analysis: rw t -> analysis
568 | (** [get_analysis graph] returns the persistent analysis sate
569 | for an EGraph. *)
570 |
571 | val add_node : rw t -> node -> Id.t
572 | (** [add_node graph term] adds the term [term] into the EGraph
573 | [graph] and returns the corresponding equivalence class. *)
574 |
575 | val merge : rw t -> Id.t -> Id.t -> unit
576 | (** [merge graph cls1 cls2] merges the two equivalence classes
577 | [cls1] and [cls2]. *)
578 |
579 | end
580 |
581 | (** This module type {!RULE} defines the rewrite interface for an
582 | EGraph, allowing users to express relatively complex
583 | transformations of expressions over some language. *)
584 | module type RULE = sig
585 |
586 | type t
587 | (** Represents rewrite rules over the language of the EGraph. *)
588 |
589 | type query
590 | (** Represents a pattern over the language of the EGraph - it can
591 | either be used to {i match} and {i bind} a particular
592 | subpattern in an expression, or can be used to express the
593 | output schema for a rewrite. *)
594 |
595 | type 'p egraph
596 | (** Represents an EGraph with read/write permissions
597 | ['p]. *)
598 |
599 | val make_constant : from:query -> into:query -> t
600 | (** [make_constant ~from ~into] creates a rewrite rule from a
601 | pattern [from] into a schema [into] that applies a purely
602 | syntactic transformation. *)
603 |
604 | val make_conditional :
605 | from:query ->
606 | into:query ->
607 | cond:(rw egraph -> Id.t -> Id.t StringMap.t -> bool) ->
608 | t
609 | (** [make_conditional ~from ~into ~cond] creates a syntactic
610 | rewrite rule from [from] to [into] that is conditionally
611 | applied based on some property [cond] of the EGraph, the root
612 | eclass of the sub-expression being transformed and the eclasses
613 | of all bound variables. *)
614 |
615 | val make_dynamic :
616 | from:query ->
617 | generator:(rw egraph -> Id.t -> Id.t StringMap.t -> query option) -> t
618 | (** [make_dynamic ~from ~generator] creates a dynamic rewrite
619 | rule from a pattern [from] into a schema that is
620 | conditionally generated based on properties of the EGraph,
621 | the root eclass of the sub-expression being transformed and
622 | the eclasses of all bound variables *)
623 |
624 | end
625 |
626 | (** {1:constructors EGraph Constructors} *)
627 |
628 | (** This functor {!MakePrinter} allows users to construct EGraph
629 | printing utilities for a given {!LANGUAGE} and {!ANALYSIS}. *)
630 | module MakePrinter : functor (L : LANGUAGE) (A : ANALYSIS) -> sig
631 |
632 | (* val pp : Format.formatter -> (Id.t L.shape, A.t, A.data, 'b) egraph -> unit
633 | * (\** [pp fmt graph] pretty prints an internal representation of
634 | * the graph.
635 | *
636 | * {b Note}: This is primarily intended for debugging, and the
637 | * output format is not guaranteed to remain consistent over
638 | * versions. *\) *)
639 |
640 | val to_dot : (Id.t L.shape, A.t, A.data, 'b) egraph -> Odot.graph
641 | (** [to_dot graph] converts an EGraph into a Graphviz
642 | representation for debugging. *)
643 |
644 | end
645 |
646 | (** This functor {!MakeExtractor} allows users to construct an
647 | EGraph extraction procedure for a given {!LANGUAGE} and {!COST}
648 | system. *)
649 | module MakeExtractor : functor
650 | (L : LANGUAGE)
651 | (E : COST with type node := Id.t L.shape) -> sig
652 |
653 | val extract : (Id.t L.shape, 'a, 'b, rw) egraph -> Id.t -> L.t
654 | (** [extract graph] computes an extraction function [Id.t ->
655 | Sexplib0.Sexp.t] to extract concrete terms of the language {!L}
656 | from their respective EClasses (specified by [Id.t]) from the
657 | EGraph according to the cost system {!E}. *)
658 |
659 | end
660 |
661 |
662 | (** This functor {!Make} serves as the main interface to Ego's
663 | generic EGraphs, and constructs an EGraph given a {!LANGUAGE}, an
664 | {!ANALYSIS} and it's {!ANALYSIS_OPS}. *)
665 | module Make :
666 | functor
667 | (L : LANGUAGE)
668 | (A : ANALYSIS)
669 | (MakeAnalysisOps : functor
670 | (S : GRAPH_API with type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph
671 | and type analysis := A.t
672 | and type data := A.data
673 | and type 'a shape := 'a L.shape
674 | and type node := L.t) ->
675 | ANALYSIS_OPS with type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph
676 | and type analysis := A.t
677 | and type data := A.data
678 | and type node := Id.t L.shape) ->
679 | sig
680 |
681 | type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph
682 | (** This type represents an EGraph parameterised over a
683 | particular language {!L} and analysis {!A}. *)
684 |
685 | (** This module {!Rule} defines the rewrite interface for the
686 | EGraph, allowing users to express relatively complex
687 | transformations of expressions of the Language {!L}. *)
688 | module Rule:
689 | RULE with type query := L.op Query.t
690 | and type 'a egraph := (Id.t L.shape, A.t, A.data, 'a) egraph
691 |
692 |
693 | val freeze : rw t -> ro t
694 | (** [freeze graph] returns a read-only reference to the EGraph.
695 |
696 | {b Note}: it is safe to modify [graph] after passing it to
697 | freeze, this method is mainly intended to allow using the
698 | read-only APIs of the EGraph when you have a RW instance of
699 | the EGraph. *)
700 |
701 | val init : A.t -> 'p t
702 | (** [init analysis] creates a new EGraph with an initial
703 | persistent analysis state of [analysis]. *)
704 |
705 | val class_equal: ro t -> Id.t -> Id.t -> bool
706 | (** [class_equal graph cls1 cls2] returns true if and only if
707 | [cls1] and [cls2] are congruent in the EGraph [graph]. *)
708 |
709 | val set_data : rw t -> Id.t -> A.data -> unit
710 | (** [set_data graph cls data] sets the analysis data for EClass
711 | [cls] in EGraph [graph] to be [data]. *)
712 |
713 | val get_data : _ t -> Id.t -> A.data
714 | (** [get_data graph cls] returns the analysis data for EClass
715 | [cls] in EGraph [graph]. *)
716 |
717 | val get_analysis: rw t -> A.t
718 | (** [get_analysis graph] returns the persistent analysis sate
719 | for an EGraph. *)
720 |
721 | val iter_children : ro t -> Id.t -> Id.t L.shape Iter.t
722 | (** [iter_children graph cls] returns an iterator over the
723 | elements of an eclass [cls]. *)
724 |
725 | (* val pp : Format.formatter -> (Id.t L.shape, 'a, A.data, _) egraph -> unit
726 | * (\** [pp fmt graph] pretty prints an internal representation of
727 | * the graph.
728 | *
729 | * {b Note}: This is primarily intended for debugging, and the
730 | * output format is not guaranteed to remain consistent over
731 | * versions. *\) *)
732 |
733 | val to_dot : (Id.t L.shape, A.t, A.data, _) egraph -> Odot.graph
734 | (** [to_dot graph] converts an EGraph into a Graphviz
735 | representation for debugging. *)
736 |
737 | val add_node : rw t -> L.t -> Id.t
738 | (** [add_node graph term] adds the term [term] into the EGraph
739 | [graph] and returns the corresponding equivalence class. *)
740 |
741 | val merge : rw t -> Id.t -> Id.t -> unit
742 | (** [merge graph cls1 cls2] merges the two equivalence classes
743 | [cls1] and [cls2]. *)
744 |
745 | val rebuild : rw t -> unit
746 | (** [rebuild graph] restores the internal invariants of the
747 | graph.
748 |
749 | {b Note}: If you call {!merge} manually (i.e outside of
750 | analysis functions), you must call {!rebuild} before running
751 | any queries or extraction. *)
752 |
753 | val find_matches : ro t -> L.op Query.t -> (Id.t * Id.t StringMap.t) Iter.t
754 | (** [find_matches graph query] returns an iterator over each
755 | match of the query [query] in the EGraph. *)
756 |
757 | val apply_rules : (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> unit
758 | (** [apply_rules graph rules] runs each of the rewrites in [rules]
759 | exactly once over the egraph [graph] and then returns. *)
760 |
761 | val run_until_saturation:
762 | ?scheduler:Scheduler.Backoff.t ->
763 | ?node_limit:[`Bounded of int | `Unbounded] ->
764 | ?fuel:[`Bounded of int | `Unbounded] ->
765 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) ->
766 | (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool
767 | (** [run_until_saturation ?scheduler ?node_limit ?fuel ?until
768 | graph rules] repeatedly each one of the rewrites in [rules]
769 | according to the scheduler [scheduler] until no further
770 | changes occur ({i i.e equality saturation }), or until it
771 | runs out of [fuel] (defaults to 30) or reaches a [node_limit]
772 | if supplied (defaults to 10_000) or some predicate [until] is
773 | satisfied.
774 |
775 | It returns a boolean indicating whether it reached equality
776 | saturation or had to terminate early. *)
777 |
778 | (** The module {!BuildRunner} allows users to supply their own
779 | custom domain-specific scheduling strategies for equality
780 | saturation by supplying a corresponding Scheduling module
781 | satisfying {!SCHEDULER} *)
782 | module BuildRunner (S : SCHEDULER
783 | with type 'a egraph := (Id.t L.shape, A.t, A.data, rw) egraph
784 | and type rule := Rule.t) :
785 | sig
786 |
787 | val run_until_saturation :
788 | ?scheduler:S.t ->
789 | ?node_limit:[`Bounded of int | `Unbounded] ->
790 | ?fuel:[`Bounded of int | `Unbounded] ->
791 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) ->
792 | (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool
793 | (** [run_until_saturation ?scheduler ?node_limit ?fuel
794 | ?until graph rules] repeatedly each one of the rewrites
795 | in [rules] according to the scheduler [scheduler] until
796 | no further changes occur ({i i.e equality saturation }),
797 | or until it runs out of [fuel] (defaults to 30) or
798 | reaches some [node_limit] (defaults to 10_000) or some
799 | predicate [until] is satisfied.
800 |
801 | It returns a boolean indicating whether it reached
802 | equality saturation or had to terminate early. *)
803 |
804 | end
805 |
806 | end
807 |
808 | end
809 |
--------------------------------------------------------------------------------
/lib/equivalence.ml:
--------------------------------------------------------------------------------
1 | open Containers
2 | (* module IntMap = Map.Make (Int) *)
3 | module IntMap = Hashtbl.Make (Int)
4 | module IntSet = CCHashSet.Make (Int)
5 |
6 | module Make = functor () -> struct
7 |
8 | type elem =
9 | | Root of int
10 | | Link of elem_ref
11 | and elem_ref = int
12 | and store = {
13 | mutable limit: int;
14 | content: elem IntMap.t
15 | }
16 |
17 | type t = elem_ref
18 |
19 | let repr v = v
20 |
21 | let (.@[]) store rf = IntMap.find store.content rf
22 | let (.@[]<-) store rf vl = IntMap.replace store.content rf vl
23 |
24 | let create_store () = {limit=0; content=IntMap.create 100}
25 |
26 | let hash = Int.hash
27 |
28 | let rref (store: store) vl =
29 | let x = store.limit in
30 | store.limit <- x + 1;
31 | IntMap.replace store.content x vl;
32 | x
33 |
34 | let make_raw =
35 | let id = ref 0 in
36 | fun () -> incr id; (Root !id)
37 |
38 | let make store () =
39 | rref store @@ make_raw ()
40 |
41 | let rec find store x =
42 | match store.@[x] with
43 | | Root _ -> x
44 | | Link y ->
45 | let z = find store y in
46 | if not @@ Equal.physical z y then
47 | store.@[x] <- Link z;
48 | z
49 | let equal store t1 t2 =
50 | let t1 = find store t1 in
51 | let t2 = find store t2 in
52 | Equal.physical t1 t2
53 |
54 | let link store x y =
55 | if Equal.physical x y then x
56 | else match[@warning "-8"] store.@[x], store.@[y] with
57 | | Root _, Root _ -> store.@[y] <- Link x; x
58 | (* if vx < vy then (store.@[x] <- Link y; y)
59 | * else if vy > vx then (store.@[y] <- Link x; x)
60 | * else (store.@[y] <- Link x;
61 | * store.@[x] <- make_raw ();
62 | * x) *)
63 |
64 | let union store x y =
65 | let x = find store x in
66 | let y = find store y in
67 | link store x y
68 |
69 | module Map = IntMap
70 |
71 | module Set = IntSet
72 |
73 | end
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/lib/equivalence.mli:
--------------------------------------------------------------------------------
1 | module Make : functor () -> sig
2 | type store
3 | type t = private int
4 |
5 | val repr : t -> int
6 | val create_store : unit -> store
7 | val make : store -> unit -> t
8 | val find : store -> t -> t
9 |
10 | val equal : store -> t -> t -> bool
11 | val union : store -> t -> t -> t
12 | val hash: t -> int
13 |
14 | module Map : Hashtbl.S with type key = t
15 |
16 | module Set : CCHashSet.S with type elt = t
17 | end
18 |
--------------------------------------------------------------------------------
/lib/generic.ml:
--------------------------------------------------------------------------------
1 | open [@warning "-33"] Containers
2 | open Language
3 | open Types
4 | module Id = Id
5 |
6 | let dedup cmp vec =
7 | let prev = ref None in
8 | Vector.filter_in_place (fun elt ->
9 | match !prev with
10 | | None -> prev := Some elt; true
11 | | Some last_value ->
12 | if Int.equal (cmp last_value elt) 0
13 | then false
14 | else begin
15 | prev := Some elt;
16 | true
17 | end
18 | ) vec
19 |
20 |
21 | (* let lappend_pair a (b,c) = (a,b,c) *)
22 | type 'a query = 'a Query.t
23 |
24 | type ('node, 'data) eclass = {
25 | mutable id: Id.t;
26 | nodes: 'node Vector.vector;
27 | mutable data: 'data;
28 | parents: ('node * Id.t) Vector.vector;
29 | }
30 |
31 | type ('node, 'analysis, 'data, 'permission) egraph = {
32 | mutable version: int;
33 | analysis: 'analysis;
34 |
35 | uf: Id.store; (* tracks equivalence classes of
36 | class ids *)
37 | class_data:
38 | ('node, 'data) eclass Id.Map.t; (* maps classes to the canonical nodes
39 | they contain, and any classes that are
40 | children of these nodes *)
41 | hash_cons: ('node, Id.t) Hashtbl.t; (* maps cannonical nodes to their
42 | equivalence classes *)
43 | pending: ('node * Id.t) Vector.vector;
44 |
45 | pending_analysis: ('node * Id.t) Vector.vector;
46 | }
47 |
48 |
49 | module MakeInt (L: LANGUAGE) (* (A: ANALYSIS) *) = struct
50 |
51 | let (.@[]) self fn = fn self [@@inline always]
52 |
53 | (* *** Initialization *)
54 | let init analysis = {
55 | version=0;
56 | analysis;
57 | uf=Id.create_store ();
58 | class_data=Id.Map.create 10;
59 | hash_cons=Hashtbl.create 10;
60 | pending=Vector.create ();
61 | pending_analysis=Vector.create ();
62 | }
63 |
64 | (* *** Eclasses *)
65 | let get_analysis self = self.analysis
66 |
67 | let get_class_data self id =
68 | match Id.Map.find_opt self.class_data id with
69 | | Some data -> data
70 | | None -> failwith @@ Printf.sprintf "attempted to set the data of an unbound class %s " (EClassId.show id)
71 |
72 | let remove_class_data self id =
73 | match Id.Map.find_opt self.class_data id with
74 | | Some classes -> Id.Map.remove self.class_data id; Some classes
75 | | None -> None
76 |
77 | let set_data self id data =
78 | match Id.Map.find_opt self.class_data id with
79 | | None -> failwith @@ Printf.sprintf "attempted to set the data of an unbound class %s " (EClassId.show id)
80 | | Some class_data -> class_data.data <- data
81 |
82 | let get_data self id =
83 | match Id.Map.find_opt self.class_data (Id.find self.uf id) with
84 | | None -> failwith @@ Printf.sprintf "attempted to get the data of an unbound class %s " (EClassId.show id)
85 | | Some class_data -> class_data.data
86 |
87 | let canonicalise self node = L.map_children node (Id.find self.uf)
88 |
89 | let find self vl = Id.find self.uf vl
90 |
91 | (* *** Exports *)
92 | (* **** Export eclasses *)
93 | let eclasses self =
94 | let r = Id.Map.create 10 in
95 | Hashtbl.iter (fun node eid ->
96 | let eid = Id.find self.uf eid in
97 | match Id.Map.find_opt r eid with
98 | | None -> let ls = Vector.of_list [node] in Id.Map.add r eid ls
99 | | Some ls -> Vector.push ls node
100 | ) self.hash_cons;
101 | r
102 |
103 | let class_equal self cls1 cls2 =
104 | Id.equal self.uf cls1 cls2
105 |
106 | end
107 |
108 | module MakePrinter (L: LANGUAGE) (A: ANALYSIS) = struct
109 |
110 | open (MakeInt(L))
111 |
112 | (* **** Export as dot *)
113 | let to_dot self =
114 | let eclasses = eclasses self in
115 |
116 | let pp_node_by_id fmt id =
117 | let pp_node_by_id fmt id =
118 | let id = self.@[find] id in
119 | begin
120 | let vls = Id.Map.find_opt eclasses id |> Option.get_lazy Vector.create in
121 | let open Format in
122 | pp_print_string fmt "{";
123 | pp_open_hovbox fmt 1;
124 | Vector.pp
125 | ~pp_sep:(fun fmt () -> pp_print_string fmt ","; pp_print_space fmt ())
126 | (L.pp_shape EClassId.pp) fmt vls;
127 | pp_close_box fmt ();
128 | pp_print_string fmt "}"
129 | end in
130 | pp_node_by_id fmt id in
131 | let stmt_list =
132 | let rev_map =
133 | Hashtbl.to_seq self.hash_cons
134 | |> Seq.map Pair.swap
135 | |> Id.Map.of_seq in
136 | let to_label id =
137 | let to_str id =
138 | match Id.Map.find_opt rev_map id with
139 | | None -> Format.to_string EClassId.pp id
140 | | Some node -> Format.to_string (L.pp_shape pp_node_by_id) node in
141 | to_str id in
142 | let to_id id =
143 | Odot.Double_quoted_id (to_label id) in
144 | let to_node_id node =
145 | Odot.Double_quoted_id (Format.to_string (L.pp_shape pp_node_by_id) node) in
146 | let to_subgraph_id id =
147 | Odot.Simple_id (Printf.sprintf "cluster_%d" (Id.repr id)) in
148 | let eclass_label eclass =
149 | let eclass_txt = Format.to_string EClassId.pp eclass in
150 | let data = get_data self eclass |> A.show_data in
151 | eclass_txt ^ " = " ^ data in
152 | let sub_graphs =
153 | (fun f -> Fun.flip Id.Map.iter eclasses (Fun.curry f))
154 | |> Iter.map (fun (eclass, (enodes: (Id.t L.shape, _) Vector.t)) ->
155 | let nodes =
156 | Vector.to_iter enodes
157 | |> Iter.map (fun (node: Id.t L.shape) ->
158 | let node_id = to_node_id node in
159 | let attrs = Odot.[Simple_id "label",
160 | Some (Double_quoted_id
161 | (Format.to_string (L.pp_shape pp_node_by_id) node))] in
162 | Odot.Stmt_node ((node_id, None), attrs))
163 | |> Iter.to_list in
164 | Odot.(Stmt_subgraph {
165 | sub_id= Some (to_subgraph_id eclass);
166 | sub_stmt_list=
167 | Stmt_attr (
168 | Attr_graph [
169 | (Simple_id "label", Some (Double_quoted_id (eclass_label eclass)))
170 | ]) :: nodes;
171 | })
172 | )
173 | |> Iter.to_list in
174 | let edges =
175 | (fun f -> Fun.flip Id.Map.iter eclasses (Fun.curry f))
176 | |> Iter.flat_map (fun (_eclass, enodes) ->
177 | Vector.to_iter enodes
178 | |> Iter.flat_map (fun node ->
179 | let label = to_node_id node in
180 | Iter.of_list (L.children node)
181 | |> Iter.map (fun child ->
182 | let child_label = to_id child in
183 | Odot.(Stmt_edge (
184 | Edge_node_id (label, None),
185 | [Edge_node_id (child_label, None)],
186 | []
187 | ))
188 | )
189 | )
190 | )
191 | |> Iter.to_list in
192 | (List.append sub_graphs edges) in
193 | Odot.{
194 | strict=true;
195 | kind=Digraph;
196 | id=None;
197 | stmt_list;
198 | }
199 |
200 | (* **** Print as dot *)
201 | let pp_dot fmt st =
202 | Format.pp_print_string fmt (Odot.string_of_graph (to_dot st))
203 |
204 | end
205 |
206 | module MakeExtractor (L: LANGUAGE) (E: COST with type node := Id.t L.shape) = struct
207 |
208 | open (MakeInt(L))
209 |
210 | let extract eg =
211 | let eclasses = eg.@[eclasses] in
212 | let cost_map = Id.Map.create 10 in
213 | let node_total_cost node =
214 | let has_cost id = Id.Map.mem cost_map (eg.@[find] id) in
215 | if List.for_all has_cost (L.children node)
216 | then let cost_f id = fst @@ Id.Map.find cost_map (eg.@[find] id) in Some (E.cost cost_f node)
217 | else None in
218 | let make_pass enodes =
219 | let cost, node =
220 | Vector.to_iter enodes
221 | |> Iter.map (fun n -> (node_total_cost n, n))
222 | |> Iter.min_exn ~lt:(fun (c1, _) (c2, _) ->
223 | (match c1, c2 with
224 | | None, None -> 0
225 | | Some _, None -> -1
226 | | None, Some _ -> 1
227 | | Some c1, Some c2 -> E.compare c1 c2) = -1) in
228 | Option.map (fun cost -> (cost, node)) cost in
229 | let find_costs () =
230 | let any_changes = ref true in
231 | while !any_changes do
232 | any_changes := false;
233 | Fun.flip Id.Map.iter eclasses (fun eclass enodes ->
234 | let pass = make_pass enodes in
235 | match Id.Map.find_opt cost_map eclass, pass with
236 | | None, Some nw -> Id.Map.replace cost_map eclass nw; any_changes := true
237 | | Some ((cold, _)), Some ((cnew, _) as nw)
238 | when E.compare cnew cold = -1 ->
239 | Id.Map.replace cost_map eclass nw; any_changes := true
240 | | _ -> ()
241 | )
242 | done in
243 | let rec extract eid =
244 | let eid = eg.@[find] eid in
245 | let enode = Id.Map.find cost_map eid |> snd in
246 | let head = L.op enode in
247 | let children = L.children enode in
248 | L.Mk (L.make head @@ List.map extract children) in
249 | find_costs ();
250 | fun result -> extract result
251 |
252 | end
253 |
254 | (* ** Graph *)
255 | module MakeOps
256 | (L: LANGUAGE)
257 | (A: ANALYSIS)
258 | (AM: sig
259 | val make: (Id.t L.shape, A.t, A.data, ro) egraph -> Id.t L.shape -> A.data
260 | val merge: A.t -> A.data -> A.data -> A.data * (bool * bool)
261 | val modify: (Id.t L.shape, A.t, A.data, rw) egraph -> Id.t -> unit
262 | end) =
263 | struct
264 |
265 | open (MakeInt (L))
266 |
267 | module Rule = struct
268 |
269 | type rule_output =
270 | | Constant of L.op Query.t
271 | | Conditional of
272 | L.op Query.t *
273 | ((Id.t L.shape, A.t, A.data, rw) egraph -> eclass_id -> eclass_id StringMap.t -> bool)
274 | | Dynamic of
275 | ((Id.t L.shape, A.t, A.data, rw) egraph -> eclass_id -> eclass_id StringMap.t -> L.op Query.t option)
276 |
277 | type t = L.op Query.t * rule_output
278 |
279 | let make_constant ~from ~into = (from, Constant into)
280 | let make_conditional ~from ~into ~cond = (from, Conditional (into, cond))
281 | let make_dynamic ~from ~generator = (from, Dynamic generator)
282 |
283 | end
284 |
285 | let new_class self =
286 | let id = Id.make self.uf () in
287 | Id.Map.add self.class_data id {id; nodes=Vector.create (); data=A.default; parents=Vector.create ()};
288 | id
289 |
290 | let freeze (graph: (_, _, _, rw) egraph) = (graph:> (_, _, _, ro) egraph)
291 |
292 | (* Adds a node into the egraph, assuming that the cannonical version
293 | of the node is up to date in the hash cons or
294 | *)
295 | let add_enode self (node: Id.t L.shape) =
296 | let node = self.@[canonicalise] node in
297 | let id = match Hashtbl.find_opt self.hash_cons node with
298 | | None ->
299 | self.version <- self.version + 1;
300 | let id = Id.make self.uf () in
301 | let cls = {
302 | id;
303 | nodes=Vector.of_list [node];
304 | data = AM.make (freeze self) node;
305 | parents=Vector.create ()
306 | } in
307 |
308 | List.iter (fun child ->
309 | let tup = (node, id) in
310 | Vector.push ((self.@[get_class_data] child).parents) tup
311 | ) (L.children node);
312 |
313 | Vector.push self.pending (node,id);
314 |
315 | Id.Map.add self.class_data id cls;
316 |
317 | Hashtbl.add self.hash_cons node id;
318 |
319 | AM.modify self id;
320 | id
321 | | Some id -> self.@[find] id in
322 | Id.find self.uf id
323 |
324 | let rec add_node self (L.Mk op: L.t) : Id.t =
325 | add_enode self @@ L.map_children op (add_node self)
326 |
327 | let rec subst self pat env =
328 | match pat with
329 | | Query.V id -> StringMap.find id env
330 | | Q (sym, args) ->
331 | let enode = L.make sym (List.map (fun arg -> self.@[subst] arg env) args) in
332 | self.@[add_enode] enode
333 |
334 | let merge self id1 id2 =
335 | let (+=) va vb = Vector.append va vb in
336 | let id1 = Id.find self.uf id1 in
337 | let id2 = Id.find self.uf id2 in
338 | if Id.eq_id id1 id2 then ()
339 | else begin
340 | self.version <- self.version + 1;
341 | (* cls2 has fewer children *)
342 | let id1, id2 =
343 | if Vector.length (self.@[get_class_data] id1).parents < Vector.length (self.@[get_class_data] id2).parents
344 | then (id2, id1)
345 | else (id1, id2) in
346 |
347 | (* make cls1 the new root *)
348 | assert (Id.eq_id id1 (Id.union self.uf id1 id2));
349 |
350 | let cls2 = self.@[remove_class_data] id2
351 | |> Option.get_exn_or "Invariant violation" in
352 | let cls1 = self.@[get_class_data] id1 in
353 | assert (Id.eq_id id1 cls1.id);
354 |
355 | self.pending += cls2.parents;
356 |
357 | let (did_update_cls1, did_update_cls2) =
358 | let data, res = (AM.merge self.analysis cls1.data cls2.data) in
359 | cls1.data <- data;
360 | res in
361 |
362 | if did_update_cls1 then self.pending_analysis += cls1.parents;
363 | if did_update_cls2 then self.pending_analysis += cls2.parents;
364 |
365 | cls1.nodes += cls2.nodes;
366 | cls1.parents += cls2.parents;
367 | AM.modify self id1
368 | end
369 |
370 | let rebuild_classes self =
371 | Id.Map.to_seq_values self.class_data |> Seq.iter (fun cls ->
372 | Vector.map_in_place (fun node -> self.@[canonicalise] node) cls.nodes;
373 | Vector.sort' (L.compare_shape EClassId.compare) cls.nodes;
374 | dedup (L.compare_shape EClassId.compare) cls.nodes
375 | )
376 |
377 | let process_unions self =
378 | (* let init_size = Hashtbl.length self.hash_cons in *)
379 | while not @@ Vector.is_empty self.pending do
380 |
381 | let rec update_hash_cons () =
382 | match Vector.pop self.pending with
383 | | None -> ()
384 | | Some (node,cls) ->
385 | let old_node = node in
386 | let node = self.@[canonicalise] node in
387 | if not @@ ((L.compare_shape EClassId.compare old_node node) = 0) then
388 | Hashtbl.remove self.hash_cons old_node;
389 | begin match (Hashtbl.find_opt self.hash_cons node) with
390 | | None -> Hashtbl.add self.hash_cons node cls
391 | | Some memo_cls -> self.@[merge] memo_cls cls
392 | end;
393 | update_hash_cons () in
394 | update_hash_cons ();
395 |
396 | let rec update_analysis () =
397 | match Vector.pop self.pending_analysis with
398 | | None -> ()
399 | | Some (node, class_id) ->
400 | let class_id = self.@[find] class_id in
401 | let node_data = AM.make (freeze self) node in
402 | let cls = self.@[get_class_data] class_id in
403 | assert (Id.eq_id cls.id class_id);
404 | let (did_update_left, _did_update_right) =
405 | let data,res = AM.merge self.analysis cls.data node_data in
406 | cls.data <- data;
407 | res in
408 | if did_update_left then begin
409 | Vector.append self.pending_analysis cls.parents;
410 | AM.modify self class_id
411 | end;
412 | update_analysis () in
413 | update_analysis ()
414 | done
415 | (* let _final_size = Hashtbl.length self.hash_cons in
416 | * print_endline @@ Printf.sprintf "after rebuilding size of nodes is %d => %d" init_size final_size *)
417 |
418 | let rebuild (self: (Id.t L.shape, 'b, 'c, rw) egraph) =
419 | process_unions self;
420 | rebuild_classes self
421 |
422 | (* ** Matching *)
423 | let ematch eg (classes: (Id.t L.shape, 'a) Vector.t Id.Map.t) pattern =
424 | let concat_map f l = Iter.concat (Iter.map f l) in
425 | let rec enode_matches p enode env =
426 | match[@warning "-8"] p with
427 | | Query.Q (f, _) when not @@ L.equal_op f (L.op enode) ->
428 | Iter.empty
429 | | Q (_, args) ->
430 | (fun f -> List.iter2 (Fun.curry f) args (L.children enode))
431 | |> Iter.fold (fun envs (qvar, trm) ->
432 | concat_map (fun env' -> match_in qvar trm env') envs) (Iter.singleton env)
433 | and match_in p eid env =
434 | let eid = find eg eid in
435 | match p with
436 | | V id -> begin
437 | match StringMap.find_opt id env with
438 | | None -> Iter.singleton (StringMap.add id eid env)
439 | | Some eid' when Id.eq_id eid eid' -> Iter.singleton env
440 | | _ -> Iter.empty
441 | end
442 | | p ->
443 | match Id.Map.find_opt classes eid with
444 | | Some v -> Vector.to_iter v |> concat_map (fun enode -> enode_matches p enode env)
445 | | None -> Iter.empty
446 | in
447 | (fun f -> Id.Map.iter (Fun.curry f) classes)
448 | |> concat_map (fun (eid, _) ->
449 | Iter.map (fun s -> (eid, s)) (match_in pattern eid StringMap.empty))
450 |
451 | let find_matches eg =
452 | let eclasses = eclasses eg in
453 | fun rule -> ematch eg eclasses rule
454 |
455 | let iter_children self cls =
456 | (* let old_cls = cls in *)
457 | let cls = (self.@[find] cls) in
458 | Id.Map.find_opt (eclasses self) cls |> Option.map Vector.to_iter |> Option.get_or ~default:Iter.empty
459 |
460 | module BuildRunner (S : SCHEDULER with type 'a egraph := (Id.t L.shape, A.t, A.data, rw) egraph
461 | and type rule := Rule.t) = struct
462 |
463 | (* ** Rewriting System *)
464 | let apply_rules scheduler iteration (eg: (Id.t L.shape, _, _, _) egraph) (rules : (Rule.t * S.data) array) =
465 | let find_matches = find_matches eg in
466 | let for_each_match =
467 | Iter.of_array rules
468 | |> Iter.flat_map (fun ((from_rule, to_rule), meta_data) ->
469 | S.guard_rule_usage eg scheduler meta_data iteration (fun () -> find_matches from_rule)
470 | |> Iter.map (fun (eid,env) -> (to_rule, eid, env))
471 | ) in
472 | for_each_match begin fun (to_rule, eid, env) ->
473 | match to_rule with
474 | | Rule.Constant to_rule ->
475 | let new_eid = subst eg to_rule env in
476 | merge eg eid new_eid
477 | | Conditional (to_rule, cond) ->
478 | if cond eg eid env then
479 | let new_eid = subst eg to_rule env in
480 | merge eg eid new_eid
481 | else ()
482 | | Dynamic cond ->
483 | match cond eg eid env with
484 | | None -> ()
485 | | Some to_rule ->
486 | let new_eid = subst eg to_rule env in
487 | merge eg eid new_eid
488 | end;
489 | rebuild eg
490 |
491 | let run_until_saturation ?scheduler ?(node_limit=`Bounded 10_000) ?(fuel=`Bounded 30) ?until eg rules =
492 | let scheduler = match scheduler with None -> S.default () | Some scheduler -> scheduler in
493 | let rules = Iter.of_list rules
494 | |> Iter.map (fun rule -> (rule, S.create_rule_metadata scheduler rule))
495 | |> Iter.to_array in
496 | let rule_data () = Array.to_iter rules |> Iter.map snd in
497 | match fuel, node_limit, until with
498 | | `Unbounded, `Unbounded, None ->
499 | let rec loop last_version ind =
500 | apply_rules scheduler ind eg rules;
501 | if not @@ Int.equal eg.version last_version
502 | then loop eg.version (ind + 1)
503 | else if S.should_stop scheduler ind (rule_data ()) then () else loop eg.version (ind + 1) in
504 | loop eg.version 0; true
505 | | `Unbounded, `Unbounded, Some pred ->
506 | let rec loop last_version ind =
507 | apply_rules scheduler ind eg rules;
508 | if not @@ Int.equal eg.version last_version
509 | then if pred eg then false else loop eg.version (ind + 1)
510 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in
511 | loop eg.version 0
512 | | `Unbounded, `Bounded node_limit, None ->
513 | let rec loop last_version ind =
514 | apply_rules scheduler ind eg rules;
515 | if not @@ Int.equal eg.version last_version
516 | then if Hashtbl.length eg.hash_cons < node_limit
517 | then loop eg.version (ind + 1)
518 | else false
519 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in
520 | loop eg.version 0
521 | | `Unbounded, `Bounded node_limit, Some pred ->
522 | let rec loop last_version ind =
523 | apply_rules scheduler ind eg rules;
524 | if not @@ Int.equal eg.version last_version
525 | then if Hashtbl.length eg.hash_cons < node_limit
526 | then if pred eg then false else loop eg.version (ind + 1)
527 | else false
528 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in
529 | loop eg.version 0
530 | | `Bounded fuel, `Unbounded, None ->
531 | let rec loop last_version ind =
532 | apply_rules scheduler ind eg rules;
533 | if not @@ Int.equal eg.version last_version
534 | then if fuel > ind
535 | then loop eg.version (ind + 1)
536 | else false
537 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in
538 | loop eg.version 0
539 | | `Bounded fuel, `Unbounded, Some pred ->
540 | let rec loop last_version ind =
541 | apply_rules scheduler ind eg rules;
542 | if not @@ Int.equal eg.version last_version
543 | then if fuel > ind
544 | then if pred eg then false else loop eg.version (ind + 1)
545 | else false
546 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in
547 | loop eg.version 0
548 | | `Bounded fuel, `Bounded node_limit, None ->
549 | let rec loop last_version ind =
550 | apply_rules scheduler ind eg rules;
551 | if not @@ Int.equal eg.version last_version
552 | then if fuel > ind && Hashtbl.length eg.hash_cons < node_limit
553 | then loop eg.version (ind + 1)
554 | else false
555 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in
556 | loop eg.version 0
557 | | `Bounded fuel, `Bounded node_limit, Some pred ->
558 | let rec loop last_version ind =
559 | apply_rules scheduler ind eg rules;
560 | if not @@ Int.equal eg.version last_version
561 | then if fuel > ind && Hashtbl.length eg.hash_cons < node_limit
562 | then if pred eg then false else loop eg.version (ind + 1)
563 | else false
564 | else if S.should_stop scheduler ind (rule_data ()) then true else loop eg.version (ind + 1) in
565 | loop eg.version 0
566 |
567 | end
568 |
569 | include (BuildRunner (Scheduler.Backoff))
570 |
571 | let apply_rules (eg: (Id.t L.shape, _, _, _) egraph) (rules : Rule.t list) =
572 | let find_matches = find_matches eg in
573 | let for_each_match =
574 | Iter.of_list rules
575 | |> Iter.flat_map
576 | (fun (from_rule, to_rule) ->
577 | find_matches from_rule
578 | |> Iter.map (fun (eid,env) -> (to_rule, eid, env))
579 | ) in
580 | for_each_match begin fun (to_rule, eid, env) ->
581 | match to_rule with
582 | | Rule.Constant to_rule ->
583 | let new_eid = subst eg to_rule env in
584 | merge eg eid new_eid
585 | | Conditional (to_rule, cond) ->
586 | if cond eg eid env then
587 | let new_eid = subst eg to_rule env in
588 | merge eg eid new_eid
589 | else ()
590 | | Dynamic cond ->
591 | match cond eg eid env with
592 | | None -> ()
593 | | Some to_rule ->
594 | let new_eid = subst eg to_rule env in
595 | merge eg eid new_eid
596 | end;
597 | rebuild eg
598 | end
599 |
600 |
601 |
602 |
603 | module Make
604 | (L: LANGUAGE)
605 | (A: ANALYSIS)
606 | (MakeAnalysisOps: functor
607 | (S: GRAPH_API
608 | with type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph
609 | and type analysis := A.t
610 | and type data := A.data
611 | and type 'a shape := 'a L.shape
612 | and type node := L.t) -> sig
613 | val make: (Id.t L.shape, A.t, A.data, ro) egraph -> Id.t L.shape -> A.data
614 | val merge: A.t -> A.data -> A.data -> A.data * (bool * bool)
615 | val modify: (Id.t L.shape, A.t, A.data, rw) egraph -> Id.t -> unit
616 | end)
617 | = struct
618 |
619 |
620 | module rec EGraph : sig
621 | type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph
622 |
623 | module Rule: sig
624 | type t
625 | val make_constant : from:L.op query -> into:L.op query -> t
626 | val make_conditional :
627 | from:L.op query ->
628 | into:L.op query ->
629 | cond:((Id.t L.shape, A.t, A.data, rw) egraph -> eclass_id -> eclass_id StringMap.t -> bool) ->
630 | t
631 |
632 | val make_dynamic :
633 | from:L.op query ->
634 | generator:((Id.t L.shape, A.t, A.data, rw) egraph ->
635 | eclass_id -> eclass_id StringMap.t -> L.op query option) ->
636 | t
637 |
638 | end
639 |
640 | val freeze : rw t -> ro t
641 | val init : A.t -> 'p t
642 | val class_equal: ro t -> eclass_id -> eclass_id -> bool
643 | val new_class : rw t -> eclass_id
644 | val set_data : rw t -> eclass_id -> A.data -> unit
645 | val get_data : _ t -> eclass_id -> A.data
646 | val get_analysis : rw t -> A.t
647 | val canonicalise : rw t -> Id.t L.shape -> Id.t L.shape
648 | val find : ro t -> eclass_id -> eclass_id
649 | (* val append_to_worklist : rw t -> eclass_id -> unit *)
650 | val eclasses: rw t -> (Id.t L.shape, Vector.rw) Vector.t Id.Map.t
651 | (* val pp : Format.formatter -> (Id.t L.shape, 'a, A.data, _) egraph -> unit *)
652 | val to_dot : (Id.t L.shape, A.t, A.data, _) egraph -> Odot.graph
653 | val pp_dot : Format.formatter -> (Id.t L.shape, A.t, A.data, _) egraph -> unit
654 | val add_node : rw t -> L.t -> eclass_id
655 | val merge : rw t -> eclass_id -> eclass_id -> unit
656 | val iter_children : ro t -> eclass_id -> Id.t L.shape Iter.t
657 | val rebuild : rw t -> unit
658 |
659 | val find_matches : ro t -> L.op query -> (eclass_id * eclass_id StringMap.t) Iter.t
660 | val apply_rules : (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> unit
661 | val run_until_saturation:
662 | ?scheduler:Scheduler.Backoff.t ->
663 | ?node_limit:[`Bounded of int | `Unbounded] ->
664 | ?fuel:[`Bounded of int | `Unbounded] ->
665 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) -> (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool
666 |
667 | module BuildRunner (S : SCHEDULER
668 | with type 'a egraph := (Id.t L.shape, A.t, A.data, rw) egraph
669 | and type rule := Rule.t) :
670 | sig
671 | val apply_rules :
672 | S.t ->
673 | int ->
674 | (Id.t L.shape, A.t, A.data, rw) egraph ->
675 | (Rule.t * S.data) array -> unit
676 | val run_until_saturation :
677 | ?scheduler:S.t ->
678 | ?node_limit:[`Bounded of int | `Unbounded] ->
679 | ?fuel:[`Bounded of int | `Unbounded] ->
680 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) ->
681 | (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool
682 | end
683 | end
684 | = struct
685 | let _unsafe = 10
686 | type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph
687 | include (MakeInt (L))
688 | include (MakePrinter (L) (A))
689 | include (MakeOps (L) (A) (Analysis))
690 | end
691 | and Analysis : sig
692 | val make: (Id.t L.shape, A.t, A.data, ro) egraph -> Id.t L.shape -> A.data
693 | val merge: A.t -> A.data -> A.data -> A.data * (bool * bool)
694 | val modify: (Id.t L.shape, A.t, A.data, rw) egraph -> Id.t -> unit
695 | end = MakeAnalysisOps (EGraph)
696 |
697 | include EGraph
698 |
699 | end
700 |
--------------------------------------------------------------------------------
/lib/generic.mli:
--------------------------------------------------------------------------------
1 | open Language
2 |
3 | type ('node, 'analysis, 'data, 'permission) egraph
4 |
5 | module MakePrinter : functor (L : LANGUAGE) (A : ANALYSIS) -> sig
6 | (* val pp : Format.formatter -> (Id.t L.shape, A.t, A.data, 'b) egraph -> unit *)
7 | val to_dot : (Id.t L.shape, A.t, A.data, 'b) egraph -> Odot.graph
8 | val pp_dot : Format.formatter -> (Id.t L.shape, A.t, A.data, 'b) egraph -> unit
9 | end
10 |
11 | module MakeExtractor : functor
12 | (L : LANGUAGE)
13 | (E : COST with type node := Id.t L.shape) -> sig
14 | val extract : (Id.t L.shape, 'a, 'b, rw) egraph -> Id.t -> L.t
15 | end
16 |
17 | module Make :
18 | functor
19 | (L : LANGUAGE)
20 | (A : ANALYSIS)
21 | (MakeAnalysisOps : functor
22 | (S : GRAPH_API with type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph
23 | and type analysis := A.t
24 | and type 'a shape := 'a L.shape
25 | and type data := A.data
26 | and type node := L.t) ->
27 | ANALYSIS_OPS with type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph
28 | and type analysis := A.t
29 | and type data := A.data
30 | and type node := Id.t L.shape) ->
31 | sig
32 | type 'p t = (Id.t L.shape, A.t, A.data, 'p) egraph
33 | module Rule:
34 | RULE with type query := L.op Query.t
35 | and type 'a egraph := (Id.t L.shape, A.t, A.data, 'a) egraph
36 |
37 | val freeze : rw t -> ro t
38 | val init : A.t -> 'p t
39 | val class_equal: ro t -> Id.t -> Id.t -> bool
40 | val new_class : rw t -> Id.t
41 | val set_data : rw t -> Id.t -> A.data -> unit
42 | val get_data : _ t -> Id.t -> A.data
43 | val get_analysis: rw t -> A.t
44 | val canonicalise : rw t -> Id.t L.shape -> Id.t L.shape
45 | val find : ro t -> Id.t -> Id.t
46 | val eclasses: rw t -> (Id.t L.shape, Containers.Vector.rw) Containers.Vector.t Id.Map.t
47 | val iter_children : ro t -> Id.t -> Id.t L.shape Iter.t
48 | val to_dot : (Id.t L.shape, A.t, A.data, _) egraph -> Odot.graph
49 | val pp_dot : Format.formatter -> (Id.t L.shape, A.t, A.data, _) egraph -> unit
50 | val add_node : rw t -> L.t -> Id.t
51 | val merge : rw t -> Id.t -> Id.t -> unit
52 | val rebuild : rw t -> unit
53 | val find_matches : ro t -> L.op Query.t -> (Id.t * Id.t StringMap.t) Iter.t
54 | val apply_rules : (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> unit
55 | val run_until_saturation:
56 | ?scheduler:Scheduler.Backoff.t ->
57 | ?node_limit:[`Bounded of int | `Unbounded] ->
58 | ?fuel:[`Bounded of int | `Unbounded] ->
59 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) -> (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool
60 |
61 | module BuildRunner (S : SCHEDULER
62 | with type 'a egraph := (Id.t L.shape, A.t, A.data, rw) egraph
63 | and type rule := Rule.t) :
64 | sig
65 | val apply_rules :
66 | S.t ->
67 | int ->
68 | (Id.t L.shape, A.t, A.data, rw) egraph ->
69 | (Rule.t * S.data) array -> unit
70 | val run_until_saturation :
71 | ?scheduler:S.t ->
72 | ?node_limit:[`Bounded of int | `Unbounded] ->
73 | ?fuel:[`Bounded of int | `Unbounded] ->
74 | ?until:((Id.t L.shape, A.t, A.data, rw) egraph -> bool) ->
75 | (Id.t L.shape, A.t, A.data, rw) egraph -> Rule.t list -> bool
76 | end
77 |
78 | end
79 |
--------------------------------------------------------------------------------
/lib/id.ml:
--------------------------------------------------------------------------------
1 | open Containers
2 |
3 | include (Equivalence.Make ())
4 |
5 | let eq_id = (Equal.map repr Equal.int)
6 |
7 | module OrderedSet = Ordered_set.Make (struct
8 | type nonrec t = t
9 | let equal = eq_id
10 | let hash = hash
11 | end)
12 |
13 |
--------------------------------------------------------------------------------
/lib/id.mli:
--------------------------------------------------------------------------------
1 | type t = private int
2 | type store
3 | val eq_id : t -> t -> bool
4 | val repr : t -> int
5 | val hash : t -> int
6 |
7 | val create_store : unit -> store
8 | val make : store -> unit -> t
9 | val find : store -> t -> t
10 | val equal : store -> t -> t -> bool
11 | val union : store -> t -> t -> t
12 |
13 | module Map : Hashtbl.S with type key = t
14 | module Set : CCHashSet.S with type elt = t
15 | module OrderedSet : Ordered_set.S with type elt = t
16 |
--------------------------------------------------------------------------------
/lib/language.ml:
--------------------------------------------------------------------------------
1 | open Containers
2 | module StringSet = Set.Make(String)
3 | module StringMap = Stdlib.Map.Make(String)
4 |
5 | let str p v = Format.to_string p v
6 |
7 | type sexp = Sexplib.Sexp.t = Atom of string | List of sexp list
8 |
9 | type rw = [`RW]
10 | type ro = [`RO]
11 |
12 |
13 | module type LANGUAGE = sig
14 | type 'a shape
15 | type op
16 |
17 | type t = Mk of t shape [@@unboxed]
18 |
19 | val equal_op: op -> op -> bool
20 |
21 | val pp_shape: (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a shape -> unit
22 | val compare_shape: ('a -> 'a -> int) -> 'a shape -> 'a shape -> int
23 | val op: 'a shape -> op
24 | val children: 'a shape -> 'a list
25 | val map_children: 'a shape -> ('a -> 'b) -> 'b shape
26 | val make : op -> 'a list -> 'a shape
27 | end
28 |
29 | module type ANALYSIS = sig
30 | type t
31 | type data [@@deriving show, eq]
32 | val default: data
33 | end
34 |
35 | module type ANALYSIS_OPS = sig
36 | type 'a t
37 | type analysis
38 | type node
39 | type data
40 | val make : ro t -> node -> data
41 | val merge : analysis -> data -> data -> data * (bool * bool)
42 | val modify : rw t -> Id.t -> unit
43 | end
44 |
45 |
46 | module type COST = sig
47 | type t
48 | type node
49 | val compare : t -> t -> int
50 | val cost : (Id.t -> t) -> node -> t
51 | end
52 |
53 | module type GRAPH_API = sig
54 | type 'p t
55 |
56 | type analysis
57 | type data
58 | type node
59 | type 'a shape
60 |
61 | val freeze : rw t -> ro t
62 | val class_equal : ro t -> Id.t -> Id.t -> bool
63 | val iter_children : ro t -> Id.t -> Id.t shape Iter.t
64 | val set_data : rw t -> Id.t -> data -> unit
65 | val get_data : ro t -> Id.t -> data
66 | val get_analysis : rw t -> analysis
67 | val add_node : rw t -> node -> Id.t
68 | val merge : rw t -> Id.t -> Id.t -> unit
69 | end
70 |
71 | module type RULE = sig
72 | type t
73 | type query
74 | type 'a egraph
75 |
76 | val make_constant : from:query -> into:query -> t
77 | val make_conditional :
78 | from:query ->
79 | into:query ->
80 | cond:(rw egraph -> Id.t -> Id.t StringMap.t -> bool) ->
81 | t
82 |
83 | val make_dynamic :
84 | from:query ->
85 | generator:(rw egraph -> Id.t -> Id.t StringMap.t -> query option) -> t
86 |
87 | end
88 |
89 | module type SCHEDULER = sig
90 |
91 | type 'a egraph
92 |
93 | type t
94 |
95 | type data
96 |
97 | type rule
98 |
99 | val default : unit -> t
100 |
101 | val should_stop: t -> int -> data Iter.t -> bool
102 |
103 | val create_rule_metadata: t -> rule -> data
104 |
105 | val guard_rule_usage:
106 | rw egraph -> t -> data -> int ->
107 | (unit -> (Id.t * Id.t StringMap.t) Iter.t) -> (Id.t * Id.t StringMap.t) Iter.t
108 |
109 | end
110 |
--------------------------------------------------------------------------------
/lib/ordered_set.ml:
--------------------------------------------------------------------------------
1 | open Containers
2 |
3 | module type S = sig
4 | type t
5 | type elt
6 | val create : unit -> t
7 | val push : elt -> t -> unit
8 | val pop : t -> elt
9 | val pop_opt : t -> elt option
10 | val append : t -> elt list -> unit
11 | val clear : t -> unit
12 | val copy : t -> t
13 | val is_empty : t -> bool
14 | val length : t -> int
15 | val iter : (elt -> unit) -> t -> unit
16 | val fold : ('a -> elt -> 'a) -> 'a -> t -> 'a
17 | end
18 |
19 | module Make (Elt: Hashtbl.HashedType) : S with type elt = Elt.t = struct
20 |
21 | module Set = CCHashSet.Make(Elt)
22 |
23 | type t = {
24 | elts: Elt.t Queue.t;
25 | cache: Set.t
26 | }
27 |
28 | type elt = Elt.t
29 |
30 | let create () = {elts=Queue.create (); cache=Set.create 1}
31 | let push vl st =
32 | if Set.mem st.cache vl
33 | then ()
34 | else (Queue.push vl st.elts; Set.insert st.cache vl)
35 |
36 | let pop st =
37 | let hd = Queue.pop st.elts in
38 | Set.remove st.cache hd;
39 | hd
40 |
41 | let pop_opt st =
42 | match Queue.peek_opt st.elts with
43 | | None -> None
44 | | Some hd ->
45 | ignore @@ Queue.pop st.elts;
46 | Set.remove st.cache hd;
47 | Some hd
48 |
49 | let append st elts = List.iter (fun elt -> push elt st) elts
50 |
51 | let clear st = Queue.clear st.elts; Set.clear st.cache
52 |
53 | let copy st = {elts=Queue.copy st.elts; cache=Set.copy st.cache}
54 |
55 | let is_empty st = Queue.is_empty st.elts
56 |
57 | let length st = Queue.length st.elts
58 |
59 | let iter f st = Queue.iter f st.elts
60 |
61 | let fold f acc st = Queue.fold f acc st.elts
62 |
63 | end
64 |
--------------------------------------------------------------------------------
/lib/ordered_set.mli:
--------------------------------------------------------------------------------
1 | module type S = sig
2 | type t
3 | type elt
4 | val create : unit -> t
5 | val push : elt -> t -> unit
6 | val pop : t -> elt
7 | val pop_opt : t -> elt option
8 | val append : t -> elt list -> unit
9 | val clear : t -> unit
10 | val copy : t -> t
11 | val is_empty : t -> bool
12 | val length : t -> int
13 | val iter : (elt -> unit) -> t -> unit
14 | val fold : ('a -> elt -> 'a) -> 'a -> t -> 'a
15 | end
16 |
17 | module Make:
18 | functor (Elt : Containers.Hashtbl.HashedType) -> S with type elt = Elt.t
19 |
--------------------------------------------------------------------------------
/lib/query.ml:
--------------------------------------------------------------------------------
1 | open Containers
2 | open Language
3 |
4 | type 'sym t =
5 | | V of string
6 | | Q of 'sym * 'sym t list
7 |
8 | let rec of_sexp intern : sexp -> _ t = function
9 | | Atom str when String.prefix ~pre:"?" str -> V (String.drop 1 str)
10 | | Atom sym -> Q (intern sym, [])
11 | | List (Atom sym :: children) ->
12 | Q (intern sym, List.map (of_sexp intern) children)
13 | | _ -> invalid_arg "Query sexp not of the expected form"
14 |
15 | let rec to_sexp to_string : _ t -> sexp = function
16 | | V str -> Atom ("?" ^ str)
17 | | Q (head, children) -> List (Atom (to_string head) :: List.map (to_sexp to_string) children)
18 |
19 | let rec pp symbol_pp fmt = function
20 | | V sym -> Format.pp_print_string fmt ("?" ^ sym)
21 | | Q (sym, []) -> symbol_pp fmt sym
22 | | Q (sym, children) ->
23 | let open Format in
24 | pp_print_string fmt "(";
25 | pp_open_hvbox fmt 1;
26 | symbol_pp fmt sym;
27 | pp_print_space fmt ();
28 | pp_print_list ~pp_sep:pp_print_space (pp symbol_pp) fmt children;
29 | pp_close_box fmt ();
30 | pp_print_string fmt ")"
31 |
32 | let show symbol_pp = str (pp symbol_pp)
33 |
34 | let%test "terms are printed as expected" =
35 | Alcotest.(check string)
36 | "prints as expected"
37 | "(+ 1 ?a)" (str (pp Symbol.pp) (Q (Symbol.intern "+", [Q (Symbol.intern "1", []); V "a"])))
38 |
39 | let variables query =
40 | let rec loop acc =
41 | function
42 | V sym -> StringSet.add sym acc
43 | | Q (_, children) ->
44 | List.fold_left loop acc children in
45 | loop StringSet.empty query
46 |
47 |
--------------------------------------------------------------------------------
/lib/scheduler.ml:
--------------------------------------------------------------------------------
1 | open Containers
2 | open Language
3 |
4 | module Backoff = struct
5 |
6 | type t = {match_limit: int; ban_length: int}
7 |
8 | type data = {
9 | mutable times_applied: int;
10 | mutable banned_until: int;
11 | mutable times_banned: int;
12 | mutable match_limit: int;
13 | mutable ban_length: int;
14 | }
15 |
16 | let with_params ~match_limit ~ban_length = {match_limit; ban_length}
17 |
18 | let default () : t = {
19 | match_limit = 1_000;
20 | ban_length = 5;
21 | }
22 |
23 | let create_rule_metadata ({match_limit; ban_length}: t) _ = {
24 | times_applied = 0;
25 | banned_until = 0;
26 | times_banned = 0;
27 | match_limit;
28 | ban_length;
29 | }
30 |
31 | let should_stop _ iteration stats =
32 | let banned = stats
33 | |> Iter.filter (fun data -> data.banned_until > iteration)
34 | |> Iter.to_array in
35 |
36 | if Array.length banned = 0
37 | then true
38 | else begin
39 | let min_ban =
40 | Iter.of_array banned
41 | |> Iter.map (fun data -> data.banned_until)
42 | |> Iter.min_exn ~lt:Int.(<) in
43 | let delta = min_ban - iteration in
44 |
45 | Iter.of_array banned
46 | |> Iter.iter (fun data -> data.banned_until <- data.banned_until - delta) ;
47 |
48 | false
49 | end
50 |
51 |
52 | let guard_rule_usage _ (_ : t) (data: data) iteration
53 | (gen_matches: (unit -> (Id.t * Id.t StringMap.t) Iter.t)) :
54 | (Id.t * Id.t StringMap.t) Iter.t =
55 | if iteration < data.banned_until
56 | then Iter.empty
57 | else begin
58 | let elts = Iter.to_array (gen_matches ()) in
59 | let total_len = Array.length elts in
60 | let threshold = data.match_limit lsl data.times_banned in
61 | if total_len > threshold
62 | then begin
63 | let ban_length = data.ban_length lsl data.times_banned in
64 | data.times_banned <- data.times_banned + 1;
65 | data.banned_until <- iteration + ban_length;
66 | Iter.empty
67 | end
68 | else begin
69 | data.times_applied <- data.times_applied + 1;
70 | Iter.of_array elts
71 | end
72 |
73 | end
74 |
75 | end
76 |
77 | module Simple = struct
78 |
79 |
80 | type t = unit
81 |
82 | type data = unit
83 |
84 | let init () : t = ()
85 |
86 | let create_rule_metadata _ _ = ()
87 |
88 | let should_stop _ _iteration _stats = true
89 |
90 | let guard_rule_usage _ (_ : t) ((): data) _iteration
91 | (gen_matches: (unit -> (Id.t * Id.t StringMap.t) Iter.t)) : (Id.t * Id.t StringMap.t) Iter.t =
92 | gen_matches ()
93 |
94 | end
95 |
--------------------------------------------------------------------------------
/lib/symbol.ml:
--------------------------------------------------------------------------------
1 | open Containers
2 |
3 | type t = int
4 |
5 | module SymbolMap = Map.Make(Int)
6 | module StrMap = Hashtbl.Make(String)
7 |
8 | let repr v = v
9 | let tbl = StrMap.create 10
10 | let strs = CCVector.create ()
11 |
12 | let intern str =
13 | match StrMap.find_opt tbl str with
14 | | Some id -> id
15 | | None ->
16 | let id = Vector.length strs in
17 | Vector.push strs str;
18 | StrMap.add tbl str id;
19 | id
20 |
21 | let pp fmt s =
22 | Format.pp_print_string fmt (Vector.get strs s)
23 |
24 | let to_string s = Vector.get strs s
25 |
--------------------------------------------------------------------------------
/lib/symbol.mli:
--------------------------------------------------------------------------------
1 | open Containers
2 | type t = private int
3 | val repr : t -> int
4 | val intern : string -> t
5 | val pp : Format.formatter -> t -> unit
6 | val to_string: t -> string
7 |
8 | module SymbolMap : Map.S with type key = t
9 |
--------------------------------------------------------------------------------
/lib/term.ml:
--------------------------------------------------------------------------------
1 | open Containers
2 |
3 | type t = string * Id.t list
4 |
5 | type op = string
6 | let equal_op = String.equal
7 | let pp_op = String.pp
8 |
9 | let pp pp_children fmt = function
10 | | (sym, []) -> Format.pp_print_string fmt sym
11 | | (sym, children) ->
12 | let open Format in
13 | pp_print_string fmt "(";
14 | pp_open_hvbox fmt 1;
15 | pp_print_string fmt sym;
16 | pp_print_space fmt ();
17 | pp_print_list ~pp_sep:pp_print_space pp_children fmt children;
18 | pp_close_box fmt ();
19 | pp_print_string fmt ")"
20 |
21 | let compare =
22 | Pair.compare String.compare
23 | (List.compare (fun id1 id2 ->
24 | Fun.uncurry Int.compare @@ Pair.map_same Id.repr (id1,id2)))
25 |
26 | let op = fst
27 | let children = snd
28 | let map_children t f = Pair.map_snd (List.map f) t
29 | let make = Pair.make
30 |
31 | let show pp_children = Format.to_string (pp pp_children)
32 | let rec of_sexp f : Sexplib0.Sexp.t -> t =
33 | function
34 | | Atom str -> (str, [])
35 | | List (Atom head :: tail) ->
36 | (head, List.to_iter tail |> Iter.map (of_sexp f) |> Iter.map f |> Iter.to_list)
37 | | _ -> failwith "invalid sexp structure"
38 |
39 |
--------------------------------------------------------------------------------
/lib/types.ml:
--------------------------------------------------------------------------------
1 | open Containers
2 |
3 | type eclass_id = Id.t
4 | type enode = Symbol.t * eclass_id list
5 |
6 | let str p v = Format.to_string p v
7 |
8 | (* ** ID *)
9 | module EClassId = struct
10 | type t = eclass_id
11 | let pp fmt id =
12 | Format.pp_print_string fmt @@ Printf.sprintf "e%d" (Id.repr id)
13 | let show = str pp
14 |
15 | let compare (a:t) (b: t) =
16 | Int.compare (a :> int) (b :> int)
17 |
18 | let%test "IDs print correctly" =
19 | let store = Id.create_store () in
20 | Alcotest.(check string)
21 | "should pretty print as e0"
22 | "e0" (str pp (Id.make store ()))
23 |
24 | end
25 |
26 |
27 | (* ** Node *)
28 | module ENode = struct
29 |
30 | type t = enode
31 |
32 | let children (_, children) = children
33 |
34 | let canonicalise uf (sym, children) =
35 | (sym, List.map (Id.find uf) children)
36 |
37 | let hash : enode Hash.t = Hash.(pair poly (list Id.hash))
38 |
39 | let%test "node hashes correctly" =
40 | let store = Id.create_store () in
41 | let i1 = Id.make store () in
42 | Alcotest.(check int)
43 | "hash values should match"
44 | (hash (Symbol.intern "example", [i1]))
45 | (hash (Symbol.intern "example", [i1]))
46 |
47 | let%test "node hashes correctly after union" =
48 | let store = Id.create_store () in
49 | let i1 = Id.make store () in
50 | let i2 = Id.make store () in
51 | let hash_1 = hash (Symbol.intern "example", [i1]) in
52 | ignore @@ Id.union store i1 i2;
53 | let hash_2 = hash (Symbol.intern "example", [i1]) in
54 | Alcotest.(check int)
55 | "hash values should match"
56 | hash_1
57 | hash_2
58 |
59 | let equal : enode Equal.t = Equal.(pair poly (list Id.eq_id))
60 |
61 | let pp ?(pp_id=EClassId.pp) fmt (sym, children) =
62 | match children with
63 | | [] -> Symbol.pp fmt sym
64 | | children ->
65 | let open Format in
66 | pp_print_string fmt "(";
67 | pp_open_hvbox fmt 1;
68 | Symbol.pp fmt sym;
69 | pp_print_space fmt ();
70 | pp_print_list ~pp_sep:(pp_print_space) pp_id fmt children;
71 | pp_close_box fmt ();
72 | pp_print_string fmt ")"
73 |
74 | let%test "leaf nodes prints correctly" =
75 | Alcotest.(check string)
76 | "should pretty print as sexp"
77 | "example"
78 | (str (pp ~pp_id:EClassId.pp)
79 | (Symbol.intern "example", []))
80 |
81 | let%test "node prints correctly" =
82 | let store = Id.create_store () in
83 | Alcotest.(check string)
84 | "should pretty print as sexp"
85 | "(example e0 e1 e2)"
86 | (str (pp ~pp_id:EClassId.pp)
87 | (Symbol.intern "example",
88 | List.init 3 (fun _ -> Id.make store ()))
89 | )
90 | module Set = Set.Make (struct
91 | type t = enode
92 | let compare n1 n2 = Int.compare (hash n1) (hash n2)
93 | end)
94 |
95 | end
96 |
--------------------------------------------------------------------------------
/macros/dune:
--------------------------------------------------------------------------------
1 | (library
2 | (name ppx_sexp)
3 | (modules ppx_sexp)
4 | (kind ppx_rewriter)
5 | (libraries ppxlib sexplib)
6 | (preprocess (pps ppxlib.metaquot ppx_deriving.std)))
7 |
--------------------------------------------------------------------------------
/macros/ppx_sexp.ml:
--------------------------------------------------------------------------------
1 | open Ppxlib
2 |
3 | let name = "s"
4 |
5 | let build_atom ~loc v =
6 | let str = Ast_helper.(Exp.constant ~loc (Pconst_string (v, loc, None))) in
7 | [%expr Sexplib0.Sexp.Atom [%e str ]]
8 |
9 | let build_list ~loc ls =
10 | let rec build_ls = function
11 | | [] -> [%expr []]
12 | | h :: t -> [%expr [%e h] :: [%e build_ls t]] in
13 | [%expr Sexplib0.Sexp.List [%e (build_ls ls)]]
14 |
15 | let rec convert ~loc expr =
16 | match expr with
17 | | { pexp_desc=Pexp_ident { txt=Lident "()"; _ }; pexp_loc=loc; _ } ->
18 | build_list ~loc []
19 | | { pexp_desc=Pexp_ident { txt=Lident txt; _ }; pexp_loc=loc; _ }
20 | when txt.[0] = '(' && txt.[String.length txt - 1] = ')' ->
21 | build_atom ~loc (String.sub txt 1 (String.length txt - 2))
22 | | { pexp_desc=Pexp_ident { txt=Lident txt; _ }; pexp_loc=loc; _ } ->
23 | build_atom ~loc txt
24 | | { pexp_desc=Pexp_constant const; pexp_loc=loc; _ } ->
25 | let const = match const with
26 | | Pconst_integer (txt, _) -> txt
27 | | Pconst_char cr -> String.make 1 cr
28 | | Pconst_string (txt, _, _) -> txt
29 | | Pconst_float (txt, _) -> txt in
30 | build_atom ~loc const
31 | | { pexp_desc=Pexp_apply (expr, args); pexp_loc=loc; _ }
32 | when List.for_all (function (Nolabel, _) -> true | _ -> false) args ->
33 | let h = convert ~loc:expr.pexp_loc expr in
34 | let t = List.map (fun (_, expr) -> convert ~loc:expr.pexp_loc expr) args in
35 | build_list ~loc (h :: t)
36 | | [%expr [%e? x] x ] -> x
37 | | e ->
38 | let exp = Pprintast.expression Format.str_formatter e; Format.flush_str_formatter () in
39 | Location.raise_errorf ~loc "use of unsupported syntactic construct %s" exp
40 |
41 | let expand ~loc ~path:_ expr = convert ~loc expr
42 |
43 | let ext =
44 | Extension.declare name Extension.Context.expression
45 | Ast_pattern.(single_expr_payload __)
46 | expand
47 |
48 | let () = Driver.register_transformation name ~extensions:[ext]
49 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Ego - EGraphs in OCaml
2 |
3 | Ego (EGraphs OCaml) is an OCaml library that provides generic equality
4 | saturation using EGraphs.
5 |
6 | The design of Ego loosely follows the design of Rust's egg library,
7 | providing a flexible interface to run equality saturation extended
8 | with custom user-defined analyses.
9 |
10 | ```ocaml
11 | (* create an egraph *)
12 | let graph = EGraph.init ()
13 | (* add expressions *)
14 | let expr1 = EGraph.add_sexp graph [%s ((a << 1) / 2)]
15 | (* Convert to graphviz *)
16 | let g : Odot.graph = EGraph.to_dot graph
17 | ```
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/test/dune:
--------------------------------------------------------------------------------
1 | (tests
2 | (names test_basic test_generic test_math test_prop)
3 | (preprocess (pps ppx_sexp ppx_deriving.std))
4 | (libraries ego alcotest))
5 |
--------------------------------------------------------------------------------
/test/test_basic.ml:
--------------------------------------------------------------------------------
1 | open Ego.Basic
2 |
3 | let sexp =
4 | (module struct
5 | type t = Sexplib.Sexp.t
6 | let pp = Sexplib.Sexp.pp_hum
7 | let equal = Sexplib.Sexp.equal
8 | end : Alcotest.TESTABLE with type t = Sexplib.Sexp.t)
9 |
10 | let documentation_example () =
11 | let graph = EGraph.init () in
12 | let expr_id = EGraph.add_sexp graph [%s ((a << 1) / 2)] in
13 | let from = Query.of_sexp [%s ("?a" << 1)]
14 | and into = Query.of_sexp [%s ("?a" * 2)] in
15 | let rule = Rule.make ~from ~into |> function
16 | | Some rule -> rule
17 | | None -> Alcotest.fail "could not build rule" in
18 | Alcotest.(check bool)
19 | "should reach equality saturation"
20 | true (EGraph.run_until_saturation graph [rule]);
21 | let cost_function score (sym, children) =
22 | let node_score =
23 | match Symbol.to_string sym with
24 | | "*" -> 1.
25 | | "/" -> 1.
26 | | "<<" -> 2.
27 | | _ -> 0. in
28 | node_score +. List.fold_left (fun acc vl -> acc +. score vl) 0. children in
29 | let result = EGraph.extract cost_function graph expr_id in
30 | Alcotest.(check sexp)
31 | "extracted expression has been simplified"
32 | [%s ((a * 2) / 2)]
33 | result
34 |
35 | (*
36 | We start off with two exprs, (g 1) and (g 2), and merge these two.
37 | Then we add a rule (g ?a) -> (h ?a), creating (h 1) and (h 2) which are also equal to (g 1) and (g 2).
38 | We extract the cheapest term using a cost function constructed such that (h 2) is lowest cost term, with cost 11.
39 | Previously (h 1), which has cost 12, was extracted instead.
40 |
41 | (h 1): 12
42 | (h 2): 11
43 | (g 1): inf
44 | (g 2): inf
45 | *)
46 | let test_match () =
47 | let graph = EGraph.init () in
48 | let expr_id1 = EGraph.add_sexp graph [%s (g 1)] in
49 | let _ = EGraph.add_sexp graph [%s (g 2)] in
50 | let from = Query.of_sexp [%s (g 1)]
51 | and into = Query.of_sexp [%s (g 2)] in
52 | let rule1 = Rule.make ~from ~into |> function
53 | | Some rule -> rule
54 | | None -> Alcotest.fail "could not build rule" in
55 | let from = Query.of_sexp [%s (g "?a")]
56 | and into = Query.of_sexp [%s (h "?a")] in
57 | let rule2 = Rule.make ~from ~into |> function
58 | | Some rule -> rule
59 | | None -> Alcotest.fail "could not build rule" in
60 | Alcotest.(check bool)
61 | "should reach equality saturation"
62 | true (EGraph.run_until_saturation graph [rule1; rule2]);
63 | let cost_function score (sym, children) =
64 | let node_score =
65 | match Symbol.to_string sym with
66 | | "g" -> 9999999.
67 | | "h" -> 10.
68 | | "1" -> 2.
69 | | "2" -> 1.
70 | | _ -> 9999999. in
71 | node_score +. List.fold_left (fun acc vl -> acc +. score vl) 0. children in
72 | let result = EGraph.extract cost_function graph expr_id1 in
73 | Alcotest.(check sexp)
74 | "cheapest expression is (h 2)"
75 | [%s (h 2)]
76 | result
77 |
78 |
79 | let () =
80 | Alcotest.run "basic" [
81 | ("documentation", ["example given in documentation works as written", `Quick, documentation_example]);
82 | ("test matching", ["test matching", `Quick, test_match])
83 | ]
84 |
--------------------------------------------------------------------------------
/test/test_generic.ml:
--------------------------------------------------------------------------------
1 | open Ego.Generic
2 |
3 | let sexp =
4 | (module struct
5 | type t = Sexplib.Sexp.t
6 | let pp = Sexplib.Sexp.pp_hum
7 | let equal = Sexplib.Sexp.equal
8 | end : Alcotest.TESTABLE with type t = Sexplib.Sexp.t)
9 |
10 | module L = struct
11 |
12 | type 'a shape = Add of 'a * 'a | Sub of 'a * 'a | Mul of 'a * 'a
13 | | Div of 'a * 'a | Var of string | Const of int [@@deriving ord, show]
14 |
15 | type op = AddOp | SubOp | MulOp | DivOp | VarOp of string | ConstOp of int [@@deriving eq]
16 |
17 | type t = Mk of t shape [@@unboxed]
18 |
19 | let rec of_sexp = function [@warning "-8"]
20 | | Sexplib0.Sexp.Atom s ->
21 | begin match int_of_string_opt s with
22 | | Some n -> Mk (Const n)
23 | | None -> Mk (Var s)
24 | end
25 | | List [Atom ("*" | " * "); l; r] -> Mk (Mul (of_sexp l, of_sexp r))
26 | | List [Atom "-"; l; r] -> Mk (Sub (of_sexp l, of_sexp r))
27 | | List [Atom "+"; l; r] -> Mk (Add (of_sexp l, of_sexp r))
28 | | List [Atom "/"; l; r] -> Mk (Div (of_sexp l, of_sexp r))
29 |
30 | let rec to_sexp = function
31 | | Mk (Add (l, r)) -> Sexplib0.Sexp.List [Atom "+"; to_sexp l; to_sexp r]
32 | | Mk (Sub (l, r)) -> List [Atom "-"; to_sexp l; to_sexp r]
33 | | Mk (Mul (l, r)) -> List [Atom "*"; to_sexp l; to_sexp r]
34 | | Mk (Div (l, r)) -> List [Atom "/"; to_sexp l; to_sexp r]
35 | | Mk (Var s) -> Atom s
36 | | Mk (Const n) -> Atom (Int.to_string n)
37 |
38 | let op = function
39 | | Add _ -> AddOp
40 | | Sub _ -> SubOp
41 | | Mul _ -> MulOp
42 | | Div _ -> DivOp
43 | | Var s -> VarOp s
44 | | Const i -> ConstOp i
45 |
46 | let op_of_string = function
47 | | "+" -> AddOp
48 | | "-" -> SubOp
49 | | ("*" | " * ") -> MulOp
50 | | "/" -> DivOp
51 | | s -> match int_of_string_opt s with
52 | | None -> VarOp s
53 | | Some n -> ConstOp n
54 |
55 | let children = function
56 | | Add (l,r) | Sub (l,r) | Mul (l,r) | Div (l,r) -> [l;r]
57 | | Var _ | Const _ -> []
58 |
59 | let map_children term f = match term with
60 | | Add (l,r) -> Add (f l, f r)
61 | | Sub (l,r) -> Sub (f l, f r)
62 | | Mul (l,r) -> Mul (f l, f r)
63 | | Div (l,r) -> Div (f l, f r)
64 | | Var s -> Var s | Const i -> Const i
65 |
66 | let make op ls =
67 | match[@warning "-8"] op,ls with
68 | | AddOp, [l;r] -> Add (l,r)
69 | | SubOp, [l;r] -> Sub (l,r)
70 | | MulOp, [l;r] -> Mul (l,r)
71 | | DivOp, [l;r] -> Div (l,r)
72 | | VarOp s, [] -> Var s
73 | | ConstOp i, [] -> Const i
74 |
75 | end
76 |
77 | module C = struct
78 | type t = float [@@deriving ord]
79 | let cost f : Ego.Id.t L.shape -> t = function
80 | | L.Add (l, r) -> f l +. f r +. 1.0
81 | | L.Sub (l, r) -> f l +. f r +. 1.5
82 | | L.Mul (l, r) -> f l +. f r +. 2.0
83 | | L.Div (l, r) -> f l +. f r +. 2.0
84 | | L.Var _ -> 1.0
85 | | L.Const _ -> 1.0
86 | end
87 |
88 | module A = struct type t = unit type data = int option [@@deriving eq, show] let default = None end
89 | module MA (S : GRAPH_API
90 | with type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph
91 | and type 'a shape := 'a L.shape
92 | and type analysis := A.t
93 | and type data := A.data
94 | and type node := L.t) = struct
95 | type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph
96 |
97 | let eval : A.data L.shape -> A.data =
98 | function
99 | | L.Add (Some l, Some r) -> Some (l + r)
100 | | L.Sub (Some l, Some r) -> Some (l - r)
101 | | L.Mul (Some l, Some r) -> Some (l * r)
102 | | L.Div (Some l, Some r) -> if r <> 0 then Some (l / r) else None
103 | | L.Const n -> Some n
104 | | _ -> None
105 |
106 | let make : ro t -> Ego.Id.t L.shape -> A.data =
107 | fun graph term ->
108 | eval (L.map_children term (S.get_data graph))
109 |
110 | let merge : A.t -> A.data -> A.data -> A.data * (bool * bool) =
111 | fun () l r -> match l,r with
112 | | Some l, Some r -> assert (l = r); Some l, (false, false)
113 | | Some l, None -> Some l, (false, true)
114 | | None, Some r -> Some r, (true, false)
115 | | _ -> None, (false, false)
116 |
117 | let modify : 'a t -> Ego.Id.t -> unit =
118 | fun graph cls ->
119 | match S.get_data (S.freeze graph) cls with
120 | | None -> ()
121 | | Some n ->
122 | let nw_cls = S.add_node graph (L.Mk (Const n)) in
123 | S.merge graph nw_cls cls
124 |
125 | end
126 |
127 | module EGraph = Make (L) (A) (MA)
128 | module Extractor = MakeExtractor (L) (C)
129 |
130 |
131 | let documentation_example () =
132 | let graph = EGraph.init () in
133 | let expr = EGraph.add_node graph (L.of_sexp [%s (2 * 2)]) in
134 | let result = Extractor.extract graph expr in
135 | Alcotest.(check sexp)
136 | "extracted expression has been simplified"
137 | [%s 4]
138 | (L.to_sexp result)
139 |
140 | let simple_constant_folding () =
141 | let graph = EGraph.init () in
142 | let expr = EGraph.add_node graph (L.of_sexp [%s (2 + (1 + (3 - 2)))]) in
143 | let result = Extractor.extract graph expr in
144 | Alcotest.(check sexp)
145 | "extracted expression has been simplified"
146 | [%s 4]
147 | (L.to_sexp result)
148 |
149 |
150 | let multiple_terms_constant_folding () =
151 | let graph = EGraph.init () in
152 | let expr1 = EGraph.add_node graph (L.of_sexp [%s (3 * (2 - (10 / 5)))]) in
153 | let expr2 = EGraph.add_node graph (L.of_sexp [%s (3 - 3)]) in
154 | Alcotest.(check sexp)
155 | "first extracted expression has been simplified"
156 | [%s 0]
157 | (L.to_sexp (Extractor.extract graph expr1));
158 | Alcotest.(check sexp)
159 | "second extracted expression has been simplified"
160 | [%s 0]
161 | (L.to_sexp (Extractor.extract graph expr2))
162 |
163 | let multiple_terms_variable_constant_folding () =
164 | let graph = EGraph.init () in
165 | let expr1 = EGraph.add_node graph (L.of_sexp [%s ((2 * x) + (3 * (2 - (10 / 5))))]) in
166 | let expr2 = EGraph.add_node graph (L.of_sexp [%s ((3 - 3) + (2 * x))]) in
167 | Alcotest.(check sexp)
168 | "first extracted expression has been simplified"
169 | [%s ("+" (( * ) 2 x) 0)]
170 | (L.to_sexp (Extractor.extract graph expr1));
171 | Alcotest.(check sexp)
172 | "second extracted expression has been simplified"
173 | [%s ("+" 0 (( * ) 2 x))]
174 | (L.to_sexp (Extractor.extract graph expr2))
175 |
176 | let syntactic_rewrite () =
177 | let graph = EGraph.init () in
178 | let rewrite =
179 | EGraph.Rule.make_constant
180 | ~from:(Query.of_sexp L.op_of_string [%s (2 * "?x")])
181 | ~into:(Query.of_sexp L.op_of_string [%s ("?x" + "?x")]) in
182 | let expr = EGraph.add_node graph (L.of_sexp [%s (1 + (3 / (2 * a)))]) in
183 | Alcotest.(check bool)
184 | "rewrites reached saturation" true
185 | @@ EGraph.run_until_saturation graph [rewrite];
186 | Alcotest.(check sexp)
187 | "first extracted expression has been simplified"
188 | [%s ("+" 1 ((/) 3 ("+" a a)))]
189 | (L.to_sexp (Extractor.extract graph expr))
190 |
191 | let conditional_rewrite () =
192 | let graph = EGraph.init () in
193 | let rewrite =
194 | EGraph.Rule.make_conditional
195 | ~from:(Query.of_sexp L.op_of_string [%s ("?x" / "?x")])
196 | ~into:(Query.of_sexp L.op_of_string [%s 1])
197 | ~cond:(fun graph _root env ->
198 | let x = StringMap.find "x" env in
199 | match EGraph.get_data graph x with
200 | | None | Some 0 -> false
201 | | _ -> true (* only safe to do this rewrite if x isn't 0 *)) in
202 | let expr_valid = EGraph.add_node graph (L.of_sexp [%s (10 / 10)]) in
203 | let expr_invalid = EGraph.add_node graph (L.of_sexp [%s (0 / 0)]) in
204 | let expr_invalid_compl = EGraph.add_node graph (L.of_sexp [%s ((4 * 3 - 6 * 2) / (2 * 3 - 3 * 2))]) in
205 | let expr_valid_compl = EGraph.add_node graph (L.of_sexp [%s (((4 * 3 - 6 * 2) + 1) / ((2 * 3 - 3 * 2) + 1))]) in
206 | let expr_x_x = EGraph.add_node graph (L.of_sexp [%s (x / x)]) in
207 | Alcotest.(check bool)
208 | "rewrites reached saturation" true
209 | @@ EGraph.run_until_saturation graph [rewrite];
210 | Alcotest.(check sexp)
211 | "basic expression has been simplified"
212 | [%s 1]
213 | (L.to_sexp (Extractor.extract graph expr_valid));
214 | Alcotest.(check sexp)
215 | "invalid expression has not been simplified"
216 | [%s ("/" 0 0)]
217 | (L.to_sexp (Extractor.extract graph expr_invalid));
218 | Alcotest.(check sexp)
219 | "complex invalid expression has not been simplified beyond minimal"
220 | [%s ("/" 0 0)]
221 | (L.to_sexp (Extractor.extract graph expr_invalid_compl));
222 | Alcotest.(check sexp)
223 | "complex valid expression has been simplified"
224 | [%s 1]
225 | (L.to_sexp (Extractor.extract graph expr_valid_compl));
226 | Alcotest.(check sexp)
227 | "expression of variables has not been simplified"
228 | [%s ("/" x x)]
229 | (L.to_sexp (Extractor.extract graph expr_x_x))
230 |
231 |
232 | let () =
233 | Alcotest.run "generic" [
234 | ("documentation", [
235 | "example given in documentation works as written", `Quick, documentation_example;
236 | "simple constant folding", `Quick, simple_constant_folding;
237 | "multiple terms constant folding", `Quick, multiple_terms_constant_folding;
238 | "multiple terms with variable constant folding", `Quick, multiple_terms_variable_constant_folding;
239 | "syntactic rewriting", `Quick, syntactic_rewrite;
240 | "conditional rewriting", `Quick, conditional_rewrite;
241 | ])
242 | ]
243 |
--------------------------------------------------------------------------------
/test/test_math.ml:
--------------------------------------------------------------------------------
1 | open Ego.Generic
2 | let sexp =
3 | (module struct
4 | type t = Sexplib.Sexp.t
5 | let pp = Sexplib.Sexp.pp_hum
6 | let equal = Sexplib.Sexp.equal
7 | end : Alcotest.TESTABLE with type t = Sexplib.Sexp.t)
8 |
9 | module Symbol : sig
10 | type t
11 | val compare : t -> t -> int
12 | val equal: t -> t -> bool
13 | val pp : Format.formatter -> t -> unit
14 | val intern: string -> t
15 | val to_string: t -> string
16 | end = struct
17 | type t = int
18 | let equal = Int.equal
19 | let compare = Int.compare
20 | let intern, to_string =
21 | let tbl = ref @@ StringMap.empty in
22 | let buf = Array.make 100 "" in
23 | let limit = ref 0 in
24 | let intern s =
25 | match (StringMap.find_opt s !tbl) with
26 | Some n -> n | None -> let ind = !limit in buf.(ind) <- s; incr limit;
27 | tbl := (StringMap.add s ind !tbl); ind in
28 | let to_string n = buf.(n) in
29 | intern, to_string
30 |
31 | let pp fmt s = Format.pp_print_string fmt (to_string s)
32 | end
33 |
34 | module L = struct
35 | type 'a shape =
36 | | Diff of 'a * 'a
37 | | Integral of 'a * 'a
38 | | Add of 'a * 'a
39 | | Sub of 'a * 'a
40 | | Mul of 'a * 'a
41 | | Div of 'a * 'a
42 | | Pow of 'a * 'a
43 | | Ln of 'a
44 | | Sqrt of 'a
45 | | Sin of 'a
46 | | Cos of 'a
47 | | Constant of float
48 | | Symbol of Symbol.t
49 | [@@deriving ord, show]
50 |
51 | (* let float_equal f1 f2 =
52 | * print_endline @@ Printf.sprintf "comparing %f eq with %f = %b" f1 f2 (Float.equal f1 f2);
53 | * Float.equal f1 f2 *)
54 |
55 | type op =
56 | | DiffOp | IntegralOp | AddOp | SubOp | MulOp | DivOp | PowOp | LnOp | SqrtOp
57 | | SinOp | CosOp | ConstantOp of float | SymbolOp of Symbol.t [@@deriving eq]
58 |
59 | type t = Mk of t shape [@@unboxed]
60 |
61 | let rec of_sexp : Sexplib0.Sexp.t -> _ = function [@warning "-8"]
62 | | Atom s ->
63 | begin match float_of_string_opt s with
64 | | Some n -> Mk (Constant n)
65 | | None ->
66 | match int_of_string_opt s with
67 | | Some n -> Mk (Constant (Float.of_int n))
68 | | None -> Mk (Symbol (Symbol.intern s))
69 | end
70 | | List [Atom "d"; l; r] -> Mk (Diff (of_sexp l, of_sexp r))
71 | | List [Atom "i"; l; r] -> Mk (Integral (of_sexp l, of_sexp r))
72 | | List [Atom ("*" | " * "); l; r] -> Mk (Mul (of_sexp l, of_sexp r))
73 | | List [Atom "-"; l; r] -> Mk (Sub (of_sexp l, of_sexp r))
74 | | List [Atom "+"; l; r] -> Mk (Add (of_sexp l, of_sexp r))
75 | | List [Atom "/"; l; r] -> Mk (Div (of_sexp l, of_sexp r))
76 | | List [Atom "pow"; l; r] -> Mk (Pow (of_sexp l, of_sexp r))
77 | | List [Atom "ln"; l] -> Mk (Ln (of_sexp l))
78 | | List [Atom "sqrt"; l] -> Mk (Sqrt (of_sexp l))
79 | | List [Atom "sin"; l] -> Mk (Sin (of_sexp l))
80 | | List [Atom "cos"; l] -> Mk (Cos (of_sexp l))
81 |
82 | let rec to_sexp : t -> Sexplib0.Sexp.t = function
83 | | Mk (Diff (l, r)) -> List [Atom "d"; to_sexp l; to_sexp r]
84 | | Mk (Integral (l, r)) -> List [Atom "i"; to_sexp l; to_sexp r]
85 | | Mk (Add (l, r)) -> List [Atom "+"; to_sexp l; to_sexp r]
86 | | Mk (Sub (l, r)) -> List [Atom "-"; to_sexp l; to_sexp r]
87 | | Mk (Mul (l, r)) -> List [Atom "*"; to_sexp l; to_sexp r]
88 | | Mk (Div (l, r)) -> List [Atom "/"; to_sexp l; to_sexp r]
89 | | Mk (Pow (l, r)) -> List [Atom "pow"; to_sexp l; to_sexp r]
90 | | Mk (Ln l) -> List [Atom "ln"; to_sexp l]
91 | | Mk (Sqrt l) -> List [Atom "sqrt"; to_sexp l]
92 | | Mk (Sin l) -> List [Atom "sin"; to_sexp l]
93 | | Mk (Cos l) -> List [Atom "cos"; to_sexp l]
94 | | Mk (Constant l) -> Atom (Float.to_string l)
95 | | Mk (Symbol l) -> Atom (Symbol.to_string l)
96 |
97 | let op = function
98 | | Diff (_, _) -> DiffOp | Integral (_, _) -> IntegralOp
99 | | Add (_, _) -> AddOp | Sub (_, _) -> SubOp
100 | | Mul (_, _) -> MulOp | Div (_, _) -> DivOp
101 | | Pow (_, _) -> PowOp | Ln _ -> LnOp
102 | | Sqrt _ -> SqrtOp | Sin _ -> SinOp | Cos _ -> CosOp
103 | | Constant c -> ConstantOp c | Symbol s -> SymbolOp s
104 |
105 | let op_of_string : string -> op = function [@warning "-8"]
106 | | "d" -> DiffOp | "i" -> IntegralOp | ("*" | " * ") -> MulOp
107 | | "-" -> SubOp | "+" -> AddOp | "/" -> DivOp
108 | | "pow" -> PowOp | "ln" -> LnOp | "sqrt" -> SqrtOp
109 | | "sin" -> SinOp | "cos" -> CosOp | s ->
110 | begin match float_of_string_opt s with
111 | | Some n -> (ConstantOp n)
112 | | None ->
113 | match int_of_string_opt s with
114 | | Some n -> ConstantOp (Float.of_int n)
115 | | None -> SymbolOp (Symbol.intern s)
116 | end
117 |
118 | let children = function
119 | | Diff (l, r) | Integral (l, r) | Add (l, r) | Sub (l, r) | Mul (l, r)
120 | | Div (l, r) | Pow (l, r) -> [l;r]
121 | | Ln l | Sqrt l | Sin l | Cos l -> [l]
122 | | Constant _ | Symbol _ -> []
123 |
124 | let map_children term f = match term with
125 | | Diff (l, r) -> Diff (f l, f r) | Integral (l, r) -> Integral (f l, f r)
126 | | Add (l, r) -> Add (f l, f r) | Sub (l, r) -> Sub (f l, f r)
127 | | Mul (l, r) -> Mul (f l, f r) | Div (l, r) -> Div (f l, f r)
128 | | Pow (l, r) -> Pow (f l, f r) | Ln l -> Ln (f l)
129 | | Sqrt l -> Sqrt (f l) | Sin l -> Sin (f l) | Cos l -> Cos (f l)
130 | | Constant c -> Constant c | Symbol s -> Symbol s
131 |
132 | let make op ls =
133 | match[@warning "-8"] op,ls with
134 | | DiffOp, [l;r] -> Diff (l, r) | IntegralOp, [l;r] -> Integral (l, r)
135 | | AddOp, [l;r] -> Add (l, r) | SubOp, [l;r] -> Sub (l, r)
136 | | MulOp, [l;r] -> Mul (l, r) | DivOp, [l;r] -> Div (l, r)
137 | | PowOp, [l;r] -> Pow (l, r) | LnOp, [l] -> Ln l
138 | | SqrtOp, [l] -> Sqrt l | SinOp, [l] -> Sin l | CosOp, [l] -> Cos l
139 | | ConstantOp c, [] -> Constant c | SymbolOp s, [] -> Symbol s
140 |
141 | end
142 |
143 | module C = struct
144 | type t = int [@@deriving ord]
145 | let cost f : Ego.Id.t L.shape -> t =
146 | fun term ->
147 | let base_cost = match term with Diff _ | Integral _ -> 100 | Sub _ -> 20 | _ -> 1 in
148 | L.children term |> List.fold_left (fun acc vl -> acc + f vl) base_cost
149 | end
150 |
151 | module A = struct type t = unit type data = float option [@@deriving eq, show] let default = None end
152 | module MA (S : GRAPH_API
153 | with type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph
154 | and type 'a shape := 'a L.shape
155 | and type analysis := A.t
156 | and type data := A.data
157 | and type node := L.t) = struct
158 | type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph
159 |
160 | let eval : A.data L.shape -> A.data =
161 | function
162 | | L.Add (Some l, Some r) -> Some (l +. r)
163 | | L.Sub (Some l, Some r) -> Some (l -. r)
164 | | L.Mul (Some l, Some r) -> Some (l *. r)
165 | | L.Div (Some l, Some r) -> if Containers.Float.equal_precision ~epsilon:0.01 r 0. then Some (l /. r) else None
166 | | L.Constant n -> Some n
167 | | _ -> None
168 |
169 | let make : ro t -> Ego.Id.t L.shape -> A.data =
170 | fun graph term ->
171 | eval (L.map_children term (S.get_data graph))
172 |
173 | let merge : A.t -> A.data -> A.data -> A.data * (bool * bool)=
174 | fun () l r -> match l,r with
175 | | Some l, Some r ->
176 | if Float.equal l r
177 | then Some l, (false,false)
178 | else failwith @@ Printf.sprintf "merge failed: float values %f <> %f " l r
179 | | Some l, _ -> Some l, (false, true)
180 | | _, Some r -> Some r, (true, false)
181 | | _ -> None, (false,false)
182 |
183 | let modify : 'a t -> Ego.Id.t -> unit =
184 | fun graph cls ->
185 | match S.get_data (S.freeze graph) cls with
186 | | None -> ()
187 | | Some n ->
188 | let nw_cls = S.add_node graph (L.Mk (Constant n)) in
189 | S.merge graph nw_cls cls
190 |
191 | end
192 |
193 | module EGraph = Make (L) (A) (MA)
194 | module Extractor = MakeExtractor (L) (C)
195 |
196 | let is_const_or_distinct_var v w =
197 | fun graph _root_id env ->
198 | let v = StringMap.find v env in
199 | let w = StringMap.find w env in
200 | (not @@ EGraph.class_equal (EGraph.freeze graph) v w)
201 | && ((EGraph.get_data graph v |> Option.is_some) ||
202 | EGraph.iter_children (EGraph.freeze graph) v |> Iter.exists (function L.Symbol _ -> true | _ -> false))
203 |
204 | let is_const v =
205 | fun graph _root_id env ->
206 | let v = StringMap.find v env in
207 | EGraph.get_data graph v |> Option.is_some
208 |
209 | let is_sym v =
210 | fun graph _root_id env ->
211 | let v = StringMap.find v env in
212 | EGraph.iter_children (EGraph.freeze graph) v |> Iter.exists (function L.Symbol _ -> true | _ -> false)
213 |
214 | let is_not_zero v =
215 | fun graph _root_id env ->
216 | let v = StringMap.find v env in
217 | EGraph.get_data graph v |> function Some 0.0 -> false | _ -> true
218 |
219 | let qf = Query.of_sexp L.op_of_string
220 | let (@->) from into = EGraph.Rule.make_constant ~from:(qf from) ~into:(qf into)
221 | let rewrite from into ~if_ = EGraph.Rule.make_conditional ~from:(qf from) ~into:(qf into) ~cond:if_
222 |
223 | let rules =
224 | let[@warning "-26"] (&&) f1 f2 = fun graph root_id env -> (f1 graph root_id env) && (f2 graph root_id env) in
225 | [
226 |
227 | [%s ("?a" + "?b")] @-> [%s ("?b" + "?a")]; (* comm-add *)
228 | [%s ("?a" * "?b")] @-> [%s ("?b" * "?a")]; (* comm-mul *)
229 | [%s ("?a" + ("?b" + "?c"))] @-> [%s (("?a" + "?b") + "?c")]; (* assoc add *)
230 | [%s ("?a" * ("?b" * "?c"))] @-> [%s (("?a" * "?b") * "?c")]; (* assoc mul *)
231 |
232 | [%s (("?a" - "?c") + "?b")] @-> [%s ("?a" + ("?b" - "?c"))];
233 |
234 | [%s ("?a" - "?b")] @-> [%s ("?a" + ("-1." * "?b"))]; (* sub canon *)
235 | rewrite [%s ("?a" / "?b")] [%s ("?a" * (pow "?b" "-1.0"))] ~if_:(is_not_zero "b"); (* div canon *)
236 |
237 | [%s ("?a" + "0.")] @-> [%s "?a"]; (* zero-add *)
238 | [%s ("?a" * "0.")] @-> [%s "0."]; (* zero-mul *)
239 | [%s ("?a" * "1.")] @-> [%s "?a"]; (* one-mul *)
240 |
241 | [%s "?a"] @-> [%s ("?a" + "0.")]; (* add-zero *)
242 | [%s "?a"] @-> [%s ("?a" * "1.")]; (* mul-one *)
243 |
244 | [%s ("?a" - "?a")] @-> [%s "0."]; (* cancel sub *)
245 | rewrite [%s ("?a" / "?a")] [%s "1."] ~if_:(is_not_zero "a"); (* cancel div *)
246 |
247 | [%s("?a" * ("?b" + "?c"))] @-> [%s (("?a" * "?b") + ("?a" * "?c"))]; (* distribute *)
248 | [%s (("?a" * "?b") + ("?a" * "?c"))] @-> [%s ("?a" * ("?b" + "?c"))]; (* factor *)
249 |
250 | [%s ((pow "?a" "?b") * (pow "?a" "?c"))] @-> [%s (pow "?a" ("?b" + "?c"))]; (* pow-mul *)
251 | rewrite [%s (pow "?x" "0.")] [%s "1."] ~if_:(is_not_zero "x"); (* po0 *)
252 |
253 | [%s (pow "?x" "1.")] @-> [%s "?x"]; (* pow1 *)
254 |
255 | [%s (pow "?x" "2.")] @-> [%s ("?x" * "?x")]; (* po2 *)
256 |
257 | rewrite [%s (pow "?x" "-1.")] [%s("1." / "?x")] ~if_:(is_not_zero "x"); (* pow-recip *)
258 |
259 | rewrite [%s ("?x" * ("1." / "?x"))] [%s "1."] ~if_:(is_not_zero "x"); (* recip mul div *)
260 |
261 | rewrite [%s (d "?x" "?x")] [%s "1."] ~if_:(is_sym "x"); (* d variable *)
262 |
263 | rewrite [%s (d "?x" "?c")] [%s"0."] ~if_:(is_sym "x" && is_const_or_distinct_var "c" "x");
264 | (* d constant *)
265 |
266 | [%s (d "?x" ("?a" + "?b"))] @-> [%s ((d "?x" "?a") + (d "?x" "?b"))]; (* d-add *)
267 | [%s (d "?x" ("?a" * "?b"))] @-> [%s (("?a" * (d "?x" "?b")) + ("?b" * (d "?x" "?a")))]; (* d-mul *)
268 |
269 | [%s (d "?x" (sin "?x"))] @-> [%s (cos "?x")]; (* d-sin *)
270 |
271 | [%s (d "?x" (cos "?x"))] @-> [%s ("-1." * (sin "?x"))]; (* d-cos *)
272 |
273 | rewrite [%s (d "?x" (ln "?x"))] [%s (1 / "?x")] ~if_:(is_not_zero "x"); (* d-ln *)
274 |
275 | rewrite [%s (d "?x" (pow "?f" "?g"))]
276 | [%s ((pow "?f" "?g") * (((d "?x" "?f") * ("?g" / "?f")) + ((d "?x" "?g") * (ln "?f"))))]
277 | ~if_:(is_not_zero "f" && is_not_zero "g");
278 | [%s (i "1." "?x")] @-> [%s "?x"];
279 | rewrite
280 | [%s (i (pow "?x" "?c") "?x")]
281 | [%s ((pow "?x" ("?c" + "1.")) / ("?c" + "1."))]
282 | ~if_:(is_const "c");
283 | [%s (i (cos "?x") "?x")] @-> [%s (sin "?x")];
284 | [%s (i (sin "?x") "?x")] @-> [%s ("-1." * (cos "?x"))];
285 | [%s (i ("?f" + "?g") "?x")] @-> [%s ((i "?f" "?x") + (i "?g" "?x"))];
286 | [%s (i ("?f" - "?g") "?x")] @-> [%s ((i "?f" "?x") - (i "?g" "?x"))];
287 | [%s (i ("?a" * "?b") "?x")] @-> [%s (("?a" * (i "?b" "?x")) - (i ((d "?x" "?a") * (i "?b" "?x")) "?x"))];
288 | ]
289 |
290 | let run_and_check1 ?node_limit ?fuel rules s1 f () =
291 | let graph = EGraph.init () in
292 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in
293 | let reached_saturation = EGraph.run_until_saturation ?node_limit ?fuel graph rules in
294 | begin
295 | match fuel, node_limit with
296 | _ , Some _ | Some _, _ -> ()
297 | | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation
298 | end;
299 | f graph term_1
300 |
301 | let run_and_check2 ?node_limit ?fuel rules s1 s2 f () =
302 | let graph = EGraph.init () in
303 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in
304 | let term_2 = EGraph.add_node graph (L.of_sexp s2) in
305 | let reached_saturation = EGraph.run_until_saturation ?node_limit ?fuel graph rules in
306 | begin
307 | match fuel, node_limit with
308 | _ , Some _ | Some _, _ -> ()
309 | | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation
310 | end;
311 | f graph term_1 term_2
312 |
313 | let check_proves_equal ?node_limit ?fuel rules s1 s2 () =
314 | let graph = EGraph.init () in
315 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in
316 | let term_2 = EGraph.add_node graph (L.of_sexp s2) in
317 | let terms_are_equal graph = EGraph.class_equal (EGraph.freeze graph) term_1 term_2 in
318 | let _reached_saturation = EGraph.run_until_saturation ~until:terms_are_equal ?node_limit ?fuel graph rules in
319 | (* begin
320 | * match fuel, node_limit with
321 | * _ , Some _ | Some _, _ -> ()
322 | * | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation
323 | * end; *)
324 | Alcotest.(check bool) "proves terms are equal modulo rewriting"
325 | true
326 | (terms_are_equal graph)
327 |
328 | let check_cannot_prove_equal ?node_limit ?fuel rules s1 s2 =
329 | run_and_check2 ?node_limit ?fuel rules s1 s2 (fun graph term_1 term_2 ->
330 | Alcotest.(check bool) "must not prove terms are equal modulo rewriting"
331 | false
332 | (EGraph.class_equal (EGraph.freeze graph) term_1 term_2)
333 | )
334 |
335 | let check_extract ?node_limit ?fuel rules s1 s2 =
336 | run_and_check1 ?node_limit ?fuel rules s1 (fun graph term_1 ->
337 | Alcotest.(check sexp)
338 | "extracted expression matches"
339 | s2
340 | (L.to_sexp (Extractor.extract graph term_1))
341 | )
342 |
343 | let () =
344 | Alcotest.run "math"
345 | [("proving with addition",
346 | let rules = [
347 | (* add comm *) [%s ("?a" + "?b")] @-> [%s ("?b" + "?a")];
348 | (* add assoc *) [%s ("?a" + ("?b" + "?c"))] @-> [%s (("?a" + "?b") + "?c")];
349 | ] in [
350 | "constants are simplified", `Quick, check_proves_equal rules
351 | [%s (1 + (2 + (3 + (4 + (5 + (6 + 7))))))]
352 | [%s (7 + (6 + (5 + (4 + (3 + (2 + 1))))))];
353 | "constants are evaluated", `Quick, check_proves_equal rules
354 | [%s (1 + (2 + (3 + (4 + (5 + (6 + 7))))))]
355 | [%s 28];
356 | "symbols can be rearranged", `Quick, check_proves_equal rules
357 | [%s (1 + (x + (2 + (3 + (4 + (5 + (6 + 7)))))))]
358 | [%s (x + 28)];
359 | ]);
360 | "proving arithmetic with full rule set", [
361 | "subtraction works with symbols", `Quick,
362 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules
363 | [%s (x - x)] [%s 0.];
364 | "subtraction works with non obvious equalities", `Quick,
365 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules
366 | [%s (x - (x + 0))] [%s 0.];
367 | "subtraction works with complex expressions", `Quick,
368 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules
369 | [%s ((sqrt 5.) - (sqrt 5.))] [%s 0.];
370 | "subtraction works with complex expressions and addition", `Quick,
371 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules
372 | [%s ((1 + (sqrt 5.)) - ((sqrt 5.) + 1))] [%s 0.];
373 | "multiplication is rewritten", `Quick,
374 | check_proves_equal ~node_limit:(`Bounded 75_000) ~fuel:(`Bounded 30) rules
375 | [%s (1 - x)] [%s (1 + ("-1." * x))];
376 | "multiplication is propagated", `Quick,
377 | check_proves_equal ~node_limit:(`Bounded 75_000) ~fuel:(`Bounded 30) rules
378 | [%s ((1 - x) + x)] [%s (1 + (x + ("-1." * x)))];
379 | "1subtraction can be reverted", `Quick,
380 | check_proves_equal ~node_limit:(`Bounded 75_000) ~fuel:(`Bounded 30) rules
381 | [%s (1 + ("-1." * x))] [%s (1 - x)];
382 | "subtraction can be cancelled", `Quick,
383 | check_proves_equal ~node_limit:(`Bounded 75_000) ~fuel:(`Bounded 30) rules
384 | [%s (x + ("-1." * x))] [%s 0];
385 | "subtraction can be propagated and cancelled", `Quick,
386 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules
387 | [%s (1 + (x - x))] [%s ((1 - x) + x) ];
388 | "complex subtraction can be propagated and cancelled", `Quick,
389 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules
390 | [%s (1 + (x - x))] [%s ((1 - x) + x) ];
391 | "plus minus one can be propagated and cancelled", `Quick,
392 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules
393 | [%s ((1 - x) + (1 + x)) ] [%s 2];
394 | "division can be simplified", `Quick,
395 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules
396 | [%s (2 / 2) ] [%s 1];
397 | "division with numerator 0 can be simplified", `Quick,
398 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules
399 | [%s ((x - x) / 2) ] [%s 0];
400 | "multiplication with 0 is simplified", `Quick,
401 | check_proves_equal ~node_limit:(`Bounded 100_000_000) ~fuel:(`Bounded 30) rules
402 | [%s (x * 0.) ] [%s 0];
403 | ];
404 | "does not prove invalid equalities", [
405 | "multiplication and addition are not equal", `Quick,
406 | check_cannot_prove_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules [%s (x + y)] [%s (x / y)]
407 | ];
408 | "reasoning about derivatives", [
409 | "dx/dy of x is 1", `Quick,
410 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules
411 | [%s (d x x)] [%s 1 ];
412 | "dx/dy of y is 0", `Quick,
413 | check_proves_equal ~node_limit:(`Bounded 10_000) ~fuel:(`Bounded 30) rules
414 | [%s (d x y)] [%s 0 ];
415 | "dx/dy of 1 + 2x is 2", `Quick,
416 | check_proves_equal ~node_limit:(`Bounded 100_000) ~fuel:(`Bounded 35) rules
417 | [%s (d x (1 + (2. * x)))] [%s 2. ];
418 | "dx/dy of xy + 1 is y", `Quick,
419 | check_extract ~node_limit:(`Unbounded) ~fuel:(`Bounded 15) rules
420 | [%s (d x (1. + (y * x)))] [%s y ];
421 | "dx/dy of ln x is 1 / x", `Quick,
422 | check_proves_equal ~node_limit:(`Bounded 100_000) ~fuel:(`Bounded 35) rules
423 | [%s (d x (ln x))] [%s 1 / x ];
424 | ];
425 | ]
426 |
--------------------------------------------------------------------------------
/test/test_prop.ml:
--------------------------------------------------------------------------------
1 | open Ego.Generic
2 |
3 | let sexp =
4 | (module struct
5 | type t = Sexplib.Sexp.t
6 | let pp = Sexplib.Sexp.pp_hum
7 | let equal = Sexplib.Sexp.equal
8 | end : Alcotest.TESTABLE with type t = Sexplib.Sexp.t)
9 |
10 |
11 | module L = struct
12 |
13 | type 'a shape =
14 | | And of 'a * 'a
15 | | Or of 'a * 'a
16 | | Not of 'a
17 | | Impl of 'a * 'a
18 | | Bool of bool
19 | | Symbol of string [@@deriving ord, show]
20 |
21 | type t = Mk of t shape [@@unboxed]
22 |
23 | type op =
24 | | AndOp
25 | | OrOp
26 | | NotOp
27 | | ImplOp
28 | | BoolOp of bool
29 | | SymbolOp of string [@@deriving eq, ord]
30 |
31 | let rec of_sexp = function[@warning "-8"]
32 | | Sexplib0.Sexp.Atom "true" -> Mk (Bool true)
33 | | Sexplib0.Sexp.Atom "false" -> Mk (Bool false)
34 | | Sexplib0.Sexp.Atom s -> Mk (Symbol s)
35 | | List [Atom "&&"; l; r] -> Mk (And (of_sexp l, of_sexp r))
36 | | List [Atom "||"; l; r] -> Mk (Or (of_sexp l, of_sexp r))
37 | | List [Atom "not"; l] -> Mk (Not (of_sexp l))
38 | | List [Atom "=>"; l; r] -> Mk (Impl (of_sexp l, of_sexp r))
39 |
40 | let op_of_string = function[@warning "-8"]
41 | | "true" -> (BoolOp true)
42 | | "false" -> (BoolOp false)
43 | | "&&" -> AndOp
44 | | "||" -> OrOp
45 | | "not" -> NotOp
46 | | "=>" -> ImplOp
47 | | s -> (SymbolOp s)
48 |
49 |
50 | let rec to_sexp = function[@warning "-8"]
51 | | Mk (Bool true) ->Sexplib0.Sexp.Atom "true"
52 | | Mk (Bool false) ->Sexplib0.Sexp.Atom "false"
53 | | Mk (Symbol s) ->Sexplib0.Sexp.Atom s
54 | | Mk (And (l, r)) -> List [Atom "&&"; to_sexp l; to_sexp r]
55 | | Mk (Or (l, r)) -> List [Atom "||"; to_sexp l; to_sexp r]
56 | | Mk (Not l) -> List [Atom "not"; to_sexp l]
57 | | Mk (Impl (l, r)) -> List [Atom "=>"; to_sexp l; to_sexp r]
58 |
59 | let op = function
60 | | And _ -> AndOp
61 | | Or _ -> OrOp
62 | | Not _ -> NotOp
63 | | Impl _ -> ImplOp
64 | | Bool b -> BoolOp b
65 | | Symbol s -> SymbolOp s
66 |
67 | let children = function
68 | | And (l,r) -> [l;r]
69 | | Or (l,r) -> [l;r]
70 | | Not l -> [l]
71 | | Impl (l,r) -> [l;r]
72 | | Bool _ | Symbol _ -> []
73 |
74 | let map_children term f = match term with
75 | | And (l,r) -> And (f l, f r)
76 | | Or (l,r) -> Or (f l, f r)
77 | | Not l -> Not (f l)
78 | | Impl (l,r) -> Impl (f l, f r)
79 | | Bool b -> Bool b
80 | | Symbol s -> Symbol s
81 |
82 | let make op children = match[@warning "-8"] op,children with
83 | | AndOp, [l;r] -> And (l,r)
84 | | OrOp, [l;r] -> Or (l,r)
85 | | NotOp, [l] -> Not l
86 | | ImplOp, [l;r] -> Impl (l,r)
87 | | BoolOp b, [] -> Bool b
88 | | SymbolOp s, [] -> Symbol s
89 |
90 | end
91 |
92 | module C = struct
93 | type t = float [@@deriving ord]
94 | let cost f : Ego.Id.t L.shape -> t = function
95 | | L.And (l, r) -> f l +. f r +. 3.
96 | | L.Or (l, r) -> f l +. f r +. 2.0
97 | | L.Impl (l, r) -> f l +. f r +. 1.0
98 | | L.Not l -> f l +. 3.0
99 | | L.Symbol _ -> 1.0
100 | | L.Bool _ -> 1.0
101 | end
102 |
103 | module A = struct type t = unit type data = bool option[@@deriving eq,show] let default = None end
104 | module MA (S : GRAPH_API
105 | with type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph
106 | and type 'a shape := 'a L.shape
107 | and type analysis := A.t
108 | and type data := A.data
109 | and type node := L.t) = struct
110 | type 'p t = (Ego.Id.t L.shape, A.t, A.data, 'p) egraph
111 |
112 | let eval : A.data L.shape -> A.data =
113 | function
114 | | L.Bool c -> Some c
115 | | L.Not (Some b) -> Some (not b)
116 | | L.And (Some l, Some r) -> Some (l && r)
117 | | L.Or (Some l, Some r) -> Some (l || r)
118 | | L.Impl (Some l, Some r) -> Some ((not l) || r)
119 | | _ -> None
120 |
121 | let make : ro t -> Ego.Id.t L.shape -> A.data =
122 | fun graph term -> eval (L.map_children term (S.get_data graph))
123 |
124 | let merge : A.t -> A.data -> A.data -> A.data * (bool * bool) =
125 | fun () l r -> match l,r with
126 | | Some l, Some r -> assert (l = r); Some l, (false, false)
127 | | Some l, None -> Some l, (false, true)
128 | | None, Some r -> Some r, (true, false)
129 | | _ -> None, (false, false)
130 |
131 | let modify : 'a t -> Ego.Id.t -> unit =
132 | fun graph cls ->
133 | match S.get_data (S.freeze graph) cls with
134 | | None -> ()
135 | | Some n ->
136 | let nw_cls = S.add_node graph (L.Mk (Bool n)) in
137 | S.merge graph nw_cls cls
138 |
139 | end
140 |
141 | module EGraph = Make (L) (A) (MA)
142 | module Extractor = MakeExtractor (L) (C)
143 |
144 |
145 | let run_and_check1 ?node_limit ?fuel rules s1 f () =
146 | let graph = EGraph.init () in
147 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in
148 | let reached_saturation = EGraph.run_until_saturation ?node_limit ?fuel graph rules in
149 | begin
150 | match fuel, node_limit with
151 | _ , Some _ | Some _, _ -> ()
152 | | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation
153 | end;
154 | f graph term_1
155 |
156 | let run_and_check2 ?node_limit ?fuel rules s1 s2 f () =
157 | let graph = EGraph.init () in
158 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in
159 | let term_2 = EGraph.add_node graph (L.of_sexp s2) in
160 | let reached_saturation = EGraph.run_until_saturation ?node_limit ?fuel graph rules in
161 | begin
162 | match fuel, node_limit with
163 | _ , Some _ | Some _, _ -> ()
164 | | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation
165 | end;
166 | f graph term_1 term_2
167 |
168 | let check_proves_equal ?node_limit ?fuel rules s1 s2 () =
169 | let graph = EGraph.init () in
170 | let term_1 = EGraph.add_node graph (L.of_sexp s1) in
171 | let term_2 = EGraph.add_node graph (L.of_sexp s2) in
172 | let terms_are_equal graph = EGraph.class_equal (EGraph.freeze graph) term_1 term_2 in
173 | let _reached_saturation = EGraph.run_until_saturation ~until:terms_are_equal ?node_limit ?fuel graph rules in
174 | (* begin
175 | * match fuel, node_limit with
176 | * _ , Some _ | Some _, _ -> ()
177 | * | None, None -> Alcotest.(check bool) "reaches equality saturation" true reached_saturation
178 | * end; *)
179 | Alcotest.(check bool) "proves terms are equal modulo rewriting"
180 | true
181 | (terms_are_equal graph)
182 |
183 | let check_cannot_prove_equal ?node_limit ?fuel rules s1 s2 =
184 | run_and_check2 ?node_limit ?fuel rules s1 s2 (fun graph term_1 term_2 ->
185 | Alcotest.(check bool) "must not prove terms are equal modulo rewriting"
186 | false
187 | (EGraph.class_equal (EGraph.freeze graph) term_1 term_2)
188 | )
189 |
190 | let check_extract ?node_limit ?fuel rules s1 s2 =
191 | run_and_check1 ?node_limit ?fuel rules s1 (fun graph term_1 ->
192 | Alcotest.(check sexp)
193 | "extracted expression matches"
194 | s2
195 | (L.to_sexp (Extractor.extract graph term_1))
196 | )
197 |
198 |
199 |
200 | let qf = Query.of_sexp L.op_of_string
201 | let (@->) from into = EGraph.Rule.make_constant ~from:(qf from) ~into:(qf into)
202 | let rewrite from into ~if_ = EGraph.Rule.make_conditional ~from:(qf from) ~into:(qf into) ~cond:if_
203 |
204 | let rules = [
205 | (* def_imply *) [%s ("?a" => "?b")] @-> [%s ((not "?a") || "?b")];
206 | (* double_neg *) [%s (not (not "?a"))] @-> [%s "?a"];
207 | (* assoc_or *) [%s ( "?a" || ("?b" || "?c"))] @-> [%s (("?a" || "?b") || "?c")];
208 | (* dist_and_or *) [%s ("?a" && ("?b" || "?c"))] @-> [%s (("?a" && "?b") || ("?a" && "?c"))];
209 | (* dist_or_and *) [%s ("?a" || ("?b" || "?c"))] @-> [%s (("?a" || "?b") && ("?a" || "?c"))];
210 | (* comm_or *) [%s ("?a" || "?b")] @-> [%s ("?b" || "?a")];
211 | (* comm_and *) [%s ("?a" && "?b")] @-> [%s ("?b" && "?a")];
212 | (* lem *) [%s ("?a" || (not "?a"))] @-> [%s"true"];
213 | (* or_true *) [%s ("?a" || "true")] @-> [%s "true"];
214 | (* and_true *) [%s ("?a" && "true")] @-> [%s"?a"];
215 | (* contrapositive *) [%s ("?a" => "?b")] @-> [%s ((not "?b") => (not "?a"))];
216 | (* lem_imply *) [%s (("?a" => "?b") && ((not "?a") => "?c")) ] @-> [%s ("?b" || "?c") ];
217 | ]
218 |
219 | let proves ?(match_limit=1_000) ?(ban_length=5) ?node_limit ?fuel start goals () =
220 | let graph = EGraph.init () in
221 | let start = EGraph.add_node graph (L.of_sexp start) in
222 | let scheduler = Ego.Generic.Scheduler.Backoff.with_params ~match_limit ~ban_length in
223 | ignore @@ EGraph.run_until_saturation ~scheduler ?fuel ?node_limit graph rules;
224 | List.iter (fun goal ->
225 | let goal = EGraph.add_node graph (L.of_sexp goal) in
226 | Alcotest.(check bool)
227 | "goal can be proved from start"
228 | true
229 | (EGraph.class_equal (EGraph.freeze graph) start goal)
230 | ) goals
231 |
232 | let proves_cached ?(match_limit=1_000) ?(ban_length=5) ?node_limit ?fuel start goals () =
233 | let graph = EGraph.init () in
234 | let start = EGraph.add_node graph (L.of_sexp start) in
235 | let goals = List.map (fun goal -> EGraph.add_node graph (L.of_sexp goal)) goals in
236 | let last =
237 | let rec last acc ls = match ls with
238 | | [] -> acc
239 | | h :: t -> last h t in
240 | last start goals in
241 | let scheduler = Ego.Generic.Scheduler.Backoff.with_params ~match_limit ~ban_length in
242 | ignore @@ EGraph.run_until_saturation ~scheduler ?fuel ?node_limit ~until:(fun graph ->
243 | EGraph.class_equal (EGraph.freeze graph) start last
244 | ) graph rules;
245 | List.iter (fun goal ->
246 | Alcotest.(check bool)
247 | "goal can be proved from start"
248 | true
249 | (EGraph.class_equal (EGraph.freeze graph) start goal)
250 | ) goals
251 |
252 | let () =
253 | Alcotest.run "prop" [
254 | "ematch tests", [
255 | "check matches after merging", `Quick,
256 | (fun () -> let graph = EGraph.init () in
257 | let n1 = EGraph.add_node graph (L.of_sexp [%s (x && z)]) in
258 | let n2 = EGraph.add_node graph (L.of_sexp [%s (y && z)]) in
259 | EGraph.merge graph n1 n2;
260 | EGraph.rebuild graph;
261 | let query = qf [%s "?a" && z] in
262 | let matches = EGraph.find_matches (EGraph.freeze graph) query |> Iter.length in
263 | Alcotest.(check int) "2 matches" 2 matches);
264 |
265 | "check matches after saturating", `Quick,
266 | fun () -> let graph = EGraph.init () in
267 | let scheduler = Ego.Generic.Scheduler.Backoff.with_params ~match_limit:1000 ~ban_length:5 in
268 | let _ = EGraph.add_node graph (L.of_sexp [%s (x && y)]) in
269 | let query = [%s "?a" && "?b"] @-> [%s "?b" && "?a"] in
270 | ignore @@ EGraph.run_until_saturation ~scheduler graph [query];
271 | let q = qf [%s "?a" && "?b"] in
272 | let matches = EGraph.find_matches (EGraph.freeze graph) q |> Iter.length in
273 | Alcotest.(check int) "2 matches" 2 matches
274 | ];
275 | "proving contrapositive", [
276 | "proves idempotent", `Quick, proves [%s (x => y)] [[%s (x => y)]];
277 | "proves negation", `Quick, proves [%s (x => y)] [[%s (x => y)];
278 | [%s ((not x) || y)]];
279 | "proves double negation", `Quick, proves [%s (x => y)] [[%s (x => y)];
280 | [%s ((not x) || y)];
281 | [%s ((not x) || (not (not y)))]];
282 | "proves commutativity", `Quick, proves [%s (x => y)] [[%s (x => y)];
283 | [%s ((not x) || y)];
284 | [%s ((not x) || (not (not y)))];
285 | [%s ((not (not y)) || (not x))];
286 | ];
287 | "proves contrapositive", `Quick, proves [%s (x => y)] [[%s (x => y)];
288 | [%s ((not x) || y)];
289 | [%s ((not x) || (not (not y)))];
290 | [%s ((not (not y)) || (not x))];
291 | [%s ((not y) => (not x))];
292 | ]];
293 | "proving chain", [
294 | "proves idempotent", `Quick, proves [%s ((x => y) && (y => z))] [[%s ((x => y) && (y => z))]];
295 | "proves contrapositive", `Quick, proves [%s ((x => y) && (y => z))]
296 | [[%s ((x => y) && (y => z))];
297 | [%s (((not y) => (not x)) && (y => z))]];
298 | "proves commutativity", `Quick, proves [%s ((x => y) && (y => z))]
299 | [[%s ((x => y) && (y => z))];
300 | [%s (((not y) => (not x)) && (y => z))];
301 | [%s ((y => z) && ((not y) => (not x)))]];
302 | "proves negation", `Quick, proves [%s ((x => y) && (y => z))]
303 | [[%s ((x => y) && (y => z))];
304 | [%s (((not y) => (not x)) && (y => z))];
305 | [%s ((y => z) && ((not y) => (not x)))];
306 | [%s (z || (not x))]
307 | ];
308 | "proves commutativity", `Quick, proves
309 | ~node_limit:(`Bounded 10_000)
310 | ~fuel:(`Bounded 60)
311 | [%s ((x => y) && (y => z))]
312 | [[%s ((x => y) && (y => z))];
313 | [%s (((not y) => (not x)) && (y => z))];
314 | [%s ((y => z) && ((not y) => (not x)))];
315 | [%s (z || (not x))];
316 | [%s ((not x) || z)]; ];
317 | "proves chain", `Quick, proves_cached
318 | ~match_limit:(10_000) ~ban_length:5
319 | ~node_limit:(`Bounded 600_000)
320 | ~fuel:(`Bounded 50)
321 | [%s ((x => y) && (y => z))]
322 | [[%s ((x => y) && (y => z))];
323 | [%s (((not y) => (not x)) && (y => z))];
324 | [%s ((y => z) && ((not y) => (not x)))];
325 | [%s (z || (not x))];
326 | [%s ((not x) || z)];
327 | [%s (x => z)];
328 | ]
329 | ]]
330 |
331 |
--------------------------------------------------------------------------------