├── .gitignore ├── .travis.yml ├── Cargo.toml ├── LICENSE ├── README.md ├── build.rs ├── rustfmt.toml ├── src ├── constraints │ ├── mod.rs │ ├── numerical.rs │ ├── symbolic.rs │ └── tests.rs ├── core │ ├── mod.rs │ └── tests.rs ├── datum │ ├── mod.rs │ └── tests.rs ├── infer │ ├── mod.rs │ └── tests.rs ├── lib.rs ├── pedigree │ ├── mod.rs │ └── tests.rs ├── planner │ ├── mod.rs │ └── tests.rs ├── rule │ ├── mod.rs │ └── tests.rs ├── test_utils.rs └── utils.rs └── tests └── skeptic.rs /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here http://doc.crates.io/guide.html#cargotoml-vs-cargolock 7 | Cargo.lock 8 | *.bk 9 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: rust 3 | rust: 4 | - nightly 5 | - beta 6 | - stable 7 | matrix: 8 | allow_failures: 9 | - rust: nightly 10 | before_script: 11 | - | 12 | pip install 'travis-cargo<0.2' --user && 13 | export PATH=$HOME/.local/bin:$PATH 14 | script: 15 | - | 16 | travis-cargo build && 17 | travis-cargo test && 18 | travis-cargo bench && 19 | travis-cargo --only stable doc 20 | after_success: 21 | - travis-cargo --only stable doc-upload 22 | addons: 23 | apt: 24 | packages: 25 | - libcurl4-openssl-dev 26 | - libelf-dev 27 | - libdw-dev 28 | notifications: 29 | email: 30 | on_success: never 31 | env: 32 | global: 33 | - secure: rkoyRxHD5DEgkFjMvuNg/7Z3iNTD0UZOgZzaNOzAy5mwAIsAztMNcltC2LogtFkS8oFtkn6Hvw898yAWJhRUu6Z3IP4U55YJxsaLVYXve8ULvoX3T3hHvGWt55MLJeDW2M2CPmQgGvNQEtLpJLHFcZaQ8ZzUD+UUUnODI94Hvrcc1x5lfOuinXlyJPV9HGqGqTPxTgSrfq6UIWkn4ErtQjLgdd7v3dz5L1pOhLZL8W7rrZCdRJw4cHU/upknDOe3Ki32f25pF5WlJqgtc5vly/67QBjJZnepQyULWRFwgitefC2Pp3ZnQhz4YKYOHALn9BVq+Q8uZV1I1mKK6V8H+AfOiPjugC96yscJt8joaLz3uVbL4PIWc2VDUS+x8w0xWlYtAYKRvMZHt1VgEw+afja8oHXzPxqUvCSKXQ4/WUFVcM75I/+m7xP8B4aG7ybTSM4Lh4WnKDVjqmoZhO5fEhJJPtQfDcyi8XBpp7PArD6vFuVcrrSCO/zC2MkJUL7QfTAIAm/Tu99Dt+i1eqDvHcoztLULgbi3ne/W2BK7MYu5/YUWhygKBLVjqVeAR/arPEkH4J65WYyqCL8UoEHqM+e+ZAeORmRiOcwe/SERETqz7wH4HReiYRy07zwQjRXmrsHkM/+G7LrZf1e9unX1rq0KUe1ecySvhN2ksc3Vlts= 34 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Josh Marlow "] 3 | build = "build.rs" 4 | categories = ["algorithms", ] 5 | description = "Collection of classic AI algorithms with convenient interfaces" 6 | documentation = "https://joshmarlow.github.io/ai_kit-rs/" 7 | keywords = ["ai", "inference", "planning", "constraint", "backtrack"] 8 | license = "MIT" 9 | name = "ai_kit" 10 | repository = "https://github.com/joshmarlow/ai_kit-rs.git" 11 | version = "0.1.1" 12 | 13 | [badges] 14 | travis-ci = { repository = "joshmarlow/ai_kit-rs" } 15 | 16 | [dependencies] 17 | itertools = "0.5.1" 18 | serde = "0.9.5" 19 | serde_derive = "0.9.5" 20 | serde_json = "0.9" 21 | uuid = { version = "0.4", features = ["v4"] } 22 | 23 | [build-dependencies] 24 | skeptic = "0.13.2" 25 | 26 | [dev-dependencies] 27 | skeptic = "0.13.2" 28 | 29 | [features] 30 | default = ["with-planner", "with-forward-inference", "with-datum", "with-rule", "with-pedigree", "with-constraint"] 31 | 32 | with-planner = ["with-constraint"] 33 | with-forward-inference = ["with-planner"] 34 | with-datum = [] 35 | with-rule = ["with-constraint"] 36 | with-constraint = [] 37 | with-pedigree = [] 38 | unstable = [] 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Josh Marlow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/joshmarlow/ai_kit-rs.svg?branch=master)](https://travis-ci.org/joshmarlow/ai_kit-rs) 2 | 3 | AI_Kit 4 | ====== 5 | 6 | AI_Kit aims to be a single dependency for various classic AI algorithms. 7 | 8 | Core project goals are: 9 | 10 | * convenient and ergonomic interfaces to various algorithms by building around traits. 11 | * only build what you need through the use of feature flags 12 | * performance 13 | * easy to understand implementations 14 | 15 | All of the algorithms (documented below) operate on several core traits, `BindingsValue`, `Unify`, `Operation`. 16 | 17 | `ai_kit` provides [optional data structures](#default-trait-implementations) that implement these traits, allowing all algorithms to be usable out of the box - see [Datum](#datum) and [Rule](#rule). 18 | Quick examples are provided before, followed by more in-depth documentation. 19 | 20 | Installation 21 | ============ 22 | 23 | You can use this library by adding the following lines to your Cargo.toml file: 24 | 25 | ``` 26 | [dependencies] 27 | ai_kit = "0.1.0" 28 | ``` 29 | 30 | and adding `extern crate ai_kit` to your crate root. 31 | 32 | Documentation 33 | ============= 34 | 35 | This README provides an introduction. 36 | 37 | API Documentation is available [here](http://joshmarlow.github.io/ai_kit-rs/) 38 | 39 | Core Concepts 40 | ============= 41 | 42 | There are three traits and one structure to understand when using this library: 43 | 44 | `Bindings` - similar to a key/value lookup, but with utilities for ensuring that two (or more) keys have the same value. 45 | 46 | `BindingsValue` - a trait allowing a data structure to be used by the `Bindings` data structure. 47 | 48 | `Unify` - a trait for data structure can be unified with another of the same type. 49 | 50 | `Operation` - a trait for mapping some number of `Unify` instances to some number of other `Unify`s. 51 | This is used for implementing [Forward](#forward-inference) and [Backward](#backward-inference) inferencing. 52 | 53 | 54 | Unify 55 | ----- 56 | 57 | Two data structures can be unified if all of their components are the same or at least of the fields that differ is a variable. The [Datum](#datum) structure implements `Unify`. [Here](#datum-implements-unify) is an example of unifying datums. 58 | 59 | Bindings 60 | -------- 61 | 62 | When successful, the unification process returns a `Bindings` structure, which maps variable names to 63 | their values (when known). It also allows for specifying that two variables are equivalent; in that case, 64 | when the value for one variable is found, it is considerd the value for another. 65 | 66 | Anything that implements `ai_kit::core::BindingsValue` can be used with `Bindings`; [Datum](#datum) implements `BindingsValue`: 67 | 68 | ```rust 69 | // Example of using the ai_kit::datum::Datum for variable bindings. 70 | 71 | extern crate ai_kit; 72 | 73 | use ai_kit::core::Bindings; 74 | use ai_kit::datum::Datum; 75 | 76 | fn main() { 77 | // Create empty bindings 78 | let bindings : Bindings = Bindings::new(); 79 | 80 | // Set the variables "?x" and "?y" equal to each other 81 | let bindings = bindings 82 | .set_binding(&"?x".to_string(), Datum::Variable("?y".to_string())); 83 | 84 | // Set the value of "?x" 85 | let bindings = bindings.set_binding(&"?x".to_string(), Datum::Float(1.0)); 86 | 87 | // Verify that "?y" now has the same value 88 | 89 | assert_eq!(bindings.get_binding(&"?x".to_string()), Some(Datum::Float(1.0))); 90 | } 91 | ``` 92 | 93 | Operation 94 | --------- 95 | 96 | There are times when a program has certain facts from which further facts can be inferred. 97 | This is implemented by the `Operation` trait. This is used to implement [forward inference](#forward-inference) and [planning](#planning). An example of forward chaining reasoning (also called Modus Ponens), would the following: 98 | 99 | ``` 100 | All men are mortal. 101 | Socrates is a man. 102 | Therefore Socrates is mortal. 103 | ``` 104 | 105 | The [Rule](#rule) struct implements `Operation`, and we [use it to perform the above inference in rust](#rule-implements-operation). 106 | 107 | Algorithms 108 | ========== 109 | 110 | ## Constraints 111 | 112 | Feature `with-constraint` 113 | 114 | A simple and limited library for checking and satisfying constraints. 115 | 116 | ## Forward Inference 117 | 118 | Feature `with-forward-inference` 119 | 120 | Implementation of forward-chaining inference - essentially this is inference via Modus Ponens. 121 | 122 | [Example](#rule-implements-operation). 123 | 124 | ## Planning 125 | 126 | Feature `with-planner` 127 | 128 | Planning with backtracking. 129 | 130 | [Example](#planning-examples) 131 | 132 | ## Pedigree 133 | 134 | Misc data-structures and code for representing the path taken to derive a given inference. 135 | 136 | ## Default Trait Implementations 137 | 138 | The above algorithms operate on any structures that implement the appropriate core traits (`BindingsValue`, `Unify` and `Operation`). 139 | 140 | `ai_kit` provides default structures that implement the core traits which should be sufficient for many use-cases. 141 | 142 | ## Datum 143 | 144 | Feature `with-datum`. 145 | 146 | The `datum::Datum` structure implements the `BindingsValue` and `Unify` traits. 147 | 148 | ``` 149 | #[derive(Clone, Debug, Serialize, Deserialize, PartialOrd)] 150 | pub enum Datum { 151 | Nil, 152 | String(String), 153 | Int(i64), 154 | Float(f64), 155 | Variable(String), 156 | Vector(Vec), 157 | } 158 | ``` 159 | 160 | ## Datum Implements Unify 161 | 162 | Because `Datum` implements the `Unify` trait, `Datum` can be unified. 163 | 164 | 165 | ```rust 166 | 167 | extern crate ai_kit; 168 | 169 | use ai_kit::core::{Bindings, Unify}; 170 | use ai_kit::datum::Datum; 171 | 172 | fn main() { 173 | let d = Datum::Float(0.0); 174 | let empty_bindings : Bindings = Bindings::new(); 175 | 176 | // These datums are the same, so they can be unified 177 | let bindings = d.unify(&Datum::Float(0.0), &empty_bindings); 178 | assert!(bindings.is_some()); 179 | 180 | // These datums are not the same, so they cannot be unified 181 | let bindings = d.unify(&Datum::Float(1.0), &empty_bindings); 182 | assert!(bindings.is_none()); 183 | 184 | // These datums differ, but the second is a variable, so they can be unified 185 | let bindings = d.unify(&Datum::Variable("?x".to_string()), &empty_bindings); 186 | assert!(bindings.is_some()); 187 | 188 | // The bindings returned by unification so that the variable ?x now has the same value as d! 189 | assert_eq!(bindings.unwrap().get_binding(&"?x".to_string()), Some(d)); 190 | } 191 | ``` 192 | 193 | ## Rule 194 | 195 | Feature `with-rule`. 196 | 197 | The `rule::Rule` structure implements the `Operation` trait. 198 | 199 | ``` 200 | #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] 201 | pub struct Rule> { 202 | pub constraints: Vec, 203 | pub lhs: Vec, 204 | pub rhs: U, 205 | _marker: PhantomData, 206 | } 207 | ``` 208 | 209 | ## Rule Implements Operation 210 | 211 | ```rust 212 | 213 | extern crate ai_kit; 214 | 215 | use ai_kit::datum::Datum; 216 | use ai_kit::infer::InferenceEngine; 217 | use ai_kit::rule::Rule; 218 | use std::marker::PhantomData; 219 | 220 | fn main() { 221 | // Encode knowledge about mortality 222 | let rules = vec![ 223 | ( 224 | "rule_of_mortality".to_string(), // Rules need ids for inferencing 225 | Rule { 226 | constraints: Vec::new(), 227 | lhs: vec![ 228 | Datum::Vector( 229 | vec![ 230 | Datum::Variable("?x".to_string()), 231 | Datum::String("isa".to_string()), 232 | Datum::String("human".to_string()), 233 | ] 234 | ) 235 | ], 236 | rhs: Datum::Vector(vec![ 237 | Datum::Variable("?x".to_string()), 238 | Datum::String("isa".to_string()), 239 | Datum::String("mortal".to_string()), 240 | ]), 241 | _marker: PhantomData, 242 | } 243 | ), 244 | ]; 245 | 246 | // Setup our initial knowledge about socrates 247 | let facts = vec![ 248 | ( 249 | "socrates_is_hu``man".to_string(), // Facts need ids for inferencing 250 | Datum::Vector( 251 | vec![ 252 | Datum::String("socrates".to_string()), 253 | Datum::String("isa".to_string()), 254 | Datum::String("human".to_string()), 255 | ] 256 | ) 257 | ), 258 | ]; 259 | 260 | // Infer new knowledge! 261 | let mut inf_engine = InferenceEngine::new( 262 | "demo".to_string(), 263 | rules.iter().map(|&(ref id, ref f)| (id, f)).collect(), 264 | facts.iter().map(|&(ref id, ref r)| (id, r)).collect()); 265 | let inferences = inf_engine.chain_forward(); 266 | assert_eq!(inferences.len(), 1); 267 | } 268 | ``` 269 | 270 | ## Planning Examples 271 | 272 | Sudoku Solver 273 | ============= 274 | 275 | Forthcoming 276 | 277 | N-Queens Solver 278 | =============== 279 | 280 | Forthcoming 281 | 282 | 283 | NLP Parser 284 | ========== 285 | 286 | This example takes a bunch of words, and rules for aggregating words into phrases and sentences, and constructs a valid parse of the words. 287 | It then saves a graph in [GraphViz .dot notation](http://www.graphviz.org/) of the actual goal tree constructed into a "parse.dot" in the current working directory. 288 | 289 | ```rust 290 | 291 | extern crate ai_kit; 292 | #[macro_use] 293 | extern crate serde_json; 294 | 295 | use ai_kit::core::Bindings; 296 | use ai_kit::datum::Datum; 297 | use ai_kit::planner::*; 298 | use ai_kit::rule::Rule; 299 | use std::fs::File; 300 | use std::io::Write; 301 | use std::path; 302 | 303 | macro_rules! from_json { 304 | ($type: ty, $json: tt) => ({ 305 | use serde_json; 306 | let x: $type = serde_json::from_value(json!($json)).expect("Expected json decoding"); 307 | x 308 | }) 309 | } 310 | 311 | #[allow(unused)] 312 | fn main() { 313 | /* 314 | * Inference rules that encode the parts of speech for each word and how to 315 | * compose parts of speech into sentences. 316 | */ 317 | let rules: Vec> = from_json!(Vec>, [ 318 | // Parts of speech for each word 319 | {"lhs": [{"str": "a"}], "rhs": {"str": "det"}}, 320 | {"lhs": [{"str": "the"}], "rhs": {"str": "det"}}, 321 | {"lhs": [{"str": "chased"}], "rhs": {"str": "verb"}}, 322 | {"lhs": [{"str": "chased"}], "rhs": {"str": "verb"}}, 323 | {"lhs": [{"str": "dog"}], "rhs": {"str": "noun"}}, 324 | {"lhs": [{"str": "cat"}], "rhs": {"str": "noun"}}, 325 | // Building phrases into sentences 326 | {"lhs": [{"str": "det"}, {"str": "noun"}], "rhs": {"str": "np"}}, 327 | {"lhs": [{"str": "verb"}, {"str": "np"}], "rhs": {"str": "vp"}}, 328 | {"lhs": [{"str": "np"}, {"str": "vp"}], "rhs": {"str": "sen"}} 329 | ]); 330 | 331 | // Our input data - a series of words 332 | let data: Vec = from_json!(Vec, [ 333 | {"str": "a"}, 334 | {"str": "the"}, 335 | {"str": "dog"}, 336 | {"str": "cat"}, 337 | {"str": "chased"} 338 | ]); 339 | 340 | // Specify that our goal is to construct a sentence from the provided data using the provided rules 341 | let mut planner = Planner::new(&Goal::with_pattern(from_json!(Datum, {"str": "sen"})), 342 | &Bindings::new(), 343 | &PlanningConfig { 344 | max_depth: 5, 345 | max_increments: 50, 346 | // Don't reuse a given piece of data (ie, a word) 347 | reuse_data: false, 348 | }, 349 | data.iter().collect(), 350 | rules.iter().collect()); 351 | 352 | // Construct the first interpretation 353 | let result = planner.next(); 354 | assert_eq!(result.is_some(), true); 355 | let (final_goal, bindings) = result.unwrap(); 356 | 357 | // What are our expected leaves of the goal (ie, the order of parsed sentences) 358 | let expected_leaves: Vec = vec![ 359 | "a".to_string(), 360 | "dog".to_string(), 361 | "chased".to_string(), 362 | "the".to_string(), 363 | "cat".to_string() 364 | ] 365 | .into_iter() 366 | .map(|s| Datum::String(s)) 367 | .collect(); 368 | 369 | // Verify that the leaves of our plan are as expected 370 | assert_eq!(final_goal.gather_leaves(&bindings), expected_leaves); 371 | 372 | // Render the plan using graphviz notation 373 | let graphviz_rendering : String = final_goal.render_as_graphviz(); 374 | 375 | // Save the plan in the current working directory 376 | File::create(path::Path::new(&"parse.dot")) 377 | .and_then(|mut file| file.write_all(graphviz_rendering.as_str().as_bytes())); 378 | } 379 | ``` 380 | 381 | Here is the expected content of "parse.dot": 382 | 383 | ``` 384 | graph "goal tree 'sen'" { 385 | 386 | "'sen' [Actor(8)]" -- "'np' [Actor(6)]"; 387 | "'np' [Actor(6)]" -- "'det' [Actor(0)]"; 388 | "'det' [Actor(0)]" -- "'a' [Datum(0)]"; 389 | 390 | "'np' [Actor(6)]" -- "'noun' [Actor(4)]"; 391 | "'noun' [Actor(4)]" -- "'dog' [Datum(2)]"; 392 | 393 | "'sen' [Actor(8)]" -- "'vp' [Actor(7)]"; 394 | "'vp' [Actor(7)]" -- "'verb' [Actor(2)]"; 395 | "'verb' [Actor(2)]" -- "'chased' [Datum(4)]"; 396 | 397 | "'vp' [Actor(7)]" -- "'np' [Actor(6)]"; 398 | "'np' [Actor(6)]" -- "'det' [Actor(1)]"; 399 | "'det' [Actor(1)]" -- "'the' [Datum(1)]"; 400 | 401 | "'np' [Actor(6)]" -- "'noun' [Actor(5)]"; 402 | "'noun' [Actor(5)]" -- "'cat' [Datum(3)]"; 403 | 404 | } 405 | ``` 406 | 407 | If you have graphviz installed locally, you can convert this graph into a PNG file: 408 | 409 | ``` 410 | dot -Tpng parse.dot > parse.png 411 | ``` 412 | 413 | Feature Matrix 414 | ============== 415 | 416 | Some features depend on other features. This is summarized in the following table: 417 | 418 | | Feature | Requires | 419 | |---------|----------| 420 | | `with-planner` | `with-constraint` | 421 | | `with-forward-inference` | `with-planner` `with-constraint` | 422 | | `with-rule` | `with-constraint` | 423 | | `with-constraint` | N/A| 424 | | `with-pedigree` | N/A | 425 | | `with-datum` | N/A | 426 | 427 | 428 | Skeptic Testing 429 | =============== 430 | 431 | Examples in this document are tested as part of the build process using [skeptic](https://github.com/brson/rust-skeptic). 432 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | extern crate skeptic; 2 | 3 | fn main() { 4 | skeptic::generate_doc_tests(&["README.md"]); 5 | } 6 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | # Maximum line length 2 | max_width = 150 3 | format_strings = true 4 | reorder_imports = true 5 | -------------------------------------------------------------------------------- /src/constraints/mod.rs: -------------------------------------------------------------------------------- 1 | //! The constraints module implements a very basic system for checking and solving constraints. 2 | //! 3 | //! Examples 4 | //! 5 | //! ``` 6 | //! use ai_kit::core::Bindings; 7 | //! use ai_kit::constraints::{Constraint, Number, NumericalConstraint, SolveResult}; 8 | //! 9 | //! // Construct a vector of constraints 10 | //! let constraints = vec![ 11 | //! // Specify a value for a variable 12 | //! Constraint::Numerical(NumericalConstraint::Set{ 13 | //! variable: "x".to_string(), 14 | //! constant: 1.0, 15 | //! }), 16 | //! // Specify a value for a second variable 17 | //! Constraint::Numerical(NumericalConstraint::Set{ 18 | //! variable: "sum".to_string(), 19 | //! constant: 3.0, 20 | //! }), 21 | //! // Specify the relation of the first two variables and a third unknown variable 22 | //! Constraint::Numerical(NumericalConstraint::Sum{ 23 | //! first: "x".to_string(), 24 | //! second: "y".to_string(), 25 | //! third: "sum".to_string() 26 | //! }), 27 | //! ]; 28 | //! 29 | //! // Check that constraints are satisfied and infer any new values that we can 30 | //! let result : SolveResult = Constraint::solve_many( 31 | //! constraints.iter().collect(), 32 | //! &Bindings::new(), 33 | //! ); 34 | //! 35 | //! // Verify that the unknown variable (y) has the expected inferred value 36 | //! let expected_inferred_value = Number { value: 2.0 }; 37 | //! 38 | //! if let SolveResult::Success(bindings) = result { 39 | //! assert_eq!( 40 | //! bindings.get_binding(&"y".to_string()).unwrap(), 41 | //! expected_inferred_value 42 | //! ); 43 | //! } 44 | //! ``` 45 | 46 | pub use self::numerical::*; 47 | use core::{Bindings, BindingsValue}; 48 | use serde_json; 49 | use std; 50 | use std::collections::HashMap; 51 | use std::ops::{Add, Div, Mul, Sub}; 52 | use utils; 53 | mod numerical; 54 | 55 | pub use self::symbolic::*; 56 | mod symbolic; 57 | 58 | pub trait ConstraintValue: BindingsValue { 59 | /// Construct a ConstraintValue from a float 60 | fn float(f64) -> Self; 61 | /// Attempt to convert this value to a float 62 | fn to_float(&self) -> Option; 63 | } 64 | 65 | #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] 66 | pub enum SolveResult { 67 | /// A conflict was found, no solution possible 68 | Conflict, 69 | /// Incomplete solution; could not solve all constraints 70 | Partial(Bindings), 71 | /// Successful solution 72 | Success(Bindings), 73 | } 74 | 75 | impl SolveResult { 76 | pub fn ok(&self) -> Option> { 77 | match *self { 78 | SolveResult::Success(ref bindings) => Some(bindings.clone()), 79 | SolveResult::Partial(ref bindings) => Some(bindings.clone()), 80 | _ => None, 81 | } 82 | } 83 | 84 | pub fn and_then(&self, f: &Fn(&Bindings) -> Self) -> Self { 85 | match *self { 86 | SolveResult::Success(ref bindings) => f(bindings), 87 | SolveResult::Partial(ref bindings) => f(bindings), 88 | _ => self.clone(), 89 | } 90 | } 91 | 92 | pub fn if_partial(&self, f: &Fn(&Bindings) -> Self) -> Self { 93 | match *self { 94 | SolveResult::Partial(ref bindings) => f(bindings), 95 | _ => self.clone(), 96 | } 97 | } 98 | } 99 | 100 | impl std::fmt::Display for SolveResult { 101 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 102 | write!(f, "{}", serde_json::to_string(&self).unwrap()) 103 | } 104 | } 105 | 106 | #[derive(Clone, Debug, Deserialize, PartialEq, PartialOrd, Serialize)] 107 | pub enum Constraint { 108 | #[serde(rename = "numerical")] Numerical(NumericalConstraint), 109 | #[serde(rename = "symbolic")] Symbolic(SymbolicConstraint), 110 | } 111 | 112 | impl Eq for Constraint {} 113 | impl Ord for Constraint { 114 | fn cmp(&self, other: &Self) -> std::cmp::Ordering { 115 | self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Less) 116 | } 117 | } 118 | 119 | impl std::fmt::Display for Constraint { 120 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 121 | write!(f, "{}", serde_json::to_string(&self).unwrap()) 122 | } 123 | } 124 | 125 | impl Constraint { 126 | pub fn solve(&self, bindings: &Bindings) -> SolveResult { 127 | match *self { 128 | Constraint::Numerical(ref numerical_constraint) => numerical_constraint.solve(bindings), 129 | Constraint::Symbolic(ref symbolic_constraint) => symbolic_constraint.solve(bindings), 130 | } 131 | } 132 | 133 | pub fn solve_many(constraints: Vec<&Constraint>, bindings: &Bindings) -> SolveResult { 134 | // Aggregate all bindings from the constraints that we can solve 135 | let fold_result = utils::fold_while_some( 136 | (Vec::new(), bindings.clone()), 137 | &mut constraints.iter(), 138 | &|(mut remaining_constraints, bindings), constraint| { 139 | let result: SolveResult = constraint.solve(&bindings); 140 | match result { 141 | SolveResult::Conflict => None, 142 | SolveResult::Partial(bindings) => { 143 | remaining_constraints.push(constraint.clone()); 144 | Some((remaining_constraints, bindings.clone())) 145 | } 146 | SolveResult::Success(bindings) => Some((remaining_constraints, bindings.clone())), 147 | } 148 | }, 149 | ); 150 | match fold_result { 151 | Some((remaining_constraints, bindings)) => { 152 | if remaining_constraints.is_empty() { 153 | SolveResult::Success(bindings) 154 | } else if remaining_constraints.len() == constraints.len() { 155 | // We've made no progress, this is unsolvable 156 | SolveResult::Partial(bindings) 157 | } else { 158 | Constraint::solve_many(remaining_constraints, &bindings) 159 | } 160 | } 161 | None => SolveResult::Conflict, 162 | } 163 | } 164 | 165 | pub fn rename_variables(&self, renamed_variables: &HashMap) -> Self { 166 | match *self { 167 | Constraint::Numerical(ref numerical_constraint) => Constraint::Numerical(numerical_constraint.rename_variables(renamed_variables)), 168 | Constraint::Symbolic(ref symbolic_constraint) => Constraint::Symbolic(symbolic_constraint.rename_variables(renamed_variables)), 169 | } 170 | } 171 | 172 | pub fn variables(&self) -> Vec { 173 | match *self { 174 | Constraint::Numerical(ref numerical_constraint) => numerical_constraint.variables(), 175 | Constraint::Symbolic(ref symbolic_constraint) => symbolic_constraint.variables(), 176 | } 177 | } 178 | } 179 | 180 | #[cfg(test)] 181 | mod tests; 182 | -------------------------------------------------------------------------------- /src/constraints/numerical.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use serde_json; 3 | use std; 4 | 5 | #[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, PartialOrd, Serialize)] 6 | pub struct Number { 7 | pub value: f64, 8 | } 9 | 10 | impl std::fmt::Display for Number { 11 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 12 | write!(f, "{}", self.value) 13 | } 14 | } 15 | 16 | impl Eq for Number {} 17 | impl BindingsValue for Number {} 18 | 19 | impl ConstraintValue for Number { 20 | fn float(f: f64) -> Self { 21 | Number { value: f } 22 | } 23 | 24 | fn to_float(&self) -> Option { 25 | Some(self.value) 26 | } 27 | } 28 | 29 | impl Number { 30 | pub fn new(f: f64) -> Self { 31 | Number { value: f } 32 | } 33 | } 34 | 35 | #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] 36 | pub enum NumericalConstraint { 37 | /// ?x = CONSTANT 38 | #[serde(rename = "set")] 39 | Set { variable: String, constant: f64 }, 40 | /// ?x + ?y = ?z 41 | #[serde(rename = "sum")] 42 | Sum { 43 | first: String, 44 | second: String, 45 | third: String, 46 | }, 47 | /// ?x * ?y = ?z 48 | #[serde(rename = "mul")] 49 | Mul { 50 | first: String, 51 | second: String, 52 | third: String, 53 | }, 54 | /// ?x > ?y 55 | #[serde(rename = ">")] 56 | GreaterThan { left: String, right: String }, 57 | /// ?x != ?y 58 | #[serde(rename = "neq")] 59 | NotEqual { left: String, right: String }, 60 | } 61 | 62 | impl Eq for NumericalConstraint {} 63 | impl Ord for NumericalConstraint { 64 | fn cmp(&self, other: &Self) -> std::cmp::Ordering { 65 | self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Less) 66 | } 67 | } 68 | impl PartialOrd for NumericalConstraint { 69 | fn partial_cmp(&self, other: &Self) -> Option { 70 | match (self, other) { 71 | ( 72 | &NumericalConstraint::Set { 73 | ref variable, 74 | ref constant, 75 | .. 76 | }, 77 | &NumericalConstraint::Set { 78 | variable: ref variable2, 79 | constant: ref constant2, 80 | .. 81 | }, 82 | ) => match variable.partial_cmp(variable2) { 83 | Some(std::cmp::Ordering::Equal) => constant.partial_cmp(constant2), 84 | ordering => ordering, 85 | }, 86 | ( 87 | &NumericalConstraint::Sum { 88 | ref first, 89 | ref second, 90 | ref third, 91 | .. 92 | }, 93 | &NumericalConstraint::Sum { 94 | first: ref first2, 95 | second: ref second2, 96 | third: ref third2, 97 | .. 98 | }, 99 | ) => match first.partial_cmp(first2) { 100 | Some(std::cmp::Ordering::Equal) => match second.partial_cmp(second2) { 101 | Some(std::cmp::Ordering::Equal) => third.partial_cmp(third2), 102 | ordering => ordering, 103 | }, 104 | ordering => ordering, 105 | }, 106 | ( 107 | &NumericalConstraint::Mul { 108 | ref first, 109 | ref second, 110 | ref third, 111 | .. 112 | }, 113 | &NumericalConstraint::Mul { 114 | first: ref first2, 115 | second: ref second2, 116 | third: ref third2, 117 | .. 118 | }, 119 | ) => match first.partial_cmp(first2) { 120 | Some(std::cmp::Ordering::Equal) => match second.partial_cmp(second2) { 121 | Some(std::cmp::Ordering::Equal) => third.partial_cmp(third2), 122 | ordering => ordering, 123 | }, 124 | ordering => ordering, 125 | }, 126 | ( 127 | &NumericalConstraint::GreaterThan { 128 | ref left, 129 | ref right, 130 | .. 131 | }, 132 | &NumericalConstraint::GreaterThan { 133 | left: ref left2, 134 | right: ref right2, 135 | .. 136 | }, 137 | ) => match left.partial_cmp(left2) { 138 | Some(std::cmp::Ordering::Equal) => right.partial_cmp(right2), 139 | ordering => ordering, 140 | }, 141 | ( 142 | &NumericalConstraint::NotEqual { 143 | ref left, 144 | ref right, 145 | .. 146 | }, 147 | &NumericalConstraint::NotEqual { 148 | left: ref left2, 149 | right: ref right2, 150 | .. 151 | }, 152 | ) => match left.partial_cmp(left2) { 153 | Some(std::cmp::Ordering::Equal) => right.partial_cmp(right2), 154 | ordering => ordering, 155 | }, 156 | (&NumericalConstraint::Set { .. }, &NumericalConstraint::Sum { .. }) => Some(std::cmp::Ordering::Less), 157 | (&NumericalConstraint::Set { .. }, &NumericalConstraint::Mul { .. }) => Some(std::cmp::Ordering::Less), 158 | (&NumericalConstraint::Set { .. }, &NumericalConstraint::GreaterThan { .. }) => Some(std::cmp::Ordering::Less), 159 | (&NumericalConstraint::Set { .. }, &NumericalConstraint::NotEqual { .. }) => Some(std::cmp::Ordering::Less), 160 | 161 | (&NumericalConstraint::Sum { .. }, &NumericalConstraint::Set { .. }) => Some(std::cmp::Ordering::Greater), 162 | (&NumericalConstraint::Sum { .. }, &NumericalConstraint::Mul { .. }) => Some(std::cmp::Ordering::Less), 163 | (&NumericalConstraint::Sum { .. }, &NumericalConstraint::GreaterThan { .. }) => Some(std::cmp::Ordering::Less), 164 | (&NumericalConstraint::Sum { .. }, &NumericalConstraint::NotEqual { .. }) => Some(std::cmp::Ordering::Less), 165 | 166 | (&NumericalConstraint::Mul { .. }, &NumericalConstraint::Set { .. }) => Some(std::cmp::Ordering::Greater), 167 | (&NumericalConstraint::Mul { .. }, &NumericalConstraint::Sum { .. }) => Some(std::cmp::Ordering::Greater), 168 | (&NumericalConstraint::Mul { .. }, &NumericalConstraint::GreaterThan { .. }) => Some(std::cmp::Ordering::Less), 169 | (&NumericalConstraint::Mul { .. }, &NumericalConstraint::NotEqual { .. }) => Some(std::cmp::Ordering::Less), 170 | 171 | (&NumericalConstraint::GreaterThan { .. }, &NumericalConstraint::Set { .. }) => Some(std::cmp::Ordering::Greater), 172 | (&NumericalConstraint::GreaterThan { .. }, &NumericalConstraint::Sum { .. }) => Some(std::cmp::Ordering::Greater), 173 | (&NumericalConstraint::GreaterThan { .. }, &NumericalConstraint::Mul { .. }) => Some(std::cmp::Ordering::Greater), 174 | (&NumericalConstraint::GreaterThan { .. }, &NumericalConstraint::NotEqual { .. }) => Some(std::cmp::Ordering::Less), 175 | 176 | (&NumericalConstraint::NotEqual { .. }, &NumericalConstraint::Set { .. }) => Some(std::cmp::Ordering::Greater), 177 | (&NumericalConstraint::NotEqual { .. }, &NumericalConstraint::Sum { .. }) => Some(std::cmp::Ordering::Greater), 178 | (&NumericalConstraint::NotEqual { .. }, &NumericalConstraint::Mul { .. }) => Some(std::cmp::Ordering::Greater), 179 | (&NumericalConstraint::NotEqual { .. }, &NumericalConstraint::GreaterThan { .. }) => Some(std::cmp::Ordering::Greater), 180 | } 181 | } 182 | } 183 | 184 | impl std::fmt::Display for NumericalConstraint { 185 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 186 | write!(f, "{}", serde_json::to_string(&self).unwrap()) 187 | } 188 | } 189 | 190 | impl NumericalConstraint { 191 | /// Try to solve this constraint using the information in the bindings 192 | pub fn solve(&self, bindings: &Bindings) -> SolveResult { 193 | let apply_op = |x: &T, y: &T, op: &Fn(f64, f64) -> f64| -> Option { 194 | x.to_float() 195 | .and_then(|x| y.to_float().and_then(|y| Some(T::float(op(x, y))))) 196 | }; 197 | macro_rules! apply_op_or_return { 198 | ($x: ident, $y: ident, $z: ident, $op: expr) => ({ 199 | if let Some(value) = apply_op($x, $y, &$op) { 200 | ($z, value) 201 | } else { 202 | return SolveResult::Conflict 203 | } 204 | }) 205 | } 206 | let (key, value) = match *self { 207 | NumericalConstraint::Set { 208 | ref variable, 209 | ref constant, 210 | .. 211 | } => match bindings.get_binding(variable) { 212 | None => (variable, T::float(constant.clone())), 213 | Some(ref value) if value.to_float() == Some(*constant) => (variable, value.clone()), 214 | Some(_) => return SolveResult::Conflict, 215 | }, 216 | NumericalConstraint::Sum { 217 | ref first, 218 | ref second, 219 | ref third, 220 | .. 221 | } => match ( 222 | bindings.get_binding(first), 223 | bindings.get_binding(second), 224 | bindings.get_binding(third), 225 | ) { 226 | (Some(ref value), Some(ref value2), None) => apply_op_or_return!(value, value2, third, &Add::add), 227 | (Some(ref value), None, Some(ref value3)) => apply_op_or_return!(value3, value, second, &Sub::sub), 228 | (None, Some(ref value2), Some(ref value3)) => apply_op_or_return!(value3, value2, first, &Sub::sub), 229 | (Some(ref value), Some(ref value2), Some(ref value3)) => { 230 | if Some(value) == apply_op(value3, value2, &Sub::sub).as_ref() { 231 | return SolveResult::Partial(bindings.clone()); 232 | } else { 233 | return SolveResult::Conflict; 234 | } 235 | } 236 | _ => return SolveResult::Partial(bindings.clone()), 237 | }, 238 | NumericalConstraint::Mul { 239 | ref first, 240 | ref second, 241 | ref third, 242 | .. 243 | } => match ( 244 | bindings.get_binding(first), 245 | bindings.get_binding(second), 246 | bindings.get_binding(third), 247 | ) { 248 | (Some(ref value), Some(ref value2), None) => apply_op_or_return!(value, value2, third, &Mul::mul), 249 | (Some(ref value), None, Some(ref value3)) => apply_op_or_return!(value3, value, second, &Div::div), 250 | (None, Some(ref value2), Some(ref value3)) => apply_op_or_return!(value3, value2, first, &Div::div), 251 | (Some(ref value), Some(ref value2), Some(ref value3)) => { 252 | if Some(value) == apply_op(value3, value2, &Div::div).as_ref() { 253 | return SolveResult::Partial(bindings.clone()); 254 | } else { 255 | return SolveResult::Conflict; 256 | } 257 | } 258 | _ => return SolveResult::Partial(bindings.clone()), 259 | }, 260 | NumericalConstraint::GreaterThan { 261 | ref left, 262 | ref right, 263 | .. 264 | } => match (bindings.get_binding(left), bindings.get_binding(right)) { 265 | (Some(ref left_value), Some(ref right_value)) if left_value > right_value => return SolveResult::Success(bindings.clone()), 266 | (Some(_), Some(_)) => return SolveResult::Conflict, 267 | _ => return SolveResult::Partial(bindings.clone()), 268 | }, 269 | NumericalConstraint::NotEqual { 270 | ref left, 271 | ref right, 272 | .. 273 | } => match (bindings.get_binding(left), bindings.get_binding(right)) { 274 | (Some(ref left_value), Some(ref right_value)) if left_value != right_value => return SolveResult::Success(bindings.clone()), 275 | (Some(_), Some(_)) => return SolveResult::Conflict, 276 | _ => return SolveResult::Partial(bindings.clone()), 277 | }, 278 | }; 279 | SolveResult::Success(bindings.set_binding(key, value)) 280 | } 281 | 282 | pub fn rename_variables(&self, renamed_variables: &HashMap) -> Self { 283 | let lookup = |v: &String| -> String { 284 | renamed_variables 285 | .get(v) 286 | .cloned() 287 | .or_else(|| Some(v.clone())) 288 | .unwrap() 289 | }; 290 | match *self { 291 | NumericalConstraint::Set { 292 | ref variable, 293 | ref constant, 294 | .. 295 | } => NumericalConstraint::Set { 296 | variable: lookup(variable), 297 | constant: constant.clone(), 298 | }, 299 | NumericalConstraint::Sum { 300 | ref first, 301 | ref second, 302 | ref third, 303 | .. 304 | } => NumericalConstraint::Sum { 305 | first: lookup(first), 306 | second: lookup(second), 307 | third: lookup(third), 308 | }, 309 | NumericalConstraint::Mul { 310 | ref first, 311 | ref second, 312 | ref third, 313 | .. 314 | } => NumericalConstraint::Mul { 315 | first: lookup(first), 316 | second: lookup(second), 317 | third: lookup(third), 318 | }, 319 | NumericalConstraint::GreaterThan { 320 | ref left, 321 | ref right, 322 | .. 323 | } => NumericalConstraint::GreaterThan { 324 | left: lookup(left), 325 | right: lookup(right), 326 | }, 327 | NumericalConstraint::NotEqual { 328 | ref left, 329 | ref right, 330 | .. 331 | } => NumericalConstraint::GreaterThan { 332 | left: lookup(left), 333 | right: lookup(right), 334 | }, 335 | } 336 | } 337 | 338 | pub fn variables(&self) -> Vec { 339 | match *self { 340 | NumericalConstraint::Set { ref variable, .. } => vec![variable.clone()], 341 | NumericalConstraint::Sum { 342 | ref first, 343 | ref second, 344 | ref third, 345 | .. 346 | } => vec![first.clone(), second.clone(), third.clone()], 347 | NumericalConstraint::Mul { 348 | ref first, 349 | ref second, 350 | ref third, 351 | .. 352 | } => vec![first.clone(), second.clone(), third.clone()], 353 | NumericalConstraint::GreaterThan { 354 | ref left, 355 | ref right, 356 | .. 357 | } => vec![left.clone(), right.clone()], 358 | NumericalConstraint::NotEqual { 359 | ref left, 360 | ref right, 361 | .. 362 | } => vec![left.clone(), right.clone()], 363 | } 364 | } 365 | } 366 | -------------------------------------------------------------------------------- /src/constraints/symbolic.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use serde_json; 3 | use std; 4 | 5 | #[derive(Clone, Debug, Deserialize, PartialEq, PartialOrd, Serialize)] 6 | pub enum SymbolicConstraint { 7 | Eq { v1: String, v2: String }, 8 | Neq { v1: String, v2: String }, 9 | } 10 | 11 | impl std::fmt::Display for SymbolicConstraint { 12 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 13 | write!(f, "{}", serde_json::to_string(&self).unwrap()) 14 | } 15 | } 16 | 17 | impl SymbolicConstraint { 18 | /// Try to solve this constraint using the information in the bindings 19 | pub fn solve(&self, bindings: &Bindings) -> SolveResult { 20 | match *self { 21 | SymbolicConstraint::Eq { ref v1, ref v2 } => { 22 | println!( 23 | "\n{:?}\n{:?}\n{:?}", 24 | bindings.get_binding(v1), 25 | bindings.get_binding(v2), 26 | T::variable(v1).and_then(|var1| Some(SolveResult::Partial(bindings.set_binding(v2, var1)))) 27 | ); 28 | match (bindings.get_binding(v1), bindings.get_binding(v2)) { 29 | (Some(ref val1), Some(ref val2)) => { 30 | if val1.eq(val2) { 31 | Some(SolveResult::Success(bindings.clone())) 32 | } else { 33 | None 34 | } 35 | } 36 | _ => Some(SolveResult::Partial(bindings.clone())), 37 | } 38 | } 39 | SymbolicConstraint::Neq { ref v1, ref v2 } => match (bindings.get_binding(v1), bindings.get_binding(v2)) { 40 | (Some(ref val1), Some(ref val2)) => { 41 | if val1.eq(val2) { 42 | None 43 | } else { 44 | Some(SolveResult::Success(bindings.clone())) 45 | } 46 | } 47 | _ => Some(SolveResult::Partial(bindings.clone())), 48 | }, 49 | }.unwrap_or(SolveResult::Conflict) 50 | } 51 | 52 | pub fn rename_variables(&self, renamed_variables: &HashMap) -> Self { 53 | let lookup = |v: &String| -> String { 54 | renamed_variables 55 | .get(v) 56 | .cloned() 57 | .or_else(|| Some(v.clone())) 58 | .unwrap() 59 | }; 60 | match *self { 61 | SymbolicConstraint::Eq { ref v1, ref v2 } => SymbolicConstraint::Eq { 62 | v1: lookup(v1), 63 | v2: lookup(v2), 64 | }, 65 | SymbolicConstraint::Neq { ref v1, ref v2 } => SymbolicConstraint::Neq { 66 | v1: lookup(v1), 67 | v2: lookup(v2), 68 | }, 69 | } 70 | } 71 | 72 | pub fn variables(&self) -> Vec { 73 | match *self { 74 | SymbolicConstraint::Eq { ref v1, ref v2 } => vec![v1.clone(), v2.clone()], 75 | SymbolicConstraint::Neq { ref v1, ref v2 } => vec![v1.clone(), v2.clone()], 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/constraints/tests.rs: -------------------------------------------------------------------------------- 1 | use constraints::*; 2 | 3 | #[cfg(test)] 4 | mod solver_tests { 5 | use super::*; 6 | 7 | #[test] 8 | fn test_solve_multi_constraint() { 9 | let constraints = from_json!(Vec, [ 10 | { 11 | "numerical": { 12 | "set": { 13 | "variable": "?diff", 14 | "constant": 5.0, 15 | }, 16 | }, 17 | }, 18 | { 19 | "numerical": { 20 | "sum": { 21 | "first": "?x", 22 | "second": "?y", 23 | "third": "?diff", 24 | }, 25 | }, 26 | }, 27 | { 28 | "numerical": { 29 | "sum": { 30 | "first": "?w", 31 | "second": "?x", 32 | "third": "?diff", 33 | }, 34 | } 35 | }]); 36 | let bindings: Bindings = Bindings::new().set_binding(&"?w".to_string(), Number::new(5.0)); 37 | let expected_bindings: Bindings = Bindings::new() 38 | .set_binding(&"?diff".to_string(), Number::new(5.0)) 39 | .set_binding(&"?w".to_string(), Number::new(5.0)) 40 | .set_binding(&"?x".to_string(), Number::new(0.0)) 41 | .set_binding(&"?y".to_string(), Number::new(5.0)); 42 | 43 | assert_eq!( 44 | Constraint::solve_many(constraints.iter().collect(), &bindings), 45 | SolveResult::Success(expected_bindings) 46 | ); 47 | } 48 | 49 | #[test] 50 | fn test_solve_multi_constraint_terminates_when_unsolvable() { 51 | let constraints = from_json!(Vec, [ 52 | { 53 | "numerical": { 54 | "set": { 55 | "variable": "?diff", 56 | "constant": 5.0, 57 | }, 58 | }, 59 | }, 60 | { 61 | "numerical": { 62 | "sum": { 63 | "first": "?z", 64 | "second": "?y", 65 | "third": "?diff", 66 | }, 67 | }, 68 | }, 69 | { 70 | "numerical": { 71 | "sum": { 72 | "first": "?w", 73 | "second": "?x", 74 | "third": "?diff", 75 | }, 76 | } 77 | }]); 78 | let bindings: Bindings = Bindings::new().set_binding(&"?w".to_string(), Number::new(5.0)); 79 | let expected_bindings: Bindings = Bindings::new() 80 | .set_binding(&"?diff".to_string(), Number::new(5.0)) 81 | .set_binding(&"?w".to_string(), Number::new(5.0)) 82 | .set_binding(&"?x".to_string(), Number::new(0.0)); 83 | 84 | assert_eq!( 85 | Constraint::solve_many(constraints.iter().collect(), &bindings), 86 | SolveResult::Partial(expected_bindings) 87 | ); 88 | } 89 | } 90 | 91 | #[cfg(test)] 92 | mod numerical_tests { 93 | use super::*; 94 | 95 | #[test] 96 | fn test_solve_sum_constraint_forward() { 97 | let constraint: Constraint = Constraint::Numerical(NumericalConstraint::Sum { 98 | first: "?x".to_string(), 99 | second: "?y".to_string(), 100 | third: "?z".to_string(), 101 | }); 102 | let bindings: Bindings = vec![ 103 | ("?x".to_string(), Number::new(10.0)), 104 | ("?y".to_string(), Number::new(5.0)), 105 | ].into_iter() 106 | .collect(); 107 | let expected_bindings: Bindings = vec![ 108 | ("?x".to_string(), Number::new(10.0)), 109 | ("?y".to_string(), Number::new(5.0)), 110 | ("?z".to_string(), Number::new(15.0)), 111 | ].into_iter() 112 | .collect(); 113 | assert_eq!( 114 | constraint.solve(&bindings), 115 | SolveResult::Success(expected_bindings) 116 | ); 117 | } 118 | 119 | #[test] 120 | fn test_solve_sum_constraint_backward() { 121 | let constraint: Constraint = Constraint::Numerical(NumericalConstraint::Sum { 122 | first: "?x".to_string(), 123 | second: "?y".to_string(), 124 | third: "?z".to_string(), 125 | }); 126 | let bindings: Bindings = vec![ 127 | ("?x".to_string(), Number::new(10.0)), 128 | ("?z".to_string(), Number::new(15.0)), 129 | ].into_iter() 130 | .collect(); 131 | let expected_bindings: Bindings = vec![ 132 | ("?x".to_string(), Number::new(10.0)), 133 | ("?y".to_string(), Number::new(5.0)), 134 | ("?z".to_string(), Number::new(15.0)), 135 | ].into_iter() 136 | .collect(); 137 | assert_eq!( 138 | constraint.solve(&bindings), 139 | SolveResult::Success(expected_bindings) 140 | ); 141 | } 142 | 143 | #[test] 144 | fn test_solve_mul_constraint_forward() { 145 | let constraint: Constraint = Constraint::Numerical(NumericalConstraint::Mul { 146 | first: "?x".to_string(), 147 | second: "?y".to_string(), 148 | third: "?z".to_string(), 149 | }); 150 | let bindings: Bindings = vec![ 151 | ("?x".to_string(), Number::new(3.0)), 152 | ("?y".to_string(), Number::new(5.0)), 153 | ].into_iter() 154 | .collect(); 155 | let expected_bindings: Bindings = vec![ 156 | ("?x".to_string(), Number::new(3.0)), 157 | ("?y".to_string(), Number::new(5.0)), 158 | ("?z".to_string(), Number::new(15.0)), 159 | ].into_iter() 160 | .collect(); 161 | assert_eq!( 162 | constraint.solve(&bindings), 163 | SolveResult::Success(expected_bindings) 164 | ); 165 | } 166 | 167 | #[test] 168 | fn test_solve_mul_constraint_backward() { 169 | let constraint: Constraint = Constraint::Numerical(NumericalConstraint::Mul { 170 | first: "?x".to_string(), 171 | second: "?y".to_string(), 172 | third: "?z".to_string(), 173 | }); 174 | let bindings: Bindings = vec![ 175 | ("?x".to_string(), Number::new(3.0)), 176 | ("?z".to_string(), Number::new(15.0)), 177 | ].into_iter() 178 | .collect(); 179 | let expected_bindings: Bindings = vec![ 180 | ("?x".to_string(), Number::new(3.0)), 181 | ("?y".to_string(), Number::new(5.0)), 182 | ("?z".to_string(), Number::new(15.0)), 183 | ].into_iter() 184 | .collect(); 185 | assert_eq!( 186 | constraint.solve(&bindings), 187 | SolveResult::Success(expected_bindings) 188 | ); 189 | } 190 | 191 | #[test] 192 | fn test_solve_greater_than_constraint_succeeds() { 193 | let constraint: Constraint = Constraint::Numerical(NumericalConstraint::GreaterThan { 194 | left: "?x".to_string(), 195 | right: "?y".to_string(), 196 | }); 197 | let bindings: Bindings = vec![ 198 | ("?x".to_string(), Number::new(15.0)), 199 | ("?y".to_string(), Number::new(5.0)), 200 | ].into_iter() 201 | .collect(); 202 | assert_eq!(constraint.solve(&bindings), SolveResult::Success(bindings)); 203 | } 204 | 205 | #[test] 206 | fn test_solve_greater_than_constraint_fails() { 207 | let constraint: Constraint = Constraint::Numerical(NumericalConstraint::GreaterThan { 208 | left: "?x".to_string(), 209 | right: "?y".to_string(), 210 | }); 211 | let bindings: Bindings = vec![ 212 | ("?x".to_string(), Number::new(5.0)), 213 | ("?y".to_string(), Number::new(15.0)), 214 | ].into_iter() 215 | .collect(); 216 | assert_eq!(constraint.solve(&bindings), SolveResult::Conflict); 217 | } 218 | } 219 | 220 | #[cfg(test)] 221 | mod symbolic_tests { 222 | use super::*; 223 | 224 | #[test] 225 | fn test_eq_returns_success() { 226 | let constraint: Constraint = Constraint::Symbolic(SymbolicConstraint::Eq { 227 | v1: "?x".to_string(), 228 | v2: "?y".to_string(), 229 | }); 230 | let bindings: Bindings = vec![ 231 | ("?x".to_string(), Number::new(3.0)), 232 | ("?y".to_string(), Number::new(3.0)), 233 | ].into_iter() 234 | .collect(); 235 | assert_eq!(constraint.solve(&bindings), SolveResult::Success(bindings)); 236 | } 237 | 238 | #[test] 239 | fn test_eq_returns_conflict() { 240 | let constraint: Constraint = Constraint::Symbolic(SymbolicConstraint::Eq { 241 | v1: "?x".to_string(), 242 | v2: "?y".to_string(), 243 | }); 244 | let bindings: Bindings = vec![ 245 | ("?x".to_string(), Number::new(2.0)), 246 | ("?y".to_string(), Number::new(3.0)), 247 | ].into_iter() 248 | .collect(); 249 | assert_eq!(constraint.solve(&bindings), SolveResult::Conflict); 250 | } 251 | 252 | #[test] 253 | fn test_eq_returns_partial_if_first_variable_is_undefined() { 254 | let constraint: Constraint = Constraint::Symbolic(SymbolicConstraint::Eq { 255 | v1: "?x".to_string(), 256 | v2: "?y".to_string(), 257 | }); 258 | let bindings: Bindings = vec![("?y".to_string(), Number::new(3.0))] 259 | .into_iter() 260 | .collect(); 261 | assert_eq!(constraint.solve(&bindings), SolveResult::Partial(bindings)); 262 | } 263 | 264 | #[test] 265 | fn test_eq_returns_partial_if_second_variable_is_undefined() { 266 | let constraint: Constraint = Constraint::Symbolic(SymbolicConstraint::Eq { 267 | v1: "?x".to_string(), 268 | v2: "?y".to_string(), 269 | }); 270 | let bindings: Bindings = vec![("?x".to_string(), Number::new(3.0))] 271 | .into_iter() 272 | .collect(); 273 | assert_eq!(constraint.solve(&bindings), SolveResult::Partial(bindings)); 274 | } 275 | 276 | #[test] 277 | fn test_neq_returns_success() { 278 | let constraint: Constraint = Constraint::Symbolic(SymbolicConstraint::Neq { 279 | v1: "?x".to_string(), 280 | v2: "?y".to_string(), 281 | }); 282 | let bindings: Bindings = vec![ 283 | ("?x".to_string(), Number::new(2.0)), 284 | ("?y".to_string(), Number::new(3.0)), 285 | ].into_iter() 286 | .collect(); 287 | assert_eq!(constraint.solve(&bindings), SolveResult::Success(bindings)); 288 | } 289 | 290 | #[test] 291 | fn test_neq_returns_conflict() { 292 | let constraint: Constraint = Constraint::Symbolic(SymbolicConstraint::Neq { 293 | v1: "?x".to_string(), 294 | v2: "?y".to_string(), 295 | }); 296 | let bindings: Bindings = vec![ 297 | ("?x".to_string(), Number::new(3.0)), 298 | ("?y".to_string(), Number::new(3.0)), 299 | ].into_iter() 300 | .collect(); 301 | assert_eq!(constraint.solve(&bindings), SolveResult::Conflict); 302 | } 303 | 304 | #[test] 305 | fn test_neq_returns_partial_if_first_variable_is_undefined() { 306 | let constraint: Constraint = Constraint::Symbolic(SymbolicConstraint::Neq { 307 | v1: "?x".to_string(), 308 | v2: "?y".to_string(), 309 | }); 310 | let bindings: Bindings = vec![("?y".to_string(), Number::new(3.0))] 311 | .into_iter() 312 | .collect(); 313 | assert_eq!(constraint.solve(&bindings), SolveResult::Partial(bindings)); 314 | } 315 | 316 | #[test] 317 | fn test_neq_returns_partial_if_second_variable_is_undefined() { 318 | let constraint: Constraint = Constraint::Symbolic(SymbolicConstraint::Neq { 319 | v1: "?x".to_string(), 320 | v2: "?y".to_string(), 321 | }); 322 | let bindings: Bindings = vec![("?x".to_string(), Number::new(3.0))] 323 | .into_iter() 324 | .collect(); 325 | assert_eq!(constraint.solve(&bindings), SolveResult::Partial(bindings)); 326 | } 327 | } 328 | -------------------------------------------------------------------------------- /src/core/mod.rs: -------------------------------------------------------------------------------- 1 | //! The core module contains the core data structures and traits used by all other modules. 2 | 3 | #[cfg(feature = "with-constraint")] 4 | use constraints; 5 | 6 | use serde::{Deserialize, Serialize}; 7 | use std::collections::{BTreeSet, HashMap}; 8 | use std::fmt::{Debug, Display, Formatter, Result}; 9 | use std::iter::{Extend, FromIterator}; 10 | 11 | /// A type must implement BindingsValue in order to be used as the value for some variable. 12 | pub trait BindingsValue 13 | : Clone + Debug + Default + Deserialize + Display + Eq + PartialEq + PartialOrd + Serialize 14 | { 15 | /// Construct a BindingsValue variable using the specified string as it's name 16 | fn variable(_s: &String) -> Option { 17 | None 18 | } 19 | /// Extract the name of this BindingsValue variable (if it is a variable) 20 | fn to_variable(&self) -> Option { 21 | None 22 | } 23 | } 24 | 25 | /// Bindings is used for storing variables and operating on variables and their bindings. 26 | #[derive(Debug, Deserialize, Clone, Serialize)] 27 | pub struct Bindings { 28 | #[serde(default)] 29 | data: HashMap, 30 | #[serde(default)] 31 | equivalences: HashMap>, 32 | } 33 | 34 | impl PartialEq for Bindings { 35 | fn eq(&self, other: &Bindings) -> bool { 36 | self.data.eq(&other.data) && self.equivalences.eq(&other.equivalences) 37 | } 38 | } 39 | impl Eq for Bindings {} 40 | 41 | impl Default for Bindings { 42 | fn default() -> Self { 43 | Bindings { 44 | data: HashMap::new(), 45 | equivalences: HashMap::new(), 46 | } 47 | } 48 | } 49 | 50 | impl Bindings { 51 | pub fn new() -> Bindings { 52 | Bindings::default() 53 | } 54 | 55 | pub fn len(&self) -> usize { 56 | self.data.len() 57 | } 58 | 59 | pub fn has_binding(&self, variable: &String) -> bool { 60 | self.data.contains_key(variable) 61 | } 62 | 63 | pub fn set_binding(&self, variable: &String, val: T) -> Bindings { 64 | let mut bindings_copy = self.clone(); 65 | 66 | bindings_copy.set_binding_mut(variable, val); 67 | 68 | bindings_copy 69 | } 70 | 71 | fn set_binding_mut(&mut self, variable: &String, val: T) { 72 | self.ensure_equivalence_exists_mut(variable); 73 | 74 | if let Some(variable2) = val.to_variable() { 75 | self.add_equivalence(variable, &variable2); 76 | } else { 77 | for equivalent_variable in self.equivalences.get(variable).unwrap().iter() { 78 | self.data.insert(equivalent_variable.clone(), val.clone()); 79 | } 80 | } 81 | } 82 | 83 | fn add_equivalence(&mut self, variable: &String, variable2: &String) { 84 | self.ensure_equivalence_exists_mut(&variable2); 85 | self.merge_equivalences_mut(variable, &variable2); 86 | } 87 | 88 | fn ensure_equivalence_exists_mut(&mut self, variable: &String) { 89 | if !self.equivalences.contains_key(variable) { 90 | self.equivalences.insert(variable.clone(), 91 | vec![variable.clone()].into_iter().collect()); 92 | } 93 | } 94 | 95 | fn merge_equivalences_mut(&mut self, variable: &String, variable2: &String) { 96 | let mut merge = self.equivalences.get(variable).cloned().unwrap(); 97 | merge.extend(self.equivalences.get(variable2).cloned().unwrap()); 98 | 99 | self.equivalences.insert(variable.clone(), merge.clone()); 100 | self.equivalences.insert(variable2.clone(), merge); 101 | } 102 | 103 | pub fn get_binding(&self, variable: &String) -> Option { 104 | match self.data.get(variable) { 105 | Some(val) => Some(val.clone()), 106 | None => None, 107 | } 108 | } 109 | 110 | pub fn update_bindings(&self, variable: &String, value: &T) -> Option { 111 | // If we are setting a variable to itself, then do nothing 112 | if Some(variable.clone()) == value.to_variable() { 113 | return Some(self.clone()); 114 | } 115 | match self.get_binding(&variable) { 116 | Some(ref val) if val == value => Some(self.clone()), 117 | Some(_) => None, 118 | None => Some(self.set_binding(variable, value.clone())), 119 | } 120 | } 121 | 122 | pub fn merge(&self, other: &Self) -> Self { 123 | let mut bindings = self.clone(); 124 | 125 | // Merge in equivalences 126 | for (ref key, ref equivalences) in other.equivalences.iter() { 127 | for equivalent_key in equivalences.iter() { 128 | bindings.ensure_equivalence_exists_mut(key); 129 | bindings.add_equivalence(key, equivalent_key); 130 | } 131 | } 132 | 133 | // Merge in values 134 | for (key, value) in other.data.iter() { 135 | bindings.set_binding_mut(&key, value.clone()) 136 | } 137 | 138 | bindings 139 | } 140 | 141 | pub fn equivalences_string(&self) -> String { 142 | let equivalent_v: Vec = self.equivalences.iter().map(|(key, value)| format!("{} => {:?}", key, value)).collect(); 143 | equivalent_v.join(",") 144 | } 145 | } 146 | 147 | impl FromIterator<(String, T)> for Bindings { 148 | fn from_iter>(iter: I) -> Bindings { 149 | let mut bindings: Bindings = Bindings::new(); 150 | for (key, value) in iter { 151 | bindings.set_binding_mut(&key, value); 152 | } 153 | bindings 154 | } 155 | } 156 | 157 | impl Extend<(String, T)> for Bindings { 158 | fn extend(&mut self, iter: It) 159 | where It: IntoIterator 160 | { 161 | for (key, value) in iter { 162 | self.set_binding_mut(&key, value); 163 | } 164 | } 165 | } 166 | 167 | impl Display for Bindings { 168 | fn fmt(&self, f: &mut Formatter) -> Result { 169 | try!(write!(f, "(")); 170 | let mut sorted_keys: Vec = self.data.keys().cloned().collect(); 171 | sorted_keys.sort(); 172 | for key in sorted_keys.into_iter() { 173 | let ref val = self.data[&key]; 174 | try!(write!(f, "{} => {}, ", key, val)); 175 | } 176 | try!(write!(f, "Equivalences: {}", self.equivalences_string())); 177 | write!(f, ")") 178 | } 179 | } 180 | 181 | /// A type must implement the Unify trait for it to be unifiable - this only requirement for a data structure 182 | /// that the algorithms in this library operate on. 183 | pub trait Unify 184 | : Clone + Debug + Display + Eq + Serialize + Deserialize + PartialEq { 185 | /// Check if this structure can be unified with another of the same type. 186 | fn unify(&self, &Self, &Bindings) -> Option>; 187 | 188 | /// Given some bindings, construct a new instance with any variables replaced by their values 189 | fn apply_bindings(&self, &Bindings) -> Option; 190 | 191 | /// Return all variables in this structure 192 | fn variables(&self) -> Vec; 193 | 194 | /// Rename any variables in this structure with another variable name 195 | fn rename_variables(&self, &HashMap) -> Self; 196 | 197 | /// Return a 'nil' sentinel value unique to this type of structure 198 | fn nil() -> Self; 199 | } 200 | 201 | /// A type that implements Operation constructs new Unifys from existing Unifys that match it's input patterns. 202 | pub trait Operation> 203 | : Clone + Debug + Display + Eq + PartialEq + Deserialize + Serialize { 204 | // NOTE: replace constraints with validate_bindings? 205 | #[cfg(feature = "with-constraint")] 206 | fn constraints<'a>(&'a self) -> Vec<&'a constraints::Constraint>; 207 | 208 | /// Creates a new instance of this Operation where all variables are unique 209 | fn snowflake(&self, String) -> Self; 210 | 211 | /// Return a vector of input patterns that must be unified with in order to apply this Operation. 212 | fn input_patterns(&self) -> Vec { 213 | Vec::new() 214 | } 215 | 216 | /// Given some bindings, construct a set of output patterns 217 | fn apply_match(&self, _bindings: &Bindings) -> Option>; 218 | 219 | /// Given some bindings, construct a set of input patterns that would match 220 | fn r_apply_match(&self, _fact: &U) -> Option<(Vec, Bindings)>; 221 | } 222 | 223 | #[cfg(test)] 224 | mod tests; 225 | -------------------------------------------------------------------------------- /src/core/tests.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod bindings_tests { 3 | use core::{Bindings, BindingsValue}; 4 | use std::collections::HashMap; 5 | 6 | impl BindingsValue for String { 7 | fn variable(s: &String) -> Option { 8 | Some(s.clone()) 9 | } 10 | fn to_variable(&self) -> Option { 11 | if self.starts_with("?") { 12 | Some(self.clone()) 13 | } else { 14 | None 15 | } 16 | } 17 | } 18 | 19 | #[test] 20 | fn test_setting_variable_as_value_adds_to_equivalence() { 21 | let expected_bindings: Bindings = Bindings { 22 | data: HashMap::new(), 23 | equivalences: vec![("?y".to_string(), vec!["?y".to_string(), "?x".to_string()].into_iter().collect()), 24 | ("?x".to_string(), vec!["?y".to_string(), "?x".to_string()].into_iter().collect())] 25 | .into_iter() 26 | .collect(), 27 | }; 28 | let bindings: Bindings = Bindings::new().set_binding(&"?x".to_string(), "?y".to_string()); 29 | assert_eq!(bindings.equivalences, expected_bindings.equivalences); 30 | } 31 | 32 | #[test] 33 | fn test_setting_variable_sets_value_for_all_equivalents() { 34 | let expected_bindings = Bindings { 35 | data: vec![("?x".to_string(), "5.0".to_string()), ("?y".to_string(), "5.0".to_string())] 36 | .into_iter() 37 | .collect(), 38 | equivalences: vec![("?y".to_string(), vec!["?y".to_string(), "?x".to_string()].into_iter().collect()), 39 | ("?x".to_string(), vec!["?y".to_string(), "?x".to_string()].into_iter().collect())] 40 | .into_iter() 41 | .collect(), 42 | }; 43 | let bindings = Bindings::new() 44 | .set_binding(&"?x".to_string(), "?y".to_string()) 45 | .set_binding(&"?x".to_string(), "5.0".to_string()); 46 | assert_eq!(bindings, expected_bindings); 47 | } 48 | 49 | #[test] 50 | fn test_merge() { 51 | let bindings = Bindings::new() 52 | .set_binding(&"?x".to_string(), "?y".to_string()) 53 | .set_binding(&"?x".to_string(), "5.0".to_string()); 54 | let bindings_2 = Bindings::new() 55 | .set_binding(&"?a".to_string(), "?b".to_string()) 56 | .set_binding(&"?a".to_string(), "10.0".to_string()); 57 | 58 | let expected_bindings = Bindings::new() 59 | .set_binding(&"?x".to_string(), "?y".to_string()) 60 | .set_binding(&"?x".to_string(), "5.0".to_string()) 61 | .set_binding(&"?a".to_string(), "?b".to_string()) 62 | .set_binding(&"?a".to_string(), "10.0".to_string()); 63 | assert_eq!(bindings.merge(&bindings_2), expected_bindings); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/datum/mod.rs: -------------------------------------------------------------------------------- 1 | //! The datum module provides a data structure, Datum, that implements the Unify trait. 2 | //! Datum aims to be a drop-in for any algorithm in ai_kit that operates on the Unify trait. 3 | 4 | #[cfg(feature = "with-constraint")] 5 | use constraints::ConstraintValue; 6 | use core; 7 | use std; 8 | use std::collections::{BTreeMap, HashMap}; 9 | use std::str; 10 | use utils; 11 | 12 | #[derive(Clone, Debug, Serialize, Deserialize, PartialOrd)] 13 | pub enum Datum { 14 | #[serde(rename = "null")] Nil, 15 | #[serde(rename = "bool")] Bool(bool), 16 | #[serde(rename = "str")] String(String), 17 | #[serde(rename = "int")] Int(i64), 18 | #[serde(rename = "float")] Float(f64), 19 | #[serde(rename = "var")] Variable(String), 20 | #[serde(rename = "vec")] Vector(Vec), 21 | #[serde(rename = "map")] Map(BTreeMap), 22 | #[serde(rename = "fn")] Function { head: Box, args: Vec }, 23 | } 24 | 25 | impl Datum { 26 | pub fn to_bool(&self) -> Option { 27 | match *self { 28 | Datum::Bool(ref value) => Some(value.clone()), 29 | _ => None, 30 | } 31 | } 32 | 33 | pub fn to_string(&self) -> Option { 34 | match *self { 35 | Datum::String(ref value) => Some(value.clone()), 36 | _ => None, 37 | } 38 | } 39 | 40 | pub fn to_int(&self) -> Option { 41 | match *self { 42 | Datum::Int(ref value) => Some(value.clone()), 43 | _ => None, 44 | } 45 | } 46 | 47 | pub fn to_float(&self) -> Option { 48 | match *self { 49 | Datum::Float(ref f_value) => Some(f_value.clone()), 50 | Datum::Int(ref i_value) => Some(i_value.clone() as f64), 51 | _ => None, 52 | } 53 | } 54 | 55 | pub fn to_variable(&self) -> Option { 56 | match *self { 57 | Datum::Variable(ref value) => Some(value.clone()), 58 | _ => None, 59 | } 60 | } 61 | 62 | pub fn function_head<'a>(&'a self) -> Option<&'a Box> { 63 | match *self { 64 | Datum::Function { ref head, .. } => Some(head), 65 | _ => None, 66 | } 67 | } 68 | 69 | pub fn function_args<'a>(&'a self) -> Option<&'a Vec> { 70 | match *self { 71 | Datum::Function { ref args, .. } => Some(args), 72 | _ => None, 73 | } 74 | } 75 | 76 | pub fn pprint(&self) -> String { 77 | match *self { 78 | Datum::Nil => format!("nil"), 79 | Datum::Bool(ref b) => format!("{}", b), 80 | Datum::String(ref s) => format!("{}", s), 81 | Datum::Int(ref i) => format!("{}", i), 82 | Datum::Float(ref f) => format!("{}", f), 83 | Datum::Variable(ref v) => format!("{}", v), 84 | Datum::Vector(ref args) => { 85 | let elements: Vec = args.iter().map(|e| e.pprint()).collect(); 86 | format!("({})", elements.join(",")) 87 | } 88 | Datum::Map(ref args) => { 89 | let elements: Vec = args.iter() 90 | .map(|(k, v)| format!("{} => {}", k, v.pprint())) 91 | .collect(); 92 | format!("({})", elements.join(",")) 93 | } 94 | Datum::Function { ref head, ref args } => { 95 | let elements: Vec = args.iter().map(|e| e.pprint()).collect(); 96 | format!("({} ({}))", head, elements.join(",")) 97 | } 98 | } 99 | } 100 | 101 | pub fn is_nil(&self) -> bool { 102 | match *self { 103 | Datum::Nil => true, 104 | _ => false, 105 | } 106 | } 107 | 108 | pub fn is_bool(&self) -> bool { 109 | match *self { 110 | Datum::Bool(_) => true, 111 | _ => false, 112 | } 113 | } 114 | 115 | pub fn is_int(&self) -> bool { 116 | match *self { 117 | Datum::Int(_) => true, 118 | _ => false, 119 | } 120 | } 121 | 122 | pub fn is_float(&self) -> bool { 123 | match *self { 124 | Datum::Float(_) => true, 125 | _ => false, 126 | } 127 | } 128 | 129 | pub fn is_string(&self) -> bool { 130 | match *self { 131 | Datum::String(_) => true, 132 | _ => false, 133 | } 134 | } 135 | 136 | pub fn is_variable(&self) -> bool { 137 | match *self { 138 | Datum::Variable(_) => true, 139 | _ => false, 140 | } 141 | } 142 | 143 | pub fn is_vector(&self) -> bool { 144 | match *self { 145 | Datum::Vector(_) => true, 146 | _ => false, 147 | } 148 | } 149 | 150 | pub fn is_map(&self) -> bool { 151 | match *self { 152 | Datum::Map { .. } => true, 153 | _ => false, 154 | } 155 | } 156 | 157 | pub fn is_function(&self) -> bool { 158 | match *self { 159 | Datum::Function { .. } => true, 160 | _ => false, 161 | } 162 | } 163 | 164 | pub fn has_same_shape_as(&self, other: &Datum) -> bool { 165 | match (self, other) { 166 | (&Datum::Nil, &Datum::Nil) => true, 167 | (&Datum::Bool(_), &Datum::Bool(_)) => true, 168 | (&Datum::Int(_), &Datum::Int(_)) => true, 169 | (&Datum::Float(_), &Datum::Float(_)) => true, 170 | (&Datum::String(_), &Datum::String(_)) => true, 171 | (&Datum::Variable(_), &Datum::Variable(_)) => true, 172 | (&Datum::Vector(ref x), &Datum::Vector(ref y)) => x.len() == y.len(), 173 | (&Datum::Map(ref x), &Datum::Map(ref y)) => x.len() == y.len(), 174 | ( 175 | &Datum::Function { ref head, ref args }, 176 | &Datum::Function { 177 | head: ref head2, 178 | args: ref args2, 179 | }, 180 | ) => { 181 | head.has_same_shape_as(head2) 182 | && args.iter() 183 | .zip(args2.iter()) 184 | .all(|(x, y)| x.has_same_shape_as(y)) 185 | } 186 | _ => false, 187 | } 188 | } 189 | } 190 | 191 | impl Default for Datum { 192 | fn default() -> Self { 193 | Datum::Nil 194 | } 195 | } 196 | 197 | impl PartialEq for Datum { 198 | fn eq(&self, other: &Datum) -> bool { 199 | match *self { 200 | Datum::Nil => match *other { 201 | Datum::Nil => true, 202 | _ => false, 203 | }, 204 | Datum::Bool(b) => match *other { 205 | Datum::Bool(b2) => b == b2, 206 | _ => false, 207 | }, 208 | Datum::String(ref s) => match *other { 209 | Datum::String(ref s2) => s == s2, 210 | _ => false, 211 | }, 212 | Datum::Int(ref i) => match *other { 213 | Datum::Int(ref i2) => i == i2, 214 | _ => false, 215 | }, 216 | Datum::Float(ref f) => match *other { 217 | Datum::Float(ref f2) => f == f2, 218 | _ => false, 219 | }, 220 | Datum::Variable(ref v) => match *other { 221 | Datum::Variable(ref v2) => v == v2, 222 | _ => false, 223 | }, 224 | Datum::Vector(ref args) => match *other { 225 | Datum::Vector(ref args2) => args == args2, 226 | _ => false, 227 | }, 228 | Datum::Map(ref args) => match *other { 229 | Datum::Map(ref args2) => args == args2, 230 | _ => false, 231 | }, 232 | Datum::Function { ref head, ref args } => match *other { 233 | Datum::Function { 234 | head: ref head2, 235 | args: ref args2, 236 | } if args.len() == args2.len() => 237 | { 238 | head == head2 && args == args2 239 | } 240 | _ => false, 241 | }, 242 | } 243 | } 244 | } 245 | 246 | impl Eq for Datum {} 247 | impl Ord for Datum { 248 | fn cmp(&self, other: &Self) -> std::cmp::Ordering { 249 | self.partial_cmp(other).unwrap() 250 | } 251 | } 252 | 253 | impl core::BindingsValue for Datum { 254 | fn variable(s: &String) -> Option { 255 | Some(Datum::Variable(s.clone())) 256 | } 257 | fn to_variable(&self) -> Option { 258 | self.to_variable() 259 | } 260 | } 261 | 262 | #[cfg(feature = "with-constraint")] 263 | impl ConstraintValue for Datum { 264 | fn to_float(&self) -> Option { 265 | self.to_float() 266 | } 267 | fn float(c: f64) -> Self { 268 | Datum::Float(c) 269 | } 270 | } 271 | 272 | impl core::Unify for Datum { 273 | fn unify(&self, other: &Datum, bindings: &core::Bindings) -> Option> { 274 | fn unify_args(args: &Vec, args2: &Vec, bindings: &core::Bindings) -> Option> { 275 | utils::fold_while_some( 276 | bindings.clone(), 277 | &mut args.iter().zip(args2.iter()), 278 | &|bindings: core::Bindings, (ref a, ref b): (&Datum, &Datum)| a.unify(&b, &bindings), 279 | ) 280 | } 281 | 282 | fn unify_maps( 283 | args: &BTreeMap, 284 | args2: &BTreeMap, 285 | bindings: &core::Bindings, 286 | ) -> Option> { 287 | if args.len() != args2.len() { 288 | None 289 | } else { 290 | utils::fold_while_some( 291 | bindings.clone(), 292 | &mut args.iter(), 293 | &|bindings: core::Bindings, (ref k, ref v): (&String, &Datum)| args2.get(*k).and_then(|v2| v.unify(v2, &bindings)), 294 | ) 295 | } 296 | } 297 | match *self { 298 | Datum::Variable(ref var_name) => bindings.update_bindings(var_name, other), 299 | Datum::Vector(ref args) => match *other { 300 | Datum::Vector(ref args2) if args.len() == args2.len() => unify_args(args, args2, &bindings), 301 | _ => None, 302 | }, 303 | Datum::Map(ref args) => match *other { 304 | Datum::Map(ref args2) if args.len() == args2.len() => unify_maps(args, args2, &bindings), 305 | _ => None, 306 | }, 307 | _ => match *other { 308 | Datum::Variable(ref var_name) => bindings.update_bindings(var_name, self), 309 | _ => { 310 | if self == other { 311 | Some(bindings.clone()) 312 | } else { 313 | None 314 | } 315 | } 316 | }, 317 | } 318 | } 319 | 320 | fn apply_bindings(&self, bindings: &core::Bindings) -> Option { 321 | fn apply_bindings_to_args(args: &Vec, bindings: &core::Bindings) -> Option> { 322 | utils::map_while_some(&mut args.iter(), &|arg| { 323 | arg.apply_bindings(bindings) 324 | }) 325 | } 326 | fn apply_bindings_to_map(args: &BTreeMap, bindings: &core::Bindings) -> Option> { 327 | utils::map_while_some(&mut args.iter(), &|(k, v)| { 328 | v.apply_bindings(bindings) 329 | .and_then(|v| Some((k.clone(), v))) 330 | }).and_then(|tuple_vec| Some(tuple_vec.into_iter().collect::>())) 331 | } 332 | match *self { 333 | Datum::Variable(ref var_name) => bindings 334 | .get_binding(var_name) 335 | .or_else(|| Some(self.clone())), 336 | Datum::Vector(ref args) => apply_bindings_to_args(args, bindings).and_then(|args| Some(Datum::Vector(args))), 337 | Datum::Map(ref args) => apply_bindings_to_map(args, bindings).and_then(|args| Some(Datum::Map(args))), 338 | Datum::Function { ref head, ref args } => head.apply_bindings(bindings).and_then(|head| { 339 | apply_bindings_to_args(args, bindings).map(|args| Datum::Function { 340 | head: Box::new(head), 341 | args: args, 342 | }) 343 | }), 344 | _ => Some(self.clone()), 345 | } 346 | } 347 | 348 | fn variables(&self) -> Vec { 349 | match *self { 350 | Datum::Variable(ref v) => vec![v.clone()], 351 | Datum::Vector(ref args) => { 352 | let mut variables = Vec::new(); 353 | for arg in args.iter() { 354 | variables.extend(arg.variables().into_iter()); 355 | } 356 | variables 357 | } 358 | Datum::Map(ref args) => { 359 | let mut variables = Vec::new(); 360 | for v in args.values() { 361 | variables.extend(v.variables().into_iter()); 362 | } 363 | variables 364 | } 365 | _ => Vec::new(), 366 | } 367 | } 368 | 369 | fn rename_variables(&self, renamed_variables: &HashMap) -> Self { 370 | match *self { 371 | Datum::Variable(ref v) => renamed_variables 372 | .get(v) 373 | .and_then(|new_v| Some(Datum::Variable(new_v.clone()))) 374 | .or_else(|| Some(self.clone())) 375 | .unwrap(), 376 | Datum::Vector(ref args) => Datum::Vector( 377 | args.iter() 378 | .map(|arg| arg.rename_variables(renamed_variables)) 379 | .collect(), 380 | ), 381 | Datum::Map(ref args) => Datum::Map( 382 | args.iter() 383 | .map(|(k, v)| (k.clone(), v.rename_variables(renamed_variables))) 384 | .collect(), 385 | ), 386 | _ => self.clone(), 387 | } 388 | } 389 | 390 | fn nil() -> Self { 391 | Datum::Nil 392 | } 393 | } 394 | 395 | impl std::fmt::Display for Datum { 396 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 397 | write!(f, "{}", self.pprint()) 398 | } 399 | } 400 | 401 | #[cfg(test)] 402 | mod tests; 403 | -------------------------------------------------------------------------------- /src/datum/tests.rs: -------------------------------------------------------------------------------- 1 | use super::super::core::{Bindings, Unify}; 2 | use datum::*; 3 | 4 | macro_rules! assert_some_value { 5 | ($x:expr, $y:expr) => (match $x { 6 | Some(val) => assert_eq!(val, $y), 7 | None => panic!("Expected value but received 'None'"), 8 | }) 9 | } 10 | 11 | macro_rules! assert_none { 12 | ($x:expr) => (match $x { 13 | None => (), 14 | Some(val) => panic!("Expected 'None' received {}", val), 15 | }) 16 | } 17 | 18 | macro_rules! datum_json { 19 | ($json: tt) => ({ 20 | use serde_json; 21 | let d: Datum = serde_json::from_value(json!($json)).expect("Expected json decoding"); 22 | d 23 | }) 24 | } 25 | 26 | #[test] 27 | fn test_unify_passes_when_variables_match() { 28 | let d = from_json!(Datum, { 29 | "vec": [{"str": "action"}, {"int": 1}, {"var": "?::t1"}] 30 | }); 31 | let bindings = Bindings::new().set_binding(&"?::t1".to_string(), Datum::Float(2.0)); 32 | assert_eq!(d.unify(&d, &bindings), Some(bindings)); 33 | } 34 | 35 | #[test] 36 | fn test_unify_passes_when_match_with_new_variable_in_self() { 37 | let d = from_json!(Datum, { 38 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"var": "?x"}] 39 | }); 40 | let d2 = from_json!(Datum, { 41 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"str": "man"}] 42 | }); 43 | let expected_bindings = Bindings::new().set_binding(&"?x".to_string(), Datum::String("man".to_string())); 44 | let actual_bindings = d.unify(&d2, &Bindings::new()); 45 | assert_some_value!(actual_bindings, expected_bindings); 46 | } 47 | 48 | #[test] 49 | fn test_unify_passes_when_match_with_new_variable_in_other() { 50 | let d = from_json!(Datum, { 51 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"var": "?x"}] 52 | }); 53 | let d2 = from_json!(Datum, { 54 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"str": "man"}] 55 | }); 56 | let expected_bindings = Bindings::new().set_binding(&"?x".to_string(), Datum::String("man".to_string())); 57 | let actual_bindings = d2.unify(&d, &Bindings::new()); 58 | assert_some_value!(actual_bindings, expected_bindings); 59 | } 60 | 61 | #[test] 62 | fn test_unify_passes_with_matching_vectors() { 63 | let d = datum_json!( 64 | {"vec": [{"vec": [{"str": "current-state"}, {"var": "?s1"}]}, 65 | {"vec": [{"str": "time"}, {"var": "?t1"}]}]}); 66 | let d2 = datum_json!({"vec": [ 67 | {"vec":[{"str": "current-state"}, {"float": 0}]}, 68 | {"vec":[{"str": "time"}, {"float": 0}]}, 69 | ]}); 70 | let expected_bindings = Bindings::new() 71 | .set_binding(&"?s1".to_string(), Datum::Float(0.0)) 72 | .set_binding(&"?t1".to_string(), Datum::Float(0.0)); 73 | let actual_bindings = d2.unify(&d, &Bindings::new()); 74 | assert_some_value!(actual_bindings, expected_bindings); 75 | } 76 | 77 | #[test] 78 | fn test_unify_passes_when_bindings_in_self_match() { 79 | let d = from_json!(Datum, { 80 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"var": "?x"}] 81 | }); 82 | let d2 = from_json!(Datum, { 83 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"str": "man"}] 84 | }); 85 | let bindings = Bindings::new().set_binding(&"?x".to_string(), Datum::String("man".to_string())); 86 | let actual_bindings = d.unify(&d2, &bindings); 87 | assert_some_value!(actual_bindings, bindings); 88 | } 89 | 90 | #[test] 91 | fn test_unify_passes_when_bindings_in_other_match() { 92 | let d = from_json!(Datum, { 93 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"var": "?x"}] 94 | }); 95 | let d2 = from_json!(Datum, { 96 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"str": "man"}] 97 | }); 98 | let bindings = Bindings::new().set_binding(&"?x".to_string(), Datum::String("man".to_string())); 99 | let actual_bindings = d2.unify(&d, &bindings); 100 | assert_some_value!(actual_bindings, bindings); 101 | } 102 | 103 | #[test] 104 | fn test_unify_fails_when_bindings_conflict() { 105 | let d = from_json!(Datum, { 106 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"var": "?x"}] 107 | }); 108 | let d2 = from_json!(Datum, { 109 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"str": "man"}] 110 | }); 111 | let bindings = Bindings::new().set_binding(&"?x".to_string(), Datum::String("mortal".to_string())); 112 | let actual_bindings = d.unify(&d2, &bindings); 113 | assert_none!(actual_bindings); 114 | } 115 | 116 | #[test] 117 | fn test_unify_with_nesting() { 118 | let d = from_json!(Datum, { 119 | "vec": [ 120 | {"str": "reward"}, 121 | {"vec": [{"str": "value"}, {"int": 5}]}, 122 | {"vec": [{"str": "time"}, {"int": 608356800}]}, 123 | {"vec": [{"str": "type"}, {"str": "observation"}]} 124 | ] 125 | }); 126 | let d2 = from_json!(Datum, { 127 | "vec": [ 128 | {"str": "reward"}, 129 | {"vec": [{"str": "value"}, {"var": "?rv"}]}, 130 | {"vec": [{"str": "time"}, {"var": "?t"}]}, 131 | {"vec": [{"str": "type"}, {"str": "observation"}]} 132 | ] 133 | }); 134 | let expected_bindings = Bindings::new() 135 | .set_binding(&"?t".to_string(), Datum::Int(608356800)) 136 | .set_binding(&"?rv".to_string(), Datum::Int(5)); 137 | assert_eq!(d.unify(&d2, &Bindings::new()), Some(expected_bindings)); 138 | } 139 | 140 | #[test] 141 | fn test_unify_fails_when_no_match() { 142 | let d = from_json!(Datum, { 143 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"str": "mortal"}] 144 | }); 145 | let d2 = from_json!(Datum, { 146 | "vec": [{"str": "isa"}, {"str": "socrates"}, {"str": "man"}] 147 | }); 148 | let actual_bindings = d.unify(&d2, &Bindings::new()); 149 | assert_none!(actual_bindings); 150 | } 151 | -------------------------------------------------------------------------------- /src/infer/mod.rs: -------------------------------------------------------------------------------- 1 | //! The infer module implements basic forward chaining inference by applying any applicable Operations to a vector of Unifys. 2 | 3 | use constraints::ConstraintValue; 4 | use core::{Bindings, BindingsValue, Operation, Unify}; 5 | use pedigree::{Origin, Pedigree, RenderType}; 6 | use planner::{ConjunctivePlanner, Goal, PlanningConfig}; 7 | use serde_json; 8 | use std; 9 | use std::collections::{BTreeMap, BTreeSet}; 10 | use std::collections::HashMap; 11 | use std::marker::PhantomData; 12 | use utils; 13 | 14 | #[derive(Clone, Debug, Deserialize, PartialEq, PartialOrd, Serialize)] 15 | pub struct Negatable> { 16 | content: U, 17 | #[serde(default)] 18 | is_negative: bool, 19 | #[serde(default)] 20 | _marker: PhantomData, 21 | } 22 | 23 | impl Eq for Negatable 24 | where 25 | B: BindingsValue, 26 | U: Unify, 27 | { 28 | } 29 | 30 | impl std::fmt::Display for Negatable 31 | where 32 | B: BindingsValue, 33 | U: Unify, 34 | { 35 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 36 | write!(f, "{}", serde_json::to_string(&self).unwrap()) 37 | } 38 | } 39 | 40 | impl Unify for Negatable 41 | where 42 | B: BindingsValue, 43 | U: Unify, 44 | { 45 | fn unify(&self, other: &Self, bindings: &Bindings) -> Option> { 46 | self.content.unify(&other.content, bindings) 47 | } 48 | fn apply_bindings(&self, bindings: &Bindings) -> Option { 49 | self.content 50 | .apply_bindings(bindings) 51 | .and_then(|bound_content| { 52 | Some(Negatable { 53 | content: bound_content, 54 | is_negative: self.is_negative, 55 | _marker: PhantomData, 56 | }) 57 | }) 58 | } 59 | fn variables(&self) -> Vec { 60 | self.content.variables() 61 | } 62 | fn rename_variables(&self, renamed_variables: &HashMap) -> Self { 63 | Negatable { 64 | content: self.content.rename_variables(renamed_variables), 65 | is_negative: self.is_negative, 66 | _marker: PhantomData, 67 | } 68 | } 69 | fn nil() -> Self { 70 | Negatable { 71 | content: U::nil(), 72 | is_negative: false, 73 | _marker: PhantomData, 74 | } 75 | } 76 | } 77 | 78 | #[derive(Clone, Debug, PartialEq)] 79 | pub struct OriginCache { 80 | items: BTreeSet, 81 | } 82 | 83 | impl OriginCache { 84 | pub fn new() -> Self { 85 | OriginCache { 86 | items: BTreeSet::new(), 87 | } 88 | } 89 | 90 | pub fn has_item(&self, item: &Origin) -> bool { 91 | self.items.contains(item) 92 | } 93 | 94 | pub fn insert_item_mut(&mut self, item: Origin) { 95 | self.items.insert(item); 96 | } 97 | } 98 | 99 | #[derive(Clone, Debug, PartialEq)] 100 | pub struct InferenceEngine<'a, T, U, A> 101 | where 102 | T: 'a + ConstraintValue, 103 | U: 'a + Unify, 104 | A: 'a + Operation, 105 | { 106 | pub rules: Vec<(&'a String, &'a A)>, 107 | pub facts: Vec<(&'a String, &'a U)>, 108 | // Facts derived from this inference process 109 | pub derived_facts: Vec<(String, U)>, 110 | pub pedigree: Pedigree, 111 | pub prefix: String, 112 | // Used to check if an inference has already been performed, 113 | // allowing us to short-circuit a potentially expensive unification process. 114 | pub origin_cache: OriginCache, 115 | _marker: PhantomData, 116 | } 117 | 118 | impl<'a, T, U, A> InferenceEngine<'a, T, U, A> 119 | where 120 | T: 'a + ConstraintValue, 121 | U: 'a + Unify, 122 | A: 'a + Operation, 123 | { 124 | pub fn new(prefix: String, rules: Vec<(&'a String, &'a A)>, facts: Vec<(&'a String, &'a U)>) -> Self { 125 | InferenceEngine { 126 | rules: rules, 127 | facts: facts, 128 | derived_facts: Vec::new(), 129 | pedigree: Pedigree::new(), 130 | prefix: prefix, 131 | origin_cache: OriginCache::new(), 132 | _marker: PhantomData, 133 | } 134 | } 135 | 136 | pub fn all_facts(&'a self) -> Vec<(&'a String, &'a U)> { 137 | self.derived_facts 138 | .iter() 139 | .map(|&(ref id, ref f)| (id, f)) 140 | .chain(self.facts.iter().map(|&(id, f)| (id, f))) 141 | .collect() 142 | } 143 | 144 | pub fn chain_until_match(&self, max_iterations: usize, goal: &U) -> (Option<(U, String)>, Self) { 145 | self.chain_until(max_iterations, &|f| { 146 | goal.unify(f, &Bindings::new()).is_some() 147 | }) 148 | } 149 | 150 | pub fn chain_until(&self, max_iterations: usize, satisfied: &Fn(&U) -> bool) -> (Option<(U, String)>, Self) { 151 | let mut engine = self.clone(); 152 | let mut target: Option<(U, String)> = None; 153 | for _idx in 0..max_iterations { 154 | for (fact, _bindings, origin) in engine.chain_forward().into_iter() { 155 | let id = engine.construct_id(&fact); 156 | 157 | if satisfied(&fact) { 158 | target = Some((fact.clone(), id.clone())); 159 | } 160 | 161 | engine.pedigree.insert_mut(id.clone(), origin); 162 | engine.derived_facts.push((id, fact)); 163 | } 164 | if target.is_some() { 165 | break; 166 | } 167 | } 168 | (target, engine) 169 | } 170 | 171 | pub fn chain_forward(&mut self) -> Vec<(U, Bindings, Origin)> { 172 | let mut origin_cache = self.origin_cache.clone(); 173 | let results = chain_forward(self.all_facts(), self.rules.clone(), &mut origin_cache); 174 | self.origin_cache = origin_cache; 175 | results 176 | } 177 | 178 | fn construct_id(&self, _fact: &U) -> String { 179 | format!("{}-{}", self.prefix, self.derived_facts.len()) 180 | } 181 | 182 | pub fn render_inference_tree(&'a self, id: &String, render_type: RenderType) -> String { 183 | let all_facts_map: BTreeMap<&'a String, &'a U> = self.all_facts().into_iter().collect(); 184 | let rule_map: BTreeMap<&'a String, &'a A> = self.rules.clone().into_iter().collect(); 185 | 186 | let node_renderer = |x| { 187 | all_facts_map 188 | .get(&x) 189 | .and_then(|y| Some(format!("{}", y))) 190 | .or_else(|| rule_map.get(&x).and_then(|y| Some(format!("{}", y)))) 191 | .unwrap_or(format!("{}?", x)) 192 | }; 193 | 194 | self.pedigree.render_inference_tree( 195 | id, 196 | &node_renderer, 197 | &node_renderer, 198 | &|x, _y| x.clone(), 199 | render_type, 200 | ) 201 | } 202 | } 203 | 204 | pub fn chain_forward(facts: Vec<(&String, &U)>, rules: Vec<(&String, &A)>, origin_cache: &mut OriginCache) -> Vec<(U, Bindings, Origin)> 205 | where 206 | T: ConstraintValue, 207 | U: Unify, 208 | A: Operation, 209 | { 210 | let mut derived_facts: Vec<(U, Bindings, Origin)> = Vec::new(); 211 | let just_the_facts: Vec<&U> = facts.iter().map(|&(_id, u)| u).collect(); 212 | 213 | for (ref rule_id, ref rule) in rules.into_iter() { 214 | let planner: ConjunctivePlanner = ConjunctivePlanner::new( 215 | rule.input_patterns() 216 | .into_iter() 217 | .map(Goal::with_pattern) 218 | .collect(), 219 | &Bindings::new(), 220 | &PlanningConfig::default(), 221 | just_the_facts.clone(), 222 | Vec::new(), 223 | ); 224 | let application_successful = 225 | |(input_goals, bindings): (Vec>, Bindings)| -> Option<(Vec>, Vec, Bindings)> { 226 | let bound_input_goals: Vec> = input_goals 227 | .iter() 228 | .map(|input_goal| { 229 | input_goal 230 | .apply_bindings(&bindings) 231 | .expect("Should be applicable") 232 | }) 233 | .collect(); 234 | rule.apply_match(&bindings) 235 | .and_then(|new_facts| Some((bound_input_goals, new_facts, bindings))) 236 | }; 237 | 238 | for (matched_inputs, new_facts, bindings) in planner.filter_map(application_successful) { 239 | let fact_ids: Vec = extract_datum_indexes(&matched_inputs) 240 | .iter() 241 | .map(|idx| facts[*idx].0.clone()) 242 | .collect(); 243 | let origin = Origin { 244 | source_id: (*rule_id).clone(), 245 | args: fact_ids, 246 | }; 247 | if origin_cache.has_item(&origin) { 248 | continue; 249 | } else { 250 | origin_cache.insert_item_mut(origin.clone()); 251 | } 252 | for new_fact in new_facts { 253 | if is_new_fact(&new_fact, &facts) { 254 | derived_facts.push((new_fact, bindings.clone(), origin.clone())) 255 | } 256 | } 257 | } 258 | } 259 | derived_facts 260 | } 261 | 262 | pub fn chain_forward_with_negative_goals( 263 | facts: Vec<(&String, &Negatable)>, 264 | rules: Vec<(&String, &A)>, 265 | origin_cache: &mut OriginCache, 266 | ) -> Vec<(Negatable, Bindings, Origin)> 267 | where 268 | T: ConstraintValue, 269 | IU: Unify, 270 | A: Operation>, 271 | { 272 | let mut derived_facts: Vec<(Negatable, Bindings, Origin)> = Vec::new(); 273 | let just_the_facts: Vec<&Negatable> = facts.iter().map(|&(_id, u)| u).collect(); 274 | 275 | for (ref rule_id, ref rule) in rules.into_iter() { 276 | let (negative_inputs, positive_inputs): (Vec>, Vec>) = rule.input_patterns() 277 | .into_iter() 278 | .partition(|input| input.is_negative); 279 | let planner: ConjunctivePlanner, A> = ConjunctivePlanner::new( 280 | positive_inputs 281 | .into_iter() 282 | .map(Goal::with_pattern) 283 | .collect(), 284 | &Bindings::new(), 285 | &PlanningConfig::default(), 286 | just_the_facts.clone().into_iter().collect(), 287 | Vec::new(), 288 | ); 289 | 290 | let negative_patterns_are_satisfied = |(input_goals, bindings)| { 291 | utils::map_while_some(&mut negative_inputs.iter(), &|pattern| { 292 | pattern.apply_bindings(&bindings) 293 | }).and_then(|bound_negative_patterns| { 294 | if any_patterns_match(&bound_negative_patterns.iter().collect(), &just_the_facts) { 295 | None 296 | } else { 297 | Some((input_goals, bindings)) 298 | } 299 | }) 300 | }; 301 | let application_successful = |(input_goals, bindings)| { 302 | rule.apply_match(&bindings) 303 | .and_then(|new_facts| Some((input_goals, new_facts, bindings))) 304 | }; 305 | 306 | for (matched_inputs, new_facts, bindings) in planner 307 | .filter_map(negative_patterns_are_satisfied) 308 | .filter_map(application_successful) 309 | { 310 | let fact_ids: Vec = extract_datum_indexes(&matched_inputs) 311 | .iter() 312 | .map(|idx| facts[*idx].0.clone()) 313 | .collect(); 314 | let origin = Origin { 315 | source_id: (*rule_id).clone(), 316 | args: fact_ids, 317 | }; 318 | if origin_cache.has_item(&origin) { 319 | continue; 320 | } else { 321 | origin_cache.insert_item_mut(origin.clone()); 322 | } 323 | for new_fact in new_facts { 324 | if is_new_fact(&new_fact, &facts) { 325 | derived_facts.push((new_fact, bindings.clone(), origin.clone())) 326 | } 327 | } 328 | } 329 | } 330 | derived_facts 331 | } 332 | 333 | fn any_patterns_match(patterns: &Vec<&U>, patterns2: &Vec<&U>) -> bool 334 | where 335 | B: BindingsValue, 336 | U: Unify, 337 | { 338 | let empty_bindings: Bindings = Bindings::new(); 339 | patterns.iter().any(|patt| { 340 | patterns2 341 | .iter() 342 | .any(|f| f.unify(patt, &empty_bindings).is_some()) 343 | }) 344 | } 345 | 346 | fn extract_datum_indexes(goals: &Vec>) -> Vec 347 | where 348 | T: ConstraintValue, 349 | U: Unify, 350 | A: Operation, 351 | { 352 | goals 353 | .iter() 354 | .map(|goal| { 355 | goal.unification_index 356 | .datum_idx() 357 | .expect("Only datum idx should be here!") 358 | }) 359 | .collect() 360 | } 361 | 362 | fn is_new_fact(f: &U, facts: &Vec<(&String, &U)>) -> bool 363 | where 364 | T: ConstraintValue, 365 | U: Unify, 366 | { 367 | let empty_bindings = Bindings::new(); 368 | !facts 369 | .iter() 370 | .any(|&(_id, fact)| fact.unify(f, &empty_bindings).is_some()) 371 | } 372 | 373 | #[cfg(test)] 374 | mod tests; 375 | -------------------------------------------------------------------------------- /src/infer/tests.rs: -------------------------------------------------------------------------------- 1 | use core::Bindings; 2 | use datum::Datum; 3 | use infer::{chain_forward_with_negative_goals, InferenceEngine, Negatable, OriginCache}; 4 | use pedigree::{InferenceChain, Origin}; 5 | use rule::Rule; 6 | 7 | #[test] 8 | fn test_forward_chain() { 9 | let r_id = "rule-0".to_string(); 10 | let r = from_json!(Rule, { 11 | "lhs": [{"vec": [{"str": "has-features"}, {"var": "?x"}]}], 12 | "rhs": {"vec": [{"str": "bird"}, {"var": "?x"}]}, 13 | }); 14 | let rules: Vec<(&String, &Rule)> = vec![(&r_id, &r)]; 15 | 16 | let f_id = "fact-0".to_string(); 17 | let f = from_json!(Datum, {"vec": [{"str": "has-features"}, {"str": "bonnie"}]}); 18 | let facts: Vec<(&String, &Datum)> = vec![(&f_id, &f)]; 19 | 20 | let mut engine = InferenceEngine::new("test".to_string(), rules, facts); 21 | let new_facts = engine.chain_forward(); 22 | 23 | let expected_new_fact = from_json!(Datum, {"vec": [{"str": "bird"}, {"str": "bonnie"}]}); 24 | let expected_bindings = Bindings::new().set_binding(&"?x".to_string(), Datum::String("bonnie".to_string())); 25 | 26 | assert_eq!(new_facts.len(), 1); 27 | assert_eq!( 28 | new_facts, 29 | vec![ 30 | ( 31 | expected_new_fact, 32 | expected_bindings, 33 | Origin { 34 | source_id: "rule-0".to_string(), 35 | args: vec!["fact-0".to_string()], 36 | }, 37 | ), 38 | ] 39 | ); 40 | } 41 | 42 | #[test] 43 | fn test_chain_until_match() { 44 | let rules = from_json!(Vec>, [ 45 | { 46 | "lhs": [{"vec": [{"str": "current-value"}, {"var": "?x"}]}], 47 | "rhs": {"vec": [{"str": "current-value"}, {"var": "?y"}]}, 48 | "constraints": [ 49 | {"numerical": {"set": {"variable": "?diff", "constant": 1}}}, 50 | {"numerical": {"sum": {"first": "?x", "second": "?diff", "third": "?y"}}} 51 | ] 52 | }, 53 | { 54 | "lhs": [{"vec": [{"str": "current-value"}, {"var": "?x"}]}], 55 | "rhs": {"vec": [{"str": "current-value"}, {"var": "?y"}]}, 56 | "constraints": [ 57 | {"numerical": {"set": {"variable": "?diff", "constant": -1}}}, 58 | {"numerical": {"sum": {"first": "?x", "second": "?diff", "third": "?y"}}} 59 | ] 60 | }, 61 | { 62 | "lhs": [{"vec": [{"str": "current-value"}, {"var": "?x"}]}], 63 | "rhs": {"vec": [{"str": "current-value"}, {"var": "?y"}]}, 64 | "constraints": [ 65 | {"numerical": {"set": {"variable": "?factor", "constant": 2}}}, 66 | {"numerical": {"mul": {"first": "?x", "second": "?factor", "third": "?y"}}} 67 | ] 68 | }, 69 | { 70 | "lhs": [{"vec": [{"str": "current-value"}, {"var": "?x"}]}], 71 | "rhs": {"vec": [{"str": "current-value"}, {"var": "?y"}]}, 72 | "constraints": [ 73 | {"numerical": {"set": {"variable": "?factor", "constant": 0.5}}}, 74 | {"numerical": {"mul": {"first": "?x", "second": "?factor", "third": "?y"}}} 75 | ] 76 | } 77 | ]); 78 | 79 | let rule_ids = vec![ 80 | "add_one".to_string(), 81 | "subtract_one".to_string(), 82 | "double".to_string(), 83 | "halve".to_string(), 84 | ]; 85 | 86 | let rules: Vec<(&String, &Rule)> = rule_ids.iter().zip(rules.iter()).collect(); 87 | 88 | let f_id = "fact-0".to_string(); 89 | let f = from_json!(Datum, {"vec": [{"str": "current-value"}, {"float": 0.0}]}); 90 | let facts: Vec<(&String, &Datum)> = vec![(&f_id, &f)]; 91 | 92 | let goal = from_json!(Datum, {"vec": [{"str": "current-value"}, {"float": 4.0}]}); 93 | let engine = InferenceEngine::new("test".to_string(), rules, facts); 94 | let (result, _engine) = engine.chain_until_match(4, &goal); 95 | assert_eq!(result.is_some(), true); 96 | let (target_fact, _target_fact_id) = result.unwrap(); 97 | assert_eq!(target_fact, goal); 98 | } 99 | 100 | #[test] 101 | fn test_chain_until_match_updates_pedigree() { 102 | let rules = from_json!(Vec>, [ 103 | { 104 | "lhs": [{"vec": [{"str": "current-value"}, {"var": "?x"}]}], 105 | "rhs": {"vec": [{"str": "current-value"}, {"var": "?y"}]}, 106 | "constraints": [ 107 | {"numerical": {"set": {"variable": "?diff", "constant": 1}}}, 108 | {"numerical": {"sum": {"first": "?x", "second": "?diff", "third": "?y"}}} 109 | ] 110 | }, 111 | { 112 | "lhs": [{"vec": [{"str": "current-value"}, {"var": "?x"}]}], 113 | "rhs": {"vec": [{"str": "current-value"}, {"var": "?y"}]}, 114 | "constraints": [ 115 | {"numerical": {"set": {"variable": "?factor", "constant": 2}}}, 116 | {"numerical": {"mul": {"first": "?x", "second": "?factor", "third": "?y"}}} 117 | ] 118 | } 119 | ]); 120 | 121 | let add_one_id = "add_one".to_string(); 122 | let double_id = "double".to_string(); 123 | let rule_ids = vec![add_one_id.clone(), double_id.clone()]; 124 | 125 | let rules: Vec<(&String, &Rule)> = rule_ids.iter().zip(rules.iter()).collect(); 126 | 127 | let f_id = "fact-0".to_string(); 128 | let f = from_json!(Datum, {"vec": [{"str": "current-value"}, {"float": 0}]}); 129 | let facts: Vec<(&String, &Datum)> = vec![(&f_id, &f)]; 130 | 131 | let goal = from_json!(Datum, {"vec": [{"str": "current-value"}, {"float": 4}]}); 132 | let engine = InferenceEngine::new("test".to_string(), rules, facts); 133 | let (result, engine) = engine.chain_until_match(4, &goal); 134 | assert_eq!(result.is_some(), true); 135 | let (target_fact, target_fact_id) = result.unwrap(); 136 | assert_eq!(target_fact, goal); 137 | 138 | let test_id_0 = "test-0".to_string(); 139 | let test_0_origin = Origin { 140 | source_id: "add_one".to_string(), 141 | args: vec!["fact-0".to_string()], 142 | }; 143 | 144 | let test_id_2 = "test-2".to_string(); 145 | let test_2_origin = Origin { 146 | source_id: "double".to_string(), 147 | args: vec!["test-0".to_string()], 148 | }; 149 | 150 | let test_id_6 = "test-6".to_string(); 151 | let test_6_origin = Origin { 152 | source_id: "double".to_string(), 153 | args: vec!["test-2".to_string()], 154 | }; 155 | 156 | let inference_chain = engine.pedigree.extract_inference_chain(&target_fact_id); 157 | let expected_inference_chain = InferenceChain { 158 | elements: vec![ 159 | (vec![(test_id_6.clone(), Some(test_6_origin.clone()))]), 160 | (vec![ 161 | (double_id.clone(), None), 162 | (test_id_2.clone(), Some(test_2_origin.clone())), 163 | ]), 164 | (vec![ 165 | (double_id.clone(), None), 166 | (test_id_0.clone(), Some(test_0_origin.clone())), 167 | ]), 168 | (vec![(add_one_id.clone(), None), (f_id.clone(), None)]), 169 | ], 170 | }; 171 | assert_eq!(inference_chain, expected_inference_chain); 172 | } 173 | 174 | #[test] 175 | fn test_chain_forward_with_negative_goals() { 176 | let rules = from_json!(Vec>>, [ 177 | { 178 | "lhs": [ 179 | {"content": {"vec": [{"str": "current-value"}, {"var": "?x"}]}}, 180 | {"content": {"vec": [{"str": "even-value"}, {"var": "?x"}]}, "is_negative": true}, 181 | ], 182 | "rhs": {"content": {"vec": [{"str": "odd-value"}, {"var": "?x"}]}}, 183 | }, 184 | ]); 185 | let data = from_json!(Vec>, [ 186 | {"content": {"vec": [{"str": "current-value"}, {"int": 0}]}}, 187 | {"content": {"vec": [{"str": "even-value"}, {"int": 0}]}}, 188 | {"content": {"vec": [{"str": "current-value"}, {"int": 1}]}}, 189 | {"content": {"vec": [{"str": "current-value"}, {"int": 2}]}}, 190 | {"content": {"vec": [{"str": "even-value"}, {"int": 2}]}}, 191 | {"content": {"vec": [{"str": "current-value"}, {"int": 3}]}}, 192 | ]); 193 | let expected_derived_facts = from_json!(Vec>, [ 194 | {"content": {"vec": [{"str": "odd-value"}, {"int": 1}]}}, 195 | {"content": {"vec": [{"str": "odd-value"}, {"int": 3}]}}, 196 | ]); 197 | 198 | let data_ids = vec![ 199 | "d0".to_string(), 200 | "d1".to_string(), 201 | "d2".to_string(), 202 | "d3".to_string(), 203 | "d4".to_string(), 204 | "d5".to_string(), 205 | ]; 206 | let data_with_ids: Vec<(&String, &Negatable)> = data_ids.iter().zip(data.iter()).collect(); 207 | let rule_ids = vec!["r0".to_string()]; 208 | let rules_with_ids: Vec<(&String, &Rule>)> = rule_ids.iter().zip(rules.iter()).collect(); 209 | let results = chain_forward_with_negative_goals(data_with_ids, rules_with_ids, &mut OriginCache::new()); 210 | let derived_facts: Vec> = results.into_iter().map(|(f, _, _)| f).collect(); 211 | assert_eq!(derived_facts, expected_derived_facts); 212 | } 213 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! AI_Kit aims to be a single dependency for various clssic AI algorithms. 2 | //! 3 | //! Core project goals are: 4 | //! 5 | //! * convenient and ergonomic interfaces to various algorithms by building around traits. 6 | //! 7 | //! * only build what you need through the use of feature flags 8 | //! 9 | //! * performance 10 | //! 11 | //! * easy to understand implementations 12 | //! 13 | //! All of the algorithms (documented below) operate on several core traits, `BindingsValue`, `Unify`, `Operation`. 14 | //! 15 | 16 | // Needed for tests 17 | #![recursion_limit="128"] 18 | #![deny(missing_debug_implementations, missing_copy_implementations, 19 | trivial_casts, trivial_numeric_casts, 20 | unsafe_code, 21 | unstable_features, 22 | unused_import_braces)] 23 | 24 | extern crate itertools; 25 | extern crate serde; 26 | #[allow(unused_imports)] 27 | #[macro_use] 28 | extern crate serde_json; 29 | #[macro_use] 30 | extern crate serde_derive; 31 | extern crate uuid; 32 | 33 | #[cfg(test)] 34 | #[macro_use] 35 | mod test_utils; 36 | #[cfg(feature = "with-constraint")] 37 | pub mod constraints; 38 | #[macro_use] 39 | pub mod core; 40 | #[cfg(feature = "with-datum")] 41 | pub mod datum; 42 | #[cfg(feature = "with-forward-inference")] 43 | pub mod infer; 44 | #[cfg(feature = "with-pedigree")] 45 | pub mod pedigree; 46 | #[cfg(feature = "with-planner")] 47 | pub mod planner; 48 | #[cfg(feature = "with-rule")] 49 | pub mod rule; 50 | pub mod utils; 51 | -------------------------------------------------------------------------------- /src/pedigree/mod.rs: -------------------------------------------------------------------------------- 1 | //! The pedigree module implements functionality for tracking which Unify and Operation structures were used to derive a new Unify. 2 | 3 | use itertools::Itertools; 4 | use serde_json; 5 | use std; 6 | use std::collections::btree_map::BTreeMap; 7 | use std::collections::btree_set::BTreeSet; 8 | 9 | // Describes how a given inference tree should be rendered 10 | #[derive(Clone, Copy, Debug)] 11 | pub enum RenderType { 12 | // Render all inferences that were derived in the process of deriving the specified inference 13 | Full, 14 | // Just render the direct ancestry for the specified inference 15 | Pedigree, 16 | } 17 | 18 | /// Represent the origin of a particular Unify 19 | #[derive(Clone, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)] 20 | pub struct Origin { 21 | /// What Operation does this Origin correspond to 22 | pub source_id: String, 23 | /// What data did the source use to construct the entity that this Origin corresponds to 24 | pub args: Vec, 25 | } 26 | 27 | impl std::fmt::Display for Origin { 28 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 29 | write!(f, "{}", serde_json::to_string(&self).unwrap()) 30 | } 31 | } 32 | 33 | impl Origin { 34 | pub fn new() -> Self { 35 | Origin { 36 | source_id: String::new(), 37 | args: Vec::new(), 38 | } 39 | } 40 | 41 | pub fn with_source(source: String) -> Self { 42 | Origin { 43 | source_id: source, 44 | args: Vec::new(), 45 | } 46 | } 47 | 48 | pub fn ancestors(&self) -> Vec<&String> { 49 | self.args 50 | .iter() 51 | .chain(std::iter::once(&self.source_id)) 52 | .collect() 53 | } 54 | } 55 | 56 | /// Used for iterating through the ancestry for a given inference 57 | #[derive(Clone, Debug, Eq, PartialEq)] 58 | pub struct InferenceGraphBackwardIterator<'a> { 59 | inf_graph: &'a InferenceGraph, 60 | next_generation: Option, 61 | } 62 | 63 | impl<'a> Iterator for InferenceGraphBackwardIterator<'a> { 64 | type Item = ( 65 | Vec<(&'a String, Option<&'a Origin>)>, 66 | Vec<&'a BTreeSet>, 67 | ); 68 | 69 | fn next( 70 | &mut self, 71 | ) -> Option< 72 | ( 73 | Vec<(&'a String, Option<&'a Origin>)>, 74 | Vec<&'a BTreeSet>, 75 | ), 76 | > { 77 | let current_generation = self.next_generation; 78 | 79 | self.next_generation = self.next_generation 80 | .and_then(|idx| if idx == 0 { None } else { Some(idx - 1) }); 81 | 82 | let construct_id_origin_tuple = |current_id| (current_id, self.inf_graph.pedigree.get_ancestor(current_id)); 83 | 84 | let construct_id_origin_tuples_for_generation = |generation: &'a BTreeSet| generation.iter().map(construct_id_origin_tuple).collect(); 85 | 86 | current_generation.and_then(|generation_idx| { 87 | self.inf_graph 88 | .entries_by_generation 89 | .get(generation_idx) 90 | .and_then(|generation| { 91 | Some(( 92 | construct_id_origin_tuples_for_generation(generation), 93 | self.inf_graph.subsequent_inferences(generation_idx + 1), 94 | )) 95 | }) 96 | }) 97 | } 98 | } 99 | 100 | /// Provide a convenient interface to a particular inference graph. 101 | #[derive(Clone, Debug, Eq, PartialEq)] 102 | pub struct InferenceGraph { 103 | pedigree: Pedigree, 104 | /// Track which generation each unify was created in; the root ocurrs in the last generation 105 | entries_by_generation: Vec>, 106 | /// Inverse of entries_by_generation; map each id into the generation it was introduced in 107 | entries_to_generation: BTreeMap, 108 | /// The endpoint data for the inference graph 109 | /// For forward chaining, this is the initial data used in the inference. 110 | /// For backward chaining, this is the final grounded results. 111 | leaves: BTreeSet, 112 | /// The goal of the infernece 113 | root: String, 114 | } 115 | 116 | impl<'a> InferenceGraph { 117 | pub fn new(root: String) -> Self { 118 | InferenceGraph { 119 | pedigree: Pedigree::new(), 120 | entries_by_generation: Vec::new(), 121 | entries_to_generation: BTreeMap::new(), 122 | leaves: BTreeSet::new(), 123 | root: root, 124 | } 125 | } 126 | 127 | pub fn back_iter(&self) -> InferenceGraphBackwardIterator { 128 | InferenceGraphBackwardIterator { 129 | inf_graph: self, 130 | next_generation: Some(self.entries_by_generation.len() - 1), 131 | } 132 | } 133 | 134 | pub fn root(&self) -> &String { 135 | &self.root 136 | } 137 | 138 | pub fn leaves(&'a self) -> &'a BTreeSet { 139 | &self.leaves 140 | } 141 | 142 | pub fn ancestor(&self, id: &String) -> Option<&Origin> { 143 | self.pedigree.get_ancestor(id) 144 | } 145 | 146 | pub fn descendent_inferences(&'a self, id: &String) -> Option<&'a BTreeSet> { 147 | // Return unify derived from this one 148 | self.pedigree.get_descendents(id) 149 | } 150 | 151 | pub fn subsequent_inferences(&'a self, generation: usize) -> Vec<&'a BTreeSet> { 152 | // Return all unify derived in and after the specified generation 153 | let mut subsequent_inferences = Vec::new(); 154 | 155 | for entries in self.entries_by_generation.iter().skip(generation) { 156 | subsequent_inferences.push(entries); 157 | } 158 | 159 | subsequent_inferences 160 | } 161 | 162 | pub fn all_ids(&'a self) -> BTreeSet<&'a String> { 163 | self.entries_to_generation.keys().collect() 164 | } 165 | } 166 | 167 | #[derive(Clone, Debug, Eq, PartialEq)] 168 | pub struct InferenceGraphBuilder { 169 | pedigree: Pedigree, 170 | entries_by_generation: Vec>, 171 | entries_to_generation: BTreeMap, 172 | leaves: BTreeSet, 173 | root: String, 174 | } 175 | 176 | impl InferenceGraphBuilder { 177 | pub fn new() -> Self { 178 | InferenceGraphBuilder { 179 | pedigree: Pedigree::new(), 180 | entries_by_generation: Vec::new(), 181 | entries_to_generation: BTreeMap::new(), 182 | leaves: BTreeSet::new(), 183 | root: String::new(), 184 | } 185 | } 186 | 187 | pub fn finalize(self) -> InferenceGraph { 188 | InferenceGraph { 189 | pedigree: self.pedigree, 190 | entries_by_generation: self.entries_by_generation, 191 | entries_to_generation: self.entries_to_generation, 192 | leaves: self.leaves, 193 | root: self.root, 194 | } 195 | } 196 | 197 | pub fn pedigree(self, pedigree: Pedigree) -> Self { 198 | let mut igraph = self.clone(); 199 | igraph.pedigree = pedigree; 200 | igraph 201 | } 202 | 203 | pub fn update_pedigree(self, id: String, origin: Origin) -> Self { 204 | let mut igraph = self.clone(); 205 | igraph.pedigree = igraph.pedigree.insert(id, origin); 206 | igraph 207 | } 208 | 209 | pub fn entries_by_generation(self, entries_by_generation: Vec>) -> Self { 210 | let mut igraph = self.clone(); 211 | igraph.entries_by_generation = entries_by_generation; 212 | igraph 213 | } 214 | 215 | pub fn extend_entries_by_generation(self, generation_idx: usize, entries: Vec) -> Self { 216 | let mut entries_by_generation = self.entries_by_generation.clone(); 217 | let new_entries: BTreeSet = entries.into_iter().collect(); 218 | 219 | if entries_by_generation.len() <= generation_idx { 220 | entries_by_generation.push(new_entries); 221 | } else { 222 | if let Some(generation) = entries_by_generation.get_mut(generation_idx) { 223 | generation.extend(new_entries); 224 | } 225 | } 226 | 227 | let mut igraph = self.clone(); 228 | igraph.entries_by_generation = entries_by_generation; 229 | igraph 230 | } 231 | 232 | pub fn entries_to_generation(self, entries_to_generation: BTreeMap) -> Self { 233 | let mut igraph = self.clone(); 234 | igraph.entries_to_generation = entries_to_generation; 235 | igraph 236 | } 237 | 238 | pub fn leaves(self, leaves: BTreeSet) -> Self { 239 | let mut igraph = self.clone(); 240 | igraph.leaves = leaves; 241 | igraph 242 | } 243 | 244 | pub fn update_leaves(self, id: String) -> Self { 245 | let mut igraph = self.clone(); 246 | igraph.leaves.insert(id); 247 | igraph 248 | } 249 | 250 | pub fn root(self, root: String) -> Self { 251 | let mut igraph = self.clone(); 252 | igraph.root = root; 253 | igraph 254 | } 255 | } 256 | 257 | #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] 258 | pub struct InferenceChain { 259 | pub elements: Vec<(Vec<(String, Option)>)>, 260 | } 261 | 262 | impl std::fmt::Display for InferenceChain { 263 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 264 | write!(f, "{}", serde_json::to_string(&self).unwrap()) 265 | } 266 | } 267 | 268 | #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] 269 | pub struct Pedigree { 270 | ancestors: BTreeMap, 271 | descendents: BTreeMap>, 272 | } 273 | 274 | impl Pedigree { 275 | pub fn new() -> Self { 276 | Pedigree { 277 | ancestors: BTreeMap::new(), 278 | descendents: BTreeMap::new(), 279 | } 280 | } 281 | 282 | pub fn insert(&self, id: String, org: Origin) -> Self { 283 | let mut pedigree = self.clone(); 284 | pedigree.insert_mut(id, org); 285 | pedigree 286 | } 287 | 288 | pub fn insert_mut(&mut self, id: String, org: Origin) { 289 | let source_id = org.source_id.clone(); 290 | self.ancestors.insert(id.clone(), org); 291 | 292 | if !self.descendents.contains_key(&source_id) { 293 | self.descendents.insert(source_id.clone(), BTreeSet::new()); 294 | } 295 | let inner_descendents = self.descendents.get_mut(&source_id).unwrap(); 296 | inner_descendents.insert(id); 297 | } 298 | 299 | pub fn get_ancestor(&self, id: &String) -> Option<&Origin> { 300 | self.ancestors.get(id) 301 | } 302 | 303 | pub fn get_descendents(&self, id: &String) -> Option<&BTreeSet> { 304 | self.descendents.get(id) 305 | } 306 | 307 | /// Remove all reference to the specified id 308 | pub fn purge(&self, id: &String) -> Self { 309 | let mut pedigree = self.clone(); 310 | pedigree.purge_mut(id); 311 | pedigree 312 | } 313 | 314 | pub fn purge_mut(&mut self, id: &String) { 315 | self.ancestors.remove(id); 316 | 317 | let mut descendents = self.descendents.clone(); 318 | 319 | for (ancestor_id, these_descendents) in self.descendents.iter() { 320 | if these_descendents.contains(id) { 321 | let mut these_descendents = these_descendents.clone(); 322 | these_descendents.remove(id); 323 | descendents.insert(ancestor_id.clone(), these_descendents); 324 | } 325 | } 326 | 327 | self.descendents = descendents; 328 | } 329 | 330 | pub fn extract_inference_chain(&self, root: &String) -> InferenceChain { 331 | InferenceChain { 332 | elements: self.extract_inference_graph(root) 333 | .back_iter() 334 | .map(|(generation, _subsequent_generations)| { 335 | generation 336 | .into_iter() 337 | .map(|(ref id, ref origin)| ((*id).clone(), (*origin).cloned())) 338 | .collect() 339 | }) 340 | .collect(), 341 | } 342 | } 343 | 344 | pub fn extract_inference_graph(&self, root: &String) -> InferenceGraph { 345 | let builder = InferenceGraphBuilder::new().root(root.clone()); 346 | 347 | // Note: we construct the entries_by_generation vector in reverse since we don't know in advance the how big it will be 348 | let mut builder = self.extract_inference_graph_helper(builder, vec![root], 0); 349 | 350 | // Reverse the order of the entries_by_generation so initial inferences are first and the root is last 351 | builder.entries_by_generation.reverse(); 352 | let mut entries_to_generation: BTreeMap = BTreeMap::new(); 353 | for (idx, entries) in builder.entries_by_generation.iter().enumerate() { 354 | for entry in entries.iter() { 355 | entries_to_generation.insert(entry.clone(), idx); 356 | } 357 | } 358 | builder.entries_to_generation = entries_to_generation; 359 | builder.finalize() 360 | } 361 | 362 | fn extract_inference_graph_helper( 363 | &self, 364 | builder: InferenceGraphBuilder, 365 | current_generation: Vec<&String>, 366 | generations_from_root: usize, 367 | ) -> InferenceGraphBuilder { 368 | if current_generation.len() > 0 { 369 | let builder = builder.extend_entries_by_generation( 370 | generations_from_root, 371 | current_generation.iter().cloned().cloned().collect(), 372 | ); 373 | current_generation 374 | .into_iter() 375 | .fold(builder, |builder, current_id| { 376 | match self.get_ancestor(current_id).cloned() { 377 | None => builder, 378 | Some(ref origin) => { 379 | let builder = builder.update_pedigree(current_id.clone(), origin.clone()); 380 | let builder = if origin.args.is_empty() { 381 | builder.update_leaves(current_id.clone()) 382 | } else { 383 | builder 384 | }; 385 | self.extract_inference_graph_helper( 386 | builder, 387 | origin.ancestors().iter().cloned().collect(), 388 | generations_from_root + 1, 389 | ) 390 | } 391 | } 392 | }) 393 | } else { 394 | builder 395 | } 396 | } 397 | 398 | pub fn render_inference_tree( 399 | &self, 400 | d_id: &String, 401 | root_renderer: &Fn(String) -> String, 402 | node_renderer: &Fn(String) -> String, 403 | relation_renderder: &Fn(String, String) -> String, 404 | render_type: RenderType, 405 | ) -> String { 406 | let s = match render_type { 407 | RenderType::Pedigree => self.render_inference_tree_pedigree(d_id, node_renderer), 408 | RenderType::Full => self.render_inference_tree_full(d_id, node_renderer, relation_renderder), 409 | }; 410 | format!( 411 | "graph \"inference chain for {}\" {{\n{}\n}}", 412 | root_renderer(d_id.clone()), 413 | s 414 | ) 415 | } 416 | 417 | pub fn render_inference_tree_full( 418 | &self, 419 | d_id: &String, 420 | node_renderer: &Fn(String) -> String, 421 | relation_renderder: &Fn(String, String) -> String, 422 | ) -> String { 423 | let inf_graph = self.extract_inference_graph(d_id); 424 | let mut relationships: Vec = Vec::new(); 425 | 426 | for (current_generation, future_generations) in inf_graph.back_iter() { 427 | // For each generation 428 | for future_generation in future_generations.iter() { 429 | // For each following generation 430 | for (&(ref current_id, ref _origin), descendent_id) in current_generation 431 | .iter() 432 | .cartesian_product(future_generation.iter()) 433 | { 434 | // For each pairing of current generation 435 | relationships.push(format!( 436 | "\"{}\" -- \"{}\" \"{}\"", 437 | node_renderer((*current_id).clone()), 438 | node_renderer(descendent_id.clone()), 439 | relation_renderder((*current_id).clone(), (*descendent_id).clone()) 440 | )); 441 | } 442 | } 443 | } 444 | relationships.iter().join(";\n") 445 | } 446 | 447 | pub fn render_inference_tree_pedigree(&self, d_id: &String, node_renderer: &Fn(String) -> String) -> String { 448 | let mut relationships: Vec = Vec::new(); 449 | if let Some(origin) = self.ancestors.get(d_id) { 450 | // Render the relationship to the parent actor 451 | relationships.push(format!( 452 | r#""{}" -- "{}""#, 453 | node_renderer(d_id.clone()), 454 | node_renderer(origin.source_id.clone()) 455 | )); 456 | for arg in origin.args.iter() { 457 | // Render the relationship to each parent actor argument 458 | relationships.push(format!( 459 | r#""{}" -- "{}""#, 460 | node_renderer(d_id.clone()), 461 | node_renderer(arg.clone()) 462 | )); 463 | // Render each parent actor argument's relationships 464 | relationships.push(self.render_inference_tree_pedigree(arg, node_renderer)); 465 | } 466 | } 467 | relationships.iter().join(";\n") 468 | } 469 | } 470 | 471 | #[cfg(test)] 472 | mod tests; 473 | -------------------------------------------------------------------------------- /src/pedigree/tests.rs: -------------------------------------------------------------------------------- 1 | use pedigree::{InferenceGraphBuilder, Origin, Pedigree}; 2 | 3 | #[test] 4 | fn test_backward_iterator() { 5 | // Expected in the inference graph 6 | let d_id1 = "test::datum::1".to_string(); 7 | let d_id2 = "test::datum::2".to_string(); 8 | let d_id3 = "test::datum::3".to_string(); 9 | 10 | let a_id0 = "test::actor::0".to_string(); 11 | let a_id1 = "test::actor::1".to_string(); 12 | let a_id2 = "test::actor::2".to_string(); 13 | 14 | let inf_graph = vec![ 15 | ( 16 | d_id3.clone(), 17 | Origin { 18 | source_id: a_id2.clone(), 19 | args: vec![d_id2.clone()], 20 | }, 21 | ), 22 | ( 23 | d_id2.clone(), 24 | Origin { 25 | source_id: a_id1.clone(), 26 | args: vec![d_id1.clone()], 27 | }, 28 | ), 29 | ( 30 | d_id1.clone(), 31 | Origin { 32 | source_id: a_id0.clone(), 33 | args: Vec::new(), 34 | }, 35 | ), 36 | ].into_iter() 37 | .fold(Pedigree::new(), |ancs, (id, origin)| { 38 | ancs.insert(id, origin) 39 | }) 40 | .extract_inference_graph(&d_id3); 41 | 42 | let mut iter = inf_graph.back_iter(); 43 | assert_eq!( 44 | iter.next(), 45 | Some(( 46 | vec![ 47 | ( 48 | &d_id3, 49 | Some(&Origin { 50 | source_id: a_id2.clone(), 51 | args: vec![d_id2.clone()], 52 | }), 53 | ), 54 | ], 55 | Vec::new() 56 | )) 57 | ); 58 | assert_eq!( 59 | iter.next(), 60 | Some(( 61 | vec![ 62 | (&a_id2, None), 63 | ( 64 | &d_id2, 65 | Some(&Origin { 66 | source_id: a_id1.clone(), 67 | args: vec![d_id1.clone()], 68 | }), 69 | ), 70 | ], 71 | vec![&vec![d_id3.clone()].into_iter().collect()] 72 | .into_iter() 73 | .collect() 74 | )) 75 | ); 76 | assert_eq!( 77 | iter.next(), 78 | Some(( 79 | vec![ 80 | (&a_id1, None), 81 | ( 82 | &d_id1, 83 | Some(&Origin { 84 | source_id: a_id0.clone(), 85 | args: Vec::new(), 86 | }), 87 | ), 88 | ], 89 | vec![ 90 | &vec![a_id2.clone(), d_id2.clone()].into_iter().collect(), 91 | &vec![d_id3.clone()].into_iter().collect(), 92 | ].into_iter() 93 | .collect() 94 | )) 95 | ); 96 | assert_eq!( 97 | iter.next(), 98 | Some(( 99 | vec![(&a_id0, None)], 100 | vec![ 101 | &vec![a_id1.clone(), d_id1.clone()].into_iter().collect(), 102 | &vec![a_id2.clone(), d_id2.clone()].into_iter().collect(), 103 | &vec![d_id3.clone()].into_iter().collect(), 104 | ] 105 | )) 106 | ); 107 | assert_eq!(iter.next(), None); 108 | } 109 | 110 | #[test] 111 | fn test_extract_inference_graph() { 112 | // Expected in the inference graph 113 | let d_id1 = "test::datum::1".to_string(); 114 | let d_id2 = "test::datum::2".to_string(); 115 | let d_id3 = "test::datum::3".to_string(); 116 | 117 | let a_id0 = "test::actor::0".to_string(); 118 | let a_id1 = "test::actor::1".to_string(); 119 | let a_id2 = "test::actor::2".to_string(); 120 | 121 | // Not expected in the inference graph 122 | let d_id4 = "test::datum::4".to_string(); 123 | 124 | let pedigree = vec![ 125 | ( 126 | d_id3.clone(), 127 | Origin { 128 | source_id: a_id2.clone(), 129 | args: vec![d_id2.clone()], 130 | }, 131 | ), 132 | ( 133 | d_id2.clone(), 134 | Origin { 135 | source_id: a_id1.clone(), 136 | args: vec![d_id1.clone()], 137 | }, 138 | ), 139 | ( 140 | d_id1.clone(), 141 | Origin { 142 | source_id: a_id0.clone(), 143 | args: Vec::new(), 144 | }, 145 | ), 146 | ( 147 | d_id4.clone(), 148 | Origin { 149 | source_id: a_id0.clone(), 150 | args: Vec::new(), 151 | }, 152 | ), 153 | ].into_iter() 154 | .fold(Pedigree::new(), |ancs, (id, origin)| { 155 | ancs.insert(id, origin) 156 | }); 157 | 158 | let expected_pedigree = vec![ 159 | ( 160 | d_id3.clone(), 161 | Origin { 162 | source_id: a_id2.clone(), 163 | args: vec![d_id2.clone()], 164 | }, 165 | ), 166 | ( 167 | d_id2.clone(), 168 | Origin { 169 | source_id: a_id1.clone(), 170 | args: vec![d_id1.clone()], 171 | }, 172 | ), 173 | ( 174 | d_id1.clone(), 175 | Origin { 176 | source_id: a_id0.clone(), 177 | args: Vec::new(), 178 | }, 179 | ), 180 | ].into_iter() 181 | .fold(Pedigree::new(), |ancs, (id, origin)| { 182 | ancs.insert(id, origin) 183 | }); 184 | 185 | let expected_inference_graph = InferenceGraphBuilder::new() 186 | .root(d_id3.clone()) 187 | .leaves(vec![d_id1.clone()].into_iter().collect()) 188 | .pedigree(expected_pedigree) 189 | .entries_by_generation(vec![ 190 | vec![a_id0.clone()].into_iter().collect(), 191 | vec![d_id1.clone(), a_id1.clone()].into_iter().collect(), 192 | vec![d_id2.clone(), a_id2.clone()].into_iter().collect(), 193 | vec![d_id3.clone()].into_iter().collect(), 194 | ]) 195 | .entries_to_generation( 196 | vec![ 197 | (a_id0.clone(), 0), 198 | (d_id1.clone(), 1), 199 | (a_id1.clone(), 1), 200 | (d_id2.clone(), 2), 201 | (a_id2.clone(), 2), 202 | (d_id3.clone(), 3), 203 | ].into_iter() 204 | .collect(), 205 | ) 206 | .finalize(); 207 | 208 | let actual_inference_graph = pedigree.extract_inference_graph(&d_id3); 209 | assert_eq!( 210 | actual_inference_graph.entries_by_generation, expected_inference_graph.entries_by_generation, 211 | "Checking entries_by_generation" 212 | ); 213 | assert_eq!( 214 | actual_inference_graph.entries_to_generation, expected_inference_graph.entries_to_generation, 215 | "Checking entries_to_generation" 216 | ); 217 | assert_eq!( 218 | actual_inference_graph.leaves, expected_inference_graph.leaves, 219 | "Checking leaves" 220 | ); 221 | assert_eq!( 222 | actual_inference_graph.root, expected_inference_graph.root, 223 | "Checking root" 224 | ); 225 | assert_eq!( 226 | actual_inference_graph.pedigree, expected_inference_graph.pedigree, 227 | "Checking pedigree" 228 | ); 229 | assert_eq!( 230 | actual_inference_graph, expected_inference_graph, 231 | "Checking all" 232 | ); 233 | } 234 | -------------------------------------------------------------------------------- /src/planner/mod.rs: -------------------------------------------------------------------------------- 1 | //! The planner module implements a basic system for backtracking planner. 2 | //! 3 | //! It supports features like: 4 | //! 5 | //! * specifying a goal with constraints that must be satisfied by the resultant plan. 6 | //! 7 | //! * the ability to solve a conjunction of goal 8 | //! 9 | //! * rendering of a plan in graphviz format for easier visualization 10 | //! 11 | 12 | use constraints::{Constraint, ConstraintValue}; 13 | use core::{Operation, Bindings, Unify}; 14 | use itertools::FoldWhile::{Continue, Done}; 15 | use itertools::Itertools; 16 | use std; 17 | use std::collections::HashSet; 18 | use std::marker::PhantomData; 19 | use utils; 20 | 21 | #[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] 22 | pub enum UnificationIndex { 23 | #[serde(rename="init")] 24 | Init, 25 | #[serde(rename="actor")] 26 | Actor(usize), 27 | #[serde(rename="datum")] 28 | Datum(usize), 29 | #[serde(rename="exhausted")] 30 | Exhausted, 31 | } 32 | 33 | impl UnificationIndex { 34 | pub fn datum_idx(&self) -> Option { 35 | match *self { 36 | UnificationIndex::Datum(datum_idx) => Some(datum_idx), 37 | _ => None, 38 | } 39 | } 40 | 41 | pub fn actor_idx(&self) -> Option { 42 | match *self { 43 | UnificationIndex::Actor(actor_idx) => Some(actor_idx), 44 | _ => None, 45 | } 46 | } 47 | 48 | pub fn is_exhausted(&self) -> bool { 49 | match *self { 50 | UnificationIndex::Exhausted => true, 51 | _ => false, 52 | } 53 | } 54 | 55 | pub fn is_init(&self) -> bool { 56 | match *self { 57 | UnificationIndex::Init => true, 58 | _ => false, 59 | } 60 | } 61 | } 62 | 63 | impl std::fmt::Display for UnificationIndex { 64 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 65 | match *self { 66 | UnificationIndex::Init => write!(f, "Init"), 67 | UnificationIndex::Actor(idx) => write!(f, "Actor({})", idx), 68 | UnificationIndex::Datum(idx) => write!(f, "Datum({})", idx), 69 | UnificationIndex::Exhausted => write!(f, "Exhausted"), 70 | } 71 | } 72 | } 73 | 74 | impl Default for UnificationIndex { 75 | fn default() -> Self { 76 | UnificationIndex::Init 77 | } 78 | } 79 | 80 | fn increment_unification_index(current_unification_index: &UnificationIndex, datum_count: usize, rule_count: usize) -> UnificationIndex { 81 | let initial_rule_index = if rule_count > 0 { 82 | UnificationIndex::Actor(0) 83 | } else { 84 | UnificationIndex::Exhausted 85 | }; 86 | 87 | let initial_datum_index = if datum_count > 0 { 88 | UnificationIndex::Datum(0) 89 | } else { 90 | initial_rule_index.clone() 91 | }; 92 | 93 | match *current_unification_index { 94 | UnificationIndex::Exhausted => UnificationIndex::Exhausted, 95 | UnificationIndex::Init => initial_datum_index.clone(), 96 | UnificationIndex::Datum(current_idx) => { 97 | if current_idx + 1 < datum_count { 98 | UnificationIndex::Datum(current_idx + 1) 99 | } else { 100 | initial_rule_index 101 | } 102 | } 103 | UnificationIndex::Actor(current_idx) => { 104 | if current_idx + 1 < rule_count { 105 | UnificationIndex::Actor(current_idx + 1) 106 | } else { 107 | UnificationIndex::Exhausted 108 | } 109 | } 110 | } 111 | } 112 | 113 | /// Determine the first goal to increment 114 | pub fn first_goal_to_increment(unification_indices: &Vec) -> Option { 115 | if unification_indices.is_empty() { 116 | None 117 | } else { 118 | // Check subgoals find index that is in the Init state 119 | let last_index = unification_indices.len() - 1; 120 | let mut idx = 0; 121 | 122 | loop { 123 | if unification_indices[idx] == UnificationIndex::Init { 124 | return Some(idx); 125 | } 126 | if unification_indices[idx] == UnificationIndex::Exhausted { 127 | if idx == 0 { 128 | return None; 129 | } else { 130 | return Some(idx - 1); 131 | } 132 | } 133 | if idx == last_index { 134 | return Some(idx); 135 | } 136 | idx += 1; 137 | } 138 | } 139 | } 140 | 141 | #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] 142 | pub struct Goal, A: Operation> { 143 | #[serde(default)] 144 | pub bindings_at_creation: Bindings, 145 | #[serde(default="Vec::new")] 146 | pub constraints: Vec, 147 | #[serde(default="Vec::new")] 148 | pub parental_constraints: Vec, 149 | pub pattern: U, 150 | #[serde(default="Vec::new")] 151 | pub subgoals: Vec>, 152 | #[serde(default)] 153 | pub unification_index: UnificationIndex, 154 | #[serde(default)] 155 | _a_marker: PhantomData, 156 | #[serde(default)] 157 | _t_marker: PhantomData, 158 | } 159 | 160 | impl Goal 161 | where T: ConstraintValue, 162 | U: Unify, 163 | A: Operation 164 | { 165 | pub fn new(pattern: U, 166 | parental_constraints: Vec, 167 | constraints: Vec, 168 | bindings_at_creation: Bindings, 169 | unification_index: UnificationIndex, 170 | subgoals: Vec) 171 | -> Self { 172 | Goal { 173 | bindings_at_creation: bindings_at_creation, 174 | constraints: constraints, 175 | pattern: pattern, 176 | parental_constraints: parental_constraints, 177 | subgoals: subgoals, 178 | unification_index: unification_index, 179 | _a_marker: PhantomData, 180 | _t_marker: PhantomData, 181 | } 182 | } 183 | 184 | pub fn constraints(&self) -> Vec<&Constraint> { 185 | self.parental_constraints.iter().chain(self.constraints.iter()).collect() 186 | } 187 | 188 | pub fn with_pattern(pattern: U) -> Self { 189 | Goal::new(pattern, 190 | Vec::new(), 191 | Vec::new(), 192 | Bindings::new(), 193 | UnificationIndex::Init, 194 | Vec::new()) 195 | } 196 | 197 | pub fn solve(goal: &Self, data: &Vec<&U>, rules: &Vec<&A>, increments: usize, config: &PlanningConfig) -> Option<(usize, Self, Bindings)> { 198 | if increments < config.max_increments { 199 | goal.increment(&data, &rules, increments, config.max_depth).and_then(|goal| match goal.validate(data, rules, &Bindings::new(), config) { 200 | Ok(bindings) => Some((increments, goal.clone(), bindings.clone())), 201 | Err(_) => Goal::solve(&goal, data, rules, increments + 1, config), 202 | }) 203 | } else { 204 | None 205 | } 206 | } 207 | 208 | pub fn solve_conjunction(goals: Vec<&Self>, 209 | data: &Vec<&U>, 210 | rules: &Vec<&A>, 211 | increments: usize, 212 | config: &PlanningConfig) 213 | -> Option<(Vec, Bindings)> { 214 | if increments < config.max_increments { 215 | Goal::increment_conjunction(goals, data, rules, increments, config.max_depth).and_then(|goals| { 216 | let validated_bindings = utils::fold_while_some(Bindings::new(), 217 | &mut goals.iter(), 218 | &|bindings, goal| goal.validate(data, rules, &bindings, config).ok()); 219 | match validated_bindings { 220 | Some(bindings) => Some((goals, bindings)), 221 | None => Goal::solve_conjunction(goals.iter().collect(), data, rules, increments + 1, config), 222 | } 223 | }) 224 | } else { 225 | None 226 | } 227 | } 228 | 229 | pub fn solve_conjunction_with_criteria(goals: Vec<&Self>, 230 | data: &Vec<&U>, 231 | rules: &Vec<&A>, 232 | increments: usize, 233 | config: &PlanningConfig, 234 | criteria: &Fn(&Vec, &Bindings) -> Option) 235 | -> Option<(Vec, Bindings, X)> { 236 | if increments < config.max_increments { 237 | Goal::solve_conjunction(goals, data, rules, increments, config).and_then(|(goals, bindings)| match criteria(&goals, &bindings) { 238 | Some(result) => Some((goals, bindings, result)), 239 | None => { 240 | Goal::solve_conjunction_with_criteria(goals.iter().collect(), 241 | data, 242 | rules, 243 | increments, 244 | config, 245 | criteria) 246 | } 247 | }) 248 | } else { 249 | None 250 | } 251 | } 252 | 253 | /// Verify that this plan does not break any of the planning specifications and that it is consistent 254 | pub fn validate(&self, data: &Vec<&U>, rules: &Vec<&A>, bindings: &Bindings, config: &PlanningConfig) -> Result, InvalidPlan> { 255 | config.validate_plan(self).and_then(|_| match self.satisified(&data, &rules, bindings) { 256 | Some(bindings) => Ok(bindings), 257 | None => Err(InvalidPlan::BindingsConflict), 258 | }) 259 | } 260 | 261 | /// Construct a mutated plan 262 | pub fn increment(&self, data: &Vec<&U>, rules: &Vec<&A>, snowflake_prefix_id: usize, max_depth: usize) -> Option { 263 | if max_depth == 0 { 264 | return None; 265 | } 266 | let mut goal = self.clone(); 267 | 268 | // If there are any subgoals, increment them first 269 | if !self.subgoals.is_empty() { 270 | if let Some(subgoals) = Goal::increment_conjunction(self.subgoals.iter().collect(), 271 | data, 272 | rules, 273 | snowflake_prefix_id, 274 | max_depth) { 275 | goal.subgoals = subgoals; 276 | return Some(goal); 277 | } 278 | } 279 | 280 | // If subgoals cannot be incremented, increment this goal 281 | loop { 282 | goal.unification_index = increment_unification_index(&goal.unification_index, data.len(), rules.len()); 283 | 284 | // Attempt to increment this goal 285 | match goal.unification_index { 286 | UnificationIndex::Datum(_idx) => { 287 | if goal.satisified(data, rules, &goal.bindings_at_creation).is_some() { 288 | return Some(goal); 289 | } 290 | } 291 | UnificationIndex::Actor(idx) => { 292 | let rule = rules[idx].snowflake(format!("{}", snowflake_prefix_id)); 293 | 294 | if let Some(subgoals) = Self::create_subgoals(&self.pattern, 295 | &rule, 296 | &goal.constraints(), 297 | data, 298 | rules, 299 | snowflake_prefix_id, 300 | max_depth) { 301 | goal.subgoals = subgoals; 302 | return Some(goal); 303 | } 304 | } 305 | // If this goal cannot be incremented, return None 306 | UnificationIndex::Exhausted => return None, 307 | UnificationIndex::Init => panic!("Init after incrementing; this should never happen"), 308 | } 309 | } 310 | } 311 | 312 | /// Determine if the plan is valid 313 | pub fn satisified(&self, data: &Vec<&U>, rules: &Vec<&A>, bindings: &Bindings) -> Option> { 314 | let bindings = bindings.merge(&self.bindings_at_creation); 315 | match self.unification_index { 316 | UnificationIndex::Datum(datum_idx) => { 317 | self.pattern 318 | .unify(data[datum_idx], &bindings) 319 | .and_then(|bindings| Constraint::solve_many(self.constraints(), &bindings).ok()) 320 | } 321 | UnificationIndex::Actor(_actor_idx) => { 322 | self.subgoals 323 | .iter() 324 | .fold_while(Some(bindings), 325 | |bindings, subgoal| match subgoal.satisified(data, rules, bindings.as_ref().unwrap()) { 326 | Some(subgoal_bindings) => Continue(Some(subgoal_bindings.clone())), 327 | None => Done(None), 328 | }) 329 | } 330 | UnificationIndex::Init => self.pattern.unify(&U::nil(), &bindings), 331 | UnificationIndex::Exhausted => None, 332 | } 333 | } 334 | 335 | pub fn create_subgoals(r_pattern: &U, 336 | rule: &A, 337 | parent_constraints: &Vec<&Constraint>, 338 | data: &Vec<&U>, 339 | rules: &Vec<&A>, 340 | snowflake_prefix_id: usize, 341 | max_depth: usize) 342 | -> Option> { 343 | rule.r_apply_match(r_pattern).and_then(|(subgoal_patterns, bindings)| { 344 | let subgoals: Vec> = subgoal_patterns.into_iter() 345 | .map(|pattern| { 346 | Goal::new(pattern.apply_bindings(&bindings).unwrap(), 347 | parent_constraints.iter().map(|c| (*c).clone()).collect(), 348 | rule.constraints().iter().map(|c| (*c).clone()).collect(), 349 | bindings.clone(), 350 | UnificationIndex::default(), 351 | Vec::new()) 352 | .increment(data, rules, snowflake_prefix_id + 1, max_depth - 1) 353 | }) 354 | .collect(); 355 | 356 | if subgoals.iter().any(|x| x.is_none()) { 357 | None 358 | } else { 359 | Some(subgoals.into_iter().map(|sg| sg.unwrap()).collect()) 360 | } 361 | }) 362 | } 363 | 364 | pub fn increment_conjunction(goals: Vec<&Self>, 365 | data: &Vec<&U>, 366 | rules: &Vec<&A>, 367 | snowflake_prefix_id: usize, 368 | max_depth: usize) 369 | -> Option> { 370 | let mut goals: Vec = goals.into_iter().map(|g| g.clone()).collect(); 371 | let goal_count = goals.len(); 372 | let is_last_goal = |idx| idx + 1 == goal_count; 373 | loop { 374 | let unification_indices = goals.iter().map(|sg| sg.unification_index.clone()).collect(); 375 | let goal_idx_to_increment = first_goal_to_increment(&unification_indices); 376 | match goal_idx_to_increment { 377 | None => return None, 378 | Some(idx) => { 379 | if let Some(new_goal) = goals[idx].increment(data, rules, snowflake_prefix_id, max_depth - 1) { 380 | goals[idx] = new_goal; 381 | 382 | if is_last_goal(idx) { 383 | return Some(goals); 384 | } else { 385 | goals[idx + 1].unification_index = UnificationIndex::Init; 386 | } 387 | } else { 388 | goals[idx].unification_index = UnificationIndex::Exhausted; 389 | } 390 | } 391 | } 392 | } 393 | } 394 | 395 | pub fn render_as_graphviz(&self) -> String { 396 | let subtree_string = self.render_subtree_as_graphviz(None); 397 | format!("graph \"goal tree {}\" {{\n{}\n}}", 398 | self.pattern, 399 | subtree_string) 400 | } 401 | 402 | fn render_subtree_as_graphviz(&self, parent: Option) -> String { 403 | let goal_rendering = format!("{} [{}]", self.pattern, self.unification_index); 404 | let subtree_string_vec: Vec = self.subgoals 405 | .iter() 406 | .map(|subgoal| subgoal.render_subtree_as_graphviz(Some(goal_rendering.clone()))) 407 | .collect(); 408 | let subtree_string = subtree_string_vec.join("\n"); 409 | let goal_parent_str = if let Some(parent_goal) = parent { 410 | format!("\"{}\" -- \"{}\";", parent_goal, goal_rendering) 411 | } else { 412 | String::new() 413 | }; 414 | format!("{}\n{}", goal_parent_str, subtree_string) 415 | } 416 | 417 | pub fn pprint(&self, ntabs: usize, only_render_spine: bool) -> String { 418 | let tabs = utils::concat_tabs(ntabs); 419 | let subgoal_v: Vec = self.subgoals.iter().map(|sg| sg.pprint(ntabs + 1, only_render_spine)).collect(); 420 | let subgoal_s = subgoal_v.join("\n"); 421 | 422 | let parental_constraint_v: Vec = 423 | self.parental_constraints.iter().map(|c| format!("{}{}", utils::concat_tabs(ntabs + 1), c)).collect(); 424 | let parental_constraint_s = parental_constraint_v.join("\n"); 425 | 426 | let constraint_v: Vec = self.constraints.iter().map(|c| format!("{}{}", utils::concat_tabs(ntabs + 1), c)).collect(); 427 | let constraint_s = constraint_v.join("\n"); 428 | 429 | if only_render_spine { 430 | format!("{}{} @ {}\n{}", 431 | tabs, 432 | self.pattern, 433 | self.unification_index, 434 | subgoal_s) 435 | } else { 436 | format!("{}{} @ {}\n\t{}bindings at creation: {}\n\t{}parental constraints:\n{}\n\t{}constraints:\n{}\n{}", 437 | tabs, 438 | self.pattern, 439 | self.unification_index, 440 | tabs, 441 | self.bindings_at_creation, 442 | tabs, 443 | parental_constraint_s, 444 | tabs, 445 | constraint_s, 446 | subgoal_s) 447 | } 448 | } 449 | 450 | pub fn apply_bindings(&self, bindings: &Bindings) -> Option { 451 | self.pattern.apply_bindings(&bindings).and_then(|pattern| { 452 | let mut clone = Goal::new(pattern, 453 | self.parental_constraints.clone(), 454 | self.constraints.clone(), 455 | self.bindings_at_creation.clone(), 456 | self.unification_index.clone(), 457 | Vec::with_capacity(self.subgoals.len())); 458 | for subgoal in self.subgoals.iter() { 459 | if let Some(applied_subgoal) = subgoal.apply_bindings(&bindings) { 460 | clone.subgoals.push(applied_subgoal); 461 | } else { 462 | return None; 463 | } 464 | } 465 | Some(clone) 466 | }) 467 | } 468 | 469 | /// Traverse the tree and determine if any datum is being used more than once 470 | pub fn find_reused_datum(&self, used_data: &mut HashSet) -> Option { 471 | match self.unification_index { 472 | UnificationIndex::Datum(ref datum_idx) => { 473 | if used_data.contains(datum_idx) { 474 | return Some(datum_idx.clone()); 475 | } else { 476 | used_data.insert(datum_idx.clone()); 477 | return None; 478 | } 479 | } 480 | UnificationIndex::Actor(_actor_idx) => { 481 | for subgoal in self.subgoals.iter() { 482 | if let Some(idx) = subgoal.find_reused_datum(used_data) { 483 | return Some(idx); 484 | } 485 | } 486 | None 487 | } 488 | _ => None, 489 | } 490 | } 491 | 492 | /// Traverse the goal tree using a depth-first search and gather the leaves of the plan 493 | pub fn gather_leaves(&self, bindings: &Bindings) -> Vec { 494 | let mut leaves = Vec::new(); 495 | 496 | if self.subgoals.is_empty() { 497 | leaves.push(self.pattern.apply_bindings(&bindings).expect("Bindings should be applicable")); 498 | } else { 499 | for sg in self.subgoals.iter() { 500 | leaves.extend(sg.gather_leaves(bindings).into_iter()); 501 | } 502 | } 503 | 504 | leaves 505 | } 506 | } 507 | 508 | impl std::fmt::Display for Goal 509 | where T: ConstraintValue, 510 | U: Unify, 511 | A: Operation 512 | { 513 | fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { 514 | write!(f, "Goal tree:\n{}", self.pprint(1, true)) 515 | } 516 | } 517 | 518 | #[derive(Clone, Copy, Debug, Eq, PartialEq)] 519 | pub struct PlanningConfig { 520 | pub max_depth: usize, 521 | pub max_increments: usize, 522 | pub reuse_data: bool, 523 | } 524 | 525 | impl Default for PlanningConfig { 526 | fn default() -> Self { 527 | PlanningConfig { 528 | max_depth: 3, 529 | max_increments: 1000, 530 | reuse_data: true, 531 | } 532 | } 533 | } 534 | 535 | impl PlanningConfig { 536 | /// Verify that nothing in plan contradicts specifications set in the PlanningConfig 537 | pub fn validate_plan(&self, goal: &Goal) -> Result<(), InvalidPlan> 538 | where T: ConstraintValue, 539 | U: Unify, 540 | A: Operation 541 | { 542 | if !self.reuse_data { 543 | if let Some(idx) = goal.find_reused_datum(&mut HashSet::new()) { 544 | return Err(InvalidPlan::ReusedData { idx: idx }); 545 | } 546 | } 547 | Ok(()) 548 | } 549 | } 550 | 551 | #[derive(Clone, Copy, Debug, Eq, PartialEq)] 552 | pub enum InvalidPlan { 553 | BindingsConflict, 554 | ReusedData { idx: usize }, 555 | } 556 | 557 | #[derive(Clone, Debug, Eq, PartialEq)] 558 | pub struct Planner<'a, T, U, A> 559 | where T: ConstraintValue, 560 | U: 'a + Unify, 561 | A: 'a + Operation 562 | { 563 | bindings: Bindings, 564 | config: PlanningConfig, 565 | data: Vec<&'a U>, 566 | goal: Goal, 567 | total_increments: usize, 568 | rules: Vec<&'a A>, 569 | } 570 | 571 | impl<'a, T, U, A> Planner<'a, T, U, A> 572 | where T: ConstraintValue, 573 | U: 'a + Unify, 574 | A: 'a + Operation 575 | { 576 | pub fn new(goal: &Goal, bindings: &Bindings, config: &PlanningConfig, data: Vec<&'a U>, rules: Vec<&'a A>) -> Self { 577 | Planner { 578 | bindings: bindings.clone(), 579 | config: config.clone(), 580 | data: data, 581 | goal: goal.clone(), 582 | total_increments: 0, 583 | rules: rules, 584 | } 585 | } 586 | } 587 | 588 | impl<'a, T, U, A> Iterator for Planner<'a, T, U, A> 589 | where T: ConstraintValue, 590 | U: 'a + Unify, 591 | A: 'a + Operation 592 | { 593 | type Item = (Goal, Bindings); 594 | 595 | fn next(&mut self) -> Option<(Goal, Bindings)> { 596 | for _i in self.total_increments..self.config.max_increments { 597 | self.total_increments += 1; 598 | 599 | if let Some((increments, goal, bindings)) = 600 | Goal::solve(&self.goal, 601 | &self.data, 602 | &self.rules, 603 | self.total_increments, 604 | &self.config) { 605 | self.goal = goal; 606 | self.total_increments += increments; 607 | return Some((self.goal.clone(), bindings.clone())); 608 | } else { 609 | break; 610 | } 611 | } 612 | None 613 | } 614 | } 615 | 616 | #[derive(Clone, Debug, Eq, PartialEq)] 617 | pub struct ConjunctivePlanner<'a, T, U, A> 618 | where T: ConstraintValue, 619 | U: 'a + Unify, 620 | A: 'a + Operation 621 | { 622 | bindings: Bindings, 623 | config: PlanningConfig, 624 | data: Vec<&'a U>, 625 | goals: Vec>, 626 | total_increments: usize, 627 | rules: Vec<&'a A>, 628 | } 629 | 630 | impl<'a, T, U, A> ConjunctivePlanner<'a, T, U, A> 631 | where T: ConstraintValue, 632 | U: 'a + Unify, 633 | A: 'a + Operation 634 | { 635 | pub fn new(goals: Vec>, bindings: &Bindings, config: &PlanningConfig, data: Vec<&'a U>, rules: Vec<&'a A>) -> Self { 636 | ConjunctivePlanner { 637 | bindings: bindings.clone(), 638 | config: config.clone(), 639 | data: data, 640 | goals: goals, 641 | total_increments: 0, 642 | rules: rules, 643 | } 644 | } 645 | } 646 | 647 | impl<'a, T, U, A> Iterator for ConjunctivePlanner<'a, T, U, A> 648 | where T: ConstraintValue, 649 | U: 'a + Unify, 650 | A: 'a + Operation 651 | { 652 | type Item = (Vec>, Bindings); 653 | 654 | fn next(&mut self) -> Option<(Vec>, Bindings)> { 655 | for _i in self.total_increments..self.config.max_increments { 656 | self.total_increments += 1; 657 | 658 | if let Some((goals, bindings)) = 659 | Goal::solve_conjunction(self.goals.iter().collect(), 660 | &self.data, 661 | &self.rules, 662 | self.total_increments, 663 | &self.config) { 664 | self.goals = goals; 665 | return Some((self.goals.clone(), bindings.clone())); 666 | } else { 667 | break; 668 | } 669 | } 670 | None 671 | } 672 | } 673 | 674 | #[cfg(test)] 675 | mod tests; 676 | -------------------------------------------------------------------------------- /src/rule/mod.rs: -------------------------------------------------------------------------------- 1 | //! The rule module provides two data structures, Rule and MultiRule, that implement the Operation trait. 2 | //! Rule aims to be a drop-in for any algorithm in ai_kit that operates on the Operation trait. It is useful when you have one possible result for an operation. 3 | //! MultiRule aims to be a drop-in for any algorithm in ai_kit that operates on the Operation trait. It is useful when you have multiple results for an operation. 4 | 5 | use constraints::{Constraint, ConstraintValue}; 6 | use core::{Bindings, Operation, Unify}; 7 | use std; 8 | use std::collections::HashMap; 9 | use std::marker::PhantomData; 10 | use utils; 11 | 12 | #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] 13 | pub struct Rule> { 14 | #[serde(default)] 15 | pub constraints: Vec, 16 | #[serde(default = "Vec::default")] 17 | pub lhs: Vec, 18 | pub rhs: U, 19 | #[serde(default)] 20 | pub _marker: PhantomData, 21 | } 22 | 23 | impl Operation for Rule 24 | where 25 | T: ConstraintValue, 26 | U: Unify, 27 | { 28 | fn input_patterns(&self) -> Vec { 29 | self.lhs.clone() 30 | } 31 | 32 | fn apply_match(&self, bindings: &Bindings) -> Option> { 33 | self.solve_constraints(&bindings).and_then(|bindings| { 34 | self.rhs 35 | .apply_bindings(&bindings) 36 | .and_then(|bound_rhs| Some(vec![bound_rhs])) 37 | }) 38 | } 39 | 40 | fn r_apply_match(&self, fact: &U) -> Option<(Vec, Bindings)> { 41 | self.rhs 42 | .unify(fact, &Bindings::new()) 43 | .and_then(|bindings| self.solve_constraints(&bindings)) 44 | .and_then(|bindings| { 45 | utils::map_while_some(&mut self.lhs.iter(), &|f| { 46 | f.apply_bindings(&bindings) 47 | }).and_then(|inputs| Some((inputs, bindings))) 48 | }) 49 | } 50 | 51 | fn constraints<'a>(&'a self) -> Vec<&'a Constraint> { 52 | self.constraints.iter().collect() 53 | } 54 | 55 | /// Construct a new version of this rule but with all variables updated to be unique for this invocation 56 | fn snowflake(&self, suffix: String) -> Self { 57 | // Gather all variables 58 | let mut variables = self.rhs.variables(); 59 | for lhs in self.lhs.iter() { 60 | variables.extend(lhs.variables()); 61 | } 62 | for constraint in self.constraints.iter() { 63 | variables.extend(constraint.variables()); 64 | } 65 | 66 | let renamed_variable: HashMap = variables 67 | .into_iter() 68 | .map(|var| (var.clone(), format!("{}::{}", var, suffix))) 69 | .collect(); 70 | 71 | let rhs = self.rhs.rename_variables(&renamed_variable); 72 | let lhs: Vec = self.lhs 73 | .iter() 74 | .map(|lhs| lhs.rename_variables(&renamed_variable)) 75 | .collect(); 76 | let constraints: Vec = self.constraints 77 | .iter() 78 | .map(|constraint| constraint.rename_variables(&renamed_variable)) 79 | .collect(); 80 | 81 | Rule { 82 | constraints: constraints, 83 | lhs: lhs, 84 | rhs: rhs, 85 | _marker: PhantomData, 86 | } 87 | } 88 | } 89 | 90 | impl Rule 91 | where 92 | T: ConstraintValue, 93 | U: Unify, 94 | { 95 | pub fn new(lhs: Vec, rhs: U, constraints: Vec) -> Self { 96 | Rule { 97 | constraints: constraints, 98 | lhs: lhs, 99 | rhs: rhs, 100 | _marker: PhantomData, 101 | } 102 | } 103 | 104 | pub fn unify(&self, facts: &Vec<&U>, bindings: &Bindings) -> Option> { 105 | utils::fold_while_some( 106 | bindings.clone(), 107 | &mut self.lhs.iter().zip(facts.iter()), 108 | &|bindings, (t1, t2)| t1.unify(t2, &bindings), 109 | ) 110 | } 111 | 112 | pub fn apply_bindings(&self, bindings: &Bindings) -> Option { 113 | self.rhs.apply_bindings(bindings) 114 | } 115 | 116 | fn solve_constraints(&self, bindings: &Bindings) -> Option> { 117 | Constraint::solve_many(self.constraints.iter().collect(), bindings).ok() 118 | } 119 | 120 | pub fn pprint(&self) -> String { 121 | let lhs_string_vec: Vec = self.lhs.iter().map(|o| format!("{}", o)).collect(); 122 | let lhs_string = lhs_string_vec.join("\n\t\t"); 123 | let constraints_string_vec: Vec = self.constraints.iter().map(|o| format!("{}", o)).collect(); 124 | let constraints_string = constraints_string_vec.join("\n"); 125 | format!( 126 | "Rule {{\n\tlhs:\n\t\t{},\n\trhs:\n\t\t{},\n\tconstraints:\n\t\t{} }}", 127 | lhs_string, self.rhs, constraints_string 128 | ) 129 | } 130 | } 131 | 132 | impl std::fmt::Display for Rule 133 | where 134 | T: ConstraintValue, 135 | U: Unify, 136 | { 137 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 138 | write!(f, "{}", self.pprint()) 139 | } 140 | } 141 | 142 | impl Eq for Rule 143 | where 144 | T: ConstraintValue, 145 | U: Unify, 146 | { 147 | } 148 | 149 | #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] 150 | pub struct MultiRule> { 151 | #[serde(default)] 152 | pub constraints: Vec, 153 | #[serde(default = "Vec::default")] 154 | pub lhs: Vec, 155 | pub rhs: Vec, 156 | #[serde(default)] 157 | pub _marker: PhantomData, 158 | } 159 | 160 | impl Operation for MultiRule 161 | where 162 | T: ConstraintValue, 163 | U: Unify, 164 | { 165 | fn input_patterns(&self) -> Vec { 166 | self.lhs.clone() 167 | } 168 | 169 | fn apply_match(&self, bindings: &Bindings) -> Option> { 170 | self.solve_constraints(&bindings) 171 | .and_then(|bindings| self.apply_bindings(&bindings)) 172 | } 173 | 174 | /// Note: MultiRule's implementation of r_apply_match will only return Some result if 175 | /// a consistent set of bindings can be used when unifying all of the rhs elements 176 | fn r_apply_match(&self, fact: &U) -> Option<(Vec, Bindings)> { 177 | utils::fold_while_some( 178 | (Vec::new(), Bindings::new()), 179 | &mut self.rhs.iter(), 180 | &|(mut bound_rhs, bindings), rhs| { 181 | let results = rhs.unify(fact, &bindings) 182 | .and_then(|bindings| self.solve_constraints(&bindings)) 183 | .and_then(|bindings| { 184 | utils::map_while_some(&mut self.lhs.iter(), &|f| { 185 | f.apply_bindings(&bindings) 186 | }).and_then(|inputs| Some((inputs, bindings))) 187 | }); 188 | if let Some((inputs, new_bindings)) = results { 189 | bound_rhs.extend(inputs.into_iter()); 190 | Some((bound_rhs, new_bindings)) 191 | } else { 192 | None 193 | } 194 | }, 195 | ) 196 | } 197 | 198 | fn constraints<'a>(&'a self) -> Vec<&'a Constraint> { 199 | self.constraints.iter().collect() 200 | } 201 | 202 | /// Construct a new version of this rule but with all variables updated to be unique for this invocation 203 | fn snowflake(&self, suffix: String) -> Self { 204 | // Gather all variables 205 | let mut variables: Vec = self.rhs.iter().flat_map(|rhs| rhs.variables()).collect(); 206 | 207 | for lhs in self.lhs.iter() { 208 | variables.extend(lhs.variables()); 209 | } 210 | for constraint in self.constraints.iter() { 211 | variables.extend(constraint.variables()); 212 | } 213 | 214 | let renamed_variable: HashMap = variables 215 | .into_iter() 216 | .map(|var| (var.clone(), format!("{}::{}", var, suffix))) 217 | .collect(); 218 | 219 | let rhs: Vec = self.rhs 220 | .iter() 221 | .map(|rhs| rhs.rename_variables(&renamed_variable)) 222 | .collect(); 223 | let lhs: Vec = self.lhs 224 | .iter() 225 | .map(|lhs| lhs.rename_variables(&renamed_variable)) 226 | .collect(); 227 | let constraints: Vec = self.constraints 228 | .iter() 229 | .map(|constraint| constraint.rename_variables(&renamed_variable)) 230 | .collect(); 231 | 232 | MultiRule { 233 | constraints: constraints, 234 | lhs: lhs, 235 | rhs: rhs, 236 | _marker: PhantomData, 237 | } 238 | } 239 | } 240 | 241 | impl MultiRule 242 | where 243 | T: ConstraintValue, 244 | U: Unify, 245 | { 246 | pub fn new(lhs: Vec, rhs: Vec, constraints: Vec) -> Self { 247 | MultiRule { 248 | constraints: constraints, 249 | lhs: lhs, 250 | rhs: rhs, 251 | _marker: PhantomData, 252 | } 253 | } 254 | 255 | pub fn unify(&self, facts: &Vec<&U>, bindings: &Bindings) -> Option> { 256 | utils::fold_while_some( 257 | bindings.clone(), 258 | &mut self.lhs.iter().zip(facts.iter()), 259 | &|bindings, (t1, t2)| t1.unify(t2, &bindings), 260 | ) 261 | } 262 | 263 | pub fn apply_bindings(&self, bindings: &Bindings) -> Option> { 264 | utils::map_while_some(&mut self.rhs.iter(), &|rhs| { 265 | rhs.apply_bindings(bindings) 266 | }) 267 | } 268 | 269 | fn solve_constraints(&self, bindings: &Bindings) -> Option> { 270 | Constraint::solve_many(self.constraints.iter().collect(), bindings).ok() 271 | } 272 | 273 | pub fn pprint(&self) -> String { 274 | let lhs_string: String = self.lhs 275 | .iter() 276 | .map(|o| format!("{}", o)) 277 | .collect::>() 278 | .join("\n\t\t"); 279 | let rhs_string: String = self.rhs 280 | .iter() 281 | .map(|o| format!("{}", o)) 282 | .collect::>() 283 | .join("\n\t\t"); 284 | 285 | let constraints_string_vec: Vec = self.constraints.iter().map(|o| format!("{}", o)).collect(); 286 | let constraints_string = constraints_string_vec.join("\n"); 287 | 288 | format!( 289 | "MultiRule {{\n\tlhs:\n\t\t{},\n\trhs:\n\t\t{},\n\tconstraints:\n\t\t{} }}", 290 | lhs_string, rhs_string, constraints_string 291 | ) 292 | } 293 | 294 | /// Rename any variables in this structure with another variable name 295 | pub fn rename_variables(&self, renamed_variables: &HashMap) -> Self { 296 | MultiRule::new( 297 | self.lhs 298 | .iter() 299 | .map(|u| u.rename_variables(renamed_variables)) 300 | .collect(), 301 | self.rhs 302 | .iter() 303 | .map(|u| u.rename_variables(renamed_variables)) 304 | .collect(), 305 | self.constraints 306 | .iter() 307 | .map(|c| c.rename_variables(renamed_variables)) 308 | .collect(), 309 | ) 310 | } 311 | } 312 | 313 | impl std::fmt::Display for MultiRule 314 | where 315 | T: ConstraintValue, 316 | U: Unify, 317 | { 318 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 319 | write!(f, "{}", self.pprint()) 320 | } 321 | } 322 | 323 | impl Eq for MultiRule 324 | where 325 | T: ConstraintValue, 326 | U: Unify, 327 | { 328 | } 329 | 330 | #[cfg(test)] 331 | mod tests; 332 | -------------------------------------------------------------------------------- /src/rule/tests.rs: -------------------------------------------------------------------------------- 1 | use datum::Datum; 2 | use rule::*; 3 | 4 | fn setup() -> Rule { 5 | from_json!(Rule, { 6 | "lhs": [{"vec": [{"str": "x"}, {"var": "?x"}]}], 7 | "rhs": {"vec": [{"str": "y"}, {"var": "?y"}]}, 8 | "constraints": [ 9 | {"numerical": {"set": {"variable": "?diff", "constant": 25}}}, 10 | {"numerical": {"sum": {"first": "?x", "second": "?y", "third": "?diff"}}} 11 | ], 12 | }) 13 | } 14 | 15 | #[test] 16 | fn test_snowflake() { 17 | let rule: Rule = setup(); 18 | 19 | let expected_snowflake = from_json!(Rule, { 20 | "lhs": [{"vec": [{"str": "x"}, {"var": "?x::test"}]}], 21 | "rhs": {"vec": [{"str": "y"}, {"var": "?y::test"}]}, 22 | "constraints": [ 23 | {"numerical": {"set": {"variable": "?diff::test", "constant": 25}}}, 24 | {"numerical": {"sum": {"first": "?x::test", "second": "?y::test", "third": "?diff::test"}}} 25 | ], 26 | }); 27 | 28 | assert_eq!(rule.snowflake("test".to_string()), expected_snowflake); 29 | } 30 | 31 | #[test] 32 | fn test_rule_application() { 33 | let rule: Rule = setup(); 34 | let initial_datum = from_json!(Datum, {"vec": [{"str": "x"}, {"float": 10}]}); 35 | let expected_datum = from_json!(Datum, {"vec": [{"str": "y"}, {"float": 15}]}); 36 | let expected_bindings: Bindings = Bindings::new().set_binding(&"?x".to_string(), Datum::Float(10.0)); 37 | let bindings = rule.input_patterns()[0] 38 | .unify(&initial_datum, &Bindings::new()) 39 | .unwrap(); 40 | assert_eq!(bindings, expected_bindings); 41 | assert_eq!(rule.apply_match(&bindings), Some(vec![expected_datum])); 42 | } 43 | 44 | #[test] 45 | fn test_rule_application_with_no_antecedents() { 46 | let rule = from_json!(Rule, { 47 | "rhs": {"vec": [{"str": "y"}, {"float": 1}]}, 48 | }); 49 | let expected_datum = from_json!(Datum, {"vec": [{"str": "y"}, {"float": 1}]}); 50 | assert_eq!( 51 | rule.apply_match(&Bindings::new()), 52 | Some(vec![expected_datum]) 53 | ); 54 | } 55 | 56 | #[test] 57 | fn test_rule_reverse_application() { 58 | let rule: Rule = setup(); 59 | let expected_datum = from_json!(Datum, {"vec": [{"str": "x"}, {"float": 10}]}); 60 | let initial_datum = from_json!(Datum, {"vec": [{"str": "y"}, {"float": 15}]}); 61 | let expected_bindings = Bindings::new() 62 | .set_binding(&"?diff".to_string(), Datum::Float(25.0)) 63 | .set_binding(&"?x".to_string(), Datum::Float(10.0)) 64 | .set_binding(&"?y".to_string(), Datum::Float(15.0)); 65 | assert_eq!( 66 | rule.r_apply_match(&initial_datum), 67 | Some((vec![expected_datum], expected_bindings)) 68 | ); 69 | } 70 | -------------------------------------------------------------------------------- /src/test_utils.rs: -------------------------------------------------------------------------------- 1 | macro_rules! from_json { 2 | ($type: ty, $json: tt) => ({ 3 | use serde_json; 4 | let x: $type = serde_json::from_value(json!($json)).expect("Expected json decoding"); 5 | x 6 | }) 7 | } 8 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | //! Internal utilities 2 | 3 | use itertools::FoldWhile::{Continue, Done}; 4 | 5 | use itertools::Itertools; 6 | 7 | /// Map across the iterator, terminating early if a mapping returns None 8 | pub fn map_while_some(iter: &mut Iterator, f: &Fn(E) -> Option) -> Option> { 9 | let mut results: Vec = Vec::new(); 10 | for x in iter { 11 | if let Some(result) = f(x) { 12 | results.push(result); 13 | } else { 14 | return None; 15 | } 16 | } 17 | return Some(results); 18 | } 19 | 20 | pub fn fold_while_some(init_acc: A, iter: &mut Iterator, f: &Fn(A, E) -> Option) -> Option { 21 | iter.fold_while(Some(init_acc), |acc, x| match f(acc.unwrap(), x) { 22 | Some(value) => Continue(Some(value)), 23 | None => Done(None), 24 | }) 25 | } 26 | 27 | /// Create a string with the specified number of tabs 28 | pub fn concat_tabs(ntabs: usize) -> String { 29 | let tabv: Vec = (0..ntabs).map(|_| "\t".to_string()).collect(); 30 | tabv.join("") 31 | } 32 | -------------------------------------------------------------------------------- /tests/skeptic.rs: -------------------------------------------------------------------------------- 1 | include!(concat!(env!("OUT_DIR"), "/skeptic-tests.rs")); 2 | --------------------------------------------------------------------------------