├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ └── ci.yml ├── .gitignore ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── benches ├── deviation.rs ├── sort.rs └── summary_statistics.rs ├── codecov.yml ├── src ├── correlation.rs ├── deviation.rs ├── entropy.rs ├── errors.rs ├── histogram │ ├── bins.rs │ ├── errors.rs │ ├── grid.rs │ ├── histograms.rs │ ├── mod.rs │ └── strategies.rs ├── lib.rs ├── maybe_nan │ ├── impl_not_none.rs │ └── mod.rs ├── quantile │ ├── interpolate.rs │ └── mod.rs ├── sort.rs └── summary_statistics │ ├── means.rs │ └── mod.rs └── tests ├── deviation.rs ├── maybe_nan.rs ├── quantile.rs ├── sort.rs └── summary_statistics.rs /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a bug report for ndarray-stats 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Description** 11 | Description of the bug. 12 | 13 | **Version Information** 14 | - `ndarray`: ??? 15 | - `ndarray-stats`: ??? 16 | - Rust: ??? 17 | 18 | Please make sure that: 19 | - the version of `ndarray-stats` you're using corresponds to the version of `ndarray` you're using 20 | - the version of the Rust compiler you're using is supported by the version of `ndarray-stats` you're using 21 | (See the "Releases" section of the README for correct version information.) 22 | 23 | **To Reproduce** 24 | Example code which doesn't work. 25 | 26 | **Expected behavior** 27 | Description of what you expected to happen. 28 | 29 | **Additional context** 30 | Add any other context about the problem here. 31 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Continuous integration 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | RUSTFLAGS: "-D warnings" 12 | 13 | jobs: 14 | 15 | test: 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | rust: 20 | - stable 21 | - beta 22 | - nightly 23 | - 1.64.0 # MSRV 24 | steps: 25 | - uses: actions/checkout@v4 26 | - uses: dtolnay/rust-toolchain@master 27 | with: 28 | toolchain: ${{ matrix.rust }} 29 | - name: Pin versions for MSRV 30 | if: "${{ matrix.rust == '1.64.0' }}" 31 | run: | 32 | cargo update -p regex --precise 1.8.4 33 | - name: Build 34 | run: cargo build --verbose 35 | - name: Run tests 36 | run: cargo test --verbose 37 | 38 | cross_test: 39 | runs-on: ubuntu-latest 40 | strategy: 41 | matrix: 42 | include: 43 | # 64-bit, big-endian 44 | - rust: stable 45 | target: s390x-unknown-linux-gnu 46 | # 32-bit, little-endian 47 | - rust: stable 48 | target: i686-unknown-linux-gnu 49 | steps: 50 | - uses: actions/checkout@v4 51 | - uses: dtolnay/rust-toolchain@master 52 | with: 53 | toolchain: ${{ matrix.rust }} 54 | target: ${{ matrix.target }} 55 | - name: Install cross 56 | run: cargo install cross -f 57 | - name: Build 58 | run: cross build --verbose --target=${{ matrix.target }} 59 | - name: Run tests 60 | run: cross test --verbose --target=${{ matrix.target }} 61 | 62 | format: 63 | runs-on: ubuntu-latest 64 | strategy: 65 | matrix: 66 | rust: 67 | - stable 68 | steps: 69 | - uses: actions/checkout@v2 70 | - uses: actions-rs/toolchain@v1 71 | with: 72 | profile: minimal 73 | toolchain: ${{ matrix.rust }} 74 | override: true 75 | components: rustfmt 76 | - name: Rustfmt 77 | run: cargo fmt -- --check 78 | 79 | coverage: 80 | runs-on: ubuntu-latest 81 | strategy: 82 | matrix: 83 | rust: 84 | - nightly 85 | steps: 86 | - uses: actions/checkout@v4 87 | - uses: dtolnay/rust-toolchain@master 88 | with: 89 | toolchain: ${{ matrix.rust }} 90 | - name: Install tarpaulin 91 | uses: taiki-e/cache-cargo-install-action@v2 92 | with: 93 | tool: cargo-tarpaulin 94 | - name: Generate code coverage 95 | run: cargo tarpaulin --verbose --all-features --workspace --timeout 120 --out Xml 96 | - name: Upload to codecov.io 97 | uses: codecov/codecov-action@v4 98 | env: 99 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 100 | with: 101 | fail_ci_if_error: true 102 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | Cargo.lock 4 | 5 | # IDE-related 6 | tags 7 | rusty-tags.vi 8 | .vscode -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ndarray-stats" 3 | version = "0.6.0" 4 | authors = ["Jim Turner ", "LukeMathWalker "] 5 | edition = "2018" 6 | 7 | license = "MIT/Apache-2.0" 8 | 9 | repository = "https://github.com/rust-ndarray/ndarray-stats" 10 | documentation = "https://docs.rs/ndarray-stats/" 11 | readme = "README.md" 12 | 13 | description = "Statistical routines for ArrayBase, the n-dimensional array data structure provided by ndarray." 14 | 15 | keywords = ["array", "multidimensional", "statistics", "matrix", "ndarray"] 16 | categories = ["data-structures", "science"] 17 | 18 | [dependencies] 19 | ndarray = "0.16.0" 20 | noisy_float = "0.2.0" 21 | num-integer = "0.1" 22 | num-traits = "0.2" 23 | rand = "0.8.3" 24 | itertools = { version = "0.13", default-features = false } 25 | indexmap = "2.4" 26 | 27 | [dev-dependencies] 28 | ndarray = { version = "0.16.1", features = ["approx"] } 29 | criterion = "0.3" 30 | quickcheck = { version = "0.9.2", default-features = false } 31 | ndarray-rand = "0.15.0" 32 | approx = "0.5" 33 | quickcheck_macros = "1.0.0" 34 | num-bigint = "0.4.0" 35 | 36 | [[bench]] 37 | name = "sort" 38 | harness = false 39 | 40 | [[bench]] 41 | name = "summary_statistics" 42 | harness = false 43 | 44 | [[bench]] 45 | name = "deviation" 46 | harness = false 47 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright 2018–2024 ndarray-stats developers 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ndarray-stats 2 | 3 | [![Coverage](https://codecov.io/gh/rust-ndarray/ndarray-stats/branch/master/graph/badge.svg)](https://codecov.io/gh/rust-ndarray/ndarray-stats) 4 | [![Dependencies status](https://deps.rs/repo/github/rust-ndarray/ndarray-stats/status.svg)](https://deps.rs/repo/github/rust-ndarray/ndarray-stats) 5 | [![Crate](https://img.shields.io/crates/v/ndarray-stats.svg)](https://crates.io/crates/ndarray-stats) 6 | [![Documentation](https://docs.rs/ndarray-stats/badge.svg)](https://docs.rs/ndarray-stats) 7 | 8 | This crate provides statistical methods for [`ndarray`]'s `ArrayBase` type. 9 | 10 | Currently available routines include: 11 | - order statistics (minimum, maximum, median, quantiles, etc.); 12 | - summary statistics (mean, skewness, kurtosis, central moments, etc.) 13 | - partitioning; 14 | - correlation analysis (covariance, pearson correlation); 15 | - measures from information theory (entropy, KL divergence, etc.); 16 | - deviation functions (distances, counts, errors, etc.); 17 | - histogram computation. 18 | 19 | See the [documentation](https://docs.rs/ndarray-stats) for more information. 20 | 21 | Please feel free to contribute new functionality! A roadmap can be found [here](https://github.com/rust-ndarray/ndarray-stats/issues/1). 22 | 23 | [`ndarray`]: https://github.com/rust-ndarray/ndarray 24 | 25 | ## Using with Cargo 26 | 27 | ```toml 28 | [dependencies] 29 | ndarray = "0.16" 30 | ndarray-stats = "0.6.0" 31 | ``` 32 | 33 | ## Releases 34 | 35 | * **0.6.0** 36 | 37 | * Breaking changes 38 | * Minimum supported Rust version: `1.64.0` 39 | * Updated to `ndarray:v0.16.0` 40 | * Updated to `approx:v0.5.0` 41 | 42 | * Updated to `ndarray-rand:v0.15.0` 43 | * Updated to `indexmap:v2.4` 44 | * Updated to `itertools:v0.13` 45 | 46 | *Contributors*: [@bluss](https://github.com/bluss) 47 | 48 | * **0.5.1** 49 | * Fixed bug in implementation of `MaybeNaN::remove_nan_mut` for `f32` and 50 | `f64` for views with non-standard layouts. Before this fix, the bug could 51 | cause incorrect results, buffer overflows, etc., in this method and others 52 | which use it. Thanks to [@JacekCzupyt](https://github.com/JacekCzupyt) for 53 | reporting the issue (#89). 54 | * Minor docs improvements. 55 | 56 | *Contributors*: [@jturner314](https://github.com/jturner314), [@BenMoon](https://github.com/BenMoon) 57 | 58 | * **0.5.0** 59 | * Breaking changes 60 | * Minimum supported Rust version: `1.49.0` 61 | * Updated to `ndarray:v0.15.0` 62 | 63 | *Contributors*: [@Armavica](https://github.com/armavica), [@cassiersg](https://github.com/cassiersg) 64 | 65 | * **0.4.0** 66 | * Breaking changes 67 | * Minimum supported Rust version: `1.42.0` 68 | * New functionality: 69 | * Summary statistics: 70 | * Weighted variance 71 | * Weighted standard deviation 72 | * Improvements / breaking changes: 73 | * Documentation improvements for Histograms 74 | * Updated to `ndarray:v0.14.0` 75 | 76 | *Contributors*: [@munckymagik](https://github.com/munckymagik), [@nilgoyette](https://github.com/nilgoyette), [@LukeMathWalker](https://github.com/LukeMathWalker), [@lebensterben](https://github.com/lebensterben), [@xd009642](https://github.com/xd009642) 77 | 78 | * **0.3.0** 79 | 80 | * Breaking changes 81 | * Minimum supported Rust version: `1.37` 82 | * New functionality: 83 | * Deviation functions: 84 | * Counts equal/unequal 85 | * `l1`, `l2`, `linf` distances 86 | * (Root) mean squared error 87 | * Peak signal-to-noise ratio 88 | * Summary statistics: 89 | * Weighted sum 90 | * Weighted mean 91 | * Improvements / breaking changes: 92 | * Updated to `ndarray:v0.13.0` 93 | 94 | *Contributors*: [@munckymagik](https://github.com/munckymagik), [@nilgoyette](https://github.com/nilgoyette), [@jturner314](https://github.com/jturner314), [@LukeMathWalker](https://github.com/LukeMathWalker) 95 | 96 | * **0.2.0** 97 | 98 | * Breaking changes 99 | * All `ndarray-stats`' extension traits are now impossible to implement by 100 | users of the library (see [#34]) 101 | * Redesigned error handling across the whole crate, standardising on `Result` 102 | * New functionality: 103 | * Summary statistics: 104 | * Harmonic mean 105 | * Geometric mean 106 | * Central moments 107 | * Kurtosis 108 | * Skewness 109 | * Information theory: 110 | * Entropy 111 | * Cross-entropy 112 | * Kullback-Leibler divergence 113 | * Quantiles and order statistics: 114 | * `argmin` / `argmin_skipnan` 115 | * `argmax` / `argmax_skipnan` 116 | * Optimized bulk quantile computation (`quantiles_mut`, `quantiles_axis_mut`) 117 | * Fixes: 118 | * Reduced occurrences of overflow for `interpolate::midpoint` 119 | 120 | *Contributors*: [@jturner314](https://github.com/jturner314), [@LukeMathWalker](https://github.com/LukeMathWalker), [@phungleson](https://github.com/phungleson), [@munckymagik](https://github.com/munckymagik) 121 | 122 | [#34]: https://github.com/rust-ndarray/ndarray-stats/issues/34 123 | 124 | * **0.1.0** 125 | 126 | * Initial release by @LukeMathWalker and @jturner314. 127 | 128 | ## Contributing 129 | 130 | Please feel free to create issues and submit PRs. 131 | 132 | ## License 133 | 134 | Copyright 2018–2024 `ndarray-stats` developers 135 | 136 | Licensed under the [Apache License, Version 2.0](LICENSE-APACHE), or the [MIT 137 | license](LICENSE-MIT), at your option. You may not use this project except in 138 | compliance with those terms. 139 | -------------------------------------------------------------------------------- /benches/deviation.rs: -------------------------------------------------------------------------------- 1 | use criterion::{ 2 | black_box, criterion_group, criterion_main, AxisScale, Criterion, PlotConfiguration, 3 | }; 4 | use ndarray::prelude::*; 5 | use ndarray_rand::rand_distr::Uniform; 6 | use ndarray_rand::RandomExt; 7 | use ndarray_stats::DeviationExt; 8 | 9 | fn sq_l2_dist(c: &mut Criterion) { 10 | let lens = vec![10, 100, 1000, 10000]; 11 | let mut group = c.benchmark_group("sq_l2_dist"); 12 | group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); 13 | for len in &lens { 14 | group.bench_with_input(format!("{}", len), len, |b, &len| { 15 | let data = Array::random(len, Uniform::new(0.0, 1.0)); 16 | let data2 = Array::random(len, Uniform::new(0.0, 1.0)); 17 | 18 | b.iter(|| black_box(data.sq_l2_dist(&data2).unwrap())) 19 | }); 20 | } 21 | group.finish(); 22 | } 23 | 24 | criterion_group! { 25 | name = benches; 26 | config = Criterion::default(); 27 | targets = sq_l2_dist 28 | } 29 | criterion_main!(benches); 30 | -------------------------------------------------------------------------------- /benches/sort.rs: -------------------------------------------------------------------------------- 1 | use criterion::{ 2 | black_box, criterion_group, criterion_main, AxisScale, BatchSize, Criterion, PlotConfiguration, 3 | }; 4 | use ndarray::prelude::*; 5 | use ndarray_stats::Sort1dExt; 6 | use rand::prelude::*; 7 | 8 | fn get_from_sorted_mut(c: &mut Criterion) { 9 | let lens = vec![10, 100, 1000, 10000]; 10 | let mut group = c.benchmark_group("get_from_sorted_mut"); 11 | group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); 12 | for len in &lens { 13 | group.bench_with_input(format!("{}", len), len, |b, &len| { 14 | let mut rng = StdRng::seed_from_u64(42); 15 | let mut data: Vec<_> = (0..len).collect(); 16 | data.shuffle(&mut rng); 17 | let indices: Vec<_> = (0..len).step_by(len / 10).collect(); 18 | b.iter_batched( 19 | || Array1::from(data.clone()), 20 | |mut arr| { 21 | for &i in &indices { 22 | black_box(arr.get_from_sorted_mut(i)); 23 | } 24 | }, 25 | BatchSize::SmallInput, 26 | ) 27 | }); 28 | } 29 | group.finish(); 30 | } 31 | 32 | fn get_many_from_sorted_mut(c: &mut Criterion) { 33 | let lens = vec![10, 100, 1000, 10000]; 34 | let mut group = c.benchmark_group("get_many_from_sorted_mut"); 35 | group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); 36 | for len in &lens { 37 | group.bench_with_input(format!("{}", len), len, |b, &len| { 38 | let mut rng = StdRng::seed_from_u64(42); 39 | let mut data: Vec<_> = (0..len).collect(); 40 | data.shuffle(&mut rng); 41 | let indices: Array1<_> = (0..len).step_by(len / 10).collect(); 42 | b.iter_batched( 43 | || Array1::from(data.clone()), 44 | |mut arr| { 45 | black_box(arr.get_many_from_sorted_mut(&indices)); 46 | }, 47 | BatchSize::SmallInput, 48 | ) 49 | }); 50 | } 51 | group.finish(); 52 | } 53 | 54 | criterion_group! { 55 | name = benches; 56 | config = Criterion::default(); 57 | targets = get_from_sorted_mut, get_many_from_sorted_mut 58 | } 59 | criterion_main!(benches); 60 | -------------------------------------------------------------------------------- /benches/summary_statistics.rs: -------------------------------------------------------------------------------- 1 | use criterion::{ 2 | black_box, criterion_group, criterion_main, AxisScale, BatchSize, Criterion, PlotConfiguration, 3 | }; 4 | use ndarray::prelude::*; 5 | use ndarray_rand::rand_distr::Uniform; 6 | use ndarray_rand::RandomExt; 7 | use ndarray_stats::SummaryStatisticsExt; 8 | 9 | fn weighted_std(c: &mut Criterion) { 10 | let lens = vec![10, 100, 1000, 10000]; 11 | let mut group = c.benchmark_group("weighted_std"); 12 | group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); 13 | for len in &lens { 14 | group.bench_with_input(format!("{}", len), len, |b, &len| { 15 | let data = Array::random(len, Uniform::new(0.0, 1.0)); 16 | let mut weights = Array::random(len, Uniform::new(0.0, 1.0)); 17 | weights /= weights.sum(); 18 | b.iter_batched( 19 | || data.clone(), 20 | |arr| { 21 | black_box(arr.weighted_std(&weights, 0.0).unwrap()); 22 | }, 23 | BatchSize::SmallInput, 24 | ) 25 | }); 26 | } 27 | group.finish(); 28 | } 29 | 30 | criterion_group! { 31 | name = benches; 32 | config = Criterion::default(); 33 | targets = weighted_std 34 | } 35 | criterion_main!(benches); 36 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: off 2 | coverage: 3 | status: 4 | project: 5 | default: 6 | target: auto 7 | threshold: 2 8 | base: auto 9 | patch: 10 | default: 11 | target: auto 12 | threshold: 2 13 | base: auto 14 | -------------------------------------------------------------------------------- /src/correlation.rs: -------------------------------------------------------------------------------- 1 | use crate::errors::EmptyInput; 2 | use ndarray::prelude::*; 3 | use ndarray::Data; 4 | use num_traits::{Float, FromPrimitive}; 5 | 6 | /// Extension trait for `ArrayBase` providing functions 7 | /// to compute different correlation measures. 8 | pub trait CorrelationExt 9 | where 10 | S: Data, 11 | { 12 | /// Return the covariance matrix `C` for a 2-dimensional 13 | /// array of observations `M`. 14 | /// 15 | /// Let `(r, o)` be the shape of `M`: 16 | /// - `r` is the number of random variables; 17 | /// - `o` is the number of observations we have collected 18 | /// for each random variable. 19 | /// 20 | /// Every column in `M` is an experiment: a single observation for each 21 | /// random variable. 22 | /// Each row in `M` contains all the observations for a certain random variable. 23 | /// 24 | /// The parameter `ddof` specifies the "delta degrees of freedom". For 25 | /// example, to calculate the population covariance, use `ddof = 0`, or to 26 | /// calculate the sample covariance (unbiased estimate), use `ddof = 1`. 27 | /// 28 | /// The covariance of two random variables is defined as: 29 | /// 30 | /// ```text 31 | /// 1 n 32 | /// cov(X, Y) = ―――――――― ∑ (xᵢ - x̅)(yᵢ - y̅) 33 | /// n - ddof i=1 34 | /// ``` 35 | /// 36 | /// where 37 | /// 38 | /// ```text 39 | /// 1 n 40 | /// x̅ = ― ∑ xᵢ 41 | /// n i=1 42 | /// ``` 43 | /// and similarly for ̅y. 44 | /// 45 | /// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`. 46 | /// 47 | /// **Panics** if `ddof` is negative or greater than or equal to the number of 48 | /// observations, or if the type cast of `n_observations` from `usize` to `A` fails. 49 | /// 50 | /// # Example 51 | /// 52 | /// ``` 53 | /// use ndarray::{aview2, arr2}; 54 | /// use ndarray_stats::CorrelationExt; 55 | /// 56 | /// let a = arr2(&[[1., 3., 5.], 57 | /// [2., 4., 6.]]); 58 | /// let covariance = a.cov(1.).unwrap(); 59 | /// assert_eq!( 60 | /// covariance, 61 | /// aview2(&[[4., 4.], [4., 4.]]) 62 | /// ); 63 | /// ``` 64 | fn cov(&self, ddof: A) -> Result, EmptyInput> 65 | where 66 | A: Float + FromPrimitive; 67 | 68 | /// Return the [Pearson correlation coefficients](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient) 69 | /// for a 2-dimensional array of observations `M`. 70 | /// 71 | /// Let `(r, o)` be the shape of `M`: 72 | /// - `r` is the number of random variables; 73 | /// - `o` is the number of observations we have collected 74 | /// for each random variable. 75 | /// 76 | /// Every column in `M` is an experiment: a single observation for each 77 | /// random variable. 78 | /// Each row in `M` contains all the observations for a certain random variable. 79 | /// 80 | /// The Pearson correlation coefficient of two random variables is defined as: 81 | /// 82 | /// ```text 83 | /// cov(X, Y) 84 | /// rho(X, Y) = ―――――――――――― 85 | /// std(X)std(Y) 86 | /// ``` 87 | /// 88 | /// Let `R` be the matrix returned by this function. Then 89 | /// ```text 90 | /// R_ij = rho(X_i, X_j) 91 | /// ``` 92 | /// 93 | /// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`. 94 | /// 95 | /// **Panics** if the type cast of `n_observations` from `usize` to `A` fails or 96 | /// if the standard deviation of one of the random variables is zero and 97 | /// division by zero panics for type A. 98 | /// 99 | /// # Example 100 | /// 101 | /// ``` 102 | /// use approx; 103 | /// use ndarray::arr2; 104 | /// use ndarray_stats::CorrelationExt; 105 | /// use approx::AbsDiffEq; 106 | /// 107 | /// let a = arr2(&[[1., 3., 5.], 108 | /// [2., 4., 6.]]); 109 | /// let corr = a.pearson_correlation().unwrap(); 110 | /// let epsilon = 1e-7; 111 | /// assert!( 112 | /// corr.abs_diff_eq( 113 | /// &arr2(&[ 114 | /// [1., 1.], 115 | /// [1., 1.], 116 | /// ]), 117 | /// epsilon 118 | /// ) 119 | /// ); 120 | /// ``` 121 | fn pearson_correlation(&self) -> Result, EmptyInput> 122 | where 123 | A: Float + FromPrimitive; 124 | 125 | private_decl! {} 126 | } 127 | 128 | impl CorrelationExt for ArrayBase 129 | where 130 | S: Data, 131 | { 132 | fn cov(&self, ddof: A) -> Result, EmptyInput> 133 | where 134 | A: Float + FromPrimitive, 135 | { 136 | let observation_axis = Axis(1); 137 | let n_observations = A::from_usize(self.len_of(observation_axis)).unwrap(); 138 | let dof = if ddof >= n_observations { 139 | panic!( 140 | "`ddof` needs to be strictly smaller than the \ 141 | number of observations provided for each \ 142 | random variable!" 143 | ) 144 | } else { 145 | n_observations - ddof 146 | }; 147 | let mean = self.mean_axis(observation_axis); 148 | match mean { 149 | Some(mean) => { 150 | let denoised = self - &mean.insert_axis(observation_axis); 151 | let covariance = denoised.dot(&denoised.t()); 152 | Ok(covariance.mapv_into(|x| x / dof)) 153 | } 154 | None => Err(EmptyInput), 155 | } 156 | } 157 | 158 | fn pearson_correlation(&self) -> Result, EmptyInput> 159 | where 160 | A: Float + FromPrimitive, 161 | { 162 | match self.dim() { 163 | (n, m) if n > 0 && m > 0 => { 164 | let observation_axis = Axis(1); 165 | // The ddof value doesn't matter, as long as we use the same one 166 | // for computing covariance and standard deviation 167 | // We choose 0 as it is the smallest number admitted by std_axis 168 | let ddof = A::zero(); 169 | let cov = self.cov(ddof).unwrap(); 170 | let std = self 171 | .std_axis(observation_axis, ddof) 172 | .insert_axis(observation_axis); 173 | let std_matrix = std.dot(&std.t()); 174 | // element-wise division 175 | Ok(cov / std_matrix) 176 | } 177 | _ => Err(EmptyInput), 178 | } 179 | } 180 | 181 | private_impl! {} 182 | } 183 | 184 | #[cfg(test)] 185 | mod cov_tests { 186 | use super::*; 187 | use ndarray::array; 188 | use ndarray_rand::rand; 189 | use ndarray_rand::rand_distr::Uniform; 190 | use ndarray_rand::RandomExt; 191 | use quickcheck_macros::quickcheck; 192 | 193 | #[quickcheck] 194 | fn constant_random_variables_have_zero_covariance_matrix(value: f64) -> bool { 195 | let n_random_variables = 3; 196 | let n_observations = 4; 197 | let a = Array::from_elem((n_random_variables, n_observations), value); 198 | abs_diff_eq!( 199 | a.cov(1.).unwrap(), 200 | &Array::zeros((n_random_variables, n_random_variables)), 201 | epsilon = 1e-8, 202 | ) 203 | } 204 | 205 | #[quickcheck] 206 | fn covariance_matrix_is_symmetric(bound: f64) -> bool { 207 | let n_random_variables = 3; 208 | let n_observations = 4; 209 | let a = Array::random( 210 | (n_random_variables, n_observations), 211 | Uniform::new(-bound.abs(), bound.abs()), 212 | ); 213 | let covariance = a.cov(1.).unwrap(); 214 | abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8) 215 | } 216 | 217 | #[test] 218 | #[should_panic] 219 | fn test_invalid_ddof() { 220 | let n_random_variables = 3; 221 | let n_observations = 4; 222 | let a = Array::random((n_random_variables, n_observations), Uniform::new(0., 10.)); 223 | let invalid_ddof = (n_observations as f64) + rand::random::().abs(); 224 | let _ = a.cov(invalid_ddof); 225 | } 226 | 227 | #[test] 228 | fn test_covariance_zero_variables() { 229 | let a = Array2::::zeros((0, 2)); 230 | let cov = a.cov(1.); 231 | assert!(cov.is_ok()); 232 | assert_eq!(cov.unwrap().shape(), &[0, 0]); 233 | } 234 | 235 | #[test] 236 | fn test_covariance_zero_observations() { 237 | let a = Array2::::zeros((2, 0)); 238 | // Negative ddof (-1 < 0) to avoid invalid-ddof panic 239 | let cov = a.cov(-1.); 240 | assert_eq!(cov, Err(EmptyInput)); 241 | } 242 | 243 | #[test] 244 | fn test_covariance_zero_variables_zero_observations() { 245 | let a = Array2::::zeros((0, 0)); 246 | // Negative ddof (-1 < 0) to avoid invalid-ddof panic 247 | let cov = a.cov(-1.); 248 | assert_eq!(cov, Err(EmptyInput)); 249 | } 250 | 251 | #[test] 252 | fn test_covariance_for_random_array() { 253 | let a = array![ 254 | [0.72009497, 0.12568055, 0.55705966, 0.5959984, 0.69471457], 255 | [0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245], 256 | [0.59044195, 0.10720363, 0.76573717, 0.54693675, 0.95923036], 257 | [0.24102952, 0.131347, 0.11118028, 0.21451351, 0.30515539], 258 | [0.26952473, 0.93079841, 0.8080893, 0.42814155, 0.24642258] 259 | ]; 260 | let numpy_covariance = array![ 261 | [0.05786248, 0.02614063, 0.06446215, 0.01285105, -0.06443992], 262 | [0.02614063, 0.08733569, 0.02436933, 0.01977437, -0.06715555], 263 | [0.06446215, 0.02436933, 0.10052129, 0.01393589, -0.06129912], 264 | [0.01285105, 0.01977437, 0.01393589, 0.00638795, -0.02355557], 265 | [ 266 | -0.06443992, 267 | -0.06715555, 268 | -0.06129912, 269 | -0.02355557, 270 | 0.09909855 271 | ] 272 | ]; 273 | assert_eq!(a.ndim(), 2); 274 | assert_abs_diff_eq!(a.cov(1.).unwrap(), &numpy_covariance, epsilon = 1e-8); 275 | } 276 | 277 | #[test] 278 | #[should_panic] 279 | // We lose precision, hence the failing assert 280 | fn test_covariance_for_badly_conditioned_array() { 281 | let a: Array2 = array![[1e12 + 1., 1e12 - 1.], [1e-6 + 1e-12, 1e-6 - 1e-12],]; 282 | let expected_covariance = array![[2., 2e-12], [2e-12, 2e-24]]; 283 | assert_abs_diff_eq!(a.cov(1.).unwrap(), &expected_covariance, epsilon = 1e-24); 284 | } 285 | } 286 | 287 | #[cfg(test)] 288 | mod pearson_correlation_tests { 289 | use super::*; 290 | use ndarray::array; 291 | use ndarray::Array; 292 | use ndarray_rand::rand_distr::Uniform; 293 | use ndarray_rand::RandomExt; 294 | use quickcheck_macros::quickcheck; 295 | 296 | #[quickcheck] 297 | fn output_matrix_is_symmetric(bound: f64) -> bool { 298 | let n_random_variables = 3; 299 | let n_observations = 4; 300 | let a = Array::random( 301 | (n_random_variables, n_observations), 302 | Uniform::new(-bound.abs(), bound.abs()), 303 | ); 304 | let pearson_correlation = a.pearson_correlation().unwrap(); 305 | abs_diff_eq!( 306 | pearson_correlation.view(), 307 | pearson_correlation.t(), 308 | epsilon = 1e-8 309 | ) 310 | } 311 | 312 | #[quickcheck] 313 | fn constant_random_variables_have_nan_correlation(value: f64) -> bool { 314 | let n_random_variables = 3; 315 | let n_observations = 4; 316 | let a = Array::from_elem((n_random_variables, n_observations), value); 317 | let pearson_correlation = a.pearson_correlation(); 318 | pearson_correlation 319 | .unwrap() 320 | .iter() 321 | .map(|x| x.is_nan()) 322 | .fold(true, |acc, flag| acc & flag) 323 | } 324 | 325 | #[test] 326 | fn test_zero_variables() { 327 | let a = Array2::::zeros((0, 2)); 328 | let pearson_correlation = a.pearson_correlation(); 329 | assert_eq!(pearson_correlation, Err(EmptyInput)) 330 | } 331 | 332 | #[test] 333 | fn test_zero_observations() { 334 | let a = Array2::::zeros((2, 0)); 335 | let pearson = a.pearson_correlation(); 336 | assert_eq!(pearson, Err(EmptyInput)); 337 | } 338 | 339 | #[test] 340 | fn test_zero_variables_zero_observations() { 341 | let a = Array2::::zeros((0, 0)); 342 | let pearson = a.pearson_correlation(); 343 | assert_eq!(pearson, Err(EmptyInput)); 344 | } 345 | 346 | #[test] 347 | fn test_for_random_array() { 348 | let a = array![ 349 | [0.16351516, 0.56863268, 0.16924196, 0.72579120], 350 | [0.44342453, 0.19834387, 0.25411802, 0.62462382], 351 | [0.97162731, 0.29958849, 0.17338142, 0.80198342], 352 | [0.91727132, 0.79817799, 0.62237124, 0.38970998], 353 | [0.26979716, 0.20887228, 0.95454999, 0.96290785] 354 | ]; 355 | let numpy_corrcoeff = array![ 356 | [1., 0.38089376, 0.08122504, -0.59931623, 0.1365648], 357 | [0.38089376, 1., 0.80918429, -0.52615195, 0.38954398], 358 | [0.08122504, 0.80918429, 1., 0.07134906, -0.17324776], 359 | [-0.59931623, -0.52615195, 0.07134906, 1., -0.8743213], 360 | [0.1365648, 0.38954398, -0.17324776, -0.8743213, 1.] 361 | ]; 362 | assert_eq!(a.ndim(), 2); 363 | assert_abs_diff_eq!( 364 | a.pearson_correlation().unwrap(), 365 | numpy_corrcoeff, 366 | epsilon = 1e-7 367 | ); 368 | } 369 | } 370 | -------------------------------------------------------------------------------- /src/deviation.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{ArrayBase, Data, Dimension, Zip}; 2 | use num_traits::{Signed, ToPrimitive}; 3 | use std::convert::Into; 4 | use std::ops::AddAssign; 5 | 6 | use crate::errors::MultiInputError; 7 | 8 | /// An extension trait for `ArrayBase` providing functions 9 | /// to compute different deviation measures. 10 | pub trait DeviationExt 11 | where 12 | S: Data, 13 | D: Dimension, 14 | { 15 | /// Counts the number of indices at which the elements of the arrays `self` 16 | /// and `other` are equal. 17 | /// 18 | /// The following **errors** may be returned: 19 | /// 20 | /// * `MultiInputError::EmptyInput` if `self` is empty 21 | /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape 22 | fn count_eq(&self, other: &ArrayBase) -> Result 23 | where 24 | A: PartialEq, 25 | T: Data; 26 | 27 | /// Counts the number of indices at which the elements of the arrays `self` 28 | /// and `other` are not equal. 29 | /// 30 | /// The following **errors** may be returned: 31 | /// 32 | /// * `MultiInputError::EmptyInput` if `self` is empty 33 | /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape 34 | fn count_neq(&self, other: &ArrayBase) -> Result 35 | where 36 | A: PartialEq, 37 | T: Data; 38 | 39 | /// Computes the [squared L2 distance] between `self` and `other`. 40 | /// 41 | /// ```text 42 | /// n 43 | /// ∑ |aᵢ - bᵢ|² 44 | /// i=1 45 | /// ``` 46 | /// 47 | /// where `self` is `a` and `other` is `b`. 48 | /// 49 | /// The following **errors** may be returned: 50 | /// 51 | /// * `MultiInputError::EmptyInput` if `self` is empty 52 | /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape 53 | /// 54 | /// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance 55 | fn sq_l2_dist(&self, other: &ArrayBase) -> Result 56 | where 57 | A: AddAssign + Clone + Signed, 58 | T: Data; 59 | 60 | /// Computes the [L2 distance] between `self` and `other`. 61 | /// 62 | /// ```text 63 | /// n 64 | /// √ ( ∑ |aᵢ - bᵢ|² ) 65 | /// i=1 66 | /// ``` 67 | /// 68 | /// where `self` is `a` and `other` is `b`. 69 | /// 70 | /// The following **errors** may be returned: 71 | /// 72 | /// * `MultiInputError::EmptyInput` if `self` is empty 73 | /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape 74 | /// 75 | /// **Panics** if the type cast from `A` to `f64` fails. 76 | /// 77 | /// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance 78 | fn l2_dist(&self, other: &ArrayBase) -> Result 79 | where 80 | A: AddAssign + Clone + Signed + ToPrimitive, 81 | T: Data; 82 | 83 | /// Computes the [L1 distance] between `self` and `other`. 84 | /// 85 | /// ```text 86 | /// n 87 | /// ∑ |aᵢ - bᵢ| 88 | /// i=1 89 | /// ``` 90 | /// 91 | /// where `self` is `a` and `other` is `b`. 92 | /// 93 | /// The following **errors** may be returned: 94 | /// 95 | /// * `MultiInputError::EmptyInput` if `self` is empty 96 | /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape 97 | /// 98 | /// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry 99 | fn l1_dist(&self, other: &ArrayBase) -> Result 100 | where 101 | A: AddAssign + Clone + Signed, 102 | T: Data; 103 | 104 | /// Computes the [L∞ distance] between `self` and `other`. 105 | /// 106 | /// ```text 107 | /// max(|aᵢ - bᵢ|) 108 | /// ᵢ 109 | /// ``` 110 | /// 111 | /// where `self` is `a` and `other` is `b`. 112 | /// 113 | /// The following **errors** may be returned: 114 | /// 115 | /// * `MultiInputError::EmptyInput` if `self` is empty 116 | /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape 117 | /// 118 | /// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance 119 | fn linf_dist(&self, other: &ArrayBase) -> Result 120 | where 121 | A: Clone + PartialOrd + Signed, 122 | T: Data; 123 | 124 | /// Computes the [mean absolute error] between `self` and `other`. 125 | /// 126 | /// ```text 127 | /// n 128 | /// 1/n * ∑ |aᵢ - bᵢ| 129 | /// i=1 130 | /// ``` 131 | /// 132 | /// where `self` is `a` and `other` is `b`. 133 | /// 134 | /// The following **errors** may be returned: 135 | /// 136 | /// * `MultiInputError::EmptyInput` if `self` is empty 137 | /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape 138 | /// 139 | /// **Panics** if the type cast from `A` to `f64` fails. 140 | /// 141 | /// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error 142 | fn mean_abs_err(&self, other: &ArrayBase) -> Result 143 | where 144 | A: AddAssign + Clone + Signed + ToPrimitive, 145 | T: Data; 146 | 147 | /// Computes the [mean squared error] between `self` and `other`. 148 | /// 149 | /// ```text 150 | /// n 151 | /// 1/n * ∑ |aᵢ - bᵢ|² 152 | /// i=1 153 | /// ``` 154 | /// 155 | /// where `self` is `a` and `other` is `b`. 156 | /// 157 | /// The following **errors** may be returned: 158 | /// 159 | /// * `MultiInputError::EmptyInput` if `self` is empty 160 | /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape 161 | /// 162 | /// **Panics** if the type cast from `A` to `f64` fails. 163 | /// 164 | /// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error 165 | fn mean_sq_err(&self, other: &ArrayBase) -> Result 166 | where 167 | A: AddAssign + Clone + Signed + ToPrimitive, 168 | T: Data; 169 | 170 | /// Computes the unnormalized [root-mean-square error] between `self` and `other`. 171 | /// 172 | /// ```text 173 | /// √ mse(a, b) 174 | /// ``` 175 | /// 176 | /// where `self` is `a`, `other` is `b` and `mse` is the mean-squared-error. 177 | /// 178 | /// The following **errors** may be returned: 179 | /// 180 | /// * `MultiInputError::EmptyInput` if `self` is empty 181 | /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape 182 | /// 183 | /// **Panics** if the type cast from `A` to `f64` fails. 184 | /// 185 | /// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation 186 | fn root_mean_sq_err(&self, other: &ArrayBase) -> Result 187 | where 188 | A: AddAssign + Clone + Signed + ToPrimitive, 189 | T: Data; 190 | 191 | /// Computes the [peak signal-to-noise ratio] between `self` and `other`. 192 | /// 193 | /// ```text 194 | /// 10 * log10(maxv^2 / mse(a, b)) 195 | /// ``` 196 | /// 197 | /// where `self` is `a`, `other` is `b`, `mse` is the mean-squared-error 198 | /// and `maxv` is the maximum possible value either array can take. 199 | /// 200 | /// The following **errors** may be returned: 201 | /// 202 | /// * `MultiInputError::EmptyInput` if `self` is empty 203 | /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape 204 | /// 205 | /// **Panics** if the type cast from `A` to `f64` fails. 206 | /// 207 | /// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 208 | fn peak_signal_to_noise_ratio( 209 | &self, 210 | other: &ArrayBase, 211 | maxv: A, 212 | ) -> Result 213 | where 214 | A: AddAssign + Clone + Signed + ToPrimitive, 215 | T: Data; 216 | 217 | private_decl! {} 218 | } 219 | 220 | impl DeviationExt for ArrayBase 221 | where 222 | S: Data, 223 | D: Dimension, 224 | { 225 | fn count_eq(&self, other: &ArrayBase) -> Result 226 | where 227 | A: PartialEq, 228 | T: Data, 229 | { 230 | return_err_if_empty!(self); 231 | return_err_unless_same_shape!(self, other); 232 | 233 | let mut count = 0; 234 | 235 | Zip::from(self).and(other).for_each(|a, b| { 236 | if a == b { 237 | count += 1; 238 | } 239 | }); 240 | 241 | Ok(count) 242 | } 243 | 244 | fn count_neq(&self, other: &ArrayBase) -> Result 245 | where 246 | A: PartialEq, 247 | T: Data, 248 | { 249 | self.count_eq(other).map(|n_eq| self.len() - n_eq) 250 | } 251 | 252 | fn sq_l2_dist(&self, other: &ArrayBase) -> Result 253 | where 254 | A: AddAssign + Clone + Signed, 255 | T: Data, 256 | { 257 | return_err_if_empty!(self); 258 | return_err_unless_same_shape!(self, other); 259 | 260 | let mut result = A::zero(); 261 | 262 | Zip::from(self).and(other).for_each(|self_i, other_i| { 263 | let (a, b) = (self_i.clone(), other_i.clone()); 264 | let diff = a - b; 265 | result += diff.clone() * diff; 266 | }); 267 | 268 | Ok(result) 269 | } 270 | 271 | fn l2_dist(&self, other: &ArrayBase) -> Result 272 | where 273 | A: AddAssign + Clone + Signed + ToPrimitive, 274 | T: Data, 275 | { 276 | let sq_l2_dist = self 277 | .sq_l2_dist(other)? 278 | .to_f64() 279 | .expect("failed cast from type A to f64"); 280 | 281 | Ok(sq_l2_dist.sqrt()) 282 | } 283 | 284 | fn l1_dist(&self, other: &ArrayBase) -> Result 285 | where 286 | A: AddAssign + Clone + Signed, 287 | T: Data, 288 | { 289 | return_err_if_empty!(self); 290 | return_err_unless_same_shape!(self, other); 291 | 292 | let mut result = A::zero(); 293 | 294 | Zip::from(self).and(other).for_each(|self_i, other_i| { 295 | let (a, b) = (self_i.clone(), other_i.clone()); 296 | result += (a - b).abs(); 297 | }); 298 | 299 | Ok(result) 300 | } 301 | 302 | fn linf_dist(&self, other: &ArrayBase) -> Result 303 | where 304 | A: Clone + PartialOrd + Signed, 305 | T: Data, 306 | { 307 | return_err_if_empty!(self); 308 | return_err_unless_same_shape!(self, other); 309 | 310 | let mut max = A::zero(); 311 | 312 | Zip::from(self).and(other).for_each(|self_i, other_i| { 313 | let (a, b) = (self_i.clone(), other_i.clone()); 314 | let diff = (a - b).abs(); 315 | if diff > max { 316 | max = diff; 317 | } 318 | }); 319 | 320 | Ok(max) 321 | } 322 | 323 | fn mean_abs_err(&self, other: &ArrayBase) -> Result 324 | where 325 | A: AddAssign + Clone + Signed + ToPrimitive, 326 | T: Data, 327 | { 328 | let l1_dist = self 329 | .l1_dist(other)? 330 | .to_f64() 331 | .expect("failed cast from type A to f64"); 332 | let n = self.len() as f64; 333 | 334 | Ok(l1_dist / n) 335 | } 336 | 337 | fn mean_sq_err(&self, other: &ArrayBase) -> Result 338 | where 339 | A: AddAssign + Clone + Signed + ToPrimitive, 340 | T: Data, 341 | { 342 | let sq_l2_dist = self 343 | .sq_l2_dist(other)? 344 | .to_f64() 345 | .expect("failed cast from type A to f64"); 346 | let n = self.len() as f64; 347 | 348 | Ok(sq_l2_dist / n) 349 | } 350 | 351 | fn root_mean_sq_err(&self, other: &ArrayBase) -> Result 352 | where 353 | A: AddAssign + Clone + Signed + ToPrimitive, 354 | T: Data, 355 | { 356 | let msd = self.mean_sq_err(other)?; 357 | Ok(msd.sqrt()) 358 | } 359 | 360 | fn peak_signal_to_noise_ratio( 361 | &self, 362 | other: &ArrayBase, 363 | maxv: A, 364 | ) -> Result 365 | where 366 | A: AddAssign + Clone + Signed + ToPrimitive, 367 | T: Data, 368 | { 369 | let maxv_f = maxv.to_f64().expect("failed cast from type A to f64"); 370 | let msd = self.mean_sq_err(&other)?; 371 | let psnr = 10. * f64::log10(maxv_f * maxv_f / msd); 372 | 373 | Ok(psnr) 374 | } 375 | 376 | private_impl! {} 377 | } 378 | -------------------------------------------------------------------------------- /src/entropy.rs: -------------------------------------------------------------------------------- 1 | //! Information theory (e.g. entropy, KL divergence, etc.). 2 | use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch}; 3 | use ndarray::{Array, ArrayBase, Data, Dimension, Zip}; 4 | use num_traits::Float; 5 | 6 | /// Extension trait for `ArrayBase` providing methods 7 | /// to compute information theory quantities 8 | /// (e.g. entropy, Kullback–Leibler divergence, etc.). 9 | pub trait EntropyExt 10 | where 11 | S: Data, 12 | D: Dimension, 13 | { 14 | /// Computes the [entropy] *S* of the array values, defined as 15 | /// 16 | /// ```text 17 | /// n 18 | /// S = - ∑ xᵢ ln(xᵢ) 19 | /// i=1 20 | /// ``` 21 | /// 22 | /// If the array is empty, `Err(EmptyInput)` is returned. 23 | /// 24 | /// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`). 25 | /// 26 | /// ## Remarks 27 | /// 28 | /// The entropy is a measure used in [Information Theory] 29 | /// to describe a probability distribution: it only make sense 30 | /// when the array values sum to 1, with each entry between 31 | /// 0 and 1 (extremes included). 32 | /// 33 | /// The array values are **not** normalised by this function before 34 | /// computing the entropy to avoid introducing potentially 35 | /// unnecessary numerical errors (e.g. if the array were to be already normalised). 36 | /// 37 | /// By definition, *xᵢ ln(xᵢ)* is set to 0 if *xᵢ* is 0. 38 | /// 39 | /// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory) 40 | /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory 41 | fn entropy(&self) -> Result 42 | where 43 | A: Float; 44 | 45 | /// Computes the [Kullback-Leibler divergence] *Dₖₗ(p,q)* between two arrays, 46 | /// where `self`=*p*. 47 | /// 48 | /// The Kullback-Leibler divergence is defined as: 49 | /// 50 | /// ```text 51 | /// n 52 | /// Dₖₗ(p,q) = - ∑ pᵢ ln(qᵢ/pᵢ) 53 | /// i=1 54 | /// ``` 55 | /// 56 | /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned. 57 | /// If the array shapes are not identical, 58 | /// `Err(MultiInputError::ShapeMismatch)` is returned. 59 | /// 60 | /// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing 61 | /// *ln(qᵢ/pᵢ)* is a panic cause for `A`. 62 | /// 63 | /// ## Remarks 64 | /// 65 | /// The Kullback-Leibler divergence is a measure used in [Information Theory] 66 | /// to describe the relationship between two probability distribution: it only make sense 67 | /// when each array sums to 1 with entries between 0 and 1 (extremes included). 68 | /// 69 | /// The array values are **not** normalised by this function before 70 | /// computing the entropy to avoid introducing potentially 71 | /// unnecessary numerical errors (e.g. if the array were to be already normalised). 72 | /// 73 | /// By definition, *pᵢ ln(qᵢ/pᵢ)* is set to 0 if *pᵢ* is 0. 74 | /// 75 | /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence 76 | /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory 77 | fn kl_divergence(&self, q: &ArrayBase) -> Result 78 | where 79 | S2: Data, 80 | A: Float; 81 | 82 | /// Computes the [cross entropy] *H(p,q)* between two arrays, 83 | /// where `self`=*p*. 84 | /// 85 | /// The cross entropy is defined as: 86 | /// 87 | /// ```text 88 | /// n 89 | /// H(p,q) = - ∑ pᵢ ln(qᵢ) 90 | /// i=1 91 | /// ``` 92 | /// 93 | /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned. 94 | /// If the array shapes are not identical, 95 | /// `Err(MultiInputError::ShapeMismatch)` is returned. 96 | /// 97 | /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number 98 | /// is a panic cause for `A`. 99 | /// 100 | /// ## Remarks 101 | /// 102 | /// The cross entropy is a measure used in [Information Theory] 103 | /// to describe the relationship between two probability distributions: it only makes sense 104 | /// when each array sums to 1 with entries between 0 and 1 (extremes included). 105 | /// 106 | /// The array values are **not** normalised by this function before 107 | /// computing the entropy to avoid introducing potentially 108 | /// unnecessary numerical errors (e.g. if the array were to be already normalised). 109 | /// 110 | /// The cross entropy is often used as an objective/loss function in 111 | /// [optimization problems], including [machine learning]. 112 | /// 113 | /// By definition, *pᵢ ln(qᵢ)* is set to 0 if *pᵢ* is 0. 114 | /// 115 | /// [cross entropy]: https://en.wikipedia.org/wiki/Cross-entropy 116 | /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory 117 | /// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method 118 | /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression 119 | fn cross_entropy(&self, q: &ArrayBase) -> Result 120 | where 121 | S2: Data, 122 | A: Float; 123 | 124 | private_decl! {} 125 | } 126 | 127 | impl EntropyExt for ArrayBase 128 | where 129 | S: Data, 130 | D: Dimension, 131 | { 132 | fn entropy(&self) -> Result 133 | where 134 | A: Float, 135 | { 136 | if self.is_empty() { 137 | Err(EmptyInput) 138 | } else { 139 | let entropy = -self 140 | .mapv(|x| { 141 | if x == A::zero() { 142 | A::zero() 143 | } else { 144 | x * x.ln() 145 | } 146 | }) 147 | .sum(); 148 | Ok(entropy) 149 | } 150 | } 151 | 152 | fn kl_divergence(&self, q: &ArrayBase) -> Result 153 | where 154 | A: Float, 155 | S2: Data, 156 | { 157 | if self.is_empty() { 158 | return Err(MultiInputError::EmptyInput); 159 | } 160 | if self.shape() != q.shape() { 161 | return Err(ShapeMismatch { 162 | first_shape: self.shape().to_vec(), 163 | second_shape: q.shape().to_vec(), 164 | } 165 | .into()); 166 | } 167 | 168 | let mut temp = Array::zeros(self.raw_dim()); 169 | Zip::from(&mut temp) 170 | .and(self) 171 | .and(q) 172 | .for_each(|result, &p, &q| { 173 | *result = { 174 | if p == A::zero() { 175 | A::zero() 176 | } else { 177 | p * (q / p).ln() 178 | } 179 | } 180 | }); 181 | let kl_divergence = -temp.sum(); 182 | Ok(kl_divergence) 183 | } 184 | 185 | fn cross_entropy(&self, q: &ArrayBase) -> Result 186 | where 187 | S2: Data, 188 | A: Float, 189 | { 190 | if self.is_empty() { 191 | return Err(MultiInputError::EmptyInput); 192 | } 193 | if self.shape() != q.shape() { 194 | return Err(ShapeMismatch { 195 | first_shape: self.shape().to_vec(), 196 | second_shape: q.shape().to_vec(), 197 | } 198 | .into()); 199 | } 200 | 201 | let mut temp = Array::zeros(self.raw_dim()); 202 | Zip::from(&mut temp) 203 | .and(self) 204 | .and(q) 205 | .for_each(|result, &p, &q| { 206 | *result = { 207 | if p == A::zero() { 208 | A::zero() 209 | } else { 210 | p * q.ln() 211 | } 212 | } 213 | }); 214 | let cross_entropy = -temp.sum(); 215 | Ok(cross_entropy) 216 | } 217 | 218 | private_impl! {} 219 | } 220 | 221 | #[cfg(test)] 222 | mod tests { 223 | use super::EntropyExt; 224 | use crate::errors::{EmptyInput, MultiInputError}; 225 | use approx::assert_abs_diff_eq; 226 | use ndarray::{array, Array1}; 227 | use noisy_float::types::n64; 228 | use std::f64; 229 | 230 | #[test] 231 | fn test_entropy_with_nan_values() { 232 | let a = array![f64::NAN, 1.]; 233 | assert!(a.entropy().unwrap().is_nan()); 234 | } 235 | 236 | #[test] 237 | fn test_entropy_with_empty_array_of_floats() { 238 | let a: Array1 = array![]; 239 | assert_eq!(a.entropy(), Err(EmptyInput)); 240 | } 241 | 242 | #[test] 243 | fn test_entropy_with_array_of_floats() { 244 | // Array of probability values - normalized and positive. 245 | let a: Array1 = array![ 246 | 0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, 0.03368976, 0.00065396, 247 | 0.02906146, 0.00063687, 0.01597306, 0.00787625, 0.00208243, 0.01450896, 0.01803418, 248 | 0.02055336, 0.03029759, 0.03323628, 0.01218822, 0.0001873, 0.01734179, 0.03521668, 249 | 0.02564429, 0.02421992, 0.03540229, 0.03497635, 0.03582331, 0.026558, 0.02460495, 250 | 0.02437716, 0.01212838, 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588, 251 | 0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, 0.00976694, 0.02864634, 252 | 0.00802828, 0.03464088, 0.03557152, 0.01398894, 0.01831756, 0.0227171, 0.00736204, 253 | 0.01866295, 254 | ]; 255 | // Computed using scipy.stats.entropy 256 | let expected_entropy = 3.721606155686918; 257 | 258 | assert_abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = 1e-6); 259 | } 260 | 261 | #[test] 262 | fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), MultiInputError> { 263 | let a = array![f64::NAN, 1.]; 264 | let b = array![2., 1.]; 265 | assert!(a.cross_entropy(&b)?.is_nan()); 266 | assert!(b.cross_entropy(&a)?.is_nan()); 267 | assert!(a.kl_divergence(&b)?.is_nan()); 268 | assert!(b.kl_divergence(&a)?.is_nan()); 269 | Ok(()) 270 | } 271 | 272 | #[test] 273 | fn test_cross_entropy_and_kl_with_same_n_dimension_but_different_n_elements() { 274 | let p = array![f64::NAN, 1.]; 275 | let q = array![2., 1., 5.]; 276 | assert!(q.cross_entropy(&p).is_err()); 277 | assert!(p.cross_entropy(&q).is_err()); 278 | assert!(q.kl_divergence(&p).is_err()); 279 | assert!(p.kl_divergence(&q).is_err()); 280 | } 281 | 282 | #[test] 283 | fn test_cross_entropy_and_kl_with_different_shape_but_same_n_elements() { 284 | // p: 3x2, 6 elements 285 | let p = array![[f64::NAN, 1.], [6., 7.], [10., 20.]]; 286 | // q: 2x3, 6 elements 287 | let q = array![[2., 1., 5.], [1., 1., 7.],]; 288 | assert!(q.cross_entropy(&p).is_err()); 289 | assert!(p.cross_entropy(&q).is_err()); 290 | assert!(q.kl_divergence(&p).is_err()); 291 | assert!(p.kl_divergence(&q).is_err()); 292 | } 293 | 294 | #[test] 295 | fn test_cross_entropy_and_kl_with_empty_array_of_floats() { 296 | let p: Array1 = array![]; 297 | let q: Array1 = array![]; 298 | assert!(p.cross_entropy(&q).unwrap_err().is_empty_input()); 299 | assert!(p.kl_divergence(&q).unwrap_err().is_empty_input()); 300 | } 301 | 302 | #[test] 303 | fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), MultiInputError> { 304 | let p = array![1.]; 305 | let q = array![-1.]; 306 | let cross_entropy: f64 = p.cross_entropy(&q)?; 307 | let kl_divergence: f64 = p.kl_divergence(&q)?; 308 | assert!(cross_entropy.is_nan()); 309 | assert!(kl_divergence.is_nan()); 310 | Ok(()) 311 | } 312 | 313 | #[test] 314 | #[should_panic] 315 | fn test_cross_entropy_with_noisy_negative_qs() { 316 | let p = array![n64(1.)]; 317 | let q = array![n64(-1.)]; 318 | let _ = p.cross_entropy(&q); 319 | } 320 | 321 | #[test] 322 | #[should_panic] 323 | fn test_kl_with_noisy_negative_qs() { 324 | let p = array![n64(1.)]; 325 | let q = array![n64(-1.)]; 326 | let _ = p.kl_divergence(&q); 327 | } 328 | 329 | #[test] 330 | fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), MultiInputError> { 331 | let p = array![0., 0.]; 332 | let q = array![0., 0.5]; 333 | assert_eq!(p.cross_entropy(&q)?, 0.); 334 | assert_eq!(p.kl_divergence(&q)?, 0.); 335 | Ok(()) 336 | } 337 | 338 | #[test] 339 | fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership( 340 | ) -> Result<(), MultiInputError> { 341 | let p = array![0.5, 0.5]; 342 | let mut q = array![0.5, 0.]; 343 | assert_eq!(p.cross_entropy(&q.view_mut())?, f64::INFINITY); 344 | assert_eq!(p.kl_divergence(&q.view_mut())?, f64::INFINITY); 345 | Ok(()) 346 | } 347 | 348 | #[test] 349 | fn test_cross_entropy() -> Result<(), MultiInputError> { 350 | // Arrays of probability values - normalized and positive. 351 | let p: Array1 = array![ 352 | 0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189, 353 | 0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, 0.01959158, 0.05020174, 354 | 0.03801479, 0.00092234, 0.08515856, 0.00580683, 0.0156542, 0.0860375, 0.0724246, 355 | 0.00727477, 0.01004402, 0.01854399, 0.03504082, 356 | ]; 357 | let q: Array1 = array![ 358 | 0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, 0.05604812, 359 | 0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, 0.0625685, 0.07381292, 360 | 0.05489067, 0.01385491, 0.03639174, 0.00511611, 0.05700415, 0.05183825, 0.06703064, 361 | 0.01813342, 0.0007763, 0.0735472, 0.05857833, 362 | ]; 363 | // Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q) 364 | let expected_cross_entropy = 3.385347705020779; 365 | 366 | assert_abs_diff_eq!(p.cross_entropy(&q)?, expected_cross_entropy, epsilon = 1e-6); 367 | Ok(()) 368 | } 369 | 370 | #[test] 371 | fn test_kl() -> Result<(), MultiInputError> { 372 | // Arrays of probability values - normalized and positive. 373 | let p: Array1 = array![ 374 | 0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516, 375 | 0.02213802, 0.02745017, 0.02163975, 0.0324602, 0.03622766, 0.00782343, 0.00222498, 376 | 0.03028156, 0.02346124, 0.00071105, 0.00794496, 0.0127609, 0.02899124, 0.01281487, 377 | 0.0230803, 0.01531864, 0.00518158, 0.02233383, 0.0220279, 0.03196097, 0.03710063, 378 | 0.01817856, 0.03524661, 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151, 379 | 0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375, 0.01988341, 0.02621831, 380 | 0.03564644, 0.01389121, 0.03151622, 0.03195532, 0.00717521, 0.03547256, 0.00371394, 381 | 0.01108706, 382 | ]; 383 | let q: Array1 = array![ 384 | 0.02038386, 0.03143914, 0.02630206, 0.0171595, 0.0067072, 0.00911324, 0.02635717, 385 | 0.01269113, 0.0302361, 0.02243133, 0.01902902, 0.01297185, 0.02118908, 0.03309548, 386 | 0.01266687, 0.0184529, 0.01830936, 0.03430437, 0.02898924, 0.02238251, 0.0139771, 387 | 0.01879774, 0.02396583, 0.03019978, 0.01421278, 0.02078981, 0.03542451, 0.02887438, 388 | 0.01261783, 0.01014241, 0.03263407, 0.0095969, 0.01923903, 0.0051315, 0.00924686, 389 | 0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, 0.03315135, 0.02099325, 390 | 0.03251755, 0.00337555, 0.03432165, 0.01763753, 0.02038337, 0.01923023, 0.01438769, 391 | 0.02082707, 392 | ]; 393 | // Computed using scipy.stats.entropy(p, q) 394 | let expected_kl = 0.3555862567800096; 395 | 396 | assert_abs_diff_eq!(p.kl_divergence(&q)?, expected_kl, epsilon = 1e-6); 397 | Ok(()) 398 | } 399 | } 400 | -------------------------------------------------------------------------------- /src/errors.rs: -------------------------------------------------------------------------------- 1 | //! Custom errors returned from our methods and functions. 2 | use noisy_float::types::N64; 3 | use std::error::Error; 4 | use std::fmt; 5 | 6 | /// An error that indicates that the input array was empty. 7 | #[derive(Clone, Debug, Eq, PartialEq)] 8 | pub struct EmptyInput; 9 | 10 | impl fmt::Display for EmptyInput { 11 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 12 | write!(f, "Empty input.") 13 | } 14 | } 15 | 16 | impl Error for EmptyInput {} 17 | 18 | /// An error computing a minimum/maximum value. 19 | #[derive(Clone, Debug, Eq, PartialEq)] 20 | pub enum MinMaxError { 21 | /// The input was empty. 22 | EmptyInput, 23 | /// The ordering between a tested pair of values was undefined. 24 | UndefinedOrder, 25 | } 26 | 27 | impl fmt::Display for MinMaxError { 28 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 29 | match self { 30 | MinMaxError::EmptyInput => write!(f, "Empty input."), 31 | MinMaxError::UndefinedOrder => { 32 | write!(f, "Undefined ordering between a tested pair of values.") 33 | } 34 | } 35 | } 36 | } 37 | 38 | impl Error for MinMaxError {} 39 | 40 | impl From for MinMaxError { 41 | fn from(_: EmptyInput) -> MinMaxError { 42 | MinMaxError::EmptyInput 43 | } 44 | } 45 | 46 | /// An error used by methods and functions that take two arrays as argument and 47 | /// expect them to have exactly the same shape 48 | /// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`). 49 | #[derive(Clone, Debug, PartialEq)] 50 | pub struct ShapeMismatch { 51 | pub first_shape: Vec, 52 | pub second_shape: Vec, 53 | } 54 | 55 | impl fmt::Display for ShapeMismatch { 56 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 57 | write!( 58 | f, 59 | "Array shapes do not match: {:?} and {:?}.", 60 | self.first_shape, self.second_shape 61 | ) 62 | } 63 | } 64 | 65 | impl Error for ShapeMismatch {} 66 | 67 | /// An error for methods that take multiple non-empty array inputs. 68 | #[derive(Clone, Debug, PartialEq)] 69 | pub enum MultiInputError { 70 | /// One or more of the arrays were empty. 71 | EmptyInput, 72 | /// The arrays did not have the same shape. 73 | ShapeMismatch(ShapeMismatch), 74 | } 75 | 76 | impl MultiInputError { 77 | /// Returns whether `self` is the `EmptyInput` variant. 78 | pub fn is_empty_input(&self) -> bool { 79 | match self { 80 | MultiInputError::EmptyInput => true, 81 | _ => false, 82 | } 83 | } 84 | 85 | /// Returns whether `self` is the `ShapeMismatch` variant. 86 | pub fn is_shape_mismatch(&self) -> bool { 87 | match self { 88 | MultiInputError::ShapeMismatch(_) => true, 89 | _ => false, 90 | } 91 | } 92 | } 93 | 94 | impl fmt::Display for MultiInputError { 95 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 96 | match self { 97 | MultiInputError::EmptyInput => write!(f, "Empty input."), 98 | MultiInputError::ShapeMismatch(e) => write!(f, "Shape mismatch: {}", e), 99 | } 100 | } 101 | } 102 | 103 | impl Error for MultiInputError {} 104 | 105 | impl From for MultiInputError { 106 | fn from(_: EmptyInput) -> Self { 107 | MultiInputError::EmptyInput 108 | } 109 | } 110 | 111 | impl From for MultiInputError { 112 | fn from(err: ShapeMismatch) -> Self { 113 | MultiInputError::ShapeMismatch(err) 114 | } 115 | } 116 | 117 | /// An error computing a quantile. 118 | #[derive(Debug, Clone, Eq, PartialEq)] 119 | pub enum QuantileError { 120 | /// The input was empty. 121 | EmptyInput, 122 | /// The `q` was not between `0.` and `1.` (inclusive). 123 | InvalidQuantile(N64), 124 | } 125 | 126 | impl fmt::Display for QuantileError { 127 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 128 | match self { 129 | QuantileError::EmptyInput => write!(f, "Empty input."), 130 | QuantileError::InvalidQuantile(q) => { 131 | write!(f, "{:} is not between 0. and 1. (inclusive).", q) 132 | } 133 | } 134 | } 135 | } 136 | 137 | impl Error for QuantileError {} 138 | 139 | impl From for QuantileError { 140 | fn from(_: EmptyInput) -> QuantileError { 141 | QuantileError::EmptyInput 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /src/histogram/bins.rs: -------------------------------------------------------------------------------- 1 | #![warn(missing_docs, clippy::all, clippy::pedantic)] 2 | 3 | use ndarray::prelude::*; 4 | use std::ops::{Index, Range}; 5 | 6 | /// A sorted collection of type `A` elements used to represent the boundaries of intervals, i.e. 7 | /// [`Bins`] on a 1-dimensional axis. 8 | /// 9 | /// **Note** that all intervals are left-closed and right-open. See examples below. 10 | /// 11 | /// # Examples 12 | /// 13 | /// ``` 14 | /// use ndarray_stats::histogram::{Bins, Edges}; 15 | /// use noisy_float::types::n64; 16 | /// 17 | /// let unit_edges = Edges::from(vec![n64(0.), n64(1.)]); 18 | /// let unit_interval = Bins::new(unit_edges); 19 | /// // left-closed 20 | /// assert_eq!( 21 | /// unit_interval.range_of(&n64(0.)).unwrap(), 22 | /// n64(0.)..n64(1.), 23 | /// ); 24 | /// // right-open 25 | /// assert_eq!( 26 | /// unit_interval.range_of(&n64(1.)), 27 | /// None 28 | /// ); 29 | /// ``` 30 | /// 31 | /// [`Bins`]: struct.Bins.html 32 | #[derive(Clone, Debug, Eq, PartialEq)] 33 | pub struct Edges { 34 | edges: Vec, 35 | } 36 | 37 | impl From> for Edges { 38 | /// Converts a `Vec` into an `Edges`, consuming the edges. 39 | /// The vector will be sorted in increasing order using an unstable sorting algorithm, with 40 | /// duplicates removed. 41 | /// 42 | /// # Current implementation 43 | /// 44 | /// The current sorting algorithm is the same as [`std::slice::sort_unstable()`][sort], 45 | /// which is based on [pattern-defeating quicksort][pdqsort]. 46 | /// 47 | /// This sort is unstable (i.e., may reorder equal elements), in-place (i.e., does not allocate) 48 | /// , and O(n log n) worst-case. 49 | /// 50 | /// # Examples 51 | /// 52 | /// ``` 53 | /// use ndarray::array; 54 | /// use ndarray_stats::histogram::Edges; 55 | /// 56 | /// let edges = Edges::from(array![1, 15, 10, 10, 20]); 57 | /// // The array gets sorted! 58 | /// assert_eq!( 59 | /// edges[2], 60 | /// 15 61 | /// ); 62 | /// ``` 63 | /// 64 | /// [sort]: https://doc.rust-lang.org/stable/std/primitive.slice.html#method.sort_unstable 65 | /// [pdqsort]: https://github.com/orlp/pdqsort 66 | fn from(mut edges: Vec) -> Self { 67 | // sort the array in-place 68 | edges.sort_unstable(); 69 | // remove duplicates 70 | edges.dedup(); 71 | Edges { edges } 72 | } 73 | } 74 | 75 | impl From> for Edges { 76 | /// Converts an `Array1` into an `Edges`, consuming the 1-dimensional array. 77 | /// The array will be sorted in increasing order using an unstable sorting algorithm, with 78 | /// duplicates removed. 79 | /// 80 | /// # Current implementation 81 | /// 82 | /// The current sorting algorithm is the same as [`std::slice::sort_unstable()`][sort], 83 | /// which is based on [pattern-defeating quicksort][pdqsort]. 84 | /// 85 | /// This sort is unstable (i.e., may reorder equal elements), in-place (i.e., does not allocate) 86 | /// , and O(n log n) worst-case. 87 | /// 88 | /// # Examples 89 | /// 90 | /// ``` 91 | /// use ndarray_stats::histogram::Edges; 92 | /// 93 | /// let edges = Edges::from(vec![1, 15, 10, 20]); 94 | /// // The vec gets sorted! 95 | /// assert_eq!( 96 | /// edges[1], 97 | /// 10 98 | /// ); 99 | /// ``` 100 | /// 101 | /// [sort]: https://doc.rust-lang.org/stable/std/primitive.slice.html#method.sort_unstable 102 | /// [pdqsort]: https://github.com/orlp/pdqsort 103 | fn from(edges: Array1) -> Self { 104 | let edges = edges.to_vec(); 105 | Self::from(edges) 106 | } 107 | } 108 | 109 | impl Index for Edges { 110 | type Output = A; 111 | 112 | /// Returns a reference to the `i`-th edge in `self`. 113 | /// 114 | /// # Panics 115 | /// 116 | /// Panics if the index `i` is out of bounds. 117 | /// 118 | /// # Examples 119 | /// 120 | /// ``` 121 | /// use ndarray_stats::histogram::Edges; 122 | /// 123 | /// let edges = Edges::from(vec![1, 5, 10, 20]); 124 | /// assert_eq!( 125 | /// edges[1], 126 | /// 5 127 | /// ); 128 | /// ``` 129 | fn index(&self, i: usize) -> &Self::Output { 130 | &self.edges[i] 131 | } 132 | } 133 | 134 | impl Edges { 135 | /// Returns the number of edges in `self`. 136 | /// 137 | /// # Examples 138 | /// 139 | /// ``` 140 | /// use ndarray_stats::histogram::Edges; 141 | /// use noisy_float::types::n64; 142 | /// 143 | /// let edges = Edges::from(vec![n64(0.), n64(1.), n64(3.)]); 144 | /// assert_eq!( 145 | /// edges.len(), 146 | /// 3 147 | /// ); 148 | /// ``` 149 | #[must_use] 150 | pub fn len(&self) -> usize { 151 | self.edges.len() 152 | } 153 | 154 | /// Returns `true` if `self` contains no edges. 155 | /// 156 | /// # Examples 157 | /// 158 | /// ``` 159 | /// use ndarray_stats::histogram::Edges; 160 | /// use noisy_float::types::{N64, n64}; 161 | /// 162 | /// let edges = Edges::::from(vec![]); 163 | /// assert_eq!(edges.is_empty(), true); 164 | /// 165 | /// let edges = Edges::from(vec![n64(0.), n64(2.), n64(5.)]); 166 | /// assert_eq!(edges.is_empty(), false); 167 | /// ``` 168 | #[must_use] 169 | pub fn is_empty(&self) -> bool { 170 | self.edges.is_empty() 171 | } 172 | 173 | /// Returns an immutable 1-dimensional array view of edges. 174 | /// 175 | /// # Examples 176 | /// 177 | /// ``` 178 | /// use ndarray::array; 179 | /// use ndarray_stats::histogram::Edges; 180 | /// 181 | /// let edges = Edges::from(vec![0, 5, 3]); 182 | /// assert_eq!( 183 | /// edges.as_array_view(), 184 | /// array![0, 3, 5].view() 185 | /// ); 186 | /// ``` 187 | #[must_use] 188 | pub fn as_array_view(&self) -> ArrayView1<'_, A> { 189 | ArrayView1::from(&self.edges) 190 | } 191 | 192 | /// Returns indices of two consecutive `edges` in `self`, if the interval they represent 193 | /// contains the given `value`, or returns `None` otherwise. 194 | /// 195 | /// That is to say, it returns 196 | /// - `Some((left, right))`, where `left` and `right` are the indices of two consecutive edges 197 | /// in `self` and `right == left + 1`, if `self[left] <= value < self[right]`; 198 | /// - `None`, otherwise. 199 | /// 200 | /// # Examples 201 | /// 202 | /// ``` 203 | /// use ndarray_stats::histogram::Edges; 204 | /// 205 | /// let edges = Edges::from(vec![0, 2, 3]); 206 | /// // `1` is in the interval [0, 2), whose indices are (0, 1) 207 | /// assert_eq!( 208 | /// edges.indices_of(&1), 209 | /// Some((0, 1)) 210 | /// ); 211 | /// // `5` is not in any of intervals 212 | /// assert_eq!( 213 | /// edges.indices_of(&5), 214 | /// None 215 | /// ); 216 | /// ``` 217 | pub fn indices_of(&self, value: &A) -> Option<(usize, usize)> { 218 | // binary search for the correct bin 219 | let n_edges = self.len(); 220 | match self.edges.binary_search(value) { 221 | Ok(i) if i == n_edges - 1 => None, 222 | Ok(i) => Some((i, i + 1)), 223 | Err(i) => match i { 224 | 0 => None, 225 | j if j == n_edges => None, 226 | j => Some((j - 1, j)), 227 | }, 228 | } 229 | } 230 | 231 | /// Returns an iterator over the `edges` in `self`. 232 | pub fn iter(&self) -> impl Iterator { 233 | self.edges.iter() 234 | } 235 | } 236 | 237 | /// A sorted collection of non-overlapping 1-dimensional intervals. 238 | /// 239 | /// **Note** that all intervals are left-closed and right-open. 240 | /// 241 | /// # Examples 242 | /// 243 | /// ``` 244 | /// use ndarray_stats::histogram::{Edges, Bins}; 245 | /// use noisy_float::types::n64; 246 | /// 247 | /// let edges = Edges::from(vec![n64(0.), n64(1.), n64(2.)]); 248 | /// let bins = Bins::new(edges); 249 | /// // first bin 250 | /// assert_eq!( 251 | /// bins.index(0), 252 | /// n64(0.)..n64(1.) // n64(1.) is not included in the bin! 253 | /// ); 254 | /// // second bin 255 | /// assert_eq!( 256 | /// bins.index(1), 257 | /// n64(1.)..n64(2.) 258 | /// ); 259 | /// ``` 260 | #[derive(Clone, Debug, Eq, PartialEq)] 261 | pub struct Bins { 262 | edges: Edges, 263 | } 264 | 265 | impl Bins { 266 | /// Returns a `Bins` instance where each bin corresponds to two consecutive members of the given 267 | /// [`Edges`], consuming the edges. 268 | /// 269 | /// [`Edges`]: struct.Edges.html 270 | #[must_use] 271 | pub fn new(edges: Edges) -> Self { 272 | Bins { edges } 273 | } 274 | 275 | /// Returns the number of bins in `self`. 276 | /// 277 | /// # Examples 278 | /// 279 | /// ``` 280 | /// use ndarray_stats::histogram::{Edges, Bins}; 281 | /// use noisy_float::types::n64; 282 | /// 283 | /// let edges = Edges::from(vec![n64(0.), n64(1.), n64(2.)]); 284 | /// let bins = Bins::new(edges); 285 | /// assert_eq!( 286 | /// bins.len(), 287 | /// 2 288 | /// ); 289 | /// ``` 290 | #[must_use] 291 | pub fn len(&self) -> usize { 292 | match self.edges.len() { 293 | 0 => 0, 294 | n => n - 1, 295 | } 296 | } 297 | 298 | /// Returns `true` if the number of bins is zero, i.e. if the number of edges is 0 or 1. 299 | /// 300 | /// # Examples 301 | /// 302 | /// ``` 303 | /// use ndarray_stats::histogram::{Edges, Bins}; 304 | /// use noisy_float::types::{N64, n64}; 305 | /// 306 | /// // At least 2 edges is needed to represent 1 interval 307 | /// let edges = Edges::from(vec![n64(0.), n64(1.), n64(3.)]); 308 | /// let bins = Bins::new(edges); 309 | /// assert_eq!(bins.is_empty(), false); 310 | /// 311 | /// // No valid interval == Empty 312 | /// let edges = Edges::::from(vec![]); 313 | /// let bins = Bins::new(edges); 314 | /// assert_eq!(bins.is_empty(), true); 315 | /// let edges = Edges::from(vec![n64(0.)]); 316 | /// let bins = Bins::new(edges); 317 | /// assert_eq!(bins.is_empty(), true); 318 | /// ``` 319 | #[must_use] 320 | pub fn is_empty(&self) -> bool { 321 | self.len() == 0 322 | } 323 | 324 | /// Returns the index of the bin in `self` that contains the given `value`, 325 | /// or returns `None` if `value` does not belong to any bins in `self`. 326 | /// 327 | /// # Examples 328 | /// 329 | /// Basic usage: 330 | /// 331 | /// ``` 332 | /// use ndarray_stats::histogram::{Edges, Bins}; 333 | /// 334 | /// let edges = Edges::from(vec![0, 2, 4, 6]); 335 | /// let bins = Bins::new(edges); 336 | /// let value = 1; 337 | /// // The first bin [0, 2) contains `1` 338 | /// assert_eq!( 339 | /// bins.index_of(&1), 340 | /// Some(0) 341 | /// ); 342 | /// // No bin contains 100 343 | /// assert_eq!( 344 | /// bins.index_of(&100), 345 | /// None 346 | /// ) 347 | /// ``` 348 | /// 349 | /// Chaining [`Bins::index`] and [`Bins::index_of`] to get the boundaries of the bin containing 350 | /// the value: 351 | /// 352 | /// ``` 353 | /// # use ndarray_stats::histogram::{Edges, Bins}; 354 | /// # let edges = Edges::from(vec![0, 2, 4, 6]); 355 | /// # let bins = Bins::new(edges); 356 | /// # let value = 1; 357 | /// assert_eq!( 358 | /// // using `Option::map` to avoid panic on index out-of-bounds 359 | /// bins.index_of(&1).map(|i| bins.index(i)), 360 | /// Some(0..2) 361 | /// ); 362 | /// ``` 363 | pub fn index_of(&self, value: &A) -> Option { 364 | self.edges.indices_of(value).map(|t| t.0) 365 | } 366 | 367 | /// Returns a range as the bin which contains the given `value`, or returns `None` otherwise. 368 | /// 369 | /// # Examples 370 | /// 371 | /// ``` 372 | /// use ndarray_stats::histogram::{Edges, Bins}; 373 | /// 374 | /// let edges = Edges::from(vec![0, 2, 4, 6]); 375 | /// let bins = Bins::new(edges); 376 | /// // [0, 2) contains `1` 377 | /// assert_eq!( 378 | /// bins.range_of(&1), 379 | /// Some(0..2) 380 | /// ); 381 | /// // `10` is not in any interval 382 | /// assert_eq!( 383 | /// bins.range_of(&10), 384 | /// None 385 | /// ); 386 | /// ``` 387 | pub fn range_of(&self, value: &A) -> Option> 388 | where 389 | A: Clone, 390 | { 391 | let edges_indexes = self.edges.indices_of(value); 392 | edges_indexes.map(|(left, right)| Range { 393 | start: self.edges[left].clone(), 394 | end: self.edges[right].clone(), 395 | }) 396 | } 397 | 398 | /// Returns a range as the bin at the given `index` position. 399 | /// 400 | /// # Panics 401 | /// 402 | /// Panics if `index` is out of bounds. 403 | /// 404 | /// # Examples 405 | /// 406 | /// ``` 407 | /// use ndarray_stats::histogram::{Edges, Bins}; 408 | /// 409 | /// let edges = Edges::from(vec![1, 5, 10, 20]); 410 | /// let bins = Bins::new(edges); 411 | /// assert_eq!( 412 | /// bins.index(1), 413 | /// 5..10 414 | /// ); 415 | /// ``` 416 | #[must_use] 417 | pub fn index(&self, index: usize) -> Range 418 | where 419 | A: Clone, 420 | { 421 | // It was not possible to implement this functionality 422 | // using the `Index` trait unless we were willing to 423 | // allocate a `Vec>` in the struct. 424 | // Index, in fact, forces you to return a reference. 425 | Range { 426 | start: self.edges[index].clone(), 427 | end: self.edges[index + 1].clone(), 428 | } 429 | } 430 | } 431 | 432 | #[cfg(test)] 433 | mod edges_tests { 434 | use super::{Array1, Edges}; 435 | use quickcheck_macros::quickcheck; 436 | use std::collections::BTreeSet; 437 | use std::iter::FromIterator; 438 | 439 | #[quickcheck] 440 | fn check_sorted_from_vec(v: Vec) -> bool { 441 | let edges = Edges::from(v); 442 | let n = edges.len(); 443 | for i in 1..n { 444 | if edges[i - 1] > edges[i] { 445 | return false; 446 | } 447 | } 448 | true 449 | } 450 | 451 | #[quickcheck] 452 | fn check_sorted_from_array(v: Vec) -> bool { 453 | let a = Array1::from(v); 454 | let edges = Edges::from(a); 455 | let n = edges.len(); 456 | for i in 1..n { 457 | if edges[i - 1] > edges[i] { 458 | return false; 459 | } 460 | } 461 | true 462 | } 463 | 464 | #[quickcheck] 465 | fn edges_are_right_open(v: Vec) -> bool { 466 | let edges = Edges::from(v); 467 | let view = edges.as_array_view(); 468 | if view.is_empty() { 469 | true 470 | } else { 471 | let last = view[view.len() - 1]; 472 | edges.indices_of(&last).is_none() 473 | } 474 | } 475 | 476 | #[quickcheck] 477 | fn edges_are_left_closed(v: Vec) -> bool { 478 | let edges = Edges::from(v); 479 | if let 1 = edges.len() { 480 | true 481 | } else { 482 | let view = edges.as_array_view(); 483 | if view.is_empty() { 484 | true 485 | } else { 486 | let first = view[0]; 487 | edges.indices_of(&first).is_some() 488 | } 489 | } 490 | } 491 | 492 | #[quickcheck] 493 | #[allow(clippy::needless_pass_by_value)] 494 | fn edges_are_deduped(v: Vec) -> bool { 495 | let unique_elements = BTreeSet::from_iter(v.iter()); 496 | let edges = Edges::from(v.clone()); 497 | let view = edges.as_array_view(); 498 | let unique_edges = BTreeSet::from_iter(view.iter()); 499 | unique_edges == unique_elements 500 | } 501 | } 502 | 503 | #[cfg(test)] 504 | mod bins_tests { 505 | use super::{Bins, Edges}; 506 | 507 | #[test] 508 | #[should_panic] 509 | #[allow(unused_must_use)] 510 | fn get_panics_for_out_of_bounds_indexes() { 511 | let edges = Edges::from(vec![0]); 512 | let bins = Bins::new(edges); 513 | // we need at least two edges to make a valid bin! 514 | bins.index(0); 515 | } 516 | } 517 | -------------------------------------------------------------------------------- /src/histogram/errors.rs: -------------------------------------------------------------------------------- 1 | use crate::errors::{EmptyInput, MinMaxError}; 2 | use std::error; 3 | use std::fmt; 4 | 5 | /// Error to denote that no bin has been found for a certain observation. 6 | #[derive(Debug, Clone)] 7 | pub struct BinNotFound; 8 | 9 | impl fmt::Display for BinNotFound { 10 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 11 | write!(f, "No bin has been found.") 12 | } 13 | } 14 | 15 | impl error::Error for BinNotFound { 16 | fn description(&self) -> &str { 17 | "No bin has been found." 18 | } 19 | } 20 | 21 | /// Error computing the set of histogram bins. 22 | #[derive(Debug, Clone)] 23 | pub enum BinsBuildError { 24 | /// The input array was empty. 25 | EmptyInput, 26 | /// The strategy for computing appropriate bins failed. 27 | Strategy, 28 | #[doc(hidden)] 29 | __NonExhaustive, 30 | } 31 | 32 | impl BinsBuildError { 33 | /// Returns whether `self` is the `EmptyInput` variant. 34 | pub fn is_empty_input(&self) -> bool { 35 | match self { 36 | BinsBuildError::EmptyInput => true, 37 | _ => false, 38 | } 39 | } 40 | 41 | /// Returns whether `self` is the `Strategy` variant. 42 | pub fn is_strategy(&self) -> bool { 43 | match self { 44 | BinsBuildError::Strategy => true, 45 | _ => false, 46 | } 47 | } 48 | } 49 | 50 | impl fmt::Display for BinsBuildError { 51 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 52 | write!(f, "The strategy failed to determine a non-zero bin width.") 53 | } 54 | } 55 | 56 | impl error::Error for BinsBuildError { 57 | fn description(&self) -> &str { 58 | "The strategy failed to determine a non-zero bin width." 59 | } 60 | } 61 | 62 | impl From for BinsBuildError { 63 | fn from(_: EmptyInput) -> Self { 64 | BinsBuildError::EmptyInput 65 | } 66 | } 67 | 68 | impl From for BinsBuildError { 69 | fn from(err: MinMaxError) -> BinsBuildError { 70 | match err { 71 | MinMaxError::EmptyInput => BinsBuildError::EmptyInput, 72 | MinMaxError::UndefinedOrder => BinsBuildError::Strategy, 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/histogram/grid.rs: -------------------------------------------------------------------------------- 1 | #![warn(missing_docs, clippy::all, clippy::pedantic)] 2 | 3 | use super::{bins::Bins, errors::BinsBuildError, strategies::BinsBuildingStrategy}; 4 | use itertools::izip; 5 | use ndarray::{ArrayBase, Axis, Data, Ix1, Ix2}; 6 | use std::ops::Range; 7 | 8 | /// An orthogonal partition of a rectangular region in an *n*-dimensional space, e.g. 9 | /// [*a*0, *b*0) × ⋯ × [*a**n*−1, *b**n*−1), 10 | /// represented as a collection of rectangular *n*-dimensional bins. 11 | /// 12 | /// The grid is **solely determined by the Cartesian product of its projections** on each coordinate 13 | /// axis. Therefore, each element in the product set should correspond to a sub-region in the grid. 14 | /// 15 | /// For example, this partition can be represented as a `Grid` struct: 16 | /// 17 | /// ```text 18 | /// 19 | /// g +---+-------+---+ 20 | /// | 3 | 4 | 5 | 21 | /// f +---+-------+---+ 22 | /// | | | | 23 | /// | 0 | 1 | 2 | 24 | /// | | | | 25 | /// e +---+-------+---+ 26 | /// a b c d 27 | /// 28 | /// R0: [a, b) × [e, f) 29 | /// R1: [b, c) × [e, f) 30 | /// R2: [c, d) × [e, f) 31 | /// R3: [a, b) × [f, g) 32 | /// R4: [b, d) × [f, g) 33 | /// R5: [c, d) × [f, g) 34 | /// Grid: { [a, b), [b, c), [c, d) } × { [e, f), [f, g) } == { R0, R1, R2, R3, R4, R5 } 35 | /// ``` 36 | /// 37 | /// while the next one can't: 38 | /// 39 | /// ```text 40 | /// g +---+-----+---+ 41 | /// | | 2 | 3 | 42 | /// (f) | +-----+---+ 43 | /// | 0 | | 44 | /// | | 1 | 45 | /// | | | 46 | /// e +---+-----+---+ 47 | /// a b c d 48 | /// 49 | /// R0: [a, b) × [e, g) 50 | /// R1: [b, d) × [e, f) 51 | /// R2: [b, c) × [f, g) 52 | /// R3: [c, d) × [f, g) 53 | /// // 'f', as long as 'R1', 'R2', or 'R3', doesn't appear on LHS 54 | /// // [b, c) × [e, g), [c, d) × [e, g) doesn't appear on RHS 55 | /// Grid: { [a, b), [b, c), [c, d) } × { [e, g) } != { R0, R1, R2, R3 } 56 | /// ``` 57 | /// 58 | /// # Examples 59 | /// 60 | /// Basic usage, building a `Grid` via [`GridBuilder`], with optimal grid layout determined by 61 | /// a given [`strategy`], and generating a [`histogram`]: 62 | /// 63 | /// ``` 64 | /// use ndarray::{Array, array}; 65 | /// use ndarray_stats::{ 66 | /// histogram::{strategies::Auto, Bins, Edges, Grid, GridBuilder}, 67 | /// HistogramExt, 68 | /// }; 69 | /// 70 | /// // 1-dimensional observations, as a (n_observations, n_dimension) 2-d matrix 71 | /// let observations = Array::from_shape_vec( 72 | /// (12, 1), 73 | /// vec![1, 4, 5, 2, 100, 20, 50, 65, 27, 40, 45, 23], 74 | /// ).unwrap(); 75 | /// 76 | /// // The optimal grid layout is inferred from the data, given a chosen strategy, Auto in this case 77 | /// let grid = GridBuilder::>::from_array(&observations).unwrap().build(); 78 | /// 79 | /// let histogram = observations.histogram(grid); 80 | /// 81 | /// let histogram_matrix = histogram.counts(); 82 | /// // Bins are left-closed, right-open! 83 | /// let expected = array![4, 3, 3, 1, 0, 1]; 84 | /// assert_eq!(histogram_matrix, expected.into_dyn()); 85 | /// ``` 86 | /// 87 | /// [`histogram`]: trait.HistogramExt.html 88 | /// [`GridBuilder`]: struct.GridBuilder.html 89 | /// [`strategy`]: strategies/index.html 90 | #[derive(Clone, Debug, Eq, PartialEq)] 91 | pub struct Grid { 92 | projections: Vec>, 93 | } 94 | 95 | impl From>> for Grid { 96 | /// Converts a `Vec>` into a `Grid`, consuming the vector of bins. 97 | /// 98 | /// The `i`-th element in `Vec>` represents the projection of the bin grid onto the 99 | /// `i`-th axis. 100 | /// 101 | /// Alternatively, a `Grid` can be built directly from data using a [`GridBuilder`]. 102 | /// 103 | /// [`GridBuilder`]: struct.GridBuilder.html 104 | fn from(projections: Vec>) -> Self { 105 | Grid { projections } 106 | } 107 | } 108 | 109 | impl Grid { 110 | /// Returns the number of dimensions of the region partitioned by the grid. 111 | /// 112 | /// # Examples 113 | /// 114 | /// ``` 115 | /// use ndarray_stats::histogram::{Edges, Bins, Grid}; 116 | /// 117 | /// let edges = Edges::from(vec![0, 1]); 118 | /// let bins = Bins::new(edges); 119 | /// let square_grid = Grid::from(vec![bins.clone(), bins.clone()]); 120 | /// 121 | /// assert_eq!(square_grid.ndim(), 2usize) 122 | /// ``` 123 | #[must_use] 124 | pub fn ndim(&self) -> usize { 125 | self.projections.len() 126 | } 127 | 128 | /// Returns the numbers of bins along each coordinate axis. 129 | /// 130 | /// # Examples 131 | /// 132 | /// ``` 133 | /// use ndarray_stats::histogram::{Edges, Bins, Grid}; 134 | /// 135 | /// let edges_x = Edges::from(vec![0, 1]); 136 | /// let edges_y = Edges::from(vec![-1, 0, 1]); 137 | /// let bins_x = Bins::new(edges_x); 138 | /// let bins_y = Bins::new(edges_y); 139 | /// let square_grid = Grid::from(vec![bins_x, bins_y]); 140 | /// 141 | /// assert_eq!(square_grid.shape(), vec![1usize, 2usize]); 142 | /// ``` 143 | #[must_use] 144 | pub fn shape(&self) -> Vec { 145 | self.projections.iter().map(Bins::len).collect() 146 | } 147 | 148 | /// Returns the grid projections on each coordinate axis as a slice of immutable references. 149 | #[must_use] 150 | pub fn projections(&self) -> &[Bins] { 151 | &self.projections 152 | } 153 | 154 | /// Returns an `n-dimensional` index, of bins along each axis that contains the point, if one 155 | /// exists. 156 | /// 157 | /// Returns `None` if the point is outside the grid. 158 | /// 159 | /// # Panics 160 | /// 161 | /// Panics if dimensionality of the point doesn't equal the grid's. 162 | /// 163 | /// # Examples 164 | /// 165 | /// Basic usage: 166 | /// 167 | /// ``` 168 | /// use ndarray::array; 169 | /// use ndarray_stats::histogram::{Edges, Bins, Grid}; 170 | /// use noisy_float::types::n64; 171 | /// 172 | /// let edges = Edges::from(vec![n64(-1.), n64(0.), n64(1.)]); 173 | /// let bins = Bins::new(edges); 174 | /// let square_grid = Grid::from(vec![bins.clone(), bins.clone()]); 175 | /// 176 | /// // (0., -0.7) falls in 1st and 0th bin respectively 177 | /// assert_eq!( 178 | /// square_grid.index_of(&array![n64(0.), n64(-0.7)]), 179 | /// Some(vec![1, 0]), 180 | /// ); 181 | /// // Returns `None`, as `1.` is outside the grid since bins are right-open 182 | /// assert_eq!( 183 | /// square_grid.index_of(&array![n64(0.), n64(1.)]), 184 | /// None, 185 | /// ); 186 | /// ``` 187 | /// 188 | /// A panic upon dimensionality mismatch: 189 | /// 190 | /// ```should_panic 191 | /// # use ndarray::array; 192 | /// # use ndarray_stats::histogram::{Edges, Bins, Grid}; 193 | /// # use noisy_float::types::n64; 194 | /// # let edges = Edges::from(vec![n64(-1.), n64(0.), n64(1.)]); 195 | /// # let bins = Bins::new(edges); 196 | /// # let square_grid = Grid::from(vec![bins.clone(), bins.clone()]); 197 | /// // the point has 3 dimensions, the grid expected 2 dimensions 198 | /// assert_eq!( 199 | /// square_grid.index_of(&array![n64(0.), n64(-0.7), n64(0.5)]), 200 | /// Some(vec![1, 0, 1]), 201 | /// ); 202 | /// ``` 203 | pub fn index_of(&self, point: &ArrayBase) -> Option> 204 | where 205 | S: Data, 206 | { 207 | assert_eq!( 208 | point.len(), 209 | self.ndim(), 210 | "Dimension mismatch: the point has {:?} dimensions, the grid \ 211 | expected {:?} dimensions.", 212 | point.len(), 213 | self.ndim() 214 | ); 215 | point 216 | .iter() 217 | .zip(self.projections.iter()) 218 | .map(|(v, e)| e.index_of(v)) 219 | .collect() 220 | } 221 | } 222 | 223 | impl Grid { 224 | /// Given an `n`-dimensional index, `i = (i_0, ..., i_{n-1})`, returns an `n`-dimensional bin, 225 | /// `I_{i_0} x ... x I_{i_{n-1}}`, where `I_{i_j}` is the `i_j`-th interval on the `j`-th 226 | /// projection of the grid on the coordinate axes. 227 | /// 228 | /// # Panics 229 | /// 230 | /// Panics if at least one in the index, `(i_0, ..., i_{n-1})`, is out of bounds on the 231 | /// corresponding coordinate axis, i.e. if there exists `j` s.t. 232 | /// `i_j >= self.projections[j].len()`. 233 | /// 234 | /// # Examples 235 | /// 236 | /// Basic usage: 237 | /// 238 | /// ``` 239 | /// use ndarray::array; 240 | /// use ndarray_stats::histogram::{Edges, Bins, Grid}; 241 | /// 242 | /// let edges_x = Edges::from(vec![0, 1]); 243 | /// let edges_y = Edges::from(vec![2, 3, 4]); 244 | /// let bins_x = Bins::new(edges_x); 245 | /// let bins_y = Bins::new(edges_y); 246 | /// let square_grid = Grid::from(vec![bins_x, bins_y]); 247 | /// 248 | /// // Query the 0-th bin on x-axis, and 1-st bin on y-axis 249 | /// assert_eq!( 250 | /// square_grid.index(&[0, 1]), 251 | /// vec![0..1, 3..4], 252 | /// ); 253 | /// ``` 254 | /// 255 | /// A panic upon out-of-bounds: 256 | /// 257 | /// ```should_panic 258 | /// # use ndarray::array; 259 | /// # use ndarray_stats::histogram::{Edges, Bins, Grid}; 260 | /// # let edges_x = Edges::from(vec![0, 1]); 261 | /// # let edges_y = Edges::from(vec![2, 3, 4]); 262 | /// # let bins_x = Bins::new(edges_x); 263 | /// # let bins_y = Bins::new(edges_y); 264 | /// # let square_grid = Grid::from(vec![bins_x, bins_y]); 265 | /// // out-of-bound on y-axis 266 | /// assert_eq!( 267 | /// square_grid.index(&[0, 2]), 268 | /// vec![0..1, 3..4], 269 | /// ); 270 | /// ``` 271 | #[must_use] 272 | pub fn index(&self, index: &[usize]) -> Vec> { 273 | assert_eq!( 274 | index.len(), 275 | self.ndim(), 276 | "Dimension mismatch: the index has {0:?} dimensions, the grid \ 277 | expected {1:?} dimensions.", 278 | index.len(), 279 | self.ndim() 280 | ); 281 | izip!(&self.projections, index) 282 | .map(|(bins, &i)| bins.index(i)) 283 | .collect() 284 | } 285 | } 286 | 287 | /// A builder used to create [`Grid`] instances for [`histogram`] computations. 288 | /// 289 | /// # Examples 290 | /// 291 | /// Basic usage, creating a `Grid` with some observations and a given [`strategy`]: 292 | /// 293 | /// ``` 294 | /// use ndarray::Array; 295 | /// use ndarray_stats::histogram::{strategies::Auto, Bins, Edges, Grid, GridBuilder}; 296 | /// 297 | /// // 1-dimensional observations, as a (n_observations, n_dimension) 2-d matrix 298 | /// let observations = Array::from_shape_vec( 299 | /// (12, 1), 300 | /// vec![1, 4, 5, 2, 100, 20, 50, 65, 27, 40, 45, 23], 301 | /// ).unwrap(); 302 | /// 303 | /// // The optimal grid layout is inferred from the data, given a chosen strategy, Auto in this case 304 | /// let grid = GridBuilder::>::from_array(&observations).unwrap().build(); 305 | /// // Equivalently, build a Grid directly 306 | /// let expected_grid = Grid::from(vec![Bins::new(Edges::from(vec![1, 20, 39, 58, 77, 96, 115]))]); 307 | /// 308 | /// assert_eq!(grid, expected_grid); 309 | /// ``` 310 | /// 311 | /// [`Grid`]: struct.Grid.html 312 | /// [`histogram`]: trait.HistogramExt.html 313 | /// [`strategy`]: strategies/index.html 314 | #[allow(clippy::module_name_repetitions)] 315 | pub struct GridBuilder { 316 | bin_builders: Vec, 317 | } 318 | 319 | impl GridBuilder 320 | where 321 | A: Ord, 322 | B: BinsBuildingStrategy, 323 | { 324 | /// Returns a `GridBuilder` for building a [`Grid`] with a given [`strategy`] and some 325 | /// observations in a 2-dimensionalarray with shape `(n_observations, n_dimension)`. 326 | /// 327 | /// # Errors 328 | /// 329 | /// It returns [`BinsBuildError`] if it is not possible to build a [`Grid`] given 330 | /// the observed data according to the chosen [`strategy`]. 331 | /// 332 | /// # Examples 333 | /// 334 | /// See [Trait-level examples] for basic usage. 335 | /// 336 | /// [`Grid`]: struct.Grid.html 337 | /// [`strategy`]: strategies/index.html 338 | /// [`BinsBuildError`]: errors/enum.BinsBuildError.html 339 | /// [Trait-level examples]: struct.GridBuilder.html#examples 340 | pub fn from_array(array: &ArrayBase) -> Result 341 | where 342 | S: Data, 343 | { 344 | let bin_builders = array 345 | .axis_iter(Axis(1)) 346 | .map(|data| B::from_array(&data)) 347 | .collect::, BinsBuildError>>()?; 348 | Ok(Self { bin_builders }) 349 | } 350 | 351 | /// Returns a [`Grid`] instance, with building parameters infered in [`from_array`], according 352 | /// to the specified [`strategy`] and observations provided. 353 | /// 354 | /// # Examples 355 | /// 356 | /// See [Trait-level examples] for basic usage. 357 | /// 358 | /// [`Grid`]: struct.Grid.html 359 | /// [`strategy`]: strategies/index.html 360 | /// [`from_array`]: #method.from_array.html 361 | #[must_use] 362 | pub fn build(&self) -> Grid { 363 | let projections: Vec<_> = self.bin_builders.iter().map(|b| b.build()).collect(); 364 | Grid::from(projections) 365 | } 366 | } 367 | -------------------------------------------------------------------------------- /src/histogram/histograms.rs: -------------------------------------------------------------------------------- 1 | use super::errors::BinNotFound; 2 | use super::grid::Grid; 3 | use ndarray::prelude::*; 4 | use ndarray::Data; 5 | 6 | /// Histogram data structure. 7 | pub struct Histogram { 8 | counts: ArrayD, 9 | grid: Grid, 10 | } 11 | 12 | impl Histogram { 13 | /// Returns a new instance of Histogram given a [`Grid`]. 14 | /// 15 | /// [`Grid`]: struct.Grid.html 16 | pub fn new(grid: Grid) -> Self { 17 | let counts = ArrayD::zeros(grid.shape()); 18 | Histogram { counts, grid } 19 | } 20 | 21 | /// Adds a single observation to the histogram. 22 | /// 23 | /// **Panics** if dimensions do not match: `self.ndim() != observation.len()`. 24 | /// 25 | /// # Example: 26 | /// ``` 27 | /// use ndarray::array; 28 | /// use ndarray_stats::histogram::{Edges, Bins, Histogram, Grid}; 29 | /// use noisy_float::types::n64; 30 | /// 31 | /// let edges = Edges::from(vec![n64(-1.), n64(0.), n64(1.)]); 32 | /// let bins = Bins::new(edges); 33 | /// let square_grid = Grid::from(vec![bins.clone(), bins.clone()]); 34 | /// let mut histogram = Histogram::new(square_grid); 35 | /// 36 | /// let observation = array![n64(0.5), n64(0.6)]; 37 | /// 38 | /// histogram.add_observation(&observation)?; 39 | /// 40 | /// let histogram_matrix = histogram.counts(); 41 | /// let expected = array![ 42 | /// [0, 0], 43 | /// [0, 1], 44 | /// ]; 45 | /// assert_eq!(histogram_matrix, expected.into_dyn()); 46 | /// # Ok::<(), Box>(()) 47 | /// ``` 48 | pub fn add_observation(&mut self, observation: &ArrayBase) -> Result<(), BinNotFound> 49 | where 50 | S: Data, 51 | { 52 | match self.grid.index_of(observation) { 53 | Some(bin_index) => { 54 | self.counts[&*bin_index] += 1; 55 | Ok(()) 56 | } 57 | None => Err(BinNotFound), 58 | } 59 | } 60 | 61 | /// Returns the number of dimensions of the space the histogram is covering. 62 | pub fn ndim(&self) -> usize { 63 | debug_assert_eq!(self.counts.ndim(), self.grid.ndim()); 64 | self.counts.ndim() 65 | } 66 | 67 | /// Borrows a view on the histogram counts matrix. 68 | pub fn counts(&self) -> ArrayViewD<'_, usize> { 69 | self.counts.view() 70 | } 71 | 72 | /// Borrows an immutable reference to the histogram grid. 73 | pub fn grid(&self) -> &Grid { 74 | &self.grid 75 | } 76 | } 77 | 78 | /// Extension trait for `ArrayBase` providing methods to compute histograms. 79 | pub trait HistogramExt 80 | where 81 | S: Data, 82 | { 83 | /// Returns the [histogram](https://en.wikipedia.org/wiki/Histogram) 84 | /// for a 2-dimensional array of points `M`. 85 | /// 86 | /// Let `(n, d)` be the shape of `M`: 87 | /// - `n` is the number of points; 88 | /// - `d` is the number of dimensions of the space those points belong to. 89 | /// It follows that every column in `M` is a `d`-dimensional point. 90 | /// 91 | /// For example: a (3, 4) matrix `M` is a collection of 3 points in a 92 | /// 4-dimensional space. 93 | /// 94 | /// Important: points outside the grid are ignored! 95 | /// 96 | /// **Panics** if `d` is different from `grid.ndim()`. 97 | /// 98 | /// # Example: 99 | /// 100 | /// ``` 101 | /// use ndarray::array; 102 | /// use ndarray_stats::{ 103 | /// HistogramExt, 104 | /// histogram::{ 105 | /// Histogram, Grid, GridBuilder, 106 | /// Edges, Bins, 107 | /// strategies::Sqrt}, 108 | /// }; 109 | /// use noisy_float::types::{N64, n64}; 110 | /// 111 | /// let observations = array![ 112 | /// [n64(1.), n64(0.5)], 113 | /// [n64(-0.5), n64(1.)], 114 | /// [n64(-1.), n64(-0.5)], 115 | /// [n64(0.5), n64(-1.)] 116 | /// ]; 117 | /// let grid = GridBuilder::>::from_array(&observations).unwrap().build(); 118 | /// let expected_grid = Grid::from( 119 | /// vec![ 120 | /// Bins::new(Edges::from(vec![n64(-1.), n64(0.), n64(1.), n64(2.)])), 121 | /// Bins::new(Edges::from(vec![n64(-1.), n64(0.), n64(1.), n64(2.)])), 122 | /// ] 123 | /// ); 124 | /// assert_eq!(grid, expected_grid); 125 | /// 126 | /// let histogram = observations.histogram(grid); 127 | /// 128 | /// let histogram_matrix = histogram.counts(); 129 | /// // Bins are left inclusive, right exclusive! 130 | /// let expected = array![ 131 | /// [1, 0, 1], 132 | /// [1, 0, 0], 133 | /// [0, 1, 0], 134 | /// ]; 135 | /// assert_eq!(histogram_matrix, expected.into_dyn()); 136 | /// ``` 137 | fn histogram(&self, grid: Grid) -> Histogram 138 | where 139 | A: Ord; 140 | 141 | private_decl! {} 142 | } 143 | 144 | impl HistogramExt for ArrayBase 145 | where 146 | S: Data, 147 | A: Ord, 148 | { 149 | fn histogram(&self, grid: Grid) -> Histogram { 150 | let mut histogram = Histogram::new(grid); 151 | for point in self.axis_iter(Axis(0)) { 152 | let _ = histogram.add_observation(&point); 153 | } 154 | histogram 155 | } 156 | 157 | private_impl! {} 158 | } 159 | -------------------------------------------------------------------------------- /src/histogram/mod.rs: -------------------------------------------------------------------------------- 1 | //! Histogram functionalities. 2 | pub use self::bins::{Bins, Edges}; 3 | pub use self::grid::{Grid, GridBuilder}; 4 | pub use self::histograms::{Histogram, HistogramExt}; 5 | 6 | mod bins; 7 | pub mod errors; 8 | mod grid; 9 | mod histograms; 10 | pub mod strategies; 11 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! The [`ndarray-stats`] crate exposes statistical routines for `ArrayBase`, 2 | //! the *n*-dimensional array data structure provided by [`ndarray`]. 3 | //! 4 | //! Currently available routines include: 5 | //! - [order statistics] (minimum, maximum, median, quantiles, etc.); 6 | //! - [summary statistics] (mean, skewness, kurtosis, central moments, etc.) 7 | //! - [partitioning]; 8 | //! - [correlation analysis] (covariance, pearson correlation); 9 | //! - [measures from information theory] (entropy, KL divergence, etc.); 10 | //! - [measures of deviation] (count equal, L1, L2 distances, mean squared err etc.) 11 | //! - [histogram computation]. 12 | //! 13 | //! Please feel free to contribute new functionality! A roadmap can be found [here]. 14 | //! 15 | //! Our work is inspired by other existing statistical packages such as 16 | //! [`NumPy`] (Python) and [`StatsBase.jl`] (Julia) - any contribution bringing us closer to 17 | //! feature parity is more than welcome! 18 | //! 19 | //! [`ndarray-stats`]: https://github.com/rust-ndarray/ndarray-stats/ 20 | //! [`ndarray`]: https://github.com/rust-ndarray/ndarray 21 | //! [order statistics]: trait.QuantileExt.html 22 | //! [partitioning]: trait.Sort1dExt.html 23 | //! [summary statistics]: trait.SummaryStatisticsExt.html 24 | //! [correlation analysis]: trait.CorrelationExt.html 25 | //! [measures of deviation]: trait.DeviationExt.html 26 | //! [measures from information theory]: trait.EntropyExt.html 27 | //! [histogram computation]: histogram/index.html 28 | //! [here]: https://github.com/rust-ndarray/ndarray-stats/issues/1 29 | //! [`NumPy`]: https://docs.scipy.org/doc/numpy-1.14.1/reference/routines.statistics.html 30 | //! [`StatsBase.jl`]: https://juliastats.github.io/StatsBase.jl/latest/ 31 | 32 | pub use crate::correlation::CorrelationExt; 33 | pub use crate::deviation::DeviationExt; 34 | pub use crate::entropy::EntropyExt; 35 | pub use crate::histogram::HistogramExt; 36 | pub use crate::maybe_nan::{MaybeNan, MaybeNanExt}; 37 | pub use crate::quantile::{interpolate, Quantile1dExt, QuantileExt}; 38 | pub use crate::sort::Sort1dExt; 39 | pub use crate::summary_statistics::SummaryStatisticsExt; 40 | 41 | #[cfg(test)] 42 | #[macro_use] 43 | extern crate approx; 44 | 45 | #[macro_use] 46 | mod multi_input_error_macros { 47 | macro_rules! return_err_if_empty { 48 | ($arr:expr) => { 49 | if $arr.len() == 0 { 50 | return Err(MultiInputError::EmptyInput); 51 | } 52 | }; 53 | } 54 | macro_rules! return_err_unless_same_shape { 55 | ($arr_a:expr, $arr_b:expr) => { 56 | use crate::errors::{MultiInputError, ShapeMismatch}; 57 | if $arr_a.shape() != $arr_b.shape() { 58 | return Err(MultiInputError::ShapeMismatch(ShapeMismatch { 59 | first_shape: $arr_a.shape().to_vec(), 60 | second_shape: $arr_b.shape().to_vec(), 61 | }) 62 | .into()); 63 | } 64 | }; 65 | } 66 | } 67 | 68 | #[macro_use] 69 | mod private { 70 | /// This is a public type in a private module, so it can be included in 71 | /// public APIs, but other crates can't access it. 72 | pub struct PrivateMarker; 73 | 74 | /// Defines an associated function for a trait that is impossible for other 75 | /// crates to implement. This makes it possible to add new associated 76 | /// types/functions/consts/etc. to the trait without breaking changes. 77 | macro_rules! private_decl { 78 | () => { 79 | /// This method makes this trait impossible to implement outside of 80 | /// `ndarray-stats` so that we can freely add new methods, etc., to 81 | /// this trait without breaking changes. 82 | /// 83 | /// We don't anticipate any other crates needing to implement this 84 | /// trait, but if you do have such a use-case, please let us know. 85 | /// 86 | /// **Warning** This method is not considered part of the public 87 | /// API, and client code should not rely on it being present. It 88 | /// may be removed in a non-breaking release. 89 | fn __private__(&self, _: crate::private::PrivateMarker); 90 | }; 91 | } 92 | 93 | /// Implements the associated function defined by `private_decl!`. 94 | macro_rules! private_impl { 95 | () => { 96 | fn __private__(&self, _: crate::private::PrivateMarker) {} 97 | }; 98 | } 99 | } 100 | 101 | mod correlation; 102 | mod deviation; 103 | mod entropy; 104 | pub mod errors; 105 | pub mod histogram; 106 | mod maybe_nan; 107 | mod quantile; 108 | mod sort; 109 | mod summary_statistics; 110 | -------------------------------------------------------------------------------- /src/maybe_nan/impl_not_none.rs: -------------------------------------------------------------------------------- 1 | use super::NotNone; 2 | use num_traits::{FromPrimitive, ToPrimitive}; 3 | use std::cmp; 4 | use std::fmt; 5 | use std::ops::{Add, Deref, DerefMut, Div, Mul, Rem, Sub}; 6 | 7 | impl Deref for NotNone { 8 | type Target = T; 9 | fn deref(&self) -> &T { 10 | match self.0 { 11 | Some(ref inner) => inner, 12 | None => unsafe { ::std::hint::unreachable_unchecked() }, 13 | } 14 | } 15 | } 16 | 17 | impl DerefMut for NotNone { 18 | fn deref_mut(&mut self) -> &mut T { 19 | match self.0 { 20 | Some(ref mut inner) => inner, 21 | None => unsafe { ::std::hint::unreachable_unchecked() }, 22 | } 23 | } 24 | } 25 | 26 | impl fmt::Display for NotNone { 27 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { 28 | self.deref().fmt(f) 29 | } 30 | } 31 | 32 | impl Eq for NotNone {} 33 | 34 | impl PartialEq for NotNone { 35 | fn eq(&self, other: &Self) -> bool { 36 | self.deref().eq(other) 37 | } 38 | } 39 | 40 | impl Ord for NotNone { 41 | fn cmp(&self, other: &Self) -> cmp::Ordering { 42 | self.deref().cmp(other) 43 | } 44 | } 45 | 46 | impl PartialOrd for NotNone { 47 | fn partial_cmp(&self, other: &Self) -> Option { 48 | self.deref().partial_cmp(other) 49 | } 50 | fn lt(&self, other: &Self) -> bool { 51 | self.deref().lt(other) 52 | } 53 | fn le(&self, other: &Self) -> bool { 54 | self.deref().le(other) 55 | } 56 | fn gt(&self, other: &Self) -> bool { 57 | self.deref().gt(other) 58 | } 59 | fn ge(&self, other: &Self) -> bool { 60 | self.deref().ge(other) 61 | } 62 | } 63 | 64 | impl Add for NotNone { 65 | type Output = NotNone; 66 | #[inline] 67 | fn add(self, rhs: Self) -> Self::Output { 68 | self.map(|v| v.add(rhs.unwrap())) 69 | } 70 | } 71 | 72 | impl Sub for NotNone { 73 | type Output = NotNone; 74 | #[inline] 75 | fn sub(self, rhs: Self) -> Self::Output { 76 | self.map(|v| v.sub(rhs.unwrap())) 77 | } 78 | } 79 | 80 | impl Mul for NotNone { 81 | type Output = NotNone; 82 | #[inline] 83 | fn mul(self, rhs: Self) -> Self::Output { 84 | self.map(|v| v.mul(rhs.unwrap())) 85 | } 86 | } 87 | 88 | impl Div for NotNone { 89 | type Output = NotNone; 90 | #[inline] 91 | fn div(self, rhs: Self) -> Self::Output { 92 | self.map(|v| v.div(rhs.unwrap())) 93 | } 94 | } 95 | 96 | impl Rem for NotNone { 97 | type Output = NotNone; 98 | #[inline] 99 | fn rem(self, rhs: Self) -> Self::Output { 100 | self.map(|v| v.rem(rhs.unwrap())) 101 | } 102 | } 103 | 104 | impl ToPrimitive for NotNone { 105 | #[inline] 106 | fn to_isize(&self) -> Option { 107 | self.deref().to_isize() 108 | } 109 | #[inline] 110 | fn to_i8(&self) -> Option { 111 | self.deref().to_i8() 112 | } 113 | #[inline] 114 | fn to_i16(&self) -> Option { 115 | self.deref().to_i16() 116 | } 117 | #[inline] 118 | fn to_i32(&self) -> Option { 119 | self.deref().to_i32() 120 | } 121 | #[inline] 122 | fn to_i64(&self) -> Option { 123 | self.deref().to_i64() 124 | } 125 | #[inline] 126 | fn to_i128(&self) -> Option { 127 | self.deref().to_i128() 128 | } 129 | #[inline] 130 | fn to_usize(&self) -> Option { 131 | self.deref().to_usize() 132 | } 133 | #[inline] 134 | fn to_u8(&self) -> Option { 135 | self.deref().to_u8() 136 | } 137 | #[inline] 138 | fn to_u16(&self) -> Option { 139 | self.deref().to_u16() 140 | } 141 | #[inline] 142 | fn to_u32(&self) -> Option { 143 | self.deref().to_u32() 144 | } 145 | #[inline] 146 | fn to_u64(&self) -> Option { 147 | self.deref().to_u64() 148 | } 149 | #[inline] 150 | fn to_u128(&self) -> Option { 151 | self.deref().to_u128() 152 | } 153 | #[inline] 154 | fn to_f32(&self) -> Option { 155 | self.deref().to_f32() 156 | } 157 | #[inline] 158 | fn to_f64(&self) -> Option { 159 | self.deref().to_f64() 160 | } 161 | } 162 | 163 | impl FromPrimitive for NotNone { 164 | #[inline] 165 | fn from_isize(n: isize) -> Option { 166 | Self::try_new(T::from_isize(n)) 167 | } 168 | #[inline] 169 | fn from_i8(n: i8) -> Option { 170 | Self::try_new(T::from_i8(n)) 171 | } 172 | #[inline] 173 | fn from_i16(n: i16) -> Option { 174 | Self::try_new(T::from_i16(n)) 175 | } 176 | #[inline] 177 | fn from_i32(n: i32) -> Option { 178 | Self::try_new(T::from_i32(n)) 179 | } 180 | #[inline] 181 | fn from_i64(n: i64) -> Option { 182 | Self::try_new(T::from_i64(n)) 183 | } 184 | #[inline] 185 | fn from_i128(n: i128) -> Option { 186 | Self::try_new(T::from_i128(n)) 187 | } 188 | #[inline] 189 | fn from_usize(n: usize) -> Option { 190 | Self::try_new(T::from_usize(n)) 191 | } 192 | #[inline] 193 | fn from_u8(n: u8) -> Option { 194 | Self::try_new(T::from_u8(n)) 195 | } 196 | #[inline] 197 | fn from_u16(n: u16) -> Option { 198 | Self::try_new(T::from_u16(n)) 199 | } 200 | #[inline] 201 | fn from_u32(n: u32) -> Option { 202 | Self::try_new(T::from_u32(n)) 203 | } 204 | #[inline] 205 | fn from_u64(n: u64) -> Option { 206 | Self::try_new(T::from_u64(n)) 207 | } 208 | #[inline] 209 | fn from_u128(n: u128) -> Option { 210 | Self::try_new(T::from_u128(n)) 211 | } 212 | #[inline] 213 | fn from_f32(n: f32) -> Option { 214 | Self::try_new(T::from_f32(n)) 215 | } 216 | #[inline] 217 | fn from_f64(n: f64) -> Option { 218 | Self::try_new(T::from_f64(n)) 219 | } 220 | } 221 | -------------------------------------------------------------------------------- /src/maybe_nan/mod.rs: -------------------------------------------------------------------------------- 1 | use ndarray::prelude::*; 2 | use ndarray::{s, Data, DataMut, RemoveAxis}; 3 | use noisy_float::types::{N32, N64}; 4 | use std::mem; 5 | 6 | /// A number type that can have not-a-number values. 7 | pub trait MaybeNan: Sized { 8 | /// A type that is guaranteed not to be a NaN value. 9 | type NotNan; 10 | 11 | /// Returns `true` if the value is a NaN value. 12 | fn is_nan(&self) -> bool; 13 | 14 | /// Tries to convert the value to `NotNan`. 15 | /// 16 | /// Returns `None` if the value is a NaN value. 17 | fn try_as_not_nan(&self) -> Option<&Self::NotNan>; 18 | 19 | /// Converts the value. 20 | /// 21 | /// If the value is `None`, a NaN value is returned. 22 | fn from_not_nan(_: Self::NotNan) -> Self; 23 | 24 | /// Converts the value. 25 | /// 26 | /// If the value is `None`, a NaN value is returned. 27 | fn from_not_nan_opt(_: Option) -> Self; 28 | 29 | /// Converts the value. 30 | /// 31 | /// If the value is `None`, a NaN value is returned. 32 | fn from_not_nan_ref_opt(_: Option<&Self::NotNan>) -> &Self; 33 | 34 | /// Returns a view with the NaN values removed. 35 | /// 36 | /// This modifies the input view by moving elements as necessary. The final 37 | /// order of the elements is unspecified. However, this method is 38 | /// idempotent, and given the same input data, the result is always ordered 39 | /// the same way. 40 | fn remove_nan_mut(_: ArrayViewMut1<'_, Self>) -> ArrayViewMut1<'_, Self::NotNan>; 41 | } 42 | 43 | /// Returns a view with the NaN values removed. 44 | /// 45 | /// This modifies the input view by moving elements as necessary. 46 | fn remove_nan_mut(mut view: ArrayViewMut1<'_, A>) -> ArrayViewMut1<'_, A> { 47 | if view.is_empty() { 48 | return view.slice_move(s![..0]); 49 | } 50 | let mut i = 0; 51 | let mut j = view.len() - 1; 52 | loop { 53 | // At this point, `i == 0 || !view[i-1].is_nan()` 54 | // and `j == view.len() - 1 || view[j+1].is_nan()`. 55 | while i <= j && !view[i].is_nan() { 56 | i += 1; 57 | } 58 | // At this point, `view[i].is_nan() || i == j + 1`. 59 | while j > i && view[j].is_nan() { 60 | j -= 1; 61 | } 62 | // At this point, `!view[j].is_nan() || j == i`. 63 | if i >= j { 64 | return view.slice_move(s![..i]); 65 | } else { 66 | view.swap(i, j); 67 | i += 1; 68 | j -= 1; 69 | } 70 | } 71 | } 72 | 73 | /// Casts a view from one element type to another. 74 | /// 75 | /// # Panics 76 | /// 77 | /// Panics if `T` and `U` differ in size or alignment. 78 | /// 79 | /// # Safety 80 | /// 81 | /// The caller must ensure that qll elements in `view` are valid values for type `U`. 82 | unsafe fn cast_view_mut(mut view: ArrayViewMut1<'_, T>) -> ArrayViewMut1<'_, U> { 83 | assert_eq!(mem::size_of::(), mem::size_of::()); 84 | assert_eq!(mem::align_of::(), mem::align_of::()); 85 | let ptr: *mut U = view.as_mut_ptr().cast(); 86 | let len: usize = view.len_of(Axis(0)); 87 | let stride: isize = view.stride_of(Axis(0)); 88 | if len <= 1 { 89 | // We can use a stride of `0` because the stride is irrelevant for the `len == 1` case. 90 | let stride = 0; 91 | ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr) 92 | } else if stride >= 0 { 93 | let stride = stride as usize; 94 | ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr) 95 | } else { 96 | // At this point, stride < 0. We have to construct the view by using the inverse of the 97 | // stride and then inverting the axis, since `ArrayViewMut::from_shape_ptr` requires the 98 | // stride to be nonnegative. 99 | let neg_stride = stride.checked_neg().unwrap() as usize; 100 | // This is safe because `ndarray` guarantees that it's safe to offset the 101 | // pointer anywhere in the array. 102 | let neg_ptr = ptr.offset((len - 1) as isize * stride); 103 | let mut v = ArrayViewMut1::from_shape_ptr([len].strides([neg_stride]), neg_ptr); 104 | v.invert_axis(Axis(0)); 105 | v 106 | } 107 | } 108 | 109 | macro_rules! impl_maybenan_for_fxx { 110 | ($fxx:ident, $Nxx:ident) => { 111 | impl MaybeNan for $fxx { 112 | type NotNan = $Nxx; 113 | 114 | fn is_nan(&self) -> bool { 115 | $fxx::is_nan(*self) 116 | } 117 | 118 | fn try_as_not_nan(&self) -> Option<&$Nxx> { 119 | $Nxx::try_borrowed(self) 120 | } 121 | 122 | fn from_not_nan(value: $Nxx) -> $fxx { 123 | value.raw() 124 | } 125 | 126 | fn from_not_nan_opt(value: Option<$Nxx>) -> $fxx { 127 | match value { 128 | None => ::std::$fxx::NAN, 129 | Some(num) => num.raw(), 130 | } 131 | } 132 | 133 | fn from_not_nan_ref_opt(value: Option<&$Nxx>) -> &$fxx { 134 | match value { 135 | None => &::std::$fxx::NAN, 136 | Some(num) => num.as_ref(), 137 | } 138 | } 139 | 140 | fn remove_nan_mut(view: ArrayViewMut1<'_, $fxx>) -> ArrayViewMut1<'_, $Nxx> { 141 | let not_nan = remove_nan_mut(view); 142 | // This is safe because `remove_nan_mut` has removed the NaN values, and `$Nxx` is 143 | // a thin wrapper around `$fxx`. 144 | unsafe { cast_view_mut(not_nan) } 145 | } 146 | } 147 | }; 148 | } 149 | impl_maybenan_for_fxx!(f32, N32); 150 | impl_maybenan_for_fxx!(f64, N64); 151 | 152 | macro_rules! impl_maybenan_for_opt_never_nan { 153 | ($ty:ty) => { 154 | impl MaybeNan for Option<$ty> { 155 | type NotNan = NotNone<$ty>; 156 | 157 | fn is_nan(&self) -> bool { 158 | self.is_none() 159 | } 160 | 161 | fn try_as_not_nan(&self) -> Option<&NotNone<$ty>> { 162 | if self.is_none() { 163 | None 164 | } else { 165 | // This is safe because we have checked for the `None` 166 | // case, and `NotNone<$ty>` is a thin wrapper around `Option<$ty>`. 167 | Some(unsafe { &*(self as *const Option<$ty> as *const NotNone<$ty>) }) 168 | } 169 | } 170 | 171 | fn from_not_nan(value: NotNone<$ty>) -> Option<$ty> { 172 | value.into_inner() 173 | } 174 | 175 | fn from_not_nan_opt(value: Option>) -> Option<$ty> { 176 | value.and_then(|v| v.into_inner()) 177 | } 178 | 179 | fn from_not_nan_ref_opt(value: Option<&NotNone<$ty>>) -> &Option<$ty> { 180 | match value { 181 | None => &None, 182 | // This is safe because `NotNone<$ty>` is a thin wrapper around 183 | // `Option<$ty>`. 184 | Some(num) => unsafe { &*(num as *const NotNone<$ty> as *const Option<$ty>) }, 185 | } 186 | } 187 | 188 | fn remove_nan_mut(view: ArrayViewMut1<'_, Self>) -> ArrayViewMut1<'_, Self::NotNan> { 189 | let not_nan = remove_nan_mut(view); 190 | // This is safe because `remove_nan_mut` has removed the `None` 191 | // values, and `NotNone<$ty>` is a thin wrapper around `Option<$ty>`. 192 | unsafe { 193 | ArrayViewMut1::from_shape_ptr( 194 | not_nan.dim(), 195 | not_nan.as_ptr() as *mut NotNone<$ty>, 196 | ) 197 | } 198 | } 199 | } 200 | }; 201 | } 202 | impl_maybenan_for_opt_never_nan!(u8); 203 | impl_maybenan_for_opt_never_nan!(u16); 204 | impl_maybenan_for_opt_never_nan!(u32); 205 | impl_maybenan_for_opt_never_nan!(u64); 206 | impl_maybenan_for_opt_never_nan!(u128); 207 | impl_maybenan_for_opt_never_nan!(i8); 208 | impl_maybenan_for_opt_never_nan!(i16); 209 | impl_maybenan_for_opt_never_nan!(i32); 210 | impl_maybenan_for_opt_never_nan!(i64); 211 | impl_maybenan_for_opt_never_nan!(i128); 212 | impl_maybenan_for_opt_never_nan!(N32); 213 | impl_maybenan_for_opt_never_nan!(N64); 214 | 215 | /// A thin wrapper around `Option` that guarantees that the value is not 216 | /// `None`. 217 | #[derive(Clone, Copy, Debug)] 218 | #[repr(transparent)] 219 | pub struct NotNone(Option); 220 | 221 | impl NotNone { 222 | /// Creates a new `NotNone` containing the given value. 223 | pub fn new(value: T) -> NotNone { 224 | NotNone(Some(value)) 225 | } 226 | 227 | /// Creates a new `NotNone` containing the given value. 228 | /// 229 | /// Returns `None` if `value` is `None`. 230 | pub fn try_new(value: Option) -> Option> { 231 | if value.is_some() { 232 | Some(NotNone(value)) 233 | } else { 234 | None 235 | } 236 | } 237 | 238 | /// Returns the underling option. 239 | pub fn into_inner(self) -> Option { 240 | self.0 241 | } 242 | 243 | /// Moves the value out of the inner option. 244 | /// 245 | /// This method is guaranteed not to panic. 246 | pub fn unwrap(self) -> T { 247 | match self.0 { 248 | Some(inner) => inner, 249 | None => unsafe { ::std::hint::unreachable_unchecked() }, 250 | } 251 | } 252 | 253 | /// Maps an `NotNone` to `NotNone` by applying a function to the 254 | /// contained value. 255 | pub fn map(self, f: F) -> NotNone 256 | where 257 | F: FnOnce(T) -> U, 258 | { 259 | NotNone::new(f(self.unwrap())) 260 | } 261 | } 262 | 263 | /// Extension trait for `ArrayBase` providing NaN-related functionality. 264 | pub trait MaybeNanExt 265 | where 266 | A: MaybeNan, 267 | S: Data, 268 | D: Dimension, 269 | { 270 | /// Traverse the non-NaN array elements and apply a fold, returning the 271 | /// resulting value. 272 | /// 273 | /// Elements are visited in arbitrary order. 274 | fn fold_skipnan<'a, F, B>(&'a self, init: B, f: F) -> B 275 | where 276 | A: 'a, 277 | F: FnMut(B, &'a A::NotNan) -> B; 278 | 279 | /// Traverse the non-NaN elements and their indices and apply a fold, 280 | /// returning the resulting value. 281 | /// 282 | /// Elements are visited in arbitrary order. 283 | fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, f: F) -> B 284 | where 285 | A: 'a, 286 | F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B; 287 | 288 | /// Visit each non-NaN element in the array by calling `f` on each element. 289 | /// 290 | /// Elements are visited in arbitrary order. 291 | fn visit_skipnan<'a, F>(&'a self, f: F) 292 | where 293 | A: 'a, 294 | F: FnMut(&'a A::NotNan); 295 | 296 | /// Fold non-NaN values along an axis. 297 | /// 298 | /// Combine the non-NaN elements of each subview with the previous using 299 | /// the fold function and initial value init. 300 | fn fold_axis_skipnan(&self, axis: Axis, init: B, fold: F) -> Array 301 | where 302 | D: RemoveAxis, 303 | F: FnMut(&B, &A::NotNan) -> B, 304 | B: Clone; 305 | 306 | /// Reduce the values along an axis into just one value, producing a new 307 | /// array with one less dimension. 308 | /// 309 | /// The NaN values are removed from the 1-dimensional lanes, then they are 310 | /// passed as mutable views to the reducer, allowing for side-effects. 311 | /// 312 | /// **Warnings**: 313 | /// 314 | /// * The lanes are visited in arbitrary order. 315 | /// 316 | /// * The order of the elements within the lanes is unspecified. However, 317 | /// if `mapping` is idempotent, this method is idempotent. Additionally, 318 | /// given the same input data, the lane is always ordered the same way. 319 | /// 320 | /// **Panics** if `axis` is out of bounds. 321 | fn map_axis_skipnan_mut<'a, B, F>(&'a mut self, axis: Axis, mapping: F) -> Array 322 | where 323 | A: 'a, 324 | S: DataMut, 325 | D: RemoveAxis, 326 | F: FnMut(ArrayViewMut1<'a, A::NotNan>) -> B; 327 | 328 | private_decl! {} 329 | } 330 | 331 | impl MaybeNanExt for ArrayBase 332 | where 333 | A: MaybeNan, 334 | S: Data, 335 | D: Dimension, 336 | { 337 | fn fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B 338 | where 339 | A: 'a, 340 | F: FnMut(B, &'a A::NotNan) -> B, 341 | { 342 | self.fold(init, |acc, elem| { 343 | if let Some(not_nan) = elem.try_as_not_nan() { 344 | f(acc, not_nan) 345 | } else { 346 | acc 347 | } 348 | }) 349 | } 350 | 351 | fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B 352 | where 353 | A: 'a, 354 | F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B, 355 | { 356 | self.indexed_iter().fold(init, |acc, (idx, elem)| { 357 | if let Some(not_nan) = elem.try_as_not_nan() { 358 | f(acc, (idx, not_nan)) 359 | } else { 360 | acc 361 | } 362 | }) 363 | } 364 | 365 | fn visit_skipnan<'a, F>(&'a self, mut f: F) 366 | where 367 | A: 'a, 368 | F: FnMut(&'a A::NotNan), 369 | { 370 | self.for_each(|elem| { 371 | if let Some(not_nan) = elem.try_as_not_nan() { 372 | f(not_nan) 373 | } 374 | }) 375 | } 376 | 377 | fn fold_axis_skipnan(&self, axis: Axis, init: B, mut fold: F) -> Array 378 | where 379 | D: RemoveAxis, 380 | F: FnMut(&B, &A::NotNan) -> B, 381 | B: Clone, 382 | { 383 | self.fold_axis(axis, init, |acc, elem| { 384 | if let Some(not_nan) = elem.try_as_not_nan() { 385 | fold(acc, not_nan) 386 | } else { 387 | acc.clone() 388 | } 389 | }) 390 | } 391 | 392 | fn map_axis_skipnan_mut<'a, B, F>( 393 | &'a mut self, 394 | axis: Axis, 395 | mut mapping: F, 396 | ) -> Array 397 | where 398 | A: 'a, 399 | S: DataMut, 400 | D: RemoveAxis, 401 | F: FnMut(ArrayViewMut1<'a, A::NotNan>) -> B, 402 | { 403 | self.map_axis_mut(axis, |lane| mapping(A::remove_nan_mut(lane))) 404 | } 405 | 406 | private_impl! {} 407 | } 408 | 409 | #[cfg(test)] 410 | mod tests { 411 | use super::*; 412 | use quickcheck_macros::quickcheck; 413 | 414 | #[quickcheck] 415 | fn remove_nan_mut_idempotent(is_nan: Vec) -> bool { 416 | let mut values: Vec<_> = is_nan 417 | .into_iter() 418 | .map(|is_nan| if is_nan { None } else { Some(1) }) 419 | .collect(); 420 | let view = ArrayViewMut1::from_shape(values.len(), &mut values).unwrap(); 421 | let removed = remove_nan_mut(view); 422 | removed == remove_nan_mut(removed.to_owned().view_mut()) 423 | } 424 | 425 | #[quickcheck] 426 | fn remove_nan_mut_only_nan_remaining(is_nan: Vec) -> bool { 427 | let mut values: Vec<_> = is_nan 428 | .into_iter() 429 | .map(|is_nan| if is_nan { None } else { Some(1) }) 430 | .collect(); 431 | let view = ArrayViewMut1::from_shape(values.len(), &mut values).unwrap(); 432 | remove_nan_mut(view).iter().all(|elem| !elem.is_nan()) 433 | } 434 | 435 | #[quickcheck] 436 | fn remove_nan_mut_keep_all_non_nan(is_nan: Vec) -> bool { 437 | let non_nan_count = is_nan.iter().filter(|&&is_nan| !is_nan).count(); 438 | let mut values: Vec<_> = is_nan 439 | .into_iter() 440 | .map(|is_nan| if is_nan { None } else { Some(1) }) 441 | .collect(); 442 | let view = ArrayViewMut1::from_shape(values.len(), &mut values).unwrap(); 443 | remove_nan_mut(view).len() == non_nan_count 444 | } 445 | } 446 | 447 | mod impl_not_none; 448 | -------------------------------------------------------------------------------- /src/quantile/interpolate.rs: -------------------------------------------------------------------------------- 1 | //! Interpolation strategies. 2 | use noisy_float::types::N64; 3 | use num_traits::{Float, FromPrimitive, NumOps, ToPrimitive}; 4 | 5 | fn float_quantile_index(q: N64, len: usize) -> N64 { 6 | q * ((len - 1) as f64) 7 | } 8 | 9 | /// Returns the fraction that the quantile is between the lower and higher indices. 10 | /// 11 | /// This ranges from 0, where the quantile exactly corresponds the lower index, 12 | /// to 1, where the quantile exactly corresponds to the higher index. 13 | fn float_quantile_index_fraction(q: N64, len: usize) -> N64 { 14 | float_quantile_index(q, len).fract() 15 | } 16 | 17 | /// Returns the index of the value on the lower side of the quantile. 18 | pub(crate) fn lower_index(q: N64, len: usize) -> usize { 19 | float_quantile_index(q, len).floor().to_usize().unwrap() 20 | } 21 | 22 | /// Returns the index of the value on the higher side of the quantile. 23 | pub(crate) fn higher_index(q: N64, len: usize) -> usize { 24 | float_quantile_index(q, len).ceil().to_usize().unwrap() 25 | } 26 | 27 | /// Used to provide an interpolation strategy to [`quantile_axis_mut`]. 28 | /// 29 | /// [`quantile_axis_mut`]: ../trait.QuantileExt.html#tymethod.quantile_axis_mut 30 | pub trait Interpolate { 31 | /// Returns `true` iff the lower value is needed to compute the 32 | /// interpolated value. 33 | #[doc(hidden)] 34 | fn needs_lower(q: N64, len: usize) -> bool; 35 | 36 | /// Returns `true` iff the higher value is needed to compute the 37 | /// interpolated value. 38 | #[doc(hidden)] 39 | fn needs_higher(q: N64, len: usize) -> bool; 40 | 41 | /// Computes the interpolated value. 42 | /// 43 | /// **Panics** if `None` is provided for the lower value when it's needed 44 | /// or if `None` is provided for the higher value when it's needed. 45 | #[doc(hidden)] 46 | fn interpolate(lower: Option, higher: Option, q: N64, len: usize) -> T; 47 | 48 | private_decl! {} 49 | } 50 | 51 | /// Select the higher value. 52 | pub struct Higher; 53 | /// Select the lower value. 54 | pub struct Lower; 55 | /// Select the nearest value. 56 | pub struct Nearest; 57 | /// Select the midpoint of the two values (`(lower + higher) / 2`). 58 | pub struct Midpoint; 59 | /// Linearly interpolate between the two values 60 | /// (`lower + (higher - lower) * fraction`, where `fraction` is the 61 | /// fractional part of the index surrounded by `lower` and `higher`). 62 | pub struct Linear; 63 | 64 | impl Interpolate for Higher { 65 | fn needs_lower(_q: N64, _len: usize) -> bool { 66 | false 67 | } 68 | fn needs_higher(_q: N64, _len: usize) -> bool { 69 | true 70 | } 71 | fn interpolate(_lower: Option, higher: Option, _q: N64, _len: usize) -> T { 72 | higher.unwrap() 73 | } 74 | private_impl! {} 75 | } 76 | 77 | impl Interpolate for Lower { 78 | fn needs_lower(_q: N64, _len: usize) -> bool { 79 | true 80 | } 81 | fn needs_higher(_q: N64, _len: usize) -> bool { 82 | false 83 | } 84 | fn interpolate(lower: Option, _higher: Option, _q: N64, _len: usize) -> T { 85 | lower.unwrap() 86 | } 87 | private_impl! {} 88 | } 89 | 90 | impl Interpolate for Nearest { 91 | fn needs_lower(q: N64, len: usize) -> bool { 92 | float_quantile_index_fraction(q, len) < 0.5 93 | } 94 | fn needs_higher(q: N64, len: usize) -> bool { 95 | !>::needs_lower(q, len) 96 | } 97 | fn interpolate(lower: Option, higher: Option, q: N64, len: usize) -> T { 98 | if >::needs_lower(q, len) { 99 | lower.unwrap() 100 | } else { 101 | higher.unwrap() 102 | } 103 | } 104 | private_impl! {} 105 | } 106 | 107 | impl Interpolate for Midpoint 108 | where 109 | T: NumOps + Clone + FromPrimitive, 110 | { 111 | fn needs_lower(_q: N64, _len: usize) -> bool { 112 | true 113 | } 114 | fn needs_higher(_q: N64, _len: usize) -> bool { 115 | true 116 | } 117 | fn interpolate(lower: Option, higher: Option, _q: N64, _len: usize) -> T { 118 | let denom = T::from_u8(2).unwrap(); 119 | let lower = lower.unwrap(); 120 | let higher = higher.unwrap(); 121 | lower.clone() + (higher.clone() - lower.clone()) / denom.clone() 122 | } 123 | private_impl! {} 124 | } 125 | 126 | impl Interpolate for Linear 127 | where 128 | T: NumOps + Clone + FromPrimitive + ToPrimitive, 129 | { 130 | fn needs_lower(_q: N64, _len: usize) -> bool { 131 | true 132 | } 133 | fn needs_higher(_q: N64, _len: usize) -> bool { 134 | true 135 | } 136 | fn interpolate(lower: Option, higher: Option, q: N64, len: usize) -> T { 137 | let fraction = float_quantile_index_fraction(q, len).to_f64().unwrap(); 138 | let lower = lower.unwrap(); 139 | let higher = higher.unwrap(); 140 | let lower_f64 = lower.to_f64().unwrap(); 141 | let higher_f64 = higher.to_f64().unwrap(); 142 | lower.clone() + T::from_f64(fraction * (higher_f64 - lower_f64)).unwrap() 143 | } 144 | private_impl! {} 145 | } 146 | -------------------------------------------------------------------------------- /src/sort.rs: -------------------------------------------------------------------------------- 1 | use indexmap::IndexMap; 2 | use ndarray::prelude::*; 3 | use ndarray::{Data, DataMut, Slice}; 4 | use rand::prelude::*; 5 | use rand::thread_rng; 6 | 7 | /// Methods for sorting and partitioning 1-D arrays. 8 | pub trait Sort1dExt 9 | where 10 | S: Data, 11 | { 12 | /// Return the element that would occupy the `i`-th position if 13 | /// the array were sorted in increasing order. 14 | /// 15 | /// The array is shuffled **in place** to retrieve the desired element: 16 | /// no copy of the array is allocated. 17 | /// After the shuffling, all elements with an index smaller than `i` 18 | /// are smaller than the desired element, while all elements with 19 | /// an index greater or equal than `i` are greater than or equal 20 | /// to the desired element. 21 | /// 22 | /// No other assumptions should be made on the ordering of the 23 | /// elements after this computation. 24 | /// 25 | /// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)): 26 | /// - average case: O(`n`); 27 | /// - worst case: O(`n`^2); 28 | /// where n is the number of elements in the array. 29 | /// 30 | /// **Panics** if `i` is greater than or equal to `n`. 31 | fn get_from_sorted_mut(&mut self, i: usize) -> A 32 | where 33 | A: Ord + Clone, 34 | S: DataMut; 35 | 36 | /// A bulk version of [`get_from_sorted_mut`], optimized to retrieve multiple 37 | /// indexes at once. 38 | /// It returns an `IndexMap`, with indexes as keys and retrieved elements as 39 | /// values. 40 | /// The `IndexMap` is sorted with respect to indexes in increasing order: 41 | /// this ordering is preserved when you iterate over it (using `iter`/`into_iter`). 42 | /// 43 | /// **Panics** if any element in `indexes` is greater than or equal to `n`, 44 | /// where `n` is the length of the array.. 45 | /// 46 | /// [`get_from_sorted_mut`]: #tymethod.get_from_sorted_mut 47 | fn get_many_from_sorted_mut(&mut self, indexes: &ArrayBase) -> IndexMap 48 | where 49 | A: Ord + Clone, 50 | S: DataMut, 51 | S2: Data; 52 | 53 | /// Partitions the array in increasing order based on the value initially 54 | /// located at `pivot_index` and returns the new index of the value. 55 | /// 56 | /// The elements are rearranged in such a way that the value initially 57 | /// located at `pivot_index` is moved to the position it would be in an 58 | /// array sorted in increasing order. The return value is the new index of 59 | /// the value after rearrangement. All elements smaller than the value are 60 | /// moved to its left and all elements equal or greater than the value are 61 | /// moved to its right. The ordering of the elements in the two partitions 62 | /// is undefined. 63 | /// 64 | /// `self` is shuffled **in place** to operate the desired partition: 65 | /// no copy of the array is allocated. 66 | /// 67 | /// The method uses Hoare's partition algorithm. 68 | /// Complexity: O(`n`), where `n` is the number of elements in the array. 69 | /// Average number of element swaps: n/6 - 1/3 (see 70 | /// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550)) 71 | /// 72 | /// **Panics** if `pivot_index` is greater than or equal to `n`. 73 | /// 74 | /// # Example 75 | /// 76 | /// ``` 77 | /// use ndarray::array; 78 | /// use ndarray_stats::Sort1dExt; 79 | /// 80 | /// let mut data = array![3, 1, 4, 5, 2]; 81 | /// let pivot_index = 2; 82 | /// let pivot_value = data[pivot_index]; 83 | /// 84 | /// // Partition by the value located at `pivot_index`. 85 | /// let new_index = data.partition_mut(pivot_index); 86 | /// // The pivot value is now located at `new_index`. 87 | /// assert_eq!(data[new_index], pivot_value); 88 | /// // Elements less than that value are moved to the left. 89 | /// for i in 0..new_index { 90 | /// assert!(data[i] < pivot_value); 91 | /// } 92 | /// // Elements greater than or equal to that value are moved to the right. 93 | /// for i in (new_index + 1)..data.len() { 94 | /// assert!(data[i] >= pivot_value); 95 | /// } 96 | /// ``` 97 | fn partition_mut(&mut self, pivot_index: usize) -> usize 98 | where 99 | A: Ord + Clone, 100 | S: DataMut; 101 | 102 | private_decl! {} 103 | } 104 | 105 | impl Sort1dExt for ArrayBase 106 | where 107 | S: Data, 108 | { 109 | fn get_from_sorted_mut(&mut self, i: usize) -> A 110 | where 111 | A: Ord + Clone, 112 | S: DataMut, 113 | { 114 | let n = self.len(); 115 | if n == 1 { 116 | self[0].clone() 117 | } else { 118 | let mut rng = thread_rng(); 119 | let pivot_index = rng.gen_range(0..n); 120 | let partition_index = self.partition_mut(pivot_index); 121 | if i < partition_index { 122 | self.slice_axis_mut(Axis(0), Slice::from(..partition_index)) 123 | .get_from_sorted_mut(i) 124 | } else if i == partition_index { 125 | self[i].clone() 126 | } else { 127 | self.slice_axis_mut(Axis(0), Slice::from(partition_index + 1..)) 128 | .get_from_sorted_mut(i - (partition_index + 1)) 129 | } 130 | } 131 | } 132 | 133 | fn get_many_from_sorted_mut(&mut self, indexes: &ArrayBase) -> IndexMap 134 | where 135 | A: Ord + Clone, 136 | S: DataMut, 137 | S2: Data, 138 | { 139 | let mut deduped_indexes: Vec = indexes.to_vec(); 140 | deduped_indexes.sort_unstable(); 141 | deduped_indexes.dedup(); 142 | 143 | get_many_from_sorted_mut_unchecked(self, &deduped_indexes) 144 | } 145 | 146 | fn partition_mut(&mut self, pivot_index: usize) -> usize 147 | where 148 | A: Ord + Clone, 149 | S: DataMut, 150 | { 151 | let pivot_value = self[pivot_index].clone(); 152 | self.swap(pivot_index, 0); 153 | let n = self.len(); 154 | let mut i = 1; 155 | let mut j = n - 1; 156 | loop { 157 | loop { 158 | if i > j { 159 | break; 160 | } 161 | if self[i] >= pivot_value { 162 | break; 163 | } 164 | i += 1; 165 | } 166 | while pivot_value <= self[j] { 167 | if j == 1 { 168 | break; 169 | } 170 | j -= 1; 171 | } 172 | if i >= j { 173 | break; 174 | } else { 175 | self.swap(i, j); 176 | i += 1; 177 | j -= 1; 178 | } 179 | } 180 | self.swap(0, i - 1); 181 | i - 1 182 | } 183 | 184 | private_impl! {} 185 | } 186 | 187 | /// To retrieve multiple indexes from the sorted array in an optimized fashion, 188 | /// [get_many_from_sorted_mut] first of all sorts and deduplicates the 189 | /// `indexes` vector. 190 | /// 191 | /// `get_many_from_sorted_mut_unchecked` does not perform this sorting and 192 | /// deduplication, assuming that the user has already taken care of it. 193 | /// 194 | /// Useful when you have to call [get_many_from_sorted_mut] multiple times 195 | /// using the same indexes. 196 | /// 197 | /// [get_many_from_sorted_mut]: ../trait.Sort1dExt.html#tymethod.get_many_from_sorted_mut 198 | pub(crate) fn get_many_from_sorted_mut_unchecked( 199 | array: &mut ArrayBase, 200 | indexes: &[usize], 201 | ) -> IndexMap 202 | where 203 | A: Ord + Clone, 204 | S: DataMut, 205 | { 206 | if indexes.is_empty() { 207 | return IndexMap::new(); 208 | } 209 | 210 | // Since `!indexes.is_empty()` and indexes must be in-bounds, `array` must 211 | // be non-empty. 212 | let mut values = vec![array[0].clone(); indexes.len()]; 213 | _get_many_from_sorted_mut_unchecked(array.view_mut(), &mut indexes.to_owned(), &mut values); 214 | 215 | // We convert the vector to a more search-friendly `IndexMap`. 216 | indexes.iter().cloned().zip(values.into_iter()).collect() 217 | } 218 | 219 | /// This is the recursive portion of `get_many_from_sorted_mut_unchecked`. 220 | /// 221 | /// `indexes` is the list of indexes to get. `indexes` is mutable so that it 222 | /// can be used as scratch space for this routine; the value of `indexes` after 223 | /// calling this routine should be ignored. 224 | /// 225 | /// `values` is a pre-allocated slice to use for writing the output. Its 226 | /// initial element values are ignored. 227 | fn _get_many_from_sorted_mut_unchecked( 228 | mut array: ArrayViewMut1<'_, A>, 229 | indexes: &mut [usize], 230 | values: &mut [A], 231 | ) where 232 | A: Ord + Clone, 233 | { 234 | let n = array.len(); 235 | debug_assert!(n >= indexes.len()); // because indexes must be unique and in-bounds 236 | debug_assert_eq!(indexes.len(), values.len()); 237 | 238 | if indexes.is_empty() { 239 | // Nothing to do in this case. 240 | return; 241 | } 242 | 243 | // At this point, `n >= 1` since `indexes.len() >= 1`. 244 | if n == 1 { 245 | // We can only reach this point if `indexes.len() == 1`, so we only 246 | // need to assign the single value, and then we're done. 247 | debug_assert_eq!(indexes.len(), 1); 248 | values[0] = array[0].clone(); 249 | return; 250 | } 251 | 252 | // We pick a random pivot index: the corresponding element is the pivot value 253 | let mut rng = thread_rng(); 254 | let pivot_index = rng.gen_range(0..n); 255 | 256 | // We partition the array with respect to the pivot value. 257 | // The pivot value moves to `array_partition_index`. 258 | // Elements strictly smaller than the pivot value have indexes < `array_partition_index`. 259 | // Elements greater or equal to the pivot value have indexes > `array_partition_index`. 260 | let array_partition_index = array.partition_mut(pivot_index); 261 | 262 | // We use a divide-and-conquer strategy, splitting the indexes we are 263 | // searching for (`indexes`) and the corresponding portions of the output 264 | // slice (`values`) into pieces with respect to `array_partition_index`. 265 | let (found_exact, index_split) = match indexes.binary_search(&array_partition_index) { 266 | Ok(index) => (true, index), 267 | Err(index) => (false, index), 268 | }; 269 | let (smaller_indexes, other_indexes) = indexes.split_at_mut(index_split); 270 | let (smaller_values, other_values) = values.split_at_mut(index_split); 271 | let (bigger_indexes, bigger_values) = if found_exact { 272 | other_values[0] = array[array_partition_index].clone(); // Write exactly found value. 273 | (&mut other_indexes[1..], &mut other_values[1..]) 274 | } else { 275 | (other_indexes, other_values) 276 | }; 277 | 278 | // We search recursively for the values corresponding to strictly smaller 279 | // indexes to the left of `partition_index`. 280 | _get_many_from_sorted_mut_unchecked( 281 | array.slice_axis_mut(Axis(0), Slice::from(..array_partition_index)), 282 | smaller_indexes, 283 | smaller_values, 284 | ); 285 | 286 | // We search recursively for the values corresponding to strictly bigger 287 | // indexes to the right of `partition_index`. Since only the right portion 288 | // of the array is passed in, the indexes need to be shifted by length of 289 | // the removed portion. 290 | bigger_indexes 291 | .iter_mut() 292 | .for_each(|x| *x -= array_partition_index + 1); 293 | _get_many_from_sorted_mut_unchecked( 294 | array.slice_axis_mut(Axis(0), Slice::from(array_partition_index + 1..)), 295 | bigger_indexes, 296 | bigger_values, 297 | ); 298 | } 299 | -------------------------------------------------------------------------------- /src/summary_statistics/means.rs: -------------------------------------------------------------------------------- 1 | use super::SummaryStatisticsExt; 2 | use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch}; 3 | use ndarray::{Array, ArrayBase, Axis, Data, Dimension, Ix1, RemoveAxis}; 4 | use num_integer::IterBinomial; 5 | use num_traits::{Float, FromPrimitive, Zero}; 6 | use std::ops::{Add, AddAssign, Div, Mul}; 7 | 8 | impl SummaryStatisticsExt for ArrayBase 9 | where 10 | S: Data, 11 | D: Dimension, 12 | { 13 | fn mean(&self) -> Result 14 | where 15 | A: Clone + FromPrimitive + Add + Div + Zero, 16 | { 17 | let n_elements = self.len(); 18 | if n_elements == 0 { 19 | Err(EmptyInput) 20 | } else { 21 | let n_elements = A::from_usize(n_elements) 22 | .expect("Converting number of elements to `A` must not fail."); 23 | Ok(self.sum() / n_elements) 24 | } 25 | } 26 | 27 | fn weighted_mean(&self, weights: &Self) -> Result 28 | where 29 | A: Copy + Div + Mul + Zero, 30 | { 31 | return_err_if_empty!(self); 32 | let weighted_sum = self.weighted_sum(weights)?; 33 | Ok(weighted_sum / weights.sum()) 34 | } 35 | 36 | fn weighted_sum(&self, weights: &ArrayBase) -> Result 37 | where 38 | A: Copy + Mul + Zero, 39 | { 40 | return_err_unless_same_shape!(self, weights); 41 | Ok(self 42 | .iter() 43 | .zip(weights) 44 | .fold(A::zero(), |acc, (&d, &w)| acc + d * w)) 45 | } 46 | 47 | fn weighted_mean_axis( 48 | &self, 49 | axis: Axis, 50 | weights: &ArrayBase, 51 | ) -> Result, MultiInputError> 52 | where 53 | A: Copy + Div + Mul + Zero, 54 | D: RemoveAxis, 55 | { 56 | return_err_if_empty!(self); 57 | let mut weighted_sum = self.weighted_sum_axis(axis, weights)?; 58 | let weights_sum = weights.sum(); 59 | weighted_sum.mapv_inplace(|v| v / weights_sum); 60 | Ok(weighted_sum) 61 | } 62 | 63 | fn weighted_sum_axis( 64 | &self, 65 | axis: Axis, 66 | weights: &ArrayBase, 67 | ) -> Result, MultiInputError> 68 | where 69 | A: Copy + Mul + Zero, 70 | D: RemoveAxis, 71 | { 72 | if self.shape()[axis.index()] != weights.len() { 73 | return Err(MultiInputError::ShapeMismatch(ShapeMismatch { 74 | first_shape: self.shape().to_vec(), 75 | second_shape: weights.shape().to_vec(), 76 | })); 77 | } 78 | 79 | // We could use `lane.weighted_sum` here, but we're avoiding 2 80 | // conditions and an unwrap per lane. 81 | Ok(self.map_axis(axis, |lane| { 82 | lane.iter() 83 | .zip(weights) 84 | .fold(A::zero(), |acc, (&d, &w)| acc + d * w) 85 | })) 86 | } 87 | 88 | fn harmonic_mean(&self) -> Result 89 | where 90 | A: Float + FromPrimitive, 91 | { 92 | self.map(|x| x.recip()) 93 | .mean() 94 | .map(|x| x.recip()) 95 | .ok_or(EmptyInput) 96 | } 97 | 98 | fn geometric_mean(&self) -> Result 99 | where 100 | A: Float + FromPrimitive, 101 | { 102 | self.map(|x| x.ln()) 103 | .mean() 104 | .map(|x| x.exp()) 105 | .ok_or(EmptyInput) 106 | } 107 | 108 | fn weighted_var(&self, weights: &Self, ddof: A) -> Result 109 | where 110 | A: AddAssign + Float + FromPrimitive, 111 | { 112 | return_err_if_empty!(self); 113 | return_err_unless_same_shape!(self, weights); 114 | let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); 115 | let one = A::from_usize(1).expect("Converting 1 to `A` must not fail."); 116 | assert!( 117 | !(ddof < zero || ddof > one), 118 | "`ddof` must not be less than zero or greater than one", 119 | ); 120 | inner_weighted_var(self, weights, ddof, zero) 121 | } 122 | 123 | fn weighted_std(&self, weights: &Self, ddof: A) -> Result 124 | where 125 | A: AddAssign + Float + FromPrimitive, 126 | { 127 | Ok(self.weighted_var(weights, ddof)?.sqrt()) 128 | } 129 | 130 | fn weighted_var_axis( 131 | &self, 132 | axis: Axis, 133 | weights: &ArrayBase, 134 | ddof: A, 135 | ) -> Result, MultiInputError> 136 | where 137 | A: AddAssign + Float + FromPrimitive, 138 | D: RemoveAxis, 139 | { 140 | return_err_if_empty!(self); 141 | if self.shape()[axis.index()] != weights.len() { 142 | return Err(MultiInputError::ShapeMismatch(ShapeMismatch { 143 | first_shape: self.shape().to_vec(), 144 | second_shape: weights.shape().to_vec(), 145 | })); 146 | } 147 | let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); 148 | let one = A::from_usize(1).expect("Converting 1 to `A` must not fail."); 149 | assert!( 150 | !(ddof < zero || ddof > one), 151 | "`ddof` must not be less than zero or greater than one", 152 | ); 153 | 154 | // `weights` must be a view because `lane` is a view in this context. 155 | let weights = weights.view(); 156 | Ok(self.map_axis(axis, |lane| { 157 | inner_weighted_var(&lane, &weights, ddof, zero).unwrap() 158 | })) 159 | } 160 | 161 | fn weighted_std_axis( 162 | &self, 163 | axis: Axis, 164 | weights: &ArrayBase, 165 | ddof: A, 166 | ) -> Result, MultiInputError> 167 | where 168 | A: AddAssign + Float + FromPrimitive, 169 | D: RemoveAxis, 170 | { 171 | Ok(self 172 | .weighted_var_axis(axis, weights, ddof)? 173 | .mapv_into(|x| x.sqrt())) 174 | } 175 | 176 | fn kurtosis(&self) -> Result 177 | where 178 | A: Float + FromPrimitive, 179 | { 180 | let central_moments = self.central_moments(4)?; 181 | Ok(central_moments[4] / central_moments[2].powi(2)) 182 | } 183 | 184 | fn skewness(&self) -> Result 185 | where 186 | A: Float + FromPrimitive, 187 | { 188 | let central_moments = self.central_moments(3)?; 189 | Ok(central_moments[3] / central_moments[2].sqrt().powi(3)) 190 | } 191 | 192 | fn central_moment(&self, order: u16) -> Result 193 | where 194 | A: Float + FromPrimitive, 195 | { 196 | if self.is_empty() { 197 | return Err(EmptyInput); 198 | } 199 | match order { 200 | 0 => Ok(A::one()), 201 | 1 => Ok(A::zero()), 202 | n => { 203 | let mean = self.mean().unwrap(); 204 | let shifted_array = self.mapv(|x| x - mean); 205 | let shifted_moments = moments(shifted_array, n); 206 | let correction_term = -shifted_moments[1]; 207 | 208 | let coefficients = central_moment_coefficients(&shifted_moments); 209 | Ok(horner_method(coefficients, correction_term)) 210 | } 211 | } 212 | } 213 | 214 | fn central_moments(&self, order: u16) -> Result, EmptyInput> 215 | where 216 | A: Float + FromPrimitive, 217 | { 218 | if self.is_empty() { 219 | return Err(EmptyInput); 220 | } 221 | match order { 222 | 0 => Ok(vec![A::one()]), 223 | 1 => Ok(vec![A::one(), A::zero()]), 224 | n => { 225 | // We only perform these operations once, and then reuse their 226 | // result to compute all the required moments 227 | let mean = self.mean().unwrap(); 228 | let shifted_array = self.mapv(|x| x - mean); 229 | let shifted_moments = moments(shifted_array, n); 230 | let correction_term = -shifted_moments[1]; 231 | 232 | let mut central_moments = vec![A::one(), A::zero()]; 233 | for k in 2..=n { 234 | let coefficients = 235 | central_moment_coefficients(&shifted_moments[..=(k as usize)]); 236 | let central_moment = horner_method(coefficients, correction_term); 237 | central_moments.push(central_moment) 238 | } 239 | Ok(central_moments) 240 | } 241 | } 242 | } 243 | 244 | private_impl! {} 245 | } 246 | 247 | /// Private function for `weighted_var` without conditions and asserts. 248 | fn inner_weighted_var( 249 | arr: &ArrayBase, 250 | weights: &ArrayBase, 251 | ddof: A, 252 | zero: A, 253 | ) -> Result 254 | where 255 | S: Data, 256 | A: AddAssign + Float + FromPrimitive, 257 | D: Dimension, 258 | { 259 | let mut weight_sum = zero; 260 | let mut mean = zero; 261 | let mut s = zero; 262 | for (&x, &w) in arr.iter().zip(weights.iter()) { 263 | weight_sum += w; 264 | let x_minus_mean = x - mean; 265 | mean += (w / weight_sum) * x_minus_mean; 266 | s += w * x_minus_mean * (x - mean); 267 | } 268 | Ok(s / (weight_sum - ddof)) 269 | } 270 | 271 | /// Returns a vector containing all moments of the array elements up to 272 | /// *order*, where the *p*-th moment is defined as: 273 | /// 274 | /// ```text 275 | /// 1 n 276 | /// ― ∑ xᵢᵖ 277 | /// n i=1 278 | /// ``` 279 | /// 280 | /// The returned moments are ordered by power magnitude: 0th moment, 1st moment, etc. 281 | /// 282 | /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. 283 | fn moments(a: ArrayBase, order: u16) -> Vec 284 | where 285 | A: Float + FromPrimitive, 286 | S: Data, 287 | D: Dimension, 288 | { 289 | let n_elements = 290 | A::from_usize(a.len()).expect("Converting number of elements to `A` must not fail"); 291 | let order = i32::from(order); 292 | 293 | // When k=0, we are raising each element to the 0th power 294 | // No need to waste CPU cycles going through the array 295 | let mut moments = vec![A::one()]; 296 | 297 | if order >= 1 { 298 | // When k=1, we don't need to raise elements to the 1th power (identity) 299 | moments.push(a.sum() / n_elements) 300 | } 301 | 302 | for k in 2..=order { 303 | moments.push(a.map(|x| x.powi(k)).sum() / n_elements) 304 | } 305 | moments 306 | } 307 | 308 | /// Returns the coefficients in the polynomial expression to compute the *p*th 309 | /// central moment as a function of the sample mean. 310 | /// 311 | /// It takes as input all moments up to order *p*, ordered by power magnitude - *p* is 312 | /// inferred to be the length of the *moments* array. 313 | fn central_moment_coefficients(moments: &[A]) -> Vec 314 | where 315 | A: Float + FromPrimitive, 316 | { 317 | let order = moments.len(); 318 | IterBinomial::new(order) 319 | .zip(moments.iter().rev()) 320 | .map(|(binom, &moment)| A::from_usize(binom).unwrap() * moment) 321 | .collect() 322 | } 323 | 324 | /// Uses [Horner's method] to evaluate a polynomial with a single indeterminate. 325 | /// 326 | /// Coefficients are expected to be sorted by ascending order 327 | /// with respect to the indeterminate's exponent. 328 | /// 329 | /// If the array is empty, `A::zero()` is returned. 330 | /// 331 | /// Horner's method can evaluate a polynomial of order *n* with a single indeterminate 332 | /// using only *n-1* multiplications and *n-1* sums - in terms of number of operations, 333 | /// this is an optimal algorithm for polynomial evaluation. 334 | /// 335 | /// [Horner's method]: https://en.wikipedia.org/wiki/Horner%27s_method 336 | fn horner_method(coefficients: Vec, indeterminate: A) -> A 337 | where 338 | A: Float, 339 | { 340 | let mut result = A::zero(); 341 | for coefficient in coefficients.into_iter().rev() { 342 | result = coefficient + indeterminate * result 343 | } 344 | result 345 | } 346 | -------------------------------------------------------------------------------- /src/summary_statistics/mod.rs: -------------------------------------------------------------------------------- 1 | //! Summary statistics (e.g. mean, variance, etc.). 2 | use crate::errors::{EmptyInput, MultiInputError}; 3 | use ndarray::{Array, ArrayBase, Axis, Data, Dimension, Ix1, RemoveAxis}; 4 | use num_traits::{Float, FromPrimitive, Zero}; 5 | use std::ops::{Add, AddAssign, Div, Mul}; 6 | 7 | /// Extension trait for `ArrayBase` providing methods 8 | /// to compute several summary statistics (e.g. mean, variance, etc.). 9 | pub trait SummaryStatisticsExt 10 | where 11 | S: Data, 12 | D: Dimension, 13 | { 14 | /// Returns the [`arithmetic mean`] x̅ of all elements in the array: 15 | /// 16 | /// ```text 17 | /// 1 n 18 | /// x̅ = ― ∑ xᵢ 19 | /// n i=1 20 | /// ``` 21 | /// 22 | /// If the array is empty, `Err(EmptyInput)` is returned. 23 | /// 24 | /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. 25 | /// 26 | /// [`arithmetic mean`]: https://en.wikipedia.org/wiki/Arithmetic_mean 27 | fn mean(&self) -> Result 28 | where 29 | A: Clone + FromPrimitive + Add + Div + Zero; 30 | 31 | /// Returns the [`arithmetic weighted mean`] x̅ of all elements in the array. Use `weighted_sum` 32 | /// if the `weights` are normalized (they sum up to 1.0). 33 | /// 34 | /// ```text 35 | /// n 36 | /// ∑ wᵢxᵢ 37 | /// i=1 38 | /// x̅ = ――――――――― 39 | /// n 40 | /// ∑ wᵢ 41 | /// i=1 42 | /// ``` 43 | /// 44 | /// **Panics** if division by zero panics for type A. 45 | /// 46 | /// The following **errors** may be returned: 47 | /// 48 | /// * `MultiInputError::EmptyInput` if `self` is empty 49 | /// * `MultiInputError::ShapeMismatch` if `self` and `weights` don't have the same shape 50 | /// 51 | /// [`arithmetic weighted mean`] https://en.wikipedia.org/wiki/Weighted_arithmetic_mean 52 | fn weighted_mean(&self, weights: &Self) -> Result 53 | where 54 | A: Copy + Div + Mul + Zero; 55 | 56 | /// Returns the weighted sum of all elements in the array, that is, the dot product of the 57 | /// arrays `self` and `weights`. Equivalent to `weighted_mean` if the `weights` are normalized. 58 | /// 59 | /// ```text 60 | /// n 61 | /// x̅ = ∑ wᵢxᵢ 62 | /// i=1 63 | /// ``` 64 | /// 65 | /// The following **errors** may be returned: 66 | /// 67 | /// * `MultiInputError::ShapeMismatch` if `self` and `weights` don't have the same shape 68 | fn weighted_sum(&self, weights: &Self) -> Result 69 | where 70 | A: Copy + Mul + Zero; 71 | 72 | /// Returns the [`arithmetic weighted mean`] x̅ along `axis`. Use `weighted_mean_axis ` if the 73 | /// `weights` are normalized. 74 | /// 75 | /// ```text 76 | /// n 77 | /// ∑ wᵢxᵢ 78 | /// i=1 79 | /// x̅ = ――――――――― 80 | /// n 81 | /// ∑ wᵢ 82 | /// i=1 83 | /// ``` 84 | /// 85 | /// **Panics** if `axis` is out of bounds. 86 | /// 87 | /// The following **errors** may be returned: 88 | /// 89 | /// * `MultiInputError::EmptyInput` if `self` is empty 90 | /// * `MultiInputError::ShapeMismatch` if `self` length along axis is not equal to `weights` length 91 | /// 92 | /// [`arithmetic weighted mean`] https://en.wikipedia.org/wiki/Weighted_arithmetic_mean 93 | fn weighted_mean_axis( 94 | &self, 95 | axis: Axis, 96 | weights: &ArrayBase, 97 | ) -> Result, MultiInputError> 98 | where 99 | A: Copy + Div + Mul + Zero, 100 | D: RemoveAxis; 101 | 102 | /// Returns the weighted sum along `axis`, that is, the dot product of `weights` and each lane 103 | /// of `self` along `axis`. Equivalent to `weighted_mean_axis` if the `weights` are normalized. 104 | /// 105 | /// ```text 106 | /// n 107 | /// x̅ = ∑ wᵢxᵢ 108 | /// i=1 109 | /// ``` 110 | /// 111 | /// **Panics** if `axis` is out of bounds. 112 | /// 113 | /// The following **errors** may be returned 114 | /// 115 | /// * `MultiInputError::ShapeMismatch` if `self` and `weights` don't have the same shape 116 | fn weighted_sum_axis( 117 | &self, 118 | axis: Axis, 119 | weights: &ArrayBase, 120 | ) -> Result, MultiInputError> 121 | where 122 | A: Copy + Mul + Zero, 123 | D: RemoveAxis; 124 | 125 | /// Returns the [`harmonic mean`] `HM(X)` of all elements in the array: 126 | /// 127 | /// ```text 128 | /// ⎛ n ⎞⁻¹ 129 | /// HM(X) = n ⎜ ∑ xᵢ⁻¹⎟ 130 | /// ⎝i=1 ⎠ 131 | /// ``` 132 | /// 133 | /// If the array is empty, `Err(EmptyInput)` is returned. 134 | /// 135 | /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. 136 | /// 137 | /// [`harmonic mean`]: https://en.wikipedia.org/wiki/Harmonic_mean 138 | fn harmonic_mean(&self) -> Result 139 | where 140 | A: Float + FromPrimitive; 141 | 142 | /// Returns the [`geometric mean`] `GM(X)` of all elements in the array: 143 | /// 144 | /// ```text 145 | /// ⎛ n ⎞¹⁄ₙ 146 | /// GM(X) = ⎜ ∏ xᵢ⎟ 147 | /// ⎝i=1 ⎠ 148 | /// ``` 149 | /// 150 | /// If the array is empty, `Err(EmptyInput)` is returned. 151 | /// 152 | /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. 153 | /// 154 | /// [`geometric mean`]: https://en.wikipedia.org/wiki/Geometric_mean 155 | fn geometric_mean(&self) -> Result 156 | where 157 | A: Float + FromPrimitive; 158 | 159 | /// Return weighted variance of all elements in the array. 160 | /// 161 | /// The weighted variance is computed using the [`West, D. H. D.`] incremental algorithm. 162 | /// Equivalent to `var_axis` if the `weights` are normalized. 163 | /// 164 | /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the 165 | /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. 166 | /// 167 | /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, 168 | /// or if `A::from_usize()` fails for zero or one. 169 | /// 170 | /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm 171 | fn weighted_var(&self, weights: &Self, ddof: A) -> Result 172 | where 173 | A: AddAssign + Float + FromPrimitive; 174 | 175 | /// Return weighted standard deviation of all elements in the array. 176 | /// 177 | /// The weighted standard deviation is computed using the [`West, D. H. D.`] incremental 178 | /// algorithm. Equivalent to `var_axis` if the `weights` are normalized. 179 | /// 180 | /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the 181 | /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. 182 | /// 183 | /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, 184 | /// or if `A::from_usize()` fails for zero or one. 185 | /// 186 | /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm 187 | fn weighted_std(&self, weights: &Self, ddof: A) -> Result 188 | where 189 | A: AddAssign + Float + FromPrimitive; 190 | 191 | /// Return weighted variance along `axis`. 192 | /// 193 | /// The weighted variance is computed using the [`West, D. H. D.`] incremental algorithm. 194 | /// Equivalent to `var_axis` if the `weights` are normalized. 195 | /// 196 | /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the 197 | /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. 198 | /// 199 | /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, 200 | /// or if `A::from_usize()` fails for zero or one. 201 | /// 202 | /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm 203 | fn weighted_var_axis( 204 | &self, 205 | axis: Axis, 206 | weights: &ArrayBase, 207 | ddof: A, 208 | ) -> Result, MultiInputError> 209 | where 210 | A: AddAssign + Float + FromPrimitive, 211 | D: RemoveAxis; 212 | 213 | /// Return weighted standard deviation along `axis`. 214 | /// 215 | /// The weighted standard deviation is computed using the [`West, D. H. D.`] incremental 216 | /// algorithm. Equivalent to `var_axis` if the `weights` are normalized. 217 | /// 218 | /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the 219 | /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. 220 | /// 221 | /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, 222 | /// or if `A::from_usize()` fails for zero or one. 223 | /// 224 | /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm 225 | fn weighted_std_axis( 226 | &self, 227 | axis: Axis, 228 | weights: &ArrayBase, 229 | ddof: A, 230 | ) -> Result, MultiInputError> 231 | where 232 | A: AddAssign + Float + FromPrimitive, 233 | D: RemoveAxis; 234 | 235 | /// Returns the [kurtosis] `Kurt[X]` of all elements in the array: 236 | /// 237 | /// ```text 238 | /// Kurt[X] = μ₄ / σ⁴ 239 | /// ``` 240 | /// 241 | /// where μ₄ is the fourth central moment and σ is the standard deviation of 242 | /// the elements in the array. 243 | /// 244 | /// This is sometimes referred to as _Pearson's kurtosis_. Fisher's kurtosis can be 245 | /// computed by subtracting 3 from Pearson's kurtosis. 246 | /// 247 | /// If the array is empty, `Err(EmptyInput)` is returned. 248 | /// 249 | /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. 250 | /// 251 | /// [kurtosis]: https://en.wikipedia.org/wiki/Kurtosis 252 | fn kurtosis(&self) -> Result 253 | where 254 | A: Float + FromPrimitive; 255 | 256 | /// Returns the [Pearson's moment coefficient of skewness] γ₁ of all elements in the array: 257 | /// 258 | /// ```text 259 | /// γ₁ = μ₃ / σ³ 260 | /// ``` 261 | /// 262 | /// where μ₃ is the third central moment and σ is the standard deviation of 263 | /// the elements in the array. 264 | /// 265 | /// If the array is empty, `Err(EmptyInput)` is returned. 266 | /// 267 | /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array. 268 | /// 269 | /// [Pearson's moment coefficient of skewness]: https://en.wikipedia.org/wiki/Skewness 270 | fn skewness(&self) -> Result 271 | where 272 | A: Float + FromPrimitive; 273 | 274 | /// Returns the *p*-th [central moment] of all elements in the array, μₚ: 275 | /// 276 | /// ```text 277 | /// 1 n 278 | /// μₚ = ― ∑ (xᵢ-x̅)ᵖ 279 | /// n i=1 280 | /// ``` 281 | /// 282 | /// If the array is empty, `Err(EmptyInput)` is returned. 283 | /// 284 | /// The *p*-th central moment is computed using a corrected two-pass algorithm (see Section 3.5 285 | /// in [Pébay et al., 2016]). Complexity is *O(np)* when *n >> p*, *p > 1*. 286 | /// 287 | /// **Panics** if `A::from_usize()` fails to convert the number of elements 288 | /// in the array or if `order` overflows `i32`. 289 | /// 290 | /// [central moment]: https://en.wikipedia.org/wiki/Central_moment 291 | /// [Pébay et al., 2016]: https://www.osti.gov/pages/servlets/purl/1427275 292 | fn central_moment(&self, order: u16) -> Result 293 | where 294 | A: Float + FromPrimitive; 295 | 296 | /// Returns the first *p* [central moments] of all elements in the array, see [central moment] 297 | /// for more details. 298 | /// 299 | /// If the array is empty, `Err(EmptyInput)` is returned. 300 | /// 301 | /// This method reuses the intermediate steps for the *k*-th moment to compute the *(k+1)*-th, 302 | /// being thus more efficient than repeated calls to [central moment] if the computation 303 | /// of central moments of multiple orders is required. 304 | /// 305 | /// **Panics** if `A::from_usize()` fails to convert the number of elements 306 | /// in the array or if `order` overflows `i32`. 307 | /// 308 | /// [central moments]: https://en.wikipedia.org/wiki/Central_moment 309 | /// [central moment]: #tymethod.central_moment 310 | fn central_moments(&self, order: u16) -> Result, EmptyInput> 311 | where 312 | A: Float + FromPrimitive; 313 | 314 | private_decl! {} 315 | } 316 | 317 | mod means; 318 | -------------------------------------------------------------------------------- /tests/deviation.rs: -------------------------------------------------------------------------------- 1 | use ndarray_stats::errors::{MultiInputError, ShapeMismatch}; 2 | use ndarray_stats::DeviationExt; 3 | 4 | use approx::assert_abs_diff_eq; 5 | use ndarray::{array, Array1}; 6 | use num_bigint::BigInt; 7 | use num_traits::Float; 8 | 9 | use std::f64; 10 | 11 | #[test] 12 | fn test_count_eq() -> Result<(), MultiInputError> { 13 | let a = array![0., 0.]; 14 | let b = array![1., 0.]; 15 | let c = array![0., 1.]; 16 | let d = array![1., 1.]; 17 | 18 | assert_eq!(a.count_eq(&a)?, 2); 19 | assert_eq!(a.count_eq(&b)?, 1); 20 | assert_eq!(a.count_eq(&c)?, 1); 21 | assert_eq!(a.count_eq(&d)?, 0); 22 | 23 | Ok(()) 24 | } 25 | 26 | #[test] 27 | fn test_count_neq() -> Result<(), MultiInputError> { 28 | let a = array![0., 0.]; 29 | let b = array![1., 0.]; 30 | let c = array![0., 1.]; 31 | let d = array![1., 1.]; 32 | 33 | assert_eq!(a.count_neq(&a)?, 0); 34 | assert_eq!(a.count_neq(&b)?, 1); 35 | assert_eq!(a.count_neq(&c)?, 1); 36 | assert_eq!(a.count_neq(&d)?, 2); 37 | 38 | Ok(()) 39 | } 40 | 41 | #[test] 42 | fn test_sq_l2_dist() -> Result<(), MultiInputError> { 43 | let a = array![0., 1., 4., 2.]; 44 | let b = array![1., 1., 2., 4.]; 45 | 46 | assert_eq!(a.sq_l2_dist(&b)?, 9.); 47 | 48 | Ok(()) 49 | } 50 | 51 | #[test] 52 | fn test_l2_dist() -> Result<(), MultiInputError> { 53 | let a = array![0., 1., 4., 2.]; 54 | let b = array![1., 1., 2., 4.]; 55 | 56 | assert_eq!(a.l2_dist(&b)?, 3.); 57 | 58 | Ok(()) 59 | } 60 | 61 | #[test] 62 | fn test_l1_dist() -> Result<(), MultiInputError> { 63 | let a = array![0., 1., 4., 2.]; 64 | let b = array![1., 1., 2., 4.]; 65 | 66 | assert_eq!(a.l1_dist(&b)?, 5.); 67 | 68 | Ok(()) 69 | } 70 | 71 | #[test] 72 | fn test_linf_dist() -> Result<(), MultiInputError> { 73 | let a = array![0., 0.]; 74 | let b = array![1., 0.]; 75 | let c = array![1., 2.]; 76 | 77 | assert_eq!(a.linf_dist(&a)?, 0.); 78 | 79 | assert_eq!(a.linf_dist(&b)?, 1.); 80 | assert_eq!(b.linf_dist(&a)?, 1.); 81 | 82 | assert_eq!(a.linf_dist(&c)?, 2.); 83 | assert_eq!(c.linf_dist(&a)?, 2.); 84 | 85 | Ok(()) 86 | } 87 | 88 | #[test] 89 | fn test_mean_abs_err() -> Result<(), MultiInputError> { 90 | let a = array![1., 1.]; 91 | let b = array![3., 5.]; 92 | 93 | assert_eq!(a.mean_abs_err(&a)?, 0.); 94 | assert_eq!(a.mean_abs_err(&b)?, 3.); 95 | assert_eq!(b.mean_abs_err(&a)?, 3.); 96 | 97 | Ok(()) 98 | } 99 | 100 | #[test] 101 | fn test_mean_sq_err() -> Result<(), MultiInputError> { 102 | let a = array![1., 1.]; 103 | let b = array![3., 5.]; 104 | 105 | assert_eq!(a.mean_sq_err(&a)?, 0.); 106 | assert_eq!(a.mean_sq_err(&b)?, 10.); 107 | assert_eq!(b.mean_sq_err(&a)?, 10.); 108 | 109 | Ok(()) 110 | } 111 | 112 | #[test] 113 | fn test_root_mean_sq_err() -> Result<(), MultiInputError> { 114 | let a = array![1., 1.]; 115 | let b = array![3., 5.]; 116 | 117 | assert_eq!(a.root_mean_sq_err(&a)?, 0.); 118 | assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 10.0.sqrt()); 119 | assert_abs_diff_eq!(b.root_mean_sq_err(&a)?, 10.0.sqrt()); 120 | 121 | Ok(()) 122 | } 123 | 124 | #[test] 125 | fn test_peak_signal_to_noise_ratio() -> Result<(), MultiInputError> { 126 | let a = array![1., 1.]; 127 | assert!(a.peak_signal_to_noise_ratio(&a, 1.)?.is_infinite()); 128 | 129 | let a = array![1., 2., 3., 4., 5., 6., 7.]; 130 | let b = array![1., 3., 3., 4., 6., 7., 8.]; 131 | let maxv = 8.; 132 | let expected = 20. * Float::log10(maxv) - 10. * Float::log10(a.mean_sq_err(&b)?); 133 | let actual = a.peak_signal_to_noise_ratio(&b, maxv)?; 134 | 135 | assert_abs_diff_eq!(actual, expected); 136 | 137 | Ok(()) 138 | } 139 | 140 | #[test] 141 | fn test_deviations_with_n_by_m_ints() -> Result<(), MultiInputError> { 142 | let a = array![[0, 1], [4, 2]]; 143 | let b = array![[1, 1], [2, 4]]; 144 | 145 | assert_eq!(a.count_eq(&a)?, 4); 146 | assert_eq!(a.count_neq(&a)?, 0); 147 | 148 | assert_eq!(a.sq_l2_dist(&b)?, 9); 149 | assert_eq!(a.l2_dist(&b)?, 3.); 150 | assert_eq!(a.l1_dist(&b)?, 5); 151 | assert_eq!(a.linf_dist(&b)?, 2); 152 | 153 | assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25); 154 | assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25); 155 | assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5); 156 | assert_abs_diff_eq!(a.peak_signal_to_noise_ratio(&b, 4)?, 8.519374645445623); 157 | 158 | Ok(()) 159 | } 160 | 161 | #[test] 162 | fn test_deviations_with_empty_receiver() { 163 | let a: Array1 = array![]; 164 | let b: Array1 = array![1.]; 165 | 166 | assert_eq!(a.count_eq(&b), Err(MultiInputError::EmptyInput)); 167 | assert_eq!(a.count_neq(&b), Err(MultiInputError::EmptyInput)); 168 | 169 | assert_eq!(a.sq_l2_dist(&b), Err(MultiInputError::EmptyInput)); 170 | assert_eq!(a.l2_dist(&b), Err(MultiInputError::EmptyInput)); 171 | assert_eq!(a.l1_dist(&b), Err(MultiInputError::EmptyInput)); 172 | assert_eq!(a.linf_dist(&b), Err(MultiInputError::EmptyInput)); 173 | 174 | assert_eq!(a.mean_abs_err(&b), Err(MultiInputError::EmptyInput)); 175 | assert_eq!(a.mean_sq_err(&b), Err(MultiInputError::EmptyInput)); 176 | assert_eq!(a.root_mean_sq_err(&b), Err(MultiInputError::EmptyInput)); 177 | assert_eq!( 178 | a.peak_signal_to_noise_ratio(&b, 0.), 179 | Err(MultiInputError::EmptyInput) 180 | ); 181 | } 182 | 183 | #[test] 184 | fn test_deviations_do_not_panic_if_nans() -> Result<(), MultiInputError> { 185 | let a: Array1 = array![1., f64::NAN, 3., f64::NAN]; 186 | let b: Array1 = array![1., f64::NAN, 3., 4.]; 187 | 188 | assert_eq!(a.count_eq(&b)?, 2); 189 | assert_eq!(a.count_neq(&b)?, 2); 190 | 191 | assert!(a.sq_l2_dist(&b)?.is_nan()); 192 | assert!(a.l2_dist(&b)?.is_nan()); 193 | assert!(a.l1_dist(&b)?.is_nan()); 194 | assert_eq!(a.linf_dist(&b)?, 0.); 195 | 196 | assert!(a.mean_abs_err(&b)?.is_nan()); 197 | assert!(a.mean_sq_err(&b)?.is_nan()); 198 | assert!(a.root_mean_sq_err(&b)?.is_nan()); 199 | assert!(a.peak_signal_to_noise_ratio(&b, 0.)?.is_nan()); 200 | 201 | Ok(()) 202 | } 203 | 204 | #[test] 205 | fn test_deviations_with_empty_argument() { 206 | let a: Array1 = array![1.]; 207 | let b: Array1 = array![]; 208 | 209 | let shape_mismatch_err = MultiInputError::ShapeMismatch(ShapeMismatch { 210 | first_shape: a.shape().to_vec(), 211 | second_shape: b.shape().to_vec(), 212 | }); 213 | let expected_err_usize = Err(shape_mismatch_err.clone()); 214 | let expected_err_f64 = Err(shape_mismatch_err); 215 | 216 | assert_eq!(a.count_eq(&b), expected_err_usize); 217 | assert_eq!(a.count_neq(&b), expected_err_usize); 218 | 219 | assert_eq!(a.sq_l2_dist(&b), expected_err_f64); 220 | assert_eq!(a.l2_dist(&b), expected_err_f64); 221 | assert_eq!(a.l1_dist(&b), expected_err_f64); 222 | assert_eq!(a.linf_dist(&b), expected_err_f64); 223 | 224 | assert_eq!(a.mean_abs_err(&b), expected_err_f64); 225 | assert_eq!(a.mean_sq_err(&b), expected_err_f64); 226 | assert_eq!(a.root_mean_sq_err(&b), expected_err_f64); 227 | assert_eq!(a.peak_signal_to_noise_ratio(&b, 0.), expected_err_f64); 228 | } 229 | 230 | #[test] 231 | fn test_deviations_with_non_copyable() -> Result<(), MultiInputError> { 232 | let a: Array1 = array![0.into(), 1.into(), 4.into(), 2.into()]; 233 | let b: Array1 = array![1.into(), 1.into(), 2.into(), 4.into()]; 234 | 235 | assert_eq!(a.count_eq(&a)?, 4); 236 | assert_eq!(a.count_neq(&a)?, 0); 237 | 238 | assert_eq!(a.sq_l2_dist(&b)?, 9.into()); 239 | assert_eq!(a.l2_dist(&b)?, 3.); 240 | assert_eq!(a.l1_dist(&b)?, 5.into()); 241 | assert_eq!(a.linf_dist(&b)?, 2.into()); 242 | 243 | assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25); 244 | assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25); 245 | assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5); 246 | assert_abs_diff_eq!( 247 | a.peak_signal_to_noise_ratio(&b, 4.into())?, 248 | 8.519374645445623 249 | ); 250 | 251 | Ok(()) 252 | } 253 | 254 | #[test] 255 | fn test_deviation_computation_for_mixed_ownership() { 256 | // It's enough to check that the code compiles! 257 | let a = array![0., 0.]; 258 | let b = array![1., 0.]; 259 | 260 | let _ = a.count_eq(&b.view()); 261 | let _ = a.count_neq(&b.view()); 262 | let _ = a.l2_dist(&b.view()); 263 | let _ = a.sq_l2_dist(&b.view()); 264 | let _ = a.l1_dist(&b.view()); 265 | let _ = a.linf_dist(&b.view()); 266 | let _ = a.mean_abs_err(&b.view()); 267 | let _ = a.mean_sq_err(&b.view()); 268 | let _ = a.root_mean_sq_err(&b.view()); 269 | let _ = a.peak_signal_to_noise_ratio(&b.view(), 10.); 270 | } 271 | -------------------------------------------------------------------------------- /tests/maybe_nan.rs: -------------------------------------------------------------------------------- 1 | use ndarray::prelude::*; 2 | use ndarray_stats::MaybeNan; 3 | use noisy_float::types::{n64, N64}; 4 | 5 | #[test] 6 | fn remove_nan_mut_nonstandard_layout() { 7 | fn eq_unordered(mut a: Vec, mut b: Vec) -> bool { 8 | a.sort(); 9 | b.sort(); 10 | a == b 11 | } 12 | let a = aview1(&[1., 2., f64::NAN, f64::NAN, 3., f64::NAN, 4., 5.]); 13 | { 14 | let mut a = a.to_owned(); 15 | let v = f64::remove_nan_mut(a.slice_mut(s![..;2])); 16 | assert!(eq_unordered(v.to_vec(), vec![n64(1.), n64(3.), n64(4.)])); 17 | } 18 | { 19 | let mut a = a.to_owned(); 20 | let v = f64::remove_nan_mut(a.slice_mut(s![..;-1])); 21 | assert!(eq_unordered( 22 | v.to_vec(), 23 | vec![n64(5.), n64(4.), n64(3.), n64(2.), n64(1.)], 24 | )); 25 | } 26 | { 27 | let mut a = a.to_owned(); 28 | let v = f64::remove_nan_mut(a.slice_mut(s![..;-2])); 29 | assert!(eq_unordered(v.to_vec(), vec![n64(5.), n64(2.)])); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /tests/quantile.rs: -------------------------------------------------------------------------------- 1 | use itertools::izip; 2 | use ndarray::array; 3 | use ndarray::prelude::*; 4 | use ndarray_stats::{ 5 | errors::{EmptyInput, MinMaxError, QuantileError}, 6 | interpolate::{Higher, Interpolate, Linear, Lower, Midpoint, Nearest}, 7 | Quantile1dExt, QuantileExt, 8 | }; 9 | use noisy_float::types::{n64, N64}; 10 | use quickcheck_macros::quickcheck; 11 | 12 | #[test] 13 | fn test_argmin() { 14 | let a = array![[1, 5, 3], [2, 0, 6]]; 15 | assert_eq!(a.argmin(), Ok((1, 1))); 16 | 17 | let a = array![[1., 5., 3.], [2., 0., 6.]]; 18 | assert_eq!(a.argmin(), Ok((1, 1))); 19 | 20 | let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]]; 21 | assert_eq!(a.argmin(), Err(MinMaxError::UndefinedOrder)); 22 | 23 | let a: Array2 = array![[], []]; 24 | assert_eq!(a.argmin(), Err(MinMaxError::EmptyInput)); 25 | } 26 | 27 | #[quickcheck] 28 | fn argmin_matches_min(data: Vec) -> bool { 29 | let a = Array1::from(data); 30 | a.argmin().map(|i| &a[i]) == a.min() 31 | } 32 | 33 | #[test] 34 | fn test_argmin_skipnan() { 35 | let a = array![[1., 5., 3.], [2., 0., 6.]]; 36 | assert_eq!(a.argmin_skipnan(), Ok((1, 1))); 37 | 38 | let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]]; 39 | assert_eq!(a.argmin_skipnan(), Ok((0, 0))); 40 | 41 | let a = array![[::std::f64::NAN, 5., 3.], [2., ::std::f64::NAN, 6.]]; 42 | assert_eq!(a.argmin_skipnan(), Ok((1, 0))); 43 | 44 | let a: Array2 = array![[], []]; 45 | assert_eq!(a.argmin_skipnan(), Err(EmptyInput)); 46 | 47 | let a = arr2(&[[::std::f64::NAN; 2]; 2]); 48 | assert_eq!(a.argmin_skipnan(), Err(EmptyInput)); 49 | } 50 | 51 | #[quickcheck] 52 | fn argmin_skipnan_matches_min_skipnan(data: Vec>) -> bool { 53 | let a = Array1::from(data); 54 | let min = a.min_skipnan(); 55 | let argmin = a.argmin_skipnan(); 56 | if min.is_none() { 57 | argmin == Err(EmptyInput) 58 | } else { 59 | a[argmin.unwrap()] == *min 60 | } 61 | } 62 | 63 | #[test] 64 | fn test_min() { 65 | let a = array![[1, 5, 3], [2, 0, 6]]; 66 | assert_eq!(a.min(), Ok(&0)); 67 | 68 | let a = array![[1., 5., 3.], [2., 0., 6.]]; 69 | assert_eq!(a.min(), Ok(&0.)); 70 | 71 | let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]]; 72 | assert_eq!(a.min(), Err(MinMaxError::UndefinedOrder)); 73 | } 74 | 75 | #[test] 76 | fn test_min_skipnan() { 77 | let a = array![[1., 5., 3.], [2., 0., 6.]]; 78 | assert_eq!(a.min_skipnan(), &0.); 79 | 80 | let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]]; 81 | assert_eq!(a.min_skipnan(), &1.); 82 | } 83 | 84 | #[test] 85 | fn test_min_skipnan_all_nan() { 86 | let a = arr2(&[[::std::f64::NAN; 3]; 2]); 87 | assert!(a.min_skipnan().is_nan()); 88 | } 89 | 90 | #[test] 91 | fn test_argmax() { 92 | let a = array![[1, 5, 3], [2, 0, 6]]; 93 | assert_eq!(a.argmax(), Ok((1, 2))); 94 | 95 | let a = array![[1., 5., 3.], [2., 0., 6.]]; 96 | assert_eq!(a.argmax(), Ok((1, 2))); 97 | 98 | let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]]; 99 | assert_eq!(a.argmax(), Err(MinMaxError::UndefinedOrder)); 100 | 101 | let a: Array2 = array![[], []]; 102 | assert_eq!(a.argmax(), Err(MinMaxError::EmptyInput)); 103 | } 104 | 105 | #[quickcheck] 106 | fn argmax_matches_max(data: Vec) -> bool { 107 | let a = Array1::from(data); 108 | a.argmax().map(|i| &a[i]) == a.max() 109 | } 110 | 111 | #[test] 112 | fn test_argmax_skipnan() { 113 | let a = array![[1., 5., 3.], [2., 0., 6.]]; 114 | assert_eq!(a.argmax_skipnan(), Ok((1, 2))); 115 | 116 | let a = array![[1., 5., 3.], [2., ::std::f64::NAN, ::std::f64::NAN]]; 117 | assert_eq!(a.argmax_skipnan(), Ok((0, 1))); 118 | 119 | let a = array![ 120 | [::std::f64::NAN, ::std::f64::NAN, 3.], 121 | [2., ::std::f64::NAN, 6.] 122 | ]; 123 | assert_eq!(a.argmax_skipnan(), Ok((1, 2))); 124 | 125 | let a: Array2 = array![[], []]; 126 | assert_eq!(a.argmax_skipnan(), Err(EmptyInput)); 127 | 128 | let a = arr2(&[[::std::f64::NAN; 2]; 2]); 129 | assert_eq!(a.argmax_skipnan(), Err(EmptyInput)); 130 | } 131 | 132 | #[quickcheck] 133 | fn argmax_skipnan_matches_max_skipnan(data: Vec>) -> bool { 134 | let a = Array1::from(data); 135 | let max = a.max_skipnan(); 136 | let argmax = a.argmax_skipnan(); 137 | if max.is_none() { 138 | argmax == Err(EmptyInput) 139 | } else { 140 | a[argmax.unwrap()] == *max 141 | } 142 | } 143 | 144 | #[test] 145 | fn test_max() { 146 | let a = array![[1, 5, 7], [2, 0, 6]]; 147 | assert_eq!(a.max(), Ok(&7)); 148 | 149 | let a = array![[1., 5., 7.], [2., 0., 6.]]; 150 | assert_eq!(a.max(), Ok(&7.)); 151 | 152 | let a = array![[1., 5., 7.], [2., ::std::f64::NAN, 6.]]; 153 | assert_eq!(a.max(), Err(MinMaxError::UndefinedOrder)); 154 | } 155 | 156 | #[test] 157 | fn test_max_skipnan() { 158 | let a = array![[1., 5., 7.], [2., 0., 6.]]; 159 | assert_eq!(a.max_skipnan(), &7.); 160 | 161 | let a = array![[1., 5., 7.], [2., ::std::f64::NAN, 6.]]; 162 | assert_eq!(a.max_skipnan(), &7.); 163 | } 164 | 165 | #[test] 166 | fn test_max_skipnan_all_nan() { 167 | let a = arr2(&[[::std::f64::NAN; 3]; 2]); 168 | assert!(a.max_skipnan().is_nan()); 169 | } 170 | 171 | #[test] 172 | fn test_quantile_axis_mut_with_odd_axis_length() { 173 | let mut a = arr2(&[[1, 3, 2, 10], [2, 4, 3, 11], [3, 5, 6, 12]]); 174 | let p = a.quantile_axis_mut(Axis(0), n64(0.5), &Lower).unwrap(); 175 | assert!(p == a.index_axis(Axis(0), 1)); 176 | } 177 | 178 | #[test] 179 | fn test_quantile_axis_mut_with_zero_axis_length() { 180 | let mut a = Array2::::zeros((5, 0)); 181 | assert_eq!( 182 | a.quantile_axis_mut(Axis(1), n64(0.5), &Lower), 183 | Err(QuantileError::EmptyInput) 184 | ); 185 | } 186 | 187 | #[test] 188 | fn test_quantile_axis_mut_with_empty_array() { 189 | let mut a = Array2::::zeros((5, 0)); 190 | let p = a.quantile_axis_mut(Axis(0), n64(0.5), &Lower).unwrap(); 191 | assert_eq!(p.shape(), &[0]); 192 | } 193 | 194 | #[test] 195 | fn test_quantile_axis_mut_with_even_axis_length() { 196 | let mut b = arr2(&[[1, 3, 2, 10], [2, 4, 3, 11], [3, 5, 6, 12], [4, 6, 7, 13]]); 197 | let q = b.quantile_axis_mut(Axis(0), n64(0.5), &Lower).unwrap(); 198 | assert!(q == b.index_axis(Axis(0), 1)); 199 | } 200 | 201 | #[test] 202 | fn test_quantile_axis_mut_to_get_minimum() { 203 | let mut b = arr2(&[[1, 3, 22, 10]]); 204 | let q = b.quantile_axis_mut(Axis(1), n64(0.), &Lower).unwrap(); 205 | assert!(q == arr1(&[1])); 206 | } 207 | 208 | #[test] 209 | fn test_quantile_axis_mut_to_get_maximum() { 210 | let mut b = arr1(&[1, 3, 22, 10]); 211 | let q = b.quantile_axis_mut(Axis(0), n64(1.), &Lower).unwrap(); 212 | assert!(q == arr0(22)); 213 | } 214 | 215 | #[test] 216 | fn test_quantile_axis_skipnan_mut_higher_opt_i32() { 217 | let mut a = arr2(&[[Some(4), Some(2), None, Some(1), Some(5)], [None; 5]]); 218 | let q = a 219 | .quantile_axis_skipnan_mut(Axis(1), n64(0.6), &Higher) 220 | .unwrap(); 221 | assert_eq!(q.shape(), &[2]); 222 | assert_eq!(q[0], Some(4)); 223 | assert!(q[1].is_none()); 224 | } 225 | 226 | #[test] 227 | fn test_quantile_axis_skipnan_mut_nearest_opt_i32() { 228 | let mut a = arr2(&[[Some(4), Some(2), None, Some(1), Some(5)], [None; 5]]); 229 | let q = a 230 | .quantile_axis_skipnan_mut(Axis(1), n64(0.6), &Nearest) 231 | .unwrap(); 232 | assert_eq!(q.shape(), &[2]); 233 | assert_eq!(q[0], Some(4)); 234 | assert!(q[1].is_none()); 235 | } 236 | 237 | #[test] 238 | fn test_quantile_axis_skipnan_mut_midpoint_opt_i32() { 239 | let mut a = arr2(&[[Some(4), Some(2), None, Some(1), Some(5)], [None; 5]]); 240 | let q = a 241 | .quantile_axis_skipnan_mut(Axis(1), n64(0.6), &Midpoint) 242 | .unwrap(); 243 | assert_eq!(q.shape(), &[2]); 244 | assert_eq!(q[0], Some(3)); 245 | assert!(q[1].is_none()); 246 | } 247 | 248 | #[test] 249 | fn test_quantile_axis_skipnan_mut_linear_f64() { 250 | let mut a = arr2(&[[1., 2., ::std::f64::NAN, 3.], [::std::f64::NAN; 4]]); 251 | let q = a 252 | .quantile_axis_skipnan_mut(Axis(1), n64(0.75), &Linear) 253 | .unwrap(); 254 | assert_eq!(q.shape(), &[2]); 255 | assert!((q[0] - 2.5).abs() < 1e-12); 256 | assert!(q[1].is_nan()); 257 | } 258 | 259 | #[test] 260 | fn test_quantile_axis_skipnan_mut_linear_opt_i32() { 261 | let mut a = arr2(&[[Some(2), Some(4), None, Some(1)], [None; 4]]); 262 | let q = a 263 | .quantile_axis_skipnan_mut(Axis(1), n64(0.75), &Linear) 264 | .unwrap(); 265 | assert_eq!(q.shape(), &[2]); 266 | assert_eq!(q[0], Some(3)); 267 | assert!(q[1].is_none()); 268 | } 269 | 270 | #[test] 271 | fn test_midpoint_overflow() { 272 | // Regression test 273 | // This triggered an overflow panic with a naive Midpoint implementation: (a+b)/2 274 | let mut a: Array1 = array![129, 130, 130, 131]; 275 | let median = a.quantile_mut(n64(0.5), &Midpoint).unwrap(); 276 | let expected_median = 130; 277 | assert_eq!(median, expected_median); 278 | } 279 | 280 | #[quickcheck] 281 | fn test_quantiles_mut(xs: Vec) -> bool { 282 | let v = Array::from(xs.clone()); 283 | 284 | // Unordered list of quantile indexes to look up, with a duplicate 285 | let quantile_indexes = Array::from(vec![ 286 | n64(0.75), 287 | n64(0.90), 288 | n64(0.95), 289 | n64(0.99), 290 | n64(1.), 291 | n64(0.), 292 | n64(0.25), 293 | n64(0.5), 294 | n64(0.5), 295 | ]); 296 | let mut correct = true; 297 | correct &= check_one_interpolation_method_for_quantiles_mut( 298 | v.clone(), 299 | quantile_indexes.view(), 300 | &Linear, 301 | ); 302 | correct &= check_one_interpolation_method_for_quantiles_mut( 303 | v.clone(), 304 | quantile_indexes.view(), 305 | &Higher, 306 | ); 307 | correct &= check_one_interpolation_method_for_quantiles_mut( 308 | v.clone(), 309 | quantile_indexes.view(), 310 | &Lower, 311 | ); 312 | correct &= check_one_interpolation_method_for_quantiles_mut( 313 | v.clone(), 314 | quantile_indexes.view(), 315 | &Midpoint, 316 | ); 317 | correct &= check_one_interpolation_method_for_quantiles_mut( 318 | v.clone(), 319 | quantile_indexes.view(), 320 | &Nearest, 321 | ); 322 | correct 323 | } 324 | 325 | fn check_one_interpolation_method_for_quantiles_mut( 326 | mut v: Array1, 327 | quantile_indexes: ArrayView1<'_, N64>, 328 | interpolate: &impl Interpolate, 329 | ) -> bool { 330 | let bulk_quantiles = v.clone().quantiles_mut(&quantile_indexes, interpolate); 331 | 332 | if v.len() == 0 { 333 | bulk_quantiles.is_err() 334 | } else { 335 | let bulk_quantiles = bulk_quantiles.unwrap(); 336 | izip!(quantile_indexes, &bulk_quantiles).all(|(&quantile_index, &quantile)| { 337 | quantile == v.quantile_mut(quantile_index, interpolate).unwrap() 338 | }) 339 | } 340 | } 341 | 342 | #[quickcheck] 343 | fn test_quantiles_axis_mut(mut xs: Vec) -> bool { 344 | // We want a square matrix 345 | let axis_length = (xs.len() as f64).sqrt().floor() as usize; 346 | xs.truncate(axis_length * axis_length); 347 | let m = Array::from_shape_vec((axis_length, axis_length), xs).unwrap(); 348 | 349 | // Unordered list of quantile indexes to look up, with a duplicate 350 | let quantile_indexes = Array::from(vec![ 351 | n64(0.75), 352 | n64(0.90), 353 | n64(0.95), 354 | n64(0.99), 355 | n64(1.), 356 | n64(0.), 357 | n64(0.25), 358 | n64(0.5), 359 | n64(0.5), 360 | ]); 361 | 362 | // Test out all interpolation methods 363 | let mut correct = true; 364 | correct &= check_one_interpolation_method_for_quantiles_axis_mut( 365 | m.clone(), 366 | quantile_indexes.view(), 367 | Axis(0), 368 | &Linear, 369 | ); 370 | correct &= check_one_interpolation_method_for_quantiles_axis_mut( 371 | m.clone(), 372 | quantile_indexes.view(), 373 | Axis(0), 374 | &Higher, 375 | ); 376 | correct &= check_one_interpolation_method_for_quantiles_axis_mut( 377 | m.clone(), 378 | quantile_indexes.view(), 379 | Axis(0), 380 | &Lower, 381 | ); 382 | correct &= check_one_interpolation_method_for_quantiles_axis_mut( 383 | m.clone(), 384 | quantile_indexes.view(), 385 | Axis(0), 386 | &Midpoint, 387 | ); 388 | correct &= check_one_interpolation_method_for_quantiles_axis_mut( 389 | m.clone(), 390 | quantile_indexes.view(), 391 | Axis(0), 392 | &Nearest, 393 | ); 394 | correct 395 | } 396 | 397 | fn check_one_interpolation_method_for_quantiles_axis_mut( 398 | mut v: Array2, 399 | quantile_indexes: ArrayView1<'_, N64>, 400 | axis: Axis, 401 | interpolate: &impl Interpolate, 402 | ) -> bool { 403 | let bulk_quantiles = v 404 | .clone() 405 | .quantiles_axis_mut(axis, &quantile_indexes, interpolate); 406 | 407 | if v.len() == 0 { 408 | bulk_quantiles.is_err() 409 | } else { 410 | let bulk_quantiles = bulk_quantiles.unwrap(); 411 | izip!(quantile_indexes, bulk_quantiles.axis_iter(axis)).all( 412 | |(&quantile_index, quantile)| { 413 | quantile 414 | == v.quantile_axis_mut(axis, quantile_index, interpolate) 415 | .unwrap() 416 | }, 417 | ) 418 | } 419 | } 420 | -------------------------------------------------------------------------------- /tests/sort.rs: -------------------------------------------------------------------------------- 1 | use ndarray::prelude::*; 2 | use ndarray_stats::Sort1dExt; 3 | use quickcheck_macros::quickcheck; 4 | 5 | #[test] 6 | fn test_partition_mut() { 7 | let mut l = vec![ 8 | arr1(&[1, 1, 1, 1, 1]), 9 | arr1(&[1, 3, 2, 10, 10]), 10 | arr1(&[2, 3, 4, 1]), 11 | arr1(&[ 12 | 355, 453, 452, 391, 289, 343, 44, 154, 271, 44, 314, 276, 160, 469, 191, 138, 163, 308, 13 | 395, 3, 416, 391, 210, 354, 200, 14 | ]), 15 | arr1(&[ 16 | 84, 192, 216, 159, 89, 296, 35, 213, 456, 278, 98, 52, 308, 418, 329, 173, 286, 106, 17 | 366, 129, 125, 450, 23, 463, 151, 18 | ]), 19 | ]; 20 | for a in l.iter_mut() { 21 | let n = a.len(); 22 | let pivot_index = n - 1; 23 | let pivot_value = a[pivot_index].clone(); 24 | let partition_index = a.partition_mut(pivot_index); 25 | for i in 0..partition_index { 26 | assert!(a[i] < pivot_value); 27 | } 28 | assert_eq!(a[partition_index], pivot_value); 29 | for j in (partition_index + 1)..n { 30 | assert!(pivot_value <= a[j]); 31 | } 32 | } 33 | } 34 | 35 | #[test] 36 | fn test_sorted_get_mut() { 37 | let a = arr1(&[1, 3, 2, 10]); 38 | let j = a.clone().view_mut().get_from_sorted_mut(2); 39 | assert_eq!(j, 3); 40 | let j = a.clone().view_mut().get_from_sorted_mut(1); 41 | assert_eq!(j, 2); 42 | let j = a.clone().view_mut().get_from_sorted_mut(3); 43 | assert_eq!(j, 10); 44 | } 45 | 46 | #[quickcheck] 47 | fn test_sorted_get_many_mut(mut xs: Vec) -> bool { 48 | let n = xs.len(); 49 | if n == 0 { 50 | true 51 | } else { 52 | let mut v = Array::from(xs.clone()); 53 | 54 | // Insert each index twice, to get a set of indexes with duplicates, not sorted 55 | let mut indexes: Vec = (0..n).into_iter().collect(); 56 | indexes.append(&mut (0..n).collect()); 57 | 58 | let mut sorted_v = Vec::with_capacity(n); 59 | for (i, (key, value)) in v 60 | .get_many_from_sorted_mut(&Array::from(indexes)) 61 | .into_iter() 62 | .enumerate() 63 | { 64 | if i != key { 65 | return false; 66 | } 67 | sorted_v.push(value); 68 | } 69 | xs.sort(); 70 | println!("Sorted: {:?}. Truth: {:?}", sorted_v, xs); 71 | xs == sorted_v 72 | } 73 | } 74 | 75 | #[quickcheck] 76 | fn test_sorted_get_mut_as_sorting_algorithm(mut xs: Vec) -> bool { 77 | let n = xs.len(); 78 | if n == 0 { 79 | true 80 | } else { 81 | let mut v = Array::from(xs.clone()); 82 | let sorted_v: Vec<_> = (0..n).map(|i| v.get_from_sorted_mut(i)).collect(); 83 | xs.sort(); 84 | xs == sorted_v 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /tests/summary_statistics.rs: -------------------------------------------------------------------------------- 1 | use approx::{abs_diff_eq, assert_abs_diff_eq}; 2 | use ndarray::{arr0, array, Array, Array1, Array2, Axis}; 3 | use ndarray_rand::rand_distr::Uniform; 4 | use ndarray_rand::RandomExt; 5 | use ndarray_stats::{ 6 | errors::{EmptyInput, MultiInputError, ShapeMismatch}, 7 | SummaryStatisticsExt, 8 | }; 9 | use noisy_float::types::N64; 10 | use quickcheck::{quickcheck, TestResult}; 11 | use std::f64; 12 | 13 | #[test] 14 | fn test_with_nan_values() { 15 | let a = array![f64::NAN, 1.]; 16 | let weights = array![1.0, f64::NAN]; 17 | assert!(a.mean().unwrap().is_nan()); 18 | assert!(a.weighted_mean(&weights).unwrap().is_nan()); 19 | assert!(a.weighted_sum(&weights).unwrap().is_nan()); 20 | assert!(a 21 | .weighted_mean_axis(Axis(0), &weights) 22 | .unwrap() 23 | .into_scalar() 24 | .is_nan()); 25 | assert!(a 26 | .weighted_sum_axis(Axis(0), &weights) 27 | .unwrap() 28 | .into_scalar() 29 | .is_nan()); 30 | assert!(a.harmonic_mean().unwrap().is_nan()); 31 | assert!(a.geometric_mean().unwrap().is_nan()); 32 | assert!(a.weighted_var(&weights, 0.0).unwrap().is_nan()); 33 | assert!(a.weighted_std(&weights, 0.0).unwrap().is_nan()); 34 | assert!(a 35 | .weighted_var_axis(Axis(0), &weights, 0.0) 36 | .unwrap() 37 | .into_scalar() 38 | .is_nan()); 39 | assert!(a 40 | .weighted_std_axis(Axis(0), &weights, 0.0) 41 | .unwrap() 42 | .into_scalar() 43 | .is_nan()); 44 | } 45 | 46 | #[test] 47 | fn test_with_empty_array_of_floats() { 48 | let a: Array1 = array![]; 49 | let weights = array![1.0]; 50 | assert_eq!(a.mean(), None); 51 | assert_eq!(a.weighted_mean(&weights), Err(MultiInputError::EmptyInput)); 52 | assert_eq!( 53 | a.weighted_mean_axis(Axis(0), &weights), 54 | Err(MultiInputError::EmptyInput) 55 | ); 56 | assert_eq!(a.harmonic_mean(), Err(EmptyInput)); 57 | assert_eq!(a.geometric_mean(), Err(EmptyInput)); 58 | assert_eq!( 59 | a.weighted_var(&weights, 0.0), 60 | Err(MultiInputError::EmptyInput) 61 | ); 62 | assert_eq!( 63 | a.weighted_std(&weights, 0.0), 64 | Err(MultiInputError::EmptyInput) 65 | ); 66 | assert_eq!( 67 | a.weighted_var_axis(Axis(0), &weights, 0.0), 68 | Err(MultiInputError::EmptyInput) 69 | ); 70 | assert_eq!( 71 | a.weighted_std_axis(Axis(0), &weights, 0.0), 72 | Err(MultiInputError::EmptyInput) 73 | ); 74 | 75 | // The sum methods accept empty arrays 76 | assert_eq!(a.weighted_sum(&array![]), Ok(0.0)); 77 | assert_eq!(a.weighted_sum_axis(Axis(0), &array![]), Ok(arr0(0.0))); 78 | } 79 | 80 | #[test] 81 | fn test_with_empty_array_of_noisy_floats() { 82 | let a: Array1 = array![]; 83 | let weights = array![]; 84 | assert_eq!(a.mean(), None); 85 | assert_eq!(a.weighted_mean(&weights), Err(MultiInputError::EmptyInput)); 86 | assert_eq!( 87 | a.weighted_mean_axis(Axis(0), &weights), 88 | Err(MultiInputError::EmptyInput) 89 | ); 90 | assert_eq!(a.harmonic_mean(), Err(EmptyInput)); 91 | assert_eq!(a.geometric_mean(), Err(EmptyInput)); 92 | assert_eq!( 93 | a.weighted_var(&weights, N64::new(0.0)), 94 | Err(MultiInputError::EmptyInput) 95 | ); 96 | assert_eq!( 97 | a.weighted_std(&weights, N64::new(0.0)), 98 | Err(MultiInputError::EmptyInput) 99 | ); 100 | assert_eq!( 101 | a.weighted_var_axis(Axis(0), &weights, N64::new(0.0)), 102 | Err(MultiInputError::EmptyInput) 103 | ); 104 | assert_eq!( 105 | a.weighted_std_axis(Axis(0), &weights, N64::new(0.0)), 106 | Err(MultiInputError::EmptyInput) 107 | ); 108 | 109 | // The sum methods accept empty arrays 110 | assert_eq!(a.weighted_sum(&weights), Ok(N64::new(0.0))); 111 | assert_eq!( 112 | a.weighted_sum_axis(Axis(0), &weights), 113 | Ok(arr0(N64::new(0.0))) 114 | ); 115 | } 116 | 117 | #[test] 118 | fn test_with_array_of_floats() { 119 | let a: Array1 = array![ 120 | 0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936, 121 | 0.22879429, 0.43997224, 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487, 122 | 0.78995021, 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, 0.34429457, 123 | 0.88072369, 0.17638164, 0.60819363, 0.250392, 0.69912532, 0.78855523, 0.79140914, 124 | 0.85084218, 0.31839879, 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153, 125 | 0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, 0.63608897, 0.84959691, 126 | 0.43599069, 0.77867775, 0.88267754, 0.83003623, 0.67016118, 0.67547638, 0.65220036, 127 | 0.68043427 128 | ]; 129 | // Computed using NumPy 130 | let expected_mean = 0.5475494059146699; 131 | let expected_weighted_mean = 0.6782420496397121; 132 | let expected_weighted_var = 0.04306695637838332; 133 | // Computed using SciPy 134 | let expected_harmonic_mean = 0.21790094950226022; 135 | let expected_geometric_mean = 0.4345897639796527; 136 | 137 | assert_abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = 1e-9); 138 | assert_abs_diff_eq!( 139 | a.harmonic_mean().unwrap(), 140 | expected_harmonic_mean, 141 | epsilon = 1e-7 142 | ); 143 | assert_abs_diff_eq!( 144 | a.geometric_mean().unwrap(), 145 | expected_geometric_mean, 146 | epsilon = 1e-12 147 | ); 148 | 149 | // Input array used as weights, normalized 150 | let weights = &a / a.sum(); 151 | assert_abs_diff_eq!( 152 | a.weighted_sum(&weights).unwrap(), 153 | expected_weighted_mean, 154 | epsilon = 1e-12 155 | ); 156 | assert_abs_diff_eq!( 157 | a.weighted_var(&weights, 0.0).unwrap(), 158 | expected_weighted_var, 159 | epsilon = 1e-12 160 | ); 161 | assert_abs_diff_eq!( 162 | a.weighted_std(&weights, 0.0).unwrap(), 163 | expected_weighted_var.sqrt(), 164 | epsilon = 1e-12 165 | ); 166 | 167 | let data = a.into_shape_with_order((2, 5, 5)).unwrap(); 168 | let weights = array![0.1, 0.5, 0.25, 0.15, 0.2]; 169 | assert_abs_diff_eq!( 170 | data.weighted_mean_axis(Axis(1), &weights).unwrap(), 171 | array![ 172 | [0.50202721, 0.53347361, 0.29086033, 0.56995637, 0.37087139], 173 | [0.58028328, 0.50485216, 0.59349973, 0.70308937, 0.72280630] 174 | ], 175 | epsilon = 1e-8 176 | ); 177 | assert_abs_diff_eq!( 178 | data.weighted_mean_axis(Axis(2), &weights).unwrap(), 179 | array![ 180 | [0.33434378, 0.38365259, 0.56405781, 0.48676574, 0.55016179], 181 | [0.71112376, 0.55134174, 0.45566513, 0.74228516, 0.68405851] 182 | ], 183 | epsilon = 1e-8 184 | ); 185 | assert_abs_diff_eq!( 186 | data.weighted_sum_axis(Axis(1), &weights).unwrap(), 187 | array![ 188 | [0.60243266, 0.64016833, 0.34903240, 0.68394765, 0.44504567], 189 | [0.69633993, 0.60582259, 0.71219968, 0.84370724, 0.86736757] 190 | ], 191 | epsilon = 1e-8 192 | ); 193 | assert_abs_diff_eq!( 194 | data.weighted_sum_axis(Axis(2), &weights).unwrap(), 195 | array![ 196 | [0.40121254, 0.46038311, 0.67686937, 0.58411889, 0.66019415], 197 | [0.85334851, 0.66161009, 0.54679815, 0.89074219, 0.82087021] 198 | ], 199 | epsilon = 1e-8 200 | ); 201 | } 202 | 203 | #[test] 204 | fn weighted_sum_dimension_zero() { 205 | let a = Array2::::zeros((0, 20)); 206 | assert_eq!( 207 | a.weighted_sum_axis(Axis(0), &Array1::zeros(0)).unwrap(), 208 | Array1::from_elem(20, 0) 209 | ); 210 | assert_eq!( 211 | a.weighted_sum_axis(Axis(1), &Array1::zeros(20)).unwrap(), 212 | Array1::from_elem(0, 0) 213 | ); 214 | assert_eq!( 215 | a.weighted_sum_axis(Axis(0), &Array1::zeros(1)), 216 | Err(MultiInputError::ShapeMismatch(ShapeMismatch { 217 | first_shape: vec![0, 20], 218 | second_shape: vec![1] 219 | })) 220 | ); 221 | assert_eq!( 222 | a.weighted_sum(&Array2::zeros((10, 20))), 223 | Err(MultiInputError::ShapeMismatch(ShapeMismatch { 224 | first_shape: vec![0, 20], 225 | second_shape: vec![10, 20] 226 | })) 227 | ); 228 | } 229 | 230 | #[test] 231 | fn mean_eq_if_uniform_weights() { 232 | fn prop(a: Vec) -> TestResult { 233 | if a.len() < 1 { 234 | return TestResult::discard(); 235 | } 236 | let a = Array1::from(a); 237 | let weights = Array1::from_elem(a.len(), 1.0 / a.len() as f64); 238 | let m = a.mean().unwrap(); 239 | let wm = a.weighted_mean(&weights).unwrap(); 240 | let ws = a.weighted_sum(&weights).unwrap(); 241 | TestResult::from_bool( 242 | abs_diff_eq!(m, wm, epsilon = 1e-9) && abs_diff_eq!(wm, ws, epsilon = 1e-9), 243 | ) 244 | } 245 | quickcheck(prop as fn(Vec) -> TestResult); 246 | } 247 | 248 | #[test] 249 | fn mean_axis_eq_if_uniform_weights() { 250 | fn prop(mut a: Vec) -> TestResult { 251 | if a.len() < 24 { 252 | return TestResult::discard(); 253 | } 254 | let depth = a.len() / 12; 255 | a.truncate(depth * 3 * 4); 256 | let weights = Array1::from_elem(depth, 1.0 / depth as f64); 257 | let a = Array1::from(a) 258 | .into_shape_with_order((depth, 3, 4)) 259 | .unwrap(); 260 | let ma = a.mean_axis(Axis(0)).unwrap(); 261 | let wm = a.weighted_mean_axis(Axis(0), &weights).unwrap(); 262 | let ws = a.weighted_sum_axis(Axis(0), &weights).unwrap(); 263 | TestResult::from_bool( 264 | abs_diff_eq!(ma, wm, epsilon = 1e-12) && abs_diff_eq!(wm, ws, epsilon = 1e12), 265 | ) 266 | } 267 | quickcheck(prop as fn(Vec) -> TestResult); 268 | } 269 | 270 | #[test] 271 | fn weighted_var_eq_var_if_uniform_weight() { 272 | fn prop(a: Vec) -> TestResult { 273 | if a.len() < 1 { 274 | return TestResult::discard(); 275 | } 276 | let a = Array1::from(a); 277 | let weights = Array1::from_elem(a.len(), 1.0 / a.len() as f64); 278 | let weighted_var = a.weighted_var(&weights, 0.0).unwrap(); 279 | let var = a.var_axis(Axis(0), 0.0).into_scalar(); 280 | TestResult::from_bool(abs_diff_eq!(weighted_var, var, epsilon = 1e-10)) 281 | } 282 | quickcheck(prop as fn(Vec) -> TestResult); 283 | } 284 | 285 | #[test] 286 | fn weighted_var_algo_eq_simple_algo() { 287 | fn prop(mut a: Vec) -> TestResult { 288 | if a.len() < 24 { 289 | return TestResult::discard(); 290 | } 291 | let depth = a.len() / 12; 292 | a.truncate(depth * 3 * 4); 293 | let a = Array1::from(a) 294 | .into_shape_with_order((depth, 3, 4)) 295 | .unwrap(); 296 | let mut success = true; 297 | for axis in 0..3 { 298 | let axis = Axis(axis); 299 | 300 | let weights = Array::random(a.len_of(axis), Uniform::new(0.0, 1.0)); 301 | let mean = a 302 | .weighted_mean_axis(axis, &weights) 303 | .unwrap() 304 | .insert_axis(axis); 305 | let res_1_pass = a.weighted_var_axis(axis, &weights, 0.0).unwrap(); 306 | let res_2_pass = (&a - &mean) 307 | .mapv_into(|v| v.powi(2)) 308 | .weighted_mean_axis(axis, &weights) 309 | .unwrap(); 310 | success &= abs_diff_eq!(res_1_pass, res_2_pass, epsilon = 1e-10); 311 | } 312 | TestResult::from_bool(success) 313 | } 314 | quickcheck(prop as fn(Vec) -> TestResult); 315 | } 316 | 317 | #[test] 318 | fn test_central_moment_with_empty_array_of_floats() { 319 | let a: Array1 = array![]; 320 | for order in 0..=3 { 321 | assert_eq!(a.central_moment(order), Err(EmptyInput)); 322 | assert_eq!(a.central_moments(order), Err(EmptyInput)); 323 | } 324 | } 325 | 326 | #[test] 327 | fn test_zeroth_central_moment_is_one() { 328 | let n = 50; 329 | let bound: f64 = 200.; 330 | let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); 331 | assert_eq!(a.central_moment(0).unwrap(), 1.); 332 | } 333 | 334 | #[test] 335 | fn test_first_central_moment_is_zero() { 336 | let n = 50; 337 | let bound: f64 = 200.; 338 | let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); 339 | assert_eq!(a.central_moment(1).unwrap(), 0.); 340 | } 341 | 342 | #[test] 343 | fn test_central_moments() { 344 | let a: Array1 = array![ 345 | 0.07820559, 0.5026185, 0.80935324, 0.39384033, 0.9483038, 0.62516215, 0.90772261, 346 | 0.87329831, 0.60267392, 0.2960298, 0.02810356, 0.31911966, 0.86705506, 0.96884832, 347 | 0.2222465, 0.42162446, 0.99909868, 0.47619762, 0.91696979, 0.9972741, 0.09891734, 348 | 0.76934818, 0.77566862, 0.7692585, 0.2235759, 0.44821286, 0.79732186, 0.04804275, 349 | 0.87863238, 0.1111003, 0.6653943, 0.44386445, 0.2133176, 0.39397086, 0.4374617, 0.95896624, 350 | 0.57850146, 0.29301706, 0.02329879, 0.2123203, 0.62005503, 0.996492, 0.5342986, 0.97822099, 351 | 0.5028445, 0.6693834, 0.14256682, 0.52724704, 0.73482372, 0.1809703, 352 | ]; 353 | // Computed using scipy.stats.moment 354 | let expected_moments = vec![ 355 | 1., 356 | 0., 357 | 0.09339920262960291, 358 | -0.0026849636727735186, 359 | 0.015403769257729755, 360 | -0.001204176487006564, 361 | 0.002976822584939186, 362 | ]; 363 | for (order, expected_moment) in expected_moments.iter().enumerate() { 364 | assert_abs_diff_eq!( 365 | a.central_moment(order as u16).unwrap(), 366 | expected_moment, 367 | epsilon = 1e-8 368 | ); 369 | } 370 | } 371 | 372 | #[test] 373 | fn test_bulk_central_moments() { 374 | // Test that the bulk method is coherent with the non-bulk method 375 | let n = 50; 376 | let bound: f64 = 200.; 377 | let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); 378 | let order = 10; 379 | let central_moments = a.central_moments(order).unwrap(); 380 | for i in 0..=order { 381 | assert_eq!(a.central_moment(i).unwrap(), central_moments[i as usize]); 382 | } 383 | } 384 | 385 | #[test] 386 | fn test_kurtosis_and_skewness_is_none_with_empty_array_of_floats() { 387 | let a: Array1 = array![]; 388 | assert_eq!(a.skewness(), Err(EmptyInput)); 389 | assert_eq!(a.kurtosis(), Err(EmptyInput)); 390 | } 391 | 392 | #[test] 393 | fn test_kurtosis_and_skewness() { 394 | let a: Array1 = array![ 395 | 0.33310096, 0.98757449, 0.9789796, 0.96738114, 0.43545674, 0.06746873, 0.23706562, 396 | 0.04241815, 0.38961714, 0.52421271, 0.93430327, 0.33911604, 0.05112372, 0.5013455, 397 | 0.05291507, 0.62511183, 0.20749633, 0.22132433, 0.14734804, 0.51960608, 0.00449208, 398 | 0.4093339, 0.2237519, 0.28070469, 0.7887231, 0.92224523, 0.43454188, 0.18335111, 399 | 0.08646856, 0.87979847, 0.25483457, 0.99975627, 0.52712442, 0.41163279, 0.85162594, 400 | 0.52618733, 0.75815023, 0.30640695, 0.14205781, 0.59695813, 0.851331, 0.39524328, 401 | 0.73965373, 0.4007615, 0.02133069, 0.92899207, 0.79878191, 0.38947334, 0.22042183, 402 | 0.77768353, 403 | ]; 404 | // Computed using scipy.stats.kurtosis(a, fisher=False) 405 | let expected_kurtosis = 1.821933711687523; 406 | // Computed using scipy.stats.skew 407 | let expected_skewness = 0.2604785422878771; 408 | 409 | let kurtosis = a.kurtosis().unwrap(); 410 | let skewness = a.skewness().unwrap(); 411 | 412 | assert_abs_diff_eq!(kurtosis, expected_kurtosis, epsilon = 1e-12); 413 | assert_abs_diff_eq!(skewness, expected_skewness, epsilon = 1e-8); 414 | } 415 | --------------------------------------------------------------------------------