├── .github └── workflows │ └── tests.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── data ├── breast_cancer.csv ├── breast_cancer_without_target.csv ├── diabetes.csv └── diabetes_without_target.csv ├── examples ├── classification_save_best.rs ├── maximal_classification.rs ├── maximal_regression.rs ├── minimal_classification.rs ├── minimal_regression.rs └── print_settings.rs ├── src ├── algorithms │ ├── categorical_naive_bayes_classifier.rs │ ├── decision_tree_classifier.rs │ ├── decision_tree_regressor.rs │ ├── elastic_net_regressor.rs │ ├── gaussian_naive_bayes_classifier.rs │ ├── knn_classifier.rs │ ├── knn_regressor.rs │ ├── lasso_regressor.rs │ ├── linear_regressor.rs │ ├── logistic_regression.rs │ ├── mod.rs │ ├── random_forest_classifier.rs │ ├── random_forest_regressor.rs │ ├── ridge_regressor.rs │ ├── support_vector_classifier.rs │ └── support_vector_regressor.rs ├── cookbook.rs ├── lib.rs ├── settings │ ├── knn_classifier_parameters.rs │ ├── knn_regressor_parameters.rs │ ├── mod.rs │ ├── settings_struct.rs │ ├── svc_parameters.rs │ └── svr_parameters.rs └── utils.rs └── tests ├── classification.rs ├── new_from_dataset.rs └── regression.rs /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Install egui dependencies 20 | run: sudo apt-get install -y libclang-dev libgtk-3-dev libxcb-render0-dev libxcb-shape0-dev libxcb-xfixes0-dev libxkbcommon-dev libssl-dev 21 | - name: Build 22 | run: cargo build --release --verbose --all-features 23 | - name: Run tests 24 | run: cargo test --release --verbose --all-features -- --nocapture 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | .idea/ 4 | /examples/*.aml 5 | /examples/*.yaml 6 | /examples/*.sc 7 | .vscode 8 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "automl" 3 | version = "0.2.9" 4 | authors = ["Chris McComb "] 5 | description = "Automated machine learning for classification and regression" 6 | edition = "2021" 7 | readme = "README.md" 8 | repository = "https://github.com/cmccomb/rust-automl" 9 | homepage = "https://github.com/cmccomb/rust-automl" 10 | documentation = "https://docs.rs/automl" 11 | license = "MIT OR Apache-2.0" 12 | keywords = ["machine-learning", "ml", "ai", "smartcore", "automl"] 13 | categories = ["algorithms", "mathematics", "science"] 14 | 15 | [dependencies] 16 | smartcore = {version = "0.2.1", features=["serde"]} 17 | serde = {version = "^1", features=["derive"]} 18 | bincode = "^1" 19 | itertools = "^0.14" 20 | comfy-table = "^7" 21 | humantime = "^2" 22 | ndarray = {version = "^0.16", optional = true} 23 | polars = {version = "0.17.0", features = ["ndarray"], optional = true} 24 | url = {version = "^2", optional = true} 25 | temp-file = {version = "0.1.6", optional = true} 26 | csv-sniffer = { version = "^0.3", optional = true } 27 | minreq = {version = "^2", optional = true, features = ["json-using-serde", "https"]} 28 | serde_yaml = "^0.9" 29 | 30 | [features] 31 | default = [] 32 | nd = ["ndarray"] 33 | csv = ["polars", "nd", "url", "temp-file", "minreq", "csv-sniffer"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Chris McComb 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Github CI](https://github.com/cmccomb/rust-automl/actions/workflows/tests.yml/badge.svg)](https://github.com/cmccomb/automl/actions) 2 | [![Crates.io](https://img.shields.io/crates/v/automl.svg)](https://crates.io/crates/automl) 3 | [![docs.rs](https://img.shields.io/docsrs/automl/latest?logo=rust)](https://docs.rs/automl) 4 | 5 | # AutoML with SmartCore 6 | 7 | AutoML (_Automated Machine Learning_) streamlines machine learning workflows, making them more accessible and efficient 8 | for users of all experience levels. This crate extends the [`smartcore`](https://docs.rs/smartcore/) machine learning framework, providing utilities to 9 | quickly train, compare, and deploy models. 10 | 11 | # Install 12 | 13 | Add AutoML to your `Cargo.toml` to get started: 14 | 15 | **Stable Version** 16 | ```toml 17 | automl = "0.2.9" 18 | ``` 19 | **Latest Development Version** 20 | ```toml 21 | automl = { git = "https://github.com/cmccomb/rust-automl" } 22 | ``` 23 | 24 | # Example Usage 25 | 26 | Here’s a quick example to illustrate how AutoML can simplify model training and comparison: 27 | 28 | 29 | ```rust 30 | let dataset = smartcore::dataset::breast_cancer::load_dataset(); 31 | let settings = automl::Settings::default_classification(); 32 | let mut classifier = automl::SupervisedModel::new(dataset, settings); 33 | classifier.train(); 34 | ``` 35 | 36 | will perform a comparison of classifier models using cross-validation. Printing the classifier object will yield: 37 | 38 | ```text 39 | ┌────────────────────────────────┬─────────────────────┬───────────────────┬──────────────────┐ 40 | │ Model │ Time │ Training Accuracy │ Testing Accuracy │ 41 | ╞════════════════════════════════╪═════════════════════╪═══════════════════╪══════════════════╡ 42 | │ Random Forest Classifier │ 835ms 393us 583ns │ 1.00 │ 0.96 │ 43 | ├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ 44 | │ Logistic Regression Classifier │ 620ms 714us 583ns │ 0.97 │ 0.95 │ 45 | ├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ 46 | │ Gaussian Naive Bayes │ 6ms 529us │ 0.94 │ 0.93 │ 47 | ├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ 48 | │ Categorical Naive Bayes │ 2ms 922us 250ns │ 0.96 │ 0.93 │ 49 | ├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ 50 | │ Decision Tree Classifier │ 15ms 404us 750ns │ 1.00 │ 0.93 │ 51 | ├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ 52 | │ KNN Classifier │ 28ms 874us 208ns │ 0.96 │ 0.92 │ 53 | ├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ 54 | │ Support Vector Classifier │ 4s 187ms 61us 708ns │ 0.57 │ 0.57 │ 55 | └────────────────────────────────┴─────────────────────┴───────────────────┴──────────────────┘ 56 | ``` 57 | 58 | You can then perform inference using the best model with the `predict` method. 59 | 60 | ## Features 61 | 62 | This crate has several features that add some additional methods. 63 | 64 | | Feature | Description | 65 | | :------ | :------------------------------------------------------------------------------------------------------ | 66 | | `nd` | Adds methods for predicting/reading data using [`ndarray`](https://crates.io/crates/ndarray). | 67 | | `csv` | Adds methods for predicting/reading data from a .csv using [`polars`](https://crates.io/crates/polars). | 68 | 69 | ## Capabilities 70 | 71 | - Feature Engineering 72 | - PCA 73 | - SVD 74 | - Interaction terms 75 | - Polynomial terms 76 | - Regression 77 | - Decision Tree Regression 78 | - KNN Regression 79 | - Random Forest Regression 80 | - Linear Regression 81 | - Ridge Regression 82 | - LASSO 83 | - Elastic Net 84 | - Support Vector Regression 85 | - Classification 86 | - Random Forest Classification 87 | - Decision Tree Classification 88 | - Support Vector Classification 89 | - Logistic Regression 90 | - KNN Classification 91 | - Gaussian Naive Bayes 92 | - Meta-learning 93 | - Blending 94 | - Save and load settings 95 | - Save and load models -------------------------------------------------------------------------------- /examples/classification_save_best.rs: -------------------------------------------------------------------------------- 1 | use automl::settings::*; 2 | use automl::*; 3 | use smartcore::linalg::naive::dense_matrix::DenseMatrix; 4 | use smartcore::linear::logistic_regression::LogisticRegression; 5 | use std::io::Read; 6 | 7 | fn main() { 8 | // Set up and train a classification model with only one algorithm for simplicity 9 | let settings = Settings::default_classification().only(Algorithm::LogisticRegression); 10 | let dataset = smartcore::dataset::breast_cancer::load_dataset(); 11 | let mut model = SupervisedModel::new(dataset, settings); 12 | model.train(); 13 | 14 | // Save the best model 15 | let file_name = "examples/best_model_only.sc"; 16 | model.save_best(file_name); 17 | 18 | // Load that model for use directly in SmartCore 19 | let mut buf: Vec = Vec::new(); 20 | std::fs::File::open(file_name) 21 | .and_then(|mut f| f.read_to_end(&mut buf)) 22 | .expect("Cannot load model from file."); 23 | let model: LogisticRegression> = 24 | bincode::deserialize(&buf).expect("Can not deserialize the model"); 25 | 26 | // Use the model variable to prove that this works. 27 | println!("{:?}", model.coefficients()); 28 | println!("{:?}", model.intercept()); 29 | } 30 | -------------------------------------------------------------------------------- /examples/maximal_classification.rs: -------------------------------------------------------------------------------- 1 | use automl::settings::*; 2 | use automl::*; 3 | 4 | fn main() { 5 | // Totally customize settings 6 | let settings = Settings::default_classification() 7 | .with_number_of_folds(3) 8 | .shuffle_data(true) 9 | .verbose(true) 10 | .with_final_model(FinalModel::Blending { 11 | algorithm: Algorithm::CategoricalNaiveBayes, 12 | meta_training_fraction: 0.15, 13 | meta_testing_fraction: 0.15, 14 | }) 15 | .skip(Algorithm::RandomForestClassifier) 16 | .sorted_by(Metric::Accuracy) 17 | .with_preprocessing(PreProcessing::ReplaceWithPCA { 18 | number_of_components: 5, 19 | }) 20 | .with_random_forest_classifier_settings( 21 | RandomForestClassifierParameters::default() 22 | .with_m(100) 23 | .with_max_depth(5) 24 | .with_min_samples_leaf(20) 25 | .with_n_trees(100) 26 | .with_min_samples_split(20), 27 | ) 28 | .with_logistic_settings( 29 | LogisticRegressionParameters::default() 30 | .with_alpha(1.0) 31 | .with_solver(LogisticRegressionSolverName::LBFGS), 32 | ) 33 | .with_svc_settings( 34 | SVCParameters::default() 35 | .with_epoch(10) 36 | .with_tol(1e-10) 37 | .with_c(1.0) 38 | .with_kernel(Kernel::Linear), 39 | ) 40 | .with_decision_tree_classifier_settings( 41 | DecisionTreeClassifierParameters::default() 42 | .with_min_samples_split(20) 43 | .with_max_depth(5) 44 | .with_min_samples_leaf(20), 45 | ) 46 | .with_knn_classifier_settings( 47 | KNNClassifierParameters::default() 48 | .with_algorithm(KNNAlgorithmName::CoverTree) 49 | .with_k(3) 50 | .with_distance(Distance::Euclidean) 51 | .with_weight(KNNWeightFunction::Uniform), 52 | ) 53 | .with_gaussian_nb_settings(GaussianNBParameters::default().with_priors(vec![1.0, 1.0])) 54 | .with_categorical_nb_settings(CategoricalNBParameters::default().with_alpha(1.0)); 55 | 56 | // Save the settings for later use 57 | settings.save("examples/maximal_classification_settings.yaml"); 58 | 59 | // Load a dataset from smartcore and add it to the regressor 60 | let mut model = 61 | SupervisedModel::new(smartcore::dataset::breast_cancer::load_dataset(), settings); 62 | 63 | // Run a model comparison with all models at default settings 64 | model.train(); 65 | 66 | // Print the results 67 | println!("{}", model); 68 | 69 | // Save teh model for later 70 | model.save("examples/maximal_classification_model.aml"); 71 | } 72 | -------------------------------------------------------------------------------- /examples/maximal_regression.rs: -------------------------------------------------------------------------------- 1 | use automl::settings::*; 2 | use automl::*; 3 | 4 | fn main() { 5 | // Totally customize settings 6 | let settings = Settings::default_regression() 7 | .with_number_of_folds(3) 8 | .shuffle_data(true) 9 | .verbose(true) 10 | .with_final_model(FinalModel::Blending { 11 | algorithm: Algorithm::Linear, 12 | meta_training_fraction: 0.15, 13 | meta_testing_fraction: 0.15, 14 | }) 15 | .skip(Algorithm::RandomForestRegressor) 16 | .sorted_by(Metric::RSquared) 17 | .with_preprocessing(PreProcessing::AddInteractions) 18 | .with_linear_settings( 19 | LinearRegressionParameters::default().with_solver(LinearRegressionSolverName::QR), 20 | ) 21 | .with_lasso_settings( 22 | LassoParameters::default() 23 | .with_alpha(1.0) 24 | .with_tol(1e-4) 25 | .with_normalize(true) 26 | .with_max_iter(1000), 27 | ) 28 | .with_ridge_settings( 29 | RidgeRegressionParameters::default() 30 | .with_alpha(1.0) 31 | .with_normalize(true) 32 | .with_solver(RidgeRegressionSolverName::Cholesky), 33 | ) 34 | .with_elastic_net_settings( 35 | ElasticNetParameters::default() 36 | .with_tol(1e-4) 37 | .with_normalize(true) 38 | .with_alpha(1.0) 39 | .with_max_iter(1000) 40 | .with_l1_ratio(0.5), 41 | ) 42 | .with_knn_regressor_settings( 43 | KNNRegressorParameters::default() 44 | .with_algorithm(KNNAlgorithmName::CoverTree) 45 | .with_k(3) 46 | .with_distance(Distance::Euclidean) 47 | .with_weight(KNNWeightFunction::Uniform), 48 | ) 49 | .with_svr_settings( 50 | SVRParameters::default() 51 | .with_eps(0.1) 52 | .with_tol(1e-3) 53 | .with_c(1.0) 54 | .with_kernel(Kernel::Linear), 55 | ) 56 | .with_random_forest_regressor_settings( 57 | RandomForestRegressorParameters::default() 58 | .with_m(1) 59 | .with_max_depth(5) 60 | .with_min_samples_leaf(1) 61 | .with_n_trees(10) 62 | .with_min_samples_split(2), 63 | ) 64 | .with_decision_tree_regressor_settings( 65 | DecisionTreeRegressorParameters::default() 66 | .with_min_samples_split(2) 67 | .with_max_depth(15) 68 | .with_min_samples_leaf(1), 69 | ); 70 | 71 | // Save the settings for later use 72 | settings.save("examples/maximal_regression_settings.yaml"); 73 | 74 | // Load a dataset from smartcore and add it to the regressor along with the customized settings 75 | let mut model = SupervisedModel::new(smartcore::dataset::diabetes::load_dataset(), settings); 76 | 77 | // Run a model comparison with all models at default settings 78 | model.train(); 79 | 80 | // Print the results 81 | println!("{}", model); 82 | 83 | // Save teh model for later 84 | model.save("examples/maximal_regression_model.aml"); 85 | } 86 | -------------------------------------------------------------------------------- /examples/minimal_classification.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | // Define a default regressor from a dataset 3 | let mut model = automl::SupervisedModel::new( 4 | smartcore::dataset::breast_cancer::load_dataset(), 5 | automl::Settings::default_classification(), 6 | ); 7 | 8 | // Run a model comparison with all models at default settings 9 | model.train(); 10 | } 11 | -------------------------------------------------------------------------------- /examples/minimal_regression.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | // Define a default regressor from a dataset 3 | let mut model = automl::SupervisedModel::new( 4 | smartcore::dataset::diabetes::load_dataset(), 5 | automl::Settings::default_regression(), 6 | ); 7 | 8 | // Run a model comparison with all models at default settings 9 | model.train(); 10 | } 11 | -------------------------------------------------------------------------------- /examples/print_settings.rs: -------------------------------------------------------------------------------- 1 | use automl::settings::*; 2 | 3 | fn main() { 4 | let regressor_settings = automl::Settings::default_regression() 5 | .with_number_of_folds(3) 6 | .shuffle_data(true) 7 | .verbose(true) 8 | .sorted_by(Metric::RSquared) 9 | .with_preprocessing(PreProcessing::AddInteractions) 10 | .with_linear_settings( 11 | LinearRegressionParameters::default().with_solver(LinearRegressionSolverName::QR), 12 | ) 13 | .with_lasso_settings( 14 | LassoParameters::default() 15 | .with_alpha(10.0) 16 | .with_tol(1e-10) 17 | .with_normalize(true) 18 | .with_max_iter(10_000), 19 | ) 20 | .with_ridge_settings( 21 | RidgeRegressionParameters::default() 22 | .with_alpha(10.0) 23 | .with_normalize(true) 24 | .with_solver(RidgeRegressionSolverName::Cholesky), 25 | ) 26 | .with_elastic_net_settings( 27 | ElasticNetParameters::default() 28 | .with_tol(1e-10) 29 | .with_normalize(true) 30 | .with_alpha(1.0) 31 | .with_max_iter(10_000) 32 | .with_l1_ratio(0.5), 33 | ) 34 | .with_knn_regressor_settings( 35 | KNNRegressorParameters::default() 36 | .with_algorithm(KNNAlgorithmName::CoverTree) 37 | .with_k(3) 38 | .with_distance(Distance::Euclidean) 39 | .with_weight(KNNWeightFunction::Uniform), 40 | ) 41 | .with_svr_settings( 42 | SVRParameters::default() 43 | .with_eps(1e-10) 44 | .with_tol(1e-10) 45 | .with_c(1.0) 46 | .with_kernel(Kernel::Linear), 47 | ) 48 | .with_random_forest_regressor_settings( 49 | RandomForestRegressorParameters::default() 50 | .with_m(100) 51 | .with_max_depth(5) 52 | .with_min_samples_leaf(20) 53 | .with_n_trees(100) 54 | .with_min_samples_split(20), 55 | ) 56 | .with_decision_tree_regressor_settings( 57 | DecisionTreeRegressorParameters::default() 58 | .with_min_samples_split(20) 59 | .with_max_depth(5) 60 | .with_min_samples_leaf(20), 61 | ); 62 | 63 | let classifier_settings = automl::Settings::default_classification() 64 | .with_number_of_folds(3) 65 | .shuffle_data(true) 66 | .verbose(true) 67 | .sorted_by(Metric::Accuracy) 68 | .with_preprocessing(PreProcessing::AddInteractions) 69 | .with_random_forest_classifier_settings( 70 | RandomForestClassifierParameters::default() 71 | .with_m(100) 72 | .with_max_depth(5) 73 | .with_min_samples_leaf(20) 74 | .with_n_trees(100) 75 | .with_min_samples_split(20), 76 | ) 77 | .with_logistic_settings(LogisticRegressionParameters::default()) 78 | .with_svc_settings( 79 | SVCParameters::default() 80 | .with_epoch(10) 81 | .with_tol(1e-10) 82 | .with_c(1.0) 83 | .with_kernel(Kernel::Linear), 84 | ) 85 | .with_decision_tree_classifier_settings( 86 | DecisionTreeClassifierParameters::default() 87 | .with_min_samples_split(20) 88 | .with_max_depth(5) 89 | .with_min_samples_leaf(20), 90 | ) 91 | .with_knn_classifier_settings( 92 | KNNClassifierParameters::default() 93 | .with_algorithm(KNNAlgorithmName::CoverTree) 94 | .with_k(3) 95 | .with_distance(Distance::Hamming) 96 | .with_weight(KNNWeightFunction::Uniform), 97 | ) 98 | .with_gaussian_nb_settings(GaussianNBParameters::default().with_priors(vec![1.0, 1.0])) 99 | .with_categorical_nb_settings(CategoricalNBParameters::default().with_alpha(1.0)); 100 | 101 | println!("{}", regressor_settings); 102 | println!("{}", classifier_settings) 103 | } 104 | -------------------------------------------------------------------------------- /src/algorithms/categorical_naive_bayes_classifier.rs: -------------------------------------------------------------------------------- 1 | //! Categorical Naive Bayes Classifier. 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, 5 | model_selection::{cross_validate, CrossValidationResult}, 6 | naive_bayes::categorical::CategoricalNB, 7 | }; 8 | 9 | use crate::{Algorithm, Settings}; 10 | 11 | /// The Categorical Naive Bayes Classifier. 12 | /// 13 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/naive_bayes.html#categorical-naive-bayes) 14 | /// for a more in-depth description of the algorithm. 15 | pub struct CategoricalNaiveBayesClassifierWrapper {} 16 | 17 | impl super::ModelWrapper for CategoricalNaiveBayesClassifierWrapper { 18 | fn cv( 19 | x: &DenseMatrix, 20 | y: &Vec, 21 | settings: &Settings, 22 | ) -> (CrossValidationResult, Algorithm) { 23 | ( 24 | cross_validate( 25 | CategoricalNB::fit, 26 | x, 27 | y, 28 | settings.categorical_nb_settings.as_ref().unwrap().clone(), 29 | settings.get_kfolds(), 30 | settings.get_metric(), 31 | ) 32 | .unwrap(), 33 | Algorithm::CategoricalNaiveBayes, 34 | ) 35 | } 36 | 37 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 38 | bincode::serialize( 39 | &CategoricalNB::fit( 40 | x, 41 | y, 42 | settings.categorical_nb_settings.as_ref().unwrap().clone(), 43 | ) 44 | .unwrap(), 45 | ) 46 | .unwrap() 47 | } 48 | 49 | fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { 50 | let model: CategoricalNB> = 51 | bincode::deserialize(final_model).unwrap(); 52 | model.predict(x).unwrap() 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/algorithms/decision_tree_classifier.rs: -------------------------------------------------------------------------------- 1 | //! Decision Tree Classifier. 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, 5 | model_selection::{cross_validate, CrossValidationResult}, 6 | tree::decision_tree_classifier::DecisionTreeClassifier, 7 | }; 8 | 9 | use crate::{Algorithm, Settings}; 10 | 11 | /// The Decision Tree Classifier. 12 | /// 13 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/tree.html#classification) 14 | /// for a more in-depth description of the algorithm. 15 | pub struct DecisionTreeClassifierWrapper {} 16 | 17 | impl super::ModelWrapper for DecisionTreeClassifierWrapper { 18 | fn cv( 19 | x: &DenseMatrix, 20 | y: &Vec, 21 | settings: &Settings, 22 | ) -> (CrossValidationResult, Algorithm) { 23 | ( 24 | cross_validate( 25 | DecisionTreeClassifier::fit, 26 | x, 27 | y, 28 | settings 29 | .decision_tree_classifier_settings 30 | .as_ref() 31 | .unwrap() 32 | .clone(), 33 | settings.get_kfolds(), 34 | settings.get_metric(), 35 | ) 36 | .unwrap(), 37 | Algorithm::DecisionTreeClassifier, 38 | ) 39 | } 40 | 41 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 42 | bincode::serialize( 43 | &DecisionTreeClassifier::fit( 44 | x, 45 | y, 46 | settings 47 | .decision_tree_classifier_settings 48 | .as_ref() 49 | .unwrap() 50 | .clone(), 51 | ) 52 | .unwrap(), 53 | ) 54 | .unwrap() 55 | } 56 | 57 | fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { 58 | let model: DecisionTreeClassifier = bincode::deserialize(final_model).unwrap(); 59 | model.predict(x).unwrap() 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/algorithms/decision_tree_regressor.rs: -------------------------------------------------------------------------------- 1 | //! Decision Tree Regressor. 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, 5 | model_selection::{cross_validate, CrossValidationResult}, 6 | tree::decision_tree_regressor::DecisionTreeRegressor, 7 | }; 8 | 9 | use crate::{Algorithm, Settings}; 10 | 11 | /// The Decision Tree Regressor. 12 | /// 13 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/tree.html#regression) 14 | /// for a more in-depth description of the algorithm. 15 | pub struct DecisionTreeRegressorWrapper {} 16 | 17 | impl super::ModelWrapper for DecisionTreeRegressorWrapper { 18 | fn cv( 19 | x: &DenseMatrix, 20 | y: &Vec, 21 | settings: &Settings, 22 | ) -> (CrossValidationResult, Algorithm) { 23 | ( 24 | cross_validate( 25 | DecisionTreeRegressor::fit, 26 | x, 27 | y, 28 | settings 29 | .decision_tree_regressor_settings 30 | .as_ref() 31 | .unwrap() 32 | .clone(), 33 | settings.get_kfolds(), 34 | settings.get_metric(), 35 | ) 36 | .unwrap(), 37 | Algorithm::DecisionTreeRegressor, 38 | ) 39 | } 40 | 41 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 42 | bincode::serialize( 43 | &DecisionTreeRegressor::fit( 44 | x, 45 | y, 46 | settings 47 | .decision_tree_regressor_settings 48 | .as_ref() 49 | .unwrap() 50 | .clone(), 51 | ) 52 | .unwrap(), 53 | ) 54 | .unwrap() 55 | } 56 | 57 | fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { 58 | let model: DecisionTreeRegressor = bincode::deserialize(final_model).unwrap(); 59 | model.predict(x).unwrap() 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/algorithms/elastic_net_regressor.rs: -------------------------------------------------------------------------------- 1 | //! Elastic Net Regressor. 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, linear::elastic_net::ElasticNet, 5 | model_selection::cross_validate, model_selection::CrossValidationResult, 6 | }; 7 | 8 | use crate::{Algorithm, Settings}; 9 | 10 | /// The Elastic Net Regressor. 11 | /// 12 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/linear_model.html#elastic-net) 13 | /// for a more in-depth description of the algorithm. 14 | pub struct ElasticNetRegressorWrapper {} 15 | 16 | impl super::ModelWrapper for ElasticNetRegressorWrapper { 17 | fn cv( 18 | x: &DenseMatrix, 19 | y: &Vec, 20 | settings: &Settings, 21 | ) -> (CrossValidationResult, Algorithm) { 22 | ( 23 | cross_validate( 24 | ElasticNet::fit, 25 | x, 26 | y, 27 | settings.elastic_net_settings.as_ref().unwrap().clone(), 28 | settings.get_kfolds(), 29 | settings.get_metric(), 30 | ) 31 | .unwrap(), 32 | Algorithm::ElasticNet, 33 | ) 34 | } 35 | 36 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 37 | bincode::serialize( 38 | &ElasticNet::fit( 39 | x, 40 | y, 41 | settings.elastic_net_settings.as_ref().unwrap().clone(), 42 | ) 43 | .unwrap(), 44 | ) 45 | .unwrap() 46 | } 47 | 48 | fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { 49 | let model: ElasticNet> = bincode::deserialize(final_model).unwrap(); 50 | model.predict(x).unwrap() 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/algorithms/gaussian_naive_bayes_classifier.rs: -------------------------------------------------------------------------------- 1 | //! Gaussian Naive Bayes Classifier 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, 5 | model_selection::{cross_validate, CrossValidationResult}, 6 | naive_bayes::gaussian::GaussianNB, 7 | }; 8 | 9 | use crate::{Algorithm, Settings}; 10 | 11 | /// The Gaussian Naive Bayes Classifier. 12 | /// 13 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/naive_bayes.html#gaussian-naive-bayes) 14 | /// for a more in-depth description of the algorithm. 15 | pub struct GaussianNaiveBayesClassifierWrapper {} 16 | 17 | impl super::ModelWrapper for GaussianNaiveBayesClassifierWrapper { 18 | fn cv( 19 | x: &DenseMatrix, 20 | y: &Vec, 21 | settings: &Settings, 22 | ) -> (CrossValidationResult, Algorithm) { 23 | ( 24 | cross_validate( 25 | GaussianNB::fit, 26 | x, 27 | y, 28 | settings.gaussian_nb_settings.as_ref().unwrap().clone(), 29 | settings.get_kfolds(), 30 | settings.get_metric(), 31 | ) 32 | .unwrap(), 33 | Algorithm::GaussianNaiveBayes, 34 | ) 35 | } 36 | 37 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 38 | bincode::serialize( 39 | &GaussianNB::fit( 40 | x, 41 | y, 42 | settings.gaussian_nb_settings.as_ref().unwrap().clone(), 43 | ) 44 | .unwrap(), 45 | ) 46 | .unwrap() 47 | } 48 | 49 | fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { 50 | let model: GaussianNB> = bincode::deserialize(final_model).unwrap(); 51 | model.predict(x).unwrap() 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/algorithms/knn_classifier.rs: -------------------------------------------------------------------------------- 1 | //! KNN Classifier 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, 5 | math::distance::{ 6 | euclidian::Euclidian, hamming::Hamming, mahalanobis::Mahalanobis, manhattan::Manhattan, 7 | minkowski::Minkowski, Distances, 8 | }, 9 | model_selection::{cross_validate, CrossValidationResult}, 10 | neighbors::knn_classifier::{ 11 | KNNClassifier, KNNClassifierParameters as SmartcoreKNNClassifierParameters, 12 | }, 13 | }; 14 | 15 | use crate::{Algorithm, Distance, Settings}; 16 | 17 | /// The KNN Classifier. 18 | /// 19 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/neighbors.html#classification) 20 | /// for a more in-depth description of the algorithm. 21 | pub struct KNNClassifierWrapper {} 22 | 23 | impl super::ModelWrapper for KNNClassifierWrapper { 24 | fn cv( 25 | x: &DenseMatrix, 26 | y: &Vec, 27 | settings: &Settings, 28 | ) -> (CrossValidationResult, Algorithm) { 29 | let cv = match settings.knn_classifier_settings.as_ref().unwrap().distance { 30 | Distance::Euclidean => cross_validate( 31 | KNNClassifier::fit, 32 | x, 33 | y, 34 | SmartcoreKNNClassifierParameters::default() 35 | .with_k(settings.knn_classifier_settings.as_ref().unwrap().k) 36 | .with_weight( 37 | settings 38 | .knn_classifier_settings 39 | .as_ref() 40 | .unwrap() 41 | .weight 42 | .clone(), 43 | ) 44 | .with_algorithm( 45 | settings 46 | .knn_classifier_settings 47 | .as_ref() 48 | .unwrap() 49 | .algorithm 50 | .clone(), 51 | ) 52 | .with_distance(Distances::euclidian()), 53 | settings.get_kfolds(), 54 | settings.get_metric(), 55 | ) 56 | .unwrap(), 57 | Distance::Manhattan => cross_validate( 58 | KNNClassifier::fit, 59 | x, 60 | y, 61 | SmartcoreKNNClassifierParameters::default() 62 | .with_k(settings.knn_classifier_settings.as_ref().unwrap().k) 63 | .with_weight( 64 | settings 65 | .knn_classifier_settings 66 | .as_ref() 67 | .unwrap() 68 | .weight 69 | .clone(), 70 | ) 71 | .with_algorithm( 72 | settings 73 | .knn_classifier_settings 74 | .as_ref() 75 | .unwrap() 76 | .algorithm 77 | .clone(), 78 | ) 79 | .with_distance(Distances::manhattan()), 80 | settings.get_kfolds(), 81 | settings.get_metric(), 82 | ) 83 | .unwrap(), 84 | Distance::Minkowski(p) => cross_validate( 85 | KNNClassifier::fit, 86 | x, 87 | y, 88 | SmartcoreKNNClassifierParameters::default() 89 | .with_k(settings.knn_classifier_settings.as_ref().unwrap().k) 90 | .with_weight( 91 | settings 92 | .knn_classifier_settings 93 | .as_ref() 94 | .unwrap() 95 | .weight 96 | .clone(), 97 | ) 98 | .with_algorithm( 99 | settings 100 | .knn_classifier_settings 101 | .as_ref() 102 | .unwrap() 103 | .algorithm 104 | .clone(), 105 | ) 106 | .with_distance(Distances::minkowski(p)), 107 | settings.get_kfolds(), 108 | settings.get_metric(), 109 | ) 110 | .unwrap(), 111 | Distance::Mahalanobis => cross_validate( 112 | KNNClassifier::fit, 113 | x, 114 | y, 115 | SmartcoreKNNClassifierParameters::default() 116 | .with_k(settings.knn_classifier_settings.as_ref().unwrap().k) 117 | .with_weight( 118 | settings 119 | .knn_classifier_settings 120 | .as_ref() 121 | .unwrap() 122 | .weight 123 | .clone(), 124 | ) 125 | .with_algorithm( 126 | settings 127 | .knn_classifier_settings 128 | .as_ref() 129 | .unwrap() 130 | .algorithm 131 | .clone(), 132 | ) 133 | .with_distance(Distances::mahalanobis(x)), 134 | settings.get_kfolds(), 135 | settings.get_metric(), 136 | ) 137 | .unwrap(), 138 | Distance::Hamming => cross_validate( 139 | KNNClassifier::fit, 140 | x, 141 | y, 142 | SmartcoreKNNClassifierParameters::default() 143 | .with_k(settings.knn_classifier_settings.as_ref().unwrap().k) 144 | .with_weight( 145 | settings 146 | .knn_classifier_settings 147 | .as_ref() 148 | .unwrap() 149 | .weight 150 | .clone(), 151 | ) 152 | .with_algorithm( 153 | settings 154 | .knn_classifier_settings 155 | .as_ref() 156 | .unwrap() 157 | .algorithm 158 | .clone(), 159 | ) 160 | .with_distance(Distances::hamming()), 161 | settings.get_kfolds(), 162 | settings.get_metric(), 163 | ) 164 | .unwrap(), 165 | }; 166 | 167 | (cv, Algorithm::KNNClassifier) 168 | } 169 | 170 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 171 | match settings.knn_classifier_settings.as_ref().unwrap().distance { 172 | Distance::Euclidean => { 173 | let params = SmartcoreKNNClassifierParameters::default() 174 | .with_k(settings.knn_classifier_settings.as_ref().unwrap().k) 175 | .with_weight( 176 | settings 177 | .knn_classifier_settings 178 | .as_ref() 179 | .unwrap() 180 | .weight 181 | .clone(), 182 | ) 183 | .with_algorithm( 184 | settings 185 | .knn_classifier_settings 186 | .as_ref() 187 | .unwrap() 188 | .algorithm 189 | .clone(), 190 | ) 191 | .with_distance(Distances::euclidian()); 192 | bincode::serialize(&KNNClassifier::fit(x, y, params).unwrap()).unwrap() 193 | } 194 | Distance::Manhattan => { 195 | let params = SmartcoreKNNClassifierParameters::default() 196 | .with_k(settings.knn_classifier_settings.as_ref().unwrap().k) 197 | .with_weight( 198 | settings 199 | .knn_classifier_settings 200 | .as_ref() 201 | .unwrap() 202 | .weight 203 | .clone(), 204 | ) 205 | .with_algorithm( 206 | settings 207 | .knn_classifier_settings 208 | .as_ref() 209 | .unwrap() 210 | .algorithm 211 | .clone(), 212 | ) 213 | .with_distance(Distances::manhattan()); 214 | bincode::serialize(&KNNClassifier::fit(x, y, params).unwrap()).unwrap() 215 | } 216 | Distance::Minkowski(p) => { 217 | let params = SmartcoreKNNClassifierParameters::default() 218 | .with_k(settings.knn_classifier_settings.as_ref().unwrap().k) 219 | .with_weight( 220 | settings 221 | .knn_classifier_settings 222 | .as_ref() 223 | .unwrap() 224 | .weight 225 | .clone(), 226 | ) 227 | .with_algorithm( 228 | settings 229 | .knn_classifier_settings 230 | .as_ref() 231 | .unwrap() 232 | .algorithm 233 | .clone(), 234 | ) 235 | .with_distance(Distances::minkowski(p)); 236 | bincode::serialize(&KNNClassifier::fit(x, y, params).unwrap()).unwrap() 237 | } 238 | Distance::Mahalanobis => { 239 | let params = SmartcoreKNNClassifierParameters::default() 240 | .with_k(settings.knn_classifier_settings.as_ref().unwrap().k) 241 | .with_weight( 242 | settings 243 | .knn_classifier_settings 244 | .as_ref() 245 | .unwrap() 246 | .weight 247 | .clone(), 248 | ) 249 | .with_algorithm( 250 | settings 251 | .knn_classifier_settings 252 | .as_ref() 253 | .unwrap() 254 | .algorithm 255 | .clone(), 256 | ) 257 | .with_distance(Distances::mahalanobis(x)); 258 | bincode::serialize(&KNNClassifier::fit(x, y, params).unwrap()).unwrap() 259 | } 260 | Distance::Hamming => { 261 | let params = SmartcoreKNNClassifierParameters::default() 262 | .with_k(settings.knn_classifier_settings.as_ref().unwrap().k) 263 | .with_weight( 264 | settings 265 | .knn_classifier_settings 266 | .as_ref() 267 | .unwrap() 268 | .weight 269 | .clone(), 270 | ) 271 | .with_algorithm( 272 | settings 273 | .knn_classifier_settings 274 | .as_ref() 275 | .unwrap() 276 | .algorithm 277 | .clone(), 278 | ) 279 | .with_distance(Distances::hamming()); 280 | bincode::serialize(&KNNClassifier::fit(x, y, params).unwrap()).unwrap() 281 | } 282 | } 283 | } 284 | 285 | fn predict(x: &DenseMatrix, final_model: &Vec, settings: &Settings) -> Vec { 286 | match settings.knn_classifier_settings.as_ref().unwrap().distance { 287 | Distance::Euclidean => { 288 | let model: KNNClassifier = 289 | bincode::deserialize(final_model).unwrap(); 290 | model.predict(x).unwrap() 291 | } 292 | Distance::Manhattan => { 293 | let model: KNNClassifier = 294 | bincode::deserialize(final_model).unwrap(); 295 | model.predict(x).unwrap() 296 | } 297 | Distance::Minkowski(_) => { 298 | let model: KNNClassifier = 299 | bincode::deserialize(final_model).unwrap(); 300 | model.predict(x).unwrap() 301 | } 302 | Distance::Mahalanobis => { 303 | let model: KNNClassifier>> = 304 | bincode::deserialize(final_model).unwrap(); 305 | model.predict(x).unwrap() 306 | } 307 | Distance::Hamming => { 308 | let model: KNNClassifier = bincode::deserialize(final_model).unwrap(); 309 | model.predict(x).unwrap() 310 | } 311 | } 312 | } 313 | } 314 | -------------------------------------------------------------------------------- /src/algorithms/knn_regressor.rs: -------------------------------------------------------------------------------- 1 | //! KNN Regressor 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, 5 | math::distance::{ 6 | euclidian::Euclidian, hamming::Hamming, mahalanobis::Mahalanobis, manhattan::Manhattan, 7 | minkowski::Minkowski, Distances, 8 | }, 9 | model_selection::cross_validate, 10 | model_selection::CrossValidationResult, 11 | neighbors::knn_regressor::{ 12 | KNNRegressor, KNNRegressorParameters as SmartcoreKNNRegressorParameters, 13 | }, 14 | }; 15 | 16 | use crate::{Algorithm, Distance, Settings}; 17 | 18 | /// The KNN Regressor. 19 | /// 20 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/neighbors.html#regression) 21 | /// for a more in-depth description of the algorithm. 22 | pub struct KNNRegressorWrapper {} 23 | 24 | impl super::ModelWrapper for KNNRegressorWrapper { 25 | fn cv( 26 | x: &DenseMatrix, 27 | y: &Vec, 28 | settings: &Settings, 29 | ) -> (CrossValidationResult, Algorithm) { 30 | let cv = match settings.knn_regressor_settings.as_ref().unwrap().distance { 31 | Distance::Euclidean => cross_validate( 32 | KNNRegressor::fit, 33 | x, 34 | y, 35 | SmartcoreKNNRegressorParameters::default() 36 | .with_k(settings.knn_regressor_settings.as_ref().unwrap().k) 37 | .with_algorithm( 38 | settings 39 | .knn_regressor_settings 40 | .as_ref() 41 | .unwrap() 42 | .algorithm 43 | .clone(), 44 | ) 45 | .with_weight( 46 | settings 47 | .knn_regressor_settings 48 | .as_ref() 49 | .unwrap() 50 | .weight 51 | .clone(), 52 | ) 53 | .with_distance(Distances::euclidian()), 54 | settings.get_kfolds(), 55 | settings.get_metric(), 56 | ) 57 | .unwrap(), 58 | Distance::Manhattan => cross_validate( 59 | KNNRegressor::fit, 60 | x, 61 | y, 62 | SmartcoreKNNRegressorParameters::default() 63 | .with_k(settings.knn_regressor_settings.as_ref().unwrap().k) 64 | .with_algorithm( 65 | settings 66 | .knn_regressor_settings 67 | .as_ref() 68 | .unwrap() 69 | .algorithm 70 | .clone(), 71 | ) 72 | .with_weight( 73 | settings 74 | .knn_regressor_settings 75 | .as_ref() 76 | .unwrap() 77 | .weight 78 | .clone(), 79 | ) 80 | .with_distance(Distances::manhattan()), 81 | settings.get_kfolds(), 82 | settings.get_metric(), 83 | ) 84 | .unwrap(), 85 | Distance::Minkowski(p) => cross_validate( 86 | KNNRegressor::fit, 87 | x, 88 | y, 89 | SmartcoreKNNRegressorParameters::default() 90 | .with_k(settings.knn_regressor_settings.as_ref().unwrap().k) 91 | .with_algorithm( 92 | settings 93 | .knn_regressor_settings 94 | .as_ref() 95 | .unwrap() 96 | .algorithm 97 | .clone(), 98 | ) 99 | .with_weight( 100 | settings 101 | .knn_regressor_settings 102 | .as_ref() 103 | .unwrap() 104 | .weight 105 | .clone(), 106 | ) 107 | .with_distance(Distances::minkowski(p)), 108 | settings.get_kfolds(), 109 | settings.get_metric(), 110 | ) 111 | .unwrap(), 112 | Distance::Mahalanobis => cross_validate( 113 | KNNRegressor::fit, 114 | x, 115 | y, 116 | SmartcoreKNNRegressorParameters::default() 117 | .with_k(settings.knn_regressor_settings.as_ref().unwrap().k) 118 | .with_algorithm( 119 | settings 120 | .knn_regressor_settings 121 | .as_ref() 122 | .unwrap() 123 | .algorithm 124 | .clone(), 125 | ) 126 | .with_weight( 127 | settings 128 | .knn_regressor_settings 129 | .as_ref() 130 | .unwrap() 131 | .weight 132 | .clone(), 133 | ) 134 | .with_distance(Distances::mahalanobis(x)), 135 | settings.get_kfolds(), 136 | settings.get_metric(), 137 | ) 138 | .unwrap(), 139 | Distance::Hamming => cross_validate( 140 | KNNRegressor::fit, 141 | x, 142 | y, 143 | SmartcoreKNNRegressorParameters::default() 144 | .with_k(settings.knn_regressor_settings.as_ref().unwrap().k) 145 | .with_algorithm( 146 | settings 147 | .knn_regressor_settings 148 | .as_ref() 149 | .unwrap() 150 | .algorithm 151 | .clone(), 152 | ) 153 | .with_weight( 154 | settings 155 | .knn_regressor_settings 156 | .as_ref() 157 | .unwrap() 158 | .weight 159 | .clone(), 160 | ) 161 | .with_distance(Distances::hamming()), 162 | settings.get_kfolds(), 163 | settings.get_metric(), 164 | ) 165 | .unwrap(), 166 | }; 167 | 168 | (cv, Algorithm::KNNRegressor) 169 | } 170 | 171 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 172 | match settings.knn_regressor_settings.as_ref().unwrap().distance { 173 | Distance::Euclidean => { 174 | let params = SmartcoreKNNRegressorParameters::default() 175 | .with_k(settings.knn_regressor_settings.as_ref().unwrap().k) 176 | .with_algorithm( 177 | settings 178 | .knn_regressor_settings 179 | .as_ref() 180 | .unwrap() 181 | .algorithm 182 | .clone(), 183 | ) 184 | .with_weight( 185 | settings 186 | .knn_regressor_settings 187 | .as_ref() 188 | .unwrap() 189 | .weight 190 | .clone(), 191 | ) 192 | .with_distance(Distances::euclidian()); 193 | 194 | bincode::serialize(&KNNRegressor::fit(x, y, params).unwrap()).unwrap() 195 | } 196 | Distance::Manhattan => { 197 | let params = SmartcoreKNNRegressorParameters::default() 198 | .with_k(settings.knn_regressor_settings.as_ref().unwrap().k) 199 | .with_algorithm( 200 | settings 201 | .knn_regressor_settings 202 | .as_ref() 203 | .unwrap() 204 | .algorithm 205 | .clone(), 206 | ) 207 | .with_weight( 208 | settings 209 | .knn_regressor_settings 210 | .as_ref() 211 | .unwrap() 212 | .weight 213 | .clone(), 214 | ) 215 | .with_distance(Distances::manhattan()); 216 | 217 | bincode::serialize(&KNNRegressor::fit(x, y, params).unwrap()).unwrap() 218 | } 219 | Distance::Minkowski(p) => { 220 | let params = SmartcoreKNNRegressorParameters::default() 221 | .with_k(settings.knn_regressor_settings.as_ref().unwrap().k) 222 | .with_algorithm( 223 | settings 224 | .knn_regressor_settings 225 | .as_ref() 226 | .unwrap() 227 | .algorithm 228 | .clone(), 229 | ) 230 | .with_weight( 231 | settings 232 | .knn_regressor_settings 233 | .as_ref() 234 | .as_ref() 235 | .unwrap() 236 | .weight 237 | .clone(), 238 | ) 239 | .with_distance(Distances::minkowski(p)); 240 | 241 | bincode::serialize(&KNNRegressor::fit(x, y, params).unwrap()).unwrap() 242 | } 243 | Distance::Mahalanobis => { 244 | let params = SmartcoreKNNRegressorParameters::default() 245 | .with_k(settings.knn_regressor_settings.as_ref().unwrap().k) 246 | .with_algorithm( 247 | settings 248 | .knn_regressor_settings 249 | .as_ref() 250 | .unwrap() 251 | .algorithm 252 | .clone(), 253 | ) 254 | .with_weight( 255 | settings 256 | .knn_regressor_settings 257 | .as_ref() 258 | .as_ref() 259 | .unwrap() 260 | .weight 261 | .clone(), 262 | ) 263 | .with_distance(Distances::mahalanobis(x)); 264 | bincode::serialize(&KNNRegressor::fit(x, y, params).unwrap()).unwrap() 265 | } 266 | Distance::Hamming => { 267 | let params = SmartcoreKNNRegressorParameters::default() 268 | .with_k(settings.knn_regressor_settings.as_ref().unwrap().k) 269 | .with_algorithm( 270 | settings 271 | .knn_regressor_settings 272 | .as_ref() 273 | .unwrap() 274 | .algorithm 275 | .clone(), 276 | ) 277 | .with_weight( 278 | settings 279 | .knn_regressor_settings 280 | .as_ref() 281 | .as_ref() 282 | .unwrap() 283 | .weight 284 | .clone(), 285 | ) 286 | .with_distance(Distances::hamming()); 287 | 288 | bincode::serialize(&KNNRegressor::fit(x, y, params).unwrap()).unwrap() 289 | } 290 | } 291 | } 292 | 293 | fn predict(x: &DenseMatrix, final_model: &Vec, settings: &Settings) -> Vec { 294 | match settings.knn_regressor_settings.as_ref().unwrap().distance { 295 | Distance::Euclidean => { 296 | let model: KNNRegressor = 297 | bincode::deserialize(final_model).unwrap(); 298 | model.predict(x).unwrap() 299 | } 300 | Distance::Manhattan => { 301 | let model: KNNRegressor = 302 | bincode::deserialize(final_model).unwrap(); 303 | model.predict(x).unwrap() 304 | } 305 | Distance::Minkowski(_) => { 306 | let model: KNNRegressor = 307 | bincode::deserialize(final_model).unwrap(); 308 | model.predict(x).unwrap() 309 | } 310 | Distance::Mahalanobis => { 311 | let model: KNNRegressor>> = 312 | bincode::deserialize(final_model).unwrap(); 313 | model.predict(x).unwrap() 314 | } 315 | Distance::Hamming => { 316 | let model: KNNRegressor = bincode::deserialize(final_model).unwrap(); 317 | model.predict(x).unwrap() 318 | } 319 | } 320 | } 321 | } 322 | -------------------------------------------------------------------------------- /src/algorithms/lasso_regressor.rs: -------------------------------------------------------------------------------- 1 | //! LASSO regression algorithm. 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, linear::lasso::Lasso, 5 | model_selection::cross_validate, model_selection::CrossValidationResult, 6 | }; 7 | 8 | use crate::{Algorithm, Settings}; 9 | 10 | /// The LASSO regression algorithm. 11 | /// 12 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/linear_model.html#lasso) 13 | /// for a more in-depth description of the algorithm. 14 | pub struct LassoRegressorWrapper {} 15 | 16 | impl super::ModelWrapper for LassoRegressorWrapper { 17 | fn cv( 18 | x: &DenseMatrix, 19 | y: &Vec, 20 | settings: &Settings, 21 | ) -> (CrossValidationResult, Algorithm) { 22 | ( 23 | cross_validate( 24 | Lasso::fit, 25 | x, 26 | y, 27 | settings 28 | .lasso_settings 29 | .as_ref() 30 | .expect("No settings provided for the LASSO regression algorithm.") 31 | .clone(), 32 | settings.get_kfolds(), 33 | settings.get_metric(), 34 | ) 35 | .expect("Error during cross-validation."), 36 | Algorithm::Lasso, 37 | ) 38 | } 39 | 40 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 41 | bincode::serialize( 42 | &Lasso::fit( 43 | x, 44 | y, 45 | settings 46 | .lasso_settings 47 | .as_ref() 48 | .expect("No settings provided for the LASSO regression algorithm.") 49 | .clone(), 50 | ) 51 | .expect("Error during training."), 52 | ) 53 | .expect("Cannot serialize trained model.") 54 | } 55 | 56 | fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { 57 | let model: Lasso> = 58 | bincode::deserialize(final_model).expect("Cannot deserialize trained model."); 59 | model.predict(x).expect("Error during inference.") 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/algorithms/linear_regressor.rs: -------------------------------------------------------------------------------- 1 | //! Linear regression algorithm. 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, 5 | linear::linear_regression::LinearRegression, 6 | model_selection::{cross_validate, CrossValidationResult}, 7 | }; 8 | 9 | use crate::{Algorithm, Settings}; 10 | 11 | /// The Linear regression algorithm. 12 | /// 13 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/linear_model.html#ordinary-least-squares) 14 | /// for a more in-depth description of the algorithm. 15 | pub struct LinearRegressorWrapper {} 16 | 17 | impl super::ModelWrapper for LinearRegressorWrapper { 18 | fn cv( 19 | x: &DenseMatrix, 20 | y: &Vec, 21 | settings: &Settings, 22 | ) -> (CrossValidationResult, Algorithm) { 23 | ( 24 | cross_validate( 25 | LinearRegression::fit, 26 | x, 27 | y, 28 | settings 29 | .linear_settings 30 | .as_ref() 31 | .expect("No settings provided for the linear regression algorithm.") 32 | .clone(), 33 | settings.get_kfolds(), 34 | settings.get_metric(), 35 | ) 36 | .expect("Error during cross-validation."), 37 | Algorithm::Linear, 38 | ) 39 | } 40 | 41 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 42 | bincode::serialize( 43 | &LinearRegression::fit( 44 | x, 45 | y, 46 | settings 47 | .linear_settings 48 | .as_ref() 49 | .expect("No settings provided for the linear regression algorithm.") 50 | .clone(), 51 | ) 52 | .expect("Error during training."), 53 | ) 54 | .expect("Cannot serialize trained model.") 55 | } 56 | 57 | fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { 58 | let model: LinearRegression> = 59 | bincode::deserialize(final_model).expect("Cannot deserialize trained model."); 60 | model.predict(x).expect("Error during inference.") 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/algorithms/logistic_regression.rs: -------------------------------------------------------------------------------- 1 | //! Logistic Regression 2 | 3 | use crate::{Algorithm, Settings}; 4 | use smartcore::{ 5 | linalg::naive::dense_matrix::DenseMatrix, linear::logistic_regression::LogisticRegression, 6 | model_selection::cross_validate, model_selection::CrossValidationResult, 7 | }; 8 | 9 | /// The Logistic Regression algorithm. 10 | /// 11 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression) 12 | /// for a more in-depth description of the algorithm. 13 | pub struct LogisticRegressionWrapper {} 14 | 15 | impl super::ModelWrapper for LogisticRegressionWrapper { 16 | fn cv( 17 | x: &DenseMatrix, 18 | y: &Vec, 19 | settings: &Settings, 20 | ) -> (CrossValidationResult, Algorithm) { 21 | ( 22 | cross_validate( 23 | LogisticRegression::fit, 24 | x, 25 | y, 26 | settings.logistic_settings.as_ref().unwrap().clone(), 27 | settings.get_kfolds(), 28 | settings.get_metric(), 29 | ) 30 | .unwrap(), 31 | Algorithm::LogisticRegression, 32 | ) 33 | } 34 | 35 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 36 | bincode::serialize( 37 | &LogisticRegression::fit(x, y, settings.logistic_settings.as_ref().unwrap().clone()) 38 | .unwrap(), 39 | ) 40 | .unwrap() 41 | } 42 | 43 | fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { 44 | let model: LogisticRegression> = 45 | bincode::deserialize(final_model).unwrap(); 46 | model.predict(x).unwrap() 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/algorithms/mod.rs: -------------------------------------------------------------------------------- 1 | //! # Algorithms 2 | //! 3 | //! This module contains the wrappers for the algorithms provided by this crate. 4 | //! The algorithms are all available through the common interface of the `ModelWrapper` trait. 5 | //! 6 | //! The available algorithms include: 7 | //! 8 | //! * Classification algorithms: 9 | //! - Logistic Regression 10 | //! - Random Forest Classifier 11 | //! - K-Nearest Neighbors Classifier 12 | //! - Decision Tree Classifier 13 | //! - Gaussian Naive Bayes Classifier 14 | //! - Categorical Naive Bayes Classifier 15 | //! - Support Vector Classifier 16 | //! 17 | //! * Regression algorithms: 18 | //! - Linear Regression 19 | //! - Elastic Net Regressor 20 | //! - Lasso Regressor 21 | //! - K-Nearest Neighbors Regressor 22 | //! - Ridge Regressor 23 | //! - Random Forest Regressor 24 | //! - Decision Tree Regressor 25 | //! - Support Vector Regressor 26 | 27 | mod linear_regressor; 28 | pub use linear_regressor::LinearRegressorWrapper; 29 | 30 | mod elastic_net_regressor; 31 | pub use elastic_net_regressor::ElasticNetRegressorWrapper; 32 | 33 | mod lasso_regressor; 34 | pub use lasso_regressor::LassoRegressorWrapper; 35 | 36 | mod knn_regressor; 37 | pub use knn_regressor::KNNRegressorWrapper; 38 | 39 | mod ridge_regressor; 40 | pub use ridge_regressor::RidgeRegressorWrapper; 41 | 42 | mod logistic_regression; 43 | pub use logistic_regression::LogisticRegressionWrapper; 44 | 45 | mod random_forest_classifier; 46 | pub use random_forest_classifier::RandomForestClassifierWrapper; 47 | 48 | mod random_forest_regressor; 49 | pub use random_forest_regressor::RandomForestRegressorWrapper; 50 | 51 | mod knn_classifier; 52 | pub use knn_classifier::KNNClassifierWrapper; 53 | 54 | mod decision_tree_classifier; 55 | pub use decision_tree_classifier::DecisionTreeClassifierWrapper; 56 | 57 | mod decision_tree_regressor; 58 | pub use decision_tree_regressor::DecisionTreeRegressorWrapper; 59 | 60 | mod gaussian_naive_bayes_classifier; 61 | pub use gaussian_naive_bayes_classifier::GaussianNaiveBayesClassifierWrapper; 62 | 63 | mod categorical_naive_bayes_classifier; 64 | pub use categorical_naive_bayes_classifier::CategoricalNaiveBayesClassifierWrapper; 65 | 66 | mod support_vector_classifier; 67 | pub use support_vector_classifier::SupportVectorClassifierWrapper; 68 | 69 | mod support_vector_regressor; 70 | pub use support_vector_regressor::SupportVectorRegressorWrapper; 71 | 72 | use crate::{Algorithm, Settings}; 73 | use smartcore::linalg::naive::dense_matrix::DenseMatrix; 74 | use smartcore::model_selection::CrossValidationResult; 75 | 76 | use crate::settings::FinalModel; 77 | use std::time::{Duration, Instant}; 78 | 79 | /// Trait for wrapping models 80 | pub trait ModelWrapper { 81 | /// Perform cross-validation and return the results 82 | /// 83 | /// # Arguments 84 | /// 85 | /// * `x` - The input data 86 | /// * `y` - The output data 87 | /// * `settings` - The settings for the model 88 | /// 89 | /// # Returns 90 | /// 91 | /// * `CrossValidationResult` - The cross-validation results 92 | /// * `Algorithm` - The algorithm used 93 | /// * `Duration` - The time taken to perform the cross-validation 94 | /// * `Vec` - The final model 95 | fn cv_model( 96 | x: &DenseMatrix, 97 | y: &Vec, 98 | settings: &Settings, 99 | ) -> (CrossValidationResult, Algorithm, Duration, Vec) { 100 | let start = Instant::now(); 101 | let results = Self::cv(x, y, settings); 102 | let end = Instant::now(); 103 | ( 104 | results.0, 105 | results.1, 106 | end.duration_since(start), 107 | match settings.final_model_approach { 108 | FinalModel::None => vec![], 109 | _ => Self::train(x, y, settings), 110 | }, 111 | ) 112 | } 113 | 114 | /// Perform cross-validation 115 | #[allow(clippy::ptr_arg)] 116 | fn cv( 117 | x: &DenseMatrix, 118 | y: &Vec, 119 | settings: &Settings, 120 | ) -> (CrossValidationResult, Algorithm); 121 | 122 | /// Train a model 123 | #[allow(clippy::ptr_arg)] 124 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec; 125 | 126 | /// Perform a prediction 127 | #[allow(clippy::ptr_arg)] 128 | fn predict(x: &DenseMatrix, final_model: &Vec, settings: &Settings) -> Vec; 129 | } 130 | -------------------------------------------------------------------------------- /src/algorithms/random_forest_classifier.rs: -------------------------------------------------------------------------------- 1 | //! Random Forest Classifier 2 | 3 | use smartcore::{ 4 | ensemble::random_forest_classifier::RandomForestClassifier, 5 | linalg::naive::dense_matrix::DenseMatrix, 6 | model_selection::{cross_validate, CrossValidationResult}, 7 | }; 8 | 9 | use crate::{Algorithm, Settings}; 10 | 11 | /// The Random Forest Classifier. 12 | /// 13 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/ensemble.html#random-forests) 14 | /// for a more in-depth description of the algorithm. 15 | pub struct RandomForestClassifierWrapper {} 16 | 17 | impl super::ModelWrapper for RandomForestClassifierWrapper { 18 | fn cv( 19 | x: &DenseMatrix, 20 | y: &Vec, 21 | settings: &Settings, 22 | ) -> (CrossValidationResult, Algorithm) { 23 | ( 24 | cross_validate( 25 | RandomForestClassifier::fit, 26 | x, 27 | y, 28 | settings 29 | .random_forest_classifier_settings 30 | .as_ref() 31 | .unwrap() 32 | .clone(), 33 | settings.get_kfolds(), 34 | settings.get_metric(), 35 | ) 36 | .unwrap(), 37 | Algorithm::RandomForestClassifier, 38 | ) 39 | } 40 | 41 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 42 | bincode::serialize( 43 | &RandomForestClassifier::fit( 44 | x, 45 | y, 46 | settings 47 | .random_forest_classifier_settings 48 | .as_ref() 49 | .unwrap() 50 | .clone(), 51 | ) 52 | .unwrap(), 53 | ) 54 | .unwrap() 55 | } 56 | 57 | fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { 58 | let model: RandomForestClassifier = bincode::deserialize(final_model).unwrap(); 59 | model.predict(x).unwrap() 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/algorithms/random_forest_regressor.rs: -------------------------------------------------------------------------------- 1 | //! Random Forest Regressor 2 | 3 | use smartcore::{ 4 | ensemble::random_forest_regressor::RandomForestRegressor, 5 | linalg::naive::dense_matrix::DenseMatrix, 6 | model_selection::{cross_validate, CrossValidationResult}, 7 | }; 8 | 9 | use crate::{Algorithm, Settings}; 10 | 11 | /// The Random Forest Regressor. 12 | /// 13 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/ensemble.html#random-forests) 14 | /// for a more in-depth description of the algorithm. 15 | pub struct RandomForestRegressorWrapper {} 16 | 17 | impl super::ModelWrapper for RandomForestRegressorWrapper { 18 | fn cv( 19 | x: &DenseMatrix, 20 | y: &Vec, 21 | settings: &Settings, 22 | ) -> (CrossValidationResult, Algorithm) { 23 | ( 24 | cross_validate( 25 | RandomForestRegressor::fit, 26 | x, 27 | y, 28 | settings 29 | .random_forest_regressor_settings 30 | .as_ref() 31 | .unwrap() 32 | .clone(), 33 | settings.get_kfolds(), 34 | settings.get_metric(), 35 | ) 36 | .unwrap(), 37 | Algorithm::RandomForestRegressor, 38 | ) 39 | } 40 | 41 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 42 | bincode::serialize( 43 | &RandomForestRegressor::fit( 44 | x, 45 | y, 46 | settings 47 | .random_forest_regressor_settings 48 | .as_ref() 49 | .unwrap() 50 | .clone(), 51 | ) 52 | .unwrap(), 53 | ) 54 | .unwrap() 55 | } 56 | 57 | fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { 58 | let model: RandomForestRegressor = bincode::deserialize(final_model).unwrap(); 59 | model.predict(x).unwrap() 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/algorithms/ridge_regressor.rs: -------------------------------------------------------------------------------- 1 | //! Ridge regression algorithm. 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, linear::ridge_regression::RidgeRegression, 5 | model_selection::cross_validate, model_selection::CrossValidationResult, 6 | }; 7 | 8 | use crate::{Algorithm, Settings}; 9 | 10 | /// The Ridge regression algorithm. 11 | /// 12 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/linear_model.html#ridge-regression) 13 | /// for a more in-depth description of the algorithm. 14 | pub struct RidgeRegressorWrapper {} 15 | 16 | impl super::ModelWrapper for RidgeRegressorWrapper { 17 | fn cv( 18 | x: &DenseMatrix, 19 | y: &Vec, 20 | settings: &Settings, 21 | ) -> (CrossValidationResult, Algorithm) { 22 | ( 23 | cross_validate( 24 | RidgeRegression::fit, 25 | x, 26 | y, 27 | settings.ridge_settings.as_ref().unwrap().clone(), 28 | settings.get_kfolds(), 29 | settings.get_metric(), 30 | ) 31 | .unwrap(), 32 | Algorithm::Ridge, 33 | ) 34 | } 35 | 36 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 37 | bincode::serialize( 38 | &RidgeRegression::fit(x, y, settings.ridge_settings.as_ref().unwrap().clone()).unwrap(), 39 | ) 40 | .unwrap() 41 | } 42 | 43 | fn predict(x: &DenseMatrix, final_model: &Vec, _settings: &Settings) -> Vec { 44 | let model: RidgeRegression> = 45 | bincode::deserialize(final_model).unwrap(); 46 | model.predict(x).unwrap() 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/algorithms/support_vector_classifier.rs: -------------------------------------------------------------------------------- 1 | //! Support Vector Classifier 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, 5 | model_selection::cross_validate, 6 | model_selection::CrossValidationResult, 7 | svm::{ 8 | svc::{SVCParameters as SmartcoreSVCParameters, SVC}, 9 | Kernels, LinearKernel, PolynomialKernel, RBFKernel, SigmoidKernel, 10 | }, 11 | }; 12 | 13 | use crate::{Algorithm, Kernel, Settings}; 14 | 15 | /// The Support Vector Classifier. 16 | /// 17 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/svm.html#svm-classification) 18 | /// for a more in-depth description of the algorithm. 19 | pub struct SupportVectorClassifierWrapper {} 20 | 21 | impl super::ModelWrapper for SupportVectorClassifierWrapper { 22 | fn cv( 23 | x: &DenseMatrix, 24 | y: &Vec, 25 | settings: &Settings, 26 | ) -> (CrossValidationResult, Algorithm) { 27 | let cv = match settings.svc_settings.as_ref().unwrap().kernel { 28 | Kernel::Linear => cross_validate( 29 | SVC::fit, 30 | x, 31 | y, 32 | SmartcoreSVCParameters::default() 33 | .with_tol(settings.svc_settings.as_ref().unwrap().tol) 34 | .with_c(settings.svc_settings.as_ref().unwrap().c) 35 | .with_epoch(settings.svc_settings.as_ref().unwrap().epoch) 36 | .with_kernel(Kernels::linear()), 37 | settings.get_kfolds(), 38 | settings.get_metric(), 39 | ) 40 | .unwrap(), 41 | Kernel::Polynomial(degree, gamma, coef) => cross_validate( 42 | SVC::fit, 43 | x, 44 | y, 45 | SmartcoreSVCParameters::default() 46 | .with_tol(settings.svc_settings.as_ref().unwrap().tol) 47 | .with_c(settings.svc_settings.as_ref().unwrap().c) 48 | .with_epoch(settings.svc_settings.as_ref().unwrap().epoch) 49 | .with_kernel(Kernels::polynomial(degree, gamma, coef)), 50 | settings.get_kfolds(), 51 | settings.get_metric(), 52 | ) 53 | .unwrap(), 54 | Kernel::RBF(gamma) => cross_validate( 55 | SVC::fit, 56 | x, 57 | y, 58 | SmartcoreSVCParameters::default() 59 | .with_tol(settings.svc_settings.as_ref().unwrap().tol) 60 | .with_c(settings.svc_settings.as_ref().unwrap().c) 61 | .with_epoch(settings.svc_settings.as_ref().unwrap().epoch) 62 | .with_kernel(Kernels::rbf(gamma)), 63 | settings.get_kfolds(), 64 | settings.get_metric(), 65 | ) 66 | .unwrap(), 67 | Kernel::Sigmoid(gamma, coef) => cross_validate( 68 | SVC::fit, 69 | x, 70 | y, 71 | SmartcoreSVCParameters::default() 72 | .with_tol(settings.svc_settings.as_ref().unwrap().tol) 73 | .with_c(settings.svc_settings.as_ref().unwrap().c) 74 | .with_epoch(settings.svc_settings.as_ref().unwrap().epoch) 75 | .with_kernel(Kernels::sigmoid(gamma, coef)), 76 | settings.get_kfolds(), 77 | settings.get_metric(), 78 | ) 79 | .unwrap(), 80 | }; 81 | (cv, Algorithm::SVC) 82 | } 83 | 84 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 85 | match settings.svc_settings.as_ref().unwrap().kernel { 86 | Kernel::Linear => { 87 | let params = SmartcoreSVCParameters::default() 88 | .with_tol(settings.svc_settings.as_ref().unwrap().tol) 89 | .with_c(settings.svc_settings.as_ref().unwrap().c) 90 | .with_epoch(settings.svc_settings.as_ref().unwrap().epoch) 91 | .with_kernel(Kernels::linear()); 92 | 93 | bincode::serialize(&SVC::fit(x, y, params).unwrap()).unwrap() 94 | } 95 | Kernel::Polynomial(degree, gamma, coef) => { 96 | let params = SmartcoreSVCParameters::default() 97 | .with_tol(settings.svc_settings.as_ref().unwrap().tol) 98 | .with_c(settings.svc_settings.as_ref().unwrap().c) 99 | .with_epoch(settings.svc_settings.as_ref().unwrap().epoch) 100 | .with_kernel(Kernels::polynomial(degree, gamma, coef)); 101 | 102 | bincode::serialize(&SVC::fit(x, y, params).unwrap()).unwrap() 103 | } 104 | Kernel::RBF(gamma) => { 105 | let params = SmartcoreSVCParameters::default() 106 | .with_tol(settings.svc_settings.as_ref().unwrap().tol) 107 | .with_c(settings.svc_settings.as_ref().unwrap().c) 108 | .with_epoch(settings.svc_settings.as_ref().unwrap().epoch) 109 | .with_kernel(Kernels::rbf(gamma)); 110 | 111 | bincode::serialize(&SVC::fit(x, y, params).unwrap()).unwrap() 112 | } 113 | Kernel::Sigmoid(gamma, coef) => { 114 | let params = SmartcoreSVCParameters::default() 115 | .with_tol(settings.svc_settings.as_ref().unwrap().tol) 116 | .with_c(settings.svc_settings.as_ref().unwrap().c) 117 | .with_epoch(settings.svc_settings.as_ref().unwrap().epoch) 118 | .with_kernel(Kernels::sigmoid(gamma, coef)); 119 | 120 | bincode::serialize(&SVC::fit(x, y, params).unwrap()).unwrap() 121 | } 122 | } 123 | } 124 | 125 | fn predict(x: &DenseMatrix, final_model: &Vec, settings: &Settings) -> Vec { 126 | match settings.svc_settings.as_ref().unwrap().kernel { 127 | Kernel::Linear => { 128 | let model: SVC, LinearKernel> = 129 | bincode::deserialize(final_model).unwrap(); 130 | model.predict(x).unwrap() 131 | } 132 | Kernel::Polynomial(_, _, _) => { 133 | let model: SVC, PolynomialKernel> = 134 | bincode::deserialize(final_model).unwrap(); 135 | model.predict(x).unwrap() 136 | } 137 | Kernel::RBF(_) => { 138 | let model: SVC, RBFKernel> = 139 | bincode::deserialize(final_model).unwrap(); 140 | model.predict(x).unwrap() 141 | } 142 | Kernel::Sigmoid(_, _) => { 143 | let model: SVC, SigmoidKernel> = 144 | bincode::deserialize(final_model).unwrap(); 145 | model.predict(x).unwrap() 146 | } 147 | } 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /src/algorithms/support_vector_regressor.rs: -------------------------------------------------------------------------------- 1 | //! Support Vector Regressor 2 | 3 | use smartcore::{ 4 | linalg::naive::dense_matrix::DenseMatrix, 5 | model_selection::cross_validate, 6 | model_selection::CrossValidationResult, 7 | svm::{ 8 | svr::{SVRParameters as SmartcoreSVRParameters, SVR}, 9 | Kernels, LinearKernel, PolynomialKernel, RBFKernel, SigmoidKernel, 10 | }, 11 | }; 12 | 13 | use crate::{Algorithm, Kernel, Settings}; 14 | 15 | /// The Support Vector Regressor. 16 | /// 17 | /// See [scikit-learn's user guide](https://scikit-learn.org/stable/modules/svm.html#svm-regression) 18 | /// for a more in-depth description of the algorithm. 19 | pub struct SupportVectorRegressorWrapper {} 20 | 21 | impl super::ModelWrapper for SupportVectorRegressorWrapper { 22 | fn cv( 23 | x: &DenseMatrix, 24 | y: &Vec, 25 | settings: &Settings, 26 | ) -> (CrossValidationResult, Algorithm) { 27 | let cv = match settings.svr_settings.as_ref().unwrap().kernel { 28 | Kernel::Linear => cross_validate( 29 | SVR::fit, 30 | x, 31 | y, 32 | SmartcoreSVRParameters::default() 33 | .with_tol(settings.svr_settings.as_ref().unwrap().tol) 34 | .with_c(settings.svr_settings.as_ref().unwrap().c) 35 | .with_eps(settings.svr_settings.as_ref().unwrap().c) 36 | .with_kernel(Kernels::linear()), 37 | settings.get_kfolds(), 38 | settings.get_metric(), 39 | ) 40 | .unwrap(), 41 | Kernel::Polynomial(degree, gamma, coef) => cross_validate( 42 | SVR::fit, 43 | x, 44 | y, 45 | SmartcoreSVRParameters::default() 46 | .with_tol(settings.svr_settings.as_ref().unwrap().tol) 47 | .with_c(settings.svr_settings.as_ref().unwrap().c) 48 | .with_eps(settings.svr_settings.as_ref().unwrap().c) 49 | .with_kernel(Kernels::polynomial(degree, gamma, coef)), 50 | settings.get_kfolds(), 51 | settings.get_metric(), 52 | ) 53 | .unwrap(), 54 | Kernel::RBF(gamma) => cross_validate( 55 | SVR::fit, 56 | x, 57 | y, 58 | SmartcoreSVRParameters::default() 59 | .with_tol(settings.svr_settings.as_ref().unwrap().tol) 60 | .with_c(settings.svr_settings.as_ref().unwrap().c) 61 | .with_eps(settings.svr_settings.as_ref().unwrap().c) 62 | .with_kernel(Kernels::rbf(gamma)), 63 | settings.get_kfolds(), 64 | settings.get_metric(), 65 | ) 66 | .unwrap(), 67 | Kernel::Sigmoid(gamma, coef) => cross_validate( 68 | SVR::fit, 69 | x, 70 | y, 71 | SmartcoreSVRParameters::default() 72 | .with_tol(settings.svr_settings.as_ref().unwrap().tol) 73 | .with_c(settings.svr_settings.as_ref().unwrap().c) 74 | .with_eps(settings.svr_settings.as_ref().unwrap().c) 75 | .with_kernel(Kernels::sigmoid(gamma, coef)), 76 | settings.get_kfolds(), 77 | settings.get_metric(), 78 | ) 79 | .unwrap(), 80 | }; 81 | (cv, Algorithm::SVR) 82 | } 83 | 84 | fn train(x: &DenseMatrix, y: &Vec, settings: &Settings) -> Vec { 85 | match settings.svr_settings.as_ref().unwrap().kernel { 86 | Kernel::Linear => { 87 | let params = SmartcoreSVRParameters::default() 88 | .with_tol(settings.svr_settings.as_ref().unwrap().tol) 89 | .with_c(settings.svr_settings.as_ref().unwrap().c) 90 | .with_eps(settings.svr_settings.as_ref().unwrap().c) 91 | .with_kernel(Kernels::linear()); 92 | 93 | bincode::serialize(&SVR::fit(x, y, params).unwrap()).unwrap() 94 | } 95 | Kernel::Polynomial(degree, gamma, coef) => { 96 | let params = SmartcoreSVRParameters::default() 97 | .with_tol(settings.svr_settings.as_ref().unwrap().tol) 98 | .with_c(settings.svr_settings.as_ref().unwrap().c) 99 | .with_eps(settings.svr_settings.as_ref().unwrap().c) 100 | .with_kernel(Kernels::polynomial(degree, gamma, coef)); 101 | 102 | bincode::serialize(&SVR::fit(x, y, params).unwrap()).unwrap() 103 | } 104 | Kernel::RBF(gamma) => { 105 | let params = SmartcoreSVRParameters::default() 106 | .with_tol(settings.svr_settings.as_ref().unwrap().tol) 107 | .with_c(settings.svr_settings.as_ref().unwrap().c) 108 | .with_eps(settings.svr_settings.as_ref().unwrap().c) 109 | .with_kernel(Kernels::rbf(gamma)); 110 | 111 | bincode::serialize(&SVR::fit(x, y, params).unwrap()).unwrap() 112 | } 113 | Kernel::Sigmoid(gamma, coef) => { 114 | let params = SmartcoreSVRParameters::default() 115 | .with_tol(settings.svr_settings.as_ref().unwrap().tol) 116 | .with_c(settings.svr_settings.as_ref().unwrap().c) 117 | .with_eps(settings.svr_settings.as_ref().unwrap().c) 118 | .with_kernel(Kernels::sigmoid(gamma, coef)); 119 | 120 | bincode::serialize(&SVR::fit(x, y, params).unwrap()).unwrap() 121 | } 122 | } 123 | } 124 | 125 | fn predict(x: &DenseMatrix, final_model: &Vec, settings: &Settings) -> Vec { 126 | match settings.svr_settings.as_ref().unwrap().kernel { 127 | Kernel::Linear => { 128 | let model: SVR, LinearKernel> = 129 | bincode::deserialize(final_model).unwrap(); 130 | model.predict(x).unwrap() 131 | } 132 | Kernel::Polynomial(_, _, _) => { 133 | let model: SVR, PolynomialKernel> = 134 | bincode::deserialize(final_model).unwrap(); 135 | model.predict(x).unwrap() 136 | } 137 | Kernel::RBF(_) => { 138 | let model: SVR, RBFKernel> = 139 | bincode::deserialize(final_model).unwrap(); 140 | model.predict(x).unwrap() 141 | } 142 | Kernel::Sigmoid(_, _) => { 143 | let model: SVR, SigmoidKernel> = 144 | bincode::deserialize(final_model).unwrap(); 145 | model.predict(x).unwrap() 146 | } 147 | } 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /src/cookbook.rs: -------------------------------------------------------------------------------- 1 | //! # A cookbook of common autoML tasks 2 | //! ## Basic Regression 3 | //! ```rust 4 | #![doc = include_str!("../examples/minimal_regression.rs")] 5 | //! ``` 6 | //! ## Basic Classification 7 | //! ```rust 8 | #![doc = include_str!("../examples/minimal_regression.rs")] 9 | //! ``` 10 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![deny(clippy::correctness)] 2 | #![warn( 3 | clippy::all, 4 | clippy::suspicious, 5 | clippy::complexity, 6 | clippy::perf, 7 | clippy::style, 8 | clippy::pedantic, 9 | clippy::nursery, 10 | clippy::missing_docs_in_private_items 11 | )] 12 | #![allow(clippy::module_name_repetitions, clippy::too_many_lines)] 13 | #![warn(missing_docs, rustdoc::missing_doc_code_examples)] 14 | #![doc = include_str!("../README.md")] 15 | 16 | pub mod settings; 17 | pub use settings::Settings; 18 | use settings::{Algorithm, Distance, FinalModel, Kernel, Metric, PreProcessing}; 19 | 20 | pub mod cookbook; 21 | 22 | mod algorithms; 23 | use algorithms::{ 24 | CategoricalNaiveBayesClassifierWrapper, DecisionTreeClassifierWrapper, 25 | DecisionTreeRegressorWrapper, ElasticNetRegressorWrapper, GaussianNaiveBayesClassifierWrapper, 26 | KNNClassifierWrapper, KNNRegressorWrapper, LassoRegressorWrapper, LinearRegressorWrapper, 27 | LogisticRegressionWrapper, ModelWrapper, RandomForestClassifierWrapper, 28 | RandomForestRegressorWrapper, RidgeRegressorWrapper, SupportVectorClassifierWrapper, 29 | SupportVectorRegressorWrapper, 30 | }; 31 | 32 | mod utils; 33 | use utils::elementwise_multiply; 34 | 35 | use itertools::Itertools; 36 | use smartcore::{ 37 | dataset::Dataset, 38 | decomposition::{ 39 | pca::{PCAParameters, PCA}, 40 | svd::{SVDParameters, SVD}, 41 | }, 42 | linalg::{naive::dense_matrix::DenseMatrix, BaseMatrix}, 43 | model_selection::{train_test_split, CrossValidationResult}, 44 | }; 45 | use std::{ 46 | cmp::Ordering::Equal, 47 | fmt::{Display, Formatter}, 48 | io::{Read, Write}, 49 | time::Duration, 50 | }; 51 | 52 | #[cfg(any(feature = "nd"))] 53 | use ndarray::{Array1, Array2}; 54 | 55 | #[cfg(any(feature = "csv"))] 56 | use { 57 | polars::prelude::{DataFrame, Float32Type}, 58 | utils::validate_and_read, 59 | }; 60 | 61 | use { 62 | comfy_table::{ 63 | modifiers::UTF8_SOLID_INNER_BORDERS, presets::UTF8_FULL, Attribute, Cell, Table, 64 | }, 65 | humantime::format_duration, 66 | }; 67 | 68 | /// This trait must be implemented for any types passed to the `SupervisedModel::new` as data. 69 | pub trait IntoSupervisedData { 70 | /// Converts the struct into paired features and labels 71 | fn to_supervised_data(self) -> (DenseMatrix, Vec); 72 | } 73 | 74 | /// Types that implement this trait can be paired in a tuple with a type implementing `IntoLabels` to 75 | /// automatically satisfy `IntoSupervisedData`. This trait is also required for data that is passed to `predict`. 76 | pub trait IntoFeatures { 77 | /// Converts the struct into a dense matrix of features 78 | fn to_dense_matrix(self) -> DenseMatrix; 79 | } 80 | 81 | /// Types that implement this trait can be paired in a tuple with a type implementing `IntoFeatures` 82 | /// to automatically satisfy `IntoSupervisedData`. 83 | pub trait IntoLabels { 84 | /// Converts the struct into a vector of labels 85 | fn into_vec(self) -> Vec; 86 | } 87 | 88 | impl IntoSupervisedData for Dataset { 89 | fn to_supervised_data(self) -> (DenseMatrix, Vec) { 90 | ( 91 | DenseMatrix::from_array(self.num_samples, self.num_features, &self.data), 92 | self.target, 93 | ) 94 | } 95 | } 96 | 97 | #[cfg(any(feature = "csv"))] 98 | impl IntoSupervisedData for (&str, usize) { 99 | fn to_supervised_data(self) -> (DenseMatrix, Vec) { 100 | let (filepath, target_index) = self; 101 | let df = validate_and_read(filepath); 102 | 103 | // Get target variables 104 | let target_column_name = df.get_column_names()[target_index]; 105 | let series = df.column(target_column_name).unwrap().clone(); 106 | let target_df = DataFrame::new(vec![series]).unwrap(); 107 | let ndarray = target_df.to_ndarray::().unwrap(); 108 | let y = ndarray.into_raw_vec(); 109 | 110 | // Get the rest of the data 111 | let features = df.drop(target_column_name).unwrap(); 112 | let (height, width) = features.shape(); 113 | let ndarray = features.to_ndarray::().unwrap(); 114 | let x = DenseMatrix::from_array(height, width, ndarray.as_slice().unwrap()); 115 | (x, y) 116 | } 117 | } 118 | 119 | #[cfg(any(feature = "csv"))] 120 | impl IntoFeatures for &str { 121 | fn to_dense_matrix(self) -> DenseMatrix { 122 | let df = validate_and_read(self); 123 | 124 | // Get the rest of the data 125 | let (height, width) = df.shape(); 126 | let ndarray = df.to_ndarray::().unwrap(); 127 | DenseMatrix::from_array(height, width, ndarray.as_slice().unwrap()) 128 | } 129 | } 130 | 131 | impl IntoSupervisedData for (X, Y) 132 | where 133 | X: IntoFeatures, 134 | Y: IntoLabels, 135 | { 136 | fn to_supervised_data(self) -> (DenseMatrix, Vec) { 137 | (self.0.to_dense_matrix(), self.1.into_vec()) 138 | } 139 | } 140 | 141 | impl IntoFeatures for Vec> { 142 | fn to_dense_matrix(self) -> DenseMatrix { 143 | DenseMatrix::from_2d_vec(&self) 144 | } 145 | } 146 | 147 | impl IntoLabels for Vec { 148 | fn into_vec(self) -> Vec { 149 | self 150 | } 151 | } 152 | 153 | #[cfg(any(feature = "nd"))] 154 | impl IntoFeatures for Array2 { 155 | fn to_dense_matrix(self) -> DenseMatrix { 156 | DenseMatrix::from_array(self.shape()[0], self.shape()[1], self.as_slice().unwrap()) 157 | } 158 | } 159 | 160 | #[cfg(any(feature = "nd"))] 161 | impl IntoLabels for Array1 { 162 | fn into_vec(self) -> Vec { 163 | self.to_vec() 164 | } 165 | } 166 | 167 | /// Trains and compares supervised models 168 | #[derive(serde::Serialize, serde::Deserialize)] 169 | pub struct SupervisedModel { 170 | /// Settings for the model. 171 | settings: Settings, 172 | /// The training data. 173 | x_train: DenseMatrix, 174 | /// The training labels. 175 | y_train: Vec, 176 | /// The validation data. 177 | x_val: DenseMatrix, 178 | /// The validation labels. 179 | y_val: Vec, 180 | /// The number of classes in the data. 181 | number_of_classes: usize, 182 | /// The results of the model comparison. 183 | comparison: Vec, 184 | /// The final model. 185 | metamodel: Model, 186 | /// PCA model for preprocessing. 187 | preprocessing_pca: Option>>, 188 | /// SVD model for preprocessing. 189 | preprocessing_svd: Option>>, 190 | } 191 | 192 | impl SupervisedModel { 193 | /// Create a new supervised model. This function accepts various types of syntax. For instance, it will work for vectors: 194 | /// ``` 195 | /// # use automl::{SupervisedModel, Settings}; 196 | /// let model = automl::SupervisedModel::new( 197 | /// (vec![vec![1.0; 5]; 5], 198 | /// vec![1.0; 5]), 199 | /// automl::Settings::default_regression(), 200 | /// ); 201 | /// ``` 202 | /// It also works for some ndarray datatypes: 203 | /// ``` 204 | /// # use automl::{SupervisedModel, Settings}; 205 | /// #[cfg(any(feature = "nd"))] 206 | /// let model = SupervisedModel::new( 207 | /// ( 208 | /// ndarray::arr2(&[[1.0, 2.0], [3.0, 4.0]]), 209 | /// ndarray::arr1(&[1.0, 2.0]) 210 | /// ), 211 | /// automl::Settings::default_regression(), 212 | /// ); 213 | /// ``` 214 | /// But you can also create a new supervised model from a [smartcore toy dataset](https://docs.rs/smartcore/0.2.0/smartcore/dataset/index.html) 215 | /// ``` 216 | /// # use automl::{SupervisedModel, Settings}; 217 | /// let model = SupervisedModel::new( 218 | /// smartcore::dataset::diabetes::load_dataset(), 219 | /// Settings::default_regression() 220 | /// ); 221 | /// ``` 222 | /// You can even create a new supervised model directly from a CSV! 223 | /// ``` 224 | /// # use automl::{SupervisedModel, Settings}; 225 | /// #[cfg(any(feature = "csv"))] 226 | /// let model = SupervisedModel::new( 227 | /// ("data/diabetes.csv", 10), 228 | /// Settings::default_regression() 229 | /// ); 230 | /// ``` 231 | /// And that CSV can even come from a URL 232 | /// ``` 233 | /// # use automl::{SupervisedModel, Settings}; 234 | /// #[cfg(any(feature = "csv"))] 235 | /// let mut model = automl::SupervisedModel::new( 236 | /// ( 237 | /// "https://raw.githubusercontent.com/plotly/datasets/master/diabetes.csv", 238 | /// 8, 239 | /// ), 240 | /// Settings::default_regression(), 241 | /// ); 242 | pub fn new(data: D, settings: Settings) -> Self 243 | where 244 | D: IntoSupervisedData, 245 | { 246 | let (x, y) = data.to_supervised_data(); 247 | Self::build(x, y, settings) 248 | } 249 | 250 | /// Load the supervised model from a file saved previously 251 | /// ``` 252 | /// # use automl::{SupervisedModel, Settings}; 253 | /// # let mut model = SupervisedModel::new( 254 | /// # smartcore::dataset::diabetes::load_dataset(), 255 | /// # Settings::default_regression() 256 | /// # ); 257 | /// # model.save("tests/load_that_model.aml"); 258 | /// let model = SupervisedModel::new_from_file("tests/load_that_model.aml"); 259 | /// # std::fs::remove_file("tests/load_that_model.aml"); 260 | /// ``` 261 | #[must_use] 262 | pub fn new_from_file(file_name: &str) -> Self { 263 | let mut buf: Vec = Vec::new(); 264 | std::fs::File::open(file_name) 265 | .and_then(|mut f| f.read_to_end(&mut buf)) 266 | .expect("Cannot load model from file."); 267 | bincode::deserialize(&buf).expect("Can not deserialize the model") 268 | } 269 | 270 | /// Predict values using the final model based on a vec. 271 | /// ``` 272 | /// # use automl::{SupervisedModel, Settings}; 273 | /// # let mut model = SupervisedModel::new( 274 | /// # smartcore::dataset::diabetes::load_dataset(), 275 | /// # Settings::default_regression() 276 | /// # .only(automl::settings::Algorithm::Linear) 277 | /// # ); 278 | /// # model.train(); 279 | /// model.predict(vec![vec![5.0; 10]; 5]); 280 | /// ``` 281 | /// Or predict values using the final model based on ndarray. 282 | /// ``` 283 | /// # use automl::{SupervisedModel, Settings}; 284 | /// # #[cfg(any(feature = "nd"))] 285 | /// # let mut model = SupervisedModel::new( 286 | /// # smartcore::dataset::diabetes::load_dataset(), 287 | /// # Settings::default_regression() 288 | /// # .only(automl::settings::Algorithm::Linear) 289 | /// # ); 290 | /// # #[cfg(any(feature = "nd"))] 291 | /// # model.train(); 292 | /// #[cfg(any(feature = "nd"))] 293 | /// model.predict( 294 | /// ndarray::arr2(&[ 295 | /// [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 296 | /// [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] 297 | /// ]) 298 | /// ); 299 | /// ``` 300 | /// You can also predict from a CSV file 301 | /// ``` 302 | /// # use automl::{SupervisedModel, Settings}; 303 | /// # #[cfg(any(feature = "csv"))] 304 | /// # let mut model = SupervisedModel::new( 305 | /// # ("data/diabetes.csv", 10), 306 | /// # Settings::default_regression() 307 | /// # .only(automl::settings::Algorithm::Linear) 308 | /// # ); 309 | /// # #[cfg(any(feature = "csv"))] 310 | /// # model.train(); 311 | /// #[cfg(any(feature = "csv"))] 312 | /// model.predict("data/diabetes_without_target.csv"); 313 | /// ``` 314 | /// 315 | /// # Panics 316 | /// 317 | /// If the model has not been trained, this function will panic. 318 | pub fn predict(&self, x: X) -> Vec { 319 | let x = &self.preprocess(x.to_dense_matrix()); 320 | match self.settings.final_model_approach { 321 | FinalModel::None => panic!(""), 322 | FinalModel::Best => self.predict_by_model(x, &self.comparison[0]), 323 | FinalModel::Blending { algorithm, .. } => self.predict_blended_model(x, algorithm), 324 | } 325 | } 326 | 327 | /// Runs a model comparison and trains a final model. 328 | /// ``` 329 | /// # use automl::{SupervisedModel, Settings}; 330 | /// let mut model = SupervisedModel::new( 331 | /// smartcore::dataset::diabetes::load_dataset(), 332 | /// Settings::default_regression() 333 | /// # .only(automl::settings::Algorithm::Linear) 334 | /// ); 335 | /// model.train(); 336 | /// ``` 337 | pub fn train(&mut self) { 338 | // Train any necessary preprocessing 339 | if let PreProcessing::ReplaceWithPCA { 340 | number_of_components, 341 | } = self.settings.preprocessing 342 | { 343 | self.train_pca(&self.x_train.clone(), number_of_components); 344 | } 345 | if let PreProcessing::ReplaceWithSVD { 346 | number_of_components, 347 | } = self.settings.preprocessing 348 | { 349 | self.train_svd(&self.x_train.clone(), number_of_components); 350 | } 351 | 352 | // Preprocess the data 353 | self.x_train = self.preprocess(self.x_train.clone()); 354 | 355 | // Split validatino out if blending 356 | if let FinalModel::Blending { 357 | meta_training_fraction, 358 | meta_testing_fraction: _, 359 | algorithm: _, 360 | } = &self.settings.final_model_approach 361 | { 362 | let (x_train, x_val, y_train, y_val) = train_test_split( 363 | &self.x_train, 364 | &self.y_train, 365 | *meta_training_fraction, 366 | self.settings.shuffle, 367 | ); 368 | self.x_train = x_train; 369 | self.y_train = y_train; 370 | self.y_val = y_val; 371 | self.x_val = x_val; 372 | } 373 | 374 | // Run logistic regression 375 | if !self 376 | .settings 377 | .skiplist 378 | .contains(&Algorithm::LogisticRegression) 379 | { 380 | self.record_model(LogisticRegressionWrapper::cv_model( 381 | &self.x_train, 382 | &self.y_train, 383 | &self.settings, 384 | )); 385 | } 386 | 387 | // Run random forest classification 388 | if !self 389 | .settings 390 | .skiplist 391 | .contains(&Algorithm::RandomForestClassifier) 392 | { 393 | self.record_model(RandomForestClassifierWrapper::cv_model( 394 | &self.x_train, 395 | &self.y_train, 396 | &self.settings, 397 | )); 398 | } 399 | 400 | // Run k-nearest neighbor classifier 401 | if !self.settings.skiplist.contains(&Algorithm::KNNClassifier) { 402 | self.record_model(KNNClassifierWrapper::cv_model( 403 | &self.x_train, 404 | &self.y_train, 405 | &self.settings, 406 | )); 407 | } 408 | 409 | if !self 410 | .settings 411 | .skiplist 412 | .contains(&Algorithm::DecisionTreeClassifier) 413 | { 414 | self.record_model(DecisionTreeClassifierWrapper::cv_model( 415 | &self.x_train, 416 | &self.y_train, 417 | &self.settings, 418 | )); 419 | } 420 | 421 | if !self 422 | .settings 423 | .skiplist 424 | .contains(&Algorithm::GaussianNaiveBayes) 425 | { 426 | self.record_model(GaussianNaiveBayesClassifierWrapper::cv_model( 427 | &self.x_train, 428 | &self.y_train, 429 | &self.settings, 430 | )); 431 | } 432 | 433 | if !self 434 | .settings 435 | .skiplist 436 | .contains(&Algorithm::CategoricalNaiveBayes) 437 | && std::mem::discriminant(&self.settings.preprocessing) 438 | != std::mem::discriminant(&PreProcessing::ReplaceWithPCA { 439 | number_of_components: 1, 440 | }) 441 | && std::mem::discriminant(&self.settings.preprocessing) 442 | != std::mem::discriminant(&PreProcessing::ReplaceWithSVD { 443 | number_of_components: 1, 444 | }) 445 | { 446 | self.record_model(CategoricalNaiveBayesClassifierWrapper::cv_model( 447 | &self.x_train, 448 | &self.y_train, 449 | &self.settings, 450 | )); 451 | } 452 | 453 | if self.number_of_classes == 2 && !self.settings.skiplist.contains(&Algorithm::SVC) { 454 | self.record_model(SupportVectorClassifierWrapper::cv_model( 455 | &self.x_train, 456 | &self.y_train, 457 | &self.settings, 458 | )); 459 | } 460 | 461 | if !self.settings.skiplist.contains(&Algorithm::Linear) { 462 | self.record_model(LinearRegressorWrapper::cv_model( 463 | &self.x_train, 464 | &self.y_train, 465 | &self.settings, 466 | )); 467 | } 468 | 469 | if !self.settings.skiplist.contains(&Algorithm::SVR) { 470 | self.record_model(SupportVectorRegressorWrapper::cv_model( 471 | &self.x_train, 472 | &self.y_train, 473 | &self.settings, 474 | )); 475 | } 476 | 477 | if !self.settings.skiplist.contains(&Algorithm::Lasso) { 478 | self.record_model(RidgeRegressorWrapper::cv_model( 479 | &self.x_train, 480 | &self.y_train, 481 | &self.settings, 482 | )); 483 | } 484 | 485 | if !self.settings.skiplist.contains(&Algorithm::Ridge) { 486 | self.record_model(LassoRegressorWrapper::cv_model( 487 | &self.x_train, 488 | &self.y_train, 489 | &self.settings, 490 | )); 491 | } 492 | 493 | if !self.settings.skiplist.contains(&Algorithm::ElasticNet) { 494 | self.record_model(ElasticNetRegressorWrapper::cv_model( 495 | &self.x_train, 496 | &self.y_train, 497 | &self.settings, 498 | )); 499 | } 500 | 501 | if !self 502 | .settings 503 | .skiplist 504 | .contains(&Algorithm::DecisionTreeRegressor) 505 | { 506 | self.record_model(DecisionTreeRegressorWrapper::cv_model( 507 | &self.x_train, 508 | &self.y_train, 509 | &self.settings, 510 | )); 511 | } 512 | 513 | if !self 514 | .settings 515 | .skiplist 516 | .contains(&Algorithm::RandomForestRegressor) 517 | { 518 | self.record_model(RandomForestRegressorWrapper::cv_model( 519 | &self.x_train, 520 | &self.y_train, 521 | &self.settings, 522 | )); 523 | } 524 | 525 | if !self.settings.skiplist.contains(&Algorithm::KNNRegressor) { 526 | self.record_model(KNNRegressorWrapper::cv_model( 527 | &self.x_train, 528 | &self.y_train, 529 | &self.settings, 530 | )); 531 | } 532 | 533 | if let FinalModel::Blending { 534 | algorithm, 535 | meta_training_fraction, 536 | meta_testing_fraction, 537 | } = self.settings.final_model_approach 538 | { 539 | self.train_blended_model(algorithm, meta_training_fraction, meta_testing_fraction); 540 | } 541 | } 542 | 543 | /// Save the supervised model to a file for later use 544 | /// ``` 545 | /// # use automl::{SupervisedModel, Settings}; 546 | /// let mut model = SupervisedModel::new( 547 | /// smartcore::dataset::diabetes::load_dataset(), 548 | /// Settings::default_regression() 549 | /// ); 550 | /// model.save("tests/save_that_model.aml"); 551 | /// # std::fs::remove_file("tests/save_that_model.aml"); 552 | /// ``` 553 | pub fn save(&self, file_name: &str) { 554 | let serial = bincode::serialize(&self).expect("Cannot serialize model."); 555 | std::fs::File::create(file_name) 556 | .and_then(|mut f| f.write_all(&serial)) 557 | .expect("Cannot write model to file."); 558 | } 559 | 560 | /// Save the best model for later use as a smartcore native object. 561 | /// ``` 562 | /// # use automl::{SupervisedModel, Settings, settings::Algorithm}; 563 | /// use std::io::Read; 564 | /// 565 | /// let mut model = SupervisedModel::new( 566 | /// smartcore::dataset::diabetes::load_dataset(), 567 | /// Settings::default_regression() 568 | /// # .only(Algorithm::Linear) 569 | /// ); 570 | /// model.train(); 571 | /// model.save("tests/save_best.sc"); 572 | /// # std::fs::remove_file("tests/save_best.sc"); 573 | /// ``` 574 | pub fn save_best(&self, file_name: &str) { 575 | if matches!(self.settings.final_model_approach, FinalModel::Best) { 576 | std::fs::File::create(file_name) 577 | .and_then(|mut f| f.write_all(&self.comparison[0].model)) 578 | .expect("Cannot write model to file."); 579 | } 580 | } 581 | } 582 | 583 | /// Private functions go here 584 | impl SupervisedModel { 585 | /// Build a new supervised model 586 | /// 587 | /// # Arguments 588 | /// 589 | /// * `x` - The input data 590 | /// * `y` - The output data 591 | /// * `settings` - The settings for the model 592 | fn build(x: DenseMatrix, y: Vec, settings: Settings) -> Self { 593 | Self { 594 | settings, 595 | x_train: x, 596 | number_of_classes: Self::count_classes(&y), 597 | y_train: y, 598 | x_val: DenseMatrix::new(0, 0, vec![]), 599 | y_val: vec![], 600 | comparison: vec![], 601 | metamodel: Model::default(), 602 | preprocessing_pca: None, 603 | preprocessing_svd: None, 604 | } 605 | } 606 | 607 | /// Train the supervised model. 608 | /// 609 | /// # Arguments 610 | /// 611 | /// * `algo` - The algorithm to use 612 | /// * `training_fraction` - The fraction of the data to use for training 613 | /// * `testing_fraction` - The fraction of the data to use for testing 614 | fn train_blended_model( 615 | &mut self, 616 | algo: Algorithm, 617 | training_fraction: f32, 618 | testing_fraction: f32, 619 | ) { 620 | // Make the data 621 | let mut meta_x: Vec> = Vec::new(); 622 | for model in &self.comparison { 623 | meta_x.push(self.predict_by_model(&self.x_val, model)); 624 | } 625 | let xdm = DenseMatrix::from_2d_vec(&meta_x).transpose(); 626 | 627 | // Split into datasets 628 | let (x_train, x_test, y_train, y_test) = train_test_split( 629 | &xdm, 630 | &self.y_val, 631 | training_fraction / (training_fraction + testing_fraction), 632 | self.settings.shuffle, 633 | ); 634 | 635 | // Train the model 636 | // let model = LassoRegressorWrapper::train(&x_train, &y_train, &self.settings); 637 | let model = algo.get_trainer()(&x_train, &y_train, &self.settings); 638 | 639 | // Score the model 640 | let train_score = self.settings.get_metric()( 641 | &y_train, 642 | &algo.get_predictor()(&x_train, &model, &self.settings), 643 | // &LassoRegressorWrapper::predict(&x_train, &model, &self.settings), 644 | ); 645 | let test_score = self.settings.get_metric()( 646 | &y_test, 647 | &algo.get_predictor()(&x_test, &model, &self.settings), 648 | // &LassoRegressorWrapper::predict(&x_test, &model, &self.settings), 649 | ); 650 | 651 | self.metamodel = Model { 652 | score: CrossValidationResult { 653 | test_score: vec![test_score; 1], 654 | train_score: vec![train_score; 1], 655 | }, 656 | name: algo, 657 | duration: Duration::default(), 658 | model, 659 | }; 660 | } 661 | 662 | /// Predict using all of the trained models. 663 | /// 664 | /// # Arguments 665 | /// 666 | /// * `x` - The input data 667 | /// * `algo` - The algorithm to use 668 | /// 669 | /// # Returns 670 | /// 671 | /// * The predicted values 672 | fn predict_blended_model(&self, x: &DenseMatrix, algo: Algorithm) -> Vec { 673 | // Make the data 674 | let mut meta_x: Vec> = Vec::new(); 675 | for i in 0..self.comparison.len() { 676 | let model = &self.comparison[i]; 677 | meta_x.push(self.predict_by_model(x, model)); 678 | } 679 | 680 | // 681 | let xdm = DenseMatrix::from_2d_vec(&meta_x).transpose(); 682 | let metamodel = &self.metamodel.model; 683 | 684 | // Train the model 685 | algo.get_predictor()(&xdm, metamodel, &self.settings) 686 | } 687 | 688 | /// Predict using a single model. 689 | /// 690 | /// # Arguments 691 | /// 692 | /// * `x` - The input data 693 | /// * `model` - The model to use 694 | /// 695 | /// # Returns 696 | /// 697 | /// * The predicted values 698 | fn predict_by_model(&self, x: &DenseMatrix, model: &Model) -> Vec { 699 | model.name.get_predictor()(x, &model.model, &self.settings) 700 | } 701 | 702 | /// Get interaction features for the data. 703 | /// 704 | /// # Arguments 705 | fn interaction_features(mut x: DenseMatrix) -> DenseMatrix { 706 | let (_, width) = x.shape(); 707 | for i in 0..width { 708 | for j in (i + 1)..width { 709 | let feature = elementwise_multiply(&x.get_col_as_vec(i), &x.get_col_as_vec(j)); 710 | let new_column = DenseMatrix::from_row_vector(feature).transpose(); 711 | x = x.h_stack(&new_column); 712 | } 713 | } 714 | x 715 | } 716 | 717 | /// Get polynomial features for the data. 718 | /// 719 | /// # Arguments 720 | /// 721 | /// * `x` - The input data 722 | /// * `order` - The order of the polynomial 723 | /// 724 | /// # Returns 725 | /// 726 | /// * The data with polynomial features 727 | fn polynomial_features(mut x: DenseMatrix, order: usize) -> DenseMatrix { 728 | let (height, width) = x.shape(); 729 | for n in 2..=order { 730 | let combinations = (0..width).combinations_with_replacement(n); 731 | for combo in combinations { 732 | let mut feature = vec![1.0; height]; 733 | for column in combo { 734 | feature = elementwise_multiply(&x.get_col_as_vec(column), &feature); 735 | } 736 | let new_column = DenseMatrix::from_row_vector(feature).transpose(); 737 | x = x.h_stack(&new_column); 738 | } 739 | } 740 | x 741 | } 742 | 743 | /// Train PCA on the data for preprocessing. 744 | /// 745 | /// # Arguments 746 | /// 747 | /// * `x` - The input data 748 | /// * `n` - The number of components to use 749 | fn train_pca(&mut self, x: &DenseMatrix, n: usize) { 750 | let pca = PCA::fit( 751 | x, 752 | PCAParameters::default() 753 | .with_n_components(n) 754 | .with_use_correlation_matrix(true), 755 | ) 756 | .unwrap(); 757 | self.preprocessing_pca = Some(pca); 758 | } 759 | 760 | /// Get PCA features for the data using the trained PCA preprocessor. 761 | /// 762 | /// # Arguments 763 | /// 764 | /// * `x` - The input data 765 | fn pca_features(&self, x: &DenseMatrix, _: usize) -> DenseMatrix { 766 | self.preprocessing_pca 767 | .as_ref() 768 | .unwrap() 769 | .transform(x) 770 | .unwrap() 771 | } 772 | 773 | /// Train SVD on the data for preprocessing. 774 | /// 775 | /// # Arguments 776 | /// 777 | /// * `x` - The input data 778 | /// * `n` - The number of components to use 779 | fn train_svd(&mut self, x: &DenseMatrix, n: usize) { 780 | let svd = SVD::fit(x, SVDParameters::default().with_n_components(n)).unwrap(); 781 | self.preprocessing_svd = Some(svd); 782 | } 783 | 784 | /// Get SVD features for the data. 785 | fn svd_features(&self, x: &DenseMatrix, _: usize) -> DenseMatrix { 786 | self.preprocessing_svd 787 | .as_ref() 788 | .unwrap() 789 | .transform(x) 790 | .unwrap() 791 | } 792 | 793 | /// Pre process the data. 794 | /// 795 | /// # Arguments 796 | /// 797 | /// * `x` - The input data 798 | /// 799 | /// # Returns 800 | /// 801 | /// * The preprocessed data 802 | fn preprocess(&self, x: DenseMatrix) -> DenseMatrix { 803 | match self.settings.preprocessing { 804 | PreProcessing::None => x, 805 | PreProcessing::AddInteractions => Self::interaction_features(x), 806 | PreProcessing::AddPolynomial { order } => Self::polynomial_features(x, order), 807 | PreProcessing::ReplaceWithPCA { 808 | number_of_components, 809 | } => self.pca_features(&x, number_of_components), 810 | PreProcessing::ReplaceWithSVD { 811 | number_of_components, 812 | } => self.svd_features(&x, number_of_components), 813 | } 814 | } 815 | 816 | /// Count the number of classes in the data. 817 | /// 818 | /// # Arguments 819 | /// 820 | /// * `y` - The data to count the classes in 821 | /// 822 | /// # Returns 823 | /// 824 | /// * The number of classes 825 | fn count_classes(y: &[f32]) -> usize { 826 | let mut sorted_targets = y.to_vec(); 827 | sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Equal)); 828 | sorted_targets.dedup(); 829 | sorted_targets.len() 830 | } 831 | 832 | /// Record a model in the comparison. 833 | fn record_model(&mut self, model: (CrossValidationResult, Algorithm, Duration, Vec)) { 834 | self.comparison.push(Model { 835 | score: model.0, 836 | name: model.1, 837 | duration: model.2, 838 | model: model.3, 839 | }); 840 | self.sort(); 841 | } 842 | 843 | /// Sort the models in the comparison by their mean test scores. 844 | fn sort(&mut self) { 845 | self.comparison.sort_by(|a, b| { 846 | a.score 847 | .mean_test_score() 848 | .partial_cmp(&b.score.mean_test_score()) 849 | .unwrap_or(Equal) 850 | }); 851 | if self.settings.sort_by == Metric::RSquared || self.settings.sort_by == Metric::Accuracy { 852 | self.comparison.reverse(); 853 | } 854 | } 855 | } 856 | 857 | impl Display for SupervisedModel { 858 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 859 | let mut table = Table::new(); 860 | table.load_preset(UTF8_FULL); 861 | table.apply_modifier(UTF8_SOLID_INNER_BORDERS); 862 | table.set_header(vec![ 863 | Cell::new("Model").add_attribute(Attribute::Bold), 864 | Cell::new("Time").add_attribute(Attribute::Bold), 865 | Cell::new(format!("Training {}", self.settings.sort_by)).add_attribute(Attribute::Bold), 866 | Cell::new(format!("Testing {}", self.settings.sort_by)).add_attribute(Attribute::Bold), 867 | ]); 868 | for model in &self.comparison { 869 | let mut row_vec = vec![]; 870 | row_vec.push(format!("{}", &model.name)); 871 | row_vec.push(format!("{}", format_duration(model.duration))); 872 | let decider = 873 | ((model.score.mean_train_score() + model.score.mean_test_score()) / 2.0).abs(); 874 | if decider > 0.01 && decider < 1000.0 { 875 | row_vec.push(format!("{:.2}", &model.score.mean_train_score())); 876 | row_vec.push(format!("{:.2}", &model.score.mean_test_score())); 877 | } else { 878 | row_vec.push(format!("{:.3e}", &model.score.mean_train_score())); 879 | row_vec.push(format!("{:.3e}", &model.score.mean_test_score())); 880 | } 881 | 882 | table.add_row(row_vec); 883 | } 884 | 885 | let mut meta_table = Table::new(); 886 | meta_table.load_preset(UTF8_FULL); 887 | meta_table.apply_modifier(UTF8_SOLID_INNER_BORDERS); 888 | meta_table.set_header(vec![ 889 | Cell::new("Meta Model").add_attribute(Attribute::Bold), 890 | Cell::new(format!("Training {}", self.settings.sort_by)).add_attribute(Attribute::Bold), 891 | Cell::new(format!("Testing {}", self.settings.sort_by)).add_attribute(Attribute::Bold), 892 | ]); 893 | 894 | // Populate row 895 | let mut row_vec = vec![]; 896 | row_vec.push(format!("{}", self.metamodel.name)); 897 | let decider = ((self.metamodel.score.mean_train_score() 898 | + self.metamodel.score.mean_test_score()) 899 | / 2.0) 900 | .abs(); 901 | if decider > 0.01 && decider < 1000.0 { 902 | row_vec.push(format!("{:.2}", self.metamodel.score.mean_train_score())); 903 | row_vec.push(format!("{:.2}", self.metamodel.score.mean_test_score())); 904 | } else { 905 | row_vec.push(format!("{:.3e}", self.metamodel.score.mean_train_score())); 906 | row_vec.push(format!("{:.3e}", self.metamodel.score.mean_test_score())); 907 | } 908 | 909 | // Add row to table 910 | meta_table.add_row(row_vec); 911 | 912 | // Write 913 | write!(f, "{table}\n{meta_table}") 914 | } 915 | } 916 | 917 | /// This contains the results of a single model 918 | #[derive(serde::Serialize, serde::Deserialize)] 919 | struct Model { 920 | /// The cross validation score of the model 921 | #[serde(with = "CrossValidationResultDef")] 922 | score: CrossValidationResult, 923 | /// The algorithm used 924 | name: Algorithm, 925 | /// The time it took to train the model 926 | duration: Duration, 927 | /// What is this? TODO 928 | model: Vec, 929 | } 930 | 931 | impl Default for Model { 932 | fn default() -> Self { 933 | Self { 934 | score: CrossValidationResult { 935 | test_score: vec![], 936 | train_score: vec![], 937 | }, 938 | name: Algorithm::Linear, 939 | duration: Duration::default(), 940 | model: vec![], 941 | } 942 | } 943 | } 944 | 945 | /// This is a wrapper for the `CrossValidationResult` 946 | #[derive(serde::Serialize, serde::Deserialize)] 947 | #[serde(remote = "CrossValidationResult::")] 948 | struct CrossValidationResultDef { 949 | /// Vector with test scores on each cv split 950 | pub test_score: Vec, 951 | /// Vector with training scores on each cv split 952 | pub train_score: Vec, 953 | } 954 | -------------------------------------------------------------------------------- /src/settings/knn_classifier_parameters.rs: -------------------------------------------------------------------------------- 1 | //! KNN classifier parameters 2 | 3 | use crate::utils::Distance; 4 | pub use smartcore::{algorithm::neighbour::KNNAlgorithmName, neighbors::KNNWeightFunction}; 5 | 6 | /// Parameters for k-nearest neighbors (KNN) classification 7 | #[derive(serde::Serialize, serde::Deserialize)] 8 | pub struct KNNClassifierParameters { 9 | /// Number of nearest neighbors to use 10 | pub(crate) k: usize, 11 | /// Weighting function to use with KNN regression 12 | pub(crate) weight: KNNWeightFunction, 13 | /// Search algorithm to use with KNN regression 14 | pub(crate) algorithm: KNNAlgorithmName, 15 | /// Distance metric to use with KNN regression 16 | pub(crate) distance: Distance, 17 | } 18 | 19 | impl KNNClassifierParameters { 20 | /// Define the number of nearest neighbors to use 21 | #[must_use] 22 | pub const fn with_k(mut self, k: usize) -> Self { 23 | self.k = k; 24 | self 25 | } 26 | 27 | /// Define the weighting function to use with KNN regression 28 | #[must_use] 29 | pub const fn with_weight(mut self, weight: KNNWeightFunction) -> Self { 30 | self.weight = weight; 31 | self 32 | } 33 | 34 | /// Define the search algorithm to use with KNN regression 35 | #[must_use] 36 | pub const fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self { 37 | self.algorithm = algorithm; 38 | self 39 | } 40 | 41 | /// Define the distance metric to use with KNN regression 42 | #[must_use] 43 | pub const fn with_distance(mut self, distance: Distance) -> Self { 44 | self.distance = distance; 45 | self 46 | } 47 | } 48 | 49 | impl Default for KNNClassifierParameters { 50 | fn default() -> Self { 51 | Self { 52 | k: 3, 53 | weight: KNNWeightFunction::Uniform, 54 | algorithm: KNNAlgorithmName::CoverTree, 55 | distance: Distance::Euclidean, 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/settings/knn_regressor_parameters.rs: -------------------------------------------------------------------------------- 1 | //! KNN regressor parameters 2 | 3 | use crate::utils::Distance; 4 | pub use smartcore::{algorithm::neighbour::KNNAlgorithmName, neighbors::KNNWeightFunction}; 5 | 6 | /// Parameters for k-nearest neighbor (KNN) regression 7 | #[derive(serde::Serialize, serde::Deserialize)] 8 | pub struct KNNRegressorParameters { 9 | /// Number of nearest neighbors to use 10 | pub(crate) k: usize, 11 | /// Weighting function to use with KNN regression 12 | pub(crate) weight: KNNWeightFunction, 13 | /// Search algorithm to use with KNN regression 14 | pub(crate) algorithm: KNNAlgorithmName, 15 | /// Distance metric to use with KNN regression 16 | pub(crate) distance: Distance, 17 | } 18 | 19 | impl KNNRegressorParameters { 20 | /// Define the number of nearest neighbors to use 21 | #[must_use] 22 | pub const fn with_k(mut self, k: usize) -> Self { 23 | self.k = k; 24 | self 25 | } 26 | 27 | /// Define the weighting function to use with KNN regression 28 | #[must_use] 29 | pub const fn with_weight(mut self, weight: KNNWeightFunction) -> Self { 30 | self.weight = weight; 31 | self 32 | } 33 | 34 | /// Define the search algorithm to use with KNN regression 35 | #[must_use] 36 | pub const fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self { 37 | self.algorithm = algorithm; 38 | self 39 | } 40 | 41 | /// Define the distance metric to use with KNN regression 42 | #[must_use] 43 | pub const fn with_distance(mut self, distance: Distance) -> Self { 44 | self.distance = distance; 45 | self 46 | } 47 | } 48 | 49 | impl Default for KNNRegressorParameters { 50 | fn default() -> Self { 51 | Self { 52 | k: 3, 53 | weight: KNNWeightFunction::Uniform, 54 | algorithm: KNNAlgorithmName::CoverTree, 55 | distance: Distance::Euclidean, 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/settings/mod.rs: -------------------------------------------------------------------------------- 1 | //! # Settings customization 2 | //! This module contains capabilities for the detailed customization of algorithm settings. 3 | //! ## Complete regression customization 4 | //! ``` 5 | //! use automl::settings::{ 6 | //! Algorithm, DecisionTreeRegressorParameters, Distance, ElasticNetParameters, 7 | //! KNNAlgorithmName, KNNRegressorParameters, KNNWeightFunction, Kernel, LassoParameters, 8 | //! LinearRegressionParameters, LinearRegressionSolverName, Metric, 9 | //! RandomForestRegressorParameters, RidgeRegressionParameters, RidgeRegressionSolverName, 10 | //! SVRParameters, 11 | //! }; 12 | //! 13 | //! let settings = automl::Settings::default_regression() 14 | //! .with_number_of_folds(3) 15 | //! .shuffle_data(true) 16 | //! .verbose(true) 17 | //! .skip(Algorithm::RandomForestRegressor) 18 | //! .sorted_by(Metric::RSquared) 19 | //! .with_linear_settings( 20 | //! LinearRegressionParameters::default().with_solver(LinearRegressionSolverName::QR), 21 | //! ) 22 | //! .with_lasso_settings( 23 | //! LassoParameters::default() 24 | //! .with_alpha(10.0) 25 | //! .with_tol(1e-10) 26 | //! .with_normalize(true) 27 | //! .with_max_iter(10_000), 28 | //! ) 29 | //! .with_ridge_settings( 30 | //! RidgeRegressionParameters::default() 31 | //! .with_alpha(10.0) 32 | //! .with_normalize(true) 33 | //! .with_solver(RidgeRegressionSolverName::Cholesky), 34 | //! ) 35 | //! .with_elastic_net_settings( 36 | //! ElasticNetParameters::default() 37 | //! .with_tol(1e-10) 38 | //! .with_normalize(true) 39 | //! .with_alpha(1.0) 40 | //! .with_max_iter(10_000) 41 | //! .with_l1_ratio(0.5), 42 | //! ) 43 | //! .with_knn_regressor_settings( 44 | //! KNNRegressorParameters::default() 45 | //! .with_algorithm(KNNAlgorithmName::CoverTree) 46 | //! .with_k(3) 47 | //! .with_distance(Distance::Euclidean) 48 | //! .with_weight(KNNWeightFunction::Uniform), 49 | //! ) 50 | //! .with_svr_settings( 51 | //! SVRParameters::default() 52 | //! .with_eps(1e-10) 53 | //! .with_tol(1e-10) 54 | //! .with_c(1.0) 55 | //! .with_kernel(Kernel::Linear), 56 | //! ) 57 | //! .with_random_forest_regressor_settings( 58 | //! RandomForestRegressorParameters::default() 59 | //! .with_m(100) 60 | //! .with_max_depth(5) 61 | //! .with_min_samples_leaf(20) 62 | //! .with_n_trees(100) 63 | //! .with_min_samples_split(20), 64 | //! ) 65 | //! .with_decision_tree_regressor_settings( 66 | //! DecisionTreeRegressorParameters::default() 67 | //! .with_min_samples_split(20) 68 | //! .with_max_depth(5) 69 | //! .with_min_samples_leaf(20), 70 | //! ); 71 | //! ``` 72 | //! ## Complete classification customization 73 | //! ``` 74 | //! use automl::settings::{ 75 | //! Algorithm, CategoricalNBParameters, DecisionTreeClassifierParameters, Distance, 76 | //! GaussianNBParameters, KNNAlgorithmName, KNNClassifierParameters, KNNWeightFunction, Kernel, 77 | //! LogisticRegressionParameters, LogisticRegressionSolverName, Metric, 78 | //! RandomForestClassifierParameters, SVCParameters, 79 | //! }; 80 | //! 81 | //! let settings = automl::Settings::default_classification() 82 | //! .with_number_of_folds(3) 83 | //! .shuffle_data(true) 84 | //! .verbose(true) 85 | //! .skip(Algorithm::RandomForestClassifier) 86 | //! .sorted_by(Metric::Accuracy) 87 | //! .with_random_forest_classifier_settings( 88 | //! RandomForestClassifierParameters::default() 89 | //! .with_m(100) 90 | //! .with_max_depth(5) 91 | //! .with_min_samples_leaf(20) 92 | //! .with_n_trees(100) 93 | //! .with_min_samples_split(20), 94 | //! ) 95 | //! .with_logistic_settings( 96 | //! LogisticRegressionParameters::default() 97 | //! .with_alpha(1.0) 98 | //! .with_solver(LogisticRegressionSolverName::LBFGS), 99 | //! ) 100 | //! .with_svc_settings( 101 | //! SVCParameters::default() 102 | //! .with_epoch(10) 103 | //! .with_tol(1e-10) 104 | //! .with_c(1.0) 105 | //! .with_kernel(Kernel::Linear), 106 | //! ) 107 | //! .with_decision_tree_classifier_settings( 108 | //! DecisionTreeClassifierParameters::default() 109 | //! .with_min_samples_split(20) 110 | //! .with_max_depth(5) 111 | //! .with_min_samples_leaf(20), 112 | //! ) 113 | //! .with_knn_classifier_settings( 114 | //! KNNClassifierParameters::default() 115 | //! .with_algorithm(KNNAlgorithmName::CoverTree) 116 | //! .with_k(3) 117 | //! .with_distance(Distance::Euclidean) 118 | //! .with_weight(KNNWeightFunction::Uniform), 119 | //! ) 120 | //! .with_gaussian_nb_settings(GaussianNBParameters::default().with_priors(vec![1.0, 1.0])) 121 | //! .with_categorical_nb_settings(CategoricalNBParameters::default().with_alpha(1.0)); 122 | //! ``` 123 | 124 | pub use crate::utils::{Distance, Kernel}; 125 | 126 | /// Weighting functions for k-nearest neighbor (KNN) regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 127 | pub use smartcore::neighbors::KNNWeightFunction; 128 | 129 | /// Search algorithms for k-nearest neighbor (KNN) regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 130 | pub use smartcore::algorithm::neighbour::KNNAlgorithmName; 131 | 132 | /// Parameters for random forest regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 133 | pub use smartcore::ensemble::random_forest_regressor::RandomForestRegressorParameters; 134 | 135 | /// Parameters for decision tree regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 136 | pub use smartcore::tree::decision_tree_regressor::DecisionTreeRegressorParameters; 137 | 138 | /// Parameters for elastic net regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 139 | pub use smartcore::linear::elastic_net::ElasticNetParameters; 140 | 141 | /// Parameters for LASSO regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 142 | pub use smartcore::linear::lasso::LassoParameters; 143 | 144 | /// Solvers for linear regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 145 | pub use smartcore::linear::linear_regression::LinearRegressionSolverName; 146 | 147 | /// Parameters for linear regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 148 | pub use smartcore::linear::linear_regression::LinearRegressionParameters; 149 | 150 | /// Parameters for ridge regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 151 | pub use smartcore::linear::ridge_regression::RidgeRegressionParameters; 152 | 153 | /// Solvers for ridge regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 154 | pub use smartcore::linear::ridge_regression::RidgeRegressionSolverName; 155 | 156 | /// Parameters for Gaussian naive bayes (re-export from [Smartcore](https://docs.rs/smartcore/)) 157 | pub use smartcore::naive_bayes::gaussian::GaussianNBParameters; 158 | 159 | /// Parameters for categorical naive bayes (re-export from [Smartcore](https://docs.rs/smartcore/)) 160 | pub use smartcore::naive_bayes::categorical::CategoricalNBParameters; 161 | 162 | /// Parameters for random forest classification (re-export from [Smartcore](https://docs.rs/smartcore/)) 163 | pub use smartcore::ensemble::random_forest_classifier::RandomForestClassifierParameters; 164 | 165 | /// Parameters for logistic regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 166 | pub use smartcore::linear::logistic_regression::LogisticRegressionParameters; 167 | 168 | /// Parameters for logistic regression (re-export from [Smartcore](https://docs.rs/smartcore/)) 169 | pub use smartcore::linear::logistic_regression::LogisticRegressionSolverName; 170 | 171 | /// Parameters for decision tree classification (re-export from [Smartcore](https://docs.rs/smartcore/)) 172 | pub use smartcore::tree::decision_tree_classifier::DecisionTreeClassifierParameters; 173 | 174 | mod knn_regressor_parameters; 175 | pub use knn_regressor_parameters::KNNRegressorParameters; 176 | 177 | mod svr_parameters; 178 | pub use svr_parameters::SVRParameters; 179 | 180 | mod knn_classifier_parameters; 181 | pub use knn_classifier_parameters::KNNClassifierParameters; 182 | 183 | mod svc_parameters; 184 | pub use svc_parameters::SVCParameters; 185 | 186 | use smartcore::linalg::naive::dense_matrix::DenseMatrix; 187 | use std::fmt::{Display, Formatter}; 188 | 189 | use super::algorithms::{ 190 | CategoricalNaiveBayesClassifierWrapper, DecisionTreeClassifierWrapper, 191 | DecisionTreeRegressorWrapper, ElasticNetRegressorWrapper, GaussianNaiveBayesClassifierWrapper, 192 | KNNClassifierWrapper, KNNRegressorWrapper, LassoRegressorWrapper, LinearRegressorWrapper, 193 | LogisticRegressionWrapper, ModelWrapper, RandomForestClassifierWrapper, 194 | RandomForestRegressorWrapper, RidgeRegressorWrapper, SupportVectorClassifierWrapper, 195 | SupportVectorRegressorWrapper, 196 | }; 197 | 198 | mod settings_struct; 199 | #[doc(no_inline)] 200 | pub use settings_struct::Settings; 201 | 202 | /// Metrics for evaluating algorithms 203 | #[non_exhaustive] 204 | #[derive(PartialEq, Eq, serde::Serialize, serde::Deserialize)] 205 | pub enum Metric { 206 | /// Sort by R^2 207 | RSquared, 208 | /// Sort by MAE 209 | MeanAbsoluteError, 210 | /// Sort by MSE 211 | MeanSquaredError, 212 | /// Sort by Accuracy 213 | Accuracy, 214 | /// Sort by none 215 | None, 216 | } 217 | 218 | impl Display for Metric { 219 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 220 | match self { 221 | Self::RSquared => write!(f, "R^2"), 222 | Self::MeanAbsoluteError => write!(f, "MAE"), 223 | Self::MeanSquaredError => write!(f, "MSE"), 224 | Self::Accuracy => write!(f, "Accuracy"), 225 | Self::None => panic!("A metric must be set."), 226 | } 227 | } 228 | } 229 | 230 | /// Algorithm options 231 | #[derive(PartialEq, Eq, Copy, Clone, serde::Serialize, serde::Deserialize)] 232 | pub enum Algorithm { 233 | /// Decision tree regressor 234 | DecisionTreeRegressor, 235 | /// KNN Regressor 236 | KNNRegressor, 237 | /// Random forest regressor 238 | RandomForestRegressor, 239 | /// Linear regressor 240 | Linear, 241 | /// Ridge regressor 242 | Ridge, 243 | /// Lasso regressor 244 | Lasso, 245 | /// Elastic net regressor 246 | ElasticNet, 247 | /// Support vector regressor 248 | SVR, 249 | /// Decision tree classifier 250 | DecisionTreeClassifier, 251 | /// KNN classifier 252 | KNNClassifier, 253 | /// Random forest classifier 254 | RandomForestClassifier, 255 | /// Support vector classifier 256 | SVC, 257 | /// Logistic regression classifier 258 | LogisticRegression, 259 | /// Gaussian Naive Bayes classifier 260 | GaussianNaiveBayes, 261 | /// Categorical Naive Bayes classifier 262 | CategoricalNaiveBayes, 263 | } 264 | 265 | impl Algorithm { 266 | /// Get the `predict` method for the underlying algorithm. 267 | pub(crate) fn get_predictor(self) -> fn(&DenseMatrix, &Vec, &Settings) -> Vec { 268 | match self { 269 | Self::Linear => LinearRegressorWrapper::predict, 270 | Self::Lasso => LassoRegressorWrapper::predict, 271 | Self::Ridge => RidgeRegressorWrapper::predict, 272 | Self::ElasticNet => ElasticNetRegressorWrapper::predict, 273 | Self::RandomForestRegressor => RandomForestRegressorWrapper::predict, 274 | Self::KNNRegressor => KNNRegressorWrapper::predict, 275 | Self::SVR => SupportVectorRegressorWrapper::predict, 276 | Self::DecisionTreeRegressor => DecisionTreeRegressorWrapper::predict, 277 | Self::LogisticRegression => LogisticRegressionWrapper::predict, 278 | Self::RandomForestClassifier => RandomForestClassifierWrapper::predict, 279 | Self::DecisionTreeClassifier => DecisionTreeClassifierWrapper::predict, 280 | Self::KNNClassifier => KNNClassifierWrapper::predict, 281 | Self::SVC => SupportVectorClassifierWrapper::predict, 282 | Self::GaussianNaiveBayes => GaussianNaiveBayesClassifierWrapper::predict, 283 | Self::CategoricalNaiveBayes => CategoricalNaiveBayesClassifierWrapper::predict, 284 | } 285 | } 286 | 287 | /// Get the `train` method for the underlying algorithm. 288 | pub(crate) fn get_trainer(self) -> fn(&DenseMatrix, &Vec, &Settings) -> Vec { 289 | match self { 290 | Self::Linear => LinearRegressorWrapper::train, 291 | Self::Lasso => LassoRegressorWrapper::train, 292 | Self::Ridge => RidgeRegressorWrapper::train, 293 | Self::ElasticNet => ElasticNetRegressorWrapper::train, 294 | Self::RandomForestRegressor => RandomForestRegressorWrapper::train, 295 | Self::KNNRegressor => KNNRegressorWrapper::train, 296 | Self::SVR => SupportVectorRegressorWrapper::train, 297 | Self::DecisionTreeRegressor => DecisionTreeRegressorWrapper::train, 298 | Self::LogisticRegression => LogisticRegressionWrapper::train, 299 | Self::RandomForestClassifier => RandomForestClassifierWrapper::train, 300 | Self::DecisionTreeClassifier => DecisionTreeClassifierWrapper::train, 301 | Self::KNNClassifier => KNNClassifierWrapper::train, 302 | Self::SVC => SupportVectorClassifierWrapper::train, 303 | Self::GaussianNaiveBayes => GaussianNaiveBayesClassifierWrapper::train, 304 | Self::CategoricalNaiveBayes => CategoricalNaiveBayesClassifierWrapper::train, 305 | } 306 | } 307 | } 308 | 309 | impl Display for Algorithm { 310 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 311 | match self { 312 | Self::DecisionTreeRegressor => write!(f, "Decision Tree Regressor"), 313 | Self::KNNRegressor => write!(f, "KNN Regressor"), 314 | Self::RandomForestRegressor => write!(f, "Random Forest Regressor"), 315 | Self::Linear => write!(f, "Linear Regressor"), 316 | Self::Ridge => write!(f, "Ridge Regressor"), 317 | Self::Lasso => write!(f, "LASSO Regressor"), 318 | Self::ElasticNet => write!(f, "Elastic Net Regressor"), 319 | Self::SVR => write!(f, "Support Vector Regressor"), 320 | Self::DecisionTreeClassifier => write!(f, "Decision Tree Classifier"), 321 | Self::KNNClassifier => write!(f, "KNN Classifier"), 322 | Self::RandomForestClassifier => write!(f, "Random Forest Classifier"), 323 | Self::LogisticRegression => write!(f, "Logistic Regression Classifier"), 324 | Self::SVC => write!(f, "Support Vector Classifier"), 325 | Self::GaussianNaiveBayes => write!(f, "Gaussian Naive Bayes"), 326 | Self::CategoricalNaiveBayes => write!(f, "Categorical Naive Bayes"), 327 | } 328 | } 329 | } 330 | 331 | /// Options for pre-processing the data 332 | #[derive(serde::Serialize, serde::Deserialize)] 333 | pub enum PreProcessing { 334 | /// Don't do any preprocessing 335 | None, 336 | /// Add interaction terms to the data 337 | AddInteractions, 338 | /// Add polynomial terms of order n to the data 339 | AddPolynomial { 340 | /// The order of the polynomial to add (i.e., x^order) 341 | order: usize, 342 | }, 343 | /// Replace the data with n PCA terms 344 | ReplaceWithPCA { 345 | /// The number of components to use from PCA 346 | number_of_components: usize, 347 | }, 348 | /// Replace the data with n PCA terms 349 | ReplaceWithSVD { 350 | /// The number of components to use from PCA 351 | number_of_components: usize, 352 | }, 353 | } 354 | 355 | impl Display for PreProcessing { 356 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 357 | match self { 358 | Self::None => write!(f, "None"), 359 | Self::AddInteractions => write!(f, "Interaction terms added"), 360 | Self::AddPolynomial { order } => { 361 | write!(f, "Polynomial terms added (order = {order})") 362 | } 363 | Self::ReplaceWithPCA { 364 | number_of_components, 365 | } => write!(f, "Replaced with PCA features (n = {number_of_components})"), 366 | 367 | Self::ReplaceWithSVD { 368 | number_of_components, 369 | } => write!(f, "Replaced with SVD features (n = {number_of_components})"), 370 | } 371 | } 372 | } 373 | 374 | /// Final model approach 375 | #[derive(serde::Serialize, serde::Deserialize)] 376 | pub enum FinalModel { 377 | /// Do not train a final model 378 | None, 379 | /// Select the best model from the comparison set as the final model 380 | Best, 381 | /// Use a blending approach to produce a final model 382 | Blending { 383 | /// Which algorithm to use as a meta-learner 384 | algorithm: Algorithm, 385 | /// How much data to retain to train the blending model 386 | meta_training_fraction: f32, 387 | /// How much data to retain to test the blending model 388 | meta_testing_fraction: f32, 389 | }, 390 | // /// Use a stacking approach to produce a final model (not implemented) 391 | // Stacking { 392 | // /// How much data to retain to train the blending model 393 | // meta_testing_fraction: f32, 394 | // }, 395 | } 396 | 397 | impl FinalModel { 398 | /// Default values for a blending model (linear regression, 30% of all data reserved for training the blending model) 399 | #[must_use] 400 | pub const fn default_blending() -> Self { 401 | Self::Blending { 402 | algorithm: Algorithm::Linear, 403 | meta_training_fraction: 0.15, 404 | meta_testing_fraction: 0.15, 405 | } 406 | } 407 | } 408 | -------------------------------------------------------------------------------- /src/settings/settings_struct.rs: -------------------------------------------------------------------------------- 1 | //! Settings for the automl crate 2 | 3 | use comfy_table::{ 4 | modifiers::UTF8_SOLID_INNER_BORDERS, presets::UTF8_FULL, Attribute, Cell, Table, 5 | }; 6 | 7 | use super::{ 8 | Algorithm, CategoricalNBParameters, DecisionTreeClassifierParameters, 9 | DecisionTreeRegressorParameters, ElasticNetParameters, FinalModel, GaussianNBParameters, 10 | KNNClassifierParameters, KNNRegressorParameters, LassoParameters, LinearRegressionParameters, 11 | LinearRegressionSolverName, LogisticRegressionParameters, Metric, PreProcessing, 12 | RandomForestClassifierParameters, RandomForestRegressorParameters, RidgeRegressionParameters, 13 | RidgeRegressionSolverName, SVCParameters, SVRParameters, 14 | }; 15 | 16 | use crate::utils::{ 17 | debug_option, print_knn_search_algorithm, print_knn_weight_function, print_option, 18 | }; 19 | 20 | use smartcore::{ 21 | metrics::{accuracy, mean_absolute_error, mean_squared_error, r2}, 22 | model_selection::KFold, 23 | tree::decision_tree_classifier::SplitCriterion, 24 | }; 25 | 26 | use std::fmt::{Display, Formatter}; 27 | use std::io::{Read, Write}; 28 | 29 | /// Settings for supervised models 30 | /// 31 | /// Any algorithms in the `skiplist` member will be skipped during training. 32 | #[derive(serde::Serialize, serde::Deserialize)] 33 | pub struct Settings { 34 | /// The metric to sort by 35 | pub(crate) sort_by: Metric, 36 | /// The type of model to train 37 | model_type: ModelType, 38 | /// The algorithms to skip 39 | pub(crate) skiplist: Vec, 40 | /// The number of folds for cross-validation 41 | number_of_folds: usize, 42 | /// Whether or not to shuffle the data 43 | pub(crate) shuffle: bool, 44 | /// Whether or not to be verbose 45 | verbose: bool, 46 | /// The approach to use for the final model 47 | pub(crate) final_model_approach: FinalModel, 48 | /// The kind of preprocessing to perform 49 | pub(crate) preprocessing: PreProcessing, 50 | /// Optional settings for linear regression 51 | pub(crate) linear_settings: Option, 52 | /// Optional settings for support vector regressor 53 | pub(crate) svr_settings: Option, 54 | /// Optional settings for lasso regression 55 | pub(crate) lasso_settings: Option>, 56 | /// Optional settings for ridge regression 57 | pub(crate) ridge_settings: Option>, 58 | /// Optional settings for elastic net 59 | pub(crate) elastic_net_settings: Option>, 60 | /// Optional settings for decision tree regressor 61 | pub(crate) decision_tree_regressor_settings: Option, 62 | /// Optional settings for random forest regressor 63 | pub(crate) random_forest_regressor_settings: Option, 64 | /// Optional settings for KNN regressor 65 | pub(crate) knn_regressor_settings: Option, 66 | /// Optional settings for logistic regression 67 | pub(crate) logistic_settings: Option>, 68 | /// Optional settings for random forest 69 | pub(crate) random_forest_classifier_settings: Option, 70 | /// Optional settings for KNN classifier 71 | pub(crate) knn_classifier_settings: Option, 72 | /// Optional settings for support vector classifier 73 | pub(crate) svc_settings: Option, 74 | /// Optional settings for decision tree classifier 75 | pub(crate) decision_tree_classifier_settings: Option, 76 | /// Optional settings for Gaussian Naive Bayes 77 | pub(crate) gaussian_nb_settings: Option>, 78 | /// Optional settings for Categorical Naive Bayes 79 | pub(crate) categorical_nb_settings: Option>, 80 | } 81 | 82 | impl Default for Settings { 83 | fn default() -> Self { 84 | Self { 85 | sort_by: Metric::RSquared, 86 | model_type: ModelType::None, 87 | final_model_approach: FinalModel::Best, 88 | skiplist: vec![ 89 | Algorithm::LogisticRegression, 90 | Algorithm::RandomForestClassifier, 91 | Algorithm::KNNClassifier, 92 | Algorithm::SVC, 93 | Algorithm::DecisionTreeClassifier, 94 | Algorithm::CategoricalNaiveBayes, 95 | Algorithm::GaussianNaiveBayes, 96 | Algorithm::Linear, 97 | Algorithm::Lasso, 98 | Algorithm::Ridge, 99 | Algorithm::ElasticNet, 100 | Algorithm::SVR, 101 | Algorithm::DecisionTreeRegressor, 102 | Algorithm::RandomForestRegressor, 103 | Algorithm::KNNRegressor, 104 | ], 105 | preprocessing: PreProcessing::None, 106 | number_of_folds: 10, 107 | shuffle: false, 108 | verbose: false, 109 | linear_settings: None, 110 | svr_settings: None, 111 | lasso_settings: None, 112 | ridge_settings: None, 113 | elastic_net_settings: None, 114 | decision_tree_regressor_settings: None, 115 | random_forest_regressor_settings: None, 116 | knn_regressor_settings: None, 117 | logistic_settings: None, 118 | random_forest_classifier_settings: None, 119 | knn_classifier_settings: None, 120 | svc_settings: None, 121 | decision_tree_classifier_settings: None, 122 | gaussian_nb_settings: None, 123 | categorical_nb_settings: None, 124 | } 125 | } 126 | } 127 | 128 | impl Settings { 129 | /// Get the k-fold cross-validator 130 | pub(crate) fn get_kfolds(&self) -> KFold { 131 | KFold::default() 132 | .with_n_splits(self.number_of_folds) 133 | .with_shuffle(self.shuffle) 134 | } 135 | 136 | /// Get the metric to sort by 137 | pub(crate) fn get_metric(&self) -> fn(&Vec, &Vec) -> f32 { 138 | match self.sort_by { 139 | Metric::RSquared => r2, 140 | Metric::MeanAbsoluteError => mean_absolute_error, 141 | Metric::MeanSquaredError => mean_squared_error, 142 | Metric::Accuracy => accuracy, 143 | Metric::None => panic!("A metric must be set."), 144 | } 145 | } 146 | 147 | /// Creates default settings for regression 148 | /// ``` 149 | /// # use automl::Settings; 150 | /// let settings = Settings::default_regression(); 151 | /// ``` 152 | #[must_use] 153 | pub fn default_regression() -> Self { 154 | Self { 155 | sort_by: Metric::RSquared, 156 | model_type: ModelType::Regression, 157 | final_model_approach: FinalModel::Best, 158 | skiplist: vec![ 159 | Algorithm::LogisticRegression, 160 | Algorithm::RandomForestClassifier, 161 | Algorithm::KNNClassifier, 162 | Algorithm::SVC, 163 | Algorithm::DecisionTreeClassifier, 164 | Algorithm::CategoricalNaiveBayes, 165 | Algorithm::GaussianNaiveBayes, 166 | ], 167 | preprocessing: PreProcessing::None, 168 | number_of_folds: 10, 169 | shuffle: false, 170 | verbose: false, 171 | linear_settings: Some(LinearRegressionParameters::default()), 172 | svr_settings: Some(SVRParameters::default()), 173 | lasso_settings: Some(LassoParameters::default()), 174 | ridge_settings: Some(RidgeRegressionParameters::default()), 175 | elastic_net_settings: Some(ElasticNetParameters::default()), 176 | decision_tree_regressor_settings: Some(DecisionTreeRegressorParameters::default()), 177 | random_forest_regressor_settings: Some(RandomForestRegressorParameters::default()), 178 | knn_regressor_settings: Some(KNNRegressorParameters::default()), 179 | logistic_settings: None, 180 | random_forest_classifier_settings: None, 181 | knn_classifier_settings: None, 182 | svc_settings: None, 183 | decision_tree_classifier_settings: None, 184 | gaussian_nb_settings: None, 185 | categorical_nb_settings: None, 186 | } 187 | } 188 | 189 | /// Creates default settings for classification 190 | /// ``` 191 | /// # use automl::Settings; 192 | /// let settings = Settings::default_classification(); 193 | /// ``` 194 | #[must_use] 195 | pub fn default_classification() -> Self { 196 | Self { 197 | sort_by: Metric::Accuracy, 198 | model_type: ModelType::Classification, 199 | final_model_approach: FinalModel::Best, 200 | skiplist: vec![ 201 | Algorithm::Linear, 202 | Algorithm::Lasso, 203 | Algorithm::Ridge, 204 | Algorithm::ElasticNet, 205 | Algorithm::SVR, 206 | Algorithm::DecisionTreeRegressor, 207 | Algorithm::RandomForestRegressor, 208 | Algorithm::KNNRegressor, 209 | ], 210 | preprocessing: PreProcessing::None, 211 | number_of_folds: 10, 212 | shuffle: false, 213 | verbose: false, 214 | linear_settings: None, 215 | svr_settings: None, 216 | lasso_settings: None, 217 | ridge_settings: None, 218 | elastic_net_settings: None, 219 | decision_tree_regressor_settings: None, 220 | random_forest_regressor_settings: None, 221 | knn_regressor_settings: None, 222 | logistic_settings: Some(LogisticRegressionParameters::default()), 223 | random_forest_classifier_settings: Some(RandomForestClassifierParameters::default()), 224 | knn_classifier_settings: Some(KNNClassifierParameters::default()), 225 | svc_settings: Some(SVCParameters::default()), 226 | decision_tree_classifier_settings: Some(DecisionTreeClassifierParameters::default()), 227 | gaussian_nb_settings: Some(GaussianNBParameters::default()), 228 | categorical_nb_settings: Some(CategoricalNBParameters::default()), 229 | } 230 | } 231 | 232 | /// Load settings from a settings file 233 | /// ``` 234 | /// # use automl::Settings; 235 | /// # let settings = Settings::default(); 236 | /// # settings.save("tests/load_those_settings.yaml"); 237 | /// let settings = Settings::new_from_file("tests/load_those_settings.yaml"); 238 | /// # std::fs::remove_file("tests/load_those_settings.yaml"); 239 | /// ``` 240 | #[must_use] 241 | pub fn new_from_file(file_name: &str) -> Self { 242 | let mut buf: Vec = Vec::new(); 243 | std::fs::File::open(file_name) 244 | .and_then(|mut f| f.read_to_end(&mut buf)) 245 | .expect("Cannot read settings file."); 246 | serde_yaml::from_slice(&buf).expect("Cannot deserialize settings file.") 247 | } 248 | 249 | /// Save the current settings to a file for later use 250 | /// ``` 251 | /// # use automl::Settings; 252 | /// let settings = Settings::default_regression(); 253 | /// settings.save("tests/save_those_settings.yaml"); 254 | /// # std::fs::remove_file("tests/save_those_settings.yaml"); 255 | /// ``` 256 | pub fn save(&self, file_name: &str) { 257 | let serial = serde_yaml::to_string(&self).expect("Cannot serialize settings."); 258 | std::fs::File::create(file_name) 259 | .and_then(|mut f| f.write_all(serial.as_ref())) 260 | .expect("Cannot write settings to file."); 261 | } 262 | 263 | /// Specify number of folds for cross-validation 264 | /// ``` 265 | /// # use automl::Settings; 266 | /// let settings = Settings::default().with_number_of_folds(3); 267 | /// ``` 268 | #[must_use] 269 | pub const fn with_number_of_folds(mut self, n: usize) -> Self { 270 | self.number_of_folds = n; 271 | self 272 | } 273 | 274 | /// Specify whether or not data should be shuffled 275 | /// ``` 276 | /// # use automl::Settings; 277 | /// let settings = Settings::default().shuffle_data(true); 278 | /// ``` 279 | #[must_use] 280 | pub const fn shuffle_data(mut self, shuffle: bool) -> Self { 281 | self.shuffle = shuffle; 282 | self 283 | } 284 | 285 | /// Specify whether or not to be verbose 286 | /// ``` 287 | /// # use automl::Settings; 288 | /// let settings = Settings::default().verbose(true); 289 | /// ``` 290 | #[must_use] 291 | pub const fn verbose(mut self, verbose: bool) -> Self { 292 | self.verbose = verbose; 293 | self 294 | } 295 | 296 | /// Specify what type of preprocessing should be performed 297 | /// ``` 298 | /// # use automl::Settings; 299 | /// use automl::settings::PreProcessing; 300 | /// let settings = Settings::default().with_preprocessing(PreProcessing::AddInteractions); 301 | /// ``` 302 | #[must_use] 303 | pub const fn with_preprocessing(mut self, pre: PreProcessing) -> Self { 304 | self.preprocessing = pre; 305 | self 306 | } 307 | 308 | /// Specify what type of final model to use 309 | /// ``` 310 | /// # use automl::Settings; 311 | /// use automl::settings::FinalModel; 312 | /// let settings = Settings::default().with_final_model(FinalModel::Best); 313 | /// ``` 314 | #[must_use] 315 | pub const fn with_final_model(mut self, approach: FinalModel) -> Self { 316 | self.final_model_approach = approach; 317 | self 318 | } 319 | 320 | /// Specify algorithms that shouldn't be included in comparison 321 | /// ``` 322 | /// # use automl::Settings; 323 | /// use automl::settings::Algorithm; 324 | /// let settings = Settings::default().skip(Algorithm::RandomForestRegressor); 325 | /// ``` 326 | #[must_use] 327 | pub fn skip(mut self, skip: Algorithm) -> Self { 328 | self.skiplist.push(skip); 329 | self 330 | } 331 | 332 | /// Specify ony one algorithm to train 333 | /// ``` 334 | /// # use automl::Settings; 335 | /// use automl::settings::Algorithm; 336 | /// let settings = Settings::default().only(Algorithm::RandomForestRegressor); 337 | /// ``` 338 | #[must_use] 339 | pub fn only(mut self, only: Algorithm) -> Self { 340 | self.skiplist = Self::default().skiplist; 341 | self.skiplist.retain(|&algo| algo != only); 342 | self 343 | } 344 | 345 | /// Adds a specific sorting function to the settings 346 | /// ``` 347 | /// # use automl::Settings; 348 | /// use automl::settings::Metric; 349 | /// let settings = Settings::default().sorted_by(Metric::RSquared); 350 | /// ``` 351 | #[must_use] 352 | pub const fn sorted_by(mut self, sort_by: Metric) -> Self { 353 | self.sort_by = sort_by; 354 | self 355 | } 356 | 357 | /// Specify settings for Random Forest Classifier 358 | /// ``` 359 | /// # use automl::Settings; 360 | /// use automl::settings::RandomForestClassifierParameters; 361 | /// let settings = Settings::default() 362 | /// .with_random_forest_classifier_settings(RandomForestClassifierParameters::default() 363 | /// .with_m(100) 364 | /// .with_max_depth(5) 365 | /// .with_min_samples_leaf(20) 366 | /// .with_n_trees(100) 367 | /// .with_min_samples_split(20) 368 | /// ); 369 | /// ``` 370 | #[must_use] 371 | pub const fn with_random_forest_classifier_settings( 372 | mut self, 373 | settings: RandomForestClassifierParameters, 374 | ) -> Self { 375 | self.random_forest_classifier_settings = Some(settings); 376 | self 377 | } 378 | 379 | /// Specify settings for logistic regression 380 | /// ``` 381 | /// # use automl::Settings; 382 | /// use automl::settings::LogisticRegressionParameters; 383 | /// let settings = Settings::default() 384 | /// .with_logistic_settings(LogisticRegressionParameters::default()); 385 | /// ``` 386 | #[must_use] 387 | pub const fn with_logistic_settings( 388 | mut self, 389 | settings: LogisticRegressionParameters, 390 | ) -> Self { 391 | self.logistic_settings = Some(settings); 392 | self 393 | } 394 | 395 | /// Specify settings for support vector classifier 396 | /// ``` 397 | /// # use automl::Settings; 398 | /// use automl::settings::{SVCParameters, Kernel}; 399 | /// let settings = Settings::default() 400 | /// .with_svc_settings(SVCParameters::default() 401 | /// .with_epoch(10) 402 | /// .with_tol(1e-10) 403 | /// .with_c(1.0) 404 | /// .with_kernel(Kernel::Linear) 405 | /// ); 406 | /// ``` 407 | #[must_use] 408 | pub const fn with_svc_settings(mut self, settings: SVCParameters) -> Self { 409 | self.svc_settings = Some(settings); 410 | self 411 | } 412 | 413 | /// Specify settings for decision tree classifier 414 | /// ``` 415 | /// # use automl::Settings; 416 | /// use automl::settings::DecisionTreeClassifierParameters; 417 | /// let settings = Settings::default() 418 | /// .with_decision_tree_classifier_settings(DecisionTreeClassifierParameters::default() 419 | /// .with_min_samples_split(20) 420 | /// .with_max_depth(5) 421 | /// .with_min_samples_leaf(20) 422 | /// ); 423 | /// ``` 424 | #[must_use] 425 | pub const fn with_decision_tree_classifier_settings( 426 | mut self, 427 | settings: DecisionTreeClassifierParameters, 428 | ) -> Self { 429 | self.decision_tree_classifier_settings = Some(settings); 430 | self 431 | } 432 | 433 | /// Specify settings for logistic regression 434 | /// ``` 435 | /// # use automl::Settings; 436 | /// use automl::settings::{KNNClassifierParameters, 437 | /// KNNAlgorithmName, KNNWeightFunction, Distance}; 438 | /// let settings = Settings::default() 439 | /// .with_knn_classifier_settings(KNNClassifierParameters::default() 440 | /// .with_algorithm(KNNAlgorithmName::CoverTree) 441 | /// .with_k(3) 442 | /// .with_distance(Distance::Euclidean) 443 | /// .with_weight(KNNWeightFunction::Uniform) 444 | /// ); 445 | /// ``` 446 | #[must_use] 447 | pub const fn with_knn_classifier_settings(mut self, settings: KNNClassifierParameters) -> Self { 448 | self.knn_classifier_settings = Some(settings); 449 | self 450 | } 451 | 452 | /// Specify settings for Gaussian Naive Bayes 453 | /// ``` 454 | /// # use automl::Settings; 455 | /// use automl::settings::GaussianNBParameters; 456 | /// let settings = Settings::default() 457 | /// .with_gaussian_nb_settings(GaussianNBParameters::default() 458 | /// .with_priors(vec![1.0, 1.0]) 459 | /// ); 460 | /// ``` 461 | #[allow(clippy::missing_const_for_fn)] 462 | #[must_use] 463 | pub fn with_gaussian_nb_settings(mut self, settings: GaussianNBParameters) -> Self { 464 | self.gaussian_nb_settings = Some(settings); 465 | self 466 | } 467 | 468 | /// Specify settings for Categorical Naive Bayes 469 | /// ``` 470 | /// # use automl::Settings; 471 | /// use automl::settings::CategoricalNBParameters; 472 | /// let settings = Settings::default() 473 | /// .with_categorical_nb_settings(CategoricalNBParameters::default() 474 | /// .with_alpha(1.0) 475 | /// ); 476 | /// ``` 477 | #[must_use] 478 | pub const fn with_categorical_nb_settings( 479 | mut self, 480 | settings: CategoricalNBParameters, 481 | ) -> Self { 482 | self.categorical_nb_settings = Some(settings); 483 | self 484 | } 485 | 486 | /// Specify settings for linear regression 487 | /// ``` 488 | /// # use automl::Settings; 489 | /// use automl::settings::{LinearRegressionParameters, LinearRegressionSolverName}; 490 | /// let settings = Settings::default() 491 | /// .with_linear_settings(LinearRegressionParameters::default() 492 | /// .with_solver(LinearRegressionSolverName::QR) 493 | /// ); 494 | /// ``` 495 | #[must_use] 496 | pub const fn with_linear_settings(mut self, settings: LinearRegressionParameters) -> Self { 497 | self.linear_settings = Some(settings); 498 | self 499 | } 500 | 501 | /// Specify settings for lasso regression 502 | /// ``` 503 | /// # use automl::Settings; 504 | /// use automl::settings::LassoParameters; 505 | /// let settings = Settings::default() 506 | /// .with_lasso_settings(LassoParameters::default() 507 | /// .with_alpha(10.0) 508 | /// .with_tol(1e-10) 509 | /// .with_normalize(true) 510 | /// .with_max_iter(10_000) 511 | /// ); 512 | /// ``` 513 | #[must_use] 514 | pub const fn with_lasso_settings(mut self, settings: LassoParameters) -> Self { 515 | self.lasso_settings = Some(settings); 516 | self 517 | } 518 | 519 | /// Specify settings for ridge regression 520 | /// ``` 521 | /// # use automl::Settings; 522 | /// use automl::settings::{RidgeRegressionParameters, RidgeRegressionSolverName}; 523 | /// let settings = Settings::default() 524 | /// .with_ridge_settings(RidgeRegressionParameters::default() 525 | /// .with_alpha(10.0) 526 | /// .with_normalize(true) 527 | /// .with_solver(RidgeRegressionSolverName::Cholesky) 528 | /// ); 529 | /// ``` 530 | #[must_use] 531 | pub const fn with_ridge_settings(mut self, settings: RidgeRegressionParameters) -> Self { 532 | self.ridge_settings = Some(settings); 533 | self 534 | } 535 | 536 | /// Specify settings for elastic net 537 | /// ``` 538 | /// # use automl::Settings; 539 | /// use automl::settings::ElasticNetParameters; 540 | /// let settings = Settings::default() 541 | /// .with_elastic_net_settings(ElasticNetParameters::default() 542 | /// .with_tol(1e-10) 543 | /// .with_normalize(true) 544 | /// .with_alpha(1.0) 545 | /// .with_max_iter(10_000) 546 | /// .with_l1_ratio(0.5) 547 | /// ); 548 | /// ``` 549 | #[must_use] 550 | pub const fn with_elastic_net_settings(mut self, settings: ElasticNetParameters) -> Self { 551 | self.elastic_net_settings = Some(settings); 552 | self 553 | } 554 | 555 | /// Specify settings for KNN regressor 556 | /// ``` 557 | /// # use automl::Settings; 558 | /// use automl::settings::{KNNRegressorParameters, 559 | /// KNNAlgorithmName, KNNWeightFunction, Distance}; 560 | /// let settings = Settings::default() 561 | /// .with_knn_regressor_settings(KNNRegressorParameters::default() 562 | /// .with_algorithm(KNNAlgorithmName::CoverTree) 563 | /// .with_k(3) 564 | /// .with_distance(Distance::Euclidean) 565 | /// .with_weight(KNNWeightFunction::Uniform) 566 | /// ); 567 | /// ``` 568 | #[must_use] 569 | pub const fn with_knn_regressor_settings(mut self, settings: KNNRegressorParameters) -> Self { 570 | self.knn_regressor_settings = Some(settings); 571 | self 572 | } 573 | 574 | /// Specify settings for support vector regressor 575 | /// ``` 576 | /// # use automl::Settings; 577 | /// use automl::settings::{SVRParameters, Kernel}; 578 | /// let settings = Settings::default() 579 | /// .with_svr_settings(SVRParameters::default() 580 | /// .with_eps(1e-10) 581 | /// .with_tol(1e-10) 582 | /// .with_c(1.0) 583 | /// .with_kernel(Kernel::Linear) 584 | /// ); 585 | /// ``` 586 | #[must_use] 587 | pub const fn with_svr_settings(mut self, settings: SVRParameters) -> Self { 588 | self.svr_settings = Some(settings); 589 | self 590 | } 591 | 592 | /// Specify settings for random forest 593 | /// ``` 594 | /// # use automl::Settings; 595 | /// use automl::settings::RandomForestRegressorParameters; 596 | /// let settings = Settings::default() 597 | /// .with_random_forest_regressor_settings(RandomForestRegressorParameters::default() 598 | /// .with_m(100) 599 | /// .with_max_depth(5) 600 | /// .with_min_samples_leaf(20) 601 | /// .with_n_trees(100) 602 | /// .with_min_samples_split(20) 603 | /// ); 604 | /// ``` 605 | #[must_use] 606 | pub const fn with_random_forest_regressor_settings( 607 | mut self, 608 | settings: RandomForestRegressorParameters, 609 | ) -> Self { 610 | self.random_forest_regressor_settings = Some(settings); 611 | self 612 | } 613 | 614 | /// Specify settings for decision tree 615 | /// ``` 616 | /// # use automl::Settings; 617 | /// use automl::settings::DecisionTreeRegressorParameters; 618 | /// let settings = Settings::default() 619 | /// .with_decision_tree_regressor_settings(DecisionTreeRegressorParameters::default() 620 | /// .with_min_samples_split(20) 621 | /// .with_max_depth(5) 622 | /// .with_min_samples_leaf(20) 623 | /// ); 624 | /// ``` 625 | #[must_use] 626 | pub const fn with_decision_tree_regressor_settings( 627 | mut self, 628 | settings: DecisionTreeRegressorParameters, 629 | ) -> Self { 630 | self.decision_tree_regressor_settings = Some(settings); 631 | self 632 | } 633 | } 634 | 635 | impl Display for Settings { 636 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 637 | // Prep new table 638 | let mut table = Table::new(); 639 | 640 | // Get list of algorithms to skip 641 | let mut skiplist = String::new(); 642 | if self.skiplist.is_empty() { 643 | skiplist.push_str("None "); 644 | } else { 645 | for algorithm_to_skip in &self.skiplist { 646 | skiplist.push_str(&format!("{algorithm_to_skip}\n")); 647 | } 648 | } 649 | 650 | // Build out the table 651 | table 652 | .load_preset(UTF8_FULL) 653 | .apply_modifier(UTF8_SOLID_INNER_BORDERS) 654 | .set_header(vec![ 655 | Cell::new("Settings").add_attribute(Attribute::Bold), 656 | Cell::new("Value").add_attribute(Attribute::Bold), 657 | ]) 658 | .add_row(vec![Cell::new("General").add_attribute(Attribute::Italic)]) 659 | .add_row(vec![" Model Type", &*format!("{}", self.model_type)]) 660 | .add_row(vec![" Verbose", &*format!("{}", self.verbose)]) 661 | .add_row(vec![" Sorting Metric", &*format!("{}", self.sort_by)]) 662 | .add_row(vec![" Shuffle Data", &*format!("{}", self.shuffle)]) 663 | .add_row(vec![ 664 | " Number of CV Folds", 665 | &*format!("{}", self.number_of_folds), 666 | ]) 667 | .add_row(vec![ 668 | " Pre-Processing", 669 | &*format!("{}", self.preprocessing), 670 | ]) 671 | .add_row(vec![ 672 | " Skipped Algorithms", 673 | &skiplist[0..skiplist.len() - 1], 674 | ]); 675 | if !self.skiplist.contains(&Algorithm::Linear) { 676 | table 677 | .add_row(vec![ 678 | Cell::new(Algorithm::Linear).add_attribute(Attribute::Italic) 679 | ]) 680 | .add_row(vec![ 681 | " Solver", 682 | match self.linear_settings.as_ref().unwrap().solver { 683 | LinearRegressionSolverName::QR => "QR", 684 | LinearRegressionSolverName::SVD => "SVD", 685 | }, 686 | ]); 687 | } 688 | if !self.skiplist.contains(&Algorithm::Ridge) { 689 | table 690 | .add_row(vec![ 691 | Cell::new(Algorithm::Ridge).add_attribute(Attribute::Italic) 692 | ]) 693 | .add_row(vec![ 694 | " Solver", 695 | match self.ridge_settings.as_ref().unwrap().solver { 696 | RidgeRegressionSolverName::Cholesky => "Cholesky", 697 | RidgeRegressionSolverName::SVD => "SVD", 698 | }, 699 | ]) 700 | .add_row(vec![ 701 | " Alpha", 702 | &*format!("{}", self.ridge_settings.as_ref().unwrap().alpha), 703 | ]) 704 | .add_row(vec![ 705 | " Normalize", 706 | &*format!("{}", self.ridge_settings.as_ref().unwrap().normalize), 707 | ]); 708 | } 709 | 710 | if !self.skiplist.contains(&Algorithm::Lasso) { 711 | table 712 | .add_row(vec![ 713 | Cell::new(Algorithm::Lasso).add_attribute(Attribute::Italic) 714 | ]) 715 | .add_row(vec![ 716 | " Alpha", 717 | &*format!("{}", self.lasso_settings.as_ref().unwrap().alpha), 718 | ]) 719 | .add_row(vec![ 720 | " Normalize", 721 | &*format!("{}", self.lasso_settings.as_ref().unwrap().normalize), 722 | ]) 723 | .add_row(vec![ 724 | " Maximum Iterations", 725 | &*format!("{}", self.lasso_settings.as_ref().unwrap().max_iter), 726 | ]) 727 | .add_row(vec![ 728 | " Tolerance", 729 | &*format!("{}", self.lasso_settings.as_ref().unwrap().tol), 730 | ]); 731 | } 732 | 733 | if !self.skiplist.contains(&Algorithm::ElasticNet) { 734 | table 735 | .add_row(vec![ 736 | Cell::new(Algorithm::ElasticNet).add_attribute(Attribute::Italic) 737 | ]) 738 | .add_row(vec![ 739 | " Alpha", 740 | &*format!("{}", self.elastic_net_settings.as_ref().unwrap().alpha), 741 | ]) 742 | .add_row(vec![ 743 | " Normalize", 744 | &*format!("{}", self.elastic_net_settings.as_ref().unwrap().normalize), 745 | ]) 746 | .add_row(vec![ 747 | " Maximum Iterations", 748 | &*format!("{}", self.elastic_net_settings.as_ref().unwrap().max_iter), 749 | ]) 750 | .add_row(vec![ 751 | " Tolerance", 752 | &*format!("{}", self.elastic_net_settings.as_ref().unwrap().tol), 753 | ]) 754 | .add_row(vec![ 755 | " L1 Ratio", 756 | &*format!("{}", self.elastic_net_settings.as_ref().unwrap().l1_ratio), 757 | ]); 758 | } 759 | 760 | if !self.skiplist.contains(&Algorithm::DecisionTreeRegressor) { 761 | table 762 | .add_row(vec![ 763 | Cell::new(Algorithm::DecisionTreeRegressor).add_attribute(Attribute::Italic) 764 | ]) 765 | .add_row(vec![ 766 | " Max Depth", 767 | &*print_option( 768 | self.decision_tree_regressor_settings 769 | .as_ref() 770 | .unwrap() 771 | .max_depth, 772 | ), 773 | ]) 774 | .add_row(vec![ 775 | " Min samples for leaf", 776 | &*format!( 777 | "{}", 778 | self.decision_tree_regressor_settings 779 | .as_ref() 780 | .unwrap() 781 | .min_samples_leaf 782 | ), 783 | ]) 784 | .add_row(vec![ 785 | " Min samples for split", 786 | &*format!( 787 | "{}", 788 | self.decision_tree_regressor_settings 789 | .as_ref() 790 | .unwrap() 791 | .min_samples_split 792 | ), 793 | ]); 794 | } 795 | 796 | if !self.skiplist.contains(&Algorithm::RandomForestRegressor) { 797 | table 798 | .add_row(vec![ 799 | Cell::new(Algorithm::RandomForestRegressor).add_attribute(Attribute::Italic) 800 | ]) 801 | .add_row(vec![ 802 | " Max Depth", 803 | &*print_option( 804 | self.random_forest_regressor_settings 805 | .as_ref() 806 | .unwrap() 807 | .max_depth, 808 | ), 809 | ]) 810 | .add_row(vec![ 811 | " Min samples for leaf", 812 | &*format!( 813 | "{}", 814 | self.random_forest_regressor_settings 815 | .as_ref() 816 | .unwrap() 817 | .min_samples_leaf 818 | ), 819 | ]) 820 | .add_row(vec![ 821 | " Min samples for split", 822 | &*format!( 823 | "{}", 824 | self.random_forest_regressor_settings 825 | .as_ref() 826 | .unwrap() 827 | .min_samples_split 828 | ), 829 | ]) 830 | .add_row(vec![ 831 | " Min samples for split", 832 | &*format!( 833 | "{}", 834 | self.random_forest_regressor_settings 835 | .as_ref() 836 | .unwrap() 837 | .n_trees 838 | ), 839 | ]) 840 | .add_row(vec![ 841 | " Number of split candidates", 842 | &*print_option(self.random_forest_regressor_settings.as_ref().unwrap().m), 843 | ]); 844 | } 845 | 846 | if !self.skiplist.contains(&Algorithm::KNNRegressor) { 847 | table 848 | .add_row(vec![ 849 | Cell::new(Algorithm::KNNRegressor).add_attribute(Attribute::Italic) 850 | ]) 851 | .add_row(vec![ 852 | " Number of neighbors", 853 | &*format!("{}", self.knn_regressor_settings.as_ref().unwrap().k), 854 | ]) 855 | .add_row(vec![ 856 | " Search algorithm", 857 | &print_knn_search_algorithm( 858 | &self.knn_regressor_settings.as_ref().unwrap().algorithm, 859 | ), 860 | ]) 861 | .add_row(vec![ 862 | " Weighting function", 863 | &print_knn_weight_function( 864 | &self.knn_regressor_settings.as_ref().unwrap().weight, 865 | ), 866 | ]) 867 | .add_row(vec![ 868 | " Distance function", 869 | &*format!( 870 | "{}", 871 | &self.knn_regressor_settings.as_ref().unwrap().distance 872 | ), 873 | ]); 874 | } 875 | 876 | if !self.skiplist.contains(&Algorithm::SVR) { 877 | table 878 | .add_row(vec![ 879 | Cell::new(Algorithm::SVR).add_attribute(Attribute::Italic) 880 | ]) 881 | .add_row(vec![ 882 | " Regularization parameter", 883 | &*format!("{}", self.svr_settings.as_ref().unwrap().c), 884 | ]) 885 | .add_row(vec![ 886 | " Tolerance", 887 | &*format!("{}", self.svr_settings.as_ref().unwrap().tol), 888 | ]) 889 | .add_row(vec![ 890 | " Epsilon", 891 | &*format!("{}", self.svr_settings.as_ref().unwrap().eps), 892 | ]) 893 | .add_row(vec![ 894 | " Kernel", 895 | &*format!("{}", self.svr_settings.as_ref().unwrap().kernel), 896 | ]); 897 | } 898 | 899 | if !self.skiplist.contains(&Algorithm::LogisticRegression) { 900 | table 901 | .add_row(vec![ 902 | Cell::new(Algorithm::LogisticRegression).add_attribute(Attribute::Italic) 903 | ]) 904 | .add_row(vec![" N/A", "N/A"]); 905 | } 906 | 907 | if !self.skiplist.contains(&Algorithm::RandomForestClassifier) { 908 | table 909 | .add_row(vec![ 910 | Cell::new(Algorithm::RandomForestClassifier).add_attribute(Attribute::Italic) 911 | ]) 912 | .add_row(vec![ 913 | " Split Criterion", 914 | match self 915 | .random_forest_classifier_settings 916 | .as_ref() 917 | .unwrap() 918 | .criterion 919 | { 920 | SplitCriterion::Gini => "Gini", 921 | SplitCriterion::Entropy => "Entropy", 922 | SplitCriterion::ClassificationError => "Classification Error", 923 | }, 924 | ]) 925 | .add_row(vec![ 926 | " Max Depth", 927 | &*print_option( 928 | self.random_forest_classifier_settings 929 | .as_ref() 930 | .unwrap() 931 | .max_depth, 932 | ), 933 | ]) 934 | .add_row(vec![ 935 | " Min samples for leaf", 936 | &*format!( 937 | "{}", 938 | self.random_forest_classifier_settings 939 | .as_ref() 940 | .unwrap() 941 | .min_samples_leaf 942 | ), 943 | ]) 944 | .add_row(vec![ 945 | " Min samples for split", 946 | &*format!( 947 | "{}", 948 | self.random_forest_classifier_settings 949 | .as_ref() 950 | .unwrap() 951 | .min_samples_split 952 | ), 953 | ]) 954 | .add_row(vec![ 955 | " Min samples for split", 956 | &*format!( 957 | "{}", 958 | self.random_forest_classifier_settings 959 | .as_ref() 960 | .unwrap() 961 | .n_trees 962 | ), 963 | ]) 964 | .add_row(vec![ 965 | " Number of split candidates", 966 | &*print_option(self.random_forest_classifier_settings.as_ref().unwrap().m), 967 | ]); 968 | } 969 | 970 | if !self.skiplist.contains(&Algorithm::KNNClassifier) { 971 | table 972 | .add_row(vec![ 973 | Cell::new(Algorithm::KNNClassifier).add_attribute(Attribute::Italic) 974 | ]) 975 | .add_row(vec![ 976 | " Number of neighbors", 977 | &*format!("{}", self.knn_classifier_settings.as_ref().unwrap().k), 978 | ]) 979 | .add_row(vec![ 980 | " Search algorithm", 981 | &print_knn_search_algorithm( 982 | &self.knn_classifier_settings.as_ref().unwrap().algorithm, 983 | ), 984 | ]) 985 | .add_row(vec![ 986 | " Weighting function", 987 | &print_knn_weight_function( 988 | &self.knn_classifier_settings.as_ref().unwrap().weight, 989 | ), 990 | ]) 991 | .add_row(vec![ 992 | " Distance function", 993 | &*format!( 994 | "{}", 995 | &self.knn_classifier_settings.as_ref().unwrap().distance 996 | ), 997 | ]); 998 | } 999 | 1000 | if !self.skiplist.contains(&Algorithm::SVC) { 1001 | table 1002 | .add_row(vec![ 1003 | Cell::new(Algorithm::SVC).add_attribute(Attribute::Italic) 1004 | ]) 1005 | .add_row(vec![ 1006 | " Regularization parameter", 1007 | &*format!("{}", self.svc_settings.as_ref().unwrap().c), 1008 | ]) 1009 | .add_row(vec![ 1010 | " Tolerance", 1011 | &*format!("{}", self.svc_settings.as_ref().unwrap().tol), 1012 | ]) 1013 | .add_row(vec![ 1014 | " Epoch", 1015 | &*format!("{}", self.svc_settings.as_ref().unwrap().epoch), 1016 | ]) 1017 | .add_row(vec![ 1018 | " Kernel", 1019 | &*format!("{}", self.svc_settings.as_ref().unwrap().kernel), 1020 | ]); 1021 | } 1022 | 1023 | if !self.skiplist.contains(&Algorithm::DecisionTreeClassifier) { 1024 | table 1025 | .add_row(vec![ 1026 | " Split Criterion", 1027 | match self 1028 | .random_forest_classifier_settings 1029 | .as_ref() 1030 | .unwrap() 1031 | .criterion 1032 | { 1033 | SplitCriterion::Gini => "Gini", 1034 | SplitCriterion::Entropy => "Entropy", 1035 | SplitCriterion::ClassificationError => "Classification Error", 1036 | }, 1037 | ]) 1038 | .add_row(vec![ 1039 | Cell::new(Algorithm::DecisionTreeClassifier).add_attribute(Attribute::Italic) 1040 | ]) 1041 | .add_row(vec![ 1042 | " Max Depth", 1043 | &*print_option( 1044 | self.decision_tree_classifier_settings 1045 | .as_ref() 1046 | .unwrap() 1047 | .max_depth, 1048 | ), 1049 | ]) 1050 | .add_row(vec![ 1051 | " Min samples for leaf", 1052 | &*format!( 1053 | "{}", 1054 | self.decision_tree_classifier_settings 1055 | .as_ref() 1056 | .unwrap() 1057 | .min_samples_leaf 1058 | ), 1059 | ]) 1060 | .add_row(vec![ 1061 | " Min samples for split", 1062 | &*format!( 1063 | "{}", 1064 | self.decision_tree_classifier_settings 1065 | .as_ref() 1066 | .unwrap() 1067 | .min_samples_split 1068 | ), 1069 | ]); 1070 | } 1071 | 1072 | if !self.skiplist.contains(&Algorithm::CategoricalNaiveBayes) { 1073 | table 1074 | .add_row(vec![ 1075 | Cell::new(Algorithm::CategoricalNaiveBayes).add_attribute(Attribute::Italic) 1076 | ]) 1077 | .add_row(vec![ 1078 | " Smoothing parameter", 1079 | &*format!("{}", self.categorical_nb_settings.as_ref().unwrap().alpha), 1080 | ]); 1081 | } 1082 | 1083 | if !self.skiplist.contains(&Algorithm::GaussianNaiveBayes) { 1084 | table 1085 | .add_row(vec![ 1086 | Cell::new(Algorithm::GaussianNaiveBayes).add_attribute(Attribute::Italic) 1087 | ]) 1088 | .add_row(vec![ 1089 | " Priors", 1090 | &*debug_option(self.gaussian_nb_settings.as_ref().unwrap().clone().priors), 1091 | ]); 1092 | } 1093 | 1094 | writeln!(f, "{table}") 1095 | } 1096 | } 1097 | 1098 | /// Model type to train 1099 | #[derive(serde::Serialize, serde::Deserialize)] 1100 | enum ModelType { 1101 | /// No model type specified 1102 | None, 1103 | /// Regression model 1104 | Regression, 1105 | /// Classification model 1106 | Classification, 1107 | } 1108 | 1109 | impl Display for ModelType { 1110 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 1111 | match self { 1112 | Self::None => write!(f, "None"), 1113 | Self::Regression => write!(f, "Regression"), 1114 | Self::Classification => write!(f, "Classification"), 1115 | } 1116 | } 1117 | } 1118 | -------------------------------------------------------------------------------- /src/settings/svc_parameters.rs: -------------------------------------------------------------------------------- 1 | //! Support Vector Classification parameters 2 | 3 | pub use crate::utils::Kernel; 4 | 5 | /// Parameters for support vector classification 6 | #[derive(serde::Serialize, serde::Deserialize)] 7 | pub struct SVCParameters { 8 | /// Number of epochs to use in the epsilon-SVC model 9 | pub(crate) epoch: usize, 10 | /// Regulation penalty to use with the SVC model 11 | pub(crate) c: f32, 12 | /// Convergence tolerance to use with the SVC model 13 | pub(crate) tol: f32, 14 | /// Kernel to use with the SVC model 15 | pub(crate) kernel: Kernel, 16 | } 17 | 18 | impl SVCParameters { 19 | /// Define the number of epochs to use in the epsilon-SVC model. 20 | #[must_use] 21 | pub const fn with_epoch(mut self, epoch: usize) -> Self { 22 | self.epoch = epoch; 23 | self 24 | } 25 | 26 | /// Define the regulation penalty to use with the SVC Model 27 | #[must_use] 28 | pub const fn with_c(mut self, c: f32) -> Self { 29 | self.c = c; 30 | self 31 | } 32 | 33 | /// Define the convergence tolerance to use with the SVC model 34 | #[must_use] 35 | pub const fn with_tol(mut self, tol: f32) -> Self { 36 | self.tol = tol; 37 | self 38 | } 39 | 40 | /// Define which kernel to use with the SVC model 41 | #[must_use] 42 | pub const fn with_kernel(mut self, kernel: Kernel) -> Self { 43 | self.kernel = kernel; 44 | self 45 | } 46 | } 47 | 48 | impl Default for SVCParameters { 49 | fn default() -> Self { 50 | Self { 51 | epoch: 2, 52 | c: 1.0, 53 | tol: 1e-3, 54 | kernel: Kernel::Linear, 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/settings/svr_parameters.rs: -------------------------------------------------------------------------------- 1 | //! Support Vector Regression parameters 2 | 3 | pub use crate::utils::Kernel; 4 | 5 | /// Parameters for support vector regression 6 | #[derive(serde::Serialize, serde::Deserialize)] 7 | pub struct SVRParameters { 8 | /// Epsilon in the epsilon-SVR model. 9 | pub(crate) eps: f32, 10 | /// Regularization parameter. 11 | pub(crate) c: f32, 12 | /// Tolerance for stopping criterion. 13 | pub(crate) tol: f32, 14 | /// Kernel to use for the SVR model 15 | pub(crate) kernel: Kernel, 16 | } 17 | 18 | impl SVRParameters { 19 | /// Define the value of epsilon to use in the epsilon-SVR model. 20 | #[must_use] 21 | pub const fn with_eps(mut self, eps: f32) -> Self { 22 | self.eps = eps; 23 | self 24 | } 25 | 26 | /// Define the regulation penalty to use with the SVR Model 27 | #[must_use] 28 | pub const fn with_c(mut self, c: f32) -> Self { 29 | self.c = c; 30 | self 31 | } 32 | 33 | /// Define the convergence tolerance to use with the SVR model 34 | #[must_use] 35 | pub const fn with_tol(mut self, tol: f32) -> Self { 36 | self.tol = tol; 37 | self 38 | } 39 | 40 | /// Define which kernel to use with the SVR model 41 | #[must_use] 42 | pub const fn with_kernel(mut self, kernel: Kernel) -> Self { 43 | self.kernel = kernel; 44 | self 45 | } 46 | } 47 | 48 | impl Default for SVRParameters { 49 | fn default() -> Self { 50 | Self { 51 | eps: 0.1, 52 | c: 1.0, 53 | tol: 1e-3, 54 | kernel: Kernel::Linear, 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | //! Utility functions for the crate. 2 | 3 | use smartcore::{algorithm::neighbour::KNNAlgorithmName, neighbors::KNNWeightFunction}; 4 | use std::fmt::{Debug, Display, Formatter}; 5 | 6 | /// Convert an Option to a String for printing in display mode. 7 | pub fn print_option(x: Option) -> String { 8 | x.map_or_else(|| "None".to_string(), |y| format!("{y}")) 9 | } 10 | 11 | /// Convert an Option to a String for printing in debug mode. 12 | pub fn debug_option(x: Option) -> String { 13 | x.map_or_else(|| "None".to_string(), |y| format!("{y:#?}")) 14 | } 15 | 16 | /// Get the name for a knn weight function. 17 | pub fn print_knn_weight_function(f: &KNNWeightFunction) -> String { 18 | match f { 19 | KNNWeightFunction::Uniform => "Uniform".to_string(), 20 | KNNWeightFunction::Distance => "Distance".to_string(), 21 | } 22 | } 23 | 24 | /// Get the name for a knn search algorithm. 25 | pub fn print_knn_search_algorithm(a: &KNNAlgorithmName) -> String { 26 | match a { 27 | KNNAlgorithmName::LinearSearch => "Linear Search".to_string(), 28 | KNNAlgorithmName::CoverTree => "Cover Tree".to_string(), 29 | } 30 | } 31 | 32 | /// Kernel options for use with support vector machines 33 | #[derive(serde::Serialize, serde::Deserialize)] 34 | pub enum Kernel { 35 | /// Linear Kernel 36 | Linear, 37 | 38 | /// Polynomial kernel 39 | Polynomial(f32, f32, f32), 40 | 41 | /// Radial basis function kernel 42 | RBF(f32), 43 | 44 | /// Sigmoid kernel 45 | Sigmoid(f32, f32), 46 | } 47 | 48 | impl Display for Kernel { 49 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 50 | match self { 51 | Self::Linear => write!(f, "Linear"), 52 | Self::Polynomial(degree, gamma, coef) => write!( 53 | f, 54 | "Polynomial\n degree = {degree}\n gamma = {gamma}\n coef = {coef}" 55 | ), 56 | Self::RBF(gamma) => write!(f, "RBF\n gamma = {gamma}"), 57 | Self::Sigmoid(gamma, coef) => { 58 | write!(f, "Sigmoid\n gamma = {gamma}\n coef = {coef}") 59 | } 60 | } 61 | } 62 | } 63 | 64 | /// Distance metrics 65 | #[derive(serde::Serialize, serde::Deserialize)] 66 | pub enum Distance { 67 | /// Euclidean distance 68 | Euclidean, 69 | 70 | /// Manhattan distance 71 | Manhattan, 72 | 73 | /// Minkowski distance, parameterized by p 74 | Minkowski(u16), 75 | 76 | /// Mahalanobis distance 77 | Mahalanobis, 78 | 79 | /// Hamming distance 80 | Hamming, 81 | } 82 | 83 | impl Display for Distance { 84 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 85 | match self { 86 | Self::Euclidean => write!(f, "Euclidean"), 87 | Self::Manhattan => write!(f, "Manhattan"), 88 | Self::Minkowski(n) => write!(f, "Minkowski\n p = {n}"), 89 | Self::Mahalanobis => write!(f, "Mahalanobis"), 90 | Self::Hamming => write!(f, "Hamming"), 91 | } 92 | } 93 | } 94 | 95 | /// Function to do element-wise multiplication fo two vectors 96 | pub fn elementwise_multiply(v1: &[f32], v2: &[f32]) -> Vec { 97 | v1.iter().zip(v2).map(|(&i1, &i2)| i1 * i2).collect() 98 | } 99 | 100 | #[cfg(any(feature = "csv"))] 101 | use polars::prelude::{CsvReader, DataFrame, PolarsError, SerReader}; 102 | 103 | #[cfg(any(feature = "csv"))] 104 | /// Read and validate a csv file or URL into a polars `DataFrame`. 105 | pub fn validate_and_read

(file_path: P) -> DataFrame 106 | where 107 | P: AsRef, 108 | { 109 | let file_path_as_str = file_path.as_ref().to_str().unwrap(); 110 | 111 | CsvReader::from_path(file_path_as_str).map_or_else( 112 | |_| { 113 | if url::Url::parse(file_path_as_str).is_ok() { 114 | let file_contents = minreq::get(file_path_as_str) 115 | .send() 116 | .expect("Could not open URL"); 117 | let temp = temp_file::with_contents(file_contents.as_bytes()); 118 | validate_and_read(temp.path().to_str().unwrap()) 119 | } else { 120 | panic!("The string {file_path_as_str} is not a valid URL or file path.") 121 | } 122 | }, 123 | |csv| { 124 | csv.infer_schema(Some(10)) 125 | .has_header( 126 | csv_sniffer::Sniffer::new() 127 | .sniff_path(file_path_as_str) 128 | .expect("Cannot sniff file") 129 | .dialect 130 | .header 131 | .has_header_row, 132 | ) 133 | .finish() 134 | .expect("Cannot read file as CSV") 135 | .drop_nulls(None) 136 | .expect("Cannot remove null values") 137 | .convert_to_float() 138 | .expect("Cannot convert types") 139 | }, 140 | ) 141 | } 142 | 143 | /// Trait to convert to a polars `DataFrame`. 144 | #[cfg(any(feature = "csv"))] 145 | trait Cleanup { 146 | /// Convert to a polars `DataFrame` with all columns of type float. 147 | fn convert_to_float(self) -> Result; 148 | } 149 | 150 | #[cfg(any(feature = "csv"))] 151 | impl Cleanup for DataFrame { 152 | #[allow(unused_mut)] 153 | fn convert_to_float(mut self) -> Result { 154 | // Work in progress 155 | // for field in self.schema().fields() { 156 | // let name = field.name(); 157 | // if field.data_type().to_string() == "str" { 158 | // let ca = self.column(name).unwrap().utf8().unwrap(); 159 | // let vec_str: Vec<&str> = ca.into_no_null_iter().collect(); 160 | // let mut unique = vec_str.clone(); 161 | // unique.sort(); 162 | // unique.dedup(); 163 | // let mut new_encoding = vec![0; 0]; 164 | // if unique.len() == vec_str.len() || unique.len() == 1 { 165 | // self.drop_in_place(name); 166 | // } else { 167 | // vec_str.into_iter().for_each(|x| { 168 | // new_encoding.push(unique.iter().position(|&y| y == x).unwrap() as u64) 169 | // }); 170 | // self.with_column(Series::new(name, &new_encoding)); 171 | // } 172 | // } 173 | // } 174 | 175 | Ok(self) 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /tests/classification.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod classification_tests { 3 | use automl::{settings::*, *}; 4 | use smartcore::dataset::breast_cancer::load_dataset; 5 | 6 | #[test] 7 | #[cfg(feature = "csv")] 8 | fn test_new_from_csv() { 9 | let file_name = "data/breast_cancer.csv"; 10 | 11 | // Set up the classifier settings and load data 12 | let settings = Settings::default_classification().with_number_of_folds(2); 13 | 14 | let mut classifier = SupervisedModel::new((file_name, 30), settings); 15 | 16 | // Compare models 17 | classifier.train(); 18 | 19 | // Try to predict something 20 | classifier.predict(vec![vec![5.0_f32; 30]; 10]); 21 | classifier.predict("data/breast_cancer_without_target.csv"); 22 | #[cfg(feature = "nd")] 23 | classifier.predict(ndarray::Array2::from_shape_vec((10, 30), vec![5.0; 300]).unwrap()); 24 | } 25 | 26 | #[test] 27 | fn test_add_interactions_preprocessing() { 28 | let settings = 29 | Settings::default_classification().with_preprocessing(PreProcessing::AddInteractions); 30 | test_from_settings(settings); 31 | } 32 | 33 | #[test] 34 | fn test_add_polynomial_preprocessing() { 35 | let settings = Settings::default_classification() 36 | .with_preprocessing(PreProcessing::AddPolynomial { order: 2 }); 37 | test_from_settings(settings); 38 | } 39 | 40 | #[test] 41 | fn test_blending() { 42 | let settings = Settings::default_classification().with_final_model(FinalModel::Blending { 43 | algorithm: Algorithm::LogisticRegression, 44 | meta_training_fraction: 0.15, 45 | meta_testing_fraction: 0.15, 46 | }); 47 | test_from_settings(settings); 48 | } 49 | 50 | fn test_from_settings(settings: Settings) { 51 | // Check training 52 | let dataset = load_dataset(); 53 | 54 | // Set up the regressor settings and load data 55 | let mut classifier = SupervisedModel::new(dataset, settings); 56 | 57 | // Compare models 58 | classifier.train(); 59 | 60 | // Try to predict something 61 | classifier.predict(vec![vec![5.0_f32; 30]; 10]); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /tests/new_from_dataset.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod new_from_dataset { 3 | use automl::{settings::*, *}; 4 | use smartcore::dataset::breast_cancer; 5 | use smartcore::dataset::diabetes; 6 | 7 | #[test] 8 | fn classification() { 9 | // Make a model 10 | let mut classifier = SupervisedModel::new( 11 | breast_cancer::load_dataset(), 12 | Settings::default_classification(), 13 | ); 14 | 15 | // Compare models 16 | classifier.train(); 17 | 18 | // Try to predict something from a vector 19 | classifier.predict(vec![vec![5.0_f32; 30]; 10]); 20 | 21 | // Try to predict something from ndarray 22 | #[cfg(feature = "nd")] 23 | classifier.predict(ndarray::Array2::from_shape_vec((10, 30), vec![5.0; 300]).unwrap()); 24 | 25 | // Try to predict something from a csv 26 | #[cfg(feature = "csv")] 27 | classifier.predict("data/breast_cancer_without_target.csv"); 28 | } 29 | 30 | #[test] 31 | fn regression() { 32 | // Make a model 33 | let mut regressor = 34 | SupervisedModel::new(diabetes::load_dataset(), Settings::default_regression()); 35 | 36 | // Compare models 37 | regressor.train(); 38 | 39 | // Try to predict something from a vector 40 | regressor.predict(vec![vec![5.0_f32; 10]; 10]); 41 | 42 | // Try to predict something from ndarray 43 | #[cfg(feature = "nd")] 44 | regressor.predict(ndarray::Array2::from_shape_vec((10, 10), vec![5.0; 100]).unwrap()); 45 | 46 | // Try to predict something from a csv 47 | #[cfg(feature = "csv")] 48 | regressor.predict("data/diabetes_without_target.csv"); 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /tests/regression.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod regression_tests { 3 | use automl::{settings::*, *}; 4 | use smartcore::dataset::diabetes::load_dataset; 5 | 6 | #[test] 7 | #[cfg(feature = "csv")] 8 | fn test_new_from_csv() { 9 | let file_name = "data/diabetes.csv"; 10 | 11 | // Set up the regressor settings and load data 12 | let settings = Settings::default_regression().with_number_of_folds(2); 13 | 14 | let mut regressor = SupervisedModel::new((file_name, 10), settings); 15 | 16 | // Compare models 17 | regressor.train(); 18 | 19 | // Try to predict something 20 | regressor.predict(vec![vec![5.0_f32; 10]; 10]); 21 | regressor.predict("data/diabetes_without_target.csv"); 22 | #[cfg(feature = "nd")] 23 | regressor.predict(ndarray::Array2::from_shape_vec((10, 10), vec![5.0; 100]).unwrap()); 24 | } 25 | 26 | #[test] 27 | #[cfg(feature = "csv")] 28 | fn test_new_from_csv_url() { 29 | // let file_name = "data/diabetes.csv"; 30 | let file_name = "https://raw.githubusercontent.com/plotly/datasets/master/diabetes.csv"; 31 | 32 | // Set up the regressor settings and load data 33 | let settings = Settings::default_regression().with_number_of_folds(2); 34 | 35 | let mut regressor = SupervisedModel::new((file_name, 8), settings); 36 | 37 | // Compare models 38 | regressor.train(); 39 | 40 | // Try to predict something 41 | regressor.predict(vec![vec![5.0_f32; 8]; 8]); 42 | } 43 | 44 | #[test] 45 | fn test_add_interactions_preprocessing() { 46 | let settings = 47 | Settings::default_regression().with_preprocessing(PreProcessing::AddInteractions); 48 | test_from_settings(settings); 49 | } 50 | 51 | #[test] 52 | fn test_add_polynomial_preprocessing() { 53 | let settings = Settings::default_regression() 54 | .with_preprocessing(PreProcessing::AddPolynomial { order: 2 }); 55 | test_from_settings(settings); 56 | } 57 | 58 | #[test] 59 | fn test_blending() { 60 | let settings = Settings::default_regression().with_final_model(FinalModel::Blending { 61 | algorithm: Algorithm::Linear, 62 | meta_training_fraction: 0.15, 63 | meta_testing_fraction: 0.15, 64 | }); 65 | test_from_settings(settings); 66 | } 67 | 68 | fn test_from_settings(settings: Settings) { 69 | // Check training 70 | let dataset = load_dataset(); 71 | 72 | // Set up the regressor settings and load data 73 | let mut regressor = SupervisedModel::new(dataset, settings); 74 | 75 | // Compare models 76 | regressor.train(); 77 | 78 | // Try to predict something 79 | regressor.predict(vec![vec![5.0_f32; 10]; 10]); 80 | } 81 | } 82 | --------------------------------------------------------------------------------