├── .gitignore ├── .vscode └── settings.json ├── Cargo.lock ├── Cargo.toml ├── README.md ├── doc_assets └── jiro.svg ├── examples ├── housing │ ├── .gitignore │ ├── .vscode │ │ └── settings.json │ ├── Cargo.lock │ ├── Cargo.toml │ ├── README.md │ ├── best_model_params.gz │ ├── data_and_preds.csv │ ├── dataset │ │ └── .gitignore │ ├── model_eval.json │ ├── models │ │ └── test0.json │ ├── models_stats │ │ └── .gitignore │ ├── models_weights │ │ └── .gitignore │ ├── src │ │ ├── bin │ │ │ ├── article.rs │ │ │ ├── boxplots.rs │ │ │ ├── genmodel.rs │ │ │ ├── importance.rs │ │ │ ├── model.rs │ │ │ ├── plot_epochs_loss.rs │ │ │ ├── plot_predictions.rs │ │ │ ├── plot_r2.rs │ │ │ └── predict_all.rs │ │ └── lib.rs │ └── visuals │ │ └── .gitignore ├── mnist │ ├── .gitignore │ ├── .vscode │ │ └── settings.json │ ├── Cargo.lock │ ├── Cargo.toml │ ├── README.md │ ├── dataset │ │ └── .gitignore │ ├── models │ │ └── .gitignore │ ├── models_stats │ │ └── .gitignore │ ├── src │ │ ├── bin │ │ │ ├── clean.rs │ │ │ ├── plot_epochs_loss.rs │ │ │ ├── predict_all.rs │ │ │ ├── specify.rs │ │ │ ├── test.rs │ │ │ └── train.rs │ │ └── lib.rs │ └── visuals │ │ └── .gitignore ├── visuals │ ├── C_price_over_price.png │ ├── GS2L_price_over_price.png │ ├── GS2L_prop_dist.png │ ├── dropout_rate.png │ ├── feynman.jpg │ ├── filt_sqrd_6ReLU-Adam_Lin-Adam_+8_id.png │ ├── filt_sqrd_6Tanh-Adam_Tanh-Adam_+8_id.png │ ├── final_best │ │ ├── final_bedrooms^2.png │ │ ├── final_best_bathrooms.png │ │ ├── final_best_bathrooms^2.png │ │ ├── final_best_bedrooms.png │ │ ├── final_best_bedrooms^2.png │ │ ├── final_best_condition.png │ │ ├── final_best_condition^2.png │ │ ├── final_best_date_month.png │ │ ├── final_best_date_month^2.png │ │ ├── final_best_date_timestamp.png │ │ ├── final_best_date_timestamp^2.png │ │ ├── final_best_floors.png │ │ ├── final_best_floors^2.png │ │ ├── final_best_grade.png │ │ ├── final_best_grade^2.png │ │ ├── final_best_id.png │ │ ├── final_best_lat.png │ │ ├── final_best_lat^2.png │ │ ├── final_best_latlong.png │ │ ├── final_best_long.png │ │ ├── final_best_long^2.png │ │ ├── final_best_pred_price.png │ │ ├── final_best_price.png │ │ ├── final_best_sqft_above.png │ │ ├── final_best_sqft_above^2.png │ │ ├── final_best_sqft_basement.png │ │ ├── final_best_sqft_basement^2.png │ │ ├── final_best_sqft_living.png │ │ ├── final_best_sqft_living^2.png │ │ ├── final_best_sqft_lot.png │ │ ├── final_best_sqft_lot^2.png │ │ ├── final_best_view.png │ │ ├── final_best_view^2.png │ │ ├── final_best_waterfront.png │ │ ├── final_best_waterfront^2.png │ │ ├── final_best_yr_built.png │ │ ├── final_best_yr_built^2.png │ │ ├── final_best_yr_renovated.png │ │ └── final_best_yr_renovated^2.png │ ├── final_folds_r2.png │ ├── final_loss.png │ ├── final_rel_importance_cropped.png │ ├── full_lt_8ReLU-Adam-Lin-Adam_boxplots.png │ ├── full_lt_8ReLU-Adam-Lin-Adam_latlong.png │ ├── full_lt_8ReLU-Adam-Lin-Adam_loss.png │ ├── full_lt_8ReLU-Adam-Lin-Adam_price.png │ ├── learning_rate_decay.png │ ├── model1_L_id.png │ ├── model1_L_price.png │ ├── model3_id.png │ ├── model4_price.png │ ├── model5_price.png │ ├── model7_id.png │ ├── model9_loss.png │ ├── network.svg │ ├── neural_network.png │ ├── no_outliers_filter_outliers_filter_loss.png │ ├── sgd_vs_momentum_folds.jpg │ ├── weights_uniform_tanh_weights_uniform_signed_tanh_weights_glorot_uniform_tanh_loss.png │ ├── weights_uniform_weights_zeros_weights_glorot_uniform_loss.png │ ├── with_sgd_with_momentum_loss.png │ ├── with_sgd_with_momentum_with_adam_loss.png │ ├── xor-example-predictions.png │ ├── xor-example-predictions_2.png │ └── xor_nn_graph.dot └── xor │ ├── .gitignore │ ├── .vscode │ └── settings.json │ ├── Cargo.lock │ ├── Cargo.toml │ ├── README.md │ └── src │ ├── lib.rs │ └── main.rs ├── src ├── activation │ ├── linear.rs │ ├── mod.rs │ ├── relu.rs │ ├── sigmoid.rs │ ├── softmax.rs │ └── tanh.rs ├── benchmarking.rs ├── bin │ ├── main.rs │ ├── test_arrayfire.rs │ ├── test_avg_pooling.rs │ ├── test_backend.rs │ ├── test_image.rs │ ├── test_one_hot_encode.rs │ └── test_softmax.rs ├── dataset.rs ├── datatable.rs ├── initializers.rs ├── layer │ ├── defaults.rs │ ├── dense_layer.rs │ ├── full_layer.rs │ └── mod.rs ├── learning_rate │ ├── inverse_time_decay.rs │ ├── mod.rs │ └── piecewise_constant.rs ├── lib.rs ├── linalg │ ├── arrayfire_matrix.rs │ ├── mod.rs │ ├── nalgebra_matrix.rs │ └── ndarray_matrix.rs ├── loss │ ├── bce.rs │ ├── mod.rs │ └── mse.rs ├── model │ ├── conv_network_model.rs │ ├── full_dense_conv_layer_model.rs │ ├── full_dense_layer_model.rs │ ├── full_direct_conv_layer_model.rs │ ├── mod.rs │ └── network_model.rs ├── monitor │ └── mod.rs ├── network │ ├── mod.rs │ └── params.rs ├── optimizer │ ├── adam.rs │ ├── mod.rs │ ├── momentum.rs │ └── sgd.rs ├── preprocessing │ ├── attach_ids.rs │ ├── extract_months.rs │ ├── extract_timestamps.rs │ ├── feature_cached.rs │ ├── filter_outliers.rs │ ├── log_scale.rs │ ├── map.rs │ ├── mod.rs │ ├── normalize.rs │ ├── one_hot_encode.rs │ ├── sample.rs │ └── square.rs ├── trainers │ ├── kfolds.rs │ ├── mod.rs │ └── split.rs ├── vec_utils.rs └── vision │ ├── conv_activation │ ├── linear.rs │ ├── mod.rs │ ├── relu.rs │ ├── sigmoid.rs │ └── tanh.rs │ ├── conv_initializers.rs │ ├── conv_layer │ ├── avg_pooling_layer.rs │ ├── defaults.rs │ ├── dense_conv_layer.rs │ ├── direct_conv_layer.rs │ ├── full_conv_layer.rs │ └── mod.rs │ ├── conv_network.rs │ ├── conv_optimizer │ ├── adam.rs │ ├── mod.rs │ ├── momentum.rs │ └── sgd.rs │ ├── image │ ├── arrayfire_image.rs │ ├── mod.rs │ ├── nalgebra_image.rs │ └── ndarray_image.rs │ ├── image_layer.rs │ └── mod.rs └── tests ├── matrix.rs └── vec_utils.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target/ 2 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "rust-analyzer.linkedProjects": [ 3 | "./Cargo.toml" 4 | ], 5 | "rust-analyzer.showUnlinkedFileNotification": false, 6 | "rust-analyzer.check.features": [ 7 | "nalgebra" 8 | ], 9 | "cSpell.enableFiletypes": [ 10 | "!rust" 11 | ] 12 | } -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "jiro_nn" 3 | version = "0.8.1" 4 | edition = "2021" 5 | license = "MIT OR Apache-2.0" 6 | description = "Neural Networks framework with model building & data preprocessing features." 7 | readme = "README.md" 8 | repository = "https://github.com/AnicetNgrt/jiro-nn" 9 | keywords = [ 10 | "machine-learning", 11 | "neural-networks", 12 | "gradient-descent", 13 | "data-science", 14 | "data-analysis", 15 | ] 16 | categories = ["science"] 17 | exclude = [ 18 | ".vscode/*", 19 | "examples/**/*", 20 | ] 21 | 22 | [dependencies] 23 | nalgebra-glm = { version = "0.18.0", optional = true } 24 | nalgebra = { version = "0.32.2", optional = true, features = ["rand", "rayon"] } 25 | libm = "0.2.6" 26 | # https://pola-rs.github.io/polars-book/user-guide/installation/#rust 27 | polars = { version = "0.28.0", optional = true, default-features = false, features = ["fmt", "json", "lazy", "streaming", "describe"] } 28 | rand = "0.8.5" 29 | rand_distr = "0.4.3" 30 | serde = { version = "1.0.159", features = ["derive"] } 31 | serde-aux = "4.2.0" 32 | serde_json = "1.0.95" 33 | sha2 = "0.10.6" 34 | assert_float_eq = "1.1.3" 35 | arrayfire = { version = "3.8.0", optional = true } 36 | bincode = "1.3.3" 37 | flate2 = "1.0.26" 38 | lazy_static = "1.4.0" 39 | ndarray = { version = "0.15.3", optional = true } 40 | convolutions-rs = { version = "0.3.4", optional = true } 41 | 42 | [features] 43 | default = ["ndarray", "data"] 44 | parquet = ["polars?/parquet"] 45 | ipc = ["polars?/ipc"] 46 | data = ["dep:polars"] 47 | ndarray = ["dep:ndarray", "dep:convolutions-rs"] 48 | nalgebra = ["dep:nalgebra", "dep:nalgebra-glm"] 49 | arrayfire = ["dep:arrayfire"] 50 | f64 = [] 51 | -------------------------------------------------------------------------------- /examples/housing/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /examples/housing/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cSpell.enableFiletypes": [ 3 | "!rust" 4 | ] 5 | } -------------------------------------------------------------------------------- /examples/housing/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "housing" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | jiro_nn = { path = "../../", default_features = false, features = ["data"] } 10 | gnuplot = "0.0.37" 11 | rand = "0.8.5" 12 | indicatif = "0.17.5" 13 | 14 | [features] 15 | f64 = ["jiro_nn/f64"] 16 | default = ["ndarray"] 17 | nalgebra = ["jiro_nn/nalgebra"] 18 | arrayfire = ["jiro_nn/arrayfire"] 19 | ndarray = ["jiro_nn/ndarray"] 20 | -------------------------------------------------------------------------------- /examples/housing/README.md: -------------------------------------------------------------------------------- 1 | # King County House price regression 2 | 3 | Standard looking results by doing roughly [the same approach as a user named frederico on Kaggle using Pytorch](https://www.kaggle.com/code/chavesfm/dnn-house-price-r-0-88/notebook). Involving data manipulation with preprocessing, and a 8 layers of ~20 inputs each model using ReLU & Adam. Training over 300 epochs with 8 folds k-folds. 4 | 5 | Charts made with the gnuplot crate. 6 | 7 | ![loss according to training epochs](../visuals/full_lt_8ReLU-Adam-Lin-Adam_loss.png) 8 | 9 | ![prices according to predicted prices](../visuals/full_lt_8ReLU-Adam-Lin-Adam_price.png) 10 | 11 | ![prices & predicted prices according to lat & long](../visuals/full_lt_8ReLU-Adam-Lin-Adam_latlong.png) 12 | -------------------------------------------------------------------------------- /examples/housing/best_model_params.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/housing/best_model_params.gz -------------------------------------------------------------------------------- /examples/housing/dataset/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /examples/housing/models_stats/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /examples/housing/models_weights/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /examples/housing/src/bin/article.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::dataset::Dataset; 2 | use jiro_nn::model::ModelBuilder; 3 | use jiro_nn::monitor::TM; 4 | use jiro_nn::preprocessing::Pipeline; 5 | use jiro_nn::preprocessing::attach_ids::AttachIds; 6 | use jiro_nn::preprocessing::map::*; 7 | use jiro_nn::trainers::kfolds::KFolds; 8 | use jiro_nn::dataset::FeatureTags::*; 9 | 10 | pub fn main() { 11 | let mut dataset_config = Dataset::from_file("dataset/kc_house_data.csv"); 12 | dataset_config 13 | .remove_features(&["id", "zipcode", "sqft_living15", "sqft_lot15"]) 14 | .tag_feature("date", DateFormat("%Y%m%dT%H%M%S")) 15 | .tag_feature("date", AddExtractedMonth) 16 | .tag_feature("date", AddExtractedTimestamp) 17 | .tag_feature("date", Not(&UsedInModel)) 18 | .tag_feature( 19 | "yr_renovated", 20 | Mapped( 21 | MapSelector::equal_scalar(0.0), 22 | MapOp::replace_with_feature("yr_built"), 23 | ), 24 | ) 25 | .tag_feature("price", Predicted) 26 | .tag_all(Log10.only(&["sqft_living", "sqft_above", "price"])) 27 | .tag_all(AddSquared.except(&["price", "date"]).incl_added_features()) 28 | .tag_all(FilterOutliers.except(&["date"]).incl_added_features()) 29 | .tag_all(Normalized.except(&["date"]).incl_added_features()); 30 | 31 | TM::start_monitoring(); 32 | 33 | let mut pipeline = Pipeline::basic_single_pass(); 34 | let (dataset_config, data) = pipeline 35 | .prepend(AttachIds::new("id")) 36 | .load_data("dataset/kc_house_data.csv", Some(&dataset_config)) 37 | .run(); 38 | 39 | let hidden_neurons = 22; 40 | let hidden_layers = 8; 41 | 42 | let mut nn = ModelBuilder::new(dataset_config).neural_network(); 43 | for _ in 0..hidden_layers { 44 | nn = nn 45 | .full_dense(hidden_neurons) 46 | .relu() 47 | .momentum() 48 | .end(); 49 | } 50 | let model = nn 51 | .full_dense(1) 52 | .linear() 53 | .momentum() 54 | .end() 55 | .end() 56 | .batch_size(128) 57 | .epochs(100) 58 | .build(); 59 | 60 | let mut kfold = KFolds::new(4); 61 | let (preds_and_ids, model_eval) = kfold 62 | .all_epochs_validation() 63 | .all_epochs_r2() 64 | .compute_best_model() 65 | .run(&model, &data); 66 | 67 | TM::stop_monitoring(); 68 | 69 | let best_model_params = kfold.take_best_model(); 70 | best_model_params.to_binary_compressed("best_model_params.gz"); 71 | 72 | let preds_and_ids = pipeline.revert(&preds_and_ids); 73 | let data = pipeline.revert(&data); 74 | let data_and_preds = data.inner_join(&preds_and_ids, "id", "id", Some("pred")); 75 | 76 | data_and_preds.to_csv_file("data_and_preds.csv"); 77 | model_eval.to_json_file("model_eval.json"); 78 | } 79 | -------------------------------------------------------------------------------- /examples/housing/src/bin/boxplots.rs: -------------------------------------------------------------------------------- 1 | use gnuplot::{ 2 | AutoOption::{Fix}, 3 | AxesCommon, Coordinate, Figure, 4 | LabelOption::Rotate, 5 | PlotOption::{Color, PointSymbol}, MarginSide::MarginBottom, 6 | }; 7 | use jiro_nn::{ 8 | model::Model, 9 | preprocessing::{ 10 | extract_months::ExtractMonths, extract_timestamps::ExtractTimestamps, normalize::Normalize, Pipeline, 11 | }, 12 | vec_utils::{vector_quartiles_iqr}, 13 | }; 14 | 15 | fn main() { 16 | let args: Vec = std::env::args().collect(); 17 | let config_name = &args[1]; 18 | 19 | let model = Model::from_json_file(format!("models/{}.json", config_name)); 20 | println!("model: {:#?}", model); 21 | 22 | let mut pipeline = Pipeline::new(); 23 | let (_, data_before) = pipeline 24 | .push(ExtractMonths) 25 | .push(ExtractTimestamps) 26 | .push(Normalize::new()) 27 | .load_data("./dataset/kc_house_data.csv", Some(&model.dataset_config)) 28 | .run(); 29 | 30 | let mut pipeline = Pipeline::basic_single_pass(); 31 | let (after_config, data) = pipeline 32 | .load_data("./dataset/kc_house_data.csv", Some(&model.dataset_config)) 33 | .run(); 34 | 35 | println!("{:#?}", data); 36 | 37 | let mut fg = Figure::new(); 38 | 39 | let mut axes = fg.axes2d().set_title("Before and after preprocessing features boxes & whiskers", &[]).set_margins(&[ 40 | MarginBottom(0.2) 41 | ]); 42 | 43 | for (i, feature_name) in after_config.feature_names().iter().enumerate() { 44 | for (j, (prefix, data)) in vec![("before", &data_before), ("after", &data)] 45 | .iter() 46 | .enumerate() 47 | { 48 | if data.has_column(feature_name) == false { 49 | continue; 50 | } 51 | let vals = data.column_to_vector(&feature_name); 52 | let (q1, q2, q3, min, max) = vector_quartiles_iqr(&vals); 53 | let outliers = vals.into_iter() 54 | .filter(|x| *x < min || *x > max) 55 | .collect::>(); 56 | 57 | let color = if j == 0 { "red" } else { "blue" }; 58 | axes = axes 59 | .label( 60 | &format!("{} {}", prefix, feature_name.replace("_", " ")), 61 | Coordinate::Axis(((i * 2) + j) as f64 + 0.1), 62 | Coordinate::Axis(-0.02), 63 | &[Rotate(-45.0)], 64 | ) 65 | .box_and_whisker_set_width( 66 | [((i * 2) + j) as f32 + 0.1].iter(), 67 | [q1].iter(), 68 | [min].iter(), 69 | [q2 + 0.003].iter(), 70 | [q2 + 0.003].iter(), 71 | [0.4f32].iter(), 72 | &[Color(color)], 73 | ) 74 | .box_and_whisker_set_width( 75 | [((i * 2) + j) as f32 + 0.1].iter(), 76 | [q2 - 0.003].iter(), 77 | [q2 - 0.003].iter(), 78 | [max].iter(), 79 | [q3].iter(), 80 | [0.4f32].iter(), 81 | &[Color(color)], 82 | ); 83 | 84 | if outliers.len() > 0 { 85 | axes = axes.points( 86 | vec![((i * 2) + j) as f64 + 0.1; outliers.len()], 87 | outliers, 88 | &[Color(color), PointSymbol('*')] 89 | ); 90 | } 91 | } 92 | } 93 | 94 | axes.set_x_ticks(None, &[], &[]) 95 | .set_y_ticks(Some((Fix(0.1), 10)), &[], &[]) 96 | .set_y_grid(true) 97 | .set_x_range(Fix(-0.2), Fix((after_config.feature_names().len() * 2) as f64)) 98 | .set_y_range(Fix(0.0), Fix(1.0)); 99 | 100 | fg.save_to_png(format!("visuals/{}_boxplots.png", config_name), 2448, 1224) 101 | .unwrap(); 102 | } 103 | -------------------------------------------------------------------------------- /examples/housing/src/bin/genmodel.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::{ 2 | dataset::{Dataset, FeatureTags::*}, 3 | model::ModelBuilder, 4 | preprocessing::map::{MapOp, MapSelector, MapValue}, 5 | }; 6 | 7 | fn main() { 8 | let args: Vec = std::env::args().collect(); 9 | let config_name = &args[1]; 10 | 11 | let mut dataset_config = Dataset::from_file("dataset/kc_house_data.csv"); 12 | dataset_config 13 | .remove_features(&["id", "zipcode", "sqft_living15", "sqft_lot15"]) 14 | .tag_feature("date", DateFormat("%Y%m%dT%H%M%S")) 15 | .tag_feature("date", AddExtractedMonth) 16 | .tag_feature("date", AddExtractedTimestamp) 17 | .tag_feature("date", Not(&UsedInModel)) 18 | .tag_feature( 19 | "yr_renovated", 20 | Mapped( 21 | MapSelector::equal(MapValue::scalar(0.0)), 22 | MapOp::replace_with(MapValue::take_from_feature("yr_built")), 23 | ), 24 | ) 25 | .tag_feature("price", Predicted) 26 | .tag_all(Log10.only(&["sqft_living", "sqft_above", "price"])) 27 | .tag_all(AddSquared.except(&["price", "date"]).incl_added_features()) 28 | //.tag_all(FilterOutliers.except(&["date"]).incl_added_features()) 29 | .tag_all(Normalized.except(&["date"]).incl_added_features()); 30 | 31 | let h_size = dataset_config.in_features_names().len() + 1; 32 | let nh = 8; 33 | 34 | let mut nn = ModelBuilder::new(dataset_config).neural_network(); 35 | for _ in 0..nh { 36 | nn = nn 37 | .full_dense(h_size) 38 | .relu() 39 | .momentum() 40 | .end(); 41 | } 42 | let model = nn 43 | .full_dense(1) 44 | .linear() 45 | .momentum() 46 | .end() 47 | .end() 48 | .batch_size(128) 49 | .epochs(100) 50 | .build(); 51 | 52 | println!("{:#?}", model); 53 | 54 | model.to_json_file(format!("models/{}.json", config_name)); 55 | } 56 | -------------------------------------------------------------------------------- /examples/housing/src/bin/importance.rs: -------------------------------------------------------------------------------- 1 | use gnuplot::*; 2 | use jiro_nn::{ 3 | linalg::Scalar, 4 | model::Model, 5 | network::params::NetworkParams, 6 | preprocessing::{sample::Sample, Pipeline}, 7 | vec_utils::{avg_vector, r2_score_vector2, shuffle_column}, 8 | }; 9 | 10 | pub fn main() { 11 | let args: Vec = std::env::args().collect(); 12 | let config_name = &args[1]; 13 | let weights_file = &args[2]; 14 | 15 | let mut model = Model::from_json_file(format!("models/{}.json", config_name)); 16 | 17 | let mut pipeline = Pipeline::basic_single_pass(); 18 | let (updated_dataset_config, data) = pipeline 19 | .prepend(Sample::new(21000, true)) 20 | .load_data("./dataset/kc_house_data.csv", Some(&model.dataset_config)) 21 | .run(); 22 | 23 | println!("Data: {:#?}", data); 24 | 25 | let model = model.with_new_dataset(updated_dataset_config); 26 | let predicted_features = model.dataset_config.predicted_features_names(); 27 | 28 | let (x_table, y_table) = data.random_order_in_out(&predicted_features); 29 | 30 | let x = x_table.to_vectors(); 31 | let y = y_table.to_vectors(); 32 | 33 | let weights = NetworkParams::from_json(format!("models_weights/{}.json", weights_file)); 34 | let mut network = model.to_network(); 35 | network.load_params(&weights); 36 | 37 | let preds = network.predict_many(&x, 1); 38 | 39 | let ref_score = r2_score_vector2(&y, &preds); 40 | 41 | println!("r2: {:#?}", ref_score); 42 | 43 | let mut x_cp = x.clone(); 44 | 45 | // Based on https://arxiv.org/pdf/1801.01489.pdf 46 | // and https://christophm.github.io/interpretable-ml-book/feature-importance.html 47 | 48 | let shuffles_count = 10; 49 | let ncols = x_cp[0].len(); 50 | let mut means_list = Vec::new(); 51 | 52 | for c in 0..ncols { 53 | println!("Shuffling column {}/{}", c + 1, ncols); 54 | 55 | let mut metric_list = Vec::new(); 56 | for _ in 0..shuffles_count { 57 | shuffle_column(&mut x_cp, c); 58 | let preds = network.predict_many(&x_cp, 1); 59 | let score = r2_score_vector2(&y, &preds); 60 | println!("score: {:#?}", score); 61 | metric_list.push(ref_score - score); 62 | x_cp = x.clone(); 63 | } 64 | means_list.push(avg_vector(&metric_list)); 65 | } 66 | 67 | // Converting it all to percentages and sorting 68 | let mut importance_rel = Vec::new(); 69 | let columns_names = x_table.get_columns_names(); 70 | 71 | let means_sum = means_list.iter().sum::(); 72 | for (mean, name) in means_list.iter().zip(columns_names.iter()) { 73 | importance_rel.push(((100.0 * mean) / means_sum, name)); 74 | } 75 | importance_rel.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap()); 76 | 77 | let mut fg = Figure::new(); 78 | let axes = fg 79 | .axes2d() 80 | .set_y_range(Fix(0.0), Auto) 81 | .set_x_range(Fix(-0.5), Fix(importance_rel.len() as f64 + 1.0)) 82 | .set_title("Relative importance (%) of all the features", &[]) 83 | .set_margins(&[MarginBottom(0.24)]) 84 | .set_x_ticks(None, &[], &[]) 85 | .set_y_label("importance (%)", &[]); 86 | 87 | axes.boxes( 88 | (0i32..(columns_names.len() as i32)).collect::>(), 89 | importance_rel.clone().into_iter().map(|(v, _)| v).collect::>(), 90 | &[], 91 | ); 92 | 93 | for i in 0..importance_rel.len() { 94 | println!("{}: {}", importance_rel[i].1, importance_rel[i].0); 95 | axes.label( 96 | &importance_rel[i].1.replace("_", " "), 97 | Coordinate::Axis(i as f64), 98 | Coordinate::Axis(-0.2), 99 | &[Rotate(-40.0)], 100 | ); 101 | } 102 | 103 | fg.save_to_png( 104 | format!("visuals/{}_rel_importance.png", config_name), 105 | 1524, 106 | 728, 107 | ) 108 | .unwrap(); 109 | } 110 | -------------------------------------------------------------------------------- /examples/housing/src/bin/model.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::model::Model; 2 | use jiro_nn::monitor::TM; 3 | use jiro_nn::preprocessing::Pipeline; 4 | use jiro_nn::preprocessing::attach_ids::AttachIds; 5 | use jiro_nn::trainers::kfolds::KFolds; 6 | 7 | 8 | pub fn main() { 9 | let args: Vec = std::env::args().collect(); 10 | let config_name = &args[1]; 11 | 12 | let mut model = Model::from_json_file(format!("models/{}.json", config_name)); 13 | 14 | let mut pipeline = Pipeline::basic_single_pass(); 15 | let (updated_dataset_config, data) = pipeline 16 | .push(AttachIds::new("id")) 17 | .load_data("./dataset/kc_house_data.csv", Some(&model.dataset_config)) 18 | .run(); 19 | 20 | println!("data: {:#?}", data); 21 | 22 | let model = model.with_new_dataset(updated_dataset_config); 23 | 24 | TM::start_monitoring(); 25 | 26 | let mut kfold = KFolds::new(4); 27 | let (preds_and_ids, model_eval) = kfold 28 | // .attach_real_time_reporter(|fold, epoch, report| { 29 | // println!("Perf report: {:2} {:4} {:#?}", fold, epoch, report) 30 | // }) 31 | .all_epochs_validation() 32 | .all_epochs_r2() 33 | .compute_best_model() 34 | // .compute_avg_model() 35 | .run(&model, &data); 36 | 37 | TM::stop_monitoring(); 38 | 39 | let best_model_params = kfold.take_best_model(); 40 | //let avg_model_params = kfold.take_avg_model(); 41 | 42 | //best_model_params.to_json_file(format!("models_weights/{}_best_params.json", config_name)); 43 | best_model_params.to_binary_compressed(format!("models_weights/{}_best_params.gz", config_name)); 44 | //avg_model_params.to_json_file(format!("models_stats/{}_avg_params.json", config_name)); 45 | 46 | let preds_and_ids = pipeline.revert(&preds_and_ids); 47 | let data = pipeline.revert(&data); 48 | let data_and_preds = data.inner_join(&preds_and_ids, "id", "id", Some("pred")); 49 | 50 | data_and_preds.to_csv_file(format!("models_stats/{}.csv", config_name)); 51 | 52 | println!("{:#?}", data_and_preds); 53 | 54 | model_eval.to_json_file(format!("models_stats/{}.json", config_name)); 55 | } 56 | -------------------------------------------------------------------------------- /examples/housing/src/bin/plot_epochs_loss.rs: -------------------------------------------------------------------------------- 1 | use gnuplot::{Figure, AxesCommon, PlotOption::{Color, Caption, LineStyle}}; 2 | use jiro_nn::{benchmarking::ModelEvaluation}; 3 | 4 | fn main() { 5 | let args: Vec = std::env::args().collect(); 6 | let model_names = &args[1..]; 7 | let mut fg = Figure::new(); 8 | let mut axes = fg.axes2d() 9 | .set_title("Loss over epochs", &[]) 10 | .set_x_label("epochs", &[]) 11 | .set_y_label("loss", &[]) 12 | .set_y_log(Some(10.)); 13 | 14 | let colors = &[ 15 | "green", 16 | "red", 17 | "blue", 18 | "yellow", 19 | "purple", 20 | "orange", 21 | "brown", 22 | "pink", 23 | "gray" 24 | ]; 25 | 26 | for (i, model_name) in model_names.iter().enumerate() { 27 | let color = colors[i % colors.len()]; 28 | 29 | let model_eval = 30 | ModelEvaluation::from_json_file(format!("models_stats/{}.json", model_name)); 31 | 32 | let x = (0..model_eval.get_n_epochs()).collect::>(); 33 | let y1 = model_eval.epochs_avg_train_loss(); 34 | //let y2 = model_eval.epochs_std_train_loss(); 35 | //let y1_minus_y2 = y1.iter().zip(y2.iter()).map(|(a, b)| a - b).collect::>(); 36 | //let y1_plus_y2 = y1.iter().zip(y2.iter()).map(|(a, b)| a + b).collect::>(); 37 | 38 | let y3 = model_eval.epochs_avg_test_loss(); 39 | //let y4 = model_eval.epochs_std_test_loss(); 40 | //let y3_minus_y4 = y3.iter().zip(y4.iter()).map(|(a, b)| a - b).collect::>(); 41 | //let y3_plus_y4 = y3.iter().zip(y4.iter()).map(|(a, b)| a + b).collect::>(); 42 | 43 | axes = axes 44 | .lines(x.clone(), y1.clone(), &[Color(color), Caption(&format!("{} train loss", model_name.replace("_", " ")))]) 45 | .lines(x.clone(), y3.clone(), &[Color(color), Caption(&format!("{} test loss", model_name.replace("_", " "))), LineStyle(gnuplot::DashType::Dash)]); 46 | //.fill_between(x.clone(), y1_minus_y2, y1_plus_y2, &[Color(color), FillAlpha(0.1)]) 47 | //.fill_between(x.clone(), y3_minus_y4, y3_plus_y4, &[Color(color), FillAlpha(0.1)]); 48 | } 49 | 50 | fg.save_to_png(format!("visuals/{}_loss.png", model_names.join("_")), 1024, 728).unwrap(); 51 | } 52 | -------------------------------------------------------------------------------- /examples/housing/src/bin/plot_predictions.rs: -------------------------------------------------------------------------------- 1 | use gnuplot::{Figure, PlotOption::{Color, Caption, PointSize}, AxesCommon}; 2 | use jiro_nn::{datatable::DataTable}; 3 | 4 | fn main() { 5 | let args: Vec = std::env::args().collect(); 6 | let model_name = &args[1]; 7 | let out_data = DataTable::from_csv_file(format!("models_stats/{}.csv", model_name)); 8 | 9 | for col in out_data.get_columns_names().iter() { 10 | let mut fg = Figure::new(); 11 | let x = out_data.column_to_vector(col); 12 | let y1 = out_data.column_to_vector("pred_price"); 13 | let y2 = out_data.column_to_vector("price"); 14 | 15 | fg.axes2d() 16 | .set_title(&format!("Predicted price and price according to {}", col), &[]) 17 | .set_x_label(&col.replace("_", " "), &[]) 18 | .set_y_label("price", &[]) 19 | .points(x.clone(), y2.clone(), &[Color("red"), PointSize(0.2), Caption("price")]) 20 | .points(x.clone(), y1.clone(), &[Color("blue"), PointSize(0.2), Caption("predicted price")]); 21 | 22 | fg.save_to_png(format!("visuals/{}_{}.png", model_name, col), 1024, 728).unwrap(); 23 | } 24 | 25 | let mut fg = Figure::new(); 26 | let x = out_data.column_to_vector("lat"); 27 | let z1 = out_data.column_to_vector("pred_price"); 28 | let z2 = out_data.column_to_vector("price"); 29 | let y = out_data.column_to_vector("long"); 30 | 31 | fg.axes3d() 32 | .set_title("Predicted price and price according to latitude and longitude", &[]) 33 | .set_view(45., 15.) 34 | .set_x_label("latitude", &[]) 35 | .set_y_label("longitude", &[]) 36 | .set_z_label("price", &[]) 37 | .points(x.clone(), y.clone(), z2.clone(), &[Color("red"), PointSize(0.2), Caption("true price")]) 38 | .points(x.clone(), y.clone(), z1.clone(), &[Color("blue"), PointSize(0.2), Caption("predicted price")]); 39 | 40 | fg.save_to_png(format!("visuals/{}_latlong.png", model_name), 1024, 728).unwrap(); 41 | } -------------------------------------------------------------------------------- /examples/housing/src/bin/plot_r2.rs: -------------------------------------------------------------------------------- 1 | use gnuplot::*; 2 | use jiro_nn::{benchmarking::ModelEvaluation}; 3 | 4 | fn main() { 5 | let args: Vec = std::env::args().collect(); 6 | let model_name = &args[1]; 7 | 8 | let model_stats = ModelEvaluation::from_json_file(format!("models_stats/{}.json", model_name)); 9 | 10 | let mut fg = Figure::new(); 11 | let mut axes = fg.axes2d() 12 | .set_title("R² over epochs", &[]) 13 | .set_x_label("epochs", &[]) 14 | .set_y_label("R²", &[]) 15 | .set_y_range(Fix(0.8), Fix(1.0)); 16 | 17 | let colors = &[ 18 | "green", 19 | "red", 20 | "blue", 21 | "yellow", 22 | "purple", 23 | "orange", 24 | "brown", 25 | "pink", 26 | "gray", 27 | "spring-green" 28 | ]; 29 | 30 | for (i, fold_eval) in model_stats.folds.iter().enumerate() { 31 | let color = colors[i % colors.len()]; 32 | 33 | let x = (0..fold_eval.epochs.len()).collect::>(); 34 | let y1 = fold_eval.epochs.iter().map(|epoch_eval| epoch_eval.r2); 35 | 36 | axes = axes 37 | .points(x.clone(), y1.clone(), &[Color(color), PointSymbol('+'), Caption(&format!("{} fold {} R^2", model_name.replace("_", " "), i))]) 38 | } 39 | 40 | fg.save_to_png(format!("visuals/{}_folds_r2.png", model_name), 1024, 728).unwrap(); 41 | } -------------------------------------------------------------------------------- /examples/housing/src/bin/predict_all.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::{ 2 | datatable::DataTable, 3 | model::Model, 4 | network::params::NetworkParams, 5 | preprocessing::{attach_ids::AttachIds, Pipeline}, 6 | vec_utils::r2_score_vector2, 7 | }; 8 | 9 | pub fn main() { 10 | let args: Vec = std::env::args().collect(); 11 | let config_name = &args[1]; 12 | let weights_file = &args[2]; 13 | let out_name = &args[3]; 14 | 15 | let mut model = Model::from_json_file(format!("models/{}.json", config_name)); 16 | 17 | let mut pipeline = Pipeline::basic_single_pass(); 18 | let (updated_dataset_config, data) = pipeline 19 | .push(AttachIds::new("id")) 20 | .load_data("./dataset/kc_house_data.csv", Some(&model.dataset_config)) 21 | .run(); 22 | 23 | println!("data: {:#?}", data); 24 | 25 | let model = model.with_new_dataset(updated_dataset_config); 26 | let predicted_features = model.dataset_config.predicted_features_names(); 27 | 28 | let (x_table, y_table) = data.random_order_in_out(&predicted_features); 29 | 30 | let x = x_table.drop_column("id").to_vectors(); 31 | let y = y_table.to_vectors(); 32 | 33 | let weights = NetworkParams::from_json(format!("models_weights/{}.json", weights_file)); 34 | let mut network = model.to_network(); 35 | network.load_params(&weights); 36 | 37 | let (preds, avg_loss, std_loss) = 38 | network.predict_evaluate_many(&x, &y, &model.loss.to_loss(), 1); 39 | 40 | println!("avg_loss: {:#?}", avg_loss); 41 | println!("std_loss: {:#?}", std_loss); 42 | 43 | let r2 = r2_score_vector2(&y, &preds); 44 | 45 | println!("r2: {:#?}", r2); 46 | 47 | let preds_and_ids = 48 | DataTable::from_vectors(&predicted_features, &preds).add_column_from(&x_table, "id"); 49 | 50 | let preds_and_ids = pipeline.revert(&preds_and_ids); 51 | let data = pipeline.revert(&data); 52 | let data_and_preds = data.inner_join(&preds_and_ids, "id", "id", Some("pred")); 53 | 54 | data_and_preds.to_csv_file(format!("models_stats/{}.csv", out_name)); 55 | } 56 | -------------------------------------------------------------------------------- /examples/housing/src/lib.rs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/housing/src/lib.rs -------------------------------------------------------------------------------- /examples/housing/visuals/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /examples/mnist/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /examples/mnist/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cSpell.enableFiletypes": [ 3 | "!rust" 4 | ] 5 | } -------------------------------------------------------------------------------- /examples/mnist/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mnist" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | jiro_nn = { path = "../../", default_features = false, features = ["data", "parquet"] } 10 | gnuplot = "0.0.37" 11 | rand = "0.8.5" 12 | 13 | [features] 14 | f64 = ["jiro_nn/f64"] 15 | default = ["arrayfire"] 16 | nalgebra = ["jiro_nn/nalgebra"] 17 | arrayfire = ["jiro_nn/arrayfire"] 18 | -------------------------------------------------------------------------------- /examples/mnist/README.md: -------------------------------------------------------------------------------- 1 | # MNIST 2 | 3 | ## Workflow 4 | 5 | 1. Download the dataset from [Kaggle](https://www.kaggle.com/c/digit-recognizer/data) and extract it in the `dataset` folder. 6 | 2. Clean the data with `cargo run --bin clean`. 7 | 3. (Optional) edit `configurationify.rs` to change the configuration as code. 8 | 4. Generate configuration files with `cargo run --bin configurationify -- `. 9 | 5. Train the model with `cargo run --bin train -- `. -------------------------------------------------------------------------------- /examples/mnist/dataset/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /examples/mnist/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /examples/mnist/models_stats/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /examples/mnist/src/bin/clean.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::datatable::DataTable; 2 | use jiro_nn::monitor::TM; 3 | use jiro_nn::preprocessing::attach_ids::AttachIds; 4 | use jiro_nn::preprocessing::Pipeline; 5 | 6 | pub fn main() { 7 | TM::start_monitoring(); 8 | 9 | TM::start("clean"); 10 | 11 | if !std::path::Path::new("dataset/train.parquet").exists() { 12 | TM::start("convert_to_parquet"); 13 | let data_csv = DataTable::from_csv_file("dataset/train.csv"); 14 | data_csv.to_parquet_file("dataset/train.parquet"); 15 | let size_ratio = std::fs::metadata("dataset/train.parquet").unwrap().len() as f64 16 | / std::fs::metadata("dataset/train.csv").unwrap().len() as f64; 17 | 18 | TM::end_with_message(format!( 19 | "Converted to parquet with size ratio: {}", 20 | size_ratio 21 | )); 22 | } 23 | 24 | let mut pipeline = Pipeline::new(); 25 | pipeline.push(AttachIds::new("id")); 26 | 27 | pipeline.load_data("dataset/train.parquet", None); 28 | let (_, preprocessed_data) = pipeline.run(); 29 | 30 | TM::start("save"); 31 | preprocessed_data.to_parquet_file("dataset/train_cleaned.parquet"); 32 | TM::end(); 33 | 34 | TM::end(); 35 | 36 | TM::stop_monitoring(); 37 | } 38 | -------------------------------------------------------------------------------- /examples/mnist/src/bin/plot_epochs_loss.rs: -------------------------------------------------------------------------------- 1 | use gnuplot::{Figure, AxesCommon, PlotOption::{Color, Caption, LineStyle}}; 2 | use jiro_nn::{benchmarking::ModelEvaluation}; 3 | 4 | fn main() { 5 | let args: Vec = std::env::args().collect(); 6 | let model_names = &args[1..]; 7 | let mut fg = Figure::new(); 8 | let mut axes = fg.axes2d() 9 | .set_title("Loss over epochs", &[]) 10 | .set_x_label("epochs", &[]) 11 | .set_y_label("loss", &[]) 12 | .set_y_log(Some(10.)); 13 | 14 | let colors = &[ 15 | "green", 16 | "red", 17 | "blue", 18 | "yellow", 19 | "purple", 20 | "orange", 21 | "brown", 22 | "pink", 23 | "gray" 24 | ]; 25 | 26 | for (i, model_name) in model_names.iter().enumerate() { 27 | let color = colors[i % colors.len()]; 28 | 29 | let model_eval = 30 | ModelEvaluation::from_json_file(format!("models_stats/{}.json", model_name)); 31 | 32 | let x = (0..model_eval.get_n_epochs()).collect::>(); 33 | let y1 = model_eval.epochs_avg_train_loss(); 34 | //let y2 = model_eval.epochs_std_train_loss(); 35 | //let y1_minus_y2 = y1.iter().zip(y2.iter()).map(|(a, b)| a - b).collect::>(); 36 | //let y1_plus_y2 = y1.iter().zip(y2.iter()).map(|(a, b)| a + b).collect::>(); 37 | 38 | let y3 = model_eval.epochs_avg_test_loss(); 39 | //let y4 = model_eval.epochs_std_test_loss(); 40 | //let y3_minus_y4 = y3.iter().zip(y4.iter()).map(|(a, b)| a - b).collect::>(); 41 | //let y3_plus_y4 = y3.iter().zip(y4.iter()).map(|(a, b)| a + b).collect::>(); 42 | 43 | axes = axes 44 | .lines(x.clone(), y1.clone(), &[Color(color), Caption(&format!("{} train loss", model_name.replace("_", " ")))]) 45 | .lines(x.clone(), y3.clone(), &[Color(color), Caption(&format!("{} test loss", model_name.replace("_", " "))), LineStyle(gnuplot::DashType::Dash)]); 46 | //.fill_between(x.clone(), y1_minus_y2, y1_plus_y2, &[Color(color), FillAlpha(0.1)]) 47 | //.fill_between(x.clone(), y3_minus_y4, y3_plus_y4, &[Color(color), FillAlpha(0.1)]); 48 | } 49 | 50 | fg.save_to_png(format!("visuals/{}_loss.png", model_names.join("_")), 1024, 728).unwrap(); 51 | } 52 | -------------------------------------------------------------------------------- /examples/mnist/src/bin/predict_all.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::{ 2 | datatable::DataTable, 3 | model::Model, 4 | network::params::NetworkParams, 5 | preprocessing::{Pipeline}, 6 | }; 7 | 8 | pub fn main() { 9 | let args: Vec = std::env::args().collect(); 10 | let config_name = &args[1]; 11 | let weights_file = &args[2]; 12 | 13 | let mut model = Model::from_json_file(format!("models/{}.json", config_name)); 14 | 15 | let mut pipeline = Pipeline::basic_single_pass(); 16 | let (updated_dataset_config, data) = pipeline 17 | .load_data("./dataset/train_cleaned.csv", Some(&model.dataset_config)) 18 | .run(); 19 | 20 | println!("data: {:#?}", data); 21 | 22 | let model = model.with_new_dataset(updated_dataset_config); 23 | let predicted_features = model.dataset_config.predicted_features_names(); 24 | 25 | let (x_table, _) = data.random_order_in_out(&predicted_features); 26 | 27 | let x = x_table.drop_column("id").to_vectors(); 28 | 29 | let weights = NetworkParams::from_json(format!("models_weights/{}.json", weights_file)); 30 | let mut network = model.to_network(); 31 | network.load_params(&weights); 32 | 33 | let preds = network.predict_many(&x, model.batch_size.unwrap_or(x.len())); 34 | 35 | let preds_and_ids = 36 | DataTable::from_vectors(&predicted_features, &preds).add_column_from(&x_table, "id"); 37 | 38 | let preds_and_ids = pipeline.revert(&preds_and_ids); 39 | let data = pipeline.revert(&data); 40 | 41 | let values = data.select_columns(&["id", "label"]); 42 | let data = data.drop_columns(&["label", "label-confidence"]); 43 | let values_and_preds = values.inner_join(&preds_and_ids, "id", "id", Some("pred")); 44 | 45 | data.to_csv_file(format!("models_stats/{}_data_for_values_and_preds.csv", weights_file)); 46 | values_and_preds.to_csv_file(format!("models_stats/{}_values_and_preds.csv", weights_file)); 47 | 48 | println!("{:#?}", values_and_preds); 49 | } 50 | -------------------------------------------------------------------------------- /examples/mnist/src/bin/specify.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::{ 2 | dataset::{Dataset, FeatureTags::*}, 3 | model::ModelBuilder, 4 | loss::Losses, 5 | }; 6 | 7 | fn main() { 8 | let args: Vec = std::env::args().collect(); 9 | let config_name = &args[1]; 10 | 11 | let mut dataset_config = Dataset::from_file("dataset/train_cleaned.parquet"); 12 | dataset_config 13 | .tag_all(Normalized.except(&["label"])) 14 | .tag_feature("id", IsId) 15 | .tag_feature("label", Predicted) 16 | .tag_feature("label", OneHotEncode); 17 | 18 | let model = ModelBuilder::new(dataset_config) 19 | .neural_network() 20 | // 28x28 pixels in 21 | .conv_network(1) 22 | .full_dense(32, 5) 23 | .relu() 24 | .adam() 25 | .dropout(0.4) 26 | .end() 27 | .avg_pooling(2) 28 | .full_dense(64, 5) 29 | .relu() 30 | .adam() 31 | .dropout(0.5) 32 | .end() 33 | .avg_pooling(2) 34 | .end() 35 | .full_dense(128) 36 | .relu() 37 | .adam() 38 | .end() 39 | .full_dense(10) 40 | .softmax() 41 | .adam() 42 | .end() 43 | .end() 44 | .epochs(20) 45 | .batch_size(128) 46 | .loss(Losses::BCE) 47 | .build(); 48 | 49 | //println!("{:#?}", model); 50 | 51 | model.to_json_file(format!("models/{}.json", config_name)); 52 | } 53 | -------------------------------------------------------------------------------- /examples/mnist/src/bin/test.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::{dataset::{Dataset, FeatureTags}, preprocessing::Pipeline, model::ModelBuilder, loss::Losses, trainers::kfolds::KFolds}; 2 | 3 | pub fn main() { 4 | let dataset_path = "mnist.parquet"; 5 | 6 | // preprocessing without messing with polars 7 | let mut dataset_config = Dataset::from_file(dataset_path); 8 | let dataset_config = dataset_config 9 | .tag_feature("id", FeatureTags::IsId) 10 | .tag_feature("label", FeatureTags::Predicted) 11 | .tag_feature("label", FeatureTags::OneHotEncode) 12 | .tag_all(FeatureTags::Normalized.except(&["label", "id"])); 13 | 14 | let mut pipeline = Pipeline::basic_single_pass(); 15 | let (dataset_config, data) = pipeline 16 | .load_data(dataset_path, Some(dataset_config)) 17 | .run(); 18 | 19 | // model building without looking everywhere for the right structs 20 | let model = ModelBuilder::new(dataset_config) 21 | .neural_network() 22 | .conv_network(1) 23 | .full_dense(32, 5) 24 | .relu() 25 | .adam() 26 | .dropout(0.4) 27 | .end() 28 | .avg_pooling(2) 29 | .full_dense(64, 5) 30 | .relu() 31 | .adam() 32 | .dropout(0.5) 33 | .end() 34 | .avg_pooling(2) 35 | .end() 36 | .full_dense(128) 37 | .relu() 38 | .adam() 39 | .end() 40 | .full_dense(10) 41 | .softmax() 42 | .adam() 43 | .end() 44 | .end() 45 | .epochs(20) 46 | .batch_size(128) 47 | .loss(Losses::BCE) 48 | .build(); 49 | 50 | // training without installing a dedicated k-folds crate 51 | // nor messing with hairy tensors 52 | let mut kfolds = KFolds::new(10); 53 | let (predictions_by_id, model_eval) = kfolds.run(&model, &data); 54 | 55 | // saving the model 56 | model_eval.to_json_file("mnist_eval.json"); 57 | kfolds.take_best_model().to_json("mnist_weights.json"); 58 | 59 | // saving the predictions alongside the original data 60 | let predictions_by_id = pipeline.revert(&predictions_by_id); 61 | pipeline.revert(&data) 62 | .inner_join(&predictions_by_id, "id", "id", Some("pred")) 63 | .to_parquet_file("mnist_values_and_preds.parquet"); 64 | } -------------------------------------------------------------------------------- /examples/mnist/src/bin/train.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::model::Model; 2 | use jiro_nn::monitor::TM; 3 | use jiro_nn::preprocessing::Pipeline; 4 | use jiro_nn::trainers::split::SplitTraining; 5 | 6 | pub fn main() { 7 | let args: Vec = std::env::args().collect(); 8 | let config_name = &args[1]; 9 | 10 | let mut model = Model::from_json_file(format!("models/{}.json", config_name)); 11 | 12 | TM::start_monitoring(); 13 | 14 | let mut pipeline = Pipeline::basic_single_pass(); 15 | let (dconfiguration, data) = pipeline 16 | .load_data("./dataset/train_cleaned.parquet", Some(&model.dataset_config)) 17 | .run(); 18 | 19 | TM::start("modelinit"); 20 | 21 | let model = model.with_new_dataset(dconfiguration); 22 | 23 | TM::end_with_message(format!( 24 | "Model parameters count: {}", 25 | model.to_network().get_params().count() 26 | )); 27 | 28 | let mut training = SplitTraining::new(0.8); 29 | let (preds_and_ids, model_eval) = training.run(&model, &data); 30 | 31 | TM::stop_monitoring(); 32 | 33 | let model_params = training.take_model(); 34 | model_params.to_json(format!("models_stats/{}_params.json", config_name)); 35 | 36 | let preds_and_ids = pipeline.revert(&preds_and_ids); 37 | let data = pipeline.revert(&data); 38 | 39 | let values = data.select_columns(&["id", "label"]).rename_column("label", "true_label"); 40 | let values_and_preds = values.inner_join(&preds_and_ids, "id", "id", Some("pred")); 41 | 42 | data.to_csv_file(format!("models_stats/{}_data_for_values_and_preds.parquet", config_name)); 43 | values_and_preds.to_csv_file(format!("models_stats/{}_values_and_preds.parquet", config_name)); 44 | 45 | println!("{:#?}", values_and_preds); 46 | 47 | model_eval.to_json_file(format!("models_stats/{}.json", config_name)); 48 | } 49 | -------------------------------------------------------------------------------- /examples/mnist/src/lib.rs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/mnist/src/lib.rs -------------------------------------------------------------------------------- /examples/mnist/visuals/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /examples/visuals/C_price_over_price.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/C_price_over_price.png -------------------------------------------------------------------------------- /examples/visuals/GS2L_price_over_price.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/GS2L_price_over_price.png -------------------------------------------------------------------------------- /examples/visuals/GS2L_prop_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/GS2L_prop_dist.png -------------------------------------------------------------------------------- /examples/visuals/dropout_rate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/dropout_rate.png -------------------------------------------------------------------------------- /examples/visuals/feynman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/feynman.jpg -------------------------------------------------------------------------------- /examples/visuals/filt_sqrd_6ReLU-Adam_Lin-Adam_+8_id.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/filt_sqrd_6ReLU-Adam_Lin-Adam_+8_id.png -------------------------------------------------------------------------------- /examples/visuals/filt_sqrd_6Tanh-Adam_Tanh-Adam_+8_id.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/filt_sqrd_6Tanh-Adam_Tanh-Adam_+8_id.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_bedrooms^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_bedrooms^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_bathrooms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_bathrooms.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_bathrooms^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_bathrooms^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_bedrooms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_bedrooms.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_bedrooms^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_bedrooms^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_condition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_condition.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_condition^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_condition^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_date_month.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_date_month.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_date_month^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_date_month^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_date_timestamp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_date_timestamp.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_date_timestamp^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_date_timestamp^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_floors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_floors.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_floors^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_floors^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_grade.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_grade.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_grade^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_grade^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_id.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_id.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_lat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_lat.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_lat^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_lat^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_latlong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_latlong.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_long.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_long.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_long^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_long^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_pred_price.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_pred_price.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_price.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_price.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_sqft_above.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_sqft_above.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_sqft_above^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_sqft_above^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_sqft_basement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_sqft_basement.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_sqft_basement^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_sqft_basement^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_sqft_living.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_sqft_living.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_sqft_living^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_sqft_living^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_sqft_lot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_sqft_lot.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_sqft_lot^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_sqft_lot^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_view.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_view.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_view^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_view^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_waterfront.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_waterfront.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_waterfront^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_waterfront^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_yr_built.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_yr_built.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_yr_built^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_yr_built^2.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_yr_renovated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_yr_renovated.png -------------------------------------------------------------------------------- /examples/visuals/final_best/final_best_yr_renovated^2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_best/final_best_yr_renovated^2.png -------------------------------------------------------------------------------- /examples/visuals/final_folds_r2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_folds_r2.png -------------------------------------------------------------------------------- /examples/visuals/final_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_loss.png -------------------------------------------------------------------------------- /examples/visuals/final_rel_importance_cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/final_rel_importance_cropped.png -------------------------------------------------------------------------------- /examples/visuals/full_lt_8ReLU-Adam-Lin-Adam_boxplots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/full_lt_8ReLU-Adam-Lin-Adam_boxplots.png -------------------------------------------------------------------------------- /examples/visuals/full_lt_8ReLU-Adam-Lin-Adam_latlong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/full_lt_8ReLU-Adam-Lin-Adam_latlong.png -------------------------------------------------------------------------------- /examples/visuals/full_lt_8ReLU-Adam-Lin-Adam_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/full_lt_8ReLU-Adam-Lin-Adam_loss.png -------------------------------------------------------------------------------- /examples/visuals/full_lt_8ReLU-Adam-Lin-Adam_price.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/full_lt_8ReLU-Adam-Lin-Adam_price.png -------------------------------------------------------------------------------- /examples/visuals/learning_rate_decay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/learning_rate_decay.png -------------------------------------------------------------------------------- /examples/visuals/model1_L_id.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/model1_L_id.png -------------------------------------------------------------------------------- /examples/visuals/model1_L_price.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/model1_L_price.png -------------------------------------------------------------------------------- /examples/visuals/model3_id.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/model3_id.png -------------------------------------------------------------------------------- /examples/visuals/model4_price.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/model4_price.png -------------------------------------------------------------------------------- /examples/visuals/model5_price.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/model5_price.png -------------------------------------------------------------------------------- /examples/visuals/model7_id.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/model7_id.png -------------------------------------------------------------------------------- /examples/visuals/model9_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/model9_loss.png -------------------------------------------------------------------------------- /examples/visuals/neural_network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/neural_network.png -------------------------------------------------------------------------------- /examples/visuals/no_outliers_filter_outliers_filter_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/no_outliers_filter_outliers_filter_loss.png -------------------------------------------------------------------------------- /examples/visuals/sgd_vs_momentum_folds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/sgd_vs_momentum_folds.jpg -------------------------------------------------------------------------------- /examples/visuals/weights_uniform_tanh_weights_uniform_signed_tanh_weights_glorot_uniform_tanh_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/weights_uniform_tanh_weights_uniform_signed_tanh_weights_glorot_uniform_tanh_loss.png -------------------------------------------------------------------------------- /examples/visuals/weights_uniform_weights_zeros_weights_glorot_uniform_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/weights_uniform_weights_zeros_weights_glorot_uniform_loss.png -------------------------------------------------------------------------------- /examples/visuals/with_sgd_with_momentum_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/with_sgd_with_momentum_loss.png -------------------------------------------------------------------------------- /examples/visuals/with_sgd_with_momentum_with_adam_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/with_sgd_with_momentum_with_adam_loss.png -------------------------------------------------------------------------------- /examples/visuals/xor-example-predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/xor-example-predictions.png -------------------------------------------------------------------------------- /examples/visuals/xor-example-predictions_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/visuals/xor-example-predictions_2.png -------------------------------------------------------------------------------- /examples/visuals/xor_nn_graph.dot: -------------------------------------------------------------------------------- 1 | digraph NeuralNetwork { 2 | rankdir=LR; 3 | 4 | // input layer 5 | node [shape=circle margin=0.1]; 6 | x1; 7 | x2; 8 | 9 | // tanh layers 10 | node [shape=circle style=filled fillcolor=yellow]; 11 | 12 | t11 [label="x11 = \nTanh(y1)"]; 13 | t12 [label="x12 = \nTanh(y2)"]; 14 | t13 [label="x13 = \nTanh(y3)"]; 15 | 16 | t21 [label="pred = \nTanh(y)"]; 17 | 18 | // hidden layer 19 | node [shape=circle style=filled fillcolor=aqua margin=0.1 width=1.0]; 20 | h1 [label="y1 =\n ∑(xi*w1i)\n+b1"]; 21 | h2 [label="y2 =\n ∑(xi*w2i)\n+b2"]; 22 | h3 [label="y3 =\n ∑(xi*w3i)\n+b3"]; 23 | 24 | // output layer 25 | node [shape=circle style=filled fillcolor=greenyellow]; 26 | y [label="y =\n ∑(x1i*wi)\n+b"]; 27 | 28 | node [shape=circle style=filled fillcolor=lightpink margin=0.01]; 29 | e [label="loss = \n(pred - true)^2\n/2"]; 30 | 31 | // weights and biases 32 | x1 -> h1 33 | x1 -> h2 34 | x1 -> h3 35 | 36 | x2 -> h1 37 | x2 -> h2 38 | x2 -> h3 39 | 40 | h1 -> t11 41 | h2 -> t12 42 | h3 -> t13 43 | 44 | t11 -> y 45 | t12 -> y 46 | t13 -> y 47 | 48 | y -> t21 49 | 50 | t21 -> e 51 | } -------------------------------------------------------------------------------- /examples/xor/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /examples/xor/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cSpell.enableFiletypes": [ 3 | "!rust" 4 | ] 5 | } -------------------------------------------------------------------------------- /examples/xor/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "xor" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | jiro_nn = { path = "../../", default_features = false } 10 | 11 | [features] 12 | f64 = ["jiro_nn/f64"] 13 | default = ["ndarray"] 14 | nalgebra = ["jiro_nn/nalgebra"] 15 | arrayfire = ["jiro_nn/arrayfire"] 16 | ndarray = ["jiro_nn/ndarray"] 17 | -------------------------------------------------------------------------------- /examples/xor/README.md: -------------------------------------------------------------------------------- 1 | # XOR example 2 | 3 | Showcasing bare-bones usage of `jiro-nn` on the CPU, without the dataframes/preprocessing features and with a user-made training loop. For a more in-depth example, see the [King County Houses regression example](../housing/README.md). 4 | 5 | If replicating on your own, don't forget to disable the default `"data"` feature and configurationify the backend you want to use: 6 | 7 | Example with `ndarray` backend: 8 | 9 | ```toml 10 | [dependencies] 11 | jiro_nn = { version = "*", default_features = false, features = ["ndarray"] } 12 | ``` -------------------------------------------------------------------------------- /examples/xor/src/lib.rs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnicetNgrt/jiro-nn/9b683e7344ecb87b50ee7a3c41f14ac3b94b804d/examples/xor/src/lib.rs -------------------------------------------------------------------------------- /examples/xor/src/main.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::{loss::Losses, model::network_model::NetworkModelBuilder}; 2 | 3 | fn main() { 4 | let training_data_in = vec![ 5 | vec![0.0, 0.0], 6 | vec![1.0, 0.0], 7 | vec![0.0, 1.0], 8 | vec![1.0, 1.0], 9 | ]; 10 | 11 | let training_data_out = vec![ 12 | vec![0.0], 13 | vec![1.0], 14 | vec![1.0], 15 | vec![0.0] 16 | ]; 17 | 18 | let in_size = 2; 19 | let hidden_out_size = 3; 20 | let out_size = 1; 21 | 22 | let network_model = NetworkModelBuilder::new() 23 | .full_dense(hidden_out_size) 24 | .init_glorot_uniform() 25 | .sgd() 26 | .tanh() 27 | .end() 28 | .full_dense(out_size) 29 | .init_glorot_uniform() 30 | .sgd() 31 | .tanh() 32 | .end() 33 | .build(); 34 | 35 | let mut network = network_model.to_network(in_size); 36 | 37 | let loss = Losses::MSE.to_loss(); 38 | let batch_size = 1; 39 | 40 | for epoch in 0..50000 { 41 | let error = network.train( 42 | epoch, 43 | &training_data_in, 44 | &training_data_out, 45 | &loss, 46 | batch_size, 47 | ); 48 | 49 | if epoch % 1000 == 0 { 50 | println!("Epoch: {} Average training loss: {}", epoch, error); 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/activation/linear.rs: -------------------------------------------------------------------------------- 1 | use super::ActivationLayer; 2 | use crate::linalg::{Matrix, MatrixTrait}; 3 | 4 | pub fn new() -> ActivationLayer { 5 | ActivationLayer::new( 6 | |m| m.clone(), 7 | |m| Matrix::constant(m.dim().0, m.dim().1, 1.0), 8 | ) 9 | } 10 | -------------------------------------------------------------------------------- /src/activation/mod.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | 3 | use crate::linalg::MatrixTrait; 4 | use serde::{Deserialize, Serialize}; 5 | 6 | use crate::{layer::Layer, linalg::Matrix}; 7 | 8 | pub mod linear; 9 | pub mod relu; 10 | pub mod sigmoid; 11 | pub mod softmax; 12 | pub mod tanh; 13 | 14 | pub type ActivationFn = fn(&Matrix) -> Matrix; 15 | pub type GradDepActivationFn = fn(&Matrix, &Matrix) -> Matrix; 16 | 17 | pub enum ActivationFnPrime { 18 | ActivationFn(ActivationFn), 19 | GradDepActivationFn(GradDepActivationFn), 20 | } 21 | 22 | pub struct ActivationLayer { 23 | // i inputs = i outputs (it's just a map) 24 | input: Option, 25 | output: Option, 26 | activation: ActivationFn, 27 | derivative: ActivationFnPrime, 28 | } 29 | 30 | impl ActivationLayer { 31 | pub fn new(activation: ActivationFn, derivative: ActivationFn) -> Self { 32 | Self { 33 | input: None, 34 | output: None, 35 | activation, 36 | derivative: ActivationFnPrime::ActivationFn(derivative), 37 | } 38 | } 39 | 40 | pub fn new_grad_dep(activation: ActivationFn, derivative: GradDepActivationFn) -> Self { 41 | Self { 42 | input: None, 43 | output: None, 44 | activation, 45 | derivative: ActivationFnPrime::GradDepActivationFn(derivative), 46 | } 47 | } 48 | } 49 | 50 | impl Layer for ActivationLayer { 51 | fn forward(&mut self, input: Matrix) -> Matrix { 52 | self.input = Some(input.clone()); 53 | let output = (self.activation)(&input); 54 | self.output = Some(output.clone()); 55 | output 56 | } 57 | 58 | fn backward(&mut self, _epoch: usize, output_gradient: Matrix) -> Matrix { 59 | match self.derivative { 60 | ActivationFnPrime::ActivationFn(f) => { 61 | // ∂E/∂X = ∂E/∂Y ⊙ f'(X) 62 | let input = self.input.clone().unwrap(); 63 | let fprime_x = (f)(&input); 64 | output_gradient.component_mul(&fprime_x) 65 | }, 66 | ActivationFnPrime::GradDepActivationFn(f) => { 67 | let output = self.output.clone().unwrap(); 68 | (f)(&output, &output_gradient) 69 | }, 70 | } 71 | } 72 | } 73 | 74 | #[derive(Debug, Clone, Copy, Serialize, Deserialize)] 75 | pub enum Activation { 76 | Tanh, 77 | Sigmoid, 78 | ReLU, 79 | Linear, 80 | Softmax 81 | } 82 | 83 | impl Activation { 84 | pub fn to_layer(&self) -> ActivationLayer { 85 | match self { 86 | Self::Linear => linear::new(), 87 | Self::Tanh => tanh::new(), 88 | Self::Sigmoid => sigmoid::new(), 89 | Self::ReLU => relu::new(), 90 | Self::Softmax => softmax::new(), 91 | } 92 | } 93 | } 94 | 95 | impl fmt::Debug for ActivationLayer { 96 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 97 | write!(f, "Activation Layer") 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/activation/relu.rs: -------------------------------------------------------------------------------- 1 | use super::ActivationLayer; 2 | use crate::linalg::{Matrix, MatrixTrait}; 3 | 4 | pub fn new() -> ActivationLayer { 5 | ActivationLayer::new( 6 | |m| m.maxof(&Matrix::constant(m.dim().0, m.dim().1, 0.)), 7 | |m| m.sign().maxof(&Matrix::constant(m.dim().0, m.dim().1, 0.)), 8 | ) 9 | } 10 | -------------------------------------------------------------------------------- /src/activation/sigmoid.rs: -------------------------------------------------------------------------------- 1 | use super::ActivationLayer; 2 | use crate::linalg::{Matrix, MatrixTrait}; 3 | 4 | fn sigmoid(m: &Matrix) -> Matrix { 5 | let exp_neg = m.scalar_mul(-1.).exp(); 6 | let ones = Matrix::constant(m.dim().0, m.dim().1, 1.0); 7 | ones.component_div(&(ones.component_add(&exp_neg))) 8 | } 9 | 10 | fn sigmoid_prime(m: &Matrix) -> Matrix { 11 | let sig = sigmoid(m); 12 | let ones = Matrix::constant(sig.dim().0, sig.dim().1, 1.0); 13 | sig.component_mul(&(ones.component_sub(&sig))) 14 | } 15 | 16 | pub fn new() -> ActivationLayer { 17 | ActivationLayer::new(sigmoid, sigmoid_prime) 18 | } 19 | -------------------------------------------------------------------------------- /src/activation/softmax.rs: -------------------------------------------------------------------------------- 1 | use super::ActivationLayer; 2 | use crate::linalg::{Matrix, MatrixTrait}; 3 | use std::{thread::available_parallelism}; 4 | 5 | // Formulas references from: 6 | // https://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/ 7 | // https://www.youtube.com/watch?v=AbLvJVwySEo 8 | 9 | fn stablesoftmax_col(col: &Matrix) -> Matrix { 10 | let shiftx = col.scalar_sub(col.max()); 11 | let exps = shiftx.exp(); 12 | let sum = exps.sum(); 13 | exps.scalar_div(sum) 14 | } 15 | 16 | fn stablesoftmax(m: &Matrix) -> Matrix { 17 | let ncol = m.dim().1; 18 | let mut columns: Vec = Vec::with_capacity(ncol); 19 | 20 | let n_threads = available_parallelism().unwrap().get().min(ncol / 10); 21 | if Matrix::is_backend_thread_safe() && n_threads > 1 { 22 | let step_size = (ncol as f32 / n_threads as f32).ceil() as usize; 23 | 24 | let mut threads = Vec::with_capacity(ncol); 25 | for i in (0..ncol).step_by(step_size) { 26 | let mut thread_columns = Vec::with_capacity(step_size); 27 | for j in i..(i + step_size).min(ncol) { 28 | let col = m.get_column_as_matrix(j); 29 | thread_columns.push(col); 30 | } 31 | threads.push(std::thread::spawn(move || { 32 | let mut results = Vec::with_capacity(step_size); 33 | for col in thread_columns { 34 | let result = stablesoftmax_col(&col); 35 | results.push(result); 36 | } 37 | results 38 | })); 39 | } 40 | for thread in threads { 41 | columns.extend(thread.join().unwrap()); 42 | } 43 | } else { 44 | for i in 0..ncol { 45 | let col = m.get_column_as_matrix(i); 46 | let result = stablesoftmax_col(&col); 47 | columns.push(result); 48 | } 49 | } 50 | 51 | Matrix::from_column_matrices(&columns) 52 | } 53 | 54 | fn softmax_prime_col(col: &Matrix, output_gradient: &Matrix) -> Matrix { 55 | let n = col.dim().0; 56 | let ones = Matrix::constant(1, n, 1.0); 57 | let col_repeated = col.dot(&ones); 58 | let identity = Matrix::identity(n); 59 | let result = col_repeated 60 | .component_mul(&identity.component_sub(&col_repeated.transpose())) 61 | .dot(output_gradient); 62 | result 63 | } 64 | 65 | fn softmax_prime(m: &Matrix, output_gradient: &Matrix) -> Matrix { 66 | let ncol = m.dim().1; 67 | let mut columns: Vec = Vec::with_capacity(ncol); 68 | 69 | let n_threads = available_parallelism().unwrap().get().min(ncol / 10); 70 | if Matrix::is_backend_thread_safe() && n_threads > 1 { 71 | let step_size = (ncol as f32 / n_threads as f32).ceil() as usize; 72 | 73 | let mut threads = Vec::with_capacity(ncol); 74 | for i in (0..ncol).step_by(step_size) { 75 | let mut thread_columns = Vec::with_capacity(step_size); 76 | for j in i..(i + step_size).min(ncol) { 77 | let col = m.get_column_as_matrix(j); 78 | let grad_col = output_gradient.get_column_as_matrix(j); 79 | thread_columns.push((col, grad_col)); 80 | } 81 | threads.push(std::thread::spawn(move || { 82 | let mut results = Vec::with_capacity(step_size); 83 | for (col, grad_col) in thread_columns { 84 | let result = softmax_prime_col(&col, &grad_col); 85 | results.push(result); 86 | } 87 | results 88 | })); 89 | } 90 | for thread in threads { 91 | columns.extend(thread.join().unwrap()); 92 | } 93 | } else { 94 | for i in 0..ncol { 95 | let col = m.get_column_as_matrix(i); 96 | let grad_col = output_gradient.get_column_as_matrix(i); 97 | let result = softmax_prime_col(&col, &grad_col); 98 | columns.push(result); 99 | } 100 | } 101 | 102 | Matrix::from_column_matrices(&columns) 103 | } 104 | 105 | pub fn new() -> ActivationLayer { 106 | ActivationLayer::new_grad_dep(stablesoftmax, softmax_prime) 107 | } 108 | -------------------------------------------------------------------------------- /src/activation/tanh.rs: -------------------------------------------------------------------------------- 1 | use super::ActivationLayer; 2 | use crate::linalg::{Matrix, MatrixTrait}; 3 | 4 | fn tanh(m: &Matrix) -> Matrix { 5 | let exp = m.exp(); 6 | let exp_neg = m.scalar_mul(-1.).exp(); 7 | (exp.component_sub(&exp_neg)).component_div(&(exp.component_add(&exp_neg))) 8 | } 9 | 10 | fn tanh_prime(m: &Matrix) -> Matrix { 11 | let hbt = tanh(m); 12 | let hbt2 = &hbt.square(); 13 | let ones = Matrix::constant(hbt.dim().0, hbt.dim().1, 1.0); 14 | ones.component_sub(&hbt2) 15 | } 16 | 17 | pub fn new() -> ActivationLayer { 18 | ActivationLayer::new(tanh, tanh_prime) 19 | } 20 | -------------------------------------------------------------------------------- /src/benchmarking.rs: -------------------------------------------------------------------------------- 1 | use std::{fs::File, io::Write}; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use crate::linalg::Scalar; 6 | 7 | #[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] 8 | pub struct ModelEvaluation { 9 | pub folds: Vec, 10 | } 11 | 12 | #[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] 13 | pub struct TrainingEvaluation { 14 | pub epochs: Vec, 15 | } 16 | 17 | #[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] 18 | pub struct EpochEvaluation { 19 | pub train_loss: Scalar, 20 | pub test_loss_avg: Scalar, 21 | pub test_loss_std: Scalar, 22 | pub r2: Scalar, 23 | } 24 | 25 | impl ModelEvaluation { 26 | pub fn new_empty() -> Self { 27 | Self { folds: vec![] } 28 | } 29 | 30 | pub fn add_fold(&mut self, fold: TrainingEvaluation) { 31 | self.folds.push(fold); 32 | } 33 | 34 | pub fn epochs_avg_train_loss(&self) -> Vec { 35 | let mut avg = vec![0.0; self.folds[0].epochs.len()]; 36 | for fold in &self.folds { 37 | for (i, epoch) in fold.epochs.iter().enumerate() { 38 | avg[i] += epoch.train_loss; 39 | } 40 | } 41 | for i in 0..avg.len() { 42 | avg[i] /= self.folds.len() as Scalar; 43 | } 44 | avg 45 | } 46 | 47 | pub fn epochs_std_train_loss(&self) -> Vec { 48 | let avg = self.epochs_avg_train_loss(); 49 | let mut std = vec![0.0; self.folds[0].epochs.len()]; 50 | for fold in &self.folds { 51 | for (i, epoch) in fold.epochs.iter().enumerate() { 52 | std[i] += (epoch.train_loss - avg[i]).powi(2); 53 | } 54 | } 55 | for i in 0..std.len() { 56 | std[i] = (std[i] / self.folds.len() as Scalar).sqrt(); 57 | } 58 | std 59 | } 60 | 61 | pub fn epochs_avg_test_loss(&self) -> Vec { 62 | let mut avg = vec![0.0; self.folds[0].epochs.len()]; 63 | for fold in &self.folds { 64 | for (i, epoch) in fold.epochs.iter().enumerate() { 65 | avg[i] += epoch.test_loss_avg; 66 | } 67 | } 68 | for i in 0..avg.len() { 69 | avg[i] /= self.folds.len() as Scalar; 70 | } 71 | avg 72 | } 73 | 74 | pub fn epochs_std_test_loss(&self) -> Vec { 75 | let avg = self.epochs_avg_test_loss(); 76 | let mut std = vec![0.0; self.folds[0].epochs.len()]; 77 | for fold in &self.folds { 78 | for (i, epoch) in fold.epochs.iter().enumerate() { 79 | std[i] += (epoch.test_loss_avg - avg[i]).powi(2); 80 | } 81 | } 82 | for i in 0..std.len() { 83 | std[i] = (std[i] / self.folds.len() as Scalar).sqrt(); 84 | } 85 | std 86 | } 87 | 88 | pub fn from_json_file>(path: S) -> Self { 89 | let file = File::open(path.as_ref()).unwrap(); 90 | serde_json::from_reader(file).unwrap() 91 | } 92 | 93 | pub fn to_json_file>(&self, path: S) { 94 | let mut file = File::create(path.as_ref()).unwrap(); 95 | let json_string = serde_json::to_string_pretty(self).unwrap(); 96 | file.write_all(json_string.as_bytes()).unwrap(); 97 | } 98 | 99 | pub fn get_n_epochs(&self) -> usize { 100 | self.folds[0].epochs.len() 101 | } 102 | 103 | pub fn get_n_folds(&self) -> usize { 104 | self.folds.len() 105 | } 106 | } 107 | 108 | impl TrainingEvaluation { 109 | pub fn new_empty() -> Self { 110 | Self { epochs: vec![] } 111 | } 112 | 113 | pub fn add_epoch(&mut self, epoch: EpochEvaluation) { 114 | self.epochs.push(epoch); 115 | } 116 | 117 | pub fn get_final_epoch(&self) -> EpochEvaluation { 118 | self.epochs[self.epochs.len() - 1].clone() 119 | } 120 | 121 | pub fn get_final_test_loss_avg(&self) -> Scalar { 122 | self.get_final_epoch().test_loss_avg 123 | } 124 | 125 | pub fn get_final_test_loss_std(&self) -> Scalar { 126 | self.get_final_epoch().test_loss_std 127 | } 128 | 129 | pub fn get_final_r2(&self) -> Scalar { 130 | self.get_final_epoch().r2 131 | } 132 | } 133 | 134 | impl EpochEvaluation { 135 | pub fn new( 136 | train_loss: Scalar, 137 | test_loss_avg: Scalar, 138 | test_loss_std: Scalar, 139 | r2: Scalar, 140 | ) -> Self { 141 | Self { 142 | train_loss, 143 | test_loss_avg, 144 | test_loss_std, 145 | r2, 146 | } 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /src/bin/main.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | println!("Hello, world!"); 3 | } 4 | -------------------------------------------------------------------------------- /src/bin/test_arrayfire.rs: -------------------------------------------------------------------------------- 1 | 2 | #[cfg(feature = "arrayfire")] 3 | use arrayfire::{convolve3, flip, index, print, Seq}; 4 | 5 | #[allow(unused_imports)] 6 | #[allow(unused_variables)] 7 | use jiro_nn::{ 8 | linalg::{Matrix, MatrixTrait}, 9 | vision::{image::Image, image::ImageTrait}, 10 | }; 11 | 12 | #[cfg(feature = "arrayfire")] 13 | pub fn main() { 14 | let image = Image::from_samples( 15 | &Matrix::from_column_leading_vector2(&vec![ 16 | vec![ 17 | 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 18 | 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 19 | ], 20 | vec![ 21 | 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 22 | 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 23 | ], 24 | vec![ 25 | 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 26 | 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 27 | ], 28 | ]), 29 | 3, 30 | ); 31 | 32 | println!("Original pictures"); 33 | print(&image.0); 34 | 35 | let kernel = Image::from_samples( 36 | &Matrix::from_column_leading_vector2(&vec![vec![ 37 | 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 38 | ]]), 39 | 3, 40 | ); 41 | println!("Kernels"); 42 | print(&kernel.0); 43 | 44 | let rot_kern = flip(&flip(&kernel.0, 0), 1); 45 | println!("Kernels rotated"); 46 | print(&rot_kern); 47 | 48 | let res = convolve3( 49 | &image.0, 50 | &rot_kern, 51 | arrayfire::ConvMode::DEFAULT, 52 | arrayfire::ConvDomain::AUTO, 53 | ); 54 | println!("Result"); 55 | print(&res); 56 | 57 | println!("Result cropped"); 58 | let out_size = image.image_dims().0 - kernel.image_dims().0 + 1; 59 | let res = index( 60 | &res, 61 | &[ 62 | Seq::new(0, (out_size - 1).try_into().unwrap(), 1), 63 | Seq::new(0, (out_size - 1).try_into().unwrap(), 1), 64 | Seq::new(0, (kernel.samples() - 1).try_into().unwrap(), 1), 65 | Seq::new(0, (image.samples() - 1).try_into().unwrap(), 1), 66 | ], 67 | ); 68 | 69 | print(&res); 70 | } 71 | 72 | #[cfg(not(feature = "arrayfire"))] 73 | pub fn main() { 74 | println!("This example requires the arrayfire feature to be enabled"); 75 | } -------------------------------------------------------------------------------- /src/bin/test_avg_pooling.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "arrayfire")] 2 | use arrayfire::print; 3 | 4 | #[allow(unused_imports)] 5 | #[allow(unused_variables)] 6 | 7 | use jiro_nn::vision::{ 8 | conv_layer::avg_pooling_layer::AvgPoolingLayer, image::Image, image::ImageTrait, 9 | image_layer::ImageLayer, 10 | }; 11 | 12 | pub fn main() { 13 | let image = Image::random_normal(6, 6, 3, 2, 2.0, 1.0); 14 | 15 | #[cfg(feature = "arrayfire")] 16 | print(&image.0); 17 | 18 | let mut layer = AvgPoolingLayer::new(2); 19 | 20 | #[allow(unused_variables)] 21 | let image = layer.forward(image); 22 | 23 | #[cfg(feature = "arrayfire")] 24 | print(&image.0); 25 | 26 | let gradient = Image::random_normal(3, 3, 3, 2, 0.0, 0.5); 27 | 28 | #[cfg(feature = "arrayfire")] 29 | print(&gradient.0); 30 | 31 | #[allow(unused_variables)] 32 | let gradient = layer.backward(0, gradient); 33 | 34 | #[cfg(feature = "arrayfire")] 35 | print(&gradient.0); 36 | } 37 | -------------------------------------------------------------------------------- /src/bin/test_backend.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "arrayfire")] 2 | use arrayfire::get_available_backends; 3 | use jiro_nn::linalg::{Matrix, MatrixTrait, Scalar}; 4 | 5 | pub fn print_mat(m: &Matrix) { 6 | println!("["); 7 | for row in m.get_data_row_leading() { 8 | print!(" "); 9 | for val in row { 10 | print!("{}, ", val); 11 | } 12 | println!("") 13 | } 14 | println!("]"); 15 | } 16 | 17 | pub fn main() { 18 | #[cfg(feature = "arrayfire")] 19 | let backends = get_available_backends(); 20 | #[cfg(feature = "arrayfire")] 21 | println!("{:#?}", backends); 22 | 23 | let m = Matrix::constant(10, 1000, 2.0); 24 | let m2 = Matrix::constant(1000, 10, 131.13313); 25 | let m3 = m.dot(&m2); 26 | 27 | print_mat(&m3); 28 | 29 | let m4 = m3.transpose(); 30 | 31 | print_mat(&m4); 32 | 33 | let m5 = m4.columns_sum(); 34 | 35 | print_mat(&m5); 36 | 37 | let m = Matrix::from_iter(1, 3, vec![1.0, 2.0, 3.0].into_iter()); 38 | 39 | print_mat(&m); 40 | 41 | let m = Matrix::from_iter(3, 1, vec![1.0, 2.0, 3.0].into_iter()); 42 | 43 | print_mat(&m); 44 | 45 | let m = Matrix::from_fn(3, 4, |i, _| i as Scalar); 46 | 47 | print_mat(&m); 48 | 49 | let m = Matrix::from_fn(3, 4, |_, j| j as Scalar); 50 | 51 | print_mat(&m); 52 | 53 | let m = Matrix::from_fn(3, 4, |i, j| i as Scalar + j as Scalar); 54 | 55 | print_mat(&m); 56 | 57 | let m = Matrix::from_row_leading_vector2(&vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]); 58 | 59 | print_mat(&m); 60 | 61 | let m = Matrix::from_column_leading_vector2(&vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]); 62 | 63 | print_mat(&m); 64 | 65 | let m = Matrix::random_uniform(10, 5, 3.0, 18.4); 66 | 67 | print_mat(&m); 68 | 69 | let m = Matrix::random_normal(10, 5, 10., 2.0); 70 | 71 | print_mat(&m); 72 | } 73 | -------------------------------------------------------------------------------- /src/bin/test_image.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "arrayfire")] 2 | use arrayfire::print; 3 | 4 | use jiro_nn::{ 5 | linalg::{Matrix, MatrixTrait}, 6 | vision::{image::Image, image::ImageTrait}, 7 | }; 8 | 9 | pub fn main() { 10 | #[allow(unused_variables)] 11 | let image = Image::from_samples( 12 | &Matrix::from_column_leading_vector2(&vec![ 13 | vec![1.0, 2.0, 3.0, 4.0], 14 | vec![3.0, 2.0, 0.0, 2.0], 15 | vec![6.0, 3.0, 1.0, 7.0], 16 | ]), 17 | 1, 18 | ); 19 | 20 | #[cfg(feature = "arrayfire")] 21 | print(&image.0); 22 | 23 | let matrix = vec![ 24 | vec![vec![1.0, 2.0], vec![3.0, 4.0]], 25 | vec![vec![5.0, 6.0], vec![7.0, 8.0]], 26 | vec![vec![9.0, 8.0], vec![7.0, 6.0]], 27 | ]; 28 | 29 | #[allow(unused_variables)] 30 | let image = Image::from_fn(2, 2, 1, 3, |x, y, z, s| { 31 | println!("{} {} {} {}", x, y, z, s); 32 | matrix[s][y][x] 33 | }); 34 | 35 | #[cfg(feature = "arrayfire")] 36 | print(&image.0); 37 | 38 | let matrix = vec![ 39 | vec![ 40 | vec![vec![11.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 41 | vec![vec![110.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 42 | vec![vec![1100.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 43 | ], 44 | vec![ 45 | vec![vec![21.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 46 | vec![vec![210.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 47 | vec![vec![2100.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 48 | ], 49 | vec![ 50 | vec![vec![31.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 51 | vec![vec![310.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 52 | vec![vec![3100.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 53 | ], 54 | vec![ 55 | vec![vec![41.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 56 | vec![vec![410.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 57 | vec![vec![4100.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], 58 | ], 59 | ]; 60 | 61 | #[allow(unused_variables)] 62 | let image = Image::from_fn(2, 3, 3, 4, |x, y, z, s| { 63 | //println!("{} {} {} {}", x, y, z, s); 64 | matrix[s][z][y][x] 65 | }); 66 | 67 | #[cfg(feature = "arrayfire")] 68 | print(&image.0); 69 | } 70 | -------------------------------------------------------------------------------- /src/bin/test_one_hot_encode.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::{linalg::{Matrix, MatrixTrait}, activation::Activation, layer::Layer, loss::Losses}; 2 | 3 | pub fn main() { 4 | let m = Matrix::from_column_leading_vector2(&vec![ 5 | vec![1.0, 2.0, 3.0, 4.0, 0.1], 6 | vec![3.0, 2.0, 0.0, 2.0, 0.3], 7 | vec![6.0, 3.0, 1.0, 7.0, 0.4], 8 | vec![1.0, 1.0, 1.0, 1.0, 1.0], 9 | ]); 10 | 11 | m.print(); 12 | 13 | let mut activation = Activation::Softmax.to_layer(); 14 | let result = activation.forward(m.clone()); 15 | 16 | result.print(); 17 | 18 | let true_m = Matrix::from_column_leading_vector2(&vec![ 19 | vec![0.0, 0.0, 0.0, 0.0, 1.0], 20 | vec![0.0, 0.0, 0.0, 1.0, 0.0], 21 | vec![0.0, 0.0, 1.0, 0.0, 0.0], 22 | vec![0.0, 1.0, 0.0, 0.0, 0.0], 23 | ]); 24 | 25 | true_m.print(); 26 | 27 | let error = Losses::BCE.to_loss().loss_prime(&true_m, &result); 28 | 29 | error.print(); 30 | 31 | let jacobian = activation.backward(0, error); 32 | 33 | jacobian.print(); 34 | } -------------------------------------------------------------------------------- /src/bin/test_softmax.rs: -------------------------------------------------------------------------------- 1 | use jiro_nn::{linalg::{Matrix, MatrixTrait}, activation::Activation, layer::Layer, loss::Losses}; 2 | 3 | pub fn main() { 4 | let m = Matrix::from_column_leading_vector2(&vec![ 5 | vec![1.0, 2.0, 3.0, 4.0, 0.1], 6 | vec![3.0, 2.0, 0.0, 2.0, 0.3], 7 | vec![6.0, 3.0, 1.0, 7.0, 0.4], 8 | vec![1.0, 1.0, 1.0, 1.0, 1.0], 9 | ]); 10 | 11 | m.print(); 12 | 13 | let mut activation = Activation::Softmax.to_layer(); 14 | let result = activation.forward(m.clone()); 15 | 16 | result.print(); 17 | 18 | let true_m = Matrix::from_column_leading_vector2(&vec![ 19 | vec![0.0, 0.0, 0.0, 0.0, 1.0], 20 | vec![0.0, 0.0, 0.0, 1.0, 0.0], 21 | vec![0.0, 0.0, 1.0, 0.0, 0.0], 22 | vec![0.0, 1.0, 0.0, 0.0, 0.0], 23 | ]); 24 | 25 | true_m.print(); 26 | 27 | let error = Losses::BCE.to_loss().loss_prime(&true_m, &result); 28 | 29 | error.print(); 30 | 31 | let jacobian = activation.backward(0, error); 32 | 33 | jacobian.print(); 34 | } -------------------------------------------------------------------------------- /src/initializers.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::linalg::{Matrix, MatrixTrait, Scalar}; 4 | 5 | #[derive(Serialize, Debug, Deserialize, Clone)] 6 | pub enum Initializers { 7 | Zeros, 8 | Uniform, 9 | UniformSigned, 10 | GlorotUniform, 11 | } 12 | 13 | impl Initializers { 14 | pub fn gen_matrix(&self, nrow: usize, ncol: usize) -> Matrix { 15 | match self { 16 | Initializers::Zeros => Matrix::zeros(nrow, ncol), 17 | Initializers::Uniform => Matrix::random_uniform(nrow, ncol, 0.0, 1.0), 18 | Initializers::UniformSigned => Matrix::random_uniform(nrow, ncol, -1.0, 1.0), 19 | Initializers::GlorotUniform => { 20 | let limit = (6. / (ncol + nrow) as Scalar).sqrt(); 21 | Matrix::random_uniform(nrow, ncol, -limit, limit) 22 | } 23 | } 24 | } 25 | 26 | pub fn gen_vector(&self, nrow: usize) -> Matrix { 27 | match self { 28 | Initializers::Zeros => Matrix::zeros(nrow, 1), 29 | Initializers::Uniform => Matrix::random_uniform(nrow, 1, 0.0, 1.0), 30 | Initializers::UniformSigned => Matrix::random_uniform(nrow, 1, -1.0, 1.0), 31 | Initializers::GlorotUniform => { 32 | // not configurationified on vectors in the original paper 33 | // but taken from keras' implementation 34 | let limit = (6. / (nrow) as Scalar).sqrt(); 35 | Matrix::random_uniform(nrow, 1, -limit, limit) 36 | } 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/layer/defaults.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | initializers::Initializers, 3 | optimizer::{sgd, Optimizers}, 4 | }; 5 | 6 | pub fn default_biases_initializer() -> Initializers { 7 | Initializers::Zeros 8 | } 9 | 10 | pub fn default_weights_initializer() -> Initializers { 11 | Initializers::GlorotUniform 12 | } 13 | 14 | pub fn default_biases_optimizer() -> Optimizers { 15 | sgd() 16 | } 17 | 18 | pub fn default_weights_optimizer() -> Optimizers { 19 | sgd() 20 | } 21 | -------------------------------------------------------------------------------- /src/layer/dense_layer.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | 3 | use crate::linalg::{MatrixTrait, Scalar}; 4 | use crate::{ 5 | initializers::Initializers, 6 | layer::Layer, 7 | linalg::Matrix, 8 | optimizer::{Optimizers}, 9 | }; 10 | 11 | use super::LearnableLayer; 12 | 13 | pub struct DenseLayer { 14 | // i inputs, j outputs, i x j connections 15 | input: Option, 16 | // j x i connection weights 17 | pub weights: Matrix, 18 | // j output biases (single column) 19 | pub biases: Matrix, 20 | weights_optimizer: Optimizers, 21 | biases_optimizer: Optimizers, 22 | } 23 | 24 | impl DenseLayer { 25 | pub fn new( 26 | i: usize, 27 | j: usize, 28 | weights_optimizer: Optimizers, 29 | biases_optimizer: Optimizers, 30 | weights_initializer: Initializers, 31 | biases_initializer: Initializers, 32 | ) -> Self { 33 | // about weights initialization : http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf 34 | 35 | let weights = weights_initializer.gen_matrix(j, i); 36 | let biases = biases_initializer.gen_vector(j); 37 | 38 | Self { 39 | weights: weights, 40 | biases: biases, 41 | input: None, 42 | weights_optimizer, 43 | biases_optimizer, 44 | } 45 | } 46 | } 47 | 48 | impl Layer for DenseLayer { 49 | /// `input` has shape `(i, n)` where `i` is the number of inputs and `n` is the number of samples. 50 | /// 51 | /// Returns output which has shape `(j, n)` where `j` is the number of outputs and `n` is the number of samples. 52 | fn forward(&mut self, input: Matrix) -> Matrix { 53 | // Y = W . X + B * (1...1) 54 | 55 | // println!("WEIGHTS"); 56 | // self 57 | // .weights.print(); 58 | 59 | let res = self 60 | .weights 61 | .dot(&input) 62 | .component_add(&self.biases.dot(&Matrix::constant(1, input.dim().1, 1.0))); 63 | 64 | self.input = Some(input); 65 | res 66 | } 67 | 68 | /// `output_gradient` has shape `(j, n)` where `j` is the number of outputs and `n` is the number of samples. 69 | /// 70 | /// Returns `input_gradient` which has shape `(i, n)` where `i` is the number of inputs and `n` is the number of samples. 71 | fn backward(&mut self, epoch: usize, output_gradient: Matrix) -> Matrix { 72 | let input = self.input.as_ref().unwrap(); 73 | 74 | let weights_gradient = &output_gradient.dot(&input.transpose()); 75 | 76 | let biases_gradient = output_gradient.columns_sum(); 77 | 78 | let input_gradient = self.weights.transpose().dot(&output_gradient); 79 | 80 | self.weights = 81 | self.weights_optimizer 82 | .update_parameters(epoch, &self.weights, &weights_gradient); 83 | self.biases = 84 | self.biases_optimizer 85 | .update_parameters(epoch, &self.biases, &biases_gradient); 86 | 87 | input_gradient 88 | } 89 | } 90 | 91 | impl LearnableLayer for DenseLayer { 92 | // returns a matrix of the (jxi) weights and the final column being the (j) biases 93 | fn get_learnable_parameters(&self) -> Vec> { 94 | let mut params = self.weights.get_data_col_leading(); 95 | params.push(self.biases.get_column(0)); 96 | params 97 | } 98 | 99 | // takes a matrix of the (jxi) weights and the final column being the (j) biases 100 | fn set_learnable_parameters(&mut self, params_matrix: &Vec>) { 101 | let mut weights = params_matrix.clone(); 102 | let biases = weights.pop().unwrap(); 103 | self.weights = Matrix::from_column_leading_vector2(&weights); 104 | self.biases = Matrix::from_column_vector(&biases); 105 | } 106 | } 107 | 108 | impl fmt::Debug for DenseLayer { 109 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 110 | write!(f, "Dense Layer") 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/layer/full_layer.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | 3 | use rand::Rng; 4 | 5 | use crate::linalg::{Matrix, MatrixTrait, Scalar}; 6 | use crate::network::NetworkLayer; 7 | use crate::{activation::ActivationLayer, layer::dense_layer::DenseLayer, layer::Layer}; 8 | 9 | use super::{DropoutLayer, LearnableLayer, ParameterableLayer}; 10 | 11 | #[derive(Debug)] 12 | pub struct FullLayer { 13 | dense: DenseLayer, 14 | activation: ActivationLayer, 15 | // dropout resources : https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf 16 | dropout_enabled: bool, 17 | dropout_rate: Option, 18 | mask: Option, 19 | } 20 | 21 | impl FullLayer { 22 | pub fn new(dense: DenseLayer, activation: ActivationLayer, dropout: Option) -> Self { 23 | Self { 24 | dense, 25 | activation, 26 | dropout_rate: dropout, 27 | dropout_enabled: false, 28 | mask: None, 29 | } 30 | } 31 | 32 | fn generate_dropout_mask(&mut self, output_shape: (usize, usize)) -> Option<(Matrix, Scalar)> { 33 | if let Some(dropout_rate) = self.dropout_rate { 34 | let mut rng = rand::thread_rng(); 35 | let dropout_mask = Matrix::from_fn(output_shape.0, output_shape.1, |_, _| { 36 | if rng 37 | .gen_range((0.0 as Scalar)..(1.0 as Scalar)) 38 | .total_cmp(&self.dropout_rate.unwrap()) 39 | == Ordering::Greater 40 | { 41 | 1.0 42 | } else { 43 | 0.0 44 | } 45 | }); 46 | Some((dropout_mask, dropout_rate)) 47 | } else { 48 | None 49 | } 50 | } 51 | } 52 | 53 | impl Layer for FullLayer { 54 | fn forward(&mut self, mut input: Matrix) -> Matrix { 55 | let output = if self.dropout_enabled { 56 | if let Some((mask, _)) = self.generate_dropout_mask(input.dim()) { 57 | input = input.component_mul(&mask); 58 | self.mask = Some(mask); 59 | }; 60 | self.dense.forward(input) 61 | } else { 62 | if let Some(dropout_rate) = self.dropout_rate { 63 | self.dense.weights = self.dense.weights.scalar_mul(1.0 - dropout_rate); 64 | let output = self.dense.forward(input); 65 | self.dense.weights = self.dense.weights.scalar_div(1.0 - dropout_rate); 66 | output 67 | } else { 68 | self.dense.forward(input) 69 | } 70 | }; 71 | 72 | self.activation.forward(output) 73 | } 74 | 75 | fn backward(&mut self, epoch: usize, output_gradient: Matrix) -> Matrix { 76 | let activation_input_gradient = self.activation.backward(epoch, output_gradient); 77 | let input_gradient = self.dense.backward(epoch, activation_input_gradient); 78 | 79 | if let Some(mask) = &self.mask { 80 | input_gradient.component_mul(&mask) 81 | } else { 82 | input_gradient 83 | } 84 | } 85 | } 86 | 87 | impl NetworkLayer for FullLayer {} 88 | 89 | impl ParameterableLayer for FullLayer { 90 | fn as_learnable_layer(&self) -> Option<&dyn LearnableLayer> { 91 | Some(self) 92 | } 93 | 94 | fn as_learnable_layer_mut(&mut self) -> Option<&mut dyn LearnableLayer> { 95 | Some(self) 96 | } 97 | 98 | fn as_dropout_layer(&mut self) -> Option<&mut dyn DropoutLayer> { 99 | Some(self) 100 | } 101 | } 102 | 103 | impl LearnableLayer for FullLayer { 104 | // returns a matrix of the (jxi) weights and the final column being the (j) biases 105 | fn get_learnable_parameters(&self) -> Vec> { 106 | self.dense.get_learnable_parameters() 107 | } 108 | 109 | // takes a matrix of the (jxi) weights and the final column being the (j) biases 110 | fn set_learnable_parameters(&mut self, params_matrix: &Vec>) { 111 | self.dense.set_learnable_parameters(params_matrix) 112 | } 113 | } 114 | 115 | impl DropoutLayer for FullLayer { 116 | fn enable_dropout(&mut self) { 117 | self.dropout_enabled = true; 118 | } 119 | 120 | fn disable_dropout(&mut self) { 121 | self.dropout_enabled = false; 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /src/layer/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | activation::Activation, 3 | linalg::{Matrix, Scalar}, 4 | }; 5 | 6 | pub mod defaults; 7 | pub mod dense_layer; 8 | pub mod full_layer; 9 | 10 | pub enum Layers { 11 | Dense, 12 | Activation(Activation), 13 | } 14 | 15 | pub trait Layer { 16 | /// `input` has shape `(i, n)` where `i` is the number of inputs and `n` is the number of samples. 17 | /// 18 | /// Returns output which has shape `(j, n)` where `j` is the number of outputs and `n` is the number of samples. 19 | fn forward(&mut self, input: Matrix) -> Matrix; 20 | 21 | /// `output_gradient` has shape `(j, n)` where `j` is the number of outputs and `n` is the number of samples. 22 | /// 23 | /// Returns `input_gradient` which has shape `(i, n)` where `i` is the number of inputs and `n` is the number of samples. 24 | fn backward(&mut self, epoch: usize, output_gradient: Matrix) -> Matrix; 25 | } 26 | 27 | pub trait ParameterableLayer { 28 | fn as_learnable_layer(&self) -> Option<&dyn LearnableLayer>; 29 | fn as_learnable_layer_mut(&mut self) -> Option<&mut dyn LearnableLayer>; 30 | fn as_dropout_layer(&mut self) -> Option<&mut dyn DropoutLayer>; 31 | } 32 | 33 | pub trait DropoutLayer { 34 | fn enable_dropout(&mut self); 35 | fn disable_dropout(&mut self); 36 | } 37 | 38 | pub trait LearnableLayer { 39 | fn get_learnable_parameters(&self) -> Vec>; 40 | fn set_learnable_parameters(&mut self, params_matrix: &Vec>); 41 | } 42 | -------------------------------------------------------------------------------- /src/learning_rate/inverse_time_decay.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::linalg::Scalar; 4 | 5 | #[derive(Debug, Clone, Serialize, Deserialize)] 6 | pub struct InverseTimeDecay { 7 | pub initial_learning_rate: Scalar, 8 | pub decay_steps: Scalar, 9 | pub decay_rate: Scalar, 10 | #[serde(default)] 11 | pub staircase: bool, 12 | } 13 | 14 | impl InverseTimeDecay { 15 | pub fn new( 16 | initial_learning_rate: Scalar, 17 | decay_steps: Scalar, 18 | decay_rate: Scalar, 19 | staircase: bool, 20 | ) -> Self { 21 | Self { 22 | initial_learning_rate, 23 | decay_steps, 24 | decay_rate, 25 | staircase, 26 | } 27 | } 28 | 29 | pub fn get_learning_rate(&self, epoch: usize) -> Scalar { 30 | let mut learning_rate = self.initial_learning_rate 31 | / (1. + self.decay_rate * (epoch as Scalar / self.decay_steps)); 32 | if self.staircase { 33 | learning_rate = self.initial_learning_rate 34 | / (1. + self.decay_rate * (epoch as Scalar / self.decay_steps).floor()); 35 | } 36 | learning_rate 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/learning_rate/mod.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use self::{inverse_time_decay::InverseTimeDecay, piecewise_constant::PiecewiseConstant}; 4 | use crate::linalg::Scalar; 5 | 6 | pub mod inverse_time_decay; 7 | pub mod piecewise_constant; 8 | 9 | pub fn default_learning_rate() -> LearningRateSchedule { 10 | LearningRateSchedule::Constant(0.001) 11 | } 12 | 13 | // https://arxiv.org/pdf/1510.04609.pdf 14 | #[derive(Clone, Debug, Serialize, Deserialize)] 15 | pub enum LearningRateSchedule { 16 | Constant(Scalar), 17 | InverseTimeDecay(InverseTimeDecay), 18 | PiecewiseConstant(PiecewiseConstant), 19 | } 20 | 21 | impl LearningRateSchedule { 22 | pub fn get_learning_rate(&self, epoch: usize) -> Scalar { 23 | match self { 24 | LearningRateSchedule::InverseTimeDecay(schedule) => schedule.get_learning_rate(epoch), 25 | LearningRateSchedule::PiecewiseConstant(schedule) => schedule.get_learning_rate(epoch), 26 | LearningRateSchedule::Constant(c) => *c, 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/learning_rate/piecewise_constant.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::linalg::Scalar; 4 | 5 | #[derive(Debug, Clone, Serialize, Deserialize)] 6 | pub struct PiecewiseConstant { 7 | pub boundaries: Vec, 8 | pub values: Vec, 9 | } 10 | 11 | impl PiecewiseConstant { 12 | pub fn new(boundaries: Vec, values: Vec) -> Self { 13 | Self { boundaries, values } 14 | } 15 | 16 | pub fn get_learning_rate(&self, epoch: usize) -> Scalar { 17 | let mut learning_rate = self.values[0]; 18 | for (i, boundary) in self.boundaries.iter().enumerate() { 19 | if epoch >= *boundary { 20 | learning_rate = self.values[i + 1]; 21 | } 22 | } 23 | learning_rate 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![doc = include_str!("../README.md")] 2 | 3 | #[macro_use] 4 | extern crate assert_float_eq; 5 | 6 | /// Monitoring performance of tasks and logging 7 | pub mod monitor; 8 | /// Activation functions and abstractions (sigmoid, relu, softmax...) 9 | pub mod activation; 10 | /// Model performance benchmarking utilities 11 | pub mod benchmarking; 12 | #[cfg(feature = "data")] 13 | /// Dataset configuration (metadata, preprocessing flags...) 14 | pub mod dataset; 15 | #[cfg(feature = "data")] 16 | /// Wrapper around dataframes libaries 17 | pub mod datatable; 18 | /// Parameters initializers and abstractions (uniform, glorot...) 19 | pub mod initializers; 20 | /// Layers and abstractions (dense, full...) 21 | pub mod layer; 22 | /// Learning rate schedulers and abstractions (constant, exponential...) 23 | pub mod learning_rate; 24 | /// Basic linear algebra backends and wrappers (matrix, vector, scalar...) 25 | pub mod linalg; 26 | /// Loss functions and abstractions (mse, crossentropy...) 27 | pub mod loss; 28 | /// Model configuration 29 | pub mod model; 30 | /// Neural network abstractions 31 | pub mod network; 32 | /// Optimizers and abstractions (sgd, adam...) 33 | pub mod optimizer; 34 | #[cfg(feature = "data")] 35 | /// Preprocessing and pipelining utilities (normalization, one-hot encoding...) 36 | pub mod preprocessing; 37 | /// Training methodologies (k-fold, split...) 38 | pub mod trainers; 39 | /// Utilities for `Vec`, `Vec>`... 40 | pub mod vec_utils; 41 | /// Everything vision (CNNs, images operations & backends...) 42 | pub mod vision; 43 | -------------------------------------------------------------------------------- /src/linalg/mod.rs: -------------------------------------------------------------------------------- 1 | pub enum Backends { 2 | ArrayFire, 3 | Nalgebra, 4 | Ndarray 5 | } 6 | 7 | #[cfg(feature = "f64")] 8 | pub type Scalar = Scalar; 9 | 10 | #[cfg(not(feature = "f64"))] 11 | pub type Scalar = f32; 12 | 13 | #[cfg(feature = "arrayfire")] 14 | pub mod arrayfire_matrix; 15 | 16 | #[cfg(feature = "arrayfire")] 17 | pub use arrayfire_matrix::Matrix; 18 | #[cfg(feature = "arrayfire")] 19 | pub const BACKEND: Backends = Backends::ArrayFire; 20 | 21 | #[cfg(feature = "nalgebra")] 22 | pub mod nalgebra_matrix; 23 | 24 | #[cfg(all(feature = "nalgebra", not(feature = "arrayfire")))] 25 | pub use nalgebra_matrix::Matrix; 26 | #[cfg(all(feature = "nalgebra", not(feature = "arrayfire")))] 27 | pub const BACKEND: Backends = Backends::Nalgebra; 28 | 29 | #[cfg(feature = "ndarray")] 30 | pub mod ndarray_matrix; 31 | 32 | #[cfg(all(feature = "ndarray", not(feature = "arrayfire"), not(feature = "nalgebra")))] 33 | pub use ndarray_matrix::Matrix; 34 | #[cfg(all(feature = "ndarray", not(feature = "arrayfire"), not(feature = "nalgebra")))] 35 | pub const BACKEND: Backends = Backends::Ndarray; 36 | 37 | pub trait MatrixTrait: Clone { 38 | fn is_backend_thread_safe() -> bool; 39 | 40 | fn zeros(nrow: usize, ncol: usize) -> Self; 41 | 42 | fn constant(nrow: usize, ncol: usize, value: Scalar) -> Self; 43 | 44 | fn identity(n: usize) -> Self; 45 | 46 | /// Creates a matrix with random values between min and max (excluded). 47 | fn random_uniform(nrow: usize, ncol: usize, min: Scalar, max: Scalar) -> Self; 48 | 49 | /// Creates a matrix with random values following a normal distribution. 50 | fn random_normal(nrow: usize, ncol: usize, mean: Scalar, std_dev: Scalar) -> Self; 51 | 52 | /// Fills the matrix with the iterator columns after columns by chunking the data by n_rows. 53 | /// ```txt 54 | /// Your data : [[col1: row0 row1 ... rowNrow][col2]...[colNcol]] 55 | /// Result : 56 | /// [ 57 | /// [col0: row0 row1 ... rowNrow], 58 | /// [col1: row0 row1 ... rowNrow], 59 | /// ... 60 | /// [colNcol: row0 row1 ... rowNrow], 61 | /// ] 62 | /// ``` 63 | fn from_iter(nrow: usize, ncol: usize, data: impl Iterator) -> Self; 64 | 65 | /// ```txt 66 | /// Your data : 67 | /// [ 68 | /// [row0: col0 col1 ... colNcol], 69 | /// [row1: col0 col1 ... colNcol], 70 | /// ... 71 | /// [rowNrow: col0 col1 ... colNcol], 72 | /// ] 73 | /// 74 | /// Result : 75 | /// [ 76 | /// [col0: row0 row1 ... rowNrow], 77 | /// [col1: row0 row1 ... rowNrow], 78 | /// ... 79 | /// [colNcol: row0 row1 ... rowNrow], 80 | /// ] 81 | /// ``` 82 | fn from_row_leading_vector2(m: &Vec>) -> Self; 83 | 84 | fn from_column_leading_vector2(m: &Vec>) -> Self; 85 | 86 | /// fills a column vector row by row with values of index 0 to v.len() 87 | fn from_column_vector(v: &Vec) -> Self; 88 | 89 | /// fills a row vector column by column with values of index 0 to v.len() 90 | fn from_row_vector(v: &Vec) -> Self; 91 | 92 | fn from_fn(nrows: usize, ncols: usize, f: F) -> Self 93 | where 94 | F: FnMut(usize, usize) -> Scalar; 95 | 96 | fn get_column_as_matrix(&self, idx: usize) -> Self; 97 | 98 | fn from_column_matrices(columns: &[Self]) -> Self; 99 | 100 | fn columns_map(&self, f: impl Fn(usize, &Vec) -> Vec) -> Self; 101 | 102 | fn get_column(&self, index: usize) -> Vec; 103 | 104 | fn get_row(&self, index: usize) -> Vec; 105 | 106 | fn map(&self, f: impl Fn(Scalar) -> Scalar + Sync) -> Self; 107 | 108 | fn map_indexed_mut(&mut self, f: impl Fn(usize, usize, Scalar) -> Scalar + Sync) -> &mut Self; 109 | 110 | fn dot(&self, other: &Self) -> Self; 111 | 112 | fn columns_sum(&self) -> Self; 113 | 114 | fn transpose(&self) -> Self; 115 | 116 | fn get_data_col_leading(&self) -> Vec>; 117 | 118 | fn get_data_row_leading(&self) -> Vec>; 119 | 120 | /// returns the dimensions of the matrix (nrow, ncol) 121 | fn dim(&self) -> (usize, usize); 122 | 123 | fn component_mul(&self, other: &Self) -> Self; 124 | 125 | fn component_add(&self, other: &Self) -> Self; 126 | 127 | fn component_sub(&self, other: &Self) -> Self; 128 | 129 | fn component_div(&self, other: &Self) -> Self; 130 | 131 | fn scalar_add(&self, scalar: Scalar) -> Self; 132 | 133 | fn scalar_mul(&self, scalar: Scalar) -> Self; 134 | 135 | fn scalar_sub(&self, scalar: Scalar) -> Self; 136 | 137 | fn scalar_div(&self, scalar: Scalar) -> Self; 138 | 139 | fn square(&self) -> Self; 140 | 141 | fn sum(&self) -> Scalar; 142 | 143 | fn mean(&self) -> Scalar; 144 | 145 | fn exp(&self) -> Self; 146 | 147 | fn max(&self) -> Scalar; 148 | 149 | fn min(&self) -> Scalar; 150 | 151 | fn maxof(&self, other: &Self) -> Self; 152 | 153 | fn sign(&self) -> Self; 154 | 155 | fn minof(&self, other: &Self) -> Self; 156 | 157 | fn sqrt(&self) -> Self; 158 | 159 | fn log(&self) -> Self; 160 | 161 | fn index(&self, row: usize, col: usize) -> Scalar; 162 | 163 | fn index_mut(&mut self, row: usize, col: usize) -> &mut Scalar; 164 | } 165 | -------------------------------------------------------------------------------- /src/loss/bce.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | linalg::{Matrix, MatrixTrait, Scalar}, 3 | loss::Loss, 4 | }; 5 | 6 | pub fn bce_vec(y_pred: &Vec, y_true: &Vec) -> Scalar { 7 | let n_samples = y_pred.len(); 8 | let mut sum = 0.0; 9 | for j in 0..n_samples { 10 | sum += y_pred[j] * y_true[j].ln() + (1.0 - y_pred[j]) * (1.0 - y_true[j]).ln(); 11 | } 12 | sum / ((n_samples) as Scalar) 13 | } 14 | 15 | fn bce(y_true: &Matrix, y_pred: &Matrix) -> Scalar { 16 | let ones = Matrix::constant(y_true.dim().0, y_true.dim().1, 1.); 17 | (y_true.component_mul(&y_pred.log()).component_add( 18 | &ones 19 | .component_sub(&y_true) 20 | .component_mul(&ones.component_sub(&y_pred).log()), 21 | )) 22 | .mean() * -1. 23 | } 24 | 25 | fn bce_prime(y_true: &Matrix, y_pred: &Matrix) -> Matrix { 26 | let ones = Matrix::constant(y_true.dim().0, y_true.dim().1, 1.); 27 | let ones_m_yt = ones.component_sub(&y_true); 28 | let ones_m_yp = ones.component_sub(&y_pred); 29 | 30 | ones_m_yt 31 | .component_div(&ones_m_yp) 32 | .component_sub(&y_true.component_div(y_pred)) 33 | .scalar_div(y_pred.dim().0 as Scalar) 34 | } 35 | 36 | pub fn new() -> Loss { 37 | Loss::new(bce, bce_prime) 38 | } 39 | -------------------------------------------------------------------------------- /src/loss/mod.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::linalg::Matrix; 4 | use crate::linalg::MatrixTrait; 5 | use crate::linalg::Scalar; 6 | 7 | pub mod mse; 8 | pub mod bce; 9 | 10 | #[derive(Serialize, Debug, Deserialize, Clone)] 11 | pub enum Losses { 12 | MSE, 13 | BCE, 14 | } 15 | 16 | impl Losses { 17 | pub fn to_loss(&self) -> Loss { 18 | match self { 19 | Losses::MSE => mse::new(), 20 | Losses::BCE => bce::new(), 21 | } 22 | } 23 | } 24 | 25 | pub type LossFn = fn(&Matrix, &Matrix) -> Scalar; 26 | pub type LossPrimeFn = fn(&Matrix, &Matrix) -> Matrix; 27 | 28 | pub struct Loss { 29 | loss: LossFn, 30 | derivative: LossPrimeFn, 31 | } 32 | 33 | impl Loss { 34 | pub fn new(loss: LossFn, derivative: LossPrimeFn) -> Self { 35 | Self { loss, derivative } 36 | } 37 | } 38 | 39 | impl Loss { 40 | pub fn loss(&self, y_true: &Matrix, y_pred: &Matrix) -> Scalar { 41 | (self.loss)(y_true, y_pred) 42 | } 43 | 44 | pub fn loss_prime(&self, y_true: &Matrix, y_pred: &Matrix) -> Matrix { 45 | (self.derivative)(y_true, y_pred) 46 | } 47 | 48 | pub fn loss_vec(&self, y_true: &Vec>, y_pred: &Vec>) -> Scalar { 49 | let y_true = Matrix::from_row_leading_vector2(&y_true); 50 | let y_pred = Matrix::from_row_leading_vector2(&y_pred); 51 | self.loss(&y_true, &y_pred) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/loss/mse.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | linalg::{Matrix, MatrixTrait, Scalar}, 3 | loss::Loss, 4 | }; 5 | 6 | pub fn mse_vec(y_true: &Vec, y_pred: &Vec) -> Scalar { 7 | let n_samples = y_pred.len(); 8 | let mut sum = 0.0; 9 | for j in 0..n_samples { 10 | let diff = y_pred[j] - y_true[j]; 11 | sum += diff * diff; 12 | } 13 | sum / ((n_samples) as Scalar) 14 | } 15 | 16 | fn mse(y_pred: &Matrix, y_true: &Matrix) -> Scalar { 17 | ((y_pred.component_sub(&y_true)).square()).mean() 18 | } 19 | 20 | fn mse_prime(y_pred: &Matrix, y_true: &Matrix) -> Matrix { 21 | (y_pred.component_sub(&y_true)).scalar_mul(-2.0) 22 | } 23 | 24 | pub fn new() -> Loss { 25 | Loss::new(mse, mse_prime) 26 | } 27 | -------------------------------------------------------------------------------- /src/model/conv_network_model.rs: -------------------------------------------------------------------------------- 1 | 2 | use serde::{Serialize, Deserialize}; 3 | 4 | use crate::{network::NetworkLayer, vision::{conv_network::{ConvNetwork, ConvNetworkLayer}, conv_layer::avg_pooling_layer::AvgPoolingLayer}}; 5 | 6 | use super::{full_dense_conv_layer_model::{FullDenseConvLayerModel, FullDenseConvLayerModelBuilder}, network_model::NetworkModelBuilder, full_direct_conv_layer_model::{FullDirectConvLayerModel, FullDirectConvLayerModelBuilder}}; 7 | 8 | pub struct ConvNetworkModelBuilder { 9 | pub model: ConvNetworkModel, 10 | parent: NetworkModelBuilder, 11 | } 12 | 13 | impl ConvNetworkModelBuilder { 14 | pub fn new(parent: NetworkModelBuilder, in_channels: usize) -> Self { 15 | Self { 16 | model: ConvNetworkModel { layers: vec![], in_channels }, 17 | parent, 18 | } 19 | } 20 | 21 | pub fn end(self) -> NetworkModelBuilder { 22 | self.parent.accept_conv_network(self.model) 23 | } 24 | 25 | pub fn full_dense(self, kernels_count: usize, kernels_size: usize) -> FullDenseConvLayerModelBuilder { 26 | FullDenseConvLayerModelBuilder::new(self, kernels_count, kernels_size) 27 | } 28 | 29 | pub fn full_direct(self, kernels_size: usize) -> FullDirectConvLayerModelBuilder { 30 | FullDirectConvLayerModelBuilder::new(self, kernels_size) 31 | } 32 | 33 | pub fn avg_pooling(mut self, kernel_size: usize) -> Self { 34 | self.model.layers.push(ConvNetworkLayerModels::AvgPooling { kernel_size }); 35 | self 36 | } 37 | 38 | pub fn accept_full_dense(mut self, model: FullDenseConvLayerModel) -> Self { 39 | self.model.layers.push(ConvNetworkLayerModels::FullDenseConv(model)); 40 | self 41 | } 42 | 43 | pub fn accept_full_direct(mut self, model: FullDirectConvLayerModel) -> Self { 44 | self.model.layers.push(ConvNetworkLayerModels::FullDirectConv(model)); 45 | self 46 | } 47 | } 48 | 49 | #[derive(Serialize, Deserialize, Clone, Debug)] 50 | pub struct ConvNetworkModel { 51 | pub in_channels: usize, 52 | pub layers: Vec, 53 | } 54 | 55 | impl ConvNetworkModel { 56 | pub fn to_layer(self, in_dims: usize) -> (usize, Box) { 57 | let mut layers = vec![]; 58 | let mut in_channels = self.in_channels; 59 | let mut in_img_dims = (in_dims as f64 / in_channels as f64).sqrt() as usize; 60 | 61 | for layer_config in self.layers.into_iter() { 62 | let (out_img_dims, out_channels, conv_layer) = layer_config 63 | .to_conv_layer(in_img_dims, in_channels); 64 | 65 | in_img_dims = out_img_dims; 66 | in_channels = out_channels; 67 | layers.push(conv_layer); 68 | } 69 | 70 | let network_layer = ConvNetwork::new(layers, self.in_channels); 71 | (in_img_dims * in_img_dims * in_channels, Box::new(network_layer)) 72 | } 73 | } 74 | 75 | #[derive(Serialize, Deserialize, Clone, Debug)] 76 | pub enum ConvNetworkLayerModels { 77 | FullDenseConv(FullDenseConvLayerModel), 78 | FullDirectConv(FullDirectConvLayerModel), 79 | AvgPooling { 80 | kernel_size: usize, 81 | }, 82 | } 83 | 84 | impl ConvNetworkLayerModels { 85 | pub fn to_conv_layer(self, in_img_dims: usize, in_channels: usize) -> (usize, usize, Box) { 86 | match self { 87 | Self::FullDenseConv(model) => model.to_layer(in_img_dims, in_channels), 88 | Self::FullDirectConv(model) => model.to_layer(in_img_dims, in_channels), 89 | Self::AvgPooling { kernel_size } => { 90 | let out_img_dims = in_img_dims / kernel_size; 91 | let out_channels = in_channels; 92 | let network_layer = AvgPoolingLayer::new(kernel_size); 93 | (out_img_dims, out_channels, Box::new(network_layer)) 94 | } 95 | } 96 | } 97 | } -------------------------------------------------------------------------------- /src/model/network_model.rs: -------------------------------------------------------------------------------- 1 | use serde::{Serialize, Deserialize}; 2 | 3 | use crate::network::{Network, NetworkLayer}; 4 | 5 | use super::{ModelBuilder, conv_network_model::{ConvNetworkModelBuilder, ConvNetworkModel}, full_dense_layer_model::{FullDenseLayerModel, FullDenseLayerModelBuilder}}; 6 | 7 | pub struct NetworkModelBuilder { 8 | pub model: NetworkModel, 9 | pub parent: Option 10 | } 11 | 12 | impl NetworkModelBuilder { 13 | pub fn new() -> Self { 14 | Self { 15 | model: NetworkModel { layers: Vec::new() }, 16 | parent: None 17 | } 18 | } 19 | 20 | pub fn set_parent(mut self, parent: ModelBuilder) -> Self { 21 | self.parent = Some(parent); 22 | self 23 | } 24 | 25 | pub fn conv_network(self, in_channels: usize) -> ConvNetworkModelBuilder { 26 | ConvNetworkModelBuilder::new(self, in_channels) 27 | } 28 | 29 | pub(crate) fn accept_conv_network(mut self, layer: ConvNetworkModel) -> Self { 30 | self.model.layers.push(NetworkLayerModels::Convolution(layer)); 31 | self 32 | } 33 | 34 | pub fn full_dense(self, size: usize) -> FullDenseLayerModelBuilder { 35 | FullDenseLayerModelBuilder::new(self, size) 36 | } 37 | 38 | pub(crate) fn accept_full_dense(mut self, layer: FullDenseLayerModel) -> Self { 39 | self.model.layers.push(NetworkLayerModels::FullDense(layer)); 40 | self 41 | } 42 | 43 | pub fn end(self) -> ModelBuilder { 44 | match self.parent { 45 | Some(parent) => parent.accept_neural_network(self.model), 46 | None => panic!("No parent model builder set") 47 | } 48 | } 49 | 50 | pub fn build(self) -> NetworkModel { 51 | self.model 52 | } 53 | } 54 | 55 | #[derive(Serialize, Deserialize, Clone, Debug)] 56 | pub struct NetworkModel { 57 | pub layers: Vec 58 | } 59 | 60 | impl NetworkModel { 61 | pub fn to_network(self, mut in_dims: usize) -> Network { 62 | let mut layers = vec![]; 63 | for layer_config in self.layers.into_iter() { 64 | let (out_dims, layer) = layer_config.to_layer(in_dims); 65 | in_dims = out_dims; 66 | layers.push(layer); 67 | } 68 | Network::new(layers) 69 | } 70 | } 71 | 72 | #[derive(Serialize, Deserialize, Clone, Debug)] 73 | pub enum NetworkLayerModels { 74 | Convolution(ConvNetworkModel), 75 | FullDense(FullDenseLayerModel) 76 | } 77 | 78 | impl NetworkLayerModels { 79 | pub fn to_layer(self, in_dims: usize) -> (usize, Box) { 80 | match self { 81 | Self::Convolution(network) => network.to_layer(in_dims), 82 | Self::FullDense(layer) => layer.to_layer(in_dims) 83 | } 84 | } 85 | } -------------------------------------------------------------------------------- /src/network/params.rs: -------------------------------------------------------------------------------- 1 | use std::{fs::File, io::{Write, Read}, path::PathBuf}; 2 | 3 | use flate2::{write::GzEncoder, Compression, read::GzDecoder}; 4 | use serde::{Deserialize, Serialize}; 5 | 6 | use crate::linalg::{Matrix, MatrixTrait, Scalar}; 7 | 8 | #[derive(Serialize, Deserialize)] 9 | pub struct NetworkParams(pub Vec>>); 10 | 11 | impl NetworkParams { 12 | pub fn average(networks: &Vec) -> Self { 13 | let mut params = Vec::new(); 14 | 15 | let layer_count = networks[0].0.len(); 16 | 17 | for layer_index in 0..layer_count { 18 | let mut layer_params = Matrix::from_column_leading_vector2(&networks[0].0[layer_index]); 19 | 20 | for network in networks.iter().skip(1) { 21 | let other_params = Matrix::from_column_leading_vector2(&network.0[layer_index]); 22 | layer_params = layer_params.component_add(&other_params).scalar_div(2.0); 23 | } 24 | 25 | params.push(layer_params.get_data_col_leading()); 26 | } 27 | 28 | NetworkParams(params) 29 | } 30 | 31 | pub fn to_json>(&self, path: P) { 32 | let json = serde_json::to_value(self).unwrap(); 33 | let mut file = File::create(path.into()).unwrap(); 34 | file.write_all(json.to_string().as_bytes()).unwrap(); 35 | } 36 | 37 | pub fn from_json>(path: P) -> Self { 38 | let file = File::open(path.into()).unwrap(); 39 | let params: serde_json::Value = serde_json::from_reader(file).unwrap(); 40 | serde_json::from_value(params).unwrap() 41 | } 42 | 43 | pub fn to_binary_compressed>(&self, path: P) { 44 | let result = bincode::serialize(self).unwrap(); 45 | let mut encoder = GzEncoder::new(Vec::new(), Compression::best()); 46 | encoder.write_all(result.as_slice()).unwrap(); 47 | let compressed = encoder.finish().unwrap(); 48 | let mut file = File::create(path.into()).unwrap(); 49 | file.write_all(&compressed).unwrap(); 50 | } 51 | 52 | pub fn from_binary_compressed>(path: P) -> Self { 53 | let file = File::open(path.into()).unwrap(); 54 | let mut decoder = GzDecoder::new(file); 55 | let mut buffer = Vec::new(); 56 | decoder.read_to_end(&mut buffer).unwrap(); 57 | bincode::deserialize(buffer.as_slice()).unwrap() 58 | } 59 | 60 | pub fn count(&self) -> usize { 61 | self.0.iter().map(|l| l.iter().map(|l| l.len()).sum::()).sum() 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/optimizer/adam.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::{ 4 | learning_rate::{default_learning_rate, LearningRateSchedule}, 5 | linalg::{Matrix, MatrixTrait, Scalar}, 6 | }; 7 | 8 | fn default_beta1() -> Scalar { 9 | 0.9 10 | } 11 | 12 | fn default_beta2() -> Scalar { 13 | 0.999 14 | } 15 | 16 | fn default_epsilon() -> Scalar { 17 | 1e-8 18 | } 19 | 20 | // https://arxiv.org/pdf/1412.6980.pdf 21 | #[derive(Clone, Debug, Serialize, Deserialize)] 22 | pub struct Adam { 23 | #[serde(default = "default_beta1")] 24 | beta1: Scalar, 25 | #[serde(default = "default_beta2")] 26 | beta2: Scalar, 27 | #[serde(default = "default_epsilon")] 28 | epsilon: Scalar, 29 | #[serde(default = "default_learning_rate")] 30 | learning_rate: LearningRateSchedule, 31 | #[serde(skip)] 32 | m: Option, // first moment vector 33 | #[serde(skip)] 34 | v: Option, // second moment vector 35 | } 36 | 37 | impl Adam { 38 | pub fn new( 39 | learning_rate: LearningRateSchedule, 40 | beta1: Scalar, 41 | beta2: Scalar, 42 | epsilon: Scalar, 43 | ) -> Self { 44 | Self { 45 | m: None, 46 | v: None, 47 | beta1, 48 | beta2, 49 | learning_rate, 50 | epsilon, 51 | } 52 | } 53 | 54 | pub fn default() -> Self { 55 | Self { 56 | v: None, 57 | m: None, 58 | beta1: default_beta1(), 59 | beta2: default_beta2(), 60 | learning_rate: default_learning_rate(), 61 | epsilon: default_epsilon(), 62 | } 63 | } 64 | 65 | pub fn update_parameters( 66 | &mut self, 67 | epoch: usize, 68 | parameters: &Matrix, 69 | parameters_gradient: &Matrix, 70 | ) -> Matrix { 71 | let alpha = self.learning_rate.get_learning_rate(epoch); 72 | 73 | let (nrow, ncol) = parameters_gradient.dim(); 74 | 75 | if self.m.is_none() { 76 | self.m = Some(Matrix::zeros(nrow, ncol)); 77 | } 78 | if self.v.is_none() { 79 | self.v = Some(Matrix::zeros(nrow, ncol)); 80 | } 81 | let m = self.m.as_ref().unwrap(); 82 | let v = self.v.as_ref().unwrap(); 83 | 84 | let g = parameters_gradient; 85 | let g2 = parameters_gradient.component_mul(¶meters_gradient); 86 | 87 | let m = &(m.scalar_mul(self.beta1)).component_add(&g.scalar_mul(1.0 - self.beta1)); 88 | let v = &(v.scalar_mul(self.beta2)).component_add(&g2.scalar_mul(1.0 - self.beta2)); 89 | 90 | let m_bias_corrected = m.scalar_div(1.0 - self.beta1); 91 | let v_bias_corrected = v.scalar_div(1.0 - self.beta2); 92 | 93 | let v_bias_corrected = v_bias_corrected.sqrt(); 94 | 95 | self.m = Some(m.clone()); 96 | self.v = Some(v.clone()); 97 | parameters.component_sub( 98 | &(m_bias_corrected.scalar_mul(alpha)) 99 | .component_div(&v_bias_corrected.scalar_add(self.epsilon)), 100 | ) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/optimizer/mod.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::linalg::Matrix; 4 | 5 | use self::{adam::Adam, momentum::Momentum, sgd::SGD}; 6 | 7 | pub mod adam; 8 | pub mod momentum; 9 | pub mod sgd; 10 | 11 | #[derive(Clone, Debug, Serialize, Deserialize)] 12 | pub enum Optimizers { 13 | SGD(SGD), 14 | Momentum(Momentum), 15 | Adam(Adam), 16 | } 17 | 18 | impl Optimizers { 19 | pub fn update_parameters( 20 | &mut self, 21 | epoch: usize, 22 | parameters: &Matrix, 23 | parameters_gradient: &Matrix, 24 | ) -> Matrix { 25 | match self { 26 | Optimizers::SGD(sgd) => sgd.update_parameters(epoch, parameters, parameters_gradient), 27 | Optimizers::Momentum(momentum) => { 28 | momentum.update_parameters(epoch, parameters, parameters_gradient) 29 | } 30 | Optimizers::Adam(adam) => { 31 | adam.update_parameters(epoch, parameters, parameters_gradient) 32 | } 33 | } 34 | } 35 | } 36 | 37 | pub fn adam() -> Optimizers { 38 | Optimizers::Adam(Adam::default()) 39 | } 40 | 41 | pub fn sgd() -> Optimizers { 42 | Optimizers::SGD(SGD::default()) 43 | } 44 | 45 | pub fn momentum() -> Optimizers { 46 | Optimizers::Momentum(Momentum::default()) 47 | } 48 | -------------------------------------------------------------------------------- /src/optimizer/momentum.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::{ 4 | learning_rate::{default_learning_rate, LearningRateSchedule}, 5 | linalg::{Matrix, MatrixTrait, Scalar}, 6 | }; 7 | 8 | pub(crate) fn default_momentum() -> Scalar { 9 | 0.9 10 | } 11 | 12 | // https://arxiv.org/pdf/1207.0580.pdf 13 | #[derive(Clone, Debug, Serialize, Deserialize)] 14 | pub struct Momentum { 15 | #[serde(default = "default_momentum")] 16 | momentum: Scalar, 17 | #[serde(default = "default_learning_rate")] 18 | learning_rate: LearningRateSchedule, 19 | #[serde(skip)] 20 | v: Option, 21 | } 22 | 23 | impl Momentum { 24 | pub fn new(learning_rate: LearningRateSchedule, momentum: Scalar) -> Self { 25 | Self { 26 | v: None, 27 | momentum, 28 | learning_rate, 29 | } 30 | } 31 | 32 | pub fn default() -> Self { 33 | Self { 34 | v: None, 35 | momentum: default_momentum(), 36 | learning_rate: default_learning_rate(), 37 | } 38 | } 39 | 40 | pub fn update_parameters( 41 | &mut self, 42 | epoch: usize, 43 | parameters: &Matrix, 44 | parameters_gradient: &Matrix, 45 | ) -> Matrix { 46 | let lr = self.learning_rate.get_learning_rate(epoch); 47 | 48 | if let None = &self.v { 49 | let (nrow, ncol) = parameters_gradient.dim(); 50 | self.v = Some(Matrix::zeros(nrow, ncol)); 51 | }; 52 | 53 | let v = self.v.as_ref().unwrap(); 54 | 55 | let v = v 56 | .scalar_mul(self.momentum) 57 | .component_add(¶meters_gradient.scalar_mul(lr)); 58 | 59 | let new_params = parameters.component_sub(&v); 60 | self.v = Some(v); 61 | new_params 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/optimizer/sgd.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::{ 4 | learning_rate::{default_learning_rate, LearningRateSchedule}, 5 | linalg::{Matrix, MatrixTrait, Scalar}, 6 | }; 7 | 8 | #[derive(Clone, Debug, Serialize, Deserialize)] 9 | pub struct SGD { 10 | #[serde(default = "default_learning_rate")] 11 | learning_rate: LearningRateSchedule, 12 | } 13 | 14 | impl SGD { 15 | pub fn default() -> Self { 16 | Self { 17 | learning_rate: default_learning_rate(), 18 | } 19 | } 20 | 21 | pub fn with_const_lr(learning_rate: Scalar) -> Self { 22 | Self { 23 | learning_rate: LearningRateSchedule::Constant(learning_rate), 24 | } 25 | } 26 | 27 | pub fn new(learning_rate: LearningRateSchedule) -> Self { 28 | Self { learning_rate } 29 | } 30 | 31 | pub fn update_parameters( 32 | &mut self, 33 | epoch: usize, 34 | parameters: &Matrix, 35 | parameters_gradient: &Matrix, 36 | ) -> Matrix { 37 | let lr = self.learning_rate.get_learning_rate(epoch); 38 | parameters.component_sub(¶meters_gradient.scalar_mul(lr)) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/preprocessing/attach_ids.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | dataset::{Dataset, Feature}, 3 | datatable::DataTable, 4 | }; 5 | 6 | use super::{DataTransformation, CachedConfig}; 7 | 8 | pub struct AttachIds(pub String); 9 | 10 | impl AttachIds { 11 | pub fn new(id_column_name: &str) -> Self { 12 | Self(id_column_name.to_string()) 13 | } 14 | } 15 | 16 | impl DataTransformation for AttachIds { 17 | fn transform( 18 | &mut self, 19 | _cached_config: &CachedConfig, 20 | dataset_config: &Dataset, 21 | data: &DataTable, 22 | ) -> (Dataset, DataTable) { 23 | let mut feature = Feature::default(); 24 | feature.name = self.0.clone(); 25 | feature.used_in_model = true; 26 | feature.is_id = true; 27 | let configuration = dataset_config.with_added_feature(feature); 28 | let data = data.with_autoincrement_id_column(&self.0.clone()); 29 | (configuration, data) 30 | } 31 | 32 | fn reverse_columnswise(&mut self, data: &DataTable) -> DataTable { 33 | data.clone() 34 | } 35 | 36 | fn get_name(&self) -> String { 37 | format!("attach_ids({})", self.0) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/preprocessing/extract_months.rs: -------------------------------------------------------------------------------- 1 | use polars::export::chrono::{DateTime, Datelike, NaiveDateTime, Utc}; 2 | 3 | use crate::{ 4 | dataset::{Dataset, Feature}, 5 | datatable::DataTable, 6 | linalg::Scalar, 7 | }; 8 | 9 | use super::{feature_cached::FeatureExtractorCached, DataTransformation, CachedConfig}; 10 | 11 | pub struct ExtractMonths; 12 | 13 | impl DataTransformation for ExtractMonths { 14 | fn transform( 15 | &mut self, 16 | cached_config: &CachedConfig, 17 | dataset_config: &Dataset, 18 | data: &DataTable, 19 | ) -> (Dataset, DataTable) { 20 | let extracted_feature_config = |feature: &Feature| { 21 | if feature.date_format.is_some() { 22 | match &feature.with_extracted_month { 23 | Some(new_feature) => Some(*new_feature.clone()), 24 | None => match &feature.extract_month { 25 | true => { 26 | let mut f = feature.clone(); 27 | f.date_format = None; 28 | f.extract_month = false; 29 | Some(f) 30 | } 31 | false => None, 32 | }, 33 | } 34 | } else { 35 | None 36 | } 37 | }; 38 | 39 | let extract_feature = |data: &DataTable, extracted: &Feature, feature: &Feature| { 40 | let format = feature.date_format.clone().unwrap(); 41 | data.map_str_column_to_scalar_column(&feature.name, &extracted.name, |date| { 42 | let datetime = NaiveDateTime::parse_from_str(date, &format).unwrap(); 43 | let timestamp: DateTime = DateTime::from_utc(datetime, Utc); 44 | timestamp.month() as Scalar 45 | }) 46 | }; 47 | 48 | let mut extractor = FeatureExtractorCached::new( 49 | Box::new(extracted_feature_config), 50 | Box::new(extract_feature), 51 | ); 52 | 53 | extractor.transform(cached_config, dataset_config, data) 54 | } 55 | 56 | fn reverse_columnswise(&mut self, data: &DataTable) -> DataTable { 57 | data.clone() 58 | } 59 | 60 | fn get_name(&self) -> String { 61 | "extract_months".to_string() 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/preprocessing/extract_timestamps.rs: -------------------------------------------------------------------------------- 1 | use polars::export::chrono::{DateTime, NaiveDateTime, Utc}; 2 | 3 | use crate::{ 4 | dataset::{Dataset, Feature}, 5 | datatable::DataTable, 6 | linalg::Scalar, 7 | }; 8 | 9 | use super::{feature_cached::FeatureExtractorCached, DataTransformation, CachedConfig}; 10 | 11 | pub struct ExtractTimestamps; 12 | 13 | impl DataTransformation for ExtractTimestamps { 14 | fn transform( 15 | &mut self, 16 | cached_config: &CachedConfig, 17 | dataset_config: &Dataset, 18 | data: &DataTable, 19 | ) -> (Dataset, DataTable) { 20 | let extracted_feature_config = |feature: &Feature| { 21 | if feature.date_format.is_some() { 22 | match &feature.with_extracted_timestamp { 23 | Some(new_feature) => Some(*new_feature.clone()), 24 | _ => match &feature.to_timestamp { 25 | true => { 26 | let mut f = feature.clone(); 27 | f.date_format = None; 28 | f.to_timestamp = false; 29 | Some(f) 30 | } 31 | _ => None, 32 | }, 33 | } 34 | } else { 35 | None 36 | } 37 | }; 38 | 39 | let extract_feature = |data: &DataTable, extracted: &Feature, feature: &Feature| { 40 | let format = feature.date_format.clone().unwrap(); 41 | data.map_str_column_to_scalar_column(&feature.name, &extracted.name, |date| { 42 | let datetime = NaiveDateTime::parse_from_str(date, &format).unwrap(); 43 | let timestamp: DateTime = DateTime::from_utc(datetime, Utc); 44 | let unix_seconds = timestamp.timestamp(); 45 | unix_seconds as Scalar 46 | }) 47 | }; 48 | 49 | let mut extractor = FeatureExtractorCached::new( 50 | Box::new(extracted_feature_config), 51 | Box::new(extract_feature), 52 | ); 53 | 54 | extractor.transform(cached_config, dataset_config, data) 55 | } 56 | 57 | fn reverse_columnswise(&mut self, data: &DataTable) -> DataTable { 58 | data.clone() 59 | } 60 | 61 | fn get_name(&self) -> String { 62 | "to_timestamps".to_string() 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /src/preprocessing/feature_cached.rs: -------------------------------------------------------------------------------- 1 | use std::{hash::{Hash, Hasher}, path::Path}; 2 | 3 | use crate::{ 4 | dataset::{Dataset, Feature}, 5 | datatable::DataTable, monitor::TM, 6 | }; 7 | 8 | use super::{CachedConfig, DataTransformation}; 9 | 10 | pub struct FeatureExtractorCached { 11 | pub extracted_feature_config: Box Option>, 12 | pub extract_feature: Box DataTable>, 13 | } 14 | 15 | impl FeatureExtractorCached { 16 | pub fn new( 17 | extracted_feature_config: Box Option>, 18 | extract_feature: Box DataTable>, 19 | ) -> Self { 20 | Self { 21 | extracted_feature_config, 22 | extract_feature, 23 | } 24 | } 25 | 26 | fn get_hashed_id(id: &String) -> String { 27 | let mut hasher = std::collections::hash_map::DefaultHasher::new(); 28 | id.hash(&mut hasher); 29 | hasher.finish().to_string() 30 | } 31 | 32 | fn get_cached_feature_file_name( 33 | &self, 34 | id: &String, 35 | working_dir: &str, 36 | feature: &Feature, 37 | ) -> String { 38 | let file_name = Path::new(&working_dir) 39 | .join("cached") 40 | .join(format!("{}_{}.csv", Self::get_hashed_id(id), feature.name)) 41 | .to_str() 42 | .unwrap() 43 | .to_string(); 44 | file_name 45 | } 46 | 47 | fn get_cached_feature( 48 | &self, 49 | id: &String, 50 | working_dir: &str, 51 | feature: &Feature, 52 | ) -> Option<(DataTable, String)> { 53 | let file_name = self.get_cached_feature_file_name(id, working_dir, feature); 54 | if std::path::Path::new(&file_name).exists() { 55 | let dataset_table = DataTable::from_csv_file(file_name.clone()); 56 | Some((dataset_table.get_column_as_table(&feature.name), file_name)) 57 | } else { 58 | None 59 | } 60 | } 61 | 62 | fn transform_no_cache( 63 | &mut self, 64 | mut dataset_table: DataTable, 65 | feature: &Feature, 66 | extracted_feature: &Feature, 67 | ) -> DataTable { 68 | let old_column = dataset_table.get_column_as_table(&feature.name); 69 | dataset_table = (self.extract_feature)(&dataset_table, &extracted_feature, feature); 70 | // if the transformation replaced the old column, we need to add it back 71 | dataset_table = 72 | if extracted_feature.name != feature.name && !dataset_table.has_column(&feature.name) { 73 | dataset_table.append_table_as_column(&old_column) 74 | } else { 75 | dataset_table 76 | }; 77 | dataset_table 78 | } 79 | } 80 | 81 | impl DataTransformation for FeatureExtractorCached { 82 | fn transform( 83 | &mut self, 84 | cached_config: &CachedConfig, 85 | dataset_config: &Dataset, 86 | data: &DataTable, 87 | ) -> (Dataset, DataTable) { 88 | let mut new_config = dataset_config.clone(); 89 | let mut dataset_table = data.clone(); 90 | 91 | if let CachedConfig::Cached { working_dir, .. } = cached_config { 92 | // create the dataset/cached directory if it does not exist 93 | std::fs::create_dir_all(format!("{}/cached/", working_dir)) 94 | .expect("Failed to create cache directory"); 95 | } 96 | 97 | for feature in &dataset_config.features { 98 | if let Some(extracted_feature) = (self.extracted_feature_config)(feature) { 99 | if let CachedConfig::Cached { id, working_dir } = cached_config { 100 | if let Some((cached_data, cachefile_name)) = 101 | self.get_cached_feature(&id, working_dir, &extracted_feature) 102 | { 103 | TM::start("loadcache"); 104 | dataset_table = if extracted_feature.name == feature.name { 105 | dataset_table 106 | .drop_column(&feature.name) 107 | .append_table_as_column(&cached_data) 108 | } else { 109 | dataset_table.append_table_as_column(&cached_data) 110 | }; 111 | TM::end_with_message(format!( 112 | "Loaded {} from cache {}", 113 | extracted_feature.name, 114 | cachefile_name 115 | )); 116 | } else { 117 | dataset_table = self.transform_no_cache(dataset_table, feature, &extracted_feature); 118 | dataset_table 119 | .get_column_as_table(&extracted_feature.name) 120 | .to_csv_file(self.get_cached_feature_file_name( 121 | &id, 122 | working_dir, 123 | &extracted_feature, 124 | )); 125 | } 126 | } else { 127 | dataset_table = self.transform_no_cache(dataset_table, feature, &extracted_feature); 128 | } 129 | 130 | new_config = if extracted_feature.name == feature.name { 131 | new_config.with_replaced_feature(&feature.name, extracted_feature) 132 | } else { 133 | new_config.with_added_feature(extracted_feature) 134 | }; 135 | } 136 | } 137 | 138 | (new_config, dataset_table) 139 | } 140 | 141 | fn reverse_columnswise(&mut self, data: &DataTable) -> DataTable { 142 | data.clone() 143 | } 144 | 145 | fn get_name(&self) -> String { 146 | "".to_string() 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /src/preprocessing/filter_outliers.rs: -------------------------------------------------------------------------------- 1 | use crate::{dataset::Dataset, datatable::DataTable, vec_utils::vector_quartiles_iqr}; 2 | 3 | use super::{DataTransformation, CachedConfig}; 4 | 5 | pub struct FilterOutliers; 6 | 7 | impl DataTransformation for FilterOutliers { 8 | fn transform( 9 | &mut self, 10 | _cached_config: &CachedConfig, 11 | dataset_config: &Dataset, 12 | data: &DataTable, 13 | ) -> (Dataset, DataTable) { 14 | let mut data = data.clone(); 15 | for feature in dataset_config.features.iter() { 16 | if feature.filter_outliers { 17 | let vals = data.column_to_vector(&feature.name); 18 | let (_, _, _, min, max) = vector_quartiles_iqr(&vals); 19 | data = data.filter_by_scalar_column(&feature.name, |x| x >= min && x <= max); 20 | } 21 | } 22 | (dataset_config.clone(), data.clone()) 23 | } 24 | 25 | fn reverse_columnswise(&mut self, data: &DataTable) -> DataTable { 26 | data.clone() 27 | } 28 | 29 | fn get_name(&self) -> String { 30 | "filter_outliers".to_string() 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/preprocessing/log_scale.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use crate::{ 4 | dataset::{Dataset, Feature}, 5 | datatable::DataTable, 6 | linalg::Scalar, 7 | vec_utils::min_vector, 8 | }; 9 | 10 | use super::{feature_cached::FeatureExtractorCached, DataTransformation, CachedConfig}; 11 | 12 | pub struct LogScale10 { 13 | logged_features: HashMap, 14 | } 15 | 16 | impl LogScale10 { 17 | pub fn new() -> Self { 18 | Self { 19 | logged_features: HashMap::new(), 20 | } 21 | } 22 | } 23 | 24 | impl DataTransformation for LogScale10 { 25 | fn transform( 26 | &mut self, 27 | cached_config: &CachedConfig, 28 | dataset_config: &Dataset, 29 | data: &DataTable, 30 | ) -> (Dataset, DataTable) { 31 | let mut logged_features = HashMap::new(); 32 | 33 | for feature in dataset_config.features.iter() { 34 | if feature.log10 { 35 | let values = data.column_to_vector(&feature.name); 36 | let min = min_vector(&values); 37 | logged_features.insert(feature.name.clone(), min); 38 | } 39 | } 40 | 41 | self.logged_features = logged_features.clone(); 42 | 43 | let mut extractor = FeatureExtractorCached::new( 44 | Box::new(move |feature: &Feature| match &feature.with_log10 { 45 | Some(new_feature) => Some(*new_feature.clone()), 46 | _ => match &feature.log10 { 47 | true => { 48 | let mut feature = feature.clone(); 49 | feature.log10 = false; 50 | Some(feature) 51 | } 52 | _ => None, 53 | }, 54 | }), 55 | Box::new( 56 | move |data: &DataTable, extracted: &Feature, feature: &Feature| { 57 | data.map_scalar_column(&feature.name, |x| { 58 | let min = logged_features.get(&feature.name).unwrap(); 59 | if min <= &1.0 { 60 | (min.abs() + x + 0.001).log10() 61 | } else { 62 | x.log10() 63 | } 64 | }) 65 | .rename_column(&feature.name, &extracted.name) 66 | }, 67 | ), 68 | ); 69 | 70 | extractor.transform(cached_config, dataset_config, data) 71 | } 72 | 73 | fn reverse_columnswise(&mut self, data: &DataTable) -> DataTable { 74 | let mut reversed_data = data.clone(); 75 | 76 | for (feature, min) in self.logged_features.iter() { 77 | if reversed_data.has_column(feature) { 78 | reversed_data = reversed_data.map_scalar_column(feature, |x| { 79 | if min <= &1.0 { 80 | (10 as Scalar).powf(x) - min.abs() - 0.001 81 | } else { 82 | (10 as Scalar).powf(x) 83 | } 84 | }); 85 | } 86 | } 87 | 88 | reversed_data 89 | } 90 | 91 | fn get_name(&self) -> String { 92 | "log10".to_string() 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/preprocessing/normalize.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use crate::{ 4 | dataset::{Dataset, Feature}, 5 | datatable::DataTable, 6 | }; 7 | 8 | use super::{feature_cached::FeatureExtractorCached, DataTransformation, CachedConfig}; 9 | use crate::linalg::Scalar; 10 | 11 | pub struct Normalize { 12 | pub features_min_max: HashMap, 13 | } 14 | 15 | impl Normalize { 16 | pub fn new() -> Self { 17 | Self { 18 | features_min_max: HashMap::new(), 19 | } 20 | } 21 | 22 | pub fn same_normalization(&mut self, new_feature: &str, old_feature: &str) -> &mut Self { 23 | let min_max = self.features_min_max.get(old_feature).unwrap(); 24 | self.features_min_max 25 | .insert(new_feature.to_string(), *min_max); 26 | self 27 | } 28 | 29 | pub fn denormalize_data(&self, data: &DataTable) -> DataTable { 30 | let mut denormalized_data = data.clone(); 31 | 32 | for (feature_name, min_max) in self.features_min_max.iter() { 33 | if denormalized_data.has_column(feature_name) { 34 | denormalized_data = denormalized_data.denormalize_column(feature_name, *min_max); 35 | } 36 | } 37 | 38 | denormalized_data 39 | } 40 | } 41 | 42 | impl DataTransformation for Normalize { 43 | fn transform( 44 | &mut self, 45 | cached_config: &CachedConfig, 46 | dataset_config: &Dataset, 47 | data: &DataTable, 48 | ) -> (Dataset, DataTable) { 49 | let mut features_min_max: HashMap = HashMap::new(); 50 | 51 | for feature in dataset_config.features.iter() { 52 | if feature.normalized { 53 | let min_max = data.min_max_column(&feature.name); 54 | features_min_max.insert(feature.name.clone(), min_max); 55 | } 56 | } 57 | 58 | self.features_min_max = features_min_max.clone(); 59 | 60 | let mut extractor = FeatureExtractorCached::new( 61 | Box::new(move |feature: &Feature| match &feature.with_normalized { 62 | Some(new_feature) => Some(*new_feature.clone()), 63 | _ => match &feature.normalized { 64 | true => { 65 | let mut feature = feature.clone(); 66 | feature.normalized = false; 67 | Some(feature) 68 | } 69 | _ => None, 70 | }, 71 | }), 72 | Box::new( 73 | move |data: &DataTable, extracted: &Feature, feature: &Feature| { 74 | data.normalize_column( 75 | &feature.name, 76 | *features_min_max.get(&feature.name).unwrap(), 77 | ) 78 | .rename_column(&feature.name, &extracted.name) 79 | }, 80 | ), 81 | ); 82 | 83 | extractor.transform(cached_config, dataset_config, data) 84 | } 85 | 86 | fn reverse_columnswise(&mut self, data: &DataTable) -> DataTable { 87 | self.denormalize_data(data) 88 | } 89 | 90 | fn get_name(&self) -> String { 91 | "norm".to_string() 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/preprocessing/one_hot_encode.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{HashMap, HashSet}; 2 | 3 | use crate::{ 4 | dataset::{Dataset, Feature}, 5 | datatable::DataTable, 6 | linalg::Scalar, 7 | }; 8 | 9 | use super::{CachedConfig, DataTransformation}; 10 | 11 | pub struct OneHotEncode; 12 | 13 | impl DataTransformation for OneHotEncode { 14 | fn transform( 15 | &mut self, 16 | _cached_config: &CachedConfig, 17 | dataset_config: &Dataset, 18 | data: &DataTable, 19 | ) -> (Dataset, DataTable) { 20 | let mut features_values: HashMap> = HashMap::new(); 21 | 22 | data.as_scalar_hashmap().iter().for_each(|(name, values)| { 23 | let feature = dataset_config.features.iter().find(|f| f.name == *name).unwrap(); 24 | if feature.one_hot_encoded { 25 | let mut values_set = HashSet::new(); 26 | values.iter().for_each(|v| { 27 | values_set.insert(*v as i64); 28 | }); 29 | features_values.insert(feature.clone(), values_set); 30 | } 31 | }); 32 | 33 | let mut new_config = dataset_config.clone(); 34 | 35 | for (feature, values) in features_values.iter() { 36 | for value in values.iter() { 37 | let mut new_feature = feature.clone(); 38 | new_feature.name = format!("{}={}", feature.name, value); 39 | new_feature.one_hot_encoded = false; 40 | new_config = new_config.with_added_feature(new_feature); 41 | } 42 | new_config = new_config.without_feature(feature.name.clone()); 43 | } 44 | 45 | let mut new_data = data.clone(); 46 | for (feature, values) in features_values.iter() { 47 | let column = new_data.column_to_vector(&feature.name); 48 | new_data = new_data.drop_column(&feature.name); 49 | let mut rows = vec![vec![0.0 as Scalar; values.len()]; column.len()]; 50 | let mut names = vec![]; 51 | 52 | for (i, value) in values.iter().enumerate() { 53 | for (row, v) in column.iter().enumerate() { 54 | if *v as i64 == *value { 55 | rows[row][i] = 1.0; 56 | } else { 57 | rows[row][i] = 0.0; 58 | } 59 | } 60 | names.push(format!("{}={}", feature.name, value)); 61 | } 62 | 63 | let onehotdata = DataTable::from_vectors(names.as_slice(), &rows); 64 | new_data = new_data.append_table_as_columns(&onehotdata); 65 | } 66 | 67 | (new_config, new_data) 68 | } 69 | 70 | fn reverse_columnswise(&mut self, data: &DataTable) -> DataTable { 71 | let mut classes: HashMap> = HashMap::new(); 72 | let mut one_hot_encoded: HashMap> = HashMap::new(); 73 | 74 | // extract any column name like xxxx=yyyy which is a one hot encoded value for yyyy of column xxxx 75 | data.as_scalar_hashmap().iter().for_each(|(name, values)| { 76 | let parts: Vec<&str> = name.split("=").collect(); 77 | if parts.len() == 2 { 78 | let column_name = parts[0].to_string(); 79 | let class = parts[1].to_string(); 80 | 81 | if !one_hot_encoded.contains_key(&column_name) { 82 | classes.insert(column_name.clone(), vec![class.clone()]); 83 | let class_idx = 0; 84 | 85 | let mut idmaxes = Vec::new(); 86 | for i in 0..values.len() { 87 | idmaxes.push((class_idx, values[i])); 88 | } 89 | one_hot_encoded.insert(column_name, idmaxes); 90 | } else { 91 | let column_classes = classes.get_mut(&column_name).unwrap(); 92 | let class_idx = column_classes.len(); 93 | column_classes.push(class.clone()); 94 | 95 | let idmaxes = one_hot_encoded.get_mut(&column_name).unwrap(); 96 | for i in 0..values.len() { 97 | idmaxes[i].1 = idmaxes[i].1.max(values[i]); 98 | if idmaxes[i].1 == values[i] { 99 | idmaxes[i].0 = class_idx; 100 | } 101 | } 102 | } 103 | } 104 | }); 105 | 106 | let mut new_data = data.clone(); 107 | 108 | // remove old column=class columns for all classes in classes for that column 109 | for (column_name, classes) in classes.iter() { 110 | for class in classes.iter() { 111 | new_data = new_data.drop_column(&format!("{}={}", column_name, class)); 112 | } 113 | } 114 | 115 | // add new column columns for all columns and put the most likely class in there 116 | for (column_name, idmaxes) in one_hot_encoded.iter() { 117 | let mut column_class_values = Vec::new(); 118 | let mut column_confidence_value = Vec::new(); 119 | for (class_idx, confidence) in idmaxes { 120 | let class_name = classes.get(column_name).unwrap()[*class_idx].clone(); 121 | column_class_values.push(class_name); 122 | column_confidence_value.push(*confidence); 123 | } 124 | new_data = new_data.with_column_string(&column_name, column_class_values.as_slice()); 125 | new_data = new_data.with_column_scalar(&format!("{}.confidence", column_name), column_confidence_value.as_slice()); 126 | } 127 | 128 | new_data 129 | } 130 | 131 | fn get_name(&self) -> String { 132 | "onehotencode".to_string() 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /src/preprocessing/sample.rs: -------------------------------------------------------------------------------- 1 | use crate::{dataset::Dataset, datatable::DataTable}; 2 | 3 | use super::{DataTransformation, CachedConfig}; 4 | 5 | pub struct Sample { 6 | pub count: usize, 7 | pub shuffle: bool, 8 | } 9 | 10 | impl Sample { 11 | pub fn new(count: usize, shuffle: bool) -> Self { 12 | Self { count, shuffle } 13 | } 14 | } 15 | 16 | impl DataTransformation for Sample { 17 | fn transform( 18 | &mut self, 19 | _cached_config: &CachedConfig, 20 | dataset_config: &Dataset, 21 | data: &DataTable, 22 | ) -> (Dataset, DataTable) { 23 | let data = data.sample(Some(self.count), self.shuffle); 24 | (dataset_config.clone(), data) 25 | } 26 | 27 | fn reverse_columnswise(&mut self, data: &DataTable) -> DataTable { 28 | data.clone() 29 | } 30 | 31 | fn get_name(&self) -> String { 32 | let seed = if self.shuffle { 33 | rand::random::() 34 | } else { 35 | 0 36 | }; 37 | 38 | format!("sample({},{})", self.count, seed) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/preprocessing/square.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashSet; 2 | 3 | use crate::{ 4 | dataset::{Dataset, Feature}, 5 | datatable::DataTable, 6 | }; 7 | 8 | use super::{feature_cached::FeatureExtractorCached, DataTransformation, CachedConfig}; 9 | 10 | pub struct Square { 11 | squared_features: HashSet, 12 | } 13 | 14 | impl Square { 15 | pub fn new() -> Self { 16 | Self { 17 | squared_features: HashSet::new(), 18 | } 19 | } 20 | } 21 | 22 | impl DataTransformation for Square { 23 | fn transform( 24 | &mut self, 25 | cached_config: &CachedConfig, 26 | dataset_config: &Dataset, 27 | data: &DataTable, 28 | ) -> (Dataset, DataTable) { 29 | let mut squared_features = HashSet::new(); 30 | 31 | for feature in dataset_config.features.iter() { 32 | if feature.squared { 33 | squared_features.insert(feature.name.clone()); 34 | } 35 | } 36 | 37 | self.squared_features = squared_features.clone(); 38 | 39 | let mut extractor = FeatureExtractorCached::new( 40 | Box::new(move |feature: &Feature| match &feature.with_squared { 41 | Some(new_feature) => Some(*new_feature.clone()), 42 | _ => match &feature.squared { 43 | true => { 44 | let mut feature = feature.clone(); 45 | feature.squared = false; 46 | Some(feature) 47 | } 48 | _ => None, 49 | }, 50 | }), 51 | Box::new( 52 | move |data: &DataTable, extracted: &Feature, feature: &Feature| { 53 | data.map_scalar_column(&feature.name, |x| x.powi(2)) 54 | .rename_column(&feature.name, &extracted.name) 55 | }, 56 | ), 57 | ); 58 | 59 | extractor.transform(cached_config, dataset_config, data) 60 | } 61 | 62 | fn reverse_columnswise(&mut self, data: &DataTable) -> DataTable { 63 | let mut reversed_data = data.clone(); 64 | 65 | for feature in self.squared_features.iter() { 66 | if reversed_data.has_column(feature) { 67 | reversed_data = reversed_data.map_scalar_column(feature, |x| x.sqrt()); 68 | } 69 | } 70 | 71 | reversed_data 72 | } 73 | 74 | fn get_name(&self) -> String { 75 | "square".to_string() 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/trainers/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "data")] 2 | pub mod kfolds; 3 | 4 | pub mod split; -------------------------------------------------------------------------------- /src/vision/conv_activation/linear.rs: -------------------------------------------------------------------------------- 1 | use super::ConvActivationLayer; 2 | use crate::vision::{image::Image, image::ImageTrait}; 3 | 4 | pub fn new() -> ConvActivationLayer { 5 | ConvActivationLayer::new( 6 | |m| m.clone(), 7 | |m| { 8 | Image::constant( 9 | m.image_dims().0, 10 | m.image_dims().1, 11 | m.channels(), 12 | m.samples(), 13 | 1., 14 | ) 15 | }, 16 | ) 17 | } 18 | -------------------------------------------------------------------------------- /src/vision/conv_activation/mod.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use super::{image::Image, image::ImageTrait}; 6 | 7 | pub mod linear; 8 | pub mod relu; 9 | pub mod sigmoid; 10 | pub mod tanh; 11 | 12 | pub type ConvActivationFn = fn(&Image) -> Image; 13 | 14 | pub struct ConvActivationLayer { 15 | input: Option, 16 | activation: ConvActivationFn, 17 | derivative: ConvActivationFn, 18 | } 19 | 20 | impl ConvActivationLayer { 21 | pub fn new(activation: ConvActivationFn, derivative: ConvActivationFn) -> Self { 22 | Self { 23 | input: None, 24 | activation, 25 | derivative, 26 | } 27 | } 28 | 29 | pub fn forward(&mut self, input: Image) -> Image { 30 | self.input = Some(input.clone()); 31 | (self.activation)(&input) 32 | } 33 | 34 | pub fn backward(&mut self, _epoch: usize, output_gradient: Image) -> Image { 35 | let input = self.input.clone().unwrap(); 36 | let fprime_x = (self.derivative)(&input); 37 | output_gradient.component_mul(&fprime_x) 38 | } 39 | } 40 | 41 | #[derive(Debug, Clone, Copy, Serialize, Deserialize)] 42 | pub enum ConvActivation { 43 | ConvTanh, 44 | ConvSigmoid, 45 | ConvReLU, 46 | ConvLinear, 47 | } 48 | 49 | impl ConvActivation { 50 | pub fn to_layer(&self) -> ConvActivationLayer { 51 | match self { 52 | Self::ConvLinear => linear::new(), 53 | Self::ConvTanh => tanh::new(), 54 | Self::ConvSigmoid => sigmoid::new(), 55 | Self::ConvReLU => relu::new(), 56 | } 57 | } 58 | } 59 | 60 | impl fmt::Debug for ConvActivationLayer { 61 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 62 | write!(f, "Convolutional Activation Layer") 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /src/vision/conv_activation/relu.rs: -------------------------------------------------------------------------------- 1 | use super::ConvActivationLayer; 2 | use crate::vision::{image::Image, image::ImageTrait}; 3 | 4 | pub fn new() -> ConvActivationLayer { 5 | ConvActivationLayer::new( 6 | |m| { 7 | m.maxof(&Image::constant( 8 | m.image_dims().0, 9 | m.image_dims().1, 10 | m.channels(), 11 | m.samples(), 12 | 0., 13 | )) 14 | }, 15 | |m| { 16 | m.sign().maxof(&Image::constant( 17 | m.image_dims().0, 18 | m.image_dims().1, 19 | m.channels(), 20 | m.samples(), 21 | 0., 22 | )) 23 | }, 24 | ) 25 | } 26 | -------------------------------------------------------------------------------- /src/vision/conv_activation/sigmoid.rs: -------------------------------------------------------------------------------- 1 | use super::ConvActivationLayer; 2 | use crate::vision::{image::Image, image::ImageTrait}; 3 | 4 | fn sigmoid(m: &Image) -> Image { 5 | let exp_neg = m.scalar_mul(-1.).exp(); 6 | let ones = Image::constant( 7 | m.image_dims().0, 8 | m.image_dims().1, 9 | m.channels(), 10 | m.samples(), 11 | 1., 12 | ); 13 | ones.component_div(&(ones.component_add(&exp_neg))) 14 | } 15 | 16 | fn sigmoid_prime(m: &Image) -> Image { 17 | let sig = sigmoid(m); 18 | let ones = Image::constant( 19 | m.image_dims().0, 20 | m.image_dims().1, 21 | m.channels(), 22 | m.samples(), 23 | 1., 24 | ); 25 | sig.component_mul(&(ones.component_sub(&sig))) 26 | } 27 | 28 | pub fn new() -> ConvActivationLayer { 29 | ConvActivationLayer::new(sigmoid, sigmoid_prime) 30 | } 31 | -------------------------------------------------------------------------------- /src/vision/conv_activation/tanh.rs: -------------------------------------------------------------------------------- 1 | use super::ConvActivationLayer; 2 | use crate::vision::{image::Image, image::ImageTrait}; 3 | 4 | fn tanh(m: &Image) -> Image { 5 | let exp = m.exp(); 6 | let exp_neg = m.scalar_mul(-1.).exp(); 7 | (exp.component_sub(&exp_neg)).component_div(&(exp.component_add(&exp_neg))) 8 | } 9 | 10 | fn tanh_prime(m: &Image) -> Image { 11 | let hbt = tanh(m); 12 | let hbt2 = &hbt.square(); 13 | let ones = Image::constant( 14 | hbt.image_dims().0, 15 | hbt.image_dims().1, 16 | hbt.channels(), 17 | hbt.samples(), 18 | 1., 19 | ); 20 | ones.component_sub(&hbt2) 21 | } 22 | 23 | pub fn new() -> ConvActivationLayer { 24 | ConvActivationLayer::new(tanh, tanh_prime) 25 | } 26 | -------------------------------------------------------------------------------- /src/vision/conv_initializers.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use super::image::Image; 4 | use crate::{linalg::Scalar, vision::image::ImageTrait}; 5 | 6 | #[derive(Serialize, Debug, Deserialize, Clone)] 7 | pub enum ConvInitializers { 8 | Zeros, 9 | Uniform, 10 | UniformSigned, 11 | GlorotUniform, 12 | } 13 | 14 | impl ConvInitializers { 15 | pub fn gen_image(&self, nrow: usize, ncol: usize, nchan: usize, nsample: usize) -> Image { 16 | match self { 17 | ConvInitializers::Zeros => Image::zeros(nrow, ncol, nchan, nsample), 18 | ConvInitializers::Uniform => { 19 | Image::random_uniform(nrow, ncol, nchan, nsample, 0.0, 1.0) 20 | } 21 | ConvInitializers::UniformSigned => { 22 | Image::random_uniform(nrow, ncol, nchan, nsample, -1.0, 1.0) 23 | } 24 | ConvInitializers::GlorotUniform => { 25 | let limit = (6. / (ncol * nrow + ncol * nrow) as Scalar).sqrt(); 26 | Image::random_uniform(nrow, ncol, nchan, nsample, -limit, limit) 27 | } 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/vision/conv_layer/avg_pooling_layer.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | layer::{LearnableLayer, ParameterableLayer}, 3 | linalg::{Scalar}, 4 | vision::{ 5 | image::Image, 6 | image::ImageTrait, conv_network::ConvNetworkLayer, 7 | }, 8 | }; 9 | 10 | use crate::vision::image_layer::ImageLayer; 11 | 12 | #[derive(Debug)] 13 | pub struct AvgPoolingLayer { 14 | pub div: usize, 15 | } 16 | 17 | impl AvgPoolingLayer { 18 | pub fn new( 19 | div: usize, 20 | ) -> Self { 21 | Self { 22 | div, 23 | } 24 | } 25 | } 26 | 27 | impl ImageLayer for AvgPoolingLayer { 28 | fn forward(&mut self, input: Image) -> Image { 29 | let unwrapped = input.unwrap(self.div, self.div, self.div, self.div, 0, 0); 30 | let meaned = unwrapped.mean_along(0); 31 | let result = meaned.wrap(input.image_dims().0/self.div, input.image_dims().1/self.div, 1, 1, 1, 1, 0, 0); 32 | result 33 | } 34 | 35 | fn backward(&mut self, _epoch: usize, output_gradient: Image) -> Image { 36 | let input_grad = output_gradient 37 | .scalar_div((self.div * self.div) as Scalar) 38 | .unwrap(1, 1, 1, 1, 0, 0) 39 | .tile(self.div * self.div, 1, 1, 1) 40 | .wrap( 41 | output_gradient.image_dims().0 * self.div, 42 | output_gradient.image_dims().1 * self.div, 43 | self.div, 44 | self.div, 45 | self.div, 46 | self.div, 47 | 0, 48 | 0, 49 | ); 50 | 51 | input_grad 52 | } 53 | } 54 | 55 | impl ParameterableLayer for AvgPoolingLayer { 56 | fn as_learnable_layer(&self) -> Option<&dyn LearnableLayer> { 57 | None 58 | } 59 | 60 | fn as_learnable_layer_mut(&mut self) -> Option<&mut dyn LearnableLayer> { 61 | None 62 | } 63 | 64 | fn as_dropout_layer(&mut self) -> Option<&mut dyn crate::layer::DropoutLayer> { 65 | None 66 | } 67 | } 68 | 69 | impl ConvNetworkLayer for AvgPoolingLayer { 70 | } -------------------------------------------------------------------------------- /src/vision/conv_layer/defaults.rs: -------------------------------------------------------------------------------- 1 | use crate::vision::{ 2 | conv_initializers::ConvInitializers, 3 | conv_optimizer::{conv_sgd, ConvOptimizers}, 4 | }; 5 | 6 | pub fn default_biases_initializer() -> ConvInitializers { 7 | ConvInitializers::Zeros 8 | } 9 | 10 | pub fn default_kernels_initializer() -> ConvInitializers { 11 | ConvInitializers::GlorotUniform 12 | } 13 | 14 | pub fn default_biases_optimizer() -> ConvOptimizers { 15 | conv_sgd() 16 | } 17 | 18 | pub fn default_kernels_optimizer() -> ConvOptimizers { 19 | conv_sgd() 20 | } 21 | -------------------------------------------------------------------------------- /src/vision/conv_layer/dense_conv_layer.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | layer::LearnableLayer, 3 | linalg::{Matrix, MatrixTrait, Scalar}, 4 | vision::{ 5 | conv_initializers::ConvInitializers, conv_optimizer::ConvOptimizers, image::Image, 6 | image::ImageTrait, 7 | }, 8 | }; 9 | 10 | use crate::vision::image_layer::ImageLayer; 11 | 12 | use super::ConvLayer; 13 | 14 | #[derive(Debug)] 15 | pub struct DenseConvLayer { 16 | pub kernels: Image, 17 | biases: Image, 18 | input: Option, 19 | kernels_optimizer: ConvOptimizers, 20 | biases_optimizer: ConvOptimizers, 21 | } 22 | 23 | impl DenseConvLayer { 24 | pub fn new( 25 | nrow: usize, 26 | ncol: usize, 27 | nchan: usize, 28 | nkern: usize, 29 | kernels_initializer: ConvInitializers, 30 | biases_initializer: ConvInitializers, 31 | kernels_optimizer: ConvOptimizers, 32 | biases_optimizer: ConvOptimizers, 33 | ) -> Self { 34 | Self { 35 | kernels: kernels_initializer.gen_image(nrow, ncol, nchan, nkern), 36 | biases: biases_initializer.gen_image(1, 1, nkern, 1), 37 | input: None, 38 | kernels_optimizer, 39 | biases_optimizer, 40 | } 41 | } 42 | 43 | pub fn out_img_dims_and_channels( 44 | in_rows: usize, 45 | in_cols: usize, 46 | krows: usize, 47 | kcols: usize, 48 | kchans: usize, 49 | ) -> (usize, usize, usize) { 50 | let out_rows = in_rows - krows + 1; 51 | let out_cols = in_cols - kcols + 1; 52 | let out_chans = kchans; 53 | (out_rows, out_cols, out_chans) 54 | } 55 | } 56 | 57 | impl ImageLayer for DenseConvLayer { 58 | fn forward(&mut self, input: Image) -> Image { 59 | let res = input 60 | .cross_correlate(&self.kernels); 61 | 62 | if self.biases.image_dims() != res.image_dims() { 63 | self.biases = self.biases.tile(res.image_dims().0, res.image_dims().1, 1, 1); 64 | } 65 | 66 | let res = res 67 | .component_add(&self.biases); 68 | 69 | self.input = Some(input); 70 | res 71 | } 72 | 73 | fn backward(&mut self, epoch: usize, output_gradient: Image) -> Image { 74 | let input = self.input.as_ref().unwrap(); 75 | 76 | let mut input_grad_channels = vec![]; 77 | for i in 0..input.channels() { 78 | let mut sum = Image::zeros(input.image_dims().0, input.image_dims().1, 1, input.samples()); 79 | for k in 0..output_gradient.channels() { 80 | let kernel = self.kernels.get_sample(k).get_channel(i); 81 | let k_output_grad = output_gradient.get_channel_across_samples(k); 82 | let correlated = k_output_grad.convolve_full(&kernel); 83 | sum = sum.component_add(&correlated); 84 | } 85 | input_grad_channels.push(sum); 86 | } 87 | let input_grad = Image::join_channels(input_grad_channels); 88 | 89 | let mut kern_grad_samples = vec![]; 90 | for k in 0..self.kernels.samples() { 91 | let mut kern_grad_channels = vec![]; 92 | let output_grad_k = output_gradient.get_channel_across_samples(k).sum_samples(); 93 | for i in 0..self.kernels.channels() { 94 | let input_i = input.get_channel_across_samples(i); 95 | let correlated = input_i.cross_correlate(&output_grad_k); 96 | kern_grad_channels.push(correlated.sum_samples()); 97 | } 98 | let kern_grad_sample = Image::join_channels(kern_grad_channels); 99 | kern_grad_samples.push(kern_grad_sample); 100 | } 101 | let kern_grad = Image::join_samples(kern_grad_samples); 102 | 103 | let mut biases_grad_channels = vec![]; 104 | for c in 0..self.biases.channels() { 105 | let channel = output_gradient.get_channel_across_samples(c); 106 | let channel = channel.sum_samples(); 107 | biases_grad_channels.push(channel); 108 | } 109 | let biases_grad = Image::join_channels(biases_grad_channels); 110 | 111 | self.kernels = self 112 | .kernels_optimizer 113 | .update_parameters(epoch, &self.kernels, &kern_grad); 114 | self.biases = self 115 | .biases_optimizer 116 | .update_parameters(epoch, &self.biases, &biases_grad); 117 | input_grad 118 | } 119 | } 120 | 121 | impl LearnableLayer for DenseConvLayer { 122 | fn get_learnable_parameters(&self) -> Vec> { 123 | let mut params = self.kernels.flatten().get_data_col_leading(); 124 | params.push(self.biases.flatten().get_column(0)); 125 | params 126 | } 127 | 128 | fn set_learnable_parameters(&mut self, params_matrix: &Vec>) { 129 | let mut kernels = params_matrix.clone(); 130 | let biases = kernels.pop().unwrap(); 131 | self.kernels = Image::from_samples( 132 | &Matrix::from_column_leading_vector2(&kernels), 133 | self.kernels.channels(), 134 | ); 135 | self.biases = Image::from_samples( 136 | &Matrix::from_column_vector(&biases), 137 | self.biases.channels(), 138 | ); 139 | } 140 | } 141 | 142 | impl ConvLayer for DenseConvLayer { 143 | fn scale_kernels(&mut self, scale: Scalar) { 144 | self.kernels = self.kernels.scalar_mul(scale); 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/vision/conv_layer/direct_conv_layer.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | layer::LearnableLayer, 3 | linalg::{Matrix, MatrixTrait, Scalar}, 4 | vision::{ 5 | conv_initializers::ConvInitializers, conv_optimizer::ConvOptimizers, image::Image, 6 | image::ImageTrait, 7 | }, 8 | }; 9 | 10 | use crate::vision::image_layer::ImageLayer; 11 | 12 | use super::ConvLayer; 13 | 14 | #[derive(Debug)] 15 | pub struct DirectConvLayer { 16 | pub kernels: Image, 17 | biases: Image, 18 | input: Option, 19 | kernels_optimizer: ConvOptimizers, 20 | biases_optimizer: ConvOptimizers, 21 | } 22 | 23 | impl DirectConvLayer { 24 | pub fn new( 25 | krows: usize, 26 | kcols: usize, 27 | in_chans: usize, 28 | kernels_initializer: ConvInitializers, 29 | biases_initializer: ConvInitializers, 30 | kernels_optimizer: ConvOptimizers, 31 | biases_optimizer: ConvOptimizers, 32 | ) -> Self { 33 | Self { 34 | kernels: kernels_initializer.gen_image(krows, kcols, in_chans, 1), 35 | biases: biases_initializer.gen_image(1, 1, in_chans, 1), 36 | input: None, 37 | kernels_optimizer, 38 | biases_optimizer, 39 | } 40 | } 41 | 42 | pub fn out_img_dims_and_channels( 43 | in_rows: usize, 44 | in_cols: usize, 45 | in_chans: usize, 46 | krows: usize, 47 | kcols: usize, 48 | ) -> (usize, usize, usize) { 49 | let out_rows = in_rows - krows + 1; 50 | let out_cols = in_cols - kcols + 1; 51 | (out_rows, out_cols, in_chans) 52 | } 53 | } 54 | 55 | impl ImageLayer for DirectConvLayer { 56 | fn forward(&mut self, input: Image) -> Image { 57 | let mut channels = vec![]; 58 | for c in 0..input.channels() { 59 | let channel = input.get_channel_across_samples(c); 60 | 61 | let kernel = self.kernels.get_channel(c); 62 | 63 | let correlated = channel.cross_correlate(&kernel); 64 | 65 | if self.biases.image_dims().0 != correlated.image_dims().0 { 66 | self.biases = 67 | self.biases 68 | .tile(correlated.image_dims().0, correlated.image_dims().1, 1, 1); 69 | } 70 | 71 | let bias = self.biases.get_channel(c); 72 | 73 | let result_channel = correlated.component_add(&bias); 74 | channels.push(result_channel); 75 | } 76 | let res = Image::join_channels(channels); 77 | self.input = Some(input); 78 | res 79 | } 80 | 81 | fn backward(&mut self, epoch: usize, output_gradient: Image) -> Image { 82 | let input = self.input.as_ref().unwrap(); 83 | 84 | let mut input_grad_channels = vec![]; 85 | for i in 0..input.channels() { 86 | let kernel = self.kernels.get_channel(i); 87 | let output_grad_i = output_gradient.get_channel_across_samples(i); 88 | let correlated = output_grad_i.convolve_full(&kernel); 89 | input_grad_channels.push(correlated); 90 | } 91 | let input_grad = Image::join_channels(input_grad_channels); 92 | 93 | let mut kern_grad_channels = vec![]; 94 | for i in 0..input.channels() { 95 | let output_grad_i = output_gradient.get_channel_across_samples(i).sum_samples(); 96 | let input_i = input.get_channel_across_samples(i); 97 | let correlated = input_i.cross_correlate(&output_grad_i).sum_samples(); 98 | kern_grad_channels.push(correlated); 99 | } 100 | let kern_grad = Image::join_channels(kern_grad_channels); 101 | 102 | let mut biases_grad_channels = vec![]; 103 | for c in 0..self.biases.channels() { 104 | let channel = output_gradient.get_channel_across_samples(c); 105 | let channel = channel.sum_samples(); 106 | biases_grad_channels.push(channel); 107 | } 108 | let biases_grad = Image::join_channels(biases_grad_channels); 109 | 110 | self.kernels = self 111 | .kernels_optimizer 112 | .update_parameters(epoch, &self.kernels, &kern_grad); 113 | self.biases = self 114 | .biases_optimizer 115 | .update_parameters(epoch, &self.biases, &biases_grad); 116 | input_grad 117 | } 118 | } 119 | 120 | impl LearnableLayer for DirectConvLayer { 121 | fn get_learnable_parameters(&self) -> Vec> { 122 | let mut params = self.kernels.flatten().get_data_col_leading(); 123 | params.push(self.biases.flatten().get_column(0)); 124 | params 125 | } 126 | 127 | fn set_learnable_parameters(&mut self, params_matrix: &Vec>) { 128 | let mut kernels = params_matrix.clone(); 129 | let biases = kernels.pop().unwrap(); 130 | self.kernels = Image::from_samples( 131 | &Matrix::from_column_leading_vector2(&kernels), 132 | self.kernels.channels(), 133 | ); 134 | self.biases = 135 | Image::from_samples(&Matrix::from_column_vector(&biases), self.biases.channels()); 136 | } 137 | } 138 | 139 | impl ConvLayer for DirectConvLayer { 140 | fn scale_kernels(&mut self, scale: Scalar) { 141 | self.kernels = self.kernels.scalar_mul(scale); 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /src/vision/conv_layer/full_conv_layer.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | 3 | use rand::Rng; 4 | 5 | use crate::layer::{DropoutLayer, LearnableLayer, ParameterableLayer}; 6 | use crate::linalg::Scalar; 7 | use crate::vision::conv_network::ConvNetworkLayer; 8 | 9 | use super::{ConvLayer, Image}; 10 | use crate::vision::image::ImageTrait; 11 | use crate::vision::conv_activation::ConvActivationLayer; 12 | use crate::vision::image_layer::ImageLayer; 13 | 14 | #[derive(Debug)] 15 | pub struct FullConvLayer { 16 | conv: Box, 17 | activation: ConvActivationLayer, 18 | dropout_enabled: bool, 19 | dropout_rate: Option, 20 | mask: Option, 21 | } 22 | 23 | impl FullConvLayer { 24 | pub fn new( 25 | conv: Box, 26 | activation: ConvActivationLayer, 27 | dropout: Option, 28 | ) -> Self { 29 | Self { 30 | conv, 31 | activation, 32 | dropout_rate: dropout, 33 | dropout_enabled: false, 34 | mask: None, 35 | } 36 | } 37 | 38 | fn generate_dropout_mask( 39 | &mut self, 40 | kern_size: (usize, usize, usize), 41 | nkern: usize, 42 | ) -> Option<(Image, Scalar)> { 43 | if let Some(dropout_rate) = self.dropout_rate { 44 | let mut rng = rand::thread_rng(); 45 | let dropout_mask = Image::from_fn( 46 | kern_size.0, 47 | kern_size.1, 48 | kern_size.2, 49 | nkern, 50 | |_, _, _, _| { 51 | if rng 52 | .gen_range((0.0 as Scalar)..(1.0 as Scalar)) 53 | .total_cmp(&self.dropout_rate.unwrap()) 54 | == Ordering::Greater 55 | { 56 | 1.0 57 | } else { 58 | 0.0 59 | } 60 | }, 61 | ); 62 | Some((dropout_mask, dropout_rate)) 63 | } else { 64 | None 65 | } 66 | } 67 | } 68 | 69 | impl ImageLayer for FullConvLayer { 70 | fn forward(&mut self, mut input: Image) -> Image { 71 | let output = if self.dropout_enabled { 72 | if let Some((mask, _)) = self.generate_dropout_mask(input.image_dims(), input.samples()) 73 | { 74 | input = input.component_mul(&mask); 75 | self.mask = Some(mask); 76 | }; 77 | self.conv.forward(input) 78 | } else { 79 | if let Some(dropout_rate) = self.dropout_rate { 80 | self.conv.scale_kernels(1.0 - dropout_rate); 81 | let output = self.conv.forward(input); 82 | self.conv.scale_kernels(1.0 / (1.0 - dropout_rate)); 83 | output 84 | } else { 85 | self.conv.forward(input) 86 | } 87 | }; 88 | 89 | self.activation.forward(output) 90 | } 91 | 92 | fn backward(&mut self, epoch: usize, output_gradient: Image) -> Image { 93 | let activation_input_gradient = self.activation.backward(epoch, output_gradient); 94 | let input_gradient = self.conv.backward(epoch, activation_input_gradient); 95 | 96 | if let Some(mask) = &self.mask { 97 | input_gradient.component_mul(&mask) 98 | } else { 99 | input_gradient 100 | } 101 | } 102 | } 103 | 104 | impl LearnableLayer for FullConvLayer { 105 | fn get_learnable_parameters(&self) -> Vec> { 106 | self.conv.get_learnable_parameters() 107 | } 108 | 109 | fn set_learnable_parameters(&mut self, params_matrix: &Vec>) { 110 | self.conv.set_learnable_parameters(params_matrix) 111 | } 112 | } 113 | 114 | impl DropoutLayer for FullConvLayer { 115 | fn enable_dropout(&mut self) { 116 | self.dropout_enabled = true; 117 | } 118 | 119 | fn disable_dropout(&mut self) { 120 | self.dropout_enabled = false; 121 | } 122 | } 123 | 124 | impl ParameterableLayer for FullConvLayer { 125 | fn as_learnable_layer(&self) -> Option<&dyn crate::layer::LearnableLayer> { 126 | Some(self) 127 | } 128 | 129 | fn as_learnable_layer_mut(&mut self) -> Option<&mut dyn crate::layer::LearnableLayer> { 130 | Some(self) 131 | } 132 | 133 | fn as_dropout_layer(&mut self) -> Option<&mut dyn crate::layer::DropoutLayer> { 134 | Some(self) 135 | } 136 | } 137 | 138 | impl ConvNetworkLayer for FullConvLayer {} 139 | -------------------------------------------------------------------------------- /src/vision/conv_layer/mod.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | use crate::{layer::LearnableLayer, linalg::Scalar}; 4 | 5 | use super::{image::Image, image_layer::ImageLayer}; 6 | 7 | pub mod defaults; 8 | pub mod dense_conv_layer; 9 | pub mod direct_conv_layer; 10 | pub mod avg_pooling_layer; 11 | pub mod full_conv_layer; 12 | 13 | pub trait ConvLayer: ImageLayer + LearnableLayer + Send + Debug { 14 | fn scale_kernels(&mut self, scale: Scalar); 15 | } 16 | -------------------------------------------------------------------------------- /src/vision/conv_network.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | use crate::{ 4 | layer::{DropoutLayer, Layer, LearnableLayer, ParameterableLayer}, 5 | linalg::{Matrix, Scalar}, 6 | network::NetworkLayer, 7 | vision::{image::Image, image::ImageTrait}, monitor::TM, 8 | }; 9 | 10 | use super::image_layer::ImageLayer; 11 | 12 | #[derive(Debug)] 13 | pub struct ConvNetwork { 14 | layers: Vec>, 15 | channels: usize, 16 | out_channels: Option 17 | } 18 | 19 | impl ConvNetwork { 20 | pub fn new(layers: Vec>, channels: usize) -> Self { 21 | Self { layers, channels, out_channels: None } 22 | } 23 | } 24 | 25 | impl Layer for ConvNetwork { 26 | fn forward(&mut self, input: Matrix) -> Matrix { 27 | TM::start("cnet.forw"); 28 | let mut output = Image::from_samples(&input, self.channels); 29 | let n_layers = self.layers.len(); 30 | 31 | for (i, layer) in self.layers.iter_mut().enumerate() { 32 | TM::start(format!("layer[{}/{}]", i+1, n_layers)); 33 | output = layer.forward(output); 34 | TM::end(); 35 | } 36 | 37 | self.out_channels = Some(output.channels()); 38 | 39 | let out = output.flatten(); 40 | TM::end(); 41 | out 42 | } 43 | 44 | fn backward(&mut self, epoch: usize, error_gradient: Matrix) -> Matrix { 45 | TM::start("cnet.back"); 46 | let mut error_gradient = Image::from_samples(&error_gradient, self.out_channels.unwrap()); 47 | 48 | for (i, layer) in self.layers.iter_mut().enumerate().rev() { 49 | TM::start(format!("layer[{}]", i+1)); 50 | error_gradient = layer.backward(epoch, error_gradient); 51 | TM::end(); 52 | } 53 | 54 | let grad = error_gradient.flatten(); 55 | TM::end(); 56 | grad 57 | } 58 | } 59 | 60 | impl NetworkLayer for ConvNetwork {} 61 | 62 | impl ParameterableLayer for ConvNetwork { 63 | fn as_learnable_layer(&self) -> Option<&dyn LearnableLayer> { 64 | Some(self) 65 | } 66 | 67 | fn as_learnable_layer_mut(&mut self) -> Option<&mut dyn LearnableLayer> { 68 | Some(self) 69 | } 70 | 71 | fn as_dropout_layer(&mut self) -> Option<&mut dyn DropoutLayer> { 72 | Some(self) 73 | } 74 | } 75 | 76 | impl LearnableLayer for ConvNetwork { 77 | fn get_learnable_parameters(&self) -> Vec> { 78 | let mut params = Vec::new(); 79 | for layer in self.layers.iter() { 80 | layer.as_learnable_layer().map(|l| { 81 | params.append(&mut l.get_learnable_parameters()); 82 | }); 83 | params.push(vec![-1.0; 1]); 84 | } 85 | params 86 | } 87 | 88 | fn set_learnable_parameters(&mut self, params_matrix: &Vec>) { 89 | // each layer's params are split in the params matrix at some line with a -1.0 90 | let mut params_matrix = params_matrix.clone(); 91 | for layer in self.layers.iter_mut() { 92 | let mut layer_params = Vec::new(); 93 | while let Some(param) = params_matrix.pop() { 94 | if param[0] == -1.0 { 95 | break; 96 | } 97 | layer_params.push(param); 98 | } 99 | layer.as_learnable_layer_mut().map(|l| { 100 | l.set_learnable_parameters(&layer_params); 101 | }); 102 | } 103 | } 104 | } 105 | 106 | impl DropoutLayer for ConvNetwork { 107 | fn enable_dropout(&mut self) { 108 | self.layers.iter_mut().for_each(|l| { 109 | l.as_dropout_layer().map(|l| l.enable_dropout()); 110 | }); 111 | } 112 | 113 | fn disable_dropout(&mut self) { 114 | self.layers.iter_mut().for_each(|l| { 115 | l.as_dropout_layer().map(|l| l.disable_dropout()); 116 | }); 117 | } 118 | } 119 | 120 | pub trait ConvNetworkLayer: ImageLayer + ParameterableLayer + Debug + Send {} 121 | -------------------------------------------------------------------------------- /src/vision/conv_optimizer/adam.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::{ 4 | learning_rate::{default_learning_rate, LearningRateSchedule}, 5 | linalg::Scalar, 6 | vision::{image::Image, image::ImageTrait}, 7 | }; 8 | 9 | fn default_beta1() -> Scalar { 10 | 0.9 11 | } 12 | 13 | fn default_beta2() -> Scalar { 14 | 0.999 15 | } 16 | 17 | fn default_epsilon() -> Scalar { 18 | 1e-8 19 | } 20 | 21 | // https://arxiv.org/pdf/1412.6980.pdf 22 | #[derive(Clone, Debug, Serialize, Deserialize)] 23 | pub struct ConvAdam { 24 | #[serde(default = "default_beta1")] 25 | beta1: Scalar, 26 | #[serde(default = "default_beta2")] 27 | beta2: Scalar, 28 | #[serde(default = "default_epsilon")] 29 | epsilon: Scalar, 30 | #[serde(default = "default_learning_rate")] 31 | learning_rate: LearningRateSchedule, 32 | #[serde(skip)] 33 | m: Option, // first moment vector 34 | #[serde(skip)] 35 | v: Option, // second moment vector 36 | } 37 | 38 | impl ConvAdam { 39 | pub fn new( 40 | learning_rate: LearningRateSchedule, 41 | beta1: Scalar, 42 | beta2: Scalar, 43 | epsilon: Scalar, 44 | ) -> Self { 45 | Self { 46 | m: None, 47 | v: None, 48 | beta1, 49 | beta2, 50 | learning_rate, 51 | epsilon, 52 | } 53 | } 54 | 55 | pub fn default() -> Self { 56 | Self { 57 | v: None, 58 | m: None, 59 | beta1: default_beta1(), 60 | beta2: default_beta2(), 61 | learning_rate: default_learning_rate(), 62 | epsilon: default_epsilon(), 63 | } 64 | } 65 | 66 | pub fn update_parameters( 67 | &mut self, 68 | epoch: usize, 69 | parameters: &Image, 70 | parameters_gradient: &Image, 71 | ) -> Image { 72 | let alpha = self.learning_rate.get_learning_rate(epoch); 73 | 74 | let (nrow, ncol, nchan) = parameters_gradient.image_dims(); 75 | let n_sample = parameters_gradient.samples(); 76 | 77 | if self.m.is_none() { 78 | self.m = Some(Image::zeros(nrow, ncol, nchan, n_sample)); 79 | } 80 | if self.v.is_none() { 81 | self.v = Some(Image::zeros(nrow, ncol, nchan, n_sample)); 82 | } 83 | let m = self.m.as_ref().unwrap(); 84 | let v = self.v.as_ref().unwrap(); 85 | 86 | let g = parameters_gradient; 87 | let g2 = parameters_gradient.component_mul(¶meters_gradient); 88 | 89 | let m = &(m.scalar_mul(self.beta1)).component_add(&g.scalar_mul(1.0 - self.beta1)); 90 | let v = &(v.scalar_mul(self.beta2)).component_add(&g2.scalar_mul(1.0 - self.beta2)); 91 | 92 | let m_bias_corrected = m.scalar_div(1.0 - self.beta1); 93 | let v_bias_corrected = v.scalar_div(1.0 - self.beta2); 94 | 95 | let v_bias_corrected = v_bias_corrected.sqrt(); 96 | 97 | self.m = Some(m.clone()); 98 | self.v = Some(v.clone()); 99 | parameters.component_sub( 100 | &(m_bias_corrected.scalar_mul(alpha)) 101 | .component_div(&v_bias_corrected.scalar_add(self.epsilon)), 102 | ) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/vision/conv_optimizer/mod.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use self::{adam::ConvAdam, momentum::ConvMomentum, sgd::ConvSGD}; 4 | 5 | use super::image::Image; 6 | 7 | pub mod adam; 8 | pub mod momentum; 9 | pub mod sgd; 10 | 11 | #[derive(Clone, Debug, Serialize, Deserialize)] 12 | pub enum ConvOptimizers { 13 | ConvSGD(ConvSGD), 14 | ConvMomentum(ConvMomentum), 15 | ConvAdam(ConvAdam), 16 | } 17 | 18 | impl ConvOptimizers { 19 | pub fn update_parameters( 20 | &mut self, 21 | epoch: usize, 22 | parameters: &Image, 23 | parameters_gradient: &Image, 24 | ) -> Image { 25 | match self { 26 | ConvOptimizers::ConvSGD(sgd) => { 27 | sgd.update_parameters(epoch, parameters, parameters_gradient) 28 | } 29 | ConvOptimizers::ConvMomentum(momentum) => { 30 | momentum.update_parameters(epoch, parameters, parameters_gradient) 31 | } 32 | ConvOptimizers::ConvAdam(adam) => { 33 | adam.update_parameters(epoch, parameters, parameters_gradient) 34 | } 35 | } 36 | } 37 | } 38 | 39 | pub fn conv_adam() -> ConvOptimizers { 40 | ConvOptimizers::ConvAdam(ConvAdam::default()) 41 | } 42 | 43 | pub fn conv_sgd() -> ConvOptimizers { 44 | ConvOptimizers::ConvSGD(ConvSGD::default()) 45 | } 46 | 47 | pub fn conv_momentum() -> ConvOptimizers { 48 | ConvOptimizers::ConvMomentum(ConvMomentum::default()) 49 | } 50 | -------------------------------------------------------------------------------- /src/vision/conv_optimizer/momentum.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::{ 4 | learning_rate::{default_learning_rate, LearningRateSchedule}, 5 | linalg::Scalar, 6 | vision::{image::Image, image::ImageTrait}, 7 | }; 8 | 9 | use crate::optimizer::momentum::default_momentum; 10 | 11 | // https://arxiv.org/pdf/1207.0580.pdf 12 | #[derive(Clone, Debug, Serialize, Deserialize)] 13 | pub struct ConvMomentum { 14 | #[serde(default = "default_momentum")] 15 | momentum: Scalar, 16 | #[serde(default = "default_learning_rate")] 17 | learning_rate: LearningRateSchedule, 18 | #[serde(skip)] 19 | v: Option, 20 | } 21 | 22 | impl ConvMomentum { 23 | pub fn new(learning_rate: LearningRateSchedule, momentum: Scalar) -> Self { 24 | Self { 25 | v: None, 26 | momentum, 27 | learning_rate, 28 | } 29 | } 30 | 31 | pub fn default() -> Self { 32 | Self { 33 | v: None, 34 | momentum: default_momentum(), 35 | learning_rate: default_learning_rate(), 36 | } 37 | } 38 | 39 | pub fn update_parameters( 40 | &mut self, 41 | epoch: usize, 42 | parameters: &Image, 43 | parameters_gradient: &Image, 44 | ) -> Image { 45 | let lr = self.learning_rate.get_learning_rate(epoch); 46 | 47 | if let None = &self.v { 48 | let (nrow, ncol, nchan) = parameters_gradient.image_dims(); 49 | let n_sample = parameters_gradient.samples(); 50 | self.v = Some(Image::zeros(nrow, ncol, nchan, n_sample)); 51 | }; 52 | 53 | let v = self.v.as_ref().unwrap(); 54 | 55 | let v = v 56 | .scalar_mul(self.momentum) 57 | .component_add(¶meters_gradient.scalar_mul(lr)); 58 | 59 | let new_params = parameters.component_sub(&v); 60 | self.v = Some(v); 61 | new_params 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/vision/conv_optimizer/sgd.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::{ 4 | learning_rate::{default_learning_rate, LearningRateSchedule}, 5 | linalg::Scalar, 6 | vision::{image::Image, image::ImageTrait}, 7 | }; 8 | 9 | #[derive(Clone, Debug, Serialize, Deserialize)] 10 | pub struct ConvSGD { 11 | #[serde(default = "default_learning_rate")] 12 | learning_rate: LearningRateSchedule, 13 | } 14 | 15 | impl ConvSGD { 16 | pub fn default() -> Self { 17 | Self { 18 | learning_rate: default_learning_rate(), 19 | } 20 | } 21 | 22 | pub fn with_const_lr(learning_rate: Scalar) -> Self { 23 | Self { 24 | learning_rate: LearningRateSchedule::Constant(learning_rate), 25 | } 26 | } 27 | 28 | pub fn new(learning_rate: LearningRateSchedule) -> Self { 29 | Self { learning_rate } 30 | } 31 | 32 | pub fn update_parameters( 33 | &mut self, 34 | epoch: usize, 35 | parameters: &Image, 36 | parameters_gradient: &Image, 37 | ) -> Image { 38 | let lr = self.learning_rate.get_learning_rate(epoch); 39 | parameters.component_sub(¶meters_gradient.scalar_mul(lr)) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/vision/image/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::linalg::{Matrix, Scalar}; 2 | 3 | #[cfg(feature = "arrayfire")] 4 | pub mod arrayfire_image; 5 | #[cfg(feature = "arrayfire")] 6 | pub type Image = arrayfire_image::Image; 7 | 8 | #[cfg(all(feature = "nalgebra", not(feature = "arrayfire")))] 9 | pub mod nalgebra_image; 10 | #[cfg(all(feature = "nalgebra", not(feature = "arrayfire")))] 11 | pub type Image = nalgebra_image::Image; 12 | 13 | #[cfg(all(feature = "ndarray", not(feature = "arrayfire"), not(feature = "nalgebra")))] 14 | pub mod ndarray_image; 15 | #[cfg(all(feature = "ndarray", not(feature = "arrayfire"), not(feature = "nalgebra")))] 16 | pub type Image = ndarray_image::Image; 17 | 18 | /// An image (or batched images) composed of Scalar n rows on m columns and c channels (with s samples if batched). 19 | pub trait ImageTrait { 20 | fn zeros(nrow: usize, ncol: usize, nchan: usize, samples: usize) -> Self; 21 | 22 | fn constant(nrow: usize, ncol: usize, nchan: usize, samples: usize, value: Scalar) -> Self; 23 | 24 | fn random_uniform( 25 | nrow: usize, 26 | ncol: usize, 27 | nchan: usize, 28 | samples: usize, 29 | min: Scalar, 30 | max: Scalar, 31 | ) -> Self; 32 | 33 | fn random_normal( 34 | nrow: usize, 35 | ncol: usize, 36 | nchan: usize, 37 | samples: usize, 38 | mean: Scalar, 39 | stddev: Scalar, 40 | ) -> Self; 41 | 42 | fn from_fn(nrows: usize, ncols: usize, nchan: usize, samples: usize, f: F) -> Self 43 | where 44 | F: FnMut(usize, usize, usize, usize) -> Scalar; 45 | 46 | /// `samples` has shape `(i, n)` where `n` is the number of samples, `i` is the number of pixels. 47 | /// 48 | /// Pixels are assumed to be in column-leading order with channels put in their entirety one after the other. 49 | fn from_samples(samples: &Matrix, channels: usize) -> Self; 50 | 51 | /// Adds the components of self and other. Assumes both images have the same pixel sizes and channels count. 52 | /// 53 | /// If other has less samples than self, it will add the first sample of other to all samples of self. 54 | fn component_add(&self, other: &Self) -> Self; 55 | 56 | /// Substracts the components of self and other. Assumes both images have the same pixel sizes and channels count. 57 | /// 58 | /// If other has less samples than self, it will substract the first sample of other to all samples of self. 59 | fn component_sub(&self, other: &Self) -> Self; 60 | 61 | /// Multiplies the components of self and other. Assumes both images have the same pixel sizes and channels count. 62 | /// 63 | /// If other has less samples than self, it will multiply the first sample of other to all samples of self. 64 | fn component_mul(&self, other: &Self) -> Self; 65 | 66 | /// Divides the components of self and other. Assumes both images have the same pixel sizes and channels count. 67 | /// 68 | /// If other has less samples than self, it will divide the first sample of other to all samples of self. 69 | fn component_div(&self, other: &Self) -> Self; 70 | 71 | fn scalar_add(&self, scalar: Scalar) -> Self; 72 | 73 | fn scalar_sub(&self, scalar: Scalar) -> Self; 74 | 75 | fn scalar_mul(&self, scalar: Scalar) -> Self; 76 | 77 | fn scalar_div(&self, scalar: Scalar) -> Self; 78 | 79 | fn cross_correlate(&self, kernels: &Self) -> Self; 80 | 81 | fn convolve_full(&self, kernels: &Self) -> Self; 82 | 83 | fn flatten(&self) -> Matrix; 84 | 85 | /// Returns (nrow, ncol, nchan) 86 | fn image_dims(&self) -> (usize, usize, usize); 87 | 88 | fn channels(&self) -> usize; 89 | 90 | /// Returns the amount of samples in the batch 91 | fn samples(&self) -> usize; 92 | 93 | /// Returns a full image (pixels + channels) for the given sample 94 | fn get_sample(&self, sample: usize) -> Self; 95 | 96 | /// Returns a single channel. Assumes the image contains only 1 sample. 97 | fn get_channel(&self, channel: usize) -> Self; 98 | 99 | /// Returns all the samples with only one channel. 100 | fn get_channel_across_samples(&self, channel: usize) -> Self; 101 | 102 | fn sum_samples(&self) -> Self; 103 | 104 | fn join_channels(channels: Vec) -> Self 105 | where 106 | Self: Sized; 107 | 108 | fn join_samples(samples: Vec) -> Self 109 | where 110 | Self: Sized; 111 | 112 | fn wrap(&self, ox: usize, oy: usize, wx: usize, wy: usize, sx: usize, sy: usize, px: usize, py: usize) -> Self; 113 | 114 | fn unwrap(&self, wx: usize, wy: usize, sx: usize, sy: usize, px: usize, py: usize) -> Self; 115 | 116 | fn tile(&self, repetitions_row: usize, repetitions_col: usize, repetitions_chan: usize, repetition_sample: usize) -> Self; 117 | 118 | fn square(&self) -> Self; 119 | 120 | fn sum(&self) -> Scalar; 121 | 122 | fn mean(&self) -> Scalar; 123 | 124 | fn mean_along(&self, dim: usize) -> Self; 125 | 126 | fn exp(&self) -> Self; 127 | 128 | fn maxof(&self, other: &Self) -> Self; 129 | 130 | fn sign(&self) -> Self; 131 | 132 | fn minof(&self, other: &Self) -> Self; 133 | 134 | fn sqrt(&self) -> Self; 135 | } 136 | -------------------------------------------------------------------------------- /src/vision/image/nalgebra_image.rs: -------------------------------------------------------------------------------- 1 | use crate::linalg::{Matrix, Scalar}; 2 | 3 | use super::ImageTrait; 4 | 5 | #[derive(Clone, Debug)] 6 | pub struct Image(usize); 7 | 8 | #[allow(unused_variables)] 9 | impl ImageTrait for Image { 10 | fn zeros(nrow: usize, ncol: usize, nchan: usize, samples: usize) -> Self { 11 | unimplemented!() 12 | } 13 | 14 | fn constant(nrow: usize, ncol: usize, nchan: usize, samples: usize, value: Scalar) -> Self { 15 | unimplemented!() 16 | } 17 | 18 | fn random_uniform( 19 | nrow: usize, 20 | ncol: usize, 21 | nchan: usize, 22 | samples: usize, 23 | min: Scalar, 24 | max: Scalar, 25 | ) -> Self { 26 | unimplemented!() 27 | } 28 | 29 | fn random_normal( 30 | nrow: usize, 31 | ncol: usize, 32 | nchan: usize, 33 | samples: usize, 34 | mean: Scalar, 35 | stddev: Scalar, 36 | ) -> Self { 37 | unimplemented!() 38 | } 39 | 40 | fn from_fn(nrows: usize, ncols: usize, nchan: usize, samples: usize, f: F) -> Self 41 | where 42 | F: FnMut(usize, usize, usize, usize) -> Scalar, 43 | { 44 | unimplemented!() 45 | } 46 | 47 | fn from_samples(samples: &Matrix, channels: usize) -> Self { 48 | unimplemented!() 49 | } 50 | 51 | fn wrap( 52 | &self, 53 | ox: usize, 54 | oy: usize, 55 | wx: usize, 56 | wy: usize, 57 | sx: usize, 58 | sy: usize, 59 | px: usize, 60 | py: usize, 61 | ) -> Self { 62 | unimplemented!() 63 | } 64 | 65 | fn unwrap(&self, wx: usize, wy: usize, sx: usize, sy: usize, px: usize, py: usize) -> Self { 66 | unimplemented!() 67 | } 68 | 69 | fn tile( 70 | &self, 71 | repetitions_row: usize, 72 | repetitions_col: usize, 73 | repetitions_chan: usize, 74 | repetition_sample: usize, 75 | ) -> Self { 76 | unimplemented!() 77 | } 78 | 79 | fn component_add(&self, other: &Self) -> Self { 80 | unimplemented!() 81 | } 82 | 83 | fn component_sub(&self, other: &Self) -> Self { 84 | unimplemented!() 85 | } 86 | 87 | fn component_mul(&self, other: &Self) -> Self { 88 | unimplemented!() 89 | } 90 | 91 | fn component_div(&self, other: &Self) -> Self { 92 | unimplemented!() 93 | } 94 | 95 | fn scalar_add(&self, scalar: Scalar) -> Self { 96 | unimplemented!() 97 | } 98 | 99 | fn scalar_sub(&self, scalar: Scalar) -> Self { 100 | unimplemented!() 101 | } 102 | 103 | fn scalar_mul(&self, scalar: Scalar) -> Self { 104 | unimplemented!() 105 | } 106 | 107 | fn scalar_div(&self, scalar: Scalar) -> Self { 108 | unimplemented!() 109 | } 110 | 111 | fn cross_correlate(&self, kernels: &Self) -> Self { 112 | unimplemented!() 113 | } 114 | 115 | fn convolve_full(&self, kernels: &Self) -> Self { 116 | unimplemented!() 117 | } 118 | 119 | fn flatten(&self) -> Matrix { 120 | unimplemented!() 121 | } 122 | 123 | fn image_dims(&self) -> (usize, usize, usize) { 124 | unimplemented!() 125 | } 126 | 127 | fn channels(&self) -> usize { 128 | unimplemented!() 129 | } 130 | 131 | fn samples(&self) -> usize { 132 | unimplemented!() 133 | } 134 | 135 | fn get_sample(&self, sample: usize) -> Self { 136 | unimplemented!() 137 | } 138 | 139 | fn get_channel(&self, channel: usize) -> Self { 140 | unimplemented!() 141 | } 142 | 143 | fn get_channel_across_samples(&self, channel: usize) -> Self { 144 | unimplemented!() 145 | } 146 | 147 | fn sum_samples(&self) -> Self { 148 | unimplemented!() 149 | } 150 | 151 | fn join_channels(channels: Vec) -> Self { 152 | unimplemented!() 153 | } 154 | 155 | fn join_samples(samples: Vec) -> Self { 156 | unimplemented!() 157 | } 158 | 159 | fn square(&self) -> Self { 160 | unimplemented!() 161 | } 162 | 163 | fn sum(&self) -> Scalar { 164 | unimplemented!() 165 | } 166 | 167 | fn mean(&self) -> Scalar { 168 | unimplemented!() 169 | } 170 | 171 | fn mean_along(&self, dim: usize) -> Self { 172 | unimplemented!() 173 | } 174 | 175 | fn exp(&self) -> Self { 176 | unimplemented!() 177 | } 178 | 179 | fn maxof(&self, other: &Self) -> Self { 180 | unimplemented!() 181 | } 182 | 183 | fn sign(&self) -> Self { 184 | unimplemented!() 185 | } 186 | 187 | fn minof(&self, other: &Self) -> Self { 188 | unimplemented!() 189 | } 190 | 191 | fn sqrt(&self) -> Self { 192 | unimplemented!() 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /src/vision/image/ndarray_image.rs: -------------------------------------------------------------------------------- 1 | use crate::linalg::{Matrix, Scalar}; 2 | 3 | use super::ImageTrait; 4 | 5 | #[derive(Clone, Debug)] 6 | pub struct Image(usize); 7 | 8 | #[allow(unused_variables)] 9 | impl ImageTrait for Image { 10 | fn zeros(nrow: usize, ncol: usize, nchan: usize, samples: usize) -> Self { 11 | unimplemented!() 12 | } 13 | 14 | fn constant(nrow: usize, ncol: usize, nchan: usize, samples: usize, value: Scalar) -> Self { 15 | unimplemented!() 16 | } 17 | 18 | fn random_uniform( 19 | nrow: usize, 20 | ncol: usize, 21 | nchan: usize, 22 | samples: usize, 23 | min: Scalar, 24 | max: Scalar, 25 | ) -> Self { 26 | unimplemented!() 27 | } 28 | 29 | fn random_normal( 30 | nrow: usize, 31 | ncol: usize, 32 | nchan: usize, 33 | samples: usize, 34 | mean: Scalar, 35 | stddev: Scalar, 36 | ) -> Self { 37 | unimplemented!() 38 | } 39 | 40 | fn from_fn(nrows: usize, ncols: usize, nchan: usize, samples: usize, f: F) -> Self 41 | where 42 | F: FnMut(usize, usize, usize, usize) -> Scalar, 43 | { 44 | unimplemented!() 45 | } 46 | 47 | fn from_samples(samples: &Matrix, channels: usize) -> Self { 48 | unimplemented!() 49 | } 50 | 51 | fn wrap( 52 | &self, 53 | ox: usize, 54 | oy: usize, 55 | wx: usize, 56 | wy: usize, 57 | sx: usize, 58 | sy: usize, 59 | px: usize, 60 | py: usize, 61 | ) -> Self { 62 | unimplemented!() 63 | } 64 | 65 | fn unwrap(&self, wx: usize, wy: usize, sx: usize, sy: usize, px: usize, py: usize) -> Self { 66 | unimplemented!() 67 | } 68 | 69 | fn tile( 70 | &self, 71 | repetitions_row: usize, 72 | repetitions_col: usize, 73 | repetitions_chan: usize, 74 | repetition_sample: usize, 75 | ) -> Self { 76 | unimplemented!() 77 | } 78 | 79 | fn component_add(&self, other: &Self) -> Self { 80 | unimplemented!() 81 | } 82 | 83 | fn component_sub(&self, other: &Self) -> Self { 84 | unimplemented!() 85 | } 86 | 87 | fn component_mul(&self, other: &Self) -> Self { 88 | unimplemented!() 89 | } 90 | 91 | fn component_div(&self, other: &Self) -> Self { 92 | unimplemented!() 93 | } 94 | 95 | fn scalar_add(&self, scalar: Scalar) -> Self { 96 | unimplemented!() 97 | } 98 | 99 | fn scalar_sub(&self, scalar: Scalar) -> Self { 100 | unimplemented!() 101 | } 102 | 103 | fn scalar_mul(&self, scalar: Scalar) -> Self { 104 | unimplemented!() 105 | } 106 | 107 | fn scalar_div(&self, scalar: Scalar) -> Self { 108 | unimplemented!() 109 | } 110 | 111 | fn cross_correlate(&self, kernels: &Self) -> Self { 112 | unimplemented!() 113 | } 114 | 115 | fn convolve_full(&self, kernels: &Self) -> Self { 116 | unimplemented!() 117 | } 118 | 119 | fn flatten(&self) -> Matrix { 120 | unimplemented!() 121 | } 122 | 123 | fn image_dims(&self) -> (usize, usize, usize) { 124 | unimplemented!() 125 | } 126 | 127 | fn channels(&self) -> usize { 128 | unimplemented!() 129 | } 130 | 131 | fn samples(&self) -> usize { 132 | unimplemented!() 133 | } 134 | 135 | fn get_sample(&self, sample: usize) -> Self { 136 | unimplemented!() 137 | } 138 | 139 | fn get_channel(&self, channel: usize) -> Self { 140 | unimplemented!() 141 | } 142 | 143 | fn get_channel_across_samples(&self, channel: usize) -> Self { 144 | unimplemented!() 145 | } 146 | 147 | fn sum_samples(&self) -> Self { 148 | unimplemented!() 149 | } 150 | 151 | fn join_channels(channels: Vec) -> Self { 152 | unimplemented!() 153 | } 154 | 155 | fn join_samples(samples: Vec) -> Self { 156 | unimplemented!() 157 | } 158 | 159 | fn square(&self) -> Self { 160 | unimplemented!() 161 | } 162 | 163 | fn sum(&self) -> Scalar { 164 | unimplemented!() 165 | } 166 | 167 | fn mean(&self) -> Scalar { 168 | unimplemented!() 169 | } 170 | 171 | fn mean_along(&self, dim: usize) -> Self { 172 | unimplemented!() 173 | } 174 | 175 | fn exp(&self) -> Self { 176 | unimplemented!() 177 | } 178 | 179 | fn maxof(&self, other: &Self) -> Self { 180 | unimplemented!() 181 | } 182 | 183 | fn sign(&self) -> Self { 184 | unimplemented!() 185 | } 186 | 187 | fn minof(&self, other: &Self) -> Self { 188 | unimplemented!() 189 | } 190 | 191 | fn sqrt(&self) -> Self { 192 | unimplemented!() 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /src/vision/image_layer.rs: -------------------------------------------------------------------------------- 1 | use super::image::Image; 2 | 3 | pub trait ImageLayer { 4 | fn forward(&mut self, input: Image) -> Image; 5 | fn backward(&mut self, epoch: usize, output_gradient: Image) -> Image; 6 | } 7 | -------------------------------------------------------------------------------- /src/vision/mod.rs: -------------------------------------------------------------------------------- 1 | /// Initializers for convolutional layers 2 | pub mod conv_initializers; 3 | /// Convolutional layers and abstractions (Average Pooling, Dense, Direct...) 4 | pub mod conv_layer; 5 | /// Convolutional network abstractions 6 | pub mod conv_network; 7 | /// Optimizers for convolutional layers 8 | pub mod conv_optimizer; 9 | /// Backends for image manipulation (convolutions, channels...) 10 | pub mod image; 11 | /// Activation functions for convolutional layers 12 | pub mod conv_activation; 13 | /// Abstraction of a layer taking an image as input and outputting an image 14 | pub mod image_layer; 15 | --------------------------------------------------------------------------------