├── .gitignore ├── supply-chain ├── imports.lock ├── config.toml └── audits.toml ├── README.md ├── Cargo.toml ├── LICENSE ├── benches └── benchmarks.rs ├── CHANGELOG.md ├── Cargo.lock └── src ├── error.rs ├── stats └── tests.rs ├── stats.rs ├── tests.rs └── lib.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | -------------------------------------------------------------------------------- /supply-chain/imports.lock: -------------------------------------------------------------------------------- 1 | 2 | # cargo-vet imports lock 3 | 4 | [audits] 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # linregress 2 | A Rust library providing an easy to use implementation of ordinary 3 | least squared linear regression with some basic statistics. 4 | 5 | ## Documentation 6 | 7 | [Full API documentation](https://docs.rs/linregress) 8 | 9 | ## License 10 | This project is licensed under the MIT License. 11 | See LICENSE for details. 12 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "linregress" 3 | version = "0.5.4" 4 | authors = ["Nils Mehrtens "] 5 | edition = "2024" 6 | description = "ordinary least squared linear regression with some basic statistics" 7 | documentation = "https://docs.rs/linregress" 8 | homepage = "https://github.com/n1m3/linregress" 9 | repository = "https://github.com/n1m3/linregress" 10 | readme = "README.md" 11 | keywords = ["statistics", "regression", "ols"] 12 | license = "MIT" 13 | 14 | [dependencies] 15 | nalgebra = { version = "0.33.0", default-features = false, features = ["std"] } 16 | 17 | [dev-dependencies] 18 | tiny-bench = "0.3.0" 19 | 20 | [[bench]] 21 | name = "benchmarks" 22 | harness = false 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 Nils Mehrtens 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /supply-chain/config.toml: -------------------------------------------------------------------------------- 1 | 2 | # cargo-vet config file 3 | 4 | [policy.linregress] 5 | audit-as-crates-io = false 6 | 7 | [[exemptions.approx]] 8 | version = "0.5.1" 9 | criteria = "safe-to-deploy" 10 | 11 | [[exemptions.autocfg]] 12 | version = "1.1.0" 13 | criteria = "safe-to-deploy" 14 | 15 | [[exemptions.bytemuck]] 16 | version = "1.10.0" 17 | criteria = "safe-to-deploy" 18 | 19 | [[exemptions.matrixmultiply]] 20 | version = "0.3.2" 21 | criteria = "safe-to-deploy" 22 | 23 | [[exemptions.nalgebra]] 24 | version = "0.31.0" 25 | criteria = "safe-to-deploy" 26 | 27 | [[exemptions.num-complex]] 28 | version = "0.4.1" 29 | criteria = "safe-to-deploy" 30 | 31 | [[exemptions.num-integer]] 32 | version = "0.1.45" 33 | criteria = "safe-to-deploy" 34 | 35 | [[exemptions.num-rational]] 36 | version = "0.4.0" 37 | criteria = "safe-to-deploy" 38 | 39 | [[exemptions.num-traits]] 40 | version = "0.2.14" 41 | criteria = "safe-to-deploy" 42 | 43 | [[exemptions.paste]] 44 | version = "1.0.7" 45 | criteria = "safe-to-deploy" 46 | 47 | [[exemptions.rawpointer]] 48 | version = "0.2.1" 49 | criteria = "safe-to-deploy" 50 | 51 | [[exemptions.safe_arch]] 52 | version = "0.6.0" 53 | criteria = "safe-to-deploy" 54 | 55 | [[exemptions.simba]] 56 | version = "0.7.1" 57 | criteria = "safe-to-deploy" 58 | 59 | [[exemptions.typenum]] 60 | version = "1.15.0" 61 | criteria = "safe-to-deploy" 62 | 63 | [[exemptions.wide]] 64 | version = "0.7.4" 65 | criteria = "safe-to-deploy" 66 | -------------------------------------------------------------------------------- /benches/benchmarks.rs: -------------------------------------------------------------------------------- 1 | use std::hint::black_box; 2 | 3 | use linregress::*; 4 | 5 | fn main() { 6 | let y = vec![1., 2., 3., 4., 5.]; 7 | let x1 = vec![5., 4., 3., 2., 1.]; 8 | let x2 = vec![729.53, 439.0367, 42.054, 1., 0.]; 9 | let x3 = vec![258.589, 616.297, 215.061, 498.361, 0.]; 10 | let data = vec![("Y", y), ("X1", x1), ("X2", x2), ("X3", x3)]; 11 | let data = RegressionDataBuilder::new().build_from(data).unwrap(); 12 | let formula = "Y ~ X1 + X2 + X3"; 13 | tiny_bench::bench_labeled("formula with stats", || { 14 | FormulaRegressionBuilder::new() 15 | .data(black_box(&data)) 16 | .formula(black_box(formula)) 17 | .fit() 18 | .unwrap(); 19 | }); 20 | tiny_bench::bench_labeled("formula without stats", || { 21 | FormulaRegressionBuilder::new() 22 | .data(black_box(&data)) 23 | .formula(black_box(formula)) 24 | .fit_without_statistics() 25 | .unwrap(); 26 | }); 27 | let columns = ("Y", ["X1", "X2", "X3"]); 28 | tiny_bench::bench_labeled("data columns with stats", || { 29 | let columns = black_box(columns); 30 | FormulaRegressionBuilder::new() 31 | .data(black_box(&data)) 32 | .data_columns(columns.0, columns.1) 33 | .fit() 34 | .unwrap(); 35 | }); 36 | tiny_bench::bench_labeled("data columns without stats", || { 37 | let columns = black_box(columns); 38 | FormulaRegressionBuilder::new() 39 | .data(black_box(&data)) 40 | .data_columns(columns.0, columns.1) 41 | .fit_without_statistics() 42 | .unwrap(); 43 | }); 44 | } 45 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## Unreleased 2 | 3 | ## 0.5.4 4 | ### Changed 5 | - Update `nalgebra` to `0.33.0` 6 | - Disable `nalgebra` `macros` feature to reduce dependencies 7 | - Update project to Rust 2021 edition 8 | 9 | ## 0.5.3 10 | ### Changed 11 | - Update `nalgebra` to `0.32.3` 12 | 13 | ## 0.5.2 14 | ### Changed 15 | - Update `nalgebra` to `0.32.2` 16 | - Replace dev dependency `criterion` with `tiny-bench` 17 | - Minor performance improvements 18 | 19 | ## 0.5.1 20 | ### Changed 21 | - Update `nalgebra` to version 0.32.1 22 | - Optimize various calculations (see [ef94ca0](https://github.com/n1m3/linregress/commit/ef94ca07ededb5d551309d581555778f71bf5136)) 23 | ### Bug fixes 24 | - Fix model fitting failure when standard error is equal to zero 25 | 26 | ## 0.5.0 27 | ### Changed 28 | - Update `nalgebra` to `0.31.0` 29 | - Fully replace `Cephes` special functions with new implementation based on implementation in `statrs` 30 | - Remove `statrs` dependency. All statistics related code is now implemented in this crate 31 | - Remove quickcheck related dev-dependencies 32 | - Port benchmarks to criterion 33 | 34 | ### Added 35 | - Added `assert_almost_eq` and `assert_slices_almost_eq` macros for use in doc tests 36 | 37 | ## 0.5.0-alpha.1 38 | ### Breaking changes 39 | - Rework API to remove `RegressionParameters` struct 40 | - `FormulaRegressionBuilder::fit_without_statistics` returns a `Vec` 41 | - The fields of `RegressionModel` and `LowLevelRegressionModel` are now private. 42 | - Appropriate accessor methods have been added. 43 | - `RegressionParameters::pairs` has been replaced with `iter_` methods on `RegressionModel` 44 | 45 | ## 0.4.4 46 | ### Added 47 | - Add `data_columns` method to `FormulaRegressionBuilder` 48 | It allows setting the regressand a regressors without using a formula string. 49 | - Add `fit_low_level_regression_model` and `fit_low_level_regression_model_without_statistics` 50 | functions for performing a regression directly on a matrix of input data 51 | 52 | ## 0.4.3 53 | ### Changed 54 | - Update `statrs` dependency to `0.15.0` to avoid multiple versions of `nalgebra` in our dependency tree 55 | 56 | ## 0.4.2 57 | ### Changed 58 | - Update `nalgebra` to `0.27.1` in response to RUSTSEC-2021-0070 59 | - Update `statrs` to `0.14.0` 60 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "approx" 7 | version = "0.5.1" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" 10 | dependencies = [ 11 | "num-traits", 12 | ] 13 | 14 | [[package]] 15 | name = "autocfg" 16 | version = "1.1.0" 17 | source = "registry+https://github.com/rust-lang/crates.io-index" 18 | checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" 19 | 20 | [[package]] 21 | name = "bytemuck" 22 | version = "1.13.1" 23 | source = "registry+https://github.com/rust-lang/crates.io-index" 24 | checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" 25 | 26 | [[package]] 27 | name = "linregress" 28 | version = "0.5.4" 29 | dependencies = [ 30 | "nalgebra", 31 | "tiny-bench", 32 | ] 33 | 34 | [[package]] 35 | name = "matrixmultiply" 36 | version = "0.3.7" 37 | source = "registry+https://github.com/rust-lang/crates.io-index" 38 | checksum = "090126dc04f95dc0d1c1c91f61bdd474b3930ca064c1edc8a849da2c6cbe1e77" 39 | dependencies = [ 40 | "autocfg", 41 | "rawpointer", 42 | ] 43 | 44 | [[package]] 45 | name = "nalgebra" 46 | version = "0.33.0" 47 | source = "registry+https://github.com/rust-lang/crates.io-index" 48 | checksum = "3c4b5f057b303842cf3262c27e465f4c303572e7f6b0648f60e16248ac3397f4" 49 | dependencies = [ 50 | "approx", 51 | "matrixmultiply", 52 | "num-complex", 53 | "num-rational", 54 | "num-traits", 55 | "simba", 56 | "typenum", 57 | ] 58 | 59 | [[package]] 60 | name = "num-complex" 61 | version = "0.4.4" 62 | source = "registry+https://github.com/rust-lang/crates.io-index" 63 | checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" 64 | dependencies = [ 65 | "num-traits", 66 | ] 67 | 68 | [[package]] 69 | name = "num-integer" 70 | version = "0.1.45" 71 | source = "registry+https://github.com/rust-lang/crates.io-index" 72 | checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" 73 | dependencies = [ 74 | "autocfg", 75 | "num-traits", 76 | ] 77 | 78 | [[package]] 79 | name = "num-rational" 80 | version = "0.4.1" 81 | source = "registry+https://github.com/rust-lang/crates.io-index" 82 | checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" 83 | dependencies = [ 84 | "autocfg", 85 | "num-integer", 86 | "num-traits", 87 | ] 88 | 89 | [[package]] 90 | name = "num-traits" 91 | version = "0.2.16" 92 | source = "registry+https://github.com/rust-lang/crates.io-index" 93 | checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" 94 | dependencies = [ 95 | "autocfg", 96 | ] 97 | 98 | [[package]] 99 | name = "paste" 100 | version = "1.0.14" 101 | source = "registry+https://github.com/rust-lang/crates.io-index" 102 | checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" 103 | 104 | [[package]] 105 | name = "rawpointer" 106 | version = "0.2.1" 107 | source = "registry+https://github.com/rust-lang/crates.io-index" 108 | checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" 109 | 110 | [[package]] 111 | name = "safe_arch" 112 | version = "0.7.1" 113 | source = "registry+https://github.com/rust-lang/crates.io-index" 114 | checksum = "f398075ce1e6a179b46f51bd88d0598b92b00d3551f1a2d4ac49e771b56ac354" 115 | dependencies = [ 116 | "bytemuck", 117 | ] 118 | 119 | [[package]] 120 | name = "simba" 121 | version = "0.9.0" 122 | source = "registry+https://github.com/rust-lang/crates.io-index" 123 | checksum = "b3a386a501cd104797982c15ae17aafe8b9261315b5d07e3ec803f2ea26be0fa" 124 | dependencies = [ 125 | "approx", 126 | "num-complex", 127 | "num-traits", 128 | "paste", 129 | "wide", 130 | ] 131 | 132 | [[package]] 133 | name = "tiny-bench" 134 | version = "0.3.0" 135 | source = "registry+https://github.com/rust-lang/crates.io-index" 136 | checksum = "cfda840d8557c12ecdde25485f7dc85152339a61b759f98d4f682e7cb5d75948" 137 | 138 | [[package]] 139 | name = "typenum" 140 | version = "1.16.0" 141 | source = "registry+https://github.com/rust-lang/crates.io-index" 142 | checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" 143 | 144 | [[package]] 145 | name = "wide" 146 | version = "0.7.11" 147 | source = "registry+https://github.com/rust-lang/crates.io-index" 148 | checksum = "aa469ffa65ef7e0ba0f164183697b89b854253fd31aeb92358b7b6155177d62f" 149 | dependencies = [ 150 | "bytemuck", 151 | "safe_arch", 152 | ] 153 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use std::error; 2 | use std::fmt; 3 | 4 | /// An error that can occur in this crate. 5 | /// 6 | /// Generally this error corresponds to problems with input data or fitting 7 | /// a regression model. 8 | #[derive(Debug, Clone)] 9 | #[non_exhaustive] 10 | pub enum Error { 11 | /// Number of slopes and output names is inconsistent. 12 | InconsistentSlopes(InconsistentSlopes), 13 | /// Cannot fit model without data. 14 | NoData, 15 | /// Cannot fit model without formula or data columns. 16 | NoFormula, 17 | /// Given formula is invalid. 18 | InvalidFormula, 19 | /// Given data columns are invalid. 20 | InvalidDataColumns, 21 | /// You must specify either a formula or data columns. 22 | BothFormulaAndDataColumnsGiven, 23 | /// Requested column is not in data. (Column given as String) 24 | ColumnNotInData(String), 25 | /// A column used in the model is misising from the provided data 26 | ModelColumnNotInData(String), 27 | /// Regressor and regressand dimensions do not match. (Column given as String) 28 | RegressorRegressandDimensionMismatch(String), 29 | /// Error while processing the regression data. (Details given as String) 30 | RegressionDataError(String), 31 | /// Error while fitting the model. (Details given as String) 32 | ModelFittingError(String), 33 | /// The given vectors have inconsistent lengths 34 | InconsistentVectors, 35 | /// The RegressionModel internal state is inconsistent 36 | InconsistentRegressionModel, 37 | } 38 | 39 | #[derive(Debug, Clone, Copy)] 40 | pub struct InconsistentSlopes { 41 | output_name_count: usize, 42 | slope_count: usize, 43 | } 44 | 45 | impl InconsistentSlopes { 46 | pub(crate) fn new(output_name_count: usize, slope_count: usize) -> Self { 47 | Self { 48 | output_name_count, 49 | slope_count, 50 | } 51 | } 52 | 53 | pub fn get_output_name_count(&self) -> usize { 54 | self.output_name_count 55 | } 56 | 57 | pub fn get_slope_count(&self) -> usize { 58 | self.slope_count 59 | } 60 | } 61 | 62 | impl error::Error for Error {} 63 | 64 | impl fmt::Display for Error { 65 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 66 | match self { 67 | Error::InconsistentSlopes(inconsistent_slopes) => write!( 68 | f, 69 | "Number of slopes and output names is inconsistent. {} outputs != {} slopes", 70 | inconsistent_slopes.get_output_name_count(), 71 | inconsistent_slopes.get_slope_count() 72 | ), 73 | Error::NoData => write!(f, "Cannot fit model without data"), 74 | Error::NoFormula => write!(f, "Cannot fit model without formula"), 75 | Error::InvalidDataColumns => write!(f, "Invalid data columns"), 76 | Error::InvalidFormula => write!( 77 | f, 78 | "Invalid formula. Expected formula of the form 'y ~ x1 + x2'" 79 | ), 80 | Error::BothFormulaAndDataColumnsGiven => { 81 | write!(f, "You must specify either a formula or data columns") 82 | } 83 | Error::ColumnNotInData(column) => { 84 | write!(f, "Requested column {} is not in the data", column) 85 | } 86 | Error::RegressorRegressandDimensionMismatch(column) => write!( 87 | f, 88 | "Regressor dimensions for {} do not match regressand dimensions", 89 | column 90 | ), 91 | Error::RegressionDataError(detail) => { 92 | write!(f, "Error while processing the regression data: {}", detail) 93 | } 94 | Error::ModelFittingError(detail) => { 95 | write!(f, "Error while fitting the model: {}", detail) 96 | } 97 | Error::ModelColumnNotInData(column) => write!( 98 | f, 99 | "The column {} used in the model is misising from the provided data", 100 | column 101 | ), 102 | Error::InconsistentVectors => write!(f, "The given vectors have inconsistent lengths"), 103 | Error::InconsistentRegressionModel => write!( 104 | f, 105 | concat!( 106 | "The RegressionModel internal state is inconsistent:", 107 | " The number of regressor names and values differ." 108 | ) 109 | ), 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /supply-chain/audits.toml: -------------------------------------------------------------------------------- 1 | 2 | # cargo-vet audits file 3 | 4 | [[audits.bytemuck]] 5 | who = "Nils Mehrtens" 6 | criteria = "safe-to-deploy" 7 | delta = "1.10.0 -> 1.12.1" 8 | 9 | [[audits.bytemuck]] 10 | who = "Nils Mehrtens" 11 | criteria = "safe-to-deploy" 12 | delta = "1.12.1 -> 1.12.2" 13 | 14 | [[audits.bytemuck]] 15 | who = "Nils Mehrtens" 16 | criteria = "safe-to-deploy" 17 | delta = "1.12.2 -> 1.13.1" 18 | 19 | [[audits.matrixmultiply]] 20 | who = "Nils Mehrtens" 21 | criteria = "safe-to-deploy" 22 | delta = "0.3.2 -> 0.3.7" 23 | 24 | [[audits.nalgebra]] 25 | who = "Nils Mehrtens" 26 | criteria = "safe-to-deploy" 27 | delta = "0.31.0 -> 0.31.3" 28 | 29 | [[audits.nalgebra]] 30 | who = "Nils Mehrtens" 31 | criteria = "safe-to-deploy" 32 | delta = "0.31.3 -> 0.32.1" 33 | 34 | [[audits.nalgebra]] 35 | who = "Nils Mehrtens" 36 | criteria = "safe-to-deploy" 37 | delta = "0.32.1 -> 0.32.2" 38 | 39 | [[audits.nalgebra]] 40 | who = "Nils Mehrtens" 41 | criteria = "safe-to-deploy" 42 | delta = "0.32.2 -> 0.32.3" 43 | 44 | [[audits.nalgebra]] 45 | who = "Nils Mehrtens" 46 | criteria = "safe-to-deploy" 47 | delta = "0.32.3 -> 0.33.0" 48 | 49 | [[audits.nalgebra-macros]] 50 | who = "Nils Mehrtens" 51 | criteria = "safe-to-deploy" 52 | delta = "0.1.0 -> 0.2.0" 53 | 54 | [[audits.nalgebra-macros]] 55 | who = "Nils Mehrtens" 56 | criteria = "safe-to-deploy" 57 | delta = "0.2.0 -> 0.2.1" 58 | 59 | [[audits.num-complex]] 60 | who = "Nils Mehrtens" 61 | criteria = "safe-to-deploy" 62 | delta = "0.4.1 -> 0.4.2" 63 | 64 | [[audits.num-complex]] 65 | who = "Nils Mehrtens" 66 | criteria = "safe-to-deploy" 67 | delta = "0.4.2 -> 0.4.4" 68 | 69 | [[audits.num-rational]] 70 | who = "Nils Mehrtens" 71 | criteria = "safe-to-deploy" 72 | delta = "0.4.0 -> 0.4.1" 73 | 74 | [[audits.num-traits]] 75 | who = "Nils Mehrtens" 76 | criteria = "safe-to-deploy" 77 | delta = "0.2.14 -> 0.2.15" 78 | 79 | [[audits.num-traits]] 80 | who = "Nils Mehrtens" 81 | criteria = "safe-to-deploy" 82 | delta = "0.2.15 -> 0.2.16" 83 | 84 | [[audits.paste]] 85 | who = "Nils Mehrtens" 86 | criteria = "safe-to-deploy" 87 | delta = "1.0.7 -> 1.0.9" 88 | 89 | [[audits.paste]] 90 | who = "Nils Mehrtens" 91 | criteria = "safe-to-deploy" 92 | delta = "1.0.9 -> 1.0.14" 93 | 94 | [[audits.proc-macro2]] 95 | who = "Nils Mehrtens" 96 | criteria = "safe-to-deploy" 97 | delta = "1.0.30 -> 1.0.47" 98 | 99 | [[audits.proc-macro2]] 100 | who = "Nils Mehrtens" 101 | criteria = "safe-to-deploy" 102 | delta = "1.0.37 -> 1.0.40" 103 | 104 | [[audits.proc-macro2]] 105 | who = "Nils Mehrtens" 106 | criteria = "safe-to-deploy" 107 | delta = "1.0.40 -> 1.0.47" 108 | 109 | [[audits.proc-macro2]] 110 | who = "Nils Mehrtens" 111 | criteria = "safe-to-deploy" 112 | delta = "1.0.47 -> 1.0.66" 113 | 114 | [[audits.quote]] 115 | who = "Nils Mehrtens" 116 | criteria = "safe-to-deploy" 117 | delta = "1.0.18 -> 1.0.20" 118 | 119 | [[audits.quote]] 120 | who = "Nils Mehrtens" 121 | criteria = "safe-to-deploy" 122 | delta = "1.0.20 -> 1.0.21" 123 | 124 | [[audits.quote]] 125 | who = "Nils Mehrtens" 126 | criteria = "safe-to-deploy" 127 | delta = "1.0.21 -> 1.0.33" 128 | 129 | [[audits.safe_arch]] 130 | who = "Nils Mehrtens" 131 | criteria = "safe-to-deploy" 132 | delta = "0.6.0 -> 0.7.1" 133 | 134 | [[audits.simba]] 135 | who = "Nils Mehrtens" 136 | criteria = "safe-to-deploy" 137 | delta = "0.7.1 -> 0.7.3" 138 | 139 | [[audits.simba]] 140 | who = "Nils Mehrtens" 141 | criteria = "safe-to-deploy" 142 | delta = "0.7.3 -> 0.8.0" 143 | 144 | [[audits.simba]] 145 | who = "Nils Mehrtens" 146 | criteria = "safe-to-deploy" 147 | delta = "0.8.0 -> 0.8.1" 148 | 149 | [[audits.simba]] 150 | who = "Nils Mehrtens" 151 | criteria = "safe-to-deploy" 152 | delta = "0.8.1 -> 0.9.0" 153 | 154 | [[audits.syn]] 155 | who = "Nils Mehrtens" 156 | criteria = "safe-to-deploy" 157 | delta = "1.0.80 -> 1.0.103" 158 | 159 | [[audits.syn]] 160 | who = "Nils Mehrtens" 161 | criteria = "safe-to-deploy" 162 | delta = "1.0.92 -> 1.0.98" 163 | 164 | [[audits.syn]] 165 | who = "Nils Mehrtens" 166 | criteria = "safe-to-deploy" 167 | delta = "1.0.98 -> 1.0.103" 168 | 169 | [[audits.syn]] 170 | who = "Nils Mehrtens" 171 | criteria = "safe-to-deploy" 172 | delta = "1.0.103 -> 1.0.109" 173 | 174 | [[audits.tiny-bench]] 175 | who = "Nils Mehrtens" 176 | criteria = "safe-to-deploy" 177 | version = "0.3.0" 178 | 179 | [[audits.typenum]] 180 | who = "Nils Mehrtens" 181 | criteria = "safe-to-deploy" 182 | delta = "1.15.0 -> 1.16.0" 183 | 184 | [[audits.unicode-ident]] 185 | who = "Nils Mehrtens" 186 | criteria = "safe-to-deploy" 187 | version = "1.0.5" 188 | 189 | [[audits.unicode-ident]] 190 | who = "Nils Mehrtens" 191 | criteria = "safe-to-deploy" 192 | delta = "1.0.5 -> 1.0.11" 193 | 194 | [[audits.wide]] 195 | who = "Nils Mehrtens" 196 | criteria = "safe-to-deploy" 197 | delta = "0.7.4 -> 0.7.5" 198 | 199 | [[audits.wide]] 200 | who = "Nils Mehrtens" 201 | criteria = "safe-to-deploy" 202 | delta = "0.7.5 -> 0.7.11" 203 | -------------------------------------------------------------------------------- /src/stats/tests.rs: -------------------------------------------------------------------------------- 1 | // This file contains code adapted from the statrs library (https://github.com/statrs-dev/statrs/) 2 | // licensed under the MIT License: 3 | // 4 | // MIT License 5 | // 6 | // Copyright (c) 2016 Michael Ma 7 | // 8 | // Permission is hereby granted, free of charge, to any person obtaining a copy 9 | // of this software and associated documentation files (the "Software"), to deal 10 | // in the Software without restriction, including without limitation the rights 11 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | // copies of the Software, and to permit persons to whom the Software is 13 | // furnished to do so, subject to the following conditions: 14 | // 15 | // The above copyright notice and this permission notice shall be included in all 16 | // copies or substantial portions of the Software. 17 | // 18 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | // SOFTWARE. 25 | 26 | use std::f64::consts::{FRAC_1_SQRT_2, LN_2, PI}; 27 | 28 | use super::*; 29 | use crate::assert_almost_eq; 30 | 31 | fn beta_reg(a: f64, b: f64, x: f64) -> f64 { 32 | checked_beta_reg(a, b, x).unwrap() 33 | } 34 | 35 | #[test] 36 | fn test_students_t_cdf() { 37 | assert_eq!(students_t_cdf(0., 1).unwrap(), 0.5); 38 | assert_eq!(students_t_cdf(0., 2).unwrap(), 0.5); 39 | assert_almost_eq!(students_t_cdf(1., 1).unwrap(), 0.75, 1e-15); 40 | assert_almost_eq!(students_t_cdf(-1., 1).unwrap(), 0.25, 1e-15); 41 | assert_almost_eq!(students_t_cdf(2., 1).unwrap(), 0.852416382349567, 1e-15); 42 | assert_almost_eq!(students_t_cdf(-2., 1).unwrap(), 0.147583617650433, 1e-15); 43 | assert_almost_eq!(students_t_cdf(1., 2).unwrap(), 0.788675134594813, 1e-15); 44 | assert_almost_eq!(students_t_cdf(-1., 2).unwrap(), 0.211324865405187, 1e-15); 45 | assert_almost_eq!(students_t_cdf(2., 2).unwrap(), 0.908248290463863, 1e-15); 46 | assert_almost_eq!(students_t_cdf(-2., 2).unwrap(), 0.091751709536137, 1e-15); 47 | } 48 | 49 | #[test] 50 | fn test_beta_reg() { 51 | assert_almost_eq!(beta_reg(0.5, 0.5, 0.5), 0.5, 1e-15); 52 | assert_eq!(beta_reg(0.5, 0.5, 1.0), 1.0); 53 | assert_almost_eq!(beta_reg(1.0, 0.5, 0.5), 0.292_893_218_813_452_5, 1e-15); 54 | assert_eq!(beta_reg(1.0, 0.5, 1.0), 1.0); 55 | assert_almost_eq!(beta_reg(2.5, 0.5, 0.5), 0.075_586_818_421_612_44, 1e-16); 56 | assert_eq!(beta_reg(2.5, 0.5, 1.0), 1.0); 57 | assert_almost_eq!(beta_reg(0.5, 1.0, 0.5), FRAC_1_SQRT_2, 1e-15); 58 | assert_eq!(beta_reg(0.5, 1.0, 1.0), 1.0); 59 | assert_almost_eq!(beta_reg(1.0, 1.0, 0.5), 0.5, 1e-15); 60 | assert_eq!(beta_reg(1.0, 1.0, 1.0), 1.0); 61 | assert_almost_eq!(beta_reg(2.5, 1.0, 0.5), 0.176_776_695_296_636_9, 1e-15); 62 | assert_eq!(beta_reg(2.5, 1.0, 1.0), 1.0); 63 | assert_eq!(beta_reg(0.5, 2.5, 0.5), 0.924_413_181_578_387_6); 64 | assert_eq!(beta_reg(0.5, 2.5, 1.0), 1.0); 65 | assert_almost_eq!(beta_reg(1.0, 2.5, 0.5), 0.823_223_304_703_363_1, 1e-15); 66 | assert_eq!(beta_reg(1.0, 2.5, 1.0), 1.0); 67 | assert_almost_eq!(beta_reg(2.5, 2.5, 0.5), 0.5, 1e-15); 68 | assert_eq!(beta_reg(2.5, 2.5, 1.0), 1.0); 69 | } 70 | 71 | #[test] 72 | #[should_panic] 73 | fn test_beta_reg_a_lte_0() { 74 | beta_reg(0.0, 1.0, 1.0); 75 | } 76 | 77 | #[test] 78 | #[should_panic] 79 | fn test_beta_reg_b_lte_0() { 80 | beta_reg(1.0, 0.0, 1.0); 81 | } 82 | 83 | #[test] 84 | #[should_panic] 85 | fn test_beta_reg_x_lt_0() { 86 | beta_reg(1.0, 1.0, -1.0); 87 | } 88 | 89 | #[test] 90 | #[should_panic] 91 | fn test_beta_reg_x_gt_1() { 92 | beta_reg(1.0, 1.0, 2.0); 93 | } 94 | 95 | #[test] 96 | fn test_checked_beta_reg_a_lte_0() { 97 | assert!(checked_beta_reg(0.0, 1.0, 1.0).is_none()); 98 | } 99 | 100 | #[test] 101 | fn test_checked_beta_reg_b_lte_0() { 102 | assert!(checked_beta_reg(1.0, 0.0, 1.0).is_none()); 103 | } 104 | 105 | #[test] 106 | fn test_checked_beta_reg_x_lt_0() { 107 | assert!(checked_beta_reg(1.0, 1.0, -1.0).is_none()); 108 | } 109 | 110 | #[test] 111 | fn test_checked_beta_reg_x_gt_1() { 112 | assert!(checked_beta_reg(1.0, 1.0, 2.0).is_none()); 113 | } 114 | #[test] 115 | fn test_ln_gamma() { 116 | assert!(ln_gamma(f64::NAN).is_nan()); 117 | assert_eq!(ln_gamma(1.000001e-35), 80.590_477_254_792_1); 118 | assert_almost_eq!(ln_gamma(1.000001e-10), 23.025_849_929_883_236, 1e-14); 119 | assert_almost_eq!(ln_gamma(1.000001e-5), 11.512_918_692_890_553, 1e-14); 120 | assert_eq!(ln_gamma(1.000001e-2), 4.599_478_872_433_667); 121 | assert_almost_eq!(ln_gamma(0.1), 2.252_712_651_734_206, 1e-14); 122 | assert_almost_eq!(ln_gamma(1.0 - 1.0e-14), 5.772_156_649_015_411e-15, 1e-15); 123 | assert_almost_eq!(ln_gamma(1.0), 0.0, 1e-15); 124 | assert_almost_eq!(ln_gamma(1.0 + 1.0e-14), -5.772_156_649_015_246e-15, 1e-15); 125 | assert_almost_eq!(ln_gamma(1.5), -0.120_782_237_635_245_22, 1e-14); 126 | assert_almost_eq!( 127 | ln_gamma(f64::consts::PI / 2.0), 128 | -0.115_903_800_845_502_42, 129 | 1e-14 130 | ); 131 | assert_eq!(ln_gamma(2.0), 0.0); 132 | assert_almost_eq!(ln_gamma(2.5), 0.284_682_870_472_919_2, 1e-13); 133 | assert_almost_eq!(ln_gamma(3.0), LN_2, 1e-14); 134 | assert_almost_eq!(ln_gamma(PI), 0.827_694_592_323_437_1, 1e-13); 135 | assert_almost_eq!(ln_gamma(3.5), 1.200_973_602_347_074_3, 1e-14); 136 | assert_almost_eq!(ln_gamma(4.0), 1.791_759_469_228_055, 1e-14); 137 | assert_almost_eq!(ln_gamma(4.5), 2.453_736_570_842_442_3, 1e-13); 138 | assert_almost_eq!(ln_gamma(5.0 - 1.0e-14), 3.178_053_830_347_930_7, 1e-14); 139 | assert_almost_eq!(ln_gamma(5.0), 3.178_053_830_347_945_8, 1e-14); 140 | assert_almost_eq!(ln_gamma(5.0 + 1.0e-14), 3.178_053_830_347_961, 1e-13); 141 | assert_almost_eq!(ln_gamma(5.5), 3.957_813_967_618_716_5, 1e-14); 142 | assert_almost_eq!(ln_gamma(10.1), 13.027_526_738_633_238, 1e-14); 143 | assert_almost_eq!(ln_gamma(150.0 + 1.0e-12), 600.009_470_555_332_4, 1e-12); 144 | assert_almost_eq!(ln_gamma(1.001e+7), 1.513_421_353_238_179e8, 1e-13); 145 | } 146 | -------------------------------------------------------------------------------- /src/stats.rs: -------------------------------------------------------------------------------- 1 | // This file contains code adapted from the statrs library (https://github.com/statrs-dev/statrs/) 2 | // licensed under the MIT License: 3 | // 4 | // MIT License 5 | // 6 | // Copyright (c) 2016 Michael Ma 7 | // 8 | // Permission is hereby granted, free of charge, to any person obtaining a copy 9 | // of this software and associated documentation files (the "Software"), to deal 10 | // in the Software without restriction, including without limitation the rights 11 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | // copies of the Software, and to permit persons to whom the Software is 13 | // furnished to do so, subject to the following conditions: 14 | // 15 | // The above copyright notice and this permission notice shall be included in all 16 | // copies or substantial portions of the Software. 17 | // 18 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | // SOFTWARE. 25 | 26 | use std::f64; 27 | 28 | use crate::ulps_eq; 29 | 30 | #[cfg(test)] 31 | mod tests; 32 | 33 | // ln(pi) 34 | const LN_PI: f64 = 1.144_729_885_849_400_2; 35 | // ln(2 * sqrt(e / pi)) 36 | const LN_2_SQRT_E_OVER_PI: f64 = 0.620_782_237_635_245_2; 37 | // Auxiliary variable when evaluating the `gamma_ln` function 38 | const GAMMA_R: f64 = 10.900511; 39 | 40 | // Polynomial coefficients for approximating the `gamma_ln` function 41 | const GAMMA_DK: [f64; 11] = [ 42 | 2.485_740_891_387_535_5e-5, 43 | 1.051_423_785_817_219_7, 44 | -3.456_870_972_220_162_5, 45 | 4.512_277_094_668_948, 46 | -2.982_852_253_235_766_4, 47 | 1.056_397_115_771_267, 48 | -1.954_287_731_916_458_7e-1, 49 | 1.709_705_434_044_412e-2, 50 | -5.719_261_174_043_057e-4, 51 | 4.633_994_733_599_057e-6, 52 | -2.719_949_084_886_077_2e-9, 53 | ]; 54 | 55 | // Standard epsilon, maximum relative precision of IEEE 754 double-precision 56 | // floating point numbers (64 bit) e.g. `2^-53` 57 | const F64_PREC: f64 = 1.1102230246251565e-16; 58 | 59 | /// Calculates the cumulative distribution function for the student's T at `x` 60 | /// with location `0` and scale `1`. 61 | /// 62 | /// # Formula 63 | /// 64 | /// ```ignore 65 | /// if x < μ { 66 | /// (1 / 2) * I(t, v / 2, 1 / 2) 67 | /// } else { 68 | /// 1 - (1 / 2) * I(t, v / 2, 1 / 2) 69 | /// } 70 | /// ``` 71 | /// 72 | /// where `t = v / (v + k^2)`, `k = (x - μ) / σ`, `μ` is the location, 73 | /// `σ` is the scale, `v` is the freedom, and `I` is the regularized 74 | /// incomplete 75 | /// beta function 76 | pub fn students_t_cdf(x: f64, freedom: i64) -> Option { 77 | if freedom <= 0 { 78 | return None; 79 | } 80 | let location: f64 = 0.; 81 | let scale: f64 = 1.0; 82 | let freedom = freedom as f64; 83 | let k = (x - location) / scale; 84 | let h = freedom / (freedom + k * k); 85 | let ib = 0.5 * checked_beta_reg(freedom / 2.0, 0.5, h)?; 86 | if x <= location { 87 | Some(ib) 88 | } else { 89 | Some(1.0 - ib) 90 | } 91 | } 92 | 93 | /// Computes the regularized lower incomplete beta function 94 | /// `I_x(a,b) = 1/Beta(a,b) * int(t^(a-1)*(1-t)^(b-1), t=0..x)` 95 | /// `a > 0`, `b > 0`, `1 >= x >= 0` where `a` is the first beta parameter, 96 | /// `b` is the second beta parameter, and `x` is the upper limit of the 97 | /// integral. 98 | /// 99 | /// Returns `None` if `a <= 0.0`, `b <= 0.0`, `x < 0.0`, or `x > 1.0` 100 | fn checked_beta_reg(a: f64, b: f64, x: f64) -> Option { 101 | if a <= 0. || b <= 0. || !(0.0..=1.).contains(&x) { 102 | return None; 103 | } 104 | let bt = if x == 0. || ulps_eq(x, 1.0, f64::EPSILON, 4) { 105 | 0.0 106 | } else { 107 | (ln_gamma(a + b) - ln_gamma(a) - ln_gamma(b) + a * x.ln() + b * (1.0 - x).ln()).exp() 108 | }; 109 | let symm_transform = x >= (a + 1.0) / (a + b + 2.0); 110 | let eps = F64_PREC; 111 | let fpmin = f64::MIN_POSITIVE / eps; 112 | 113 | let mut a = a; 114 | let mut b = b; 115 | let mut x = x; 116 | if symm_transform { 117 | let swap = a; 118 | x = 1.0 - x; 119 | a = b; 120 | b = swap; 121 | } 122 | 123 | let qab = a + b; 124 | let qap = a + 1.0; 125 | let qam = a - 1.0; 126 | let mut c = 1.0; 127 | let mut d = 1.0 - qab * x / qap; 128 | 129 | if d.abs() < fpmin { 130 | d = fpmin; 131 | } 132 | d = 1.0 / d; 133 | let mut h = d; 134 | 135 | for m in 1..141 { 136 | let m = f64::from(m); 137 | let m2 = m * 2.0; 138 | let mut aa = m * (b - m) * x / ((qam + m2) * (a + m2)); 139 | d = 1.0 + aa * d; 140 | 141 | if d.abs() < fpmin { 142 | d = fpmin; 143 | } 144 | 145 | c = 1.0 + aa / c; 146 | if c.abs() < fpmin { 147 | c = fpmin; 148 | } 149 | 150 | d = 1.0 / d; 151 | h = h * d * c; 152 | aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2)); 153 | d = 1.0 + aa * d; 154 | 155 | if d.abs() < fpmin { 156 | d = fpmin; 157 | } 158 | 159 | c = 1.0 + aa / c; 160 | 161 | if c.abs() < fpmin { 162 | c = fpmin; 163 | } 164 | 165 | d = 1.0 / d; 166 | let del = d * c; 167 | h *= del; 168 | 169 | if (del - 1.0).abs() <= eps { 170 | return if symm_transform { 171 | Some(1.0 - bt * h / a) 172 | } else { 173 | Some(bt * h / a) 174 | }; 175 | } 176 | } 177 | 178 | if symm_transform { 179 | Some(1.0 - bt * h / a) 180 | } else { 181 | Some(bt * h / a) 182 | } 183 | } 184 | 185 | /// Computes the logarithm of the gamma function 186 | /// with an accuracy of 16 floating point digits. 187 | /// The implementation is derived from 188 | /// "An Analysis of the Lanczos Gamma Approximation", 189 | /// Glendon Ralph Pugh, 2004 p. 116 190 | fn ln_gamma(x: f64) -> f64 { 191 | if x < 0.5 { 192 | let s = GAMMA_DK 193 | .iter() 194 | .enumerate() 195 | .skip(1) 196 | .fold(GAMMA_DK[0], |s, t| s + t.1 / (t.0 as f64 - x)); 197 | 198 | LN_PI 199 | - (f64::consts::PI * x).sin().ln() 200 | - s.ln() 201 | - LN_2_SQRT_E_OVER_PI 202 | - (0.5 - x) * ((0.5 - x + GAMMA_R) / f64::consts::E).ln() 203 | } else { 204 | let s = GAMMA_DK 205 | .iter() 206 | .enumerate() 207 | .skip(1) 208 | .fold(GAMMA_DK[0], |s, t| s + t.1 / (x + t.0 as f64 - 1.0)); 209 | 210 | s.ln() + LN_2_SQRT_E_OVER_PI + (x - 0.5) * ((x - 0.5 + GAMMA_R) / f64::consts::E).ln() 211 | } 212 | } 213 | -------------------------------------------------------------------------------- /src/tests.rs: -------------------------------------------------------------------------------- 1 | use std::f64::consts::PI; 2 | 3 | use super::*; 4 | 5 | #[test] 6 | fn test_pinv_with_formula_builder() { 7 | use std::collections::HashMap; 8 | let inputs = vec![1., 3., 4., 5., 2., 3., 4.]; 9 | let outputs1 = vec![1., 2., 3., 4., 5., 6., 7.]; 10 | let outputs2 = vec![7., 6., 5., 4., 3., 2., 1.]; 11 | let mut data = HashMap::new(); 12 | data.insert("Y", inputs); 13 | data.insert("X1", outputs1); 14 | data.insert("X2", outputs2); 15 | let data = RegressionDataBuilder::new().build_from(data).unwrap(); 16 | let regression = FormulaRegressionBuilder::new() 17 | .data(&data) 18 | .formula("Y ~ X1 + X2") 19 | .fit() 20 | .expect("Fitting model failed"); 21 | 22 | let model_parameters = vec![0.09523809523809523, 0.5059523809523809, 0.2559523809523808]; 23 | let se = vec![ 24 | 0.015457637291218289, 25 | 0.1417242813072997, 26 | 0.14172428130729975, 27 | ]; 28 | let ssr = 9.107142857142858; 29 | let rsquared = 0.16118421052631582; 30 | let rsquared_adj = -0.006578947368421018; 31 | let scale = 1.8214285714285716; 32 | let pvalues = vec![ 33 | 0.001639031204417556, 34 | 0.016044083709847945, 35 | 0.13074580446389245, 36 | ]; 37 | let residuals = vec![ 38 | -1.392857142857142, 39 | 0.3571428571428581, 40 | 1.1071428571428577, 41 | 1.8571428571428577, 42 | -1.3928571428571423, 43 | -0.6428571428571423, 44 | 0.10714285714285765, 45 | ]; 46 | assert_slices_almost_eq!(regression.parameters(), &model_parameters); 47 | assert_slices_almost_eq!(regression.se(), &se); 48 | assert_almost_eq!(regression.ssr(), ssr); 49 | assert_almost_eq!(regression.rsquared(), rsquared); 50 | assert_almost_eq!(regression.rsquared_adj(), rsquared_adj); 51 | assert_slices_almost_eq!(regression.p_values(), &pvalues); 52 | assert_slices_almost_eq!(regression.residuals(), &residuals); 53 | assert_almost_eq!(regression.scale(), scale); 54 | } 55 | 56 | #[test] 57 | fn test_pinv_with_data_columns() { 58 | use std::collections::HashMap; 59 | let inputs = vec![1., 3., 4., 5., 2., 3., 4.]; 60 | let outputs1 = vec![1., 2., 3., 4., 5., 6., 7.]; 61 | let outputs2 = vec![7., 6., 5., 4., 3., 2., 1.]; 62 | let mut data = HashMap::new(); 63 | data.insert("Y", inputs); 64 | data.insert("X1", outputs1); 65 | data.insert("X2", outputs2); 66 | let data = RegressionDataBuilder::new().build_from(data).unwrap(); 67 | let regression = FormulaRegressionBuilder::new() 68 | .data(&data) 69 | .data_columns("Y", ["X1", "X2"]) 70 | .fit() 71 | .expect("Fitting model failed"); 72 | 73 | let model_parameters = vec![0.09523809523809523, 0.5059523809523809, 0.2559523809523808]; 74 | let se = vec![ 75 | 0.015457637291218289, 76 | 0.1417242813072997, 77 | 0.14172428130729975, 78 | ]; 79 | let ssr = 9.107142857142858; 80 | let rsquared = 0.16118421052631582; 81 | let rsquared_adj = -0.006578947368421018; 82 | let scale = 1.8214285714285716; 83 | let pvalues = vec![ 84 | 0.001639031204417556, 85 | 0.016044083709847945, 86 | 0.13074580446389245, 87 | ]; 88 | let residuals = vec![ 89 | -1.392857142857142, 90 | 0.3571428571428581, 91 | 1.1071428571428577, 92 | 1.8571428571428577, 93 | -1.3928571428571423, 94 | -0.6428571428571423, 95 | 0.10714285714285765, 96 | ]; 97 | assert_slices_almost_eq!(regression.parameters(), &model_parameters); 98 | assert_slices_almost_eq!(regression.se(), &se); 99 | assert_almost_eq!(regression.ssr(), ssr); 100 | assert_almost_eq!(regression.rsquared(), rsquared); 101 | assert_almost_eq!(regression.rsquared_adj(), rsquared_adj); 102 | assert_slices_almost_eq!(regression.p_values(), &pvalues); 103 | assert_slices_almost_eq!(regression.residuals(), &residuals); 104 | assert_almost_eq!(regression.scale(), scale); 105 | } 106 | 107 | #[test] 108 | fn test_regression_standard_error_equal_to_zero_does_not_prevent_fitting() { 109 | // Regression test for underlying issue of https://github.com/n1m3/linregress/issues/9 110 | 111 | // The following input does not conform to our API (we expect that all intercepts == 1, not 0), 112 | // but Hyrum's law... 113 | let data = vec![ 114 | 0.0, 115 | 0.0, 116 | 0.0, 117 | 34059798.0, 118 | 0.0, 119 | 1.0, 120 | 66771421.0, 121 | 0.0, 122 | 2.0, 123 | 100206133.0, 124 | 0.0, 125 | 3.0, 126 | 133435943.0, 127 | 0.0, 128 | 4.0, 129 | 166028256.0, 130 | 0.0, 131 | 5.0, 132 | 199723152.0, 133 | 0.0, 134 | 6.0, 135 | 233754352.0, 136 | 0.0, 137 | 7.0, 138 | 267284084.0, 139 | 0.0, 140 | 8.0, 141 | 301756656.0, 142 | 0.0, 143 | 9.0, 144 | 331420366.0, 145 | 0.0, 146 | 10.0, 147 | 367961084.0, 148 | 0.0, 149 | 11.0, 150 | 401288216.0, 151 | 0.0, 152 | 12.0, 153 | 434555574.0, 154 | 0.0, 155 | 13.0, 156 | 469093436.0, 157 | 0.0, 158 | 14.0, 159 | 501541551.0, 160 | 0.0, 161 | 15.0, 162 | 523986797.0, 163 | 0.0, 164 | 16.0, 165 | 558792615.0, 166 | 0.0, 167 | 17.0, 168 | 631494010.0, 169 | 0.0, 170 | 18.0, 171 | 669229109.0, 172 | 0.0, 173 | 19.0, 174 | 704321427.0, 175 | 0.0, 176 | 20.0, 177 | ]; 178 | let rows = 21; 179 | let columns = 3; 180 | fit_low_level_regression_model(&data, rows, columns).unwrap(); 181 | } 182 | 183 | #[test] 184 | fn test_low_level_model_fitting() { 185 | let inputs = [1., 3., 4., 5., 2., 3., 4.]; 186 | let outputs1 = [1., 2., 3., 4., 5., 6., 7.]; 187 | let outputs2 = [7., 6., 5., 4., 3., 2., 1.]; 188 | let mut data_row_major = Vec::with_capacity(4 * 7); 189 | for n in 0..7 { 190 | data_row_major.push(inputs[n]); 191 | data_row_major.push(1.0); 192 | data_row_major.push(outputs1[n]); 193 | data_row_major.push(outputs2[n]); 194 | } 195 | let regression = fit_low_level_regression_model(&data_row_major, 7, 4).unwrap(); 196 | let model_parameters = vec![0.09523809523809523, 0.5059523809523809, 0.2559523809523808]; 197 | let se = vec![ 198 | 0.015457637291218289, 199 | 0.1417242813072997, 200 | 0.14172428130729975, 201 | ]; 202 | let ssr = 9.107142857142858; 203 | let rsquared = 0.16118421052631582; 204 | let rsquared_adj = -0.006578947368421018; 205 | let scale = 1.8214285714285716; 206 | let pvalues = vec![ 207 | 0.001639031204417556, 208 | 0.016044083709847945, 209 | 0.13074580446389245, 210 | ]; 211 | let residuals = vec![ 212 | -1.392857142857142, 213 | 0.3571428571428581, 214 | 1.1071428571428577, 215 | 1.8571428571428577, 216 | -1.3928571428571423, 217 | -0.6428571428571423, 218 | 0.10714285714285765, 219 | ]; 220 | assert_slices_almost_eq!(regression.parameters(), &model_parameters); 221 | assert_slices_almost_eq!(regression.se(), &se); 222 | assert_almost_eq!(regression.ssr(), ssr); 223 | assert_almost_eq!(regression.rsquared(), rsquared); 224 | assert_almost_eq!(regression.rsquared_adj(), rsquared_adj); 225 | assert_slices_almost_eq!(regression.p_values(), &pvalues); 226 | assert_slices_almost_eq!(regression.residuals(), &residuals); 227 | assert_almost_eq!(regression.scale(), scale); 228 | } 229 | 230 | #[test] 231 | fn test_without_statistics() { 232 | use std::collections::HashMap; 233 | let inputs = vec![1., 3., 4., 5., 2., 3., 4.]; 234 | let outputs1 = vec![1., 2., 3., 4., 5., 6., 7.]; 235 | let outputs2 = vec![7., 6., 5., 4., 3., 2., 1.]; 236 | let mut data = HashMap::new(); 237 | data.insert("Y", inputs); 238 | data.insert("X1", outputs1); 239 | data.insert("X2", outputs2); 240 | let data = RegressionDataBuilder::new().build_from(data).unwrap(); 241 | let regression = FormulaRegressionBuilder::new() 242 | .data(&data) 243 | .formula("Y ~ X1 + X2") 244 | .fit_without_statistics() 245 | .expect("Fitting model failed"); 246 | let model_parameters = vec![0.09523809523809523, 0.5059523809523809, 0.2559523809523808]; 247 | assert_slices_almost_eq!(®ression, &model_parameters); 248 | } 249 | 250 | #[test] 251 | fn test_invalid_input_empty_matrix() { 252 | let y = vec![]; 253 | let x1 = vec![]; 254 | let x2 = vec![]; 255 | let data = vec![("Y", y), ("X1", x1), ("X2", x2)]; 256 | let data = RegressionDataBuilder::new().build_from(data); 257 | assert!(data.is_err()); 258 | } 259 | 260 | #[test] 261 | fn test_invalid_input_wrong_shape_x() { 262 | let y = vec![1., 2., 3.]; 263 | let x1 = vec![1., 2., 3.]; 264 | let x2 = vec![1., 2.]; 265 | let data = vec![("Y", y), ("X1", x1), ("X2", x2)]; 266 | let data = RegressionDataBuilder::new().build_from(data); 267 | assert!(data.is_err()); 268 | } 269 | 270 | #[test] 271 | fn test_invalid_input_wrong_shape_y() { 272 | let y = vec![1., 2., 3., 4.]; 273 | let x1 = vec![1., 2., 3.]; 274 | let x2 = vec![1., 2., 3.]; 275 | let data = vec![("Y", y), ("X1", x1), ("X2", x2)]; 276 | let data = RegressionDataBuilder::new().build_from(data); 277 | assert!(data.is_err()); 278 | } 279 | 280 | #[test] 281 | fn test_invalid_input_nan() { 282 | let y1 = vec![1., 2., 3., 4.]; 283 | let x1 = vec![1., 2., 3., f64::NAN]; 284 | let data1 = vec![("Y", y1), ("X", x1)]; 285 | let y2 = vec![1., 2., 3., f64::NAN]; 286 | let x2 = vec![1., 2., 3., 4.]; 287 | let data2 = vec![("Y", y2), ("X", x2)]; 288 | let r_data1 = RegressionDataBuilder::new().build_from(data1.to_owned()); 289 | let r_data2 = RegressionDataBuilder::new().build_from(data2.to_owned()); 290 | assert!(r_data1.is_err()); 291 | assert!(r_data2.is_err()); 292 | let builder = RegressionDataBuilder::new(); 293 | let builder = builder.invalid_value_handling(InvalidValueHandling::DropInvalid); 294 | let r_data1 = builder.build_from(data1); 295 | let r_data2 = builder.build_from(data2); 296 | assert!(r_data1.is_ok()); 297 | assert!(r_data2.is_ok()); 298 | } 299 | 300 | #[test] 301 | fn test_invalid_input_infinity() { 302 | let y1 = vec![1., 2., 3., 4.]; 303 | let x1 = vec![1., 2., 3., f64::INFINITY]; 304 | let data1 = vec![("Y", y1), ("X", x1)]; 305 | let y2 = vec![1., 2., 3., f64::NEG_INFINITY]; 306 | let x2 = vec![1., 2., 3., 4.]; 307 | let data2 = vec![("Y", y2), ("X", x2)]; 308 | let r_data1 = RegressionDataBuilder::new().build_from(data1.to_owned()); 309 | let r_data2 = RegressionDataBuilder::new().build_from(data2.to_owned()); 310 | assert!(r_data1.is_err()); 311 | assert!(r_data2.is_err()); 312 | let builder = RegressionDataBuilder::new(); 313 | let builder = builder.invalid_value_handling(InvalidValueHandling::DropInvalid); 314 | let r_data1 = builder.build_from(data1); 315 | let r_data2 = builder.build_from(data2); 316 | assert!(r_data1.is_ok()); 317 | assert!(r_data2.is_ok()); 318 | } 319 | 320 | #[test] 321 | fn test_invalid_input_all_equal_columns() { 322 | let y = vec![38.0, 38.0, 38.0]; 323 | let x = vec![42.0, 42.0, 42.0]; 324 | let data = vec![("y", y), ("x", x)]; 325 | let data = RegressionDataBuilder::new().build_from(data); 326 | assert!(data.is_err()); 327 | } 328 | 329 | #[test] 330 | fn test_drop_invalid_values() { 331 | let mut data: HashMap, Vec> = HashMap::new(); 332 | data.insert("Y".into(), vec![-1., -2., -3., -4.]); 333 | data.insert("foo".into(), vec![1., 2., 12., 4.]); 334 | data.insert("bar".into(), vec![1., 1., 7., 4.]); 335 | data.insert("baz".into(), vec![1.3333, 2.754, 3.12, 4.11]); 336 | assert_eq!(RegressionData::drop_invalid_values(data.to_owned()), data); 337 | data.insert( 338 | "invalid".into(), 339 | vec![f64::NAN, 42., f64::NEG_INFINITY, 23.11], 340 | ); 341 | data.insert( 342 | "invalid2".into(), 343 | vec![1.337, PI, f64::INFINITY, 11.111111], 344 | ); 345 | let mut ref_data: HashMap, Vec> = HashMap::new(); 346 | ref_data.insert("Y".into(), vec![-2., -4.]); 347 | ref_data.insert("foo".into(), vec![2., 4.]); 348 | ref_data.insert("bar".into(), vec![1., 4.]); 349 | ref_data.insert("baz".into(), vec![2.754, 4.11]); 350 | ref_data.insert("invalid".into(), vec![42., 23.11]); 351 | ref_data.insert("invalid2".into(), vec![PI, 11.111111]); 352 | assert_eq!( 353 | ref_data, 354 | RegressionData::drop_invalid_values(data.to_owned()) 355 | ); 356 | } 357 | 358 | #[test] 359 | fn test_all_invalid_input() { 360 | let data = vec![ 361 | ("Y", vec![1., 2., 3.]), 362 | ("X", vec![f64::NAN, f64::NAN, f64::NAN]), 363 | ]; 364 | let builder = RegressionDataBuilder::new(); 365 | let builder = builder.invalid_value_handling(InvalidValueHandling::DropInvalid); 366 | let r_data = builder.build_from(data); 367 | assert!(r_data.is_err()); 368 | } 369 | 370 | #[test] 371 | fn test_invalid_column_names() { 372 | let data1 = vec![("x~f", vec![1., 2., 3.]), ("foo", vec![0., 0., 0.])]; 373 | let data2 = vec![("foo", vec![1., 2., 3.]), ("foo+", vec![0., 0., 0.])]; 374 | let builder = RegressionDataBuilder::new(); 375 | assert!(builder.build_from(data1).is_err()); 376 | assert!(builder.build_from(data2).is_err()); 377 | } 378 | 379 | #[test] 380 | fn test_no_formula() { 381 | let data = vec![("x", vec![1., 2., 3.]), ("foo", vec![0., 0., 0.])]; 382 | let data = RegressionDataBuilder::new().build_from(data).unwrap(); 383 | let res = FormulaRegressionBuilder::new().data(&data).fit(); 384 | assert!(res.is_err()); 385 | } 386 | 387 | #[test] 388 | fn test_both_formula_and_data_columns() { 389 | let y = vec![1., 2., 3., 4., 5.]; 390 | let x1 = vec![5., 4., 3., 2., 1.]; 391 | let x2 = vec![729.53, 439.0367, 42.054, 1., 0.]; 392 | let x3 = vec![258.589, 616.297, 215.061, 498.361, 0.]; 393 | let data = vec![("Y", y), ("X1", x1), ("X2", x2), ("X3", x3)]; 394 | let data = RegressionDataBuilder::new().build_from(data).unwrap(); 395 | let formula = "Y ~ X1 + X2 + X3"; 396 | let res = FormulaRegressionBuilder::new() 397 | .data(&data) 398 | .formula(formula) 399 | .data_columns("Y", ["X1", "X2", "X3"]) 400 | .fit(); 401 | assert!(res.is_err()); 402 | } 403 | 404 | fn build_model() -> RegressionModel { 405 | let y = vec![1., 2., 3., 4., 5.]; 406 | let x1 = vec![5., 4., 3., 2., 1.]; 407 | let x2 = vec![729.53, 439.0367, 42.054, 1., 0.]; 408 | let x3 = vec![258.589, 616.297, 215.061, 498.361, 0.]; 409 | let data = vec![("Y", y), ("X1", x1), ("X2", x2), ("X3", x3)]; 410 | let data = RegressionDataBuilder::new().build_from(data).unwrap(); 411 | let formula = "Y ~ X1 + X2 + X3"; 412 | FormulaRegressionBuilder::new() 413 | .data(&data) 414 | .formula(formula) 415 | .fit() 416 | .unwrap() 417 | } 418 | 419 | #[test] 420 | fn test_prediction_empty_vectors() { 421 | let model = build_model(); 422 | let new_data: HashMap, _> = vec![("X1", vec![]), ("X2", vec![]), ("X3", vec![])] 423 | .into_iter() 424 | .map(|(x, y)| (Cow::from(x), y)) 425 | .collect(); 426 | assert!(model.check_variables(&new_data).is_err()); 427 | } 428 | 429 | #[test] 430 | fn test_prediction_vectors_with_different_lengths() { 431 | let model = build_model(); 432 | let new_data: HashMap, _> = vec![ 433 | ("X1", vec![1.0, 2.0]), 434 | ("X2", vec![2.0, 1.0]), 435 | ("X3", vec![3.0]), 436 | ] 437 | .into_iter() 438 | .map(|(x, y)| (Cow::from(x), y)) 439 | .collect(); 440 | assert!(model.check_variables(&new_data).is_err()); 441 | } 442 | 443 | #[test] 444 | fn test_too_many_prediction_variables() { 445 | let model = build_model(); 446 | let new_data: HashMap, _> = vec![ 447 | ("X1", vec![1.0]), 448 | ("X2", vec![2.0]), 449 | ("X3", vec![3.0]), 450 | ("X4", vec![4.0]), 451 | ] 452 | .into_iter() 453 | .map(|(x, y)| (Cow::from(x), y)) 454 | .collect(); 455 | assert!(model.check_variables(&new_data).is_err()); 456 | } 457 | 458 | #[test] 459 | fn test_not_enough_prediction_variables() { 460 | let model = build_model(); 461 | let new_data: HashMap, _> = vec![("X1", vec![1.0]), ("X2", vec![2.0])] 462 | .into_iter() 463 | .map(|(x, y)| (Cow::from(x), y)) 464 | .collect(); 465 | assert!(model.check_variables(&new_data).is_err()); 466 | } 467 | 468 | #[test] 469 | fn test_prediction() { 470 | let model = build_model(); 471 | let new_data = vec![("X1", vec![2.5]), ("X2", vec![2.0]), ("X3", vec![2.0])]; 472 | let prediction = model.predict(new_data).unwrap(); 473 | assert_eq!(prediction.len(), 1); 474 | assert_almost_eq!(prediction[0], 3.500000000000111, 1.0E-7); 475 | } 476 | 477 | #[test] 478 | fn test_multiple_predictions() { 479 | let model = build_model(); 480 | let new_data = vec![ 481 | ("X1", vec![2.5, 3.5]), 482 | ("X2", vec![2.0, 8.0]), 483 | ("X3", vec![2.0, 1.0]), 484 | ]; 485 | let prediction = model.predict(new_data).unwrap(); 486 | assert_eq!(prediction.len(), 2); 487 | assert_almost_eq!(prediction[0], 3.500000000000111, 1.0E-7); 488 | assert_almost_eq!(prediction[1], 2.5000000000001337, 1.0E-7); 489 | } 490 | 491 | #[test] 492 | fn test_multiple_predictions_out_of_order() { 493 | let model = build_model(); 494 | let new_data = vec![ 495 | ("X1", vec![2.5, 3.5]), 496 | ("X3", vec![2.0, 1.0]), 497 | ("X2", vec![2.0, 8.0]), 498 | ]; 499 | let prediction = model.predict(new_data).unwrap(); 500 | assert_eq!(prediction.len(), 2); 501 | assert_almost_eq!(prediction[0], 3.500000000000111, 1.0E-7); 502 | assert_almost_eq!(prediction[1], 2.5000000000001337, 1.0E-7); 503 | } 504 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | /*! 2 | `linregress` provides an easy to use implementation of ordinary 3 | least squared linear regression with some basic statistics. 4 | See [`RegressionModel`] for details. 5 | 6 | The builder [`FormulaRegressionBuilder`] is used to construct a model from a 7 | table of data and an R-style formula or a list of columns to use. 8 | Currently only very simple formulae are supported, 9 | see [`FormulaRegressionBuilder.formula`] for details. 10 | 11 | # Example 12 | 13 | ``` 14 | use linregress::{FormulaRegressionBuilder, RegressionDataBuilder}; 15 | 16 | # use linregress::Error; 17 | # fn main() -> Result<(), Error> { 18 | let y = vec![1., 2. ,3. , 4., 5.]; 19 | let x1 = vec![5., 4., 3., 2., 1.]; 20 | let x2 = vec![729.53, 439.0367, 42.054, 1., 0.]; 21 | let x3 = vec![258.589, 616.297, 215.061, 498.361, 0.]; 22 | let data = vec![("Y", y), ("X1", x1), ("X2", x2), ("X3", x3)]; 23 | let data = RegressionDataBuilder::new().build_from(data)?; 24 | let formula = "Y ~ X1 + X2 + X3"; 25 | let model = FormulaRegressionBuilder::new() 26 | .data(&data) 27 | .formula(formula) 28 | .fit()?; 29 | let parameters: Vec<_> = model.iter_parameter_pairs().collect(); 30 | let pvalues: Vec<_> = model.iter_p_value_pairs().collect(); 31 | let standard_errors: Vec<_> = model.iter_se_pairs().collect(); 32 | assert_eq!( 33 | parameters, 34 | vec![ 35 | ("X1", -1.0000000000000004), 36 | ("X2", 1.5508427875232655e-15), 37 | ("X3", -1.4502288259166107e-15), 38 | ] 39 | ); 40 | assert_eq!( 41 | standard_errors, 42 | vec![ 43 | ("X1", 9.22799842631787e-13), 44 | ("X2", 4.184801029355531e-15), 45 | ("X3", 2.5552590991720465e-15), 46 | ] 47 | ); 48 | assert_eq!( 49 | pvalues, 50 | vec![ 51 | ("X1", 5.874726257570879e-13), 52 | ("X2", 0.7740647742008093), 53 | ("X3", 0.6713674042015161), 54 | ] 55 | ); 56 | # Ok(()) 57 | # } 58 | ``` 59 | 60 | [`RegressionModel`]: struct.RegressionModel.html 61 | [`FormulaRegressionBuilder`]: struct.FormulaRegressionBuilder.html 62 | [`FormulaRegressionBuilder.formula`]: struct.FormulaRegressionBuilder.html#method.formula 63 | */ 64 | 65 | use std::borrow::Cow; 66 | use std::collections::{BTreeSet, HashMap, HashSet}; 67 | use std::iter; 68 | use std::ops::Neg; 69 | 70 | use nalgebra::{DMatrix, DVector}; 71 | 72 | pub use crate::error::{Error, InconsistentSlopes}; 73 | use crate::stats::students_t_cdf; 74 | 75 | mod error; 76 | mod stats; 77 | #[cfg(test)] 78 | mod tests; 79 | 80 | macro_rules! ensure { 81 | ($predicate:expr, $error:expr) => { 82 | if !$predicate { 83 | return Err($error); 84 | } 85 | }; 86 | } 87 | 88 | /// Only exposed for use in doc comments. This macro is not considered part of this crate's stable API. 89 | #[macro_export] 90 | macro_rules! assert_almost_eq { 91 | ($a:expr, $b:expr) => { 92 | $crate::assert_almost_eq!($a, $b, 1.0E-14); 93 | }; 94 | ($a:expr, $b:expr, $prec:expr) => { 95 | if !$crate::almost_equal($a, $b, $prec) { 96 | panic!("assert_almost_eq failed:\n{:?} vs\n{:?}", $a, $b); 97 | } 98 | }; 99 | } 100 | 101 | /// Only exposed for use in doc comments. This macro is not considered part of this crate's stable API. 102 | #[macro_export] 103 | macro_rules! assert_slices_almost_eq { 104 | ($a:expr, $b:expr) => { 105 | $crate::assert_slices_almost_eq!($a, $b, 1.0E-14); 106 | }; 107 | ($a:expr, $b:expr, $prec:expr) => { 108 | if !$crate::slices_almost_equal($a, $b, $prec) { 109 | panic!("assert_slices_almost_eq failed:\n{:?} vs\n{:?}", $a, $b); 110 | } 111 | }; 112 | } 113 | 114 | /// Only exposed for use in doc comments. This function is not considered part of this crate's stable API. 115 | #[doc(hidden)] 116 | pub fn almost_equal(a: f64, b: f64, precision: f64) -> bool { 117 | if a.is_infinite() || b.is_infinite() || a.is_nan() || b.is_nan() { 118 | false 119 | } else { 120 | (a - b).abs() <= precision 121 | } 122 | } 123 | 124 | /// Only exposed for use in doc comments. This function is not considered part of this crate's stable API. 125 | #[doc(hidden)] 126 | pub fn slices_almost_equal(a: &[f64], b: &[f64], precision: f64) -> bool { 127 | if a.len() != b.len() { 128 | return false; 129 | } 130 | for (&x, &y) in a.iter().zip(b.iter()) { 131 | if !almost_equal(x, y, precision) { 132 | return false; 133 | } 134 | } 135 | true 136 | } 137 | 138 | /// Compares `a` and `b` approximately. 139 | /// 140 | /// They are considered equal if 141 | /// `(a-b).abs() <= epsilon` or they differ by at most `max_ulps` 142 | /// `units of least precision` i.e. there are at most `max_ulps` 143 | /// other representable floating point numbers between `a` and `b` 144 | fn ulps_eq(a: f64, b: f64, epsilon: f64, max_ulps: u32) -> bool { 145 | if (a - b).abs() <= epsilon { 146 | return true; 147 | } 148 | if a.signum() != b.signum() { 149 | return false; 150 | } 151 | let a: u64 = a.to_bits(); 152 | let b: u64 = b.to_bits(); 153 | a.abs_diff(b) <= max_ulps as u64 154 | } 155 | 156 | /// A builder to create and fit a linear regression model. 157 | /// 158 | /// Given a dataset and a set of columns to use this builder 159 | /// will produce an ordinary least squared linear regression model. 160 | /// 161 | /// See [`formula`] and [`data`] for details on how to configure this builder. 162 | /// 163 | /// The pseudo inverse method is used to fit the model. 164 | /// 165 | /// # Usage 166 | /// 167 | /// ``` 168 | /// use linregress::{FormulaRegressionBuilder, RegressionDataBuilder, assert_almost_eq}; 169 | /// 170 | /// # use linregress::Error; 171 | /// # fn main() -> Result<(), Error> { 172 | /// let y = vec![1., 2. ,3., 4.]; 173 | /// let x = vec![4., 3., 2., 1.]; 174 | /// let data = vec![("Y", y), ("X", x)]; 175 | /// let data = RegressionDataBuilder::new().build_from(data)?; 176 | /// let model = FormulaRegressionBuilder::new().data(&data).formula("Y ~ X").fit()?; 177 | /// // Alternatively 178 | /// let model = FormulaRegressionBuilder::new().data(&data).data_columns("Y", ["X"]).fit()?; 179 | /// let params = model.parameters(); 180 | /// assert_almost_eq!(params[0], 4.999999999999998); 181 | /// assert_almost_eq!(params[1], -0.9999999999999989); 182 | /// assert_eq!(model.regressor_names()[0], "X"); 183 | /// # Ok(()) 184 | /// # } 185 | /// ``` 186 | /// 187 | /// [`formula`]: struct.FormulaRegressionBuilder.html#method.formula 188 | /// [`data`]: struct.FormulaRegressionBuilder.html#method.data 189 | #[derive(Debug, Clone)] 190 | pub struct FormulaRegressionBuilder<'a> { 191 | data: Option<&'a RegressionData<'a>>, 192 | formula: Option>, 193 | columns: Option<(Cow<'a, str>, Vec>)>, 194 | } 195 | 196 | impl Default for FormulaRegressionBuilder<'_> { 197 | fn default() -> Self { 198 | FormulaRegressionBuilder::new() 199 | } 200 | } 201 | 202 | impl<'a> FormulaRegressionBuilder<'a> { 203 | /// Create as new FormulaRegressionBuilder with no data or formula set. 204 | pub fn new() -> Self { 205 | FormulaRegressionBuilder { 206 | data: None, 207 | formula: None, 208 | columns: None, 209 | } 210 | } 211 | 212 | /// Set the data to be used for the regression. 213 | /// 214 | /// The data has to be given as a reference to a [`RegressionData`] struct. 215 | /// See [`RegressionDataBuilder`] for details. 216 | /// 217 | /// [`RegressionData`]: struct.RegressionData.html 218 | /// [`RegressionDataBuilder`]: struct.RegressionDataBuilder.html 219 | pub fn data(mut self, data: &'a RegressionData<'a>) -> Self { 220 | self.data = Some(data); 221 | self 222 | } 223 | 224 | /// Set the formula to use for the regression. 225 | /// 226 | /// The expected format is ` ~ + `. 227 | /// 228 | /// E.g. for a regressand named Y and three regressors named A, B and C 229 | /// the correct format would be `Y ~ A + B + C`. 230 | /// 231 | /// Note that there is currently no special support for categorical variables. 232 | /// So if you have a categorical variable with more than two distinct values 233 | /// or values that are not `0` and `1` you will need to perform "dummy coding" yourself. 234 | /// 235 | /// Alternatively you can use [`data_columns`][Self::data_columns]. 236 | pub fn formula>>(mut self, formula: T) -> Self { 237 | self.formula = Some(formula.into()); 238 | self 239 | } 240 | 241 | /// Set the columns to be used as regressand and regressors for the regression. 242 | /// 243 | /// Note that there is currently no special support for categorical variables. 244 | /// So if you have a categorical variable with more than two distinct values 245 | /// or values that are not `0` and `1` you will need to perform "dummy coding" yourself. 246 | /// 247 | /// Alternatively you can use [`formula`][Self::formula]. 248 | pub fn data_columns(mut self, regressand: S1, regressors: I) -> Self 249 | where 250 | I: IntoIterator, 251 | S1: Into>, 252 | S2: Into>, 253 | { 254 | let regressand = regressand.into(); 255 | let regressors: Vec<_> = regressors.into_iter().map(|i| i.into()).collect(); 256 | self.columns = Some((regressand, regressors)); 257 | self 258 | } 259 | 260 | /// Fits the model and returns a [`RegressionModel`] if successful. 261 | /// You need to set the data with [`data`] and a formula with [`formula`] 262 | /// before you can use it. 263 | /// 264 | /// [`RegressionModel`]: struct.RegressionModel.html 265 | /// [`data`]: struct.FormulaRegressionBuilder.html#method.data 266 | /// [`formula`]: struct.FormulaRegressionBuilder.html#method.formula 267 | pub fn fit(self) -> Result { 268 | let FittingData(input_vector, output_matrix, outputs) = 269 | Self::get_matrices_and_regressor_names(self)?; 270 | RegressionModel::try_from_matrices_and_regressor_names(input_vector, output_matrix, outputs) 271 | } 272 | 273 | /// Like [`fit`] but does not perfom any statistics on the resulting model. 274 | /// Returns a [`Vec`] containing the model parameters 275 | /// (in the order `intercept, column 1, column 2, …`) if successfull. 276 | /// 277 | /// This is usefull if you do not care about the statistics or the model and data 278 | /// you want to fit result in too few residual degrees of freedom to perform 279 | /// statistics. 280 | /// 281 | /// [`fit`]: struct.FormulaRegressionBuilder.html#method.fit 282 | pub fn fit_without_statistics(self) -> Result, Error> { 283 | let FittingData(input_vector, output_matrix, _output_names) = 284 | Self::get_matrices_and_regressor_names(self)?; 285 | let low_level_result = fit_ols_pinv(input_vector, output_matrix)?; 286 | let parameters = low_level_result.params; 287 | Ok(parameters.iter().copied().collect()) 288 | } 289 | 290 | fn get_matrices_and_regressor_names(self) -> Result { 291 | let (input, outputs) = self.get_data_columns()?; 292 | let data = &self.data.ok_or(Error::NoData)?.data; 293 | let input_vector: Vec = data 294 | .get(&input) 295 | .cloned() 296 | .ok_or_else(|| Error::ColumnNotInData(input.into()))?; 297 | let mut output_matrix = Vec::new(); 298 | // Add column of all ones as the first column of the matrix 299 | output_matrix.resize(input_vector.len(), 1.); 300 | // Add each input as a new column of the matrix 301 | for output in &outputs { 302 | let output_vec = data 303 | .get(output.as_ref()) 304 | .ok_or_else(|| Error::ColumnNotInData(output.to_string()))?; 305 | ensure!( 306 | output_vec.len() == input_vector.len(), 307 | Error::RegressorRegressandDimensionMismatch(output.to_string()) 308 | ); 309 | output_matrix.extend_from_slice(output_vec); 310 | } 311 | let output_matrix = DMatrix::from_vec(input_vector.len(), outputs.len() + 1, output_matrix); 312 | let outputs: Vec<_> = outputs.iter().map(|x| x.to_string()).collect(); 313 | Ok(FittingData(input_vector, output_matrix, outputs)) 314 | } 315 | 316 | fn get_data_columns(&self) -> Result<(Cow<'_, str>, Vec>), Error> { 317 | match (self.formula.as_ref(), self.columns.as_ref()) { 318 | (Some(..), Some(..)) => Err(Error::BothFormulaAndDataColumnsGiven), 319 | (Some(formula), None) => Self::parse_formula(formula), 320 | (None, Some((regressand, regressors))) => { 321 | ensure!(!regressors.is_empty(), Error::InvalidDataColumns); 322 | Ok((regressand.clone(), regressors.clone())) 323 | } 324 | (None, None) => Err(Error::NoFormula), 325 | } 326 | } 327 | 328 | fn parse_formula(formula: &str) -> Result<(Cow<'_, str>, Vec>), Error> { 329 | let (input, outputs) = formula.split_once('~').ok_or(Error::InvalidFormula)?; 330 | let input = input.trim(); 331 | let outputs: Vec<_> = outputs 332 | .split('+') 333 | .map(str::trim) 334 | .filter(|x| !x.is_empty()) 335 | .map(|i| i.into()) 336 | .collect(); 337 | ensure!(!outputs.is_empty(), Error::InvalidFormula); 338 | Ok((input.into(), outputs)) 339 | } 340 | } 341 | 342 | /// A simple tuple struct to reduce the type complxity of the 343 | /// return type of get_matrices_and_regressor_names. 344 | struct FittingData(Vec, DMatrix, Vec); 345 | 346 | /// A container struct for the regression data. 347 | /// 348 | /// This struct is obtained using a [`RegressionDataBuilder`]. 349 | /// 350 | /// [`RegressionDataBuilder`]: struct.RegressionDataBuilder.html 351 | #[derive(Debug, Clone)] 352 | pub struct RegressionData<'a> { 353 | data: HashMap, Vec>, 354 | } 355 | 356 | impl<'a> RegressionData<'a> { 357 | /// Constructs a new `RegressionData` struct from any collection that 358 | /// implements the `IntoIterator` trait. 359 | /// 360 | /// The iterator must consist of tupels of the form `(S, Vec)` where 361 | /// `S` is a type that can be converted to a `Cow<'a, str>`. 362 | /// 363 | /// `invalid_value_handling` specifies what to do if invalid data is encountered. 364 | fn new( 365 | data: I, 366 | invalid_value_handling: InvalidValueHandling, 367 | ) -> Result, Error> 368 | where 369 | I: IntoIterator)>, 370 | S: Into>, 371 | { 372 | let temp: HashMap<_, _> = data 373 | .into_iter() 374 | .map(|(key, value)| (key.into(), value)) 375 | .collect(); 376 | ensure!( 377 | !temp.is_empty(), 378 | Error::RegressionDataError("The data contains no columns.".into()) 379 | ); 380 | let mut len: Option = None; 381 | for (key, val) in temp.iter() { 382 | let this_len = val.len(); 383 | if len.is_none() { 384 | len = Some(this_len); 385 | } 386 | ensure!( 387 | this_len > 0, 388 | Error::RegressionDataError("The data contains an empty column.".into()) 389 | ); 390 | ensure!( 391 | Some(this_len) == len, 392 | Error::RegressionDataError( 393 | "The lengths of the columns in the given data are inconsistent.".into() 394 | ) 395 | ); 396 | ensure!( 397 | !key.contains('~') && !key.contains('+'), 398 | Error::RegressionDataError( 399 | "The column names may not contain `~` or `+`, because they are used \ 400 | as separators in the formula." 401 | .into() 402 | ) 403 | ); 404 | } 405 | if Self::check_if_all_columns_are_equal(&temp) { 406 | return Err(Error::RegressionDataError( 407 | "All input columns contain only equal values. Fitting this model would lead \ 408 | to invalid statistics." 409 | .into(), 410 | )); 411 | } 412 | if Self::check_if_data_is_valid(&temp) { 413 | return Ok(Self { data: temp }); 414 | } 415 | match invalid_value_handling { 416 | InvalidValueHandling::ReturnError => Err(Error::RegressionDataError( 417 | "The data contains a non real value (NaN or infinity or negative infinity). \ 418 | If you would like to silently drop these values configure the builder with \ 419 | InvalidValueHandling::DropInvalid." 420 | .into(), 421 | )), 422 | InvalidValueHandling::DropInvalid => { 423 | let temp = Self::drop_invalid_values(temp); 424 | let first_key = temp.keys().next().expect("Cleaned data has no columns."); 425 | let first_len = temp[first_key].len(); 426 | ensure!( 427 | first_len > 0, 428 | Error::RegressionDataError("The cleaned data is empty.".into()) 429 | ); 430 | Ok(Self { data: temp }) 431 | } 432 | } 433 | } 434 | 435 | fn check_if_all_columns_are_equal(data: &HashMap, Vec>) -> bool { 436 | for column in data.values() { 437 | if column.is_empty() { 438 | return false; 439 | } 440 | let first_iter = iter::repeat(&column[0]).take(column.len()); 441 | if !first_iter.eq(column.iter()) { 442 | return false; 443 | } 444 | } 445 | true 446 | } 447 | 448 | fn check_if_data_is_valid(data: &HashMap, Vec>) -> bool { 449 | for column in data.values() { 450 | if column.iter().any(|x| !x.is_finite()) { 451 | return false; 452 | } 453 | } 454 | true 455 | } 456 | 457 | fn drop_invalid_values( 458 | data: HashMap, Vec>, 459 | ) -> HashMap, Vec> { 460 | let mut invalid_rows: BTreeSet = BTreeSet::new(); 461 | for column in data.values() { 462 | for (index, value) in column.iter().enumerate() { 463 | if !value.is_finite() { 464 | invalid_rows.insert(index); 465 | } 466 | } 467 | } 468 | let mut cleaned_data = HashMap::new(); 469 | for (key, mut column) in data { 470 | for index in invalid_rows.iter().rev() { 471 | column.remove(*index); 472 | } 473 | cleaned_data.insert(key, column); 474 | } 475 | cleaned_data 476 | } 477 | } 478 | 479 | /// A builder to create a RegressionData struct for use with a [`FormulaRegressionBuilder`]. 480 | /// 481 | /// [`FormulaRegressionBuilder`]: struct.FormulaRegressionBuilder.html 482 | #[derive(Debug, Clone, Copy, Default)] 483 | pub struct RegressionDataBuilder { 484 | handle_invalid_values: InvalidValueHandling, 485 | } 486 | 487 | impl RegressionDataBuilder { 488 | /// Create a new [`RegressionDataBuilder`]. 489 | /// 490 | /// [`RegressionDataBuilder`]: struct.RegressionDataBuilder.html 491 | pub fn new() -> Self { 492 | Self::default() 493 | } 494 | 495 | /// Configure how to handle non real `f64` values (NaN or infinity or negative infinity) using 496 | /// a variant of the [`InvalidValueHandling`] enum. 497 | /// 498 | /// The default value is [`ReturnError`]. 499 | /// 500 | /// # Example 501 | /// ``` 502 | /// use linregress::{InvalidValueHandling, RegressionDataBuilder}; 503 | /// 504 | /// # use linregress::Error; 505 | /// # fn main() -> Result<(), Error> { 506 | /// let builder = RegressionDataBuilder::new(); 507 | /// let builder = builder.invalid_value_handling(InvalidValueHandling::DropInvalid); 508 | /// # Ok(()) 509 | /// # } 510 | /// ``` 511 | /// 512 | /// [`InvalidValueHandling`]: enum.InvalidValueHandling.html 513 | /// [`ReturnError`]: enum.InvalidValueHandling.html#variant.ReturnError 514 | pub fn invalid_value_handling(mut self, setting: InvalidValueHandling) -> Self { 515 | self.handle_invalid_values = setting; 516 | self 517 | } 518 | 519 | /// Build a [`RegressionData`] struct from the given data. 520 | /// 521 | /// Any type that implements the [`IntoIterator`] trait can be used for the data. 522 | /// This could for example be a [`Hashmap`] or a [`Vec`]. 523 | /// 524 | /// The iterator must consist of tupels of the form `(S, Vec)` where 525 | /// `S` is a type that implements `Into>`, such as [`String`] or [`str`]. 526 | /// 527 | /// You can think of this format as the representation of a table of data where 528 | /// each tuple `(S, Vec)` represents a column. The `S` is the header or label of the 529 | /// column and the `Vec` contains the data of the column. 530 | /// 531 | /// Because `~` and `+` are used as separators in the formula they may not be used in the name 532 | /// of a data column. 533 | /// 534 | /// # Example 535 | /// 536 | /// ``` 537 | /// use std::collections::HashMap; 538 | /// use linregress::RegressionDataBuilder; 539 | /// 540 | /// # use linregress::Error; 541 | /// # fn main() -> Result<(), Error> { 542 | /// let mut data1 = HashMap::new(); 543 | /// data1.insert("Y", vec![1., 2., 3., 4.]); 544 | /// data1.insert("X", vec![4., 3., 2., 1.]); 545 | /// let regression_data1 = RegressionDataBuilder::new().build_from(data1)?; 546 | /// 547 | /// let y = vec![1., 2., 3., 4.]; 548 | /// let x = vec![4., 3., 2., 1.]; 549 | /// let data2 = vec![("X", x), ("Y", y)]; 550 | /// let regression_data2 = RegressionDataBuilder::new().build_from(data2)?; 551 | /// # Ok(()) 552 | /// # } 553 | /// ``` 554 | /// 555 | /// [`RegressionData`]: struct.RegressionData.html 556 | /// [`IntoIterator`]: https://doc.rust-lang.org/std/iter/trait.IntoIterator.html 557 | /// [`Hashmap`]: https://doc.rust-lang.org/std/collections/struct.HashMap.html 558 | /// [`Vec`]: https://doc.rust-lang.org/std/vec/struct.Vec.html 559 | /// [`String`]: https://doc.rust-lang.org/std/string/struct.String.html 560 | /// [`str`]: https://doc.rust-lang.org/std/primitive.str.html 561 | pub fn build_from<'a, I, S>(self, data: I) -> Result, Error> 562 | where 563 | I: IntoIterator)>, 564 | S: Into>, 565 | { 566 | RegressionData::new(data, self.handle_invalid_values) 567 | } 568 | } 569 | 570 | /// How to proceed if given non real `f64` values (NaN or infinity or negative infinity). 571 | /// 572 | /// Used with [`RegressionDataBuilder.invalid_value_handling`] 573 | /// 574 | /// The default is [`ReturnError`]. 575 | /// 576 | /// [`RegressionDataBuilder.invalid_value_handling`]: struct.RegressionDataBuilder.html#method.invalid_value_handling 577 | /// [`ReturnError`]: enum.InvalidValueHandling.html#variant.ReturnError 578 | #[derive(Debug, Clone, Copy, Default)] 579 | #[non_exhaustive] 580 | pub enum InvalidValueHandling { 581 | /// Return an error to the caller. 582 | #[default] 583 | ReturnError, 584 | /// Drop the columns containing the invalid values. 585 | DropInvalid, 586 | } 587 | 588 | /// A fitted regression model. 589 | /// 590 | /// Is the result of [`FormulaRegressionBuilder.fit()`]. 591 | /// 592 | ///[`FormulaRegressionBuilder.fit()`]: struct.FormulaRegressionBuilder.html#method.fit 593 | #[derive(Debug, Clone)] 594 | pub struct RegressionModel { 595 | regressor_names: Vec, 596 | model: LowLevelRegressionModel, 597 | } 598 | 599 | impl RegressionModel { 600 | /// The names of the regressor columns 601 | #[inline] 602 | pub fn regressor_names(&self) -> &[String] { 603 | &self.regressor_names 604 | } 605 | 606 | /// The two-tailed p-values for the t-statistics of the parameters 607 | #[inline] 608 | pub fn p_values(&self) -> &[f64] { 609 | self.model.p_values() 610 | } 611 | 612 | /// Iterates over pairs of regressor columns and their associated p-values 613 | /// 614 | /// # Note 615 | /// 616 | /// This does not include the value for the intercept. 617 | /// 618 | /// # Usage 619 | /// 620 | /// ``` 621 | /// # use linregress::Error; 622 | /// # fn main() -> Result<(), Error> { 623 | /// use linregress::{FormulaRegressionBuilder, RegressionDataBuilder}; 624 | /// 625 | /// let y = vec![1.,2. ,3. , 4.]; 626 | /// let x1 = vec![4., 3., 2., 1.]; 627 | /// let x2 = vec![1., 2., 3., 4.]; 628 | /// let data = vec![("Y", y), ("X1", x1), ("X2", x2)]; 629 | /// let data = RegressionDataBuilder::new().build_from(data)?; 630 | /// let model = FormulaRegressionBuilder::new().data(&data).formula("Y ~ X1 + X2").fit()?; 631 | /// let pairs: Vec<(&str, f64)> = model.iter_p_value_pairs().collect(); 632 | /// assert_eq!(pairs[0], ("X1", 1.7052707580549508e-28)); 633 | /// assert_eq!(pairs[1], ("X2", 2.522589878779506e-31)); 634 | /// # Ok(()) 635 | /// # } 636 | /// ``` 637 | #[inline] 638 | pub fn iter_p_value_pairs(&self) -> impl Iterator + '_ { 639 | self.regressor_names 640 | .iter() 641 | .zip(self.model.p_values().iter().skip(1)) 642 | .map(|(r, &v)| (r.as_str(), v)) 643 | } 644 | 645 | /// The residuals of the model 646 | #[inline] 647 | pub fn residuals(&self) -> &[f64] { 648 | self.model.residuals() 649 | } 650 | 651 | /// The model's intercept and slopes (also known as betas) 652 | #[inline] 653 | pub fn parameters(&self) -> &[f64] { 654 | self.model.parameters() 655 | } 656 | 657 | /// Iterates over pairs of regressor columns and their associated slope values 658 | /// 659 | /// # Note 660 | /// 661 | /// This does not include the value for the intercept. 662 | /// 663 | /// # Usage 664 | /// 665 | /// ``` 666 | /// # use linregress::Error; 667 | /// # fn main() -> Result<(), Error> { 668 | /// use linregress::{FormulaRegressionBuilder, RegressionDataBuilder}; 669 | /// 670 | /// let y = vec![1.,2. ,3. , 4.]; 671 | /// let x1 = vec![4., 3., 2., 1.]; 672 | /// let x2 = vec![1., 2., 3., 4.]; 673 | /// let data = vec![("Y", y), ("X1", x1), ("X2", x2)]; 674 | /// let data = RegressionDataBuilder::new().build_from(data)?; 675 | /// let model = FormulaRegressionBuilder::new().data(&data).formula("Y ~ X1 + X2").fit()?; 676 | /// let pairs: Vec<(&str, f64)> = model.iter_parameter_pairs().collect(); 677 | /// assert_eq!(pairs[0], ("X1", -0.03703703703703709)); 678 | /// assert_eq!(pairs[1], ("X2", 0.9629629629629626)); 679 | /// # Ok(()) 680 | /// # } 681 | /// ``` 682 | #[inline] 683 | pub fn iter_parameter_pairs(&self) -> impl Iterator + '_ { 684 | self.regressor_names 685 | .iter() 686 | .zip(self.model.parameters().iter().skip(1)) 687 | .map(|(r, &v)| (r.as_str(), v)) 688 | } 689 | 690 | /// The standard errors of the parameter estimates 691 | #[inline] 692 | pub fn se(&self) -> &[f64] { 693 | self.model.se() 694 | } 695 | 696 | /// Iterates over pairs of regressor columns and their associated standard errors 697 | /// 698 | /// # Note 699 | /// 700 | /// This does not include the value for the intercept. 701 | /// 702 | /// # Usage 703 | /// 704 | /// ``` 705 | /// # use linregress::Error; 706 | /// # fn main() -> Result<(), Error> { 707 | /// use linregress::{FormulaRegressionBuilder, RegressionDataBuilder}; 708 | /// 709 | /// let y = vec![1.,2. ,3. , 4.]; 710 | /// let x1 = vec![4., 3., 2., 1.]; 711 | /// let x2 = vec![1., 2., 3., 4.]; 712 | /// let data = vec![("Y", y), ("X1", x1), ("X2", x2)]; 713 | /// let data = RegressionDataBuilder::new().build_from(data)?; 714 | /// let model = FormulaRegressionBuilder::new().data(&data).formula("Y ~ X1 + X2").fit()?; 715 | /// let pairs: Vec<(&str, f64)> = model.iter_parameter_pairs().collect(); 716 | /// assert_eq!(pairs[0], ("X1", -0.03703703703703709)); 717 | /// assert_eq!(pairs[1], ("X2", 0.9629629629629626)); 718 | /// # Ok(()) 719 | /// # } 720 | /// ``` 721 | #[inline] 722 | pub fn iter_se_pairs(&self) -> impl Iterator + '_ { 723 | self.regressor_names 724 | .iter() 725 | .zip(self.model.se().iter().skip(1)) 726 | .map(|(r, &v)| (r.as_str(), v)) 727 | } 728 | 729 | /// Sum of squared residuals 730 | #[inline] 731 | pub fn ssr(&self) -> f64 { 732 | self.model.ssr() 733 | } 734 | 735 | /// R-squared of the model 736 | #[inline] 737 | pub fn rsquared(&self) -> f64 { 738 | self.model.rsquared() 739 | } 740 | 741 | /// Adjusted R-squared of the model 742 | #[inline] 743 | pub fn rsquared_adj(&self) -> f64 { 744 | self.model.rsquared_adj() 745 | } 746 | 747 | /// A scale factor for the covariance matrix 748 | /// 749 | /// Note that the square root of `scale` is often 750 | /// called the standard error of the regression. 751 | #[inline] 752 | pub fn scale(&self) -> f64 { 753 | self.model.scale() 754 | } 755 | /// Evaluates the model on given new input data and returns the predicted values. 756 | /// 757 | /// The new data is expected to have the same columns as the original data. 758 | /// See [`RegressionDataBuilder.build`] for details on the type of the `new_data` parameter. 759 | /// 760 | /// ## Note 761 | /// 762 | /// This function does *no* special handling of non real values (NaN or infinity or negative infinity). 763 | /// Such a value in `new_data` will result in a corresponding meaningless prediction. 764 | /// 765 | /// ## Example 766 | /// 767 | /// ``` 768 | /// # use linregress::{RegressionDataBuilder, FormulaRegressionBuilder, assert_slices_almost_eq}; 769 | /// # use linregress::Error; 770 | /// # fn main() -> Result<(), Error> { 771 | /// let y = vec![1., 2., 3., 4., 5.]; 772 | /// let x1 = vec![5., 4., 3., 2., 1.]; 773 | /// let x2 = vec![729.53, 439.0367, 42.054, 1., 0.]; 774 | /// let x3 = vec![258.589, 616.297, 215.061, 498.361, 0.]; 775 | /// let data = vec![("Y", y), ("X1", x1), ("X2", x2), ("X3", x3)]; 776 | /// let data = RegressionDataBuilder::new().build_from(data).unwrap(); 777 | /// let formula = "Y ~ X1 + X2 + X3"; 778 | /// let model = FormulaRegressionBuilder::new() 779 | /// .data(&data) 780 | /// .formula(formula) 781 | /// .fit()?; 782 | /// let new_data = vec![ 783 | /// ("X1", vec![2.5, 3.5]), 784 | /// ("X2", vec![2.0, 8.0]), 785 | /// ("X3", vec![2.0, 1.0]), 786 | /// ]; 787 | /// let prediction: Vec = model.predict(new_data)?; 788 | /// assert_slices_almost_eq!(&prediction, &[3.500000000000028, 2.5000000000000644], 1.0e-13); 789 | /// # Ok(()) 790 | /// # } 791 | /// ``` 792 | /// 793 | /// [`RegressionDataBuilder.build`]: struct.RegressionDataBuilder.html#method.build_from 794 | pub fn predict<'a, I, S>(&self, new_data: I) -> Result, Error> 795 | where 796 | I: IntoIterator)>, 797 | S: Into>, 798 | { 799 | let new_data: HashMap, Vec> = new_data 800 | .into_iter() 801 | .map(|(key, value)| (key.into(), value)) 802 | .collect(); 803 | self.check_variables(&new_data)?; 804 | let input_len = new_data.values().next().unwrap().len(); 805 | let mut new_data_values: Vec = vec![]; 806 | for key in &self.regressor_names { 807 | new_data_values.extend_from_slice(new_data[&Cow::from(key)].as_slice()); 808 | } 809 | 810 | let num_regressors = self.model.parameters.len() - 1; 811 | let new_data_matrix = DMatrix::from_vec(input_len, num_regressors, new_data_values); 812 | let param_matrix = DMatrix::from_iterator( 813 | num_regressors, 814 | 1, 815 | self.model.parameters.iter().skip(1).copied(), 816 | ); 817 | let intercept = self.model.parameters[0]; 818 | let intercept_matrix = 819 | DMatrix::from_iterator(input_len, 1, std::iter::repeat(intercept).take(input_len)); 820 | let predictions = (new_data_matrix * param_matrix) + intercept_matrix; 821 | let predictions: Vec = predictions.into_iter().copied().collect(); 822 | Ok(predictions) 823 | } 824 | 825 | fn check_variables( 826 | &self, 827 | given_parameters: &HashMap, Vec>, 828 | ) -> Result<(), Error> { 829 | ensure!(!given_parameters.is_empty(), Error::NoData); 830 | let first_len = given_parameters.values().next().unwrap().len(); 831 | ensure!(first_len > 0, Error::NoData); 832 | let model_parameters: HashSet<_> = self.regressor_names.iter().map(Cow::from).collect(); 833 | for param in &model_parameters { 834 | if !given_parameters.contains_key(param) { 835 | return Err(Error::ColumnNotInData(param.to_string())); 836 | } 837 | } 838 | for (param, values) in given_parameters { 839 | ensure!(values.len() == first_len, Error::InconsistentVectors); 840 | if !model_parameters.contains(param) { 841 | return Err(Error::ModelColumnNotInData(param.to_string())); 842 | } 843 | } 844 | Ok(()) 845 | } 846 | 847 | fn try_from_matrices_and_regressor_names>( 848 | inputs: Vec, 849 | outputs: DMatrix, 850 | output_names: I, 851 | ) -> Result { 852 | let low_level_result = fit_ols_pinv(inputs, outputs)?; 853 | let model = LowLevelRegressionModel::from_low_level_regression(low_level_result)?; 854 | let regressor_names: Vec = output_names.into_iter().collect(); 855 | let num_slopes = model.parameters.len() - 1; 856 | ensure!( 857 | regressor_names.len() == num_slopes, 858 | Error::InconsistentSlopes(InconsistentSlopes::new(regressor_names.len(), num_slopes)) 859 | ); 860 | Ok(Self { 861 | regressor_names, 862 | model, 863 | }) 864 | } 865 | } 866 | 867 | /// A fitted regression model 868 | /// 869 | /// Is the result of [`fit_low_level_regression_model`]. 870 | /// 871 | #[derive(Debug, Clone)] 872 | pub struct LowLevelRegressionModel { 873 | /// The model's intercept and slopes (also known as betas). 874 | parameters: Vec, 875 | /// The standard errors of the parameter estimates. 876 | se: Vec, 877 | /// Sum of squared residuals. 878 | ssr: f64, 879 | /// R-squared of the model. 880 | rsquared: f64, 881 | /// Adjusted R-squared of the model. 882 | rsquared_adj: f64, 883 | /// The two-tailed p-values for the t-statistics of the params. 884 | pvalues: Vec, 885 | /// The residuals of the model. 886 | residuals: Vec, 887 | /// A scale factor for the covariance matrix. 888 | /// 889 | /// Note that the square root of `scale` is often 890 | /// called the standard error of the regression. 891 | scale: f64, 892 | } 893 | 894 | impl LowLevelRegressionModel { 895 | fn from_low_level_regression( 896 | low_level_result: InternalLowLevelRegressionResult, 897 | ) -> Result { 898 | let parameters = low_level_result.params; 899 | let singular_values = low_level_result.singular_values; 900 | let normalized_cov_params = low_level_result.normalized_cov_params; 901 | let diag = DMatrix::from_diagonal(&singular_values); 902 | let rank = &diag.rank(0.0); 903 | let input_vec = low_level_result.inputs.to_vec(); 904 | let input_matrix = DMatrix::from_vec(low_level_result.inputs.len(), 1, input_vec); 905 | let residuals = &input_matrix - (low_level_result.outputs * parameters.to_owned()); 906 | let ssr = residuals.dot(&residuals); 907 | let n = low_level_result.inputs.len(); 908 | let df_resid = n - rank; 909 | ensure!( 910 | df_resid >= 1, 911 | Error::ModelFittingError( 912 | "There are not enough residual degrees of freedom to perform statistics on this model".into())); 913 | let scale = residuals.dot(&residuals) / df_resid as f64; 914 | let cov_params = normalized_cov_params * scale; 915 | let se = get_se_from_cov_params(&cov_params); 916 | let mean = input_matrix.mean(); 917 | let mut centered_input_matrix = input_matrix; 918 | subtract_value_from_matrix(&mut centered_input_matrix, mean); 919 | let centered_tss = centered_input_matrix.dot(¢ered_input_matrix); 920 | let rsquared = 1. - (ssr / centered_tss); 921 | let rsquared_adj = 1. - ((n - 1) as f64 / df_resid as f64 * (1. - rsquared)); 922 | let tvalues = parameters 923 | .iter() 924 | .zip(se.iter()) 925 | .map(|(&x, &y)| x / y.max(f64::EPSILON)); 926 | let pvalues: Vec = tvalues 927 | .map(|x| students_t_cdf(x.abs().neg(), df_resid as i64).map(|i| i * 2.)) 928 | .collect::>() 929 | .ok_or_else(|| { 930 | Error::ModelFittingError( 931 | "Failed to calculate p-values: students_t_cdf failed due to invalid parameters" 932 | .into(), 933 | ) 934 | })?; 935 | // Convert these from internal Matrix types to user facing types 936 | let parameters: Vec = parameters.iter().copied().collect(); 937 | let residuals: Vec = residuals.iter().copied().collect(); 938 | Ok(Self { 939 | parameters, 940 | se, 941 | ssr, 942 | rsquared, 943 | rsquared_adj, 944 | pvalues, 945 | residuals, 946 | scale, 947 | }) 948 | } 949 | 950 | /// The two-tailed p-values for the t-statistics of the parameters 951 | #[inline] 952 | pub fn p_values(&self) -> &[f64] { 953 | &self.pvalues 954 | } 955 | 956 | /// The residuals of the model 957 | #[inline] 958 | pub fn residuals(&self) -> &[f64] { 959 | &self.residuals 960 | } 961 | 962 | /// The model's intercept and slopes (also known as betas) 963 | #[inline] 964 | pub fn parameters(&self) -> &[f64] { 965 | &self.parameters 966 | } 967 | 968 | /// The standard errors of the parameter estimates 969 | #[inline] 970 | pub fn se(&self) -> &[f64] { 971 | &self.se 972 | } 973 | 974 | /// Sum of squared residuals 975 | #[inline] 976 | pub fn ssr(&self) -> f64 { 977 | self.ssr 978 | } 979 | 980 | /// R-squared of the model 981 | #[inline] 982 | pub fn rsquared(&self) -> f64 { 983 | self.rsquared 984 | } 985 | 986 | /// Adjusted R-squared of the model 987 | #[inline] 988 | pub fn rsquared_adj(&self) -> f64 { 989 | self.rsquared_adj 990 | } 991 | 992 | /// A scale factor for the covariance matrix 993 | /// 994 | /// Note that the square root of `scale` is often 995 | /// called the standard error of the regression. 996 | #[inline] 997 | pub fn scale(&self) -> f64 { 998 | self.scale 999 | } 1000 | } 1001 | 1002 | /// Fit a regression model directly on a matrix of input data 1003 | /// 1004 | /// Expects a matrix in the format 1005 | /// 1006 | /// | regressand | intercept | regressor 1 | regressor 2 | … | 1007 | /// |------------|-----------|-------------|-------------|-----| 1008 | /// | value | 1.0 | value | value | … | 1009 | /// | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | 1010 | /// 1011 | /// in row major order. 1012 | /// 1013 | /// # Note 1014 | /// - The matrix should already contain the `intercept` column consisting of only the value `1.0`. 1015 | /// - No validation of the data is performed, except for a simple dimension consistency check. 1016 | /// 1017 | /// # Example 1018 | /// ``` 1019 | /// # fn main() -> Result<(), linregress::Error> { 1020 | /// use linregress::{fit_low_level_regression_model, assert_slices_almost_eq}; 1021 | /// 1022 | /// let data_row_major: Vec = vec![ 1023 | /// 1., 1.0, 1., 7., 1024 | /// 3., 1.0, 2., 6., 1025 | /// 4., 1.0, 3., 5., 1026 | /// 5., 1.0, 4., 4., 1027 | /// 2., 1.0, 5., 3., 1028 | /// 3., 1.0, 6., 2., 1029 | /// 4., 1.0, 7., 1., 1030 | /// ]; 1031 | /// let model = fit_low_level_regression_model(&data_row_major, 7, 4)?; 1032 | /// let params = [ 1033 | /// 0.09523809523809518f64, 1034 | /// 0.5059523809523807, 1035 | /// 0.2559523809523811, 1036 | /// ]; 1037 | /// assert_slices_almost_eq!(model.parameters(), ¶ms); 1038 | /// # Ok(()) 1039 | /// # } 1040 | /// ``` 1041 | pub fn fit_low_level_regression_model( 1042 | data_row_major: &[f64], 1043 | num_rows: usize, 1044 | num_columns: usize, 1045 | ) -> Result { 1046 | let regression = get_low_level_regression(data_row_major, num_rows, num_columns)?; 1047 | let model = LowLevelRegressionModel::from_low_level_regression(regression)?; 1048 | Ok(model) 1049 | } 1050 | 1051 | /// Like [`fit_low_level_regression_model`] but does not compute any statistics after 1052 | /// fitting the model. 1053 | /// 1054 | /// Returns a `Vec` analogous to the `parameters` field of [`LowLevelRegressionModel`]. 1055 | pub fn fit_low_level_regression_model_without_statistics( 1056 | data_row_major: &[f64], 1057 | num_rows: usize, 1058 | num_columns: usize, 1059 | ) -> Result, Error> { 1060 | let regression = get_low_level_regression(data_row_major, num_rows, num_columns)?; 1061 | Ok(regression.params.iter().copied().collect()) 1062 | } 1063 | 1064 | fn get_low_level_regression( 1065 | data_row_major: &[f64], 1066 | num_rows: usize, 1067 | num_columns: usize, 1068 | ) -> Result { 1069 | ensure!( 1070 | !data_row_major.is_empty() && num_rows * num_columns == data_row_major.len(), 1071 | Error::InconsistentVectors 1072 | ); 1073 | let data = DMatrix::from_row_slice(num_rows, num_columns, data_row_major); 1074 | let inputs = data.view((0, 0), (num_rows, 1)); 1075 | let inputs: Vec = inputs.iter().copied().collect(); 1076 | let outputs: DMatrix = data.view((0, 1), (num_rows, num_columns - 1)).into_owned(); 1077 | fit_ols_pinv(inputs, outputs) 1078 | } 1079 | 1080 | /// Result of fitting a low level matrix based model 1081 | #[derive(Debug, Clone)] 1082 | struct InternalLowLevelRegressionResult { 1083 | inputs: Vec, 1084 | outputs: DMatrix, 1085 | params: DMatrix, 1086 | singular_values: DVector, 1087 | normalized_cov_params: DMatrix, 1088 | } 1089 | 1090 | /// Performs ordinary least squared linear regression using the pseudo inverse method 1091 | fn fit_ols_pinv( 1092 | inputs: Vec, 1093 | outputs: DMatrix, 1094 | ) -> Result { 1095 | ensure!( 1096 | !inputs.is_empty(), 1097 | Error::ModelFittingError( 1098 | "Fitting the model failed because the input vector is empty".into() 1099 | ) 1100 | ); 1101 | ensure!( 1102 | outputs.nrows() >= 1 && outputs.ncols() >= 1, 1103 | Error::ModelFittingError( 1104 | "Fitting the model failed because the output matrix is empty".into() 1105 | ) 1106 | ); 1107 | let singular_values = outputs 1108 | .to_owned() 1109 | .try_svd(false, false, f64::EPSILON, 0) 1110 | .ok_or_else(|| { 1111 | Error::ModelFittingError( 1112 | "Computing the singular-value decomposition of the output matrix failed".into(), 1113 | ) 1114 | })? 1115 | .singular_values; 1116 | let pinv = outputs.clone().pseudo_inverse(0.).map_err(|_| { 1117 | Error::ModelFittingError("Taking the pinv of the output matrix failed".into()) 1118 | }); 1119 | let pinv = pinv?; 1120 | let normalized_cov_params = &pinv * &pinv.transpose(); 1121 | let params = get_sum_of_products(&pinv, &inputs); 1122 | ensure!( 1123 | params.len() >= 2, 1124 | Error::ModelFittingError("Invalid parameter matrix".into()) 1125 | ); 1126 | Ok(InternalLowLevelRegressionResult { 1127 | inputs, 1128 | outputs, 1129 | params, 1130 | singular_values, 1131 | normalized_cov_params, 1132 | }) 1133 | } 1134 | 1135 | fn subtract_value_from_matrix(matrix: &mut DMatrix, sub: f64) { 1136 | for i in matrix.iter_mut() { 1137 | *i -= sub; 1138 | } 1139 | } 1140 | 1141 | /// Calculates the standard errors given a model's covariate parameters 1142 | fn get_se_from_cov_params(matrix: &DMatrix) -> Vec { 1143 | matrix 1144 | .row_iter() 1145 | .enumerate() 1146 | .map(|(n, row)| row.get(n).expect("BUG: Matrix is not square").sqrt()) 1147 | .collect() 1148 | } 1149 | 1150 | fn get_sum_of_products(matrix: &DMatrix, vector: &[f64]) -> DMatrix { 1151 | DMatrix::from_iterator( 1152 | matrix.nrows(), 1153 | 1, 1154 | matrix 1155 | .row_iter() 1156 | .map(|row| row.iter().zip(vector.iter()).map(|(x, y)| x * y).sum()), 1157 | ) 1158 | } 1159 | --------------------------------------------------------------------------------