├── src ├── tensors.rs ├── utils.rs ├── lib.rs ├── printer.rs ├── combinatorics.rs ├── domains.rs ├── domains │ └── rational.rs └── numerical_integration.rs ├── .gitignore ├── examples ├── dual.rs └── integration.rs ├── .github └── workflows │ └── coverage.yml ├── License.md ├── Cargo.toml ├── Readme.md └── Contributing.md /src/tensors.rs: -------------------------------------------------------------------------------- 1 | //! Methods for tensor manipulation and linear algebra. 2 | pub mod matrix; 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Cargo.lock 2 | target/ 3 | flake.lock 4 | flake.nix 5 | .envrc 6 | .direnv/ 7 | .devenv/ 8 | .devenv* 9 | devenv.local.nix 10 | .zed/ 11 | .direnv 12 | devenv.lock 13 | devenv.nix 14 | devenv.yaml 15 | 16 | .pre-commit-config.yaml 17 | -------------------------------------------------------------------------------- /examples/dual.rs: -------------------------------------------------------------------------------- 1 | use numerica::{ 2 | create_hyperdual_from_components, 3 | domains::{float::FloatLike, rational::Rational}, 4 | }; 5 | 6 | create_hyperdual_from_components!( 7 | Dual, 8 | [ 9 | [0, 0, 0], 10 | [1, 0, 0], 11 | [0, 1, 0], 12 | [0, 0, 1], 13 | [1, 1, 0], 14 | [1, 0, 1], 15 | [0, 1, 1], 16 | [1, 1, 1], 17 | [2, 0, 0] 18 | ] 19 | ); 20 | 21 | fn main() { 22 | let x = Dual::::new_variable(0, (1, 1).into()); 23 | let y = Dual::new_variable(1, (2, 1).into()); 24 | let z = Dual::new_variable(2, (3, 1).into()); 25 | 26 | let t3 = x * y * z; 27 | 28 | println!("{}", t3.inv()); 29 | } 30 | -------------------------------------------------------------------------------- /.github/workflows/coverage.yml: -------------------------------------------------------------------------------- 1 | name: Coverage 2 | 3 | on: [pull_request, push] 4 | 5 | jobs: 6 | coverage: 7 | runs-on: ubuntu-latest 8 | env: 9 | CARGO_TERM_COLOR: always 10 | steps: 11 | - uses: actions/checkout@v5 12 | - name: Install Rust 13 | run: rustup update stable 14 | - name: Install cargo-llvm-cov 15 | uses: taiki-e/install-action@cargo-llvm-cov 16 | - name: Install nextest 17 | uses: taiki-e/install-action@nextest 18 | - name: Generate code coverage 19 | run: cargo llvm-cov --workspace --codecov --output-path codecov.json 20 | - name: Upload coverage to Codecov 21 | uses: codecov/codecov-action@v5 22 | with: 23 | token: ${{ secrets.CODECOV_TOKEN }} 24 | slug: symbolica-dev/numerica 25 | files: codecov.json 26 | fail_ci_if_error: true 27 | -------------------------------------------------------------------------------- /License.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025-present Ruijl Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining 6 | a copy of this software and associated documentation files (the 7 | "Software"), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 19 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 20 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 22 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | //! Utility traits and structures. 2 | 3 | use std::ops::{Deref, DerefMut}; 4 | 5 | use dyn_clone::DynClone; 6 | 7 | /// A wrapper around a mutable reference that tracks if the value 8 | /// has been mutably accessed. 9 | #[derive(Debug)] 10 | pub struct Settable<'a, T> { 11 | value: &'a mut T, 12 | is_set: bool, 13 | } 14 | 15 | impl Deref for Settable<'_, T> { 16 | type Target = T; 17 | 18 | fn deref(&self) -> &Self::Target { 19 | &self.value 20 | } 21 | } 22 | 23 | impl DerefMut for Settable<'_, T> { 24 | fn deref_mut(&mut self) -> &mut Self::Target { 25 | self.is_set = true; 26 | self.value 27 | } 28 | } 29 | 30 | impl<'a, T> From<&'a mut T> for Settable<'a, T> { 31 | fn from(value: &'a mut T) -> Self { 32 | Self { 33 | value, 34 | is_set: false, 35 | } 36 | } 37 | } 38 | 39 | impl Settable<'_, T> { 40 | /// Check if the value has been set. 41 | pub fn is_set(&self) -> bool { 42 | self.is_set 43 | } 44 | } 45 | 46 | /// A cloneable function that checks for abort. 47 | pub trait AbortCheck: Fn() -> bool + DynClone + Send + Sync {} 48 | dyn_clone::clone_trait_object!(AbortCheck); 49 | impl bool> AbortCheck for T {} 50 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "numerica" 3 | authors = ["Ben Ruijl "] 4 | keywords = ["algebra", "mathematics", "integer", "float", "rational"] 5 | version = "1.1.0" 6 | edition = "2024" 7 | license = "MIT" 8 | description = "Open-source math library for exact and floating point computations" 9 | readme = "Readme.md" 10 | repository = "https://github.com/symbolica-dev/numerica" 11 | rust-version = "1.88" 12 | 13 | [features] 14 | # import/export objects that do not depend on the state using serde 15 | serde = ["dep:serde", "rug/serde"] 16 | # allow importing/exporting of atoms with a state map using bincode 17 | bincode = ["dep:bincode", "dep:bincode-trait-derive"] 18 | python = ["pyo3", "numpy"] 19 | python_stubgen = ["python", "pyo3-stub-gen"] 20 | 21 | [dependencies] 22 | pyo3 = { version = "0.27", features = ["num-complex"], optional = true } 23 | numpy = { version = "0.27", optional = true } 24 | rand = "0.9" 25 | rand_xoshiro = "0.7" 26 | rug = "1.27.0" 27 | wide = "0.8" 28 | ahash = "0.8.7" 29 | dyn-clone = "1.0" 30 | colored = "3.0" 31 | smallvec = "1.13" 32 | serde = { version = "1.0", features = ["derive"], optional = true } 33 | bincode = { version = "2.0", optional = true } 34 | bincode-trait-derive = { version = "0.1.0", optional = true } 35 | pyo3-stub-gen = { version = "0.17", default-features = false, features = [ 36 | "numpy", 37 | ], optional = true } 38 | -------------------------------------------------------------------------------- /examples/integration.rs: -------------------------------------------------------------------------------- 1 | use std::f64::consts::PI; 2 | 3 | use numerica::numerical_integration::{ 4 | ContinuousGrid, DiscreteGrid, Grid, MonteCarloRng, Sample, 5 | }; 6 | 7 | fn main() { 8 | // Integrate x*pi + x^2 using multi-channeling: 9 | // x*pi and x^2 will have their own Vegas grid 10 | let fs = [|x: f64| (x * PI).sin(), |x: f64| x * x]; 11 | 12 | let mut grid = DiscreteGrid::new( 13 | vec![ 14 | Some(Grid::Continuous(ContinuousGrid::new( 15 | 1, 10, 1000, None, false, 16 | ))), 17 | Some(Grid::Continuous(ContinuousGrid::new( 18 | 1, 10, 1000, None, false, 19 | ))), 20 | ], 21 | 0.01, 22 | false, 23 | ); 24 | 25 | let mut rng = MonteCarloRng::new(0, 0); 26 | 27 | let mut sample = Sample::new(); 28 | for iteration in 1..20 { 29 | // sample 10_000 times per iteration 30 | for _ in 0..10_000 { 31 | grid.sample(&mut rng, &mut sample); 32 | 33 | if let Sample::Discrete(_weight, i, cont_sample) = &sample { 34 | if let Sample::Continuous(_cont_weight, xs) = cont_sample.as_ref().unwrap().as_ref() 35 | { 36 | grid.add_training_sample(&sample, fs[*i](xs[0])).unwrap(); 37 | } 38 | } 39 | } 40 | 41 | grid.update(1.5, 1.5); 42 | 43 | println!( 44 | "Integral at iteration {:2}: {:.6} ± {:.6}", 45 | iteration, grid.accumulator.avg, grid.accumulator.err 46 | ); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Numerica is an open-source mathematics library for Rust, that provides high-performance number types, such as error-tracking floats and finite field elements. 2 | //! 3 | //! It provides 4 | //! - Abstractions over rings, Euclidean domains, fields and floats 5 | //! - High-performance Integer with automatic up-and-downgrading to arbitrary precision types 6 | //! - Rational numbers with reconstruction algorithms 7 | //! - Fast finite field arithmetic 8 | //! - Error-tracking floating point types 9 | //! - Generic dual numbers for automatic differentiation 10 | //! - Matrix operations and linear system solving 11 | //! - Numerical integration using Vegas algorithm with discrete layer support 12 | //! 13 | //! For operations on symbols, check out the sister project [Symbolica](https://symbolica.io). 14 | //! 15 | //! # Example 16 | //! Solve a linear system over the rationals: 17 | //! 18 | //! ```rust 19 | //! # use numerica::tensors::matrix::Matrix; 20 | //! # use numerica::domains::rational::Q; 21 | //! let a = Matrix::from_linear( 22 | //! vec![ 23 | //! 1.into(), 2.into(), 3.into(), 24 | //! 4.into(), 5.into(), 16.into(), 25 | //! 7.into(), 8.into(), 9.into(), 26 | //! ], 27 | //! 3, 3, Q, 28 | //! ) 29 | //! .unwrap(); 30 | //! 31 | //! let b = Matrix::from_linear(vec![1.into(), 2.into(), 3.into()], 3, 1, Q).unwrap(); 32 | //! 33 | //! let r = a.solve(&b).unwrap(); 34 | //! assert_eq!(r.into_vec(), [(-1, 3), (2, 3), (0, 1)]); 35 | //! ``` 36 | //! Solution: $(-1/3, 2/3, 0)$. 37 | pub mod combinatorics; 38 | pub mod domains; 39 | pub mod numerical_integration; 40 | pub mod printer; 41 | pub mod tensors; 42 | pub mod utils; 43 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Numerica 2 | 3 |

4 | Symbolica website 5 | Zulip Chat 6 | Numerica repository 7 | Codecov 8 |

