├── .github └── workflows │ └── main.yml ├── .gitignore ├── .rustfmt.toml ├── Cargo.toml ├── LICENSE_APACHE.md ├── LICENSE_MIT.md ├── README.md ├── examples └── cli.rs └── src ├── equations.rs ├── expr.rs ├── lib.rs ├── ops.rs ├── parse.rs └── solve.rs /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | 3 | name: Continuous integration 4 | 5 | jobs: 6 | check: 7 | name: Compile and Test 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | rust: 12 | - nightly 13 | - stable 14 | # MSRV - We probably only need 1.39.0 for bind-by-move 15 | - 1.44.1 16 | steps: 17 | - uses: actions/checkout@v2 18 | - uses: actions-rs/toolchain@v1 19 | with: 20 | profile: minimal 21 | toolchain: ${{ matrix.rust }} 22 | override: true 23 | - uses: actions-rs/cargo@v1 24 | with: 25 | command: check 26 | args: --all --verbose 27 | - uses: actions-rs/cargo@v1 28 | with: 29 | command: build 30 | args: --all --verbose 31 | - uses: actions-rs/cargo@v1 32 | with: 33 | command: test 34 | args: --all --verbose 35 | 36 | api-docs: 37 | name: Publish API Docs to GitHub Pages 38 | runs-on: ubuntu-latest 39 | strategy: 40 | matrix: 41 | rust: 42 | - nightly 43 | steps: 44 | - uses: actions/checkout@v2 45 | - uses: actions-rs/toolchain@v1 46 | with: 47 | profile: minimal 48 | toolchain: ${{ matrix.rust }} 49 | override: true 50 | - uses: actions-rs/cargo@v1 51 | with: 52 | command: doc 53 | args: --all --verbose 54 | - name: Redirect top-level GitHub Pages 55 | run: "echo '' > target/doc/index.html" 56 | shell: bash 57 | - name: GitHub Pages 58 | uses: crazy-max/ghaction-github-pages@v1 59 | with: 60 | build_dir: target/doc 61 | env: 62 | GITHUB_PAT: ${{ secrets.GH_PAGES_ACCESS_TOKEN }} 63 | 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | Cargo.lock 3 | tarpaulin-report.html 4 | -------------------------------------------------------------------------------- /.rustfmt.toml: -------------------------------------------------------------------------------- 1 | max_width = 80 2 | tab_spaces = 4 3 | fn_single_line = true 4 | match_block_trailing_comma = true 5 | normalize_comments = true 6 | wrap_comments = true 7 | merge_imports = true 8 | reorder_impl_items = true 9 | use_field_init_shorthand = true 10 | use_try_shorthand = true 11 | normalize_doc_attributes = true 12 | report_todo = "Always" 13 | report_fixme = "Always" 14 | edition = "2018" 15 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "constraints" 3 | version = "0.1.0" 4 | authors = ["Michael-F-Bryan "] 5 | license = "MIT OR Apache-2.0" 6 | edition = "2018" 7 | description = "An experiment in writing algebraic constraint solvers for 3D CAD." 8 | publish = false 9 | readme = "README.md" 10 | 11 | [dependencies] 12 | euclid = "0.20.13" 13 | arrayvec = "0.5.1" 14 | smol_str = "0.1.15" 15 | nalgebra = "0.21.1" 16 | approx = "0.3.2" 17 | 18 | [dev-dependencies] 19 | pretty_assertions = "0.6.1" 20 | -------------------------------------------------------------------------------- /LICENSE_APACHE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /LICENSE_MIT.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Michael-F-Bryan 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Constraints 2 | 3 | [![Continuous integration](https://github.com/Michael-F-Bryan/constraints/workflows/Continuous%20integration/badge.svg?branch=master)](https://github.com/Michael-F-Bryan/constraints/actions) 4 | 5 | ([API Docs]) 6 | 7 | An experiment in writing algebraic constraint solvers for 3D CAD. 8 | 9 | ## License 10 | 11 | This project is licensed under either of 12 | 13 | * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE.md) or 14 | http://www.apache.org/licenses/LICENSE-2.0) 15 | * MIT license ([LICENSE-MIT](LICENSE-MIT.md) or 16 | http://opensource.org/licenses/MIT) 17 | 18 | at your option. 19 | 20 | It is recommended to always use [cargo-crev][crev] to verify the 21 | trustworthiness of each of your dependencies, including this one. 22 | 23 | ### Contribution 24 | 25 | Unless you explicitly state otherwise, any contribution intentionally 26 | submitted for inclusion in the work by you, as defined in the Apache-2.0 27 | license, shall be dual licensed as above, without any additional terms or 28 | conditions. 29 | 30 | The intent of this crate is to be free of soundness bugs. The developers will 31 | do their best to avoid them, and welcome help in analysing and fixing them. 32 | 33 | [API Docs]: https://michael-f-bryan.github.io/constraints 34 | [crev]: https://github.com/crev-dev/cargo-crev 35 | -------------------------------------------------------------------------------- /examples/cli.rs: -------------------------------------------------------------------------------- 1 | use constraints::{ops::Builtins, SystemOfEquations}; 2 | use std::io::{BufRead, BufReader, Write}; 3 | 4 | fn main() -> Result<(), Box> { 5 | let mut system = SystemOfEquations::new(); 6 | let stdin = std::io::stdin(); 7 | let mut lines = BufReader::new(stdin.lock()).lines(); 8 | 9 | loop { 10 | let mut stdout = std::io::stdout(); 11 | write!(stdout, "> ")?; 12 | stdout.flush()?; 13 | 14 | match lines.next() { 15 | Some(Ok(line)) => match line.parse() { 16 | Ok(equation) => system.push(equation), 17 | Err(e) => eprintln!("Unable to parse \"{}\": {}", line, e), 18 | }, 19 | Some(Err(e)) => return Err(Box::new(e)), 20 | None => break, 21 | } 22 | } 23 | 24 | println!(); 25 | 26 | let unknowns: Vec<_> = 27 | system.unknowns().iter().map(ToString::to_string).collect(); 28 | println!("Solving for {}", unknowns.join(", ")); 29 | 30 | let ctx = Builtins::default(); 31 | let solution = system.solve(&ctx)?; 32 | 33 | println!("Found:"); 34 | 35 | for (name, value) in &solution.known_values { 36 | println!(" {} = {}", name, value); 37 | } 38 | 39 | Ok(()) 40 | } 41 | -------------------------------------------------------------------------------- /src/equations.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | ops::{self, Context, EvaluationError}, 3 | solve::{Solution, SolveError}, 4 | Expression, Parameter, ParseError, 5 | }; 6 | use nalgebra::DVector as Vector; 7 | use std::{ 8 | fmt::Debug, 9 | iter::{Extend, FromIterator}, 10 | str::FromStr, 11 | }; 12 | 13 | #[derive(Debug, Clone, PartialEq)] 14 | pub struct Equation { 15 | pub(crate) body: Expression, 16 | } 17 | 18 | impl Equation { 19 | pub fn new(left: Expression, right: Expression) -> Self { 20 | debug_assert_ne!( 21 | left.params().count() + right.params().count(), 22 | 0, 23 | "Equations should contain at least one unknown" 24 | ); 25 | Equation { body: left - right } 26 | } 27 | } 28 | 29 | impl FromStr for Equation { 30 | type Err = ParseError; 31 | 32 | fn from_str(s: &str) -> Result { 33 | match s.find("=") { 34 | Some(index) => { 35 | let (left, right) = s.split_at(index); 36 | let right = &right[1..]; 37 | Ok(Equation::new(left.parse()?, right.parse()?)) 38 | }, 39 | None => Ok(Equation { body: s.parse()? }), 40 | } 41 | } 42 | } 43 | 44 | /// A builder for constructing a system of equations and solving them. 45 | #[derive(Debug, Default, Clone, PartialEq)] 46 | pub struct SystemOfEquations { 47 | pub(crate) equations: Vec, 48 | } 49 | 50 | impl SystemOfEquations { 51 | pub fn new() -> Self { SystemOfEquations::default() } 52 | 53 | pub fn with(mut self, equation: Equation) -> Self { 54 | self.push(equation); 55 | self 56 | } 57 | 58 | pub fn push(&mut self, equation: Equation) { 59 | self.equations.push(equation); 60 | } 61 | 62 | pub fn solve(self, ctx: &C) -> Result 63 | where 64 | C: Context, 65 | { 66 | let unknowns = self.unknowns(); 67 | crate::solve::solve(&self.equations, &unknowns, &self, ctx) 68 | } 69 | 70 | pub fn unknowns(&self) -> Vec { 71 | let mut unknowns: Vec<_> = self 72 | .equations 73 | .iter() 74 | .flat_map(|eq| eq.body.params()) 75 | .cloned() 76 | .collect(); 77 | unknowns.sort(); 78 | unknowns.dedup(); 79 | 80 | unknowns 81 | } 82 | 83 | pub fn num_unknowns(&self) -> usize { self.unknowns().len() } 84 | 85 | pub fn from_equations(equations: E) -> Result 86 | where 87 | E: IntoIterator, 88 | S: AsRef, 89 | { 90 | let mut system = SystemOfEquations::new(); 91 | 92 | for equation in equations { 93 | system.push(equation.as_ref().parse()?); 94 | } 95 | 96 | Ok(system) 97 | } 98 | 99 | pub(crate) fn evaluate( 100 | &self, 101 | lookup_parameter_value: F, 102 | ctx: &C, 103 | ) -> Result, EvaluationError> 104 | where 105 | F: Fn(&Parameter) -> Option, 106 | C: Context, 107 | { 108 | let mut values = Vec::new(); 109 | 110 | for equation in &self.equations { 111 | values.push(ops::evaluate( 112 | &equation.body, 113 | &lookup_parameter_value, 114 | ctx, 115 | )?); 116 | } 117 | 118 | Ok(Vector::from_vec(values)) 119 | } 120 | } 121 | 122 | impl Extend for SystemOfEquations { 123 | fn extend>(&mut self, iter: T) { 124 | self.equations.extend(iter); 125 | } 126 | } 127 | 128 | impl FromIterator for SystemOfEquations { 129 | fn from_iter>(iter: T) -> Self { 130 | SystemOfEquations { 131 | equations: Vec::from_iter(iter), 132 | } 133 | } 134 | } 135 | 136 | impl<'a> IntoIterator for &'a SystemOfEquations { 137 | type IntoIter = <&'a [Equation] as IntoIterator>::IntoIter; 138 | type Item = &'a Equation; 139 | 140 | fn into_iter(self) -> Self::IntoIter { self.equations.iter() } 141 | } 142 | 143 | impl IntoIterator for SystemOfEquations { 144 | type IntoIter = as IntoIterator>::IntoIter; 145 | type Item = Equation; 146 | 147 | fn into_iter(self) -> Self::IntoIter { self.equations.into_iter() } 148 | } 149 | -------------------------------------------------------------------------------- /src/expr.rs: -------------------------------------------------------------------------------- 1 | use crate::parse::ParseError; 2 | use smol_str::SmolStr; 3 | use std::{ 4 | fmt::{self, Display, Formatter}, 5 | ops::{Add, Div, Mul, Neg, Sub}, 6 | rc::Rc, 7 | str::FromStr, 8 | }; 9 | 10 | // PERF: Switch from Rc to Arc and use Arc::make_mut() 11 | // to get efficient copy-on-write semantics 12 | 13 | /// An expression. 14 | #[derive(Debug, Clone, PartialEq)] 15 | pub enum Expression { 16 | /// A free variable (e.g. `x`). 17 | Parameter(Parameter), 18 | /// A fixed constant (e.g. `3.14`). 19 | Constant(f64), 20 | /// An expression involving two operands. 21 | Binary { 22 | /// The left operand. 23 | left: Rc, 24 | /// The right operand. 25 | right: Rc, 26 | /// The binary operation being executed. 27 | op: BinaryOperation, 28 | }, 29 | /// Negate the expression. 30 | Negate(Rc), 31 | /// Invoke a builtin function. 32 | FunctionCall { 33 | /// The name of the function being called. 34 | name: SmolStr, 35 | /// The argument passed to this function call. 36 | argument: Rc, 37 | }, 38 | } 39 | 40 | impl Expression { 41 | /// Iterate over all [`Expression`]s in this [`Expression`] tree. 42 | pub fn iter(&self) -> impl Iterator + '_ { 43 | Iter { 44 | to_visit: vec![self], 45 | } 46 | } 47 | 48 | /// Iterate over all [`Parameter`]s mentioned in this [`Expression`]. 49 | pub fn params(&self) -> impl Iterator + '_ { 50 | self.iter().filter_map(|expr| match expr { 51 | Expression::Parameter(p) => Some(p), 52 | _ => None, 53 | }) 54 | } 55 | 56 | /// Does this [`Expression`] involve a particular [`Parameter`]? 57 | pub fn depends_on(&self, param: &Parameter) -> bool { 58 | self.params().any(|p| p == param) 59 | } 60 | 61 | /// Is this a [`Expression::Constant`] expression? 62 | pub fn is_constant(&self) -> bool { 63 | match self { 64 | Expression::Constant(_) => true, 65 | _ => false, 66 | } 67 | } 68 | 69 | /// Iterate over all functions used by this [`Expression`]. 70 | pub fn functions(&self) -> impl Iterator + '_ { 71 | self.iter().filter_map(|expr| match expr { 72 | Expression::FunctionCall { name, .. } => Some(name.as_ref()), 73 | _ => None, 74 | }) 75 | } 76 | } 77 | 78 | /// A depth-first iterator over the sub-[`Expression`]s in an [`Expression`]. 79 | #[derive(Debug)] 80 | struct Iter<'expr> { 81 | to_visit: Vec<&'expr Expression>, 82 | } 83 | 84 | impl<'expr> Iterator for Iter<'expr> { 85 | type Item = &'expr Expression; 86 | 87 | fn next(&mut self) -> Option { 88 | let next_item = self.to_visit.pop()?; 89 | 90 | match next_item { 91 | Expression::Binary { left, right, .. } => { 92 | self.to_visit.push(right); 93 | self.to_visit.push(left); 94 | }, 95 | Expression::Negate(inner) => self.to_visit.push(inner), 96 | Expression::FunctionCall { argument, .. } => { 97 | self.to_visit.push(argument) 98 | }, 99 | _ => {}, 100 | } 101 | 102 | Some(next_item) 103 | } 104 | } 105 | 106 | /// A free variable. 107 | #[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)] 108 | pub enum Parameter { 109 | /// A variable with associated name. 110 | Named(SmolStr), 111 | /// An anonymous variable generated by the system. 112 | Anonymous { number: usize }, 113 | } 114 | 115 | impl Parameter { 116 | /// Create a new [`Parameter::Named`] parameter, automatically converting 117 | /// the `name` to a `SmolStr`. 118 | pub fn named(name: S) -> Self 119 | where 120 | S: Into, 121 | { 122 | Parameter::Named(name.into()) 123 | } 124 | } 125 | 126 | impl Display for Parameter { 127 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 128 | match self { 129 | Parameter::Named(name) => write!(f, "{}", name), 130 | Parameter::Anonymous { number } => write!(f, "${}", number), 131 | } 132 | } 133 | } 134 | 135 | /// An operation that can be applied to two arguments. 136 | #[derive(Debug, Copy, Clone, PartialEq)] 137 | pub enum BinaryOperation { 138 | Plus, 139 | Minus, 140 | Times, 141 | Divide, 142 | } 143 | 144 | // define some operator overloads to make constructing an expression easier. 145 | 146 | impl Add for Expression { 147 | type Output = Expression; 148 | 149 | fn add(self, rhs: Expression) -> Expression { 150 | Expression::Binary { 151 | left: Rc::new(self), 152 | right: Rc::new(rhs), 153 | op: BinaryOperation::Plus, 154 | } 155 | } 156 | } 157 | 158 | impl Sub for Expression { 159 | type Output = Expression; 160 | 161 | fn sub(self, rhs: Expression) -> Expression { 162 | Expression::Binary { 163 | left: Rc::new(self), 164 | right: Rc::new(rhs), 165 | op: BinaryOperation::Minus, 166 | } 167 | } 168 | } 169 | 170 | impl Mul for Expression { 171 | type Output = Expression; 172 | 173 | fn mul(self, rhs: Expression) -> Expression { 174 | Expression::Binary { 175 | left: Rc::new(self), 176 | right: Rc::new(rhs), 177 | op: BinaryOperation::Times, 178 | } 179 | } 180 | } 181 | 182 | impl Div for Expression { 183 | type Output = Expression; 184 | 185 | fn div(self, rhs: Expression) -> Expression { 186 | Expression::Binary { 187 | left: Rc::new(self), 188 | right: Rc::new(rhs), 189 | op: BinaryOperation::Divide, 190 | } 191 | } 192 | } 193 | 194 | impl Neg for Expression { 195 | type Output = Expression; 196 | 197 | fn neg(self) -> Self::Output { Expression::Negate(Rc::new(self)) } 198 | } 199 | 200 | impl FromStr for Expression { 201 | type Err = ParseError; 202 | 203 | fn from_str(s: &str) -> Result { crate::parse(s) } 204 | } 205 | 206 | impl Display for Expression { 207 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 208 | match self { 209 | Expression::Parameter(p) => write!(f, "{}", p), 210 | Expression::Constant(value) => write!(f, "{}", value), 211 | Expression::Binary { left, right, op } => { 212 | write_with_precedence(op.precedence(), left, f)?; 213 | 214 | let middle = match op { 215 | BinaryOperation::Plus => " + ", 216 | BinaryOperation::Minus => " - ", 217 | BinaryOperation::Times => "*", 218 | BinaryOperation::Divide => "/", 219 | }; 220 | write!(f, "{}", middle)?; 221 | 222 | write_with_precedence(op.precedence(), right, f)?; 223 | 224 | Ok(()) 225 | }, 226 | Expression::Negate(inner) => { 227 | write!(f, "-")?; 228 | 229 | write_with_precedence( 230 | BinaryOperation::Times.precedence(), 231 | inner, 232 | f, 233 | )?; 234 | Ok(()) 235 | }, 236 | Expression::FunctionCall { name, argument } => { 237 | write!(f, "{}({})", name, argument) 238 | }, 239 | } 240 | } 241 | } 242 | 243 | impl Expression { 244 | fn precedence(&self) -> Precedence { 245 | match self { 246 | Expression::Parameter(_) 247 | | Expression::Constant(_) 248 | | Expression::FunctionCall { .. } => Precedence::Bi, 249 | Expression::Negate(_) => Precedence::Md, 250 | Expression::Binary { op, .. } => op.precedence(), 251 | } 252 | } 253 | } 254 | 255 | impl BinaryOperation { 256 | fn precedence(self) -> Precedence { 257 | match self { 258 | BinaryOperation::Plus | BinaryOperation::Minus => Precedence::As, 259 | BinaryOperation::Times | BinaryOperation::Divide => Precedence::Md, 260 | } 261 | } 262 | } 263 | 264 | #[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] 265 | enum Precedence { 266 | Bi, 267 | Md, 268 | As, 269 | } 270 | 271 | fn write_with_precedence( 272 | current_precedence: Precedence, 273 | expr: &Expression, 274 | f: &mut Formatter<'_>, 275 | ) -> fmt::Result { 276 | if expr.precedence() > current_precedence { 277 | // we need parentheses to maintain ordering 278 | write!(f, "({})", expr) 279 | } else { 280 | write!(f, "{}", expr) 281 | } 282 | } 283 | 284 | #[cfg(test)] 285 | mod tests { 286 | use super::*; 287 | 288 | #[test] 289 | fn pretty_printing_works_similarly_to_a_human() { 290 | let inputs = vec![ 291 | (Expression::Constant(3.0), "3"), 292 | ( 293 | Expression::FunctionCall { 294 | name: "sin".into(), 295 | argument: Rc::new(Expression::Constant(5.0)), 296 | }, 297 | "sin(5)", 298 | ), 299 | (Expression::Negate(Rc::new(Expression::Constant(5.0))), "-5"), 300 | ( 301 | Expression::Negate(Rc::new(Expression::FunctionCall { 302 | name: "sin".into(), 303 | argument: Rc::new(Expression::Constant(5.0)), 304 | })), 305 | "-sin(5)", 306 | ), 307 | ( 308 | Expression::Binary { 309 | left: Rc::new(Expression::Constant(1.0)), 310 | right: Rc::new(Expression::Constant(1.0)), 311 | op: BinaryOperation::Plus, 312 | }, 313 | "1 + 1", 314 | ), 315 | ( 316 | Expression::Binary { 317 | left: Rc::new(Expression::Constant(1.0)), 318 | right: Rc::new(Expression::Constant(1.0)), 319 | op: BinaryOperation::Minus, 320 | }, 321 | "1 - 1", 322 | ), 323 | ( 324 | Expression::Binary { 325 | left: Rc::new(Expression::Constant(1.0)), 326 | right: Rc::new(Expression::Constant(1.0)), 327 | op: BinaryOperation::Times, 328 | }, 329 | "1*1", 330 | ), 331 | ( 332 | Expression::Binary { 333 | left: Rc::new(Expression::Constant(1.0)), 334 | right: Rc::new(Expression::Constant(1.0)), 335 | op: BinaryOperation::Divide, 336 | }, 337 | "1/1", 338 | ), 339 | ( 340 | Expression::Negate(Rc::new(Expression::Binary { 341 | left: Rc::new(Expression::Constant(1.0)), 342 | right: Rc::new(Expression::Constant(1.0)), 343 | op: BinaryOperation::Plus, 344 | })), 345 | "-(1 + 1)", 346 | ), 347 | ( 348 | Expression::Negate(Rc::new(Expression::Binary { 349 | left: Rc::new(Expression::Constant(1.0)), 350 | right: Rc::new(Expression::Constant(1.0)), 351 | op: BinaryOperation::Times, 352 | })), 353 | "-1*1", 354 | ), 355 | ( 356 | Expression::Binary { 357 | left: Rc::new(Expression::Binary { 358 | left: Rc::new(Expression::Constant(1.0)), 359 | right: Rc::new(Expression::Constant(2.0)), 360 | op: BinaryOperation::Plus, 361 | }), 362 | right: Rc::new(Expression::Constant(3.0)), 363 | op: BinaryOperation::Divide, 364 | }, 365 | "(1 + 2)/3", 366 | ), 367 | ( 368 | Expression::Binary { 369 | left: Rc::new(Expression::Constant(3.0)), 370 | right: Rc::new(Expression::Binary { 371 | left: Rc::new(Expression::Constant(1.0)), 372 | right: Rc::new(Expression::Constant(2.0)), 373 | op: BinaryOperation::Times, 374 | }), 375 | op: BinaryOperation::Minus, 376 | }, 377 | "3 - 1*2", 378 | ), 379 | ]; 380 | 381 | for (expr, should_be) in inputs { 382 | let got = expr.to_string(); 383 | assert_eq!(got, should_be); 384 | } 385 | } 386 | 387 | #[test] 388 | fn iterate_over_parameters_in_an_expression() { 389 | let expr: Expression = 390 | "a + sin(5*(b + (c - d)) / -e) - a * f".parse().unwrap(); 391 | let a = Parameter::named("a"); 392 | let b = Parameter::named("b"); 393 | let c = Parameter::named("c"); 394 | let d = Parameter::named("d"); 395 | let e = Parameter::named("e"); 396 | let f = Parameter::named("f"); 397 | 398 | let got: Vec<_> = expr.params().collect(); 399 | 400 | assert_eq!(got, &[&a, &b, &c, &d, &e, &a, &f]); 401 | } 402 | } 403 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! The symbolic algebra system. 2 | 3 | #[cfg(test)] 4 | #[macro_use] 5 | extern crate pretty_assertions; 6 | 7 | mod equations; 8 | mod expr; 9 | pub mod ops; 10 | mod parse; 11 | mod solve; 12 | 13 | pub use equations::{Equation, SystemOfEquations}; 14 | pub use expr::{BinaryOperation, Expression, Parameter}; 15 | pub use parse::{parse, ParseError, TokenKind}; 16 | -------------------------------------------------------------------------------- /src/ops.rs: -------------------------------------------------------------------------------- 1 | //! [`Expression`] operations. 2 | 3 | use crate::{BinaryOperation, Expression, Parameter}; 4 | use euclid::approxeq::ApproxEq; 5 | use smol_str::SmolStr; 6 | use std::{ 7 | error::Error, 8 | fmt::{self, Display, Formatter}, 9 | rc::Rc, 10 | }; 11 | 12 | /// Contextual information used when evaluating an [`Expression`]. 13 | pub trait Context { 14 | /// Evaluate a function by name. 15 | fn evaluate_function( 16 | &self, 17 | name: &str, 18 | argument: f64, 19 | ) -> Result; 20 | 21 | /// For some [`Parameter`], `x`, and a function called `name`, get 22 | /// `name'(x)`. 23 | fn differentiate_function( 24 | &self, 25 | name: &str, 26 | param: &Parameter, 27 | ) -> Result; 28 | } 29 | 30 | /// Errors that may occur while evaulating an [`Expression`]. 31 | #[derive(Debug, Clone, PartialEq)] 32 | pub enum EvaluationError { 33 | /// We don't know this function. 34 | UnknownFunction { 35 | /// The function's name. 36 | name: SmolStr, 37 | }, 38 | /// The [`Context`] is unable to differentiate the function. 39 | UnableToDifferentiate { 40 | /// The function that we tried to differentiate. 41 | name: SmolStr, 42 | }, 43 | /// The [`Expression`] contains a [`Parameter`] which hasn't yet been 44 | /// evaluated. 45 | /// 46 | /// Consider using [`substitute()`] to replace the [`Parameter`] with its 47 | /// value (i.e. a [`Expression::Constant`]). 48 | UnevaluatedParameter(Parameter), 49 | } 50 | 51 | impl Display for EvaluationError { 52 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 53 | match self { 54 | EvaluationError::UnknownFunction { name } => { 55 | write!(f, "No known function called \"{}\"", name) 56 | }, 57 | EvaluationError::UnableToDifferentiate { name } => { 58 | write!(f, "Unable to differentiate \"{}\"", name) 59 | }, 60 | EvaluationError::UnevaluatedParameter(p) => { 61 | write!(f, "The parameter, {}, needs to have a value", p) 62 | }, 63 | } 64 | } 65 | } 66 | 67 | impl Error for EvaluationError {} 68 | 69 | /// The set of builtin functions. 70 | #[derive(Debug, Default)] 71 | pub struct Builtins; 72 | 73 | impl Context for Builtins { 74 | fn evaluate_function( 75 | &self, 76 | name: &str, 77 | argument: f64, 78 | ) -> Result { 79 | match name { 80 | "sin" => Ok(argument.to_radians().sin()), 81 | "cos" => Ok(argument.to_radians().cos()), 82 | "tan" => Ok(argument.to_radians().tan()), 83 | "asin" => Ok(argument.asin().to_degrees()), 84 | "acos" => Ok(argument.acos().to_degrees()), 85 | "atan" => Ok(argument.atan().to_degrees()), 86 | "sqrt" => Ok(argument.sqrt()), 87 | _ => Err(EvaluationError::UnknownFunction { name: name.into() }), 88 | } 89 | } 90 | 91 | fn differentiate_function( 92 | &self, 93 | name: &str, 94 | param: &Parameter, 95 | ) -> Result { 96 | match name { 97 | "sin" => Ok(Expression::FunctionCall { 98 | name: "cos".into(), 99 | argument: Rc::new(Expression::Parameter(param.clone())), 100 | }), 101 | "cos" => Ok(-Expression::FunctionCall { 102 | name: "sin".into(), 103 | argument: Rc::new(Expression::Parameter(param.clone())), 104 | }), 105 | "sqrt" => { 106 | let sqrt_x = Expression::FunctionCall { 107 | name: "sqrt".into(), 108 | argument: Rc::new(Expression::Parameter(param.clone())), 109 | }; 110 | Ok(Expression::Constant(0.5) / sqrt_x) 111 | }, 112 | _ => Err(EvaluationError::UnableToDifferentiate { 113 | name: name.into(), 114 | }), 115 | } 116 | } 117 | } 118 | 119 | /// Simplify an expression by evaluating all constant operations. 120 | pub fn fold_constants(expr: &Expression, ctx: &C) -> Expression 121 | where 122 | C: Context, 123 | { 124 | match expr { 125 | Expression::Binary { left, right, op } => { 126 | fold_binary_op(left, right, *op, ctx) 127 | }, 128 | Expression::Negate(expr) => match fold_constants(expr, ctx) { 129 | Expression::Constant(value) => Expression::Constant(-value), 130 | // double negative 131 | Expression::Negate(inner) => inner.as_ref().clone(), 132 | other => Expression::Negate(Rc::new(other)), 133 | }, 134 | Expression::FunctionCall { name, argument } => { 135 | let argument = fold_constants(argument, ctx); 136 | 137 | if let Expression::Constant(argument) = argument { 138 | if let Ok(result) = ctx.evaluate_function(name, argument) { 139 | return Expression::Constant(result); 140 | } 141 | } 142 | 143 | Expression::FunctionCall { 144 | name: name.clone(), 145 | argument: Rc::new(argument), 146 | } 147 | }, 148 | _ => expr.clone(), 149 | } 150 | } 151 | 152 | fn fold_binary_op( 153 | left: &Expression, 154 | right: &Expression, 155 | op: BinaryOperation, 156 | ctx: &C, 157 | ) -> Expression 158 | where 159 | C: Context, 160 | { 161 | let left = fold_constants(left, ctx); 162 | let right = fold_constants(right, ctx); 163 | 164 | // If our operands contain constants, we can use arithmetic's identity laws 165 | // to simplify things 166 | match (left, right, op) { 167 | ( 168 | Expression::Parameter(p_left), 169 | Expression::Parameter(p_right), 170 | BinaryOperation::Plus, 171 | ) if p_left == p_right => { 172 | Expression::Constant(2.0) * Expression::Parameter(p_right.clone()) 173 | }, 174 | ( 175 | Expression::Parameter(p_left), 176 | Expression::Parameter(p_right), 177 | BinaryOperation::Minus, 178 | ) if p_left == p_right => Expression::Constant(0.0), 179 | ( 180 | Expression::Parameter(p_left), 181 | Expression::Parameter(p_right), 182 | BinaryOperation::Divide, 183 | ) if p_left == p_right => Expression::Constant(1.0), 184 | 185 | // x + 0 = x 186 | (Expression::Constant(l), right, BinaryOperation::Plus) 187 | if l.approx_eq(&0.0) => 188 | { 189 | right 190 | }, 191 | (left, Expression::Constant(r), BinaryOperation::Plus) 192 | if r.approx_eq(&0.0) => 193 | { 194 | left 195 | }, 196 | 197 | // 0 * x = 0 198 | (Expression::Constant(l), _, BinaryOperation::Times) 199 | if l.approx_eq(&0.0) => 200 | { 201 | Expression::Constant(0.0) 202 | }, 203 | (_, Expression::Constant(r), BinaryOperation::Times) 204 | if r.approx_eq(&0.0) => 205 | { 206 | Expression::Constant(0.0) 207 | }, 208 | 209 | // 1 * x = x 210 | (Expression::Constant(l), right, BinaryOperation::Times) 211 | if l.approx_eq(&1.0) => 212 | { 213 | right 214 | }, 215 | (left, Expression::Constant(r), BinaryOperation::Times) 216 | if r.approx_eq(&1.0) => 217 | { 218 | left 219 | }, 220 | 221 | // 0 / x = 0 222 | (Expression::Constant(l), _, BinaryOperation::Divide) 223 | if l.approx_eq(&0.0) => 224 | { 225 | Expression::Constant(0.0) 226 | }, 227 | 228 | // x / 1 = x 229 | (left, Expression::Constant(r), BinaryOperation::Divide) 230 | if r.approx_eq(&1.0) => 231 | { 232 | left 233 | }, 234 | 235 | // 0 - x = -x 236 | (Expression::Constant(l), right, BinaryOperation::Minus) 237 | if l.approx_eq(&0.0) => 238 | { 239 | -right 240 | }, 241 | 242 | // x - 0 = x 243 | (left, Expression::Constant(r), BinaryOperation::Minus) 244 | if r.approx_eq(&0.0) => 245 | { 246 | left 247 | }, 248 | 249 | // (x * y) * z 250 | ( 251 | Expression::Constant(constant_a), 252 | Expression::Binary { 253 | left, 254 | right, 255 | op: BinaryOperation::Times, 256 | }, 257 | BinaryOperation::Times, 258 | ) if left.is_constant() || right.is_constant() => { 259 | let (constant_b, expr) = match (&*left, &*right) { 260 | (Expression::Constant(left), right) => (left, right), 261 | (left, Expression::Constant(right)) => (right, left), 262 | _ => unreachable!(), 263 | }; 264 | Expression::Constant(constant_a * constant_b) 265 | * Expression::clone(expr) 266 | }, 267 | ( 268 | Expression::Binary { 269 | left, 270 | right, 271 | op: BinaryOperation::Times, 272 | }, 273 | Expression::Constant(constant_a), 274 | BinaryOperation::Times, 275 | ) if left.is_constant() || right.is_constant() => { 276 | let (constant_b, expr) = match (&*left, &*right) { 277 | (Expression::Constant(left), right) => (left, right), 278 | (left, Expression::Constant(right)) => (right, left), 279 | _ => unreachable!(), 280 | }; 281 | Expression::Constant(constant_a * constant_b) 282 | * Expression::clone(expr) 283 | }, 284 | 285 | // Evaluate in-place 286 | (Expression::Constant(l), Expression::Constant(r), op) => { 287 | let value = match op { 288 | BinaryOperation::Plus => l + r, 289 | BinaryOperation::Minus => l - r, 290 | BinaryOperation::Times => l * r, 291 | BinaryOperation::Divide => l / r, 292 | }; 293 | 294 | Expression::Constant(value) 295 | }, 296 | 297 | // Oh well, we tried 298 | (left, right, op) => Expression::Binary { 299 | left: Rc::new(left), 300 | right: Rc::new(right), 301 | op, 302 | }, 303 | } 304 | } 305 | 306 | /// Replace all references to a [`Parameter`] with an [`Expression`]. 307 | pub fn substitute( 308 | expression: &Expression, 309 | param: &Parameter, 310 | value: &Expression, 311 | ) -> Expression { 312 | match expression { 313 | Expression::Parameter(p) => { 314 | if p == param { 315 | value.clone() 316 | } else { 317 | Expression::Parameter(p.clone()) 318 | } 319 | }, 320 | Expression::Constant(value) => Expression::Constant(*value), 321 | Expression::Binary { left, right, op } => { 322 | let left = substitute(left, param, value); 323 | let right = substitute(right, param, value); 324 | Expression::Binary { 325 | left: Rc::new(left), 326 | right: Rc::new(right), 327 | op: *op, 328 | } 329 | }, 330 | Expression::Negate(inner) => -substitute(inner, param, value), 331 | Expression::FunctionCall { name, argument } => { 332 | Expression::FunctionCall { 333 | name: name.clone(), 334 | argument: Rc::new(substitute(argument, param, value)), 335 | } 336 | }, 337 | } 338 | } 339 | 340 | /// Calculate an [`Expression`]'s partial derivative with respect to a 341 | /// particular [`Parameter`]. 342 | pub fn partial_derivative( 343 | expr: &Expression, 344 | param: &Parameter, 345 | ctx: &C, 346 | ) -> Result 347 | where 348 | C: Context, 349 | { 350 | let got = match expr { 351 | Expression::Parameter(p) => { 352 | if p == param { 353 | Expression::Constant(1.0) 354 | } else { 355 | Expression::Constant(0.0) 356 | } 357 | }, 358 | Expression::Constant(_) => Expression::Constant(0.0), 359 | Expression::Binary { 360 | left, 361 | right, 362 | op: BinaryOperation::Plus, 363 | } => { 364 | partial_derivative(left, param, ctx)? 365 | + partial_derivative(right, param, ctx)? 366 | }, 367 | Expression::Binary { 368 | left, 369 | right, 370 | op: BinaryOperation::Minus, 371 | } => { 372 | partial_derivative(left, param, ctx)? 373 | - partial_derivative(right, param, ctx)? 374 | }, 375 | Expression::Binary { 376 | left, 377 | right, 378 | op: BinaryOperation::Times, 379 | } => { 380 | // The product rule 381 | let d_left = partial_derivative(left, param, ctx)?; 382 | let d_right = partial_derivative(right, param, ctx)?; 383 | let left = Expression::clone(left); 384 | let right = Expression::clone(right); 385 | 386 | d_left * right + d_right * left 387 | }, 388 | Expression::Binary { 389 | left, 390 | right, 391 | op: BinaryOperation::Divide, 392 | } => { 393 | // The quotient rule 394 | let d_left = partial_derivative(left, param, ctx)?; 395 | let d_right = partial_derivative(right, param, ctx)?; 396 | let right = Expression::clone(right); 397 | let left = Expression::clone(left); 398 | 399 | (d_left * right.clone() + left * d_right) / (right.clone() * right) 400 | }, 401 | 402 | Expression::Negate(inner) => -partial_derivative(inner, param, ctx)?, 403 | Expression::FunctionCall { name, argument } => { 404 | // implement the chain rule: (f o g)' = (f' o g) * g' 405 | let g = Parameter::named("__temp__"); 406 | let f_dash_of_g = ctx.differentiate_function(name, &g)?; 407 | let g_dash = partial_derivative(argument, param, ctx)?; 408 | 409 | substitute(&f_dash_of_g, &g, argument) * g_dash 410 | }, 411 | }; 412 | 413 | Ok(got) 414 | } 415 | 416 | pub fn evaluate( 417 | expr: &Expression, 418 | parameter_value: &F, 419 | ctx: &C, 420 | ) -> Result 421 | where 422 | C: Context, 423 | F: Fn(&Parameter) -> Option, 424 | { 425 | match expr { 426 | Expression::Parameter(p) => match parameter_value(p) { 427 | Some(value) => Ok(value), 428 | None => Err(EvaluationError::UnevaluatedParameter(p.clone())), 429 | }, 430 | Expression::Constant(value) => Ok(*value), 431 | Expression::Binary { left, right, op } => { 432 | let left = evaluate(left, parameter_value, ctx)?; 433 | let right = evaluate(right, parameter_value, ctx)?; 434 | Ok(match op { 435 | BinaryOperation::Plus => left + right, 436 | BinaryOperation::Minus => left - right, 437 | BinaryOperation::Times => left * right, 438 | BinaryOperation::Divide => left / right, 439 | }) 440 | }, 441 | Expression::Negate(inner) => { 442 | let inner = evaluate(inner, parameter_value, ctx)?; 443 | Ok(-inner) 444 | }, 445 | Expression::FunctionCall { name, argument } => { 446 | let argument = evaluate(argument, parameter_value, ctx)?; 447 | ctx.evaluate_function(name, argument) 448 | }, 449 | } 450 | } 451 | 452 | #[cfg(test)] 453 | mod tests { 454 | use super::*; 455 | 456 | #[test] 457 | fn constant_fold_simple_arithmetic() { 458 | let inputs = vec![ 459 | ("1", 1.0), 460 | ("1 + 1.5", 1.0 + 1.5), 461 | ("1 - 1.5", 1.0 - 1.5), 462 | ("2 * 3", 2.0 * 3.0), 463 | ("4 / 2", 4.0 / 2.0), 464 | ("sqrt(4)", 4_f64.sqrt()), 465 | ("sqrt(2 + 2)", (2_f64 + 2.0).sqrt()), 466 | ("sin(90)", 90_f64.to_radians().sin()), 467 | ("atan(1)", 45.0), 468 | ("sqrt(2 + sqrt(4))", (2.0 + 4_f64.sqrt()).sqrt()), 469 | ("-(1 + 2)", -(1.0 + 2.0)), 470 | ("0 * x", 0.0), 471 | ("x - x", 0.0), 472 | ("x/x", 1.0), 473 | ]; 474 | let ctx = Builtins::default(); 475 | 476 | for (src, should_be) in inputs { 477 | let expr: Expression = src.parse().unwrap(); 478 | let got = fold_constants(&expr, &ctx); 479 | 480 | match got { 481 | Expression::Constant(value) => assert_eq!( 482 | value, should_be, 483 | "{} -> {} != {}", 484 | expr, value, should_be 485 | ), 486 | other => panic!( 487 | "Expected a constant expression, but got \"{}\"", 488 | other 489 | ), 490 | } 491 | } 492 | } 493 | 494 | #[test] 495 | fn constant_folding_leaves_unknowns_unevaluated() { 496 | let inputs = vec![ 497 | ("x", "x"), 498 | ("-(2 * 3 + x)", "-(6 + x)"), 499 | ("unknown_function(3)", "unknown_function(3)"), 500 | ("2 * x * 3", "6 * x"), 501 | ("x + 5", "x + 5"), 502 | ("x + 5*2", "x + 10"), 503 | ("x + x", "2*x"), 504 | ("0 + x", "x"), 505 | ("x + 0", "x"), 506 | ("1 * x", "x"), 507 | ("x * 1", "x"), 508 | ("x - 0", "x"), 509 | ("0 - x", "-x"), 510 | ("x / 1", "x"), 511 | ("--x", "x"), 512 | ("(x + x)*3 + 5", "6*x + 5"), 513 | ]; 514 | let ctx = Builtins::default(); 515 | 516 | for (src, should_be) in inputs { 517 | let expr: Expression = src.parse().unwrap(); 518 | 519 | let got = fold_constants(&expr, &ctx); 520 | 521 | let should_be: Expression = should_be.parse().unwrap(); 522 | 523 | assert_eq!(got, should_be, "{} != {}", got, should_be); 524 | } 525 | } 526 | 527 | #[test] 528 | fn basic_substitutions() { 529 | let parameter = Parameter::named("x"); 530 | let inputs = vec![ 531 | ("1 + 2", "3", "1 + 2"), 532 | ("x", "5", "5"), 533 | ("y", "5", "y"), 534 | ("x + 5", "5", " 5 + 5"), 535 | ("-x", "5", "-5"), 536 | ("sin(x)", "y + y", "sin(y + y)"), 537 | ]; 538 | 539 | for (src, new_value, should_be) in inputs { 540 | let original: Expression = src.parse().unwrap(); 541 | let new_value: Expression = new_value.parse().unwrap(); 542 | let should_be: Expression = should_be.parse().unwrap(); 543 | 544 | let got = substitute(&original, ¶meter, &new_value); 545 | 546 | assert_eq!(got, should_be, "{} != {}", got, should_be); 547 | } 548 | } 549 | 550 | #[test] 551 | fn differentiate_wrt_x() { 552 | let x = Parameter::named("x"); 553 | let inputs = vec![ 554 | ("x", "1"), 555 | ("1", "0"), 556 | ("x*x", "2 * x"), 557 | ("3*x*x + 5*x + 2", "6*x + 5"), 558 | ("x - y", "1"), 559 | ("sin(x)", "cos(x)"), 560 | ("cos(x)", "-sin(x)"), 561 | ("sqrt(x)", "0.5 / sqrt(x)"), 562 | ("x/y", "y/y*y"), // = 1/y, simplification just isn't smart enough 563 | ]; 564 | let ctx = Builtins::default(); 565 | 566 | for (src, should_be) in inputs { 567 | let original: Expression = src.parse().unwrap(); 568 | let should_be: Expression = should_be.parse().unwrap(); 569 | 570 | let got = partial_derivative(&original, &x, &ctx).unwrap(); 571 | let got = fold_constants(&got, &ctx); 572 | 573 | assert_eq!(got, should_be, "{} != {}", got, should_be); 574 | } 575 | } 576 | 577 | #[test] 578 | fn evaluate_some_expressions() { 579 | let inputs = 580 | vec![("1", 1.0), ("1+1", 2.0), ("sin(90)", 1.0), ("x", 0.5)]; 581 | let ctx = Builtins::default(); 582 | 583 | for (src, should_be) in inputs { 584 | let expr: Expression = src.parse().unwrap(); 585 | 586 | let got = evaluate(&expr, &get_parameter_by_name, &ctx).unwrap(); 587 | 588 | assert_eq!(got, should_be); 589 | } 590 | 591 | fn get_parameter_by_name(p: &Parameter) -> Option { 592 | let name = match p { 593 | Parameter::Named(name) => name.as_str(), 594 | _ => return None, 595 | }; 596 | 597 | match name { 598 | "x" => Some(0.5), 599 | "y" => Some(-10.0), 600 | "z" => Some(0.0), 601 | _ => None, 602 | } 603 | } 604 | } 605 | } 606 | -------------------------------------------------------------------------------- /src/parse.rs: -------------------------------------------------------------------------------- 1 | use crate::{BinaryOperation, Expression, Parameter}; 2 | use std::{ 3 | fmt::{self, Display, Formatter}, 4 | iter::Peekable, 5 | ops::Range, 6 | rc::Rc, 7 | }; 8 | 9 | /// Parse an [`Expression`] tree from some text. 10 | pub fn parse(s: &str) -> Result { 11 | Parser::new(s).parse() 12 | } 13 | 14 | /// A simple recursive descent parser (`LL(1)`) for converting a string into an 15 | /// expression tree. 16 | /// 17 | /// The grammar: 18 | /// 19 | /// ```text 20 | /// expression := term "+" expression 21 | /// | term "-" expression 22 | /// | term 23 | /// 24 | /// term := factor "*" term 25 | /// | factor "/" term 26 | /// | factor 27 | /// 28 | /// factor := "-" term 29 | /// | variable_or_function_call 30 | /// | "(" expression ")" 31 | /// | NUMBER 32 | /// 33 | /// variable_or_function_call = IDENTIFIER "(" expression ")" 34 | /// | IDENTIFIER 35 | /// ``` 36 | #[derive(Debug, Clone)] 37 | pub(crate) struct Parser<'a> { 38 | tokens: Peekable>, 39 | } 40 | 41 | impl<'a> Parser<'a> { 42 | pub(crate) fn new(src: &'a str) -> Self { 43 | Parser { 44 | tokens: Tokens::new(src).peekable(), 45 | } 46 | } 47 | 48 | pub(crate) fn parse(mut self) -> Result { 49 | let expr = self.expression()?; 50 | 51 | match self.tokens.next() { 52 | None => Ok(expr), 53 | Some(Ok(token)) => { 54 | panic!("Not all tokens consumed! Found {:?}", token) 55 | }, 56 | Some(Err(e)) => Err(e), 57 | } 58 | } 59 | 60 | fn peek(&mut self) -> Option { 61 | self.tokens 62 | .peek() 63 | .and_then(|result| result.as_ref().ok()) 64 | .map(|tok| tok.kind) 65 | } 66 | 67 | fn advance(&mut self) -> Result, ParseError> { 68 | match self.tokens.next() { 69 | Some(result) => result, 70 | None => Err(ParseError::UnexpectedEndOfInput), 71 | } 72 | } 73 | 74 | fn expression(&mut self) -> Result { 75 | let left = self.term()?; 76 | 77 | self.then_right_part_of_binary_op( 78 | left, 79 | &[TokenKind::Plus, TokenKind::Minus], 80 | |p| p.expression(), 81 | ) 82 | } 83 | 84 | fn term(&mut self) -> Result { 85 | let left = self.factor()?; 86 | 87 | self.then_right_part_of_binary_op( 88 | left, 89 | &[TokenKind::Times, TokenKind::Divide], 90 | |p| p.term(), 91 | ) 92 | } 93 | 94 | fn then_right_part_of_binary_op( 95 | &mut self, 96 | left: Expression, 97 | expected: &[TokenKind], 98 | then: F, 99 | ) -> Result 100 | where 101 | F: FnOnce(&mut Parser<'_>) -> Result, 102 | { 103 | if let Some(kind) = self.peek() { 104 | for candidate in expected { 105 | if *candidate == kind { 106 | // skip past the operator 107 | let _ = self.advance()?; 108 | // and parse the second bit 109 | let right = then(self)?; 110 | 111 | return Ok(Expression::Binary { 112 | left: Rc::new(left), 113 | right: Rc::new(right), 114 | op: candidate.as_binary_op(), 115 | }); 116 | } 117 | } 118 | } 119 | 120 | Ok(left) 121 | } 122 | 123 | fn factor(&mut self) -> Result { 124 | let expected = 125 | &[TokenKind::Number, TokenKind::Identifier, TokenKind::Minus]; 126 | 127 | match self.peek() { 128 | Some(TokenKind::Number) => { 129 | return self.number(); 130 | }, 131 | Some(TokenKind::Minus) => { 132 | let _ = self.advance()?; 133 | let operand = self.factor()?; 134 | return Ok(-operand); 135 | }, 136 | Some(TokenKind::Identifier) => { 137 | return self.variable_or_function_call() 138 | }, 139 | Some(TokenKind::OpenParen) => { 140 | let _ = self.advance()?; 141 | let expr = self.expression()?; 142 | let close_paren = self.advance()?; 143 | 144 | if close_paren.kind == TokenKind::CloseParen { 145 | return Ok(expr); 146 | } else { 147 | return Err(ParseError::UnexpectedToken { 148 | found: close_paren.kind, 149 | span: close_paren.span, 150 | expected: &[TokenKind::CloseParen], 151 | }); 152 | } 153 | }, 154 | _ => {}, 155 | } 156 | 157 | // we couldn't parse the factor, return a nice error 158 | match self.tokens.next() { 159 | Some(Ok(Token { span, kind, .. })) => { 160 | Err(ParseError::UnexpectedToken { 161 | found: kind, 162 | expected, 163 | span, 164 | }) 165 | }, 166 | Some(Err(e)) => Err(e), 167 | None => Err(ParseError::UnexpectedEndOfInput), 168 | } 169 | } 170 | 171 | fn variable_or_function_call(&mut self) -> Result { 172 | let ident = self.advance()?; 173 | debug_assert_eq!(ident.kind, TokenKind::Identifier); 174 | 175 | if self.peek() == Some(TokenKind::OpenParen) { 176 | self.function_call(ident) 177 | } else { 178 | Ok(Expression::Parameter(Parameter::named(ident.text))) 179 | } 180 | } 181 | 182 | fn function_call( 183 | &mut self, 184 | identifier: Token<'a>, 185 | ) -> Result { 186 | let open_paren = self.advance()?; 187 | debug_assert_eq!(open_paren.kind, TokenKind::OpenParen); 188 | 189 | let argument = self.expression()?; 190 | 191 | let Token { kind, span, .. } = self.advance()?; 192 | 193 | if kind == TokenKind::CloseParen { 194 | Ok(Expression::FunctionCall { 195 | name: identifier.text.into(), 196 | argument: Rc::new(argument), 197 | }) 198 | } else { 199 | Err(ParseError::UnexpectedToken { 200 | found: kind, 201 | span, 202 | expected: &[TokenKind::CloseParen], 203 | }) 204 | } 205 | } 206 | 207 | fn number(&mut self) -> Result { 208 | let token = self 209 | .tokens 210 | .next() 211 | .ok_or(ParseError::UnexpectedEndOfInput)??; 212 | 213 | debug_assert_eq!(token.kind, TokenKind::Number); 214 | let number = 215 | token.text.parse().expect("Guaranteed correct by the lexer"); 216 | 217 | Ok(Expression::Constant(number)) 218 | } 219 | } 220 | 221 | /// Possible errors that may occur while parsing. 222 | #[derive(Debug, Clone, PartialEq)] 223 | pub enum ParseError { 224 | InvalidCharacter { 225 | character: char, 226 | index: usize, 227 | }, 228 | UnexpectedEndOfInput, 229 | UnexpectedToken { 230 | found: TokenKind, 231 | span: Range, 232 | expected: &'static [TokenKind], 233 | }, 234 | } 235 | 236 | impl Display for ParseError { 237 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 238 | match self { 239 | ParseError::InvalidCharacter { character, index } => write!( 240 | f, 241 | "Unexpected character, \"{}\" at index {}", 242 | character, index 243 | ), 244 | ParseError::UnexpectedEndOfInput => { 245 | write!(f, "Unexpected end of input") 246 | }, 247 | ParseError::UnexpectedToken { 248 | found, 249 | span, 250 | expected, 251 | } => { 252 | write!( 253 | f, 254 | "Found a {:?} at {:?} but was expecting one of [", 255 | found, span 256 | )?; 257 | for (i, kind) in expected.iter().enumerate() { 258 | if i > 0 { 259 | write!(f, ", ")?; 260 | } 261 | 262 | write!(f, "{:?}", kind)?; 263 | } 264 | write!(f, "]")?; 265 | 266 | Ok(()) 267 | }, 268 | } 269 | } 270 | } 271 | 272 | #[derive(Debug, Clone, PartialEq)] 273 | struct Tokens<'a> { 274 | src: &'a str, 275 | cursor: usize, 276 | } 277 | 278 | impl<'a> Tokens<'a> { 279 | fn new(src: &'a str) -> Self { Tokens { src, cursor: 0 } } 280 | 281 | fn rest(&self) -> &'a str { &self.src[self.cursor..] } 282 | 283 | fn peek(&self) -> Option { self.rest().chars().next() } 284 | 285 | fn advance(&mut self) -> Option { 286 | let c = self.peek()?; 287 | self.cursor += c.len_utf8(); 288 | Some(c) 289 | } 290 | 291 | fn chomp( 292 | &mut self, 293 | kind: TokenKind, 294 | ) -> Option, ParseError>> { 295 | let start = self.cursor; 296 | self.advance()?; 297 | let end = self.cursor; 298 | 299 | let tok = Token { 300 | text: &self.src[start..end], 301 | span: start..end, 302 | kind, 303 | }; 304 | 305 | Some(Ok(tok)) 306 | } 307 | 308 | fn take_while

