, ParseError>> {
295 | let start = self.cursor;
296 | self.advance()?;
297 | let end = self.cursor;
298 |
299 | let tok = Token {
300 | text: &self.src[start..end],
301 | span: start..end,
302 | kind,
303 | };
304 |
305 | Some(Ok(tok))
306 | }
307 |
308 | fn take_while(
309 | &mut self,
310 | mut predicate: P,
311 | ) -> Option<(&'a str, Range)>
312 | where
313 | P: FnMut(char) -> bool,
314 | {
315 | let start = self.cursor;
316 |
317 | while let Some(c) = self.peek() {
318 | if !predicate(c) {
319 | break;
320 | }
321 |
322 | self.advance();
323 | }
324 |
325 | let end = self.cursor;
326 |
327 | if start != end {
328 | let text = &self.src[start..end];
329 | Some((text, start..end))
330 | } else {
331 | None
332 | }
333 | }
334 |
335 | fn chomp_number(&mut self) -> Token<'a> {
336 | let mut seen_decimal_point = false;
337 |
338 | let (text, span) = self
339 | .take_while(|c| match c {
340 | '.' if !seen_decimal_point => {
341 | seen_decimal_point = true;
342 | true
343 | },
344 | '0'..='9' => true,
345 | _ => false,
346 | })
347 | .expect("We know there is at least one digit in the input");
348 |
349 | Token {
350 | text,
351 | span,
352 | kind: TokenKind::Number,
353 | }
354 | }
355 |
356 | fn chomp_identifier(&mut self) -> Token<'a> {
357 | let mut seen_first_character = false;
358 |
359 | let (text, span) = self
360 | .take_while(|c| {
361 | if seen_first_character {
362 | c.is_alphanumeric() || c == '_'
363 | } else {
364 | seen_first_character = true;
365 | c.is_alphabetic() || c == '_'
366 | }
367 | })
368 | .expect("We know there should be at least 1 character");
369 |
370 | Token {
371 | text,
372 | span,
373 | kind: TokenKind::Identifier,
374 | }
375 | }
376 | }
377 |
378 | impl<'a> Iterator for Tokens<'a> {
379 | type Item = Result, ParseError>;
380 |
381 | fn next(&mut self) -> Option {
382 | loop {
383 | let next_character = self.peek()?;
384 |
385 | if next_character.is_whitespace() {
386 | // Skip the whitespace
387 | self.advance();
388 | continue;
389 | }
390 |
391 | return match next_character {
392 | '(' => self.chomp(TokenKind::OpenParen),
393 | ')' => self.chomp(TokenKind::CloseParen),
394 | '+' => self.chomp(TokenKind::Plus),
395 | '-' => self.chomp(TokenKind::Minus),
396 | '*' => self.chomp(TokenKind::Times),
397 | '/' => self.chomp(TokenKind::Divide),
398 | '_' | 'a'..='z' | 'A'..='Z' => {
399 | Some(Ok(self.chomp_identifier()))
400 | },
401 | '0'..='9' => Some(Ok(self.chomp_number())),
402 | other => Some(Err(ParseError::InvalidCharacter {
403 | character: other,
404 | index: self.cursor,
405 | })),
406 | };
407 | }
408 | }
409 | }
410 |
411 | #[derive(Debug, Clone, PartialEq)]
412 | struct Token<'a> {
413 | text: &'a str,
414 | span: Range,
415 | kind: TokenKind,
416 | }
417 |
418 | /// The kinds of token that can appear in an [`Expression`]'s text form.
419 | #[derive(Debug, Copy, Clone, PartialEq)]
420 | pub enum TokenKind {
421 | Identifier,
422 | Number,
423 | OpenParen,
424 | CloseParen,
425 | Plus,
426 | Minus,
427 | Times,
428 | Divide,
429 | }
430 |
431 | impl TokenKind {
432 | fn as_binary_op(self) -> BinaryOperation {
433 | match self {
434 | TokenKind::Plus => BinaryOperation::Plus,
435 | TokenKind::Minus => BinaryOperation::Minus,
436 | TokenKind::Times => BinaryOperation::Times,
437 | TokenKind::Divide => BinaryOperation::Divide,
438 | other => unreachable!("{:?} is not a binary op", other),
439 | }
440 | }
441 | }
442 |
443 | #[cfg(test)]
444 | mod tokenizer_tests {
445 | use super::*;
446 |
447 | macro_rules! tokenize_test {
448 | ($name:ident, $src:expr, $should_be:expr) => {
449 | #[test]
450 | fn $name() {
451 | let mut tokens = Tokens::new($src);
452 |
453 | let got = tokens.next().unwrap().unwrap();
454 |
455 | let Range { start, end } = got.span;
456 | assert_eq!(start, 0);
457 | assert_eq!(end, $src.len());
458 | assert_eq!(got.kind, $should_be);
459 |
460 | assert!(
461 | tokens.next().is_none(),
462 | "{:?} should be empty",
463 | tokens
464 | );
465 | }
466 | };
467 | }
468 |
469 | tokenize_test!(open_paren, "(", TokenKind::OpenParen);
470 | tokenize_test!(close_paren, ")", TokenKind::CloseParen);
471 | tokenize_test!(plus, "+", TokenKind::Plus);
472 | tokenize_test!(minus, "-", TokenKind::Minus);
473 | tokenize_test!(times, "*", TokenKind::Times);
474 | tokenize_test!(divide, "/", TokenKind::Divide);
475 | tokenize_test!(single_digit_integer, "3", TokenKind::Number);
476 | tokenize_test!(multi_digit_integer, "31", TokenKind::Number);
477 | tokenize_test!(number_with_trailing_dot, "31.", TokenKind::Number);
478 | tokenize_test!(simple_decimal, "3.14", TokenKind::Number);
479 | tokenize_test!(simple_identifier, "x", TokenKind::Identifier);
480 | tokenize_test!(longer_identifier, "hello", TokenKind::Identifier);
481 | tokenize_test!(
482 | identifiers_can_have_underscores,
483 | "hello_world",
484 | TokenKind::Identifier
485 | );
486 | tokenize_test!(
487 | identifiers_can_start_with_underscores,
488 | "_hello_world",
489 | TokenKind::Identifier
490 | );
491 | tokenize_test!(
492 | identifiers_can_contain_numbers,
493 | "var5",
494 | TokenKind::Identifier
495 | );
496 | }
497 |
498 | #[cfg(test)]
499 | mod parser_tests {
500 | use super::*;
501 |
502 | macro_rules! parser_test {
503 | ($name:ident, $src:expr) => {
504 | parser_test!($name, $src, $src);
505 | };
506 | ($name:ident, $src:expr, $should_be:expr) => {
507 | #[test]
508 | fn $name() {
509 | let got = Parser::new($src).parse().unwrap();
510 |
511 | let round_tripped = got.to_string();
512 | assert_eq!(round_tripped, $should_be);
513 | }
514 | };
515 | }
516 |
517 | parser_test!(simple_integer, "1");
518 | parser_test!(one_plus_one, "1 + 1");
519 | parser_test!(one_plus_one_plus_negative_one, "1 + -1");
520 | parser_test!(one_plus_one_times_three, "1 + 1*3");
521 | parser_test!(one_plus_one_all_times_three, "(1 + 1)*3");
522 | parser_test!(negative_one, "-1");
523 | parser_test!(negative_one_plus_one, "-1 + 1");
524 | parser_test!(negative_one_plus_x, "-1 + x");
525 | parser_test!(number_in_parens, "(1)", "1");
526 | parser_test!(bimdas, "1*2 + 3*4/(5 - 2)*1 - 3");
527 | parser_test!(function_call, "sin(1)", "sin(1)");
528 | parser_test!(function_call_with_expression, "sin(1/0)");
529 | parser_test!(
530 | function_calls_function_calls_function_with_variable,
531 | "foo(bar(baz(pi)))"
532 | );
533 | }
534 |
--------------------------------------------------------------------------------
/src/solve.rs:
--------------------------------------------------------------------------------
1 | use crate::{
2 | ops::{self, Context, EvaluationError},
3 | Equation, Expression, Parameter, SystemOfEquations,
4 | };
5 | use nalgebra::{DMatrix as Matrix, DVector as Vector};
6 | use std::{
7 | collections::HashMap,
8 | error::Error,
9 | fmt::{self, Debug, Display, Formatter},
10 | };
11 |
12 | pub(crate) fn solve(
13 | equations: &[Equation],
14 | unknowns: &[Parameter],
15 | system: &SystemOfEquations,
16 | ctx: &C,
17 | ) -> Result
18 | where
19 | C: Context,
20 | {
21 | let jacobian = Jacobian::for_equations(equations, &unknowns, ctx)?;
22 | let got = solve_with_newtons_method(&jacobian, &system, ctx)?;
23 |
24 | Ok(Solution {
25 | known_values: jacobian.collate_unknowns(got.as_slice()),
26 | })
27 | }
28 |
29 | #[derive(Debug, Clone, PartialEq)]
30 | pub struct Solution {
31 | pub known_values: HashMap,
32 | }
33 |
34 | #[derive(Debug, Clone, PartialEq)]
35 | pub enum SolveError {
36 | Eval(EvaluationError),
37 | DidntConverge,
38 | NoSolution,
39 | }
40 |
41 | impl From for SolveError {
42 | fn from(e: EvaluationError) -> Self { SolveError::Eval(e) }
43 | }
44 |
45 | impl Display for SolveError {
46 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
47 | match self {
48 | SolveError::Eval(_) => write!(f, "Evaluation failed"),
49 | SolveError::DidntConverge => {
50 | write!(f, "The solution didn't converge")
51 | },
52 | SolveError::NoSolution => write!(f, "No solution found"),
53 | }
54 | }
55 | }
56 |
57 | impl Error for SolveError {
58 | fn source(&self) -> Option<&(dyn Error + 'static)> {
59 | match self {
60 | SolveError::Eval(inner) => Some(inner),
61 | _ => None,
62 | }
63 | }
64 | }
65 |
66 | /// Solve a set of non-linear equations iteratively using Newton's method.
67 | ///
68 | /// The iterative equation for Newton's method when applied to a set of
69 | /// equations, `F`, is:
70 | ///
71 | /// ```text
72 | /// x_next = x_current - jacobian(F).inverse() * F(x_current)
73 | /// ```
74 | ///
75 | /// This is the multi-variable equivalent of Newton-Raphson, where the jacobian
76 | /// is the slope of our equations, and we pre-multiply by the inverse because
77 | /// that's the matrix equivalent of division.
78 | ///
79 | /// Calculating the inverse of a matrix is expensive though, so we rearrange
80 | /// it to look like this:
81 | ///
82 | /// ```text
83 | /// jacobian(F) * (x_next - x_current) = -F(x_current)
84 | /// ```
85 | ///
86 | /// ... Which is in the form `A.δx = b`.
87 | ///
88 | /// We can then solve for `δx` using gaussian elimination, then get the refined
89 | /// solution by solving `δx = x_next - x_current`.
90 | ///
91 | /// See also:
92 | ///
93 | /// - https://en.wikipedia.org/wiki/Newton%27s_method#Nonlinear_systems_of_equations
94 | /// - https://www.youtube.com/watch?v=zPDp_ewoyhM
95 | fn solve_with_newtons_method(
96 | jacobian: &Jacobian,
97 | system: &SystemOfEquations,
98 | ctx: &C,
99 | ) -> Result, SolveError>
100 | where
101 | C: Context,
102 | {
103 | const MAX_ITERATIONS: usize = 50;
104 |
105 | let mut solution = jacobian.initial_values();
106 |
107 | for _ in 0..MAX_ITERATIONS {
108 | let x_next = {
109 | let evaluated_jacobian =
110 | jacobian.evaluate(solution.as_slice(), ctx)?;
111 |
112 | let lookup = jacobian.lookup_value_by_name(solution.as_slice());
113 | let f_of_x = system.evaluate(&lookup, ctx)?;
114 | step_newtons_method(evaluated_jacobian, &solution, f_of_x)?
115 | };
116 |
117 | if approx::relative_eq!(x_next, solution) {
118 | return Ok(x_next);
119 | }
120 | solution = x_next;
121 | }
122 |
123 | Err(SolveError::DidntConverge)
124 | }
125 |
126 | fn step_newtons_method(
127 | jacobian: Matrix,
128 | x: &Vector,
129 | f_of_x: Vector,
130 | ) -> Result, SolveError> {
131 | // We're trying to solve:
132 | // x_next = x_current - jacobian(F).inverse() * F(x_current)
133 | //
134 | // Which gets rearranged as:
135 | // jacobian(F) * (x_next - x_current) = -F(x_current)
136 | //
137 | // Note that we use LU decomposition to solve equations of the form `Ax = b`
138 |
139 | let negative_f_of_x = -f_of_x;
140 | let delta_x = jacobian
141 | .lu()
142 | .solve(&negative_f_of_x)
143 | .ok_or(SolveError::NoSolution)?;
144 |
145 | Ok(delta_x + x)
146 | }
147 |
148 | /// A matrix of [`Expression`]s representing the partial derivatives for each
149 | /// parameter in each equation.
150 | #[derive(Debug, Clone, PartialEq)]
151 | pub(crate) struct Jacobian<'a> {
152 | cells: Box<[Expression]>,
153 | equations: &'a [Equation],
154 | unknowns: &'a [Parameter],
155 | }
156 |
157 | impl<'a> Jacobian<'a> {
158 | fn for_equations(
159 | equations: &'a [Equation],
160 | unknowns: &'a [Parameter],
161 | ctx: &C,
162 | ) -> Result
163 | where
164 | C: Context,
165 | {
166 | let mut cells = Vec::new();
167 |
168 | for equation in equations {
169 | for unknown in unknowns {
170 | let value = if equation.body.depends_on(unknown) {
171 | let derivative =
172 | ops::partial_derivative(&equation.body, unknown, ctx)?;
173 | ops::fold_constants(&derivative, ctx)
174 | } else {
175 | Expression::Constant(0.0)
176 | };
177 | cells.push(value);
178 | }
179 | }
180 |
181 | Ok(Jacobian {
182 | cells: cells.into_boxed_slice(),
183 | equations,
184 | unknowns,
185 | })
186 | }
187 |
188 | fn rows(&self) -> usize { self.equations.len() }
189 |
190 | fn columns(&self) -> usize { self.unknowns.len() }
191 |
192 | fn evaluate(
193 | &self,
194 | parameter_values: &[f64],
195 | ctx: &C,
196 | ) -> Result, EvaluationError>
197 | where
198 | C: Context,
199 | {
200 | assert_eq!(parameter_values.len(), self.unknowns.len());
201 |
202 | let mut values = Vec::with_capacity(self.cells.len());
203 | let lookup = self.lookup_value_by_name(parameter_values);
204 |
205 | for row in self.iter_rows() {
206 | for expression in row {
207 | values.push(ops::evaluate(&expression, &lookup, ctx)?);
208 | }
209 | }
210 |
211 | Ok(Matrix::from_vec(self.rows(), self.columns(), values))
212 | }
213 |
214 | fn lookup_value_by_name<'p>(
215 | &'p self,
216 | parameter_values: &'p [f64],
217 | ) -> impl Fn(&Parameter) -> Option + 'p {
218 | move |parameter| {
219 | self.unknowns
220 | .iter()
221 | .position(|p| p == parameter)
222 | .map(|ix| parameter_values[ix])
223 | }
224 | }
225 |
226 | pub(crate) fn collate_unknowns(
227 | &self,
228 | parameter_values: &[f64],
229 | ) -> HashMap {
230 | self.unknowns
231 | .iter()
232 | .cloned()
233 | .zip(parameter_values.iter().copied())
234 | .collect()
235 | }
236 |
237 | fn initial_values(&self) -> Vector {
238 | Vector::zeros(self.unknowns.len())
239 | }
240 |
241 | fn iter_rows(&self) -> impl Iterator- + '_ {
242 | self.cells.chunks_exact(self.columns())
243 | }
244 | }
245 |
246 | #[cfg(test)]
247 | mod tests {
248 | use super::*;
249 | use crate::ops::Builtins;
250 |
251 | #[test]
252 | fn single_equality() {
253 | let equation: Equation = "x = 5".parse().unwrap();
254 | let x = Parameter::named("x");
255 | let builtins = Builtins::default();
256 |
257 | let got = SystemOfEquations::new()
258 | .with(equation)
259 | .solve(&builtins)
260 | .unwrap();
261 |
262 | assert_eq!(got.known_values.len(), 1);
263 | assert_eq!(got.known_values[&x], 5.0);
264 | }
265 |
266 | #[test]
267 | fn calculate_jacobian_of_known_system_of_equations() {
268 | // See https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant#Example_5
269 | let system = SystemOfEquations::from_equations(&[
270 | "5 * b",
271 | "4*a*a - 2*sin(b*c)",
272 | "b*c",
273 | ])
274 | .unwrap();
275 | let ctx = Builtins::default();
276 |
277 | let unknowns = system.unknowns();
278 | let got = Jacobian::for_equations(&system.equations, &unknowns, &ctx)
279 | .unwrap();
280 |
281 | assert_eq!(
282 | got.columns(),
283 | system.num_unknowns(),
284 | "There are 3 unknowns"
285 | );
286 | assert_eq!(got.rows(), system.equations.len(), "There are 3 equations");
287 | // Note: I needed to rearrange the Wikipedia equations a bit because
288 | // our solver multiplied things differently (i.e. "c*2" instead of
289 | // "2*c")
290 | let should_be = [
291 | ["0", "5", "0"].as_ref(),
292 | ["8*a", "-cos(b*c)*c*2", "-cos(b*c)*b*2"].as_ref(),
293 | ["0", "c", "b"].as_ref(),
294 | ];
295 | assert_jacobian_eq(&got, should_be.as_ref());
296 | }
297 |
298 | fn assert_jacobian_eq(jacobian: &Jacobian, should_be: &[&[&str]]) {
299 | assert_eq!(jacobian.rows(), should_be.len());
300 |
301 | for (row, row_should_be) in jacobian.iter_rows().zip(should_be) {
302 | assert_eq!(row.len(), row_should_be.len());
303 |
304 | for (value, column_should_be) in row.iter().zip(*row_should_be) {
305 | let should_be: Expression = column_should_be.parse().unwrap();
306 |
307 | // Usually I wouldn't compare strings, but it's possible to get
308 | // different (but equivalent!) trees when calculating the
309 | // jacobian vs parsing from a string
310 | assert_eq!(value.to_string(), should_be.to_string());
311 | }
312 | }
313 | }
314 |
315 | #[test]
316 | fn solve_simple_equations() {
317 | let system =
318 | SystemOfEquations::from_equations(&["x-1", "y-2", "z-3"]).unwrap();
319 | let ctx = Builtins::default();
320 | let unknowns = system.unknowns();
321 | let jacobian =
322 | Jacobian::for_equations(&system.equations, &unknowns, &ctx)
323 | .unwrap();
324 |
325 | let got = solve_with_newtons_method(&jacobian, &system, &ctx).unwrap();
326 |
327 | let named_parameters = jacobian.collate_unknowns(got.as_slice());
328 | let x = Parameter::named("x");
329 | let y = Parameter::named("y");
330 | let z = Parameter::named("z");
331 | assert_eq!(named_parameters[&x], 1.0);
332 | assert_eq!(named_parameters[&y], 2.0);
333 | assert_eq!(named_parameters[&z], 3.0);
334 | }
335 |
336 | #[test]
337 | fn work_through_youtube_example() {
338 | // From https://www.youtube.com/watch?v=zPDp_ewoyhM
339 | let system = SystemOfEquations::from_equations(&[
340 | "a + 2*b - 2",
341 | "a*a + 4*b*b - 4",
342 | ])
343 | .unwrap();
344 | let ctx = Builtins::default();
345 |
346 | // first we need to calculate the jacobian
347 | let unknowns = system.unknowns();
348 | let jacobian =
349 | Jacobian::for_equations(&system.equations, &unknowns, &ctx)
350 | .unwrap();
351 | assert_jacobian_eq(
352 | &jacobian,
353 | &[&["1 + -0", "2"], &["2*a + -0", "8*b"]],
354 | );
355 |
356 | // make an initial guess
357 | let x_0 = Vector::from_vec(vec![1.0, 2.0]);
358 |
359 | // evaluate the components we need
360 | let jacobian_of_x_0 = jacobian.evaluate(x_0.as_slice(), &ctx).unwrap();
361 | let lookup_parameter_value =
362 | jacobian.lookup_value_by_name(x_0.as_slice());
363 | let f_of_x_0 = system.evaluate(lookup_parameter_value, &ctx).unwrap();
364 |
365 | // and double-check them
366 | assert_eq!(
367 | jacobian_of_x_0,
368 | Matrix::from_vec(2, 2, vec![1.0, 2.0, 2.0, 16.0])
369 | );
370 | assert_eq!(f_of_x_0.as_slice(), &[3.0, 13.0]);
371 |
372 | // one iteration of newton's method
373 | let x_1 = step_newtons_method(jacobian_of_x_0, &x_0, f_of_x_0).unwrap();
374 | let should_be = Vector::from_vec(vec![-10.0 / 12.0, 17.0 / 12.0]);
375 | approx::relative_eq!(x_1, should_be);
376 | }
377 |
378 | macro_rules! solve_test {
379 | ($(#[$attr:meta])* $name:ident, $equations:expr => { $( $var_name:ident : $value:expr ),* $(,)? }
380 | ) => {
381 | $(
382 | #[$attr]
383 | )*
384 | #[test]
385 | fn $name() {
386 | let equations = SystemOfEquations::from_equations(& $equations).unwrap();
387 | let ctx = Builtins::default();
388 |
389 | let got = equations.clone().solve(&ctx).unwrap();
390 |
391 | let mut should_be = HashMap::::new();
392 |
393 | $(
394 | let p = Parameter::named(stringify!($var_name));
395 | should_be.insert(p, $value);
396 | )*
397 |
398 | assert_eq!(
399 | got.known_values.keys().collect::>(),
400 | should_be.keys().collect::>(),
401 | "The keys should match",
402 | );
403 |
404 | $(
405 | let p = Parameter::named(stringify!($var_name));
406 | approx::assert_relative_eq!(got.known_values[&p], should_be[&p]);
407 | )*
408 |
409 | // double-check with the equations to make sure it all adds up
410 | for equation in equations {
411 | let lookup = |p: &Parameter| got.known_values.get(p).copied();
412 | let evaluated = crate::ops::evaluate(&equation.body, &lookup, &ctx).unwrap();
413 |
414 | approx::assert_relative_eq!(evaluated, 0.0);
415 | }
416 | }
417 | };
418 | }
419 |
420 | solve_test!(x_is_5, ["x = 5"] => { x: 5.0 });
421 | solve_test!(unrelated_equations, ["x = 5", "y = -2"] => { x: 5.0, y: -2.0 });
422 | solve_test!(difference_of_numbers, ["x + y = 10", "x - y = 0"] => { x: 5.0, y: 5.0 });
423 | solve_test!(simple_trig, ["sin(x) = 0"] => { x: 0.0 });
424 | solve_test!(sin_x_take_cos_x, ["sin(x) - cos(x + 90) = 0"] => { x: 0.0 });
425 | solve_test!(#[ignore] simple_trig_2, ["cos(x) = 1"] => { x: 0.0 });
426 | solve_test!(#[ignore] difference_of_numbers_reversed, ["x - y = 0", "x + y = 10"] => { x: 5.0, y: 5.0 });
427 | solve_test!(#[ignore] trig_equation_with_no_solution, ["sin(x) * cos(x + 90) = 0.5"] => {x: 0.0});
428 | }
429 |
--------------------------------------------------------------------------------