├── tests ├── R │ ├── log_termination_0 │ │ ├── aic.csv │ │ ├── bic.csv │ │ ├── deviance.csv │ │ ├── null_dev.csv │ │ ├── coefficients.csv │ │ ├── covariance.csv │ │ ├── dev_resid.csv │ │ ├── student_resid.csv │ │ ├── standard_deviance_resid.csv │ │ └── standard_pearson_resid.csv │ ├── log_regularization │ │ ├── iris_setosa_l2_1e-2.csv │ │ ├── iris_setosa_l1_l2_1e-2.csv │ │ └── iris_versicolor_l1_1e-2.csv │ ├── iris_l1_l2.R │ ├── iris_l1.R │ ├── logistic.R │ └── iris_l2.R ├── edge_cases.rs ├── lr_test_sign.rs ├── linear_offset.rs ├── logistic.rs ├── custom_link.rs ├── data │ ├── iris.csv │ ├── log_regularization.csv │ ├── lr_test_sign0.csv │ ├── lr_test_sign1.csv │ ├── log_termination_1.csv │ └── log_termination_0.csv ├── common │ └── mod.rs └── regularization.rs ├── .gitignore ├── .travis.yml ├── src ├── response.rs ├── error.rs ├── standardize.rs ├── num.rs ├── utility.rs ├── math.rs ├── fit │ └── options.rs ├── response │ ├── binomial.rs │ ├── poisson.rs │ ├── linear.rs │ └── logistic.rs ├── link.rs ├── lib.rs ├── glm.rs ├── model.rs ├── regularization.rs ├── irls.rs └── fit.rs ├── LICENSE ├── Cargo.toml └── README.md /tests/R/log_termination_0/aic.csv: -------------------------------------------------------------------------------- 1 | 73.97466 2 | -------------------------------------------------------------------------------- /tests/R/log_termination_0/bic.csv: -------------------------------------------------------------------------------- 1 | 80.56127 2 | -------------------------------------------------------------------------------- /tests/R/log_termination_0/deviance.csv: -------------------------------------------------------------------------------- 1 | 69.97466 2 | -------------------------------------------------------------------------------- /tests/R/log_termination_0/null_dev.csv: -------------------------------------------------------------------------------- 1 | 70.14744 2 | -------------------------------------------------------------------------------- /tests/R/log_termination_0/coefficients.csv: -------------------------------------------------------------------------------- 1 | 0.01697901 2 | -0.2311846 3 | -------------------------------------------------------------------------------- /tests/R/log_termination_0/covariance.csv: -------------------------------------------------------------------------------- 1 | 0.119088 2 | -0.0109707 3 | -0.0109707 4 | 0.3187949 5 | -------------------------------------------------------------------------------- /tests/R/log_regularization/iris_setosa_l2_1e-2.csv: -------------------------------------------------------------------------------- 1 | 10.23981 2 | -0.564851 3 | 2.44362 4 | -4.951192 5 | -2.414187 6 | -------------------------------------------------------------------------------- /tests/R/log_regularization/iris_setosa_l1_l2_1e-2.csv: -------------------------------------------------------------------------------- 1 | -5.494414 2 | -1.66715 3 | 2.081791 4 | -4.195334 5 | -3.804016 6 | -------------------------------------------------------------------------------- /tests/R/log_regularization/iris_versicolor_l1_1e-2.csv: -------------------------------------------------------------------------------- 1 | -0.999758 2 | -0.1946723 3 | -1.221296 4 | 2.292068 5 | -2.099191 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: rust 2 | rust: 3 | - stable 4 | - beta 5 | - nightly 6 | addons: 7 | apt: 8 | packages: 9 | - libopenblas-dev 10 | jobs: 11 | allow_failures: 12 | - rust: nightly 13 | # The nightly features have been stabilized, so there's no longer need for 14 | # additional tasks here. 15 | # include: 16 | # - rust: nightly 17 | # script: 18 | # - cargo +nightly build --verbose --features nightly 19 | # - cargo +nightly test --verbose --features nightly 20 | -------------------------------------------------------------------------------- /src/response.rs: -------------------------------------------------------------------------------- 1 | //! Response functions 2 | 3 | use crate::{error::RegressionResult, glm::Glm, num::Float}; 4 | 5 | pub mod binomial; 6 | pub mod linear; 7 | pub mod logistic; 8 | pub mod poisson; 9 | 10 | /// Describes the domain of the response variable for a GLM, e.g. integer for 11 | /// Poisson, float for Linear, bool for logistic. Implementing this trait for a 12 | /// type Y shows how to convert to a floating point type and allows that type to 13 | /// be used as a response variable. 14 | pub trait Response { 15 | /// Converts the domain to a floating-point value for IRLS. 16 | fn into_float(self) -> RegressionResult; 17 | } 18 | -------------------------------------------------------------------------------- /tests/edge_cases.rs: -------------------------------------------------------------------------------- 1 | //! Handles edge cases that have caused trouble at times. 2 | use anyhow::Result; 3 | use ndarray::{array, Array2}; 4 | use ndarray_glm::{Logistic, ModelBuilder}; 5 | use num_traits::float::FloatCore; 6 | 7 | /// Ensure that a valid likelihood is returned when the initial guess is the 8 | /// best one. 9 | #[test] 10 | fn start_zero() -> Result<()> { 11 | // Exactly half of the data are true, meaning the initial guess of beta = 0 will be the best. 12 | let data_y = array![true, false, false, true]; 13 | let data_x: Array2 = array![[], [], [], []]; 14 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 15 | let fit = model.fit()?; 16 | assert!(fit.model_like > -f64::infinity()); 17 | 18 | Ok(()) 19 | } 20 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | //! define the error enum for the result of regressions 2 | 3 | use ndarray_linalg::error::LinalgError; 4 | use thiserror::Error; 5 | 6 | #[derive(Error, Debug)] 7 | pub enum RegressionError { 8 | #[error("Inconsistent input: {0}")] 9 | BadInput(String), 10 | #[error("Invalid response data: {0}")] 11 | InvalidY(String), 12 | #[error("Linear algebra")] 13 | LinalgError { 14 | #[from] 15 | source: LinalgError, 16 | }, 17 | #[error("Underconstrained data")] 18 | Underconstrained, 19 | #[error("Colinear data (X^T * X is not invertible)")] 20 | ColinearData, 21 | #[error("Maximum iterations ({0}) reached")] 22 | MaxIter(usize), 23 | } 24 | 25 | pub type RegressionResult = Result; 26 | -------------------------------------------------------------------------------- /tests/R/iris_l1_l2.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | # This script uses glmnet: 4 | # https://glmnet.stanford.edu/articles/glmnet.html 5 | library(glmnet) 6 | 7 | infile <- "../data/iris.csv" 8 | data <- read.csv(infile, header = TRUE) 9 | x_data <- data[ 10 | , 11 | c("sepal_length", "sepal_width", "petal_length", "petal_width") 12 | ] 13 | # scale externally so that we can match the operation 14 | x_data <- scale(x_data) 15 | # This class is seperable, so this will test regularization 16 | y_data <- data["class"] == "setosa" 17 | l1 <- 1e-2 / length(y_data) 18 | l2 <- 1e-2 / length(y_data) 19 | lambda <- l1 + l2 20 | alpha <- l1 / lambda 21 | model <- glmnet( 22 | x_data, y_data, 23 | # Standardization is recommended particularly for L1, although the result is 24 | # re-scaled anyway. 25 | standardize = FALSE, 26 | alpha = alpha, 27 | lambda = lambda, 28 | thresh = 1e-20, 29 | family = "binomial", 30 | ) 31 | beta <- coef(model) 32 | print(beta) 33 | beta <- beta[, "s0"] 34 | write( 35 | beta, 36 | file = "log_regularization/iris_setosa_l1_l2_1e-2.csv", 37 | sep = "\n", 38 | ) 39 | -------------------------------------------------------------------------------- /src/standardize.rs: -------------------------------------------------------------------------------- 1 | //! Standardization of a design matrix. 2 | use ndarray::{Array1, Array2, Axis}; 3 | use num_traits::{Float, FromPrimitive}; 4 | 5 | /// Returns a standardization of a design matrix where rows are seperate 6 | /// observations and columns are different dependent variables. Each quantity 7 | /// has its mean subtracted and is then divided by the standard deviation. 8 | pub fn standardize(mut design: Array2) -> Array2 9 | where 10 | F: Float + FromPrimitive + std::ops::DivAssign, 11 | { 12 | let n_obs: usize = design.nrows(); 13 | if n_obs >= 1 { 14 | // subtract the mean 15 | design = &design - &design.mean_axis(Axis(0)).expect("mean should succeed here"); 16 | } 17 | if n_obs >= 2 { 18 | // divide by the population standard deviation 19 | let std: Array1 = design.std_axis(Axis(0), F::zero()); 20 | // design = &design / &std; 21 | design.zip_mut_with(&std, |x, &sig| { 22 | if sig > F::zero() { 23 | *x /= sig; 24 | } 25 | }) 26 | } 27 | design 28 | } 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Felix Clark 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/R/iris_l1.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | # This script uses glmnet: 4 | # https://glmnet.stanford.edu/articles/glmnet.html 5 | library(glmnet) 6 | 7 | infile <- "../data/iris.csv" 8 | data <- read.csv(infile, header = TRUE) 9 | x_data <- data[ 10 | , 11 | c("sepal_length", "sepal_width", "petal_length", "petal_width") 12 | ] 13 | # scale externally so that we can match the operation 14 | x_data <- scale(x_data) 15 | # This class is seperable, so this will test regularization 16 | # y_data <- data["class"] == "setosa" 17 | y_data <- data["class"] == "versicolor" 18 | l1 <- 1e-2 / length(y_data) 19 | model <- glmnet( 20 | x_data, y_data, 21 | # Standardization is recommended particularly for L1, although the result is 22 | # re-scaled anyway. 23 | standardize = FALSE, 24 | # the LASSO penalty 25 | alpha = 1, 26 | # With this dataset, a large lambda zeros out all coefficients 27 | lambda = l1, 28 | # lambda = 0, 29 | thresh = 1e-20, 30 | family = "binomial", 31 | ) 32 | beta <- coef(model) 33 | print(beta) 34 | beta <- beta[, "s0"] 35 | write(beta, file = "log_regularization/iris_versicolor_l1_1e-2.csv", sep = "\n") 36 | -------------------------------------------------------------------------------- /src/num.rs: -------------------------------------------------------------------------------- 1 | //! numerical trait constraints 2 | use std::cmp; 3 | 4 | use ndarray::ScalarOperand; 5 | use ndarray_linalg::Lapack; 6 | 7 | pub trait Float: Sized + num_traits::Float + Lapack + ScalarOperand { 8 | // Return 1/2 = 0.5 9 | fn half() -> Self; 10 | 11 | /// A more conventional sign function, because the built-in signum treats signed zeros as 12 | /// positive and negative: https://github.com/rust-lang/rust/issues/57543 13 | fn sign(self) -> Self { 14 | if self == Self::zero() { 15 | Self::zero() 16 | } else { 17 | self.signum() 18 | } 19 | } 20 | 21 | /// total_cmp is not implemented in num_traits, so implement it here. 22 | fn total_cmp(&self, other: &Self) -> cmp::Ordering; 23 | } 24 | 25 | impl Float for f32 { 26 | fn half() -> Self { 27 | 0.5 28 | } 29 | 30 | fn total_cmp(&self, other: &Self) -> cmp::Ordering { 31 | self.total_cmp(other) 32 | } 33 | } 34 | impl Float for f64 { 35 | fn half() -> Self { 36 | 0.5 37 | } 38 | 39 | fn total_cmp(&self, other: &Self) -> cmp::Ordering { 40 | self.total_cmp(other) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /tests/R/logistic.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | infile <- "../data/log_termination_0.csv" 4 | data <- read.csv(infile, header=FALSE) 5 | model <- glm(V1 ~ V2, data, family="binomial", offset=V3) 6 | print(model) 7 | coefs <- model$coefficients 8 | write(coefs, file="log_termination_0/coefficients.csv", ncolumns=1) 9 | mod_sum <- summary(model) 10 | write(mod_sum$deviance.resid, file="log_termination_0/dev_resid.csv", ncolumns=1) 11 | # For logistic the scaled and unscaled covariances are the same 12 | # print(mod_sum$cov.scaled) 13 | cov_mat <- mod_sum$cov.unscaled 14 | write(cov_mat, file="log_termination_0/covariance.csv", ncolumns=1) 15 | write(model$deviance, file="log_termination_0/deviance.csv", ncolumns=1) 16 | write(model$null.deviance, file="log_termination_0/null_dev.csv", ncolumns=1) 17 | write(model$aic, file="log_termination_0/aic.csv", ncolumns=1) 18 | write(BIC(model), file="log_termination_0/bic.csv", ncolumns=1) 19 | write(rstandard(model, type="pearson"), file="log_termination_0/standard_pearson_resid.csv", ncolumns=1) 20 | write(rstandard(model, type="deviance"), file="log_termination_0/standard_deviance_resid.csv", ncolumns=1) 21 | write(rstudent(model), file="log_termination_0/student_resid.csv", ncolumns=1) 22 | 23 | # TODO: wald and score tests, bic, etc. 24 | # wald_score <- wald.test(cov_mat, coefs) 25 | # write(wald_score, file="log_termination_0/wald.csv") 26 | -------------------------------------------------------------------------------- /tests/R/iris_l2.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | # This script uses glmnet: 4 | # https://glmnet.stanford.edu/articles/glmnet.html 5 | library(glmnet) 6 | 7 | infile <- "../data/iris.csv" 8 | data <- read.csv(infile, header = TRUE) 9 | x_data <- data[ 10 | , 11 | c("sepal_length", "sepal_width", "petal_length", "petal_width") 12 | ] 13 | # This class is seperable, so this will test regularization 14 | y_data <- data["class"] == "setosa" 15 | # y_data <- data["class"] == "versicolor" 16 | # The glmnet package divides the squared errors by N, so to match our 17 | # convention we need to scale their lambda too. 18 | l2 <- 1e-2 / length(y_data) 19 | model <- glmnet( 20 | x_data, y_data, 21 | # Standardization is recommended particularly for L1, but it doesn't change 22 | # the result because the result is re-scaled 23 | # standardize = TRUE, 24 | standardize = FALSE, 25 | # the ridge penalty 26 | alpha = 0, 27 | # With this dataset, a large lambda zeros out all coefficients 28 | lambda = l2, 29 | # the tolerance has to be much smaller to make this result more precise. 30 | thresh = 1e-10, 31 | family = "binomial", 32 | ) 33 | print(model) 34 | beta <- coef(model) 35 | print(beta) 36 | beta <- beta[, "s0"] 37 | # There are convenience functions to read array from single-column files 38 | write(beta, file = "log_regularization/iris_setosa_l2_1e-2.csv", sep = "\n") 39 | -------------------------------------------------------------------------------- /tests/lr_test_sign.rs: -------------------------------------------------------------------------------- 1 | //! Data that has resulted in a significantly negative LR test. 2 | mod common; 3 | use anyhow::Result; 4 | use common::y_x_off_from_csv; 5 | use ndarray_glm::{Logistic, ModelBuilder}; 6 | 7 | #[test] 8 | fn lr_test_sign0() -> Result<()> { 9 | // TODO: this assumes the tests are run from the root directory of the 10 | // crate. This might not be true in general, but it often will be. 11 | let (y, x, off) = y_x_off_from_csv::("tests/data/lr_test_sign0.csv")?; 12 | let model = ModelBuilder::::data(&y, &x) 13 | .linear_offset(off) 14 | .build()?; 15 | let fit = model.fit_options().l2_reg(2e-6).fit()?; 16 | dbg!(&fit.result); 17 | assert!(fit.lr_test() >= 0.); 18 | Ok(()) 19 | } 20 | 21 | // This test seems to have a first step that has a big jump but lands at exactly the same 22 | // likelihood, so it's useful for testing the step halving and termination logic. 23 | #[test] 24 | fn lr_test_sign1() -> Result<()> { 25 | let (y, x, off) = y_x_off_from_csv::("tests/data/lr_test_sign1.csv")?; 26 | let model = ModelBuilder::::data(&y, &x) 27 | .linear_offset(off) 28 | .build()?; 29 | // This fit failed with regularization in the range of about 3e-7 to 3e-6. 30 | // Only a single iteration was performed in this case, because step halving was not being 31 | // engaged when it should have. 32 | let fit = model.fit_options().l2_reg(1e-6).fit()?; 33 | assert!(fit.lr_test() >= 0.); 34 | Ok(()) 35 | } 36 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ndarray-glm" 3 | version = "0.0.12" 4 | authors = ["Felix Clark "] 5 | description = "Performs regression for generalized linear models using IRLS on data stored in arrays." 6 | edition = "2021" 7 | repository = "https://github.com/felix-clark/ndarray-glm" 8 | readme = "README.md" 9 | license = "MIT" 10 | keywords = ["ndarray", "statistics", "regression", "glm", "irls"] 11 | categories = ["mathematics", "science"] 12 | 13 | [badges] 14 | maintenance = { status = "experimental" } 15 | travis-ci = { repository = "felix-clark/ndarray-glm" } 16 | 17 | [features] 18 | # Feature flags to forward to ndarray-linalg 19 | openblas-static = ["ndarray-linalg/openblas-static"] 20 | openblas-system = ["ndarray-linalg/openblas-system"] 21 | netlib-static = ["ndarray-linalg/netlib-static"] 22 | netlib-system = ["ndarray-linalg/netlib-system"] 23 | intel-mkl-static = ["ndarray-linalg/intel-mkl-static"] 24 | intel-mkl-system = ["ndarray-linalg/intel-mkl-system"] 25 | 26 | [dependencies] 27 | itertools = "0.10" 28 | ndarray = { version = "0.15", features = ["blas"] } 29 | ndarray-linalg = { version = "0.16" } 30 | num-traits = "0.2" 31 | thiserror = "1.0" 32 | 33 | [dev-dependencies] 34 | anyhow = "1.0" 35 | approx = "0.4" 36 | ndarray = { version = "0.15", features = ["blas", "approx"] } 37 | ndarray-linalg = { version = "0.16", default-features = false, features = ["openblas-system"]} 38 | 39 | # [package.metadata.docs.rs] 40 | # TODO: use ndarray-linalg example for katex in docs. This HTML file needs to be added. 41 | # rustdoc-args = ["--html-in-header", "katex-header.html"] 42 | -------------------------------------------------------------------------------- /src/utility.rs: -------------------------------------------------------------------------------- 1 | //! utility functions for internal library use 2 | 3 | use ndarray::{concatenate, Array1, Array2, ArrayView2, Axis}; 4 | use num_traits::{ 5 | identities::One, 6 | {Float, FromPrimitive}, 7 | }; 8 | 9 | /// Prepend the input with a column of ones. 10 | /// Used to incorporate a constant intercept term in a regression. 11 | pub fn one_pad(data: ArrayView2) -> Array2 12 | where 13 | T: Copy + One, 14 | { 15 | // create the ones column 16 | let ones: Array2 = Array2::ones((data.nrows(), 1)); 17 | // This should be guaranteed to succeed since we are manually specifying the dimension 18 | concatenate![Axis(1), ones, data] 19 | } 20 | 21 | /// Returns a standardization of a design matrix where rows are seperate 22 | /// observations and columns are different dependent variables. Each quantity 23 | /// has its mean subtracted and is then divided by the standard deviation. 24 | /// The normalization by the standard deviation is not performed if there is only 1 25 | /// observation, since such an operation is undefined. 26 | pub fn standardize(mut design: Array2) -> Array2 27 | where 28 | F: Float + FromPrimitive + std::ops::DivAssign, 29 | { 30 | let n_obs: usize = design.nrows(); 31 | if n_obs >= 1 { 32 | // subtract the mean 33 | design = &design - &design.mean_axis(Axis(0)).expect("mean should succeed here"); 34 | } 35 | if n_obs >= 2 { 36 | // divide by the population standard deviation 37 | let std: Array1 = design.std_axis(Axis(0), F::zero()); 38 | // design = &design / &std; 39 | design.zip_mut_with(&std, |x, &sig| { 40 | if sig > F::zero() { 41 | *x /= sig; 42 | } 43 | }) 44 | } 45 | design 46 | } 47 | -------------------------------------------------------------------------------- /src/math.rs: -------------------------------------------------------------------------------- 1 | //! Mathematical helper functions 2 | use crate::num::Float; 3 | use ndarray::Array2; 4 | use ndarray_linalg::QRSquareInto; 5 | 6 | /// The product-logarithm function (not the W function) x * log(x). If x == 0, 0 is returned. 7 | pub fn prod_log(x: F) -> F 8 | where 9 | F: Float, 10 | { 11 | if x == F::zero() { 12 | return F::zero(); 13 | } 14 | x * num_traits::Float::ln(x) 15 | } 16 | 17 | /// Returns true iff the matrix is rank deficient with tolerance `eps` using QR 18 | /// decomposition. 19 | // NOTE: SVD may be faster 20 | pub fn is_rank_deficient(matrix: Array2, eps: F) -> ndarray_linalg::error::Result 21 | where 22 | F: Float, 23 | { 24 | if matrix.ncols() != matrix.nrows() { 25 | return Ok(true); 26 | } 27 | let (_, r) = matrix.qr_square_into()?; 28 | let diag = r.into_diag(); 29 | for e in diag.into_iter() { 30 | if num_traits::Float::abs(e) < eps { 31 | return Ok(true); 32 | } 33 | } 34 | Ok(false) 35 | } 36 | 37 | #[cfg(test)] 38 | mod tests { 39 | use super::*; 40 | use crate::array; 41 | use approx::assert_abs_diff_eq; 42 | 43 | #[test] 44 | fn test_prod_log() { 45 | assert_abs_diff_eq!(0., prod_log(0.)); 46 | let e: f64 = std::f64::consts::E; 47 | assert_abs_diff_eq!(e, prod_log(e)); 48 | } 49 | 50 | #[test] 51 | fn test_rank_def() { 52 | assert!(is_rank_deficient(array![[0., 1.]], 0.).unwrap()); 53 | assert!(!is_rank_deficient(array![[0., 1.], [2., 0.]], f32::EPSILON as f64).unwrap()); 54 | assert!(is_rank_deficient(array![[0., 1.], [0., 2.342]], f64::EPSILON).unwrap()); 55 | assert!(is_rank_deficient( 56 | array![[1., 1., 0.], [1., 0.5, 0.5], [1., 0.2, 0.8]], 57 | f64::EPSILON 58 | ) 59 | .unwrap()); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/fit/options.rs: -------------------------------------------------------------------------------- 1 | //! Fit-specific configuration and fit builder 2 | use super::Fit; 3 | use crate::{error::RegressionResult, glm::Glm, model::Model, num::Float, Array1}; 4 | 5 | /// A builder struct for fit configuration 6 | pub struct FitConfig<'a, M, F> 7 | where 8 | M: Glm, 9 | F: Float, 10 | { 11 | pub(crate) model: &'a Model, 12 | pub options: FitOptions, 13 | } 14 | 15 | impl<'a, M, F> FitConfig<'a, M, F> 16 | where 17 | M: Glm, 18 | F: Float, 19 | { 20 | pub fn fit(self) -> RegressionResult> { 21 | M::regression(self.model, self.options) 22 | } 23 | 24 | /// Use a maximum number of iterations 25 | pub fn max_iter(mut self, max_iter: usize) -> Self { 26 | self.options.max_iter = max_iter; 27 | self 28 | } 29 | 30 | /// Set the tolerance of iteration 31 | pub fn tol(mut self, tol: F) -> Self { 32 | self.options.tol = tol; 33 | self 34 | } 35 | 36 | /// Use to set a L2 regularization parameter 37 | pub fn l2_reg(mut self, l2: F) -> Self { 38 | self.options.l2 = l2; 39 | self 40 | } 41 | 42 | pub fn l1_reg(mut self, l1: F) -> Self { 43 | self.options.l1 = l1; 44 | self 45 | } 46 | } 47 | 48 | /// Specifies the fitting options 49 | pub struct FitOptions 50 | where 51 | F: Float, 52 | { 53 | /// The maximum number of IRLS iterations 54 | pub max_iter: usize, 55 | /// The relative tolerance of the likelihood 56 | pub tol: F, 57 | /// The regularization of the fit 58 | pub l2: F, 59 | pub l1: F, 60 | /// An initial guess. A sensible default is selected if this is not provided. 61 | pub init_guess: Option>, 62 | } 63 | 64 | impl Default for FitOptions 65 | where 66 | F: Float, 67 | { 68 | fn default() -> Self { 69 | Self { 70 | max_iter: 32, 71 | // This tolerance is rather small, but it is used in the context of a 72 | // fraction of the total likelihood. 73 | tol: F::epsilon(), 74 | l2: F::zero(), 75 | l1: F::zero(), 76 | init_guess: None, 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /tests/linear_offset.rs: -------------------------------------------------------------------------------- 1 | //! testing closure with a linear offset 2 | 3 | use anyhow::Result; 4 | use approx::assert_abs_diff_eq; 5 | use ndarray::{array, Array1, Array2}; 6 | use ndarray_glm::{Linear, ModelBuilder}; 7 | 8 | #[test] 9 | /// Check that the result is the same in linear regression when subtracting 10 | /// offsets from the y values as it is when adding linear offsets to the model. 11 | fn lin_off_0() -> Result<()> { 12 | let y_data: Array1 = array![0.6, 0.3, 0.5, 0.1]; 13 | let offsets: Array1 = array![0.1, -0.1, 0.2, 0.0]; 14 | let x_data: Array2 = array![[1.2, 0.7], [2.1, 0.8], [1.5, 0.6], [1.6, 0.3]]; 15 | 16 | let lin_model = ModelBuilder::::data(&y_data, &x_data) 17 | .linear_offset(offsets.clone()) 18 | .build()?; 19 | let lin_fit = lin_model.fit()?; 20 | let y_offset = y_data - offsets; 21 | let lin_model_off = ModelBuilder::::data(&y_offset, &x_data).build()?; 22 | let lin_fit_off = lin_model_off.fit()?; 23 | dbg!(&lin_fit.result); 24 | dbg!(&lin_fit_off.result); 25 | // Ensure that the two methods give consistent results 26 | assert_abs_diff_eq!( 27 | lin_fit.result, 28 | lin_fit_off.result, 29 | epsilon = 16.0 * f64::EPSILON 30 | ); 31 | 32 | Ok(()) 33 | } 34 | 35 | #[test] 36 | // Ensure that the linear offset term adjusts all values sanely. 37 | // TODO: similar test for all types of regression, to ensure they are using 38 | // linear_predictor() properly. 39 | fn lin_off_1() -> Result<()> { 40 | let data_x = array![ 41 | [-0.23, 2.1, 0.7], 42 | [1.2, 4.5, 1.3], 43 | [0.42, 1.8, 0.97], 44 | [0.4, 3.2, -0.3] 45 | ]; 46 | let data_y = array![1.23, 0.91, 2.34, 0.62]; 47 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 48 | let fit = model.fit()?; 49 | let result = fit.result; 50 | // a constant linear offset to add for easy checking 51 | let lin_off = 1.832; 52 | let lin_offsets = array![lin_off, lin_off, lin_off, lin_off]; 53 | let model_off = ModelBuilder::::data(&data_y, &data_x) 54 | .linear_offset(lin_offsets) 55 | .build()?; 56 | let off_fit = model_off.fit()?; 57 | dbg!(off_fit.n_iter); 58 | let off_result = off_fit.result; 59 | let mut compensated_offset_result = off_result; 60 | compensated_offset_result[0] += lin_off; 61 | assert_abs_diff_eq!( 62 | result, 63 | compensated_offset_result, 64 | epsilon = 32. * f64::EPSILON 65 | ); 66 | Ok(()) 67 | } 68 | -------------------------------------------------------------------------------- /src/response/binomial.rs: -------------------------------------------------------------------------------- 1 | //! Regression with a binomial response function. The N parameter must be known ahead of time. 2 | use crate::{ 3 | error::{RegressionError, RegressionResult}, 4 | glm::{DispersionType, Glm}, 5 | math::prod_log, 6 | num::Float, 7 | response::Response, 8 | }; 9 | 10 | /// Use a fixed type of u16 for the domain of the binomial distribution. 11 | type BinDom = u16; 12 | 13 | /// Binomial regression with a fixed N. Non-canonical link functions are not 14 | /// possible at this time due to the awkward ergonomics with the const trait 15 | /// parameter N. 16 | pub struct Binomial; 17 | 18 | impl Response> for BinDom { 19 | fn into_float(self) -> RegressionResult { 20 | F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string())) 21 | } 22 | } 23 | 24 | impl Glm for Binomial { 25 | /// Only the canonical link function is available for binomial regression. 26 | type Link = link::Logit; 27 | const DISPERSED: DispersionType = DispersionType::NoDispersion; 28 | 29 | /// The log-partition function for the binomial distribution is similar to 30 | /// that for logistic regression, but it is adjusted for the maximum value. 31 | fn log_partition(nat_par: F) -> F { 32 | let n: F = F::from(N).unwrap(); 33 | n * num_traits::Float::exp(nat_par).ln_1p() 34 | } 35 | 36 | fn variance(mean: F) -> F { 37 | let n_float: F = F::from(N).unwrap(); 38 | mean * (n_float - mean) / n_float 39 | } 40 | 41 | fn log_like_sat(y: F) -> F { 42 | let n: F = F::from(N).unwrap(); 43 | prod_log(y) + prod_log(n - y) - prod_log(n) 44 | } 45 | } 46 | 47 | pub mod link { 48 | use super::*; 49 | use crate::link::{Canonical, Link}; 50 | use num_traits::Float; 51 | 52 | pub struct Logit {} 53 | impl Canonical for Logit {} 54 | impl Link> for Logit { 55 | fn func(y: F) -> F { 56 | let n_float: F = F::from(N).unwrap(); 57 | Float::ln(y / (n_float - y)) 58 | } 59 | fn func_inv(lin_pred: F) -> F { 60 | let n_float: F = F::from(N).unwrap(); 61 | n_float / (F::one() + (-lin_pred).exp()) 62 | } 63 | } 64 | } 65 | 66 | #[cfg(test)] 67 | mod tests { 68 | use super::Binomial; 69 | use crate::{error::RegressionResult, model::ModelBuilder}; 70 | use approx::assert_abs_diff_eq; 71 | use ndarray::array; 72 | 73 | #[test] 74 | fn bin_reg() -> RegressionResult<()> { 75 | const N: u16 = 12; 76 | let ln2 = f64::ln(2.); 77 | let beta = array![0., 1.]; 78 | let data_x = array![[0.], [0.], [ln2], [ln2], [ln2]]; 79 | // the first two data points should average to 6 and the last 3 should average to 8. 80 | let data_y = array![5, 7, 9, 6, 9]; 81 | let model = ModelBuilder::>::data(&data_y, &data_x).build()?; 82 | let fit = model.fit()?; 83 | dbg!(&fit.result); 84 | dbg!(&fit.n_iter); 85 | assert_abs_diff_eq!(beta, fit.result, epsilon = 0.05 * f32::EPSILON as f64); 86 | Ok(()) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/response/poisson.rs: -------------------------------------------------------------------------------- 1 | //! Model for Poisson regression 2 | 3 | use crate::{ 4 | error::{RegressionError, RegressionResult}, 5 | glm::{DispersionType, Glm}, 6 | link::Link, 7 | math::prod_log, 8 | num::Float, 9 | response::Response, 10 | }; 11 | use num_traits::{ToPrimitive, Unsigned}; 12 | use std::marker::PhantomData; 13 | 14 | /// Poisson regression over an unsigned integer type. 15 | pub struct Poisson 16 | where 17 | L: Link>, 18 | { 19 | _link: PhantomData, 20 | } 21 | 22 | /// Poisson variables can be any unsigned integer. 23 | impl Response> for U 24 | where 25 | U: Unsigned + ToPrimitive + ToString + Copy, 26 | L: Link>, 27 | { 28 | fn into_float(self) -> RegressionResult { 29 | F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string())) 30 | } 31 | } 32 | // TODO: A floating point response for Poisson might also be do-able. 33 | 34 | impl Glm for Poisson 35 | where 36 | L: Link>, 37 | { 38 | type Link = L; 39 | const DISPERSED: DispersionType = DispersionType::NoDispersion; 40 | 41 | /// The logarithm of the partition function for Poisson is the exponential of the natural 42 | /// parameter, which is the logarithm of the mean. 43 | fn log_partition(nat_par: F) -> F { 44 | num_traits::Float::exp(nat_par) 45 | } 46 | 47 | /// The variance of a Poisson variable is equal to its mean. 48 | fn variance(mean: F) -> F { 49 | mean 50 | } 51 | 52 | /// The saturation likelihood of the Poisson distribution is non-trivial. 53 | /// It is equal to y * (log(y) - 1). 54 | fn log_like_sat(y: F) -> F { 55 | prod_log(y) - y 56 | } 57 | } 58 | 59 | pub mod link { 60 | //! Link functions for Poisson regression 61 | use super::Poisson; 62 | use crate::{ 63 | link::{Canonical, Link}, 64 | num::Float, 65 | }; 66 | 67 | /// The canonical link function of the Poisson response is the logarithm. 68 | pub struct Log {} 69 | impl Canonical for Log {} 70 | impl Link> for Log { 71 | fn func(y: F) -> F { 72 | num_traits::Float::ln(y) 73 | } 74 | fn func_inv(lin_pred: F) -> F { 75 | num_traits::Float::exp(lin_pred) 76 | } 77 | } 78 | } 79 | 80 | #[cfg(test)] 81 | mod tests { 82 | use super::*; 83 | use crate::{error::RegressionResult, model::ModelBuilder}; 84 | use approx::assert_abs_diff_eq; 85 | use ndarray::{array, Array1}; 86 | 87 | #[test] 88 | fn poisson_reg() -> RegressionResult<()> { 89 | let ln2 = f64::ln(2.); 90 | let beta = array![0., ln2, -ln2]; 91 | let data_x = array![[1., 0.], [1., 1.], [0., 1.], [0., 1.]]; 92 | let data_y: Array1 = array![2, 1, 0, 1]; 93 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 94 | let fit = model.fit_options().max_iter(10).fit()?; 95 | dbg!(fit.n_iter); 96 | assert_abs_diff_eq!(beta, fit.result, epsilon = f32::EPSILON as f64); 97 | Ok(()) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /tests/R/log_termination_0/dev_resid.csv: -------------------------------------------------------------------------------- 1 | 2.563145 2 | 2.71281 3 | 2.557949 4 | 2.475573 5 | 2.577566 6 | 1.634185 7 | 2.39281 8 | 2.580669 9 | 2.111554 10 | -0.2312484 11 | -0.2389869 12 | -0.3052555 13 | -0.3432332 14 | -0.2688686 15 | -0.2682756 16 | -0.3598812 17 | -0.4542541 18 | -0.3190135 19 | -0.2617458 20 | -0.2125298 21 | -0.3001661 22 | -0.3476463 23 | -0.2478706 24 | -0.2292369 25 | -0.2242134 26 | -0.2710139 27 | -0.2370418 28 | -0.2937149 29 | -0.2160693 30 | -0.3426471 31 | -0.2523456 32 | -0.3714077 33 | -0.2117829 34 | -0.3042595 35 | -0.2482059 36 | -0.2874135 37 | -0.3221059 38 | -0.2489937 39 | -0.308419 40 | -0.3376255 41 | -0.1807792 42 | -0.303479 43 | -0.1804502 44 | -0.3003403 45 | -0.2396847 46 | -0.2669215 47 | -0.1857231 48 | -0.2029178 49 | -0.2714028 50 | -0.2972798 51 | -0.1986253 52 | -0.2976784 53 | -0.231062 54 | -0.1958145 55 | -0.3974662 56 | -0.4039628 57 | -0.1921727 58 | -0.2193634 59 | -0.3240725 60 | -0.2650228 61 | -0.2664074 62 | -0.3539208 63 | -0.2301521 64 | -0.3209368 65 | -0.3683485 66 | -0.4995805 67 | -0.2696586 68 | -0.5156772 69 | -0.1891902 70 | -0.2738336 71 | -0.3753976 72 | -0.3090243 73 | -0.2291766 74 | -0.3729128 75 | -0.1987149 76 | -0.1705557 77 | -0.1944996 78 | -0.1861347 79 | -0.2618028 80 | -0.3849316 81 | -0.2927861 82 | -0.3956183 83 | -0.3285308 84 | -0.265307 85 | -0.2862236 86 | -0.5472937 87 | -0.2971533 88 | -0.3478463 89 | -0.2863305 90 | -0.2500459 91 | -0.3824633 92 | -0.4729475 93 | -0.3499525 94 | -0.1059287 95 | -0.267188 96 | -0.2993311 97 | -0.2214057 98 | -0.3419842 99 | -0.3014065 100 | -0.359729 101 | -0.2242159 102 | -0.185147 103 | -0.4442257 104 | -0.3768091 105 | -0.2877922 106 | -0.2984893 107 | -0.2855188 108 | -0.2580203 109 | -0.2832515 110 | -0.3167523 111 | -0.1577498 112 | -0.2572238 113 | -0.3637878 114 | -0.3649873 115 | -0.4129295 116 | -0.324241 117 | -0.3669546 118 | -0.3268634 119 | -0.2859762 120 | -0.2924243 121 | -0.3804754 122 | -0.2312695 123 | -0.280733 124 | -0.2676919 125 | -0.204026 126 | -0.2211961 127 | -0.3715868 128 | -0.09556253 129 | -0.3251646 130 | -0.3593632 131 | -0.2560001 132 | -0.1933726 133 | -0.2720178 134 | -0.2755992 135 | -0.2695606 136 | -0.4143107 137 | -0.1389016 138 | -0.3704506 139 | -0.369172 140 | -0.3059059 141 | -0.2017364 142 | -0.2702863 143 | -0.2733071 144 | -0.3360871 145 | -0.573603 146 | -0.267818 147 | -0.2522235 148 | -0.26726 149 | -0.1850362 150 | -0.1110951 151 | -0.2638523 152 | -0.380087 153 | -0.2331742 154 | -0.1966642 155 | -0.2525341 156 | -0.3262789 157 | -0.2461315 158 | -0.2134938 159 | -0.2857572 160 | -0.2534688 161 | -0.2828471 162 | -0.3719114 163 | -0.199929 164 | -0.4029659 165 | -0.4392699 166 | -0.2864539 167 | -0.2778437 168 | -0.3422333 169 | -0.3956187 170 | -0.3336082 171 | -0.2686468 172 | -0.2201903 173 | -0.1748509 174 | -0.2316667 175 | -0.196696 176 | -0.3335312 177 | -0.2892762 178 | -0.2824684 179 | -0.2405036 180 | -0.5137442 181 | -0.3697195 182 | -0.3770862 183 | -0.3852902 184 | -0.1762655 185 | -0.3261795 186 | -0.26832 187 | -0.2237708 188 | -0.1379423 189 | -0.230036 190 | -0.195971 191 | -0.2524884 192 | -0.2125275 193 | -0.5141427 194 | -0.2793855 195 | -0.2335425 196 | -0.2568786 197 | -0.3590793 198 | -0.3090381 199 | -0.3398795 200 | -------------------------------------------------------------------------------- /tests/R/log_termination_0/student_resid.csv: -------------------------------------------------------------------------------- 1 | 2.587682 2 | 2.746958 3 | 2.583108 4 | 2.512852 5 | 2.608893 6 | 1.654251 7 | 2.415318 8 | 2.615712 9 | 2.134113 10 | -0.2319457 11 | -0.2392105 12 | -0.3058236 13 | -0.3462758 14 | -0.2692078 15 | -0.2685564 16 | -0.3607301 17 | -0.4557995 18 | -0.3201088 19 | -0.2624018 20 | -0.2127511 21 | -0.3009526 22 | -0.3485044 23 | -0.2484579 24 | -0.2300924 25 | -0.2243881 26 | -0.2714596 27 | -0.2375521 28 | -0.295254 29 | -0.2162214 30 | -0.3459671 31 | -0.2525779 32 | -0.3721678 33 | -0.2120097 34 | -0.3047669 35 | -0.2493472 36 | -0.2877781 37 | -0.3225911 38 | -0.2501642 39 | -0.3089554 40 | -0.3404515 41 | -0.1808865 42 | -0.3038796 43 | -0.1808051 44 | -0.3012676 45 | -0.2398962 46 | -0.2675434 47 | -0.1858893 48 | -0.2033739 49 | -0.271905 50 | -0.2980472 51 | -0.1987472 52 | -0.2980586 53 | -0.2312794 54 | -0.1960733 55 | -0.3984017 56 | -0.404877 57 | -0.1924043 58 | -0.2196824 59 | -0.3245705 60 | -0.2653371 61 | -0.2671174 62 | -0.3562378 63 | -0.2305505 64 | -0.3228536 65 | -0.3695073 66 | -0.502978 67 | -0.2711627 68 | -0.5192499 69 | -0.1892891 70 | -0.2741298 71 | -0.3761998 72 | -0.3097999 73 | -0.229356 74 | -0.3744174 75 | -0.1988427 76 | -0.1706503 77 | -0.1949097 78 | -0.186265 79 | -0.2622617 80 | -0.3897367 81 | -0.294643 82 | -0.396525 83 | -0.3291793 84 | -0.2666356 85 | -0.2866201 86 | -0.5507158 87 | -0.2975614 88 | -0.3487386 89 | -0.2871819 90 | -0.2503083 91 | -0.383297 92 | -0.4803092 93 | -0.3506951 94 | -0.1059773 95 | -0.2674639 96 | -0.2999573 97 | -0.2218031 98 | -0.34318 99 | -0.3020078 100 | -0.360933 101 | -0.2246472 102 | -0.1854214 103 | -0.4462828 104 | -0.3812184 105 | -0.2885384 106 | -0.2994272 107 | -0.2859054 108 | -0.2583976 109 | -0.2838008 110 | -0.3178094 111 | -0.1578824 112 | -0.2575097 113 | -0.3644609 114 | -0.3665041 115 | -0.414233 116 | -0.3247601 117 | -0.3678703 118 | -0.3293673 119 | -0.2864585 120 | -0.2928615 121 | -0.3813885 122 | -0.2316296 123 | -0.2811741 124 | -0.2681558 125 | -0.2043816 126 | -0.2213806 127 | -0.3726801 128 | -0.09557555 129 | -0.3258438 130 | -0.3605564 131 | -0.2564229 132 | -0.193601 133 | -0.2733947 134 | -0.2762568 135 | -0.2698511 136 | -0.4162655 137 | -0.1390288 138 | -0.3736852 139 | -0.3707958 140 | -0.3063801 141 | -0.201857 142 | -0.2717656 143 | -0.27379 144 | -0.3369388 145 | -0.5781868 146 | -0.2684582 147 | -0.252513 148 | -0.2679379 149 | -0.1851703 150 | -0.1111294 151 | -0.2641326 152 | -0.3810984 153 | -0.2340801 154 | -0.1969205 155 | -0.2530023 156 | -0.3273567 157 | -0.2465098 158 | -0.213663 159 | -0.2861991 160 | -0.2542116 161 | -0.2836488 162 | -0.372664 163 | -0.2000606 164 | -0.4043297 165 | -0.4404808 166 | -0.2868014 167 | -0.2782634 168 | -0.342892 169 | -0.3969143 170 | -0.3342588 171 | -0.2693688 172 | -0.2209575 173 | -0.1749822 174 | -0.2318698 175 | -0.197108 176 | -0.3344105 177 | -0.2896621 178 | -0.283258 179 | -0.240743 180 | -0.515533 181 | -0.370965 182 | -0.3783169 183 | -0.3868603 184 | -0.1766237 185 | -0.3267775 186 | -0.26882 187 | -0.224594 188 | -0.1379905 189 | -0.2302183 190 | -0.1961036 191 | -0.2529098 192 | -0.2126827 193 | -0.5163152 194 | -0.2810275 195 | -0.2338765 196 | -0.2575158 197 | -0.3605888 198 | -0.3100438 199 | -0.3408504 200 | -------------------------------------------------------------------------------- /tests/R/log_termination_0/standard_deviance_resid.csv: -------------------------------------------------------------------------------- 1 | 2.569438 2 | 2.719348 3 | 2.564466 4 | 2.486821 5 | 2.585393 6 | 1.653321 7 | 2.40064 8 | 2.589377 9 | 2.123711 10 | -0.2326224 11 | -0.2394276 12 | -0.3063645 13 | -0.3491167 14 | -0.2695344 15 | -0.2688269 16 | -0.3615227 17 | -0.4571837 18 | -0.3211455 19 | -0.2630339 20 | -0.2129671 21 | -0.301702 22 | -0.3493092 23 | -0.2490258 24 | -0.2309225 25 | -0.2245584 26 | -0.2718882 27 | -0.238047 28 | -0.2967196 29 | -0.2163699 30 | -0.3490652 31 | -0.2528026 32 | -0.3728747 33 | -0.2122312 34 | -0.3052503 35 | -0.2504487 36 | -0.2881272 37 | -0.3230507 38 | -0.2512934 39 | -0.3094656 40 | -0.3430967 41 | -0.180992 42 | -0.3042613 43 | -0.1811536 44 | -0.3021508 45 | -0.2401015 46 | -0.2681419 47 | -0.1860525 48 | -0.2038196 49 | -0.272388 50 | -0.298779 51 | -0.1988667 52 | -0.2984216 53 | -0.2314907 54 | -0.1963269 55 | -0.3992622 56 | -0.4057158 57 | -0.1926314 58 | -0.2199934 59 | -0.3250419 60 | -0.2656401 61 | -0.2678007 62 | -0.3583976 63 | -0.2309377 64 | -0.324662 65 | -0.370585 66 | -0.5059416 67 | -0.2726045 68 | -0.5223382 69 | -0.1893862 70 | -0.2744146 71 | -0.3769447 72 | -0.310537 73 | -0.2295306 74 | -0.3758132 75 | -0.1989679 76 | -0.1707434 77 | -0.1953112 78 | -0.1863929 79 | -0.2627043 80 | -0.394138 81 | -0.2964099 82 | -0.3973599 83 | -0.329792 84 | -0.2679115 85 | -0.287 86 | -0.5536215 87 | -0.2979511 88 | -0.3495754 89 | -0.2879962 90 | -0.2505623 91 | -0.3840688 92 | -0.4867692 93 | -0.3513913 94 | -0.1060256 95 | -0.2677296 96 | -0.3005544 97 | -0.22219 98 | -0.3443027 99 | -0.3025809 100 | -0.3620563 101 | -0.2250669 102 | -0.1856907 103 | -0.4481322 104 | -0.3852732 105 | -0.2892522 106 | -0.300321 107 | -0.2862759 108 | -0.258762 109 | -0.2843272 110 | -0.3188107 111 | -0.1580132 112 | -0.2577859 113 | -0.3650888 114 | -0.3679154 115 | -0.4154234 116 | -0.3252514 117 | -0.3687231 118 | -0.3317211 119 | -0.2869205 120 | -0.2932796 121 | -0.3822342 122 | -0.2319795 123 | -0.2815973 124 | -0.2686024 125 | -0.2047293 126 | -0.2215604 127 | -0.373696 128 | -0.09558851 129 | -0.3264861 130 | -0.3616699 131 | -0.2568313 132 | -0.1938249 133 | -0.2747145 134 | -0.2768882 135 | -0.2701309 136 | -0.4180469 137 | -0.1391546 138 | -0.3766755 139 | -0.3723038 140 | -0.3068316 141 | -0.201975 142 | -0.2731836 143 | -0.2742543 144 | -0.3377408 145 | -0.5820093 146 | -0.2690741 147 | -0.252793 148 | -0.26859 149 | -0.1853021 150 | -0.1111634 151 | -0.2644029 152 | -0.3820352 153 | -0.2349582 154 | -0.1971716 155 | -0.2534548 156 | -0.3283744 157 | -0.2468762 158 | -0.2138282 159 | -0.2866224 160 | -0.2549285 161 | -0.2844165 162 | -0.3733637 163 | -0.2001896 164 | -0.4055803 165 | -0.4415738 166 | -0.2871343 167 | -0.2786664 168 | -0.3435114 169 | -0.3981061 170 | -0.3348723 171 | -0.2700631 172 | -0.2217035 173 | -0.1751114 174 | -0.2320674 175 | -0.1975113 176 | -0.3352392 177 | -0.2900315 178 | -0.2840141 179 | -0.2409753 180 | -0.5170859 181 | -0.3721226 182 | -0.3794576 183 | -0.3883095 184 | -0.1769757 185 | -0.3273429 186 | -0.2693013 187 | -0.2253938 188 | -0.1380382 189 | -0.2303957 190 | -0.1962336 191 | -0.2533173 192 | -0.2128342 193 | -0.5181995 194 | -0.2825969 195 | -0.234201 196 | -0.2581307 197 | -0.3619963 198 | -0.3109988 199 | -0.3417633 200 | -------------------------------------------------------------------------------- /tests/R/log_termination_0/standard_pearson_resid.csv: -------------------------------------------------------------------------------- 1 | 5.082471 2 | 6.230543 3 | 5.047925 4 | 4.539156 5 | 5.184214 6 | 1.693235 7 | 4.07654 8 | 5.207528 9 | 2.896449 10 | -0.1655945 11 | -0.1705168 12 | -0.2191803 13 | -0.2505432 14 | -0.1923249 15 | -0.1918123 16 | -0.2598301 17 | -0.3317981 18 | -0.2300038 19 | -0.1875973 20 | -0.1514447 21 | -0.2157609 22 | -0.2507778 23 | -0.1774489 24 | -0.1643653 25 | -0.1597898 26 | -0.1940327 27 | -0.1695139 28 | -0.2120955 29 | -0.1538938 30 | -0.2504935 31 | -0.1801908 32 | -0.2682746 33 | -0.1509154 34 | -0.2183665 35 | -0.1784665 36 | -0.2058587 37 | -0.2314261 38 | -0.1790773 39 | -0.2214531 40 | -0.2461043 41 | -0.1285053 42 | -0.217646 43 | -0.1286181 44 | -0.2160847 45 | -0.1710039 46 | -0.1913062 47 | -0.1321282 48 | -0.1448672 49 | -0.1943945 50 | -0.2136242 51 | -0.1413163 52 | -0.213375 53 | -0.1647871 54 | -0.1394921 55 | -0.287989 56 | -0.292837 57 | -0.1368422 58 | -0.1564992 59 | -0.2328899 60 | -0.1894972 61 | -0.1910561 62 | -0.2574456 63 | -0.1643848 64 | -0.2325584 65 | -0.2665509 66 | -0.3692115 67 | -0.1945259 68 | -0.3819732 69 | -0.1345177 70 | -0.1958735 71 | -0.271305 72 | -0.2222303 73 | -0.163374 74 | -0.270427 75 | -0.1413888 76 | -0.1211742 77 | -0.1387616 78 | -0.1323725 79 | -0.1873629 80 | -0.2839401 81 | -0.2118595 82 | -0.2865636 83 | -0.23638 84 | -0.1911211 85 | -0.2050357 86 | -0.4065944 87 | -0.2130302 88 | -0.2509733 89 | -0.205749 90 | -0.1785681 91 | -0.2766199 92 | -0.3540496 93 | -0.2523238 94 | -0.07507669 95 | -0.1910154 96 | -0.2149266 97 | -0.1580797 98 | -0.2470617 99 | -0.2164098 100 | -0.26021 101 | -0.1601517 102 | -0.1318678 103 | -0.3248568 104 | -0.2773366 105 | -0.2066681 106 | -0.2147461 107 | -0.204508 108 | -0.1845056 109 | -0.2030829 110 | -0.2282903 111 | -0.1120807 112 | -0.1838002 113 | -0.2624868 114 | -0.2645483 115 | -0.3001222 116 | -0.2330432 117 | -0.2651773 118 | -0.23773 119 | -0.2049752 120 | -0.2096165 121 | -0.2752458 122 | -0.1651371 123 | -0.2010971 124 | -0.1916446 125 | -0.145522 126 | -0.1576299 127 | -0.26887 128 | -0.06766852 129 | -0.2339455 130 | -0.2599237 131 | -0.1831051 132 | -0.137698 133 | -0.1960631 134 | -0.1976632 135 | -0.1927595 136 | -0.3020614 137 | -0.09863493 138 | -0.2709848 139 | -0.2678077 140 | -0.2195255 141 | -0.1435475 142 | -0.1949475 143 | -0.1957519 144 | -0.2422308 145 | -0.4290631 146 | -0.1919828 147 | -0.1801825 148 | -0.1916302 149 | -0.1315911 150 | -0.07872582 151 | -0.1885999 152 | -0.2750921 153 | -0.1672761 154 | -0.1400981 155 | -0.1806578 156 | -0.2353202 157 | -0.1758981 158 | -0.1520649 159 | -0.2047591 160 | -0.181719 161 | -0.2031409 162 | -0.2686391 163 | -0.1422656 164 | -0.2927094 165 | -0.3199245 166 | -0.2051351 167 | -0.1989637 168 | -0.2464991 169 | -0.2871018 170 | -0.2401232 171 | -0.1926992 172 | -0.157723 173 | -0.1242972 174 | -0.1652035 175 | -0.1403397 176 | -0.2403847 177 | -0.2072472 178 | -0.202848 179 | -0.1716347 180 | -0.3780361 181 | -0.2676911 182 | -0.2731576 183 | -0.279751 184 | -0.1256283 185 | -0.2345791 186 | -0.1921514 187 | -0.1603803 188 | -0.09784038 189 | -0.1639979 190 | -0.1394269 191 | -0.1805593 192 | -0.1513502 193 | -0.3788701 194 | -0.2017918 195 | -0.1667406 196 | -0.1840419 197 | -0.2601515 198 | -0.222561 199 | -0.2451951 200 | -------------------------------------------------------------------------------- /src/response/linear.rs: -------------------------------------------------------------------------------- 1 | //! Functions for solving linear regression 2 | 3 | use crate::{ 4 | error::{RegressionError, RegressionResult}, 5 | glm::{DispersionType, Glm}, 6 | link::Link, 7 | num::Float, 8 | response::Response, 9 | }; 10 | use num_traits::ToPrimitive; 11 | use std::marker::PhantomData; 12 | 13 | /// Linear regression with constant variance (Ordinary least squares). 14 | pub struct Linear 15 | where 16 | L: Link>, 17 | { 18 | _link: PhantomData, 19 | } 20 | 21 | /// Allow all floating point types in the linear model. 22 | impl Response> for Y 23 | where 24 | Y: Float + ToPrimitive + ToString, 25 | L: Link>, 26 | { 27 | fn into_float(self) -> RegressionResult { 28 | F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string())) 29 | } 30 | } 31 | 32 | impl Glm for Linear 33 | where 34 | L: Link>, 35 | { 36 | type Link = L; 37 | const DISPERSED: DispersionType = DispersionType::FreeDispersion; 38 | 39 | /// Logarithm of the partition function in terms of the natural parameter, 40 | /// which is mu for OLS. 41 | fn log_partition(nat_par: F) -> F { 42 | let half = F::from(0.5).unwrap(); 43 | half * nat_par * nat_par 44 | } 45 | 46 | /// variance is not a function of the mean in OLS regression. 47 | fn variance(_mean: F) -> F { 48 | F::one() 49 | } 50 | 51 | /// The saturated model likelihood is 0.5*y^2 for each observation. Note 52 | /// that if a sum of squares were used for the log-likelihood, this would be 53 | /// zero. 54 | fn log_like_sat(y: F) -> F { 55 | // Only for linear regression does this identity hold. 56 | Self::log_partition(y) 57 | } 58 | } 59 | 60 | pub mod link { 61 | //! Link functions for linear regression. 62 | use super::*; 63 | use crate::link::{Canonical, Link}; 64 | 65 | /// The identity link function, which is canonical for linear regression. 66 | pub struct Id; 67 | /// The identity is the canonical link function. 68 | impl Canonical for Id {} 69 | impl Link for Id { 70 | #[inline] 71 | fn func(y: F) -> F { 72 | y 73 | } 74 | #[inline] 75 | fn func_inv(lin_pred: F) -> F { 76 | lin_pred 77 | } 78 | } 79 | } 80 | 81 | #[cfg(test)] 82 | mod tests { 83 | use super::Linear; 84 | use crate::{error::RegressionResult, model::ModelBuilder}; 85 | use approx::assert_abs_diff_eq; 86 | use ndarray::array; 87 | 88 | #[test] 89 | fn lin_reg() -> RegressionResult<()> { 90 | let beta = array![0.3, 1.2, -0.5]; 91 | let data_x = array![[-0.1, 0.2], [0.7, 0.5], [3.2, 0.1]]; 92 | // let data_x = array![[-0.1, 0.1], [0.7, -0.7], [3.2, -3.2]]; 93 | let data_y = array![ 94 | beta[0] + beta[1] * data_x[[0, 0]] + beta[2] * data_x[[0, 1]], 95 | beta[0] + beta[1] * data_x[[1, 0]] + beta[2] * data_x[[1, 1]], 96 | beta[0] + beta[1] * data_x[[2, 0]] + beta[2] * data_x[[2, 1]], 97 | ]; 98 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 99 | let fit = model.fit_options().max_iter(10).fit()?; 100 | dbg!(fit.n_iter); 101 | // This is failing within the default tolerance 102 | assert_abs_diff_eq!(beta, fit.result, epsilon = 64.0 * f64::EPSILON); 103 | let lr: f64 = fit.lr_test(); 104 | dbg!(&lr); 105 | dbg!(&lr.sqrt()); 106 | Ok(()) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /tests/logistic.rs: -------------------------------------------------------------------------------- 1 | //! test cases for logistic regression 2 | 3 | use anyhow::Result; 4 | 5 | use approx::assert_abs_diff_eq; 6 | use ndarray::Array; 7 | use ndarray_glm::{Logistic, ModelBuilder}; 8 | mod common; 9 | use common::{array_from_csv, y_x_off_from_csv}; 10 | 11 | #[test] 12 | // this data caused an infinite loop with step halving 13 | fn log_termination_0() -> Result<()> { 14 | let (y, x, off) = y_x_off_from_csv::("tests/data/log_termination_0.csv")?; 15 | let model = ModelBuilder::::data(&y, &x) 16 | .linear_offset(off) 17 | .build()?; 18 | let fit = model.fit()?; 19 | dbg!(&fit.result); 20 | dbg!(&fit.n_iter); 21 | 22 | let n_par = fit.result.len(); 23 | 24 | // Check consistency with R results 25 | let r_result = array_from_csv::("tests/R/log_termination_0/coefficients.csv")?; 26 | assert_abs_diff_eq!(&fit.result, &r_result, epsilon=1e-5); 27 | assert!(fit.lr_test_against(&r_result) >= 0., "make sure this is better than R's"); 28 | let r_dev_resid = array_from_csv::("tests/R/log_termination_0/dev_resid.csv")?; 29 | assert_abs_diff_eq!(fit.resid_dev(), r_dev_resid, epsilon=1e-5); 30 | let r_flat_cov = array_from_csv::("tests/R/log_termination_0/covariance.csv")?; 31 | let r_cov = Array::from_shape_vec((n_par, n_par), r_flat_cov.into_raw_vec())?; 32 | assert_abs_diff_eq!(*fit.covariance()?, r_cov, epsilon=1e-5); 33 | 34 | // We've already asserted that our fit is better according to our likelihood function, so the 35 | // epsilon doesn't have to be extremely strict. 36 | let eps = 5e-5; 37 | let r_dev = array_from_csv::("tests/R/log_termination_0/deviance.csv")?[0]; 38 | assert_abs_diff_eq!(fit.deviance(), r_dev, epsilon=eps); 39 | let r_aic = array_from_csv::("tests/R/log_termination_0/aic.csv")?[0]; 40 | assert_abs_diff_eq!(fit.aic(), r_aic, epsilon=eps); 41 | let r_bic = array_from_csv::("tests/R/log_termination_0/bic.csv")?[0]; 42 | assert_abs_diff_eq!(fit.bic(), r_bic, epsilon=eps); 43 | let r_stand_resid_pear = array_from_csv::("tests/R/log_termination_0/standard_pearson_resid.csv")?; 44 | let r_stand_resid_dev = array_from_csv::("tests/R/log_termination_0/standard_deviance_resid.csv")?; 45 | assert_abs_diff_eq!(fit.resid_pear_std()?, r_stand_resid_pear, epsilon=0.02); 46 | assert_abs_diff_eq!(fit.resid_dev_std()?, r_stand_resid_dev, epsilon=0.02); 47 | let r_stud_resid = array_from_csv::("tests/R/log_termination_0/student_resid.csv")?; 48 | assert_abs_diff_eq!(fit.resid_student()?, r_stud_resid, epsilon=0.05); 49 | 50 | let r_null_dev = array_from_csv::("tests/R/log_termination_0/null_dev.csv")?[0]; 51 | assert_abs_diff_eq!(fit.lr_test(), r_null_dev - r_dev, epsilon=eps); 52 | 53 | 54 | // dbg!(fit.score_test()?); 55 | // dbg!(fit.lr_test()); 56 | // dbg!(fit.wald_test()); 57 | 58 | Ok(()) 59 | } 60 | 61 | #[test] 62 | // this data caused an infinite loop with step halving 63 | fn log_termination_1() -> Result<()> { 64 | let (y, x, off) = y_x_off_from_csv::("tests/data/log_termination_1.csv")?; 65 | let model = ModelBuilder::::data(&y, &x) 66 | .linear_offset(off) 67 | .build()?; 68 | let fit = model.fit()?; 69 | dbg!(fit.result); 70 | dbg!(fit.n_iter); 71 | Ok(()) 72 | } 73 | 74 | #[test] 75 | fn log_regularization() -> Result<()> { 76 | let (y, x, off) = y_x_off_from_csv::("tests/data/log_regularization.csv")?; 77 | // This can be terminated either by standardizing the data or by using 78 | // lambda = 2e-6 intead of 1e-6. 79 | // let x = ndarray_glm::standardize::standardize(x); 80 | let model = ModelBuilder::::data(&y, &x) 81 | .linear_offset(off) 82 | .build()?; 83 | let fit = model.fit_options().l2_reg(2e-6).fit()?; 84 | dbg!(fit.result); 85 | dbg!(fit.n_iter); 86 | Ok(()) 87 | } 88 | -------------------------------------------------------------------------------- /src/link.rs: -------------------------------------------------------------------------------- 1 | //! Defines traits for link functions 2 | 3 | use crate::{glm::Glm, num::Float}; 4 | use ndarray::Array1; 5 | 6 | /// Describes the functions to map to and from the linear predictors and the 7 | /// expectation of the response. It is constrained mathematically by the 8 | /// response distribution and the transformation of the linear predictor. 9 | // TODO: The link function and its inverse are independent of the response 10 | // distribution. This could be refactored to separate the function itself from 11 | // the transformation that works with the distribution. 12 | pub trait Link: Transform { 13 | /// Maps the expectation value of the response variable to the linear 14 | /// predictor. In general this is determined by a composition of the inverse 15 | /// natural parameter transformation and the canonical link function. 16 | fn func(y: F) -> F; 17 | // fn func(y: Array1) -> Array1; 18 | /// Maps the linear predictor to the expectation value of the response. 19 | // TODO: There may not be a point in using Array versions of these functions 20 | // since clones are necessary anyway. Perhaps we could simply define the 21 | // scalar function and use mapv(). 22 | fn func_inv(lin_pred: F) -> F; 23 | // fn func_inv(lin_pred: Array1) -> Array1; 24 | } 25 | 26 | pub trait Transform { 27 | /// The natural parameter(s) of the response distribution as a function 28 | /// of the linear predictor. For canonical link functions this is the 29 | /// identity. It must be monotonic, invertible, and twice-differentiable. 30 | /// For link function g and canonical link function g_0 it is equal to 31 | /// g_0 ( g^{-1}(lin_pred) ) . 32 | fn nat_param(lin_pred: Array1) -> Array1; 33 | /// The derivative of the transformation to the natural parameter. If it is 34 | /// zero in a region that the IRLS is in the algorithm may have difficulty 35 | /// converging. 36 | fn d_nat_param(lin_pred: &Array1) -> Array1; 37 | /// Adjust the error and variance terms of the likelihood function based on 38 | /// the first and second derivatives of the transformation. The adjustment 39 | /// is performed simultaneously. The linear predictor must be 40 | /// un-transformed, i.e. it must be X*beta without the transformation 41 | /// applied. 42 | fn adjust_errors_variance( 43 | errors: Array1, 44 | variance: Array1, 45 | lin_pred: &Array1, 46 | ) -> (Array1, Array1) { 47 | let eta_d = Self::d_nat_param(lin_pred); 48 | let err_adj = &eta_d * &errors; 49 | // The second-derivative term in the variance matrix can lead it to not 50 | // be positive-definite. In fact, the second term should vanish when 51 | // taking the expecation of Y to give the Fisher information. 52 | // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors; 53 | let var_adj = &eta_d * &variance * eta_d; 54 | (err_adj, var_adj) 55 | } 56 | } 57 | 58 | /// The canonical transformation by definition equates the linear predictor with 59 | /// the natural parameter of the response distribution. Implementing this trait 60 | /// for a link function automatically defines the trivial transformation 61 | /// functions. 62 | pub trait Canonical {} 63 | impl Transform for T 64 | where 65 | T: Canonical, 66 | { 67 | /// By defintion this function is the identity function for canonical links. 68 | #[inline] 69 | fn nat_param(lin_pred: Array1) -> Array1 { 70 | lin_pred 71 | } 72 | #[inline] 73 | fn d_nat_param(lin_pred: &Array1) -> Array1 { 74 | Array1::::ones(lin_pred.len()) 75 | } 76 | /// The canonical link function requires no transformation of the error and variance terms. 77 | #[inline] 78 | fn adjust_errors_variance( 79 | errors: Array1, 80 | variance: Array1, 81 | _lin_pred: &Array1, 82 | ) -> (Array1, Array1) { 83 | (errors, variance) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /tests/custom_link.rs: -------------------------------------------------------------------------------- 1 | //! Test implementation of custom link functions 2 | 3 | use anyhow::Result; 4 | use approx::assert_abs_diff_eq; 5 | use ndarray::{array, Array1, Axis}; 6 | use ndarray_glm::{ 7 | link::{Link, Transform}, 8 | num::Float, 9 | Linear, ModelBuilder, 10 | }; 11 | 12 | #[test] 13 | fn linear_with_lin_transform() -> Result<()> { 14 | // A linear transformation for simplicity. 15 | struct LinTran {} 16 | impl Link> for LinTran { 17 | fn func(y: F) -> F { 18 | F::from(2.5).unwrap() * y - F::from(3.4).unwrap() 19 | } 20 | fn func_inv(lin_pred: F) -> F { 21 | (lin_pred + F::from(3.4).unwrap()) * F::from(0.4).unwrap() 22 | } 23 | } 24 | assert_abs_diff_eq!( 25 | LinTran::func(LinTran::func_inv(0.45)), 26 | 0.45, 27 | epsilon = 4. * f64::EPSILON 28 | ); 29 | impl Transform for LinTran { 30 | fn nat_param(lin_pred: Array1) -> Array1 { 31 | lin_pred.mapv_into(Self::func_inv) 32 | } 33 | fn d_nat_param(lin_pred: &Array1) -> Array1 { 34 | Array1::::from_elem(lin_pred.len(), F::from(0.4).unwrap()) 35 | } 36 | } 37 | let beta = array![-0.2, 0.7]; 38 | let data_x = array![-1.5, -1.2, -0.8, -0.8, -0.5, -0.2, -0.2, 0.3, 0.3, 0.7, 0.9, 1.2, 1.2]; 39 | let mut data_y = data_x.mapv(|x| LinTran::func_inv(beta[0] + beta[1] * x)); 40 | // some x points are redundant, and Gaussian errors are symmetric, so some 41 | // pairs of points can be moved off of the exact fit without affecting the 42 | // result. 43 | data_y[2] += 0.3; 44 | data_y[3] -= 0.3; 45 | data_y[5] -= 0.2; 46 | data_y[6] += 0.2; 47 | data_y[7] += 0.4; 48 | data_y[8] -= 0.4; 49 | data_y[11] -= 0.3; 50 | data_y[12] += 0.3; 51 | // Change X data into a 2D array 52 | let data_x = data_x.insert_axis(Axis(1)); 53 | let model = ModelBuilder::>::data(&data_y, &data_x).build()?; 54 | let fit = model.fit()?; 55 | dbg!(fit.n_iter); 56 | dbg!(&fit.result); 57 | dbg!(&beta); 58 | assert_abs_diff_eq!(fit.result, beta, epsilon = 16.0 * f64::EPSILON); 59 | Ok(()) 60 | } 61 | 62 | #[test] 63 | fn linear_with_cubic() -> Result<()> { 64 | // An adjusted cube root link function to test on Linear regression. This 65 | // fits to y ~ (a + b*x)^3. If the starting guess is zero this fails to 66 | // converge because the derivative of the link function is zero at the 67 | // origin. 68 | struct Cbrt {} 69 | impl Link> for Cbrt { 70 | fn func(y: F) -> F { 71 | y.cbrt() 72 | } 73 | fn func_inv(lin_pred: F) -> F { 74 | num_traits::Float::powi(lin_pred, 3) 75 | } 76 | } 77 | assert_abs_diff_eq!( 78 | Cbrt::func(Cbrt::func_inv(0.45)), 79 | 0.45, 80 | epsilon = 4. * f64::EPSILON 81 | ); 82 | impl Transform for Cbrt { 83 | fn nat_param(lin_pred: Array1) -> Array1 { 84 | lin_pred.mapv_into(|w| num_traits::Float::powi(w, 3)) 85 | } 86 | fn d_nat_param(lin_pred: &Array1) -> Array1 { 87 | let three = F::from(3.).unwrap(); 88 | lin_pred.mapv(|w| three * num_traits::Float::powi(w, 2)) 89 | } 90 | } 91 | 92 | type TestLink = Cbrt; 93 | let beta = array![-0.2, 0.7]; 94 | let data_x = array![-1.5, -1.2, -0.8, -0.8, -0.5, -0.2, -0.2, 0.3, 0.3, 0.7, 0.9, 1.2, 1.2]; 95 | let mut data_y = data_x.mapv(|x| TestLink::func_inv(beta[0] + beta[1] * x)); 96 | // some x points are redundant, and Gaussian errors are symmetric, so some 97 | // pairs of points can be moved off of the exact fit without affecting the 98 | // result. 99 | data_y[2] += 0.3; 100 | data_y[3] -= 0.3; 101 | data_y[5] -= 0.2; 102 | data_y[6] += 0.2; 103 | data_y[7] += 0.4; 104 | data_y[8] -= 0.4; 105 | data_y[11] -= 0.3; 106 | data_y[12] += 0.3; 107 | // Change X data into a 2D array 108 | let data_x = data_x.insert_axis(Axis(1)); 109 | let model = ModelBuilder::>::data(&data_y, &data_x).build()?; 110 | eprintln!("Built model"); 111 | let fit = model.fit()?; 112 | dbg!(fit.n_iter); 113 | dbg!(&fit.result); 114 | dbg!(&beta); 115 | assert_abs_diff_eq!(fit.result, beta, epsilon = f32::EPSILON as f64); 116 | Ok(()) 117 | } 118 | -------------------------------------------------------------------------------- /tests/data/iris.csv: -------------------------------------------------------------------------------- 1 | sepal_length,sepal_width,petal_length,petal_width,class 2 | 5.1,3.5,1.4,0.2,setosa 3 | 4.9,3.0,1.4,0.2,setosa 4 | 4.7,3.2,1.3,0.2,setosa 5 | 4.6,3.1,1.5,0.2,setosa 6 | 5.0,3.6,1.4,0.2,setosa 7 | 5.4,3.9,1.7,0.4,setosa 8 | 4.6,3.4,1.4,0.3,setosa 9 | 5.0,3.4,1.5,0.2,setosa 10 | 4.4,2.9,1.4,0.2,setosa 11 | 4.9,3.1,1.5,0.1,setosa 12 | 5.4,3.7,1.5,0.2,setosa 13 | 4.8,3.4,1.6,0.2,setosa 14 | 4.8,3.0,1.4,0.1,setosa 15 | 4.3,3.0,1.1,0.1,setosa 16 | 5.8,4.0,1.2,0.2,setosa 17 | 5.7,4.4,1.5,0.4,setosa 18 | 5.4,3.9,1.3,0.4,setosa 19 | 5.1,3.5,1.4,0.3,setosa 20 | 5.7,3.8,1.7,0.3,setosa 21 | 5.1,3.8,1.5,0.3,setosa 22 | 5.4,3.4,1.7,0.2,setosa 23 | 5.1,3.7,1.5,0.4,setosa 24 | 4.6,3.6,1.0,0.2,setosa 25 | 5.1,3.3,1.7,0.5,setosa 26 | 4.8,3.4,1.9,0.2,setosa 27 | 5.0,3.0,1.6,0.2,setosa 28 | 5.0,3.4,1.6,0.4,setosa 29 | 5.2,3.5,1.5,0.2,setosa 30 | 5.2,3.4,1.4,0.2,setosa 31 | 4.7,3.2,1.6,0.2,setosa 32 | 4.8,3.1,1.6,0.2,setosa 33 | 5.4,3.4,1.5,0.4,setosa 34 | 5.2,4.1,1.5,0.1,setosa 35 | 5.5,4.2,1.4,0.2,setosa 36 | 4.9,3.1,1.5,0.2,setosa 37 | 5.0,3.2,1.2,0.2,setosa 38 | 5.5,3.5,1.3,0.2,setosa 39 | 4.9,3.6,1.4,0.1,setosa 40 | 4.4,3.0,1.3,0.2,setosa 41 | 5.1,3.4,1.5,0.2,setosa 42 | 5.0,3.5,1.3,0.3,setosa 43 | 4.5,2.3,1.3,0.3,setosa 44 | 4.4,3.2,1.3,0.2,setosa 45 | 5.0,3.5,1.6,0.6,setosa 46 | 5.1,3.8,1.9,0.4,setosa 47 | 4.8,3.0,1.4,0.3,setosa 48 | 5.1,3.8,1.6,0.2,setosa 49 | 4.6,3.2,1.4,0.2,setosa 50 | 5.3,3.7,1.5,0.2,setosa 51 | 5.0,3.3,1.4,0.2,setosa 52 | 7.0,3.2,4.7,1.4,versicolor 53 | 6.4,3.2,4.5,1.5,versicolor 54 | 6.9,3.1,4.9,1.5,versicolor 55 | 5.5,2.3,4.0,1.3,versicolor 56 | 6.5,2.8,4.6,1.5,versicolor 57 | 5.7,2.8,4.5,1.3,versicolor 58 | 6.3,3.3,4.7,1.6,versicolor 59 | 4.9,2.4,3.3,1.0,versicolor 60 | 6.6,2.9,4.6,1.3,versicolor 61 | 5.2,2.7,3.9,1.4,versicolor 62 | 5.0,2.0,3.5,1.0,versicolor 63 | 5.9,3.0,4.2,1.5,versicolor 64 | 6.0,2.2,4.0,1.0,versicolor 65 | 6.1,2.9,4.7,1.4,versicolor 66 | 5.6,2.9,3.6,1.3,versicolor 67 | 6.7,3.1,4.4,1.4,versicolor 68 | 5.6,3.0,4.5,1.5,versicolor 69 | 5.8,2.7,4.1,1.0,versicolor 70 | 6.2,2.2,4.5,1.5,versicolor 71 | 5.6,2.5,3.9,1.1,versicolor 72 | 5.9,3.2,4.8,1.8,versicolor 73 | 6.1,2.8,4.0,1.3,versicolor 74 | 6.3,2.5,4.9,1.5,versicolor 75 | 6.1,2.8,4.7,1.2,versicolor 76 | 6.4,2.9,4.3,1.3,versicolor 77 | 6.6,3.0,4.4,1.4,versicolor 78 | 6.8,2.8,4.8,1.4,versicolor 79 | 6.7,3.0,5.0,1.7,versicolor 80 | 6.0,2.9,4.5,1.5,versicolor 81 | 5.7,2.6,3.5,1.0,versicolor 82 | 5.5,2.4,3.8,1.1,versicolor 83 | 5.5,2.4,3.7,1.0,versicolor 84 | 5.8,2.7,3.9,1.2,versicolor 85 | 6.0,2.7,5.1,1.6,versicolor 86 | 5.4,3.0,4.5,1.5,versicolor 87 | 6.0,3.4,4.5,1.6,versicolor 88 | 6.7,3.1,4.7,1.5,versicolor 89 | 6.3,2.3,4.4,1.3,versicolor 90 | 5.6,3.0,4.1,1.3,versicolor 91 | 5.5,2.5,4.0,1.3,versicolor 92 | 5.5,2.6,4.4,1.2,versicolor 93 | 6.1,3.0,4.6,1.4,versicolor 94 | 5.8,2.6,4.0,1.2,versicolor 95 | 5.0,2.3,3.3,1.0,versicolor 96 | 5.6,2.7,4.2,1.3,versicolor 97 | 5.7,3.0,4.2,1.2,versicolor 98 | 5.7,2.9,4.2,1.3,versicolor 99 | 6.2,2.9,4.3,1.3,versicolor 100 | 5.1,2.5,3.0,1.1,versicolor 101 | 5.7,2.8,4.1,1.3,versicolor 102 | 6.3,3.3,6.0,2.5,virginica 103 | 5.8,2.7,5.1,1.9,virginica 104 | 7.1,3.0,5.9,2.1,virginica 105 | 6.3,2.9,5.6,1.8,virginica 106 | 6.5,3.0,5.8,2.2,virginica 107 | 7.6,3.0,6.6,2.1,virginica 108 | 4.9,2.5,4.5,1.7,virginica 109 | 7.3,2.9,6.3,1.8,virginica 110 | 6.7,2.5,5.8,1.8,virginica 111 | 7.2,3.6,6.1,2.5,virginica 112 | 6.5,3.2,5.1,2.0,virginica 113 | 6.4,2.7,5.3,1.9,virginica 114 | 6.8,3.0,5.5,2.1,virginica 115 | 5.7,2.5,5.0,2.0,virginica 116 | 5.8,2.8,5.1,2.4,virginica 117 | 6.4,3.2,5.3,2.3,virginica 118 | 6.5,3.0,5.5,1.8,virginica 119 | 7.7,3.8,6.7,2.2,virginica 120 | 7.7,2.6,6.9,2.3,virginica 121 | 6.0,2.2,5.0,1.5,virginica 122 | 6.9,3.2,5.7,2.3,virginica 123 | 5.6,2.8,4.9,2.0,virginica 124 | 7.7,2.8,6.7,2.0,virginica 125 | 6.3,2.7,4.9,1.8,virginica 126 | 6.7,3.3,5.7,2.1,virginica 127 | 7.2,3.2,6.0,1.8,virginica 128 | 6.2,2.8,4.8,1.8,virginica 129 | 6.1,3.0,4.9,1.8,virginica 130 | 6.4,2.8,5.6,2.1,virginica 131 | 7.2,3.0,5.8,1.6,virginica 132 | 7.4,2.8,6.1,1.9,virginica 133 | 7.9,3.8,6.4,2.0,virginica 134 | 6.4,2.8,5.6,2.2,virginica 135 | 6.3,2.8,5.1,1.5,virginica 136 | 6.1,2.6,5.6,1.4,virginica 137 | 7.7,3.0,6.1,2.3,virginica 138 | 6.3,3.4,5.6,2.4,virginica 139 | 6.4,3.1,5.5,1.8,virginica 140 | 6.0,3.0,4.8,1.8,virginica 141 | 6.9,3.1,5.4,2.1,virginica 142 | 6.7,3.1,5.6,2.4,virginica 143 | 6.9,3.1,5.1,2.3,virginica 144 | 5.8,2.7,5.1,1.9,virginica 145 | 6.8,3.2,5.9,2.3,virginica 146 | 6.7,3.3,5.7,2.5,virginica 147 | 6.7,3.0,5.2,2.3,virginica 148 | 6.3,2.5,5.0,1.9,virginica 149 | 6.5,3.0,5.2,2.0,virginica 150 | 6.2,3.4,5.4,2.3,virginica 151 | 5.9,3.0,5.1,1.8,virginica 152 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ndarray-glm 2 | 3 | Rust library for solving linear, logistic, and generalized linear models through 4 | iteratively reweighted least squares, using the 5 | [`ndarray-linalg`](https://docs.rs/crate/ndarray-linalg/) module. 6 | 7 | [![Crate](https://img.shields.io/crates/v/ndarray-glm.svg)](https://crates.io/crates/ndarray-glm) 8 | [![Documentation](https://docs.rs/ndarray-glm/badge.svg)](https://docs.rs/ndarray-glm) 9 | [![Build Status](https://travis-ci.org/felix-clark/ndarray-glm.png?branch=master)](https://travis-ci.org/felix-clark/ndarray-glm) 10 | ![Downloads](https://img.shields.io/crates/d/ndarray-glm) 11 | 12 | ## Status 13 | 14 | This package is in alpha and the interface could undergo changes. Even the 15 | return value of certain functions may change from one release to the next. 16 | Correctness is not guaranteed. 17 | 18 | The regression algorithm uses iteratively re-weighted least squares (IRLS) with 19 | a step-halving procedure applied when the next iteration of guesses does not 20 | increase the likelihood. 21 | 22 | Suggestions (via issues) and pull requests are welcome. 23 | 24 | ## Prerequisites 25 | 26 | The recommended approach is to use a system BLAS implementation. For instance, to install 27 | OpenBLAS on Debian/Ubuntu: 28 | ``` 29 | sudo apt update && sudo apt install -y libopenblas-dev 30 | ``` 31 | Then use this crate with the `openblas-system` feature. 32 | 33 | To use an alternative backend or to build a static BLAS implementation, refer to the 34 | `ndarray-linalg` 35 | [documentation](https://github.com/rust-ndarray/ndarray-linalg#backend-features). Use 36 | this crate with the appropriate feature flag and it will be forwarded to 37 | `ndarray-linalg`. 38 | 39 | ## Example 40 | 41 | To use in your crate, add the following to the `Cargo.toml`: 42 | 43 | ``` 44 | ndarray = { version = "0.15", features = ["blas"]} 45 | ndarray-glm = { version = "0.0.12", features = ["openblas-system"] } 46 | ``` 47 | 48 | An example for linear regression is shown below. 49 | 50 | ``` rust 51 | use ndarray_glm::{array, Linear, ModelBuilder, utility::standardize}; 52 | 53 | // define some test data 54 | let data_y = array![0.3, 1.3, 0.7]; 55 | let data_x = array![[0.1, 0.2], [-0.4, 0.1], [0.2, 0.4]]; 56 | // The design matrix can optionally be standardized, where the mean of each independent 57 | // variable is subtracted and each is then divided by the standard deviation of that variable. 58 | let data_x = standardize(data_x); 59 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 60 | // L2 (ridge) regularization can be applied with l2_reg(). 61 | let fit = model.fit_options().l2_reg(1e-5).fit()?; 62 | // Currently the result is a simple array of the MLE estimators, including the intercept term. 63 | println!("Fit result: {}", fit.result); 64 | ``` 65 | 66 | Custom non-canonical link functions can be defined by the user, although the 67 | interface is currently not particularly ergonomic. See `tests/custom_link.rs` 68 | for examples. 69 | 70 | ## Features 71 | 72 | - [X] Linear regression 73 | - [X] Logistic regression 74 | - [X] Generalized linear model IRLS 75 | - [X] Linear offsets 76 | - [X] Generic over floating point type 77 | - [X] Non-float domain types 78 | - [X] Regularization 79 | - [X] L2 (ridge) 80 | - [X] L1 (lasso) 81 | - [X] Elastic Net 82 | - [ ] Other exponential family distributions 83 | - [X] Poisson 84 | - [X] Binomial 85 | - [ ] Exponential 86 | - [ ] Gamma 87 | - [ ] Inverse Gaussian 88 | - [X] Data standardization/normalization 89 | - [X] External utility function 90 | - [ ] Automatic internal transformation 91 | - [ ] Weighted (and correlated?) regressions 92 | - [X] Non-canonical link functions 93 | - [X] Goodness-of-fit tests 94 | 95 | ## Troubleshooting 96 | 97 | Lasso/L1 regularization can converge slowly in some cases, particularly when 98 | the data is poorly-behaved, seperable, etc. 99 | 100 | The following tips are recommended things to try if facing convergence issues 101 | generally, but are more likely to be necessary in a L1 regularization problem. 102 | 103 | * Standardize the feature data 104 | * Use f32 instead of f64 105 | * Increase the tolerance and/or the maximum number of iterations 106 | * Include a small L2 regularization as well. 107 | 108 | If you encounter problems that persist even after these techniques are applied, 109 | please file an issue so the algorithm can be improved. 110 | 111 | ## References 112 | 113 | * [notes on generalized linear models](https://felix-clark.github.io/glm-math) 114 | * Generalized Linear Models and Extensions by Hardin & Hilbe 115 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! A rust library for performing GLM regression with data represented in 2 | //! [`ndarray`](file:///home/felix/Projects/ndarray-glm/target/doc/ndarray/index.html)s. 3 | //! The [`ndarray-linalg`](https://docs.rs/ndarray-linalg/) crate is used to allow 4 | //! optimization of linear algebra operations with BLAS. 5 | //! 6 | //! This crate is early alpha and may change rapidly. No guarantees can be made about 7 | //! the accuracy of the fits. 8 | //! 9 | //! # Feature summary: 10 | //! 11 | //! * Linear, logistic, Poisson, and binomial regression (more to come) 12 | //! * Generic over floating-point type 13 | //! * L2 (ridge) regularization 14 | //! * Statistical tests of fit result 15 | //! * Alternative and custom link functions 16 | //! 17 | //! 18 | //! # Setting up BLAS backend 19 | //! 20 | //! See the [backend features of 21 | //! `ndarray-linalg`](https://github.com/rust-ndarray/ndarray-linalg#backend-features) 22 | //! for a description of the available BLAS configuartions. You do not need to 23 | //! include `ndarray-linalg` in your crate; simply provide the feature you need to 24 | //! `ndarray-glm` and it will be forwarded to `ndarray-linalg`. 25 | //! 26 | //! Examples using OpenBLAS are shown here. In principle you should also be able to use 27 | //! Netlib or Intel MKL, although these backends are untested. 28 | //! 29 | //! ## System OpenBLAS (recommended) 30 | //! 31 | //! Ensure that the development OpenBLAS library is installed on your system. In 32 | //! Debian/Ubuntu, for instance, this means installing `libopenblas-dev`. Then, put the 33 | //! following into your crate's `Cargo.toml`: 34 | //! ```text 35 | //! ndarray = { version = "0.15", features = ["blas"]} 36 | //! ndarray-glm = { version = "0.0.12", features = ["openblas-system"] } 37 | //! ``` 38 | //! 39 | //! ## Compile OpenBLAS from source 40 | //! 41 | //! This option does not require OpenBLAS to be installed on your system, but the 42 | //! initial compile time will be very long. Use the folling lines in your crate's 43 | //! `Cargo.toml`. 44 | //! ```text 45 | //! ndarray = { version = "0.15", features = ["blas"]} 46 | //! ndarray-glm = { version = "0.0.12", features = ["openblas-static"] } 47 | //! ``` 48 | //! 49 | //! # Examples: 50 | //! 51 | //! Basic linear regression: 52 | //! ``` 53 | //! use ndarray_glm::{array, Linear, ModelBuilder}; 54 | //! 55 | //! let data_y = array![0.3, 1.3, 0.7]; 56 | //! let data_x = array![[0.1, 0.2], [-0.4, 0.1], [0.2, 0.4]]; 57 | //! let model = ModelBuilder::::data(&data_y, &data_x).build().unwrap(); 58 | //! let fit = model.fit().unwrap(); 59 | //! // The result is a flat array with the first term as the intercept. 60 | //! println!("Fit result: {}", fit.result); 61 | //! ``` 62 | //! 63 | //! Data standardization and L2 regularization: 64 | //! ``` 65 | //! use ndarray_glm::{array, Linear, ModelBuilder, utility::standardize}; 66 | //! 67 | //! let data_y = array![0.3, 1.3, 0.7]; 68 | //! let data_x = array![[0.1, 0.2], [-0.4, 0.1], [0.2, 0.4]]; 69 | //! // The design matrix can optionally be standardized, where the mean of each independent 70 | //! // variable is subtracted and each is then divided by the standard deviation of that variable. 71 | //! let data_x = standardize(data_x); 72 | //! let model = ModelBuilder::::data(&data_y, &data_x).build().unwrap(); 73 | //! // L2 (ridge) regularization can be applied with l2_reg(). 74 | //! let fit = model.fit_options().l2_reg(1e-5).fit().unwrap(); 75 | //! println!("Fit result: {}", fit.result); 76 | //! ``` 77 | //! 78 | //! Logistic regression with a non-canonical link function: 79 | //! ``` 80 | //! use ndarray_glm::{array, Logistic, logistic_link::Cloglog, ModelBuilder}; 81 | //! 82 | //! let data_y = array![true, false, false, true, true]; 83 | //! let data_x = array![[0.5, 0.2], [0.1, 0.3], [0.2, 0.6], [0.6, 0.3], [0.4, 0.4]]; 84 | //! let model = ModelBuilder::>::data(&data_y, &data_x).build().unwrap(); 85 | //! let fit = model.fit_options().l2_reg(1e-5).fit().unwrap(); 86 | //! println!("Fit result: {}", fit.result); 87 | //! ``` 88 | 89 | #![doc(html_root_url = "https://docs.rs/crate/ndarray-glm")] 90 | pub mod error; 91 | mod fit; 92 | mod glm; 93 | mod irls; 94 | pub mod link; 95 | mod math; 96 | pub mod model; 97 | pub mod num; 98 | mod regularization; 99 | mod response; 100 | pub mod utility; 101 | 102 | // Import some common names into the top-level namespace 103 | pub use { 104 | fit::Fit, 105 | model::ModelBuilder, 106 | response::logistic::link as logistic_link, 107 | response::{binomial::Binomial, linear::Linear, logistic::Logistic, poisson::Poisson}, 108 | }; 109 | 110 | // re-export common structs from ndarray 111 | pub use ndarray::{array, Array1, Array2, ArrayView1, ArrayView2}; 112 | -------------------------------------------------------------------------------- /tests/common/mod.rs: -------------------------------------------------------------------------------- 1 | //! Utility functions for testing 2 | use anyhow::{anyhow, Result}; 3 | use ndarray::{Array1, Array2}; 4 | use num_traits::Float; 5 | use std::{ 6 | error::Error, 7 | fs::File, 8 | io::{BufRead, BufReader}, 9 | str::FromStr, 10 | }; 11 | 12 | /// Read y, x pairs from a CSV. Right now it's assumed that there is only one covariate. 13 | // This function isn't used yet, but it will be. 14 | #[cfg(test)] 15 | #[allow(dead_code)] 16 | pub fn y_x_from_csv(file: &str) -> Result<(Array1, Array2)> 17 | where 18 | Y: FromStr, 19 | X: Float + FromStr, 20 | ::Err: 'static + Error + Send + Sync, 21 | ::Err: 'static + Error + Send + Sync, 22 | { 23 | let file = File::open(file)?; 24 | let reader = BufReader::new(file); 25 | let mut y_vec: Vec = Vec::new(); 26 | let mut x_vec: Vec = Vec::new(); 27 | for line_result in reader.lines() { 28 | let line = line_result?; 29 | let split_line: Vec<&str> = line.split(',').collect(); 30 | if split_line.len() != 2 { 31 | return Err(anyhow!("Expected two entries in CSV")); 32 | } 33 | let y_parsed: Y = split_line[0].parse()?; 34 | let x_parsed: X = split_line[1].parse()?; 35 | y_vec.push(y_parsed); 36 | x_vec.push(x_parsed); 37 | } 38 | let y = Array1::::from(y_vec); 39 | let x = Array2::::from_shape_vec((y.len(), 1), x_vec)?; 40 | Ok((y, x)) 41 | } 42 | 43 | /// Read y, x, and linear offsets from a CSV. Right now it's assumed that there is only one covariate. 44 | #[cfg(test)] 45 | #[allow(dead_code)] 46 | pub fn y_x_off_from_csv(file: &str) -> Result<(Array1, Array2, Array1)> 47 | where 48 | Y: FromStr, 49 | X: Float + FromStr, 50 | ::Err: 'static + Error + Send + Sync, 51 | ::Err: 'static + Error + Send + Sync, 52 | { 53 | let file = File::open(file)?; 54 | let reader = BufReader::new(file); 55 | let mut y_vec: Vec = Vec::new(); 56 | let mut x_vec: Vec = Vec::new(); 57 | let mut off_vec: Vec = Vec::new(); 58 | for line_result in reader.lines() { 59 | let line = line_result?; 60 | let split_line: Vec<&str> = line.split(',').collect(); 61 | if split_line.len() != 3 { 62 | return Err(anyhow!("Expected three entries in CSV")); 63 | } 64 | let y_parsed: Y = split_line[0].parse()?; 65 | let x_parsed: X = split_line[1].parse()?; 66 | let off_parsed: X = split_line[2].parse()?; 67 | y_vec.push(y_parsed); 68 | x_vec.push(x_parsed); 69 | off_vec.push(off_parsed); 70 | } 71 | let y = Array1::::from(y_vec); 72 | let x = Array2::::from_shape_vec((y.len(), 1), x_vec)?; 73 | let off = Array1::::from(off_vec); 74 | Ok((y, x, off)) 75 | } 76 | 77 | /// Read a flat array from a text file 78 | #[cfg(test)] 79 | // Silence an false warning about non-use 80 | #[allow(dead_code)] 81 | pub fn array_from_csv(file: &str) -> Result> 82 | where 83 | X: Float + FromStr, 84 | ::Err: 'static + Error + Send + Sync, 85 | { 86 | let file = File::open(file)?; 87 | let reader = BufReader::new(file); 88 | let mut x_vec: Vec = Vec::new(); 89 | for line_result in reader.lines() { 90 | let line = line_result?; 91 | let x_parsed: X = line.parse()?; 92 | x_vec.push(x_parsed); 93 | } 94 | let x: Array1 = x_vec.into(); 95 | Ok(x) 96 | } 97 | 98 | /// Load data from the popular iris test dataset. 99 | /// The class will be encoded as an integer in the y data. 100 | #[allow(dead_code)] 101 | pub fn y_x_from_iris() -> Result<(Array1, Array2)> { 102 | let file = File::open("tests/data/iris.csv")?; 103 | let reader = BufReader::new(file); 104 | let mut y_vec: Vec = Vec::new(); 105 | let mut x_vec: Vec = Vec::new(); 106 | for line_result in reader.lines() { 107 | let line = line_result?; 108 | if line == "sepal_length,sepal_width,petal_length,petal_width,class" { 109 | continue; 110 | } 111 | let split_line: Vec<&str> = line.split(',').collect(); 112 | if split_line.len() != 5 { 113 | return Err(anyhow!("Expected five entries in CSV")); 114 | } 115 | for i in 0..4 { 116 | let x_val: f32 = split_line[i].parse()?; 117 | x_vec.push(x_val); 118 | } 119 | let y_parsed = match split_line[4] { 120 | "setosa" => 0, 121 | "versicolor" => 1, 122 | "virginica" => 2, 123 | _ => unreachable!("There should only be 3 classes of irises"), 124 | }; 125 | y_vec.push(y_parsed); 126 | } 127 | let y = Array1::::from(y_vec); 128 | let x = Array2::::from_shape_vec((y.len(), 4), x_vec)?; 129 | Ok((y, x)) 130 | } 131 | -------------------------------------------------------------------------------- /tests/regularization.rs: -------------------------------------------------------------------------------- 1 | //! testing regularization 2 | mod common; 3 | 4 | use anyhow::Result; 5 | use approx::assert_abs_diff_eq; 6 | use common::{array_from_csv, y_x_from_iris}; 7 | use ndarray::{array, Array1, Array2}; 8 | use ndarray_glm::{utility::standardize, Linear, Logistic, ModelBuilder}; 9 | 10 | #[test] 11 | /// Test that the intercept is not affected by regularization when the dependent 12 | /// data is centered. This is only strictly true for linear regression. 13 | fn same_lin_intercept() -> Result<()> { 14 | let y_data: Array1 = array![0.3, 0.5, 0.8, 0.2]; 15 | let x_data: Array2 = array![[1.5, 0.6], [2.1, 0.8], [1.2, 0.7], [1.6, 0.3]]; 16 | // standardize the data 17 | let x_data = standardize(x_data); 18 | 19 | let lin_model = ModelBuilder::::data(&y_data, &x_data).build()?; 20 | let lin_fit = lin_model.fit()?; 21 | let lin_model_reg = ModelBuilder::::data(&y_data, &x_data).build()?; 22 | // use a pretty large regularization term to make sure the effect is pronounced 23 | let lin_fit_reg = lin_model_reg.fit_options().l2_reg(1.0).fit()?; 24 | dbg!(&lin_fit.result); 25 | dbg!(&lin_fit_reg.result); 26 | // Ensure that the intercept terms are equal 27 | assert_abs_diff_eq!( 28 | lin_fit.result[0], 29 | lin_fit_reg.result[0], 30 | epsilon = 2.0 * f64::EPSILON 31 | ); 32 | 33 | Ok(()) 34 | } 35 | 36 | #[test] 37 | /// Test the lasso regression on underconstrained data 38 | fn lasso_underconstrained() -> Result<()> { 39 | let y_data: Array1 = array![true, false, true]; 40 | let x_data: Array2 = array![[0.1, 1.5, 8.0], [-0.1, 1.0, -12.0], [0.2, 0.5, 9.5]]; 41 | // Either standardization or 32-bit floats is needed to converge. 42 | let x_data = standardize(x_data); 43 | let model = ModelBuilder::::data(&y_data, &x_data).build()?; 44 | // The smoothing parameter needs to be relatively large in order to test 45 | let fit = model.fit_options().max_iter(64).l1_reg(1.0).fit()?; 46 | dbg!(fit.result); 47 | let like = fit.model_like; 48 | // make sure the likelihood isn't NaN 49 | assert!(like.is_normal()); 50 | Ok(()) 51 | } 52 | 53 | #[test] 54 | fn elnet_seperable() -> Result<()> { 55 | let (y_labels, x_data) = y_x_from_iris()?; 56 | let x_data = standardize(x_data); 57 | // setosa 58 | let y_data: Array1 = y_labels.mapv(|i| i == 0); 59 | let target: Array1 = array_from_csv("tests/R/log_regularization/iris_setosa_l1_l2_1e-2.csv")?; 60 | dbg!(&target); 61 | let model = ModelBuilder::::data(&y_data, &x_data).build()?; 62 | let fit = model.fit_options().l1_reg(1e-2).l2_reg(1e-2).fit()?; 63 | dbg!(&fit.result); 64 | // If this is negative then our alg hasn't converged to a good minimum 65 | assert!(fit.lr_test_against(&target) >= 0., "If it's not an exact match to the target, it should be a better result under our likelihood."); 66 | assert_abs_diff_eq!(&target, &fit.result, epsilon = 0.01); 67 | Ok(()) 68 | } 69 | 70 | #[test] 71 | fn ridge_seperable() -> Result<()> { 72 | let (y_labels, x_data) = y_x_from_iris()?; 73 | // let x_data = standardize(x_data); 74 | let y_data: Array1 = y_labels.mapv(|i| i == 0); 75 | let target: Array1 = array_from_csv("tests/R/log_regularization/iris_setosa_l2_1e-2.csv")?; 76 | let model = ModelBuilder::::data(&y_data, &x_data).build()?; 77 | // Temporarily try L2 for testing 78 | let fit = model.fit_options().l2_reg(1e-2).fit()?; 79 | // This still appears to be positive so our result is better 80 | // Ensure that our result is better, even if the parameters aren't epsilon-equivalent. 81 | assert!(fit.lr_test_against(&target) > -f32::EPSILON); 82 | // their result seems less precise, even when reducing the threshold. 83 | assert_abs_diff_eq!(&target, &fit.result, epsilon = 2e-3); 84 | Ok(()) 85 | } 86 | 87 | #[test] 88 | fn lasso_versicolor() -> Result<()> { 89 | let (y_labels, x_data) = y_x_from_iris()?; 90 | let x_data = standardize(x_data); 91 | // NOTE: It matches for versicolor, but not setosa (which is fully seperable). 92 | // versicolor 93 | let y_data: Array1 = y_labels.mapv(|i| i == 1); 94 | let target: Array1 = array_from_csv("tests/R/log_regularization/iris_versicolor_l1_1e-2.csv")?; 95 | let model = ModelBuilder::::data(&y_data, &x_data).build()?; 96 | // TODO: test more harshly by increasing lambda. It passes at l1 = 1 at time of writing but 97 | // taks longer. 98 | let fit = model.fit_options().l1_reg(1e-2).fit()?; 99 | // If this is negative then our alg hasn't converged to a good minimum 100 | assert!(fit.lr_test_against(&target) >= 0., "If it's not an exact match to the target, it should be a better result under our likelihood."); 101 | // The epsilon tolerance doesn't need to be very low if we've found a better minimum 102 | assert_abs_diff_eq!(&target, &fit.result, epsilon = 0.01); 103 | Ok(()) 104 | } 105 | -------------------------------------------------------------------------------- /tests/data/log_regularization.csv: -------------------------------------------------------------------------------- 1 | true,1.2193131,-3.2022715 2 | true,1.1801388,-3.5498357 3 | true,1.1601324,-2.9067175 4 | true,1.2045205,-3.300263 5 | true,1.2824708,-1.038213 6 | true,1.1342251,-3.3903043 7 | true,1.1621742,-2.8323135 8 | true,1.2184155,-2.1300461 9 | false,-0.41074926,-2.5214505 10 | false,0.97007436,-3.2458508 11 | false,-0.24895352,-3.6581054 12 | false,0.9783293,-3.3000073 13 | false,-0.5885619,-2.903534 14 | false,-0.5007807,-3.106107 15 | false,-0.62737083,-3.6018958 16 | false,0.9031565,-3.0888348 17 | false,0.7203382,-3.8959014 18 | false,-0.55074775,-3.1264353 19 | false,0.90990645,-3.5095015 20 | false,0.8678664,-4.045617 21 | false,-0.5184523,-2.878534 22 | false,0.79999596,-3.1672046 23 | false,0.804452,-4.1845894 24 | false,1.0550954,-2.636488 25 | false,1.0733204,-2.559767 26 | false,0.82326007,-4.118101 27 | false,-0.5851728,-3.4469075 28 | false,-0.5038215,-3.6488595 29 | false,-0.573655,-3.1878405 30 | false,-0.6176456,-3.7475553 31 | false,0.9985628,-3.1413326 32 | false,-0.7065073,-4.1481404 33 | false,-0.6194027,-3.4137855 34 | false,-0.5784123,-2.8795235 35 | false,-0.57170135,-3.6843724 36 | false,-0.52790344,-2.508654 37 | false,-0.67210525,-3.870491 38 | false,-0.38833547,-2.2940629 39 | false,0.8713136,-3.9614625 40 | false,0.7462681,-3.21338 41 | false,0.8717802,-3.9081273 42 | false,-0.533901,-2.9195514 43 | false,0.9854994,-3.1786363 44 | false,-0.60825104,-3.1311865 45 | false,-0.568052,-2.6128466 46 | false,1.0630666,-2.4684453 47 | false,0.9843849,-3.0524857 48 | false,1.0813189,-2.4847744 49 | false,0.94213784,-3.251393 50 | false,-0.5005621,-3.5402033 51 | false,-0.61657226,-3.249961 52 | false,-0.4862158,-3.3581712 53 | false,0.70465565,-3.8698623 54 | false,0.9735776,-3.400121 55 | false,0.8487184,-3.7168138 56 | false,0.8504126,-3.2296853 57 | false,0.7740992,-5.3767567 58 | false,0.8039921,-3.1494188 59 | false,0.9473632,-3.878926 60 | false,0.94306403,-3.3567743 61 | false,-0.46513575,-2.1704593 62 | false,-0.6577675,-3.8597848 63 | false,0.80936736,-3.0139823 64 | false,-0.59992224,-3.156299 65 | false,1.0337565,-3.060298 66 | false,-0.4659681,-1.8616779 67 | false,-0.25259688,-2.097617 68 | false,0.60552067,-4.3096523 69 | false,-0.13298273,-2.7033253 70 | false,1.1334077,-2.2630975 71 | false,-0.5843559,-3.2043858 72 | false,0.9830567,-3.2682126 73 | false,-0.61661345,-3.3256328 74 | false,-0.18989393,-2.4832728 75 | false,-0.61569196,-3.4693508 76 | false,0.8969802,-2.6681943 77 | false,0.91982085,-3.4969087 78 | false,1.0017933,-2.7795973 79 | false,1.0185544,-2.15576 80 | false,0.85604364,-2.8518777 81 | false,0.91338027,-3.6118245 82 | false,-0.44871384,-1.9410766 83 | false,-0.22448498,-2.8435984 84 | false,0.854244,-4.26372 85 | false,0.9298247,-3.3350835 86 | false,0.9919857,-3.5081737 87 | false,-0.48487252,-2.549753 88 | false,1.0462288,-2.7037094 89 | false,0.967798,-3.243675 90 | false,0.68672746,-3.2842414 91 | false,-0.5376425,-2.977534 92 | false,-0.5212343,-3.7312818 93 | false,-0.5820122,-4.143177 94 | false,1.0747514,-2.8839417 95 | false,0.9287764,-3.241282 96 | false,0.8383481,-2.7857566 97 | false,0.84790874,-3.088206 98 | false,0.7018074,-4.0200224 99 | false,0.9097569,-3.8757997 100 | false,0.6970985,-3.4017804 101 | false,0.85544413,-3.1814318 102 | false,1.1275536,-2.311534 103 | false,0.977297,-3.3992295 104 | false,-0.6071262,-3.2451415 105 | false,-0.53342485,-2.517648 106 | false,-0.5792473,-2.967439 107 | false,-0.24168313,-3.7166774 108 | false,0.80485374,-3.758474 109 | false,-0.45355284,-3.0133708 110 | false,0.75524676,-4.2287397 111 | false,1.0477058,-2.8399858 112 | false,0.9120359,-4.0633545 113 | false,-0.5697641,-3.532654 114 | false,-0.55881673,-2.6308005 115 | false,0.94130284,-3.370558 116 | false,-0.62216914,-5.381933 117 | false,0.78644824,-3.1046753 118 | false,-0.6698377,-3.8812838 119 | false,-0.5053506,-3.107417 120 | false,0.92313594,-3.322509 121 | false,-0.5808335,-3.0687158 122 | false,-0.5234874,-3.2037897 123 | false,0.86294967,-3.1389453 124 | false,0.89145595,-3.3910303 125 | false,-0.78038037,-4.8554235 126 | false,1.0060856,-2.8213315 127 | false,-0.61578596,-3.1068404 128 | false,0.7889972,-3.0109434 129 | false,-0.54767704,-3.6485171 130 | false,-0.58674467,-3.4586437 131 | false,-0.30476525,-2.3913426 132 | false,0.61489,-3.0918524 133 | false,-0.3605033,-3.3683574 134 | false,0.6847831,-3.331433 135 | false,0.9967753,-2.6805768 136 | false,0.6340052,-3.2968535 137 | false,1.0213277,-2.5930195 138 | false,1.0001702,-2.746601 139 | false,0.63541365,-3.3146675 140 | false,-0.4839579,-3.7838674 141 | false,0.8460796,-2.6600602 142 | false,-0.7533034,-5.211149 143 | false,-0.43137386,-2.5453756 144 | false,-0.49925748,-3.0159621 145 | false,-0.64599,-3.4936776 146 | false,0.6755821,-4.162391 147 | false,1.010294,-2.9670782 148 | false,-0.4523279,-3.1350107 149 | false,0.82564384,-3.2064748 150 | false,-0.53777677,-2.5613034 151 | false,0.7752573,-2.510901 152 | false,0.78022814,-4.1310253 153 | false,0.6234257,-3.1576967 154 | false,-0.35609847,-1.8625667 155 | false,-0.5541089,-2.8448658 156 | false,-0.53884965,-2.3788 157 | false,-0.4491644,-1.9617666 158 | false,1.0167127,-3.4378614 159 | false,0.98980284,-3.379426 160 | false,0.9519398,-3.2246213 161 | false,-0.6513209,-3.582153 162 | false,-0.3885154,-3.8879168 163 | false,1.0918008,-2.3303013 164 | false,0.7315403,-3.4762826 165 | false,-0.5744604,-3.904035 166 | false,1.0497668,-2.8503888 167 | false,-0.4423195,-2.8007054 168 | false,-0.4119177,-2.7109041 169 | false,-0.5138235,-2.7545297 170 | false,-0.2533193,-2.0306814 171 | false,0.85990244,-2.780622 172 | false,0.90431154,-3.8353171 173 | false,-0.61061376,-3.4411867 174 | false,-0.61250573,-3.3783073 175 | false,-0.33911598,-4.5889893 176 | false,0.82879674,-3.4899902 177 | false,-0.57095784,-4.2981777 178 | false,1.0190511,-2.7300186 179 | false,0.83831835,-3.1926084 180 | false,0.97677445,-2.8796453 181 | false,-0.4917673,-3.4012911 182 | false,1.0690722,-2.8748078 183 | false,-0.64506125,-3.7131538 184 | false,-0.45198053,-2.6179054 185 | false,-0.5599256,-2.709443 186 | false,0.90968555,-3.4705572 187 | false,0.84510374,-3.2077503 188 | false,-0.52754027,-2.6144018 189 | false,-0.54932237,-2.9447875 190 | false,0.87772995,-3.209741 191 | false,-0.64744425,-3.65374 192 | false,0.96616846,-3.5632045 193 | false,0.8853154,-3.860868 194 | -------------------------------------------------------------------------------- /tests/data/lr_test_sign0.csv: -------------------------------------------------------------------------------- 1 | false,0.7152226,-3.6085353 2 | false,0.56037426,-3.6402123 3 | false,-1.022976,-3.643538 4 | false,0.6710763,-3.665741 5 | false,-0.8232459,-3.6468651 6 | false,0.017770052,-3.6316876 7 | false,0.6199495,-3.694242 8 | false,-0.28073478,-3.593704 9 | false,0.69234216,-3.6504192 10 | false,-0.5164597,-3.595847 11 | false,0.7872968,-3.5961988 12 | false,-0.77665204,-3.6319296 13 | false,0.62330854,-3.6856184 14 | false,-0.004978776,-3.6808743 15 | false,0.621892,-3.6891317 16 | false,-1.2630844,-3.673563 17 | false,0.959444,-3.4963577 18 | false,0.79039633,-3.583415 19 | false,0.762359,-3.6108892 20 | false,0.44988728,-3.6731575 21 | false,0.72182035,-3.6133442 22 | false,0.047882557,-3.7293506 23 | false,-0.13113725,-3.6758916 24 | false,0.7163267,-3.6376493 25 | false,-0.13001621,-3.5945144 26 | false,0.55909896,-3.7181485 27 | false,0.70205,-3.64781 28 | false,-0.0012778044,-3.6403902 29 | false,0.64563656,-3.6783702 30 | false,0.7340226,-3.5724525 31 | false,0.7735822,-3.606394 32 | false,0.56631017,-3.6794572 33 | false,0.5656501,-3.627299 34 | false,0.72041,-3.6350577 35 | false,0.32734275,-3.5956397 36 | false,0.70419395,-3.6342216 37 | false,-1.2309288,-3.632314 38 | false,0.07541752,-3.5789294 39 | false,0.72441995,-3.6325212 40 | false,0.7113774,-3.6401277 41 | false,-1.2768474,-3.637235 42 | false,0.7426175,-3.6242754 43 | false,0.8401146,-3.5678499 44 | false,0.67826724,-3.661591 45 | false,0.41560686,-3.7137117 46 | false,0.6934705,-3.6403508 47 | false,0.6689173,-3.646778 48 | false,0.702981,-3.6468165 49 | false,0.67415154,-3.6288226 50 | false,0.8378422,-3.5670893 51 | true,0.506668,-3.6840324 52 | false,0.56240034,-3.6830785 53 | false,0.67530286,-3.6560564 54 | false,0.7645917,-3.6102765 55 | false,0.6982378,-3.6480098 56 | false,0.6442516,-3.6792786 57 | false,0.68148017,-3.6043475 58 | false,0.7029536,-3.6411314 59 | false,.71576965,-3.6397545 60 | false,-1.3758832,-3.7240877 61 | false,0.56052005,-3.6340125 62 | false,0.63253975,-3.6880493 63 | false,0.56105685,-3.6864665 64 | false,0.76998997,-3.606385 65 | false,0.56831634,-3.7227163 66 | false,0.020682931,-3.6849864 67 | false,0.75811386,-3.615311 68 | false,0.8074597,-3.5845907 69 | false,.7497585,-3.6201265 70 | false,0.7496525,-3.6202145 71 | false,0.7313365,-3.6287484 72 | false,0.70235455,-3.610566 73 | false,0.7826723,-3.5939753 74 | false,0.6283232,-3.688149 75 | false,0.67756975,-3.6301572 76 | false,-0.3347485,-3.6735375 77 | false,0.67394745,-3.657558 78 | false,0.046198845,-3.6414473 79 | false,0.743494,-3.6209612 80 | false,0.674222,-3.6619494 81 | false,-0.6970977,-3.6376023 82 | false,-0.39293802,-3.63612 83 | false,-0.85064065,-3.59853 84 | false,0.6775354,-3.6592307 85 | false,0.58206046,-3.6288016 86 | false,0.5213444,-3.6743352 87 | false,0.69585,-3.6493435 88 | false,0.77308047,-3.6066277 89 | false,-0.0681716,-3.6455176 90 | false,0.74158955,-3.624875 91 | false,0.67450905,-3.6637163 92 | false,-0.1693654,-3.6829767 93 | false,0.56658673,-3.6264708 94 | false,0.8290237,-3.570659 95 | false,0.5178454,-3.6373298 96 | false,0.7058407,-3.6434455 97 | false,0.815573,-3.548574 98 | false,0.453143,-3.663815 99 | true,0.013307929,-3.6036236 100 | false,0.822853,-3.578124 101 | false,-0.99600863,-3.702743 102 | false,0.6475085,-3.6750708 103 | false,0.7338419,-3.6272666 104 | false,0.96412957,-3.4960144 105 | false,0.55754256,-3.6837423 106 | false,-0.058371305,-3.5962403 107 | false,0.6002399,-3.679225 108 | false,0.68433404,-3.6560805 109 | true,0.7826755,-3.5968366 110 | false,0.6070493,-3.6227078 111 | false,0.5407921,-3.6386738 112 | false,0.51439667,-3.6645045 113 | false,-0.47911608,-3.6214197 114 | false,-0.282274,-3.5073347 115 | false,0.61114645,-3.7004695 116 | false,0.43772566,-3.7098577 117 | false,0.7886317,-3.5950437 118 | false,0.5587115,-3.610912 119 | false,0.58002186,-3.6671083 120 | false,-0.08960211,-3.6685073 121 | false,0.7189616,-3.6379051 122 | false,-0.25201273,-3.6531727 123 | false,0.8323481,-3.570134 124 | false,0.68426526,-3.5954833 125 | false,0.8082104,-3.5862873 126 | false,0.82050896,-3.5768504 127 | false,-0.55750877,-3.6706944 128 | false,-0.25864017,-3.6377466 129 | false,0.6894479,-3.6274357 130 | false,0.71303976,-3.633041 131 | false,0.6909976,-3.6528401 132 | false,0.16668427,-3.5628033 133 | false,0.6740612,-3.6621964 134 | false,0.6700624,-3.6390095 135 | false,0.6896107,-3.6529715 136 | false,0.63629854,-3.64453 137 | false,0.74605155,-3.6223283 138 | false,0.73017204,-3.629482 139 | false,0.18682218,-3.6553874 140 | false,0.5488491,-3.6950738 141 | false,0.58560145,-3.6642652 142 | false,0.7751788,-3.605426 143 | false,0.6772654,-3.6602924 144 | false,0.64719903,-3.6795602 145 | false,-0.25924456,-3.6764307 146 | false,0.6669462,-3.6680992 147 | false,0.77604365,-3.5753791 148 | false,0.7085434,-3.6418722 149 | false,-0.33969414,-3.6163805 150 | false,0.61070347,-3.7006905 151 | false,0.36059976,-3.7292917 152 | false,0.40752208,-3.664181 153 | false,-0.08623636,-3.6730607 154 | false,0.5931206,-3.6858704 155 | true,-1.1504465,-3.6717937 156 | false,0.5284418,-3.653012 157 | false,0.715564,-3.6399958 158 | false,0.75680614,-3.6140022 159 | false,0.8191289,-3.579786 160 | false,0.43273652,-3.6701689 161 | false,0.013841271,-3.6684394 162 | false,0.6103896,-3.6020484 163 | false,0.7480247,-3.6211753 164 | true,0.72249484,-3.6273766 165 | false,0.3664639,-3.6982884 166 | false,-0.0988189,-3.6279793 167 | false,0.50096035,-3.6863177 168 | false,0.56772625,-3.6446884 169 | false,0.7028991,-3.6185918 170 | false,0.73033345,-3.629336 171 | false,0.54914486,-3.64696 172 | false,0.76795006,-3.5969083 173 | false,-0.7261873,-3.6019404 174 | false,0.69491184,-3.6449492 175 | false,-0.21735424,-3.577795 176 | false,0.61544144,-3.6180913 177 | false,0.6178874,-3.6420574 178 | false,0.6815157,-3.6480515 179 | false,0.80156803,-3.5901687 180 | false,0.7075745,-3.643994 181 | false,0.7511097,-3.616567 182 | false,0.6582899,-3.6618757 183 | false,-1.2481409,-3.642815 184 | false,-0.5470858,-3.7088833 185 | false,0.5214509,-3.6995325 186 | false,-0.7649454,-3.6302497 187 | false,0.900054,-3.5309136 188 | false,0.6464579,-3.677993 189 | false,-1.2727971,-3.6452026 190 | false,0.7403475,-3.6234975 191 | false,-0.23072696,-3.6246545 192 | false,0.74116254,-3.6227877 193 | false,0.6983613,-3.6499014 194 | false,0.013787508,-3.6618838 195 | false,0.68922365,-3.6433485 196 | -------------------------------------------------------------------------------- /tests/data/lr_test_sign1.csv: -------------------------------------------------------------------------------- 1 | false,0.60000193,-3.006097 2 | false,0.73958755,-3.0753865 3 | true,0.5902617,-2.9823687 4 | false,0.5225923,-3.0445418 5 | false,0.7422992,-3.0973408 6 | false,0.6313586,-2.996151 7 | false,0.6433028,-2.9973829 8 | false,0.6724516,-3.0129297 9 | false,0.6110176,-2.955964 10 | false,0.73995745,-3.077232 11 | false,0.6517347,-2.9936688 12 | false,0.7269701,-3.085854 13 | false,0.65412045,-2.9974234 14 | false,0.79741526,-3.1469471 15 | false,0.7078545,-3.0524921 16 | false,0.6456047,-2.9878178 17 | false,0.5355282,-3.0687723 18 | false,0.71035266,-3.0784914 19 | false,0.56887555,-2.9262736 20 | false,0.77582073,-3.124317 21 | false,0.8417636,-3.1967273 22 | false,0.69972503,-3.0384562 23 | false,0.7334889,-3.1216624 24 | false,0.63272476,-3.0253038 25 | false,0.73779917,-3.0724897 26 | false,0.67752063,-3.0426145 27 | false,0.608974,-3.030224 28 | false,0.63329184,-2.9857216 29 | false,0.6975018,-3.0398118 30 | false,0.65684855,-2.998184 31 | false,0.69776726,-3.0381372 32 | false,0.8243823,-3.1623414 33 | false,0.77510834,-3.1207707 34 | false,0.69983006,-3.0503528 35 | false,0.6940248,-3.0688574 36 | false,0.5966468,-2.9579482 37 | false,0.715245,-3.0812771 38 | false,0.6554657,-3.029903 39 | false,-1.1431051,-2.9085011 40 | false,0.72175777,-3.0712206 41 | false,0.7714646,-3.1041377 42 | false,-1.0342649,-3.280007 43 | false,0.63316846,-2.9768028 44 | false,0.6993027,-3.0587947 45 | false,0.7129332,-3.050919 46 | false,0.71866596,-3.0547364 47 | false,0.7635138,-3.1077745 48 | true,-0.6056436,-3.0790966 49 | false,0.6200999,-2.9883525 50 | false,0.5932543,-3.018548 51 | false,0.8202294,-3.149598 52 | false,0.63610005,-2.9781556 53 | false,0.6942589,-3.0380707 54 | false,0.7316766,-3.1197534 55 | false,0.5863311,-3.0464222 56 | false,0.50769436,-2.9189882 57 | false,-1.2319046,-3.0978568 58 | false,0.69078696,-3.0546613 59 | false,0.7655693,-3.1170523 60 | false,-1.288313,-3.0455718 61 | false,0.65671325,-3.0351205 62 | false,0.6772561,-3.042112 63 | false,0.79055417,-3.1219118 64 | false,0.7136332,-3.0604615 65 | false,0.56115365,-3.0502899 66 | false,0.59185874,-3.0542533 67 | false,0.6306044,-3.0019577 68 | false,0.80326915,-3.1516292 69 | false,-0.10070264,-3.1303878 70 | false,0.63442767,-2.994978 71 | false,0.8102906,-3.1392844 72 | false,-1.3158369,-3.0202413 73 | false,0.66692674,-3.106054 74 | false,0.7538556,-3.0882857 75 | false,-1.3325711,-3.0055542 76 | false,0.696327,-3.0505736 77 | false,0.7303494,-3.071169 78 | false,0.75628114,-3.0902889 79 | false,-1.36619,-2.9741197 80 | false,0.69720614,-3.0666103 81 | false,0.73967814,-3.0759912 82 | false,0.7051089,-3.0750349 83 | false,0.77216935,-3.1089287 84 | false,-1.2844682,-3.0491078 85 | false,0.6998968,-3.0400407 86 | false,0.19623339,-3.1114411 87 | false,0.6832442,-3.0229092 88 | false,0.6576929,-3.0164769 89 | false,0.58470345,-2.9541173 90 | false,0.70751,-3.0553222 91 | false,-0.48157918,-3.121332 92 | false,-1.378693,-2.9630754 93 | false,0.7919359,-3.1247494 94 | false,0.7298856,-3.1045246 95 | false,0.77479076,-3.1116133 96 | false,-0.6076642,-3.009858 97 | false,0.6050153,-2.9508438 98 | false,0.78681445,-3.1365786 99 | false,0.5827042,-2.980685 100 | false,0.74414706,-3.122613 101 | false,0.6597233,-3.0002046 102 | false,0.7060808,-3.043687 103 | false,-1.0536684,-3.2624996 104 | false,0.6928083,-3.0486038 105 | false,0.7505895,-3.088583 106 | true,-1.247789,-3.082468 107 | false,0.63798666,-2.9811478 108 | false,0.65624404,-3.0530274 109 | false,0.59138286,-2.9394965 110 | false,0.6885004,-3.0809355 111 | false,0.5990684,-3.092376 112 | false,0.67402446,-3.0160098 113 | true,0.6511196,-3.0304334 114 | false,0.77971506,-3.1501226 115 | true,-0.4894935,-3.0507085 116 | false,0.8002145,-3.1483133 117 | true,0.70457995,-3.0430188 118 | false,0.8872721,-3.280554 119 | false,-1.0288248,-3.098318 120 | false,0.73008263,-3.06792 121 | false,0.61170006,-2.9770405 122 | false,0.7480694,-3.123457 123 | false,0.6575941,-3.0805457 124 | false,-1.323791,-3.013087 125 | false,0.6725675,-3.01183 126 | false,-1.4372009,-2.9084072 127 | false,0.7340354,-3.070084 128 | false,0.44430757,-3.0710452 129 | false,0.7007612,-3.0479262 130 | false,-1.3656977,-2.9742785 131 | false,0.73020756,-3.0676873 132 | false,0.67472386,-3.101095 133 | false,0.32870078,-3.0641696 134 | false,0.6477268,-2.9980295 135 | false,0.7257174,-3.0620112 136 | false,0.7563257,-3.0934255 137 | false,0.6491368,-2.992809 138 | false,0.58669364,-2.9333498 139 | false,0.83817077,-3.167197 140 | false,0.7162311,-3.078364 141 | true,0.60862124,-3.0571003 142 | false,0.6730138,-3.0252657 143 | false,0.67073524,-3.0342946 144 | false,0.7872877,-3.1207042 145 | false,0.6449033,-3.0401018 146 | false,0.6720556,-3.026371 147 | false,0.74880767,-3.0826085 148 | false,0.6659467,-3.0123458 149 | false,0.5957966,-2.9410503 150 | false,0.5805838,-2.9822063 151 | true,0.6984676,-3.0399632 152 | false,0.7069485,-3.045417 153 | false,0.56867445,-2.916801 154 | false,0.70985186,-3.0513036 155 | false,0.7235371,-3.119639 156 | false,0.8386946,-3.165984 157 | false,0.730543,-3.070951 158 | false,0.5747795,-2.9777536 159 | false,0.6706457,-3.0551708 160 | false,0.81516993,-3.153976 161 | false,-1.3480839,-2.990317 162 | false,0.5834594,-2.9644012 163 | false,0.6476103,-2.9897158 164 | false,0.665419,-3.005446 165 | true,0.69104457,-3.0637836 166 | false,0.62486017,-2.987982 167 | false,0.7674829,-3.10451 168 | false,-1.3200102,-3.0159051 169 | false,0.7080008,-3.0687392 170 | false,0.81226957,-3.1411595 171 | false,0.64461267,-3.0122118 172 | false,0.66681623,-3.0631704 173 | false,0.6297997,-3.0076773 174 | false,0.7265135,-3.0628402 175 | false,0.67735183,-3.0434346 176 | false,0.6024531,-3.0360103 177 | false,0.6163956,-3.0664628 178 | false,0.79089,-3.1586432 179 | false,0.7969954,-3.2248938 180 | false,0.6590091,-3.0160227 181 | false,0.76484513,-3.0973046 182 | false,0.70781684,-3.0452693 183 | false,-1.2956144,-2.9725523 184 | false,0.6655396,-3.0155113 185 | false,0.69066393,-3.0301595 186 | false,0.7216128,-3.0601282 187 | false,0.7090354,-3.0476308 188 | false,0.7370745,-3.085056 189 | false,0.66277254,-3.002796 190 | false,0.6287874,-3.0545063 191 | false,0.72305405,-3.0604885 192 | false,0.7346107,-3.0700073 193 | false,0.71277225,-3.0501294 194 | false,0.44307065,-2.936778 195 | false,0.6936883,-3.0868723 196 | false,0.6324481,-3.0188847 197 | false,0.6700398,-3.011019 198 | false,0.70309997,-3.0726452 199 | false,0.46533906,-2.821354 200 | -------------------------------------------------------------------------------- /tests/data/log_termination_1.csv: -------------------------------------------------------------------------------- 1 | true,0.48353553,-3.2022738 2 | true,0.32687593,-3.5498385 3 | true,0.50492036,-3.3002653 4 | true,0.56624234,-2.9067192 5 | true,-0.14332628,-3.3903067 6 | true,-0.011929154,-1.0382121 7 | true,-0.4571042,-2.8323152 8 | true,0.65687394,-3.4152577 9 | true,0.015737057,-2.1300468 10 | false,0.23096406,-3.3786242 11 | false,0.012414932,-3.6018987 12 | false,-0.52334684,-3.1413348 13 | false,0.6178515,-2.5214515 14 | false,0.6964315,-3.2458532 15 | false,-1.0957825,-3.3000097 16 | false,-1.1722804,-2.6364894 17 | false,-1.384096,-2.2940636 18 | false,0.5960889,-3.1264374 19 | false,0.74327517,-3.540206 20 | false,0.3723656,-3.8959048 21 | false,0.6944523,-3.2296875 22 | false,-1.1790563,-2.8795252 23 | false,-0.12711275,-3.6581082 24 | false,-1.2141193,-3.3581736 25 | false,-1.2183346,-3.7168167 26 | false,0.59808135,-3.4001236 27 | false,-0.051205873,-3.747558 28 | false,0.06812978,-2.8785357 29 | false,-0.9532629,-3.782242 30 | false,0.5170653,-2.5086553 31 | false,-1.2525051,-3.44691 32 | false,-1.2220773,-2.6128478 33 | false,0.78843486,-3.9081306 34 | false,0.652954,-3.1311886 35 | false,-0.9228772,-3.1878426 36 | false,-0.46118212,-3.2133822 37 | false,-0.35449618,-2.919553 38 | false,-1.2833532,-3.1786385 39 | false,-0.8604206,-3.1061091 40 | false,0.6273012,-2.5597682 41 | false,-1.017805,-4.045621 42 | false,0.30743694,-3.0524876 43 | false,0.760401,-3.8698654 44 | false,0.5382639,-3.2513955 45 | false,-0.9086895,-3.509504 46 | false,-0.7136542,-3.1672068 47 | false,0.5777776,-4.184593 48 | false,-0.9980084,-3.6488624 49 | false,0.70232236,-3.4137878 50 | false,-0.44568497,-3.249963 51 | false,-0.92240953,-3.9614658 52 | false,0.6939018,-3.088837 53 | false,0.7389426,-3.6843753 54 | false,-1.1729985,-4.118105 55 | false,-0.6568296,-2.4684463 56 | false,0.59121525,-2.4847755 57 | false,-0.43012452,-4.1481442 58 | false,-0.8200198,-3.8704944 59 | false,-0.7762546,-2.9035356 60 | false,-1.0411963,-3.399232 61 | false,-0.45994437,-3.5081763 62 | false,0.62327325,-2.5176492 63 | false,-1.2486656,-3.4693534 64 | false,0.57346594,-2.7037108 65 | false,0.16317546,-2.7795987 66 | false,-0.7400616,-2.1704597 67 | false,-1.3026286,-3.0139842 68 | false,-0.016997218,-2.0976174 69 | false,-1.1019399,-4.020026 70 | false,0.6611141,-3.2842436 71 | false,-1.1442604,-2.6681957 72 | false,0.70348763,-3.1563013 73 | false,-1.1237589,-3.6118271 74 | false,-0.40564853,-2.7857583 75 | false,0.78321004,-3.875803 76 | false,-0.5143354,-4.309656 77 | false,-0.7337447,-3.7312846 78 | false,0.8111557,-4.143181 79 | false,-0.72404915,-3.2412841 80 | false,0.5662259,-2.2630982 81 | false,-1.3205436,-2.8518794 82 | false,0.61006343,-2.483274 83 | false,0.6439452,-2.977536 84 | false,-0.51819354,-3.0603 85 | false,0.70379484,-3.2436771 86 | false,-1.4223874,-1.9410769 87 | false,0.66888523,-3.149421 88 | false,-1.3163908,-2.8839433 89 | false,-0.58185303,-3.3567767 90 | false,-0.47290254,-3.4017828 91 | false,0.6307907,-2.5497541 92 | false,0.16947198,-1.8616779 93 | false,0.62836826,-2.7033267 94 | false,-0.9814276,-5.376762 95 | false,-1.210118,-3.3350859 96 | false,0.52871215,-3.204388 97 | false,0.7025646,-3.8789291 98 | false,0.67520285,-2.9674406 99 | false,-1.283294,-3.181434 100 | false,-1.3194449,-2.8436 101 | false,0.6918458,-3.859788 102 | false,0.8273946,-4.263724 103 | false,0.5684525,-2.1557608 104 | false,0.26348603,-2.3115346 105 | false,-0.012208343,-3.3256352 106 | false,0.6822369,-3.268215 107 | false,-0.73920596,-3.245144 108 | false,0.7346592,-3.4969113 109 | false,0.6920308,-3.0882082 110 | false,0.6592177,-3.1389475 111 | false,-0.6693195,-4.2287436 112 | false,0.72450817,-3.458646 113 | false,0.6506053,-2.6916752 114 | false,0.5953907,-2.8399875 115 | false,-1.2699009,-2.510902 116 | false,-0.30402076,-2.96708 117 | false,-0.11133349,-2.593021 118 | false,-1.3464684,-2.630802 119 | false,0.6972213,-3.0918543 120 | false,-0.6777657,-3.206477 121 | false,-0.71906555,-2.6600616 122 | false,0.45432138,-3.7584772 123 | false,-0.86802197,-3.1350129 124 | false,0.7041714,-3.2037919 125 | false,-0.3661058,-4.0633583 126 | false,0.7499423,-3.64852 127 | false,-1.1538891,-2.5453768 128 | false,-0.99647,-5.381939 129 | false,0.6569295,-3.010945 130 | false,0.64912844,-2.8448672 131 | false,0.7465631,-3.5326567 132 | false,0.71838236,-4.1310287 133 | false,-1.264586,-3.0159638 134 | false,0.64894676,-3.107419 135 | false,0.72622204,-3.3314354 136 | false,0.5509386,-2.5613046 137 | false,0.880937,-4.855428 138 | false,-0.922426,-2.3913436 139 | false,0.63227654,-2.8213332 140 | false,0.6954663,-3.1046772 141 | false,0.7819948,-3.8812869 142 | false,-1.3025188,-3.0133727 143 | false,0.7222867,-3.391033 144 | false,0.37153876,-2.7466025 145 | false,0.55291605,-1.8625668 146 | false,-1.2127208,-3.1576986 147 | false,0.7226932,-3.3705604 148 | false,0.7328142,-3.4936802 149 | false,0.74241686,-4.1623945 150 | false,0.93980014,-5.2111545 151 | false,0.72233117,-3.3146696 152 | false,0.64522636,-2.6805782 153 | false,-0.2788471,-3.3225114 154 | false,0.7747294,-3.7838705 155 | false,-1.1970979,-3.2968557 156 | false,-1.2959403,-3.0687177 157 | false,0.6387843,-3.36836 158 | false,-0.8153173,-3.7166803 159 | false,-0.23521876,-3.1068423 160 | false,0.6858834,-3.2246234 161 | false,-0.679168,-3.37831 162 | false,-1.3467581,-2.614403 163 | false,0.7461469,-3.860871 164 | false,-0.81385136,-2.3788009 165 | false,-1.0387502,-2.330302 166 | false,-1.1294994,-3.2077525 167 | false,-0.9035834,-3.1591618 168 | false,0.6699456,-2.8748093 169 | false,-0.009785533,-2.6179066 170 | false,-0.053070188,-2.800707 171 | false,-1.1826934,-3.4899929 172 | false,0.6465385,-3.437864 173 | false,0.82264376,-4.2981815 174 | false,0.6977775,-3.5632071 175 | false,0.53678644,-3.713157 176 | false,-0.6443374,-2.7545311 177 | false,-1.1633543,-3.2097433 178 | false,-1.0217475,-3.3794284 179 | false,0.73973286,-3.4762852 180 | false,0.56944954,-1.9617668 181 | false,0.30056667,-2.7806234 182 | false,0.08328712,-2.73002 183 | false,-1.2830048,-2.7109056 184 | false,-1.1636153,-3.9040384 185 | false,-0.81747836,-2.8503904 186 | false,0.73330843,-3.4411893 187 | false,0.69177365,-3.4012935 188 | false,0.20494008,-4.5889935 189 | false,0.7062788,-3.6537428 190 | false,0.78206015,-3.88792 191 | false,-1.0173594,-3.5660207 192 | false,-0.7045444,-3.83532 193 | false,-1.4153104,-2.0306816 194 | false,0.6429255,-2.9447892 195 | false,-0.9160576,-3.4705598 196 | false,-0.5827174,-3.5821557 197 | false,-1.1231985,-2.879647 198 | false,-0.58335847,-3.1926105 199 | false,-1.1834576,-2.7094445 200 | -------------------------------------------------------------------------------- /tests/data/log_termination_0.csv: -------------------------------------------------------------------------------- 1 | true,0.26557046,-3.2022736 2 | true,0.5244869,-3.5498385 3 | true,-0.21809632,-3.3002653 4 | true,0.5479409,-2.906719 5 | true,-0.3813186,-3.390307 6 | true,0.03792116,-1.0382099 7 | true,-0.04919088,-2.8323147 8 | true,-0.453344,-3.415258 9 | true,0.01047641,-2.1300457 10 | false,1.0668023,-3.3786244 11 | false,-0.18769538,-3.601899 12 | false,-0.35202113,-3.1413345 13 | false,1.2880116,-2.521451 14 | false,0.3166774,-3.2458532 15 | false,0.10186827,-3.3000097 16 | false,0.36787283,-2.6364887 17 | false,-0.24988532,-2.2940626 18 | false,-0.6783223,-3.1264372 19 | false,-0.7201773,-3.5402062 20 | false,-0.4314589,-3.8959053 21 | false,-0.5853845,-3.2296875 22 | false,-0.37470338,-2.8795247 23 | false,-0.75128376,-3.6581087 24 | false,1.2318513,-3.3581736 25 | false,-0.12531418,-3.716817 26 | false,-0.42064115,-3.4001236 27 | false,-0.74604774,-3.7475586 28 | false,1.1256697,-2.8785355 29 | false,-0.08433908,-3.7822423 30 | false,1.3585893,-2.5086546 31 | false,0.005041659,-3.4469101 32 | false,0.18818086,-2.6128473 33 | false,-0.45354,-3.9081311 34 | false,-0.27919883,-3.1311884 35 | false,1.271002,-3.1878426 36 | false,-0.13111639,-3.2133822 37 | false,0.13094604,-2.9195528 38 | false,1.2829714,-3.1786385 39 | false,-0.29096013,-3.106109 40 | false,1.2689505,-2.5597675 41 | false,0.3342812,-4.0456214 42 | false,0.08396453,-3.0524874 43 | false,1.1104087,-3.8698661 44 | false,-0.68441665,-3.2513955 45 | false,0.18637711,-3.5095043 46 | false,0.7208801,-3.1672068 47 | false,-0.50222147,-4.1845937 48 | false,1.0418514,-3.6488628 49 | false,-0.49237987,-3.413788 50 | false,-0.5876201,-3.249963 51 | false,-0.1234926,-3.9614666 52 | false,0.09748709,-3.0888367 53 | false,-0.2486701,-3.6843758 54 | false,-0.676539,-4.1181054 55 | false,0.20423794,-2.4684455 56 | false,-0.012358367,-2.4847746 57 | false,-0.642534,-4.1481447 58 | false,-0.59853345,-3.8704948 59 | false,0.14618611,-2.9035354 60 | false,-0.21990025,-3.3992321 61 | false,-0.7370253,-3.5081766 62 | false,1.0310516,-2.5176485 63 | false,0.7160105,-3.4693534 64 | false,1.0968622,-2.70371 65 | false,-0.45908076,-2.7795982 66 | false,-0.5856943,-2.1704588 67 | false,1.2937942,-3.0139837 68 | false,-0.56300896,-2.0976164 69 | false,0.04820007,-4.0200267 70 | false,-0.0106089115,-3.2842436 71 | false,-0.14692777,-2.6681952 72 | false,-0.52543974,-3.1563013 73 | false,0.13696462,-3.6118274 74 | false,-0.59596336,-2.7857578 75 | false,0.24310935,-3.8758035 76 | false,-0.30030304,-4.309657 77 | false,1.0555162,-3.731285 78 | false,-0.3424111,-4.1431813 79 | false,0.5709101,-3.2412841 80 | false,1.3804396,-2.2630973 81 | false,1.2689683,-2.851879 82 | false,0.18201947,-2.4832733 83 | false,-0.29528198,-2.9775357 84 | false,1.2367306,-3.0602999 85 | false,-0.22552526,-3.2436771 86 | false,-0.43786216,-1.9410757 87 | false,-0.14895535,-3.1494207 88 | false,-0.39894313,-2.8839428 89 | false,-0.7180385,-3.3567767 90 | false,0.28069985,-3.4017828 91 | false,0.19821835,-2.5497537 92 | false,1.252464,-1.8616767 93 | false,0.32849503,-2.7033262 94 | false,-0.7762929,-5.376764 95 | false,-0.014078617,-3.3350859 96 | false,-0.45130554,-3.2043877 97 | false,-0.71616507,-3.8789296 98 | false,-0.6086625,-2.9674404 99 | false,-0.41315138,-3.1814337 100 | false,-0.5242159,-2.8435996 101 | false,-0.7438436,-3.8597884 102 | false,-0.81740195,-4.263725 103 | false,0.55138636,-2.1557596 104 | false,1.3621999,-2.3115337 105 | false,-0.62829953,-3.3256352 106 | false,-0.70248127,-3.2682147 107 | false,-0.2101,-3.2451437 108 | false,-0.40677756,-3.4969113 109 | false,0.5391067,-3.088208 110 | false,-0.66932964,-3.1389472 111 | false,0.7294785,-4.228744 112 | false,-0.21406662,-3.4586463 113 | false,0.03267044,-2.6916747 114 | false,-0.6382973,-2.8399873 115 | false,-0.3233307,-2.5109015 116 | false,-0.13329303,-2.9670796 117 | false,0.3818912,-2.5930202 118 | false,1.2497455,-2.6308012 119 | false,0.4388262,-3.0918543 120 | false,-0.253937,-3.206477 121 | false,-0.23217905,-2.660061 122 | false,-0.5770677,-3.7584777 123 | false,0.41546327,-3.1350126 124 | false,0.53724766,-3.2037919 125 | false,-0.7986775,-4.063359 126 | false,0.28877604,-3.6485202 127 | false,0.47571462,-2.545376 128 | false,0.09451425,-5.381941 129 | false,-0.34829602,-3.0109448 130 | false,-0.52060723,-2.844867 131 | false,-0.49226674,-3.532657 132 | false,-0.622851,-4.1310296 133 | false,1.2084823,-3.0159636 134 | false,0.6975992,-3.1074188 135 | false,-0.07615328,-3.3314354 136 | false,-0.5714911,-2.5613039 137 | false,-0.8744174,-4.8554296 138 | false,1.1694076,-2.3913429 139 | false,-0.65958846,-2.8213327 140 | false,-0.21230477,-3.1046772 141 | false,0.08752161,-3.8812873 142 | false,1.2759559,-3.0133724 143 | false,-0.45556974,-3.391033 144 | false,0.5014289,-2.746602 145 | false,-0.5371747,-1.8625656 146 | false,0.732478,-3.1576986 147 | false,0.33955103,-3.3705604 148 | false,-0.70246065,-3.4936805 149 | false,-0.37387347,-4.1623955 150 | false,-0.4731303,-5.2111564 151 | false,0.18484354,-3.3146698 152 | false,-0.3117661,-2.6805778 153 | false,1.2368045,-3.3225114 154 | false,0.7313858,-3.783871 155 | false,0.64754677,-3.2968557 156 | false,-0.6285814,-3.0687175 157 | false,0.5638829,-3.36836 158 | false,0.3041898,-3.7166808 159 | false,0.3807578,-3.1068423 160 | false,0.9275137,-3.2246234 161 | false,-0.7031316,-3.37831 162 | false,0.1693201,-2.6144023 163 | false,0.25447702,-3.8608716 164 | false,0.4682963,-2.3788 165 | false,-0.10173762,-2.3303013 166 | false,-0.07723445,-3.2077525 167 | false,0.40226156,-3.1591616 168 | false,-0.21446574,-2.874809 169 | false,-0.4003486,-2.617906 170 | false,0.3332528,-2.8007066 171 | false,-0.7320912,-3.489993 172 | false,1.2398887,-3.437864 173 | false,-0.4674436,-4.2981825 174 | false,0.25253803,-3.5632074 175 | false,1.0358511,-3.7131572 176 | false,0.5350432,-2.7545307 177 | false,-0.17243266,-3.2097433 178 | false,-0.6961477,-3.3794284 179 | false,0.30013,-3.4762855 180 | false,0.059309006,-1.9617656 181 | false,-0.4967607,-2.780623 182 | false,-0.45456904,-2.7300196 183 | false,-0.56493205,-2.710905 184 | false,1.1671963,-3.904039 185 | false,0.31851083,-2.85039 186 | false,-0.5102673,-3.4411893 187 | false,1.2568061,-3.4012935 188 | false,0.33830458,-4.5889945 189 | false,-0.07715261,-3.653743 190 | false,0.3121574,-3.8879204 191 | false,-0.5151504,-3.566021 192 | false,-0.16930008,-3.8353205 193 | false,-0.24594662,-2.0306807 194 | false,1.2807302,-2.944789 195 | false,0.5825742,-3.4705598 196 | false,-0.7365061,-3.5821562 197 | false,-0.6639884,-2.8796465 198 | false,-0.6828925,-3.1926105 199 | false,0.5622875,-2.709444 200 | -------------------------------------------------------------------------------- /src/response/logistic.rs: -------------------------------------------------------------------------------- 1 | //! functions for solving logistic regression 2 | 3 | use crate::{ 4 | error::{RegressionError, RegressionResult}, 5 | glm::{DispersionType, Glm}, 6 | link::Link, 7 | math::prod_log, 8 | num::Float, 9 | response::Response, 10 | }; 11 | use ndarray::Array1; 12 | use std::marker::PhantomData; 13 | 14 | /// Logistic regression 15 | pub struct Logistic 16 | where 17 | L: Link>, 18 | { 19 | _link: PhantomData, 20 | } 21 | 22 | /// The logistic response variable must be boolean (at least for now). 23 | impl Response> for bool 24 | where 25 | L: Link>, 26 | { 27 | fn into_float(self) -> RegressionResult { 28 | Ok(if self { F::one() } else { F::zero() }) 29 | } 30 | } 31 | // Allow floats for the domain. We can't use num_traits::Float because of the 32 | // possibility of conflicting implementations upstream, so manually implement 33 | // for f32 and f64. 34 | impl Response> for f32 35 | where 36 | L: Link>, 37 | { 38 | fn into_float(self) -> RegressionResult { 39 | if !(0.0..=1.0).contains(&self) { 40 | return Err(RegressionError::InvalidY(self.to_string())); 41 | } 42 | F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string())) 43 | } 44 | } 45 | impl Response> for f64 46 | where 47 | L: Link>, 48 | { 49 | fn into_float(self) -> RegressionResult { 50 | if !(0.0..=1.0).contains(&self) { 51 | return Err(RegressionError::InvalidY(self.to_string())); 52 | } 53 | F::from(self).ok_or_else(|| RegressionError::InvalidY(self.to_string())) 54 | } 55 | } 56 | 57 | /// Implementation of GLM functionality for logistic regression. 58 | impl Glm for Logistic 59 | where 60 | L: Link>, 61 | { 62 | type Link = L; 63 | const DISPERSED: DispersionType = DispersionType::NoDispersion; 64 | 65 | /// The log of the partition function for logistic regression. The natural 66 | /// parameter is the logit of p. 67 | fn log_partition(nat_par: F) -> F { 68 | num_traits::Float::exp(nat_par).ln_1p() 69 | } 70 | 71 | /// var = mu*(1-mu) 72 | fn variance(mean: F) -> F { 73 | mean * (F::one() - mean) 74 | } 75 | 76 | /// This function is specialized over the default provided by Glm in order 77 | /// to handle over/underflow issues more precisely. 78 | fn log_like_natural(y: F, logit_p: F) -> F 79 | where 80 | F: Float, 81 | { 82 | let (yt, xt) = if logit_p < F::zero() { 83 | (y, logit_p) 84 | } else { 85 | (F::one() - y, -logit_p) 86 | }; 87 | yt * xt - num_traits::Float::exp(xt).ln_1p() 88 | } 89 | 90 | /// The saturated likelihood is zero for logistic regression when y = 0 or 1 but is greater 91 | /// than zero for 0 < y < 1. 92 | fn log_like_sat(y: F) -> F { 93 | prod_log(y) + prod_log(F::one() - y) 94 | } 95 | } 96 | 97 | pub mod link { 98 | //! Link functions for logistic regression 99 | use super::*; 100 | use crate::link::{Canonical, Link, Transform}; 101 | use crate::num::Float; 102 | 103 | /// The canonical link function for logistic regression is the logit function g(p) = 104 | /// log(p/(1-p)). 105 | pub struct Logit {} 106 | impl Canonical for Logit {} 107 | impl Link> for Logit { 108 | fn func(y: F) -> F { 109 | num_traits::Float::ln(y / (F::one() - y)) 110 | } 111 | fn func_inv(lin_pred: F) -> F { 112 | (F::one() + num_traits::Float::exp(-lin_pred)).recip() 113 | } 114 | } 115 | 116 | /// The complementary log-log link g(p) = log(-log(1-p)) is appropriate when 117 | /// modeling the probability of non-zero counts when the counts are 118 | /// Poisson-distributed with mean lambda = exp(lin_pred). 119 | pub struct Cloglog {} 120 | impl Link> for Cloglog { 121 | fn func(y: F) -> F { 122 | num_traits::Float::ln(-F::ln_1p(-y)) 123 | } 124 | // This quickly underflows to zero for inputs greater than ~2. 125 | fn func_inv(lin_pred: F) -> F { 126 | -F::exp_m1(-num_traits::Float::exp(lin_pred)) 127 | } 128 | } 129 | impl Transform for Cloglog { 130 | fn nat_param(lin_pred: Array1) -> Array1 { 131 | lin_pred.mapv(|x| num_traits::Float::ln(num_traits::Float::exp(x).exp_m1())) 132 | } 133 | fn d_nat_param(lin_pred: &Array1) -> Array1 { 134 | let neg_exp_lin = -lin_pred.mapv(num_traits::Float::exp); 135 | &neg_exp_lin / &neg_exp_lin.mapv(F::exp_m1) 136 | } 137 | } 138 | } 139 | 140 | #[cfg(test)] 141 | mod tests { 142 | use super::*; 143 | use crate::{error::RegressionResult, model::ModelBuilder}; 144 | use approx::assert_abs_diff_eq; 145 | use ndarray::array; 146 | 147 | /// A simple test where the correct value for the data is known exactly. 148 | #[test] 149 | fn log_reg() -> RegressionResult<()> { 150 | let beta = array![0., 1.0]; 151 | let ln2 = f64::ln(2.); 152 | let data_x = array![[0.], [0.], [ln2], [ln2], [ln2]]; 153 | let data_y = array![true, false, true, true, false]; 154 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 155 | let fit = model.fit()?; 156 | // dbg!(fit.n_iter); 157 | // NOTE: This tolerance must be higher than it would ideally be. 158 | // Only 2 iterations are completed, so more accuracy could presumably be achieved with a 159 | // lower tolerance. 160 | assert_abs_diff_eq!(beta, fit.result, epsilon = 0.5 * f32::EPSILON as f64); 161 | // let lr = fit.lr_test(); 162 | Ok(()) 163 | } 164 | 165 | // verify that the link and inverse are indeed inverses. 166 | #[test] 167 | fn cloglog_closure() { 168 | use link::Cloglog; 169 | let mu_test_vals = array![1e-8, 0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99, 0.9999999]; 170 | assert_abs_diff_eq!( 171 | mu_test_vals, 172 | mu_test_vals.mapv(|mu| Cloglog::func_inv(Cloglog::func(mu))) 173 | ); 174 | let lin_test_vals = array![-10., -2., -0.1, 0.0, 0.1, 1., 2.]; 175 | assert_abs_diff_eq!( 176 | lin_test_vals, 177 | lin_test_vals.mapv(|lin| Cloglog::func(Cloglog::func_inv(lin))), 178 | epsilon = 1e-3 * f32::EPSILON as f64 179 | ); 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /src/glm.rs: -------------------------------------------------------------------------------- 1 | //! Trait defining a generalized linear model for common functionality. 2 | //! Models are fit such that = g^-1(X*B) where g is the link function. 3 | 4 | use crate::link::{Link, Transform}; 5 | use crate::{ 6 | error::RegressionResult, 7 | fit::{options::FitOptions, Fit}, 8 | irls::Irls, 9 | model::{Dataset, Model}, 10 | num::Float, 11 | }; 12 | use ndarray::{Array1, Array2}; 13 | use ndarray_linalg::SolveH; 14 | 15 | /// Whether the model's response has a free dispersion parameter (e.g. linear) or if it is fixed to 16 | /// one (e.g. logistic) 17 | pub enum DispersionType { 18 | FreeDispersion, 19 | NoDispersion, 20 | } 21 | 22 | /// Trait describing generalized linear model that enables the IRLS algorithm 23 | /// for fitting. 24 | pub trait Glm: Sized { 25 | /// The link function type of the GLM instantiation. Implementations specify 26 | /// this manually so that the provided methods can be called in this trait 27 | /// without necessitating a trait parameter. 28 | type Link: Link; 29 | 30 | /// Registers whether the dispersion is fixed at one (e.g. logistic) or free (e.g. linear) 31 | const DISPERSED: DispersionType; 32 | 33 | /// The link function which maps the expected value of the response variable 34 | /// to the linear predictor. 35 | fn link(y: Array1) -> Array1 { 36 | y.mapv(Self::Link::func) 37 | } 38 | 39 | /// The inverse of the link function which maps the linear predictors to the 40 | /// expected value of the prediction. 41 | fn mean(lin_pred: &Array1) -> Array1 { 42 | lin_pred.mapv(Self::Link::func_inv) 43 | } 44 | 45 | /// The logarithm of the partition function in terms of the natural parameter. 46 | /// This can be used to calculate the normalized likelihood. 47 | fn log_partition(nat_par: F) -> F; 48 | 49 | /// The variance as a function of the mean. This should be related to the 50 | /// Laplacian of the log-partition function, or in other words, the 51 | /// derivative of the inverse link function mu = g^{-1}(eta). This is unique 52 | /// to each response function, but should not depend on the link function. 53 | fn variance(mean: F) -> F; 54 | 55 | /// Returns the likelihood function summed over all observations. 56 | fn log_like(data: &Dataset, regressors: &Array1) -> F 57 | where 58 | F: Float, 59 | { 60 | // the total likelihood prior to regularization 61 | Self::log_like_terms(data, regressors).sum() 62 | } 63 | 64 | /// Returns the likelihood function of the response distribution as a 65 | /// function of the response variable y and the natural parameters of each 66 | /// observation. Terms that depend only on the response variable `y` are 67 | /// dropped. This dispersion parameter is taken to be 1, as it does not 68 | /// affect the IRLS steps. 69 | /// The default implementation can be overwritten for performance or numerical 70 | /// accuracy, but should be mathematically equivalent to the default implementation. 71 | fn log_like_natural(y: F, nat: F) -> F 72 | where 73 | F: Float, 74 | { 75 | // subtracting the saturated likelihood to keep the likelihood closer to 76 | // zero, but this can complicate some fit statistics. In addition to 77 | // causing some null likelihood tests to fail as written, it would make 78 | // the current deviance calculation incorrect. 79 | y * nat - Self::log_partition(nat) 80 | } 81 | 82 | /// Returns the likelihood of a saturated model where every observation can 83 | /// be fit exactly. 84 | fn log_like_sat(y: F) -> F 85 | where 86 | F: Float; 87 | 88 | /// Returns the log-likelihood contributions for each observable given the regressor values. 89 | fn log_like_terms(data: &Dataset, regressors: &Array1) -> Array1 90 | where 91 | F: Float, 92 | { 93 | let lin_pred = data.linear_predictor(regressors); 94 | let nat_par = Self::Link::nat_param(lin_pred); 95 | // the likelihood prior to regularization 96 | ndarray::Zip::from(&data.y) 97 | .and(&nat_par) 98 | .map_collect(|&y, &eta| Self::log_like_natural(y, eta)) 99 | } 100 | 101 | /// Provide an initial guess for the parameters. This can be overridden 102 | /// but this should provide a decent general starting point. The y data is 103 | /// averaged with its mean to prevent infinities resulting from application 104 | /// of the link function: 105 | /// X * beta_0 ~ g(0.5*(y + y_avg)) 106 | /// This is equivalent to minimizing half the sum of squared differences 107 | /// between X*beta and g(0.5*(y + y_avg)). 108 | // TODO: consider incorporating weights and/or correlations. 109 | fn init_guess(data: &Dataset) -> Array1 110 | where 111 | F: Float, 112 | Array2: SolveH, 113 | { 114 | let y_bar: F = data.y.mean().unwrap_or_else(F::zero); 115 | let mu_y: Array1 = data.y.mapv(|y| F::from(0.5).unwrap() * (y + y_bar)); 116 | let link_y = mu_y.mapv(Self::Link::func); 117 | // Compensate for linear offsets if they are present 118 | let link_y: Array1 = if let Some(off) = &data.linear_offset { 119 | &link_y - off 120 | } else { 121 | link_y 122 | }; 123 | let x_mat: Array2 = data.x.t().dot(&data.x); 124 | let init_guess: Array1 = 125 | x_mat 126 | .solveh_into(data.x.t().dot(&link_y)) 127 | .unwrap_or_else(|err| { 128 | eprintln!("WARNING: failed to get initial guess for IRLS. Will begin at zero."); 129 | eprintln!("{err}"); 130 | Array1::::zeros(data.x.ncols()) 131 | }); 132 | init_guess 133 | } 134 | 135 | /// Do the regression and return a result. Returns object holding fit result. 136 | fn regression( 137 | model: &Model, 138 | options: FitOptions, 139 | ) -> RegressionResult> 140 | where 141 | F: Float, 142 | Self: Sized, 143 | { 144 | let initial: Array1 = options 145 | .init_guess 146 | .clone() 147 | .unwrap_or_else(|| Self::init_guess(&model.data)); 148 | 149 | let mut irls: Irls = Irls::new(model, initial, options); 150 | 151 | for iteration in irls.by_ref() { 152 | let _it_result = iteration?; 153 | // TODO: Optionally track history 154 | } 155 | 156 | Ok(Fit::new( 157 | &model.data, 158 | model.use_intercept, 159 | irls, 160 | )) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /src/model.rs: -------------------------------------------------------------------------------- 1 | //! Collect data for and configure a model 2 | 3 | use crate::{ 4 | error::{RegressionError, RegressionResult}, 5 | fit::{self, Fit}, 6 | glm::Glm, 7 | math::is_rank_deficient, 8 | num::Float, 9 | response::Response, 10 | utility::one_pad, 11 | }; 12 | use fit::options::{FitConfig, FitOptions}; 13 | use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix1, Ix2}; 14 | use ndarray_linalg::InverseInto; 15 | use std::{ 16 | cell::{Ref, RefCell}, 17 | marker::PhantomData, 18 | }; 19 | 20 | pub struct Dataset 21 | where 22 | F: Float, 23 | { 24 | /// the observation of response data by event 25 | pub y: Array1, 26 | /// the design matrix with events in rows and instances in columns 27 | pub x: Array2, 28 | /// The offset in the linear predictor for each data point. This can be used 29 | /// to fix the effect of control variables. 30 | // TODO: Consider making this an option of a reference. 31 | pub linear_offset: Option>, 32 | /// The weight of each observation 33 | pub weights: Option>, 34 | /// The cached projection matrix 35 | // crate-public only so that a null dataset can be created. 36 | pub(crate) hat: RefCell>>, 37 | } 38 | 39 | impl Dataset 40 | where 41 | F: Float, 42 | { 43 | /// Returns the linear predictors, i.e. the design matrix multiplied by the 44 | /// regression parameters. Each entry in the resulting array is the linear 45 | /// predictor for a given observation. If linear offsets for each 46 | /// observation are provided, these are added to the linear predictors 47 | pub fn linear_predictor(&self, regressors: &Array1) -> Array1 { 48 | let linear_predictor: Array1 = self.x.dot(regressors); 49 | // Add linear offsets to the predictors if they are set 50 | if let Some(lin_offset) = &self.linear_offset { 51 | linear_predictor + lin_offset 52 | } else { 53 | linear_predictor 54 | } 55 | } 56 | 57 | /// Returns the hat matrix of the dataset of covariate data, also known as the "projection" or 58 | /// "influence" matrix. 59 | pub fn hat(&self) -> RegressionResult>> { 60 | if self.hat.borrow().is_none() { 61 | if self.weights.is_some() { 62 | unimplemented!("Weights must be accounted for in the hat matrix") 63 | } 64 | let xt = self.x.t(); 65 | let xtx: Array2 = xt.dot(&self.x); 66 | // NOTE: invh/invh_into() are bugged and incorrect! 67 | let xtx_inv = xtx.inv_into().map_err(|_| RegressionError::ColinearData)?; 68 | *self.hat.borrow_mut() = Some(self.x.dot(&xtx_inv).dot(&xt)); 69 | } 70 | let borrowed: Ref>> = self.hat.borrow(); 71 | Ok(Ref::map(borrowed, |x| x.as_ref().unwrap())) 72 | } 73 | 74 | /// Returns the leverage for each observation. This is given by the diagonal of the projection 75 | /// matrix and indicates the sensitivity of each prediction to its corresponding observation. 76 | pub fn leverage(&self) -> RegressionResult> { 77 | let hat = self.hat()?; 78 | Ok(hat.diag().to_owned()) 79 | } 80 | } 81 | 82 | /// Holds the data and configuration settings for a regression. 83 | pub struct Model 84 | where 85 | M: Glm, 86 | F: Float, 87 | { 88 | pub(crate) model: PhantomData, 89 | /// The dataset 90 | pub data: Dataset, 91 | /// Whether the intercept term is used (commonly true) 92 | pub use_intercept: bool, 93 | } 94 | 95 | impl Model 96 | where 97 | M: Glm, 98 | F: Float, 99 | { 100 | /// Perform the regression and return a fit object holding the results. 101 | pub fn fit(&self) -> RegressionResult> { 102 | self.fit_options().fit() 103 | } 104 | 105 | /// Fit options builder interface 106 | pub fn fit_options(&self) -> FitConfig { 107 | FitConfig { 108 | model: self, 109 | options: FitOptions::default(), 110 | } 111 | } 112 | 113 | /// An experimental interface that would allow fit options to be set externally. 114 | pub fn with_options(&self, options: FitOptions) -> FitConfig { 115 | FitConfig { 116 | model: self, 117 | options, 118 | } 119 | } 120 | } 121 | 122 | /// Provides an interface to create the full model option struct with convenient 123 | /// type inference. 124 | pub struct ModelBuilder { 125 | _model: PhantomData, 126 | } 127 | 128 | impl ModelBuilder { 129 | /// Borrow the Y and X data where each row in the arrays is a new 130 | /// observation, and create the full model builder with the data to allow 131 | /// for adjusting additional options. 132 | pub fn data<'a, Y, F, YD, XD>( 133 | data_y: &'a ArrayBase, 134 | data_x: &'a ArrayBase, 135 | ) -> ModelBuilderData<'a, M, Y, F> 136 | where 137 | Y: Response, 138 | F: Float, 139 | YD: Data, 140 | XD: Data, 141 | { 142 | ModelBuilderData { 143 | model: PhantomData, 144 | data_y: data_y.view(), 145 | data_x: data_x.view(), 146 | linear_offset: None, 147 | weights: None, 148 | use_intercept_term: true, 149 | colin_tol: F::epsilon(), 150 | } 151 | } 152 | } 153 | 154 | /// Holds the data and all the specifications for the model and provides 155 | /// functions to adjust the settings. 156 | pub struct ModelBuilderData<'a, M, Y, F> 157 | where 158 | M: Glm, 159 | Y: Response, 160 | F: 'static + Float, 161 | { 162 | model: PhantomData, 163 | /// Observed response variable data where each entry is a new observation. 164 | data_y: ArrayView1<'a, Y>, 165 | /// Design matrix of observed covariate data where each row is a new 166 | /// observation and each column represents a different dependent variable. 167 | data_x: ArrayView2<'a, F>, 168 | /// The offset in the linear predictor for each data point. This can be used 169 | /// to incorporate control terms. 170 | // TODO: consider making this a reference/ArrayView. Y and X are effectively 171 | // cloned so perhaps this isn't a big deal. 172 | linear_offset: Option>, 173 | /// The weights for each observation. 174 | weights: Option>, 175 | /// Whether to use an intercept term. Defaults to `true`. 176 | use_intercept_term: bool, 177 | /// tolerance for determinant check on rank of data matrix X. 178 | colin_tol: F, 179 | } 180 | 181 | /// A builder to generate a Model object 182 | impl<'a, M, Y, F> ModelBuilderData<'a, M, Y, F> 183 | where 184 | M: Glm, 185 | Y: Response + Copy, 186 | F: Float, 187 | { 188 | /// Represents an offset added to the linear predictor for each data point. 189 | /// This can be used to control for fixed effects or in multi-level models. 190 | pub fn linear_offset(mut self, linear_offset: Array1) -> Self { 191 | self.linear_offset = Some(linear_offset); 192 | self 193 | } 194 | 195 | /// Do not add a constant term to the design matrix 196 | pub fn no_constant(mut self) -> Self { 197 | self.use_intercept_term = false; 198 | self 199 | } 200 | 201 | /// Set the tolerance for the co-linearity check. 202 | /// The check can be effectively disabled by setting the tolerance to a negative value. 203 | pub fn colinear_tol(mut self, tol: F) -> Self { 204 | self.colin_tol = tol; 205 | self 206 | } 207 | 208 | pub fn build(self) -> RegressionResult> 209 | where 210 | M: Glm, 211 | F: Float, 212 | { 213 | let n_data = self.data_y.len(); 214 | if n_data != self.data_x.nrows() { 215 | return Err(RegressionError::BadInput( 216 | "y and x data must have same number of points".to_string(), 217 | )); 218 | } 219 | // If they are provided, check that the offsets have the correct number of entries 220 | if let Some(lin_off) = &self.linear_offset { 221 | if n_data != lin_off.len() { 222 | return Err(RegressionError::BadInput( 223 | "Offsets must have same dimension as observations".to_string(), 224 | )); 225 | } 226 | } 227 | 228 | // add constant term to X data 229 | let data_x = if self.use_intercept_term { 230 | one_pad(self.data_x) 231 | } else { 232 | self.data_x.to_owned() 233 | }; 234 | // Check if the data is under-constrained 235 | if n_data < data_x.ncols() { 236 | // The regression can find a solution if n_data == ncols, but there will be 237 | // no estimate for the uncertainty. Regularization can solve this, so keep 238 | // it to a warning. 239 | // return Err(RegressionError::Underconstrained); 240 | eprintln!("Warning: data is underconstrained"); 241 | } 242 | // Check for co-linearity up to a tolerance 243 | let xtx: Array2 = data_x.t().dot(&data_x); 244 | if is_rank_deficient(xtx, self.colin_tol)? { 245 | return Err(RegressionError::ColinearData); 246 | } 247 | 248 | // convert to floating-point 249 | let data_y: Array1 = self 250 | .data_y 251 | .iter() 252 | .map(|&y| y.into_float()) 253 | .collect::>()?; 254 | 255 | Ok(Model { 256 | model: PhantomData, 257 | data: Dataset { 258 | y: data_y, 259 | x: data_x, 260 | linear_offset: self.linear_offset, 261 | weights: self.weights, 262 | hat: RefCell::new(None), 263 | }, 264 | use_intercept: self.use_intercept_term, 265 | }) 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /src/regularization.rs: -------------------------------------------------------------------------------- 1 | //! Regularization methods and their effect on the likelihood and the matrix and 2 | //! vector components of the IRLS step. 3 | use crate::{error::RegressionResult, num::Float, Array1, Array2}; 4 | use ndarray::ArrayViewMut1; 5 | use ndarray_linalg::SolveH; 6 | 7 | /// Penalize the likelihood with a smooth function of the regression parameters. 8 | pub(crate) trait IrlsReg 9 | where 10 | F: Float, 11 | { 12 | /// Defines the impact of the regularization approach on the likelihood. It 13 | /// must be zero when the regressors are zero, otherwise some assumptions in 14 | /// the fitting statistics section may be invalidated. 15 | fn likelihood(&self, regressors: &Array1) -> F; 16 | 17 | /// Defines the regularization effect on the gradient of the likelihood with respect 18 | /// to beta. 19 | fn gradient(&self, l: Array1, regressors: &Array1) -> Array1; 20 | 21 | /// Processing to do before each step. 22 | fn prepare(&mut self, _guess: &Array1) {} 23 | 24 | /// For ADMM, the likelihood in the IRLS step is augmented with a rho term and does not include 25 | /// the L1 component. Without ADMM this should return the actual un-augmented likelihood. 26 | fn irls_like(&self, regressors: &Array1) -> F { 27 | self.likelihood(regressors) 28 | } 29 | 30 | /// Defines the adjustment to the vector side of the IRLS update equation. 31 | /// It is the negative gradient of the penalty plus the hessian times the 32 | /// regressors. A default implementation is provided, but this is zero for 33 | /// ridge regression so it is allowed to be overridden. 34 | fn irls_vec(&self, vec: Array1, regressors: &Array1) -> Array1; 35 | 36 | /// Defines the change to the matrix side of the IRLS update equation. It 37 | /// subtracts the Hessian of the penalty from the matrix. The difference is 38 | /// typically only on the diagonal. 39 | fn irls_mat(&self, mat: Array2, regressors: &Array1) -> Array2; 40 | 41 | /// Return the next guess under regularization given the current guess and the RHS and LHS of 42 | /// the unregularized IRLS matrix solution equation for the next guess. 43 | fn next_guess( 44 | &mut self, 45 | guess: &Array1, 46 | irls_vec: Array1, 47 | irls_mat: Array2, 48 | ) -> RegressionResult> { 49 | self.prepare(guess); 50 | // Apply the regularization effects to the Hessian (LHS) 51 | let lhs = self.irls_mat(irls_mat, guess); 52 | // Apply regularization effects to the modified Jacobian (RHS) 53 | let rhs = self.irls_vec(irls_vec, guess); 54 | let next_guess = lhs.solveh_into(rhs)?; 55 | Ok(next_guess) 56 | } 57 | 58 | fn terminate_ok(&self, _tol: F) -> bool { 59 | true 60 | } 61 | } 62 | 63 | /// Represents a lack of regularization. 64 | pub struct Null {} 65 | 66 | impl IrlsReg for Null { 67 | #[inline] 68 | fn likelihood(&self, _: &Array1) -> F { 69 | F::zero() 70 | } 71 | #[inline] 72 | fn gradient(&self, jac: Array1, _: &Array1) -> Array1 { 73 | jac 74 | } 75 | #[inline] 76 | fn irls_vec(&self, vec: Array1, _: &Array1) -> Array1 { 77 | vec 78 | } 79 | #[inline] 80 | fn irls_mat(&self, mat: Array2, _: &Array1) -> Array2 { 81 | mat 82 | } 83 | } 84 | 85 | /// Penalizes the regression by lambda/2 * |beta|^2. 86 | pub struct Ridge { 87 | l2_vec: Array1, 88 | } 89 | 90 | impl Ridge { 91 | /// Create the regularization from the diagonal. This outsources the 92 | /// question of whether to include the first term (usually the intercept) in 93 | /// the regularization. 94 | pub fn from_diag(l2: Array1) -> Self { 95 | Self { l2_vec: l2 } 96 | } 97 | } 98 | 99 | impl IrlsReg for Ridge { 100 | /// The likelihood is penalized by 0.5 * lambda_2 * |beta|^2 101 | fn likelihood(&self, beta: &Array1) -> F { 102 | -F::from(0.5).unwrap() * (&self.l2_vec * &beta.mapv(|b| b * b)).sum() 103 | } 104 | /// The gradient is penalized by lambda_2 * beta. 105 | fn gradient(&self, jac: Array1, beta: &Array1) -> Array1 { 106 | jac - (&self.l2_vec * beta) 107 | } 108 | /// Ridge regression has no effect on the vector side of the IRLS step equation, because the 109 | /// 1st and 2nd order derivative terms exactly cancel. 110 | #[inline] 111 | fn irls_vec(&self, vec: Array1, _: &Array1) -> Array1 { 112 | vec 113 | } 114 | /// Add lambda to the diagonals of the information matrix. 115 | fn irls_mat(&self, mut mat: Array2, _: &Array1) -> Array2 { 116 | let mut mat_diag: ArrayViewMut1 = mat.diag_mut(); 117 | mat_diag += &self.l2_vec; 118 | mat 119 | } 120 | } 121 | 122 | /// Penalizes the likelihood by the L1-norm of the parameters. 123 | pub struct Lasso { 124 | /// The L1 parameters for each element 125 | l1_vec: Array1, 126 | /// The dual solution 127 | dual: Array1, 128 | /// The cumulative sum of residuals for each element 129 | cum_res: Array1, 130 | /// ADMM penalty parameter 131 | rho: F, 132 | /// L2-Norm of primal residuals |r|^2 133 | r_sq: F, 134 | /// L2-Norm of dual residuals |s|^2 135 | s_sq: F, 136 | } 137 | 138 | impl Lasso { 139 | /// Create the regularization from the diagonal, outsourcing the question of whether to include 140 | /// the first term (commonly the intercept, which is left out) in the diagonal. 141 | pub fn from_diag(l1: Array1) -> Self { 142 | let n: usize = l1.len(); 143 | let gamma = Array1::zeros(n); 144 | let u = Array1::zeros(n); 145 | Self { 146 | l1_vec: l1, 147 | dual: gamma, 148 | cum_res: u, 149 | rho: F::one(), 150 | r_sq: F::infinity(), // or should it be NaN? 151 | s_sq: F::infinity(), 152 | } 153 | } 154 | 155 | fn update_rho(&mut self) { 156 | // Can these be declared const? 157 | let mu: F = F::from(8.).unwrap(); 158 | let tau: F = F::from(2.).unwrap(); 159 | if self.r_sq > mu * mu * self.s_sq { 160 | self.rho *= tau; 161 | self.cum_res /= tau; 162 | } 163 | if self.r_sq * mu * mu < self.s_sq { 164 | self.rho /= tau; 165 | self.cum_res *= tau; 166 | } 167 | } 168 | } 169 | 170 | impl IrlsReg for Lasso { 171 | fn likelihood(&self, beta: &Array1) -> F { 172 | -(&self.l1_vec * beta.mapv(num_traits::Float::abs)).sum() 173 | } 174 | 175 | // This is used in the fit's score function, for instance. Thus it includes the regularization 176 | // terms and not the augmented term. 177 | fn gradient(&self, jac: Array1, regressors: &Array1) -> Array1 { 178 | jac - &self.l1_vec * ®ressors.mapv(F::sign) 179 | } 180 | 181 | /// Update the dual solution and the cumulative residuals. 182 | fn prepare(&mut self, beta: &Array1) { 183 | // Apply adaptive penalty term updating 184 | self.update_rho(); 185 | 186 | let old_dual = self.dual.clone(); 187 | 188 | self.dual = soft_thresh(beta + &self.cum_res, &self.l1_vec / self.rho); 189 | // the primal residuals 190 | let r: Array1 = beta - &self.dual; 191 | // the dual residuals 192 | let s: Array1 = (&self.dual - old_dual) * self.rho; 193 | self.cum_res += &r; 194 | 195 | self.r_sq = r.mapv(|r| r * r).sum(); 196 | self.s_sq = s.mapv(|s| s * s).sum(); 197 | } 198 | 199 | fn irls_like(&self, regressors: &Array1) -> F { 200 | -F::from(0.5).unwrap() 201 | * self.rho 202 | * (regressors - &self.dual + &self.cum_res) 203 | .mapv(|x| x * x) 204 | .sum() 205 | } 206 | 207 | /// The beta term from the gradient is cancelled by the corresponding term from the Hessian. 208 | /// The dual and residual terms remain. 209 | fn irls_vec(&self, vec: Array1, _regressors: &Array1) -> Array1 { 210 | let d: Array1 = &self.dual - &self.cum_res; 211 | vec + d * self.rho 212 | } 213 | 214 | /// Add the constant rho to all elements of the diagonal of the Hessian. 215 | fn irls_mat(&self, mut mat: Array2, _: &Array1) -> Array2 { 216 | let mut mat_diag: ArrayViewMut1 = mat.diag_mut(); 217 | mat_diag += self.rho; 218 | mat 219 | } 220 | 221 | fn terminate_ok(&self, tol: F) -> bool { 222 | // Expressed like this, it should perhaps instead be an epsilon^2. 223 | let n: usize = self.dual.len(); 224 | let n_sq = F::from((n as f64).sqrt()).unwrap(); 225 | let r_pass = self.r_sq < n_sq * tol; 226 | let s_pass = self.s_sq < n_sq * tol; 227 | r_pass && s_pass 228 | } 229 | } 230 | 231 | /// Penalizes the likelihood with both an L1-norm and L2-norm. 232 | pub struct ElasticNet { 233 | /// The L1 parameters for each element 234 | l1_vec: Array1, 235 | /// The L2 parameters for each element 236 | l2_vec: Array1, 237 | /// The dual solution 238 | dual: Array1, 239 | /// The cumulative sum of residuals for each element 240 | cum_res: Array1, 241 | /// ADMM penalty parameter 242 | rho: F, 243 | /// L2-Norm of primal residuals |r|^2 244 | r_sq: F, 245 | /// L2-Norm of dual residuals |s|^2 246 | s_sq: F, 247 | } 248 | 249 | impl ElasticNet { 250 | /// Create the regularization from the diagonal, outsourcing the question of whether to include 251 | /// the first term (commonly the intercept, which is left out) in the diagonal. 252 | pub fn from_diag(l1: Array1, l2: Array1) -> Self { 253 | let n: usize = l1.len(); 254 | let gamma = Array1::zeros(n); 255 | let u = Array1::zeros(n); 256 | Self { 257 | l1_vec: l1, 258 | l2_vec: l2, 259 | dual: gamma, 260 | cum_res: u, 261 | rho: F::one(), 262 | r_sq: F::infinity(), // or should it be NaN? 263 | s_sq: F::infinity(), 264 | } 265 | } 266 | 267 | fn update_rho(&mut self) { 268 | // Can these be declared const? 269 | let mu: F = F::from(8.).unwrap(); 270 | let tau: F = F::from(2.).unwrap(); 271 | if self.r_sq > mu * mu * self.s_sq { 272 | self.rho *= tau; 273 | self.cum_res /= tau; 274 | } 275 | if self.r_sq * mu * mu < self.s_sq { 276 | self.rho /= tau; 277 | self.cum_res *= tau; 278 | } 279 | } 280 | } 281 | 282 | impl IrlsReg for ElasticNet { 283 | fn likelihood(&self, beta: &Array1) -> F { 284 | -(&self.l1_vec * beta.mapv(num_traits::Float::abs)).sum() 285 | -F::from(0.5).unwrap() * (&self.l2_vec * &beta.mapv(|b| b * b)).sum() 286 | } 287 | 288 | // This is used in the fit's score function, for instance. Thus it includes the regularization 289 | // terms and not the augmented term. 290 | fn gradient(&self, jac: Array1, regressors: &Array1) -> Array1 { 291 | jac - &self.l1_vec * ®ressors.mapv(F::sign) - &self.l2_vec * regressors 292 | } 293 | 294 | /// Update the dual solution and the cumulative residuals. 295 | fn prepare(&mut self, beta: &Array1) { 296 | // Apply adaptive penalty term updating 297 | self.update_rho(); 298 | 299 | let old_dual = self.dual.clone(); 300 | 301 | self.dual = soft_thresh(beta + &self.cum_res, &self.l1_vec / self.rho); 302 | // the primal residuals 303 | let r: Array1 = beta - &self.dual; 304 | // the dual residuals 305 | let s: Array1 = (&self.dual - old_dual) * self.rho; 306 | self.cum_res += &r; 307 | 308 | self.r_sq = r.mapv(|r| r * r).sum(); 309 | self.s_sq = s.mapv(|s| s * s).sum(); 310 | } 311 | 312 | fn irls_like(&self, regressors: &Array1) -> F { 313 | -F::from(0.5).unwrap() 314 | * self.rho 315 | * (regressors - &self.dual + &self.cum_res) 316 | .mapv(|x| x * x) 317 | .sum() 318 | -F::from(0.5).unwrap() * (&self.l2_vec * ®ressors.mapv(|b| b * b)).sum() 319 | } 320 | 321 | /// The beta term from the gradient is cancelled by the corresponding term from the Hessian. 322 | /// The dual and residual terms remain. 323 | fn irls_vec(&self, vec: Array1, _regressors: &Array1) -> Array1 { 324 | let d: Array1 = &self.dual - &self.cum_res; 325 | vec + d * self.rho 326 | } 327 | 328 | /// Add the constant rho to all elements of the diagonal of the Hessian. 329 | fn irls_mat(&self, mut mat: Array2, _: &Array1) -> Array2 { 330 | let mut mat_diag: ArrayViewMut1 = mat.diag_mut(); 331 | mat_diag += &self.l2_vec; 332 | mat_diag += self.rho; 333 | mat 334 | } 335 | 336 | fn terminate_ok(&self, tol: F) -> bool { 337 | // Expressed like this, it should perhaps instead be an epsilon^2. 338 | let n: usize = self.dual.len(); 339 | let n_sq = F::from((n as f64).sqrt()).unwrap(); 340 | let r_pass = self.r_sq < n_sq * tol; 341 | let s_pass = self.s_sq < n_sq * tol; 342 | r_pass && s_pass 343 | } 344 | } 345 | 346 | /// The soft thresholding operator 347 | fn soft_thresh(x: Array1, lambda: Array1) -> Array1 { 348 | let sign_x = x.mapv(F::sign); 349 | let abs_x = x.mapv(::abs); 350 | let red_x = abs_x - lambda; 351 | let clipped = red_x.mapv(|x| if x < F::zero() { F::zero() } else { x }); 352 | sign_x * clipped 353 | } 354 | 355 | #[cfg(test)] 356 | mod tests { 357 | use super::*; 358 | use approx::assert_abs_diff_eq; 359 | use ndarray::array; 360 | 361 | #[test] 362 | fn ridge_matrix() { 363 | let l = 1e-4; 364 | let ridge = Ridge::from_diag(array![0., l]); 365 | let mat = array![[0.5, 0.1], [0.1, 0.2]]; 366 | let mut target_mat = mat.clone(); 367 | target_mat[[1, 1]] += l; 368 | let dummy_beta = array![0., 0.]; 369 | assert_eq!(ridge.irls_mat(mat, &dummy_beta), target_mat); 370 | } 371 | 372 | #[test] 373 | fn soft_thresh_correct() { 374 | let x = array![0.25, -0.1, -0.4, 0.3, 0.5, -0.5]; 375 | let lambda = array![-0., 0.0, 0.1, 0.1, 1.0, 1.0]; 376 | let target = array![0.25, -0.1, -0.3, 0.2, 0., 0.]; 377 | let output = soft_thresh(x, lambda); 378 | assert_abs_diff_eq!(target, output); 379 | } 380 | } 381 | -------------------------------------------------------------------------------- /src/irls.rs: -------------------------------------------------------------------------------- 1 | //! Iteratively re-weighed least squares algorithm 2 | use crate::glm::Glm; 3 | use crate::link::Transform; 4 | use crate::model::{Dataset, Model}; 5 | use crate::regularization::{ElasticNet, Lasso, Null, Ridge}; 6 | use crate::{ 7 | error::{RegressionError, RegressionResult}, 8 | fit::options::FitOptions, 9 | num::Float, 10 | regularization::IrlsReg, 11 | }; 12 | use ndarray::{Array1, Array2}; 13 | use ndarray_linalg::SolveH; 14 | use std::marker::PhantomData; 15 | 16 | /// Iterate over updates via iteratively re-weighted least-squares until 17 | /// reaching a specified tolerance. 18 | pub(crate) struct Irls<'a, M, F> 19 | where 20 | M: Glm, 21 | F: Float, 22 | Array2: SolveH, 23 | { 24 | model: PhantomData, 25 | data: &'a Dataset, 26 | /// The current parameter guess. 27 | pub(crate) guess: Array1, 28 | /// The options for the fit 29 | pub(crate) options: FitOptions, 30 | /// The regularizer object, which may be stateful 31 | pub(crate) reg: Box>, 32 | /// The number of iterations taken so far 33 | pub n_iter: usize, 34 | /// The data likelihood for the previous iteration, unregularized and unaugmented. 35 | /// This is cached separately from the guess because it demands expensive matrix 36 | /// multiplications. The augmented and/or regularized terms are relatively cheap, so they 37 | /// aren't stored. 38 | pub(crate) last_like_data: F, 39 | /// Sometimes the next guess is better than the previous but within 40 | /// tolerance, so we want to return the current guess but exit immediately 41 | /// in the next iteration. 42 | done: bool, 43 | } 44 | 45 | impl<'a, M, F> Irls<'a, M, F> 46 | where 47 | M: Glm, 48 | F: Float, 49 | Array2: SolveH, 50 | { 51 | pub fn new(model: &'a Model, initial: Array1, options: FitOptions) -> Self { 52 | let data = &model.data; 53 | let reg = get_reg(&options, data.x.ncols(), model.use_intercept); 54 | let initial_like_data: F = M::log_like(data, &initial); 55 | Self { 56 | model: PhantomData, 57 | data, 58 | guess: initial, 59 | options, 60 | reg, 61 | n_iter: 0, 62 | last_like_data: initial_like_data, 63 | done: false, 64 | } 65 | } 66 | 67 | /// A helper function to step to a new guess, while incrementing the number 68 | /// of iterations and checking that it is not over the maximum. 69 | fn step_with(&mut self, next_guess: Array1, next_like_data: F) -> ::Item { 70 | self.guess.assign(&next_guess); 71 | self.last_like_data = next_like_data; 72 | let model_like = next_like_data + self.reg.likelihood(&next_guess); 73 | self.n_iter += 1; 74 | if self.n_iter > self.options.max_iter { 75 | // NOTE: This could also return the best guess so far. Including the data in the error 76 | // type would necessitate either a conversion to f32 or a parameterization. 77 | return Err(RegressionError::MaxIter(self.options.max_iter)); 78 | } 79 | Ok(IrlsStep { 80 | guess: next_guess, 81 | like: model_like, 82 | }) 83 | } 84 | 85 | /// Returns the (LHS, RHS) of the IRLS update matrix equation. This is a bit 86 | /// faster than computing the Fisher matrix and the Jacobian separately. 87 | /// The returned matrix and vector are not regularized. 88 | // TODO: re-factor to have the distributions compute the fisher information, 89 | // as that is useful in the score test as well. 90 | fn irls_mat_vec(&self) -> (Array2, Array1) { 91 | // The linear predictor without control terms 92 | let linear_predictor_no_control: Array1 = self.data.x.dot(&self.guess); 93 | // the linear predictor given the model, including offsets if present 94 | let linear_predictor = match &self.data.linear_offset { 95 | Some(off) => &linear_predictor_no_control + off, 96 | None => linear_predictor_no_control.clone(), 97 | }; 98 | // The data.linear_predictor() function is not used above because we will use 99 | // both versions, with and without the linear offset, and we don't want 100 | // to repeat the matrix multiplication. 101 | 102 | // The prediction of y given the current model. 103 | // This does cause an unnecessary clone with an identity link, but we 104 | // need the linear predictor around for the future. 105 | let predictor: Array1 = M::mean(&linear_predictor); 106 | 107 | // The variances predicted by the model. This should have weights with 108 | // it and must be non-zero. 109 | // This could become a full covariance with weights. 110 | // TODO: allow the variance conditioning to be a configurable parameter. 111 | let var_diag: Array1 = predictor.mapv(M::variance); 112 | 113 | // The errors represent the difference between observed and predicted. 114 | let errors = &self.data.y - &predictor; 115 | 116 | // Adjust the errors and variance using the appropriate derivatives of 117 | // the link function. 118 | let (errors, var_diag) = 119 | M::Link::adjust_errors_variance(errors, var_diag, &linear_predictor); 120 | // Try adjusting only the variance as if the derivative will cancel. 121 | // This might not be quite right due to the matrix multiplications. 122 | // let var_diag = M::Link::d_nat_param(&linear_predictor) * var_diag; 123 | 124 | // condition after the adjustment in case the derivatives are zero. Or 125 | // should the Hessian itself be conditioned? 126 | let var_diag: Array1 = var_diag.mapv_into(|v| v + F::epsilon()); 127 | 128 | // X weighted by the model variance for each observation 129 | // This is really the negative Hessian of the likelihood. 130 | // When adding correlations between observations this statement will 131 | // need to be modified. 132 | let neg_hessian: Array2 = (&self.data.x.t() * &var_diag).dot(&self.data.x); 133 | 134 | // This isn't quite the jacobian because the H*beta_old term is subtracted out. 135 | let rhs: Array1 = { 136 | // NOTE: This w*X should not include the linear offset, because it 137 | // comes from the Hessian times the last guess. 138 | let target: Array1 = (var_diag * linear_predictor_no_control) + errors; 139 | let target: Array1 = self.data.x.t().dot(&target); 140 | target 141 | }; 142 | (neg_hessian, rhs) 143 | } 144 | } 145 | 146 | /// Represents a step in the IRLS. Holds the current guess, likelihood, and the 147 | /// number of steps taken this iteration. 148 | pub struct IrlsStep { 149 | /// The current parameter guess. 150 | pub guess: Array1, 151 | /// The regularized log-likelihood of the current guess. 152 | pub like: F, 153 | // TODO: Consider tracking data likelihood, regularized likelihood, and augmented likelihood 154 | // separately. 155 | } 156 | 157 | impl<'a, M, F> Iterator for Irls<'a, M, F> 158 | where 159 | M: Glm, 160 | F: Float, 161 | Array2: SolveH, 162 | { 163 | type Item = RegressionResult>; 164 | 165 | /// Acquire the next IRLS step based on the previous one. 166 | fn next(&mut self) -> Option { 167 | // if the last step was an improvement but within tolerance, this step 168 | // has been flagged to terminate early. 169 | if self.done { 170 | return None; 171 | } 172 | 173 | let (irls_mat, irls_vec) = self.irls_mat_vec(); 174 | let next_guess: Array1 = match self.reg.next_guess(&self.guess, irls_vec, irls_mat) { 175 | Ok(solution) => solution, 176 | Err(err) => return Some(Err(err)), 177 | }; 178 | 179 | // This is the raw, unregularized and unaugmented 180 | let next_like_data = M::log_like(self.data, &next_guess); 181 | 182 | // The augmented likelihood to maximize may not be the same as the regularized model 183 | // likelihood. 184 | // NOTE: This must be computed after self.reg.next_guess() is called, because that step can 185 | // change the penalty parameter in ADMM. last_like_obj does not represent the previous 186 | // objective; it represents the current version of the objective function using the 187 | // previous guess. These may be different because the augmentation parameter and dual 188 | // variables for the regularization can change. 189 | let last_like_obj = self.last_like_data + self.reg.irls_like(&self.guess); 190 | let next_like_obj = next_like_data + self.reg.irls_like(&next_guess); 191 | 192 | // NOTE: might be optimizable by not checking the likelihood until step 193 | // = next_guess - &self.guess stops decreasing. There could be edge 194 | // cases that lead to poor convergence. 195 | // Ideally we could only check the step difference but that might not be 196 | // as stable. Some parameters might be at different scales. 197 | 198 | // If this guess is a strict improvement, return it immediately. 199 | if next_like_obj > last_like_obj { 200 | return Some(self.step_with(next_guess, next_like_data)); 201 | } 202 | 203 | // Indicates if the likelihood change is small, within tolerance, even if it is not 204 | // positive. 205 | let small_delta_like = small_delta(next_like_obj, last_like_obj, self.options.tol); 206 | 207 | // If the parameters have changed significantly but the likelihood hasn't improved, 208 | // step halving needs to be engaged. The parameter delta should probably ideally be 209 | // tested using the spread of the covariate data, but in principle the data can be 210 | // standardized so this will just compare to the raw tolerance. 211 | let small_delta_guess = small_delta_vec(&next_guess, &self.guess, self.options.tol); 212 | 213 | // Terminate if the difference is close to zero and the parameters haven't changed 214 | // significantly. 215 | if small_delta_like && small_delta_guess { 216 | // If this guess is an improvement then go ahead and return it, but 217 | // quit early on the next iteration. The equivalence with zero is 218 | // necessary in order to return a value when the iteration starts at 219 | // the best guess. This comparison includes zero so that the 220 | // iteration terminates if the likelihood hasn't changed at all. 221 | if next_like_obj >= last_like_obj { 222 | // assert_eq!(next_like_obj, last_like_obj); // this should still hold 223 | self.done = true; 224 | return Some(self.step_with(next_guess, next_like_data)); 225 | } 226 | return None; 227 | } 228 | 229 | // Don't go through step halving if the regularization isn't convergent 230 | if !self.reg.terminate_ok(self.options.tol) { 231 | return Some(self.step_with(next_guess, next_like_data)); 232 | } 233 | 234 | // apply step halving if the new likelihood is the same or worse as the previous guess. 235 | // NOTE: It's difficult to engage the step halving because it's rarely necessary, so this 236 | // part of the algorithm is undertested. It may be more common using L1 regularization. 237 | let f_step = |x: F| { 238 | let b = &next_guess * x + &self.guess * (F::one() - x); 239 | // Using the real likelihood in the step finding avoids potential issues with the 240 | // augmentation. They should be close to equivalent at this point because the 241 | // regularization has reported that the internals have converged. 242 | M::log_like(self.data, &b) + self.reg.likelihood(&b) 243 | }; 244 | let beta_tol_factor = num_traits::Float::sqrt(self.guess.mapv(|b| F::one() + b * b).sum()); 245 | let step_mult: F = step_scale(&f_step, beta_tol_factor * self.options.tol); 246 | if step_mult.is_zero() { 247 | // can't find a better minimum if the step multiplier returns zero 248 | return None; 249 | } 250 | // If step_mult == 1, that means the guess is a good one according to the un-augmented 251 | // regularized likelihood, so go ahead and use it. 252 | 253 | // If the step multiplier is not zero, it found a better guess 254 | let next_guess = &next_guess * step_mult + &self.guess * (F::one() - step_mult); 255 | let next_like_data = M::log_like(self.data, &next_guess); 256 | let next_like = M::log_like(self.data, &next_guess) + self.reg.likelihood(&next_guess); 257 | let last_like = self.last_like_data + self.reg.likelihood(&self.guess); 258 | if next_like < last_like { 259 | return None; 260 | } 261 | 262 | let small_delta_like = small_delta(next_like, last_like, self.options.tol); 263 | let small_delta_guess = small_delta_vec(&next_guess, &self.guess, self.options.tol); 264 | if small_delta_like && small_delta_guess { 265 | self.done = true; 266 | } 267 | 268 | Some(self.step_with(next_guess, next_like_data)) 269 | } 270 | } 271 | 272 | fn small_delta(new: F, old: F, tol: F) -> bool 273 | where 274 | F: Float, 275 | { 276 | let rel = (new - old) / (F::epsilon() + num_traits::Float::abs(new)); 277 | num_traits::Float::abs(rel) <= tol 278 | } 279 | 280 | fn small_delta_vec(new: &Array1, old: &Array1, tol: F) -> bool 281 | where 282 | F: Float, 283 | { 284 | // this method interpolates between relative and absolute differences 285 | let delta = new - old; 286 | let n = F::from(delta.len()).unwrap(); 287 | 288 | let new2: F = new.mapv(|d| d * d).sum(); 289 | let delta2: F = delta.mapv(|d| d * d).sum(); 290 | 291 | // use sum of absolute values to indicate magnitude of beta 292 | // sum of squares might be better 293 | delta2 <= (n + new2) * tol * tol 294 | } 295 | 296 | /// Zero the first element of the array `l` if `use_intercept == true` 297 | fn zero_first_maybe(mut l: Array1, use_intercept: bool) -> Array1 298 | where 299 | F: Float, 300 | { 301 | // if an intercept term is included it should not be subject to 302 | // regularization. 303 | if use_intercept { 304 | l[0] = F::zero(); 305 | } 306 | l 307 | } 308 | 309 | /// Generate a regularizer from the set of options 310 | fn get_reg( 311 | options: &FitOptions, 312 | n: usize, 313 | use_intercept: bool, 314 | ) -> Box> { 315 | if options.l1 < F::zero() || options.l2 < F::zero() { 316 | eprintln!("WARNING: regularization parameters should not be negative."); 317 | } 318 | let use_l1 = options.l1 > F::zero(); 319 | let use_l2 = options.l2 > F::zero(); 320 | 321 | if use_l1 && use_l2 { 322 | let l1_diag: Array1 = Array1::::from_elem(n, options.l1); 323 | let l1_diag: Array1 = zero_first_maybe(l1_diag, use_intercept); 324 | let l2_diag: Array1 = Array1::::from_elem(n, options.l2); 325 | let l2_diag: Array1 = zero_first_maybe(l2_diag, use_intercept); 326 | Box::new(ElasticNet::from_diag(l1_diag, l2_diag)) 327 | } else if use_l2 { 328 | let l2_diag: Array1 = Array1::::from_elem(n, options.l2); 329 | let l2_diag: Array1 = zero_first_maybe(l2_diag, use_intercept); 330 | Box::new(Ridge::from_diag(l2_diag)) 331 | } else if use_l1 { 332 | let l1_diag: Array1 = Array1::::from_elem(n, options.l1); 333 | let l1_diag: Array1 = zero_first_maybe(l1_diag, use_intercept); 334 | Box::new(Lasso::from_diag(l1_diag)) 335 | } else { 336 | Box::new(Null {}) 337 | } 338 | } 339 | 340 | /// Find a better step scale to optimize and objective function. 341 | /// Looks for a new solution better than x = 1 looking first at 0 < x < 1 and returning any value 342 | /// found to be a strict improvement. 343 | /// If none are found, it will check a single negative step. 344 | fn step_scale(f: &dyn Fn(F) -> F, tol: F) -> F { 345 | let tol = num_traits::Float::abs(tol); 346 | // TODO: Add list of values to explicitly try (for instance with zeroed parameters) 347 | 348 | let zero: F = F::zero(); 349 | let one: F = F::one(); 350 | // `scale = 0.5` should also work, but using the golden ratio is prettier. 351 | let scale = F::from(0.618033988749894).unwrap(); 352 | let mut x: F = one; 353 | let f0: F = f(zero); 354 | 355 | while x > tol { 356 | let fx = f(x); 357 | if fx > f0 { 358 | return x; 359 | } 360 | x *= scale; 361 | } 362 | 363 | // If f(1) > f(0), then an improvement has already been found. However, if the optimization is 364 | // languishing, it could be useful to try x > 1. It's pretty rare to get to this state, 365 | // however. 366 | 367 | // If we're here a strict improvement hasn't been found, but it's possible that the likelihoods 368 | // are equal. 369 | // check a single step in the negative direction, in case this is an improvement. 370 | if f(-scale) > f0 { 371 | return -scale; 372 | } 373 | 374 | x 375 | } 376 | -------------------------------------------------------------------------------- /src/fit.rs: -------------------------------------------------------------------------------- 1 | //! Stores the fit results of the IRLS regression and provides functions that 2 | //! depend on the MLE estimate. These include statistical tests for goodness-of-fit. 3 | 4 | pub mod options; 5 | use crate::{ 6 | error::RegressionResult, 7 | glm::{DispersionType, Glm}, 8 | irls::Irls, 9 | link::{Link, Transform}, 10 | model::{Dataset, Model}, 11 | num::Float, 12 | regularization::IrlsReg, 13 | Linear, 14 | }; 15 | use ndarray::{array, Array1, Array2, ArrayBase, ArrayView1, Axis, Data, Ix2}; 16 | use ndarray_linalg::InverseInto; 17 | use options::FitOptions; 18 | use std::{ 19 | cell::{Ref, RefCell}, 20 | marker::PhantomData, 21 | }; 22 | 23 | /// the result of a successful GLM fit 24 | pub struct Fit<'a, M, F> 25 | where 26 | M: Glm, 27 | F: Float, 28 | { 29 | model: PhantomData, 30 | /// The data and model specification used in the fit. 31 | data: &'a Dataset, 32 | /// Whether the intercept covariate is used 33 | use_intercept: bool, 34 | /// The parameter values that maximize the likelihood as given by the IRLS regression. 35 | pub result: Array1, 36 | /// The options used for this fit. 37 | pub options: FitOptions, 38 | /// The value of the likelihood function for the fit result. 39 | pub model_like: F, 40 | /// The regularizer of the fit 41 | reg: Box>, 42 | /// The number of overall iterations taken in the IRLS. 43 | pub n_iter: usize, 44 | /// The number of data points 45 | n_data: usize, 46 | /// The number of parameters 47 | n_par: usize, 48 | /// The estimated covariance matrix of the parameters. Since the calculation 49 | /// requires a matrix inversion, it is computed only when needed and the 50 | /// value is cached. Access through the `covariance()` function. 51 | cov: RefCell>>, 52 | /// The likelihood and parameters for the null model. 53 | null_model: RefCell)>>, 54 | } 55 | 56 | impl<'a, M, F> Fit<'a, M, F> 57 | where 58 | M: Glm, 59 | F: 'static + Float, 60 | { 61 | /// Returns the Akaike information criterion for the model fit. 62 | // TODO: Should an effective number of parameters that takes regularization 63 | // into acount be considered? 64 | pub fn aic(&self) -> F { 65 | F::from(2 * self.n_par).unwrap() - F::from(2.).unwrap() * self.model_like 66 | } 67 | 68 | /// Returns the Bayesian information criterion for the model fit. 69 | // TODO: Also consider the effect of regularization on this statistic. 70 | // TODO: Wikipedia suggests that the variance should included in the number 71 | // of parameters for multiple linear regression. Should an additional 72 | // parameter be included for the dispersion parameter? This question does 73 | // not affect the difference between two models fit with the methodology in 74 | // this package. 75 | pub fn bic(&self) -> F { 76 | let logn = num_traits::Float::ln(F::from(self.data.y.len()).unwrap()); 77 | logn * F::from(self.n_par).unwrap() - F::from(2.).unwrap() * self.model_like 78 | } 79 | 80 | /// The covariance matrix estimated by the Fisher information and the dispersion parameter (for 81 | /// families with a free scale). The matrix is cached to avoid repeating the potentially 82 | /// expensive matrix inversion. 83 | pub fn covariance(&self) -> RegressionResult>> { 84 | if self.cov.borrow().is_none() { 85 | if self.data.weights.is_some() { 86 | // NOTE: Perhaps it is just the fisher matrix that must be updated. 87 | unimplemented!( 88 | "The covariance calculation must take into account weights/correlations." 89 | ); 90 | } 91 | let fisher_reg = self.fisher(&self.result); 92 | // The covariance must be multiplied by the dispersion parameter. 93 | // For logistic/poisson regression, this is identically 1. 94 | // For linear/gamma regression it is estimated from the data. 95 | let phi: F = self.dispersion(); 96 | // NOTE: invh/invh_into() are bugged and incorrect! 97 | let unscaled_cov: Array2 = fisher_reg.inv_into()?; 98 | let cov = unscaled_cov * phi; 99 | *self.cov.borrow_mut() = Some(cov); 100 | } 101 | Ok(Ref::map(self.cov.borrow(), |x| x.as_ref().unwrap())) 102 | } 103 | 104 | /// Returns the deviance of the fit: twice the difference between the 105 | /// saturated likelihood and the model likelihood. Asymptotically this fits 106 | /// a chi-squared distribution with `self.ndf()` degrees of freedom. 107 | /// Note that the regularized likelihood is used here. 108 | // TODO: This is likely sensitive to regularization because the saturated 109 | // model is not regularized but the model likelihood is. Perhaps this can be 110 | // accounted for with an effective number of degrees of freedom. 111 | pub fn deviance(&self) -> F { 112 | // Note that this must change if the GLM likelihood subtracts the 113 | // saturated one already. 114 | F::from(2.).unwrap() * (self.data.y.mapv(M::log_like_sat).sum() - self.model_like) 115 | } 116 | 117 | /// The dispersion parameter(typically denoted `phi`) which relates the variance of the `y` 118 | /// values with the variance of the response distribution: `Var[y] = phi * Var[mu]`. 119 | /// Identically one for logistic, binomial, and Poisson regression. 120 | /// For others (linear, gamma) the dispersion parameter is estimated from the data. 121 | /// This is equal to the total deviance divided by the degrees of freedom. For OLS linear 122 | /// regression this is equal to the sum of `(y_i - mu_i)^2 / (n-p)`, an estimate of `sigma^2`; 123 | /// with no covariates it is equal to the sample variance. 124 | pub fn dispersion(&self) -> F { 125 | use DispersionType::*; 126 | match M::DISPERSED { 127 | FreeDispersion => { 128 | let ndf: F = F::from(self.ndf()).unwrap(); 129 | let dev = self.deviance(); 130 | dev / ndf 131 | } 132 | NoDispersion => F::one(), 133 | } 134 | } 135 | 136 | /// Returns the errors in the response variables for the data passed as an 137 | /// argument given the current model fit. 138 | fn errors(&self, data: &Dataset) -> Array1 { 139 | &data.y - &self.predict(&data.x, data.linear_offset.as_ref()) 140 | } 141 | 142 | #[deprecated(since = "0.0.10", note = "use predict() instead")] 143 | pub fn expectation( 144 | &self, 145 | data_x: &ArrayBase, 146 | lin_off: Option<&Array1>, 147 | ) -> Array1 148 | where 149 | S: Data, 150 | { 151 | self.predict(data_x, lin_off) 152 | } 153 | 154 | /// Returns the fisher information (the negative hessian of the likelihood) 155 | /// at the parameter values given. The regularization is included. 156 | pub fn fisher(&self, params: &Array1) -> Array2 { 157 | let lin_pred: Array1 = self.data.linear_predictor(params); 158 | let mu: Array1 = M::mean(&lin_pred); 159 | let var_diag: Array1 = mu.mapv_into(M::variance); 160 | // adjust the variance for non-canonical link functions 161 | let eta_d = M::Link::d_nat_param(&lin_pred); 162 | let adj_var: Array1 = &eta_d * &var_diag * eta_d; 163 | // calculate the fisher matrix 164 | let fisher: Array2 = (&self.data.x.t() * &adj_var).dot(&self.data.x); 165 | // Regularize the fisher matrix 166 | self.reg.as_ref().irls_mat(fisher, params) 167 | } 168 | 169 | /// Perform a likelihood-ratio test, returning the statistic -2*ln(L_0/L) 170 | /// where L_0 is the likelihood of the best-fit null model (with no 171 | /// parameters but the intercept) and L is the likelihood of the fit result. 172 | /// The number of degrees of freedom of this statistic, equal to the number 173 | /// of parameters fixed to zero to form the null model, is `test_ndf()`. By 174 | /// Wilks' theorem this statistic is asymptotically chi-squared distributed 175 | /// with this number of degrees of freedom. 176 | // TODO: Should the effective number of degrees of freedom due to 177 | // regularization be taken into account? Should the degrees of freedom be a 178 | // float? 179 | pub fn lr_test(&self) -> F { 180 | // The model likelihood should include regularization terms and there 181 | // shouldn't be any in the null model with all non-intercept parameters 182 | // set to zero. 183 | let null_like = self.null_like(); 184 | F::from(-2.).unwrap() * (null_like - self.model_like) 185 | } 186 | 187 | /// Perform a likelihood-ratio test against a general alternative model, not 188 | /// necessarily a null model. The alternative model is regularized the same 189 | /// way that the regression resulting in this fit was. The degrees of 190 | /// freedom cannot be generally inferred. 191 | pub fn lr_test_against(&self, alternative: &Array1) -> F { 192 | let alt_like = M::log_like(self.data, alternative); 193 | let alt_like_reg = alt_like + self.reg.likelihood(alternative); 194 | F::from(2.).unwrap() * (self.model_like - alt_like_reg) 195 | } 196 | 197 | /// Returns the residual degrees of freedom in the model, i.e. the number 198 | /// of data points minus the number of parameters. Not to be confused with 199 | /// `test_ndf()`, the degrees of freedom in the statistical tests of the 200 | /// fit. 201 | pub fn ndf(&self) -> usize { 202 | self.n_data - self.n_par 203 | } 204 | 205 | pub(crate) fn new(data: &'a Dataset, use_intercept: bool, irls: Irls) -> Self { 206 | let Irls { 207 | guess: result, 208 | options, 209 | reg, 210 | n_iter, 211 | last_like_data: data_like, 212 | .. 213 | } = irls; 214 | assert_eq!(data_like, M::log_like(data, &result), "Unregularized likelihoods should match exactly."); 215 | // Cache some of these variables that will be used often. 216 | let n_par = result.len(); 217 | let n_data = data.y.len(); 218 | let model_like = data_like + reg.likelihood(&result); 219 | Self { 220 | model: PhantomData, 221 | data, 222 | use_intercept, 223 | result, 224 | options, 225 | model_like, 226 | reg, 227 | n_iter, 228 | n_data, 229 | n_par, 230 | cov: RefCell::new(None), 231 | null_model: RefCell::new(None), 232 | } 233 | } 234 | 235 | /// Returns the likelihood given the null model, which fixes all parameters 236 | /// to zero except the intercept (if it is used). A total of `test_ndf()` 237 | /// parameters are constrained. 238 | pub fn null_like(&self) -> F { 239 | let (null_like, _) = self.null_model_fit(); 240 | null_like 241 | } 242 | 243 | /// Return the likelihood and intercept for the null model. Since this can 244 | /// require an additional regression, the values are cached. 245 | fn null_model_fit(&self) -> (F, Array1) { 246 | // TODO: make a result instead of allowing a potential panic in the borrow. 247 | if self.null_model.borrow().is_none() { 248 | let (null_like, null_intercept): (F, Array1) = match &self.data.linear_offset { 249 | None => { 250 | // If there is no linear offset, the natural parameter is 251 | // identical for all observations so it is sufficient to 252 | // calculate the null likelihood for a single point with y equal 253 | // to the average. 254 | // The average y 255 | let y_bar: F = self 256 | .data 257 | .y 258 | .mean() 259 | .expect("Should be able to take average of y values"); 260 | // This approach assumes that the likelihood is in the natural 261 | // exponential form as calculated by Glm::log_like_natural(). If that 262 | // function is overridden and the values differ significantly, this 263 | // approach will give incorrect results. If the likelihood has terms 264 | // non-linear in y, then the likelihood must be calculated for every 265 | // point rather than averaged. 266 | // If the intercept is allowed to maximize the likelihood, the natural 267 | // parameter is equal to the link of the expectation. Otherwise it is 268 | // the transformation function of zero. 269 | let intercept: F = if self.use_intercept { 270 | M::Link::func(y_bar) 271 | } else { 272 | F::zero() 273 | }; 274 | // this is a length-one array. This works because the 275 | // likelihood contribution is the same for all observations. 276 | let nat_par = M::Link::nat_param(array![intercept]); 277 | // The null likelihood per observation 278 | let null_like_one: F = M::log_like_natural(y_bar, nat_par[0]); 279 | // just multiply the average likelihood by the number of data points, since every term is the same. 280 | let null_like_total = F::from(self.n_data).unwrap() * null_like_one; 281 | let null_params: Array1 = { 282 | let mut par = Array1::::zeros(self.n_par); 283 | par[0] = intercept; 284 | par 285 | }; 286 | (null_like_total, null_params) 287 | } 288 | Some(off) => { 289 | if self.use_intercept { 290 | // If there are linear offsets and the intercept is allowed 291 | // to be free, there is not a major simplification and the 292 | // model needs to be re-fit. 293 | // the X data is a single column of ones. Since this model 294 | // isn't being created by the ModelBuilder, the X data 295 | // has to be automatically padded with ones. 296 | let data_x_null = Array2::::ones((self.n_data, 1)); 297 | let null_model = Model { 298 | model: std::marker::PhantomData::, 299 | data: Dataset:: { 300 | y: self.data.y.clone(), 301 | x: data_x_null, 302 | linear_offset: Some(off.clone()), 303 | weights: self.data.weights.clone(), 304 | hat: RefCell::new(None), 305 | }, 306 | // If we are in this branch it is because an intercept is needed. 307 | use_intercept: true, 308 | }; 309 | // TODO: Make this function return an error, although it's 310 | // difficult to imagine this case happening. 311 | // TODO: Should the tolerance of this fit be stricter? 312 | // The intercept should not be regularized 313 | let null_fit = null_model 314 | .fit_options() 315 | // There shouldn't be too much trouble fitting this 316 | // single-parameter fit, but there shouldn't be harm in 317 | // using the same maximum as in the original model. 318 | .max_iter(self.options.max_iter) 319 | .fit() 320 | .expect("Could not fit null model!"); 321 | let null_params: Array1 = { 322 | let mut par = Array1::::zeros(self.n_par); 323 | // there is only one parameter in this fit. 324 | par[0] = null_fit.result[0]; 325 | par 326 | }; 327 | (null_fit.model_like, null_params) 328 | } else { 329 | // If the intercept is fixed to zero, then no minimization is 330 | // required. The natural parameters are directly known in terms 331 | // of the linear offset. The likelihood must still be summed 332 | // over all observations, since they have different offsets. 333 | let nat_par = M::Link::nat_param(off.clone()); 334 | let null_like = ndarray::Zip::from(&self.data.y) 335 | .and(&nat_par) 336 | .map_collect(|&y, &eta| M::log_like_natural(y, eta)) 337 | .sum(); 338 | let null_params = Array1::::zeros(self.n_par); 339 | (null_like, null_params) 340 | } 341 | } 342 | }; 343 | *self.null_model.borrow_mut() = Some((null_like, null_intercept)); 344 | } 345 | self.null_model 346 | .borrow() 347 | .as_ref() 348 | .expect("the null model should be cached now") 349 | .clone() 350 | } 351 | 352 | /// Returns the expected value of Y given the input data X. This data need 353 | /// not be the training data, so an option for linear offsets is provided. 354 | /// Panics if the number of covariates in the data matrix is not consistent 355 | /// with the training set. The data matrix may need to be padded by ones if 356 | /// it is not part of a Model. The `utility::one_pad()` function facilitates 357 | /// this. 358 | pub fn predict(&self, data_x: &ArrayBase, lin_off: Option<&Array1>) -> Array1 359 | where 360 | S: Data, 361 | { 362 | let lin_pred: Array1 = data_x.dot(&self.result); 363 | let lin_pred: Array1 = if let Some(off) = &lin_off { 364 | lin_pred + *off 365 | } else { 366 | lin_pred 367 | }; 368 | lin_pred.mapv_into(M::Link::func_inv) 369 | } 370 | 371 | /// Return the deviance residuals for each point in the training data. 372 | /// Equal to `sign(y-E[y|x])*sqrt(-2*(L[y|x] - L_sat[y]))`. 373 | /// This is usually a better choice for non-linear models. 374 | /// NaNs might be possible if L[y|x] > L_sat[y] due to floating-point operations. These are 375 | /// not checked or clipped right now. 376 | pub fn resid_dev(&self) -> Array1 { 377 | let signs = self.resid_resp().mapv_into(F::signum); 378 | let ll_terms: Array1 = M::log_like_terms(self.data, &self.result); 379 | let ll_sat: Array1 = self.data.y.mapv(M::log_like_sat); 380 | let neg_two = F::from(-2.).unwrap(); 381 | let ll_diff = (ll_terms - ll_sat) * neg_two; 382 | let dev: Array1 = ll_diff.mapv_into(num_traits::Float::sqrt); 383 | signs * dev 384 | } 385 | 386 | /// Return the standardized deviance residuals, also known as the "internally studentized 387 | /// deviance residuals". This is generally applicable for outlier detection, although the 388 | /// influence of each point on the fit is only approximately accounted for. 389 | /// `d / sqrt(phi * (1 - h))` where `d` is the deviance residual, phi is the dispersion (e.g. 390 | /// sigma^2 for linear regression, 1 for logistic regression), and h is the leverage. 391 | pub fn resid_dev_std(&self) -> RegressionResult> { 392 | let dev = self.resid_dev(); 393 | let phi = self.dispersion(); 394 | let hat: Array1 = self.data.leverage()?; 395 | let omh: Array1 = -hat + F::one(); 396 | let denom: Array1 = (omh * phi).mapv_into(num_traits::Float::sqrt); 397 | Ok(dev / denom) 398 | } 399 | 400 | /// Return the partial residuals. 401 | pub fn resid_part(&self) -> Array1 { 402 | let x_mean = self.data.x.mean_axis(Axis(0)).expect("empty dataset"); 403 | let x_centered = &self.data.x - x_mean.insert_axis(Axis(0)); 404 | self.resid_work() + x_centered.dot(&self.result) 405 | } 406 | 407 | /// Return the Pearson residuals for each point in the training data. 408 | /// This is equal to `(y - E[y])/sqrt(V(E[y]))`, where V is the variance function. 409 | /// These are not scaled by the sample standard deviation for families with a free dispersion 410 | /// parameter like linear regression. 411 | pub fn resid_pear(&self) -> Array1 { 412 | let mu: Array1 = self.predict(&self.data.x, self.data.linear_offset.as_ref()); 413 | let residuals = &self.data.y - μ 414 | let var_diag: Array1 = mu.mapv_into(M::variance); 415 | let std: Array1 = var_diag.mapv_into(num_traits::Float::sqrt); 416 | residuals / std 417 | } 418 | 419 | /// Return the standardized Pearson residuals for every observation. 420 | /// Also known as the "internally studentized Pearson residuals". 421 | /// (y - E[y]) / (sqrt(Var[y] * (1 - h))) where h is a vector representing the leverage for 422 | /// each observation. 423 | pub fn resid_pear_std(&self) -> RegressionResult> { 424 | let pearson = self.resid_pear(); 425 | let phi = self.dispersion(); 426 | let hat = self.data.leverage()?; 427 | let omh = -hat + F::one(); 428 | let denom: Array1 = (omh * phi).mapv_into(num_traits::Float::sqrt); 429 | Ok(pearson / denom) 430 | } 431 | 432 | /// Return the response residuals, or fitting deviation, for each data point in the fit; that 433 | /// is, the difference y - E[y|x] where the expectation value is the y value predicted by the 434 | /// model given x. 435 | pub fn resid_resp(&self) -> Array1 { 436 | self.errors(self.data) 437 | } 438 | 439 | /// Return the studentized residuals, which are the changes in the fit likelihood resulting 440 | /// from leaving each observation out. This is a robust and general method for outlier 441 | /// detection, although a one-step approximation is used to avoid re-fitting the model 442 | /// completely for each observation. 443 | /// If the linear errors are standard normally distributed then this statistic should follow a 444 | /// t-distribution with `self.ndf() - 1` degrees of freedom. 445 | pub fn resid_student(&self) -> RegressionResult> { 446 | let r_dev = self.resid_dev(); 447 | let r_pear = self.resid_pear(); 448 | let signs = r_pear.mapv(F::signum); 449 | let r_dev_sq = r_dev.mapv_into(|x| x * x); 450 | let r_pear_sq = r_pear.mapv_into(|x| x * x); 451 | let hat = self.data.leverage()?; 452 | let omh = -hat.clone() + F::one(); 453 | let sum_quad = &r_dev_sq + hat * r_pear_sq / &omh; 454 | let sum_quad_scaled = match M::DISPERSED { 455 | // The dispersion is corrected for the contribution from each current point. 456 | // This is an approximation; the exact solution would perform a fit at each point. 457 | DispersionType::FreeDispersion => { 458 | let dev = self.deviance(); 459 | let dof = F::from(self.ndf() - 1).unwrap(); 460 | let phi_i: Array1 = (-r_dev_sq / &omh + dev) / dof; 461 | sum_quad / phi_i 462 | } 463 | DispersionType::NoDispersion => sum_quad, 464 | }; 465 | Ok(signs * sum_quad_scaled.mapv_into(num_traits::Float::sqrt)) 466 | } 467 | 468 | /// Returns the working residuals `d\eta/d\mu * (y - E{y|x})`. 469 | /// This should be equal to the response residuals divided by the variance function (as 470 | /// opposed to the square root of the variance as in the Pearson residuals). 471 | pub fn resid_work(&self) -> Array1 { 472 | let lin_pred: Array1 = self.data.linear_predictor(&self.result); 473 | let mu: Array1 = lin_pred.mapv(M::Link::func_inv); 474 | let resid_response: Array1 = &self.data.y - μ 475 | let d_eta: Array1 = M::Link::d_nat_param(&lin_pred); 476 | d_eta * resid_response 477 | } 478 | 479 | /// Returns the score function (the gradient of the likelihood) at the 480 | /// parameter values given. It should be zero within FPE at the minimized 481 | /// result. 482 | pub fn score(&self, params: &Array1) -> Array1 { 483 | // This represents the predictions given the input parameters, not the 484 | // fit parameters. 485 | let lin_pred: Array1 = self.data.linear_predictor(params); 486 | let mu: Array1 = M::mean(&lin_pred); 487 | let resid_response = &self.data.y - mu; 488 | // adjust for non-canonical link functions. 489 | let eta_d = M::Link::d_nat_param(&lin_pred); 490 | let resid_working = eta_d * resid_response; 491 | let score_unreg = self.data.x.t().dot(&resid_working); 492 | self.reg.as_ref().gradient(score_unreg, params) 493 | } 494 | 495 | /// Returns the score test statistic. This statistic is asymptotically 496 | /// chi-squared distributed with `test_ndf()` degrees of freedom. 497 | pub fn score_test(&self) -> RegressionResult { 498 | let (_, null_params) = self.null_model_fit(); 499 | self.score_test_against(null_params) 500 | } 501 | 502 | /// Returns the score test statistic compared to another set of model 503 | /// parameters, not necessarily a null model. The degrees of freedom cannot 504 | /// be generally inferred. 505 | pub fn score_test_against(&self, alternative: Array1) -> RegressionResult { 506 | let score_alt = self.score(&alternative); 507 | let fisher_alt = self.fisher(&alternative); 508 | // The is not the same as the cached covariance matrix since it is 509 | // evaluated at the null parameters. 510 | // NOTE: invh/invh_into() are bugged and incorrect! 511 | let inv_fisher_alt = fisher_alt.inv_into()?; 512 | Ok(score_alt.t().dot(&inv_fisher_alt.dot(&score_alt))) 513 | } 514 | 515 | /// The degrees of freedom for the likelihood ratio test, the score test, 516 | /// and the Wald test. Not to be confused with `ndf()`, the degrees of 517 | /// freedom in the model fit. 518 | pub fn test_ndf(&self) -> usize { 519 | if self.use_intercept { 520 | self.n_par - 1 521 | } else { 522 | self.n_par 523 | } 524 | } 525 | 526 | /// Returns the Wald test statistic compared to a null model with only an 527 | /// intercept (if one is used). This statistic is asymptotically chi-squared 528 | /// distributed with `test_ndf()` degrees of freedom. 529 | pub fn wald_test(&self) -> F { 530 | // The null parameters are all zero except for a possible intercept term 531 | // which optimizes the null model. 532 | let (_, null_params) = self.null_model_fit(); 533 | self.wald_test_against(&null_params) 534 | } 535 | 536 | /// Returns the Wald test statistic compared to another specified model fit 537 | /// instead of the null model. The degrees of freedom cannot be generally 538 | /// inferred. 539 | pub fn wald_test_against(&self, alternative: &Array1) -> F { 540 | let d_params: Array1 = &self.result - alternative; 541 | let fisher_alt: Array2 = self.fisher(alternative); 542 | d_params.t().dot(&fisher_alt.dot(&d_params)) 543 | } 544 | 545 | /// Returns the signed square root of the Wald test statistic for each 546 | /// parameter. Since it does not account for covariance between the 547 | /// parameters it may not be accurate. 548 | pub fn wald_z(&self) -> RegressionResult> { 549 | let par_cov = self.covariance()?; 550 | let par_variances: ArrayView1 = par_cov.diag(); 551 | Ok(&self.result / &par_variances.mapv(num_traits::Float::sqrt)) 552 | } 553 | } 554 | 555 | /// Specialized functions for OLS. 556 | impl<'a, F> Fit<'a, Linear, F> 557 | where 558 | F: 'static + Float, 559 | { 560 | /// Returns the coefficient of multiple correlation, R^2. 561 | pub fn r_sq(&self) -> F { 562 | let y_avg: F = self.data.y.mean().expect("Data should be non-empty"); 563 | let total_sum_sq: F = self.data.y.mapv(|y| y - y_avg).mapv(|dy| dy * dy).sum(); 564 | (total_sum_sq - self.resid_sum_sq()) / total_sum_sq 565 | } 566 | 567 | /// Returns the residual sum of squares, i.e. the sum of the squared residuals. 568 | pub fn resid_sum_sq(&self) -> F { 569 | self.resid_resp().mapv_into(|r| r * r).sum() 570 | } 571 | } 572 | 573 | #[cfg(test)] 574 | mod tests { 575 | use super::*; 576 | use crate::{ 577 | model::ModelBuilder, 578 | utility::{one_pad, standardize}, 579 | Linear, Logistic, 580 | }; 581 | use anyhow::Result; 582 | use approx::assert_abs_diff_eq; 583 | use ndarray::Axis; 584 | 585 | /// Checks if the test statistics are invariant based upon whether the data is standardized. 586 | #[test] 587 | fn standardization_invariance() -> Result<()> { 588 | let data_y = array![true, false, false, true, true, true, true, false, true]; 589 | let data_x = array![-0.5, 0.3, -0.6, 0.2, 0.3, 1.2, 0.8, 0.6, -0.2].insert_axis(Axis(1)); 590 | let lin_off = array![0.1, 0.0, -0.1, 0.2, 0.1, 0.3, 0.4, -0.1, 0.1]; 591 | let data_x_std = standardize(data_x.clone()); 592 | let model = ModelBuilder::::data(&data_y, &data_x) 593 | .linear_offset(lin_off.clone()) 594 | .build()?; 595 | let fit = model.fit()?; 596 | let model_std = ModelBuilder::::data(&data_y, &data_x_std) 597 | .linear_offset(lin_off) 598 | .build()?; 599 | let fit_std = model_std.fit()?; 600 | let lr = fit.lr_test(); 601 | let lr_std = fit_std.lr_test(); 602 | assert_abs_diff_eq!(lr, lr_std); 603 | eprintln!("about to try score test"); 604 | assert_abs_diff_eq!( 605 | fit.score_test()?, 606 | fit_std.score_test()?, 607 | epsilon = f32::EPSILON as f64 608 | ); 609 | eprintln!("about to try wald test"); 610 | assert_abs_diff_eq!( 611 | fit.wald_test(), 612 | fit_std.wald_test(), 613 | epsilon = 4.0 * f64::EPSILON 614 | ); 615 | assert_abs_diff_eq!(fit.aic(), fit_std.aic()); 616 | assert_abs_diff_eq!(fit.bic(), fit_std.bic()); 617 | eprintln!("about to try deviance"); 618 | assert_abs_diff_eq!(fit.deviance(), fit_std.deviance()); 619 | // The Wald Z-score of the intercept term is not invariant under a 620 | // linear transformation of the data, but the parameter part seems to 621 | // be, at least for single-component data. 622 | assert_abs_diff_eq!( 623 | fit.wald_z()?[1], 624 | fit_std.wald_z()?[1], 625 | epsilon = 0.01 * f32::EPSILON as f64 626 | ); 627 | 628 | Ok(()) 629 | } 630 | 631 | #[test] 632 | fn null_model() -> Result<()> { 633 | let data_y = array![true, false, false, true, true]; 634 | let data_x: Array2 = array![[], [], [], [], []]; 635 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 636 | let fit = model.fit()?; 637 | dbg!(fit.n_iter); 638 | dbg!(&fit.result); 639 | // with no offsets, the result should be the link function of the mean. 640 | assert_abs_diff_eq!( 641 | fit.result[0], 642 | ::Link::func(0.6), 643 | epsilon = 4.0 * f64::EPSILON 644 | ); 645 | let empty_null_like = fit.null_like(); 646 | assert_eq!(fit.test_ndf(), 0); 647 | dbg!(&fit.model_like); 648 | let lr = fit.lr_test(); 649 | // Since there is no data, the null likelihood should be identical to 650 | // the fit likelihood, so the likelihood ratio test should yield zero. 651 | assert_abs_diff_eq!(lr, 0., epsilon = 4. * f64::EPSILON); 652 | 653 | // Check that the assertions still hold if linear offsets are included. 654 | let lin_off: Array1 = array![0.2, -0.1, 0.1, 0.0, 0.1]; 655 | let model = ModelBuilder::::data(&data_y, &data_x) 656 | .linear_offset(lin_off) 657 | .build()?; 658 | let fit_off = model.fit()?; 659 | let empty_model_like_off = fit_off.model_like; 660 | let empty_null_like_off = fit_off.null_like(); 661 | // these two assertions should be equivalent 662 | assert_abs_diff_eq!(fit_off.lr_test(), 0.); 663 | assert_abs_diff_eq!(empty_model_like_off, empty_null_like_off); 664 | 665 | // check consistency with data provided 666 | let data_x_with = array![[0.5], [-0.2], [0.3], [0.4], [-0.1]]; 667 | let model = ModelBuilder::::data(&data_y, &data_x_with).build()?; 668 | let fit_with = model.fit()?; 669 | dbg!(&fit_with.result); 670 | // The null likelihood of the model with parameters should be the same 671 | // as the likelihood of the model with only the intercept. 672 | assert_abs_diff_eq!(empty_null_like, fit_with.null_like()); 673 | 674 | Ok(()) 675 | } 676 | 677 | #[test] 678 | fn null_like_logistic() -> Result<()> { 679 | // 6 true and 4 false for y_bar = 0.6. 680 | let data_y = array![true, true, true, true, true, true, false, false, false, false]; 681 | let ybar: f64 = 0.6; 682 | let data_x = array![0.4, 0.2, 0.5, 0.1, 0.6, 0.7, 0.3, 0.8, -0.1, 0.1].insert_axis(Axis(1)); 683 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 684 | let fit = model.fit()?; 685 | let target_null_like = fit 686 | .data 687 | .y 688 | .mapv(|y| { 689 | let eta = (ybar / (1. - ybar)).ln(); 690 | y * eta - eta.exp().ln_1p() 691 | }) 692 | .sum(); 693 | assert_abs_diff_eq!(fit.null_like(), target_null_like); 694 | Ok(()) 695 | } 696 | 697 | // Check that the deviance is equal to the sum of square deviations for a linear model 698 | #[test] 699 | fn deviance_linear() -> Result<()> { 700 | let data_y = array![0.3, -0.2, 0.5, 0.7, 0.2, 1.4, 1.1, 0.2]; 701 | let data_x = array![0.6, 2.1, 0.4, -3.2, 0.7, 0.1, -0.3, 0.5].insert_axis(Axis(1)); 702 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 703 | let fit = model.fit()?; 704 | // The predicted values of Y given the model. 705 | let pred_y = fit.predict(&one_pad(data_x.view()), None); 706 | let target_dev = (data_y - pred_y).mapv(|dy| dy * dy).sum(); 707 | assert_abs_diff_eq!(fit.deviance(), target_dev,); 708 | Ok(()) 709 | } 710 | 711 | // Check that the deviance and dispersion parameter are equal up to the number of degrees of 712 | // freedom for a linea model. 713 | #[test] 714 | fn deviance_dispersion_eq_linear() -> Result<()> { 715 | let data_y = array![0.2, -0.1, 0.4, 1.3, 0.2, -0.6, 0.9]; 716 | let data_x = array![ 717 | [0.4, 0.2], 718 | [0.1, 0.4], 719 | [-0.1, 0.3], 720 | [0.5, 0.7], 721 | [0.4, 0.1], 722 | [-0.2, -0.3], 723 | [0.4, -0.1] 724 | ]; 725 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 726 | let fit = model.fit()?; 727 | let dev = fit.deviance(); 728 | let disp = fit.dispersion(); 729 | let ndf = fit.ndf() as f64; 730 | assert_abs_diff_eq!(dev, disp * ndf, epsilon = 4. * f64::EPSILON); 731 | Ok(()) 732 | } 733 | 734 | // Check that the residuals for a linear model are all consistent. 735 | #[test] 736 | fn residuals_linear() -> Result<()> { 737 | let data_y = array![0.1, -0.3, 0.7, 0.2, 1.2, -0.4]; 738 | let data_x = array![0.4, 0.1, 0.3, -0.1, 0.5, 0.6].insert_axis(Axis(1)); 739 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 740 | let fit = model.fit()?; 741 | let response = fit.resid_resp(); 742 | let pearson = fit.resid_pear(); 743 | let deviance = fit.resid_dev(); 744 | assert_abs_diff_eq!(response, pearson); 745 | assert_abs_diff_eq!(response, deviance); 746 | let pearson_std = fit.resid_pear_std()?; 747 | let deviance_std = fit.resid_dev_std()?; 748 | let _student = fit.resid_student()?; 749 | assert_abs_diff_eq!(pearson_std, deviance_std, epsilon = 8. * f64::EPSILON); 750 | 751 | // // NOTE: Studentization can't be checked directly because the method used is an 752 | // approximation. Another approach will be needed to give exact values. 753 | // let orig_dev = fit.deviance(); 754 | // let n_data = data_y.len(); 755 | // // Check that the leave-one-out stats hold literally 756 | // let mut loo_dev: Vec = Vec::new(); 757 | // for i in 0..n_data { 758 | // let ya = data_y.slice(s![0..i]); 759 | // let yb = data_y.slice(s![i + 1..]); 760 | // let xa = data_x.slice(s![0..i, ..]); 761 | // let xb = data_x.slice(s![i + 1.., ..]); 762 | // let y_loo = concatenate![Axis(0), ya, yb]; 763 | // let x_loo = concatenate![Axis(0), xa, xb]; 764 | // let model_i = ModelBuilder::::data(&y_loo, &x_loo).build()?; 765 | // let fit_i = model_i.fit()?; 766 | // let yi = data_y[i]; 767 | // let xi = data_x.slice(s![i..i + 1, ..]); 768 | // let xi = crate::utility::one_pad(xi); 769 | // let yi_pred: f64 = fit_i.predict(&xi, None)[0]; 770 | // let disp_i = fit_i.dispersion(); 771 | // let pear_loo = (yi - yi_pred) / disp_i.sqrt(); 772 | // let dev_i = fit_i.deviance(); 773 | // let d_dev = 2. * (orig_dev - dev_i); 774 | // loo_dev.push(d_dev.sqrt() * (yi - yi_pred).signum()); 775 | // } 776 | // let loo_dev: Array1 = loo_dev.into(); 777 | // This is off from 1 by a constant factor that depends on the data 778 | // This is only approximately true 779 | // assert_abs_diff_eq!(student, loo_dev); 780 | Ok(()) 781 | } 782 | 783 | // check the null likelihood for the case where it can be counted exactly. 784 | #[test] 785 | fn null_like_linear() -> Result<()> { 786 | let data_y = array![0.3, -0.2, 0.5, 0.7, 0.2, 1.4, 1.1, 0.2]; 787 | let data_x = array![0.6, 2.1, 0.4, -3.2, 0.7, 0.1, -0.3, 0.5].insert_axis(Axis(1)); 788 | let ybar: f64 = data_y.mean().unwrap(); 789 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 790 | let fit = model.fit()?; 791 | // let target_null_like = data_y.mapv(|y| -0.5 * (y - ybar) * (y - ybar)).sum(); 792 | let target_null_like = data_y.mapv(|y| y * ybar - 0.5 * ybar * ybar).sum(); 793 | // With the saturated likelihood subtracted the null likelihood should 794 | // just be the sum of squared differences from the mean. 795 | // let target_null_like = 0.; 796 | // dbg!(target_null_like); 797 | let fit_null_like = fit.null_like(); 798 | assert_abs_diff_eq!(2. * (fit.model_like - fit_null_like), fit.lr_test()); 799 | assert_eq!(fit.test_ndf(), 1); 800 | assert_abs_diff_eq!( 801 | fit_null_like, 802 | target_null_like, 803 | epsilon = 4.0 * f64::EPSILON 804 | ); 805 | Ok(()) 806 | } 807 | 808 | // check the null likelihood where there is no dependence on the X data. 809 | #[test] 810 | fn null_like_logistic_nodep() -> Result<()> { 811 | let data_y = array![true, true, false, false, true, false, false, true]; 812 | let data_x = array![0.4, 0.2, 0.4, 0.2, 0.7, 0.7, -0.1, -0.1].insert_axis(Axis(1)); 813 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 814 | let fit = model.fit()?; 815 | let lr = fit.lr_test(); 816 | assert_abs_diff_eq!(lr, 0.); 817 | Ok(()) 818 | } 819 | // TODO: Test that the statistics behave sensibly under regularization. The 820 | // likelihood ratio test should yield a smaller value. 821 | 822 | // Test the basic caching funcions. 823 | #[test] 824 | fn cached_computations() -> Result<()> { 825 | let data_y = array![true, true, false, true, true, false, false, false, true]; 826 | let data_x = array![0.4, 0.1, -0.3, 0.7, -0.5, -0.1, 0.8, 1.0, 0.4].insert_axis(Axis(1)); 827 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 828 | let fit = model.fit()?; 829 | let _null_like = fit.null_like(); 830 | let _null_like = fit.null_like(); 831 | let _cov = fit.covariance()?; 832 | let _wald = fit.wald_z(); 833 | Ok(()) 834 | } 835 | 836 | // Check the consistency of the various statistical tests for linear 837 | // regression, where they should be the most comparable. 838 | #[test] 839 | fn linear_stat_tests() -> Result<()> { 840 | let data_y = array![-0.3, -0.1, 0.0, 0.2, 0.4, 0.5, 0.8, 0.8, 1.1]; 841 | let data_x = array![-0.5, -0.2, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.3].insert_axis(Axis(1)); 842 | let model = ModelBuilder::::data(&data_y, &data_x).build()?; 843 | let fit = model.fit()?; 844 | let lr = fit.lr_test(); 845 | let wald = fit.wald_test(); 846 | let score = fit.score_test()?; 847 | assert_abs_diff_eq!(lr, wald, epsilon = 32.0 * f64::EPSILON); 848 | assert_abs_diff_eq!(lr, score, epsilon = 32.0 * f64::EPSILON); 849 | Ok(()) 850 | } 851 | } 852 | --------------------------------------------------------------------------------