├── codecov.yml ├── .gitignore ├── src ├── function │ ├── mod.rs │ ├── evaluate.rs │ ├── harmonic.rs │ ├── logistic.rs │ ├── exponential.rs │ ├── factorial.rs │ └── kernel.rs ├── statistics │ ├── mod.rs │ ├── traits.rs │ ├── order_statistics.rs │ ├── iter_statistics.rs │ └── statistics.rs ├── consts.rs ├── stats_tests │ ├── mod.rs │ ├── anderson_darling.rs │ ├── chisquare.rs │ ├── ttest_onesample.rs │ ├── skewtest.rs │ ├── f_oneway.rs │ └── fisher.rs ├── euclid.rs ├── lib.rs ├── distribution │ ├── ziggurat.rs │ ├── dirac.rs │ ├── bernoulli.rs │ ├── erlang.rs │ ├── chi_squared.rs │ └── mod.rs ├── prec.rs └── generate.rs ├── .github ├── dependabot.yml └── workflows │ ├── coverage.yml │ └── test.yml ├── LICENSE.md ├── tests ├── gather_nist_data.sh └── nist_tests.rs ├── Cargo.toml ├── benches └── order_statistics.rs └── README.md /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | enabled: yes 6 | target: 90% 7 | threshold: 0.5% -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled files 2 | *.o 3 | *.so 4 | *.rlib 5 | *.dll 6 | 7 | # Executables 8 | *.exe 9 | 10 | # Test data for integration tests 11 | tests/*.dat 12 | 13 | # Generated by Cargo 14 | /target/ 15 | *.lock 16 | 17 | #editor specific 18 | /.vscode/ 19 | .idea/ 20 | *.iml 21 | 22 | # macOS 23 | .DS_Store 24 | -------------------------------------------------------------------------------- /src/function/mod.rs: -------------------------------------------------------------------------------- 1 | //! Provides a host of special statistical functions (e.g. the beta function or 2 | //! the error function) 3 | 4 | pub mod beta; 5 | pub mod erf; 6 | pub mod evaluate; 7 | pub mod exponential; 8 | pub mod factorial; 9 | pub mod gamma; 10 | pub mod harmonic; 11 | pub mod kernel; 12 | pub mod logistic; 13 | -------------------------------------------------------------------------------- /src/statistics/mod.rs: -------------------------------------------------------------------------------- 1 | //! Provides traits for statistical computation 2 | 3 | pub use self::order_statistics::*; 4 | pub use self::slice_statistics::*; 5 | pub use self::statistics::*; 6 | pub use self::traits::*; 7 | 8 | mod iter_statistics; 9 | mod order_statistics; 10 | // TODO: fix later 11 | mod slice_statistics; 12 | #[allow(clippy::module_inception)] 13 | mod statistics; 14 | mod traits; 15 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "cargo" 9 | directory: "/" 10 | open-pull-requests-limit: 10 11 | schedule: 12 | interval: "monthly" 13 | - package-ecosystem: "github-actions" 14 | directory: "/" 15 | schedule: 16 | interval: "monthly" 17 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Michael Ma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/consts.rs: -------------------------------------------------------------------------------- 1 | //! Defines mathematical expressions commonly used when computing distribution 2 | //! values as constants 3 | 4 | /// Constant value for `sqrt(2 * pi)` 5 | pub const SQRT_2PI: f64 = 2.5066282746310005024157652848110452530069867406099; 6 | 7 | /// Constant value for `ln(pi)` 8 | pub const LN_PI: f64 = 1.1447298858494001741434273513530587116472948129153; 9 | 10 | /// Constant value for `ln(sqrt(2 * pi))` 11 | pub const LN_SQRT_2PI: f64 = 0.91893853320467274178032973640561763986139747363778; 12 | 13 | /// Constant value for `ln(sqrt(2 * pi * e))` 14 | pub const LN_SQRT_2PIE: f64 = 1.4189385332046727417803297364056176398613974736378; 15 | 16 | /// Constant value for `ln(2 * sqrt(e / pi))` 17 | pub const LN_2_SQRT_E_OVER_PI: f64 = 0.6207822376352452223455184457816472122518527279025978; 18 | 19 | /// Constant value for `2 * sqrt(e / pi)` 20 | pub const TWO_SQRT_E_OVER_PI: f64 = 1.8603827342052657173362492472666631120594218414085755; 21 | 22 | /// Constant value for Euler-Masheroni constant `lim(n -> inf) { sum(k=1 -> n) 23 | /// { 1/k - ln(n) } }` 24 | pub const EULER_MASCHERONI: f64 = 25 | 0.5772156649015328606065120900824024310421593359399235988057672348849; 26 | -------------------------------------------------------------------------------- /src/stats_tests/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "std")] 2 | pub mod anderson_darling; 3 | #[cfg(feature = "std")] 4 | pub mod chisquare; 5 | #[cfg(feature = "std")] 6 | pub mod f_oneway; 7 | pub mod fisher; 8 | #[cfg(feature = "std")] 9 | pub mod ks_test; 10 | #[cfg(feature = "std")] 11 | pub mod mannwhitneyu; 12 | #[cfg(feature = "std")] 13 | pub mod skewtest; 14 | #[cfg(feature = "std")] 15 | pub mod ttest_onesample; 16 | 17 | /// Specifies an [alternative hypothesis](https://en.wikipedia.org/wiki/Alternative_hypothesis) 18 | #[derive(Debug, Copy, Clone)] 19 | pub enum Alternative { 20 | #[doc(alias = "two-tailed")] 21 | #[doc(alias = "two tailed")] 22 | TwoSided, 23 | #[doc(alias = "one-tailed")] 24 | #[doc(alias = "one tailed")] 25 | Less, 26 | #[doc(alias = "one-tailed")] 27 | #[doc(alias = "one tailed")] 28 | Greater, 29 | } 30 | 31 | /// Specifies how to deal with NaNs provided in input data 32 | /// based on scipy treatment 33 | #[derive(Debug, Copy, Clone)] 34 | pub enum NaNPolicy { 35 | /// allow for NaNs; if exist fcuntion will return NaN 36 | Propogate, 37 | /// filter out the NaNs before calculations 38 | Emit, 39 | /// if NaNs are in the input data, return an Error 40 | Error, 41 | } 42 | 43 | pub use fisher::{fishers_exact, fishers_exact_with_odds_ratio}; 44 | -------------------------------------------------------------------------------- /tests/gather_nist_data.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # this script is to download and preprocess datafiles for the nist_tests.rs 3 | # integration test for statrs downloads data to directory specified by env 4 | # var STATRS_NIST_DATA_DIR 5 | 6 | process_file() { 7 | # Define input and output file names 8 | SOURCE=$1 9 | FILENAME=$2 10 | TARGET=${STATRS_NIST_DATA_DIR-tests}/${FILENAME} 11 | echo -e ${FILENAME} '\n\tDownloading...' 12 | curl -fsSL ${SOURCE}/$FILENAME > ${TARGET} 13 | 14 | # Extract line numbers for Certified Values and Data from the header 15 | INFO=$(grep "Certified Values:" $TARGET) 16 | CERTIFIED_VALUES_START=$(echo $INFO | awk '{print $4}') 17 | CERTIFIED_VALUES_END=$(echo $INFO | awk '{print $6}') 18 | 19 | INFO=$(grep "Data :" $TARGET) 20 | DATA_START=$(echo $INFO | awk '{print $4}') 21 | DATA_END=$(echo $INFO | awk '{print $6}') 22 | 23 | echo -e '\tFormatting...' 24 | # Extract and reformat sections 25 | sed -n -i \ 26 | -e "${CERTIFIED_VALUES_START},${CERTIFIED_VALUES_END}p" \ 27 | -e "${DATA_START},${DATA_END}p" \ 28 | $TARGET 29 | } 30 | 31 | URL='https://www.itl.nist.gov/div898/strd/univ/data' 32 | for file in Lottery.dat Lew.dat Mavro.dat Michelso.dat NumAcc1.dat NumAcc2.dat NumAcc3.dat 33 | do 34 | process_file $URL $file 35 | done 36 | 37 | -------------------------------------------------------------------------------- /.github/workflows/coverage.yml: -------------------------------------------------------------------------------- 1 | name: Coverage 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | jobs: 9 | coverage: 10 | strategy: # allows pinning additional nightly 11 | fail-fast: false # allows continue past failure of one toolchain if multiple 12 | matrix: 13 | toolchain: [ nightly ] 14 | 15 | # 360 mins is gh actions default 16 | timeout-minutes: ${{ matrix.toolchain == 'nightly' && 10 || 360 }} 17 | 18 | name: Coverage 19 | runs-on: ubuntu-latest 20 | env: 21 | RUSTFLAGS: -D warnings 22 | CARGO_TERM_COLOR: always 23 | steps: 24 | - uses: actions/checkout@v4 25 | - uses: dtolnay/rust-toolchain@master 26 | with: 27 | toolchain: ${{ matrix.toolchain }} 28 | components: llvm-tools-preview 29 | 30 | - uses: taiki-e/install-action@v2 31 | with: 32 | tool: nextest 33 | - uses: taiki-e/install-action@v2 34 | with: 35 | tool: cargo-llvm-cov 36 | 37 | - name: Collect coverage 38 | run: | 39 | cargo llvm-cov --no-report nextest 40 | cargo llvm-cov --no-report --doc 41 | cargo llvm-cov report --doctests --lcov --output-path lcov.info 42 | 43 | - name: Upload to codecov.io 44 | uses: codecov/codecov-action@v5 45 | with: 46 | files: lcov.info 47 | token: ${{secrets.CODECOV_TOKEN}} 48 | fail_ci_if_error: false 49 | -------------------------------------------------------------------------------- /src/euclid.rs: -------------------------------------------------------------------------------- 1 | //! Provides number theory utility functions 2 | 3 | /// Provides a trait for the canonical modulus operation since % is technically 4 | /// the remainder operation 5 | pub trait Modulus { 6 | /// Performs a canonical modulus operation between `self` and `divisor`. 7 | /// 8 | /// # Examples 9 | /// 10 | /// ``` 11 | /// use statrs::euclid::Modulus; 12 | /// 13 | /// let x = 4i64.modulus(5); 14 | /// assert_eq!(x, 4); 15 | /// 16 | /// let y = -4i64.modulus(5); 17 | /// assert_eq!(x, 4); 18 | /// ``` 19 | fn modulus(self, divisor: Self) -> Self; 20 | } 21 | 22 | impl Modulus for f64 { 23 | fn modulus(self, divisor: f64) -> f64 { 24 | ((self % divisor) + divisor) % divisor 25 | } 26 | } 27 | 28 | impl Modulus for f32 { 29 | fn modulus(self, divisor: f32) -> f32 { 30 | ((self % divisor) + divisor) % divisor 31 | } 32 | } 33 | 34 | impl Modulus for i64 { 35 | fn modulus(self, divisor: i64) -> i64 { 36 | ((self % divisor) + divisor) % divisor 37 | } 38 | } 39 | 40 | impl Modulus for i32 { 41 | fn modulus(self, divisor: i32) -> i32 { 42 | ((self % divisor) + divisor) % divisor 43 | } 44 | } 45 | 46 | impl Modulus for u64 { 47 | fn modulus(self, divisor: u64) -> u64 { 48 | ((self % divisor) + divisor) % divisor 49 | } 50 | } 51 | 52 | impl Modulus for u32 { 53 | fn modulus(self, divisor: u32) -> u32 { 54 | ((self % divisor) + divisor) % divisor 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "statrs" 3 | version = "0.18.0" 4 | authors = ["Michael Ma"] 5 | description = "Statistical computing library for Rust" 6 | license = "MIT" 7 | keywords = ["probability", "statistics", "stats", "distribution", "math"] 8 | categories = ["science"] 9 | homepage = "https://github.com/statrs-dev/statrs" 10 | repository = "https://github.com/statrs-dev/statrs" 11 | edition = "2024" 12 | 13 | include = ["CHANGELOG.md", "LICENSE.md", "src/", "tests/"] 14 | 15 | # When changing MSRV: Also update the README 16 | rust-version = "1.87.0" 17 | 18 | [lib] 19 | name = "statrs" 20 | path = "src/lib.rs" 21 | 22 | [[bench]] 23 | name = "order_statistics" 24 | harness = false 25 | required-features = ["rand", "std"] 26 | 27 | [features] 28 | default = ["std", "nalgebra", "rand"] 29 | std = ["nalgebra?/std", "rand?/std"] 30 | # at the moment, all nalgebra features needs std 31 | nalgebra = ["dep:nalgebra", "std"] 32 | rand = ["dep:rand", "nalgebra?/rand"] 33 | 34 | [dependencies] 35 | approx = "0.5.0" 36 | num-traits = "0.2.14" 37 | 38 | [dependencies.rand] 39 | version = "0.9.0" 40 | optional = true 41 | default-features = false 42 | 43 | [dependencies.nalgebra] 44 | version = "0.34" 45 | optional = true 46 | default-features = false 47 | 48 | [dev-dependencies] 49 | criterion = "0.5" 50 | anyhow = "1.0" 51 | 52 | [dev-dependencies.nalgebra] 53 | version = "0.34" 54 | default-features = false 55 | features = ["macros"] 56 | 57 | [lints.rust.unexpected_cfgs] 58 | level = "warn" 59 | # Set by cargo-llvm-cov when running on nightly 60 | check-cfg = ['cfg(coverage_nightly)'] 61 | 62 | [package.metadata.docs.rs] 63 | all-features = true 64 | rustdoc-args = ["--cfg", "docsrs"] 65 | -------------------------------------------------------------------------------- /src/function/evaluate.rs: -------------------------------------------------------------------------------- 1 | //! Provides functions that don't have a numerical solution and must 2 | //! be solved computationally (e.g. evaluation of a polynomial) 3 | 4 | /// evaluates a polynomial at `z` where `coeff` are the coeffecients 5 | /// to a polynomial of order `k` where `k` is the length of `coeff` and the 6 | /// coeffecient 7 | /// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to 8 | /// `2z^2 - z + 3` 9 | /// 10 | /// # Remarks 11 | /// 12 | /// Returns 0 for a 0 length coefficient slice 13 | pub fn polynomial(z: f64, coeff: &[f64]) -> f64 { 14 | let n = coeff.len(); 15 | if n == 0 { 16 | return 0.0; 17 | } 18 | 19 | let mut sum = *coeff.last().unwrap(); 20 | for c in coeff[0..n - 1].iter().rev() { 21 | sum = *c + z * sum; 22 | } 23 | sum 24 | } 25 | 26 | #[rustfmt::skip] 27 | #[cfg(test)] 28 | mod tests { 29 | use core::f64; 30 | 31 | // these tests probably could be more robust 32 | #[test] 33 | fn test_polynomial() { 34 | let empty: [f64; 0] = []; 35 | assert_eq!(super::polynomial(2.0, &empty), 0.0); 36 | 37 | let zero = [0.0]; 38 | assert_eq!(super::polynomial(2.0, &zero), 0.0); 39 | 40 | let mut coeff = [1.0, 0.0, 5.0]; 41 | assert_eq!(super::polynomial(2.0, &coeff), 21.0); 42 | 43 | coeff = [-5.0, -2.0, 3.0]; 44 | assert_eq!(super::polynomial(2.0, &coeff), 3.0); 45 | assert_eq!(super::polynomial(-2.0, &coeff), 11.0); 46 | 47 | let large_coeff = [-1.35e3, 2.5e2, 8.0, -4.0, 1e2, 3.0]; 48 | assert_eq!(super::polynomial(5.0, &large_coeff), 71475.0); 49 | assert_eq!(super::polynomial(-5.0, &large_coeff), 51225.0); 50 | 51 | coeff = [f64::INFINITY, -2.0, 3.0]; 52 | assert_eq!(super::polynomial(2.0, &coeff), f64::INFINITY); 53 | assert_eq!(super::polynomial(-2.0, &coeff), f64::INFINITY); 54 | 55 | coeff = [f64::NEG_INFINITY, -2.0, 3.0]; 56 | assert_eq!(super::polynomial(2.0, &coeff), f64::NEG_INFINITY); 57 | assert_eq!(super::polynomial(-2.0, &coeff), f64::NEG_INFINITY); 58 | 59 | coeff = [f64::NAN, -2.0, 3.0]; 60 | assert!(super::polynomial(2.0, &coeff).is_nan()); 61 | assert!(super::polynomial(-2.0, &coeff).is_nan()); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | env: 10 | CARGO_INCREMENTAL: 0 11 | RUSTFLAGS: "-Dwarnings" 12 | 13 | jobs: 14 | clippy: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Install Rust stable with clippy 19 | uses: dtolnay/rust-toolchain@stable 20 | with: 21 | components: clippy 22 | 23 | - name: Run cargo clippy (default features) 24 | run: cargo clippy --all-targets 25 | 26 | fmt: 27 | runs-on: ubuntu-latest 28 | steps: 29 | - uses: actions/checkout@v4 30 | - name: Install Rust stable with rustfmt 31 | uses: dtolnay/rust-toolchain@stable 32 | with: 33 | components: rustfmt 34 | 35 | - name: Run rustfmt --check 36 | run: cargo fmt -- --check 37 | 38 | msrv: 39 | runs-on: ubuntu-latest 40 | steps: 41 | - uses: actions/checkout@v4 42 | - name: Install cargo-hack 43 | uses: taiki-e/install-action@cargo-hack 44 | - uses: Swatinem/rust-cache@v2 45 | - name: Use predefined lockfile 46 | run: mv -v Cargo.lock.MSRV Cargo.lock 47 | - name: Build (lib only) 48 | run: cargo hack check --rust-version --locked 49 | 50 | test: 51 | needs: [clippy, fmt, msrv] 52 | runs-on: ${{ matrix.os }} 53 | strategy: 54 | fail-fast: false 55 | matrix: 56 | os: [ubuntu-latest, macos-latest, windows-latest] 57 | 58 | steps: 59 | - uses: actions/checkout@v4 60 | - name: Install Rust stable 61 | uses: dtolnay/rust-toolchain@stable 62 | 63 | - name: Test no_std 64 | run: cargo test --no-default-features -F rand 65 | 66 | - name: Test default features 67 | run: cargo test 68 | 69 | features: 70 | needs: [clippy, fmt] 71 | runs-on: ubuntu-latest 72 | steps: 73 | - uses: actions/checkout@v4 74 | - name: Install Rust stable 75 | uses: dtolnay/rust-toolchain@stable 76 | - name: Install cargo-hack 77 | uses: taiki-e/install-action@cargo-hack 78 | - uses: Swatinem/rust-cache@v2 79 | - name: Check all possible feature sets 80 | run: cargo hack check --feature-powerset --no-dev-deps 81 | -------------------------------------------------------------------------------- /src/function/harmonic.rs: -------------------------------------------------------------------------------- 1 | //! Provides functions for calculating 2 | //! [harmonic](https://en.wikipedia.org/wiki/Harmonic_number) 3 | //! numbers 4 | 5 | use crate::consts; 6 | use crate::function::gamma; 7 | 8 | /// Computes the `t`-th harmonic number 9 | /// 10 | /// # Remarks 11 | /// 12 | /// Returns `1` as a special case when `t == 0` 13 | pub fn harmonic(t: u64) -> f64 { 14 | match t { 15 | 0 => 1.0, 16 | _ => consts::EULER_MASCHERONI + gamma::digamma(t as f64 + 1.0), 17 | } 18 | } 19 | 20 | /// Computes the generalized harmonic number of order `n` of `m` 21 | /// e.g. `(1 + 1/2^m + 1/3^m + ... + 1/n^m)` 22 | /// 23 | /// # Remarks 24 | /// 25 | /// Returns `1` as a special case when `n == 0` 26 | pub fn gen_harmonic(n: u64, m: f64) -> f64 { 27 | match n { 28 | 0 => 1.0, 29 | _ => (0..n).fold(0.0, |acc, x| acc + (x as f64 + 1.0).powf(-m)), 30 | } 31 | } 32 | 33 | #[rustfmt::skip] 34 | #[cfg(test)] 35 | mod tests { 36 | use core::f64; 37 | use crate::prec; 38 | use super::*; 39 | 40 | #[test] 41 | fn test_harmonic() { 42 | prec::assert_ulps_eq!(harmonic(0), 1.0, max_ulps = 0); 43 | prec::assert_abs_diff_eq!(harmonic(1), 1.0, epsilon = 1e-14); 44 | prec::assert_abs_diff_eq!(harmonic(2), 1.5, epsilon = 1e-14); 45 | prec::assert_abs_diff_eq!(harmonic(4), 2.083333333333333333333, epsilon = 1e-14); 46 | prec::assert_abs_diff_eq!(harmonic(8), 2.717857142857142857143, epsilon = 1e-14); 47 | prec::assert_abs_diff_eq!(harmonic(16), 3.380728993228993228993, epsilon = 1e-14); 48 | } 49 | 50 | #[test] 51 | fn test_gen_harmonic() { 52 | assert_eq!(gen_harmonic(0, 0.0), 1.0); 53 | assert_eq!(gen_harmonic(0, f64::INFINITY), 1.0); 54 | assert_eq!(gen_harmonic(0, f64::NEG_INFINITY), 1.0); 55 | assert_eq!(gen_harmonic(1, 0.0), 1.0); 56 | assert_eq!(gen_harmonic(1, f64::INFINITY), 1.0); 57 | assert_eq!(gen_harmonic(1, f64::NEG_INFINITY), 1.0); 58 | assert_eq!(gen_harmonic(2, 1.0), 1.5); 59 | assert_eq!(gen_harmonic(2, 3.0), 1.125); 60 | assert_eq!(gen_harmonic(2, f64::INFINITY), 1.0); 61 | assert_eq!(gen_harmonic(2, f64::NEG_INFINITY), f64::INFINITY); 62 | prec::assert_abs_diff_eq!(gen_harmonic(4, 1.0), 2.083333333333333333333, epsilon = 1e-14); 63 | assert_eq!(gen_harmonic(4, 3.0), 1.177662037037037037037); 64 | assert_eq!(gen_harmonic(4, f64::INFINITY), 1.0); 65 | assert_eq!(gen_harmonic(4, f64::NEG_INFINITY), f64::INFINITY); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This crate aims to be a functional port of the Math.NET Numerics 2 | //! Distribution package and in doing so providing the Rust numerical computing 3 | //! community with a robust, well-tested statistical distribution package. This 4 | //! crate also ports over some of the special statistical functions from 5 | //! Math.NET in so far as they are used in the computation of distribution 6 | //! values. This crate depends on the `rand` crate to provide RNG. 7 | //! 8 | //! # Sampling 9 | //! The common use case is to set up the distributions and sample from them which depends on the `Rand` crate for random number generation. 10 | #![cfg_attr(feature = "rand", doc = "```")] 11 | #![cfg_attr(not(feature = "rand"), doc = "```ignore")] 12 | //! use statrs::distribution::Exp; 13 | //! use rand::distr::Distribution; 14 | //! let mut r = rand::rng(); 15 | //! let n = Exp::new(0.5).unwrap(); 16 | //! print!("{}", n.sample(&mut r)); 17 | //! ``` 18 | //! 19 | //! # Introspecting distributions 20 | //! Statrs also comes with a number of useful utility traits for more detailed introspection of distributions. 21 | //! ``` 22 | //! use statrs::distribution::{Exp, Continuous, ContinuousCDF}; // `cdf` and `pdf` 23 | //! use statrs::statistics::Distribution; // statistical moments and entropy 24 | //! 25 | //! let n = Exp::new(1.0).unwrap(); 26 | //! assert_eq!(n.mean(), Some(1.0)); 27 | //! assert_eq!(n.variance(), Some(1.0)); 28 | //! assert_eq!(n.entropy(), Some(1.0)); 29 | //! assert_eq!(n.skewness(), Some(2.0)); 30 | //! assert_eq!(n.cdf(1.0), 0.6321205588285576784045); 31 | //! assert_eq!(n.pdf(1.0), 0.3678794411714423215955); 32 | //! ``` 33 | //! 34 | //! # Utility functions 35 | //! as well as utility functions including `erf`, `gamma`, `ln_gamma`, `beta`, etc. 36 | //! 37 | //! ``` 38 | //! use statrs::distribution::FisherSnedecor; 39 | //! use statrs::statistics::Distribution; 40 | //! 41 | //! let n = FisherSnedecor::new(1.0, 1.0).unwrap(); 42 | //! assert!(n.variance().is_none()); 43 | //! ``` 44 | //! ## Distributions implemented 45 | //! Statrs comes with a number of commonly used distributions including Normal, Gamma, Student's T, Exponential, Weibull, etc. view all implemented in `distributions` module. 46 | 47 | #![crate_type = "lib"] 48 | #![crate_name = "statrs"] 49 | #![allow(clippy::excessive_precision)] 50 | #![allow(clippy::many_single_char_names)] 51 | #![forbid(unsafe_code)] 52 | #![cfg_attr(coverage_nightly, feature(coverage_attribute))] 53 | #![cfg_attr(docsrs, feature(doc_cfg))] 54 | #![cfg_attr(not(feature = "std"), no_std)] 55 | 56 | pub mod consts; 57 | pub mod distribution; 58 | pub mod euclid; 59 | pub mod function; 60 | pub mod generate; 61 | pub mod prec; 62 | pub mod statistics; 63 | pub mod stats_tests; 64 | -------------------------------------------------------------------------------- /src/function/logistic.rs: -------------------------------------------------------------------------------- 1 | //! Provides the [logistic](http://en.wikipedia.org/wiki/Logistic_function) and 2 | //! related functions 3 | 4 | /// Computes the logistic function 5 | pub fn logistic(p: f64) -> f64 { 6 | 1.0 / ((-p).exp() + 1.0) 7 | } 8 | 9 | /// Computes the logit function 10 | /// 11 | /// # Panics 12 | /// 13 | /// If `p < 0.0` or `p > 1.0` 14 | pub fn logit(p: f64) -> f64 { 15 | checked_logit(p).unwrap() 16 | } 17 | 18 | /// Computes the logit function, returning `None` if `p < 0.0` or `p > 1.0`. 19 | pub fn checked_logit(p: f64) -> Option { 20 | if (0.0..=1.0).contains(&p) { 21 | Some((p / (1.0 - p)).ln()) 22 | } else { 23 | None 24 | } 25 | } 26 | 27 | #[rustfmt::skip] 28 | #[cfg(test)] 29 | mod tests { 30 | use core::f64; 31 | use crate::prec; 32 | use super::*; 33 | 34 | #[test] 35 | fn test_logistic() { 36 | assert_eq!(logistic(f64::NEG_INFINITY), 0.0); 37 | assert_eq!(logistic(-11.512915464920228103874353849992239636376994324587), 0.00001); 38 | prec::assert_abs_diff_eq!(logistic(-6.9067547786485535272274487616830597875179908939086), 0.001, epsilon = 1e-18); 39 | prec::assert_abs_diff_eq!(logistic(-2.1972245773362193134015514347727700402304323440139), 0.1, epsilon = 1e-16); 40 | assert_eq!(logistic(0.0), 0.5); 41 | prec::assert_abs_diff_eq!(logistic(2.1972245773362195801634726294284168954491240598975), 0.9, epsilon = 1e-15); 42 | prec::assert_abs_diff_eq!(logistic(6.9067547786485526081487245019905638981131702804661), 0.999, epsilon = 1e-15); 43 | assert_eq!(logistic(11.512915464924779098232747799811946290419057060965), 0.99999); 44 | assert_eq!(logistic(f64::INFINITY), 1.0); 45 | } 46 | 47 | #[test] 48 | fn test_logit() { 49 | assert_eq!(logit(0.0), f64::NEG_INFINITY); 50 | assert_eq!(logit(0.00001), -11.512915464920228103874353849992239636376994324587); 51 | assert_eq!(logit(0.001), -6.9067547786485535272274487616830597875179908939086); 52 | assert_eq!(logit(0.1), -2.1972245773362193134015514347727700402304323440139); 53 | assert_eq!(logit(0.5), 0.0); 54 | assert_eq!(logit(0.9), 2.1972245773362195801634726294284168954491240598975); 55 | assert_eq!(logit(0.999), 6.9067547786485526081487245019905638981131702804661); 56 | assert_eq!(logit(0.99999), 11.512915464924779098232747799811946290419057060965); 57 | assert_eq!(logit(1.0), f64::INFINITY); 58 | } 59 | 60 | #[test] 61 | #[should_panic] 62 | fn test_logit_p_lt_0() { 63 | logit(-1.0); 64 | } 65 | 66 | #[test] 67 | #[should_panic] 68 | fn test_logit_p_gt_1() { 69 | logit(2.0); 70 | } 71 | 72 | #[test] 73 | fn test_checked_logit_p_lt_0() { 74 | assert!(checked_logit(-1.0).is_none()); 75 | } 76 | 77 | #[test] 78 | fn test_checked_logit_p_gt_1() { 79 | assert!(checked_logit(2.0).is_none()); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/distribution/ziggurat.rs: -------------------------------------------------------------------------------- 1 | use super::ziggurat_tables; 2 | use rand::Rng; 3 | use rand::distr::Open01; 4 | 5 | pub fn sample_std_normal(rng: &mut R) -> f64 { 6 | #[inline] 7 | fn pdf(x: f64) -> f64 { 8 | (-x * x / 2.0).exp() 9 | } 10 | 11 | #[inline] 12 | fn zero_case(rng: &mut R, u: f64) -> f64 { 13 | let mut x = 1.0f64; 14 | let mut y = 0.0f64; 15 | while -2.0 * y < x * x { 16 | let x_: f64 = rng.sample(Open01); 17 | let y_: f64 = rng.sample(Open01); 18 | 19 | x = x_.ln() / ziggurat_tables::ZIG_NORM_R; 20 | y = y_.ln(); 21 | } 22 | if u < 0.0 { 23 | x - ziggurat_tables::ZIG_NORM_R 24 | } else { 25 | ziggurat_tables::ZIG_NORM_R - x 26 | } 27 | } 28 | 29 | ziggurat( 30 | rng, 31 | true, 32 | &ziggurat_tables::ZIG_NORM_X, 33 | &ziggurat_tables::ZIG_NORM_F, 34 | pdf, 35 | zero_case, 36 | ) 37 | } 38 | 39 | pub fn sample_exp_1(rng: &mut R) -> f64 { 40 | #[inline] 41 | fn pdf(x: f64) -> f64 { 42 | (-x).exp() 43 | } 44 | 45 | #[inline] 46 | fn zero_case(rng: &mut R, _u: f64) -> f64 { 47 | ziggurat_tables::ZIG_EXP_R - rng.random::().ln() 48 | } 49 | 50 | ziggurat( 51 | rng, 52 | false, 53 | &ziggurat_tables::ZIG_EXP_X, 54 | &ziggurat_tables::ZIG_EXP_F, 55 | pdf, 56 | zero_case, 57 | ) 58 | } 59 | 60 | // Ziggurat method for sampling a random number based on the ZIGNOR 61 | // variant from Doornik 2005. Code borrowed from 62 | // https://github.com/rust-lang-nursery/rand/blob/master/src/distributions/mod. 63 | // rs#L223 64 | #[inline(always)] 65 | fn ziggurat( 66 | rng: &mut R, 67 | symmetric: bool, 68 | x_tab: ziggurat_tables::ZigTable, 69 | f_tab: ziggurat_tables::ZigTable, 70 | mut pdf: P, 71 | mut zero_case: Z, 72 | ) -> f64 73 | where 74 | P: FnMut(f64) -> f64, 75 | Z: FnMut(&mut R, f64) -> f64, 76 | { 77 | const SCALE: f64 = (1u64 << 53) as f64; 78 | loop { 79 | let bits: u64 = rng.random(); 80 | let i = (bits & 0xff) as usize; 81 | let f = (bits >> 11) as f64 / SCALE; 82 | 83 | // u is either U(-1, 1) or U(0, 1) depending on if this is a 84 | // symmetric distribution or not. 85 | let u = if symmetric { 2.0 * f - 1.0 } else { f }; 86 | let x = u * x_tab[i]; 87 | 88 | let test_x = if symmetric { x.abs() } else { x }; 89 | 90 | // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < 91 | // x_tab[i+1]/x_tab[i]) 92 | if test_x < x_tab[i + 1] { 93 | return x; 94 | } 95 | if i == 0 { 96 | return zero_case(rng, u); 97 | } 98 | // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1 99 | if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.random::() < pdf(x) { 100 | return x; 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /benches/order_statistics.rs: -------------------------------------------------------------------------------- 1 | extern crate statrs; 2 | use criterion::{BatchSize, Criterion, black_box, criterion_group, criterion_main}; 3 | use rand::prelude::*; 4 | use statrs::statistics::*; 5 | 6 | fn bench_order_statistic(c: &mut Criterion) { 7 | let mut rng = rand::rng(); 8 | let to_random_owned = |data: &[f64]| -> Data> { 9 | let mut rng = rand::rng(); 10 | let mut owned = data.to_vec(); 11 | owned.shuffle(&mut rng); 12 | Data::new(owned) 13 | }; 14 | let k = black_box(rng.random_range(..=usize::MAX)); 15 | let tau = black_box(rng.random_range(0.0..1.0)); 16 | let mut group = c.benchmark_group("order statistic"); 17 | let data: Vec<_> = (0..100).map(|x| x as f64).collect(); 18 | group.bench_function("order_statistic", |b| { 19 | b.iter_batched( 20 | || to_random_owned(&data), 21 | |mut data| data.order_statistic(k), 22 | BatchSize::SmallInput, 23 | ) 24 | }); 25 | group.bench_function("median", |b| { 26 | b.iter_batched( 27 | || to_random_owned(&data), 28 | |data| data.median(), 29 | BatchSize::SmallInput, 30 | ) 31 | }); 32 | group.bench_function("quantile", |b| { 33 | b.iter_batched( 34 | || to_random_owned(&data), 35 | |mut data| data.quantile(tau), 36 | BatchSize::SmallInput, 37 | ) 38 | }); 39 | group.bench_function("percentile", |b| { 40 | b.iter_batched( 41 | || to_random_owned(&data), 42 | |mut data| data.percentile(k), 43 | BatchSize::SmallInput, 44 | ) 45 | }); 46 | group.bench_function("lower_quartile", |b| { 47 | b.iter_batched( 48 | || to_random_owned(&data), 49 | |mut data| data.lower_quartile(), 50 | BatchSize::SmallInput, 51 | ) 52 | }); 53 | group.bench_function("upper_quartile", |b| { 54 | b.iter_batched( 55 | || to_random_owned(&data), 56 | |mut data| data.upper_quartile(), 57 | BatchSize::SmallInput, 58 | ) 59 | }); 60 | group.bench_function("interquartile_range", |b| { 61 | b.iter_batched( 62 | || to_random_owned(&data), 63 | |mut data| data.interquartile_range(), 64 | BatchSize::SmallInput, 65 | ) 66 | }); 67 | group.bench_function("ranks: RankTieBreaker::First", |b| { 68 | b.iter_batched( 69 | || to_random_owned(&data), 70 | |mut data| data.ranks(RankTieBreaker::First), 71 | BatchSize::SmallInput, 72 | ) 73 | }); 74 | group.bench_function("ranks: RankTieBreaker::Average", |b| { 75 | b.iter_batched( 76 | || to_random_owned(&data), 77 | |mut data| data.ranks(RankTieBreaker::Average), 78 | BatchSize::SmallInput, 79 | ) 80 | }); 81 | group.bench_function("ranks: RankTieBreaker::Min", |b| { 82 | b.iter_batched( 83 | || to_random_owned(&data), 84 | |mut data| data.ranks(RankTieBreaker::Min), 85 | BatchSize::SmallInput, 86 | ) 87 | }); 88 | group.finish(); 89 | } 90 | 91 | criterion_group!(benches, bench_order_statistic); 92 | criterion_main!(benches); 93 | -------------------------------------------------------------------------------- /src/function/exponential.rs: -------------------------------------------------------------------------------- 1 | //! Provides functions related to exponential calculations 2 | 3 | use crate::consts; 4 | 5 | /// Computes the generalized Exponential Integral function 6 | /// where `x` is the argument and `n` is the integer power of the 7 | /// denominator term. 8 | /// 9 | /// Returns `None` if `x < 0.0` or the computation could not 10 | /// converge after 100 iterations 11 | /// 12 | /// # Remarks 13 | /// 14 | /// This implementation follows the derivation in 15 | /// 16 | /// _"Handbook of Mathematical Functions, Applied Mathematics Series, Volume 17 | /// 55"_ - Abramowitz, M., and Stegun, I.A 1964 18 | /// 19 | /// AND 20 | /// 21 | /// _"Advanced mathematical methods for scientists and engineers"_ - Bender, 22 | /// Carl M.; Steven A. Orszag (1978). page 253 23 | /// 24 | /// The continued fraction approach is used for `x > 1.0` while the taylor 25 | /// series expansions is used for `0.0 < x <= 1`. 26 | // TODO: Add examples 27 | pub fn integral(x: f64, n: u64) -> Option { 28 | let eps = 0.00000000000000001; 29 | let max_iter = 100; 30 | let nf64 = n as f64; 31 | let near_f64min = 1e-100; // needs very small value that is not quite as small as f64 min 32 | 33 | // special cases 34 | if n == 0 { 35 | return Some((-x).exp() / x); 36 | } 37 | if x == 0.0 { 38 | return Some(1.0 / (nf64 - 1.0)); 39 | } 40 | 41 | if x > 1.0 { 42 | let mut b = x + nf64; 43 | let mut c = 1.0 / near_f64min; 44 | let mut d = 1.0 / b; 45 | let mut h = d; 46 | for i in 1..max_iter + 1 { 47 | let a = -(i as f64) * (nf64 - 1.0 + i as f64); 48 | b += 2.0; 49 | d = 1.0 / (a * d + b); 50 | c = b + a / c; 51 | let del = c * d; 52 | h *= del; 53 | if (del - 1.0).abs() < eps { 54 | return Some(h * (-x).exp()); 55 | } 56 | } 57 | None 58 | } else { 59 | let mut factorial = 1.0; 60 | let mut result = if n - 1 != 0 { 61 | 1.0 / (nf64 - 1.0) 62 | } else { 63 | -x.ln() - consts::EULER_MASCHERONI 64 | }; 65 | for i in 1..max_iter + 1 { 66 | factorial *= -x / i as f64; 67 | let del = if i != n - 1 { 68 | -factorial / (i as f64 - nf64 + 1.0) 69 | } else { 70 | let mut psi = -consts::EULER_MASCHERONI; 71 | for ii in 1..n { 72 | psi += 1.0 / ii as f64; 73 | } 74 | factorial * (-x.ln() + psi) 75 | }; 76 | result += del; 77 | if del.abs() < result.abs() * eps { 78 | return Some(result); 79 | } 80 | } 81 | None 82 | } 83 | } 84 | 85 | #[rustfmt::skip] 86 | #[cfg(test)] 87 | mod tests { 88 | use crate::prec; 89 | use super::*; 90 | 91 | #[test] 92 | fn test_integral() { 93 | assert_eq!(integral(0.001, 1).unwrap(), 6.33153936413614904); 94 | prec::assert_abs_diff_eq!(integral(0.1, 1).unwrap(), 1.82292395841939059, epsilon = 1e-15); 95 | assert_eq!(integral(1.0, 1).unwrap(), 0.219383934395520286); 96 | prec::assert_abs_diff_eq!(integral(2.0, 1).unwrap(), 0.0489005107080611248, epsilon = 1e-15); 97 | prec::assert_abs_diff_eq!(integral(2.5, 1).unwrap(), 0.0249149178702697399, epsilon = 1e-15); 98 | prec::assert_abs_diff_eq!(integral(10.0, 1).unwrap(), 4.15696892968532464e-06, epsilon = 1e-20); 99 | assert_eq!(integral(0.001, 2).unwrap(), 0.992668960469238915); 100 | prec::assert_abs_diff_eq!(integral(0.1, 2).unwrap(), 0.722545022194020392, epsilon = 1e-15); 101 | prec::assert_abs_diff_eq!(integral(1.0, 2).unwrap(), 0.148495506775922048, epsilon = 1e-16); 102 | prec::assert_abs_diff_eq!(integral(2.0, 2).unwrap(), 0.0375342618204904527, epsilon = 1e-16); 103 | prec::assert_abs_diff_eq!(integral(10.0, 2).unwrap(), 3.830240465631608e-06, epsilon = 1e-20); 104 | assert_eq!(integral(0.001, 0).unwrap(), 999.000499833375); 105 | assert_eq!(integral(0.1, 0).unwrap(), 9.048374180359595); 106 | prec::assert_abs_diff_eq!(integral(1.0, 0).unwrap(), 0.3678794411714423, epsilon = 1e-16); 107 | assert_eq!(integral(2.0, 0).unwrap(), 0.06766764161830635); 108 | assert_eq!(integral(10.0, 0).unwrap(), 4.539992976248485e-06); 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /tests/nist_tests.rs: -------------------------------------------------------------------------------- 1 | //! This test relies on data that is reusable but not distributable by statrs as 2 | //! such, the data will need to be downloaded from the relevant NIST StRD dataset 3 | //! the parsing for testing assumes data to be of form, 4 | //! ```text 5 | //! sample mean : 6 | //! sample std_dev : 7 | //! sample correlation: 8 | //! [zero or more blank lines] 9 | //! data0 10 | //! data1 11 | //! data2 12 | //! ... 13 | //! ``` 14 | //! This test can be run on it's own from the shell from this folder as 15 | //! ```sh 16 | //! ./gather_nist_data.sh && cargo test -- --ignored nist_ 17 | //! ``` 18 | use anyhow::Result; 19 | use approx::assert_relative_eq; 20 | use statrs::statistics::Statistics; 21 | 22 | use std::io::{BufRead, BufReader}; 23 | use std::path::PathBuf; 24 | use std::{env, fs}; 25 | 26 | struct TestCase { 27 | certified: CertifiedValues, 28 | values: Vec, 29 | } 30 | 31 | impl core::fmt::Debug for TestCase { 32 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 33 | write!(f, "TestCase({:?}, [...]", self.certified) 34 | } 35 | } 36 | 37 | #[derive(Debug)] 38 | struct CertifiedValues { 39 | mean: f64, 40 | std_dev: f64, 41 | corr: f64, 42 | } 43 | 44 | impl core::fmt::Display for CertifiedValues { 45 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 46 | write!( 47 | f, 48 | "μ={:.3e}, σ={:.3e}, r={:.3e}", 49 | self.mean, self.std_dev, self.corr 50 | ) 51 | } 52 | } 53 | 54 | const NIST_DATA_DIR_ENV: &str = "STATRS_NIST_DATA_DIR"; 55 | const FILENAMES: [&str; 7] = [ 56 | "Lottery.dat", 57 | "Lew.dat", 58 | "Mavro.dat", 59 | "Michelso.dat", 60 | "NumAcc1.dat", 61 | "NumAcc2.dat", 62 | "NumAcc3.dat", 63 | ]; 64 | 65 | fn get_path(fname: &str, prefix: Option<&str>) -> PathBuf { 66 | if let Some(prefix) = prefix { 67 | [prefix, fname].iter().collect() 68 | } else { 69 | ["tests", fname].iter().collect() 70 | } 71 | } 72 | 73 | #[test] 74 | #[ignore = "NIST tests should not run from typical `cargo test` calls"] 75 | fn nist_strd_univariate_mean() { 76 | for fname in FILENAMES { 77 | let filepath = get_path(fname, env::var(NIST_DATA_DIR_ENV).ok().as_deref()); 78 | let case = parse_file(filepath) 79 | .unwrap_or_else(|e| panic!("failed parsing file {fname} with `{e:?}`")); 80 | assert_relative_eq!(case.values.mean(), case.certified.mean, epsilon = 1e-12); 81 | } 82 | } 83 | 84 | #[test] 85 | #[ignore] 86 | fn nist_strd_univariate_std_dev() { 87 | for fname in FILENAMES { 88 | let filepath = get_path(fname, env::var(NIST_DATA_DIR_ENV).ok().as_deref()); 89 | let case = parse_file(filepath) 90 | .unwrap_or_else(|e| panic!("failed parsing file {fname} with `{e:?}`")); 91 | assert_relative_eq!( 92 | case.values.std_dev(), 93 | case.certified.std_dev, 94 | epsilon = 1e-10 95 | ); 96 | } 97 | } 98 | 99 | fn parse_certified_value(line: String) -> Result { 100 | line.chars() 101 | .skip_while(|&c| c != ':') 102 | .skip(1) // skip through ':' delimiter 103 | .skip_while(|&c| c.is_whitespace()) // effectively `String` trim 104 | .take_while(|&c| matches!(c, '0'..='9' | '-' | '.')) 105 | .collect::() 106 | .parse::() 107 | .map_err(|e| e.into()) 108 | } 109 | 110 | fn parse_file(path: impl AsRef) -> anyhow::Result { 111 | let f = fs::File::open(path)?; 112 | let reader = BufReader::new(f); 113 | let mut lines = reader.lines(); 114 | 115 | let mean = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; 116 | let std_dev = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; 117 | let corr = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; 118 | 119 | Ok(TestCase { 120 | certified: CertifiedValues { 121 | mean, 122 | std_dev, 123 | corr, 124 | }, 125 | values: lines 126 | .map_while(|line| line.ok()?.trim().parse().ok()) 127 | .collect(), 128 | }) 129 | } 130 | 131 | #[test] 132 | #[ignore = "NIST tests should not run from typical `cargo test` calls"] 133 | fn nist_test_covariance_consistent_with_variance() {} 134 | 135 | #[test] 136 | #[ignore = "NIST tests should not run from typical `cargo test` calls"] 137 | fn nist_test_covariance_is_symmetric() {} 138 | -------------------------------------------------------------------------------- /src/stats_tests/anderson_darling.rs: -------------------------------------------------------------------------------- 1 | use crate::distribution::ContinuousCDF; 2 | 3 | #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] 4 | #[non_exhaustive] 5 | pub enum AndersonDarlingError { 6 | SampleSizeInvalid, 7 | } 8 | 9 | impl core::fmt::Display for AndersonDarlingError { 10 | /// Formats the error for display. 11 | fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { 12 | match self { 13 | AndersonDarlingError::SampleSizeInvalid => { 14 | write!(f, "Sample size `n` must be greater than 0.") 15 | } 16 | } 17 | } 18 | } 19 | 20 | #[cfg(feature = "std")] 21 | impl std::error::Error for AndersonDarlingError {} 22 | 23 | pub fn anderson_darling>( 24 | f_obs: &[f64], 25 | dist: &T, 26 | ) -> Result<(f64, f64), AndersonDarlingError> { 27 | let n = f_obs.len(); 28 | if n == 0 { 29 | return Err(AndersonDarlingError::SampleSizeInvalid); 30 | } 31 | let mut f_obs = f_obs.to_vec(); 32 | f_obs.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); 33 | 34 | let n_float = n as f64; 35 | let beta: f64 = (1.0 / n_float) 36 | * (0..n) 37 | .map(|i| { 38 | (2.0 * (i + 1) as f64 - 1.0) 39 | * (f64::ln(dist.cdf(f_obs[i])) + f64::ln(1.0 - dist.cdf(f_obs[n - 1 - i]))) 40 | }) 41 | .sum::(); 42 | let a_squared = -n_float - beta; 43 | let a_squared_adjusted = a_squared * (1.0 + (0.75 / n_float) + (2.25 / n_float.powi(2))); 44 | let p_value = if a_squared_adjusted >= 0.6 { 45 | (1.2937 - 5.709 * a_squared_adjusted + 0.0186 * a_squared_adjusted.powi(2)).exp() 46 | } else if a_squared_adjusted >= 0.34 { 47 | (0.9177 - 4.279 * a_squared_adjusted - 1.38 * a_squared_adjusted.powi(2)).exp() 48 | } else if a_squared_adjusted >= 0.2 { 49 | 1.0 - (-8.318 + 42.796 * a_squared_adjusted - 59.938 * a_squared_adjusted.powi(2)).exp() 50 | } else { 51 | 1.0 - (-13.436 + 101.14 * a_squared_adjusted - 223.73 * a_squared_adjusted.powi(2)).exp() 52 | }; 53 | 54 | Ok((a_squared, p_value)) 55 | } 56 | 57 | #[cfg(test)] 58 | mod tests { 59 | use super::*; 60 | use crate::distribution::{Gamma, Normal}; 61 | 62 | #[test] 63 | fn test_normality_good_fit() { 64 | let data = vec![5.2, 4.9, 5.5, 4.8, 5.0, 5.1, 5.3, 4.7, 5.4, 4.9, 5.2, 5.0]; 65 | let n = data.len(); 66 | let n_float = n as f64; 67 | let mean = data.iter().sum::() / n_float; 68 | let std_dev = 69 | (data.iter().map(|&val| (val - mean).powi(2)).sum::() / (n_float - 1.0)).sqrt(); 70 | let normal_dist = Normal::new(mean, std_dev).unwrap(); 71 | 72 | let (stat, p_value) = anderson_darling(&data, &normal_dist).unwrap(); 73 | 74 | assert!(stat < 0.5, "Statistic should be low for a good fit"); 75 | assert!(p_value > 0.05, "P-value should be high for a good fit"); 76 | } 77 | 78 | #[test] 79 | fn test_normality_poor_fit() { 80 | let data = vec![1.0, 1.2, 1.5, 1.9, 2.0, 2.1, 2.2, 2.3, 5.0, 8.0, 12.0]; 81 | let n = data.len(); 82 | let n_float = n as f64; 83 | let mean = data.iter().sum::() / n_float; 84 | let std_dev = 85 | (data.iter().map(|&val| (val - mean).powi(2)).sum::() / (n_float - 1.0)).sqrt(); 86 | let normal_dist = Normal::new(mean, std_dev).unwrap(); 87 | 88 | let (stat, p_value) = anderson_darling(&data, &normal_dist).unwrap(); 89 | 90 | assert!(stat > 0.5, "Statistic should be high for a poor fit"); 91 | assert!(p_value < 0.05, "P-value should be low for a poor fit"); 92 | } 93 | 94 | #[test] 95 | fn test_gamma_distribution_good_fit() { 96 | let data = vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0, 5.0]; 97 | let gamma_dist = Gamma::new(2.0, 1.0).unwrap(); 98 | let (stat, p_value) = anderson_darling(&data, &gamma_dist).unwrap(); 99 | assert!(stat < 1.0, "Statistic should be low for a good gamma fit"); 100 | assert!(p_value > 0.05, "P-value should be high for a good fit"); 101 | } 102 | 103 | #[test] 104 | fn test_gamma_distribution_bad_fit() { 105 | let data = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]; 106 | let gamma_dist = Gamma::new(2.0, 1.0).unwrap(); 107 | let (stat, p_value) = anderson_darling(&data, &gamma_dist).unwrap(); 108 | assert!(stat > 1.0, "Statistic should be high for a bad gamma fit"); 109 | assert!(p_value < 0.05, "P-value should be low for a bad fit"); 110 | } 111 | 112 | #[test] 113 | fn test_sample_size_invalid() { 114 | let data: Vec = vec![]; 115 | let normal_dist = Normal::new(0.0, 1.0).unwrap(); 116 | let result = anderson_darling(&data, &normal_dist); 117 | assert!(result.is_err()); 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/statistics/traits.rs: -------------------------------------------------------------------------------- 1 | use ::num_traits::float::Float; 2 | 3 | /// The `Min` trait specifies than an object has a minimum value 4 | pub trait Min { 5 | /// Returns the minimum value in the domain of a given distribution 6 | /// if it exists, otherwise `None`. 7 | /// 8 | /// # Examples 9 | /// 10 | /// ``` 11 | /// use statrs::statistics::Min; 12 | /// use statrs::distribution::Uniform; 13 | /// 14 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 15 | /// assert_eq!(0.0, n.min()); 16 | /// ``` 17 | fn min(&self) -> T; 18 | } 19 | 20 | /// The `Max` trait specifies that an object has a maximum value 21 | pub trait Max { 22 | /// Returns the maximum value in the domain of a given distribution 23 | /// if it exists, otherwise `None`. 24 | /// 25 | /// # Examples 26 | /// 27 | /// ``` 28 | /// use statrs::statistics::Max; 29 | /// use statrs::distribution::Uniform; 30 | /// 31 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 32 | /// assert_eq!(1.0, n.max()); 33 | /// ``` 34 | fn max(&self) -> T; 35 | } 36 | pub trait DiscreteDistribution { 37 | /// Returns the mean, if it exists. 38 | fn mean(&self) -> Option { 39 | None 40 | } 41 | /// Returns the variance, if it exists. 42 | fn variance(&self) -> Option { 43 | None 44 | } 45 | /// Returns the standard deviation, if it exists. 46 | fn std_dev(&self) -> Option { 47 | self.variance().map(|var| var.sqrt()) 48 | } 49 | /// Returns the entropy, if it exists. 50 | fn entropy(&self) -> Option { 51 | None 52 | } 53 | /// Returns the skewness, if it exists. 54 | fn skewness(&self) -> Option { 55 | None 56 | } 57 | } 58 | 59 | pub trait Distribution { 60 | /// Returns the mean, if it exists. 61 | /// 62 | /// # Examples 63 | /// 64 | /// ``` 65 | /// use statrs::statistics::Distribution; 66 | /// use statrs::distribution::Uniform; 67 | /// 68 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 69 | /// assert_eq!(0.5, n.mean().unwrap()); 70 | /// ``` 71 | fn mean(&self) -> Option { 72 | None 73 | } 74 | /// Returns the variance, if it exists. 75 | /// 76 | /// # Examples 77 | /// 78 | /// ``` 79 | /// use statrs::statistics::Distribution; 80 | /// use statrs::distribution::Uniform; 81 | /// 82 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 83 | /// assert_eq!(1.0 / 12.0, n.variance().unwrap()); 84 | /// ``` 85 | fn variance(&self) -> Option { 86 | None 87 | } 88 | /// Returns the standard deviation, if it exists. 89 | /// 90 | /// # Examples 91 | /// 92 | /// ``` 93 | /// use statrs::statistics::Distribution; 94 | /// use statrs::distribution::Uniform; 95 | /// 96 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 97 | /// assert_eq!((1f64 / 12f64).sqrt(), n.std_dev().unwrap()); 98 | /// ``` 99 | fn std_dev(&self) -> Option { 100 | self.variance().map(|var| var.sqrt()) 101 | } 102 | /// Returns the entropy, if it exists. 103 | /// 104 | /// # Examples 105 | /// 106 | /// ``` 107 | /// use statrs::statistics::Distribution; 108 | /// use statrs::distribution::Uniform; 109 | /// 110 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 111 | /// assert_eq!(0.0, n.entropy().unwrap()); 112 | /// ``` 113 | fn entropy(&self) -> Option { 114 | None 115 | } 116 | /// Returns the skewness, if it exists. 117 | /// 118 | /// # Examples 119 | /// 120 | /// ``` 121 | /// use statrs::statistics::Distribution; 122 | /// use statrs::distribution::Uniform; 123 | /// 124 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 125 | /// assert_eq!(0.0, n.skewness().unwrap()); 126 | /// ``` 127 | fn skewness(&self) -> Option { 128 | None 129 | } 130 | } 131 | 132 | /// The `Mean` trait implements the calculation of a mean. 133 | // TODO: Clarify the traits of multidimensional distributions 134 | pub trait MeanN { 135 | fn mean(&self) -> Option; 136 | } 137 | 138 | // TODO: Clarify the traits of multidimensional distributions 139 | pub trait VarianceN { 140 | fn variance(&self) -> Option; 141 | } 142 | 143 | /// The `Median` trait returns the median of the distribution. 144 | pub trait Median { 145 | /// Returns the median. 146 | /// 147 | /// # Examples 148 | /// 149 | /// ``` 150 | /// use statrs::statistics::Median; 151 | /// use statrs::distribution::Uniform; 152 | /// 153 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 154 | /// assert_eq!(0.5, n.median()); 155 | /// ``` 156 | fn median(&self) -> T; 157 | } 158 | 159 | /// The `Mode` trait specifies that an object has a closed form solution 160 | /// for its mode(s) 161 | pub trait Mode { 162 | /// Returns the mode, if one exists. 163 | /// 164 | /// # Examples 165 | /// 166 | /// ``` 167 | /// use statrs::statistics::Mode; 168 | /// use statrs::distribution::Uniform; 169 | /// 170 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 171 | /// assert_eq!(Some(0.5), n.mode()); 172 | /// ``` 173 | fn mode(&self) -> T; 174 | } 175 | -------------------------------------------------------------------------------- /src/function/factorial.rs: -------------------------------------------------------------------------------- 1 | //! Provides functions related to factorial calculations (e.g. binomial 2 | //! coefficient, factorial, multinomial) 3 | 4 | use crate::function::gamma; 5 | 6 | /// The maximum factorial representable 7 | /// by a 64-bit floating point without 8 | /// overflowing 9 | pub const MAX_FACTORIAL: usize = 170; 10 | 11 | /// Computes the factorial function `x -> x!` for 12 | /// `170 >= x >= 0`. All factorials larger than `170!` 13 | /// will overflow an `f64`. 14 | /// 15 | /// # Remarks 16 | /// 17 | /// Returns `f64::INFINITY` if `x > 170` 18 | pub fn factorial(x: u64) -> f64 { 19 | let x = x as usize; 20 | FCACHE.get(x).map_or(f64::INFINITY, |&fac| fac) 21 | } 22 | 23 | /// Computes the logarithmic factorial function `x -> ln(x!)` 24 | /// for `x >= 0`. 25 | /// 26 | /// # Remarks 27 | /// 28 | /// Returns `0.0` if `x <= 1` 29 | pub fn ln_factorial(x: u64) -> f64 { 30 | let x = x as usize; 31 | FCACHE 32 | .get(x) 33 | .map_or_else(|| gamma::ln_gamma(x as f64 + 1.0), |&fac| fac.ln()) 34 | } 35 | 36 | /// Computes the binomial coefficient `n choose k` 37 | /// where `k` and `n` are non-negative values. 38 | /// 39 | /// # Remarks 40 | /// 41 | /// Returns `0.0` if `k > n` 42 | pub fn binomial(n: u64, k: u64) -> f64 { 43 | if k > n { 44 | 0.0 45 | } else { 46 | (0.5 + (ln_factorial(n) - ln_factorial(k) - ln_factorial(n - k)).exp()).floor() 47 | } 48 | } 49 | 50 | /// Computes the natural logarithm of the binomial coefficient 51 | /// `ln(n choose k)` where `k` and `n` are non-negative values 52 | /// 53 | /// # Remarks 54 | /// 55 | /// Returns `f64::NEG_INFINITY` if `k > n` 56 | pub fn ln_binomial(n: u64, k: u64) -> f64 { 57 | if k > n { 58 | f64::NEG_INFINITY 59 | } else { 60 | ln_factorial(n) - ln_factorial(k) - ln_factorial(n - k) 61 | } 62 | } 63 | 64 | /// Computes the multinomial coefficient: `n choose n1, n2, n3, ...` 65 | /// 66 | /// # Panics 67 | /// 68 | /// If the elements in `ni` do not sum to `n` 69 | pub fn multinomial(n: u64, ni: &[u64]) -> f64 { 70 | checked_multinomial(n, ni).unwrap() 71 | } 72 | 73 | /// Computes the multinomial coefficient: `n choose n1, n2, n3, ...` 74 | /// 75 | /// Returns `None` if the elements in `ni` do not sum to `n`. 76 | pub fn checked_multinomial(n: u64, ni: &[u64]) -> Option { 77 | let (sum, ret) = ni.iter().fold((0, ln_factorial(n)), |acc, &x| { 78 | (acc.0 + x, acc.1 - ln_factorial(x)) 79 | }); 80 | 81 | if sum == n { 82 | Some((0.5 + ret.exp()).floor()) 83 | } else { 84 | None 85 | } 86 | } 87 | 88 | // Initialization for pre-computed cache of 171 factorial 89 | // values 0!...170! 90 | const FCACHE: [f64; MAX_FACTORIAL + 1] = { 91 | let mut fcache = [1.0; MAX_FACTORIAL + 1]; 92 | 93 | // `const` only allow while loops 94 | let mut i = 1; 95 | while i < MAX_FACTORIAL + 1 { 96 | fcache[i] = fcache[i - 1] * i as f64; 97 | i += 1; 98 | } 99 | 100 | fcache 101 | }; 102 | 103 | #[rustfmt::skip] 104 | #[cfg(test)] 105 | mod tests { 106 | use super::*; 107 | use crate::prec; 108 | 109 | #[test] 110 | fn test_fcache() { 111 | assert!((FCACHE[0] - 1.0).abs() < f64::EPSILON); 112 | assert!((FCACHE[1] - 1.0).abs() < f64::EPSILON); 113 | assert!((FCACHE[2] - 2.0).abs() < f64::EPSILON); 114 | assert!((FCACHE[3] - 6.0).abs() < f64::EPSILON); 115 | assert!((FCACHE[4] - 24.0).abs() < f64::EPSILON); 116 | assert!((FCACHE[70] - 1197857166996989e85).abs() < f64::EPSILON); 117 | assert!((FCACHE[170] - 7257415615307994e291).abs() < f64::EPSILON); 118 | } 119 | 120 | #[test] 121 | fn test_factorial_and_ln_factorial() { 122 | let mut fac = 1.0; 123 | assert_eq!(factorial(0), fac); 124 | for i in 1..171 { 125 | fac *= i as f64; 126 | assert_eq!(factorial(i), fac); 127 | assert_eq!(ln_factorial(i), fac.ln()); 128 | } 129 | } 130 | 131 | #[test] 132 | fn test_factorial_overflow() { 133 | assert_eq!(factorial(172), f64::INFINITY); 134 | assert_eq!(factorial(u64::MAX), f64::INFINITY); 135 | } 136 | 137 | #[test] 138 | fn test_ln_factorial_does_not_overflow() { 139 | assert_eq!(ln_factorial(1 << 10), 6078.2118847500501140); 140 | prec::assert_abs_diff_eq!( 141 | ln_factorial(1 << 12), 142 | 29978.648060844048236, 143 | epsilon = 1e-11 144 | ); 145 | assert_eq!(ln_factorial(1 << 15), 307933.81973375485425); 146 | assert_eq!(ln_factorial(1 << 17), 1413421.9939462073242); 147 | } 148 | 149 | #[test] 150 | fn test_binomial() { 151 | assert_eq!(binomial(1, 1), 1.0); 152 | assert_eq!(binomial(5, 2), 10.0); 153 | assert_eq!(binomial(7, 3), 35.0); 154 | assert_eq!(binomial(1, 0), 1.0); 155 | assert_eq!(binomial(0, 1), 0.0); 156 | assert_eq!(binomial(5, 7), 0.0); 157 | } 158 | 159 | #[test] 160 | fn test_ln_binomial() { 161 | assert_eq!(ln_binomial(1, 1), 1f64.ln()); 162 | prec::assert_abs_diff_eq!(ln_binomial(5, 2), 10f64.ln(), epsilon = 1e-14); 163 | prec::assert_abs_diff_eq!(ln_binomial(7, 3), 35f64.ln(), epsilon = 1e-14); 164 | assert_eq!(ln_binomial(1, 0), 1f64.ln()); 165 | assert_eq!(ln_binomial(0, 1), 0f64.ln()); 166 | assert_eq!(ln_binomial(5, 7), 0f64.ln()); 167 | } 168 | 169 | #[test] 170 | fn test_multinomial() { 171 | assert_eq!(1.0, multinomial(1, &[1, 0])); 172 | assert_eq!(10.0, multinomial(5, &[3, 2])); 173 | assert_eq!(10.0, multinomial(5, &[2, 3])); 174 | assert_eq!(35.0, multinomial(7, &[3, 4])); 175 | } 176 | 177 | #[test] 178 | #[should_panic] 179 | fn test_multinomial_bad_ni() { 180 | multinomial(1, &[1, 1]); 181 | } 182 | 183 | #[test] 184 | fn test_checked_multinomial_bad_ni() { 185 | assert!(checked_multinomial(1, &[1, 1]).is_none()); 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /src/statistics/order_statistics.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "std")] 2 | use super::RankTieBreaker; 3 | 4 | /// The `OrderStatistics` trait provides statistical utilities 5 | /// having to do with ordering. All the algorithms are in-place thus requiring 6 | /// a mutable borrow. 7 | pub trait OrderStatistics { 8 | /// Returns the order statistic `(order 1..N)` from the data 9 | /// 10 | /// # Remarks 11 | /// 12 | /// No sorting is assumed. Order must be one-based (between `1` and `N` 13 | /// inclusive) 14 | /// Returns `f64::NAN` if order is outside the viable range or data is 15 | /// empty. 16 | /// 17 | /// # Examples 18 | /// 19 | /// ``` 20 | /// use statrs::statistics::OrderStatistics; 21 | /// use statrs::statistics::Data; 22 | /// 23 | /// let x = []; 24 | /// let mut x = Data::new(x); 25 | /// assert!(x.order_statistic(1).is_nan()); 26 | /// 27 | /// let y = [0.0, 3.0, -2.0]; 28 | /// let mut y = Data::new(y); 29 | /// assert!(y.order_statistic(0).is_nan()); 30 | /// assert!(y.order_statistic(4).is_nan()); 31 | /// assert_eq!(y.order_statistic(2), 0.0); 32 | /// assert!(y != Data::new([0.0, 3.0, -2.0])); 33 | /// ``` 34 | fn order_statistic(&mut self, order: usize) -> T; 35 | 36 | /// Returns the median value from the data 37 | /// 38 | /// # Remarks 39 | /// 40 | /// Returns `f64::NAN` if data is empty 41 | /// 42 | /// # Examples 43 | /// 44 | /// ``` 45 | /// use statrs::statistics::OrderStatistics; 46 | /// use statrs::statistics::Data; 47 | /// 48 | /// let x = []; 49 | /// let mut x = Data::new(x); 50 | /// assert!(x.median().is_nan()); 51 | /// 52 | /// let y = [0.0, 3.0, -2.0]; 53 | /// let mut y = Data::new(y); 54 | /// assert_eq!(y.median(), 0.0); 55 | /// assert!(y != Data::new([0.0, 3.0, -2.0])); 56 | fn median(&mut self) -> T; 57 | 58 | /// Estimates the tau-th quantile from the data. The tau-th quantile 59 | /// is the data value where the cumulative distribution function crosses 60 | /// tau. 61 | /// 62 | /// # Remarks 63 | /// 64 | /// No sorting is assumed. Tau must be between `0` and `1` inclusive. 65 | /// Returns `f64::NAN` if data is empty or tau is outside the inclusive 66 | /// range. 67 | /// 68 | /// # Examples 69 | /// 70 | /// ``` 71 | /// use statrs::statistics::OrderStatistics; 72 | /// use statrs::statistics::Data; 73 | /// 74 | /// let x = []; 75 | /// let mut x = Data::new(x); 76 | /// assert!(x.quantile(0.5).is_nan()); 77 | /// 78 | /// let y = [0.0, 3.0, -2.0]; 79 | /// let mut y = Data::new(y); 80 | /// assert!(y.quantile(-1.0).is_nan()); 81 | /// assert!(y.quantile(2.0).is_nan()); 82 | /// assert_eq!(y.quantile(0.5), 0.0); 83 | /// assert!(y != Data::new([0.0, 3.0, -2.0])); 84 | /// ``` 85 | fn quantile(&mut self, tau: f64) -> T; 86 | 87 | /// Estimates the p-Percentile value from the data. 88 | /// 89 | /// # Remarks 90 | /// 91 | /// Use quantile for non-integer percentiles. `p` must be between `0` and 92 | /// `100` inclusive. 93 | /// Returns `f64::NAN` if data is empty or `p` is outside the inclusive 94 | /// range. 95 | /// 96 | /// # Examples 97 | /// 98 | /// ``` 99 | /// use statrs::statistics::OrderStatistics; 100 | /// use statrs::statistics::Data; 101 | /// 102 | /// let x = []; 103 | /// let mut x = Data::new(x); 104 | /// assert!(x.percentile(0).is_nan()); 105 | /// 106 | /// let y = [1.0, 5.0, 3.0, 4.0, 10.0, 9.0, 6.0, 7.0, 8.0, 2.0]; 107 | /// let mut y = Data::new(y); 108 | /// assert_eq!(y.percentile(0), 1.0); 109 | /// assert_eq!(y.percentile(50), 5.5); 110 | /// assert_eq!(y.percentile(100), 10.0); 111 | /// assert!(y.percentile(105).is_nan()); 112 | /// assert!(y != Data::new([1.0, 5.0, 3.0, 4.0, 10.0, 9.0, 6.0, 7.0, 8.0, 2.0])); 113 | /// ``` 114 | fn percentile(&mut self, p: usize) -> T; 115 | 116 | /// Estimates the first quartile value from the data. 117 | /// 118 | /// # Remarks 119 | /// 120 | /// Returns `f64::NAN` if data is empty 121 | /// 122 | /// # Examples 123 | /// 124 | /// ``` 125 | /// use approx::assert_abs_diff_eq; 126 | /// 127 | /// use statrs::statistics::OrderStatistics; 128 | /// use statrs::statistics::Data; 129 | /// 130 | /// # fn main() { 131 | /// let x = []; 132 | /// let mut x = Data::new(x); 133 | /// assert!(x.lower_quartile().is_nan()); 134 | /// 135 | /// let y = [2.0, 1.0, 3.0, 4.0]; 136 | /// let mut y = Data::new(y); 137 | /// assert_abs_diff_eq!(y.lower_quartile(), 1.416666666666666, epsilon = 1e-15); 138 | /// assert!(y != Data::new([2.0, 1.0, 3.0, 4.0])); 139 | /// # } 140 | /// ``` 141 | fn lower_quartile(&mut self) -> T; 142 | 143 | /// Estimates the third quartile value from the data. 144 | /// 145 | /// # Remarks 146 | /// 147 | /// Returns `f64::NAN` if data is empty 148 | /// 149 | /// # Examples 150 | /// 151 | /// ``` 152 | /// use approx::assert_abs_diff_eq; 153 | /// 154 | /// use statrs::statistics::OrderStatistics; 155 | /// use statrs::statistics::Data; 156 | /// 157 | /// # fn main() { 158 | /// let x = []; 159 | /// let mut x = Data::new(x); 160 | /// assert!(x.upper_quartile().is_nan()); 161 | /// 162 | /// let y = [2.0, 1.0, 3.0, 4.0]; 163 | /// let mut y = Data::new(y); 164 | /// assert_abs_diff_eq!(y.upper_quartile(), 3.5833333333333333, epsilon = 1e-15); 165 | /// assert!(y != Data::new([2.0, 1.0, 3.0, 4.0])); 166 | /// # } 167 | /// ``` 168 | fn upper_quartile(&mut self) -> T; 169 | 170 | /// Estimates the inter-quartile range from the data. 171 | /// 172 | /// # Remarks 173 | /// 174 | /// Returns `f64::NAN` if data is empty 175 | /// 176 | /// # Examples 177 | /// 178 | /// ``` 179 | /// use approx::assert_abs_diff_eq; 180 | /// 181 | /// use statrs::statistics::Data; 182 | /// use statrs::statistics::OrderStatistics; 183 | /// 184 | /// # fn main() { 185 | /// let x = []; 186 | /// let mut x = Data::new(x); 187 | /// assert!(x.interquartile_range().is_nan()); 188 | /// 189 | /// let y = [2.0, 1.0, 3.0, 4.0]; 190 | /// let mut y = Data::new(y); 191 | /// assert_abs_diff_eq!(y.interquartile_range(), 2.166666666666667, epsilon = 1e-15); 192 | /// assert!(y != Data::new([2.0, 1.0, 3.0, 4.0])); 193 | /// # } 194 | /// ``` 195 | fn interquartile_range(&mut self) -> T; 196 | 197 | /// Evaluates the rank of each entry of the data. 198 | /// 199 | /// # Examples 200 | /// 201 | /// ``` 202 | /// use statrs::statistics::{OrderStatistics, RankTieBreaker}; 203 | /// use statrs::statistics::Data; 204 | /// 205 | /// let x = []; 206 | /// let mut x = Data::new(x); 207 | /// assert_eq!(x.ranks(RankTieBreaker::Average).len(), 0); 208 | /// 209 | /// let y = [1.0, 3.0, 2.0, 2.0]; 210 | /// let mut y = Data::new([1.0, 3.0, 2.0, 2.0]); 211 | /// assert_eq!(y.clone().ranks(RankTieBreaker::Average), [1.0, 4.0, 212 | /// 2.5, 2.5]); 213 | /// assert_eq!(y.clone().ranks(RankTieBreaker::Min), [1.0, 4.0, 2.0, 214 | /// 2.0]); 215 | /// ``` 216 | #[cfg(feature = "std")] 217 | fn ranks(&mut self, tie_breaker: RankTieBreaker) -> Vec; 218 | } 219 | -------------------------------------------------------------------------------- /src/prec.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused_macros, unused_imports)] 2 | //! Provides utility functions for working with floating point precision. 3 | //! 4 | //! This module is intended for internal use within the `statrs` crate to ensure consistent 5 | //! precision checking across all statistical computations. While it is currently public 6 | //! for historical reasons, it will be made private in a future breaking release. 7 | //! 8 | //! # Usage 9 | //! 10 | //! The module provides three main types of precision checks: 11 | //! 12 | //! 1. Absolute difference checks (`abs_diff_eq!`) - Use when comparing values that should 13 | //! be close in absolute terms, e.g., when checking if a value is close to zero 14 | //! 15 | //! 2. Relative difference checks (`relative_eq!`) - Use when comparing values that scale 16 | //! with the input, e.g., when comparing probability densities or statistical moments 17 | //! 18 | //! 3. ULPs (Units in Last Place) checks (`ulps_eq!`) - Use for comparing values that 19 | //! should be close in terms of floating-point representation 20 | //! 21 | //! Each check type has both a non-asserting version (e.g., `abs_diff_eq!`) and an 22 | //! asserting version (e.g., `assert_abs_diff_eq!`). 23 | //! 24 | //! # Default Precision Levels 25 | //! 26 | //! The module defines default precision levels that are carefully chosen to balance 27 | //! correctness and performance: 28 | //! 29 | //! - `DEFAULT_RELATIVE_ACC`: 1e-14 for relative comparisons 30 | //! - `DEFAULT_EPS`: 1e-9 for absolute comparisons 31 | //! - `DEFAULT_ULPS`: 5 for ULPs comparisons 32 | //! 33 | //! These defaults should be used unless there is a specific reason to use different 34 | //! precision levels. 35 | //! 36 | //! # Module-Specific Precision 37 | //! 38 | //! Some modules may require different precision levels than the crate defaults. In such 39 | //! cases, the module should define its own precision constants using the same names as 40 | //! defined here (e.g., `MODULE_RELATIVE_ACC`, `MODULE_EPS`) to maintain consistency 41 | //! and searchability. 42 | //! 43 | //! # Deprecated Functionality 44 | //! 45 | //! The following items are deprecated and will be removed in a future release: 46 | //! - `almost_eq` function - Use `abs_diff_eq!` macro instead 47 | //! - `assert_almost_eq!` macro - Use `assert_abs_diff_eq!` macro instead 48 | 49 | /// Standard epsilon, maximum relative precision of IEEE 754 double-precision 50 | /// floating point numbers (64 bit) e.g. `2^-53` 51 | pub const F64_PREC: f64 = 0.00000000000000011102230246251565; 52 | 53 | /// Default accuracy for `f64`, equivalent to `0.0 * F64_PREC` 54 | pub const DEFAULT_F64_ACC: f64 = 0.0000000000000011102230246251565; 55 | 56 | /// Default and target relative accuracy for f64 operations 57 | pub const DEFAULT_RELATIVE_ACC: f64 = 1e-14; 58 | 59 | /// Default and target absolute accuracy for f64 operations 60 | pub const DEFAULT_EPS: f64 = 1e-9; 61 | 62 | /// Default and target ULPs accuracy for f64 operations 63 | pub const DEFAULT_ULPS: u32 = 5; 64 | 65 | /// Compares if two floats are close via `approx::abs_diff_eq` 66 | /// using a maximum absolute difference (epsilon) of `acc`. 67 | #[deprecated(since = "0.19.0", note = "Use abs_diff_eq! macro instead")] 68 | pub fn almost_eq(a: f64, b: f64, acc: f64) -> bool { 69 | use approx::AbsDiffEq; 70 | if a.is_infinite() && b.is_infinite() { 71 | return a == b; 72 | } 73 | a.abs_diff_eq(&b, acc) 74 | } 75 | 76 | /// Compares if two floats are close via `prec::relative_eq!` 77 | /// Updates first argument to value of second argument 78 | pub(crate) fn convergence(x: &mut f64, x_new: f64) -> bool { 79 | let res = relative_eq!(*x, x_new); 80 | *x = x_new; 81 | res 82 | } 83 | 84 | macro_rules! redefine_one_opt_approx_macro { 85 | ( 86 | $approx_macro:ident, 87 | { epsilon: $default_eps:expr } 88 | ) => { 89 | macro_rules! $approx_macro { 90 | // Caller provides an override for epsilon. 91 | ($a:expr, $b:expr, epsilon = $user_eps:expr) => { 92 | approx::$approx_macro!($a, $b, epsilon = $user_eps) 93 | }; 94 | // No override: use default. 95 | ($a:expr, $b:expr) => { 96 | approx::$approx_macro!($a, $b, epsilon = $default_eps) 97 | }; 98 | } 99 | }; 100 | } 101 | 102 | macro_rules! redefine_two_opt_approx_macro { 103 | ( 104 | $approx_macro:ident, 105 | { epsilon: $default_eps:expr, $second_key:ident: $default_second:expr } 106 | ) => { 107 | macro_rules! $approx_macro { 108 | // Caller provides both options. 109 | ($a:expr, $b:expr, epsilon = $user_eps:expr, $second_key = $user_second:expr) => { 110 | approx::$approx_macro!($a, $b, epsilon = $user_eps, $second_key = $user_second) 111 | }; 112 | // Caller provides epsilon only; use default for second. 113 | ($a:expr, $b:expr, epsilon = $user_eps:expr) => { 114 | approx::$approx_macro!($a, $b, epsilon = $user_eps, $second_key = $default_second) 115 | }; 116 | // Caller provides the second option only; use default for epsilon. 117 | ($a:expr, $b:expr, $second_key = $user_second:expr) => { 118 | approx::$approx_macro!($a, $b, epsilon = $default_eps, $second_key = $user_second) 119 | }; 120 | // Caller provides neither: use both defaults. 121 | ($a:expr, $b:expr) => { 122 | approx::$approx_macro!( 123 | $a, 124 | $b, 125 | epsilon = $default_eps, 126 | $second_key = $default_second 127 | ) 128 | }; 129 | } 130 | }; 131 | } 132 | mod macros { 133 | pub(crate) use redefine_one_opt_approx_macro; 134 | pub(crate) use redefine_two_opt_approx_macro; 135 | 136 | // Non-asserting wrappers: 137 | redefine_one_opt_approx_macro!( 138 | abs_diff_eq, 139 | { epsilon: crate::prec::DEFAULT_EPS } 140 | ); 141 | redefine_two_opt_approx_macro!( 142 | relative_eq, 143 | { epsilon: crate::prec::DEFAULT_EPS, max_relative: crate::prec::DEFAULT_RELATIVE_ACC } 144 | ); 145 | redefine_two_opt_approx_macro!( 146 | ulps_eq, 147 | { epsilon: crate::prec::DEFAULT_EPS, max_ulps: crate::prec::DEFAULT_ULPS } 148 | ); 149 | 150 | pub(crate) use abs_diff_eq; 151 | pub(crate) use relative_eq; 152 | pub(crate) use ulps_eq; 153 | 154 | // Asserting wrappers: 155 | redefine_one_opt_approx_macro!( 156 | assert_abs_diff_eq, 157 | { epsilon: crate::prec::DEFAULT_EPS } 158 | ); 159 | redefine_two_opt_approx_macro!( 160 | assert_relative_eq, 161 | { epsilon: crate::prec::DEFAULT_EPS, max_relative: crate::prec::DEFAULT_RELATIVE_ACC } 162 | ); 163 | redefine_two_opt_approx_macro!( 164 | assert_ulps_eq, 165 | { epsilon: crate::prec::DEFAULT_EPS, max_ulps: crate::prec::DEFAULT_ULPS } 166 | ); 167 | 168 | pub(crate) use assert_abs_diff_eq; 169 | pub(crate) use assert_relative_eq; 170 | pub(crate) use assert_ulps_eq; 171 | 172 | #[deprecated(since = "0.19.0", note = "Use assert_abs_diff_eq! macro instead")] 173 | macro_rules! assert_almost_eq { 174 | ($a:expr, $b:expr, $eps:expr $(,)?) => { 175 | approx::assert_abs_diff_eq!($a, $b, epsilon = $eps) 176 | }; 177 | } 178 | } 179 | 180 | pub(crate) use macros::*; 181 | -------------------------------------------------------------------------------- /src/stats_tests/chisquare.rs: -------------------------------------------------------------------------------- 1 | //! Provides the functions related to [Chi-Squared tests](https://en.wikipedia.org/wiki/Chi-squared_test) 2 | 3 | use crate::distribution::{ChiSquared, ContinuousCDF}; 4 | use crate::prec; 5 | 6 | /// Represents the errors that can occur when computing the chisquare function 7 | #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] 8 | #[non_exhaustive] 9 | pub enum ChiSquareTestError { 10 | /// `f_obs` must have a length (or number of categories) greater than 1 11 | FObsInvalid, 12 | /// `f_exp` must have same length and sum as `f_obs` 13 | FExpInvalid, 14 | /// for the p-value to be meaningful, `ddof` must be at least two less 15 | /// than the number of categories, k, which is the length of `f_obs` 16 | DdofInvalid, 17 | } 18 | 19 | impl core::fmt::Display for ChiSquareTestError { 20 | #[cfg_attr(coverage_nightly, coverage(off))] 21 | fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { 22 | match self { 23 | ChiSquareTestError::FObsInvalid => { 24 | write!(f, "`f_obs` must have a length greater than 1") 25 | } 26 | ChiSquareTestError::FExpInvalid => { 27 | write!(f, "`f_exp` must have same length and sum as `f_obs`") 28 | } 29 | ChiSquareTestError::DdofInvalid => { 30 | write!( 31 | f, 32 | "for the p-value to be meaningful, `ddof` must be at least two less than the number of categories, k, which is the length of `f_obs`" 33 | ) 34 | } 35 | } 36 | } 37 | } 38 | 39 | #[cfg(feature = "std")] 40 | impl std::error::Error for ChiSquareTestError {} 41 | 42 | /// Perform a Pearson's chi-square test 43 | /// 44 | /// Returns the chi-square test statistic and p-value 45 | /// 46 | /// # Remarks 47 | /// 48 | /// `ddof` represents an adjustment that can be made to the degrees of freedom where the unadjusted 49 | /// degrees of freedom is `f_obs.len() - 1`. 50 | /// 51 | /// Implementation based on [wikipedia](https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test) 52 | /// while aligning to [scipy's](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html) 53 | /// function header where possible. The scipy implementation was also used for testing and validation. 54 | /// 55 | /// # Examples 56 | /// 57 | /// ``` 58 | /// use statrs::stats_tests::chisquare::chisquare; 59 | /// let (statistic, pvalue) = chisquare(&[16, 18, 16, 14, 12, 12], None, None).unwrap(); 60 | /// let (statistic, pvalue) = chisquare(&[16, 18, 16, 14, 12, 12], None, Some(1)).unwrap(); 61 | /// let (statistic, pvalue) = chisquare( 62 | /// &[16, 18, 16, 14, 12, 12], 63 | /// Some(&[16.0, 16.0, 16.0, 16.0, 16.0, 8.0]), 64 | /// None, 65 | /// ) 66 | /// .unwrap(); 67 | /// ``` 68 | pub fn chisquare( 69 | f_obs: &[usize], 70 | f_exp: Option<&[f64]>, 71 | ddof: Option, 72 | ) -> Result<(f64, f64), ChiSquareTestError> { 73 | let n: usize = f_obs.len(); 74 | if n <= 1 { 75 | return Err(ChiSquareTestError::FObsInvalid); 76 | } 77 | let stat = if let Some(f_exp) = f_exp { 78 | if f_exp.len() != n { 79 | return Err(ChiSquareTestError::FExpInvalid); 80 | } 81 | 82 | let mut total_samples = 0.0; 83 | let mut sum_expected = 0.0; 84 | 85 | let mut stat = 0.0; 86 | 87 | for (obs, exp) in f_obs.iter().zip(f_exp) { 88 | let obs = *obs as f64; 89 | 90 | stat += (obs - exp).powi(2) / exp; 91 | 92 | total_samples += obs; 93 | sum_expected += exp; 94 | } 95 | 96 | if !prec::relative_eq!(total_samples, sum_expected) { 97 | return Err(ChiSquareTestError::FExpInvalid); 98 | } 99 | 100 | stat 101 | } else { 102 | let total_samples: usize = f_obs.iter().sum(); 103 | // Assume all frequencies are equally likely 104 | let exp = total_samples as f64 / n as f64; 105 | 106 | f_obs 107 | .iter() 108 | .map(|obs| *obs as f64) 109 | .map(|obs| (obs - exp).powi(2) / exp) 110 | .sum() 111 | }; 112 | 113 | let ddof = match ddof { 114 | Some(ddof_to_validate) => { 115 | if ddof_to_validate >= (n - 1) { 116 | return Err(ChiSquareTestError::DdofInvalid); 117 | } 118 | ddof_to_validate 119 | } 120 | None => 0, 121 | }; 122 | let dof = n - 1 - ddof; 123 | 124 | let chi_dist = ChiSquared::new(dof as f64).expect("ddof validity should already be checked"); 125 | let pvalue = 1.0 - chi_dist.cdf(stat); 126 | 127 | Ok((stat, pvalue)) 128 | } 129 | 130 | #[rustfmt::skip] 131 | #[cfg(test)] 132 | mod tests { 133 | use super::*; 134 | use crate::prec; 135 | 136 | #[test] 137 | fn test_scipy_example() { 138 | let (statistic, pvalue) = chisquare(&[16, 18, 16, 14, 12, 12], None, None).unwrap(); 139 | prec::assert_abs_diff_eq!(statistic, 2.0, epsilon = 1e-1); 140 | prec::assert_abs_diff_eq!(pvalue, 0.84914503608460956, epsilon = 1e-9); 141 | 142 | let (statistic, pvalue) = chisquare( 143 | &[16, 18, 16, 14, 12, 12], 144 | Some(&[16.0, 16.0, 16.0, 16.0, 16.0, 8.0]), 145 | None, 146 | ) 147 | .unwrap(); 148 | prec::assert_abs_diff_eq!(statistic, 3.5, epsilon = 1e-1); 149 | prec::assert_abs_diff_eq!(pvalue, 0.62338762774958223, epsilon = 1e-9); 150 | 151 | let (statistic, pvalue) = chisquare(&[16, 18, 16, 14, 12, 12], None, Some(1)).unwrap(); 152 | prec::assert_abs_diff_eq!(statistic, 2.0, epsilon = 1e-1); 153 | prec::assert_abs_diff_eq!(pvalue, 0.7357588823428847, epsilon = 1e-9); 154 | } 155 | #[test] 156 | fn test_wiki_example() { 157 | // fairness of dice - p-value not provided 158 | let (statistic, _) = chisquare(&[5, 8, 9, 8, 10, 20], None, None).unwrap(); 159 | prec::assert_abs_diff_eq!(statistic, 13.4, epsilon = 1e-1); 160 | 161 | let (statistic, _) = chisquare(&[5, 8, 9, 8, 10, 20], Some(&[10.0; 6]), None).unwrap(); 162 | prec::assert_abs_diff_eq!(statistic, 13.4, epsilon = 1e-1); 163 | 164 | // chi-squared goodness of fit test 165 | let (statistic, pvalue) = chisquare(&[44, 56], Some(&[50.0, 50.0]), None).unwrap(); 166 | prec::assert_abs_diff_eq!(statistic, 1.44, epsilon = 1e-2); 167 | prec::assert_abs_diff_eq!(pvalue, 0.24, epsilon = 1e-2); 168 | } 169 | 170 | #[test] 171 | fn test_bad_data_f_obs_invalid() { 172 | let result = chisquare(&[16], None, None); 173 | assert_eq!(result, Err(ChiSquareTestError::FObsInvalid)); 174 | let f_exp: &[usize] = &[]; 175 | let result = chisquare(f_exp, None, None); 176 | assert_eq!(result, Err(ChiSquareTestError::FObsInvalid)); 177 | } 178 | #[test] 179 | fn test_bad_data_f_exp_invalid() { 180 | let result = chisquare(&[16, 18, 16, 14, 12, 12], Some(&[1.0, 2.0, 3.0]), None); 181 | assert_eq!(result, Err(ChiSquareTestError::FExpInvalid)); 182 | let result = chisquare(&[16, 18, 16, 14, 12, 12], Some(&[16.0; 6]), None); 183 | assert_eq!(result, Err(ChiSquareTestError::FExpInvalid)); 184 | } 185 | #[test] 186 | fn test_bad_data_ddof_invalid() { 187 | let result = chisquare(&[16, 18, 16, 14, 12, 12], None, Some(5)); 188 | assert_eq!(result, Err(ChiSquareTestError::DdofInvalid)); 189 | let result = chisquare(&[16, 18, 16, 14, 12, 12], None, Some(100)); 190 | assert_eq!(result, Err(ChiSquareTestError::DdofInvalid)); 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # statrs 2 | 3 | ![tests][actions-test-badge] 4 | [![MIT licensed][license-badge]](./LICENSE.md) 5 | [![Crate][crates-badge]][crates-url] 6 | [![docs.rs][docsrs-badge]][docs-url] 7 | [![codecov-statrs][codecov-badge]][codecov-url] 8 | ![Crates.io MSRV][crates-msrv-badge] 9 | 10 | [actions-test-badge]: https://github.com/statrs-dev/statrs/actions/workflows/test.yml/badge.svg 11 | [crates-badge]: https://img.shields.io/crates/v/statrs.svg 12 | [crates-url]: https://crates.io/crates/statrs 13 | [license-badge]: https://img.shields.io/badge/license-MIT-blue.svg 14 | [docsrs-badge]: https://img.shields.io/docsrs/statrs 15 | [docs-url]: https://docs.rs/statrs/*/statrs 16 | [codecov-badge]: https://codecov.io/gh/statrs-dev/statrs/graph/badge.svg?token=XtMSMYXvIf 17 | [codecov-url]: https://codecov.io/gh/statrs-dev/statrs 18 | [crates-msrv-badge]: https://img.shields.io/crates/msrv/statrs 19 | 20 | Statrs provides a host of statistical utilities for Rust scientific computing. 21 | 22 | Included are a number of common distributions that can be sampled (i.e. Normal, Exponential, Student's T, Gamma, Uniform, etc.) plus common statistical functions like the gamma function, beta function, and error function. 23 | 24 | This library began as port of the statistical capabilities in the C# Math.NET library. 25 | All unit tests in the library borrowed from Math.NET when possible and filled-in when not. 26 | Planned for future releases are continued implementations of distributions as well as porting over more statistical utilities. 27 | 28 | Please check out the documentation [here][docs-url]. 29 | 30 | ## Usage 31 | 32 | Add the most recent release to your `Cargo.toml` 33 | 34 | ```toml 35 | [dependencies] 36 | statrs = "*" # replace * by the latest version of the crate. 37 | ``` 38 | 39 | For examples, view [the docs](https://docs.rs/statrs/*/statrs/). 40 | 41 | ### Running tests 42 | 43 | If you'd like to run all suggested tests, you'll need to download some data from 44 | NIST, we have a script for this and formatting the data in the `tests/` folder. 45 | 46 | ```sh 47 | cargo test 48 | ./tests/gather_nist_data.sh && cargo test -- --include-ignored nist_ 49 | ``` 50 | 51 | If you'd like to modify where the data is downloaded, you can use the environment variable, 52 | `STATRS_NIST_DATA_DIR` for running the script and the tests. 53 | 54 | ## Minimum supported Rust version (MSRV) 55 | 56 | This crate requires a Rust version of 1.87.0 or higher. Increases in MSRV will be considered a semver non-breaking API change and require a version increase (PATCH until 1.0.0, MINOR after 1.0.0). 57 | 58 | ## Precision 59 | Floating-point numbers cannot always represent decimal values exactly, which can introduce small (and in some cases catastrophically large) errors in computations. 60 | In statistical applications, these errors can accumulate, making careful precision control important. 61 | 62 | ### For Users and Evaluators 63 | 64 | The `statrs` crate takes precision seriously: 65 | 66 | - We use standardized precision checks throughout the codebase 67 | - Default precision levels are carefully chosen to balance correctness and performance 68 | - Module-specific precision requirements are explicitly documented where they differ from defaults 69 | - Our test suite verifies numerical accuracy against common reference libraries 70 | 71 | Key precision constants in the crate are set by pub consts in the `prec` module: 72 | - Default relative accuracy: `pub const DEFAULT_RELATIVE_ACC` 73 | - Default epsilon: `pub const DEFAULT_EPS` 74 | - Default ULPs (Units in Last Place): `pub const DEFAULT_ULPS` 75 | 76 | Some modules/submodules have default precision that is different from the crate defaults, for searchability the names of such constants are the `MODULE_RELATIVE_ACC`, `MODULE_EPS`, and `MODULE_ULPS`. 77 | 78 | > [!IMPORTANT] 79 | > Starting from v0.19.0, the `prec` module is no longer public (`pub mod prec` → `mod prec`). This change reflects that precision handling is an internal implementation detail. 80 | > 81 | > The precision constants mentioned above remain stable and documented and will be reexported at the crate level, but direct access to the module's utilities is now restricted to maintain better API boundaries. 82 | 83 | 84 | ### For Contributors 85 | // express your sentiment about the intended use of `prec` module in this section. The reason is that this section is for contributors and the users need not know about internal functionality. 86 | To help maintain consistent precision checking, `statrs` provides: 87 | 88 | 1. A `prec` module that wraps and standardizes common approximation checks from the `approx` crate with crate-specific defaults 89 | 2. Macros for common precision comparison patterns 90 | 3. Helper functions for convergence testing 91 | 92 | When contributing: 93 | - Use the provided precision utilities rather than hard-coding values 94 | - Maintain or improve precision in existing tests when making changes, new modules can start at lesser precision than the crate defaults if need be 95 | - when doing so, one should use the same names as defined in the `prec` module, this helps with searchabiliy. 96 | - Document any module-specific precision requirements 97 | 98 | ### Learning Resources 99 | 100 | If you're new to floating-point precision, these resources provide helpful introductions: 101 | 102 | - [Comparing Floating Point Numbers, 2012 Edition](https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/) 103 | - [The Floating Point Guide - Comparison](http://floating-point-gui.de/errors/comparison/) 104 | - [What Every Computer Scientist Should Know About Floating-Point Arithmetic](https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html) 105 | 106 | ## Contributing 107 | 108 | Thanks for your help to improve the project! 109 | **No contribution is too small and all contributions are valued.** 110 | 111 | If you're not familiar with precision in floating point operations, please read the section on [precision](#precision) specifically, the [For Contributors](#for-contributors) section. 112 | 113 | Suggestions if you don't know where to start, 114 | - if you're an existing user, file an issue or feature request. 115 | - [documentation][docs-url] is a great place to start, as you'll be able to identify the value of existing documentation better than its authors. 116 | - tests are valuable in demonstrating correct behavior, you can review test coverage on the [CodeCov Report][codecov-url] 117 | - check out some of the issues marked [help wanted](https://github.com/statrs-dev/statrs/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). 118 | - look at features you'd like to see in statrs 119 | - Math.NET 120 | - [Distributions](https://github.com/mathnet/mathnet-numerics/tree/master/src/Numerics/Distributions) 121 | - [Statistics](https://github.com/mathnet/mathnet-numerics/tree/master/src/Numerics/Statistics) 122 | - scipy.stats 123 | - KDE, see (issue #193)[https://github.com/statrs-dev/statrs/issues/193] 124 | 125 | ### How to contribute 126 | 127 | Clone the repo: 128 | 129 | ```sh 130 | git clone https://github.com/statrs-dev/statrs 131 | ``` 132 | 133 | Create a feature branch: 134 | 135 | ```sh 136 | git checkout -b master 137 | ``` 138 | 139 | Write your code and docs, then ensure it is formatted: 140 | 141 | ```sh 142 | cargo fmt 143 | ``` 144 | 145 | Add `--check` to view the diff without making file changes. 146 | Our CI will check format without making changes. 147 | 148 | After commiting your code: 149 | 150 | ```shell 151 | git push -u # with `git` 152 | gh pr create --head # with GitHub's cli 153 | ``` 154 | 155 | Then submit a PR, preferably referencing the relevant issue, if it exists. 156 | 157 | ### Commit messages 158 | 159 | Please be explicit and and purposeful with commit messages. 160 | [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/#summary) encouraged. 161 | 162 | ### Communication Expectations 163 | 164 | Please allow at least one week before pinging issues/pr's. 165 | -------------------------------------------------------------------------------- /src/stats_tests/ttest_onesample.rs: -------------------------------------------------------------------------------- 1 | //! Provides the [one-sample t-test](https://en.wikipedia.org/wiki/Student%27s_t-test#One-sample_t-test) 2 | //! and related functions 3 | 4 | use crate::distribution::{ContinuousCDF, StudentsT}; 5 | use crate::stats_tests::{Alternative, NaNPolicy}; 6 | 7 | /// Represents the errors that can occur when computing the ttest_onesample function 8 | #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] 9 | #[non_exhaustive] 10 | pub enum TTestOneSampleError { 11 | /// sample must be greater than length 1 12 | SampleTooSmall, 13 | /// samples can not contain NaN when `nan_policy` is set to `NaNPolicy::Error` 14 | SampleContainsNaN, 15 | } 16 | 17 | impl core::fmt::Display for TTestOneSampleError { 18 | #[cfg_attr(coverage_nightly, coverage(off))] 19 | fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { 20 | match self { 21 | TTestOneSampleError::SampleTooSmall => write!(f, "sample must be len > 1"), 22 | TTestOneSampleError::SampleContainsNaN => { 23 | write!( 24 | f, 25 | "samples can not contain NaN when nan_policy is set to NaNPolicy::Error" 26 | ) 27 | } 28 | } 29 | } 30 | } 31 | 32 | #[cfg(feature = "std")] 33 | impl std::error::Error for TTestOneSampleError {} 34 | 35 | /// Perform a one sample t-test 36 | /// 37 | /// Returns the t-statistic and p-value 38 | /// 39 | /// # Remarks 40 | /// 41 | /// `a` needs to be mutable in case needing to filter out NaNs for NaNPolicy::Emit 42 | /// 43 | /// Implementation based on [jmp](https://www.jmp.com/en_us/statistics-knowledge-portal/t-test/one-sample-t-test.html) 44 | /// while aligning to [scipy's](https://docs.scipy.org/doc/scipy-1.14.1/reference/generated/scipy.stats.ttest_1samp.html) 45 | /// function header where possible. 46 | /// 47 | /// # Examples 48 | /// 49 | /// ``` 50 | /// use statrs::stats_tests::ttest_onesample::ttest_onesample; 51 | /// use statrs::stats_tests::{Alternative, NaNPolicy}; 52 | /// let data = Vec::from([13f64, 9f64, 11f64, 8f64, 7f64, 12f64]); 53 | /// let (statistic, pvalue) = ttest_onesample(data, 13f64, Alternative::TwoSided, NaNPolicy::Error).unwrap(); 54 | /// ``` 55 | pub fn ttest_onesample( 56 | mut a: Vec, 57 | popmean: f64, 58 | alternative: Alternative, 59 | nan_policy: NaNPolicy, 60 | ) -> Result<(f64, f64), TTestOneSampleError> { 61 | let has_nans = a.iter().any(|x| x.is_nan()); 62 | if has_nans { 63 | match nan_policy { 64 | NaNPolicy::Propogate => { 65 | return Ok((f64::NAN, f64::NAN)); 66 | } 67 | NaNPolicy::Error => { 68 | return Err(TTestOneSampleError::SampleContainsNaN); 69 | } 70 | NaNPolicy::Emit => { 71 | a = a.into_iter().filter(|x| !x.is_nan()).collect::>(); 72 | } 73 | } 74 | } 75 | 76 | let n = a.len(); 77 | if n < 2 { 78 | return Err(TTestOneSampleError::SampleTooSmall); 79 | } 80 | let samplemean = a.iter().sum::() / (n as f64); 81 | let df = (n - 1) as f64; 82 | let s = a.iter().map(|x| (x - samplemean).powi(2)).sum::() / df; 83 | let se = (s / n as f64).sqrt(); 84 | 85 | let tstat = (samplemean - popmean) / se; 86 | 87 | let t_dist = 88 | StudentsT::new(0.0, 1.0, df).expect("df should always be non NaN and greater than 0"); 89 | 90 | let pvalue = match alternative { 91 | Alternative::TwoSided => 2.0 * (1.0 - t_dist.cdf(tstat.abs())), 92 | Alternative::Less => t_dist.cdf(tstat), 93 | Alternative::Greater => 1.0 - t_dist.cdf(tstat), 94 | }; 95 | 96 | Ok((tstat, pvalue)) 97 | } 98 | 99 | #[rustfmt::skip] 100 | #[cfg(test)] 101 | mod tests { 102 | use super::*; 103 | use crate::prec; 104 | 105 | /// Test one sample t-test comparing to 106 | #[test] 107 | fn test_jmp_example() { 108 | // Test against an example from jmp.com 109 | // https://www.jmp.com/en_us/statistics-knowledge-portal/t-test/one-sample-t-test.html 110 | let data = Vec::from([ 111 | 20.70f64, 27.46f64, 22.15f64, 19.85f64, 21.29f64, 24.75f64, 20.75f64, 22.91f64, 112 | 25.34f64, 20.33f64, 21.54f64, 21.08f64, 22.14f64, 19.56f64, 21.10f64, 18.04f64, 113 | 24.12f64, 19.95f64, 19.72f64, 18.28f64, 16.26f64, 17.46f64, 20.53f64, 22.12f64, 114 | 25.06f64, 22.44f64, 19.08f64, 19.88f64, 21.39f64, 22.33f64, 25.79f64, 115 | ]); 116 | let (statistic, pvalue) = 117 | ttest_onesample(data.clone(), 20.0, Alternative::TwoSided, NaNPolicy::Error).unwrap(); 118 | prec::assert_relative_eq!(statistic, 3.066831635284081); 119 | prec::assert_abs_diff_eq!(pvalue, 0.004552621060635401); 120 | 121 | let (statistic, pvalue) = 122 | ttest_onesample(data.clone(), 20.0, Alternative::Greater, NaNPolicy::Error).unwrap(); 123 | prec::assert_relative_eq!(statistic, 3.066831635284081); 124 | prec::assert_abs_diff_eq!(pvalue, 0.0022763105303177005); 125 | 126 | let (statistic, pvalue) = 127 | ttest_onesample(data.clone(), 20.0, Alternative::Less, NaNPolicy::Error).unwrap(); 128 | prec::assert_relative_eq!(statistic, 3.066831635284081); 129 | prec::assert_abs_diff_eq!(pvalue, 0.9977236894696823); 130 | } 131 | #[test] 132 | fn test_nan_in_data_w_emit() { 133 | // results should be the same as the example above since the NaNs should be filtered out 134 | let data = Vec::from([ 135 | 20.70f64, 136 | 27.46f64, 137 | 22.15f64, 138 | 19.85f64, 139 | 21.29f64, 140 | 24.75f64, 141 | 20.75f64, 142 | 22.91f64, 143 | 25.34f64, 144 | 20.33f64, 145 | 21.54f64, 146 | 21.08f64, 147 | 22.14f64, 148 | 19.56f64, 149 | 21.10f64, 150 | 18.04f64, 151 | 24.12f64, 152 | 19.95f64, 153 | 19.72f64, 154 | 18.28f64, 155 | 16.26f64, 156 | 17.46f64, 157 | 20.53f64, 158 | 22.12f64, 159 | 25.06f64, 160 | 22.44f64, 161 | 19.08f64, 162 | 19.88f64, 163 | 21.39f64, 164 | 22.33f64, 165 | 25.79f64, 166 | f64::NAN, 167 | f64::NAN, 168 | f64::NAN, 169 | f64::NAN, 170 | f64::NAN, 171 | ]); 172 | let (statistic, pvalue) = 173 | ttest_onesample(data.clone(), 20.0, Alternative::TwoSided, NaNPolicy::Emit).unwrap(); 174 | prec::assert_relative_eq!(statistic, 3.066831635284081); 175 | prec::assert_abs_diff_eq!(pvalue, 0.004552621060635401); 176 | } 177 | #[test] 178 | fn test_nan_in_data_w_propogate() { 179 | let sample_input = Vec::from([1.3, f64::NAN]); 180 | let (statistic, pvalue) = ttest_onesample( 181 | sample_input, 182 | 20.0, 183 | Alternative::TwoSided, 184 | NaNPolicy::Propogate, 185 | ) 186 | .unwrap(); 187 | assert!(statistic.is_nan()); 188 | assert!(pvalue.is_nan()); 189 | } 190 | #[test] 191 | fn test_nan_in_data_w_error() { 192 | let sample_input = Vec::from([0.0571, 0.0813, f64::NAN, 0.0836]); 193 | let result = ttest_onesample(sample_input, 20.0, Alternative::TwoSided, NaNPolicy::Error); 194 | assert_eq!(result, Err(TTestOneSampleError::SampleContainsNaN)); 195 | } 196 | #[test] 197 | fn test_bad_data_sample_too_small() { 198 | let sample_input = Vec::new(); 199 | let result = ttest_onesample(sample_input, 20.0, Alternative::TwoSided, NaNPolicy::Error); 200 | assert_eq!(result, Err(TTestOneSampleError::SampleTooSmall)); 201 | 202 | let sample_input = Vec::from([1.0]); 203 | let result = ttest_onesample(sample_input, 20.0, Alternative::TwoSided, NaNPolicy::Error); 204 | assert_eq!(result, Err(TTestOneSampleError::SampleTooSmall)); 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /src/statistics/iter_statistics.rs: -------------------------------------------------------------------------------- 1 | use crate::statistics::*; 2 | use core::borrow::Borrow; 3 | use core::f64; 4 | 5 | impl Statistics for T 6 | where 7 | T: IntoIterator, 8 | T::Item: Borrow, 9 | { 10 | fn min(self) -> f64 { 11 | let mut iter = self.into_iter(); 12 | match iter.next() { 13 | None => f64::NAN, 14 | Some(x) => iter.map(|x| *x.borrow()).fold(*x.borrow(), |acc, x| { 15 | if x < acc || x.is_nan() { x } else { acc } 16 | }), 17 | } 18 | } 19 | 20 | fn max(self) -> f64 { 21 | let mut iter = self.into_iter(); 22 | match iter.next() { 23 | None => f64::NAN, 24 | Some(x) => iter.map(|x| *x.borrow()).fold(*x.borrow(), |acc, x| { 25 | if x > acc || x.is_nan() { x } else { acc } 26 | }), 27 | } 28 | } 29 | 30 | fn abs_min(self) -> f64 { 31 | let mut iter = self.into_iter(); 32 | match iter.next() { 33 | None => f64::NAN, 34 | Some(init) => iter 35 | .map(|x| x.borrow().abs()) 36 | .fold(init.borrow().abs(), |acc, x| { 37 | if x < acc || x.is_nan() { x } else { acc } 38 | }), 39 | } 40 | } 41 | 42 | fn abs_max(self) -> f64 { 43 | let mut iter = self.into_iter(); 44 | match iter.next() { 45 | None => f64::NAN, 46 | Some(init) => iter 47 | .map(|x| x.borrow().abs()) 48 | .fold(init.borrow().abs(), |acc, x| { 49 | if x > acc || x.is_nan() { x } else { acc } 50 | }), 51 | } 52 | } 53 | 54 | fn mean(self) -> f64 { 55 | let mut i = 0.0; 56 | let mut mean = 0.0; 57 | for x in self { 58 | i += 1.0; 59 | mean += (x.borrow() - mean) / i; 60 | } 61 | if i > 0.0 { mean } else { f64::NAN } 62 | } 63 | 64 | fn geometric_mean(self) -> f64 { 65 | let mut i = 0.0; 66 | let mut sum = 0.0; 67 | for x in self { 68 | i += 1.0; 69 | sum += x.borrow().ln(); 70 | } 71 | if i > 0.0 { (sum / i).exp() } else { f64::NAN } 72 | } 73 | 74 | fn harmonic_mean(self) -> f64 { 75 | let mut i = 0.0; 76 | let mut sum = 0.0; 77 | for x in self { 78 | i += 1.0; 79 | 80 | let borrow = *x.borrow(); 81 | if borrow < 0f64 { 82 | return f64::NAN; 83 | } 84 | sum += 1.0 / borrow; 85 | } 86 | if i > 0.0 { i / sum } else { f64::NAN } 87 | } 88 | 89 | fn variance(self) -> f64 { 90 | let mut iter = self.into_iter(); 91 | let mut sum = match iter.next() { 92 | None => f64::NAN, 93 | Some(x) => *x.borrow(), 94 | }; 95 | let mut i = 1.0; 96 | let mut variance = 0.0; 97 | 98 | for x in iter { 99 | i += 1.0; 100 | let borrow = *x.borrow(); 101 | sum += borrow; 102 | let diff = i * borrow - sum; 103 | variance += diff * diff / (i * (i - 1.0)) 104 | } 105 | if i > 1.0 { 106 | variance / (i - 1.0) 107 | } else { 108 | f64::NAN 109 | } 110 | } 111 | 112 | fn std_dev(self) -> f64 { 113 | self.variance().sqrt() 114 | } 115 | 116 | fn population_variance(self) -> f64 { 117 | let mut iter = self.into_iter(); 118 | let mut sum = match iter.next() { 119 | None => return f64::NAN, 120 | Some(x) => *x.borrow(), 121 | }; 122 | let mut i = 1.0; 123 | let mut variance = 0.0; 124 | 125 | for x in iter { 126 | i += 1.0; 127 | let borrow = *x.borrow(); 128 | sum += borrow; 129 | let diff = i * borrow - sum; 130 | variance += diff * diff / (i * (i - 1.0)); 131 | } 132 | variance / i 133 | } 134 | 135 | fn population_std_dev(self) -> f64 { 136 | self.population_variance().sqrt() 137 | } 138 | 139 | fn covariance(self, other: Self) -> f64 { 140 | let mut n = 0.0; 141 | let mut mean1 = 0.0; 142 | let mut mean2 = 0.0; 143 | let mut comoment = 0.0; 144 | 145 | let mut iter = other.into_iter(); 146 | for x in self { 147 | let borrow = *x.borrow(); 148 | let borrow2 = match iter.next() { 149 | None => panic!("Iterators must have the same length"), 150 | Some(x) => *x.borrow(), 151 | }; 152 | let old_mean2 = mean2; 153 | n += 1.0; 154 | mean1 += (borrow - mean1) / n; 155 | mean2 += (borrow2 - mean2) / n; 156 | comoment += (borrow - mean1) * (borrow2 - old_mean2); 157 | } 158 | if iter.next().is_some() { 159 | panic!("Iterators must have the same length"); 160 | } 161 | 162 | if n > 1.0 { 163 | comoment / (n - 1.0) 164 | } else { 165 | f64::NAN 166 | } 167 | } 168 | 169 | fn population_covariance(self, other: Self) -> f64 { 170 | let mut n = 0.0; 171 | let mut mean1 = 0.0; 172 | let mut mean2 = 0.0; 173 | let mut comoment = 0.0; 174 | 175 | let mut iter = other.into_iter(); 176 | for x in self { 177 | let borrow = *x.borrow(); 178 | let borrow2 = match iter.next() { 179 | None => panic!("Iterators must have the same length"), 180 | Some(x) => *x.borrow(), 181 | }; 182 | let old_mean2 = mean2; 183 | n += 1.0; 184 | mean1 += (borrow - mean1) / n; 185 | mean2 += (borrow2 - mean2) / n; 186 | comoment += (borrow - mean1) * (borrow2 - old_mean2); 187 | } 188 | if iter.next().is_some() { 189 | panic!("Iterators must have the same length") 190 | } 191 | if n > 0.0 { comoment / n } else { f64::NAN } 192 | } 193 | 194 | fn quadratic_mean(self) -> f64 { 195 | let mut i = 0.0; 196 | let mut mean = 0.0; 197 | for x in self { 198 | let borrow = *x.borrow(); 199 | i += 1.0; 200 | mean += (borrow * borrow - mean) / i; 201 | } 202 | if i > 0.0 { mean.sqrt() } else { f64::NAN } 203 | } 204 | } 205 | 206 | #[rustfmt::skip] 207 | #[cfg(test)] 208 | mod tests { 209 | use core::f64::consts; 210 | use crate::generate::{InfinitePeriodic, InfiniteSinusoidal}; 211 | use crate::prec; 212 | use crate::statistics::Statistics; 213 | 214 | #[test] 215 | fn test_empty_data_returns_nan() { 216 | let data = [0.0; 0]; 217 | assert!(data.min().is_nan()); 218 | assert!(data.max().is_nan()); 219 | assert!(data.mean().is_nan()); 220 | assert!(data.quadratic_mean().is_nan()); 221 | assert!(data.variance().is_nan()); 222 | assert!(data.population_variance().is_nan()); 223 | } 224 | 225 | // TODO: test github issue 137 (Math.NET) 226 | 227 | #[test] 228 | fn test_large_samples() { 229 | let shorter = || InfinitePeriodic::default(4.0, 1.0).take(4*4096); 230 | let longer = || InfinitePeriodic::default(4.0, 1.0).take(4*32768); 231 | let s_mean = shorter().mean(); 232 | let s_qmean = shorter().quadratic_mean(); 233 | let l_mean = longer().mean(); 234 | let l_qmean = longer().quadratic_mean(); 235 | 236 | prec::assert_abs_diff_eq!(s_mean, 0.375, epsilon = 1e-14); 237 | prec::assert_abs_diff_eq!(l_mean, 0.375, epsilon = 1e-14); 238 | prec::assert_abs_diff_eq!(s_qmean, (0.21875f64).sqrt(), epsilon = 1e-14); 239 | prec::assert_abs_diff_eq!(l_qmean, (0.21875f64).sqrt(), epsilon = 1e-14); 240 | } 241 | 242 | #[test] 243 | fn test_quadratic_mean_of_sinusoidal() { 244 | let data = InfiniteSinusoidal::default(64.0, 16.0, 2.0).take(128); 245 | let qmean = data.quadratic_mean(); 246 | 247 | prec::assert_abs_diff_eq!(qmean, 2.0 / consts::SQRT_2, epsilon = 1e-15); 248 | } 249 | } 250 | -------------------------------------------------------------------------------- /src/distribution/dirac.rs: -------------------------------------------------------------------------------- 1 | use crate::distribution::ContinuousCDF; 2 | use crate::statistics::*; 3 | 4 | /// Implements the [Dirac Delta](https://en.wikipedia.org/wiki/Dirac_delta_function#As_a_distribution) 5 | /// distribution 6 | /// 7 | /// # Examples 8 | /// 9 | /// ``` 10 | /// use statrs::distribution::{Dirac, Continuous}; 11 | /// use statrs::statistics::Distribution; 12 | /// 13 | /// let n = Dirac::new(3.0).unwrap(); 14 | /// assert_eq!(n.mean().unwrap(), 3.0); 15 | /// ``` 16 | #[derive(Debug, Copy, Clone, PartialEq)] 17 | pub struct Dirac(f64); 18 | 19 | /// Represents the errors that can occur when creating a [`Dirac`]. 20 | #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] 21 | #[non_exhaustive] 22 | pub enum DiracError { 23 | /// The value v is NaN. 24 | ValueInvalid, 25 | } 26 | 27 | impl core::fmt::Display for DiracError { 28 | #[cfg_attr(coverage_nightly, coverage(off))] 29 | fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { 30 | match self { 31 | DiracError::ValueInvalid => write!(f, "Value v is NaN"), 32 | } 33 | } 34 | } 35 | 36 | #[cfg(feature = "std")] 37 | impl std::error::Error for DiracError {} 38 | 39 | impl Dirac { 40 | /// Constructs a new dirac distribution function at value `v`. 41 | /// 42 | /// # Errors 43 | /// 44 | /// Returns an error if `v` is not-a-number. 45 | /// 46 | /// # Examples 47 | /// 48 | /// ``` 49 | /// use statrs::distribution::Dirac; 50 | /// 51 | /// let mut result = Dirac::new(0.0); 52 | /// assert!(result.is_ok()); 53 | /// 54 | /// result = Dirac::new(f64::NAN); 55 | /// assert!(result.is_err()); 56 | /// ``` 57 | pub fn new(v: f64) -> Result { 58 | if v.is_nan() { 59 | Err(DiracError::ValueInvalid) 60 | } else { 61 | Ok(Dirac(v)) 62 | } 63 | } 64 | 65 | /// Returns the value `v` of the dirac distribution 66 | /// 67 | /// # Examples 68 | /// 69 | /// ``` 70 | /// use statrs::distribution::Dirac; 71 | /// 72 | /// let n = Dirac::new(3.0).unwrap(); 73 | /// assert_eq!(n.v(), 3.0); 74 | /// ``` 75 | pub fn v(&self) -> f64 { 76 | self.0 77 | } 78 | } 79 | 80 | impl core::fmt::Display for Dirac { 81 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 82 | write!(f, "δ_{}", self.0) 83 | } 84 | } 85 | 86 | #[cfg(feature = "rand")] 87 | #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] 88 | impl ::rand::distr::Distribution for Dirac { 89 | fn sample(&self, _: &mut R) -> f64 { 90 | self.0 91 | } 92 | } 93 | 94 | impl ContinuousCDF for Dirac { 95 | /// Calculates the cumulative distribution function for the 96 | /// dirac distribution at `x` 97 | /// 98 | /// Where the value is 1 if x > `v`, 0 otherwise. 99 | fn cdf(&self, x: f64) -> f64 { 100 | if x < self.0 { 0.0 } else { 1.0 } 101 | } 102 | 103 | /// Calculates the survival function for the 104 | /// dirac distribution at `x` 105 | /// 106 | /// Where the value is 0 if x > `v`, 1 otherwise. 107 | fn sf(&self, x: f64) -> f64 { 108 | if x < self.0 { 1.0 } else { 0.0 } 109 | } 110 | } 111 | 112 | impl Min for Dirac { 113 | /// Returns the minimum value in the domain of the 114 | /// dirac distribution representable by a double precision float 115 | /// 116 | /// # Formula 117 | /// 118 | /// ```text 119 | /// v 120 | /// ``` 121 | fn min(&self) -> f64 { 122 | self.0 123 | } 124 | } 125 | 126 | impl Max for Dirac { 127 | /// Returns the maximum value in the domain of the 128 | /// dirac distribution representable by a double precision float 129 | /// 130 | /// # Formula 131 | /// 132 | /// ```text 133 | /// v 134 | /// ``` 135 | fn max(&self) -> f64 { 136 | self.0 137 | } 138 | } 139 | 140 | impl Distribution for Dirac { 141 | /// Returns the mean of the dirac distribution 142 | /// 143 | /// # Remarks 144 | /// 145 | /// Since the only value that can be produced by this distribution is `v` with probability 146 | /// 1, it is just `v`. 147 | fn mean(&self) -> Option { 148 | Some(self.0) 149 | } 150 | 151 | /// Returns the variance of the dirac distribution 152 | /// 153 | /// # Formula 154 | /// 155 | /// ```text 156 | /// 0 157 | /// ``` 158 | /// 159 | /// Since only one value can be produced there is no variance. 160 | fn variance(&self) -> Option { 161 | Some(0.0) 162 | } 163 | 164 | /// Returns the entropy of the dirac distribution 165 | /// 166 | /// # Formula 167 | /// 168 | /// ```text 169 | /// 0 170 | /// ``` 171 | /// 172 | /// Since this distribution has full certainty, it encodes no information 173 | fn entropy(&self) -> Option { 174 | Some(0.0) 175 | } 176 | 177 | /// Returns the skewness of the dirac distribution 178 | /// 179 | /// # Formula 180 | /// 181 | /// ```text 182 | /// 0 183 | /// ``` 184 | fn skewness(&self) -> Option { 185 | Some(0.0) 186 | } 187 | } 188 | 189 | impl Median for Dirac { 190 | /// Returns the median of the dirac distribution 191 | /// 192 | /// # Formula 193 | /// 194 | /// ```text 195 | /// v 196 | /// ``` 197 | /// 198 | /// where `v` is the point of the dirac distribution 199 | fn median(&self) -> f64 { 200 | self.0 201 | } 202 | } 203 | 204 | impl Mode> for Dirac { 205 | /// Returns the mode of the dirac distribution 206 | /// 207 | /// # Formula 208 | /// 209 | /// ```text 210 | /// v 211 | /// ``` 212 | /// 213 | /// where `v` is the point of the dirac distribution 214 | fn mode(&self) -> Option { 215 | Some(self.0) 216 | } 217 | } 218 | 219 | #[rustfmt::skip] 220 | #[cfg(test)] 221 | mod tests { 222 | use super::*; 223 | use crate::distribution::internal::testing_boiler; 224 | 225 | testing_boiler!(v: f64; Dirac; DiracError); 226 | 227 | #[test] 228 | fn test_create() { 229 | create_ok(10.0); 230 | create_ok(-5.0); 231 | create_ok(10.0); 232 | create_ok(100.0); 233 | create_ok(f64::INFINITY); 234 | } 235 | 236 | #[test] 237 | fn test_bad_create() { 238 | create_err(f64::NAN); 239 | } 240 | 241 | #[test] 242 | fn test_variance() { 243 | let variance = |x: Dirac| x.variance().unwrap(); 244 | test_exact(0.0, 0.0, variance); 245 | test_exact(-5.0, 0.0, variance); 246 | test_exact(f64::INFINITY, 0.0, variance); 247 | } 248 | 249 | #[test] 250 | fn test_entropy() { 251 | let entropy = |x: Dirac| x.entropy().unwrap(); 252 | test_exact(0.0, 0.0, entropy); 253 | test_exact(f64::INFINITY, 0.0, entropy); 254 | } 255 | 256 | #[test] 257 | fn test_skewness() { 258 | let skewness = |x: Dirac| x.skewness().unwrap(); 259 | test_exact(0.0, 0.0, skewness); 260 | test_exact(4.0, 0.0, skewness); 261 | test_exact(0.3, 0.0, skewness); 262 | test_exact(f64::INFINITY, 0.0, skewness); 263 | } 264 | 265 | #[test] 266 | fn test_mode() { 267 | let mode = |x: Dirac| x.mode().unwrap(); 268 | test_exact(0.0, 0.0, mode); 269 | test_exact(3.0, 3.0, mode); 270 | test_exact(f64::INFINITY, f64::INFINITY, mode); 271 | } 272 | 273 | #[test] 274 | fn test_median() { 275 | let median = |x: Dirac| x.median(); 276 | test_exact(0.0, 0.0, median); 277 | test_exact(3.0, 3.0, median); 278 | test_exact(f64::INFINITY, f64::INFINITY, median); 279 | } 280 | 281 | #[test] 282 | fn test_min_max() { 283 | let min = |x: Dirac| x.min(); 284 | let max = |x: Dirac| x.max(); 285 | test_exact(0.0, 0.0, min); 286 | test_exact(3.0, 3.0, min); 287 | test_exact(f64::INFINITY, f64::INFINITY, min); 288 | 289 | test_exact(0.0, 0.0, max); 290 | test_exact(3.0, 3.0, max); 291 | test_exact(f64::NEG_INFINITY, f64::NEG_INFINITY, max); 292 | } 293 | 294 | #[test] 295 | fn test_cdf() { 296 | let cdf = |arg: f64| move |x: Dirac| x.cdf(arg); 297 | test_exact(0.0, 1.0, cdf(0.0)); 298 | test_exact(3.0, 1.0, cdf(3.0)); 299 | test_exact(f64::INFINITY, 0.0, cdf(1.0)); 300 | test_exact(f64::INFINITY, 1.0, cdf(f64::INFINITY)); 301 | } 302 | 303 | #[test] 304 | fn test_sf() { 305 | let sf = |arg: f64| move |x: Dirac| x.sf(arg); 306 | test_exact(0.0, 0.0, sf(0.0)); 307 | test_exact(3.0, 0.0, sf(3.0)); 308 | test_exact(f64::INFINITY, 1.0, sf(1.0)); 309 | test_exact(f64::INFINITY, 0.0, sf(f64::INFINITY)); 310 | } 311 | } 312 | -------------------------------------------------------------------------------- /src/distribution/bernoulli.rs: -------------------------------------------------------------------------------- 1 | use crate::distribution::{Binomial, BinomialError, Discrete, DiscreteCDF}; 2 | use crate::statistics::*; 3 | 4 | /// Implements the 5 | /// [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution) 6 | /// distribution which is a special case of the 7 | /// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution) 8 | /// distribution where `n = 1` (referenced [Here](./struct.Binomial.html)) 9 | /// 10 | /// # Examples 11 | /// 12 | /// ``` 13 | /// use statrs::distribution::{Bernoulli, Discrete}; 14 | /// use statrs::statistics::Distribution; 15 | /// 16 | /// let n = Bernoulli::new(0.5).unwrap(); 17 | /// assert_eq!(n.mean().unwrap(), 0.5); 18 | /// assert_eq!(n.pmf(0), 0.5); 19 | /// assert_eq!(n.pmf(1), 0.5); 20 | /// ``` 21 | #[derive(Copy, Clone, PartialEq, Debug)] 22 | pub struct Bernoulli { 23 | b: Binomial, 24 | } 25 | 26 | impl Bernoulli { 27 | /// Constructs a new bernoulli distribution with 28 | /// the given `p` probability of success. 29 | /// 30 | /// # Errors 31 | /// 32 | /// Returns an error if `p` is `NaN`, less than `0.0` 33 | /// or greater than `1.0` 34 | /// 35 | /// # Examples 36 | /// 37 | /// ``` 38 | /// use statrs::distribution::Bernoulli; 39 | /// 40 | /// let mut result = Bernoulli::new(0.5); 41 | /// assert!(result.is_ok()); 42 | /// 43 | /// result = Bernoulli::new(-0.5); 44 | /// assert!(result.is_err()); 45 | /// ``` 46 | pub fn new(p: f64) -> Result { 47 | Binomial::new(p, 1).map(|b| Bernoulli { b }) 48 | } 49 | 50 | /// Returns the probability of success `p` of the 51 | /// bernoulli distribution. 52 | /// 53 | /// # Examples 54 | /// 55 | /// ``` 56 | /// use statrs::distribution::Bernoulli; 57 | /// 58 | /// let n = Bernoulli::new(0.5).unwrap(); 59 | /// assert_eq!(n.p(), 0.5); 60 | /// ``` 61 | pub fn p(&self) -> f64 { 62 | self.b.p() 63 | } 64 | 65 | /// Returns the number of trials `n` of the 66 | /// bernoulli distribution. Will always be `1.0`. 67 | /// 68 | /// # Examples 69 | /// 70 | /// ``` 71 | /// use statrs::distribution::Bernoulli; 72 | /// 73 | /// let n = Bernoulli::new(0.5).unwrap(); 74 | /// assert_eq!(n.n(), 1); 75 | /// ``` 76 | pub fn n(&self) -> u64 { 77 | 1 78 | } 79 | } 80 | 81 | impl core::fmt::Display for Bernoulli { 82 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 83 | write!(f, "Bernoulli({})", self.p()) 84 | } 85 | } 86 | 87 | #[cfg(feature = "rand")] 88 | #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] 89 | impl ::rand::distr::Distribution for Bernoulli { 90 | fn sample(&self, rng: &mut R) -> bool { 91 | rng.random_bool(self.p()) 92 | } 93 | } 94 | 95 | #[cfg(feature = "rand")] 96 | #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] 97 | impl ::rand::distr::Distribution for Bernoulli { 98 | fn sample(&self, rng: &mut R) -> f64 { 99 | rng.sample::(self) as u8 as f64 100 | } 101 | } 102 | 103 | impl DiscreteCDF for Bernoulli { 104 | /// Calculates the cumulative distribution 105 | /// function for the bernoulli distribution at `x`. 106 | /// 107 | /// # Formula 108 | /// 109 | /// ```text 110 | /// if x >= 1 { 1 } 111 | /// else { 1 - p } 112 | /// ``` 113 | fn cdf(&self, x: u64) -> f64 { 114 | if x >= 1 { 1. } else { 1. - self.b.p() } 115 | } 116 | 117 | /// Calculates the survival function for the 118 | /// bernoulli distribution at `x`. 119 | /// 120 | /// # Formula 121 | /// 122 | /// ```text 123 | /// if x < 0 { 1 } 124 | /// else if x >= 1 { 0 } 125 | /// else { p } 126 | /// ``` 127 | fn sf(&self, x: u64) -> f64 { 128 | self.b.sf(x) 129 | } 130 | } 131 | 132 | impl Min for Bernoulli { 133 | /// Returns the minimum value in the domain of the 134 | /// bernoulli distribution representable by a 64- 135 | /// bit integer 136 | /// 137 | /// # Formula 138 | /// 139 | /// ```text 140 | /// 0 141 | /// ``` 142 | fn min(&self) -> u64 { 143 | 0 144 | } 145 | } 146 | 147 | impl Max for Bernoulli { 148 | /// Returns the maximum value in the domain of the 149 | /// bernoulli distribution representable by a 64- 150 | /// bit integer 151 | /// 152 | /// # Formula 153 | /// 154 | /// ```text 155 | /// 1 156 | /// ``` 157 | fn max(&self) -> u64 { 158 | 1 159 | } 160 | } 161 | 162 | impl Distribution for Bernoulli { 163 | /// Returns the mean of the bernoulli 164 | /// distribution 165 | /// 166 | /// # Formula 167 | /// 168 | /// ```text 169 | /// p 170 | /// ``` 171 | fn mean(&self) -> Option { 172 | self.b.mean() 173 | } 174 | 175 | /// Returns the variance of the bernoulli 176 | /// distribution 177 | /// 178 | /// # Formula 179 | /// 180 | /// ```text 181 | /// p * (1 - p) 182 | /// ``` 183 | fn variance(&self) -> Option { 184 | self.b.variance() 185 | } 186 | 187 | /// Returns the entropy of the bernoulli 188 | /// distribution 189 | /// 190 | /// # Formula 191 | /// 192 | /// ```text 193 | /// q = (1 - p) 194 | /// -q * ln(q) - p * ln(p) 195 | /// ``` 196 | fn entropy(&self) -> Option { 197 | self.b.entropy() 198 | } 199 | 200 | /// Returns the skewness of the bernoulli 201 | /// distribution 202 | /// 203 | /// # Formula 204 | /// 205 | /// ```text 206 | /// q = (1 - p) 207 | /// (1 - 2p) / sqrt(p * q) 208 | /// ``` 209 | fn skewness(&self) -> Option { 210 | self.b.skewness() 211 | } 212 | } 213 | 214 | impl Median for Bernoulli { 215 | /// Returns the median of the bernoulli 216 | /// distribution 217 | /// 218 | /// # Formula 219 | /// 220 | /// ```text 221 | /// if p < 0.5 { 0 } 222 | /// else if p > 0.5 { 1 } 223 | /// else { 0.5 } 224 | /// ``` 225 | fn median(&self) -> f64 { 226 | self.b.median() 227 | } 228 | } 229 | 230 | impl Mode> for Bernoulli { 231 | /// Returns the mode of the bernoulli distribution 232 | /// 233 | /// # Formula 234 | /// 235 | /// ```text 236 | /// if p < 0.5 { 0 } 237 | /// else { 1 } 238 | /// ``` 239 | fn mode(&self) -> Option { 240 | self.b.mode() 241 | } 242 | } 243 | 244 | impl Discrete for Bernoulli { 245 | /// Calculates the probability mass function for the 246 | /// bernoulli distribution at `x`. 247 | /// 248 | /// # Formula 249 | /// 250 | /// ```text 251 | /// if x == 0 { 1 - p } 252 | /// else { p } 253 | /// ``` 254 | fn pmf(&self, x: u64) -> f64 { 255 | self.b.pmf(x) 256 | } 257 | 258 | /// Calculates the log probability mass function for the 259 | /// bernoulli distribution at `x`. 260 | /// 261 | /// # Formula 262 | /// 263 | /// ```text 264 | /// else if x == 0 { ln(1 - p) } 265 | /// else { ln(p) } 266 | /// ``` 267 | fn ln_pmf(&self, x: u64) -> f64 { 268 | self.b.ln_pmf(x) 269 | } 270 | } 271 | 272 | #[rustfmt::skip] 273 | #[cfg(test)] 274 | mod test { 275 | use super::*; 276 | use crate::distribution::internal::testing_boiler; 277 | 278 | testing_boiler!(p: f64; Bernoulli; BinomialError); 279 | 280 | #[test] 281 | fn test_create() { 282 | create_ok(0.0); 283 | create_ok(0.3); 284 | create_ok(1.0); 285 | } 286 | 287 | #[test] 288 | fn test_bad_create() { 289 | create_err(f64::NAN); 290 | create_err(-1.0); 291 | create_err(2.0); 292 | } 293 | 294 | #[test] 295 | fn test_cdf_upper_bound() { 296 | let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg); 297 | test_relative(0.3, 1., cdf(1)); 298 | } 299 | 300 | #[test] 301 | fn test_sf_upper_bound() { 302 | let sf = |arg: u64| move |x: Bernoulli| x.sf(arg); 303 | test_relative(0.3, 0., sf(1)); 304 | } 305 | 306 | #[test] 307 | fn test_cdf() { 308 | let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg); 309 | test_relative(0.0, 1.0, cdf(0)); 310 | test_relative(0.0, 1.0, cdf(1)); 311 | test_absolute(0.3, 0.7, 1e-15, cdf(0)); 312 | test_absolute(0.7, 0.3, 1e-15, cdf(0)); 313 | } 314 | 315 | #[test] 316 | fn test_sf() { 317 | let sf = |arg: u64| move |x: Bernoulli| x.sf(arg); 318 | test_relative(0.0, 0.0, sf(0)); 319 | test_relative(0.0, 0.0, sf(1)); 320 | test_absolute(0.3, 0.3, 1e-15, sf(0)); 321 | test_absolute(0.7, 0.7, 1e-15, sf(0)); 322 | } 323 | 324 | #[test] 325 | fn test_inverse_cdf() { 326 | let invcdf = |arg: f64| move |x: Bernoulli| x.inverse_cdf(arg); 327 | test_exact(0., 0, invcdf(0.)); 328 | test_exact(0., 0, invcdf(0.5)); 329 | test_exact(1., 0, invcdf(0.)); 330 | test_exact(1., 1, invcdf(1.)); 331 | test_exact(1., 1, invcdf(1e-6)); 332 | test_exact(0.5, 0, invcdf(0.25)); 333 | test_exact(0.5, 0, invcdf(0.5)); 334 | } 335 | } 336 | -------------------------------------------------------------------------------- /src/distribution/erlang.rs: -------------------------------------------------------------------------------- 1 | use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; 2 | use crate::statistics::*; 3 | 4 | /// Implements the [Erlang](https://en.wikipedia.org/wiki/Erlang_distribution) 5 | /// distribution 6 | /// which is a special case of the 7 | /// [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) 8 | /// distribution 9 | /// 10 | /// # Examples 11 | /// 12 | /// ``` 13 | /// use statrs::distribution::{Erlang, Continuous}; 14 | /// use statrs::statistics::Distribution; 15 | /// use approx::assert_abs_diff_eq; 16 | /// 17 | /// let n = Erlang::new(3, 1.0).unwrap(); 18 | /// assert_eq!(n.mean().unwrap(), 3.0); 19 | /// assert_abs_diff_eq!(n.pdf(2.0), 0.270670566473225383788, epsilon = 1e-15); 20 | /// ``` 21 | #[derive(Copy, Clone, PartialEq, Debug)] 22 | pub struct Erlang { 23 | g: Gamma, 24 | } 25 | 26 | impl Erlang { 27 | /// Constructs a new erlang distribution with a shape (k) 28 | /// of `shape` and a rate (λ) of `rate` 29 | /// 30 | /// # Errors 31 | /// 32 | /// Returns an error if `shape` or `rate` are `NaN`. 33 | /// Also returns an error if `shape == 0` or `rate <= 0.0` 34 | /// 35 | /// # Examples 36 | /// 37 | /// ``` 38 | /// use statrs::distribution::Erlang; 39 | /// 40 | /// let mut result = Erlang::new(3, 1.0); 41 | /// assert!(result.is_ok()); 42 | /// 43 | /// result = Erlang::new(0, 0.0); 44 | /// assert!(result.is_err()); 45 | /// ``` 46 | pub fn new(shape: u64, rate: f64) -> Result { 47 | Gamma::new(shape as f64, rate).map(|g| Erlang { g }) 48 | } 49 | 50 | /// Returns the shape (k) of the erlang distribution 51 | /// 52 | /// # Examples 53 | /// 54 | /// ``` 55 | /// use statrs::distribution::Erlang; 56 | /// 57 | /// let n = Erlang::new(3, 1.0).unwrap(); 58 | /// assert_eq!(n.shape(), 3); 59 | /// ``` 60 | pub fn shape(&self) -> u64 { 61 | self.g.shape() as u64 62 | } 63 | 64 | /// Returns the rate (λ) of the erlang distribution 65 | /// 66 | /// # Examples 67 | /// 68 | /// ``` 69 | /// use statrs::distribution::Erlang; 70 | /// 71 | /// let n = Erlang::new(3, 1.0).unwrap(); 72 | /// assert_eq!(n.rate(), 1.0); 73 | /// ``` 74 | pub fn rate(&self) -> f64 { 75 | self.g.rate() 76 | } 77 | } 78 | 79 | impl core::fmt::Display for Erlang { 80 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 81 | write!(f, "E({}, {})", self.rate(), self.shape()) 82 | } 83 | } 84 | 85 | #[cfg(feature = "rand")] 86 | #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] 87 | impl ::rand::distr::Distribution for Erlang { 88 | fn sample(&self, rng: &mut R) -> f64 { 89 | ::rand::distr::Distribution::sample(&self.g, rng) 90 | } 91 | } 92 | 93 | impl ContinuousCDF for Erlang { 94 | /// Calculates the cumulative distribution function for the erlang 95 | /// distribution 96 | /// at `x` 97 | /// 98 | /// # Formula 99 | /// 100 | /// ```text 101 | /// γ(k, λx) (k - 1)! 102 | /// ``` 103 | /// 104 | /// where `k` is the shape, `λ` is the rate, and `γ` is the lower 105 | /// incomplete gamma function 106 | fn cdf(&self, x: f64) -> f64 { 107 | self.g.cdf(x) 108 | } 109 | 110 | /// Calculates the cumulative distribution function for the erlang 111 | /// distribution 112 | /// at `x` 113 | /// 114 | /// # Formula 115 | /// 116 | /// ```text 117 | /// γ(k, λx) (k - 1)! 118 | /// ``` 119 | /// 120 | /// where `k` is the shape, `λ` is the rate, and `γ` is the upper 121 | /// incomplete gamma function 122 | fn sf(&self, x: f64) -> f64 { 123 | self.g.sf(x) 124 | } 125 | 126 | /// Calculates the inverse cumulative distribution function for the erlang 127 | /// distribution at `x` 128 | /// 129 | /// # Formula 130 | /// 131 | /// ```text 132 | /// γ^{-1}(k, (k - 1)! x) / λ 133 | /// ``` 134 | /// 135 | /// where `k` is the shape, `λ` is the rate, and `γ` is the upper 136 | /// incomplete gamma function 137 | fn inverse_cdf(&self, p: f64) -> f64 { 138 | self.g.inverse_cdf(p) 139 | } 140 | } 141 | 142 | impl Min for Erlang { 143 | /// Returns the minimum value in the domain of the 144 | /// erlang distribution representable by a double precision 145 | /// float 146 | /// 147 | /// # Formula 148 | /// 149 | /// ```text 150 | /// 0 151 | /// ``` 152 | fn min(&self) -> f64 { 153 | self.g.min() 154 | } 155 | } 156 | 157 | impl Max for Erlang { 158 | /// Returns the maximum value in the domain of the 159 | /// erlang distribution representable by a double precision 160 | /// float 161 | /// 162 | /// # Formula 163 | /// 164 | /// ```text 165 | /// f64::INFINITY 166 | /// ``` 167 | fn max(&self) -> f64 { 168 | self.g.max() 169 | } 170 | } 171 | 172 | impl Distribution for Erlang { 173 | /// Returns the mean of the erlang distribution 174 | /// 175 | /// # Remarks 176 | /// 177 | /// Returns `shape` if `rate == f64::INFINITY`. This behavior 178 | /// is borrowed from the Math.NET implementation 179 | /// 180 | /// # Formula 181 | /// 182 | /// ```text 183 | /// k / λ 184 | /// ``` 185 | /// 186 | /// where `k` is the shape and `λ` is the rate 187 | fn mean(&self) -> Option { 188 | self.g.mean() 189 | } 190 | 191 | /// Returns the variance of the erlang distribution 192 | /// 193 | /// # Formula 194 | /// 195 | /// ```text 196 | /// k / λ^2 197 | /// ``` 198 | /// 199 | /// where `α` is the shape and `λ` is the rate 200 | fn variance(&self) -> Option { 201 | self.g.variance() 202 | } 203 | 204 | /// Returns the entropy of the erlang distribution 205 | /// 206 | /// # Formula 207 | /// 208 | /// ```text 209 | /// k - ln(λ) + ln(Γ(k)) + (1 - k) * ψ(k) 210 | /// ``` 211 | /// 212 | /// where `k` is the shape, `λ` is the rate, `Γ` is the gamma function, 213 | /// and `ψ` is the digamma function 214 | fn entropy(&self) -> Option { 215 | self.g.entropy() 216 | } 217 | 218 | /// Returns the skewness of the erlang distribution 219 | /// 220 | /// # Formula 221 | /// 222 | /// ```text 223 | /// 2 / sqrt(k) 224 | /// ``` 225 | /// 226 | /// where `k` is the shape 227 | fn skewness(&self) -> Option { 228 | self.g.skewness() 229 | } 230 | } 231 | 232 | impl Mode> for Erlang { 233 | /// Returns the mode for the erlang distribution 234 | /// 235 | /// # Remarks 236 | /// 237 | /// Returns `shape` if `rate ==f64::INFINITY`. This behavior 238 | /// is borrowed from the Math.NET implementation 239 | /// 240 | /// # Formula 241 | /// 242 | /// ```text 243 | /// (k - 1) / λ 244 | /// ``` 245 | /// 246 | /// where `k` is the shape and `λ` is the rate 247 | fn mode(&self) -> Option { 248 | self.g.mode() 249 | } 250 | } 251 | 252 | impl Continuous for Erlang { 253 | /// Calculates the probability density function for the erlang distribution 254 | /// at `x` 255 | /// 256 | /// # Remarks 257 | /// 258 | /// Returns `NAN` if any of `shape` or `rate` are `INF` 259 | /// or if `x` is `INF` 260 | /// 261 | /// # Formula 262 | /// 263 | /// ```text 264 | /// (λ^k / Γ(k)) * x^(k - 1) * e^(-λ * x) 265 | /// ``` 266 | /// 267 | /// where `k` is the shape, `λ` is the rate, and `Γ` is the gamma function 268 | fn pdf(&self, x: f64) -> f64 { 269 | self.g.pdf(x) 270 | } 271 | 272 | /// Calculates the log probability density function for the erlang 273 | /// distribution 274 | /// at `x` 275 | /// 276 | /// # Remarks 277 | /// 278 | /// Returns `NAN` if any of `shape` or `rate` are `INF` 279 | /// or if `x` is `INF` 280 | /// 281 | /// # Formula 282 | /// 283 | /// ```text 284 | /// ln((λ^k / Γ(k)) * x^(k - 1) * e ^(-λ * x)) 285 | /// ``` 286 | /// 287 | /// where `k` is the shape, `λ` is the rate, and `Γ` is the gamma function 288 | fn ln_pdf(&self, x: f64) -> f64 { 289 | self.g.ln_pdf(x) 290 | } 291 | } 292 | 293 | #[rustfmt::skip] 294 | #[cfg(test)] 295 | mod tests { 296 | use super::*; 297 | use crate::distribution::internal::density_util; 298 | use crate::distribution::internal::testing_boiler; 299 | testing_boiler!(shape: u64, rate: f64; Erlang; GammaError); 300 | 301 | #[test] 302 | fn test_create() { 303 | create_ok(1, 0.1); 304 | create_ok(1, 1.0); 305 | create_ok(10, 10.0); 306 | create_ok(10, 1.0); 307 | create_ok(10, f64::INFINITY); 308 | } 309 | 310 | #[test] 311 | fn test_bad_create() { 312 | let invalid = [ 313 | (0, 1.0, GammaError::ShapeInvalid), 314 | (1, 0.0, GammaError::RateInvalid), 315 | (1, f64::NAN, GammaError::RateInvalid), 316 | (1, -1.0, GammaError::RateInvalid), 317 | ]; 318 | 319 | for (s, r, err) in invalid { 320 | test_create_err(s, r, err); 321 | } 322 | } 323 | 324 | #[test] 325 | fn test_continuous() { 326 | density_util::check_continuous_distribution(&create_ok(1, 2.5), 0.0, 20.0); 327 | density_util::check_continuous_distribution(&create_ok(2, 1.5), 0.0, 20.0); 328 | density_util::check_continuous_distribution(&create_ok(3, 0.5), 0.0, 20.0); 329 | } 330 | } 331 | -------------------------------------------------------------------------------- /src/distribution/chi_squared.rs: -------------------------------------------------------------------------------- 1 | use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; 2 | use crate::statistics::*; 3 | use core::f64; 4 | 5 | /// Implements the 6 | /// [Chi-squared](https://en.wikipedia.org/wiki/Chi-squared_distribution) 7 | /// distribution which is a special case of the 8 | /// [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) distribution 9 | /// (referenced [Here](./struct.Gamma.html)) 10 | /// 11 | /// # Examples 12 | /// 13 | /// ``` 14 | /// use statrs::distribution::{ChiSquared, Continuous}; 15 | /// use statrs::statistics::Distribution; 16 | /// use approx::assert_abs_diff_eq; 17 | /// 18 | /// let n = ChiSquared::new(3.0).unwrap(); 19 | /// assert_eq!(n.mean().unwrap(), 3.0); 20 | /// assert_abs_diff_eq!(n.pdf(4.0), 0.107981933026376103901, epsilon = 1e-15); 21 | /// ``` 22 | #[derive(Copy, Clone, PartialEq, Debug)] 23 | pub struct ChiSquared { 24 | freedom: f64, 25 | g: Gamma, 26 | } 27 | 28 | impl ChiSquared { 29 | /// Constructs a new chi-squared distribution with `freedom` 30 | /// degrees of freedom. This is equivalent to a Gamma distribution 31 | /// with a shape of `freedom / 2.0` and a rate of `0.5`. 32 | /// 33 | /// # Errors 34 | /// 35 | /// Returns an error if `freedom` is `NaN` or less than 36 | /// or equal to `0.0` 37 | /// 38 | /// # Examples 39 | /// 40 | /// ``` 41 | /// use statrs::distribution::ChiSquared; 42 | /// 43 | /// let mut result = ChiSquared::new(3.0); 44 | /// assert!(result.is_ok()); 45 | /// 46 | /// result = ChiSquared::new(0.0); 47 | /// assert!(result.is_err()); 48 | /// ``` 49 | pub fn new(freedom: f64) -> Result { 50 | Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { freedom, g }) 51 | } 52 | 53 | /// Returns the degrees of freedom of the chi-squared 54 | /// distribution 55 | /// 56 | /// # Examples 57 | /// 58 | /// ``` 59 | /// use statrs::distribution::ChiSquared; 60 | /// 61 | /// let n = ChiSquared::new(3.0).unwrap(); 62 | /// assert_eq!(n.freedom(), 3.0); 63 | /// ``` 64 | pub fn freedom(&self) -> f64 { 65 | self.freedom 66 | } 67 | 68 | /// Returns the shape of the underlying Gamma distribution 69 | /// 70 | /// # Examples 71 | /// 72 | /// ``` 73 | /// use statrs::distribution::ChiSquared; 74 | /// 75 | /// let n = ChiSquared::new(3.0).unwrap(); 76 | /// assert_eq!(n.shape(), 3.0 / 2.0); 77 | /// ``` 78 | pub fn shape(&self) -> f64 { 79 | self.g.shape() 80 | } 81 | 82 | /// Returns the rate of the underlying Gamma distribution 83 | /// 84 | /// # Examples 85 | /// 86 | /// ``` 87 | /// use statrs::distribution::ChiSquared; 88 | /// 89 | /// let n = ChiSquared::new(3.0).unwrap(); 90 | /// assert_eq!(n.rate(), 0.5); 91 | /// ``` 92 | pub fn rate(&self) -> f64 { 93 | self.g.rate() 94 | } 95 | } 96 | 97 | impl core::fmt::Display for ChiSquared { 98 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 99 | write!(f, "χ^2_{}", self.freedom) 100 | } 101 | } 102 | 103 | #[cfg(feature = "rand")] 104 | #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] 105 | impl ::rand::distr::Distribution for ChiSquared { 106 | fn sample(&self, r: &mut R) -> f64 { 107 | ::rand::distr::Distribution::sample(&self.g, r) 108 | } 109 | } 110 | 111 | impl ContinuousCDF for ChiSquared { 112 | /// Calculates the cumulative distribution function for the 113 | /// chi-squared distribution at `x` 114 | /// 115 | /// # Formula 116 | /// 117 | /// ```text 118 | /// (1 / Γ(k / 2)) * γ(k / 2, x / 2) 119 | /// ``` 120 | /// 121 | /// where `k` is the degrees of freedom, `Γ` is the gamma function, 122 | /// and `γ` is the lower incomplete gamma function 123 | fn cdf(&self, x: f64) -> f64 { 124 | self.g.cdf(x) 125 | } 126 | 127 | /// Calculates the cumulative distribution function for the 128 | /// chi-squared distribution at `x` 129 | /// 130 | /// # Formula 131 | /// 132 | /// ```text 133 | /// (1 / Γ(k / 2)) * γ(k / 2, x / 2) 134 | /// ``` 135 | /// 136 | /// where `k` is the degrees of freedom, `Γ` is the gamma function, 137 | /// and `γ` is the upper incomplete gamma function 138 | fn sf(&self, x: f64) -> f64 { 139 | self.g.sf(x) 140 | } 141 | 142 | /// Calculates the inverse cumulative distribution function for the 143 | /// chi-squared distribution at `x` 144 | /// 145 | /// # Formula 146 | /// 147 | /// ```text 148 | /// γ^{-1}(k / 2, x * Γ(k / 2) / 2) 149 | /// ``` 150 | /// 151 | /// where `k` is the degrees of freedom, `Γ` is the gamma function, 152 | /// and `γ` is the lower incomplete gamma function 153 | fn inverse_cdf(&self, p: f64) -> f64 { 154 | self.g.inverse_cdf(p) 155 | } 156 | } 157 | 158 | impl Min for ChiSquared { 159 | /// Returns the minimum value in the domain of the 160 | /// chi-squared distribution representable by a double precision 161 | /// float 162 | /// 163 | /// # Formula 164 | /// 165 | /// ```text 166 | /// 0 167 | /// ``` 168 | fn min(&self) -> f64 { 169 | 0.0 170 | } 171 | } 172 | 173 | impl Max for ChiSquared { 174 | /// Returns the maximum value in the domain of the 175 | /// chi-squared distribution representable by a double precision 176 | /// float 177 | /// 178 | /// # Formula 179 | /// 180 | /// ```text 181 | /// f64::INFINITY 182 | /// ``` 183 | fn max(&self) -> f64 { 184 | f64::INFINITY 185 | } 186 | } 187 | 188 | impl Distribution for ChiSquared { 189 | /// Returns the mean of the chi-squared distribution 190 | /// 191 | /// # Formula 192 | /// 193 | /// ```text 194 | /// k 195 | /// ``` 196 | /// 197 | /// where `k` is the degrees of freedom 198 | fn mean(&self) -> Option { 199 | self.g.mean() 200 | } 201 | 202 | /// Returns the variance of the chi-squared distribution 203 | /// 204 | /// # Formula 205 | /// 206 | /// ```text 207 | /// 2k 208 | /// ``` 209 | /// 210 | /// where `k` is the degrees of freedom 211 | fn variance(&self) -> Option { 212 | self.g.variance() 213 | } 214 | 215 | /// Returns the entropy of the chi-squared distribution 216 | /// 217 | /// # Formula 218 | /// 219 | /// ```text 220 | /// (k / 2) + ln(2 * Γ(k / 2)) + (1 - (k / 2)) * ψ(k / 2) 221 | /// ``` 222 | /// 223 | /// where `k` is the degrees of freedom, `Γ` is the gamma function, 224 | /// and `ψ` is the digamma function 225 | fn entropy(&self) -> Option { 226 | self.g.entropy() 227 | } 228 | 229 | /// Returns the skewness of the chi-squared distribution 230 | /// 231 | /// # Formula 232 | /// 233 | /// ```text 234 | /// sqrt(8 / k) 235 | /// ``` 236 | /// 237 | /// where `k` is the degrees of freedom 238 | fn skewness(&self) -> Option { 239 | self.g.skewness() 240 | } 241 | } 242 | 243 | impl Median for ChiSquared { 244 | /// Returns the median of the chi-squared distribution 245 | /// 246 | /// # Formula 247 | /// 248 | /// ```text 249 | /// k * (1 - (2 / 9k))^3 250 | /// ``` 251 | fn median(&self) -> f64 { 252 | if self.freedom < 1.0 { 253 | // if k is small, calculate using expansion of formula 254 | self.freedom - 2.0 / 3.0 + 12.0 / (81.0 * self.freedom) 255 | - 8.0 / (729.0 * self.freedom * self.freedom) 256 | } else { 257 | // if k is large enough, median heads toward k - 2/3 258 | self.freedom - 2.0 / 3.0 259 | } 260 | } 261 | } 262 | 263 | impl Mode> for ChiSquared { 264 | /// Returns the mode of the chi-squared distribution 265 | /// 266 | /// # Formula 267 | /// 268 | /// ```text 269 | /// k - 2 270 | /// ``` 271 | /// 272 | /// where `k` is the degrees of freedom 273 | fn mode(&self) -> Option { 274 | self.g.mode() 275 | } 276 | } 277 | 278 | impl Continuous for ChiSquared { 279 | /// Calculates the probability density function for the chi-squared 280 | /// distribution at `x` 281 | /// 282 | /// # Formula 283 | /// 284 | /// ```text 285 | /// 1 / (2^(k / 2) * Γ(k / 2)) * x^((k / 2) - 1) * e^(-x / 2) 286 | /// ``` 287 | /// 288 | /// where `k` is the degrees of freedom and `Γ` is the gamma function 289 | fn pdf(&self, x: f64) -> f64 { 290 | self.g.pdf(x) 291 | } 292 | 293 | /// Calculates the log probability density function for the chi-squared 294 | /// distribution at `x` 295 | /// 296 | /// # Formula 297 | /// 298 | /// ```text 299 | /// ln(1 / (2^(k / 2) * Γ(k / 2)) * x^((k / 2) - 1) * e^(-x / 2)) 300 | /// ``` 301 | fn ln_pdf(&self, x: f64) -> f64 { 302 | self.g.ln_pdf(x) 303 | } 304 | } 305 | 306 | #[rustfmt::skip] 307 | #[cfg(test)] 308 | mod tests { 309 | use super::*; 310 | use crate::distribution::internal::density_util; 311 | crate::distribution::internal::testing_boiler!(freedom: f64; ChiSquared; GammaError); 312 | 313 | #[test] 314 | fn test_median() { 315 | let median = |x: ChiSquared| x.median(); 316 | test_absolute(0.5, 0.0857338820301783264746, 1e-16, median); 317 | test_exact(1.0, 1.0 - 2.0 / 3.0, median); 318 | test_exact(2.0, 2.0 - 2.0 / 3.0, median); 319 | test_exact(2.5, 2.5 - 2.0 / 3.0, median); 320 | test_exact(3.0, 3.0 - 2.0 / 3.0, median); 321 | } 322 | 323 | #[test] 324 | fn test_continuous() { 325 | // TODO: figure out why this test fails: 326 | //check_continuous_distribution(&create_ok(1.0), 0.0, 10.0); 327 | density_util::check_continuous_distribution(&create_ok(2.0), 0.0, 10.0); 328 | density_util::check_continuous_distribution(&create_ok(5.0), 0.0, 50.0); 329 | } 330 | } 331 | -------------------------------------------------------------------------------- /src/stats_tests/skewtest.rs: -------------------------------------------------------------------------------- 1 | //! Provides the [skewtest](https://docs.scipy.org/doc/scipy-1.15.0/reference/generated/scipy.stats.skewtest.html) 2 | //! to test whether or not provided data is different than a normal distribution 3 | 4 | use crate::distribution::{ContinuousCDF, Normal}; 5 | use crate::stats_tests::{Alternative, NaNPolicy}; 6 | 7 | /// Represents the errors that can occur when computing the skewtest function 8 | #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] 9 | #[non_exhaustive] 10 | pub enum SkewTestError { 11 | /// sample must contain at least 8 observations 12 | SampleTooSmall, 13 | /// samples can not contain NaN when `nan_policy` is set to `NaNPolicy::Error` 14 | SampleContainsNaN, 15 | } 16 | 17 | impl core::fmt::Display for SkewTestError { 18 | #[cfg_attr(coverage_nightly, coverage(off))] 19 | fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { 20 | match self { 21 | SkewTestError::SampleTooSmall => { 22 | write!(f, "sample must contain at least 8 observations") 23 | } 24 | SkewTestError::SampleContainsNaN => { 25 | write!( 26 | f, 27 | "samples can not contain NaN when nan_policy is set to NaNPolicy::Error" 28 | ) 29 | } 30 | } 31 | } 32 | } 33 | 34 | #[cfg(feature = "std")] 35 | impl std::error::Error for SkewTestError {} 36 | 37 | fn calc_root_b1(data: &[f64]) -> f64 { 38 | // Fisher's moment coefficient of skewness 39 | // https://en.wikipedia.org/wiki/Skewness#Definition 40 | let n = data.len() as f64; 41 | let mu = data.iter().sum::() / n; 42 | 43 | // NOTE: population not sample skewness 44 | (data.iter().map(|x_i| (x_i - mu).powi(3)).sum::() / n) 45 | / (data.iter().map(|x_i| (x_i - mu).powi(2)).sum::() / n).powf(1.5) 46 | } 47 | 48 | /// Perform a skewness test for whether the skew of the sample provided is different than a normal 49 | /// distribution 50 | /// 51 | /// Returns the z-score and p-value 52 | /// 53 | /// # Remarks 54 | /// 55 | /// `a` needs to be mutable in case needing to filter out NaNs for NaNPolicy::Emit 56 | /// 57 | /// Implementation based on [fintools.com](https://www.fintools.com/docs/normality_correlation.pdf) 58 | /// which indirectly uses [D'Agostino, (1970)](https://doi.org/10.2307/2684359) 59 | /// while aligning to [scipy's](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.skewtest.html#scipy.stats.skewtest) 60 | /// function header where possible. The scipy implementation was also used for testing and validation. 61 | /// Includes the use of [Shapiro & Wilk (1965)](https://doi.org/10.2307/2333709) for 62 | /// testing and validation. 63 | /// 64 | /// # Examples 65 | /// 66 | /// ``` 67 | /// use statrs::stats_tests::skewtest::skewtest; 68 | /// use statrs::stats_tests::{Alternative, NaNPolicy}; 69 | /// let data = Vec::from([ 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, 8.0f64, ]); 70 | /// let (statistic, pvalue) = skewtest(data, Alternative::TwoSided, NaNPolicy::Error).unwrap(); 71 | /// ``` 72 | pub fn skewtest( 73 | mut a: Vec, 74 | alternative: Alternative, 75 | nan_policy: NaNPolicy, 76 | ) -> Result<(f64, f64), SkewTestError> { 77 | let has_nans = a.iter().any(|x| x.is_nan()); 78 | if has_nans { 79 | match nan_policy { 80 | NaNPolicy::Propogate => { 81 | return Ok((f64::NAN, f64::NAN)); 82 | } 83 | NaNPolicy::Error => { 84 | return Err(SkewTestError::SampleContainsNaN); 85 | } 86 | NaNPolicy::Emit => { 87 | a = a.into_iter().filter(|x| !x.is_nan()).collect::>(); 88 | } 89 | } 90 | } 91 | 92 | let n = a.len(); 93 | if n < 8 { 94 | return Err(SkewTestError::SampleTooSmall); 95 | } 96 | let n = n as f64; 97 | 98 | let root_b1 = calc_root_b1(&a); 99 | let mut y = root_b1 * ((n + 1.0) * (n + 3.0) / (6.0 * (n - 2.0))).sqrt(); 100 | let beta2_root_b1 = 3.0 * (n.powi(2) + 27.0 * n - 70.0) * (n + 1.0) * (n + 3.0) 101 | / ((n - 2.0) * (n + 5.0) * (n + 7.0) * (n + 9.0)); 102 | let w_sq = -1.0 + (2.0 * (beta2_root_b1 - 1.0)).sqrt(); 103 | let delta = 1.0 / (0.5 * w_sq.ln()).sqrt(); 104 | let alpha = (2.0 / (w_sq - 1.0)).sqrt(); 105 | // correction from scipy version to`match scipy example results 106 | if y == 0.0 { 107 | y = 1.0; 108 | } 109 | let zscore = delta * (y / alpha + ((y / alpha).powi(2) + 1.0).sqrt()).ln(); 110 | 111 | let norm_dist = Normal::default(); 112 | 113 | let pvalue = match alternative { 114 | Alternative::TwoSided => 2.0 * (1.0 - norm_dist.cdf(zscore.abs())), 115 | Alternative::Less => norm_dist.cdf(zscore), 116 | Alternative::Greater => 1.0 - norm_dist.cdf(zscore), 117 | }; 118 | 119 | Ok((zscore, pvalue)) 120 | } 121 | 122 | #[rustfmt::skip] 123 | #[cfg(test)] 124 | mod tests { 125 | use super::*; 126 | use crate::prec; 127 | 128 | #[test] 129 | fn test_scipy_example() { 130 | let data = Vec::from([ 131 | 148.0f64, 154.0f64, 158.0f64, 160.0f64, 161.0f64, 162.0f64, 166.0f64, 170.0f64, 132 | 182.0f64, 195.0f64, 236.0f64, 133 | ]); 134 | let (statistic, pvalue) = 135 | skewtest(data.clone(), Alternative::TwoSided, NaNPolicy::Error).unwrap(); 136 | prec::assert_relative_eq!(statistic, 2.7788579769903414); 137 | prec::assert_abs_diff_eq!(pvalue, 0.005455036974740185); 138 | 139 | let (statistic, pvalue) = skewtest( 140 | Vec::from([ 141 | 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, 8.0f64, 142 | ]), 143 | Alternative::TwoSided, 144 | NaNPolicy::Error, 145 | ) 146 | .unwrap(); 147 | prec::assert_relative_eq!(statistic, 1.0108048609177787); 148 | prec::assert_abs_diff_eq!(pvalue, 0.3121098361421897); 149 | let (statistic, pvalue) = skewtest( 150 | Vec::from([ 151 | 2.0f64, 8.0f64, 0.0f64, 4.0f64, 1.0f64, 9.0f64, 9.0f64, 0.0f64, 152 | ]), 153 | Alternative::TwoSided, 154 | NaNPolicy::Error, 155 | ) 156 | .unwrap(); 157 | prec::assert_relative_eq!(statistic, 0.44626385374196975); 158 | prec::assert_abs_diff_eq!(pvalue, 0.6554066631275459); 159 | let (statistic, pvalue) = skewtest( 160 | Vec::from([ 161 | 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, 8000.0f64, 162 | ]), 163 | Alternative::TwoSided, 164 | NaNPolicy::Error, 165 | ) 166 | .unwrap(); 167 | prec::assert_relative_eq!(statistic, 3.571773510360407); 168 | prec::assert_abs_diff_eq!(pvalue, 0.0003545719905823133); 169 | let (statistic, pvalue) = skewtest( 170 | Vec::from([ 171 | 100.0f64, 100.0f64, 100.0f64, 100.0f64, 100.0f64, 100.0f64, 100.0f64, 101.0f64, 172 | ]), 173 | Alternative::TwoSided, 174 | NaNPolicy::Error, 175 | ) 176 | .unwrap(); 177 | prec::assert_relative_eq!(statistic, 3.5717766638478072); 178 | prec::assert_abs_diff_eq!(pvalue, 0.000354567720281634012); 179 | let (statistic, pvalue) = skewtest( 180 | Vec::from([ 181 | 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, 8.0f64, 182 | ]), 183 | Alternative::Less, 184 | NaNPolicy::Error, 185 | ) 186 | .unwrap(); 187 | prec::assert_relative_eq!(statistic, 1.0108048609177787); 188 | prec::assert_abs_diff_eq!(pvalue, 0.8439450819289052); 189 | let (statistic, pvalue) = skewtest( 190 | Vec::from([ 191 | 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, 8.0f64, 192 | ]), 193 | Alternative::Greater, 194 | NaNPolicy::Error, 195 | ) 196 | .unwrap(); 197 | prec::assert_relative_eq!(statistic, 1.0108048609177787); 198 | prec::assert_abs_diff_eq!(pvalue, 0.15605491807109484); 199 | } 200 | #[test] 201 | fn test_nan_in_data_w_emit() { 202 | // results should be the same as the example above since the NaNs should be filtered out 203 | let data = Vec::from([ 204 | 148.0f64, 205 | 154.0f64, 206 | 158.0f64, 207 | 160.0f64, 208 | 161.0f64, 209 | 162.0f64, 210 | 166.0f64, 211 | 170.0f64, 212 | 182.0f64, 213 | 195.0f64, 214 | 236.0f64, 215 | f64::NAN, 216 | ]); 217 | let (statistic, pvalue) = 218 | skewtest(data.clone(), Alternative::TwoSided, NaNPolicy::Emit).unwrap(); 219 | prec::assert_relative_eq!(statistic, 2.7788579769903414); 220 | prec::assert_abs_diff_eq!(pvalue, 0.005455036974740185); 221 | } 222 | #[test] 223 | fn test_nan_in_data_w_propogate() { 224 | let sample_input = Vec::from([1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, f64::NAN]); 225 | let (statistic, pvalue) = 226 | skewtest(sample_input, Alternative::TwoSided, NaNPolicy::Propogate).unwrap(); 227 | assert!(statistic.is_nan()); 228 | assert!(pvalue.is_nan()); 229 | } 230 | #[test] 231 | fn test_nan_in_data_w_error() { 232 | let sample_input = Vec::from([0.0571, 0.0813, f64::NAN, 0.0836]); 233 | let result = skewtest(sample_input, Alternative::TwoSided, NaNPolicy::Error); 234 | assert_eq!(result, Err(SkewTestError::SampleContainsNaN)); 235 | } 236 | #[test] 237 | fn test_bad_data_sample_too_small() { 238 | let sample_input = Vec::new(); 239 | let result = skewtest(sample_input, Alternative::TwoSided, NaNPolicy::Error); 240 | assert_eq!(result, Err(SkewTestError::SampleTooSmall)); 241 | 242 | let sample_input = Vec::from([1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, f64::NAN]); 243 | let result = skewtest(sample_input, Alternative::TwoSided, NaNPolicy::Emit); 244 | assert_eq!(result, Err(SkewTestError::SampleTooSmall)); 245 | } 246 | #[test] 247 | fn test_calc_root_b1() { 248 | // compare to https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.skew.html 249 | // since no wikipedia examples 250 | let sample_input = Vec::from([1.0, 2.0, 3.0, 4.0, 5.0]); 251 | prec::assert_ulps_eq!(calc_root_b1(&sample_input), 0.0); 252 | 253 | let sample_input = Vec::from([2.0, 8.0, 0.0, 4.0, 1.0, 9.0, 9.0, 0.0]); 254 | let result = calc_root_b1(&sample_input); 255 | prec::assert_abs_diff_eq!(result, 0.2650554122698573); 256 | } 257 | } 258 | -------------------------------------------------------------------------------- /src/distribution/mod.rs: -------------------------------------------------------------------------------- 1 | //! Defines common interfaces for interacting with statistical distributions 2 | //! and provides 3 | //! concrete implementations for a variety of distributions. 4 | use super::statistics::{Max, Min}; 5 | use ::num_traits::{Float, Num}; 6 | use num_traits::NumAssignOps; 7 | 8 | pub use self::bernoulli::Bernoulli; 9 | pub use self::beta::{Beta, BetaError}; 10 | pub use self::binomial::{Binomial, BinomialError}; 11 | #[cfg(feature = "std")] 12 | pub use self::categorical::{Categorical, CategoricalError}; 13 | pub use self::cauchy::{Cauchy, CauchyError}; 14 | pub use self::chi::{Chi, ChiError}; 15 | pub use self::chi_squared::ChiSquared; 16 | pub use self::dirac::{Dirac, DiracError}; 17 | #[cfg(feature = "nalgebra")] 18 | pub use self::dirichlet::{Dirichlet, DirichletError}; 19 | pub use self::discrete_uniform::{DiscreteUniform, DiscreteUniformError}; 20 | #[cfg(feature = "std")] 21 | pub use self::empirical::Empirical; 22 | pub use self::erlang::Erlang; 23 | pub use self::exponential::{Exp, ExpError}; 24 | pub use self::fisher_snedecor::{FisherSnedecor, FisherSnedecorError}; 25 | pub use self::gamma::{Gamma, GammaError}; 26 | pub use self::geometric::{Geometric, GeometricError}; 27 | pub use self::gumbel::{Gumbel, GumbelError}; 28 | pub use self::hypergeometric::{Hypergeometric, HypergeometricError}; 29 | pub use self::inverse_gamma::{InverseGamma, InverseGammaError}; 30 | pub use self::laplace::{Laplace, LaplaceError}; 31 | pub use self::levy::{Levy, LevyError}; 32 | pub use self::log_normal::{LogNormal, LogNormalError}; 33 | #[cfg(feature = "nalgebra")] 34 | pub use self::multinomial::{Multinomial, MultinomialError}; 35 | #[cfg(feature = "nalgebra")] 36 | pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError}; 37 | #[cfg(feature = "nalgebra")] 38 | pub use self::multivariate_students_t::{MultivariateStudent, MultivariateStudentError}; 39 | pub use self::negative_binomial::{NegativeBinomial, NegativeBinomialError}; 40 | pub use self::normal::{Normal, NormalError}; 41 | pub use self::pareto::{Pareto, ParetoError}; 42 | pub use self::poisson::{Poisson, PoissonError}; 43 | pub use self::students_t::{StudentsT, StudentsTError}; 44 | pub use self::triangular::{Triangular, TriangularError}; 45 | pub use self::uniform::{Uniform, UniformError}; 46 | pub use self::weibull::{Weibull, WeibullError}; 47 | 48 | mod bernoulli; 49 | mod beta; 50 | mod binomial; 51 | #[cfg(feature = "std")] 52 | mod categorical; 53 | mod cauchy; 54 | mod chi; 55 | mod chi_squared; 56 | mod dirac; 57 | #[cfg(feature = "nalgebra")] 58 | #[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))] 59 | mod dirichlet; 60 | mod discrete_uniform; 61 | #[cfg(feature = "std")] 62 | mod empirical; 63 | mod erlang; 64 | mod exponential; 65 | mod fisher_snedecor; 66 | mod gamma; 67 | mod geometric; 68 | mod gumbel; 69 | mod hypergeometric; 70 | #[macro_use] 71 | mod internal; 72 | mod inverse_gamma; 73 | mod laplace; 74 | mod levy; 75 | mod log_normal; 76 | #[cfg(feature = "nalgebra")] 77 | #[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))] 78 | mod multinomial; 79 | #[cfg(feature = "nalgebra")] 80 | #[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))] 81 | mod multivariate_normal; 82 | #[cfg(feature = "nalgebra")] 83 | #[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))] 84 | mod multivariate_students_t; 85 | mod negative_binomial; 86 | mod normal; 87 | mod pareto; 88 | mod poisson; 89 | mod students_t; 90 | mod triangular; 91 | mod uniform; 92 | mod weibull; 93 | #[cfg(feature = "rand")] 94 | #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] 95 | mod ziggurat; 96 | #[cfg(feature = "rand")] 97 | #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] 98 | mod ziggurat_tables; 99 | 100 | /// The `ContinuousCDF` trait is used to specify an interface for univariate 101 | /// distributions for which cdf float arguments are sensible. 102 | pub trait ContinuousCDF: Min + Max { 103 | /// Returns the cumulative distribution function calculated 104 | /// at `x` for a given distribution. May panic depending 105 | /// on the implementor. 106 | /// 107 | /// # Examples 108 | /// 109 | /// ``` 110 | /// use statrs::distribution::{ContinuousCDF, Uniform}; 111 | /// 112 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 113 | /// assert_eq!(0.5, n.cdf(0.5)); 114 | /// ``` 115 | fn cdf(&self, x: K) -> T; 116 | 117 | /// Returns the survival function calculated 118 | /// at `x` for a given distribution. May panic depending 119 | /// on the implementor. 120 | /// 121 | /// # Examples 122 | /// 123 | /// ``` 124 | /// use statrs::distribution::{ContinuousCDF, Uniform}; 125 | /// 126 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 127 | /// assert_eq!(0.5, n.sf(0.5)); 128 | /// ``` 129 | fn sf(&self, x: K) -> T { 130 | T::one() - self.cdf(x) 131 | } 132 | 133 | /// Due to issues with rounding and floating-point accuracy the default 134 | /// implementation may be ill-behaved. 135 | /// Specialized inverse cdfs should be used whenever possible. 136 | /// Performs a binary search on the domain of `cdf` to obtain an approximation 137 | /// of `F^-1(p) := inf { x | F(x) >= p }`. Needless to say, performance may 138 | /// may be lacking. 139 | #[doc(alias = "quantile function")] 140 | #[doc(alias = "quantile")] 141 | fn inverse_cdf(&self, p: T) -> K { 142 | if p == T::zero() { 143 | return self.min(); 144 | }; 145 | if p == T::one() { 146 | return self.max(); 147 | }; 148 | let two = K::one() + K::one(); 149 | let mut high = two; 150 | let mut low = -high; 151 | while self.cdf(low) > p { 152 | low = low + low; 153 | } 154 | while self.cdf(high) < p { 155 | high = high + high; 156 | } 157 | let mut i = 16; 158 | while i != 0 { 159 | let mid = (high + low) / two; 160 | if self.cdf(mid) >= p { 161 | high = mid; 162 | } else { 163 | low = mid; 164 | } 165 | i -= 1; 166 | } 167 | (high + low) / two 168 | } 169 | } 170 | 171 | /// The `DiscreteCDF` trait is used to specify an interface for univariate 172 | /// discrete distributions. 173 | pub trait DiscreteCDF: 174 | Min + Max 175 | { 176 | /// Returns the cumulative distribution function calculated 177 | /// at `x` for a given distribution. May panic depending 178 | /// on the implementor. 179 | /// 180 | /// # Examples 181 | /// 182 | /// ``` 183 | /// use statrs::distribution::{DiscreteCDF, DiscreteUniform}; 184 | /// 185 | /// let n = DiscreteUniform::new(1, 10).unwrap(); 186 | /// assert_eq!(0.6, n.cdf(6)); 187 | /// ``` 188 | fn cdf(&self, x: K) -> T; 189 | 190 | /// Returns the survival function calculated at `x` for 191 | /// a given distribution. May panic depending on the implementor. 192 | /// 193 | /// # Examples 194 | /// 195 | /// ``` 196 | /// use statrs::distribution::{DiscreteCDF, DiscreteUniform}; 197 | /// 198 | /// let n = DiscreteUniform::new(1, 10).unwrap(); 199 | /// assert_eq!(0.4, n.sf(6)); 200 | /// ``` 201 | fn sf(&self, x: K) -> T { 202 | T::one() - self.cdf(x) 203 | } 204 | 205 | /// Due to issues with rounding and floating-point accuracy the default implementation may be ill-behaved 206 | /// Specialized inverse cdfs should be used whenever possible. 207 | /// 208 | /// # Panics 209 | /// this default impl panics if provided `p` not on interval [0.0, 1.0] 210 | fn inverse_cdf(&self, p: T) -> K { 211 | if p <= self.cdf(self.min()) { 212 | return self.min(); 213 | } else if p == T::one() { 214 | return self.max(); 215 | } else if !(T::zero()..=T::one()).contains(&p) { 216 | panic!("p must be on [0, 1]") 217 | } 218 | 219 | let two = K::one() + K::one(); 220 | let mut ub = two.clone(); 221 | let lb = self.min(); 222 | while self.cdf(ub.clone()) < p { 223 | ub *= two.clone(); 224 | } 225 | 226 | internal::integral_bisection_search(|p| self.cdf(p.clone()), p, lb, ub).unwrap() 227 | } 228 | } 229 | 230 | /// The `Continuous` trait provides an interface for interacting with 231 | /// continuous statistical distributions 232 | /// 233 | /// # Remarks 234 | /// 235 | /// All methods provided by the `Continuous` trait are unchecked, meaning 236 | /// they can panic if in an invalid state or encountering invalid input 237 | /// depending on the implementing distribution. 238 | pub trait Continuous { 239 | /// Returns the probability density function calculated at `x` for a given 240 | /// distribution. 241 | /// May panic depending on the implementor. 242 | /// 243 | /// # Examples 244 | /// 245 | /// ``` 246 | /// use statrs::distribution::{Continuous, Uniform}; 247 | /// 248 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 249 | /// assert_eq!(1.0, n.pdf(0.5)); 250 | /// ``` 251 | fn pdf(&self, x: K) -> T; 252 | 253 | /// Returns the log of the probability density function calculated at `x` 254 | /// for a given distribution. 255 | /// May panic depending on the implementor. 256 | /// 257 | /// # Examples 258 | /// 259 | /// ``` 260 | /// use statrs::distribution::{Continuous, Uniform}; 261 | /// 262 | /// let n = Uniform::new(0.0, 1.0).unwrap(); 263 | /// assert_eq!(0.0, n.ln_pdf(0.5)); 264 | /// ``` 265 | fn ln_pdf(&self, x: K) -> T; 266 | } 267 | 268 | /// The `Discrete` trait provides an interface for interacting with discrete 269 | /// statistical distributions 270 | /// 271 | /// # Remarks 272 | /// 273 | /// All methods provided by the `Discrete` trait are unchecked, meaning 274 | /// they can panic if in an invalid state or encountering invalid input 275 | /// depending on the implementing distribution. 276 | pub trait Discrete { 277 | /// Returns the probability mass function calculated at `x` for a given 278 | /// distribution. 279 | /// May panic depending on the implementor. 280 | /// 281 | /// # Examples 282 | /// 283 | /// ``` 284 | /// use statrs::distribution::{Discrete, Binomial}; 285 | /// use approx::assert_abs_diff_eq; 286 | /// 287 | /// let n = Binomial::new(0.5, 10).unwrap(); 288 | /// assert_abs_diff_eq!(n.pmf(5), 0.24609375, epsilon = 1e-15); 289 | /// ``` 290 | fn pmf(&self, x: K) -> T; 291 | 292 | /// Returns the log of the probability mass function calculated at `x` for 293 | /// a given distribution. 294 | /// May panic depending on the implementor. 295 | /// 296 | /// # Examples 297 | /// 298 | /// ``` 299 | /// use statrs::distribution::{Discrete, Binomial}; 300 | /// use approx::assert_abs_diff_eq; 301 | /// 302 | /// let n = Binomial::new(0.5, 10).unwrap(); 303 | /// assert_abs_diff_eq!(n.ln_pmf(5), (0.24609375f64).ln(), epsilon = 1e-15); 304 | /// ``` 305 | fn ln_pmf(&self, x: K) -> T; 306 | } 307 | -------------------------------------------------------------------------------- /src/stats_tests/f_oneway.rs: -------------------------------------------------------------------------------- 1 | //! Provides the [one-way ANOVA F-test](https://en.wikipedia.org/wiki/One-way_analysis_of_variance) 2 | //! and related functions 3 | 4 | use crate::distribution::{ContinuousCDF, FisherSnedecor}; 5 | use crate::stats_tests::NaNPolicy; 6 | 7 | /// Represents the errors that occur when computing the f_oneway function 8 | #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] 9 | #[non_exhaustive] 10 | pub enum FOneWayTestError { 11 | /// must be at least two samples 12 | NotEnoughSamples, 13 | /// one sample must be length greater than 1 14 | SampleTooSmall, 15 | /// samples must not contain all of the same values 16 | SampleContainsSameConstants, 17 | /// samples can not contain NaN when `nan_policy` is set to `NaNPolicy::Error` 18 | SampleContainsNaN, 19 | } 20 | 21 | impl core::fmt::Display for FOneWayTestError { 22 | #[cfg_attr(coverage_nightly, coverage(off))] 23 | fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { 24 | match self { 25 | FOneWayTestError::NotEnoughSamples => write!(f, "must be at least two samples"), 26 | FOneWayTestError::SampleTooSmall => { 27 | write!(f, "one sample must be length greater than 1") 28 | } 29 | FOneWayTestError::SampleContainsSameConstants => { 30 | write!(f, "samples must not contain all of the same values") 31 | } 32 | FOneWayTestError::SampleContainsNaN => { 33 | write!( 34 | f, 35 | "samples can not contain NaN when `nan_policy` is set to `NaNPolicy::Error`" 36 | ) 37 | } 38 | } 39 | } 40 | } 41 | 42 | impl std::error::Error for FOneWayTestError {} 43 | 44 | /// Perform a one-way Analysis of Variance (ANOVA) F-test 45 | /// 46 | /// Takes in a set (outer vector) of samples (inner vector) and returns the F-statistic and p-value 47 | /// 48 | /// # Remarks 49 | /// 50 | /// `samples` needs to be mutable in case needing to filter out NaNs for NaNPolicy::Emit 51 | /// 52 | /// Implementation based on [statsdirect](https://www.statsdirect.com/help/analysis_of_variance/one_way.htm) 53 | /// while aligning to [scipy's](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.f_oneway.html#scipy.stats.f_oneway) 54 | /// function header where possible. The scipy implementation was also used for testing and 55 | /// validation. Includes the use of [McDonald et al. (1991)](doi.org/10.1007/BF01319403) for 56 | /// testing and validation. 57 | /// 58 | /// # Examples 59 | /// 60 | /// ``` 61 | /// use statrs::stats_tests::f_oneway::f_oneway; 62 | /// use statrs::stats_tests::NaNPolicy; 63 | /// 64 | /// let a1 = Vec::from([6f64, 8f64, 4f64, 5f64, 3f64, 4f64]); 65 | /// let a2 = Vec::from([8f64, 12f64, 9f64, 11f64, 6f64, 8f64]); 66 | /// let a3 = Vec::from([13f64, 9f64, 11f64, 8f64, 7f64, 12f64]); 67 | /// let sample_input = Vec::from([a1, a2, a3]); 68 | /// let (statistic, pvalue) = f_oneway(sample_input, NaNPolicy::Error).unwrap(); // (9.3, 0.002) 69 | /// ``` 70 | pub fn f_oneway( 71 | mut samples: Vec>, 72 | nan_policy: NaNPolicy, 73 | ) -> Result<(f64, f64), FOneWayTestError> { 74 | let k = samples.len(); 75 | 76 | // initial input validation 77 | if k < 2 { 78 | return Err(FOneWayTestError::NotEnoughSamples); 79 | } 80 | 81 | let has_nans = samples.iter().flatten().any(|x| x.is_nan()); 82 | if has_nans { 83 | match nan_policy { 84 | NaNPolicy::Propogate => { 85 | return Ok((f64::NAN, f64::NAN)); 86 | } 87 | NaNPolicy::Error => { 88 | return Err(FOneWayTestError::SampleContainsNaN); 89 | } 90 | NaNPolicy::Emit => { 91 | samples = samples 92 | .into_iter() 93 | .map(|v| v.into_iter().filter(|x| !x.is_nan()).collect::>()) 94 | .collect::>(); 95 | } 96 | } 97 | } 98 | 99 | // do remaining input validation after potential subset from Emit 100 | let n_i: Vec = samples.iter().map(|v| v.len()).collect(); 101 | if !n_i.iter().all(|x| *x >= 1) || !n_i.iter().any(|x| *x >= 2) { 102 | return Err(FOneWayTestError::SampleTooSmall); 103 | } 104 | 105 | if samples.iter().any(|v| { 106 | if v.len() > 1 { 107 | let mut it = v.iter(); 108 | let first = it.next().unwrap(); 109 | it.all(|x| x == first) 110 | } else { 111 | false 112 | } 113 | }) { 114 | return Err(FOneWayTestError::SampleContainsSameConstants); 115 | } 116 | 117 | let n = n_i.iter().sum::(); 118 | let g = samples.iter().flatten().sum::(); 119 | 120 | let tsq = samples 121 | .iter() 122 | .map(|v| v.iter().sum::().powi(2) / v.len() as f64) 123 | .sum::(); 124 | let ysq = samples.iter().flatten().map(|x| x.powi(2)).sum::(); 125 | 126 | // Sum of Squares (SS) and Mean Square (MS) between and within groups 127 | let sst = tsq - (g.powi(2) / n as f64); 128 | let mst = sst / (k - 1) as f64; 129 | 130 | let sse = ysq - tsq; 131 | let mse = sse / (n - k) as f64; 132 | 133 | let fstat = mst / mse; 134 | 135 | // degrees of freedom for between groups (t) and within groups (e) 136 | let dft = (k - 1) as f64; 137 | let dfe = (n - k) as f64; 138 | // k >= 2 meaning dft = (k-1) > 0 or Err(NotEnoughSamples) 139 | // one group must be at least 2 and all other groups must be at least 1 or Err(SampleTooSmall) 140 | // meaning that the minimum value of n will always be at least one greater than k so dfe must 141 | // be > 0 142 | let f_dist = FisherSnedecor::new(dft, dfe).expect("degrees of freedom should always be >0 "); 143 | let pvalue = 1.0 - f_dist.cdf(fstat); 144 | 145 | Ok((fstat, pvalue)) 146 | } 147 | 148 | #[cfg(test)] 149 | mod tests { 150 | use super::*; 151 | use crate::prec; 152 | 153 | #[test] 154 | fn test_scipy_example() { 155 | let tillamook = Vec::from([ 156 | 0.0571, 0.0813, 0.0831, 0.0976, 0.0817, 0.0859, 0.0735, 0.0659, 0.0923, 0.0836, 157 | ]); 158 | let newport = Vec::from([ 159 | 0.0873, 0.0662, 0.0672, 0.0819, 0.0749, 0.0649, 0.0835, 0.0725, 160 | ]); 161 | let petersburg = Vec::from([0.0974, 0.1352, 0.0817, 0.1016, 0.0968, 0.1064, 0.105]); 162 | let magadan = Vec::from([ 163 | 0.1033, 0.0915, 0.0781, 0.0685, 0.0677, 0.0697, 0.0764, 0.0689, 164 | ]); 165 | let tvarminne = Vec::from([0.0703, 0.1026, 0.0956, 0.0973, 0.1039, 0.1045]); 166 | let sample_input = Vec::from([tillamook, newport, petersburg, magadan, tvarminne]); 167 | let (statistic, pvalue) = f_oneway(sample_input, NaNPolicy::Error).unwrap(); 168 | 169 | prec::assert_relative_eq!(statistic, 7.121019471642447); 170 | prec::assert_abs_diff_eq!(pvalue, 0.0002812242314534544); 171 | } 172 | 173 | #[test] 174 | fn test_nan_in_data_w_emit() { 175 | // same as scipy example above with NaNs added should give same result 176 | let tillamook = Vec::from([ 177 | 0.0571, 178 | 0.0813, 179 | 0.0831, 180 | 0.0976, 181 | 0.0817, 182 | 0.0859, 183 | 0.0735, 184 | 0.0659, 185 | 0.0923, 186 | 0.0836, 187 | f64::NAN, 188 | ]); 189 | let newport = Vec::from([ 190 | 0.0873, 0.0662, 0.0672, 0.0819, 0.0749, 0.0649, 0.0835, 0.0725, 191 | ]); 192 | let petersburg = Vec::from([0.0974, 0.1352, 0.0817, 0.1016, 0.0968, 0.1064, 0.105]); 193 | let magadan = Vec::from([ 194 | 0.1033, 195 | 0.0915, 196 | 0.0781, 197 | 0.0685, 198 | 0.0677, 199 | 0.0697, 200 | 0.0764, 201 | 0.0689, 202 | f64::NAN, 203 | ]); 204 | let tvarminne = Vec::from([0.0703, 0.1026, 0.0956, 0.0973, 0.1039, 0.1045]); 205 | let sample_input = Vec::from([tillamook, newport, petersburg, magadan, tvarminne]); 206 | let (statistic, pvalue) = f_oneway(sample_input, NaNPolicy::Emit).unwrap(); 207 | 208 | prec::assert_relative_eq!(statistic, 7.121019471642447); 209 | prec::assert_abs_diff_eq!(pvalue, 0.0002812242314534544); 210 | } 211 | 212 | #[test] 213 | fn test_group_length_one_ok() { 214 | // group length 1 doesn't result in error 215 | let group1 = Vec::from([0.5]); 216 | let group2 = Vec::from([0.25, 0.75]); 217 | let sample_input = Vec::from([group1, group2]); 218 | let (statistic, pvalue) = f_oneway(sample_input, NaNPolicy::Propogate).unwrap(); 219 | prec::assert_relative_eq!(statistic, 0.0); 220 | prec::assert_abs_diff_eq!(pvalue, 1.0); 221 | } 222 | 223 | #[test] 224 | fn test_nan_in_data_w_propogate() { 225 | let group1 = Vec::from([0.0571, 0.0813, f64::NAN, 0.0836]); 226 | let group2 = Vec::from([0.0873, 0.0662, 0.0672, 0.0819, 0.0749]); 227 | let sample_input = Vec::from([group1, group2]); 228 | let (statistic, pvalue) = f_oneway(sample_input, NaNPolicy::Propogate).unwrap(); 229 | assert!(statistic.is_nan()); 230 | assert!(pvalue.is_nan()); 231 | } 232 | 233 | #[test] 234 | fn test_nan_in_data_w_error() { 235 | let group1 = Vec::from([0.0571, 0.0813, f64::NAN, 0.0836]); 236 | let group2 = Vec::from([0.0873, 0.0662, 0.0672, 0.0819, 0.0749]); 237 | let sample_input = Vec::from([group1, group2]); 238 | let result = f_oneway(sample_input, NaNPolicy::Error); 239 | assert_eq!(result, Err(FOneWayTestError::SampleContainsNaN)); 240 | } 241 | 242 | #[test] 243 | fn test_bad_data_not_enough_samples() { 244 | let group1 = Vec::from([0.0, 0.0]); 245 | let sample_input = Vec::from([group1]); 246 | let result = f_oneway(sample_input, NaNPolicy::Propogate); 247 | assert_eq!(result, Err(FOneWayTestError::NotEnoughSamples)) 248 | } 249 | 250 | #[test] 251 | fn test_bad_data_sample_too_small() { 252 | let group1 = Vec::new(); 253 | let group2 = Vec::from([0.0873, 0.0662]); 254 | let sample_input = Vec::from([group1, group2]); 255 | let result = f_oneway(sample_input, NaNPolicy::Propogate); 256 | assert_eq!(result, Err(FOneWayTestError::SampleTooSmall)); 257 | 258 | let group1 = Vec::from([f64::NAN]); 259 | let group2 = Vec::from([0.0873, 0.0662]); 260 | let sample_input = Vec::from([group1, group2]); 261 | let result = f_oneway(sample_input, NaNPolicy::Emit); 262 | assert_eq!(result, Err(FOneWayTestError::SampleTooSmall)); 263 | 264 | let group1 = Vec::from([1.0]); 265 | let group2 = Vec::from([0.0873]); 266 | let sample_input = Vec::from([group1, group2]); 267 | let result = f_oneway(sample_input, NaNPolicy::Propogate); 268 | assert_eq!(result, Err(FOneWayTestError::SampleTooSmall)); 269 | 270 | let group1 = Vec::from([1.0, f64::NAN]); 271 | let group2 = Vec::from([0.0873, f64::NAN]); 272 | let sample_input = Vec::from([group1, group2]); 273 | let result = f_oneway(sample_input, NaNPolicy::Emit); 274 | assert_eq!(result, Err(FOneWayTestError::SampleTooSmall)); 275 | } 276 | 277 | #[test] 278 | fn test_bad_data_sample_contains_same_constants() { 279 | let group1 = Vec::from([1.0, 1.0]); 280 | let group2 = Vec::from([2.0, 2.0]); 281 | let sample_input = Vec::from([group1, group2]); 282 | let result = f_oneway(sample_input, NaNPolicy::Error); 283 | assert_eq!(result, Err(FOneWayTestError::SampleContainsSameConstants)); 284 | 285 | let group1 = Vec::from([1.0, 1.0, 1.0]); 286 | let group2 = Vec::from([0.0873, 0.0662, 0.0342]); 287 | let sample_input = Vec::from([group1, group2]); 288 | let result = f_oneway(sample_input, NaNPolicy::Error); 289 | assert_eq!(result, Err(FOneWayTestError::SampleContainsSameConstants)); 290 | } 291 | } 292 | -------------------------------------------------------------------------------- /src/generate.rs: -------------------------------------------------------------------------------- 1 | //! Provides utility functions for generating data sequences 2 | 3 | use crate::euclid::Modulus; 4 | use core::f64::consts; 5 | /// Generates a base 10 log spaced vector of the given length between the 6 | /// specified decade exponents (inclusive). Equivalent to MATLAB logspace 7 | /// 8 | /// # Examples 9 | /// 10 | /// ``` 11 | /// use statrs::generate; 12 | /// 13 | /// let x = generate::log_spaced(5, 0.0, 4.0); 14 | /// assert_eq!(x, [1.0, 10.0, 100.0, 1000.0, 10000.0]); 15 | /// ``` 16 | #[cfg(feature = "std")] 17 | pub fn log_spaced(length: usize, start_exp: f64, stop_exp: f64) -> Vec { 18 | match length { 19 | 0 => Vec::new(), 20 | 1 => vec![10f64.powf(stop_exp)], 21 | _ => { 22 | let step = (stop_exp - start_exp) / (length - 1) as f64; 23 | let mut vec = (0..length) 24 | .map(|x| 10f64.powf(start_exp + (x as f64) * step)) 25 | .collect::>(); 26 | vec[length - 1] = 10f64.powf(stop_exp); 27 | vec 28 | } 29 | } 30 | } 31 | 32 | /// Infinite iterator returning floats that form a periodic wave 33 | #[derive(Clone, Copy, PartialEq, Debug)] 34 | pub struct InfinitePeriodic { 35 | amplitude: f64, 36 | step: f64, 37 | phase: f64, 38 | k: f64, 39 | } 40 | 41 | impl InfinitePeriodic { 42 | /// Constructs a new infinite periodic wave generator 43 | /// 44 | /// # Examples 45 | /// 46 | /// ``` 47 | /// use statrs::generate::InfinitePeriodic; 48 | /// 49 | /// let x = InfinitePeriodic::new(8.0, 2.0, 10.0, 1.0, 50 | /// 2).take(10).collect::>(); 51 | /// assert_eq!(x, [6.0, 8.5, 1.0, 3.5, 6.0, 8.5, 1.0, 3.5, 6.0, 8.5]); 52 | /// ``` 53 | pub fn new( 54 | sampling_rate: f64, 55 | frequency: f64, 56 | amplitude: f64, 57 | phase: f64, 58 | delay: i64, 59 | ) -> InfinitePeriodic { 60 | let step = frequency / sampling_rate * amplitude; 61 | InfinitePeriodic { 62 | amplitude, 63 | step, 64 | phase: (phase - delay as f64 * step).modulus(amplitude), 65 | k: 0.0, 66 | } 67 | } 68 | 69 | /// Constructs a default infinite periodic wave generator 70 | /// 71 | /// # Examples 72 | /// 73 | /// ``` 74 | /// use statrs::generate::InfinitePeriodic; 75 | /// 76 | /// let x = InfinitePeriodic::default(8.0, 77 | /// 2.0).take(10).collect::>(); 78 | /// assert_eq!(x, [0.0, 0.25, 0.5, 0.75, 0.0, 0.25, 0.5, 0.75, 0.0, 0.25]); 79 | /// ``` 80 | pub fn default(sampling_rate: f64, frequency: f64) -> InfinitePeriodic { 81 | Self::new(sampling_rate, frequency, 1.0, 0.0, 0) 82 | } 83 | } 84 | 85 | impl core::fmt::Display for InfinitePeriodic { 86 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 87 | write!(f, "{self:#?}") 88 | } 89 | } 90 | 91 | impl Iterator for InfinitePeriodic { 92 | type Item = f64; 93 | 94 | fn next(&mut self) -> Option { 95 | let mut x = self.phase + self.k * self.step; 96 | if x >= self.amplitude { 97 | x %= self.amplitude; 98 | self.phase = x; 99 | self.k = 0.0; 100 | } 101 | self.k += 1.0; 102 | Some(x) 103 | } 104 | } 105 | 106 | /// Infinite iterator returning floats that form a sinusoidal wave 107 | #[derive(Debug, Clone, Copy, PartialEq)] 108 | pub struct InfiniteSinusoidal { 109 | amplitude: f64, 110 | mean: f64, 111 | step: f64, 112 | phase: f64, 113 | i: usize, 114 | } 115 | 116 | impl InfiniteSinusoidal { 117 | /// Constructs a new infinite sinusoidal wave generator 118 | /// 119 | /// # Examples 120 | /// 121 | /// ``` 122 | /// use statrs::generate::InfiniteSinusoidal; 123 | /// 124 | /// let x = InfiniteSinusoidal::new(8.0, 2.0, 1.0, 5.0, 2.0, 125 | /// 1).take(10).collect::>(); 126 | /// assert_eq!(x, 127 | /// [5.416146836547142, 5.909297426825682, 4.583853163452858, 128 | /// 4.090702573174318, 5.416146836547142, 5.909297426825682, 129 | /// 4.583853163452858, 4.090702573174318, 5.416146836547142, 130 | /// 5.909297426825682]); 131 | /// ``` 132 | pub fn new( 133 | sampling_rate: f64, 134 | frequency: f64, 135 | amplitude: f64, 136 | mean: f64, 137 | phase: f64, 138 | delay: i64, 139 | ) -> InfiniteSinusoidal { 140 | let pi2 = consts::PI * 2.0; 141 | let step = frequency / sampling_rate * pi2; 142 | InfiniteSinusoidal { 143 | amplitude, 144 | mean, 145 | step, 146 | phase: (phase - delay as f64 * step) % pi2, 147 | i: 0, 148 | } 149 | } 150 | 151 | /// Constructs a default infinite sinusoidal wave generator 152 | /// 153 | /// # Examples 154 | /// 155 | /// ``` 156 | /// use statrs::generate::InfiniteSinusoidal; 157 | /// 158 | /// let x = InfiniteSinusoidal::default(8.0, 2.0, 159 | /// 1.0).take(10).collect::>(); 160 | /// assert_eq!(x, 161 | /// [0.0, 1.0, 0.00000000000000012246467991473532, 162 | /// -1.0, -0.00000000000000024492935982947064, 1.0, 163 | /// 0.00000000000000036739403974420594, -1.0, 164 | /// -0.0000000000000004898587196589413, 1.0]); 165 | /// ``` 166 | pub fn default(sampling_rate: f64, frequency: f64, amplitude: f64) -> InfiniteSinusoidal { 167 | Self::new(sampling_rate, frequency, amplitude, 0.0, 0.0, 0) 168 | } 169 | } 170 | 171 | impl core::fmt::Display for InfiniteSinusoidal { 172 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 173 | write!(f, "{:#?}", &self) 174 | } 175 | } 176 | 177 | impl Iterator for InfiniteSinusoidal { 178 | type Item = f64; 179 | 180 | fn next(&mut self) -> Option { 181 | let x = self.mean + self.amplitude * (self.phase + self.i as f64 * self.step).sin(); 182 | self.i += 1; 183 | if self.i == 1000 { 184 | self.i = 0; 185 | self.phase = (self.phase + 1000.0 * self.step) % (consts::PI * 2.0); 186 | } 187 | Some(x) 188 | } 189 | } 190 | 191 | /// Infinite iterator returning floats forming a square wave starting 192 | /// with the high phase 193 | #[derive(Debug, Clone, Copy, PartialEq)] 194 | pub struct InfiniteSquare { 195 | periodic: InfinitePeriodic, 196 | high_duration: f64, 197 | high_value: f64, 198 | low_value: f64, 199 | } 200 | 201 | impl InfiniteSquare { 202 | /// Constructs a new infinite square wave generator 203 | /// 204 | /// # Examples 205 | /// 206 | /// ``` 207 | /// use statrs::generate::InfiniteSquare; 208 | /// 209 | /// let x = InfiniteSquare::new(3, 7, 1.0, -1.0, 210 | /// 1).take(12).collect::>(); 211 | /// assert_eq!(x, [-1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 212 | /// -1.0, 1.0]) 213 | /// ``` 214 | pub fn new( 215 | high_duration: i64, 216 | low_duration: i64, 217 | high_value: f64, 218 | low_value: f64, 219 | delay: i64, 220 | ) -> InfiniteSquare { 221 | let duration = (high_duration + low_duration) as f64; 222 | InfiniteSquare { 223 | periodic: InfinitePeriodic::new(1.0, 1.0 / duration, duration, 0.0, delay), 224 | high_duration: high_duration as f64, 225 | high_value, 226 | low_value, 227 | } 228 | } 229 | } 230 | 231 | impl core::fmt::Display for InfiniteSquare { 232 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 233 | write!(f, "{:#?}", &self) 234 | } 235 | } 236 | 237 | impl Iterator for InfiniteSquare { 238 | type Item = f64; 239 | 240 | fn next(&mut self) -> Option { 241 | self.periodic.next().map(|x| { 242 | if x < self.high_duration { 243 | self.high_value 244 | } else { 245 | self.low_value 246 | } 247 | }) 248 | } 249 | } 250 | 251 | /// Infinite iterator returning floats forming a triangle wave starting with 252 | /// the raise phase from the lowest sample 253 | #[derive(Debug, Clone, Copy, PartialEq)] 254 | pub struct InfiniteTriangle { 255 | periodic: InfinitePeriodic, 256 | raise_duration: f64, 257 | raise: f64, 258 | fall: f64, 259 | high_value: f64, 260 | low_value: f64, 261 | } 262 | 263 | impl InfiniteTriangle { 264 | /// Constructs a new infinite triangle wave generator 265 | /// 266 | /// # Examples 267 | /// 268 | /// ``` 269 | /// use approx::assert_abs_diff_eq; 270 | /// 271 | /// use statrs::generate::InfiniteTriangle; 272 | /// 273 | /// # fn main() { 274 | /// let x = InfiniteTriangle::new(4, 7, 1.0, -1.0, 275 | /// 1).take(12).collect::>(); 276 | /// let expected: [f64; 12] = [-0.714, -1.0, -0.5, 0.0, 0.5, 1.0, 0.714, 277 | /// 0.429, 0.143, -0.143, -0.429, -0.714]; 278 | /// for (&left, &right) in x.iter().zip(expected.iter()) { 279 | /// assert_abs_diff_eq!(left, right, epsilon = 1e-3); 280 | /// } 281 | /// # } 282 | /// ``` 283 | pub fn new( 284 | raise_duration: i64, 285 | fall_duration: i64, 286 | high_value: f64, 287 | low_value: f64, 288 | delay: i64, 289 | ) -> InfiniteTriangle { 290 | let duration = (raise_duration + fall_duration) as f64; 291 | let height = high_value - low_value; 292 | InfiniteTriangle { 293 | periodic: InfinitePeriodic::new(1.0, 1.0 / duration, duration, 0.0, delay), 294 | raise_duration: raise_duration as f64, 295 | raise: height / raise_duration as f64, 296 | fall: height / fall_duration as f64, 297 | high_value, 298 | low_value, 299 | } 300 | } 301 | } 302 | 303 | impl core::fmt::Display for InfiniteTriangle { 304 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 305 | write!(f, "{:#?}", &self) 306 | } 307 | } 308 | 309 | impl Iterator for InfiniteTriangle { 310 | type Item = f64; 311 | 312 | fn next(&mut self) -> Option { 313 | self.periodic.next().map(|x| { 314 | if x < self.raise_duration { 315 | self.low_value + x * self.raise 316 | } else { 317 | self.high_value - (x - self.raise_duration) * self.fall 318 | } 319 | }) 320 | } 321 | } 322 | 323 | /// Infinite iterator returning floats forming a sawtooth wave 324 | /// starting with the lowest sample 325 | #[derive(Debug, Clone, Copy, PartialEq)] 326 | pub struct InfiniteSawtooth { 327 | periodic: InfinitePeriodic, 328 | low_value: f64, 329 | } 330 | 331 | impl InfiniteSawtooth { 332 | /// Constructs a new infinite sawtooth wave generator 333 | /// 334 | /// # Examples 335 | /// 336 | /// ``` 337 | /// use statrs::generate::InfiniteSawtooth; 338 | /// 339 | /// let x = InfiniteSawtooth::new(5, 1.0, -1.0, 340 | /// 1).take(12).collect::>(); 341 | /// assert_eq!(x, [1.0, -1.0, -0.5, 0.0, 0.5, 1.0, -1.0, -0.5, 0.0, 0.5, 342 | /// 1.0, -1.0]); 343 | /// ``` 344 | pub fn new(period: i64, high_value: f64, low_value: f64, delay: i64) -> InfiniteSawtooth { 345 | let height = high_value - low_value; 346 | let period = period as f64; 347 | InfiniteSawtooth { 348 | periodic: InfinitePeriodic::new( 349 | 1.0, 350 | 1.0 / period, 351 | height * period / (period - 1.0), 352 | 0.0, 353 | delay, 354 | ), 355 | low_value, 356 | } 357 | } 358 | } 359 | 360 | impl core::fmt::Display for InfiniteSawtooth { 361 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 362 | write!(f, "{:#?}", &self) 363 | } 364 | } 365 | 366 | impl Iterator for InfiniteSawtooth { 367 | type Item = f64; 368 | 369 | fn next(&mut self) -> Option { 370 | self.periodic.next().map(|x| x + self.low_value) 371 | } 372 | } 373 | -------------------------------------------------------------------------------- /src/statistics/statistics.rs: -------------------------------------------------------------------------------- 1 | /// Enumeration of possible tie-breaking strategies 2 | /// when computing ranks 3 | #[derive(Copy, Clone, Debug)] 4 | pub enum RankTieBreaker { 5 | /// Replaces ties with their mean 6 | Average, 7 | /// Replace ties with their minimum 8 | Min, 9 | /// Replace ties with their maximum 10 | Max, 11 | /// Permutation with increasing values at each index of ties 12 | First, 13 | } 14 | 15 | /// The `Statistics` trait provides a host of statistical utilities for 16 | /// analyzing 17 | /// data sets 18 | pub trait Statistics { 19 | /// Returns the minimum value in the data 20 | /// 21 | /// # Remarks 22 | /// 23 | /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` 24 | /// 25 | /// # Examples 26 | /// 27 | /// ``` 28 | /// use core::f64; 29 | /// use statrs::statistics::Statistics; 30 | /// 31 | /// let x = &[]; 32 | /// assert!(Statistics::min(x).is_nan()); 33 | /// 34 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 35 | /// assert!(Statistics::min(y).is_nan()); 36 | /// 37 | /// let z = &[0.0, 3.0, -2.0]; 38 | /// assert_eq!(Statistics::min(z), -2.0); 39 | /// ``` 40 | fn min(self) -> T; 41 | 42 | /// Returns the maximum value in the data 43 | /// 44 | /// # Remarks 45 | /// 46 | /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` 47 | /// 48 | /// # Examples 49 | /// 50 | /// ``` 51 | /// use core::f64; 52 | /// use statrs::statistics::Statistics; 53 | /// 54 | /// let x = &[]; 55 | /// assert!(Statistics::max(x).is_nan()); 56 | /// 57 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 58 | /// assert!(Statistics::max(y).is_nan()); 59 | /// 60 | /// let z = &[0.0, 3.0, -2.0]; 61 | /// assert_eq!(Statistics::max(z), 3.0); 62 | /// ``` 63 | fn max(self) -> T; 64 | 65 | /// Returns the minimum absolute value in the data 66 | /// 67 | /// # Remarks 68 | /// 69 | /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` 70 | /// 71 | /// # Examples 72 | /// 73 | /// ``` 74 | /// use core::f64; 75 | /// use statrs::statistics::Statistics; 76 | /// 77 | /// let x = &[]; 78 | /// assert!(x.abs_min().is_nan()); 79 | /// 80 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 81 | /// assert!(y.abs_min().is_nan()); 82 | /// 83 | /// let z = &[0.0, 3.0, -2.0]; 84 | /// assert_eq!(z.abs_min(), 0.0); 85 | /// ``` 86 | fn abs_min(self) -> T; 87 | 88 | /// Returns the maximum absolute value in the data 89 | /// 90 | /// # Remarks 91 | /// 92 | /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` 93 | /// 94 | /// # Examples 95 | /// 96 | /// ``` 97 | /// use core::f64; 98 | /// use statrs::statistics::Statistics; 99 | /// 100 | /// let x = &[]; 101 | /// assert!(x.abs_max().is_nan()); 102 | /// 103 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 104 | /// assert!(y.abs_max().is_nan()); 105 | /// 106 | /// let z = &[0.0, 3.0, -2.0, -8.0]; 107 | /// assert_eq!(z.abs_max(), 8.0); 108 | /// ``` 109 | fn abs_max(self) -> T; 110 | 111 | /// Evaluates the sample mean, an estimate of the population 112 | /// mean. 113 | /// 114 | /// # Remarks 115 | /// 116 | /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` 117 | /// 118 | /// # Examples 119 | /// 120 | /// ``` 121 | /// use approx::assert_abs_diff_eq; 122 | /// 123 | /// use core::f64; 124 | /// use statrs::statistics::Statistics; 125 | /// 126 | /// # fn main() { 127 | /// let x = &[]; 128 | /// assert!(x.mean().is_nan()); 129 | /// 130 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 131 | /// assert!(y.mean().is_nan()); 132 | /// 133 | /// let z = &[0.0, 3.0, -2.0]; 134 | /// assert_abs_diff_eq!(z.mean(), 1.0 / 3.0, epsilon = 1e-15); 135 | /// # } 136 | /// ``` 137 | fn mean(self) -> T; 138 | 139 | /// Evaluates the geometric mean of the data 140 | /// 141 | /// # Remarks 142 | /// 143 | /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN`. 144 | /// Returns `f64::NAN` if an entry is less than `0`. Returns `0` 145 | /// if no entry is less than `0` but there are entries equal to `0`. 146 | /// 147 | /// # Examples 148 | /// 149 | /// ``` 150 | /// use approx::assert_abs_diff_eq; 151 | /// 152 | /// use core::f64; 153 | /// use statrs::statistics::Statistics; 154 | /// 155 | /// # fn main() { 156 | /// let x = &[]; 157 | /// assert!(x.geometric_mean().is_nan()); 158 | /// 159 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 160 | /// assert!(y.geometric_mean().is_nan()); 161 | /// 162 | /// let mut z = &[0.0, 3.0, -2.0]; 163 | /// assert!(z.geometric_mean().is_nan()); 164 | /// 165 | /// z = &[0.0, 3.0, 2.0]; 166 | /// assert_eq!(z.geometric_mean(), 0.0); 167 | /// 168 | /// z = &[1.0, 2.0, 3.0]; 169 | /// // test value from online calculator, could be more accurate 170 | /// assert_abs_diff_eq!(z.geometric_mean(), 1.81712, epsilon = 1e-5); 171 | /// # } 172 | /// ``` 173 | fn geometric_mean(self) -> T; 174 | 175 | /// Evaluates the harmonic mean of the data 176 | /// 177 | /// # Remarks 178 | /// 179 | /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN`, or if 180 | /// any value 181 | /// in data is less than `0`. Returns `0` if there are no values less than 182 | /// `0` but 183 | /// there exists values equal to `0`. 184 | /// 185 | /// # Examples 186 | /// 187 | /// ``` 188 | /// use approx::assert_abs_diff_eq; 189 | /// 190 | /// use core::f64; 191 | /// use statrs::statistics::Statistics; 192 | /// 193 | /// # fn main() { 194 | /// let x = &[]; 195 | /// assert!(x.harmonic_mean().is_nan()); 196 | /// 197 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 198 | /// assert!(y.harmonic_mean().is_nan()); 199 | /// 200 | /// let mut z = &[0.0, 3.0, -2.0]; 201 | /// assert!(z.harmonic_mean().is_nan()); 202 | /// 203 | /// z = &[0.0, 3.0, 2.0]; 204 | /// assert_eq!(z.harmonic_mean(), 0.0); 205 | /// 206 | /// z = &[1.0, 2.0, 3.0]; 207 | /// // test value from online calculator, could be more accurate 208 | /// assert_abs_diff_eq!(z.harmonic_mean(), 1.63636, epsilon = 1e-5); 209 | /// # } 210 | /// ``` 211 | fn harmonic_mean(self) -> T; 212 | 213 | /// Estimates the unbiased population variance from the provided samples 214 | /// 215 | /// # Remarks 216 | /// 217 | /// On a dataset of size `N`, `N-1` is used as a normalizer (Bessel's 218 | /// correction). 219 | /// 220 | /// Returns `f64::NAN` if data has less than two entries or if any entry is 221 | /// `f64::NAN` 222 | /// 223 | /// # Examples 224 | /// 225 | /// ``` 226 | /// use core::f64; 227 | /// use statrs::statistics::Statistics; 228 | /// 229 | /// let x = &[]; 230 | /// assert!(x.variance().is_nan()); 231 | /// 232 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 233 | /// assert!(y.variance().is_nan()); 234 | /// 235 | /// let z = &[0.0, 3.0, -2.0]; 236 | /// assert_eq!(z.variance(), 19.0 / 3.0); 237 | /// ``` 238 | fn variance(self) -> T; 239 | 240 | /// Estimates the unbiased population standard deviation from the provided 241 | /// samples 242 | /// 243 | /// # Remarks 244 | /// 245 | /// On a dataset of size `N`, `N-1` is used as a normalizer (Bessel's 246 | /// correction). 247 | /// 248 | /// Returns `f64::NAN` if data has less than two entries or if any entry is 249 | /// `f64::NAN` 250 | /// 251 | /// # Examples 252 | /// 253 | /// ``` 254 | /// use core::f64; 255 | /// use statrs::statistics::Statistics; 256 | /// 257 | /// let x = &[]; 258 | /// assert!(x.std_dev().is_nan()); 259 | /// 260 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 261 | /// assert!(y.std_dev().is_nan()); 262 | /// 263 | /// let z = &[0.0, 3.0, -2.0]; 264 | /// assert_eq!(z.std_dev(), (19f64 / 3.0).sqrt()); 265 | /// ``` 266 | fn std_dev(self) -> T; 267 | 268 | /// Evaluates the population variance from a full population. 269 | /// 270 | /// # Remarks 271 | /// 272 | /// On a dataset of size `N`, `N` is used as a normalizer and would thus 273 | /// be biased if applied to a subset 274 | /// 275 | /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` 276 | /// 277 | /// # Examples 278 | /// 279 | /// ``` 280 | /// use core::f64; 281 | /// use statrs::statistics::Statistics; 282 | /// 283 | /// let x = &[]; 284 | /// assert!(x.population_variance().is_nan()); 285 | /// 286 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 287 | /// assert!(y.population_variance().is_nan()); 288 | /// 289 | /// let z = &[0.0, 3.0, -2.0]; 290 | /// assert_eq!(z.population_variance(), 38.0 / 9.0); 291 | /// ``` 292 | fn population_variance(self) -> T; 293 | 294 | /// Evaluates the population standard deviation from a full population. 295 | /// 296 | /// # Remarks 297 | /// 298 | /// On a dataset of size `N`, `N` is used as a normalizer and would thus 299 | /// be biased if applied to a subset 300 | /// 301 | /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` 302 | /// 303 | /// # Examples 304 | /// 305 | /// ``` 306 | /// use core::f64; 307 | /// use statrs::statistics::Statistics; 308 | /// 309 | /// let x = &[]; 310 | /// assert!(x.population_std_dev().is_nan()); 311 | /// 312 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 313 | /// assert!(y.population_std_dev().is_nan()); 314 | /// 315 | /// let z = &[0.0, 3.0, -2.0]; 316 | /// assert_eq!(z.population_std_dev(), (38f64 / 9.0).sqrt()); 317 | /// ``` 318 | fn population_std_dev(self) -> T; 319 | 320 | /// Estimates the unbiased population covariance between the two provided 321 | /// samples 322 | /// 323 | /// # Remarks 324 | /// 325 | /// On a dataset of size `N`, `N-1` is used as a normalizer (Bessel's 326 | /// correction). 327 | /// 328 | /// Returns `f64::NAN` if data has less than two entries or if any entry is 329 | /// `f64::NAN` 330 | /// 331 | /// # Panics 332 | /// 333 | /// If the two sample containers do not contain the same number of elements 334 | /// 335 | /// # Examples 336 | /// 337 | /// ``` 338 | /// use approx::assert_abs_diff_eq; 339 | /// 340 | /// use core::f64; 341 | /// use statrs::statistics::Statistics; 342 | /// 343 | /// # fn main() { 344 | /// let x = &[]; 345 | /// assert!(x.covariance(&[]).is_nan()); 346 | /// 347 | /// let y1 = &[0.0, f64::NAN, 3.0, -2.0]; 348 | /// let y2 = &[-5.0, 4.0, 10.0, f64::NAN]; 349 | /// assert!(y1.covariance(y2).is_nan()); 350 | /// 351 | /// let z1 = &[0.0, 3.0, -2.0]; 352 | /// let z2 = &[-5.0, 4.0, 10.0]; 353 | /// assert_abs_diff_eq!(z1.covariance(z2), -5.5, epsilon = 1e-14); 354 | /// # } 355 | /// ``` 356 | fn covariance(self, other: Self) -> T; 357 | 358 | /// Evaluates the population covariance between the two provider populations 359 | /// 360 | /// # Remarks 361 | /// 362 | /// On a dataset of size `N`, `N` is used as a normalizer and would thus be 363 | /// biased if applied to a subset 364 | /// 365 | /// Returns `f64::NAN` if data is empty or any entry is `f64::NAN` 366 | /// 367 | /// # Panics 368 | /// 369 | /// If the two sample containers do not contain the same number of elements 370 | /// 371 | /// # Examples 372 | /// 373 | /// ``` 374 | /// use approx::assert_abs_diff_eq; 375 | /// 376 | /// use core::f64; 377 | /// use statrs::statistics::Statistics; 378 | /// 379 | /// # fn main() { 380 | /// let x = &[]; 381 | /// assert!(x.population_covariance(&[]).is_nan()); 382 | /// 383 | /// let y1 = &[0.0, f64::NAN, 3.0, -2.0]; 384 | /// let y2 = &[-5.0, 4.0, 10.0, f64::NAN]; 385 | /// assert!(y1.population_covariance(y2).is_nan()); 386 | /// 387 | /// let z1 = &[0.0, 3.0, -2.0]; 388 | /// let z2 = &[-5.0, 4.0, 10.0]; 389 | /// assert_abs_diff_eq!(z1.population_covariance(z2), -11.0 / 3.0, epsilon = 1e-14); 390 | /// # } 391 | /// ``` 392 | fn population_covariance(self, other: Self) -> T; 393 | 394 | /// Estimates the quadratic mean (Root Mean Square) of the data 395 | /// 396 | /// # Remarks 397 | /// 398 | /// Returns `f64::NAN` if data is empty or any entry is `f64::NAN` 399 | /// 400 | /// # Examples 401 | /// 402 | /// ``` 403 | /// use approx::assert_abs_diff_eq; 404 | /// 405 | /// use core::f64; 406 | /// use statrs::statistics::Statistics; 407 | /// 408 | /// # fn main() { 409 | /// let x = &[]; 410 | /// assert!(x.quadratic_mean().is_nan()); 411 | /// 412 | /// let y = &[0.0, f64::NAN, 3.0, -2.0]; 413 | /// assert!(y.quadratic_mean().is_nan()); 414 | /// 415 | /// let z = &[0.0, 3.0, -2.0]; 416 | /// // test value from online calculator, could be more accurate 417 | /// assert_abs_diff_eq!(z.quadratic_mean(), 2.08167, epsilon = 1e-5); 418 | /// # } 419 | /// ``` 420 | fn quadratic_mean(self) -> T; 421 | } 422 | -------------------------------------------------------------------------------- /src/stats_tests/fisher.rs: -------------------------------------------------------------------------------- 1 | use super::Alternative; 2 | use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric, HypergeometricError}; 3 | 4 | const EPSILON: f64 = 1.0 - 1e-4; 5 | 6 | /// Binary search in two-sided test with starting bound as argument 7 | fn binary_search( 8 | n: u64, 9 | n1: u64, 10 | n2: u64, 11 | mode: u64, 12 | p_exact: f64, 13 | epsilon: f64, 14 | upper: bool, 15 | ) -> u64 { 16 | let (mut min_val, mut max_val) = { if upper { (mode, n) } else { (0, mode) } }; 17 | 18 | let population = n1 + n2; 19 | let successes = n1; 20 | let draws = n; 21 | let dist = Hypergeometric::new(population, successes, draws).unwrap(); 22 | 23 | let mut guess = 0; 24 | loop { 25 | if max_val - min_val <= 1 { 26 | break; 27 | } 28 | guess = { 29 | if max_val == min_val + 1 && guess == min_val { 30 | max_val 31 | } else { 32 | (max_val + min_val) / 2 33 | } 34 | }; 35 | 36 | let ng = { if upper { guess - 1 } else { guess + 1 } }; 37 | 38 | let pmf_comp = dist.pmf(ng); 39 | let p_guess = dist.pmf(guess); 40 | if p_guess <= p_exact && p_exact < pmf_comp { 41 | break; 42 | } 43 | if p_guess < p_exact { 44 | max_val = guess 45 | } else { 46 | min_val = guess 47 | } 48 | } 49 | 50 | if guess == 0 { 51 | guess = min_val 52 | } 53 | if upper { 54 | loop { 55 | if guess > 0 && dist.pmf(guess) < p_exact * epsilon { 56 | guess -= 1; 57 | } else { 58 | break; 59 | } 60 | } 61 | loop { 62 | if dist.pmf(guess) > p_exact / epsilon { 63 | guess += 1; 64 | } else { 65 | break; 66 | } 67 | } 68 | } else { 69 | loop { 70 | if dist.pmf(guess) < p_exact * epsilon { 71 | guess += 1; 72 | } else { 73 | break; 74 | } 75 | } 76 | loop { 77 | if guess > 0 && dist.pmf(guess) > p_exact / epsilon { 78 | guess -= 1; 79 | } else { 80 | break; 81 | } 82 | } 83 | } 84 | guess 85 | } 86 | 87 | #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] 88 | #[non_exhaustive] 89 | pub enum FishersExactTestError { 90 | /// The table does not describe a valid [`Hypergeometric`] distribution. 91 | /// Make sure that the contingency table stores the data in row-major order. 92 | TableInvalidForHypergeometric(HypergeometricError), 93 | } 94 | 95 | impl core::fmt::Display for FishersExactTestError { 96 | #[cfg_attr(coverage_nightly, coverage(off))] 97 | fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { 98 | match self { 99 | FishersExactTestError::TableInvalidForHypergeometric(hg_err) => { 100 | writeln!( 101 | f, 102 | "Cannot create a Hypergeometric distribution from the data in the contingency table." 103 | )?; 104 | writeln!(f, "Is it in row-major order?")?; 105 | write!(f, "Inner error: '{hg_err}'") 106 | } 107 | } 108 | } 109 | } 110 | 111 | #[cfg(feature = "std")] 112 | impl std::error::Error for FishersExactTestError {} 113 | 114 | impl From for FishersExactTestError { 115 | fn from(value: HypergeometricError) -> Self { 116 | Self::TableInvalidForHypergeometric(value) 117 | } 118 | } 119 | 120 | /// Perform a Fisher exact test on a 2x2 contingency table. 121 | /// Based on scipy's fisher test: 122 | /// Expects a table in row-major order 123 | /// Returns the [odds ratio](https://en.wikipedia.org/wiki/Odds_ratio) and p_value 124 | /// # Examples 125 | /// 126 | /// ``` 127 | /// use statrs::stats_tests::fishers_exact_with_odds_ratio; 128 | /// use statrs::stats_tests::Alternative; 129 | /// let table = [3, 5, 4, 50]; 130 | /// let (odds_ratio, p_value) = fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap(); 131 | /// ``` 132 | pub fn fishers_exact_with_odds_ratio( 133 | table: &[u64; 4], 134 | alternative: Alternative, 135 | ) -> Result<(f64, f64), FishersExactTestError> { 136 | // If both values in a row or column are zero, p-value is 1 and odds ratio is NaN. 137 | match table { 138 | [0, _, 0, _] | [_, 0, _, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a row 139 | [0, 0, _, _] | [_, _, 0, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a column 140 | _ => (), // continue 141 | } 142 | 143 | let odds_ratio = { 144 | if table[1] > 0 && table[2] > 0 { 145 | (table[0] * table[3]) as f64 / (table[1] * table[2]) as f64 146 | } else { 147 | f64::INFINITY 148 | } 149 | }; 150 | 151 | let p_value = fishers_exact(table, alternative)?; 152 | Ok((odds_ratio, p_value)) 153 | } 154 | 155 | /// Perform a Fisher exact test on a 2x2 contingency table. 156 | /// Based on scipy's fisher test: 157 | /// Expects a table in row-major order 158 | /// Returns only the p_value 159 | /// # Examples 160 | /// 161 | /// ``` 162 | /// use statrs::stats_tests::fishers_exact; 163 | /// use statrs::stats_tests::Alternative; 164 | /// let table = [3, 5, 4, 50]; 165 | /// let p_value = fishers_exact(&table, Alternative::Less).unwrap(); 166 | /// ``` 167 | pub fn fishers_exact( 168 | table: &[u64; 4], 169 | alternative: Alternative, 170 | ) -> Result { 171 | // If both values in a row or column are zero, the p-value is 1 and the odds ratio is NaN. 172 | match table { 173 | [0, _, 0, _] | [_, 0, _, 0] => return Ok(1.0), // both 0 in a row 174 | [0, 0, _, _] | [_, _, 0, 0] => return Ok(1.0), // both 0 in a column 175 | _ => (), // continue 176 | } 177 | 178 | let n1 = table[0] + table[1]; 179 | let n2 = table[2] + table[3]; 180 | let n = table[0] + table[2]; 181 | 182 | let p_value = { 183 | let population = n1 + n2; 184 | let successes = n1; 185 | 186 | match alternative { 187 | Alternative::Less => { 188 | let draws = n; 189 | let dist = Hypergeometric::new(population, successes, draws)?; 190 | dist.cdf(table[0]) 191 | } 192 | Alternative::Greater => { 193 | let draws = table[1] + table[3]; 194 | let dist = Hypergeometric::new(population, successes, draws)?; 195 | dist.cdf(table[1]) 196 | } 197 | Alternative::TwoSided => { 198 | let draws = n; 199 | let dist = Hypergeometric::new(population, successes, draws)?; 200 | 201 | let p_exact = dist.pmf(table[0]); 202 | let mode = ((n + 1) * (n1 + 1)) / (n1 + n2 + 2); 203 | let p_mode = dist.pmf(mode); 204 | 205 | if (p_exact - p_mode).abs() / p_exact.max(p_mode) <= 1.0 - EPSILON { 206 | return Ok(1.0); 207 | } 208 | 209 | if table[0] < mode { 210 | let p_lower = dist.cdf(table[0]); 211 | if dist.pmf(n) > p_exact / EPSILON { 212 | return Ok(p_lower); 213 | } 214 | let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, true); 215 | return Ok(p_lower + 1.0 - dist.cdf(guess - 1)); 216 | } 217 | 218 | let p_upper = 1.0 - dist.cdf(table[0] - 1); 219 | if dist.pmf(0) > p_exact / EPSILON { 220 | return Ok(p_upper); 221 | } 222 | 223 | let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, false); 224 | p_upper + dist.cdf(guess) 225 | } 226 | } 227 | }; 228 | 229 | Ok(p_value.min(1.0)) 230 | } 231 | 232 | #[rustfmt::skip] 233 | #[cfg(test)] 234 | mod tests { 235 | use super::*; 236 | use crate::prec; 237 | 238 | /// Test fishers_exact by comparing against values from scipy. 239 | #[test] 240 | fn test_fishers_exact() { 241 | let cases = [ 242 | ( 243 | [3, 5, 4, 50], 244 | 0.9963034765672599, 245 | 0.03970749246529277, 246 | 0.03970749246529276, 247 | ), 248 | ( 249 | [61, 118, 2, 1], 250 | 0.27535061623455315, 251 | 0.9598172545684959, 252 | 0.27535061623455315, 253 | ), 254 | ( 255 | [172, 46, 90, 127], 256 | 1.0, 257 | 6.662405187351769e-16, 258 | 9.041009036528785e-16, 259 | ), 260 | ( 261 | [127, 38, 112, 43], 262 | 0.8637599357870167, 263 | 0.20040942958644145, 264 | 0.3687862842650179, 265 | ), 266 | ( 267 | [186, 177, 111, 154], 268 | 0.9918518696328176, 269 | 0.012550663906725129, 270 | 0.023439141644624434, 271 | ), 272 | ( 273 | [137, 49, 135, 183], 274 | 0.999999999998533, 275 | 5.6517533666400615e-12, 276 | 8.870999836202932e-12, 277 | ), 278 | ( 279 | [37, 115, 37, 152], 280 | 0.8834621182590621, 281 | 0.17638403366123565, 282 | 0.29400927608021704, 283 | ), 284 | ( 285 | [124, 117, 119, 175], 286 | 0.9956704915461392, 287 | 0.007134712391455461, 288 | 0.011588218284387445, 289 | ), 290 | ( 291 | [70, 114, 41, 118], 292 | 0.9945558498544903, 293 | 0.010384865876586255, 294 | 0.020438291037108678, 295 | ), 296 | ( 297 | [173, 21, 89, 7], 298 | 0.2303739114068352, 299 | 0.8808002774812677, 300 | 0.4027047267306024, 301 | ), 302 | ( 303 | [18, 147, 123, 58], 304 | 4.077820702304103e-29, 305 | 0.9999999999999817, 306 | 0.0, 307 | ), 308 | ( 309 | [116, 20, 92, 186], 310 | 0.9999999999998267, 311 | 6.598118571034892e-25, 312 | 8.164831402188242e-25, 313 | ), 314 | ( 315 | [9, 22, 44, 38], 316 | 0.01584272038710196, 317 | 0.9951463496539362, 318 | 0.021581786662999272, 319 | ), 320 | ( 321 | [9, 101, 135, 7], 322 | 3.3336213533847776e-50, 323 | 1.0, 324 | 3.3336213533847776e-50, 325 | ), 326 | ( 327 | [153, 27, 191, 144], 328 | 0.9999999999950817, 329 | 2.473736787266208e-11, 330 | 3.185816623300107e-11, 331 | ), 332 | ( 333 | [111, 195, 189, 69], 334 | 6.665245982898848e-19, 335 | 0.9999999999994574, 336 | 1.0735744915712542e-18, 337 | ), 338 | ( 339 | [125, 21, 31, 131], 340 | 0.99999999999974, 341 | 9.720661317939016e-34, 342 | 1.0352129312860277e-33, 343 | ), 344 | ( 345 | [201, 192, 69, 179], 346 | 0.9999999988714893, 347 | 3.1477232259550017e-09, 348 | 4.761075937088169e-09, 349 | ), 350 | ( 351 | [124, 138, 159, 160], 352 | 0.30153826772785475, 353 | 0.7538974235759873, 354 | 0.5601766196310243, 355 | ), 356 | ]; 357 | 358 | for (table, less_expected, greater_expected, two_sided_expected) in cases.iter() { 359 | for (alternative, expected) in [ 360 | Alternative::Less, 361 | Alternative::Greater, 362 | Alternative::TwoSided, 363 | ] 364 | .iter() 365 | .zip([less_expected, greater_expected, two_sided_expected]) 366 | { 367 | let p_value = fishers_exact(table, *alternative).unwrap(); 368 | prec::assert_relative_eq!(p_value, *expected); 369 | } 370 | } 371 | } 372 | 373 | #[test] 374 | fn test_fishers_exact_for_trivial() { 375 | let cases = [[0, 0, 1, 2], [1, 2, 0, 0], [1, 0, 2, 0], [0, 1, 0, 2]]; 376 | 377 | for table in cases.iter() { 378 | assert_eq!(fishers_exact(table, Alternative::Less).unwrap(), 1.0) 379 | } 380 | } 381 | 382 | #[test] 383 | fn test_fishers_exact_with_odds() { 384 | let table = [3, 5, 4, 50]; 385 | let (odds_ratio, p_value) = 386 | fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap(); 387 | prec::assert_abs_diff_eq!(p_value, 0.9963034765672599); 388 | prec::assert_abs_diff_eq!(odds_ratio, 7.5); 389 | } 390 | } 391 | -------------------------------------------------------------------------------- /src/function/kernel.rs: -------------------------------------------------------------------------------- 1 | //! Kernel functions for use in kernel-based methods such as 2 | //! kernel density estimation (KDE), local regression, and smoothing. 3 | //! 4 | //! Each kernel maps a normalized distance `x` (often `|x_i - x_0| / h`) 5 | //! to a nonnegative weight. Kernels with bounded support return zero 6 | //! outside a finite interval (e.g., `[-1, 1]`). 7 | //! 8 | //! # Implemented Kernels 9 | //! | Kernel | Formula | Support | 10 | //! |---------|----------|----------| 11 | //! | Gaussian | `exp(-0.5 * x²) / √(2π)` | (-∞, ∞) | 12 | //! | Epanechnikov | `0.75 * (1 - x²)` | [-1, 1] | 13 | //! | Triangular | `1 - |x|` | [-1, 1] | 14 | //! | Tricube | `(1 - |x|³)³` | [-1, 1] | 15 | //! | Quartic (biweight) | `(15/16) * (1 - x²)²` | [-1, 1] | 16 | //! | Uniform | `0.5` | [-1, 1] | 17 | //! | Cosine | `(π/4) * cos(πx/2)` | [-1, 1] | 18 | //! | Logistic | `1 / (2 + exp(x) + exp(-x))` | (-∞, ∞) | 19 | //! | Sigmoid | `(2 / π) * (1 / (exp(πx) + exp(-πx)))` | (-∞, ∞) | 20 | //! 21 | //! # Example 22 | //! ``` 23 | //! use statrs::function::kernel::{Kernel, Gaussian, Epanechnikov}; 24 | //! 25 | //! let g = Gaussian; 26 | //! let e = Epanechnikov; 27 | //! assert!((g.evaluate(0.0) - 0.39894).abs() < 1e-5); 28 | //! assert!((e.evaluate(0.0) - 0.75).abs() < 1e-12); 29 | //! ``` 30 | 31 | use core::f64::consts::{FRAC_PI_2, PI}; 32 | 33 | /// Common interface for kernel functions used in KDE and smoothing. 34 | pub trait Kernel { 35 | /// Evaluate the kernel at normalized distance `x`. 36 | fn evaluate(&self, x: f64) -> f64; 37 | 38 | /// Evaluate the kernel with bandwidth scaling. 39 | /// 40 | /// The result is scaled by `1 / bandwidth` to ensure 41 | /// that the kernel integrates to 1 after scaling. 42 | fn evaluate_with_bandwidth(&self, x: f64, bandwidth: f64) -> f64 { 43 | self.evaluate(x / bandwidth) / bandwidth 44 | } 45 | 46 | /// Returns the support of the kernel if bounded (e.g. `[-1, 1]`). 47 | fn support(&self) -> Option<(f64, f64)> { 48 | None 49 | } 50 | } 51 | 52 | /// Gaussian kernel: (1 / √(2π)) * exp(-0.5 * x²) 53 | #[derive(Debug, Clone, Copy, Default, PartialEq)] 54 | pub struct Gaussian; 55 | 56 | impl Kernel for Gaussian { 57 | fn evaluate(&self, x: f64) -> f64 { 58 | (-(x * x) / 2.0).exp() / (2.0 * PI).sqrt() 59 | } 60 | } 61 | 62 | /// Epanechnikov kernel: ¾(1 - x²) for |x| ≤ 1 63 | #[derive(Debug, Clone, Copy, Default, PartialEq)] 64 | pub struct Epanechnikov; 65 | 66 | impl Kernel for Epanechnikov { 67 | fn evaluate(&self, x: f64) -> f64 { 68 | let a = x.abs(); 69 | if a <= 1.0 { 0.75 * (1.0 - a * a) } else { 0.0 } 70 | } 71 | 72 | fn support(&self) -> Option<(f64, f64)> { 73 | Some((-1.0, 1.0)) 74 | } 75 | } 76 | 77 | /// Triangular kernel: (1 - |x|) for |x| ≤ 1 78 | #[derive(Debug, Clone, Copy, Default, PartialEq)] 79 | pub struct Triangular; 80 | 81 | impl Kernel for Triangular { 82 | fn evaluate(&self, x: f64) -> f64 { 83 | let a = x.abs(); 84 | if a <= 1.0 { 1.0 - a } else { 0.0 } 85 | } 86 | 87 | fn support(&self) -> Option<(f64, f64)> { 88 | Some((-1.0, 1.0)) 89 | } 90 | } 91 | 92 | /// Tricube kernel: (1 - |x|³)³ for |x| ≤ 1 93 | #[derive(Debug, Clone, Copy, Default, PartialEq)] 94 | pub struct Tricube; 95 | 96 | impl Kernel for Tricube { 97 | fn evaluate(&self, x: f64) -> f64 { 98 | let a = x.abs(); 99 | if a <= 1.0 { 100 | let t = 1.0 - a.powi(3); 101 | t.powi(3) 102 | } else { 103 | 0.0 104 | } 105 | } 106 | 107 | fn support(&self) -> Option<(f64, f64)> { 108 | Some((-1.0, 1.0)) 109 | } 110 | } 111 | 112 | /// Quartic (biweight) kernel: (15/16) * (1 - x²)² for |x| ≤ 1 113 | #[derive(Debug, Clone, Copy, Default, PartialEq)] 114 | pub struct Quartic; 115 | 116 | impl Kernel for Quartic { 117 | fn evaluate(&self, x: f64) -> f64 { 118 | let a = x.abs(); 119 | if a <= 1.0 { 120 | let t = 1.0 - a * a; 121 | (15.0 / 16.0) * t * t 122 | } else { 123 | 0.0 124 | } 125 | } 126 | 127 | fn support(&self) -> Option<(f64, f64)> { 128 | Some((-1.0, 1.0)) 129 | } 130 | } 131 | 132 | /// Uniform (rectangular) kernel: 0.5 for |x| ≤ 1, 0 otherwise 133 | #[derive(Debug, Clone, Copy, Default, PartialEq)] 134 | pub struct Uniform; 135 | 136 | impl Kernel for Uniform { 137 | fn evaluate(&self, x: f64) -> f64 { 138 | if x.abs() <= 1.0 { 0.5 } else { 0.0 } 139 | } 140 | 141 | fn support(&self) -> Option<(f64, f64)> { 142 | Some((-1.0, 1.0)) 143 | } 144 | } 145 | 146 | /// Cosine kernel: (π/4) * cos(πx/2) for |x| ≤ 1, 0 otherwise 147 | /// 148 | /// This kernel integrates to 1 over [-1, 1]. 149 | #[derive(Debug, Clone, Copy, Default, PartialEq)] 150 | pub struct Cosine; 151 | 152 | impl Kernel for Cosine { 153 | fn evaluate(&self, x: f64) -> f64 { 154 | let a = x.abs(); 155 | if a <= 1.0 { 156 | (PI / 4.0) * (FRAC_PI_2 * a).cos() 157 | } else { 158 | 0.0 159 | } 160 | } 161 | 162 | fn support(&self) -> Option<(f64, f64)> { 163 | Some((-1.0, 1.0)) 164 | } 165 | } 166 | 167 | /// Logistic kernel: 1 / (2 + exp(x) + exp(-x)) 168 | #[derive(Debug, Clone, Copy, Default, PartialEq)] 169 | pub struct Logistic; 170 | 171 | impl Kernel for Logistic { 172 | fn evaluate(&self, x: f64) -> f64 { 173 | 1.0 / (2.0 + x.exp() + (-x).exp()) 174 | } 175 | } 176 | 177 | /// Sigmoid kernel: (1 / (π * cosh(πx))) ≈ (2 / π) * (1 / (exp(πx) + exp(-πx))) 178 | /// 179 | /// Note: Integrates to 1/π over (-∞, ∞). 180 | #[derive(Debug, Clone, Copy, Default, PartialEq)] 181 | pub struct Sigmoid; 182 | 183 | impl Kernel for Sigmoid { 184 | fn evaluate(&self, x: f64) -> f64 { 185 | (2.0 / PI) * 1.0 / ((PI * x).exp() + (-PI * x).exp()) 186 | } 187 | } 188 | 189 | #[cfg(test)] 190 | mod tests { 191 | use super::*; 192 | use crate::prec::assert_abs_diff_eq; 193 | 194 | fn integrate f64>(f: F, a: f64, b: f64, n: usize) -> f64 { 195 | let h = (b - a) / n as f64; 196 | let mut sum = 0.0; 197 | for i in 0..n { 198 | let x0 = a + i as f64 * h; 199 | let x1 = a + (i + 1) as f64 * h; 200 | sum += 0.5 * (f(x0) + f(x1)) * h; 201 | } 202 | sum 203 | } 204 | 205 | #[test] 206 | fn uniform_behavior() { 207 | let k = Uniform; 208 | assert_eq!(k.evaluate(0.0), 0.5); 209 | assert_eq!(k.evaluate(0.8), 0.5); 210 | assert_eq!(k.evaluate(1.0), 0.5); 211 | assert_eq!(k.evaluate(1.01), 0.0); 212 | assert_eq!(k.evaluate(-1.01), 0.0); 213 | // symmetry 214 | assert_abs_diff_eq!(k.evaluate(0.5), k.evaluate(-0.5), epsilon = 1e-15); 215 | // normalization check 216 | let integral = integrate(|u| k.evaluate(u), -1.0, 1.0, 10_000); 217 | assert_abs_diff_eq!(integral, 1.0, epsilon = 1e-3); 218 | } 219 | 220 | #[test] 221 | fn cosine_behavior() { 222 | let k = Cosine; 223 | assert_abs_diff_eq!(k.evaluate(0.0), PI / 4.0, epsilon = 1e-12); 224 | assert_abs_diff_eq!(k.evaluate(1.0), 0.0, epsilon = 1e-15); 225 | assert_abs_diff_eq!(k.evaluate(-1.0), 0.0, epsilon = 1e-15); 226 | assert!(k.evaluate(0.25) > k.evaluate(0.75)); 227 | assert_abs_diff_eq!(k.evaluate(0.3), k.evaluate(-0.3), epsilon = 1e-15); 228 | let integral = integrate(|u| k.evaluate(u), -1.0, 1.0, 10_000); 229 | assert_abs_diff_eq!(integral, 1.0, epsilon = 1e-3); 230 | } 231 | 232 | #[test] 233 | fn logistic_behavior() { 234 | let k = Logistic; 235 | assert_abs_diff_eq!(k.evaluate(0.0), 0.25, epsilon = 1e-12); 236 | assert!(k.evaluate(0.0) > k.evaluate(2.0)); 237 | assert_abs_diff_eq!(k.evaluate(0.5), k.evaluate(-0.5), epsilon = 1e-15); 238 | // integral over a wide range should approximate 1.0 239 | let integral = integrate(|u| k.evaluate(u), -10.0, 10.0, 50_000); 240 | assert_abs_diff_eq!(integral, 1.0, epsilon = 1e-3); 241 | } 242 | 243 | #[test] 244 | fn sigmoid_behavior() { 245 | let k = Sigmoid; 246 | assert_abs_diff_eq!(k.evaluate(0.0), 1.0 / PI, epsilon = 1e-12); 247 | assert!(k.evaluate(0.0) > k.evaluate(2.0)); 248 | assert_abs_diff_eq!(k.evaluate(0.5), k.evaluate(-0.5), epsilon = 1e-15); 249 | let integral = integrate(|u| k.evaluate(u), -10.0, 10.0, 50_000); 250 | assert_abs_diff_eq!(integral, 1.0 / PI, epsilon = 1e-3); 251 | } 252 | 253 | #[test] 254 | fn tricube_basic_properties() { 255 | let k = Tricube; 256 | assert_abs_diff_eq!(k.evaluate(0.5), k.evaluate(-0.5), epsilon = 1e-15); 257 | assert_eq!(k.evaluate(1.0), 0.0); 258 | assert_eq!(k.evaluate(-1.0), 0.0); 259 | assert_eq!(k.evaluate(0.0), 1.0); 260 | assert!(k.evaluate(0.25) > k.evaluate(0.5)); 261 | assert!(k.evaluate(0.5) > k.evaluate(0.75)); 262 | } 263 | 264 | #[test] 265 | fn epanechnikov_behavior() { 266 | let k = Epanechnikov; 267 | assert_abs_diff_eq!(k.evaluate(0.3), k.evaluate(-0.3), epsilon = 1e-15); 268 | assert_eq!(k.evaluate(1.0), 0.0); 269 | assert_eq!(k.evaluate(-1.0), 0.0); 270 | assert!(k.evaluate(0.0) > k.evaluate(0.8)); 271 | assert!(k.evaluate(0.5) > 0.0); 272 | assert!(k.evaluate(0.5) < k.evaluate(0.0)); 273 | } 274 | 275 | #[test] 276 | fn quartic_behavior() { 277 | let k = Quartic; 278 | assert_abs_diff_eq!(k.evaluate(0.0), 15.0 / 16.0, epsilon = 1e-12); 279 | assert_eq!(k.evaluate(1.0), 0.0); 280 | assert_eq!(k.evaluate(-1.0), 0.0); 281 | assert_abs_diff_eq!(k.evaluate(0.3), k.evaluate(-0.3), epsilon = 1e-15); 282 | assert!(k.evaluate(0.25) > k.evaluate(0.75)); 283 | assert_eq!(k.evaluate(1.1), 0.0); 284 | } 285 | 286 | #[test] 287 | fn triangular_behavior() { 288 | let k = Triangular; 289 | assert_eq!(k.evaluate(0.0), 1.0); 290 | assert_eq!(k.evaluate(1.0), 0.0); 291 | assert_eq!(k.evaluate(-1.0), 0.0); 292 | assert_abs_diff_eq!(k.evaluate(0.3), k.evaluate(-0.3), epsilon = 1e-15); 293 | assert!(k.evaluate(0.25) > k.evaluate(0.75)); 294 | assert_eq!(k.evaluate(1.2), 0.0); 295 | } 296 | 297 | #[test] 298 | fn gaussian_behavior() { 299 | let k = Gaussian; 300 | assert_abs_diff_eq!(k.evaluate(0.5), k.evaluate(-0.5), epsilon = 1e-15); 301 | assert!(k.evaluate(0.0) > k.evaluate(1.0)); 302 | assert!(k.evaluate(2.0) < 0.2); 303 | for u in [-3.0, -1.0, 0.0, 1.0, 3.0] { 304 | assert!(k.evaluate(u) >= 0.0); 305 | } 306 | } 307 | 308 | #[test] 309 | fn kernel_trait_usage() { 310 | struct Linear; 311 | impl Kernel for Linear { 312 | fn evaluate(&self, x: f64) -> f64 { 313 | (1.0 - x.abs()).max(0.0) 314 | } 315 | } 316 | 317 | let lin = Linear; 318 | assert_eq!(lin.evaluate(0.0), 1.0); 319 | assert_eq!(lin.evaluate(1.5), 0.0); 320 | 321 | let t = Tricube; 322 | let g = Gaussian; 323 | assert_eq!(t.evaluate(0.0), 1.0); 324 | assert!(g.evaluate(1.0) < 1.0); 325 | assert_abs_diff_eq!(t.evaluate(0.5), Tricube.evaluate(0.5), epsilon = 1e-15); 326 | } 327 | 328 | #[test] 329 | fn bandwidth_scaling_equivalence() { 330 | let g = Gaussian; 331 | let scaled = g.evaluate_with_bandwidth(0.5, 2.0); 332 | let manual = g.evaluate(0.25) / 2.0; 333 | assert_abs_diff_eq!(scaled, manual, epsilon = 1e-14); 334 | } 335 | 336 | #[test] 337 | fn monotonicity_samples() { 338 | let kernels: [&dyn Kernel; 4] = [&Tricube, &Epanechnikov, &Quartic, &Triangular]; 339 | let samples = [0.0_f64, 0.25, 0.5, 0.75, 0.99]; 340 | for k in kernels { 341 | let mut prev = k.evaluate(0.0); 342 | for &u in &samples[1..] { 343 | let cur = k.evaluate(u); 344 | assert!( 345 | cur <= prev + 1e-12, 346 | "kernel not nonincreasing at u={}, prev={}, cur={}", 347 | u, 348 | prev, 349 | cur 350 | ); 351 | prev = cur; 352 | } 353 | } 354 | } 355 | 356 | #[test] 357 | fn integrate_tricube_to_expected() { 358 | let k = Tricube; 359 | let integral = integrate(|u| k.evaluate(u), -1.0, 1.0, 10_000); 360 | let expected = 81.0 / 70.0; // ≈ 1.1571 361 | assert!( 362 | (integral - expected).abs() < 1e-3, 363 | "Tricube integral ≈ {}, expected ≈ {}", 364 | integral, 365 | expected 366 | ); 367 | } 368 | 369 | #[test] 370 | fn integrate_epanechnikov_to_one() { 371 | let k = Epanechnikov; 372 | let integral = integrate(|u| k.evaluate(u), -1.0, 1.0, 10_000); 373 | assert!( 374 | (integral - 1.0).abs() < 1e-3, 375 | "Epanechnikov integral ≈ {}", 376 | integral 377 | ); 378 | } 379 | 380 | #[test] 381 | fn integrate_quartic_to_one() { 382 | let k = Quartic; 383 | let integral = integrate(|u| k.evaluate(u), -1.0, 1.0, 10_000); 384 | assert!( 385 | (integral - 1.0).abs() < 1e-3, 386 | "Quartic integral ≈ {}", 387 | integral 388 | ); 389 | } 390 | 391 | #[test] 392 | fn integrate_triangular_to_one() { 393 | let k = Triangular; 394 | let integral = integrate(|u| k.evaluate(u), -1.0, 1.0, 10_000); 395 | assert!( 396 | (integral - 1.0).abs() < 1e-3, 397 | "Triangular integral ≈ {}", 398 | integral 399 | ); 400 | } 401 | } 402 | --------------------------------------------------------------------------------