( 309 | &mut self, 310 | mut predicate: P, 311 | ) -> Option<(&'a str, Range)> 312 | where 313 | P: FnMut(char) -> bool, 314 | { 315 | let start = self.cursor; 316 | 317 | while let Some(c) = self.peek() { 318 | if !predicate(c) { 319 | break; 320 | } 321 | 322 | self.advance(); 323 | } 324 | 325 | let end = self.cursor; 326 | 327 | if start != end { 328 | let text = &self.src[start..end]; 329 | Some((text, start..end)) 330 | } else { 331 | None 332 | } 333 | } 334 | 335 | fn chomp_number(&mut self) -> Token<'a> { 336 | let mut seen_decimal_point = false; 337 | 338 | let (text, span) = self 339 | .take_while(|c| match c { 340 | '.' if !seen_decimal_point => { 341 | seen_decimal_point = true; 342 | true 343 | }, 344 | '0'..='9' => true, 345 | _ => false, 346 | }) 347 | .expect("We know there is at least one digit in the input"); 348 | 349 | Token { 350 | text, 351 | span, 352 | kind: TokenKind::Number, 353 | } 354 | } 355 | 356 | fn chomp_identifier(&mut self) -> Token<'a> { 357 | let mut seen_first_character = false; 358 | 359 | let (text, span) = self 360 | .take_while(|c| { 361 | if seen_first_character { 362 | c.is_alphanumeric() || c == '_' 363 | } else { 364 | seen_first_character = true; 365 | c.is_alphabetic() || c == '_' 366 | } 367 | }) 368 | .expect("We know there should be at least 1 character"); 369 | 370 | Token { 371 | text, 372 | span, 373 | kind: TokenKind::Identifier, 374 | } 375 | } 376 | } 377 | 378 | impl<'a> Iterator for Tokens<'a> { 379 | type Item = Result, ParseError>; 380 | 381 | fn next(&mut self) -> Option { 382 | loop { 383 | let next_character = self.peek()?; 384 | 385 | if next_character.is_whitespace() { 386 | // Skip the whitespace 387 | self.advance(); 388 | continue; 389 | } 390 | 391 | return match next_character { 392 | '(' => self.chomp(TokenKind::OpenParen), 393 | ')' => self.chomp(TokenKind::CloseParen), 394 | '+' => self.chomp(TokenKind::Plus), 395 | '-' => self.chomp(TokenKind::Minus), 396 | '*' => self.chomp(TokenKind::Times), 397 | '/' => self.chomp(TokenKind::Divide), 398 | '_' | 'a'..='z' | 'A'..='Z' => { 399 | Some(Ok(self.chomp_identifier())) 400 | }, 401 | '0'..='9' => Some(Ok(self.chomp_number())), 402 | other => Some(Err(ParseError::InvalidCharacter { 403 | character: other, 404 | index: self.cursor, 405 | })), 406 | }; 407 | } 408 | } 409 | } 410 | 411 | #[derive(Debug, Clone, PartialEq)] 412 | struct Token<'a> { 413 | text: &'a str, 414 | span: Range, 415 | kind: TokenKind, 416 | } 417 | 418 | /// The kinds of token that can appear in an [`Expression`]'s text form. 419 | #[derive(Debug, Copy, Clone, PartialEq)] 420 | pub enum TokenKind { 421 | Identifier, 422 | Number, 423 | OpenParen, 424 | CloseParen, 425 | Plus, 426 | Minus, 427 | Times, 428 | Divide, 429 | } 430 | 431 | impl TokenKind { 432 | fn as_binary_op(self) -> BinaryOperation { 433 | match self { 434 | TokenKind::Plus => BinaryOperation::Plus, 435 | TokenKind::Minus => BinaryOperation::Minus, 436 | TokenKind::Times => BinaryOperation::Times, 437 | TokenKind::Divide => BinaryOperation::Divide, 438 | other => unreachable!("{:?} is not a binary op", other), 439 | } 440 | } 441 | } 442 | 443 | #[cfg(test)] 444 | mod tokenizer_tests { 445 | use super::*; 446 | 447 | macro_rules! tokenize_test { 448 | ($name:ident, $src:expr, $should_be:expr) => { 449 | #[test] 450 | fn $name() { 451 | let mut tokens = Tokens::new($src); 452 | 453 | let got = tokens.next().unwrap().unwrap(); 454 | 455 | let Range { start, end } = got.span; 456 | assert_eq!(start, 0); 457 | assert_eq!(end, $src.len()); 458 | assert_eq!(got.kind, $should_be); 459 | 460 | assert!( 461 | tokens.next().is_none(), 462 | "{:?} should be empty", 463 | tokens 464 | ); 465 | } 466 | }; 467 | } 468 | 469 | tokenize_test!(open_paren, "(", TokenKind::OpenParen); 470 | tokenize_test!(close_paren, ")", TokenKind::CloseParen); 471 | tokenize_test!(plus, "+", TokenKind::Plus); 472 | tokenize_test!(minus, "-", TokenKind::Minus); 473 | tokenize_test!(times, "*", TokenKind::Times); 474 | tokenize_test!(divide, "/", TokenKind::Divide); 475 | tokenize_test!(single_digit_integer, "3", TokenKind::Number); 476 | tokenize_test!(multi_digit_integer, "31", TokenKind::Number); 477 | tokenize_test!(number_with_trailing_dot, "31.", TokenKind::Number); 478 | tokenize_test!(simple_decimal, "3.14", TokenKind::Number); 479 | tokenize_test!(simple_identifier, "x", TokenKind::Identifier); 480 | tokenize_test!(longer_identifier, "hello", TokenKind::Identifier); 481 | tokenize_test!( 482 | identifiers_can_have_underscores, 483 | "hello_world", 484 | TokenKind::Identifier 485 | ); 486 | tokenize_test!( 487 | identifiers_can_start_with_underscores, 488 | "_hello_world", 489 | TokenKind::Identifier 490 | ); 491 | tokenize_test!( 492 | identifiers_can_contain_numbers, 493 | "var5", 494 | TokenKind::Identifier 495 | ); 496 | } 497 | 498 | #[cfg(test)] 499 | mod parser_tests { 500 | use super::*; 501 | 502 | macro_rules! parser_test { 503 | ($name:ident, $src:expr) => { 504 | parser_test!($name, $src, $src); 505 | }; 506 | ($name:ident, $src:expr, $should_be:expr) => { 507 | #[test] 508 | fn $name() { 509 | let got = Parser::new($src).parse().unwrap(); 510 | 511 | let round_tripped = got.to_string(); 512 | assert_eq!(round_tripped, $should_be); 513 | } 514 | }; 515 | } 516 | 517 | parser_test!(simple_integer, "1"); 518 | parser_test!(one_plus_one, "1 + 1"); 519 | parser_test!(one_plus_one_plus_negative_one, "1 + -1"); 520 | parser_test!(one_plus_one_times_three, "1 + 1*3"); 521 | parser_test!(one_plus_one_all_times_three, "(1 + 1)*3"); 522 | parser_test!(negative_one, "-1"); 523 | parser_test!(negative_one_plus_one, "-1 + 1"); 524 | parser_test!(negative_one_plus_x, "-1 + x"); 525 | parser_test!(number_in_parens, "(1)", "1"); 526 | parser_test!(bimdas, "1*2 + 3*4/(5 - 2)*1 - 3"); 527 | parser_test!(function_call, "sin(1)", "sin(1)"); 528 | parser_test!(function_call_with_expression, "sin(1/0)"); 529 | parser_test!( 530 | function_calls_function_calls_function_with_variable, 531 | "foo(bar(baz(pi)))" 532 | ); 533 | } 534 | -------------------------------------------------------------------------------- /src/solve.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | ops::{self, Context, EvaluationError}, 3 | Equation, Expression, Parameter, SystemOfEquations, 4 | }; 5 | use nalgebra::{DMatrix as Matrix, DVector as Vector}; 6 | use std::{ 7 | collections::HashMap, 8 | error::Error, 9 | fmt::{self, Debug, Display, Formatter}, 10 | }; 11 | 12 | pub(crate) fn solve( 13 | equations: &[Equation], 14 | unknowns: &[Parameter], 15 | system: &SystemOfEquations, 16 | ctx: &C, 17 | ) -> Result 18 | where 19 | C: Context, 20 | { 21 | let jacobian = Jacobian::for_equations(equations, &unknowns, ctx)?; 22 | let got = solve_with_newtons_method(&jacobian, &system, ctx)?; 23 | 24 | Ok(Solution { 25 | known_values: jacobian.collate_unknowns(got.as_slice()), 26 | }) 27 | } 28 | 29 | #[derive(Debug, Clone, PartialEq)] 30 | pub struct Solution { 31 | pub known_values: HashMap, 32 | } 33 | 34 | #[derive(Debug, Clone, PartialEq)] 35 | pub enum SolveError { 36 | Eval(EvaluationError), 37 | DidntConverge, 38 | NoSolution, 39 | } 40 | 41 | impl From for SolveError { 42 | fn from(e: EvaluationError) -> Self { SolveError::Eval(e) } 43 | } 44 | 45 | impl Display for SolveError { 46 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 47 | match self { 48 | SolveError::Eval(_) => write!(f, "Evaluation failed"), 49 | SolveError::DidntConverge => { 50 | write!(f, "The solution didn't converge") 51 | }, 52 | SolveError::NoSolution => write!(f, "No solution found"), 53 | } 54 | } 55 | } 56 | 57 | impl Error for SolveError { 58 | fn source(&self) -> Option<&(dyn Error + 'static)> { 59 | match self { 60 | SolveError::Eval(inner) => Some(inner), 61 | _ => None, 62 | } 63 | } 64 | } 65 | 66 | /// Solve a set of non-linear equations iteratively using Newton's method. 67 | /// 68 | /// The iterative equation for Newton's method when applied to a set of 69 | /// equations, `F`, is: 70 | /// 71 | /// ```text 72 | /// x_next = x_current - jacobian(F).inverse() * F(x_current) 73 | /// ``` 74 | /// 75 | /// This is the multi-variable equivalent of Newton-Raphson, where the jacobian 76 | /// is the slope of our equations, and we pre-multiply by the inverse because 77 | /// that's the matrix equivalent of division. 78 | /// 79 | /// Calculating the inverse of a matrix is expensive though, so we rearrange 80 | /// it to look like this: 81 | /// 82 | /// ```text 83 | /// jacobian(F) * (x_next - x_current) = -F(x_current) 84 | /// ``` 85 | /// 86 | /// ... Which is in the form `A.δx = b`. 87 | /// 88 | /// We can then solve for `δx` using gaussian elimination, then get the refined 89 | /// solution by solving `δx = x_next - x_current`. 90 | /// 91 | /// See also: 92 | /// 93 | /// - https://en.wikipedia.org/wiki/Newton%27s_method#Nonlinear_systems_of_equations 94 | /// - https://www.youtube.com/watch?v=zPDp_ewoyhM 95 | fn solve_with_newtons_method( 96 | jacobian: &Jacobian, 97 | system: &SystemOfEquations, 98 | ctx: &C, 99 | ) -> Result, SolveError> 100 | where 101 | C: Context, 102 | { 103 | const MAX_ITERATIONS: usize = 50; 104 | 105 | let mut solution = jacobian.initial_values(); 106 | 107 | for _ in 0..MAX_ITERATIONS { 108 | let x_next = { 109 | let evaluated_jacobian = 110 | jacobian.evaluate(solution.as_slice(), ctx)?; 111 | 112 | let lookup = jacobian.lookup_value_by_name(solution.as_slice()); 113 | let f_of_x = system.evaluate(&lookup, ctx)?; 114 | step_newtons_method(evaluated_jacobian, &solution, f_of_x)? 115 | }; 116 | 117 | if approx::relative_eq!(x_next, solution) { 118 | return Ok(x_next); 119 | } 120 | solution = x_next; 121 | } 122 | 123 | Err(SolveError::DidntConverge) 124 | } 125 | 126 | fn step_newtons_method( 127 | jacobian: Matrix, 128 | x: &Vector, 129 | f_of_x: Vector, 130 | ) -> Result, SolveError> { 131 | // We're trying to solve: 132 | // x_next = x_current - jacobian(F).inverse() * F(x_current) 133 | // 134 | // Which gets rearranged as: 135 | // jacobian(F) * (x_next - x_current) = -F(x_current) 136 | // 137 | // Note that we use LU decomposition to solve equations of the form `Ax = b` 138 | 139 | let negative_f_of_x = -f_of_x; 140 | let delta_x = jacobian 141 | .lu() 142 | .solve(&negative_f_of_x) 143 | .ok_or(SolveError::NoSolution)?; 144 | 145 | Ok(delta_x + x) 146 | } 147 | 148 | /// A matrix of [`Expression`]s representing the partial derivatives for each 149 | /// parameter in each equation. 150 | #[derive(Debug, Clone, PartialEq)] 151 | pub(crate) struct Jacobian<'a> { 152 | cells: Box<[Expression]>, 153 | equations: &'a [Equation], 154 | unknowns: &'a [Parameter], 155 | } 156 | 157 | impl<'a> Jacobian<'a> { 158 | fn for_equations( 159 | equations: &'a [Equation], 160 | unknowns: &'a [Parameter], 161 | ctx: &C, 162 | ) -> Result 163 | where 164 | C: Context, 165 | { 166 | let mut cells = Vec::new(); 167 | 168 | for equation in equations { 169 | for unknown in unknowns { 170 | let value = if equation.body.depends_on(unknown) { 171 | let derivative = 172 | ops::partial_derivative(&equation.body, unknown, ctx)?; 173 | ops::fold_constants(&derivative, ctx) 174 | } else { 175 | Expression::Constant(0.0) 176 | }; 177 | cells.push(value); 178 | } 179 | } 180 | 181 | Ok(Jacobian { 182 | cells: cells.into_boxed_slice(), 183 | equations, 184 | unknowns, 185 | }) 186 | } 187 | 188 | fn rows(&self) -> usize { self.equations.len() } 189 | 190 | fn columns(&self) -> usize { self.unknowns.len() } 191 | 192 | fn evaluate( 193 | &self, 194 | parameter_values: &[f64], 195 | ctx: &C, 196 | ) -> Result, EvaluationError> 197 | where 198 | C: Context, 199 | { 200 | assert_eq!(parameter_values.len(), self.unknowns.len()); 201 | 202 | let mut values = Vec::with_capacity(self.cells.len()); 203 | let lookup = self.lookup_value_by_name(parameter_values); 204 | 205 | for row in self.iter_rows() { 206 | for expression in row { 207 | values.push(ops::evaluate(&expression, &lookup, ctx)?); 208 | } 209 | } 210 | 211 | Ok(Matrix::from_vec(self.rows(), self.columns(), values)) 212 | } 213 | 214 | fn lookup_value_by_name<'p>( 215 | &'p self, 216 | parameter_values: &'p [f64], 217 | ) -> impl Fn(&Parameter) -> Option + 'p { 218 | move |parameter| { 219 | self.unknowns 220 | .iter() 221 | .position(|p| p == parameter) 222 | .map(|ix| parameter_values[ix]) 223 | } 224 | } 225 | 226 | pub(crate) fn collate_unknowns( 227 | &self, 228 | parameter_values: &[f64], 229 | ) -> HashMap { 230 | self.unknowns 231 | .iter() 232 | .cloned() 233 | .zip(parameter_values.iter().copied()) 234 | .collect() 235 | } 236 | 237 | fn initial_values(&self) -> Vector { 238 | Vector::zeros(self.unknowns.len()) 239 | } 240 | 241 | fn iter_rows(&self) -> impl Iterator + '_ { 242 | self.cells.chunks_exact(self.columns()) 243 | } 244 | } 245 | 246 | #[cfg(test)] 247 | mod tests { 248 | use super::*; 249 | use crate::ops::Builtins; 250 | 251 | #[test] 252 | fn single_equality() { 253 | let equation: Equation = "x = 5".parse().unwrap(); 254 | let x = Parameter::named("x"); 255 | let builtins = Builtins::default(); 256 | 257 | let got = SystemOfEquations::new() 258 | .with(equation) 259 | .solve(&builtins) 260 | .unwrap(); 261 | 262 | assert_eq!(got.known_values.len(), 1); 263 | assert_eq!(got.known_values[&x], 5.0); 264 | } 265 | 266 | #[test] 267 | fn calculate_jacobian_of_known_system_of_equations() { 268 | // See https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant#Example_5 269 | let system = SystemOfEquations::from_equations(&[ 270 | "5 * b", 271 | "4*a*a - 2*sin(b*c)", 272 | "b*c", 273 | ]) 274 | .unwrap(); 275 | let ctx = Builtins::default(); 276 | 277 | let unknowns = system.unknowns(); 278 | let got = Jacobian::for_equations(&system.equations, &unknowns, &ctx) 279 | .unwrap(); 280 | 281 | assert_eq!( 282 | got.columns(), 283 | system.num_unknowns(), 284 | "There are 3 unknowns" 285 | ); 286 | assert_eq!(got.rows(), system.equations.len(), "There are 3 equations"); 287 | // Note: I needed to rearrange the Wikipedia equations a bit because 288 | // our solver multiplied things differently (i.e. "c*2" instead of 289 | // "2*c") 290 | let should_be = [ 291 | ["0", "5", "0"].as_ref(), 292 | ["8*a", "-cos(b*c)*c*2", "-cos(b*c)*b*2"].as_ref(), 293 | ["0", "c", "b"].as_ref(), 294 | ]; 295 | assert_jacobian_eq(&got, should_be.as_ref()); 296 | } 297 | 298 | fn assert_jacobian_eq(jacobian: &Jacobian, should_be: &[&[&str]]) { 299 | assert_eq!(jacobian.rows(), should_be.len()); 300 | 301 | for (row, row_should_be) in jacobian.iter_rows().zip(should_be) { 302 | assert_eq!(row.len(), row_should_be.len()); 303 | 304 | for (value, column_should_be) in row.iter().zip(*row_should_be) { 305 | let should_be: Expression = column_should_be.parse().unwrap(); 306 | 307 | // Usually I wouldn't compare strings, but it's possible to get 308 | // different (but equivalent!) trees when calculating the 309 | // jacobian vs parsing from a string 310 | assert_eq!(value.to_string(), should_be.to_string()); 311 | } 312 | } 313 | } 314 | 315 | #[test] 316 | fn solve_simple_equations() { 317 | let system = 318 | SystemOfEquations::from_equations(&["x-1", "y-2", "z-3"]).unwrap(); 319 | let ctx = Builtins::default(); 320 | let unknowns = system.unknowns(); 321 | let jacobian = 322 | Jacobian::for_equations(&system.equations, &unknowns, &ctx) 323 | .unwrap(); 324 | 325 | let got = solve_with_newtons_method(&jacobian, &system, &ctx).unwrap(); 326 | 327 | let named_parameters = jacobian.collate_unknowns(got.as_slice()); 328 | let x = Parameter::named("x"); 329 | let y = Parameter::named("y"); 330 | let z = Parameter::named("z"); 331 | assert_eq!(named_parameters[&x], 1.0); 332 | assert_eq!(named_parameters[&y], 2.0); 333 | assert_eq!(named_parameters[&z], 3.0); 334 | } 335 | 336 | #[test] 337 | fn work_through_youtube_example() { 338 | // From https://www.youtube.com/watch?v=zPDp_ewoyhM 339 | let system = SystemOfEquations::from_equations(&[ 340 | "a + 2*b - 2", 341 | "a*a + 4*b*b - 4", 342 | ]) 343 | .unwrap(); 344 | let ctx = Builtins::default(); 345 | 346 | // first we need to calculate the jacobian 347 | let unknowns = system.unknowns(); 348 | let jacobian = 349 | Jacobian::for_equations(&system.equations, &unknowns, &ctx) 350 | .unwrap(); 351 | assert_jacobian_eq( 352 | &jacobian, 353 | &[&["1 + -0", "2"], &["2*a + -0", "8*b"]], 354 | ); 355 | 356 | // make an initial guess 357 | let x_0 = Vector::from_vec(vec![1.0, 2.0]); 358 | 359 | // evaluate the components we need 360 | let jacobian_of_x_0 = jacobian.evaluate(x_0.as_slice(), &ctx).unwrap(); 361 | let lookup_parameter_value = 362 | jacobian.lookup_value_by_name(x_0.as_slice()); 363 | let f_of_x_0 = system.evaluate(lookup_parameter_value, &ctx).unwrap(); 364 | 365 | // and double-check them 366 | assert_eq!( 367 | jacobian_of_x_0, 368 | Matrix::from_vec(2, 2, vec![1.0, 2.0, 2.0, 16.0]) 369 | ); 370 | assert_eq!(f_of_x_0.as_slice(), &[3.0, 13.0]); 371 | 372 | // one iteration of newton's method 373 | let x_1 = step_newtons_method(jacobian_of_x_0, &x_0, f_of_x_0).unwrap(); 374 | let should_be = Vector::from_vec(vec![-10.0 / 12.0, 17.0 / 12.0]); 375 | approx::relative_eq!(x_1, should_be); 376 | } 377 | 378 | macro_rules! solve_test { 379 | ($(#[$attr:meta])* $name:ident, $equations:expr => { $( $var_name:ident : $value:expr ),* $(,)? } 380 | ) => { 381 | $( 382 | #[$attr] 383 | )* 384 | #[test] 385 | fn $name() { 386 | let equations = SystemOfEquations::from_equations(& $equations).unwrap(); 387 | let ctx = Builtins::default(); 388 | 389 | let got = equations.clone().solve(&ctx).unwrap(); 390 | 391 | let mut should_be = HashMap::::new(); 392 | 393 | $( 394 | let p = Parameter::named(stringify!($var_name)); 395 | should_be.insert(p, $value); 396 | )* 397 | 398 | assert_eq!( 399 | got.known_values.keys().collect::>(), 400 | should_be.keys().collect::>(), 401 | "The keys should match", 402 | ); 403 | 404 | $( 405 | let p = Parameter::named(stringify!($var_name)); 406 | approx::assert_relative_eq!(got.known_values[&p], should_be[&p]); 407 | )* 408 | 409 | // double-check with the equations to make sure it all adds up 410 | for equation in equations { 411 | let lookup = |p: &Parameter| got.known_values.get(p).copied(); 412 | let evaluated = crate::ops::evaluate(&equation.body, &lookup, &ctx).unwrap(); 413 | 414 | approx::assert_relative_eq!(evaluated, 0.0); 415 | } 416 | } 417 | }; 418 | } 419 | 420 | solve_test!(x_is_5, ["x = 5"] => { x: 5.0 }); 421 | solve_test!(unrelated_equations, ["x = 5", "y = -2"] => { x: 5.0, y: -2.0 }); 422 | solve_test!(difference_of_numbers, ["x + y = 10", "x - y = 0"] => { x: 5.0, y: 5.0 }); 423 | solve_test!(simple_trig, ["sin(x) = 0"] => { x: 0.0 }); 424 | solve_test!(sin_x_take_cos_x, ["sin(x) - cos(x + 90) = 0"] => { x: 0.0 }); 425 | solve_test!(#[ignore] simple_trig_2, ["cos(x) = 1"] => { x: 0.0 }); 426 | solve_test!(#[ignore] difference_of_numbers_reversed, ["x - y = 0", "x + y = 10"] => { x: 5.0, y: 5.0 }); 427 | solve_test!(#[ignore] trig_equation_with_no_solution, ["sin(x) * cos(x + 90) = 0.5"] => {x: 0.0}); 428 | } 429 | --------------------------------------------------------------------------------