├── .idea ├── vcs.xml ├── .gitignore ├── modules.xml ├── egg.iml └── misc.xml ├── .gitignore ├── scripts ├── nightly-static │ ├── .eslintrc.json │ ├── index.html │ └── main.js └── run-nightly.sh ├── src ├── eggstentions │ ├── pretty_string.rs │ ├── mod.rs │ ├── appliers.rs │ ├── costs.rs │ ├── searchers.rs │ ├── tree.rs │ ├── expression_ops.rs │ └── reconstruct.rs ├── explanation.rs ├── tutorials │ ├── mod.rs │ └── _02_getting_started.rs ├── eclass.rs ├── util.rs ├── lib.rs ├── expression_ops.rs ├── subst.rs ├── tools.rs ├── test.rs ├── unionfind.rs ├── extract.rs ├── colors.rs ├── macros.rs └── rewrite.rs ├── Makefile ├── LICENSE ├── release.toml ├── .github └── workflows │ └── build.yml ├── Cargo.toml ├── doc ├── egraphs.drawio └── egg.svg ├── README.md ├── CHANGELOG.md └── tests ├── prop.rs ├── lambda.rs └── math.rs /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | **/*.rs.bk 3 | Cargo.lock 4 | 5 | dots/ 6 | egg/data/ 7 | 8 | # nix stuff 9 | .envrc 10 | default.nix 11 | 12 | # other stuff 13 | TODO.md 14 | 15 | *.dot 16 | .idea/easter-egg.iml 17 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | workspace.xml -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /scripts/nightly-static/.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "parserOptions": {"ecmaVersion": 8 }, 3 | "env": { 4 | "es6": true, 5 | "browser": true 6 | }, 7 | "globals": { 8 | "d3": "readonly", 9 | "tippy": "readonly" 10 | }, 11 | "extends": "eslint:recommended" 12 | } 13 | -------------------------------------------------------------------------------- /src/eggstentions/pretty_string.rs: -------------------------------------------------------------------------------- 1 | use crate::{Language, Pattern}; 2 | 3 | /// A trait for pretty printing a pattern. 4 | pub trait PrettyString { 5 | /// Returns a string representation of the pattern. 6 | fn pretty_string(&self) -> String; 7 | } 8 | 9 | impl PrettyString for Pattern { 10 | fn pretty_string(&self) -> String { 11 | self.pretty(500) 12 | } 13 | } -------------------------------------------------------------------------------- /src/eggstentions/mod.rs: -------------------------------------------------------------------------------- 1 | /// Module for code related to ematching an EGraph. 2 | pub mod searchers; 3 | /// Module for special appliers to create specialized Rewrites. 4 | pub mod appliers; 5 | /// Module for code related to reconstructing a term from the EGraph. 6 | pub mod reconstruct; 7 | /// Module for code related to the cost functions. 8 | pub mod costs; 9 | /// Module for utilities when dealing with expressions. 10 | pub mod expression_ops; 11 | /// Module for code related to the pretty printing of terms. 12 | pub mod pretty_string; 13 | /// Module for code related to the tree structure of terms. 14 | pub mod tree; 15 | -------------------------------------------------------------------------------- /.idea/egg.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /src/explanation.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused_imports)] 2 | use derive_new::new; 3 | use crate::{unionfind::UnionFindWrapper, Analysis, Id, Language}; 4 | 5 | // /// Explanation object collects all information needed to explain existance and equality of enodes in the egraph. 6 | // /// It holds a mapping from an Id to the node added when said Id was first added to the egraph. 7 | // /// It also holds a union-find structure to query equivalencies between nodes that isn't merged upwards. 8 | // #[derive(Debug, new)] 9 | // pub struct Explanation> { 10 | // /// Just handle everything with a non collapsing union find. 11 | // uf: UnionFindWrapper<(), Id>, 12 | // } -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: 2 | test: test-egg 3 | cargo fmt -- --check 4 | cargo clean --doc 5 | cargo doc --no-deps 6 | cargo deadlinks 7 | 8 | .PHONY: test-egg 9 | test-egg: 10 | cargo build 11 | cargo test --release 12 | 13 | cargo clippy --tests 14 | cargo clippy --tests --features "serde-1" 15 | cargo clippy --tests --features "reports" 16 | 17 | .PHONY: deploy-nightlies 18 | deploy-nightlies: 19 | rsync -ri --exclude=".*" scripts/nightly-static/ ~/public/egg-nightlies/ 20 | 21 | .PHONY: nightly 22 | nightly: 23 | bash scripts/run-nightly.sh 24 | 25 | # makefile hack to run my hacky benchmarks 26 | bench: 27 | cargo test --features "reports" --release -- --test-threads=1 --nocapture 28 | bench-%: 29 | cargo test --features "reports" --release -- --test-threads=1 --nocapture $* 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 Max Willsey 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /release.toml: -------------------------------------------------------------------------------- 1 | dev-version-ext = "dev" 2 | 3 | [[pre-release-replacements]] 4 | file = "README.md" 5 | search = "egg = .*" 6 | replace="{{crate_name}} = \"{{version}}\"" 7 | 8 | [[pre-release-replacements]] 9 | file = "src/tutorials/_02_getting_started.rs" 10 | search = "egg = .*" 11 | replace="{{crate_name}} = \"{{version}}\"" 12 | 13 | [[pre-release-replacements]] 14 | file = "CHANGELOG.md" 15 | search = "Unreleased" 16 | replace = "{{version}}" 17 | 18 | [[pre-release-replacements]] 19 | file = "CHANGELOG.md" 20 | search = "\\.\\.\\.HEAD" 21 | replace = "...{{tag_name}}" 22 | exactly=1 23 | 24 | [[pre-release-replacements]] 25 | file = "CHANGELOG.md" 26 | search = "ReleaseDate" 27 | replace = "{{date}}" 28 | 29 | [[pre-release-replacements]] 30 | file = "CHANGELOG.md" 31 | search = "" 32 | replace = "\n\n## [Unreleased] - ReleaseDate" 33 | exactly = 1 34 | 35 | [[pre-release-replacements]] 36 | file = "CHANGELOG.md" 37 | search = "" 38 | replace = "\n[Unreleased]: https://github.com/mwillsey/egg/compare/{{tag_name}}...HEAD" 39 | exactly = 1 40 | -------------------------------------------------------------------------------- /src/tutorials/mod.rs: -------------------------------------------------------------------------------- 1 | // -*- mode: markdown; markdown-fontify-code-block-default-mode: rustic-mode; -*- 2 | /*! 3 | 4 | # A Guide-level Explanation of `egg` 5 | 6 | `egg` is a e-graph library optimized for equality saturation. 7 | Using these techniques, one can pretty easily build an optimizer or synthesizer for a language. 8 | 9 | This tutorial is targeted at readers who may not have seen e-graphs, equality saturation, or Rust. 10 | If you already know some of that stuff, you may just want to skim or skip those chapters. 11 | 12 | This is intended to be guide-level introduction using examples to build intuition. 13 | For more detail, check out the [API documentation](../index.html), 14 | which the tutorials will link to frequently. 15 | Most of the code examples here are typechecked and run, so you may copy-paste them to get started. 16 | 17 | There is also a [paper](https://arxiv.org/abs/2004.03082) 18 | describing `egg` and if are keen to read more about its technical novelties. 19 | 20 | The tutorials are a work-in-progress with more to be added soon. 21 | 22 | */ 23 | 24 | pub mod _01_background; 25 | pub mod _02_getting_started; 26 | -------------------------------------------------------------------------------- /scripts/nightly-static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Egg Nightly 6 | 7 | 8 | 9 | 10 | 11 | 44 | 45 | 46 |
    47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /scripts/run-nightly.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | function run() { 4 | echo $@ 5 | $@ 6 | } 7 | 8 | data_dir="$HOME/public/egg-nightlies/data" 9 | test_dir="tests/" 10 | 11 | rev=$(git describe --long --dirty --abbrev=1000) 12 | if [[ "$rev" == *-dirty ]]; then 13 | # if we are dirty, use the current time 14 | now=$(date --iso-8601=seconds) 15 | else 16 | # if clean, use the commit time 17 | now=$(git log -1 --format=%cI) 18 | fi 19 | 20 | 21 | # get the branch either from an environment variable or git 22 | # and replace silly characters 23 | branch=${NIGHTLY_BRANCH:-$(git rev-parse --abbrev-ref HEAD)} 24 | branch=$(echo -n "$branch" | tr -c "[:alnum:]-_" "@") 25 | 26 | run_dir="${data_dir}/${now}___${branch}___${rev}" 27 | 28 | echo "Running nightly into ${run_dir}" 29 | if [ -d "$run_dir" ]; then 30 | echo "Already exists: $run_dir" 31 | exit 0 32 | else 33 | mkdir "$run_dir" 34 | fi 35 | 36 | suites=$(ls $test_dir) 37 | echo -e "Found test suites in ${test_dir}\n${suites}" 38 | 39 | EGG_BENCH=${EGG_BENCH:-10} 40 | export EGG_BENCH 41 | 42 | for suite in $suites; do 43 | suite=${suite/%.rs/} 44 | test_dir="$run_dir/$suite/" 45 | mkdir $test_dir 46 | cmd="cargo test --features reports --no-fail-fast --release \ 47 | --test ${suite} -- --test-threads=1" 48 | EGG_BENCH_DIR=$test_dir run $cmd 49 | done 50 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | # Controls when the action will run. 4 | on: 5 | # Triggers the workflow on push or pull request events but only for the master branch 6 | push: 7 | branches: [ dev, master ] 8 | pull_request: 9 | branches: [ dev, master ] 10 | 11 | # Allows you to run this workflow manually from the Actions tab 12 | workflow_dispatch: 13 | 14 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v2 23 | - name: Setup rust 24 | uses: actions-rs/toolchain@v1 25 | with: 26 | toolchain: stable 27 | components: rustfmt, clippy 28 | - name: Setup Clippy 29 | uses: actions-rs/clippy-check@v1 30 | with: 31 | token: ${{ secrets.GITHUB_TOKEN }} 32 | - name: Install graphviz 33 | run: sudo apt-get update && sudo apt-get install graphviz 34 | - name: Install cargo deadlinks 35 | run: which cargo-deadlinks || cargo install cargo-deadlinks 36 | - name: Build 37 | uses: actions-rs/cargo@v1 38 | with: 39 | command: build 40 | toolchain: stable 41 | - name: Test 42 | uses: actions-rs/cargo@v1 43 | with: 44 | command: test 45 | toolchain: stable 46 | -------------------------------------------------------------------------------- /src/eggstentions/appliers.rs: -------------------------------------------------------------------------------- 1 | use crate::{Applier, EGraph, Id, Pattern, SearchMatches, Subst, SymbolLang}; 2 | use std::fmt::Formatter; 3 | 4 | /// A wrapper around an Applier that applies the applier to all matches. 5 | pub struct DiffApplier> { 6 | applier: T 7 | } 8 | 9 | impl> DiffApplier { 10 | /// Create a new DiffApplier. 11 | pub fn new(applier: T) -> DiffApplier { 12 | DiffApplier { applier } 13 | } 14 | } 15 | 16 | impl DiffApplier> { 17 | /// Returns a string representation of the pattern. 18 | pub fn pretty(&self, width: usize) -> String { 19 | self.applier.pretty(width) 20 | } 21 | } 22 | 23 | impl> std::fmt::Display for DiffApplier { 24 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 25 | write!(f, "-|> {}", self.applier) 26 | } 27 | } 28 | 29 | impl> Applier for DiffApplier { 30 | fn apply_matches(&self, egraph: &mut EGraph, matches: &Option) -> Vec { 31 | let added = vec![]; 32 | if let Some(mat) = matches { 33 | for (eclass, substs) in &mat.matches { 34 | for subst in substs { 35 | let _ids = self.apply_one(egraph, *eclass, subst); 36 | } 37 | } 38 | } 39 | added 40 | } 41 | 42 | fn apply_one(&self, egraph: &mut EGraph, eclass: Id, subst: &Subst) -> Vec { 43 | self.applier.apply_one(egraph, eclass, subst) 44 | } 45 | } -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "easter-egg" 3 | version = "0.1.1-dev" 4 | authors = ["Max Willsey ", "Eytan Boaz Singher "] 5 | edition = "2018" 6 | description = "An implementation of colored egraphs" 7 | repository = "https://github.com/eytans/easter-egg" 8 | readme = "README.md" 9 | license = "MIT" 10 | keywords = ["e-graphs"] 11 | categories = ["data-structures"] 12 | 13 | [dependencies] 14 | symbolic_expressions = "5" 15 | log = "0.4" 16 | smallvec = "1" 17 | indexmap = "2.8" 18 | instant = "0.1" 19 | once_cell = "1" 20 | 21 | serde = { version = "1", features = ["derive"] } 22 | serde_cbor = "0.11.2" 23 | 24 | # for the reports feature 25 | multimap = "0.10.0" 26 | itertools = "0.14.0" 27 | either = "1.9.0" 28 | global_counter = "0.2.1" 29 | bimap = "0.6.2" 30 | bitvec = "1.0.0" 31 | invariants = "0.1.4" 32 | # for very silly reasons 33 | num-traits = "0.2.15" 34 | ordered-float = "5.0" 35 | lazy_static = "1.4.0" 36 | thiserror = "2.0" 37 | serde_json = "1.0.100" 38 | regex = "1.9.1" 39 | rayon = { version = "1.10.0", optional = true } 40 | crossbeam = { version = "0.8.4", optional = true, features = ["crossbeam-channel"] } 41 | num_cpus = "1.16.0" 42 | derive-new = "0.7.0" 43 | maplit = "1.0.2" 44 | derive_more = {version = "2.0", features = ["display"]} 45 | as-any = "0.3.2" 46 | symbol_table = { version = "0.4.0", features = ["global"] } 47 | 48 | [dev-dependencies] 49 | env_logger = {version = "0.11", default-features = false} 50 | 51 | [features] 52 | default = ["parallel"] 53 | #default = ["colored_no_cremove"] 54 | #default = ["colored_no_cmemo"] 55 | upward-merging = [] 56 | wasm-bindgen = [ "instant/wasm-bindgen" ] 57 | reports = [] 58 | colored_no_cremove = [] 59 | colored_no_cmemo = ["colored_no_cremove"] 60 | keep_splits = [] 61 | parallel = ["rayon", "crossbeam", "concurrent_cufind"] 62 | concurrent_cufind = [] 63 | invariants-off = ["invariants/off"] 64 | serde = ["smallvec/serde", "indexmap/serde", "either/serde", "bimap/serde", "ordered-float/serde", "symbol_table/serde"] 65 | -------------------------------------------------------------------------------- /src/eclass.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::iter::ExactSizeIterator; 3 | use indexmap::IndexMap; 4 | 5 | use crate::{ColorId, Id, Language}; 6 | 7 | /// An equivalence class of enodes. 8 | #[non_exhaustive] 9 | #[derive(Debug, Clone)] 10 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 11 | pub struct EClass { 12 | /// This eclass's id. 13 | pub id: Id, 14 | /// The equivalent enodes in this equivalence class with their associated colors. 15 | /// No color set means it is a black edge (which will never be removed). 16 | pub nodes: Vec, 17 | /// The analysis data associated with this eclass. 18 | pub data: D, 19 | pub(crate) parents: Vec<(L, Id)>, 20 | 21 | /// Each EClass has a unique color (None is black, i.e. default). 22 | pub(crate) color: Option, 23 | /// Colored parents are colored_canonized pointing to the black ID of the class. 24 | pub(crate) colored_parents: IndexMap>, 25 | } 26 | 27 | impl EClass { 28 | /// Returns `true` if the `eclass` is empty. 29 | pub fn is_empty(&self) -> bool { 30 | self.nodes.is_empty() 31 | } 32 | 33 | /// Returns the number of enodes in this eclass. 34 | pub fn len(&self) -> usize { 35 | self.nodes.len() 36 | } 37 | 38 | /// Iterates over the enodes in this eclass. 39 | pub fn iter(&self) -> impl ExactSizeIterator { 40 | self.nodes.iter() 41 | } 42 | 43 | /// Returns the color of the EClass if exists. None means it's a black class. 44 | pub fn color(&self) -> Option { 45 | self.color 46 | } 47 | } 48 | 49 | impl EClass { 50 | /// Iterates over the childless enodes in this eclass. 51 | pub fn leaves(&self) -> impl Iterator { 52 | self.nodes.iter().filter(|n| n.is_leaf()) 53 | } 54 | 55 | /// Asserts that the childless enodes in this eclass are unique. 56 | pub fn assert_unique_leaves(&self) 57 | where 58 | L: Language, 59 | { 60 | let mut leaves = self.leaves(); 61 | if let Some(first) = leaves.next() { 62 | assert!( 63 | leaves.all(|l| l == first), 64 | "Different leaves in eclass {}: {:?}", 65 | self.id, 66 | self.leaves().collect::>() 67 | ); 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /doc/egraphs.drawio: -------------------------------------------------------------------------------- 1 | 7V3bkto4EP2aedmqpZB85TGZZHYrtcmmKg9JHr3ggLMezBrPhXz9yiAZLBks25LVBpOqDJKFwN1HUqv7tHxn3T++/pEGm9XHZBHGd3i6eL2z3t1hjGYOJn/ymh2tmU6nh5plGi1o3bHiS/QrZA1p7VO0CLelhlmSxFm0KVfOk/U6nGeluiBNk5dysx9JXP7WTbAMhYov8yAWa79Gi2x1qPWd6bH+zzBartg3I3Z/jwFrTCu2q2CRvJxUWe/vrPs0SbLDu8fX+zDOpcfkcvjcw5mrxQ9Lw3Um84EP/vuZ/XPx6dfXb+mnOfrwd/Lt/neL6uc5iJ/oHdNfm+2YCJZp8rQRv43+gOcwzcLXKl0E/7AejrdLgBImj2GW7kg7hgkmoZejgJFN61YnwsUWrQyoUpdFX8f7Jm/orTcQA6qXAhHCehHmnaA76+3LKsrCL5tgnl99IdgndavsMaaXf0RxfJ/ESUrK62RNGr1dBNtV8fG88DnIsjBd72vwNK/dZmnybwEzIrG3otQva5HXxXmZu6LM3QqRu7okLoG7K5E4vepTSUJRgHVrCrANC9y+VYHvygPAlPwdQdzhgiz1tEglmKTZKlkm6yD+K0k2VJI/wyzbUdMkeMqSsh7C1yj7Rt5PJ55Di99PLr3LpTFlhR0tHNTA1HdHVrb9i1wpYyBcM5vIaq6qbfKUzsMLDd1DuyxIl+GlDml/ubwuKj4N4yCLnsumU5Ui6Uc/JxG5kQIw2LYnnj87vtwygJzJzMIuQhZCroddu9z/4S5ol0esvEnTYHfSbJM32J7/Edw87XFGFdcaoUvNyZvD1x9RW0irPZDdionEjbN8Skj2v+0Icfe/p4Rd+H27B9Ib0gDbm9c9nNh18m55+PvA+iJvD92xK53mKtITxTG2STmfRCJiZL+Jo2U+Oz1Gi0Xesb7JyOGWX3EyqrI4bV2TkadPh8GVapAbacgzrEJfUOFvQxcxN0hmhiU80zdI8JUOkhmwQcIG7SWrl9g4b3JPDSnN42C7jeZlqZZN1rLJheVNLr2mlC9pSiFJU+pEYU6FwlhdR4vLd8pw4XFwuG3BphL7mU24nmZcT2esM1U2EZLw4HTBWQPTHgbOZE32fnBWeEW7As31JlN0skPwuenOn8z86fGF+wWhhFPLyGRHsJfuDn1YVlHxvWhLCsdu9qXdaelzmEZENmFKK1VjWnYbasPCtKUI02hKQI2PoEXlbl2zmJbwE55gmrpRBECfeDoe9q/mMAKids4/IMQkpLVu13SkWbGWhGK1RX6Yw4dfFioiQZX2qsZIkGo3bcknKwwGd/+iW4eT+sOr2subkO+KslxkXpPZ2DqjNbgxIuSMugAaPUJVTsjbVI3puBKq8iV28mLUBJbEENKGGYh5p9F6Se3E6giUdkWcjzcV60s/ivFFSwlQxKkYU5ciTqdYIJ/54eT/Og+u2h0Boq6/encKXSK0h6amwwtNeZ2aI6eHUBbS6OI1Hss6N3aar259LWXAgmRstr7CKNnwsAEt/IZFj29f8bfhKQ9YYI9ZqFcY2RseNqCFDHEzj2N/XvTmqtVsJReE+ForGZbn3FYUdbRNRx1x35xOqFtvBM0/hatchy1WGGF5QcMgjjTXILhlQMLDaCSiP+BlQNZZ0s8y4HKIa7sMYJ4TgDkkGw2gYgn37GjOnDqDYMLTaYlOrhu/ZwtF9EF3cji3Rhu9z2prRTWSWHRYnWEsjaVua95MnBrEiMH8KX0uDIu+wgftWG3KFYZhKYz97rKZGTzm5h21EsulgRuJnBPQtB/Jasb61EFOskCtWTwnrTU7iWfm9M1OcqoCQ5xmdbOTHBl2UpUy9bGTLPMOh06ehUKtrT0LprNILWCcJM4mr9CnetVAcfJYwEhIJnQBlB9mmScl9TxRVdDAikW0H5GLCYm3OhqgMcEsqX3dLTHBmE5rd4ZMlbU7Q7YYaGeCOSMTTAMTjOl5ZIJpIAU0n0GBMcHsqo3+yAQzgw1oTDAbC+C4FSZY54Ft2oNnV5F9RiaYGWxAowDYEt6uqw6dylvJWNZKhkUBmCligs1MM8Hs0S0IlgtmKzowDRoXbHgLDDiSma2ZnAOeZKZhhZH1w/QUEeV3K22XGA/yyTO2hO/X9MkzyPVYReuTZ+ANC2n3pA96WChit3k9W12iX70vdtt5xDXgvRnyqTuyhGHp3UI/bCsGUoD0OFk4QNe4DUzjlX5f0/y6ARrewIh77LEwl0dyWXtmeMuiVaR8bFrGZuMzEc7qcKHqGCZPMayJSvKe64bNGSfvrHl/sbmemKcjOr+BHSGtYItveqbR7ENuvcM/2Rm5CnZGug0I6UkKVqoxnuEJt1Xn/YxtKdE973Uc8fklENZH5UDzZIGmh+/TdRXjyDK1x372Qq5xXJ3YaT0F9mBbDQ1NeAhoqnKaP1yb4WJ8iyR1eqC6YXwyhqcTp9MoBueltaSnAU/1NNDtETMAzmj2+QCK+SwoZsHpDLbXZxK0j6n7nc9XkcyCEoJd6lSg+ompUPgOzXUDje7g4lE1QPNAXEvQxI3ngfiyJwSwRIraJZzNTGMeyCDzQNwqf92YB6ImcNR8BgWWB+IqOtZtzANRYZQCI/O5Is/zVvJAOg9s094OV+MDcfE4sJthAxpL14XKbuzL0SVvJcMi36pK7+APDOY3aJpDb27VYURjAoEJkwPa1MS+Hxy9AOzUpG4DPyYQqJviPM3PT761BAINfi0MelgMM4HAE1F+wwkE0qD1lJ/CKgUO1Dc6pNz1N5RvAA4gYq6xPcG+4yGb/t/zKlrpvL71dIUWZj6wdAVP5F4OJl3hgpYNzQ09HdTGLR4FtW5MY1AcsvNGv/tw8yM8qCcgKM6PADsN1m/8ppLz5ZhR0RD6IsMWzhIOCKnMxjZFvbnKlAxPPLoAZkrGCMdbyOlgQjWS0zF8K870BtVHotE2JotUKln3tpMbvW0fwFPTjWbTiD1CsfWDlXQ+m0Fak5oNYjY/80dN1qr43Ae7eiQ4T4dkyKVhc0a+OG/Ud2le19pu1Zy/3nIpJcU0yQkcx+Zk0Vt9TBZh3uJ/ -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | use std::str::FromStr; 3 | use std::sync::Mutex; 4 | 5 | use indexmap::{IndexMap, IndexSet}; 6 | use once_cell::sync::Lazy; 7 | use std::fmt::Formatter; 8 | use std::iter::FromIterator; 9 | use itertools::Itertools; 10 | use serde::de::{Error, MapAccess}; 11 | use serde::{Deserialize, Deserializer, Serialize}; 12 | use serde::ser::SerializeMap; 13 | 14 | pub use symbol_table::GlobalSymbol as Symbol; 15 | 16 | pub(crate) trait JoinDisp { 17 | #[allow(dead_code)] 18 | fn disp_string(self) -> String; 19 | fn sep_string(self, sep: &str) -> String; 20 | } 21 | 22 | impl JoinDisp for I where I: Iterator, 23 | I::Item: fmt::Display { 24 | fn disp_string(self) -> String { 25 | self.sep_string(", ") 26 | } 27 | 28 | fn sep_string(self, sep: &str) -> String { 29 | self.map(|x| format!("{}", x)).join(sep) 30 | } 31 | } 32 | 33 | pub trait Singleton { 34 | fn singleton(t: T) -> Self; 35 | } 36 | 37 | impl Singleton for FI 38 | where FI: FromIterator { 39 | fn singleton(t: T) -> Self { 40 | FI::from_iter(std::iter::once(t)) 41 | } 42 | } 43 | 44 | 45 | 46 | /** A data structure to maintain a queue of unique elements. 47 | 48 | Notably, insert/pop operations have O(1) expected amortized runtime complexity. 49 | */ 50 | #[derive(Clone)] 51 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 52 | pub(crate) struct UniqueQueue 53 | where 54 | T: Eq + std::hash::Hash + Clone, 55 | { 56 | set: IndexSet, 57 | queue: std::collections::VecDeque, 58 | } 59 | 60 | impl Default for UniqueQueue 61 | where 62 | T: Eq + std::hash::Hash + Clone, 63 | { 64 | fn default() -> Self { 65 | UniqueQueue { 66 | set: IndexSet::default(), 67 | queue: std::collections::VecDeque::new(), 68 | } 69 | } 70 | } 71 | 72 | #[allow(dead_code)] 73 | impl UniqueQueue 74 | where 75 | T: Eq + std::hash::Hash + Clone, 76 | { 77 | pub fn insert(&mut self, t: T) { 78 | if self.set.insert(t.clone()) { 79 | self.queue.push_back(t); 80 | } 81 | } 82 | 83 | pub fn extend(&mut self, iter: I) 84 | where 85 | I: IntoIterator, 86 | { 87 | for t in iter.into_iter() { 88 | self.insert(t); 89 | } 90 | } 91 | 92 | pub fn pop(&mut self) -> Option { 93 | let res = self.queue.pop_front(); 94 | res.as_ref().map(|t| self.set.remove(t)); 95 | res 96 | } 97 | 98 | pub fn is_empty(&self) -> bool { 99 | let r = self.queue.is_empty(); 100 | debug_assert_eq!(r, self.set.is_empty()); 101 | r 102 | } 103 | } 104 | 105 | impl IntoIterator for UniqueQueue 106 | where 107 | T: Eq + std::hash::Hash + Clone, 108 | { 109 | type Item = T; 110 | 111 | type IntoIter = as IntoIterator>::IntoIter; 112 | 113 | fn into_iter(self) -> Self::IntoIter { 114 | self.queue.into_iter() 115 | } 116 | } 117 | 118 | impl FromIterator for UniqueQueue 119 | where 120 | A: Eq + std::hash::Hash + Clone, 121 | { 122 | fn from_iter>(iter: T) -> Self { 123 | let mut queue = UniqueQueue::default(); 124 | for t in iter { 125 | queue.insert(t); 126 | } 127 | queue 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # egg logo Easter Egg: Colored E-Graphs 2 | 3 | [![Build Status](https://github.com/eytans/egg/workflows/Build%20and%20Test/badge.svg?branch=features/color_splits)](https://github.com/eytans/egg/actions) 4 | [![Crates.io](https://img.shields.io/crates/v/egg.svg)](https://crates.io/crates/egg) 5 | [![Released Docs.rs](https://docs.rs/egg/badge.svg)](https://docs.rs/egg/) 6 | [![Master docs](https://img.shields.io/badge/docs-master-blue)](https://egraphs-good.github.io/egg/egg/) 7 | 8 | Easter Egg is an extension of the egg library that implements colored e-graphs, enabling efficient representation of multiple congruence relations in a single e-graph structure. 9 | 10 | ## Features 11 | 12 | - Support for multiple congruence relations (colors) in a single e-graph 13 | - Memory-efficient representation of coarsened equality relations 14 | - Optimized algorithms for colored e-graph operations 15 | - Compatible with existing egg functionality 16 | - Has a parallel backoff schedular to run search of rules in parallel. 17 | - Machine has early stop as it is no longer recursive, but stack based machine (I think this could be parallel if we wished) 18 | 19 | ## Using Easter Egg 20 | 21 | Add `easter_egg` to your `Cargo.toml` like this: 22 | ```toml 23 | [dependencies] 24 | easter_egg = { git = "https://github.com/eytans/egg.git" } 25 | ``` 26 | 27 | ## Developing 28 | Easter Egg is written in Rust. 29 | Typically, you install Rust using rustup. 30 | Run cargo doc --open to build and open the documentation in a browser. 31 | Before committing/pushing, make sure to run cargo test, which runs all the tests. 32 | You should also run cargo fmt to format your code and a linter. 33 | 34 | ## Tests 35 | You will need graphviz to run the tests. 36 | Running cargo test will run the tests. 37 | Some tests may time out; try cargo test --release if that happens. 38 | There are several interesting tests in the tests directory: 39 | 40 | egraph.rs implements basic functionality tests for colored e-graphs. 41 | 42 | ## Key Concepts 43 | 44 | Colored E-Graphs: An extension of e-graphs that efficiently represents multiple congruence relations. 45 | Colored E-Classes: E-classes associated with specific colors, representing coarsened equality relations. 46 | Colored Operations: Modified e-graph operations (merge, rebuild, e-matching) that work with colored e-graphs. 47 | 48 | ## Performance 49 | Colored e-graphs offer significant memory savings compared to maintaining separate e-graphs for different assumptions, with a slight trade-off in performance for some operations. 50 | 51 | ## Contributing 52 | Contributions to Easter Egg are welcome! Please feel free to submit issues, pull requests, or reach out with any questions or suggestions. 53 | 54 | 55 | ## Differences from egg 56 | 57 | Easter Egg extends the functionality of egg while maintaining a similar API. However, there are some key differences to be aware of: 58 | 59 | ### API and Stability 60 | 61 | - Easter Egg provides a similar API to egg, but it is currently less stable as it's a newer extension. 62 | - A new "colored" API has been added, allowing users to create colors and perform searches with colors. 63 | 64 | ### Multi-Pattern Handling 65 | 66 | Easter Egg handles multi-patterns slightly differently from egg: 67 | 68 | - The "=" operator is not allowed inside patterns as an operation. 69 | - Instead, we introduce "|=" and "!=" as replacements for conditions. 70 | - These new operators represent new machine operations used during search. 71 | 72 | ### Example Usage 73 | 74 | ```rust 75 | use egg::{EGraph, Rewrite, Runner, Symbol}; 76 | 77 | // Create a new colored e-graph 78 | let mut egraph = EGraph::new(()); 79 | 80 | // Create a new color 81 | let blue_color = egraph.add_color(); 82 | 83 | // Add a term to the e-graph 84 | let x = egraph.add(Symbol::new("x")); 85 | let y = egraph.add(Symbol::new("y")); 86 | 87 | // Perform a colored merge 88 | egraph.colored_union(blue_color, x, y); 89 | let runner = .... 90 | ``` -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | 2 | /*! 3 | 4 | `easter-egg` is a colored e-graphs implementation on top of egg (https://github.com/mwillsey/egg) 5 | (**e**-**g**raphs **g**ood) is a e-graph library optimized for equality saturation. 6 | 7 | This is the API documentation. 8 | 9 | The [tutorial](tutorials/index.html) is a good starting point if you're new to 10 | e-graphs, equality saturation, or Rust. 11 | 12 | The [tests](https://github.com/eytans/easter-egg/tree/master/tests) 13 | on Github provide some more elaborate examples. 14 | 15 | There is also a [paper](https://arxiv.org/abs/2004.03082) 16 | describing `egg` and some of its technical novelties, and a paper on easter egg can be found 17 | [here](https://repositum.tuwien.at/bitstream/20.500.12708/200780/1/Singher-2024-Easter%20Egg%20Equality%20Reasoning%20Based%20on%20E-Graphs%20with%20Multipl...-vor.pdf). 18 | 19 | !*/ 20 | 21 | /* needs to be public for trait `GetOp` */ 22 | pub mod macros; 23 | pub mod explanation; 24 | 25 | #[macro_use] 26 | extern crate global_counter; 27 | 28 | pub mod tutorials; 29 | 30 | mod dot; 31 | mod eclass; 32 | mod egraph; 33 | mod extract; 34 | mod language; 35 | mod machine; 36 | mod pattern; 37 | mod rewrite; 38 | mod run; 39 | mod subst; 40 | mod unionfind; 41 | mod util; 42 | 43 | /// A key to identify [`EClass`](struct.EClass.html)es within an 44 | /// [`EGraph`](struct.EGraph.html). 45 | #[derive(Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Hash)] 46 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 47 | pub struct Id(pub u32); 48 | 49 | 50 | #[derive(Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Hash)] 51 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 52 | pub struct ColorId(usize); 53 | 54 | 55 | impl From for Id { 56 | fn from(n: usize) -> Id { 57 | Id(n as u32) 58 | } 59 | } 60 | 61 | impl From for usize { 62 | fn from(id: Id) -> usize { 63 | id.0 as usize 64 | } 65 | } 66 | 67 | impl Into for u32 { 68 | fn into(self) -> Id { 69 | Id(self) 70 | } 71 | } 72 | 73 | impl Into for i32 { 74 | fn into(self) -> Id { 75 | Id(self as u32) 76 | } 77 | } 78 | 79 | impl std::fmt::Debug for Id { 80 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 81 | write!(f, "{}", self.0) 82 | } 83 | } 84 | 85 | impl std::fmt::Display for Id { 86 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 87 | write!(f, "{}", self.0) 88 | } 89 | } 90 | 91 | impl From for ColorId { 92 | fn from(n: usize) -> ColorId { 93 | ColorId(n as usize) 94 | } 95 | } 96 | 97 | impl From for usize { 98 | fn from(id: ColorId) -> usize { 99 | id.0 as usize 100 | } 101 | } 102 | 103 | impl std::fmt::Debug for ColorId { 104 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 105 | write!(f, "{}", self.0) 106 | } 107 | } 108 | 109 | impl std::fmt::Display for ColorId { 110 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 111 | write!(f, "{}", self.0) 112 | } 113 | } 114 | 115 | pub(crate) use unionfind::SimpleUnionFind; 116 | 117 | pub use { 118 | dot::Dot, 119 | eclass::EClass, 120 | egraph::EGraph, 121 | extract::*, 122 | language::*, 123 | pattern::{ENodeOrVar, Pattern, PatternAst, SearchMatches}, 124 | multipattern::MultiPattern, 125 | rewrite::{Applier, Rewrite, Searcher}, 126 | run::*, 127 | subst::{Subst, Var}, 128 | util::*, 129 | eggstentions::*, 130 | machine::set_global_bind_limit, 131 | }; 132 | 133 | #[cfg(test)] 134 | fn init_logger() { 135 | invariants::set_max_level(log::LevelFilter::Trace); 136 | let _ = env_logger::builder().is_test(true).filter_level(log::LevelFilter::Debug).try_init(); 137 | } 138 | 139 | #[doc(hidden)] 140 | pub mod test; 141 | mod colors; 142 | mod eggstentions; 143 | pub mod tools; 144 | mod multipattern; 145 | 146 | -------------------------------------------------------------------------------- /src/eggstentions/costs.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | 3 | use crate::{CostFunction, Id, SymbolLang}; 4 | use itertools::Itertools; 5 | 6 | /// A cost function that minimizes the depth, number of nodes and maximizes variables in the egraph. 7 | #[derive(Clone, Debug, PartialEq, Eq)] 8 | pub struct RepOrder { 9 | vars: Vec, 10 | depth: usize, 11 | size: usize, 12 | } 13 | 14 | impl RepOrder { 15 | /// Returns the current depth. 16 | pub fn get_depth(&self) -> usize { 17 | self.depth 18 | } 19 | 20 | fn count_ph1(it: &Vec) -> usize { 21 | it.iter().filter(|x| x.ends_with("1")).count() 22 | } 23 | 24 | fn compare_vars(&self, other: &Self) -> Option { 25 | match self.vars.iter().unique().count().partial_cmp(&other.vars.iter().unique().count()) { 26 | None => { Self::count_ph1(&self.vars).partial_cmp(&Self::count_ph1(&other.vars)) } 27 | Some(ord) => { match ord { 28 | Ordering::Less => { Some(Ordering::Less) } 29 | Ordering::Equal => { Self::count_ph1(&self.vars).partial_cmp(&Self::count_ph1(&other.vars)) } 30 | Ordering::Greater => { Some(Ordering::Greater) } 31 | }} 32 | } 33 | 34 | } 35 | } 36 | 37 | impl PartialOrd for RepOrder { 38 | fn partial_cmp(&self, other: &Self) -> Option { 39 | match self.size.partial_cmp(&other.size) { 40 | None => { other.compare_vars(self) } 41 | Some(x) => { 42 | match x { 43 | Ordering::Less => {Some(Ordering::Less)} 44 | Ordering::Equal => { other.compare_vars(self) } 45 | Ordering::Greater => {Some(Ordering::Greater)} 46 | } 47 | } 48 | } 49 | } 50 | } 51 | 52 | impl Ord for RepOrder { 53 | fn cmp(&self, other: &Self) -> Ordering { 54 | self.partial_cmp(other).unwrap_or(Ordering::Equal) 55 | } 56 | } 57 | 58 | /// A struct to create a cost function out of RepOrder. 59 | pub struct MinRep; 60 | 61 | impl CostFunction for MinRep { 62 | /// The `Cost` type. It only requires `PartialOrd` so you can use 63 | /// floating point types, but failed comparisons (`NaN`s) will 64 | /// result in a panic. 65 | /// We choose to count 66 | type Cost = RepOrder; 67 | 68 | /// Calculates the cost of an enode whose children are `Cost`s. 69 | /// 70 | /// For this to work properly, your cost function should be 71 | /// _monotonic_, i.e. `cost` should return a `Cost` greater than 72 | /// any of the child costs of the given enode. 73 | fn cost(&mut self, enode: &SymbolLang, mut costs: C) -> Self::Cost where 74 | C: FnMut(Id) -> Self::Cost { 75 | let current_depth = enode.children.iter().map(|i| costs(*i).depth).max().unwrap_or(0); 76 | let current_size = enode.children.iter().map(|i| costs(*i).size).sum1().unwrap_or(0); 77 | let mut vars = enode.children.iter().flat_map(|i| costs(*i).vars).collect_vec(); 78 | if enode.op.as_str().starts_with("ts_ph") { 79 | vars.push(enode.op.to_string()); 80 | } 81 | RepOrder{depth: current_depth + 1, size: current_size + 1, vars} 82 | } 83 | } 84 | 85 | #[cfg(test)] 86 | mod tests { 87 | use std::iter::FromIterator; 88 | 89 | use crate::eggstentions::costs::RepOrder; 90 | 91 | #[test] 92 | fn compare_two_different_sizes() { 93 | assert!(RepOrder{vars: Vec::new(), depth: 0, size: 1} < RepOrder{vars: Vec::new(), depth: 0, size: 2}); 94 | assert!(RepOrder{vars: Vec::from_iter(vec![":".to_string(), "a".to_string(), "b".to_string()]), depth: 0, size: 1} < RepOrder{vars: Vec::new(), depth: 0, size: 2}); 95 | } 96 | 97 | #[test] 98 | fn compare_two_different_vars() { 99 | assert!(RepOrder{vars: Vec::from_iter(vec![":".to_string(), "a".to_string(), "b".to_string()]), depth: 0, size: 2} < RepOrder{vars: Vec::from_iter(vec![":".to_string(), "a".to_string()]), depth: 0, size: 2}); 100 | } 101 | } -------------------------------------------------------------------------------- /src/expression_ops.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{HashMap, HashSet}; 2 | use std::fmt::Formatter; 3 | use std::hash::{Hash, Hasher}; 4 | 5 | use itertools::Itertools; 6 | use smallvec::alloc::fmt::Display; 7 | use crate::{Language, RecExpr, EGraph, Id}; 8 | 9 | #[derive(Clone, Debug)] 10 | pub struct RecExpSlice<'a, L: Language> { 11 | index: usize, 12 | exp: &'a RecExpr 13 | } 14 | 15 | impl<'a, L: Language> RecExpSlice<'a, L> { 16 | pub fn new(index: usize, exp: &'a RecExpr) -> RecExpSlice<'a, L> { 17 | RecExpSlice{index, exp} 18 | } 19 | 20 | pub fn add_to_graph(&self, graph: &mut EGraph) -> Id { 21 | graph.add_expr(&RecExpr::from(self.exp.as_ref()[..self.index+1].iter().cloned().collect_vec())) 22 | } 23 | 24 | pub fn to_spaceless_string(&self) -> String { 25 | self.to_sexp_string() 26 | .replace(" ", "_") 27 | .replace("(", "PO") 28 | .replace(")", "PC") 29 | .replace("->", "fn") 30 | } 31 | 32 | pub fn to_sexp_string(&self) -> String { 33 | if self.is_leaf() { 34 | format!("{}", self.root().display_op().to_string()) 35 | } else { 36 | format!("({} {})", self.root().display_op().to_string(), self.children().iter().map(|t| t.to_sexp_string()).intersperse(" ".to_string()).collect::()) 37 | } 38 | } 39 | } 40 | 41 | impl<'a, L: Language> PartialEq for RecExpSlice<'a, L> { 42 | fn eq(&self, other: &Self) -> bool { 43 | self.root() == other.root() && self.children() == other.children() 44 | } 45 | } 46 | 47 | impl<'a, L: Language> Hash for RecExpSlice<'a, L> { 48 | fn hash(&self, state: &mut H) { 49 | (self.root().hash(state), self.children().hash(state)).hash(state) 50 | } 51 | } 52 | 53 | impl<'a, L: Language> From<&'a RecExpr> for RecExpSlice<'a, L> { 54 | fn from(expr: &'a RecExpr) -> Self { 55 | RecExpSlice{index: expr.as_ref().len() - 1, exp: expr} 56 | } 57 | } 58 | 59 | impl<'a, L: Language + Clone> From<&'a RecExpSlice<'a, L>> for RecExpr { 60 | fn from(expr: &'a RecExpSlice<'a, L>) -> Self { 61 | // Need to remove unneeded nodes because recexpr comparison works straigt on vec 62 | let mut nodes: Vec> = vec![]; 63 | nodes.push(expr.clone()); 64 | let mut indices = HashSet::new(); 65 | while !nodes.is_empty() { 66 | let current = nodes.pop().unwrap(); 67 | indices.insert(current.index); 68 | for n in current.children() { 69 | nodes.push(n); 70 | } 71 | } 72 | let mut res: Vec = vec![]; 73 | let mut id_trans: HashMap = HashMap::new(); 74 | for i in indices.iter().sorted() { 75 | id_trans.insert(Id::from(*i), Id::from(res.len())); 76 | res.push(expr.exp.as_ref()[*i].clone().map_children(|id| *id_trans.get(&id).unwrap())); 77 | } 78 | RecExpr::from(res) 79 | } 80 | } 81 | 82 | impl<'a, L: Language> Into> for RecExpSlice<'a, L> { 83 | fn into(self) -> RecExpr { 84 | RecExpr::from(self.exp.as_ref()[..self.index + 1].iter().cloned().collect_vec()) 85 | } 86 | } 87 | 88 | pub trait IntoTree<'a, T: Language> { 89 | fn into_tree(&'a self) -> RecExpSlice<'a, T>; 90 | } 91 | 92 | impl<'a, T: Language> IntoTree<'a, T> for RecExpr { 93 | fn into_tree(&'a self) -> RecExpSlice<'a, T> { 94 | RecExpSlice::from(self) 95 | } 96 | } 97 | 98 | pub trait Tree<'a, T: 'a + Language> { 99 | fn root(&self) -> &'a T; 100 | 101 | fn children(&self) -> Vec>; 102 | 103 | fn is_leaf(&self) -> bool { 104 | self.children().is_empty() 105 | } 106 | } 107 | 108 | impl<'a ,L: Language> Tree<'a, L> for RecExpSlice<'a, L> { 109 | fn root(&self) -> &'a L { 110 | &self.exp.as_ref()[self.index] 111 | } 112 | 113 | fn children(&self) -> Vec> { 114 | self.exp.as_ref()[self.index].children().iter().map(|t| 115 | RecExpSlice::new(usize::from(*t), self.exp)).collect_vec() 116 | } 117 | } 118 | 119 | impl<'a, T: 'a + Language + Display> Display for RecExpSlice<'a, T> { 120 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 121 | write!(f, "{}", &self.to_sexp_string()) 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /doc/egg.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 25 | 27 | 30 | 34 | 38 | 39 | 50 | 51 | 73 | 75 | 76 | 78 | image/svg+xml 79 | 81 | 82 | 83 | 84 | 85 | 90 | 94 | 101 | 108 | 109 | 114 | 119 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changes 2 | 3 | 4 | 5 | ## [Unreleased] - ReleaseDate 6 | 7 | ### Added 8 | - The `BackoffScheduler` is now more flexible. 9 | - `EGraph::pre_union` allows inspection of unions, which can be useful for debugging. 10 | 11 | ### Changed 12 | - `EGraph::add_expr` now proceeds linearly through the given `RecExpr`, which 13 | should be faster and include _all_ e-nodes from the expression. 14 | - `Rewrite` now has public `searcher` and `applier` fields and no `long_name`. 15 | 16 | ## [0.6.0] - 2020-07-16 17 | 18 | ### Added 19 | - `Id` is now a struct not a type alias. This should help prevent some bugs. 20 | - `Runner` hooks allow you to modify the `Runner` each iteration and stop early if you want. 21 | - Added a way to lookup an e-node without adding it. 22 | - `define_language!` now support variants with data _and_ children. 23 | - Added a tutorial in the documentation! 24 | 25 | ### Fixed 26 | - Fixed a bug when making `Pattern`s from `RecExpr`s. 27 | - Improved the `RecExpr` API. 28 | 29 | ## [0.5.0] - 2020-06-22 30 | 31 | ### Added 32 | - `egg` now provides `Symbol`s, a simple interned string that users can (and 33 | should) use in their `Language`s. 34 | - `egg` will now warn you when you try to use `Rewrite`s with the same name. 35 | - Rewrite creation will now fail if the searcher doesn't bind the right variables. 36 | - The `rewrite!` macro supports bidirectional rewrites now. 37 | - `define_language!` now supports variable numbers of children with `Box<[Id]>`. 38 | 39 | ### Fixed 40 | - The `rewrite!` macro builds conditional rewrites in the correct order now. 41 | 42 | ## [0.4.1] - 2020-05-26 43 | 44 | ### Added 45 | - Added various Debug and Display impls. 46 | 47 | ### Fixed 48 | - Fixed the way applications were counted by the Runner. 49 | 50 | ## [0.4.0] - 2020-05-21 51 | 52 | ### Added 53 | - The rebuilding algorithm is now _precise_ meaning it avoid a lot of 54 | unnecessary work. This leads to across the board speedup by up to 2x. 55 | - `Language` elements are now much more compact, leading to speed ups across the board. 56 | 57 | ### Changed 58 | - Replaced `Metadata` with `Analysis`, which can hold egraph-global data as well 59 | as per-eclass data. 60 | - **Fix:** 61 | An eclass's metadata will now get updated by 62 | congruence. 63 | ([commit](https://github.com/mwillsey/egg/commit/0de75c9c9b0a80adb67fb78cc98cce3da383621a)) 64 | - The `BackoffScheduler` will now fast-forward if all rules are banned. 65 | ([commit](https://github.com/mwillsey/egg/commit/dd172ef77279e28448d0bf8147e0171a8175228d)) 66 | - Improve benchmark reporting 67 | ([commit](https://github.com/mwillsey/egg/commit/ca2ea5e239feda7eb6971942e119075f55f869ab)) 68 | - The egraph now skips irrelevant eclasses while searching for a ~40% search speed up. 69 | ([PR](https://github.com/mwillsey/egg/pull/21)) 70 | 71 | ## [0.3.0] - 2020-02-27 72 | 73 | ### Added 74 | - `Runner` can now be configured with user-defined `RewriteScheduler`s 75 | and `IterationData`. 76 | 77 | ### Changed 78 | - Reworked the `Runner` API. It's now a generic struct instead of a 79 | trait. 80 | - Patterns are now compiled into a small virtual machine bytecode inspired 81 | by [this paper](https://link.springer.com/chapter/10.1007/978-3-540-73595-3_13). 82 | This gets about a 40% speed up. 83 | 84 | ## [0.2.0] - 2020-02-19 85 | 86 | ### Added 87 | 88 | - A dumb little benchmarking system called `egg_bench` that can help 89 | benchmark tests. 90 | - String interning for `Var`s (née `QuestionMarkName`s). 91 | This speeds up things by ~35%. 92 | - Add a configurable time limit to `SimpleRunner` 93 | 94 | ### Changed 95 | 96 | - Renamed `WildMap` to `Subst`, `QuestionMarkName` to `Var`. 97 | 98 | ### Removed 99 | 100 | - Multi-matching patterns (ex: `?a...`). 101 | They were a hack and undocumented. 102 | If we can come up with better way to do it, then we can put them back. 103 | 104 | ## [0.1.2] - 2020-02-14 105 | 106 | This release completes the documentation 107 | (at least every public item is documented). 108 | 109 | ### Changed 110 | - Replaced `Pattern::{from_expr, to_expr}` with `From` and `TryFrom` 111 | implementations. 112 | 113 | ## [0.1.1] - 2020-02-13 114 | 115 | ### Added 116 | - A lot of documentation 117 | 118 | ### Changed 119 | - The graphviz visualization now looks a lot better; enode argument 120 | come out from the "correct" position based on which argument they 121 | are. 122 | 123 | ## [0.1.0] - 2020-02-11 124 | 125 | This is egg's first real release! 126 | 127 | Hard to make a changelog on the first release, since basically 128 | everything has changed! 129 | But hopefully things will be a little more stable from here on out 130 | since the API is a lot nicer. 131 | 132 | 133 | [Unreleased]: https://github.com/mwillsey/egg/compare/v0.6.0...HEAD 134 | [0.6.0]: https://github.com/mwillsey/egg/compare/v0.5.0...v0.6.0 135 | [0.5.0]: https://github.com/mwillsey/egg/compare/v0.4.1...v0.5.0 136 | [0.4.1]: https://github.com/mwillsey/egg/compare/v0.4.0...v0.4.1 137 | [0.4.0]: https://github.com/mwillsey/egg/compare/v0.3.0...v0.4.0 138 | [0.3.0]: https://github.com/mwillsey/egg/compare/v0.2.0...v0.3.0 139 | [0.2.0]: https://github.com/mwillsey/egg/compare/v0.1.2...v0.2.0 140 | [0.1.2]: https://github.com/mwillsey/egg/compare/v0.1.1...v0.1.2 141 | [0.1.1]: https://github.com/mwillsey/egg/compare/v0.1.0...v0.1.1 142 | [0.1.0]: https://github.com/mwillsey/egg/tree/v0.1.0 143 | -------------------------------------------------------------------------------- /tests/prop.rs: -------------------------------------------------------------------------------- 1 | use easter_egg::*; 2 | 3 | define_language! { 4 | enum Prop { 5 | Bool(bool), 6 | "&" = And([Id; 2]), 7 | "~" = Not(Id), 8 | "|" = Or([Id; 2]), 9 | "->" = Implies([Id; 2]), 10 | Symbol(Symbol), 11 | } 12 | } 13 | 14 | type EGraph = easter_egg::EGraph; 15 | type Rewrite = easter_egg::Rewrite; 16 | 17 | #[derive(Default, Clone)] 18 | struct ConstantFold; 19 | impl Analysis for ConstantFold { 20 | type Data = Option; 21 | fn merge(&self, to: &mut Self::Data, from: Self::Data) -> bool { 22 | merge_if_different(to, to.or(from)) 23 | } 24 | fn make(egraph: &EGraph, enode: &Prop) -> Self::Data { 25 | let x = |i: &Id| egraph[*i].data; 26 | let result = match enode { 27 | Prop::Bool(c) => Some(*c), 28 | Prop::Symbol(_) => None, 29 | Prop::And([a, b]) => Some(x(a)? && x(b)?), 30 | Prop::Not(a) => Some(!x(a)?), 31 | Prop::Or([a, b]) => Some(x(a)? || x(b)?), 32 | Prop::Implies([a, b]) => Some(x(a)? || !x(b)?), 33 | }; 34 | println!("Make: {:?} -> {:?}", enode, result); 35 | result 36 | } 37 | fn modify(egraph: &mut EGraph, id: Id) { 38 | println!("Modifying {}", id); 39 | if let Some(c) = egraph[id].data { 40 | let const_id = egraph.add(Prop::Bool(c)); 41 | egraph.union(id, const_id); 42 | } 43 | } 44 | } 45 | 46 | macro_rules! rule { 47 | ($name:ident, $left:literal, $right:literal) => { 48 | #[allow(dead_code)] 49 | fn $name() -> Rewrite { 50 | rewrite!(stringify!($name); $left => $right) 51 | } 52 | }; 53 | ($name:ident, $name2:ident, $left:literal, $right:literal) => { 54 | rule!($name, $left, $right); 55 | rule!($name2, $right, $left); 56 | }; 57 | } 58 | 59 | rule! {def_imply, def_imply_flip, "(-> ?a ?b)", "(| (~ ?a) ?b)" } 60 | rule! {double_neg, double_neg_flip, "(~ (~ ?a))", "?a" } 61 | rule! {assoc_or, "(| ?a (| ?b ?c))", "(| (| ?a ?b) ?c)" } 62 | rule! {dist_and_or, "(& ?a (| ?b ?c))", "(| (& ?a ?b) (& ?a ?c))"} 63 | rule! {dist_or_and, "(| ?a (& ?b ?c))", "(& (| ?a ?b) (| ?a ?c))"} 64 | rule! {comm_or, "(| ?a ?b)", "(| ?b ?a)" } 65 | rule! {comm_and, "(& ?a ?b)", "(& ?b ?a)" } 66 | rule! {lem, "(| ?a (~ ?a))", "true" } 67 | rule! {or_true, "(| ?a true)", "true" } 68 | rule! {and_true, "(& ?a true)", "?a" } 69 | rule! {contrapositive, "(-> ?a ?b)", "(-> (~ ?b) (~ ?a))" } 70 | rule! {lem_imply, "(& (-> ?a ?b) (-> (~ ?a) ?c))", "(| ?b ?c)"} 71 | 72 | fn prove_something(name: &str, start: &str, rewrites: &[Rewrite], goals: &[&str]) { 73 | let _ = env_logger::builder().is_test(true).try_init(); 74 | println!("Proving {}", name); 75 | 76 | let start_expr: RecExpr<_> = start.parse().unwrap(); 77 | let goal_exprs: Vec> = goals.iter().map(|g| g.parse().unwrap()).collect(); 78 | 79 | let mut egraph = Runner::default() 80 | .with_iter_limit(20) 81 | .with_node_limit(5_000) 82 | .with_expr(&start_expr) 83 | .run(rewrites) 84 | .egraph; 85 | 86 | egraph.rebuild(); 87 | for (i, (goal_expr, goal_str)) in goal_exprs.iter().zip(goals).enumerate() { 88 | println!("Trying to prove goal {}: {}", i, goal_str); 89 | let equivs = egraph.equivs(&start_expr, &goal_expr); 90 | if equivs.is_empty() { 91 | panic!("Couldn't prove goal {}: {}", i, goal_str); 92 | } 93 | } 94 | } 95 | 96 | #[test] 97 | fn prove_contrapositive() { 98 | let _ = env_logger::builder().is_test(true).try_init(); 99 | let rules = &[def_imply(), def_imply_flip(), double_neg_flip(), comm_or()]; 100 | prove_something( 101 | "contrapositive", 102 | "(-> x y)", 103 | rules, 104 | &[ 105 | "(-> x y)", 106 | "(| (~ x) y)", 107 | "(| (~ x) (~ (~ y)))", 108 | "(| (~ (~ y)) (~ x))", 109 | "(-> (~ y) (~ x))", 110 | ], 111 | ); 112 | } 113 | 114 | #[test] 115 | fn prove_chain() { 116 | let _ = env_logger::builder().is_test(true).try_init(); 117 | let rules = &[ 118 | // rules needed for contrapositive 119 | def_imply(), 120 | def_imply_flip(), 121 | double_neg_flip(), 122 | comm_or(), 123 | // and some others 124 | comm_and(), 125 | lem_imply(), 126 | ]; 127 | prove_something( 128 | "chain", 129 | "(& (-> x y) (-> y z))", 130 | rules, 131 | &[ 132 | "(& (-> x y) (-> y z))", 133 | "(& (-> (~ y) (~ x)) (-> y z))", 134 | "(& (-> y z) (-> (~ y) (~ x)))", 135 | "(| z (~ x))", 136 | "(| (~ x) z)", 137 | "(-> x z)", 138 | ], 139 | ); 140 | } 141 | 142 | #[test] 143 | fn const_fold() { 144 | let start = "(| (& false true) (& true false))"; 145 | let start_expr = start.parse().unwrap(); 146 | let end = "false"; 147 | let end_expr = end.parse().unwrap(); 148 | let mut eg = EGraph::default(); 149 | eg.add_expr(&start_expr); 150 | eg.rebuild(); 151 | assert!(!eg.equivs(&start_expr, &end_expr).is_empty()); 152 | } 153 | -------------------------------------------------------------------------------- /src/subst.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | use std::str::FromStr; 3 | use thiserror::Error; 4 | 5 | use crate::{Analysis, EGraph, Id, Symbol}; 6 | use crate::ColorId; 7 | use std::fmt::Formatter; 8 | 9 | /// A variable for use in [`Pattern`]s or [`Subst`]s. 10 | /// 11 | /// This implements [`FromStr`], and will only parse if it has a 12 | /// leading `?`. 13 | /// 14 | /// [`Pattern`]: struct.Pattern.html 15 | /// [`Subst`]: struct.Subst.html 16 | /// [`FromStr`]: https://doc.rust-lang.org/std/str/trait.FromStr.html 17 | #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] 18 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 19 | pub struct Var(Symbol); 20 | 21 | #[derive(Debug, Error)] 22 | pub enum VarParseError { 23 | #[error("pattern variable {0:?} should have a leading question mark")] 24 | MissingQuestionMark(String), 25 | } 26 | 27 | impl FromStr for Var { 28 | type Err = VarParseError; 29 | 30 | fn from_str(s: &str) -> Result { 31 | use VarParseError::*; 32 | 33 | if s.starts_with('?') && s.len() > 1 { 34 | Ok(Var(s.into())) 35 | } else { 36 | Err(MissingQuestionMark(s.to_owned())) 37 | } 38 | } 39 | } 40 | 41 | impl fmt::Display for Var { 42 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 43 | write!(f, "{}", self.0) 44 | } 45 | } 46 | 47 | impl fmt::Debug for Var { 48 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 49 | write!(f, "{}", self.0) 50 | } 51 | } 52 | 53 | /// A substitition mapping [`Var`]s to eclass [`Id`]s. 54 | /// 55 | /// [`Var`]: struct.Var.html 56 | /// [`Id`]: struct.Id.html 57 | #[derive(Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] 58 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 59 | pub struct Subst { 60 | pub(crate) vec: smallvec::SmallVec<[(Var, Id); 8]>, 61 | pub(crate) color: Option, 62 | } 63 | 64 | impl Subst { 65 | pub(crate) fn fix>(&mut self, egraph: &EGraph) { 66 | let color = self.color; 67 | for (_var, id) in &mut self.vec { 68 | *id = egraph.opt_colored_find(color, *id); 69 | } 70 | } 71 | } 72 | 73 | impl Subst { 74 | /// Create a `Subst` with the given initial capacity 75 | pub fn with_capacity(capacity: usize) -> Self { 76 | Self { 77 | vec: smallvec::SmallVec::with_capacity(capacity), 78 | color: None, 79 | } 80 | } 81 | 82 | pub fn colored_with_capacity(capacity: usize, color: Option) -> Self { 83 | Self { 84 | vec: smallvec::SmallVec::with_capacity(capacity), 85 | color, 86 | } 87 | } 88 | 89 | /// Insert something, returning the old `Id` if present. 90 | pub fn insert(&mut self, var: Var, id: Id) -> Option { 91 | for pair in &mut self.vec { 92 | if pair.0 == var { 93 | return Some(std::mem::replace(&mut pair.1, id)); 94 | } 95 | } 96 | self.vec.push((var, id)); 97 | None 98 | } 99 | 100 | /// Retrieve a `Var`, returning `None` if not present. 101 | #[inline(never)] 102 | pub fn get(&self, var: Var) -> Option<&Id> { 103 | self.vec 104 | .iter() 105 | .find_map(|(v, id)| if *v == var { Some(id) } else { None }) 106 | } 107 | 108 | pub fn color(&self) -> Option { 109 | self.color 110 | } 111 | 112 | pub fn merge(&self, sub2: Subst) -> Subst { 113 | assert!(self.color.is_none() || sub2.color.is_none() || self.color == sub2.color); 114 | let mut new = self.clone(); 115 | if new.color.is_none() && sub2.color.is_some() { 116 | new.color = sub2.color.clone(); 117 | } 118 | for (var, id) in sub2.vec { 119 | if let Some(vid) = self.get(var) { 120 | assert!(vid == &id); 121 | } else { 122 | new.insert(var, id); 123 | } 124 | } 125 | new 126 | } 127 | } 128 | 129 | impl std::ops::Index for Subst { 130 | type Output = Id; 131 | 132 | fn index(&self, var: Var) -> &Self::Output { 133 | match self.get(var) { 134 | Some(id) => id, 135 | None => panic!("Var '{}={}' not found in {:?}", var.0, var, self), 136 | } 137 | } 138 | } 139 | 140 | impl fmt::Display for Subst { 141 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 142 | write!(f, "{:#?}", self) 143 | } 144 | } 145 | 146 | impl fmt::Debug for Subst { 147 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 148 | let len = self.vec.len(); 149 | write!(f, "{{")?; 150 | for i in 0..len { 151 | let (var, id) = &self.vec[i]; 152 | write!(f, "{}: {}", var, id)?; 153 | if i < len - 1 { 154 | write!(f, ", ")?; 155 | } 156 | } 157 | write!(f, " color: {}", self.color.map_or("None".to_string(), |x| x.to_string()))?; 158 | write!(f, "}}") 159 | } 160 | } 161 | 162 | #[cfg(test)] 163 | mod tests { 164 | use super::*; 165 | 166 | #[test] 167 | fn var_parse() { 168 | assert_eq!(Var::from_str("?a").unwrap().to_string(), "?a"); 169 | assert_eq!(Var::from_str("?abc 123").unwrap().to_string(), "?abc 123"); 170 | assert!(Var::from_str("a").is_err()); 171 | assert!(Var::from_str("a?").is_err()); 172 | assert!(Var::from_str("?").is_err()); 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /src/eggstentions/searchers.rs: -------------------------------------------------------------------------------- 1 | use crate::{EGraph, Id, Pattern, Searcher, SearchMatches, Var, Language, Analysis, ColorId}; 2 | 3 | use smallvec::alloc::fmt::Formatter; 4 | use std::rc::Rc; 5 | 6 | /// Trait for converting a type to a dynamic type behind a Rc pointer. 7 | pub trait ToDyn> { 8 | /// Convert to a dynamic type behind a Rc pointer. 9 | fn into_rc_dyn(self) -> Rc>; 10 | } 11 | 12 | impl + 'static> ToDyn for Pattern { 13 | fn into_rc_dyn(self) -> Rc> { 14 | let dyn_s: Rc> = Rc::new(self); 15 | dyn_s 16 | } 17 | } 18 | 19 | /// A searcher that wraps another searcher and returns the same result. 20 | pub struct PointerSearcher> { 21 | searcher: Rc>, 22 | } 23 | 24 | impl> PointerSearcher { 25 | /// Create a new PointerSearcher. 26 | pub fn new(searcher: Rc>) -> Self { PointerSearcher { searcher } } 27 | } 28 | 29 | impl> std::fmt::Display for PointerSearcher { 30 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 31 | write!(f, "{}", self.searcher) 32 | } 33 | } 34 | 35 | impl> Searcher for PointerSearcher { 36 | fn search_eclass_with_limit(&self, egraph: &EGraph, eclass: Id, limit: usize) -> Option { 37 | self.searcher.search_eclass_with_limit(egraph, eclass, limit) 38 | } 39 | 40 | fn search(&self, egraph: &EGraph) -> Option { 41 | self.searcher.search(egraph) 42 | } 43 | 44 | fn colored_search_eclass_with_limit(&self, egraph: &EGraph, eclass: Id, color: ColorId, limit: usize) -> Option { 45 | self.searcher.colored_search_eclass_with_limit(egraph, eclass, color, limit) 46 | } 47 | 48 | fn vars(&self) -> Vec { 49 | self.searcher.vars() 50 | } 51 | } 52 | 53 | 54 | #[cfg(test)] 55 | mod tests { 56 | use std::str::FromStr; 57 | 58 | use crate::{EGraph, RecExpr, Searcher, SymbolLang, MultiPattern, init_logger}; 59 | // use crate::system_case_splits; 60 | 61 | #[test] 62 | fn eq_two_trees_one_common() { 63 | init_logger(); 64 | 65 | let searcher: MultiPattern = "?x = (a ?b ?c), ?x = (a ?c ?d)".parse().unwrap(); 66 | let mut egraph: EGraph = EGraph::default(); 67 | let x = egraph.add_expr(&RecExpr::from_str("x").unwrap()); 68 | let z = egraph.add_expr(&RecExpr::from_str("z").unwrap()); 69 | let a = egraph.add_expr(&RecExpr::from_str("(a x y)").unwrap()); 70 | egraph.add_expr(&RecExpr::from_str("(a z x)").unwrap()); 71 | egraph.rebuild(); 72 | assert!(searcher.search(&egraph).is_none()); 73 | let a2 = egraph.add(SymbolLang::new("a", vec![z, x])); 74 | egraph.union(a, a2); 75 | egraph.rebuild(); 76 | assert_eq!(searcher.search(&egraph).unwrap().len(), 1); 77 | } 78 | 79 | #[test] 80 | fn diff_two_trees_one_common() { 81 | init_logger(); 82 | 83 | let searcher = MultiPattern::from_str("?v1 = (a ?b ?c), ?v2 = (a ?c ?d)").unwrap(); 84 | let mut egraph: EGraph = EGraph::default(); 85 | let _x = egraph.add_expr(&RecExpr::from_str("x").unwrap()); 86 | let _z = egraph.add_expr(&RecExpr::from_str("z").unwrap()); 87 | let _a = egraph.add_expr(&RecExpr::from_str("(a x y)").unwrap()); 88 | egraph.add_expr(&RecExpr::from_str("(a z x)").unwrap()); 89 | egraph.rebuild(); 90 | assert_eq!(searcher.search(&egraph).unwrap().len(), 1); 91 | } 92 | 93 | #[test] 94 | fn find_ind_hyp() { 95 | init_logger(); 96 | 97 | let mut egraph: EGraph = EGraph::default(); 98 | let full_pl = egraph.add_expr(&"(pl (S p0) Z)".parse().unwrap()); 99 | let after_pl = egraph.add_expr(&"(S (pl p0 Z))".parse().unwrap()); 100 | let sp0 = egraph.add_expr(&"(S p0)".parse().unwrap()); 101 | let ind_var = egraph.add_expr(&"ind_var".parse().unwrap()); 102 | egraph.union(ind_var, sp0); 103 | let _ltwf = egraph.add_expr(&"(ltwf p0 (S p0))".parse().unwrap()); 104 | egraph.union(full_pl, after_pl); 105 | egraph.rebuild(); 106 | let searcher = MultiPattern::from_str("?v1 = (ltwf ?x ind_var), ?v2 = (pl ?x Z)").unwrap(); 107 | assert!(searcher.search(&egraph).is_some()); 108 | } 109 | 110 | // #[cfg(feature = "split_colored")] 111 | // #[test] 112 | // fn skip_vacuity_cases() { 113 | // let mut graph: EGraph = EGraph::default(); 114 | // graph.add_expr(&RecExpr::from_str("(ite x 1 2)").unwrap()); 115 | // graph.rebuild(); 116 | // let mut case_splitter = system_case_splits(); 117 | // let pattern: Pattern = Pattern::from_str("(ite ?z ?x ?y)").unwrap(); 118 | // println!("{:?}", pattern.search(&graph)); 119 | // let splitters = case_splitter.find_splitters(&mut graph); 120 | // println!("{:?}", splitters); 121 | // assert_eq!(splitters.len(), 1); 122 | // let colors = splitters[0].create_colors(&mut graph); 123 | // graph.rebuild(); 124 | // println!("{:?}", pattern.search(&graph)); 125 | // let new_splitters = case_splitter.find_splitters(&mut graph); 126 | // println!("{:?}", new_splitters); 127 | // assert_eq!(new_splitters.len(), 1); 128 | // 129 | // } 130 | } -------------------------------------------------------------------------------- /src/tools.rs: -------------------------------------------------------------------------------- 1 | pub mod tools { 2 | use std::collections::hash_map::RandomState; 3 | use std::hash::Hash; 4 | 5 | use itertools::MultiProduct; 6 | use itertools::Itertools; 7 | use indexmap::IndexMap; 8 | use log::debug; 9 | use crate::{ENodeOrVar, Id, Language, MultiPattern, Pattern, RecExpr}; 10 | use crate::ENodeOrVar::ENode; 11 | 12 | // fn combinations<'a, T: 'a, I: Iterator + Clone>(mut sets: impl Iterator) -> impl Iterator> { 13 | // let first = sets.next(); 14 | // let second = sets.next(); 15 | // if first.is_none() || second.is_none() { 16 | // return iter::empty(); 17 | // } 18 | // 19 | // let initial = Itertools::cartesian_product(first.unwrap(), second.unwrap()) 20 | // .map(|p| vec![p.0, p.1]); 21 | // let res = sets.fold(initial, |res, i| Itertools::cartesian_product(res, i)); 22 | // res.unwrap_or(iter::empty()) 23 | // } 24 | 25 | pub fn product<'a, T: 'a + Clone>(vecs: &[&'a Vec]) -> Vec> { 26 | if vecs.is_empty() { 27 | return vec![vec![]]; 28 | } 29 | 30 | if vecs.len() == 1 { 31 | return vecs[0].iter().map(|t| vec![t]).collect(); 32 | } 33 | 34 | let rec_res = product(&vecs[1..vecs.len()]); 35 | let initial_set = &vecs[0]; 36 | let mut res = Vec::new(); 37 | for s in initial_set.iter() { 38 | for r in rec_res.iter() { 39 | let mut new_r = r.clone(); 40 | new_r.push(s); 41 | res.push(new_r) 42 | } 43 | } 44 | 45 | return res; 46 | } 47 | 48 | #[allow(dead_code)] 49 | pub(crate) fn combinations>(iters: impl Iterator) -> MultiProduct { 50 | iters.multi_cartesian_product() 51 | } 52 | 53 | pub fn choose(vec: &[K], size: usize) -> Vec> { 54 | if size == 1 { 55 | let mut res = Vec::default(); 56 | vec.iter().for_each(|k| res.push(vec![k])); 57 | return res; 58 | } 59 | if size == 0 || size > vec.len() { 60 | return Vec::default(); 61 | } 62 | 63 | let mut res = Vec::default(); 64 | for (i, k) in vec.iter().enumerate() { 65 | let mut rec_res = choose(&vec[i + 1..], size - 1); 66 | for v in rec_res.iter_mut() { 67 | v.push(k); 68 | } 69 | res.extend(rec_res); 70 | } 71 | res 72 | } 73 | 74 | pub trait Grouped { 75 | fn grouped K>(&mut self, key: F) -> IndexMap>; 76 | } 77 | 78 | impl> Grouped for I { 79 | fn grouped K>(&mut self, key: F) -> IndexMap, RandomState> { 80 | let mut res = IndexMap::new(); 81 | self.for_each(|t| res.entry(key(&t)).or_insert(Vec::new()).push(t)); 82 | res 83 | } 84 | } 85 | 86 | pub fn vacuity_detector_from_ops(ops: Vec) -> Vec> { 87 | let v: crate::Var = "?multipattern_var".parse().unwrap(); 88 | let patterns = ops.into_iter().map(|o| { 89 | let p: Pattern = { 90 | let mut rec_expr: RecExpr> = Default::default(); 91 | for i in 0..o.children().len() { 92 | let var: crate::Var = format!("?var_{}_{}", o.op_id(), i).parse().unwrap(); 93 | rec_expr.add(ENodeOrVar::Var(var)); 94 | } 95 | let mut new_node = o.clone(); 96 | new_node.children_mut().iter_mut().enumerate().for_each(|(i, c)| *c = Id::from(i)); 97 | rec_expr.add(ENode(new_node, None)); 98 | Pattern::from(rec_expr) 99 | }; 100 | (v, p.ast) 101 | }).collect_vec(); 102 | let res = (0..patterns.len()).map(|i| { 103 | let mut new_patterns = patterns.iter().map(|(_v, x)| x).cloned().collect_vec(); 104 | let main = new_patterns.remove(i); 105 | MultiPattern::new_with_specials(vec![(v, main)], vec![(v, new_patterns)], vec![]) 106 | }).collect_vec(); 107 | debug!("Vacuity detector: {}", res.iter().join(", ")); 108 | res 109 | } 110 | } 111 | 112 | #[cfg(test)] 113 | mod tests { 114 | use std::iter::FromIterator; 115 | use indexmap::IndexSet; 116 | 117 | use itertools::Itertools; 118 | 119 | use crate::tools::tools::choose; 120 | use crate::tools::tools::combinations; 121 | 122 | #[test] 123 | fn check_comb_amount() { 124 | let v1 = vec![1, 2, 3]; 125 | let v2 = vec![4, 5, 6]; 126 | let combs = combinations(vec![v1.iter(), v2.iter()].into_iter()).collect_vec(); 127 | assert_eq!(combs.len(), 9); 128 | for v in &combs { 129 | assert_eq!(v.len(), 2); 130 | } 131 | // No doubles 132 | let as_set: IndexSet<&Vec<&i32>> = IndexSet::from_iter(combs.iter()); 133 | assert_eq!(as_set.len(), 9); 134 | } 135 | 136 | #[test] 137 | fn check_choose_amount() { 138 | let v3 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; 139 | for i in 1..9 { 140 | let chosen = choose(&v3, i); 141 | for v in &chosen { 142 | assert_eq!(v.len(), i); 143 | } 144 | let as_set: IndexSet<&Vec<&i32>> = IndexSet::from_iter(chosen.iter()); 145 | assert_eq!(chosen.len(), as_set.len()); 146 | } 147 | assert_eq!(choose(&v3, 2).len(), 36); 148 | } 149 | } -------------------------------------------------------------------------------- /src/eggstentions/tree.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | use std::cmp::Ordering::{Equal, Greater, Less}; 3 | use std::fmt::{Display, Formatter}; 4 | use std::rc::Rc; 5 | 6 | use crate::{EGraph, Id, SymbolLang}; 7 | use itertools::{Itertools, max}; 8 | use symbolic_expressions::Sexp; 9 | 10 | macro_rules! bail { 11 | ($s:literal $(,)?) => { 12 | return Err($s.into()) 13 | }; 14 | ($s:literal, $($args:expr),+) => { 15 | return Err(format!($s, $($args),+).into()) 16 | }; 17 | } 18 | 19 | type ROption = Rc>; 20 | 21 | /// A term tree with a root and subtrees 22 | #[derive(Clone, Hash, PartialEq, Eq)] 23 | pub struct Tree { 24 | /// The root of the term 25 | pub root: String, 26 | /// The subtrees of the term 27 | pub subtrees: Vec, 28 | /// The type of the term 29 | pub typ: ROption, 30 | } 31 | 32 | impl Tree { 33 | /// Create a new single node tree 34 | pub fn leaf(op: String) -> Tree { 35 | Tree { root: op, subtrees: Vec::new(), typ: Rc::new(None) } 36 | } 37 | 38 | /// Create a new single node tree with a type 39 | pub fn tleaf(op: String, typ: Option) -> Tree { 40 | Tree { root: op, subtrees: Vec::new(), typ: Rc::new(typ) } 41 | } 42 | 43 | /// Create a new tree with a root and subtrees 44 | pub fn branch(op: String, subtrees: Vec) -> Tree { 45 | Tree { root: op, subtrees, typ: Rc::new(None) } 46 | } 47 | 48 | #[allow(missing_docs)] 49 | pub fn depth(&self) -> usize { 50 | return max(self.subtrees.iter().map(|x| x.depth())).unwrap_or(0) + 1 51 | } 52 | 53 | #[allow(missing_docs)] 54 | pub fn size(&self) -> usize { 55 | return self.subtrees.iter().map(|x| x.size()).sum::() + 1 56 | } 57 | 58 | // pub fn to_rec_expr(&self, op_res: Option>) -> (Id, RecExpr) { 59 | // let mut res = if op_res.is_none() { RecExpr::default() } else { op_res.unwrap() }; 60 | // return if self.is_leaf() { 61 | // (res.add(SymbolLang::leaf(&self.root)), res) 62 | // } else { 63 | // let mut ids = Vec::default(); 64 | // for s in &self.subtrees { 65 | // let (id, r) = s.to_rec_expr(Some(res)); 66 | // res = r; 67 | // ids.insert(0, id); 68 | // } 69 | // (res.add(SymbolLang::new(&self.root, ids)), res) 70 | // }; 71 | // } 72 | 73 | /// Add this term to the egraph 74 | pub fn add_to_graph(&self, graph: &mut EGraph) -> Id { 75 | let mut children = Vec::new(); 76 | for t in &self.subtrees { 77 | children.push(t.add_to_graph(graph)); 78 | }; 79 | graph.add(SymbolLang::new(self.root.clone(), children)) 80 | } 81 | 82 | #[allow(missing_docs)] 83 | pub fn is_leaf(&self) -> bool { 84 | self.subtrees.is_empty() 85 | } 86 | 87 | #[allow(missing_docs)] 88 | pub fn to_sexp_string(&self) -> String { 89 | if self.is_leaf() { 90 | self.root.clone() 91 | } else { 92 | format!("({} {})", self.root.clone(), itertools::Itertools::intersperse(self.subtrees.iter().map(|t| t.to_string()), " ".parse().unwrap()).collect::()) 93 | } 94 | } 95 | 96 | /// Lexicographic ordering for trees, by root symbol and then by subtree ordering. 97 | pub fn tree_lexicographic_ordering(t1: &Tree, t2: &Tree) -> Ordering { 98 | match t1.root.cmp(&t2.root ) { 99 | Less => Less, 100 | Equal => { 101 | t1.subtrees.iter().zip_longest(&t2.subtrees).find_map(|x| { 102 | if !x.has_left() { 103 | Some(Less) 104 | } else if !x.has_right() { 105 | Some(Greater) 106 | } else { 107 | let l = *x.as_ref().left().unwrap(); 108 | let r = *x.as_ref().right().unwrap(); 109 | let rec_res = Self::tree_lexicographic_ordering(l, r); 110 | rec_res.is_eq().then(|| rec_res) 111 | } 112 | }).unwrap_or(Equal) 113 | }, 114 | Greater => Greater 115 | } 116 | } 117 | 118 | /// Ordering for trees, by depth and then by size and then by lexicographic ordering. 119 | pub fn tree_size_ordering(t1: &Tree, t2: &Tree) -> Ordering { 120 | match t1.depth().cmp(&t2.depth()) { 121 | Less => Less, 122 | Equal => match t1.size().cmp(&t2.size()) { 123 | Less => Less, 124 | // Oh the horror of string semantics (but I am not going to implement a full recursive 125 | // check here) 126 | Equal => Self::tree_lexicographic_ordering(t1, t2), 127 | Greater => Greater 128 | }, 129 | Greater => Greater 130 | } 131 | } 132 | } 133 | 134 | impl Display for Tree { 135 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 136 | write!(f, "{}", &self.to_sexp_string()) 137 | } 138 | } 139 | 140 | impl std::str::FromStr for Tree { 141 | type Err = String; 142 | fn from_str(s: &str) -> Result { 143 | fn parse_sexp_tree(sexp: &Sexp) -> Result { 144 | match sexp { 145 | Sexp::Empty => Err("Found empty s-expression".into()), 146 | Sexp::String(s) => { 147 | Ok(Tree::leaf(s.clone())) 148 | } 149 | Sexp::List(list) if list.is_empty() => Err("Found empty s-expression".into()), 150 | Sexp::List(list) => match &list[0] { 151 | Sexp::Empty => unreachable!("Cannot be in head position"), 152 | // TODO: add apply 153 | Sexp::List(l) => bail!("Found a list in the head position: {:?}", l), 154 | // Sexp::String(op) if op == "typed" => { 155 | // let mut tree = parse_sexp_tree(&list[1])?; 156 | // let types = parse_sexp_tree(&list[2])?; 157 | // tree.typ = Box::new(Some(types)); 158 | // Ok(tree) 159 | // } 160 | Sexp::String(op) => { 161 | let arg_ids = list[1..].iter().map(|s| parse_sexp_tree(s).expect("Parsing should succeed")).collect::>(); 162 | let node = Tree::branch(op.clone(), arg_ids); 163 | Ok(node) 164 | } 165 | }, 166 | } 167 | } 168 | 169 | let sexp = symbolic_expressions::parser::parse_str(s.trim()).map_err(|e| e.to_string())?; 170 | Ok(parse_sexp_tree(&sexp)?) 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /src/eggstentions/expression_ops.rs: -------------------------------------------------------------------------------- 1 | use std::collections::BTreeSet; 2 | use std::fmt::Formatter; 3 | use std::hash::{Hash, Hasher}; 4 | 5 | use crate::{EGraph, Id, Language, RecExpr}; 6 | use indexmap::{IndexMap, IndexSet}; 7 | use itertools::Itertools; 8 | use smallvec::alloc::fmt::Display; 9 | 10 | /// A wrapper arround RecExp t omake it easier to use. 11 | #[derive(Clone, Debug)] 12 | pub struct RecExpSlice<'a, L: Language> { 13 | index: usize, 14 | exp: &'a RecExpr 15 | } 16 | 17 | impl<'a, L: Language> RecExpSlice<'a, L> { 18 | /// Create a new RecExpSlice. 19 | pub fn new(index: usize, exp: &'a RecExpr) -> RecExpSlice<'a, L> { 20 | RecExpSlice{index, exp} 21 | } 22 | 23 | /// Adds expression to the EGraph `graph` and returns the root of the expression. 24 | pub fn add_to_graph(&self, graph: &mut EGraph) -> Id { 25 | graph.add_expr(&RecExpr::from(self.exp.as_ref()[..self.index+1].iter().cloned().collect_vec())) 26 | } 27 | 28 | /// Returns a string representation that is easier to parse. 29 | pub fn to_spaceless_string(&self) -> String { 30 | self.to_sexp_string() 31 | .replace(" ", "_") 32 | .replace("(", "PO") 33 | .replace(")", "PC") 34 | .replace("->", "fn") 35 | } 36 | 37 | /// Returns a sexp string representation of the expression. 38 | pub fn to_sexp_string(&self) -> String { 39 | if self.is_leaf() { 40 | format!("{}", self.root().display_op().to_string()) 41 | } else { 42 | format!("({} {})", self.root().display_op().to_string(), 43 | itertools::Itertools::intersperse(self.children().iter().map(|t| t.to_sexp_string()), " ".to_string()).collect::()) 44 | } 45 | } 46 | 47 | /// Recreates the expression from the slice, but without any dangling children. 48 | pub fn to_clean_exp(&self) -> RecExpr { 49 | fn add_to_exp<'a, L: Language>(expr: &mut Vec, child: &RecExpSlice<'a, L>) -> Id { 50 | let children = child.children(); 51 | let mut rec_res = children.iter().map(|c| add_to_exp(expr, c)); 52 | let mut root = child.root().clone(); 53 | root.update_children(|_id| rec_res.next().unwrap()); 54 | expr.push(root); 55 | Id::from(expr.len() - 1) 56 | } 57 | 58 | let mut exp = vec![]; 59 | add_to_exp(&mut exp, self); 60 | debug_assert_eq!(exp.iter().flat_map(|x| x.children()).count(), 61 | exp.iter().flat_map(|x| x.children()).unique().count()); 62 | RecExpr::from(exp) 63 | } 64 | } 65 | 66 | impl<'a, L: Language> PartialEq for RecExpSlice<'a, L> { 67 | fn eq(&self, other: &Self) -> bool { 68 | self.root() == other.root() && self.children() == other.children() 69 | } 70 | } 71 | 72 | impl<'a, L: Language> Hash for RecExpSlice<'a, L> { 73 | fn hash(&self, state: &mut H) { 74 | (self.root().hash(state), self.children().hash(state)).hash(state) 75 | } 76 | } 77 | 78 | impl<'a, L: Language> From<&'a RecExpr> for RecExpSlice<'a, L> { 79 | fn from(expr: &'a RecExpr) -> Self { 80 | RecExpSlice{index: expr.as_ref().len() - 1, exp: expr} 81 | } 82 | } 83 | 84 | impl<'a, L: Language + Clone> From<&'a RecExpSlice<'a, L>> for RecExpr { 85 | fn from(expr: &'a RecExpSlice<'a, L>) -> Self { 86 | // Need to remove unneeded nodes because recexpr comparison works straigt on vec 87 | let mut nodes: Vec> = vec![]; 88 | nodes.push(expr.clone()); 89 | let mut indices = IndexSet::new(); 90 | while !nodes.is_empty() { 91 | let current = nodes.pop().unwrap(); 92 | indices.insert(current.index); 93 | for n in current.children() { 94 | nodes.push(n); 95 | } 96 | } 97 | let mut res: Vec = vec![]; 98 | let mut id_trans: IndexMap = IndexMap::new(); 99 | for i in indices.iter().sorted() { 100 | id_trans.insert(Id::from(*i), Id::from(res.len())); 101 | res.push(expr.exp.as_ref()[*i].clone().map_children(|id| *id_trans.get(&id).unwrap())); 102 | } 103 | RecExpr::from(res) 104 | } 105 | } 106 | 107 | impl<'a, L: Language> Into> for RecExpSlice<'a, L> { 108 | fn into(self) -> RecExpr { 109 | RecExpr::from(self.exp.as_ref()[..self.index + 1].iter().cloned().collect_vec()) 110 | } 111 | } 112 | 113 | /// Trait to wrap a RecExpr like object into a RecExpSlice. 114 | pub trait IntoTree<'a, T: Language> { 115 | /// Wraps the object into a RecExpSlice. 116 | fn into_tree(&'a self) -> RecExpSlice<'a, T>; 117 | } 118 | 119 | impl<'a, T: Language> IntoTree<'a, T> for RecExpr { 120 | fn into_tree(&'a self) -> RecExpSlice<'a, T> { 121 | RecExpSlice::from(self) 122 | } 123 | } 124 | 125 | /// A trait for objects that can be used as trees. 126 | pub trait Tree<'a, T: 'a + Language> { 127 | /// Returns the root of the tree. 128 | fn root(&self) -> &'a T; 129 | 130 | /// Returns the children (subtrees) of the root of the tree. 131 | fn children(&self) -> Vec>; 132 | 133 | /// Returns true if the tree is a leaf. 134 | fn is_leaf(&self) -> bool { 135 | self.children().is_empty() 136 | } 137 | 138 | /// Returns true if the root of the tree is a hole. Decide if a hole is a hole by checking if the 139 | /// display op starts with a question mark. 140 | fn is_root_hole(&self) -> bool { 141 | self.root().display_op().to_string().starts_with("?") 142 | } 143 | 144 | /// Returns true if the root of the tree is not a hole. 145 | fn is_root_ident(&self) -> bool { 146 | !self.is_root_hole() 147 | } 148 | 149 | /// Return all holes in tree 150 | fn holes(&self) -> BTreeSet { 151 | let mut res: BTreeSet = self.children().into_iter().flat_map(|c| c.holes()).collect(); 152 | if self.is_root_hole() { 153 | res.insert(self.root().clone()); 154 | } 155 | res 156 | } 157 | 158 | /// Return all non constants 159 | fn consts(&self) -> Vec { 160 | let mut res: Vec = self.children().into_iter().flat_map(|c| c.consts()).collect(); 161 | if self.is_root_ident() { 162 | res.push(self.root().clone()); 163 | } 164 | res 165 | } 166 | } 167 | 168 | impl<'a ,L: Language> Tree<'a, L> for RecExpSlice<'a, L> { 169 | fn root(&self) -> &'a L { 170 | &self.exp.as_ref()[self.index] 171 | } 172 | 173 | fn children(&self) -> Vec> { 174 | self.exp.as_ref()[self.index].children().iter().map(|t| 175 | RecExpSlice::new(usize::from(*t), self.exp)).collect_vec() 176 | } 177 | } 178 | 179 | impl<'a, T: 'a + Language + Display> Display for RecExpSlice<'a, T> { 180 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 181 | write!(f, "{}", &self.to_sexp_string()) 182 | } 183 | } 184 | 185 | -------------------------------------------------------------------------------- /scripts/nightly-static/main.js: -------------------------------------------------------------------------------- 1 | console.log("Starting!"); 2 | 3 | const runList = d3.select("#runs").selectAll("li"); 4 | 5 | const dateFormat = d3.timeFormat("%Y-%m-%d"); 6 | const timeFormat = d3.timeFormat("%H:%M:%S"); 7 | const datetimeFormat = d => `${dateFormat(d)} ${timeFormat(d)}`; 8 | 9 | const DIR_REGEX = /(.+?)___(.+?)___((.+?)-(\d+)-g([0-9a-f]+)(-dirty)?)/; 10 | const BASE_URL = `${window.location.protocol}//${window.location.host}/egg-nightlies/`; 11 | const DATA_DIR = BASE_URL + "data/"; 12 | const REPO = "https://github.com/mwillsey/egg"; 13 | 14 | function get(path) { 15 | // console.log("Fetching", path); 16 | return d3 17 | .json(path, { cache: "force-cache" }) 18 | .catch(err => console.warn("Failed to get", path, err)); 19 | } 20 | 21 | function link(href, text) { 22 | return `${text}` 23 | } 24 | 25 | async function updatePlot(runs) { 26 | // get the active data 27 | let active = []; 28 | await Promise.all(runs.map(r => r.load(10))); 29 | for (let run of runs) { 30 | for (let suite of run.suites) { 31 | for (let test of suite.tests) { 32 | active.push(test); 33 | } 34 | } 35 | } 36 | 37 | // set the dimensions and margins of the graph 38 | let margin = { top: 20, right: 20, bottom: 30, left: 50 }; 39 | let width = 960 - margin.left - margin.right; 40 | let height = 500 - margin.top - margin.bottom; 41 | 42 | // set the ranges 43 | let x = d3.scaleTime().range([0, width]); 44 | let y = d3.scaleLinear().range([height, 0]); 45 | 46 | let svg = d3 47 | .select("#plot") 48 | .attr("width", width + margin.left + margin.right) 49 | .attr("height", height + margin.top + margin.bottom) 50 | .append("g") 51 | .attr("transform", `translate(${margin.left}, ${margin.top})`); 52 | 53 | // Scale the range of the data 54 | x.domain(d3.extent(active, t => t.run.date)); 55 | y.domain([0, d3.max(active, t => t.avg_time())]); 56 | 57 | let byName = d3 58 | .nest() 59 | .key(d => d.data.name) 60 | .entries(active); 61 | 62 | // define the 1st line 63 | let valueline = d3 64 | .line() 65 | .x(t => x(t.run.date)) 66 | .y(t => y(t.avg_time())) 67 | .curve(d3.curveMonotoneX); 68 | 69 | let groups = svg 70 | .selectAll("g.path-group") 71 | .data(byName) 72 | .join("g") 73 | .attr("class", "path-group"); 74 | 75 | groups 76 | .selectAll("path") 77 | .data(d => [d.values]) 78 | .join("path") 79 | .attr("class", "line") 80 | .attr("d", test => valueline(test)); 81 | 82 | groups 83 | .selectAll(".dot") 84 | .data(d => d.values) 85 | .join("circle") 86 | .attr("class", "dot") 87 | .attr("cx", t => x(t.run.date)) 88 | .attr("cy", t => y(t.avg_time())) 89 | .attr("r", 5); 90 | 91 | tippy("#plot .dot", { 92 | allowHTML: true, 93 | interactive: true, 94 | appendTo: document.body, 95 | content(node) { 96 | let t = node.__data__; 97 | return ` 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 |
    Date ${datetimeFormat(t.run.date)}
    Rev ${t.run.branch} @ ${link(t.run.url, t.run.shortrev)}
    Suite ${link(t.suite.url, t.suite.name)}
    Test ${t.data.name}
    n ${t.data.times.length}
    avg ${t.avg_time()}
    106 | `; 107 | } 108 | }); 109 | 110 | 111 | // Add the X Axis 112 | svg 113 | .append("g") 114 | .attr("transform", `translate(0, ${height})`) 115 | .call(d3.axisBottom(x)); 116 | 117 | // Add the Y Axis 118 | svg.append("g").call(d3.axisLeft(y)); 119 | 120 | // svg 121 | // .selectAll(".text") 122 | // .data(data) 123 | // .enter() 124 | // .append("text") // Uses the enter().append() method 125 | // .attr("class", "label") // Assign a class for styling 126 | // .attr("x", d => x(d.run.date)) 127 | // .attr("y", d => y(d.data.times[0])) 128 | // .attr("dy", "-5") 129 | // .text(() => "label"); 130 | 131 | } 132 | 133 | class Run { 134 | constructor(dirname) { 135 | let match = DIR_REGEX.exec(dirname); 136 | this.path = DATA_DIR + dirname + "/"; 137 | this.name = dirname; 138 | this.date = new Date(match[1]); 139 | this.branch = match[2]; 140 | this.describe = match[3]; 141 | this.tag = match[4]; 142 | this.commitsAhead = +match[5]; 143 | this.rev = match[6]; 144 | this.dirty = !!match[7]; 145 | 146 | this.suites = null; 147 | this.url = `${REPO}/tree/${this.rev}`; 148 | this.shortrev = this.rev.slice(0,7) 149 | } 150 | 151 | async load(depth = 1) { 152 | let suites = await get(this.path); 153 | this.suites = await Promise.all( 154 | suites 155 | .filter(s => s.type == "directory") 156 | .map(async s => { 157 | let suite = new Suite(this, s.name); 158 | if (depth > 0) { 159 | await suite.load(depth - 1); 160 | } 161 | return suite; 162 | }) 163 | ); 164 | return this; 165 | } 166 | } 167 | 168 | class Suite { 169 | constructor(run, name) { 170 | this.run = run; 171 | this.name = name; 172 | this.path = run.path + name + "/"; 173 | 174 | this.tests = null; 175 | this.url = `${REPO}/blob/${run.rev}/tests/${name}.rs`; 176 | } 177 | 178 | async load(depth = 0) { 179 | let tests = await get(this.path); 180 | this.tests = await Promise.all( 181 | tests 182 | .filter(t => t.type == "file") 183 | .map(async t => { 184 | let test = new Test(this.run, this, t.name); 185 | if (depth > 0) { 186 | await test.load(); 187 | } 188 | return test; 189 | }) 190 | ); 191 | return this; 192 | } 193 | } 194 | 195 | class Test { 196 | constructor(run, suite, name) { 197 | this.run = run; 198 | this.suite = suite; 199 | this.name = name; 200 | this.path = suite.path + name; 201 | this.data = null; 202 | } 203 | 204 | async load() { 205 | this.data = await get(this.path); 206 | } 207 | 208 | avg_time() { 209 | var total = 0; 210 | for (let t of this.data.times) total += t; 211 | return total / this.data.times.length; 212 | } 213 | } 214 | 215 | async function fetchRuns() { 216 | let runsListing = await d3.json(DATA_DIR); 217 | console.log("Got the listing", runsListing); 218 | let runs = runsListing 219 | .filter(i => i.type == "directory") 220 | .map(i => new Run(i.name)); 221 | console.log("Runs: ", runs); 222 | return runs; 223 | } 224 | 225 | let RUNS = null; 226 | 227 | async function init() { 228 | let runs = await fetchRuns(); 229 | RUNS = runs; 230 | 231 | runList 232 | .data(runs) 233 | .join("li") 234 | .text(run => `${run.branch} ${datetimeFormat(run.date)}`); 235 | 236 | updatePlot(runs); 237 | } 238 | 239 | init() 240 | .then(() => console.log("Successful init")) 241 | .catch(e => console.error("Failed init", e)); 242 | -------------------------------------------------------------------------------- /src/eggstentions/reconstruct.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{HashMap, HashSet}; 2 | 3 | use crate::{EGraph, Id, Language, SymbolLang, RecExpr, EClass, ColorId}; 4 | use indexmap::IndexMap; 5 | use itertools::Itertools; 6 | use crate::eggstentions::tree::Tree; 7 | 8 | /// Reconstructs a RecExpr from an eclass. 9 | pub fn reconstruct(graph: &EGraph, class: Id, max_depth: usize) 10 | -> Option> { 11 | reconstruct_colored(graph, None, class, max_depth) 12 | } 13 | 14 | /// Reconstructs a RecExpr from an eclass under a specific colored assumption. 15 | pub fn reconstruct_colored(graph: &EGraph, color: Option, class: Id, max_depth: usize) -> Option> { 16 | let mut translations: IndexMap> = IndexMap::new(); 17 | let class = graph.find(class); 18 | reconstruct_inner(&graph, class, max_depth, color, &mut translations); 19 | translations.get(&class).map(|x| x.clone()) 20 | } 21 | 22 | /// Reconstructs a RecExpr from an eclass, but filtering to start with `edge`. 23 | pub fn reconstruct_edge(graph: &EGraph, class: Id, edge: SymbolLang, max_depth: usize) -> Option> { 24 | let mut translations: IndexMap> = IndexMap::new(); 25 | for child in &edge.children { 26 | reconstruct_inner(&graph, *child, max_depth - 1, None, &mut translations); 27 | } 28 | build_translation(graph, None, &mut translations, &edge, class); 29 | translations.get(&class).map(|x| x.clone()) 30 | } 31 | 32 | fn reconstruct_inner(graph: &EGraph, class: Id, max_depth: usize, 33 | color: Option, translations: &mut IndexMap>) { 34 | if max_depth == 0 || translations.contains_key(&class) { 35 | return; 36 | } 37 | let cur_class = &graph[class]; 38 | let mut inner_ids = vec![]; 39 | check_class(graph, color, class, translations, &mut inner_ids, &cur_class); 40 | color.map(|c| { 41 | if let Some(x) = graph.get_color(c) { 42 | let ids = x.equality_class(graph, class); 43 | for id in ids { 44 | let colorded_class = &graph[id]; 45 | check_class(graph, color, id, translations, &mut inner_ids, &colorded_class) 46 | } 47 | } 48 | }); 49 | inner_ids.sort_by_key(|c| c.children.len()); 50 | for edge in inner_ids { 51 | for id in &edge.children { 52 | reconstruct_inner(graph, *id, max_depth - 1, color, translations); 53 | 54 | } 55 | if edge.children.iter().all(|c| translations.contains_key(c) || 56 | color.map_or(false, |c_id| graph.get_color(c_id).map_or(false, |x| 57 | x.equality_class(graph, class).find(|id| translations.contains_key(id)).is_some()))) { 58 | build_translation(graph, color, translations, &edge, class); 59 | return; 60 | } 61 | } 62 | } 63 | 64 | fn check_class<'a>(graph: &EGraph, color: Option, class: Id, translations: &mut IndexMap>, inner_ids: &mut Vec<&'a SymbolLang>, colorded_class: &'a EClass) { 65 | for edge in &colorded_class.nodes { 66 | if edge.children.iter().all(|c| translations.contains_key(c)) { 67 | build_translation(graph, color, translations, &edge, class); 68 | return; 69 | } 70 | inner_ids.push(&edge); 71 | } 72 | } 73 | 74 | fn build_translation(graph: &EGraph, color: Option, translations: &mut IndexMap>, edge: &SymbolLang, id: Id) { 75 | let mut res = vec![]; 76 | let mut children = vec![]; 77 | for c in edge.children.iter() { 78 | let cur_len = res.len(); 79 | let translation = translations.get(c).or_else(|| 80 | color.map(|c_id| 81 | graph.get_color(c_id).map(|x| 82 | x.equality_class(graph, *c).find_map(|id| 83 | // Build translation is only called when a translation exists 84 | translations.get(&id)))) 85 | .flatten().flatten() 86 | ); 87 | if translation.is_none() { return; } 88 | res.extend(translation.unwrap().as_ref().iter().cloned().map(|s| s.map_children(|child| Id::from(usize::from(child) + cur_len)))); 89 | children.push(Id::from(res.len() - 1)); 90 | }; 91 | res.push(SymbolLang::new(edge.op, children)); 92 | translations.insert(id, RecExpr::from(res)); 93 | } 94 | 95 | /// Reconstructs a RecExpr for each EClass in the graph. 96 | pub fn reconstruct_all(graph: &EGraph, color: Option, max_depth: usize) 97 | -> IndexMap { 98 | let mut translations: IndexMap = IndexMap::default(); 99 | let mut edge_in_need: HashMap> = HashMap::default(); 100 | 101 | let mut todo = HashSet::new(); 102 | 103 | let mut layers = vec![vec![]]; 104 | // Initialize data structures (translations, and which edges might be "free" next) 105 | for c in graph.classes() 106 | .filter(|c| c.color().is_none() || c.color() == color) { 107 | let fixed_id = graph.opt_colored_find(color, c.id); 108 | for n in &c.nodes { 109 | let fixed_n = if color.is_some() { 110 | graph.colored_canonize(*color.as_ref().unwrap(), n) 111 | } else { 112 | n.clone() 113 | }; 114 | if n.children().is_empty() { 115 | todo.insert(fixed_id); 116 | translations.insert(fixed_id, fixed_n); 117 | layers.last_mut().unwrap().push(fixed_id); 118 | } else { 119 | for ch in fixed_n.children() { 120 | let fixed_child = graph.opt_colored_find(color, *ch); 121 | // this might be a bit expensive to do for each edge 122 | edge_in_need.entry(fixed_child).or_default().push((fixed_id, fixed_n.clone())); 123 | } 124 | } 125 | } 126 | } 127 | let mut res = IndexMap::new(); 128 | for (id, n) in translations.iter() { 129 | res.insert(*id, Tree::leaf(n.op.to_string())); 130 | } 131 | 132 | let empty = vec![]; 133 | // Build layers 134 | for _ in 0..max_depth { 135 | layers.push(vec![]); 136 | let doing = std::mem::take(&mut todo); 137 | for c in doing { 138 | for (trg, n) in edge_in_need.get(&c).unwrap_or(&empty) { 139 | if (!translations.contains_key(trg)) && 140 | n.children().iter().all(|ch| translations.contains_key(ch)) { 141 | translations.insert(*trg, n.clone()); 142 | todo.insert(*trg); 143 | layers.last_mut().unwrap().push(*trg); 144 | } 145 | } 146 | } 147 | } 148 | 149 | // Build translations 150 | for l in layers.iter().dropping(1) { 151 | for id in l { 152 | let n = &translations[id]; 153 | let new_tree = Tree::branch(n.op.to_string(), n.children().iter().map(|ch| res[ch].clone()).collect()); 154 | res.insert(*id, new_tree); 155 | } 156 | } 157 | res 158 | } 159 | -------------------------------------------------------------------------------- /src/tutorials/_02_getting_started.rs: -------------------------------------------------------------------------------- 1 | // -*- mode: markdown; markdown-fontify-code-block-default-mode: rustic-mode; evil-shift-width: 2; -*- 2 | /*! 3 | 4 | # My first `egg` 🐣 5 | 6 | This tutorial is aimed at getting you up and running with `egg`, 7 | even if you have little Rust experience. 8 | If you haven't heard about e-graphs, you may want to read the 9 | [background tutorial](../_01_background/index.html). 10 | If you do have prior Rust experience, you may want to skim around in this section. 11 | 12 | ## Getting started with Rust 13 | 14 | [Rust](https://rust-lang.org) 15 | is one of the reasons why `egg` 16 | is fast (systems programming + optimizing compiler) and flexible (generics and traits). 17 | 18 | The Rust folks have put together a great, free [book](https://doc.rust-lang.org/stable/book/) for learning Rust, 19 | and there are a bunch of other fantastic resources collected on the 20 | ["Learn" page of the Rust site](https://www.rust-lang.org/learn). 21 | This tutorial is no replacement for those, 22 | but instead it aims to get you up and running as fast as possible. 23 | 24 | First, 25 | [install](https://www.rust-lang.org/tools/install) Rust 26 | and let's [create a project](https://doc.rust-lang.org/cargo/getting-started/first-steps.html) 27 | with `cargo`, Rust's package management and build tool: `cargo new my-first-egg`.[^lib] 28 | 29 | [^lib]: By default `cargo` will create a binary project. 30 | If you are just getting starting with Rust, 31 | it might be easier to stick with a binary project, 32 | just put all your code in `main`, and use `cargo run`. 33 | Library projects (`cargo new --lib my-first-egg`) 34 | can be easier to build on once you want to start writing tests. 35 | 36 | Now we can add `egg` as a project dependency by adding a line to `Cargo.toml`: 37 | ```toml 38 | [dependencies] 39 | egg = "0.6.0" 40 | ``` 41 | 42 | All of the code samples below work, but you'll have to `use` the relevant types. 43 | You can just bring them all in with a `use easter_egg::*;` at the top of your file. 44 | 45 | ## Now you're speaking my [`Language`] 46 | 47 | [`EGraph`]s (and almost everything else in this crate) are 48 | parameterized over the [`Language`] given by the user. 49 | While `egg` supports the ability easily create your own [`Language`], 50 | we will instead start with the provided [`SymbolLang`]. 51 | 52 | [`Language`] is a trait, 53 | and values of types that implement [`Language`] are e-nodes. 54 | An e-node may have any number of children, which are [`Id`]s. 55 | An [`Id`] is basically just a number that `egg` uses to coordinate what children 56 | an e-node is associated with. 57 | In an [`EGraph`], e-node children refer to e-classes. 58 | In a [`RecExpr`] (`egg`'s version of a plain old expression), 59 | e-node children refer to other e-nodes in that [`RecExpr`]. 60 | 61 | Most [`Language`]s, including [`SymbolLang`], can be parsed and pretty-printed. 62 | That means that [`RecExpr`]s in those languages 63 | implement the [`FromStr`] and [`Display`] traits from the Rust standard library. 64 | ``` 65 | # use easter_egg::*; 66 | // Since parsing can return an error, `unwrap` just panics if the result doesn't return Ok 67 | let my_expression: RecExpr = "(foo a b)".parse().unwrap(); 68 | println!("this is my expression {}", my_expression); 69 | 70 | // let's try to create an e-node, but hmmm, what do I put as the children? 71 | let my_enode = SymbolLang::new("bar", vec![]); 72 | ``` 73 | 74 | Some e-nodes are just constants and have no children (also called leaves). 75 | But it's intentionally kind of awkward to create e-nodes with children in isolation, 76 | since you would have to add meaningless [`Id`]s as children. 77 | The way to make meaningful [`Id`]s is by adding e-nodes to either an [`EGraph`] or a [`RecExpr`]: 78 | 79 | ``` 80 | # use easter_egg::*; 81 | let mut expr = RecExpr::default(); 82 | let a = expr.add(SymbolLang::leaf("a")); 83 | let b = expr.add(SymbolLang::leaf("b")); 84 | let foo = expr.add(SymbolLang::new("foo", vec![a, b])); 85 | 86 | // we can do the same thing with an EGraph 87 | let mut egraph: EGraph = Default::default(); 88 | let a = egraph.add(SymbolLang::leaf("a")); 89 | let b = egraph.add(SymbolLang::leaf("b")); 90 | let foo = egraph.add(SymbolLang::new("foo", vec![a, b])); 91 | 92 | // we can also add RecExprs to an egraph 93 | let foo2 = egraph.add_expr(&expr); 94 | // note that if you add the same thing to an e-graph twice, you'll get back equivalent Ids 95 | assert_eq!(foo, foo2); 96 | ``` 97 | 98 | ## Searching an [`EGraph`] with [`Pattern`]s 99 | 100 | Now that we can add stuff to an [`EGraph`], let's see if we can find it. 101 | We'll use a [`Pattern`], which implements the [`Searcher`] trait, 102 | to search the e-graph for matches: 103 | 104 | ``` 105 | # use easter_egg::*; 106 | // let's make an e-graph 107 | let mut egraph: EGraph = Default::default(); 108 | let a = egraph.add(SymbolLang::leaf("a")); 109 | let b = egraph.add(SymbolLang::leaf("b")); 110 | let foo = egraph.add(SymbolLang::new("foo", vec![a, b])); 111 | 112 | // we can make Patterns by parsing, similar to RecExprs 113 | // names preceded by ? are parsed as Pattern variables and will match anything 114 | let pat: Pattern = "(foo ?x ?x)".parse().unwrap(); 115 | 116 | // since we use ?x twice, it must match the same thing, 117 | // so this search will return nothing 118 | let matches = pat.search(&egraph); 119 | assert!(matches.is_none()); 120 | 121 | egraph.union(a, b); 122 | // recall that rebuild must be called to "see" the effects of unions 123 | egraph.rebuild(); 124 | 125 | // now we can find a match since a = b 126 | let matches = pat.search(&egraph); 127 | assert!(matches.is_some()) 128 | ``` 129 | 130 | 131 | 132 | ## Using [`Runner`] to make an optimizer 133 | 134 | Now that we can make [`Pattern`]s and work with [`RecExpr`]s, we can make an optimizer! 135 | We'll use the [`rewrite!`] macro to easily create [`Rewrite`]s which consist of a name, 136 | left-hand pattern to search for, 137 | and right-hand pattern to apply. 138 | From there we can use the [`Runner`] API to run `egg`'s equality saturation algorithm. 139 | Finally, we can use an [`Extractor`] to get the best result. 140 | ``` 141 | use easter_egg::{*, rewrite as rw}; 142 | 143 | let rules: &[Rewrite] = &[ 144 | rw!("commute-add"; "(+ ?x ?y)" => "(+ ?y ?x)"), 145 | rw!("commute-mul"; "(* ?x ?y)" => "(* ?y ?x)"), 146 | 147 | rw!("add-0"; "(+ ?x 0)" => "?x"), 148 | rw!("mul-0"; "(* ?x 0)" => "0"), 149 | rw!("mul-1"; "(* ?x 1)" => "?x"), 150 | ]; 151 | 152 | // While it may look like we are working with numbers, 153 | // SymbolLang stores everything as strings. 154 | // We can make our own Language later to work with other types. 155 | let start = "(+ 0 (* 1 a))".parse().unwrap(); 156 | 157 | // That's it! We can run equality saturation now. 158 | let runner = Runner::default().with_expr(&start).run(rules); 159 | 160 | // Extractors can take a user-defined cost function, 161 | // we'll use the egg-provided AstSize for now 162 | let mut extractor = Extractor::new(&runner.egraph, AstSize); 163 | 164 | // We want to extract the best expression represented in the 165 | // same e-class as our initial expression, not from the whole e-graph. 166 | // Luckily the runner stores the eclass Id where we put the initial expression. 167 | let (best_cost, best_expr) = extractor.find_best(runner.roots[0]); 168 | 169 | // we found the best thing, which is just "a" in this case 170 | assert_eq!(best_expr, "a".parse().unwrap()); 171 | assert_eq!(best_cost, 1); 172 | ``` 173 | 174 | [`EGraph`]: ../../struct.EGraph.html 175 | [`Id`]: ../../struct.Id.html 176 | [`Language`]: ../../trait.Language.html 177 | [`Searcher`]: ../../trait.Searcher.html 178 | [`Pattern`]: ../../struct.Pattern.html 179 | [`RecExpr`]: ../../struct.RecExpr.html 180 | [`SymbolLang`]: ../../struct.SymbolLang.html 181 | [`define_language!`]: ../../macro.define_language.html 182 | [`rewrite!`]: ../../macro.rewrite.html 183 | [`FromStr`]: https://doc.rust-lang.org/std/str/trait.FromStr.html 184 | [`Display`]: https://doc.rust-lang.org/stable/std/fmt/trait.Display.html 185 | [`Rewrite`]: ../../struct.Rewrite.html 186 | [`Runner`]: ../../struct.Runner.html 187 | [`Extractor`]: ../../struct.Extractor.html 188 | 189 | */ 190 | -------------------------------------------------------------------------------- /src/test.rs: -------------------------------------------------------------------------------- 1 | /*! Utilities for benchmarking egg. 2 | 3 | These are not considered part of the public api. 4 | */ 5 | #![macro_use] 6 | 7 | /* 8 | #[cfg(test)] 9 | #[macro_use] 10 | pub mod test { 11 | */ 12 | use std::path::PathBuf; 13 | use std::time::{Duration, Instant}; 14 | 15 | fn mean_stdev(data: &[f64]) -> (f64, f64) { 16 | assert_ne!(data.len(), 0); 17 | 18 | let sum = data.iter().sum::(); 19 | let n = data.len() as f64; 20 | let mean = sum / n; 21 | 22 | let variance = data 23 | .iter() 24 | .map(|value| { 25 | let diff = mean - (*value as f64); 26 | diff * diff 27 | }) 28 | .sum::() 29 | / n; 30 | 31 | (mean, variance.sqrt()) 32 | } 33 | 34 | pub fn env_var(s: &str) -> Option 35 | where 36 | T: std::str::FromStr, 37 | T::Err: std::fmt::Debug, 38 | { 39 | use std::env::VarError; 40 | match std::env::var(s) { 41 | Err(VarError::NotPresent) => None, 42 | Err(VarError::NotUnicode(_)) => panic!("Environment variable {} isn't unicode", s), 43 | Ok(v) if v.is_empty() => None, 44 | Ok(v) => match v.parse() { 45 | Ok(v) => Some(v), 46 | Err(err) => panic!("Couldn't parse environment variable {}={}, {:?}", s, v, err), 47 | }, 48 | } 49 | } 50 | 51 | pub struct Reporter { 52 | name: String, 53 | times: Option>, 54 | result: T, 55 | } 56 | 57 | impl Reporter { 58 | pub fn into_inner(self) -> T { 59 | // consume these so rust doesn't complain 60 | let _ = self.name; 61 | let _ = self.times; 62 | self.result 63 | } 64 | 65 | #[cfg(not(feature = "reports"))] 66 | pub fn report(self, to_report: impl FnOnce(&T) -> &R) -> T { 67 | if let Some(dir) = env_var::("EGG_BENCH_DIR") { 68 | eprintln!( 69 | "EGG_BENCH_DIR is set to '{:?}', but the 'reports' feature is not enabled", 70 | dir 71 | ); 72 | } 73 | to_report(&self.result); 74 | self.result 75 | } 76 | 77 | #[cfg(feature = "reports")] 78 | pub fn report(self, to_report: impl FnOnce(&T) -> &R) -> T 79 | where 80 | R: serde::Serialize, 81 | { 82 | let directory = match env_var::("EGG_BENCH_DIR") { 83 | None => { 84 | eprintln!("EGG_BENCH_DIR not set, skipping reporting"); 85 | return self.result; 86 | } 87 | Some(dir) => { 88 | assert!(dir.is_dir(), "EGG_BENCH_DIR is not a directory: {:?}", dir); 89 | dir 90 | } 91 | }; 92 | 93 | let filename = format!("{}.json", self.name); 94 | let path = directory.join(&filename); 95 | let file = std::fs::OpenOptions::new() 96 | .write(true) 97 | .create_new(true) 98 | .open(&path) 99 | .unwrap_or_else(|err| panic!("Failed to open {:?}: {}", path, err)); 100 | 101 | let report = serde_json::json!({ 102 | "name": &self.name, 103 | "times": self.times.as_deref(), 104 | "data": to_report(&self.result), 105 | }); 106 | 107 | serde_json::to_writer_pretty(file, &report) 108 | .unwrap_or_else(|err| panic!("Failed to serialize report to {:?}: {}", path, err)); 109 | 110 | println!("Wrote report to {:?}", path); 111 | self.result 112 | } 113 | } 114 | 115 | pub fn run(name: impl Into, mut f: impl FnMut() -> T) -> Reporter { 116 | let name = name.into(); 117 | let seconds: f64 = match env_var("EGG_BENCH") { 118 | Some(s) => s, 119 | None => { 120 | return Reporter { 121 | name, 122 | times: None, 123 | result: f(), 124 | } 125 | } 126 | }; 127 | 128 | let duration = Duration::from_secs_f64(seconds); 129 | 130 | let start = Instant::now(); 131 | let mut times = vec![]; 132 | 133 | println!("benching {} for {} seconds...", name, seconds); 134 | 135 | let result = loop { 136 | let i = Instant::now(); 137 | let result = f(); 138 | times.push(i.elapsed().as_secs_f64()); 139 | 140 | if start.elapsed() > duration { 141 | break result; 142 | } 143 | }; 144 | 145 | let (mean, stdev) = mean_stdev(×); 146 | println!("bench {}:", name); 147 | println!(" n = {}", times.len()); 148 | println!(" μ = {}", mean); 149 | println!(" σ = {}", stdev); 150 | 151 | Reporter { 152 | name, 153 | times: Some(times), 154 | result, 155 | } 156 | } 157 | /* 158 | prop_compose! { 159 | fn arb_symbol(max: u32)(u in 0..=max) -> Symbol { 160 | Symbol(u) 161 | } 162 | } 163 | 164 | prop_compose! { 165 | fn arb_id(max: u32)(u in 0..=max) -> Id { 166 | Id(u) 167 | } 168 | } 169 | 170 | prop_compose! { 171 | fn arb_node(max_sym: u32, max_children: u32, max_id: u32)( 172 | sym in arb_symbol(max_sym), 173 | vec in prop::collection::vec(arb_id(max_id), 0..=(max_children as usize))) -> SymbolLang { 174 | SymbolLang::new(sym, vec) 175 | } 176 | }*/ 177 | 178 | // prop_compose! { 179 | // fn arb_expression(max_sym: u32, max_children: u32, depth: u32)( 180 | // sym in arb_symbol(max_sym), 181 | // vec in prop::collection::vec(arb_id(max_id), 0..=(max_children as usize))) -> Expression { 182 | // let mut exp = RecExpr::default(); 183 | // 184 | // exp.add() 185 | // } 186 | // } 187 | 188 | #[allow(unused_imports)] 189 | #[cfg(test)] 190 | use invariants; 191 | 192 | /// Make a test function 193 | #[macro_export] 194 | macro_rules! test_fn { 195 | ( 196 | $(#[$meta:meta])* 197 | $name:ident, $rules:expr, 198 | $start:literal 199 | => 200 | $($goal:literal),+ $(,)? 201 | $(@check $check_fn:expr)? 202 | ) => { 203 | $crate::test_fn! { 204 | $(#[$meta])* 205 | $name, $rules, 206 | runner = $crate::Runner::<_, _, ()>::default(), 207 | $start => $( $goal ),+ 208 | $(@check $check_fn)? 209 | } 210 | }; 211 | 212 | ( 213 | $(#[$meta:meta])* 214 | $name:ident, $rules:expr, 215 | runner = $runner:expr, 216 | $start:literal 217 | => 218 | $($goal:literal),+ $(,)? 219 | $(@check $check_fn:expr)? 220 | ) => { 221 | $(#[$meta])* 222 | #[test] 223 | fn $name() { 224 | let _ = env_logger::builder().is_test(true).try_init(); 225 | let level = invariants::max_level(); 226 | invariants::set_max_level(invariants::AssertLevel::Off); 227 | let name = stringify!($name); 228 | let start: $crate::RecExpr<_> = $start.parse().unwrap(); 229 | let rules = $rules; 230 | 231 | let runner: $crate::Runner<_, _, ()> = $crate::test::run(name, || { 232 | let mut runner = $runner.with_expr(&start); 233 | if let Some(lim) = $crate::test::env_var("EGG_NODE_LIMIT") { 234 | runner = runner.with_node_limit(lim) 235 | } 236 | if let Some(lim) = $crate::test::env_var("EGG_ITER_LIMIT") { 237 | runner = runner.with_iter_limit(lim) 238 | } 239 | if let Some(lim) = $crate::test::env_var("EGG_TIME_LIMIT") { 240 | runner = runner.with_time_limit(std::time::Duration::from_secs(lim)) 241 | } 242 | runner.run(&rules) 243 | }).report(|r| &r.iterations); 244 | runner.print_report(); 245 | 246 | let goals = &[$( 247 | $goal.parse().unwrap() 248 | ),+]; 249 | 250 | // NOTE this is a bit of hack, we rely on the fact that the 251 | // initial root is the last expr added by the runner. We can't 252 | // use egraph.find_expr(start) because it may have been pruned 253 | // away 254 | let id = runner.egraph.find(*runner.roots.last().unwrap()); 255 | runner.egraph.check_goals(id, goals); 256 | // This is very bad because we are accessing a gloabl variable. But *&%$ it for this tests its enough. 257 | invariants::set_max_level(level); 258 | $( ($check_fn)(runner) )? 259 | } 260 | }; 261 | } 262 | /* 263 | pub(crate) use test_fn; 264 | } 265 | */ 266 | -------------------------------------------------------------------------------- /src/unionfind.rs: -------------------------------------------------------------------------------- 1 | use crate::Id; 2 | use std::fmt::Debug; 3 | use as_any::AsAny; 4 | use bimap::BiBTreeMap; 5 | 6 | pub trait UnionFind : AsAny + Debug + Send + Sync { 7 | /// Returns the number of elements in the union find. 8 | fn len(&self) -> usize; 9 | 10 | /// Finds the leader of the set that `current` is in. 11 | /// If K is not in the union find, it should return K. 12 | fn find(&self, current: K) -> K; 13 | 14 | /// Finds the leader of the set that `current` is in. 15 | /// This version updates the parents to compress the path. 16 | fn find_mut(&mut self, current: K) -> K; 17 | 18 | /// Given two leader ids, unions the two eclasses. 19 | /// This should run find to compress paths for efficiency. 20 | /// Returns (new leader, other id found). 21 | /// If either root is not in the union find, it should insert it or panic. 22 | fn union(&mut self, root1: K, root2: K) -> (K, K); 23 | 24 | /// Return a boxed clone of the union find. 25 | fn clone_box(&self) -> Box>; 26 | 27 | /// Return an iterator over the leaders. 28 | fn iter(&self) -> Box + '_>; 29 | } 30 | 31 | impl Clone for Box + 'static> where 32 | K: Copy + std::cmp::Eq + 'static, 33 | { 34 | fn clone(&self) -> Self { 35 | self.clone_box() 36 | } 37 | } 38 | 39 | #[derive(Debug, Clone, Default)] 40 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 41 | pub struct SimpleUnionFind { 42 | parents: Vec, 43 | } 44 | 45 | impl SimpleUnionFind { 46 | pub(crate) fn parent(&self, query: Id) -> Id { 47 | self.parents[usize::from(query)] 48 | } 49 | 50 | pub(crate) fn parent_mut(&mut self, query: Id) -> &mut Id { 51 | &mut self.parents[usize::from(query)] 52 | } 53 | 54 | /// Creates a new union find with a single element. 55 | pub fn make_set(&mut self) -> Id { 56 | let id = Id::from(self.parents.len()); 57 | self.parents.push(id); 58 | id 59 | } 60 | } 61 | 62 | impl<'a> UnionFind for SimpleUnionFind { 63 | fn len(&self) -> usize { 64 | self.parents.len() 65 | } 66 | 67 | fn find(&self, mut current: Id) -> Id { 68 | while current != self.parent(current) { 69 | current = self.parent(current) 70 | } 71 | current 72 | } 73 | 74 | /// Given two leader ids, unions the two eclasses making root1 the leader. 75 | fn union(&mut self, root1: Id, root2: Id) -> (Id, Id) { 76 | let root1 = self.find_mut(root1); 77 | let root2 = self.find_mut(root2); 78 | if root1 > root2 { 79 | return self.union(root2, root1); 80 | } 81 | *self.parent_mut(root2) = root1; 82 | (root1, root2) 83 | } 84 | 85 | fn clone_box(&self) -> Box> { 86 | Box::new(self.clone()) 87 | } 88 | 89 | fn iter(&self) -> Box + '_> { 90 | let it = self.parents.iter() 91 | .enumerate() 92 | .filter(|(i, p)| *i == (p.0 as usize)) 93 | .map(|(_, p)| *p); 94 | Box::new(it) 95 | } 96 | 97 | fn find_mut(&mut self, mut current: Id) -> Id { 98 | let mut collected = vec![]; 99 | while current != self.parent(current) { 100 | collected.push(current); 101 | current = self.parent(current); 102 | } 103 | for c in collected { 104 | *self.parent_mut(c) = current; 105 | } 106 | current 107 | } 108 | } 109 | 110 | impl SimpleUnionFind { 111 | pub(crate) fn union_no_swap(&mut self, root1: Id, root2: Id) -> (Id, Id) { 112 | let root1 = self.find_mut(root1); 113 | let root2 = self.find_mut(root2); 114 | *self.parent_mut(root2) = root1; 115 | (root1, root2) 116 | } 117 | } 118 | 119 | /// Data inside the union find wrapper should implement a merge function. 120 | pub trait Merge { 121 | fn merge(&mut self, other: Self); 122 | } 123 | 124 | impl Merge for () { 125 | fn merge(&mut self, _: Self) {} 126 | } 127 | 128 | /// A wrapper for other union find implementations [U]. 129 | /// This "translates" keys [K] to the internal representation so that external api can use any key. 130 | /// It also holds an object [T] for each equivalence class which is unioned with the [merge] function. 131 | /// It won't implement the union find api right now because I don't want to change it at the moment 132 | #[derive(Debug, Clone, Default)] 133 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 134 | pub struct UnionFindWrapper { 135 | uf: SimpleUnionFind, 136 | trns: BiBTreeMap 137 | } 138 | 139 | impl UnionFind for UnionFindWrapper { 140 | // This is hacky, interface should be &K but I want runtime above all 141 | fn union(&mut self, key1: K, key2: K) -> (K, K) { 142 | if !self.trns.contains_left(&key1) { 143 | self.insert(key1); 144 | } 145 | if !self.trns.contains_left(&key2) { 146 | self.insert(key2); 147 | } 148 | let mut idx1 = unsafe { *self.trns.get_by_left(&key1).unwrap_unchecked() }; 149 | let mut idx2 = unsafe { *self.trns.get_by_left(&key2).unwrap_unchecked() }; 150 | // I need to union by the keys of trns 151 | let mut k1 = self.uf.find_mut(idx1.into()); 152 | let mut k2 = self.uf.find_mut(idx2.into()); 153 | if self.trns.get_by_right(&(k1.0 as usize)) > self.trns.get_by_right(&(k2.0 as usize)) { 154 | std::mem::swap(&mut k1, &mut k2); 155 | std::mem::swap(&mut idx1, &mut idx2); 156 | } 157 | let (root1, root2) = self.uf.union_no_swap(k1, k2); 158 | let key1 = self.trns.get_by_right(&(root1.0 as usize)).unwrap(); 159 | let key2 = self.trns.get_by_right(&(root2.0 as usize)).unwrap(); 160 | (key1.clone(), key2.clone()) 161 | } 162 | 163 | fn find(&self, key: K) -> K { 164 | let idx = self.trns.get_by_left(&key); 165 | match idx { 166 | None => return key, 167 | Some(idx) => { 168 | let root = self.uf.find((*idx).into()); 169 | *self.trns.get_by_right(&(root.0 as usize)).unwrap() 170 | } 171 | } 172 | } 173 | 174 | fn len(&self) -> usize { 175 | self.trns.len() 176 | } 177 | 178 | fn clone_box(&self) -> Box> { 179 | Box::new(self.clone()) 180 | } 181 | 182 | fn iter(&self) -> Box + '_> { 183 | Box::new(self.trns.iter().map(|(k, _)| *k)) 184 | } 185 | 186 | fn find_mut(&mut self, key: K) -> K { 187 | let idx = self.trns.get_by_left(&key); 188 | match idx { 189 | None => return key, 190 | Some(idx) => { 191 | let root = self.uf.find_mut((*idx).into()); 192 | *self.trns.get_by_right(&(root.0 as usize)).unwrap() 193 | } 194 | } 195 | } 196 | } 197 | 198 | impl UnionFindWrapper { 199 | pub fn insert(&mut self, key: K) { 200 | if self.trns.contains_left(&key) { 201 | return; 202 | } 203 | let id = self.uf.make_set(); 204 | self.trns.insert(key, id.0 as usize); 205 | } 206 | 207 | /// Swap a key with a new key. Panics if new key already exists. 208 | pub fn swap(&mut self, key: K, new_key: K) { 209 | if self.trns.contains_left(&new_key) { 210 | panic!("Key already exists"); 211 | } 212 | let (_, idx) = self.trns.remove_by_left(&key).unwrap(); 213 | self.trns.insert(new_key, idx); 214 | } 215 | 216 | /// Remove a node from the union-find. It will not remove the group, but it will remove a single node. 217 | /// Fails if the node is a leader. 218 | // pub fn remove(&mut self, t: &K) -> Option<()> { 219 | // let leader = self.find_mut(*t); 220 | // if &leader == t { 221 | // return None; 222 | // } 223 | // self.trns.remove_by_left(t); 224 | // Some(()) 225 | // } 226 | 227 | pub fn contains(&self, key: &K) -> bool { 228 | self.trns.contains_left(key) 229 | } 230 | } 231 | 232 | impl UnionFindWrapper { 233 | #[allow(dead_code)] 234 | pub(crate) fn debug_print_all(&self) { 235 | for (k, v) in self.trns.iter() { 236 | println!("{:?}: {:?}", k, v); 237 | } 238 | for p in &self.uf.parents { 239 | println!("{:?}", p); 240 | } 241 | } 242 | } 243 | 244 | 245 | #[cfg(test)] 246 | mod tests { 247 | use crate::unionfind::SimpleUnionFind; 248 | use super::*; 249 | 250 | fn ids(us: impl IntoIterator) -> Vec { 251 | us.into_iter().map(|u| u.into()).collect() 252 | } 253 | 254 | #[test] 255 | fn union_find() { 256 | let n = 10; 257 | let id = Id::from; 258 | 259 | let mut uf = SimpleUnionFind::default(); 260 | for _ in 0..n { 261 | uf.make_set(); 262 | } 263 | 264 | // test the initial condition of everyone in their own set 265 | assert_eq!(uf.parents, ids(0..n)); 266 | 267 | // build up one set 268 | uf.union(id(0), id(1)); 269 | uf.union(id(0), id(2)); 270 | uf.union(id(0), id(3)); 271 | 272 | // build up another set 273 | uf.union(id(6), id(7)); 274 | uf.union(id(6), id(8)); 275 | uf.union(id(6), id(9)); 276 | 277 | // this should compress all paths 278 | for i in 0..n { 279 | uf.find_mut(id(i)); 280 | } 281 | 282 | // indexes: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 283 | let expected = vec![0, 0, 0, 0, 4, 5, 6, 6, 6, 6]; 284 | assert_eq!(uf.parents, ids(expected)); 285 | } 286 | } 287 | -------------------------------------------------------------------------------- /src/extract.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | use std::fmt::Debug; 3 | use indexmap::IndexMap; 4 | 5 | use crate::{Analysis, EClass, EGraph, Id, Language, RecExpr}; 6 | 7 | /** Extracting a single [`RecExpr`] from an [`EGraph`]. 8 | 9 | ``` 10 | use easter_egg::*; 11 | 12 | define_language! { 13 | enum SimpleLanguage { 14 | Num(i32), 15 | "+" = Add([Id; 2]), 16 | "*" = Mul([Id; 2]), 17 | } 18 | } 19 | 20 | let rules: &[Rewrite] = &[ 21 | rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), 22 | rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), 23 | 24 | rewrite!("add-0"; "(+ ?a 0)" => "?a"), 25 | rewrite!("mul-0"; "(* ?a 0)" => "0"), 26 | rewrite!("mul-1"; "(* ?a 1)" => "?a"), 27 | ]; 28 | 29 | let start = "(+ 0 (* 1 10))".parse().unwrap(); 30 | let runner = Runner::default().with_expr(&start).run(rules); 31 | let (egraph, root) = (runner.egraph, runner.roots[0]); 32 | 33 | let mut extractor = Extractor::new(&egraph, AstSize); 34 | let (best_cost, best) = extractor.find_best(root); 35 | assert_eq!(best_cost, 1); 36 | assert_eq!(best, "10".parse().unwrap()); 37 | ``` 38 | 39 | [`RecExpr`]: struct.RecExpr.html 40 | [`EGraph`]: struct.EGraph.html 41 | **/ 42 | pub struct Extractor<'a, CF: CostFunction, L: Language, N: Analysis> { 43 | cost_function: CF, 44 | costs: IndexMap, 45 | egraph: &'a EGraph, 46 | } 47 | 48 | /** A cost function that can be used by an [`Extractor`]. 49 | 50 | To extract an expression from an [`EGraph`], the [`Extractor`] 51 | requires a cost function to performs its greedy search. 52 | `egg` provides the simple [`AstSize`] and [`AstDepth`] cost functions. 53 | 54 | The example below illustrates a silly but realistic example of 55 | implementing a cost function that is essentially AST size weighted by 56 | the operator: 57 | ``` 58 | # use easter_egg::*; 59 | struct SillyCostFn; 60 | impl CostFunction for SillyCostFn { 61 | type Cost = f64; 62 | fn cost(&mut self, enode: &SymbolLang, mut costs: C) -> Self::Cost 63 | where 64 | C: FnMut(Id) -> Self::Cost 65 | { 66 | let op_cost = match enode.op.as_str() { 67 | "foo" => 100.0, 68 | "bar" => 0.7, 69 | _ => 1.0 70 | }; 71 | enode.fold(op_cost, |sum, id| sum + costs(id)) 72 | } 73 | } 74 | 75 | let e: RecExpr = "(do_it foo bar baz)".parse().unwrap(); 76 | assert_eq!(SillyCostFn.cost_rec(&e), 102.7); 77 | assert_eq!(AstSize.cost_rec(&e), 4); 78 | assert_eq!(AstDepth.cost_rec(&e), 2); 79 | ``` 80 | 81 | [`AstSize`]: struct.AstSize.html 82 | [`AstDepth`]: struct.AstDepth.html 83 | [`Extractor`]: struct.Extractor.html 84 | [`EGraph`]: struct.EGraph.html 85 | **/ 86 | pub trait CostFunction { 87 | /// The `Cost` type. It only requires `PartialOrd` so you can use 88 | /// floating point types, but failed comparisons (`NaN`s) will 89 | /// result in a panic. 90 | type Cost: PartialOrd + Debug + Clone; 91 | 92 | /// Calculates the cost of an enode whose children are `Cost`s. 93 | /// 94 | /// For this to work properly, your cost function should be 95 | /// _monotonic_, i.e. `cost` should return a `Cost` greater than 96 | /// any of the child costs of the given enode. 97 | fn cost(&mut self, enode: &L, costs: C) -> Self::Cost 98 | where 99 | C: FnMut(Id) -> Self::Cost; 100 | 101 | /// Calculates the total cost of a [`RecExpr`]. 102 | /// 103 | /// As provided, this just recursively calls `cost` all the way 104 | /// down the [`RecExpr`]. 105 | /// 106 | /// [`RecExpr`]: struct.RecExpr.html 107 | fn cost_rec(&mut self, expr: &RecExpr) -> Self::Cost { 108 | let mut costs: IndexMap = IndexMap::default(); 109 | for (i, node) in expr.as_ref().iter().enumerate() { 110 | let cost = self.cost(node, |i| costs[&i].clone()); 111 | costs.insert(Id::from(i), cost); 112 | } 113 | let last_id = Id::from(expr.as_ref().len() - 1); 114 | costs[&last_id].clone() 115 | } 116 | } 117 | 118 | /** A simple [`CostFunction`] that counts total ast size. 119 | 120 | ``` 121 | # use easter_egg::*; 122 | let e: RecExpr = "(do_it foo bar baz)".parse().unwrap(); 123 | assert_eq!(AstSize.cost_rec(&e), 4); 124 | ``` 125 | 126 | [`CostFunction`]: trait.CostFunction.html 127 | **/ 128 | pub struct AstSize; 129 | impl CostFunction for AstSize { 130 | type Cost = usize; 131 | fn cost(&mut self, enode: &L, mut costs: C) -> Self::Cost 132 | where 133 | C: FnMut(Id) -> Self::Cost, 134 | { 135 | enode.fold(1, |sum, id| sum + costs(id)) 136 | } 137 | } 138 | 139 | /** A simple [`CostFunction`] that counts maximum ast depth. 140 | 141 | ``` 142 | # use easter_egg::*; 143 | let e: RecExpr = "(do_it foo bar baz)".parse().unwrap(); 144 | assert_eq!(AstDepth.cost_rec(&e), 2); 145 | ``` 146 | 147 | [`CostFunction`]: trait.CostFunction.html 148 | **/ 149 | pub struct AstDepth; 150 | impl CostFunction for AstDepth { 151 | type Cost = usize; 152 | fn cost(&mut self, enode: &L, mut costs: C) -> Self::Cost 153 | where 154 | C: FnMut(Id) -> Self::Cost, 155 | { 156 | 1 + enode.fold(0, |max, id| max.max(costs(id))) 157 | } 158 | } 159 | 160 | fn cmp(a: &Option, b: &Option) -> Ordering { 161 | // None is high 162 | match (a, b) { 163 | (None, None) => Ordering::Equal, 164 | (None, Some(_)) => Ordering::Greater, 165 | (Some(_), None) => Ordering::Less, 166 | (Some(a), Some(b)) => a.partial_cmp(&b).unwrap(), 167 | } 168 | } 169 | 170 | impl<'a, CF, L, N> Extractor<'a, CF, L, N> 171 | where 172 | CF: CostFunction, 173 | L: Language, 174 | N: Analysis, 175 | { 176 | /// Create a new `Extractor` given an `EGraph` and a 177 | /// `CostFunction`. 178 | /// 179 | /// The extraction does all the work on creation, so this function 180 | /// performs the greedy search for cheapest representative of each 181 | /// eclass. 182 | pub fn new(egraph: &'a EGraph, cost_function: CF) -> Self { 183 | let costs = IndexMap::default(); 184 | let mut extractor = Extractor { 185 | costs, 186 | egraph, 187 | cost_function, 188 | }; 189 | extractor.find_costs(); 190 | 191 | extractor 192 | } 193 | 194 | /// Find the cheapest (lowest cost) represented `RecExpr` in the 195 | /// given eclass. 196 | pub fn find_best(&mut self, eclass: Id) -> (CF::Cost, RecExpr) { 197 | let mut expr = RecExpr::default(); 198 | // added_memo maps eclass id to id in expr 199 | let mut added_memo: IndexMap = Default::default(); 200 | let (_, cost) = self.find_best_rec(&mut expr, eclass, &mut added_memo); 201 | (cost, expr) 202 | } 203 | 204 | /// Find the cost of the term that would be extracted from this e-class. 205 | pub fn find_best_cost(&mut self, eclass: Id) -> CF::Cost { 206 | let (cost, _) = &self.costs[&self.egraph.find(eclass)]; 207 | cost.clone() 208 | } 209 | 210 | fn find_best_rec( 211 | &mut self, 212 | expr: &mut RecExpr, 213 | eclass: Id, 214 | added_memo: &mut IndexMap, 215 | ) -> (Id, CF::Cost) { 216 | let id = self.egraph.find(eclass); 217 | let (best_cost, best_node) = match self.costs.get(&id) { 218 | Some(result) => result.clone(), 219 | None => panic!("Failed to extract from eclass {}", id), 220 | }; 221 | 222 | match added_memo.get(&id) { 223 | Some(id_expr) => (*id_expr, best_cost), 224 | None => { 225 | let node = 226 | best_node.map_children(|child| self.find_best_rec(expr, child, added_memo).0); 227 | let id_expr = expr.add(node); 228 | assert!(added_memo.insert(id, id_expr).is_none()); 229 | (id_expr, best_cost) 230 | } 231 | } 232 | } 233 | 234 | fn node_total_cost(&mut self, node: &L) -> Option { 235 | let eg = &self.egraph; 236 | let has_cost = |&id| self.costs.contains_key(&eg.find(id)); 237 | if node.children().iter().all(has_cost) { 238 | let costs = &self.costs; 239 | let cost_f = |id| costs[&eg.find(id)].0.clone(); 240 | Some(self.cost_function.cost(&node, cost_f)) 241 | } else { 242 | None 243 | } 244 | } 245 | 246 | fn find_costs(&mut self) { 247 | let mut did_something = true; 248 | while did_something { 249 | did_something = false; 250 | 251 | for class in self.egraph.classes() { 252 | let pass = self.make_pass(class); 253 | match (self.costs.get(&class.id), pass) { 254 | (None, Some(new)) => { 255 | self.costs.insert(class.id, new); 256 | did_something = true; 257 | } 258 | (Some(old), Some(new)) if new.0 < old.0 => { 259 | self.costs.insert(class.id, new); 260 | did_something = true; 261 | } 262 | _ => (), 263 | } 264 | } 265 | } 266 | 267 | for class in self.egraph.classes() { 268 | if !self.costs.contains_key(&class.id) { 269 | log::warn!( 270 | "Failed to compute cost for eclass {}: {:?}", 271 | class.id, 272 | class.nodes 273 | ) 274 | } 275 | } 276 | } 277 | 278 | fn make_pass(&mut self, eclass: &EClass) -> Option<(CF::Cost, L)> { 279 | let (cost, node) = eclass 280 | .iter() 281 | .map(|n| (self.node_total_cost(n), n)) 282 | .min_by(|a, b| cmp(&a.0, &b.0)) 283 | .unwrap_or_else(|| panic!("Can't extract, eclass is empty: {:#?}", eclass)); 284 | cost.map(|c| (c, node.clone())) 285 | } 286 | } 287 | -------------------------------------------------------------------------------- /tests/lambda.rs: -------------------------------------------------------------------------------- 1 | use easter_egg::{rewrite as rw, *}; 2 | use std::collections::HashSet; 3 | use std::fmt::{Display, Formatter}; 4 | use serde::{Deserialize, Serialize}; 5 | 6 | define_language! { 7 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 8 | enum Lambda { 9 | Bool(bool), 10 | Num(i32), 11 | 12 | "var" = Var(Id), 13 | 14 | "+" = Add([Id; 2]), 15 | "==" = Eq([Id; 2]), 16 | 17 | "app" = App([Id; 2]), 18 | "lam" = Lambda([Id; 2]), 19 | "let" = Let([Id; 3]), 20 | "fix" = Fix([Id; 2]), 21 | 22 | "if" = If([Id; 3]), 23 | 24 | Symbol(easter_egg::Symbol), 25 | } 26 | } 27 | 28 | impl Lambda { 29 | fn num(&self) -> Option { 30 | match self { 31 | Lambda::Num(n) => Some(*n), 32 | _ => None, 33 | } 34 | } 35 | } 36 | 37 | type EGraph = easter_egg::EGraph; 38 | 39 | #[derive(Default, Clone)] 40 | struct LambdaAnalysis; 41 | 42 | #[derive(Clone, Debug)] 43 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 44 | struct Data { 45 | free: HashSet, 46 | constant: Option, 47 | } 48 | 49 | fn eval(egraph: &EGraph, enode: &Lambda) -> Option { 50 | let x = |i: &Id| egraph[*i].data.constant.clone(); 51 | match enode { 52 | Lambda::Num(_) | Lambda::Bool(_) => Some(enode.clone()), 53 | Lambda::Add([a, b]) => Some(Lambda::Num(x(a)?.num()? + x(b)?.num()?)), 54 | Lambda::Eq([a, b]) => Some(Lambda::Bool(x(a)? == x(b)?)), 55 | _ => None, 56 | } 57 | } 58 | 59 | impl Analysis for LambdaAnalysis { 60 | type Data = Data; 61 | fn merge(&self, to: &mut Data, from: Data) -> bool { 62 | let before_len = to.free.len(); 63 | // to.free.extend(from.free); 64 | to.free.retain(|i| from.free.contains(i)); 65 | let did_change = before_len != to.free.len(); 66 | if to.constant.is_none() && from.constant.is_some() { 67 | to.constant = from.constant; 68 | true 69 | } else { 70 | did_change 71 | } 72 | } 73 | 74 | fn make(egraph: &EGraph, enode: &Lambda) -> Data { 75 | let f = |i: &Id| egraph[*i].data.free.iter().cloned(); 76 | let mut free = HashSet::default(); 77 | match enode { 78 | Lambda::Var(v) => { 79 | free.insert(*v); 80 | } 81 | Lambda::Let([v, a, b]) => { 82 | free.extend(f(b)); 83 | free.remove(v); 84 | free.extend(f(a)); 85 | } 86 | Lambda::Lambda([v, a]) | Lambda::Fix([v, a]) => { 87 | free.extend(f(a)); 88 | free.remove(v); 89 | } 90 | _ => enode.for_each(|c| free.extend(&egraph[c].data.free)), 91 | } 92 | let constant = eval(egraph, enode); 93 | Data { constant, free } 94 | } 95 | 96 | fn modify(egraph: &mut EGraph, id: Id) { 97 | if let Some(c) = egraph[id].data.constant.clone() { 98 | let const_id = egraph.add(c); 99 | egraph.union(id, const_id); 100 | } 101 | } 102 | } 103 | 104 | fn var(s: &str) -> Var { 105 | s.parse().unwrap() 106 | } 107 | 108 | struct IsConstApplier { 109 | v: Var, 110 | } 111 | 112 | impl Display for IsConstApplier { 113 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 114 | write!(f, "is-const({})", self.v) 115 | } 116 | } 117 | 118 | impl Applier for IsConstApplier { 119 | fn apply_one(&self, egraph: &mut easter_egg::EGraph, _eclass: Id, subst: &Subst) -> Vec { 120 | if egraph[subst[self.v]].data.constant.is_some() { 121 | vec![subst[self.v]] 122 | } else { 123 | vec![] 124 | } 125 | } 126 | } 127 | 128 | fn is_const(v: Var) -> IsConstApplier { 129 | IsConstApplier { v } 130 | } 131 | 132 | fn rules() -> Vec> { 133 | vec![ 134 | // open term rules 135 | rw!("if-true"; "(if true ?then ?else)" => "?then"), 136 | rw!("if-false"; "(if false ?then ?else)" => "?else"), 137 | multi_rewrite!("if-elim"; "?var1 = (if (== (var ?x) ?e) ?then ?else), ?var2 = (let ?x ?e ?then), ?var2 = (let ?x ?e ?else)" => "?var1 = ?else"), 138 | rw!("add-comm"; "(+ ?a ?b)" => "(+ ?b ?a)"), 139 | rw!("add-assoc"; "(+ (+ ?a ?b) ?c)" => "(+ ?a (+ ?b ?c))"), 140 | rw!("eq-comm"; "(== ?a ?b)" => "(== ?b ?a)"), 141 | // subst rules 142 | rw!("fix"; "(fix ?v ?e)" => "(let ?v (fix ?v ?e) ?e)"), 143 | rw!("beta"; "(app (lam ?v ?body) ?e)" => "(let ?v ?e ?body)"), 144 | rw!("let-app"; "(let ?v ?e (app ?a ?b))" => "(app (let ?v ?e ?a) (let ?v ?e ?b))"), 145 | rw!("let-add"; "(let ?v ?e (+ ?a ?b))" => "(+ (let ?v ?e ?a) (let ?v ?e ?b))"), 146 | rw!("let-eq"; "(let ?v ?e (== ?a ?b))" => "(== (let ?v ?e ?a) (let ?v ?e ?b))"), 147 | rw!("let-const"; 148 | "(let ?v ?e ?c)" => { is_const(var("?c")) }), 149 | rw!("let-if"; 150 | "(let ?v ?e (if ?cond ?then ?else))" => 151 | "(if (let ?v ?e ?cond) (let ?v ?e ?then) (let ?v ?e ?else))" 152 | ), 153 | rw!("let-var-same"; "(let ?v1 ?e (var ?v1))" => "?e"), 154 | multi_rewrite!("let-var-diff"; "?x = (let ?v1 ?e (var ?v2)), ?v1 != ?v2" => "?x = (var ?v2)"), 155 | rw!("let-lam-same"; "(let ?v1 ?e (lam ?v1 ?body))" => "(lam ?v1 ?body)"), 156 | multi_rewrite!("let-lam-diff"; 157 | "?root = (let ?v1 ?e (lam ?v2 ?body)), ?v1 != ?v2" => 158 | { CaptureAvoid { 159 | root: var("?root"), fresh: var("?fresh"), v2: var("?v2"), e: var("?e"), 160 | if_not_free: "(lam ?v2 (let ?v1 ?e ?body))".parse().unwrap(), 161 | if_free: "(lam ?fresh (let ?v1 ?e (let ?v2 (var ?fresh) ?body)))".parse().unwrap(), 162 | }}), 163 | ] 164 | } 165 | 166 | struct CaptureAvoid { 167 | root: Var, 168 | fresh: Var, 169 | v2: Var, 170 | e: Var, 171 | if_not_free: Pattern, 172 | if_free: Pattern, 173 | } 174 | 175 | impl Applier for CaptureAvoid { 176 | fn apply_one(&self, egraph: &mut EGraph, _eclass: Id, subst: &Subst) -> Vec { 177 | let e = subst[self.e]; 178 | let v2 = subst[self.v2]; 179 | let v2_free_in_e = egraph[e].data.free.contains(&v2); 180 | let eclass = subst[self.root]; 181 | if v2_free_in_e { 182 | let mut subst = subst.clone(); 183 | let sym = Lambda::Symbol(format!("_{}", eclass).into()); 184 | subst.insert(self.fresh, egraph.add(sym)); 185 | self.if_free.apply_one(egraph, eclass, &subst) 186 | } else { 187 | self.if_not_free.apply_one(egraph, eclass, &subst) 188 | } 189 | } 190 | } 191 | 192 | impl Display for CaptureAvoid { 193 | fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result { 194 | todo!() 195 | } 196 | } 197 | 198 | #[cfg(test)] 199 | easter_egg::test_fn! { 200 | lambda_under, rules(), 201 | "(lam x (+ 4 202 | (app (lam y (var y)) 203 | 4)))" 204 | => 205 | // "(lam x (+ 4 (let y 4 (var y))))", 206 | // "(lam x (+ 4 4))", 207 | "(lam x 8))", 208 | } 209 | 210 | easter_egg::test_fn! { 211 | lambda_let_simple, rules(), 212 | "(let x 0 213 | (let y 1 214 | (+ (var x) (var y))))" 215 | => 216 | // "(let ?a 0 217 | // (+ (var ?a) 1))", 218 | // "(+ 0 1)", 219 | "1", 220 | } 221 | 222 | easter_egg::test_fn! { 223 | #[should_panic(expected = "Could not prove goal 0")] 224 | lambda_capture, rules(), 225 | "(let x 1 (lam x (var x)))" => "(lam x 1)" 226 | } 227 | 228 | easter_egg::test_fn! { 229 | #[should_panic(expected = "Could not prove goal 0")] 230 | lambda_capture_free, rules(), 231 | "(let y (+ (var x) (var x)) (lam x (var y)))" => "(lam x (+ (var x) (var x)))" 232 | } 233 | 234 | easter_egg::test_fn! { 235 | #[should_panic(expected = "Could not prove goal 0")] 236 | lambda_closure_not_seven, rules(), 237 | "(let five 5 238 | (let add-five (lam x (+ (var x) (var five))) 239 | (let five 6 240 | (app (var add-five) 1))))" 241 | => 242 | "7" 243 | } 244 | 245 | easter_egg::test_fn! { 246 | lambda_compose, rules(), 247 | "(let compose (lam f (lam g (lam x (app (var f) 248 | (app (var g) (var x)))))) 249 | (let add1 (lam y (+ (var y) 1)) 250 | (app (app (var compose) (var add1)) (var add1))))" 251 | => 252 | "(lam ?x (+ 1 253 | (app (lam ?y (+ 1 (var ?y))) 254 | (var ?x))))", 255 | "(lam ?x (+ (var ?x) 2))" 256 | } 257 | 258 | easter_egg::test_fn! { 259 | lambda_if_simple, rules(), 260 | "(if (== 1 1) 7 9)" => "7" 261 | } 262 | 263 | easter_egg::test_fn! { 264 | lambda_compose_many, rules(), 265 | "(let compose (lam f (lam g (lam x (app (var f) 266 | (app (var g) (var x)))))) 267 | (let add1 (lam y (+ (var y) 1)) 268 | (app (app (var compose) (var add1)) 269 | (app (app (var compose) (var add1)) 270 | (app (app (var compose) (var add1)) 271 | (app (app (var compose) (var add1)) 272 | (app (app (var compose) (var add1)) 273 | (app (app (var compose) (var add1)) 274 | (var add1)))))))))" 275 | => 276 | "(lam ?x (+ (var ?x) 7))" 277 | } 278 | 279 | easter_egg::test_fn! { 280 | #[cfg_attr(feature = "upward-merging", ignore)] 281 | #[ignore] 282 | lambda_function_repeat, rules(), 283 | runner = Runner::default() 284 | .with_time_limit(std::time::Duration::from_secs(20)) 285 | .with_node_limit(100_000) 286 | .with_iter_limit(60), 287 | "(let compose (lam f (lam g (lam x (app (var f) 288 | (app (var g) (var x)))))) 289 | (let repeat (fix repeat (lam fun (lam n 290 | (if (== (var n) 0) 291 | (lam i (var i)) 292 | (app (app (var compose) (var fun)) 293 | (app (app (var repeat) 294 | (var fun)) 295 | (+ (var n) -1))))))) 296 | (let add1 (lam y (+ (var y) 1)) 297 | (app (app (var repeat) 298 | (var add1)) 299 | 2))))" 300 | => 301 | "(lam ?x (+ (var ?x) 2))" 302 | } 303 | 304 | easter_egg::test_fn! { 305 | lambda_if, rules(), 306 | "(let zeroone (lam x 307 | (if (== (var x) 0) 308 | 0 309 | 1)) 310 | (+ (app (var zeroone) 0) 311 | (app (var zeroone) 10)))" 312 | => 313 | // "(+ (if false 0 1) (if true 0 1))", 314 | // "(+ 1 0)", 315 | "1", 316 | } 317 | 318 | easter_egg::test_fn! { 319 | #[cfg_attr(feature = "upward-merging", ignore)] 320 | #[ignore] 321 | lambda_fib, rules(), 322 | runner = Runner::default() 323 | .with_iter_limit(60) 324 | .with_node_limit(50_000), 325 | "(let fib (fix fib (lam n 326 | (if (== (var n) 0) 327 | 0 328 | (if (== (var n) 1) 329 | 1 330 | (+ (app (var fib) 331 | (+ (var n) -1)) 332 | (app (var fib) 333 | (+ (var n) -2))))))) 334 | (app (var fib) 4))" 335 | => "3" 336 | } 337 | -------------------------------------------------------------------------------- /src/colors.rs: -------------------------------------------------------------------------------- 1 | use crate::{egraph, SimpleUnionFind}; 2 | pub use crate::{Id, EGraph, Language, Analysis, ColorId}; 3 | use crate::{unionfind::UnionFind, Singleton}; 4 | use crate::util::JoinDisp; 5 | use as_any::Downcast; 6 | use invariants::dassert; 7 | use itertools::Itertools; 8 | use std::fmt::Formatter; 9 | use indexmap::{IndexMap, IndexSet}; 10 | use crate::unionfind::UnionFindWrapper; 11 | 12 | pub const BLACK_COLOR: ColorId = ColorId(0); 13 | 14 | /// Represents an e-graph layer that implements its own congruence relation. 15 | /// 16 | /// Each color represents a distinct congruence relation within the e-graph layers. 17 | /// The ids in the union find are directly taken from its parent layer. 18 | /// For efficiency, the color maintains the parent’s equality class for an id, 19 | /// speeding up search and rebuild operations. 20 | /// 21 | /// Currently, ids are not removed from the union find, which is wasteful but simpler. 22 | /// An optimization could remove an id when it is merged in the parent, eliminating 23 | /// redundancy; however, that would require updating all child colors to point to the new representative. 24 | #[derive(Clone, Debug)] 25 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 26 | pub struct Color> { 27 | color_id: ColorId, 28 | /// Used for rebuilding uf 29 | pub(crate) pending: Vec, 30 | /// Maintain which classes in black are represented in colored class (including rep) 31 | pub(crate) equality_classes: IndexMap>, 32 | /// Used to implement a union find. Opposite function of `equality_classes`. 33 | /// Supports removal of elements when they are not needed. 34 | union_find: Box>, 35 | /// Used to determine for each a colored equality class what is the black colored class. 36 | /// Relevant when a colored edge was added. 37 | pub(crate) children: Vec, 38 | pub(crate) parent: Option, 39 | parents: Vec, 40 | phantom: std::marker::PhantomData<(L, N)>, 41 | } 42 | 43 | impl> Color { 44 | pub(crate) fn collect_decendents(&self, egraph: &EGraph) -> Vec { 45 | let mut res = self.children.clone(); 46 | for c in self.children.iter() { 47 | res.extend(egraph.get_color(*c).unwrap().collect_decendents(egraph)); 48 | } 49 | res 50 | } 51 | } 52 | 53 | impl> Color { 54 | pub(crate) fn new(new_id: ColorId, parent: Option, graph: &EGraph) -> Color { 55 | let parents = parent.map_or_else(|| vec![], |p| { 56 | let mut res = graph.get_color(p).unwrap().parents.clone(); 57 | res.push(p); 58 | res 59 | }); 60 | let union_find: Box> = if parent.is_none() { 61 | // Box::new(SimpleUnionFind::default()) 62 | // TODO: When remove optional color, this should go back to simple 63 | Box::new(UnionFindWrapper::default()) 64 | } else { 65 | Box::new(UnionFindWrapper::default()) 66 | }; 67 | Color { 68 | color_id: new_id, 69 | pending: Default::default(), 70 | equality_classes: Default::default(), 71 | union_find, 72 | children: vec![], 73 | parent, 74 | parents, 75 | phantom: Default::default(), 76 | } 77 | } 78 | 79 | pub fn get_id(&self) -> ColorId { 80 | self.color_id 81 | } 82 | 83 | pub fn children(&self) -> &[ColorId] { 84 | &self.children 85 | } 86 | 87 | pub fn parents(&self) -> &[ColorId] { 88 | &self.parents 89 | } 90 | 91 | pub(crate) fn verify_uf_minimal(&self, egraph: &EGraph) { 92 | let mut parents: IndexMap = IndexMap::default(); 93 | for k in self.union_find.iter() { 94 | let v = self.find(egraph, k); 95 | *parents.entry(v).or_default() += 1; 96 | } 97 | for (k, v) in parents { 98 | assert!(v >= 1, "Found {} parents for {}", v, k); 99 | } 100 | } 101 | 102 | pub fn find(&self, egraph: &EGraph, id: Id) -> Id { 103 | let fixed = self.parent().map_or_else(|| egraph.find(id), |c_id| egraph.colored_find(c_id, id)); 104 | self.union_find.find(fixed) 105 | } 106 | 107 | pub fn find_mut(&mut self, id: Id) -> Id { 108 | self.union_find.find_mut(id) 109 | } 110 | 111 | pub fn is_dirty(&self) -> bool { !self.pending.is_empty() } 112 | 113 | 114 | /// Update the color according to the union of base_to and base_from in the parent layer 115 | /// Assumes to and from canonised to the base (parent, black or colored) and != 116 | /// @returns whether children need an update as well 117 | pub(crate) fn inner_base_union(&mut self, base_to: Id, base_from: Id) -> bool { 118 | // I should update the uf and the equality classes. 119 | // This should recursively try to update children until hitting a case they were both in UF and equal? 120 | // 1. If both were present but not equal I definitly need to union them, then I potentially need to remove from 121 | // 2. If both were present and equal I potentially need to remove from. This is a special case, no need to 122 | // recurse as any future child will see to and from as the same. 123 | // 3. Any of them was missing. I think if from was missing I need to change to new rep? Does it matter? 124 | // Not really, I just need to not assume I am holding parent rep 125 | dassert!(base_to != base_from, "Should not be the same"); 126 | let uf: &mut UnionFindWrapper = self.union_find.as_mut().downcast_mut().unwrap(); 127 | let from_existed = uf.contains(&base_from); 128 | let to_existed = uf.contains(&base_to); 129 | 130 | let diff = if to_existed && from_existed { 131 | // This part only needs to happen if one of the two is in the union find. 132 | let (colored_to, colored_from) = self.inner_colored_union(base_to, base_from); 133 | self.equality_classes.entry(colored_to).and_modify(|s| { 134 | s.swap_remove(&base_from); 135 | }); 136 | if self.equality_classes.get(&colored_to).map_or(false, |s| s.len() == 1) { 137 | dassert!(self.equality_classes.get(&colored_to).unwrap().contains(&colored_to), 138 | "We should always have the representative in the map"); 139 | self.equality_classes.swap_remove(&colored_to); 140 | } 141 | colored_to != colored_from 142 | } else if from_existed { 143 | // If from existed, we need to update the to to be the new representative. 144 | let colored_from = self.find_mut(base_from); 145 | dassert!(base_to != colored_from, 146 | "Ids in colored union should not be the same if from existed and to didnt"); 147 | let uf: &mut UnionFindWrapper = self.union_find.as_mut().downcast_mut().unwrap(); 148 | uf.swap(base_from, base_to); 149 | self.equality_classes.entry(colored_from).and_modify(|s| { 150 | s.swap_remove(&base_from); 151 | s.insert(base_to); 152 | }); 153 | self.pending.push(base_to); 154 | true 155 | } else { 156 | // TODO: I don't need to do anything here, right? 157 | let to = self.find_mut(base_to); 158 | dassert!(to != base_from, 159 | "Ids in colored union should not be the same if to existed and from didnt"); 160 | self.pending.push(to); 161 | true 162 | }; 163 | 164 | diff 165 | } 166 | 167 | // Assumed id1 and id2 are parent canonized 168 | #[inline(always)] 169 | pub(crate) fn inner_colored_union(&mut self, id1: Id, id2: Id) -> (Id, Id) { 170 | // Parent classes will be updated in black union to come. 171 | let (to, from) = self.union_find.union(id1, id2); 172 | let changed = to != from; 173 | if changed { 174 | self.pending.push(to); 175 | let from_ids = self.equality_classes.swap_remove(&from).unwrap_or_else(|| IndexSet::singleton(from)); 176 | self.equality_classes.entry(to).or_insert_with(|| IndexSet::singleton(to)).extend(from_ids); 177 | } 178 | (to, from) 179 | } 180 | 181 | pub fn base_equality_class(&self, egraph: &EGraph, id: Id) -> Option<&IndexSet> { 182 | self.equality_classes.get(&self.find(egraph, id)) 183 | } 184 | 185 | pub fn equality_class<'a>(&'a self, egraph: &'a EGraph, id: Id) -> Box + 'a> { 186 | let parent = self.parent(); 187 | let fixed_id = self.find(egraph, id); 188 | let mut res: Box> = if let Some(ids) = self.equality_classes.get(&fixed_id) { 189 | if let Some(c_id) = parent { 190 | Box::new(ids.into_iter().copied() 191 | .flat_map(move |id| egraph.get_color(c_id).unwrap().equality_class(egraph, id))) 192 | } else { 193 | Box::new(ids.into_iter().copied()) 194 | } 195 | } else { 196 | if let Some(c_id) = parent { 197 | Box::new(egraph.get_color(c_id).unwrap().equality_class(egraph, id)) 198 | } else { 199 | Box::new(std::iter::once(id)) 200 | } 201 | }; 202 | dassert!({ 203 | let temp = res.collect_vec(); 204 | let r = temp.len() == temp.iter().unique().count(); 205 | res = Box::new(temp.into_iter()); 206 | r 207 | }); 208 | res 209 | } 210 | 211 | /// Returns the black representative of the colored e-class of the current color only. Does not 212 | /// include the parents equality classes. 213 | pub fn current_black_reps(&self) -> impl Iterator { 214 | self.equality_classes.keys().into_iter() 215 | } 216 | 217 | pub fn parent(&self) -> Option { self.parent } 218 | 219 | pub fn get_all_enodes(&self, id: Id, egraph: &EGraph) -> Vec { 220 | let mut res: IndexSet = IndexSet::default(); 221 | for cls in self.equality_class(egraph, id) { 222 | res.extend(egraph[cls].nodes.iter().map(|n: &L| egraph.colored_canonize(self.color_id, n))); 223 | } 224 | return res.into_iter().collect_vec(); 225 | } 226 | 227 | #[inline(always)] 228 | pub fn assert_black_ids(&self, egraph: &EGraph) { 229 | // Check that black ids are actually black representatives 230 | dassert!({ 231 | for (_, set) in &self.equality_classes { 232 | for id in set { 233 | dassert!(egraph.find(*id) == *id, "black id {:?} is not black rep {:?}", id, egraph.find(*id)); 234 | } 235 | } 236 | true 237 | }); 238 | } 239 | } 240 | 241 | impl std::fmt::Display for Color where L: Language, N: Analysis { 242 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 243 | write!(f, "Color(id={}, groups={})", self.color_id, self.equality_classes.iter().map(|(id, set)| format!("{} - {}", id, set.iter().sep_string(" "))).join(", ")) 244 | } 245 | } 246 | 247 | #[cfg(test)] 248 | mod test { 249 | 250 | // #[test] 251 | // fn test_black_union_alone() { 252 | // let mut g = EGraph::::new(()); 253 | // let id1 = g.add_expr(&"1".parse().unwrap()); 254 | // let id2 = g.add_expr(&"2".parse().unwrap()); 255 | // let mut color = Color::new(ColorId::from(0)); 256 | // color.black_union(&mut g, id1, id2); 257 | // color.black_union(&mut g, id1, id2); 258 | // color.black_union(&mut g, id1, id1); 259 | // assert_eq!(color.find(&g, id1), color.find(&g, id2)); 260 | // } 261 | } 262 | -------------------------------------------------------------------------------- /tests/math.rs: -------------------------------------------------------------------------------- 1 | // use easter_egg::{rewrite as rw, *}; 2 | // use ordered_float::NotNan; 3 | // use serde::{Deserialize, Serialize}; 4 | // 5 | // pub type EGraph = easter_egg::EGraph; 6 | // pub type Rewrite = easter_egg::Rewrite; 7 | // 8 | // pub type Constant = NotNan; 9 | // 10 | // define_language! { 11 | // pub enum Math { 12 | // "d" = Diff([Id; 2]), 13 | // "i" = Integral([Id; 2]), 14 | // 15 | // "+" = Add([Id; 2]), 16 | // "-" = Sub([Id; 2]), 17 | // "*" = Mul([Id; 2]), 18 | // "/" = Div([Id; 2]), 19 | // "pow" = Pow([Id; 2]), 20 | // "ln" = Ln(Id), 21 | // "sqrt" = Sqrt(Id), 22 | // 23 | // "sin" = Sin(Id), 24 | // "cos" = Cos(Id), 25 | // 26 | // Constant(Constant), 27 | // Symbol(Symbol), 28 | // } 29 | // } 30 | // 31 | // // You could use easter_egg::AstSize, but this is useful for debugging, since 32 | // // it will really try to get rid of the Diff operator 33 | // pub struct MathCostFn; 34 | // impl easter_egg::CostFunction for MathCostFn { 35 | // type Cost = usize; 36 | // fn cost(&mut self, enode: &Math, mut costs: C) -> Self::Cost 37 | // where 38 | // C: FnMut(Id) -> Self::Cost, 39 | // { 40 | // let op_cost = match enode { 41 | // Math::Diff(..) => 100, 42 | // Math::Integral(..) => 100, 43 | // _ => 1, 44 | // }; 45 | // enode.fold(op_cost, |sum, i| sum + costs(i)) 46 | // } 47 | // } 48 | // 49 | // #[derive(Default, Serialize, Deserialize, Clone)] 50 | // pub struct ConstantFold; 51 | // impl Analysis for ConstantFold { 52 | // type Data = Option; 53 | // 54 | // fn merge(&self, to: &mut Self::Data, from: Self::Data) -> bool { 55 | // if let (Some(c1), Some(c2)) = (to.as_ref(), from.as_ref()) { 56 | // assert_eq!(c1, c2); 57 | // } 58 | // merge_if_different(to, to.or(from)) 59 | // } 60 | // 61 | // fn make(egraph: &EGraph, enode: &Math) -> Self::Data { 62 | // let x = |i: &Id| egraph[*i].data; 63 | // Some(match enode { 64 | // Math::Constant(c) => *c, 65 | // Math::Add([a, b]) => x(a)? + x(b)?, 66 | // Math::Sub([a, b]) => x(a)? - x(b)?, 67 | // Math::Mul([a, b]) => x(a)? * x(b)?, 68 | // Math::Div([a, b]) if x(b) != Some(0.0.into()) => x(a)? / x(b)?, 69 | // _ => return None, 70 | // }) 71 | // } 72 | // 73 | // fn modify(egraph: &mut EGraph, id: Id) { 74 | // let class = &mut egraph[id]; 75 | // if let Some(c) = class.data { 76 | // let added = egraph.add(Math::Constant(c)); 77 | // let (id, _did_something) = egraph.union(id, added); 78 | // // to not prune, comment this out 79 | // egraph[id].nodes.retain(|n| n.is_leaf()); 80 | // 81 | // assert!( 82 | // !egraph[id].nodes.is_empty(), 83 | // "empty eclass! {:#?}", 84 | // egraph[id] 85 | // ); 86 | // #[cfg(debug_assertions)] 87 | // egraph[id].assert_unique_leaves(); 88 | // } 89 | // } 90 | // } 91 | // 92 | // struct IsConstOrDistinctCondition { 93 | // v: Var, 94 | // w: Var, 95 | // } 96 | // 97 | // impl Condition for IsConstOrDistinctCondition { 98 | // fn check(&self, egraph: &mut easter_egg::EGraph, _eclass: Id, subst: &Subst) -> bool { 99 | // egraph.find(subst[self.v]) != egraph.find(subst[self.w]) 100 | // && egraph[subst[self.v]] 101 | // .nodes 102 | // .iter() 103 | // .any(|n| matches!(n, Math::Constant(..) | Math::Symbol(..))) 104 | // } 105 | // 106 | // fn check_colored(&self, egraph: &mut easter_egg::EGraph, eclass: Id, subst: &Subst) -> Option> { 107 | // self.check(egraph, eclass, subst).then(|| vec![]) 108 | // } 109 | // 110 | // fn describe(&self) -> String { 111 | // "is_const_or_distinct".to_string() 112 | // } 113 | // } 114 | // 115 | // fn is_const_or_distinct_var(v: &str, w: &str) -> impl Condition { 116 | // let v = v.parse().unwrap(); 117 | // let w = w.parse().unwrap(); 118 | // IsConstOrDistinctCondition { v, w } 119 | // } 120 | // 121 | // struct IsConstCondition { 122 | // v: Var, 123 | // } 124 | // 125 | // impl Condition for IsConstCondition { 126 | // fn check(&self, egraph: &mut easter_egg::EGraph, _eclass: Id, subst: &Subst) -> bool { 127 | // egraph[subst[self.v]] 128 | // .nodes 129 | // .iter() 130 | // .any(|n| matches!(n, Math::Constant(..))) 131 | // } 132 | // 133 | // fn check_colored(&self, egraph: &mut easter_egg::EGraph, eclass: Id, subst: &Subst) -> Option> { 134 | // self.check(egraph, eclass, subst).then(|| vec![]) 135 | // } 136 | // 137 | // fn describe(&self) -> String { 138 | // "is_const".to_string() 139 | // } 140 | // } 141 | // 142 | // fn is_const(var: &str) -> impl Condition { 143 | // let var = var.parse().unwrap(); 144 | // IsConstCondition { v: var } 145 | // } 146 | // 147 | // struct IsSymCondition { 148 | // v: Var, 149 | // } 150 | // 151 | // impl Condition for IsSymCondition { 152 | // fn check(&self, egraph: &mut easter_egg::EGraph, _eclass: Id, subst: &Subst) -> bool { 153 | // egraph[subst[self.v]] 154 | // .nodes 155 | // .iter() 156 | // .any(|n| matches!(n, Math::Symbol(..))) 157 | // } 158 | // 159 | // fn check_colored(&self, egraph: &mut easter_egg::EGraph, eclass: Id, subst: &Subst) -> Option> { 160 | // self.check(egraph, eclass, subst).then(|| vec![]) 161 | // } 162 | // 163 | // fn describe(&self) -> String { 164 | // "is_sym".to_string() 165 | // } 166 | // } 167 | // 168 | // fn is_sym(var: &str) -> impl Condition { 169 | // let var = var.parse().unwrap(); 170 | // IsSymCondition { v: var } 171 | // } 172 | // 173 | // struct IsNotZeroCondition { 174 | // v: Var, 175 | // } 176 | // 177 | // #[rustfmt::skip] 178 | // pub fn rules() -> Vec { vec![ 179 | // rw!("comm-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), 180 | // rw!("comm-mul"; "(* ?a ?b)" => "(* ?b ?a)"), 181 | // rw!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"), 182 | // rw!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"), 183 | // 184 | // rw!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"), 185 | // multi_rewrite!("div-canon"; "?root = (/ ?a ?b), ?b != 0" => "?root = (* ?a (pow ?b -1))"), 186 | // // rw!("canon-sub"; "(+ ?a (* -1 ?b))" => "(- ?a ?b)"), 187 | // // rw!("canon-div"; "(* ?a (pow ?b -1))" => "(/ ?a ?b)" if is_not_zero("?b")), 188 | // 189 | // rw!("zero-add"; "(+ ?a 0)" => "?a"), 190 | // rw!("zero-mul"; "(* ?a 0)" => "0"), 191 | // rw!("one-mul"; "(* ?a 1)" => "?a"), 192 | // 193 | // rw!("add-zero"; "?a" => "(+ ?a 0)"), 194 | // rw!("mul-one"; "?a" => "(* ?a 1)"), 195 | // 196 | // rw!("cancel-sub"; "(- ?a ?a)" => "0"), 197 | // multi_rewrite!("cancel-div"; "?root = (/ ?a ?a), ?a != 0" => "1"), 198 | // 199 | // rw!("distribute"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"), 200 | // rw!("factor" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"), 201 | // 202 | // rw!("pow-mul"; "(* (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (+ ?b ?c))"), 203 | // multi_rewrite!("pow0"; "?root = (pow ?x 0), ?x != 0" => "?root = 1"), 204 | // rw!("pow1"; "(pow ?x 1)" => "?x"), 205 | // rw!("pow2"; "(pow ?x 2)" => "(* ?x ?x)"), 206 | // multi_rewrite!("pow-recip"; "?root = (pow ?x -1), ?x != 0" => "?root = (/ 1 ?x)"), 207 | // rw!("recip-mul-div"; "?root = (* ?x (/ 1 ?x)), ?x != 0" => "?root = 1"), 208 | // 209 | // rw!("d-variable"; "(d ?x ?x)" => "1" if is_sym("?x")), 210 | // rw!("d-constant"; "(d ?x ?c)" => "0" if {conditions::MutAndCondition::new(vec![Box::new(is_sym("?x")), Box::new(is_const_or_distinct_var("?c", "?x"))])}), 211 | // 212 | // rw!("d-add"; "(d ?x (+ ?a ?b))" => "(+ (d ?x ?a) (d ?x ?b))"), 213 | // rw!("d-mul"; "(d ?x (* ?a ?b))" => "(+ (* ?a (d ?x ?b)) (* ?b (d ?x ?a)))"), 214 | // 215 | // rw!("d-sin"; "(d ?x (sin ?x))" => "(cos ?x)"), 216 | // rw!("d-cos"; "(d ?x (cos ?x))" => "(* -1 (sin ?x))"), 217 | // 218 | // rw!("d-ln"; "(d ?x (ln ?x))" => "(/ 1 ?x)" if is_not_zero("?x")), 219 | // 220 | // multi_rewrite!("d-power"; 221 | // "?root = (d ?x (pow ?f ?g)), ?f != 0, ?g != 0" => 222 | // "?root = (* (pow ?f ?g) 223 | // (+ (* (d ?x ?f) 224 | // (/ ?g ?f)) 225 | // (* (d ?x ?g) 226 | // (ln ?f))))" 227 | // ), 228 | // 229 | // rw!("i-one"; "(i 1 ?x)" => "?x"), 230 | // rw!("i-power-const"; "(i (pow ?x ?c) ?x)" => 231 | // "(/ (pow ?x (+ ?c 1)) (+ ?c 1))" if is_const("?c")), 232 | // rw!("i-cos"; "(i (cos ?x) ?x)" => "(sin ?x)"), 233 | // rw!("i-sin"; "(i (sin ?x) ?x)" => "(* -1 (cos ?x))"), 234 | // rw!("i-sum"; "(i (+ ?f ?g) ?x)" => "(+ (i ?f ?x) (i ?g ?x))"), 235 | // rw!("i-dif"; "(i (- ?f ?g) ?x)" => "(- (i ?f ?x) (i ?g ?x))"), 236 | // rw!("i-parts"; "(i (* ?a ?b) ?x)" => 237 | // "(- (* ?a (i ?b ?x)) (i (* (d ?x ?a) (i ?b ?x)) ?x))"), 238 | // ]} 239 | // 240 | // easter_egg::test_fn! { 241 | // math_associate_adds, [ 242 | // rw!("comm-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), 243 | // rw!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"), 244 | // ], 245 | // runner = Runner::default() 246 | // .with_iter_limit(7) 247 | // .with_scheduler(SimpleScheduler), 248 | // "(+ 1 (+ 2 (+ 3 (+ 4 (+ 5 (+ 6 7))))))" 249 | // => 250 | // "(+ 7 (+ 6 (+ 5 (+ 4 (+ 3 (+ 2 1))))))" 251 | // @check |r: Runner| assert_eq!(r.egraph.number_of_classes(), 127) 252 | // } 253 | // 254 | // easter_egg::test_fn! { 255 | // #[should_panic(expected = "Could not prove goal 0")] 256 | // math_fail, rules(), 257 | // "(+ x y)" => "(/ x y)" 258 | // } 259 | // 260 | // easter_egg::test_fn! {math_simplify_add, rules(), "(+ x (+ x (+ x x)))" => "(* 4 x)" } 261 | // easter_egg::test_fn! {math_powers, rules(), "(* (pow 2 x) (pow 2 y))" => "(pow 2 (+ x y))"} 262 | // 263 | // easter_egg::test_fn! { 264 | // math_simplify_const, rules(), 265 | // "(+ 1 (- a (* (- 2 1) a)))" => "1" 266 | // } 267 | // 268 | // easter_egg::test_fn! { 269 | // math_simplify_root, rules(), 270 | // runner = Runner::default().with_node_limit(75_000), 271 | // r#" 272 | // (/ 1 273 | // (- (/ (+ 1 (sqrt five)) 274 | // 2) 275 | // (/ (- 1 (sqrt five)) 276 | // 2)))"# 277 | // => 278 | // "(/ 1 (sqrt five))" 279 | // } 280 | // 281 | // easter_egg::test_fn! { 282 | // math_simplify_factor, rules(), 283 | // "(* (+ x 3) (+ x 1))" 284 | // => 285 | // "(+ (+ (* x x) (* 4 x)) 3)" 286 | // } 287 | // 288 | // easter_egg::test_fn! {math_diff_same, rules(), "(d x x)" => "1"} 289 | // easter_egg::test_fn! {math_diff_different, rules(), "(d x y)" => "0"} 290 | // easter_egg::test_fn! {math_diff_simple1, rules(), "(d x (+ 1 (* 2 x)))" => "2"} 291 | // easter_egg::test_fn! {math_diff_simple2, rules(), "(d x (+ 1 (* y x)))" => "y"} 292 | // easter_egg::test_fn! {math_diff_ln, rules(), "(d x (ln x))" => "(/ 1 x)"} 293 | // 294 | // easter_egg::test_fn! { 295 | // diff_power_simple, rules(), 296 | // "(d x (pow x 3))" => "(* 3 (pow x 2))" 297 | // } 298 | // 299 | // easter_egg::test_fn! { 300 | // diff_power_harder, rules(), 301 | // runner = Runner::default() 302 | // .with_time_limit(std::time::Duration::from_secs(10)) 303 | // .with_iter_limit(60) 304 | // .with_node_limit(100_000) 305 | // // HACK this needs to "see" the end expression 306 | // .with_expr(&"(* x (- (* 3 x) 14))".parse().unwrap()), 307 | // "(d x (- (pow x 3) (* 7 (pow x 2))))" 308 | // => 309 | // "(* x (- (* 3 x) 14))" 310 | // } 311 | // 312 | // easter_egg::test_fn! { 313 | // integ_one, rules(), "(i 1 x)" => "x" 314 | // } 315 | // 316 | // easter_egg::test_fn! { 317 | // integ_sin, rules(), "(i (cos x) x)" => "(sin x)" 318 | // } 319 | // 320 | // easter_egg::test_fn! { 321 | // integ_x, rules(), "(i (pow x 1) x)" => "(/ (pow x 2) 2)" 322 | // } 323 | // 324 | // easter_egg::test_fn! { 325 | // integ_part1, rules(), "(i (* x (cos x)) x)" => "(+ (* x (sin x)) (cos x))" 326 | // } 327 | // 328 | // easter_egg::test_fn! { 329 | // integ_part2, rules(), 330 | // "(i (* (cos x) x) x)" => "(+ (* x (sin x)) (cos x))" 331 | // } 332 | // 333 | // easter_egg::test_fn! { 334 | // integ_part3, rules(), "(i (ln x) x)" => "(- (* x (ln x)) x)" 335 | // } 336 | -------------------------------------------------------------------------------- /src/macros.rs: -------------------------------------------------------------------------------- 1 | use num_traits::Float; 2 | use ordered_float::NotNan; 3 | use crate::Symbol; 4 | 5 | #[doc(hidden)] 6 | #[macro_export] 7 | macro_rules! replace_expr { 8 | ($_t:tt $sub:expr) => {$sub}; 9 | } 10 | 11 | #[doc(hidden)] 12 | #[macro_export] 13 | macro_rules! count { 14 | ($($tts:tt)*) => {0u32 $(+ replace_expr!($tts 1u32))*}; 15 | } 16 | 17 | /** A macro to easily create a [`Language`]. 18 | 19 | `define_language` derives `Debug`, `PartialEq`, `Eq`, `PartialOrd`, `Ord`, 20 | `Hash`, and `Clone` on the given `enum` so it can implement [`Language`]. 21 | The macro also implements [`Display`] and [`FromOp`] for the `enum` 22 | based on either the data of variants or the provided strings. 23 | 24 | The final variant **must have a trailing comma**; this is due to limitations in 25 | macro parsing. 26 | 27 | See [`LanguageChildren`] for acceptable types of children `Id`s. 28 | 29 | Note that you can always implement [`Language`] yourself by just not using this 30 | macro. 31 | 32 | Presently, the macro does not support data variant with children, but that may 33 | be added later. 34 | 35 | # Example 36 | 37 | The following macro invocation shows the the accepted forms of variants: 38 | ``` 39 | # use easter_egg::*; 40 | define_language! { 41 | enum SimpleLanguage { 42 | // string variant with no children 43 | "pi" = Pi, 44 | 45 | // string variants with an array of child `Id`s (any static size) 46 | // any type that implements LanguageChildren may be used here 47 | "+" = Add([Id; 2]), 48 | "-" = Sub([Id; 2]), 49 | "*" = Mul([Id; 2]), 50 | 51 | // can also do a variable number of children in a boxed slice 52 | // this will only match if the lengths are the same 53 | "list" = List(Box<[Id]>), 54 | 55 | // string variants with a single child `Id` 56 | // note that this is distinct from `Sub`, even though it has the same 57 | // string, because it has a different number of children 58 | "-" = Neg(Id), 59 | 60 | // data variants with a single field 61 | // this field must implement `FromStr` and `Display` 62 | Num(i32), 63 | // language items are parsed in order, and we want symbol to 64 | // be a fallback, so we put it last 65 | Symbol(Symbol), 66 | // This is the ultimate fallback, it will parse any operator (as a string) 67 | // and any number of children. 68 | // Note that if there were 0 children, the previous branch would have succeeded 69 | // Other(Symbol, Vec), 70 | } 71 | } 72 | ``` 73 | [`Display`]: std::fmt::Display 74 | **/ 75 | #[macro_export] 76 | macro_rules! define_language { 77 | ($(#[$meta:meta])* $vis:vis enum $name:ident $variants:tt) => { 78 | $crate::__define_language!($(#[$meta])* $vis enum $name $variants -> {} {} {} {} {} {} {} {}); 79 | }; 80 | } 81 | 82 | pub trait GetOp { 83 | fn get_op(&self) -> u32 { 0 /* shrug */ } 84 | } 85 | 86 | impl GetOp for u32 { 87 | fn get_op(&self) -> u32 { 88 | *self 89 | } 90 | } 91 | 92 | impl GetOp for i32 { 93 | fn get_op(&self) -> u32 { 94 | self.abs() as u32 95 | } 96 | } 97 | 98 | impl GetOp for bool { 99 | fn get_op(&self) -> u32 { if *self { 1 } else { 0 } } 100 | } 101 | 102 | impl GetOp for Symbol { 103 | fn get_op(&self) -> u32 { 104 | let x: std::num::NonZeroU32 = (*self).into(); 105 | x.into() 106 | } 107 | } 108 | 109 | impl GetOp for NotNan { } 110 | 111 | #[doc(hidden)] 112 | #[macro_export] 113 | macro_rules! __define_language { 114 | ($(#[$meta:meta])* $vis:vis enum $name:ident {} -> 115 | $decl:tt $op_id:tt {$($matches:tt)*} $children:tt $children_mut:tt 116 | $display:tt {$($from_op:tt)*} $display_op:tt 117 | ) => { 118 | $(#[$meta])* 119 | #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] 120 | $vis enum $name $decl 121 | 122 | impl $crate::Language for $name { 123 | #[inline(always)] 124 | fn op_id(&self) -> OpId { 125 | match self $op_id 126 | } 127 | 128 | #[inline(always)] 129 | fn matches(&self, other: &Self) -> bool { 130 | self.op_id() == other.op_id() && 131 | match (self, other) { $($matches)* _ => false } 132 | } 133 | 134 | fn children(&self) -> &[Id] { match self $children } 135 | fn children_mut(&mut self) -> &mut [Id] { match self $children_mut } 136 | 137 | fn display_op(&self) -> &dyn ::std::fmt::Display { 138 | match self $display_op 139 | } 140 | } 141 | 142 | impl ::std::fmt::Display for $name { 143 | fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { 144 | // We need to pass `f` to the match expression for hygiene 145 | // reasons. 146 | match (self, f) $display 147 | } 148 | } 149 | 150 | impl $crate::FromOp for $name { 151 | type Error = $crate::FromOpError; 152 | 153 | fn from_op(op: &str, children: ::std::vec::Vec<$crate::Id>) -> ::std::result::Result { 154 | match (op, children) { 155 | $($from_op)* 156 | (op, children) => Err($crate::FromOpError::new(op, children)), 157 | } 158 | } 159 | } 160 | }; 161 | 162 | ($(#[$meta:meta])* $vis:vis enum $name:ident 163 | { 164 | $string:literal = $variant:ident, 165 | $($variants:tt)* 166 | } -> 167 | { $($decl:tt)* } { $($op_id:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* } 168 | { $($display:tt)* } { $($from_op:tt)* } { $($display_op:tt)* } 169 | ) => { 170 | $crate::__define_language!( 171 | $(#[$meta])* $vis enum $name 172 | { $($variants)* } -> 173 | { $($decl)* $variant, } 174 | { $($op_id)* $name::$variant => count!($($op_id)*), } 175 | { $($matches)* ($name::$variant, $name::$variant) => true, } 176 | { $($children)* $name::$variant => &[], } 177 | { $($children_mut)* $name::$variant => &mut [], } 178 | { $($display)* ($name::$variant, f) => f.write_str($string), } 179 | { $($from_op)* ($string, children) if children.is_empty() => Ok($name::$variant), } 180 | { $($display_op)* $name::$variant => &$string, } 181 | ); 182 | }; 183 | 184 | ($(#[$meta:meta])* $vis:vis enum $name:ident 185 | { 186 | $string:literal = $variant:ident ($ids:ty), 187 | $($variants:tt)* 188 | } -> 189 | { $($decl:tt)* } { $($op_id:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* } 190 | { $($display:tt)* } { $($from_op:tt)* } { $($display_op:tt)* } 191 | ) => { 192 | $crate::__define_language!( 193 | $(#[$meta])* $vis enum $name 194 | { $($variants)* } -> 195 | { $($decl)* $variant($ids), } 196 | { $($op_id)* #[allow(unused_variables)] $name::$variant(ids) => count!($($op_id)*), } 197 | { $($matches)* ($name::$variant(l), $name::$variant(r)) => $crate::LanguageChildren::len(l) == $crate::LanguageChildren::len(r), } 198 | { $($children)* $name::$variant(ids) => $crate::LanguageChildren::as_slice(ids), } 199 | { $($children_mut)* $name::$variant(ids) => $crate::LanguageChildren::as_mut_slice(ids), } 200 | { $($display)* ($name::$variant(..), f) => f.write_str($string), } 201 | { $($from_op)* (op, children) if op == $string && <$ids as $crate::LanguageChildren>::can_be_length(children.len()) => { 202 | let children = <$ids as $crate::LanguageChildren>::from_vec(children); 203 | Ok($name::$variant(children)) 204 | }, 205 | } 206 | { $($display_op)* $name::$variant(..) => &$string, } 207 | ); 208 | }; 209 | 210 | ($(#[$meta:meta])* $vis:vis enum $name:ident 211 | { 212 | $variant:ident ($data:ty), 213 | $($variants:tt)* 214 | } -> 215 | { $($decl:tt)* } { $($op_id:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* } 216 | { $($display:tt)* } { $($from_op:tt)* } { $($display_op:tt)* } 217 | ) => { 218 | $crate::__define_language!( 219 | $(#[$meta])* $vis enum $name 220 | { $($variants)* } -> 221 | { $($decl)* $variant($data), } 222 | { $($op_id)* $name::$variant(data) => count!($($op_id)*) + (($crate::macros::GetOp::get_op(data) << 16)), } 223 | { $($matches)* ($name::$variant(data1), $name::$variant(data2)) => data1 == data2, } 224 | { $($children)* $name::$variant(_data) => &[], } 225 | { $($children_mut)* $name::$variant(_data) => &mut [], } 226 | { $($display)* ($name::$variant(data), f) => ::std::fmt::Display::fmt(data, f), } 227 | { $($from_op)* (op, children) if op.parse::<$data>().is_ok() && children.is_empty() => Ok($name::$variant(op.parse().unwrap())), } 228 | { $($display_op)* $name::$variant(data) => data, } 229 | ); 230 | }; 231 | 232 | ($(#[$meta:meta])* $vis:vis enum $name:ident 233 | { 234 | $variant:ident ($data:ty, $ids:ty), 235 | $($variants:tt)* 236 | } -> 237 | { $($decl:tt)* } { $($matches:tt)* } { $($children:tt)* } { $($children_mut:tt)* } 238 | { $($display:tt)* } { $($from_op:tt)* } { $($display_op:tt)* } 239 | ) => { 240 | $crate::__define_language!( 241 | $(#[$meta])* $vis enum $name 242 | { $($variants)* } -> 243 | { $($decl)* $variant($data, $ids), } 244 | { $($matches)* ($name::$variant(d1, l), $name::$variant(d2, r)) => d1 == d2 && $crate::LanguageChildren::len(l) == $crate::LanguageChildren::len(r), } 245 | { $($children)* $name::$variant(_, ids) => $crate::LanguageChildren::as_slice(ids), } 246 | { $($children_mut)* $name::$variant(_, ids) => $crate::LanguageChildren::as_mut_slice(ids), } 247 | { $($display)* ($name::$variant(data, _), f) => ::std::fmt::Display::fmt(data, f), } 248 | { $($from_op)* (op, children) if op.parse::<$data>().is_ok() && <$ids as $crate::LanguageChildren>::can_be_length(children.len()) => { 249 | let data = op.parse::<$data>().unwrap(); 250 | let children = <$ids as $crate::LanguageChildren>::from_vec(children); 251 | Ok($name::$variant(data, children)) 252 | }, 253 | } 254 | { $($display_op)* $name::$variant(data, _) => data, } 255 | ); 256 | }; 257 | } 258 | 259 | 260 | /** A macro to easily make [`Rewrite`]s. 261 | 262 | The `rewrite!` macro greatly simplifies creating simple, purely 263 | syntactic rewrites while also allowing more complex ones. 264 | 265 | This panics if [`Rewrite::new`](Rewrite::new()) fails. 266 | 267 | The simplest form `rewrite!(a; b => c)` creates a [`Rewrite`] 268 | with name `a`, [`Searcher`] `b`, and [`Applier`] `c`. 269 | Note that in the `b` and `c` position, the macro only accepts a single 270 | token tree (see the [macros reference][macro] for more info). 271 | In short, that means you should pass in an identifier, literal, or 272 | something surrounded by parentheses or braces. 273 | 274 | If you pass in a literal to the `b` or `c` position, the macro will 275 | try to parse it as a [`Pattern`] which implements both [`Searcher`] 276 | and [`Applier`]. 277 | 278 | This now uses negative patterns: 279 | The macro also accepts any number of `if ` forms at the end, 280 | where the given expression should implement [`Condition`]. 281 | For each of these, the macro will wrap the given applier in a 282 | [`ConditionalApplier`] with the given condition, with the first condition being 283 | the outermost, and the last condition being the innermost. 284 | 285 | # Example 286 | ``` 287 | # use easter_egg::*; 288 | use std::borrow::Cow; 289 | use std::sync::Arc; 290 | define_language! { 291 | enum SimpleLanguage { 292 | Num(i32), 293 | "+" = Add([Id; 2]), 294 | "-" = Sub([Id; 2]), 295 | "*" = Mul([Id; 2]), 296 | "/" = Div([Id; 2]), 297 | } 298 | } 299 | 300 | type EGraph = easter_egg::EGraph; 301 | 302 | let mut rules: Vec> = vec![ 303 | rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), 304 | rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), 305 | 306 | rewrite!("mul-0"; "(* ?a 0)" => "0"), 307 | 308 | rewrite!("silly"; "(* ?a 1)" => { MySillyApplier("foo") }), 309 | 310 | multi_rewrite!("something_conditional"; 311 | "?c = (/ ?a ?b), ?b != 0" => "?c = (* ?a (/ 1 ?b))"), 312 | ]; 313 | 314 | // rewrite! supports bidirectional rules too 315 | // it returns a Vec of length 2, so you need to concat 316 | rules.extend(vec![ 317 | rewrite!("add-0"; "(+ ?a 0)" <=> "?a"), 318 | rewrite!("mul-1"; "(* ?a 1)" <=> "?a"), 319 | ].concat()); 320 | 321 | #[derive(Debug)] 322 | struct MySillyApplier(&'static str); 323 | impl std::fmt::Display for MySillyApplier { 324 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 325 | self.0.fmt(f) 326 | } 327 | } 328 | 329 | impl Applier for MySillyApplier { 330 | fn apply_one(&self, _: &mut EGraph, _: Id, _: &Subst) -> Vec { 331 | panic!() 332 | } 333 | } 334 | 335 | // This returns a function that implements Condition 336 | fn is_not_zero(var: &'static str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { 337 | let var = var.parse().unwrap(); 338 | let zero = SimpleLanguage::Num(0); 339 | move |egraph, _, subst| !egraph[subst[var]].nodes.contains(&zero) 340 | } 341 | ``` 342 | 343 | [macro]: https://doc.rust-lang.org/stable/reference/macros-by-example.html#metavariables 344 | **/ 345 | #[macro_export] 346 | macro_rules! rewrite { 347 | ( 348 | $name:expr; 349 | $lhs:tt => $rhs:tt 350 | $(if $cond:expr)* 351 | ) => {{ 352 | let searcher = $crate::__rewrite!(@parse Pattern $lhs); 353 | let core_applier = $crate::__rewrite!(@parse Pattern $rhs); 354 | let applier = $crate::__rewrite!(@applier core_applier; $($cond,)*); 355 | $crate::Rewrite::new($name, searcher, applier).unwrap() 356 | }}; 357 | ( 358 | $name:expr; 359 | $lhs:tt <=> $rhs:tt 360 | $(if $cond:expr)* 361 | ) => {{ 362 | let name = $name; 363 | let name2 = String::from(name.clone()) + "-rev"; 364 | vec![ 365 | $crate::rewrite!(name; $lhs => $rhs $(if $cond)*), 366 | $crate::rewrite!(name2; $rhs => $lhs $(if $cond)*) 367 | ] 368 | }}; 369 | } 370 | 371 | /** A macro to easily make [`Rewrite`]s using [`MultiPattern`]s. 372 | 373 | Similar to the [`rewrite!`] macro, 374 | this macro uses the form `multi_rewrite!(name; multipattern => multipattern)`. 375 | String literals will be parsed a [`MultiPattern`]s. 376 | 377 | **/ 378 | #[macro_export] 379 | macro_rules! multi_rewrite { 380 | // limited multipattern support 381 | ( 382 | $name:expr; 383 | $lhs:tt => $rhs:tt 384 | ) => {{ 385 | let searcher = $crate::__rewrite!(@parse MultiPattern $lhs); 386 | let applier = $crate::__rewrite!(@parse MultiPattern $rhs); 387 | $crate::Rewrite::new($name.to_string(), searcher, applier).unwrap() 388 | }}; 389 | } 390 | 391 | #[doc(hidden)] 392 | #[macro_export] 393 | macro_rules! __rewrite { 394 | (@parse $t:ident $rhs:literal) => { 395 | $rhs.parse::<$crate::$t<_>>().unwrap() 396 | }; 397 | (@parse $t:ident $rhs:expr) => { $rhs }; 398 | (@applier $applier:expr;) => { $applier }; 399 | (@applier $applier:expr; $cond:expr, $($conds:expr,)*) => { 400 | $crate::ConditionalApplier { 401 | condition: $cond, 402 | applier: $crate::__rewrite!(@applier $applier; $($conds,)*), 403 | phantom: Default::default(), 404 | } 405 | }; 406 | } 407 | 408 | #[cfg(test)] 409 | mod tests { 410 | 411 | use crate::*; 412 | 413 | define_language! { 414 | enum Simple { 415 | "+" = Add([Id; 2]), 416 | "-" = Sub([Id; 2]), 417 | "*" = Mul([Id; 2]), 418 | "-" = Neg(Id), 419 | "list" = List(Box<[Id]>), 420 | "pi" = Pi, 421 | Int(i32), 422 | Var(Symbol), 423 | } 424 | } 425 | 426 | #[test] 427 | fn modify_children() { 428 | let mut add = Simple::Add([0.into(), 0.into()]); 429 | add.for_each_mut(|id| *id = 1.into()); 430 | assert_eq!(add, Simple::Add([1.into(), 1.into()])); 431 | } 432 | 433 | #[test] 434 | fn some_rewrites() { 435 | let mut rws: Vec> = vec![ 436 | // here it should parse the rhs 437 | rewrite!("rule"; "cons" => "f"), 438 | // here it should just accept the rhs without trying to parse 439 | rewrite!("rule"; "f" => { "pat".parse::>().unwrap() }), 440 | ]; 441 | rws.extend(rewrite!("two-way"; "foo" <=> "bar")); 442 | } 443 | 444 | #[test] 445 | #[should_panic(expected = "refers to unbound var ?x")] 446 | fn rewrite_simple_panic() { 447 | let _: Rewrite = rewrite!("bad"; "?a" => "?x"); 448 | } 449 | } 450 | -------------------------------------------------------------------------------- /src/rewrite.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | use std::{any::Any, sync::Arc}; 3 | 4 | use crate::{Analysis, EGraph, Id, Language, Pattern, SearchMatches, Subst, Var, ColorId}; 5 | use std::fmt::Formatter; 6 | use invariants::iassert; 7 | 8 | /// A rewrite that searches for the lefthand side and applies the righthand side. 9 | /// 10 | /// The [`rewrite!`] is the easiest way to create rewrites. 11 | /// 12 | /// A [`Rewrite`] consists principally of a [`Searcher`] (the lefthand 13 | /// side) and an [`Applier`] (the righthand side). 14 | /// It additionally stores a name used to refer to the rewrite and a 15 | /// long name used for debugging. 16 | /// 17 | /// [`rewrite!`]: macro.rewrite.html 18 | /// [`Searcher`]: trait.Searcher.html 19 | /// [`Applier`]: trait.Applier.html 20 | /// [`Condition`]: trait.Condition.html 21 | /// [`Rewrite`]: struct.Rewrite.html 22 | /// [`Pattern`]: struct.Pattern.html 23 | #[derive(Clone)] 24 | #[non_exhaustive] 25 | pub struct Rewrite> { 26 | /// The name of the rewrite. 27 | pub name: String, 28 | /// The searcher (left-hand side) of the rewrite. 29 | pub searcher: Arc + Send + Sync>, 30 | /// The applier (right-hand side) of the rewrite. 31 | pub applier: Arc + Send + Sync>, 32 | } 33 | 34 | impl fmt::Debug for Rewrite 35 | where 36 | L: Language + 'static, 37 | N: Analysis + 'static, 38 | { 39 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 40 | let mut d = f.debug_struct("Rewrite"); 41 | d.field("name", &self.name); 42 | 43 | if let Some(pat) = ::downcast_ref::>(&self.searcher) { 44 | d.field("searcher", &(pat)); 45 | } else { 46 | d.field("searcher", &"<< searcher >>"); 47 | } 48 | 49 | if let Some(pat) = ::downcast_ref::>(&self.applier) { 50 | d.field("applier", &(pat)); 51 | } else { 52 | d.field("applier", &"<< applier >>"); 53 | } 54 | 55 | d.finish() 56 | } 57 | } 58 | 59 | impl> Rewrite { 60 | /// Returns the name of the rewrite. 61 | pub fn name(&self) -> &str { 62 | &self.name 63 | } 64 | } 65 | 66 | impl + 'static> Rewrite { 67 | /// Create a new [`Rewrite`]. You typically want to use the 68 | /// [`rewrite!`] macro instead. 69 | /// 70 | /// [`Rewrite`]: struct.Rewrite.html 71 | /// [`rewrite!`]: macro.rewrite.html 72 | pub fn new( 73 | name: impl Into, 74 | searcher: impl Searcher + 'static + Send + Sync, 75 | applier: impl Applier + 'static + Send + Sync, 76 | ) -> Result { 77 | let name = name.into(); 78 | let searcher = Arc::new(searcher); 79 | let applier = Arc::new(applier); 80 | 81 | let bound_vars = searcher.vars(); 82 | for v in applier.vars() { 83 | if !bound_vars.contains(&v) { 84 | return Err(format!("Rewrite {} refers to unbound var {}", name, v)); 85 | } 86 | } 87 | 88 | Ok(Self { 89 | name, 90 | searcher, 91 | applier, 92 | }) 93 | } 94 | 95 | /// Call [`search`] on the [`Searcher`]. 96 | /// 97 | /// [`Searcher`]: trait.Searcher.html 98 | /// [`search`]: trait.Searcher.html#method.search 99 | pub fn search(&self, egraph: &EGraph) -> Option { 100 | self.searcher.search(egraph) 101 | } 102 | 103 | /// Call [`search_with_limit`] on the [`Searcher`]. 104 | /// 105 | /// [`search_with_limit`]: Searcher::search_with_limit() 106 | pub fn search_with_limit(&self, egraph: &EGraph, limit: usize) -> Option { 107 | self.searcher.search_with_limit(egraph, limit) 108 | } 109 | 110 | /// Call [`apply_matches`] on the [`Applier`]. 111 | /// 112 | /// [`Applier`]: trait.Applier.html 113 | /// [`apply_matches`]: trait.Applier.html#method.apply_matches 114 | pub fn apply(&self, egraph: &mut EGraph, matches: &Option) -> Vec { 115 | self.applier.apply_matches(egraph, matches) 116 | } 117 | 118 | /// This `run` is for testing use only. You should use things 119 | /// from the `easter_egg::run` module 120 | #[cfg(test)] 121 | pub(crate) fn run(&self, egraph: &mut EGraph) -> Vec { 122 | let start = instant::Instant::now(); 123 | 124 | let matches = self.search(egraph); 125 | log::debug!("Found rewrite {} {} times", self.name, matches.as_ref().map_or(0, |m| m.len())); 126 | 127 | let ids = self.apply(egraph, &matches); 128 | let elapsed = start.elapsed(); 129 | log::debug!( 130 | "Applied rewrite {} {} times in {}.{:03}", 131 | self.name, 132 | ids.len(), 133 | elapsed.as_secs(), 134 | elapsed.subsec_millis() 135 | ); 136 | 137 | egraph.rebuild(); 138 | ids 139 | } 140 | } 141 | 142 | /// Searches the given list of e-classes with a limit. 143 | pub(crate) fn search_eclasses_with_limit( 144 | searcher: &S, 145 | egraph: &EGraph, 146 | eclasses: I, 147 | mut limit: usize, 148 | ) -> Option 149 | where 150 | L: Language, 151 | N: Analysis, 152 | S: Searcher + ?Sized, 153 | I: IntoIterator, 154 | { 155 | let mut ms = vec![]; 156 | for eclass in eclasses { 157 | if limit == 0 { 158 | break; 159 | } 160 | match searcher.search_eclass_with_limit(egraph, eclass, limit) { 161 | None => continue, 162 | Some(m) => { 163 | let len = m.total_substs(); 164 | assert!(len <= limit); 165 | limit -= len; 166 | ms.push(m); 167 | } 168 | } 169 | } 170 | SearchMatches::merge_matches(ms) 171 | } 172 | 173 | impl> std::fmt::Display for Rewrite { 174 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 175 | write!(f, "Rewrite({}, {}, {})", self.name, self.searcher, self.applier) 176 | } 177 | } 178 | 179 | /// The lefthand side of a [`Rewrite`]. 180 | /// 181 | /// A [`Searcher`] is something that can search the egraph and find 182 | /// matching substititions. 183 | /// Right now the only significant [`Searcher`] is [`Pattern`]. 184 | /// 185 | /// [`Rewrite`]: struct.Rewrite.html 186 | /// [`Searcher`]: trait.Searcher.html 187 | /// [`Pattern`]: struct.Pattern.html 188 | pub trait Searcher: std::fmt::Display 189 | where 190 | L: Language, 191 | N: Analysis, 192 | { 193 | /// Search one eclass, returning None if no matches can be found. 194 | /// This should not return a SearchMatches with no substs. 195 | fn search_eclass(&self, egraph: &EGraph, eclass: Id) -> Option { 196 | self.search_eclass_with_limit(egraph, eclass, usize::MAX) 197 | } 198 | 199 | /// Similar to [`search_eclass`], but return at most `limit` many matches. 200 | /// 201 | /// Implementation of [`Searcher`] should implement 202 | /// [`search_eclass_with_limit`]. 203 | /// 204 | /// [`search_eclass`]: Searcher::search_eclass 205 | /// [`search_eclass_with_limit`]: Searcher::search_eclass_with_limit 206 | fn search_eclass_with_limit( 207 | &self, 208 | egraph: &EGraph, 209 | eclass: Id, 210 | limit: usize, 211 | ) -> Option; 212 | 213 | /// Search the whole [`EGraph`], returning a list of all the 214 | /// [`SearchMatches`] where something was found. 215 | /// This just calls [`search_eclass`] on each eclass. 216 | /// 217 | /// [`EGraph`]: struct.EGraph.html 218 | /// [`search_eclass`]: trait.Searcher.html#tymethod.search_eclass 219 | /// [`SearchMatches`]: struct.SearchMatches.html 220 | fn search(&self, egraph: &EGraph) -> Option { 221 | self.search_with_limit(egraph, usize::MAX) 222 | } 223 | 224 | /// Similar to [`search`], but return at most `limit` many matches. 225 | /// 226 | /// [`search`]: Searcher::search 227 | fn search_with_limit(&self, egraph: &EGraph, limit: usize) -> Option { 228 | search_eclasses_with_limit(self, egraph, egraph.classes().map(|e| e.id), limit) 229 | } 230 | 231 | /// Search one eclass with starting color 232 | /// This should also check all color-equal classes 233 | fn colored_search_eclass(&self, egraph: &EGraph, eclass: Id, color: ColorId) -> Option { 234 | self.colored_search_eclass_with_limit(egraph, eclass, color, usize::MAX) 235 | } 236 | 237 | /// Search one eclass with starting color 238 | /// This should also check all color-equal classes 239 | fn colored_search_eclass_with_limit(&self, egraph: &EGraph, eclass: Id, color: ColorId, limit: usize) -> Option; 240 | 241 | /// Returns a list of the variables bound by this Searcher 242 | fn vars(&self) -> Vec; 243 | 244 | /// Returns the number of matches in the e-graph 245 | fn n_matches(&self, egraph: &EGraph) -> usize { 246 | self.search(egraph).map_or(0, |ms| ms.iter().map(|m| m.substs.len()).sum()) 247 | } 248 | } 249 | 250 | /// The righthand side of a [`Rewrite`]. 251 | /// 252 | /// An [`Applier`] is anything that can do something with a 253 | /// substitition ([`Subst`]). This allows you to implement rewrites 254 | /// that determine when and how to respond to a match using custom 255 | /// logic, including access to the [`Analysis`] data of an [`EClass`]. 256 | /// 257 | /// Notably, [`Pattern`] implements [`Applier`], which suffices in 258 | /// most cases. 259 | /// 260 | /// # Example 261 | /// ``` 262 | /// use easter_egg::{rewrite as rw, *}; 263 | /// use std::fmt::{Display, Formatter}; 264 | /// 265 | /// define_language! { 266 | /// enum Math { 267 | /// Num(i32), 268 | /// "+" = Add([Id; 2]), 269 | /// "*" = Mul([Id; 2]), 270 | /// Symbol(Symbol), 271 | /// } 272 | /// } 273 | /// 274 | /// type EGraph = easter_egg::EGraph; 275 | /// 276 | /// // Our metadata in this case will be size of the smallest 277 | /// // represented expression in the eclass. 278 | /// #[derive(Default, Clone)] 279 | /// struct MinSize; 280 | /// impl Analysis for MinSize { 281 | /// type Data = usize; 282 | /// fn merge(&self, to: &mut Self::Data, from: Self::Data) -> bool { 283 | /// merge_if_different(to, (*to).min(from)) 284 | /// } 285 | /// fn make(egraph: &EGraph, enode: &Math) -> Self::Data { 286 | /// let get_size = |i: Id| egraph[i].data; 287 | /// AstSize.cost(enode, get_size) 288 | /// } 289 | /// } 290 | /// 291 | /// let rules = &[ 292 | /// rw!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), 293 | /// rw!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), 294 | /// rw!("add-0"; "(+ ?a 0)" => "?a"), 295 | /// rw!("mul-0"; "(* ?a 0)" => "0"), 296 | /// rw!("mul-1"; "(* ?a 1)" => "?a"), 297 | /// // the rewrite macro parses the rhs as a single token tree, so 298 | /// // we wrap it in braces (parens work too). 299 | /// rw!("funky"; "(+ ?a (* ?b ?c))" => { Funky { 300 | /// a: "?a".parse().unwrap(), 301 | /// b: "?b".parse().unwrap(), 302 | /// c: "?c".parse().unwrap(), 303 | /// }}), 304 | /// ]; 305 | /// 306 | /// #[derive(Debug, Clone, PartialEq, Eq)] 307 | /// struct Funky { 308 | /// a: Var, 309 | /// b: Var, 310 | /// c: Var, 311 | /// } 312 | /// 313 | /// impl std::fmt::Display for Funky { 314 | /// fn fmt(&self,f: &mut Formatter<'_>) -> std::fmt::Result { 315 | /// write!(f, "Funky") 316 | /// } 317 | /// } 318 | /// 319 | /// impl Applier for Funky { 320 | /// fn apply_one(&self, egraph: &mut EGraph, matched_id: Id, subst: &Subst) -> Vec { 321 | /// let a: Id = subst[self.a]; 322 | /// // In a custom Applier, you can inspect the analysis data, 323 | /// // which is powerful combination! 324 | /// let size_of_a = egraph[a].data; 325 | /// if size_of_a > 50 { 326 | /// println!("Too big! Not doing anything"); 327 | /// vec![] 328 | /// } else { 329 | /// // we're going to manually add: 330 | /// // (+ (+ ?a 0) (* (+ ?b 0) (+ ?c 0))) 331 | /// // to be unified with the original: 332 | /// // (+ ?a (* ?b ?c )) 333 | /// let b: Id = subst[self.b]; 334 | /// let c: Id = subst[self.c]; 335 | /// let zero = egraph.add(Math::Num(0)); 336 | /// let a0 = egraph.add(Math::Add([a, zero])); 337 | /// let b0 = egraph.add(Math::Add([b, zero])); 338 | /// let c0 = egraph.add(Math::Add([c, zero])); 339 | /// let b0c0 = egraph.add(Math::Mul([b0, c0])); 340 | /// let a0b0c0 = egraph.add(Math::Add([a0, b0c0])); 341 | /// // NOTE: we just return the id according to what we 342 | /// // want unified with matched_id. The `apply_matches` 343 | /// // method actually does the union, _not_ `apply_one`. 344 | /// vec![a0b0c0] 345 | /// } 346 | /// } 347 | /// } 348 | /// 349 | /// let start = "(+ x (* y z))".parse().unwrap(); 350 | /// let mut runner = Runner::default().with_expr(&start); 351 | /// runner.egraph.rebuild(); 352 | /// runner.run(rules); 353 | /// ``` 354 | /// [`Pattern`]: struct.Pattern.html 355 | /// [`EClass`]: struct.EClass.html 356 | /// [`Rewrite`]: struct.Rewrite.html 357 | /// [`Subst`]: struct.Subst.html 358 | /// [`Applier`]: trait.Applier.html 359 | /// [`Condition`]: trait.Condition.html 360 | /// [`Analysis`]: trait.Analysis.html 361 | pub trait Applier: std::fmt::Display 362 | where 363 | L: Language + 'static, 364 | N: Analysis +'static, 365 | { 366 | /// Apply many substititions. 367 | /// 368 | /// This method should call [`apply_one`] for each match and then 369 | /// unify the results with the matched eclass. 370 | /// This should return a list of [`Id`]s where the union actually 371 | /// did something. 372 | /// 373 | /// The default implementation does this and should suffice for 374 | /// most use cases. 375 | /// 376 | /// [`Id`]: struct.Id.html 377 | /// [`apply_one`]: trait.Applier.html#method.apply_one 378 | fn apply_matches(&self, egraph: &mut EGraph, matches: &Option) -> Vec { 379 | let mut added = vec![]; 380 | if let Some(mat) = matches { 381 | for (eclass, substs) in mat.matches.iter() { 382 | for subst in substs { 383 | let ids = self 384 | .apply_one(egraph, *eclass, subst) 385 | .into_iter() 386 | .filter_map(|id| { 387 | let (to, did_something) = 388 | egraph.opt_colored_union(subst.color(), id, *eclass); 389 | if did_something { 390 | Some(to) 391 | } else { 392 | None 393 | } 394 | }); 395 | added.extend(ids) 396 | } 397 | } 398 | } 399 | added 400 | } 401 | 402 | /// Apply a single substitition. 403 | /// 404 | /// An [`Applier`] should only add things to the egraph here, 405 | /// _not_ union them with the id `eclass`. 406 | /// That is the responsibility of the [`apply_matches`] method. 407 | /// The `eclass` parameter allows the implementer to inspect the 408 | /// eclass where the match was found if they need to. 409 | /// 410 | /// This should return a list of [`Id`]s of things you'd like to 411 | /// be unioned with `eclass`. There can be zero, one, or many. 412 | /// 413 | /// [`Applier`]: trait.Applier.html 414 | /// [`Id`]: struct.Id.html 415 | /// [`apply_matches`]: trait.Applier.html#method.apply_matches 416 | fn apply_one(&self, egraph: &mut EGraph, eclass: Id, subst: &Subst) -> Vec; 417 | 418 | /// Returns a list of variables that this Applier assumes are bound. 419 | /// 420 | /// `egg` will check that the corresponding `Searcher` binds those 421 | /// variables. 422 | /// By default this return an empty `Vec`, which basically turns off the 423 | /// checking. 424 | fn vars(&self) -> Vec { 425 | vec![] 426 | } 427 | } 428 | 429 | #[cfg(test)] 430 | mod tests { 431 | use crate::{SymbolLang as S, *}; 432 | use std::str::FromStr; 433 | use std::fmt::Formatter; 434 | 435 | type EGraph = crate::EGraph; 436 | 437 | #[test] 438 | fn conditional_rewrite() { 439 | crate::init_logger(); 440 | let mut egraph = EGraph::default(); 441 | 442 | let x = egraph.add(S::leaf("x")); 443 | let y = egraph.add(S::leaf("2")); 444 | let _mul = egraph.add(S::new("*", vec![x, y])); 445 | 446 | let true_id = egraph.add(S::leaf("TRUE")); 447 | 448 | let mul_to_shift = multi_rewrite!( 449 | "mul_to_shift"; 450 | "?x = (* ?a ?b), ?y = (is-power2 ?b), ?y = TRUE" => "?x =(>> ?a (log2 ?b))" 451 | ); 452 | 453 | println!("rewrite shouldn't do anything yet"); 454 | egraph.rebuild(); 455 | let apps = mul_to_shift.run(&mut egraph); 456 | assert!(apps.is_empty()); 457 | 458 | println!("Add the needed equality"); 459 | let two_ispow2 = egraph.add(S::new("is-power2", vec![y])); 460 | egraph.union(two_ispow2, true_id); 461 | 462 | println!("Should fire now"); 463 | egraph.rebuild(); 464 | let apps = mul_to_shift.run(&mut egraph); 465 | // Can't check ids because multipattern 466 | assert!(apps.len() > 0); 467 | } 468 | 469 | #[test] 470 | fn fn_rewrite() { 471 | crate::init_logger(); 472 | let mut egraph = EGraph::default(); 473 | 474 | let start = RecExpr::from_str("(+ x y)").unwrap(); 475 | let goal = RecExpr::from_str("xy").unwrap(); 476 | 477 | let root = egraph.add_expr(&start); 478 | 479 | fn get(egraph: &EGraph, id: Id) -> Symbol { 480 | egraph[id].nodes[0].op 481 | } 482 | 483 | #[derive(Debug)] 484 | struct Appender; 485 | impl std::fmt::Display for Appender { 486 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 487 | write!(f, "{:?}", self) 488 | } 489 | } 490 | 491 | impl Applier for Appender { 492 | fn apply_one(&self, egraph: &mut EGraph, _eclass: Id, subst: &Subst) -> Vec { 493 | let a: Var = "?a".parse().unwrap(); 494 | let b: Var = "?b".parse().unwrap(); 495 | let a = get(&egraph, subst[a]); 496 | let b = get(&egraph, subst[b]); 497 | let s = format!("{}{}", a, b); 498 | vec![egraph.add(S::leaf(&s))] 499 | } 500 | } 501 | 502 | let fold_add = rewrite!( 503 | "fold_add"; "(+ ?a ?b)" => { Appender } 504 | ); 505 | 506 | egraph.rebuild(); 507 | fold_add.run(&mut egraph); 508 | assert_eq!(egraph.equivs(&start, &goal), vec![egraph.find(root)]); 509 | } 510 | } 511 | --------------------------------------------------------------------------------