├── .github └── workflows │ ├── build.yml │ └── docs.yml ├── .gitignore ├── CHANGELOG.md ├── Cargo.toml ├── LICENSE ├── Makefile ├── README.md ├── doc ├── egg.svg └── egraphs.drawio ├── rust-toolchain ├── src ├── dot.rs ├── eclass.rs ├── egraph.rs ├── explain.rs ├── extract.rs ├── language.rs ├── lib.rs ├── lp_extract.rs ├── machine.rs ├── macros.rs ├── multipattern.rs ├── pattern.rs ├── rewrite.rs ├── run.rs ├── subst.rs ├── test.rs ├── tutorials │ ├── _01_background.rs │ ├── _02_getting_started.rs │ ├── _03_explanations.rs │ └── mod.rs ├── unionfind.rs └── util.rs └── tests ├── datalog.rs ├── lambda.rs ├── math.rs ├── prop.rs └── simple.rs /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-20.04 7 | steps: 8 | - uses: actions/checkout@v2 9 | - name: Cache cargo bin 10 | uses: actions/cache@v2 11 | with: 12 | path: ~/.cargo/bin 13 | key: ${{ runner.os }}-cargo-bin 14 | - name: Install cbc 15 | run: sudo apt-get install coinor-libcbc-dev 16 | - name: Test 17 | run: make test 18 | nits: 19 | runs-on: ubuntu-20.04 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Install cargo deadlinks 23 | run: | 24 | curl -L -o ~/.cargo/bin/cargo-deadlinks https://github.com/deadlinks/cargo-deadlinks/releases/download/0.4.2/deadlinks-linux 25 | chmod +x ~/.cargo/bin/cargo-deadlinks 26 | cargo deadlinks --version 27 | - name: Nits 28 | run: make nits -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Publish Docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | jobs: 7 | docs: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | with: 12 | persist-credentials: false 13 | - name: Build Docs 14 | run: | 15 | cargo doc --no-deps --all-features 16 | touch target/doc/.nojekyll # prevent jekyll from running 17 | - name: Deploy 🚀 18 | uses: JamesIves/github-pages-deploy-action@4.1.5 19 | with: 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | BRANCH: gh-pages # The branch the action should deploy to. 22 | FOLDER: target/doc # The folder the action should deploy. 23 | CLEAN: true # Automatically remove deleted files from the deploy branch 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | **/*.rs.bk 3 | Cargo.lock 4 | 5 | dots/ 6 | egg/data/ 7 | 8 | # nix stuff 9 | .envrc 10 | default.nix 11 | 12 | # other stuff 13 | TODO.md 14 | flamegraph.svg 15 | perf.data* 16 | *.bench 17 | .vscode/settings.json 18 | out.csv 19 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changes 2 | 3 | ## [Unreleased] - ReleaseDate 4 | - Change the API of `make` to have mutable access to the e-graph for some [advanced uses cases](https://github.com/egraphs-good/egg/pull/277). 5 | 6 | 7 | ## [0.9.5] - 2023-06-29 8 | - Fixed a few edge cases in proof size optimization that caused egg to crash. 9 | 10 | ## [0.9.4] - 2023-05-23 11 | - [#253] Improved rebuilding algorithm using a queue. 12 | - [#259] Fixed another overflow bug in proof size optimization. 13 | - Various typo fixes (Thanks @nlewycky) 14 | 15 | ## [0.9.3] - 2023-02-06 16 | 17 | ### Added 18 | - [#215](https://github.com/egraphs-good/egg/pull/215) Added a better intersection algorithms on two egraphs based on "Join Algorithms for the Theory of Uninterpreted Functions". The old intersection algorithm was not complete on the terms in both egraphs, but the new one is. Unfortunately, the new algorithm is quadratic. 19 | 20 | ### Changed 21 | - [#230](https://github.com/egraphs-good/egg/pull/230) Fixed a performance bug in `get_string_with_let` that caused printing let-bound proofs to be extremely inefficient. 22 | 23 | 24 | ## [0.9.2] - 2022-12-15 25 | 26 | ### Added 27 | 28 | - [#210](https://github.com/egraphs-good/egg/pull/210) Fix crashes in proof generation due to proof size calculations overflowing. 29 | 30 | ## [0.9.1] - 2022-09-22 31 | 32 | ### Added 33 | 34 | - [#186](https://github.com/egraphs-good/egg/pull/186) Added proof minimization (enabled by default), a greedy algorithm to find smaller proofs 35 | - with and `without_explanation_length_optimization` for turning this on and off 36 | - `copy_without_unions` for copying an egraph but without any equality information 37 | - `id_to_expr` for getting an expression corresponding to a particular enode's id 38 | - [#197](https://github.com/egraphs-good/egg/pull/197) Added `search_with_limit`, so that matching also stops when it hits scheduling limits. 39 | 40 | ### Changed 41 | 42 | - Changed the `pre_union` hook to support explanations 43 | - now provides the `Id` of the two specific enodes that are merged, not canonical ids. 44 | - It also provides the reason for the merge in the form of a `Justification`. 45 | 46 | ## [0.9.0] - 2022-06-12 47 | 48 | ### Added 49 | - Added a way to update analysis data and have it propagate through the e-graph 50 | 51 | ### Changed 52 | - Improved documentation 53 | - Updated dependencies 54 | - `union` is now allowed when explanations are on 55 | 56 | ## [0.8.1] - 2022-05-04 57 | 58 | ### Changed 59 | - Improved documentation for features. 60 | 61 | ## [0.8.0] - 2022-04-28 62 | 63 | ### Added 64 | - ([#128](https://github.com/egraphs-good/egg/pull/128)) Add an ILP-based extractor. 65 | - ([#168](https://github.com/egraphs-good/egg/pull/168)) Added MultiPatterns. 66 | 67 | ### Changed 68 | - ([#165](https://github.com/egraphs-good/egg/pull/165)) Unions now happen "instantly", restoring the pre-0.7 behavior. 69 | - The tested MSRV is now 1.60.0. 70 | - Several small documentation enhancements. 71 | - ([#162](https://github.com/egraphs-good/egg/pull/162), [#163](https://github.com/egraphs-good/egg/pull/163)) 72 | Extracted the `Symbol` logic into the [`symbol_table`](https://crates.io/crates/symbol_table) crate. 73 | 74 | ## [0.7.1] - 2021-12-14 75 | 76 | This patch fixes a pretty bad e-matching bug introduced in 0.7.0. Please upgrade! 77 | 78 | ### Fixed 79 | - (#143) Non-linear patterns e-match correctly again 80 | - (#141) Loosen requirement on FromOp::Error 81 | 82 | ## [0.7.0] - 2021-11-23 83 | 84 | It's a been a long time since a release! 85 | There's a lot in this one, hopefully I can cut releases more frequently in the future, 86 | because there are definitely more features coming :) 87 | 88 | ### Added 89 | - The egraph now has an `EGraph::with_explanations_enabled` mode that allows for 90 | explaining why two terms are equivalent in the egraph. 91 | In explanations mode, all unions must be done through `union_instantiations` in order 92 | to justify the union. 93 | Calling `explain_equivalence` returns an `Explanation` 94 | which has both a `FlatExplanation` form and a 95 | `TreeExplanation` form. 96 | See #115 and #119 for more details. 97 | - The `BackoffScheduler` is now more flexible. 98 | - `EGraph::pre_union` allows inspection of unions, which can be useful for debugging. 99 | - The dot printer is now more flexible. 100 | 101 | ### Changed 102 | 103 | - `Analysis::merge` now gets a `&mut self`, so it can store data on the `Analysis` itself. 104 | - `Analysis::merge` has a different signature. 105 | - Pattern compilation and execution is faster, especially when there are ground terms involved. 106 | - All unions are now delayed until rebuilding, so `EGraph::rebuild` be called to observe effects. 107 | - The `apply_one` function on appliers *now needs to perform unions*. 108 | - The congruence closure algorithm now keeps the egraph congruent before 109 | doing any analysis (calling `make`). It does this by interleaving rebuilding 110 | and doing analysis. 111 | - `EGraph::add_expr` now proceeds linearly through the given `RecExpr`, which 112 | should be faster and include _all_ e-nodes from the expression. 113 | - `Rewrite` now has public `searcher` and `applier` fields and no `long_name`. 114 | - ([#61](https://github.com/egraphs-good/egg/pull/61)) 115 | Rebuilding is much improved! 116 | The new algorithm's congruence closure part is closer to 117 | [Downey, Sethi, Tarjan](https://dl.acm.org/doi/pdf/10.1145/322217.322228), 118 | and the analysis data propagation is more precise with respect to merging. 119 | Overall, the algorithm is simpler, easier to reason about, and more than twice as fast! 120 | - ([#86](https://github.com/egraphs-good/egg/pull/86)) 121 | `Language::display_op` has been removed. Languages should implement `Display` 122 | to display the operator instead. `define_language!` now implements `Display` 123 | accordingly. 124 | - `Language::from_op_str` has been replaced by a new `FromOp` trait. 125 | `define_language!` implements this trait automatically. 126 | 127 | ## [0.6.0] - 2020-07-16 128 | 129 | ### Added 130 | - `Id` is now a struct not a type alias. This should help prevent some bugs. 131 | - `Runner` hooks allow you to modify the `Runner` each iteration and stop early if you want. 132 | - Added a way to lookup an e-node without adding it. 133 | - `define_language!` now support variants with data _and_ children. 134 | - Added a tutorial in the documentation! 135 | 136 | ### Fixed 137 | - Fixed a bug when making `Pattern`s from `RecExpr`s. 138 | - Improved the `RecExpr` API. 139 | 140 | ## [0.5.0] - 2020-06-22 141 | 142 | ### Added 143 | - `egg` now provides `Symbol`s, a simple interned string that users can (and 144 | should) use in their `Language`s. 145 | - `egg` will now warn you when you try to use `Rewrite`s with the same name. 146 | - Rewrite creation will now fail if the searcher doesn't bind the right variables. 147 | - The `rewrite!` macro supports bidirectional rewrites now. 148 | - `define_language!` now supports variable numbers of children with `Box<[Id]>`. 149 | 150 | ### Fixed 151 | - The `rewrite!` macro builds conditional rewrites in the correct order now. 152 | 153 | ## [0.4.1] - 2020-05-26 154 | 155 | ### Added 156 | - Added various Debug and Display impls. 157 | 158 | ### Fixed 159 | - Fixed the way applications were counted by the Runner. 160 | 161 | ## [0.4.0] - 2020-05-21 162 | 163 | ### Added 164 | - The rebuilding algorithm is now _precise_ meaning it avoid a lot of 165 | unnecessary work. This leads to across the board speedup by up to 2x. 166 | - `Language` elements are now much more compact, leading to speed ups across the board. 167 | 168 | ### Changed 169 | - Replaced `Metadata` with `Analysis`, which can hold egraph-global data as well 170 | as per-eclass data. 171 | - **Fix:** 172 | An eclass's metadata will now get updated by 173 | congruence. 174 | ([commit](https://github.com/egraphs-good/egg/commit/0de75c9c9b0a80adb67fb78cc98cce3da383621a)) 175 | - The `BackoffScheduler` will now fast-forward if all rules are banned. 176 | ([commit](https://github.com/egraphs-good/egg/commit/dd172ef77279e28448d0bf8147e0171a8175228d)) 177 | - Improve benchmark reporting 178 | ([commit](https://github.com/egraphs-good/egg/commit/ca2ea5e239feda7eb6971942e119075f55f869ab)) 179 | - The egraph now skips irrelevant eclasses while searching for a ~40% search speed up. 180 | ([PR](https://github.com/egraphs-good/egg/pull/21)) 181 | 182 | ## [0.3.0] - 2020-02-27 183 | 184 | ### Added 185 | - `Runner` can now be configured with user-defined `RewriteScheduler`s 186 | and `IterationData`. 187 | 188 | ### Changed 189 | - Reworked the `Runner` API. It's now a generic struct instead of a 190 | trait. 191 | - Patterns are now compiled into a small virtual machine bytecode inspired 192 | by [this paper](https://link.springer.com/chapter/10.1007/978-3-540-73595-3_13). 193 | This gets about a 40% speed up. 194 | 195 | ## [0.2.0] - 2020-02-19 196 | 197 | ### Added 198 | 199 | - A dumb little benchmarking system called `egg_bench` that can help 200 | benchmark tests. 201 | - String interning for `Var`s (née `QuestionMarkName`s). 202 | This speeds up things by ~35%. 203 | - Add a configurable time limit to `SimpleRunner` 204 | 205 | ### Changed 206 | 207 | - Renamed `WildMap` to `Subst`, `QuestionMarkName` to `Var`. 208 | 209 | ### Removed 210 | 211 | - Multi-matching patterns (ex: `?a...`). 212 | They were a hack and undocumented. 213 | If we can come up with better way to do it, then we can put them back. 214 | 215 | ## [0.1.2] - 2020-02-14 216 | 217 | This release completes the documentation 218 | (at least every public item is documented). 219 | 220 | ### Changed 221 | - Replaced `Pattern::{from_expr, to_expr}` with `From` and `TryFrom` 222 | implementations. 223 | 224 | ## [0.1.1] - 2020-02-13 225 | 226 | ### Added 227 | - A lot of documentation 228 | 229 | ### Changed 230 | - The graphviz visualization now looks a lot better; enode argument 231 | come out from the "correct" position based on which argument they 232 | are. 233 | 234 | ## [0.1.0] - 2020-02-11 235 | 236 | This is egg's first real release! 237 | 238 | Hard to make a changelog on the first release, since basically 239 | everything has changed! 240 | But hopefully things will be a little more stable from here on out 241 | since the API is a lot nicer. 242 | 243 | 244 | [Unreleased]: https://github.com/egraphs-good/egg/compare/v0.9.5...HEAD 245 | [0.9.5]: https://github.com/egraphs-good/egg/compare/v0.9.4...v0.9.5 246 | [0.9.4]: https://github.com/egraphs-good/egg/compare/v0.9.3...v0.9.4 247 | [0.9.3]: https://github.com/egraphs-good/egg/compare/v0.9.2...v0.9.3 248 | [0.9.2]: https://github.com/egraphs-good/egg/compare/v0.9.1...v0.9.2 249 | [0.9.1]: https://github.com/egraphs-good/egg/compare/v0.9.0...v0.9.1 250 | [0.9.0]: https://github.com/egraphs-good/egg/compare/v0.8.1...v0.9.0 251 | [0.8.1]: https://github.com/egraphs-good/egg/compare/v0.8.0...v0.8.1 252 | [0.8.0]: https://github.com/egraphs-good/egg/compare/v0.7.1...v0.8.0 253 | [0.7.1]: https://github.com/egraphs-good/egg/compare/v0.7.0...v0.7.1 254 | [0.7.0]: https://github.com/egraphs-good/egg/compare/v0.6.0...v0.7.0 255 | [0.6.0]: https://github.com/egraphs-good/egg/compare/v0.5.0...v0.6.0 256 | [0.5.0]: https://github.com/egraphs-good/egg/compare/v0.4.1...v0.5.0 257 | [0.4.1]: https://github.com/egraphs-good/egg/compare/v0.4.0...v0.4.1 258 | [0.4.0]: https://github.com/egraphs-good/egg/compare/v0.3.0...v0.4.0 259 | [0.3.0]: https://github.com/egraphs-good/egg/compare/v0.2.0...v0.3.0 260 | [0.2.0]: https://github.com/egraphs-good/egg/compare/v0.1.2...v0.2.0 261 | [0.1.2]: https://github.com/egraphs-good/egg/compare/v0.1.1...v0.1.2 262 | [0.1.1]: https://github.com/egraphs-good/egg/compare/v0.1.0...v0.1.1 263 | [0.1.0]: https://github.com/egraphs-good/egg/tree/v0.1.0 264 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Max Willsey "] 3 | categories = ["data-structures"] 4 | description = "An implementation of egraphs" 5 | edition = "2018" 6 | keywords = ["e-graphs"] 7 | license = "MIT" 8 | name = "egg" 9 | readme = "README.md" 10 | repository = "https://github.com/egraphs-good/egg" 11 | version = "0.9.5" 12 | 13 | [dependencies] 14 | env_logger = { version = "0.9.0", default-features = false } 15 | fxhash = "0.2.1" 16 | hashbrown = "0.12.1" 17 | indexmap = "1.8.1" 18 | instant = "0.1.12" 19 | log = "0.4.17" 20 | smallvec = { version = "1.8.0", features = ["union", "const_generics"] } 21 | symbol_table = { version = "0.2.0", features = ["global"] } 22 | symbolic_expressions = "5.0.3" 23 | thiserror = "1.0.31" 24 | 25 | # for the lp feature 26 | coin_cbc = { version = "0.1.6", optional = true } 27 | 28 | # for the serde-1 feature 29 | serde = { version = "1.0.137", features = ["derive"], optional = true } 30 | vectorize = { version = "0.2.0", optional = true } 31 | 32 | # for the reports feature 33 | serde_json = { version = "1.0.81", optional = true } 34 | saturating = "0.1.0" 35 | 36 | [dev-dependencies] 37 | ordered-float = "3.0.0" 38 | 39 | [features] 40 | # forces the use of indexmaps over hashmaps 41 | deterministic = [] 42 | lp = ["coin_cbc"] 43 | reports = ["serde-1", "serde_json"] 44 | serde-1 = [ 45 | "serde", 46 | "indexmap/serde-1", 47 | "hashbrown/serde", 48 | "symbol_table/serde", 49 | "vectorize", 50 | ] 51 | wasm-bindgen = ["instant/wasm-bindgen"] 52 | 53 | # private features for testing 54 | test-explanations = [] 55 | 56 | [package.metadata.docs.rs] 57 | all-features = true 58 | rustdoc-args = ["--cfg", "docsrs"] 59 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 Max Willsey 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: test nits 2 | 3 | .PHONY: test 4 | test: 5 | cargo test --release 6 | cargo test --release --features=lp 7 | # don't run examples in proof-production mode 8 | cargo test --release --features "test-explanations" 9 | 10 | 11 | .PHONY: nits 12 | nits: 13 | rustup component add rustfmt clippy 14 | cargo fmt -- --check 15 | cargo clean --doc 16 | cargo doc --no-deps --all-features 17 | cargo deadlinks 18 | 19 | cargo clippy --tests 20 | cargo clippy --tests --features "test-explanations" 21 | cargo clippy --tests --features "serde-1" 22 | cargo clippy --tests --all-features 23 | 24 | .PHONY: docs 25 | docs: 26 | RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features --open -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # egg logo egg: egraphs good 2 | 3 | [![Crates.io](https://img.shields.io/crates/v/egg.svg)](https://crates.io/crates/egg) 4 | [![Released Docs.rs](https://img.shields.io/crates/v/egg?color=blue&label=docs)](https://docs.rs/egg/) 5 | [![Main branch docs](https://img.shields.io/badge/docs-main-blue)](https://egraphs-good.github.io/egg/egg/) 6 | [![Zulip](https://img.shields.io/badge/zulip-join%20chat-blue)](https://egraphs.zulipchat.com) 7 | 8 | > Also check out the [egglog](https://github.com/egraphs-good/egglog) 9 | system that provides an alternative approach to 10 | equality saturation based on Datalog. 11 | It features a language-based design, incremental execution, and composable analyses. 12 | See also the [paper](//mwillsey.com/papers/egglog) and the [egglog web demo](https://egraphs-good.github.io/egglog). 13 | 14 | Are you using egg? 15 | Please cite using the BibTeX below and 16 | add your project to the `egg` 17 | [website](https://github.com/egraphs-good/egraphs-good.github.io)! 18 | 19 |
20 | BibTeX 21 |
@article{2021-egg,
22 |   author = {Willsey, Max and Nandi, Chandrakana and Wang, Yisu Remy and Flatt, Oliver and Tatlock, Zachary and Panchekha, Pavel},
23 |   title = {egg: Fast and Extensible Equality Saturation},
24 |   year = {2021},
25 |   issue_date = {January 2021},
26 |   publisher = {Association for Computing Machinery},
27 |   address = {New York, NY, USA},
28 |   volume = {5},
29 |   number = {POPL},
30 |   url = {https://doi.org/10.1145/3434304},
31 |   doi = {10.1145/3434304},
32 |   abstract = {An e-graph efficiently represents a congruence relation over many expressions. Although they were originally developed in the late 1970s for use in automated theorem provers, a more recent technique known as equality saturation repurposes e-graphs to implement state-of-the-art, rewrite-driven compiler optimizations and program synthesizers. However, e-graphs remain unspecialized for this newer use case. Equality saturation workloads exhibit distinct characteristics and often require ad-hoc e-graph extensions to incorporate transformations beyond purely syntactic rewrites.  This work contributes two techniques that make e-graphs fast and extensible, specializing them to equality saturation. A new amortized invariant restoration technique called rebuilding takes advantage of equality saturation's distinct workload, providing asymptotic speedups over current techniques in practice. A general mechanism called e-class analyses integrates domain-specific analyses into the e-graph, reducing the need for ad hoc manipulation.  We implemented these techniques in a new open-source library called egg. Our case studies on three previously published applications of equality saturation highlight how egg's performance and flexibility enable state-of-the-art results across diverse domains.},
33 |   journal = {Proc. ACM Program. Lang.},
34 |   month = jan,
35 |   articleno = {23},
36 |   numpages = {29},
37 |   keywords = {equality saturation, e-graphs}
38 | }
39 | 
40 |
41 | 42 | Check out the [egg web demo](https://egraphs-good.github.io/egg-web-demo) for some quick e-graph action! 43 | 44 | ## Using egg 45 | 46 | Add `egg` to your `Cargo.toml` like this: 47 | ```toml 48 | [dependencies] 49 | egg = "0.9.5" 50 | ``` 51 | 52 | Make sure to compile with `--release` if you are measuring performance! 53 | 54 | ## Developing 55 | 56 | It's written in [Rust](https://www.rust-lang.org/). 57 | Typically, you install Rust using [`rustup`](https://www.rust-lang.org/tools/install). 58 | 59 | Run `cargo doc --open` to build and open the documentation in a browser. 60 | 61 | Before committing/pushing, make sure to run `make`, 62 | which runs all the tests and lints that CI will (including those under feature flags). 63 | This requires the [`cbc`](https://projects.coin-or.org/Cbc) solver 64 | due to the `lp` feature. 65 | 66 | ### Tests 67 | 68 | Running `cargo test` will run the tests. 69 | Some tests may time out; try `cargo test --release` if that happens. 70 | 71 | There are a couple interesting tests in the `tests` directory: 72 | 73 | - `prop.rs` implements propositional logic and proves some simple 74 | theorems. 75 | - `math.rs` implements real arithmetic, with a little bit of symbolic differentiation. 76 | - `lambda.rs` implements a small lambda calculus, using `egg` as a partial evaluator. 77 | 78 | 79 | ### Benchmarking 80 | 81 | To get a simple csv of the runtime of each test, you set the environment variable 82 | `EGG_BENCH_CSV` to something to append a row per test to a csv. 83 | 84 | Example: 85 | ```bash 86 | EGG_BENCH_CSV=math.csv cargo test --test math --release -- --nocapture --test --test-threads=1 87 | ``` 88 | 89 | -------------------------------------------------------------------------------- /doc/egg.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 25 | 27 | 30 | 34 | 38 | 39 | 50 | 51 | 73 | 75 | 76 | 78 | image/svg+xml 79 | 81 | 82 | 83 | 84 | 85 | 90 | 94 | 101 | 108 | 109 | 114 | 119 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /doc/egraphs.drawio: -------------------------------------------------------------------------------- 1 | 7V3bkto4EP2aedmqpZB85TGZZHYrtcmmKg9JHr3ggLMezBrPhXz9yiAZLBks25LVBpOqDJKFwN1HUqv7tHxn3T++/pEGm9XHZBHGd3i6eL2z3t1hjGYOJn/ymh2tmU6nh5plGi1o3bHiS/QrZA1p7VO0CLelhlmSxFm0KVfOk/U6nGeluiBNk5dysx9JXP7WTbAMhYov8yAWa79Gi2x1qPWd6bH+zzBartg3I3Z/jwFrTCu2q2CRvJxUWe/vrPs0SbLDu8fX+zDOpcfkcvjcw5mrxQ9Lw3Um84EP/vuZ/XPx6dfXb+mnOfrwd/Lt/neL6uc5iJ/oHdNfm+2YCJZp8rQRv43+gOcwzcLXKl0E/7AejrdLgBImj2GW7kg7hgkmoZejgJFN61YnwsUWrQyoUpdFX8f7Jm/orTcQA6qXAhHCehHmnaA76+3LKsrCL5tgnl99IdgndavsMaaXf0RxfJ/ESUrK62RNGr1dBNtV8fG88DnIsjBd72vwNK/dZmnybwEzIrG3otQva5HXxXmZu6LM3QqRu7okLoG7K5E4vepTSUJRgHVrCrANC9y+VYHvygPAlPwdQdzhgiz1tEglmKTZKlkm6yD+K0k2VJI/wyzbUdMkeMqSsh7C1yj7Rt5PJ55Di99PLr3LpTFlhR0tHNTA1HdHVrb9i1wpYyBcM5vIaq6qbfKUzsMLDd1DuyxIl+GlDml/ubwuKj4N4yCLnsumU5Ui6Uc/JxG5kQIw2LYnnj87vtwygJzJzMIuQhZCroddu9z/4S5ol0esvEnTYHfSbJM32J7/Edw87XFGFdcaoUvNyZvD1x9RW0irPZDdionEjbN8Skj2v+0Icfe/p4Rd+H27B9Ib0gDbm9c9nNh18m55+PvA+iJvD92xK53mKtITxTG2STmfRCJiZL+Jo2U+Oz1Gi0Xesb7JyOGWX3EyqrI4bV2TkadPh8GVapAbacgzrEJfUOFvQxcxN0hmhiU80zdI8JUOkhmwQcIG7SWrl9g4b3JPDSnN42C7jeZlqZZN1rLJheVNLr2mlC9pSiFJU+pEYU6FwlhdR4vLd8pw4XFwuG3BphL7mU24nmZcT2esM1U2EZLw4HTBWQPTHgbOZE32fnBWeEW7As31JlN0skPwuenOn8z86fGF+wWhhFPLyGRHsJfuDn1YVlHxvWhLCsdu9qXdaelzmEZENmFKK1VjWnYbasPCtKUI02hKQI2PoEXlbl2zmJbwE55gmrpRBECfeDoe9q/mMAKids4/IMQkpLVu13SkWbGWhGK1RX6Yw4dfFioiQZX2qsZIkGo3bcknKwwGd/+iW4eT+sOr2subkO+KslxkXpPZ2DqjNbgxIuSMugAaPUJVTsjbVI3puBKq8iV28mLUBJbEENKGGYh5p9F6Se3E6giUdkWcjzcV60s/ivFFSwlQxKkYU5ciTqdYIJ/54eT/Og+u2h0Boq6/encKXSK0h6amwwtNeZ2aI6eHUBbS6OI1Hss6N3aar259LWXAgmRstr7CKNnwsAEt/IZFj29f8bfhKQ9YYI9ZqFcY2RseNqCFDHEzj2N/XvTmqtVsJReE+ForGZbn3FYUdbRNRx1x35xOqFtvBM0/hatchy1WGGF5QcMgjjTXILhlQMLDaCSiP+BlQNZZ0s8y4HKIa7sMYJ4TgDkkGw2gYgn37GjOnDqDYMLTaYlOrhu/ZwtF9EF3cji3Rhu9z2prRTWSWHRYnWEsjaVua95MnBrEiMH8KX0uDIu+wgftWG3KFYZhKYz97rKZGTzm5h21EsulgRuJnBPQtB/Jasb61EFOskCtWTwnrTU7iWfm9M1OcqoCQ5xmdbOTHBl2UpUy9bGTLPMOh06ehUKtrT0LprNILWCcJM4mr9CnetVAcfJYwEhIJnQBlB9mmScl9TxRVdDAikW0H5GLCYm3OhqgMcEsqX3dLTHBmE5rd4ZMlbU7Q7YYaGeCOSMTTAMTjOl5ZIJpIAU0n0GBMcHsqo3+yAQzgw1oTDAbC+C4FSZY54Ft2oNnV5F9RiaYGWxAowDYEt6uqw6dylvJWNZKhkUBmCligs1MM8Hs0S0IlgtmKzowDRoXbHgLDDiSma2ZnAOeZKZhhZH1w/QUEeV3K22XGA/yyTO2hO/X9MkzyPVYReuTZ+ANC2n3pA96WChit3k9W12iX70vdtt5xDXgvRnyqTuyhGHp3UI/bCsGUoD0OFk4QNe4DUzjlX5f0/y6ARrewIh77LEwl0dyWXtmeMuiVaR8bFrGZuMzEc7qcKHqGCZPMayJSvKe64bNGSfvrHl/sbmemKcjOr+BHSGtYItveqbR7ENuvcM/2Rm5CnZGug0I6UkKVqoxnuEJt1Xn/YxtKdE973Uc8fklENZH5UDzZIGmh+/TdRXjyDK1x372Qq5xXJ3YaT0F9mBbDQ1NeAhoqnKaP1yb4WJ8iyR1eqC6YXwyhqcTp9MoBueltaSnAU/1NNDtETMAzmj2+QCK+SwoZsHpDLbXZxK0j6n7nc9XkcyCEoJd6lSg+ompUPgOzXUDje7g4lE1QPNAXEvQxI3ngfiyJwSwRIraJZzNTGMeyCDzQNwqf92YB6ImcNR8BgWWB+IqOtZtzANRYZQCI/O5Is/zVvJAOg9s094OV+MDcfE4sJthAxpL14XKbuzL0SVvJcMi36pK7+APDOY3aJpDb27VYURjAoEJkwPa1MS+Hxy9AOzUpG4DPyYQqJviPM3PT761BAINfi0MelgMM4HAE1F+wwkE0qD1lJ/CKgUO1Dc6pNz1N5RvAA4gYq6xPcG+4yGb/t/zKlrpvL71dIUWZj6wdAVP5F4OJl3hgpYNzQ09HdTGLR4FtW5MY1AcsvNGv/tw8yM8qCcgKM6PADsN1m/8ppLz5ZhR0RD6IsMWzhIOCKnMxjZFvbnKlAxPPLoAZkrGCMdbyOlgQjWS0zF8K870BtVHotE2JotUKln3tpMbvW0fwFPTjWbTiD1CsfWDlXQ+m0Fak5oNYjY/80dN1qr43Ae7eiQ4T4dkyKVhc0a+OG/Ud2le19pu1Zy/3nIpJcU0yQkcx+Zk0Vt9TBZh3uJ/ -------------------------------------------------------------------------------- /rust-toolchain: -------------------------------------------------------------------------------- 1 | 1.60 -------------------------------------------------------------------------------- /src/dot.rs: -------------------------------------------------------------------------------- 1 | /*! 2 | EGraph visualization with [GraphViz] 3 | 4 | Use the [`Dot`] struct to visualize an [`EGraph`] 5 | 6 | [GraphViz]: https://graphviz.gitlab.io/ 7 | !*/ 8 | 9 | use std::ffi::OsStr; 10 | use std::fmt::{self, Debug, Display, Formatter}; 11 | use std::io::{Error, ErrorKind, Result, Write}; 12 | use std::path::Path; 13 | 14 | use crate::{egraph::EGraph, Analysis, Language}; 15 | 16 | /** 17 | A wrapper for an [`EGraph`] that can output [GraphViz] for 18 | visualization. 19 | 20 | The [`EGraph::dot`](EGraph::dot()) method creates `Dot`s. 21 | 22 | # Example 23 | 24 | ```no_run 25 | use egg::{*, rewrite as rw}; 26 | 27 | let rules = &[ 28 | rw!("mul-commutes"; "(* ?x ?y)" => "(* ?y ?x)"), 29 | rw!("mul-two"; "(* ?x 2)" => "(<< ?x 1)"), 30 | ]; 31 | 32 | let mut egraph: EGraph = Default::default(); 33 | egraph.add_expr(&"(/ (* 2 a) 2)".parse().unwrap()); 34 | let egraph = Runner::default().with_egraph(egraph).run(rules).egraph; 35 | 36 | // Dot implements std::fmt::Display 37 | println!("My egraph dot file: {}", egraph.dot()); 38 | 39 | // create a Dot and then compile it assuming `dot` is on the system 40 | egraph.dot().to_svg("target/foo.svg").unwrap(); 41 | egraph.dot().to_png("target/foo.png").unwrap(); 42 | egraph.dot().to_pdf("target/foo.pdf").unwrap(); 43 | egraph.dot().to_dot("target/foo.dot").unwrap(); 44 | ``` 45 | 46 | Note that self-edges (from an enode to its containing eclass) will be 47 | rendered improperly due to a deficiency in GraphViz. 48 | So the example above will render with an from the "+" enode to itself 49 | instead of to its own eclass. 50 | 51 | [GraphViz]: https://graphviz.gitlab.io/ 52 | **/ 53 | pub struct Dot<'a, L: Language, N: Analysis> { 54 | pub(crate) egraph: &'a EGraph, 55 | /// A list of strings to be output top part of the dot file. 56 | pub config: Vec, 57 | /// Whether or not to anchor the edges in the output. 58 | /// True by default. 59 | pub use_anchors: bool, 60 | } 61 | 62 | impl<'a, L, N> Dot<'a, L, N> 63 | where 64 | L: Language + Display, 65 | N: Analysis, 66 | { 67 | /// Writes the `Dot` to a .dot file with the given filename. 68 | /// Does _not_ require a `dot` binary. 69 | pub fn to_dot(&self, filename: impl AsRef) -> Result<()> { 70 | let mut file = std::fs::File::create(filename)?; 71 | write!(file, "{}", self) 72 | } 73 | 74 | /// Adds a line to the dot output. 75 | /// Indentation and a newline will be added automatically. 76 | pub fn with_config_line(mut self, line: impl Into) -> Self { 77 | self.config.push(line.into()); 78 | self 79 | } 80 | 81 | /// Set whether or not to anchor the edges in the output. 82 | pub fn with_anchors(mut self, use_anchors: bool) -> Self { 83 | self.use_anchors = use_anchors; 84 | self 85 | } 86 | 87 | /// Renders the `Dot` to a .png file with the given filename. 88 | /// Requires a `dot` binary to be on your `$PATH`. 89 | pub fn to_png(&self, filename: impl AsRef) -> Result<()> { 90 | self.run_dot(&["-Tpng".as_ref(), "-o".as_ref(), filename.as_ref()]) 91 | } 92 | 93 | /// Renders the `Dot` to a .svg file with the given filename. 94 | /// Requires a `dot` binary to be on your `$PATH`. 95 | pub fn to_svg(&self, filename: impl AsRef) -> Result<()> { 96 | self.run_dot(&["-Tsvg".as_ref(), "-o".as_ref(), filename.as_ref()]) 97 | } 98 | 99 | /// Renders the `Dot` to a .pdf file with the given filename. 100 | /// Requires a `dot` binary to be on your `$PATH`. 101 | pub fn to_pdf(&self, filename: impl AsRef) -> Result<()> { 102 | self.run_dot(&["-Tpdf".as_ref(), "-o".as_ref(), filename.as_ref()]) 103 | } 104 | 105 | /// Invokes `dot` with the given arguments, piping this formatted 106 | /// `Dot` into stdin. 107 | pub fn run_dot(&self, args: I) -> Result<()> 108 | where 109 | S: AsRef, 110 | I: IntoIterator, 111 | { 112 | self.run("dot", args) 113 | } 114 | 115 | /// Invokes some program with the given arguments, piping this 116 | /// formatted `Dot` into stdin. 117 | /// 118 | /// Can be used to run a different binary than `dot`: 119 | /// ```no_run 120 | /// # use egg::*; 121 | /// # let mut egraph: EGraph = Default::default(); 122 | /// egraph.dot().run( 123 | /// "/path/to/my/dot", 124 | /// &["arg1", "-o", "outfile"] 125 | /// ).unwrap(); 126 | /// ``` 127 | pub fn run(&self, program: S1, args: I) -> Result<()> 128 | where 129 | S1: AsRef, 130 | S2: AsRef, 131 | I: IntoIterator, 132 | { 133 | use std::process::{Command, Stdio}; 134 | let mut child = Command::new(program) 135 | .args(args) 136 | .stdin(Stdio::piped()) 137 | .stdout(Stdio::null()) 138 | .spawn()?; 139 | let stdin = child.stdin.as_mut().expect("Failed to open stdin"); 140 | write!(stdin, "{}", self)?; 141 | match child.wait()?.code() { 142 | Some(0) => Ok(()), 143 | Some(e) => Err(Error::new( 144 | ErrorKind::Other, 145 | format!("dot program returned error code {}", e), 146 | )), 147 | None => Err(Error::new( 148 | ErrorKind::Other, 149 | "dot program was killed by a signal", 150 | )), 151 | } 152 | } 153 | 154 | // gives back the appropriate label and anchor 155 | fn edge(&self, i: usize, len: usize) -> (String, String) { 156 | assert!(i < len); 157 | let s = |s: &str| s.to_string(); 158 | if !self.use_anchors { 159 | return (s(""), format!("label={}", i)); 160 | } 161 | match (len, i) { 162 | (1, 0) => (s(""), s("")), 163 | (2, 0) => (s(":sw"), s("")), 164 | (2, 1) => (s(":se"), s("")), 165 | (3, 0) => (s(":sw"), s("")), 166 | (3, 1) => (s(":s"), s("")), 167 | (3, 2) => (s(":se"), s("")), 168 | (_, _) => (s(""), format!("label={}", i)), 169 | } 170 | } 171 | } 172 | 173 | impl<'a, L: Language, N: Analysis> Debug for Dot<'a, L, N> { 174 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 175 | f.debug_tuple("Dot").field(self.egraph).finish() 176 | } 177 | } 178 | 179 | impl<'a, L, N> Display for Dot<'a, L, N> 180 | where 181 | L: Language + Display, 182 | N: Analysis, 183 | { 184 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 185 | writeln!(f, "digraph egraph {{")?; 186 | 187 | // set compound=true to enable edges to clusters 188 | writeln!(f, " compound=true")?; 189 | writeln!(f, " clusterrank=local")?; 190 | 191 | for line in &self.config { 192 | writeln!(f, " {}", line)?; 193 | } 194 | 195 | // define all the nodes, clustered by eclass 196 | for class in self.egraph.classes() { 197 | writeln!(f, " subgraph cluster_{} {{", class.id)?; 198 | writeln!(f, " style=dotted")?; 199 | for (i, node) in class.iter().enumerate() { 200 | writeln!(f, " {}.{}[label = \"{}\"]", class.id, i, node)?; 201 | } 202 | writeln!(f, " }}")?; 203 | } 204 | 205 | for class in self.egraph.classes() { 206 | for (i_in_class, node) in class.iter().enumerate() { 207 | let mut arg_i = 0; 208 | node.try_for_each(|child| { 209 | // write the edge to the child, but clip it to the eclass with lhead 210 | let (anchor, label) = self.edge(arg_i, node.len()); 211 | let child_leader = self.egraph.find(child); 212 | 213 | if child_leader == class.id { 214 | writeln!( 215 | f, 216 | // {}.0 to pick an arbitrary node in the cluster 217 | " {}.{}{} -> {}.{}:n [lhead = cluster_{}, {}]", 218 | class.id, i_in_class, anchor, class.id, i_in_class, class.id, label 219 | )?; 220 | } else { 221 | writeln!( 222 | f, 223 | // {}.0 to pick an arbitrary node in the cluster 224 | " {}.{}{} -> {}.0 [lhead = cluster_{}, {}]", 225 | class.id, i_in_class, anchor, child, child_leader, label 226 | )?; 227 | } 228 | arg_i += 1; 229 | Ok(()) 230 | })?; 231 | } 232 | } 233 | 234 | write!(f, "}}") 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /src/eclass.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::iter::ExactSizeIterator; 3 | 4 | use crate::*; 5 | 6 | /// An equivalence class of enodes. 7 | #[non_exhaustive] 8 | #[derive(Debug, Clone)] 9 | #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] 10 | pub struct EClass { 11 | /// This eclass's id. 12 | pub id: Id, 13 | /// The equivalent enodes in this equivalence class. 14 | pub nodes: Vec, 15 | /// The analysis data associated with this eclass. 16 | /// 17 | /// Modifying this field will _not_ cause changes to propagate through the e-graph. 18 | /// Prefer [`EGraph::set_analysis_data`] instead. 19 | pub data: D, 20 | /// The original Ids of parent enodes. 21 | pub(crate) parents: Vec, 22 | } 23 | 24 | impl EClass { 25 | /// Returns `true` if the `eclass` is empty. 26 | pub fn is_empty(&self) -> bool { 27 | self.nodes.is_empty() 28 | } 29 | 30 | /// Returns the number of enodes in this eclass. 31 | pub fn len(&self) -> usize { 32 | self.nodes.len() 33 | } 34 | 35 | /// Iterates over the enodes in this eclass. 36 | pub fn iter(&self) -> impl ExactSizeIterator { 37 | self.nodes.iter() 38 | } 39 | 40 | /// Iterates over the non-canonical ids of parent enodes of this eclass. 41 | pub fn parents(&self) -> impl ExactSizeIterator + '_ { 42 | self.parents.iter().copied() 43 | } 44 | } 45 | 46 | impl EClass { 47 | /// Iterates over the childless enodes in this eclass. 48 | pub fn leaves(&self) -> impl Iterator { 49 | self.nodes.iter().filter(|&n| n.is_leaf()) 50 | } 51 | 52 | /// Asserts that the childless enodes in this eclass are unique. 53 | pub fn assert_unique_leaves(&self) 54 | where 55 | L: Language, 56 | { 57 | let mut leaves = self.leaves(); 58 | if let Some(first) = leaves.next() { 59 | assert!( 60 | leaves.all(|l| l == first), 61 | "Different leaves in eclass {}: {:?}", 62 | self.id, 63 | self.leaves().collect::>() 64 | ); 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/extract.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | use std::fmt::Debug; 3 | 4 | use crate::util::HashMap; 5 | use crate::{Analysis, EClass, EGraph, Id, Language, RecExpr}; 6 | 7 | /** Extracting a single [`RecExpr`] from an [`EGraph`]. 8 | 9 | ``` 10 | use egg::*; 11 | 12 | define_language! { 13 | enum SimpleLanguage { 14 | Num(i32), 15 | "+" = Add([Id; 2]), 16 | "*" = Mul([Id; 2]), 17 | } 18 | } 19 | 20 | let rules: &[Rewrite] = &[ 21 | rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), 22 | rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), 23 | 24 | rewrite!("add-0"; "(+ ?a 0)" => "?a"), 25 | rewrite!("mul-0"; "(* ?a 0)" => "0"), 26 | rewrite!("mul-1"; "(* ?a 1)" => "?a"), 27 | ]; 28 | 29 | let start = "(+ 0 (* 1 10))".parse().unwrap(); 30 | let runner = Runner::default().with_expr(&start).run(rules); 31 | let (egraph, root) = (runner.egraph, runner.roots[0]); 32 | 33 | let mut extractor = Extractor::new(&egraph, AstSize); 34 | let (best_cost, best) = extractor.find_best(root); 35 | assert_eq!(best_cost, 1); 36 | assert_eq!(best, "10".parse().unwrap()); 37 | ``` 38 | 39 | **/ 40 | #[derive(Debug)] 41 | pub struct Extractor<'a, CF: CostFunction, L: Language, N: Analysis> { 42 | cost_function: CF, 43 | costs: HashMap, 44 | egraph: &'a EGraph, 45 | } 46 | 47 | /** A cost function that can be used by an [`Extractor`]. 48 | 49 | To extract an expression from an [`EGraph`], the [`Extractor`] 50 | requires a cost function to performs its greedy search. 51 | `egg` provides the simple [`AstSize`] and [`AstDepth`] cost functions. 52 | 53 | The example below illustrates a silly but realistic example of 54 | implementing a cost function that is essentially AST size weighted by 55 | the operator: 56 | ``` 57 | # use egg::*; 58 | struct SillyCostFn; 59 | impl CostFunction for SillyCostFn { 60 | type Cost = f64; 61 | fn cost(&mut self, enode: &SymbolLang, mut costs: C) -> Self::Cost 62 | where 63 | C: FnMut(Id) -> Self::Cost 64 | { 65 | let op_cost = match enode.op.as_str() { 66 | "foo" => 100.0, 67 | "bar" => 0.7, 68 | _ => 1.0 69 | }; 70 | enode.fold(op_cost, |sum, id| sum + costs(id)) 71 | } 72 | } 73 | 74 | let e: RecExpr = "(do_it foo bar baz)".parse().unwrap(); 75 | assert_eq!(SillyCostFn.cost_rec(&e), 102.7); 76 | assert_eq!(AstSize.cost_rec(&e), 4); 77 | assert_eq!(AstDepth.cost_rec(&e), 2); 78 | ``` 79 | 80 | If you'd like to access the [`Analysis`] data or anything else in the e-graph, 81 | you can put a reference to the e-graph in your [`CostFunction`]: 82 | 83 | ``` 84 | # use egg::*; 85 | # type MyAnalysis = (); 86 | struct EGraphCostFn<'a> { 87 | egraph: &'a EGraph, 88 | } 89 | 90 | impl<'a> CostFunction for EGraphCostFn<'a> { 91 | type Cost = usize; 92 | fn cost(&mut self, enode: &SymbolLang, mut costs: C) -> Self::Cost 93 | where 94 | C: FnMut(Id) -> Self::Cost 95 | { 96 | // use self.egraph however you want here 97 | println!("the egraph has {} classes", self.egraph.number_of_classes()); 98 | return 1 99 | } 100 | } 101 | 102 | let mut egraph = EGraph::::default(); 103 | let id = egraph.add_expr(&"(foo bar)".parse().unwrap()); 104 | let cost_func = EGraphCostFn { egraph: &egraph }; 105 | let mut extractor = Extractor::new(&egraph, cost_func); 106 | let _ = extractor.find_best(id); 107 | ``` 108 | 109 | Note that a particular e-class might occur in an expression multiple times. 110 | This means that pathological, but nevertheless realistic cases 111 | might overflow `usize` if you implement a cost function like [`AstSize`], 112 | even if the actual [`RecExpr`] fits compactly in memory. 113 | You might want to use [`saturating_add`](u64::saturating_add) to 114 | ensure your cost function is still monotonic in this situation. 115 | **/ 116 | pub trait CostFunction { 117 | /// The `Cost` type. It only requires `PartialOrd` so you can use 118 | /// floating point types, but failed comparisons (`NaN`s) will 119 | /// result in a panic. 120 | type Cost: PartialOrd + Debug + Clone; 121 | 122 | /// Calculates the cost of an enode whose children are `Cost`s. 123 | /// 124 | /// For this to work properly, your cost function should be 125 | /// _monotonic_, i.e. `cost` should return a `Cost` greater than 126 | /// any of the child costs of the given enode. 127 | fn cost(&mut self, enode: &L, costs: C) -> Self::Cost 128 | where 129 | C: FnMut(Id) -> Self::Cost; 130 | 131 | /// Calculates the total cost of a [`RecExpr`]. 132 | /// 133 | /// As provided, this just recursively calls `cost` all the way 134 | /// down the [`RecExpr`]. 135 | /// 136 | fn cost_rec(&mut self, expr: &RecExpr) -> Self::Cost { 137 | let mut costs: HashMap = HashMap::default(); 138 | for (i, node) in expr.as_ref().iter().enumerate() { 139 | let cost = self.cost(node, |i| costs[&i].clone()); 140 | costs.insert(Id::from(i), cost); 141 | } 142 | let last_id = Id::from(expr.as_ref().len() - 1); 143 | costs[&last_id].clone() 144 | } 145 | } 146 | 147 | /** A simple [`CostFunction`] that counts total AST size. 148 | 149 | ``` 150 | # use egg::*; 151 | let e: RecExpr = "(do_it foo bar baz)".parse().unwrap(); 152 | assert_eq!(AstSize.cost_rec(&e), 4); 153 | ``` 154 | 155 | **/ 156 | #[derive(Debug)] 157 | pub struct AstSize; 158 | impl CostFunction for AstSize { 159 | type Cost = usize; 160 | fn cost(&mut self, enode: &L, mut costs: C) -> Self::Cost 161 | where 162 | C: FnMut(Id) -> Self::Cost, 163 | { 164 | enode.fold(1, |sum, id| sum.saturating_add(costs(id))) 165 | } 166 | } 167 | 168 | /** A simple [`CostFunction`] that counts maximum AST depth. 169 | 170 | ``` 171 | # use egg::*; 172 | let e: RecExpr = "(do_it foo bar baz)".parse().unwrap(); 173 | assert_eq!(AstDepth.cost_rec(&e), 2); 174 | ``` 175 | 176 | **/ 177 | #[derive(Debug)] 178 | pub struct AstDepth; 179 | impl CostFunction for AstDepth { 180 | type Cost = usize; 181 | fn cost(&mut self, enode: &L, mut costs: C) -> Self::Cost 182 | where 183 | C: FnMut(Id) -> Self::Cost, 184 | { 185 | 1 + enode.fold(0, |max, id| max.max(costs(id))) 186 | } 187 | } 188 | 189 | fn cmp(a: &Option, b: &Option) -> Ordering { 190 | // None is high 191 | match (a, b) { 192 | (None, None) => Ordering::Equal, 193 | (None, Some(_)) => Ordering::Greater, 194 | (Some(_), None) => Ordering::Less, 195 | (Some(a), Some(b)) => a.partial_cmp(b).unwrap(), 196 | } 197 | } 198 | 199 | impl<'a, CF, L, N> Extractor<'a, CF, L, N> 200 | where 201 | CF: CostFunction, 202 | L: Language, 203 | N: Analysis, 204 | { 205 | /// Create a new `Extractor` given an `EGraph` and a 206 | /// `CostFunction`. 207 | /// 208 | /// The extraction does all the work on creation, so this function 209 | /// performs the greedy search for cheapest representative of each 210 | /// eclass. 211 | pub fn new(egraph: &'a EGraph, cost_function: CF) -> Self { 212 | let costs = HashMap::default(); 213 | let mut extractor = Extractor { 214 | costs, 215 | egraph, 216 | cost_function, 217 | }; 218 | extractor.find_costs(); 219 | 220 | extractor 221 | } 222 | 223 | /// Find the cheapest (lowest cost) represented `RecExpr` in the 224 | /// given eclass. 225 | pub fn find_best(&self, eclass: Id) -> (CF::Cost, RecExpr) { 226 | let (cost, root) = self.costs[&self.egraph.find(eclass)].clone(); 227 | let expr = root.build_recexpr(|id| self.find_best_node(id).clone()); 228 | (cost, expr) 229 | } 230 | 231 | /// Find the cheapest e-node in the given e-class. 232 | pub fn find_best_node(&self, eclass: Id) -> &L { 233 | &self.costs[&self.egraph.find(eclass)].1 234 | } 235 | 236 | /// Find the cost of the term that would be extracted from this e-class. 237 | pub fn find_best_cost(&self, eclass: Id) -> CF::Cost { 238 | let (cost, _) = &self.costs[&self.egraph.find(eclass)]; 239 | cost.clone() 240 | } 241 | 242 | fn node_total_cost(&mut self, node: &L) -> Option { 243 | let eg = &self.egraph; 244 | let has_cost = |id| self.costs.contains_key(&eg.find(id)); 245 | if node.all(has_cost) { 246 | let costs = &self.costs; 247 | let cost_f = |id| costs[&eg.find(id)].0.clone(); 248 | Some(self.cost_function.cost(node, cost_f)) 249 | } else { 250 | None 251 | } 252 | } 253 | 254 | fn find_costs(&mut self) { 255 | let mut did_something = true; 256 | while did_something { 257 | did_something = false; 258 | 259 | for class in self.egraph.classes() { 260 | let pass = self.make_pass(class); 261 | match (self.costs.get(&class.id), pass) { 262 | (None, Some(new)) => { 263 | self.costs.insert(class.id, new); 264 | did_something = true; 265 | } 266 | (Some(old), Some(new)) if new.0 < old.0 => { 267 | self.costs.insert(class.id, new); 268 | did_something = true; 269 | } 270 | _ => (), 271 | } 272 | } 273 | } 274 | 275 | for class in self.egraph.classes() { 276 | if !self.costs.contains_key(&class.id) { 277 | log::warn!( 278 | "Failed to compute cost for eclass {}: {:?}", 279 | class.id, 280 | class.nodes 281 | ) 282 | } 283 | } 284 | } 285 | 286 | fn make_pass(&mut self, eclass: &EClass) -> Option<(CF::Cost, L)> { 287 | let (cost, node) = eclass 288 | .iter() 289 | .map(|n| (self.node_total_cost(n), n)) 290 | .min_by(|a, b| cmp(&a.0, &b.0)) 291 | .unwrap_or_else(|| panic!("Can't extract, eclass is empty: {:#?}", eclass)); 292 | cost.map(|c| (c, node.clone())) 293 | } 294 | } 295 | 296 | #[cfg(test)] 297 | mod tests { 298 | use crate::*; 299 | 300 | #[test] 301 | fn ast_size_overflow() { 302 | let rules: &[Rewrite] = 303 | &[rewrite!("explode"; "(meow ?a)" => "(meow (meow ?a ?a))")]; 304 | 305 | let start = "(meow 42)".parse().unwrap(); 306 | let runner = Runner::default() 307 | .with_iter_limit(100) 308 | .with_expr(&start) 309 | .run(rules); 310 | 311 | let extractor = Extractor::new(&runner.egraph, AstSize); 312 | let (_, best_expr) = extractor.find_best(runner.roots[0]); 313 | assert_eq!(best_expr, start); 314 | } 315 | } 316 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(docsrs, feature(doc_cfg))] 2 | #![warn(missing_docs)] 3 | /*! 4 | 5 | `egg` (**e**-**g**raphs **g**ood) is a e-graph library optimized for equality saturation. 6 | 7 | This is the API documentation. 8 | 9 | The [tutorial](tutorials) is a good starting point if you're new to 10 | e-graphs, equality saturation, or Rust. 11 | 12 | The [tests](https://github.com/egraphs-good/egg/tree/main/tests) 13 | on Github provide some more elaborate examples. 14 | 15 | There is also a [paper](https://arxiv.org/abs/2004.03082) 16 | describing `egg` and some of its technical novelties. 17 | 18 | ## Logging 19 | 20 | Many parts of `egg` dump useful logging info using the [`log`](https://docs.rs/log/) crate. 21 | The easiest way to see this info is to use the [`env_logger`](https://docs.rs/env_logger/) 22 | crate in your binary or test. 23 | The simplest way to enable `env_logger` is to put the following line near the top of your `main`: 24 | `env_logger::init();`. 25 | Then, set the environment variable `RUST_LOG=egg=info`, or use `warn` or `debug` instead of info 26 | for less or more logging. 27 | 28 | */ 29 | #![doc = "## Simple Example\n```"] 30 | #![doc = include_str!("../tests/simple.rs")] 31 | #![doc = "\n```"] 32 | 33 | mod macros; 34 | 35 | #[doc(hidden)] 36 | pub mod test; 37 | 38 | pub mod tutorials; 39 | 40 | mod dot; 41 | mod eclass; 42 | mod egraph; 43 | mod explain; 44 | mod extract; 45 | mod language; 46 | #[cfg(feature = "lp")] 47 | mod lp_extract; 48 | mod machine; 49 | mod multipattern; 50 | mod pattern; 51 | mod rewrite; 52 | mod run; 53 | mod subst; 54 | mod unionfind; 55 | mod util; 56 | 57 | /// A key to identify [`EClass`]es within an 58 | /// [`EGraph`]. 59 | #[derive(Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Hash)] 60 | #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] 61 | #[cfg_attr(feature = "serde-1", serde(transparent))] 62 | pub struct Id(u32); 63 | 64 | impl From for Id { 65 | fn from(n: usize) -> Id { 66 | Id(n as u32) 67 | } 68 | } 69 | 70 | impl From for usize { 71 | fn from(id: Id) -> usize { 72 | id.0 as usize 73 | } 74 | } 75 | 76 | impl std::fmt::Debug for Id { 77 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 78 | write!(f, "{}", self.0) 79 | } 80 | } 81 | 82 | impl std::fmt::Display for Id { 83 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 84 | write!(f, "{}", self.0) 85 | } 86 | } 87 | 88 | pub(crate) use {explain::Explain, unionfind::UnionFind}; 89 | 90 | pub use { 91 | dot::Dot, 92 | eclass::EClass, 93 | egraph::{EGraph, LanguageMapper, SimpleLanguageMapper}, 94 | explain::{ 95 | Explanation, FlatExplanation, FlatTerm, Justification, TreeExplanation, TreeTerm, 96 | UnionEqualities, 97 | }, 98 | extract::*, 99 | language::*, 100 | multipattern::*, 101 | pattern::{ENodeOrVar, Pattern, PatternAst, SearchMatches}, 102 | rewrite::{Applier, Condition, ConditionEqual, ConditionalApplier, Rewrite, Searcher}, 103 | run::*, 104 | subst::{Subst, Var}, 105 | util::*, 106 | }; 107 | 108 | #[cfg(feature = "lp")] 109 | pub use lp_extract::*; 110 | 111 | #[cfg(test)] 112 | fn init_logger() { 113 | let _ = env_logger::builder().is_test(true).try_init(); 114 | } 115 | -------------------------------------------------------------------------------- /src/lp_extract.rs: -------------------------------------------------------------------------------- 1 | use coin_cbc::{Col, Model, Sense}; 2 | 3 | use crate::*; 4 | 5 | /// A cost function to be used by an [`LpExtractor`]. 6 | #[cfg_attr(docsrs, doc(cfg(feature = "lp")))] 7 | pub trait LpCostFunction> { 8 | /// Returns the cost of the given e-node. 9 | /// 10 | /// This function may look at other parts of the e-graph to compute the cost 11 | /// of the given e-node. 12 | fn node_cost(&mut self, egraph: &EGraph, eclass: Id, enode: &L) -> f64; 13 | } 14 | 15 | #[cfg_attr(docsrs, doc(cfg(feature = "lp")))] 16 | impl> LpCostFunction for AstSize { 17 | fn node_cost(&mut self, _egraph: &EGraph, _eclass: Id, _enode: &L) -> f64 { 18 | 1.0 19 | } 20 | } 21 | 22 | /// A structure to perform extraction using integer linear programming. 23 | /// This uses the [`cbc`](https://projects.coin-or.org/Cbc) solver. 24 | /// You must have it installed on your machine to use this feature. 25 | /// You can install it using: 26 | /// 27 | /// | OS | Command | 28 | /// |------------------|------------------------------------------| 29 | /// | Fedora / Red Hat | `sudo dnf install coin-or-Cbc-devel` | 30 | /// | Ubuntu / Debian | `sudo apt-get install coinor-libcbc-dev` | 31 | /// | macOS | `brew install cbc` | 32 | /// 33 | /// On macOS, you might also need the following in your `.zshrc` file: 34 | /// `export LIBRARY_PATH=$LIBRARY_PATH:$(brew --prefix)/lib` 35 | /// 36 | /// # Example 37 | /// ``` 38 | /// use egg::*; 39 | /// let mut egraph = EGraph::::default(); 40 | /// 41 | /// let f = egraph.add_expr(&"(f x x x)".parse().unwrap()); 42 | /// let g = egraph.add_expr(&"(g (g x))".parse().unwrap()); 43 | /// egraph.union(f, g); 44 | /// egraph.rebuild(); 45 | /// 46 | /// let best = Extractor::new(&egraph, AstSize).find_best(f).1; 47 | /// let lp_best = LpExtractor::new(&egraph, AstSize).solve(f); 48 | /// 49 | /// // In regular extraction, cost is measures on the tree. 50 | /// assert_eq!(best.to_string(), "(g (g x))"); 51 | /// 52 | /// // Using ILP only counts common sub-expressions once, 53 | /// // so it can lead to a smaller DAG expression. 54 | /// assert_eq!(lp_best.to_string(), "(f x x x)"); 55 | /// assert_eq!(lp_best.as_ref().len(), 2); 56 | /// ``` 57 | #[cfg_attr(docsrs, doc(cfg(feature = "lp")))] 58 | pub struct LpExtractor<'a, L: Language, N: Analysis> { 59 | egraph: &'a EGraph, 60 | model: Model, 61 | vars: HashMap, 62 | } 63 | 64 | struct ClassVars { 65 | active: Col, 66 | order: Col, 67 | nodes: Vec, 68 | } 69 | 70 | impl<'a, L, N> LpExtractor<'a, L, N> 71 | where 72 | L: Language, 73 | N: Analysis, 74 | { 75 | /// Create an [`LpExtractor`] using costs from the given [`LpCostFunction`]. 76 | /// See those docs for details. 77 | pub fn new(egraph: &'a EGraph, mut cost_function: CF) -> Self 78 | where 79 | CF: LpCostFunction, 80 | { 81 | let max_order = egraph.total_number_of_nodes() as f64 * 10.0; 82 | 83 | let mut model = Model::default(); 84 | 85 | let vars: HashMap = egraph 86 | .classes() 87 | .map(|class| { 88 | let cvars = ClassVars { 89 | active: model.add_binary(), 90 | order: model.add_col(), 91 | nodes: class.nodes.iter().map(|_| model.add_binary()).collect(), 92 | }; 93 | model.set_col_upper(cvars.order, max_order); 94 | (class.id, cvars) 95 | }) 96 | .collect(); 97 | 98 | let mut cycles: HashSet<(Id, usize)> = Default::default(); 99 | find_cycles(egraph, |id, i| { 100 | cycles.insert((id, i)); 101 | }); 102 | 103 | for (&id, class) in &vars { 104 | // class active == some node active 105 | // sum(for node_active in class) == class_active 106 | let row = model.add_row(); 107 | model.set_row_equal(row, 0.0); 108 | model.set_weight(row, class.active, -1.0); 109 | for &node_active in &class.nodes { 110 | model.set_weight(row, node_active, 1.0); 111 | } 112 | 113 | for (i, (node, &node_active)) in egraph[id].iter().zip(&class.nodes).enumerate() { 114 | if cycles.contains(&(id, i)) { 115 | model.set_col_upper(node_active, 0.0); 116 | model.set_col_lower(node_active, 0.0); 117 | continue; 118 | } 119 | 120 | for child in node.children() { 121 | let child_active = vars[child].active; 122 | // node active implies child active, encoded as: 123 | // node_active <= child_active 124 | // node_active - child_active <= 0 125 | let row = model.add_row(); 126 | model.set_row_upper(row, 0.0); 127 | model.set_weight(row, node_active, 1.0); 128 | model.set_weight(row, child_active, -1.0); 129 | } 130 | } 131 | } 132 | 133 | model.set_obj_sense(Sense::Minimize); 134 | for class in egraph.classes() { 135 | for (node, &node_active) in class.iter().zip(&vars[&class.id].nodes) { 136 | model.set_obj_coeff(node_active, cost_function.node_cost(egraph, class.id, node)); 137 | } 138 | } 139 | 140 | dbg!(max_order); 141 | 142 | Self { 143 | egraph, 144 | model, 145 | vars, 146 | } 147 | } 148 | 149 | /// Set the cbc timeout in seconds. 150 | pub fn timeout(&mut self, seconds: f64) -> &mut Self { 151 | self.model.set_parameter("seconds", &seconds.to_string()); 152 | self 153 | } 154 | 155 | /// Extract a single rooted term. 156 | /// 157 | /// This is just a shortcut for [`LpExtractor::solve_multiple`]. 158 | pub fn solve(&mut self, root: Id) -> RecExpr { 159 | self.solve_multiple(&[root]).0 160 | } 161 | 162 | /// Extract (potentially multiple) roots 163 | pub fn solve_multiple(&mut self, roots: &[Id]) -> (RecExpr, Vec) { 164 | let egraph = self.egraph; 165 | 166 | for class in self.vars.values() { 167 | self.model.set_binary(class.active); 168 | } 169 | 170 | for root in roots { 171 | let root = &egraph.find(*root); 172 | self.model.set_col_lower(self.vars[root].active, 1.0); 173 | } 174 | 175 | let solution = self.model.solve(); 176 | log::info!( 177 | "CBC status {:?}, {:?}", 178 | solution.raw().status(), 179 | solution.raw().secondary_status() 180 | ); 181 | 182 | let mut todo: Vec = roots.iter().map(|id| self.egraph.find(*id)).collect(); 183 | let mut expr = RecExpr::default(); 184 | // converts e-class ids to e-node ids 185 | let mut ids: HashMap = HashMap::default(); 186 | 187 | while let Some(&id) = todo.last() { 188 | if ids.contains_key(&id) { 189 | todo.pop(); 190 | continue; 191 | } 192 | let v = &self.vars[&id]; 193 | assert!(solution.col(v.active) > 0.0); 194 | let node_idx = v.nodes.iter().position(|&n| solution.col(n) > 0.0).unwrap(); 195 | let node = &self.egraph[id].nodes[node_idx]; 196 | if node.all(|child| ids.contains_key(&child)) { 197 | let new_id = expr.add(node.clone().map_children(|i| ids[&self.egraph.find(i)])); 198 | ids.insert(id, new_id); 199 | todo.pop(); 200 | } else { 201 | todo.extend_from_slice(node.children()) 202 | } 203 | } 204 | 205 | let root_idxs = roots.iter().map(|root| ids[root]).collect(); 206 | 207 | assert!(expr.is_dag(), "LpExtract found a cyclic term!: {:?}", expr); 208 | (expr, root_idxs) 209 | } 210 | } 211 | 212 | fn find_cycles(egraph: &EGraph, mut f: impl FnMut(Id, usize)) 213 | where 214 | L: Language, 215 | N: Analysis, 216 | { 217 | enum Color { 218 | White, 219 | Gray, 220 | Black, 221 | } 222 | type Enter = bool; 223 | 224 | let mut color: HashMap = egraph.classes().map(|c| (c.id, Color::White)).collect(); 225 | let mut stack: Vec<(Enter, Id)> = egraph.classes().map(|c| (true, c.id)).collect(); 226 | 227 | while let Some((enter, id)) = stack.pop() { 228 | if enter { 229 | *color.get_mut(&id).unwrap() = Color::Gray; 230 | stack.push((false, id)); 231 | for (i, node) in egraph[id].iter().enumerate() { 232 | for child in node.children() { 233 | match &color[child] { 234 | Color::White => stack.push((true, *child)), 235 | Color::Gray => f(id, i), 236 | Color::Black => (), 237 | } 238 | } 239 | } 240 | } else { 241 | *color.get_mut(&id).unwrap() = Color::Black; 242 | } 243 | } 244 | } 245 | 246 | #[cfg(test)] 247 | mod tests { 248 | use crate::{SymbolLang as S, *}; 249 | 250 | #[test] 251 | fn simple_lp_extract_two() { 252 | let mut egraph = EGraph::::default(); 253 | let a = egraph.add(S::leaf("a")); 254 | let plus = egraph.add(S::new("+", vec![a, a])); 255 | let f = egraph.add(S::new("f", vec![plus])); 256 | let g = egraph.add(S::new("g", vec![plus])); 257 | 258 | let mut ext = LpExtractor::new(&egraph, AstSize); 259 | ext.timeout(10.0); // way too much time 260 | let (exp, ids) = ext.solve_multiple(&[f, g]); 261 | println!("{:?}", exp); 262 | println!("{}", exp); 263 | assert_eq!(exp.as_ref().len(), 4); 264 | assert_eq!(ids.len(), 2); 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /src/machine.rs: -------------------------------------------------------------------------------- 1 | use crate::*; 2 | use std::result; 3 | 4 | type Result = result::Result<(), ()>; 5 | 6 | #[derive(Default)] 7 | struct Machine { 8 | reg: Vec, 9 | // a buffer to re-use for lookups 10 | lookup: Vec, 11 | } 12 | 13 | #[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] 14 | struct Reg(u32); 15 | 16 | #[derive(Debug, Clone, PartialEq, Eq)] 17 | pub struct Program { 18 | instructions: Vec>, 19 | subst: Subst, 20 | } 21 | 22 | #[derive(Debug, Clone, PartialEq, Eq)] 23 | enum Instruction { 24 | Bind { node: L, i: Reg, out: Reg }, 25 | Compare { i: Reg, j: Reg }, 26 | Lookup { term: Vec>, i: Reg }, 27 | Scan { out: Reg }, 28 | } 29 | 30 | #[derive(Debug, Clone, PartialEq, Eq)] 31 | enum ENodeOrReg { 32 | ENode(L), 33 | Reg(Reg), 34 | } 35 | 36 | #[inline(always)] 37 | fn for_each_matching_node( 38 | eclass: &EClass, 39 | node: &L, 40 | mut f: impl FnMut(&L) -> Result, 41 | ) -> Result 42 | where 43 | L: Language, 44 | { 45 | if eclass.nodes.len() < 50 { 46 | eclass 47 | .nodes 48 | .iter() 49 | .filter(|n| node.matches(n)) 50 | .try_for_each(f) 51 | } else { 52 | debug_assert!(node.all(|id| id == Id::from(0))); 53 | debug_assert!(eclass.nodes.windows(2).all(|w| w[0] < w[1])); 54 | let mut start = eclass.nodes.binary_search(node).unwrap_or_else(|i| i); 55 | let discrim = node.discriminant(); 56 | while start > 0 { 57 | if eclass.nodes[start - 1].discriminant() == discrim { 58 | start -= 1; 59 | } else { 60 | break; 61 | } 62 | } 63 | let mut matching = eclass.nodes[start..] 64 | .iter() 65 | .take_while(|&n| n.discriminant() == discrim) 66 | .filter(|n| node.matches(n)); 67 | debug_assert_eq!( 68 | matching.clone().count(), 69 | eclass.nodes.iter().filter(|n| node.matches(n)).count(), 70 | "matching node {:?}\nstart={}\n{:?} != {:?}\nnodes: {:?}", 71 | node, 72 | start, 73 | matching.clone().collect::>(), 74 | eclass 75 | .nodes 76 | .iter() 77 | .filter(|n| node.matches(n)) 78 | .collect::>(), 79 | eclass.nodes 80 | ); 81 | matching.try_for_each(&mut f) 82 | } 83 | } 84 | 85 | impl Machine { 86 | #[inline(always)] 87 | fn reg(&self, reg: Reg) -> Id { 88 | self.reg[reg.0 as usize] 89 | } 90 | 91 | fn run( 92 | &mut self, 93 | egraph: &EGraph, 94 | instructions: &[Instruction], 95 | subst: &Subst, 96 | yield_fn: &mut impl FnMut(&Self, &Subst) -> Result, 97 | ) -> Result 98 | where 99 | L: Language, 100 | N: Analysis, 101 | { 102 | let mut instructions = instructions.iter(); 103 | while let Some(instruction) = instructions.next() { 104 | match instruction { 105 | Instruction::Bind { i, out, node } => { 106 | let remaining_instructions = instructions.as_slice(); 107 | return for_each_matching_node(&egraph[self.reg(*i)], node, |matched| { 108 | self.reg.truncate(out.0 as usize); 109 | matched.for_each(|id| self.reg.push(id)); 110 | self.run(egraph, remaining_instructions, subst, yield_fn) 111 | }); 112 | } 113 | Instruction::Scan { out } => { 114 | let remaining_instructions = instructions.as_slice(); 115 | for class in egraph.classes() { 116 | self.reg.truncate(out.0 as usize); 117 | self.reg.push(class.id); 118 | self.run(egraph, remaining_instructions, subst, yield_fn)? 119 | } 120 | return Ok(()); 121 | } 122 | Instruction::Compare { i, j } => { 123 | if egraph.find(self.reg(*i)) != egraph.find(self.reg(*j)) { 124 | return Ok(()); 125 | } 126 | } 127 | Instruction::Lookup { term, i } => { 128 | self.lookup.clear(); 129 | for node in term { 130 | match node { 131 | ENodeOrReg::ENode(node) => { 132 | let look = |i| self.lookup[usize::from(i)]; 133 | match egraph.lookup(node.clone().map_children(look)) { 134 | Some(id) => self.lookup.push(id), 135 | None => return Ok(()), 136 | } 137 | } 138 | ENodeOrReg::Reg(r) => { 139 | self.lookup.push(egraph.find(self.reg(*r))); 140 | } 141 | } 142 | } 143 | 144 | let id = egraph.find(self.reg(*i)); 145 | if self.lookup.last().copied() != Some(id) { 146 | return Ok(()); 147 | } 148 | } 149 | } 150 | } 151 | 152 | yield_fn(self, subst) 153 | } 154 | } 155 | 156 | struct Compiler { 157 | v2r: IndexMap, 158 | free_vars: Vec>, 159 | subtree_size: Vec, 160 | todo_nodes: HashMap<(Id, Reg), L>, 161 | instructions: Vec>, 162 | next_reg: Reg, 163 | } 164 | 165 | impl Compiler { 166 | fn new() -> Self { 167 | Self { 168 | free_vars: Default::default(), 169 | subtree_size: Default::default(), 170 | v2r: Default::default(), 171 | todo_nodes: Default::default(), 172 | instructions: Default::default(), 173 | next_reg: Reg(0), 174 | } 175 | } 176 | 177 | fn add_todo(&mut self, pattern: &PatternAst, id: Id, reg: Reg) { 178 | match &pattern[id] { 179 | ENodeOrVar::Var(v) => { 180 | if let Some(&j) = self.v2r.get(v) { 181 | self.instructions.push(Instruction::Compare { i: reg, j }) 182 | } else { 183 | self.v2r.insert(*v, reg); 184 | } 185 | } 186 | ENodeOrVar::ENode(pat) => { 187 | self.todo_nodes.insert((id, reg), pat.clone()); 188 | } 189 | } 190 | } 191 | 192 | fn load_pattern(&mut self, pattern: &PatternAst) { 193 | let len = pattern.as_ref().len(); 194 | self.free_vars = Vec::with_capacity(len); 195 | self.subtree_size = Vec::with_capacity(len); 196 | 197 | for node in pattern.as_ref() { 198 | let mut free = HashSet::default(); 199 | let mut size = 0; 200 | match node { 201 | ENodeOrVar::ENode(n) => { 202 | size = 1; 203 | for &child in n.children() { 204 | free.extend(&self.free_vars[usize::from(child)]); 205 | size += self.subtree_size[usize::from(child)]; 206 | } 207 | } 208 | ENodeOrVar::Var(v) => { 209 | free.insert(*v); 210 | } 211 | } 212 | self.free_vars.push(free); 213 | self.subtree_size.push(size); 214 | } 215 | } 216 | 217 | fn next(&mut self) -> Option<((Id, Reg), L)> { 218 | // we take the max todo according to this key 219 | // - prefer grounded 220 | // - prefer more free variables 221 | // - prefer smaller term 222 | let key = |(id, _): &&(Id, Reg)| { 223 | let i = usize::from(*id); 224 | let n_bound = self.free_vars[i] 225 | .iter() 226 | .filter(|v| self.v2r.contains_key(*v)) 227 | .count(); 228 | let n_free = self.free_vars[i].len() - n_bound; 229 | let size = self.subtree_size[i] as isize; 230 | (n_free == 0, n_free, -size) 231 | }; 232 | 233 | self.todo_nodes 234 | .keys() 235 | .max_by_key(key) 236 | .copied() 237 | .map(|k| (k, self.todo_nodes.remove(&k).unwrap())) 238 | } 239 | 240 | /// check to see if this e-node corresponds to a term that is grounded by 241 | /// the variables bound at this point 242 | fn is_ground_now(&self, id: Id) -> bool { 243 | self.free_vars[usize::from(id)] 244 | .iter() 245 | .all(|v| self.v2r.contains_key(v)) 246 | } 247 | 248 | fn compile(&mut self, patternbinder: Option, pattern: &PatternAst) { 249 | self.load_pattern(pattern); 250 | let last_i = pattern.as_ref().len() - 1; 251 | 252 | let mut next_out = self.next_reg; 253 | 254 | // Check if patternbinder already bound in v2r 255 | // Behavior common to creating a new pattern 256 | let add_new_pattern = |comp: &mut Compiler| { 257 | if !comp.instructions.is_empty() { 258 | // After first pattern needs scan 259 | comp.instructions 260 | .push(Instruction::Scan { out: comp.next_reg }); 261 | } 262 | comp.add_todo(pattern, Id::from(last_i), comp.next_reg); 263 | }; 264 | 265 | if let Some(v) = patternbinder { 266 | if let Some(&i) = self.v2r.get(&v) { 267 | // patternbinder already bound 268 | self.add_todo(pattern, Id::from(last_i), i); 269 | } else { 270 | // patternbinder is new variable 271 | next_out.0 += 1; 272 | add_new_pattern(self); 273 | self.v2r.insert(v, self.next_reg); //add to known variables. 274 | } 275 | } else { 276 | // No pattern binder 277 | next_out.0 += 1; 278 | add_new_pattern(self); 279 | } 280 | 281 | while let Some(((id, reg), node)) = self.next() { 282 | if self.is_ground_now(id) && !node.is_leaf() { 283 | let extracted = pattern.extract(id); 284 | self.instructions.push(Instruction::Lookup { 285 | i: reg, 286 | term: extracted 287 | .as_ref() 288 | .iter() 289 | .map(|n| match n { 290 | ENodeOrVar::ENode(n) => ENodeOrReg::ENode(n.clone()), 291 | ENodeOrVar::Var(v) => ENodeOrReg::Reg(self.v2r[v]), 292 | }) 293 | .collect(), 294 | }); 295 | } else { 296 | let out = next_out; 297 | next_out.0 += node.len() as u32; 298 | 299 | // zero out the children so Bind can use it to sort 300 | let op = node.clone().map_children(|_| Id::from(0)); 301 | self.instructions.push(Instruction::Bind { 302 | i: reg, 303 | node: op, 304 | out, 305 | }); 306 | 307 | for (i, &child) in node.children().iter().enumerate() { 308 | self.add_todo(pattern, child, Reg(out.0 + i as u32)); 309 | } 310 | } 311 | } 312 | self.next_reg = next_out; 313 | } 314 | 315 | fn extract(self) -> Program { 316 | let mut subst = Subst::default(); 317 | for (v, r) in self.v2r { 318 | subst.insert(v, Id::from(r.0 as usize)); 319 | } 320 | Program { 321 | instructions: self.instructions, 322 | subst, 323 | } 324 | } 325 | } 326 | 327 | impl Program { 328 | pub(crate) fn compile_from_pat(pattern: &PatternAst) -> Self { 329 | let mut compiler = Compiler::new(); 330 | compiler.compile(None, pattern); 331 | let program = compiler.extract(); 332 | log::debug!("Compiled {:?} to {:?}", pattern.as_ref(), program); 333 | program 334 | } 335 | 336 | pub(crate) fn compile_from_multi_pat(patterns: &[(Var, PatternAst)]) -> Self { 337 | let mut compiler = Compiler::new(); 338 | for (var, pattern) in patterns { 339 | compiler.compile(Some(*var), pattern); 340 | } 341 | compiler.extract() 342 | } 343 | 344 | pub fn run_with_limit( 345 | &self, 346 | egraph: &EGraph, 347 | eclass: Id, 348 | mut limit: usize, 349 | ) -> Vec 350 | where 351 | A: Analysis, 352 | { 353 | assert!(egraph.clean, "Tried to search a dirty e-graph!"); 354 | 355 | if limit == 0 { 356 | return vec![]; 357 | } 358 | 359 | let mut machine = Machine::default(); 360 | assert_eq!(machine.reg.len(), 0); 361 | machine.reg.push(eclass); 362 | 363 | let mut matches = Vec::new(); 364 | machine 365 | .run( 366 | egraph, 367 | &self.instructions, 368 | &self.subst, 369 | &mut |machine, subst| { 370 | let subst_vec = subst 371 | .vec 372 | .iter() 373 | // HACK we are reusing Ids here, this is bad 374 | .map(|(v, reg_id)| (*v, machine.reg(Reg(usize::from(*reg_id) as u32)))) 375 | .collect(); 376 | matches.push(Subst { vec: subst_vec }); 377 | limit -= 1; 378 | if limit != 0 { 379 | Ok(()) 380 | } else { 381 | Err(()) 382 | } 383 | }, 384 | ) 385 | .unwrap_or_default(); 386 | 387 | log::trace!("Ran program, found {:?}", matches); 388 | matches 389 | } 390 | } 391 | -------------------------------------------------------------------------------- /src/macros.rs: -------------------------------------------------------------------------------- 1 | #[allow(unused_imports)] 2 | use crate::*; 3 | 4 | /** A macro to easily create a [`Language`]. 5 | 6 | `define_language` derives `Debug`, `PartialEq`, `Eq`, `PartialOrd`, `Ord`, 7 | `Hash`, and `Clone` on the given `enum` so it can implement [`Language`]. 8 | The macro also implements [`Display`] and [`FromOp`] for the `enum` 9 | based on either the data of variants or the provided strings. 10 | 11 | The final variant **must have a trailing comma**; this is due to limitations in 12 | macro parsing. 13 | 14 | The language discriminant will use the cases of the enum (the enum discriminant). 15 | 16 | See [`LanguageChildren`] for acceptable types of children `Id`s. 17 | 18 | Note that you can always implement [`Language`] yourself by just not using this 19 | macro. 20 | 21 | Presently, the macro does not support data variant with children, but that may 22 | be added later. 23 | 24 | # Example 25 | 26 | The following macro invocation shows the the accepted forms of variants: 27 | ``` 28 | # use egg::*; 29 | define_language! { 30 | enum SimpleLanguage { 31 | // string variant with no children 32 | "pi" = Pi, 33 | 34 | // string variants with an array of child `Id`s (any static size) 35 | // any type that implements LanguageChildren may be used here 36 | "+" = Add([Id; 2]), 37 | "-" = Sub([Id; 2]), 38 | "*" = Mul([Id; 2]), 39 | 40 | // can also do a variable number of children in a boxed slice 41 | // this will only match if the lengths are the same 42 | "list" = List(Box<[Id]>), 43 | 44 | // string variants with a single child `Id` 45 | // note that this is distinct from `Sub`, even though it has the same 46 | // string, because it has a different number of children 47 | "-" = Neg(Id), 48 | 49 | // data variants with a single field 50 | // this field must implement `FromStr` and `Display` 51 | Num(i32), 52 | // language items are parsed in order, and we want symbol to 53 | // be a fallback, so we put it last 54 | Symbol(Symbol), 55 | // This is the ultimate fallback, it will parse any operator (as a string) 56 | // and any number of children. 57 | // Note that if there were 0 children, the previous branch would have succeeded 58 | Other(Symbol, Vec), 59 | } 60 | } 61 | ``` 62 | 63 | [`Display`]: std::fmt::Display 64 | **/ 65 | #[macro_export] 66 | macro_rules! define_language { 67 | ($(#[$meta:meta])* $vis:vis enum $name:ident $variants:tt) => { 68 | $crate::__define_language!($(#[$meta])* $vis enum $name $variants -> {} {} {} {} {} {}); 69 | }; 70 | } 71 | 72 | #[doc(hidden)] 73 | #[macro_export] 74 | macro_rules! __define_language { 75 | ($(#[$meta:meta])* $vis:vis enum $name:ident {} -> 76 | $decl:tt {$($matches:tt)*} $children:tt $children_mut:tt 77 | $display:tt {$($from_op:tt)*} 78 | ) => { 79 | $(#[$meta])* 80 | #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] 81 | $vis enum $name $decl 82 | 83 | impl $crate::Language for $name { 84 | type Discriminant = std::mem::Discriminant; 85 | 86 | #[inline(always)] 87 | fn discriminant(&self) -> Self::Discriminant { 88 | std::mem::discriminant(self) 89 | } 90 | 91 | #[inline(always)] 92 | fn matches(&self, other: &Self) -> bool { 93 | ::std::mem::discriminant(self) == ::std::mem::discriminant(other) && 94 | match (self, other) { $($matches)* _ => false } 95 | } 96 | 97 | fn children(&self) -> &[Id] { match self $children } 98 | fn children_mut(&mut self) -> &mut [Id] { match self $children_mut } 99 | } 100 | 101 | impl ::std::fmt::Display for $name { 102 | fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { 103 | // We need to pass `f` to the match expression for hygiene 104 | // reasons. 105 | match (self, f) $display 106 | } 107 | } 108 | 109 | impl $crate::FromOp for $name { 110 | type Error = $crate::FromOpError; 111 | 112 | fn from_op(op: &str, children: ::std::vec::Vec<$crate::Id>) -> ::std::result::Result { 113 | match (op, children) { 114 | $($from_op)* 115 | (op, children) => Err($crate::FromOpError::new(op, children)), 116 | } 117 | } 118 | } 119 | }; 120 | 121 | ($(#[$meta:meta])* $vis:vis enum $name:ident 122 | { 123 | $string:literal = $variant:ident, 124 | $($variants:tt)* 125 | } -> 126 | { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* } 127 | { $($display:tt)* } { $($from_op:tt)* } 128 | ) => { 129 | $crate::__define_language!( 130 | $(#[$meta])* $vis enum $name 131 | { $($variants)* } -> 132 | { $($decl)* $variant, } 133 | { $($matches)* ($name::$variant, $name::$variant) => true, } 134 | { $($children)* $name::$variant => &[], } 135 | { $($children_mut)* $name::$variant => &mut [], } 136 | { $($display)* ($name::$variant, f) => f.write_str($string), } 137 | { $($from_op)* ($string, children) if children.is_empty() => Ok($name::$variant), } 138 | ); 139 | }; 140 | 141 | ($(#[$meta:meta])* $vis:vis enum $name:ident 142 | { 143 | $string:literal = $variant:ident ($ids:ty), 144 | $($variants:tt)* 145 | } -> 146 | { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* } 147 | { $($display:tt)* } { $($from_op:tt)* } 148 | ) => { 149 | $crate::__define_language!( 150 | $(#[$meta])* $vis enum $name 151 | { $($variants)* } -> 152 | { $($decl)* $variant($ids), } 153 | { $($matches)* ($name::$variant(l), $name::$variant(r)) => $crate::LanguageChildren::len(l) == $crate::LanguageChildren::len(r), } 154 | { $($children)* $name::$variant(ids) => $crate::LanguageChildren::as_slice(ids), } 155 | { $($children_mut)* $name::$variant(ids) => $crate::LanguageChildren::as_mut_slice(ids), } 156 | { $($display)* ($name::$variant(..), f) => f.write_str($string), } 157 | { $($from_op)* (op, children) if op == $string && <$ids as $crate::LanguageChildren>::can_be_length(children.len()) => { 158 | let children = <$ids as $crate::LanguageChildren>::from_vec(children); 159 | Ok($name::$variant(children)) 160 | }, 161 | } 162 | ); 163 | }; 164 | 165 | ($(#[$meta:meta])* $vis:vis enum $name:ident 166 | { 167 | $variant:ident ($data:ty), 168 | $($variants:tt)* 169 | } -> 170 | { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* } 171 | { $($display:tt)* } { $($from_op:tt)* } 172 | ) => { 173 | $crate::__define_language!( 174 | $(#[$meta])* $vis enum $name 175 | { $($variants)* } -> 176 | { $($decl)* $variant($data), } 177 | { $($matches)* ($name::$variant(data1), $name::$variant(data2)) => data1 == data2, } 178 | { $($children)* $name::$variant(_data) => &[], } 179 | { $($children_mut)* $name::$variant(_data) => &mut [], } 180 | { $($display)* ($name::$variant(data), f) => ::std::fmt::Display::fmt(data, f), } 181 | { $($from_op)* (op, children) if op.parse::<$data>().is_ok() && children.is_empty() => Ok($name::$variant(op.parse().unwrap())), } 182 | ); 183 | }; 184 | 185 | ($(#[$meta:meta])* $vis:vis enum $name:ident 186 | { 187 | $variant:ident ($data:ty, $ids:ty), 188 | $($variants:tt)* 189 | } -> 190 | { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* } 191 | { $($display:tt)* } { $($from_op:tt)* } 192 | ) => { 193 | $crate::__define_language!( 194 | $(#[$meta])* $vis enum $name 195 | { $($variants)* } -> 196 | { $($decl)* $variant($data, $ids), } 197 | { $($matches)* ($name::$variant(d1, l), $name::$variant(d2, r)) => d1 == d2 && $crate::LanguageChildren::len(l) == $crate::LanguageChildren::len(r), } 198 | { $($children)* $name::$variant(_, ids) => $crate::LanguageChildren::as_slice(ids), } 199 | { $($children_mut)* $name::$variant(_, ids) => $crate::LanguageChildren::as_mut_slice(ids), } 200 | { $($display)* ($name::$variant(data, _), f) => ::std::fmt::Display::fmt(data, f), } 201 | { $($from_op)* (op, children) if op.parse::<$data>().is_ok() && <$ids as $crate::LanguageChildren>::can_be_length(children.len()) => { 202 | let data = op.parse::<$data>().unwrap(); 203 | let children = <$ids as $crate::LanguageChildren>::from_vec(children); 204 | Ok($name::$variant(data, children)) 205 | }, 206 | } 207 | ); 208 | }; 209 | } 210 | 211 | /** A macro to easily make [`Rewrite`]s. 212 | 213 | The `rewrite!` macro greatly simplifies creating simple, purely 214 | syntactic rewrites while also allowing more complex ones. 215 | 216 | This panics if [`Rewrite::new`](Rewrite::new()) fails. 217 | 218 | The simplest form `rewrite!(a; b => c)` creates a [`Rewrite`] 219 | with name `a`, [`Searcher`] `b`, and [`Applier`] `c`. 220 | Note that in the `b` and `c` position, the macro only accepts a single 221 | token tree (see the [macros reference][macro] for more info). 222 | In short, that means you should pass in an identifier, literal, or 223 | something surrounded by parentheses or braces. 224 | 225 | If you pass in a literal to the `b` or `c` position, the macro will 226 | try to parse it as a [`Pattern`] which implements both [`Searcher`] 227 | and [`Applier`]. 228 | 229 | The macro also accepts any number of `if ` forms at the end, 230 | where the given expression should implement [`Condition`]. 231 | For each of these, the macro will wrap the given applier in a 232 | [`ConditionalApplier`] with the given condition, with the first condition being 233 | the outermost, and the last condition being the innermost. 234 | 235 | # Example 236 | ``` 237 | # use egg::*; 238 | use std::borrow::Cow; 239 | use std::sync::Arc; 240 | define_language! { 241 | enum SimpleLanguage { 242 | Num(i32), 243 | "+" = Add([Id; 2]), 244 | "-" = Sub([Id; 2]), 245 | "*" = Mul([Id; 2]), 246 | "/" = Div([Id; 2]), 247 | } 248 | } 249 | 250 | type EGraph = egg::EGraph; 251 | 252 | let mut rules: Vec> = vec![ 253 | rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), 254 | rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), 255 | 256 | rewrite!("mul-0"; "(* ?a 0)" => "0"), 257 | 258 | rewrite!("silly"; "(* ?a 1)" => { MySillyApplier("foo") }), 259 | 260 | rewrite!("something_conditional"; 261 | "(/ ?a ?b)" => "(* ?a (/ 1 ?b))" 262 | if is_not_zero("?b")), 263 | ]; 264 | 265 | // rewrite! supports bidirectional rules too 266 | // it returns a Vec of length 2, so you need to concat 267 | rules.extend(vec![ 268 | rewrite!("add-0"; "(+ ?a 0)" <=> "?a"), 269 | rewrite!("mul-1"; "(* ?a 1)" <=> "?a"), 270 | ].concat()); 271 | 272 | #[derive(Debug)] 273 | struct MySillyApplier(&'static str); 274 | impl Applier for MySillyApplier { 275 | fn apply_one(&self, _: &mut EGraph, _: Id, _: &Subst, _: Option<&PatternAst>, _: Symbol) -> Vec { 276 | panic!() 277 | } 278 | } 279 | 280 | // This returns a function that implements Condition 281 | fn is_not_zero(var: &'static str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { 282 | let var = var.parse().unwrap(); 283 | let zero = SimpleLanguage::Num(0); 284 | // note this check is just an example, 285 | // checking for the absence of 0 is insufficient since 0 could be merged in later 286 | // see https://github.com/egraphs-good/egg/issues/297 287 | move |egraph, _, subst| !egraph[subst[var]].nodes.contains(&zero) 288 | } 289 | ``` 290 | 291 | [macro]: https://doc.rust-lang.org/stable/reference/macros-by-example.html#metavariables 292 | **/ 293 | #[macro_export] 294 | macro_rules! rewrite { 295 | ( 296 | $name:expr; 297 | $lhs:tt => $rhs:tt 298 | $(if $cond:expr)* 299 | ) => {{ 300 | let searcher = $crate::__rewrite!(@parse Pattern $lhs); 301 | let core_applier = $crate::__rewrite!(@parse Pattern $rhs); 302 | let applier = $crate::__rewrite!(@applier core_applier; $($cond,)*); 303 | $crate::Rewrite::new($name.to_string(), searcher, applier).unwrap() 304 | }}; 305 | ( 306 | $name:expr; 307 | $lhs:tt <=> $rhs:tt 308 | $(if $cond:expr)* 309 | ) => {{ 310 | let name = $name; 311 | let name2 = String::from(name.clone()) + "-rev"; 312 | vec![ 313 | $crate::rewrite!(name; $lhs => $rhs $(if $cond)*), 314 | $crate::rewrite!(name2; $rhs => $lhs $(if $cond)*) 315 | ] 316 | }}; 317 | } 318 | 319 | /** A macro to easily make [`Rewrite`]s using [`MultiPattern`]s. 320 | 321 | Similar to the [`rewrite!`] macro, 322 | this macro uses the form `multi_rewrite!(name; multipattern => multipattern)`. 323 | String literals will be parsed a [`MultiPattern`]s. 324 | 325 | **/ 326 | #[macro_export] 327 | macro_rules! multi_rewrite { 328 | // limited multipattern support 329 | ( 330 | $name:expr; 331 | $lhs:tt => $rhs:tt 332 | ) => {{ 333 | let searcher = $crate::__rewrite!(@parse MultiPattern $lhs); 334 | let applier = $crate::__rewrite!(@parse MultiPattern $rhs); 335 | $crate::Rewrite::new($name.to_string(), searcher, applier).unwrap() 336 | }}; 337 | } 338 | 339 | #[doc(hidden)] 340 | #[macro_export] 341 | macro_rules! __rewrite { 342 | (@parse $t:ident $rhs:literal) => { 343 | $rhs.parse::<$crate::$t<_>>().unwrap() 344 | }; 345 | (@parse $t:ident $rhs:expr) => { $rhs }; 346 | (@applier $applier:expr;) => { $applier }; 347 | (@applier $applier:expr; $cond:expr, $($conds:expr,)*) => { 348 | $crate::ConditionalApplier { 349 | condition: $cond, 350 | applier: $crate::__rewrite!(@applier $applier; $($conds,)*) 351 | } 352 | }; 353 | } 354 | 355 | #[cfg(test)] 356 | mod tests { 357 | 358 | use crate::*; 359 | 360 | define_language! { 361 | enum Simple { 362 | "+" = Add([Id; 2]), 363 | "-" = Sub([Id; 2]), 364 | "*" = Mul([Id; 2]), 365 | "-" = Neg(Id), 366 | "list" = List(Box<[Id]>), 367 | "pi" = Pi, 368 | Int(i32), 369 | Var(Symbol), 370 | } 371 | } 372 | 373 | #[test] 374 | fn modify_children() { 375 | let mut add = Simple::Add([0.into(), 0.into()]); 376 | add.for_each_mut(|id| *id = 1.into()); 377 | assert_eq!(add, Simple::Add([1.into(), 1.into()])); 378 | } 379 | 380 | #[test] 381 | fn some_rewrites() { 382 | let mut rws: Vec> = vec![ 383 | // here it should parse the rhs 384 | rewrite!("rule"; "cons" => "f"), 385 | // here it should just accept the rhs without trying to parse 386 | rewrite!("rule"; "f" => { "pat".parse::>().unwrap() }), 387 | ]; 388 | rws.extend(rewrite!("two-way"; "foo" <=> "bar")); 389 | } 390 | 391 | #[test] 392 | #[should_panic(expected = "refers to unbound var ?x")] 393 | fn rewrite_simple_panic() { 394 | let _: Rewrite = rewrite!("bad"; "?a" => "?x"); 395 | } 396 | 397 | #[test] 398 | #[should_panic(expected = "refers to unbound var ?x")] 399 | fn rewrite_conditional_panic() { 400 | let x: Pattern = "?x".parse().unwrap(); 401 | let _: Rewrite = rewrite!( 402 | "bad"; "?a" => "?a" if ConditionEqual::new(x.clone(), x) 403 | ); 404 | } 405 | } 406 | -------------------------------------------------------------------------------- /src/multipattern.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | use thiserror::Error; 3 | 4 | use crate::*; 5 | 6 | /// A set of open expressions bound to variables. 7 | /// 8 | /// Multipatterns bind many expressions to variables, 9 | /// allowing for simultaneous searching or application of many terms 10 | /// constrained to the same substitution. 11 | /// 12 | /// Multipatterns are good for writing graph rewrites or datalog-style rules. 13 | /// 14 | /// You can create multipatterns via the [`MultiPattern::new`] function or the 15 | /// [`multi_rewrite!`] macro. 16 | /// 17 | /// [`MultiPattern`] implements both [`Searcher`] and [`Applier`]. 18 | /// When searching a multipattern, the result ensures that 19 | /// patterns bound to the same variable are equivalent. 20 | /// When applying a multipattern, patterns bound a variable occuring in the 21 | /// searcher are unioned with that e-class. 22 | /// 23 | /// Multipatterns currently do not support the explanations feature. 24 | #[derive(Debug, PartialEq, Eq, Clone)] 25 | pub struct MultiPattern { 26 | asts: Vec<(Var, PatternAst)>, 27 | program: machine::Program, 28 | } 29 | 30 | impl MultiPattern { 31 | /// Creates a new multipattern, binding the given patterns to the corresponding variables. 32 | /// 33 | /// ``` 34 | /// use egg::*; 35 | /// 36 | /// let mut egraph = EGraph::::default(); 37 | /// egraph.add_expr(&"(f a a)".parse().unwrap()); 38 | /// egraph.add_expr(&"(f a b)".parse().unwrap()); 39 | /// egraph.add_expr(&"(g a a)".parse().unwrap()); 40 | /// egraph.add_expr(&"(g a b)".parse().unwrap()); 41 | /// egraph.rebuild(); 42 | /// 43 | /// let f_pat: PatternAst = "(f ?x ?y)".parse().unwrap(); 44 | /// let g_pat: PatternAst = "(g ?x ?y)".parse().unwrap(); 45 | /// let v1: Var = "?v1".parse().unwrap(); 46 | /// let v2: Var = "?v2".parse().unwrap(); 47 | /// 48 | /// let multipattern = MultiPattern::new(vec![(v1, f_pat), (v2, g_pat)]); 49 | /// // you can also parse multipatterns 50 | /// assert_eq!(multipattern, "?v1 = (f ?x ?y), ?v2 = (g ?x ?y)".parse().unwrap()); 51 | /// 52 | /// assert_eq!(multipattern.n_matches(&egraph), 2); 53 | /// ``` 54 | pub fn new(asts: Vec<(Var, PatternAst)>) -> Self { 55 | let program = machine::Program::compile_from_multi_pat(&asts); 56 | Self { asts, program } 57 | } 58 | } 59 | 60 | #[derive(Debug, Error)] 61 | /// An error raised when parsing a [`MultiPattern`] 62 | pub enum MultiPatternParseError { 63 | /// One of the patterns in the multipattern failed to parse. 64 | #[error(transparent)] 65 | PatternParseError(E), 66 | /// One of the clauses in the multipattern wasn't of the form `?var (= pattern)+`. 67 | #[error("Bad clause in the multipattern: {0}")] 68 | PatternAssignmentError(String), 69 | /// One of the variables failed to parse. 70 | #[error(transparent)] 71 | VariableError(::Err), 72 | } 73 | 74 | impl FromStr for MultiPattern { 75 | type Err = MultiPatternParseError< as FromStr>::Err>; 76 | 77 | fn from_str(s: &str) -> Result { 78 | use MultiPatternParseError::*; 79 | let mut asts = vec![]; 80 | for split in s.trim().split(',') { 81 | let split = split.trim(); 82 | if split.is_empty() { 83 | continue; 84 | } 85 | let mut parts = split.split('='); 86 | let vs: &str = parts 87 | .next() 88 | .ok_or_else(|| PatternAssignmentError(split.into()))?; 89 | let v: Var = vs.trim().parse().map_err(VariableError)?; 90 | let ps = parts 91 | .map(|p| p.trim().parse()) 92 | .collect::>, _>>() 93 | .map_err(PatternParseError)?; 94 | if ps.is_empty() { 95 | return Err(PatternAssignmentError(split.into())); 96 | } 97 | asts.extend(ps.into_iter().map(|p| (v, p))) 98 | } 99 | Ok(MultiPattern::new(asts)) 100 | } 101 | } 102 | 103 | impl> Searcher for MultiPattern { 104 | fn search_eclass_with_limit( 105 | &self, 106 | egraph: &EGraph, 107 | eclass: Id, 108 | limit: usize, 109 | ) -> Option> { 110 | let substs = self.program.run_with_limit(egraph, eclass, limit); 111 | if substs.is_empty() { 112 | None 113 | } else { 114 | Some(SearchMatches { 115 | eclass, 116 | substs, 117 | ast: None, 118 | }) 119 | } 120 | } 121 | 122 | fn vars(&self) -> Vec { 123 | let mut vars = vec![]; 124 | for (v, pat) in &self.asts { 125 | vars.push(*v); 126 | for n in pat.as_ref() { 127 | if let ENodeOrVar::Var(v) = n { 128 | vars.push(*v) 129 | } 130 | } 131 | } 132 | vars.sort(); 133 | vars.dedup(); 134 | vars 135 | } 136 | } 137 | 138 | impl> Applier for MultiPattern { 139 | fn apply_one( 140 | &self, 141 | _egraph: &mut EGraph, 142 | _eclass: Id, 143 | _subst: &Subst, 144 | _searcher_ast: Option<&PatternAst>, 145 | _rule_name: Symbol, 146 | ) -> Vec { 147 | panic!("Multipatterns do not support apply_one") 148 | } 149 | 150 | fn apply_matches( 151 | &self, 152 | egraph: &mut EGraph, 153 | matches: &[SearchMatches], 154 | _rule_name: Symbol, 155 | ) -> Vec { 156 | // TODO explanations? 157 | // the ids returned are kinda garbage 158 | let mut added = vec![]; 159 | for mat in matches { 160 | for subst in &mat.substs { 161 | let mut subst = subst.clone(); 162 | let mut id_buf = vec![]; 163 | for (i, (v, p)) in self.asts.iter().enumerate() { 164 | id_buf.resize(p.as_ref().len(), 0.into()); 165 | let id1 = crate::pattern::apply_pat(&mut id_buf, p.as_ref(), egraph, &subst); 166 | if let Some(id2) = subst.insert(*v, id1) { 167 | egraph.union(id1, id2); 168 | } 169 | if i == 0 { 170 | added.push(id1) 171 | } 172 | } 173 | } 174 | } 175 | added 176 | } 177 | 178 | fn vars(&self) -> Vec { 179 | let mut bound_vars = HashSet::default(); 180 | let mut vars = vec![]; 181 | for (bv, pat) in &self.asts { 182 | for n in pat.as_ref() { 183 | if let ENodeOrVar::Var(v) = n { 184 | // using vars that are already bound doesn't count 185 | if !bound_vars.contains(v) { 186 | vars.push(*v) 187 | } 188 | } 189 | } 190 | bound_vars.insert(bv); 191 | } 192 | vars.sort(); 193 | vars.dedup(); 194 | vars 195 | } 196 | } 197 | 198 | #[cfg(test)] 199 | mod tests { 200 | use crate::{SymbolLang as S, *}; 201 | 202 | type EGraph = crate::EGraph; 203 | 204 | impl EGraph { 205 | fn add_string(&mut self, s: &str) -> Id { 206 | self.add_expr(&s.parse().unwrap()) 207 | } 208 | } 209 | 210 | #[test] 211 | #[should_panic = "unbound var ?z"] 212 | fn bad_unbound_var() { 213 | let _: Rewrite = multi_rewrite!("foo"; "?x = (foo ?y)" => "?x = ?z"); 214 | } 215 | 216 | #[test] 217 | fn ok_unbound_var() { 218 | let _: Rewrite = multi_rewrite!("foo"; "?x = (foo ?y)" => "?z = (baz ?y), ?x = ?z"); 219 | } 220 | 221 | #[test] 222 | fn multi_patterns() { 223 | crate::init_logger(); 224 | let mut egraph = EGraph::default(); 225 | let _ = egraph.add_expr(&"(f a a)".parse().unwrap()); 226 | let ab = egraph.add_expr(&"(f a b)".parse().unwrap()); 227 | let ac = egraph.add_expr(&"(f a c)".parse().unwrap()); 228 | egraph.union(ab, ac); 229 | egraph.rebuild(); 230 | 231 | let n_matches = |multipattern: &str| -> usize { 232 | let mp: MultiPattern = multipattern.parse().unwrap(); 233 | mp.n_matches(&egraph) 234 | }; 235 | 236 | assert_eq!(n_matches("?x = (f a a), ?y = (f ?c b)"), 1); 237 | assert_eq!(n_matches("?x = (f a a), ?y = (f a b)"), 1); 238 | assert_eq!(n_matches("?x = (f a a), ?y = (f a a)"), 1); 239 | assert_eq!(n_matches("?x = (f ?a ?b), ?y = (f ?c ?d)"), 9); 240 | assert_eq!(n_matches("?x = (f ?a a), ?y = (f ?a b)"), 1); 241 | 242 | assert_eq!(n_matches("?x = (f a a), ?x = (f a c)"), 0); 243 | assert_eq!(n_matches("?x = (f a b), ?x = (f a c)"), 1); 244 | } 245 | 246 | #[test] 247 | fn unbound_rhs() { 248 | let mut egraph = EGraph::default(); 249 | let _x = egraph.add_expr(&"(x)".parse().unwrap()); 250 | let rules = vec![ 251 | // Rule creates y and z if they don't exist. 252 | multi_rewrite!("rule1"; "?x = (x)" => "?y = (y), ?y = (z)"), 253 | // Can't fire without above rule. `y` and `z` don't already exist in egraph 254 | multi_rewrite!("rule2"; "?x = (x), ?y = (y), ?z = (z)" => "?y = (y), ?y = (z)"), 255 | ]; 256 | let mut runner = Runner::default().with_egraph(egraph).run(&rules); 257 | let y = runner.egraph.add_expr(&"(y)".parse().unwrap()); 258 | let z = runner.egraph.add_expr(&"(z)".parse().unwrap()); 259 | assert_eq!(runner.egraph.find(y), runner.egraph.find(z)); 260 | } 261 | 262 | #[test] 263 | fn ctx_transfer() { 264 | let mut egraph = EGraph::default(); 265 | egraph.add_string("(lte ctx1 ctx2)"); 266 | egraph.add_string("(lte ctx2 ctx2)"); 267 | egraph.add_string("(lte ctx1 ctx1)"); 268 | let x2 = egraph.add_string("(tag x ctx2)"); 269 | let y2 = egraph.add_string("(tag y ctx2)"); 270 | let z2 = egraph.add_string("(tag z ctx2)"); 271 | 272 | let x1 = egraph.add_string("(tag x ctx1)"); 273 | let y1 = egraph.add_string("(tag y ctx1)"); 274 | let z1 = egraph.add_string("(tag z ctx2)"); 275 | egraph.union(x1, y1); 276 | egraph.union(y2, z2); 277 | let rules = vec![multi_rewrite!("context-transfer"; 278 | "?x = (tag ?a ?ctx1) = (tag ?b ?ctx1), 279 | ?t = (lte ?ctx1 ?ctx2), 280 | ?a1 = (tag ?a ?ctx2), 281 | ?b1 = (tag ?b ?ctx2)" 282 | => 283 | "?a1 = ?b1")]; 284 | let runner = Runner::default().with_egraph(egraph).run(&rules); 285 | assert_eq!(runner.egraph.find(x1), runner.egraph.find(y1)); 286 | assert_eq!(runner.egraph.find(y2), runner.egraph.find(z2)); 287 | 288 | assert_eq!(runner.egraph.find(x2), runner.egraph.find(y2)); 289 | assert_eq!(runner.egraph.find(x2), runner.egraph.find(z2)); 290 | 291 | assert_ne!(runner.egraph.find(y1), runner.egraph.find(z1)); 292 | assert_ne!(runner.egraph.find(x1), runner.egraph.find(z1)); 293 | } 294 | } 295 | -------------------------------------------------------------------------------- /src/pattern.rs: -------------------------------------------------------------------------------- 1 | use fmt::Formatter; 2 | use log::*; 3 | use std::borrow::Cow; 4 | use std::fmt::{self, Display}; 5 | use std::{convert::TryFrom, str::FromStr}; 6 | 7 | use thiserror::Error; 8 | 9 | use crate::*; 10 | 11 | /// A pattern that can function as either a [`Searcher`] or [`Applier`]. 12 | /// 13 | /// A [`Pattern`] is essentially a for-all quantified expression with 14 | /// [`Var`]s as the variables (in the logical sense). 15 | /// 16 | /// When creating a [`Rewrite`], the most common thing to use as either 17 | /// the left hand side (the [`Searcher`]) or the right hand side 18 | /// (the [`Applier`]) is a [`Pattern`]. 19 | /// 20 | /// As a [`Searcher`], a [`Pattern`] does the intuitive 21 | /// thing. 22 | /// Here is a somewhat verbose formal-ish statement: 23 | /// Searching for a pattern in an egraph yields substitutions 24 | /// ([`Subst`]s) _s_ such that, for any _s'_—where instead of 25 | /// mapping a variables to an eclass as _s_ does, _s'_ maps 26 | /// a variable to an arbitrary expression represented by that 27 | /// eclass—_p[s']_ (the pattern under substitution _s'_) is also 28 | /// represented by the egraph. 29 | /// 30 | /// As an [`Applier`], a [`Pattern`] performs the given substitution 31 | /// and adds the result to the [`EGraph`]. 32 | /// 33 | /// Importantly, [`Pattern`] implements [`FromStr`] if the 34 | /// [`Language`] does. 35 | /// This is probably how you'll create most [`Pattern`]s. 36 | /// 37 | /// ``` 38 | /// use egg::*; 39 | /// define_language! { 40 | /// enum Math { 41 | /// Num(i32), 42 | /// "+" = Add([Id; 2]), 43 | /// } 44 | /// } 45 | /// 46 | /// let mut egraph = EGraph::::default(); 47 | /// let a11 = egraph.add_expr(&"(+ 1 1)".parse().unwrap()); 48 | /// let a22 = egraph.add_expr(&"(+ 2 2)".parse().unwrap()); 49 | /// 50 | /// // use Var syntax (leading question mark) to get a 51 | /// // variable in the Pattern 52 | /// let same_add: Pattern = "(+ ?a ?a)".parse().unwrap(); 53 | /// 54 | /// // Rebuild before searching 55 | /// egraph.rebuild(); 56 | /// 57 | /// // This is the search method from the Searcher trait 58 | /// let matches = same_add.search(&egraph); 59 | /// let matched_eclasses: Vec = matches.iter().map(|m| m.eclass).collect(); 60 | /// assert_eq!(matched_eclasses, vec![a11, a22]); 61 | /// ``` 62 | /// 63 | /// [`FromStr`]: std::str::FromStr 64 | #[derive(Debug, PartialEq, Eq, Clone)] 65 | pub struct Pattern { 66 | /// The actual pattern as a [`RecExpr`] 67 | pub ast: PatternAst, 68 | program: machine::Program, 69 | } 70 | 71 | /// A [`RecExpr`] that represents a 72 | /// [`Pattern`]. 73 | pub type PatternAst = RecExpr>; 74 | 75 | impl PatternAst { 76 | /// Returns a new `PatternAst` with the variables renames canonically 77 | pub fn alpha_rename(&self) -> Self { 78 | let mut vars = HashMap::::default(); 79 | let mut new = PatternAst::default(); 80 | 81 | fn mkvar(i: usize) -> Var { 82 | let vs = &["?x", "?y", "?z", "?w"]; 83 | match vs.get(i) { 84 | Some(v) => v.parse().unwrap(), 85 | None => format!("?v{}", i - vs.len()).parse().unwrap(), 86 | } 87 | } 88 | 89 | for n in self.as_ref() { 90 | new.add(match n { 91 | ENodeOrVar::ENode(_) => n.clone(), 92 | ENodeOrVar::Var(v) => { 93 | let i = vars.len(); 94 | ENodeOrVar::Var(*vars.entry(*v).or_insert_with(|| mkvar(i))) 95 | } 96 | }); 97 | } 98 | 99 | new 100 | } 101 | } 102 | 103 | impl Pattern { 104 | /// Creates a new pattern from the given pattern ast. 105 | pub fn new(ast: PatternAst) -> Self { 106 | let ast = ast.compact(); 107 | let program = machine::Program::compile_from_pat(&ast); 108 | Pattern { ast, program } 109 | } 110 | 111 | /// Returns a list of the [`Var`]s in this pattern. 112 | pub fn vars(&self) -> Vec { 113 | let mut vars = vec![]; 114 | for n in self.ast.as_ref() { 115 | if let ENodeOrVar::Var(v) = n { 116 | if !vars.contains(v) { 117 | vars.push(*v) 118 | } 119 | } 120 | } 121 | vars 122 | } 123 | } 124 | 125 | impl Pattern { 126 | /// Pretty print this pattern as a sexp with the given width 127 | pub fn pretty(&self, width: usize) -> String { 128 | self.ast.pretty(width) 129 | } 130 | } 131 | 132 | /// The language of [`Pattern`]s. 133 | /// 134 | #[derive(Debug, Hash, PartialEq, Eq, Clone, PartialOrd, Ord)] 135 | pub enum ENodeOrVar { 136 | /// An enode from the underlying [`Language`] 137 | ENode(L), 138 | /// A pattern variable 139 | Var(Var), 140 | } 141 | 142 | /// The discriminant for the language of [`Pattern`]s. 143 | #[derive(Debug, Hash, PartialEq, Eq, Clone)] 144 | pub enum ENodeOrVarDiscriminant { 145 | ENode(L::Discriminant), 146 | Var(Var), 147 | } 148 | 149 | impl Language for ENodeOrVar { 150 | type Discriminant = ENodeOrVarDiscriminant; 151 | 152 | #[inline(always)] 153 | fn discriminant(&self) -> Self::Discriminant { 154 | match self { 155 | ENodeOrVar::ENode(n) => ENodeOrVarDiscriminant::ENode(n.discriminant()), 156 | ENodeOrVar::Var(v) => ENodeOrVarDiscriminant::Var(*v), 157 | } 158 | } 159 | 160 | fn matches(&self, _other: &Self) -> bool { 161 | panic!("Should never call this") 162 | } 163 | 164 | fn children(&self) -> &[Id] { 165 | match self { 166 | ENodeOrVar::ENode(n) => n.children(), 167 | ENodeOrVar::Var(_) => &[], 168 | } 169 | } 170 | 171 | fn children_mut(&mut self) -> &mut [Id] { 172 | match self { 173 | ENodeOrVar::ENode(n) => n.children_mut(), 174 | ENodeOrVar::Var(_) => &mut [], 175 | } 176 | } 177 | } 178 | 179 | impl Display for ENodeOrVar { 180 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 181 | match self { 182 | Self::ENode(node) => Display::fmt(node, f), 183 | Self::Var(var) => Display::fmt(var, f), 184 | } 185 | } 186 | } 187 | 188 | #[derive(Debug, Error)] 189 | pub enum ENodeOrVarParseError { 190 | #[error(transparent)] 191 | BadVar(::Err), 192 | 193 | #[error("tried to parse pattern variable {0:?} as an operator")] 194 | UnexpectedVar(String), 195 | 196 | #[error(transparent)] 197 | BadOp(E), 198 | } 199 | 200 | impl FromOp for ENodeOrVar { 201 | type Error = ENodeOrVarParseError; 202 | 203 | fn from_op(op: &str, children: Vec) -> Result { 204 | use ENodeOrVarParseError::*; 205 | 206 | if op.starts_with('?') && op.len() > 1 { 207 | if children.is_empty() { 208 | op.parse().map(Self::Var).map_err(BadVar) 209 | } else { 210 | Err(UnexpectedVar(op.to_owned())) 211 | } 212 | } else { 213 | L::from_op(op, children).map(Self::ENode).map_err(BadOp) 214 | } 215 | } 216 | } 217 | 218 | impl std::str::FromStr for Pattern { 219 | type Err = RecExprParseError>; 220 | 221 | fn from_str(s: &str) -> Result { 222 | PatternAst::from_str(s).map(Self::from) 223 | } 224 | } 225 | 226 | impl<'a, L: Language> From<&'a [L]> for Pattern { 227 | fn from(expr: &'a [L]) -> Self { 228 | let nodes: Vec<_> = expr.iter().cloned().map(ENodeOrVar::ENode).collect(); 229 | let ast = RecExpr::from(nodes); 230 | Self::new(ast) 231 | } 232 | } 233 | 234 | impl From<&RecExpr> for Pattern { 235 | fn from(expr: &RecExpr) -> Self { 236 | Self::from(expr.as_ref()) 237 | } 238 | } 239 | 240 | impl From> for Pattern { 241 | fn from(ast: PatternAst) -> Self { 242 | Self::new(ast) 243 | } 244 | } 245 | 246 | impl TryFrom> for RecExpr { 247 | type Error = Var; 248 | fn try_from(pat: Pattern) -> Result { 249 | let nodes = pat.ast.as_ref().iter().cloned(); 250 | let ns: Result, _> = nodes 251 | .map(|n| match n { 252 | ENodeOrVar::ENode(n) => Ok(n), 253 | ENodeOrVar::Var(v) => Err(v), 254 | }) 255 | .collect(); 256 | ns.map(RecExpr::from) 257 | } 258 | } 259 | 260 | impl Display for Pattern { 261 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 262 | Display::fmt(&self.ast, f) 263 | } 264 | } 265 | 266 | /// The result of searching a [`Searcher`] over one eclass. 267 | /// 268 | /// Note that one [`SearchMatches`] can contain many found 269 | /// substitutions. So taking the length of a list of [`SearchMatches`] 270 | /// tells you how many eclasses something was matched in, _not_ how 271 | /// many matches were found total. 272 | /// 273 | #[derive(Debug)] 274 | pub struct SearchMatches<'a, L: Language> { 275 | /// The eclass id that these matches were found in. 276 | pub eclass: Id, 277 | /// The substitutions for each match. 278 | pub substs: Vec, 279 | /// Optionally, an ast for the matches used in proof production. 280 | pub ast: Option>>, 281 | } 282 | 283 | impl> Searcher for Pattern { 284 | fn get_pattern_ast(&self) -> Option<&PatternAst> { 285 | Some(&self.ast) 286 | } 287 | 288 | fn search_with_limit(&self, egraph: &EGraph, limit: usize) -> Vec> { 289 | match self.ast.as_ref().last().unwrap() { 290 | ENodeOrVar::ENode(e) => { 291 | let key = e.discriminant(); 292 | match egraph.classes_by_op.get(&key) { 293 | None => vec![], 294 | Some(ids) => rewrite::search_eclasses_with_limit( 295 | self, 296 | egraph, 297 | ids.iter().cloned(), 298 | limit, 299 | ), 300 | } 301 | } 302 | ENodeOrVar::Var(_) => rewrite::search_eclasses_with_limit( 303 | self, 304 | egraph, 305 | egraph.classes().map(|e| e.id), 306 | limit, 307 | ), 308 | } 309 | } 310 | 311 | fn search_eclass_with_limit( 312 | &self, 313 | egraph: &EGraph, 314 | eclass: Id, 315 | limit: usize, 316 | ) -> Option> { 317 | let substs = self.program.run_with_limit(egraph, eclass, limit); 318 | if substs.is_empty() { 319 | None 320 | } else { 321 | let ast = Some(Cow::Borrowed(&self.ast)); 322 | Some(SearchMatches { 323 | eclass, 324 | substs, 325 | ast, 326 | }) 327 | } 328 | } 329 | 330 | fn vars(&self) -> Vec { 331 | Pattern::vars(self) 332 | } 333 | } 334 | 335 | impl Applier for Pattern 336 | where 337 | L: Language, 338 | A: Analysis, 339 | { 340 | fn get_pattern_ast(&self) -> Option<&PatternAst> { 341 | Some(&self.ast) 342 | } 343 | 344 | fn apply_matches( 345 | &self, 346 | egraph: &mut EGraph, 347 | matches: &[SearchMatches], 348 | rule_name: Symbol, 349 | ) -> Vec { 350 | let mut added = vec![]; 351 | let ast = self.ast.as_ref(); 352 | let mut id_buf = vec![0.into(); ast.len()]; 353 | for mat in matches { 354 | let sast = mat.ast.as_ref().map(|cow| cow.as_ref()); 355 | for subst in &mat.substs { 356 | let did_something; 357 | let id; 358 | if egraph.are_explanations_enabled() { 359 | let (id_temp, did_something_temp) = 360 | egraph.union_instantiations(sast.unwrap(), &self.ast, subst, rule_name); 361 | did_something = did_something_temp; 362 | id = id_temp; 363 | } else { 364 | id = apply_pat(&mut id_buf, ast, egraph, subst); 365 | did_something = egraph.union(id, mat.eclass); 366 | } 367 | 368 | if did_something { 369 | added.push(id) 370 | } 371 | } 372 | } 373 | added 374 | } 375 | 376 | fn apply_one( 377 | &self, 378 | egraph: &mut EGraph, 379 | eclass: Id, 380 | subst: &Subst, 381 | searcher_ast: Option<&PatternAst>, 382 | rule_name: Symbol, 383 | ) -> Vec { 384 | let ast = self.ast.as_ref(); 385 | let mut id_buf = vec![0.into(); ast.len()]; 386 | let id = apply_pat(&mut id_buf, ast, egraph, subst); 387 | 388 | if let Some(ast) = searcher_ast { 389 | let (from, did_something) = 390 | egraph.union_instantiations(ast, &self.ast, subst, rule_name); 391 | if did_something { 392 | vec![from] 393 | } else { 394 | vec![] 395 | } 396 | } else if egraph.union(eclass, id) { 397 | vec![eclass] 398 | } else { 399 | vec![] 400 | } 401 | } 402 | 403 | fn vars(&self) -> Vec { 404 | Pattern::vars(self) 405 | } 406 | } 407 | 408 | pub(crate) fn apply_pat>( 409 | ids: &mut [Id], 410 | pat: &[ENodeOrVar], 411 | egraph: &mut EGraph, 412 | subst: &Subst, 413 | ) -> Id { 414 | debug_assert_eq!(pat.len(), ids.len()); 415 | trace!("apply_rec {:2?} {:?}", pat, subst); 416 | 417 | for (i, pat_node) in pat.iter().enumerate() { 418 | let id = match pat_node { 419 | ENodeOrVar::Var(w) => subst[*w], 420 | ENodeOrVar::ENode(e) => { 421 | let n = e.clone().map_children(|child| ids[usize::from(child)]); 422 | trace!("adding: {:?}", n); 423 | egraph.add(n) 424 | } 425 | }; 426 | ids[i] = id; 427 | } 428 | 429 | *ids.last().unwrap() 430 | } 431 | 432 | #[cfg(test)] 433 | mod tests { 434 | 435 | use crate::{SymbolLang as S, *}; 436 | 437 | type EGraph = crate::EGraph; 438 | 439 | #[test] 440 | fn simple_match() { 441 | crate::init_logger(); 442 | let mut egraph = EGraph::default(); 443 | 444 | let (plus_id, _) = egraph.union_instantiations( 445 | &"(+ x y)".parse().unwrap(), 446 | &"(+ z w)".parse().unwrap(), 447 | &Default::default(), 448 | "union_plus".to_string(), 449 | ); 450 | egraph.rebuild(); 451 | 452 | let commute_plus = rewrite!( 453 | "commute_plus"; 454 | "(+ ?a ?b)" => "(+ ?b ?a)" 455 | ); 456 | 457 | let matches = commute_plus.search(&egraph); 458 | let n_matches: usize = matches.iter().map(|m| m.substs.len()).sum(); 459 | assert_eq!(n_matches, 2, "matches is wrong: {:#?}", matches); 460 | 461 | let applications = commute_plus.apply(&mut egraph, &matches); 462 | egraph.rebuild(); 463 | assert_eq!(applications.len(), 2); 464 | 465 | let actual_substs: Vec = matches.iter().flat_map(|m| m.substs.clone()).collect(); 466 | 467 | println!("Here are the substs!"); 468 | for m in &actual_substs { 469 | println!("substs: {:?}", m); 470 | } 471 | 472 | egraph.dot().to_dot("target/simple-match.dot").unwrap(); 473 | 474 | use crate::extract::{AstSize, Extractor}; 475 | 476 | let ext = Extractor::new(&egraph, AstSize); 477 | let (_, best) = ext.find_best(plus_id); 478 | eprintln!("Best: {:#?}", best); 479 | } 480 | 481 | #[test] 482 | fn nonlinear_patterns() { 483 | crate::init_logger(); 484 | let mut egraph = EGraph::default(); 485 | egraph.add_expr(&"(f a a)".parse().unwrap()); 486 | egraph.add_expr(&"(f a (g a))))".parse().unwrap()); 487 | egraph.add_expr(&"(f a (g b))))".parse().unwrap()); 488 | egraph.add_expr(&"(h (foo a b) 0 1)".parse().unwrap()); 489 | egraph.add_expr(&"(h (foo a b) 1 0)".parse().unwrap()); 490 | egraph.add_expr(&"(h (foo a b) 0 0)".parse().unwrap()); 491 | egraph.rebuild(); 492 | 493 | let n_matches = |s: &str| s.parse::>().unwrap().n_matches(&egraph); 494 | 495 | assert_eq!(n_matches("(f ?x ?y)"), 3); 496 | assert_eq!(n_matches("(f ?x ?x)"), 1); 497 | assert_eq!(n_matches("(f ?x (g ?y))))"), 2); 498 | assert_eq!(n_matches("(f ?x (g ?x))))"), 1); 499 | assert_eq!(n_matches("(h ?x 0 0)"), 1); 500 | } 501 | 502 | #[test] 503 | fn search_with_limit() { 504 | crate::init_logger(); 505 | let init_expr = &"(+ 1 (+ 2 (+ 3 (+ 4 (+ 5 6)))))".parse().unwrap(); 506 | let rules: Vec> = vec![ 507 | rewrite!("comm"; "(+ ?x ?y)" => "(+ ?y ?x)"), 508 | rewrite!("assoc"; "(+ ?x (+ ?y ?z))" => "(+ (+ ?x ?y) ?z)"), 509 | ]; 510 | let runner = Runner::default().with_expr(init_expr).run(&rules); 511 | let egraph = &runner.egraph; 512 | 513 | let len = |m: &Vec>| -> usize { m.iter().map(|m| m.substs.len()).sum() }; 514 | 515 | let pat = &"(+ ?x (+ ?y ?z))".parse::>().unwrap(); 516 | let m = pat.search(egraph); 517 | let match_size = 2100; 518 | assert_eq!(len(&m), match_size); 519 | 520 | for limit in [1, 10, 100, 1000, 10000] { 521 | let m = pat.search_with_limit(egraph, limit); 522 | assert_eq!(len(&m), usize::min(limit, match_size)); 523 | } 524 | 525 | let id = egraph.lookup_expr(init_expr).unwrap(); 526 | let m = pat.search_eclass(egraph, id).unwrap(); 527 | let match_size = 540; 528 | assert_eq!(m.substs.len(), match_size); 529 | 530 | for limit in [1, 10, 100, 1000] { 531 | let m1 = pat.search_eclass_with_limit(egraph, id, limit).unwrap(); 532 | assert_eq!(m1.substs.len(), usize::min(limit, match_size)); 533 | } 534 | } 535 | } 536 | -------------------------------------------------------------------------------- /src/subst.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | use std::str::FromStr; 3 | 4 | use crate::*; 5 | use fmt::{Debug, Display, Formatter}; 6 | use thiserror::Error; 7 | 8 | /// A variable for use in [`Pattern`]s or [`Subst`]s. 9 | /// 10 | /// This implements [`FromStr`], and will only parse if it has a 11 | /// leading `?`. 12 | /// 13 | /// [`FromStr`]: std::str::FromStr 14 | #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] 15 | pub struct Var(VarInner); 16 | 17 | impl Var { 18 | /// Create a new variable from a u32. 19 | /// 20 | /// You can also use special syntax `?#3`, `?#42` to denote a numeric variable. 21 | /// These avoid some symbol interning, and can also be created manually from 22 | /// using this function or the `From` impl. 23 | /// 24 | /// ```rust 25 | /// # use egg::*; 26 | /// assert_eq!(Var::from(12), "?#12".parse().unwrap()); 27 | /// assert_eq!(Var::from_u32(12), "?#12".parse().unwrap()); 28 | /// ``` 29 | pub fn from_u32(num: u32) -> Self { 30 | Var(VarInner::Num(num)) 31 | } 32 | 33 | /// If this variable was created from a u32, get it back out. 34 | pub fn as_u32(&self) -> Option { 35 | match self.0 { 36 | VarInner::Num(num) => Some(num), 37 | _ => None, 38 | } 39 | } 40 | } 41 | 42 | #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] 43 | enum VarInner { 44 | Sym(Symbol), 45 | Num(u32), 46 | } 47 | 48 | #[derive(Debug, Error)] 49 | pub enum VarParseError { 50 | #[error("pattern variable {0:?} should have a leading question mark")] 51 | MissingQuestionMark(String), 52 | #[error("number pattern variable {0:?} was malformed")] 53 | BadNumber(String), 54 | } 55 | 56 | impl FromStr for Var { 57 | type Err = VarParseError; 58 | 59 | fn from_str(s: &str) -> Result { 60 | use VarParseError::*; 61 | 62 | match s.as_bytes() { 63 | [b'?', b'#', ..] => s[2..] 64 | .parse() 65 | .map(|num| Var(VarInner::Num(num))) 66 | .map_err(|_| BadNumber(s.to_owned())), 67 | [b'?', ..] if s.len() > 1 => Ok(Var(VarInner::Sym(Symbol::from(s)))), 68 | _ => Err(MissingQuestionMark(s.to_owned())), 69 | } 70 | } 71 | } 72 | 73 | impl Display for Var { 74 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 75 | match self.0 { 76 | VarInner::Sym(sym) => write!(f, "{}", sym), 77 | VarInner::Num(num) => write!(f, "?#{}", num), 78 | } 79 | } 80 | } 81 | 82 | impl Debug for Var { 83 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 84 | match self.0 { 85 | VarInner::Sym(sym) => write!(f, "{:?}", sym), 86 | VarInner::Num(num) => write!(f, "?#{}", num), 87 | } 88 | } 89 | } 90 | 91 | impl From for Var { 92 | fn from(num: u32) -> Self { 93 | Var(VarInner::Num(num)) 94 | } 95 | } 96 | 97 | /// A substitution mapping [`Var`]s to eclass [`Id`]s. 98 | /// 99 | #[derive(Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] 100 | pub struct Subst { 101 | pub(crate) vec: smallvec::SmallVec<[(Var, Id); 3]>, 102 | } 103 | 104 | impl Subst { 105 | /// Create a `Subst` with the given initial capacity 106 | pub fn with_capacity(capacity: usize) -> Self { 107 | Self { 108 | vec: smallvec::SmallVec::with_capacity(capacity), 109 | } 110 | } 111 | 112 | /// Insert something, returning the old `Id` if present. 113 | pub fn insert(&mut self, var: Var, id: Id) -> Option { 114 | for pair in &mut self.vec { 115 | if pair.0 == var { 116 | return Some(std::mem::replace(&mut pair.1, id)); 117 | } 118 | } 119 | self.vec.push((var, id)); 120 | None 121 | } 122 | 123 | /// Retrieve a `Var`, returning `None` if not present. 124 | #[inline(never)] 125 | pub fn get(&self, var: Var) -> Option<&Id> { 126 | self.vec 127 | .iter() 128 | .find_map(|(v, id)| if *v == var { Some(id) } else { None }) 129 | } 130 | } 131 | 132 | impl std::ops::Index for Subst { 133 | type Output = Id; 134 | 135 | fn index(&self, var: Var) -> &Self::Output { 136 | match self.get(var) { 137 | Some(id) => id, 138 | None => panic!("Var '{}={}' not found in {:?}", var, var, self), 139 | } 140 | } 141 | } 142 | 143 | impl Debug for Subst { 144 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 145 | let len = self.vec.len(); 146 | write!(f, "{{")?; 147 | for i in 0..len { 148 | let (var, id) = &self.vec[i]; 149 | write!(f, "{}: {}", var, id)?; 150 | if i < len - 1 { 151 | write!(f, ", ")?; 152 | } 153 | } 154 | write!(f, "}}") 155 | } 156 | } 157 | 158 | #[cfg(test)] 159 | mod tests { 160 | use super::*; 161 | 162 | #[test] 163 | fn var_parse() { 164 | assert_eq!(Var::from_str("?a").unwrap().to_string(), "?a"); 165 | assert_eq!(Var::from_str("?abc 123").unwrap().to_string(), "?abc 123"); 166 | assert!(Var::from_str("a").is_err()); 167 | assert!(Var::from_str("a?").is_err()); 168 | assert!(Var::from_str("?").is_err()); 169 | assert!(Var::from_str("?#").is_err()); 170 | assert!(Var::from_str("?#foo").is_err()); 171 | 172 | // numeric vars 173 | assert_eq!(Var::from_str("?#0").unwrap(), Var(VarInner::Num(0))); 174 | assert_eq!(Var::from_str("?#010").unwrap(), Var(VarInner::Num(10))); 175 | assert_eq!( 176 | Var::from_str("?#10").unwrap(), 177 | Var::from_str("?#0010").unwrap() 178 | ); 179 | assert_eq!(Var::from_str("?#010").unwrap(), Var(VarInner::Num(10))); 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /src/test.rs: -------------------------------------------------------------------------------- 1 | /*! Utilities for testing / benchmarking egg. 2 | 3 | These are not considered part of the public api. 4 | */ 5 | 6 | use std::{fmt::Display, fs::File, io::Write, path::PathBuf}; 7 | 8 | use saturating::Saturating; 9 | 10 | use crate::*; 11 | 12 | pub fn env_var(s: &str) -> Option 13 | where 14 | T: std::str::FromStr, 15 | T::Err: std::fmt::Debug, 16 | { 17 | use std::env::VarError; 18 | match std::env::var(s) { 19 | Err(VarError::NotPresent) => None, 20 | Err(VarError::NotUnicode(_)) => panic!("Environment variable {} isn't unicode", s), 21 | Ok(v) if v.is_empty() => None, 22 | Ok(v) => match v.parse() { 23 | Ok(v) => Some(v), 24 | Err(err) => panic!("Couldn't parse environment variable {}={}, {:?}", s, v, err), 25 | }, 26 | } 27 | } 28 | 29 | #[allow(clippy::type_complexity)] 30 | pub fn test_runner( 31 | name: &str, 32 | runner: Option>, 33 | rules: &[Rewrite], 34 | start: RecExpr, 35 | goals: &[Pattern], 36 | check_fn: Option)>, 37 | should_check: bool, 38 | ) where 39 | L: Language + Display + FromOp + 'static, 40 | A: Analysis + Default, 41 | { 42 | let _ = env_logger::builder().is_test(true).try_init(); 43 | let mut runner = runner.unwrap_or_default(); 44 | 45 | if let Some(lim) = env_var("EGG_NODE_LIMIT") { 46 | runner = runner.with_node_limit(lim) 47 | } 48 | if let Some(lim) = env_var("EGG_ITER_LIMIT") { 49 | runner = runner.with_iter_limit(lim) 50 | } 51 | if let Some(lim) = env_var("EGG_TIME_LIMIT") { 52 | runner = runner.with_time_limit(std::time::Duration::from_secs(lim)) 53 | } 54 | 55 | // Force sure explanations on if feature is on 56 | if cfg!(feature = "test-explanations") { 57 | runner = runner.with_explanations_enabled(); 58 | } 59 | 60 | runner = runner.with_expr(&start); 61 | // NOTE this is a bit of hack, we rely on the fact that the 62 | // initial root is the last expr added by the runner. We can't 63 | // use egraph.find_expr(start) because it may have been pruned 64 | // away 65 | let id = runner.egraph.find(*runner.roots.last().unwrap()); 66 | 67 | if check_fn.is_none() { 68 | let goals = goals.to_vec(); 69 | runner = runner.with_hook(move |r| { 70 | if goals 71 | .iter() 72 | .all(|g: &Pattern<_>| g.search_eclass(&r.egraph, id).is_some()) 73 | { 74 | Err("Proved all goals".into()) 75 | } else { 76 | Ok(()) 77 | } 78 | }); 79 | } 80 | let mut runner = runner.run(rules); 81 | 82 | if should_check { 83 | let report = runner.report(); 84 | println!("{report}"); 85 | runner.egraph.check_goals(id, goals); 86 | 87 | if let Some(filename) = env_var::("EGG_BENCH_CSV") { 88 | let mut file = File::options() 89 | .create(true) 90 | .append(true) 91 | .open(&filename) 92 | .unwrap_or_else(|_| panic!("Couldn't open {:?}", filename)); 93 | writeln!(file, "{},{}", name, runner.report().total_time).unwrap(); 94 | } 95 | 96 | if runner.egraph.are_explanations_enabled() { 97 | for goal in goals { 98 | let matches = goal.search_eclass(&runner.egraph, id).unwrap(); 99 | let subst = matches.substs[0].clone(); 100 | // don't optimize the length for the first egraph 101 | runner = runner.without_explanation_length_optimization(); 102 | let mut explained = runner.explain_matches(&start, &goal.ast, &subst); 103 | explained.get_string_with_let(); 104 | let flattened = explained.make_flat_explanation().clone(); 105 | let vanilla_len = flattened.len(); 106 | explained.check_proof(rules); 107 | assert!(explained.get_tree_size() > Saturating(0)); 108 | 109 | runner = runner.with_explanation_length_optimization(); 110 | let mut explained_short = runner.explain_matches(&start, &goal.ast, &subst); 111 | explained_short.get_string_with_let(); 112 | let short_len = explained_short.get_flat_strings().len(); 113 | assert!(short_len <= vanilla_len); 114 | assert!(explained_short.get_tree_size() > Saturating(0)); 115 | explained_short.check_proof(rules); 116 | } 117 | } 118 | 119 | if let Some(check_fn) = check_fn { 120 | check_fn(runner) 121 | } 122 | } 123 | } 124 | 125 | fn percentile(k: f64, data: &[u128]) -> u128 { 126 | // assumes data is sorted 127 | assert!((0.0..=1.0).contains(&k)); 128 | let i = (data.len() as f64 * k) as usize; 129 | let i = i.min(data.len() - 1); 130 | data[i] 131 | } 132 | 133 | pub fn bench_egraph( 134 | _name: &str, 135 | rules: Vec>, 136 | exprs: &[&str], 137 | extra_patterns: &[&str], 138 | ) -> EGraph 139 | where 140 | L: Language + FromOp + 'static + Display, 141 | N: Analysis + Default + 'static, 142 | { 143 | let mut patterns: Vec> = vec![]; 144 | for rule in &rules { 145 | if let Some(ast) = rule.searcher.get_pattern_ast() { 146 | patterns.push(ast.alpha_rename().into()) 147 | } 148 | if let Some(ast) = rule.applier.get_pattern_ast() { 149 | patterns.push(ast.alpha_rename().into()) 150 | } 151 | } 152 | for extra in extra_patterns { 153 | let p: Pattern = extra.parse().unwrap(); 154 | patterns.push(p.ast.alpha_rename().into()); 155 | } 156 | 157 | eprintln!("{} patterns", patterns.len()); 158 | 159 | patterns.retain(|p| p.ast.as_ref().len() > 1); 160 | patterns.sort_by_key(|p| p.to_string()); 161 | patterns.dedup(); 162 | patterns.sort_by_key(|p| p.ast.as_ref().len()); 163 | 164 | let iter_limit = env_var("EGG_ITER_LIMIT").unwrap_or(1); 165 | let node_limit = env_var("EGG_NODE_LIMIT").unwrap_or(1_000_000); 166 | let time_limit = env_var("EGG_TIME_LIMIT").unwrap_or(1000); 167 | let n_samples = env_var("EGG_SAMPLES").unwrap_or(100); 168 | eprintln!("Benching {} samples", n_samples); 169 | eprintln!( 170 | "Limits: {} iters, {} nodes, {} seconds", 171 | iter_limit, node_limit, time_limit 172 | ); 173 | 174 | let mut runner = Runner::default() 175 | .with_scheduler(SimpleScheduler) 176 | .with_hook(move |runner| { 177 | let n_nodes = runner.egraph.total_number_of_nodes(); 178 | eprintln!("Iter {}, {} nodes", runner.iterations.len(), n_nodes); 179 | if n_nodes > node_limit { 180 | Err("Bench stopped".into()) 181 | } else { 182 | Ok(()) 183 | } 184 | }) 185 | .with_iter_limit(iter_limit) 186 | .with_node_limit(node_limit) 187 | .with_time_limit(Duration::from_secs(time_limit)); 188 | 189 | for expr in exprs { 190 | runner = runner.with_expr(&expr.parse().unwrap()); 191 | } 192 | 193 | let runner = runner.run(&rules); 194 | eprintln!("{}", runner.report()); 195 | let egraph = runner.egraph; 196 | 197 | let get_len = |pat: &Pattern| pat.to_string().len(); 198 | let max_width = patterns.iter().map(get_len).max().unwrap_or(0); 199 | for pat in &patterns { 200 | let mut times: Vec = (0..n_samples) 201 | .map(|_| { 202 | let start = Instant::now(); 203 | let matches = pat.search(&egraph); 204 | let time = start.elapsed(); 205 | let _n_results = matches.iter().map(|m| m.substs.len()).sum::(); 206 | time.as_nanos() 207 | }) 208 | .collect(); 209 | times.sort_unstable(); 210 | 211 | println!( 212 | "test {name:10} ns/iter (+/- {iqr})", 213 | name = pat.to_string().replace(' ', "_"), 214 | width = max_width, 215 | time = percentile(0.05, ×), 216 | iqr = percentile(0.75, ×) - percentile(0.25, ×), 217 | ); 218 | } 219 | 220 | egraph 221 | } 222 | 223 | /// Utility to make a test proving expressions equivalent 224 | /// 225 | /// # Example 226 | /// 227 | /// ``` 228 | /// # use egg::*; 229 | /// egg::test_fn! { 230 | /// // name of the generated test function 231 | /// my_test_name, 232 | /// // the rules to use 233 | /// [ 234 | /// rewrite!("my_silly_rewrite"; "(foo ?a)" => "(bar ?a)"), 235 | /// rewrite!("my_other_rewrite"; "(bar ?a)" => "(baz ?a)"), 236 | /// ], 237 | /// // the `runner = ...` is optional 238 | /// // if included, this must come right after the rules 239 | /// runner = Runner::::default(), 240 | /// // the initial expression 241 | /// "(foo 1)" => 242 | /// // 1 or more goal expressions, all of which will be check to be 243 | /// // equivalent to the initial one 244 | /// "(bar 1)", 245 | /// "(baz 1)", 246 | /// } 247 | /// ``` 248 | #[macro_export] 249 | macro_rules! test_fn { 250 | ( 251 | $(#[$meta:meta])* 252 | $name:ident, $rules:expr, 253 | $(runner = $runner:expr,)? 254 | $start:literal 255 | => 256 | $($goal:literal),+ $(,)? 257 | $(@check $check_fn:expr)? 258 | ) => { 259 | 260 | $(#[$meta])* 261 | #[test] 262 | pub fn $name() { 263 | // NOTE this is no longer needed, we always check 264 | let check = true; 265 | $crate::test::test_runner( 266 | stringify!($name), 267 | None $(.or(Some($runner)))?, 268 | &$rules, 269 | $start.parse().unwrap(), 270 | &[$( $goal.parse().unwrap() ),+], 271 | None $(.or(Some($check_fn)))?, 272 | check, 273 | ) 274 | }}; 275 | } 276 | -------------------------------------------------------------------------------- /src/tutorials/_02_getting_started.rs: -------------------------------------------------------------------------------- 1 | // -*- mode: markdown; markdown-fontify-code-block-default-mode: rustic-mode; evil-shift-width: 2; -*- 2 | /*! 3 | 4 | # My first `egg` 🐣 5 | 6 | This tutorial is aimed at getting you up and running with `egg`, 7 | even if you have little Rust experience. 8 | If you haven't heard about e-graphs, you may want to read the 9 | [background tutorial](super::_01_background). 10 | If you do have prior Rust experience, you may want to skim around in this section. 11 | 12 | ## Getting started with Rust 13 | 14 | [Rust](https://rust-lang.org) 15 | is one of the reasons why `egg` 16 | is fast (systems programming + optimizing compiler) and flexible (generics and traits). 17 | 18 | The Rust folks have put together a great, free [book](https://doc.rust-lang.org/stable/book/) for learning Rust, 19 | and there are a bunch of other fantastic resources collected on the 20 | ["Learn" page of the Rust site](https://www.rust-lang.org/learn). 21 | This tutorial is no replacement for those, 22 | but instead it aims to get you up and running as fast as possible. 23 | 24 | First, 25 | [install](https://www.rust-lang.org/tools/install) Rust 26 | and let's [create a project](https://doc.rust-lang.org/cargo/getting-started/first-steps.html) 27 | with `cargo`, Rust's package management and build tool: `cargo new my-first-egg`.[^lib] 28 | 29 | [^lib]: By default `cargo` will create a binary project. 30 | If you are just getting starting with Rust, 31 | it might be easier to stick with a binary project, 32 | just put all your code in `main`, and use `cargo run`. 33 | Library projects (`cargo new --lib my-first-egg`) 34 | can be easier to build on once you want to start writing tests. 35 | 36 | Now we can add `egg` as a project dependency by adding a line to `Cargo.toml`: 37 | ```toml 38 | [dependencies] 39 | egg = "0.9.5" 40 | ``` 41 | 42 | All of the code samples below work, but you'll have to `use` the relevant types. 43 | You can just bring them all in with a `use egg::*;` at the top of your file. 44 | 45 | ## Now you're speaking my [`Language`] 46 | 47 | [`EGraph`]s (and almost everything else in this crate) are 48 | parameterized over the [`Language`] given by the user. 49 | While `egg` supports the ability easily create your own [`Language`], 50 | we will instead start with the provided [`SymbolLang`]. 51 | 52 | [`Language`] is a trait, 53 | and values of types that implement [`Language`] are e-nodes. 54 | An e-node may have any number of children, which are [`Id`]s. 55 | An [`Id`] is basically just a number that `egg` uses to coordinate what children 56 | an e-node is associated with. 57 | In an [`EGraph`], e-node children refer to e-classes. 58 | In a [`RecExpr`] (`egg`'s version of a plain old expression), 59 | e-node children refer to other e-nodes in that [`RecExpr`]. 60 | 61 | Most [`Language`]s, including [`SymbolLang`], can be parsed and pretty-printed. 62 | That means that [`RecExpr`]s in those languages 63 | implement the [`FromStr`] and [`Display`] traits from the Rust standard library. 64 | ``` 65 | # use egg::*; 66 | // Since parsing can return an error, `unwrap` just panics if the result doesn't return Ok 67 | let my_expression: RecExpr = "(foo a b)".parse().unwrap(); 68 | println!("this is my expression {}", my_expression); 69 | 70 | // let's try to create an e-node, but hmmm, what do I put as the children? 71 | let my_enode = SymbolLang::new("bar", vec![]); 72 | ``` 73 | 74 | Some e-nodes are just constants and have no children (also called leaves). 75 | But it's intentionally kind of awkward to create e-nodes with children in isolation, 76 | since you would have to add meaningless [`Id`]s as children. 77 | The way to make meaningful [`Id`]s is by adding e-nodes to either an [`EGraph`] or a [`RecExpr`]: 78 | 79 | ``` 80 | # use egg::*; 81 | let mut expr = RecExpr::default(); 82 | let a = expr.add(SymbolLang::leaf("a")); 83 | let b = expr.add(SymbolLang::leaf("b")); 84 | let foo = expr.add(SymbolLang::new("foo", vec![a, b])); 85 | 86 | // we can do the same thing with an EGraph 87 | let mut egraph: EGraph = Default::default(); 88 | let a = egraph.add(SymbolLang::leaf("a")); 89 | let b = egraph.add(SymbolLang::leaf("b")); 90 | let foo = egraph.add(SymbolLang::new("foo", vec![a, b])); 91 | 92 | // we can also add RecExprs to an egraph 93 | let foo2 = egraph.add_expr(&expr); 94 | // note that if you add the same thing to an e-graph twice, you'll get back equivalent Ids 95 | assert_eq!(foo, foo2); 96 | ``` 97 | 98 | ## Searching an [`EGraph`] with [`Pattern`]s 99 | 100 | Now that we can add stuff to an [`EGraph`], let's see if we can find it. 101 | We'll use a [`Pattern`], which implements the [`Searcher`] trait, 102 | to search the e-graph for matches: 103 | 104 | ``` 105 | # use egg::*; 106 | // let's make an e-graph 107 | let mut egraph: EGraph = Default::default(); 108 | let a = egraph.add(SymbolLang::leaf("a")); 109 | let b = egraph.add(SymbolLang::leaf("b")); 110 | let foo = egraph.add(SymbolLang::new("foo", vec![a, b])); 111 | 112 | // rebuild the e-graph since we modified it 113 | egraph.rebuild(); 114 | 115 | // we can make Patterns by parsing, similar to RecExprs 116 | // names preceded by ? are parsed as Pattern variables and will match anything 117 | let pat: Pattern = "(foo ?x ?x)".parse().unwrap(); 118 | 119 | // since we use ?x twice, it must match the same thing, 120 | // so this search will return nothing 121 | let matches = pat.search(&egraph); 122 | assert!(matches.is_empty()); 123 | 124 | egraph.union(a, b); 125 | // recall that rebuild must be called to "see" the effects of adds or unions 126 | egraph.rebuild(); 127 | 128 | // now we can find a match since a = b 129 | let matches = pat.search(&egraph); 130 | assert!(!matches.is_empty()); 131 | ``` 132 | 133 | 134 | 135 | ## Using [`Runner`] to make an optimizer 136 | 137 | Now that we can make [`Pattern`]s and work with [`RecExpr`]s, we can make an optimizer! 138 | We'll use the [`rewrite!`] macro to easily create [`Rewrite`]s which consist of a name, 139 | left-hand pattern to search for, 140 | and right-hand pattern to apply. 141 | From there we can use the [`Runner`] API to run `egg`'s equality saturation algorithm. 142 | Finally, we can use an [`Extractor`] to get the best result. 143 | ``` 144 | use egg::{*, rewrite as rw}; 145 | 146 | let rules: &[Rewrite] = &[ 147 | rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"), 148 | rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"), 149 | 150 | rw!("add-0"; "(+ ?x 0)" => "?x"), 151 | rw!("mul-0"; "(* ?x 0)" => "0"), 152 | rw!("mul-1"; "(* ?x 1)" => "?x"), 153 | ]; 154 | 155 | // While it may look like we are working with numbers, 156 | // SymbolLang stores everything as strings. 157 | // We can make our own Language later to work with other types. 158 | let start = "(+ 0 (* 1 a))".parse().unwrap(); 159 | 160 | // That's it! We can run equality saturation now. 161 | let runner = Runner::default().with_expr(&start).run(rules); 162 | 163 | // Extractors can take a user-defined cost function, 164 | // we'll use the egg-provided AstSize for now 165 | let extractor = Extractor::new(&runner.egraph, AstSize); 166 | 167 | // We want to extract the best expression represented in the 168 | // same e-class as our initial expression, not from the whole e-graph. 169 | // Luckily the runner stores the eclass Id where we put the initial expression. 170 | let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); 171 | 172 | // we found the best thing, which is just "a" in this case 173 | assert_eq!(best_expr, "a".parse().unwrap()); 174 | assert_eq!(best_cost, 1); 175 | ``` 176 | 177 | [`EGraph`]: super::super::EGraph 178 | [`Id`]: super::super::Id 179 | [`Language`]: super::super::Language 180 | [`Searcher`]: super::super::Searcher 181 | [`Pattern`]: super::super::Pattern 182 | [`RecExpr`]: super::super::RecExpr 183 | [`SymbolLang`]: super::super::SymbolLang 184 | [`define_language!`]: super::super::define_language! 185 | [`rewrite!`]: super::super::rewrite! 186 | [`FromStr`]: std::str::FromStr 187 | [`Display`]: std::fmt::Display 188 | [`Rewrite`]: super::super::Rewrite 189 | [`Runner`]: super::super::Runner 190 | [`Extractor`]: super::super::Extractor 191 | 192 | */ 193 | -------------------------------------------------------------------------------- /src/tutorials/_03_explanations.rs: -------------------------------------------------------------------------------- 1 | // -*- mode: markdown; markdown-fontify-code-block-default-mode: rustic-mode; evil-shift-width: 2; -*- 2 | /*! 3 | 4 | # Explanations 5 | 6 | It's often useful to know exactly why two terms are equivalent in 7 | the egraph. 8 | For example, if you are trying to debug incorrect rules, 9 | it would be useful to have a trace of rewrites showing how an example 10 | given bad equivalence was found. 11 | `egg` uses an algorithm adapted from 12 | [Proof-Producing Congruence Closure](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.76.1716&rep=rep1&type=pdf) 13 | in order to generate such [`Explanation`]s between two given terms. 14 | 15 | Consider this program, which prints a [`FlatExplanation`] showing how 16 | `(/ (* (/ 2 3) (/ 3 2)) 1)` can be simplified to `1`: 17 | ``` 18 | use egg::{*, rewrite as rw}; 19 | let rules: &[Rewrite] = &[ 20 | rw!("div-one"; "?x" => "(/ ?x 1)"), 21 | rw!("unsafe-invert-division"; "(/ ?a ?b)" => "(/ 1 (/ ?b ?a))"), 22 | rw!("simplify-frac"; "(/ ?a (/ ?b ?c))" => "(/ (* ?a ?c) (* (/ ?b ?c) ?c))"), 23 | rw!("cancel-denominator"; "(* (/ ?a ?b) ?b)" => "?a"), 24 | rw!("times-zero"; "(* ?a 0)" => "0"), 25 | ]; 26 | 27 | let start = "(/ (* (/ 2 3) (/ 3 2)) 1)".parse().unwrap(); 28 | let end = "1".parse().unwrap(); 29 | let mut runner = Runner::default().with_explanations_enabled().with_expr(&start).run(rules); 30 | 31 | println!("{}", runner.explain_equivalence(&start, &end).get_flat_string()); 32 | ``` 33 | 34 | The output of the program is a series of s-expressions annotated 35 | with the rewrite being performed: 36 | ```text 37 | (/ (* (/ 2 3) (/ 3 2)) 1) 38 | (Rewrite<= div-one (* (/ 2 3) (/ 3 2))) 39 | (* (Rewrite=> unsafe-invert-division (/ 1 (/ 3 2))) (/ 3 2)) 40 | (Rewrite=> cancel-denominator 1) 41 | ``` 42 | At each step, the part of the term being rewritten is annotated 43 | with the rule being applied. 44 | Each term besides the first term has exactly one rewrite annotation. 45 | `Rewrite=>` indicates that the previous term is rewritten to the current term 46 | and `Rewrite<=` indicates that the current term is rewritten to the previous term. 47 | 48 | It turns out that these rules can easily lead to undesirable results in the egraph. 49 | For example, with just `0` as the starting term, the egraph finds that `0` is equivalent 50 | to `1` within a few iterations. 51 | Here's the flattened explanation that `egg` generates: 52 | ```text 53 | 0 54 | (Rewrite<= times-zero (* (/ 1 0) 0)) 55 | (Rewrite=> cancel-denominator 1) 56 | ``` 57 | 58 | This tells you how the egraph got from `0` to `1`, but it's not clear why. 59 | In fact, normally the rules `times-zero` and `cancel-denominator` are perfectly 60 | reasonable. 61 | However, in the presence of a division by zero, they lead to arbitrary unions in the egraph. 62 | So the true problem is the presense of the term `(/ 1 0)`. 63 | For these kinds of questions, `egg` provides the `explain_existance` function which can be used to get an explanation 64 | of why a term exists in the egraph in the first place. 65 | 66 | 67 | # Explanation Trees 68 | 69 | So far we have looked at the [`FlatExplanation`] represenation of explanations because 70 | they are the most human-readable. 71 | But explanations can also be used for automatic testing or translation validation of egraph results, 72 | so the flat representation is not always necessary. 73 | In fact, the flattened representation misses the opportunity to share parts of the explanation 74 | among several different terms. 75 | Egraphs tend to generate explanations with a large amount of duplication of explanations 76 | from one term to another, making explanation-sharing very important. 77 | To solve this problem, `egg` provides the [`TreeExplanation`] representation. 78 | 79 | Here's an example [`TreeExplanation`] in string form: 80 | ```text 81 | (+ 1 (- a (* (- 2 1) a))) 82 | (+ 83 | 1 84 | (Explanation 85 | (- a (* (- 2 1) a)) 86 | (- 87 | a 88 | (Explanation 89 | (* (- 2 1) a) 90 | (* (Explanation (- 2 1) (Rewrite=> constant_fold 1)) a) 91 | (Rewrite=> comm-mul (* a 1)) 92 | (Rewrite<= mul-one a))) 93 | (Rewrite=> cancel-sub 0))) 94 | (Rewrite=> constant_fold 1) 95 | ``` 96 | 97 | The big difference between [`FlatExplanation`] and [`TreeExplanation`] is that now 98 | children of terms can contain explanations themselves. 99 | So a [`TreeTerm`] can have have each of their children be rewritten from an initial term 100 | to a final term, making the representation more compact. 101 | In addition, the string format supports let bindings in order to allow sharing of explantions: 102 | 103 | ```text 104 | (let 105 | (v_0 (- 2 1)) 106 | (let 107 | (v_1 (- 2 (Explanation v_0 (Rewrite=> constant_fold 1)))) 108 | (Explanation 109 | (* (- 2 (- 2 1)) (- 2 (- 2 1))) 110 | (* 111 | (Explanation (- 2 (- 2 1)) v_1 (Rewrite=> constant_fold 1)) 112 | (Explanation (- 2 (- 2 1)) v_1 (Rewrite=> constant_fold 1))) 113 | (Rewrite=> constant_fold 1)))) 114 | ``` 115 | As you can see, the let binding allows for sharing the term `v_1`. 116 | There are other duplicate expressions that could be let bound, but are not because 117 | `egg` only binds shared sub-terms found during the explanation generation process. 118 | 119 | Besides the string forms, [`TreeExplanation`] and [`FlatExplanation`] encode the same information 120 | as Rust objects. 121 | For proof sharing, each `Rc` in the [`TreeExplanation`] can be checked for pointer 122 | equality with other terms. 123 | 124 | 125 | [`EGraph`]: super::super::EGraph 126 | [`Runner`]: super::super::Runner 127 | [`Explanation`]: super::super::Explanation 128 | [`TreeExplanation`]: super::super::TreeExplanation 129 | [`FlatExplanation`]: super::super::FlatExplanation 130 | [`TreeTerm`]: super::super::TreeTerm 131 | [`with_explanations_enabled`]: super::super::Runner::with_explanations_enabled 132 | 133 | */ 134 | -------------------------------------------------------------------------------- /src/tutorials/mod.rs: -------------------------------------------------------------------------------- 1 | // -*- mode: markdown; markdown-fontify-code-block-default-mode: rustic-mode; -*- 2 | /*! 3 | 4 | # A Guide-level Explanation of `egg` 5 | 6 | `egg` is a e-graph library optimized for equality saturation. 7 | Using these techniques, one can pretty easily build an optimizer or synthesizer for a language. 8 | 9 | This tutorial is targeted at readers who may not have seen e-graphs, equality saturation, or Rust. 10 | If you already know some of that stuff, you may just want to skim or skip those chapters. 11 | 12 | This is intended to be guide-level introduction using examples to build intuition. 13 | For more detail, check out the [API documentation](super), 14 | which the tutorials will link to frequently. 15 | Most of the code examples here are typechecked and run, so you may copy-paste them to get started. 16 | 17 | There is also a [paper](https://arxiv.org/abs/2004.03082) 18 | describing `egg` and if you are keen to read more about its technical novelties. 19 | 20 | The tutorials are a work-in-progress with more to be added soon. 21 | 22 | */ 23 | 24 | pub mod _01_background; 25 | pub mod _02_getting_started; 26 | pub mod _03_explanations; 27 | -------------------------------------------------------------------------------- /src/unionfind.rs: -------------------------------------------------------------------------------- 1 | use crate::Id; 2 | use std::fmt::Debug; 3 | 4 | #[derive(Debug, Clone, Default)] 5 | #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] 6 | pub struct UnionFind { 7 | parents: Vec, 8 | } 9 | 10 | impl UnionFind { 11 | pub fn make_set(&mut self) -> Id { 12 | let id = Id::from(self.parents.len()); 13 | self.parents.push(id); 14 | id 15 | } 16 | 17 | pub fn size(&self) -> usize { 18 | self.parents.len() 19 | } 20 | 21 | fn parent(&self, query: Id) -> Id { 22 | self.parents[usize::from(query)] 23 | } 24 | 25 | fn parent_mut(&mut self, query: Id) -> &mut Id { 26 | &mut self.parents[usize::from(query)] 27 | } 28 | 29 | pub fn find(&self, mut current: Id) -> Id { 30 | while current != self.parent(current) { 31 | current = self.parent(current) 32 | } 33 | current 34 | } 35 | 36 | pub fn find_mut(&mut self, mut current: Id) -> Id { 37 | while current != self.parent(current) { 38 | let grandparent = self.parent(self.parent(current)); 39 | *self.parent_mut(current) = grandparent; 40 | current = grandparent; 41 | } 42 | current 43 | } 44 | 45 | /// Given two leader ids, unions the two eclasses making root1 the leader. 46 | pub fn union(&mut self, root1: Id, root2: Id) -> Id { 47 | *self.parent_mut(root2) = root1; 48 | root1 49 | } 50 | } 51 | 52 | #[cfg(test)] 53 | mod tests { 54 | use super::*; 55 | 56 | fn ids(us: impl IntoIterator) -> Vec { 57 | us.into_iter().map(|u| u.into()).collect() 58 | } 59 | 60 | #[test] 61 | fn union_find() { 62 | let n = 10; 63 | let id = Id::from; 64 | 65 | let mut uf = UnionFind::default(); 66 | for _ in 0..n { 67 | uf.make_set(); 68 | } 69 | 70 | // test the initial condition of everyone in their own set 71 | assert_eq!(uf.parents, ids(0..n)); 72 | 73 | // build up one set 74 | uf.union(id(0), id(1)); 75 | uf.union(id(0), id(2)); 76 | uf.union(id(0), id(3)); 77 | 78 | // build up another set 79 | uf.union(id(6), id(7)); 80 | uf.union(id(6), id(8)); 81 | uf.union(id(6), id(9)); 82 | 83 | // this should compress all paths 84 | for i in 0..n { 85 | uf.find_mut(id(i)); 86 | } 87 | 88 | // indexes: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 89 | let expected = vec![0, 0, 0, 0, 4, 5, 6, 6, 6, 6]; 90 | assert_eq!(uf.parents, ids(expected)); 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | use std::{fmt, iter::FromIterator}; 2 | use symbolic_expressions::Sexp; 3 | 4 | use fmt::{Debug, Display, Formatter}; 5 | 6 | #[cfg(feature = "serde-1")] 7 | use ::serde::{Deserialize, Serialize}; 8 | 9 | #[allow(unused_imports)] 10 | use crate::*; 11 | 12 | /// An interned string. 13 | /// 14 | /// This is provided by the [`symbol_table`](https://crates.io/crates/symbol_table) crate. 15 | /// 16 | /// Internally, `egg` frequently compares [`Var`]s and elements of 17 | /// [`Language`]s. To keep comparisons fast, `egg` provides [`Symbol`] a simple 18 | /// wrapper providing interned strings. 19 | /// 20 | /// You may wish to use [`Symbol`] in your own [`Language`]s to increase 21 | /// performance and keep enode sizes down (a [`Symbol`] is only 4 bytes, 22 | /// compared to 24 for a `String`.) 23 | /// 24 | /// A [`Symbol`] is simply a wrapper around an integer. 25 | /// When creating a [`Symbol`] from a string, `egg` looks up it up in a global 26 | /// table, returning the index (inserting it if not found). 27 | /// That integer is used to cheaply implement 28 | /// `Copy`, `Clone`, `PartialEq`, `Eq`, `PartialOrd`, `Ord`, and `Hash`. 29 | /// 30 | /// The internal symbol cache leaks the strings, which should be 31 | /// fine if you only put in things like variable names and identifiers. 32 | /// 33 | /// # Example 34 | /// ```rust 35 | /// use egg::Symbol; 36 | /// 37 | /// assert_eq!(Symbol::from("foo"), Symbol::from("foo")); 38 | /// assert_eq!(Symbol::from("foo"), "foo".parse().unwrap()); 39 | /// 40 | /// assert_ne!(Symbol::from("foo"), Symbol::from("bar")); 41 | /// ``` 42 | /// 43 | pub use symbol_table::GlobalSymbol as Symbol; 44 | 45 | pub(crate) type BuildHasher = fxhash::FxBuildHasher; 46 | 47 | // pub(crate) type HashMap = hashbrown::HashMap; 48 | // pub(crate) type HashSet = hashbrown::HashSet; 49 | 50 | pub(crate) use hashmap::*; 51 | 52 | #[cfg(feature = "deterministic")] 53 | mod hashmap { 54 | pub(crate) type HashMap = super::IndexMap; 55 | pub(crate) type HashSet = super::IndexSet; 56 | } 57 | #[cfg(not(feature = "deterministic"))] 58 | mod hashmap { 59 | use super::BuildHasher; 60 | pub(crate) type HashMap = hashbrown::HashMap; 61 | pub(crate) type HashSet = hashbrown::HashSet; 62 | } 63 | 64 | pub(crate) type IndexMap = indexmap::IndexMap; 65 | pub(crate) type IndexSet = indexmap::IndexSet; 66 | 67 | pub(crate) type Instant = instant::Instant; 68 | pub(crate) type Duration = instant::Duration; 69 | 70 | pub(crate) fn concat_vecs(to: &mut Vec, mut from: Vec) { 71 | if to.len() < from.len() { 72 | std::mem::swap(to, &mut from) 73 | } 74 | to.extend(from); 75 | } 76 | 77 | pub(crate) fn pretty_print( 78 | buf: &mut String, 79 | sexp: &Sexp, 80 | width: usize, 81 | level: usize, 82 | ) -> std::fmt::Result { 83 | use std::fmt::Write; 84 | if let Sexp::List(list) = sexp { 85 | let indent = sexp.to_string().len() > width; 86 | write!(buf, "(")?; 87 | 88 | for (i, val) in list.iter().enumerate() { 89 | if indent && i > 0 { 90 | writeln!(buf)?; 91 | for _ in 0..level { 92 | write!(buf, " ")?; 93 | } 94 | } 95 | pretty_print(buf, val, width, level + 1)?; 96 | if !indent && i < list.len() - 1 { 97 | write!(buf, " ")?; 98 | } 99 | } 100 | 101 | write!(buf, ")")?; 102 | Ok(()) 103 | } else { 104 | // I don't care about quotes 105 | write!(buf, "{}", sexp.to_string().trim_matches('"')) 106 | } 107 | } 108 | 109 | /// A wrapper that uses display implementation as debug 110 | pub(crate) struct DisplayAsDebug(pub T); 111 | 112 | impl Debug for DisplayAsDebug { 113 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 114 | Display::fmt(&self.0, f) 115 | } 116 | } 117 | 118 | /** A data structure to maintain a queue of unique elements. 119 | 120 | Notably, insert/pop operations have O(1) expected amortized runtime complexity. 121 | */ 122 | #[derive(Clone)] 123 | #[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] 124 | pub(crate) struct UniqueQueue 125 | where 126 | T: Eq + std::hash::Hash + Clone, 127 | { 128 | set: hashbrown::HashSet, 129 | queue: std::collections::VecDeque, 130 | } 131 | 132 | impl Default for UniqueQueue 133 | where 134 | T: Eq + std::hash::Hash + Clone, 135 | { 136 | fn default() -> Self { 137 | UniqueQueue { 138 | set: hashbrown::HashSet::default(), 139 | queue: std::collections::VecDeque::new(), 140 | } 141 | } 142 | } 143 | 144 | impl UniqueQueue 145 | where 146 | T: Eq + std::hash::Hash + Clone, 147 | { 148 | pub fn insert(&mut self, t: T) { 149 | if self.set.insert(t.clone()) { 150 | self.queue.push_back(t); 151 | } 152 | } 153 | 154 | pub fn extend(&mut self, iter: I) 155 | where 156 | I: IntoIterator, 157 | { 158 | for t in iter.into_iter() { 159 | self.insert(t); 160 | } 161 | } 162 | 163 | pub fn pop(&mut self) -> Option { 164 | let res = self.queue.pop_front(); 165 | res.as_ref().map(|t| self.set.remove(t)); 166 | res 167 | } 168 | 169 | pub fn is_empty(&self) -> bool { 170 | let r = self.queue.is_empty(); 171 | debug_assert_eq!(r, self.set.is_empty()); 172 | r 173 | } 174 | } 175 | 176 | impl IntoIterator for UniqueQueue 177 | where 178 | T: Eq + std::hash::Hash + Clone, 179 | { 180 | type Item = T; 181 | 182 | type IntoIter = as IntoIterator>::IntoIter; 183 | 184 | fn into_iter(self) -> Self::IntoIter { 185 | self.queue.into_iter() 186 | } 187 | } 188 | 189 | impl FromIterator for UniqueQueue 190 | where 191 | A: Eq + std::hash::Hash + Clone, 192 | { 193 | fn from_iter>(iter: T) -> Self { 194 | let mut queue = UniqueQueue::default(); 195 | for t in iter { 196 | queue.insert(t); 197 | } 198 | queue 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /tests/datalog.rs: -------------------------------------------------------------------------------- 1 | use egg::*; 2 | 3 | define_language! { 4 | enum Lang { 5 | "true" = True, 6 | Int(i32), 7 | Relation(Symbol, Box<[Id]>), 8 | } 9 | } 10 | 11 | trait DatalogExtTrait { 12 | fn assert(&mut self, s: &str); 13 | fn check(&mut self, s: &str); 14 | fn check_not(&mut self, s: &str); 15 | } 16 | 17 | impl DatalogExtTrait for EGraph { 18 | fn assert(&mut self, s: &str) { 19 | let true_id = self.add(Lang::True); 20 | for e in s.split(',') { 21 | let exp = e.trim().parse().unwrap(); 22 | let id = self.add_expr(&exp); 23 | self.union(true_id, id); 24 | } 25 | } 26 | 27 | fn check(&mut self, s: &str) { 28 | let true_id = self.add(Lang::True); 29 | for e in s.split(',') { 30 | let exp = e.trim().parse().unwrap(); 31 | let id = self.add_expr(&exp); 32 | assert_eq!(true_id, id, "{} is not true", e); 33 | } 34 | } 35 | 36 | fn check_not(&mut self, s: &str) { 37 | let true_id = self.add(Lang::True); 38 | for e in s.split(',') { 39 | let exp = e.trim().parse().unwrap(); 40 | let id = self.add_expr(&exp); 41 | assert!(true_id != id, "{} is true", e); 42 | } 43 | } 44 | } 45 | 46 | #[test] 47 | fn path() { 48 | let mut egraph = EGraph::::default(); 49 | egraph.assert("(edge 1 2), (edge 2 3), (edge 3 4)"); 50 | let rules = vec![ 51 | multi_rewrite!("base-case"; "?x = true = (edge ?a ?b)" => "?x = (path ?a ?b)"), 52 | multi_rewrite!("transitive"; "?x = true = (path ?a ?b) = (edge ?b ?c)" => "?x = (path ?a ?c)"), 53 | ]; 54 | 55 | let mut runner = Runner::default().with_egraph(egraph).run(&rules); 56 | runner.egraph.check("(path 1 4)"); 57 | runner.egraph.check_not("(path 4 1)"); 58 | } 59 | 60 | #[test] 61 | fn path2() { 62 | // `pred` function symbol allows us to insert without truth. 63 | let mut egraph = EGraph::::default(); 64 | egraph.assert("(edge 1 2), (edge 2 3), (edge 3 4), (edge 1 4)"); 65 | let rules = vec![ 66 | multi_rewrite!("base-case"; "?x = (edge ?a ?b), ?t = true" => "?t = (pred (path ?a ?b))"), 67 | multi_rewrite!("transitive"; "?x = (path ?a ?b), ?y = (edge ?b ?c), ?t = true" => "?t = (pred (path ?a ?c))"), 68 | ]; 69 | let mut runner = Runner::default().with_egraph(egraph).run(&rules); 70 | runner.egraph.check("(pred (path 1 4))"); 71 | runner.egraph.check("(pred (path 2 3))"); 72 | runner.egraph.check_not("(pred (path 4 1))"); 73 | runner.egraph.check_not("(pred (path 3 1))"); 74 | } 75 | -------------------------------------------------------------------------------- /tests/lambda.rs: -------------------------------------------------------------------------------- 1 | use egg::{rewrite as rw, *}; 2 | use fxhash::FxHashSet as HashSet; 3 | 4 | define_language! { 5 | enum Lambda { 6 | Bool(bool), 7 | Num(i32), 8 | 9 | "var" = Var(Id), 10 | 11 | "+" = Add([Id; 2]), 12 | "=" = Eq([Id; 2]), 13 | 14 | "app" = App([Id; 2]), 15 | "lam" = Lambda([Id; 2]), 16 | "let" = Let([Id; 3]), 17 | "fix" = Fix([Id; 2]), 18 | 19 | "if" = If([Id; 3]), 20 | 21 | Symbol(egg::Symbol), 22 | } 23 | } 24 | 25 | impl Lambda { 26 | fn num(&self) -> Option { 27 | match self { 28 | Lambda::Num(n) => Some(*n), 29 | _ => None, 30 | } 31 | } 32 | } 33 | 34 | type EGraph = egg::EGraph; 35 | 36 | #[derive(Default)] 37 | struct LambdaAnalysis; 38 | 39 | #[derive(Debug)] 40 | struct Data { 41 | free: HashSet, 42 | constant: Option<(Lambda, PatternAst)>, 43 | } 44 | 45 | fn eval(egraph: &EGraph, enode: &Lambda) -> Option<(Lambda, PatternAst)> { 46 | let x = |i: &Id| egraph[*i].data.constant.as_ref().map(|c| &c.0); 47 | match enode { 48 | Lambda::Num(n) => Some((enode.clone(), format!("{}", n).parse().unwrap())), 49 | Lambda::Bool(b) => Some((enode.clone(), format!("{}", b).parse().unwrap())), 50 | Lambda::Add([a, b]) => Some(( 51 | Lambda::Num(x(a)?.num()?.checked_add(x(b)?.num()?)?), 52 | format!("(+ {} {})", x(a)?, x(b)?).parse().unwrap(), 53 | )), 54 | Lambda::Eq([a, b]) => Some(( 55 | Lambda::Bool(x(a)? == x(b)?), 56 | format!("(= {} {})", x(a)?, x(b)?).parse().unwrap(), 57 | )), 58 | _ => None, 59 | } 60 | } 61 | 62 | impl Analysis for LambdaAnalysis { 63 | type Data = Data; 64 | fn merge(&mut self, to: &mut Data, from: Data) -> DidMerge { 65 | let before_len = to.free.len(); 66 | // to.free.extend(from.free); 67 | to.free.retain(|i| from.free.contains(i)); 68 | // compare lengths to see if I changed to or from 69 | DidMerge( 70 | before_len != to.free.len(), 71 | to.free.len() != from.free.len(), 72 | ) | merge_option(&mut to.constant, from.constant, |a, b| { 73 | assert_eq!(a.0, b.0, "Merged non-equal constants"); 74 | DidMerge(false, false) 75 | }) 76 | } 77 | 78 | fn make(egraph: &mut EGraph, enode: &Lambda) -> Data { 79 | let f = |i: &Id| egraph[*i].data.free.iter().cloned(); 80 | let mut free = HashSet::default(); 81 | match enode { 82 | Lambda::Var(v) => { 83 | free.insert(*v); 84 | } 85 | Lambda::Let([v, a, b]) => { 86 | free.extend(f(b)); 87 | free.remove(v); 88 | free.extend(f(a)); 89 | } 90 | Lambda::Lambda([v, a]) | Lambda::Fix([v, a]) => { 91 | free.extend(f(a)); 92 | free.remove(v); 93 | } 94 | _ => enode.for_each(|c| free.extend(&egraph[c].data.free)), 95 | } 96 | let constant = eval(egraph, enode); 97 | Data { constant, free } 98 | } 99 | 100 | fn modify(egraph: &mut EGraph, id: Id) { 101 | if let Some(c) = egraph[id].data.constant.clone() { 102 | if egraph.are_explanations_enabled() { 103 | egraph.union_instantiations( 104 | &c.0.to_string().parse().unwrap(), 105 | &c.1, 106 | &Default::default(), 107 | "analysis".to_string(), 108 | ); 109 | } else { 110 | let const_id = egraph.add(c.0); 111 | egraph.union(id, const_id); 112 | } 113 | } 114 | } 115 | } 116 | 117 | fn var(s: &str) -> Var { 118 | s.parse().unwrap() 119 | } 120 | 121 | fn is_not_same_var(v1: Var, v2: Var) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { 122 | move |egraph, _, subst| egraph.find(subst[v1]) != egraph.find(subst[v2]) 123 | } 124 | 125 | fn is_const(v: Var) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { 126 | move |egraph, _, subst| egraph[subst[v]].data.constant.is_some() 127 | } 128 | 129 | fn rules() -> Vec> { 130 | vec![ 131 | // open term rules 132 | rw!("if-true"; "(if true ?then ?else)" => "?then"), 133 | rw!("if-false"; "(if false ?then ?else)" => "?else"), 134 | rw!("if-elim"; "(if (= (var ?x) ?e) ?then ?else)" => "?else" 135 | if ConditionEqual::parse("(let ?x ?e ?then)", "(let ?x ?e ?else)")), 136 | rw!("add-comm"; "(+ ?a ?b)" => "(+ ?b ?a)"), 137 | rw!("add-assoc"; "(+ (+ ?a ?b) ?c)" => "(+ ?a (+ ?b ?c))"), 138 | rw!("eq-comm"; "(= ?a ?b)" => "(= ?b ?a)"), 139 | // subst rules 140 | rw!("fix"; "(fix ?v ?e)" => "(let ?v (fix ?v ?e) ?e)"), 141 | rw!("beta"; "(app (lam ?v ?body) ?e)" => "(let ?v ?e ?body)"), 142 | rw!("let-app"; "(let ?v ?e (app ?a ?b))" => "(app (let ?v ?e ?a) (let ?v ?e ?b))"), 143 | rw!("let-add"; "(let ?v ?e (+ ?a ?b))" => "(+ (let ?v ?e ?a) (let ?v ?e ?b))"), 144 | rw!("let-eq"; "(let ?v ?e (= ?a ?b))" => "(= (let ?v ?e ?a) (let ?v ?e ?b))"), 145 | rw!("let-const"; 146 | "(let ?v ?e ?c)" => "?c" if is_const(var("?c"))), 147 | rw!("let-if"; 148 | "(let ?v ?e (if ?cond ?then ?else))" => 149 | "(if (let ?v ?e ?cond) (let ?v ?e ?then) (let ?v ?e ?else))" 150 | ), 151 | rw!("let-var-same"; "(let ?v1 ?e (var ?v1))" => "?e"), 152 | rw!("let-var-diff"; "(let ?v1 ?e (var ?v2))" => "(var ?v2)" 153 | if is_not_same_var(var("?v1"), var("?v2"))), 154 | rw!("let-lam-same"; "(let ?v1 ?e (lam ?v1 ?body))" => "(lam ?v1 ?body)"), 155 | rw!("let-lam-diff"; 156 | "(let ?v1 ?e (lam ?v2 ?body))" => 157 | { CaptureAvoid { 158 | fresh: var("?fresh"), v2: var("?v2"), e: var("?e"), 159 | if_not_free: "(lam ?v2 (let ?v1 ?e ?body))".parse().unwrap(), 160 | if_free: "(lam ?fresh (let ?v1 ?e (let ?v2 (var ?fresh) ?body)))".parse().unwrap(), 161 | }} 162 | if is_not_same_var(var("?v1"), var("?v2"))), 163 | ] 164 | } 165 | 166 | struct CaptureAvoid { 167 | fresh: Var, 168 | v2: Var, 169 | e: Var, 170 | if_not_free: Pattern, 171 | if_free: Pattern, 172 | } 173 | 174 | impl Applier for CaptureAvoid { 175 | fn apply_one( 176 | &self, 177 | egraph: &mut EGraph, 178 | eclass: Id, 179 | subst: &Subst, 180 | searcher_ast: Option<&PatternAst>, 181 | rule_name: Symbol, 182 | ) -> Vec { 183 | let e = subst[self.e]; 184 | let v2 = subst[self.v2]; 185 | let v2_free_in_e = egraph[e].data.free.contains(&v2); 186 | if v2_free_in_e { 187 | let mut subst = subst.clone(); 188 | let sym = Lambda::Symbol(format!("_{}", eclass).into()); 189 | subst.insert(self.fresh, egraph.add(sym)); 190 | self.if_free 191 | .apply_one(egraph, eclass, &subst, searcher_ast, rule_name) 192 | } else { 193 | self.if_not_free 194 | .apply_one(egraph, eclass, subst, searcher_ast, rule_name) 195 | } 196 | } 197 | } 198 | 199 | egg::test_fn! { 200 | lambda_under, rules(), 201 | "(lam x (+ 4 202 | (app (lam y (var y)) 203 | 4)))" 204 | => 205 | // "(lam x (+ 4 (let y 4 (var y))))", 206 | // "(lam x (+ 4 4))", 207 | "(lam x 8))", 208 | } 209 | 210 | egg::test_fn! { 211 | lambda_if_elim, rules(), 212 | "(if (= (var a) (var b)) 213 | (+ (var a) (var a)) 214 | (+ (var a) (var b)))" 215 | => 216 | "(+ (var a) (var b))" 217 | } 218 | 219 | egg::test_fn! { 220 | lambda_let_simple, rules(), 221 | "(let x 0 222 | (let y 1 223 | (+ (var x) (var y))))" 224 | => 225 | // "(let ?a 0 226 | // (+ (var ?a) 1))", 227 | // "(+ 0 1)", 228 | "1", 229 | } 230 | 231 | egg::test_fn! { 232 | #[should_panic(expected = "Could not prove goal 0")] 233 | lambda_capture, rules(), 234 | "(let x 1 (lam x (var x)))" => "(lam x 1)" 235 | } 236 | 237 | egg::test_fn! { 238 | #[should_panic(expected = "Could not prove goal 0")] 239 | lambda_capture_free, rules(), 240 | "(let y (+ (var x) (var x)) (lam x (var y)))" => "(lam x (+ (var x) (var x)))" 241 | } 242 | 243 | egg::test_fn! { 244 | #[should_panic(expected = "Could not prove goal 0")] 245 | lambda_closure_not_seven, rules(), 246 | "(let five 5 247 | (let add-five (lam x (+ (var x) (var five))) 248 | (let five 6 249 | (app (var add-five) 1))))" 250 | => 251 | "7" 252 | } 253 | 254 | egg::test_fn! { 255 | lambda_compose, rules(), 256 | "(let compose (lam f (lam g (lam x (app (var f) 257 | (app (var g) (var x)))))) 258 | (let add1 (lam y (+ (var y) 1)) 259 | (app (app (var compose) (var add1)) (var add1))))" 260 | => 261 | "(lam ?x (+ 1 262 | (app (lam ?y (+ 1 (var ?y))) 263 | (var ?x))))", 264 | "(lam ?x (+ (var ?x) 2))" 265 | } 266 | 267 | egg::test_fn! { 268 | lambda_if_simple, rules(), 269 | "(if (= 1 1) 7 9)" => "7" 270 | } 271 | 272 | egg::test_fn! { 273 | lambda_compose_many, rules(), 274 | "(let compose (lam f (lam g (lam x (app (var f) 275 | (app (var g) (var x)))))) 276 | (let add1 (lam y (+ (var y) 1)) 277 | (app (app (var compose) (var add1)) 278 | (app (app (var compose) (var add1)) 279 | (app (app (var compose) (var add1)) 280 | (app (app (var compose) (var add1)) 281 | (app (app (var compose) (var add1)) 282 | (app (app (var compose) (var add1)) 283 | (var add1)))))))))" 284 | => 285 | "(lam ?x (+ (var ?x) 7))" 286 | } 287 | 288 | egg::test_fn! { 289 | #[cfg(not(debug_assertions))] 290 | #[cfg_attr(feature = "test-explanations", ignore)] 291 | lambda_function_repeat, rules(), 292 | runner = Runner::default() 293 | .with_time_limit(std::time::Duration::from_secs(20)) 294 | .with_node_limit(150_000) 295 | .with_iter_limit(60), 296 | "(let compose (lam f (lam g (lam x (app (var f) 297 | (app (var g) (var x)))))) 298 | (let repeat (fix repeat (lam fun (lam n 299 | (if (= (var n) 0) 300 | (lam i (var i)) 301 | (app (app (var compose) (var fun)) 302 | (app (app (var repeat) 303 | (var fun)) 304 | (+ (var n) -1))))))) 305 | (let add1 (lam y (+ (var y) 1)) 306 | (app (app (var repeat) 307 | (var add1)) 308 | 2))))" 309 | => 310 | "(lam ?x (+ (var ?x) 2))" 311 | } 312 | 313 | egg::test_fn! { 314 | lambda_if, rules(), 315 | "(let zeroone (lam x 316 | (if (= (var x) 0) 317 | 0 318 | 1)) 319 | (+ (app (var zeroone) 0) 320 | (app (var zeroone) 10)))" 321 | => 322 | // "(+ (if false 0 1) (if true 0 1))", 323 | // "(+ 1 0)", 324 | "1", 325 | } 326 | 327 | egg::test_fn! { 328 | #[cfg(not(debug_assertions))] 329 | #[cfg_attr(feature = "test-explanations", ignore)] 330 | lambda_fib, rules(), 331 | runner = Runner::default() 332 | .with_iter_limit(60) 333 | .with_node_limit(500_000), 334 | "(let fib (fix fib (lam n 335 | (if (= (var n) 0) 336 | 0 337 | (if (= (var n) 1) 338 | 1 339 | (+ (app (var fib) 340 | (+ (var n) -1)) 341 | (app (var fib) 342 | (+ (var n) -2))))))) 343 | (app (var fib) 4))" 344 | => "3" 345 | } 346 | 347 | #[test] 348 | fn lambda_ematching_bench() { 349 | let exprs = &[ 350 | "(let zeroone (lam x 351 | (if (= (var x) 0) 352 | 0 353 | 1)) 354 | (+ (app (var zeroone) 0) 355 | (app (var zeroone) 10)))", 356 | "(let compose (lam f (lam g (lam x (app (var f) 357 | (app (var g) (var x)))))) 358 | (let repeat (fix repeat (lam fun (lam n 359 | (if (= (var n) 0) 360 | (lam i (var i)) 361 | (app (app (var compose) (var fun)) 362 | (app (app (var repeat) 363 | (var fun)) 364 | (+ (var n) -1))))))) 365 | (let add1 (lam y (+ (var y) 1)) 366 | (app (app (var repeat) 367 | (var add1)) 368 | 2))))", 369 | "(let fib (fix fib (lam n 370 | (if (= (var n) 0) 371 | 0 372 | (if (= (var n) 1) 373 | 1 374 | (+ (app (var fib) 375 | (+ (var n) -1)) 376 | (app (var fib) 377 | (+ (var n) -2))))))) 378 | (app (var fib) 4))", 379 | ]; 380 | 381 | let extra_patterns = &[ 382 | "(if (= (var ?x) ?e) ?then ?else)", 383 | "(+ (+ ?a ?b) ?c)", 384 | "(let ?v (fix ?v ?e) ?e)", 385 | "(app (lam ?v ?body) ?e)", 386 | "(let ?v ?e (app ?a ?b))", 387 | "(app (let ?v ?e ?a) (let ?v ?e ?b))", 388 | "(let ?v ?e (+ ?a ?b))", 389 | "(+ (let ?v ?e ?a) (let ?v ?e ?b))", 390 | "(let ?v ?e (= ?a ?b))", 391 | "(= (let ?v ?e ?a) (let ?v ?e ?b))", 392 | "(let ?v ?e (if ?cond ?then ?else))", 393 | "(if (let ?v ?e ?cond) (let ?v ?e ?then) (let ?v ?e ?else))", 394 | "(let ?v1 ?e (var ?v1))", 395 | "(let ?v1 ?e (var ?v2))", 396 | "(let ?v1 ?e (lam ?v1 ?body))", 397 | "(let ?v1 ?e (lam ?v2 ?body))", 398 | "(lam ?v2 (let ?v1 ?e ?body))", 399 | "(lam ?fresh (let ?v1 ?e (let ?v2 (var ?fresh) ?body)))", 400 | ]; 401 | 402 | egg::test::bench_egraph("lambda", rules(), exprs, extra_patterns); 403 | } 404 | -------------------------------------------------------------------------------- /tests/math.rs: -------------------------------------------------------------------------------- 1 | use egg::{rewrite as rw, *}; 2 | use ordered_float::NotNan; 3 | 4 | pub type EGraph = egg::EGraph; 5 | pub type Rewrite = egg::Rewrite; 6 | 7 | pub type Constant = NotNan; 8 | 9 | define_language! { 10 | pub enum Math { 11 | "d" = Diff([Id; 2]), 12 | "i" = Integral([Id; 2]), 13 | 14 | "+" = Add([Id; 2]), 15 | "-" = Sub([Id; 2]), 16 | "*" = Mul([Id; 2]), 17 | "/" = Div([Id; 2]), 18 | "pow" = Pow([Id; 2]), 19 | "ln" = Ln(Id), 20 | "sqrt" = Sqrt(Id), 21 | 22 | "sin" = Sin(Id), 23 | "cos" = Cos(Id), 24 | 25 | Constant(Constant), 26 | Symbol(Symbol), 27 | } 28 | } 29 | 30 | // You could use egg::AstSize, but this is useful for debugging, since 31 | // it will really try to get rid of the Diff operator 32 | pub struct MathCostFn; 33 | impl egg::CostFunction for MathCostFn { 34 | type Cost = usize; 35 | fn cost(&mut self, enode: &Math, mut costs: C) -> Self::Cost 36 | where 37 | C: FnMut(Id) -> Self::Cost, 38 | { 39 | let op_cost = match enode { 40 | Math::Diff(..) => 100, 41 | Math::Integral(..) => 100, 42 | _ => 1, 43 | }; 44 | enode.fold(op_cost, |sum, i| sum + costs(i)) 45 | } 46 | } 47 | 48 | #[derive(Default)] 49 | pub struct ConstantFold; 50 | impl Analysis for ConstantFold { 51 | type Data = Option<(Constant, PatternAst)>; 52 | 53 | fn make(egraph: &mut EGraph, enode: &Math) -> Self::Data { 54 | let x = |i: &Id| egraph[*i].data.as_ref().map(|d| d.0); 55 | Some(match enode { 56 | Math::Constant(c) => (*c, format!("{}", c).parse().unwrap()), 57 | Math::Add([a, b]) => ( 58 | x(a)? + x(b)?, 59 | format!("(+ {} {})", x(a)?, x(b)?).parse().unwrap(), 60 | ), 61 | Math::Sub([a, b]) => ( 62 | x(a)? - x(b)?, 63 | format!("(- {} {})", x(a)?, x(b)?).parse().unwrap(), 64 | ), 65 | Math::Mul([a, b]) => ( 66 | x(a)? * x(b)?, 67 | format!("(* {} {})", x(a)?, x(b)?).parse().unwrap(), 68 | ), 69 | Math::Div([a, b]) if x(b) != Some(NotNan::new(0.0).unwrap()) => ( 70 | x(a)? / x(b)?, 71 | format!("(/ {} {})", x(a)?, x(b)?).parse().unwrap(), 72 | ), 73 | _ => return None, 74 | }) 75 | } 76 | 77 | fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { 78 | merge_option(to, from, |a, b| { 79 | assert_eq!(a.0, b.0, "Merged non-equal constants"); 80 | DidMerge(false, false) 81 | }) 82 | } 83 | 84 | fn modify(egraph: &mut EGraph, id: Id) { 85 | let data = egraph[id].data.clone(); 86 | if let Some((c, pat)) = data { 87 | if egraph.are_explanations_enabled() { 88 | egraph.union_instantiations( 89 | &pat, 90 | &format!("{}", c).parse().unwrap(), 91 | &Default::default(), 92 | "constant_fold".to_string(), 93 | ); 94 | } else { 95 | let added = egraph.add(Math::Constant(c)); 96 | egraph.union(id, added); 97 | } 98 | // to not prune, comment this out 99 | egraph[id].nodes.retain(|n| n.is_leaf()); 100 | 101 | #[cfg(debug_assertions)] 102 | egraph[id].assert_unique_leaves(); 103 | } 104 | } 105 | } 106 | 107 | fn is_const_or_distinct_var(v: &str, w: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { 108 | let v = v.parse().unwrap(); 109 | let w = w.parse().unwrap(); 110 | move |egraph, _, subst| { 111 | egraph.find(subst[v]) != egraph.find(subst[w]) 112 | && (egraph[subst[v]].data.is_some() 113 | || egraph[subst[v]] 114 | .nodes 115 | .iter() 116 | .any(|n| matches!(n, Math::Symbol(..)))) 117 | } 118 | } 119 | 120 | fn is_const(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { 121 | let var = var.parse().unwrap(); 122 | move |egraph, _, subst| egraph[subst[var]].data.is_some() 123 | } 124 | 125 | fn is_sym(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { 126 | let var = var.parse().unwrap(); 127 | move |egraph, _, subst| { 128 | egraph[subst[var]] 129 | .nodes 130 | .iter() 131 | .any(|n| matches!(n, Math::Symbol(..))) 132 | } 133 | } 134 | 135 | fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { 136 | let var = var.parse().unwrap(); 137 | move |egraph, _, subst| { 138 | if let Some(n) = &egraph[subst[var]].data { 139 | *(n.0) != 0.0 140 | } else { 141 | true 142 | } 143 | } 144 | } 145 | 146 | #[rustfmt::skip] 147 | pub fn rules() -> Vec { vec![ 148 | rw!("comm-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), 149 | rw!("comm-mul"; "(* ?a ?b)" => "(* ?b ?a)"), 150 | rw!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"), 151 | rw!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"), 152 | 153 | rw!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"), 154 | rw!("div-canon"; "(/ ?a ?b)" => "(* ?a (pow ?b -1))" if is_not_zero("?b")), 155 | // rw!("canon-sub"; "(+ ?a (* -1 ?b))" => "(- ?a ?b)"), 156 | // rw!("canon-div"; "(* ?a (pow ?b -1))" => "(/ ?a ?b)" if is_not_zero("?b")), 157 | 158 | rw!("zero-add"; "(+ ?a 0)" => "?a"), 159 | rw!("zero-mul"; "(* ?a 0)" => "0"), 160 | rw!("one-mul"; "(* ?a 1)" => "?a"), 161 | 162 | rw!("add-zero"; "?a" => "(+ ?a 0)"), 163 | rw!("mul-one"; "?a" => "(* ?a 1)"), 164 | 165 | rw!("cancel-sub"; "(- ?a ?a)" => "0"), 166 | rw!("cancel-div"; "(/ ?a ?a)" => "1" if is_not_zero("?a")), 167 | 168 | rw!("distribute"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"), 169 | rw!("factor" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"), 170 | 171 | rw!("pow-mul"; "(* (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (+ ?b ?c))"), 172 | rw!("pow0"; "(pow ?x 0)" => "1" 173 | if is_not_zero("?x")), 174 | rw!("pow1"; "(pow ?x 1)" => "?x"), 175 | rw!("pow2"; "(pow ?x 2)" => "(* ?x ?x)"), 176 | rw!("pow-recip"; "(pow ?x -1)" => "(/ 1 ?x)" 177 | if is_not_zero("?x")), 178 | rw!("recip-mul-div"; "(* ?x (/ 1 ?x))" => "1" if is_not_zero("?x")), 179 | 180 | rw!("d-variable"; "(d ?x ?x)" => "1" if is_sym("?x")), 181 | rw!("d-constant"; "(d ?x ?c)" => "0" if is_sym("?x") if is_const_or_distinct_var("?c", "?x")), 182 | 183 | rw!("d-add"; "(d ?x (+ ?a ?b))" => "(+ (d ?x ?a) (d ?x ?b))"), 184 | rw!("d-mul"; "(d ?x (* ?a ?b))" => "(+ (* ?a (d ?x ?b)) (* ?b (d ?x ?a)))"), 185 | 186 | rw!("d-sin"; "(d ?x (sin ?x))" => "(cos ?x)"), 187 | rw!("d-cos"; "(d ?x (cos ?x))" => "(* -1 (sin ?x))"), 188 | 189 | rw!("d-ln"; "(d ?x (ln ?x))" => "(/ 1 ?x)" if is_not_zero("?x")), 190 | 191 | rw!("d-power"; 192 | "(d ?x (pow ?f ?g))" => 193 | "(* (pow ?f ?g) 194 | (+ (* (d ?x ?f) 195 | (/ ?g ?f)) 196 | (* (d ?x ?g) 197 | (ln ?f))))" 198 | if is_not_zero("?f") 199 | if is_not_zero("?g") 200 | ), 201 | 202 | rw!("i-one"; "(i 1 ?x)" => "?x"), 203 | rw!("i-power-const"; "(i (pow ?x ?c) ?x)" => 204 | "(/ (pow ?x (+ ?c 1)) (+ ?c 1))" if is_const("?c")), 205 | rw!("i-cos"; "(i (cos ?x) ?x)" => "(sin ?x)"), 206 | rw!("i-sin"; "(i (sin ?x) ?x)" => "(* -1 (cos ?x))"), 207 | rw!("i-sum"; "(i (+ ?f ?g) ?x)" => "(+ (i ?f ?x) (i ?g ?x))"), 208 | rw!("i-dif"; "(i (- ?f ?g) ?x)" => "(- (i ?f ?x) (i ?g ?x))"), 209 | rw!("i-parts"; "(i (* ?a ?b) ?x)" => 210 | "(- (* ?a (i ?b ?x)) (i (* (d ?x ?a) (i ?b ?x)) ?x))"), 211 | ]} 212 | 213 | egg::test_fn! { 214 | math_associate_adds, [ 215 | rw!("comm-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), 216 | rw!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"), 217 | ], 218 | runner = Runner::default() 219 | .with_iter_limit(7) 220 | .with_scheduler(SimpleScheduler), 221 | "(+ 1 (+ 2 (+ 3 (+ 4 (+ 5 (+ 6 7))))))" 222 | => 223 | "(+ 7 (+ 6 (+ 5 (+ 4 (+ 3 (+ 2 1))))))" 224 | @check |r: Runner| assert_eq!(r.egraph.number_of_classes(), 127) 225 | } 226 | 227 | egg::test_fn! { 228 | #[should_panic(expected = "Could not prove goal 0")] 229 | math_fail, rules(), 230 | "(+ x y)" => "(/ x y)" 231 | } 232 | 233 | egg::test_fn! {math_simplify_add, rules(), "(+ x (+ x (+ x x)))" => "(* 4 x)" } 234 | egg::test_fn! {math_powers, rules(), "(* (pow 2 x) (pow 2 y))" => "(pow 2 (+ x y))"} 235 | 236 | egg::test_fn! { 237 | math_simplify_const, rules(), 238 | "(+ 1 (- a (* (- 2 1) a)))" => "1" 239 | } 240 | 241 | egg::test_fn! { 242 | math_simplify_root, rules(), 243 | runner = Runner::default().with_node_limit(75_000), 244 | r#" 245 | (/ 1 246 | (- (/ (+ 1 (sqrt five)) 247 | 2) 248 | (/ (- 1 (sqrt five)) 249 | 2)))"# 250 | => 251 | "(/ 1 (sqrt five))" 252 | } 253 | 254 | egg::test_fn! { 255 | math_simplify_factor, rules(), 256 | "(* (+ x 3) (+ x 1))" 257 | => 258 | "(+ (+ (* x x) (* 4 x)) 3)" 259 | } 260 | 261 | egg::test_fn! {math_diff_same, rules(), "(d x x)" => "1"} 262 | egg::test_fn! {math_diff_different, rules(), "(d x y)" => "0"} 263 | egg::test_fn! {math_diff_simple1, rules(), "(d x (+ 1 (* 2 x)))" => "2"} 264 | egg::test_fn! {math_diff_simple2, rules(), "(d x (+ 1 (* y x)))" => "y"} 265 | egg::test_fn! {math_diff_ln, rules(), "(d x (ln x))" => "(/ 1 x)"} 266 | 267 | egg::test_fn! { 268 | diff_power_simple, rules(), 269 | "(d x (pow x 3))" => "(* 3 (pow x 2))" 270 | } 271 | 272 | egg::test_fn! { 273 | diff_power_harder, rules(), 274 | runner = Runner::default() 275 | .with_time_limit(std::time::Duration::from_secs(10)) 276 | .with_iter_limit(60) 277 | .with_node_limit(100_000) 278 | .with_explanations_enabled() 279 | // HACK this needs to "see" the end expression 280 | .with_expr(&"(* x (- (* 3 x) 14))".parse().unwrap()), 281 | "(d x (- (pow x 3) (* 7 (pow x 2))))" 282 | => 283 | "(* x (- (* 3 x) 14))" 284 | } 285 | 286 | egg::test_fn! { 287 | integ_one, rules(), "(i 1 x)" => "x" 288 | } 289 | 290 | egg::test_fn! { 291 | integ_sin, rules(), "(i (cos x) x)" => "(sin x)" 292 | } 293 | 294 | egg::test_fn! { 295 | integ_x, rules(), "(i (pow x 1) x)" => "(/ (pow x 2) 2)" 296 | } 297 | 298 | egg::test_fn! { 299 | integ_part1, rules(), "(i (* x (cos x)) x)" => "(+ (* x (sin x)) (cos x))" 300 | } 301 | 302 | egg::test_fn! { 303 | integ_part2, rules(), 304 | "(i (* (cos x) x) x)" => "(+ (* x (sin x)) (cos x))" 305 | } 306 | 307 | egg::test_fn! { 308 | integ_part3, rules(), "(i (ln x) x)" => "(- (* x (ln x)) x)" 309 | } 310 | 311 | #[test] 312 | fn assoc_mul_saturates() { 313 | let expr: RecExpr = "(* x 1)".parse().unwrap(); 314 | 315 | let runner: Runner = Runner::default() 316 | .with_iter_limit(3) 317 | .with_expr(&expr) 318 | .run(&rules()); 319 | 320 | assert!(matches!(runner.stop_reason, Some(StopReason::Saturated))); 321 | } 322 | 323 | #[test] 324 | fn test_union_trusted() { 325 | let expr: RecExpr = "(+ (* x 1) y)".parse().unwrap(); 326 | let expr2 = "20".parse().unwrap(); 327 | let mut runner: Runner = Runner::default() 328 | .with_explanations_enabled() 329 | .with_iter_limit(3) 330 | .with_expr(&expr) 331 | .run(&rules()); 332 | let lhs = runner.egraph.add_expr(&expr); 333 | let rhs = runner.egraph.add_expr(&expr2); 334 | runner.egraph.union_trusted(lhs, rhs, "whatever"); 335 | let proof = runner.explain_equivalence(&expr, &expr2).get_flat_strings(); 336 | assert_eq!(proof, vec!["(+ (* x 1) y)", "(Rewrite=> whatever 20)"]); 337 | } 338 | 339 | #[cfg(feature = "lp")] 340 | #[test] 341 | fn math_lp_extract() { 342 | let expr: RecExpr = "(pow (+ x (+ x x)) (+ x x))".parse().unwrap(); 343 | 344 | let runner: Runner = Runner::default() 345 | .with_iter_limit(3) 346 | .with_expr(&expr) 347 | .run(&rules()); 348 | let root = runner.roots[0]; 349 | 350 | let best = Extractor::new(&runner.egraph, AstSize).find_best(root).1; 351 | let lp_best = LpExtractor::new(&runner.egraph, AstSize).solve(root); 352 | 353 | println!("input [{}] {}", expr.as_ref().len(), expr); 354 | println!("normal [{}] {}", best.as_ref().len(), best); 355 | println!("ilp cse [{}] {}", lp_best.as_ref().len(), lp_best); 356 | 357 | assert_ne!(best, lp_best); 358 | assert_eq!(lp_best.as_ref().len(), 4); 359 | } 360 | 361 | #[test] 362 | fn math_ematching_bench() { 363 | let exprs = &[ 364 | "(i (ln x) x)", 365 | "(i (+ x (cos x)) x)", 366 | "(i (* (cos x) x) x)", 367 | "(d x (+ 1 (* 2 x)))", 368 | "(d x (- (pow x 3) (* 7 (pow x 2))))", 369 | "(+ (* y (+ x y)) (- (+ x 2) (+ x x)))", 370 | "(/ 1 (- (/ (+ 1 (sqrt five)) 2) (/ (- 1 (sqrt five)) 2)))", 371 | ]; 372 | 373 | let extra_patterns = &[ 374 | "(+ ?a (+ ?b ?c))", 375 | "(+ (+ ?a ?b) ?c)", 376 | "(* ?a (* ?b ?c))", 377 | "(* (* ?a ?b) ?c)", 378 | "(+ ?a (* -1 ?b))", 379 | "(* ?a (pow ?b -1))", 380 | "(* ?a (+ ?b ?c))", 381 | "(pow ?a (+ ?b ?c))", 382 | "(+ (* ?a ?b) (* ?a ?c))", 383 | "(* (pow ?a ?b) (pow ?a ?c))", 384 | "(* ?x (/ 1 ?x))", 385 | "(d ?x (+ ?a ?b))", 386 | "(+ (d ?x ?a) (d ?x ?b))", 387 | "(d ?x (* ?a ?b))", 388 | "(+ (* ?a (d ?x ?b)) (* ?b (d ?x ?a)))", 389 | "(d ?x (sin ?x))", 390 | "(d ?x (cos ?x))", 391 | "(* -1 (sin ?x))", 392 | "(* -1 (cos ?x))", 393 | "(i (cos ?x) ?x)", 394 | "(i (sin ?x) ?x)", 395 | "(d ?x (ln ?x))", 396 | "(d ?x (pow ?f ?g))", 397 | "(* (pow ?f ?g) (+ (* (d ?x ?f) (/ ?g ?f)) (* (d ?x ?g) (ln ?f))))", 398 | "(i (pow ?x ?c) ?x)", 399 | "(/ (pow ?x (+ ?c 1)) (+ ?c 1))", 400 | "(i (+ ?f ?g) ?x)", 401 | "(i (- ?f ?g) ?x)", 402 | "(+ (i ?f ?x) (i ?g ?x))", 403 | "(- (i ?f ?x) (i ?g ?x))", 404 | "(i (* ?a ?b) ?x)", 405 | "(- (* ?a (i ?b ?x)) (i (* (d ?x ?a) (i ?b ?x)) ?x))", 406 | ]; 407 | 408 | egg::test::bench_egraph("math", rules(), exprs, extra_patterns); 409 | } 410 | 411 | #[test] 412 | fn test_basic_egraph_union_intersect() { 413 | let mut egraph1 = EGraph::new(ConstantFold {}).with_explanations_enabled(); 414 | let mut egraph2 = EGraph::new(ConstantFold {}).with_explanations_enabled(); 415 | egraph1.union_instantiations( 416 | &"x".parse().unwrap(), 417 | &"y".parse().unwrap(), 418 | &Default::default(), 419 | "", 420 | ); 421 | egraph1.union_instantiations( 422 | &"y".parse().unwrap(), 423 | &"z".parse().unwrap(), 424 | &Default::default(), 425 | "", 426 | ); 427 | egraph2.union_instantiations( 428 | &"x".parse().unwrap(), 429 | &"y".parse().unwrap(), 430 | &Default::default(), 431 | "", 432 | ); 433 | egraph2.union_instantiations( 434 | &"x".parse().unwrap(), 435 | &"a".parse().unwrap(), 436 | &Default::default(), 437 | "", 438 | ); 439 | 440 | let mut egraph3 = egraph1.egraph_intersect(&egraph2, ConstantFold {}); 441 | 442 | egraph2.egraph_union(&egraph1); 443 | 444 | assert_eq!( 445 | egraph2.add_expr(&"x".parse().unwrap()), 446 | egraph2.add_expr(&"y".parse().unwrap()) 447 | ); 448 | assert_eq!( 449 | egraph3.add_expr(&"x".parse().unwrap()), 450 | egraph3.add_expr(&"y".parse().unwrap()) 451 | ); 452 | 453 | assert_eq!( 454 | egraph2.add_expr(&"x".parse().unwrap()), 455 | egraph2.add_expr(&"z".parse().unwrap()) 456 | ); 457 | assert_ne!( 458 | egraph3.add_expr(&"x".parse().unwrap()), 459 | egraph3.add_expr(&"z".parse().unwrap()) 460 | ); 461 | assert_eq!( 462 | egraph2.add_expr(&"x".parse().unwrap()), 463 | egraph2.add_expr(&"a".parse().unwrap()) 464 | ); 465 | assert_ne!( 466 | egraph3.add_expr(&"x".parse().unwrap()), 467 | egraph3.add_expr(&"a".parse().unwrap()) 468 | ); 469 | 470 | assert_eq!( 471 | egraph2.add_expr(&"y".parse().unwrap()), 472 | egraph2.add_expr(&"a".parse().unwrap()) 473 | ); 474 | assert_ne!( 475 | egraph3.add_expr(&"y".parse().unwrap()), 476 | egraph3.add_expr(&"a".parse().unwrap()) 477 | ); 478 | } 479 | 480 | #[test] 481 | fn test_intersect_basic() { 482 | let mut egraph1 = EGraph::new(ConstantFold {}).with_explanations_enabled(); 483 | let mut egraph2 = EGraph::new(ConstantFold {}).with_explanations_enabled(); 484 | egraph1.union_instantiations( 485 | &"(+ x 0)".parse().unwrap(), 486 | &"(+ y 0)".parse().unwrap(), 487 | &Default::default(), 488 | "", 489 | ); 490 | egraph2.union_instantiations( 491 | &"x".parse().unwrap(), 492 | &"y".parse().unwrap(), 493 | &Default::default(), 494 | "", 495 | ); 496 | egraph2.add_expr(&"(+ x 0)".parse().unwrap()); 497 | egraph2.add_expr(&"(+ y 0)".parse().unwrap()); 498 | 499 | let mut egraph3 = egraph1.egraph_intersect(&egraph2, ConstantFold {}); 500 | 501 | assert_ne!( 502 | egraph3.add_expr(&"x".parse().unwrap()), 503 | egraph3.add_expr(&"y".parse().unwrap()) 504 | ); 505 | assert_eq!( 506 | egraph3.add_expr(&"(+ x 0)".parse().unwrap()), 507 | egraph3.add_expr(&"(+ y 0)".parse().unwrap()) 508 | ); 509 | } 510 | 511 | #[test] 512 | fn test_medium_intersect() { 513 | let mut egraph1 = egg::EGraph::::new(()); 514 | 515 | egraph1.add_expr(&"(sqrt (ln 1))".parse().unwrap()); 516 | let ln = egraph1.add_expr(&"(ln 1)".parse().unwrap()); 517 | let a = egraph1.add_expr(&"(sqrt (sin pi))".parse().unwrap()); 518 | let b = egraph1.add_expr(&"(* 1 pi)".parse().unwrap()); 519 | let pi = egraph1.add_expr(&"pi".parse().unwrap()); 520 | egraph1.union(a, b); 521 | egraph1.union(a, pi); 522 | let c = egraph1.add_expr(&"(+ pi pi)".parse().unwrap()); 523 | egraph1.union(ln, c); 524 | let k = egraph1.add_expr(&"k".parse().unwrap()); 525 | let one = egraph1.add_expr(&"1".parse().unwrap()); 526 | egraph1.union(k, one); 527 | egraph1.rebuild(); 528 | 529 | assert_eq!( 530 | egraph1.add_expr(&"(ln k)".parse().unwrap()), 531 | egraph1.add_expr(&"(+ (* k pi) (* k pi))".parse().unwrap()) 532 | ); 533 | 534 | let mut egraph2 = egg::EGraph::::new(()); 535 | let ln = egraph2.add_expr(&"(ln 2)".parse().unwrap()); 536 | let k = egraph2.add_expr(&"k".parse().unwrap()); 537 | let mk1 = egraph2.add_expr(&"(* k 1)".parse().unwrap()); 538 | egraph2.union(mk1, k); 539 | let two = egraph2.add_expr(&"2".parse().unwrap()); 540 | egraph2.union(mk1, two); 541 | let mul2pi = egraph2.add_expr(&"(+ (* 2 pi) (* 2 pi))".parse().unwrap()); 542 | egraph2.union(ln, mul2pi); 543 | egraph2.rebuild(); 544 | 545 | assert_eq!( 546 | egraph2.add_expr(&"(ln k)".parse().unwrap()), 547 | egraph2.add_expr(&"(+ (* k pi) (* k pi))".parse().unwrap()) 548 | ); 549 | 550 | let mut egraph3 = egraph1.egraph_intersect(&egraph2, ()); 551 | 552 | assert_eq!( 553 | egraph3.add_expr(&"(ln k)".parse().unwrap()), 554 | egraph3.add_expr(&"(+ (* k pi) (* k pi))".parse().unwrap()) 555 | ); 556 | } 557 | -------------------------------------------------------------------------------- /tests/prop.rs: -------------------------------------------------------------------------------- 1 | use egg::*; 2 | 3 | define_language! { 4 | enum Prop { 5 | Bool(bool), 6 | "&" = And([Id; 2]), 7 | "~" = Not(Id), 8 | "|" = Or([Id; 2]), 9 | "->" = Implies([Id; 2]), 10 | Symbol(Symbol), 11 | } 12 | } 13 | 14 | type EGraph = egg::EGraph; 15 | type Rewrite = egg::Rewrite; 16 | 17 | #[derive(Default)] 18 | struct ConstantFold; 19 | impl Analysis for ConstantFold { 20 | type Data = Option<(bool, PatternAst)>; 21 | fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { 22 | merge_option(to, from, |a, b| { 23 | assert_eq!(a.0, b.0, "Merged non-equal constants"); 24 | DidMerge(false, false) 25 | }) 26 | } 27 | 28 | fn make(egraph: &mut EGraph, enode: &Prop) -> Self::Data { 29 | let x = |i: &Id| egraph[*i].data.as_ref().map(|c| c.0); 30 | let result = match enode { 31 | Prop::Bool(c) => Some((*c, c.to_string().parse().unwrap())), 32 | Prop::Symbol(_) => None, 33 | Prop::And([a, b]) => Some(( 34 | x(a)? && x(b)?, 35 | format!("(& {} {})", x(a)?, x(b)?).parse().unwrap(), 36 | )), 37 | Prop::Not(a) => Some((!x(a)?, format!("(~ {})", x(a)?).parse().unwrap())), 38 | Prop::Or([a, b]) => Some(( 39 | x(a)? || x(b)?, 40 | format!("(| {} {})", x(a)?, x(b)?).parse().unwrap(), 41 | )), 42 | Prop::Implies([a, b]) => Some(( 43 | !x(a)? || x(b)?, 44 | format!("(-> {} {})", x(a)?, x(b)?).parse().unwrap(), 45 | )), 46 | }; 47 | println!("Make: {:?} -> {:?}", enode, result); 48 | result 49 | } 50 | 51 | fn modify(egraph: &mut EGraph, id: Id) { 52 | if let Some(c) = egraph[id].data.clone() { 53 | egraph.union_instantiations( 54 | &c.1, 55 | &c.0.to_string().parse().unwrap(), 56 | &Default::default(), 57 | "analysis".to_string(), 58 | ); 59 | } 60 | } 61 | } 62 | 63 | macro_rules! rule { 64 | ($name:ident, $left:literal, $right:literal) => { 65 | #[allow(dead_code)] 66 | fn $name() -> Rewrite { 67 | rewrite!(stringify!($name); $left => $right) 68 | } 69 | }; 70 | ($name:ident, $name2:ident, $left:literal, $right:literal) => { 71 | rule!($name, $left, $right); 72 | rule!($name2, $right, $left); 73 | }; 74 | } 75 | 76 | rule! {def_imply, def_imply_flip, "(-> ?a ?b)", "(| (~ ?a) ?b)" } 77 | rule! {double_neg, double_neg_flip, "(~ (~ ?a))", "?a" } 78 | rule! {assoc_or, "(| ?a (| ?b ?c))", "(| (| ?a ?b) ?c)" } 79 | rule! {dist_and_or, "(& ?a (| ?b ?c))", "(| (& ?a ?b) (& ?a ?c))"} 80 | rule! {dist_or_and, "(| ?a (& ?b ?c))", "(& (| ?a ?b) (| ?a ?c))"} 81 | rule! {comm_or, "(| ?a ?b)", "(| ?b ?a)" } 82 | rule! {comm_and, "(& ?a ?b)", "(& ?b ?a)" } 83 | rule! {lem, "(| ?a (~ ?a))", "true" } 84 | rule! {or_true, "(| ?a true)", "true" } 85 | rule! {and_true, "(& ?a true)", "?a" } 86 | rule! {contrapositive, "(-> ?a ?b)", "(-> (~ ?b) (~ ?a))" } 87 | 88 | // this has to be a multipattern since (& (-> ?a ?b) (-> (~ ?a) ?c)) != (| ?b ?c) 89 | // see https://github.com/egraphs-good/egg/issues/185 90 | fn lem_imply() -> Rewrite { 91 | multi_rewrite!( 92 | "lem_imply"; 93 | "?value = true = (& (-> ?a ?b) (-> (~ ?a) ?c))" 94 | => 95 | "?value = (| ?b ?c)" 96 | ) 97 | } 98 | 99 | fn prove_something(name: &str, start: &str, rewrites: &[Rewrite], goals: &[&str]) { 100 | let _ = env_logger::builder().is_test(true).try_init(); 101 | println!("Proving {}", name); 102 | 103 | let start_expr: RecExpr<_> = start.parse().unwrap(); 104 | let goal_exprs: Vec> = goals.iter().map(|g| g.parse().unwrap()).collect(); 105 | 106 | let mut runner = Runner::default() 107 | .with_iter_limit(20) 108 | .with_node_limit(5_000) 109 | .with_expr(&start_expr); 110 | 111 | // we are assume the input expr is true 112 | // this is needed for the soundness of lem_imply 113 | let true_id = runner.egraph.add(Prop::Bool(true)); 114 | let root = runner.roots[0]; 115 | runner.egraph.union(root, true_id); 116 | runner.egraph.rebuild(); 117 | 118 | let egraph = runner.run(rewrites).egraph; 119 | 120 | for (i, (goal_expr, goal_str)) in goal_exprs.iter().zip(goals).enumerate() { 121 | println!("Trying to prove goal {}: {}", i, goal_str); 122 | let equivs = egraph.equivs(&start_expr, goal_expr); 123 | if equivs.is_empty() { 124 | panic!("Couldn't prove goal {}: {}", i, goal_str); 125 | } 126 | } 127 | } 128 | 129 | #[test] 130 | fn prove_contrapositive() { 131 | let _ = env_logger::builder().is_test(true).try_init(); 132 | let rules = &[def_imply(), def_imply_flip(), double_neg_flip(), comm_or()]; 133 | prove_something( 134 | "contrapositive", 135 | "(-> x y)", 136 | rules, 137 | &[ 138 | "(-> x y)", 139 | "(| (~ x) y)", 140 | "(| (~ x) (~ (~ y)))", 141 | "(| (~ (~ y)) (~ x))", 142 | "(-> (~ y) (~ x))", 143 | ], 144 | ); 145 | } 146 | 147 | #[test] 148 | fn prove_chain() { 149 | let _ = env_logger::builder().is_test(true).try_init(); 150 | let rules = &[ 151 | // rules needed for contrapositive 152 | def_imply(), 153 | def_imply_flip(), 154 | double_neg_flip(), 155 | comm_or(), 156 | // and some others 157 | comm_and(), 158 | lem_imply(), 159 | ]; 160 | prove_something( 161 | "chain", 162 | "(& (-> x y) (-> y z))", 163 | rules, 164 | &[ 165 | "(& (-> x y) (-> y z))", 166 | "(& (-> (~ y) (~ x)) (-> y z))", 167 | "(& (-> y z) (-> (~ y) (~ x)))", 168 | "(| z (~ x))", 169 | "(| (~ x) z)", 170 | "(-> x z)", 171 | ], 172 | ); 173 | } 174 | 175 | #[test] 176 | fn const_fold() { 177 | let start = "(| (& false true) (& true false))"; 178 | let start_expr = start.parse().unwrap(); 179 | let end = "false"; 180 | let end_expr = end.parse().unwrap(); 181 | let mut eg = EGraph::default(); 182 | eg.add_expr(&start_expr); 183 | eg.rebuild(); 184 | assert!(!eg.equivs(&start_expr, &end_expr).is_empty()); 185 | } 186 | -------------------------------------------------------------------------------- /tests/simple.rs: -------------------------------------------------------------------------------- 1 | use egg::*; 2 | 3 | define_language! { 4 | enum SimpleLanguage { 5 | Num(i32), 6 | "+" = Add([Id; 2]), 7 | "*" = Mul([Id; 2]), 8 | Symbol(Symbol), 9 | } 10 | } 11 | 12 | fn make_rules() -> Vec> { 13 | vec![ 14 | rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), 15 | rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), 16 | rewrite!("add-0"; "(+ ?a 0)" => "?a"), 17 | rewrite!("mul-0"; "(* ?a 0)" => "0"), 18 | rewrite!("mul-1"; "(* ?a 1)" => "?a"), 19 | ] 20 | } 21 | 22 | /// parse an expression, simplify it using egg, and pretty print it back out 23 | fn simplify(s: &str) -> String { 24 | // parse the expression, the type annotation tells it which Language to use 25 | let expr: RecExpr = s.parse().unwrap(); 26 | 27 | // simplify the expression using a Runner, which creates an e-graph with 28 | // the given expression and runs the given rules over it 29 | let runner = Runner::default().with_expr(&expr).run(&make_rules()); 30 | 31 | // the Runner knows which e-class the expression given with `with_expr` is in 32 | let root = runner.roots[0]; 33 | 34 | // use an Extractor to pick the best element of the root eclass 35 | let extractor = Extractor::new(&runner.egraph, AstSize); 36 | let (best_cost, best) = extractor.find_best(root); 37 | println!("Simplified {} to {} with cost {}", expr, best, best_cost); 38 | best.to_string() 39 | } 40 | 41 | #[test] 42 | fn simple_tests() { 43 | assert_eq!(simplify("(* 0 42)"), "0"); 44 | assert_eq!(simplify("(+ 0 (* 1 foo))"), "foo"); 45 | } 46 | --------------------------------------------------------------------------------