9 | 10 | Numerica is an open-source mathematics library for Rust, that provides high-performance number types, such as error-tracking floats and finite field elements. 11 | 12 | It provides 13 | - Abstractions over rings, Euclidean domains, fields and floats 14 | - High-performance Integer with automatic up-and-downgrading to arbitrary precision types 15 | - Rational numbers with reconstruction algorithms 16 | - Fast finite field arithmetic 17 | - Error-tracking floating point types 18 | - Generic dual numbers for automatic (higher-order) differentiation 19 | - Matrix operations and linear system solving 20 | - Numerical integration using Vegas algorithm with discrete layer support 21 | 22 | For operations on symbols, check out the parent project [Symbolica](https://symbolica.io). 23 | 24 | 25 | # Examples 26 | 27 | 28 | ### Solving a linear system 29 | 30 | Solve a linear system over the rationals: 31 | 32 | ```rust 33 | let a = Matrix::from_linear( 34 | vec![ 35 | 1.into(), 2.into(), 3.into(), 36 | 4.into(), 5.into(), 16.into(), 37 | 7.into(), 8.into(), 9.into(), 38 | ], 39 | 3, 3, Q, 40 | ) 41 | .unwrap(); 42 | 43 | let b = Matrix::new_vec(vec![1.into(), 2.into(), 3.into()], Q); 44 | 45 | let r = a.solve(&b).unwrap(); 46 | assert_eq!(r.into_vec(), [(-1, 3), (2, 3), (0, 1)]); 47 | ``` 48 | Solution: $(-1/3, 2/3, 0)$. 49 | 50 | Solve over the finite field $\mathbb{Z}_7$: 51 | 52 | ```rust 53 | let z_7 = Zp::new(7); 54 | 55 | let a = Matrix::from_linear( 56 | vec![ 57 | z_7.to_element(5), z_7.to_element(8), 58 | z_7.to_element(2), z_7.to_element(1), 59 | ], 60 | 2, 2, z_7, 61 | ) 62 | .unwrap(); 63 | 64 | let r = a.inv(); 65 | assert_eq!(r.into_vec(), [z_7.to_element(4), z_7.to_element(3), z_7.to_element(2), z_7.to_element(0)]); 66 | ``` 67 | 68 | ### Error-tracking floating points 69 | 70 | Wrap `f64` and `Float` in an `ErrorPropagatingFloat` to propagate errors through 71 | your computation. For example, a number with 60 accurate digits only has 10 remaining after the following operations: 72 | 73 | ```rust 74 | let a = ErrorPropagatingFloat::new(Float::with_val(200, 1e-50), 60.); 75 | let r = (a.exp() - a.one()) / a; // large cancellation 76 | assert_eq!(format!("{r}"), "1.000000000"); 77 | assert_eq!(r.get_precision(), Some(10.205999132796238)); 78 | ``` 79 | 80 | ### Automatic differentiation with dual numbers 81 | 82 | Create a dual number that fits your needs (supports multiple variables and higher-order differentiation). 83 | Here, we create a simple dual number in three variables: 84 | 85 | ```rust 86 | create_hyperdual_single_derivative!(Dual, 3); 87 | 88 | fn main() { 89 | let x = Dual::::new_variable(0, (1, 1).into()); 90 | let y = Dual::new_variable(1, (2, 1).into()); 91 | let z = Dual::new_variable(2, (3, 1).into()); 92 | 93 | let t3 = x * y * z; 94 | 95 | println!("{}", t3.inv()); 96 | } 97 | ``` 98 | It yields `(1/6)+(-1/6)*ε0+(-1/12)*ε1+(-1/18)*ε2`. 99 | 100 | The multiplication table is computed and unrolled at compile time for maximal performance. 101 | 102 | ### Solve integer relations 103 | 104 | Solve 105 | 106 | $$ 107 | -32.0177 c_1 + 3.1416 c_2 + 2.7183 c_3 = 0 108 | $$ 109 | 110 | over the integers using PSLQ: 111 | 112 | ```rust 113 | let result = Integer::solve_integer_relation( 114 | &[F64::from(-32.0177), F64::from(3.1416), F64::from(2.7183)], 115 | F64::from(1e-4), 116 | 1, 117 | Some(Integer::from(100000u64)), 118 | None, 119 | ) 120 | .unwrap(); 121 | 122 | assert_eq!(result, &[1, 5, 6]); 123 | ``` 124 | 125 | Or via LLL basis reduction: 126 | 127 | ```rust 128 | let v1 = Vector::new(vec![1.into(), 0.into(), 0.into(), 31416.into()], Z); 129 | let v2 = Vector::new(vec![0.into(), 1.into(), 0.into(), 27183.into()], Z); 130 | let v3 = Vector::new(vec![0.into(), 0.into(), 1.into(), (-320177).into()], Z); 131 | 132 | let basis = Vector::basis_reduction(&[v1, v2, v3], (3, 4).into()); 133 | 134 | assert_eq!(basis[0].into_vec(), [5, 6, 1, 1]); 135 | ``` 136 | 137 | ### Numerical integration 138 | 139 | Integrate $sin(x \pi) + y$ using the Vegas algorithm and using random numbers suitable for Monte Carlo integration: 140 | 141 | ```rust 142 | let f = |x: &[f64]| (x[0] * std::f64::consts::PI).sin() + x[1]; 143 | let mut grid = Grid::Continuous(ContinuousGrid::new(2, 128, 100, None, false)); 144 | let mut rng = MonteCarloRng::new(0, 0); 145 | let mut sample = Sample::new(); 146 | for iteration in 1..20 { 147 | for _ in 0..10_000 { 148 | grid.sample(&mut rng, &mut sample); 149 | 150 | if let Sample::Continuous(_cont_weight, xs) = &sample { 151 | grid.add_training_sample(&sample, f(xs)).unwrap(); 152 | } 153 | } 154 | 155 | grid.update(1.5, 1.5); 156 | 157 | println!( 158 | "Integral at iteration {}: {}", 159 | iteration, 160 | grid.get_statistics().format_uncertainty() 161 | ); 162 | } 163 | ``` 164 | 165 | 166 | 167 | 168 | ## Development 169 | 170 | Follow the development and discussions on [Zulip](https://reform.zulipchat.com)! -------------------------------------------------------------------------------- /Contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing to Numerica 2 | 3 | Numerica is open to contributions from the community, which can be submitted via a Pull Request (PR) on Github. Before you submit your PR, consider the following guidelines: 4 | 5 | - Be sure that the PR describes the problem you are fixing, or documents the design for the feature you are building. Discussing the design upfront helps to ensure that we are ready to accept your work. 6 | - Please sign our Contributor License Agreement (CLA). You will be prompted to sign one when you submit a PR. We cannot accept code without a signed CLA. Make sure you author all contributed Git commits with email address associated with your CLA signature. 7 | 8 | 9 | ## Numerica Individual Contributor License Agreement 10 | 11 | In order to clarify the intellectual property license granted with Contributions from any person or entity, Ruijl Research (CHE-330.486.212) must have a Contributor License Agreement ("CLA") on file that has been signed by each Contributor, indicating agreement to the license terms below. This license is for your protection as a Contributor as well as the protection of Ruijl Research; it does not change your rights to use your own Contributions for any other purpose. 12 | 13 | You accept and agree to the following terms and conditions for Your present and future Contributions submitted to Ruijl Research. Except for the license granted herein to Ruijl Research and recipients of software distributed by Ruijl Research, You reserve all right, title, and interest in and to Your Contributions. 14 | 15 | 1. Definitions. 16 | 17 | "You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner that is making this Agreement with Ruijl Research. For legal entities, the entity making a Contribution and all other entities that control, are controlled by, or are under common control with that entity are considered to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 18 | 19 | "Contribution" shall mean any original work of authorship, including any modifications or additions to an existing work, that is intentionally submitted by You to Ruijl Research for inclusion in, or documentation of, any of the products owned or managed by Ruijl Research (the "Work"). For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to Ruijl Research or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, Ruijl Research for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution." 20 | 21 | 2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to Ruijl Research and to recipients of software distributed by Ruijl Research a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense, and distribute Your Contributions and such derivative works. 22 | 23 | 3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to Ruijl Research and to recipients of software distributed by Ruijl Research a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which such Contribution(s) was submitted. If any entity institutes patent litigation against You or any other entity (including a cross-claim or counterclaim in a lawsuit) alleging that your Contribution, or the Work to which you have contributed, constitutes direct or contributory patent infringement, then any patent licenses granted to that entity under this Agreement for that Contribution or Work shall terminate as of the date such litigation is filed. 24 | 25 | 4. You represent that you are legally entitled to grant the above license. If your employer(s) has rights to intellectual property that you create that includes your Contributions, you represent that you have received permission to make Contributions on behalf of that employer, that your employer has waived such rights for your Contributions to Ruijl Research, or that your employer has executed a separate Corporate CLA with Ruijl Research. 26 | 27 | 5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of others). You represent that Your Contribution submissions include complete details of any third-party license or other restriction (including, but not limited to, related patents and trademarks) of which you are personally aware and which are associated with any part of Your Contributions. 28 | 29 | 6. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support. You may provide support for free, for a fee, or not at all. Unless required by applicable law or agreed to in writing, You provide Your Contributions on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON- INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. 30 | 31 | 7. Should You wish to submit work that is not Your original creation, You may submit it to Ruijl Research separately from any Contribution, identifying the complete details of its source and of any license or other restriction (including, but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and conspicuously marking the work as "Submitted on behalf of a third-party: [named here]". 32 | 33 | 8. You agree to notify Ruijl Research of any facts or circumstances of which you become aware that would make these representations inaccurate in any respect. 34 | 35 | -------------------------------------------------------------------------------- /src/printer.rs: -------------------------------------------------------------------------------- 1 | //! Methods for printing rings. 2 | 3 | /// The overall print mode. 4 | #[derive(Debug, Copy, Clone, PartialEq, Eq)] 5 | #[non_exhaustive] 6 | #[derive(Default)] 7 | pub enum PrintMode { 8 | #[default] 9 | Symbolica, 10 | Latex, 11 | Mathematica, 12 | Sympy, 13 | } 14 | 15 | impl PrintMode { 16 | pub fn is_symbolica(&self) -> bool { 17 | *self == PrintMode::Symbolica 18 | } 19 | 20 | pub fn is_latex(&self) -> bool { 21 | *self == PrintMode::Latex 22 | } 23 | 24 | pub fn is_mathematica(&self) -> bool { 25 | *self == PrintMode::Mathematica 26 | } 27 | 28 | pub fn is_sympy(&self) -> bool { 29 | *self == PrintMode::Sympy 30 | } 31 | } 32 | 33 | /// Various options for printing expressions. 34 | #[derive(Debug, Copy, Clone)] 35 | pub struct PrintOptions { 36 | pub mode: PrintMode, 37 | pub terms_on_new_line: bool, 38 | pub color_top_level_sum: bool, 39 | pub color_builtin_symbols: bool, 40 | pub print_ring: bool, 41 | pub symmetric_representation_for_finite_field: bool, 42 | pub explicit_rational_polynomial: bool, 43 | pub number_thousands_separator: Option, 44 | pub multiplication_operator: char, 45 | pub double_star_for_exponentiation: bool, 46 | pub square_brackets_for_function: bool, 47 | pub num_exp_as_superscript: bool, 48 | pub precision: Option, 49 | pub pretty_matrix: bool, 50 | pub hide_namespace: Option<&'static str>, 51 | pub hide_all_namespaces: bool, 52 | /// Print attribute and tags 53 | pub include_attributes: bool, 54 | pub color_namespace: bool, 55 | pub max_terms: Option, 56 | /// Provides a handle to set the behavior of the custom print function. 57 | /// Symbolica does not use this option for its own printing. 58 | pub custom_print_mode: Option<(&'static str, usize)>, 59 | } 60 | 61 | impl PrintOptions { 62 | pub const fn new() -> Self { 63 | Self { 64 | terms_on_new_line: false, 65 | color_top_level_sum: true, 66 | color_builtin_symbols: true, 67 | print_ring: true, 68 | symmetric_representation_for_finite_field: false, 69 | explicit_rational_polynomial: false, 70 | number_thousands_separator: None, 71 | multiplication_operator: '*', 72 | double_star_for_exponentiation: false, 73 | square_brackets_for_function: false, 74 | num_exp_as_superscript: false, 75 | mode: PrintMode::Symbolica, 76 | precision: None, 77 | pretty_matrix: false, 78 | hide_namespace: None, 79 | hide_all_namespaces: true, 80 | include_attributes: false, 81 | color_namespace: true, 82 | max_terms: None, 83 | custom_print_mode: None, 84 | } 85 | } 86 | 87 | /// Print the output in a Mathematica-readable format. 88 | pub const fn mathematica() -> PrintOptions { 89 | Self { 90 | terms_on_new_line: false, 91 | color_top_level_sum: false, 92 | color_builtin_symbols: false, 93 | print_ring: true, 94 | symmetric_representation_for_finite_field: false, 95 | explicit_rational_polynomial: false, 96 | number_thousands_separator: None, 97 | multiplication_operator: ' ', 98 | double_star_for_exponentiation: false, 99 | square_brackets_for_function: true, 100 | num_exp_as_superscript: false, 101 | mode: PrintMode::Mathematica, 102 | precision: None, 103 | pretty_matrix: false, 104 | hide_namespace: Some("symbolica"), 105 | hide_all_namespaces: false, 106 | include_attributes: false, 107 | color_namespace: false, 108 | max_terms: None, 109 | custom_print_mode: None, 110 | } 111 | } 112 | 113 | /// Print the output in a Latex input format. 114 | pub const fn latex() -> PrintOptions { 115 | Self { 116 | terms_on_new_line: false, 117 | color_top_level_sum: false, 118 | color_builtin_symbols: false, 119 | print_ring: true, 120 | symmetric_representation_for_finite_field: false, 121 | explicit_rational_polynomial: false, 122 | number_thousands_separator: None, 123 | multiplication_operator: ' ', 124 | double_star_for_exponentiation: false, 125 | square_brackets_for_function: false, 126 | num_exp_as_superscript: false, 127 | mode: PrintMode::Latex, 128 | precision: None, 129 | pretty_matrix: false, 130 | hide_namespace: None, 131 | hide_all_namespaces: true, 132 | include_attributes: false, 133 | color_namespace: false, 134 | max_terms: None, 135 | custom_print_mode: None, 136 | } 137 | } 138 | 139 | /// Print the output suitable for a file. 140 | pub const fn file() -> PrintOptions { 141 | Self { 142 | terms_on_new_line: false, 143 | color_top_level_sum: false, 144 | color_builtin_symbols: false, 145 | print_ring: false, 146 | symmetric_representation_for_finite_field: false, 147 | explicit_rational_polynomial: false, 148 | number_thousands_separator: None, 149 | multiplication_operator: '*', 150 | double_star_for_exponentiation: false, 151 | square_brackets_for_function: false, 152 | num_exp_as_superscript: false, 153 | mode: PrintMode::Symbolica, 154 | precision: None, 155 | pretty_matrix: false, 156 | hide_namespace: None, 157 | hide_all_namespaces: false, 158 | include_attributes: false, 159 | color_namespace: false, 160 | max_terms: None, 161 | custom_print_mode: None, 162 | } 163 | } 164 | 165 | /// Print the output suitable for a file without namespaces. 166 | pub const fn file_no_namespace() -> PrintOptions { 167 | Self { 168 | hide_all_namespaces: true, 169 | ..Self::file() 170 | } 171 | } 172 | 173 | /// Print the output suitable for a file with namespaces 174 | /// and attributes and tags. 175 | pub const fn full() -> PrintOptions { 176 | Self { 177 | include_attributes: true, 178 | ..Self::file() 179 | } 180 | } 181 | 182 | /// Print the output with namespaces suppressed. 183 | pub const fn short() -> PrintOptions { 184 | Self { 185 | hide_all_namespaces: true, 186 | ..Self::new() 187 | } 188 | } 189 | 190 | /// Print the output in a sympy input format. 191 | pub const fn sympy() -> PrintOptions { 192 | Self { 193 | double_star_for_exponentiation: true, 194 | ..Self::file() 195 | } 196 | } 197 | 198 | pub fn from_fmt(f: &std::fmt::Formatter) -> PrintOptions { 199 | PrintOptions { 200 | precision: f.precision(), 201 | hide_all_namespaces: !f.alternate(), 202 | terms_on_new_line: f.align() == Some(std::fmt::Alignment::Right), 203 | ..Default::default() 204 | } 205 | } 206 | 207 | pub fn update_with_fmt(mut self, f: &std::fmt::Formatter) -> Self { 208 | self.precision = f.precision(); 209 | 210 | if f.alternate() { 211 | self.hide_all_namespaces = false; 212 | } 213 | 214 | if let Some(a) = f.align() { 215 | self.terms_on_new_line = a == std::fmt::Alignment::Right; 216 | } 217 | self 218 | } 219 | 220 | pub const fn hide_namespace(mut self, namespace: &'static str) -> Self { 221 | self.hide_namespace = Some(namespace); 222 | self 223 | } 224 | } 225 | 226 | impl Default for PrintOptions { 227 | fn default() -> Self { 228 | Self::new() 229 | } 230 | } 231 | 232 | /// The current state useful for printing. These 233 | /// settings will control, for example, if parentheses are needed 234 | /// (e.g., a sum in a product), 235 | /// and if 1 should be suppressed (e.g. in a product). 236 | #[derive(Debug, Copy, Clone)] 237 | pub struct PrintState { 238 | pub in_sum: bool, 239 | pub in_product: bool, 240 | pub suppress_one: bool, 241 | pub in_exp: bool, 242 | pub in_exp_base: bool, 243 | pub top_level_add_child: bool, 244 | pub superscript: bool, 245 | pub level: u16, 246 | } 247 | 248 | impl Default for PrintState { 249 | fn default() -> Self { 250 | Self::new() 251 | } 252 | } 253 | 254 | impl PrintState { 255 | pub const fn new() -> PrintState { 256 | Self { 257 | in_sum: false, 258 | in_product: false, 259 | in_exp: false, 260 | in_exp_base: false, 261 | suppress_one: false, 262 | top_level_add_child: true, 263 | superscript: false, 264 | level: 0, 265 | } 266 | } 267 | 268 | pub fn from_fmt(f: &std::fmt::Formatter) -> PrintState { 269 | PrintState { 270 | in_sum: f.sign_plus(), 271 | ..Default::default() 272 | } 273 | } 274 | 275 | pub fn update_with_fmt(mut self, f: &std::fmt::Formatter) -> Self { 276 | self.in_sum = f.sign_plus(); 277 | self 278 | } 279 | 280 | pub fn step(self, in_sum: bool, in_product: bool, in_exp: bool, in_exp_base: bool) -> Self { 281 | Self { 282 | in_sum, 283 | in_product, 284 | in_exp, 285 | in_exp_base, 286 | level: self.level + 1, 287 | ..self 288 | } 289 | } 290 | } 291 | -------------------------------------------------------------------------------- /src/combinatorics.rs: -------------------------------------------------------------------------------- 1 | //! Provides combinatorial utilities for generating combinations, permutations, and partitions of sets. 2 | //! 3 | //! # Examples 4 | //! 5 | //! Combinations without replacements: 6 | //! 7 | //! ```rust 8 | //! # use numerica::combinatorics::CombinationIterator; 9 | //! 10 | //! let mut c = CombinationIterator::new(4, 3); 11 | //! let mut combinations = vec![]; 12 | //! while let Some(a) = c.next() { 13 | //! combinations.push(a.to_vec()); 14 | //! } 15 | //! 16 | //! let ans = vec![[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]]; 17 | //! 18 | //! assert_eq!(combinations, ans); 19 | //! ``` 20 | //! 21 | //! Partitions: 22 | //! 23 | //! ```rust 24 | //! # use numerica::combinatorics::partitions; 25 | //! 26 | //! let p = partitions( 27 | //! &[1, 1, 1, 2, 2], 28 | //! &[('f', 2), ('g', 2), ('f', 1)], 29 | //! false, 30 | //! false, 31 | //! ); 32 | //! 33 | //! let res = vec![ 34 | //! (3.into(), vec![('f', vec![1]), ('f', vec![1, 1]), ('g', vec![2, 2])]), 35 | //! (12.into(), vec![('f', vec![1]), ('f', vec![1, 2]), ('g', vec![1, 2])]), 36 | //! (3.into(), vec![('f', vec![1]), ('f', vec![2, 2]), ('g', vec![1, 1])]), 37 | //! (6.into(), vec![('f', vec![2]), ('f', vec![1, 1]), ('g', vec![1, 2])]), 38 | //! (6.into(), vec![('f', vec![2]), ('f', vec![1, 2]), ('g', vec![1, 1])]), 39 | //! ]; 40 | //! 41 | //! assert_eq!(p, res); 42 | //! ``` 43 | use ahash::HashMap; 44 | use smallvec::SmallVec; 45 | use std::{cmp::Ordering, hash::Hash}; 46 | 47 | use crate::domains::integer::Integer; 48 | 49 | /// An iterator type for generating combinations of indices without replacement. 50 | /// 51 | /// # Examples 52 | /// 53 | /// Create an iterator to generate combinations of 3 elements from a total of 4: 54 | /// ```rust 55 | /// # use numerica::combinatorics::CombinationIterator; 56 | /// let mut combos = CombinationIterator::new(4, 3); 57 | /// 58 | /// while let Some(c) = combos.next() { 59 | /// println!("{:?}", c); 60 | /// } 61 | /// 62 | /// // The combinations output is: 63 | /// // [0, 1, 2] 64 | /// // [0, 1, 3] 65 | /// // [0, 2, 3] 66 | /// // [1, 2, 3] 67 | /// ``` 68 | pub struct CombinationIterator { 69 | n: usize, 70 | indices: Vec, 71 | init: bool, 72 | } 73 | 74 | impl CombinationIterator { 75 | /// Creates a new `CombinationIterator` for generating combinations of `k` elements from a set of `n` elements. 76 | pub fn new(n: usize, k: usize) -> CombinationIterator { 77 | CombinationIterator { 78 | indices: (0..k).collect(), 79 | n, 80 | init: false, 81 | } 82 | } 83 | 84 | /// Advances the iterator and returns the next combination. 85 | pub fn next(&mut self) -> Option<&[usize]> { 86 | if self.indices.is_empty() || self.indices.len() > self.n { 87 | return None; 88 | } 89 | 90 | if !self.init { 91 | self.init = true; 92 | 93 | return Some(&self.indices); 94 | } 95 | 96 | if self.indices.is_empty() { 97 | return None; 98 | } 99 | 100 | let mut done = true; 101 | for (i, v) in self.indices.iter().enumerate().rev() { 102 | if *v < self.n - self.indices.len() + i { 103 | let a = *v + 1; 104 | for (p, vv) in &mut self.indices[i..].iter_mut().enumerate() { 105 | *vv = a + p; 106 | } 107 | 108 | done = false; 109 | break; 110 | } 111 | } 112 | 113 | if done { None } else { Some(&self.indices) } 114 | } 115 | } 116 | 117 | /// An iterator that generates combinations of size `k` from a sequence of items, allowing repeat selections. 118 | /// 119 | /// The iterator will produce each combination in ascending order 120 | /// so that only unique combinations are generated, even though 121 | /// each pick is allowed to repeat items. 122 | /// 123 | /// # Example 124 | /// 125 | /// ```rust 126 | /// # use numerica::combinatorics::CombinationWithReplacementIterator; 127 | /// 128 | /// let mut comb_iter = CombinationWithReplacementIterator::new(3, 2); 129 | /// while let Some(combination) = comb_iter.next() { 130 | /// println!("{:?}", combination); 131 | /// } 132 | /// ``` 133 | /// This would print out combinations like `[0, 0], [0, 1], [0, 2], [1, 1], [1, 2]`, etc. 134 | pub struct CombinationWithReplacementIterator { 135 | indices: SmallVec<[u32; 10]>, 136 | k: u32, 137 | init: bool, 138 | } 139 | 140 | impl CombinationWithReplacementIterator { 141 | /// Creates a new `CombinationWithReplacementIterator` for generating combinations of `k` elements from a set of `n` elements with replacement. 142 | pub fn new(n: usize, k: u32) -> CombinationWithReplacementIterator { 143 | CombinationWithReplacementIterator { 144 | indices: (0..n).map(|_| 0).collect(), 145 | k, 146 | init: false, 147 | } 148 | } 149 | 150 | /// Advances the iterator and returns the next combination with replacement. 151 | pub fn next(&mut self) -> Option<&[u32]> { 152 | if self.indices.is_empty() { 153 | return None; 154 | } 155 | 156 | if !self.init { 157 | self.init = true; 158 | self.indices[0] = self.k; 159 | return Some(&self.indices); 160 | } 161 | 162 | if self.k == 0 { 163 | return None; 164 | } 165 | 166 | // find the last non-zero index that is not at the end 167 | let mut i = self.indices.len() - 1; 168 | while self.indices[i] == 0 { 169 | i -= 1; 170 | } 171 | 172 | // cannot move to the right more 173 | // find the next index 174 | let mut last_val = 0; 175 | if i == self.indices.len() - 1 { 176 | last_val = self.indices[i]; 177 | self.indices[i] = 0; 178 | 179 | if self.indices.len() == 1 { 180 | return None; 181 | } 182 | 183 | i = self.indices.len() - 2; 184 | while self.indices[i] == 0 { 185 | if i == 0 { 186 | return None; 187 | } 188 | 189 | i -= 1; 190 | } 191 | } 192 | 193 | self.indices[i] -= 1; 194 | self.indices[i + 1] = last_val + 1; 195 | 196 | Some(&self.indices) 197 | } 198 | } 199 | 200 | /// Generate all unique permutations of the `list` entries. 201 | /// 202 | /// The combinatorial prefactor of each element is `list.len()! / out.len()` where 203 | /// `out` is the returned list. 204 | pub fn unique_permutations(list: &[T]) -> (Integer, Vec>) { 205 | let mut unique: HashMap<&T, usize> = HashMap::default(); 206 | for e in list { 207 | *unique.entry(e).or_insert(0) += 1; 208 | } 209 | let mut unique: Vec<_> = unique.into_iter().collect(); 210 | unique.sort(); 211 | 212 | // determine pre-factor 213 | let mut prefactor = Integer::one(); 214 | for (_, count) in &unique { 215 | prefactor *= &Integer::factorial(*count as u32); 216 | } 217 | 218 | let mut out = vec![]; 219 | unique_permutations_impl( 220 | &mut unique, 221 | &mut Vec::with_capacity(list.len()), 222 | list.len(), 223 | &mut out, 224 | ); 225 | (prefactor, out) 226 | } 227 | 228 | fn unique_permutations_impl( 229 | unique: &mut Vec<(&T, usize)>, 230 | accum: &mut Vec, 231 | len: usize, 232 | out: &mut Vec>, 233 | ) { 234 | if accum.len() == len { 235 | out.push(accum.to_vec()); 236 | } 237 | 238 | for i in 0..unique.len() { 239 | let (entry, count) = &mut unique[i]; 240 | if *count > 0 { 241 | *count -= 1; 242 | accum.push(entry.clone()); 243 | unique_permutations_impl(unique, accum, len, out); 244 | accum.pop(); 245 | unique[i].1 += 1; 246 | } 247 | } 248 | } 249 | 250 | /// Partition the unordered list `elements` into named bins of unordered lists with a given length, 251 | /// returning all partitions and their multiplicity. 252 | /// 253 | /// # Arguments 254 | /// 255 | /// * `elements` - A slice of elements to partition. 256 | /// * `bins` - A slice of tuples where each tuple contains a bin identifier and the number of elements in that bin. 257 | /// * `fill_last` - A boolean flag indicating whether to add all remaining elements to the last bin if the elements are larger than the bins. 258 | /// * `repeat` - A boolean flag indicating whether to repeat the bins to exactly fit all elements, if possible. 259 | /// 260 | /// # Returns 261 | /// 262 | /// A `Vec` of tuples where each tuple contains: 263 | /// * An `Integer` representing the multiplicity of the partition. 264 | /// * A `Vec` of tuples where each tuple contains a bin identifier and a `Vec` of elements in that bin. 265 | /// 266 | /// # Example 267 | /// 268 | /// ``` 269 | /// # use numerica::combinatorics::partitions; 270 | /// let result = partitions( 271 | /// &[1, 1, 1, 2, 2], 272 | /// &[('f', 2), ('g', 2), ('f', 1)], 273 | /// false, 274 | /// false 275 | /// ); 276 | /// ``` 277 | /// generates all possible ways to partition the elements of three sets 278 | /// and yields: 279 | /// ```plain 280 | /// [(3, [('g', [1]), ('f', [1, 1]), ('f', [2, 2])]), (6, [('g', [1]), ('f', [1, 2]), 281 | /// ('f', [1, 2])]), (6, [('g', [2]), ('f', [1, 1]), ('f', [1, 2])])] 282 | /// ``` 283 | /// 284 | /// This generates all possible ways to partition the elements into the specified bins. 285 | pub fn partitions( 286 | elements: &[T], 287 | bins: &[(B, usize)], 288 | fill_last: bool, 289 | repeat: bool, 290 | ) -> Vec<(Integer, Vec<(B, Vec)>)> { 291 | if bins.is_empty() { 292 | return vec![]; 293 | } 294 | 295 | let bin_sum = bins.iter().map(|b| b.1).sum::(); 296 | match elements.len().cmp(&bin_sum) { 297 | Ordering::Less => { 298 | return vec![]; 299 | } 300 | Ordering::Equal => {} 301 | Ordering::Greater => { 302 | if !fill_last && (!repeat || !elements.len().is_multiple_of(bin_sum)) { 303 | return vec![]; 304 | } 305 | } 306 | } 307 | 308 | // create groups of equal elements 309 | let mut element_groups: HashMap = HashMap::default(); 310 | for e in elements { 311 | *element_groups.entry(e.clone()).or_insert(0) += 1; 312 | } 313 | 314 | let mut element_sorted: Vec<(T, usize)> = element_groups.into_iter().collect(); 315 | element_sorted.sort(); 316 | 317 | let mut sorted_bins = bins.to_vec(); 318 | 319 | // extend the bins if needed 320 | if fill_last { 321 | let last_bin = sorted_bins.last_mut().unwrap(); 322 | last_bin.1 += elements.len() - bin_sum; 323 | } 324 | 325 | if repeat { 326 | for _ in 1..elements.len() / bin_sum { 327 | sorted_bins.extend_from_slice(bins); 328 | } 329 | } 330 | 331 | // sort the bins from largest to smallest and based on the bin id 332 | sorted_bins.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0))); 333 | 334 | fn fill_bin( 335 | len: usize, 336 | elems: &mut [(T, usize)], 337 | accum: &mut Vec, 338 | result: &mut Vec>, 339 | ) { 340 | if len == 0 { 341 | result.push(accum.clone()); 342 | return; 343 | } 344 | 345 | for i in 0..elems.len() { 346 | let (name, count) = &mut elems[i]; 347 | if *count > 0 { 348 | *count -= 1; 349 | accum.push(name.clone()); 350 | fill_bin(len - 1, &mut elems[i..], accum, result); 351 | accum.pop(); 352 | elems[i].1 += 1; 353 | } 354 | } 355 | } 356 | 357 | fn fill_rec( 358 | bins: &[(B, usize)], 359 | elems: &mut [(T, usize)], 360 | single_bin_accum: &mut Vec, 361 | single_bin_fill: &mut Vec>, 362 | accum: &mut Vec<(B, Vec)>, 363 | result: &mut Vec<(Integer, Vec<(B, Vec)>)>, 364 | ) { 365 | if bins.is_empty() { 366 | if elems.iter().all(|x| x.1 == 0) { 367 | result.push((Integer::one(), accum.clone())); 368 | } 369 | return; 370 | } 371 | debug_assert!(elems.iter().any(|x| x.1 > 0)); 372 | 373 | let (bin_id, bin_len) = &bins[0]; 374 | 375 | // find all possible ways to fill fun_len 376 | fill_bin(*bin_len, elems, single_bin_accum, single_bin_fill); 377 | 378 | let mut new_bin_fill = vec![]; 379 | for a in single_bin_fill.drain(..) { 380 | // make sure we generate a descending list 381 | if let Some(l) = accum.last() 382 | && l.0 == *bin_id 383 | && a.len() == l.1.len() 384 | && a < l.1 385 | { 386 | continue; 387 | } 388 | 389 | // remove uses from the counters 390 | for x in &a { 391 | elems.iter_mut().find(|e| e.0 == *x).unwrap().1 -= 1; 392 | } 393 | 394 | accum.push((bin_id.clone(), a.clone())); 395 | fill_rec( 396 | &bins[1..], 397 | elems, 398 | single_bin_accum, 399 | &mut new_bin_fill, 400 | accum, 401 | result, 402 | ); 403 | accum.pop(); 404 | 405 | for x in &a { 406 | elems.iter_mut().find(|e| e.0 == *x).unwrap().1 += 1; 407 | } 408 | } 409 | } 410 | 411 | let mut res = vec![]; 412 | fill_rec( 413 | &mut sorted_bins, 414 | &mut element_sorted, 415 | &mut vec![], 416 | &mut vec![], 417 | &mut vec![], 418 | &mut res, 419 | ); 420 | 421 | // compute the prefactor 422 | let mut counter = vec![]; 423 | let mut bin_groups: HashMap<&(B, Vec), usize> = HashMap::default(); 424 | for (pref, sol) in &mut res { 425 | for (e, _) in &element_sorted { 426 | counter.clear(); 427 | for (_, bin) in &*sol { 428 | let c = bin.iter().filter(|be| *be == e).count(); 429 | if c > 0 { 430 | counter.push(c as u32); 431 | } 432 | } 433 | *pref *= &Integer::multinom(&counter); 434 | } 435 | 436 | // count the number of unique bins 437 | for named_bin in &*sol { 438 | *bin_groups.entry(named_bin).or_insert(0) += 1; 439 | } 440 | 441 | for (_, p) in bin_groups.drain() { 442 | *pref /= &Integer::new(p as i64); 443 | } 444 | } 445 | 446 | res 447 | } 448 | 449 | #[cfg(test)] 450 | mod test { 451 | use super::{CombinationIterator, partitions}; 452 | 453 | #[test] 454 | fn combinations() { 455 | let mut c = CombinationIterator::new(4, 3); 456 | let mut combinations = vec![]; 457 | while let Some(a) = c.next() { 458 | combinations.push(a.to_vec()); 459 | } 460 | 461 | let ans = vec![[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]]; 462 | 463 | assert_eq!(combinations, ans); 464 | } 465 | 466 | #[test] 467 | fn partitions_no_fill() { 468 | let p = partitions( 469 | &[1, 1, 1, 2, 2], 470 | &[('f', 2), ('g', 2), ('f', 1)], 471 | false, 472 | false, 473 | ); 474 | 475 | let res = vec![ 476 | ( 477 | 3.into(), 478 | vec![('f', vec![1]), ('f', vec![1, 1]), ('g', vec![2, 2])], 479 | ), 480 | ( 481 | 12.into(), 482 | vec![('f', vec![1]), ('f', vec![1, 2]), ('g', vec![1, 2])], 483 | ), 484 | ( 485 | 3.into(), 486 | vec![('f', vec![1]), ('f', vec![2, 2]), ('g', vec![1, 1])], 487 | ), 488 | ( 489 | 6.into(), 490 | vec![('f', vec![2]), ('f', vec![1, 1]), ('g', vec![1, 2])], 491 | ), 492 | ( 493 | 6.into(), 494 | vec![('f', vec![2]), ('f', vec![1, 2]), ('g', vec![1, 1])], 495 | ), 496 | ]; 497 | 498 | assert_eq!(p, res); 499 | } 500 | } 501 | -------------------------------------------------------------------------------- /src/domains.rs: -------------------------------------------------------------------------------- 1 | //! Defines core algebraic traits and data structures. 2 | //! 3 | //! The core trait is [Ring], which has two binary operations, addition and multiplication. 4 | //! Each ring has an associated element type, that should not be confused with the ring type itself. 5 | //! For example: 6 | //! - The ring of integers [Z](type@integer::Z) has elements of type [Integer]. 7 | //! - The ring of rational numbers [Q](type@rational::Q) has elements of type [Rational](rational::Rational). 8 | //! - The ring of finite fields [FiniteField](finite_field::FiniteField) has elements of type [FiniteField](finite_field::FiniteFieldElement). 9 | //! - The ring of polynomials [PolynomialRing](super::poly::polynomial::PolynomialRing) has elements of type [MultivariatePolynomial](super::poly::polynomial::MultivariatePolynomial). 10 | //! 11 | //! In general, the ring elements do not implement operations such as addition or multiplication, 12 | //! but rather the ring itself does. Most Symbolica structures are generic over the ring type. 13 | //! 14 | //! An extension of the ring trait is the [`EuclideanDomain`] trait, which adds the ability to compute remainders, quotients, and gcds. 15 | //! Another extension is the [`Field`] trait, which adds the ability to divide and invert elements. 16 | pub mod dual; 17 | pub mod finite_field; 18 | pub mod float; 19 | pub mod integer; 20 | pub mod rational; 21 | 22 | use std::borrow::Borrow; 23 | use std::fmt::{Debug, Display, Error, Formatter}; 24 | use std::hash::Hash; 25 | use std::ops::{Add, Deref, Div, Mul, Sub}; 26 | 27 | use integer::Integer; 28 | 29 | use crate::printer::{PrintOptions, PrintState}; 30 | 31 | /// The internal ordering trait is used to compare elements of a ring. 32 | /// This ordering is defined even for rings that do not have a total ordering, such 33 | /// as complex numbers. 34 | pub trait InternalOrdering { 35 | /// Compare two elements using an internal ordering. 36 | fn internal_cmp(&self, other: &Self) -> std::cmp::Ordering; 37 | } 38 | 39 | macro_rules! impl_internal_ordering { 40 | ($($t:ty),*) => { 41 | $( 42 | impl InternalOrdering for $t { 43 | fn internal_cmp(&self, other: &Self) -> std::cmp::Ordering { 44 | self.cmp(other) 45 | } 46 | } 47 | )* 48 | }; 49 | } 50 | 51 | impl_internal_ordering!(u8); 52 | impl_internal_ordering!(u64); 53 | 54 | macro_rules! impl_internal_ordering_range { 55 | ($($t:ty),*) => { 56 | $( 57 | impl InternalOrdering for $t { 58 | fn internal_cmp(&self, other: &Self) -> std::cmp::Ordering { 59 | match self.len().cmp(&other.len()) { 60 | std::cmp::Ordering::Equal => (), 61 | ord => return ord, 62 | } 63 | 64 | for (i, j) in self.iter().zip(other) { 65 | match i.internal_cmp(&j) { 66 | std::cmp::Ordering::Equal => {} 67 | ord => return ord, 68 | } 69 | } 70 | 71 | std::cmp::Ordering::Equal 72 | } 73 | } 74 | )* 75 | }; 76 | } 77 | 78 | impl_internal_ordering_range!([T]); 79 | impl_internal_ordering_range!(Vec); 80 | 81 | /// Rings whose elements contain all the knowledge of the ring itself, 82 | /// for example integers. A counterexample would be finite field elements, 83 | /// as they do not store the prime. 84 | pub trait SelfRing: Clone + PartialEq + Eq + Hash + InternalOrdering + Debug + Display { 85 | fn is_zero(&self) -> bool; 86 | fn is_one(&self) -> bool; 87 | fn format( 88 | &self, 89 | opts: &PrintOptions, 90 | state: PrintState, 91 | f: &mut W, 92 | ) -> Result; 93 | 94 | fn format_string(&self, opts: &PrintOptions, state: PrintState) -> String { 95 | let mut s = String::new(); 96 | self.format(opts, state, &mut s) 97 | .expect("Could not write to string"); 98 | s 99 | } 100 | } 101 | 102 | /// A set is a collection of elements. 103 | pub trait Set: Clone + PartialEq + Eq + Hash + Debug + Display { 104 | /// The element of a set. For example, the elements of the ring of integers [Z](type@integer::Z), `Z::Element`, are [Integer]. 105 | type Element: Clone + PartialEq + Eq + Hash + InternalOrdering + Debug; 106 | 107 | /// The number of elements in the set. `None` is used for infinite sets. 108 | fn size(&self) -> Option; 109 | } 110 | 111 | /// Operations on rings. They should be implemented for `T = ::Element` and `T = &::Element`. 112 | pub trait RingOps: Set { 113 | /// Compute `a + b`. 114 | fn add(&self, a: T, b: T) -> Self::Element; 115 | /// Compute `a - b`. 116 | fn sub(&self, a: T, b: T) -> Self::Element; 117 | /// Compute `a * b`. 118 | fn mul(&self, a: T, b: T) -> Self::Element; 119 | /// Compute `-a`. 120 | fn neg(&self, a: T) -> Self::Element; 121 | 122 | /// In-place addition: `a += b`. 123 | fn add_assign(&self, a: &mut Self::Element, b: T); 124 | /// In-place subtraction: `a -= b`. 125 | fn sub_assign(&self, a: &mut Self::Element, b: T); 126 | /// In-place multiplication: `a *= b`. 127 | fn mul_assign(&self, a: &mut Self::Element, b: T); 128 | /// In-place fused multiply-add: `a += b * c`. 129 | fn add_mul_assign(&self, a: &mut Self::Element, b: T, c: T); 130 | /// In-place fused multiply-subtract: `a -= b * c`. 131 | fn sub_mul_assign(&self, a: &mut Self::Element, b: T, c: T); 132 | } 133 | 134 | /// A ring is a set with two binary operations, addition and multiplication. 135 | /// Examples of rings include the integers, rational numbers, and polynomials. 136 | /// 137 | /// Each ring has an element type, that should not be confused with the ring type itself. 138 | /// For example: 139 | /// - The ring of integers [Z](type@integer::Z) has elements of type [Integer]. 140 | /// - The ring of rational numbers [Q](type@rational::Q) has elements of type [Rational](rational::Rational). 141 | /// - The ring of finite fields [FiniteField](finite_field::FiniteField) has elements of type [FiniteField](finite_field::FiniteFieldElement). 142 | /// - The ring of polynomials [PolynomialRing](super::poly::polynomial::PolynomialRing) has elements of type [MultivariatePolynomial](super::poly::polynomial::MultivariatePolynomial). 143 | /// 144 | /// In general, the ring elements do not implement operations such as addition or multiplication, 145 | /// but rather the ring itself does. Most Symbolica structures are generic over the ring type. 146 | /// 147 | /// An extension of the ring trait is the [`EuclideanDomain`] trait, which adds the ability to compute remainders, quotients, and gcds. 148 | /// Another extension is the [`Field`] trait, which adds the ability to divide and invert elements. 149 | pub trait Ring: 150 | Set + RingOps<::Element> + for<'a> RingOps<&'a ::Element> 151 | { 152 | /// Return the additive identity `0`. 153 | fn zero(&self) -> Self::Element; 154 | /// Return the multiplicative identity `1`. 155 | fn one(&self) -> Self::Element; 156 | /// Return the nth element by computing `n * 1`. 157 | fn nth(&self, n: Integer) -> Self::Element; 158 | /// Return `b` raised to the power of `e`. 159 | fn pow(&self, b: &Self::Element, e: u64) -> Self::Element; 160 | /// Return `true` iff `a` is the additive identity `0`. 161 | fn is_zero(&self, a: &Self::Element) -> bool; 162 | /// Return `true` iff `a` is the multiplicative identity `1`. 163 | fn is_one(&self, a: &Self::Element) -> bool; 164 | /// Should return `true` iff `gcd(1,x)` returns `1` for any `x`. 165 | /// For fraction fields, this is most often `false`, as `gcd(1,1/2)` is commonly 166 | /// defined to be `1/2`. 167 | fn one_is_gcd_unit() -> bool; 168 | /// The characteristic of the ring, i.e., the smallest positive integer `n` such that 169 | /// `n * 1 = 0`. If no such `n` exists, return `0`. 170 | fn characteristic(&self) -> Integer; 171 | 172 | /// Invert `a` if `a` is a unit in the ring. If is not, return `None`. 173 | /// For example, in [Z](type@integer::Z), only `1` and `-1` are invertible. 174 | fn try_inv(&self, a: &Self::Element) -> Option; 175 | 176 | /// Return the result of dividing `a` by `b`, if possible and if the result is unique. 177 | /// For example, in [Z](type@integer::Z), `4/2` is possible but `3/2` is not. 178 | fn try_div(&self, a: &Self::Element, b: &Self::Element) -> Option; 179 | 180 | fn sample(&self, rng: &mut impl rand::RngCore, range: (i64, i64)) -> Self::Element; 181 | 182 | /// Format a ring element with custom [PrintOptions] and [PrintState]. 183 | fn format( 184 | &self, 185 | element: &Self::Element, 186 | opts: &PrintOptions, 187 | state: PrintState, 188 | f: &mut W, 189 | ) -> Result; 190 | 191 | /// Whether the ring does not contain additional information 192 | /// that cannot be inferred from any element of the ring. 193 | /// For example, [Z](type@integer::Z) and [Q](type@rational::Q) have independent elements, 194 | /// while [FiniteField](finite_field::FiniteField) does not, as the prime is part of the ring itself 195 | /// and not part of the elements. 196 | /// 197 | /// Types that return `true` can implement [SelfRing]. 198 | fn has_independent_elements(&self) -> bool { 199 | false 200 | } 201 | 202 | /// Format the ring itself. 203 | fn format_ring( 204 | &self, 205 | _opts: &PrintOptions, 206 | _state: PrintState, 207 | f: &mut W, 208 | ) -> Result { 209 | f.write_fmt(format_args!("{}", self)).map(|_| false) 210 | } 211 | 212 | /// Create a new printer for the given ring element that 213 | /// can be used in a [format!] macro. 214 | fn printer<'a>(&'a self, element: &'a Self::Element) -> RingPrinter<'a, Self> { 215 | RingPrinter::new(self, element) 216 | } 217 | 218 | /// Wrap an element of the ring together with the ring itself, so that 219 | /// operators such as `+` and `*` can be used. 220 | /// 221 | /// ``` 222 | /// use numerica::domains::{Ring, finite_field::{FiniteFieldCore, Zp}}; 223 | /// let z = Zp::new(5); 224 | /// let w = z.wrap(z.to_element(3)); 225 | /// let sub = w - &z.to_element(2); 226 | /// assert_eq!(*sub, z.one()); 227 | /// ``` 228 | fn wrap(&self, element: Self::Element) -> WrappedRingElement 229 | where 230 | Self: Sized, 231 | { 232 | WrappedRingElement::new(self, element) 233 | } 234 | } 235 | 236 | /// A Euclidean domain is a ring that supports division with remainder, quotients, and gcds. 237 | pub trait EuclideanDomain: Ring { 238 | fn rem(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; 239 | fn quot_rem(&self, a: &Self::Element, b: &Self::Element) -> (Self::Element, Self::Element); 240 | 241 | fn quot(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { 242 | self.quot_rem(a, b).0 243 | } 244 | 245 | fn gcd(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; 246 | } 247 | 248 | /// A field is a ring that supports division and inversion. 249 | pub trait Field: EuclideanDomain { 250 | fn div(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; 251 | fn div_assign(&self, a: &mut Self::Element, b: &Self::Element); 252 | fn inv(&self, a: &Self::Element) -> Self::Element; 253 | 254 | /// Find the shortest linear recurrence relation for a given series `s`, 255 | /// using the Berlekamp-Massey algorithm. 256 | /// 257 | /// Yields a vector `c` such that `s[i] = sum(j, 0, m, c[j] * s[i-j-1])` for `i > m` and 258 | /// the number of stable iterations. 259 | /// 260 | /// # Example 261 | /// ```rust 262 | /// use numerica::domains::{Field, rational::Q}; 263 | /// let (res, s) = Q.find_linear_recurrence_relation(&[0.into(), 1.into(), 1.into(), 3.into(), 264 | /// 5.into(), 11.into(), 21.into()]); 265 | /// assert_eq!(s, 3); 266 | /// assert_eq!(res, [1, 2]); // s[i] = 1 * s[i-1] + 2 * s[i-2] 267 | /// ``` 268 | fn find_linear_recurrence_relation( 269 | &self, 270 | series: &[Self::Element], 271 | ) -> (Vec, usize) { 272 | let mut c = vec![self.one()]; 273 | let mut c_old = vec![self.one()]; 274 | let mut tmp = vec![]; 275 | let mut seq_len = 0; 276 | let mut m = 1; 277 | let mut stable_count = 0; 278 | let mut b_inv = self.one(); 279 | for (n, s) in series.iter().enumerate() { 280 | let mut error = s.clone(); 281 | for i in 1..=seq_len { 282 | self.add_mul_assign(&mut error, &c[i], &series[n - i]); 283 | } 284 | 285 | if self.is_zero(&error) { 286 | m += 1; 287 | stable_count += 1; 288 | } else if 2 * seq_len <= n { 289 | tmp.clone_from(&c); 290 | 291 | let factor = self.mul(&error, &b_inv); 292 | 293 | if c.len() < m + c_old.len() { 294 | c.resize(m + c_old.len(), self.zero()); 295 | } 296 | 297 | for (j, c_j) in c_old.iter().enumerate() { 298 | self.sub_mul_assign(&mut c[j + m], c_j, &factor); 299 | } 300 | seq_len = n + 1 - seq_len; 301 | std::mem::swap(&mut c_old, &mut tmp); 302 | b_inv = self.inv(&error); 303 | m = 1; 304 | stable_count = 0; 305 | } else { 306 | let factor = self.mul(&error, &b_inv); 307 | 308 | if c.len() < m + c_old.len() { 309 | c.resize(m + c_old.len(), self.zero()); 310 | } 311 | 312 | for (j, c_j) in c_old.iter().enumerate() { 313 | self.sub_mul_assign(&mut c[j + m], c_j, &factor); 314 | } 315 | m += 1; 316 | stable_count = 0; 317 | } 318 | } 319 | 320 | c.drain(0..c.len() - seq_len); 321 | for x in &mut c { 322 | *x = self.neg(&*x); 323 | } 324 | (c, stable_count) 325 | } 326 | } 327 | 328 | /// Rings that can be upgraded to fields, such as `IntegerRing` and `PolynomialRing`. 329 | /// The most common upgrade is by creating a fraction field, such as `Q[x]`. 330 | pub trait UpgradeToField: Ring { 331 | type Upgraded: Field; 332 | 333 | /// Upgrade the ring to a field. 334 | fn upgrade(self) -> Self::Upgraded; 335 | 336 | /// Upgrade an element of the ring to an element of the upgraded field. 337 | fn upgrade_element(&self, element: ::Element) -> ::Element; 338 | } 339 | 340 | impl UpgradeToField for T { 341 | type Upgraded = Self; 342 | 343 | fn upgrade(self) -> Self::Upgraded { 344 | self 345 | } 346 | 347 | fn upgrade_element(&self, element: ::Element) -> ::Element { 348 | element 349 | } 350 | } 351 | 352 | /// Provides an interface for printing elements of a ring with optional customization, 353 | /// suitable as an argument to [format!]. Internally, it will call [Ring::format]. 354 | pub struct RingPrinter<'a, R: Ring> { 355 | pub ring: &'a R, 356 | pub element: &'a R::Element, 357 | pub opts: PrintOptions, 358 | pub state: PrintState, 359 | } 360 | 361 | impl<'a, R: Ring> RingPrinter<'a, R> { 362 | pub fn new(ring: &'a R, element: &'a R::Element) -> RingPrinter<'a, R> { 363 | RingPrinter { 364 | ring, 365 | element, 366 | opts: PrintOptions::default(), 367 | state: PrintState::default(), 368 | } 369 | } 370 | } 371 | 372 | impl Display for RingPrinter<'_, R> { 373 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 374 | self.ring 375 | .format( 376 | self.element, 377 | &self.opts.update_with_fmt(f), 378 | self.state.update_with_fmt(f), 379 | f, 380 | ) 381 | .map(|_| ()) 382 | } 383 | } 384 | 385 | /// A ring element wrapped together with its ring. 386 | /// 387 | /// ``` 388 | /// # use numerica::domains::{Ring, finite_field::{FiniteFieldCore, Zp}}; 389 | /// let z = Zp::new(5); 390 | /// let w = z.wrap(z.to_element(3)); 391 | /// let sub = w - &z.to_element(2); 392 | /// assert_eq!(*sub, z.one()); 393 | /// ``` 394 | #[derive(Clone)] 395 | pub struct WrappedRingElement> { 396 | pub ring: C, 397 | pub element: R::Element, 398 | } 399 | 400 | impl> AsRef for WrappedRingElement { 401 | fn as_ref(&self) -> &R::Element { 402 | &self.element 403 | } 404 | } 405 | 406 | impl> Deref for WrappedRingElement { 407 | type Target = R::Element; 408 | fn deref(&self) -> &R::Element { 409 | &self.element 410 | } 411 | } 412 | 413 | impl> WrappedRingElement { 414 | pub fn new(ring: C, element: R::Element) -> Self { 415 | WrappedRingElement { ring, element } 416 | } 417 | 418 | pub fn ring(&self) -> &R { 419 | self.ring.borrow() 420 | } 421 | } 422 | 423 | impl> Debug for WrappedRingElement { 424 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 425 | self.ring() 426 | .format( 427 | &self.element, 428 | &PrintOptions::default(), 429 | PrintState::default(), 430 | f, 431 | ) 432 | .map(|_| ()) 433 | } 434 | } 435 | 436 | impl> Display for WrappedRingElement { 437 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 438 | self.ring() 439 | .format( 440 | &self.element, 441 | &PrintOptions::default(), 442 | PrintState::default(), 443 | f, 444 | ) 445 | .map(|_| ()) 446 | } 447 | } 448 | 449 | impl> PartialEq for WrappedRingElement { 450 | fn eq(&self, other: &Self) -> bool { 451 | self.element == other.element 452 | } 453 | } 454 | 455 | impl> Eq for WrappedRingElement {} 456 | 457 | impl> Hash for WrappedRingElement { 458 | fn hash(&self, state: &mut H) { 459 | self.element.hash(state) 460 | } 461 | } 462 | 463 | impl> InternalOrdering for WrappedRingElement { 464 | fn internal_cmp(&self, other: &Self) -> std::cmp::Ordering { 465 | self.element.internal_cmp(&other.element) 466 | } 467 | } 468 | 469 | impl> SelfRing for WrappedRingElement { 470 | fn is_zero(&self) -> bool { 471 | self.ring().is_zero(&self.element) 472 | } 473 | 474 | fn is_one(&self) -> bool { 475 | self.ring().is_one(&self.element) 476 | } 477 | 478 | fn format( 479 | &self, 480 | opts: &PrintOptions, 481 | state: PrintState, 482 | f: &mut W, 483 | ) -> Result { 484 | self.ring().format(&self.element, opts, state, f) 485 | } 486 | } 487 | 488 | impl> Add<&R::Element> for WrappedRingElement { 489 | type Output = WrappedRingElement; 490 | 491 | fn add(self, rhs: &R::Element) -> Self::Output { 492 | WrappedRingElement { 493 | element: self.ring().add(&self.element, rhs), 494 | ring: self.ring, 495 | } 496 | } 497 | } 498 | 499 | impl> Sub<&R::Element> for WrappedRingElement { 500 | type Output = WrappedRingElement; 501 | 502 | fn sub(self, rhs: &R::Element) -> Self::Output { 503 | WrappedRingElement { 504 | element: self.ring().sub(&self.element, rhs), 505 | ring: self.ring, 506 | } 507 | } 508 | } 509 | 510 | impl> Mul<&R::Element> for WrappedRingElement { 511 | type Output = WrappedRingElement; 512 | 513 | fn mul(self, rhs: &R::Element) -> Self::Output { 514 | WrappedRingElement { 515 | element: self.ring().mul(&self.element, rhs), 516 | ring: self.ring, 517 | } 518 | } 519 | } 520 | 521 | impl> Div<&R::Element> for WrappedRingElement { 522 | type Output = WrappedRingElement; 523 | 524 | fn div(self, rhs: &R::Element) -> Self::Output { 525 | WrappedRingElement { 526 | element: self.ring().div(&self.element, rhs), 527 | ring: self.ring, 528 | } 529 | } 530 | } 531 | 532 | /// A ring that supports a derivative. 533 | pub trait Derivable: Ring { 534 | type Variable; 535 | 536 | /// Take the derivative of `e` in `x`. 537 | fn derivative(&self, e: &::Element, x: &Self::Variable) -> ::Element; 538 | } 539 | 540 | #[cfg(test)] 541 | mod tests { 542 | use crate::domains::{Field, Ring, rational::Q}; 543 | 544 | #[test] 545 | fn linear_recurrence() { 546 | let series = [ 547 | 2.into(), 548 | 2.into(), 549 | 1.into(), 550 | 2.into(), 551 | 1.into(), 552 | 191.into(), 553 | 393.into(), 554 | 132.into(), 555 | ]; 556 | 557 | let (res, it) = Q.find_linear_recurrence_relation(&series); 558 | assert_eq!(it, 0); 559 | 560 | for (i, x) in series.iter().enumerate().skip(res.len()) { 561 | let mut c = Q.zero(); 562 | for j in 0..res.len() { 563 | c += &res[j] * &series[i - j - 1]; 564 | } 565 | 566 | assert_eq!(&c, x); 567 | } 568 | } 569 | } 570 | -------------------------------------------------------------------------------- /src/domains/rational.rs: -------------------------------------------------------------------------------- 1 | //! Fraction fields and rational numbers. 2 | 3 | use std::{ 4 | borrow::Cow, 5 | fmt::{Display, Error, Formatter}, 6 | ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}, 7 | }; 8 | 9 | use crate::{ 10 | domains::{RingOps, Set, finite_field::Zp64}, 11 | printer::{PrintOptions, PrintState}, 12 | }; 13 | 14 | use super::{ 15 | EuclideanDomain, Field, InternalOrdering, Ring, SelfRing, UpgradeToField, 16 | finite_field::{ 17 | FiniteField, FiniteFieldCore, FiniteFieldWorkspace, PrimeIteratorU64, ToFiniteField, Two, 18 | Z2, Zp, 19 | }, 20 | integer::{Integer, IntegerRing, Z}, 21 | }; 22 | 23 | /// The field of rational numbers. 24 | pub type Q = FractionField; 25 | pub type RationalField = FractionField; 26 | /// The field of rational numbers. 27 | pub const Q: FractionField = FractionField::new(Z); 28 | 29 | /// The fraction field of `R`. 30 | #[derive(Clone, PartialEq, Eq, Hash, Debug)] 31 | pub struct FractionField { 32 | ring: R, 33 | } 34 | 35 | impl FractionField { 36 | pub const fn new(ring: R) -> FractionField { 37 | FractionField { ring } 38 | } 39 | 40 | pub fn ring(&self) -> &R { 41 | &self.ring 42 | } 43 | } 44 | 45 | impl FractionField { 46 | pub fn to_element_numerator(&self, numerator: R::Element) -> ::Element { 47 | Fraction { 48 | numerator, 49 | denominator: self.ring.one(), 50 | } 51 | } 52 | } 53 | 54 | impl FractionField { 55 | pub fn to_element( 56 | &self, 57 | mut numerator: R::Element, 58 | mut denominator: R::Element, 59 | do_gcd: bool, 60 | ) -> ::Element { 61 | if self.ring.is_zero(&denominator) { 62 | panic!("Cannot create a fraction with zero denominator"); 63 | } 64 | 65 | if do_gcd { 66 | let g = self.ring.gcd(&numerator, &denominator); 67 | if !self.ring.is_one(&g) { 68 | numerator = self.ring.quot_rem(&numerator, &g).0; 69 | denominator = self.ring.quot_rem(&denominator, &g).0; 70 | } 71 | } 72 | 73 | let f = self.ring.get_normalization_factor(&denominator); 74 | 75 | if self.ring.is_one(&f) { 76 | Fraction { 77 | numerator, 78 | denominator, 79 | } 80 | } else { 81 | Fraction { 82 | numerator: self.ring.mul(&numerator, &f), 83 | denominator: self.ring.mul(&denominator, &f), 84 | } 85 | } 86 | } 87 | } 88 | 89 | impl Display for FractionField { 90 | fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result { 91 | Ok(()) 92 | } 93 | } 94 | 95 | pub trait FractionNormalization: Ring { 96 | /// Get the factor that normalizes the element `a`. 97 | /// - For a field, this is the inverse of `a`. 98 | /// - For the integers, this is the sign of `a`. 99 | /// - For a polynomial ring, this is the normalization factor of the leading coefficient. 100 | fn get_normalization_factor(&self, a: &Self::Element) -> Self::Element; 101 | } 102 | 103 | impl FractionNormalization for Z { 104 | fn get_normalization_factor(&self, a: &Integer) -> Integer { 105 | if *a < 0 { (-1).into() } else { 1.into() } 106 | } 107 | } 108 | 109 | impl FractionNormalization for T { 110 | fn get_normalization_factor(&self, a: &Self::Element) -> Self::Element { 111 | self.inv(a) 112 | } 113 | } 114 | 115 | /// A fraction of two elements of a ring. Create a new one through [FractionField::to_element]. 116 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 117 | #[derive(Clone, PartialEq, Eq, Hash, Debug)] 118 | pub struct Fraction { 119 | numerator: R::Element, 120 | denominator: R::Element, 121 | } 122 | 123 | #[cfg(feature = "bincode")] 124 | impl bincode::enc::Encode for Fraction 125 | where 126 | R::Element: bincode::enc::Encode, 127 | { 128 | fn encode( 129 | &self, 130 | encoder: &mut E, 131 | ) -> Result<(), bincode::error::EncodeError> { 132 | self.numerator.encode(encoder)?; 133 | self.denominator.encode(encoder) 134 | } 135 | } 136 | 137 | #[cfg(feature = "bincode")] 138 | impl bincode::de::Decode for Fraction 139 | where 140 | R::Element: bincode::de::Decode, 141 | { 142 | fn decode>( 143 | decoder: &mut D, 144 | ) -> Result { 145 | Ok(Fraction { 146 | numerator: R::Element::decode(decoder)?, 147 | denominator: R::Element::decode(decoder)?, 148 | }) 149 | } 150 | } 151 | 152 | #[cfg(feature = "bincode")] 153 | impl<'de, C, R: Ring> bincode::de::BorrowDecode<'de, C> for Fraction 154 | where 155 | R::Element: bincode::de::BorrowDecode<'de, C>, 156 | { 157 | fn borrow_decode>( 158 | decoder: &mut D, 159 | ) -> Result { 160 | Ok(Fraction { 161 | numerator: R::Element::borrow_decode(decoder)?, 162 | denominator: R::Element::borrow_decode(decoder)?, 163 | }) 164 | } 165 | } 166 | 167 | impl Fraction { 168 | pub fn numerator(&self) -> R::Element { 169 | self.numerator.clone() 170 | } 171 | 172 | pub fn denominator(&self) -> R::Element { 173 | self.denominator.clone() 174 | } 175 | 176 | pub fn numerator_ref(&self) -> &R::Element { 177 | &self.numerator 178 | } 179 | 180 | pub fn denominator_ref(&self) -> &R::Element { 181 | &self.denominator 182 | } 183 | 184 | pub fn from_unchecked(numerator: R::Element, denominator: R::Element) -> Self { 185 | Fraction { 186 | numerator, 187 | denominator, 188 | } 189 | } 190 | } 191 | 192 | impl InternalOrdering for Fraction { 193 | fn internal_cmp(&self, other: &Self) -> std::cmp::Ordering { 194 | self.numerator 195 | .internal_cmp(&other.numerator) 196 | .then_with(|| self.denominator.internal_cmp(&other.denominator)) 197 | } 198 | } 199 | 200 | impl Set for FractionField { 201 | type Element = Fraction; 202 | 203 | fn size(&self) -> Option { 204 | self.ring.size().map(|s| &s * (&s - 1)) 205 | } 206 | } 207 | 208 | impl RingOps< as Set>::Element> 209 | for FractionField 210 | { 211 | fn add(&self, a: Self::Element, b: Self::Element) -> Self::Element { 212 | let r = &self.ring; 213 | 214 | if a.denominator == b.denominator { 215 | let num = r.add(&a.numerator, &b.numerator); 216 | let g = r.gcd(&num, &a.denominator); 217 | if !r.is_one(&g) { 218 | return Fraction { 219 | numerator: r.quot_rem(&num, &g).0, 220 | denominator: r.quot_rem(&a.denominator, &g).0, 221 | }; 222 | } else { 223 | return Fraction { 224 | numerator: num, 225 | denominator: a.denominator.clone(), 226 | }; 227 | } 228 | } 229 | 230 | let denom_gcd = r.gcd(&a.denominator, &b.denominator); 231 | 232 | let mut a_den_red = Cow::Borrowed(&a.denominator); 233 | let mut b_den_red = Cow::Borrowed(&b.denominator); 234 | 235 | if !r.is_one(&denom_gcd) { 236 | a_den_red = Cow::Owned(r.quot_rem(&a.denominator, &denom_gcd).0); 237 | b_den_red = Cow::Owned(r.quot_rem(&b.denominator, &denom_gcd).0); 238 | } 239 | 240 | let num1 = r.mul(&a.numerator, &b_den_red); 241 | let num2 = r.mul(&b.numerator, &a_den_red); 242 | let mut num = r.add(&num1, &num2); 243 | 244 | // TODO: prefer small * large over medium * medium sized operations 245 | // a_denom_red.as_ref() * &other.denominator may be faster 246 | // TODO: add size hint trait with default implementation? 247 | let mut den = r.mul(b_den_red.as_ref(), &a.denominator); 248 | 249 | let g = r.gcd(&num, &denom_gcd); 250 | 251 | if !r.is_one(&g) { 252 | num = r.quot_rem(&num, &g).0; 253 | den = r.quot_rem(&den, &g).0; 254 | } 255 | 256 | Fraction { 257 | numerator: num, 258 | denominator: den, 259 | } 260 | } 261 | 262 | fn sub(&self, a: Self::Element, b: Self::Element) -> Self::Element { 263 | self.add(a, self.neg(b)) 264 | } 265 | 266 | fn mul(&self, a: Self::Element, b: Self::Element) -> Self::Element { 267 | let r = &self.ring; 268 | let gcd1 = r.gcd(&a.numerator, &b.denominator); 269 | let gcd2 = r.gcd(&a.denominator, &b.numerator); 270 | 271 | if r.is_one(&gcd1) { 272 | if r.is_one(&gcd2) { 273 | Fraction { 274 | numerator: r.mul(&a.numerator, &b.numerator), 275 | denominator: r.mul(&a.denominator, &b.denominator), 276 | } 277 | } else { 278 | Fraction { 279 | numerator: r.mul(&a.numerator, &r.quot_rem(&b.numerator, &gcd2).0), 280 | denominator: r.mul(&r.quot_rem(&a.denominator, &gcd2).0, &b.denominator), 281 | } 282 | } 283 | } else if r.is_one(&gcd2) { 284 | Fraction { 285 | numerator: r.mul(&r.quot_rem(&a.numerator, &gcd1).0, &b.numerator), 286 | denominator: r.mul(&a.denominator, &r.quot_rem(&b.denominator, &gcd1).0), 287 | } 288 | } else { 289 | Fraction { 290 | numerator: r.mul( 291 | &r.quot_rem(&a.numerator, &gcd1).0, 292 | &r.quot_rem(&b.numerator, &gcd2).0, 293 | ), 294 | denominator: r.mul( 295 | &r.quot_rem(&a.denominator, &gcd2).0, 296 | &r.quot_rem(&b.denominator, &gcd1).0, 297 | ), 298 | } 299 | } 300 | } 301 | 302 | fn add_assign(&self, a: &mut Self::Element, b: Self::Element) { 303 | // TODO: optimize 304 | *a = self.add(&*a, &b); 305 | } 306 | 307 | fn sub_assign(&self, a: &mut Self::Element, b: Self::Element) { 308 | *a = self.sub(&*a, &b); 309 | } 310 | 311 | fn mul_assign(&self, a: &mut Self::Element, b: Self::Element) { 312 | *a = self.mul(&*a, &b); 313 | } 314 | 315 | fn add_mul_assign(&self, a: &mut Self::Element, b: Self::Element, c: Self::Element) { 316 | self.add_assign(a, &self.mul(b, c)); 317 | } 318 | 319 | fn sub_mul_assign(&self, a: &mut Self::Element, b: Self::Element, c: Self::Element) { 320 | self.sub_assign(a, &self.mul(b, c)); 321 | } 322 | 323 | fn neg(&self, a: Self::Element) -> Self::Element { 324 | Fraction { 325 | numerator: self.ring.neg(a.numerator), 326 | denominator: a.denominator, 327 | } 328 | } 329 | } 330 | 331 | impl RingOps<& as Set>::Element> 332 | for FractionField 333 | { 334 | fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { 335 | let r = &self.ring; 336 | 337 | if a.denominator == b.denominator { 338 | let num = r.add(&a.numerator, &b.numerator); 339 | let g = r.gcd(&num, &a.denominator); 340 | if !r.is_one(&g) { 341 | return Fraction { 342 | numerator: r.quot_rem(&num, &g).0, 343 | denominator: r.quot_rem(&a.denominator, &g).0, 344 | }; 345 | } else { 346 | return Fraction { 347 | numerator: num, 348 | denominator: a.denominator.clone(), 349 | }; 350 | } 351 | } 352 | 353 | let denom_gcd = r.gcd(&a.denominator, &b.denominator); 354 | 355 | let mut a_den_red = Cow::Borrowed(&a.denominator); 356 | let mut b_den_red = Cow::Borrowed(&b.denominator); 357 | 358 | if !r.is_one(&denom_gcd) { 359 | a_den_red = Cow::Owned(r.quot_rem(&a.denominator, &denom_gcd).0); 360 | b_den_red = Cow::Owned(r.quot_rem(&b.denominator, &denom_gcd).0); 361 | } 362 | 363 | let num1 = r.mul(&a.numerator, &b_den_red); 364 | let num2 = r.mul(&b.numerator, &a_den_red); 365 | let mut num = r.add(&num1, &num2); 366 | 367 | // TODO: prefer small * large over medium * medium sized operations 368 | // a_denom_red.as_ref() * &other.denominator may be faster 369 | // TODO: add size hint trait with default implementation? 370 | let mut den = r.mul(b_den_red.as_ref(), &a.denominator); 371 | 372 | let g = r.gcd(&num, &denom_gcd); 373 | 374 | if !r.is_one(&g) { 375 | num = r.quot_rem(&num, &g).0; 376 | den = r.quot_rem(&den, &g).0; 377 | } 378 | 379 | Fraction { 380 | numerator: num, 381 | denominator: den, 382 | } 383 | } 384 | 385 | fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { 386 | // TODO: optimize 387 | self.add(a, &self.neg(b)) 388 | } 389 | 390 | fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { 391 | let r = &self.ring; 392 | let gcd1 = r.gcd(&a.numerator, &b.denominator); 393 | let gcd2 = r.gcd(&a.denominator, &b.numerator); 394 | 395 | if r.is_one(&gcd1) { 396 | if r.is_one(&gcd2) { 397 | Fraction { 398 | numerator: r.mul(&a.numerator, &b.numerator), 399 | denominator: r.mul(&a.denominator, &b.denominator), 400 | } 401 | } else { 402 | Fraction { 403 | numerator: r.mul(&a.numerator, &r.quot_rem(&b.numerator, &gcd2).0), 404 | denominator: r.mul(&r.quot_rem(&a.denominator, &gcd2).0, &b.denominator), 405 | } 406 | } 407 | } else if r.is_one(&gcd2) { 408 | Fraction { 409 | numerator: r.mul(&r.quot_rem(&a.numerator, &gcd1).0, &b.numerator), 410 | denominator: r.mul(&a.denominator, &r.quot_rem(&b.denominator, &gcd1).0), 411 | } 412 | } else { 413 | Fraction { 414 | numerator: r.mul( 415 | &r.quot_rem(&a.numerator, &gcd1).0, 416 | &r.quot_rem(&b.numerator, &gcd2).0, 417 | ), 418 | denominator: r.mul( 419 | &r.quot_rem(&a.denominator, &gcd2).0, 420 | &r.quot_rem(&b.denominator, &gcd1).0, 421 | ), 422 | } 423 | } 424 | } 425 | 426 | fn add_assign(&self, a: &mut Self::Element, b: &Self::Element) { 427 | // TODO: optimize 428 | *a = self.add(&*a, b); 429 | } 430 | 431 | fn sub_assign(&self, a: &mut Self::Element, b: &Self::Element) { 432 | *a = self.sub(&*a, b); 433 | } 434 | 435 | fn mul_assign(&self, a: &mut Self::Element, b: &Self::Element) { 436 | *a = self.mul(&*a, b); 437 | } 438 | 439 | fn add_mul_assign(&self, a: &mut Self::Element, b: &Self::Element, c: &Self::Element) { 440 | self.add_assign(a, &self.mul(b, c)); 441 | } 442 | 443 | fn sub_mul_assign(&self, a: &mut Self::Element, b: &Self::Element, c: &Self::Element) { 444 | self.sub_assign(a, &self.mul(b, c)); 445 | } 446 | 447 | fn neg(&self, a: &Self::Element) -> Self::Element { 448 | Fraction { 449 | numerator: self.ring.neg(&a.numerator), 450 | denominator: a.denominator.clone(), 451 | } 452 | } 453 | } 454 | 455 | impl Ring for FractionField { 456 | fn zero(&self) -> Self::Element { 457 | Fraction { 458 | numerator: self.ring.zero(), 459 | denominator: self.ring.one(), 460 | } 461 | } 462 | 463 | fn one(&self) -> Self::Element { 464 | Fraction { 465 | numerator: self.ring.one(), 466 | denominator: self.ring.one(), 467 | } 468 | } 469 | 470 | #[inline] 471 | fn nth(&self, n: Integer) -> Self::Element { 472 | Fraction { 473 | numerator: self.ring.nth(n), 474 | denominator: self.ring.one(), 475 | } 476 | } 477 | 478 | fn pow(&self, b: &Self::Element, e: u64) -> Self::Element { 479 | Fraction { 480 | numerator: self.ring.pow(&b.numerator, e), 481 | denominator: self.ring.pow(&b.denominator, e), 482 | } 483 | } 484 | 485 | fn is_zero(&self, a: &Self::Element) -> bool { 486 | self.ring.is_zero(&a.numerator) 487 | } 488 | 489 | fn is_one(&self, a: &Self::Element) -> bool { 490 | self.ring.is_one(&a.numerator) && self.ring.is_one(&a.denominator) 491 | } 492 | 493 | fn one_is_gcd_unit() -> bool { 494 | false 495 | } 496 | 497 | fn characteristic(&self) -> Integer { 498 | self.ring.characteristic() 499 | } 500 | 501 | fn try_inv(&self, a: &Self::Element) -> Option { 502 | if self.ring.is_zero(&a.numerator) { 503 | None 504 | } else { 505 | Some(self.inv(a)) 506 | } 507 | } 508 | 509 | fn try_div(&self, a: &Self::Element, b: &Self::Element) -> Option { 510 | if self.is_zero(b) { 511 | None 512 | } else { 513 | Some(self.div(a, b)) 514 | } 515 | } 516 | 517 | fn sample(&self, rng: &mut impl rand::RngCore, range: (i64, i64)) -> Self::Element { 518 | Fraction { 519 | numerator: self.ring.sample(rng, range), 520 | denominator: self.ring.one(), 521 | } 522 | } 523 | 524 | fn format( 525 | &self, 526 | element: &Self::Element, 527 | opts: &PrintOptions, 528 | mut state: PrintState, 529 | f: &mut W, 530 | ) -> Result { 531 | let has_denom = !self.ring.is_one(&element.denominator); 532 | 533 | let write_par = has_denom && (state.in_exp || state.in_exp_base); 534 | if write_par { 535 | if state.in_sum { 536 | state.in_sum = false; 537 | f.write_char('+')?; 538 | } 539 | 540 | f.write_char('(')?; 541 | state.in_exp = false; 542 | state.in_exp_base = false; 543 | } 544 | 545 | if self.ring.format( 546 | &element.numerator, 547 | opts, 548 | PrintState { 549 | in_product: state.in_product || has_denom, 550 | suppress_one: state.suppress_one && !has_denom, 551 | level: state.level + 1, 552 | ..state 553 | }, 554 | f, 555 | )? { 556 | return Ok(true); 557 | }; 558 | 559 | if has_denom { 560 | f.write_char('/')?; 561 | self.ring.format( 562 | &element.denominator, 563 | opts, 564 | state.step(false, true, false, true), 565 | f, 566 | )?; 567 | } 568 | 569 | if write_par { 570 | f.write_char(')')?; 571 | } 572 | 573 | Ok(false) 574 | } 575 | 576 | fn has_independent_elements(&self) -> bool { 577 | self.ring.has_independent_elements() 578 | } 579 | } 580 | 581 | impl SelfRing for Fraction 582 | where 583 | R::Element: SelfRing, 584 | { 585 | fn is_zero(&self) -> bool { 586 | self.numerator.is_zero() 587 | } 588 | 589 | fn is_one(&self) -> bool { 590 | self.numerator.is_one() && self.denominator.is_one() 591 | } 592 | 593 | fn format( 594 | &self, 595 | opts: &PrintOptions, 596 | mut state: PrintState, 597 | f: &mut W, 598 | ) -> Result { 599 | let has_denom = !self.denominator.is_one(); 600 | 601 | let write_par = has_denom && (state.in_exp || state.in_exp_base); 602 | if write_par { 603 | if state.in_sum { 604 | state.in_sum = false; 605 | f.write_char('+')?; 606 | } 607 | 608 | f.write_char('(')?; 609 | state.in_exp = false; 610 | state.in_exp_base = false; 611 | } 612 | 613 | if self.numerator.format( 614 | opts, 615 | PrintState { 616 | in_product: state.in_product || has_denom, 617 | suppress_one: state.suppress_one && !has_denom, 618 | level: state.level + 1, 619 | ..state 620 | }, 621 | f, 622 | )? { 623 | return Ok(true); 624 | } 625 | 626 | if has_denom { 627 | f.write_char('/')?; 628 | self.denominator 629 | .format(opts, state.step(false, true, false, true), f)?; 630 | } 631 | 632 | if write_par { 633 | f.write_char(')')?; 634 | } 635 | 636 | Ok(false) 637 | } 638 | } 639 | 640 | impl Display for Fraction 641 | where 642 | R::Element: SelfRing, 643 | { 644 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 645 | self.format(&PrintOptions::default(), PrintState::new(), f) 646 | .map(|_| ()) 647 | } 648 | } 649 | 650 | impl EuclideanDomain for FractionField { 651 | fn rem(&self, _: &Self::Element, _: &Self::Element) -> Self::Element { 652 | self.zero() 653 | } 654 | 655 | fn quot_rem(&self, a: &Self::Element, b: &Self::Element) -> (Self::Element, Self::Element) { 656 | (self.div(a, b), self.zero()) 657 | } 658 | 659 | fn gcd(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { 660 | let gcd_num = self.ring.gcd(&a.numerator, &b.numerator); 661 | let gcd_den = self.ring.gcd(&a.denominator, &b.denominator); 662 | 663 | let d1 = self.ring.quot_rem(&a.denominator, &gcd_den).0; 664 | let lcm = self.ring.mul(&d1, &b.denominator); 665 | 666 | Fraction { 667 | numerator: gcd_num, 668 | denominator: lcm, 669 | } 670 | } 671 | } 672 | 673 | impl Field for FractionField { 674 | fn div(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { 675 | // TODO: optimize 676 | self.mul(a, &self.inv(b)) 677 | } 678 | 679 | fn div_assign(&self, a: &mut Self::Element, b: &Self::Element) { 680 | *a = self.div(a, b); 681 | } 682 | 683 | fn inv(&self, a: &Self::Element) -> Self::Element { 684 | if self.ring.is_zero(&a.numerator) { 685 | panic!("Division by 0"); 686 | } 687 | 688 | let f = self.ring.get_normalization_factor(&a.numerator); 689 | 690 | Fraction { 691 | numerator: self.ring.mul(&a.denominator, &f), 692 | denominator: self.ring.mul(&a.numerator, &f), 693 | } 694 | } 695 | } 696 | 697 | /// A rational number. 698 | pub type Rational = Fraction; 699 | 700 | impl UpgradeToField for IntegerRing { 701 | type Upgraded = Q; 702 | 703 | fn upgrade(self) -> Self::Upgraded { 704 | Q 705 | } 706 | 707 | fn upgrade_element(&self, element: ::Element) -> ::Element { 708 | Rational::from(element) 709 | } 710 | } 711 | 712 | impl Default for Rational { 713 | fn default() -> Self { 714 | Rational::zero() 715 | } 716 | } 717 | 718 | impl PartialEq for Rational { 719 | fn eq(&self, other: &Integer) -> bool { 720 | self.denominator.is_one() && &self.numerator == other 721 | } 722 | } 723 | 724 | impl + Copy> PartialEq for Rational { 725 | fn eq(&self, other: &T) -> bool { 726 | self.denominator.is_one() && self.numerator == (*other).into() 727 | } 728 | } 729 | 730 | impl + Copy> PartialEq<(T, T)> for Rational { 731 | fn eq(&self, other: &(T, T)) -> bool { 732 | self == &Rational::from((other.0.into(), other.1.into())) 733 | } 734 | } 735 | 736 | impl PartialOrd for Rational { 737 | fn partial_cmp(&self, other: &Integer) -> Option { 738 | Some(self.numerator.cmp(&(other * self.denominator_ref()))) 739 | } 740 | } 741 | 742 | impl + Copy> PartialOrd for Rational { 743 | fn partial_cmp(&self, other: &T) -> Option { 744 | Some( 745 | self.numerator 746 | .cmp(&((*other).into() * self.denominator_ref())), 747 | ) 748 | } 749 | } 750 | 751 | impl + Copy> PartialOrd<(T, T)> for Rational { 752 | fn partial_cmp(&self, other: &(T, T)) -> Option { 753 | Some(self.cmp(&Rational::from((other.0.into(), other.1.into())))) 754 | } 755 | } 756 | 757 | impl TryFrom for Rational { 758 | type Error = &'static str; 759 | 760 | /// Convert a floating point number to its exact rational number equivalent. 761 | /// Use [`Rational::truncate_denominator`] to get an approximation with a smaller denominator. 762 | #[inline] 763 | fn try_from(f: f64) -> Result { 764 | if !f.is_finite() { 765 | return Err("Cannot convert non-finite float to rational"); 766 | } 767 | 768 | // taken from num-traits 769 | let bits: u64 = f.to_bits(); 770 | let sign: i8 = if bits >> 63 == 0 { 1 } else { -1 }; 771 | let mut exponent: i16 = ((bits >> 52) & 0x7ff) as i16; 772 | let mantissa = if exponent == 0 { 773 | (bits & 0xfffffffffffff) << 1 774 | } else { 775 | (bits & 0xfffffffffffff) | 0x10000000000000 776 | }; 777 | // Exponent bias + mantissa shift 778 | exponent -= 1023 + 52; 779 | 780 | // superfluous factors of 2 will be divided out in the conversion to rational 781 | if exponent < 0 { 782 | Ok(( 783 | (sign as i64 * mantissa as i64).into(), 784 | Integer::from(2).pow(-exponent as u64), 785 | ) 786 | .into()) 787 | } else { 788 | Ok(( 789 | &Integer::from(sign as i64 * mantissa as i64) 790 | * &Integer::from(2).pow(exponent as u64), 791 | 1.into(), 792 | ) 793 | .into()) 794 | } 795 | } 796 | } 797 | 798 | impl> From for Rational { 799 | #[inline] 800 | fn from(value: T) -> Self { 801 | Rational { 802 | numerator: value.into(), 803 | denominator: 1.into(), 804 | } 805 | } 806 | } 807 | 808 | impl From<&Integer> for Rational { 809 | fn from(value: &Integer) -> Self { 810 | Rational { 811 | numerator: value.clone(), 812 | denominator: 1.into(), 813 | } 814 | } 815 | } 816 | 817 | impl> From<(T, T)> for Rational { 818 | #[inline] 819 | fn from((num, den): (T, T)) -> Self { 820 | Q.to_element(num.into(), den.into(), true) 821 | } 822 | } 823 | 824 | impl From for Rational { 825 | fn from(value: rug::Rational) -> Self { 826 | let (num, den) = value.into_numer_denom(); 827 | Q.to_element(num.into(), den.into(), false) 828 | } 829 | } 830 | 831 | impl ToFiniteField for Rational { 832 | fn to_finite_field(&self, field: &Zp) -> ::Element { 833 | field.div( 834 | &self.numerator.to_finite_field(field), 835 | &self.denominator.to_finite_field(field), 836 | ) 837 | } 838 | } 839 | 840 | impl ToFiniteField for Rational { 841 | fn to_finite_field(&self, field: &Zp64) -> ::Element { 842 | field.div( 843 | &self.numerator.to_finite_field(field), 844 | &self.denominator.to_finite_field(field), 845 | ) 846 | } 847 | } 848 | 849 | impl ToFiniteField for Rational { 850 | fn to_finite_field(&self, field: &Z2) -> ::Element { 851 | field.div( 852 | &self.numerator.to_finite_field(field), 853 | &self.denominator.to_finite_field(field), 854 | ) 855 | } 856 | } 857 | 858 | impl Rational { 859 | pub fn new>(num: T, den: T) -> Rational { 860 | let d = den.into(); 861 | if d.is_zero() { 862 | panic!("Cannot create a rational number with zero denominator"); 863 | } 864 | 865 | Q.to_element(num.into(), d, true) 866 | } 867 | 868 | pub fn from_int_unchecked>(num: T, den: T) -> Rational { 869 | Q.to_element(num.into(), den.into(), false) 870 | } 871 | 872 | pub fn is_negative(&self) -> bool { 873 | self.numerator < 0 874 | } 875 | 876 | pub fn is_integer(&self) -> bool { 877 | self.denominator.is_one() 878 | } 879 | 880 | pub fn zero() -> Rational { 881 | Rational { 882 | numerator: 0.into(), 883 | denominator: 1.into(), 884 | } 885 | } 886 | 887 | pub fn one() -> Rational { 888 | Rational { 889 | numerator: 1.into(), 890 | denominator: 1.into(), 891 | } 892 | } 893 | 894 | pub fn abs(&self) -> Rational { 895 | if self.is_negative() { 896 | self.clone().neg() 897 | } else { 898 | self.clone() 899 | } 900 | } 901 | 902 | pub fn is_zero(&self) -> bool { 903 | self.numerator.is_zero() 904 | } 905 | 906 | pub fn is_one(&self) -> bool { 907 | self.numerator.is_one() && self.denominator.is_one() 908 | } 909 | 910 | pub fn pow(&self, e: u64) -> Rational { 911 | Q.pow(self, e) 912 | } 913 | 914 | pub fn inv(&self) -> Rational { 915 | Q.inv(self) 916 | } 917 | 918 | pub fn neg(&self) -> Rational { 919 | Q.neg(self) 920 | } 921 | 922 | pub fn gcd(&self, other: &Rational) -> Rational { 923 | Q.gcd(self, other) 924 | } 925 | 926 | pub fn to_f64(&self) -> f64 { 927 | rug::Rational::from(( 928 | self.numerator.clone().to_multi_prec(), 929 | self.denominator.clone().to_multi_prec(), 930 | )) 931 | .to_f64() 932 | } 933 | 934 | pub fn to_multi_prec(self) -> rug::Rational { 935 | rug::Rational::from(( 936 | self.numerator.to_multi_prec(), 937 | self.denominator.to_multi_prec(), 938 | )) 939 | } 940 | 941 | /// Return a best approximation of the rational number where the denominator 942 | /// is less than or equal to `max_denominator`. 943 | pub fn truncate_denominator(&self, max_denominator: &Integer) -> Rational { 944 | assert!(!max_denominator.is_zero() && !max_denominator.is_negative()); 945 | 946 | if self.denominator_ref() < max_denominator { 947 | return self.clone(); 948 | } 949 | 950 | let (mut p0, mut q0, mut p1, mut q1) = ( 951 | Integer::zero(), 952 | Integer::one(), 953 | Integer::one(), 954 | Integer::zero(), 955 | ); 956 | 957 | let (mut n, mut d) = (self.numerator_ref().abs(), self.denominator()); 958 | loop { 959 | let a = &n / &d; 960 | let q2 = &q0 + &(&a * &q1); 961 | if &q2 > max_denominator { 962 | break; 963 | } 964 | (p1, p0, q0, q1) = (p0 + &(&a * &p1), p1, q1, q2); 965 | (d, n) = (&n - &a * &d, d); 966 | } 967 | 968 | let k = &(max_denominator - &q0) / &q1; 969 | let bound1: Rational = (p0 + &(&k * &p1), &q0 + &(&k * &q1)).into(); 970 | let bound2: Rational = (p1, q1).into(); 971 | 972 | let res = if (&bound2 - self).abs() <= (&bound1 - self).abs() { 973 | bound2 974 | } else { 975 | bound1 976 | }; 977 | 978 | if self.is_negative() { res.neg() } else { res } 979 | } 980 | 981 | /// Round the rational to the one with the smallest numerator or denominator in the interval 982 | /// `[self * (1-relative_error), self * (1+relative_error)]`, where 983 | /// `0 < relative_error < 1`. 984 | pub fn round(&self, relative_error: &Rational) -> Rational { 985 | if self.is_zero() { 986 | return Rational::zero(); 987 | } 988 | 989 | if self.is_negative() { 990 | self.round_in_interval( 991 | self.clone() * (Rational::one() + relative_error), 992 | self.clone() * (Rational::one() - relative_error), 993 | ) 994 | } else { 995 | self.round_in_interval( 996 | self.clone() * (Rational::one() - relative_error), 997 | self.clone() * (Rational::one() + relative_error), 998 | ) 999 | } 1000 | } 1001 | 1002 | /// Round the rational to the one with the smallest numerator or denominator in the interval 1003 | /// `[l, u]`, where `l < u`. 1004 | pub fn round_in_interval(&self, mut l: Rational, mut u: Rational) -> Rational { 1005 | assert!(l < u); 1006 | 1007 | let mut flip = false; 1008 | if l.is_negative() && u.is_negative() { 1009 | flip = true; 1010 | (l, u) = (-u, -l); 1011 | } else if l.is_negative() { 1012 | return Rational::zero(); 1013 | } 1014 | 1015 | // use continued fractions to find the best approximation in an interval 1016 | let (mut ln, mut ld) = (l.numerator(), l.denominator()); 1017 | let (mut un, mut ud) = (u.numerator(), u.denominator()); 1018 | 1019 | // h1/k1 accumulates the shared continued fraction terms of l and u 1020 | let (mut h1, mut h0, mut k1, mut k0): (Integer, Integer, Integer, Integer) = 1021 | (1.into(), 0.into(), 0.into(), 1.into()); 1022 | 1023 | loop { 1024 | let a = &(&ln - &1.into()) / &ld; // get next term in continued fraction 1025 | (ld, ud, ln, un) = (&un - &a * &ud, &ln - &a * &ld, ud, ld); // subtract and invert 1026 | (h1, h0) = (&a * &h1 + &h0, h1); 1027 | (k1, k0) = (&a * &k1 + &k0, k1); 1028 | if ln <= ld { 1029 | let res: Rational = (h1 + &h0, k1 + &k0).into(); 1030 | 1031 | if flip { 1032 | return -res; 1033 | } else { 1034 | return res; 1035 | } 1036 | } 1037 | } 1038 | } 1039 | 1040 | /// Round to the nearest integer towards zero. 1041 | pub fn floor(&self) -> Integer { 1042 | self.numerator_ref() / self.denominator_ref() 1043 | } 1044 | 1045 | /// Round to the nearest integer away from zero. 1046 | pub fn ceil(&self) -> Integer { 1047 | if self.is_negative() { 1048 | (self.numerator().clone() + 1) / self.denominator_ref() - 1 1049 | } else { 1050 | ((self.numerator().clone() - 1) / self.denominator_ref()) + 1 1051 | } 1052 | } 1053 | 1054 | pub fn round_to_nearest_integer(&self) -> Integer { 1055 | if self.is_negative() { 1056 | (self - &(1, 2).into()).floor() 1057 | } else { 1058 | (self + &(1, 2).into()).floor() 1059 | } 1060 | } 1061 | 1062 | /// Reconstruct a rational number `q` from a value `v` in a prime field `p`, 1063 | /// such that `q ≡ v mod p`. 1064 | /// 1065 | /// From "Maximal Quotient Rational Reconstruction: An Almost 1066 | /// Optimal Algorithm for Rational Reconstruction" by Monagan. 1067 | pub fn maximal_quotient_reconstruction( 1068 | v: &Integer, 1069 | p: &Integer, 1070 | acceptance_scale: Option, 1071 | ) -> Result { 1072 | let mut acceptance_scale = match acceptance_scale { 1073 | Some(t) => t.clone(), 1074 | None => { 1075 | // set t to 2^20*ceil(log2(m)) 1076 | let ceil_log2 = match &p { 1077 | Integer::Single(n) => u64::BITS as u64 - (*n as u64).leading_zeros() as u64, 1078 | Integer::Double(n) => u128::BITS as u64 - (*n as u128).leading_zeros() as u64, 1079 | Integer::Large(n) => { 1080 | let mut pos = 0; 1081 | while let Some(p) = n.find_one(pos) { 1082 | if let Some(p2) = pos.checked_add(p) { 1083 | if p2 == u32::MAX { 1084 | return Err("Could not reconstruct, as the log is too large"); 1085 | } 1086 | 1087 | pos += 1; 1088 | } else { 1089 | return Err("Could not reconstruct, as the log is too large"); 1090 | } 1091 | } 1092 | pos as u64 1093 | } 1094 | }; 1095 | 1096 | &Integer::new(2i64 << 10) * &Integer::new(ceil_log2 as i64) 1097 | } 1098 | }; 1099 | 1100 | if v.is_zero() { 1101 | return if p > &acceptance_scale { 1102 | Ok(Rational::zero()) 1103 | } else { 1104 | Err("Could not reconstruct: u=0 and t <= m") 1105 | }; 1106 | } 1107 | 1108 | let mut n = Integer::zero(); 1109 | let mut d = Integer::zero(); 1110 | let (mut t, mut old_t) = (Integer::one(), Integer::zero()); 1111 | let (mut r, mut old_r) = (if v.is_negative() { v + p } else { v.clone() }, p.clone()); 1112 | 1113 | while !r.is_zero() && old_r > acceptance_scale { 1114 | let q = &old_r / &r; 1115 | if q > acceptance_scale { 1116 | n = r.clone(); 1117 | d = t.clone(); 1118 | acceptance_scale = q.clone(); 1119 | } 1120 | (r, old_r) = (&old_r - &(&q * &r), r); 1121 | (t, old_t) = (&old_t - &(&q * &t), t); 1122 | } 1123 | 1124 | if d.is_zero() || !Z.gcd(&n, &d).is_one() { 1125 | return Err("Reconstruction failed"); 1126 | } 1127 | if d < Integer::zero() { 1128 | n = n.neg(); 1129 | d = d.neg(); 1130 | } 1131 | 1132 | Ok((n, d).into()) 1133 | } 1134 | 1135 | /// Return the rational number that corresponds to `f` evaluated at sample point `sample`, 1136 | /// i.e. `f(sample)`, if such a number exists and if the evaluations were not unlucky. 1137 | /// 1138 | /// The procedure can be repeated with a different starting prime, by setting `prime_start` 1139 | /// to a non-zero value. 1140 | pub fn rational_reconstruction< 1141 | F: Fn(&Zp, &[::Element]) -> ::Element, 1142 | R: Ring, 1143 | >( 1144 | f: F, 1145 | sample: &[R::Element], 1146 | prime_start: Option, 1147 | ) -> Result 1148 | where 1149 | Zp: FiniteFieldCore, 1150 | R::Element: ToFiniteField, 1151 | { 1152 | let mut cur_result = Integer::one(); 1153 | let mut prime_accum = Integer::one(); 1154 | let mut prime_sample_point = vec![]; 1155 | let mut primes = 1156 | PrimeIteratorU64::new(u32::get_large_prime() as u64 + prime_start.unwrap_or(0) as u64); 1157 | 1158 | let mut last_guess = Rational::zero(); 1159 | for i in 0..sample.len() { 1160 | let Some(p) = primes.next() else { 1161 | return Err("Ran out of primes for rational reconstruction"); 1162 | }; 1163 | let Some(p) = u32::try_from_integer(p.into()) else { 1164 | return Err("Ran out of primes for rational reconstruction"); 1165 | }; 1166 | 1167 | let field = FiniteField::::new(p); 1168 | prime_sample_point.clear(); 1169 | for x in sample { 1170 | prime_sample_point.push(x.to_finite_field(&field)); 1171 | } 1172 | 1173 | let eval = f(&field, &prime_sample_point); 1174 | 1175 | let eval_conv = field.from_element(&eval).to_integer(); 1176 | 1177 | if i == 0 { 1178 | cur_result = eval_conv; 1179 | } else { 1180 | let new_result = Integer::chinese_remainder( 1181 | eval_conv, 1182 | cur_result.clone(), 1183 | Integer::Single(p as i64), 1184 | prime_accum.clone(), 1185 | ); 1186 | 1187 | if cur_result == new_result { 1188 | return Ok(last_guess); 1189 | } 1190 | cur_result = new_result; 1191 | } 1192 | 1193 | prime_accum *= &Integer::Single(p as i64); 1194 | 1195 | if cur_result < Integer::zero() { 1196 | cur_result += &prime_accum; 1197 | } 1198 | 1199 | if let Ok(q) = 1200 | Rational::maximal_quotient_reconstruction(&cur_result, &prime_accum, None) 1201 | { 1202 | if q == last_guess { 1203 | return Ok(q); 1204 | } else { 1205 | last_guess = q; 1206 | } 1207 | } 1208 | } 1209 | 1210 | Err("Reconstruction failed") 1211 | } 1212 | } 1213 | 1214 | impl PartialOrd for Rational { 1215 | fn partial_cmp(&self, other: &Self) -> Option { 1216 | Some(self.cmp(other)) 1217 | } 1218 | } 1219 | 1220 | impl Ord for Rational { 1221 | fn cmp(&self, other: &Self) -> std::cmp::Ordering { 1222 | if self.denominator == other.denominator { 1223 | return self.numerator.cmp(&other.numerator); 1224 | } 1225 | 1226 | let a = self.numerator_ref() * other.denominator_ref(); 1227 | let b = self.denominator_ref() * other.numerator_ref(); 1228 | 1229 | a.cmp(&b) 1230 | } 1231 | } 1232 | 1233 | impl Add for Rational { 1234 | type Output = Rational; 1235 | 1236 | fn add(self, other: Rational) -> Self::Output { 1237 | Q.add(&self, &other) 1238 | } 1239 | } 1240 | 1241 | impl Sub for Rational { 1242 | type Output = Rational; 1243 | 1244 | fn sub(self, other: Rational) -> Self::Output { 1245 | self.add(&other.neg()) 1246 | } 1247 | } 1248 | 1249 | impl Mul for Rational { 1250 | type Output = Rational; 1251 | 1252 | fn mul(self, other: Rational) -> Self::Output { 1253 | Q.mul(&self, &other) 1254 | } 1255 | } 1256 | 1257 | impl Div for Rational { 1258 | type Output = Rational; 1259 | 1260 | fn div(self, other: Rational) -> Self::Output { 1261 | Q.div(&self, &other) 1262 | } 1263 | } 1264 | 1265 | impl<'a> Add<&'a Rational> for Rational { 1266 | type Output = Rational; 1267 | 1268 | fn add(self, other: &'a Rational) -> Self::Output { 1269 | Q.add(&self, other) 1270 | } 1271 | } 1272 | 1273 | impl<'a> Sub<&'a Rational> for Rational { 1274 | type Output = Rational; 1275 | 1276 | fn sub(self, other: &'a Rational) -> Self::Output { 1277 | self.add(&other.neg()) 1278 | } 1279 | } 1280 | 1281 | impl<'a> Mul<&'a Rational> for Rational { 1282 | type Output = Rational; 1283 | 1284 | fn mul(self, other: &'a Rational) -> Self::Output { 1285 | Q.mul(&self, other) 1286 | } 1287 | } 1288 | 1289 | impl<'a> Div<&'a Rational> for Rational { 1290 | type Output = Rational; 1291 | 1292 | fn div(self, other: &'a Rational) -> Self::Output { 1293 | Q.div(&self, other) 1294 | } 1295 | } 1296 | 1297 | impl<'a> Add<&'a Rational> for &Rational { 1298 | type Output = Rational; 1299 | 1300 | fn add(self, other: &'a Rational) -> Self::Output { 1301 | Q.add(self, other) 1302 | } 1303 | } 1304 | 1305 | impl<'a> Sub<&'a Rational> for &Rational { 1306 | type Output = Rational; 1307 | 1308 | fn sub(self, other: &'a Rational) -> Self::Output { 1309 | Q.sub(self, other) 1310 | } 1311 | } 1312 | 1313 | impl Neg for Rational { 1314 | type Output = Self; 1315 | fn neg(self) -> Self::Output { 1316 | Q.neg(&self) 1317 | } 1318 | } 1319 | 1320 | impl<'a> Mul<&'a Rational> for &Rational { 1321 | type Output = Rational; 1322 | 1323 | fn mul(self, other: &'a Rational) -> Self::Output { 1324 | Q.mul(self, other) 1325 | } 1326 | } 1327 | 1328 | impl<'a> Div<&'a Rational> for &Rational { 1329 | type Output = Rational; 1330 | 1331 | fn div(self, other: &'a Rational) -> Self::Output { 1332 | Q.div(self, other) 1333 | } 1334 | } 1335 | 1336 | impl<'a> AddAssign<&'a Rational> for Rational { 1337 | fn add_assign(&mut self, other: &'a Rational) { 1338 | Q.add_assign(self, other) 1339 | } 1340 | } 1341 | 1342 | impl<'a> SubAssign<&'a Rational> for Rational { 1343 | fn sub_assign(&mut self, other: &'a Rational) { 1344 | self.add_assign(&other.neg()) 1345 | } 1346 | } 1347 | 1348 | impl<'a> MulAssign<&'a Rational> for Rational { 1349 | fn mul_assign(&mut self, other: &'a Rational) { 1350 | Q.mul_assign(self, other) 1351 | } 1352 | } 1353 | 1354 | impl<'a> DivAssign<&'a Rational> for Rational { 1355 | fn div_assign(&mut self, other: &'a Rational) { 1356 | Q.div_assign(self, other) 1357 | } 1358 | } 1359 | 1360 | impl AddAssign for Rational { 1361 | fn add_assign(&mut self, other: Rational) { 1362 | Q.add_assign(self, &other) 1363 | } 1364 | } 1365 | 1366 | impl SubAssign for Rational { 1367 | fn sub_assign(&mut self, other: Rational) { 1368 | self.add_assign(&other.neg()) 1369 | } 1370 | } 1371 | 1372 | impl MulAssign for Rational { 1373 | fn mul_assign(&mut self, other: Rational) { 1374 | Q.mul_assign(self, &other) 1375 | } 1376 | } 1377 | 1378 | impl DivAssign for Rational { 1379 | fn div_assign(&mut self, other: Rational) { 1380 | Q.div_assign(self, &other) 1381 | } 1382 | } 1383 | 1384 | impl<'a> std::iter::Sum<&'a Self> for Rational { 1385 | fn sum>(iter: I) -> Self { 1386 | iter.fold(Rational::zero(), |a, b| a + b) 1387 | } 1388 | } 1389 | 1390 | #[cfg(test)] 1391 | mod test { 1392 | use crate::domains::{ 1393 | Field, Ring, RingOps, 1394 | integer::Z, 1395 | rational::{FractionField, Rational}, 1396 | }; 1397 | 1398 | #[test] 1399 | fn rounding() { 1400 | let r: Rational = (11, 10).into(); 1401 | let res = r.round_in_interval((1, 1).into(), (12, 10).into()); 1402 | assert_eq!(res, (1, 1)); 1403 | 1404 | let r: Rational = (11, 10).into(); 1405 | let res = r.round_in_interval((2, 1).into(), (3, 1).into()); 1406 | assert_eq!(res, (2, 1)); 1407 | 1408 | let r: Rational = (503, 1500).into(); 1409 | let res = r.round(&(1, 10).into()); 1410 | assert_eq!(res, (1, 3)); 1411 | 1412 | let r: Rational = (-503, 1500).into(); 1413 | let res = r.round(&(1, 10).into()); 1414 | assert_eq!(res, (-1, 3)); 1415 | 1416 | let r = crate::domains::float::Float::from(rug::Float::with_val( 1417 | 1000, 1418 | rug::float::Constant::Pi, 1419 | )) 1420 | .to_rational(); 1421 | let res = r.round(&(1, 100000000).into()); 1422 | assert_eq!(res, (93343, 29712)); 1423 | } 1424 | 1425 | #[test] 1426 | fn fraction_int() { 1427 | let f = FractionField::new(Z); 1428 | let b = f.neg(f.nth(3.into())); 1429 | let d = f.div(&f.add(&f.nth(100.into()), &b), &b); 1430 | assert_eq!(d, f.to_element((-97).into(), 3.into(), false)); 1431 | } 1432 | } 1433 | -------------------------------------------------------------------------------- /src/numerical_integration.rs: -------------------------------------------------------------------------------- 1 | //! Methods for numerical integration of black-box functions. 2 | //! 3 | //! A standard approach is to construct a [Grid] that will approximate the function 4 | //! and then sample the grid to evaluate the function at different points. The grid 5 | //! will adapt to the function based on the samples added. 6 | //! 7 | //! To use multichanneling methods, a [DiscreteGrid] can be used, which contains multiple 8 | //! [Grid]s that approximate different channels. 9 | //! 10 | //! # Examples 11 | //! 12 | //! ``` 13 | //! use numerica::numerical_integration::{ContinuousGrid, DiscreteGrid, Grid, MonteCarloRng, Sample}; 14 | //! 15 | //! let f = |x: &[f64]| (x[0] * std::f64::consts::PI).sin() + x[1]; 16 | //! 17 | //! let mut grid = Grid::Continuous(ContinuousGrid::new(2, 128, 100, None, false)); 18 | //! 19 | //! let mut rng = MonteCarloRng::new(0, 0); 20 | //! 21 | //! let mut sample = Sample::new(); 22 | //! for iteration in 1..20 { 23 | //! // sample 10_000 times per iteration 24 | //! for _ in 0..10_000 { 25 | //! grid.sample(&mut rng, &mut sample); 26 | //! 27 | //! if let Sample::Continuous(_cont_weight, xs) = &sample { 28 | //! grid.add_training_sample(&sample, f(xs)).unwrap(); 29 | //! } 30 | //! } 31 | //! 32 | //! grid.update(1.5, 1.5); 33 | //! 34 | //! println!( 35 | //! "Integral at iteration {}: {}", 36 | //! iteration, 37 | //! grid.get_statistics().format_uncertainty() 38 | //! ); 39 | //! } 40 | //! ``` 41 | 42 | use rand::{Rng, RngCore, SeedableRng}; 43 | use rand_xoshiro::Xoshiro256StarStar; 44 | 45 | use crate::domains::float::{Constructible, Real, RealLike}; 46 | 47 | /// Keep track of statistical quantities, such as the average, 48 | /// the error and the chi-squared of samples added over multiple 49 | /// iterations. 50 | /// 51 | /// Samples can be added using [`Self::add_sample()`]. When an iteration of 52 | /// samples is finished, call [`Self::update_iter()`], which 53 | /// updates the average, error and chi-squared over all iterations with the average 54 | /// and error of the current iteration in a weighted fashion. 55 | /// 56 | /// This accumulator can be merged with another accumulator using [`Self::merge_samples()`] or 57 | /// [`Self::merge_samples_no_reset()`]. This is useful when 58 | /// samples are collected in multiple threads. 59 | /// 60 | /// The accumulator also stores which samples yielded the highest weight thus far. 61 | /// This can be used to study the input that impacted the average and error the most. 62 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 63 | #[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))] 64 | #[derive(Debug, Default, Clone)] 65 | pub struct StatisticsAccumulator { 66 | sum: T, 67 | sum_sq: T, 68 | total_sum: T, 69 | total_sum_sq: T, 70 | weight_sum: T, 71 | avg_sum: T, 72 | pub avg: T, 73 | pub err: T, 74 | guess: T, 75 | pub chi_sq: T, 76 | chi_sum: T, 77 | chi_sq_sum: T, 78 | new_samples: usize, 79 | new_zero_evaluations: usize, 80 | pub cur_iter: usize, 81 | pub processed_samples: usize, 82 | pub max_eval_positive: T, 83 | pub max_eval_positive_xs: Option>, 84 | pub max_eval_negative: T, 85 | pub max_eval_negative_xs: Option>, 86 | pub num_zero_evaluations: usize, 87 | } 88 | 89 | impl StatisticsAccumulator { 90 | /// Create a new [StatisticsAccumulator]. 91 | pub fn new() -> StatisticsAccumulator { 92 | StatisticsAccumulator { 93 | sum: T::new_zero(), 94 | sum_sq: T::new_zero(), 95 | total_sum: T::new_zero(), 96 | total_sum_sq: T::new_zero(), 97 | weight_sum: T::new_zero(), 98 | avg_sum: T::new_zero(), 99 | avg: T::new_zero(), 100 | err: T::new_zero(), 101 | guess: T::new_zero(), 102 | chi_sq: T::new_zero(), 103 | chi_sum: T::new_zero(), 104 | chi_sq_sum: T::new_zero(), 105 | new_samples: 0, 106 | new_zero_evaluations: 0, 107 | cur_iter: 0, 108 | processed_samples: 0, 109 | max_eval_positive: T::new_zero(), 110 | max_eval_positive_xs: None, 111 | max_eval_negative: T::new_zero(), 112 | max_eval_negative_xs: None, 113 | num_zero_evaluations: 0, 114 | } 115 | } 116 | 117 | /// Copy the statistics accumulator, skipping the samples 118 | /// that evaluated to the maximum point. 119 | /// 120 | /// This function does not allocate. 121 | pub fn shallow_copy(&self) -> StatisticsAccumulator { 122 | StatisticsAccumulator { 123 | sum: self.sum, 124 | sum_sq: self.sum_sq, 125 | total_sum: self.total_sum, 126 | total_sum_sq: self.total_sum_sq, 127 | weight_sum: self.weight_sum, 128 | avg_sum: self.avg_sum, 129 | avg: self.avg, 130 | err: self.err, 131 | guess: self.guess, 132 | chi_sq: self.chi_sq, 133 | chi_sum: self.chi_sum, 134 | chi_sq_sum: self.chi_sq_sum, 135 | new_samples: self.new_samples, 136 | new_zero_evaluations: self.new_zero_evaluations, 137 | cur_iter: self.cur_iter, 138 | processed_samples: self.processed_samples, 139 | max_eval_positive: self.max_eval_positive, 140 | max_eval_positive_xs: None, 141 | max_eval_negative: self.max_eval_negative, 142 | max_eval_negative_xs: None, 143 | num_zero_evaluations: self.num_zero_evaluations, 144 | } 145 | } 146 | 147 | /// Add a new `sample` to the accumulator with its corresponding evaluation `eval`. 148 | /// Note that the average and error are only updated upon calling [`Self::update_iter()`]. 149 | pub fn add_sample(&mut self, eval: T, sample: Option<&Sample>) { 150 | self.sum += &eval; 151 | self.sum_sq += eval * eval; 152 | self.new_samples += 1; 153 | 154 | if eval == T::new_zero() { 155 | self.new_zero_evaluations += 1; 156 | } 157 | 158 | if self.max_eval_positive_xs.is_none() || eval > self.max_eval_positive { 159 | self.max_eval_positive = eval; 160 | self.max_eval_positive_xs = sample.cloned(); 161 | } 162 | 163 | if self.max_eval_negative_xs.is_none() || eval < self.max_eval_negative { 164 | self.max_eval_negative = eval; 165 | self.max_eval_negative_xs = sample.cloned(); 166 | } 167 | } 168 | 169 | /// Add the non-processed samples of `other` to non-processed samples of this 170 | /// accumulator. The non-processed samples are removed from `other`. 171 | pub fn merge_samples(&mut self, other: &mut StatisticsAccumulator) { 172 | self.merge_samples_no_reset(other); 173 | other.clear_samples(); 174 | } 175 | 176 | /// Clear all non-processed samples from this accumulator. 177 | pub fn clear_samples(&mut self) { 178 | self.sum = T::new_zero(); 179 | self.sum_sq = T::new_zero(); 180 | self.new_samples = 0; 181 | self.new_zero_evaluations = 0; 182 | } 183 | 184 | /// Add the non-processed samples of `other` to non-processed samples of this 185 | /// accumulator without removing the samples from `other`. 186 | pub fn merge_samples_no_reset(&mut self, other: &StatisticsAccumulator) { 187 | self.sum += &other.sum; 188 | self.sum_sq += &other.sum_sq; 189 | self.new_samples += other.new_samples; 190 | self.new_zero_evaluations += other.new_zero_evaluations; 191 | 192 | if other.max_eval_positive > self.max_eval_positive { 193 | self.max_eval_positive = other.max_eval_positive; 194 | self.max_eval_positive_xs 195 | .clone_from(&other.max_eval_positive_xs); 196 | } 197 | 198 | if other.max_eval_negative < self.max_eval_negative { 199 | self.max_eval_negative = other.max_eval_negative; 200 | self.max_eval_negative_xs 201 | .clone_from(&other.max_eval_negative_xs); 202 | } 203 | } 204 | 205 | /// Process the samples added with [`Self::add_sample()`] and 206 | /// compute a new average, error, and chi-squared. 207 | /// 208 | /// When `weighted_average=True`, a weighted average and error is computed using 209 | /// the iteration variances as a weight. 210 | pub fn update_iter(&mut self, weighted_average: bool) -> bool { 211 | // TODO: we could be throwing away events that are very rare 212 | if self.new_samples < 2 { 213 | self.cur_iter += 1; 214 | return false; 215 | } 216 | 217 | self.processed_samples += self.new_samples; 218 | self.num_zero_evaluations += self.new_zero_evaluations; 219 | let n = T::new_from_usize(self.new_samples); 220 | self.total_sum += self.sum; 221 | self.total_sum_sq += self.sum_sq; 222 | self.sum /= n; 223 | self.sum_sq /= n; 224 | let mut w = self.sum_sq.sqrt(); 225 | 226 | w = ((w + self.sum) * (w - self.sum)) / (n - T::new_one()); // compute variance 227 | if w == T::new_zero() { 228 | // all sampled points are the same 229 | // set the weight to a large number 230 | w = T::new_from_usize(usize::MAX); 231 | } else { 232 | w = w.inv(); 233 | } 234 | 235 | self.weight_sum += w; 236 | self.avg_sum += w * self.sum; 237 | 238 | let sigma_sq = self.weight_sum.inv(); 239 | let weighted_avg = sigma_sq * self.avg_sum; 240 | 241 | if weighted_average { 242 | self.avg = weighted_avg; 243 | self.err = sigma_sq.sqrt(); 244 | } else { 245 | let n_tot = T::new_from_usize(self.processed_samples); 246 | self.avg = self.total_sum / n_tot; 247 | let mut var = (self.total_sum_sq / n_tot).sqrt(); 248 | var = ((var + self.avg) * (var - self.avg)) / (n_tot - T::new_one()); 249 | self.err = var.sqrt(); 250 | } 251 | 252 | if self.cur_iter == 0 { 253 | self.guess = self.sum; 254 | } 255 | // compute chi-squared wrt the first iteration average 256 | // TODO: use rolling average instead? 257 | w *= self.sum - self.guess; 258 | self.chi_sum += w; 259 | self.chi_sq_sum += w * self.sum; 260 | self.chi_sq = self.chi_sq_sum - weighted_avg * self.chi_sum; 261 | 262 | // reset 263 | self.sum = T::new_zero(); 264 | self.sum_sq = T::new_zero(); 265 | self.new_samples = 0; 266 | self.new_zero_evaluations = 0; 267 | self.cur_iter += 1; 268 | 269 | true 270 | } 271 | 272 | /// Get an estimate for the average, error and chi-squared, as if the current iteration 273 | /// has ended without adding more samples. 274 | pub fn get_live_estimate(&self) -> (T, T, T) { 275 | let mut a = self.shallow_copy(); 276 | a.update_iter(false); 277 | (a.avg, a.err, a.chi_sq) 278 | } 279 | 280 | /// Format the live `mean ± sdev` as `mean(sdev)` in a human-readable way with the correct number of digits. 281 | /// 282 | /// Based on the Python package [gvar](https://github.com/gplepage/gvar) by Peter Lepage. 283 | pub fn format_live_uncertainty(&self) -> String { 284 | let mut a = self.shallow_copy(); 285 | a.update_iter(false); 286 | Self::format_uncertainty_impl(a.avg.to_f64(), a.err.to_f64()) 287 | } 288 | 289 | /// Format `mean ± sdev` as `mean(sdev)` in a human-readable way with the correct number of digits. 290 | /// 291 | /// Based on the Python package [gvar](https://github.com/gplepage/gvar) by Peter Lepage. 292 | pub fn format_uncertainty(&self) -> String { 293 | Self::format_uncertainty_impl(self.avg.to_f64(), self.err.to_f64()) 294 | } 295 | 296 | fn format_uncertainty_impl(mean: f64, sdev: f64) -> String { 297 | fn ndec(x: f64, offset: usize) -> i32 { 298 | let mut ans = (offset as f64 - x.log10()) as i32; 299 | if ans > 0 && x * 10.0f64.powi(ans) >= [0.5, 9.5, 99.5][offset] { 300 | ans -= 1; 301 | } 302 | if ans < 0 { 0 } else { ans } 303 | } 304 | let v = mean; 305 | let dv = sdev; 306 | 307 | // special cases 308 | if v.is_nan() || dv.is_nan() { 309 | format!("{v:e} ± {dv:e}") 310 | } else if dv.is_infinite() { 311 | format!("{v:e} ± inf") 312 | } else if v == 0. && !(1e-4..1e5).contains(&dv) { 313 | if dv == 0. { 314 | "0(0)".to_owned() 315 | } else { 316 | let e = format!("{dv:.1e}"); 317 | let mut ans = e.split('e'); 318 | let e1 = ans.next().unwrap(); 319 | let e2 = ans.next().unwrap(); 320 | "0.0(".to_owned() + e1 + ")e" + e2 321 | } 322 | } else if v == 0. { 323 | if dv >= 9.95 { 324 | format!("0({dv:.0})") 325 | } else if dv >= 0.995 { 326 | format!("0.0({dv:.1})") 327 | } else { 328 | let ndecimal = ndec(dv, 2); 329 | format!( 330 | "{:.*}({:.0})", 331 | ndecimal as usize, 332 | v, 333 | dv * 10.0f64.powi(ndecimal) 334 | ) 335 | } 336 | } else if dv == 0. { 337 | let e = format!("{v:e}"); 338 | let mut ans = e.split('e'); 339 | let e1 = ans.next().unwrap(); 340 | let e2 = ans.next().unwrap(); 341 | if e2 != "0" { 342 | e1.to_owned() + "(0)e" + e2 343 | } else { 344 | e1.to_owned() + "(0)" 345 | } 346 | } else if dv > 1e4 * v.abs() { 347 | format!("{v:.1e} ± {dv:.2e}") 348 | } else if v.abs() >= 1e6 || v.abs() < 1e-5 { 349 | // exponential notation for large |self.mean| 350 | let exponent = v.abs().log10().floor(); 351 | let fac = 10.0.powf(&exponent); 352 | let mantissa = Self::format_uncertainty_impl(v / fac, dv / fac); 353 | let e = format!("{fac:.0e}"); 354 | let mut ee = e.split('e'); 355 | mantissa + "e" + ee.nth(1).unwrap() 356 | } 357 | // normal cases 358 | else if dv >= 9.95 { 359 | if v.abs() >= 9.5 { 360 | format!("{v:.0}({dv:.0})") 361 | } else { 362 | let ndecimal = ndec(v.abs(), 1); 363 | format!("{:.*}({:.*})", ndecimal as usize, v, ndecimal as usize, dv) 364 | } 365 | } else if dv >= 0.995 { 366 | if v.abs() >= 0.95 { 367 | format!("{v:.1}({dv:.1})") 368 | } else { 369 | let ndecimal = ndec(v.abs(), 1); 370 | format!("{:.*}({:.*})", ndecimal as usize, v, ndecimal as usize, dv) 371 | } 372 | } else { 373 | let ndecimal = ndec(v.abs(), 1).max(ndec(dv, 2)); 374 | format!( 375 | "{:.*}({:.0})", 376 | ndecimal as usize, 377 | v, 378 | dv * 10.0f64.powi(ndecimal) 379 | ) 380 | } 381 | } 382 | } 383 | 384 | /// A sample taken from a [Grid] that approximates a function. The sample is more likely to fall in a region 385 | /// where the function changes rapidly. 386 | /// 387 | /// If the sample comes from a [ContinuousGrid], it is the variant [Continuous](Sample::Continuous) 388 | /// and contains the weight and the list of sample points. 389 | /// If the sample comes from a [DiscreteGrid], it is the variant [Discrete](Sample::Discrete) and contains 390 | /// the weight, the bin and the subsample if the bin has a nested grid. 391 | /// If the sample comes from a [Uniform](Grid::Uniform) grid, it is the variant [Uniform](Sample::Uniform) 392 | /// and contains the weight, the list of discrete bin indices and the list of continuous sample points. 393 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 394 | #[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))] 395 | #[derive(Debug, Clone)] 396 | pub enum Sample { 397 | Continuous(T, Vec), 398 | Discrete(T, usize, Option>>), 399 | Uniform(T, Vec, Vec), 400 | } 401 | 402 | impl Default for Sample { 403 | fn default() -> Self { 404 | Self::new() 405 | } 406 | } 407 | 408 | impl Sample { 409 | /// Construct a new empty sample that can be handed over to [`Grid::sample()`]. 410 | pub fn new() -> Sample { 411 | Sample::Continuous(T::new_zero(), vec![]) 412 | } 413 | 414 | /// Get the weight of the sample. 415 | pub fn get_weight(&self) -> T { 416 | match self { 417 | Sample::Continuous(w, _) | Sample::Discrete(w, _, _) | Sample::Uniform(w, _, _) => *w, 418 | } 419 | } 420 | 421 | /// Transform the sample to a discrete grid, used for recycling memory. 422 | fn to_discrete_grid(&mut self) -> (&mut T, &mut usize, &mut Option>>) { 423 | if let Sample::Continuous(..) = self { 424 | *self = Sample::Discrete(T::new_zero(), 0, None); 425 | } else if let Sample::Uniform(..) = self { 426 | *self = Sample::Discrete(T::new_zero(), 0, None); 427 | } 428 | 429 | match self { 430 | Sample::Discrete(weight, index, sub_sample) => (weight, index, sub_sample), 431 | _ => unreachable!(), 432 | } 433 | } 434 | 435 | /// Transform the sample to a continuous, used for recycling memory. 436 | fn to_continuous_grid(&mut self) -> (&mut T, &mut Vec) { 437 | if let Sample::Continuous(..) = self { 438 | *self = Sample::Continuous(T::new_zero(), vec![]) 439 | } else if let Sample::Uniform(_, _, g) = self { 440 | *self = Sample::Continuous(T::new_zero(), std::mem::take(g)) 441 | } 442 | 443 | match self { 444 | Sample::Continuous(weight, sub_samples) => (weight, sub_samples), 445 | _ => unreachable!(), 446 | } 447 | } 448 | 449 | /// Transform the sample to a continuous one and extract bin indices, used for recycling memory. 450 | fn to_continuous_with_uniform(&mut self) -> Vec { 451 | match self { 452 | Sample::Uniform(_, bin_indices, g) => { 453 | let b = std::mem::take(bin_indices); 454 | *self = Sample::Continuous(T::new_zero(), std::mem::take(g)); 455 | b 456 | } 457 | Self::Continuous(_, _) => { 458 | vec![] 459 | } 460 | Self::Discrete(_, _, _) => { 461 | *self = Sample::Continuous(T::new_zero(), vec![]); 462 | vec![] 463 | } 464 | } 465 | } 466 | } 467 | 468 | /// An adapting grid that captures the enhancements of an integrand. 469 | /// It supports discrete and continuous dimensions. The discrete dimensions 470 | /// can have a nested grid. 471 | /// 472 | /// Use [Grid::Uniform] to create a layered rectangular discrete grid 473 | /// that is uniformly sampled and has a shared continuous grid across all discrete bins. 474 | /// 475 | /// Use [Grid::clone_without_samples] to create a copy of the grid that can 476 | /// accumulate samples independently, and can later be merged into the current grid. 477 | /// 478 | /// # Examples 479 | /// 480 | /// ``` 481 | /// use numerica::numerical_integration::{ContinuousGrid, DiscreteGrid, Grid, MonteCarloRng, Sample}; 482 | /// 483 | /// let f = |x: &[f64]| (x[0] * std::f64::consts::PI).sin() + x[1]; 484 | /// 485 | /// let mut grid = Grid::Continuous(ContinuousGrid::new(2, 128, 100, None, false)); 486 | /// 487 | /// let mut rng = MonteCarloRng::new(0, 0); 488 | /// 489 | /// let mut sample = Sample::new(); 490 | /// for iteration in 1..20 { 491 | /// // sample 10_000 times per iteration 492 | /// for _ in 0..10_000 { 493 | /// grid.sample(&mut rng, &mut sample); 494 | /// 495 | /// if let Sample::Continuous(_cont_weight, xs) = &sample { 496 | /// grid.add_training_sample(&sample, f(xs)).unwrap(); 497 | /// } 498 | /// } 499 | /// 500 | /// grid.update(1.5, 1.5); 501 | /// 502 | /// println!( 503 | /// "Integral at iteration {}: {}", 504 | /// iteration, 505 | /// grid.get_statistics().format_uncertainty() 506 | /// ); 507 | /// } 508 | /// ``` 509 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 510 | #[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))] 511 | #[derive(Debug, Clone)] 512 | pub enum Grid { 513 | /// A continuous grid. 514 | Continuous(ContinuousGrid), 515 | /// A discrete grid with optional nested grids. 516 | Discrete(DiscreteGrid), 517 | /// A layered rectangular uniform discrete grid `(a, g)` where `a.len()` is the number of 518 | /// discrete dimensions, `a[i]` is the number of bins in discrete dimension `i`, 519 | /// and `g` is a shared continuous grid across all discrete bins. 520 | /// Each discrete bin has equal probability. 521 | Uniform(Vec, ContinuousGrid), 522 | } 523 | 524 | impl Grid { 525 | /// Sample a position in the grid. The sample is more likely to land in a region 526 | /// where the function the grid is based on is changing rapidly. 527 | pub fn sample(&mut self, rng: &mut R, sample: &mut Sample) { 528 | match self { 529 | Grid::Continuous(g) => g.sample(rng, sample), 530 | Grid::Discrete(g) => g.sample(rng, sample), 531 | Grid::Uniform(disc, g) => { 532 | let mut bin_indices = sample.to_continuous_with_uniform(); 533 | g.sample(rng, sample); 534 | 535 | bin_indices.clear(); 536 | for n_bins in disc.iter() { 537 | bin_indices.push(rng.random_range(0..*n_bins)); 538 | } 539 | 540 | if let Sample::Continuous(w, c) = sample { 541 | for n_bins in disc.iter() { 542 | *w *= w.from_usize(*n_bins); 543 | } 544 | 545 | *sample = Sample::Uniform(*w, bin_indices, std::mem::take(c)); 546 | } else { 547 | unreachable!() 548 | } 549 | } 550 | } 551 | } 552 | 553 | /// Add a sample point and its corresponding evaluation `eval` to the grid as training. 554 | /// Upon a call to [`Grid::update`], the grid will be adapted to better represent 555 | /// the function that is being evaluated. 556 | pub fn add_training_sample(&mut self, sample: &Sample, eval: T) -> Result<(), String> { 557 | match self { 558 | Grid::Continuous(g) => g.add_training_sample(sample, eval), 559 | Grid::Discrete(g) => g.add_training_sample(sample, eval), 560 | Grid::Uniform(_, g) => g.add_training_sample(sample, eval), 561 | } 562 | } 563 | 564 | /// Returns `Ok` when this grid can be merged with another grid, 565 | /// and `Err` when the grids have a different shape. 566 | pub fn is_mergeable(&self, grid: &Grid) -> Result<(), String> { 567 | match (self, grid) { 568 | (Grid::Continuous(c1), Grid::Continuous(c2)) => c1.is_mergeable(c2), 569 | (Grid::Discrete(d1), Grid::Discrete(d2)) => d1.is_mergeable(d2), 570 | (Grid::Uniform(d1, c1), Grid::Uniform(d2, c2)) if d1 == d2 => c1.is_mergeable(c2), 571 | _ => Err("Cannot merge a discrete and continuous grid".to_owned()), 572 | } 573 | } 574 | 575 | /// Merge a grid with exactly the same structure. 576 | pub fn merge(&mut self, grid: &Grid) -> Result<(), String> { 577 | // first do a complete check to see if the grids are mergeable 578 | self.is_mergeable(grid)?; 579 | self.merge_unchecked(grid); 580 | 581 | Ok(()) 582 | } 583 | 584 | /// Merge a grid without checks. For internal use only. 585 | fn merge_unchecked(&mut self, grid: &Grid) { 586 | match (self, grid) { 587 | (Grid::Continuous(c1), Grid::Continuous(c2)) => c1.merge_unchecked(c2), 588 | (Grid::Discrete(d1), Grid::Discrete(d2)) => d1.merge_unchecked(d2), 589 | (Grid::Uniform(_, c1), Grid::Uniform(_, c2)) => c1.merge_unchecked(c2), 590 | _ => panic!("Cannot merge grids that have a different shape."), 591 | } 592 | } 593 | 594 | /// Update the grid based on the samples added through [`Grid::add_training_sample`]. 595 | pub fn update(&mut self, discrete_learning_rate: T, continuous_learning_rate: T) { 596 | match self { 597 | Grid::Continuous(g) => g.update(continuous_learning_rate), 598 | Grid::Discrete(g) => g.update(discrete_learning_rate, continuous_learning_rate), 599 | Grid::Uniform(_, g) => g.update(continuous_learning_rate), 600 | } 601 | } 602 | 603 | /// Get the statistics of this grid. 604 | pub fn get_statistics(&mut self) -> &StatisticsAccumulator { 605 | match self { 606 | Grid::Continuous(g) => &g.accumulator, 607 | Grid::Discrete(g) => &g.accumulator, 608 | Grid::Uniform(_, g) => &g.accumulator, 609 | } 610 | } 611 | 612 | /// Clone the grid and remove any samples that have not been processed in an update. 613 | /// The new grid can accumulate samples independently, and can later be merged into 614 | /// the current grid. 615 | pub fn clone_without_samples(&self) -> Grid { 616 | match self { 617 | Grid::Continuous(c) => Grid::Continuous(c.clone_without_samples()), 618 | Grid::Discrete(d) => Grid::Discrete(d.clone_without_samples()), 619 | Grid::Uniform(bins, c) => Grid::Uniform(bins.clone(), c.clone_without_samples()), 620 | } 621 | } 622 | } 623 | /// A bin of a discrete grid, which may contain a subgrid. 624 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 625 | #[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))] 626 | #[derive(Debug, Clone)] 627 | pub struct Bin { 628 | pub pdf: T, 629 | pub accumulator: StatisticsAccumulator, 630 | pub sub_grid: Option>, 631 | } 632 | 633 | impl Bin { 634 | /// Returns `Ok` when this grid can be merged with another grid, 635 | /// and `Err` when the grids have a different shape. 636 | pub fn is_mergeable(&self, other: &Bin) -> Result<(), String> { 637 | if self.pdf != other.pdf { 638 | return Err("PDF not equivalent".to_owned()); 639 | } 640 | 641 | match (&self.sub_grid, &other.sub_grid) { 642 | (None, None) => Ok(()), 643 | (Some(s1), Some(s2)) => s1.is_mergeable(s2), 644 | (None, Some(_)) | (Some(_), None) => Err("Sub-grid not equivalent".to_owned()), 645 | } 646 | } 647 | 648 | /// Merge a grid without checks. For internal use only. 649 | fn merge(&mut self, other: &Bin) { 650 | self.accumulator.merge_samples_no_reset(&other.accumulator); 651 | 652 | if let (Some(s1), Some(s2)) = (&mut self.sub_grid, &other.sub_grid) { 653 | s1.merge_unchecked(s2); 654 | } 655 | } 656 | } 657 | 658 | /// A discrete grid consisting of a given number of bins. 659 | /// Each bin may have a nested grid. 660 | /// 661 | /// After adding training samples and updating, the probabilities 662 | /// of a sample from the grid landing in a bin is proportional to its 663 | /// average value if training happens on the average, or to its 664 | /// variance (recommended). 665 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 666 | #[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))] 667 | #[derive(Debug, Clone)] 668 | pub struct DiscreteGrid { 669 | pub bins: Vec>, 670 | pub accumulator: StatisticsAccumulator, 671 | max_prob_ratio: T, 672 | train_on_avg: bool, 673 | } 674 | 675 | impl DiscreteGrid { 676 | /// Create a new discrete grid with `bins.len()` number of bins, where 677 | /// each bin may have a sub-grid. 678 | /// 679 | /// Also set the maximal probability ratio between bins, `max_prob_ratio`, 680 | /// that prevents one bin from getting oversampled. 681 | /// 682 | /// If you want to train on the average instead of the error, set `train_on_avg` to `true` (not recommended). 683 | pub fn new( 684 | bins: Vec>>, 685 | max_prob_ratio: T, 686 | train_on_avg: bool, 687 | ) -> DiscreteGrid { 688 | let pdf = T::new_from_usize(1) / T::new_from_usize(bins.len()); 689 | DiscreteGrid { 690 | bins: bins 691 | .into_iter() 692 | .map(|s| Bin { 693 | pdf, 694 | accumulator: StatisticsAccumulator::new(), 695 | sub_grid: s, 696 | }) 697 | .collect(), 698 | accumulator: StatisticsAccumulator::new(), 699 | max_prob_ratio, 700 | train_on_avg, 701 | } 702 | } 703 | 704 | /// Sample a bin from all bins based on the bin pdfs. 705 | fn sample_bin(&self, rng: &mut R) -> (usize, T) { 706 | let r: T = T::new_sample_unit(rng); 707 | 708 | let mut cdf = T::new_zero(); 709 | for (i, bin) in self.bins.iter().enumerate() { 710 | cdf += bin.pdf; 711 | if r <= cdf { 712 | // the 'volume' of the bin is 1 / pdf 713 | return (i, bin.pdf.inv()); 714 | } 715 | } 716 | unreachable!( 717 | "Could not sample discrete dimension: {:?} at point {}", 718 | self, r 719 | ); 720 | } 721 | 722 | /// Update the discrete grid probabilities of landing in a particular bin when sampling, 723 | /// and adapt all sub-grids based on the new training samples. 724 | /// 725 | /// If `learning_rate` is set to 0, no training happens. 726 | pub fn update(&mut self, discrete_learning_rate: T, continuous_learning_rate: T) { 727 | let mut err_sum = T::new_zero(); 728 | for bin in &mut self.bins { 729 | if let Some(sub_grid) = &mut bin.sub_grid { 730 | sub_grid.update(discrete_learning_rate, continuous_learning_rate); 731 | } 732 | 733 | let acc = &mut bin.accumulator; 734 | acc.update_iter(false); 735 | 736 | if acc.processed_samples > 1 { 737 | err_sum += acc.err * T::new_from_usize(acc.processed_samples - 1).sqrt(); 738 | } 739 | } 740 | 741 | if discrete_learning_rate.is_zero() 742 | || self.bins.iter().all(|bin| { 743 | if self.train_on_avg { 744 | bin.accumulator.avg == T::new_zero() 745 | } else { 746 | bin.accumulator.err == T::new_zero() || bin.accumulator.processed_samples < 2 747 | } 748 | }) 749 | { 750 | return; 751 | } 752 | 753 | let mut max_per_bin = T::new_zero(); 754 | for bin in &mut self.bins { 755 | let acc = &mut bin.accumulator; 756 | 757 | if self.train_on_avg { 758 | bin.pdf = acc.avg.norm() 759 | } else if acc.processed_samples < 2 { 760 | bin.pdf = T::new_zero(); 761 | } else { 762 | let n_samples = T::new_from_usize(acc.processed_samples - 1); 763 | let var = acc.err * n_samples.sqrt(); 764 | bin.pdf = var; 765 | } 766 | 767 | if bin.pdf > max_per_bin { 768 | max_per_bin = bin.pdf; 769 | } 770 | } 771 | 772 | let mut sum = T::new_zero(); 773 | let r = max_per_bin / self.max_prob_ratio; 774 | for bin in &mut self.bins { 775 | if bin.pdf < r { 776 | bin.pdf = r 777 | } 778 | sum += bin.pdf; 779 | } 780 | 781 | for bin in &mut self.bins { 782 | bin.pdf /= sum; 783 | } 784 | 785 | self.accumulator.update_iter(false); 786 | } 787 | 788 | /// Sample a point form this grid, writing the result in `sample`. 789 | pub fn sample(&mut self, rng: &mut R, sample: &mut Sample) { 790 | let (weight, vs, child) = sample.to_discrete_grid(); 791 | 792 | *weight = T::new_one(); 793 | let (v, w) = self.sample_bin(rng); 794 | *weight *= &w; 795 | *vs = v; 796 | 797 | // get the child grid for this sample 798 | if let Some(sub_grid) = &mut self.bins[v].sub_grid { 799 | let child_sample = if let Some(sub_sample) = child { 800 | sub_sample 801 | } else { 802 | *child = Some(Box::new(Sample::new())); 803 | child.as_mut().unwrap() 804 | }; 805 | 806 | sub_grid.sample(rng, child_sample); 807 | 808 | // multiply the weight of the subsample 809 | *weight *= &child_sample.get_weight(); 810 | } else { 811 | *child = None; 812 | }; 813 | } 814 | 815 | /// Add a training sample with its corresponding evaluation, i.e. `f(sample)`, to the grid. 816 | pub fn add_training_sample(&mut self, sample: &Sample, eval: T) -> Result<(), String> { 817 | if !eval.is_finite() { 818 | return Err(format!( 819 | "Added training sample that is not finite: sample={sample:?}, fx={eval}" 820 | )); 821 | } 822 | 823 | if let Sample::Discrete(weight, index, sub_sample) = sample { 824 | self.accumulator.add_sample(eval * weight, Some(sample)); 825 | 826 | // undo the weight of the bin, which is 1 / pdf 827 | let bin_weight = *weight * self.bins[*index].pdf; 828 | self.bins[*index] 829 | .accumulator 830 | .add_sample(bin_weight * eval, Some(sample)); 831 | 832 | if let Some(sg) = &mut self.bins[*index].sub_grid 833 | && let Some(sub_sample) = sub_sample 834 | { 835 | sg.add_training_sample(sub_sample, eval)?; 836 | } 837 | 838 | Ok(()) 839 | } else { 840 | Err(format!("Discrete sample expected: {sample:?}")) 841 | } 842 | } 843 | 844 | /// Returns `Ok` when this grid can be merged with another grid, 845 | /// and `Err` when the grids have a different shape. 846 | pub fn is_mergeable(&self, other: &DiscreteGrid) -> Result<(), String> { 847 | if self.bins.len() != other.bins.len() { 848 | return Err("Discrete grid dimensions do not match".to_owned()); 849 | } 850 | 851 | for (c, o) in self.bins.iter().zip(&other.bins) { 852 | c.is_mergeable(o)?; 853 | } 854 | 855 | Ok(()) 856 | } 857 | 858 | /// Merge a grid with exactly the same structure. 859 | pub fn merge(&mut self, grid: &DiscreteGrid) -> Result<(), String> { 860 | // first do a complete check to see if the grids are mergeable 861 | self.is_mergeable(grid)?; 862 | self.merge_unchecked(grid); 863 | 864 | Ok(()) 865 | } 866 | 867 | /// Merge a grid without checks. For internal use only. 868 | fn merge_unchecked(&mut self, other: &DiscreteGrid) { 869 | for (c, o) in self.bins.iter_mut().zip(&other.bins) { 870 | c.merge(o); 871 | } 872 | 873 | self.accumulator.merge_samples_no_reset(&other.accumulator); 874 | } 875 | 876 | /// Clone the grid and remove any samples that have not been processed in an update. 877 | pub fn clone_without_samples(&self) -> DiscreteGrid { 878 | let mut d = self.clone(); 879 | 880 | for bin in &mut d.bins { 881 | bin.accumulator.clear_samples(); 882 | 883 | bin.sub_grid.as_ref().map(|g| g.clone_without_samples()); 884 | } 885 | 886 | d.accumulator.clear_samples(); 887 | d 888 | } 889 | } 890 | 891 | /// An adaptive continuous grid that uses factorized dimensions to approximate 892 | /// a function. The VEGAS algorithm is used to adapt the grid 893 | /// based on new sample points. 894 | /// 895 | /// After adding training samples and updating, the probabilities 896 | /// of a sample from the grid landing in a bin is proportional to its 897 | /// average value if training happens on the average, or to its 898 | /// variance (recommended). 899 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 900 | #[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))] 901 | #[derive(Debug, Clone)] 902 | pub struct ContinuousGrid { 903 | pub continuous_dimensions: Vec>, 904 | pub accumulator: StatisticsAccumulator, 905 | } 906 | 907 | impl ContinuousGrid { 908 | /// Create a new grid with `n_dims` dimensions and `n_bins` bins 909 | /// per dimension. 910 | /// 911 | /// With `min_samples_for_update` grid updates can be prevented if 912 | /// there are too few samples in a certain bin. With `bin_number_evolution` 913 | /// the bin numbers can be changed based on the iteration index. If the 914 | /// `bin_number_evolution` array is smaller than the current iteration number, 915 | /// the last element is taken from the list. 916 | /// 917 | /// With `train_on_avg`, the grids will be adapted based on the average value 918 | /// of the bin, contrary to it's variance. 919 | pub fn new( 920 | n_dims: usize, 921 | n_bins: usize, 922 | min_samples_for_update: usize, 923 | bin_number_evolution: Option>, 924 | train_on_avg: bool, 925 | ) -> ContinuousGrid { 926 | ContinuousGrid { 927 | continuous_dimensions: vec![ 928 | ContinuousDimension::new( 929 | n_bins, 930 | min_samples_for_update, 931 | bin_number_evolution, 932 | train_on_avg 933 | ); 934 | n_dims 935 | ], 936 | accumulator: StatisticsAccumulator::new(), 937 | } 938 | } 939 | 940 | /// Sample a point in the grid, writing the result in `sample`. 941 | pub fn sample(&mut self, rng: &mut R, sample: &mut Sample) { 942 | let (weight, vs) = sample.to_continuous_grid(); 943 | *weight = T::new_one(); 944 | vs.clear(); 945 | vs.resize(self.continuous_dimensions.len(), T::new_zero()); 946 | for (vs, d) in vs.iter_mut().zip(&self.continuous_dimensions) { 947 | let (v, w) = d.sample(rng); 948 | *weight *= &w; 949 | *vs = v; 950 | } 951 | } 952 | 953 | /// Add a training sample with its corresponding evaluation, i.e. `f(sample)`, to the grid. 954 | pub fn add_training_sample(&mut self, sample: &Sample, eval: T) -> Result<(), String> { 955 | if !eval.is_finite() { 956 | return Err(format!( 957 | "Added training sample that is not finite: sample={sample:?}, fx={eval}" 958 | )); 959 | } 960 | 961 | match sample { 962 | Sample::Continuous(weight, xs) | Sample::Uniform(weight, _, xs) => { 963 | self.accumulator.add_sample(eval * weight, Some(sample)); 964 | 965 | for (d, x) in self.continuous_dimensions.iter_mut().zip(xs) { 966 | d.add_training_sample(*x, *weight, eval)?; 967 | } 968 | Ok(()) 969 | } 970 | _ => unreachable!( 971 | "Sample cannot be converted to continuous sample: {:?}", 972 | self 973 | ), 974 | } 975 | } 976 | 977 | /// Update the grid based on the added training samples. This will move the partition bounds of every dimension. 978 | /// 979 | /// The `learning_rate` determines the speed of the adaptation. If it is set to `0`, no training will be performed. 980 | pub fn update(&mut self, learning_rate: T) { 981 | for d in self.continuous_dimensions.iter_mut() { 982 | d.update(learning_rate); 983 | } 984 | 985 | self.accumulator.update_iter(false); 986 | } 987 | 988 | /// Returns `Ok` when this grid can be merged with another grid, 989 | /// and `Err` when the grids have a different shape. 990 | pub fn is_mergeable(&self, grid: &ContinuousGrid) -> Result<(), String> { 991 | if self.continuous_dimensions.len() != grid.continuous_dimensions.len() { 992 | return Err("Cannot merge grids that have a different shape.".to_owned()); 993 | } 994 | 995 | for (c, o) in self 996 | .continuous_dimensions 997 | .iter() 998 | .zip(&grid.continuous_dimensions) 999 | { 1000 | c.is_mergeable(o)?; 1001 | } 1002 | Ok(()) 1003 | } 1004 | 1005 | /// Merge a grid with exactly the same structure. 1006 | pub fn merge(&mut self, grid: &ContinuousGrid) -> Result<(), String> { 1007 | // first do a complete check to see if the grids are mergeable 1008 | self.is_mergeable(grid)?; 1009 | self.merge_unchecked(grid); 1010 | 1011 | Ok(()) 1012 | } 1013 | 1014 | /// Merge a grid without checks. For internal use only. 1015 | fn merge_unchecked(&mut self, grid: &ContinuousGrid) { 1016 | self.accumulator.merge_samples_no_reset(&grid.accumulator); 1017 | 1018 | for (c, o) in self 1019 | .continuous_dimensions 1020 | .iter_mut() 1021 | .zip(&grid.continuous_dimensions) 1022 | { 1023 | c.merge(o); 1024 | } 1025 | } 1026 | 1027 | /// Clone the grid and remove any samples that have not been processed in an update. 1028 | pub fn clone_without_samples(&self) -> ContinuousGrid { 1029 | let mut c = self.clone(); 1030 | 1031 | for dim in &mut c.continuous_dimensions { 1032 | for bin in &mut dim.bin_accumulator { 1033 | bin.clear_samples(); 1034 | } 1035 | } 1036 | 1037 | c.accumulator.clear_samples(); 1038 | c 1039 | } 1040 | } 1041 | 1042 | /// A dimension in a continuous grid that contains a partitioning. 1043 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 1044 | #[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))] 1045 | #[derive(Debug, Clone)] 1046 | pub struct ContinuousDimension { 1047 | pub partitioning: Vec, 1048 | bin_accumulator: Vec>, 1049 | bin_importance: Vec, 1050 | counter: Vec, 1051 | min_samples_for_update: usize, 1052 | bin_number_evolution: Vec, 1053 | update_counter: usize, 1054 | train_on_avg: bool, 1055 | } 1056 | 1057 | impl ContinuousDimension { 1058 | /// Create a new dimension with `n_bins` bins. 1059 | /// 1060 | /// With `min_samples_for_update` grid updates can be prevented if 1061 | /// there are too few samples in a certain bin. With `bin_number_evolution` 1062 | /// the bin numbers can be changed based on the iteration index. If the 1063 | /// `bin_number_evolution` array is smaller than the current iteration number, 1064 | /// the last element is taken from the list. 1065 | /// 1066 | /// With `train_on_avg`, the grids will be adapted based on the average value 1067 | /// of the bin, contrary to it's variance. 1068 | fn new( 1069 | n_bins: usize, 1070 | min_samples_for_update: usize, 1071 | bin_number_evolution: Option>, 1072 | train_on_avg: bool, 1073 | ) -> ContinuousDimension { 1074 | ContinuousDimension { 1075 | partitioning: (0..=n_bins) 1076 | .map(|i| T::new_from_usize(i) / T::new_from_usize(n_bins)) 1077 | .collect(), 1078 | bin_importance: vec![T::new_zero(); n_bins], 1079 | bin_accumulator: vec![StatisticsAccumulator::new(); n_bins], 1080 | counter: vec![0; n_bins], 1081 | min_samples_for_update, 1082 | bin_number_evolution: bin_number_evolution.unwrap_or(vec![n_bins]), 1083 | update_counter: 0, 1084 | train_on_avg, 1085 | } 1086 | } 1087 | 1088 | /// Sample a point in this dimension, writing the result in `sample`. 1089 | fn sample(&self, rng: &mut R) -> (T, T) { 1090 | let r: T = T::new_sample_unit(rng); 1091 | 1092 | // map the point to a bin 1093 | let n_bins = T::new_from_usize(self.partitioning.len() - 1); 1094 | let bin_index = (n_bins * r).to_usize_clamped(); 1095 | let bin_width = self.partitioning[bin_index + 1] - self.partitioning[bin_index]; 1096 | 1097 | // rescale the point in the bin 1098 | let sample = 1099 | self.partitioning[bin_index] + (n_bins * r - T::new_from_usize(bin_index)) * bin_width; 1100 | let weight = n_bins as T * bin_width; // d_sample / d_r 1101 | 1102 | (sample, weight) 1103 | } 1104 | 1105 | /// Add a training sample with its corresponding evaluation, i.e. `f(sample)`, to the proper bin. 1106 | fn add_training_sample(&mut self, sample: T, weight: T, eval: T) -> Result<(), String> { 1107 | if sample < T::new_zero() 1108 | || sample > T::new_one() 1109 | || !eval.is_finite() 1110 | || !weight.is_finite() 1111 | { 1112 | return Err(format!( 1113 | "Malformed sample point: sample={sample}, weight={weight}, fx={eval}" 1114 | )); 1115 | } 1116 | 1117 | let mut index = self 1118 | .partitioning 1119 | .binary_search_by(|v| v.partial_cmp(&sample).unwrap()) 1120 | .unwrap_or_else(|e| e); 1121 | index = index.saturating_sub(1); 1122 | 1123 | self.bin_accumulator[index].add_sample(weight * eval, None); 1124 | Ok(()) 1125 | } 1126 | 1127 | /// Update the grid based on the added training samples. This will move the partition bounds of every dimension. 1128 | /// 1129 | /// The `learning_rate` determines the speed of the adaptation. If it is set to `0`, no training will be performed. 1130 | fn update(&mut self, learning_rate: T) { 1131 | for (bi, acc) in self.bin_importance.iter_mut().zip(&self.bin_accumulator) { 1132 | if self.train_on_avg { 1133 | *bi += &acc.sum 1134 | } else { 1135 | *bi += &acc.sum_sq; 1136 | } 1137 | } 1138 | 1139 | for (c, acc) in self.counter.iter_mut().zip(&mut self.bin_accumulator) { 1140 | *c += acc.new_samples; 1141 | acc.clear_samples(); 1142 | } 1143 | 1144 | if self.counter.iter().sum::() < self.min_samples_for_update { 1145 | // do not train the grid if there is a lack of samples 1146 | return; 1147 | } 1148 | 1149 | if learning_rate.is_zero() { 1150 | self.bin_accumulator.clear(); 1151 | self.bin_accumulator 1152 | .resize(self.partitioning.len() - 1, StatisticsAccumulator::new()); 1153 | self.bin_importance.clear(); 1154 | self.bin_importance 1155 | .resize(self.partitioning.len() - 1, T::new_zero()); 1156 | self.counter.clear(); 1157 | self.counter.resize(self.partitioning.len() - 1, 0); 1158 | return; 1159 | } 1160 | 1161 | let n_bins = self.partitioning.len() - 1; 1162 | 1163 | for avg in self.bin_importance.iter_mut() { 1164 | *avg = avg.norm(); 1165 | } 1166 | 1167 | // normalize the average 1168 | for (avg, &c) in self.bin_importance.iter_mut().zip(&self.counter) { 1169 | if c > 0 { 1170 | *avg /= T::new_from_usize(c); 1171 | } 1172 | } 1173 | 1174 | // smoothen the averages between adjacent grid points 1175 | if self.partitioning.len() > 2 { 1176 | let mut prev = self.bin_importance[0]; 1177 | let mut cur = self.bin_importance[1]; 1178 | self.bin_importance[0] = (T::new_from_usize(3) * prev + cur) / T::new_from_usize(4); 1179 | for bin in 1..n_bins - 1 { 1180 | let s = prev + cur * T::new_from_usize(6); 1181 | prev = cur; 1182 | cur = self.bin_importance[bin + 1]; 1183 | self.bin_importance[bin] = (s + cur) / T::new_from_usize(8); 1184 | } 1185 | self.bin_importance[n_bins - 1] = 1186 | (prev + T::new_from_usize(3) * cur) / T::new_from_usize(4); 1187 | } 1188 | 1189 | let mut sum = T::new_zero(); 1190 | for x in &self.bin_importance { 1191 | sum += x; 1192 | } 1193 | 1194 | let mut imp_sum = T::new_zero(); 1195 | for bi in self.bin_importance.iter_mut() { 1196 | let m = if *bi == sum { 1197 | T::new_one() 1198 | } else if *bi == T::new_zero() { 1199 | T::new_zero() 1200 | } else { 1201 | ((*bi / sum - T::new_one()) / (*bi / sum).log()).powf(&learning_rate) 1202 | }; 1203 | *bi = m; 1204 | imp_sum += m; 1205 | } 1206 | 1207 | let new_number_of_bins = *self 1208 | .bin_number_evolution 1209 | .get(self.update_counter) 1210 | .or(self.bin_number_evolution.last()) 1211 | .unwrap_or(&self.bin_accumulator.len()); 1212 | self.update_counter += 1; 1213 | let new_weight_per_bin = imp_sum / T::new_from_usize(new_number_of_bins); 1214 | 1215 | // resize the bins using their importance measure 1216 | let mut new_partitioning = vec![T::new_zero(); new_number_of_bins + 1]; 1217 | 1218 | // evenly distribute the bins such that each has weight_per_bin weight 1219 | let mut acc = T::new_zero(); 1220 | let mut j = 0; 1221 | let mut target = T::new_zero(); 1222 | for nb in &mut new_partitioning[1..].iter_mut() { 1223 | target += new_weight_per_bin; 1224 | // find the bin that has the accumulated weight we are looking for 1225 | while j < self.bin_importance.len() && acc + self.bin_importance[j] < target { 1226 | acc += &self.bin_importance[j]; 1227 | // prevent some rounding errors from going out of the bin 1228 | if j + 1 < self.bin_importance.len() { 1229 | j += 1; 1230 | } else { 1231 | break; 1232 | } 1233 | } 1234 | 1235 | // find out how deep we are in the current bin 1236 | let bin_depth = (target - acc) / self.bin_importance[j]; 1237 | *nb = self.partitioning[j] 1238 | + bin_depth * (self.partitioning[j + 1] - self.partitioning[j]); 1239 | } 1240 | 1241 | // it could be that all the weights are distributed before we reach 1, for example if the first bin 1242 | // has all the weights. we still force to have the complete input range 1243 | new_partitioning[new_number_of_bins] = T::new_one(); 1244 | self.partitioning = new_partitioning; 1245 | 1246 | self.bin_importance.clear(); 1247 | self.bin_importance 1248 | .resize(self.partitioning.len() - 1, T::new_zero()); 1249 | self.counter.clear(); 1250 | self.counter.resize(self.partitioning.len() - 1, 0); 1251 | self.bin_accumulator.clear(); 1252 | self.bin_accumulator 1253 | .resize(self.partitioning.len() - 1, StatisticsAccumulator::new()); 1254 | } 1255 | 1256 | /// Returns `Ok` when this grid can be merged with another grid, 1257 | /// and `Err` when the grids have a different shape. 1258 | fn is_mergeable(&self, other: &ContinuousDimension) -> Result<(), String> { 1259 | if self.partitioning != other.partitioning { 1260 | Err("Partitions do not match".to_owned()) 1261 | } else { 1262 | Ok(()) 1263 | } 1264 | } 1265 | 1266 | /// Merge a grid without checks. For internal use only. 1267 | fn merge(&mut self, other: &ContinuousDimension) { 1268 | for (bi, obi) in self.bin_accumulator.iter_mut().zip(&other.bin_accumulator) { 1269 | bi.merge_samples_no_reset(obi); 1270 | } 1271 | } 1272 | 1273 | /// Clone the dimension and remove any samples that have not been processed in an update. 1274 | pub fn clone_without_samples(&self) -> ContinuousDimension { 1275 | let mut d = self.clone(); 1276 | 1277 | for bi in &mut d.bin_importance { 1278 | *bi = T::new_zero(); 1279 | } 1280 | for c in &mut d.counter { 1281 | *c = 0; 1282 | } 1283 | 1284 | for bin in &mut d.bin_accumulator { 1285 | bin.clear_samples(); 1286 | } 1287 | d 1288 | } 1289 | } 1290 | 1291 | /// A reproducible, fast, non-cryptographic random number generator suitable for parallel Monte Carlo simulations. 1292 | /// A `seed` has to be set, which can be any `u64` number (small numbers work just as well as large numbers). 1293 | /// 1294 | /// Each thread or instance generating samples should use the same `seed` but a different `stream_id`, 1295 | /// which is an instance counter starting at 0. 1296 | pub struct MonteCarloRng { 1297 | state: Xoshiro256StarStar, 1298 | } 1299 | 1300 | impl RngCore for MonteCarloRng { 1301 | #[inline] 1302 | fn next_u32(&mut self) -> u32 { 1303 | self.state.next_u32() 1304 | } 1305 | 1306 | #[inline] 1307 | fn next_u64(&mut self) -> u64 { 1308 | self.state.next_u64() 1309 | } 1310 | 1311 | #[inline] 1312 | fn fill_bytes(&mut self, dest: &mut [u8]) { 1313 | self.state.fill_bytes(dest) 1314 | } 1315 | } 1316 | 1317 | impl MonteCarloRng { 1318 | /// Create a new random number generator with a given `seed` and `stream_id`. For parallel runs, 1319 | /// each thread or instance generating samples should use the same `seed` but a different `stream_id`. 1320 | pub fn new(seed: u64, stream_id: usize) -> Self { 1321 | let mut state = Xoshiro256StarStar::seed_from_u64(seed); 1322 | for _ in 0..stream_id { 1323 | state.jump(); 1324 | } 1325 | 1326 | Self { state } 1327 | } 1328 | } 1329 | 1330 | #[cfg(test)] 1331 | mod test { 1332 | use std::f64::consts::PI; 1333 | 1334 | use super::{ContinuousGrid, DiscreteGrid, Grid, MonteCarloRng, Sample}; 1335 | 1336 | #[test] 1337 | fn multichannel() { 1338 | // Integrate x*pi + x^2 using multi-channeling: 1339 | // x*pi and x^2 will have their own Vegas grid 1340 | let fs = [|x: f64| (x * PI).sin(), |x: f64| x * x]; 1341 | 1342 | let mut grid = DiscreteGrid::new( 1343 | vec![ 1344 | Some(Grid::Continuous(ContinuousGrid::new( 1345 | 1, 10, 1000, None, false, 1346 | ))), 1347 | Some(Grid::Continuous(ContinuousGrid::new( 1348 | 1, 10, 1000, None, false, 1349 | ))), 1350 | ], 1351 | 0.01, 1352 | false, 1353 | ); 1354 | 1355 | let mut rng = MonteCarloRng::new(0, 0); 1356 | 1357 | let mut sample = Sample::new(); 1358 | for _ in 1..20 { 1359 | // sample 10_000 times per iteration 1360 | for _ in 0..10_000 { 1361 | grid.sample(&mut rng, &mut sample); 1362 | 1363 | if let Sample::Discrete(_weight, i, cont_sample) = &sample { 1364 | if let Sample::Continuous(_cont_weight, xs) = 1365 | cont_sample.as_ref().unwrap().as_ref() 1366 | { 1367 | grid.add_training_sample(&sample, fs[*i](xs[0])).unwrap(); 1368 | } 1369 | } 1370 | } 1371 | 1372 | grid.update(1.5, 1.5); 1373 | } 1374 | 1375 | assert_eq!(grid.accumulator.avg, 0.9718412953459551); 1376 | assert_eq!(grid.accumulator.err, 0.000934925483808598) 1377 | } 1378 | 1379 | #[test] 1380 | fn uniform() { 1381 | let fs = [|x: f64| (x * PI).sin(), |x: f64| x * x, |x| x]; 1382 | 1383 | let mut grid = Grid::Uniform(vec![3, 10], ContinuousGrid::new(1, 10, 1000, None, false)); 1384 | 1385 | let mut rng = MonteCarloRng::new(0, 0); 1386 | 1387 | let mut sample = Sample::new(); 1388 | 1389 | grid.sample(&mut rng, &mut sample); 1390 | for _ in 1..20 { 1391 | // sample 10_000 times per iteration 1392 | for _ in 0..10_000 { 1393 | grid.sample(&mut rng, &mut sample); 1394 | 1395 | if let Sample::Uniform(_weight, i, cont_sample) = &sample { 1396 | grid.add_training_sample(&sample, fs[i[0]](cont_sample[0]) / 10.) 1397 | .unwrap(); 1398 | } 1399 | } 1400 | 1401 | grid.update(1.5, 1.5); 1402 | } 1403 | 1404 | let r = grid.get_statistics(); 1405 | assert_eq!(r.avg, 1.4679742806412577); 1406 | assert_eq!(r.err, 0.0018395594908128354); 1407 | } 1408 | } 1409 | --------------------------------------------------------------------------------